1use brainwires_core::Tool;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum ToolCategory {
11 FileOps,
13 Search,
15 SemanticSearch,
17 Git,
19 TaskManager,
21 AgentPool,
23 Web,
25 WebSearch,
27 Bash,
29 Planning,
31 Context,
33 Orchestrator,
35 CodeExecution,
37 SessionTask,
39 Validation,
41}
42
43pub struct ToolRegistry {
60 tools: Vec<Tool>,
61}
62
63impl ToolRegistry {
64 pub fn new() -> Self {
66 Self { tools: vec![] }
67 }
68
69 pub fn with_runtime_meta_tools() -> Self {
74 let mut registry = Self::new();
75 registry.register_tools(crate::ToolSearchTool::get_tools());
76 registry
77 }
78
79 pub fn register(&mut self, tool: Tool) {
81 self.tools.push(tool);
82 }
83
84 pub fn register_tools(&mut self, tools: Vec<Tool>) {
86 self.tools.extend(tools);
87 }
88
89 pub fn get_all(&self) -> &[Tool] {
91 &self.tools
92 }
93
94 pub fn get_all_with_extra(&self, extra: &[Tool]) -> Vec<Tool> {
96 let mut all = self.tools.clone();
97 all.extend(extra.iter().cloned());
98 all
99 }
100
101 pub fn get(&self, name: &str) -> Option<&Tool> {
103 self.tools.iter().find(|t| t.name == name)
104 }
105
106 pub fn get_initial_tools(&self) -> Vec<&Tool> {
108 self.tools.iter().filter(|t| !t.defer_loading).collect()
109 }
110
111 pub fn get_deferred_tools(&self) -> Vec<&Tool> {
113 self.tools.iter().filter(|t| t.defer_loading).collect()
114 }
115
116 pub fn search_tools(&self, query: &str) -> Vec<&Tool> {
118 let query_lower = query.to_lowercase();
119 let query_terms: Vec<&str> = query_lower.split_whitespace().collect();
120
121 self.tools
122 .iter()
123 .filter(|tool| {
124 let name_lower = tool.name.to_lowercase();
125 let desc_lower = tool.description.to_lowercase();
126 query_terms
127 .iter()
128 .any(|term| name_lower.contains(term) || desc_lower.contains(term))
129 })
130 .collect()
131 }
132
133 pub fn get_by_category(&self, category: ToolCategory) -> Vec<&Tool> {
135 let names: &[&str] = match category {
136 ToolCategory::FileOps => &[
137 "read_file",
138 "write_file",
139 "edit_file",
140 "patch_file",
141 "list_directory",
142 "search_files",
143 "delete_file",
144 "create_directory",
145 ],
146 ToolCategory::Search => &["search_code", "search_files"],
147 ToolCategory::SemanticSearch => &[
148 "index_codebase",
149 "query_codebase",
150 "search_with_filters",
151 "get_rag_statistics",
152 "clear_rag_index",
153 "search_git_history",
154 ],
155 ToolCategory::Git => &[
156 "git_status",
157 "git_diff",
158 "git_log",
159 "git_stage",
160 "git_unstage",
161 "git_commit",
162 "git_push",
163 "git_pull",
164 "git_fetch",
165 "git_discard",
166 "git_branch",
167 ],
168 ToolCategory::TaskManager => &[
169 "task_create",
170 "task_start",
171 "task_complete",
172 "task_list",
173 "task_skip",
174 "task_add",
175 "task_block",
176 "task_depends",
177 "task_ready",
178 "task_time",
179 ],
180 ToolCategory::AgentPool => &[
181 "agent_spawn",
182 "agent_status",
183 "agent_list",
184 "agent_stop",
185 "agent_await",
186 ],
187 ToolCategory::Web => &["fetch_url"],
188 ToolCategory::WebSearch => &["web_search", "web_browse", "web_scrape"],
189 ToolCategory::Bash => &["execute_command"],
190 ToolCategory::Planning => &["plan_task"],
191 ToolCategory::Context => &["recall_context"],
192 ToolCategory::Orchestrator => &["execute_script"],
193 ToolCategory::CodeExecution => &["execute_code"],
194 ToolCategory::SessionTask => &["task_list_write"],
195 ToolCategory::Validation => &["check_duplicates", "verify_build", "check_syntax"],
196 };
197
198 self.tools
199 .iter()
200 .filter(|t| names.contains(&t.name.as_str()))
201 .collect()
202 }
203
204 pub fn get_all_with_mcp(&self, mcp_tools: &[Tool]) -> Vec<Tool> {
206 self.get_all_with_extra(mcp_tools)
207 }
208
209 pub const CORE_TOOL_NAMES: &'static [&'static str] = &[
215 "edit_file",
216 "execute_command",
217 "git_commit",
218 "git_diff",
219 "git_log",
220 "git_stage",
221 "git_status",
222 "index_codebase",
223 "list_directory",
224 "query_codebase",
225 "read_file",
226 "search_code",
227 "search_tools",
228 "write_file",
229 ];
230
231 pub fn get_core(&self) -> Vec<&Tool> {
235 Self::CORE_TOOL_NAMES
236 .iter()
237 .filter_map(|name| self.tools.iter().find(|t| t.name == *name))
238 .collect()
239 }
240
241 pub fn get_core_with_extras(&self, extra_names: &[String]) -> Vec<&Tool> {
245 let mut out = self.get_core();
246 for name in extra_names {
247 if Self::CORE_TOOL_NAMES.contains(&name.as_str()) {
248 continue; }
250 if let Some(tool) = self.tools.iter().find(|t| t.name == *name) {
251 out.push(tool);
252 }
253 }
254 out
255 }
256
257 pub fn get_primary(&self) -> Vec<&Tool> {
259 let primary_names = ["execute_script", "search_tools"];
260 self.tools
261 .iter()
262 .filter(|t| primary_names.contains(&t.name.as_str()))
263 .collect()
264 }
265
266 #[cfg(feature = "rag")]
271 pub fn semantic_search_tools(
272 &self,
273 query: &str,
274 limit: usize,
275 min_score: f32,
276 ) -> anyhow::Result<Vec<(&Tool, f32)>> {
277 let tool_pairs: Vec<(String, String)> = self
278 .tools
279 .iter()
280 .map(|t| (t.name.clone(), t.description.clone()))
281 .collect();
282
283 let index = crate::tool_embedding::ToolEmbeddingIndex::build(&tool_pairs)?;
284 let results = index.search(query, limit, min_score)?;
285
286 Ok(results
287 .into_iter()
288 .filter_map(|(name, score)| self.get(&name).map(|tool| (tool, score)))
289 .collect())
290 }
291
292 pub fn filtered_view(&self, allow: &[&str]) -> Vec<Tool> {
301 self.tools
302 .iter()
303 .filter(|t| allow.contains(&t.name.as_str()))
304 .cloned()
305 .collect()
306 }
307
308 pub fn len(&self) -> usize {
310 self.tools.len()
311 }
312
313 pub fn is_empty(&self) -> bool {
315 self.tools.is_empty()
316 }
317}
318
319impl Default for ToolRegistry {
320 fn default() -> Self {
321 Self::new()
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use brainwires_core::ToolInputSchema;
329 use std::collections::HashMap;
330
331 fn make_tool(name: &str, defer: bool) -> Tool {
332 Tool {
333 name: name.to_string(),
334 description: format!("A {} tool", name),
335 input_schema: ToolInputSchema::object(HashMap::new(), vec![]),
336 requires_approval: false,
337 defer_loading: defer,
338 ..Default::default()
339 }
340 }
341
342 #[test]
343 fn test_new_is_empty() {
344 let registry = ToolRegistry::new();
345 assert!(registry.is_empty());
346 assert_eq!(registry.len(), 0);
347 }
348
349 #[test]
350 fn test_register_single() {
351 let mut registry = ToolRegistry::new();
352 registry.register(make_tool("test_tool", false));
353 assert_eq!(registry.len(), 1);
354 assert!(registry.get("test_tool").is_some());
355 }
356
357 #[test]
358 fn test_register_multiple() {
359 let mut registry = ToolRegistry::new();
360 registry.register_tools(vec![make_tool("tool1", false), make_tool("tool2", false)]);
361 assert_eq!(registry.len(), 2);
362 }
363
364 #[test]
365 fn test_get_by_name() {
366 let mut registry = ToolRegistry::new();
367 registry.register(make_tool("my_tool", false));
368
369 assert!(registry.get("my_tool").is_some());
370 assert!(registry.get("nonexistent").is_none());
371 }
372
373 #[test]
374 fn test_get_core_preserves_canonical_order() {
375 let mut registry = ToolRegistry::new();
380 for name in ToolRegistry::CORE_TOOL_NAMES.iter().rev() {
381 registry.register(make_tool(name, false));
382 }
383
384 let core_names: Vec<&str> = registry
385 .get_core()
386 .iter()
387 .map(|t| t.name.as_str())
388 .collect();
389 let expected: Vec<&str> = ToolRegistry::CORE_TOOL_NAMES.to_vec();
390 assert_eq!(core_names, expected);
391 }
392
393 #[test]
394 fn test_get_core_with_extras_appends_unknown_core() {
395 let mut registry = ToolRegistry::new();
396 for name in ToolRegistry::CORE_TOOL_NAMES {
397 registry.register(make_tool(name, false));
398 }
399 registry.register(make_tool("extra_one", false));
400 registry.register(make_tool("extra_two", false));
401
402 let extras = vec![
405 "extra_one".to_string(),
406 "read_file".to_string(),
407 "does_not_exist".to_string(),
408 "extra_two".to_string(),
409 ];
410 let names: Vec<&str> = registry
411 .get_core_with_extras(&extras)
412 .iter()
413 .map(|t| t.name.as_str())
414 .collect();
415
416 let mut expected: Vec<&str> = ToolRegistry::CORE_TOOL_NAMES.to_vec();
417 expected.push("extra_one");
418 expected.push("extra_two");
419 assert_eq!(names, expected);
420 }
421
422 #[test]
423 fn test_initial_vs_deferred() {
424 let mut registry = ToolRegistry::new();
425 registry.register(make_tool("initial", false));
426 registry.register(make_tool("deferred", true));
427
428 assert_eq!(registry.get_initial_tools().len(), 1);
429 assert_eq!(registry.get_initial_tools()[0].name, "initial");
430
431 assert_eq!(registry.get_deferred_tools().len(), 1);
432 assert_eq!(registry.get_deferred_tools()[0].name, "deferred");
433 }
434
435 #[test]
436 fn test_search_tools() {
437 let mut registry = ToolRegistry::new();
438 registry.register(Tool {
439 name: "read_file".to_string(),
440 description: "Read a file from disk".to_string(),
441 input_schema: ToolInputSchema::object(HashMap::new(), vec![]),
442 ..Default::default()
443 });
444 registry.register(Tool {
445 name: "write_file".to_string(),
446 description: "Write content to a file".to_string(),
447 input_schema: ToolInputSchema::object(HashMap::new(), vec![]),
448 ..Default::default()
449 });
450 registry.register(Tool {
451 name: "execute_command".to_string(),
452 description: "Execute a bash command".to_string(),
453 input_schema: ToolInputSchema::object(HashMap::new(), vec![]),
454 ..Default::default()
455 });
456
457 let results = registry.search_tools("file");
458 assert_eq!(results.len(), 2);
459
460 let results = registry.search_tools("bash");
461 assert_eq!(results.len(), 1);
462 }
463
464 #[test]
465 fn test_get_all_with_extra() {
466 let mut registry = ToolRegistry::new();
467 registry.register(make_tool("builtin", false));
468
469 let extra = vec![make_tool("mcp_tool", false)];
470 let all = registry.get_all_with_extra(&extra);
471 assert_eq!(all.len(), 2);
472 }
473
474 #[test]
475 fn filtered_view_returns_only_named_tools() {
476 let mut registry = ToolRegistry::new();
477 registry.register(make_tool("read_file", false));
478 registry.register(make_tool("write_file", false));
479 registry.register(make_tool("execute_command", false));
480
481 let view = registry.filtered_view(&["read_file", "execute_command"]);
482 assert_eq!(view.len(), 2);
483 let names: Vec<&str> = view.iter().map(|t| t.name.as_str()).collect();
484 assert!(names.contains(&"read_file"));
485 assert!(names.contains(&"execute_command"));
486 assert!(!names.contains(&"write_file"));
487 }
488
489 #[test]
490 fn filtered_view_unknown_names_are_silently_skipped() {
491 let mut registry = ToolRegistry::new();
492 registry.register(make_tool("read_file", false));
493
494 let view = registry.filtered_view(&["read_file", "nonexistent"]);
496 assert_eq!(view.len(), 1);
497 assert_eq!(view[0].name, "read_file");
498 }
499
500 #[test]
501 fn filtered_view_empty_allow_list_returns_empty() {
502 let mut registry = ToolRegistry::new();
503 registry.register(make_tool("read_file", false));
504
505 let view = registry.filtered_view(&[]);
506 assert!(view.is_empty());
507 }
508}