Skip to main content

ai_agents_tools/mcp/
wrapper.rs

1//! MCP wrapper tool — presents an MCP server as a single builtin Tool.
2//!
3//! Each instance wraps one MCP server connection and presents ALL of the
4//! server's functions through a single tool with a `function` discriminator
5//! field, matching the pattern used by `datetime`, `math`, `json`, etc.
6
7use async_trait::async_trait;
8use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10use serde_json::{Value, json};
11use std::collections::HashMap;
12
13use rmcp::model as mcp_model;
14use rmcp::service::{Peer, RunningService};
15use rmcp::{RoleClient, ServiceExt};
16
17use ai_agents_core::{Tool, ToolResult};
18
19/// A discovered function from an MCP server (name, description, schema).
20#[derive(Debug, Clone)]
21pub(crate) struct DiscoveredFunction {
22    /// Original function name as reported by the MCP server.
23    pub(crate) name: String,
24    /// Human-readable description of the function.
25    pub(crate) description: String,
26    /// JSON Schema for the function's parameters.
27    pub(crate) input_schema: Value,
28}
29
30/// Configuration for the MCP wrapper tool, deserialized from YAML.
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct MCPWrapperConfig {
33    /// Display name for this tool (also used as the tool ID).
34    pub name: String,
35
36    /// Transport configuration (stdio, http, or sse).
37    #[serde(flatten)]
38    pub transport: MCPWrapperTransport,
39
40    /// Environment variables passed to the server process.
41    #[serde(default)]
42    pub env: HashMap<String, String>,
43
44    /// Startup timeout in milliseconds.
45    #[serde(default = "default_startup_timeout")]
46    pub startup_timeout_ms: u64,
47
48    /// Security settings for function-level blocking and HITL.
49    #[serde(default)]
50    pub security: MCPWrapperSecurity,
51
52    /// Optional custom description override.
53    /// If not set, auto-generated from discovered functions.
54    #[serde(default)]
55    pub description: Option<String>,
56
57    /// Named views: subsets of this server's functions registered as separate tools.
58    #[serde(default)]
59    pub views: HashMap<String, MCPViewConfig>,
60}
61
62fn default_startup_timeout() -> u64 {
63    30_000
64}
65
66/// Transport configuration for connecting to an MCP server.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68#[serde(tag = "transport", rename_all = "lowercase")]
69pub enum MCPWrapperTransport {
70    Stdio {
71        command: String,
72        #[serde(default)]
73        args: Vec<String>,
74    },
75    Http {
76        url: String,
77        #[serde(default)]
78        headers: HashMap<String, String>,
79    },
80    #[serde(alias = "sse")]
81    Sse {
82        url: String,
83        #[serde(default)]
84        headers: HashMap<String, String>,
85    },
86}
87
88/// Security settings for the MCP wrapper tool.
89#[derive(Debug, Clone, Serialize, Deserialize, Default)]
90pub struct MCPWrapperSecurity {
91    /// Functions that should never be exposed to the LLM.
92    #[serde(default)]
93    pub blocked_functions: Vec<String>,
94
95    /// Functions that require HITL approval before execution.
96    #[serde(default)]
97    pub hitl_functions: Vec<String>,
98}
99
100/// Configuration for a single MCP view (a named function subset).
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct MCPViewConfig {
103    /// Whitelist of function names to include in this view.
104    pub functions: Vec<String>,
105    /// Optional custom description for this view tool.
106    #[serde(default)]
107    pub description: Option<String>,
108}
109
110/// An MCP server exposed as a single builtin Tool.
111///
112/// Connects to an MCP server at initialization, discovers available functions
113/// via `peer.list_tools()`, and builds a dynamic `input_schema()` with
114/// `function` as an enum of discovered names and `params` as per-function
115/// parameters. Uses two-phase construction: `new()` then `initialized()`.
116pub struct MCPWrapperTool {
117    config: MCPWrapperConfig,
118    /// Immutable after `initialized()` — tool description shown to the LLM.
119    description: String,
120    /// Immutable after `initialized()` — JSON Schema for the tool's input.
121    schema: Value,
122    /// Running service handle — kept alive so the background task is not dropped.
123    _running: RwLock<Option<RunningService<RoleClient, ()>>>,
124    /// Peer for issuing MCP requests.
125    peer: RwLock<Option<Peer<RoleClient>>>,
126    /// Discovered functions from the MCP server (populated after init).
127    functions: Vec<DiscoveredFunction>,
128}
129
130impl MCPWrapperTool {
131    /// Create a new wrapper tool from configuration.
132    /// The tool is NOT connected yet — call `initialized()` to connect and discover.
133    pub fn new(config: MCPWrapperConfig) -> Self {
134        let desc = config
135            .description
136            .clone()
137            .unwrap_or_else(|| format!("{} operations via MCP", config.name));
138        Self {
139            config,
140            description: desc,
141            schema: json!({"type": "object"}),
142            _running: RwLock::new(None),
143            peer: RwLock::new(None),
144            functions: Vec::new(),
145        }
146    }
147
148    /// Connect to the MCP server, discover functions, and build schema/description.
149    /// Returns a new `MCPWrapperTool` with the discovered state baked in.
150    pub async fn initialized(mut self) -> Result<Self, String> {
151        let running = match &self.config.transport {
152            MCPWrapperTransport::Stdio { command, args } => {
153                Self::connect_stdio(command, args, &self.config.env, &self.config.name).await?
154            }
155            MCPWrapperTransport::Http { url, headers }
156            | MCPWrapperTransport::Sse { url, headers } => {
157                Self::connect_http(url, headers, &self.config.name).await?
158            }
159        };
160
161        let peer = running.peer().clone();
162
163        // Discover functions from the MCP server
164        let tool_list = peer
165            .list_all_tools()
166            .await
167            .map_err(|e| format!("Failed to list tools from '{}': {}", self.config.name, e))?;
168
169        let mut functions = Vec::new();
170        for tool in &tool_list {
171            let name = tool.name.to_string();
172
173            // Skip blocked functions
174            if self.config.security.blocked_functions.contains(&name) {
175                tracing::debug!(
176                    server = %self.config.name,
177                    function = %name,
178                    "Skipping blocked MCP function"
179                );
180                continue;
181            }
182
183            let description = tool
184                .description
185                .as_ref()
186                .map(|d| d.to_string())
187                .unwrap_or_default();
188
189            let input_schema = Value::Object(tool.input_schema.as_ref().clone());
190
191            functions.push(DiscoveredFunction {
192                name,
193                description,
194                input_schema,
195            });
196        }
197
198        tracing::info!(
199            server = %self.config.name,
200            functions = functions.len(),
201            "MCP wrapper tool initialized"
202        );
203
204        // Build immutable schema and description
205        self.schema = Self::build_schema(&self.config.name, &functions);
206        self.description = Self::build_description(
207            &self.config.name,
208            self.config.description.as_deref(),
209            &functions,
210        );
211        self.functions = functions;
212        *self.peer.write() = Some(peer);
213        *self._running.write() = Some(running);
214
215        Ok(self)
216    }
217
218    /// Build the dynamic input schema from discovered functions.
219    pub(crate) fn build_schema(server_name: &str, functions: &[DiscoveredFunction]) -> Value {
220        let function_names: Vec<Value> = functions
221            .iter()
222            .map(|f| Value::String(f.name.clone()))
223            .collect();
224
225        // Build per-function parameter hints for the LLM
226        let mut params_description =
227            String::from("Parameters for the selected function. See function list for details.");
228
229        if functions.len() <= 30 {
230            params_description = String::from("Parameters for the selected function:\n");
231            for f in functions {
232                if let Some(props) = f.input_schema.get("properties") {
233                    let prop_names: Vec<&str> = props
234                        .as_object()
235                        .map(|obj| obj.keys().map(|k| k.as_str()).collect())
236                        .unwrap_or_default();
237                    if !prop_names.is_empty() {
238                        params_description.push_str(&format!(
239                            "  - {}: {{{}}}\n",
240                            f.name,
241                            prop_names.join(", ")
242                        ));
243                    } else {
244                        params_description.push_str(&format!("  - {}: (no parameters)\n", f.name));
245                    }
246                }
247            }
248        }
249
250        json!({
251            "type": "object",
252            "required": ["function"],
253            "properties": {
254                "function": {
255                    "type": "string",
256                    "description": format!("The function to call inside the '{}' tool. Pass this as arguments.function, NOT as the tool name.", server_name),
257                    "enum": function_names
258                },
259                "params": {
260                    "type": "object",
261                    "description": params_description,
262                    "additionalProperties": true
263                }
264            }
265        })
266    }
267
268    /// Build a rich description listing all available functions.
269    pub(crate) fn build_description(
270        server_name: &str,
271        custom: Option<&str>,
272        functions: &[DiscoveredFunction],
273    ) -> String {
274        let mut desc = match custom {
275            Some(c) if !c.is_empty() => c.to_string(),
276            _ => format!("{} operations via MCP.", server_name),
277        };
278
279        if !functions.is_empty() {
280            // Clarify the dispatch pattern: the tool name is server_name,
281            // function names go inside arguments.function.
282            desc.push_str(&format!(
283                " Use tool '{}' with arguments.function set to one of: ",
284                server_name
285            ));
286            let names: Vec<&str> = functions.iter().map(|f| f.name.as_str()).collect();
287            desc.push_str(&names.join(", "));
288            desc.push('.');
289
290            // Add per-function descriptions for smaller function sets
291            if functions.len() <= 20 {
292                desc.push_str("\n\nFunction details:");
293                for f in functions {
294                    if !f.description.is_empty() {
295                        desc.push_str(&format!("\n- {}: {}", f.name, f.description));
296                    } else {
297                        desc.push_str(&format!("\n- {}", f.name));
298                    }
299                }
300            }
301        }
302
303        desc
304    }
305
306    /// Execute a function call on the MCP server.
307    pub(crate) async fn call_function(&self, function: &str, params: Value) -> ToolResult {
308        // Validate that the function exists
309        if !self.functions.iter().any(|f| f.name == function) {
310            let available: Vec<&str> = self.functions.iter().map(|f| f.name.as_str()).collect();
311            return ToolResult::error(format!(
312                "Unknown function '{}'. Available functions: {}",
313                function,
314                available.join(", ")
315            ));
316        }
317
318        let peer = {
319            let peer_guard = self.peer.read();
320            match peer_guard.as_ref() {
321                Some(p) => p.clone(),
322                None => {
323                    return ToolResult::error(format!(
324                        "MCP server '{}' not initialized",
325                        self.config.name
326                    ));
327                }
328            }
329        };
330
331        let mut call_params = mcp_model::CallToolRequestParams::new(function.to_string());
332        if let Value::Object(map) = params {
333            call_params.arguments = Some(map.into_iter().collect());
334        }
335
336        match peer.call_tool(call_params).await {
337            Ok(result) => {
338                let output = result
339                    .content
340                    .iter()
341                    .filter_map(|c| match &c.raw {
342                        mcp_model::RawContent::Text(t) => Some(t.text.as_str()),
343                        _ => None,
344                    })
345                    .collect::<Vec<_>>()
346                    .join("\n");
347
348                if result.is_error.unwrap_or(false) {
349                    ToolResult::error(output)
350                } else {
351                    ToolResult::ok(output)
352                }
353            }
354            Err(e) => ToolResult::error(format!("MCP function '{}' failed: {}", function, e)),
355        }
356    }
357
358    /// Return the subset of discovered functions matching the given names.
359    pub(crate) fn get_functions_filtered(&self, names: &[String]) -> Vec<DiscoveredFunction> {
360        self.functions
361            .iter()
362            .filter(|f| names.iter().any(|n| n == &f.name))
363            .cloned()
364            .collect()
365    }
366
367    /// Connect to an MCP server via stdio transport.
368    async fn connect_stdio(
369        command: &str,
370        args: &[String],
371        env: &HashMap<String, String>,
372        server_name: &str,
373    ) -> Result<RunningService<RoleClient, ()>, String> {
374        use rmcp::transport::TokioChildProcess;
375        use tokio::process::Command;
376
377        let mut cmd = Command::new(command);
378        cmd.args(args);
379        for (key, value) in env {
380            cmd.env(key, value);
381        }
382
383        let transport = TokioChildProcess::new(cmd)
384            .map_err(|e| format!("Failed to spawn '{}': {}", command, e))?;
385
386        let running: RunningService<RoleClient, ()> = ()
387            .serve(transport)
388            .await
389            .map_err(|e| format!("Failed MCP handshake with '{}': {}", server_name, e))?;
390
391        Ok(running)
392    }
393
394    /// Connect to an MCP server via HTTP/SSE transport.
395    async fn connect_http(
396        url: &str,
397        headers: &HashMap<String, String>,
398        server_name: &str,
399    ) -> Result<RunningService<RoleClient, ()>, String> {
400        use rmcp::transport::streamable_http_client::{
401            StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
402        };
403
404        if headers.is_empty() {
405            let transport = StreamableHttpClientTransport::from_uri(url);
406            let running: RunningService<RoleClient, ()> = ()
407                .serve(transport)
408                .await
409                .map_err(|e| format!("Failed HTTP MCP connection to '{}': {}", server_name, e))?;
410            Ok(running)
411        } else {
412            use reqwest::header::{HeaderName, HeaderValue};
413
414            let mut custom_headers = HashMap::new();
415            for (key, value) in headers {
416                let header_name = HeaderName::try_from(key.as_str())
417                    .map_err(|e| format!("Invalid header name '{}': {}", key, e))?;
418                let header_value = HeaderValue::try_from(value.as_str())
419                    .map_err(|e| format!("Invalid header value for '{}': {}", key, e))?;
420                custom_headers.insert(header_name, header_value);
421            }
422
423            let config =
424                StreamableHttpClientTransportConfig::with_uri(url).custom_headers(custom_headers);
425            let transport = StreamableHttpClientTransport::from_config(config);
426
427            let running: RunningService<RoleClient, ()> = ()
428                .serve(transport)
429                .await
430                .map_err(|e| format!("Failed HTTP MCP connection to '{}': {}", server_name, e))?;
431            Ok(running)
432        }
433    }
434
435    /// Gracefully shut down the MCP server connection.
436    pub async fn shutdown(&self) {
437        let running = self._running.write().take();
438        if let Some(r) = running {
439            let _ = r.cancel().await;
440        }
441        self.peer.write().take();
442    }
443
444    /// Check if a specific function requires HITL approval.
445    pub fn requires_hitl(&self, function_name: &str) -> bool {
446        self.config
447            .security
448            .hitl_functions
449            .iter()
450            .any(|f| f == function_name)
451    }
452
453    /// Get the number of discovered functions.
454    pub fn function_count(&self) -> usize {
455        self.functions.len()
456    }
457
458    /// Get the list of discovered function names.
459    pub fn function_names(&self) -> Vec<&str> {
460        self.functions.iter().map(|f| f.name.as_str()).collect()
461    }
462}
463
464#[async_trait]
465impl Tool for MCPWrapperTool {
466    fn id(&self) -> &str {
467        &self.config.name
468    }
469
470    fn name(&self) -> &str {
471        &self.config.name
472    }
473
474    fn description(&self) -> &str {
475        &self.description
476    }
477
478    fn input_schema(&self) -> Value {
479        self.schema.clone()
480    }
481
482    async fn execute(&self, args: Value) -> ToolResult {
483        // Extract the `function` field from input
484        let function = match args.get("function").and_then(|v| v.as_str()) {
485            Some(f) => f.to_string(),
486            None => {
487                let available: Vec<&str> = self.functions.iter().map(|f| f.name.as_str()).collect();
488                return ToolResult::error(format!(
489                    "'function' is required. Available functions: {}",
490                    available.join(", ")
491                ));
492            }
493        };
494
495        // Extract optional `params` field (defaults to empty object)
496        let params = args.get("params").cloned().unwrap_or_else(|| json!({}));
497
498        // Per-function HITL: signal the runtime via metadata if approval is needed.
499        // The runtime's HITL engine sees the tool ID ("github"), not the function.
500        // For per-function granularity, we return metadata that the runtime can inspect.
501        if self.requires_hitl(&function) {
502            return ToolResult::ok_with_metadata(
503                format!(
504                    "Function '{}' on MCP server '{}' requires approval before execution.",
505                    function, self.config.name
506                ),
507                HashMap::from([
508                    ("_hitl_required".to_string(), json!(true)),
509                    ("_hitl_function".to_string(), json!(function)),
510                    ("_hitl_params".to_string(), params.clone()),
511                    ("_hitl_tool".to_string(), json!(self.config.name)),
512                ]),
513            );
514        }
515
516        self.call_function(&function, params).await
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523
524    #[test]
525    fn test_mcp_wrapper_config_deserialize_stdio() {
526        let yaml = r#"
527name: github
528type: mcp
529transport: stdio
530command: npx
531args: ["-y", "@modelcontextprotocol/server-github"]
532env:
533  GITHUB_TOKEN: "test-token"
534startup_timeout_ms: 15000
535security:
536  blocked_functions: [delete_repo]
537  hitl_functions: [create_issue]
538"#;
539        let config: MCPWrapperConfig = serde_yaml::from_str(yaml).unwrap();
540        assert_eq!(config.name, "github");
541        assert_eq!(config.startup_timeout_ms, 15000);
542        assert_eq!(config.security.blocked_functions, vec!["delete_repo"]);
543        assert_eq!(config.security.hitl_functions, vec!["create_issue"]);
544    }
545
546    #[test]
547    fn test_mcp_wrapper_config_deserialize_http() {
548        let yaml = r#"
549name: custom_api
550type: mcp
551transport: http
552url: "http://localhost:3000/mcp"
553headers:
554  Authorization: "Bearer test"
555"#;
556        let config: MCPWrapperConfig = serde_yaml::from_str(yaml).unwrap();
557        assert_eq!(config.name, "custom_api");
558    }
559
560    #[test]
561    fn test_build_schema() {
562        let functions = vec![
563            DiscoveredFunction {
564                name: "create_issue".to_string(),
565                description: "Create a new issue".to_string(),
566                input_schema: json!({
567                    "type": "object",
568                    "properties": {
569                        "repo": {"type": "string"},
570                        "title": {"type": "string"},
571                        "body": {"type": "string"}
572                    },
573                    "required": ["repo", "title"]
574                }),
575            },
576            DiscoveredFunction {
577                name: "list_repos".to_string(),
578                description: "List repositories".to_string(),
579                input_schema: json!({
580                    "type": "object",
581                    "properties": {
582                        "org": {"type": "string"}
583                    }
584                }),
585            },
586        ];
587
588        let schema = MCPWrapperTool::build_schema("github", &functions);
589
590        assert_eq!(schema["type"], "object");
591        assert!(
592            schema["required"]
593                .as_array()
594                .unwrap()
595                .contains(&json!("function"))
596        );
597        let func_enum = &schema["properties"]["function"]["enum"];
598        assert!(
599            func_enum
600                .as_array()
601                .unwrap()
602                .contains(&json!("create_issue"))
603        );
604        assert!(func_enum.as_array().unwrap().contains(&json!("list_repos")));
605    }
606
607    #[test]
608    fn test_build_description() {
609        let functions = vec![
610            DiscoveredFunction {
611                name: "create_issue".to_string(),
612                description: "Create a new issue".to_string(),
613                input_schema: json!({}),
614            },
615            DiscoveredFunction {
616                name: "list_repos".to_string(),
617                description: "List repositories".to_string(),
618                input_schema: json!({}),
619            },
620        ];
621
622        let desc = MCPWrapperTool::build_description("github", None, &functions);
623
624        assert!(desc.contains("github operations via MCP"));
625        assert!(desc.contains("Use tool 'github'"));
626        assert!(desc.contains("create_issue"));
627        assert!(desc.contains("list_repos"));
628        assert!(desc.contains("Create a new issue"));
629    }
630
631    #[test]
632    fn test_requires_hitl() {
633        let config = MCPWrapperConfig {
634            name: "github".to_string(),
635            transport: MCPWrapperTransport::Stdio {
636                command: "npx".to_string(),
637                args: vec![],
638            },
639            env: HashMap::new(),
640            startup_timeout_ms: 30000,
641            security: MCPWrapperSecurity {
642                blocked_functions: vec![],
643                hitl_functions: vec!["create_issue".to_string()],
644            },
645            description: None,
646            views: HashMap::new(),
647        };
648        let tool = MCPWrapperTool::new(config);
649
650        assert!(tool.requires_hitl("create_issue"));
651        assert!(!tool.requires_hitl("list_repos"));
652    }
653
654    #[test]
655    fn test_default_description() {
656        let config = MCPWrapperConfig {
657            name: "github".to_string(),
658            transport: MCPWrapperTransport::Stdio {
659                command: "npx".to_string(),
660                args: vec![],
661            },
662            env: HashMap::new(),
663            startup_timeout_ms: 30000,
664            security: MCPWrapperSecurity::default(),
665            description: None,
666            views: HashMap::new(),
667        };
668        let tool = MCPWrapperTool::new(config);
669        assert_eq!(tool.description(), "github operations via MCP");
670    }
671
672    #[test]
673    fn test_custom_description() {
674        let config = MCPWrapperConfig {
675            name: "github".to_string(),
676            transport: MCPWrapperTransport::Stdio {
677                command: "npx".to_string(),
678                args: vec![],
679            },
680            env: HashMap::new(),
681            startup_timeout_ms: 30000,
682            security: MCPWrapperSecurity::default(),
683            description: Some("GitHub integration for DevOps".to_string()),
684            views: HashMap::new(),
685        };
686        let tool = MCPWrapperTool::new(config);
687        assert_eq!(tool.description(), "GitHub integration for DevOps");
688    }
689
690    #[test]
691    fn test_view_config_deserialize() {
692        let yaml = r#"
693functions: [create_issue, list_issues]
694description: "Issue management"
695"#;
696        let config: MCPViewConfig = serde_yaml::from_str(yaml).unwrap();
697        assert_eq!(config.functions, vec!["create_issue", "list_issues"]);
698        assert_eq!(config.description.as_deref(), Some("Issue management"));
699    }
700
701    #[test]
702    fn test_view_config_no_description() {
703        let yaml = r#"
704functions: [search_code, get_pull_request]
705"#;
706        let config: MCPViewConfig = serde_yaml::from_str(yaml).unwrap();
707        assert_eq!(config.functions, vec!["search_code", "get_pull_request"]);
708        assert!(config.description.is_none());
709    }
710
711    #[test]
712    fn test_mcp_config_with_views() {
713        let yaml = r#"
714name: github
715type: mcp
716transport: stdio
717command: npx
718args: ["-y", "@modelcontextprotocol/server-github"]
719views:
720  github_issues:
721    functions: [create_issue, list_issues]
722  github_code:
723    functions: [search_code]
724    description: "Code search"
725"#;
726        let config: MCPWrapperConfig = serde_yaml::from_str(yaml).unwrap();
727        assert_eq!(config.views.len(), 2);
728        assert_eq!(
729            config.views["github_issues"].functions,
730            vec!["create_issue", "list_issues"]
731        );
732        assert_eq!(
733            config.views["github_code"].description.as_deref(),
734            Some("Code search")
735        );
736    }
737
738    #[test]
739    fn test_mcp_config_without_views() {
740        let yaml = r#"
741name: github
742type: mcp
743transport: stdio
744command: npx
745args: ["-y", "@modelcontextprotocol/server-github"]
746"#;
747        let config: MCPWrapperConfig = serde_yaml::from_str(yaml).unwrap();
748        assert!(config.views.is_empty());
749    }
750
751    #[test]
752    fn test_tool_entry_mcp_with_views() {
753        let yaml = r#"
754name: github
755type: mcp
756transport: stdio
757command: npx
758args: ["-y", "@modelcontextprotocol/server-github"]
759env:
760  GITHUB_TOKEN: "test"
761views:
762  github_issues:
763    functions: [create_issue, list_issues]
764  github_code:
765    functions: [search_code]
766    description: "Code search"
767"#;
768        let config: MCPWrapperConfig = serde_yaml::from_str(yaml).unwrap();
769        assert_eq!(config.views.len(), 2);
770        assert_eq!(
771            config.views["github_issues"].functions,
772            vec!["create_issue", "list_issues"]
773        );
774    }
775
776    #[test]
777    fn test_view_schema_filtered() {
778        let functions = vec![
779            DiscoveredFunction {
780                name: "create_issue".to_string(),
781                description: "Create a new issue".to_string(),
782                input_schema: json!({
783                    "type": "object",
784                    "properties": {
785                        "repo": {"type": "string"},
786                        "title": {"type": "string"}
787                    }
788                }),
789            },
790            DiscoveredFunction {
791                name: "list_issues".to_string(),
792                description: "List issues".to_string(),
793                input_schema: json!({
794                    "type": "object",
795                    "properties": {
796                        "repo": {"type": "string"}
797                    }
798                }),
799            },
800        ];
801
802        let schema = MCPWrapperTool::build_schema("github_issues", &functions);
803        let func_enum = schema["properties"]["function"]["enum"].as_array().unwrap();
804        assert_eq!(func_enum.len(), 2);
805        assert!(func_enum.contains(&json!("create_issue")));
806        assert!(func_enum.contains(&json!("list_issues")));
807    }
808
809    #[test]
810    fn test_view_description_custom() {
811        let functions = vec![DiscoveredFunction {
812            name: "create_issue".to_string(),
813            description: "Create a new issue".to_string(),
814            input_schema: json!({}),
815        }];
816
817        let desc = MCPWrapperTool::build_description(
818            "github_issues",
819            Some("Issue management"),
820            &functions,
821        );
822        assert!(desc.starts_with("Issue management"));
823        assert!(desc.contains("create_issue"));
824    }
825}