pub mod client;
pub mod protocol;
pub use client::McpClient;
pub use protocol::*;
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::{anyhow, Result};
use tokio::sync::RwLock;
pub struct McpBridge {
servers: parking_lot::RwLock<Vec<McpServer>>,
clients: RwLock<HashMap<String, Arc<McpClient>>>,
tool_cache: RwLock<HashMap<String, Vec<McpTool>>>,
}
impl McpBridge {
pub fn new() -> Self {
Self {
servers: parking_lot::RwLock::new(Vec::new()),
clients: RwLock::new(HashMap::new()),
tool_cache: RwLock::new(HashMap::new()),
}
}
pub fn register_server(&self, server: McpServer) {
self.servers.write().push(server);
}
pub fn servers(&self) -> Vec<String> {
self.servers.read().iter().map(|s| s.name.clone()).collect()
}
pub fn get_server(&self, name: &str) -> Option<McpServer> {
self.servers.read().iter().find(|s| s.name == name).cloned()
}
pub async fn initialize_all(&self) -> Result<()> {
let mut errors = Vec::new();
let server_list: Vec<McpServer> = self.servers.read().iter().cloned().collect();
for server in server_list {
if !server.enabled {
tracing::debug!(server = %server.name, "Skipping disabled MCP server");
continue;
}
let client = Arc::new(McpClient::new(server.clone()));
match client.initialize().await {
Ok(()) => {
self.clients
.write()
.await
.insert(server.name.clone(), client);
tracing::info!(server = %server.name, "MCP server started");
}
Err(e) => {
tracing::error!(server = %server.name, error = %e, "Failed to initialize MCP server");
errors.push(format!("{}: {}", server.name, e));
}
}
}
if errors.is_empty() {
Ok(())
} else {
Err(anyhow!("MCP initialization failed: {}", errors.join("; ")))
}
}
pub async fn initialize_server(&self, name: &str) -> Result<()> {
let server = self
.servers
.read()
.iter()
.find(|s| s.name == name)
.cloned()
.ok_or_else(|| anyhow!("MCP server '{}' not found", name))?;
let client = Arc::new(McpClient::new(server));
client.initialize().await?;
self.clients.write().await.insert(name.to_string(), client);
Ok(())
}
pub async fn client(&self, name: &str) -> Option<Arc<McpClient>> {
self.clients.read().await.get(name).cloned()
}
pub async fn list_tools(&self) -> Result<Vec<McpTool>> {
let clients = self.clients.read().await;
let mut all_tools = Vec::new();
for (name, client) in clients.iter() {
if let Ok(mcp_tools) = client.list_tools().await {
let start = all_tools.len();
all_tools.extend(mcp_tools);
*self
.tool_cache
.write()
.await
.entry(name.clone())
.or_insert_with(Vec::new) = all_tools[start..].to_vec();
}
}
Ok(all_tools)
}
pub async fn cached_tools(&self, server_name: &str) -> Option<Vec<McpTool>> {
self.tool_cache.read().await.get(server_name).cloned()
}
pub async fn call_tool(
&self,
server_name: &str,
tool_name: &str,
args: serde_json::Value,
) -> Result<McpToolCallResult> {
let clients = self.clients.read().await;
let client = clients
.get(server_name)
.ok_or_else(|| anyhow!("MCP server '{}' not connected", server_name))?;
client.call_tool(tool_name, args).await
}
pub async fn shutdown_all(&self) -> Result<()> {
let mut clients = self.clients.write().await;
for (name, client) in clients.drain() {
if let Err(e) = client.shutdown().await {
tracing::warn!(server = %name, error = %e, "Error shutting down MCP server");
}
}
self.tool_cache.write().await.clear();
Ok(())
}
pub async fn refresh_tools(&self, server_name: &str) -> Result<Vec<McpTool>> {
let clients = self.clients.read().await;
let client = clients
.get(server_name)
.ok_or_else(|| anyhow!("MCP server '{}' not connected", server_name))?;
let mcp_tools = client.refresh_tools().await?;
*self
.tool_cache
.write()
.await
.entry(server_name.to_string())
.or_insert_with(Vec::new) = mcp_tools.clone();
Ok(mcp_tools)
}
pub async fn clear_cache(&self, server_name: &str) {
self.tool_cache.write().await.remove(server_name);
}
pub async fn clear_all_caches(&self) {
self.tool_cache.write().await.clear();
}
}
impl Default for McpBridge {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::Duration;
#[test]
fn test_mcp_server_builder() {
let server = McpServer::new("test-server", "npx")
.with_args(vec!["-y".to_string(), "@anthropic/mcp-server".to_string()])
.with_env("DEBUG", "true");
assert_eq!(server.name, "test-server");
assert_eq!(server.command, "npx");
assert_eq!(server.args, vec!["-y", "@anthropic/mcp-server"]);
assert_eq!(server.env.get("DEBUG"), Some(&"true".to_string()));
assert!(server.enabled);
}
#[test]
fn test_mcp_request_serialization() {
let request = McpRequest::new("tools/list");
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains(r#""method":"tools/list""#));
assert!(json.contains(r#""jsonrpc":"2.0""#));
}
#[test]
fn test_mcp_request_with_params() {
let request = McpRequest::new("tools/call").with_params(serde_json::json!({
"name": "my_tool",
"arguments": {"arg1": "value1"}
}));
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("my_tool"));
assert!(json.contains("arg1"));
}
#[test]
fn test_mcp_request_to_jsonl() {
let request = McpRequest::new("initialize");
let jsonl = request.to_jsonl().unwrap();
assert_eq!(jsonl.last(), Some(&b'\n'));
let json_str = String::from_utf8_lossy(&jsonl[..jsonl.len() - 1]);
let parsed: McpRequest = serde_json::from_str(&json_str).unwrap();
assert_eq!(parsed.method, "initialize");
}
#[test]
fn test_mcp_response_result() {
let response = McpResponse {
jsonrpc: "2.0".to_string(),
id: serde_json::json!(1),
result: Some(serde_json::json!({"tools": []})),
error: None,
};
assert!(!response.is_error());
let result = response.clone().into_result().unwrap();
assert!(result.get("tools").is_some());
}
#[test]
fn test_mcp_response_error() {
let response = McpResponse {
jsonrpc: "2.0".to_string(),
id: serde_json::json!(2),
result: None,
error: Some(McpError::internal_error("Something went wrong")),
};
assert!(response.is_error());
let err = response.into_result().unwrap_err();
assert!(err.to_string().contains("internal error"));
}
#[test]
fn test_mcp_error_codes() {
assert_eq!(McpError::parse_error().code, -32700);
assert_eq!(McpError::invalid_request("test").code, -32600);
assert_eq!(McpError::method_not_found().code, -32601);
assert_eq!(McpError::invalid_params().code, -32602);
assert_eq!(McpError::internal_error("x").code, -32603);
assert_eq!(McpError::server_error("x").code, -32000);
}
#[test]
fn test_bridge_registration() {
let bridge = McpBridge::new();
bridge.register_server(McpServer::new("test", "echo"));
assert_eq!(bridge.servers(), vec!["test"]);
assert!(bridge.get_server("test").is_some());
assert!(bridge.get_server("missing").is_none());
}
#[tokio::test]
async fn test_mcp_client_non_existent_command() {
let server = McpServer::new("ghost", "nonexistent-binary-xyz");
let client = McpClient::new(server);
let result = client.initialize().await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Failed to spawn"));
}
#[tokio::test]
async fn test_mcp_client_shutdown_no_panic() {
let server = McpServer::new("test-shutdown", "echo");
let client = McpClient::new(server);
client.shutdown().await.expect("shutdown should succeed");
assert!(!client.is_initialized().await);
}
#[tokio::test]
async fn test_mcp_client_with_timeout() {
let server = McpServer::new("test", "sleep").with_args(vec!["999".to_string()]);
let client = McpClient::new(server).with_timeout(Duration::from_millis(100));
let result = client.initialize().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_bridge_initialize_all_empty() {
let bridge = McpBridge::new();
bridge
.initialize_all()
.await
.expect("empty bridge should initialize");
}
#[tokio::test]
async fn test_bridge_initialize_all_fails_gracefully() {
let bridge = McpBridge::new();
bridge.register_server(McpServer::new("ghost", "nonexistent-cmd-xyz"));
bridge.register_server(McpServer::new("ghost2", "nonexistent-cmd-abc"));
let result = bridge.initialize_all().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_bridge_shutdown_all_empty() {
let bridge = McpBridge::new();
bridge
.shutdown_all()
.await
.expect("empty bridge shutdown should succeed");
}
#[tokio::test]
async fn test_bridge_call_tool_no_server() {
let bridge = McpBridge::new();
let result = bridge
.call_tool("ghost", "tool", serde_json::json!({}))
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not connected"));
}
#[tokio::test]
async fn test_bridge_initialize_server_not_found() {
let bridge = McpBridge::new();
let result = bridge.initialize_server("missing").await;
assert!(result.is_err());
}
#[test]
fn test_mcp_client_debug() {
let server = McpServer::new("debug-test", "echo");
let client = McpClient::new(server);
let debug = format!("{:?}", client);
assert!(debug.contains("debug-test"));
}
#[tokio::test]
async fn test_mcp_client_double_init_ignored() {
let server = McpServer::new("echo", "echo");
let client = McpClient::new(server);
let _ = client.initialize().await;
let _ = client.is_initialized().await;
}
}