1pub mod client;
28pub mod protocol;
29
30pub use client::McpClient;
31pub use protocol::*;
32
33use std::collections::HashMap;
34use std::sync::Arc;
35
36use anyhow::{anyhow, Result};
37use tokio::sync::RwLock;
38
39pub struct McpBridge {
62 servers: parking_lot::RwLock<Vec<McpServer>>,
64 clients: RwLock<HashMap<String, Arc<McpClient>>>,
66 tool_cache: RwLock<HashMap<String, Vec<McpTool>>>,
68}
69
70impl McpBridge {
71 pub fn new() -> Self {
73 Self {
74 servers: parking_lot::RwLock::new(Vec::new()),
75 clients: RwLock::new(HashMap::new()),
76 tool_cache: RwLock::new(HashMap::new()),
77 }
78 }
79
80 pub fn register_server(&self, server: McpServer) {
82 self.servers.write().push(server);
83 }
84
85 pub fn servers(&self) -> Vec<String> {
87 self.servers.read().iter().map(|s| s.name.clone()).collect()
88 }
89
90 pub fn get_server(&self, name: &str) -> Option<McpServer> {
92 self.servers.read().iter().find(|s| s.name == name).cloned()
93 }
94
95 pub async fn initialize_all(&self) -> Result<()> {
100 let mut errors = Vec::new();
101
102 let server_list: Vec<McpServer> = self.servers.read().iter().cloned().collect();
103 for server in server_list {
104 if !server.enabled {
105 tracing::debug!(server = %server.name, "Skipping disabled MCP server");
106 continue;
107 }
108
109 let client = Arc::new(McpClient::new(server.clone()));
110 match client.initialize().await {
111 Ok(()) => {
112 self.clients
113 .write()
114 .await
115 .insert(server.name.clone(), client);
116 tracing::info!(server = %server.name, "MCP server started");
117 }
118 Err(e) => {
119 tracing::error!(server = %server.name, error = %e, "Failed to initialize MCP server");
120 errors.push(format!("{}: {}", server.name, e));
121 }
122 }
123 }
124
125 if errors.is_empty() {
126 Ok(())
127 } else {
128 Err(anyhow!("MCP initialization failed: {}", errors.join("; ")))
129 }
130 }
131
132 pub async fn initialize_server(&self, name: &str) -> Result<()> {
134 let server = self
135 .servers
136 .read()
137 .iter()
138 .find(|s| s.name == name)
139 .cloned()
140 .ok_or_else(|| anyhow!("MCP server '{name}' not found"))?;
141
142 let client = Arc::new(McpClient::new(server));
143 client.initialize().await?;
144
145 self.clients.write().await.insert(name.to_string(), client);
146 Ok(())
147 }
148
149 pub async fn client(&self, name: &str) -> Option<Arc<McpClient>> {
151 self.clients.read().await.get(name).cloned()
152 }
153
154 pub async fn list_tools(&self) -> Result<Vec<McpTool>> {
158 let clients = self.clients.read().await;
159 let mut all_tools = Vec::new();
160
161 for (name, client) in clients.iter() {
162 if let Ok(mcp_tools) = client.list_tools().await {
163 let start = all_tools.len();
164 all_tools.extend(mcp_tools);
165 *self
166 .tool_cache
167 .write()
168 .await
169 .entry(name.clone())
170 .or_insert_with(Vec::new) = all_tools[start..].to_vec();
171 }
172 }
173
174 Ok(all_tools)
175 }
176
177 pub async fn cached_tools(&self, server_name: &str) -> Option<Vec<McpTool>> {
179 self.tool_cache.read().await.get(server_name).cloned()
180 }
181
182 pub async fn call_tool(
184 &self,
185 server_name: &str,
186 tool_name: &str,
187 args: serde_json::Value,
188 ) -> Result<McpToolCallResult> {
189 let clients = self.clients.read().await;
190 let client = clients
191 .get(server_name)
192 .ok_or_else(|| anyhow!("MCP server '{server_name}' not connected"))?;
193
194 client.call_tool(tool_name, args).await
195 }
196
197 pub async fn shutdown_all(&self) -> Result<()> {
199 let mut clients = self.clients.write().await;
200
201 for (name, client) in clients.drain() {
202 if let Err(e) = client.shutdown().await {
203 tracing::warn!(server = %name, error = %e, "Error shutting down MCP server");
204 }
205 }
206
207 self.tool_cache.write().await.clear();
208 Ok(())
209 }
210
211 pub async fn refresh_tools(&self, server_name: &str) -> Result<Vec<McpTool>> {
213 let clients = self.clients.read().await;
214 let client = clients
215 .get(server_name)
216 .ok_or_else(|| anyhow!("MCP server '{server_name}' not connected"))?;
217
218 let mcp_tools = client.refresh_tools().await?;
219
220 *self
221 .tool_cache
222 .write()
223 .await
224 .entry(server_name.to_string())
225 .or_insert_with(Vec::new) = mcp_tools.clone();
226
227 Ok(mcp_tools)
228 }
229
230 pub async fn clear_cache(&self, server_name: &str) {
232 self.tool_cache.write().await.remove(server_name);
233 }
234
235 pub async fn remove_server(&self, name: &str) -> Result<()> {
237 if let Some(client) = self.clients.write().await.remove(name) {
239 if let Err(e) = client.shutdown().await {
240 tracing::warn!(server = %name, error = %e, "Error shutting down MCP server during removal");
241 }
242 }
243 let found = {
245 let mut servers = self.servers.write();
246 let len_before = servers.len();
247 servers.retain(|s| s.name != name);
248 servers.len() != len_before
249 };
250 if !found {
251 return Err(anyhow!("MCP server '{name}' not found"));
252 }
253 self.tool_cache.write().await.remove(name);
255 Ok(())
256 }
257
258 pub async fn toggle_server(&self, name: &str) -> Result<bool> {
260 let new_state = {
263 let mut servers = self.servers.write();
264 let server = servers
265 .iter_mut()
266 .find(|s| s.name == name)
267 .ok_or_else(|| anyhow!("MCP server '{name}' not found"))?;
268 server.enabled = !server.enabled;
269 server.enabled
270 };
271
272 if !new_state {
274 if let Some(client) = self.clients.write().await.remove(name) {
275 if let Err(e) = client.shutdown().await {
276 tracing::warn!(server = %name, error = %e, "Error shutting down MCP server on disable");
277 }
278 }
279 self.tool_cache.write().await.remove(name);
280 }
281
282 Ok(new_state)
283 }
284
285 pub async fn clear_all_caches(&self) {
287 self.tool_cache.write().await.clear();
288 }
289}
290
291impl Default for McpBridge {
292 fn default() -> Self {
293 Self::new()
294 }
295}
296
297#[cfg(test)]
302mod tests {
303 use super::*;
304 use tokio::time::Duration;
305
306 #[test]
309 fn test_mcp_server_builder() {
310 let server = McpServer::new("test-server", "npx")
311 .with_args(vec!["-y".to_string(), "@anthropic/mcp-server".to_string()])
312 .with_env("DEBUG", "true");
313
314 assert_eq!(server.name, "test-server");
315 assert_eq!(server.command, "npx");
316 assert_eq!(server.args, vec!["-y", "@anthropic/mcp-server"]);
317 assert_eq!(server.env.get("DEBUG"), Some(&"true".to_string()));
318 assert!(server.enabled);
319 }
320
321 #[test]
324 fn test_mcp_request_serialization() {
325 let request = McpRequest::new("tools/list");
326 let json = serde_json::to_string(&request).unwrap();
327
328 assert!(json.contains(r#""method":"tools/list""#));
329 assert!(json.contains(r#""jsonrpc":"2.0""#));
330 }
331
332 #[test]
333 fn test_mcp_request_with_params() {
334 let request = McpRequest::new("tools/call").with_params(serde_json::json!({
335 "name": "my_tool",
336 "arguments": {"arg1": "value1"}
337 }));
338
339 let json = serde_json::to_string(&request).unwrap();
340 assert!(json.contains("my_tool"));
341 assert!(json.contains("arg1"));
342 }
343
344 #[test]
345 fn test_mcp_request_to_jsonl() {
346 let request = McpRequest::new("initialize");
347 let jsonl = request.to_jsonl().unwrap();
348
349 assert_eq!(jsonl.last(), Some(&b'\n'));
351
352 let json_str = String::from_utf8_lossy(&jsonl[..jsonl.len() - 1]);
354 let parsed: McpRequest = serde_json::from_str(&json_str).unwrap();
355 assert_eq!(parsed.method, "initialize");
356 }
357
358 #[test]
359 fn test_mcp_response_result() {
360 let response = McpResponse {
361 jsonrpc: "2.0".to_string(),
362 id: serde_json::json!(1),
363 result: Some(serde_json::json!({"tools": []})),
364 error: None,
365 };
366
367 assert!(!response.is_error());
368 let result = response.clone().into_result().unwrap();
369 assert!(result.get("tools").is_some());
370 }
371
372 #[test]
373 fn test_mcp_response_error() {
374 let response = McpResponse {
375 jsonrpc: "2.0".to_string(),
376 id: serde_json::json!(2),
377 result: None,
378 error: Some(McpError::internal_error("Something went wrong")),
379 };
380
381 assert!(response.is_error());
382 let err = response.into_result().unwrap_err();
383 assert!(err.to_string().contains("internal error"));
384 }
385
386 #[test]
387 fn test_mcp_error_codes() {
388 assert_eq!(McpError::parse_error().code, -32700);
389 assert_eq!(McpError::invalid_request("test").code, -32600);
390 assert_eq!(McpError::method_not_found().code, -32601);
391 assert_eq!(McpError::invalid_params().code, -32602);
392 assert_eq!(McpError::internal_error("x").code, -32603);
393 assert_eq!(McpError::server_error("x").code, -32000);
394 }
395
396 #[test]
399 fn test_bridge_registration() {
400 let bridge = McpBridge::new();
401
402 bridge.register_server(McpServer::new("test", "echo"));
403
404 assert_eq!(bridge.servers(), vec!["test"]);
405 assert!(bridge.get_server("test").is_some());
406 assert!(bridge.get_server("missing").is_none());
407 }
408
409 #[tokio::test]
412 async fn test_mcp_client_non_existent_command() {
413 let server = McpServer::new("ghost", "nonexistent-binary-xyz");
414 let client = McpClient::new(server);
415
416 let result = client.initialize().await;
417 assert!(result.is_err());
418 assert!(result.unwrap_err().to_string().contains("Failed to spawn"));
419 }
420
421 #[tokio::test]
422 async fn test_mcp_client_shutdown_no_panic() {
423 let server = McpServer::new("test-shutdown", "echo");
424 let client = McpClient::new(server);
425
426 client.shutdown().await.expect("shutdown should succeed");
428 assert!(!client.is_initialized().await);
429 }
430
431 #[tokio::test]
432 async fn test_mcp_client_with_timeout() {
433 let server = McpServer::new("test", "sleep").with_args(vec!["999".to_string()]);
434 let client = McpClient::new(server).with_timeout(Duration::from_millis(100));
435
436 let result = client.initialize().await;
438 assert!(result.is_err());
440 }
441
442 #[tokio::test]
445 async fn test_bridge_initialize_all_empty() {
446 let bridge = McpBridge::new();
447 bridge
448 .initialize_all()
449 .await
450 .expect("empty bridge should initialize");
451 }
452
453 #[tokio::test]
454 async fn test_bridge_initialize_all_fails_gracefully() {
455 let bridge = McpBridge::new();
456 bridge.register_server(McpServer::new("ghost", "nonexistent-cmd-xyz"));
457 bridge.register_server(McpServer::new("ghost2", "nonexistent-cmd-abc"));
458
459 let result = bridge.initialize_all().await;
460 assert!(result.is_err());
462 }
463
464 #[tokio::test]
465 async fn test_bridge_shutdown_all_empty() {
466 let bridge = McpBridge::new();
467 bridge
468 .shutdown_all()
469 .await
470 .expect("empty bridge shutdown should succeed");
471 }
472
473 #[tokio::test]
474 async fn test_bridge_call_tool_no_server() {
475 let bridge = McpBridge::new();
476 let result = bridge
477 .call_tool("ghost", "tool", serde_json::json!({}))
478 .await;
479 assert!(result.is_err());
480 assert!(result.unwrap_err().to_string().contains("not connected"));
481 }
482
483 #[tokio::test]
484 async fn test_bridge_initialize_server_not_found() {
485 let bridge = McpBridge::new();
486 let result = bridge.initialize_server("missing").await;
487 assert!(result.is_err());
488 }
489
490 #[test]
491 fn test_mcp_client_debug() {
492 let server = McpServer::new("debug-test", "echo");
493 let client = McpClient::new(server);
494 let debug = format!("{client:?}");
495 assert!(debug.contains("debug-test"));
496 }
497
498 #[tokio::test]
501 async fn test_mcp_client_double_init_ignored() {
502 let server = McpServer::new("echo", "echo");
503 let client = McpClient::new(server);
504
505 let _ = client.initialize().await;
507 let _ = client.is_initialized().await;
509 }
510}