1mod client;
30mod protocol;
31
32pub use client::McpClient;
33pub use protocol::*;
34
35use std::collections::HashMap;
36use std::sync::Arc;
37
38use anyhow::{anyhow, Result};
39use tokio::sync::RwLock;
40
41use crate::program::ToolDef;
42
43pub struct McpBridge {
66 servers: parking_lot::RwLock<Vec<McpServer>>,
68 clients: RwLock<HashMap<String, Arc<McpClient>>>,
70 tool_cache: RwLock<HashMap<String, Vec<ToolDef>>>,
72}
73
74impl McpBridge {
75 pub fn new() -> Self {
77 Self {
78 servers: parking_lot::RwLock::new(Vec::new()),
79 clients: RwLock::new(HashMap::new()),
80 tool_cache: RwLock::new(HashMap::new()),
81 }
82 }
83
84 pub fn register_server(&self, server: McpServer) {
86 self.servers.write().push(server);
87 }
88
89 pub fn servers(&self) -> Vec<String> {
91 self.servers.read().iter().map(|s| s.name.clone()).collect()
92 }
93
94 pub fn get_server(&self, name: &str) -> Option<McpServer> {
96 self.servers.read().iter().find(|s| s.name == name).cloned()
97 }
98
99 pub async fn initialize_all(&self) -> Result<()> {
104 let mut errors = Vec::new();
105
106 let server_list: Vec<McpServer> = self.servers.read().iter().cloned().collect();
107 for server in server_list {
108 if !server.enabled {
109 tracing::debug!(server = %server.name, "Skipping disabled MCP server");
110 continue;
111 }
112
113 let client = Arc::new(McpClient::new(server.clone()));
114 match client.initialize().await {
115 Ok(()) => {
116 self.clients
117 .write()
118 .await
119 .insert(server.name.clone(), client);
120 tracing::info!(server = %server.name, "MCP server started");
121 }
122 Err(e) => {
123 tracing::error!(server = %server.name, error = %e, "Failed to initialize MCP server");
124 errors.push(format!("{}: {}", server.name, e));
125 }
126 }
127 }
128
129 if errors.is_empty() {
130 Ok(())
131 } else {
132 Err(anyhow!("MCP initialization failed: {}", errors.join("; ")))
133 }
134 }
135
136 pub async fn initialize_server(&self, name: &str) -> Result<()> {
138 let server = self
139 .servers
140 .read()
141 .iter()
142 .find(|s| s.name == name)
143 .cloned()
144 .ok_or_else(|| anyhow!("MCP server '{}' not found", name))?;
145
146 let client = Arc::new(McpClient::new(server));
147 client.initialize().await?;
148
149 self.clients.write().await.insert(name.to_string(), client);
150 Ok(())
151 }
152
153 pub async fn client(&self, name: &str) -> Option<Arc<McpClient>> {
155 self.clients.read().await.get(name).cloned()
156 }
157
158 pub async fn list_tools(&self) -> Result<Vec<ToolDef>> {
162 let clients = self.clients.read().await;
163 let mut all_tools = Vec::new();
164
165 for (name, client) in clients.iter() {
166 if let Ok(mcp_tools) = client.list_tools().await {
167 let defs: Vec<ToolDef> = mcp_tools.iter().map(|t| t.to_tool_def()).collect();
168 let start = all_tools.len();
169 all_tools.extend(defs);
170 *self
171 .tool_cache
172 .write()
173 .await
174 .entry(name.clone())
175 .or_insert_with(Vec::new) = all_tools[start..].to_vec();
176 }
177 }
178
179 Ok(all_tools)
180 }
181
182 pub async fn cached_tools(&self, server_name: &str) -> Option<Vec<ToolDef>> {
184 self.tool_cache.read().await.get(server_name).cloned()
185 }
186
187 pub async fn call_tool(
189 &self,
190 server_name: &str,
191 tool_name: &str,
192 args: serde_json::Value,
193 ) -> Result<McpToolCallResult> {
194 let clients = self.clients.read().await;
195 let client = clients
196 .get(server_name)
197 .ok_or_else(|| anyhow!("MCP server '{}' not connected", server_name))?;
198
199 client.call_tool(tool_name, args).await
200 }
201
202 pub async fn shutdown_all(&self) -> Result<()> {
204 let mut clients = self.clients.write().await;
205
206 for (name, client) in clients.drain() {
207 if let Err(e) = client.shutdown().await {
208 tracing::warn!(server = %name, error = %e, "Error shutting down MCP server");
209 }
210 }
211
212 self.tool_cache.write().await.clear();
213 Ok(())
214 }
215
216 pub async fn refresh_tools(&self, server_name: &str) -> Result<Vec<ToolDef>> {
218 let clients = self.clients.read().await;
219 let client = clients
220 .get(server_name)
221 .ok_or_else(|| anyhow!("MCP server '{}' not connected", server_name))?;
222
223 let mcp_tools = client.refresh_tools().await?;
224 let defs: Vec<ToolDef> = mcp_tools.iter().map(|t| t.to_tool_def()).collect();
225
226 *self
227 .tool_cache
228 .write()
229 .await
230 .entry(server_name.to_string())
231 .or_insert_with(Vec::new) = defs.clone();
232
233 Ok(defs)
234 }
235
236 pub async fn clear_cache(&self, server_name: &str) {
238 self.tool_cache.write().await.remove(server_name);
239 }
240
241 pub async fn clear_all_caches(&self) {
243 self.tool_cache.write().await.clear();
244 }
245}
246
247impl Default for McpBridge {
248 fn default() -> Self {
249 Self::new()
250 }
251}
252
253#[cfg(test)]
258mod tests {
259 use super::*;
260 use tokio::time::Duration;
261
262 #[test]
265 fn test_mcp_server_builder() {
266 let server = McpServer::new("test-server", "npx")
267 .with_args(vec!["-y".to_string(), "@anthropic/mcp-server".to_string()])
268 .with_env("DEBUG", "true");
269
270 assert_eq!(server.name, "test-server");
271 assert_eq!(server.command, "npx");
272 assert_eq!(server.args, vec!["-y", "@anthropic/mcp-server"]);
273 assert_eq!(server.env.get("DEBUG"), Some(&"true".to_string()));
274 assert!(server.enabled);
275 }
276
277 #[test]
280 fn test_mcp_request_serialization() {
281 let request = McpRequest::new("tools/list");
282 let json = serde_json::to_string(&request).unwrap();
283
284 assert!(json.contains(r#""method":"tools/list""#));
285 assert!(json.contains(r#""jsonrpc":"2.0""#));
286 }
287
288 #[test]
289 fn test_mcp_request_with_params() {
290 let request = McpRequest::new("tools/call").with_params(serde_json::json!({
291 "name": "my_tool",
292 "arguments": {"arg1": "value1"}
293 }));
294
295 let json = serde_json::to_string(&request).unwrap();
296 assert!(json.contains("my_tool"));
297 assert!(json.contains("arg1"));
298 }
299
300 #[test]
301 fn test_mcp_request_to_jsonl() {
302 let request = McpRequest::new("initialize");
303 let jsonl = request.to_jsonl().unwrap();
304
305 assert_eq!(jsonl.last(), Some(&b'\n'));
307
308 let json_str = String::from_utf8_lossy(&jsonl[..jsonl.len() - 1]);
310 let parsed: McpRequest = serde_json::from_str(&json_str).unwrap();
311 assert_eq!(parsed.method, "initialize");
312 }
313
314 #[test]
315 fn test_mcp_response_result() {
316 let response = McpResponse {
317 jsonrpc: "2.0".to_string(),
318 id: serde_json::json!(1),
319 result: Some(serde_json::json!({"tools": []})),
320 error: None,
321 };
322
323 assert!(!response.is_error());
324 let result = response.clone().into_result().unwrap();
325 assert!(result.get("tools").is_some());
326 }
327
328 #[test]
329 fn test_mcp_response_error() {
330 let response = McpResponse {
331 jsonrpc: "2.0".to_string(),
332 id: serde_json::json!(2),
333 result: None,
334 error: Some(McpError::internal_error("Something went wrong")),
335 };
336
337 assert!(response.is_error());
338 let err = response.into_result().unwrap_err();
339 assert!(err.to_string().contains("internal error"));
340 }
341
342 #[test]
343 fn test_mcp_error_codes() {
344 assert_eq!(McpError::parse_error().code, -32700);
345 assert_eq!(McpError::invalid_request("test").code, -32600);
346 assert_eq!(McpError::method_not_found().code, -32601);
347 assert_eq!(McpError::invalid_params().code, -32602);
348 assert_eq!(McpError::internal_error("x").code, -32603);
349 assert_eq!(McpError::server_error("x").code, -32000);
350 }
351
352 #[test]
355 fn test_mcp_tool_conversion() {
356 let mcp_tool = McpTool {
357 name: "test_tool".to_string(),
358 description: "A test tool".to_string(),
359 input_schema: serde_json::json!({
360 "arg1": {
361 "type": "string",
362 "description": "First argument"
363 },
364 "arg2": {
365 "type": "number",
366 "description": "Second argument",
367 "default": "42"
368 }
369 }),
370 };
371
372 let tool_def = mcp_tool.to_tool_def();
373
374 assert_eq!(tool_def.name, "test_tool");
375 assert_eq!(tool_def.description, "A test tool");
376 assert_eq!(tool_def.arguments.len(), 2);
377
378 let arg1 = tool_def
379 .arguments
380 .iter()
381 .find(|a| a.name == "arg1")
382 .unwrap();
383 assert!(arg1.required);
384 assert_eq!(arg1.description, "First argument");
385
386 let arg2 = tool_def
387 .arguments
388 .iter()
389 .find(|a| a.name == "arg2")
390 .unwrap();
391 assert!(!arg2.required);
392 assert_eq!(arg2.default, Some("42".to_string()));
393 }
394
395 #[test]
398 fn test_bridge_registration() {
399 let bridge = McpBridge::new();
400
401 bridge.register_server(McpServer::new("test", "echo"));
402
403 assert_eq!(bridge.servers(), vec!["test"]);
404 assert!(bridge.get_server("test").is_some());
405 assert!(bridge.get_server("missing").is_none());
406 }
407
408 #[tokio::test]
411 async fn test_mcp_client_non_existent_command() {
412 let server = McpServer::new("ghost", "nonexistent-binary-xyz");
413 let client = McpClient::new(server);
414
415 let result = client.initialize().await;
416 assert!(result.is_err());
417 assert!(result.unwrap_err().to_string().contains("Failed to spawn"));
418 }
419
420 #[tokio::test]
421 async fn test_mcp_client_shutdown_no_panic() {
422 let server = McpServer::new("test-shutdown", "echo");
423 let client = McpClient::new(server);
424
425 client.shutdown().await.expect("shutdown should succeed");
427 assert!(!client.is_initialized().await);
428 }
429
430 #[tokio::test]
431 async fn test_mcp_client_with_timeout() {
432 let server = McpServer::new("test", "sleep").with_args(vec!["999".to_string()]);
433 let client = McpClient::new(server).with_timeout(Duration::from_millis(100));
434
435 let result = client.initialize().await;
437 assert!(result.is_err());
439 }
440
441 #[tokio::test]
444 async fn test_bridge_initialize_all_empty() {
445 let bridge = McpBridge::new();
446 bridge
447 .initialize_all()
448 .await
449 .expect("empty bridge should initialize");
450 }
451
452 #[tokio::test]
453 async fn test_bridge_initialize_all_fails_gracefully() {
454 let bridge = McpBridge::new();
455 bridge.register_server(McpServer::new("ghost", "nonexistent-cmd-xyz"));
456 bridge.register_server(McpServer::new("ghost2", "nonexistent-cmd-abc"));
457
458 let result = bridge.initialize_all().await;
459 assert!(result.is_err());
461 }
462
463 #[tokio::test]
464 async fn test_bridge_shutdown_all_empty() {
465 let bridge = McpBridge::new();
466 bridge
467 .shutdown_all()
468 .await
469 .expect("empty bridge shutdown should succeed");
470 }
471
472 #[tokio::test]
473 async fn test_bridge_call_tool_no_server() {
474 let bridge = McpBridge::new();
475 let result = bridge
476 .call_tool("ghost", "tool", serde_json::json!({}))
477 .await;
478 assert!(result.is_err());
479 assert!(result.unwrap_err().to_string().contains("not connected"));
480 }
481
482 #[tokio::test]
483 async fn test_bridge_initialize_server_not_found() {
484 let bridge = McpBridge::new();
485 let result = bridge.initialize_server("missing").await;
486 assert!(result.is_err());
487 }
488
489 #[test]
490 fn test_mcp_client_debug() {
491 let server = McpServer::new("debug-test", "echo");
492 let client = McpClient::new(server);
493 let debug = format!("{:?}", client);
494 assert!(debug.contains("debug-test"));
495 }
496
497 #[cfg(unix)]
500 #[tokio::test]
501 #[ignore = "Requires bash shell environment with executable script support"]
502 async fn test_jsonrpc_echo_server() {
503 use std::os::unix::fs::PermissionsExt;
504 let temp_script = tempfile::tempdir().unwrap().path().join("mcp_echo.sh");
506 std::fs::write(
507 &temp_script,
508 r#"#!/bin/bash
509while IFS= read -r line; do
510 echo "$line"
511done
512"#,
513 )
514 .unwrap();
515 std::fs::set_permissions(&temp_script, std::fs::Permissions::from_mode(0o755)).unwrap();
516
517 let bridge = McpBridge::new();
518 bridge.register_server(
519 McpServer::new("echo-server", "bash")
520 .with_args(vec![temp_script.to_string_lossy().to_string()]),
521 );
522
523 bridge.initialize_all().await.unwrap();
524
525 let client = bridge.client("echo-server").await.unwrap();
526 let request =
527 McpRequest::new("tools/list").with_params(serde_json::json!({"test": "value"}));
528 let response = client.send_request(request).await;
529
530 if response.is_ok() {
534 tracing::info!("Echo server responded successfully");
535 }
536 }
537
538 #[tokio::test]
541 async fn test_mcp_client_double_init_ignored() {
542 let server = McpServer::new("echo", "echo");
543 let client = McpClient::new(server);
544
545 let _ = client.initialize().await;
547 let _ = client.is_initialized().await;
549 }
550}