use crate::Result;
use std::collections::HashMap;
pub type DocumentId = String;
#[derive(Debug, Clone)]
pub struct Document {
pub id: DocumentId,
pub content: String,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct BM25Result {
pub doc_id: DocumentId,
pub score: f32,
pub content: String,
}
pub struct BM25Retriever {
k1: f32,
b: f32,
documents: HashMap<DocumentId, Document>,
term_frequencies: HashMap<String, HashMap<DocumentId, f32>>,
document_frequencies: HashMap<String, usize>,
document_lengths: HashMap<DocumentId, usize>,
avg_doc_length: f32,
total_docs: usize,
}
impl BM25Retriever {
pub fn new() -> Self {
Self::with_parameters(1.2, 0.75)
}
pub fn with_parameters(k1: f32, b: f32) -> Self {
Self {
k1,
b,
documents: HashMap::new(),
term_frequencies: HashMap::new(),
document_frequencies: HashMap::new(),
document_lengths: HashMap::new(),
avg_doc_length: 0.0,
total_docs: 0,
}
}
pub fn index_document(&mut self, document: Document) -> Result<()> {
let doc_id = document.id.clone();
let tokens = self.tokenize(&document.content);
let doc_length = tokens.len();
let mut term_freq: HashMap<String, usize> = HashMap::new();
for token in &tokens {
*term_freq.entry(token.clone()).or_insert(0) += 1;
}
for term in term_freq.keys() {
*self.document_frequencies.entry(term.clone()).or_insert(0) += 1;
}
for (term, freq) in term_freq {
let normalized_freq = freq as f32 / doc_length as f32;
self.term_frequencies
.entry(term)
.or_default()
.insert(doc_id.clone(), normalized_freq);
}
self.document_lengths.insert(doc_id.clone(), doc_length);
self.documents.insert(doc_id, document);
self.total_docs += 1;
self.update_avg_doc_length();
Ok(())
}
pub fn index_documents(&mut self, documents: &[Document]) -> Result<()> {
for document in documents {
self.index_document(document.clone())?;
}
Ok(())
}
pub fn search(&self, query: &str, limit: usize) -> Vec<BM25Result> {
if self.total_docs == 0 {
return Vec::new();
}
let query_tokens = self.tokenize(query);
let mut doc_scores: HashMap<DocumentId, f32> = HashMap::new();
for token in &query_tokens {
if let Some(doc_freqs) = self.term_frequencies.get(token) {
let idf = self.calculate_idf(token);
for (doc_id, tf) in doc_freqs {
let doc_length = *self.document_lengths.get(doc_id).unwrap_or(&0);
let bm25_term_score = self.calculate_bm25_term_score(*tf, doc_length, idf);
*doc_scores.entry(doc_id.clone()).or_insert(0.0) += bm25_term_score;
}
}
}
let mut results: Vec<BM25Result> = doc_scores
.into_iter()
.filter_map(|(doc_id, score)| {
self.documents.get(&doc_id).map(|doc| BM25Result {
doc_id: doc_id.clone(),
score,
content: doc.content.clone(),
})
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(limit);
results
}
pub fn get_document(&self, doc_id: &DocumentId) -> Option<&Document> {
self.documents.get(doc_id)
}
pub fn document_count(&self) -> usize {
self.total_docs
}
pub fn term_count(&self) -> usize {
self.term_frequencies.len()
}
fn calculate_idf(&self, term: &str) -> f32 {
let doc_freq = self.document_frequencies.get(term).unwrap_or(&0);
if *doc_freq == 0 {
return 0.0;
}
(self.total_docs as f32 / *doc_freq as f32).ln() + 1.0
}
fn calculate_bm25_term_score(&self, tf: f32, doc_length: usize, idf: f32) -> f32 {
let tf_component = (tf * (self.k1 + 1.0))
/ (tf + self.k1 * (1.0 - self.b + self.b * (doc_length as f32 / self.avg_doc_length)));
idf * tf_component
}
fn update_avg_doc_length(&mut self) {
if self.total_docs > 0 {
let total_length: usize = self.document_lengths.values().sum();
self.avg_doc_length = total_length as f32 / self.total_docs as f32;
}
}
fn tokenize(&self, text: &str) -> Vec<String> {
text.to_lowercase()
.split_whitespace()
.map(|s| {
s.chars()
.filter(|c| c.is_alphanumeric())
.collect::<String>()
})
.filter(|s| !s.is_empty() && s.len() > 2 && !self.is_stop_word(s))
.collect()
}
fn is_stop_word(&self, word: &str) -> bool {
const STOP_WORDS: &[&str] = &[
"the", "be", "to", "of", "and", "a", "in", "that", "have", "i", "it", "for", "not",
"on", "with", "he", "as", "you", "do", "at", "this", "but", "his", "by", "from",
"they", "we", "say", "her", "she", "or", "an", "will", "my", "one", "all", "would",
"there", "their", "what", "so", "up", "out", "if", "about", "who", "get", "which",
"go", "me", "when", "make", "can", "like", "time", "no", "just", "him", "know", "take",
"people", "into", "year", "your", "good", "some", "could", "them", "see", "other",
"than", "then", "now", "look", "only", "come", "its", "over", "think", "also", "back",
"after", "use", "two", "how", "our", "work", "first", "well", "way", "even", "new",
"want", "because", "any", "these", "give", "day", "most", "us",
];
STOP_WORDS.contains(&word)
}
pub fn clear(&mut self) {
self.documents.clear();
self.term_frequencies.clear();
self.document_frequencies.clear();
self.document_lengths.clear();
self.avg_doc_length = 0.0;
self.total_docs = 0;
}
pub fn get_statistics(&self) -> BM25Statistics {
BM25Statistics {
total_documents: self.total_docs,
total_terms: self.term_frequencies.len(),
avg_doc_length: self.avg_doc_length,
parameters: BM25Parameters {
k1: self.k1,
b: self.b,
},
}
}
}
impl Default for BM25Retriever {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct BM25Parameters {
pub k1: f32,
pub b: f32,
}
#[derive(Debug, Clone)]
pub struct BM25Statistics {
pub total_documents: usize,
pub total_terms: usize,
pub avg_doc_length: f32,
pub parameters: BM25Parameters,
}
impl BM25Statistics {
pub fn print(&self) {
println!("BM25 Index Statistics:");
println!(" Total documents: {}", self.total_documents);
println!(" Total terms: {}", self.total_terms);
println!(
" Average document length: {:.2} tokens",
self.avg_doc_length
);
println!(
" Parameters: k1={:.2}, b={:.2}",
self.parameters.k1, self.parameters.b
);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_documents() -> Vec<Document> {
vec![
Document {
id: "doc1".to_string(),
content: "The quick brown fox jumps over the lazy dog".to_string(),
metadata: HashMap::new(),
},
Document {
id: "doc2".to_string(),
content: "A fast brown animal leaps across a sleeping canine".to_string(),
metadata: HashMap::new(),
},
Document {
id: "doc3".to_string(),
content: "The weather is nice today".to_string(),
metadata: HashMap::new(),
},
]
}
#[test]
fn test_bm25_creation() {
let retriever = BM25Retriever::new();
assert_eq!(retriever.document_count(), 0);
assert_eq!(retriever.term_count(), 0);
}
#[test]
fn test_document_indexing() {
let mut retriever = BM25Retriever::new();
let docs = create_test_documents();
retriever.index_documents(&docs).unwrap();
assert_eq!(retriever.document_count(), 3);
assert!(retriever.term_count() > 0);
}
#[test]
fn test_search() {
let mut retriever = BM25Retriever::new();
let docs = create_test_documents();
retriever.index_documents(&docs).unwrap();
let results = retriever.search("brown fox", 10);
assert!(!results.is_empty());
assert_eq!(results[0].doc_id, "doc1");
assert!(results[0].score > 0.0);
}
#[test]
fn test_tokenization() {
let retriever = BM25Retriever::new();
let tokens = retriever.tokenize("The quick, brown fox!");
assert!(tokens.contains(&"quick".to_string()));
assert!(tokens.contains(&"brown".to_string()));
assert!(tokens.contains(&"fox".to_string()));
assert!(!tokens.contains(&"the".to_string())); }
#[test]
fn test_empty_search() {
let retriever = BM25Retriever::new();
let results = retriever.search("test", 10);
assert!(results.is_empty());
}
}