brainwires_tool_system/
search.rs1use anyhow::Result;
2use ignore::WalkBuilder;
3use regex::Regex;
4use serde::Deserialize;
5use serde_json::{Value, json};
6use std::collections::HashMap;
7use std::fs;
8
9use brainwires_core::{Tool, ToolContext, ToolInputSchema, ToolResult};
10
11pub struct SearchTool;
13
14impl SearchTool {
15 pub fn get_tools() -> Vec<Tool> {
17 vec![Self::search_code_tool()]
18 }
19
20 fn search_code_tool() -> Tool {
21 let mut properties = HashMap::new();
22 properties.insert(
23 "pattern".to_string(),
24 json!({"type": "string", "description": "Regex pattern to search for"}),
25 );
26 properties.insert(
27 "path".to_string(),
28 json!({"type": "string", "description": "Path to search in", "default": "."}),
29 );
30 Tool {
31 name: "search_code".to_string(),
32 description: "Search for code patterns in files using regex.".to_string(),
33 input_schema: ToolInputSchema::object(properties, vec!["pattern".to_string()]),
34 requires_approval: false,
35 ..Default::default()
36 }
37 }
38
39 #[tracing::instrument(name = "tool.execute", skip(input, context), fields(tool_name))]
41 pub fn execute(
42 tool_use_id: &str,
43 tool_name: &str,
44 input: &Value,
45 context: &ToolContext,
46 ) -> ToolResult {
47 let result = match tool_name {
48 "search_code" => Self::search_code(input, context),
49 _ => Err(anyhow::anyhow!("Unknown search tool: {}", tool_name)),
50 };
51 match result {
52 Ok(output) => ToolResult::success(tool_use_id.to_string(), output),
53 Err(e) => ToolResult::error(tool_use_id.to_string(), format!("Search failed: {}", e)),
54 }
55 }
56
57 fn search_code(input: &Value, context: &ToolContext) -> Result<String> {
58 #[derive(Deserialize)]
59 struct Input {
60 pattern: String,
61 #[serde(default = "default_path")]
62 path: String,
63 }
64 fn default_path() -> String {
65 ".".to_string()
66 }
67
68 let params: Input = serde_json::from_value(input.clone())?;
69 let regex = Regex::new(¶ms.pattern)?;
70 let search_path = if params.path == "." {
71 &context.working_directory
72 } else {
73 ¶ms.path
74 };
75
76 let mut matches = Vec::new();
77 for entry in WalkBuilder::new(search_path).build() {
78 let entry = entry?;
79 if entry.path().is_file()
80 && let Ok(content) = fs::read_to_string(entry.path())
81 {
82 for (line_num, line) in content.lines().enumerate() {
83 if regex.is_match(line) {
84 matches.push(format!(
85 "{}:{} - {}",
86 entry.path().display(),
87 line_num + 1,
88 line.trim()
89 ));
90 if matches.len() >= 100 {
91 break;
92 }
93 }
94 }
95 }
96 }
97 Ok(format!(
98 "Search Results:\nPattern: {}\nMatches: {}\n\n{}",
99 params.pattern,
100 matches.len(),
101 matches.join("\n")
102 ))
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109
110 fn create_test_context() -> ToolContext {
111 ToolContext {
112 working_directory: std::env::current_dir()
113 .unwrap()
114 .to_str()
115 .unwrap()
116 .to_string(),
117 ..Default::default()
118 }
119 }
120
121 #[test]
122 fn test_get_tools() {
123 let tools = SearchTool::get_tools();
124 assert_eq!(tools.len(), 1);
125 assert_eq!(tools[0].name, "search_code");
126 }
127
128 #[test]
129 fn test_execute_unknown_tool() {
130 let context = create_test_context();
131 let input = json!({"pattern": "test"});
132 let result = SearchTool::execute("1", "unknown_tool", &input, &context);
133 assert!(result.is_error);
134 }
135}