use std::collections::HashMap;
pub type DocId = u64;
pub type Trigram = (char, char, char);
pub fn trigrams_of(text: &str) -> Vec<Trigram> {
let chars: Vec<char> = text.chars().flat_map(|c| c.to_lowercase()).collect();
if chars.len() < 3 {
return Vec::new();
}
let mut out: Vec<Trigram> = chars.windows(3).map(|w| (w[0], w[1], w[2])).collect();
out.sort_unstable();
out.dedup();
out
}
#[derive(Default)]
pub struct TrigramIndex {
postings: HashMap<Trigram, Vec<DocId>>,
docs: HashMap<DocId, String>,
}
impl TrigramIndex {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.docs.len()
}
pub fn is_empty(&self) -> bool {
self.docs.is_empty()
}
pub fn vocab_size(&self) -> usize {
self.postings.len()
}
pub fn contains(&self, doc_id: DocId) -> bool {
self.docs.contains_key(&doc_id)
}
pub fn doc_text(&self, doc_id: DocId) -> Option<&str> {
self.docs.get(&doc_id).map(|s| s.as_str())
}
pub fn documents(&self) -> impl Iterator<Item = (DocId, &str)> {
self.docs.iter().map(|(id, t)| (*id, t.as_str()))
}
pub fn insert(&mut self, doc_id: DocId, text: &str) {
if self.docs.contains_key(&doc_id) {
self.remove(doc_id);
}
for tri in trigrams_of(text) {
let postings = self.postings.entry(tri).or_default();
if let Err(idx) = postings.binary_search(&doc_id) {
postings.insert(idx, doc_id);
}
}
self.docs.insert(doc_id, text.to_string());
}
pub fn remove(&mut self, doc_id: DocId) -> bool {
let Some(text) = self.docs.remove(&doc_id) else {
return false;
};
for tri in trigrams_of(&text) {
if let Some(postings) = self.postings.get_mut(&tri) {
if let Ok(idx) = postings.binary_search(&doc_id) {
postings.remove(idx);
}
if postings.is_empty() {
self.postings.remove(&tri);
}
}
}
true
}
pub fn candidates(&self, required: &[Trigram]) -> Vec<DocId> {
if required.is_empty() {
return Vec::new();
}
let mut lists: Vec<&Vec<DocId>> = Vec::with_capacity(required.len());
for tri in required {
match self.postings.get(tri) {
Some(list) => lists.push(list),
None => return Vec::new(),
}
}
lists.sort_by_key(|l| l.len());
let mut acc: Vec<DocId> = lists[0].clone();
for list in &lists[1..] {
acc = sorted_intersect(&acc, list);
if acc.is_empty() {
break;
}
}
acc
}
}
fn sorted_intersect(a: &[DocId], b: &[DocId]) -> Vec<DocId> {
let mut out = Vec::with_capacity(a.len().min(b.len()));
let (mut i, mut j) = (0usize, 0usize);
while i < a.len() && j < b.len() {
match a[i].cmp(&b[j]) {
std::cmp::Ordering::Less => i += 1,
std::cmp::Ordering::Greater => j += 1,
std::cmp::Ordering::Equal => {
out.push(a[i]);
i += 1;
j += 1;
}
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trigrams_basic() {
assert_eq!(trigrams_of("ab"), Vec::<Trigram>::new());
assert_eq!(trigrams_of("abc"), vec![('a', 'b', 'c')]);
assert_eq!(trigrams_of("AAA"), vec![('a', 'a', 'a')]);
}
#[test]
fn test_insert_and_candidates() {
let mut idx = TrigramIndex::new();
idx.insert(1, "hello world");
idx.insert(2, "help me");
idx.insert(3, "world peace");
let c = idx.candidates(&trigrams_of("hel"));
assert_eq!(c, vec![1, 2]);
let c = idx.candidates(&trigrams_of("world"));
assert_eq!(c, vec![1, 3]);
assert!(idx.candidates(&trigrams_of("xyz")).is_empty());
}
#[test]
fn test_candidates_never_drop_a_true_match() {
let mut idx = TrigramIndex::new();
let corpus = [
(1, "fn parse_query() {}"),
(2, "let parser = build();"),
(3, "totally unrelated text"),
(4, "PARSE in caps"),
];
for (id, t) in corpus {
idx.insert(id, t);
}
let cands = idx.candidates(&trigrams_of("parse"));
assert!(cands.contains(&1));
assert!(cands.contains(&4));
assert!(!cands.contains(&3));
}
#[test]
fn test_remove_is_clean() {
let mut idx = TrigramIndex::new();
idx.insert(1, "alpha beta");
idx.insert(2, "alpha gamma");
let vocab_before = idx.vocab_size();
assert!(idx.remove(1));
assert!(!idx.contains(1));
assert_eq!(idx.len(), 1);
assert!(idx.candidates(&trigrams_of("beta")).is_empty());
assert_eq!(idx.candidates(&trigrams_of("alpha")), vec![2]);
assert!(idx.remove(2));
assert!(idx.is_empty());
assert_eq!(idx.vocab_size(), 0);
assert!(vocab_before > 0);
}
#[test]
fn test_reinsert_replaces() {
let mut idx = TrigramIndex::new();
idx.insert(1, "before");
idx.insert(1, "after change");
assert_eq!(idx.len(), 1);
assert!(idx.candidates(&trigrams_of("before")).is_empty());
assert_eq!(idx.candidates(&trigrams_of("change")), vec![1]);
}
}