use super::error::{Result, ToolError};
use super::r#trait::{Tool, ToolExecutionContext, ToolResult};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
const PARAM_ALIASES: &[(&str, &str, &str)] = &[
("grep", "query", "pattern"),
("glob", "query", "pattern"),
("read_file", "file", "path"),
("read_file", "file_path", "path"),
("read_file", "filepath", "path"),
("write_file", "file", "path"),
("write_file", "file_path", "path"),
("write_file", "filepath", "path"),
("edit_file", "file", "path"),
("edit_file", "file_path", "path"),
("edit_file", "filepath", "path"),
("edit_file", "old_string", "old_text"),
("edit_file", "new_string", "new_text"),
("doc_parser", "file", "path"),
("doc_parser", "file_path", "path"),
("write_file", "text", "content"),
("write_file", "body", "content"),
("bash", "cmd", "command"),
("web_search", "pattern", "query"),
("exa_search", "pattern", "query"),
("brave_search", "pattern", "query"),
("memory_search", "pattern", "query"),
];
fn normalize_tool_input(tool_name: &str, mut input: Value) -> Value {
if let Some(obj) = input.as_object_mut() {
for &(tool, wrong, correct) in PARAM_ALIASES {
if tool == tool_name
&& !obj.contains_key(correct)
&& let Some(val) = obj.remove(wrong)
{
tracing::debug!(
"Normalized tool param: {}.{} → {}.{}",
tool_name,
wrong,
tool_name,
correct
);
obj.insert(correct.to_string(), val);
}
}
}
input
}
pub struct ToolRegistry {
tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
session_active: RwLock<HashMap<uuid::Uuid, std::collections::HashSet<String>>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: RwLock::new(HashMap::new()),
session_active: RwLock::new(HashMap::new()),
}
}
pub fn activate_tools(&self, session_id: uuid::Uuid, names: impl IntoIterator<Item = String>) {
let mut map = self.session_active.write().unwrap();
map.entry(session_id).or_default().extend(names);
}
pub fn active_tools(&self, session_id: uuid::Uuid) -> std::collections::HashSet<String> {
self.session_active
.read()
.unwrap()
.get(&session_id)
.cloned()
.unwrap_or_default()
}
pub fn register(&self, tool: Arc<dyn Tool>) {
let name = tool.name().to_string();
tracing::debug!("Registered tool: {}", name);
self.tools.write().unwrap().insert(name, tool);
}
pub fn unregister(&self, name: &str) -> bool {
self.tools.write().unwrap().remove(name).is_some()
}
pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
self.tools.read().unwrap().get(name).cloned()
}
pub fn has_tool(&self, name: &str) -> bool {
self.tools.read().unwrap().contains_key(name)
}
pub fn list_tools(&self) -> Vec<String> {
self.tools.read().unwrap().keys().cloned().collect()
}
pub fn get_tool_definitions(&self) -> Vec<crate::brain::provider::Tool> {
self.tools
.read()
.unwrap()
.values()
.map(|tool| crate::brain::provider::Tool {
name: tool.name().to_string(),
description: tool.description().to_string(),
input_schema: tool.input_schema(),
})
.collect()
}
pub fn get_tool_definitions_filtered(
&self,
active_extended: &std::collections::HashSet<String>,
) -> Vec<crate::brain::provider::Tool> {
use crate::brain::tools::catalog;
self.tools
.read()
.unwrap()
.values()
.filter(|tool| {
let name = tool.name();
catalog::is_core(name) || active_extended.contains(name)
})
.map(|tool| crate::brain::provider::Tool {
name: tool.name().to_string(),
description: tool.description().to_string(),
input_schema: tool.input_schema(),
})
.collect()
}
pub fn search_tools(&self, query: &str, limit: usize) -> Vec<(String, String, String)> {
use crate::brain::tools::catalog;
let q = query.to_ascii_lowercase();
let terms: Vec<&str> = q.split_whitespace().filter(|t| t.len() > 1).collect();
let mut scored: Vec<(i32, String, String, String)> = self
.tools
.read()
.unwrap()
.values()
.filter(|tool| !catalog::is_core(tool.name()))
.map(|tool| {
let name = tool.name().to_string();
let desc = tool.description().to_string();
let category = catalog::tool_category(&name).to_string();
let hay = format!("{name} {category} {desc}").to_ascii_lowercase();
let mut score = 0i32;
for term in &terms {
if name.to_ascii_lowercase().contains(term) {
score += 5; } else if category.contains(term) {
score += 3;
} else if hay.contains(term) {
score += 1;
}
}
(score, name, category, desc)
})
.filter(|(score, ..)| *score > 0)
.collect();
scored.sort_by(|a, b| b.0.cmp(&a.0).then_with(|| a.1.cmp(&b.1)));
scored
.into_iter()
.take(limit)
.map(|(_, name, category, desc)| (name, category, desc))
.collect()
}
pub fn definitions_for(
&self,
names: &std::collections::HashSet<String>,
) -> Vec<crate::brain::provider::Tool> {
self.tools
.read()
.unwrap()
.values()
.filter(|tool| names.contains(tool.name()))
.map(|tool| crate::brain::provider::Tool {
name: tool.name().to_string(),
description: tool.description().to_string(),
input_schema: tool.input_schema(),
})
.collect()
}
pub async fn execute(
&self,
name: &str,
input: Value,
context: &ToolExecutionContext,
) -> Result<ToolResult> {
let (tool, resolved_name) = match self.get(name) {
Some(t) => (t, name.to_string()),
None => {
let registered = self.list_tools();
match super::tool_name_heal::resolve_tool_name(name, ®istered) {
Some(real) => {
tracing::warn!(
"Self-healed tool name: '{}' → '{}' (model called a near-miss name)",
name,
real
);
let t = self
.get(&real)
.ok_or_else(|| ToolError::NotFound(name.to_string()))?;
(t, real)
}
None => return Err(ToolError::NotFound(name.to_string())),
}
}
};
let name = resolved_name.as_str();
if !crate::brain::tools::catalog::is_core(name) {
self.activate_tools(context.session_id, [name.to_string()]);
}
let input = normalize_tool_input(name, input);
tool.validate_input(&input)?;
if tool.requires_approval() && !context.auto_approve {
return Err(ToolError::ApprovalRequired(format!(
"Tool '{}' requires approval before execution",
name
)));
}
tracing::info!("Executing tool: {}", name);
let result = tool.execute(input, context).await?;
if result.success {
tracing::info!("Tool '{}' executed successfully", name);
} else {
tracing::warn!(
"Tool '{}' failed: {:?}",
name,
result.error.as_deref().unwrap_or("unknown error")
);
}
Ok(result)
}
pub fn count(&self) -> usize {
self.tools.read().unwrap().len()
}
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::brain::tools::r#trait::ToolCapability;
use async_trait::async_trait;
use uuid::Uuid;
struct MockTool {
name: String,
requires_approval: bool,
}
#[async_trait]
impl Tool for MockTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
"A mock tool for testing"
}
fn input_schema(&self) -> Value {
serde_json::json!({
"type": "object",
"properties": {
"message": {
"type": "string",
"description": "Test message"
}
},
"required": ["message"]
})
}
fn capabilities(&self) -> Vec<ToolCapability> {
vec![ToolCapability::ReadFiles]
}
fn requires_approval(&self) -> bool {
self.requires_approval
}
async fn execute(
&self,
_input: Value,
_context: &ToolExecutionContext,
) -> Result<ToolResult> {
Ok(ToolResult::success("Mock execution successful".to_string()))
}
}
#[test]
fn test_registry_creation() {
let registry = ToolRegistry::new();
assert_eq!(registry.count(), 0);
}
#[test]
fn test_register_tool() {
let registry = ToolRegistry::new();
let tool = Arc::new(MockTool {
name: "test_tool".to_string(),
requires_approval: false,
});
registry.register(tool);
assert_eq!(registry.count(), 1);
assert!(registry.has_tool("test_tool"));
assert!(!registry.has_tool("nonexistent"));
}
#[test]
fn test_list_tools() {
let registry = ToolRegistry::new();
registry.register(Arc::new(MockTool {
name: "tool1".to_string(),
requires_approval: false,
}));
registry.register(Arc::new(MockTool {
name: "tool2".to_string(),
requires_approval: false,
}));
let tools = registry.list_tools();
assert_eq!(tools.len(), 2);
assert!(tools.contains(&"tool1".to_string()));
assert!(tools.contains(&"tool2".to_string()));
}
#[tokio::test]
async fn test_execute_tool() {
let registry = ToolRegistry::new();
let tool = Arc::new(MockTool {
name: "test_tool".to_string(),
requires_approval: false,
});
registry.register(tool);
let session_id = Uuid::new_v4();
let context = ToolExecutionContext::new(session_id);
let input = serde_json::json!({ "message": "test" });
let result = registry
.execute("test_tool", input, &context)
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "Mock execution successful");
}
#[tokio::test]
async fn test_execute_nonexistent_tool() {
let registry = ToolRegistry::new();
let session_id = Uuid::new_v4();
let context = ToolExecutionContext::new(session_id);
let input = serde_json::json!({});
let result = registry.execute("nonexistent", input, &context).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ToolError::NotFound(_)));
}
#[tokio::test]
async fn test_execute_requires_approval() {
let registry = ToolRegistry::new();
let tool = Arc::new(MockTool {
name: "dangerous_tool".to_string(),
requires_approval: true,
});
registry.register(tool);
let session_id = Uuid::new_v4();
let context = ToolExecutionContext::new(session_id); let input = serde_json::json!({ "message": "test" });
let result = registry.execute("dangerous_tool", input, &context).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ToolError::ApprovalRequired(_)
));
}
struct ValidateFailTool;
#[async_trait]
impl Tool for ValidateFailTool {
fn name(&self) -> &str {
"extended_blind_tool"
}
fn description(&self) -> &str {
"always fails validation"
}
fn input_schema(&self) -> Value {
serde_json::json!({ "type": "object" })
}
fn capabilities(&self) -> Vec<ToolCapability> {
vec![ToolCapability::ReadFiles]
}
fn validate_input(&self, _input: &Value) -> Result<()> {
Err(ToolError::InvalidInput("missing required param".into()))
}
async fn execute(
&self,
_input: Value,
_context: &ToolExecutionContext,
) -> Result<ToolResult> {
Ok(ToolResult::success("unreachable".to_string()))
}
}
#[tokio::test]
async fn test_execute_jit_activates_extended_tool_even_on_failure() {
let registry = ToolRegistry::new();
registry.register(Arc::new(ValidateFailTool));
let session_id = Uuid::new_v4();
let context = ToolExecutionContext::new(session_id);
assert!(
!registry
.active_tools(session_id)
.contains("extended_blind_tool")
);
let result = registry
.execute("extended_blind_tool", serde_json::json!({}), &context)
.await;
assert!(matches!(result.unwrap_err(), ToolError::InvalidInput(_)));
assert!(
registry
.active_tools(session_id)
.contains("extended_blind_tool"),
"extended tool must be activated before validation, even on a failing call"
);
}
#[tokio::test]
async fn test_execute_does_not_activate_core_tool() {
let registry = ToolRegistry::new();
registry.register(Arc::new(MockTool {
name: "bash".to_string(),
requires_approval: false,
}));
let session_id = Uuid::new_v4();
let context = ToolExecutionContext::new(session_id);
let result = registry
.execute("bash", serde_json::json!({ "message": "hi" }), &context)
.await
.unwrap();
assert!(result.success);
assert!(
!registry.active_tools(session_id).contains("bash"),
"core tools must never be added to the session active set"
);
}
#[tokio::test]
async fn test_execute_with_auto_approve() {
let registry = ToolRegistry::new();
let tool = Arc::new(MockTool {
name: "dangerous_tool".to_string(),
requires_approval: true,
});
registry.register(tool);
let session_id = Uuid::new_v4();
let context = ToolExecutionContext::new(session_id).with_auto_approve(true);
let input = serde_json::json!({ "message": "test" });
let result = registry
.execute("dangerous_tool", input, &context)
.await
.unwrap();
assert!(result.success);
}
}