use std::collections::HashMap;
use anyhow::{Context, Result};
use serde_json::Value;
use super::client::{McpClient, McpClientStatus};
use super::config::{self, McpServerConfig};
use super::tool_bridge::{McpToolAnnotations, parse_mcp_tool_name};
use crate::db::Database;
use crate::providers::ToolDefinition;
use crate::tools::ToolEffect;
pub struct McpManager {
clients: HashMap<String, McpClient>,
annotations: HashMap<String, McpToolAnnotations>,
}
impl Default for McpManager {
fn default() -> Self {
Self::new()
}
}
impl McpManager {
pub fn new() -> Self {
Self {
clients: HashMap::new(),
annotations: HashMap::new(),
}
}
pub async fn start_from_db(db: &Database) -> Result<Self> {
let configs = config::load_mcp_configs(db).await?;
if configs.is_empty() {
tracing::debug!("no MCP servers configured");
return Ok(Self::new());
}
tracing::info!(
count = configs.len(),
servers = ?configs.keys().collect::<Vec<_>>(),
"starting MCP servers"
);
let mut manager = Self::new();
manager.connect_all(configs).await;
Ok(manager)
}
async fn connect_all(&mut self, configs: HashMap<String, McpServerConfig>) {
let handles: Vec<_> = configs
.into_iter()
.map(|(name, config)| {
tokio::spawn(async move {
let mut client = McpClient::new(name.clone(), config);
let result = client.connect().await;
(name, client, result)
})
})
.collect();
for handle in handles {
match handle.await {
Ok((name, client, result)) => {
if let Err(e) = &result {
tracing::warn!(
server = %name,
error = %e,
"MCP server failed to connect (non-fatal)"
);
}
for tool in client.tools() {
self.annotations
.insert(tool.definition.name.clone(), tool.annotations.clone());
}
self.clients.insert(name, client);
}
Err(e) => {
tracing::error!(error = %e, "MCP server connect task panicked");
}
}
}
let connected = self
.clients
.values()
.filter(|c| c.status() == McpClientStatus::Connected)
.count();
let total = self.clients.len();
tracing::info!(connected, total, "MCP server startup complete");
}
pub fn all_tool_definitions(&self) -> Vec<ToolDefinition> {
self.clients
.values()
.filter(|c| c.status() == McpClientStatus::Connected)
.flat_map(|c| c.tools().iter().map(|t| t.definition.clone()))
.collect()
}
pub fn server_instructions(&self) -> Vec<(String, String)> {
let mut out: Vec<(String, String)> = self
.clients
.iter()
.filter(|(_, c)| c.status() == McpClientStatus::Connected)
.filter_map(|(name, c)| {
c.instructions()
.map(|instr| (name.clone(), instr.to_string()))
})
.collect();
out.sort_by(|a, b| a.0.cmp(&b.0));
out
}
pub fn classify_tool(&self, qualified_name: &str) -> ToolEffect {
let annotations = self.annotations.get(qualified_name);
super::tool_bridge::classify_mcp_tool(annotations)
}
pub async fn call_tool(&self, qualified_name: &str, arguments: Value) -> Result<String> {
let (server_name, tool_name) = parse_mcp_tool_name(qualified_name)
.context("invalid MCP tool name format (expected server__tool)")?;
let client = self
.clients
.get(server_name)
.context(format!("MCP server '{server_name}' not found"))?;
if client.status() != McpClientStatus::Connected {
anyhow::bail!(
"MCP server '{server_name}' is not connected (status: {:?})",
client.status()
);
}
let result = client.call_tool(tool_name, arguments).await?;
let output = call_tool_result_to_string(&result);
if result.is_error.unwrap_or(false) {
anyhow::bail!("MCP tool error: {output}");
}
Ok(output)
}
pub fn has_tool(&self, qualified_name: &str) -> bool {
self.annotations.contains_key(qualified_name)
}
pub fn status_summary(&self) -> Vec<McpServerStatus> {
self.clients
.values()
.map(|c| McpServerStatus {
name: c.name().to_string(),
status: c.status(),
tool_count: c.tools().len(),
error: c.last_error().map(String::from),
})
.collect()
}
pub fn status_bar_summary(&self) -> Option<McpStatusBarInfo> {
if self.clients.is_empty() {
return None;
}
let total = self.clients.len();
let connected = self
.clients
.values()
.filter(|c| c.status() == McpClientStatus::Connected)
.count();
let failed = self
.clients
.values()
.filter(|c| c.status() == McpClientStatus::Failed)
.count();
Some(McpStatusBarInfo {
connected,
failed,
total,
})
}
pub async fn reconnect_server(&mut self, name: &str) -> Result<usize> {
let client = self
.clients
.get_mut(name)
.with_context(|| format!("MCP server '{name}' not found"))?;
client.disconnect().await;
let prefix = format!("{name}__");
self.annotations.retain(|k, _| !k.starts_with(&prefix));
client.connect().await?;
for tool in client.tools() {
self.annotations
.insert(tool.definition.name.clone(), tool.annotations.clone());
}
Ok(client.tools().len())
}
pub async fn shutdown(&mut self) {
for client in self.clients.values_mut() {
client.disconnect().await;
}
self.annotations.clear();
tracing::info!("all MCP servers disconnected");
}
pub async fn add_server(&mut self, name: String, config: McpServerConfig) -> Result<()> {
if let Some(mut old) = self.clients.remove(&name) {
old.disconnect().await;
self.annotations.retain(|k, _| {
parse_mcp_tool_name(k)
.map(|(s, _)| s != name)
.unwrap_or(true)
});
}
let mut client = McpClient::new(name.clone(), config);
client.connect().await?;
for tool in client.tools() {
self.annotations
.insert(tool.definition.name.clone(), tool.annotations.clone());
}
self.clients.insert(name, client);
Ok(())
}
pub async fn remove_server(&mut self, name: &str) -> bool {
if let Some(mut client) = self.clients.remove(name) {
client.disconnect().await;
self.annotations.retain(|k, _| {
parse_mcp_tool_name(k)
.map(|(s, _)| s != name)
.unwrap_or(true)
});
tracing::info!(server = %name, "MCP server removed");
true
} else {
false
}
}
pub fn is_empty(&self) -> bool {
self.clients.is_empty()
}
pub fn connected_count(&self) -> usize {
self.clients
.values()
.filter(|c| c.status() == McpClientStatus::Connected)
.count()
}
#[cfg(feature = "test-support")]
pub fn insert_client_for_test(&mut self, client: McpClient) {
let name = client.name().to_string();
for tool in client.tools() {
self.annotations
.insert(tool.definition.name.clone(), tool.annotations.clone());
}
self.clients.insert(name, client);
}
}
#[derive(Debug, Clone)]
pub struct McpServerStatus {
pub name: String,
pub status: McpClientStatus,
pub tool_count: usize,
pub error: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct McpStatusBarInfo {
pub connected: usize,
pub failed: usize,
pub total: usize,
}
fn call_tool_result_to_string(result: &rmcp::model::CallToolResult) -> String {
let mut parts: Vec<String> = Vec::new();
for content in &result.content {
match &content.raw {
rmcp::model::RawContent::Text(text) => {
parts.push(text.text.clone());
}
other => {
let kind = format!("{:?}", std::mem::discriminant(other));
tracing::debug!(content_type = %kind, "MCP tool returned non-text content");
parts.push(format!("[non-text content: {kind}]"));
}
}
}
if parts.is_empty() {
"(no output)".to_string()
} else {
parts.join("\n")
}
}
#[cfg(test)]
mod tests {
use super::super::config::{McpServerConfig, McpTransport};
use super::*;
use std::collections::HashMap;
fn dummy_config() -> McpServerConfig {
McpServerConfig {
transport: McpTransport::Stdio {
command: "false".into(),
args: vec![],
env: HashMap::new(),
cwd: None,
},
startup_timeout_sec: 1,
tool_timeout_sec: 1,
enabled_tools: None,
disabled_tools: None,
}
}
fn manager_with_mixed_clients() -> McpManager {
let mut mgr = McpManager::new();
let c1 = McpClient::new("server_a".into(), dummy_config());
mgr.insert_client_for_test(c1);
let mut c2 = McpClient::new("server_b".into(), dummy_config());
c2.set_status_for_test(McpClientStatus::Failed);
c2.set_last_error_for_test(Some("connection refused".into()));
mgr.insert_client_for_test(c2);
mgr
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn call_tool_on_disconnected_server_returns_err() {
let mgr = manager_with_mixed_clients();
let result = mgr
.call_tool("server_a__some_tool", serde_json::json!({}))
.await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("not connected"),
"expected 'not connected' in: {msg}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn call_tool_on_nonexistent_server_returns_err() {
let mgr = McpManager::new();
let result = mgr.call_tool("ghost__tool", serde_json::json!({})).await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("not found"), "expected 'not found' in: {msg}");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn call_tool_with_invalid_name_returns_err() {
let mgr = McpManager::new();
let result = mgr.call_tool("no_separator", serde_json::json!({})).await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("invalid MCP tool name"),
"expected parse error in: {msg}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remove_server_purges_annotations() {
let mut mgr = McpManager::new();
let c = McpClient::new("myserver".into(), dummy_config());
mgr.insert_client_for_test(c);
mgr.annotations
.insert("myserver__list_files".into(), McpToolAnnotations::default());
assert!(mgr.has_tool("myserver__list_files"));
let removed = mgr.remove_server("myserver").await;
assert!(removed);
assert!(
!mgr.has_tool("myserver__list_files"),
"annotation cache must be purged after remove"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remove_nonexistent_server_returns_false() {
let mut mgr = McpManager::new();
assert!(!mgr.remove_server("ghost").await);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn add_server_collision_purges_old_annotations() {
let mut mgr = McpManager::new();
let c = McpClient::new("myserver".into(), dummy_config());
mgr.insert_client_for_test(c);
mgr.annotations
.insert("myserver__old_tool".into(), McpToolAnnotations::default());
assert!(mgr.has_tool("myserver__old_tool"), "precondition");
let result = mgr.add_server("myserver".into(), dummy_config()).await;
let _ = result;
assert!(
!mgr.has_tool("myserver__old_tool"),
"stale annotation must be purged on collision, even if reconnect fails"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn reconnect_server_purges_stale_annotations() {
let mut mgr = McpManager::new();
let c = McpClient::new("db".into(), dummy_config());
mgr.insert_client_for_test(c);
mgr.annotations
.insert("db__query".into(), McpToolAnnotations::default());
mgr.annotations
.insert("db__insert".into(), McpToolAnnotations::default());
assert!(mgr.has_tool("db__query"), "precondition");
let _ = mgr.reconnect_server("db").await;
assert!(
!mgr.has_tool("db__query"),
"db__query annotation must be purged after reconnect attempt"
);
assert!(
!mgr.has_tool("db__insert"),
"db__insert annotation must be purged after reconnect attempt"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn reconnect_server_nonexistent_returns_err() {
let mut mgr = McpManager::new();
let result = mgr.reconnect_server("ghost").await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not found"));
}
#[test]
fn status_bar_summary_empty_returns_none() {
let mgr = McpManager::new();
assert!(mgr.status_bar_summary().is_none());
}
#[test]
fn status_bar_summary_all_connected() {
let mut mgr = McpManager::new();
let mut c1 = McpClient::new("a".into(), dummy_config());
c1.set_status_for_test(McpClientStatus::Connected);
mgr.insert_client_for_test(c1);
let mut c2 = McpClient::new("b".into(), dummy_config());
c2.set_status_for_test(McpClientStatus::Connected);
mgr.insert_client_for_test(c2);
let info = mgr.status_bar_summary().unwrap();
assert_eq!(info.connected, 2);
assert_eq!(info.failed, 0);
assert_eq!(info.total, 2);
}
#[test]
fn status_bar_summary_partial_failure() {
let mgr = manager_with_mixed_clients();
let info = mgr.status_bar_summary().unwrap();
assert_eq!(info.connected, 0);
assert_eq!(info.failed, 1);
assert_eq!(info.total, 2);
}
#[test]
fn status_bar_summary_all_failed() {
let mut mgr = McpManager::new();
let mut c = McpClient::new("bad".into(), dummy_config());
c.set_status_for_test(McpClientStatus::Failed);
mgr.insert_client_for_test(c);
let info = mgr.status_bar_summary().unwrap();
assert_eq!(info.connected, 0);
assert_eq!(info.failed, 1);
assert_eq!(info.total, 1);
}
#[test]
fn result_to_string_text_content() {
use rmcp::model::{CallToolResult, Content};
let result = CallToolResult::success(vec![Content::text("hello"), Content::text("world")]);
assert_eq!(call_tool_result_to_string(&result), "hello\nworld");
}
#[test]
fn result_to_string_empty_content() {
let result = rmcp::model::CallToolResult::success(vec![]);
assert_eq!(call_tool_result_to_string(&result), "(no output)");
}
#[test]
fn result_to_string_non_text_content() {
use rmcp::model::{CallToolResult, Content};
let result = CallToolResult::success(vec![Content::image("iVBOR", "image/png")]);
let output = call_tool_result_to_string(&result);
assert!(
output.contains("non-text content"),
"expected non-text placeholder in: {output}"
);
}
use super::super::client::DiscoveredTool;
fn fake_tool(qualified_name: &str, read_only: Option<bool>) -> DiscoveredTool {
let original = qualified_name.split("__").nth(1).unwrap_or(qualified_name);
DiscoveredTool {
definition: ToolDefinition {
name: qualified_name.into(),
description: "fake test tool".into(),
parameters: serde_json::json!({
"type": "object",
"properties": {}
}),
},
annotations: McpToolAnnotations {
read_only_hint: read_only,
destructive_hint: None,
},
original_name: original.into(),
}
}
fn connected_client_with_tools(name: &str, tools: Vec<DiscoveredTool>) -> McpClient {
let mut c = McpClient::new(name.into(), dummy_config());
c.set_status_for_test(McpClientStatus::Connected);
c.set_tools_for_test(tools);
c
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn shutdown_clears_all_annotations() {
let mut mgr = McpManager::new();
let c1 = connected_client_with_tools("alpha", vec![fake_tool("alpha__read", Some(true))]);
let c2 = connected_client_with_tools("beta", vec![fake_tool("beta__write", Some(false))]);
mgr.insert_client_for_test(c1);
mgr.insert_client_for_test(c2);
assert!(mgr.has_tool("alpha__read"));
assert!(mgr.has_tool("beta__write"));
mgr.shutdown().await;
assert!(
!mgr.has_tool("alpha__read") && !mgr.has_tool("beta__write"),
"shutdown must purge ALL cached annotations"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn shutdown_on_empty_manager_is_noop() {
let mut mgr = McpManager::new();
mgr.shutdown().await;
assert!(mgr.is_empty());
}
#[test]
fn status_summary_reports_per_server_details() {
let mut mgr = McpManager::new();
let c1 = connected_client_with_tools(
"good",
vec![fake_tool("good__t1", None), fake_tool("good__t2", None)],
);
mgr.insert_client_for_test(c1);
let mut c2 = McpClient::new("bad".into(), dummy_config());
c2.set_status_for_test(McpClientStatus::Failed);
c2.set_last_error_for_test(Some("connect timeout".into()));
mgr.insert_client_for_test(c2);
let mut summary = mgr.status_summary();
summary.sort_by(|a, b| a.name.cmp(&b.name));
assert_eq!(summary.len(), 2);
assert_eq!(summary[0].name, "bad");
assert_eq!(summary[0].status, McpClientStatus::Failed);
assert_eq!(summary[0].tool_count, 0);
assert_eq!(summary[0].error.as_deref(), Some("connect timeout"));
assert_eq!(summary[1].name, "good");
assert_eq!(summary[1].status, McpClientStatus::Connected);
assert_eq!(summary[1].tool_count, 2);
assert_eq!(summary[1].error, None);
}
#[test]
fn status_summary_empty_returns_empty_vec() {
let mgr = McpManager::new();
assert!(mgr.status_summary().is_empty());
}
#[test]
fn all_tool_definitions_excludes_disconnected_and_failed_clients() {
let mut mgr = McpManager::new();
let connected = connected_client_with_tools(
"live",
vec![fake_tool("live__a", None), fake_tool("live__b", None)],
);
mgr.insert_client_for_test(connected);
let mut zombie = McpClient::new("zombie".into(), dummy_config());
zombie.set_tools_for_test(vec![fake_tool("zombie__ghost", None)]);
mgr.insert_client_for_test(zombie);
let mut failed = McpClient::new("broken".into(), dummy_config());
failed.set_status_for_test(McpClientStatus::Failed);
failed.set_tools_for_test(vec![fake_tool("broken__nope", None)]);
mgr.insert_client_for_test(failed);
let defs = mgr.all_tool_definitions();
let names: std::collections::HashSet<&str> = defs.iter().map(|d| d.name.as_str()).collect();
assert_eq!(
defs.len(),
2,
"only the Connected client's tools may be exposed; got {names:?}"
);
assert!(names.contains("live__a") && names.contains("live__b"));
assert!(!names.contains("zombie__ghost"));
assert!(!names.contains("broken__nope"));
}
#[test]
fn classify_tool_uses_cached_annotations_when_present() {
let mut mgr = McpManager::new();
let c = connected_client_with_tools("db", vec![fake_tool("db__select", Some(true))]);
mgr.insert_client_for_test(c);
assert!(mgr.has_tool("db__select"));
let expected = super::super::tool_bridge::classify_mcp_tool(Some(&McpToolAnnotations {
read_only_hint: Some(true),
destructive_hint: None,
}));
assert_eq!(mgr.classify_tool("db__select"), expected);
}
#[test]
fn classify_tool_unknown_falls_back_to_default() {
let mgr = McpManager::new();
let unknown = mgr.classify_tool("does_not__exist");
let expected = super::super::tool_bridge::classify_mcp_tool(None);
assert_eq!(
unknown, expected,
"unknown tool must use the same default classification as a `None` annotation"
);
}
#[test]
fn is_empty_and_connected_count_reflect_state() {
let mut mgr = McpManager::new();
assert!(mgr.is_empty());
assert_eq!(mgr.connected_count(), 0);
let zombie = McpClient::new("zombie".into(), dummy_config());
mgr.insert_client_for_test(zombie);
assert!(!mgr.is_empty());
assert_eq!(mgr.connected_count(), 0);
let live = connected_client_with_tools("live", vec![]);
mgr.insert_client_for_test(live);
assert_eq!(mgr.connected_count(), 1);
let live2 = connected_client_with_tools("live2", vec![]);
mgr.insert_client_for_test(live2);
assert_eq!(mgr.connected_count(), 2);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remove_server_does_not_touch_other_servers_annotations() {
let mut mgr = McpManager::new();
let s1 = connected_client_with_tools("db", vec![fake_tool("db__select", None)]);
let s2 =
connected_client_with_tools("db_archive", vec![fake_tool("db_archive__list", None)]);
mgr.insert_client_for_test(s1);
mgr.insert_client_for_test(s2);
assert!(mgr.has_tool("db__select"));
assert!(mgr.has_tool("db_archive__list"));
let removed = mgr.remove_server("db").await;
assert!(removed);
assert!(
!mgr.has_tool("db__select"),
"removed server's tool must be gone"
);
assert!(
mgr.has_tool("db_archive__list"),
"sibling server's tool MUST survive removal of a name-prefix-sharing server"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn add_server_collision_does_not_touch_prefix_neighbour() {
let mut mgr = McpManager::new();
let s1 = connected_client_with_tools("db", vec![fake_tool("db__old", None)]);
let s2 =
connected_client_with_tools("db_archive", vec![fake_tool("db_archive__keep", None)]);
mgr.insert_client_for_test(s1);
mgr.insert_client_for_test(s2);
let _ = mgr.add_server("db".into(), dummy_config()).await;
assert!(!mgr.has_tool("db__old"));
assert!(
mgr.has_tool("db_archive__keep"),
"add_server collision must not purge sibling-prefix server's tools"
);
}
#[test]
fn server_instructions_empty_when_no_clients() {
let mgr = McpManager::new();
assert!(mgr.server_instructions().is_empty());
}
#[test]
fn server_instructions_skips_unconnected_servers() {
let mut mgr = McpManager::new();
let mut c = McpClient::new("halfdead".into(), dummy_config());
c.set_status_for_test(McpClientStatus::Failed);
c.set_instructions_for_test(Some("will be ignored".into()));
mgr.insert_client_for_test(c);
assert!(
mgr.server_instructions().is_empty(),
"only Connected servers should contribute instructions"
);
}
#[test]
fn server_instructions_skips_connected_without_instructions() {
let mut mgr = McpManager::new();
let mut c = McpClient::new("silent".into(), dummy_config());
c.set_status_for_test(McpClientStatus::Connected);
mgr.insert_client_for_test(c);
assert!(mgr.server_instructions().is_empty());
}
#[test]
fn server_instructions_returns_connected_with_instructions_sorted() {
let mut mgr = McpManager::new();
for (name, instr) in [
("zebra", "Z guidance"),
("alpha", "A guidance"),
("middle", "M guidance"),
] {
let mut c = McpClient::new(name.into(), dummy_config());
c.set_status_for_test(McpClientStatus::Connected);
c.set_instructions_for_test(Some(instr.into()));
mgr.insert_client_for_test(c);
}
let result = mgr.server_instructions();
assert_eq!(
result,
vec![
("alpha".to_string(), "A guidance".to_string()),
("middle".to_string(), "M guidance".to_string()),
("zebra".to_string(), "Z guidance".to_string()),
],
"results must be sorted by server name for stable prompt rendering"
);
}
}