mermaid_cli/providers/tool/
mod.rs1pub mod computer_use;
19pub mod exec;
20pub mod filesystem;
21pub mod mcp;
22pub mod subagent;
23pub mod web;
24pub mod web_client;
25
26use async_trait::async_trait;
27use std::collections::HashMap;
28use std::sync::Arc;
29
30use crate::domain::{ToolDefinition, ToolOutcome};
31
32use super::ctx::ExecContext;
33
34#[async_trait]
38pub trait ToolExecutor: Send + Sync {
39 fn name(&self) -> &'static str;
42
43 fn schema(&self) -> ToolDefinition;
49
50 fn is_internal(&self) -> bool {
56 false
57 }
58
59 async fn execute(&self, args: serde_json::Value, ctx: ExecContext) -> ToolOutcome;
63}
64
65pub struct ToolRegistry {
69 entries: HashMap<&'static str, Arc<dyn ToolExecutor>>,
70}
71
72impl ToolRegistry {
73 pub fn new() -> Self {
74 Self {
75 entries: HashMap::new(),
76 }
77 }
78
79 pub fn register(&mut self, tool: Arc<dyn ToolExecutor>) {
80 self.entries.insert(tool.name(), tool);
81 }
82
83 pub fn get(&self, name: &str) -> Option<Arc<dyn ToolExecutor>> {
84 self.entries.get(name).cloned()
85 }
86
87 pub fn len(&self) -> usize {
88 self.entries.len()
89 }
90
91 pub fn is_empty(&self) -> bool {
92 self.entries.is_empty()
93 }
94
95 pub fn names(&self) -> impl Iterator<Item = &'static str> + '_ {
96 self.entries.keys().copied()
97 }
98
99 pub fn describe_all(&self) -> Vec<ToolDefinition> {
105 self.entries
106 .values()
107 .filter(|t| !t.is_internal())
108 .map(|t| t.schema())
109 .collect()
110 }
111}
112
113impl Default for ToolRegistry {
114 fn default() -> Self {
115 let mut r = Self::new();
116 r.register(Arc::new(filesystem::ReadFileTool));
117 r.register(Arc::new(filesystem::WriteFileTool));
118 r.register(Arc::new(filesystem::EditFileTool));
119 r.register(Arc::new(filesystem::DeleteFileTool));
120 r.register(Arc::new(filesystem::CreateDirectoryTool));
121 r.register(Arc::new(exec::ExecuteCommandTool));
122 r.register(Arc::new(mcp::McpToolProxy));
126 r
127 }
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub enum TuiMode {
137 Interactive,
138 Headless,
139}
140
141impl ToolRegistry {
142 pub fn build(
159 _config: &crate::app::Config,
160 mode: TuiMode,
161 providers: Arc<crate::providers::ProviderFactory>,
162 ) -> Arc<Self> {
163 let mut r = Self::new();
164 r.register(Arc::new(filesystem::ReadFileTool));
165 r.register(Arc::new(filesystem::WriteFileTool));
166 r.register(Arc::new(filesystem::EditFileTool));
167 r.register(Arc::new(filesystem::DeleteFileTool));
168 r.register(Arc::new(filesystem::CreateDirectoryTool));
169 r.register(Arc::new(exec::ExecuteCommandTool));
170 r.register(Arc::new(mcp::McpToolProxy));
171
172 if let Some(key) = crate::utils::resolve_api_key("OLLAMA_API_KEY", None) {
173 r.register(Arc::new(web::WebSearchTool::new(key.clone())));
174 r.register(Arc::new(web::WebFetchTool::new(key)));
175 }
176
177 if mode == TuiMode::Interactive {
182 let backend = computer_use::probe();
183 if backend.is_usable() {
184 let driver = Arc::new(computer_use::ComputerUseDriver::new(backend));
185 r.register(Arc::new(computer_use::ScreenshotTool::new(driver.clone())));
186 r.register(Arc::new(computer_use::ClickTool::new(driver.clone())));
187 r.register(Arc::new(computer_use::TypeTextTool::new(driver.clone())));
188 r.register(Arc::new(computer_use::PressKeyTool::new(driver.clone())));
189 r.register(Arc::new(computer_use::ScrollTool::new(driver.clone())));
190 r.register(Arc::new(computer_use::MouseMoveTool::new(driver.clone())));
191 r.register(Arc::new(computer_use::ListWindowsTool::new(driver)));
192 }
193 }
194
195 let spawner = Arc::new(subagent::SubagentSpawner::new(providers));
200 r.register(Arc::new(subagent::SubagentTool::new(spawner)));
201
202 Arc::new(r)
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn default_registry_has_builtin_tools() {
212 let r = ToolRegistry::default();
213 for name in &[
214 "read_file",
215 "write_file",
216 "edit_file",
217 "delete_file",
218 "create_directory",
219 "execute_command",
220 ] {
221 assert!(r.get(name).is_some(), "missing: {}", name);
222 }
223 assert!(r.get("not_a_tool").is_none());
224 assert!(r.len() >= 6);
225 }
226
227 #[test]
228 fn describe_all_returns_one_per_user_facing_tool() {
229 let r = ToolRegistry::default();
230 let schemas = r.describe_all();
231 let visible = r
234 .names()
235 .filter(|n| r.get(n).map(|t| !t.is_internal()).unwrap_or(false))
236 .count();
237 assert_eq!(schemas.len(), visible);
238 for schema in &schemas {
239 assert!(
240 r.get(&schema.name).is_some(),
241 "schema for unknown tool: {}",
242 schema.name
243 );
244 }
245 }
246
247 #[test]
248 fn mcp_proxy_is_registered_but_internal() {
249 let r = ToolRegistry::default();
250 let proxy = r.get("mcp_proxy").expect("mcp_proxy registered");
251 assert!(proxy.is_internal());
252 assert!(!r.describe_all().iter().any(|s| s.name == "mcp_proxy"));
253 }
254
255 #[test]
256 fn schema_name_matches_executor_name() {
257 let r = ToolRegistry::default();
258 for name in r.names() {
259 let tool = r.get(name).unwrap();
260 assert_eq!(tool.name(), tool.schema().name.as_str());
261 }
262 }
263
264 static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
269
270 #[test]
271 fn build_registers_web_tools_when_key_present() {
272 let _guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
273 let prior = std::env::var("OLLAMA_API_KEY").ok();
274 unsafe {
275 std::env::set_var("OLLAMA_API_KEY", "test-key-build");
276 }
277 let cfg = crate::app::Config::default();
278 let providers = Arc::new(crate::providers::ProviderFactory::new(cfg.clone()));
279 let r = ToolRegistry::build(&cfg, TuiMode::Interactive, providers);
280 assert!(r.get("web_search").is_some(), "web_search registered");
281 assert!(r.get("web_fetch").is_some(), "web_fetch registered");
282 unsafe {
283 match prior {
284 Some(v) => std::env::set_var("OLLAMA_API_KEY", v),
285 None => std::env::remove_var("OLLAMA_API_KEY"),
286 }
287 }
288 }
289
290 #[test]
291 fn build_skips_web_tools_without_key() {
292 let _guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
293 let prior = std::env::var("OLLAMA_API_KEY").ok();
294 unsafe {
295 std::env::remove_var("OLLAMA_API_KEY");
296 }
297 let cfg = crate::app::Config::default();
298 let providers = Arc::new(crate::providers::ProviderFactory::new(cfg.clone()));
299 let r = ToolRegistry::build(&cfg, TuiMode::Headless, providers);
300 assert!(r.get("web_search").is_none(), "web_search skipped");
301 assert!(r.get("web_fetch").is_none(), "web_fetch skipped");
302 assert!(r.get("read_file").is_some());
303 assert!(r.get("execute_command").is_some());
304 unsafe {
305 if let Some(v) = prior {
306 std::env::set_var("OLLAMA_API_KEY", v);
307 }
308 }
309 }
310}