1pub mod client;
28pub mod protocol;
29
30pub use client::McpClient;
31pub use protocol::*;
32
33use std::collections::HashMap;
34use std::sync::Arc;
35
36use anyhow::{anyhow, Result};
37use tokio::sync::RwLock;
38
39pub struct McpBridge {
62 servers: parking_lot::RwLock<Vec<McpServer>>,
64 clients: RwLock<HashMap<String, Arc<McpClient>>>,
66 tool_cache: RwLock<HashMap<String, Vec<McpTool>>>,
68}
69
70impl McpBridge {
71 pub fn new() -> Self {
73 Self {
74 servers: parking_lot::RwLock::new(Vec::new()),
75 clients: RwLock::new(HashMap::new()),
76 tool_cache: RwLock::new(HashMap::new()),
77 }
78 }
79
80 pub fn register_server(&self, server: McpServer) {
82 self.servers.write().push(server);
83 }
84
85 pub fn servers(&self) -> Vec<String> {
87 self.servers.read().iter().map(|s| s.name.clone()).collect()
88 }
89
90 pub fn get_server(&self, name: &str) -> Option<McpServer> {
92 self.servers.read().iter().find(|s| s.name == name).cloned()
93 }
94
95 pub async fn initialize_all(&self) -> Result<()> {
100 let mut errors = Vec::new();
101
102 let server_list: Vec<McpServer> = self.servers.read().iter().cloned().collect();
103 for server in server_list {
104 if !server.enabled {
105 tracing::debug!(server = %server.name, "Skipping disabled MCP server");
106 continue;
107 }
108
109 let client = Arc::new(McpClient::new(server.clone()));
110 match client.initialize().await {
111 Ok(()) => {
112 self.clients
113 .write()
114 .await
115 .insert(server.name.clone(), client);
116 tracing::info!(server = %server.name, "MCP server started");
117 }
118 Err(e) => {
119 tracing::error!(server = %server.name, error = %e, "Failed to initialize MCP server");
120 errors.push(format!("{}: {}", server.name, e));
121 }
122 }
123 }
124
125 if errors.is_empty() {
126 Ok(())
127 } else {
128 Err(anyhow!("MCP initialization failed: {}", errors.join("; ")))
129 }
130 }
131
132 pub async fn initialize_server(&self, name: &str) -> Result<()> {
134 let server = self
135 .servers
136 .read()
137 .iter()
138 .find(|s| s.name == name)
139 .cloned()
140 .ok_or_else(|| anyhow!("MCP server '{}' not found", name))?;
141
142 let client = Arc::new(McpClient::new(server));
143 client.initialize().await?;
144
145 self.clients.write().await.insert(name.to_string(), client);
146 Ok(())
147 }
148
149 pub async fn client(&self, name: &str) -> Option<Arc<McpClient>> {
151 self.clients.read().await.get(name).cloned()
152 }
153
154 pub async fn list_tools(&self) -> Result<Vec<McpTool>> {
158 let clients = self.clients.read().await;
159 let mut all_tools = Vec::new();
160
161 for (name, client) in clients.iter() {
162 if let Ok(mcp_tools) = client.list_tools().await {
163 let start = all_tools.len();
164 all_tools.extend(mcp_tools);
165 *self
166 .tool_cache
167 .write()
168 .await
169 .entry(name.clone())
170 .or_insert_with(Vec::new) = all_tools[start..].to_vec();
171 }
172 }
173
174 Ok(all_tools)
175 }
176
177 pub async fn cached_tools(&self, server_name: &str) -> Option<Vec<McpTool>> {
179 self.tool_cache.read().await.get(server_name).cloned()
180 }
181
182 pub async fn call_tool(
184 &self,
185 server_name: &str,
186 tool_name: &str,
187 args: serde_json::Value,
188 ) -> Result<McpToolCallResult> {
189 let clients = self.clients.read().await;
190 let client = clients
191 .get(server_name)
192 .ok_or_else(|| anyhow!("MCP server '{}' not connected", server_name))?;
193
194 client.call_tool(tool_name, args).await
195 }
196
197 pub async fn shutdown_all(&self) -> Result<()> {
199 let mut clients = self.clients.write().await;
200
201 for (name, client) in clients.drain() {
202 if let Err(e) = client.shutdown().await {
203 tracing::warn!(server = %name, error = %e, "Error shutting down MCP server");
204 }
205 }
206
207 self.tool_cache.write().await.clear();
208 Ok(())
209 }
210
211 pub async fn refresh_tools(&self, server_name: &str) -> Result<Vec<McpTool>> {
213 let clients = self.clients.read().await;
214 let client = clients
215 .get(server_name)
216 .ok_or_else(|| anyhow!("MCP server '{}' not connected", server_name))?;
217
218 let mcp_tools = client.refresh_tools().await?;
219
220 *self
221 .tool_cache
222 .write()
223 .await
224 .entry(server_name.to_string())
225 .or_insert_with(Vec::new) = mcp_tools.clone();
226
227 Ok(mcp_tools)
228 }
229
230 pub async fn clear_cache(&self, server_name: &str) {
232 self.tool_cache.write().await.remove(server_name);
233 }
234
235 pub async fn clear_all_caches(&self) {
237 self.tool_cache.write().await.clear();
238 }
239}
240
241impl Default for McpBridge {
242 fn default() -> Self {
243 Self::new()
244 }
245}
246
247#[cfg(test)]
252mod tests {
253 use super::*;
254 use tokio::time::Duration;
255
256 #[test]
259 fn test_mcp_server_builder() {
260 let server = McpServer::new("test-server", "npx")
261 .with_args(vec!["-y".to_string(), "@anthropic/mcp-server".to_string()])
262 .with_env("DEBUG", "true");
263
264 assert_eq!(server.name, "test-server");
265 assert_eq!(server.command, "npx");
266 assert_eq!(server.args, vec!["-y", "@anthropic/mcp-server"]);
267 assert_eq!(server.env.get("DEBUG"), Some(&"true".to_string()));
268 assert!(server.enabled);
269 }
270
271 #[test]
274 fn test_mcp_request_serialization() {
275 let request = McpRequest::new("tools/list");
276 let json = serde_json::to_string(&request).unwrap();
277
278 assert!(json.contains(r#""method":"tools/list""#));
279 assert!(json.contains(r#""jsonrpc":"2.0""#));
280 }
281
282 #[test]
283 fn test_mcp_request_with_params() {
284 let request = McpRequest::new("tools/call").with_params(serde_json::json!({
285 "name": "my_tool",
286 "arguments": {"arg1": "value1"}
287 }));
288
289 let json = serde_json::to_string(&request).unwrap();
290 assert!(json.contains("my_tool"));
291 assert!(json.contains("arg1"));
292 }
293
294 #[test]
295 fn test_mcp_request_to_jsonl() {
296 let request = McpRequest::new("initialize");
297 let jsonl = request.to_jsonl().unwrap();
298
299 assert_eq!(jsonl.last(), Some(&b'\n'));
301
302 let json_str = String::from_utf8_lossy(&jsonl[..jsonl.len() - 1]);
304 let parsed: McpRequest = serde_json::from_str(&json_str).unwrap();
305 assert_eq!(parsed.method, "initialize");
306 }
307
308 #[test]
309 fn test_mcp_response_result() {
310 let response = McpResponse {
311 jsonrpc: "2.0".to_string(),
312 id: serde_json::json!(1),
313 result: Some(serde_json::json!({"tools": []})),
314 error: None,
315 };
316
317 assert!(!response.is_error());
318 let result = response.clone().into_result().unwrap();
319 assert!(result.get("tools").is_some());
320 }
321
322 #[test]
323 fn test_mcp_response_error() {
324 let response = McpResponse {
325 jsonrpc: "2.0".to_string(),
326 id: serde_json::json!(2),
327 result: None,
328 error: Some(McpError::internal_error("Something went wrong")),
329 };
330
331 assert!(response.is_error());
332 let err = response.into_result().unwrap_err();
333 assert!(err.to_string().contains("internal error"));
334 }
335
336 #[test]
337 fn test_mcp_error_codes() {
338 assert_eq!(McpError::parse_error().code, -32700);
339 assert_eq!(McpError::invalid_request("test").code, -32600);
340 assert_eq!(McpError::method_not_found().code, -32601);
341 assert_eq!(McpError::invalid_params().code, -32602);
342 assert_eq!(McpError::internal_error("x").code, -32603);
343 assert_eq!(McpError::server_error("x").code, -32000);
344 }
345
346 #[test]
349 fn test_bridge_registration() {
350 let bridge = McpBridge::new();
351
352 bridge.register_server(McpServer::new("test", "echo"));
353
354 assert_eq!(bridge.servers(), vec!["test"]);
355 assert!(bridge.get_server("test").is_some());
356 assert!(bridge.get_server("missing").is_none());
357 }
358
359 #[tokio::test]
362 async fn test_mcp_client_non_existent_command() {
363 let server = McpServer::new("ghost", "nonexistent-binary-xyz");
364 let client = McpClient::new(server);
365
366 let result = client.initialize().await;
367 assert!(result.is_err());
368 assert!(result.unwrap_err().to_string().contains("Failed to spawn"));
369 }
370
371 #[tokio::test]
372 async fn test_mcp_client_shutdown_no_panic() {
373 let server = McpServer::new("test-shutdown", "echo");
374 let client = McpClient::new(server);
375
376 client.shutdown().await.expect("shutdown should succeed");
378 assert!(!client.is_initialized().await);
379 }
380
381 #[tokio::test]
382 async fn test_mcp_client_with_timeout() {
383 let server = McpServer::new("test", "sleep").with_args(vec!["999".to_string()]);
384 let client = McpClient::new(server).with_timeout(Duration::from_millis(100));
385
386 let result = client.initialize().await;
388 assert!(result.is_err());
390 }
391
392 #[tokio::test]
395 async fn test_bridge_initialize_all_empty() {
396 let bridge = McpBridge::new();
397 bridge
398 .initialize_all()
399 .await
400 .expect("empty bridge should initialize");
401 }
402
403 #[tokio::test]
404 async fn test_bridge_initialize_all_fails_gracefully() {
405 let bridge = McpBridge::new();
406 bridge.register_server(McpServer::new("ghost", "nonexistent-cmd-xyz"));
407 bridge.register_server(McpServer::new("ghost2", "nonexistent-cmd-abc"));
408
409 let result = bridge.initialize_all().await;
410 assert!(result.is_err());
412 }
413
414 #[tokio::test]
415 async fn test_bridge_shutdown_all_empty() {
416 let bridge = McpBridge::new();
417 bridge
418 .shutdown_all()
419 .await
420 .expect("empty bridge shutdown should succeed");
421 }
422
423 #[tokio::test]
424 async fn test_bridge_call_tool_no_server() {
425 let bridge = McpBridge::new();
426 let result = bridge
427 .call_tool("ghost", "tool", serde_json::json!({}))
428 .await;
429 assert!(result.is_err());
430 assert!(result.unwrap_err().to_string().contains("not connected"));
431 }
432
433 #[tokio::test]
434 async fn test_bridge_initialize_server_not_found() {
435 let bridge = McpBridge::new();
436 let result = bridge.initialize_server("missing").await;
437 assert!(result.is_err());
438 }
439
440 #[test]
441 fn test_mcp_client_debug() {
442 let server = McpServer::new("debug-test", "echo");
443 let client = McpClient::new(server);
444 let debug = format!("{:?}", client);
445 assert!(debug.contains("debug-test"));
446 }
447
448 #[tokio::test]
451 async fn test_mcp_client_double_init_ignored() {
452 let server = McpServer::new("echo", "echo");
453 let client = McpClient::new(server);
454
455 let _ = client.initialize().await;
457 let _ = client.is_initialized().await;
459 }
460}