use std::collections::HashMap;
use std::sync::Arc;
#[cfg(not(target_has_atomic = "64"))]
use std::sync::atomic::AtomicU32;
#[cfg(target_has_atomic = "64")]
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use anyhow::{Context, Result, anyhow, bail};
use serde_json::json;
use tokio::sync::Mutex;
use tokio::time::{Duration, timeout};
use crate::config::schema::McpServerConfig;
use crate::tools::mcp_protocol::{
JsonRpcRequest, MCP_PROTOCOL_VERSION, McpToolDef, McpToolsListResult,
};
use crate::tools::mcp_transport::{McpTransportConn, create_transport};
const RECV_TIMEOUT_SECS: u64 = 30;
const DEFAULT_TOOL_TIMEOUT_SECS: u64 = 180;
const MAX_TOOL_TIMEOUT_SECS: u64 = 600;
struct McpServerInner {
config: McpServerConfig,
transport: Box<dyn McpTransportConn>,
#[cfg(target_has_atomic = "64")]
next_id: AtomicU64,
#[cfg(not(target_has_atomic = "64"))]
next_id: AtomicU32,
tools: Vec<McpToolDef>,
}
#[derive(Clone)]
pub struct McpServer {
inner: Arc<Mutex<McpServerInner>>,
}
impl McpServer {
pub async fn connect(config: McpServerConfig) -> Result<Self> {
let mut transport = create_transport(&config).with_context(|| {
format!(
"failed to create transport for MCP server `{}`",
config.name
)
})?;
let id = 1u64;
let init_req = JsonRpcRequest::new(
id,
"initialize",
json!({
"protocolVersion": MCP_PROTOCOL_VERSION,
"capabilities": {},
"clientInfo": {
"name": "zeroclaw",
"version": env!("CARGO_PKG_VERSION")
}
}),
);
let init_resp = timeout(
Duration::from_secs(RECV_TIMEOUT_SECS),
transport.send_and_recv(&init_req),
)
.await
.with_context(|| {
format!(
"MCP server `{}` timed out after {}s waiting for initialize response",
config.name, RECV_TIMEOUT_SECS
)
})??;
if init_resp.error.is_some() {
bail!(
"MCP server `{}` rejected initialize: {:?}",
config.name,
init_resp.error
);
}
let notif = JsonRpcRequest::notification("notifications/initialized", json!({}));
let _ = transport.send_and_recv(¬if).await;
let id = 2u64;
let list_req = JsonRpcRequest::new(id, "tools/list", json!({}));
let list_resp = timeout(
Duration::from_secs(RECV_TIMEOUT_SECS),
transport.send_and_recv(&list_req),
)
.await
.with_context(|| {
format!(
"MCP server `{}` timed out after {}s waiting for tools/list response",
config.name, RECV_TIMEOUT_SECS
)
})??;
let result = list_resp
.result
.ok_or_else(|| anyhow!("tools/list returned no result from `{}`", config.name))?;
let tool_list: McpToolsListResult = serde_json::from_value(result)
.with_context(|| format!("failed to parse tools/list from `{}`", config.name))?;
let tool_count = tool_list.tools.len();
let inner = McpServerInner {
config,
transport,
#[cfg(target_has_atomic = "64")]
next_id: AtomicU64::new(3), #[cfg(not(target_has_atomic = "64"))]
next_id: AtomicU32::new(3), tools: tool_list.tools,
};
tracing::info!(
"MCP server `{}` connected — {} tool(s) available",
inner.config.name,
tool_count
);
Ok(Self {
inner: Arc::new(Mutex::new(inner)),
})
}
pub async fn tools(&self) -> Vec<McpToolDef> {
self.inner.lock().await.tools.clone()
}
pub async fn name(&self) -> String {
self.inner.lock().await.config.name.clone()
}
pub async fn call_tool(
&self,
tool_name: &str,
arguments: serde_json::Value,
) -> Result<serde_json::Value> {
let mut inner = self.inner.lock().await;
let id = inner.next_id.fetch_add(1, Ordering::Relaxed) as u64;
let req = JsonRpcRequest::new(
id,
"tools/call",
json!({ "name": tool_name, "arguments": arguments }),
);
let tool_timeout = inner
.config
.tool_timeout_secs
.unwrap_or(DEFAULT_TOOL_TIMEOUT_SECS)
.min(MAX_TOOL_TIMEOUT_SECS);
let resp = timeout(
Duration::from_secs(tool_timeout),
inner.transport.send_and_recv(&req),
)
.await
.map_err(|_| {
anyhow!(
"MCP server `{}` timed out after {}s during tool call `{tool_name}`",
inner.config.name,
tool_timeout
)
})?
.with_context(|| {
format!(
"MCP server `{}` error during tool call `{tool_name}`",
inner.config.name
)
})?;
if let Some(err) = resp.error {
bail!("MCP tool `{tool_name}` error {}: {}", err.code, err.message);
}
Ok(resp.result.unwrap_or(serde_json::Value::Null))
}
}
pub struct McpRegistry {
servers: Vec<McpServer>,
tool_index: HashMap<String, (usize, String)>,
}
impl McpRegistry {
pub async fn connect_all(configs: &[McpServerConfig]) -> Result<Self> {
let mut servers = Vec::new();
let mut tool_index = HashMap::new();
for config in configs {
match McpServer::connect(config.clone()).await {
Ok(server) => {
let server_idx = servers.len();
let tools = server.tools().await;
for tool in &tools {
let prefixed = format!("{}__{}", config.name, tool.name);
tool_index.insert(prefixed, (server_idx, tool.name.clone()));
}
servers.push(server);
}
Err(e) => {
tracing::error!("Failed to connect to MCP server `{}`: {:#}", config.name, e);
}
}
}
Ok(Self {
servers,
tool_index,
})
}
pub fn tool_names(&self) -> Vec<String> {
self.tool_index.keys().cloned().collect()
}
pub async fn get_tool_def(&self, prefixed_name: &str) -> Option<McpToolDef> {
let (server_idx, original_name) = self.tool_index.get(prefixed_name)?;
let inner = self.servers[*server_idx].inner.lock().await;
inner
.tools
.iter()
.find(|t| &t.name == original_name)
.cloned()
}
pub async fn call_tool(
&self,
prefixed_name: &str,
arguments: serde_json::Value,
) -> Result<String> {
let (server_idx, original_name) = self
.tool_index
.get(prefixed_name)
.ok_or_else(|| anyhow!("unknown MCP tool `{prefixed_name}`"))?;
let result = self.servers[*server_idx]
.call_tool(original_name, arguments)
.await?;
serde_json::to_string_pretty(&result)
.with_context(|| format!("failed to serialize result of MCP tool `{prefixed_name}`"))
}
pub fn is_empty(&self) -> bool {
self.servers.is_empty()
}
pub fn server_count(&self) -> usize {
self.servers.len()
}
pub fn tool_count(&self) -> usize {
self.tool_index.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::schema::McpTransport;
#[test]
fn tool_name_prefix_format() {
let prefixed = format!("{}__{}", "filesystem", "read_file");
assert_eq!(prefixed, "filesystem__read_file");
}
#[tokio::test]
async fn connect_nonexistent_command_fails_cleanly() {
let config = McpServerConfig {
name: "nonexistent".to_string(),
command: "/usr/bin/this_binary_does_not_exist_zeroclaw_test".to_string(),
args: vec![],
env: std::collections::HashMap::default(),
tool_timeout_secs: None,
transport: McpTransport::Stdio,
url: None,
headers: std::collections::HashMap::default(),
};
let result = McpServer::connect(config).await;
assert!(result.is_err());
let msg = result.err().unwrap().to_string();
assert!(msg.contains("failed to create transport"), "got: {msg}");
}
#[tokio::test]
async fn connect_all_nonfatal_on_single_failure() {
let configs = vec![McpServerConfig {
name: "bad".to_string(),
command: "/usr/bin/does_not_exist_zc_test".to_string(),
args: vec![],
env: std::collections::HashMap::default(),
tool_timeout_secs: None,
transport: McpTransport::Stdio,
url: None,
headers: std::collections::HashMap::default(),
}];
let registry = McpRegistry::connect_all(&configs)
.await
.expect("connect_all should not fail");
assert!(registry.is_empty());
assert_eq!(registry.tool_count(), 0);
}
#[test]
fn http_transport_requires_url() {
let config = McpServerConfig {
name: "test".into(),
transport: McpTransport::Http,
..Default::default()
};
let result = create_transport(&config);
assert!(result.is_err());
}
#[test]
fn sse_transport_requires_url() {
let config = McpServerConfig {
name: "test".into(),
transport: McpTransport::Sse,
..Default::default()
};
let result = create_transport(&config);
assert!(result.is_err());
}
#[tokio::test]
async fn empty_registry_is_empty() {
let registry = McpRegistry::connect_all(&[])
.await
.expect("connect_all on empty slice should succeed");
assert!(registry.is_empty());
assert_eq!(registry.server_count(), 0);
assert_eq!(registry.tool_count(), 0);
}
#[tokio::test]
async fn empty_registry_tool_names_is_empty() {
let registry = McpRegistry::connect_all(&[])
.await
.expect("connect_all should succeed");
assert!(registry.tool_names().is_empty());
}
#[tokio::test]
async fn empty_registry_get_tool_def_returns_none() {
let registry = McpRegistry::connect_all(&[])
.await
.expect("connect_all should succeed");
let result = registry.get_tool_def("nonexistent__tool").await;
assert!(result.is_none());
}
#[tokio::test]
async fn empty_registry_call_tool_unknown_name_returns_error() {
let registry = McpRegistry::connect_all(&[])
.await
.expect("connect_all should succeed");
let err = registry
.call_tool("nonexistent__tool", serde_json::json!({}))
.await
.expect_err("should fail for unknown tool");
assert!(err.to_string().contains("unknown MCP tool"), "got: {err}");
}
#[tokio::test]
async fn connect_all_empty_gives_zero_servers() {
let registry = McpRegistry::connect_all(&[])
.await
.expect("connect_all should succeed");
assert_eq!(registry.server_count(), 0);
assert_eq!(registry.tool_count(), 0);
assert!(registry.is_empty());
}
}