1use std::sync::Arc;
7
8use async_trait::async_trait;
9use serde_json::json;
10use tokio::sync::mpsc;
11
12use soul_core::error::SoulResult;
13use soul_core::tool::{Tool, ToolOutput};
14use soul_core::types::ToolDefinition;
15use soul_core::vfs::VirtualFs;
16
17use crate::truncate::{truncate_head, MAX_BYTES};
18
19const MAX_RESULTS: usize = 1000;
21
22use super::resolve_path;
23
24pub struct FindTool {
25 fs: Arc<dyn VirtualFs>,
26 cwd: String,
27}
28
29impl FindTool {
30 pub fn new(fs: Arc<dyn VirtualFs>, cwd: impl Into<String>) -> Self {
31 Self {
32 fs,
33 cwd: cwd.into(),
34 }
35 }
36}
37
38fn matches_glob(name: &str, full_path: &str, pattern: &str) -> bool {
41 let pattern = pattern.trim();
42
43 if pattern.contains("**/") || pattern.contains("/**") {
45 let simple = pattern.replace("**/", "").replace("/**", "");
46 return matches_simple_glob(name, &simple) || matches_simple_glob(full_path, pattern);
47 }
48
49 if pattern.contains('/') {
51 return path_matches_glob(full_path, pattern);
52 }
53
54 matches_simple_glob(name, pattern)
55}
56
57fn matches_simple_glob(name: &str, pattern: &str) -> bool {
58 if pattern == "*" {
59 return true;
60 }
61
62 if pattern.starts_with("*.") {
63 let ext = &pattern[1..];
64 return name.ends_with(ext);
65 }
66
67 if pattern.starts_with('*') && pattern.ends_with('*') && pattern.len() > 2 {
68 let middle = &pattern[1..pattern.len() - 1];
69 return name.contains(middle);
70 }
71
72 if pattern.starts_with('*') {
73 let suffix = &pattern[1..];
74 return name.ends_with(suffix);
75 }
76
77 if pattern.ends_with('*') {
78 let prefix = &pattern[..pattern.len() - 1];
79 return name.starts_with(prefix);
80 }
81
82 name == pattern
83}
84
85fn path_matches_glob(path: &str, pattern: &str) -> bool {
86 let path_parts: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
87 let pattern_parts: Vec<&str> = pattern.split('/').filter(|s| !s.is_empty()).collect();
88
89 if pattern_parts.is_empty() {
90 return true;
91 }
92
93 let mut pi = pattern_parts.len();
95 let mut qi = path_parts.len();
96
97 while pi > 0 && qi > 0 {
98 pi -= 1;
99 qi -= 1;
100 if pattern_parts[pi] == "**" {
101 return true; }
103 if !matches_simple_glob(path_parts[qi], pattern_parts[pi]) {
104 return false;
105 }
106 }
107
108 pi == 0
109}
110
111async fn find_files(
113 fs: &dyn VirtualFs,
114 dir: &str,
115 pattern: &str,
116 results: &mut Vec<String>,
117 limit: usize,
118) -> SoulResult<()> {
119 if results.len() >= limit {
120 return Ok(());
121 }
122
123 let entries = match fs.read_dir(dir).await {
124 Ok(e) => e,
125 Err(_) => return Ok(()), };
127
128 for entry in entries {
129 if results.len() >= limit {
130 break;
131 }
132
133 let path = if dir == "/" || dir.is_empty() {
134 format!("/{}", entry.name)
135 } else {
136 format!("{}/{}", dir.trim_end_matches('/'), entry.name)
137 };
138
139 if entry.is_dir {
140 if !entry.name.starts_with('.') {
141 Box::pin(find_files(fs, &path, pattern, results, limit)).await?;
142 }
143 } else if entry.is_file && matches_glob(&entry.name, &path, pattern) {
144 results.push(path);
145 }
146 }
147
148 Ok(())
149}
150
151#[async_trait]
152impl Tool for FindTool {
153 fn name(&self) -> &str {
154 "find"
155 }
156
157 fn definition(&self) -> ToolDefinition {
158 ToolDefinition {
159 name: "find".into(),
160 description: "Find files matching a glob pattern. Returns matching file paths.".into(),
161 input_schema: json!({
162 "type": "object",
163 "properties": {
164 "pattern": {
165 "type": "string",
166 "description": "Glob pattern to match files (e.g., '*.rs', 'src/**/*.ts', 'Cargo.toml')"
167 },
168 "path": {
169 "type": "string",
170 "description": "Directory to search in (defaults to working directory)"
171 },
172 "limit": {
173 "type": "integer",
174 "description": "Maximum number of results (default: 1000)"
175 }
176 },
177 "required": ["pattern"]
178 }),
179 }
180 }
181
182 async fn execute(
183 &self,
184 _call_id: &str,
185 arguments: serde_json::Value,
186 _partial_tx: Option<mpsc::UnboundedSender<String>>,
187 ) -> SoulResult<ToolOutput> {
188 let pattern = arguments
189 .get("pattern")
190 .and_then(|v| v.as_str())
191 .unwrap_or("");
192
193 if pattern.is_empty() {
194 return Ok(ToolOutput::error("Missing required parameter: pattern"));
195 }
196
197 let search_path = arguments
198 .get("path")
199 .and_then(|v| v.as_str())
200 .map(|p| resolve_path(&self.cwd, p))
201 .unwrap_or_else(|| self.cwd.clone());
202
203 let limit = arguments
204 .get("limit")
205 .and_then(|v| v.as_u64())
206 .map(|v| (v as usize).min(MAX_RESULTS))
207 .unwrap_or(MAX_RESULTS);
208
209 let mut results = Vec::new();
210 if let Err(e) =
211 find_files(self.fs.as_ref(), &search_path, pattern, &mut results, limit).await
212 {
213 return Ok(ToolOutput::error(format!(
214 "Failed to search {}: {}",
215 search_path, e
216 )));
217 }
218
219 results.sort();
220
221 if results.is_empty() {
222 return Ok(ToolOutput::success(format!(
223 "No files matching '{}' found",
224 pattern
225 ))
226 .with_metadata(json!({"count": 0})));
227 }
228
229 let cwd_prefix = format!("{}/", self.cwd.trim_end_matches('/'));
231 let relative: Vec<String> = results
232 .iter()
233 .map(|p| {
234 if p.starts_with(&cwd_prefix) {
235 p[cwd_prefix.len()..].to_string()
236 } else {
237 p.clone()
238 }
239 })
240 .collect();
241
242 let output = relative.join("\n");
243 let truncated = truncate_head(&output, results.len(), MAX_BYTES);
244
245 let notice = truncated.truncation_notice();
246 let mut result = truncated.content;
247 if results.len() >= limit {
248 result.push_str(&format!("\n[Reached limit: {} results]", limit));
249 }
250 if let Some(notice) = notice {
251 result.push_str(&format!("\n{}", notice));
252 }
253
254 Ok(ToolOutput::success(result).with_metadata(json!({
255 "count": results.len(),
256 "limit_reached": results.len() >= limit,
257 })))
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use soul_core::vfs::MemoryFs;
265
266 async fn setup() -> (Arc<MemoryFs>, FindTool) {
267 let fs = Arc::new(MemoryFs::new());
268 let tool = FindTool::new(fs.clone() as Arc<dyn VirtualFs>, "/project");
269 (fs, tool)
270 }
271
272 async fn populate(fs: &MemoryFs) {
273 fs.write("/project/src/main.rs", "fn main() {}")
274 .await
275 .unwrap();
276 fs.write("/project/src/lib.rs", "pub mod foo;")
277 .await
278 .unwrap();
279 fs.write("/project/src/utils.ts", "export {}")
280 .await
281 .unwrap();
282 fs.write("/project/Cargo.toml", "[package]").await.unwrap();
283 fs.write("/project/README.md", "# readme").await.unwrap();
284 }
285
286 #[tokio::test]
287 async fn find_by_extension() {
288 let (fs, tool) = setup().await;
289 populate(&*fs).await;
290
291 let result = tool
292 .execute("c1", json!({"pattern": "*.rs"}), None)
293 .await
294 .unwrap();
295
296 assert!(!result.is_error);
297 assert!(result.content.contains("main.rs"));
298 assert!(result.content.contains("lib.rs"));
299 assert!(!result.content.contains("utils.ts"));
300 }
301
302 #[tokio::test]
303 async fn find_exact_name() {
304 let (fs, tool) = setup().await;
305 populate(&*fs).await;
306
307 let result = tool
308 .execute("c2", json!({"pattern": "Cargo.toml"}), None)
309 .await
310 .unwrap();
311
312 assert!(!result.is_error);
313 assert!(result.content.contains("Cargo.toml"));
314 assert_eq!(result.metadata["count"].as_u64().unwrap(), 1);
315 }
316
317 #[tokio::test]
318 async fn find_no_results() {
319 let (fs, tool) = setup().await;
320 populate(&*fs).await;
321
322 let result = tool
323 .execute("c3", json!({"pattern": "*.py"}), None)
324 .await
325 .unwrap();
326
327 assert!(!result.is_error);
328 assert!(result.content.contains("No files"));
329 }
330
331 #[tokio::test]
332 async fn find_with_limit() {
333 let (fs, tool) = setup().await;
334 populate(&*fs).await;
335
336 let result = tool
337 .execute("c4", json!({"pattern": "*", "limit": 2}), None)
338 .await
339 .unwrap();
340
341 assert!(!result.is_error);
342 assert_eq!(result.metadata["count"].as_u64().unwrap(), 2);
343 }
344
345 #[tokio::test]
346 async fn find_empty_pattern() {
347 let (_fs, tool) = setup().await;
348 let result = tool
349 .execute("c5", json!({"pattern": ""}), None)
350 .await
351 .unwrap();
352 assert!(result.is_error);
353 }
354
355 #[test]
356 fn glob_extensions() {
357 assert!(matches_glob("file.rs", "/src/file.rs", "*.rs"));
358 assert!(!matches_glob("file.ts", "/src/file.ts", "*.rs"));
359 }
360
361 #[test]
362 fn glob_prefix() {
363 assert!(matches_glob("Cargo.toml", "/Cargo.toml", "Cargo*"));
364 assert!(!matches_glob("package.json", "/package.json", "Cargo*"));
365 }
366
367 #[test]
368 fn glob_exact() {
369 assert!(matches_glob("Makefile", "/Makefile", "Makefile"));
370 assert!(!matches_glob("makefile", "/makefile", "Makefile"));
371 }
372
373 #[tokio::test]
374 async fn tool_name_and_definition() {
375 let (_fs, tool) = setup().await;
376 assert_eq!(tool.name(), "find");
377 let def = tool.definition();
378 assert_eq!(def.name, "find");
379 }
380}