Skip to main content

brainwires_tool_system/
search.rs

1use anyhow::Result;
2use ignore::WalkBuilder;
3use regex::Regex;
4use serde::Deserialize;
5use serde_json::{Value, json};
6use std::collections::HashMap;
7use std::fs;
8
9use brainwires_core::{Tool, ToolContext, ToolInputSchema, ToolResult};
10
11/// Regex-based code pattern search tool
12pub struct SearchTool;
13
14impl SearchTool {
15    /// Return tool definitions for code search.
16    pub fn get_tools() -> Vec<Tool> {
17        vec![Self::search_code_tool()]
18    }
19
20    fn search_code_tool() -> Tool {
21        let mut properties = HashMap::new();
22        properties.insert(
23            "pattern".to_string(),
24            json!({"type": "string", "description": "Regex pattern to search for"}),
25        );
26        properties.insert(
27            "path".to_string(),
28            json!({"type": "string", "description": "Path to search in", "default": "."}),
29        );
30        Tool {
31            name: "search_code".to_string(),
32            description: "Search for code patterns in files using regex.".to_string(),
33            input_schema: ToolInputSchema::object(properties, vec!["pattern".to_string()]),
34            requires_approval: false,
35            ..Default::default()
36        }
37    }
38
39    /// Execute a search tool by name.
40    #[tracing::instrument(name = "tool.execute", skip(input, context), fields(tool_name))]
41    pub fn execute(
42        tool_use_id: &str,
43        tool_name: &str,
44        input: &Value,
45        context: &ToolContext,
46    ) -> ToolResult {
47        let result = match tool_name {
48            "search_code" => Self::search_code(input, context),
49            _ => Err(anyhow::anyhow!("Unknown search tool: {}", tool_name)),
50        };
51        match result {
52            Ok(output) => ToolResult::success(tool_use_id.to_string(), output),
53            Err(e) => ToolResult::error(tool_use_id.to_string(), format!("Search failed: {}", e)),
54        }
55    }
56
57    fn search_code(input: &Value, context: &ToolContext) -> Result<String> {
58        #[derive(Deserialize)]
59        struct Input {
60            pattern: String,
61            #[serde(default = "default_path")]
62            path: String,
63        }
64        fn default_path() -> String {
65            ".".to_string()
66        }
67
68        let params: Input = serde_json::from_value(input.clone())?;
69        let regex = Regex::new(&params.pattern)?;
70        let search_path = if params.path == "." {
71            &context.working_directory
72        } else {
73            &params.path
74        };
75
76        let mut matches = Vec::new();
77        for entry in WalkBuilder::new(search_path).build() {
78            let entry = entry?;
79            if entry.path().is_file()
80                && let Ok(content) = fs::read_to_string(entry.path())
81            {
82                for (line_num, line) in content.lines().enumerate() {
83                    if regex.is_match(line) {
84                        matches.push(format!(
85                            "{}:{} - {}",
86                            entry.path().display(),
87                            line_num + 1,
88                            line.trim()
89                        ));
90                        if matches.len() >= 100 {
91                            break;
92                        }
93                    }
94                }
95            }
96        }
97        Ok(format!(
98            "Search Results:\nPattern: {}\nMatches: {}\n\n{}",
99            params.pattern,
100            matches.len(),
101            matches.join("\n")
102        ))
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    fn create_test_context() -> ToolContext {
111        ToolContext {
112            working_directory: std::env::current_dir()
113                .unwrap()
114                .to_str()
115                .unwrap()
116                .to_string(),
117            ..Default::default()
118        }
119    }
120
121    #[test]
122    fn test_get_tools() {
123        let tools = SearchTool::get_tools();
124        assert_eq!(tools.len(), 1);
125        assert_eq!(tools[0].name, "search_code");
126    }
127
128    #[test]
129    fn test_execute_unknown_tool() {
130        let context = create_test_context();
131        let input = json!({"pattern": "test"});
132        let result = SearchTool::execute("1", "unknown_tool", &input, &context);
133        assert!(result.is_error);
134    }
135}