Skip to main content

atomcode_core/tool/
grep.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use ignore::WalkBuilder;
4use regex::RegexBuilder;
5use serde::Deserialize;
6use serde_json::json;
7
8use super::{ApprovalRequirement, Tool, ToolContext, ToolDef, ToolResult};
9
10pub struct GrepTool;
11
12#[derive(Deserialize)]
13struct GrepArgs {
14    pattern: String,
15    path: Option<String>,
16    #[serde(default = "default_max_results")]
17    max_results: usize,
18    #[serde(default = "default_context")]
19    context: usize,
20}
21
22fn default_context() -> usize {
23    3
24}
25fn default_max_results() -> usize {
26    50
27}
28
29#[async_trait]
30impl Tool for GrepTool {
31    fn definition(&self) -> ToolDef {
32        ToolDef {
33            name: "grep",
34            description: "Search file contents for a pattern. Returns matching lines with surrounding context.\n\
35                Usage:\n\
36                - Use this to find where a function, variable, string, or UI element is defined or used.\n\
37                - Use this BEFORE editing when the user's request is ambiguous — find ALL candidates first.\n\
38                - Pattern is regex by default (case-insensitive unless uppercase is used).\n\
39                - Escape special regex chars: . → \\\\. , ( → \\\\( , [ → \\\\[\n\
40                - If regex fails, the tool automatically retries with literal string matching.\n\
41                - NEVER use bash grep/rg — always use this tool.\n\
42                Examples:\n\
43                - Find a function: {\"pattern\": \"def process_data\"}\n\
44                - Find a string with dots: {\"pattern\": \"console\\\\.log\"}\n\
45                - Find across alternatives: {\"pattern\": \"upload|上传\"}\n\
46                - Search specific directory: {\"pattern\": \"import\", \"path\": \"src/views\"}".to_string(),
47            parameters: json!({
48                "type": "object",
49                "properties": {
50                    "pattern": { "type": "string", "description": "Search pattern (regex by default). Escape dots/parens: console\\.log\\(" },
51                    "path": { "type": "string", "description": "Directory or file to search (default: working directory)" },
52                    "max_results": { "type": "integer", "description": "Max results to return (default 50)" },
53                    "context": { "type": "integer", "description": "Lines of context around each match (default 3)" }
54                },
55                "required": ["pattern"]
56            }),
57        }
58    }
59
60    fn approval(&self, _args: &str) -> ApprovalRequirement {
61        ApprovalRequirement::AutoApprove
62    }
63
64    fn approval_with_context(&self, args: &str, ctx: &ToolContext) -> ApprovalRequirement {
65        let parsed = match serde_json::from_str::<GrepArgs>(args) {
66            Ok(parsed) => parsed,
67            Err(_) => return self.approval(args),
68        };
69        let working_dir = match ctx.working_dir.try_read() {
70            Ok(wd) => wd.clone(),
71            Err(_) => return self.approval(args),
72        };
73        let raw_path = parsed.path.as_deref().unwrap_or(".");
74        match super::approval_for_path(raw_path, &working_dir, super::ExternalPathAction::Read)
75        {
76            Ok(approval) => approval,
77            Err(_) => self.approval(args),
78        }
79    }
80
81    async fn execute(&self, args: &str, ctx: &ToolContext) -> Result<ToolResult> {
82        let parsed: GrepArgs = serde_json::from_str(args)?;
83        let path = parsed.path.as_deref().unwrap_or(".");
84
85        let wd = ctx.working_dir.read().await.clone();
86
87        // Graph enhancement: build a short header when pattern looks like an
88        // identifier and no specific path is given (project-wide search).
89        // The header is prepended to ripgrep results — it never replaces them.
90        let graph_header = if parsed.path.is_none() {
91            self.build_graph_header(&parsed.pattern, ctx, &wd).await
92        } else {
93            None
94        };
95
96        let max = parsed.max_results;
97        let context_lines = parsed.context.min(10);
98        let resolved = match super::inspect_path_access(path, &wd) {
99            Ok(access) => access.path,
100            Err(err) => {
101                return Ok(ToolResult {
102                    call_id: String::new(),
103                    output: err.to_string(),
104                    success: false,
105                });
106            }
107        };
108
109        if !resolved.exists() {
110            return Ok(ToolResult {
111                call_id: String::new(),
112                output: format!("Path not found: {}", resolved.display()),
113                success: false,
114            });
115        }
116
117        // Build regex (smart-case: case-insensitive if pattern has no uppercase)
118        let has_uppercase = parsed.pattern.chars().any(|c| c.is_uppercase());
119        let re = match RegexBuilder::new(&parsed.pattern)
120            .case_insensitive(!has_uppercase)
121            .build()
122        {
123            Ok(r) => r,
124            Err(_) => {
125                // Regex failed — try as literal
126                match RegexBuilder::new(&regex::escape(&parsed.pattern))
127                    .case_insensitive(!has_uppercase)
128                    .build()
129                {
130                    Ok(r) => r,
131                    Err(e) => {
132                        return Ok(ToolResult {
133                            call_id: String::new(),
134                            output: format!("Invalid pattern '{}': {}", parsed.pattern, e),
135                            success: false,
136                        });
137                    }
138                }
139            }
140        };
141
142        // Walk files using ignore crate (respects .gitignore, skips binary, multi-threaded)
143        let walker = WalkBuilder::new(&resolved)
144            .hidden(true) // skip hidden files
145            .git_ignore(true) // respect .gitignore
146            .git_global(true)
147            .git_exclude(true)
148            .build();
149
150        let mut matches: Vec<String> = Vec::new();
151        let mut files_searched = 0usize;
152        let mut match_count = 0usize;
153
154        for entry in walker {
155            let entry = match entry {
156                Ok(e) => e,
157                Err(_) => continue,
158            };
159
160            if !entry.file_type().map_or(false, |ft| ft.is_file()) {
161                continue;
162            }
163
164            let file_path = entry.path();
165
166            // Skip known noise directories/files not covered by .gitignore
167            let path_str = file_path.to_string_lossy();
168            if path_str.contains("/datalog/")
169                || path_str.ends_with(".log")
170                || path_str.contains("/target/")
171                || path_str.contains("/dist/")
172                || path_str.contains("/node_modules/")
173            {
174                continue;
175            }
176
177            // Read file (skip binary)
178            let content = match std::fs::read_to_string(file_path) {
179                Ok(c) => c,
180                Err(_) => continue, // binary or unreadable
181            };
182
183            files_searched += 1;
184            let lines: Vec<&str> = content.lines().collect();
185
186            // Find matching lines
187            let mut file_matches: Vec<usize> = Vec::new();
188            for (i, line) in lines.iter().enumerate() {
189                if re.is_match(line) {
190                    file_matches.push(i);
191                    if match_count + file_matches.len() >= max {
192                        break;
193                    }
194                }
195            }
196
197            if file_matches.is_empty() {
198                continue;
199            }
200
201            // Format matches with context
202            let rel_path = file_path
203                .strip_prefix(&wd)
204                .unwrap_or(file_path)
205                .to_string_lossy();
206
207            let mut shown: std::collections::HashSet<usize> = std::collections::HashSet::new();
208            for &match_line in &file_matches {
209                let start = match_line.saturating_sub(context_lines);
210                let end = (match_line + context_lines + 1).min(lines.len());
211
212                // Separator between non-contiguous chunks
213                if !shown.is_empty() && start > 0 && !shown.contains(&(start - 1)) {
214                    matches.push("--".to_string());
215                }
216
217                for i in start..end {
218                    if shown.contains(&i) {
219                        continue;
220                    }
221                    shown.insert(i);
222
223                    let prefix = if i == match_line {
224                        format!("{}:{}:", rel_path, i + 1)
225                    } else {
226                        format!("{}-{}-", rel_path, i + 1)
227                    };
228                    matches.push(format!("{}{}", prefix, lines[i]));
229                }
230            }
231
232            match_count += file_matches.len();
233            if match_count >= max {
234                break;
235            }
236        }
237
238        // Annotate matching lines with enclosing function name (tree-sitter)
239        let mut searcher = ctx.semantic.lock().await;
240        let mut annotated: Vec<String> = Vec::new();
241        let mut sym_cache: std::collections::HashMap<String, Vec<crate::semantic::Symbol>> =
242            std::collections::HashMap::new();
243
244        for line in &matches {
245            let parts: Vec<&str> = line.splitn(3, ':').collect();
246            if parts.len() >= 3 {
247                if let Ok(line_no) = parts[1].parse::<usize>() {
248                    let file = parts[0];
249                    let abs_file = if std::path::Path::new(file).is_absolute() {
250                        std::path::PathBuf::from(file)
251                    } else {
252                        wd.join(file)
253                    };
254                    let symbols = sym_cache
255                        .entry(file.to_string())
256                        .or_insert_with(|| searcher.list_symbols(&abs_file).unwrap_or_default());
257                    if let Some(sym) = symbols
258                        .iter()
259                        .find(|s| line_no >= s.start_line && line_no <= s.end_line)
260                    {
261                        annotated.push(format!("{}  ← in {}()", line, sym.name));
262                        continue;
263                    }
264                }
265            }
266            annotated.push(line.clone());
267        }
268        drop(searcher);
269
270        let mut output = String::new();
271
272        if let Some(header) = graph_header {
273            output.push_str(&header);
274            output.push('\n');
275        }
276
277        if annotated.is_empty() {
278            output.push_str(&format!(
279                "No matches found for '{}' in {}",
280                parsed.pattern, path
281            ));
282            output.push_str(&format!(" ({} files searched)", files_searched));
283        } else {
284            let total = annotated.len();
285            output.push_str(&annotated.join("\n"));
286            if total >= max {
287                output.push_str(&format!("\n\n[Results capped at {} matches]", max));
288            }
289        };
290
291        // success=true even on zero matches: empty result IS the answer
292        // (matches glob's "No files matching" semantics — same screenshot
293        // showed glob's empty-result line rendering normally while grep's
294        // identical-meaning line painted red ✗ as if the search failed).
295        // Real failures (bad path / bad regex) take the early-return
296        // branches above with their own success=false.
297        Ok(ToolResult {
298            call_id: String::new(),
299            output,
300            success: true,
301        })
302    }
303}
304
305impl GrepTool {
306    /// Build a short graph-based header for the grep output.
307    /// Returns None if graph is not ready, no identifier is extractable,
308    /// or no symbols match.
309    async fn build_graph_header(
310        &self,
311        pattern: &str,
312        ctx: &ToolContext,
313        wd: &std::path::Path,
314    ) -> Option<String> {
315        let query_word = extract_graph_candidates(pattern)?;
316        let graph = ctx.graph.read().await;
317        if !graph.is_ready() {
318            return None;
319        }
320
321        let symbols = graph.find_by_name(&query_word);
322        if symbols.is_empty() {
323            return None;
324        }
325
326        let mut out = format!(
327            "[Graph: {} definitions for '{}']\n",
328            symbols.len(),
329            query_word
330        );
331        for sym in symbols.iter().take(5) {
332            let rel = sym
333                .file
334                .strip_prefix(wd)
335                .unwrap_or(&sym.file)
336                .to_string_lossy();
337            out.push_str(&format!(
338                "  {} {:?} in {}:{}\n",
339                sym.name, sym.kind, rel, sym.start_line
340            ));
341        }
342
343        Some(out)
344    }
345}
346
347/// Extract a likely code identifier from a grep pattern for graph lookup.
348///
349/// Only returns Some when the word strongly resembles a code symbol:
350/// - Contains underscore (snake_case): "fetch_weather" → Some("fetch_weather")
351/// - Contains mixed case (CamelCase): "SearchFilter" → Some("SearchFilter")
352/// - Pure lowercase words like "error", "table", "search" are rejected to
353///   avoid false positives with common English words.
354///
355/// Examples:
356/// - "fetch_weather" → Some("fetch_weather")
357/// - "SearchFilter"  → Some("SearchFilter")
358/// - "NdarrayMixin"  → Some("NdarrayMixin")
359/// - "weather|天气"  → None  (no identifier-like word)
360/// - "Structured ndarray gets viewed" → None  (plain English)
361/// - "error.*line" → None  (too generic)
362/// - "console\\.log" → None  (too generic)
363fn extract_graph_candidates(pattern: &str) -> Option<String> {
364    let skip_keywords = [
365        "pub", "fn", "struct", "enum", "impl", "use", "let", "const", "async", "trait", "type",
366        "mod", "crate", "self", "super", "for", "def", "class", "function", "var", "import",
367        "from", "export", "return", "match", "where", "static", "mut", "ref", "true", "false",
368        "none", "some", "null", "this", "not", "and", "the", "with",
369    ];
370
371    let mut best: Option<String> = None;
372    let mut best_len = 0;
373
374    for word in pattern.split(|c: char| !c.is_ascii_alphanumeric() && c != '_') {
375        let w = word.trim();
376        if w.len() < 4 {
377            continue;
378        }
379        if !w
380            .chars()
381            .next()
382            .map(|c| c.is_ascii_alphabetic() || c == '_')
383            .unwrap_or(false)
384        {
385            continue;
386        }
387        if !w.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
388            continue;
389        }
390        if skip_keywords.contains(&w.to_lowercase().as_str()) {
391            continue;
392        }
393
394        let has_underscore = w.contains('_');
395        // Require real CamelCase (lower→upper transition like "SearchFilter")
396        // not just a capitalized word ("Structured").
397        let has_camel_transition = w
398            .as_bytes()
399            .windows(2)
400            .any(|pair| pair[0].is_ascii_lowercase() && pair[1].is_ascii_uppercase());
401
402        if !has_underscore && !has_camel_transition {
403            continue;
404        }
405
406        if w.len() > best_len {
407            best = Some(w.to_string());
408            best_len = w.len();
409        }
410    }
411
412    best
413}
414
415#[cfg(test)]
416mod tests {
417    use super::extract_graph_candidates;
418    use super::GrepTool;
419    use crate::tool::{ApprovalRequirement, Tool, ToolContext};
420    use tempfile::TempDir;
421
422    #[test]
423    fn grep_outside_workspace_non_sensitive_requires_approval() {
424        let workspace = TempDir::new().unwrap();
425        let outside = TempDir::new().unwrap();
426        let ctx = ToolContext::new(workspace.path().to_path_buf());
427        let args = format!(
428            r#"{{"pattern":"foo","path":"{}"}}"#,
429            outside.path().display()
430        );
431        assert!(matches!(
432            GrepTool.approval_with_context(&args, &ctx),
433            ApprovalRequirement::RequireApproval(_)
434        ));
435    }
436
437    #[test]
438    fn grep_sensitive_path_still_requires_always() {
439        let workspace = TempDir::new().unwrap();
440        let ctx = ToolContext::new(workspace.path().to_path_buf());
441        // /etc is in the system-protected prefixes list.
442        let args = r#"{"pattern":"PermitRoot","path":"/etc"}"#;
443        assert!(matches!(
444            GrepTool.approval_with_context(args, &ctx),
445            ApprovalRequirement::RequireApprovalAlways(_)
446        ));
447    }
448
449    // Regression: zero-match grep used to return success=false, which
450    // made the TUI render the result row red ✗ as if the search itself
451    // failed (screenshot 43.png had two such red rows next to genuine
452    // path-not-found errors, making the agent's normal exploration look
453    // like a wall of failures). "No matches" is a valid answer — it's
454    // the *path* / *regex* errors that are real failures (those still
455    // return success=false via the early-return branches). Pin the
456    // semantics so future churn doesn't regress.
457    #[tokio::test]
458    async fn grep_zero_matches_reports_success_true() {
459        let workspace = TempDir::new().unwrap();
460        // Seed a file so the search has something to walk; the pattern
461        // intentionally won't match.
462        std::fs::write(
463            workspace.path().join("a.rs"),
464            "fn alpha() {}\nfn beta() {}\n",
465        )
466        .unwrap();
467        let ctx = ToolContext::new(workspace.path().to_path_buf());
468        let args = r#"{"pattern":"definitely_not_in_file_xyz"}"#;
469        let result = GrepTool.execute(args, &ctx).await.unwrap();
470        assert!(
471            result.success,
472            "zero-match grep must return success=true so TUI doesn't \
473             paint it red. Output was: {}",
474            result.output
475        );
476        assert!(
477            result.output.contains("No matches found"),
478            "zero-match output should explain what happened, got: {}",
479            result.output
480        );
481    }
482
483    // Real failure modes (bad regex / missing path) MUST still be
484    // success=false — those genuinely indicate the model needs to
485    // change inputs, not just learn that "no rows match".
486    #[tokio::test]
487    async fn grep_path_not_found_reports_success_false() {
488        let workspace = TempDir::new().unwrap();
489        let ctx = ToolContext::new(workspace.path().to_path_buf());
490        let args = r#"{"pattern":"foo","path":"/nonexistent/path/xyz123"}"#;
491        let result = GrepTool.execute(args, &ctx).await.unwrap();
492        assert!(
493            !result.success,
494            "path-not-found must remain success=false — output: {}",
495            result.output
496        );
497    }
498
499    // Default path "." resolves to the workspace itself — must auto-approve.
500    #[test]
501    fn grep_default_path_auto_approves() {
502        let workspace = TempDir::new().unwrap();
503        let ctx = ToolContext::new(workspace.path().to_path_buf());
504        let args = r#"{"pattern":"foo"}"#;
505        assert!(matches!(
506            GrepTool.approval_with_context(args, &ctx),
507            ApprovalRequirement::AutoApprove
508        ));
509    }
510
511    #[test]
512    fn snake_case_identifier() {
513        assert_eq!(
514            extract_graph_candidates("fetch_weather"),
515            Some("fetch_weather".into())
516        );
517    }
518
519    #[test]
520    fn camel_case_identifier() {
521        assert_eq!(
522            extract_graph_candidates("SearchFilter"),
523            Some("SearchFilter".into())
524        );
525        assert_eq!(
526            extract_graph_candidates("NdarrayMixin"),
527            Some("NdarrayMixin".into())
528        );
529    }
530
531    #[test]
532    fn mixed_cjk_ascii_no_identifier() {
533        assert_eq!(extract_graph_candidates("weather|天气"), None);
534    }
535
536    #[test]
537    fn plain_english_rejected() {
538        assert_eq!(
539            extract_graph_candidates("Structured ndarray gets viewed"),
540            None
541        );
542        assert_eq!(extract_graph_candidates("error"), None);
543        assert_eq!(extract_graph_candidates("table"), None);
544        assert_eq!(extract_graph_candidates("search"), None);
545        assert_eq!(extract_graph_candidates("console"), None);
546    }
547
548    #[test]
549    fn regex_pattern_rejected() {
550        assert_eq!(extract_graph_candidates("error.*line"), None);
551        assert_eq!(extract_graph_candidates("console\\.log"), None);
552    }
553
554    #[test]
555    fn keywords_rejected() {
556        assert_eq!(extract_graph_candidates("pub struct"), None);
557        assert_eq!(extract_graph_candidates("import from"), None);
558    }
559
560    #[test]
561    fn keyword_then_identifier() {
562        assert_eq!(
563            extract_graph_candidates("pub struct QueryIntent"),
564            Some("QueryIntent".into())
565        );
566        assert_eq!(
567            extract_graph_candidates("def process_data"),
568            Some("process_data".into())
569        );
570    }
571
572    #[test]
573    fn pure_chinese_rejected() {
574        assert_eq!(extract_graph_candidates("(科技|财经|体育)"), None);
575    }
576
577    #[test]
578    fn short_words_rejected() {
579        assert_eq!(extract_graph_candidates("fn"), None);
580        assert_eq!(extract_graph_candidates("FnX"), None); // 3 chars, below threshold of 4
581    }
582
583    #[test]
584    fn or_pattern_picks_best_identifier() {
585        assert_eq!(
586            extract_graph_candidates("SearchFilter|from_intent"),
587            Some("SearchFilter".into()),
588        );
589    }
590
591    #[test]
592    fn not_prefix_rejected() {
593        // "not data_is_mixin" — "not" is a keyword, "data_is_mixin" has underscore → match
594        assert_eq!(
595            extract_graph_candidates("not data_is_mixin"),
596            Some("data_is_mixin".into())
597        );
598    }
599}