use crate::providers::ToolDefinition;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use super::{
DynTool, Tool, ToolEffect, agent, ask_user, bg_task_tools, boxed, file_tools, glob_tool, grep,
memory, recall, shell, skill_tools, todo, web_fetch, web_search,
};
pub struct ToolCatalog {
definitions: HashMap<String, ToolDefinition>,
tools: HashMap<&'static str, DynTool>,
mcp_manager: RwLock<Option<Arc<tokio::sync::RwLock<crate::mcp::McpManager>>>>,
}
impl ToolCatalog {
pub fn new() -> Self {
let mut definitions = HashMap::new();
for def in file_tools::definitions() {
definitions.insert(def.name.clone(), def);
}
for def in grep::definitions() {
definitions.insert(def.name.clone(), def);
}
for def in shell::definitions() {
definitions.insert(def.name.clone(), def);
}
for def in agent::definitions() {
definitions.insert(def.name.clone(), def);
}
for def in bg_task_tools::definitions() {
definitions.insert(def.name.clone(), def);
}
for def in ask_user::definitions() {
definitions.insert(def.name.clone(), def);
}
for def in glob_tool::definitions() {
definitions.insert(def.name.clone(), def);
}
for def in web_fetch::definitions() {
definitions.insert(def.name.clone(), def);
}
for def in web_search::definitions() {
definitions.insert(def.name.clone(), def);
}
for def in todo::definitions() {
definitions.insert(def.name.clone(), def);
}
for def in memory::definitions() {
definitions.insert(def.name.clone(), def);
}
for def in skill_tools::definitions() {
definitions.insert(def.name.clone(), def);
}
let recall_def = recall::definition();
definitions.insert(recall_def.name.clone(), recall_def);
let mut tools: HashMap<&'static str, DynTool> = HashMap::new();
for tool in [
boxed(file_tools::ReadTool),
boxed(file_tools::WriteTool),
boxed(file_tools::EditTool),
boxed(file_tools::DeleteTool),
boxed(file_tools::ListTool),
boxed(grep::GrepTool),
boxed(glob_tool::GlobTool),
boxed(shell::BashTool),
boxed(web_fetch::WebFetchTool),
boxed(web_search::WebSearchTool),
boxed(memory::MemoryReadTool),
boxed(memory::MemoryWriteTool),
boxed(todo::TodoWriteTool),
boxed(recall::RecallContextTool),
boxed(agent::ListAgentsTool),
boxed(agent::InvokeAgentTool),
boxed(skill_tools::ListSkillsTool),
boxed(skill_tools::ActivateSkillTool),
boxed(ask_user::AskUserTool),
boxed(bg_task_tools::ListBackgroundTasksTool),
boxed(bg_task_tools::CancelTaskTool),
boxed(bg_task_tools::WaitTaskTool),
] {
tools.insert(tool.name(), tool);
}
Self {
definitions,
tools,
mcp_manager: RwLock::new(None),
}
}
pub fn set_mcp_manager(&self, manager: Arc<tokio::sync::RwLock<crate::mcp::McpManager>>) {
if let Ok(mut guard) = self.mcp_manager.write() {
*guard = Some(manager);
}
}
pub fn mcp_manager(&self) -> Option<Arc<tokio::sync::RwLock<crate::mcp::McpManager>>> {
self.mcp_manager.read().ok().and_then(|g| g.clone())
}
pub fn classify_tool_with_mcp(&self, name: &str) -> ToolEffect {
self.classify_call(name, &serde_json::Value::Null)
}
pub fn classify_call(&self, name: &str, args: &serde_json::Value) -> ToolEffect {
if crate::mcp::is_mcp_tool_name(name) {
if let Some(mgr) = self.mcp_manager()
&& let Ok(mgr) = mgr.try_read()
{
return mgr.classify_tool(name);
}
return ToolEffect::RemoteAction;
}
match self.get_tool(name) {
Some(tool) => tool.classify(args),
None => ToolEffect::LocalMutation,
}
}
pub fn is_mutating_call(&self, name: &str, args: &serde_json::Value) -> bool {
self.classify_call(name, args).is_mutating()
}
pub fn default_static() -> &'static Self {
use std::sync::OnceLock;
static CATALOG: OnceLock<ToolCatalog> = OnceLock::new();
CATALOG.get_or_init(ToolCatalog::new)
}
pub fn all_builtin_tool_names(&self) -> Vec<String> {
let mut names: Vec<String> = self.definitions.keys().cloned().collect();
names.sort();
names
}
pub fn has_tool(&self, name: &str) -> bool {
self.definitions.contains_key(name)
}
pub fn get_tool(&self, name: &str) -> Option<&dyn Tool> {
self.tools.get(name).map(|boxed| &**boxed as &dyn Tool)
}
pub fn get_definitions(&self, allowed: &[String], denied: &[String]) -> Vec<ToolDefinition> {
let mut defs: Vec<ToolDefinition> = if !allowed.is_empty() {
allowed
.iter()
.filter_map(|name| self.definitions.get(name).cloned())
.collect()
} else if !denied.is_empty() {
self.definitions
.values()
.filter(|d| !denied.contains(&d.name))
.cloned()
.collect()
} else {
self.definitions.values().cloned().collect()
};
if let Some(mgr) = self.mcp_manager()
&& let Ok(mgr) = mgr.try_read()
{
let mcp_defs = mgr.all_tool_definitions();
if !allowed.is_empty() {
for def in mcp_defs {
if allowed.contains(&def.name) {
defs.push(def);
}
}
} else if !denied.is_empty() {
for def in mcp_defs {
if !denied.contains(&def.name) {
defs.push(def);
}
}
} else {
defs.extend(mcp_defs);
}
}
defs
}
}
impl Default for ToolCatalog {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_registers_every_builtin_tool() {
let catalog = ToolCatalog::new();
let names = catalog.all_builtin_tool_names();
for expected in [
"Read",
"Write",
"Edit",
"Delete",
"List",
"Grep",
"Glob",
"Bash",
"InvokeAgent",
"ListBackgroundTasks",
"CancelTask",
"WaitTask",
"AskUser",
"WebFetch",
"WebSearch",
"TodoWrite",
"MemoryRead",
"MemoryWrite",
"ListSkills",
"ActivateSkill",
"RecallContext",
] {
assert!(
names.contains(&expected.to_string()),
"missing built-in tool {expected:?} (got {names:?})"
);
}
}
#[test]
fn all_builtin_tool_names_returns_sorted() {
let names = ToolCatalog::new().all_builtin_tool_names();
let mut sorted = names.clone();
sorted.sort();
assert_eq!(
names, sorted,
"names must be sorted for stable test snapshots"
);
}
#[test]
fn has_tool_matches_builtin_set() {
let catalog = ToolCatalog::new();
assert!(catalog.has_tool("Read"), "Read must be registered");
assert!(catalog.has_tool("Bash"), "Bash must be registered");
assert!(!catalog.has_tool("definitely_not_a_real_tool"));
}
#[test]
fn get_definitions_no_filter_returns_all() {
let catalog = ToolCatalog::new();
let defs = catalog.get_definitions(&[], &[]);
let names: std::collections::HashSet<_> = defs.iter().map(|d| d.name.clone()).collect();
for name in catalog.all_builtin_tool_names() {
assert!(names.contains(&name), "missing {name} in no-filter result");
}
}
#[test]
fn get_definitions_allowlist_only_returns_allowed() {
let catalog = ToolCatalog::new();
let defs = catalog.get_definitions(&["Read".to_string(), "Write".to_string()], &[]);
let names: Vec<_> = defs.iter().map(|d| d.name.clone()).collect();
assert_eq!(names.len(), 2);
assert!(names.contains(&"Read".to_string()));
assert!(names.contains(&"Write".to_string()));
}
#[test]
fn get_definitions_denylist_excludes_denied() {
let catalog = ToolCatalog::new();
let defs = catalog.get_definitions(&[], &["Bash".to_string()]);
let names: std::collections::HashSet<_> = defs.iter().map(|d| d.name.clone()).collect();
assert!(!names.contains("Bash"), "Bash should be filtered out");
assert!(names.contains("Read"), "Read should still be present");
}
#[test]
fn get_definitions_allowlist_wins_over_denylist() {
let catalog = ToolCatalog::new();
let defs = catalog.get_definitions(
&["Read".to_string()], &["Read".to_string()], );
let names: Vec<_> = defs.iter().map(|d| d.name.clone()).collect();
assert_eq!(names, vec!["Read".to_string()], "allowlist must win");
}
#[test]
fn classify_tool_with_mcp_falls_back_for_builtins() {
let catalog = ToolCatalog::new();
assert_eq!(catalog.classify_tool_with_mcp("Read"), ToolEffect::ReadOnly);
assert_eq!(
catalog.classify_tool_with_mcp("Write"),
ToolEffect::LocalMutation
);
assert_eq!(
catalog.classify_tool_with_mcp("Delete"),
ToolEffect::Destructive
);
}
#[test]
fn classify_call_args_aware_for_bash() {
let catalog = ToolCatalog::new();
let echo = serde_json::json!({"command": "echo hi"});
let rm = serde_json::json!({"command": "rm -rf /tmp/foo"});
assert_eq!(catalog.classify_call("Bash", &echo), ToolEffect::ReadOnly);
assert_eq!(catalog.classify_call("Bash", &rm), ToolEffect::Destructive,);
}
#[test]
fn default_static_returns_same_instance() {
let a = ToolCatalog::default_static();
let b = ToolCatalog::default_static();
assert!(std::ptr::eq(a, b));
assert!(a.has_tool("Read"));
assert!(a.has_tool("Bash"));
}
#[test]
fn classify_tool_with_mcp_unknown_mcp_returns_remote_action() {
let catalog = ToolCatalog::new();
let effect = catalog.classify_tool_with_mcp("someserver__sometool");
assert_eq!(effect, ToolEffect::RemoteAction);
}
#[test]
fn mcp_manager_starts_empty() {
let catalog = ToolCatalog::new();
assert!(catalog.mcp_manager().is_none());
}
}