Skip to main content

a3s_code_core/mcp/
manager.rs

1//! MCP Manager
2//!
3//! Manages MCP server lifecycle and provides unified access to MCP tools.
4
5use crate::mcp::client::McpClient;
6use crate::mcp::protocol::{
7    CallToolResult, McpServerConfig, McpTool, McpTransportConfig, ToolContent,
8};
9use crate::mcp::transport::http_sse::HttpSseTransport;
10use crate::mcp::transport::stdio::StdioTransport;
11use crate::mcp::transport::streamable_http::StreamableHttpTransport;
12use crate::mcp::transport::McpTransport;
13use anyhow::{anyhow, Result};
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::RwLock;
17
18/// MCP server status
19#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
20pub struct McpServerStatus {
21    pub name: String,
22    pub connected: bool,
23    pub enabled: bool,
24    pub tool_count: usize,
25    pub error: Option<String>,
26}
27
28/// MCP Manager for managing multiple MCP servers
29pub struct McpManager {
30    /// Connected clients
31    clients: RwLock<HashMap<String, Arc<McpClient>>>,
32    /// Server configurations
33    configs: RwLock<HashMap<String, McpServerConfig>>,
34}
35
36impl McpManager {
37    /// Create a new MCP manager
38    pub fn new() -> Self {
39        Self {
40            clients: RwLock::new(HashMap::new()),
41            configs: RwLock::new(HashMap::new()),
42        }
43    }
44
45    /// Register a server configuration
46    pub async fn register_server(&self, config: McpServerConfig) {
47        let name = config.name.clone();
48        let mut configs = self.configs.write().await;
49        configs.insert(name.clone(), config);
50        tracing::info!("Registered MCP server: {}", name);
51    }
52
53    /// Connect to a registered server
54    pub async fn connect(&self, name: &str) -> Result<()> {
55        // Get config
56        let config = {
57            let configs = self.configs.read().await;
58            configs
59                .get(name)
60                .cloned()
61                .ok_or_else(|| anyhow!("MCP server not found: {}", name))?
62        };
63
64        if !config.enabled {
65            return Err(anyhow!("MCP server is disabled: {}", name));
66        }
67
68        // Create transport based on config
69        let transport: Arc<dyn McpTransport> = match &config.transport {
70            McpTransportConfig::Stdio { command, args } => Arc::new(
71                StdioTransport::spawn_with_timeout(
72                    command,
73                    args,
74                    &config.env,
75                    config.tool_timeout_secs,
76                )
77                .await?,
78            ),
79            McpTransportConfig::Http { url, headers } => Arc::new(
80                HttpSseTransport::connect_with_timeout(
81                    url,
82                    headers.clone(),
83                    config.tool_timeout_secs,
84                )
85                .await?,
86            ),
87            McpTransportConfig::StreamableHttp { url, headers } => Arc::new(
88                StreamableHttpTransport::connect_with_timeout(
89                    url,
90                    headers.clone(),
91                    config.tool_timeout_secs,
92                )
93                .await?,
94            ),
95        };
96
97        // Create client
98        let client = Arc::new(McpClient::new(name.to_string(), transport));
99
100        // Initialize
101        client.initialize().await?;
102
103        // Fetch tools
104        let tools = client.list_tools().await?;
105        tracing::info!("MCP server '{}' connected with {} tools", name, tools.len());
106
107        // Store client
108        {
109            let mut clients = self.clients.write().await;
110            clients.insert(name.to_string(), client);
111        }
112
113        Ok(())
114    }
115
116    /// Disconnect from a server
117    pub async fn disconnect(&self, name: &str) -> Result<()> {
118        let client = {
119            let mut clients = self.clients.write().await;
120            clients.remove(name)
121        };
122
123        if let Some(client) = client {
124            client.close().await?;
125            tracing::info!("MCP server '{}' disconnected", name);
126        }
127
128        Ok(())
129    }
130
131    /// Get all registered server configurations
132    pub async fn all_configs(&self) -> Vec<McpServerConfig> {
133        self.configs.read().await.values().cloned().collect()
134    }
135
136    /// Get all MCP tools with server prefix
137    ///
138    /// Returns tools with names like `mcp__github__create_issue`
139    pub async fn get_all_tools(&self) -> Vec<(String, McpTool)> {
140        let clients = self.clients.read().await;
141        let mut all_tools = Vec::new();
142
143        for (server_name, client) in clients.iter() {
144            let tools = client.get_cached_tools().await;
145            for tool in tools {
146                let full_name = format!("mcp__{}__{}", server_name, tool.name);
147                all_tools.push((full_name, tool));
148            }
149        }
150
151        all_tools
152    }
153
154    /// Call an MCP tool by full name
155    ///
156    /// Full name format: `mcp__<server>__<tool>`
157    pub async fn call_tool(
158        &self,
159        full_name: &str,
160        arguments: Option<serde_json::Value>,
161    ) -> Result<CallToolResult> {
162        // Parse full name
163        let (server_name, tool_name) = Self::parse_tool_name(full_name)?;
164
165        // Get client
166        let client = {
167            let clients = self.clients.read().await;
168            clients
169                .get(&server_name)
170                .cloned()
171                .ok_or_else(|| anyhow!("MCP server not connected: {}", server_name))?
172        };
173
174        // Call tool
175        client.call_tool(&tool_name, arguments).await
176    }
177
178    /// Parse MCP tool full name into (server, tool)
179    fn parse_tool_name(full_name: &str) -> Result<(String, String)> {
180        // Format: mcp__<server>__<tool>
181        if !full_name.starts_with("mcp__") {
182            return Err(anyhow!("Invalid MCP tool name: {}", full_name));
183        }
184
185        let rest = &full_name[5..]; // Skip "mcp__"
186        let parts: Vec<&str> = rest.splitn(2, "__").collect();
187
188        if parts.len() != 2 {
189            return Err(anyhow!("Invalid MCP tool name format: {}", full_name));
190        }
191
192        Ok((parts[0].to_string(), parts[1].to_string()))
193    }
194
195    /// Get status of all servers
196    pub async fn get_status(&self) -> HashMap<String, McpServerStatus> {
197        let configs = self.configs.read().await;
198        let clients = self.clients.read().await;
199        let mut status = HashMap::new();
200
201        for (name, config) in configs.iter() {
202            let client = clients.get(name);
203            let (connected, tool_count) = if let Some(c) = client {
204                (c.is_connected(), c.get_cached_tools().await.len())
205            } else {
206                (false, 0)
207            };
208
209            status.insert(
210                name.clone(),
211                McpServerStatus {
212                    name: name.clone(),
213                    connected,
214                    enabled: config.enabled,
215                    tool_count,
216                    error: None,
217                },
218            );
219        }
220
221        status
222    }
223
224    /// Get a specific client
225    pub async fn get_client(&self, name: &str) -> Option<Arc<McpClient>> {
226        let clients = self.clients.read().await;
227        clients.get(name).cloned()
228    }
229
230    /// Check if a server is connected
231    pub async fn is_connected(&self, name: &str) -> bool {
232        let clients = self.clients.read().await;
233        clients.get(name).map(|c| c.is_connected()).unwrap_or(false)
234    }
235
236    /// List connected server names
237    pub async fn list_connected(&self) -> Vec<String> {
238        let clients = self.clients.read().await;
239        clients.keys().cloned().collect()
240    }
241
242    /// Get cached tools for a specific connected server.
243    pub async fn get_server_tools(&self, name: &str) -> Vec<McpTool> {
244        let clients = self.clients.read().await;
245        match clients.get(name) {
246            Some(client) => client.get_cached_tools().await,
247            None => Vec::new(),
248        }
249    }
250}
251
252impl Default for McpManager {
253    fn default() -> Self {
254        Self::new()
255    }
256}
257
258/// Convert MCP tool result to string output
259pub fn tool_result_to_string(result: &CallToolResult) -> String {
260    let mut output = String::new();
261
262    for content in &result.content {
263        match content {
264            ToolContent::Text { text } => {
265                output.push_str(text);
266                output.push('\n');
267            }
268            ToolContent::Image { data: _, mime_type } => {
269                output.push_str(&format!("[Image: {}]\n", mime_type));
270            }
271            ToolContent::Resource { resource } => {
272                if let Some(text) = &resource.text {
273                    output.push_str(text);
274                    output.push('\n');
275                } else {
276                    output.push_str(&format!("[Resource: {}]\n", resource.uri));
277                }
278            }
279        }
280    }
281
282    output.trim_end().to_string()
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn test_parse_tool_name() {
291        let (server, tool) = McpManager::parse_tool_name("mcp__github__create_issue").unwrap();
292        assert_eq!(server, "github");
293        assert_eq!(tool, "create_issue");
294    }
295
296    #[test]
297    fn test_parse_tool_name_with_underscores() {
298        let (server, tool) = McpManager::parse_tool_name("mcp__my_server__my_tool_name").unwrap();
299        assert_eq!(server, "my_server");
300        assert_eq!(tool, "my_tool_name");
301    }
302
303    #[test]
304    fn test_parse_tool_name_invalid() {
305        assert!(McpManager::parse_tool_name("invalid_name").is_err());
306        assert!(McpManager::parse_tool_name("mcp__nodelimiter").is_err());
307    }
308
309    #[test]
310    fn test_tool_result_to_string() {
311        let result = CallToolResult {
312            content: vec![
313                ToolContent::Text {
314                    text: "Line 1".to_string(),
315                },
316                ToolContent::Text {
317                    text: "Line 2".to_string(),
318                },
319            ],
320            is_error: false,
321        };
322
323        let output = tool_result_to_string(&result);
324        assert!(output.contains("Line 1"));
325        assert!(output.contains("Line 2"));
326    }
327
328    #[tokio::test]
329    async fn test_mcp_manager_new() {
330        let manager = McpManager::new();
331        let status = manager.get_status().await;
332        assert!(status.is_empty());
333    }
334
335    #[tokio::test]
336    async fn test_mcp_manager_register_server() {
337        let manager = McpManager::new();
338
339        let config = McpServerConfig {
340            name: "test".to_string(),
341            transport: McpTransportConfig::Stdio {
342                command: "echo".to_string(),
343                args: vec![],
344            },
345            enabled: true,
346            env: HashMap::new(),
347            oauth: None,
348            tool_timeout_secs: 60,
349        };
350
351        manager.register_server(config).await;
352
353        let status = manager.get_status().await;
354        assert!(status.contains_key("test"));
355        assert!(!status["test"].connected);
356    }
357
358    #[tokio::test]
359    async fn test_mcp_manager_default() {
360        let manager = McpManager::default();
361        let status = manager.get_status().await;
362        assert!(status.is_empty());
363    }
364
365    #[tokio::test]
366    async fn test_list_connected_empty() {
367        let manager = McpManager::new();
368        let connected = manager.list_connected().await;
369        assert!(connected.is_empty());
370    }
371
372    #[tokio::test]
373    async fn test_is_connected_false_for_unknown_server() {
374        let manager = McpManager::new();
375        let connected = manager.is_connected("unknown_server").await;
376        assert!(!connected);
377    }
378
379    #[tokio::test]
380    async fn test_get_client_none_for_unknown_server() {
381        let manager = McpManager::new();
382        let client = manager.get_client("unknown_server").await;
383        assert!(client.is_none());
384    }
385
386    #[test]
387    fn test_parse_tool_name_simple() {
388        let (server, tool) = McpManager::parse_tool_name("mcp__server__tool").unwrap();
389        assert_eq!(server, "server");
390        assert_eq!(tool, "tool");
391    }
392
393    #[test]
394    fn test_parse_tool_name_multiple_underscores() {
395        let (server, tool) = McpManager::parse_tool_name("mcp__my_server__my_tool_name").unwrap();
396        assert_eq!(server, "my_server");
397        assert_eq!(tool, "my_tool_name");
398    }
399
400    #[test]
401    fn test_parse_tool_name_missing_prefix() {
402        let result = McpManager::parse_tool_name("server__tool");
403        assert!(result.is_err());
404    }
405
406    #[test]
407    fn test_parse_tool_name_only_prefix() {
408        let result = McpManager::parse_tool_name("mcp__");
409        assert!(result.is_err());
410    }
411
412    #[test]
413    fn test_parse_tool_name_empty_string() {
414        let result = McpManager::parse_tool_name("");
415        assert!(result.is_err());
416    }
417
418    #[test]
419    fn test_tool_result_to_string_single_text() {
420        let result = CallToolResult {
421            content: vec![ToolContent::Text {
422                text: "Hello World".to_string(),
423            }],
424            is_error: false,
425        };
426        let output = tool_result_to_string(&result);
427        assert_eq!(output, "Hello World");
428    }
429
430    #[test]
431    fn test_tool_result_to_string_multiple_text() {
432        let result = CallToolResult {
433            content: vec![
434                ToolContent::Text {
435                    text: "First line".to_string(),
436                },
437                ToolContent::Text {
438                    text: "Second line".to_string(),
439                },
440            ],
441            is_error: false,
442        };
443        let output = tool_result_to_string(&result);
444        assert!(output.contains("First line"));
445        assert!(output.contains("Second line"));
446    }
447
448    #[test]
449    fn test_tool_result_to_string_empty() {
450        let result = CallToolResult {
451            content: vec![],
452            is_error: false,
453        };
454        let output = tool_result_to_string(&result);
455        assert_eq!(output, "");
456    }
457
458    #[test]
459    fn test_tool_result_to_string_image() {
460        let result = CallToolResult {
461            content: vec![ToolContent::Image {
462                data: "base64data".to_string(),
463                mime_type: "image/png".to_string(),
464            }],
465            is_error: false,
466        };
467        let output = tool_result_to_string(&result);
468        assert!(output.contains("[Image: image/png]"));
469    }
470
471    #[test]
472    fn test_tool_result_to_string_resource() {
473        use crate::mcp::protocol::ResourceContent;
474        let result = CallToolResult {
475            content: vec![ToolContent::Resource {
476                resource: ResourceContent {
477                    uri: "file:///test.txt".to_string(),
478                    mime_type: Some("text/plain".to_string()),
479                    text: Some("Resource content".to_string()),
480                    blob: None,
481                },
482            }],
483            is_error: false,
484        };
485        let output = tool_result_to_string(&result);
486        assert!(output.contains("Resource content"));
487    }
488
489    #[test]
490    fn test_tool_result_to_string_mixed_content() {
491        use crate::mcp::protocol::ResourceContent;
492        let result = CallToolResult {
493            content: vec![
494                ToolContent::Text {
495                    text: "Text content".to_string(),
496                },
497                ToolContent::Image {
498                    data: "base64".to_string(),
499                    mime_type: "image/jpeg".to_string(),
500                },
501                ToolContent::Resource {
502                    resource: ResourceContent {
503                        uri: "file:///doc.md".to_string(),
504                        mime_type: Some("text/markdown".to_string()),
505                        text: Some("Doc content".to_string()),
506                        blob: None,
507                    },
508                },
509            ],
510            is_error: false,
511        };
512        let output = tool_result_to_string(&result);
513        assert!(output.contains("Text content"));
514        assert!(output.contains("[Image: image/jpeg]"));
515        assert!(output.contains("Doc content"));
516    }
517
518    #[tokio::test]
519    async fn test_get_status_registered_server() {
520        use std::collections::HashMap;
521        let manager = McpManager::new();
522
523        let config = McpServerConfig {
524            name: "test_server".to_string(),
525            transport: McpTransportConfig::Stdio {
526                command: "echo".to_string(),
527                args: vec![],
528            },
529            enabled: true,
530            env: HashMap::new(),
531            oauth: None,
532            tool_timeout_secs: 60,
533        };
534
535        manager.register_server(config).await;
536
537        let status = manager.get_status().await;
538        assert!(status.contains_key("test_server"));
539        assert!(!status["test_server"].connected);
540        assert!(status["test_server"].enabled);
541    }
542
543    #[tokio::test]
544    async fn test_get_status_disabled_server() {
545        use std::collections::HashMap;
546        let manager = McpManager::new();
547
548        let config = McpServerConfig {
549            name: "disabled_server".to_string(),
550            transport: McpTransportConfig::Stdio {
551                command: "echo".to_string(),
552                args: vec![],
553            },
554            enabled: false,
555            env: HashMap::new(),
556            oauth: None,
557            tool_timeout_secs: 60,
558        };
559
560        manager.register_server(config).await;
561
562        let status = manager.get_status().await;
563        assert!(status.contains_key("disabled_server"));
564        assert!(!status["disabled_server"].enabled);
565    }
566
567    #[tokio::test]
568    async fn test_get_all_tools_empty_manager() {
569        let manager = McpManager::new();
570        let tools = manager.get_all_tools().await;
571        assert!(tools.is_empty());
572    }
573}