1use brainwires_core::Tool;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum ToolCategory {
11 FileOps,
13 Search,
15 SemanticSearch,
17 Git,
19 TaskManager,
21 AgentPool,
23 Web,
25 WebSearch,
27 Bash,
29 Planning,
31 Context,
33 Orchestrator,
35 CodeExecution,
37 SessionTask,
39 Validation,
41}
42
43pub struct ToolRegistry {
59 tools: Vec<Tool>,
60}
61
62impl ToolRegistry {
63 pub fn new() -> Self {
65 Self { tools: vec![] }
66 }
67
68 pub fn with_builtins() -> Self {
70 let mut registry = Self::new();
71
72 registry.register_tools(crate::ToolSearchTool::get_tools());
74
75 #[cfg(feature = "native")]
77 {
78 registry.register_tools(crate::FileOpsTool::get_tools());
79 registry.register_tools(crate::BashTool::get_tools());
80 registry.register_tools(crate::GitTool::get_tools());
81 registry.register_tools(crate::WebTool::get_tools());
82 registry.register_tools(crate::SearchTool::get_tools());
83 registry.register_tools(crate::get_validation_tools());
84 }
85
86 #[cfg(feature = "orchestrator")]
88 registry.register_tools(crate::OrchestratorTool::get_tools());
89
90 #[cfg(feature = "interpreters")]
91 registry.register_tools(crate::CodeExecTool::get_tools());
92
93 #[cfg(feature = "rag")]
94 registry.register_tools(crate::SemanticSearchTool::get_tools());
95
96 registry
97 }
98
99 pub fn register(&mut self, tool: Tool) {
101 self.tools.push(tool);
102 }
103
104 pub fn register_tools(&mut self, tools: Vec<Tool>) {
106 self.tools.extend(tools);
107 }
108
109 pub fn get_all(&self) -> &[Tool] {
111 &self.tools
112 }
113
114 pub fn get_all_with_extra(&self, extra: &[Tool]) -> Vec<Tool> {
116 let mut all = self.tools.clone();
117 all.extend(extra.iter().cloned());
118 all
119 }
120
121 pub fn get(&self, name: &str) -> Option<&Tool> {
123 self.tools.iter().find(|t| t.name == name)
124 }
125
126 pub fn get_initial_tools(&self) -> Vec<&Tool> {
128 self.tools.iter().filter(|t| !t.defer_loading).collect()
129 }
130
131 pub fn get_deferred_tools(&self) -> Vec<&Tool> {
133 self.tools.iter().filter(|t| t.defer_loading).collect()
134 }
135
136 pub fn search_tools(&self, query: &str) -> Vec<&Tool> {
138 let query_lower = query.to_lowercase();
139 let query_terms: Vec<&str> = query_lower.split_whitespace().collect();
140
141 self.tools
142 .iter()
143 .filter(|tool| {
144 let name_lower = tool.name.to_lowercase();
145 let desc_lower = tool.description.to_lowercase();
146 query_terms
147 .iter()
148 .any(|term| name_lower.contains(term) || desc_lower.contains(term))
149 })
150 .collect()
151 }
152
153 pub fn get_by_category(&self, category: ToolCategory) -> Vec<&Tool> {
155 let names: &[&str] = match category {
156 ToolCategory::FileOps => &[
157 "read_file",
158 "write_file",
159 "edit_file",
160 "patch_file",
161 "list_directory",
162 "search_files",
163 "delete_file",
164 "create_directory",
165 ],
166 ToolCategory::Search => &["search_code", "search_files"],
167 ToolCategory::SemanticSearch => &[
168 "index_codebase",
169 "query_codebase",
170 "search_with_filters",
171 "get_rag_statistics",
172 "clear_rag_index",
173 "search_git_history",
174 ],
175 ToolCategory::Git => &[
176 "git_status",
177 "git_diff",
178 "git_log",
179 "git_stage",
180 "git_unstage",
181 "git_commit",
182 "git_push",
183 "git_pull",
184 "git_fetch",
185 "git_discard",
186 "git_branch",
187 ],
188 ToolCategory::TaskManager => &[
189 "task_create",
190 "task_start",
191 "task_complete",
192 "task_list",
193 "task_skip",
194 "task_add",
195 "task_block",
196 "task_depends",
197 "task_ready",
198 "task_time",
199 ],
200 ToolCategory::AgentPool => &[
201 "agent_spawn",
202 "agent_status",
203 "agent_list",
204 "agent_stop",
205 "agent_await",
206 ],
207 ToolCategory::Web => &["fetch_url"],
208 ToolCategory::WebSearch => &["web_search", "web_browse", "web_scrape"],
209 ToolCategory::Bash => &["execute_command"],
210 ToolCategory::Planning => &["plan_task"],
211 ToolCategory::Context => &["recall_context"],
212 ToolCategory::Orchestrator => &["execute_script"],
213 ToolCategory::CodeExecution => &["execute_code"],
214 ToolCategory::SessionTask => &["task_list_write"],
215 ToolCategory::Validation => &["check_duplicates", "verify_build", "check_syntax"],
216 };
217
218 self.tools
219 .iter()
220 .filter(|t| names.contains(&t.name.as_str()))
221 .collect()
222 }
223
224 pub fn get_all_with_mcp(&self, mcp_tools: &[Tool]) -> Vec<Tool> {
226 self.get_all_with_extra(mcp_tools)
227 }
228
229 pub const CORE_TOOL_NAMES: &'static [&'static str] = &[
235 "edit_file",
236 "execute_command",
237 "git_commit",
238 "git_diff",
239 "git_log",
240 "git_stage",
241 "git_status",
242 "index_codebase",
243 "list_directory",
244 "query_codebase",
245 "read_file",
246 "search_code",
247 "search_tools",
248 "write_file",
249 ];
250
251 pub fn get_core(&self) -> Vec<&Tool> {
255 Self::CORE_TOOL_NAMES
256 .iter()
257 .filter_map(|name| self.tools.iter().find(|t| t.name == *name))
258 .collect()
259 }
260
261 pub fn get_core_with_extras(&self, extra_names: &[String]) -> Vec<&Tool> {
265 let mut out = self.get_core();
266 for name in extra_names {
267 if Self::CORE_TOOL_NAMES.contains(&name.as_str()) {
268 continue; }
270 if let Some(tool) = self.tools.iter().find(|t| t.name == *name) {
271 out.push(tool);
272 }
273 }
274 out
275 }
276
277 pub fn get_primary(&self) -> Vec<&Tool> {
279 let primary_names = ["execute_script", "search_tools"];
280 self.tools
281 .iter()
282 .filter(|t| primary_names.contains(&t.name.as_str()))
283 .collect()
284 }
285
286 #[cfg(feature = "rag")]
291 pub fn semantic_search_tools(
292 &self,
293 query: &str,
294 limit: usize,
295 min_score: f32,
296 ) -> anyhow::Result<Vec<(&Tool, f32)>> {
297 let tool_pairs: Vec<(String, String)> = self
298 .tools
299 .iter()
300 .map(|t| (t.name.clone(), t.description.clone()))
301 .collect();
302
303 let index = crate::tool_embedding::ToolEmbeddingIndex::build(&tool_pairs)?;
304 let results = index.search(query, limit, min_score)?;
305
306 Ok(results
307 .into_iter()
308 .filter_map(|(name, score)| self.get(&name).map(|tool| (tool, score)))
309 .collect())
310 }
311
312 pub fn filtered_view(&self, allow: &[&str]) -> Vec<Tool> {
321 self.tools
322 .iter()
323 .filter(|t| allow.contains(&t.name.as_str()))
324 .cloned()
325 .collect()
326 }
327
328 pub fn len(&self) -> usize {
330 self.tools.len()
331 }
332
333 pub fn is_empty(&self) -> bool {
335 self.tools.is_empty()
336 }
337}
338
339impl Default for ToolRegistry {
340 fn default() -> Self {
341 Self::new()
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use brainwires_core::ToolInputSchema;
349 use std::collections::HashMap;
350
351 fn make_tool(name: &str, defer: bool) -> Tool {
352 Tool {
353 name: name.to_string(),
354 description: format!("A {} tool", name),
355 input_schema: ToolInputSchema::object(HashMap::new(), vec![]),
356 requires_approval: false,
357 defer_loading: defer,
358 ..Default::default()
359 }
360 }
361
362 #[test]
363 fn test_new_is_empty() {
364 let registry = ToolRegistry::new();
365 assert!(registry.is_empty());
366 assert_eq!(registry.len(), 0);
367 }
368
369 #[test]
370 fn test_register_single() {
371 let mut registry = ToolRegistry::new();
372 registry.register(make_tool("test_tool", false));
373 assert_eq!(registry.len(), 1);
374 assert!(registry.get("test_tool").is_some());
375 }
376
377 #[test]
378 fn test_register_multiple() {
379 let mut registry = ToolRegistry::new();
380 registry.register_tools(vec![make_tool("tool1", false), make_tool("tool2", false)]);
381 assert_eq!(registry.len(), 2);
382 }
383
384 #[test]
385 fn test_get_by_name() {
386 let mut registry = ToolRegistry::new();
387 registry.register(make_tool("my_tool", false));
388
389 assert!(registry.get("my_tool").is_some());
390 assert!(registry.get("nonexistent").is_none());
391 }
392
393 #[test]
394 fn test_get_core_preserves_canonical_order() {
395 let mut registry = ToolRegistry::new();
400 for name in ToolRegistry::CORE_TOOL_NAMES.iter().rev() {
401 registry.register(make_tool(name, false));
402 }
403
404 let core_names: Vec<&str> = registry
405 .get_core()
406 .iter()
407 .map(|t| t.name.as_str())
408 .collect();
409 let expected: Vec<&str> = ToolRegistry::CORE_TOOL_NAMES.to_vec();
410 assert_eq!(core_names, expected);
411 }
412
413 #[test]
414 fn test_get_core_with_extras_appends_unknown_core() {
415 let mut registry = ToolRegistry::new();
416 for name in ToolRegistry::CORE_TOOL_NAMES {
417 registry.register(make_tool(name, false));
418 }
419 registry.register(make_tool("extra_one", false));
420 registry.register(make_tool("extra_two", false));
421
422 let extras = vec![
425 "extra_one".to_string(),
426 "read_file".to_string(),
427 "does_not_exist".to_string(),
428 "extra_two".to_string(),
429 ];
430 let names: Vec<&str> = registry
431 .get_core_with_extras(&extras)
432 .iter()
433 .map(|t| t.name.as_str())
434 .collect();
435
436 let mut expected: Vec<&str> = ToolRegistry::CORE_TOOL_NAMES.to_vec();
437 expected.push("extra_one");
438 expected.push("extra_two");
439 assert_eq!(names, expected);
440 }
441
442 #[test]
443 fn test_initial_vs_deferred() {
444 let mut registry = ToolRegistry::new();
445 registry.register(make_tool("initial", false));
446 registry.register(make_tool("deferred", true));
447
448 assert_eq!(registry.get_initial_tools().len(), 1);
449 assert_eq!(registry.get_initial_tools()[0].name, "initial");
450
451 assert_eq!(registry.get_deferred_tools().len(), 1);
452 assert_eq!(registry.get_deferred_tools()[0].name, "deferred");
453 }
454
455 #[test]
456 fn test_search_tools() {
457 let mut registry = ToolRegistry::new();
458 registry.register(Tool {
459 name: "read_file".to_string(),
460 description: "Read a file from disk".to_string(),
461 input_schema: ToolInputSchema::object(HashMap::new(), vec![]),
462 ..Default::default()
463 });
464 registry.register(Tool {
465 name: "write_file".to_string(),
466 description: "Write content to a file".to_string(),
467 input_schema: ToolInputSchema::object(HashMap::new(), vec![]),
468 ..Default::default()
469 });
470 registry.register(Tool {
471 name: "execute_command".to_string(),
472 description: "Execute a bash command".to_string(),
473 input_schema: ToolInputSchema::object(HashMap::new(), vec![]),
474 ..Default::default()
475 });
476
477 let results = registry.search_tools("file");
478 assert_eq!(results.len(), 2);
479
480 let results = registry.search_tools("bash");
481 assert_eq!(results.len(), 1);
482 }
483
484 #[test]
485 fn test_get_all_with_extra() {
486 let mut registry = ToolRegistry::new();
487 registry.register(make_tool("builtin", false));
488
489 let extra = vec![make_tool("mcp_tool", false)];
490 let all = registry.get_all_with_extra(&extra);
491 assert_eq!(all.len(), 2);
492 }
493
494 #[test]
495 fn test_no_duplicate_names_in_builtins() {
496 let registry = ToolRegistry::with_builtins();
497 let mut seen = std::collections::HashSet::new();
498 for tool in registry.get_all() {
499 assert!(
500 seen.insert(tool.name.clone()),
501 "Duplicate tool name: {}",
502 tool.name
503 );
504 }
505 }
506
507 #[test]
508 fn filtered_view_returns_only_named_tools() {
509 let mut registry = ToolRegistry::new();
510 registry.register(make_tool("read_file", false));
511 registry.register(make_tool("write_file", false));
512 registry.register(make_tool("execute_command", false));
513
514 let view = registry.filtered_view(&["read_file", "execute_command"]);
515 assert_eq!(view.len(), 2);
516 let names: Vec<&str> = view.iter().map(|t| t.name.as_str()).collect();
517 assert!(names.contains(&"read_file"));
518 assert!(names.contains(&"execute_command"));
519 assert!(!names.contains(&"write_file"));
520 }
521
522 #[test]
523 fn filtered_view_unknown_names_are_silently_skipped() {
524 let mut registry = ToolRegistry::new();
525 registry.register(make_tool("read_file", false));
526
527 let view = registry.filtered_view(&["read_file", "nonexistent"]);
529 assert_eq!(view.len(), 1);
530 assert_eq!(view[0].name, "read_file");
531 }
532
533 #[test]
534 fn filtered_view_empty_allow_list_returns_empty() {
535 let mut registry = ToolRegistry::new();
536 registry.register(make_tool("read_file", false));
537
538 let view = registry.filtered_view(&[]);
539 assert!(view.is_empty());
540 }
541}