Skip to main content

rab/extensions/mcp/
mod.rs

1//! MCP (Model Context Protocol) extension.
2//!
3//! Provides:
4//! - A unified `mcp` proxy tool (status, list, search, describe, connect, call, auth)
5//! - Direct tool adapters for servers with `directTools` enabled
6//! - Tool renderers for both proxy and direct MCP tools
7//!
8//! Mirrors pi-mcp-adapter's architecture but adapted to Rust/yoagent patterns.
9
10mod cache;
11mod config;
12mod renderer;
13pub mod server;
14pub mod types;
15
16use crate::agent::extension::{Extension, ToolDefinition};
17use cache::{has_valid_cache, load_cache, update_cache_entry};
18use renderer::{McpProxyToolRenderer, McpToolRenderer};
19use server::ServerManager;
20use std::borrow::Cow;
21use std::collections::HashMap;
22use std::path::Path;
23use std::sync::Arc;
24use tokio::sync::Mutex;
25use types::format_tool_name;
26use yoagent::mcp::types::{McpContent, McpToolInfo};
27use yoagent::types::{AgentTool, Content, ToolContext, ToolError, ToolResult};
28
29// ── Re-exports for external use ────────────────────────────────────
30pub use cache::{load_cache as load_metadata_cache, save_cache as save_metadata_cache};
31pub use config::load_mcp_config;
32pub use types::{CachedTool, McpConfig, McpSettings, MetadataCache, ServerCacheEntry, ServerEntry};
33
34/// Maximum number of results returned by `mcp search`.
35const MAX_SEARCH_RESULTS: usize = 30;
36
37// ═══════════════════════════════════════════════════════════════════
38// MCP Extension
39// ═══════════════════════════════════════════════════════════════════
40
41/// MCP Extension that provides MCP server management.
42///
43/// Provides:
44/// - `mcp` proxy tool (gateway to all configured servers)
45/// - Direct tools for servers with `directTools` enabled (optional)
46/// - Tool renderers for MCP tool calls/results
47pub struct McpExtension {
48    config: McpConfig,
49    manager: Arc<Mutex<ServerManager>>,
50    /// Cached tool metadata by server name.
51    tool_cache: Arc<Mutex<HashMap<String, Vec<McpToolInfo>>>>,
52}
53
54impl McpExtension {
55    /// Create a new MCP extension from a loaded config.
56    pub fn new(config: McpConfig) -> Self {
57        let idle_timeout = config
58            .settings
59            .as_ref()
60            .map(|s| s.idle_timeout)
61            .unwrap_or(10);
62
63        let mut manager = ServerManager::new(idle_timeout);
64
65        // Register all servers from config
66        for (name, entry) in &config.mcp_servers {
67            let config_hash = config::compute_server_config_hash(entry);
68            manager.register(name, entry.clone(), config_hash);
69        }
70
71        Self {
72            config,
73            manager: Arc::new(Mutex::new(manager)),
74            tool_cache: Arc::new(Mutex::new(HashMap::new())),
75        }
76    }
77
78    /// Create from config loaded from disk at the given working directory.
79    pub fn from_cwd(cwd: &Path) -> Self {
80        let config = load_mcp_config(cwd);
81        Self::new(config)
82    }
83
84    /// Restore cached tool metadata from the on-disk cache.
85    /// Should be called once at startup to prime the cache without connecting.
86    pub async fn restore_cache(&self) {
87        let cache = load_cache();
88        let mut tool_cache = self.tool_cache.lock().await;
89        for (server_name, entry) in &cache.servers {
90            let def = self.config.mcp_servers.get(server_name);
91            let ch = def.map(config::compute_server_config_hash).unwrap_or(0);
92            if entry.config_hash != ch {
93                continue;
94            }
95            let tools: Vec<McpToolInfo> = entry
96                .tools
97                .iter()
98                .map(|t| McpToolInfo {
99                    name: t.name.clone(),
100                    description: t.description.clone(),
101                    input_schema: if t.input_schema.is_null() {
102                        serde_json::json!({"type": "object", "properties": {}})
103                    } else {
104                        t.input_schema.clone()
105                    },
106                })
107                .collect();
108            if !tools.is_empty() {
109                tool_cache.insert(server_name.clone(), tools);
110            }
111        }
112    }
113
114    /// Bootstrap direct tools — checks which servers have directTools configured
115    /// but no cached metadata yet, and logs a hint. Does NOT block startup on
116    /// network connections. The first connection (via `mcp({{ connect: ... }})`
117    /// or `mcp({{ server: ... }})`) populates the cache; on subsequent startups
118    /// direct tools are available automatically.
119    pub async fn bootstrap_direct_tools(&self) {
120        let global_direct_tools = self
121            .config
122            .settings
123            .as_ref()
124            .is_some_and(|s| s.direct_tools);
125        let missing_cache: Vec<String> = self
126            .config
127            .mcp_servers
128            .iter()
129            .filter(|(server_name, entry)| {
130                let has_direct = match entry.direct_tools.as_ref() {
131                    Some(v) if v.is_boolean() => v.as_bool().unwrap_or(false),
132                    Some(v) if v.is_array() => true,
133                    None => global_direct_tools,
134                    Some(_) => false,
135                };
136                if !has_direct {
137                    return false;
138                }
139                let config_hash = config::compute_server_config_hash(entry);
140                !has_valid_cache(server_name, config_hash)
141            })
142            .map(|(name, _)| name.clone())
143            .collect();
144
145        if !missing_cache.is_empty() {
146            eprintln!(
147                "MCP: direct tools configured for {} but no cached metadata yet. \
148                 Connect once via the mcp proxy tool, then restart.",
149                missing_cache.join(", ")
150            );
151        }
152    }
153}
154
155impl Extension for McpExtension {
156    fn name(&self) -> Cow<'static, str> {
157        "mcp".into()
158    }
159
160    fn tools(&self) -> Vec<ToolDefinition> {
161        let mut tools: Vec<ToolDefinition> = Vec::new();
162
163        // The proxy mcp tool is always available
164        tools.push(ToolDefinition {
165            tool: Box::new(McpProxyTool {
166                config: self.config.clone(),
167                manager: self.manager.clone(),
168                tool_cache: self.tool_cache.clone(),
169            }),
170            snippet: "MCP gateway - connect to MCP servers and call their tools. Non-MCP Pi tools should be called directly, not through mcp.",
171            guidelines: &[
172                "Use mcp to connect to external MCP tool servers",
173                "Direct tools for configured servers can be called directly without mcp",
174                "The proxy tool handles connect, list, search, describe, and call operations",
175            ],
176            prepare_arguments: None,
177            before_tool_call: None,
178            after_tool_call: None,
179            renderer: Some(std::sync::Arc::new(McpProxyToolRenderer)),
180        });
181
182        // Add direct tools for servers with directTools enabled.
183        // Per-server directTools takes precedence; falls back to global setting.
184        let global_direct_tools = self
185            .config
186            .settings
187            .as_ref()
188            .is_some_and(|s| s.direct_tools);
189        let cache = load_cache();
190        let prefix_mode = self
191            .config
192            .settings
193            .as_ref()
194            .map(|s| s.tool_prefix.as_str())
195            .unwrap_or("server");
196
197        for (server_name, entry) in &self.config.mcp_servers {
198            let direct = entry.direct_tools.as_ref();
199            let has_direct = match direct {
200                Some(v) if v.is_boolean() => v.as_bool().unwrap_or(false),
201                Some(v) if v.is_array() => true,
202                None => global_direct_tools,
203                Some(_) => false,
204            };
205
206            if !has_direct {
207                continue;
208            }
209
210            // Collect tool names for this server.
211            // When directTools is an array, use those names directly (no cache needed).
212            // When directTools is true, fall back to cached metadata.
213            let tool_names: Vec<String> = match direct {
214                Some(v) if v.is_array() => v
215                    .as_array()
216                    .unwrap()
217                    .iter()
218                    .filter_map(|s| s.as_str().map(String::from))
219                    .collect(),
220                _ => {
221                    // Need cache to know tool names
222                    let config_hash = config::compute_server_config_hash(entry);
223                    if !has_valid_cache(server_name, config_hash) {
224                        continue;
225                    }
226                    cache
227                        .servers
228                        .get(server_name)
229                        .map(|s| &s.tools)
230                        .into_iter()
231                        .flatten()
232                        .map(|ct| ct.name.clone())
233                        .collect()
234                }
235            };
236
237            if tool_names.is_empty() {
238                continue;
239            }
240
241            // Look up cached metadata for descriptions/schemas if available
242            let cached_tools: Vec<&CachedTool> = cache
243                .servers
244                .get(server_name)
245                .map(|s| s.tools.iter().collect())
246                .unwrap_or_default();
247            for tool_name in &tool_names {
248                let prefixed = format_tool_name(tool_name, server_name, prefix_mode);
249
250                // Use cached metadata if available, otherwise provide defaults
251                let (description, input_schema) = cached_tools
252                    .iter()
253                    .find(|ct| ct.name == *tool_name)
254                    .map(|ct| {
255                        let desc = ct
256                            .description
257                            .clone()
258                            .unwrap_or_else(|| "MCP tool".to_string());
259                        let schema = if ct.input_schema.is_null() {
260                            serde_json::json!({"type": "object", "properties": {}})
261                        } else {
262                            ct.input_schema.clone()
263                        };
264                        (desc, schema)
265                    })
266                    .unwrap_or_else(|| {
267                        (
268                            format!("MCP tool: {} on {}", tool_name, server_name),
269                            serde_json::json!({"type": "object", "properties": {}}),
270                        )
271                    });
272
273                tools.push(ToolDefinition {
274                    tool: Box::new(McpDirectTool {
275                        server_name: server_name.clone(),
276                        original_name: tool_name.clone(),
277                        display_name: prefixed.clone(),
278                        description,
279                        input_schema,
280                        manager: self.manager.clone(),
281                    }),
282                    snippet: "MCP direct tool",
283                    guidelines: &[],
284                    prepare_arguments: None,
285                    before_tool_call: None,
286                    after_tool_call: None,
287                    renderer: Some(std::sync::Arc::new(McpToolRenderer::new(&prefixed))),
288                });
289            }
290        }
291
292        tools
293    }
294}
295
296// ═══════════════════════════════════════════════════════════════════
297// Proxy `mcp` Tool
298// ═══════════════════════════════════════════════════════════════════
299
300/// The unified `mcp` proxy tool — a gateway to all MCP servers.
301///
302/// Supports operations:
303/// - `{ }` — show status
304/// - `{ server: "name" }` — list tools from server
305/// - `{ tool: "name", args: '{"key": "val"}' }` — call a tool
306/// - `{ connect: "server-name" }` — connect to a server
307/// - `{ describe: "tool_name" }` — show tool details
308/// - `{ search: "query" }` — search tools by name/description
309/// - `{ action: "ui-messages" }` — retrieve UI session messages (stub)
310/// - `{ action: "auth-start", server: "name" }` — start OAuth (stub)
311/// - `{ action: "auth-complete", server: "name", args: '{"redirectUrl":"..."}' }` — complete OAuth (stub)
312struct McpProxyTool {
313    config: McpConfig,
314    manager: Arc<Mutex<ServerManager>>,
315    tool_cache: Arc<Mutex<HashMap<String, Vec<McpToolInfo>>>>,
316}
317
318impl McpProxyTool {
319    /// Ensure a server is connected (lazy connect).
320    async fn ensure_connected(&self, name: &str) -> bool {
321        let mut manager = self.manager.lock().await;
322        manager.ensure_connected(name).await
323    }
324
325    /// Cache the tools for a server after successful connection.
326    async fn cache_tools(&self, server_name: &str) {
327        let manager = self.manager.lock().await;
328        let client = manager.get_client(server_name);
329        drop(manager);
330
331        if let Some(client) = client {
332            let client = client.lock().await;
333            if let Ok(tools) = client.list_tools().await {
334                let config_hash = self
335                    .config
336                    .mcp_servers
337                    .get(server_name)
338                    .map(config::compute_server_config_hash)
339                    .unwrap_or(0);
340
341                let mut tool_cache = self.tool_cache.lock().await;
342                tool_cache.insert(server_name.to_string(), tools.clone());
343                drop(tool_cache);
344
345                update_cache_entry(server_name, config_hash, &tools);
346            }
347        }
348    }
349
350    /// Call a tool on a connected server.
351    async fn call_tool(
352        &self,
353        server_name: &str,
354        tool_name: &str,
355        args: serde_json::Value,
356    ) -> Result<(Vec<Content>, bool), String> {
357        let manager = self.manager.lock().await;
358        let client = manager.get_client(server_name);
359        drop(manager);
360
361        let client = match client {
362            Some(c) => c,
363            None => return Err(format!("Server '{}' is not connected", server_name)),
364        };
365
366        let client = client.lock().await;
367        let result = client
368            .call_tool(tool_name, args)
369            .await
370            .map_err(|e| format!("MCP call failed: {}", e))?;
371
372        let is_error = result.is_error;
373        let content: Vec<Content> = result
374            .content
375            .into_iter()
376            .map(|c| match c {
377                McpContent::Text { text } => Content::Text { text },
378                McpContent::Image { data, mime_type } => Content::Image { data, mime_type },
379            })
380            .collect();
381
382        Ok((content, is_error))
383    }
384
385    /// Format search results as a text response.
386    fn format_search_results(query: &str, matches: &[(String, McpToolInfo)]) -> String {
387        let mut text = format!(
388            "Found {} tool{} matching \"{}\":\n\n",
389            matches.len(),
390            if matches.len() == 1 { "" } else { "s" },
391            query
392        );
393
394        for (server_name, tool) in matches {
395            text.push_str(&format!(
396                "{} @ {}\n  {}\n",
397                tool.name,
398                server_name,
399                tool.description.as_deref().unwrap_or("(no description)")
400            ));
401
402            let schema = &tool.input_schema;
403            if !schema.is_null()
404                && schema.is_object()
405                && let Some(props) = schema.get("properties").and_then(|p| p.as_object())
406                && !props.is_empty()
407            {
408                let required: std::collections::HashSet<&str> = schema
409                    .get("required")
410                    .and_then(|r| r.as_array())
411                    .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect())
412                    .unwrap_or_default();
413
414                text.push_str("    Parameters:\n");
415                for (prop_name, prop_schema) in props {
416                    let is_req = required.contains(prop_name.as_str());
417                    let type_str = prop_schema
418                        .get("type")
419                        .and_then(|t| t.as_str())
420                        .unwrap_or("any");
421                    let desc = prop_schema
422                        .get("description")
423                        .and_then(|d| d.as_str())
424                        .unwrap_or("");
425                    text.push_str(&format!(
426                        "    - {} ({}){} {}\n",
427                        prop_name,
428                        type_str,
429                        if is_req { " *required*" } else { "" },
430                        if desc.is_empty() {
431                            String::new()
432                        } else {
433                            format!("- {}", desc)
434                        }
435                    ));
436                }
437            }
438            text.push('\n');
439        }
440
441        text.trim().to_string()
442    }
443
444    /// Execute the status operation: list all configured servers and their status.
445    async fn execute_status(&self) -> ToolResult {
446        let manager = self.manager.lock().await;
447        let tool_cache = self.tool_cache.lock().await;
448
449        let mut lines = vec![format!(
450            "MCP: {} servers configured",
451            self.config.mcp_servers.len()
452        )];
453        lines.push(String::new());
454
455        for name in self.config.mcp_servers.keys() {
456            let status = manager.status(name);
457            let tool_count = tool_cache.get(name).map(|v| v.len()).unwrap_or(0);
458            let status_str = match status {
459                Some(server::ConnectionStatus::Connected) => "✓ connected",
460                Some(server::ConnectionStatus::Idle) => "○ idle",
461                Some(server::ConnectionStatus::Failed) => "✗ failed",
462                None => "○ not connected",
463            };
464            lines.push(format!("{} {} ({} tools)", status_str, name, tool_count));
465        }
466
467        if !self.config.mcp_servers.is_empty() {
468            lines.push(String::new());
469            lines.push(
470                "mcp({ server: \"name\" }) to list tools, mcp({ search: \"...\" }) to search"
471                    .to_string(),
472            );
473        }
474
475        ToolResult {
476            content: vec![Content::Text {
477                text: lines.join("\n"),
478            }],
479            details: serde_json::json!({"mode": "status"}),
480        }
481    }
482
483    /// Execute the list operation: list tools for a specific server.
484    async fn execute_list(&self, server_name: &str) -> ToolResult {
485        // Ensure connected
486        let connected = {
487            let mut manager = self.manager.lock().await;
488            manager.ensure_connected(server_name).await
489        };
490
491        if connected {
492            // Cache tools after connecting
493            let manager = self.manager.lock().await;
494            let client = manager.get_client(server_name);
495            drop(manager);
496
497            if let Some(client) = client {
498                let client = client.lock().await;
499                if let Ok(tools) = client.list_tools().await {
500                    let config_hash = self
501                        .config
502                        .mcp_servers
503                        .get(server_name)
504                        .map(config::compute_server_config_hash)
505                        .unwrap_or(0);
506
507                    let mut tool_cache = self.tool_cache.lock().await;
508                    tool_cache.insert(server_name.to_string(), tools.clone());
509                    drop(tool_cache);
510
511                    update_cache_entry(server_name, config_hash, &tools);
512                }
513            }
514        }
515
516        let tool_cache = self.tool_cache.lock().await;
517        let tools = tool_cache.get(server_name);
518
519        match tools {
520            Some(tool_list) if !tool_list.is_empty() => {
521                let mut lines = vec![format!("{} ({} tools):", server_name, tool_list.len())];
522                lines.push(String::new());
523
524                for tool in tool_list {
525                    let desc = tool.description.as_deref().unwrap_or("");
526                    let truncated = if desc.len() > 80 {
527                        format!("{}...", &desc[..77])
528                    } else {
529                        desc.to_string()
530                    };
531                    lines.push(format!("- {}", tool.name));
532                    if !truncated.is_empty() {
533                        lines.push(format!("  {}", truncated));
534                    }
535                }
536
537                ToolResult {
538                    content: vec![Content::Text {
539                        text: lines.join("\n"),
540                    }],
541                    details: serde_json::json!({"mode": "list", "server": server_name, "tools": tool_list.len()}),
542                }
543            }
544            _ => {
545                if self.config.mcp_servers.contains_key(server_name) {
546                    ToolResult {
547                        content: vec![Content::Text {
548                            text: format!(
549                                "Server \"{}\" has no tools (or hasn't been connected yet). Use mcp({{ connect: \"{}\" }}) to connect.",
550                                server_name, server_name
551                            ),
552                        }],
553                        details: serde_json::json!({"mode": "list", "error": "no_tools", "server": server_name}),
554                    }
555                } else {
556                    ToolResult {
557                        content: vec![Content::Text {
558                            text: format!(
559                                "Server \"{}\" not found. Use mcp({{}}) to see available servers.",
560                                server_name
561                            ),
562                        }],
563                        details: serde_json::json!({"mode": "list", "error": "not_found"}),
564                    }
565                }
566            }
567        }
568    }
569
570    /// Execute the search operation.
571    async fn execute_search(
572        &self,
573        query: &str,
574        regex: bool,
575        filter_server: Option<&str>,
576    ) -> ToolResult {
577        let tool_cache = self.tool_cache.lock().await;
578        let query_lower = query.to_lowercase();
579
580        let matches: Vec<(String, McpToolInfo)> = tool_cache
581            .iter()
582            .filter(|(server_name, _)| filter_server.is_none_or(|s| server_name.as_str() == s))
583            .flat_map(|(server_name, tools)| {
584                let ql = query_lower.clone();
585                tools.iter().filter_map(move |tool| {
586                    let name_match = if regex {
587                        regex::Regex::new(query)
588                            .ok()
589                            .is_some_and(|re| re.is_match(&tool.name))
590                    } else {
591                        tool.name.to_lowercase().contains(&ql)
592                    };
593
594                    let desc_match = tool.description.as_ref().is_some_and(|desc| {
595                        if regex {
596                            regex::Regex::new(query)
597                                .ok()
598                                .is_some_and(|re| re.is_match(desc))
599                        } else {
600                            desc.to_lowercase().contains(&ql)
601                        }
602                    });
603
604                    if name_match || desc_match {
605                        Some((server_name.clone(), tool.clone()))
606                    } else {
607                        None
608                    }
609                })
610            })
611            .take(MAX_SEARCH_RESULTS)
612            .collect();
613
614        drop(tool_cache);
615
616        if matches.is_empty() {
617            return ToolResult {
618                content: vec![Content::Text {
619                    text: format!("No tools matching \"{}\"", query),
620                }],
621                details: serde_json::json!({"mode": "search", "matches": [], "query": query}),
622            };
623        }
624
625        let text = McpProxyTool::format_search_results(query, &matches);
626        let count = matches.len();
627        ToolResult {
628            content: vec![Content::Text { text }],
629            details: serde_json::json!({
630                "mode": "search",
631                "matches": matches.iter().map(|(s, t)| serde_json::json!({"server": s, "tool": t.name})).collect::<Vec<_>>(),
632                "count": count,
633                "query": query,
634            }),
635        }
636    }
637
638    /// Execute the describe operation.
639    async fn execute_describe(&self, tool_name: &str) -> ToolResult {
640        let tool_cache = self.tool_cache.lock().await;
641
642        for (server_name, tools) in tool_cache.iter() {
643            for tool in tools {
644                if tool.name == tool_name {
645                    let prefix = self
646                        .config
647                        .settings
648                        .as_ref()
649                        .map(|s| s.tool_prefix.as_str())
650                        .unwrap_or("server");
651                    let full_name = format_tool_name(&tool.name, server_name, prefix);
652
653                    let mut lines = vec![
654                        full_name,
655                        format!("Server: {}", server_name),
656                        String::new(),
657                        tool.description
658                            .clone()
659                            .unwrap_or_else(|| "(no description)".to_string()),
660                        String::new(),
661                    ];
662
663                    let schema = &tool.input_schema;
664                    if !schema.is_null() && schema.is_object() {
665                        if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
666                            if props.is_empty() {
667                                lines.push("Parameters: (none)".to_string());
668                            } else {
669                                lines.push("Parameters:".to_string());
670                                let required: std::collections::HashSet<&str> = schema
671                                    .get("required")
672                                    .and_then(|r| r.as_array())
673                                    .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect())
674                                    .unwrap_or_default();
675
676                                for (prop_name, prop_schema) in props {
677                                    let type_str = prop_schema
678                                        .get("type")
679                                        .and_then(|t| t.as_str())
680                                        .unwrap_or("any");
681                                    let desc = prop_schema
682                                        .get("description")
683                                        .and_then(|d| d.as_str())
684                                        .unwrap_or("");
685                                    let req = if required.contains(prop_name.as_str()) {
686                                        " *required*"
687                                    } else {
688                                        ""
689                                    };
690                                    lines.push(format!(
691                                        "  - {} ({}){}{}",
692                                        prop_name,
693                                        type_str,
694                                        req,
695                                        if desc.is_empty() {
696                                            String::new()
697                                        } else {
698                                            format!(" - {}", desc)
699                                        }
700                                    ));
701                                }
702                            }
703                        } else {
704                            lines.push("Parameters: (empty schema)".to_string());
705                        }
706                    } else {
707                        lines.push("Parameters: (none)".to_string());
708                    }
709
710                    return ToolResult {
711                        content: vec![Content::Text {
712                            text: lines.join("\n"),
713                        }],
714                        details: serde_json::json!({
715                            "mode": "describe",
716                            "server": server_name,
717                            "tool": tool_name,
718                        }),
719                    };
720                }
721            }
722        }
723
724        ToolResult {
725            content: vec![Content::Text {
726                text: format!(
727                    "Tool \"{}\" not found. Use mcp({{ search: \"...\" }}) to find tools.",
728                    tool_name
729                ),
730            }],
731            details: serde_json::json!({"mode": "describe", "error": "not_found"}),
732        }
733    }
734
735    /// Execute the connect operation.
736    async fn execute_connect(&self, server_name: &str) -> ToolResult {
737        if !self.config.mcp_servers.contains_key(server_name) {
738            return ToolResult {
739                content: vec![Content::Text {
740                    text: format!(
741                        "Server \"{}\" not found. Use mcp({{}}) to see available servers.",
742                        server_name
743                    ),
744                }],
745                details: serde_json::json!({"mode": "connect", "error": "not_found"}),
746            };
747        }
748
749        let connected = self.ensure_connected(server_name).await;
750        if connected {
751            self.cache_tools(server_name).await;
752
753            // Touch the server to mark it as recently used
754            let mut manager = self.manager.lock().await;
755            manager.touch(server_name);
756            drop(manager);
757
758            // List tools to show results
759            self.execute_list(server_name).await
760        } else {
761            ToolResult {
762                content: vec![Content::Text {
763                    text: format!(
764                        "Failed to connect to \"{}\". Check the server config.",
765                        server_name
766                    ),
767                }],
768                details: serde_json::json!({"mode": "connect", "error": "connect_failed", "server": server_name}),
769            }
770        }
771    }
772
773    /// Execute the tool call operation.
774    async fn execute_call(
775        &self,
776        tool_name: &str,
777        args_str: Option<&str>,
778        server_override: Option<&str>,
779    ) -> ToolResult {
780        // Parse args JSON if provided
781        let parsed_args: serde_json::Value = args_str
782            .and_then(|s| serde_json::from_str(s).ok())
783            .unwrap_or(serde_json::json!({}));
784
785        // Find the server and original tool name
786        let prefix_mode = self
787            .config
788            .settings
789            .as_ref()
790            .map(|s| s.tool_prefix.as_str())
791            .unwrap_or("server");
792
793        let (server_name, original_name) = if let Some(srv) = server_override {
794            // Server specified — lookup tool by original name
795            (srv.to_string(), tool_name.to_string())
796        } else {
797            // No server — search all
798            let tool_cache = self.tool_cache.lock().await;
799            let mut found = None;
800            for (srv, tools) in tool_cache.iter() {
801                for tool in tools {
802                    let prefixed = format_tool_name(&tool.name, srv, prefix_mode);
803                    if prefixed == tool_name || tool.name == tool_name {
804                        found = Some((srv.clone(), tool.name.clone()));
805                        break;
806                    }
807                }
808                if found.is_some() {
809                    break;
810                }
811            }
812            match found {
813                Some(f) => f,
814                None => {
815                    return ToolResult {
816                        content: vec![Content::Text {
817                            text: format!(
818                                "Tool \"{}\" not found. Use mcp({{ search: \"...\" }}) to find tools.",
819                                tool_name
820                            ),
821                        }],
822                        details: serde_json::json!({"mode": "call", "error": "tool_not_found"}),
823                    };
824                }
825            }
826        };
827
828        // Ensure connected
829        if !self.ensure_connected(&server_name).await {
830            return ToolResult {
831                content: vec![Content::Text {
832                    text: format!(
833                        "Server \"{}\" is not available. Use mcp({{ connect: \"{}\" }}) to connect.",
834                        server_name, server_name
835                    ),
836                }],
837                details: serde_json::json!({"mode": "call", "error": "server_unavailable"}),
838            };
839        }
840
841        // Touch the server
842        {
843            let mut manager = self.manager.lock().await;
844            manager.touch(&server_name);
845        }
846
847        // Call the tool
848        match self
849            .call_tool(&server_name, &original_name, parsed_args)
850            .await
851        {
852            Ok((content, is_error)) => {
853                let text: String = content
854                    .iter()
855                    .filter_map(|c| {
856                        if let Content::Text { text } = c {
857                            Some(text.as_str())
858                        } else {
859                            None
860                        }
861                    })
862                    .collect::<Vec<_>>()
863                    .join("\n");
864
865                if is_error {
866                    ToolResult {
867                        content: vec![Content::Text {
868                            text: format!("Error: {}", text),
869                        }],
870                        details: serde_json::json!({"mode": "call", "error": "tool_error", "server": server_name}),
871                    }
872                } else {
873                    ToolResult {
874                        content: vec![Content::Text { text }],
875                        details: serde_json::json!({"mode": "call", "server": server_name, "tool": original_name}),
876                    }
877                }
878            }
879            Err(e) => ToolResult {
880                content: vec![Content::Text {
881                    text: format!("Failed to call tool: {}", e),
882                }],
883                details: serde_json::json!({"mode": "call", "error": "call_failed", "server": server_name}),
884            },
885        }
886    }
887}
888
889#[async_trait::async_trait]
890impl AgentTool for McpProxyTool {
891    fn name(&self) -> &str {
892        "mcp"
893    }
894
895    fn label(&self) -> &str {
896        "mcp"
897    }
898
899    fn description(&self) -> &str {
900        "MCP gateway - connect to MCP servers and call their tools. Non-MCP Pi tools should be called directly, not through mcp.\n\n\
901         Direct tools available (call as normal tools): varies by configuration\n\n\
902         Servers: varies by configuration\n\n\
903         Usage:\n\
904           mcp({ })                              → Show server status\n\
905           mcp({ server: \"name\" })               → List tools from server\n\
906           mcp({ search: \"query\" })              → Search MCP tools by name/description\n\
907           mcp({ describe: \"tool_name\" })        → Show tool details and parameters\n\
908           mcp({ connect: \"server-name\" })       → Connect to a server and refresh metadata\n\
909           mcp({ tool: \"name\", args: '{\"key\": \"value\"}' })    → Call a tool (args is JSON string)\n\
910           mcp({ action: \"ui-messages\" })        → Retrieve accumulated messages from completed UI sessions\n\
911           mcp({ action: \"auth-start\", server: \"name\" })      → Start manual OAuth and get a browser URL\n\
912           mcp({ action: \"auth-complete\", server: \"name\", args: '{\"redirectUrl\":\"...\"}' }) → Complete OAuth\n\n\
913         Mode: action > tool (call) > connect > describe > search > server (list) > nothing (status)"
914    }
915
916    fn parameters_schema(&self) -> serde_json::Value {
917        serde_json::json!({
918            "type": "object",
919            "properties": {
920                "server": {
921                    "type": "string",
922                    "description": "Server name for listing tools"
923                },
924                "tool": {
925                    "type": "string",
926                    "description": "Tool name to call"
927                },
928                "args": {
929                    "type": "string",
930                    "description": "JSON string of arguments for the tool call"
931                },
932                "connect": {
933                    "type": "string",
934                    "description": "Server name to connect to"
935                },
936                "describe": {
937                    "type": "string",
938                    "description": "Tool name to describe"
939                },
940                "search": {
941                    "type": "string",
942                    "description": "Search query for finding tools"
943                },
944                "regex": {
945                    "type": "boolean",
946                    "description": "Treat search query as regex (default: false)"
947                },
948                "includeSchemas": {
949                    "type": "boolean",
950                    "description": "Include parameter schemas in search results (default: true)"
951                },
952                "action": {
953                    "type": "string",
954                    "description": "Action: 'ui-messages', 'auth-start', or 'auth-complete'"
955                }
956            }
957        })
958    }
959
960    async fn execute(
961        &self,
962        params: serde_json::Value,
963        _ctx: ToolContext,
964    ) -> Result<ToolResult, ToolError> {
965        // Determine mode: action > tool (call) > connect > describe > search > server (list) > nothing (status)
966        if let Some(action) = params.get("action").and_then(|v| v.as_str()) {
967            match action {
968                "ui-messages" => {
969                    return Ok(ToolResult {
970                        content: vec![Content::Text {
971                            text: "No UI session messages available. (UI sessions not yet implemented in this version.)".to_string(),
972                        }],
973                        details: serde_json::json!({"mode": "ui-messages", "sessions": 0}),
974                    });
975                }
976                "auth-start" => {
977                    let server_name = params.get("server").and_then(|v| v.as_str()).unwrap_or("");
978                    if server_name.is_empty() {
979                        return Err(ToolError::InvalidArgs(
980                            "Missing 'server' argument for auth-start action".into(),
981                        ));
982                    }
983                    return Ok(ToolResult {
984                        content: vec![Content::Text {
985                            text: format!(
986                                "OAuth authentication for \"{}\" is not yet implemented in this version. \
987                                 Please start the server manually and configure authentication.",
988                                server_name
989                            ),
990                        }],
991                        details: serde_json::json!({"mode": "auth-start", "error": "not_implemented"}),
992                    });
993                }
994                "auth-complete" => {
995                    let server_name = params.get("server").and_then(|v| v.as_str()).unwrap_or("");
996                    if server_name.is_empty() {
997                        return Err(ToolError::InvalidArgs(
998                            "Missing 'server' argument for auth-complete action".into(),
999                        ));
1000                    }
1001                    return Ok(ToolResult {
1002                        content: vec![Content::Text {
1003                            text: format!(
1004                                "OAuth completion for \"{}\" is not yet implemented in this version.",
1005                                server_name
1006                            ),
1007                        }],
1008                        details: serde_json::json!({"mode": "auth-complete", "error": "not_implemented"}),
1009                    });
1010                }
1011                _ => {
1012                    return Err(ToolError::InvalidArgs(format!(
1013                        "Unknown action '{}'. Supported: ui-messages, auth-start, auth-complete",
1014                        action
1015                    )));
1016                }
1017            }
1018        }
1019
1020        if let Some(tool_name) = params.get("tool").and_then(|v| v.as_str()) {
1021            let args_str = params.get("args").and_then(|v| v.as_str());
1022            let server_override = params.get("server").and_then(|v| v.as_str());
1023            return Ok(self
1024                .execute_call(tool_name, args_str, server_override)
1025                .await);
1026        }
1027
1028        if let Some(server_name) = params.get("connect").and_then(|v| v.as_str()) {
1029            return Ok(self.execute_connect(server_name).await);
1030        }
1031
1032        if let Some(tool_name) = params.get("describe").and_then(|v| v.as_str()) {
1033            return Ok(self.execute_describe(tool_name).await);
1034        }
1035
1036        if let Some(query) = params.get("search").and_then(|v| v.as_str()) {
1037            let regex = params
1038                .get("regex")
1039                .and_then(|v| v.as_bool())
1040                .unwrap_or(false);
1041            let server_filter = params.get("server").and_then(|v| v.as_str());
1042            return Ok(self.execute_search(query, regex, server_filter).await);
1043        }
1044
1045        if let Some(server_name) = params.get("server").and_then(|v| v.as_str()) {
1046            return Ok(self.execute_list(server_name).await);
1047        }
1048
1049        // Default: status
1050        Ok(self.execute_status().await)
1051    }
1052}
1053
1054// ═══════════════════════════════════════════════════════════════════
1055// Direct Tool Adapter
1056// ═══════════════════════════════════════════════════════════════════
1057
1058/// A direct MCP tool adapter that wraps a remote MCP server tool.
1059struct McpDirectTool {
1060    server_name: String,
1061    original_name: String,
1062    display_name: String,
1063    description: String,
1064    input_schema: serde_json::Value,
1065    manager: Arc<Mutex<ServerManager>>,
1066}
1067
1068#[async_trait::async_trait]
1069impl AgentTool for McpDirectTool {
1070    fn name(&self) -> &str {
1071        &self.display_name
1072    }
1073
1074    fn label(&self) -> &str {
1075        &self.display_name
1076    }
1077
1078    fn description(&self) -> &str {
1079        &self.description
1080    }
1081
1082    fn parameters_schema(&self) -> serde_json::Value {
1083        self.input_schema.clone()
1084    }
1085
1086    async fn execute(
1087        &self,
1088        params: serde_json::Value,
1089        _ctx: ToolContext,
1090    ) -> Result<ToolResult, ToolError> {
1091        let mut manager = self.manager.lock().await;
1092
1093        // Ensure connected
1094        if !manager.ensure_connected(&self.server_name).await {
1095            return Err(ToolError::Failed(format!(
1096                "Server '{}' is not available",
1097                self.server_name
1098            )));
1099        }
1100
1101        manager.touch(&self.server_name);
1102        let client = manager.get_client(&self.server_name).ok_or_else(|| {
1103            ToolError::Failed(format!("Server '{}' has no client", self.server_name))
1104        })?;
1105
1106        drop(manager);
1107
1108        let client = client.lock().await;
1109        let result = client
1110            .call_tool(&self.original_name, params)
1111            .await
1112            .map_err(|e| ToolError::Failed(format!("MCP call failed: {}", e)))?;
1113
1114        if result.is_error {
1115            let error_text = result
1116                .content
1117                .iter()
1118                .filter_map(|c| match c {
1119                    McpContent::Text { text } => Some(text.as_str()),
1120                    _ => None,
1121                })
1122                .collect::<Vec<_>>()
1123                .join("\n");
1124            return Err(ToolError::Failed(error_text));
1125        }
1126
1127        let content: Vec<Content> = result
1128            .content
1129            .into_iter()
1130            .map(|c| match c {
1131                McpContent::Text { text } => Content::Text { text },
1132                McpContent::Image { data, mime_type } => Content::Image { data, mime_type },
1133            })
1134            .collect();
1135
1136        Ok(ToolResult {
1137            content,
1138            details: serde_json::Value::Null,
1139        })
1140    }
1141}