Skip to main content

mermaid_cli/providers/tool/
mod.rs

1//! Tool executors — one type per tool the model can call.
2//!
3//! The trait is small: `execute(args, ctx) -> ToolOutcome` for
4//! dispatch, plus `schema() -> ToolDefinition` for advertising the
5//! tool to the model. Everything else (cancellation, progress,
6//! identity, workdir) rides inside `ExecContext`.
7//!
8//! Adding a tool:
9//!   1. New file under `src/providers/tool/`.
10//!   2. Impl `ToolExecutor` for a unit struct — both `execute` and
11//!      `schema`.
12//!   3. Register it in `ToolRegistry::default()`.
13//!
14//! Because `schema()` lives on the same trait as `execute()`, the
15//! name + JSON schema the model sees cannot drift from the handler
16//! that runs when the model calls it. Single source of truth.
17
18pub mod computer_use;
19pub mod exec;
20pub mod filesystem;
21pub mod mcp;
22pub mod subagent;
23pub mod web;
24pub mod web_client;
25
26use async_trait::async_trait;
27use std::collections::HashMap;
28use std::sync::Arc;
29
30use crate::domain::{ToolDefinition, ToolOutcome};
31
32use super::ctx::ExecContext;
33
34/// Implemented by every tool that the model can call. All tools are
35/// `Send + Sync` — they run across tokio `select!` branches inside
36/// the effect runner.
37#[async_trait]
38pub trait ToolExecutor: Send + Sync {
39    /// Canonical name the model uses to call this tool. Matches
40    /// `schema().name` exactly.
41    fn name(&self) -> &'static str;
42
43    /// JSON-schema description the model sees in the outgoing
44    /// request. Adapters translate this into provider-native shape
45    /// (Anthropic's `type: "custom"`, Gemini's `function_declarations`,
46    /// OpenAI's flat `tools`, Ollama's function calling). The same
47    /// `ToolDefinition` feeds all four.
48    fn schema(&self) -> ToolDefinition;
49
50    /// True for tools that exist for internal dispatch only and
51    /// should NOT be advertised to the model (e.g. the MCP proxy
52    /// router, which fronts every `mcp__server__tool` call — the
53    /// individual MCP tools are advertised separately from
54    /// `state.mcp.servers`). Default `false`.
55    fn is_internal(&self) -> bool {
56        false
57    }
58
59    /// Run the tool. The returned `ToolOutcome` is passed verbatim
60    /// into `Msg::ToolFinished` — there's no error-to-outcome
61    /// conversion happening outside this function.
62    async fn execute(&self, args: serde_json::Value, ctx: ExecContext) -> ToolOutcome;
63}
64
65/// Registry of dispatchable tools. Single source of truth for what
66/// the model sees AND what handles a call when the model issues it.
67/// Built once at startup; read-only after that.
68pub struct ToolRegistry {
69    entries: HashMap<&'static str, Arc<dyn ToolExecutor>>,
70}
71
72impl ToolRegistry {
73    pub fn new() -> Self {
74        Self {
75            entries: HashMap::new(),
76        }
77    }
78
79    pub fn register(&mut self, tool: Arc<dyn ToolExecutor>) {
80        self.entries.insert(tool.name(), tool);
81    }
82
83    pub fn get(&self, name: &str) -> Option<Arc<dyn ToolExecutor>> {
84        self.entries.get(name).cloned()
85    }
86
87    pub fn len(&self) -> usize {
88        self.entries.len()
89    }
90
91    pub fn is_empty(&self) -> bool {
92        self.entries.is_empty()
93    }
94
95    pub fn names(&self) -> impl Iterator<Item = &'static str> + '_ {
96        self.entries.keys().copied()
97    }
98
99    /// Emit every user-facing tool's schema, for inclusion in an
100    /// outgoing `ChatRequest.tools`. Effect runner calls this before
101    /// dispatching `Cmd::CallModel` so the model always sees the
102    /// same list the runner can dispatch. Internal routers (the MCP
103    /// proxy) are filtered out.
104    pub fn describe_all(&self) -> Vec<ToolDefinition> {
105        self.entries
106            .values()
107            .filter(|t| !t.is_internal())
108            .map(|t| t.schema())
109            .collect()
110    }
111}
112
113impl Default for ToolRegistry {
114    fn default() -> Self {
115        let mut r = Self::new();
116        r.register(Arc::new(filesystem::ReadFileTool));
117        r.register(Arc::new(filesystem::WriteFileTool));
118        r.register(Arc::new(filesystem::EditFileTool));
119        r.register(Arc::new(filesystem::DeleteFileTool));
120        r.register(Arc::new(filesystem::CreateDirectoryTool));
121        r.register(Arc::new(exec::ExecuteCommandTool));
122        // MCP proxy is the dispatcher for every mcp__server__tool
123        // call; it's internal (not advertised) but MUST be registered
124        // so runtime lookups succeed.
125        r.register(Arc::new(mcp::McpToolProxy));
126        r
127    }
128}
129
130/// Whether the host mermaid process is running interactively (TUI)
131/// or headlessly (one-shot `mermaid run <prompt>` / CI). Controls
132/// which tools get registered: headless mode never advertises
133/// GUI / computer-use tools even when a display probes alive, because
134/// a CI job has no user to watch the screenshot.
135#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub enum TuiMode {
137    Interactive,
138    Headless,
139}
140
141impl ToolRegistry {
142    /// Config-aware factory. Always registers filesystem + exec +
143    /// the MCP proxy + the subagent tool. Conditionally registers:
144    ///
145    ///   - `web_search` + `web_fetch` iff `OLLAMA_API_KEY` resolves
146    ///     (via `utils::resolve_api_key`). Without a key, the tools
147    ///     would error on every call — so we don't advertise them at
148    ///     all.
149    ///   - All seven computer-use tools iff `mode == Interactive`
150    ///     AND `computer_use::probe()` returns a usable backend.
151    ///
152    /// `providers` is the shared `ProviderFactory` that the effect
153    /// runner also holds; the `SubagentSpawner` needs it so child
154    /// reducer loops hit the same provider cache.
155    ///
156    /// Returns `Arc<Self>` so the effect runner can share a handle
157    /// across turns without cloning the underlying HashMap.
158    pub fn build(
159        _config: &crate::app::Config,
160        mode: TuiMode,
161        providers: Arc<crate::providers::ProviderFactory>,
162    ) -> Arc<Self> {
163        let mut r = Self::new();
164        r.register(Arc::new(filesystem::ReadFileTool));
165        r.register(Arc::new(filesystem::WriteFileTool));
166        r.register(Arc::new(filesystem::EditFileTool));
167        r.register(Arc::new(filesystem::DeleteFileTool));
168        r.register(Arc::new(filesystem::CreateDirectoryTool));
169        r.register(Arc::new(exec::ExecuteCommandTool));
170        r.register(Arc::new(mcp::McpToolProxy));
171
172        if let Some(key) = crate::utils::resolve_api_key("OLLAMA_API_KEY", None) {
173            r.register(Arc::new(web::WebSearchTool::new(key.clone())));
174            r.register(Arc::new(web::WebFetchTool::new(key)));
175        }
176
177        // Computer-use tools only register when (a) the process runs
178        // interactively (Headless CI has no user to watch a screenshot)
179        // AND (b) a display backend passes the startup probe. Failed
180        // probe → tools aren't advertised → model can't call them.
181        if mode == TuiMode::Interactive {
182            let backend = computer_use::probe();
183            if backend.is_usable() {
184                let driver = Arc::new(computer_use::ComputerUseDriver::new(backend));
185                r.register(Arc::new(computer_use::ScreenshotTool::new(driver.clone())));
186                r.register(Arc::new(computer_use::ClickTool::new(driver.clone())));
187                r.register(Arc::new(computer_use::TypeTextTool::new(driver.clone())));
188                r.register(Arc::new(computer_use::PressKeyTool::new(driver.clone())));
189                r.register(Arc::new(computer_use::ScrollTool::new(driver.clone())));
190                r.register(Arc::new(computer_use::MouseMoveTool::new(driver.clone())));
191                r.register(Arc::new(computer_use::ListWindowsTool::new(driver)));
192            }
193        }
194
195        // Subagents: always register. Depth + breadth caps live on
196        // `SubagentSpawner`; the tool itself is harmless when nobody
197        // calls it. Headless runs do register the agent — a CI prompt
198        // may still delegate to subagents for batched work.
199        let spawner = Arc::new(subagent::SubagentSpawner::new(providers));
200        r.register(Arc::new(subagent::SubagentTool::new(spawner)));
201
202        Arc::new(r)
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn default_registry_has_builtin_tools() {
212        let r = ToolRegistry::default();
213        for name in &[
214            "read_file",
215            "write_file",
216            "edit_file",
217            "delete_file",
218            "create_directory",
219            "execute_command",
220        ] {
221            assert!(r.get(name).is_some(), "missing: {}", name);
222        }
223        assert!(r.get("not_a_tool").is_none());
224        assert!(r.len() >= 6);
225    }
226
227    #[test]
228    fn describe_all_returns_one_per_user_facing_tool() {
229        let r = ToolRegistry::default();
230        let schemas = r.describe_all();
231        // mcp_proxy is registered but internal — filtered out of
232        // describe_all. So len() includes it but schemas don't.
233        let visible = r
234            .names()
235            .filter(|n| r.get(n).map(|t| !t.is_internal()).unwrap_or(false))
236            .count();
237        assert_eq!(schemas.len(), visible);
238        for schema in &schemas {
239            assert!(
240                r.get(&schema.name).is_some(),
241                "schema for unknown tool: {}",
242                schema.name
243            );
244        }
245    }
246
247    #[test]
248    fn mcp_proxy_is_registered_but_internal() {
249        let r = ToolRegistry::default();
250        let proxy = r.get("mcp_proxy").expect("mcp_proxy registered");
251        assert!(proxy.is_internal());
252        assert!(!r.describe_all().iter().any(|s| s.name == "mcp_proxy"));
253    }
254
255    #[test]
256    fn schema_name_matches_executor_name() {
257        let r = ToolRegistry::default();
258        for name in r.names() {
259            let tool = r.get(name).unwrap();
260            assert_eq!(tool.name(), tool.schema().name.as_str());
261        }
262    }
263
264    /// Serialization guard for tests that mutate the `OLLAMA_API_KEY`
265    /// env var. Cargo's default test harness runs tests in parallel
266    /// threads inside one process; without this mutex two env-touching
267    /// tests would race and occasionally flip each other's expectations.
268    static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
269
270    #[test]
271    fn build_registers_web_tools_when_key_present() {
272        let _guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
273        let prior = std::env::var("OLLAMA_API_KEY").ok();
274        unsafe {
275            std::env::set_var("OLLAMA_API_KEY", "test-key-build");
276        }
277        let cfg = crate::app::Config::default();
278        let providers = Arc::new(crate::providers::ProviderFactory::new(cfg.clone()));
279        let r = ToolRegistry::build(&cfg, TuiMode::Interactive, providers);
280        assert!(r.get("web_search").is_some(), "web_search registered");
281        assert!(r.get("web_fetch").is_some(), "web_fetch registered");
282        unsafe {
283            match prior {
284                Some(v) => std::env::set_var("OLLAMA_API_KEY", v),
285                None => std::env::remove_var("OLLAMA_API_KEY"),
286            }
287        }
288    }
289
290    #[test]
291    fn build_skips_web_tools_without_key() {
292        let _guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
293        let prior = std::env::var("OLLAMA_API_KEY").ok();
294        unsafe {
295            std::env::remove_var("OLLAMA_API_KEY");
296        }
297        let cfg = crate::app::Config::default();
298        let providers = Arc::new(crate::providers::ProviderFactory::new(cfg.clone()));
299        let r = ToolRegistry::build(&cfg, TuiMode::Headless, providers);
300        assert!(r.get("web_search").is_none(), "web_search skipped");
301        assert!(r.get("web_fetch").is_none(), "web_fetch skipped");
302        assert!(r.get("read_file").is_some());
303        assert!(r.get("execute_command").is_some());
304        unsafe {
305            if let Some(v) = prior {
306                std::env::set_var("OLLAMA_API_KEY", v);
307            }
308        }
309    }
310}