use std::collections::HashMap;
use serde::{Deserialize, Serialize};
type DocFrequency = Vec<(String, u32)>;
#[derive(Debug, Serialize, Deserialize)]
pub struct Bm25Index {
term_to_counts: HashMap<String, DocFrequency>, doc_lengths: HashMap<String, u32>, avg_dl: f32, n_docs: u32, }
impl Bm25Index {
pub fn new() -> Self {
Bm25Index {
term_to_counts: HashMap::new(),
doc_lengths: HashMap::new(),
avg_dl: 0.0,
n_docs: 0,
}
}
pub fn load(path: &str) -> Self {
match std::fs::read_to_string(path) {
Ok(contents) => serde_json::from_str(&contents).unwrap(),
Err(_) => Self::new(),
}
}
pub fn save(&self, path: &str) {
let contents = serde_json::to_string(self).unwrap();
std::fs::write(path, contents).unwrap();
}
pub fn index_record(&mut self, doc_id: &str, tokens: &[String]) {
let mut term_frequencies: HashMap<&str, u32> = HashMap::new();
for token in tokens {
if term_frequencies.contains_key(token.as_str()) {
term_frequencies.insert(token.as_str(), term_frequencies.get(token.as_str()).copied().unwrap() + 1);
} else {
term_frequencies.insert(token.as_str(), 1);
}
}
for (term, term_frequency) in &term_frequencies {
self.term_to_counts
.entry(term.to_string())
.or_default()
.push((doc_id.to_string(), *term_frequency));
}
self.doc_lengths.insert(doc_id.to_string(), tokens.len() as u32);
self.n_docs += 1;
self.avg_dl = self.doc_lengths.values().sum::<u32>() as f32 / self.n_docs as f32;
}
pub fn remove_record(&mut self, doc_id: &str, tokens: &[String]) {
for token in tokens {
if let Some(doc_freqs) = self.term_to_counts.get_mut(token) {
doc_freqs.retain(|(id, _)| id != doc_id);
if doc_freqs.is_empty() {
self.term_to_counts.remove(token);
}
}
}
if let Some(length) = self.doc_lengths.remove(doc_id) {
self.n_docs -= 1;
if self.n_docs > 0 {
self.avg_dl = (self.avg_dl * (self.n_docs as f32 + 1.0) - length as f32) / self.n_docs as f32;
} else {
self.avg_dl = 0.0;
}
}
}
pub fn score(&self, query_tokens: &[String]) -> Vec<(String, f32)> {
const K: f32 = 1.5; const B: f32 = 0.75;
let mut scores: HashMap<String, f32> = HashMap::new();
for token in query_tokens {
let Some(postings) = self.term_to_counts.get(token) else { continue };
let documents_count = postings.len() as f32;
let idf_numerator = self.n_docs as f32 - documents_count + 0.5; let idf_denominator = documents_count + 0.5; let idf = (idf_numerator / idf_denominator + 1.0).ln();
for (doc_id, term_frequency) in postings {
let doc_len = *self.doc_lengths.get(doc_id).unwrap_or(&0) as f32;
let term_frequency = *term_frequency as f32;
let tf_numerator = term_frequency * (K + 1.0);
let length_norm = 1.0 - B + B * doc_len / self.avg_dl; let tf_denominator = term_frequency + K * length_norm;
let tf_norm = tf_numerator / tf_denominator;
*scores.entry(doc_id.clone()).or_insert(0.0) += idf * tf_norm;
}
}
let mut results: Vec<(String, f32)> = scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
results
}
}
#[cfg(test)]
mod tests {
use super::*;
fn toks(words: &[&str]) -> Vec<String> {
words.iter().map(|s| s.to_string()).collect()
}
#[test]
fn matching_doc_is_returned() {
let mut idx = Bm25Index::new();
idx.index_record("doc1", &toks(&["rust", "fast"]));
idx.index_record("doc2", &toks(&["python", "slow"]));
let scores = idx.score(&toks(&["rust"]));
assert_eq!(scores.len(), 1);
assert_eq!(scores[0].0, "doc1");
}
#[test]
fn results_are_sorted_by_score_descending() {
let mut idx = Bm25Index::new();
idx.index_record("doc1", &toks(&["rust"]));
idx.index_record("doc2", &toks(&["rust", "rust", "rust"]));
let scores = idx.score(&toks(&["rust"]));
assert_eq!(scores.len(), 2);
assert_eq!(scores[0].0, "doc2");
assert!(scores[0].1 > scores[1].1);
}
#[test]
fn no_match_returns_empty() {
let mut idx = Bm25Index::new();
idx.index_record("doc1", &toks(&["rust"]));
assert!(idx.score(&toks(&["python"])).is_empty());
}
#[test]
fn query_on_empty_index_returns_empty() {
let idx = Bm25Index::new();
assert!(idx.score(&toks(&["rust"])).is_empty());
}
#[test]
fn remove_record_excludes_it_from_results() {
let mut idx = Bm25Index::new();
idx.index_record("doc1", &toks(&["rust"]));
idx.remove_record("doc1", &toks(&["rust"]));
assert!(idx.score(&toks(&["rust"])).is_empty());
}
#[test]
fn remove_one_of_two_docs_leaves_the_other() {
let mut idx = Bm25Index::new();
idx.index_record("doc1", &toks(&["rust"]));
idx.index_record("doc2", &toks(&["rust"]));
idx.remove_record("doc1", &toks(&["rust"]));
let scores = idx.score(&toks(&["rust"]));
assert_eq!(scores.len(), 1);
assert_eq!(scores[0].0, "doc2");
}
#[test]
fn rare_term_scores_higher_than_common_term() {
let mut idx = Bm25Index::new();
idx.index_record("doc1", &toks(&["common", "rare"]));
idx.index_record("doc2", &toks(&["common"]));
idx.index_record("doc3", &toks(&["common"]));
let scores = idx.score(&toks(&["rare"]));
assert_eq!(scores.len(), 1);
assert!(scores[0].1 > 0.0);
}
}