reasonkit/mcp/
registry.rs

1//! MCP Server Registry
2//!
3//! Dynamic server discovery, registration, and health monitoring.
4
5use super::server::{McpServerTrait, ServerStatus};
6#[cfg(feature = "daemon")]
7use super::tools::ToolResult;
8use super::tools::{GetPromptResult, Prompt, Tool};
9use super::types::*;
10use crate::error::{Error, Result};
11use chrono::{DateTime, Utc};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio::sync::RwLock;
16use uuid::Uuid;
17
18/// Health check status
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(rename_all = "snake_case")]
21pub enum HealthStatus {
22    /// Server is healthy
23    Healthy,
24    /// Server is degraded
25    Degraded,
26    /// Server is unhealthy
27    Unhealthy,
28    /// Health check in progress
29    Checking,
30    /// Unknown status
31    Unknown,
32}
33
34/// Health check result
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct HealthCheck {
37    /// Server ID
38    pub server_id: Uuid,
39    /// Server name
40    pub server_name: String,
41    /// Health status
42    pub status: HealthStatus,
43    /// Last check timestamp
44    pub checked_at: DateTime<Utc>,
45    /// Response time in milliseconds
46    pub response_time_ms: Option<f64>,
47    /// Error message if unhealthy
48    pub error: Option<String>,
49}
50
51/// Server registration
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct ServerRegistration {
54    /// Server ID
55    pub id: Uuid,
56    /// Server name
57    pub name: String,
58    /// Server info
59    pub info: ServerInfo,
60    /// Server capabilities
61    pub capabilities: ServerCapabilities,
62    /// Registration timestamp
63    pub registered_at: DateTime<Utc>,
64    /// Last health check
65    pub last_health_check: Option<HealthCheck>,
66    /// Tags for categorization
67    pub tags: Vec<String>,
68}
69
70/// MCP server registry
71pub struct McpRegistry {
72    /// Registered servers
73    servers: Arc<RwLock<HashMap<Uuid, Arc<dyn McpServerTrait>>>>,
74    /// Server registrations (metadata)
75    registrations: Arc<RwLock<HashMap<Uuid, ServerRegistration>>>,
76    /// Health check interval in seconds
77    health_check_interval_secs: u64,
78    /// Background health check handle
79    health_check_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
80}
81
82impl McpRegistry {
83    /// Create a new registry
84    pub fn new() -> Self {
85        Self {
86            servers: Arc::new(RwLock::new(HashMap::new())),
87            registrations: Arc::new(RwLock::new(HashMap::new())),
88            health_check_interval_secs: crate::mcp::DEFAULT_HEALTH_CHECK_INTERVAL_SECS,
89            health_check_handle: Arc::new(RwLock::new(None)),
90        }
91    }
92
93    /// Create a new registry with custom health check interval
94    pub fn with_health_check_interval(interval_secs: u64) -> Self {
95        Self {
96            servers: Arc::new(RwLock::new(HashMap::new())),
97            registrations: Arc::new(RwLock::new(HashMap::new())),
98            health_check_interval_secs: interval_secs,
99            health_check_handle: Arc::new(RwLock::new(None)),
100        }
101    }
102
103    /// Register a server
104    pub async fn register_server(
105        &self,
106        server: Arc<dyn McpServerTrait>,
107        tags: Vec<String>,
108    ) -> Result<Uuid> {
109        let info = server.server_info().await;
110        let capabilities = server.capabilities().await;
111
112        let registration = ServerRegistration {
113            id: Uuid::new_v4(),
114            name: info.name.clone(),
115            info,
116            capabilities,
117            registered_at: Utc::now(),
118            last_health_check: None,
119            tags,
120        };
121
122        let id = registration.id;
123
124        let mut servers = self.servers.write().await;
125        let mut regs = self.registrations.write().await;
126
127        servers.insert(id, server);
128        regs.insert(id, registration);
129
130        Ok(id)
131    }
132
133    /// Unregister a server
134    pub async fn unregister_server(&self, id: Uuid) -> Result<()> {
135        let mut servers = self.servers.write().await;
136        let mut regs = self.registrations.write().await;
137
138        if let Some(server) = servers.remove(&id) {
139            // Attempt graceful shutdown via Arc
140            // Note: We can't unwrap Arc<dyn Trait> as trait objects aren't Sized
141            // Just drop the Arc - if it's the last reference, the server will be dropped
142            drop(server);
143        }
144
145        regs.remove(&id);
146
147        Ok(())
148    }
149
150    /// Get a server by ID
151    pub async fn get_server(&self, id: Uuid) -> Option<Arc<dyn McpServerTrait>> {
152        let servers = self.servers.read().await;
153        servers.get(&id).cloned()
154    }
155
156    /// List all registered servers
157    pub async fn list_servers(&self) -> Vec<ServerRegistration> {
158        let regs = self.registrations.read().await;
159        regs.values().cloned().collect()
160    }
161
162    /// Find servers by tag
163    pub async fn find_servers_by_tag(&self, tag: &str) -> Vec<ServerRegistration> {
164        let regs = self.registrations.read().await;
165        regs.values()
166            .filter(|r| r.tags.iter().any(|t| t == tag))
167            .cloned()
168            .collect()
169    }
170
171    /// List all tools from all servers
172    pub async fn list_all_tools(&self) -> Result<Vec<Tool>> {
173        let servers = self.servers.read().await;
174        let mut all_tools = Vec::new();
175
176        for (id, server) in servers.iter() {
177            let regs = self.registrations.read().await;
178            let server_name = regs.get(id).map(|r| r.name.clone()).unwrap_or_default();
179
180            // Query tools/list from server
181            let request = McpRequest::new(
182                RequestId::String(Uuid::new_v4().to_string()),
183                "tools/list",
184                None,
185            );
186
187            match server.send_request(request).await {
188                Ok(response) => {
189                    if let Some(result) = response.result {
190                        if let Ok(tools_response) =
191                            serde_json::from_value::<ToolsListResponse>(result)
192                        {
193                            for mut tool in tools_response.tools {
194                                tool.server_id = Some(*id);
195                                tool.server_name = Some(server_name.clone());
196                                all_tools.push(tool);
197                            }
198                        }
199                    }
200                }
201                Err(_) => {
202                    // Server didn't respond - skip
203                    continue;
204                }
205            }
206        }
207
208        Ok(all_tools)
209    }
210
211    /// List all prompts from all servers
212    pub async fn list_all_prompts(&self) -> Result<Vec<Prompt>> {
213        let servers = self.servers.read().await;
214        let mut all_prompts = Vec::new();
215
216        for (_, server) in servers.iter() {
217            // Query prompts/list from server
218            let request = McpRequest::new(
219                RequestId::String(Uuid::new_v4().to_string()),
220                "prompts/list",
221                None,
222            );
223
224            match server.send_request(request).await {
225                Ok(response) => {
226                    if let Some(result) = response.result {
227                        if let Ok(prompts_response) =
228                            serde_json::from_value::<PromptsListResponse>(result)
229                        {
230                            all_prompts.extend(prompts_response.prompts);
231                        }
232                    }
233                }
234                Err(_) => {
235                    continue;
236                }
237            }
238        }
239
240        Ok(all_prompts)
241    }
242
243    /// Get a prompt from a specific server (or find by name)
244    pub async fn get_prompt(
245        &self,
246        prompt_name: &str,
247        arguments: HashMap<String, String>,
248        server_id: Option<Uuid>,
249    ) -> Result<GetPromptResult> {
250        let servers = self.servers.read().await;
251
252        // If server_id is provided, query that server directly
253        if let Some(id) = server_id {
254            if let Some(server) = servers.get(&id) {
255                return self
256                    .get_prompt_from_server(server.clone(), prompt_name, arguments)
257                    .await;
258            } else {
259                return Err(Error::NotFound {
260                    resource: format!("Server {}", id),
261                });
262            }
263        }
264
265        // Otherwise, broadcast to find the prompt
266        // Note: This is inefficient; in a real registry, we'd cache prompt->server mapping
267        for (_, server) in servers.iter() {
268            if let Ok(result) = self
269                .get_prompt_from_server(server.clone(), prompt_name, arguments.clone())
270                .await
271            {
272                return Ok(result);
273            }
274        }
275
276        Err(Error::NotFound {
277            resource: format!("Prompt {}", prompt_name),
278        })
279    }
280
281    async fn get_prompt_from_server(
282        &self,
283        server: Arc<dyn McpServerTrait>,
284        prompt_name: &str,
285        arguments: HashMap<String, String>,
286    ) -> Result<GetPromptResult> {
287        let params = serde_json::json!({
288            "name": prompt_name,
289            "arguments": arguments
290        });
291
292        let request = McpRequest::new(
293            RequestId::String(Uuid::new_v4().to_string()),
294            "prompts/get",
295            Some(params),
296        );
297
298        let response = server.send_request(request).await?;
299
300        if let Some(error) = response.error {
301            return Err(Error::Mcp(error.message));
302        }
303
304        if let Some(result) = response.result {
305            let prompt_result: GetPromptResult =
306                serde_json::from_value(result).map_err(Error::Json)?;
307            Ok(prompt_result)
308        } else {
309            Err(Error::Mcp("Empty response from server".to_string()))
310        }
311    }
312
313    /// Perform health check on a specific server
314    pub async fn check_server_health(&self, id: Uuid) -> Result<HealthCheck> {
315        let server = self.get_server(id).await.ok_or_else(|| Error::NotFound {
316            resource: format!("Server {}", id),
317        })?;
318
319        let regs = self.registrations.read().await;
320        let server_name = regs.get(&id).map(|r| r.name.clone()).unwrap_or_default();
321        drop(regs);
322
323        let start = std::time::Instant::now();
324        let is_healthy = server.health_check().await?;
325        let response_time_ms = start.elapsed().as_millis() as f64;
326
327        let status = match server.status().await {
328            ServerStatus::Running => HealthStatus::Healthy,
329            ServerStatus::Degraded => HealthStatus::Degraded,
330            ServerStatus::Unhealthy | ServerStatus::Failed => HealthStatus::Unhealthy,
331            _ => HealthStatus::Unknown,
332        };
333
334        let health_check = HealthCheck {
335            server_id: id,
336            server_name,
337            status,
338            checked_at: Utc::now(),
339            response_time_ms: Some(response_time_ms),
340            error: if !is_healthy {
341                Some("Health check failed".to_string())
342            } else {
343                None
344            },
345        };
346
347        // Update registration
348        let mut regs = self.registrations.write().await;
349        if let Some(reg) = regs.get_mut(&id) {
350            reg.last_health_check = Some(health_check.clone());
351        }
352
353        Ok(health_check)
354    }
355
356    /// Perform health checks on all servers
357    pub async fn check_all_health(&self) -> Vec<HealthCheck> {
358        let servers = self.servers.read().await;
359        let server_ids: Vec<Uuid> = servers.keys().copied().collect();
360        drop(servers);
361
362        let mut checks = Vec::new();
363        for id in server_ids {
364            if let Ok(check) = self.check_server_health(id).await {
365                checks.push(check);
366            }
367        }
368
369        checks
370    }
371
372    /// Start background health checking
373    pub async fn start_health_monitoring(&self) {
374        let servers = self.servers.clone();
375        let registrations = self.registrations.clone();
376        let interval_secs = self.health_check_interval_secs;
377
378        let handle = tokio::spawn(async move {
379            let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs));
380
381            loop {
382                interval.tick().await;
383
384                let servers_guard = servers.read().await;
385                let server_ids: Vec<Uuid> = servers_guard.keys().copied().collect();
386                drop(servers_guard);
387
388                for id in server_ids {
389                    let servers_guard = servers.read().await;
390                    if let Some(server) = servers_guard.get(&id).cloned() {
391                        drop(servers_guard);
392
393                        let start = std::time::Instant::now();
394                        let is_healthy = server.health_check().await.unwrap_or(false);
395                        let response_time_ms = start.elapsed().as_millis() as f64;
396
397                        let status = match server.status().await {
398                            ServerStatus::Running => HealthStatus::Healthy,
399                            ServerStatus::Degraded => HealthStatus::Degraded,
400                            ServerStatus::Unhealthy | ServerStatus::Failed => {
401                                HealthStatus::Unhealthy
402                            }
403                            _ => HealthStatus::Unknown,
404                        };
405
406                        let mut regs = registrations.write().await;
407                        if let Some(reg) = regs.get_mut(&id) {
408                            let health_check = HealthCheck {
409                                server_id: id,
410                                server_name: reg.name.clone(),
411                                status,
412                                checked_at: Utc::now(),
413                                response_time_ms: Some(response_time_ms),
414                                error: if !is_healthy {
415                                    Some("Health check failed".to_string())
416                                } else {
417                                    None
418                                },
419                            };
420                            reg.last_health_check = Some(health_check);
421                        }
422                    }
423                }
424            }
425        });
426
427        let mut handle_lock = self.health_check_handle.write().await;
428        *handle_lock = Some(handle);
429    }
430
431    /// Stop background health monitoring
432    pub async fn stop_health_monitoring(&self) {
433        let mut handle_lock = self.health_check_handle.write().await;
434        if let Some(handle) = handle_lock.take() {
435            handle.abort();
436        }
437    }
438
439    /// Get registry statistics
440    pub async fn statistics(&self) -> RegistryStatistics {
441        let regs = self.registrations.read().await;
442
443        let mut healthy = 0;
444        let mut degraded = 0;
445        let mut unhealthy = 0;
446        let mut unknown = 0;
447
448        for reg in regs.values() {
449            if let Some(check) = &reg.last_health_check {
450                match check.status {
451                    HealthStatus::Healthy => healthy += 1,
452                    HealthStatus::Degraded => degraded += 1,
453                    HealthStatus::Unhealthy => unhealthy += 1,
454                    _ => unknown += 1,
455                }
456            } else {
457                unknown += 1;
458            }
459        }
460
461        RegistryStatistics {
462            total_servers: regs.len(),
463            healthy_servers: healthy,
464            degraded_servers: degraded,
465            unhealthy_servers: unhealthy,
466            unknown_servers: unknown,
467        }
468    }
469
470    // ─── Daemon-required methods ───────────────────────────────────────
471
472    /// Ping a server to check connectivity (lightweight health check)
473    ///
474    /// Returns Ok(true) if server responds, Ok(false) if no response,
475    /// or Err if server not found.
476    #[cfg(feature = "daemon")]
477    pub async fn ping_server(&self, id: &Uuid) -> Result<bool> {
478        let server = self.get_server(*id).await.ok_or_else(|| Error::NotFound {
479            resource: format!("Server {}", id),
480        })?;
481
482        // Use health_check as ping
483        server.health_check().await
484    }
485
486    /// Attempt to reconnect to a disconnected server
487    ///
488    /// Note: For now, this just performs a health check. Full reconnection
489    /// logic would require re-initializing the server process.
490    #[cfg(feature = "daemon")]
491    pub async fn reconnect_server(&self, id: &Uuid) -> Result<()> {
492        let server = self.get_server(*id).await.ok_or_else(|| Error::NotFound {
493            resource: format!("Server {}", id),
494        })?;
495
496        // Health check is our best option without full server restart support
497        let healthy = server.health_check().await?;
498        if healthy {
499            Ok(())
500        } else {
501            Err(Error::network(
502                "Server reconnection failed - health check returned false",
503            ))
504        }
505    }
506
507    /// Call a tool by name across registered servers
508    ///
509    /// Searches all registered servers for the tool and calls the first match.
510    #[cfg(feature = "daemon")]
511    pub async fn call_tool_by_name(
512        &self,
513        tool_name: &str,
514        args: serde_json::Value,
515    ) -> Result<ToolResult> {
516        use std::collections::HashMap;
517
518        let servers = self.servers.read().await;
519
520        // Convert Value to HashMap<String, Value>
521        let args_map: HashMap<String, serde_json::Value> = match args {
522            serde_json::Value::Object(obj) => obj.into_iter().collect(),
523            _ => HashMap::new(),
524        };
525
526        for (_id, server) in servers.iter() {
527            // Get tools from this server
528            let tools = server.list_tools().await;
529            if tools.iter().any(|t| t.name == tool_name) {
530                // Found the tool, call it
531                return server.call_tool(tool_name, args_map).await;
532            }
533        }
534
535        Err(Error::NotFound {
536            resource: format!("Tool {}", tool_name),
537        })
538    }
539
540    /// Disconnect from a server (graceful shutdown)
541    #[cfg(feature = "daemon")]
542    pub async fn disconnect_server(&self, id: &Uuid) -> Result<()> {
543        // Get mutable access to unregister the server
544        self.unregister_server(*id).await
545    }
546}
547
548impl Default for McpRegistry {
549    fn default() -> Self {
550        Self::new()
551    }
552}
553
554/// Registry statistics
555#[derive(Debug, Clone, Serialize, Deserialize)]
556pub struct RegistryStatistics {
557    /// Total registered servers
558    pub total_servers: usize,
559    /// Healthy servers
560    pub healthy_servers: usize,
561    /// Degraded servers
562    pub degraded_servers: usize,
563    /// Unhealthy servers
564    pub unhealthy_servers: usize,
565    /// Unknown status servers
566    pub unknown_servers: usize,
567}
568
569/// Tools list response (from MCP spec)
570#[derive(Debug, Deserialize)]
571struct ToolsListResponse {
572    tools: Vec<Tool>,
573    #[allow(dead_code)]
574    next_cursor: Option<String>,
575}
576
577/// Prompts list response (from MCP spec)
578#[derive(Debug, Deserialize)]
579struct PromptsListResponse {
580    prompts: Vec<Prompt>,
581    #[allow(dead_code)]
582    next_cursor: Option<String>,
583}
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588
589    #[test]
590    fn test_health_status() {
591        let status = HealthStatus::Healthy;
592        let json = serde_json::to_string(&status).unwrap();
593        assert_eq!(json, "\"healthy\"");
594    }
595
596    #[tokio::test]
597    async fn test_registry_creation() {
598        let registry = McpRegistry::new();
599        let stats = registry.statistics().await;
600        assert_eq!(stats.total_servers, 0);
601    }
602}