use std::collections::HashMap;
use parking_lot::RwLock;
use super::tokenizer::tokenize;
#[derive(Debug, Clone, Copy)]
pub struct Bm25Params {
pub k1: f64,
pub b: f64,
}
impl Default for Bm25Params {
fn default() -> Self {
Self { k1: 1.2, b: 0.75 }
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Bm25Score {
pub doc_id: u64,
pub score: f64,
}
#[derive(Default)]
struct IndexState {
postings: HashMap<String, Vec<(u64, u32)>>,
doc_len: HashMap<u64, u32>,
total_tokens: u64,
}
pub struct Bm25Index {
params: Bm25Params,
state: RwLock<IndexState>,
}
impl Bm25Index {
#[must_use]
pub fn new() -> Self {
Self::with_params(Bm25Params::default())
}
#[must_use]
pub fn with_params(params: Bm25Params) -> Self {
Self {
params,
state: RwLock::new(IndexState::default()),
}
}
pub fn add_document(&self, doc_id: u64, text: &str) {
let tokens = tokenize(text);
let mut tf: HashMap<String, u32> = HashMap::new();
for tok in &tokens {
*tf.entry(tok.clone()).or_insert(0) += 1;
}
let len = tokens.len() as u32;
let mut s = self.state.write();
for (term, count) in tf {
s.postings.entry(term).or_default().push((doc_id, count));
}
s.doc_len.insert(doc_id, len);
s.total_tokens += u64::from(len);
}
pub fn remove_document(&self, doc_id: u64) -> bool {
let mut s = self.state.write();
let Some(len) = s.doc_len.remove(&doc_id) else {
return false;
};
s.total_tokens = s.total_tokens.saturating_sub(u64::from(len));
for postings in s.postings.values_mut() {
postings.retain(|(d, _)| *d != doc_id);
}
s.postings.retain(|_, p| !p.is_empty());
true
}
#[must_use]
pub fn doc_count(&self) -> usize {
self.state.read().doc_len.len()
}
#[must_use]
pub fn average_doc_length(&self) -> f64 {
let s = self.state.read();
if s.doc_len.is_empty() {
0.0
} else {
s.total_tokens as f64 / s.doc_len.len() as f64
}
}
pub fn score(&self, query: &str, limit: Option<usize>) -> Vec<Bm25Score> {
let q_tokens = tokenize(query);
if q_tokens.is_empty() {
return Vec::new();
}
let s = self.state.read();
let n = s.doc_len.len() as f64;
if n == 0.0 {
return Vec::new();
}
let avgdl = s.total_tokens as f64 / n;
let Bm25Params { k1, b } = self.params;
let mut scores: HashMap<u64, f64> = HashMap::new();
let mut seen: std::collections::HashSet<&str> = std::collections::HashSet::with_capacity(q_tokens.len());
for q in &q_tokens {
if !seen.insert(q.as_str()) {
continue;
}
let Some(postings) = s.postings.get(q) else {
continue;
};
let df = postings.len() as f64;
let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
for &(doc_id, tf) in postings {
let dl = f64::from(*s.doc_len.get(&doc_id).unwrap_or(&0));
let tf_f = f64::from(tf);
let denom = tf_f + k1 * (1.0 - b + b * (dl / avgdl).max(0.0));
let contribution = idf * (tf_f * (k1 + 1.0)) / denom.max(f64::MIN_POSITIVE);
*scores.entry(doc_id).or_insert(0.0) += contribution;
}
}
let mut out: Vec<Bm25Score> = scores
.into_iter()
.map(|(doc_id, score)| Bm25Score { doc_id, score })
.collect();
out.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.doc_id.cmp(&b.doc_id))
});
if let Some(k) = limit {
out.truncate(k);
}
out
}
pub fn matches(&self, doc_id: u64, query: &str) -> bool {
let q_tokens = tokenize(query);
if q_tokens.is_empty() {
return false;
}
let s = self.state.read();
for q in &q_tokens {
if let Some(postings) = s.postings.get(q) {
if postings.iter().any(|(d, _)| *d == doc_id) {
return true;
}
}
}
false
}
}
impl Default for Bm25Index {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn corpus() -> Bm25Index {
let idx = Bm25Index::new();
idx.add_document(1, "the quick brown fox jumps over the lazy dog");
idx.add_document(2, "a fast brown fox leaps over a sleepy dog");
idx.add_document(3, "the lazy cat sat on the mat");
idx.add_document(4, "stock market closes higher on tuesday");
idx
}
#[test]
fn empty_query_returns_empty() {
let idx = corpus();
assert!(idx.score("", None).is_empty());
assert!(idx.score(" ", None).is_empty());
}
#[test]
fn single_term_picks_relevant_doc() {
let idx = corpus();
let res = idx.score("market", None);
assert_eq!(res.len(), 1);
assert_eq!(res[0].doc_id, 4);
assert!(res[0].score > 0.0);
}
#[test]
fn multi_term_ranks_dog_docs_first() {
let idx = corpus();
let res = idx.score("brown fox dog", None);
assert!(res.len() >= 2);
let top_ids: Vec<_> = res.iter().take(2).map(|s| s.doc_id).collect();
assert!(top_ids.contains(&1));
assert!(top_ids.contains(&2));
assert!(!res.iter().any(|s| s.doc_id == 3));
assert!(!res.iter().any(|s| s.doc_id == 4));
}
#[test]
fn limit_caps_result_count() {
let idx = corpus();
let res = idx.score("the dog cat fox market", Some(2));
assert!(res.len() <= 2);
}
#[test]
fn matches_returns_true_for_known_doc() {
let idx = corpus();
assert!(idx.matches(1, "fox"));
assert!(!idx.matches(1, "elephant"));
assert!(!idx.matches(99, "fox"));
}
#[test]
fn doc_count_and_average_length_are_correct() {
let idx = corpus();
assert_eq!(idx.doc_count(), 4);
let avg = idx.average_doc_length();
assert!((avg - 7.75).abs() < 0.01);
}
#[test]
fn remove_document_drops_postings_and_score() {
let idx = corpus();
assert!(idx.remove_document(4));
assert_eq!(idx.doc_count(), 3);
let res = idx.score("market", None);
assert!(res.is_empty(), "doc 4 should be gone");
assert!(!idx.remove_document(4));
}
#[test]
fn duplicate_query_tokens_dont_double_count() {
let idx = corpus();
let single = idx.score("fox", None);
let dupe = idx.score("fox fox fox", None);
assert_eq!(single.len(), dupe.len());
for (s, d) in single.iter().zip(dupe.iter()) {
assert_eq!(s.doc_id, d.doc_id);
assert!((s.score - d.score).abs() < 1e-9);
}
}
#[test]
fn idf_is_non_negative_for_common_term() {
let idx = corpus();
let res = idx.score("the", None);
for r in &res {
assert!(r.score >= 0.0);
}
}
}