1mod client;
30mod protocol;
31
32pub use client::McpClient;
33pub use protocol::*;
34
35use std::collections::HashMap;
36use std::sync::Arc;
37
38use anyhow::{anyhow, Result};
39use tokio::sync::RwLock;
40
41use crate::program::ToolDef;
42
43pub struct McpBridge {
66 servers: parking_lot::RwLock<Vec<McpServer>>,
68 clients: RwLock<HashMap<String, Arc<McpClient>>>,
70 tool_cache: RwLock<HashMap<String, Vec<ToolDef>>>,
72}
73
74impl McpBridge {
75 pub fn new() -> Self {
77 Self {
78 servers: parking_lot::RwLock::new(Vec::new()),
79 clients: RwLock::new(HashMap::new()),
80 tool_cache: RwLock::new(HashMap::new()),
81 }
82 }
83
84 pub fn register_server(&self, server: McpServer) {
86 self.servers.write().push(server);
87 }
88
89 pub fn servers(&self) -> Vec<String> {
91 self.servers.read().iter().map(|s| s.name.clone()).collect()
92 }
93
94 pub fn get_server(&self, name: &str) -> Option<McpServer> {
96 self.servers.read().iter().find(|s| s.name == name).cloned()
97 }
98
99 pub async fn initialize_all(&self) -> Result<()> {
104 let mut errors = Vec::new();
105
106 let server_list: Vec<McpServer> = self.servers.read().iter().cloned().collect();
107 for server in server_list {
108 if !server.enabled {
109 tracing::debug!(server = %server.name, "Skipping disabled MCP server");
110 continue;
111 }
112
113 let client = Arc::new(McpClient::new(server.clone()));
114 match client.initialize().await {
115 Ok(()) => {
116 self.clients
117 .write()
118 .await
119 .insert(server.name.clone(), client);
120 tracing::info!(server = %server.name, "MCP server started");
121 }
122 Err(e) => {
123 tracing::error!(server = %server.name, error = %e, "Failed to initialize MCP server");
124 errors.push(format!("{}: {}", server.name, e));
125 }
126 }
127 }
128
129 if errors.is_empty() {
130 Ok(())
131 } else {
132 Err(anyhow!("MCP initialization failed: {}", errors.join("; ")))
133 }
134 }
135
136 pub async fn initialize_server(&self, name: &str) -> Result<()> {
138 let server = self
139 .servers
140 .read()
141 .iter()
142 .find(|s| s.name == name)
143 .cloned()
144 .ok_or_else(|| anyhow!("MCP server '{}' not found", name))?;
145
146 let client = Arc::new(McpClient::new(server));
147 client.initialize().await?;
148
149 self.clients.write().await.insert(name.to_string(), client);
150 Ok(())
151 }
152
153 pub async fn client(&self, name: &str) -> Option<Arc<McpClient>> {
155 self.clients.read().await.get(name).cloned()
156 }
157
158 pub async fn list_tools(&self) -> Result<Vec<ToolDef>> {
162 let clients = self.clients.read().await;
163 let mut all_tools = Vec::new();
164
165 for (name, client) in clients.iter() {
166 if let Ok(mcp_tools) = client.list_tools().await {
167 let defs: Vec<ToolDef> = mcp_tools.iter().map(|t| t.to_tool_def()).collect();
168 let start = all_tools.len();
169 all_tools.extend(defs);
170 *self
171 .tool_cache
172 .write()
173 .await
174 .entry(name.clone())
175 .or_insert_with(Vec::new) = all_tools[start..].to_vec();
176 }
177 }
178
179 Ok(all_tools)
180 }
181
182 pub async fn cached_tools(&self, server_name: &str) -> Option<Vec<ToolDef>> {
184 self.tool_cache.read().await.get(server_name).cloned()
185 }
186
187 pub async fn call_tool(
189 &self,
190 server_name: &str,
191 tool_name: &str,
192 args: serde_json::Value,
193 ) -> Result<McpToolCallResult> {
194 let clients = self.clients.read().await;
195 let client = clients
196 .get(server_name)
197 .ok_or_else(|| anyhow!("MCP server '{}' not connected", server_name))?;
198
199 client.call_tool(tool_name, args).await
200 }
201
202 pub async fn shutdown_all(&self) -> Result<()> {
204 let mut clients = self.clients.write().await;
205
206 for (name, client) in clients.drain() {
207 if let Err(e) = client.shutdown().await {
208 tracing::warn!(server = %name, error = %e, "Error shutting down MCP server");
209 }
210 }
211
212 self.tool_cache.write().await.clear();
213 Ok(())
214 }
215
216 pub async fn refresh_tools(&self, server_name: &str) -> Result<Vec<ToolDef>> {
218 let clients = self.clients.read().await;
219 let client = clients
220 .get(server_name)
221 .ok_or_else(|| anyhow!("MCP server '{}' not connected", server_name))?;
222
223 let mcp_tools = client.refresh_tools().await?;
224 let defs: Vec<ToolDef> = mcp_tools.iter().map(|t| t.to_tool_def()).collect();
225
226 *self
227 .tool_cache
228 .write()
229 .await
230 .entry(server_name.to_string())
231 .or_insert_with(Vec::new) = defs.clone();
232
233 Ok(defs)
234 }
235
236 pub async fn clear_cache(&self, server_name: &str) {
238 self.tool_cache.write().await.remove(server_name);
239 }
240
241 pub async fn clear_all_caches(&self) {
243 self.tool_cache.write().await.clear();
244 }
245}
246
247impl Default for McpBridge {
248 fn default() -> Self {
249 Self::new()
250 }
251}
252
253#[cfg(test)]
258mod tests {
259 use super::*;
260 use tokio::time::Duration;
261
262 #[test]
265 fn test_mcp_server_builder() {
266 let server = McpServer::new("test-server", "npx")
267 .with_args(vec!["-y".to_string(), "@anthropic/mcp-server".to_string()])
268 .with_env("DEBUG", "true");
269
270 assert_eq!(server.name, "test-server");
271 assert_eq!(server.command, "npx");
272 assert_eq!(server.args, vec!["-y", "@anthropic/mcp-server"]);
273 assert_eq!(server.env.get("DEBUG"), Some(&"true".to_string()));
274 assert!(server.enabled);
275 }
276
277 #[test]
280 fn test_mcp_request_serialization() {
281 let request = McpRequest::new("tools/list");
282 let json = serde_json::to_string(&request).unwrap();
283
284 assert!(json.contains(r#""method":"tools/list""#));
285 assert!(json.contains(r#""jsonrpc":"2.0""#));
286 }
287
288 #[test]
289 fn test_mcp_request_with_params() {
290 let request = McpRequest::new("tools/call").with_params(serde_json::json!({
291 "name": "my_tool",
292 "arguments": {"arg1": "value1"}
293 }));
294
295 let json = serde_json::to_string(&request).unwrap();
296 assert!(json.contains("my_tool"));
297 assert!(json.contains("arg1"));
298 }
299
300 #[test]
301 fn test_mcp_request_to_jsonl() {
302 let request = McpRequest::new("initialize");
303 let jsonl = request.to_jsonl().unwrap();
304
305 assert_eq!(jsonl.last(), Some(&b'\n'));
307
308 let json_str = String::from_utf8_lossy(&jsonl[..jsonl.len() - 1]);
310 let parsed: McpRequest = serde_json::from_str(&json_str).unwrap();
311 assert_eq!(parsed.method, "initialize");
312 }
313
314 #[test]
315 fn test_mcp_response_result() {
316 let response = McpResponse {
317 jsonrpc: "2.0".to_string(),
318 id: serde_json::json!(1),
319 result: Some(serde_json::json!({"tools": []})),
320 error: None,
321 };
322
323 assert!(!response.is_error());
324 let result = response.clone().into_result().unwrap();
325 assert!(result.get("tools").is_some());
326 }
327
328 #[test]
329 fn test_mcp_response_error() {
330 let response = McpResponse {
331 jsonrpc: "2.0".to_string(),
332 id: serde_json::json!(2),
333 result: None,
334 error: Some(McpError::internal_error("Something went wrong")),
335 };
336
337 assert!(response.is_error());
338 let err = response.into_result().unwrap_err();
339 assert!(err.to_string().contains("internal error"));
340 }
341
342 #[test]
343 fn test_mcp_error_codes() {
344 assert_eq!(McpError::parse_error().code, -32700);
345 assert_eq!(McpError::invalid_request("test").code, -32600);
346 assert_eq!(McpError::method_not_found().code, -32601);
347 assert_eq!(McpError::invalid_params().code, -32602);
348 assert_eq!(McpError::internal_error("x").code, -32603);
349 assert_eq!(McpError::server_error("x").code, -32000);
350 }
351
352 #[test]
355 fn test_mcp_tool_conversion() {
356 let mcp_tool = McpTool {
357 name: "test_tool".to_string(),
358 description: "A test tool".to_string(),
359 input_schema: serde_json::json!({
360 "type": "object",
361 "properties": {
362 "arg1": {
363 "type": "string",
364 "description": "First argument"
365 },
366 "arg2": {
367 "type": "number",
368 "description": "Second argument",
369 "default": "42"
370 }
371 },
372 "required": ["arg1"]
373 }),
374 };
375
376 let tool_def = mcp_tool.to_tool_def();
377
378 assert_eq!(tool_def.name, "test_tool");
379 assert_eq!(tool_def.description, "A test tool");
380 assert_eq!(tool_def.arguments.len(), 2);
381
382 let arg1 = tool_def
383 .arguments
384 .iter()
385 .find(|a| a.name == "arg1")
386 .unwrap();
387 assert!(arg1.required);
388 assert_eq!(arg1.description, "First argument");
389
390 let arg2 = tool_def
391 .arguments
392 .iter()
393 .find(|a| a.name == "arg2")
394 .unwrap();
395 assert!(!arg2.required);
396 assert_eq!(arg2.default, Some("42".to_string()));
397 }
398
399 #[test]
402 fn test_bridge_registration() {
403 let bridge = McpBridge::new();
404
405 bridge.register_server(McpServer::new("test", "echo"));
406
407 assert_eq!(bridge.servers(), vec!["test"]);
408 assert!(bridge.get_server("test").is_some());
409 assert!(bridge.get_server("missing").is_none());
410 }
411
412 #[tokio::test]
415 async fn test_mcp_client_non_existent_command() {
416 let server = McpServer::new("ghost", "nonexistent-binary-xyz");
417 let client = McpClient::new(server);
418
419 let result = client.initialize().await;
420 assert!(result.is_err());
421 assert!(result.unwrap_err().to_string().contains("Failed to spawn"));
422 }
423
424 #[tokio::test]
425 async fn test_mcp_client_shutdown_no_panic() {
426 let server = McpServer::new("test-shutdown", "echo");
427 let client = McpClient::new(server);
428
429 client.shutdown().await.expect("shutdown should succeed");
431 assert!(!client.is_initialized().await);
432 }
433
434 #[tokio::test]
435 async fn test_mcp_client_with_timeout() {
436 let server = McpServer::new("test", "sleep").with_args(vec!["999".to_string()]);
437 let client = McpClient::new(server).with_timeout(Duration::from_millis(100));
438
439 let result = client.initialize().await;
441 assert!(result.is_err());
443 }
444
445 #[tokio::test]
448 async fn test_bridge_initialize_all_empty() {
449 let bridge = McpBridge::new();
450 bridge
451 .initialize_all()
452 .await
453 .expect("empty bridge should initialize");
454 }
455
456 #[tokio::test]
457 async fn test_bridge_initialize_all_fails_gracefully() {
458 let bridge = McpBridge::new();
459 bridge.register_server(McpServer::new("ghost", "nonexistent-cmd-xyz"));
460 bridge.register_server(McpServer::new("ghost2", "nonexistent-cmd-abc"));
461
462 let result = bridge.initialize_all().await;
463 assert!(result.is_err());
465 }
466
467 #[tokio::test]
468 async fn test_bridge_shutdown_all_empty() {
469 let bridge = McpBridge::new();
470 bridge
471 .shutdown_all()
472 .await
473 .expect("empty bridge shutdown should succeed");
474 }
475
476 #[tokio::test]
477 async fn test_bridge_call_tool_no_server() {
478 let bridge = McpBridge::new();
479 let result = bridge
480 .call_tool("ghost", "tool", serde_json::json!({}))
481 .await;
482 assert!(result.is_err());
483 assert!(result.unwrap_err().to_string().contains("not connected"));
484 }
485
486 #[tokio::test]
487 async fn test_bridge_initialize_server_not_found() {
488 let bridge = McpBridge::new();
489 let result = bridge.initialize_server("missing").await;
490 assert!(result.is_err());
491 }
492
493 #[test]
494 fn test_mcp_client_debug() {
495 let server = McpServer::new("debug-test", "echo");
496 let client = McpClient::new(server);
497 let debug = format!("{:?}", client);
498 assert!(debug.contains("debug-test"));
499 }
500
501 #[cfg(unix)]
504 #[tokio::test]
505 #[ignore = "Requires bash shell environment with executable script support"]
506 async fn test_jsonrpc_echo_server() {
507 use std::os::unix::fs::PermissionsExt;
508 let temp_script = tempfile::tempdir().unwrap().path().join("mcp_echo.sh");
510 std::fs::write(
511 &temp_script,
512 r#"#!/bin/bash
513while IFS= read -r line; do
514 echo "$line"
515done
516"#,
517 )
518 .unwrap();
519 std::fs::set_permissions(&temp_script, std::fs::Permissions::from_mode(0o755)).unwrap();
520
521 let bridge = McpBridge::new();
522 bridge.register_server(
523 McpServer::new("echo-server", "bash")
524 .with_args(vec![temp_script.to_string_lossy().to_string()]),
525 );
526
527 bridge.initialize_all().await.unwrap();
528
529 let client = bridge.client("echo-server").await.unwrap();
530 let request =
531 McpRequest::new("tools/list").with_params(serde_json::json!({"test": "value"}));
532 let response = client.send_request(request).await;
533
534 if response.is_ok() {
538 tracing::info!("Echo server responded successfully");
539 }
540 }
541
542 #[tokio::test]
545 async fn test_mcp_client_double_init_ignored() {
546 let server = McpServer::new("echo", "echo");
547 let client = McpClient::new(server);
548
549 let _ = client.initialize().await;
551 let _ = client.is_initialized().await;
553 }
554}