use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{RwLock, watch};
use tracing::{debug, info, warn};
use roboticus_core::config::{McpServerConfig, McpServerSpec, McpTransport};
use super::bridge::bridge_tools;
use super::client::{LiveMcpConnection, McpClientError};
use crate::capability::{Capability, CapabilityRegistry};
#[derive(Debug, Clone, serde::Serialize)]
pub struct McpServerStatus {
pub name: String,
pub connected: bool,
pub tool_count: usize,
pub server_name: String,
pub server_version: String,
}
struct ServerEntry {
connection: Arc<RwLock<LiveMcpConnection>>,
config: McpServerConfig,
}
pub struct McpConnectionManager {
servers: RwLock<HashMap<String, ServerEntry>>,
cancel_tx: watch::Sender<bool>,
cancel_rx: watch::Receiver<bool>,
}
impl Default for McpConnectionManager {
fn default() -> Self {
Self::new()
}
}
impl McpConnectionManager {
pub fn new() -> Self {
let (cancel_tx, cancel_rx) = watch::channel(false);
Self {
servers: RwLock::new(HashMap::new()),
cancel_tx,
cancel_rx,
}
}
pub fn is_cancelled(&self) -> bool {
*self.cancel_rx.borrow()
}
pub fn cancel(&self) {
let _ = self.cancel_tx.send(true);
}
pub fn subscribe_cancel(&self) -> watch::Receiver<bool> {
self.cancel_rx.clone()
}
pub async fn connect_server(
&self,
config: &McpServerConfig,
registry: &CapabilityRegistry,
) -> Result<usize, McpClientError> {
let conn = LiveMcpConnection::connect(config).await?;
self.register_connected_server(config, registry, conn).await
}
async fn register_connected_server(
&self,
config: &McpServerConfig,
registry: &CapabilityRegistry,
conn: LiveMcpConnection,
) -> Result<usize, McpClientError> {
let tool_count = conn.tools().len();
let transport = match &config.spec {
McpServerSpec::Stdio { .. } => McpTransport::Stdio,
McpServerSpec::Sse { .. } => McpTransport::Sse,
};
let conn_arc = Arc::new(RwLock::new(conn));
{
let conn_read = conn_arc.read().await;
let caps = bridge_tools(
&config.name,
conn_read.tools(),
transport,
Arc::clone(&conn_arc),
);
let cap_arcs: Vec<Arc<dyn Capability>> =
caps.into_iter().map(|c| Arc::new(c) as _).collect();
if let Err(e) = registry.reload_mcp_server(&config.name, cap_arcs).await {
warn!(
server = %config.name,
error = %e,
"failed to register MCP tools in CapabilityRegistry"
);
}
}
let mut servers = self.servers.write().await;
if let Some(existing) = servers.get(&config.name)
&& let Ok(existing_conn) = existing.connection.try_read()
&& existing_conn.is_alive()
{
debug!(
server = %config.name,
"MCP server already reconnected by another caller; dropping duplicate"
);
return Ok(tool_count);
}
servers.insert(
config.name.clone(),
ServerEntry {
connection: conn_arc,
config: config.clone(),
},
);
info!(
server = %config.name,
tool_count,
"MCP server connected and tools registered"
);
Ok(tool_count)
}
pub async fn disconnect_server(&self, name: &str, registry: &CapabilityRegistry) {
let mut servers = self.servers.write().await;
if servers.remove(name).is_some() {
if let Err(e) = registry.reload_mcp_server(name, vec![]).await {
warn!(server = %name, error = %e, "error unregistering MCP tools on disconnect");
}
info!(server = %name, "MCP server disconnected");
}
}
pub async fn connect_all(&self, configs: &[McpServerConfig], registry: &CapabilityRegistry) {
for cfg in configs {
if !cfg.enabled {
debug!(name = %cfg.name, "skipping disabled MCP server");
continue;
}
if let Err(e) = self.connect_server(cfg, registry).await {
warn!(name = %cfg.name, error = %e, "failed to connect MCP server at startup");
}
}
}
pub async fn server_statuses(&self) -> Vec<McpServerStatus> {
let servers = self.servers.read().await;
let mut statuses = Vec::with_capacity(servers.len());
for (name, entry) in servers.iter() {
let conn = entry.connection.read().await;
statuses.push(McpServerStatus {
name: name.clone(),
connected: conn.is_alive(),
tool_count: conn.tools().len(),
server_name: conn.server_name().to_string(),
server_version: conn.server_version().to_string(),
});
}
statuses
}
pub async fn connected_count(&self) -> usize {
let servers = self.servers.read().await;
let mut count = 0;
for entry in servers.values() {
if entry.connection.read().await.is_alive() {
count += 1;
}
}
count
}
pub async fn total_count(&self) -> usize {
self.servers.read().await.len()
}
pub async fn get_connection(&self, name: &str) -> Option<Arc<RwLock<LiveMcpConnection>>> {
self.servers
.read()
.await
.get(name)
.map(|e| Arc::clone(&e.connection))
}
pub async fn health_check_loop(
&self,
registry: &CapabilityRegistry,
interval: Duration,
mut cancel_rx: watch::Receiver<bool>,
) {
loop {
tokio::select! {
_ = tokio::time::sleep(interval) => {}
_ = cancel_rx.changed() => {
if *cancel_rx.borrow() {
debug!("MCP health-check loop cancelled");
return;
}
}
}
let dead: Vec<McpServerConfig> = {
let servers = self.servers.read().await;
servers
.values()
.filter_map(|entry| {
if let Ok(conn) = entry.connection.try_read()
&& !conn.is_alive()
{
return Some(entry.config.clone());
}
None
})
.collect()
};
for cfg in dead {
warn!(server = %cfg.name, "MCP server connection lost — attempting reconnect");
match self.connect_server(&cfg, registry).await {
Ok(tool_count) => {
info!(
server = %cfg.name,
tool_count,
"MCP server reconnected — tools re-registered"
);
}
Err(e) => {
warn!(server = %cfg.name, error = %e, "MCP reconnect failed");
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mcp::client::test_support;
use std::time::Duration;
fn test_sse_config(name: &str, enabled: bool) -> McpServerConfig {
McpServerConfig {
name: name.into(),
spec: McpServerSpec::Sse {
url: "http://in-memory-test.invalid/mcp".into(),
},
enabled,
auth_token_env: None,
tool_allowlist: Vec::new(),
}
}
#[test]
fn manager_new_is_empty() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let mgr = McpConnectionManager::new();
assert_eq!(mgr.total_count().await, 0);
assert_eq!(mgr.connected_count().await, 0);
assert!(mgr.server_statuses().await.is_empty());
});
}
#[test]
fn manager_cancellation_works() {
let mgr = McpConnectionManager::new();
assert!(!mgr.is_cancelled());
mgr.cancel();
assert!(mgr.is_cancelled());
}
#[test]
fn server_status_serializes() {
let status = McpServerStatus {
name: "github".into(),
connected: true,
tool_count: 5,
server_name: "github-mcp".into(),
server_version: "1.0.0".into(),
};
let json = serde_json::to_string(&status).unwrap();
assert!(json.contains("\"name\":\"github\""));
assert!(json.contains("\"connected\":true"));
assert!(json.contains("\"tool_count\":5"));
assert!(json.contains("\"server_name\":\"github-mcp\""));
assert!(json.contains("\"server_version\":\"1.0.0\""));
}
#[test]
fn manager_default_matches_new() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let mgr = McpConnectionManager::default();
assert_eq!(mgr.total_count().await, 0);
assert!(!mgr.is_cancelled());
});
}
#[test]
fn subscribe_cancel_receiver_fires() {
let mgr = McpConnectionManager::new();
let rx = mgr.subscribe_cancel();
assert!(!*rx.borrow());
mgr.cancel();
assert!(*rx.borrow());
assert!(rx.has_changed().unwrap());
}
#[tokio::test]
async fn connect_server_registers_registry_and_status() {
let registry = CapabilityRegistry::new();
let mgr = McpConnectionManager::new();
let config = test_sse_config("remote-test", true);
let (conn, server_handle) = test_support::echo_connection(&config.name).await.unwrap();
let tool_count = mgr
.register_connected_server(&config, ®istry, conn)
.await
.unwrap();
assert_eq!(tool_count, 1);
assert_eq!(mgr.total_count().await, 1);
assert_eq!(mgr.connected_count().await, 1);
assert!(mgr.get_connection("remote-test").await.is_some());
assert!(registry.get("remote-test::echo").await.is_some());
let statuses = mgr.server_statuses().await;
assert_eq!(statuses.len(), 1);
assert_eq!(statuses[0].name, "remote-test");
assert!(statuses[0].connected);
assert_eq!(statuses[0].tool_count, 1);
server_handle.abort();
let _ = server_handle.await;
}
#[tokio::test]
async fn disconnect_server_unregisters_registry_capabilities() {
let registry = CapabilityRegistry::new();
let mgr = McpConnectionManager::new();
let config = test_sse_config("remote-test", true);
let (conn, server_handle) = test_support::echo_connection(&config.name).await.unwrap();
mgr.register_connected_server(&config, ®istry, conn)
.await
.unwrap();
mgr.disconnect_server("remote-test", ®istry).await;
assert_eq!(mgr.total_count().await, 0);
assert!(mgr.get_connection("remote-test").await.is_none());
assert!(registry.get("remote-test::echo").await.is_none());
server_handle.abort();
let _ = server_handle.await;
}
#[tokio::test]
async fn connect_all_skips_disabled_servers() {
let registry = CapabilityRegistry::new();
let mgr = McpConnectionManager::new();
let disabled_cfg = test_sse_config("disabled-test", false);
mgr.connect_all(std::slice::from_ref(&disabled_cfg), ®istry)
.await;
assert_eq!(mgr.total_count().await, 0);
assert!(mgr.get_connection("disabled-test").await.is_none());
assert!(registry.get("disabled-test::echo").await.is_none());
assert!(!disabled_cfg.enabled);
}
#[tokio::test]
async fn register_connected_server_supports_connect_all_style_registry_state() {
let registry = CapabilityRegistry::new();
let mgr = McpConnectionManager::new();
let enabled_cfg = test_sse_config("enabled-test", true);
let (enabled_conn, enabled_handle) = test_support::echo_connection(&enabled_cfg.name)
.await
.unwrap();
mgr.register_connected_server(&enabled_cfg, ®istry, enabled_conn)
.await
.unwrap();
assert_eq!(mgr.total_count().await, 1);
assert!(mgr.get_connection("enabled-test").await.is_some());
assert!(mgr.get_connection("disabled-test").await.is_none());
assert!(registry.get("enabled-test::echo").await.is_some());
enabled_handle.abort();
let _ = enabled_handle.await;
}
#[tokio::test]
async fn health_check_loop_exits_when_cancelled() {
let registry = CapabilityRegistry::new();
let mgr = McpConnectionManager::new();
let cancel = mgr.subscribe_cancel();
mgr.cancel();
tokio::time::timeout(
Duration::from_secs(1),
mgr.health_check_loop(®istry, Duration::from_millis(10), cancel),
)
.await
.expect("health loop should exit promptly after cancellation");
}
}