use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
const DEFAULT_BM25_CORPUS_CAP: usize = 50_000;
fn bm25_corpus_cap() -> usize {
std::env::var("TRUSTY_BM25_CORPUS_CAP")
.ok()
.and_then(|v| v.parse().ok())
.filter(|&n: &usize| n > 0)
.unwrap_or(DEFAULT_BM25_CORPUS_CAP)
}
static BM25_CAP_LOGGED: AtomicBool = AtomicBool::new(false);
pub fn tokenize(text: &str) -> Vec<String> {
let mut tokens: Vec<String> = Vec::new();
for raw in text.split(|c: char| !c.is_alphanumeric()) {
if raw.is_empty() {
continue;
}
tokens.push(raw.to_lowercase());
let camel_parts = split_camel_case(raw);
if camel_parts.len() > 1 {
tokens.extend(camel_parts.iter().map(|s| s.to_lowercase()));
}
let digit_parts = split_on_digits(raw);
if digit_parts.len() > 1 {
tokens.extend(digit_parts.iter().map(|s| s.to_lowercase()));
}
}
tokens.sort_unstable();
tokens.dedup();
tokens
}
fn split_camel_case(s: &str) -> Vec<&str> {
let bytes_len = s.len();
let chars: Vec<(usize, char)> = s.char_indices().collect();
if chars.len() < 2 {
return vec![s];
}
let mut bounds: Vec<usize> = vec![0];
for i in 1..chars.len() {
let (idx, c) = chars[i];
let (_, prev) = chars[i - 1];
let lower_to_upper = (prev.is_lowercase() || prev.is_ascii_digit()) && c.is_uppercase();
let acronym_to_word = prev.is_uppercase()
&& c.is_uppercase()
&& i + 1 < chars.len()
&& chars[i + 1].1.is_lowercase();
if lower_to_upper || acronym_to_word {
bounds.push(idx);
}
}
bounds.push(bytes_len);
bounds
.windows(2)
.map(|w| &s[w[0]..w[1]])
.filter(|p| !p.is_empty())
.collect()
}
fn split_on_digits(s: &str) -> Vec<&str> {
let bytes_len = s.len();
let chars: Vec<(usize, char)> = s.char_indices().collect();
if chars.len() < 2 {
return vec![s];
}
let mut bounds: Vec<usize> = vec![0];
for i in 1..chars.len() {
let (idx, c) = chars[i];
let (_, prev) = chars[i - 1];
let alpha_to_digit = prev.is_alphabetic() && c.is_ascii_digit();
let digit_to_alpha = prev.is_ascii_digit() && c.is_alphabetic();
if alpha_to_digit || digit_to_alpha {
bounds.push(idx);
}
}
bounds.push(bytes_len);
bounds
.windows(2)
.map(|w| &s[w[0]..w[1]])
.filter(|p| !p.is_empty())
.collect()
}
pub struct BM25Index {
k1: f32,
b: f32,
doc_freqs: HashMap<String, usize>,
doc_lengths: Vec<Option<usize>>,
inverted: HashMap<String, Vec<(usize, usize)>>,
total_doc_length: u64,
id_to_slot: HashMap<String, usize>,
slot_to_id: Vec<Option<String>>,
free_slots: Vec<usize>,
doc_terms: Vec<Option<Vec<String>>>,
live_docs: usize,
}
impl BM25Index {
pub fn new() -> Self {
Self {
k1: 1.5,
b: 0.75,
doc_freqs: HashMap::new(),
doc_lengths: Vec::new(),
inverted: HashMap::new(),
total_doc_length: 0,
id_to_slot: HashMap::new(),
slot_to_id: Vec::new(),
free_slots: Vec::new(),
doc_terms: Vec::new(),
live_docs: 0,
}
}
pub fn len(&self) -> usize {
self.live_docs
}
pub fn is_empty(&self) -> bool {
self.live_docs == 0
}
fn avg_doc_len(&self) -> f32 {
if self.live_docs == 0 {
0.0
} else {
self.total_doc_length as f32 / self.live_docs as f32
}
}
fn allocate_slot(&mut self, doc_id: &str) -> usize {
if let Some(slot) = self.free_slots.pop() {
self.slot_to_id[slot] = Some(doc_id.to_string());
self.doc_lengths[slot] = Some(0);
self.doc_terms[slot] = Some(Vec::new());
slot
} else {
let slot = self.slot_to_id.len();
self.slot_to_id.push(Some(doc_id.to_string()));
self.doc_lengths.push(Some(0));
self.doc_terms.push(Some(Vec::new()));
slot
}
}
pub fn upsert_document(&mut self, doc_id: &str, text: &str) {
if self.id_to_slot.contains_key(doc_id) {
self.remove_document(doc_id);
} else {
let cap = bm25_corpus_cap();
if self.live_docs >= cap {
if !BM25_CAP_LOGGED.swap(true, Ordering::Relaxed) {
tracing::warn!(
cap,
live_docs = self.live_docs,
"BM25 corpus cap reached — dropping further new documents \
(override with TRUSTY_BM25_CORPUS_CAP)"
);
}
return;
}
}
let slot = self.allocate_slot(doc_id);
self.id_to_slot.insert(doc_id.to_string(), slot);
self.live_docs += 1;
let tokens = tokenize(text);
let doc_len = tokens.len();
self.doc_lengths[slot] = Some(doc_len);
self.total_doc_length = self.total_doc_length.saturating_add(doc_len as u64);
let mut term_counts: HashMap<&str, usize> = HashMap::new();
for t in &tokens {
*term_counts.entry(t.as_str()).or_default() += 1;
}
for (term, count) in term_counts {
*self.doc_freqs.entry(term.to_string()).or_default() += 1;
self.inverted
.entry(term.to_string())
.or_default()
.push((slot, count));
}
self.doc_terms[slot] = Some(tokens);
}
pub fn add_document(&mut self, doc_id: usize, text: &str) {
let synthetic = format!("__legacy:{doc_id}");
self.upsert_document(&synthetic, text);
}
pub fn remove_document(&mut self, doc_id: &str) {
let Some(slot) = self.id_to_slot.remove(doc_id) else {
return;
};
if let Some(terms) = self.doc_terms[slot].take() {
let mut unique = terms.clone();
unique.sort_unstable();
unique.dedup();
for term in &unique {
if let Some(df) = self.doc_freqs.get_mut(term) {
*df = df.saturating_sub(1);
if *df == 0 {
self.doc_freqs.remove(term);
}
}
if let Some(postings) = self.inverted.get_mut(term) {
postings.retain(|(s, _)| *s != slot);
if postings.is_empty() {
self.inverted.remove(term);
}
}
}
}
if let Some(old_len) = self.doc_lengths[slot] {
self.total_doc_length = self.total_doc_length.saturating_sub(old_len as u64);
}
self.doc_lengths[slot] = None;
self.slot_to_id[slot] = None;
self.free_slots.push(slot);
self.live_docs = self.live_docs.saturating_sub(1);
}
pub fn score_query_all(&self, query: &str, top_k: usize) -> Vec<(String, f32)> {
if self.live_docs == 0 || top_k == 0 {
return Vec::new();
}
let n = self.live_docs as f32;
let avg = self.avg_doc_len().max(1.0);
let mut acc: HashMap<usize, f32> = HashMap::new();
for term in tokenize(query) {
let df = match self.doc_freqs.get(&term) {
Some(d) if *d > 0 => *d as f32,
_ => continue,
};
let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
let Some(postings) = self.inverted.get(&term) else {
continue;
};
for (slot, count) in postings {
let dl = match self.doc_lengths.get(*slot).and_then(|x| *x) {
Some(l) => l as f32,
None => continue,
};
let tf = *count as f32;
let tf_norm =
tf * (self.k1 + 1.0) / (tf + self.k1 * (1.0 - self.b + self.b * dl / avg));
*acc.entry(*slot).or_insert(0.0) += idf * tf_norm;
}
}
let mut scored: Vec<(String, f32)> = acc
.into_iter()
.filter(|(_, s)| *s > 0.0)
.filter_map(|(slot, score)| {
self.slot_to_id
.get(slot)
.and_then(|o| o.clone())
.map(|id| (id, score))
})
.collect();
scored.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
scored.truncate(top_k);
scored
}
pub fn score(&self, query: &str, doc_id: usize) -> f32 {
let n = self.live_docs as f32;
let dl = match self.doc_lengths.get(doc_id).and_then(|x| *x) {
Some(l) => l as f32,
None => return 0.0,
};
let mut score = 0.0f32;
for term in tokenize(query) {
let df = *self.doc_freqs.get(&term).unwrap_or(&0) as f32;
if df == 0.0 {
continue;
}
let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
let tf = self
.inverted
.get(&term)
.and_then(|v| v.iter().find(|(id, _)| *id == doc_id))
.map(|(_, c)| *c as f32)
.unwrap_or(0.0);
let tf_norm = tf * (self.k1 + 1.0)
/ (tf + self.k1 * (1.0 - self.b + self.b * dl / self.avg_doc_len().max(1.0)));
score += idf * tf_norm;
}
score
}
}
impl Default for BM25Index {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bm25_scores_relevant_doc_higher() {
let mut idx = BM25Index::new();
idx.add_document(0, "authentication login password secure");
idx.add_document(1, "rendering ui components svelte");
let s0 = idx.score("authentication", 0);
let s1 = idx.score("authentication", 1);
assert!(s0 > s1, "relevant doc should score higher: {s0} vs {s1}");
}
#[test]
fn tokenize_splits_code() {
let tokens = tokenize("fn search_hybrid(query: &str) -> Vec<Hit>");
assert!(tokens.contains(&"search".to_string()));
assert!(tokens.contains(&"hybrid".to_string()));
assert!(tokens.contains(&"query".to_string()));
}
#[test]
fn tokenize_camel_case_pascal() {
let tokens = tokenize("CodeIndexer");
assert!(tokens.contains(&"code".to_string()), "got {tokens:?}");
assert!(tokens.contains(&"indexer".to_string()), "got {tokens:?}");
assert!(
tokens.contains(&"codeindexer".to_string()),
"got {tokens:?}"
);
}
#[test]
fn tokenize_pascal_two_words() {
let tokens = tokenize("UsearchStore");
assert!(tokens.contains(&"usearch".to_string()), "got {tokens:?}");
assert!(tokens.contains(&"store".to_string()), "got {tokens:?}");
}
#[test]
fn tokenize_snake_case() {
let tokens = tokenize("use_kg_first");
assert!(tokens.contains(&"use".to_string()), "got {tokens:?}");
assert!(tokens.contains(&"kg".to_string()), "got {tokens:?}");
assert!(tokens.contains(&"first".to_string()), "got {tokens:?}");
}
#[test]
fn tokenize_alpha_digit_split() {
let tokens = tokenize("HTTP2Client");
assert!(tokens.contains(&"http".to_string()), "got {tokens:?}");
assert!(tokens.contains(&"2".to_string()), "got {tokens:?}");
assert!(tokens.contains(&"client".to_string()), "got {tokens:?}");
}
#[test]
fn tokenize_acronym_then_word() {
let tokens = tokenize("HTTPSClient");
assert!(tokens.contains(&"https".to_string()), "got {tokens:?}");
assert!(tokens.contains(&"client".to_string()), "got {tokens:?}");
}
#[test]
fn bm25_incremental_upsert_and_remove() {
let mut idx = BM25Index::new();
idx.upsert_document("a", "authentication login password");
idx.upsert_document("b", "rendering ui components svelte");
idx.upsert_document("c", "database connection pool postgres");
assert_eq!(idx.len(), 3);
let hits = idx.score_query_all("authentication", 10);
assert!(hits.iter().any(|(id, _)| id == "a"));
assert!(!hits.iter().any(|(id, _)| id == "b"));
idx.remove_document("a");
assert_eq!(idx.len(), 2);
let hits_after = idx.score_query_all("authentication", 10);
assert!(!hits_after.iter().any(|(id, _)| id == "a"));
let svelte_hits = idx.score_query_all("svelte", 10);
assert!(svelte_hits.iter().any(|(id, _)| id == "b"));
}
#[test]
fn bm25_upsert_replaces_existing_doc() {
let mut idx = BM25Index::new();
idx.upsert_document("a", "alpha beta gamma");
idx.upsert_document("a", "delta epsilon");
assert_eq!(idx.len(), 1);
assert!(idx.score_query_all("alpha", 10).is_empty());
assert!(!idx.score_query_all("delta", 10).is_empty());
}
#[test]
fn score_query_all_returns_sorted_unique_results() {
let mut idx = BM25Index::new();
idx.upsert_document("a", "search rust async tokio");
idx.upsert_document("b", "search rust");
idx.upsert_document("c", "unrelated content");
let hits = idx.score_query_all("rust async", 10);
for w in hits.windows(2) {
assert!(w[0].1 >= w[1].1, "results must be sorted desc: {hits:?}");
}
let mut ids: Vec<&str> = hits.iter().map(|(id, _)| id.as_str()).collect();
ids.sort();
let unique = ids.len();
ids.dedup();
assert_eq!(unique, ids.len());
}
#[test]
fn tokenize_dedups_and_sorts() {
let tokens = tokenize("foo foo bar");
let foos: Vec<&String> = tokens.iter().filter(|t| t.as_str() == "foo").collect();
assert_eq!(foos.len(), 1, "duplicates must collapse: {tokens:?}");
let mut sorted = tokens.clone();
sorted.sort();
assert_eq!(tokens, sorted, "tokens must be sorted: {tokens:?}");
}
#[test]
fn bm25_corpus_cap_env_override() {
let prev = std::env::var("TRUSTY_BM25_CORPUS_CAP").ok();
unsafe {
std::env::set_var("TRUSTY_BM25_CORPUS_CAP", "0");
}
assert_eq!(
bm25_corpus_cap(),
DEFAULT_BM25_CORPUS_CAP,
"zero must fall back to default"
);
unsafe {
std::env::set_var("TRUSTY_BM25_CORPUS_CAP", "123");
}
assert_eq!(bm25_corpus_cap(), 123, "positive value must be honoured");
match prev {
Some(v) => unsafe { std::env::set_var("TRUSTY_BM25_CORPUS_CAP", v) },
None => unsafe { std::env::remove_var("TRUSTY_BM25_CORPUS_CAP") },
}
}
}