Skip to main content

pawan/tools/
search.rs

1//! Search tools (glob and grep)
2
3use super::Tool;
4use async_trait::async_trait;
5use serde_json::{json, Value};
6use std::path::PathBuf;
7
8/// Tool for finding files by glob pattern
9pub struct GlobSearchTool {
10    workspace_root: PathBuf,
11}
12
13impl GlobSearchTool {
14    pub fn new(workspace_root: PathBuf) -> Self {
15        Self { workspace_root }
16    }
17}
18
19#[async_trait]
20impl Tool for GlobSearchTool {
21    fn name(&self) -> &str {
22        "glob_search"
23    }
24
25    fn description(&self) -> &str {
26        "Find files matching a glob pattern. Respects .gitignore. \
27         Examples: '**/*.rs', 'src/**/*.toml', 'Cargo.*'"
28    }
29
30    fn parameters_schema(&self) -> Value {
31        json!({
32            "type": "object",
33            "properties": {
34                "pattern": {
35                    "type": "string",
36                    "description": "Glob pattern to match files"
37                },
38                "path": {
39                    "type": "string",
40                    "description": "Directory to search in (optional, defaults to workspace root)"
41                },
42                "max_results": {
43                    "type": "integer",
44                    "description": "Maximum number of results (default: 100)"
45                }
46            },
47            "required": ["pattern"]
48        })
49    }
50
51    fn thulp_definition(&self) -> thulp_core::ToolDefinition {
52        use thulp_core::{Parameter, ParameterType};
53        thulp_core::ToolDefinition::builder("glob_search")
54            .description(self.description())
55            .parameter(
56                Parameter::builder("pattern")
57                    .param_type(ParameterType::String)
58                    .required(true)
59                    .description("Glob pattern to match files")
60                    .build(),
61            )
62            .parameter(
63                Parameter::builder("path")
64                    .param_type(ParameterType::String)
65                    .required(false)
66                    .description("Directory to search in (optional, defaults to workspace root)")
67                    .build(),
68            )
69            .parameter(
70                Parameter::builder("max_results")
71                    .param_type(ParameterType::Integer)
72                    .required(false)
73                    .description("Maximum number of results (default: 100)")
74                    .build(),
75            )
76            .build()
77    }
78
79    async fn execute(&self, args: Value) -> crate::Result<Value> {
80        let pattern = args["pattern"]
81            .as_str()
82            .ok_or_else(|| crate::PawanError::Tool("pattern is required".into()))?;
83
84        let base_path = args["path"]
85            .as_str()
86            .map(|p| self.workspace_root.join(p))
87            .unwrap_or_else(|| self.workspace_root.clone());
88
89        let max_results = args["max_results"].as_u64().unwrap_or(100) as usize;
90
91        // Use ignore crate to respect .gitignore
92        let mut builder = ignore::WalkBuilder::new(&base_path);
93        builder.hidden(false); // Include hidden files if explicitly matched
94
95        let mut matches = Vec::new();
96        let glob_matcher = glob::Pattern::new(pattern)
97            .map_err(|e| crate::PawanError::Tool(format!("Invalid glob pattern: {}", e)))?;
98
99        for result in builder.build() {
100            if matches.len() >= max_results {
101                break;
102            }
103
104            if let Ok(entry) = result {
105                let path = entry.path();
106                if path.is_file() {
107                    let relative = path.strip_prefix(&self.workspace_root).unwrap_or(path);
108                    let relative_str = relative.to_string_lossy();
109
110                    if glob_matcher.matches(&relative_str) {
111                        let metadata = path.metadata().ok();
112                        let size = metadata.as_ref().map(|m| m.len()).unwrap_or(0);
113                        let modified = metadata.and_then(|m| m.modified().ok()).map(|t| {
114                            t.duration_since(std::time::UNIX_EPOCH)
115                                .map(|d| d.as_secs())
116                                .unwrap_or(0)
117                        });
118                        matches.push(json!({
119                            "path": relative_str,
120                            "size": size,
121                            "modified": modified
122                        }));
123                    }
124                }
125            }
126        }
127
128        // Sort by modification time (newest first)
129        matches.sort_by(|a, b| {
130            let a_mod = a["modified"].as_u64().unwrap_or(0);
131            let b_mod = b["modified"].as_u64().unwrap_or(0);
132            b_mod.cmp(&a_mod)
133        });
134
135        Ok(json!({
136            "pattern": pattern,
137            "matches": matches,
138            "count": matches.len(),
139            "truncated": matches.len() >= max_results
140        }))
141    }
142}
143
144/// Tool for searching file contents
145pub struct GrepSearchTool {
146    workspace_root: PathBuf,
147}
148
149impl GrepSearchTool {
150    pub fn new(workspace_root: PathBuf) -> Self {
151        Self { workspace_root }
152    }
153}
154
155#[async_trait]
156impl Tool for GrepSearchTool {
157    fn name(&self) -> &str {
158        "grep_search"
159    }
160
161    fn description(&self) -> &str {
162        "Search file contents for a pattern. Supports regex. \
163         Returns file paths and line numbers with matches."
164    }
165
166    fn parameters_schema(&self) -> Value {
167        json!({
168            "type": "object",
169            "properties": {
170                "pattern": {
171                    "type": "string",
172                    "description": "Pattern to search for (supports regex)"
173                },
174                "path": {
175                    "type": "string",
176                    "description": "Directory to search in (optional, defaults to workspace root)"
177                },
178                "include": {
179                    "type": "string",
180                    "description": "File pattern to include (e.g., '*.rs', '*.{ts,tsx}')"
181                },
182                "max_results": {
183                    "type": "integer",
184                    "description": "Maximum number of matching files (default: 50)"
185                },
186                "context_lines": {
187                    "type": "integer",
188                    "description": "Lines of context around matches (default: 0)"
189                }
190            },
191            "required": ["pattern"]
192        })
193    }
194
195    fn thulp_definition(&self) -> thulp_core::ToolDefinition {
196        use thulp_core::{Parameter, ParameterType};
197        thulp_core::ToolDefinition::builder("grep_search")
198            .description(self.description())
199            .parameter(
200                Parameter::builder("pattern")
201                    .param_type(ParameterType::String)
202                    .required(true)
203                    .description("Pattern to search for (supports regex)")
204                    .build(),
205            )
206            .parameter(
207                Parameter::builder("path")
208                    .param_type(ParameterType::String)
209                    .required(false)
210                    .description("Directory to search in (optional, defaults to workspace root)")
211                    .build(),
212            )
213            .parameter(
214                Parameter::builder("include")
215                    .param_type(ParameterType::String)
216                    .required(false)
217                    .description("File pattern to include (e.g., '*.rs', '*.{ts,tsx}')")
218                    .build(),
219            )
220            .parameter(
221                Parameter::builder("max_results")
222                    .param_type(ParameterType::Integer)
223                    .required(false)
224                    .description("Maximum number of matching files (default: 50)")
225                    .build(),
226            )
227            .parameter(
228                Parameter::builder("context_lines")
229                    .param_type(ParameterType::Integer)
230                    .required(false)
231                    .description("Lines of context around matches (default: 0)")
232                    .build(),
233            )
234            .build()
235    }
236
237    async fn execute(&self, args: Value) -> crate::Result<Value> {
238        let pattern = args["pattern"]
239            .as_str()
240            .ok_or_else(|| crate::PawanError::Tool("pattern is required".into()))?;
241
242        let base_path = args["path"]
243            .as_str()
244            .map(|p| self.workspace_root.join(p))
245            .unwrap_or_else(|| self.workspace_root.clone());
246
247        let include = args["include"].as_str();
248        let max_results = args["max_results"].as_u64().unwrap_or(50) as usize;
249        let context_lines = args["context_lines"].as_u64().unwrap_or(0) as usize;
250
251        // Build regex
252        let regex = regex::Regex::new(pattern)
253            .map_err(|e| crate::PawanError::Tool(format!("Invalid regex: {}", e)))?;
254
255        // Build glob matcher for include filter
256        let include_matcher = include
257            .map(glob::Pattern::new)
258            .transpose()
259            .map_err(|e| crate::PawanError::Tool(format!("Invalid include pattern: {}", e)))?;
260
261        let mut file_matches = Vec::new();
262
263        // Walk directory
264        let mut builder = ignore::WalkBuilder::new(&base_path);
265        builder.hidden(false);
266
267        for result in builder.build() {
268            if file_matches.len() >= max_results {
269                break;
270            }
271
272            if let Ok(entry) = result {
273                let path = entry.path();
274                if !path.is_file() {
275                    continue;
276                }
277
278                let relative = path.strip_prefix(&self.workspace_root).unwrap_or(path);
279                let relative_str = relative.to_string_lossy();
280
281                // Check include filter
282                if let Some(ref matcher) = include_matcher {
283                    // Match against filename only
284                    let filename = path
285                        .file_name()
286                        .map(|n| n.to_string_lossy())
287                        .unwrap_or_default();
288                    if !matcher.matches(&filename) && !matcher.matches(&relative_str) {
289                        continue;
290                    }
291                }
292
293                // Read and search file
294                if let Ok(content) = std::fs::read_to_string(path) {
295                    let mut line_matches = Vec::new();
296                    let lines: Vec<&str> = content.lines().collect();
297
298                    for (line_num, line) in lines.iter().enumerate() {
299                        if regex.is_match(line) {
300                            let mut match_info = json!({
301                                "line": line_num + 1,
302                                "content": line.chars().take(200).collect::<String>()
303                            });
304
305                            // Add context if requested
306                            if context_lines > 0 {
307                                let start = line_num.saturating_sub(context_lines);
308                                let end = (line_num + context_lines + 1).min(lines.len());
309                                let context: Vec<String> = lines[start..end]
310                                    .iter()
311                                    .enumerate()
312                                    .map(|(i, l)| format!("{}: {}", start + i + 1, l))
313                                    .collect();
314                                match_info["context"] = json!(context);
315                            }
316
317                            line_matches.push(match_info);
318                        }
319                    }
320
321                    if !line_matches.is_empty() {
322                        file_matches.push(json!({
323                            "path": relative_str,
324                            "matches": line_matches,
325                            "match_count": line_matches.len()
326                        }));
327                    }
328                }
329            }
330        }
331
332        // Sort by match count (most matches first)
333        file_matches.sort_by(|a, b| {
334            let a_count = a["match_count"].as_u64().unwrap_or(0);
335            let b_count = b["match_count"].as_u64().unwrap_or(0);
336            b_count.cmp(&a_count)
337        });
338
339        let total_matches: u64 = file_matches
340            .iter()
341            .map(|f| f["match_count"].as_u64().unwrap_or(0))
342            .sum();
343
344        Ok(json!({
345            "pattern": pattern,
346            "files": file_matches,
347            "file_count": file_matches.len(),
348            "total_matches": total_matches,
349            "truncated": file_matches.len() >= max_results
350        }))
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use tempfile::TempDir;
358
359    #[tokio::test]
360    async fn test_glob_search() {
361        let temp_dir = TempDir::new().unwrap();
362        std::fs::write(temp_dir.path().join("file1.rs"), "rust code").unwrap();
363        std::fs::write(temp_dir.path().join("file2.rs"), "more rust").unwrap();
364        std::fs::write(temp_dir.path().join("file3.txt"), "text file").unwrap();
365
366        let tool = GlobSearchTool::new(temp_dir.path().to_path_buf());
367        let result = tool.execute(json!({"pattern": "*.rs"})).await.unwrap();
368
369        assert_eq!(result["count"], 2);
370    }
371
372    #[tokio::test]
373    async fn test_grep_search() {
374        let temp_dir = TempDir::new().unwrap();
375        std::fs::write(
376            temp_dir.path().join("test.rs"),
377            "fn main() {\n    println!(\"hello\");\n}",
378        )
379        .unwrap();
380
381        let tool = GrepSearchTool::new(temp_dir.path().to_path_buf());
382        let result = tool
383            .execute(json!({
384                "pattern": "println",
385                "include": "*.rs"
386            }))
387            .await
388            .unwrap();
389
390        assert_eq!(result["file_count"], 1);
391        assert_eq!(result["total_matches"], 1);
392    }
393
394    // --- GlobSearchTool expanded tests ---
395
396    #[tokio::test]
397    async fn test_glob_no_matches() {
398        let tmp = TempDir::new().unwrap();
399        std::fs::write(tmp.path().join("file.txt"), "text").unwrap();
400        let tool = GlobSearchTool::new(tmp.path().into());
401        let result = tool.execute(json!({"pattern": "*.rs"})).await.unwrap();
402        assert_eq!(result["count"], 0);
403        assert_eq!(result["truncated"], false);
404    }
405
406    #[tokio::test]
407    async fn test_glob_invalid_pattern() {
408        let tmp = TempDir::new().unwrap();
409        let tool = GlobSearchTool::new(tmp.path().into());
410        let result = tool.execute(json!({"pattern": "[invalid"})).await;
411        assert!(result.is_err(), "Invalid glob should error");
412    }
413
414    #[tokio::test]
415    async fn test_glob_missing_pattern() {
416        let tmp = TempDir::new().unwrap();
417        let tool = GlobSearchTool::new(tmp.path().into());
418        let result = tool.execute(json!({})).await;
419        assert!(result.is_err(), "Missing pattern should error");
420    }
421
422    #[tokio::test]
423    async fn test_glob_max_results() {
424        let tmp = TempDir::new().unwrap();
425        for i in 0..10 {
426            std::fs::write(tmp.path().join(format!("f{}.rs", i)), "code").unwrap();
427        }
428        let tool = GlobSearchTool::new(tmp.path().into());
429        let result = tool
430            .execute(json!({"pattern": "*.rs", "max_results": 3}))
431            .await
432            .unwrap();
433        assert_eq!(result["count"], 3);
434        assert_eq!(result["truncated"], true);
435    }
436
437    #[tokio::test]
438    async fn test_glob_subdirectory() {
439        let tmp = TempDir::new().unwrap();
440        std::fs::create_dir(tmp.path().join("sub")).unwrap();
441        std::fs::write(tmp.path().join("sub/a.rs"), "code").unwrap();
442        std::fs::write(tmp.path().join("b.rs"), "code").unwrap();
443        let tool = GlobSearchTool::new(tmp.path().into());
444        // Search only in sub/
445        let result = tool
446            .execute(json!({"pattern": "*.rs", "path": "sub"}))
447            .await
448            .unwrap();
449        assert_eq!(result["count"], 1);
450    }
451
452    #[tokio::test]
453    async fn test_glob_result_has_metadata() {
454        let tmp = TempDir::new().unwrap();
455        std::fs::write(tmp.path().join("f.rs"), "hello world").unwrap();
456        let tool = GlobSearchTool::new(tmp.path().into());
457        let result = tool.execute(json!({"pattern": "*.rs"})).await.unwrap();
458        let first = &result["matches"][0];
459        assert!(first["path"].as_str().is_some());
460        assert!(first["size"].as_u64().unwrap() > 0);
461        assert!(first["modified"].as_u64().is_some());
462    }
463
464    // --- GrepSearchTool expanded tests ---
465
466    #[tokio::test]
467    async fn test_grep_no_matches() {
468        let tmp = TempDir::new().unwrap();
469        std::fs::write(tmp.path().join("f.rs"), "fn main() {}").unwrap();
470        let tool = GrepSearchTool::new(tmp.path().into());
471        let result = tool
472            .execute(json!({"pattern": "nonexistent_string_xyz"}))
473            .await
474            .unwrap();
475        assert_eq!(result["file_count"], 0);
476        assert_eq!(result["total_matches"], 0);
477    }
478
479    #[tokio::test]
480    async fn test_grep_regex() {
481        let tmp = TempDir::new().unwrap();
482        std::fs::write(
483            tmp.path().join("f.rs"),
484            "fn foo() {}\nfn bar() {}\nfn baz() {}",
485        )
486        .unwrap();
487        let tool = GrepSearchTool::new(tmp.path().into());
488        let result = tool
489            .execute(json!({"pattern": "fn \\w+\\(\\)"}))
490            .await
491            .unwrap();
492        assert_eq!(result["total_matches"], 3);
493    }
494
495    #[tokio::test]
496    async fn test_grep_invalid_regex() {
497        let tmp = TempDir::new().unwrap();
498        let tool = GrepSearchTool::new(tmp.path().into());
499        let result = tool.execute(json!({"pattern": "[invalid"})).await;
500        assert!(result.is_err(), "Invalid regex should error");
501    }
502
503    #[tokio::test]
504    async fn test_grep_missing_pattern() {
505        let tmp = TempDir::new().unwrap();
506        let tool = GrepSearchTool::new(tmp.path().into());
507        let result = tool.execute(json!({})).await;
508        assert!(result.is_err(), "Missing pattern should error");
509    }
510
511    #[tokio::test]
512    async fn test_grep_include_filter() {
513        let tmp = TempDir::new().unwrap();
514        std::fs::write(tmp.path().join("a.rs"), "hello").unwrap();
515        std::fs::write(tmp.path().join("b.txt"), "hello").unwrap();
516        let tool = GrepSearchTool::new(tmp.path().into());
517        let result = tool
518            .execute(json!({"pattern": "hello", "include": "*.rs"}))
519            .await
520            .unwrap();
521        assert_eq!(result["file_count"], 1);
522        let path = result["files"][0]["path"].as_str().unwrap();
523        assert!(path.ends_with(".rs"));
524    }
525
526    #[tokio::test]
527    async fn test_grep_context_lines() {
528        let tmp = TempDir::new().unwrap();
529        std::fs::write(
530            tmp.path().join("f.rs"),
531            "line1\nline2\nTARGET\nline4\nline5",
532        )
533        .unwrap();
534        let tool = GrepSearchTool::new(tmp.path().into());
535        let result = tool
536            .execute(json!({"pattern": "TARGET", "context_lines": 1}))
537            .await
538            .unwrap();
539        let matches = result["files"][0]["matches"].as_array().unwrap();
540        assert!(matches[0]["context"].is_array());
541        let ctx = matches[0]["context"].as_array().unwrap();
542        assert_eq!(ctx.len(), 3); // 1 before + match + 1 after
543    }
544
545    #[tokio::test]
546    async fn test_grep_max_results() {
547        let tmp = TempDir::new().unwrap();
548        for i in 0..10 {
549            std::fs::write(tmp.path().join(format!("f{}.rs", i)), "match_me").unwrap();
550        }
551        let tool = GrepSearchTool::new(tmp.path().into());
552        let result = tool
553            .execute(json!({"pattern": "match_me", "max_results": 3}))
554            .await
555            .unwrap();
556        assert_eq!(result["file_count"], 3);
557        assert_eq!(result["truncated"], true);
558    }
559
560    #[tokio::test]
561    async fn test_grep_multiple_matches_in_file() {
562        let tmp = TempDir::new().unwrap();
563        std::fs::write(tmp.path().join("f.rs"), "foo\nbar\nfoo\nbaz\nfoo").unwrap();
564        let tool = GrepSearchTool::new(tmp.path().into());
565        let result = tool.execute(json!({"pattern": "foo"})).await.unwrap();
566        assert_eq!(result["files"][0]["match_count"], 3);
567        assert_eq!(result["total_matches"], 3);
568    }
569
570    #[tokio::test]
571    async fn test_grep_line_truncation() {
572        let tmp = TempDir::new().unwrap();
573        let long_line = "x".repeat(500);
574        std::fs::write(tmp.path().join("f.rs"), &long_line).unwrap();
575        let tool = GrepSearchTool::new(tmp.path().into());
576        let result = tool.execute(json!({"pattern": "x+"})).await.unwrap();
577        let content = result["files"][0]["matches"][0]["content"]
578            .as_str()
579            .unwrap();
580        assert_eq!(
581            content.len(),
582            200,
583            "Line content should be truncated to 200 chars"
584        );
585    }
586
587    #[tokio::test]
588    async fn test_grep_sorted_by_match_count() {
589        let tmp = TempDir::new().unwrap();
590        std::fs::write(tmp.path().join("few.rs"), "x").unwrap();
591        std::fs::write(tmp.path().join("many.rs"), "x\nx\nx\nx\nx").unwrap();
592        let tool = GrepSearchTool::new(tmp.path().into());
593        let result = tool.execute(json!({"pattern": "x"})).await.unwrap();
594        let files = result["files"].as_array().unwrap();
595        assert!(files.len() == 2);
596        // First file should have more matches
597        let first_count = files[0]["match_count"].as_u64().unwrap();
598        let second_count = files[1]["match_count"].as_u64().unwrap();
599        assert!(
600            first_count >= second_count,
601            "Results should be sorted by match count desc"
602        );
603    }
604}