use regex::Regex;
use crate::candidate_gate::AllowedSet;
use crate::trigram_index::{DocId, Trigram, TrigramIndex, trigrams_of};
pub const DEFAULT_MAX_SCAN: usize = 100_000;
const GREP_K1: f32 = 1.2;
const GREP_B: f32 = 0.75;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GrepMode {
Rank,
Gate,
}
#[derive(Debug, Clone, PartialEq)]
pub struct GrepHit {
pub doc_id: DocId,
pub score: f32,
pub match_count: usize,
}
#[derive(Debug, Clone)]
pub struct GrepResults {
pub hits: Vec<GrepHit>,
pub used_index: bool,
}
impl GrepResults {
pub fn into_allowed_set(self) -> AllowedSet {
AllowedSet::from_iter(self.hits.into_iter().map(|h| h.doc_id))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum GrepError {
InvalidRegex(String),
DegeneratePattern { corpus: usize, max_scan: usize },
}
impl std::fmt::Display for GrepError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
GrepError::InvalidRegex(e) => write!(f, "invalid regex: {e}"),
GrepError::DegeneratePattern { corpus, max_scan } => write!(
f,
"degenerate pattern (no indexable literal) over a corpus of {corpus} documents \
exceeds the scan budget of {max_scan}"
),
}
}
}
impl std::error::Error for GrepError {}
pub struct GrepExecutor<'a> {
index: &'a TrigramIndex,
max_scan: usize,
}
impl<'a> GrepExecutor<'a> {
pub fn new(index: &'a TrigramIndex) -> Self {
Self {
index,
max_scan: DEFAULT_MAX_SCAN,
}
}
pub fn with_max_scan(mut self, max_scan: usize) -> Self {
self.max_scan = max_scan;
self
}
pub fn search(
&self,
pattern: &str,
allowed: &AllowedSet,
limit: usize,
mode: GrepMode,
) -> Result<GrepResults, GrepError> {
let re = Regex::new(pattern).map_err(|e| GrepError::InvalidRegex(e.to_string()))?;
if allowed.is_empty() {
return Ok(GrepResults {
hits: Vec::new(),
used_index: false,
});
}
let extract = strip_leading_inline_flags(pattern);
let (terms, is_alternation) = literal_terms(extract);
let mut term_df: Vec<(String, usize)> = Vec::new();
let (candidates, used_index): (Vec<DocId>, bool) = if terms.is_empty() {
if self.index.len() > self.max_scan {
return Err(GrepError::DegeneratePattern {
corpus: self.index.len(),
max_scan: self.max_scan,
});
}
(self.index.documents().map(|(id, _)| id).collect(), false)
} else if is_alternation {
let mut union: Vec<DocId> = Vec::new();
for term in &terms {
let branch = self.index.candidates(&trigrams_of(term));
term_df.push((term.to_lowercase(), branch.len().max(1)));
union.extend(branch);
}
union.sort_unstable();
union.dedup();
(union, true)
} else {
let mut trigrams: Vec<Trigram> = Vec::new();
for term in &terms {
let df = self.index.candidates(&trigrams_of(term)).len().max(1);
term_df.push((term.to_lowercase(), df));
trigrams.extend(trigrams_of(term));
}
trigrams.sort_unstable();
trigrams.dedup();
(self.index.candidates(&trigrams), true)
};
if mode == GrepMode::Gate {
let mut hits: Vec<GrepHit> = Vec::new();
for doc_id in candidates {
if !allowed.contains(doc_id) {
continue;
}
if let Some(text) = self.index.doc_text(doc_id) {
if re.is_match(text) {
hits.push(GrepHit {
doc_id,
score: 1.0,
match_count: 1,
});
}
}
}
hits.sort_by(|a, b| a.doc_id.cmp(&b.doc_id));
if limit > 0 && hits.len() > limit {
hits.truncate(limit);
}
return Ok(GrepResults { hits, used_index });
}
let n = self.index.len().max(1) as f32;
let term_idf: Vec<(String, f32)> = term_df
.iter()
.map(|(t, df)| {
let dff = *df as f32;
let idf = (1.0 + (n - dff + 0.5) / (dff + 0.5)).ln();
(t.clone(), idf.max(0.0))
})
.collect();
struct Pending {
doc_id: DocId,
len: f32,
raw: f32,
match_count: usize,
}
let mut pending: Vec<Pending> = Vec::new();
let mut total_len = 0.0f32;
let mut counts: Vec<u32> = vec![0; term_idf.len()];
for doc_id in candidates {
if !allowed.contains(doc_id) {
continue;
}
let Some(text) = self.index.doc_text(doc_id) else {
continue;
};
let mut match_count = 0usize;
if is_alternation {
for c in counts.iter_mut() {
*c = 0;
}
for m in re.find_iter(text) {
match_count += 1;
let ms = m.as_str();
for (i, (term_lc, _)) in term_idf.iter().enumerate() {
if eq_ci_ascii(ms, term_lc) {
counts[i] += 1;
break;
}
}
}
} else {
match_count = re.find_iter(text).count();
}
if match_count == 0 {
continue;
}
let len = text.chars().count().max(1) as f32;
let raw = if term_idf.is_empty() {
tf_saturate(match_count as f32)
} else if is_alternation {
let mut s = 0.0f32;
for (i, (_, idf)) in term_idf.iter().enumerate() {
if counts[i] > 0 {
s += idf * tf_saturate(counts[i] as f32);
}
}
s
} else {
let mut s = 0.0f32;
for (term_lc, idf) in &term_idf {
let c = count_ci_ascii(text, term_lc);
if c > 0 {
s += idf * tf_saturate(c as f32);
}
}
s
};
total_len += len;
pending.push(Pending {
doc_id,
len,
raw,
match_count,
});
}
let avg_len = if pending.is_empty() {
1.0
} else {
(total_len / pending.len() as f32).max(1.0)
};
let mut hits: Vec<GrepHit> = pending
.into_iter()
.map(|p| {
let norm = 1.0 - GREP_B + GREP_B * (p.len / avg_len);
GrepHit {
doc_id: p.doc_id,
score: if norm > 0.0 { p.raw / norm } else { p.raw },
match_count: p.match_count,
}
})
.collect();
hits.sort_by(|a, b| {
b.score
.total_cmp(&a.score)
.then_with(|| a.doc_id.cmp(&b.doc_id))
});
if limit > 0 && hits.len() > limit {
hits.truncate(limit);
}
Ok(GrepResults { hits, used_index })
}
}
fn tf_saturate(count: f32) -> f32 {
count / (count + GREP_K1)
}
fn count_ci_ascii(hay: &str, needle: &str) -> usize {
let h = hay.as_bytes();
let n = needle.as_bytes();
if n.is_empty() || h.len() < n.len() {
return 0;
}
let last = h.len() - n.len();
let mut count = 0;
let mut i = 0;
while i <= last {
let mut k = 0;
while k < n.len() && h[i + k].to_ascii_lowercase() == n[k] {
k += 1;
}
if k == n.len() {
count += 1;
i += n.len(); } else {
i += 1;
}
}
count
}
fn eq_ci_ascii(a: &str, b: &str) -> bool {
a.len() == b.len()
&& a.bytes()
.zip(b.bytes())
.all(|(x, y)| x.to_ascii_lowercase() == y)
}
fn strip_leading_inline_flags(pattern: &str) -> &str {
if let Some(rest) = pattern.strip_prefix("(?") {
if let Some(close) = rest.find(')') {
let flags = &rest[..close];
if !flags.is_empty() && flags.bytes().all(|b| b.is_ascii_alphabetic() || b == b'-') {
return &rest[close + 1..];
}
}
}
pattern
}
fn literal_terms(pattern: &str) -> (Vec<String>, bool) {
if let Some(branches) = literal_alternation(pattern) {
(branches, true)
} else if let Some(runs) = required_literals(pattern) {
(runs, false)
} else {
(Vec::new(), false)
}
}
fn literal_alternation(pattern: &str) -> Option<Vec<String>> {
if !pattern.contains('|') {
return None;
}
if pattern.contains(['(', ')', '[', ']', '{', '}']) {
return None;
}
let mut branches: Vec<String> = Vec::new();
for raw in pattern.split('|') {
let lits = required_literals(raw)?;
if lits.len() != 1 {
return None;
}
branches.push(lits.into_iter().next().unwrap());
}
if branches.is_empty() {
None
} else {
Some(branches)
}
}
pub fn required_trigrams(pattern: &str) -> Option<Vec<Trigram>> {
let literals = required_literals(pattern)?;
let mut trigrams: Vec<Trigram> = Vec::new();
for lit in &literals {
trigrams.extend(trigrams_of(lit));
}
if trigrams.is_empty() {
return None;
}
trigrams.sort_unstable();
trigrams.dedup();
Some(trigrams)
}
fn required_literals(pattern: &str) -> Option<Vec<String>> {
let mut runs: Vec<String> = Vec::new();
let mut cur = String::new();
let mut chars = pattern.chars().peekable();
while let Some(c) = chars.next() {
match c {
'|' | '(' | ')' | '[' | ']' | '{' | '}' => return None,
'\\' => match chars.next() {
Some(n) if n.is_ascii_alphanumeric() => flush(&mut cur, &mut runs),
Some(n) => cur.push(n),
None => {}
},
'*' | '?' => {
cur.pop();
flush(&mut cur, &mut runs);
}
'.' | '^' | '$' | '+' => flush(&mut cur, &mut runs),
_ => cur.push(c),
}
}
flush(&mut cur, &mut runs);
let mandatory: Vec<String> = runs
.into_iter()
.filter(|r| r.chars().count() >= 3)
.collect();
if mandatory.is_empty() {
None
} else {
Some(mandatory)
}
}
fn flush(cur: &mut String, runs: &mut Vec<String>) {
if !cur.is_empty() {
runs.push(std::mem::take(cur));
}
}
#[cfg(test)]
mod tests {
use super::*;
fn build_index() -> TrigramIndex {
let mut idx = TrigramIndex::new();
idx.insert(1, "fn parse_query(input: &str) -> Query");
idx.insert(2, "let parser = build();");
idx.insert(3, "// completely unrelated comment");
idx.insert(4, "error: connection timeout occurred");
idx.insert(5, "PARSE_MODE constant");
idx
}
#[test]
fn test_required_literals_extraction() {
assert_eq!(required_literals("parse"), Some(vec!["parse".to_string()]));
assert_eq!(
required_literals("parse.*query"),
Some(vec!["parse".to_string(), "query".to_string()])
);
assert_eq!(
required_literals(r"config\.toml"),
Some(vec!["config.toml".to_string()])
);
assert_eq!(required_literals("colou?r"), Some(vec!["colo".to_string()]));
assert_eq!(required_literals("cat|dog"), None);
assert_eq!(required_literals("(foo)bar"), None);
assert_eq!(required_literals("a[bc]def"), None);
assert_eq!(required_literals("a.b"), None);
}
#[test]
fn test_grep_substring_uses_index() {
let idx = build_index();
let exec = GrepExecutor::new(&idx);
let res = exec
.search("parse", &AllowedSet::All, 0, GrepMode::Rank)
.unwrap();
assert!(res.used_index, "a pure literal must use the trigram index");
let ids: Vec<DocId> = res.hits.iter().map(|h| h.doc_id).collect();
assert!(ids.contains(&1));
assert!(ids.contains(&2));
assert!(!ids.contains(&5));
assert!(!ids.contains(&3));
}
#[test]
fn test_grep_case_insensitive_pattern() {
let idx = build_index();
let exec = GrepExecutor::new(&idx);
let res = exec
.search("(?i)parse", &AllowedSet::All, 0, GrepMode::Rank)
.unwrap();
let ids: Vec<DocId> = res.hits.iter().map(|h| h.doc_id).collect();
assert!(ids.contains(&5));
}
#[test]
fn test_grep_regex_with_wildcard() {
let idx = build_index();
let exec = GrepExecutor::new(&idx);
let res = exec
.search("parse.*query", &AllowedSet::All, 0, GrepMode::Rank)
.unwrap();
assert!(res.used_index);
let ids: Vec<DocId> = res.hits.iter().map(|h| h.doc_id).collect();
assert_eq!(ids, vec![1]);
}
#[test]
fn test_allowed_set_pushdown() {
let idx = build_index();
let exec = GrepExecutor::new(&idx);
let allowed = AllowedSet::from_iter([2u64]);
let res = exec.search("parse", &allowed, 0, GrepMode::Rank).unwrap();
let ids: Vec<DocId> = res.hits.iter().map(|h| h.doc_id).collect();
assert_eq!(ids, vec![2]);
}
#[test]
fn test_gate_mode_to_allowed_set() {
let idx = build_index();
let exec = GrepExecutor::new(&idx);
let res = exec
.search("parse", &AllowedSet::All, 0, GrepMode::Gate)
.unwrap();
let gate = res.into_allowed_set();
assert!(gate.contains(1));
assert!(gate.contains(2));
assert!(!gate.contains(3));
}
#[test]
fn test_invalid_regex_errors() {
let idx = build_index();
let exec = GrepExecutor::new(&idx);
let err = exec
.search("(unclosed", &AllowedSet::All, 0, GrepMode::Rank)
.unwrap_err();
assert!(matches!(err, GrepError::InvalidRegex(_)));
}
#[test]
fn test_degenerate_pattern_rejected_over_budget() {
let idx = build_index();
let exec = GrepExecutor::new(&idx).with_max_scan(1);
let err = exec
.search("a.", &AllowedSet::All, 0, GrepMode::Rank)
.unwrap_err();
assert!(matches!(err, GrepError::DegeneratePattern { .. }));
}
#[test]
fn test_degenerate_pattern_scans_within_budget() {
let idx = build_index();
let exec = GrepExecutor::new(&idx).with_max_scan(1000);
let res = exec
.search("er.", &AllowedSet::All, 0, GrepMode::Rank)
.unwrap();
assert!(!res.used_index, "degenerate pattern must full-scan");
assert!(!res.hits.is_empty());
}
#[test]
fn test_literal_alternation_extraction() {
assert_eq!(
literal_alternation("parse|timeout"),
Some(vec!["parse".to_string(), "timeout".to_string()])
);
assert_eq!(literal_alternation("parse"), None);
assert_eq!(literal_alternation("(parse|query)x"), None);
assert_eq!(literal_alternation("parse|ab"), None);
assert_eq!(literal_alternation("parse|foo.*bar"), None);
}
#[test]
fn test_strip_leading_inline_flags() {
assert_eq!(
strip_leading_inline_flags("(?i)parse|timeout"),
"parse|timeout"
);
assert_eq!(strip_leading_inline_flags("(?ims)parse"), "parse");
assert_eq!(strip_leading_inline_flags("(?i-u)parse|x"), "parse|x");
assert_eq!(strip_leading_inline_flags("(?i:parse|x)y"), "(?i:parse|x)y");
assert_eq!(strip_leading_inline_flags("parse|timeout"), "parse|timeout");
assert_eq!(strip_leading_inline_flags("(parse)"), "(parse)");
}
#[test]
fn test_case_insensitive_alternation_uses_index() {
let idx = build_index();
let exec = GrepExecutor::new(&idx);
let res = exec
.search("(?i)parse|timeout", &AllowedSet::All, 0, GrepMode::Rank)
.unwrap();
assert!(
res.used_index,
"flagged alternation must still use the index"
);
let ids: Vec<DocId> = res.hits.iter().map(|h| h.doc_id).collect();
assert!(ids.contains(&1));
assert!(ids.contains(&4));
assert!(
ids.contains(&5),
"case-insensitive match must include PARSE"
);
}
#[test]
fn test_alternation_uses_index_and_unions_branches() {
let idx = build_index();
let exec = GrepExecutor::new(&idx);
let res = exec
.search("parse|timeout", &AllowedSet::All, 0, GrepMode::Rank)
.unwrap();
assert!(res.used_index, "literal alternation must use the index");
let ids: Vec<DocId> = res.hits.iter().map(|h| h.doc_id).collect();
assert!(ids.contains(&1));
assert!(ids.contains(&2));
assert!(ids.contains(&4));
assert!(!ids.contains(&5));
}
#[test]
fn test_rank_prefers_rarer_term_over_common_frequent_term() {
let mut idx = TrigramIndex::new();
idx.insert(1, "alpha alpha alpha alpha");
for i in 2..=8u64 {
idx.insert(i, "alpha context");
}
idx.insert(9, "zeta marker present here");
let exec = GrepExecutor::new(&idx);
let res = exec
.search("alpha|zeta", &AllowedSet::All, 0, GrepMode::Rank)
.unwrap();
assert!(res.used_index);
assert_eq!(res.hits.first().map(|h| h.doc_id), Some(9));
let score_rare = res.hits.iter().find(|h| h.doc_id == 9).unwrap().score;
let score_common = res.hits.iter().find(|h| h.doc_id == 1).unwrap().score;
assert!(
score_rare > score_common,
"rare-term doc {score_rare} must outrank frequent common-term doc {score_common}"
);
}
#[test]
fn test_rank_saturates_repeated_matches() {
let mut idx = TrigramIndex::new();
idx.insert(1, "zebra xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx");
idx.insert(2, "zebra zebra zebra zebra zebra zebra zebra zebra");
let exec = GrepExecutor::new(&idx);
let res = exec
.search("zebra", &AllowedSet::All, 0, GrepMode::Rank)
.unwrap();
let s1 = res.hits.iter().find(|h| h.doc_id == 1).unwrap().score;
let s2 = res.hits.iter().find(|h| h.doc_id == 2).unwrap().score;
assert!(s2 > s1, "more matches should still score higher");
assert!(
s2 < 4.0 * s1,
"saturation must keep 8x matches well under 8x score (got {s2} vs {s1})"
);
}
}