1use ignore::WalkBuilder;
2use ignore::overrides::OverrideBuilder;
3use nucleo_matcher::Matcher;
4use nucleo_matcher::Utf32Str;
5use nucleo_matcher::pattern::AtomKind;
6use nucleo_matcher::pattern::CaseMatching;
7use nucleo_matcher::pattern::Normalization;
8use nucleo_matcher::pattern::Pattern;
9use serde::Serialize;
10use std::cell::UnsafeCell;
11use std::cmp::Reverse;
12use std::collections::BinaryHeap;
13use std::num::NonZero;
14use std::path::Path;
15use std::sync::Arc;
16use std::sync::atomic::AtomicBool;
17use std::sync::atomic::AtomicUsize;
18use std::sync::atomic::Ordering;
19use tokio::process::Command;
20
21mod cli;
22
23pub use cli::Cli;
24
25#[derive(Debug, Clone, Serialize)]
36pub struct FileMatch {
37 pub score: u32,
38 pub path: String,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub indices: Option<Vec<u32>>, }
42
43pub struct FileSearchResults {
44 pub matches: Vec<FileMatch>,
45 pub total_match_count: usize,
46}
47
48pub trait Reporter {
49 fn report_match(&self, file_match: &FileMatch);
50 fn warn_matches_truncated(&self, total_match_count: usize, shown_match_count: usize);
51 fn warn_no_search_pattern(&self, search_directory: &Path);
52}
53
54pub async fn run_main<T: Reporter>(
55 Cli {
56 pattern,
57 limit,
58 cwd,
59 compute_indices,
60 json: _,
61 exclude,
62 threads,
63 }: Cli,
64 reporter: T,
65) -> anyhow::Result<()> {
66 let search_directory = match cwd {
67 Some(dir) => dir,
68 None => std::env::current_dir()?,
69 };
70 let pattern_text = match pattern {
71 Some(pattern) => pattern,
72 None => {
73 reporter.warn_no_search_pattern(&search_directory);
74 #[cfg(unix)]
75 Command::new("ls")
76 .arg("-al")
77 .current_dir(search_directory)
78 .stdout(std::process::Stdio::inherit())
79 .stderr(std::process::Stdio::inherit())
80 .status()
81 .await?;
82 #[cfg(windows)]
83 {
84 Command::new("cmd")
85 .arg("/c")
86 .arg(search_directory)
87 .stdout(std::process::Stdio::inherit())
88 .stderr(std::process::Stdio::inherit())
89 .status()
90 .await?;
91 }
92 return Ok(());
93 }
94 };
95
96 let cancel_flag = Arc::new(AtomicBool::new(false));
97 let FileSearchResults {
98 total_match_count,
99 matches,
100 } = run(
101 &pattern_text,
102 limit,
103 &search_directory,
104 exclude,
105 threads,
106 cancel_flag,
107 compute_indices,
108 )?;
109 let match_count = matches.len();
110 let matches_truncated = total_match_count > match_count;
111
112 for file_match in matches {
113 reporter.report_match(&file_match);
114 }
115 if matches_truncated {
116 reporter.warn_matches_truncated(total_match_count, match_count);
117 }
118
119 Ok(())
120}
121
122pub fn run(
125 pattern_text: &str,
126 limit: NonZero<usize>,
127 search_directory: &Path,
128 exclude: Vec<String>,
129 threads: NonZero<usize>,
130 cancel_flag: Arc<AtomicBool>,
131 compute_indices: bool,
132) -> anyhow::Result<FileSearchResults> {
133 let pattern = create_pattern(pattern_text);
134 let WorkerCount {
138 num_walk_builder_threads,
139 num_best_matches_lists,
140 } = create_worker_count(threads);
141 let best_matchers_per_worker: Vec<UnsafeCell<BestMatchesList>> = (0..num_best_matches_lists)
142 .map(|_| {
143 UnsafeCell::new(BestMatchesList::new(
144 limit.get(),
145 pattern.clone(),
146 Matcher::new(nucleo_matcher::Config::DEFAULT),
147 ))
148 })
149 .collect();
150
151 let mut walk_builder = WalkBuilder::new(search_directory);
154 walk_builder.threads(num_walk_builder_threads);
155 if !exclude.is_empty() {
156 let mut override_builder = OverrideBuilder::new(search_directory);
157 for exclude in exclude {
158 let exclude_pattern = format!("!{exclude}");
160 override_builder.add(&exclude_pattern)?;
161 }
162 let override_matcher = override_builder.build()?;
163 walk_builder.overrides(override_matcher);
164 }
165 let walker = walk_builder.build_parallel();
166
167 let index_counter = AtomicUsize::new(0);
170 walker.run(|| {
171 let index = index_counter.fetch_add(1, Ordering::Relaxed);
172 let best_list_ptr = best_matchers_per_worker[index].get();
173 let best_list = unsafe { &mut *best_list_ptr };
174
175 const CHECK_INTERVAL: usize = 1024;
178 let mut processed = 0;
179
180 let cancel = cancel_flag.clone();
181
182 Box::new(move |entry| {
183 if let Some(path) = get_file_path(&entry, search_directory) {
184 best_list.insert(path);
185 }
186
187 processed += 1;
188 if processed % CHECK_INTERVAL == 0 && cancel.load(Ordering::Relaxed) {
189 ignore::WalkState::Quit
190 } else {
191 ignore::WalkState::Continue
192 }
193 })
194 });
195
196 fn get_file_path<'a>(
197 entry_result: &'a Result<ignore::DirEntry, ignore::Error>,
198 search_directory: &std::path::Path,
199 ) -> Option<&'a str> {
200 let entry = match entry_result {
201 Ok(e) => e,
202 Err(_) => return None,
203 };
204 if entry.file_type().is_some_and(|ft| ft.is_dir()) {
205 return None;
206 }
207 let path = entry.path();
208 match path.strip_prefix(search_directory) {
209 Ok(rel_path) => rel_path.to_str(),
210 Err(_) => None,
211 }
212 }
213
214 if cancel_flag.load(Ordering::Relaxed) {
216 return Ok(FileSearchResults {
217 matches: Vec::new(),
218 total_match_count: 0,
219 });
220 }
221
222 let mut global_heap: BinaryHeap<Reverse<(u32, String)>> = BinaryHeap::new();
224 let mut total_match_count = 0;
225 for best_list_cell in best_matchers_per_worker.iter() {
226 let best_list = unsafe { &*best_list_cell.get() };
227 total_match_count += best_list.num_matches;
228 for &Reverse((score, ref line)) in best_list.binary_heap.iter() {
229 if global_heap.len() < limit.get() {
230 global_heap.push(Reverse((score, line.clone())));
231 } else if let Some(min_element) = global_heap.peek()
232 && score > min_element.0.0
233 {
234 global_heap.pop();
235 global_heap.push(Reverse((score, line.clone())));
236 }
237 }
238 }
239
240 let mut raw_matches: Vec<(u32, String)> = global_heap.into_iter().map(|r| r.0).collect();
241 sort_matches(&mut raw_matches);
242
243 let mut matcher = if compute_indices {
245 Some(Matcher::new(nucleo_matcher::Config::DEFAULT))
246 } else {
247 None
248 };
249
250 let matches: Vec<FileMatch> = raw_matches
251 .into_iter()
252 .map(|(score, path)| {
253 let indices = if compute_indices {
254 let mut buf = Vec::<char>::new();
255 let haystack: Utf32Str<'_> = Utf32Str::new(&path, &mut buf);
256 let mut idx_vec: Vec<u32> = Vec::new();
257 if let Some(ref mut m) = matcher {
258 pattern.indices(haystack, m, &mut idx_vec);
260 }
261 idx_vec.sort_unstable();
262 idx_vec.dedup();
263 Some(idx_vec)
264 } else {
265 None
266 };
267
268 FileMatch {
269 score,
270 path,
271 indices,
272 }
273 })
274 .collect();
275
276 Ok(FileSearchResults {
277 matches,
278 total_match_count,
279 })
280}
281
282fn sort_matches(matches: &mut [(u32, String)]) {
284 matches.sort_by(|a, b| match b.0.cmp(&a.0) {
285 std::cmp::Ordering::Equal => a.1.cmp(&b.1),
286 other => other,
287 });
288}
289
290struct BestMatchesList {
292 max_count: usize,
293 num_matches: usize,
294 pattern: Pattern,
295 matcher: Matcher,
296 binary_heap: BinaryHeap<Reverse<(u32, String)>>,
297
298 utf32buf: Vec<char>,
300}
301
302impl BestMatchesList {
303 const fn new(max_count: usize, pattern: Pattern, matcher: Matcher) -> Self {
304 Self {
305 max_count,
306 num_matches: 0,
307 pattern,
308 matcher,
309 binary_heap: BinaryHeap::new(),
310 utf32buf: Vec::<char>::new(),
311 }
312 }
313
314 fn insert(&mut self, line: &str) {
315 let haystack: Utf32Str<'_> = Utf32Str::new(line, &mut self.utf32buf);
316 if let Some(score) = self.pattern.score(haystack, &mut self.matcher) {
317 self.num_matches += 1;
320
321 if self.binary_heap.len() < self.max_count {
322 self.binary_heap.push(Reverse((score, line.to_string())));
323 } else if let Some(min_element) = self.binary_heap.peek()
324 && score > min_element.0.0
325 {
326 self.binary_heap.pop();
327 self.binary_heap.push(Reverse((score, line.to_string())));
328 }
329 }
330 }
331}
332
333struct WorkerCount {
334 num_walk_builder_threads: usize,
335 num_best_matches_lists: usize,
336}
337
338const fn create_worker_count(num_workers: NonZero<usize>) -> WorkerCount {
339 let num_walk_builder_threads = num_workers.get();
349 let num_best_matches_lists = num_walk_builder_threads + 1;
350
351 WorkerCount {
352 num_walk_builder_threads,
353 num_best_matches_lists,
354 }
355}
356
357fn create_pattern(pattern: &str) -> Pattern {
358 Pattern::new(
359 pattern,
360 CaseMatching::Smart,
361 Normalization::Smart,
362 AtomKind::Fuzzy,
363 )
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[test]
371 fn verify_score_is_none_for_non_match() {
372 let mut utf32buf = Vec::<char>::new();
373 let line = "hello";
374 let mut matcher = Matcher::new(nucleo_matcher::Config::DEFAULT);
375 let haystack: Utf32Str<'_> = Utf32Str::new(line, &mut utf32buf);
376 let pattern = create_pattern("zzz");
377 let score = pattern.score(haystack, &mut matcher);
378 assert_eq!(score, None);
379 }
380
381 #[test]
382 fn tie_breakers_sort_by_path_when_scores_equal() {
383 let mut matches = vec![
384 (100, "b_path".to_string()),
385 (100, "a_path".to_string()),
386 (90, "zzz".to_string()),
387 ];
388
389 sort_matches(&mut matches);
390
391 let expected = vec![
393 (100, "a_path".to_string()),
394 (100, "b_path".to_string()),
395 (90, "zzz".to_string()),
396 ];
397
398 assert_eq!(matches, expected);
399 }
400}