use crate::bloom::BloomFilter;
use crate::postings::Postings;
use crate::prefix_extract;
use crate::regex_ast::Ast;
use crate::trigram;
#[derive(Debug, Clone)]
pub struct ApproxFilter {
pub trigrams: Vec<u64>,
pub min_required: usize,
}
impl ApproxFilter {
#[must_use]
pub fn build(ast: &Ast, k: u16) -> Self {
let runs = prefix_extract::extract_literal_runs(ast);
let mut trigrams: Vec<u64> = Vec::new();
let mut total_trigram_count: usize = 0;
for run in &runs {
if run.len() < trigram::TRIGRAM_LEN {
continue;
}
for w in run.windows(trigram::TRIGRAM_LEN) {
let h = trigram::hash_trigram(w);
trigrams.push(h);
total_trigram_count += 1;
}
}
trigrams.sort_unstable();
trigrams.dedup();
let edits = usize::from(k);
let destroyed = edits.saturating_mul(3);
let surviving = total_trigram_count.saturating_sub(destroyed);
let min_required = surviving.min(trigrams.len());
Self {
trigrams,
min_required,
}
}
#[must_use]
pub fn is_active(&self) -> bool {
self.min_required > 0 && !self.trigrams.is_empty()
}
#[must_use]
pub fn candidates(&self, postings: &Postings) -> Vec<u32> {
let union = postings.union(&self.trigrams);
union.iter().collect()
}
#[must_use]
pub fn passes(&self, bloom: &BloomFilter) -> bool {
if !self.is_active() {
return true;
}
let mut hits = 0_usize;
let mut remaining = self.trigrams.len();
for t in &self.trigrams {
if bloom.contains(&t.to_le_bytes()) {
hits += 1;
if hits >= self.min_required {
return true;
}
}
remaining -= 1;
if hits + remaining < self.min_required {
return false;
}
}
hits >= self.min_required
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::regex_ast::parse;
fn build(pattern: &str, k: u16) -> ApproxFilter {
let ast = parse(pattern).expect("parses");
ApproxFilter::build(&ast, k)
}
#[test]
fn k0_filter_requires_all_trigrams() {
let f = build("hello", 0);
assert_eq!(f.trigrams.len(), 3);
assert_eq!(f.min_required, 3);
}
#[test]
fn k1_filter_loosens_min_required() {
let f = build("hello", 1);
assert_eq!(f.trigrams.len(), 3);
assert_eq!(f.min_required, 0);
assert!(!f.is_active());
}
#[test]
fn long_pattern_under_k2_keeps_filter_active() {
let f = build("hello world", 2);
assert!(
f.min_required >= 3,
"expected min_required >= 3, got {}",
f.min_required
);
assert!(f.is_active());
}
#[test]
fn k0_filter_for_unsupported_pattern_is_inactive() {
let f = build(".*", 0);
assert!(f.trigrams.is_empty());
assert!(!f.is_active());
}
#[test]
fn passes_returns_true_when_inactive() {
let f = build(".*", 0);
let bloom = BloomFilter::with_size_and_fp_rate(64, 0.01);
assert!(f.passes(&bloom));
}
#[test]
fn passes_rejects_doc_below_threshold() {
let f = build("hello", 0);
let bloom = BloomFilter::with_size_and_fp_rate(64, 0.01);
assert!(!f.passes(&bloom));
}
#[test]
fn passes_accepts_doc_at_threshold() {
let f = build("hello", 0);
let mut bloom = BloomFilter::with_size_and_fp_rate(256, 0.001);
for t in &f.trigrams {
bloom.insert(&t.to_le_bytes());
}
assert!(f.passes(&bloom));
}
#[test]
fn min_required_never_exceeds_unique_trigrams() {
let f = build("aaaaa", 0);
assert_eq!(f.trigrams.len(), 1);
assert_eq!(f.min_required, 1);
}
}