Skip to main content

ai_agent/tools/
grep.rs

1use crate::types::*;
2use std::path::Path;
3use tokio::process::Command;
4
5pub struct GrepTool;
6
7impl GrepTool {
8    pub fn new() -> Self {
9        Self
10    }
11
12    pub fn name(&self) -> &str {
13        "Grep"
14    }
15
16    pub fn description(&self) -> &str {
17        "Search file contents using regex patterns. Uses ripgrep (rg) if available, falls back to grep."
18    }
19
20    pub fn user_facing_name(&self, _input: Option<&serde_json::Value>) -> String {
21        "Search".to_string()
22    }
23
24    pub fn get_tool_use_summary(&self, input: Option<&serde_json::Value>) -> Option<String> {
25        input.and_then(|inp| inp["pattern"].as_str().map(String::from))
26    }
27
28    pub fn render_tool_result_message(
29        &self,
30        content: &serde_json::Value,
31    ) -> Option<String> {
32        let text = content["content"].as_str()?;
33        let lines = text.lines().count();
34        Some(format!("{} {}", lines, if lines == 1 { "match" } else { "matches" }))
35    }
36
37    pub fn input_schema(&self) -> ToolInputSchema {
38        ToolInputSchema {
39            schema_type: "object".to_string(),
40            properties: serde_json::json!({
41                "pattern": {
42                    "type": "string",
43                    "description": "The regex pattern to search for"
44                },
45                "path": {
46                    "type": "string",
47                    "description": "File or directory to search in (defaults to cwd)"
48                },
49                "glob": {
50                    "type": "string",
51                    "description": "Glob pattern to filter files (e.g., \"*.ts\", \"*.{js,jsx}\")"
52                },
53                "type": {
54                    "type": "string",
55                    "description": "File type filter (e.g., \"ts\", \"py\", \"js\")"
56                },
57                "output_mode": {
58                    "type": "string",
59                    "enum": ["content", "files_with_matches", "count"],
60                    "description": "Output mode (default: files_with_matches)"
61                },
62                "-i": {
63                    "type": "boolean",
64                    "description": "Case insensitive search"
65                },
66                "-n": {
67                    "type": "boolean",
68                    "description": "Show line numbers (default: true)"
69                },
70                "-A": {
71                    "type": "number",
72                    "description": "Lines after match"
73                },
74                "-B": {
75                    "type": "number",
76                    "description": "Lines before match"
77                },
78                "-C": {
79                    "type": "number",
80                    "description": "Context lines"
81                },
82                "context": {
83                    "type": "number",
84                    "description": "Context lines (alias for -C)"
85                },
86                "head_limit": {
87                    "type": "number",
88                    "description": "Limit output entries (default: 250)"
89                }
90            }),
91            required: Some(vec!["pattern".to_string()]),
92        }
93    }
94
95    pub async fn execute(
96        &self,
97        input: serde_json::Value,
98        context: &ToolContext,
99    ) -> Result<ToolResult, crate::error::AgentError> {
100        let pattern = input["pattern"]
101            .as_str()
102            .ok_or_else(|| crate::error::AgentError::Tool("pattern is required".to_string()))?;
103
104        let search_path = input["path"]
105            .as_str()
106            .map(|p| {
107                if Path::new(p).is_absolute() {
108                    p.to_string()
109                } else {
110                    Path::new(&context.cwd)
111                        .join(p)
112                        .to_string_lossy()
113                        .to_string()
114                }
115            })
116            .unwrap_or_else(|| context.cwd.clone());
117
118        let output_mode = input["output_mode"]
119            .as_str()
120            .unwrap_or("files_with_matches");
121
122        let head_limit = input["head_limit"].as_u64().unwrap_or(250) as usize;
123
124        // Try ripgrep first
125        let result = self
126            .run_rg(
127                input.clone(),
128                pattern,
129                &search_path,
130                output_mode,
131                head_limit,
132            )
133            .await;
134
135        match result {
136            Ok(output) => Ok(output),
137            Err(_) => {
138                // Fall back to grep
139                self.run_grep(
140                    input.clone(),
141                    pattern,
142                    &search_path,
143                    output_mode,
144                    head_limit,
145                )
146                .await
147            }
148        }
149    }
150
151    async fn run_rg(
152        &self,
153        input: serde_json::Value,
154        pattern: &str,
155        search_path: &str,
156        output_mode: &str,
157        head_limit: usize,
158    ) -> Result<ToolResult, crate::error::AgentError> {
159        let mut args: Vec<String> = vec![];
160
161        if output_mode == "files_with_matches" {
162            args.push("--files-with-matches".to_string());
163        } else if output_mode == "count" {
164            args.push("--count".to_string());
165        } else if output_mode == "content" && input["-n"].as_bool().unwrap_or(true) {
166            args.push("--line-number".to_string());
167        }
168
169        if input["-i"].as_bool().unwrap_or(false) {
170            args.push("--ignore-case".to_string());
171        }
172
173        if let Some(n) = input["-A"].as_u64() {
174            args.push("-A".to_string());
175            args.push(n.to_string());
176        }
177
178        if let Some(n) = input["-B"].as_u64() {
179            args.push("-B".to_string());
180            args.push(n.to_string());
181        }
182
183        let ctx = input["-C"].as_u64().or_else(|| input["context"].as_u64());
184        if let Some(n) = ctx {
185            args.push("-C".to_string());
186            args.push(n.to_string());
187        }
188
189        if let Some(glob) = input["glob"].as_str() {
190            args.push("--glob".to_string());
191            args.push(glob.to_string());
192        }
193
194        if let Some(t) = input["type"].as_str() {
195            args.push("--type".to_string());
196            args.push(t.to_string());
197        }
198
199        args.push("--".to_string());
200        args.push(pattern.to_string());
201        args.push(search_path.to_string());
202
203        let output = Command::new("rg")
204            .args(&args)
205            .output()
206            .await
207            .map_err(|e| crate::error::AgentError::Tool(e.to_string()))?;
208
209        if !output.status.success() {
210            return Err(crate::error::AgentError::Tool("rg failed".to_string()));
211        }
212
213        let stdout = String::from_utf8_lossy(&output.stdout);
214        let result = stdout.trim();
215
216        if result.is_empty() {
217            return Ok(ToolResult {
218                result_type: "text".to_string(),
219                tool_use_id: "".to_string(),
220                content: format!("No matches found for pattern \"{}\"", pattern),
221                is_error: None,
222                was_persisted: None,
223            });
224        }
225
226        // Apply head limit
227        let content = self.apply_head_limit(result, head_limit);
228
229        Ok(ToolResult {
230            result_type: "text".to_string(),
231            tool_use_id: "".to_string(),
232            content,
233            is_error: None,
234            was_persisted: None,
235        })
236    }
237
238    async fn run_grep(
239        &self,
240        input: serde_json::Value,
241        pattern: &str,
242        search_path: &str,
243        output_mode: &str,
244        head_limit: usize,
245    ) -> Result<ToolResult, crate::error::AgentError> {
246        let mut args: Vec<String> = vec!["-r".to_string()];
247
248        if input["-i"].as_bool().unwrap_or(false) {
249            args.push("-i".to_string());
250        }
251
252        if output_mode == "files_with_matches" {
253            args.push("-l".to_string());
254        } else if output_mode == "count" {
255            args.push("-c".to_string());
256        } else if output_mode == "content" && input["-n"].as_bool().unwrap_or(true) {
257            args.push("-n".to_string());
258        }
259
260        if let Some(glob) = input["glob"].as_str() {
261            args.push("--include".to_string());
262            args.push(glob.to_string());
263        }
264
265        args.push("--".to_string());
266        args.push(pattern.to_string());
267        args.push(search_path.to_string());
268
269        let output = Command::new("grep")
270            .args(&args)
271            .output()
272            .await
273            .map_err(|e| crate::error::AgentError::Tool(e.to_string()))?;
274
275        let stdout = String::from_utf8_lossy(&output.stdout);
276        let result = stdout.trim();
277
278        if result.is_empty() {
279            return Ok(ToolResult {
280                result_type: "text".to_string(),
281                tool_use_id: "".to_string(),
282                content: format!("No matches found for pattern \"{}\"", pattern),
283                is_error: None,
284                was_persisted: None,
285            });
286        }
287
288        // Apply head limit
289        let content = self.apply_head_limit(result, head_limit);
290
291        Ok(ToolResult {
292            result_type: "text".to_string(),
293            tool_use_id: "".to_string(),
294            content,
295            is_error: None,
296            was_persisted: None,
297        })
298    }
299
300    fn apply_head_limit(&self, result: &str, head_limit: usize) -> String {
301        if head_limit > 0 {
302            let lines: Vec<&str> = result.lines().collect();
303            if lines.len() > head_limit {
304                let limited: Vec<&str> = lines.iter().take(head_limit).cloned().collect();
305                let remaining = lines.len() - head_limit;
306                return format!("{}\n... ({} more)", limited.join("\n"), remaining);
307            }
308        }
309        result.to_string()
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[tokio::test]
318    async fn test_grep_tool() {
319        // Create test file
320        tokio::fs::write("/tmp/test_grep.txt", "hello world\nfoo bar\ntest line")
321            .await
322            .unwrap();
323
324        let tool = GrepTool::new();
325        let result = tool.execute(
326            serde_json::json!({"pattern": "hello", "path": "/tmp/test_grep.txt", "output_mode": "content"}),
327            &ToolContext { cwd: "/tmp".to_string(), abort_signal: Default::default() },
328        ).await;
329        assert!(result.is_ok());
330        let content = result.unwrap().content;
331        assert!(content.contains("hello"));
332    }
333}