use std::collections::HashMap;
use std::sync::Arc;
use parking_lot::RwLock;
#[derive(Debug, Clone, PartialEq)]
pub struct BM25Score {
pub rid: String,
pub score: f32,
}
pub trait BM25Index: Send + Sync {
fn index(&self, rid: &str, text: &str);
fn delete(&self, rid: &str) -> bool;
fn search(&self, query: &str, top_k: usize) -> Vec<BM25Score>;
fn doc_count(&self) -> usize;
}
pub fn ascii_lower_tokens(s: &str) -> Vec<String> {
s.split(|c: char| !c.is_ascii_alphanumeric())
.filter(|t| !t.is_empty())
.map(|t| t.to_ascii_lowercase())
.collect()
}
#[derive(Default)]
pub struct InMemoryBM25Index {
inner: Arc<RwLock<State>>,
k1: f32,
b: f32,
}
#[derive(Default)]
struct State {
docs: HashMap<String, DocStats>,
inverted: HashMap<String, HashMap<String, u32>>,
total_tokens: u64,
}
struct DocStats {
tf: HashMap<String, u32>,
length: u32,
}
impl InMemoryBM25Index {
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(State::default())),
k1: 1.2,
b: 0.75,
}
}
pub fn with_params(mut self, k1: f32, b: f32) -> Self {
self.k1 = k1;
self.b = b;
self
}
}
impl BM25Index for InMemoryBM25Index {
fn index(&self, rid: &str, text: &str) {
let tokens = ascii_lower_tokens(text);
let length = tokens.len() as u32;
let mut tf: HashMap<String, u32> = HashMap::new();
for t in &tokens {
*tf.entry(t.clone()).or_insert(0) += 1;
}
let mut g = self.inner.write();
if let Some(prev) = g.docs.remove(rid) {
g.total_tokens = g.total_tokens.saturating_sub(prev.length as u64);
for (term, count) in prev.tf.iter() {
if let Some(postings) = g.inverted.get_mut(term) {
postings.remove(rid);
if postings.is_empty() {
g.inverted.remove(term);
}
}
let _ = count; }
}
for (term, count) in &tf {
g.inverted
.entry(term.clone())
.or_default()
.insert(rid.to_string(), *count);
}
g.total_tokens = g.total_tokens.saturating_add(length as u64);
g.docs.insert(rid.to_string(), DocStats { tf, length });
}
fn delete(&self, rid: &str) -> bool {
let mut g = self.inner.write();
let Some(prev) = g.docs.remove(rid) else {
return false;
};
g.total_tokens = g.total_tokens.saturating_sub(prev.length as u64);
for term in prev.tf.keys() {
if let Some(postings) = g.inverted.get_mut(term) {
postings.remove(rid);
if postings.is_empty() {
g.inverted.remove(term);
}
}
}
true
}
fn search(&self, query: &str, top_k: usize) -> Vec<BM25Score> {
let q_tokens = ascii_lower_tokens(query);
if q_tokens.is_empty() {
return Vec::new();
}
let g = self.inner.read();
let n_docs = g.docs.len() as f32;
if n_docs == 0.0 {
return Vec::new();
}
let avg_dl = (g.total_tokens as f32) / n_docs;
let mut scores: HashMap<String, f32> = HashMap::new();
for term in &q_tokens {
let Some(postings) = g.inverted.get(term) else {
continue;
};
let df = postings.len() as f32;
let idf = (((n_docs - df + 0.5) / (df + 0.5)) + 1.0).ln();
for (rid, &tf_u) in postings {
let tf = tf_u as f32;
let dl = g.docs.get(rid).map(|d| d.length as f32).unwrap_or(0.0);
let denom = tf + self.k1 * (1.0 - self.b + self.b * (dl / avg_dl.max(1.0)));
let term_score = idf * ((tf * (self.k1 + 1.0)) / denom.max(1e-9));
*scores.entry(rid.clone()).or_insert(0.0) += term_score;
}
}
let mut hits: Vec<BM25Score> = scores
.into_iter()
.map(|(rid, score)| BM25Score { rid, score })
.collect();
hits.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.rid.cmp(&b.rid))
});
hits.truncate(top_k);
hits
}
fn doc_count(&self) -> usize {
self.inner.read().docs.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ascii_lower_tokens_basic() {
let toks = ascii_lower_tokens("Hello World! FOO_bar 123");
assert_eq!(toks, vec!["hello", "world", "foo", "bar", "123"]);
}
#[test]
fn empty_query_yields_empty_results() {
let idx = InMemoryBM25Index::new();
idx.index("a", "the quick brown fox");
assert!(idx.search("", 10).is_empty());
}
#[test]
fn empty_index_yields_empty_results() {
let idx = InMemoryBM25Index::new();
assert!(idx.search("anything", 10).is_empty());
}
#[test]
fn search_finds_documents_containing_term() {
let idx = InMemoryBM25Index::new();
idx.index("a", "rust programming language");
idx.index("b", "python programming language");
idx.index("c", "fishing tackle reviews");
let hits = idx.search("rust", 10);
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].rid, "a");
}
#[test]
fn rare_term_outranks_common_term() {
let idx = InMemoryBM25Index::new();
for i in 0..10 {
idx.index(&format!("common-{}", i), "the");
}
idx.index("rare", "rust the");
let hits = idx.search("rust", 10);
assert_eq!(hits[0].rid, "rare");
}
#[test]
fn term_frequency_increases_score_but_saturates() {
let idx = InMemoryBM25Index::new();
idx.index("once", "rust");
idx.index("twice", "rust rust");
idx.index("ten", "rust rust rust rust rust rust rust rust rust rust");
let hits = idx.search("rust", 10);
let by_rid: HashMap<&str, f32> = hits.iter().map(|h| (h.rid.as_str(), h.score)).collect();
assert!(by_rid["once"] < by_rid["twice"]);
assert!(by_rid["twice"] < by_rid["ten"]);
let gap_low = by_rid["twice"] - by_rid["once"];
let gap_high = by_rid["ten"] - by_rid["twice"];
assert!(
gap_low > gap_high,
"expected saturation: low gap {}, high gap {}",
gap_low,
gap_high
);
}
#[test]
fn longer_documents_penalized_via_length_norm() {
let idx = InMemoryBM25Index::new();
idx.index("short", "rust");
idx.index("long", "rust the the the the the the the the the");
let hits = idx.search("rust", 10);
let by_rid: HashMap<&str, f32> = hits.iter().map(|h| (h.rid.as_str(), h.score)).collect();
assert!(by_rid["short"] > by_rid["long"]);
}
#[test]
fn delete_removes_doc() {
let idx = InMemoryBM25Index::new();
idx.index("a", "rust programming");
idx.index("b", "rust language");
assert_eq!(idx.search("rust", 10).len(), 2);
assert!(idx.delete("a"));
assert_eq!(idx.search("rust", 10).len(), 1);
assert!(!idx.delete("a"));
}
#[test]
fn re_index_replaces_old_text() {
let idx = InMemoryBM25Index::new();
idx.index("a", "rust");
idx.index("a", "python");
assert!(idx.search("rust", 10).is_empty());
assert_eq!(idx.search("python", 10).len(), 1);
}
#[test]
fn doc_count_tracks_inserts_and_deletes() {
let idx = InMemoryBM25Index::new();
assert_eq!(idx.doc_count(), 0);
idx.index("a", "x");
idx.index("b", "y");
assert_eq!(idx.doc_count(), 2);
idx.delete("a");
assert_eq!(idx.doc_count(), 1);
}
#[test]
fn truncates_to_top_k() {
let idx = InMemoryBM25Index::new();
for i in 0..10 {
idx.index(&format!("d{}", i), "rust");
}
let hits = idx.search("rust", 3);
assert_eq!(hits.len(), 3);
}
#[test]
fn multi_term_query_combines_scores() {
let idx = InMemoryBM25Index::new();
idx.index("a", "rust");
idx.index("b", "programming");
idx.index("c", "rust programming");
let hits = idx.search("rust programming", 10);
assert_eq!(hits[0].rid, "c");
}
#[test]
fn case_insensitive_matching() {
let idx = InMemoryBM25Index::new();
idx.index("a", "Rust Programming Language");
let hits = idx.search("RUST", 10);
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].rid, "a");
}
#[test]
fn dyn_dispatch() {
let idx: Arc<dyn BM25Index> = Arc::new(InMemoryBM25Index::new());
idx.index("a", "rust");
assert_eq!(idx.search("rust", 10).len(), 1);
}
#[test]
fn deterministic_tie_break_by_rid() {
let idx = InMemoryBM25Index::new();
idx.index("zzz", "rust");
idx.index("aaa", "rust");
idx.index("mmm", "rust");
let hits = idx.search("rust", 10);
let ids: Vec<&str> = hits.iter().map(|h| h.rid.as_str()).collect();
assert_eq!(ids, vec!["aaa", "mmm", "zzz"]);
}
}