use std::collections::BTreeMap;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use crate::bloom::BloomFilter;
use crate::postings::Postings;
use crate::prefix_extract;
use crate::regex_ast::{self, RegexError};
use crate::tiling::ApproxFilter;
use crate::tre::{TreCompiledPattern, TreError, TreMatchOpts};
use crate::trigram;
pub const MIN_TRIGRAM_QUERY_LEN: usize = 3;
const DEFAULT_BLOOM_N: usize = 256;
const DEFAULT_BLOOM_FP: f64 = 0.01;
const PARALLEL_RECHECK_THRESHOLD: usize = 1024;
const PARALLEL_RECHECK_CHUNK_SIZE: usize = 256;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexedDoc {
pub text: Vec<u8>,
pub bloom: BloomFilter,
}
impl IndexedDoc {
fn new(text: Vec<u8>) -> Self {
let tris = trigram::extract_trigram_set(&text);
let mut bloom =
BloomFilter::with_size_and_fp_rate(DEFAULT_BLOOM_N.max(tris.len()), DEFAULT_BLOOM_FP);
for t in &tris {
bloom.insert(&t.to_le_bytes());
}
Self { text, bloom }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextIndex {
postings: Postings,
docs: BTreeMap<u32, IndexedDoc>,
next_doc_id: u32,
}
impl Default for TextIndex {
fn default() -> Self {
Self::new()
}
}
impl TextIndex {
#[must_use]
pub fn new() -> Self {
Self {
postings: Postings::new(),
docs: BTreeMap::new(),
next_doc_id: 0,
}
}
#[must_use]
pub fn doc_count(&self) -> usize {
self.docs.len()
}
#[must_use]
pub fn postings(&self) -> &Postings {
&self.postings
}
#[must_use]
pub fn docs(&self) -> &BTreeMap<u32, IndexedDoc> {
&self.docs
}
pub fn insert(&mut self, text: Vec<u8>) -> u32 {
let doc_id = self.next_doc_id;
self.next_doc_id = self
.next_doc_id
.checked_add(1)
.expect("invariant: doc ids fit in u32; saturate at 2^32-1 by removing old docs");
let tris = trigram::extract_trigram_set(&text);
for t in &tris {
self.postings.insert(*t, doc_id);
}
let doc = IndexedDoc::new(text);
self.docs.insert(doc_id, doc);
doc_id
}
pub fn remove(&mut self, doc_id: u32) -> Option<Vec<u8>> {
let doc = self.docs.remove(&doc_id)?;
let tris = trigram::extract_trigram_set(&doc.text);
for t in &tris {
self.postings.remove(*t, doc_id);
}
Some(doc.text)
}
#[must_use]
pub fn search_substring(&self, query: &[u8]) -> Vec<u32> {
if query.is_empty() {
return self.docs.keys().copied().collect();
}
if query.len() < MIN_TRIGRAM_QUERY_LEN {
return self.full_scan(query);
}
let qtris = trigram::extract_query_trigram_set(query);
if qtris.is_empty() {
return self.full_scan(query);
}
let candidates = self.postings.intersect(&qtris);
if candidates.is_empty() {
return Vec::new();
}
let mut hits: Vec<u32> = Vec::new();
for doc_id in &candidates {
let Some(doc) = self.docs.get(&doc_id) else {
continue;
};
if !qtris.iter().all(|t| doc.bloom.contains(&t.to_le_bytes())) {
continue;
}
if Self::contains_substring(&doc.text, query) {
hits.push(doc_id);
}
}
hits.sort_unstable();
hits
}
fn full_scan(&self, query: &[u8]) -> Vec<u32> {
let mut out = Vec::new();
for (id, doc) in &self.docs {
if Self::contains_substring(&doc.text, query) {
out.push(*id);
}
}
out
}
pub fn search_regex(&self, pattern: &str) -> Result<Vec<u32>, RegexError> {
let re = regex::bytes::Regex::new(pattern).map_err(|e| RegexError::Parse(e.to_string()))?;
let ast = regex_ast::parse(pattern).ok();
let trigram_hashes: Vec<u64> = ast
.as_ref()
.map(prefix_extract::required_trigram_hashes)
.unwrap_or_default();
let anchored_prefix = ast.as_ref().and_then(prefix_extract::anchored_prefix);
let candidates: Vec<u32> = if trigram_hashes.is_empty() {
self.docs.keys().copied().collect()
} else {
self.postings.intersect(&trigram_hashes).iter().collect()
};
let mut hits: Vec<u32> = Vec::new();
for doc_id in candidates {
let Some(doc) = self.docs.get(&doc_id) else {
continue;
};
if !trigram_hashes.is_empty()
&& !trigram_hashes
.iter()
.all(|t| doc.bloom.contains(&t.to_le_bytes()))
{
continue;
}
if let Some(prefix) = anchored_prefix.as_ref() {
if doc.text.len() < prefix.len() || &doc.text[..prefix.len()] != prefix.as_slice() {
continue;
}
}
if re.is_match(&doc.text) {
hits.push(doc_id);
}
}
hits.sort_unstable();
Ok(hits)
}
fn contains_substring(haystack: &[u8], needle: &[u8]) -> bool {
if needle.is_empty() {
return true;
}
if needle.len() > haystack.len() {
return false;
}
haystack.windows(needle.len()).any(|w| w == needle)
}
pub fn search_regex_approx(
&self,
pattern: &str,
max_errors: u16,
) -> Result<Vec<u32>, TreError> {
let opts = TreMatchOpts {
max_errors,
..TreMatchOpts::default()
};
let pat = TreCompiledPattern::compile(pattern.as_bytes(), opts)?;
let ast = regex_ast::parse(pattern).ok();
let filter = ast.as_ref().map_or_else(
|| ApproxFilter {
trigrams: Vec::new(),
min_required: 0,
},
|a| ApproxFilter::build(a, max_errors),
);
let anchored_prefix = ast.as_ref().and_then(prefix_extract::anchored_prefix);
let prefilter = |doc: &IndexedDoc| -> bool {
if filter.is_active() && !filter.passes(&doc.bloom) {
return false;
}
if let Some(prefix) = anchored_prefix.as_ref() {
if !anchor_prefix_compatible(&doc.text, prefix, max_errors) {
return false;
}
}
true
};
let survivors: Vec<(u32, &[u8])> = if filter.is_active() {
let candidates = filter.candidates(&self.postings);
candidates
.into_iter()
.filter_map(|doc_id| {
let doc = self.docs.get(&doc_id)?;
if !prefilter(doc) {
return None;
}
Some((doc_id, doc.text.as_slice()))
})
.collect()
} else {
self.docs
.iter()
.filter_map(|(id, doc)| {
if !prefilter(doc) {
return None;
}
Some((*id, doc.text.as_slice()))
})
.collect()
};
let hits: Vec<u32> = if survivors.len() >= PARALLEL_RECHECK_THRESHOLD {
run_parallel_recheck(pattern, opts, &survivors)?
} else {
survivors
.into_iter()
.filter_map(|(id, text)| if pat.is_match(text) { Some(id) } else { None })
.collect()
};
let mut out = hits;
out.sort_unstable();
Ok(out)
}
}
fn anchor_prefix_compatible(doc: &[u8], prefix: &[u8], max_errors: u16) -> bool {
if max_errors == 0 {
return doc.len() >= prefix.len() && &doc[..prefix.len()] == prefix;
}
let k = usize::from(max_errors);
let window_end = (prefix.len() + k).min(doc.len());
let window = &doc[..window_end];
let dist = bounded_edit_distance(prefix, window, k);
dist <= k
}
fn bounded_edit_distance(pat: &[u8], txt: &[u8], k: usize) -> usize {
let m = pat.len();
let n = txt.len();
if m == 0 {
return 0;
}
let mut prev = vec![usize::MAX; n + 1];
let mut curr = vec![usize::MAX; n + 1];
prev[0] = 0;
for cell in prev.iter_mut().take(n.min(k) + 1).skip(1) {
*cell = 0; }
for i in 1..=m {
curr[0] = i;
let lo = i.saturating_sub(k);
let hi = (i + k).min(n);
for j in 1..=n {
if j < lo || j > hi {
curr[j] = usize::MAX;
continue;
}
let cost = usize::from(pat[i - 1] != txt[j - 1]);
let sub = prev[j - 1].saturating_add(cost);
let del = prev[j].saturating_add(1);
let ins = curr[j - 1].saturating_add(1);
curr[j] = sub.min(del).min(ins);
}
std::mem::swap(&mut prev, &mut curr);
}
let lo = m.saturating_sub(k);
let hi = (m + k).min(n);
let mut best = usize::MAX;
for cell in prev.iter().take(hi + 1).skip(lo) {
best = best.min(*cell);
}
best.min(k + 1)
}
fn run_parallel_recheck(
pattern: &str,
opts: TreMatchOpts,
survivors: &[(u32, &[u8])],
) -> Result<Vec<u32>, TreError> {
use std::sync::atomic::{AtomicBool, Ordering};
let pattern_bytes = pattern.as_bytes();
let compile_failed = AtomicBool::new(false);
let hits: Vec<u32> = survivors
.par_chunks(PARALLEL_RECHECK_CHUNK_SIZE)
.map_init(
|| TreCompiledPattern::compile(pattern_bytes, opts).ok(),
|worker_pat: &mut Option<TreCompiledPattern>, chunk: &[(u32, &[u8])]| -> Vec<u32> {
let Some(pat) = worker_pat.as_ref() else {
compile_failed.store(true, Ordering::Relaxed);
return Vec::new();
};
let mut local = Vec::new();
for &(id, text) in chunk {
if pat.is_match(text) {
local.push(id);
}
}
local
},
)
.flatten()
.collect();
if compile_failed.load(Ordering::Relaxed) {
return Err(TreError::Internal(
"parallel worker failed to compile pattern".into(),
));
}
Ok(hits)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn insert_then_search_finds_the_doc() {
let mut idx = TextIndex::new();
let id = idx.insert(b"hello world".to_vec());
let hits = idx.search_substring(b"hello");
assert_eq!(hits, vec![id]);
}
#[test]
fn search_substring_returns_only_true_positives() {
let mut idx = TextIndex::new();
let a = idx.insert(b"the quick brown fox".to_vec());
let _b = idx.insert(b"jumped over a lazy dog".to_vec());
let c = idx.insert(b"a brown fox is quick".to_vec());
let hits = idx.search_substring(b"brown fox");
assert!(hits.contains(&a));
assert!(hits.contains(&c));
assert_eq!(hits.len(), 2);
}
#[test]
fn search_substring_no_false_negatives_on_corpus() {
let mut store = TextIndex::new();
let corpus: &[&[u8]] = &[
b"alpha beta gamma",
b"beta cake",
b"the alphabet starts with alpha",
b"omega only",
];
let ids: Vec<u32> = corpus.iter().map(|t| store.insert(t.to_vec())).collect();
for q in [b"alpha".as_slice(), b"beta", b"omega", b"the"] {
let hits = store.search_substring(q);
for (i, doc) in corpus.iter().enumerate() {
if doc.windows(q.len()).any(|w| w == q) {
assert!(
hits.contains(&ids[i]),
"false negative: query {q:?} should hit doc {i} {doc:?}",
);
}
}
}
}
#[test]
fn search_returns_results_in_insertion_order() {
let mut idx = TextIndex::new();
let id_a = idx.insert(b"hello a".to_vec());
let id_b = idx.insert(b"hello b".to_vec());
let id_c = idx.insert(b"hello c".to_vec());
let hits = idx.search_substring(b"hello");
assert_eq!(hits, vec![id_a, id_b, id_c]);
}
#[test]
fn remove_excludes_doc_from_subsequent_searches() {
let mut idx = TextIndex::new();
let a = idx.insert(b"the quick brown fox".to_vec());
let b = idx.insert(b"another brown fox here".to_vec());
let removed = idx.remove(a).expect("doc a present");
assert_eq!(removed, b"the quick brown fox");
let hits = idx.search_substring(b"brown fox");
assert_eq!(hits, vec![b]);
}
#[test]
fn remove_garbage_collects_unique_trigrams() {
let mut idx = TextIndex::new();
let a = idx.insert(b"unique-string-only-here".to_vec());
let postings_before = idx.postings().len();
assert!(postings_before > 0);
idx.remove(a);
assert_eq!(idx.postings().len(), 0);
assert_eq!(idx.doc_count(), 0);
}
#[test]
fn remove_missing_doc_id_returns_none() {
let mut idx = TextIndex::new();
idx.insert(b"abc".to_vec());
assert!(idx.remove(9999).is_none());
}
#[test]
fn query_shorter_than_three_chars_uses_full_scan() {
let mut idx = TextIndex::new();
let a = idx.insert(b"abcdef".to_vec());
let _b = idx.insert(b"xyz".to_vec());
let c = idx.insert(b"ab".to_vec());
let hits = idx.search_substring(b"ab");
assert!(hits.contains(&a));
assert!(hits.contains(&c));
assert_eq!(hits.len(), 2);
}
#[test]
fn empty_query_matches_every_doc() {
let mut idx = TextIndex::new();
let a = idx.insert(b"x".to_vec());
let b = idx.insert(b"y".to_vec());
let hits = idx.search_substring(b"");
assert_eq!(hits, vec![a, b]);
}
#[test]
fn unicode_query_byte_level_works() {
let mut idx = TextIndex::new();
let a = idx.insert(b"caf\xc3\xa9 noir".to_vec());
let b = idx.insert(b"cafe noir".to_vec());
let hits = idx.search_substring(b"\xc3\xa9");
assert_eq!(hits, vec![a]);
let hits = idx.search_substring(b"noir");
assert!(hits.contains(&a));
assert!(hits.contains(&b));
assert_eq!(hits.len(), 2);
}
#[test]
fn search_for_nonexistent_substring_returns_empty() {
let mut idx = TextIndex::new();
idx.insert(b"hello world".to_vec());
idx.insert(b"another doc".to_vec());
assert!(idx.search_substring(b"completely-absent").is_empty());
}
#[test]
fn search_on_empty_index_returns_empty() {
let idx = TextIndex::new();
assert!(idx.search_substring(b"anything").is_empty());
assert!(idx.search_substring(b"").is_empty());
}
}