1use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10
11use tokio::sync::{RwLock, watch};
12use tracing::{debug, info, warn};
13
14use roboticus_core::config::{McpServerConfig, McpServerSpec, McpTransport};
15
16use super::bridge::bridge_tools;
17use super::client::{LiveMcpConnection, McpClientError};
18use crate::capability::{Capability, CapabilityRegistry};
19
20#[derive(Debug, Clone, serde::Serialize)]
24pub struct McpServerStatus {
25 pub name: String,
26 pub connected: bool,
27 pub tool_count: usize,
28 pub server_name: String,
29 pub server_version: String,
30}
31
32struct ServerEntry {
36 connection: Arc<RwLock<LiveMcpConnection>>,
38 config: McpServerConfig,
40}
41
42pub struct McpConnectionManager {
54 servers: RwLock<HashMap<String, ServerEntry>>,
55 cancel_tx: watch::Sender<bool>,
57 cancel_rx: watch::Receiver<bool>,
59}
60
61impl Default for McpConnectionManager {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67impl McpConnectionManager {
68 pub fn new() -> Self {
70 let (cancel_tx, cancel_rx) = watch::channel(false);
71 Self {
72 servers: RwLock::new(HashMap::new()),
73 cancel_tx,
74 cancel_rx,
75 }
76 }
77
78 pub fn is_cancelled(&self) -> bool {
80 *self.cancel_rx.borrow()
81 }
82
83 pub fn cancel(&self) {
85 let _ = self.cancel_tx.send(true);
87 }
88
89 pub fn subscribe_cancel(&self) -> watch::Receiver<bool> {
99 self.cancel_rx.clone()
100 }
101
102 pub async fn connect_server(
108 &self,
109 config: &McpServerConfig,
110 registry: &CapabilityRegistry,
111 ) -> Result<usize, McpClientError> {
112 let conn = LiveMcpConnection::connect(config).await?;
113 self.register_connected_server(config, registry, conn).await
114 }
115
116 async fn register_connected_server(
117 &self,
118 config: &McpServerConfig,
119 registry: &CapabilityRegistry,
120 conn: LiveMcpConnection,
121 ) -> Result<usize, McpClientError> {
122 let tool_count = conn.tools().len();
123
124 let transport = match &config.spec {
125 McpServerSpec::Stdio { .. } => McpTransport::Stdio,
126 McpServerSpec::Sse { .. } => McpTransport::Sse,
127 };
128
129 let conn_arc = Arc::new(RwLock::new(conn));
130
131 {
132 let conn_read = conn_arc.read().await;
133 let caps = bridge_tools(
134 &config.name,
135 conn_read.tools(),
136 transport,
137 Arc::clone(&conn_arc),
138 );
139 let cap_arcs: Vec<Arc<dyn Capability>> =
140 caps.into_iter().map(|c| Arc::new(c) as _).collect();
141
142 if let Err(e) = registry.reload_mcp_server(&config.name, cap_arcs).await {
143 warn!(
144 server = %config.name,
145 error = %e,
146 "failed to register MCP tools in CapabilityRegistry"
147 );
148 }
149 }
150
151 let mut servers = self.servers.write().await;
152 if let Some(existing) = servers.get(&config.name)
156 && let Ok(existing_conn) = existing.connection.try_read()
157 && existing_conn.is_alive()
158 {
159 debug!(
160 server = %config.name,
161 "MCP server already reconnected by another caller; dropping duplicate"
162 );
163 return Ok(tool_count);
164 }
165 servers.insert(
166 config.name.clone(),
167 ServerEntry {
168 connection: conn_arc,
169 config: config.clone(),
170 },
171 );
172
173 info!(
174 server = %config.name,
175 tool_count,
176 "MCP server connected and tools registered"
177 );
178 Ok(tool_count)
179 }
180
181 pub async fn disconnect_server(&self, name: &str, registry: &CapabilityRegistry) {
183 let mut servers = self.servers.write().await;
184 if servers.remove(name).is_some() {
185 if let Err(e) = registry.reload_mcp_server(name, vec![]).await {
187 warn!(server = %name, error = %e, "error unregistering MCP tools on disconnect");
188 }
189 info!(server = %name, "MCP server disconnected");
190 }
191 }
192
193 pub async fn connect_all(&self, configs: &[McpServerConfig], registry: &CapabilityRegistry) {
195 for cfg in configs {
196 if !cfg.enabled {
197 debug!(name = %cfg.name, "skipping disabled MCP server");
198 continue;
199 }
200 if let Err(e) = self.connect_server(cfg, registry).await {
201 warn!(name = %cfg.name, error = %e, "failed to connect MCP server at startup");
202 }
203 }
204 }
205
206 pub async fn server_statuses(&self) -> Vec<McpServerStatus> {
210 let servers = self.servers.read().await;
211 let mut statuses = Vec::with_capacity(servers.len());
212 for (name, entry) in servers.iter() {
213 let conn = entry.connection.read().await;
214 statuses.push(McpServerStatus {
215 name: name.clone(),
216 connected: conn.is_alive(),
217 tool_count: conn.tools().len(),
218 server_name: conn.server_name().to_string(),
219 server_version: conn.server_version().to_string(),
220 });
221 }
222 statuses
223 }
224
225 pub async fn connected_count(&self) -> usize {
227 let servers = self.servers.read().await;
228 let mut count = 0;
229 for entry in servers.values() {
230 if entry.connection.read().await.is_alive() {
231 count += 1;
232 }
233 }
234 count
235 }
236
237 pub async fn total_count(&self) -> usize {
239 self.servers.read().await.len()
240 }
241
242 pub async fn get_connection(&self, name: &str) -> Option<Arc<RwLock<LiveMcpConnection>>> {
244 self.servers
245 .read()
246 .await
247 .get(name)
248 .map(|e| Arc::clone(&e.connection))
249 }
250
251 pub async fn health_check_loop(
265 &self,
266 registry: &CapabilityRegistry,
267 interval: Duration,
268 mut cancel_rx: watch::Receiver<bool>,
269 ) {
270 loop {
271 tokio::select! {
272 _ = tokio::time::sleep(interval) => {}
273 _ = cancel_rx.changed() => {
274 if *cancel_rx.borrow() {
275 debug!("MCP health-check loop cancelled");
276 return;
277 }
278 }
279 }
280
281 let dead: Vec<McpServerConfig> = {
283 let servers = self.servers.read().await;
284 servers
285 .values()
286 .filter_map(|entry| {
287 if let Ok(conn) = entry.connection.try_read()
291 && !conn.is_alive()
292 {
293 return Some(entry.config.clone());
294 }
295 None
296 })
297 .collect()
298 };
299
300 for cfg in dead {
301 warn!(server = %cfg.name, "MCP server connection lost — attempting reconnect");
302 match self.connect_server(&cfg, registry).await {
303 Ok(tool_count) => {
304 info!(
305 server = %cfg.name,
306 tool_count,
307 "MCP server reconnected — tools re-registered"
308 );
309 }
310 Err(e) => {
311 warn!(server = %cfg.name, error = %e, "MCP reconnect failed");
312 }
313 }
314 }
315 }
316 }
317}
318
319#[cfg(test)]
322mod tests {
323 use super::*;
324 use crate::mcp::client::test_support;
325 use std::time::Duration;
326
327 fn test_sse_config(name: &str, enabled: bool) -> McpServerConfig {
328 McpServerConfig {
329 name: name.into(),
330 spec: McpServerSpec::Sse {
331 url: "http://in-memory-test.invalid/mcp".into(),
332 },
333 enabled,
334 auth_token_env: None,
335 tool_allowlist: Vec::new(),
336 }
337 }
338
339 #[test]
340 fn manager_new_is_empty() {
341 let rt = tokio::runtime::Runtime::new().unwrap();
342 rt.block_on(async {
343 let mgr = McpConnectionManager::new();
344 assert_eq!(mgr.total_count().await, 0);
345 assert_eq!(mgr.connected_count().await, 0);
346 assert!(mgr.server_statuses().await.is_empty());
347 });
348 }
349
350 #[test]
351 fn manager_cancellation_works() {
352 let mgr = McpConnectionManager::new();
353 assert!(!mgr.is_cancelled());
354 mgr.cancel();
355 assert!(mgr.is_cancelled());
356 }
357
358 #[test]
359 fn server_status_serializes() {
360 let status = McpServerStatus {
361 name: "github".into(),
362 connected: true,
363 tool_count: 5,
364 server_name: "github-mcp".into(),
365 server_version: "1.0.0".into(),
366 };
367 let json = serde_json::to_string(&status).unwrap();
368 assert!(json.contains("\"name\":\"github\""));
369 assert!(json.contains("\"connected\":true"));
370 assert!(json.contains("\"tool_count\":5"));
371 assert!(json.contains("\"server_name\":\"github-mcp\""));
372 assert!(json.contains("\"server_version\":\"1.0.0\""));
373 }
374
375 #[test]
376 fn manager_default_matches_new() {
377 let rt = tokio::runtime::Runtime::new().unwrap();
378 rt.block_on(async {
379 let mgr = McpConnectionManager::default();
380 assert_eq!(mgr.total_count().await, 0);
381 assert!(!mgr.is_cancelled());
382 });
383 }
384
385 #[test]
386 fn subscribe_cancel_receiver_fires() {
387 let mgr = McpConnectionManager::new();
388 let rx = mgr.subscribe_cancel();
389 assert!(!*rx.borrow());
390 mgr.cancel();
391 assert!(*rx.borrow());
393 assert!(rx.has_changed().unwrap());
396 }
397
398 #[tokio::test]
399 async fn connect_server_registers_registry_and_status() {
400 let registry = CapabilityRegistry::new();
401 let mgr = McpConnectionManager::new();
402 let config = test_sse_config("remote-test", true);
403 let (conn, server_handle) = test_support::echo_connection(&config.name).await.unwrap();
404
405 let tool_count = mgr
406 .register_connected_server(&config, ®istry, conn)
407 .await
408 .unwrap();
409 assert_eq!(tool_count, 1);
410 assert_eq!(mgr.total_count().await, 1);
411 assert_eq!(mgr.connected_count().await, 1);
412 assert!(mgr.get_connection("remote-test").await.is_some());
413 assert!(registry.get("remote-test::echo").await.is_some());
414
415 let statuses = mgr.server_statuses().await;
416 assert_eq!(statuses.len(), 1);
417 assert_eq!(statuses[0].name, "remote-test");
418 assert!(statuses[0].connected);
419 assert_eq!(statuses[0].tool_count, 1);
420
421 server_handle.abort();
422 let _ = server_handle.await;
423 }
424
425 #[tokio::test]
426 async fn disconnect_server_unregisters_registry_capabilities() {
427 let registry = CapabilityRegistry::new();
428 let mgr = McpConnectionManager::new();
429 let config = test_sse_config("remote-test", true);
430 let (conn, server_handle) = test_support::echo_connection(&config.name).await.unwrap();
431 mgr.register_connected_server(&config, ®istry, conn)
432 .await
433 .unwrap();
434
435 mgr.disconnect_server("remote-test", ®istry).await;
436 assert_eq!(mgr.total_count().await, 0);
437 assert!(mgr.get_connection("remote-test").await.is_none());
438 assert!(registry.get("remote-test::echo").await.is_none());
439
440 server_handle.abort();
441 let _ = server_handle.await;
442 }
443
444 #[tokio::test]
445 async fn connect_all_skips_disabled_servers() {
446 let registry = CapabilityRegistry::new();
447 let mgr = McpConnectionManager::new();
448 let disabled_cfg = test_sse_config("disabled-test", false);
449 mgr.connect_all(std::slice::from_ref(&disabled_cfg), ®istry)
450 .await;
451
452 assert_eq!(mgr.total_count().await, 0);
453 assert!(mgr.get_connection("disabled-test").await.is_none());
454 assert!(registry.get("disabled-test::echo").await.is_none());
455 assert!(!disabled_cfg.enabled);
456 }
457
458 #[tokio::test]
459 async fn register_connected_server_supports_connect_all_style_registry_state() {
460 let registry = CapabilityRegistry::new();
461 let mgr = McpConnectionManager::new();
462 let enabled_cfg = test_sse_config("enabled-test", true);
463 let (enabled_conn, enabled_handle) = test_support::echo_connection(&enabled_cfg.name)
464 .await
465 .unwrap();
466
467 mgr.register_connected_server(&enabled_cfg, ®istry, enabled_conn)
468 .await
469 .unwrap();
470
471 assert_eq!(mgr.total_count().await, 1);
472 assert!(mgr.get_connection("enabled-test").await.is_some());
473 assert!(mgr.get_connection("disabled-test").await.is_none());
474 assert!(registry.get("enabled-test::echo").await.is_some());
475
476 enabled_handle.abort();
477 let _ = enabled_handle.await;
478 }
479
480 #[tokio::test]
481 async fn health_check_loop_exits_when_cancelled() {
482 let registry = CapabilityRegistry::new();
483 let mgr = McpConnectionManager::new();
484 let cancel = mgr.subscribe_cancel();
485 mgr.cancel();
486
487 tokio::time::timeout(
488 Duration::from_secs(1),
489 mgr.health_check_loop(®istry, Duration::from_millis(10), cancel),
490 )
491 .await
492 .expect("health loop should exit promptly after cancellation");
493 }
494}