Skip to main content

oxi_agent/tools/
grep.rs

1use super::path_security::PathGuard;
2/// Grep tool - search files for patterns
3use super::{AgentTool, AgentToolResult, ToolContext, ToolError};
4use async_trait::async_trait;
5use regex::RegexBuilder;
6use serde_json::{json, Value};
7use std::path::{Path, PathBuf};
8use tokio::fs;
9use tokio::sync::oneshot;
10
11/// Maximum characters per line in grep output
12const GREP_MAX_LINE_LENGTH: usize = 500;
13
14/// Truncate a single line to max characters, adding "... [truncated]" suffix.
15fn truncate_line(line: &str) -> (String, bool) {
16    if line.len() <= GREP_MAX_LINE_LENGTH {
17        (line.to_string(), false)
18    } else {
19        (
20            format!("{}... [truncated]", &line[..GREP_MAX_LINE_LENGTH]),
21            true,
22        )
23    }
24}
25
26/// GrepTool.
27pub struct GrepTool {
28    root_dir: Option<PathBuf>,
29}
30
31impl GrepTool {
32    /// Create with no explicit root (uses ToolContext.workspace_dir at runtime).
33    pub fn new() -> Self {
34        Self { root_dir: None }
35    }
36
37    /// Create with a specific working directory (overrides ToolContext).
38    pub fn with_cwd(cwd: PathBuf) -> Self {
39        Self {
40            root_dir: Some(cwd),
41        }
42    }
43
44    /// Check if a filename matches a simple glob pattern like "*.rs", "*.ts"
45    fn matches_glob(file_name: &str, pattern: &str) -> bool {
46        if let Some(ext) = pattern.strip_prefix("*.") {
47            file_name.ends_with(ext)
48        } else if pattern.contains('*') {
49            // Simple wildcard matching
50            let parts: Vec<&str> = pattern.split('*').collect();
51            if parts.len() == 2 {
52                file_name.starts_with(parts[0]) && file_name.ends_with(parts[1])
53            } else {
54                file_name == pattern
55            }
56        } else {
57            file_name == pattern
58        }
59    }
60
61    #[allow(clippy::type_complexity)]
62    #[allow(clippy::too_many_arguments)]
63    async fn grep_impl(
64        root_dir: &Path,
65        pattern: &str,
66        path: &str,
67        case_insensitive: bool,
68        literal: bool,
69        context_before: usize,
70        context_after: usize,
71        include: Option<&str>,
72        max_results: usize,
73    ) -> Result<(String, bool), ToolError> {
74        // Security: validate path with PathGuard
75        let guard = PathGuard::new(root_dir);
76        let root = guard
77            .validate_traversal(Path::new(path))
78            .map_err(|e| e.to_string())?;
79
80        if !root.exists() {
81            return Err(format!("Path not found: {}", path));
82        }
83
84        // Escape the pattern for literal matching if needed
85        let pattern = if literal {
86            regex::escape(pattern)
87        } else {
88            pattern.to_string()
89        };
90
91        let re = RegexBuilder::new(&pattern)
92            .case_insensitive(case_insensitive)
93            .build()
94            .map_err(|e| format!("Invalid pattern '{}': {}", pattern, e))?;
95
96        let mut matches: Vec<String> = Vec::new();
97        let mut lines_truncated = false;
98        Self::grep_walk(
99            &root,
100            &root,
101            &re,
102            include,
103            context_before,
104            context_after,
105            max_results,
106            &mut matches,
107            &mut lines_truncated,
108        )
109        .await?;
110
111        if matches.is_empty() {
112            Ok(("No matches found".to_string(), false))
113        } else {
114            let header = format!("Found {} matches:\n", matches.len());
115            Ok((header + &matches.join("\n"), lines_truncated))
116        }
117    }
118
119    /// Read a file and return lines as a vector
120    async fn read_file_lines(path: &Path) -> Result<Vec<String>, ToolError> {
121        match fs::read_to_string(path).await {
122            Ok(content) => {
123                // Normalize line endings: replace CRLF and standalone CR with LF, then split
124                let normalized = content.replace("\r\n", "\n").replace('\r', "\n");
125                Ok(normalized.lines().map(|s| s.to_string()).collect())
126            }
127            Err(e) => Err(format!("Cannot read file: {}", e)),
128        }
129    }
130
131    #[allow(clippy::too_many_arguments)]
132    async fn grep_walk(
133        root: &Path,
134        current: &Path,
135        re: &regex::Regex,
136        include: Option<&str>,
137        context_before: usize,
138        context_after: usize,
139        max_results: usize,
140        matches: &mut Vec<String>,
141        lines_truncated: &mut bool,
142    ) -> Result<(), ToolError> {
143        if matches.len() >= max_results {
144            return Ok(());
145        }
146
147        // Detect and skip broken symlinks - they cause read_dir to fail
148        // and should not cause the entire search to fail
149        if current
150            .symlink_metadata()
151            .map(|m| m.file_type().is_symlink())
152            .unwrap_or(false)
153            && !current.exists()
154        {
155            return Ok(());
156        }
157
158        if current.is_file() {
159            // Check include filter
160            if let Some(glob) = include {
161                let file_name = current
162                    .file_name()
163                    .map(|n| n.to_string_lossy().to_string())
164                    .unwrap_or_default();
165                if !Self::matches_glob(&file_name, glob) {
166                    return Ok(());
167                }
168            }
169
170            // Try to read and search the file
171            match Self::read_file_lines(current).await {
172                Ok(lines) => {
173                    let relative = current.strip_prefix(root).unwrap_or(current).display();
174
175                    for (i, line) in lines.iter().enumerate() {
176                        if re.is_match(line) {
177                            // Check if adding this match would exceed max_results
178                            // We may need to add context lines too
179                            let context_lines_count = if context_before > 0 || context_after > 0 {
180                                let start = if context_before > 0 {
181                                    i.saturating_sub(context_before)
182                                } else {
183                                    i
184                                };
185                                let end = std::cmp::min(lines.len(), i + context_after + 1);
186                                end - start
187                            } else {
188                                1
189                            };
190
191                            if matches.len() + context_lines_count > max_results {
192                                // Can't add this match with its context, stop
193                                return Ok(());
194                            }
195
196                            // Add context lines before match
197                            if context_before > 0 && i > 0 {
198                                let start = i.saturating_sub(context_before);
199                                for (j, context_line) in
200                                    lines.iter().enumerate().take(i).skip(start)
201                                {
202                                    let (truncated_text, was_truncated) =
203                                        truncate_line(context_line);
204                                    if was_truncated {
205                                        *lines_truncated = true;
206                                    }
207                                    matches.push(format!(
208                                        "{}-{}- {}",
209                                        relative,
210                                        j + 1,
211                                        truncated_text
212                                    ));
213                                }
214                            }
215
216                            // Add the match line
217                            let (truncated_text, was_truncated) = truncate_line(line);
218                            if was_truncated {
219                                *lines_truncated = true;
220                            }
221                            matches.push(format!("{}:{}: {}", relative, i + 1, truncated_text));
222
223                            // Add context lines after match
224                            if context_after > 0 {
225                                let end = std::cmp::min(lines.len(), i + context_after + 1);
226                                for (j, context_line) in
227                                    lines.iter().enumerate().take(end).skip(i + 1)
228                                {
229                                    let (truncated_text, was_truncated) =
230                                        truncate_line(context_line);
231                                    if was_truncated {
232                                        *lines_truncated = true;
233                                    }
234                                    matches.push(format!(
235                                        "{}-{}- {}",
236                                        relative,
237                                        j + 1,
238                                        truncated_text
239                                    ));
240                                }
241                            }
242
243                            if matches.len() >= max_results {
244                                return Ok(());
245                            }
246                        }
247                    }
248                }
249                Err(_) => {
250                    // Skip files we can't read (binary, permissions, etc.)
251                }
252            }
253            return Ok(());
254        }
255
256        // Directory: walk entries
257        let mut entries = fs::read_dir(current)
258            .await
259            .map_err(|e| format!("Cannot read directory {}: {}", current.display(), e))?;
260
261        while let Some(entry) = entries
262            .next_entry()
263            .await
264            .map_err(|e| format!("Error reading entry: {}", e))?
265        {
266            let entry_path = entry.path();
267
268            // Skip hidden files/dirs
269            if entry_path
270                .file_name()
271                .map(|n| n.to_string_lossy().starts_with('.'))
272                .unwrap_or(false)
273            {
274                continue;
275            }
276
277            // Skip common non-searchable dirs
278            if entry_path.is_dir() {
279                let dir_name = entry_path
280                    .file_name()
281                    .map(|n| n.to_string_lossy().to_string())
282                    .unwrap_or_default();
283                if matches!(
284                    dir_name.as_str(),
285                    "node_modules"
286                        | "target"
287                        | ".git"
288                        | "dist"
289                        | "build"
290                        | "__pycache__"
291                        | ".venv"
292                        | "venv"
293                ) {
294                    continue;
295                }
296            }
297
298            Box::pin(Self::grep_walk(
299                root,
300                &entry_path,
301                re,
302                include,
303                context_before,
304                context_after,
305                max_results,
306                matches,
307                lines_truncated,
308            ))
309            .await?;
310        }
311
312        Ok(())
313    }
314}
315
316impl Default for GrepTool {
317    fn default() -> Self {
318        Self::new()
319    }
320}
321
322#[async_trait]
323impl AgentTool for GrepTool {
324    fn name(&self) -> &str {
325        "grep"
326    }
327
328    fn label(&self) -> &str {
329        "Grep"
330    }
331
332    fn essential(&self) -> bool {
333        true
334    }
335    fn description(&self) -> &str {
336        "Search files for a pattern. Returns matching lines with file paths and line numbers. Use literal=true to treat pattern as a literal string. Use context=n to show n lines before and after matches. Long lines are truncated to 500 chars."
337    }
338
339    fn parameters_schema(&self) -> Value {
340        json!({
341            "type": "object",
342            "properties": {
343                "pattern": {
344                    "type": "string",
345                    "description": "The pattern to search for (regex by default, or literal string if literal=true)"
346                },
347                "path": {
348                    "type": "string",
349                    "description": "The directory or file to search in",
350                    "default": "."
351                },
352                "case_insensitive": {
353                    "type": "boolean",
354                    "description": "If true, perform case-insensitive search",
355                    "default": false
356                },
357                "literal": {
358                    "type": "boolean",
359                    "description": "If true, treat pattern as a literal string instead of regex",
360                    "default": false
361                },
362                "context": {
363                    "type": "integer",
364                    "description": "Number of lines to show before and after each match",
365                    "default": 0
366                },
367                "include": {
368                    "type": "string",
369                    "description": "Glob pattern to filter files (e.g., '*.rs', '*.ts')"
370                },
371                "max_results": {
372                    "type": "integer",
373                    "description": "Maximum number of results to return",
374                    "default": 100
375                }
376            },
377            "required": ["pattern"]
378        })
379    }
380
381    async fn execute(
382        &self,
383        _tool_call_id: &str,
384        params: Value,
385        _signal: Option<oneshot::Receiver<()>>,
386        ctx: &ToolContext,
387    ) -> Result<AgentToolResult, ToolError> {
388        let pattern = params
389            .get("pattern")
390            .and_then(|v: &Value| v.as_str())
391            .ok_or_else(|| "Missing required parameter: pattern".to_string())?;
392
393        let path = params
394            .get("path")
395            .and_then(|v: &Value| v.as_str())
396            .unwrap_or(".");
397
398        let case_insensitive = params
399            .get("case_insensitive")
400            .and_then(|v: &Value| v.as_bool())
401            .unwrap_or(false);
402
403        let literal = params
404            .get("literal")
405            .and_then(|v: &Value| v.as_bool())
406            .unwrap_or(false);
407
408        let context = params
409            .get("context")
410            .and_then(|v: &Value| v.as_u64())
411            .unwrap_or(0) as usize;
412
413        let include = params.get("include").and_then(|v: &Value| v.as_str());
414
415        let max_results = params
416            .get("max_results")
417            .and_then(|v: &Value| v.as_u64())
418            .unwrap_or(100) as usize;
419
420        // Use root_dir if set, else ctx.root()
421        let root = self.root_dir.as_deref().unwrap_or(ctx.root());
422
423        match Self::grep_impl(
424            root,
425            pattern,
426            path,
427            case_insensitive,
428            literal,
429            context,
430            context,
431            include,
432            max_results,
433        )
434        .await
435        {
436            Ok((output, lines_truncated)) => {
437                let mut result = AgentToolResult::success(output);
438                if lines_truncated {
439                    result.metadata = Some(json!({
440                        "lines_truncated": true,
441                        "message": "Some lines truncated to 500 chars. Use read tool to see full lines."
442                    }));
443                }
444                Ok(result)
445            }
446            Err(e) => Ok(AgentToolResult::error(e)),
447        }
448    }
449}