Skip to main content

astrid_tools/
grep.rs

1//! Grep tool — searches file contents with regex.
2
3use std::fmt::Write;
4
5use crate::{BuiltinTool, ToolContext, ToolError, ToolResult};
6use regex::Regex;
7use serde_json::Value;
8use std::path::{Path, PathBuf};
9use walkdir::WalkDir;
10
11/// Maximum number of matching files to report.
12const MAX_MATCHING_FILES: usize = 100;
13
14/// Built-in tool for searching file contents.
15pub struct GrepTool;
16
17#[async_trait::async_trait]
18impl BuiltinTool for GrepTool {
19    fn name(&self) -> &'static str {
20        "grep"
21    }
22
23    fn description(&self) -> &'static str {
24        "Searches file contents using regex. Supports context lines and file type filtering. \
25         Returns matching lines in file:line:content format."
26    }
27
28    fn input_schema(&self) -> Value {
29        serde_json::json!({
30            "type": "object",
31            "properties": {
32                "pattern": {
33                    "type": "string",
34                    "description": "Regex pattern to search for"
35                },
36                "path": {
37                    "type": "string",
38                    "description": "File or directory to search in (defaults to workspace root)"
39                },
40                "glob": {
41                    "type": "string",
42                    "description": "Glob to filter files (e.g. \"*.rs\", \"*.{ts,tsx}\")"
43                },
44                "context": {
45                    "type": "integer",
46                    "description": "Number of context lines to show before and after matches"
47                },
48                "case_insensitive": {
49                    "type": "boolean",
50                    "description": "Case insensitive search (default: false)"
51                }
52            },
53            "required": ["pattern"]
54        })
55    }
56
57    #[allow(clippy::too_many_lines)]
58    async fn execute(&self, args: Value, ctx: &ToolContext) -> ToolResult {
59        let pattern_str = args
60            .get("pattern")
61            .and_then(Value::as_str)
62            .ok_or_else(|| ToolError::InvalidArguments("pattern is required".into()))?;
63
64        let case_insensitive = args
65            .get("case_insensitive")
66            .and_then(Value::as_bool)
67            .unwrap_or(false);
68
69        let regex_pattern = if case_insensitive {
70            format!("(?i){pattern_str}")
71        } else {
72            pattern_str.to_string()
73        };
74
75        let regex = Regex::new(&regex_pattern)
76            .map_err(|e| ToolError::InvalidArguments(format!("Invalid regex: {e}")))?;
77
78        let search_path = args
79            .get("path")
80            .and_then(Value::as_str)
81            .map_or_else(|| ctx.workspace_root.clone(), PathBuf::from);
82
83        if !search_path.exists() {
84            return Err(ToolError::PathNotFound(search_path.display().to_string()));
85        }
86
87        // Canonicalize to handle symlinks (e.g. /var -> /private/var on macOS)
88        let search_path = search_path.canonicalize()?;
89
90        let context_lines = args
91            .get("context")
92            .and_then(Value::as_u64)
93            .map_or(0, |v| usize::try_from(v).unwrap_or(0));
94
95        let file_glob = args
96            .get("glob")
97            .and_then(Value::as_str)
98            .map(|g| {
99                globset::GlobBuilder::new(g)
100                    .literal_separator(false)
101                    .build()
102                    .map(|gb| gb.compile_matcher())
103            })
104            .transpose()
105            .map_err(|e| ToolError::InvalidArguments(format!("Invalid file glob: {e}")))?;
106
107        // If search_path is a file, just search that file
108        if search_path.is_file() {
109            return search_file(&search_path, &regex, context_lines);
110        }
111
112        // Walk directory
113        let mut output = String::new();
114        let mut match_count: usize = 0;
115        let mut file_count: usize = 0;
116
117        for entry in WalkDir::new(&search_path)
118            .follow_links(false)
119            .into_iter()
120            .filter_entry(|e| {
121                // Skip hidden directories (but not the root entry)
122                if e.depth() == 0 {
123                    return true;
124                }
125                e.file_name().to_str().is_none_or(|s| !s.starts_with('.'))
126            })
127        {
128            let Ok(entry) = entry else { continue };
129
130            if !entry.file_type().is_file() {
131                continue;
132            }
133
134            // Apply file glob filter
135            if let Some(ref glob) = file_glob {
136                let rel = entry
137                    .path()
138                    .strip_prefix(&search_path)
139                    .unwrap_or(entry.path());
140                let file_name = entry.file_name().to_string_lossy();
141                if !glob.is_match(rel) && !glob.is_match(file_name.as_ref()) {
142                    continue;
143                }
144            }
145
146            // Skip binary files (check first 512 bytes)
147            if let Ok(data) = std::fs::read(entry.path()) {
148                let check_len = data.len().min(512);
149                if data[..check_len].contains(&0) {
150                    continue;
151                }
152            }
153
154            let Ok(content) = std::fs::read_to_string(entry.path()) else {
155                continue;
156            };
157
158            let lines: Vec<&str> = content.lines().collect();
159            let mut file_has_match = false;
160
161            for (idx, line) in lines.iter().enumerate() {
162                if regex.is_match(line) {
163                    if !file_has_match {
164                        file_has_match = true;
165                        file_count = file_count.saturating_add(1);
166                        if file_count > MAX_MATCHING_FILES {
167                            let _ = write!(
168                                output,
169                                "\n(stopped after {MAX_MATCHING_FILES} files with matches)"
170                            );
171                            return Ok(output);
172                        }
173                    }
174
175                    match_count = match_count.saturating_add(1);
176                    write_context_lines(&mut output, entry.path(), &lines, idx, context_lines);
177                }
178            }
179        }
180
181        if match_count == 0 {
182            return Ok(format!("No matches for \"{pattern_str}\" found"));
183        }
184
185        let _ = write!(output, "\n({match_count} matches in {file_count} files)");
186        Ok(output)
187    }
188}
189
190/// Write a match with context lines to the output buffer.
191fn write_context_lines(
192    output: &mut String,
193    path: &Path,
194    lines: &[&str],
195    idx: usize,
196    context: usize,
197) {
198    // Safety: idx is a valid index into lines from enumerate(), so idx+1 won't overflow
199    #[allow(clippy::arithmetic_side_effects)]
200    let line_num = idx + 1;
201
202    // Context before
203    let start = idx.saturating_sub(context);
204    for (i, line) in lines[start..idx].iter().enumerate() {
205        // Safety: start <= idx and i < idx-start, so start+i+1 won't overflow
206        #[allow(clippy::arithmetic_side_effects)]
207        let display_num = start + i + 1;
208        let _ = writeln!(output, "{}:{display_num}-{}", path.display(), line);
209    }
210
211    // The match itself
212    let _ = writeln!(output, "{}:{line_num}:{}", path.display(), lines[idx]);
213
214    // Context after
215    let end = idx
216        .saturating_add(1)
217        .saturating_add(context)
218        .min(lines.len());
219    // Safety: line_num = idx+1, so idx+1 won't overflow; i is bounded by end-idx-1
220    #[allow(clippy::arithmetic_side_effects)]
221    let after_start = idx + 1;
222    for (i, line) in lines[after_start..end].iter().enumerate() {
223        // Safety: idx+2+i bounded by lines.len()
224        #[allow(clippy::arithmetic_side_effects)]
225        let display_num = idx + 2 + i;
226        let _ = writeln!(output, "{}:{display_num}-{}", path.display(), line);
227    }
228}
229
230/// Search a single file for matches.
231fn search_file(path: &Path, regex: &Regex, context_lines: usize) -> ToolResult {
232    let content = std::fs::read_to_string(path)?;
233    let lines: Vec<&str> = content.lines().collect();
234    let mut output = String::new();
235    let mut match_count: usize = 0;
236
237    for (idx, line) in lines.iter().enumerate() {
238        if regex.is_match(line) {
239            match_count = match_count.saturating_add(1);
240            write_context_lines(&mut output, path, &lines, idx, context_lines);
241        }
242    }
243
244    if match_count == 0 {
245        return Ok(format!("No matches found in {}", path.display()));
246    }
247
248    let _ = write!(output, "\n({match_count} matches)");
249    Ok(output)
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255    use std::io::Write as IoWrite;
256    use tempfile::{NamedTempFile, TempDir};
257
258    fn ctx_with_root(root: &std::path::Path) -> ToolContext {
259        ToolContext::new(root.to_path_buf())
260    }
261
262    #[tokio::test]
263    async fn test_grep_basic() {
264        let dir = TempDir::new().unwrap();
265        std::fs::write(dir.path().join("a.rs"), "fn main() {}\nfn test() {}\n").unwrap();
266        std::fs::write(dir.path().join("b.rs"), "fn helper() {}\n").unwrap();
267
268        let ctx = ctx_with_root(dir.path());
269        let result = GrepTool
270            .execute(serde_json::json!({"pattern": "fn main"}), &ctx)
271            .await
272            .unwrap();
273
274        assert!(result.contains("fn main"));
275        assert!(result.contains("1 matches"));
276    }
277
278    #[tokio::test]
279    async fn test_grep_with_glob_filter() {
280        let dir = TempDir::new().unwrap();
281        std::fs::write(dir.path().join("a.rs"), "fn main() {}\n").unwrap();
282        std::fs::write(dir.path().join("b.txt"), "fn main() {}\n").unwrap();
283
284        let ctx = ctx_with_root(dir.path());
285        let result = GrepTool
286            .execute(
287                serde_json::json!({"pattern": "fn main", "glob": "*.rs"}),
288                &ctx,
289            )
290            .await
291            .unwrap();
292
293        assert!(result.contains("a.rs"));
294        assert!(!result.contains("b.txt"));
295    }
296
297    #[tokio::test]
298    async fn test_grep_case_insensitive() {
299        let dir = TempDir::new().unwrap();
300        std::fs::write(dir.path().join("test.txt"), "Hello World\nhello world\n").unwrap();
301
302        let ctx = ctx_with_root(dir.path());
303        let result = GrepTool
304            .execute(
305                serde_json::json!({"pattern": "hello", "case_insensitive": true}),
306                &ctx,
307            )
308            .await
309            .unwrap();
310
311        assert!(result.contains("Hello World"));
312        assert!(result.contains("hello world"));
313    }
314
315    #[tokio::test]
316    async fn test_grep_no_matches() {
317        let dir = TempDir::new().unwrap();
318        std::fs::write(dir.path().join("test.txt"), "hello world\n").unwrap();
319
320        let ctx = ctx_with_root(dir.path());
321        let result = GrepTool
322            .execute(serde_json::json!({"pattern": "foobar"}), &ctx)
323            .await
324            .unwrap();
325
326        assert!(result.contains("No matches"));
327    }
328
329    #[tokio::test]
330    async fn test_grep_context_lines() {
331        let mut f = NamedTempFile::new().unwrap();
332        writeln!(f, "line 1").unwrap();
333        writeln!(f, "line 2").unwrap();
334        writeln!(f, "MATCH").unwrap();
335        writeln!(f, "line 4").unwrap();
336        writeln!(f, "line 5").unwrap();
337
338        let ctx = ctx_with_root(&std::env::temp_dir());
339        let result = GrepTool
340            .execute(
341                serde_json::json!({
342                    "pattern": "MATCH",
343                    "path": f.path().to_str().unwrap(),
344                    "context": 1
345                }),
346                &ctx,
347            )
348            .await
349            .unwrap();
350
351        assert!(result.contains("line 2"));
352        assert!(result.contains("MATCH"));
353        assert!(result.contains("line 4"));
354    }
355
356    #[tokio::test]
357    async fn test_grep_single_file() {
358        let mut f = NamedTempFile::new().unwrap();
359        writeln!(f, "hello").unwrap();
360        writeln!(f, "world").unwrap();
361
362        let ctx = ctx_with_root(&std::env::temp_dir());
363        let result = GrepTool
364            .execute(
365                serde_json::json!({
366                    "pattern": "hello",
367                    "path": f.path().to_str().unwrap()
368                }),
369                &ctx,
370            )
371            .await
372            .unwrap();
373
374        assert!(result.contains("hello"));
375        assert!(result.contains("1 matches"));
376    }
377
378    #[tokio::test]
379    async fn test_grep_invalid_regex() {
380        let ctx = ctx_with_root(&std::env::temp_dir());
381        let result = GrepTool
382            .execute(serde_json::json!({"pattern": "[invalid"}), &ctx)
383            .await;
384
385        assert!(result.is_err());
386    }
387}