Skip to main content

roboticus_agent/mcp/
manager.rs

1//! MCP connection manager — lifecycle, health checks, and CapabilityRegistry bridging.
2//!
3//! Coordinates raw MCP connections (`LiveMcpConnection`) with the
4//! `CapabilityRegistry` (tool registration) and provides a cancellable
5//! health-check loop driven by a `tokio::sync::watch` channel.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10
11use tokio::sync::{RwLock, watch};
12use tracing::{debug, info, warn};
13
14use roboticus_core::config::{McpServerConfig, McpServerSpec, McpTransport};
15
16use super::bridge::bridge_tools;
17use super::client::{LiveMcpConnection, McpClientError};
18use crate::capability::{Capability, CapabilityRegistry};
19
20// ── Status ────────────────────────────────────────────────────────────────────
21
22/// Status of a single MCP server connection for dashboard/WebUI reporting.
23#[derive(Debug, Clone, serde::Serialize)]
24pub struct McpServerStatus {
25    pub name: String,
26    pub connected: bool,
27    pub tool_count: usize,
28    pub server_name: String,
29    pub server_version: String,
30}
31
32// ── Internal storage ──────────────────────────────────────────────────────────
33
34/// An entry for a managed MCP server connection.
35struct ServerEntry {
36    /// Shared connection handle — held by this manager and by `McpCapability` instances.
37    connection: Arc<RwLock<LiveMcpConnection>>,
38    /// Original config kept for reconnection.
39    config: McpServerConfig,
40}
41
42// ── Manager ───────────────────────────────────────────────────────────────────
43
44/// Manages MCP connections with lifecycle, health checks, and CapabilityRegistry bridging.
45///
46/// `McpConnectionManager` owns the shared `Arc<RwLock<LiveMcpConnection>>`
47/// handles. The same `Arc` is handed to `McpCapability` instances so that
48/// tool calls go to the correct underlying server without locking the manager
49/// itself.
50///
51/// Cancellation is provided through a `tokio::sync::watch` channel; call
52/// [`McpConnectionManager::cancel`] to signal the health-check loop to stop.
53pub struct McpConnectionManager {
54    servers: RwLock<HashMap<String, ServerEntry>>,
55    /// Sender half of the cancellation channel.  `true` means "stop".
56    cancel_tx: watch::Sender<bool>,
57    /// Receiver half shared with `health_check_loop`.
58    cancel_rx: watch::Receiver<bool>,
59}
60
61impl Default for McpConnectionManager {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67impl McpConnectionManager {
68    /// Create a new manager with its own cancellation signal.
69    pub fn new() -> Self {
70        let (cancel_tx, cancel_rx) = watch::channel(false);
71        Self {
72            servers: RwLock::new(HashMap::new()),
73            cancel_tx,
74            cancel_rx,
75        }
76    }
77
78    /// Returns `true` if [`cancel`](Self::cancel) has been called.
79    pub fn is_cancelled(&self) -> bool {
80        *self.cancel_rx.borrow()
81    }
82
83    /// Signal the health-check loop (and any other watchers) to stop.
84    pub fn cancel(&self) {
85        // Ignore the error — it just means all receivers have been dropped.
86        let _ = self.cancel_tx.send(true);
87    }
88
89    /// Returns a new `watch::Receiver` that fires when the manager is cancelled.
90    ///
91    /// Useful for spawning the health-check loop on a separate task:
92    /// ```ignore
93    /// let mut cancel = manager.subscribe_cancel();
94    /// tokio::spawn(async move {
95    ///     manager.health_check_loop(registry, interval, cancel).await;
96    /// });
97    /// ```
98    pub fn subscribe_cancel(&self) -> watch::Receiver<bool> {
99        self.cancel_rx.clone()
100    }
101
102    // ── Connection lifecycle ──────────────────────────────────────────────────
103
104    /// Connect to a server and register its tools with the `CapabilityRegistry`.
105    ///
106    /// Returns the number of tools that were registered.
107    pub async fn connect_server(
108        &self,
109        config: &McpServerConfig,
110        registry: &CapabilityRegistry,
111    ) -> Result<usize, McpClientError> {
112        let conn = LiveMcpConnection::connect(config).await?;
113        self.register_connected_server(config, registry, conn).await
114    }
115
116    async fn register_connected_server(
117        &self,
118        config: &McpServerConfig,
119        registry: &CapabilityRegistry,
120        conn: LiveMcpConnection,
121    ) -> Result<usize, McpClientError> {
122        let tool_count = conn.tools().len();
123
124        let transport = match &config.spec {
125            McpServerSpec::Stdio { .. } => McpTransport::Stdio,
126            McpServerSpec::Sse { .. } => McpTransport::Sse,
127        };
128
129        let conn_arc = Arc::new(RwLock::new(conn));
130
131        {
132            let conn_read = conn_arc.read().await;
133            let caps = bridge_tools(
134                &config.name,
135                conn_read.tools(),
136                transport,
137                Arc::clone(&conn_arc),
138            );
139            let cap_arcs: Vec<Arc<dyn Capability>> =
140                caps.into_iter().map(|c| Arc::new(c) as _).collect();
141
142            if let Err(e) = registry.reload_mcp_server(&config.name, cap_arcs).await {
143                warn!(
144                    server = %config.name,
145                    error = %e,
146                    "failed to register MCP tools in CapabilityRegistry"
147                );
148            }
149        }
150
151        let mut servers = self.servers.write().await;
152        // TOCTOU guard: if another caller already reconnected this server
153        // between our connect() and this write lock acquisition, skip the
154        // redundant insert. The existing connection wins.
155        if let Some(existing) = servers.get(&config.name)
156            && let Ok(existing_conn) = existing.connection.try_read()
157            && existing_conn.is_alive()
158        {
159            debug!(
160                server = %config.name,
161                "MCP server already reconnected by another caller; dropping duplicate"
162            );
163            return Ok(tool_count);
164        }
165        servers.insert(
166            config.name.clone(),
167            ServerEntry {
168                connection: conn_arc,
169                config: config.clone(),
170            },
171        );
172
173        info!(
174            server = %config.name,
175            tool_count,
176            "MCP server connected and tools registered"
177        );
178        Ok(tool_count)
179    }
180
181    /// Disconnect a server and unregister all its tools from the registry.
182    pub async fn disconnect_server(&self, name: &str, registry: &CapabilityRegistry) {
183        let mut servers = self.servers.write().await;
184        if servers.remove(name).is_some() {
185            // Reload with an empty list atomically removes all tools for this server.
186            if let Err(e) = registry.reload_mcp_server(name, vec![]).await {
187                warn!(server = %name, error = %e, "error unregistering MCP tools on disconnect");
188            }
189            info!(server = %name, "MCP server disconnected");
190        }
191    }
192
193    /// Connect to all *enabled* servers in `configs`, logging warnings for failures.
194    pub async fn connect_all(&self, configs: &[McpServerConfig], registry: &CapabilityRegistry) {
195        for cfg in configs {
196            if !cfg.enabled {
197                debug!(name = %cfg.name, "skipping disabled MCP server");
198                continue;
199            }
200            if let Err(e) = self.connect_server(cfg, registry).await {
201                warn!(name = %cfg.name, error = %e, "failed to connect MCP server at startup");
202            }
203        }
204    }
205
206    // ── Status / introspection ────────────────────────────────────────────────
207
208    /// Status snapshot of all managed servers.
209    pub async fn server_statuses(&self) -> Vec<McpServerStatus> {
210        let servers = self.servers.read().await;
211        let mut statuses = Vec::with_capacity(servers.len());
212        for (name, entry) in servers.iter() {
213            let conn = entry.connection.read().await;
214            statuses.push(McpServerStatus {
215                name: name.clone(),
216                connected: conn.is_alive(),
217                tool_count: conn.tools().len(),
218                server_name: conn.server_name().to_string(),
219                server_version: conn.server_version().to_string(),
220            });
221        }
222        statuses
223    }
224
225    /// Number of servers whose transport channel is still alive.
226    pub async fn connected_count(&self) -> usize {
227        let servers = self.servers.read().await;
228        let mut count = 0;
229        for entry in servers.values() {
230            if entry.connection.read().await.is_alive() {
231                count += 1;
232            }
233        }
234        count
235    }
236
237    /// Total number of registered servers (alive or not).
238    pub async fn total_count(&self) -> usize {
239        self.servers.read().await.len()
240    }
241
242    /// Get the shared connection arc for a server, for direct tool dispatch.
243    pub async fn get_connection(&self, name: &str) -> Option<Arc<RwLock<LiveMcpConnection>>> {
244        self.servers
245            .read()
246            .await
247            .get(name)
248            .map(|e| Arc::clone(&e.connection))
249    }
250
251    // ── Health-check loop ─────────────────────────────────────────────────────
252
253    /// Periodically pings all connections and reconnects any that have dropped.
254    ///
255    /// This method runs until [`cancel`](Self::cancel) is called (or the
256    /// provided `cancel_rx` fires). Intended to be spawned as a background task:
257    ///
258    /// ```ignore
259    /// let cancel = manager.subscribe_cancel();
260    /// tokio::spawn(async move {
261    ///     manager.health_check_loop(registry, Duration::from_secs(30), cancel).await;
262    /// });
263    /// ```
264    pub async fn health_check_loop(
265        &self,
266        registry: &CapabilityRegistry,
267        interval: Duration,
268        mut cancel_rx: watch::Receiver<bool>,
269    ) {
270        loop {
271            tokio::select! {
272                _ = tokio::time::sleep(interval) => {}
273                _ = cancel_rx.changed() => {
274                    if *cancel_rx.borrow() {
275                        debug!("MCP health-check loop cancelled");
276                        return;
277                    }
278                }
279            }
280
281            // Collect names of dead connections while holding a read lock.
282            let dead: Vec<McpServerConfig> = {
283                let servers = self.servers.read().await;
284                servers
285                    .values()
286                    .filter_map(|entry| {
287                        // We need a blocking check here; `is_alive` is sync.
288                        // We cannot `.await` inside a closure, so we do a
289                        // try_read and fall back to assuming alive on contention.
290                        if let Ok(conn) = entry.connection.try_read()
291                            && !conn.is_alive()
292                        {
293                            return Some(entry.config.clone());
294                        }
295                        None
296                    })
297                    .collect()
298            };
299
300            for cfg in dead {
301                warn!(server = %cfg.name, "MCP server connection lost — attempting reconnect");
302                match self.connect_server(&cfg, registry).await {
303                    Ok(tool_count) => {
304                        info!(
305                            server = %cfg.name,
306                            tool_count,
307                            "MCP server reconnected — tools re-registered"
308                        );
309                    }
310                    Err(e) => {
311                        warn!(server = %cfg.name, error = %e, "MCP reconnect failed");
312                    }
313                }
314            }
315        }
316    }
317}
318
319// ── Tests ─────────────────────────────────────────────────────────────────────
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use crate::mcp::client::test_support;
325    use std::time::Duration;
326
327    fn test_sse_config(name: &str, enabled: bool) -> McpServerConfig {
328        McpServerConfig {
329            name: name.into(),
330            spec: McpServerSpec::Sse {
331                url: "http://in-memory-test.invalid/mcp".into(),
332            },
333            enabled,
334            auth_token_env: None,
335            tool_allowlist: Vec::new(),
336        }
337    }
338
339    #[test]
340    fn manager_new_is_empty() {
341        let rt = tokio::runtime::Runtime::new().unwrap();
342        rt.block_on(async {
343            let mgr = McpConnectionManager::new();
344            assert_eq!(mgr.total_count().await, 0);
345            assert_eq!(mgr.connected_count().await, 0);
346            assert!(mgr.server_statuses().await.is_empty());
347        });
348    }
349
350    #[test]
351    fn manager_cancellation_works() {
352        let mgr = McpConnectionManager::new();
353        assert!(!mgr.is_cancelled());
354        mgr.cancel();
355        assert!(mgr.is_cancelled());
356    }
357
358    #[test]
359    fn server_status_serializes() {
360        let status = McpServerStatus {
361            name: "github".into(),
362            connected: true,
363            tool_count: 5,
364            server_name: "github-mcp".into(),
365            server_version: "1.0.0".into(),
366        };
367        let json = serde_json::to_string(&status).unwrap();
368        assert!(json.contains("\"name\":\"github\""));
369        assert!(json.contains("\"connected\":true"));
370        assert!(json.contains("\"tool_count\":5"));
371        assert!(json.contains("\"server_name\":\"github-mcp\""));
372        assert!(json.contains("\"server_version\":\"1.0.0\""));
373    }
374
375    #[test]
376    fn manager_default_matches_new() {
377        let rt = tokio::runtime::Runtime::new().unwrap();
378        rt.block_on(async {
379            let mgr = McpConnectionManager::default();
380            assert_eq!(mgr.total_count().await, 0);
381            assert!(!mgr.is_cancelled());
382        });
383    }
384
385    #[test]
386    fn subscribe_cancel_receiver_fires() {
387        let mgr = McpConnectionManager::new();
388        let rx = mgr.subscribe_cancel();
389        assert!(!*rx.borrow());
390        mgr.cancel();
391        // After cancel(), the watch value is true.
392        assert!(*rx.borrow());
393        // changed() should not block — it already has a new value.
394        // We use try_changed() to verify without async overhead.
395        assert!(rx.has_changed().unwrap());
396    }
397
398    #[tokio::test]
399    async fn connect_server_registers_registry_and_status() {
400        let registry = CapabilityRegistry::new();
401        let mgr = McpConnectionManager::new();
402        let config = test_sse_config("remote-test", true);
403        let (conn, server_handle) = test_support::echo_connection(&config.name).await.unwrap();
404
405        let tool_count = mgr
406            .register_connected_server(&config, &registry, conn)
407            .await
408            .unwrap();
409        assert_eq!(tool_count, 1);
410        assert_eq!(mgr.total_count().await, 1);
411        assert_eq!(mgr.connected_count().await, 1);
412        assert!(mgr.get_connection("remote-test").await.is_some());
413        assert!(registry.get("remote-test::echo").await.is_some());
414
415        let statuses = mgr.server_statuses().await;
416        assert_eq!(statuses.len(), 1);
417        assert_eq!(statuses[0].name, "remote-test");
418        assert!(statuses[0].connected);
419        assert_eq!(statuses[0].tool_count, 1);
420
421        server_handle.abort();
422        let _ = server_handle.await;
423    }
424
425    #[tokio::test]
426    async fn disconnect_server_unregisters_registry_capabilities() {
427        let registry = CapabilityRegistry::new();
428        let mgr = McpConnectionManager::new();
429        let config = test_sse_config("remote-test", true);
430        let (conn, server_handle) = test_support::echo_connection(&config.name).await.unwrap();
431        mgr.register_connected_server(&config, &registry, conn)
432            .await
433            .unwrap();
434
435        mgr.disconnect_server("remote-test", &registry).await;
436        assert_eq!(mgr.total_count().await, 0);
437        assert!(mgr.get_connection("remote-test").await.is_none());
438        assert!(registry.get("remote-test::echo").await.is_none());
439
440        server_handle.abort();
441        let _ = server_handle.await;
442    }
443
444    #[tokio::test]
445    async fn connect_all_skips_disabled_servers() {
446        let registry = CapabilityRegistry::new();
447        let mgr = McpConnectionManager::new();
448        let disabled_cfg = test_sse_config("disabled-test", false);
449        mgr.connect_all(std::slice::from_ref(&disabled_cfg), &registry)
450            .await;
451
452        assert_eq!(mgr.total_count().await, 0);
453        assert!(mgr.get_connection("disabled-test").await.is_none());
454        assert!(registry.get("disabled-test::echo").await.is_none());
455        assert!(!disabled_cfg.enabled);
456    }
457
458    #[tokio::test]
459    async fn register_connected_server_supports_connect_all_style_registry_state() {
460        let registry = CapabilityRegistry::new();
461        let mgr = McpConnectionManager::new();
462        let enabled_cfg = test_sse_config("enabled-test", true);
463        let (enabled_conn, enabled_handle) = test_support::echo_connection(&enabled_cfg.name)
464            .await
465            .unwrap();
466
467        mgr.register_connected_server(&enabled_cfg, &registry, enabled_conn)
468            .await
469            .unwrap();
470
471        assert_eq!(mgr.total_count().await, 1);
472        assert!(mgr.get_connection("enabled-test").await.is_some());
473        assert!(mgr.get_connection("disabled-test").await.is_none());
474        assert!(registry.get("enabled-test::echo").await.is_some());
475
476        enabled_handle.abort();
477        let _ = enabled_handle.await;
478    }
479
480    #[tokio::test]
481    async fn health_check_loop_exits_when_cancelled() {
482        let registry = CapabilityRegistry::new();
483        let mgr = McpConnectionManager::new();
484        let cancel = mgr.subscribe_cancel();
485        mgr.cancel();
486
487        tokio::time::timeout(
488            Duration::from_secs(1),
489            mgr.health_check_loop(&registry, Duration::from_millis(10), cancel),
490        )
491        .await
492        .expect("health loop should exit promptly after cancellation");
493    }
494}