1use std::path::PathBuf;
13
14use async_trait::async_trait;
15
16use crate::agent::capability::Capability;
17use crate::agent::driver::ToolDefinition;
18
19use super::{Tool, ToolResult};
20
21const MAX_GLOB_RESULTS: usize = 200;
23
24const MAX_GREP_RESULTS: usize = 200;
26
27const MAX_GREP_BYTES: usize = 32_768;
29
30pub struct GlobTool {
38 allowed_paths: Vec<String>,
39}
40
41impl GlobTool {
42 pub fn new(allowed_paths: Vec<String>) -> Self {
43 Self { allowed_paths }
44 }
45}
46
47#[async_trait]
48impl Tool for GlobTool {
49 fn name(&self) -> &'static str {
50 "glob"
51 }
52
53 fn definition(&self) -> ToolDefinition {
54 ToolDefinition {
55 name: "glob".into(),
56 description:
57 "Find files matching a glob pattern. Returns paths sorted by modification time."
58 .into(),
59 input_schema: serde_json::json!({
60 "type": "object",
61 "required": ["pattern"],
62 "properties": {
63 "pattern": {
64 "type": "string",
65 "description": "Glob pattern (e.g., 'src/**/*.rs', '*.toml')"
66 },
67 "path": {
68 "type": "string",
69 "description": "Base directory to search in (default: current dir)"
70 }
71 }
72 }),
73 }
74 }
75
76 async fn execute(&self, input: serde_json::Value) -> ToolResult {
77 let pattern = match input.get("pattern").and_then(|v| v.as_str()) {
78 Some(p) => p,
79 None => return ToolResult::error("missing required field 'pattern'"),
80 };
81
82 let base = input.get("path").and_then(|v| v.as_str()).unwrap_or(".");
83
84 let full_pattern = if pattern.starts_with('/') {
86 pattern.to_string()
87 } else {
88 format!("{}/{}", base.trim_end_matches('/'), pattern)
89 };
90
91 let entries = match glob::glob(&full_pattern) {
92 Ok(paths) => paths,
93 Err(e) => return ToolResult::error(format!("invalid glob pattern: {e}")),
94 };
95
96 let mut results: Vec<(PathBuf, std::time::SystemTime)> = Vec::new();
97 for entry in entries.take(MAX_GLOB_RESULTS * 2) {
98 let Ok(path) = entry else { continue };
100 if !path.is_file() {
101 continue;
102 }
103 if !self.allowed_paths.iter().any(|p| p == "*") {
105 let Ok(canon) = path.canonicalize() else {
106 continue;
107 };
108 let allowed = self.allowed_paths.iter().any(|prefix| {
109 PathBuf::from(prefix)
110 .canonicalize()
111 .map(|pc| canon.starts_with(&pc))
112 .unwrap_or(false)
113 });
114 if !allowed {
115 continue;
116 }
117 }
118 let mtime = path.metadata().and_then(|m| m.modified()).unwrap_or(std::time::UNIX_EPOCH);
119 results.push((path, mtime));
120 }
121
122 results.sort_by(|a, b| b.1.cmp(&a.1));
124 results.truncate(MAX_GLOB_RESULTS);
125
126 if results.is_empty() {
127 return ToolResult::success(format!("No files matching '{full_pattern}'"));
128 }
129
130 let output: String =
131 results.iter().map(|(p, _)| p.display().to_string()).collect::<Vec<_>>().join("\n");
132
133 let suffix = if results.len() == MAX_GLOB_RESULTS {
134 format!("\n\n[truncated at {MAX_GLOB_RESULTS} results]")
135 } else {
136 String::new()
137 };
138
139 ToolResult::success(format!("{output}{suffix}"))
140 }
141
142 fn required_capability(&self) -> Capability {
143 Capability::FileRead { allowed_paths: self.allowed_paths.clone() }
144 }
145}
146
147pub struct GrepTool {
155 allowed_paths: Vec<String>,
156}
157
158impl GrepTool {
159 pub fn new(allowed_paths: Vec<String>) -> Self {
160 Self { allowed_paths }
161 }
162}
163
164#[async_trait]
165impl Tool for GrepTool {
166 fn name(&self) -> &'static str {
167 "grep"
168 }
169
170 fn definition(&self) -> ToolDefinition {
171 ToolDefinition {
172 name: "grep".into(),
173 description:
174 "Search file contents with regex. Returns matching lines with file:line:content."
175 .into(),
176 input_schema: serde_json::json!({
177 "type": "object",
178 "required": ["pattern"],
179 "properties": {
180 "pattern": {
181 "type": "string",
182 "description": "Regex pattern to search for"
183 },
184 "path": {
185 "type": "string",
186 "description": "File or directory to search (default: current dir)"
187 },
188 "glob": {
189 "type": "string",
190 "description": "Glob to filter files (e.g., '*.rs', '*.toml')"
191 },
192 "case_insensitive": {
193 "type": "boolean",
194 "description": "Case-insensitive search (default: false)"
195 }
196 }
197 }),
198 }
199 }
200
201 async fn execute(&self, input: serde_json::Value) -> ToolResult {
202 let pattern_str = match input.get("pattern").and_then(|v| v.as_str()) {
203 Some(p) => p,
204 None => return ToolResult::error("missing required field 'pattern'"),
205 };
206
207 let search_path = input.get("path").and_then(|v| v.as_str()).unwrap_or(".");
208
209 let file_glob = input.get("glob").and_then(|v| v.as_str());
210 let case_insensitive =
211 input.get("case_insensitive").and_then(|v| v.as_bool()).unwrap_or(false);
212
213 let matcher = PatternMatcher::new(pattern_str, case_insensitive);
214
215 let root = PathBuf::from(search_path);
216 if !root.exists() {
217 return ToolResult::error(format!("path '{}' not found", root.display()));
218 }
219
220 let mut output = String::new();
221 let mut match_count = 0;
222
223 if root.is_file() {
225 search_file(&root, &matcher, &mut output, &mut match_count);
226 return finish_grep(output, match_count);
227 }
228
229 let walker = walkdir::WalkDir::new(&root)
231 .max_depth(20)
232 .follow_links(false)
233 .into_iter()
234 .filter_map(|e| e.ok())
235 .filter(|e| e.file_type().is_file());
236
237 let file_pattern = file_glob.and_then(|g| glob::Pattern::new(g).ok());
239
240 for entry in walker {
241 if match_count >= MAX_GREP_RESULTS {
242 break;
243 }
244
245 let path = entry.path();
246
247 if let Some(ref pat) = file_pattern {
249 let name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
250 if !pat.matches(name) {
251 continue;
252 }
253 }
254
255 if is_likely_binary(path) {
257 continue;
258 }
259
260 search_file(path, &matcher, &mut output, &mut match_count);
261 }
262
263 finish_grep(output, match_count)
264 }
265
266 fn required_capability(&self) -> Capability {
267 Capability::FileRead { allowed_paths: self.allowed_paths.clone() }
268 }
269}
270
271struct PatternMatcher {
277 pattern: String,
278 case_insensitive: bool,
279}
280
281impl PatternMatcher {
282 fn new(pattern: &str, case_insensitive: bool) -> Self {
283 let pattern = if case_insensitive { pattern.to_lowercase() } else { pattern.to_string() };
284 Self { pattern, case_insensitive }
285 }
286
287 fn is_match(&self, line: &str) -> bool {
288 if self.case_insensitive {
289 line.to_lowercase().contains(&self.pattern)
290 } else {
291 line.contains(&self.pattern)
292 }
293 }
294}
295
296fn search_file(
298 path: &std::path::Path,
299 matcher: &PatternMatcher,
300 output: &mut String,
301 match_count: &mut usize,
302) {
303 let Ok(content) = std::fs::read_to_string(path) else {
304 return;
305 };
306 for (line_num, line) in content.lines().enumerate() {
307 if *match_count >= MAX_GREP_RESULTS {
308 break;
309 }
310 if matcher.is_match(line) {
311 use std::fmt::Write;
312 let _ = writeln!(output, "{}:{}:{}", path.display(), line_num + 1, line);
313 *match_count += 1;
314 }
315 }
316}
317
318fn is_likely_binary(path: &std::path::Path) -> bool {
320 let Ok(mut f) = std::fs::File::open(path) else {
321 return true;
322 };
323 let mut buf = [0u8; 512];
324 let Ok(n) = std::io::Read::read(&mut f, &mut buf) else {
325 return true;
326 };
327 buf[..n].contains(&0)
328}
329
330fn finish_grep(mut output: String, match_count: usize) -> ToolResult {
332 if match_count == 0 {
333 return ToolResult::success("No matches found.");
334 }
335
336 if output.len() > MAX_GREP_BYTES {
337 output.truncate(MAX_GREP_BYTES);
338 output.push_str("\n\n[output truncated]");
339 }
340
341 if match_count >= MAX_GREP_RESULTS {
342 output.push_str(&format!("\n\n[truncated at {MAX_GREP_RESULTS} matches]"));
343 }
344
345 ToolResult::success(output)
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use std::io::Write as IoWrite;
352 use tempfile::TempDir;
353
354 fn create_project(dir: &std::path::Path) {
355 std::fs::create_dir_all(dir.join("src")).unwrap();
356 let mut f1 = std::fs::File::create(dir.join("src/main.rs")).unwrap();
357 f1.write_all(b"fn main() {\n println!(\"hello\");\n}\n").unwrap();
358
359 let mut f2 = std::fs::File::create(dir.join("src/lib.rs")).unwrap();
360 f2.write_all(b"pub fn add(a: i32, b: i32) -> i32 {\n a + b\n}\n").unwrap();
361
362 let mut f3 = std::fs::File::create(dir.join("Cargo.toml")).unwrap();
363 f3.write_all(b"[package]\nname = \"test\"\nversion = \"0.1.0\"\n").unwrap();
364 }
365
366 #[tokio::test]
369 async fn test_glob_find_rust_files() {
370 let dir = TempDir::new().unwrap();
371 create_project(dir.path());
372 let tool = GlobTool::new(vec!["*".into()]);
373
374 let result = tool
375 .execute(serde_json::json!({
376 "pattern": "**/*.rs",
377 "path": dir.path().to_str().unwrap()
378 }))
379 .await;
380 assert!(!result.is_error, "error: {}", result.content);
381 assert!(result.content.contains("main.rs"));
382 assert!(result.content.contains("lib.rs"));
383 assert!(!result.content.contains("Cargo.toml"));
384 }
385
386 #[tokio::test]
387 async fn test_glob_find_toml() {
388 let dir = TempDir::new().unwrap();
389 create_project(dir.path());
390 let tool = GlobTool::new(vec!["*".into()]);
391
392 let result = tool
393 .execute(serde_json::json!({
394 "pattern": "*.toml",
395 "path": dir.path().to_str().unwrap()
396 }))
397 .await;
398 assert!(!result.is_error);
399 assert!(result.content.contains("Cargo.toml"));
400 assert!(!result.content.contains(".rs"));
401 }
402
403 #[tokio::test]
404 async fn test_glob_no_matches() {
405 let dir = TempDir::new().unwrap();
406 create_project(dir.path());
407 let tool = GlobTool::new(vec!["*".into()]);
408
409 let result = tool
410 .execute(serde_json::json!({
411 "pattern": "**/*.py",
412 "path": dir.path().to_str().unwrap()
413 }))
414 .await;
415 assert!(!result.is_error);
416 assert!(result.content.contains("No files matching"));
417 }
418
419 #[tokio::test]
420 async fn test_glob_invalid_pattern() {
421 let tool = GlobTool::new(vec!["*".into()]);
422 let result = tool.execute(serde_json::json!({"pattern": "[invalid"})).await;
423 assert!(result.is_error);
424 assert!(result.content.contains("invalid glob"));
425 }
426
427 #[tokio::test]
428 async fn test_glob_missing_pattern() {
429 let tool = GlobTool::new(vec!["*".into()]);
430 let result = tool.execute(serde_json::json!({"path": "."})).await;
431 assert!(result.is_error);
432 assert!(result.content.contains("missing"));
433 }
434
435 #[test]
436 fn test_glob_tool_metadata() {
437 let tool = GlobTool::new(vec!["/home".into()]);
438 assert_eq!(tool.name(), "glob");
439 let def = tool.definition();
440 assert_eq!(def.name, "glob");
441 match tool.required_capability() {
442 Capability::FileRead { allowed_paths } => {
443 assert_eq!(allowed_paths, vec!["/home".to_string()]);
444 }
445 other => panic!("expected FileRead, got: {other:?}"),
446 }
447 }
448
449 #[tokio::test]
452 async fn test_grep_find_pattern() {
453 let dir = TempDir::new().unwrap();
454 create_project(dir.path());
455 let tool = GrepTool::new(vec!["*".into()]);
456
457 let result = tool
458 .execute(serde_json::json!({
459 "pattern": "println",
460 "path": dir.path().to_str().unwrap()
461 }))
462 .await;
463 assert!(!result.is_error, "error: {}", result.content);
464 assert!(result.content.contains("main.rs"));
465 assert!(result.content.contains("println"));
466 }
467
468 #[tokio::test]
469 async fn test_grep_with_file_glob() {
470 let dir = TempDir::new().unwrap();
471 create_project(dir.path());
472 let tool = GrepTool::new(vec!["*".into()]);
473
474 let result = tool
475 .execute(serde_json::json!({
476 "pattern": "fn",
477 "path": dir.path().to_str().unwrap(),
478 "glob": "*.rs"
479 }))
480 .await;
481 assert!(!result.is_error);
482 assert!(result.content.contains("main.rs"));
483 assert!(result.content.contains("lib.rs"));
484 assert!(!result.content.contains("Cargo.toml"));
486 }
487
488 #[tokio::test]
489 async fn test_grep_case_insensitive() {
490 let dir = TempDir::new().unwrap();
491 create_project(dir.path());
492 let tool = GrepTool::new(vec!["*".into()]);
493
494 let result = tool
495 .execute(serde_json::json!({
496 "pattern": "PRINTLN",
497 "path": dir.path().to_str().unwrap(),
498 "case_insensitive": true
499 }))
500 .await;
501 assert!(!result.is_error);
502 assert!(result.content.contains("println"));
503 }
504
505 #[tokio::test]
506 async fn test_grep_no_matches() {
507 let dir = TempDir::new().unwrap();
508 create_project(dir.path());
509 let tool = GrepTool::new(vec!["*".into()]);
510
511 let result = tool
512 .execute(serde_json::json!({
513 "pattern": "ZZZZZ_NONEXISTENT",
514 "path": dir.path().to_str().unwrap()
515 }))
516 .await;
517 assert!(!result.is_error);
518 assert!(result.content.contains("No matches"));
519 }
520
521 #[tokio::test]
522 async fn test_grep_special_chars_in_pattern() {
523 let dir = TempDir::new().unwrap();
524 create_project(dir.path());
525 let tool = GrepTool::new(vec!["*".into()]);
526
527 let result = tool
529 .execute(serde_json::json!({
530 "pattern": "[invalid",
531 "path": dir.path().to_str().unwrap()
532 }))
533 .await;
534 assert!(!result.is_error);
535 assert!(result.content.contains("No matches"));
536 }
537
538 #[tokio::test]
539 async fn test_grep_single_file() {
540 let dir = TempDir::new().unwrap();
541 create_project(dir.path());
542 let tool = GrepTool::new(vec!["*".into()]);
543
544 let file_path = dir.path().join("src/main.rs");
545 let result = tool
546 .execute(serde_json::json!({
547 "pattern": "fn",
548 "path": file_path.to_str().unwrap()
549 }))
550 .await;
551 assert!(!result.is_error);
552 assert!(result.content.contains("fn main"));
553 }
554
555 #[tokio::test]
556 async fn test_grep_nonexistent_path() {
557 let tool = GrepTool::new(vec!["*".into()]);
558 let result = tool
559 .execute(serde_json::json!({
560 "pattern": "test",
561 "path": "/nonexistent_dir_xyz"
562 }))
563 .await;
564 assert!(result.is_error);
565 assert!(result.content.contains("not found"));
566 }
567
568 #[tokio::test]
569 async fn test_grep_missing_pattern() {
570 let tool = GrepTool::new(vec!["*".into()]);
571 let result = tool.execute(serde_json::json!({"path": "."})).await;
572 assert!(result.is_error);
573 assert!(result.content.contains("missing"));
574 }
575
576 #[test]
577 fn test_grep_tool_metadata() {
578 let tool = GrepTool::new(vec!["/project".into()]);
579 assert_eq!(tool.name(), "grep");
580 let def = tool.definition();
581 assert_eq!(def.name, "grep");
582 match tool.required_capability() {
583 Capability::FileRead { allowed_paths } => {
584 assert_eq!(allowed_paths, vec!["/project".to_string()]);
585 }
586 other => panic!("expected FileRead, got: {other:?}"),
587 }
588 }
589
590 #[test]
593 fn test_is_likely_binary_text() {
594 let dir = TempDir::new().unwrap();
595 let path = dir.path().join("text.txt");
596 std::fs::write(&path, "hello world").unwrap();
597 assert!(!is_likely_binary(&path));
598 }
599
600 #[test]
601 fn test_is_likely_binary_binary() {
602 let dir = TempDir::new().unwrap();
603 let path = dir.path().join("binary.bin");
604 std::fs::write(&path, &[0u8, 1, 2, 0, 3, 4]).unwrap();
605 assert!(is_likely_binary(&path));
606 }
607
608 #[test]
609 fn test_is_likely_binary_nonexistent() {
610 assert!(is_likely_binary(std::path::Path::new("/no_such_file_xyz")));
611 }
612}