Skip to main content

soul_coder/tools/
find.rs

1//! Find tool — search for files by name/glob pattern.
2//!
3//! Uses VirtualFs for WASM compatibility. Recursively walks directories
4//! and matches filenames against glob patterns.
5
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use serde_json::json;
10use tokio::sync::mpsc;
11
12use soul_core::error::SoulResult;
13use soul_core::tool::{Tool, ToolOutput};
14use soul_core::types::ToolDefinition;
15use soul_core::vfs::VirtualFs;
16
17use crate::truncate::{truncate_head, MAX_BYTES};
18
19/// Maximum results returned.
20const MAX_RESULTS: usize = 1000;
21
22use super::resolve_path;
23
24pub struct FindTool {
25    fs: Arc<dyn VirtualFs>,
26    cwd: String,
27}
28
29impl FindTool {
30    pub fn new(fs: Arc<dyn VirtualFs>, cwd: impl Into<String>) -> Self {
31        Self {
32            fs,
33            cwd: cwd.into(),
34        }
35    }
36}
37
38/// Match a filename against a glob pattern.
39/// Supports: *.ext, prefix*, *suffix, exact match, **/ (recursive, treated as *)
40fn matches_glob(name: &str, full_path: &str, pattern: &str) -> bool {
41    let pattern = pattern.trim();
42
43    // Handle **/ patterns (recursive) - match against full path
44    if pattern.contains("**/") || pattern.contains("/**") {
45        let simple = pattern.replace("**/", "").replace("/**", "");
46        return matches_simple_glob(name, &simple) || matches_simple_glob(full_path, pattern);
47    }
48
49    // Handle path patterns (containing /)
50    if pattern.contains('/') {
51        return path_matches_glob(full_path, pattern);
52    }
53
54    matches_simple_glob(name, pattern)
55}
56
57fn matches_simple_glob(name: &str, pattern: &str) -> bool {
58    if pattern == "*" {
59        return true;
60    }
61
62    if pattern.starts_with("*.") {
63        let ext = &pattern[1..];
64        return name.ends_with(ext);
65    }
66
67    if pattern.starts_with('*') && pattern.ends_with('*') && pattern.len() > 2 {
68        let middle = &pattern[1..pattern.len() - 1];
69        return name.contains(middle);
70    }
71
72    if pattern.starts_with('*') {
73        let suffix = &pattern[1..];
74        return name.ends_with(suffix);
75    }
76
77    if pattern.ends_with('*') {
78        let prefix = &pattern[..pattern.len() - 1];
79        return name.starts_with(prefix);
80    }
81
82    name == pattern
83}
84
85fn path_matches_glob(path: &str, pattern: &str) -> bool {
86    let path_parts: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
87    let pattern_parts: Vec<&str> = pattern.split('/').filter(|s| !s.is_empty()).collect();
88
89    if pattern_parts.is_empty() {
90        return true;
91    }
92
93    // Match from the end (most specific part first)
94    let mut pi = pattern_parts.len();
95    let mut qi = path_parts.len();
96
97    while pi > 0 && qi > 0 {
98        pi -= 1;
99        qi -= 1;
100        if pattern_parts[pi] == "**" {
101            return true; // Matches any depth
102        }
103        if !matches_simple_glob(path_parts[qi], pattern_parts[pi]) {
104            return false;
105        }
106    }
107
108    pi == 0
109}
110
111/// Recursively collect matching files.
112async fn find_files(
113    fs: &dyn VirtualFs,
114    dir: &str,
115    pattern: &str,
116    results: &mut Vec<String>,
117    limit: usize,
118) -> SoulResult<()> {
119    if results.len() >= limit {
120        return Ok(());
121    }
122
123    let entries = match fs.read_dir(dir).await {
124        Ok(e) => e,
125        Err(_) => return Ok(()), // Skip unreadable dirs
126    };
127
128    for entry in entries {
129        if results.len() >= limit {
130            break;
131        }
132
133        let path = if dir == "/" || dir.is_empty() {
134            format!("/{}", entry.name)
135        } else {
136            format!("{}/{}", dir.trim_end_matches('/'), entry.name)
137        };
138
139        if entry.is_dir {
140            if !entry.name.starts_with('.') {
141                Box::pin(find_files(fs, &path, pattern, results, limit)).await?;
142            }
143        } else if entry.is_file && matches_glob(&entry.name, &path, pattern) {
144            results.push(path);
145        }
146    }
147
148    Ok(())
149}
150
151#[async_trait]
152impl Tool for FindTool {
153    fn name(&self) -> &str {
154        "find"
155    }
156
157    fn definition(&self) -> ToolDefinition {
158        ToolDefinition {
159            name: "find".into(),
160            description: "Find files matching a glob pattern. Returns matching file paths.".into(),
161            input_schema: json!({
162                "type": "object",
163                "properties": {
164                    "pattern": {
165                        "type": "string",
166                        "description": "Glob pattern to match files (e.g., '*.rs', 'src/**/*.ts', 'Cargo.toml')"
167                    },
168                    "path": {
169                        "type": "string",
170                        "description": "Directory to search in (defaults to working directory)"
171                    },
172                    "limit": {
173                        "type": "integer",
174                        "description": "Maximum number of results (default: 1000)"
175                    }
176                },
177                "required": ["pattern"]
178            }),
179        }
180    }
181
182    async fn execute(
183        &self,
184        _call_id: &str,
185        arguments: serde_json::Value,
186        _partial_tx: Option<mpsc::UnboundedSender<String>>,
187    ) -> SoulResult<ToolOutput> {
188        let pattern = arguments
189            .get("pattern")
190            .and_then(|v| v.as_str())
191            .unwrap_or("");
192
193        if pattern.is_empty() {
194            return Ok(ToolOutput::error("Missing required parameter: pattern"));
195        }
196
197        let search_path = arguments
198            .get("path")
199            .and_then(|v| v.as_str())
200            .map(|p| resolve_path(&self.cwd, p))
201            .unwrap_or_else(|| self.cwd.clone());
202
203        let limit = arguments
204            .get("limit")
205            .and_then(|v| v.as_u64())
206            .map(|v| (v as usize).min(MAX_RESULTS))
207            .unwrap_or(MAX_RESULTS);
208
209        let mut results = Vec::new();
210        if let Err(e) =
211            find_files(self.fs.as_ref(), &search_path, pattern, &mut results, limit).await
212        {
213            return Ok(ToolOutput::error(format!(
214                "Failed to search {}: {}",
215                search_path, e
216            )));
217        }
218
219        results.sort();
220
221        if results.is_empty() {
222            return Ok(ToolOutput::success(format!(
223                "No files matching '{}' found",
224                pattern
225            ))
226            .with_metadata(json!({"count": 0})));
227        }
228
229        // Make paths relative to cwd
230        let cwd_prefix = format!("{}/", self.cwd.trim_end_matches('/'));
231        let relative: Vec<String> = results
232            .iter()
233            .map(|p| {
234                if p.starts_with(&cwd_prefix) {
235                    p[cwd_prefix.len()..].to_string()
236                } else {
237                    p.clone()
238                }
239            })
240            .collect();
241
242        let output = relative.join("\n");
243        let truncated = truncate_head(&output, results.len(), MAX_BYTES);
244
245        let notice = truncated.truncation_notice();
246        let mut result = truncated.content;
247        if results.len() >= limit {
248            result.push_str(&format!("\n[Reached limit: {} results]", limit));
249        }
250        if let Some(notice) = notice {
251            result.push_str(&format!("\n{}", notice));
252        }
253
254        Ok(ToolOutput::success(result).with_metadata(json!({
255            "count": results.len(),
256            "limit_reached": results.len() >= limit,
257        })))
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use soul_core::vfs::MemoryFs;
265
266    async fn setup() -> (Arc<MemoryFs>, FindTool) {
267        let fs = Arc::new(MemoryFs::new());
268        let tool = FindTool::new(fs.clone() as Arc<dyn VirtualFs>, "/project");
269        (fs, tool)
270    }
271
272    async fn populate(fs: &MemoryFs) {
273        fs.write("/project/src/main.rs", "fn main() {}")
274            .await
275            .unwrap();
276        fs.write("/project/src/lib.rs", "pub mod foo;")
277            .await
278            .unwrap();
279        fs.write("/project/src/utils.ts", "export {}")
280            .await
281            .unwrap();
282        fs.write("/project/Cargo.toml", "[package]").await.unwrap();
283        fs.write("/project/README.md", "# readme").await.unwrap();
284    }
285
286    #[tokio::test]
287    async fn find_by_extension() {
288        let (fs, tool) = setup().await;
289        populate(&*fs).await;
290
291        let result = tool
292            .execute("c1", json!({"pattern": "*.rs"}), None)
293            .await
294            .unwrap();
295
296        assert!(!result.is_error);
297        assert!(result.content.contains("main.rs"));
298        assert!(result.content.contains("lib.rs"));
299        assert!(!result.content.contains("utils.ts"));
300    }
301
302    #[tokio::test]
303    async fn find_exact_name() {
304        let (fs, tool) = setup().await;
305        populate(&*fs).await;
306
307        let result = tool
308            .execute("c2", json!({"pattern": "Cargo.toml"}), None)
309            .await
310            .unwrap();
311
312        assert!(!result.is_error);
313        assert!(result.content.contains("Cargo.toml"));
314        assert_eq!(result.metadata["count"].as_u64().unwrap(), 1);
315    }
316
317    #[tokio::test]
318    async fn find_no_results() {
319        let (fs, tool) = setup().await;
320        populate(&*fs).await;
321
322        let result = tool
323            .execute("c3", json!({"pattern": "*.py"}), None)
324            .await
325            .unwrap();
326
327        assert!(!result.is_error);
328        assert!(result.content.contains("No files"));
329    }
330
331    #[tokio::test]
332    async fn find_with_limit() {
333        let (fs, tool) = setup().await;
334        populate(&*fs).await;
335
336        let result = tool
337            .execute("c4", json!({"pattern": "*", "limit": 2}), None)
338            .await
339            .unwrap();
340
341        assert!(!result.is_error);
342        assert_eq!(result.metadata["count"].as_u64().unwrap(), 2);
343    }
344
345    #[tokio::test]
346    async fn find_empty_pattern() {
347        let (_fs, tool) = setup().await;
348        let result = tool
349            .execute("c5", json!({"pattern": ""}), None)
350            .await
351            .unwrap();
352        assert!(result.is_error);
353    }
354
355    #[test]
356    fn glob_extensions() {
357        assert!(matches_glob("file.rs", "/src/file.rs", "*.rs"));
358        assert!(!matches_glob("file.ts", "/src/file.ts", "*.rs"));
359    }
360
361    #[test]
362    fn glob_prefix() {
363        assert!(matches_glob("Cargo.toml", "/Cargo.toml", "Cargo*"));
364        assert!(!matches_glob("package.json", "/package.json", "Cargo*"));
365    }
366
367    #[test]
368    fn glob_exact() {
369        assert!(matches_glob("Makefile", "/Makefile", "Makefile"));
370        assert!(!matches_glob("makefile", "/makefile", "Makefile"));
371    }
372
373    #[tokio::test]
374    async fn tool_name_and_definition() {
375        let (_fs, tool) = setup().await;
376        assert_eq!(tool.name(), "find");
377        let def = tool.definition();
378        assert_eq!(def.name, "find");
379    }
380}