1use std::sync::Arc;
7
8use async_trait::async_trait;
9use serde_json::json;
10use tokio::sync::mpsc;
11
12use soul_core::error::SoulResult;
13use soul_core::tool::{Tool, ToolOutput};
14use soul_core::types::ToolDefinition;
15use soul_core::vfs::VirtualFs;
16
17use crate::truncate::{truncate_head, MAX_BYTES};
18
19const MAX_RESULTS: usize = 1000;
21
22use super::resolve_path;
23
24pub struct FindTool {
25 fs: Arc<dyn VirtualFs>,
26 cwd: String,
27}
28
29impl FindTool {
30 pub fn new(fs: Arc<dyn VirtualFs>, cwd: impl Into<String>) -> Self {
31 Self {
32 fs,
33 cwd: cwd.into(),
34 }
35 }
36}
37
38fn matches_glob(name: &str, full_path: &str, pattern: &str) -> bool {
41 let pattern = pattern.trim();
42
43 if pattern.contains("**/") || pattern.contains("/**") {
45 let simple = pattern.replace("**/", "").replace("/**", "");
46 return matches_simple_glob(name, &simple) || matches_simple_glob(full_path, pattern);
47 }
48
49 if pattern.contains('/') {
51 return path_matches_glob(full_path, pattern);
52 }
53
54 matches_simple_glob(name, pattern)
55}
56
57fn matches_simple_glob(name: &str, pattern: &str) -> bool {
58 if pattern == "*" {
59 return true;
60 }
61
62 if pattern.starts_with("*.") {
63 let ext = &pattern[1..];
64 return name.ends_with(ext);
65 }
66
67 if pattern.starts_with('*') && pattern.ends_with('*') && pattern.len() > 2 {
68 let middle = &pattern[1..pattern.len() - 1];
69 return name.contains(middle);
70 }
71
72 if pattern.starts_with('*') {
73 let suffix = &pattern[1..];
74 return name.ends_with(suffix);
75 }
76
77 if pattern.ends_with('*') {
78 let prefix = &pattern[..pattern.len() - 1];
79 return name.starts_with(prefix);
80 }
81
82 name == pattern
83}
84
85fn path_matches_glob(path: &str, pattern: &str) -> bool {
86 let path_parts: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
87 let pattern_parts: Vec<&str> = pattern.split('/').filter(|s| !s.is_empty()).collect();
88
89 if pattern_parts.is_empty() {
90 return true;
91 }
92
93 let mut pi = pattern_parts.len();
95 let mut qi = path_parts.len();
96
97 while pi > 0 && qi > 0 {
98 pi -= 1;
99 qi -= 1;
100 if pattern_parts[pi] == "**" {
101 return true; }
103 if !matches_simple_glob(path_parts[qi], pattern_parts[pi]) {
104 return false;
105 }
106 }
107
108 pi == 0
109}
110
111async fn find_files(
113 fs: &dyn VirtualFs,
114 dir: &str,
115 pattern: &str,
116 results: &mut Vec<String>,
117 limit: usize,
118) -> SoulResult<()> {
119 if results.len() >= limit {
120 return Ok(());
121 }
122
123 let entries = match fs.read_dir(dir).await {
124 Ok(e) => e,
125 Err(_) => return Ok(()), };
127
128 for entry in entries {
129 if results.len() >= limit {
130 break;
131 }
132
133 let path = if dir == "/" || dir.is_empty() {
134 format!("/{}", entry.name)
135 } else {
136 format!("{}/{}", dir.trim_end_matches('/'), entry.name)
137 };
138
139 if entry.is_dir {
140 if !entry.name.starts_with('.') {
141 Box::pin(find_files(fs, &path, pattern, results, limit)).await?;
142 }
143 } else if entry.is_file && matches_glob(&entry.name, &path, pattern) {
144 results.push(path);
145 }
146 }
147
148 Ok(())
149}
150
151#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
152#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
153impl Tool for FindTool {
154 fn name(&self) -> &str {
155 "find"
156 }
157
158 fn definition(&self) -> ToolDefinition {
159 ToolDefinition {
160 name: "find".into(),
161 description: "Find files matching a glob pattern. Returns matching file paths.".into(),
162 input_schema: json!({
163 "type": "object",
164 "properties": {
165 "pattern": {
166 "type": "string",
167 "description": "Glob pattern to match files (e.g., '*.rs', 'src/**/*.ts', 'Cargo.toml')"
168 },
169 "path": {
170 "type": "string",
171 "description": "Directory to search in (defaults to working directory)"
172 },
173 "limit": {
174 "type": "integer",
175 "description": "Maximum number of results (default: 1000)"
176 }
177 },
178 "required": ["pattern"]
179 }),
180 }
181 }
182
183 async fn execute(
184 &self,
185 _call_id: &str,
186 arguments: serde_json::Value,
187 _partial_tx: Option<mpsc::UnboundedSender<String>>,
188 ) -> SoulResult<ToolOutput> {
189 let pattern = arguments
190 .get("pattern")
191 .and_then(|v| v.as_str())
192 .unwrap_or("");
193
194 if pattern.is_empty() {
195 return Ok(ToolOutput::error("Missing required parameter: pattern"));
196 }
197
198 let search_path = arguments
199 .get("path")
200 .and_then(|v| v.as_str())
201 .map(|p| resolve_path(&self.cwd, p))
202 .unwrap_or_else(|| self.cwd.clone());
203
204 let limit = arguments
205 .get("limit")
206 .and_then(|v| v.as_u64())
207 .map(|v| (v as usize).min(MAX_RESULTS))
208 .unwrap_or(MAX_RESULTS);
209
210 let mut results = Vec::new();
211 if let Err(e) =
212 find_files(self.fs.as_ref(), &search_path, pattern, &mut results, limit).await
213 {
214 return Ok(ToolOutput::error(format!(
215 "Failed to search {}: {}",
216 search_path, e
217 )));
218 }
219
220 results.sort();
221
222 if results.is_empty() {
223 return Ok(ToolOutput::success(format!(
224 "No files matching '{}' found",
225 pattern
226 ))
227 .with_metadata(json!({"count": 0})));
228 }
229
230 let cwd_prefix = format!("{}/", self.cwd.trim_end_matches('/'));
232 let relative: Vec<String> = results
233 .iter()
234 .map(|p| {
235 if p.starts_with(&cwd_prefix) {
236 p[cwd_prefix.len()..].to_string()
237 } else {
238 p.clone()
239 }
240 })
241 .collect();
242
243 let output = relative.join("\n");
244 let truncated = truncate_head(&output, results.len(), MAX_BYTES);
245
246 let notice = truncated.truncation_notice();
247 let mut result = truncated.content;
248 if results.len() >= limit {
249 result.push_str(&format!("\n[Reached limit: {} results]", limit));
250 }
251 if let Some(notice) = notice {
252 result.push_str(&format!("\n{}", notice));
253 }
254
255 Ok(ToolOutput::success(result).with_metadata(json!({
256 "count": results.len(),
257 "limit_reached": results.len() >= limit,
258 })))
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use soul_core::vfs::MemoryFs;
266
267 async fn setup() -> (Arc<MemoryFs>, FindTool) {
268 let fs = Arc::new(MemoryFs::new());
269 let tool = FindTool::new(fs.clone() as Arc<dyn VirtualFs>, "/project");
270 (fs, tool)
271 }
272
273 async fn populate(fs: &MemoryFs) {
274 fs.write("/project/src/main.rs", "fn main() {}")
275 .await
276 .unwrap();
277 fs.write("/project/src/lib.rs", "pub mod foo;")
278 .await
279 .unwrap();
280 fs.write("/project/src/utils.ts", "export {}")
281 .await
282 .unwrap();
283 fs.write("/project/Cargo.toml", "[package]").await.unwrap();
284 fs.write("/project/README.md", "# readme").await.unwrap();
285 }
286
287 #[tokio::test]
288 async fn find_by_extension() {
289 let (fs, tool) = setup().await;
290 populate(&*fs).await;
291
292 let result = tool
293 .execute("c1", json!({"pattern": "*.rs"}), None)
294 .await
295 .unwrap();
296
297 assert!(!result.is_error);
298 assert!(result.content.contains("main.rs"));
299 assert!(result.content.contains("lib.rs"));
300 assert!(!result.content.contains("utils.ts"));
301 }
302
303 #[tokio::test]
304 async fn find_exact_name() {
305 let (fs, tool) = setup().await;
306 populate(&*fs).await;
307
308 let result = tool
309 .execute("c2", json!({"pattern": "Cargo.toml"}), None)
310 .await
311 .unwrap();
312
313 assert!(!result.is_error);
314 assert!(result.content.contains("Cargo.toml"));
315 assert_eq!(result.metadata["count"].as_u64().unwrap(), 1);
316 }
317
318 #[tokio::test]
319 async fn find_no_results() {
320 let (fs, tool) = setup().await;
321 populate(&*fs).await;
322
323 let result = tool
324 .execute("c3", json!({"pattern": "*.py"}), None)
325 .await
326 .unwrap();
327
328 assert!(!result.is_error);
329 assert!(result.content.contains("No files"));
330 }
331
332 #[tokio::test]
333 async fn find_with_limit() {
334 let (fs, tool) = setup().await;
335 populate(&*fs).await;
336
337 let result = tool
338 .execute("c4", json!({"pattern": "*", "limit": 2}), None)
339 .await
340 .unwrap();
341
342 assert!(!result.is_error);
343 assert_eq!(result.metadata["count"].as_u64().unwrap(), 2);
344 }
345
346 #[tokio::test]
347 async fn find_empty_pattern() {
348 let (_fs, tool) = setup().await;
349 let result = tool
350 .execute("c5", json!({"pattern": ""}), None)
351 .await
352 .unwrap();
353 assert!(result.is_error);
354 }
355
356 #[test]
357 fn glob_extensions() {
358 assert!(matches_glob("file.rs", "/src/file.rs", "*.rs"));
359 assert!(!matches_glob("file.ts", "/src/file.ts", "*.rs"));
360 }
361
362 #[test]
363 fn glob_prefix() {
364 assert!(matches_glob("Cargo.toml", "/Cargo.toml", "Cargo*"));
365 assert!(!matches_glob("package.json", "/package.json", "Cargo*"));
366 }
367
368 #[test]
369 fn glob_exact() {
370 assert!(matches_glob("Makefile", "/Makefile", "Makefile"));
371 assert!(!matches_glob("makefile", "/makefile", "Makefile"));
372 }
373
374 #[tokio::test]
375 async fn tool_name_and_definition() {
376 let (_fs, tool) = setup().await;
377 assert_eq!(tool.name(), "find");
378 let def = tool.definition();
379 assert_eq!(def.name, "find");
380 }
381}