#![allow(clippy::cast_possible_truncation)] #![allow(clippy::cast_precision_loss)]
use roaring::RoaringBitmap;
use rustc_hash::{FxHashMap, FxHashSet};
use std::collections::HashSet;
pub type Trigram = [u8; 3];
#[must_use]
pub fn extract_trigrams(text: &str) -> HashSet<Trigram> {
extract_trigrams_internal(text, true)
}
#[must_use]
pub fn extract_trigrams_for_pattern(text: &str) -> HashSet<Trigram> {
extract_trigrams_internal(text, false)
}
fn extract_trigrams_internal(text: &str, trailing_padding: bool) -> HashSet<Trigram> {
if text.is_empty() {
return HashSet::new();
}
let text_bytes = text.as_bytes();
let text_len = text_bytes.len();
let trailing_pad = if trailing_padding { 2 } else { 0 };
let total_len = 2 + text_len + trailing_pad;
let trigram_count = if total_len >= 3 { total_len - 2 } else { 0 };
let mut trigrams = HashSet::with_capacity(trigram_count);
for i in 0..trigram_count {
let trigram: [u8; 3] = std::array::from_fn(|j| {
let pos = i + j;
if pos < 2 {
b' ' } else if pos < 2 + text_len {
text_bytes[pos - 2]
} else {
b' ' }
});
trigrams.insert(trigram);
}
trigrams
}
#[derive(Debug, Clone, Default)]
pub struct TrigramStats {
pub doc_count: u64,
pub trigram_count: usize,
pub memory_bytes: usize,
}
#[derive(Debug, Default)]
pub struct TrigramIndex {
inverted: FxHashMap<Trigram, RoaringBitmap>,
doc_trigrams: FxHashMap<u64, FxHashSet<Trigram>>,
all_docs: RoaringBitmap,
}
impl TrigramIndex {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.all_docs.is_empty()
}
#[must_use]
pub fn doc_count(&self) -> u64 {
self.all_docs.len()
}
pub fn insert(&mut self, doc_id: u64, text: &str) {
assert!(
u32::try_from(doc_id).is_ok(),
"TrigramIndex: doc_id {doc_id} exceeds u32::MAX limit. Maximum 4B documents supported."
);
if self.doc_trigrams.contains_key(&doc_id) {
self.remove(doc_id);
}
let trigrams = extract_trigrams(text);
let trigram_set: FxHashSet<Trigram> = trigrams.iter().copied().collect();
self.doc_trigrams.insert(doc_id, trigram_set);
#[allow(clippy::cast_possible_truncation)]
let doc_id_u32 = doc_id as u32;
for trigram in trigrams {
self.inverted.entry(trigram).or_default().insert(doc_id_u32);
}
self.all_docs.insert(doc_id_u32);
}
pub fn remove(&mut self, doc_id: u64) {
if u32::try_from(doc_id).is_err() {
return;
}
#[allow(clippy::cast_possible_truncation)]
let doc_id_u32 = doc_id as u32;
if let Some(trigrams) = self.doc_trigrams.remove(&doc_id) {
for trigram in trigrams {
if let Some(bitmap) = self.inverted.get_mut(&trigram) {
bitmap.remove(doc_id_u32);
if bitmap.is_empty() {
self.inverted.remove(&trigram);
}
}
}
}
self.all_docs.remove(doc_id_u32);
}
#[must_use]
pub fn search_like(&self, pattern: &str) -> RoaringBitmap {
let trigrams = extract_trigrams_for_pattern(pattern);
self.intersect_trigram_bitmaps(pattern, &trigrams)
}
fn intersect_trigram_bitmaps(
&self,
pattern: &str,
trigrams: &HashSet<Trigram>,
) -> RoaringBitmap {
if pattern.is_empty() || trigrams.is_empty() {
return self.all_docs.clone();
}
let mut result: Option<RoaringBitmap> = None;
for trigram in trigrams {
match self.inverted.get(trigram) {
Some(bitmap) => {
result = Some(match result {
Some(acc) => acc & bitmap,
None => bitmap.clone(),
});
}
None => {
return RoaringBitmap::new();
}
}
}
result.unwrap_or_default()
}
#[must_use]
pub fn score_jaccard(&self, doc_id: u64, query_trigrams: &HashSet<Trigram>) -> f32 {
let Some(doc_trigrams) = self.doc_trigrams.get(&doc_id) else {
return 0.0;
};
if doc_trigrams.is_empty() || query_trigrams.is_empty() {
return 0.0;
}
let intersection = query_trigrams
.iter()
.filter(|t: &&Trigram| doc_trigrams.contains::<Trigram>(t))
.count();
let union = doc_trigrams.len() + query_trigrams.len() - intersection;
if union == 0 {
0.0
} else {
intersection as f32 / union as f32
}
}
#[must_use]
pub fn search_like_ranked(&self, pattern: &str, threshold: f32) -> Vec<(u64, f32)> {
if pattern.is_empty() {
return self
.all_docs
.iter()
.map(|id| (u64::from(id), 0.0f32))
.collect();
}
let query_trigrams = extract_trigrams_for_pattern(pattern);
let candidates = self.intersect_trigram_bitmaps(pattern, &query_trigrams);
if candidates.is_empty() {
return Vec::new();
}
let mut results: Vec<(u64, f32)> = candidates
.iter()
.map(|id| {
let doc_id = u64::from(id);
let score = self.score_jaccard(doc_id, &query_trigrams);
(doc_id, score)
})
.filter(|(_, score)| *score >= threshold)
.collect();
results.sort_by(|a, b| b.1.total_cmp(&a.1));
results
}
#[must_use]
pub fn stats(&self) -> TrigramStats {
let inverted_size = self.inverted.len() * (3 + 8); let bitmap_size: usize = self
.inverted
.values()
.map(roaring::RoaringBitmap::serialized_size)
.sum();
let doc_trigrams_size = self.doc_trigrams.len() * 64;
TrigramStats {
doc_count: self.all_docs.len(),
trigram_count: self.inverted.len(),
memory_bytes: inverted_size + bitmap_size + doc_trigrams_size,
}
}
}