1use regex::Regex;
4use serde::Deserialize;
5use serde_json::{Value, json};
6use std::collections::HashMap;
7
8use crate::ToolRegistry;
9use brainwires_core::{Tool, ToolContext, ToolInputSchema, ToolResult};
10
11#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Default)]
13#[serde(rename_all = "lowercase")]
14pub enum SearchMode {
15 #[default]
17 Keyword,
18 Regex,
20 Semantic,
22}
23
24pub struct ToolSearchTool;
26
27impl ToolSearchTool {
28 pub fn get_tools() -> Vec<Tool> {
30 vec![Self::search_tools_tool()]
31 }
32
33 fn search_tools_tool() -> Tool {
34 let mut properties = HashMap::new();
35 properties.insert(
36 "query".to_string(),
37 json!({"type": "string", "description": "Search query to find relevant tools"}),
38 );
39 properties.insert("mode".to_string(), json!({"type": "string", "enum": ["keyword", "regex", "semantic"], "description": "Search mode: keyword (substring match), regex (pattern match), or semantic (embedding similarity, requires rag feature)", "default": "keyword"}));
40 properties.insert(
41 "include_deferred".to_string(),
42 json!({"type": "boolean", "description": "Include deferred tools", "default": true}),
43 );
44 properties.insert(
45 "limit".to_string(),
46 json!({"type": "integer", "description": "Maximum number of results to return (semantic mode only)", "default": 10}),
47 );
48 properties.insert(
49 "min_score".to_string(),
50 json!({"type": "number", "description": "Minimum similarity score 0.0-1.0 (semantic mode only)", "default": 0.3}),
51 );
52 Tool {
53 name: "search_tools".to_string(),
54 description: "Search for available tools by name or description.".to_string(),
55 input_schema: ToolInputSchema::object(properties, vec!["query".to_string()]),
56 requires_approval: false,
57 defer_loading: false,
58 ..Default::default()
59 }
60 }
61
62 #[tracing::instrument(
64 name = "tool.execute",
65 skip(input, _context, registry),
66 fields(tool_name)
67 )]
68 pub fn execute(
69 tool_use_id: &str,
70 tool_name: &str,
71 input: &Value,
72 _context: &ToolContext,
73 registry: &ToolRegistry,
74 ) -> ToolResult {
75 let result = match tool_name {
76 "search_tools" => Self::search_tools(input, registry),
77 _ => Err(anyhow::anyhow!("Unknown tool search tool: {}", tool_name)),
78 };
79 match result {
80 Ok(output) => ToolResult::success(tool_use_id.to_string(), output),
81 Err(e) => ToolResult::error(
82 tool_use_id.to_string(),
83 format!("Tool search failed: {}", e),
84 ),
85 }
86 }
87
88 fn search_tools(input: &Value, registry: &ToolRegistry) -> anyhow::Result<String> {
89 #[derive(Deserialize)]
90 #[allow(dead_code)] struct Input {
92 query: String,
93 #[serde(default)]
94 mode: SearchMode,
95 #[serde(default = "dt")]
96 include_deferred: bool,
97 #[serde(default = "default_limit")]
98 limit: usize,
99 #[serde(default = "default_min_score")]
100 min_score: f32,
101 }
102 fn dt() -> bool {
103 true
104 }
105 fn default_limit() -> usize {
106 10
107 }
108 fn default_min_score() -> f32 {
109 0.3
110 }
111
112 let params: Input = serde_json::from_value(input.clone())?;
113
114 #[cfg(feature = "rag")]
116 if params.mode == SearchMode::Semantic {
117 return Self::search_tools_semantic(
118 ¶ms.query,
119 registry,
120 params.include_deferred,
121 params.limit,
122 params.min_score,
123 );
124 }
125
126 #[cfg(not(feature = "rag"))]
127 if params.mode == SearchMode::Semantic {
128 return Err(anyhow::anyhow!(
129 "Semantic search mode requires the 'rag' feature to be enabled. Use 'keyword' or 'regex' mode instead."
130 ));
131 }
132
133 if params.mode == SearchMode::Regex && params.query.len() > 200 {
134 return Err(anyhow::anyhow!(
135 "Regex pattern exceeds maximum length of 200 characters (got {})",
136 params.query.len()
137 ));
138 }
139
140 let regex =
141 if params.mode == SearchMode::Regex {
142 Some(Regex::new(¶ms.query).map_err(|e| {
143 anyhow::anyhow!("Invalid regex pattern '{}': {}", params.query, e)
144 })?)
145 } else {
146 None
147 };
148
149 let query_lower = params.query.to_lowercase();
150 let query_terms: Vec<&str> = query_lower.split_whitespace().collect();
151
152 let matching_tools: Vec<&Tool> = registry
153 .get_all()
154 .iter()
155 .filter(|tool| {
156 if tool.defer_loading && !params.include_deferred {
157 return false;
158 }
159 let search_text = format!("{} {}", tool.name, tool.description);
160 match ®ex {
161 Some(re) => re.is_match(&search_text),
162 None => {
163 let name_lower = tool.name.to_lowercase();
164 let desc_lower = tool.description.to_lowercase();
165 query_terms
166 .iter()
167 .any(|term| name_lower.contains(term) || desc_lower.contains(term))
168 }
169 }
170 })
171 .collect();
172
173 if matching_tools.is_empty() {
174 return Ok(format!(
175 "No tools found matching query: \"{}\"",
176 params.query
177 ));
178 }
179
180 let mut result = format!(
181 "Found {} tools matching \"{}\":\n\n",
182 matching_tools.len(),
183 params.query
184 );
185 for tool in matching_tools {
186 Self::format_tool(&mut result, tool, None);
187 }
188 Ok(result)
189 }
190
191 fn format_tool(result: &mut String, tool: &Tool, score: Option<f32>) {
193 result.push_str(&format!("## {}\n", tool.name));
194 if let Some(s) = score {
195 result.push_str(&format!("**Similarity:** {:.2}\n", s));
196 }
197 result.push_str(&format!("**Description:** {}\n", tool.description));
198 if let Some(props) = &tool.input_schema.properties {
199 result.push_str("**Parameters:**\n");
200 for (name, schema) in props {
201 let desc = schema
202 .get("description")
203 .and_then(|v| v.as_str())
204 .unwrap_or("No description");
205 let ptype = schema
206 .get("type")
207 .and_then(|v| v.as_str())
208 .unwrap_or("unknown");
209 result.push_str(&format!(" - `{}` ({}): {}\n", name, ptype, desc));
210 }
211 }
212 result.push('\n');
213 }
214
215 #[cfg(feature = "rag")]
217 fn search_tools_semantic(
218 query: &str,
219 registry: &ToolRegistry,
220 include_deferred: bool,
221 limit: usize,
222 min_score: f32,
223 ) -> anyhow::Result<String> {
224 use crate::tool_embedding::ToolEmbeddingIndex;
225 use std::sync::OnceLock;
226
227 static CACHED_INDEX: OnceLock<(usize, ToolEmbeddingIndex)> = OnceLock::new();
229
230 let tools: Vec<&Tool> = registry
231 .get_all()
232 .iter()
233 .filter(|t| include_deferred || !t.defer_loading)
234 .collect();
235
236 let tool_pairs: Vec<(String, String)> = tools
238 .iter()
239 .map(|t| (t.name.clone(), t.description.clone()))
240 .collect();
241
242 let index = CACHED_INDEX.get_or_init(|| {
247 let idx = ToolEmbeddingIndex::build(&tool_pairs)
248 .expect("Failed to build tool embedding index");
249 (tool_pairs.len(), idx)
250 });
251
252 let search_results = if index.0 != tool_pairs.len() {
255 let fresh_index = ToolEmbeddingIndex::build(&tool_pairs)?;
256 fresh_index.search(query, limit, min_score)?
257 } else {
258 index.1.search(query, limit, min_score)?
259 };
260
261 if search_results.is_empty() {
262 return Ok(format!(
263 "No tools found semantically matching query: \"{}\" (min_score: {:.2})",
264 query, min_score
265 ));
266 }
267
268 let mut result = format!(
269 "Found {} tools semantically matching \"{}\":\n\n",
270 search_results.len(),
271 query
272 );
273
274 for (tool_name, score) in &search_results {
275 if let Some(tool) = registry.get(tool_name) {
276 Self::format_tool(&mut result, tool, Some(*score));
277 }
278 }
279
280 Ok(result)
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[test]
289 fn test_get_tools() {
290 let tools = ToolSearchTool::get_tools();
291 assert_eq!(tools.len(), 1);
292 assert_eq!(tools[0].name, "search_tools");
293 }
294
295 #[test]
296 fn test_search_mode_default() {
297 let mode = SearchMode::default();
298 assert_eq!(mode, SearchMode::Keyword);
299 }
300}