Skip to main content

bamboo_tools/tools/
grep.rs

1use async_trait::async_trait;
2use bamboo_agent_core::{Tool, ToolError, ToolExecutionContext, ToolResult};
3use globset::{GlobBuilder, GlobSet};
4use regex::{Regex, RegexBuilder};
5use serde::Deserialize;
6use serde_json::json;
7use std::collections::{BTreeSet, HashMap};
8use std::path::{Path, PathBuf};
9use walkdir::WalkDir;
10
11use super::workspace_state;
12
13const DEFAULT_HEAD_LIMIT: usize = 200;
14const MAX_RESULT_BYTES: usize = 256 * 1024;
15const MAX_MATCHES: usize = 2_000;
16const MAX_SCANNED_FILES: usize = 50_000;
17const MAX_FILE_BYTES: u64 = 2 * 1024 * 1024;
18const SKIP_DIRS: [&str; 8] = [
19    ".git",
20    "node_modules",
21    "target",
22    "dist",
23    "build",
24    ".next",
25    ".cache",
26    "coverage",
27];
28const SEARCH_SCOPE_TOO_BROAD_ERROR: &str =
29    "Search scope too broad. Add path/glob/type or reduce pattern.";
30const MULTILINE_REQUIRES_NARROWED_PATH_ERROR: &str = "Multiline grep requires narrowed path.";
31const RESULT_TOO_LARGE_ERROR: &str = "Result too large; refine query and retry.";
32
33#[derive(Debug, Deserialize, Clone, Copy, Default)]
34#[serde(rename_all = "snake_case")]
35enum OutputMode {
36    Content,
37    #[default]
38    FilesWithMatches,
39    Count,
40}
41
42#[derive(Debug, Deserialize)]
43struct GrepArgs {
44    pattern: String,
45    #[serde(default)]
46    path: Option<String>,
47    #[serde(default)]
48    glob: Option<String>,
49    #[serde(default)]
50    output_mode: Option<OutputMode>,
51    #[serde(rename = "-B", default)]
52    before: Option<usize>,
53    #[serde(rename = "-A", default)]
54    after: Option<usize>,
55    #[serde(rename = "-C", default)]
56    context: Option<usize>,
57    #[serde(rename = "-n", default)]
58    line_numbers: Option<bool>,
59    #[serde(rename = "-i", default)]
60    case_insensitive: Option<bool>,
61    #[serde(default)]
62    r#type: Option<String>,
63    #[serde(default)]
64    head_limit: Option<usize>,
65    #[serde(default)]
66    multiline: Option<bool>,
67}
68
69pub struct GrepTool;
70
71impl GrepTool {
72    pub fn new() -> Self {
73        Self
74    }
75
76    fn extension_map() -> HashMap<&'static str, &'static [&'static str]> {
77        HashMap::from([
78            ("js", &["js", "mjs", "cjs"] as &[_]),
79            ("ts", &["ts", "tsx"]),
80            ("py", &["py"]),
81            ("rust", &["rs"]),
82            ("go", &["go"]),
83            ("java", &["java"]),
84            ("cpp", &["cc", "cpp", "cxx", "hpp", "h"]),
85            ("c", &["c", "h"]),
86            ("json", &["json"]),
87            ("yaml", &["yaml", "yml"]),
88            ("toml", &["toml"]),
89            ("md", &["md", "markdown"]),
90        ])
91    }
92
93    fn collect_files(base: &Path, type_filter: Option<&str>) -> Vec<PathBuf> {
94        let ext_map = Self::extension_map();
95        let allowed_ext = type_filter.and_then(|name| ext_map.get(name).copied());
96
97        let mut files = Vec::new();
98        for entry in WalkDir::new(base)
99            .follow_links(false)
100            .into_iter()
101            .filter_entry(|entry| {
102                !entry.file_type().is_dir() || !Self::should_skip_dir(entry.path())
103            })
104            .filter_map(|entry| entry.ok())
105        {
106            if !entry.file_type().is_file() {
107                continue;
108            }
109            if files.len() >= MAX_SCANNED_FILES {
110                break;
111            }
112            let path = entry.path();
113
114            if let Some(extensions) = allowed_ext {
115                let ext = path
116                    .extension()
117                    .and_then(|v| v.to_str())
118                    .unwrap_or_default();
119                if !extensions.iter().any(|candidate| candidate == &ext) {
120                    continue;
121                }
122            }
123
124            files.push(path.to_path_buf());
125        }
126
127        files
128    }
129
130    fn should_skip_dir(path: &Path) -> bool {
131        path.file_name()
132            .and_then(|name| name.to_str())
133            .map(|name| SKIP_DIRS.contains(&name))
134            .unwrap_or(false)
135    }
136
137    fn compile_glob(glob: Option<&str>) -> Result<Option<GlobSet>, ToolError> {
138        let Some(pattern) = glob else {
139            return Ok(None);
140        };
141
142        let mut builder = globset::GlobSetBuilder::new();
143        let glob = GlobBuilder::new(pattern)
144            .literal_separator(false)
145            .build()
146            .map_err(|e| ToolError::InvalidArguments(format!("Invalid glob pattern: {}", e)))?;
147        builder.add(glob);
148        builder
149            .build()
150            .map(Some)
151            .map_err(|e| ToolError::Execution(format!("Failed to compile glob: {}", e)))
152    }
153
154    fn compile_regex(
155        pattern: &str,
156        case_insensitive: bool,
157        multiline: bool,
158    ) -> Result<Regex, ToolError> {
159        let mut builder = RegexBuilder::new(pattern);
160        builder.case_insensitive(case_insensitive);
161        builder.dot_matches_new_line(multiline);
162        builder.multi_line(multiline);
163        builder
164            .build()
165            .map_err(|e| ToolError::InvalidArguments(format!("Invalid regex pattern: {}", e)))
166    }
167
168    fn byte_to_line(line_starts: &[usize], byte: usize) -> usize {
169        match line_starts.binary_search(&byte) {
170            Ok(idx) => idx,
171            Err(idx) => idx.saturating_sub(1),
172        }
173    }
174
175    fn format_content_hits(
176        path: &Path,
177        content: &str,
178        regex: &Regex,
179        multiline: bool,
180        before: usize,
181        after: usize,
182        line_numbers: bool,
183    ) -> Vec<String> {
184        let lines: Vec<&str> = content.lines().collect();
185        if lines.is_empty() {
186            return Vec::new();
187        }
188
189        let mut selected_lines = BTreeSet::new();
190
191        if multiline {
192            let mut line_starts = vec![0usize];
193            for (idx, byte) in content.bytes().enumerate() {
194                if byte == b'\n' {
195                    line_starts.push(idx + 1);
196                }
197            }
198
199            for mat in regex.find_iter(content) {
200                let start_line = Self::byte_to_line(&line_starts, mat.start());
201                let end_line = Self::byte_to_line(&line_starts, mat.end().saturating_sub(1));
202                let range_start = start_line.saturating_sub(before);
203                let range_end = (end_line + after).min(lines.len().saturating_sub(1));
204                for line_idx in range_start..=range_end {
205                    selected_lines.insert(line_idx);
206                }
207            }
208        } else {
209            for (idx, line) in lines.iter().enumerate() {
210                if regex.is_match(line) {
211                    let range_start = idx.saturating_sub(before);
212                    let range_end = (idx + after).min(lines.len().saturating_sub(1));
213                    for line_idx in range_start..=range_end {
214                        selected_lines.insert(line_idx);
215                    }
216                }
217            }
218        }
219
220        selected_lines
221            .into_iter()
222            .map(|idx| {
223                let display_path = bamboo_infrastructure::paths::path_to_display_string(path);
224                if line_numbers {
225                    format!("{}:{}:{}", display_path, idx + 1, lines[idx])
226                } else {
227                    format!("{}:{}", display_path, lines[idx])
228                }
229            })
230            .collect()
231    }
232
233    fn resolve_search_root(path: Option<&str>, cwd: &Path) -> PathBuf {
234        match path {
235            Some(path) => {
236                let candidate = PathBuf::from(path);
237                if candidate.is_absolute() {
238                    candidate
239                } else {
240                    cwd.join(candidate)
241                }
242            }
243            None => cwd.to_path_buf(),
244        }
245    }
246
247    fn validate_scope(
248        args: &GrepArgs,
249        output_mode: OutputMode,
250        multiline: bool,
251        cwd: &Path,
252    ) -> Result<(), ToolError> {
253        if matches!(output_mode, OutputMode::Content)
254            && args.path.is_none()
255            && args.glob.is_none()
256            && args.r#type.is_none()
257        {
258            return Err(ToolError::InvalidArguments(
259                SEARCH_SCOPE_TOO_BROAD_ERROR.to_string(),
260            ));
261        }
262
263        if multiline {
264            let Some(path) = args.path.as_deref() else {
265                return Err(ToolError::InvalidArguments(
266                    MULTILINE_REQUIRES_NARROWED_PATH_ERROR.to_string(),
267                ));
268            };
269
270            let resolved = Self::resolve_search_root(Some(path), cwd);
271            if resolved.is_dir() {
272                if let (Ok(resolved_canonical), Ok(cwd_canonical)) =
273                    (resolved.canonicalize(), cwd.canonicalize())
274                {
275                    if resolved_canonical == cwd_canonical {
276                        return Err(ToolError::InvalidArguments(
277                            MULTILINE_REQUIRES_NARROWED_PATH_ERROR.to_string(),
278                        ));
279                    }
280                }
281            }
282        }
283
284        Ok(())
285    }
286}
287
288impl Default for GrepTool {
289    fn default() -> Self {
290        Self::new()
291    }
292}
293
294#[async_trait]
295impl Tool for GrepTool {
296    fn name(&self) -> &str {
297        "Grep"
298    }
299
300    fn description(&self) -> &str {
301        "Search file contents using ripgrep-style regex parameters. Start with files_with_matches or a narrowed path/glob/type before using content or multiline mode."
302    }
303
304    fn mutability(&self) -> crate::ToolMutability {
305        crate::ToolMutability::ReadOnly
306    }
307
308    fn concurrency_safe(&self) -> bool {
309        true
310    }
311
312    fn parameters_schema(&self) -> serde_json::Value {
313        json!({
314            "type": "object",
315            "properties": {
316                "pattern": { "type": "string", "description": "Regex pattern" },
317                "path": { "type": "string", "description": "File or directory to search. Narrow this for expensive or multiline searches." },
318                "glob": { "type": "string", "description": "Glob file filter used to limit candidate files" },
319                "output_mode": {
320                    "type": "string",
321                    "enum": ["content", "files_with_matches", "count"],
322                    "description": "Output mode. Prefer files_with_matches for broad discovery, then refine with Read or content mode."
323                },
324                "-B": { "type": "number", "description": "Lines before match" },
325                "-A": { "type": "number", "description": "Lines after match" },
326                "-C": { "type": "number", "description": "Lines before and after match" },
327                "-n": { "type": "boolean", "description": "Show line numbers" },
328                "-i": { "type": "boolean", "description": "Case insensitive" },
329                "type": { "type": "string", "description": "File type filter (for example rust, js, ts, py)" },
330                "head_limit": { "type": "number", "description": "Limit output entries. Keep this small for broad queries." },
331                "multiline": { "type": "boolean", "description": "Enable multiline regex. Requires a narrowed path." }
332            },
333            "required": ["pattern"],
334            "additionalProperties": false
335        })
336    }
337
338    async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
339        self.execute_with_context(args, ToolExecutionContext::none("Grep"))
340            .await
341    }
342
343    async fn execute_with_context(
344        &self,
345        args: serde_json::Value,
346        ctx: ToolExecutionContext<'_>,
347    ) -> Result<ToolResult, ToolError> {
348        let parsed: GrepArgs = serde_json::from_value(args)
349            .map_err(|e| ToolError::InvalidArguments(format!("Invalid Grep args: {}", e)))?;
350
351        let cwd = workspace_state::workspace_or_process_cwd(ctx.session_id);
352        let root = Self::resolve_search_root(parsed.path.as_deref(), &cwd);
353
354        let output_mode = parsed.output_mode.unwrap_or_default();
355        let context = parsed.context.unwrap_or(0);
356        let before = parsed.before.unwrap_or(context);
357        let after = parsed.after.unwrap_or(context);
358        let line_numbers = parsed.line_numbers.unwrap_or(false);
359        let case_insensitive = parsed.case_insensitive.unwrap_or(false);
360        let multiline = parsed.multiline.unwrap_or(false);
361        let head_limit = parsed.head_limit.unwrap_or(DEFAULT_HEAD_LIMIT);
362
363        Self::validate_scope(&parsed, output_mode, multiline, &cwd)?;
364
365        let regex = Self::compile_regex(&parsed.pattern, case_insensitive, multiline)?;
366        let glob_filter = Self::compile_glob(parsed.glob.as_deref())?;
367
368        let files = if root.is_file() {
369            vec![root.clone()]
370        } else if root.is_dir() {
371            Self::collect_files(&root, parsed.r#type.as_deref())
372        } else {
373            return Err(ToolError::Execution(format!(
374                "Path does not exist: {}",
375                root.display()
376            )));
377        };
378
379        let mut matched_files = Vec::new();
380        let mut count_rows = Vec::new();
381        let mut content_rows = Vec::new();
382        let mut total_matches = 0usize;
383        let mut partial = false;
384
385        for file in files {
386            if let Some(filter) = &glob_filter {
387                let relative = file.strip_prefix(&root).unwrap_or(&file);
388                if !filter.is_match(relative) && !filter.is_match(&file) {
389                    continue;
390                }
391            }
392
393            let Ok(metadata) = tokio::fs::metadata(&file).await else {
394                continue;
395            };
396            if metadata.len() > MAX_FILE_BYTES {
397                continue;
398            }
399
400            let Ok(content) = tokio::fs::read_to_string(&file).await else {
401                continue;
402            };
403
404            if content.contains('\0') {
405                continue;
406            }
407
408            let match_count = if multiline {
409                regex.find_iter(&content).count()
410            } else {
411                content.lines().filter(|line| regex.is_match(line)).count()
412            };
413
414            if match_count == 0 {
415                continue;
416            }
417
418            total_matches = total_matches.saturating_add(match_count);
419            if total_matches > MAX_MATCHES {
420                return Err(ToolError::Execution(RESULT_TOO_LARGE_ERROR.to_string()));
421            }
422
423            matched_files.push(bamboo_infrastructure::paths::path_to_display_string(&file));
424            count_rows.push(format!(
425                "{}:{}",
426                bamboo_infrastructure::paths::path_to_display_string(&file),
427                match_count
428            ));
429
430            if matches!(output_mode, OutputMode::Content) {
431                content_rows.extend(Self::format_content_hits(
432                    &file,
433                    &content,
434                    &regex,
435                    multiline,
436                    before,
437                    after,
438                    line_numbers,
439                ));
440                if content_rows.len() >= head_limit {
441                    content_rows.truncate(head_limit);
442                    partial = true;
443                    break;
444                }
445            }
446
447            if matches!(
448                output_mode,
449                OutputMode::FilesWithMatches | OutputMode::Count
450            ) && matched_files.len() >= head_limit
451            {
452                partial = true;
453                break;
454            }
455        }
456
457        let mut result_lines = match output_mode {
458            OutputMode::FilesWithMatches => matched_files,
459            OutputMode::Count => count_rows,
460            OutputMode::Content => content_rows,
461        };
462
463        if result_lines.len() > head_limit {
464            result_lines.truncate(head_limit);
465            partial = true;
466        }
467        if partial {
468            result_lines
469                .push("[PARTIAL] Output was truncated. Narrow path/pattern and retry.".to_string());
470        }
471
472        let result = result_lines.join("\n");
473        if result.len() > MAX_RESULT_BYTES {
474            return Err(ToolError::Execution(RESULT_TOO_LARGE_ERROR.to_string()));
475        }
476
477        Ok(ToolResult {
478            success: true,
479            result,
480            display_preference: Some("Collapsible".to_string()),
481        })
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488    use serde_json::json;
489
490    fn result_lines(result: &ToolResult) -> Vec<&str> {
491        result
492            .result
493            .lines()
494            .filter(|line| !line.is_empty())
495            .collect()
496    }
497
498    fn non_partial_lines(result: &ToolResult) -> Vec<&str> {
499        result_lines(result)
500            .into_iter()
501            .filter(|line| !line.starts_with("[PARTIAL]"))
502            .collect()
503    }
504
505    #[tokio::test]
506    async fn grep_defaults_to_files_with_matches() {
507        let dir = tempfile::tempdir().unwrap();
508        let file_hit = dir.path().join("match.rs");
509        let file_miss = dir.path().join("miss.txt");
510        tokio::fs::write(&file_hit, "let value = 1;\nneedle\n")
511            .await
512            .unwrap();
513        tokio::fs::write(&file_miss, "nothing to see\n")
514            .await
515            .unwrap();
516
517        let tool = GrepTool::new();
518        let result = tool
519            .execute(json!({
520                "pattern": "needle",
521                "path": dir.path()
522            }))
523            .await
524            .unwrap();
525
526        assert!(result.success);
527        let lines = result_lines(&result);
528        assert_eq!(lines.len(), 1);
529        assert!(lines[0].contains("match.rs"));
530    }
531
532    #[tokio::test]
533    async fn grep_content_mode_supports_context_and_line_numbers() {
534        let dir = tempfile::tempdir().unwrap();
535        let file = dir.path().join("content.txt");
536        tokio::fs::write(&file, "one\ntwo\nneedle\nfour\nfive\n")
537            .await
538            .unwrap();
539
540        let tool = GrepTool::new();
541        let result = tool
542            .execute(json!({
543                "pattern": "needle",
544                "path": file,
545                "output_mode": "content",
546                "-C": 1,
547                "-n": true
548            }))
549            .await
550            .unwrap();
551
552        let output = result.result;
553        assert!(output.contains(":2:two"));
554        assert!(output.contains(":3:needle"));
555        assert!(output.contains(":4:four"));
556        assert!(!output.contains(":1:one"));
557        assert!(!output.contains(":5:five"));
558    }
559
560    #[tokio::test]
561    async fn grep_count_mode_respects_type_filter_and_head_limit() {
562        let dir = tempfile::tempdir().unwrap();
563        let file_rs_a = dir.path().join("a.rs");
564        let file_rs_b = dir.path().join("b.rs");
565        let file_txt = dir.path().join("c.txt");
566        tokio::fs::write(&file_rs_a, "foo\nfoo\n").await.unwrap();
567        tokio::fs::write(&file_rs_b, "foo\n").await.unwrap();
568        tokio::fs::write(&file_txt, "foo\n").await.unwrap();
569
570        let tool = GrepTool::new();
571        let result = tool
572            .execute(json!({
573                "pattern": "foo",
574                "path": dir.path(),
575                "output_mode": "count",
576                "type": "rust",
577                "head_limit": 1
578            }))
579            .await
580            .unwrap();
581
582        let lines = non_partial_lines(&result);
583        assert_eq!(lines.len(), 1);
584        assert!(lines[0].contains(".rs:"));
585        assert!(!lines[0].contains("c.txt"));
586        assert!(result.result.contains("[PARTIAL]"));
587    }
588
589    #[tokio::test]
590    async fn grep_multiline_and_case_insensitive_work_with_glob_filter() {
591        let dir = tempfile::tempdir().unwrap();
592        let file_one = dir.path().join("one.rs");
593        let file_two = dir.path().join("two.rs");
594        tokio::fs::write(&file_one, "Hello\nWORLD\n").await.unwrap();
595        tokio::fs::write(&file_two, "Hello\nplanet\n")
596            .await
597            .unwrap();
598
599        let tool = GrepTool::new();
600        let result = tool
601            .execute(json!({
602                "pattern": "hello\\s+world",
603                "path": dir.path(),
604                "glob": "**/one.rs",
605                "-i": true,
606                "multiline": true
607            }))
608            .await
609            .unwrap();
610
611        let output = result.result;
612        assert!(output.contains("one.rs"));
613        assert!(!output.contains("two.rs"));
614    }
615
616    #[tokio::test]
617    async fn grep_content_mode_requires_scope_hint() {
618        let tool = GrepTool::new();
619        let error = tool
620            .execute(json!({
621                "pattern": "needle",
622                "output_mode": "content"
623            }))
624            .await
625            .expect_err("content mode without scope should fail");
626
627        assert!(matches!(error, ToolError::InvalidArguments(_)));
628        assert!(error.to_string().contains(SEARCH_SCOPE_TOO_BROAD_ERROR));
629    }
630
631    #[tokio::test]
632    async fn grep_multiline_requires_explicit_narrowed_path() {
633        let tool = GrepTool::new();
634        let error = tool
635            .execute(json!({
636                "pattern": "a\\s+b",
637                "multiline": true
638            }))
639            .await
640            .expect_err("multiline without path should fail");
641        assert!(matches!(error, ToolError::InvalidArguments(_)));
642        assert!(error
643            .to_string()
644            .contains(MULTILINE_REQUIRES_NARROWED_PATH_ERROR));
645
646        let cwd = std::env::current_dir().unwrap();
647        let error = tool
648            .execute(json!({
649                "pattern": "a\\s+b",
650                "multiline": true,
651                "path": cwd
652            }))
653            .await
654            .expect_err("multiline at workspace root should fail");
655        assert!(matches!(error, ToolError::InvalidArguments(_)));
656        assert!(error
657            .to_string()
658            .contains(MULTILINE_REQUIRES_NARROWED_PATH_ERROR));
659    }
660
661    #[tokio::test]
662    async fn grep_defaults_head_limit_to_200() {
663        let dir = tempfile::tempdir().unwrap();
664        for idx in 0..260 {
665            let file = dir.path().join(format!("file-{idx}.txt"));
666            tokio::fs::write(&file, "needle\n").await.unwrap();
667        }
668
669        let tool = GrepTool::new();
670        let result = tool
671            .execute(json!({
672                "pattern": "needle",
673                "path": dir.path()
674            }))
675            .await
676            .unwrap();
677
678        let lines = non_partial_lines(&result);
679        assert_eq!(lines.len(), 200);
680        assert!(result.result.contains("[PARTIAL]"));
681    }
682
683    #[tokio::test]
684    async fn grep_rejects_excessive_match_volume() {
685        let dir = tempfile::tempdir().unwrap();
686        let file = dir.path().join("huge.txt");
687        let mut content = String::new();
688        for _ in 0..(MAX_MATCHES + 1) {
689            content.push_str("needle\n");
690        }
691        tokio::fs::write(&file, content).await.unwrap();
692
693        let tool = GrepTool::new();
694        let error = tool
695            .execute(json!({
696                "pattern": "needle",
697                "path": file
698            }))
699            .await
700            .expect_err("should reject oversized results");
701
702        assert!(matches!(error, ToolError::Execution(_)));
703        assert!(error.to_string().contains(RESULT_TOO_LARGE_ERROR));
704    }
705}