model_context_protocol/
server_hub.rs1use serde_json::Value;
39use std::sync::atomic::Ordering;
40use std::sync::Arc;
41use std::time::Duration;
42
43use crate::hub_common::HubConnections;
44use crate::protocol::McpToolDefinition;
45use crate::server::McpServerConfig;
46use crate::tool::{BoxFuture, DynTool, McpTool, ToolCallResult, ToolProvider};
47use crate::transport::{McpServerConnectionConfig, McpTransportError};
48
49pub struct McpServerHub {
61 name: String,
63 connections: HubConnections,
65 timeout: Duration,
67}
68
69impl McpServerHub {
70 pub fn new(name: impl Into<String>) -> Self {
72 Self {
73 name: name.into(),
74 connections: HubConnections::new(),
75 timeout: Duration::from_secs(30),
76 }
77 }
78
79 pub fn with_timeout(name: impl Into<String>, timeout: Duration) -> Self {
81 Self {
82 name: name.into(),
83 connections: HubConnections::new(),
84 timeout,
85 }
86 }
87
88 pub async fn connect(
96 self: &Arc<Self>,
97 config: McpServerConnectionConfig,
98 ) -> Result<(), McpTransportError> {
99 let server_name = config.name.clone();
100 let restart_enabled = config.restart_policy.enabled;
101
102 let connection = self.connections.connect(config).await?;
104
105 if restart_enabled {
107 let hub = Arc::clone(self);
108 let conn = Arc::clone(&connection);
109 let name = server_name.clone();
110
111 tokio::spawn(async move {
112 hub.restart_monitor(name, conn).await;
113 });
114 }
115
116 Ok(())
117 }
118
119 async fn restart_monitor(&self, name: String, conn: Arc<crate::hub_common::ManagedConnection>) {
121 let policy = &conn.config.restart_policy;
122
123 loop {
124 tokio::select! {
126 _ = conn.restart_notify.notified() => {}
127 _ = tokio::time::sleep(Duration::from_secs(5)) => {
128 if conn.is_alive().await {
129 continue;
130 }
131 }
132 }
133
134 if conn.shutdown_requested.load(Ordering::SeqCst) {
136 break;
137 }
138
139 if conn.is_alive().await {
141 continue;
142 }
143
144 conn.notify_failure();
146
147 let attempt = conn.restart_count.fetch_add(1, Ordering::SeqCst);
149
150 if let Some(max) = policy.max_attempts {
152 if attempt >= max {
153 eprintln!(
154 "[McpServerHub] Server '{}' exceeded max restart attempts ({})",
155 name, max
156 );
157 break;
158 }
159 }
160
161 let delay = policy.delay_for_attempt(attempt);
163
164 eprintln!(
165 "[McpServerHub] Server '{}' disconnected. Restarting in {}ms (attempt {}/{})",
166 name,
167 delay,
168 attempt + 1,
169 policy
170 .max_attempts
171 .map(|m| m.to_string())
172 .unwrap_or_else(|| "∞".into())
173 );
174
175 tokio::time::sleep(Duration::from_millis(delay)).await;
176
177 if conn.shutdown_requested.load(Ordering::SeqCst) {
179 break;
180 }
181
182 match self.connections.establish_connection(&conn).await {
184 Ok(_) => {
185 eprintln!("[McpServerHub] Server '{}' reconnected successfully", name);
186 conn.restart_count.store(0, Ordering::SeqCst);
187 }
188 Err(e) => {
189 eprintln!(
190 "[McpServerHub] Server '{}' failed to reconnect: {}",
191 name, e
192 );
193 }
194 }
195 }
196 }
197
198 pub fn trigger_restart(&self, server_name: &str) {
200 if let Some(conn) = self.connections.get(server_name) {
201 conn.restart_notify.notify_one();
202 }
203 }
204
205 pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpTransportError> {
211 self.connections.call_tool(name, args).await
212 }
213
214 pub async fn list_tools(&self) -> Result<Vec<(String, McpToolDefinition)>, McpTransportError> {
216 Ok(self.connections.list_tools())
217 }
218
219 pub async fn list_all_tools(&self) -> Result<Vec<McpToolDefinition>, McpTransportError> {
221 Ok(self.connections.list_tool_definitions())
222 }
223
224 pub async fn discover_tools_parallel(
226 &self,
227 ) -> Result<Vec<(String, McpToolDefinition)>, McpTransportError> {
228 self.connections.discover_tools_parallel(self.timeout).await
229 }
230
231 pub async fn refresh_tools(&self) -> Result<(), McpTransportError> {
233 self.connections.refresh_tools_parallel(self.timeout).await
234 }
235
236 pub fn list_servers(&self) -> Vec<String> {
238 self.connections.list_servers()
239 }
240
241 pub fn is_connected(&self, server_name: &str) -> bool {
243 self.connections.is_connected(server_name)
244 }
245
246 pub async fn is_alive(&self, server_name: &str) -> bool {
248 if let Some(conn) = self.connections.get(server_name) {
249 conn.is_alive().await
250 } else {
251 false
252 }
253 }
254
255 pub async fn health_check(&self) -> Vec<(String, bool)> {
257 self.connections.health_check().await
258 }
259
260 pub fn server_for_tool(&self, tool_name: &str) -> Option<String> {
262 self.connections.server_for_tool(tool_name)
263 }
264
265 pub async fn disconnect(&self, server_name: &str) -> Result<(), McpTransportError> {
267 let connection = self
268 .connections
269 .remove(server_name)
270 .ok_or_else(|| McpTransportError::ServerNotFound(server_name.to_string()))?;
271
272 connection.shutdown_requested.store(true, Ordering::SeqCst);
274 connection.restart_notify.notify_one();
275
276 self.connections.clear_tools_for_server(server_name);
278
279 if let Some(transport) = connection.get_transport().await {
281 transport.shutdown().await?;
282 }
283
284 Ok(())
285 }
286
287 pub async fn shutdown_all(&self) -> Result<(), McpTransportError> {
289 let names: Vec<String> = self.list_servers();
290 let mut errors = Vec::new();
291
292 for name in names {
293 if let Err(e) = self.disconnect(&name).await {
294 errors.push(format!("{}: {}", name, e));
295 }
296 }
297
298 if errors.is_empty() {
299 Ok(())
300 } else {
301 Err(McpTransportError::TransportError(errors.join("; ")))
302 }
303 }
304
305 pub fn into_config(self, version: &str) -> McpServerConfig {
310 let hub = Arc::new(self);
311 let provider = HubToolProvider {
312 hub: Arc::clone(&hub),
313 };
314
315 McpServerConfig::builder()
316 .name(&hub.name)
317 .version(version)
318 .with_tools_from(provider)
319 .build()
320 }
321
322 pub fn to_config(self: &Arc<Self>, version: &str) -> McpServerConfig {
326 let provider = HubToolProvider {
327 hub: Arc::clone(self),
328 };
329
330 McpServerConfig::builder()
331 .name(&self.name)
332 .version(version)
333 .with_tools_from(provider)
334 .build()
335 }
336
337 pub fn proxy_tools(self: &Arc<Self>) -> Vec<DynTool> {
349 let provider = HubToolProvider {
350 hub: Arc::clone(self),
351 };
352 provider.tools()
353 }
354
355 pub fn circuit_breaker_stats(
357 &self,
358 server_name: &str,
359 ) -> Option<crate::circuit_breaker::CircuitBreakerStats> {
360 self.connections.circuit_breaker_stats(server_name)
361 }
362
363 pub fn reset_circuit_breaker(&self, server_name: &str) {
365 self.connections.reset_circuit_breaker(server_name);
366 }
367}
368
369struct HubToolProvider {
371 hub: Arc<McpServerHub>,
372}
373
374impl ToolProvider for HubToolProvider {
375 fn tools(&self) -> Vec<DynTool> {
376 self.hub
377 .connections
378 .list_tools()
379 .into_iter()
380 .map(|(_, def)| {
381 let tool: DynTool = Arc::new(ProxyTool {
382 name: def.name.clone(),
383 definition: def,
384 hub: Arc::clone(&self.hub),
385 });
386 tool
387 })
388 .collect()
389 }
390}
391
392struct ProxyTool {
394 name: String,
395 definition: McpToolDefinition,
396 hub: Arc<McpServerHub>,
397}
398
399impl McpTool for ProxyTool {
400 fn definition(&self) -> McpToolDefinition {
401 self.definition.clone()
402 }
403
404 fn call<'a>(&'a self, args: Value) -> BoxFuture<'a, ToolCallResult> {
405 let name = self.name.clone();
406 let hub = Arc::clone(&self.hub);
407
408 Box::pin(async move {
409 match hub.call_tool(&name, args).await {
410 Ok(value) => {
411 if let Some(s) = value.as_str() {
413 Ok(vec![crate::protocol::ToolContent::text(s)])
414 } else {
415 Ok(vec![crate::protocol::ToolContent::text(value.to_string())])
416 }
417 }
418 Err(e) => Err(e.to_string()),
419 }
420 })
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
429 fn test_hub_creation() {
430 let hub = McpServerHub::new("test-hub");
431 assert_eq!(hub.name, "test-hub");
432 assert!(hub.list_servers().is_empty());
433 }
434
435 #[tokio::test]
436 async fn test_hub_into_config() {
437 let hub = McpServerHub::new("test-hub");
438 let config = hub.into_config("1.0.0");
439 assert_eq!(config.name(), "test-hub");
440 assert_eq!(config.version(), "1.0.0");
441 }
442
443 #[tokio::test]
444 async fn test_hub_unknown_tool() {
445 let hub = McpServerHub::new("test");
446 let result = hub.call_tool("nonexistent", serde_json::json!({})).await;
447 assert!(matches!(result, Err(McpTransportError::UnknownTool(_))));
448 }
449}