1use crate::mcp::client::McpClient;
6use crate::mcp::oauth;
7use crate::mcp::protocol::{
8 CallToolResult, McpServerConfig, McpTool, McpTransportConfig, OAuthConfig, ToolContent,
9};
10use crate::mcp::transport::http_sse::HttpSseTransport;
11use crate::mcp::transport::stdio::StdioTransport;
12use crate::mcp::transport::streamable_http::StreamableHttpTransport;
13use crate::mcp::transport::McpTransport;
14use anyhow::{anyhow, Result};
15use std::collections::HashMap;
16use std::sync::Arc;
17use tokio::sync::RwLock;
18
19#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
21pub struct McpServerStatus {
22 pub name: String,
23 pub connected: bool,
24 pub enabled: bool,
25 pub tool_count: usize,
26 pub error: Option<String>,
27}
28
29pub struct McpManager {
31 clients: RwLock<HashMap<String, Arc<McpClient>>>,
33 configs: RwLock<HashMap<String, McpServerConfig>>,
35 connect_errors: RwLock<HashMap<String, String>>,
37}
38
39impl McpManager {
40 pub fn new() -> Self {
42 Self {
43 clients: RwLock::new(HashMap::new()),
44 configs: RwLock::new(HashMap::new()),
45 connect_errors: RwLock::new(HashMap::new()),
46 }
47 }
48
49 pub async fn register_server(&self, config: McpServerConfig) {
51 let name = config.name.clone();
52 let mut configs = self.configs.write().await;
53 configs.insert(name.clone(), config);
54 tracing::info!("Registered MCP server: {}", name);
55 }
56
57 pub async fn connect(&self, name: &str) -> Result<()> {
62 let result = self.do_connect(name).await;
63 match &result {
64 Ok(_) => {
65 self.connect_errors.write().await.remove(name);
66 }
67 Err(e) => {
68 self.connect_errors
69 .write()
70 .await
71 .insert(name.to_string(), e.to_string());
72 }
73 }
74 result
75 }
76
77 async fn do_connect(&self, name: &str) -> Result<()> {
78 let config = {
80 let configs = self.configs.read().await;
81 configs
82 .get(name)
83 .cloned()
84 .ok_or_else(|| anyhow!("MCP server not found: {}", name))?
85 };
86
87 if !config.enabled {
88 return Err(anyhow!("MCP server is disabled: {}", name));
89 }
90
91 let auth_header = Self::resolve_auth_header(config.oauth.as_ref()).await?;
93
94 let transport: Arc<dyn McpTransport> = match &config.transport {
96 McpTransportConfig::Stdio { command, args } => Arc::new(
97 StdioTransport::spawn_with_timeout(
98 command,
99 args,
100 &config.env,
101 config.tool_timeout_secs,
102 )
103 .await?,
104 ),
105 McpTransportConfig::Http { url, headers } => {
106 let mut merged = headers.clone();
107 if let Some((k, v)) = &auth_header {
108 merged.insert(k.clone(), v.clone());
109 }
110 Arc::new(
111 HttpSseTransport::connect_with_timeout(url, merged, config.tool_timeout_secs)
112 .await?,
113 )
114 }
115 McpTransportConfig::StreamableHttp { url, headers } => {
116 let mut merged = headers.clone();
117 if let Some((k, v)) = &auth_header {
118 merged.insert(k.clone(), v.clone());
119 }
120 Arc::new(
121 StreamableHttpTransport::connect_with_timeout(
122 url,
123 merged,
124 config.tool_timeout_secs,
125 )
126 .await?,
127 )
128 }
129 };
130
131 let client = Arc::new(McpClient::new(name.to_string(), transport));
133
134 client.initialize().await?;
136
137 let tools = client.list_tools().await?;
139 tracing::info!("MCP server '{}' connected with {} tools", name, tools.len());
140
141 {
143 let mut clients = self.clients.write().await;
144 clients.insert(name.to_string(), client);
145 }
146
147 Ok(())
148 }
149
150 pub async fn disconnect(&self, name: &str) -> Result<()> {
152 let client = {
153 let mut clients = self.clients.write().await;
154 clients.remove(name)
155 };
156
157 if let Some(client) = client {
158 client.close().await?;
159 tracing::info!("MCP server '{}' disconnected", name);
160 }
161
162 Ok(())
163 }
164
165 pub async fn all_configs(&self) -> Vec<McpServerConfig> {
167 self.configs.read().await.values().cloned().collect()
168 }
169
170 pub async fn get_all_tools(&self) -> Vec<(String, McpTool)> {
175 let clients = self.clients.read().await;
176 let mut all_tools = Vec::new();
177
178 for (server_name, client) in clients.iter() {
179 let tools = client.get_cached_tools().await;
180 for tool in tools {
181 all_tools.push((server_name.clone(), tool));
182 }
183 }
184
185 all_tools
186 }
187
188 pub async fn call_tool(
192 &self,
193 full_name: &str,
194 arguments: Option<serde_json::Value>,
195 ) -> Result<CallToolResult> {
196 let (server_name, tool_name) = Self::parse_tool_name(full_name)?;
198
199 let client = {
201 let clients = self.clients.read().await;
202 clients
203 .get(&server_name)
204 .cloned()
205 .ok_or_else(|| anyhow!("MCP server not connected: {}", server_name))?
206 };
207
208 client.call_tool(&tool_name, arguments).await
210 }
211
212 async fn resolve_auth_header(oauth: Option<&OAuthConfig>) -> Result<Option<(String, String)>> {
218 let Some(oauth) = oauth else {
219 return Ok(None);
220 };
221
222 let token = if let Some(static_token) = &oauth.access_token {
223 static_token.clone()
224 } else {
225 oauth::exchange_client_credentials(
226 &oauth.token_url,
227 &oauth.client_id,
228 oauth.client_secret.as_deref().unwrap_or(""),
229 &oauth.scopes,
230 )
231 .await?
232 };
233
234 Ok(Some((
235 "Authorization".to_string(),
236 format!("Bearer {}", token),
237 )))
238 }
239
240 fn parse_tool_name(full_name: &str) -> Result<(String, String)> {
242 if !full_name.starts_with("mcp__") {
244 return Err(anyhow!("Invalid MCP tool name: {}", full_name));
245 }
246
247 let rest = &full_name[5..]; let parts: Vec<&str> = rest.splitn(2, "__").collect();
249
250 if parts.len() != 2 {
251 return Err(anyhow!("Invalid MCP tool name format: {}", full_name));
252 }
253
254 Ok((parts[0].to_string(), parts[1].to_string()))
255 }
256
257 pub async fn get_status(&self) -> HashMap<String, McpServerStatus> {
259 let configs = self.configs.read().await;
260 let clients = self.clients.read().await;
261 let errors = self.connect_errors.read().await;
262 let mut status = HashMap::new();
263
264 for (name, config) in configs.iter() {
265 let client = clients.get(name);
266 let (connected, tool_count) = if let Some(c) = client {
267 (c.is_connected(), c.get_cached_tools().await.len())
268 } else {
269 (false, 0)
270 };
271
272 status.insert(
273 name.clone(),
274 McpServerStatus {
275 name: name.clone(),
276 connected,
277 enabled: config.enabled,
278 tool_count,
279 error: errors.get(name).cloned(),
280 },
281 );
282 }
283
284 status
285 }
286
287 pub async fn get_client(&self, name: &str) -> Option<Arc<McpClient>> {
289 let clients = self.clients.read().await;
290 clients.get(name).cloned()
291 }
292
293 pub async fn is_connected(&self, name: &str) -> bool {
295 let clients = self.clients.read().await;
296 clients.get(name).map(|c| c.is_connected()).unwrap_or(false)
297 }
298
299 pub async fn list_connected(&self) -> Vec<String> {
301 let clients = self.clients.read().await;
302 clients.keys().cloned().collect()
303 }
304
305 pub async fn get_server_tools(&self, name: &str) -> Vec<McpTool> {
307 let clients = self.clients.read().await;
308 match clients.get(name) {
309 Some(client) => client.get_cached_tools().await,
310 None => Vec::new(),
311 }
312 }
313}
314
315impl Default for McpManager {
316 fn default() -> Self {
317 Self::new()
318 }
319}
320
321pub fn tool_result_to_string(result: &CallToolResult) -> String {
323 let mut output = String::new();
324
325 for content in &result.content {
326 match content {
327 ToolContent::Text { text } => {
328 output.push_str(text);
329 output.push('\n');
330 }
331 ToolContent::Image { data: _, mime_type } => {
332 output.push_str(&format!("[Image: {}]\n", mime_type));
333 }
334 ToolContent::Resource { resource } => {
335 if let Some(text) = &resource.text {
336 output.push_str(text);
337 output.push('\n');
338 } else {
339 output.push_str(&format!("[Resource: {}]\n", resource.uri));
340 }
341 }
342 }
343 }
344
345 output.trim_end().to_string()
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn test_parse_tool_name() {
354 let (server, tool) = McpManager::parse_tool_name("mcp__github__create_issue").unwrap();
355 assert_eq!(server, "github");
356 assert_eq!(tool, "create_issue");
357 }
358
359 #[test]
360 fn test_parse_tool_name_with_underscores() {
361 let (server, tool) = McpManager::parse_tool_name("mcp__my_server__my_tool_name").unwrap();
362 assert_eq!(server, "my_server");
363 assert_eq!(tool, "my_tool_name");
364 }
365
366 #[test]
367 fn test_parse_tool_name_invalid() {
368 assert!(McpManager::parse_tool_name("invalid_name").is_err());
369 assert!(McpManager::parse_tool_name("mcp__nodelimiter").is_err());
370 }
371
372 #[test]
373 fn test_tool_result_to_string() {
374 let result = CallToolResult {
375 content: vec![
376 ToolContent::Text {
377 text: "Line 1".to_string(),
378 },
379 ToolContent::Text {
380 text: "Line 2".to_string(),
381 },
382 ],
383 is_error: false,
384 };
385
386 let output = tool_result_to_string(&result);
387 assert!(output.contains("Line 1"));
388 assert!(output.contains("Line 2"));
389 }
390
391 #[tokio::test]
392 async fn test_mcp_manager_new() {
393 let manager = McpManager::new();
394 let status = manager.get_status().await;
395 assert!(status.is_empty());
396 }
397
398 #[tokio::test]
399 async fn test_mcp_manager_register_server() {
400 let manager = McpManager::new();
401
402 let config = McpServerConfig {
403 name: "test".to_string(),
404 transport: McpTransportConfig::Stdio {
405 command: "echo".to_string(),
406 args: vec![],
407 },
408 enabled: true,
409 env: HashMap::new(),
410 oauth: None,
411 tool_timeout_secs: 60,
412 };
413
414 manager.register_server(config).await;
415
416 let status = manager.get_status().await;
417 assert!(status.contains_key("test"));
418 assert!(!status["test"].connected);
419 }
420
421 #[tokio::test]
422 async fn test_mcp_manager_default() {
423 let manager = McpManager::default();
424 let status = manager.get_status().await;
425 assert!(status.is_empty());
426 }
427
428 #[tokio::test]
429 async fn test_list_connected_empty() {
430 let manager = McpManager::new();
431 let connected = manager.list_connected().await;
432 assert!(connected.is_empty());
433 }
434
435 #[tokio::test]
436 async fn test_is_connected_false_for_unknown_server() {
437 let manager = McpManager::new();
438 let connected = manager.is_connected("unknown_server").await;
439 assert!(!connected);
440 }
441
442 #[tokio::test]
443 async fn test_get_client_none_for_unknown_server() {
444 let manager = McpManager::new();
445 let client = manager.get_client("unknown_server").await;
446 assert!(client.is_none());
447 }
448
449 #[test]
450 fn test_parse_tool_name_simple() {
451 let (server, tool) = McpManager::parse_tool_name("mcp__server__tool").unwrap();
452 assert_eq!(server, "server");
453 assert_eq!(tool, "tool");
454 }
455
456 #[test]
457 fn test_parse_tool_name_multiple_underscores() {
458 let (server, tool) = McpManager::parse_tool_name("mcp__my_server__my_tool_name").unwrap();
459 assert_eq!(server, "my_server");
460 assert_eq!(tool, "my_tool_name");
461 }
462
463 #[test]
464 fn test_parse_tool_name_missing_prefix() {
465 let result = McpManager::parse_tool_name("server__tool");
466 assert!(result.is_err());
467 }
468
469 #[test]
470 fn test_parse_tool_name_only_prefix() {
471 let result = McpManager::parse_tool_name("mcp__");
472 assert!(result.is_err());
473 }
474
475 #[test]
476 fn test_parse_tool_name_empty_string() {
477 let result = McpManager::parse_tool_name("");
478 assert!(result.is_err());
479 }
480
481 #[test]
482 fn test_tool_result_to_string_single_text() {
483 let result = CallToolResult {
484 content: vec![ToolContent::Text {
485 text: "Hello World".to_string(),
486 }],
487 is_error: false,
488 };
489 let output = tool_result_to_string(&result);
490 assert_eq!(output, "Hello World");
491 }
492
493 #[test]
494 fn test_tool_result_to_string_multiple_text() {
495 let result = CallToolResult {
496 content: vec![
497 ToolContent::Text {
498 text: "First line".to_string(),
499 },
500 ToolContent::Text {
501 text: "Second line".to_string(),
502 },
503 ],
504 is_error: false,
505 };
506 let output = tool_result_to_string(&result);
507 assert!(output.contains("First line"));
508 assert!(output.contains("Second line"));
509 }
510
511 #[test]
512 fn test_tool_result_to_string_empty() {
513 let result = CallToolResult {
514 content: vec![],
515 is_error: false,
516 };
517 let output = tool_result_to_string(&result);
518 assert_eq!(output, "");
519 }
520
521 #[test]
522 fn test_tool_result_to_string_image() {
523 let result = CallToolResult {
524 content: vec![ToolContent::Image {
525 data: "base64data".to_string(),
526 mime_type: "image/png".to_string(),
527 }],
528 is_error: false,
529 };
530 let output = tool_result_to_string(&result);
531 assert!(output.contains("[Image: image/png]"));
532 }
533
534 #[test]
535 fn test_tool_result_to_string_resource() {
536 use crate::mcp::protocol::ResourceContent;
537 let result = CallToolResult {
538 content: vec![ToolContent::Resource {
539 resource: ResourceContent {
540 uri: "file:///test.txt".to_string(),
541 mime_type: Some("text/plain".to_string()),
542 text: Some("Resource content".to_string()),
543 blob: None,
544 },
545 }],
546 is_error: false,
547 };
548 let output = tool_result_to_string(&result);
549 assert!(output.contains("Resource content"));
550 }
551
552 #[test]
553 fn test_tool_result_to_string_mixed_content() {
554 use crate::mcp::protocol::ResourceContent;
555 let result = CallToolResult {
556 content: vec![
557 ToolContent::Text {
558 text: "Text content".to_string(),
559 },
560 ToolContent::Image {
561 data: "base64".to_string(),
562 mime_type: "image/jpeg".to_string(),
563 },
564 ToolContent::Resource {
565 resource: ResourceContent {
566 uri: "file:///doc.md".to_string(),
567 mime_type: Some("text/markdown".to_string()),
568 text: Some("Doc content".to_string()),
569 blob: None,
570 },
571 },
572 ],
573 is_error: false,
574 };
575 let output = tool_result_to_string(&result);
576 assert!(output.contains("Text content"));
577 assert!(output.contains("[Image: image/jpeg]"));
578 assert!(output.contains("Doc content"));
579 }
580
581 #[tokio::test]
582 async fn test_get_status_registered_server() {
583 use std::collections::HashMap;
584 let manager = McpManager::new();
585
586 let config = McpServerConfig {
587 name: "test_server".to_string(),
588 transport: McpTransportConfig::Stdio {
589 command: "echo".to_string(),
590 args: vec![],
591 },
592 enabled: true,
593 env: HashMap::new(),
594 oauth: None,
595 tool_timeout_secs: 60,
596 };
597
598 manager.register_server(config).await;
599
600 let status = manager.get_status().await;
601 assert!(status.contains_key("test_server"));
602 assert!(!status["test_server"].connected);
603 assert!(status["test_server"].enabled);
604 }
605
606 #[tokio::test]
607 async fn test_get_status_disabled_server() {
608 use std::collections::HashMap;
609 let manager = McpManager::new();
610
611 let config = McpServerConfig {
612 name: "disabled_server".to_string(),
613 transport: McpTransportConfig::Stdio {
614 command: "echo".to_string(),
615 args: vec![],
616 },
617 enabled: false,
618 env: HashMap::new(),
619 oauth: None,
620 tool_timeout_secs: 60,
621 };
622
623 manager.register_server(config).await;
624
625 let status = manager.get_status().await;
626 assert!(status.contains_key("disabled_server"));
627 assert!(!status["disabled_server"].enabled);
628 }
629
630 #[tokio::test]
631 async fn test_get_all_tools_empty_manager() {
632 let manager = McpManager::new();
633 let tools = manager.get_all_tools().await;
634 assert!(tools.is_empty());
635 }
636
637 #[tokio::test]
638 async fn test_resolve_auth_header_none_when_no_oauth() {
639 let result = McpManager::resolve_auth_header(None).await.unwrap();
640 assert!(result.is_none());
641 }
642
643 #[tokio::test]
644 async fn test_resolve_auth_header_uses_static_token() {
645 use crate::mcp::protocol::OAuthConfig;
646 let oauth = OAuthConfig {
647 auth_url: "https://example.com/auth".to_string(),
648 token_url: "https://example.com/token".to_string(),
649 client_id: "client".to_string(),
650 client_secret: None,
651 scopes: vec![],
652 redirect_uri: "http://localhost/cb".to_string(),
653 access_token: Some("my-static-token".to_string()),
654 };
655 let result = McpManager::resolve_auth_header(Some(&oauth)).await.unwrap();
656 assert!(result.is_some());
657 let (key, value) = result.unwrap();
658 assert_eq!(key, "Authorization");
659 assert_eq!(value, "Bearer my-static-token");
660 }
661
662 #[tokio::test]
663 async fn test_resolve_auth_header_client_credentials_fails_gracefully() {
664 use crate::mcp::protocol::OAuthConfig;
665 let oauth = OAuthConfig {
667 auth_url: "https://127.0.0.1:1/auth".to_string(),
668 token_url: "http://127.0.0.1:1/token".to_string(),
669 client_id: "client".to_string(),
670 client_secret: Some("secret".to_string()),
671 scopes: vec!["read".to_string()],
672 redirect_uri: "http://localhost/cb".to_string(),
673 access_token: None,
674 };
675 let result = McpManager::resolve_auth_header(Some(&oauth)).await;
676 assert!(result.is_err());
677 }
678
679 #[tokio::test]
680 async fn test_connect_error_recorded_in_status() {
681 use std::collections::HashMap;
682 let manager = McpManager::new();
683
684 let config = McpServerConfig {
685 name: "bad-server".to_string(),
686 transport: McpTransportConfig::Stdio {
687 command: "true".to_string(),
689 args: vec![],
690 },
691 enabled: true,
692 env: HashMap::new(),
693 oauth: None,
694 tool_timeout_secs: 5,
695 };
696
697 manager.register_server(config).await;
698 let _ = manager.connect("bad-server").await;
700
701 let status = manager.get_status().await;
702 let s = &status["bad-server"];
703 assert!(!s.connected, "server should not be connected");
704 assert!(
705 s.error.is_some(),
706 "error should be recorded after failed connect"
707 );
708 }
709
710 #[tokio::test]
711 async fn test_get_all_tools_returns_server_name_not_full_name() {
712 let manager = McpManager::new();
717 let tools = manager.get_all_tools().await;
718 for (name, _tool) in &tools {
721 assert!(
722 !name.starts_with("mcp__"),
723 "get_all_tools() must return server names, not prefixed full names; got '{name}'"
724 );
725 }
726 }
727}