use std::collections::HashMap;
use std::sync::Arc;
use super::bridge::{McpBridgedTool, mcp_tool_name};
use super::client::{McpClientError, McpStdioClient};
use super::protocol::{McpServerConfig, McpToolDefinition};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum McpServerStatus {
Disconnected,
Connecting,
Connected,
Error,
}
impl std::fmt::Display for McpServerStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Disconnected => write!(f, "disconnected"),
Self::Connecting => write!(f, "connecting"),
Self::Connected => write!(f, "connected"),
Self::Error => write!(f, "error"),
}
}
}
#[derive(Debug, Clone)]
pub struct McpServerSummary {
pub name: String,
pub status: McpServerStatus,
pub server_version: Option<String>,
pub tool_count: usize,
pub error: Option<String>,
}
struct ConnectedServer {
client: Arc<McpStdioClient>,
tools: Vec<McpToolDefinition>,
}
pub struct McpManager {
servers: HashMap<String, ConnectedServer>,
errors: HashMap<String, String>,
}
impl McpManager {
pub fn new() -> Self {
Self {
servers: HashMap::new(),
errors: HashMap::new(),
}
}
pub async fn connect(
&mut self,
config: &McpServerConfig,
) -> Result<Vec<McpBridgedTool>, McpClientError> {
self.disconnect(&config.name).await;
let client = McpStdioClient::connect(config).await.inspect_err(|e| {
self.errors.insert(config.name.clone(), e.to_string());
})?;
let client = Arc::new(client);
let tools = client.tools().to_vec();
let bridged: Vec<McpBridgedTool> = tools
.iter()
.map(|tool_def| {
McpBridgedTool::new(config.name.clone(), tool_def.clone(), client.clone())
})
.collect();
self.servers.insert(
config.name.clone(),
ConnectedServer {
client,
tools: tools.clone(),
},
);
self.errors.remove(&config.name);
Ok(bridged)
}
pub async fn disconnect(&mut self, name: &str) {
if let Some(server) = self.servers.remove(name) {
server.client.shutdown().await;
}
}
pub async fn shutdown_all(&mut self) {
let names: Vec<String> = self.servers.keys().cloned().collect();
for name in names {
self.disconnect(&name).await;
}
}
pub fn list_servers(&self) -> Vec<McpServerSummary> {
let mut summaries: Vec<McpServerSummary> = self
.servers
.iter()
.map(|(name, server)| McpServerSummary {
name: name.clone(),
status: McpServerStatus::Connected,
server_version: server.client.server_info().map(|info| info.version.clone()),
tool_count: server.tools.len(),
error: None,
})
.collect();
for (name, error) in &self.errors {
if !self.servers.contains_key(name) {
summaries.push(McpServerSummary {
name: name.clone(),
status: McpServerStatus::Error,
server_version: None,
tool_count: 0,
error: Some(error.clone()),
});
}
}
summaries.sort_by(|a, b| a.name.cmp(&b.name));
summaries
}
pub fn all_tool_names(&self) -> Vec<String> {
self.servers
.iter()
.flat_map(|(name, server)| {
server
.tools
.iter()
.map(move |tool| mcp_tool_name(name, &tool.name))
})
.collect()
}
pub async fn call_tool(
&self,
server_name: &str,
tool_name: &str,
arguments: Option<serde_json::Value>,
) -> Result<super::protocol::McpToolCallResult, McpClientError> {
let server = self.servers.get(server_name).ok_or_else(|| {
McpClientError::ParseError(format!("MCP server '{}' not connected", server_name))
})?;
server.client.call_tool(tool_name, arguments).await
}
pub fn is_connected(&self, name: &str) -> bool {
self.servers.contains_key(name)
}
pub fn connected_count(&self) -> usize {
self.servers.len()
}
}
impl Default for McpManager {
fn default() -> Self {
Self::new()
}
}