Skip to main content

model_context_protocol/
hub_common.rs

1//! Common hub infrastructure shared between McpHub and McpServerHub.
2//!
3//! This module provides:
4//! - Shared connection state management
5//! - Circuit breaker integration
6//! - Tool cache management
7//! - Parallel tool discovery
8
9use dashmap::DashMap;
10use futures::future::join_all;
11use serde_json::Value;
12use std::sync::atomic::{AtomicBool, AtomicU32};
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::{broadcast, Notify};
16
17use crate::circuit_breaker::CircuitBreaker;
18use crate::protocol::McpToolDefinition;
19use crate::transport::{McpServerConnectionConfig, McpTransport, McpTransportError};
20use crate::transport_factory::TransportFactory;
21
22/// Connection state for a managed server.
23pub struct ManagedConnection {
24    /// Original configuration (used for restarts)
25    pub config: McpServerConnectionConfig,
26    /// Current transport (may be replaced on restart)
27    pub transport: tokio::sync::RwLock<Option<Arc<dyn McpTransport>>>,
28    /// Circuit breaker for resilient connections
29    pub circuit_breaker: CircuitBreaker,
30    /// Number of restart attempts
31    pub restart_count: AtomicU32,
32    /// Whether shutdown has been requested
33    pub shutdown_requested: AtomicBool,
34    /// Notifier for restart events
35    pub restart_notify: Notify,
36    /// Broadcast channel to notify pending requests of failure
37    pub failure_tx: broadcast::Sender<()>,
38}
39
40impl ManagedConnection {
41    /// Create a new managed connection.
42    pub fn new(config: McpServerConnectionConfig) -> Self {
43        let (failure_tx, _) = broadcast::channel(16);
44        Self {
45            config,
46            transport: tokio::sync::RwLock::new(None),
47            circuit_breaker: CircuitBreaker::new(),
48            restart_count: AtomicU32::new(0),
49            shutdown_requested: AtomicBool::new(false),
50            restart_notify: Notify::new(),
51            failure_tx,
52        }
53    }
54
55    /// Check if the connection is alive.
56    pub async fn is_alive(&self) -> bool {
57        if let Some(transport) = self.transport.read().await.as_ref() {
58            transport.is_alive()
59        } else {
60            false
61        }
62    }
63
64    /// Get the transport if available.
65    pub async fn get_transport(&self) -> Option<Arc<dyn McpTransport>> {
66        self.transport.read().await.clone()
67    }
68
69    /// Subscribe to failure notifications.
70    pub fn subscribe_failures(&self) -> broadcast::Receiver<()> {
71        self.failure_tx.subscribe()
72    }
73
74    /// Notify pending requests of failure.
75    pub fn notify_failure(&self) {
76        let _ = self.failure_tx.send(());
77    }
78}
79
80/// Managed hub connections with shared infrastructure.
81pub struct HubConnections {
82    /// Server name → connection mapping
83    connections: DashMap<String, Arc<ManagedConnection>>,
84    /// Tool name → (server name, optional tool definition)
85    tool_cache: DashMap<String, (String, Option<McpToolDefinition>)>,
86}
87
88impl Default for HubConnections {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94impl HubConnections {
95    /// Create a new connection manager.
96    pub fn new() -> Self {
97        Self {
98            connections: DashMap::new(),
99            tool_cache: DashMap::new(),
100        }
101    }
102
103    /// Establish a connection to a server.
104    pub async fn connect(
105        &self,
106        config: McpServerConnectionConfig,
107    ) -> Result<Arc<ManagedConnection>, McpTransportError> {
108        let server_name = config.name.clone();
109        let connection = Arc::new(ManagedConnection::new(config));
110
111        // Establish initial connection
112        self.establish_connection(&connection).await?;
113
114        // Store connection
115        self.connections
116            .insert(server_name, Arc::clone(&connection));
117
118        Ok(connection)
119    }
120
121    /// Establish or re-establish a connection.
122    pub async fn establish_connection(
123        &self,
124        conn: &ManagedConnection,
125    ) -> Result<(), McpTransportError> {
126        let config = &conn.config;
127        let server_name = config.name.clone();
128
129        // Use transport factory for unified transport creation
130        let transport = TransportFactory::create(config).await?;
131
132        // Discover tools and cache them (clear old tools first)
133        let tools = transport.list_tools().await?;
134
135        // Remove old tools for this server, then add new ones
136        self.tool_cache.retain(|_, (srv, _)| srv != &server_name);
137        for tool in tools {
138            self.tool_cache
139                .insert(tool.name.clone(), (server_name.clone(), Some(tool)));
140        }
141
142        // Store transport
143        *conn.transport.write().await = Some(transport);
144
145        Ok(())
146    }
147
148    /// Get a connection by server name.
149    pub fn get(&self, server_name: &str) -> Option<Arc<ManagedConnection>> {
150        self.connections.get(server_name).map(|r| r.value().clone())
151    }
152
153    /// Remove a connection.
154    pub fn remove(&self, server_name: &str) -> Option<Arc<ManagedConnection>> {
155        self.connections.remove(server_name).map(|(_, v)| v)
156    }
157
158    /// Get server name for a tool.
159    pub fn server_for_tool(&self, tool_name: &str) -> Option<String> {
160        self.tool_cache.get(tool_name).map(|r| r.value().0.clone())
161    }
162
163    /// Get tool definition by name.
164    pub fn get_tool_definition(&self, tool_name: &str) -> Option<McpToolDefinition> {
165        self.tool_cache
166            .get(tool_name)
167            .and_then(|r| r.value().1.clone())
168    }
169
170    /// List all server names.
171    pub fn list_servers(&self) -> Vec<String> {
172        self.connections.iter().map(|r| r.key().clone()).collect()
173    }
174
175    /// List all tools with their server names.
176    pub fn list_tools(&self) -> Vec<(String, McpToolDefinition)> {
177        self.tool_cache
178            .iter()
179            .filter_map(|r| r.value().1.clone().map(|def| (r.value().0.clone(), def)))
180            .collect()
181    }
182
183    /// List all tool definitions.
184    pub fn list_tool_definitions(&self) -> Vec<McpToolDefinition> {
185        self.tool_cache
186            .iter()
187            .filter_map(|r| r.value().1.clone())
188            .collect()
189    }
190
191    /// Check if a server is connected.
192    pub fn is_connected(&self, server_name: &str) -> bool {
193        self.connections.contains_key(server_name)
194    }
195
196    /// Clear tool cache for a server.
197    pub fn clear_tools_for_server(&self, server_name: &str) {
198        self.tool_cache.retain(|_, (srv, _)| srv != server_name);
199    }
200
201    /// Clear all connections and tool cache.
202    pub fn clear(&self) {
203        self.connections.clear();
204        self.tool_cache.clear();
205    }
206
207    /// Iterate over all connections.
208    pub fn iter(&self) -> impl Iterator<Item = (String, Arc<ManagedConnection>)> + '_ {
209        self.connections
210            .iter()
211            .map(|r| (r.key().clone(), r.value().clone()))
212    }
213
214    /// Call a tool with circuit breaker and failure handling.
215    pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpTransportError> {
216        let server_name = self
217            .server_for_tool(name)
218            .ok_or_else(|| McpTransportError::UnknownTool(name.to_string()))?;
219
220        let connection = self
221            .get(&server_name)
222            .ok_or_else(|| McpTransportError::ServerNotFound(server_name.clone()))?;
223
224        // Check circuit breaker
225        if !connection.circuit_breaker.allow_request() {
226            return Err(McpTransportError::ServerError(format!(
227                "Server '{}' circuit breaker is open - server is unhealthy",
228                server_name
229            )));
230        }
231
232        // Subscribe to failure notifications before getting transport
233        let mut failure_rx = connection.subscribe_failures();
234
235        let transport = connection
236            .get_transport()
237            .await
238            .ok_or(McpTransportError::ConnectionClosed)?;
239
240        // Race between the actual tool call and a failure notification
241        let result = tokio::select! {
242            result = transport.call_tool(name, args) => result,
243            _ = failure_rx.recv() => {
244                Err(McpTransportError::ServerRestarting(server_name.clone()))
245            }
246        };
247
248        // Record result in circuit breaker
249        match &result {
250            Ok(_) => connection.circuit_breaker.record_success(),
251            Err(_) => connection.circuit_breaker.record_failure(),
252        }
253
254        result
255    }
256
257    /// Discover tools from all servers in parallel.
258    ///
259    /// This is much faster than sequential discovery when connecting to many servers.
260    pub async fn discover_tools_parallel(
261        &self,
262        timeout: Duration,
263    ) -> Result<Vec<(String, McpToolDefinition)>, McpTransportError> {
264        let connections: Vec<_> = self.iter().collect();
265
266        // Create futures for each server's tool discovery
267        let futures: Vec<_> = connections
268            .into_iter()
269            .map(|(server_name, conn)| {
270                let server_name = server_name.clone();
271                async move {
272                    let result = tokio::time::timeout(timeout, async {
273                        if let Some(transport) = conn.get_transport().await {
274                            transport.list_tools().await
275                        } else {
276                            Err(McpTransportError::ConnectionClosed)
277                        }
278                    })
279                    .await;
280
281                    match result {
282                        Ok(Ok(tools)) => (server_name, Ok(tools)),
283                        Ok(Err(e)) => (server_name, Err(e)),
284                        Err(_) => (
285                            server_name.clone(),
286                            Err(McpTransportError::Timeout(format!(
287                                "Tool discovery for '{}' timed out",
288                                server_name
289                            ))),
290                        ),
291                    }
292                }
293            })
294            .collect();
295
296        // Run all discoveries in parallel
297        let results = join_all(futures).await;
298
299        // Collect results and update cache
300        let mut all_tools = Vec::new();
301
302        for (server_name, result) in results {
303            match result {
304                Ok(tools) => {
305                    for tool in tools {
306                        self.tool_cache
307                            .insert(tool.name.clone(), (server_name.clone(), Some(tool.clone())));
308                        all_tools.push((server_name.clone(), tool));
309                    }
310                }
311                Err(e) => {
312                    eprintln!(
313                        "Warning: Failed to discover tools from '{}': {}",
314                        server_name, e
315                    );
316                }
317            }
318        }
319
320        Ok(all_tools)
321    }
322
323    /// Refresh tool cache from all servers in parallel.
324    pub async fn refresh_tools_parallel(&self, timeout: Duration) -> Result<(), McpTransportError> {
325        // Clear existing cache
326        self.tool_cache.clear();
327        // Discover tools
328        let _ = self.discover_tools_parallel(timeout).await?;
329        Ok(())
330    }
331
332    /// Get health status of all servers.
333    pub async fn health_check(&self) -> Vec<(String, bool)> {
334        let connections: Vec<_> = self.iter().collect();
335        let mut results = Vec::new();
336
337        for (name, conn) in connections {
338            let transport_alive = conn.is_alive().await;
339            let circuit_ok = conn.circuit_breaker.allow_request();
340            results.push((name, transport_alive && circuit_ok));
341        }
342
343        results
344    }
345
346    /// Get circuit breaker statistics for a server.
347    pub fn circuit_breaker_stats(
348        &self,
349        server_name: &str,
350    ) -> Option<crate::circuit_breaker::CircuitBreakerStats> {
351        self.get(server_name).map(|c| c.circuit_breaker.stats())
352    }
353
354    /// Reset circuit breaker for a server.
355    pub fn reset_circuit_breaker(&self, server_name: &str) {
356        if let Some(conn) = self.get(server_name) {
357            conn.circuit_breaker.reset();
358        }
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use crate::transport::RestartPolicy;
366
367    #[test]
368    fn test_restart_policy_delay() {
369        let policy = RestartPolicy {
370            enabled: true,
371            max_attempts: Some(5),
372            delay_ms: 1000,
373            max_delay_ms: 30_000,
374            backoff_multiplier: 2.0,
375        };
376
377        assert_eq!(policy.delay_for_attempt(0), 1000);
378        assert_eq!(policy.delay_for_attempt(1), 2000);
379        assert_eq!(policy.delay_for_attempt(2), 4000);
380        assert_eq!(policy.delay_for_attempt(5), 30_000); // Capped
381    }
382
383    #[test]
384    fn test_hub_connections_creation() {
385        let conns = HubConnections::new();
386        assert!(conns.list_servers().is_empty());
387        assert!(conns.list_tool_definitions().is_empty());
388    }
389}