use regex::Regex;
use serde::Deserialize;
use serde_json::{Value, json};
use std::collections::HashMap;
use crate::ToolRegistry;
use brainwires_core::{Tool, ToolContext, ToolInputSchema, ToolResult};
#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum SearchMode {
#[default]
Keyword,
Regex,
Semantic,
}
pub struct ToolSearchTool;
impl ToolSearchTool {
pub fn get_tools() -> Vec<Tool> {
vec![Self::search_tools_tool()]
}
fn search_tools_tool() -> Tool {
let mut properties = HashMap::new();
properties.insert(
"query".to_string(),
json!({"type": "string", "description": "Search query to find relevant tools"}),
);
properties.insert("mode".to_string(), json!({"type": "string", "enum": ["keyword", "regex", "semantic"], "description": "Search mode: keyword (substring match), regex (pattern match), or semantic (embedding similarity, requires rag feature)", "default": "keyword"}));
properties.insert(
"include_deferred".to_string(),
json!({"type": "boolean", "description": "Include deferred tools", "default": true}),
);
properties.insert(
"limit".to_string(),
json!({"type": "integer", "description": "Maximum number of results to return (semantic mode only)", "default": 10}),
);
properties.insert(
"min_score".to_string(),
json!({"type": "number", "description": "Minimum similarity score 0.0-1.0 (semantic mode only)", "default": 0.3}),
);
Tool {
name: "search_tools".to_string(),
description: "Search for available tools by name or description.".to_string(),
input_schema: ToolInputSchema::object(properties, vec!["query".to_string()]),
requires_approval: false,
defer_loading: false,
..Default::default()
}
}
#[tracing::instrument(
name = "tool.execute",
skip(input, _context, registry),
fields(tool_name)
)]
pub fn execute(
tool_use_id: &str,
tool_name: &str,
input: &Value,
_context: &ToolContext,
registry: &ToolRegistry,
) -> ToolResult {
let result = match tool_name {
"search_tools" => Self::search_tools(input, registry),
_ => Err(anyhow::anyhow!("Unknown tool search tool: {}", tool_name)),
};
match result {
Ok(output) => ToolResult::success(tool_use_id.to_string(), output),
Err(e) => ToolResult::error(
tool_use_id.to_string(),
format!("Tool search failed: {}", e),
),
}
}
fn search_tools(input: &Value, registry: &ToolRegistry) -> anyhow::Result<String> {
#[derive(Deserialize)]
#[allow(dead_code)] struct Input {
query: String,
#[serde(default)]
mode: SearchMode,
#[serde(default = "dt")]
include_deferred: bool,
#[serde(default = "default_limit")]
limit: usize,
#[serde(default = "default_min_score")]
min_score: f32,
}
fn dt() -> bool {
true
}
fn default_limit() -> usize {
10
}
fn default_min_score() -> f32 {
0.3
}
let params: Input = serde_json::from_value(input.clone())?;
#[cfg(feature = "rag")]
if params.mode == SearchMode::Semantic {
return Self::search_tools_semantic(
¶ms.query,
registry,
params.include_deferred,
params.limit,
params.min_score,
);
}
#[cfg(not(feature = "rag"))]
if params.mode == SearchMode::Semantic {
return Err(anyhow::anyhow!(
"Semantic search mode requires the 'rag' feature to be enabled. Use 'keyword' or 'regex' mode instead."
));
}
if params.mode == SearchMode::Regex && params.query.len() > 200 {
return Err(anyhow::anyhow!(
"Regex pattern exceeds maximum length of 200 characters (got {})",
params.query.len()
));
}
let regex =
if params.mode == SearchMode::Regex {
Some(Regex::new(¶ms.query).map_err(|e| {
anyhow::anyhow!("Invalid regex pattern '{}': {}", params.query, e)
})?)
} else {
None
};
let query_lower = params.query.to_lowercase();
let query_terms: Vec<&str> = query_lower.split_whitespace().collect();
let matching_tools: Vec<&Tool> = registry
.get_all()
.iter()
.filter(|tool| {
if tool.defer_loading && !params.include_deferred {
return false;
}
let search_text = format!("{} {}", tool.name, tool.description);
match ®ex {
Some(re) => re.is_match(&search_text),
None => {
let name_lower = tool.name.to_lowercase();
let desc_lower = tool.description.to_lowercase();
query_terms
.iter()
.any(|term| name_lower.contains(term) || desc_lower.contains(term))
}
}
})
.collect();
if matching_tools.is_empty() {
return Ok(format!(
"No tools found matching query: \"{}\"",
params.query
));
}
let mut result = format!(
"Found {} tools matching \"{}\":\n\n",
matching_tools.len(),
params.query
);
for tool in matching_tools {
Self::format_tool(&mut result, tool, None);
}
Ok(result)
}
fn format_tool(result: &mut String, tool: &Tool, score: Option<f32>) {
result.push_str(&format!("## {}\n", tool.name));
if let Some(s) = score {
result.push_str(&format!("**Similarity:** {:.2}\n", s));
}
result.push_str(&format!("**Description:** {}\n", tool.description));
if let Some(props) = &tool.input_schema.properties {
result.push_str("**Parameters:**\n");
for (name, schema) in props {
let desc = schema
.get("description")
.and_then(|v| v.as_str())
.unwrap_or("No description");
let ptype = schema
.get("type")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
result.push_str(&format!(" - `{}` ({}): {}\n", name, ptype, desc));
}
}
result.push('\n');
}
#[cfg(feature = "rag")]
fn search_tools_semantic(
query: &str,
registry: &ToolRegistry,
include_deferred: bool,
limit: usize,
min_score: f32,
) -> anyhow::Result<String> {
use crate::tool_embedding::ToolEmbeddingIndex;
use std::sync::OnceLock;
static CACHED_INDEX: OnceLock<(usize, ToolEmbeddingIndex)> = OnceLock::new();
let tools: Vec<&Tool> = registry
.get_all()
.iter()
.filter(|t| include_deferred || !t.defer_loading)
.collect();
let tool_pairs: Vec<(String, String)> = tools
.iter()
.map(|t| (t.name.clone(), t.description.clone()))
.collect();
let index = CACHED_INDEX.get_or_init(|| {
let idx = ToolEmbeddingIndex::build(&tool_pairs)
.expect("Failed to build tool embedding index");
(tool_pairs.len(), idx)
});
let search_results = if index.0 != tool_pairs.len() {
let fresh_index = ToolEmbeddingIndex::build(&tool_pairs)?;
fresh_index.search(query, limit, min_score)?
} else {
index.1.search(query, limit, min_score)?
};
if search_results.is_empty() {
return Ok(format!(
"No tools found semantically matching query: \"{}\" (min_score: {:.2})",
query, min_score
));
}
let mut result = format!(
"Found {} tools semantically matching \"{}\":\n\n",
search_results.len(),
query
);
for (tool_name, score) in &search_results {
if let Some(tool) = registry.get(tool_name) {
Self::format_tool(&mut result, tool, Some(*score));
}
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_tools() {
let tools = ToolSearchTool::get_tools();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name, "search_tools");
}
#[test]
fn test_search_mode_default() {
let mode = SearchMode::default();
assert_eq!(mode, SearchMode::Keyword);
}
}