Skip to main content

krait/commands/
search.rs

1use std::path::{Path, PathBuf};
2
3use anyhow::Context as _;
4use rayon::prelude::*;
5use regex::{Regex, RegexBuilder};
6use serde::Serialize;
7
8use crate::detect::Language;
9
10/// Bytes to scan at the start of a file for binary detection.
11/// Smaller than read.rs's `BINARY_SCAN_SIZE` intentionally — search is a hot path.
12const BINARY_SCAN_BYTES: usize = 512;
13
14/// A single match found during search.
15#[derive(Debug, Serialize)]
16pub struct SearchMatch {
17    pub path: String,
18    pub line: u32,
19    pub column: u32,
20    pub preview: String,
21    pub context_before: Vec<String>,
22    pub context_after: Vec<String>,
23}
24
25/// Aggregated search results.
26#[derive(Debug, Serialize)]
27pub struct SearchOutput {
28    pub matches: Vec<SearchMatch>,
29    pub total_matches: usize,
30    pub files_searched: usize,
31    pub files_with_matches: usize,
32    pub truncated: bool,
33}
34
35/// Options controlling search behaviour.
36#[allow(clippy::struct_excessive_bools)]
37pub struct SearchOptions {
38    pub pattern: String,
39    pub path: Option<PathBuf>,
40    pub ignore_case: bool,
41    pub word: bool,
42    pub literal: bool,
43    pub context: u32,
44    pub files_only: bool,
45    pub lang_filter: Option<String>,
46    pub max_matches: usize,
47}
48
49/// Run the search and return aggregated results.
50///
51/// # Errors
52/// Returns an error if the regex pattern is invalid or file walking fails.
53pub fn run(opts: &SearchOptions, project_root: &Path) -> anyhow::Result<SearchOutput> {
54    let re = build_regex(opts)?;
55    let search_root = opts.path.as_deref().unwrap_or(project_root);
56    let files = collect_files(search_root, opts)?;
57    let files_searched = files.len();
58
59    // Parallel search: each file returns a Vec<SearchMatch>
60    let file_results: Vec<Vec<SearchMatch>> = files
61        .par_iter()
62        .map(|path| search_file(path, search_root, project_root, &re, opts))
63        .collect();
64
65    // Flatten, sort, truncate
66    let mut matches: Vec<SearchMatch> = Vec::new();
67    let mut files_with_matches: usize = 0;
68    let mut truncated = false;
69
70    for file_matches in file_results {
71        if file_matches.is_empty() {
72            continue;
73        }
74        files_with_matches += 1;
75        for m in file_matches {
76            if matches.len() >= opts.max_matches {
77                truncated = true;
78                break;
79            }
80            matches.push(m);
81        }
82        if truncated {
83            break;
84        }
85    }
86
87    let total_matches = matches.len();
88    Ok(SearchOutput {
89        matches,
90        total_matches,
91        files_searched,
92        files_with_matches,
93        truncated,
94    })
95}
96
97fn build_regex(opts: &SearchOptions) -> anyhow::Result<Regex> {
98    let pat = if opts.literal {
99        regex::escape(&opts.pattern)
100    } else {
101        normalize_grep_escapes(&opts.pattern)
102    };
103
104    let pat = if opts.word {
105        format!(r"\b{pat}\b")
106    } else {
107        pat
108    };
109
110    RegexBuilder::new(&pat)
111        .case_insensitive(opts.ignore_case)
112        .build()
113        .with_context(|| format!("invalid regex pattern: {}", opts.pattern))
114}
115
116/// Convert grep/BRE-style escape sequences to Rust regex (ERE) syntax.
117///
118/// Agents trained on Unix tools often emit `\|` for alternation, `\+` for
119/// one-or-more, etc. In Rust regex these are invalid or literal — normalize
120/// them silently so patterns just work.
121fn normalize_grep_escapes(pattern: &str) -> String {
122    let mut out = String::with_capacity(pattern.len());
123    let mut chars = pattern.chars().peekable();
124    while let Some(c) = chars.next() {
125        if c == '\\' {
126            match chars.peek() {
127                Some('|') => {
128                    out.push('|');
129                    chars.next();
130                }
131                Some('+') => {
132                    out.push('+');
133                    chars.next();
134                }
135                Some('?') => {
136                    out.push('?');
137                    chars.next();
138                }
139                Some('(') => {
140                    out.push('(');
141                    chars.next();
142                }
143                Some(')') => {
144                    out.push(')');
145                    chars.next();
146                }
147                _ => out.push(c),
148            }
149        } else {
150            out.push(c);
151        }
152    }
153    out
154}
155
156fn collect_files(root: &Path, opts: &SearchOptions) -> anyhow::Result<Vec<PathBuf>> {
157    let extensions: Option<&[&str]> = opts.lang_filter.as_deref().map(extensions_for_lang);
158
159    let mut builder = ignore::WalkBuilder::new(root);
160    builder
161        .hidden(true)
162        .git_ignore(true)
163        .git_global(false)
164        .git_exclude(true);
165
166    let mut files: Vec<PathBuf> = Vec::new();
167    for entry in builder.build() {
168        let entry = entry?;
169        if !entry.file_type().is_some_and(|ft| ft.is_file()) {
170            continue;
171        }
172        let path = entry.path();
173        if let Some(exts) = extensions {
174            match path.extension().and_then(|e| e.to_str()) {
175                Some(ext) if exts.contains(&ext) => {}
176                _ => continue,
177            }
178        }
179        files.push(path.to_path_buf());
180    }
181
182    files.sort();
183    Ok(files)
184}
185
186fn search_file(
187    path: &Path,
188    search_root: &Path,
189    project_root: &Path,
190    re: &Regex,
191    opts: &SearchOptions,
192) -> Vec<SearchMatch> {
193    let Ok(bytes) = std::fs::read(path) else {
194        return vec![];
195    };
196
197    // Skip binary files: null byte in first BINARY_SCAN_BYTES
198    if bytes[..bytes.len().min(BINARY_SCAN_BYTES)].contains(&0u8) {
199        return vec![];
200    }
201
202    let Ok(content) = std::str::from_utf8(&bytes) else {
203        return vec![];
204    };
205
206    // Relative path for display
207    let rel = path
208        .strip_prefix(search_root)
209        .or_else(|_| path.strip_prefix(project_root))
210        .unwrap_or(path)
211        .to_string_lossy()
212        .into_owned();
213
214    let lines: Vec<&str> = content.lines().collect();
215    let mut result: Vec<SearchMatch> = Vec::new();
216
217    for (idx, line) in lines.iter().enumerate() {
218        let Some(m) = re.find(line) else { continue };
219
220        let line_no = u32::try_from(idx + 1).unwrap_or(u32::MAX);
221        let col = u32::try_from(m.start() + 1).unwrap_or(1);
222
223        let (context_before, context_after) = if opts.context > 0 {
224            let ctx = opts.context as usize;
225            let before: Vec<String> = lines[idx.saturating_sub(ctx)..idx]
226                .iter()
227                .map(ToString::to_string)
228                .collect();
229            let after: Vec<String> = lines[(idx + 1)..(idx + 1 + ctx).min(lines.len())]
230                .iter()
231                .map(ToString::to_string)
232                .collect();
233            (before, after)
234        } else {
235            (vec![], vec![])
236        };
237
238        result.push(SearchMatch {
239            path: rel.clone(),
240            line: line_no,
241            column: col,
242            preview: line.to_string(),
243            context_before,
244            context_after,
245        });
246    }
247
248    result
249}
250
251/// Map CLI language flag to file extensions via `Language::extensions()` — single source of truth.
252fn extensions_for_lang(lang: &str) -> &'static [&'static str] {
253    match lang {
254        "ts" | "typescript" => Language::TypeScript.extensions(),
255        "js" | "javascript" => Language::JavaScript.extensions(),
256        "rs" | "rust" => Language::Rust.extensions(),
257        "go" => Language::Go.extensions(),
258        "c" | "cpp" | "c++" | "cxx" => Language::Cpp.extensions(),
259        _ => &[],
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use std::fs;
266
267    use tempfile::TempDir;
268
269    use super::*;
270
271    fn make_project(files: &[(&str, &str)]) -> TempDir {
272        let dir = tempfile::tempdir().unwrap();
273        for (name, content) in files {
274            let path = dir.path().join(name);
275            if let Some(parent) = path.parent() {
276                fs::create_dir_all(parent).unwrap();
277            }
278            fs::write(path, content).unwrap();
279        }
280        dir
281    }
282
283    fn opts(pattern: &str) -> SearchOptions {
284        SearchOptions {
285            pattern: pattern.to_string(),
286            path: None,
287            ignore_case: false,
288            word: false,
289            literal: false,
290            context: 0,
291            files_only: false,
292            lang_filter: None,
293            max_matches: 200,
294        }
295    }
296
297    #[test]
298    fn finds_literal_match() {
299        let dir = make_project(&[("src/lib.rs", "fn hello() {}\nfn world() {}")]);
300        let o = run(
301            &SearchOptions {
302                literal: true,
303                ..opts("hello")
304            },
305            dir.path(),
306        )
307        .unwrap();
308        assert_eq!(o.total_matches, 1);
309        assert_eq!(o.matches[0].line, 1);
310        assert!(o.matches[0].preview.contains("hello"));
311    }
312
313    #[test]
314    fn finds_regex_match() {
315        let dir = make_project(&[("a.rs", "foo123\nbar456")]);
316        let o = run(&opts(r"\d+"), dir.path()).unwrap();
317        assert_eq!(o.total_matches, 2);
318    }
319
320    #[test]
321    fn ignore_case_works() {
322        let dir = make_project(&[("a.rs", "Hello\nhello\nHELLO")]);
323        let o = run(
324            &SearchOptions {
325                ignore_case: true,
326                ..opts("hello")
327            },
328            dir.path(),
329        )
330        .unwrap();
331        assert_eq!(o.total_matches, 3);
332    }
333
334    #[test]
335    fn word_boundary_works() {
336        let dir = make_project(&[("a.rs", "foobar\nfoo bar\nfoo")]);
337        let o = run(
338            &SearchOptions {
339                word: true,
340                ..opts("foo")
341            },
342            dir.path(),
343        )
344        .unwrap();
345        // "foobar" should NOT match, "foo bar" and "foo" should match
346        assert_eq!(o.total_matches, 2);
347    }
348
349    #[test]
350    fn respects_max_matches() {
351        let content: String = (1..=10).map(|i| format!("line{i}\n")).collect();
352        let dir = make_project(&[("a.rs", &content)]);
353        let o = run(
354            &SearchOptions {
355                max_matches: 3,
356                ..opts("line")
357            },
358            dir.path(),
359        )
360        .unwrap();
361        assert_eq!(o.total_matches, 3);
362        assert!(o.truncated);
363    }
364
365    #[test]
366    fn skips_binary_files() {
367        let dir = tempfile::tempdir().unwrap();
368        let mut binary = vec![0u8; 100];
369        binary.extend_from_slice(b"hello");
370        fs::write(dir.path().join("file.bin"), &binary).unwrap();
371        let o = run(&opts("hello"), dir.path()).unwrap();
372        assert_eq!(o.total_matches, 0);
373    }
374
375    #[test]
376    fn lang_filter_ts_only() {
377        let dir = make_project(&[("a.ts", "const foo = 1;"), ("b.rs", "const foo: i32 = 1;")]);
378        let o = run(
379            &SearchOptions {
380                lang_filter: Some("ts".to_string()),
381                ..opts("foo")
382            },
383            dir.path(),
384        )
385        .unwrap();
386        assert_eq!(o.total_matches, 1);
387        assert!(o.matches[0].path.ends_with("a.ts"));
388    }
389
390    #[test]
391    fn context_lines_correct() {
392        let dir = make_project(&[("a.rs", "line1\nline2\ntarget\nline4\nline5")]);
393        let o = run(
394            &SearchOptions {
395                context: 1,
396                ..opts("target")
397            },
398            dir.path(),
399        )
400        .unwrap();
401        assert_eq!(o.total_matches, 1);
402        let m = &o.matches[0];
403        assert_eq!(m.context_before, vec!["line2"]);
404        assert_eq!(m.context_after, vec!["line4"]);
405    }
406
407    #[test]
408    fn files_only_mode() {
409        let dir = make_project(&[("a.rs", "needle"), ("b.rs", "haystack")]);
410        let o = run(
411            &SearchOptions {
412                files_only: true,
413                ..opts("needle")
414            },
415            dir.path(),
416        )
417        .unwrap();
418        assert_eq!(o.files_with_matches, 1);
419    }
420
421    #[test]
422    fn no_matches_returns_empty() {
423        let dir = make_project(&[("a.rs", "hello world")]);
424        let o = run(&opts("nonexistent_xyz"), dir.path()).unwrap();
425        assert_eq!(o.total_matches, 0);
426        assert!(!o.truncated);
427    }
428
429    #[test]
430    fn grep_escape_alternation() {
431        // \| is grep BRE syntax; should be treated as | (alternation)
432        let dir = make_project(&[("a.ts", "import foo\nfrom bar")]);
433        let o = run(&opts(r"foo\|bar"), dir.path()).unwrap();
434        assert_eq!(o.total_matches, 2);
435    }
436
437    #[test]
438    fn grep_escape_plus_and_parens() {
439        assert_eq!(normalize_grep_escapes(r"foo\+"), "foo+");
440        assert_eq!(normalize_grep_escapes(r"\(foo\)"), "(foo)");
441        assert_eq!(normalize_grep_escapes(r"a\?b"), "a?b");
442    }
443
444    #[test]
445    fn real_pipe_unaffected() {
446        // a plain | should still work as alternation
447        let dir = make_project(&[("a.ts", "import foo\nfrom bar")]);
448        let o = run(&opts("foo|bar"), dir.path()).unwrap();
449        assert_eq!(o.total_matches, 2);
450    }
451}