oxify_mcp/
registry.rs

1//! Server registry for managing multiple MCP servers
2//!
3//! This module provides a registry for managing multiple MCP servers with
4//! load balancing, health checking, and failover capabilities.
5
6use crate::{McpError, McpServer, Result, ToolSchema};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
12use std::sync::Arc;
13use std::time::Instant;
14use tokio::sync::RwLock;
15
16/// Type alias for the server map to reduce complexity
17type ServerMap = Arc<RwLock<HashMap<String, ServerEntry>>>;
18
19/// Load balancing strategy for selecting servers
20#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
21pub enum LoadBalanceStrategy {
22    /// Simple round-robin selection
23    #[default]
24    RoundRobin,
25    /// Select server with least active connections
26    LeastConnections,
27    /// Random selection
28    Random,
29    /// Weighted round-robin based on server weights
30    WeightedRoundRobin,
31}
32
33/// Health status of a server
34#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
35pub enum ServerHealth {
36    #[default]
37    Healthy,
38    Degraded,
39    Unhealthy,
40}
41
42/// Server entry with metadata for load balancing
43#[derive(Clone)]
44pub struct ServerEntry {
45    /// The actual MCP server
46    pub server: Arc<Box<dyn McpServer>>,
47    /// Server weight for weighted load balancing
48    pub weight: u32,
49    /// Current health status
50    pub health: ServerHealth,
51    /// Number of active connections
52    pub active_connections: Arc<AtomicU64>,
53    /// Total request count
54    pub request_count: Arc<AtomicU64>,
55    /// Total error count
56    pub error_count: Arc<AtomicU64>,
57    /// Average response time in milliseconds
58    pub avg_response_time_ms: Arc<AtomicU64>,
59    /// Last health check time
60    pub last_health_check: Option<Instant>,
61    /// Server group for affinity
62    pub group: Option<String>,
63    /// Tags for filtering
64    pub tags: Vec<String>,
65}
66
67impl ServerEntry {
68    fn new(server: Arc<Box<dyn McpServer>>) -> Self {
69        Self {
70            server,
71            weight: 1,
72            health: ServerHealth::Healthy,
73            active_connections: Arc::new(AtomicU64::new(0)),
74            request_count: Arc::new(AtomicU64::new(0)),
75            error_count: Arc::new(AtomicU64::new(0)),
76            avg_response_time_ms: Arc::new(AtomicU64::new(0)),
77            last_health_check: None,
78            group: None,
79            tags: Vec::new(),
80        }
81    }
82
83    fn with_weight(mut self, weight: u32) -> Self {
84        self.weight = weight;
85        self
86    }
87
88    fn with_group(mut self, group: String) -> Self {
89        self.group = Some(group);
90        self
91    }
92
93    fn with_tags(mut self, tags: Vec<String>) -> Self {
94        self.tags = tags;
95        self
96    }
97}
98
99/// Registry for managing multiple MCP servers (both built-in and external)
100/// with load balancing and health checking capabilities
101pub struct McpRegistry {
102    servers: ServerMap,
103    /// Round-robin counter
104    rr_counter: AtomicUsize,
105    /// Default load balancing strategy
106    default_strategy: LoadBalanceStrategy,
107    /// Health check interval in seconds
108    health_check_interval_secs: u64,
109    /// Unhealthy threshold (consecutive failures)
110    unhealthy_threshold: u32,
111    /// Recovery threshold (consecutive successes)
112    recovery_threshold: u32,
113}
114
115impl McpRegistry {
116    /// Create a new empty registry
117    pub fn new() -> Self {
118        Self {
119            servers: Arc::new(RwLock::new(HashMap::new())),
120            rr_counter: AtomicUsize::new(0),
121            default_strategy: LoadBalanceStrategy::RoundRobin,
122            health_check_interval_secs: 30,
123            unhealthy_threshold: 3,
124            recovery_threshold: 2,
125        }
126    }
127
128    /// Create a new registry with custom configuration
129    pub fn with_config(
130        strategy: LoadBalanceStrategy,
131        health_check_interval_secs: u64,
132        unhealthy_threshold: u32,
133        recovery_threshold: u32,
134    ) -> Self {
135        Self {
136            servers: Arc::new(RwLock::new(HashMap::new())),
137            rr_counter: AtomicUsize::new(0),
138            default_strategy: strategy,
139            health_check_interval_secs,
140            unhealthy_threshold,
141            recovery_threshold,
142        }
143    }
144
145    /// Register a new MCP server
146    pub async fn register<S: McpServer + 'static>(
147        &self,
148        server_id: String,
149        server: S,
150    ) -> Result<()> {
151        let mut servers = self.servers.write().await;
152        let entry = ServerEntry::new(Arc::new(Box::new(server)));
153        servers.insert(server_id.clone(), entry);
154        tracing::info!("Registered MCP server: {}", server_id);
155        Ok(())
156    }
157
158    /// Register a server with custom weight
159    pub async fn register_with_weight<S: McpServer + 'static>(
160        &self,
161        server_id: String,
162        server: S,
163        weight: u32,
164    ) -> Result<()> {
165        let mut servers = self.servers.write().await;
166        let entry = ServerEntry::new(Arc::new(Box::new(server))).with_weight(weight);
167        servers.insert(server_id.clone(), entry);
168        tracing::info!(
169            "Registered MCP server: {} with weight {}",
170            server_id,
171            weight
172        );
173        Ok(())
174    }
175
176    /// Register a server with group affinity
177    pub async fn register_with_group<S: McpServer + 'static>(
178        &self,
179        server_id: String,
180        server: S,
181        group: String,
182    ) -> Result<()> {
183        let mut servers = self.servers.write().await;
184        let entry = ServerEntry::new(Arc::new(Box::new(server))).with_group(group.clone());
185        servers.insert(server_id.clone(), entry);
186        tracing::info!("Registered MCP server: {} in group {}", server_id, group);
187        Ok(())
188    }
189
190    /// Register a server with tags
191    pub async fn register_with_tags<S: McpServer + 'static>(
192        &self,
193        server_id: String,
194        server: S,
195        tags: Vec<String>,
196    ) -> Result<()> {
197        let mut servers = self.servers.write().await;
198        let entry = ServerEntry::new(Arc::new(Box::new(server))).with_tags(tags.clone());
199        servers.insert(server_id.clone(), entry);
200        tracing::info!("Registered MCP server: {} with tags {:?}", server_id, tags);
201        Ok(())
202    }
203
204    /// Unregister a server by ID
205    pub async fn unregister(&self, server_id: &str) -> Result<()> {
206        let mut servers = self.servers.write().await;
207        servers
208            .remove(server_id)
209            .ok_or_else(|| McpError::ServerError(format!("Server '{}' not found", server_id)))?;
210        tracing::info!("Unregistered MCP server: {}", server_id);
211        Ok(())
212    }
213
214    /// Get a server by ID
215    pub async fn get_server(&self, server_id: &str) -> Option<Arc<Box<dyn McpServer>>> {
216        let servers = self.servers.read().await;
217        servers.get(server_id).map(|entry| entry.server.clone())
218    }
219
220    /// Get a server entry by ID (includes metadata)
221    pub async fn get_server_entry(&self, server_id: &str) -> Option<ServerEntry> {
222        let servers = self.servers.read().await;
223        servers.get(server_id).cloned()
224    }
225
226    /// Set server health status
227    pub async fn set_server_health(&self, server_id: &str, health: ServerHealth) -> Result<()> {
228        let mut servers = self.servers.write().await;
229        let entry = servers
230            .get_mut(server_id)
231            .ok_or_else(|| McpError::ServerError(format!("Server '{}' not found", server_id)))?;
232        entry.health = health;
233        entry.last_health_check = Some(Instant::now());
234        tracing::info!("Updated server {} health to {:?}", server_id, health);
235        Ok(())
236    }
237
238    /// Set server weight
239    pub async fn set_server_weight(&self, server_id: &str, weight: u32) -> Result<()> {
240        let mut servers = self.servers.write().await;
241        let entry = servers
242            .get_mut(server_id)
243            .ok_or_else(|| McpError::ServerError(format!("Server '{}' not found", server_id)))?;
244        entry.weight = weight;
245        tracing::info!("Updated server {} weight to {}", server_id, weight);
246        Ok(())
247    }
248
249    /// Select a server using the specified load balancing strategy
250    pub async fn select_server(&self, strategy: Option<LoadBalanceStrategy>) -> Option<String> {
251        let strategy = strategy.unwrap_or(self.default_strategy);
252        let servers = self.servers.read().await;
253
254        // Filter healthy servers
255        let healthy_servers: Vec<(&String, &ServerEntry)> = servers
256            .iter()
257            .filter(|(_, e)| e.health == ServerHealth::Healthy)
258            .collect();
259
260        if healthy_servers.is_empty() {
261            return None;
262        }
263
264        match strategy {
265            LoadBalanceStrategy::RoundRobin => {
266                let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) % healthy_servers.len();
267                healthy_servers.get(idx).map(|(id, _)| (*id).clone())
268            }
269            LoadBalanceStrategy::LeastConnections => healthy_servers
270                .iter()
271                .min_by_key(|(_, e)| e.active_connections.load(Ordering::Relaxed))
272                .map(|(id, _)| (*id).clone()),
273            LoadBalanceStrategy::Random => {
274                use std::collections::hash_map::DefaultHasher;
275                use std::hash::{Hash, Hasher};
276                let mut hasher = DefaultHasher::new();
277                std::time::Instant::now().hash(&mut hasher);
278                let hash = hasher.finish() as usize;
279                let idx = hash % healthy_servers.len();
280                healthy_servers.get(idx).map(|(id, _)| (*id).clone())
281            }
282            LoadBalanceStrategy::WeightedRoundRobin => {
283                let total_weight: u32 = healthy_servers.iter().map(|(_, e)| e.weight).sum();
284                if total_weight == 0 {
285                    return healthy_servers.first().map(|(id, _)| (*id).clone());
286                }
287                let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed);
288                let mut position = (idx as u32) % total_weight;
289
290                for (id, entry) in &healthy_servers {
291                    if position < entry.weight {
292                        return Some((*id).clone());
293                    }
294                    position -= entry.weight;
295                }
296                healthy_servers.first().map(|(id, _)| (*id).clone())
297            }
298        }
299    }
300
301    /// Select a server from a specific group
302    pub async fn select_server_from_group(&self, group: &str) -> Option<String> {
303        let servers = self.servers.read().await;
304
305        let group_servers: Vec<(&String, &ServerEntry)> = servers
306            .iter()
307            .filter(|(_, e)| e.health == ServerHealth::Healthy && e.group.as_deref() == Some(group))
308            .collect();
309
310        if group_servers.is_empty() {
311            return None;
312        }
313
314        let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) % group_servers.len();
315        group_servers.get(idx).map(|(id, _)| (*id).clone())
316    }
317
318    /// Select a server by tag
319    pub async fn select_server_by_tag(&self, tag: &str) -> Option<String> {
320        let servers = self.servers.read().await;
321
322        let tagged_servers: Vec<(&String, &ServerEntry)> = servers
323            .iter()
324            .filter(|(_, e)| e.health == ServerHealth::Healthy && e.tags.contains(&tag.to_string()))
325            .collect();
326
327        if tagged_servers.is_empty() {
328            return None;
329        }
330
331        let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) % tagged_servers.len();
332        tagged_servers.get(idx).map(|(id, _)| (*id).clone())
333    }
334
335    /// Invoke a tool with automatic failover
336    pub async fn invoke_tool_with_failover(
337        &self,
338        tool_name: &str,
339        arguments: Value,
340        max_retries: u32,
341    ) -> Result<Value> {
342        // Find servers that have this tool
343        let server_ids = self.find_tool(tool_name).await?;
344        if server_ids.is_empty() {
345            return Err(McpError::ToolNotFound(format!(
346                "Tool '{}' not found in any server",
347                tool_name
348            )));
349        }
350
351        let mut last_error = None;
352        let mut tried_servers: Vec<String> = Vec::new();
353
354        for _ in 0..max_retries.min(server_ids.len() as u32) {
355            // Select a healthy server that hasn't been tried
356            let servers = self.servers.read().await;
357            let available_server = server_ids
358                .iter()
359                .filter(|id| !tried_servers.contains(id))
360                .find(|id| {
361                    servers
362                        .get(*id)
363                        .map(|e| e.health == ServerHealth::Healthy)
364                        .unwrap_or(false)
365                })
366                .cloned();
367            drop(servers);
368
369            let server_id = match available_server {
370                Some(id) => id,
371                None => break, // No more healthy servers to try
372            };
373
374            tried_servers.push(server_id.clone());
375
376            // Increment active connections
377            {
378                let servers = self.servers.read().await;
379                if let Some(entry) = servers.get(&server_id) {
380                    entry.active_connections.fetch_add(1, Ordering::Relaxed);
381                    entry.request_count.fetch_add(1, Ordering::Relaxed);
382                }
383            }
384
385            let start_time = Instant::now();
386            let result = self
387                .invoke_tool(&server_id, tool_name, arguments.clone())
388                .await;
389
390            // Decrement active connections and update metrics
391            {
392                let servers = self.servers.read().await;
393                if let Some(entry) = servers.get(&server_id) {
394                    entry.active_connections.fetch_sub(1, Ordering::Relaxed);
395                    let elapsed_ms = start_time.elapsed().as_millis() as u64;
396                    // Simple moving average
397                    let old_avg = entry.avg_response_time_ms.load(Ordering::Relaxed);
398                    let new_avg = (old_avg + elapsed_ms) / 2;
399                    entry.avg_response_time_ms.store(new_avg, Ordering::Relaxed);
400                }
401            }
402
403            match result {
404                Ok(value) => return Ok(value),
405                Err(e) => {
406                    tracing::warn!(
407                        "Tool invocation failed on server {}: {}. Trying failover...",
408                        server_id,
409                        e
410                    );
411
412                    // Update error count
413                    {
414                        let servers = self.servers.read().await;
415                        if let Some(entry) = servers.get(&server_id) {
416                            entry.error_count.fetch_add(1, Ordering::Relaxed);
417                        }
418                    }
419
420                    last_error = Some(e);
421                }
422            }
423        }
424
425        Err(last_error.unwrap_or_else(|| {
426            McpError::ServerError("All servers failed or unavailable".to_string())
427        }))
428    }
429
430    /// Get all servers with their health status
431    pub async fn get_server_health_status(&self) -> HashMap<String, ServerHealth> {
432        let servers = self.servers.read().await;
433        servers
434            .iter()
435            .map(|(id, entry)| (id.clone(), entry.health))
436            .collect()
437    }
438
439    /// Get server metrics
440    pub async fn get_server_metrics(&self, server_id: &str) -> Option<ServerMetrics> {
441        let servers = self.servers.read().await;
442        servers.get(server_id).map(|entry| ServerMetrics {
443            server_id: server_id.to_string(),
444            health: entry.health,
445            weight: entry.weight,
446            active_connections: entry.active_connections.load(Ordering::Relaxed),
447            request_count: entry.request_count.load(Ordering::Relaxed),
448            error_count: entry.error_count.load(Ordering::Relaxed),
449            avg_response_time_ms: entry.avg_response_time_ms.load(Ordering::Relaxed),
450            group: entry.group.clone(),
451            tags: entry.tags.clone(),
452        })
453    }
454
455    /// Get metrics for all servers
456    pub async fn get_all_server_metrics(&self) -> Vec<ServerMetrics> {
457        let servers = self.servers.read().await;
458        servers
459            .iter()
460            .map(|(id, entry)| ServerMetrics {
461                server_id: id.clone(),
462                health: entry.health,
463                weight: entry.weight,
464                active_connections: entry.active_connections.load(Ordering::Relaxed),
465                request_count: entry.request_count.load(Ordering::Relaxed),
466                error_count: entry.error_count.load(Ordering::Relaxed),
467                avg_response_time_ms: entry.avg_response_time_ms.load(Ordering::Relaxed),
468                group: entry.group.clone(),
469                tags: entry.tags.clone(),
470            })
471            .collect()
472    }
473
474    /// List servers in a specific group
475    pub async fn list_servers_in_group(&self, group: &str) -> Vec<String> {
476        let servers = self.servers.read().await;
477        servers
478            .iter()
479            .filter(|(_, e)| e.group.as_deref() == Some(group))
480            .map(|(id, _)| id.clone())
481            .collect()
482    }
483
484    /// List servers with a specific tag
485    pub async fn list_servers_with_tag(&self, tag: &str) -> Vec<String> {
486        let servers = self.servers.read().await;
487        servers
488            .iter()
489            .filter(|(_, e)| e.tags.contains(&tag.to_string()))
490            .map(|(id, _)| id.clone())
491            .collect()
492    }
493
494    /// Get load balancing configuration
495    pub fn get_config(&self) -> LoadBalanceConfig {
496        LoadBalanceConfig {
497            default_strategy: self.default_strategy,
498            health_check_interval_secs: self.health_check_interval_secs,
499            unhealthy_threshold: self.unhealthy_threshold,
500            recovery_threshold: self.recovery_threshold,
501        }
502    }
503
504    /// List all registered server IDs
505    pub async fn list_server_ids(&self) -> Vec<String> {
506        let servers = self.servers.read().await;
507        servers.keys().cloned().collect()
508    }
509
510    /// Get all tools from all registered servers
511    pub async fn list_all_tools(&self) -> Result<HashMap<String, Vec<ToolSchema>>> {
512        let servers = self.servers.read().await;
513        let mut all_tools = HashMap::new();
514
515        for (server_id, entry) in servers.iter() {
516            let tools_json = entry.server.list_tools().await?;
517            let tools: Vec<ToolSchema> = tools_json
518                .into_iter()
519                .filter_map(|v| serde_json::from_value(v).ok())
520                .collect();
521            all_tools.insert(server_id.clone(), tools);
522        }
523
524        Ok(all_tools)
525    }
526
527    /// Get tools from a specific server
528    pub async fn list_tools(&self, server_id: &str) -> Result<Vec<ToolSchema>> {
529        let server = self
530            .get_server(server_id)
531            .await
532            .ok_or_else(|| McpError::ServerError(format!("Server '{}' not found", server_id)))?;
533
534        let tools_json = server.list_tools().await?;
535        let tools: Vec<ToolSchema> = tools_json
536            .into_iter()
537            .filter_map(|v| serde_json::from_value(v).ok())
538            .collect();
539
540        Ok(tools)
541    }
542
543    /// Invoke a tool on a specific server
544    pub async fn invoke_tool(
545        &self,
546        server_id: &str,
547        tool_name: &str,
548        arguments: Value,
549    ) -> Result<Value> {
550        let server = self
551            .get_server(server_id)
552            .await
553            .ok_or_else(|| McpError::ServerError(format!("Server '{}' not found", server_id)))?;
554
555        server.call_tool(tool_name, arguments).await
556    }
557
558    /// Find a tool across all servers
559    pub async fn find_tool(&self, tool_name: &str) -> Result<Vec<String>> {
560        let all_tools = self.list_all_tools().await?;
561        let mut server_ids = Vec::new();
562
563        for (server_id, tools) in all_tools {
564            if tools.iter().any(|t| t.name == tool_name) {
565                server_ids.push(server_id);
566            }
567        }
568
569        Ok(server_ids)
570    }
571
572    /// Get server statistics
573    pub async fn get_stats(&self) -> RegistryStats {
574        let servers = self.servers.read().await;
575        let server_count = servers.len();
576
577        let mut total_tools = 0;
578        for entry in servers.values() {
579            if let Ok(tools) = entry.server.list_tools().await {
580                total_tools += tools.len();
581            }
582        }
583
584        RegistryStats {
585            server_count,
586            total_tools,
587        }
588    }
589
590    /// Clear all registered servers
591    pub async fn clear(&self) {
592        let mut servers = self.servers.write().await;
593        servers.clear();
594        tracing::info!("Cleared all registered MCP servers");
595    }
596}
597
598impl Default for McpRegistry {
599    fn default() -> Self {
600        Self::new()
601    }
602}
603
604/// Statistics about the registry
605#[derive(Debug, Clone)]
606pub struct RegistryStats {
607    pub server_count: usize,
608    pub total_tools: usize,
609}
610
611/// Server metrics for monitoring and load balancing decisions
612#[derive(Debug, Clone, Serialize, Deserialize)]
613pub struct ServerMetrics {
614    pub server_id: String,
615    pub health: ServerHealth,
616    pub weight: u32,
617    pub active_connections: u64,
618    pub request_count: u64,
619    pub error_count: u64,
620    pub avg_response_time_ms: u64,
621    pub group: Option<String>,
622    pub tags: Vec<String>,
623}
624
625impl ServerMetrics {
626    /// Calculate error rate (0.0 - 1.0)
627    pub fn error_rate(&self) -> f64 {
628        if self.request_count == 0 {
629            return 0.0;
630        }
631        self.error_count as f64 / self.request_count as f64
632    }
633
634    /// Check if server should be marked unhealthy based on error rate
635    pub fn should_mark_unhealthy(&self, error_rate_threshold: f64) -> bool {
636        self.error_rate() > error_rate_threshold && self.request_count > 10
637    }
638}
639
640/// Load balancing configuration
641#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
642pub struct LoadBalanceConfig {
643    pub default_strategy: LoadBalanceStrategy,
644    pub health_check_interval_secs: u64,
645    pub unhealthy_threshold: u32,
646    pub recovery_threshold: u32,
647}
648
649impl Default for LoadBalanceConfig {
650    fn default() -> Self {
651        Self {
652            default_strategy: LoadBalanceStrategy::RoundRobin,
653            health_check_interval_secs: 30,
654            unhealthy_threshold: 3,
655            recovery_threshold: 2,
656        }
657    }
658}
659
660#[async_trait]
661impl McpServer for McpRegistry {
662    /// Call a tool on a registered server
663    /// Arguments must include "server_id" field
664    async fn call_tool(&self, name: &str, arguments: Value) -> Result<Value> {
665        let server_id = arguments
666            .get("server_id")
667            .and_then(|v| v.as_str())
668            .ok_or_else(|| McpError::InvalidRequest("Missing 'server_id' field".to_string()))?
669            .to_string();
670
671        self.invoke_tool(&server_id, name, arguments).await
672    }
673
674    /// List all tools from all registered servers
675    async fn list_tools(&self) -> Result<Vec<Value>> {
676        let all_tools = self.list_all_tools().await?;
677        let mut tools = Vec::new();
678
679        for (server_id, server_tools) in all_tools {
680            for tool in server_tools {
681                tools.push(serde_json::json!({
682                    "server_id": server_id,
683                    "name": tool.name,
684                    "description": tool.description,
685                    "inputSchema": tool.input_schema,
686                }));
687            }
688        }
689
690        Ok(tools)
691    }
692}
693
694#[cfg(test)]
695mod tests {
696    use super::*;
697    use crate::servers::FilesystemServer;
698    use std::path::PathBuf;
699
700    #[tokio::test]
701    async fn test_registry_creation() {
702        let registry = McpRegistry::new();
703        let server_ids = registry.list_server_ids().await;
704        assert_eq!(server_ids.len(), 0);
705    }
706
707    #[tokio::test]
708    async fn test_register_server() {
709        let registry = McpRegistry::new();
710        let fs_server = FilesystemServer::new(PathBuf::from("/tmp"));
711
712        registry
713            .register("fs".to_string(), fs_server)
714            .await
715            .unwrap();
716
717        let server_ids = registry.list_server_ids().await;
718        assert_eq!(server_ids.len(), 1);
719        assert!(server_ids.contains(&"fs".to_string()));
720    }
721
722    #[tokio::test]
723    async fn test_unregister_server() {
724        let registry = McpRegistry::new();
725        let fs_server = FilesystemServer::new(PathBuf::from("/tmp"));
726
727        registry
728            .register("fs".to_string(), fs_server)
729            .await
730            .unwrap();
731
732        registry.unregister("fs").await.unwrap();
733
734        let server_ids = registry.list_server_ids().await;
735        assert_eq!(server_ids.len(), 0);
736    }
737
738    #[tokio::test]
739    async fn test_list_tools() {
740        let registry = McpRegistry::new();
741        let fs_server = FilesystemServer::new(PathBuf::from("/tmp"));
742
743        registry
744            .register("fs".to_string(), fs_server)
745            .await
746            .unwrap();
747
748        let tools = registry.list_tools("fs").await.unwrap();
749        assert!(!tools.is_empty());
750        assert!(tools.iter().any(|t| t.name == "fs_read"));
751    }
752
753    #[tokio::test]
754    async fn test_find_tool() {
755        let registry = McpRegistry::new();
756        let fs_server = FilesystemServer::new(PathBuf::from("/tmp"));
757
758        registry
759            .register("fs".to_string(), fs_server)
760            .await
761            .unwrap();
762
763        let servers = registry.find_tool("fs_read").await.unwrap();
764        assert_eq!(servers.len(), 1);
765        assert_eq!(servers[0], "fs");
766    }
767
768    #[tokio::test]
769    async fn test_get_stats() {
770        let registry = McpRegistry::new();
771        let fs_server = FilesystemServer::new(PathBuf::from("/tmp"));
772
773        registry
774            .register("fs".to_string(), fs_server)
775            .await
776            .unwrap();
777
778        let stats = registry.get_stats().await;
779        assert_eq!(stats.server_count, 1);
780        assert!(stats.total_tools > 0);
781    }
782
783    #[tokio::test]
784    async fn test_clear() {
785        let registry = McpRegistry::new();
786        let fs_server = FilesystemServer::new(PathBuf::from("/tmp"));
787
788        registry
789            .register("fs".to_string(), fs_server)
790            .await
791            .unwrap();
792
793        registry.clear().await;
794
795        let server_ids = registry.list_server_ids().await;
796        assert_eq!(server_ids.len(), 0);
797    }
798}