agcodex_file_search/
lib.rs

1use ignore::WalkBuilder;
2use ignore::overrides::OverrideBuilder;
3use nucleo_matcher::Matcher;
4use nucleo_matcher::Utf32Str;
5use nucleo_matcher::pattern::AtomKind;
6use nucleo_matcher::pattern::CaseMatching;
7use nucleo_matcher::pattern::Normalization;
8use nucleo_matcher::pattern::Pattern;
9use serde::Serialize;
10use std::cell::UnsafeCell;
11use std::cmp::Reverse;
12use std::collections::BinaryHeap;
13use std::num::NonZero;
14use std::path::Path;
15use std::sync::Arc;
16use std::sync::atomic::AtomicBool;
17use std::sync::atomic::AtomicUsize;
18use std::sync::atomic::Ordering;
19use tokio::process::Command;
20
21mod cli;
22
23pub use cli::Cli;
24
25/// A single match result returned from the search.
26///
27/// * `score` – Relevance score returned by `nucleo_matcher`.
28/// * `path`  – Path to the matched file (relative to the search directory).
29/// * `indices` – Optional list of character indices that matched the query.
30///   These are only filled when the caller of [`run`] sets
31///   `compute_indices` to `true`.  The indices vector follows the
32///   guidance from `nucleo_matcher::Pattern::indices`: they are
33///   unique and sorted in ascending order so that callers can use
34///   them directly for highlighting.
35#[derive(Debug, Clone, Serialize)]
36pub struct FileMatch {
37    pub score: u32,
38    pub path: String,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub indices: Option<Vec<u32>>, // Sorted & deduplicated when present
41}
42
43pub struct FileSearchResults {
44    pub matches: Vec<FileMatch>,
45    pub total_match_count: usize,
46}
47
48pub trait Reporter {
49    fn report_match(&self, file_match: &FileMatch);
50    fn warn_matches_truncated(&self, total_match_count: usize, shown_match_count: usize);
51    fn warn_no_search_pattern(&self, search_directory: &Path);
52}
53
54pub async fn run_main<T: Reporter>(
55    Cli {
56        pattern,
57        limit,
58        cwd,
59        compute_indices,
60        json: _,
61        exclude,
62        threads,
63    }: Cli,
64    reporter: T,
65) -> anyhow::Result<()> {
66    let search_directory = match cwd {
67        Some(dir) => dir,
68        None => std::env::current_dir()?,
69    };
70    let pattern_text = match pattern {
71        Some(pattern) => pattern,
72        None => {
73            reporter.warn_no_search_pattern(&search_directory);
74            #[cfg(unix)]
75            Command::new("ls")
76                .arg("-al")
77                .current_dir(search_directory)
78                .stdout(std::process::Stdio::inherit())
79                .stderr(std::process::Stdio::inherit())
80                .status()
81                .await?;
82            #[cfg(windows)]
83            {
84                Command::new("cmd")
85                    .arg("/c")
86                    .arg(search_directory)
87                    .stdout(std::process::Stdio::inherit())
88                    .stderr(std::process::Stdio::inherit())
89                    .status()
90                    .await?;
91            }
92            return Ok(());
93        }
94    };
95
96    let cancel_flag = Arc::new(AtomicBool::new(false));
97    let FileSearchResults {
98        total_match_count,
99        matches,
100    } = run(
101        &pattern_text,
102        limit,
103        &search_directory,
104        exclude,
105        threads,
106        cancel_flag,
107        compute_indices,
108    )?;
109    let match_count = matches.len();
110    let matches_truncated = total_match_count > match_count;
111
112    for file_match in matches {
113        reporter.report_match(&file_match);
114    }
115    if matches_truncated {
116        reporter.warn_matches_truncated(total_match_count, match_count);
117    }
118
119    Ok(())
120}
121
122/// The worker threads will periodically check `cancel_flag` to see if they
123/// should stop processing files.
124pub fn run(
125    pattern_text: &str,
126    limit: NonZero<usize>,
127    search_directory: &Path,
128    exclude: Vec<String>,
129    threads: NonZero<usize>,
130    cancel_flag: Arc<AtomicBool>,
131    compute_indices: bool,
132) -> anyhow::Result<FileSearchResults> {
133    let pattern = create_pattern(pattern_text);
134    // Create one BestMatchesList per worker thread so that each worker can
135    // operate independently. The results across threads will be merged when
136    // the traversal is complete.
137    let WorkerCount {
138        num_walk_builder_threads,
139        num_best_matches_lists,
140    } = create_worker_count(threads);
141    let best_matchers_per_worker: Vec<UnsafeCell<BestMatchesList>> = (0..num_best_matches_lists)
142        .map(|_| {
143            UnsafeCell::new(BestMatchesList::new(
144                limit.get(),
145                pattern.clone(),
146                Matcher::new(nucleo_matcher::Config::DEFAULT),
147            ))
148        })
149        .collect();
150
151    // Use the ignore crate's tree-walker for efficient parallel file traversal.
152    // This provides high-performance directory walking with built-in parallelism.
153    let mut walk_builder = WalkBuilder::new(search_directory);
154    walk_builder.threads(num_walk_builder_threads);
155    if !exclude.is_empty() {
156        let mut override_builder = OverrideBuilder::new(search_directory);
157        for exclude in exclude {
158            // The `!` prefix is used to indicate an exclude pattern.
159            let exclude_pattern = format!("!{exclude}");
160            override_builder.add(&exclude_pattern)?;
161        }
162        let override_matcher = override_builder.build()?;
163        walk_builder.overrides(override_matcher);
164    }
165    let walker = walk_builder.build_parallel();
166
167    // Each worker created by `WalkParallel::run()` will have its own
168    // `BestMatchesList` to update.
169    let index_counter = AtomicUsize::new(0);
170    walker.run(|| {
171        let index = index_counter.fetch_add(1, Ordering::Relaxed);
172        let best_list_ptr = best_matchers_per_worker[index].get();
173        let best_list = unsafe { &mut *best_list_ptr };
174
175        // Each worker keeps a local counter so we only read the atomic flag
176        // every N entries which is cheaper than checking on every file.
177        const CHECK_INTERVAL: usize = 1024;
178        let mut processed = 0;
179
180        let cancel = cancel_flag.clone();
181
182        Box::new(move |entry| {
183            if let Some(path) = get_file_path(&entry, search_directory) {
184                best_list.insert(path);
185            }
186
187            processed += 1;
188            if processed % CHECK_INTERVAL == 0 && cancel.load(Ordering::Relaxed) {
189                ignore::WalkState::Quit
190            } else {
191                ignore::WalkState::Continue
192            }
193        })
194    });
195
196    fn get_file_path<'a>(
197        entry_result: &'a Result<ignore::DirEntry, ignore::Error>,
198        search_directory: &std::path::Path,
199    ) -> Option<&'a str> {
200        let entry = match entry_result {
201            Ok(e) => e,
202            Err(_) => return None,
203        };
204        if entry.file_type().is_some_and(|ft| ft.is_dir()) {
205            return None;
206        }
207        let path = entry.path();
208        match path.strip_prefix(search_directory) {
209            Ok(rel_path) => rel_path.to_str(),
210            Err(_) => None,
211        }
212    }
213
214    // If the cancel flag is set, we return early with an empty result.
215    if cancel_flag.load(Ordering::Relaxed) {
216        return Ok(FileSearchResults {
217            matches: Vec::new(),
218            total_match_count: 0,
219        });
220    }
221
222    // Merge results across best_matchers_per_worker.
223    let mut global_heap: BinaryHeap<Reverse<(u32, String)>> = BinaryHeap::new();
224    let mut total_match_count = 0;
225    for best_list_cell in best_matchers_per_worker.iter() {
226        let best_list = unsafe { &*best_list_cell.get() };
227        total_match_count += best_list.num_matches;
228        for &Reverse((score, ref line)) in best_list.binary_heap.iter() {
229            if global_heap.len() < limit.get() {
230                global_heap.push(Reverse((score, line.clone())));
231            } else if let Some(min_element) = global_heap.peek()
232                && score > min_element.0.0
233            {
234                global_heap.pop();
235                global_heap.push(Reverse((score, line.clone())));
236            }
237        }
238    }
239
240    let mut raw_matches: Vec<(u32, String)> = global_heap.into_iter().map(|r| r.0).collect();
241    sort_matches(&mut raw_matches);
242
243    // Transform into `FileMatch`, optionally computing indices.
244    let mut matcher = if compute_indices {
245        Some(Matcher::new(nucleo_matcher::Config::DEFAULT))
246    } else {
247        None
248    };
249
250    let matches: Vec<FileMatch> = raw_matches
251        .into_iter()
252        .map(|(score, path)| {
253            let indices = if compute_indices {
254                let mut buf = Vec::<char>::new();
255                let haystack: Utf32Str<'_> = Utf32Str::new(&path, &mut buf);
256                let mut idx_vec: Vec<u32> = Vec::new();
257                if let Some(ref mut m) = matcher {
258                    // Ignore the score returned from indices – we already have `score`.
259                    pattern.indices(haystack, m, &mut idx_vec);
260                }
261                idx_vec.sort_unstable();
262                idx_vec.dedup();
263                Some(idx_vec)
264            } else {
265                None
266            };
267
268            FileMatch {
269                score,
270                path,
271                indices,
272            }
273        })
274        .collect();
275
276    Ok(FileSearchResults {
277        matches,
278        total_match_count,
279    })
280}
281
282/// Sort matches in-place by descending score, then ascending path.
283fn sort_matches(matches: &mut [(u32, String)]) {
284    matches.sort_by(|a, b| match b.0.cmp(&a.0) {
285        std::cmp::Ordering::Equal => a.1.cmp(&b.1),
286        other => other,
287    });
288}
289
290/// Maintains the `max_count` best matches for a given pattern.
291struct BestMatchesList {
292    max_count: usize,
293    num_matches: usize,
294    pattern: Pattern,
295    matcher: Matcher,
296    binary_heap: BinaryHeap<Reverse<(u32, String)>>,
297
298    /// Internal buffer for converting strings to UTF-32.
299    utf32buf: Vec<char>,
300}
301
302impl BestMatchesList {
303    const fn new(max_count: usize, pattern: Pattern, matcher: Matcher) -> Self {
304        Self {
305            max_count,
306            num_matches: 0,
307            pattern,
308            matcher,
309            binary_heap: BinaryHeap::new(),
310            utf32buf: Vec::<char>::new(),
311        }
312    }
313
314    fn insert(&mut self, line: &str) {
315        let haystack: Utf32Str<'_> = Utf32Str::new(line, &mut self.utf32buf);
316        if let Some(score) = self.pattern.score(haystack, &mut self.matcher) {
317            // In the tests below, we verify that score() returns None for a
318            // non-match, so we can categorically increment the count here.
319            self.num_matches += 1;
320
321            if self.binary_heap.len() < self.max_count {
322                self.binary_heap.push(Reverse((score, line.to_string())));
323            } else if let Some(min_element) = self.binary_heap.peek()
324                && score > min_element.0.0
325            {
326                self.binary_heap.pop();
327                self.binary_heap.push(Reverse((score, line.to_string())));
328            }
329        }
330    }
331}
332
333struct WorkerCount {
334    num_walk_builder_threads: usize,
335    num_best_matches_lists: usize,
336}
337
338const fn create_worker_count(num_workers: NonZero<usize>) -> WorkerCount {
339    // It appears that the number of times the function passed to
340    // `WalkParallel::run()` is called is: the number of threads specified to
341    // the builder PLUS ONE.
342    //
343    // In `WalkParallel::visit()`, the builder function gets called once here:
344    // Reference: ignore crate's WalkParallel::visit() implementation
345    //
346    // And then once for every worker here:
347    // Reference: ignore crate's worker thread spawn logic
348    let num_walk_builder_threads = num_workers.get();
349    let num_best_matches_lists = num_walk_builder_threads + 1;
350
351    WorkerCount {
352        num_walk_builder_threads,
353        num_best_matches_lists,
354    }
355}
356
357fn create_pattern(pattern: &str) -> Pattern {
358    Pattern::new(
359        pattern,
360        CaseMatching::Smart,
361        Normalization::Smart,
362        AtomKind::Fuzzy,
363    )
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[test]
371    fn verify_score_is_none_for_non_match() {
372        let mut utf32buf = Vec::<char>::new();
373        let line = "hello";
374        let mut matcher = Matcher::new(nucleo_matcher::Config::DEFAULT);
375        let haystack: Utf32Str<'_> = Utf32Str::new(line, &mut utf32buf);
376        let pattern = create_pattern("zzz");
377        let score = pattern.score(haystack, &mut matcher);
378        assert_eq!(score, None);
379    }
380
381    #[test]
382    fn tie_breakers_sort_by_path_when_scores_equal() {
383        let mut matches = vec![
384            (100, "b_path".to_string()),
385            (100, "a_path".to_string()),
386            (90, "zzz".to_string()),
387        ];
388
389        sort_matches(&mut matches);
390
391        // Highest score first; ties broken alphabetically.
392        let expected = vec![
393            (100, "a_path".to_string()),
394            (100, "b_path".to_string()),
395            (90, "zzz".to_string()),
396        ];
397
398        assert_eq!(matches, expected);
399    }
400}