1use std::fmt::Write;
4
5use crate::{BuiltinTool, ToolContext, ToolError, ToolResult};
6use regex::Regex;
7use serde_json::Value;
8use std::path::{Path, PathBuf};
9use walkdir::WalkDir;
10
11const MAX_MATCHING_FILES: usize = 100;
13
14pub struct GrepTool;
16
17#[async_trait::async_trait]
18impl BuiltinTool for GrepTool {
19 fn name(&self) -> &'static str {
20 "grep"
21 }
22
23 fn description(&self) -> &'static str {
24 "Searches file contents using regex. Supports context lines and file type filtering. \
25 Returns matching lines in file:line:content format."
26 }
27
28 fn input_schema(&self) -> Value {
29 serde_json::json!({
30 "type": "object",
31 "properties": {
32 "pattern": {
33 "type": "string",
34 "description": "Regex pattern to search for"
35 },
36 "path": {
37 "type": "string",
38 "description": "File or directory to search in (defaults to workspace root)"
39 },
40 "glob": {
41 "type": "string",
42 "description": "Glob to filter files (e.g. \"*.rs\", \"*.{ts,tsx}\")"
43 },
44 "context": {
45 "type": "integer",
46 "description": "Number of context lines to show before and after matches"
47 },
48 "case_insensitive": {
49 "type": "boolean",
50 "description": "Case insensitive search (default: false)"
51 }
52 },
53 "required": ["pattern"]
54 })
55 }
56
57 #[allow(clippy::too_many_lines)]
58 async fn execute(&self, args: Value, ctx: &ToolContext) -> ToolResult {
59 let pattern_str = args
60 .get("pattern")
61 .and_then(Value::as_str)
62 .ok_or_else(|| ToolError::InvalidArguments("pattern is required".into()))?;
63
64 let case_insensitive = args
65 .get("case_insensitive")
66 .and_then(Value::as_bool)
67 .unwrap_or(false);
68
69 let regex_pattern = if case_insensitive {
70 format!("(?i){pattern_str}")
71 } else {
72 pattern_str.to_string()
73 };
74
75 let regex = Regex::new(®ex_pattern)
76 .map_err(|e| ToolError::InvalidArguments(format!("Invalid regex: {e}")))?;
77
78 let search_path = args
79 .get("path")
80 .and_then(Value::as_str)
81 .map_or_else(|| ctx.workspace_root.clone(), PathBuf::from);
82
83 if !search_path.exists() {
84 return Err(ToolError::PathNotFound(search_path.display().to_string()));
85 }
86
87 let search_path = search_path.canonicalize()?;
89
90 let context_lines = args
91 .get("context")
92 .and_then(Value::as_u64)
93 .map_or(0, |v| usize::try_from(v).unwrap_or(0));
94
95 let file_glob = args
96 .get("glob")
97 .and_then(Value::as_str)
98 .map(|g| {
99 globset::GlobBuilder::new(g)
100 .literal_separator(false)
101 .build()
102 .map(|gb| gb.compile_matcher())
103 })
104 .transpose()
105 .map_err(|e| ToolError::InvalidArguments(format!("Invalid file glob: {e}")))?;
106
107 if search_path.is_file() {
109 return search_file(&search_path, ®ex, context_lines);
110 }
111
112 let mut output = String::new();
114 let mut match_count: usize = 0;
115 let mut file_count: usize = 0;
116
117 for entry in WalkDir::new(&search_path)
118 .follow_links(false)
119 .into_iter()
120 .filter_entry(|e| {
121 if e.depth() == 0 {
123 return true;
124 }
125 e.file_name().to_str().is_none_or(|s| !s.starts_with('.'))
126 })
127 {
128 let Ok(entry) = entry else { continue };
129
130 if !entry.file_type().is_file() {
131 continue;
132 }
133
134 if let Some(ref glob) = file_glob {
136 let rel = entry
137 .path()
138 .strip_prefix(&search_path)
139 .unwrap_or(entry.path());
140 let file_name = entry.file_name().to_string_lossy();
141 if !glob.is_match(rel) && !glob.is_match(file_name.as_ref()) {
142 continue;
143 }
144 }
145
146 if let Ok(data) = std::fs::read(entry.path()) {
148 let check_len = data.len().min(512);
149 if data[..check_len].contains(&0) {
150 continue;
151 }
152 }
153
154 let Ok(content) = std::fs::read_to_string(entry.path()) else {
155 continue;
156 };
157
158 let lines: Vec<&str> = content.lines().collect();
159 let mut file_has_match = false;
160
161 for (idx, line) in lines.iter().enumerate() {
162 if regex.is_match(line) {
163 if !file_has_match {
164 file_has_match = true;
165 file_count = file_count.saturating_add(1);
166 if file_count > MAX_MATCHING_FILES {
167 let _ = write!(
168 output,
169 "\n(stopped after {MAX_MATCHING_FILES} files with matches)"
170 );
171 return Ok(output);
172 }
173 }
174
175 match_count = match_count.saturating_add(1);
176 write_context_lines(&mut output, entry.path(), &lines, idx, context_lines);
177 }
178 }
179 }
180
181 if match_count == 0 {
182 return Ok(format!("No matches for \"{pattern_str}\" found"));
183 }
184
185 let _ = write!(output, "\n({match_count} matches in {file_count} files)");
186 Ok(output)
187 }
188}
189
190fn write_context_lines(
192 output: &mut String,
193 path: &Path,
194 lines: &[&str],
195 idx: usize,
196 context: usize,
197) {
198 #[allow(clippy::arithmetic_side_effects)]
200 let line_num = idx + 1;
201
202 let start = idx.saturating_sub(context);
204 for (i, line) in lines[start..idx].iter().enumerate() {
205 #[allow(clippy::arithmetic_side_effects)]
207 let display_num = start + i + 1;
208 let _ = writeln!(output, "{}:{display_num}-{}", path.display(), line);
209 }
210
211 let _ = writeln!(output, "{}:{line_num}:{}", path.display(), lines[idx]);
213
214 let end = idx
216 .saturating_add(1)
217 .saturating_add(context)
218 .min(lines.len());
219 #[allow(clippy::arithmetic_side_effects)]
221 let after_start = idx + 1;
222 for (i, line) in lines[after_start..end].iter().enumerate() {
223 #[allow(clippy::arithmetic_side_effects)]
225 let display_num = idx + 2 + i;
226 let _ = writeln!(output, "{}:{display_num}-{}", path.display(), line);
227 }
228}
229
230fn search_file(path: &Path, regex: &Regex, context_lines: usize) -> ToolResult {
232 let content = std::fs::read_to_string(path)?;
233 let lines: Vec<&str> = content.lines().collect();
234 let mut output = String::new();
235 let mut match_count: usize = 0;
236
237 for (idx, line) in lines.iter().enumerate() {
238 if regex.is_match(line) {
239 match_count = match_count.saturating_add(1);
240 write_context_lines(&mut output, path, &lines, idx, context_lines);
241 }
242 }
243
244 if match_count == 0 {
245 return Ok(format!("No matches found in {}", path.display()));
246 }
247
248 let _ = write!(output, "\n({match_count} matches)");
249 Ok(output)
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255 use std::io::Write as IoWrite;
256 use tempfile::{NamedTempFile, TempDir};
257
258 fn ctx_with_root(root: &std::path::Path) -> ToolContext {
259 ToolContext::new(root.to_path_buf())
260 }
261
262 #[tokio::test]
263 async fn test_grep_basic() {
264 let dir = TempDir::new().unwrap();
265 std::fs::write(dir.path().join("a.rs"), "fn main() {}\nfn test() {}\n").unwrap();
266 std::fs::write(dir.path().join("b.rs"), "fn helper() {}\n").unwrap();
267
268 let ctx = ctx_with_root(dir.path());
269 let result = GrepTool
270 .execute(serde_json::json!({"pattern": "fn main"}), &ctx)
271 .await
272 .unwrap();
273
274 assert!(result.contains("fn main"));
275 assert!(result.contains("1 matches"));
276 }
277
278 #[tokio::test]
279 async fn test_grep_with_glob_filter() {
280 let dir = TempDir::new().unwrap();
281 std::fs::write(dir.path().join("a.rs"), "fn main() {}\n").unwrap();
282 std::fs::write(dir.path().join("b.txt"), "fn main() {}\n").unwrap();
283
284 let ctx = ctx_with_root(dir.path());
285 let result = GrepTool
286 .execute(
287 serde_json::json!({"pattern": "fn main", "glob": "*.rs"}),
288 &ctx,
289 )
290 .await
291 .unwrap();
292
293 assert!(result.contains("a.rs"));
294 assert!(!result.contains("b.txt"));
295 }
296
297 #[tokio::test]
298 async fn test_grep_case_insensitive() {
299 let dir = TempDir::new().unwrap();
300 std::fs::write(dir.path().join("test.txt"), "Hello World\nhello world\n").unwrap();
301
302 let ctx = ctx_with_root(dir.path());
303 let result = GrepTool
304 .execute(
305 serde_json::json!({"pattern": "hello", "case_insensitive": true}),
306 &ctx,
307 )
308 .await
309 .unwrap();
310
311 assert!(result.contains("Hello World"));
312 assert!(result.contains("hello world"));
313 }
314
315 #[tokio::test]
316 async fn test_grep_no_matches() {
317 let dir = TempDir::new().unwrap();
318 std::fs::write(dir.path().join("test.txt"), "hello world\n").unwrap();
319
320 let ctx = ctx_with_root(dir.path());
321 let result = GrepTool
322 .execute(serde_json::json!({"pattern": "foobar"}), &ctx)
323 .await
324 .unwrap();
325
326 assert!(result.contains("No matches"));
327 }
328
329 #[tokio::test]
330 async fn test_grep_context_lines() {
331 let mut f = NamedTempFile::new().unwrap();
332 writeln!(f, "line 1").unwrap();
333 writeln!(f, "line 2").unwrap();
334 writeln!(f, "MATCH").unwrap();
335 writeln!(f, "line 4").unwrap();
336 writeln!(f, "line 5").unwrap();
337
338 let ctx = ctx_with_root(&std::env::temp_dir());
339 let result = GrepTool
340 .execute(
341 serde_json::json!({
342 "pattern": "MATCH",
343 "path": f.path().to_str().unwrap(),
344 "context": 1
345 }),
346 &ctx,
347 )
348 .await
349 .unwrap();
350
351 assert!(result.contains("line 2"));
352 assert!(result.contains("MATCH"));
353 assert!(result.contains("line 4"));
354 }
355
356 #[tokio::test]
357 async fn test_grep_single_file() {
358 let mut f = NamedTempFile::new().unwrap();
359 writeln!(f, "hello").unwrap();
360 writeln!(f, "world").unwrap();
361
362 let ctx = ctx_with_root(&std::env::temp_dir());
363 let result = GrepTool
364 .execute(
365 serde_json::json!({
366 "pattern": "hello",
367 "path": f.path().to_str().unwrap()
368 }),
369 &ctx,
370 )
371 .await
372 .unwrap();
373
374 assert!(result.contains("hello"));
375 assert!(result.contains("1 matches"));
376 }
377
378 #[tokio::test]
379 async fn test_grep_invalid_regex() {
380 let ctx = ctx_with_root(&std::env::temp_dir());
381 let result = GrepTool
382 .execute(serde_json::json!({"pattern": "[invalid"}), &ctx)
383 .await;
384
385 assert!(result.is_err());
386 }
387}