1use crate::mcp::client::McpClient;
6use crate::mcp::oauth;
7use crate::mcp::protocol::{
8 CallToolResult, McpServerConfig, McpTool, McpTransportConfig, OAuthConfig, ToolContent,
9};
10use crate::mcp::transport::http_sse::HttpSseTransport;
11use crate::mcp::transport::stdio::StdioTransport;
12use crate::mcp::transport::streamable_http::StreamableHttpTransport;
13use crate::mcp::transport::McpTransport;
14use anyhow::{anyhow, Result};
15use std::collections::HashMap;
16use std::sync::Arc;
17use tokio::sync::RwLock;
18
19#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
21pub struct McpServerStatus {
22 pub name: String,
23 pub connected: bool,
24 pub enabled: bool,
25 pub tool_count: usize,
26 pub error: Option<String>,
27}
28
29pub struct McpManager {
31 clients: RwLock<HashMap<String, Arc<McpClient>>>,
33 configs: RwLock<HashMap<String, McpServerConfig>>,
35}
36
37impl McpManager {
38 pub fn new() -> Self {
40 Self {
41 clients: RwLock::new(HashMap::new()),
42 configs: RwLock::new(HashMap::new()),
43 }
44 }
45
46 pub async fn register_server(&self, config: McpServerConfig) {
48 let name = config.name.clone();
49 let mut configs = self.configs.write().await;
50 configs.insert(name.clone(), config);
51 tracing::info!("Registered MCP server: {}", name);
52 }
53
54 pub async fn connect(&self, name: &str) -> Result<()> {
56 let config = {
58 let configs = self.configs.read().await;
59 configs
60 .get(name)
61 .cloned()
62 .ok_or_else(|| anyhow!("MCP server not found: {}", name))?
63 };
64
65 if !config.enabled {
66 return Err(anyhow!("MCP server is disabled: {}", name));
67 }
68
69 let auth_header = Self::resolve_auth_header(config.oauth.as_ref()).await?;
71
72 let transport: Arc<dyn McpTransport> = match &config.transport {
74 McpTransportConfig::Stdio { command, args } => Arc::new(
75 StdioTransport::spawn_with_timeout(
76 command,
77 args,
78 &config.env,
79 config.tool_timeout_secs,
80 )
81 .await?,
82 ),
83 McpTransportConfig::Http { url, headers } => {
84 let mut merged = headers.clone();
85 if let Some((k, v)) = &auth_header {
86 merged.insert(k.clone(), v.clone());
87 }
88 Arc::new(
89 HttpSseTransport::connect_with_timeout(url, merged, config.tool_timeout_secs)
90 .await?,
91 )
92 }
93 McpTransportConfig::StreamableHttp { url, headers } => {
94 let mut merged = headers.clone();
95 if let Some((k, v)) = &auth_header {
96 merged.insert(k.clone(), v.clone());
97 }
98 Arc::new(
99 StreamableHttpTransport::connect_with_timeout(
100 url,
101 merged,
102 config.tool_timeout_secs,
103 )
104 .await?,
105 )
106 }
107 };
108
109 let client = Arc::new(McpClient::new(name.to_string(), transport));
111
112 client.initialize().await?;
114
115 let tools = client.list_tools().await?;
117 tracing::info!("MCP server '{}' connected with {} tools", name, tools.len());
118
119 {
121 let mut clients = self.clients.write().await;
122 clients.insert(name.to_string(), client);
123 }
124
125 Ok(())
126 }
127
128 pub async fn disconnect(&self, name: &str) -> Result<()> {
130 let client = {
131 let mut clients = self.clients.write().await;
132 clients.remove(name)
133 };
134
135 if let Some(client) = client {
136 client.close().await?;
137 tracing::info!("MCP server '{}' disconnected", name);
138 }
139
140 Ok(())
141 }
142
143 pub async fn all_configs(&self) -> Vec<McpServerConfig> {
145 self.configs.read().await.values().cloned().collect()
146 }
147
148 pub async fn get_all_tools(&self) -> Vec<(String, McpTool)> {
152 let clients = self.clients.read().await;
153 let mut all_tools = Vec::new();
154
155 for (server_name, client) in clients.iter() {
156 let tools = client.get_cached_tools().await;
157 for tool in tools {
158 let full_name = format!("mcp__{}__{}", server_name, tool.name);
159 all_tools.push((full_name, tool));
160 }
161 }
162
163 all_tools
164 }
165
166 pub async fn call_tool(
170 &self,
171 full_name: &str,
172 arguments: Option<serde_json::Value>,
173 ) -> Result<CallToolResult> {
174 let (server_name, tool_name) = Self::parse_tool_name(full_name)?;
176
177 let client = {
179 let clients = self.clients.read().await;
180 clients
181 .get(&server_name)
182 .cloned()
183 .ok_or_else(|| anyhow!("MCP server not connected: {}", server_name))?
184 };
185
186 client.call_tool(&tool_name, arguments).await
188 }
189
190 async fn resolve_auth_header(oauth: Option<&OAuthConfig>) -> Result<Option<(String, String)>> {
196 let Some(oauth) = oauth else {
197 return Ok(None);
198 };
199
200 let token = if let Some(static_token) = &oauth.access_token {
201 static_token.clone()
202 } else {
203 oauth::exchange_client_credentials(
204 &oauth.token_url,
205 &oauth.client_id,
206 oauth.client_secret.as_deref().unwrap_or(""),
207 &oauth.scopes,
208 )
209 .await?
210 };
211
212 Ok(Some((
213 "Authorization".to_string(),
214 format!("Bearer {}", token),
215 )))
216 }
217
218 fn parse_tool_name(full_name: &str) -> Result<(String, String)> {
220 if !full_name.starts_with("mcp__") {
222 return Err(anyhow!("Invalid MCP tool name: {}", full_name));
223 }
224
225 let rest = &full_name[5..]; let parts: Vec<&str> = rest.splitn(2, "__").collect();
227
228 if parts.len() != 2 {
229 return Err(anyhow!("Invalid MCP tool name format: {}", full_name));
230 }
231
232 Ok((parts[0].to_string(), parts[1].to_string()))
233 }
234
235 pub async fn get_status(&self) -> HashMap<String, McpServerStatus> {
237 let configs = self.configs.read().await;
238 let clients = self.clients.read().await;
239 let mut status = HashMap::new();
240
241 for (name, config) in configs.iter() {
242 let client = clients.get(name);
243 let (connected, tool_count) = if let Some(c) = client {
244 (c.is_connected(), c.get_cached_tools().await.len())
245 } else {
246 (false, 0)
247 };
248
249 status.insert(
250 name.clone(),
251 McpServerStatus {
252 name: name.clone(),
253 connected,
254 enabled: config.enabled,
255 tool_count,
256 error: None,
257 },
258 );
259 }
260
261 status
262 }
263
264 pub async fn get_client(&self, name: &str) -> Option<Arc<McpClient>> {
266 let clients = self.clients.read().await;
267 clients.get(name).cloned()
268 }
269
270 pub async fn is_connected(&self, name: &str) -> bool {
272 let clients = self.clients.read().await;
273 clients.get(name).map(|c| c.is_connected()).unwrap_or(false)
274 }
275
276 pub async fn list_connected(&self) -> Vec<String> {
278 let clients = self.clients.read().await;
279 clients.keys().cloned().collect()
280 }
281
282 pub async fn get_server_tools(&self, name: &str) -> Vec<McpTool> {
284 let clients = self.clients.read().await;
285 match clients.get(name) {
286 Some(client) => client.get_cached_tools().await,
287 None => Vec::new(),
288 }
289 }
290}
291
292impl Default for McpManager {
293 fn default() -> Self {
294 Self::new()
295 }
296}
297
298pub fn tool_result_to_string(result: &CallToolResult) -> String {
300 let mut output = String::new();
301
302 for content in &result.content {
303 match content {
304 ToolContent::Text { text } => {
305 output.push_str(text);
306 output.push('\n');
307 }
308 ToolContent::Image { data: _, mime_type } => {
309 output.push_str(&format!("[Image: {}]\n", mime_type));
310 }
311 ToolContent::Resource { resource } => {
312 if let Some(text) = &resource.text {
313 output.push_str(text);
314 output.push('\n');
315 } else {
316 output.push_str(&format!("[Resource: {}]\n", resource.uri));
317 }
318 }
319 }
320 }
321
322 output.trim_end().to_string()
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn test_parse_tool_name() {
331 let (server, tool) = McpManager::parse_tool_name("mcp__github__create_issue").unwrap();
332 assert_eq!(server, "github");
333 assert_eq!(tool, "create_issue");
334 }
335
336 #[test]
337 fn test_parse_tool_name_with_underscores() {
338 let (server, tool) = McpManager::parse_tool_name("mcp__my_server__my_tool_name").unwrap();
339 assert_eq!(server, "my_server");
340 assert_eq!(tool, "my_tool_name");
341 }
342
343 #[test]
344 fn test_parse_tool_name_invalid() {
345 assert!(McpManager::parse_tool_name("invalid_name").is_err());
346 assert!(McpManager::parse_tool_name("mcp__nodelimiter").is_err());
347 }
348
349 #[test]
350 fn test_tool_result_to_string() {
351 let result = CallToolResult {
352 content: vec![
353 ToolContent::Text {
354 text: "Line 1".to_string(),
355 },
356 ToolContent::Text {
357 text: "Line 2".to_string(),
358 },
359 ],
360 is_error: false,
361 };
362
363 let output = tool_result_to_string(&result);
364 assert!(output.contains("Line 1"));
365 assert!(output.contains("Line 2"));
366 }
367
368 #[tokio::test]
369 async fn test_mcp_manager_new() {
370 let manager = McpManager::new();
371 let status = manager.get_status().await;
372 assert!(status.is_empty());
373 }
374
375 #[tokio::test]
376 async fn test_mcp_manager_register_server() {
377 let manager = McpManager::new();
378
379 let config = McpServerConfig {
380 name: "test".to_string(),
381 transport: McpTransportConfig::Stdio {
382 command: "echo".to_string(),
383 args: vec![],
384 },
385 enabled: true,
386 env: HashMap::new(),
387 oauth: None,
388 tool_timeout_secs: 60,
389 };
390
391 manager.register_server(config).await;
392
393 let status = manager.get_status().await;
394 assert!(status.contains_key("test"));
395 assert!(!status["test"].connected);
396 }
397
398 #[tokio::test]
399 async fn test_mcp_manager_default() {
400 let manager = McpManager::default();
401 let status = manager.get_status().await;
402 assert!(status.is_empty());
403 }
404
405 #[tokio::test]
406 async fn test_list_connected_empty() {
407 let manager = McpManager::new();
408 let connected = manager.list_connected().await;
409 assert!(connected.is_empty());
410 }
411
412 #[tokio::test]
413 async fn test_is_connected_false_for_unknown_server() {
414 let manager = McpManager::new();
415 let connected = manager.is_connected("unknown_server").await;
416 assert!(!connected);
417 }
418
419 #[tokio::test]
420 async fn test_get_client_none_for_unknown_server() {
421 let manager = McpManager::new();
422 let client = manager.get_client("unknown_server").await;
423 assert!(client.is_none());
424 }
425
426 #[test]
427 fn test_parse_tool_name_simple() {
428 let (server, tool) = McpManager::parse_tool_name("mcp__server__tool").unwrap();
429 assert_eq!(server, "server");
430 assert_eq!(tool, "tool");
431 }
432
433 #[test]
434 fn test_parse_tool_name_multiple_underscores() {
435 let (server, tool) = McpManager::parse_tool_name("mcp__my_server__my_tool_name").unwrap();
436 assert_eq!(server, "my_server");
437 assert_eq!(tool, "my_tool_name");
438 }
439
440 #[test]
441 fn test_parse_tool_name_missing_prefix() {
442 let result = McpManager::parse_tool_name("server__tool");
443 assert!(result.is_err());
444 }
445
446 #[test]
447 fn test_parse_tool_name_only_prefix() {
448 let result = McpManager::parse_tool_name("mcp__");
449 assert!(result.is_err());
450 }
451
452 #[test]
453 fn test_parse_tool_name_empty_string() {
454 let result = McpManager::parse_tool_name("");
455 assert!(result.is_err());
456 }
457
458 #[test]
459 fn test_tool_result_to_string_single_text() {
460 let result = CallToolResult {
461 content: vec![ToolContent::Text {
462 text: "Hello World".to_string(),
463 }],
464 is_error: false,
465 };
466 let output = tool_result_to_string(&result);
467 assert_eq!(output, "Hello World");
468 }
469
470 #[test]
471 fn test_tool_result_to_string_multiple_text() {
472 let result = CallToolResult {
473 content: vec![
474 ToolContent::Text {
475 text: "First line".to_string(),
476 },
477 ToolContent::Text {
478 text: "Second line".to_string(),
479 },
480 ],
481 is_error: false,
482 };
483 let output = tool_result_to_string(&result);
484 assert!(output.contains("First line"));
485 assert!(output.contains("Second line"));
486 }
487
488 #[test]
489 fn test_tool_result_to_string_empty() {
490 let result = CallToolResult {
491 content: vec![],
492 is_error: false,
493 };
494 let output = tool_result_to_string(&result);
495 assert_eq!(output, "");
496 }
497
498 #[test]
499 fn test_tool_result_to_string_image() {
500 let result = CallToolResult {
501 content: vec![ToolContent::Image {
502 data: "base64data".to_string(),
503 mime_type: "image/png".to_string(),
504 }],
505 is_error: false,
506 };
507 let output = tool_result_to_string(&result);
508 assert!(output.contains("[Image: image/png]"));
509 }
510
511 #[test]
512 fn test_tool_result_to_string_resource() {
513 use crate::mcp::protocol::ResourceContent;
514 let result = CallToolResult {
515 content: vec![ToolContent::Resource {
516 resource: ResourceContent {
517 uri: "file:///test.txt".to_string(),
518 mime_type: Some("text/plain".to_string()),
519 text: Some("Resource content".to_string()),
520 blob: None,
521 },
522 }],
523 is_error: false,
524 };
525 let output = tool_result_to_string(&result);
526 assert!(output.contains("Resource content"));
527 }
528
529 #[test]
530 fn test_tool_result_to_string_mixed_content() {
531 use crate::mcp::protocol::ResourceContent;
532 let result = CallToolResult {
533 content: vec![
534 ToolContent::Text {
535 text: "Text content".to_string(),
536 },
537 ToolContent::Image {
538 data: "base64".to_string(),
539 mime_type: "image/jpeg".to_string(),
540 },
541 ToolContent::Resource {
542 resource: ResourceContent {
543 uri: "file:///doc.md".to_string(),
544 mime_type: Some("text/markdown".to_string()),
545 text: Some("Doc content".to_string()),
546 blob: None,
547 },
548 },
549 ],
550 is_error: false,
551 };
552 let output = tool_result_to_string(&result);
553 assert!(output.contains("Text content"));
554 assert!(output.contains("[Image: image/jpeg]"));
555 assert!(output.contains("Doc content"));
556 }
557
558 #[tokio::test]
559 async fn test_get_status_registered_server() {
560 use std::collections::HashMap;
561 let manager = McpManager::new();
562
563 let config = McpServerConfig {
564 name: "test_server".to_string(),
565 transport: McpTransportConfig::Stdio {
566 command: "echo".to_string(),
567 args: vec![],
568 },
569 enabled: true,
570 env: HashMap::new(),
571 oauth: None,
572 tool_timeout_secs: 60,
573 };
574
575 manager.register_server(config).await;
576
577 let status = manager.get_status().await;
578 assert!(status.contains_key("test_server"));
579 assert!(!status["test_server"].connected);
580 assert!(status["test_server"].enabled);
581 }
582
583 #[tokio::test]
584 async fn test_get_status_disabled_server() {
585 use std::collections::HashMap;
586 let manager = McpManager::new();
587
588 let config = McpServerConfig {
589 name: "disabled_server".to_string(),
590 transport: McpTransportConfig::Stdio {
591 command: "echo".to_string(),
592 args: vec![],
593 },
594 enabled: false,
595 env: HashMap::new(),
596 oauth: None,
597 tool_timeout_secs: 60,
598 };
599
600 manager.register_server(config).await;
601
602 let status = manager.get_status().await;
603 assert!(status.contains_key("disabled_server"));
604 assert!(!status["disabled_server"].enabled);
605 }
606
607 #[tokio::test]
608 async fn test_get_all_tools_empty_manager() {
609 let manager = McpManager::new();
610 let tools = manager.get_all_tools().await;
611 assert!(tools.is_empty());
612 }
613
614 #[tokio::test]
615 async fn test_resolve_auth_header_none_when_no_oauth() {
616 let result = McpManager::resolve_auth_header(None).await.unwrap();
617 assert!(result.is_none());
618 }
619
620 #[tokio::test]
621 async fn test_resolve_auth_header_uses_static_token() {
622 use crate::mcp::protocol::OAuthConfig;
623 let oauth = OAuthConfig {
624 auth_url: "https://example.com/auth".to_string(),
625 token_url: "https://example.com/token".to_string(),
626 client_id: "client".to_string(),
627 client_secret: None,
628 scopes: vec![],
629 redirect_uri: "http://localhost/cb".to_string(),
630 access_token: Some("my-static-token".to_string()),
631 };
632 let result = McpManager::resolve_auth_header(Some(&oauth)).await.unwrap();
633 assert!(result.is_some());
634 let (key, value) = result.unwrap();
635 assert_eq!(key, "Authorization");
636 assert_eq!(value, "Bearer my-static-token");
637 }
638
639 #[tokio::test]
640 async fn test_resolve_auth_header_client_credentials_fails_gracefully() {
641 use crate::mcp::protocol::OAuthConfig;
642 let oauth = OAuthConfig {
644 auth_url: "https://127.0.0.1:1/auth".to_string(),
645 token_url: "http://127.0.0.1:1/token".to_string(),
646 client_id: "client".to_string(),
647 client_secret: Some("secret".to_string()),
648 scopes: vec!["read".to_string()],
649 redirect_uri: "http://localhost/cb".to_string(),
650 access_token: None,
651 };
652 let result = McpManager::resolve_auth_header(Some(&oauth)).await;
653 assert!(result.is_err());
654 }
655}