Skip to main content

codetether_agent/tool/
codesearch.rs

1//! Code Search Tool - Search code in the workspace using ripgrep-style patterns.
2
3use super::{Tool, ToolResult};
4use anyhow::{Context, Result};
5use async_trait::async_trait;
6use serde::Deserialize;
7use serde_json::{Value, json};
8use std::path::PathBuf;
9use walkdir::WalkDir;
10
11const MAX_RESULTS: usize = 50;
12const MAX_CONTEXT_LINES: usize = 3;
13
14pub struct CodeSearchTool {
15    root: PathBuf,
16}
17
18impl Default for CodeSearchTool {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24#[allow(dead_code)]
25impl CodeSearchTool {
26    pub fn new() -> Self {
27        Self {
28            root: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
29        }
30    }
31
32    pub fn with_root(root: PathBuf) -> Self {
33        Self { root }
34    }
35
36    fn should_skip(&self, path: &std::path::Path) -> bool {
37        let skip_dirs = [
38            ".git",
39            "node_modules",
40            "target",
41            "dist",
42            ".next",
43            "__pycache__",
44            ".venv",
45            "vendor",
46        ];
47        path.components()
48            .any(|c| skip_dirs.contains(&c.as_os_str().to_str().unwrap_or("")))
49    }
50
51    fn is_text_file(&self, path: &std::path::Path) -> bool {
52        let text_exts = [
53            "rs", "ts", "js", "tsx", "jsx", "py", "go", "java", "c", "cpp", "h", "hpp", "md",
54            "txt", "json", "yaml", "yml", "toml", "sh", "bash", "zsh", "html", "css", "scss",
55        ];
56        path.extension()
57            .and_then(|e| e.to_str())
58            .map(|e| text_exts.contains(&e))
59            .unwrap_or(false)
60    }
61
62    fn search_file(
63        &self,
64        path: &std::path::Path,
65        pattern: &regex::Regex,
66        context: usize,
67    ) -> Result<Vec<Match>> {
68        let content = std::fs::read_to_string(path)?;
69        let lines: Vec<&str> = content.lines().collect();
70        let mut matches = Vec::new();
71
72        for (idx, line) in lines.iter().enumerate() {
73            if pattern.is_match(line) {
74                let start = idx.saturating_sub(context);
75                let end = (idx + context + 1).min(lines.len());
76                let context_lines: Vec<String> = lines[start..end]
77                    .iter()
78                    .enumerate()
79                    .map(|(i, l)| {
80                        let line_num = start + i + 1;
81                        let marker = if start + i == idx { ">" } else { " " };
82                        format!("{} {:4}: {}", marker, line_num, l)
83                    })
84                    .collect();
85
86                matches.push(Match {
87                    path: path
88                        .strip_prefix(&self.root)
89                        .unwrap_or(path)
90                        .to_string_lossy()
91                        .to_string(),
92                    line: idx + 1,
93                    matched_line: line.to_string(),
94                    context: context_lines.join("\n"),
95                });
96            }
97        }
98        Ok(matches)
99    }
100}
101
102#[derive(Debug)]
103struct Match {
104    path: String,
105    line: usize,
106    matched_line: String,
107    context: String,
108}
109
110impl std::fmt::Display for Match {
111    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112        write!(f, "{}:{}: {}", self.path, self.line, self.matched_line)
113    }
114}
115
116#[derive(Deserialize)]
117struct Params {
118    pattern: String,
119    #[serde(default)]
120    path: Option<String>,
121    #[serde(default)]
122    file_pattern: Option<String>,
123    #[serde(default = "default_context")]
124    context_lines: usize,
125    #[serde(default)]
126    case_sensitive: bool,
127}
128
129fn default_context() -> usize {
130    2
131}
132
133#[async_trait]
134impl Tool for CodeSearchTool {
135    fn id(&self) -> &str {
136        "codesearch"
137    }
138    fn name(&self) -> &str {
139        "Code Search"
140    }
141    fn description(&self) -> &str {
142        "Search for code patterns in the workspace. Supports regex."
143    }
144    fn parameters(&self) -> Value {
145        json!({
146            "type": "object",
147            "properties": {
148                "pattern": {"type": "string", "description": "Search pattern (regex supported)"},
149                "path": {"type": "string", "description": "Subdirectory to search in"},
150                "file_pattern": {"type": "string", "description": "Glob pattern for files (e.g., *.rs)"},
151                "context_lines": {"type": "integer", "default": 2, "description": "Lines of context"},
152                "case_sensitive": {"type": "boolean", "default": false}
153            },
154            "required": ["pattern"]
155        })
156    }
157
158    async fn execute(&self, params: Value) -> Result<ToolResult> {
159        let p: Params = serde_json::from_value(params).context("Invalid params")?;
160
161        let regex = regex::RegexBuilder::new(&p.pattern)
162            .case_insensitive(!p.case_sensitive)
163            .build()
164            .context("Invalid regex pattern")?;
165
166        let search_root = match &p.path {
167            Some(subpath) => self.root.join(subpath),
168            None => self.root.clone(),
169        };
170
171        let file_glob = p
172            .file_pattern
173            .as_ref()
174            .and_then(|pat| glob::Pattern::new(pat).ok());
175
176        let mut all_matches = Vec::new();
177
178        for entry in WalkDir::new(&search_root)
179            .into_iter()
180            .filter_map(|e| e.ok())
181        {
182            let path = entry.path();
183            if !path.is_file() || self.should_skip(path) || !self.is_text_file(path) {
184                continue;
185            }
186
187            if let Some(ref glob) = file_glob
188                && !glob.matches_path(path)
189            {
190                continue;
191            }
192
193            if let Ok(matches) =
194                self.search_file(path, &regex, p.context_lines.min(MAX_CONTEXT_LINES))
195            {
196                all_matches.extend(matches);
197                if all_matches.len() >= MAX_RESULTS {
198                    break;
199                }
200            }
201        }
202
203        if all_matches.is_empty() {
204            return Ok(ToolResult::success(format!(
205                "No matches found for pattern: {}",
206                p.pattern
207            )));
208        }
209
210        let output = all_matches
211            .iter()
212            .take(MAX_RESULTS)
213            .map(|m| format!("{}:{}\n{}", m.path, m.line, m.context))
214            .collect::<Vec<_>>()
215            .join("\n\n");
216
217        let truncated = all_matches.len() > MAX_RESULTS;
218        let msg = if truncated {
219            format!(
220                "Found {} matches (showing first {}):\n\n{}",
221                all_matches.len(),
222                MAX_RESULTS,
223                output
224            )
225        } else {
226            format!("Found {} matches:\n\n{}", all_matches.len(), output)
227        };
228
229        Ok(ToolResult::success(msg).with_metadata("match_count", json!(all_matches.len())))
230    }
231}