use super::algorithm::{compute_idf, BM25Params};
use crate::vector_stores::Document;
use std::collections::HashMap;
pub struct BM25Index {
documents: Vec<Document>,
term_doc_freqs: HashMap<String, HashMap<usize, usize>>,
doc_term_freqs: Vec<HashMap<String, usize>>,
doc_lengths: Vec<usize>,
avgdl: f64,
n_docs: usize,
idf_cache: HashMap<String, f64>,
params: BM25Params,
}
impl BM25Index {
pub fn new() -> Self {
Self::with_params(BM25Params::default())
}
pub fn with_params(params: BM25Params) -> Self {
Self {
documents: Vec::new(),
term_doc_freqs: HashMap::new(),
doc_term_freqs: Vec::new(),
doc_lengths: Vec::new(),
avgdl: 0.0,
n_docs: 0,
idf_cache: HashMap::new(),
params,
}
}
pub fn add_document(&mut self, document: Document, terms: Vec<String>) {
let doc_id = self.n_docs;
let mut term_freq = HashMap::new();
for term in &terms {
*term_freq.entry(term.clone()).or_insert(0) += 1;
}
for (term, freq) in &term_freq {
self.term_doc_freqs
.entry(term.clone())
.or_insert_with(HashMap::new)
.insert(doc_id, *freq);
}
self.documents.push(document);
self.doc_term_freqs.push(term_freq);
self.doc_lengths.push(terms.len());
self.n_docs += 1;
self.update_avgdl();
self.idf_cache.clear();
}
pub fn add_documents(&mut self, documents: Vec<Document>, terms_list: Vec<Vec<String>>) {
if documents.len() != terms_list.len() {
return;
}
for (doc, terms) in documents.into_iter().zip(terms_list) {
self.add_document(doc, terms);
}
}
fn update_avgdl(&mut self) {
if self.n_docs == 0 {
self.avgdl = 0.0;
} else {
let total_length: usize = self.doc_lengths.iter().sum();
self.avgdl = total_length as f64 / self.n_docs as f64;
}
}
pub fn compute_idf_for_term(&mut self, term: &str) -> f64 {
if let Some(idf) = self.idf_cache.get(term) {
return *idf;
}
let n = self.term_doc_freqs.get(term).map(|m| m.len()).unwrap_or(0);
let idf = compute_idf(n, self.n_docs);
self.idf_cache.insert(term.to_string(), idf);
idf
}
pub fn compute_idf_for_terms(&mut self, terms: &[String]) -> HashMap<String, f64> {
let mut idf_values = HashMap::new();
for term in terms {
idf_values.insert(term.clone(), self.compute_idf_for_term(term));
}
idf_values
}
pub fn get_document(&self, doc_id: usize) -> Option<&Document> {
self.documents.get(doc_id)
}
pub fn get_documents(&self) -> &[Document] {
&self.documents
}
pub fn get_doc_term_freq(&self, doc_id: usize) -> Option<&HashMap<String, usize>> {
self.doc_term_freqs.get(doc_id)
}
pub fn get_doc_length(&self, doc_id: usize) -> usize {
self.doc_lengths.get(doc_id).copied().unwrap_or(0)
}
pub fn avgdl(&self) -> f64 {
self.avgdl
}
pub fn n_docs(&self) -> usize {
self.n_docs
}
pub fn params(&self) -> &BM25Params {
&self.params
}
pub fn clear(&mut self) {
self.documents.clear();
self.term_doc_freqs.clear();
self.doc_term_freqs.clear();
self.doc_lengths.clear();
self.avgdl = 0.0;
self.n_docs = 0;
self.idf_cache.clear();
}
}
impl Default for BM25Index {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_index_basic() {
let mut index = BM25Index::new();
let doc = Document::new("Rust programming");
let terms = vec!["rust".to_string(), "programming".to_string()];
index.add_document(doc, terms);
assert_eq!(index.n_docs(), 1);
assert_eq!(index.get_doc_length(0), 2);
}
#[test]
fn test_index_idf() {
let mut index = BM25Index::new();
index.add_document(
Document::new("Rust programming language"),
vec![
"rust".to_string(),
"programming".to_string(),
"language".to_string(),
],
);
index.add_document(
Document::new("Python scripting language"),
vec![
"python".to_string(),
"scripting".to_string(),
"language".to_string(),
],
);
let idf_language = index.compute_idf_for_term("language");
let idf_rust = index.compute_idf_for_term("rust");
assert!(idf_rust > idf_language);
}
#[test]
fn test_avgdl() {
let mut index = BM25Index::new();
index.add_document(Document::new("a"), vec!["a".to_string()]);
index.add_document(
Document::new("a b c"),
vec!["a".to_string(), "b".to_string(), "c".to_string()],
);
assert_eq!(index.avgdl(), 2.0);
}
}