1use async_trait::async_trait;
8use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10use serde_json::{Value, json};
11use std::collections::HashMap;
12
13use rmcp::model as mcp_model;
14use rmcp::service::{Peer, RunningService};
15use rmcp::{RoleClient, ServiceExt};
16
17use ai_agents_core::{Tool, ToolResult};
18
19#[derive(Debug, Clone)]
21pub(crate) struct DiscoveredFunction {
22 pub(crate) name: String,
24 pub(crate) description: String,
26 pub(crate) input_schema: Value,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct MCPWrapperConfig {
33 pub name: String,
35
36 #[serde(flatten)]
38 pub transport: MCPWrapperTransport,
39
40 #[serde(default)]
42 pub env: HashMap<String, String>,
43
44 #[serde(default = "default_startup_timeout")]
46 pub startup_timeout_ms: u64,
47
48 #[serde(default)]
50 pub security: MCPWrapperSecurity,
51
52 #[serde(default)]
55 pub description: Option<String>,
56
57 #[serde(default)]
59 pub views: HashMap<String, MCPViewConfig>,
60}
61
62fn default_startup_timeout() -> u64 {
63 30_000
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68#[serde(tag = "transport", rename_all = "lowercase")]
69pub enum MCPWrapperTransport {
70 Stdio {
71 command: String,
72 #[serde(default)]
73 args: Vec<String>,
74 },
75 Http {
76 url: String,
77 #[serde(default)]
78 headers: HashMap<String, String>,
79 },
80 #[serde(alias = "sse")]
81 Sse {
82 url: String,
83 #[serde(default)]
84 headers: HashMap<String, String>,
85 },
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize, Default)]
90pub struct MCPWrapperSecurity {
91 #[serde(default)]
93 pub blocked_functions: Vec<String>,
94
95 #[serde(default)]
97 pub hitl_functions: Vec<String>,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct MCPViewConfig {
103 pub functions: Vec<String>,
105 #[serde(default)]
107 pub description: Option<String>,
108}
109
110pub struct MCPWrapperTool {
117 config: MCPWrapperConfig,
118 description: String,
120 schema: Value,
122 _running: RwLock<Option<RunningService<RoleClient, ()>>>,
124 peer: RwLock<Option<Peer<RoleClient>>>,
126 functions: Vec<DiscoveredFunction>,
128}
129
130impl MCPWrapperTool {
131 pub fn new(config: MCPWrapperConfig) -> Self {
134 let desc = config
135 .description
136 .clone()
137 .unwrap_or_else(|| format!("{} operations via MCP", config.name));
138 Self {
139 config,
140 description: desc,
141 schema: json!({"type": "object"}),
142 _running: RwLock::new(None),
143 peer: RwLock::new(None),
144 functions: Vec::new(),
145 }
146 }
147
148 pub async fn initialized(mut self) -> Result<Self, String> {
151 let running = match &self.config.transport {
152 MCPWrapperTransport::Stdio { command, args } => {
153 Self::connect_stdio(command, args, &self.config.env, &self.config.name).await?
154 }
155 MCPWrapperTransport::Http { url, headers }
156 | MCPWrapperTransport::Sse { url, headers } => {
157 Self::connect_http(url, headers, &self.config.name).await?
158 }
159 };
160
161 let peer = running.peer().clone();
162
163 let tool_list = peer
165 .list_all_tools()
166 .await
167 .map_err(|e| format!("Failed to list tools from '{}': {}", self.config.name, e))?;
168
169 let mut functions = Vec::new();
170 for tool in &tool_list {
171 let name = tool.name.to_string();
172
173 if self.config.security.blocked_functions.contains(&name) {
175 tracing::debug!(
176 server = %self.config.name,
177 function = %name,
178 "Skipping blocked MCP function"
179 );
180 continue;
181 }
182
183 let description = tool
184 .description
185 .as_ref()
186 .map(|d| d.to_string())
187 .unwrap_or_default();
188
189 let input_schema = Value::Object(tool.input_schema.as_ref().clone());
190
191 functions.push(DiscoveredFunction {
192 name,
193 description,
194 input_schema,
195 });
196 }
197
198 tracing::info!(
199 server = %self.config.name,
200 functions = functions.len(),
201 "MCP wrapper tool initialized"
202 );
203
204 self.schema = Self::build_schema(&self.config.name, &functions);
206 self.description = Self::build_description(
207 &self.config.name,
208 self.config.description.as_deref(),
209 &functions,
210 );
211 self.functions = functions;
212 *self.peer.write() = Some(peer);
213 *self._running.write() = Some(running);
214
215 Ok(self)
216 }
217
218 pub(crate) fn build_schema(server_name: &str, functions: &[DiscoveredFunction]) -> Value {
220 let function_names: Vec<Value> = functions
221 .iter()
222 .map(|f| Value::String(f.name.clone()))
223 .collect();
224
225 let mut params_description =
227 String::from("Parameters for the selected function. See function list for details.");
228
229 if functions.len() <= 30 {
230 params_description = String::from("Parameters for the selected function:\n");
231 for f in functions {
232 if let Some(props) = f.input_schema.get("properties") {
233 let prop_names: Vec<&str> = props
234 .as_object()
235 .map(|obj| obj.keys().map(|k| k.as_str()).collect())
236 .unwrap_or_default();
237 if !prop_names.is_empty() {
238 params_description.push_str(&format!(
239 " - {}: {{{}}}\n",
240 f.name,
241 prop_names.join(", ")
242 ));
243 } else {
244 params_description.push_str(&format!(" - {}: (no parameters)\n", f.name));
245 }
246 }
247 }
248 }
249
250 json!({
251 "type": "object",
252 "required": ["function"],
253 "properties": {
254 "function": {
255 "type": "string",
256 "description": format!("The function to call inside the '{}' tool. Pass this as arguments.function, NOT as the tool name.", server_name),
257 "enum": function_names
258 },
259 "params": {
260 "type": "object",
261 "description": params_description,
262 "additionalProperties": true
263 }
264 }
265 })
266 }
267
268 pub(crate) fn build_description(
270 server_name: &str,
271 custom: Option<&str>,
272 functions: &[DiscoveredFunction],
273 ) -> String {
274 let mut desc = match custom {
275 Some(c) if !c.is_empty() => c.to_string(),
276 _ => format!("{} operations via MCP.", server_name),
277 };
278
279 if !functions.is_empty() {
280 desc.push_str(&format!(
283 " Use tool '{}' with arguments.function set to one of: ",
284 server_name
285 ));
286 let names: Vec<&str> = functions.iter().map(|f| f.name.as_str()).collect();
287 desc.push_str(&names.join(", "));
288 desc.push('.');
289
290 if functions.len() <= 20 {
292 desc.push_str("\n\nFunction details:");
293 for f in functions {
294 if !f.description.is_empty() {
295 desc.push_str(&format!("\n- {}: {}", f.name, f.description));
296 } else {
297 desc.push_str(&format!("\n- {}", f.name));
298 }
299 }
300 }
301 }
302
303 desc
304 }
305
306 pub(crate) async fn call_function(&self, function: &str, params: Value) -> ToolResult {
308 if !self.functions.iter().any(|f| f.name == function) {
310 let available: Vec<&str> = self.functions.iter().map(|f| f.name.as_str()).collect();
311 return ToolResult::error(format!(
312 "Unknown function '{}'. Available functions: {}",
313 function,
314 available.join(", ")
315 ));
316 }
317
318 let peer = {
319 let peer_guard = self.peer.read();
320 match peer_guard.as_ref() {
321 Some(p) => p.clone(),
322 None => {
323 return ToolResult::error(format!(
324 "MCP server '{}' not initialized",
325 self.config.name
326 ));
327 }
328 }
329 };
330
331 let mut call_params = mcp_model::CallToolRequestParams::new(function.to_string());
332 if let Value::Object(map) = params {
333 call_params.arguments = Some(map.into_iter().collect());
334 }
335
336 match peer.call_tool(call_params).await {
337 Ok(result) => {
338 let output = result
339 .content
340 .iter()
341 .filter_map(|c| match &c.raw {
342 mcp_model::RawContent::Text(t) => Some(t.text.as_str()),
343 _ => None,
344 })
345 .collect::<Vec<_>>()
346 .join("\n");
347
348 if result.is_error.unwrap_or(false) {
349 ToolResult::error(output)
350 } else {
351 ToolResult::ok(output)
352 }
353 }
354 Err(e) => ToolResult::error(format!("MCP function '{}' failed: {}", function, e)),
355 }
356 }
357
358 pub(crate) fn get_functions_filtered(&self, names: &[String]) -> Vec<DiscoveredFunction> {
360 self.functions
361 .iter()
362 .filter(|f| names.iter().any(|n| n == &f.name))
363 .cloned()
364 .collect()
365 }
366
367 async fn connect_stdio(
369 command: &str,
370 args: &[String],
371 env: &HashMap<String, String>,
372 server_name: &str,
373 ) -> Result<RunningService<RoleClient, ()>, String> {
374 use rmcp::transport::TokioChildProcess;
375 use tokio::process::Command;
376
377 let mut cmd = Command::new(command);
378 cmd.args(args);
379 for (key, value) in env {
380 cmd.env(key, value);
381 }
382
383 let transport = TokioChildProcess::new(cmd)
384 .map_err(|e| format!("Failed to spawn '{}': {}", command, e))?;
385
386 let running: RunningService<RoleClient, ()> = ()
387 .serve(transport)
388 .await
389 .map_err(|e| format!("Failed MCP handshake with '{}': {}", server_name, e))?;
390
391 Ok(running)
392 }
393
394 async fn connect_http(
396 url: &str,
397 headers: &HashMap<String, String>,
398 server_name: &str,
399 ) -> Result<RunningService<RoleClient, ()>, String> {
400 use rmcp::transport::streamable_http_client::{
401 StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
402 };
403
404 if headers.is_empty() {
405 let transport = StreamableHttpClientTransport::from_uri(url);
406 let running: RunningService<RoleClient, ()> = ()
407 .serve(transport)
408 .await
409 .map_err(|e| format!("Failed HTTP MCP connection to '{}': {}", server_name, e))?;
410 Ok(running)
411 } else {
412 use reqwest::header::{HeaderName, HeaderValue};
413
414 let mut custom_headers = HashMap::new();
415 for (key, value) in headers {
416 let header_name = HeaderName::try_from(key.as_str())
417 .map_err(|e| format!("Invalid header name '{}': {}", key, e))?;
418 let header_value = HeaderValue::try_from(value.as_str())
419 .map_err(|e| format!("Invalid header value for '{}': {}", key, e))?;
420 custom_headers.insert(header_name, header_value);
421 }
422
423 let config =
424 StreamableHttpClientTransportConfig::with_uri(url).custom_headers(custom_headers);
425 let transport = StreamableHttpClientTransport::from_config(config);
426
427 let running: RunningService<RoleClient, ()> = ()
428 .serve(transport)
429 .await
430 .map_err(|e| format!("Failed HTTP MCP connection to '{}': {}", server_name, e))?;
431 Ok(running)
432 }
433 }
434
435 pub async fn shutdown(&self) {
437 let running = self._running.write().take();
438 if let Some(r) = running {
439 let _ = r.cancel().await;
440 }
441 self.peer.write().take();
442 }
443
444 pub fn requires_hitl(&self, function_name: &str) -> bool {
446 self.config
447 .security
448 .hitl_functions
449 .iter()
450 .any(|f| f == function_name)
451 }
452
453 pub fn function_count(&self) -> usize {
455 self.functions.len()
456 }
457
458 pub fn function_names(&self) -> Vec<&str> {
460 self.functions.iter().map(|f| f.name.as_str()).collect()
461 }
462}
463
464#[async_trait]
465impl Tool for MCPWrapperTool {
466 fn id(&self) -> &str {
467 &self.config.name
468 }
469
470 fn name(&self) -> &str {
471 &self.config.name
472 }
473
474 fn description(&self) -> &str {
475 &self.description
476 }
477
478 fn input_schema(&self) -> Value {
479 self.schema.clone()
480 }
481
482 async fn execute(&self, args: Value) -> ToolResult {
483 let function = match args.get("function").and_then(|v| v.as_str()) {
485 Some(f) => f.to_string(),
486 None => {
487 let available: Vec<&str> = self.functions.iter().map(|f| f.name.as_str()).collect();
488 return ToolResult::error(format!(
489 "'function' is required. Available functions: {}",
490 available.join(", ")
491 ));
492 }
493 };
494
495 let params = args.get("params").cloned().unwrap_or_else(|| json!({}));
497
498 if self.requires_hitl(&function) {
502 return ToolResult::ok_with_metadata(
503 format!(
504 "Function '{}' on MCP server '{}' requires approval before execution.",
505 function, self.config.name
506 ),
507 HashMap::from([
508 ("_hitl_required".to_string(), json!(true)),
509 ("_hitl_function".to_string(), json!(function)),
510 ("_hitl_params".to_string(), params.clone()),
511 ("_hitl_tool".to_string(), json!(self.config.name)),
512 ]),
513 );
514 }
515
516 self.call_function(&function, params).await
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523
524 #[test]
525 fn test_mcp_wrapper_config_deserialize_stdio() {
526 let yaml = r#"
527name: github
528type: mcp
529transport: stdio
530command: npx
531args: ["-y", "@modelcontextprotocol/server-github"]
532env:
533 GITHUB_TOKEN: "test-token"
534startup_timeout_ms: 15000
535security:
536 blocked_functions: [delete_repo]
537 hitl_functions: [create_issue]
538"#;
539 let config: MCPWrapperConfig = serde_yaml::from_str(yaml).unwrap();
540 assert_eq!(config.name, "github");
541 assert_eq!(config.startup_timeout_ms, 15000);
542 assert_eq!(config.security.blocked_functions, vec!["delete_repo"]);
543 assert_eq!(config.security.hitl_functions, vec!["create_issue"]);
544 }
545
546 #[test]
547 fn test_mcp_wrapper_config_deserialize_http() {
548 let yaml = r#"
549name: custom_api
550type: mcp
551transport: http
552url: "http://localhost:3000/mcp"
553headers:
554 Authorization: "Bearer test"
555"#;
556 let config: MCPWrapperConfig = serde_yaml::from_str(yaml).unwrap();
557 assert_eq!(config.name, "custom_api");
558 }
559
560 #[test]
561 fn test_build_schema() {
562 let functions = vec![
563 DiscoveredFunction {
564 name: "create_issue".to_string(),
565 description: "Create a new issue".to_string(),
566 input_schema: json!({
567 "type": "object",
568 "properties": {
569 "repo": {"type": "string"},
570 "title": {"type": "string"},
571 "body": {"type": "string"}
572 },
573 "required": ["repo", "title"]
574 }),
575 },
576 DiscoveredFunction {
577 name: "list_repos".to_string(),
578 description: "List repositories".to_string(),
579 input_schema: json!({
580 "type": "object",
581 "properties": {
582 "org": {"type": "string"}
583 }
584 }),
585 },
586 ];
587
588 let schema = MCPWrapperTool::build_schema("github", &functions);
589
590 assert_eq!(schema["type"], "object");
591 assert!(
592 schema["required"]
593 .as_array()
594 .unwrap()
595 .contains(&json!("function"))
596 );
597 let func_enum = &schema["properties"]["function"]["enum"];
598 assert!(
599 func_enum
600 .as_array()
601 .unwrap()
602 .contains(&json!("create_issue"))
603 );
604 assert!(func_enum.as_array().unwrap().contains(&json!("list_repos")));
605 }
606
607 #[test]
608 fn test_build_description() {
609 let functions = vec![
610 DiscoveredFunction {
611 name: "create_issue".to_string(),
612 description: "Create a new issue".to_string(),
613 input_schema: json!({}),
614 },
615 DiscoveredFunction {
616 name: "list_repos".to_string(),
617 description: "List repositories".to_string(),
618 input_schema: json!({}),
619 },
620 ];
621
622 let desc = MCPWrapperTool::build_description("github", None, &functions);
623
624 assert!(desc.contains("github operations via MCP"));
625 assert!(desc.contains("Use tool 'github'"));
626 assert!(desc.contains("create_issue"));
627 assert!(desc.contains("list_repos"));
628 assert!(desc.contains("Create a new issue"));
629 }
630
631 #[test]
632 fn test_requires_hitl() {
633 let config = MCPWrapperConfig {
634 name: "github".to_string(),
635 transport: MCPWrapperTransport::Stdio {
636 command: "npx".to_string(),
637 args: vec![],
638 },
639 env: HashMap::new(),
640 startup_timeout_ms: 30000,
641 security: MCPWrapperSecurity {
642 blocked_functions: vec![],
643 hitl_functions: vec!["create_issue".to_string()],
644 },
645 description: None,
646 views: HashMap::new(),
647 };
648 let tool = MCPWrapperTool::new(config);
649
650 assert!(tool.requires_hitl("create_issue"));
651 assert!(!tool.requires_hitl("list_repos"));
652 }
653
654 #[test]
655 fn test_default_description() {
656 let config = MCPWrapperConfig {
657 name: "github".to_string(),
658 transport: MCPWrapperTransport::Stdio {
659 command: "npx".to_string(),
660 args: vec![],
661 },
662 env: HashMap::new(),
663 startup_timeout_ms: 30000,
664 security: MCPWrapperSecurity::default(),
665 description: None,
666 views: HashMap::new(),
667 };
668 let tool = MCPWrapperTool::new(config);
669 assert_eq!(tool.description(), "github operations via MCP");
670 }
671
672 #[test]
673 fn test_custom_description() {
674 let config = MCPWrapperConfig {
675 name: "github".to_string(),
676 transport: MCPWrapperTransport::Stdio {
677 command: "npx".to_string(),
678 args: vec![],
679 },
680 env: HashMap::new(),
681 startup_timeout_ms: 30000,
682 security: MCPWrapperSecurity::default(),
683 description: Some("GitHub integration for DevOps".to_string()),
684 views: HashMap::new(),
685 };
686 let tool = MCPWrapperTool::new(config);
687 assert_eq!(tool.description(), "GitHub integration for DevOps");
688 }
689
690 #[test]
691 fn test_view_config_deserialize() {
692 let yaml = r#"
693functions: [create_issue, list_issues]
694description: "Issue management"
695"#;
696 let config: MCPViewConfig = serde_yaml::from_str(yaml).unwrap();
697 assert_eq!(config.functions, vec!["create_issue", "list_issues"]);
698 assert_eq!(config.description.as_deref(), Some("Issue management"));
699 }
700
701 #[test]
702 fn test_view_config_no_description() {
703 let yaml = r#"
704functions: [search_code, get_pull_request]
705"#;
706 let config: MCPViewConfig = serde_yaml::from_str(yaml).unwrap();
707 assert_eq!(config.functions, vec!["search_code", "get_pull_request"]);
708 assert!(config.description.is_none());
709 }
710
711 #[test]
712 fn test_mcp_config_with_views() {
713 let yaml = r#"
714name: github
715type: mcp
716transport: stdio
717command: npx
718args: ["-y", "@modelcontextprotocol/server-github"]
719views:
720 github_issues:
721 functions: [create_issue, list_issues]
722 github_code:
723 functions: [search_code]
724 description: "Code search"
725"#;
726 let config: MCPWrapperConfig = serde_yaml::from_str(yaml).unwrap();
727 assert_eq!(config.views.len(), 2);
728 assert_eq!(
729 config.views["github_issues"].functions,
730 vec!["create_issue", "list_issues"]
731 );
732 assert_eq!(
733 config.views["github_code"].description.as_deref(),
734 Some("Code search")
735 );
736 }
737
738 #[test]
739 fn test_mcp_config_without_views() {
740 let yaml = r#"
741name: github
742type: mcp
743transport: stdio
744command: npx
745args: ["-y", "@modelcontextprotocol/server-github"]
746"#;
747 let config: MCPWrapperConfig = serde_yaml::from_str(yaml).unwrap();
748 assert!(config.views.is_empty());
749 }
750
751 #[test]
752 fn test_tool_entry_mcp_with_views() {
753 let yaml = r#"
754name: github
755type: mcp
756transport: stdio
757command: npx
758args: ["-y", "@modelcontextprotocol/server-github"]
759env:
760 GITHUB_TOKEN: "test"
761views:
762 github_issues:
763 functions: [create_issue, list_issues]
764 github_code:
765 functions: [search_code]
766 description: "Code search"
767"#;
768 let config: MCPWrapperConfig = serde_yaml::from_str(yaml).unwrap();
769 assert_eq!(config.views.len(), 2);
770 assert_eq!(
771 config.views["github_issues"].functions,
772 vec!["create_issue", "list_issues"]
773 );
774 }
775
776 #[test]
777 fn test_view_schema_filtered() {
778 let functions = vec![
779 DiscoveredFunction {
780 name: "create_issue".to_string(),
781 description: "Create a new issue".to_string(),
782 input_schema: json!({
783 "type": "object",
784 "properties": {
785 "repo": {"type": "string"},
786 "title": {"type": "string"}
787 }
788 }),
789 },
790 DiscoveredFunction {
791 name: "list_issues".to_string(),
792 description: "List issues".to_string(),
793 input_schema: json!({
794 "type": "object",
795 "properties": {
796 "repo": {"type": "string"}
797 }
798 }),
799 },
800 ];
801
802 let schema = MCPWrapperTool::build_schema("github_issues", &functions);
803 let func_enum = schema["properties"]["function"]["enum"].as_array().unwrap();
804 assert_eq!(func_enum.len(), 2);
805 assert!(func_enum.contains(&json!("create_issue")));
806 assert!(func_enum.contains(&json!("list_issues")));
807 }
808
809 #[test]
810 fn test_view_description_custom() {
811 let functions = vec![DiscoveredFunction {
812 name: "create_issue".to_string(),
813 description: "Create a new issue".to_string(),
814 input_schema: json!({}),
815 }];
816
817 let desc = MCPWrapperTool::build_description(
818 "github_issues",
819 Some("Issue management"),
820 &functions,
821 );
822 assert!(desc.starts_with("Issue management"));
823 assert!(desc.contains("create_issue"));
824 }
825}