Skip to main content

model_context_protocol/
hub.rs

1//! McpHub - Central hub for MCP tool routing across multiple servers.
2//!
3//! The McpHub manages connections to multiple MCP servers and provides
4//! unified tool discovery, execution, and routing.
5
6use serde_json::Value;
7use std::sync::Arc;
8use std::time::Duration;
9
10use crate::circuit_breaker::CircuitBreakerStats;
11use crate::hub_common::HubConnections;
12use crate::protocol::McpToolDefinition;
13use crate::transport::{McpServerConnectionConfig, McpTransport, McpTransportError};
14
15/// Central hub for MCP tool routing across multiple servers.
16///
17/// The McpHub provides:
18/// - Connection management for multiple MCP servers
19/// - Tool discovery and caching
20/// - Automatic routing of tool calls to the correct server
21/// - Circuit breaker protection for resilience
22/// - Parallel tool discovery for performance
23///
24/// # Example
25///
26/// ```rust,ignore
27/// use mcp::{McpHub, McpServerConnectionConfig};
28///
29/// let hub = McpHub::new();
30///
31/// // Connect to an external server
32/// let config = McpServerConnectionConfig::stdio("my-server", "node", vec!["server.js".into()]);
33/// hub.connect(config).await?;
34///
35/// // List all available tools
36/// let tools = hub.list_all_tools().await?;
37///
38/// // Call a tool (automatically routed to correct server)
39/// let result = hub.call_tool("my_tool", serde_json::json!({"arg": "value"})).await?;
40/// ```
41pub struct McpHub {
42    /// Shared connection infrastructure
43    connections: HubConnections,
44    /// Default timeout for tool discovery
45    discovery_timeout: Duration,
46}
47
48impl Default for McpHub {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl McpHub {
55    /// Create a new empty hub.
56    pub fn new() -> Self {
57        Self {
58            connections: HubConnections::new(),
59            discovery_timeout: Duration::from_secs(30),
60        }
61    }
62
63    /// Create a hub with a custom discovery timeout.
64    pub fn with_discovery_timeout(timeout: Duration) -> Self {
65        Self {
66            connections: HubConnections::new(),
67            discovery_timeout: timeout,
68        }
69    }
70
71    /// Connect to an MCP server.
72    ///
73    /// This method:
74    /// 1. Creates the appropriate transport based on config
75    /// 2. Initializes the connection
76    /// 3. Discovers tools and caches the mapping
77    /// 4. Returns the transport for direct access if needed
78    pub async fn connect(
79        &self,
80        config: McpServerConnectionConfig,
81    ) -> Result<Arc<dyn McpTransport>, McpTransportError> {
82        let conn = self.connections.connect(config).await?;
83        conn.get_transport()
84            .await
85            .ok_or(McpTransportError::ConnectionClosed)
86    }
87
88    /// Call a tool, automatically routing to the correct server.
89    ///
90    /// Uses circuit breaker to prevent cascading failures - if a server is
91    /// unhealthy, requests will be rejected immediately.
92    pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpTransportError> {
93        self.connections.call_tool(name, args).await
94    }
95
96    /// List all tools from all connected servers.
97    pub async fn list_tools(&self) -> Result<Vec<(String, McpToolDefinition)>, McpTransportError> {
98        Ok(self.connections.list_tools())
99    }
100
101    /// Get all registered tools as a flat list.
102    pub async fn list_all_tools(&self) -> Result<Vec<McpToolDefinition>, McpTransportError> {
103        Ok(self.connections.list_tool_definitions())
104    }
105
106    /// Discover tools from all servers in parallel.
107    ///
108    /// This is faster than sequential discovery when connecting to many servers.
109    pub async fn discover_tools_parallel(
110        &self,
111    ) -> Result<Vec<(String, McpToolDefinition)>, McpTransportError> {
112        self.connections
113            .discover_tools_parallel(self.discovery_timeout)
114            .await
115    }
116
117    /// Populate the tool cache by querying all servers (parallel).
118    pub async fn refresh_tool_cache(&self) -> Result<(), McpTransportError> {
119        self.connections
120            .refresh_tools_parallel(self.discovery_timeout)
121            .await
122    }
123
124    /// Shutdown all connected servers.
125    pub async fn shutdown_all(&self) -> Result<(), McpTransportError> {
126        let mut errors = Vec::new();
127
128        for (server_name, conn) in self.connections.iter() {
129            if let Some(transport) = conn.get_transport().await {
130                if let Err(e) = transport.shutdown().await {
131                    errors.push(format!("{}: {}", server_name, e));
132                }
133            }
134        }
135        self.connections.clear();
136
137        if errors.is_empty() {
138            Ok(())
139        } else {
140            Err(McpTransportError::TransportError(errors.join("; ")))
141        }
142    }
143
144    /// Disconnect a specific server.
145    pub async fn disconnect(&self, server_name: &str) -> Result<(), McpTransportError> {
146        let conn = self
147            .connections
148            .remove(server_name)
149            .ok_or_else(|| McpTransportError::ServerNotFound(server_name.to_string()))?;
150
151        self.connections.clear_tools_for_server(server_name);
152
153        if let Some(transport) = conn.get_transport().await {
154            transport.shutdown().await?;
155        }
156        Ok(())
157    }
158
159    /// Get list of connected server names.
160    pub fn list_servers(&self) -> Vec<String> {
161        self.connections.list_servers()
162    }
163
164    /// Check if a server is connected.
165    pub fn is_connected(&self, server_name: &str) -> bool {
166        self.connections.is_connected(server_name)
167    }
168
169    /// Get health status of all servers (includes circuit breaker state).
170    pub async fn health_check(&self) -> Vec<(String, bool)> {
171        self.connections.health_check().await
172    }
173
174    /// Get the server name that provides a specific tool.
175    pub fn server_for_tool(&self, tool_name: &str) -> Option<String> {
176        self.connections.server_for_tool(tool_name)
177    }
178
179    /// Get circuit breaker statistics for a server.
180    pub fn circuit_breaker_stats(&self, server_name: &str) -> Option<CircuitBreakerStats> {
181        self.connections.circuit_breaker_stats(server_name)
182    }
183
184    /// Reset circuit breaker for a server (e.g., after manual recovery).
185    pub fn reset_circuit_breaker(&self, server_name: &str) {
186        self.connections.reset_circuit_breaker(server_name);
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    #[tokio::test]
195    async fn test_hub_creation() {
196        let hub = McpHub::new();
197        let servers = hub.list_servers();
198        assert!(servers.is_empty());
199    }
200
201    #[tokio::test]
202    async fn test_hub_unknown_tool() {
203        let hub = McpHub::new();
204
205        let result = hub
206            .call_tool("nonexistent_tool", serde_json::json!({}))
207            .await;
208        assert!(matches!(result, Err(McpTransportError::UnknownTool(_))));
209    }
210
211    #[test]
212    fn test_connection_config() {
213        let config =
214            McpServerConnectionConfig::stdio("test", "node", vec!["server.js".to_string()])
215                .with_timeout(60);
216
217        assert_eq!(config.name, "test");
218        assert_eq!(config.timeout_secs, 60);
219    }
220}