use super::protocol::{McpMessage, McpRequest, RequestId, ToolDefinition, ToolResult};
#[cfg(unix)]
use super::transport::UnixSocketTransport;
use super::transport::{McpTransport, McpTransportType, StdioTransport, TcpTransport};
use anyhow::{anyhow, Result};
use parking_lot::RwLock as ParkingRwLock;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tracing::{debug, info, warn};
#[derive(Debug, Clone)]
pub struct McpServerConfig {
pub name: String,
pub transport: McpTransportType,
pub auto_reconnect: bool,
pub reconnect_interval_ms: u64,
}
struct ConnectedServer {
#[allow(dead_code)]
config: McpServerConfig,
transport: Arc<dyn McpTransport>,
tools: Vec<ToolDefinition>,
}
pub struct McpClientManager {
configs: ParkingRwLock<HashMap<String, McpServerConfig>>,
servers: ParkingRwLock<HashMap<String, ConnectedServer>>,
tool_mapping: ParkingRwLock<HashMap<String, String>>,
request_id_counter: AtomicU64,
}
impl Default for McpClientManager {
fn default() -> Self {
Self::new()
}
}
impl McpClientManager {
pub fn new() -> Self {
Self {
configs: ParkingRwLock::new(HashMap::new()),
servers: ParkingRwLock::new(HashMap::new()),
tool_mapping: ParkingRwLock::new(HashMap::new()),
request_id_counter: AtomicU64::new(1),
}
}
fn next_request_id(&self) -> RequestId {
RequestId::Number(self.request_id_counter.fetch_add(1, Ordering::SeqCst) as i64)
}
pub async fn add_server(&self, config: McpServerConfig) -> Result<()> {
let name = config.name.clone();
let mut configs = self.configs.write();
configs.insert(name, config);
Ok(())
}
pub async fn connect(&self, name: &str) -> Result<()> {
let config = {
let configs = self.configs.read();
configs
.get(name)
.ok_or_else(|| anyhow!("Server not found: {}", name))?
.clone()
};
info!(server = %name, transport = ?config.transport, "Connecting to MCP server");
let transport: Arc<dyn McpTransport> = match &config.transport {
McpTransportType::Stdio { command, args } => {
let stdio_transport = StdioTransport::new(command, args)?;
stdio_transport.start(command, args).await?;
info!(server = %name, command = %command, "Stdio transport started");
Arc::new(stdio_transport)
}
McpTransportType::Tcp { addr } => {
let tcp_transport = TcpTransport::connect(addr).await?;
info!(server = %name, addr = %addr, "TCP transport connected");
Arc::new(tcp_transport)
}
#[cfg(unix)]
McpTransportType::Unix { path } => {
let unix_transport = UnixSocketTransport::connect(path).await?;
info!(server = %name, path = %path, "Unix socket transport connected");
Arc::new(unix_transport)
}
};
let init_params = serde_json::json!({
"protocolVersion": "2024-11-05",
"capabilities": {
"roots": {
"listChanged": true
},
"sampling": {}
},
"clientInfo": {
"name": "continuum",
"version": env!("CARGO_PKG_VERSION")
}
});
let request = McpRequest {
id: self.next_request_id(),
method: "initialize".to_string(),
params: Some(init_params),
};
debug!(server = %name, "Sending initialize request");
transport.send(&McpMessage::Request(request)).await?;
let response =
tokio::time::timeout(std::time::Duration::from_secs(30), transport.receive())
.await
.map_err(|_| anyhow!("Initialize timeout for server: {}", name))??;
match response {
Some(McpMessage::Response(response)) => {
if let Some(error) = &response.error {
warn!(server = %name, code = ?error.code, message = %error.message, "Initialize failed");
return Err(anyhow!(
"Initialize failed (code {}): {}",
error.code,
error.message
));
}
if let Some(result) = &response.result {
debug!(server = %name, result = ?result, "Server capabilities received");
if let Some(server_info) = result.get("serverInfo") {
info!(server = %name, server_info = ?server_info, "Connected to MCP server");
}
}
}
Some(McpMessage::Error(error)) => {
warn!(server = %name, error = ?error, "Received error response");
return Err(anyhow!("Server error: {:?}", error));
}
Some(other) => {
warn!(server = %name, message = ?other, "Unexpected message type");
return Err(anyhow!("Unexpected response type during initialization"));
}
None => {
warn!(server = %name, "No response received");
return Err(anyhow!("No response from server during initialization"));
}
}
let notification = McpMessage::Notification(super::protocol::McpNotification {
method: "notifications/initialized".to_string(),
params: None,
});
transport.send(¬ification).await?;
debug!(server = %name, "Sent initialized notification");
let list_tools_request = McpRequest {
id: self.next_request_id(),
method: "tools/list".to_string(),
params: None,
};
transport
.send(&McpMessage::Request(list_tools_request))
.await?;
let tools_response =
tokio::time::timeout(std::time::Duration::from_secs(10), transport.receive())
.await
.map_err(|_| anyhow!("Tools list timeout for server: {}", name))??;
let tools = match tools_response {
Some(McpMessage::Response(response)) => {
if let Some(result) = &response.result {
if let Some(tools_array) = result.get("tools") {
match serde_json::from_value::<Vec<ToolDefinition>>(tools_array.clone()) {
Ok(t) => {
info!(server = %name, tool_count = t.len(), "Tools discovered");
t
}
Err(e) => {
warn!(server = %name, error = %e, "Failed to parse tools");
Vec::new()
}
}
} else {
Vec::new()
}
} else {
Vec::new()
}
}
_ => Vec::new(),
};
let mut servers = self.servers.write();
servers.insert(
name.to_string(),
ConnectedServer {
config,
transport,
tools,
},
);
info!(server = %name, "MCP connection established successfully");
Ok(())
}
pub async fn connect_all(&self) -> Result<Vec<String>> {
let names: Vec<String> = self.configs.read().keys().cloned().collect();
let mut results = Vec::new();
for name in &names {
if self.connect(name).await.is_ok() {
results.push(name.clone());
}
}
Ok(results)
}
pub async fn disconnect(&self, name: &str) -> Result<()> {
let server = {
let mut servers = self.servers.write();
servers.remove(name)
};
if let Some(s) = server {
s.transport.close().await?;
}
Ok(())
}
pub fn get_server_status(&self, name: &str) -> Option<bool> {
let servers = self.servers.read();
servers.get(name).map(|_| true)
}
pub fn list_servers(&self) -> Vec<(String, bool)> {
let servers = self.servers.read();
let configs = self.configs.read();
let mut result = Vec::new();
for name in configs.keys() {
let connected = servers.contains_key(name);
result.push((name.clone(), connected));
}
result
}
pub fn list_all_tools(&self) -> Vec<(String, ToolDefinition)> {
let servers = self.servers.read();
let mut tools = Vec::new();
for (server_name, server) in servers.iter() {
for tool in &server.tools {
tools.push((server_name.clone(), tool.clone()));
}
}
tools
}
pub async fn call_tool(&self, tool_name: &str, arguments: Value) -> Result<ToolResult> {
let server_name = {
let tool_mapping = self.tool_mapping.read();
tool_mapping
.get(tool_name)
.ok_or_else(|| anyhow!("Tool not found: {}", tool_name))?
.clone()
};
let (transport, request_id) = {
let servers = self.servers.read();
let server = servers
.get(&server_name)
.ok_or_else(|| anyhow!("Server not found: {}", server_name))?;
(server.transport.clone(), self.next_request_id())
};
let params = serde_json::json!({
"name": tool_name,
"arguments": arguments
});
let request = McpRequest {
id: request_id,
method: "tools/call".to_string(),
params: Some(params),
};
transport.send(&McpMessage::Request(request)).await?;
match transport.receive().await? {
Some(McpMessage::Response(response)) => {
if let Some(error) = response.error {
return Err(anyhow!("Tool call error: {}", error.message));
}
if let Some(result) = response.result {
let tool_result: ToolResult =
serde_json::from_value(result).unwrap_or_else(|_| ToolResult {
is_error: false,
content: vec![super::protocol::ContentBlock::Text {
text: "Tool executed successfully".to_string(),
}],
});
Ok(tool_result)
} else {
Err(anyhow!("Empty response"))
}
}
Some(McpMessage::Error(error)) => Err(anyhow!("Error: {:?}", error)),
_ => Err(anyhow!("Unexpected response type")),
}
}
pub async fn register_tools(
&self,
server_name: &str,
tools: Vec<ToolDefinition>,
) -> Result<()> {
let mut servers = self.servers.write();
let server = servers
.get_mut(server_name)
.ok_or_else(|| anyhow!("Server not found: {}", server_name))?;
let mut tool_mapping = self.tool_mapping.write();
for tool in &tools {
tool_mapping.insert(tool.name.clone(), server_name.to_string());
}
server.tools = tools;
Ok(())
}
pub fn render_status(&self) -> String {
let servers = self.servers.read();
let configs = self.configs.read();
let mut output = String::new();
output.push_str("MCP Servers:\n");
if configs.is_empty() {
output.push_str(" No servers configured\n");
} else {
for name in configs.keys() {
let server = servers.get(name);
let status = if server.is_some() { "🟢" } else { "🔴" };
let tool_count = server.map(|s| s.tools.len()).unwrap_or(0);
output.push_str(&format!(" {} {} ({} tools)\n", status, name, tool_count));
}
}
output
}
}
pub fn preset_servers() -> Vec<McpServerConfig> {
vec![
McpServerConfig {
name: "filesystem".to_string(),
transport: McpTransportType::Stdio {
command: "mcp-server-filesystem".to_string(),
args: vec!["--root".to_string(), ".".to_string()],
},
auto_reconnect: true,
reconnect_interval_ms: 5000,
},
McpServerConfig {
name: "github".to_string(),
transport: McpTransportType::Stdio {
command: "mcp-server-github".to_string(),
args: vec![],
},
auto_reconnect: true,
reconnect_interval_ms: 5000,
},
McpServerConfig {
name: "playwright".to_string(),
transport: McpTransportType::Stdio {
command: "npx".to_string(),
args: vec![
"@playwright/mcp@latest".to_string(),
"--headless".to_string(),
"--browser".to_string(),
"chrome".to_string(),
],
},
auto_reconnect: true,
reconnect_interval_ms: 5000,
},
]
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_manager_creation() {
let manager = McpClientManager::new();
let servers = manager.list_servers();
assert!(servers.is_empty());
}
#[tokio::test]
async fn test_add_server() {
let manager = McpClientManager::new();
let config = McpServerConfig {
name: "test".to_string(),
transport: McpTransportType::Stdio {
command: "test-command".to_string(),
args: vec![],
},
auto_reconnect: false,
reconnect_interval_ms: 1000,
};
manager.add_server(config).await.unwrap();
let servers = manager.list_servers();
assert_eq!(servers.len(), 1);
assert_eq!(servers[0].0, "test");
assert!(!servers[0].1); }
#[test]
fn test_preset_servers() {
let presets = preset_servers();
assert!(!presets.is_empty());
assert!(presets.iter().any(|s| s.name == "filesystem"));
assert!(presets.iter().any(|s| s.name == "github"));
}
#[test]
fn test_request_id_generation() {
let manager = McpClientManager::new();
let id1 = manager.next_request_id();
let id2 = manager.next_request_id();
assert_ne!(id1, id2);
}
}