1use dashmap::DashMap;
10use futures::future::join_all;
11use serde_json::Value;
12use std::sync::atomic::{AtomicBool, AtomicU32};
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::{broadcast, Notify};
16
17use crate::circuit_breaker::CircuitBreaker;
18use crate::protocol::McpToolDefinition;
19use crate::transport::{McpServerConnectionConfig, McpTransport, McpTransportError};
20use crate::transport_factory::TransportFactory;
21
22pub struct ManagedConnection {
24 pub config: McpServerConnectionConfig,
26 pub transport: tokio::sync::RwLock<Option<Arc<dyn McpTransport>>>,
28 pub circuit_breaker: CircuitBreaker,
30 pub restart_count: AtomicU32,
32 pub shutdown_requested: AtomicBool,
34 pub restart_notify: Notify,
36 pub failure_tx: broadcast::Sender<()>,
38}
39
40impl ManagedConnection {
41 pub fn new(config: McpServerConnectionConfig) -> Self {
43 let (failure_tx, _) = broadcast::channel(16);
44 Self {
45 config,
46 transport: tokio::sync::RwLock::new(None),
47 circuit_breaker: CircuitBreaker::new(),
48 restart_count: AtomicU32::new(0),
49 shutdown_requested: AtomicBool::new(false),
50 restart_notify: Notify::new(),
51 failure_tx,
52 }
53 }
54
55 pub async fn is_alive(&self) -> bool {
57 if let Some(transport) = self.transport.read().await.as_ref() {
58 transport.is_alive()
59 } else {
60 false
61 }
62 }
63
64 pub async fn get_transport(&self) -> Option<Arc<dyn McpTransport>> {
66 self.transport.read().await.clone()
67 }
68
69 pub fn subscribe_failures(&self) -> broadcast::Receiver<()> {
71 self.failure_tx.subscribe()
72 }
73
74 pub fn notify_failure(&self) {
76 let _ = self.failure_tx.send(());
77 }
78}
79
80pub struct HubConnections {
82 connections: DashMap<String, Arc<ManagedConnection>>,
84 tool_cache: DashMap<String, (String, Option<McpToolDefinition>)>,
86}
87
88impl Default for HubConnections {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94impl HubConnections {
95 pub fn new() -> Self {
97 Self {
98 connections: DashMap::new(),
99 tool_cache: DashMap::new(),
100 }
101 }
102
103 pub async fn connect(
105 &self,
106 config: McpServerConnectionConfig,
107 ) -> Result<Arc<ManagedConnection>, McpTransportError> {
108 let server_name = config.name.clone();
109 let connection = Arc::new(ManagedConnection::new(config));
110
111 self.establish_connection(&connection).await?;
113
114 self.connections
116 .insert(server_name, Arc::clone(&connection));
117
118 Ok(connection)
119 }
120
121 pub async fn establish_connection(
123 &self,
124 conn: &ManagedConnection,
125 ) -> Result<(), McpTransportError> {
126 let config = &conn.config;
127 let server_name = config.name.clone();
128
129 let transport = TransportFactory::create(config).await?;
131
132 let tools = transport.list_tools().await?;
134
135 self.tool_cache.retain(|_, (srv, _)| srv != &server_name);
137 for tool in tools {
138 self.tool_cache
139 .insert(tool.name.clone(), (server_name.clone(), Some(tool)));
140 }
141
142 *conn.transport.write().await = Some(transport);
144
145 Ok(())
146 }
147
148 pub fn get(&self, server_name: &str) -> Option<Arc<ManagedConnection>> {
150 self.connections.get(server_name).map(|r| r.value().clone())
151 }
152
153 pub fn remove(&self, server_name: &str) -> Option<Arc<ManagedConnection>> {
155 self.connections.remove(server_name).map(|(_, v)| v)
156 }
157
158 pub fn server_for_tool(&self, tool_name: &str) -> Option<String> {
160 self.tool_cache.get(tool_name).map(|r| r.value().0.clone())
161 }
162
163 pub fn get_tool_definition(&self, tool_name: &str) -> Option<McpToolDefinition> {
165 self.tool_cache
166 .get(tool_name)
167 .and_then(|r| r.value().1.clone())
168 }
169
170 pub fn list_servers(&self) -> Vec<String> {
172 self.connections.iter().map(|r| r.key().clone()).collect()
173 }
174
175 pub fn list_tools(&self) -> Vec<(String, McpToolDefinition)> {
177 self.tool_cache
178 .iter()
179 .filter_map(|r| r.value().1.clone().map(|def| (r.value().0.clone(), def)))
180 .collect()
181 }
182
183 pub fn list_tool_definitions(&self) -> Vec<McpToolDefinition> {
185 self.tool_cache
186 .iter()
187 .filter_map(|r| r.value().1.clone())
188 .collect()
189 }
190
191 pub fn is_connected(&self, server_name: &str) -> bool {
193 self.connections.contains_key(server_name)
194 }
195
196 pub fn clear_tools_for_server(&self, server_name: &str) {
198 self.tool_cache.retain(|_, (srv, _)| srv != server_name);
199 }
200
201 pub fn clear(&self) {
203 self.connections.clear();
204 self.tool_cache.clear();
205 }
206
207 pub fn iter(&self) -> impl Iterator<Item = (String, Arc<ManagedConnection>)> + '_ {
209 self.connections
210 .iter()
211 .map(|r| (r.key().clone(), r.value().clone()))
212 }
213
214 pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpTransportError> {
216 let server_name = self
217 .server_for_tool(name)
218 .ok_or_else(|| McpTransportError::UnknownTool(name.to_string()))?;
219
220 let connection = self
221 .get(&server_name)
222 .ok_or_else(|| McpTransportError::ServerNotFound(server_name.clone()))?;
223
224 if !connection.circuit_breaker.allow_request() {
226 return Err(McpTransportError::ServerError(format!(
227 "Server '{}' circuit breaker is open - server is unhealthy",
228 server_name
229 )));
230 }
231
232 let mut failure_rx = connection.subscribe_failures();
234
235 let transport = connection
236 .get_transport()
237 .await
238 .ok_or(McpTransportError::ConnectionClosed)?;
239
240 let result = tokio::select! {
242 result = transport.call_tool(name, args) => result,
243 _ = failure_rx.recv() => {
244 Err(McpTransportError::ServerRestarting(server_name.clone()))
245 }
246 };
247
248 match &result {
250 Ok(_) => connection.circuit_breaker.record_success(),
251 Err(_) => connection.circuit_breaker.record_failure(),
252 }
253
254 result
255 }
256
257 pub async fn discover_tools_parallel(
261 &self,
262 timeout: Duration,
263 ) -> Result<Vec<(String, McpToolDefinition)>, McpTransportError> {
264 let connections: Vec<_> = self.iter().collect();
265
266 let futures: Vec<_> = connections
268 .into_iter()
269 .map(|(server_name, conn)| {
270 let server_name = server_name.clone();
271 async move {
272 let result = tokio::time::timeout(timeout, async {
273 if let Some(transport) = conn.get_transport().await {
274 transport.list_tools().await
275 } else {
276 Err(McpTransportError::ConnectionClosed)
277 }
278 })
279 .await;
280
281 match result {
282 Ok(Ok(tools)) => (server_name, Ok(tools)),
283 Ok(Err(e)) => (server_name, Err(e)),
284 Err(_) => (
285 server_name.clone(),
286 Err(McpTransportError::Timeout(format!(
287 "Tool discovery for '{}' timed out",
288 server_name
289 ))),
290 ),
291 }
292 }
293 })
294 .collect();
295
296 let results = join_all(futures).await;
298
299 let mut all_tools = Vec::new();
301
302 for (server_name, result) in results {
303 match result {
304 Ok(tools) => {
305 for tool in tools {
306 self.tool_cache
307 .insert(tool.name.clone(), (server_name.clone(), Some(tool.clone())));
308 all_tools.push((server_name.clone(), tool));
309 }
310 }
311 Err(e) => {
312 eprintln!(
313 "Warning: Failed to discover tools from '{}': {}",
314 server_name, e
315 );
316 }
317 }
318 }
319
320 Ok(all_tools)
321 }
322
323 pub async fn refresh_tools_parallel(&self, timeout: Duration) -> Result<(), McpTransportError> {
325 self.tool_cache.clear();
327 let _ = self.discover_tools_parallel(timeout).await?;
329 Ok(())
330 }
331
332 pub async fn health_check(&self) -> Vec<(String, bool)> {
334 let connections: Vec<_> = self.iter().collect();
335 let mut results = Vec::new();
336
337 for (name, conn) in connections {
338 let transport_alive = conn.is_alive().await;
339 let circuit_ok = conn.circuit_breaker.allow_request();
340 results.push((name, transport_alive && circuit_ok));
341 }
342
343 results
344 }
345
346 pub fn circuit_breaker_stats(
348 &self,
349 server_name: &str,
350 ) -> Option<crate::circuit_breaker::CircuitBreakerStats> {
351 self.get(server_name).map(|c| c.circuit_breaker.stats())
352 }
353
354 pub fn reset_circuit_breaker(&self, server_name: &str) {
356 if let Some(conn) = self.get(server_name) {
357 conn.circuit_breaker.reset();
358 }
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use crate::transport::RestartPolicy;
366
367 #[test]
368 fn test_restart_policy_delay() {
369 let policy = RestartPolicy {
370 enabled: true,
371 max_attempts: Some(5),
372 delay_ms: 1000,
373 max_delay_ms: 30_000,
374 backoff_multiplier: 2.0,
375 };
376
377 assert_eq!(policy.delay_for_attempt(0), 1000);
378 assert_eq!(policy.delay_for_attempt(1), 2000);
379 assert_eq!(policy.delay_for_attempt(2), 4000);
380 assert_eq!(policy.delay_for_attempt(5), 30_000); }
382
383 #[test]
384 fn test_hub_connections_creation() {
385 let conns = HubConnections::new();
386 assert!(conns.list_servers().is_empty());
387 assert!(conns.list_tool_definitions().is_empty());
388 }
389}