use super::config::McpServerConfig;
use super::types::{
CallToolParams, CallToolResult, GetPromptResult, ListPromptsResult, ListResourcesResult,
ListToolsResult, McpPrompt, McpResource, McpTool, ReadResourceResult, ServerCapabilities,
};
use async_trait::async_trait;
use std::collections::HashMap;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum McpError {
#[error("Connection failed: {0}")]
ConnectionFailed(String),
#[error("Server disconnected")]
Disconnected,
#[error("Request timed out")]
Timeout,
#[error("Invalid response: {0}")]
InvalidResponse(String),
#[error("Protocol error: {0}")]
ProtocolError(String),
#[error("Tool error: {0}")]
ToolError(String),
#[error("Server error [{code}]: {message}")]
ServerError { code: i32, message: String },
#[error("I/O error: {0}")]
IoError(#[from] std::io::Error),
#[error("JSON error: {0}")]
JsonError(#[from] serde_json::Error),
}
pub type McpResult<T> = Result<T, McpError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ServerState {
Disconnected,
Connecting,
Connected,
Failed,
ShuttingDown,
}
impl std::fmt::Display for ServerState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Disconnected => write!(f, "disconnected"),
Self::Connecting => write!(f, "connecting"),
Self::Connected => write!(f, "connected"),
Self::Failed => write!(f, "failed"),
Self::ShuttingDown => write!(f, "shutting_down"),
}
}
}
#[derive(Debug, Clone)]
pub struct ServerInfo {
pub name: String,
pub server_name: Option<String>,
pub server_version: Option<String>,
pub capabilities: ServerCapabilities,
pub state: ServerState,
}
#[async_trait]
pub trait McpServer: Send + Sync {
fn name(&self) -> &str;
fn config(&self) -> &McpServerConfig;
fn state(&self) -> ServerState;
fn info(&self) -> Option<&ServerInfo>;
async fn connect(&mut self) -> McpResult<()>;
async fn disconnect(&mut self) -> McpResult<()>;
fn is_connected(&self) -> bool {
self.state() == ServerState::Connected
}
async fn list_tools(&self) -> McpResult<ListToolsResult>;
async fn call_tool(&self, params: CallToolParams) -> McpResult<CallToolResult>;
async fn list_resources(&self) -> McpResult<ListResourcesResult>;
async fn read_resource(&self, uri: &str) -> McpResult<ReadResourceResult>;
async fn list_prompts(&self) -> McpResult<ListPromptsResult>;
async fn get_prompt(
&self,
name: &str,
arguments: Option<HashMap<String, String>>,
) -> McpResult<GetPromptResult>;
}
#[derive(Default)]
pub struct McpServerManager {
servers: HashMap<String, Box<dyn McpServer>>,
}
impl McpServerManager {
pub fn new() -> Self {
Self {
servers: HashMap::new(),
}
}
pub fn add_server(&mut self, server: Box<dyn McpServer>) {
let name = server.name().to_string();
self.servers.insert(name, server);
}
pub fn remove_server(&mut self, name: &str) -> Option<Box<dyn McpServer>> {
self.servers.remove(name)
}
pub fn get(&self, name: &str) -> Option<&dyn McpServer> {
self.servers.get(name).map(|s| s.as_ref())
}
pub fn get_mut(&mut self, name: &str) -> Option<&mut Box<dyn McpServer>> {
self.servers.get_mut(name)
}
pub fn names(&self) -> Vec<&str> {
self.servers.keys().map(|s| s.as_str()).collect()
}
pub fn connected(&self) -> Vec<&dyn McpServer> {
self.servers
.values()
.filter(|s| s.is_connected())
.map(|s| s.as_ref())
.collect()
}
pub async fn connect_all(&mut self) -> Vec<McpResult<()>> {
let mut results = Vec::new();
for server in self.servers.values_mut() {
results.push(server.connect().await);
}
results
}
pub async fn disconnect_all(&mut self) -> Vec<McpResult<()>> {
let mut results = Vec::new();
for server in self.servers.values_mut() {
results.push(server.disconnect().await);
}
results
}
pub async fn list_all_tools(&self) -> Vec<(String, Vec<McpTool>)> {
let mut all_tools = Vec::new();
for (name, server) in &self.servers {
if server.is_connected() {
if let Ok(result) = server.list_tools().await {
all_tools.push((name.clone(), result.tools));
}
}
}
all_tools
}
pub async fn list_all_resources(&self) -> Vec<(String, Vec<McpResource>)> {
let mut all_resources = Vec::new();
for (name, server) in &self.servers {
if server.is_connected() {
if let Ok(result) = server.list_resources().await {
all_resources.push((name.clone(), result.resources));
}
}
}
all_resources
}
pub async fn list_all_prompts(&self) -> Vec<(String, Vec<McpPrompt>)> {
let mut all_prompts = Vec::new();
for (name, server) in &self.servers {
if server.is_connected() {
if let Ok(result) = server.list_prompts().await {
all_prompts.push((name.clone(), result.prompts));
}
}
}
all_prompts
}
}
impl std::fmt::Debug for McpServerManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("McpServerManager")
.field("servers", &self.names())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_server_state_display() {
assert_eq!(ServerState::Disconnected.to_string(), "disconnected");
assert_eq!(ServerState::Connected.to_string(), "connected");
assert_eq!(ServerState::Failed.to_string(), "failed");
}
#[test]
fn test_mcp_error_display() {
let err = McpError::ConnectionFailed("timeout".to_string());
assert!(err.to_string().contains("timeout"));
let err = McpError::ServerError {
code: -32601,
message: "Method not found".to_string(),
};
assert!(err.to_string().contains("-32601"));
assert!(err.to_string().contains("Method not found"));
}
#[test]
fn test_server_manager_new() {
let manager = McpServerManager::new();
assert!(manager.names().is_empty());
}
}