use std::collections::HashMap;
#[cfg(all(not(target_arch = "wasm32"), feature = "parallel"))]
use rayon::prelude::*;
#[derive(Debug, Clone, Copy)]
pub struct Bm25Config {
pub k1: f32,
pub b: f32,
}
impl Default for Bm25Config {
fn default() -> Self {
Self { k1: 1.2, b: 0.75 }
}
}
#[derive(Debug, Clone)]
struct Document {
id: String,
term_freqs: HashMap<String, u32>,
length: usize,
}
#[derive(Debug, Clone)]
pub struct Bm25Index {
documents: Vec<Document>,
doc_index: HashMap<String, usize>,
doc_freqs: HashMap<String, u32>,
total_length: usize,
config: Bm25Config,
}
impl Default for Bm25Index {
fn default() -> Self {
Self::new()
}
}
impl Bm25Index {
pub fn new() -> Self {
Self::with_config(Bm25Config::default())
}
pub fn with_config(config: Bm25Config) -> Self {
Self {
documents: Vec::new(),
doc_index: HashMap::new(),
doc_freqs: HashMap::new(),
total_length: 0,
config,
}
}
pub fn add_document<T: AsRef<str>>(&mut self, id: &str, tokens: &[T]) {
if let Some(&idx) = self.doc_index.get(id) {
self.remove_document_at(idx);
}
let mut term_freqs: HashMap<String, u32> = HashMap::new();
for token in tokens {
*term_freqs.entry(token.as_ref().to_string()).or_insert(0) += 1;
}
for term in term_freqs.keys() {
*self.doc_freqs.entry(term.clone()).or_insert(0) += 1;
}
let doc = Document {
id: id.to_string(),
term_freqs,
length: tokens.len(),
};
let idx = self.documents.len();
self.documents.push(doc);
self.doc_index.insert(id.to_string(), idx);
self.total_length += tokens.len();
}
pub fn remove_document(&mut self, id: &str) {
if let Some(&idx) = self.doc_index.get(id) {
self.remove_document_at(idx);
}
}
fn remove_document_at(&mut self, idx: usize) {
let doc = &self.documents[idx];
for term in doc.term_freqs.keys() {
if let Some(df) = self.doc_freqs.get_mut(term) {
*df = df.saturating_sub(1);
}
}
self.total_length = self.total_length.saturating_sub(doc.length);
let id = doc.id.clone();
self.doc_index.remove(&id);
self.documents.swap_remove(idx);
if idx < self.documents.len() {
let swapped_id = &self.documents[idx].id;
self.doc_index.insert(swapped_id.clone(), idx);
}
}
pub fn search<T: AsRef<str>>(&self, query_tokens: &[T], top_k: usize) -> Vec<(String, f32)> {
if self.documents.is_empty() || query_tokens.is_empty() {
return Vec::new();
}
let n = self.documents.len() as f32;
let avgdl = self.total_length as f32 / n;
let mut query_terms = Vec::new();
let mut idf_values = Vec::new();
let mut seen_terms = HashMap::new();
for token in query_tokens {
seen_terms.entry(token.as_ref()).or_insert(());
}
for term in seen_terms.keys() {
let df = self.doc_freqs.get(*term).copied().unwrap_or(0) as f32;
let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
if idf > 0.0 {
query_terms.push(*term);
idf_values.push(idf);
}
}
if query_terms.is_empty() {
return Vec::new();
}
let k1 = self.config.k1;
let b = self.config.b;
let k1_plus_1 = k1 + 1.0;
let c1 = k1 * (1.0 - b);
let c2 = k1 * b / avgdl;
#[cfg(all(not(target_arch = "wasm32"), feature = "parallel"))]
let mut scores: Vec<(usize, f32)> = self
.documents
.par_iter()
.enumerate()
.map(|(idx, doc)| {
let score = self.score_document(doc, &query_terms, &idf_values, k1_plus_1, c1, c2);
(idx, score)
})
.filter(|(_, score)| *score > 0.0)
.collect();
#[cfg(any(target_arch = "wasm32", not(feature = "parallel")))]
let mut scores: Vec<(usize, f32)> = self
.documents
.iter()
.enumerate()
.map(|(idx, doc)| {
let score = self.score_document(doc, &query_terms, &idf_values, k1_plus_1, c1, c2);
(idx, score)
})
.filter(|(_, score)| *score > 0.0)
.collect();
scores.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(top_k);
scores
.into_iter()
.map(|(idx, score)| (self.documents[idx].id.clone(), score))
.collect()
}
fn score_document(
&self,
doc: &Document,
query_terms: &[&str],
idf_values: &[f32],
k1_plus_1: f32,
c1: f32,
c2: f32,
) -> f32 {
let mut score = 0.0;
let doc_len = doc.length as f32;
for (i, term) in query_terms.iter().enumerate() {
let tf = match doc.term_freqs.get(*term) {
Some(&tf) => tf as f32,
None => continue,
};
let idf = idf_values[i];
let numerator = tf * k1_plus_1;
let denominator = tf + c1 + c2 * doc_len;
score += idf * numerator / denominator;
}
score
}
pub fn clear(&mut self) {
self.documents.clear();
self.doc_index.clear();
self.doc_freqs.clear();
self.total_length = 0;
}
pub fn len(&self) -> usize {
self.documents.len()
}
pub fn is_empty(&self) -> bool {
self.documents.is_empty()
}
pub fn avg_doc_length(&self) -> f32 {
if self.documents.is_empty() {
0.0
} else {
self.total_length as f32 / self.documents.len() as f32
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_document() {
let mut index = Bm25Index::new();
index.add_document("doc1", &["hello", "world"]);
assert_eq!(index.len(), 1);
}
#[test]
fn test_search_exact_match() {
let mut index = Bm25Index::new();
index.add_document("doc1", &["hello", "world"]);
index.add_document("doc2", &["hello", "rust"]);
let results = index.search(&["hello", "world"], 10);
assert_eq!(results[0].0, "doc1");
}
#[test]
fn test_search_partial_match() {
let mut index = Bm25Index::new();
index.add_document("doc1", &["hello", "world"]);
index.add_document("doc2", &["goodbye", "world"]);
let results = index.search(&["hello"], 10);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, "doc1");
}
#[test]
fn test_search_empty_index() {
let index = Bm25Index::new();
let results = index.search(&["hello"], 10);
assert!(results.is_empty());
}
#[test]
fn test_search_empty_query() {
let mut index = Bm25Index::new();
index.add_document("doc1", &["hello", "world"]);
let results: Vec<(String, f32)> = index.search::<&str>(&[], 10);
assert!(results.is_empty());
}
#[test]
fn test_remove_document() {
let mut index = Bm25Index::new();
index.add_document("doc1", &["hello", "world"]);
index.add_document("doc2", &["hello", "rust"]);
index.remove_document("doc1");
assert_eq!(index.len(), 1);
let results = index.search(&["world"], 10);
assert!(results.is_empty());
}
#[test]
fn test_replace_document() {
let mut index = Bm25Index::new();
index.add_document("doc1", &["hello", "world"]);
index.add_document("doc1", &["goodbye", "rust"]);
assert_eq!(index.len(), 1);
let results = index.search(&["rust"], 10);
assert_eq!(results[0].0, "doc1");
}
#[test]
fn test_top_k() {
let mut index = Bm25Index::new();
index.add_document("doc1", &["hello", "world"]);
index.add_document("doc2", &["hello", "rust"]);
index.add_document("doc3", &["hello", "python"]);
let results = index.search(&["hello"], 2);
assert_eq!(results.len(), 2);
}
#[test]
fn test_idf_rare_term_higher_score() {
let mut index = Bm25Index::new();
index.add_document("doc1", &["rare", "common"]);
index.add_document("doc2", &["common"]);
index.add_document("doc3", &["common"]);
let results = index.search(&["rare", "common"], 10);
assert_eq!(results[0].0, "doc1");
}
#[test]
fn test_doc_length_normalization() {
let mut index = Bm25Index::new();
index.add_document("short", &["hello"]);
index.add_document(
"long",
&[
"hello", "hello", "hello", "hello", "hello", "other", "words", "here",
],
);
let results = index.search(&["hello"], 10);
assert_eq!(results.len(), 2);
}
#[test]
fn test_clear() {
let mut index = Bm25Index::new();
index.add_document("doc1", &["hello", "world"]);
index.clear();
assert!(index.is_empty());
let results = index.search(&["hello"], 10);
assert!(results.is_empty());
}
#[test]
fn test_avg_doc_length() {
let mut index = Bm25Index::new();
index.add_document("doc1", &["a", "b", "c"]);
index.add_document("doc2", &["x", "y"]);
assert_eq!(index.avg_doc_length(), 2.5);
}
#[test]
fn test_custom_config() {
let config = Bm25Config { k1: 2.0, b: 0.5 };
let index = Bm25Index::with_config(config);
assert_eq!(index.config.k1, 2.0);
assert_eq!(index.config.b, 0.5);
}
}