use crate::mcp::{
adapter::McpToolAdapter,
config::{McpConfig, McpServerConfig},
};
use autoagents::core::tool::ToolT;
use rmcp::{
model::ClientInfo,
service::{RoleClient, RunningService, ServiceExt},
transport::{ConfigureCommandExt, TokioChildProcess},
};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::process::Command;
use tokio::sync::RwLock;
#[derive(Debug)]
pub struct McpServerConnection {
pub name: String,
pub service: Arc<RunningService<RoleClient, ClientInfo>>,
}
#[derive(Debug)]
pub struct McpToolsManager {
connections: Arc<RwLock<HashMap<String, McpServerConnection>>>,
tools: Arc<RwLock<Vec<Arc<dyn ToolT>>>>,
}
#[derive(Debug, thiserror::Error)]
pub enum McpError {
#[error("Server not found: {0}")]
ServerNotFound(String),
#[error("Connection failed: {0}")]
ConnectionFailed(String),
#[error("Transport error: {0}")]
TransportError(String),
#[error("Configuration error: {0}")]
ConfigError(String),
#[error("Tool error: {0}")]
ToolError(String),
#[error("Rmcp error: {0}")]
RmcpError(#[from] rmcp::ErrorData),
#[error("JSON error: {0}")]
JsonError(#[from] serde_json::Error),
#[error("Generic error: {0}")]
GenericError(String),
}
impl Default for McpToolsManager {
fn default() -> Self {
Self {
connections: Arc::new(RwLock::new(HashMap::new())),
tools: Arc::new(RwLock::new(Vec::new())),
}
}
}
impl McpToolsManager {
pub fn new() -> Self {
Self::default()
}
pub async fn from_config_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self, McpError> {
let config =
McpConfig::from_file(path).map_err(|e| McpError::ConfigError(e.to_string()))?;
let manager = Self::new();
manager.connect_servers(&config).await?;
Ok(manager)
}
pub async fn connect_servers(&self, config: &McpConfig) -> Result<(), McpError> {
let mut connections = self.connections.write().await;
let mut all_tools = self.tools.write().await;
for server_config in &config.servers {
match self.connect_server(server_config).await {
Ok(connection) => {
let server_name = connection.name.clone();
match self.load_server_tools(&connection).await {
Ok(tools) => {
log::info!(
"Connected to MCP server '{}' with {} tools",
server_name,
tools.len()
);
all_tools.extend(tools);
connections.insert(server_name.clone(), connection);
}
Err(e) => {
log::error!(
"Failed to load tools from server '{}': {}",
server_name,
e
);
return Err(e);
}
}
}
Err(e) => {
log::error!(
"Failed to connect to server '{}': {}",
server_config.name,
e
);
return Err(e);
}
}
}
Ok(())
}
async fn connect_server(
&self,
server_config: &McpServerConfig,
) -> Result<McpServerConnection, McpError> {
server_config.validate().map_err(McpError::ConfigError)?;
let service = match server_config.protocol.as_str() {
"stdio" => self.connect_stdio_server(server_config).await?,
_ => {
return Err(McpError::ConfigError(format!(
"Unsupported protocol: {}. Currently only 'stdio' is supported.",
server_config.protocol
)));
}
};
Ok(McpServerConnection {
name: server_config.name.clone(),
service,
})
}
async fn connect_stdio_server(
&self,
config: &McpServerConfig,
) -> Result<Arc<RunningService<RoleClient, ClientInfo>>, McpError> {
let mut command = Command::new(&config.command);
if !config.args.is_empty() {
command.args(&config.args);
}
if let Some(cwd) = &config.cwd {
command.current_dir(cwd);
}
for (key, value) in &config.env {
command.env(key, value);
}
let transport = TokioChildProcess::new(command.configure(|_| {}))
.map_err(|e| McpError::TransportError(e.to_string()))?;
let client_info = ClientInfo::default();
let service = client_info.serve(transport).await.map_err(|e| {
McpError::GenericError(format!("Failed to connect to MCP server: {:?}", e))
})?;
Ok(Arc::new(service))
}
async fn load_server_tools(
&self,
connection: &McpServerConnection,
) -> Result<Vec<Arc<dyn ToolT>>, McpError> {
let tools_result = connection
.service
.list_tools(None)
.await
.map_err(|e| McpError::GenericError(format!("Failed to list tools: {:?}", e)))?;
let mut adapted_tools = Vec::new();
for tool in tools_result.tools {
let adapter = McpToolAdapter::new(tool, Arc::clone(&connection.service));
adapted_tools.push(Arc::new(adapter) as Arc<dyn ToolT>);
}
Ok(adapted_tools)
}
pub async fn get_tools(&self) -> Vec<Arc<dyn ToolT>> {
self.tools.read().await.clone()
}
pub async fn get_server_tools(
&self,
server_name: &str,
) -> Result<Vec<Arc<dyn ToolT>>, McpError> {
let connections = self.connections.read().await;
let connection = connections
.get(server_name)
.ok_or_else(|| McpError::ServerNotFound(server_name.to_string()))?;
self.load_server_tools(connection).await
}
pub async fn get_tool(&self, tool_name: &str) -> Option<Arc<dyn ToolT>> {
let tools = self.tools.read().await;
tools.iter().find(|tool| tool.name() == tool_name).cloned()
}
pub async fn refresh_tools(&self) -> Result<(), McpError> {
let connections = self.connections.read().await;
let mut all_tools = Vec::new();
for connection in connections.values() {
let tools = self.load_server_tools(connection).await?;
all_tools.extend(tools);
}
*self.tools.write().await = all_tools;
Ok(())
}
pub async fn connected_servers(&self) -> Vec<String> {
self.connections.read().await.keys().cloned().collect()
}
pub async fn is_server_connected(&self, server_name: &str) -> bool {
self.connections.read().await.contains_key(server_name)
}
pub async fn tool_count(&self) -> usize {
self.tools.read().await.len()
}
pub async fn tool_names(&self) -> Vec<String> {
self.tools
.read()
.await
.iter()
.map(|tool| tool.name().to_string())
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_manager_basic_operations() {
let manager = McpToolsManager::default();
assert_eq!(manager.tool_count().await, 0);
assert!(manager.tool_names().await.is_empty());
assert!(manager.connected_servers().await.is_empty());
assert!(!manager.is_server_connected("nonexistent").await);
assert!(manager.get_tool("nonexistent").await.is_none());
}
fn invalid_protocol_config() -> McpServerConfig {
McpServerConfig::new(
"bad_server".to_string(),
"http".to_string(),
"noop".to_string(),
)
}
#[tokio::test]
async fn test_get_server_tools_missing() {
let manager = McpToolsManager::default();
let err = manager.get_server_tools("missing").await.unwrap_err();
assert!(matches!(err, McpError::ServerNotFound(_)));
}
#[tokio::test]
async fn test_connect_server_rejects_unsupported_protocol() {
let manager = McpToolsManager::default();
let config = invalid_protocol_config();
let err = manager.connect_server(&config).await.unwrap_err();
assert!(matches!(err, McpError::ConfigError(_)));
assert!(err.to_string().contains("Unsupported protocol"));
}
#[tokio::test]
async fn test_connect_servers_returns_error_on_invalid_protocol() {
let manager = McpToolsManager::default();
let mut config = McpConfig::new();
config.add_server(invalid_protocol_config());
let err = manager.connect_servers(&config).await.unwrap_err();
assert!(err.to_string().contains("Unsupported protocol"));
}
#[tokio::test]
async fn test_refresh_tools_on_empty_manager() {
let manager = McpToolsManager::default();
manager.refresh_tools().await.unwrap();
let tools = manager.get_tools().await;
assert!(tools.is_empty());
}
}