1use crate::mcp::client::McpClient;
6use crate::mcp::oauth;
7use crate::mcp::protocol::{
8 CallToolResult, McpServerConfig, McpTool, McpTransportConfig, OAuthConfig, ToolContent,
9};
10use crate::mcp::transport::http_sse::HttpSseTransport;
11use crate::mcp::transport::stdio::StdioTransport;
12use crate::mcp::transport::streamable_http::StreamableHttpTransport;
13use crate::mcp::transport::McpTransport;
14use anyhow::{anyhow, Result};
15use std::collections::HashMap;
16use std::sync::Arc;
17use tokio::sync::RwLock;
18
19#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
21pub struct McpServerStatus {
22 pub name: String,
23 pub connected: bool,
24 pub enabled: bool,
25 pub tool_count: usize,
26 pub error: Option<String>,
27}
28
29pub struct McpManager {
31 clients: RwLock<HashMap<String, Arc<McpClient>>>,
33 configs: RwLock<HashMap<String, McpServerConfig>>,
35 connect_errors: RwLock<HashMap<String, String>>,
37 last_used_at_ms: RwLock<HashMap<String, u64>>,
43}
44
45impl McpManager {
46 pub fn new() -> Self {
48 Self {
49 clients: RwLock::new(HashMap::new()),
50 configs: RwLock::new(HashMap::new()),
51 connect_errors: RwLock::new(HashMap::new()),
52 last_used_at_ms: RwLock::new(HashMap::new()),
53 }
54 }
55
56 pub async fn register_server(&self, config: McpServerConfig) {
58 let name = config.name.clone();
59 let mut configs = self.configs.write().await;
60 configs.insert(name.clone(), config);
61 tracing::info!("Registered MCP server: {}", name);
62 }
63
64 pub async fn connect(&self, name: &str) -> Result<()> {
69 let result = self.do_connect(name).await;
70 match &result {
71 Ok(_) => {
72 self.connect_errors.write().await.remove(name);
73 }
74 Err(e) => {
75 self.connect_errors
76 .write()
77 .await
78 .insert(name.to_string(), e.to_string());
79 }
80 }
81 result
82 }
83
84 async fn do_connect(&self, name: &str) -> Result<()> {
85 let config = {
87 let configs = self.configs.read().await;
88 configs
89 .get(name)
90 .cloned()
91 .ok_or_else(|| anyhow!("MCP server not found: {}", name))?
92 };
93
94 if !config.enabled {
95 return Err(anyhow!("MCP server is disabled: {}", name));
96 }
97
98 let auth_header = Self::resolve_auth_header(config.oauth.as_ref()).await?;
100
101 let transport: Arc<dyn McpTransport> = match &config.transport {
103 McpTransportConfig::Stdio { command, args } => Arc::new(
104 StdioTransport::spawn_with_timeout(
105 command,
106 args,
107 &config.env,
108 config.tool_timeout_secs,
109 )
110 .await?,
111 ),
112 McpTransportConfig::Http { url, headers } => {
113 let mut merged = headers.clone();
114 if let Some((k, v)) = &auth_header {
115 merged.insert(k.clone(), v.clone());
116 }
117 Arc::new(
118 HttpSseTransport::connect_with_timeout(url, merged, config.tool_timeout_secs)
119 .await?,
120 )
121 }
122 McpTransportConfig::StreamableHttp { url, headers } => {
123 let mut merged = headers.clone();
124 if let Some((k, v)) = &auth_header {
125 merged.insert(k.clone(), v.clone());
126 }
127 Arc::new(
128 StreamableHttpTransport::connect_with_timeout(
129 url,
130 merged,
131 config.tool_timeout_secs,
132 )
133 .await?,
134 )
135 }
136 };
137
138 let client = Arc::new(McpClient::new(name.to_string(), transport));
140
141 client.initialize().await?;
143
144 let tools = client.list_tools().await?;
146 tracing::info!("MCP server '{}' connected with {} tools", name, tools.len());
147
148 {
151 let mut clients = self.clients.write().await;
152 clients.insert(name.to_string(), client);
153 }
154 self.last_used_at_ms
155 .write()
156 .await
157 .insert(name.to_string(), now_epoch_ms());
158
159 Ok(())
160 }
161
162 pub async fn disconnect(&self, name: &str) -> Result<()> {
164 let client = {
165 let mut clients = self.clients.write().await;
166 clients.remove(name)
167 };
168 self.last_used_at_ms.write().await.remove(name);
169
170 if let Some(client) = client {
171 client.close().await?;
172 tracing::info!("MCP server '{}' disconnected", name);
173 }
174
175 Ok(())
176 }
177
178 pub async fn last_used_at_ms(&self, name: &str) -> Option<u64> {
181 self.last_used_at_ms.read().await.get(name).copied()
182 }
183
184 pub async fn touch(&self, name: &str) {
190 self.last_used_at_ms
191 .write()
192 .await
193 .insert(name.to_string(), now_epoch_ms());
194 }
195
196 pub async fn disconnect_idle(&self, idle_threshold_ms: u64) -> Vec<String> {
213 let cutoff = now_epoch_ms().saturating_sub(idle_threshold_ms);
214 let candidates: Vec<String> = {
216 let clients = self.clients.read().await;
217 let last_used = self.last_used_at_ms.read().await;
218 clients
219 .keys()
220 .filter(|name| match last_used.get(*name) {
221 Some(ts) => *ts < cutoff,
222 None => true,
225 })
226 .cloned()
227 .collect()
228 };
229 let mut disconnected = Vec::with_capacity(candidates.len());
230 for name in candidates {
231 match self.disconnect(&name).await {
232 Ok(()) => disconnected.push(name),
233 Err(e) => tracing::warn!(
234 server = %name,
235 error = %e,
236 "MCP idle disconnect failed; entry already removed from registry"
237 ),
238 }
239 }
240 {
247 let clients = self.clients.read().await;
248 self.last_used_at_ms
249 .write()
250 .await
251 .retain(|name, _| clients.contains_key(name));
252 }
253 disconnected
254 }
255
256 pub async fn all_configs(&self) -> Vec<McpServerConfig> {
258 self.configs.read().await.values().cloned().collect()
259 }
260
261 pub async fn get_all_tools(&self) -> Vec<(String, McpTool)> {
266 let clients = self.clients.read().await;
267 let mut all_tools = Vec::new();
268
269 for (server_name, client) in clients.iter() {
270 let tools = client.get_cached_tools().await;
271 for tool in tools {
272 all_tools.push((server_name.clone(), tool));
273 }
274 }
275
276 all_tools
277 }
278
279 pub async fn call_tool(
283 &self,
284 full_name: &str,
285 arguments: Option<serde_json::Value>,
286 ) -> Result<CallToolResult> {
287 let (server_name, tool_name) = Self::parse_tool_name(full_name)?;
289
290 let client = {
292 let clients = self.clients.read().await;
293 clients
294 .get(&server_name)
295 .cloned()
296 .ok_or_else(|| anyhow!("MCP server not connected: {}", server_name))?
297 };
298
299 self.last_used_at_ms
302 .write()
303 .await
304 .insert(server_name.clone(), now_epoch_ms());
305
306 client.call_tool(&tool_name, arguments).await
308 }
309
310 async fn resolve_auth_header(oauth: Option<&OAuthConfig>) -> Result<Option<(String, String)>> {
316 let Some(oauth) = oauth else {
317 return Ok(None);
318 };
319
320 let token = if let Some(static_token) = &oauth.access_token {
321 static_token.clone()
322 } else {
323 oauth::exchange_client_credentials(
324 &oauth.token_url,
325 &oauth.client_id,
326 oauth.client_secret.as_deref().unwrap_or(""),
327 &oauth.scopes,
328 )
329 .await?
330 };
331
332 Ok(Some((
333 "Authorization".to_string(),
334 format!("Bearer {}", token),
335 )))
336 }
337
338 fn parse_tool_name(full_name: &str) -> Result<(String, String)> {
340 if !full_name.starts_with("mcp__") {
342 return Err(anyhow!("Invalid MCP tool name: {}", full_name));
343 }
344
345 let rest = &full_name[5..]; let parts: Vec<&str> = rest.splitn(2, "__").collect();
347
348 if parts.len() != 2 {
349 return Err(anyhow!("Invalid MCP tool name format: {}", full_name));
350 }
351
352 Ok((parts[0].to_string(), parts[1].to_string()))
353 }
354
355 pub async fn get_status(&self) -> HashMap<String, McpServerStatus> {
357 let configs = self.configs.read().await;
358 let clients = self.clients.read().await;
359 let errors = self.connect_errors.read().await;
360 let mut status = HashMap::new();
361
362 for (name, config) in configs.iter() {
363 let client = clients.get(name);
364 let (connected, tool_count) = if let Some(c) = client {
365 (c.is_connected(), c.get_cached_tools().await.len())
366 } else {
367 (false, 0)
368 };
369
370 status.insert(
371 name.clone(),
372 McpServerStatus {
373 name: name.clone(),
374 connected,
375 enabled: config.enabled,
376 tool_count,
377 error: errors.get(name).cloned(),
378 },
379 );
380 }
381
382 status
383 }
384
385 pub async fn get_client(&self, name: &str) -> Option<Arc<McpClient>> {
387 let clients = self.clients.read().await;
388 clients.get(name).cloned()
389 }
390
391 pub async fn is_connected(&self, name: &str) -> bool {
393 let clients = self.clients.read().await;
394 clients.get(name).map(|c| c.is_connected()).unwrap_or(false)
395 }
396
397 pub async fn list_connected(&self) -> Vec<String> {
399 let clients = self.clients.read().await;
400 clients.keys().cloned().collect()
401 }
402
403 pub async fn get_server_tools(&self, name: &str) -> Vec<McpTool> {
405 let clients = self.clients.read().await;
406 match clients.get(name) {
407 Some(client) => client.get_cached_tools().await,
408 None => Vec::new(),
409 }
410 }
411}
412
413impl Default for McpManager {
414 fn default() -> Self {
415 Self::new()
416 }
417}
418
419fn now_epoch_ms() -> u64 {
425 std::time::SystemTime::now()
426 .duration_since(std::time::UNIX_EPOCH)
427 .map(|d| d.as_millis() as u64)
428 .unwrap_or(0)
429}
430
431pub fn tool_result_to_string(result: &CallToolResult) -> String {
433 let mut output = String::new();
434
435 for content in &result.content {
436 match content {
437 ToolContent::Text { text } => {
438 output.push_str(text);
439 output.push('\n');
440 }
441 ToolContent::Image { data: _, mime_type } => {
442 output.push_str(&format!("[Image: {}]\n", mime_type));
443 }
444 ToolContent::Resource { resource } => {
445 if let Some(text) = &resource.text {
446 output.push_str(text);
447 output.push('\n');
448 } else {
449 output.push_str(&format!("[Resource: {}]\n", resource.uri));
450 }
451 }
452 }
453 }
454
455 output.trim_end().to_string()
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461
462 #[test]
463 fn test_parse_tool_name() {
464 let (server, tool) = McpManager::parse_tool_name("mcp__github__create_issue").unwrap();
465 assert_eq!(server, "github");
466 assert_eq!(tool, "create_issue");
467 }
468
469 #[test]
470 fn test_parse_tool_name_with_underscores() {
471 let (server, tool) = McpManager::parse_tool_name("mcp__my_server__my_tool_name").unwrap();
472 assert_eq!(server, "my_server");
473 assert_eq!(tool, "my_tool_name");
474 }
475
476 #[test]
477 fn test_parse_tool_name_invalid() {
478 assert!(McpManager::parse_tool_name("invalid_name").is_err());
479 assert!(McpManager::parse_tool_name("mcp__nodelimiter").is_err());
480 }
481
482 #[test]
483 fn test_tool_result_to_string() {
484 let result = CallToolResult {
485 content: vec![
486 ToolContent::Text {
487 text: "Line 1".to_string(),
488 },
489 ToolContent::Text {
490 text: "Line 2".to_string(),
491 },
492 ],
493 is_error: false,
494 };
495
496 let output = tool_result_to_string(&result);
497 assert!(output.contains("Line 1"));
498 assert!(output.contains("Line 2"));
499 }
500
501 #[tokio::test]
502 async fn test_mcp_manager_new() {
503 let manager = McpManager::new();
504 let status = manager.get_status().await;
505 assert!(status.is_empty());
506 }
507
508 #[tokio::test]
509 async fn test_mcp_manager_register_server() {
510 let manager = McpManager::new();
511
512 let config = McpServerConfig {
513 name: "test".to_string(),
514 transport: McpTransportConfig::Stdio {
515 command: "echo".to_string(),
516 args: vec![],
517 },
518 enabled: true,
519 env: HashMap::new(),
520 oauth: None,
521 tool_timeout_secs: 60,
522 };
523
524 manager.register_server(config).await;
525
526 let status = manager.get_status().await;
527 assert!(status.contains_key("test"));
528 assert!(!status["test"].connected);
529 }
530
531 #[tokio::test]
532 async fn test_mcp_manager_default() {
533 let manager = McpManager::default();
534 let status = manager.get_status().await;
535 assert!(status.is_empty());
536 }
537
538 #[tokio::test]
539 async fn test_list_connected_empty() {
540 let manager = McpManager::new();
541 let connected = manager.list_connected().await;
542 assert!(connected.is_empty());
543 }
544
545 #[tokio::test]
546 async fn test_is_connected_false_for_unknown_server() {
547 let manager = McpManager::new();
548 let connected = manager.is_connected("unknown_server").await;
549 assert!(!connected);
550 }
551
552 #[tokio::test]
553 async fn test_get_client_none_for_unknown_server() {
554 let manager = McpManager::new();
555 let client = manager.get_client("unknown_server").await;
556 assert!(client.is_none());
557 }
558
559 #[test]
560 fn test_parse_tool_name_simple() {
561 let (server, tool) = McpManager::parse_tool_name("mcp__server__tool").unwrap();
562 assert_eq!(server, "server");
563 assert_eq!(tool, "tool");
564 }
565
566 #[test]
567 fn test_parse_tool_name_multiple_underscores() {
568 let (server, tool) = McpManager::parse_tool_name("mcp__my_server__my_tool_name").unwrap();
569 assert_eq!(server, "my_server");
570 assert_eq!(tool, "my_tool_name");
571 }
572
573 #[test]
574 fn test_parse_tool_name_missing_prefix() {
575 let result = McpManager::parse_tool_name("server__tool");
576 assert!(result.is_err());
577 }
578
579 #[test]
580 fn test_parse_tool_name_only_prefix() {
581 let result = McpManager::parse_tool_name("mcp__");
582 assert!(result.is_err());
583 }
584
585 #[test]
586 fn test_parse_tool_name_empty_string() {
587 let result = McpManager::parse_tool_name("");
588 assert!(result.is_err());
589 }
590
591 #[test]
592 fn test_tool_result_to_string_single_text() {
593 let result = CallToolResult {
594 content: vec![ToolContent::Text {
595 text: "Hello World".to_string(),
596 }],
597 is_error: false,
598 };
599 let output = tool_result_to_string(&result);
600 assert_eq!(output, "Hello World");
601 }
602
603 #[test]
604 fn test_tool_result_to_string_multiple_text() {
605 let result = CallToolResult {
606 content: vec![
607 ToolContent::Text {
608 text: "First line".to_string(),
609 },
610 ToolContent::Text {
611 text: "Second line".to_string(),
612 },
613 ],
614 is_error: false,
615 };
616 let output = tool_result_to_string(&result);
617 assert!(output.contains("First line"));
618 assert!(output.contains("Second line"));
619 }
620
621 #[test]
622 fn test_tool_result_to_string_empty() {
623 let result = CallToolResult {
624 content: vec![],
625 is_error: false,
626 };
627 let output = tool_result_to_string(&result);
628 assert_eq!(output, "");
629 }
630
631 #[test]
632 fn test_tool_result_to_string_image() {
633 let result = CallToolResult {
634 content: vec![ToolContent::Image {
635 data: "base64data".to_string(),
636 mime_type: "image/png".to_string(),
637 }],
638 is_error: false,
639 };
640 let output = tool_result_to_string(&result);
641 assert!(output.contains("[Image: image/png]"));
642 }
643
644 #[test]
645 fn test_tool_result_to_string_resource() {
646 use crate::mcp::protocol::ResourceContent;
647 let result = CallToolResult {
648 content: vec![ToolContent::Resource {
649 resource: ResourceContent {
650 uri: "file:///test.txt".to_string(),
651 mime_type: Some("text/plain".to_string()),
652 text: Some("Resource content".to_string()),
653 blob: None,
654 },
655 }],
656 is_error: false,
657 };
658 let output = tool_result_to_string(&result);
659 assert!(output.contains("Resource content"));
660 }
661
662 #[test]
663 fn test_tool_result_to_string_mixed_content() {
664 use crate::mcp::protocol::ResourceContent;
665 let result = CallToolResult {
666 content: vec![
667 ToolContent::Text {
668 text: "Text content".to_string(),
669 },
670 ToolContent::Image {
671 data: "base64".to_string(),
672 mime_type: "image/jpeg".to_string(),
673 },
674 ToolContent::Resource {
675 resource: ResourceContent {
676 uri: "file:///doc.md".to_string(),
677 mime_type: Some("text/markdown".to_string()),
678 text: Some("Doc content".to_string()),
679 blob: None,
680 },
681 },
682 ],
683 is_error: false,
684 };
685 let output = tool_result_to_string(&result);
686 assert!(output.contains("Text content"));
687 assert!(output.contains("[Image: image/jpeg]"));
688 assert!(output.contains("Doc content"));
689 }
690
691 #[tokio::test]
692 async fn test_get_status_registered_server() {
693 use std::collections::HashMap;
694 let manager = McpManager::new();
695
696 let config = McpServerConfig {
697 name: "test_server".to_string(),
698 transport: McpTransportConfig::Stdio {
699 command: "echo".to_string(),
700 args: vec![],
701 },
702 enabled: true,
703 env: HashMap::new(),
704 oauth: None,
705 tool_timeout_secs: 60,
706 };
707
708 manager.register_server(config).await;
709
710 let status = manager.get_status().await;
711 assert!(status.contains_key("test_server"));
712 assert!(!status["test_server"].connected);
713 assert!(status["test_server"].enabled);
714 }
715
716 #[tokio::test]
717 async fn test_get_status_disabled_server() {
718 use std::collections::HashMap;
719 let manager = McpManager::new();
720
721 let config = McpServerConfig {
722 name: "disabled_server".to_string(),
723 transport: McpTransportConfig::Stdio {
724 command: "echo".to_string(),
725 args: vec![],
726 },
727 enabled: false,
728 env: HashMap::new(),
729 oauth: None,
730 tool_timeout_secs: 60,
731 };
732
733 manager.register_server(config).await;
734
735 let status = manager.get_status().await;
736 assert!(status.contains_key("disabled_server"));
737 assert!(!status["disabled_server"].enabled);
738 }
739
740 #[tokio::test]
741 async fn test_get_all_tools_empty_manager() {
742 let manager = McpManager::new();
743 let tools = manager.get_all_tools().await;
744 assert!(tools.is_empty());
745 }
746
747 #[tokio::test]
748 async fn test_resolve_auth_header_none_when_no_oauth() {
749 let result = McpManager::resolve_auth_header(None).await.unwrap();
750 assert!(result.is_none());
751 }
752
753 #[tokio::test]
754 async fn test_resolve_auth_header_uses_static_token() {
755 use crate::mcp::protocol::OAuthConfig;
756 let oauth = OAuthConfig {
757 auth_url: "https://example.com/auth".to_string(),
758 token_url: "https://example.com/token".to_string(),
759 client_id: "client".to_string(),
760 client_secret: None,
761 scopes: vec![],
762 redirect_uri: "http://localhost/cb".to_string(),
763 access_token: Some("my-static-token".to_string()),
764 };
765 let result = McpManager::resolve_auth_header(Some(&oauth)).await.unwrap();
766 assert!(result.is_some());
767 let (key, value) = result.unwrap();
768 assert_eq!(key, "Authorization");
769 assert_eq!(value, "Bearer my-static-token");
770 }
771
772 #[tokio::test]
773 async fn test_resolve_auth_header_client_credentials_fails_gracefully() {
774 use crate::mcp::protocol::OAuthConfig;
775 let oauth = OAuthConfig {
777 auth_url: "https://127.0.0.1:1/auth".to_string(),
778 token_url: "http://127.0.0.1:1/token".to_string(),
779 client_id: "client".to_string(),
780 client_secret: Some("secret".to_string()),
781 scopes: vec!["read".to_string()],
782 redirect_uri: "http://localhost/cb".to_string(),
783 access_token: None,
784 };
785 let result = McpManager::resolve_auth_header(Some(&oauth)).await;
786 assert!(result.is_err());
787 }
788
789 #[tokio::test]
790 async fn test_connect_error_recorded_in_status() {
791 use std::collections::HashMap;
792 let manager = McpManager::new();
793
794 let config = McpServerConfig {
795 name: "bad-server".to_string(),
796 transport: McpTransportConfig::Stdio {
797 command: "true".to_string(),
799 args: vec![],
800 },
801 enabled: true,
802 env: HashMap::new(),
803 oauth: None,
804 tool_timeout_secs: 5,
805 };
806
807 manager.register_server(config).await;
808 let _ = manager.connect("bad-server").await;
810
811 let status = manager.get_status().await;
812 let s = &status["bad-server"];
813 assert!(!s.connected, "server should not be connected");
814 assert!(
815 s.error.is_some(),
816 "error should be recorded after failed connect"
817 );
818 }
819
820 #[tokio::test]
821 async fn test_get_all_tools_returns_server_name_not_full_name() {
822 let manager = McpManager::new();
827 let tools = manager.get_all_tools().await;
828 for (name, _tool) in &tools {
831 assert!(
832 !name.starts_with("mcp__"),
833 "get_all_tools() must return server names, not prefixed full names; got '{name}'"
834 );
835 }
836 }
837
838 #[tokio::test]
839 async fn touch_updates_last_used_at_ms() {
840 let manager = McpManager::new();
841 assert!(manager.last_used_at_ms("svc-a").await.is_none());
843 manager.touch("svc-a").await;
844 let t1 = manager.last_used_at_ms("svc-a").await.expect("set");
845 assert!(t1 > 0);
846 manager.touch("svc-a").await;
848 let t2 = manager.last_used_at_ms("svc-a").await.expect("set again");
849 assert!(t2 >= t1);
850 }
851
852 #[tokio::test]
853 async fn disconnect_idle_drops_stale_servers_and_keeps_fresh_ones() {
854 let manager = McpManager::new();
855 manager.touch("fresh-svc").await;
869 assert!(manager.last_used_at_ms("fresh-svc").await.is_some());
871 assert!(manager.last_used_at_ms("never-touched").await.is_none());
872
873 let dropped = manager.disconnect_idle(0).await;
874 assert!(
875 dropped.is_empty(),
876 "no clients connected -> nothing to disconnect, got {dropped:?}"
877 );
878 assert!(
883 manager.last_used_at_ms("fresh-svc").await.is_none(),
884 "orphan timestamp (touched, never connected) must be purged by disconnect_idle"
885 );
886 }
887
888 #[tokio::test]
889 async fn touch_keeps_timestamp_after_explicit_disconnect_removes_it() {
890 let manager = McpManager::new();
891 manager.touch("svc").await;
892 assert!(manager.last_used_at_ms("svc").await.is_some());
893 let _ = manager.disconnect("svc").await;
896 assert!(manager.last_used_at_ms("svc").await.is_none());
897 }
898}