use std::collections::HashMap;
use std::sync::Arc;
use serde_json::Value;
use tokio::sync::Mutex;
use tracing::{info, warn};
use super::client::McpClient;
use super::config::{self, ResolvedMcpServer};
use super::protocol::McpServerInfo;
pub const MCP_EAGER_THRESHOLD: usize = 10;
#[derive(Debug, Clone)]
struct ToolRoute {
server_name: String,
original_name: String,
}
#[derive(Debug, Clone)]
pub struct ServerMeta {
pub description: String,
pub instructions: Option<String>,
pub tool_count: usize,
pub tool_names: Vec<String>,
}
pub struct McpManager {
clients: HashMap<String, Arc<Mutex<McpClient>>>,
routes: HashMap<String, ToolRoute>,
tool_defs: Vec<Value>,
server_meta: HashMap<String, ServerMeta>,
}
impl McpManager {
pub fn empty() -> Self {
Self {
clients: HashMap::new(),
routes: HashMap::new(),
tool_defs: Vec::new(),
server_meta: HashMap::new(),
}
}
pub async fn connect_all(working_dir: &str) -> Self {
let servers = config::load_mcp_servers(working_dir).unwrap_or_default();
let handles: Vec<_> = servers
.into_iter()
.map(|server| {
tokio::spawn(async move {
let result = connect_server(&server).await;
(
server.name.clone(),
server.source.clone(),
server.description.clone(),
result,
)
})
})
.collect();
let mut clients = HashMap::new();
let mut routes = HashMap::new();
let mut tool_defs = Vec::new();
let mut server_meta = HashMap::new();
for handle in handles {
let (server_name, server_source, server_description, result) = match handle.await {
Ok(r) => r,
Err(e) => {
warn!(error = %e, "MCP server task panicked");
continue;
}
};
match result {
Ok((client, tools, mcp_info)) => {
info!(
server = %server_name,
tool_count = tools.len(),
source = ?server_source,
description = ?server_description,
"MCP server connected"
);
let mut tool_names = Vec::new();
for def in &tools {
let original_name =
def["function"]["name"].as_str().unwrap_or("").to_string();
let prefixed = format!("mcp__{}__{}", server_name, original_name);
routes.insert(
prefixed.clone(),
ToolRoute {
server_name: server_name.clone(),
original_name: original_name.clone(),
},
);
tool_names.push(prefixed.clone());
let mut patched = def.clone();
if let Some(func) = patched.get_mut("function") {
func["name"] = Value::String(prefixed);
}
tool_defs.push(patched);
}
server_meta.insert(
server_name.clone(),
ServerMeta {
description: mcp_info.name.clone(),
instructions: mcp_info.instructions.clone(),
tool_count: tools.len(),
tool_names,
},
);
clients.insert(server_name, Arc::new(Mutex::new(client)));
}
Err(e) => {
warn!(server = %server_name, error = %e, "Failed to connect MCP server");
}
}
}
Self {
clients,
routes,
tool_defs,
server_meta,
}
}
pub fn tool_definitions(&self) -> &[Value] {
&self.tool_defs
}
pub fn eager_tool_definitions(&self) -> Vec<Value> {
if self.is_deferred_mode() {
Vec::new()
} else {
self.tool_defs.clone()
}
}
pub fn is_deferred_mode(&self) -> bool {
self.total_tool_count() > MCP_EAGER_THRESHOLD
}
pub fn total_tool_count(&self) -> usize {
self.tool_defs.len()
}
pub fn is_mcp_tool(&self, name: &str) -> bool {
self.routes.contains_key(name)
}
pub fn search_tools(&self, query: &str) -> Vec<Value> {
let query_lower = query.to_lowercase();
let keywords: Vec<&str> = query_lower.split_whitespace().collect();
self.tool_defs
.iter()
.filter(|def| {
let name = def["function"]["name"]
.as_str()
.unwrap_or("")
.to_lowercase();
let desc = def["function"]["description"]
.as_str()
.unwrap_or("")
.to_lowercase();
let haystack = format!("{} {}", name, desc);
keywords.iter().all(|kw| haystack.contains(kw))
})
.cloned()
.collect()
}
pub fn server_overview(&self) -> String {
if self.server_meta.is_empty() {
return String::new();
}
let mut out = String::from("## MCP Servers\n\n");
for (name, meta) in &self.server_meta {
out.push_str(&format!("### {name} ({} tools)\n", meta.tool_count,));
if !meta.description.is_empty() {
out.push_str(&format!("{}\n", meta.description));
}
if let Some(ref instructions) = meta.instructions {
out.push_str(&format!("\n{instructions}\n"));
}
out.push('\n');
}
if self.is_deferred_mode() {
out.push_str("<available-deferred-tools>\n");
for (server_name, meta) in &self.server_meta {
for tool_name in &meta.tool_names {
let desc = self.tool_defs.iter().find_map(|def| {
let n = def["function"]["name"].as_str()?;
if n == tool_name {
def["function"]["description"].as_str()
} else {
None
}
});
if let Some(d) = desc {
out.push_str(&format!("- {tool_name}: {d}\n"));
} else {
out.push_str(&format!("- {tool_name} (server: {server_name})\n"));
}
}
}
out.push_str("</available-deferred-tools>\n\n");
out.push_str(
"Use the `tool_search` tool to fetch full schemas for deferred tools before calling them.\n",
);
}
out
}
pub fn server_meta(&self) -> &HashMap<String, ServerMeta> {
&self.server_meta
}
pub async fn call_tool(
&self,
prefixed_name: &str,
arguments: &str,
) -> crate::common::Result<String> {
let route = self.routes.get(prefixed_name).ok_or_else(|| {
crate::common::AgentError::InvalidArgument(format!(
"Unknown MCP tool: {}",
prefixed_name
))
})?;
let client_lock = self.clients.get(&route.server_name).ok_or_else(|| {
crate::common::AgentError::Internal(format!(
"MCP server '{}' not connected",
route.server_name
))
})?;
let args: Value = if arguments.trim().is_empty() {
Value::Object(Default::default())
} else {
match serde_json::from_str(arguments) {
Ok(v) => v,
Err(e) => {
warn!(
tool = %prefixed_name,
arguments = %arguments,
error = %e,
"Failed to parse MCP tool arguments — sending empty object"
);
Value::Object(Default::default())
}
}
};
let mut client = client_lock.lock().await;
client
.call_tool(&route.original_name, args)
.await
.map_err(|e| crate::common::AgentError::Internal(format!("MCP call failed: {e}")))
}
pub async fn shutdown_all(&self) {
for (name, client_lock) in &self.clients {
let mut client = client_lock.lock().await;
if let Err(e) = client.shutdown().await {
warn!(server = %name, error = %e, "MCP server shutdown error");
}
}
}
pub async fn child_pids(&self) -> Vec<u32> {
let mut pids = Vec::new();
for client in self.clients.values() {
if let Some(pid) = client.lock().await.pid() {
pids.push(pid);
}
}
pids
}
pub fn server_count(&self) -> usize {
self.clients.len()
}
}
async fn connect_server(
server: &ResolvedMcpServer,
) -> anyhow::Result<(McpClient, Vec<Value>, McpServerInfo)> {
let mut client = if let Some(ref cmd) = server.command {
let args: Vec<&str> = server.args.iter().map(|s| s.as_str()).collect();
if server.env.is_empty() {
McpClient::connect_stdio(cmd, &args)?
} else {
McpClient::connect_stdio_with_env(cmd, &args, &server.env)?
}
} else if let Some(ref url) = server.url {
if server.headers.is_empty() {
McpClient::connect_http(url)?
} else {
McpClient::connect_http_with_headers(url, &server.headers)?
}
} else {
anyhow::bail!("MCP server '{}' has neither command nor url", server.name);
};
if let Some(url) = client.target_url() {
tracing::debug!(server = %server.name, url = %url, "MCP HTTP transport target");
}
let info = tokio::time::timeout(std::time::Duration::from_secs(30), client.initialize())
.await
.map_err(|_| anyhow::anyhow!("MCP initialize timed out for '{}'", server.name))??;
tokio::time::timeout(std::time::Duration::from_secs(15), client.list_tools())
.await
.map_err(|_| anyhow::anyhow!("MCP list_tools timed out for '{}'", server.name))??;
let defs = client.tool_definitions();
let raw_count = client.raw_tools().len();
if let Some(instructions) = client.server_instructions() {
tracing::debug!(
server = %server.name,
raw_tools = raw_count,
instructions_len = instructions.len(),
"MCP server has system instructions"
);
}
Ok((client, defs, info))
}