use brainwires_core::Tool;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolCategory {
FileOps,
Search,
SemanticSearch,
Git,
TaskManager,
AgentPool,
Web,
WebSearch,
Bash,
Planning,
Context,
Orchestrator,
CodeExecution,
SessionTask,
Validation,
}
pub struct ToolRegistry {
tools: Vec<Tool>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self { tools: vec![] }
}
pub fn with_builtins() -> Self {
let mut registry = Self::new();
registry.register_tools(crate::ToolSearchTool::get_tools());
#[cfg(feature = "native")]
{
registry.register_tools(crate::FileOpsTool::get_tools());
registry.register_tools(crate::BashTool::get_tools());
registry.register_tools(crate::GitTool::get_tools());
registry.register_tools(crate::WebTool::get_tools());
registry.register_tools(crate::SearchTool::get_tools());
registry.register_tools(crate::get_validation_tools());
}
#[cfg(feature = "orchestrator")]
registry.register_tools(crate::OrchestratorTool::get_tools());
#[cfg(feature = "interpreters")]
registry.register_tools(crate::CodeExecTool::get_tools());
#[cfg(feature = "rag")]
registry.register_tools(crate::SemanticSearchTool::get_tools());
registry
}
pub fn register(&mut self, tool: Tool) {
self.tools.push(tool);
}
pub fn register_tools(&mut self, tools: Vec<Tool>) {
self.tools.extend(tools);
}
pub fn get_all(&self) -> &[Tool] {
&self.tools
}
pub fn get_all_with_extra(&self, extra: &[Tool]) -> Vec<Tool> {
let mut all = self.tools.clone();
all.extend(extra.iter().cloned());
all
}
pub fn get(&self, name: &str) -> Option<&Tool> {
self.tools.iter().find(|t| t.name == name)
}
pub fn get_initial_tools(&self) -> Vec<&Tool> {
self.tools.iter().filter(|t| !t.defer_loading).collect()
}
pub fn get_deferred_tools(&self) -> Vec<&Tool> {
self.tools.iter().filter(|t| t.defer_loading).collect()
}
pub fn search_tools(&self, query: &str) -> Vec<&Tool> {
let query_lower = query.to_lowercase();
let query_terms: Vec<&str> = query_lower.split_whitespace().collect();
self.tools
.iter()
.filter(|tool| {
let name_lower = tool.name.to_lowercase();
let desc_lower = tool.description.to_lowercase();
query_terms
.iter()
.any(|term| name_lower.contains(term) || desc_lower.contains(term))
})
.collect()
}
pub fn get_by_category(&self, category: ToolCategory) -> Vec<&Tool> {
let names: &[&str] = match category {
ToolCategory::FileOps => &[
"read_file",
"write_file",
"edit_file",
"patch_file",
"list_directory",
"search_files",
"delete_file",
"create_directory",
],
ToolCategory::Search => &["search_code", "search_files"],
ToolCategory::SemanticSearch => &[
"index_codebase",
"query_codebase",
"search_with_filters",
"get_rag_statistics",
"clear_rag_index",
"search_git_history",
],
ToolCategory::Git => &[
"git_status",
"git_diff",
"git_log",
"git_stage",
"git_unstage",
"git_commit",
"git_push",
"git_pull",
"git_fetch",
"git_discard",
"git_branch",
],
ToolCategory::TaskManager => &[
"task_create",
"task_start",
"task_complete",
"task_list",
"task_skip",
"task_add",
"task_block",
"task_depends",
"task_ready",
"task_time",
],
ToolCategory::AgentPool => &[
"agent_spawn",
"agent_status",
"agent_list",
"agent_stop",
"agent_await",
],
ToolCategory::Web => &["fetch_url"],
ToolCategory::WebSearch => &["web_search", "web_browse", "web_scrape"],
ToolCategory::Bash => &["execute_command"],
ToolCategory::Planning => &["plan_task"],
ToolCategory::Context => &["recall_context"],
ToolCategory::Orchestrator => &["execute_script"],
ToolCategory::CodeExecution => &["execute_code"],
ToolCategory::SessionTask => &["task_list_write"],
ToolCategory::Validation => &["check_duplicates", "verify_build", "check_syntax"],
};
self.tools
.iter()
.filter(|t| names.contains(&t.name.as_str()))
.collect()
}
pub fn get_all_with_mcp(&self, mcp_tools: &[Tool]) -> Vec<Tool> {
self.get_all_with_extra(mcp_tools)
}
pub fn get_core(&self) -> Vec<&Tool> {
let core_names = [
"read_file",
"write_file",
"edit_file",
"list_directory",
"search_code",
"execute_command",
"git_status",
"git_diff",
"git_log",
"git_stage",
"git_commit",
"search_tools",
"index_codebase",
"query_codebase",
];
self.tools
.iter()
.filter(|t| core_names.contains(&t.name.as_str()))
.collect()
}
pub fn get_primary(&self) -> Vec<&Tool> {
let primary_names = ["execute_script", "search_tools"];
self.tools
.iter()
.filter(|t| primary_names.contains(&t.name.as_str()))
.collect()
}
#[cfg(feature = "rag")]
pub fn semantic_search_tools(
&self,
query: &str,
limit: usize,
min_score: f32,
) -> anyhow::Result<Vec<(&Tool, f32)>> {
let tool_pairs: Vec<(String, String)> = self
.tools
.iter()
.map(|t| (t.name.clone(), t.description.clone()))
.collect();
let index = crate::tool_embedding::ToolEmbeddingIndex::build(&tool_pairs)?;
let results = index.search(query, limit, min_score)?;
Ok(results
.into_iter()
.filter_map(|(name, score)| self.get(&name).map(|tool| (tool, score)))
.collect())
}
pub fn filtered_view(&self, allow: &[&str]) -> Vec<Tool> {
self.tools
.iter()
.filter(|t| allow.contains(&t.name.as_str()))
.cloned()
.collect()
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use brainwires_core::ToolInputSchema;
use std::collections::HashMap;
fn make_tool(name: &str, defer: bool) -> Tool {
Tool {
name: name.to_string(),
description: format!("A {} tool", name),
input_schema: ToolInputSchema::object(HashMap::new(), vec![]),
requires_approval: false,
defer_loading: defer,
..Default::default()
}
}
#[test]
fn test_new_is_empty() {
let registry = ToolRegistry::new();
assert!(registry.is_empty());
assert_eq!(registry.len(), 0);
}
#[test]
fn test_register_single() {
let mut registry = ToolRegistry::new();
registry.register(make_tool("test_tool", false));
assert_eq!(registry.len(), 1);
assert!(registry.get("test_tool").is_some());
}
#[test]
fn test_register_multiple() {
let mut registry = ToolRegistry::new();
registry.register_tools(vec![make_tool("tool1", false), make_tool("tool2", false)]);
assert_eq!(registry.len(), 2);
}
#[test]
fn test_get_by_name() {
let mut registry = ToolRegistry::new();
registry.register(make_tool("my_tool", false));
assert!(registry.get("my_tool").is_some());
assert!(registry.get("nonexistent").is_none());
}
#[test]
fn test_initial_vs_deferred() {
let mut registry = ToolRegistry::new();
registry.register(make_tool("initial", false));
registry.register(make_tool("deferred", true));
assert_eq!(registry.get_initial_tools().len(), 1);
assert_eq!(registry.get_initial_tools()[0].name, "initial");
assert_eq!(registry.get_deferred_tools().len(), 1);
assert_eq!(registry.get_deferred_tools()[0].name, "deferred");
}
#[test]
fn test_search_tools() {
let mut registry = ToolRegistry::new();
registry.register(Tool {
name: "read_file".to_string(),
description: "Read a file from disk".to_string(),
input_schema: ToolInputSchema::object(HashMap::new(), vec![]),
..Default::default()
});
registry.register(Tool {
name: "write_file".to_string(),
description: "Write content to a file".to_string(),
input_schema: ToolInputSchema::object(HashMap::new(), vec![]),
..Default::default()
});
registry.register(Tool {
name: "execute_command".to_string(),
description: "Execute a bash command".to_string(),
input_schema: ToolInputSchema::object(HashMap::new(), vec![]),
..Default::default()
});
let results = registry.search_tools("file");
assert_eq!(results.len(), 2);
let results = registry.search_tools("bash");
assert_eq!(results.len(), 1);
}
#[test]
fn test_get_all_with_extra() {
let mut registry = ToolRegistry::new();
registry.register(make_tool("builtin", false));
let extra = vec![make_tool("mcp_tool", false)];
let all = registry.get_all_with_extra(&extra);
assert_eq!(all.len(), 2);
}
#[test]
fn test_no_duplicate_names_in_builtins() {
let registry = ToolRegistry::with_builtins();
let mut seen = std::collections::HashSet::new();
for tool in registry.get_all() {
assert!(
seen.insert(tool.name.clone()),
"Duplicate tool name: {}",
tool.name
);
}
}
#[test]
fn filtered_view_returns_only_named_tools() {
let mut registry = ToolRegistry::new();
registry.register(make_tool("read_file", false));
registry.register(make_tool("write_file", false));
registry.register(make_tool("execute_command", false));
let view = registry.filtered_view(&["read_file", "execute_command"]);
assert_eq!(view.len(), 2);
let names: Vec<&str> = view.iter().map(|t| t.name.as_str()).collect();
assert!(names.contains(&"read_file"));
assert!(names.contains(&"execute_command"));
assert!(!names.contains(&"write_file"));
}
#[test]
fn filtered_view_unknown_names_are_silently_skipped() {
let mut registry = ToolRegistry::new();
registry.register(make_tool("read_file", false));
let view = registry.filtered_view(&["read_file", "nonexistent"]);
assert_eq!(view.len(), 1);
assert_eq!(view[0].name, "read_file");
}
#[test]
fn filtered_view_empty_allow_list_returns_empty() {
let mut registry = ToolRegistry::new();
registry.register(make_tool("read_file", false));
let view = registry.filtered_view(&[]);
assert!(view.is_empty());
}
}