Skip to main content

a3s_code_core/mcp/
manager.rs

1//! MCP Manager
2//!
3//! Manages MCP server lifecycle and provides unified access to MCP tools.
4
5use crate::mcp::client::McpClient;
6use crate::mcp::oauth;
7use crate::mcp::protocol::{
8    CallToolResult, McpServerConfig, McpTool, McpTransportConfig, OAuthConfig, ToolContent,
9};
10use crate::mcp::transport::http_sse::HttpSseTransport;
11use crate::mcp::transport::stdio::StdioTransport;
12use crate::mcp::transport::streamable_http::StreamableHttpTransport;
13use crate::mcp::transport::McpTransport;
14use anyhow::{anyhow, Result};
15use std::collections::HashMap;
16use std::sync::Arc;
17use tokio::sync::RwLock;
18
19/// MCP server status
20#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
21pub struct McpServerStatus {
22    pub name: String,
23    pub connected: bool,
24    pub enabled: bool,
25    pub tool_count: usize,
26    pub error: Option<String>,
27}
28
29/// MCP Manager for managing multiple MCP servers
30pub struct McpManager {
31    /// Connected clients
32    clients: RwLock<HashMap<String, Arc<McpClient>>>,
33    /// Server configurations
34    configs: RwLock<HashMap<String, McpServerConfig>>,
35    /// Last connection error per server, cleared on successful connect
36    connect_errors: RwLock<HashMap<String, String>>,
37}
38
39impl McpManager {
40    /// Create a new MCP manager
41    pub fn new() -> Self {
42        Self {
43            clients: RwLock::new(HashMap::new()),
44            configs: RwLock::new(HashMap::new()),
45            connect_errors: RwLock::new(HashMap::new()),
46        }
47    }
48
49    /// Register a server configuration
50    pub async fn register_server(&self, config: McpServerConfig) {
51        let name = config.name.clone();
52        let mut configs = self.configs.write().await;
53        configs.insert(name.clone(), config);
54        tracing::info!("Registered MCP server: {}", name);
55    }
56
57    /// Connect to a registered server, recording success or failure internally.
58    ///
59    /// On success the stored error (if any) is cleared; on failure the error
60    /// message is stored and visible via [`McpManager::get_status`].
61    pub async fn connect(&self, name: &str) -> Result<()> {
62        let result = self.do_connect(name).await;
63        match &result {
64            Ok(_) => {
65                self.connect_errors.write().await.remove(name);
66            }
67            Err(e) => {
68                self.connect_errors
69                    .write()
70                    .await
71                    .insert(name.to_string(), e.to_string());
72            }
73        }
74        result
75    }
76
77    async fn do_connect(&self, name: &str) -> Result<()> {
78        // Get config
79        let config = {
80            let configs = self.configs.read().await;
81            configs
82                .get(name)
83                .cloned()
84                .ok_or_else(|| anyhow!("MCP server not found: {}", name))?
85        };
86
87        if !config.enabled {
88            return Err(anyhow!("MCP server is disabled: {}", name));
89        }
90
91        // Resolve OAuth token into an Authorization header (if configured)
92        let auth_header = Self::resolve_auth_header(config.oauth.as_ref()).await?;
93
94        // Create transport based on config
95        let transport: Arc<dyn McpTransport> = match &config.transport {
96            McpTransportConfig::Stdio { command, args } => Arc::new(
97                StdioTransport::spawn_with_timeout(
98                    command,
99                    args,
100                    &config.env,
101                    config.tool_timeout_secs,
102                )
103                .await?,
104            ),
105            McpTransportConfig::Http { url, headers } => {
106                let mut merged = headers.clone();
107                if let Some((k, v)) = &auth_header {
108                    merged.insert(k.clone(), v.clone());
109                }
110                Arc::new(
111                    HttpSseTransport::connect_with_timeout(url, merged, config.tool_timeout_secs)
112                        .await?,
113                )
114            }
115            McpTransportConfig::StreamableHttp { url, headers } => {
116                let mut merged = headers.clone();
117                if let Some((k, v)) = &auth_header {
118                    merged.insert(k.clone(), v.clone());
119                }
120                Arc::new(
121                    StreamableHttpTransport::connect_with_timeout(
122                        url,
123                        merged,
124                        config.tool_timeout_secs,
125                    )
126                    .await?,
127                )
128            }
129        };
130
131        // Create client
132        let client = Arc::new(McpClient::new(name.to_string(), transport));
133
134        // Initialize
135        client.initialize().await?;
136
137        // Fetch tools
138        let tools = client.list_tools().await?;
139        tracing::info!("MCP server '{}' connected with {} tools", name, tools.len());
140
141        // Store client
142        {
143            let mut clients = self.clients.write().await;
144            clients.insert(name.to_string(), client);
145        }
146
147        Ok(())
148    }
149
150    /// Disconnect from a server
151    pub async fn disconnect(&self, name: &str) -> Result<()> {
152        let client = {
153            let mut clients = self.clients.write().await;
154            clients.remove(name)
155        };
156
157        if let Some(client) = client {
158            client.close().await?;
159            tracing::info!("MCP server '{}' disconnected", name);
160        }
161
162        Ok(())
163    }
164
165    /// Get all registered server configurations
166    pub async fn all_configs(&self) -> Vec<McpServerConfig> {
167        self.configs.read().await.values().cloned().collect()
168    }
169
170    /// Get all MCP tools, grouped by server name.
171    ///
172    /// Returns `(server_name, tool)` pairs — the caller is responsible for
173    /// constructing the `mcp__<server>__<tool>` prefix (e.g. via [`create_mcp_tools`]).
174    pub async fn get_all_tools(&self) -> Vec<(String, McpTool)> {
175        let clients = self.clients.read().await;
176        let mut all_tools = Vec::new();
177
178        for (server_name, client) in clients.iter() {
179            let tools = client.get_cached_tools().await;
180            for tool in tools {
181                all_tools.push((server_name.clone(), tool));
182            }
183        }
184
185        all_tools
186    }
187
188    /// Call an MCP tool by full name
189    ///
190    /// Full name format: `mcp__<server>__<tool>`
191    pub async fn call_tool(
192        &self,
193        full_name: &str,
194        arguments: Option<serde_json::Value>,
195    ) -> Result<CallToolResult> {
196        // Parse full name
197        let (server_name, tool_name) = Self::parse_tool_name(full_name)?;
198
199        // Get client
200        let client = {
201            let clients = self.clients.read().await;
202            clients
203                .get(&server_name)
204                .cloned()
205                .ok_or_else(|| anyhow!("MCP server not connected: {}", server_name))?
206        };
207
208        // Call tool
209        client.call_tool(&tool_name, arguments).await
210    }
211
212    /// Resolve an OAuth config into a `(header-name, header-value)` pair.
213    ///
214    /// - If `oauth.access_token` is set, uses it directly (static token).
215    /// - Otherwise, performs a client credentials exchange.
216    /// - If `oauth` is `None`, returns `Ok(None)` (no auth needed).
217    async fn resolve_auth_header(oauth: Option<&OAuthConfig>) -> Result<Option<(String, String)>> {
218        let Some(oauth) = oauth else {
219            return Ok(None);
220        };
221
222        let token = if let Some(static_token) = &oauth.access_token {
223            static_token.clone()
224        } else {
225            oauth::exchange_client_credentials(
226                &oauth.token_url,
227                &oauth.client_id,
228                oauth.client_secret.as_deref().unwrap_or(""),
229                &oauth.scopes,
230            )
231            .await?
232        };
233
234        Ok(Some((
235            "Authorization".to_string(),
236            format!("Bearer {}", token),
237        )))
238    }
239
240    /// Parse MCP tool full name into (server, tool)
241    fn parse_tool_name(full_name: &str) -> Result<(String, String)> {
242        // Format: mcp__<server>__<tool>
243        if !full_name.starts_with("mcp__") {
244            return Err(anyhow!("Invalid MCP tool name: {}", full_name));
245        }
246
247        let rest = &full_name[5..]; // Skip "mcp__"
248        let parts: Vec<&str> = rest.splitn(2, "__").collect();
249
250        if parts.len() != 2 {
251            return Err(anyhow!("Invalid MCP tool name format: {}", full_name));
252        }
253
254        Ok((parts[0].to_string(), parts[1].to_string()))
255    }
256
257    /// Get status of all servers
258    pub async fn get_status(&self) -> HashMap<String, McpServerStatus> {
259        let configs = self.configs.read().await;
260        let clients = self.clients.read().await;
261        let errors = self.connect_errors.read().await;
262        let mut status = HashMap::new();
263
264        for (name, config) in configs.iter() {
265            let client = clients.get(name);
266            let (connected, tool_count) = if let Some(c) = client {
267                (c.is_connected(), c.get_cached_tools().await.len())
268            } else {
269                (false, 0)
270            };
271
272            status.insert(
273                name.clone(),
274                McpServerStatus {
275                    name: name.clone(),
276                    connected,
277                    enabled: config.enabled,
278                    tool_count,
279                    error: errors.get(name).cloned(),
280                },
281            );
282        }
283
284        status
285    }
286
287    /// Get a specific client
288    pub async fn get_client(&self, name: &str) -> Option<Arc<McpClient>> {
289        let clients = self.clients.read().await;
290        clients.get(name).cloned()
291    }
292
293    /// Check if a server is connected
294    pub async fn is_connected(&self, name: &str) -> bool {
295        let clients = self.clients.read().await;
296        clients.get(name).map(|c| c.is_connected()).unwrap_or(false)
297    }
298
299    /// List connected server names
300    pub async fn list_connected(&self) -> Vec<String> {
301        let clients = self.clients.read().await;
302        clients.keys().cloned().collect()
303    }
304
305    /// Get cached tools for a specific connected server.
306    pub async fn get_server_tools(&self, name: &str) -> Vec<McpTool> {
307        let clients = self.clients.read().await;
308        match clients.get(name) {
309            Some(client) => client.get_cached_tools().await,
310            None => Vec::new(),
311        }
312    }
313}
314
315impl Default for McpManager {
316    fn default() -> Self {
317        Self::new()
318    }
319}
320
321/// Convert MCP tool result to string output
322pub fn tool_result_to_string(result: &CallToolResult) -> String {
323    let mut output = String::new();
324
325    for content in &result.content {
326        match content {
327            ToolContent::Text { text } => {
328                output.push_str(text);
329                output.push('\n');
330            }
331            ToolContent::Image { data: _, mime_type } => {
332                output.push_str(&format!("[Image: {}]\n", mime_type));
333            }
334            ToolContent::Resource { resource } => {
335                if let Some(text) = &resource.text {
336                    output.push_str(text);
337                    output.push('\n');
338                } else {
339                    output.push_str(&format!("[Resource: {}]\n", resource.uri));
340                }
341            }
342        }
343    }
344
345    output.trim_end().to_string()
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn test_parse_tool_name() {
354        let (server, tool) = McpManager::parse_tool_name("mcp__github__create_issue").unwrap();
355        assert_eq!(server, "github");
356        assert_eq!(tool, "create_issue");
357    }
358
359    #[test]
360    fn test_parse_tool_name_with_underscores() {
361        let (server, tool) = McpManager::parse_tool_name("mcp__my_server__my_tool_name").unwrap();
362        assert_eq!(server, "my_server");
363        assert_eq!(tool, "my_tool_name");
364    }
365
366    #[test]
367    fn test_parse_tool_name_invalid() {
368        assert!(McpManager::parse_tool_name("invalid_name").is_err());
369        assert!(McpManager::parse_tool_name("mcp__nodelimiter").is_err());
370    }
371
372    #[test]
373    fn test_tool_result_to_string() {
374        let result = CallToolResult {
375            content: vec![
376                ToolContent::Text {
377                    text: "Line 1".to_string(),
378                },
379                ToolContent::Text {
380                    text: "Line 2".to_string(),
381                },
382            ],
383            is_error: false,
384        };
385
386        let output = tool_result_to_string(&result);
387        assert!(output.contains("Line 1"));
388        assert!(output.contains("Line 2"));
389    }
390
391    #[tokio::test]
392    async fn test_mcp_manager_new() {
393        let manager = McpManager::new();
394        let status = manager.get_status().await;
395        assert!(status.is_empty());
396    }
397
398    #[tokio::test]
399    async fn test_mcp_manager_register_server() {
400        let manager = McpManager::new();
401
402        let config = McpServerConfig {
403            name: "test".to_string(),
404            transport: McpTransportConfig::Stdio {
405                command: "echo".to_string(),
406                args: vec![],
407            },
408            enabled: true,
409            env: HashMap::new(),
410            oauth: None,
411            tool_timeout_secs: 60,
412        };
413
414        manager.register_server(config).await;
415
416        let status = manager.get_status().await;
417        assert!(status.contains_key("test"));
418        assert!(!status["test"].connected);
419    }
420
421    #[tokio::test]
422    async fn test_mcp_manager_default() {
423        let manager = McpManager::default();
424        let status = manager.get_status().await;
425        assert!(status.is_empty());
426    }
427
428    #[tokio::test]
429    async fn test_list_connected_empty() {
430        let manager = McpManager::new();
431        let connected = manager.list_connected().await;
432        assert!(connected.is_empty());
433    }
434
435    #[tokio::test]
436    async fn test_is_connected_false_for_unknown_server() {
437        let manager = McpManager::new();
438        let connected = manager.is_connected("unknown_server").await;
439        assert!(!connected);
440    }
441
442    #[tokio::test]
443    async fn test_get_client_none_for_unknown_server() {
444        let manager = McpManager::new();
445        let client = manager.get_client("unknown_server").await;
446        assert!(client.is_none());
447    }
448
449    #[test]
450    fn test_parse_tool_name_simple() {
451        let (server, tool) = McpManager::parse_tool_name("mcp__server__tool").unwrap();
452        assert_eq!(server, "server");
453        assert_eq!(tool, "tool");
454    }
455
456    #[test]
457    fn test_parse_tool_name_multiple_underscores() {
458        let (server, tool) = McpManager::parse_tool_name("mcp__my_server__my_tool_name").unwrap();
459        assert_eq!(server, "my_server");
460        assert_eq!(tool, "my_tool_name");
461    }
462
463    #[test]
464    fn test_parse_tool_name_missing_prefix() {
465        let result = McpManager::parse_tool_name("server__tool");
466        assert!(result.is_err());
467    }
468
469    #[test]
470    fn test_parse_tool_name_only_prefix() {
471        let result = McpManager::parse_tool_name("mcp__");
472        assert!(result.is_err());
473    }
474
475    #[test]
476    fn test_parse_tool_name_empty_string() {
477        let result = McpManager::parse_tool_name("");
478        assert!(result.is_err());
479    }
480
481    #[test]
482    fn test_tool_result_to_string_single_text() {
483        let result = CallToolResult {
484            content: vec![ToolContent::Text {
485                text: "Hello World".to_string(),
486            }],
487            is_error: false,
488        };
489        let output = tool_result_to_string(&result);
490        assert_eq!(output, "Hello World");
491    }
492
493    #[test]
494    fn test_tool_result_to_string_multiple_text() {
495        let result = CallToolResult {
496            content: vec![
497                ToolContent::Text {
498                    text: "First line".to_string(),
499                },
500                ToolContent::Text {
501                    text: "Second line".to_string(),
502                },
503            ],
504            is_error: false,
505        };
506        let output = tool_result_to_string(&result);
507        assert!(output.contains("First line"));
508        assert!(output.contains("Second line"));
509    }
510
511    #[test]
512    fn test_tool_result_to_string_empty() {
513        let result = CallToolResult {
514            content: vec![],
515            is_error: false,
516        };
517        let output = tool_result_to_string(&result);
518        assert_eq!(output, "");
519    }
520
521    #[test]
522    fn test_tool_result_to_string_image() {
523        let result = CallToolResult {
524            content: vec![ToolContent::Image {
525                data: "base64data".to_string(),
526                mime_type: "image/png".to_string(),
527            }],
528            is_error: false,
529        };
530        let output = tool_result_to_string(&result);
531        assert!(output.contains("[Image: image/png]"));
532    }
533
534    #[test]
535    fn test_tool_result_to_string_resource() {
536        use crate::mcp::protocol::ResourceContent;
537        let result = CallToolResult {
538            content: vec![ToolContent::Resource {
539                resource: ResourceContent {
540                    uri: "file:///test.txt".to_string(),
541                    mime_type: Some("text/plain".to_string()),
542                    text: Some("Resource content".to_string()),
543                    blob: None,
544                },
545            }],
546            is_error: false,
547        };
548        let output = tool_result_to_string(&result);
549        assert!(output.contains("Resource content"));
550    }
551
552    #[test]
553    fn test_tool_result_to_string_mixed_content() {
554        use crate::mcp::protocol::ResourceContent;
555        let result = CallToolResult {
556            content: vec![
557                ToolContent::Text {
558                    text: "Text content".to_string(),
559                },
560                ToolContent::Image {
561                    data: "base64".to_string(),
562                    mime_type: "image/jpeg".to_string(),
563                },
564                ToolContent::Resource {
565                    resource: ResourceContent {
566                        uri: "file:///doc.md".to_string(),
567                        mime_type: Some("text/markdown".to_string()),
568                        text: Some("Doc content".to_string()),
569                        blob: None,
570                    },
571                },
572            ],
573            is_error: false,
574        };
575        let output = tool_result_to_string(&result);
576        assert!(output.contains("Text content"));
577        assert!(output.contains("[Image: image/jpeg]"));
578        assert!(output.contains("Doc content"));
579    }
580
581    #[tokio::test]
582    async fn test_get_status_registered_server() {
583        use std::collections::HashMap;
584        let manager = McpManager::new();
585
586        let config = McpServerConfig {
587            name: "test_server".to_string(),
588            transport: McpTransportConfig::Stdio {
589                command: "echo".to_string(),
590                args: vec![],
591            },
592            enabled: true,
593            env: HashMap::new(),
594            oauth: None,
595            tool_timeout_secs: 60,
596        };
597
598        manager.register_server(config).await;
599
600        let status = manager.get_status().await;
601        assert!(status.contains_key("test_server"));
602        assert!(!status["test_server"].connected);
603        assert!(status["test_server"].enabled);
604    }
605
606    #[tokio::test]
607    async fn test_get_status_disabled_server() {
608        use std::collections::HashMap;
609        let manager = McpManager::new();
610
611        let config = McpServerConfig {
612            name: "disabled_server".to_string(),
613            transport: McpTransportConfig::Stdio {
614                command: "echo".to_string(),
615                args: vec![],
616            },
617            enabled: false,
618            env: HashMap::new(),
619            oauth: None,
620            tool_timeout_secs: 60,
621        };
622
623        manager.register_server(config).await;
624
625        let status = manager.get_status().await;
626        assert!(status.contains_key("disabled_server"));
627        assert!(!status["disabled_server"].enabled);
628    }
629
630    #[tokio::test]
631    async fn test_get_all_tools_empty_manager() {
632        let manager = McpManager::new();
633        let tools = manager.get_all_tools().await;
634        assert!(tools.is_empty());
635    }
636
637    #[tokio::test]
638    async fn test_resolve_auth_header_none_when_no_oauth() {
639        let result = McpManager::resolve_auth_header(None).await.unwrap();
640        assert!(result.is_none());
641    }
642
643    #[tokio::test]
644    async fn test_resolve_auth_header_uses_static_token() {
645        use crate::mcp::protocol::OAuthConfig;
646        let oauth = OAuthConfig {
647            auth_url: "https://example.com/auth".to_string(),
648            token_url: "https://example.com/token".to_string(),
649            client_id: "client".to_string(),
650            client_secret: None,
651            scopes: vec![],
652            redirect_uri: "http://localhost/cb".to_string(),
653            access_token: Some("my-static-token".to_string()),
654        };
655        let result = McpManager::resolve_auth_header(Some(&oauth)).await.unwrap();
656        assert!(result.is_some());
657        let (key, value) = result.unwrap();
658        assert_eq!(key, "Authorization");
659        assert_eq!(value, "Bearer my-static-token");
660    }
661
662    #[tokio::test]
663    async fn test_resolve_auth_header_client_credentials_fails_gracefully() {
664        use crate::mcp::protocol::OAuthConfig;
665        // No static token + invalid token_url → should return error
666        let oauth = OAuthConfig {
667            auth_url: "https://127.0.0.1:1/auth".to_string(),
668            token_url: "http://127.0.0.1:1/token".to_string(),
669            client_id: "client".to_string(),
670            client_secret: Some("secret".to_string()),
671            scopes: vec!["read".to_string()],
672            redirect_uri: "http://localhost/cb".to_string(),
673            access_token: None,
674        };
675        let result = McpManager::resolve_auth_header(Some(&oauth)).await;
676        assert!(result.is_err());
677    }
678
679    #[tokio::test]
680    async fn test_connect_error_recorded_in_status() {
681        use std::collections::HashMap;
682        let manager = McpManager::new();
683
684        let config = McpServerConfig {
685            name: "bad-server".to_string(),
686            transport: McpTransportConfig::Stdio {
687                // `true` exits immediately — not a valid MCP server
688                command: "true".to_string(),
689                args: vec![],
690            },
691            enabled: true,
692            env: HashMap::new(),
693            oauth: None,
694            tool_timeout_secs: 5,
695        };
696
697        manager.register_server(config).await;
698        // connect() will fail; error must be stored
699        let _ = manager.connect("bad-server").await;
700
701        let status = manager.get_status().await;
702        let s = &status["bad-server"];
703        assert!(!s.connected, "server should not be connected");
704        assert!(
705            s.error.is_some(),
706            "error should be recorded after failed connect"
707        );
708    }
709
710    #[tokio::test]
711    async fn test_get_all_tools_returns_server_name_not_full_name() {
712        // get_all_tools() must return (server_name, tool) — not (mcp__server__tool, tool)
713        // so that create_mcp_tools() can build the correct full name without double-prefix.
714        // With no connected servers the result is empty; the format assertion is enforced
715        // structurally: the field is named "server_name" and callers must not pre-add the prefix.
716        let manager = McpManager::new();
717        let tools = manager.get_all_tools().await;
718        // Empty is fine; verify no full-name leakage by checking the tuple semantics.
719        // (Real server injection is tested via integration_mcp.rs #[ignore] tests.)
720        for (name, _tool) in &tools {
721            assert!(
722                !name.starts_with("mcp__"),
723                "get_all_tools() must return server names, not prefixed full names; got '{name}'"
724            );
725        }
726    }
727}