model_context_protocol/
hub.rs1use serde_json::Value;
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9use std::time::Duration;
10
11#[cfg(feature = "http")]
12use crate::http::HttpTransportAdapter;
13use crate::protocol::ToolDefinition;
14#[cfg(feature = "stdio")]
15use crate::stdio::StdioTransportAdapter;
16use crate::transport::{
17 McpServerConnectionConfig, McpTransport, McpTransportError, TransportTypeId,
18};
19
20pub struct McpHub {
45 transports: Arc<RwLock<HashMap<String, Arc<dyn McpTransport>>>>,
47
48 tool_cache: Arc<RwLock<HashMap<String, String>>>,
50}
51
52impl Default for McpHub {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58impl McpHub {
59 pub fn new() -> Self {
61 Self {
62 transports: Arc::new(RwLock::new(HashMap::new())),
63 tool_cache: Arc::new(RwLock::new(HashMap::new())),
64 }
65 }
66
67 pub async fn connect(
75 &self,
76 config: McpServerConnectionConfig,
77 ) -> Result<Arc<dyn McpTransport>, McpTransportError> {
78 let transport: Arc<dyn McpTransport> = match config.transport {
79 #[cfg(feature = "stdio")]
80 TransportTypeId::Stdio => {
81 let command = config.command.ok_or_else(|| {
82 McpTransportError::TransportError(
83 "Stdio transport requires command".to_string(),
84 )
85 })?;
86
87 let transport = StdioTransportAdapter::connect_with_env(
88 &command,
89 &config.args,
90 config.env,
91 Some(config.config.clone()),
92 Duration::from_secs(config.timeout_secs),
93 )
94 .await?;
95
96 Arc::new(transport)
97 }
98 #[cfg(not(feature = "stdio"))]
99 TransportTypeId::Stdio => {
100 return Err(McpTransportError::NotSupported(
101 "Stdio transport not enabled. Enable the 'stdio' feature.".to_string(),
102 ));
103 }
104 #[cfg(feature = "http")]
105 TransportTypeId::Http | TransportTypeId::Sse => {
106 let url = config.url.ok_or_else(|| {
107 McpTransportError::TransportError("HTTP transport requires URL".to_string())
108 })?;
109
110 let transport = HttpTransportAdapter::with_timeout(
111 url,
112 Duration::from_secs(config.timeout_secs),
113 )?;
114
115 Arc::new(transport)
116 }
117 #[cfg(not(feature = "http"))]
118 TransportTypeId::Http | TransportTypeId::Sse => {
119 return Err(McpTransportError::NotSupported(
120 "HTTP transport not enabled. Enable the 'http' feature.".to_string(),
121 ));
122 }
123 };
124
125 let tools = transport.list_tools().await?;
127
128 {
129 let mut cache = self.tool_cache.write().unwrap();
130 for tool in &tools {
131 cache.insert(tool.name.clone(), config.name.clone());
132 }
133 let mut transports = self.transports.write().unwrap();
134 transports.insert(config.name.clone(), transport.clone());
135 }
136
137 Ok(transport)
138 }
139
140 pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpTransportError> {
142 let server_name = self
144 .tool_cache
145 .read()
146 .unwrap()
147 .get(name)
148 .cloned()
149 .ok_or_else(|| McpTransportError::UnknownTool(name.to_string()))?;
150
151 let transport = self
153 .transports
154 .read()
155 .unwrap()
156 .get(&server_name)
157 .cloned()
158 .ok_or_else(|| McpTransportError::ServerNotFound(server_name.clone()))?;
159
160 transport.call_tool(name, args).await
162 }
163
164 pub async fn list_tools(&self) -> Result<Vec<(String, ToolDefinition)>, McpTransportError> {
166 let mut all_tools = Vec::new();
167
168 let transports = self.transports.read().unwrap().clone();
169 for (server_name, transport) in transports {
170 match transport.list_tools().await {
171 Ok(tools) => {
172 let mut cache = self.tool_cache.write().unwrap();
173 for tool in tools {
174 cache.insert(tool.name.clone(), server_name.clone());
175 all_tools.push((server_name.clone(), tool));
176 }
177 }
178 Err(e) => {
179 eprintln!(
180 "Warning: Failed to list tools from '{}': {}",
181 server_name, e
182 );
183 }
184 }
185 }
186
187 Ok(all_tools)
188 }
189
190 pub async fn list_all_tools(&self) -> Result<Vec<ToolDefinition>, McpTransportError> {
192 let tools_with_servers = self.list_tools().await?;
193 Ok(tools_with_servers
194 .into_iter()
195 .map(|(_, tool)| tool)
196 .collect())
197 }
198
199 pub async fn refresh_tool_cache(&self) -> Result<(), McpTransportError> {
201 let _ = self.list_tools().await?;
202 Ok(())
203 }
204
205 pub fn register_tool_sync(&self, tool_name: &str, server_name: &str) {
207 self.tool_cache
208 .write()
209 .unwrap()
210 .insert(tool_name.to_string(), server_name.to_string());
211 }
212
213 pub async fn shutdown_all(&self) -> Result<(), McpTransportError> {
215 let mut errors = Vec::new();
216
217 let transports = std::mem::take(&mut *self.transports.write().unwrap());
218 for (server_name, transport) in transports {
219 if let Err(e) = transport.shutdown().await {
220 errors.push(format!("{}: {}", server_name, e));
221 }
222 }
223 self.tool_cache.write().unwrap().clear();
224
225 if errors.is_empty() {
226 Ok(())
227 } else {
228 Err(McpTransportError::TransportError(errors.join("; ")))
229 }
230 }
231
232 pub async fn disconnect(&self, server_name: &str) -> Result<(), McpTransportError> {
234 let transport = self
235 .transports
236 .write()
237 .unwrap()
238 .remove(server_name)
239 .ok_or_else(|| McpTransportError::ServerNotFound(server_name.to_string()))?;
240
241 self.tool_cache
243 .write()
244 .unwrap()
245 .retain(|_, server| server != server_name);
246
247 transport.shutdown().await
248 }
249
250 pub fn list_servers(&self) -> Vec<String> {
252 self.transports.read().unwrap().keys().cloned().collect()
253 }
254
255 pub fn is_connected(&self, server_name: &str) -> bool {
257 self.transports.read().unwrap().contains_key(server_name)
258 }
259
260 pub fn health_check(&self) -> Vec<(String, bool)> {
262 self.transports
263 .read()
264 .unwrap()
265 .iter()
266 .map(|(name, transport)| (name.clone(), transport.is_alive()))
267 .collect()
268 }
269
270 pub fn server_for_tool(&self, tool_name: &str) -> Option<String> {
272 self.tool_cache.read().unwrap().get(tool_name).cloned()
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[tokio::test]
281 async fn test_hub_creation() {
282 let hub = McpHub::new();
283 let servers = hub.list_servers();
284 assert!(servers.is_empty());
285 }
286
287 #[tokio::test]
288 async fn test_hub_unknown_tool() {
289 let hub = McpHub::new();
290
291 let result = hub
292 .call_tool("nonexistent_tool", serde_json::json!({}))
293 .await;
294 assert!(matches!(result, Err(McpTransportError::UnknownTool(_))));
295 }
296
297 #[test]
298 fn test_connection_config() {
299 let config =
300 McpServerConnectionConfig::stdio("test", "node", vec!["server.js".to_string()])
301 .with_timeout(60);
302
303 assert_eq!(config.name, "test");
304 assert_eq!(config.timeout_secs, 60);
305 }
306}