Skip to main content

a3s_code_core/mcp/
manager.rs

1//! MCP Manager
2//!
3//! Manages MCP server lifecycle and provides unified access to MCP tools.
4
5use crate::mcp::client::McpClient;
6use crate::mcp::protocol::{
7    CallToolResult, McpServerConfig, McpTool, McpTransportConfig, ToolContent,
8};
9use crate::mcp::transport::stdio::StdioTransport;
10use crate::mcp::transport::McpTransport;
11use anyhow::{anyhow, Result};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16/// MCP server status
17#[derive(Debug, Clone)]
18pub struct McpServerStatus {
19    pub name: String,
20    pub connected: bool,
21    pub enabled: bool,
22    pub tool_count: usize,
23    pub error: Option<String>,
24}
25
26/// MCP Manager for managing multiple MCP servers
27pub struct McpManager {
28    /// Connected clients
29    clients: RwLock<HashMap<String, Arc<McpClient>>>,
30    /// Server configurations
31    configs: RwLock<HashMap<String, McpServerConfig>>,
32}
33
34impl McpManager {
35    /// Create a new MCP manager
36    pub fn new() -> Self {
37        Self {
38            clients: RwLock::new(HashMap::new()),
39            configs: RwLock::new(HashMap::new()),
40        }
41    }
42
43    /// Register a server configuration
44    pub async fn register_server(&self, config: McpServerConfig) {
45        let name = config.name.clone();
46        let mut configs = self.configs.write().await;
47        configs.insert(name.clone(), config);
48        tracing::info!("Registered MCP server: {}", name);
49    }
50
51    /// Connect to a registered server
52    pub async fn connect(&self, name: &str) -> Result<()> {
53        // Get config
54        let config = {
55            let configs = self.configs.read().await;
56            configs
57                .get(name)
58                .cloned()
59                .ok_or_else(|| anyhow!("MCP server not found: {}", name))?
60        };
61
62        if !config.enabled {
63            return Err(anyhow!("MCP server is disabled: {}", name));
64        }
65
66        // Create transport based on config
67        let transport: Arc<dyn McpTransport> = match &config.transport {
68            McpTransportConfig::Stdio { command, args } => Arc::new(
69                StdioTransport::spawn_with_timeout(
70                    command,
71                    args,
72                    &config.env,
73                    config.tool_timeout_secs,
74                )
75                .await?,
76            ),
77            McpTransportConfig::Http { url: _, headers: _ } => {
78                // HTTP transport not implemented yet
79                return Err(anyhow!("HTTP transport not yet implemented"));
80            }
81        };
82
83        // Create client
84        let client = Arc::new(McpClient::new(name.to_string(), transport));
85
86        // Initialize
87        client.initialize().await?;
88
89        // Fetch tools
90        let tools = client.list_tools().await?;
91        tracing::info!("MCP server '{}' connected with {} tools", name, tools.len());
92
93        // Store client
94        {
95            let mut clients = self.clients.write().await;
96            clients.insert(name.to_string(), client);
97        }
98
99        Ok(())
100    }
101
102    /// Disconnect from a server
103    pub async fn disconnect(&self, name: &str) -> Result<()> {
104        let client = {
105            let mut clients = self.clients.write().await;
106            clients.remove(name)
107        };
108
109        if let Some(client) = client {
110            client.close().await?;
111            tracing::info!("MCP server '{}' disconnected", name);
112        }
113
114        Ok(())
115    }
116
117    /// Get all MCP tools with server prefix
118    ///
119    /// Returns tools with names like `mcp__github__create_issue`
120    pub async fn get_all_tools(&self) -> Vec<(String, McpTool)> {
121        let clients = self.clients.read().await;
122        let mut all_tools = Vec::new();
123
124        for (server_name, client) in clients.iter() {
125            let tools = client.get_cached_tools().await;
126            for tool in tools {
127                let full_name = format!("mcp__{}_{}", server_name, tool.name);
128                all_tools.push((full_name, tool));
129            }
130        }
131
132        all_tools
133    }
134
135    /// Call an MCP tool by full name
136    ///
137    /// Full name format: `mcp__<server>__<tool>`
138    pub async fn call_tool(
139        &self,
140        full_name: &str,
141        arguments: Option<serde_json::Value>,
142    ) -> Result<CallToolResult> {
143        // Parse full name
144        let (server_name, tool_name) = Self::parse_tool_name(full_name)?;
145
146        // Get client
147        let client = {
148            let clients = self.clients.read().await;
149            clients
150                .get(&server_name)
151                .cloned()
152                .ok_or_else(|| anyhow!("MCP server not connected: {}", server_name))?
153        };
154
155        // Call tool
156        client.call_tool(&tool_name, arguments).await
157    }
158
159    /// Parse MCP tool full name into (server, tool)
160    fn parse_tool_name(full_name: &str) -> Result<(String, String)> {
161        // Format: mcp__<server>__<tool>
162        if !full_name.starts_with("mcp__") {
163            return Err(anyhow!("Invalid MCP tool name: {}", full_name));
164        }
165
166        let rest = &full_name[5..]; // Skip "mcp__"
167        let parts: Vec<&str> = rest.splitn(2, "__").collect();
168
169        if parts.len() != 2 {
170            return Err(anyhow!("Invalid MCP tool name format: {}", full_name));
171        }
172
173        Ok((parts[0].to_string(), parts[1].to_string()))
174    }
175
176    /// Get status of all servers
177    pub async fn get_status(&self) -> HashMap<String, McpServerStatus> {
178        let configs = self.configs.read().await;
179        let clients = self.clients.read().await;
180        let mut status = HashMap::new();
181
182        for (name, config) in configs.iter() {
183            let client = clients.get(name);
184            let (connected, tool_count) = if let Some(c) = client {
185                (c.is_connected(), c.get_cached_tools().await.len())
186            } else {
187                (false, 0)
188            };
189
190            status.insert(
191                name.clone(),
192                McpServerStatus {
193                    name: name.clone(),
194                    connected,
195                    enabled: config.enabled,
196                    tool_count,
197                    error: None,
198                },
199            );
200        }
201
202        status
203    }
204
205    /// Get a specific client
206    pub async fn get_client(&self, name: &str) -> Option<Arc<McpClient>> {
207        let clients = self.clients.read().await;
208        clients.get(name).cloned()
209    }
210
211    /// Check if a server is connected
212    pub async fn is_connected(&self, name: &str) -> bool {
213        let clients = self.clients.read().await;
214        clients.get(name).map(|c| c.is_connected()).unwrap_or(false)
215    }
216
217    /// List connected server names
218    pub async fn list_connected(&self) -> Vec<String> {
219        let clients = self.clients.read().await;
220        clients.keys().cloned().collect()
221    }
222}
223
224impl Default for McpManager {
225    fn default() -> Self {
226        Self::new()
227    }
228}
229
230/// Convert MCP tool result to string output
231pub fn tool_result_to_string(result: &CallToolResult) -> String {
232    let mut output = String::new();
233
234    for content in &result.content {
235        match content {
236            ToolContent::Text { text } => {
237                output.push_str(text);
238                output.push('\n');
239            }
240            ToolContent::Image { data: _, mime_type } => {
241                output.push_str(&format!("[Image: {}]\n", mime_type));
242            }
243            ToolContent::Resource { resource } => {
244                if let Some(text) = &resource.text {
245                    output.push_str(text);
246                    output.push('\n');
247                } else {
248                    output.push_str(&format!("[Resource: {}]\n", resource.uri));
249                }
250            }
251        }
252    }
253
254    output.trim_end().to_string()
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260
261    #[test]
262    fn test_parse_tool_name() {
263        let (server, tool) = McpManager::parse_tool_name("mcp__github__create_issue").unwrap();
264        assert_eq!(server, "github");
265        assert_eq!(tool, "create_issue");
266    }
267
268    #[test]
269    fn test_parse_tool_name_with_underscores() {
270        let (server, tool) = McpManager::parse_tool_name("mcp__my_server__my_tool_name").unwrap();
271        assert_eq!(server, "my_server");
272        assert_eq!(tool, "my_tool_name");
273    }
274
275    #[test]
276    fn test_parse_tool_name_invalid() {
277        assert!(McpManager::parse_tool_name("invalid_name").is_err());
278        assert!(McpManager::parse_tool_name("mcp__nodelimiter").is_err());
279    }
280
281    #[test]
282    fn test_tool_result_to_string() {
283        let result = CallToolResult {
284            content: vec![
285                ToolContent::Text {
286                    text: "Line 1".to_string(),
287                },
288                ToolContent::Text {
289                    text: "Line 2".to_string(),
290                },
291            ],
292            is_error: false,
293        };
294
295        let output = tool_result_to_string(&result);
296        assert!(output.contains("Line 1"));
297        assert!(output.contains("Line 2"));
298    }
299
300    #[tokio::test]
301    async fn test_mcp_manager_new() {
302        let manager = McpManager::new();
303        let status = manager.get_status().await;
304        assert!(status.is_empty());
305    }
306
307    #[tokio::test]
308    async fn test_mcp_manager_register_server() {
309        let manager = McpManager::new();
310
311        let config = McpServerConfig {
312            name: "test".to_string(),
313            transport: McpTransportConfig::Stdio {
314                command: "echo".to_string(),
315                args: vec![],
316            },
317            enabled: true,
318            env: HashMap::new(),
319            oauth: None,
320            tool_timeout_secs: 60,
321        };
322
323        manager.register_server(config).await;
324
325        let status = manager.get_status().await;
326        assert!(status.contains_key("test"));
327        assert!(!status["test"].connected);
328    }
329
330    #[tokio::test]
331    async fn test_mcp_manager_default() {
332        let manager = McpManager::default();
333        let status = manager.get_status().await;
334        assert!(status.is_empty());
335    }
336
337    #[tokio::test]
338    async fn test_list_connected_empty() {
339        let manager = McpManager::new();
340        let connected = manager.list_connected().await;
341        assert!(connected.is_empty());
342    }
343
344    #[tokio::test]
345    async fn test_is_connected_false_for_unknown_server() {
346        let manager = McpManager::new();
347        let connected = manager.is_connected("unknown_server").await;
348        assert!(!connected);
349    }
350
351    #[tokio::test]
352    async fn test_get_client_none_for_unknown_server() {
353        let manager = McpManager::new();
354        let client = manager.get_client("unknown_server").await;
355        assert!(client.is_none());
356    }
357
358    #[test]
359    fn test_parse_tool_name_simple() {
360        let (server, tool) = McpManager::parse_tool_name("mcp__server__tool").unwrap();
361        assert_eq!(server, "server");
362        assert_eq!(tool, "tool");
363    }
364
365    #[test]
366    fn test_parse_tool_name_multiple_underscores() {
367        let (server, tool) = McpManager::parse_tool_name("mcp__my_server__my_tool_name").unwrap();
368        assert_eq!(server, "my_server");
369        assert_eq!(tool, "my_tool_name");
370    }
371
372    #[test]
373    fn test_parse_tool_name_missing_prefix() {
374        let result = McpManager::parse_tool_name("server__tool");
375        assert!(result.is_err());
376    }
377
378    #[test]
379    fn test_parse_tool_name_only_prefix() {
380        let result = McpManager::parse_tool_name("mcp__");
381        assert!(result.is_err());
382    }
383
384    #[test]
385    fn test_parse_tool_name_empty_string() {
386        let result = McpManager::parse_tool_name("");
387        assert!(result.is_err());
388    }
389
390    #[test]
391    fn test_tool_result_to_string_single_text() {
392        let result = CallToolResult {
393            content: vec![ToolContent::Text {
394                text: "Hello World".to_string(),
395            }],
396            is_error: false,
397        };
398        let output = tool_result_to_string(&result);
399        assert_eq!(output, "Hello World");
400    }
401
402    #[test]
403    fn test_tool_result_to_string_multiple_text() {
404        let result = CallToolResult {
405            content: vec![
406                ToolContent::Text {
407                    text: "First line".to_string(),
408                },
409                ToolContent::Text {
410                    text: "Second line".to_string(),
411                },
412            ],
413            is_error: false,
414        };
415        let output = tool_result_to_string(&result);
416        assert!(output.contains("First line"));
417        assert!(output.contains("Second line"));
418    }
419
420    #[test]
421    fn test_tool_result_to_string_empty() {
422        let result = CallToolResult {
423            content: vec![],
424            is_error: false,
425        };
426        let output = tool_result_to_string(&result);
427        assert_eq!(output, "");
428    }
429
430    #[test]
431    fn test_tool_result_to_string_image() {
432        let result = CallToolResult {
433            content: vec![ToolContent::Image {
434                data: "base64data".to_string(),
435                mime_type: "image/png".to_string(),
436            }],
437            is_error: false,
438        };
439        let output = tool_result_to_string(&result);
440        assert!(output.contains("[Image: image/png]"));
441    }
442
443    #[test]
444    fn test_tool_result_to_string_resource() {
445        use crate::mcp::protocol::ResourceContent;
446        let result = CallToolResult {
447            content: vec![ToolContent::Resource {
448                resource: ResourceContent {
449                    uri: "file:///test.txt".to_string(),
450                    mime_type: Some("text/plain".to_string()),
451                    text: Some("Resource content".to_string()),
452                    blob: None,
453                },
454            }],
455            is_error: false,
456        };
457        let output = tool_result_to_string(&result);
458        assert!(output.contains("Resource content"));
459    }
460
461    #[test]
462    fn test_tool_result_to_string_mixed_content() {
463        use crate::mcp::protocol::ResourceContent;
464        let result = CallToolResult {
465            content: vec![
466                ToolContent::Text {
467                    text: "Text content".to_string(),
468                },
469                ToolContent::Image {
470                    data: "base64".to_string(),
471                    mime_type: "image/jpeg".to_string(),
472                },
473                ToolContent::Resource {
474                    resource: ResourceContent {
475                        uri: "file:///doc.md".to_string(),
476                        mime_type: Some("text/markdown".to_string()),
477                        text: Some("Doc content".to_string()),
478                        blob: None,
479                    },
480                },
481            ],
482            is_error: false,
483        };
484        let output = tool_result_to_string(&result);
485        assert!(output.contains("Text content"));
486        assert!(output.contains("[Image: image/jpeg]"));
487        assert!(output.contains("Doc content"));
488    }
489
490    #[tokio::test]
491    async fn test_get_status_registered_server() {
492        use std::collections::HashMap;
493        let manager = McpManager::new();
494
495        let config = McpServerConfig {
496            name: "test_server".to_string(),
497            transport: McpTransportConfig::Stdio {
498                command: "echo".to_string(),
499                args: vec![],
500            },
501            enabled: true,
502            env: HashMap::new(),
503            oauth: None,
504            tool_timeout_secs: 60,
505        };
506
507        manager.register_server(config).await;
508
509        let status = manager.get_status().await;
510        assert!(status.contains_key("test_server"));
511        assert!(!status["test_server"].connected);
512        assert!(status["test_server"].enabled);
513    }
514
515    #[tokio::test]
516    async fn test_get_status_disabled_server() {
517        use std::collections::HashMap;
518        let manager = McpManager::new();
519
520        let config = McpServerConfig {
521            name: "disabled_server".to_string(),
522            transport: McpTransportConfig::Stdio {
523                command: "echo".to_string(),
524                args: vec![],
525            },
526            enabled: false,
527            env: HashMap::new(),
528            oauth: None,
529            tool_timeout_secs: 60,
530        };
531
532        manager.register_server(config).await;
533
534        let status = manager.get_status().await;
535        assert!(status.contains_key("disabled_server"));
536        assert!(!status["disabled_server"].enabled);
537    }
538
539    #[tokio::test]
540    async fn test_get_all_tools_empty_manager() {
541        let manager = McpManager::new();
542        let tools = manager.get_all_tools().await;
543        assert!(tools.is_empty());
544    }
545}