Skip to main content

model_context_protocol/
server_hub.rs

1//! McpServerHub - A hub that aggregates multiple MCP servers into a single server.
2//!
3//! The McpServerHub connects to multiple external MCP servers and exposes their
4//! tools as a unified MCP server that can be wrapped by McpStdioServer or McpHttpServer.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use mcp::{McpServerHub, McpServerConnectionConfig};
10//! use mcp::server::stdio::McpStdioServer;
11//!
12//! #[tokio::main]
13//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
14//!     // Create a hub that aggregates multiple servers
15//!     let hub = McpServerHub::new("aggregator");
16//!
17//!     // Connect to external servers
18//!     let calc_config = McpServerConnectionConfig::stdio(
19//!         "calculator",
20//!         "node",
21//!         vec!["calc-server.js".into()],
22//!     );
23//!     hub.connect(calc_config).await?;
24//!
25//!     let files_config = McpServerConnectionConfig::stdio(
26//!         "files",
27//!         "python",
28//!         vec!["files-server.py".into()],
29//!     );
30//!     hub.connect(files_config).await?;
31//!
32//!     // Run as a stdio server - all connected tools are now exposed
33//!     McpStdioServer::run(hub.into_config()).await?;
34//!     Ok(())
35//! }
36//! ```
37
38use serde_json::Value;
39use std::sync::atomic::Ordering;
40use std::sync::Arc;
41use std::time::Duration;
42
43use crate::hub_common::HubConnections;
44use crate::protocol::McpToolDefinition;
45use crate::server::McpServerConfig;
46use crate::tool::{BoxFuture, DynTool, McpTool, ToolCallResult, ToolProvider};
47use crate::transport::{McpServerConnectionConfig, McpTransportError};
48
49/// A hub that aggregates multiple MCP servers into a single MCP server.
50///
51/// This allows you to:
52/// - Connect to multiple external MCP servers
53/// - Expose all their tools through a single unified server
54/// - Wrap the hub with McpStdioServer or McpHttpServer
55/// - Automatically restart servers on failure
56/// - Parallel tool discovery for performance
57///
58/// Tools from connected servers are automatically discovered and made available
59/// through the hub's server interface.
60pub struct McpServerHub {
61    /// Hub server name
62    name: String,
63    /// Shared connection infrastructure
64    connections: HubConnections,
65    /// Default timeout for operations
66    timeout: Duration,
67}
68
69impl McpServerHub {
70    /// Create a new hub with the given name.
71    pub fn new(name: impl Into<String>) -> Self {
72        Self {
73            name: name.into(),
74            connections: HubConnections::new(),
75            timeout: Duration::from_secs(30),
76        }
77    }
78
79    /// Create a hub with a custom timeout.
80    pub fn with_timeout(name: impl Into<String>, timeout: Duration) -> Self {
81        Self {
82            name: name.into(),
83            connections: HubConnections::new(),
84            timeout,
85        }
86    }
87
88    /// Connect to an external MCP server.
89    ///
90    /// This method:
91    /// 1. Creates the appropriate transport based on config
92    /// 2. Initializes the connection
93    /// 3. Discovers tools and creates proxy tools for them
94    /// 4. Starts a restart monitor if restart policy is enabled
95    pub async fn connect(
96        self: &Arc<Self>,
97        config: McpServerConnectionConfig,
98    ) -> Result<(), McpTransportError> {
99        let server_name = config.name.clone();
100        let restart_enabled = config.restart_policy.enabled;
101
102        // Connect using shared infrastructure
103        let connection = self.connections.connect(config).await?;
104
105        // Start restart monitor if enabled
106        if restart_enabled {
107            let hub = Arc::clone(self);
108            let conn = Arc::clone(&connection);
109            let name = server_name.clone();
110
111            tokio::spawn(async move {
112                hub.restart_monitor(name, conn).await;
113            });
114        }
115
116        Ok(())
117    }
118
119    /// Monitor a connection and restart on failure.
120    async fn restart_monitor(&self, name: String, conn: Arc<crate::hub_common::ManagedConnection>) {
121        let policy = &conn.config.restart_policy;
122
123        loop {
124            // Wait for health check interval or restart notification
125            tokio::select! {
126                _ = conn.restart_notify.notified() => {}
127                _ = tokio::time::sleep(Duration::from_secs(5)) => {
128                    if conn.is_alive().await {
129                        continue;
130                    }
131                }
132            }
133
134            // Check if shutdown was requested
135            if conn.shutdown_requested.load(Ordering::SeqCst) {
136                break;
137            }
138
139            // Check if transport is still alive (double-check)
140            if conn.is_alive().await {
141                continue;
142            }
143
144            // Server is dead - notify all pending requests to fail immediately
145            conn.notify_failure();
146
147            // Get current attempt count
148            let attempt = conn.restart_count.fetch_add(1, Ordering::SeqCst);
149
150            // Check if we've exceeded max attempts
151            if let Some(max) = policy.max_attempts {
152                if attempt >= max {
153                    eprintln!(
154                        "[McpServerHub] Server '{}' exceeded max restart attempts ({})",
155                        name, max
156                    );
157                    break;
158                }
159            }
160
161            // Calculate delay with exponential backoff
162            let delay = policy.delay_for_attempt(attempt);
163
164            eprintln!(
165                "[McpServerHub] Server '{}' disconnected. Restarting in {}ms (attempt {}/{})",
166                name,
167                delay,
168                attempt + 1,
169                policy
170                    .max_attempts
171                    .map(|m| m.to_string())
172                    .unwrap_or_else(|| "∞".into())
173            );
174
175            tokio::time::sleep(Duration::from_millis(delay)).await;
176
177            // Check again if shutdown was requested during sleep
178            if conn.shutdown_requested.load(Ordering::SeqCst) {
179                break;
180            }
181
182            // Attempt reconnection
183            match self.connections.establish_connection(&conn).await {
184                Ok(_) => {
185                    eprintln!("[McpServerHub] Server '{}' reconnected successfully", name);
186                    conn.restart_count.store(0, Ordering::SeqCst);
187                }
188                Err(e) => {
189                    eprintln!(
190                        "[McpServerHub] Server '{}' failed to reconnect: {}",
191                        name, e
192                    );
193                }
194            }
195        }
196    }
197
198    /// Trigger an immediate restart for a specific server.
199    pub fn trigger_restart(&self, server_name: &str) {
200        if let Some(conn) = self.connections.get(server_name) {
201            conn.restart_notify.notify_one();
202        }
203    }
204
205    /// Call a tool by name, routing to the correct server.
206    ///
207    /// If the server restarts while a request is pending, the request will
208    /// immediately fail with a `ServerRestarting` error rather than timing out.
209    /// Uses circuit breaker to prevent cascading failures.
210    pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpTransportError> {
211        self.connections.call_tool(name, args).await
212    }
213
214    /// List all tools from all connected servers.
215    pub async fn list_tools(&self) -> Result<Vec<(String, McpToolDefinition)>, McpTransportError> {
216        Ok(self.connections.list_tools())
217    }
218
219    /// List all tool definitions.
220    pub async fn list_all_tools(&self) -> Result<Vec<McpToolDefinition>, McpTransportError> {
221        Ok(self.connections.list_tool_definitions())
222    }
223
224    /// Discover tools from all servers in parallel.
225    pub async fn discover_tools_parallel(
226        &self,
227    ) -> Result<Vec<(String, McpToolDefinition)>, McpTransportError> {
228        self.connections.discover_tools_parallel(self.timeout).await
229    }
230
231    /// Refresh tool cache by re-querying all servers (parallel).
232    pub async fn refresh_tools(&self) -> Result<(), McpTransportError> {
233        self.connections.refresh_tools_parallel(self.timeout).await
234    }
235
236    /// Get list of connected server names.
237    pub fn list_servers(&self) -> Vec<String> {
238        self.connections.list_servers()
239    }
240
241    /// Check if a server is connected.
242    pub fn is_connected(&self, server_name: &str) -> bool {
243        self.connections.is_connected(server_name)
244    }
245
246    /// Check if a server is connected and alive.
247    pub async fn is_alive(&self, server_name: &str) -> bool {
248        if let Some(conn) = self.connections.get(server_name) {
249            conn.is_alive().await
250        } else {
251            false
252        }
253    }
254
255    /// Get health status of all servers.
256    pub async fn health_check(&self) -> Vec<(String, bool)> {
257        self.connections.health_check().await
258    }
259
260    /// Get the server name that provides a specific tool.
261    pub fn server_for_tool(&self, tool_name: &str) -> Option<String> {
262        self.connections.server_for_tool(tool_name)
263    }
264
265    /// Disconnect a specific server (stops restart monitor).
266    pub async fn disconnect(&self, server_name: &str) -> Result<(), McpTransportError> {
267        let connection = self
268            .connections
269            .remove(server_name)
270            .ok_or_else(|| McpTransportError::ServerNotFound(server_name.to_string()))?;
271
272        // Signal shutdown to restart monitor
273        connection.shutdown_requested.store(true, Ordering::SeqCst);
274        connection.restart_notify.notify_one();
275
276        // Clear tools for this server
277        self.connections.clear_tools_for_server(server_name);
278
279        // Shutdown transport
280        if let Some(transport) = connection.get_transport().await {
281            transport.shutdown().await?;
282        }
283
284        Ok(())
285    }
286
287    /// Shutdown all connected servers.
288    pub async fn shutdown_all(&self) -> Result<(), McpTransportError> {
289        let names: Vec<String> = self.list_servers();
290        let mut errors = Vec::new();
291
292        for name in names {
293            if let Err(e) = self.disconnect(&name).await {
294                errors.push(format!("{}: {}", name, e));
295            }
296        }
297
298        if errors.is_empty() {
299            Ok(())
300        } else {
301            Err(McpTransportError::TransportError(errors.join("; ")))
302        }
303    }
304
305    /// Convert this hub into an McpServerConfig that can be used with
306    /// McpStdioServer or McpHttpServer.
307    ///
308    /// This creates proxy tools that route calls to the connected external servers.
309    pub fn into_config(self, version: &str) -> McpServerConfig {
310        let hub = Arc::new(self);
311        let provider = HubToolProvider {
312            hub: Arc::clone(&hub),
313        };
314
315        McpServerConfig::builder()
316            .name(&hub.name)
317            .version(version)
318            .with_tools_from(provider)
319            .build()
320    }
321
322    /// Create an McpServerConfig from this hub (keeps hub accessible).
323    ///
324    /// Use this when you need to keep a reference to the hub for direct access.
325    pub fn to_config(self: &Arc<Self>, version: &str) -> McpServerConfig {
326        let provider = HubToolProvider {
327            hub: Arc::clone(self),
328        };
329
330        McpServerConfig::builder()
331            .name(&self.name)
332            .version(version)
333            .with_tools_from(provider)
334            .build()
335    }
336
337    /// Get proxy tools for all connected servers.
338    ///
339    /// Use this when you want to combine hub tools with local tools:
340    /// ```rust,ignore
341    /// let config = McpServerConfig::builder()
342    ///     .name("my-server")
343    ///     .version("1.0.0")
344    ///     .register_tools_in_group("local")  // Local tools
345    ///     .with_tools(hub.proxy_tools())     // Proxied tools
346    ///     .build();
347    /// ```
348    pub fn proxy_tools(self: &Arc<Self>) -> Vec<DynTool> {
349        let provider = HubToolProvider {
350            hub: Arc::clone(self),
351        };
352        provider.tools()
353    }
354
355    /// Get circuit breaker statistics for a server.
356    pub fn circuit_breaker_stats(
357        &self,
358        server_name: &str,
359    ) -> Option<crate::circuit_breaker::CircuitBreakerStats> {
360        self.connections.circuit_breaker_stats(server_name)
361    }
362
363    /// Reset circuit breaker for a server.
364    pub fn reset_circuit_breaker(&self, server_name: &str) {
365        self.connections.reset_circuit_breaker(server_name);
366    }
367}
368
369/// Tool provider that creates proxy tools for all tools in the hub.
370struct HubToolProvider {
371    hub: Arc<McpServerHub>,
372}
373
374impl ToolProvider for HubToolProvider {
375    fn tools(&self) -> Vec<DynTool> {
376        self.hub
377            .connections
378            .list_tools()
379            .into_iter()
380            .map(|(_, def)| {
381                let tool: DynTool = Arc::new(ProxyTool {
382                    name: def.name.clone(),
383                    definition: def,
384                    hub: Arc::clone(&self.hub),
385                });
386                tool
387            })
388            .collect()
389    }
390}
391
392/// A proxy tool that forwards calls to an external MCP server via the hub.
393struct ProxyTool {
394    name: String,
395    definition: McpToolDefinition,
396    hub: Arc<McpServerHub>,
397}
398
399impl McpTool for ProxyTool {
400    fn definition(&self) -> McpToolDefinition {
401        self.definition.clone()
402    }
403
404    fn call<'a>(&'a self, args: Value) -> BoxFuture<'a, ToolCallResult> {
405        let name = self.name.clone();
406        let hub = Arc::clone(&self.hub);
407
408        Box::pin(async move {
409            match hub.call_tool(&name, args).await {
410                Ok(value) => {
411                    // Convert the Value to ToolContent
412                    if let Some(s) = value.as_str() {
413                        Ok(vec![crate::protocol::ToolContent::text(s)])
414                    } else {
415                        Ok(vec![crate::protocol::ToolContent::text(value.to_string())])
416                    }
417                }
418                Err(e) => Err(e.to_string()),
419            }
420        })
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn test_hub_creation() {
430        let hub = McpServerHub::new("test-hub");
431        assert_eq!(hub.name, "test-hub");
432        assert!(hub.list_servers().is_empty());
433    }
434
435    #[tokio::test]
436    async fn test_hub_into_config() {
437        let hub = McpServerHub::new("test-hub");
438        let config = hub.into_config("1.0.0");
439        assert_eq!(config.name(), "test-hub");
440        assert_eq!(config.version(), "1.0.0");
441    }
442
443    #[tokio::test]
444    async fn test_hub_unknown_tool() {
445        let hub = McpServerHub::new("test");
446        let result = hub.call_tool("nonexistent", serde_json::json!({})).await;
447        assert!(matches!(result, Err(McpTransportError::UnknownTool(_))));
448    }
449}