use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BM25Index {
inverted_index: HashMap<String, HashMap<String, u32>>,
doc_lengths: HashMap<String, u32>,
term_dfs: HashMap<String, u32>,
total_docs: u64,
k1: f32, b: f32, }
impl BM25Index {
#[inline]
#[must_use]
pub fn new() -> Self {
Self {
inverted_index: HashMap::new(),
doc_lengths: HashMap::new(),
term_dfs: HashMap::new(),
total_docs: 0,
k1: 1.5,
b: 0.75,
}
}
#[inline]
pub fn tokenize(text: &str) -> Vec<String> {
text.to_lowercase()
.split(|c: char| c.is_whitespace() || c.is_ascii_punctuation())
.map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()).to_string())
.filter(|s| !s.is_empty() && s.len() > 1) .collect()
}
pub fn insert_doc(&mut self, doc_id: &str, text: &str) {
self.delete_doc(doc_id);
let tokens = Self::tokenize(text);
let doc_len = tokens.len() as u32;
let mut term_freqs: HashMap<String, u32> = HashMap::new();
for token in &tokens {
*term_freqs.entry(token.clone()).or_insert(0) += 1;
}
for (term, tf) in &term_freqs {
self.inverted_index
.entry(term.clone())
.or_default()
.insert(doc_id.to_string(), *tf);
}
self.doc_lengths.insert(doc_id.to_string(), doc_len);
for term in term_freqs.keys() {
*self.term_dfs.entry(term.clone()).or_insert(0) += 1;
}
self.total_docs += 1;
}
pub fn delete_doc(&mut self, doc_id: &str) {
if self.doc_lengths.remove(doc_id).is_some() {
let mut terms_to_update = Vec::new();
for (term, docs) in &mut self.inverted_index {
if docs.remove(doc_id).is_some() {
terms_to_update.push(term.clone());
}
}
for term in terms_to_update {
if let Some(df) = self.term_dfs.get_mut(&term) {
*df = df.saturating_sub(1);
}
}
self.total_docs = self.total_docs.saturating_sub(1);
}
}
pub fn search(&self, query: &str, limit: usize) -> Vec<(String, f32)> {
if self.total_docs == 0 {
return Vec::new();
}
let query_terms = Self::tokenize(query);
if query_terms.is_empty() {
return Vec::new();
}
let avgdl = if self.total_docs > 0 {
self.doc_lengths.values().sum::<u32>() as f32 / self.total_docs as f32
} else {
0.0
};
let mut doc_scores: HashMap<String, f32> = HashMap::new();
for term in &query_terms {
if let Some(docs) = self.inverted_index.get(term) {
let df = self.term_dfs.get(term).copied().unwrap_or(0) as f32;
let idf = if df > 0.0 {
((self.total_docs as f32 - df + 0.5) / (df + 0.5)).ln().max(0.0)
} else {
0.0
};
for (doc_id, &tf) in docs {
if let Some(&doc_len) = self.doc_lengths.get(doc_id) {
let score = self.calculate_bm25_score(tf, doc_len, df as u32, self.total_docs, avgdl, idf);
*doc_scores.entry(doc_id.clone()).or_insert(0.0) += score;
}
}
}
}
let mut results: Vec<(String, f32)> = doc_scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(limit);
results
}
fn calculate_bm25_score(
&self,
tf: u32,
doc_len: u32,
_df: u32,
_total_docs: u64,
avgdl: f32,
idf: f32,
) -> f32 {
let tf_f32 = tf as f32;
let doc_len_f32 = doc_len as f32;
let numerator = tf_f32 * (self.k1 + 1.0);
let denominator = tf_f32 + self.k1 * (1.0 - self.b + self.b * (doc_len_f32 / avgdl));
idf * (numerator / denominator)
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.doc_lengths.len()
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.doc_lengths.is_empty()
}
}
impl Default for BM25Index {
fn default() -> Self {
Self::new()
}
}