Skip to main content

soul_coder/tools/
grep.rs

1//! Grep tool — search file contents using regex or literal patterns.
2//!
3//! Uses VirtualFs for WASM compatibility. In WASM mode, performs regex search
4//! over all files in the VFS. In native mode, can delegate to ripgrep via VirtualExecutor.
5
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use serde_json::json;
10use tokio::sync::mpsc;
11
12use soul_core::error::SoulResult;
13use soul_core::tool::{Tool, ToolOutput};
14use soul_core::types::ToolDefinition;
15use soul_core::vfs::VirtualFs;
16
17use crate::truncate::{truncate_head, truncate_line, GREP_MAX_LINE_LENGTH, MAX_BYTES};
18
19/// Maximum number of matches returned.
20const MAX_MATCHES: usize = 100;
21
22use super::resolve_path;
23
24pub struct GrepTool {
25    fs: Arc<dyn VirtualFs>,
26    cwd: String,
27}
28
29impl GrepTool {
30    pub fn new(fs: Arc<dyn VirtualFs>, cwd: impl Into<String>) -> Self {
31        Self {
32            fs,
33            cwd: cwd.into(),
34        }
35    }
36}
37
38/// Simple pattern matching (supports literal and basic regex via contains).
39fn matches_pattern(line: &str, pattern: &str, literal: bool, ignore_case: bool) -> bool {
40    if literal {
41        if ignore_case {
42            line.to_lowercase().contains(&pattern.to_lowercase())
43        } else {
44            line.contains(pattern)
45        }
46    } else {
47        // Basic regex-like: treat as literal for WASM (no regex crate dependency)
48        // For full regex, the native implementation delegates to rg
49        if ignore_case {
50            line.to_lowercase().contains(&pattern.to_lowercase())
51        } else {
52            line.contains(pattern)
53        }
54    }
55}
56
57/// Recursively collect all file paths from a VFS directory.
58async fn collect_files(
59    fs: &dyn VirtualFs,
60    dir: &str,
61    files: &mut Vec<String>,
62    glob_filter: Option<&str>,
63) -> SoulResult<()> {
64    let entries = fs.read_dir(dir).await?;
65    for entry in entries {
66        let path = if dir == "/" || dir.is_empty() {
67            format!("/{}", entry.name)
68        } else {
69            format!("{}/{}", dir.trim_end_matches('/'), entry.name)
70        };
71
72        if entry.is_dir {
73            // Skip hidden dirs
74            if !entry.name.starts_with('.') {
75                Box::pin(collect_files(fs, &path, files, glob_filter)).await?;
76            }
77        } else if entry.is_file {
78            if let Some(glob) = glob_filter {
79                if matches_glob(&entry.name, glob) {
80                    files.push(path);
81                }
82            } else {
83                files.push(path);
84            }
85        }
86    }
87    Ok(())
88}
89
90/// Simple glob matching (supports *.ext patterns).
91fn matches_glob(filename: &str, glob: &str) -> bool {
92    if glob.starts_with("*.") {
93        let ext = &glob[1..]; // ".ext"
94        filename.ends_with(ext)
95    } else if glob.contains('*') {
96        // Very basic wildcard
97        let parts: Vec<&str> = glob.split('*').collect();
98        if parts.len() == 2 {
99            filename.starts_with(parts[0]) && filename.ends_with(parts[1])
100        } else {
101            true // No filtering
102        }
103    } else {
104        filename == glob
105    }
106}
107
108#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
109#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
110impl Tool for GrepTool {
111    fn name(&self) -> &str {
112        "grep"
113    }
114
115    fn definition(&self) -> ToolDefinition {
116        ToolDefinition {
117            name: "grep".into(),
118            description: "Search file contents for a pattern. Returns matching lines with file paths and line numbers.".into(),
119            input_schema: json!({
120                "type": "object",
121                "properties": {
122                    "pattern": {
123                        "type": "string",
124                        "description": "Search pattern (literal string or regex)"
125                    },
126                    "path": {
127                        "type": "string",
128                        "description": "Directory to search in (defaults to working directory)"
129                    },
130                    "glob": {
131                        "type": "string",
132                        "description": "Glob pattern to filter files (e.g., '*.rs', '*.ts')"
133                    },
134                    "ignore_case": {
135                        "type": "boolean",
136                        "description": "Case-insensitive search"
137                    },
138                    "literal": {
139                        "type": "boolean",
140                        "description": "Treat pattern as literal string (no regex)"
141                    },
142                    "context": {
143                        "type": "integer",
144                        "description": "Number of context lines before and after each match"
145                    },
146                    "max_matches": {
147                        "type": "integer",
148                        "description": "Maximum number of matches to return (default: 100)"
149                    }
150                },
151                "required": ["pattern"]
152            }),
153        }
154    }
155
156    async fn execute(
157        &self,
158        _call_id: &str,
159        arguments: serde_json::Value,
160        _partial_tx: Option<mpsc::UnboundedSender<String>>,
161    ) -> SoulResult<ToolOutput> {
162        let pattern = arguments
163            .get("pattern")
164            .and_then(|v| v.as_str())
165            .unwrap_or("");
166
167        if pattern.is_empty() {
168            return Ok(ToolOutput::error("Missing required parameter: pattern"));
169        }
170
171        let search_path = arguments
172            .get("path")
173            .and_then(|v| v.as_str())
174            .map(|p| resolve_path(&self.cwd, p))
175            .unwrap_or_else(|| self.cwd.clone());
176
177        let glob_filter = arguments.get("glob").and_then(|v| v.as_str());
178        let ignore_case = arguments
179            .get("ignore_case")
180            .and_then(|v| v.as_bool())
181            .unwrap_or(false);
182        let literal = arguments
183            .get("literal")
184            .and_then(|v| v.as_bool())
185            .unwrap_or(false);
186        let context_lines = arguments
187            .get("context")
188            .and_then(|v| v.as_u64())
189            .unwrap_or(0) as usize;
190        let max_matches = arguments
191            .get("max_matches")
192            .and_then(|v| v.as_u64())
193            .map(|v| (v as usize).min(MAX_MATCHES))
194            .unwrap_or(MAX_MATCHES);
195
196        // Collect files to search
197        let mut files = Vec::new();
198        if let Err(e) = collect_files(self.fs.as_ref(), &search_path, &mut files, glob_filter).await
199        {
200            return Ok(ToolOutput::error(format!(
201                "Failed to enumerate files in {}: {}",
202                search_path, e
203            )));
204        }
205
206        files.sort();
207
208        let mut output = String::new();
209        let mut total_matches = 0;
210        let mut files_with_matches = 0;
211
212        'files: for file_path in &files {
213            let content = match self.fs.read_to_string(file_path).await {
214                Ok(c) => c,
215                Err(_) => continue, // Skip unreadable files
216            };
217
218            let lines: Vec<&str> = content.lines().collect();
219            let mut file_had_match = false;
220
221            for (line_idx, line) in lines.iter().enumerate() {
222                if matches_pattern(line, pattern, literal, ignore_case) {
223                    if !file_had_match {
224                        if !output.is_empty() {
225                            output.push('\n');
226                        }
227                        files_with_matches += 1;
228                        file_had_match = true;
229                    }
230
231                    // Context before
232                    let ctx_start = line_idx.saturating_sub(context_lines);
233                    for ctx_idx in ctx_start..line_idx {
234                        output.push_str(&format!(
235                            "{}:{}-{}\n",
236                            display_path(file_path, &self.cwd),
237                            ctx_idx + 1,
238                            truncate_line(lines[ctx_idx], GREP_MAX_LINE_LENGTH)
239                        ));
240                    }
241
242                    // Match line
243                    output.push_str(&format!(
244                        "{}:{}:{}\n",
245                        display_path(file_path, &self.cwd),
246                        line_idx + 1,
247                        truncate_line(line, GREP_MAX_LINE_LENGTH)
248                    ));
249
250                    // Context after
251                    let ctx_end = (line_idx + context_lines + 1).min(lines.len());
252                    for ctx_idx in (line_idx + 1)..ctx_end {
253                        output.push_str(&format!(
254                            "{}:{}-{}\n",
255                            display_path(file_path, &self.cwd),
256                            ctx_idx + 1,
257                            truncate_line(lines[ctx_idx], GREP_MAX_LINE_LENGTH)
258                        ));
259                    }
260
261                    total_matches += 1;
262                    if total_matches >= max_matches {
263                        break 'files;
264                    }
265                }
266            }
267        }
268
269        if total_matches == 0 {
270            return Ok(ToolOutput::success(format!(
271                "No matches found for pattern '{}' in {}",
272                pattern,
273                display_path(&search_path, &self.cwd)
274            ))
275            .with_metadata(json!({"matches": 0, "files": 0})));
276        }
277
278        // Apply byte truncation
279        let truncated = truncate_head(&output, total_matches + (total_matches * context_lines * 2), MAX_BYTES);
280
281        let notice = truncated.truncation_notice();
282        let is_truncated = truncated.is_truncated();
283        let mut result = truncated.content;
284        if total_matches >= max_matches {
285            result.push_str(&format!(
286                "\n[Reached max matches limit: {}]",
287                max_matches
288            ));
289        }
290        if let Some(notice) = notice {
291            result.push_str(&format!("\n{}", notice));
292        }
293
294        Ok(ToolOutput::success(result).with_metadata(json!({
295            "matches": total_matches,
296            "files_with_matches": files_with_matches,
297            "truncated": is_truncated,
298        })))
299    }
300}
301
302/// Make paths relative to cwd for display.
303fn display_path(path: &str, cwd: &str) -> String {
304    let cwd_prefix = format!("{}/", cwd.trim_end_matches('/'));
305    if path.starts_with(&cwd_prefix) {
306        path[cwd_prefix.len()..].to_string()
307    } else {
308        path.to_string()
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use soul_core::vfs::MemoryFs;
316
317    async fn setup() -> (Arc<MemoryFs>, GrepTool) {
318        let fs = Arc::new(MemoryFs::new());
319        let tool = GrepTool::new(fs.clone() as Arc<dyn VirtualFs>, "/project");
320        (fs, tool)
321    }
322
323    #[tokio::test]
324    async fn grep_simple_match() {
325        let (fs, tool) = setup().await;
326        fs.write("/project/file.txt", "hello world\nfoo bar\nhello again")
327            .await
328            .unwrap();
329
330        let result = tool
331            .execute("c1", json!({"pattern": "hello"}), None)
332            .await
333            .unwrap();
334
335        assert!(!result.is_error);
336        assert!(result.content.contains("file.txt:1:hello world"));
337        assert!(result.content.contains("file.txt:3:hello again"));
338    }
339
340    #[tokio::test]
341    async fn grep_case_insensitive() {
342        let (fs, tool) = setup().await;
343        fs.write("/project/file.txt", "Hello World\nhello world")
344            .await
345            .unwrap();
346
347        let result = tool
348            .execute(
349                "c2",
350                json!({"pattern": "HELLO", "ignore_case": true}),
351                None,
352            )
353            .await
354            .unwrap();
355
356        assert!(!result.is_error);
357        assert!(result.metadata["matches"].as_u64().unwrap() == 2);
358    }
359
360    #[tokio::test]
361    async fn grep_with_glob_filter() {
362        let (fs, tool) = setup().await;
363        fs.write("/project/code.rs", "fn main() {}")
364            .await
365            .unwrap();
366        fs.write("/project/readme.md", "fn main() {}")
367            .await
368            .unwrap();
369
370        let result = tool
371            .execute(
372                "c3",
373                json!({"pattern": "fn main", "glob": "*.rs"}),
374                None,
375            )
376            .await
377            .unwrap();
378
379        assert!(!result.is_error);
380        assert!(result.content.contains("code.rs"));
381        assert!(!result.content.contains("readme.md"));
382    }
383
384    #[tokio::test]
385    async fn grep_no_matches() {
386        let (fs, tool) = setup().await;
387        fs.write("/project/file.txt", "nothing here")
388            .await
389            .unwrap();
390
391        let result = tool
392            .execute("c4", json!({"pattern": "missing"}), None)
393            .await
394            .unwrap();
395
396        assert!(!result.is_error);
397        assert!(result.content.contains("No matches"));
398    }
399
400    #[tokio::test]
401    async fn grep_empty_pattern() {
402        let (_fs, tool) = setup().await;
403        let result = tool
404            .execute("c5", json!({"pattern": ""}), None)
405            .await
406            .unwrap();
407        assert!(result.is_error);
408    }
409
410    #[tokio::test]
411    async fn grep_with_context() {
412        let (fs, tool) = setup().await;
413        fs.write("/project/file.txt", "a\nb\nc\nd\ne")
414            .await
415            .unwrap();
416
417        let result = tool
418            .execute(
419                "c6",
420                json!({"pattern": "c", "context": 1}),
421                None,
422            )
423            .await
424            .unwrap();
425
426        assert!(!result.is_error);
427        assert!(result.content.contains("b")); // before context
428        assert!(result.content.contains("d")); // after context
429    }
430
431    #[test]
432    fn glob_matching() {
433        assert!(matches_glob("file.rs", "*.rs"));
434        assert!(!matches_glob("file.ts", "*.rs"));
435        assert!(matches_glob("test.spec.ts", "*.ts"));
436    }
437
438    #[test]
439    fn display_path_relative() {
440        assert_eq!(display_path("/project/src/main.rs", "/project"), "src/main.rs");
441        assert_eq!(display_path("/other/file.txt", "/project"), "/other/file.txt");
442    }
443
444    #[tokio::test]
445    async fn tool_name_and_definition() {
446        let (_fs, tool) = setup().await;
447        assert_eq!(tool.name(), "grep");
448        let def = tool.definition();
449        assert_eq!(def.name, "grep");
450    }
451}