Skip to main content

a3s_code_core/mcp/
manager.rs

1//! MCP Manager
2//!
3//! Manages MCP server lifecycle and provides unified access to MCP tools.
4
5use crate::mcp::client::McpClient;
6use crate::mcp::oauth;
7use crate::mcp::protocol::{
8    CallToolResult, McpServerConfig, McpTool, McpTransportConfig, OAuthConfig, ToolContent,
9};
10use crate::mcp::transport::http_sse::HttpSseTransport;
11use crate::mcp::transport::stdio::StdioTransport;
12use crate::mcp::transport::streamable_http::StreamableHttpTransport;
13use crate::mcp::transport::McpTransport;
14use anyhow::{anyhow, Result};
15use std::collections::HashMap;
16use std::sync::Arc;
17use tokio::sync::RwLock;
18
19/// MCP server status
20#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
21pub struct McpServerStatus {
22    pub name: String,
23    pub connected: bool,
24    pub enabled: bool,
25    pub tool_count: usize,
26    pub error: Option<String>,
27}
28
29/// MCP Manager for managing multiple MCP servers
30pub struct McpManager {
31    /// Connected clients
32    clients: RwLock<HashMap<String, Arc<McpClient>>>,
33    /// Server configurations
34    configs: RwLock<HashMap<String, McpServerConfig>>,
35}
36
37impl McpManager {
38    /// Create a new MCP manager
39    pub fn new() -> Self {
40        Self {
41            clients: RwLock::new(HashMap::new()),
42            configs: RwLock::new(HashMap::new()),
43        }
44    }
45
46    /// Register a server configuration
47    pub async fn register_server(&self, config: McpServerConfig) {
48        let name = config.name.clone();
49        let mut configs = self.configs.write().await;
50        configs.insert(name.clone(), config);
51        tracing::info!("Registered MCP server: {}", name);
52    }
53
54    /// Connect to a registered server
55    pub async fn connect(&self, name: &str) -> Result<()> {
56        // Get config
57        let config = {
58            let configs = self.configs.read().await;
59            configs
60                .get(name)
61                .cloned()
62                .ok_or_else(|| anyhow!("MCP server not found: {}", name))?
63        };
64
65        if !config.enabled {
66            return Err(anyhow!("MCP server is disabled: {}", name));
67        }
68
69        // Resolve OAuth token into an Authorization header (if configured)
70        let auth_header = Self::resolve_auth_header(config.oauth.as_ref()).await?;
71
72        // Create transport based on config
73        let transport: Arc<dyn McpTransport> = match &config.transport {
74            McpTransportConfig::Stdio { command, args } => Arc::new(
75                StdioTransport::spawn_with_timeout(
76                    command,
77                    args,
78                    &config.env,
79                    config.tool_timeout_secs,
80                )
81                .await?,
82            ),
83            McpTransportConfig::Http { url, headers } => {
84                let mut merged = headers.clone();
85                if let Some((k, v)) = &auth_header {
86                    merged.insert(k.clone(), v.clone());
87                }
88                Arc::new(
89                    HttpSseTransport::connect_with_timeout(url, merged, config.tool_timeout_secs)
90                        .await?,
91                )
92            }
93            McpTransportConfig::StreamableHttp { url, headers } => {
94                let mut merged = headers.clone();
95                if let Some((k, v)) = &auth_header {
96                    merged.insert(k.clone(), v.clone());
97                }
98                Arc::new(
99                    StreamableHttpTransport::connect_with_timeout(
100                        url,
101                        merged,
102                        config.tool_timeout_secs,
103                    )
104                    .await?,
105                )
106            }
107        };
108
109        // Create client
110        let client = Arc::new(McpClient::new(name.to_string(), transport));
111
112        // Initialize
113        client.initialize().await?;
114
115        // Fetch tools
116        let tools = client.list_tools().await?;
117        tracing::info!("MCP server '{}' connected with {} tools", name, tools.len());
118
119        // Store client
120        {
121            let mut clients = self.clients.write().await;
122            clients.insert(name.to_string(), client);
123        }
124
125        Ok(())
126    }
127
128    /// Disconnect from a server
129    pub async fn disconnect(&self, name: &str) -> Result<()> {
130        let client = {
131            let mut clients = self.clients.write().await;
132            clients.remove(name)
133        };
134
135        if let Some(client) = client {
136            client.close().await?;
137            tracing::info!("MCP server '{}' disconnected", name);
138        }
139
140        Ok(())
141    }
142
143    /// Get all registered server configurations
144    pub async fn all_configs(&self) -> Vec<McpServerConfig> {
145        self.configs.read().await.values().cloned().collect()
146    }
147
148    /// Get all MCP tools with server prefix
149    ///
150    /// Returns tools with names like `mcp__github__create_issue`
151    pub async fn get_all_tools(&self) -> Vec<(String, McpTool)> {
152        let clients = self.clients.read().await;
153        let mut all_tools = Vec::new();
154
155        for (server_name, client) in clients.iter() {
156            let tools = client.get_cached_tools().await;
157            for tool in tools {
158                let full_name = format!("mcp__{}__{}", server_name, tool.name);
159                all_tools.push((full_name, tool));
160            }
161        }
162
163        all_tools
164    }
165
166    /// Call an MCP tool by full name
167    ///
168    /// Full name format: `mcp__<server>__<tool>`
169    pub async fn call_tool(
170        &self,
171        full_name: &str,
172        arguments: Option<serde_json::Value>,
173    ) -> Result<CallToolResult> {
174        // Parse full name
175        let (server_name, tool_name) = Self::parse_tool_name(full_name)?;
176
177        // Get client
178        let client = {
179            let clients = self.clients.read().await;
180            clients
181                .get(&server_name)
182                .cloned()
183                .ok_or_else(|| anyhow!("MCP server not connected: {}", server_name))?
184        };
185
186        // Call tool
187        client.call_tool(&tool_name, arguments).await
188    }
189
190    /// Resolve an OAuth config into a `(header-name, header-value)` pair.
191    ///
192    /// - If `oauth.access_token` is set, uses it directly (static token).
193    /// - Otherwise, performs a client credentials exchange.
194    /// - If `oauth` is `None`, returns `Ok(None)` (no auth needed).
195    async fn resolve_auth_header(oauth: Option<&OAuthConfig>) -> Result<Option<(String, String)>> {
196        let Some(oauth) = oauth else {
197            return Ok(None);
198        };
199
200        let token = if let Some(static_token) = &oauth.access_token {
201            static_token.clone()
202        } else {
203            oauth::exchange_client_credentials(
204                &oauth.token_url,
205                &oauth.client_id,
206                oauth.client_secret.as_deref().unwrap_or(""),
207                &oauth.scopes,
208            )
209            .await?
210        };
211
212        Ok(Some((
213            "Authorization".to_string(),
214            format!("Bearer {}", token),
215        )))
216    }
217
218    /// Parse MCP tool full name into (server, tool)
219    fn parse_tool_name(full_name: &str) -> Result<(String, String)> {
220        // Format: mcp__<server>__<tool>
221        if !full_name.starts_with("mcp__") {
222            return Err(anyhow!("Invalid MCP tool name: {}", full_name));
223        }
224
225        let rest = &full_name[5..]; // Skip "mcp__"
226        let parts: Vec<&str> = rest.splitn(2, "__").collect();
227
228        if parts.len() != 2 {
229            return Err(anyhow!("Invalid MCP tool name format: {}", full_name));
230        }
231
232        Ok((parts[0].to_string(), parts[1].to_string()))
233    }
234
235    /// Get status of all servers
236    pub async fn get_status(&self) -> HashMap<String, McpServerStatus> {
237        let configs = self.configs.read().await;
238        let clients = self.clients.read().await;
239        let mut status = HashMap::new();
240
241        for (name, config) in configs.iter() {
242            let client = clients.get(name);
243            let (connected, tool_count) = if let Some(c) = client {
244                (c.is_connected(), c.get_cached_tools().await.len())
245            } else {
246                (false, 0)
247            };
248
249            status.insert(
250                name.clone(),
251                McpServerStatus {
252                    name: name.clone(),
253                    connected,
254                    enabled: config.enabled,
255                    tool_count,
256                    error: None,
257                },
258            );
259        }
260
261        status
262    }
263
264    /// Get a specific client
265    pub async fn get_client(&self, name: &str) -> Option<Arc<McpClient>> {
266        let clients = self.clients.read().await;
267        clients.get(name).cloned()
268    }
269
270    /// Check if a server is connected
271    pub async fn is_connected(&self, name: &str) -> bool {
272        let clients = self.clients.read().await;
273        clients.get(name).map(|c| c.is_connected()).unwrap_or(false)
274    }
275
276    /// List connected server names
277    pub async fn list_connected(&self) -> Vec<String> {
278        let clients = self.clients.read().await;
279        clients.keys().cloned().collect()
280    }
281
282    /// Get cached tools for a specific connected server.
283    pub async fn get_server_tools(&self, name: &str) -> Vec<McpTool> {
284        let clients = self.clients.read().await;
285        match clients.get(name) {
286            Some(client) => client.get_cached_tools().await,
287            None => Vec::new(),
288        }
289    }
290}
291
292impl Default for McpManager {
293    fn default() -> Self {
294        Self::new()
295    }
296}
297
298/// Convert MCP tool result to string output
299pub fn tool_result_to_string(result: &CallToolResult) -> String {
300    let mut output = String::new();
301
302    for content in &result.content {
303        match content {
304            ToolContent::Text { text } => {
305                output.push_str(text);
306                output.push('\n');
307            }
308            ToolContent::Image { data: _, mime_type } => {
309                output.push_str(&format!("[Image: {}]\n", mime_type));
310            }
311            ToolContent::Resource { resource } => {
312                if let Some(text) = &resource.text {
313                    output.push_str(text);
314                    output.push('\n');
315                } else {
316                    output.push_str(&format!("[Resource: {}]\n", resource.uri));
317                }
318            }
319        }
320    }
321
322    output.trim_end().to_string()
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn test_parse_tool_name() {
331        let (server, tool) = McpManager::parse_tool_name("mcp__github__create_issue").unwrap();
332        assert_eq!(server, "github");
333        assert_eq!(tool, "create_issue");
334    }
335
336    #[test]
337    fn test_parse_tool_name_with_underscores() {
338        let (server, tool) = McpManager::parse_tool_name("mcp__my_server__my_tool_name").unwrap();
339        assert_eq!(server, "my_server");
340        assert_eq!(tool, "my_tool_name");
341    }
342
343    #[test]
344    fn test_parse_tool_name_invalid() {
345        assert!(McpManager::parse_tool_name("invalid_name").is_err());
346        assert!(McpManager::parse_tool_name("mcp__nodelimiter").is_err());
347    }
348
349    #[test]
350    fn test_tool_result_to_string() {
351        let result = CallToolResult {
352            content: vec![
353                ToolContent::Text {
354                    text: "Line 1".to_string(),
355                },
356                ToolContent::Text {
357                    text: "Line 2".to_string(),
358                },
359            ],
360            is_error: false,
361        };
362
363        let output = tool_result_to_string(&result);
364        assert!(output.contains("Line 1"));
365        assert!(output.contains("Line 2"));
366    }
367
368    #[tokio::test]
369    async fn test_mcp_manager_new() {
370        let manager = McpManager::new();
371        let status = manager.get_status().await;
372        assert!(status.is_empty());
373    }
374
375    #[tokio::test]
376    async fn test_mcp_manager_register_server() {
377        let manager = McpManager::new();
378
379        let config = McpServerConfig {
380            name: "test".to_string(),
381            transport: McpTransportConfig::Stdio {
382                command: "echo".to_string(),
383                args: vec![],
384            },
385            enabled: true,
386            env: HashMap::new(),
387            oauth: None,
388            tool_timeout_secs: 60,
389        };
390
391        manager.register_server(config).await;
392
393        let status = manager.get_status().await;
394        assert!(status.contains_key("test"));
395        assert!(!status["test"].connected);
396    }
397
398    #[tokio::test]
399    async fn test_mcp_manager_default() {
400        let manager = McpManager::default();
401        let status = manager.get_status().await;
402        assert!(status.is_empty());
403    }
404
405    #[tokio::test]
406    async fn test_list_connected_empty() {
407        let manager = McpManager::new();
408        let connected = manager.list_connected().await;
409        assert!(connected.is_empty());
410    }
411
412    #[tokio::test]
413    async fn test_is_connected_false_for_unknown_server() {
414        let manager = McpManager::new();
415        let connected = manager.is_connected("unknown_server").await;
416        assert!(!connected);
417    }
418
419    #[tokio::test]
420    async fn test_get_client_none_for_unknown_server() {
421        let manager = McpManager::new();
422        let client = manager.get_client("unknown_server").await;
423        assert!(client.is_none());
424    }
425
426    #[test]
427    fn test_parse_tool_name_simple() {
428        let (server, tool) = McpManager::parse_tool_name("mcp__server__tool").unwrap();
429        assert_eq!(server, "server");
430        assert_eq!(tool, "tool");
431    }
432
433    #[test]
434    fn test_parse_tool_name_multiple_underscores() {
435        let (server, tool) = McpManager::parse_tool_name("mcp__my_server__my_tool_name").unwrap();
436        assert_eq!(server, "my_server");
437        assert_eq!(tool, "my_tool_name");
438    }
439
440    #[test]
441    fn test_parse_tool_name_missing_prefix() {
442        let result = McpManager::parse_tool_name("server__tool");
443        assert!(result.is_err());
444    }
445
446    #[test]
447    fn test_parse_tool_name_only_prefix() {
448        let result = McpManager::parse_tool_name("mcp__");
449        assert!(result.is_err());
450    }
451
452    #[test]
453    fn test_parse_tool_name_empty_string() {
454        let result = McpManager::parse_tool_name("");
455        assert!(result.is_err());
456    }
457
458    #[test]
459    fn test_tool_result_to_string_single_text() {
460        let result = CallToolResult {
461            content: vec![ToolContent::Text {
462                text: "Hello World".to_string(),
463            }],
464            is_error: false,
465        };
466        let output = tool_result_to_string(&result);
467        assert_eq!(output, "Hello World");
468    }
469
470    #[test]
471    fn test_tool_result_to_string_multiple_text() {
472        let result = CallToolResult {
473            content: vec![
474                ToolContent::Text {
475                    text: "First line".to_string(),
476                },
477                ToolContent::Text {
478                    text: "Second line".to_string(),
479                },
480            ],
481            is_error: false,
482        };
483        let output = tool_result_to_string(&result);
484        assert!(output.contains("First line"));
485        assert!(output.contains("Second line"));
486    }
487
488    #[test]
489    fn test_tool_result_to_string_empty() {
490        let result = CallToolResult {
491            content: vec![],
492            is_error: false,
493        };
494        let output = tool_result_to_string(&result);
495        assert_eq!(output, "");
496    }
497
498    #[test]
499    fn test_tool_result_to_string_image() {
500        let result = CallToolResult {
501            content: vec![ToolContent::Image {
502                data: "base64data".to_string(),
503                mime_type: "image/png".to_string(),
504            }],
505            is_error: false,
506        };
507        let output = tool_result_to_string(&result);
508        assert!(output.contains("[Image: image/png]"));
509    }
510
511    #[test]
512    fn test_tool_result_to_string_resource() {
513        use crate::mcp::protocol::ResourceContent;
514        let result = CallToolResult {
515            content: vec![ToolContent::Resource {
516                resource: ResourceContent {
517                    uri: "file:///test.txt".to_string(),
518                    mime_type: Some("text/plain".to_string()),
519                    text: Some("Resource content".to_string()),
520                    blob: None,
521                },
522            }],
523            is_error: false,
524        };
525        let output = tool_result_to_string(&result);
526        assert!(output.contains("Resource content"));
527    }
528
529    #[test]
530    fn test_tool_result_to_string_mixed_content() {
531        use crate::mcp::protocol::ResourceContent;
532        let result = CallToolResult {
533            content: vec![
534                ToolContent::Text {
535                    text: "Text content".to_string(),
536                },
537                ToolContent::Image {
538                    data: "base64".to_string(),
539                    mime_type: "image/jpeg".to_string(),
540                },
541                ToolContent::Resource {
542                    resource: ResourceContent {
543                        uri: "file:///doc.md".to_string(),
544                        mime_type: Some("text/markdown".to_string()),
545                        text: Some("Doc content".to_string()),
546                        blob: None,
547                    },
548                },
549            ],
550            is_error: false,
551        };
552        let output = tool_result_to_string(&result);
553        assert!(output.contains("Text content"));
554        assert!(output.contains("[Image: image/jpeg]"));
555        assert!(output.contains("Doc content"));
556    }
557
558    #[tokio::test]
559    async fn test_get_status_registered_server() {
560        use std::collections::HashMap;
561        let manager = McpManager::new();
562
563        let config = McpServerConfig {
564            name: "test_server".to_string(),
565            transport: McpTransportConfig::Stdio {
566                command: "echo".to_string(),
567                args: vec![],
568            },
569            enabled: true,
570            env: HashMap::new(),
571            oauth: None,
572            tool_timeout_secs: 60,
573        };
574
575        manager.register_server(config).await;
576
577        let status = manager.get_status().await;
578        assert!(status.contains_key("test_server"));
579        assert!(!status["test_server"].connected);
580        assert!(status["test_server"].enabled);
581    }
582
583    #[tokio::test]
584    async fn test_get_status_disabled_server() {
585        use std::collections::HashMap;
586        let manager = McpManager::new();
587
588        let config = McpServerConfig {
589            name: "disabled_server".to_string(),
590            transport: McpTransportConfig::Stdio {
591                command: "echo".to_string(),
592                args: vec![],
593            },
594            enabled: false,
595            env: HashMap::new(),
596            oauth: None,
597            tool_timeout_secs: 60,
598        };
599
600        manager.register_server(config).await;
601
602        let status = manager.get_status().await;
603        assert!(status.contains_key("disabled_server"));
604        assert!(!status["disabled_server"].enabled);
605    }
606
607    #[tokio::test]
608    async fn test_get_all_tools_empty_manager() {
609        let manager = McpManager::new();
610        let tools = manager.get_all_tools().await;
611        assert!(tools.is_empty());
612    }
613
614    #[tokio::test]
615    async fn test_resolve_auth_header_none_when_no_oauth() {
616        let result = McpManager::resolve_auth_header(None).await.unwrap();
617        assert!(result.is_none());
618    }
619
620    #[tokio::test]
621    async fn test_resolve_auth_header_uses_static_token() {
622        use crate::mcp::protocol::OAuthConfig;
623        let oauth = OAuthConfig {
624            auth_url: "https://example.com/auth".to_string(),
625            token_url: "https://example.com/token".to_string(),
626            client_id: "client".to_string(),
627            client_secret: None,
628            scopes: vec![],
629            redirect_uri: "http://localhost/cb".to_string(),
630            access_token: Some("my-static-token".to_string()),
631        };
632        let result = McpManager::resolve_auth_header(Some(&oauth)).await.unwrap();
633        assert!(result.is_some());
634        let (key, value) = result.unwrap();
635        assert_eq!(key, "Authorization");
636        assert_eq!(value, "Bearer my-static-token");
637    }
638
639    #[tokio::test]
640    async fn test_resolve_auth_header_client_credentials_fails_gracefully() {
641        use crate::mcp::protocol::OAuthConfig;
642        // No static token + invalid token_url → should return error
643        let oauth = OAuthConfig {
644            auth_url: "https://127.0.0.1:1/auth".to_string(),
645            token_url: "http://127.0.0.1:1/token".to_string(),
646            client_id: "client".to_string(),
647            client_secret: Some("secret".to_string()),
648            scopes: vec!["read".to_string()],
649            redirect_uri: "http://localhost/cb".to_string(),
650            access_token: None,
651        };
652        let result = McpManager::resolve_auth_header(Some(&oauth)).await;
653        assert!(result.is_err());
654    }
655}