Skip to main content

batuta/agent/tool/
search.rs

1//! Search tools for agent code discovery.
2//!
3//! Provides `glob` and `grep` tools for the `apr code` agentic
4//! coding assistant. These are the agent's primary code navigation
5//! tools, giving it the ability to find files by pattern and search
6//! content by regex.
7//!
8//! Both tools require `Capability::FileRead` and respect path
9//! restrictions (Poka-Yoke). Results are truncated to prevent
10//! context overflow (Jidoka: bounded output).
11
12use std::path::PathBuf;
13
14use async_trait::async_trait;
15
16use crate::agent::capability::Capability;
17use crate::agent::driver::ToolDefinition;
18
19use super::{Tool, ToolResult};
20
21/// Maximum matching files returned by glob.
22const MAX_GLOB_RESULTS: usize = 200;
23
24/// Maximum matching lines returned by grep.
25const MAX_GREP_RESULTS: usize = 200;
26
27/// Maximum bytes of grep output before truncation.
28const MAX_GREP_BYTES: usize = 32_768;
29
30// ─── GlobTool ───────────────────────────────────────────────
31
32/// Find files by glob pattern.
33///
34/// Wraps the `glob` crate for fast file discovery. Returns paths
35/// sorted by modification time (most recent first), capped at
36/// `MAX_GLOB_RESULTS`.
37pub struct GlobTool {
38    allowed_paths: Vec<String>,
39}
40
41impl GlobTool {
42    pub fn new(allowed_paths: Vec<String>) -> Self {
43        Self { allowed_paths }
44    }
45}
46
47#[async_trait]
48impl Tool for GlobTool {
49    fn name(&self) -> &'static str {
50        "glob"
51    }
52
53    fn definition(&self) -> ToolDefinition {
54        ToolDefinition {
55            name: "glob".into(),
56            description:
57                "Find files matching a glob pattern. Returns paths sorted by modification time."
58                    .into(),
59            input_schema: serde_json::json!({
60                "type": "object",
61                "required": ["pattern"],
62                "properties": {
63                    "pattern": {
64                        "type": "string",
65                        "description": "Glob pattern (e.g., 'src/**/*.rs', '*.toml')"
66                    },
67                    "path": {
68                        "type": "string",
69                        "description": "Base directory to search in (default: current dir)"
70                    }
71                }
72            }),
73        }
74    }
75
76    async fn execute(&self, input: serde_json::Value) -> ToolResult {
77        let pattern = match input.get("pattern").and_then(|v| v.as_str()) {
78            Some(p) => p,
79            None => return ToolResult::error("missing required field 'pattern'"),
80        };
81
82        let base = input.get("path").and_then(|v| v.as_str()).unwrap_or(".");
83
84        // Construct full pattern
85        let full_pattern = if pattern.starts_with('/') {
86            pattern.to_string()
87        } else {
88            format!("{}/{}", base.trim_end_matches('/'), pattern)
89        };
90
91        let entries = match glob::glob(&full_pattern) {
92            Ok(paths) => paths,
93            Err(e) => return ToolResult::error(format!("invalid glob pattern: {e}")),
94        };
95
96        let mut results: Vec<(PathBuf, std::time::SystemTime)> = Vec::new();
97        for entry in entries.take(MAX_GLOB_RESULTS * 2) {
98            // overscan to allow filtering
99            let Ok(path) = entry else { continue };
100            if !path.is_file() {
101                continue;
102            }
103            // Validate against allowed paths
104            if !self.allowed_paths.iter().any(|p| p == "*") {
105                let Ok(canon) = path.canonicalize() else {
106                    continue;
107                };
108                let allowed = self.allowed_paths.iter().any(|prefix| {
109                    PathBuf::from(prefix)
110                        .canonicalize()
111                        .map(|pc| canon.starts_with(&pc))
112                        .unwrap_or(false)
113                });
114                if !allowed {
115                    continue;
116                }
117            }
118            let mtime = path.metadata().and_then(|m| m.modified()).unwrap_or(std::time::UNIX_EPOCH);
119            results.push((path, mtime));
120        }
121
122        // Sort by modification time (most recent first)
123        results.sort_by(|a, b| b.1.cmp(&a.1));
124        results.truncate(MAX_GLOB_RESULTS);
125
126        if results.is_empty() {
127            return ToolResult::success(format!("No files matching '{full_pattern}'"));
128        }
129
130        let output: String =
131            results.iter().map(|(p, _)| p.display().to_string()).collect::<Vec<_>>().join("\n");
132
133        let suffix = if results.len() == MAX_GLOB_RESULTS {
134            format!("\n\n[truncated at {MAX_GLOB_RESULTS} results]")
135        } else {
136            String::new()
137        };
138
139        ToolResult::success(format!("{output}{suffix}"))
140    }
141
142    fn required_capability(&self) -> Capability {
143        Capability::FileRead { allowed_paths: self.allowed_paths.clone() }
144    }
145}
146
147// ─── GrepTool ───────────────────────────────────────────────
148
149/// Search file contents by regex pattern.
150///
151/// Walks a directory and matches lines against a regex. Returns
152/// matching lines with file path, line number, and content.
153/// Results capped at `MAX_GREP_RESULTS` lines.
154pub struct GrepTool {
155    allowed_paths: Vec<String>,
156}
157
158impl GrepTool {
159    pub fn new(allowed_paths: Vec<String>) -> Self {
160        Self { allowed_paths }
161    }
162}
163
164#[async_trait]
165impl Tool for GrepTool {
166    fn name(&self) -> &'static str {
167        "grep"
168    }
169
170    fn definition(&self) -> ToolDefinition {
171        ToolDefinition {
172            name: "grep".into(),
173            description:
174                "Search file contents with regex. Returns matching lines with file:line:content."
175                    .into(),
176            input_schema: serde_json::json!({
177                "type": "object",
178                "required": ["pattern"],
179                "properties": {
180                    "pattern": {
181                        "type": "string",
182                        "description": "Regex pattern to search for"
183                    },
184                    "path": {
185                        "type": "string",
186                        "description": "File or directory to search (default: current dir)"
187                    },
188                    "glob": {
189                        "type": "string",
190                        "description": "Glob to filter files (e.g., '*.rs', '*.toml')"
191                    },
192                    "case_insensitive": {
193                        "type": "boolean",
194                        "description": "Case-insensitive search (default: false)"
195                    }
196                }
197            }),
198        }
199    }
200
201    async fn execute(&self, input: serde_json::Value) -> ToolResult {
202        let pattern_str = match input.get("pattern").and_then(|v| v.as_str()) {
203            Some(p) => p,
204            None => return ToolResult::error("missing required field 'pattern'"),
205        };
206
207        let search_path = input.get("path").and_then(|v| v.as_str()).unwrap_or(".");
208
209        let file_glob = input.get("glob").and_then(|v| v.as_str());
210        let case_insensitive =
211            input.get("case_insensitive").and_then(|v| v.as_bool()).unwrap_or(false);
212
213        let matcher = PatternMatcher::new(pattern_str, case_insensitive);
214
215        let root = PathBuf::from(search_path);
216        if !root.exists() {
217            return ToolResult::error(format!("path '{}' not found", root.display()));
218        }
219
220        let mut output = String::new();
221        let mut match_count = 0;
222
223        // Single file
224        if root.is_file() {
225            search_file(&root, &matcher, &mut output, &mut match_count);
226            return finish_grep(output, match_count);
227        }
228
229        // Directory walk
230        let walker = walkdir::WalkDir::new(&root)
231            .max_depth(20)
232            .follow_links(false)
233            .into_iter()
234            .filter_map(|e| e.ok())
235            .filter(|e| e.file_type().is_file());
236
237        // Compile file glob filter if provided
238        let file_pattern = file_glob.and_then(|g| glob::Pattern::new(g).ok());
239
240        for entry in walker {
241            if match_count >= MAX_GREP_RESULTS {
242                break;
243            }
244
245            let path = entry.path();
246
247            // Filter by file glob
248            if let Some(ref pat) = file_pattern {
249                let name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
250                if !pat.matches(name) {
251                    continue;
252                }
253            }
254
255            // Skip binary files (quick heuristic: NUL byte in first 512 bytes)
256            if is_likely_binary(path) {
257                continue;
258            }
259
260            search_file(path, &matcher, &mut output, &mut match_count);
261        }
262
263        finish_grep(output, match_count)
264    }
265
266    fn required_capability(&self) -> Capability {
267        Capability::FileRead { allowed_paths: self.allowed_paths.clone() }
268    }
269}
270
271/// Simple pattern matcher (substring with optional case-insensitivity).
272///
273/// Uses string `contains()` instead of regex to avoid adding a regex
274/// dependency. For the agent's search use case, substring matching
275/// covers the vast majority of queries.
276struct PatternMatcher {
277    pattern: String,
278    case_insensitive: bool,
279}
280
281impl PatternMatcher {
282    fn new(pattern: &str, case_insensitive: bool) -> Self {
283        let pattern = if case_insensitive { pattern.to_lowercase() } else { pattern.to_string() };
284        Self { pattern, case_insensitive }
285    }
286
287    fn is_match(&self, line: &str) -> bool {
288        if self.case_insensitive {
289            line.to_lowercase().contains(&self.pattern)
290        } else {
291            line.contains(&self.pattern)
292        }
293    }
294}
295
296/// Search a single file for pattern matches, appending results to `output`.
297fn search_file(
298    path: &std::path::Path,
299    matcher: &PatternMatcher,
300    output: &mut String,
301    match_count: &mut usize,
302) {
303    let Ok(content) = std::fs::read_to_string(path) else {
304        return;
305    };
306    for (line_num, line) in content.lines().enumerate() {
307        if *match_count >= MAX_GREP_RESULTS {
308            break;
309        }
310        if matcher.is_match(line) {
311            use std::fmt::Write;
312            let _ = writeln!(output, "{}:{}:{}", path.display(), line_num + 1, line);
313            *match_count += 1;
314        }
315    }
316}
317
318/// Quick check if a file is likely binary (non-UTF-8 in first 512 bytes).
319fn is_likely_binary(path: &std::path::Path) -> bool {
320    let Ok(mut f) = std::fs::File::open(path) else {
321        return true;
322    };
323    let mut buf = [0u8; 512];
324    let Ok(n) = std::io::Read::read(&mut f, &mut buf) else {
325        return true;
326    };
327    buf[..n].contains(&0)
328}
329
330/// Format the final grep result with truncation info.
331fn finish_grep(mut output: String, match_count: usize) -> ToolResult {
332    if match_count == 0 {
333        return ToolResult::success("No matches found.");
334    }
335
336    if output.len() > MAX_GREP_BYTES {
337        output.truncate(MAX_GREP_BYTES);
338        output.push_str("\n\n[output truncated]");
339    }
340
341    if match_count >= MAX_GREP_RESULTS {
342        output.push_str(&format!("\n\n[truncated at {MAX_GREP_RESULTS} matches]"));
343    }
344
345    ToolResult::success(output)
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use std::io::Write as IoWrite;
352    use tempfile::TempDir;
353
354    fn create_project(dir: &std::path::Path) {
355        std::fs::create_dir_all(dir.join("src")).unwrap();
356        let mut f1 = std::fs::File::create(dir.join("src/main.rs")).unwrap();
357        f1.write_all(b"fn main() {\n    println!(\"hello\");\n}\n").unwrap();
358
359        let mut f2 = std::fs::File::create(dir.join("src/lib.rs")).unwrap();
360        f2.write_all(b"pub fn add(a: i32, b: i32) -> i32 {\n    a + b\n}\n").unwrap();
361
362        let mut f3 = std::fs::File::create(dir.join("Cargo.toml")).unwrap();
363        f3.write_all(b"[package]\nname = \"test\"\nversion = \"0.1.0\"\n").unwrap();
364    }
365
366    // ─── GlobTool tests ─────────────────────────────────
367
368    #[tokio::test]
369    async fn test_glob_find_rust_files() {
370        let dir = TempDir::new().unwrap();
371        create_project(dir.path());
372        let tool = GlobTool::new(vec!["*".into()]);
373
374        let result = tool
375            .execute(serde_json::json!({
376                "pattern": "**/*.rs",
377                "path": dir.path().to_str().unwrap()
378            }))
379            .await;
380        assert!(!result.is_error, "error: {}", result.content);
381        assert!(result.content.contains("main.rs"));
382        assert!(result.content.contains("lib.rs"));
383        assert!(!result.content.contains("Cargo.toml"));
384    }
385
386    #[tokio::test]
387    async fn test_glob_find_toml() {
388        let dir = TempDir::new().unwrap();
389        create_project(dir.path());
390        let tool = GlobTool::new(vec!["*".into()]);
391
392        let result = tool
393            .execute(serde_json::json!({
394                "pattern": "*.toml",
395                "path": dir.path().to_str().unwrap()
396            }))
397            .await;
398        assert!(!result.is_error);
399        assert!(result.content.contains("Cargo.toml"));
400        assert!(!result.content.contains(".rs"));
401    }
402
403    #[tokio::test]
404    async fn test_glob_no_matches() {
405        let dir = TempDir::new().unwrap();
406        create_project(dir.path());
407        let tool = GlobTool::new(vec!["*".into()]);
408
409        let result = tool
410            .execute(serde_json::json!({
411                "pattern": "**/*.py",
412                "path": dir.path().to_str().unwrap()
413            }))
414            .await;
415        assert!(!result.is_error);
416        assert!(result.content.contains("No files matching"));
417    }
418
419    #[tokio::test]
420    async fn test_glob_invalid_pattern() {
421        let tool = GlobTool::new(vec!["*".into()]);
422        let result = tool.execute(serde_json::json!({"pattern": "[invalid"})).await;
423        assert!(result.is_error);
424        assert!(result.content.contains("invalid glob"));
425    }
426
427    #[tokio::test]
428    async fn test_glob_missing_pattern() {
429        let tool = GlobTool::new(vec!["*".into()]);
430        let result = tool.execute(serde_json::json!({"path": "."})).await;
431        assert!(result.is_error);
432        assert!(result.content.contains("missing"));
433    }
434
435    #[test]
436    fn test_glob_tool_metadata() {
437        let tool = GlobTool::new(vec!["/home".into()]);
438        assert_eq!(tool.name(), "glob");
439        let def = tool.definition();
440        assert_eq!(def.name, "glob");
441        match tool.required_capability() {
442            Capability::FileRead { allowed_paths } => {
443                assert_eq!(allowed_paths, vec!["/home".to_string()]);
444            }
445            other => panic!("expected FileRead, got: {other:?}"),
446        }
447    }
448
449    // ─── GrepTool tests ─────────────────────────────────
450
451    #[tokio::test]
452    async fn test_grep_find_pattern() {
453        let dir = TempDir::new().unwrap();
454        create_project(dir.path());
455        let tool = GrepTool::new(vec!["*".into()]);
456
457        let result = tool
458            .execute(serde_json::json!({
459                "pattern": "println",
460                "path": dir.path().to_str().unwrap()
461            }))
462            .await;
463        assert!(!result.is_error, "error: {}", result.content);
464        assert!(result.content.contains("main.rs"));
465        assert!(result.content.contains("println"));
466    }
467
468    #[tokio::test]
469    async fn test_grep_with_file_glob() {
470        let dir = TempDir::new().unwrap();
471        create_project(dir.path());
472        let tool = GrepTool::new(vec!["*".into()]);
473
474        let result = tool
475            .execute(serde_json::json!({
476                "pattern": "fn",
477                "path": dir.path().to_str().unwrap(),
478                "glob": "*.rs"
479            }))
480            .await;
481        assert!(!result.is_error);
482        assert!(result.content.contains("main.rs"));
483        assert!(result.content.contains("lib.rs"));
484        // Should NOT search Cargo.toml
485        assert!(!result.content.contains("Cargo.toml"));
486    }
487
488    #[tokio::test]
489    async fn test_grep_case_insensitive() {
490        let dir = TempDir::new().unwrap();
491        create_project(dir.path());
492        let tool = GrepTool::new(vec!["*".into()]);
493
494        let result = tool
495            .execute(serde_json::json!({
496                "pattern": "PRINTLN",
497                "path": dir.path().to_str().unwrap(),
498                "case_insensitive": true
499            }))
500            .await;
501        assert!(!result.is_error);
502        assert!(result.content.contains("println"));
503    }
504
505    #[tokio::test]
506    async fn test_grep_no_matches() {
507        let dir = TempDir::new().unwrap();
508        create_project(dir.path());
509        let tool = GrepTool::new(vec!["*".into()]);
510
511        let result = tool
512            .execute(serde_json::json!({
513                "pattern": "ZZZZZ_NONEXISTENT",
514                "path": dir.path().to_str().unwrap()
515            }))
516            .await;
517        assert!(!result.is_error);
518        assert!(result.content.contains("No matches"));
519    }
520
521    #[tokio::test]
522    async fn test_grep_special_chars_in_pattern() {
523        let dir = TempDir::new().unwrap();
524        create_project(dir.path());
525        let tool = GrepTool::new(vec!["*".into()]);
526
527        // Brackets are treated as literal substring, not regex
528        let result = tool
529            .execute(serde_json::json!({
530                "pattern": "[invalid",
531                "path": dir.path().to_str().unwrap()
532            }))
533            .await;
534        assert!(!result.is_error);
535        assert!(result.content.contains("No matches"));
536    }
537
538    #[tokio::test]
539    async fn test_grep_single_file() {
540        let dir = TempDir::new().unwrap();
541        create_project(dir.path());
542        let tool = GrepTool::new(vec!["*".into()]);
543
544        let file_path = dir.path().join("src/main.rs");
545        let result = tool
546            .execute(serde_json::json!({
547                "pattern": "fn",
548                "path": file_path.to_str().unwrap()
549            }))
550            .await;
551        assert!(!result.is_error);
552        assert!(result.content.contains("fn main"));
553    }
554
555    #[tokio::test]
556    async fn test_grep_nonexistent_path() {
557        let tool = GrepTool::new(vec!["*".into()]);
558        let result = tool
559            .execute(serde_json::json!({
560                "pattern": "test",
561                "path": "/nonexistent_dir_xyz"
562            }))
563            .await;
564        assert!(result.is_error);
565        assert!(result.content.contains("not found"));
566    }
567
568    #[tokio::test]
569    async fn test_grep_missing_pattern() {
570        let tool = GrepTool::new(vec!["*".into()]);
571        let result = tool.execute(serde_json::json!({"path": "."})).await;
572        assert!(result.is_error);
573        assert!(result.content.contains("missing"));
574    }
575
576    #[test]
577    fn test_grep_tool_metadata() {
578        let tool = GrepTool::new(vec!["/project".into()]);
579        assert_eq!(tool.name(), "grep");
580        let def = tool.definition();
581        assert_eq!(def.name, "grep");
582        match tool.required_capability() {
583            Capability::FileRead { allowed_paths } => {
584                assert_eq!(allowed_paths, vec!["/project".to_string()]);
585            }
586            other => panic!("expected FileRead, got: {other:?}"),
587        }
588    }
589
590    // ─── Helper tests ───────────────────────────────────
591
592    #[test]
593    fn test_is_likely_binary_text() {
594        let dir = TempDir::new().unwrap();
595        let path = dir.path().join("text.txt");
596        std::fs::write(&path, "hello world").unwrap();
597        assert!(!is_likely_binary(&path));
598    }
599
600    #[test]
601    fn test_is_likely_binary_binary() {
602        let dir = TempDir::new().unwrap();
603        let path = dir.path().join("binary.bin");
604        std::fs::write(&path, &[0u8, 1, 2, 0, 3, 4]).unwrap();
605        assert!(is_likely_binary(&path));
606    }
607
608    #[test]
609    fn test_is_likely_binary_nonexistent() {
610        assert!(is_likely_binary(std::path::Path::new("/no_such_file_xyz")));
611    }
612}