Skip to main content

a3s_code_core/mcp/
manager.rs

1//! MCP Manager
2//!
3//! Manages MCP server lifecycle and provides unified access to MCP tools.
4
5use crate::mcp::client::McpClient;
6use crate::mcp::protocol::{
7    CallToolResult, McpServerConfig, McpTool, McpTransportConfig, ToolContent,
8};
9use crate::mcp::transport::stdio::StdioTransport;
10use crate::mcp::transport::McpTransport;
11use anyhow::{anyhow, Result};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16/// MCP server status
17#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
18pub struct McpServerStatus {
19    pub name: String,
20    pub connected: bool,
21    pub enabled: bool,
22    pub tool_count: usize,
23    pub error: Option<String>,
24}
25
26/// MCP Manager for managing multiple MCP servers
27pub struct McpManager {
28    /// Connected clients
29    clients: RwLock<HashMap<String, Arc<McpClient>>>,
30    /// Server configurations
31    configs: RwLock<HashMap<String, McpServerConfig>>,
32}
33
34impl McpManager {
35    /// Create a new MCP manager
36    pub fn new() -> Self {
37        Self {
38            clients: RwLock::new(HashMap::new()),
39            configs: RwLock::new(HashMap::new()),
40        }
41    }
42
43    /// Register a server configuration
44    pub async fn register_server(&self, config: McpServerConfig) {
45        let name = config.name.clone();
46        let mut configs = self.configs.write().await;
47        configs.insert(name.clone(), config);
48        tracing::info!("Registered MCP server: {}", name);
49    }
50
51    /// Connect to a registered server
52    pub async fn connect(&self, name: &str) -> Result<()> {
53        // Get config
54        let config = {
55            let configs = self.configs.read().await;
56            configs
57                .get(name)
58                .cloned()
59                .ok_or_else(|| anyhow!("MCP server not found: {}", name))?
60        };
61
62        if !config.enabled {
63            return Err(anyhow!("MCP server is disabled: {}", name));
64        }
65
66        // Create transport based on config
67        let transport: Arc<dyn McpTransport> = match &config.transport {
68            McpTransportConfig::Stdio { command, args } => Arc::new(
69                StdioTransport::spawn_with_timeout(
70                    command,
71                    args,
72                    &config.env,
73                    config.tool_timeout_secs,
74                )
75                .await?,
76            ),
77            McpTransportConfig::Http { url: _, headers: _ } => {
78                // HTTP transport not implemented yet
79                return Err(anyhow!("HTTP transport not yet implemented"));
80            }
81        };
82
83        // Create client
84        let client = Arc::new(McpClient::new(name.to_string(), transport));
85
86        // Initialize
87        client.initialize().await?;
88
89        // Fetch tools
90        let tools = client.list_tools().await?;
91        tracing::info!("MCP server '{}' connected with {} tools", name, tools.len());
92
93        // Store client
94        {
95            let mut clients = self.clients.write().await;
96            clients.insert(name.to_string(), client);
97        }
98
99        Ok(())
100    }
101
102    /// Disconnect from a server
103    pub async fn disconnect(&self, name: &str) -> Result<()> {
104        let client = {
105            let mut clients = self.clients.write().await;
106            clients.remove(name)
107        };
108
109        if let Some(client) = client {
110            client.close().await?;
111            tracing::info!("MCP server '{}' disconnected", name);
112        }
113
114        Ok(())
115    }
116
117    /// Get all MCP tools with server prefix
118    ///
119    /// Returns tools with names like `mcp__github__create_issue`
120    pub async fn get_all_tools(&self) -> Vec<(String, McpTool)> {
121        let clients = self.clients.read().await;
122        let mut all_tools = Vec::new();
123
124        for (server_name, client) in clients.iter() {
125            let tools = client.get_cached_tools().await;
126            for tool in tools {
127                let full_name = format!("mcp__{}__{}", server_name, tool.name);
128                all_tools.push((full_name, tool));
129            }
130        }
131
132        all_tools
133    }
134
135    /// Call an MCP tool by full name
136    ///
137    /// Full name format: `mcp__<server>__<tool>`
138    pub async fn call_tool(
139        &self,
140        full_name: &str,
141        arguments: Option<serde_json::Value>,
142    ) -> Result<CallToolResult> {
143        // Parse full name
144        let (server_name, tool_name) = Self::parse_tool_name(full_name)?;
145
146        // Get client
147        let client = {
148            let clients = self.clients.read().await;
149            clients
150                .get(&server_name)
151                .cloned()
152                .ok_or_else(|| anyhow!("MCP server not connected: {}", server_name))?
153        };
154
155        // Call tool
156        client.call_tool(&tool_name, arguments).await
157    }
158
159    /// Parse MCP tool full name into (server, tool)
160    fn parse_tool_name(full_name: &str) -> Result<(String, String)> {
161        // Format: mcp__<server>__<tool>
162        if !full_name.starts_with("mcp__") {
163            return Err(anyhow!("Invalid MCP tool name: {}", full_name));
164        }
165
166        let rest = &full_name[5..]; // Skip "mcp__"
167        let parts: Vec<&str> = rest.splitn(2, "__").collect();
168
169        if parts.len() != 2 {
170            return Err(anyhow!("Invalid MCP tool name format: {}", full_name));
171        }
172
173        Ok((parts[0].to_string(), parts[1].to_string()))
174    }
175
176    /// Get status of all servers
177    pub async fn get_status(&self) -> HashMap<String, McpServerStatus> {
178        let configs = self.configs.read().await;
179        let clients = self.clients.read().await;
180        let mut status = HashMap::new();
181
182        for (name, config) in configs.iter() {
183            let client = clients.get(name);
184            let (connected, tool_count) = if let Some(c) = client {
185                (c.is_connected(), c.get_cached_tools().await.len())
186            } else {
187                (false, 0)
188            };
189
190            status.insert(
191                name.clone(),
192                McpServerStatus {
193                    name: name.clone(),
194                    connected,
195                    enabled: config.enabled,
196                    tool_count,
197                    error: None,
198                },
199            );
200        }
201
202        status
203    }
204
205    /// Get a specific client
206    pub async fn get_client(&self, name: &str) -> Option<Arc<McpClient>> {
207        let clients = self.clients.read().await;
208        clients.get(name).cloned()
209    }
210
211    /// Check if a server is connected
212    pub async fn is_connected(&self, name: &str) -> bool {
213        let clients = self.clients.read().await;
214        clients.get(name).map(|c| c.is_connected()).unwrap_or(false)
215    }
216
217    /// List connected server names
218    pub async fn list_connected(&self) -> Vec<String> {
219        let clients = self.clients.read().await;
220        clients.keys().cloned().collect()
221    }
222
223    /// Get cached tools for a specific connected server.
224    pub async fn get_server_tools(&self, name: &str) -> Vec<McpTool> {
225        let clients = self.clients.read().await;
226        match clients.get(name) {
227            Some(client) => client.get_cached_tools().await,
228            None => Vec::new(),
229        }
230    }
231}
232
233impl Default for McpManager {
234    fn default() -> Self {
235        Self::new()
236    }
237}
238
239/// Convert MCP tool result to string output
240pub fn tool_result_to_string(result: &CallToolResult) -> String {
241    let mut output = String::new();
242
243    for content in &result.content {
244        match content {
245            ToolContent::Text { text } => {
246                output.push_str(text);
247                output.push('\n');
248            }
249            ToolContent::Image { data: _, mime_type } => {
250                output.push_str(&format!("[Image: {}]\n", mime_type));
251            }
252            ToolContent::Resource { resource } => {
253                if let Some(text) = &resource.text {
254                    output.push_str(text);
255                    output.push('\n');
256                } else {
257                    output.push_str(&format!("[Resource: {}]\n", resource.uri));
258                }
259            }
260        }
261    }
262
263    output.trim_end().to_string()
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn test_parse_tool_name() {
272        let (server, tool) = McpManager::parse_tool_name("mcp__github__create_issue").unwrap();
273        assert_eq!(server, "github");
274        assert_eq!(tool, "create_issue");
275    }
276
277    #[test]
278    fn test_parse_tool_name_with_underscores() {
279        let (server, tool) = McpManager::parse_tool_name("mcp__my_server__my_tool_name").unwrap();
280        assert_eq!(server, "my_server");
281        assert_eq!(tool, "my_tool_name");
282    }
283
284    #[test]
285    fn test_parse_tool_name_invalid() {
286        assert!(McpManager::parse_tool_name("invalid_name").is_err());
287        assert!(McpManager::parse_tool_name("mcp__nodelimiter").is_err());
288    }
289
290    #[test]
291    fn test_tool_result_to_string() {
292        let result = CallToolResult {
293            content: vec![
294                ToolContent::Text {
295                    text: "Line 1".to_string(),
296                },
297                ToolContent::Text {
298                    text: "Line 2".to_string(),
299                },
300            ],
301            is_error: false,
302        };
303
304        let output = tool_result_to_string(&result);
305        assert!(output.contains("Line 1"));
306        assert!(output.contains("Line 2"));
307    }
308
309    #[tokio::test]
310    async fn test_mcp_manager_new() {
311        let manager = McpManager::new();
312        let status = manager.get_status().await;
313        assert!(status.is_empty());
314    }
315
316    #[tokio::test]
317    async fn test_mcp_manager_register_server() {
318        let manager = McpManager::new();
319
320        let config = McpServerConfig {
321            name: "test".to_string(),
322            transport: McpTransportConfig::Stdio {
323                command: "echo".to_string(),
324                args: vec![],
325            },
326            enabled: true,
327            env: HashMap::new(),
328            oauth: None,
329            tool_timeout_secs: 60,
330        };
331
332        manager.register_server(config).await;
333
334        let status = manager.get_status().await;
335        assert!(status.contains_key("test"));
336        assert!(!status["test"].connected);
337    }
338
339    #[tokio::test]
340    async fn test_mcp_manager_default() {
341        let manager = McpManager::default();
342        let status = manager.get_status().await;
343        assert!(status.is_empty());
344    }
345
346    #[tokio::test]
347    async fn test_list_connected_empty() {
348        let manager = McpManager::new();
349        let connected = manager.list_connected().await;
350        assert!(connected.is_empty());
351    }
352
353    #[tokio::test]
354    async fn test_is_connected_false_for_unknown_server() {
355        let manager = McpManager::new();
356        let connected = manager.is_connected("unknown_server").await;
357        assert!(!connected);
358    }
359
360    #[tokio::test]
361    async fn test_get_client_none_for_unknown_server() {
362        let manager = McpManager::new();
363        let client = manager.get_client("unknown_server").await;
364        assert!(client.is_none());
365    }
366
367    #[test]
368    fn test_parse_tool_name_simple() {
369        let (server, tool) = McpManager::parse_tool_name("mcp__server__tool").unwrap();
370        assert_eq!(server, "server");
371        assert_eq!(tool, "tool");
372    }
373
374    #[test]
375    fn test_parse_tool_name_multiple_underscores() {
376        let (server, tool) = McpManager::parse_tool_name("mcp__my_server__my_tool_name").unwrap();
377        assert_eq!(server, "my_server");
378        assert_eq!(tool, "my_tool_name");
379    }
380
381    #[test]
382    fn test_parse_tool_name_missing_prefix() {
383        let result = McpManager::parse_tool_name("server__tool");
384        assert!(result.is_err());
385    }
386
387    #[test]
388    fn test_parse_tool_name_only_prefix() {
389        let result = McpManager::parse_tool_name("mcp__");
390        assert!(result.is_err());
391    }
392
393    #[test]
394    fn test_parse_tool_name_empty_string() {
395        let result = McpManager::parse_tool_name("");
396        assert!(result.is_err());
397    }
398
399    #[test]
400    fn test_tool_result_to_string_single_text() {
401        let result = CallToolResult {
402            content: vec![ToolContent::Text {
403                text: "Hello World".to_string(),
404            }],
405            is_error: false,
406        };
407        let output = tool_result_to_string(&result);
408        assert_eq!(output, "Hello World");
409    }
410
411    #[test]
412    fn test_tool_result_to_string_multiple_text() {
413        let result = CallToolResult {
414            content: vec![
415                ToolContent::Text {
416                    text: "First line".to_string(),
417                },
418                ToolContent::Text {
419                    text: "Second line".to_string(),
420                },
421            ],
422            is_error: false,
423        };
424        let output = tool_result_to_string(&result);
425        assert!(output.contains("First line"));
426        assert!(output.contains("Second line"));
427    }
428
429    #[test]
430    fn test_tool_result_to_string_empty() {
431        let result = CallToolResult {
432            content: vec![],
433            is_error: false,
434        };
435        let output = tool_result_to_string(&result);
436        assert_eq!(output, "");
437    }
438
439    #[test]
440    fn test_tool_result_to_string_image() {
441        let result = CallToolResult {
442            content: vec![ToolContent::Image {
443                data: "base64data".to_string(),
444                mime_type: "image/png".to_string(),
445            }],
446            is_error: false,
447        };
448        let output = tool_result_to_string(&result);
449        assert!(output.contains("[Image: image/png]"));
450    }
451
452    #[test]
453    fn test_tool_result_to_string_resource() {
454        use crate::mcp::protocol::ResourceContent;
455        let result = CallToolResult {
456            content: vec![ToolContent::Resource {
457                resource: ResourceContent {
458                    uri: "file:///test.txt".to_string(),
459                    mime_type: Some("text/plain".to_string()),
460                    text: Some("Resource content".to_string()),
461                    blob: None,
462                },
463            }],
464            is_error: false,
465        };
466        let output = tool_result_to_string(&result);
467        assert!(output.contains("Resource content"));
468    }
469
470    #[test]
471    fn test_tool_result_to_string_mixed_content() {
472        use crate::mcp::protocol::ResourceContent;
473        let result = CallToolResult {
474            content: vec![
475                ToolContent::Text {
476                    text: "Text content".to_string(),
477                },
478                ToolContent::Image {
479                    data: "base64".to_string(),
480                    mime_type: "image/jpeg".to_string(),
481                },
482                ToolContent::Resource {
483                    resource: ResourceContent {
484                        uri: "file:///doc.md".to_string(),
485                        mime_type: Some("text/markdown".to_string()),
486                        text: Some("Doc content".to_string()),
487                        blob: None,
488                    },
489                },
490            ],
491            is_error: false,
492        };
493        let output = tool_result_to_string(&result);
494        assert!(output.contains("Text content"));
495        assert!(output.contains("[Image: image/jpeg]"));
496        assert!(output.contains("Doc content"));
497    }
498
499    #[tokio::test]
500    async fn test_get_status_registered_server() {
501        use std::collections::HashMap;
502        let manager = McpManager::new();
503
504        let config = McpServerConfig {
505            name: "test_server".to_string(),
506            transport: McpTransportConfig::Stdio {
507                command: "echo".to_string(),
508                args: vec![],
509            },
510            enabled: true,
511            env: HashMap::new(),
512            oauth: None,
513            tool_timeout_secs: 60,
514        };
515
516        manager.register_server(config).await;
517
518        let status = manager.get_status().await;
519        assert!(status.contains_key("test_server"));
520        assert!(!status["test_server"].connected);
521        assert!(status["test_server"].enabled);
522    }
523
524    #[tokio::test]
525    async fn test_get_status_disabled_server() {
526        use std::collections::HashMap;
527        let manager = McpManager::new();
528
529        let config = McpServerConfig {
530            name: "disabled_server".to_string(),
531            transport: McpTransportConfig::Stdio {
532                command: "echo".to_string(),
533                args: vec![],
534            },
535            enabled: false,
536            env: HashMap::new(),
537            oauth: None,
538            tool_timeout_secs: 60,
539        };
540
541        manager.register_server(config).await;
542
543        let status = manager.get_status().await;
544        assert!(status.contains_key("disabled_server"));
545        assert!(!status["disabled_server"].enabled);
546    }
547
548    #[tokio::test]
549    async fn test_get_all_tools_empty_manager() {
550        let manager = McpManager::new();
551        let tools = manager.get_all_tools().await;
552        assert!(tools.is_empty());
553    }
554}