pub mod computer_use;
pub mod exec;
pub mod filesystem;
pub mod mcp;
pub mod subagent;
pub mod web;
pub mod web_client;
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use crate::domain::{ToolDefinition, ToolOutcome};
use super::ctx::ExecContext;
#[async_trait]
pub trait ToolExecutor: Send + Sync {
fn name(&self) -> &'static str;
fn schema(&self) -> ToolDefinition;
fn is_internal(&self) -> bool {
false
}
async fn execute(&self, args: serde_json::Value, ctx: ExecContext) -> ToolOutcome;
}
pub struct ToolRegistry {
entries: HashMap<&'static str, Arc<dyn ToolExecutor>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
entries: HashMap::new(),
}
}
pub fn register(&mut self, tool: Arc<dyn ToolExecutor>) {
self.entries.insert(tool.name(), tool);
}
pub fn get(&self, name: &str) -> Option<Arc<dyn ToolExecutor>> {
self.entries.get(name).cloned()
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn names(&self) -> impl Iterator<Item = &'static str> + '_ {
self.entries.keys().copied()
}
pub fn describe_all(&self) -> Vec<ToolDefinition> {
self.entries
.values()
.filter(|t| !t.is_internal())
.map(|t| t.schema())
.collect()
}
}
impl Default for ToolRegistry {
fn default() -> Self {
let mut r = Self::new();
r.register(Arc::new(filesystem::ReadFileTool));
r.register(Arc::new(filesystem::WriteFileTool));
r.register(Arc::new(filesystem::EditFileTool));
r.register(Arc::new(filesystem::DeleteFileTool));
r.register(Arc::new(filesystem::CreateDirectoryTool));
r.register(Arc::new(exec::ExecuteCommandTool));
r.register(Arc::new(mcp::McpToolProxy));
r
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TuiMode {
Interactive,
Headless,
}
impl ToolRegistry {
pub fn build(
_config: &crate::app::Config,
mode: TuiMode,
providers: Arc<crate::providers::ProviderFactory>,
) -> Arc<Self> {
let mut r = Self::new();
r.register(Arc::new(filesystem::ReadFileTool));
r.register(Arc::new(filesystem::WriteFileTool));
r.register(Arc::new(filesystem::EditFileTool));
r.register(Arc::new(filesystem::DeleteFileTool));
r.register(Arc::new(filesystem::CreateDirectoryTool));
r.register(Arc::new(exec::ExecuteCommandTool));
r.register(Arc::new(mcp::McpToolProxy));
if let Some(key) = crate::utils::resolve_api_key("OLLAMA_API_KEY", None) {
r.register(Arc::new(web::WebSearchTool::new(key.clone())));
r.register(Arc::new(web::WebFetchTool::new(key)));
}
if mode == TuiMode::Interactive {
let backend = computer_use::probe();
if backend.is_usable() {
let driver = Arc::new(computer_use::ComputerUseDriver::new(backend));
r.register(Arc::new(computer_use::ScreenshotTool::new(driver.clone())));
r.register(Arc::new(computer_use::ClickTool::new(driver.clone())));
r.register(Arc::new(computer_use::TypeTextTool::new(driver.clone())));
r.register(Arc::new(computer_use::PressKeyTool::new(driver.clone())));
r.register(Arc::new(computer_use::ScrollTool::new(driver.clone())));
r.register(Arc::new(computer_use::MouseMoveTool::new(driver.clone())));
r.register(Arc::new(computer_use::ListWindowsTool::new(driver)));
}
}
let spawner = Arc::new(subagent::SubagentSpawner::new(providers));
r.register(Arc::new(subagent::SubagentTool::new(spawner)));
Arc::new(r)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_registry_has_builtin_tools() {
let r = ToolRegistry::default();
for name in &[
"read_file",
"write_file",
"edit_file",
"delete_file",
"create_directory",
"execute_command",
] {
assert!(r.get(name).is_some(), "missing: {}", name);
}
assert!(r.get("not_a_tool").is_none());
assert!(r.len() >= 6);
}
#[test]
fn describe_all_returns_one_per_user_facing_tool() {
let r = ToolRegistry::default();
let schemas = r.describe_all();
let visible = r
.names()
.filter(|n| r.get(n).map(|t| !t.is_internal()).unwrap_or(false))
.count();
assert_eq!(schemas.len(), visible);
for schema in &schemas {
assert!(
r.get(&schema.name).is_some(),
"schema for unknown tool: {}",
schema.name
);
}
}
#[test]
fn mcp_proxy_is_registered_but_internal() {
let r = ToolRegistry::default();
let proxy = r.get("mcp_proxy").expect("mcp_proxy registered");
assert!(proxy.is_internal());
assert!(!r.describe_all().iter().any(|s| s.name == "mcp_proxy"));
}
#[test]
fn schema_name_matches_executor_name() {
let r = ToolRegistry::default();
for name in r.names() {
let tool = r.get(name).unwrap();
assert_eq!(tool.name(), tool.schema().name.as_str());
}
}
static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
#[test]
fn build_registers_web_tools_when_key_present() {
let _guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let prior = std::env::var("OLLAMA_API_KEY").ok();
unsafe {
std::env::set_var("OLLAMA_API_KEY", "test-key-build");
}
let cfg = crate::app::Config::default();
let providers = Arc::new(crate::providers::ProviderFactory::new(cfg.clone()));
let r = ToolRegistry::build(&cfg, TuiMode::Interactive, providers);
assert!(r.get("web_search").is_some(), "web_search registered");
assert!(r.get("web_fetch").is_some(), "web_fetch registered");
unsafe {
match prior {
Some(v) => std::env::set_var("OLLAMA_API_KEY", v),
None => std::env::remove_var("OLLAMA_API_KEY"),
}
}
}
#[test]
fn build_skips_web_tools_without_key() {
let _guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let prior = std::env::var("OLLAMA_API_KEY").ok();
unsafe {
std::env::remove_var("OLLAMA_API_KEY");
}
let cfg = crate::app::Config::default();
let providers = Arc::new(crate::providers::ProviderFactory::new(cfg.clone()));
let r = ToolRegistry::build(&cfg, TuiMode::Headless, providers);
assert!(r.get("web_search").is_none(), "web_search skipped");
assert!(r.get("web_fetch").is_none(), "web_fetch skipped");
assert!(r.get("read_file").is_some());
assert!(r.get("execute_command").is_some());
unsafe {
if let Some(v) = prior {
std::env::set_var("OLLAMA_API_KEY", v);
}
}
}
}