1use crate::mcp::client::McpClient;
6use crate::mcp::protocol::{
7 CallToolResult, McpServerConfig, McpTool, McpTransportConfig, ToolContent,
8};
9use crate::mcp::transport::http_sse::HttpSseTransport;
10use crate::mcp::transport::stdio::StdioTransport;
11use crate::mcp::transport::streamable_http::StreamableHttpTransport;
12use crate::mcp::transport::McpTransport;
13use anyhow::{anyhow, Result};
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::RwLock;
17
18#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
20pub struct McpServerStatus {
21 pub name: String,
22 pub connected: bool,
23 pub enabled: bool,
24 pub tool_count: usize,
25 pub error: Option<String>,
26}
27
28pub struct McpManager {
30 clients: RwLock<HashMap<String, Arc<McpClient>>>,
32 configs: RwLock<HashMap<String, McpServerConfig>>,
34}
35
36impl McpManager {
37 pub fn new() -> Self {
39 Self {
40 clients: RwLock::new(HashMap::new()),
41 configs: RwLock::new(HashMap::new()),
42 }
43 }
44
45 pub async fn register_server(&self, config: McpServerConfig) {
47 let name = config.name.clone();
48 let mut configs = self.configs.write().await;
49 configs.insert(name.clone(), config);
50 tracing::info!("Registered MCP server: {}", name);
51 }
52
53 pub async fn connect(&self, name: &str) -> Result<()> {
55 let config = {
57 let configs = self.configs.read().await;
58 configs
59 .get(name)
60 .cloned()
61 .ok_or_else(|| anyhow!("MCP server not found: {}", name))?
62 };
63
64 if !config.enabled {
65 return Err(anyhow!("MCP server is disabled: {}", name));
66 }
67
68 let transport: Arc<dyn McpTransport> = match &config.transport {
70 McpTransportConfig::Stdio { command, args } => Arc::new(
71 StdioTransport::spawn_with_timeout(
72 command,
73 args,
74 &config.env,
75 config.tool_timeout_secs,
76 )
77 .await?,
78 ),
79 McpTransportConfig::Http { url, headers } => Arc::new(
80 HttpSseTransport::connect_with_timeout(
81 url,
82 headers.clone(),
83 config.tool_timeout_secs,
84 )
85 .await?,
86 ),
87 McpTransportConfig::StreamableHttp { url, headers } => Arc::new(
88 StreamableHttpTransport::connect_with_timeout(
89 url,
90 headers.clone(),
91 config.tool_timeout_secs,
92 )
93 .await?,
94 ),
95 };
96
97 let client = Arc::new(McpClient::new(name.to_string(), transport));
99
100 client.initialize().await?;
102
103 let tools = client.list_tools().await?;
105 tracing::info!("MCP server '{}' connected with {} tools", name, tools.len());
106
107 {
109 let mut clients = self.clients.write().await;
110 clients.insert(name.to_string(), client);
111 }
112
113 Ok(())
114 }
115
116 pub async fn disconnect(&self, name: &str) -> Result<()> {
118 let client = {
119 let mut clients = self.clients.write().await;
120 clients.remove(name)
121 };
122
123 if let Some(client) = client {
124 client.close().await?;
125 tracing::info!("MCP server '{}' disconnected", name);
126 }
127
128 Ok(())
129 }
130
131 pub async fn all_configs(&self) -> Vec<McpServerConfig> {
133 self.configs.read().await.values().cloned().collect()
134 }
135
136 pub async fn get_all_tools(&self) -> Vec<(String, McpTool)> {
140 let clients = self.clients.read().await;
141 let mut all_tools = Vec::new();
142
143 for (server_name, client) in clients.iter() {
144 let tools = client.get_cached_tools().await;
145 for tool in tools {
146 let full_name = format!("mcp__{}__{}", server_name, tool.name);
147 all_tools.push((full_name, tool));
148 }
149 }
150
151 all_tools
152 }
153
154 pub async fn call_tool(
158 &self,
159 full_name: &str,
160 arguments: Option<serde_json::Value>,
161 ) -> Result<CallToolResult> {
162 let (server_name, tool_name) = Self::parse_tool_name(full_name)?;
164
165 let client = {
167 let clients = self.clients.read().await;
168 clients
169 .get(&server_name)
170 .cloned()
171 .ok_or_else(|| anyhow!("MCP server not connected: {}", server_name))?
172 };
173
174 client.call_tool(&tool_name, arguments).await
176 }
177
178 fn parse_tool_name(full_name: &str) -> Result<(String, String)> {
180 if !full_name.starts_with("mcp__") {
182 return Err(anyhow!("Invalid MCP tool name: {}", full_name));
183 }
184
185 let rest = &full_name[5..]; let parts: Vec<&str> = rest.splitn(2, "__").collect();
187
188 if parts.len() != 2 {
189 return Err(anyhow!("Invalid MCP tool name format: {}", full_name));
190 }
191
192 Ok((parts[0].to_string(), parts[1].to_string()))
193 }
194
195 pub async fn get_status(&self) -> HashMap<String, McpServerStatus> {
197 let configs = self.configs.read().await;
198 let clients = self.clients.read().await;
199 let mut status = HashMap::new();
200
201 for (name, config) in configs.iter() {
202 let client = clients.get(name);
203 let (connected, tool_count) = if let Some(c) = client {
204 (c.is_connected(), c.get_cached_tools().await.len())
205 } else {
206 (false, 0)
207 };
208
209 status.insert(
210 name.clone(),
211 McpServerStatus {
212 name: name.clone(),
213 connected,
214 enabled: config.enabled,
215 tool_count,
216 error: None,
217 },
218 );
219 }
220
221 status
222 }
223
224 pub async fn get_client(&self, name: &str) -> Option<Arc<McpClient>> {
226 let clients = self.clients.read().await;
227 clients.get(name).cloned()
228 }
229
230 pub async fn is_connected(&self, name: &str) -> bool {
232 let clients = self.clients.read().await;
233 clients.get(name).map(|c| c.is_connected()).unwrap_or(false)
234 }
235
236 pub async fn list_connected(&self) -> Vec<String> {
238 let clients = self.clients.read().await;
239 clients.keys().cloned().collect()
240 }
241
242 pub async fn get_server_tools(&self, name: &str) -> Vec<McpTool> {
244 let clients = self.clients.read().await;
245 match clients.get(name) {
246 Some(client) => client.get_cached_tools().await,
247 None => Vec::new(),
248 }
249 }
250}
251
252impl Default for McpManager {
253 fn default() -> Self {
254 Self::new()
255 }
256}
257
258pub fn tool_result_to_string(result: &CallToolResult) -> String {
260 let mut output = String::new();
261
262 for content in &result.content {
263 match content {
264 ToolContent::Text { text } => {
265 output.push_str(text);
266 output.push('\n');
267 }
268 ToolContent::Image { data: _, mime_type } => {
269 output.push_str(&format!("[Image: {}]\n", mime_type));
270 }
271 ToolContent::Resource { resource } => {
272 if let Some(text) = &resource.text {
273 output.push_str(text);
274 output.push('\n');
275 } else {
276 output.push_str(&format!("[Resource: {}]\n", resource.uri));
277 }
278 }
279 }
280 }
281
282 output.trim_end().to_string()
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn test_parse_tool_name() {
291 let (server, tool) = McpManager::parse_tool_name("mcp__github__create_issue").unwrap();
292 assert_eq!(server, "github");
293 assert_eq!(tool, "create_issue");
294 }
295
296 #[test]
297 fn test_parse_tool_name_with_underscores() {
298 let (server, tool) = McpManager::parse_tool_name("mcp__my_server__my_tool_name").unwrap();
299 assert_eq!(server, "my_server");
300 assert_eq!(tool, "my_tool_name");
301 }
302
303 #[test]
304 fn test_parse_tool_name_invalid() {
305 assert!(McpManager::parse_tool_name("invalid_name").is_err());
306 assert!(McpManager::parse_tool_name("mcp__nodelimiter").is_err());
307 }
308
309 #[test]
310 fn test_tool_result_to_string() {
311 let result = CallToolResult {
312 content: vec![
313 ToolContent::Text {
314 text: "Line 1".to_string(),
315 },
316 ToolContent::Text {
317 text: "Line 2".to_string(),
318 },
319 ],
320 is_error: false,
321 };
322
323 let output = tool_result_to_string(&result);
324 assert!(output.contains("Line 1"));
325 assert!(output.contains("Line 2"));
326 }
327
328 #[tokio::test]
329 async fn test_mcp_manager_new() {
330 let manager = McpManager::new();
331 let status = manager.get_status().await;
332 assert!(status.is_empty());
333 }
334
335 #[tokio::test]
336 async fn test_mcp_manager_register_server() {
337 let manager = McpManager::new();
338
339 let config = McpServerConfig {
340 name: "test".to_string(),
341 transport: McpTransportConfig::Stdio {
342 command: "echo".to_string(),
343 args: vec![],
344 },
345 enabled: true,
346 env: HashMap::new(),
347 oauth: None,
348 tool_timeout_secs: 60,
349 };
350
351 manager.register_server(config).await;
352
353 let status = manager.get_status().await;
354 assert!(status.contains_key("test"));
355 assert!(!status["test"].connected);
356 }
357
358 #[tokio::test]
359 async fn test_mcp_manager_default() {
360 let manager = McpManager::default();
361 let status = manager.get_status().await;
362 assert!(status.is_empty());
363 }
364
365 #[tokio::test]
366 async fn test_list_connected_empty() {
367 let manager = McpManager::new();
368 let connected = manager.list_connected().await;
369 assert!(connected.is_empty());
370 }
371
372 #[tokio::test]
373 async fn test_is_connected_false_for_unknown_server() {
374 let manager = McpManager::new();
375 let connected = manager.is_connected("unknown_server").await;
376 assert!(!connected);
377 }
378
379 #[tokio::test]
380 async fn test_get_client_none_for_unknown_server() {
381 let manager = McpManager::new();
382 let client = manager.get_client("unknown_server").await;
383 assert!(client.is_none());
384 }
385
386 #[test]
387 fn test_parse_tool_name_simple() {
388 let (server, tool) = McpManager::parse_tool_name("mcp__server__tool").unwrap();
389 assert_eq!(server, "server");
390 assert_eq!(tool, "tool");
391 }
392
393 #[test]
394 fn test_parse_tool_name_multiple_underscores() {
395 let (server, tool) = McpManager::parse_tool_name("mcp__my_server__my_tool_name").unwrap();
396 assert_eq!(server, "my_server");
397 assert_eq!(tool, "my_tool_name");
398 }
399
400 #[test]
401 fn test_parse_tool_name_missing_prefix() {
402 let result = McpManager::parse_tool_name("server__tool");
403 assert!(result.is_err());
404 }
405
406 #[test]
407 fn test_parse_tool_name_only_prefix() {
408 let result = McpManager::parse_tool_name("mcp__");
409 assert!(result.is_err());
410 }
411
412 #[test]
413 fn test_parse_tool_name_empty_string() {
414 let result = McpManager::parse_tool_name("");
415 assert!(result.is_err());
416 }
417
418 #[test]
419 fn test_tool_result_to_string_single_text() {
420 let result = CallToolResult {
421 content: vec![ToolContent::Text {
422 text: "Hello World".to_string(),
423 }],
424 is_error: false,
425 };
426 let output = tool_result_to_string(&result);
427 assert_eq!(output, "Hello World");
428 }
429
430 #[test]
431 fn test_tool_result_to_string_multiple_text() {
432 let result = CallToolResult {
433 content: vec![
434 ToolContent::Text {
435 text: "First line".to_string(),
436 },
437 ToolContent::Text {
438 text: "Second line".to_string(),
439 },
440 ],
441 is_error: false,
442 };
443 let output = tool_result_to_string(&result);
444 assert!(output.contains("First line"));
445 assert!(output.contains("Second line"));
446 }
447
448 #[test]
449 fn test_tool_result_to_string_empty() {
450 let result = CallToolResult {
451 content: vec![],
452 is_error: false,
453 };
454 let output = tool_result_to_string(&result);
455 assert_eq!(output, "");
456 }
457
458 #[test]
459 fn test_tool_result_to_string_image() {
460 let result = CallToolResult {
461 content: vec![ToolContent::Image {
462 data: "base64data".to_string(),
463 mime_type: "image/png".to_string(),
464 }],
465 is_error: false,
466 };
467 let output = tool_result_to_string(&result);
468 assert!(output.contains("[Image: image/png]"));
469 }
470
471 #[test]
472 fn test_tool_result_to_string_resource() {
473 use crate::mcp::protocol::ResourceContent;
474 let result = CallToolResult {
475 content: vec![ToolContent::Resource {
476 resource: ResourceContent {
477 uri: "file:///test.txt".to_string(),
478 mime_type: Some("text/plain".to_string()),
479 text: Some("Resource content".to_string()),
480 blob: None,
481 },
482 }],
483 is_error: false,
484 };
485 let output = tool_result_to_string(&result);
486 assert!(output.contains("Resource content"));
487 }
488
489 #[test]
490 fn test_tool_result_to_string_mixed_content() {
491 use crate::mcp::protocol::ResourceContent;
492 let result = CallToolResult {
493 content: vec![
494 ToolContent::Text {
495 text: "Text content".to_string(),
496 },
497 ToolContent::Image {
498 data: "base64".to_string(),
499 mime_type: "image/jpeg".to_string(),
500 },
501 ToolContent::Resource {
502 resource: ResourceContent {
503 uri: "file:///doc.md".to_string(),
504 mime_type: Some("text/markdown".to_string()),
505 text: Some("Doc content".to_string()),
506 blob: None,
507 },
508 },
509 ],
510 is_error: false,
511 };
512 let output = tool_result_to_string(&result);
513 assert!(output.contains("Text content"));
514 assert!(output.contains("[Image: image/jpeg]"));
515 assert!(output.contains("Doc content"));
516 }
517
518 #[tokio::test]
519 async fn test_get_status_registered_server() {
520 use std::collections::HashMap;
521 let manager = McpManager::new();
522
523 let config = McpServerConfig {
524 name: "test_server".to_string(),
525 transport: McpTransportConfig::Stdio {
526 command: "echo".to_string(),
527 args: vec![],
528 },
529 enabled: true,
530 env: HashMap::new(),
531 oauth: None,
532 tool_timeout_secs: 60,
533 };
534
535 manager.register_server(config).await;
536
537 let status = manager.get_status().await;
538 assert!(status.contains_key("test_server"));
539 assert!(!status["test_server"].connected);
540 assert!(status["test_server"].enabled);
541 }
542
543 #[tokio::test]
544 async fn test_get_status_disabled_server() {
545 use std::collections::HashMap;
546 let manager = McpManager::new();
547
548 let config = McpServerConfig {
549 name: "disabled_server".to_string(),
550 transport: McpTransportConfig::Stdio {
551 command: "echo".to_string(),
552 args: vec![],
553 },
554 enabled: false,
555 env: HashMap::new(),
556 oauth: None,
557 tool_timeout_secs: 60,
558 };
559
560 manager.register_server(config).await;
561
562 let status = manager.get_status().await;
563 assert!(status.contains_key("disabled_server"));
564 assert!(!status["disabled_server"].enabled);
565 }
566
567 #[tokio::test]
568 async fn test_get_all_tools_empty_manager() {
569 let manager = McpManager::new();
570 let tools = manager.get_all_tools().await;
571 assert!(tools.is_empty());
572 }
573}