Skip to main content

model_context_protocol/
hub.rs

1//! McpHub - Central hub for MCP tool routing across multiple servers.
2//!
3//! The McpHub manages connections to multiple MCP servers and provides
4//! unified tool discovery, execution, and routing.
5
6use serde_json::Value;
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9use std::time::Duration;
10
11#[cfg(feature = "http")]
12use crate::http::HttpTransportAdapter;
13use crate::protocol::ToolDefinition;
14#[cfg(feature = "stdio")]
15use crate::stdio::StdioTransportAdapter;
16use crate::transport::{
17    McpServerConnectionConfig, McpTransport, McpTransportError, TransportTypeId,
18};
19
20/// Central hub for MCP tool routing across multiple servers.
21///
22/// The McpHub provides:
23/// - Connection management for multiple MCP servers
24/// - Tool discovery and caching
25/// - Automatic routing of tool calls to the correct server
26///
27/// # Example
28///
29/// ```rust,ignore
30/// use mcp::{McpHub, McpServerConnectionConfig};
31///
32/// let hub = McpHub::new();
33///
34/// // Connect to an external server
35/// let config = McpServerConnectionConfig::stdio("my-server", "node", vec!["server.js".into()]);
36/// hub.connect(config).await?;
37///
38/// // List all available tools
39/// let tools = hub.list_all_tools().await?;
40///
41/// // Call a tool (automatically routed to correct server)
42/// let result = hub.call_tool("my_tool", serde_json::json!({"arg": "value"})).await?;
43/// ```
44pub struct McpHub {
45    /// Server name → transport mapping
46    transports: Arc<RwLock<HashMap<String, Arc<dyn McpTransport>>>>,
47
48    /// Tool name → server name mapping for routing
49    tool_cache: Arc<RwLock<HashMap<String, String>>>,
50}
51
52impl Default for McpHub {
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58impl McpHub {
59    /// Create a new empty hub.
60    pub fn new() -> Self {
61        Self {
62            transports: Arc::new(RwLock::new(HashMap::new())),
63            tool_cache: Arc::new(RwLock::new(HashMap::new())),
64        }
65    }
66
67    /// Connect to an MCP server.
68    ///
69    /// This method:
70    /// 1. Creates the appropriate transport based on config
71    /// 2. Initializes the connection
72    /// 3. Discovers tools and caches the mapping
73    /// 4. Returns the transport for direct access if needed
74    pub async fn connect(
75        &self,
76        config: McpServerConnectionConfig,
77    ) -> Result<Arc<dyn McpTransport>, McpTransportError> {
78        let transport: Arc<dyn McpTransport> = match config.transport {
79            #[cfg(feature = "stdio")]
80            TransportTypeId::Stdio => {
81                let command = config.command.ok_or_else(|| {
82                    McpTransportError::TransportError(
83                        "Stdio transport requires command".to_string(),
84                    )
85                })?;
86
87                let transport = StdioTransportAdapter::connect_with_env(
88                    &command,
89                    &config.args,
90                    config.env,
91                    Some(config.config.clone()),
92                    Duration::from_secs(config.timeout_secs),
93                )
94                .await?;
95
96                Arc::new(transport)
97            }
98            #[cfg(not(feature = "stdio"))]
99            TransportTypeId::Stdio => {
100                return Err(McpTransportError::NotSupported(
101                    "Stdio transport not enabled. Enable the 'stdio' feature.".to_string(),
102                ));
103            }
104            #[cfg(feature = "http")]
105            TransportTypeId::Http | TransportTypeId::Sse => {
106                let url = config.url.ok_or_else(|| {
107                    McpTransportError::TransportError("HTTP transport requires URL".to_string())
108                })?;
109
110                let transport = HttpTransportAdapter::with_timeout(
111                    url,
112                    Duration::from_secs(config.timeout_secs),
113                )?;
114
115                Arc::new(transport)
116            }
117            #[cfg(not(feature = "http"))]
118            TransportTypeId::Http | TransportTypeId::Sse => {
119                return Err(McpTransportError::NotSupported(
120                    "HTTP transport not enabled. Enable the 'http' feature.".to_string(),
121                ));
122            }
123        };
124
125        // Discover tools and cache mappings
126        let tools = transport.list_tools().await?;
127
128        {
129            let mut cache = self.tool_cache.write().unwrap();
130            for tool in &tools {
131                cache.insert(tool.name.clone(), config.name.clone());
132            }
133            let mut transports = self.transports.write().unwrap();
134            transports.insert(config.name.clone(), transport.clone());
135        }
136
137        Ok(transport)
138    }
139
140    /// Call a tool, automatically routing to the correct server.
141    pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpTransportError> {
142        // Look up server for this tool
143        let server_name = self
144            .tool_cache
145            .read()
146            .unwrap()
147            .get(name)
148            .cloned()
149            .ok_or_else(|| McpTransportError::UnknownTool(name.to_string()))?;
150
151        // Get transport
152        let transport = self
153            .transports
154            .read()
155            .unwrap()
156            .get(&server_name)
157            .cloned()
158            .ok_or_else(|| McpTransportError::ServerNotFound(server_name.clone()))?;
159
160        // Forward call
161        transport.call_tool(name, args).await
162    }
163
164    /// List all tools from all connected servers.
165    pub async fn list_tools(&self) -> Result<Vec<(String, ToolDefinition)>, McpTransportError> {
166        let mut all_tools = Vec::new();
167
168        let transports = self.transports.read().unwrap().clone();
169        for (server_name, transport) in transports {
170            match transport.list_tools().await {
171                Ok(tools) => {
172                    let mut cache = self.tool_cache.write().unwrap();
173                    for tool in tools {
174                        cache.insert(tool.name.clone(), server_name.clone());
175                        all_tools.push((server_name.clone(), tool));
176                    }
177                }
178                Err(e) => {
179                    eprintln!(
180                        "Warning: Failed to list tools from '{}': {}",
181                        server_name, e
182                    );
183                }
184            }
185        }
186
187        Ok(all_tools)
188    }
189
190    /// Get all registered tools as a flat list.
191    pub async fn list_all_tools(&self) -> Result<Vec<ToolDefinition>, McpTransportError> {
192        let tools_with_servers = self.list_tools().await?;
193        Ok(tools_with_servers
194            .into_iter()
195            .map(|(_, tool)| tool)
196            .collect())
197    }
198
199    /// Populate the tool cache by querying all servers.
200    pub async fn refresh_tool_cache(&self) -> Result<(), McpTransportError> {
201        let _ = self.list_tools().await?;
202        Ok(())
203    }
204
205    /// Manually register a tool in the cache.
206    pub fn register_tool_sync(&self, tool_name: &str, server_name: &str) {
207        self.tool_cache
208            .write()
209            .unwrap()
210            .insert(tool_name.to_string(), server_name.to_string());
211    }
212
213    /// Shutdown all connected servers.
214    pub async fn shutdown_all(&self) -> Result<(), McpTransportError> {
215        let mut errors = Vec::new();
216
217        let transports = std::mem::take(&mut *self.transports.write().unwrap());
218        for (server_name, transport) in transports {
219            if let Err(e) = transport.shutdown().await {
220                errors.push(format!("{}: {}", server_name, e));
221            }
222        }
223        self.tool_cache.write().unwrap().clear();
224
225        if errors.is_empty() {
226            Ok(())
227        } else {
228            Err(McpTransportError::TransportError(errors.join("; ")))
229        }
230    }
231
232    /// Disconnect a specific server.
233    pub async fn disconnect(&self, server_name: &str) -> Result<(), McpTransportError> {
234        let transport = self
235            .transports
236            .write()
237            .unwrap()
238            .remove(server_name)
239            .ok_or_else(|| McpTransportError::ServerNotFound(server_name.to_string()))?;
240
241        // Remove tool cache entries for this server
242        self.tool_cache
243            .write()
244            .unwrap()
245            .retain(|_, server| server != server_name);
246
247        transport.shutdown().await
248    }
249
250    /// Get list of connected server names.
251    pub fn list_servers(&self) -> Vec<String> {
252        self.transports.read().unwrap().keys().cloned().collect()
253    }
254
255    /// Check if a server is connected.
256    pub fn is_connected(&self, server_name: &str) -> bool {
257        self.transports.read().unwrap().contains_key(server_name)
258    }
259
260    /// Get health status of all servers.
261    pub fn health_check(&self) -> Vec<(String, bool)> {
262        self.transports
263            .read()
264            .unwrap()
265            .iter()
266            .map(|(name, transport)| (name.clone(), transport.is_alive()))
267            .collect()
268    }
269
270    /// Get the server name that provides a specific tool.
271    pub fn server_for_tool(&self, tool_name: &str) -> Option<String> {
272        self.tool_cache.read().unwrap().get(tool_name).cloned()
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[tokio::test]
281    async fn test_hub_creation() {
282        let hub = McpHub::new();
283        let servers = hub.list_servers();
284        assert!(servers.is_empty());
285    }
286
287    #[tokio::test]
288    async fn test_hub_unknown_tool() {
289        let hub = McpHub::new();
290
291        let result = hub
292            .call_tool("nonexistent_tool", serde_json::json!({}))
293            .await;
294        assert!(matches!(result, Err(McpTransportError::UnknownTool(_))));
295    }
296
297    #[test]
298    fn test_connection_config() {
299        let config =
300            McpServerConnectionConfig::stdio("test", "node", vec!["server.js".to_string()])
301                .with_timeout(60);
302
303        assert_eq!(config.name, "test");
304        assert_eq!(config.timeout_secs, 60);
305    }
306}