use std::collections::BTreeMap;
use serde::{Deserialize, Serialize};
use crate::bloom::BloomFilter;
use crate::postings::Postings;
use crate::prefix_extract;
use crate::regex_ast::{self, RegexError};
use crate::tre::{TreCompiledPattern, TreError, TreMatchOpts};
use crate::trigram;
pub const MIN_TRIGRAM_QUERY_LEN: usize = 3;
const DEFAULT_BLOOM_N: usize = 256;
const DEFAULT_BLOOM_FP: f64 = 0.01;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexedDoc {
pub text: Vec<u8>,
pub bloom: BloomFilter,
}
impl IndexedDoc {
fn new(text: Vec<u8>) -> Self {
let tris = trigram::extract_trigram_set(&text);
let mut bloom =
BloomFilter::with_size_and_fp_rate(DEFAULT_BLOOM_N.max(tris.len()), DEFAULT_BLOOM_FP);
for t in &tris {
bloom.insert(&t.to_le_bytes());
}
Self { text, bloom }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextIndex {
postings: Postings,
docs: BTreeMap<u32, IndexedDoc>,
next_doc_id: u32,
}
impl Default for TextIndex {
fn default() -> Self {
Self::new()
}
}
impl TextIndex {
#[must_use]
pub fn new() -> Self {
Self {
postings: Postings::new(),
docs: BTreeMap::new(),
next_doc_id: 0,
}
}
#[must_use]
pub fn doc_count(&self) -> usize {
self.docs.len()
}
#[must_use]
pub fn postings(&self) -> &Postings {
&self.postings
}
#[must_use]
pub fn docs(&self) -> &BTreeMap<u32, IndexedDoc> {
&self.docs
}
pub fn insert(&mut self, text: Vec<u8>) -> u32 {
let doc_id = self.next_doc_id;
self.next_doc_id = self
.next_doc_id
.checked_add(1)
.expect("invariant: doc ids fit in u32; saturate at 2^32-1 by removing old docs");
let tris = trigram::extract_trigram_set(&text);
for t in &tris {
self.postings.insert(*t, doc_id);
}
let doc = IndexedDoc::new(text);
self.docs.insert(doc_id, doc);
doc_id
}
pub fn remove(&mut self, doc_id: u32) -> Option<Vec<u8>> {
let doc = self.docs.remove(&doc_id)?;
let tris = trigram::extract_trigram_set(&doc.text);
for t in &tris {
self.postings.remove(*t, doc_id);
}
Some(doc.text)
}
#[must_use]
pub fn search_substring(&self, query: &[u8]) -> Vec<u32> {
if query.is_empty() {
return self.docs.keys().copied().collect();
}
if query.len() < MIN_TRIGRAM_QUERY_LEN {
return self.full_scan(query);
}
let qtris = trigram::extract_query_trigram_set(query);
if qtris.is_empty() {
return self.full_scan(query);
}
let candidates = self.postings.intersect(&qtris);
if candidates.is_empty() {
return Vec::new();
}
let mut hits: Vec<u32> = Vec::new();
for doc_id in &candidates {
let Some(doc) = self.docs.get(&doc_id) else {
continue;
};
if !qtris.iter().all(|t| doc.bloom.contains(&t.to_le_bytes())) {
continue;
}
if Self::contains_substring(&doc.text, query) {
hits.push(doc_id);
}
}
hits.sort_unstable();
hits
}
fn full_scan(&self, query: &[u8]) -> Vec<u32> {
let mut out = Vec::new();
for (id, doc) in &self.docs {
if Self::contains_substring(&doc.text, query) {
out.push(*id);
}
}
out
}
pub fn search_regex(&self, pattern: &str) -> Result<Vec<u32>, RegexError> {
let re = regex::bytes::Regex::new(pattern).map_err(|e| RegexError::Parse(e.to_string()))?;
let trigram_hashes: Vec<u64> = match regex_ast::parse(pattern) {
Ok(ast) => prefix_extract::required_trigram_hashes(&ast),
Err(_) => Vec::new(),
};
let candidates: Vec<u32> = if trigram_hashes.is_empty() {
self.docs.keys().copied().collect()
} else {
self.postings.intersect(&trigram_hashes).iter().collect()
};
let mut hits: Vec<u32> = Vec::new();
for doc_id in candidates {
let Some(doc) = self.docs.get(&doc_id) else {
continue;
};
if !trigram_hashes.is_empty()
&& !trigram_hashes
.iter()
.all(|t| doc.bloom.contains(&t.to_le_bytes()))
{
continue;
}
if re.is_match(&doc.text) {
hits.push(doc_id);
}
}
hits.sort_unstable();
Ok(hits)
}
fn contains_substring(haystack: &[u8], needle: &[u8]) -> bool {
if needle.is_empty() {
return true;
}
if needle.len() > haystack.len() {
return false;
}
haystack.windows(needle.len()).any(|w| w == needle)
}
pub fn search_regex_approx(
&self,
pattern: &str,
max_errors: u16,
) -> Result<Vec<u32>, TreError> {
let opts = TreMatchOpts {
max_errors,
..TreMatchOpts::default()
};
let pat = TreCompiledPattern::compile(pattern.as_bytes(), opts)?;
let mut hits = Vec::new();
for (id, doc) in &self.docs {
if pat.is_match(&doc.text) {
hits.push(*id);
}
}
Ok(hits)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn insert_then_search_finds_the_doc() {
let mut idx = TextIndex::new();
let id = idx.insert(b"hello world".to_vec());
let hits = idx.search_substring(b"hello");
assert_eq!(hits, vec![id]);
}
#[test]
fn search_substring_returns_only_true_positives() {
let mut idx = TextIndex::new();
let a = idx.insert(b"the quick brown fox".to_vec());
let _b = idx.insert(b"jumped over a lazy dog".to_vec());
let c = idx.insert(b"a brown fox is quick".to_vec());
let hits = idx.search_substring(b"brown fox");
assert!(hits.contains(&a));
assert!(hits.contains(&c));
assert_eq!(hits.len(), 2);
}
#[test]
fn search_substring_no_false_negatives_on_corpus() {
let mut store = TextIndex::new();
let corpus: &[&[u8]] = &[
b"alpha beta gamma",
b"beta cake",
b"the alphabet starts with alpha",
b"omega only",
];
let ids: Vec<u32> = corpus.iter().map(|t| store.insert(t.to_vec())).collect();
for q in [b"alpha".as_slice(), b"beta", b"omega", b"the"] {
let hits = store.search_substring(q);
for (i, doc) in corpus.iter().enumerate() {
if doc.windows(q.len()).any(|w| w == q) {
assert!(
hits.contains(&ids[i]),
"false negative: query {q:?} should hit doc {i} {doc:?}",
);
}
}
}
}
#[test]
fn search_returns_results_in_insertion_order() {
let mut idx = TextIndex::new();
let id_a = idx.insert(b"hello a".to_vec());
let id_b = idx.insert(b"hello b".to_vec());
let id_c = idx.insert(b"hello c".to_vec());
let hits = idx.search_substring(b"hello");
assert_eq!(hits, vec![id_a, id_b, id_c]);
}
#[test]
fn remove_excludes_doc_from_subsequent_searches() {
let mut idx = TextIndex::new();
let a = idx.insert(b"the quick brown fox".to_vec());
let b = idx.insert(b"another brown fox here".to_vec());
let removed = idx.remove(a).expect("doc a present");
assert_eq!(removed, b"the quick brown fox");
let hits = idx.search_substring(b"brown fox");
assert_eq!(hits, vec![b]);
}
#[test]
fn remove_garbage_collects_unique_trigrams() {
let mut idx = TextIndex::new();
let a = idx.insert(b"unique-string-only-here".to_vec());
let postings_before = idx.postings().len();
assert!(postings_before > 0);
idx.remove(a);
assert_eq!(idx.postings().len(), 0);
assert_eq!(idx.doc_count(), 0);
}
#[test]
fn remove_missing_doc_id_returns_none() {
let mut idx = TextIndex::new();
idx.insert(b"abc".to_vec());
assert!(idx.remove(9999).is_none());
}
#[test]
fn query_shorter_than_three_chars_uses_full_scan() {
let mut idx = TextIndex::new();
let a = idx.insert(b"abcdef".to_vec());
let _b = idx.insert(b"xyz".to_vec());
let c = idx.insert(b"ab".to_vec());
let hits = idx.search_substring(b"ab");
assert!(hits.contains(&a));
assert!(hits.contains(&c));
assert_eq!(hits.len(), 2);
}
#[test]
fn empty_query_matches_every_doc() {
let mut idx = TextIndex::new();
let a = idx.insert(b"x".to_vec());
let b = idx.insert(b"y".to_vec());
let hits = idx.search_substring(b"");
assert_eq!(hits, vec![a, b]);
}
#[test]
fn unicode_query_byte_level_works() {
let mut idx = TextIndex::new();
let a = idx.insert(b"caf\xc3\xa9 noir".to_vec());
let b = idx.insert(b"cafe noir".to_vec());
let hits = idx.search_substring(b"\xc3\xa9");
assert_eq!(hits, vec![a]);
let hits = idx.search_substring(b"noir");
assert!(hits.contains(&a));
assert!(hits.contains(&b));
assert_eq!(hits.len(), 2);
}
#[test]
fn search_for_nonexistent_substring_returns_empty() {
let mut idx = TextIndex::new();
idx.insert(b"hello world".to_vec());
idx.insert(b"another doc".to_vec());
assert!(idx.search_substring(b"completely-absent").is_empty());
}
#[test]
fn search_on_empty_index_returns_empty() {
let idx = TextIndex::new();
assert!(idx.search_substring(b"anything").is_empty());
assert!(idx.search_substring(b"").is_empty());
}
}