1pub mod client;
28pub mod protocol;
29
30pub use client::McpClient;
31pub use protocol::*;
32
33use std::collections::HashMap;
34use std::sync::Arc;
35
36use anyhow::{Result, anyhow};
37use tokio::sync::RwLock;
38
39pub struct McpBridge {
62 servers: parking_lot::RwLock<Vec<McpServer>>,
64 clients: RwLock<HashMap<String, Arc<McpClient>>>,
66 tool_cache: RwLock<HashMap<String, Vec<McpTool>>>,
68}
69
70impl McpBridge {
71 pub fn new() -> Self {
73 Self {
74 servers: parking_lot::RwLock::new(Vec::new()),
75 clients: RwLock::new(HashMap::new()),
76 tool_cache: RwLock::new(HashMap::new()),
77 }
78 }
79
80 pub fn register_server(&self, server: McpServer) {
87 let mut servers = self.servers.write();
88 let name = server.name.clone();
89 if let Some(existing) = servers.iter_mut().find(|s| s.name == name) {
90 tracing::warn!(
91 server = %name,
92 "Overwriting duplicate MCP server registration"
93 );
94 *existing = server;
95 } else {
96 servers.push(server);
97 }
98 }
99
100 pub fn servers(&self) -> Vec<String> {
102 self.servers.read().iter().map(|s| s.name.clone()).collect()
103 }
104
105 pub fn get_server(&self, name: &str) -> Option<McpServer> {
107 self.servers.read().iter().find(|s| s.name == name).cloned()
108 }
109
110 pub async fn initialize_all(&self) -> Result<()> {
115 let mut errors = Vec::new();
116
117 let server_list: Vec<McpServer> = self.servers.read().iter().cloned().collect();
118 for server in server_list {
119 if !server.enabled {
120 tracing::debug!(server = %server.name, "Skipping disabled MCP server");
121 continue;
122 }
123
124 let client = Arc::new(McpClient::new(server.clone()));
125 match client.initialize().await {
126 Ok(()) => {
127 self.clients
128 .write()
129 .await
130 .insert(server.name.clone(), client);
131 tracing::info!(server = %server.name, "MCP server started");
132 }
133 Err(e) => {
134 tracing::error!(server = %server.name, error = %e, "Failed to initialize MCP server");
135 errors.push(format!("{}: {}", server.name, e));
136 }
137 }
138 }
139
140 if errors.is_empty() {
141 Ok(())
142 } else {
143 Err(anyhow!("MCP initialization failed: {}", errors.join("; ")))
144 }
145 }
146
147 pub async fn initialize_server(&self, name: &str) -> Result<()> {
149 let server = self
150 .servers
151 .read()
152 .iter()
153 .find(|s| s.name == name)
154 .cloned()
155 .ok_or_else(|| anyhow!("MCP server '{name}' not found"))?;
156
157 let client = Arc::new(McpClient::new(server));
158 client.initialize().await?;
159
160 self.clients.write().await.insert(name.to_string(), client);
161 Ok(())
162 }
163
164 pub async fn client(&self, name: &str) -> Option<Arc<McpClient>> {
166 self.clients.read().await.get(name).cloned()
167 }
168
169 pub async fn list_tools(&self) -> Result<Vec<McpTool>> {
177 let clients: Vec<(String, Arc<McpClient>)> = self
178 .clients
179 .read()
180 .await
181 .iter()
182 .map(|(k, v)| (k.clone(), v.clone()))
183 .collect();
184
185 let mut all_tools = Vec::new();
186 for (name, client) in &clients {
187 if let Ok(mcp_tools) = client.list_tools().await {
188 let start = all_tools.len();
189 all_tools.extend(mcp_tools);
190 *self
191 .tool_cache
192 .write()
193 .await
194 .entry(name.clone())
195 .or_insert_with(Vec::new) = all_tools[start..].to_vec();
196 }
197 }
198
199 Ok(all_tools)
200 }
201
202 pub async fn cached_tools(&self, server_name: &str) -> Option<Vec<McpTool>> {
204 self.tool_cache.read().await.get(server_name).cloned()
205 }
206
207 pub async fn call_tool(
211 &self,
212 server_name: &str,
213 tool_name: &str,
214 args: serde_json::Value,
215 ) -> Result<McpToolCallResult> {
216 let client = {
217 let clients = self.clients.read().await;
218 clients
219 .get(server_name)
220 .cloned()
221 .ok_or_else(|| anyhow!("MCP server '{server_name}' not connected"))?
222 };
223
224 client.call_tool(tool_name, args).await
225 }
226
227 pub async fn shutdown_all(&self) -> Result<()> {
229 let mut clients = self.clients.write().await;
230
231 for (name, client) in clients.drain() {
232 if let Err(e) = client.shutdown().await {
233 tracing::warn!(server = %name, error = %e, "Error shutting down MCP server");
234 }
235 }
236
237 self.tool_cache.write().await.clear();
238 Ok(())
239 }
240
241 pub async fn refresh_tools(&self, server_name: &str) -> Result<Vec<McpTool>> {
245 let client = {
246 let clients = self.clients.read().await;
247 clients
248 .get(server_name)
249 .cloned()
250 .ok_or_else(|| anyhow!("MCP server '{server_name}' not connected"))?
251 };
252
253 let mcp_tools = client.refresh_tools().await?;
254
255 *self
256 .tool_cache
257 .write()
258 .await
259 .entry(server_name.to_string())
260 .or_insert_with(Vec::new) = mcp_tools.clone();
261
262 Ok(mcp_tools)
263 }
264
265 pub async fn clear_cache(&self, server_name: &str) {
267 self.tool_cache.write().await.remove(server_name);
268 }
269
270 pub async fn remove_server(&self, name: &str) -> Result<()> {
272 if let Some(client) = self.clients.write().await.remove(name)
274 && let Err(e) = client.shutdown().await
275 {
276 tracing::warn!(server = %name, error = %e, "Error shutting down MCP server during removal");
277 }
278 let found = {
280 let mut servers = self.servers.write();
281 let len_before = servers.len();
282 servers.retain(|s| s.name != name);
283 servers.len() != len_before
284 };
285 if !found {
286 return Err(anyhow!("MCP server '{name}' not found"));
287 }
288 self.tool_cache.write().await.remove(name);
290 Ok(())
291 }
292
293 pub async fn toggle_server(&self, name: &str) -> Result<bool> {
295 let new_state = {
298 let mut servers = self.servers.write();
299 let server = servers
300 .iter_mut()
301 .find(|s| s.name == name)
302 .ok_or_else(|| anyhow!("MCP server '{name}' not found"))?;
303 server.enabled = !server.enabled;
304 server.enabled
305 };
306
307 if !new_state {
309 if let Some(client) = self.clients.write().await.remove(name)
310 && let Err(e) = client.shutdown().await
311 {
312 tracing::warn!(server = %name, error = %e, "Error shutting down MCP server on disable");
313 }
314 self.tool_cache.write().await.remove(name);
315 }
316
317 Ok(new_state)
318 }
319
320 pub async fn clear_all_caches(&self) {
322 self.tool_cache.write().await.clear();
323 }
324}
325
326impl Default for McpBridge {
327 fn default() -> Self {
328 Self::new()
329 }
330}
331
332#[cfg(test)]
337mod tests {
338 use super::*;
339 use tokio::time::Duration;
340
341 #[test]
344 fn test_mcp_server_builder() {
345 let server = McpServer::new("test-server", "npx")
346 .with_args(vec!["-y".to_string(), "@anthropic/mcp-server".to_string()])
347 .with_env("DEBUG", "true");
348
349 assert_eq!(server.name, "test-server");
350 assert_eq!(server.command, "npx");
351 assert_eq!(server.args, vec!["-y", "@anthropic/mcp-server"]);
352 assert_eq!(server.env.get("DEBUG"), Some(&"true".to_string()));
353 assert!(server.enabled);
354 }
355
356 #[test]
359 fn test_mcp_request_serialization() {
360 let request = McpRequest::new("tools/list");
361 let json = serde_json::to_string(&request).unwrap();
362
363 assert!(json.contains(r#""method":"tools/list""#));
364 assert!(json.contains(r#""jsonrpc":"2.0""#));
365 }
366
367 #[test]
368 fn test_mcp_request_with_params() {
369 let request = McpRequest::new("tools/call").with_params(serde_json::json!({
370 "name": "my_tool",
371 "arguments": {"arg1": "value1"}
372 }));
373
374 let json = serde_json::to_string(&request).unwrap();
375 assert!(json.contains("my_tool"));
376 assert!(json.contains("arg1"));
377 }
378
379 #[test]
380 fn test_mcp_request_to_jsonl() {
381 let request = McpRequest::new("initialize");
382 let jsonl = request.to_jsonl().unwrap();
383
384 assert_eq!(jsonl.last(), Some(&b'\n'));
386
387 let json_str = String::from_utf8_lossy(&jsonl[..jsonl.len() - 1]);
389 let parsed: McpRequest = serde_json::from_str(&json_str).unwrap();
390 assert_eq!(parsed.method, "initialize");
391 }
392
393 #[test]
394 fn test_mcp_response_result() {
395 let response = McpResponse {
396 jsonrpc: "2.0".to_string(),
397 id: serde_json::json!(1),
398 result: Some(serde_json::json!({"tools": []})),
399 error: None,
400 };
401
402 assert!(!response.is_error());
403 let result = response.clone().into_result().unwrap();
404 assert!(result.get("tools").is_some());
405 }
406
407 #[test]
408 fn test_mcp_response_error() {
409 let response = McpResponse {
410 jsonrpc: "2.0".to_string(),
411 id: serde_json::json!(2),
412 result: None,
413 error: Some(McpError::internal_error("Something went wrong")),
414 };
415
416 assert!(response.is_error());
417 let err = response.into_result().unwrap_err();
418 assert!(err.to_string().contains("internal error"));
419 }
420
421 #[test]
422 fn test_mcp_error_codes() {
423 assert_eq!(McpError::parse_error().code, -32700);
424 assert_eq!(McpError::invalid_request("test").code, -32600);
425 assert_eq!(McpError::method_not_found().code, -32601);
426 assert_eq!(McpError::invalid_params().code, -32602);
427 assert_eq!(McpError::internal_error("x").code, -32603);
428 assert_eq!(McpError::server_error("x").code, -32000);
429 }
430
431 #[test]
434 fn test_bridge_registration() {
435 let bridge = McpBridge::new();
436
437 bridge.register_server(McpServer::new("test", "echo"));
438
439 assert_eq!(bridge.servers(), vec!["test"]);
440 assert!(bridge.get_server("test").is_some());
441 assert!(bridge.get_server("missing").is_none());
442 }
443
444 #[tokio::test]
447 async fn test_mcp_client_non_existent_command() {
448 let server = McpServer::new("ghost", "nonexistent-binary-xyz");
449 let client = McpClient::new(server);
450
451 let result = client.initialize().await;
452 assert!(result.is_err());
453 assert!(result.unwrap_err().to_string().contains("Failed to spawn"));
454 }
455
456 #[tokio::test]
457 async fn test_mcp_client_shutdown_no_panic() {
458 let server = McpServer::new("test-shutdown", "echo");
459 let client = McpClient::new(server);
460
461 client.shutdown().await.expect("shutdown should succeed");
463 assert!(!client.is_initialized().await);
464 }
465
466 #[tokio::test]
467 async fn test_mcp_client_with_timeout() {
468 let server = McpServer::new("test", "sleep").with_args(vec!["999".to_string()]);
469 let client = McpClient::new(server).with_timeout(Duration::from_millis(100));
470
471 let result = client.initialize().await;
473 assert!(result.is_err());
475 }
476
477 #[tokio::test]
480 async fn test_bridge_initialize_all_empty() {
481 let bridge = McpBridge::new();
482 bridge
483 .initialize_all()
484 .await
485 .expect("empty bridge should initialize");
486 }
487
488 #[tokio::test]
489 async fn test_bridge_initialize_all_fails_gracefully() {
490 let bridge = McpBridge::new();
491 bridge.register_server(McpServer::new("ghost", "nonexistent-cmd-xyz"));
492 bridge.register_server(McpServer::new("ghost2", "nonexistent-cmd-abc"));
493
494 let result = bridge.initialize_all().await;
495 assert!(result.is_err());
497 }
498
499 #[tokio::test]
500 async fn test_bridge_shutdown_all_empty() {
501 let bridge = McpBridge::new();
502 bridge
503 .shutdown_all()
504 .await
505 .expect("empty bridge shutdown should succeed");
506 }
507
508 #[tokio::test]
509 async fn test_bridge_call_tool_no_server() {
510 let bridge = McpBridge::new();
511 let result = bridge
512 .call_tool("ghost", "tool", serde_json::json!({}))
513 .await;
514 assert!(result.is_err());
515 assert!(result.unwrap_err().to_string().contains("not connected"));
516 }
517
518 #[tokio::test]
519 async fn test_bridge_initialize_server_not_found() {
520 let bridge = McpBridge::new();
521 let result = bridge.initialize_server("missing").await;
522 assert!(result.is_err());
523 }
524
525 #[test]
526 fn test_mcp_client_debug() {
527 let server = McpServer::new("debug-test", "echo");
528 let client = McpClient::new(server);
529 let debug = format!("{client:?}");
530 assert!(debug.contains("debug-test"));
531 }
532
533 #[tokio::test]
536 async fn test_mcp_client_double_init_ignored() {
537 let server = McpServer::new("echo", "echo");
538 let client = McpClient::new(server);
539
540 let _ = client.initialize().await;
542 let _ = client.is_initialized().await;
544 }
545}