use std::collections::HashMap;
use std::sync::RwLock;
use async_trait::async_trait;
use super::traits::MemorySearcher;
use crate::error::Result;
const K1: f32 = 1.2;
const B: f32 = 0.75;
const MAX_BM25_SCORE_PER_TERM: f32 = 3.0;
pub struct Bm25Searcher {
index: RwLock<Bm25Index>,
}
struct Bm25Index {
term_docs: HashMap<String, HashMap<String, u32>>,
doc_lengths: HashMap<String, u32>,
doc_count: u32,
}
impl Bm25Index {
fn new() -> Self {
Self {
term_docs: HashMap::new(),
doc_lengths: HashMap::new(),
doc_count: 0,
}
}
fn avg_doc_length(&self) -> f32 {
if self.doc_count == 0 {
return 1.0;
}
let total: u32 = self.doc_lengths.values().sum();
total as f32 / self.doc_count as f32
}
}
impl Default for Bm25Searcher {
fn default() -> Self {
Self::new()
}
}
impl Bm25Searcher {
pub fn new() -> Self {
Self {
index: RwLock::new(Bm25Index::new()),
}
}
fn tokenize(text: &str) -> Vec<String> {
text.to_lowercase()
.split(|ch: char| !ch.is_ascii_alphanumeric())
.filter(|t| t.len() >= 2)
.map(|t| t.to_string())
.collect()
}
}
#[async_trait]
impl MemorySearcher for Bm25Searcher {
fn name(&self) -> &str {
"bm25"
}
fn score(&self, chunk: &str, query: &str) -> f32 {
let query_terms = Self::tokenize(query);
if query_terms.is_empty() {
return 0.0;
}
let chunk_tokens = Self::tokenize(chunk);
let doc_len = chunk_tokens.len() as f32;
if doc_len == 0.0 {
return 0.0;
}
let mut tf_map: HashMap<&str, u32> = HashMap::new();
for token in &chunk_tokens {
*tf_map.entry(token.as_str()).or_insert(0) += 1;
}
let index = self.index.read().unwrap();
let avg_dl = index.avg_doc_length();
let n = index.doc_count.max(1) as f32;
let mut score = 0.0f32;
for term in &query_terms {
let tf = *tf_map.get(term.as_str()).unwrap_or(&0) as f32;
if tf == 0.0 {
continue;
}
let df = index
.term_docs
.get(term)
.map(|docs| docs.len() as f32)
.unwrap_or(0.0);
let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
let tf_norm = (tf * (K1 + 1.0)) / (tf + K1 * (1.0 - B + B * doc_len / avg_dl));
score += idf * tf_norm;
}
let max_possible = query_terms.len() as f32 * MAX_BM25_SCORE_PER_TERM;
(score / max_possible).clamp(0.0, 1.0)
}
async fn index(&self, key: &str, text: &str) -> Result<()> {
let tokens = Self::tokenize(text);
let mut index = self.index.write().unwrap();
if index.doc_lengths.contains_key(key) {
for docs in index.term_docs.values_mut() {
docs.remove(key);
}
index.term_docs.retain(|_, docs| !docs.is_empty());
index.doc_count = index.doc_count.saturating_sub(1);
}
let mut tf_map: HashMap<String, u32> = HashMap::new();
for token in &tokens {
*tf_map.entry(token.clone()).or_insert(0) += 1;
}
for (term, count) in tf_map {
index
.term_docs
.entry(term)
.or_default()
.insert(key.to_string(), count);
}
index
.doc_lengths
.insert(key.to_string(), tokens.len() as u32);
index.doc_count += 1;
Ok(())
}
async fn remove(&self, key: &str) -> Result<()> {
let mut index = self.index.write().unwrap();
if index.doc_lengths.remove(key).is_some() {
for docs in index.term_docs.values_mut() {
docs.remove(key);
}
index.term_docs.retain(|_, docs| !docs.is_empty());
index.doc_count = index.doc_count.saturating_sub(1);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_name() {
assert_eq!(Bm25Searcher::new().name(), "bm25");
}
#[test]
fn test_score_matching_terms() {
let searcher = Bm25Searcher::new();
let score = searcher.score("Rust programming language", "rust programming");
assert!(score > 0.0, "Matching terms should score > 0: {}", score);
}
#[test]
fn test_score_no_match() {
let searcher = Bm25Searcher::new();
let score = searcher.score("Hello World", "foobar baz");
assert_eq!(score, 0.0);
}
#[test]
fn test_score_empty_query() {
let searcher = Bm25Searcher::new();
assert_eq!(searcher.score("some text", ""), 0.0);
}
#[test]
fn test_score_empty_chunk() {
let searcher = Bm25Searcher::new();
assert_eq!(searcher.score("", "query"), 0.0);
}
#[tokio::test]
async fn test_index_improves_idf() {
let searcher = Bm25Searcher::new();
let before = searcher.score("rust is fast", "rust");
searcher.index("doc1", "rust is fast").await.unwrap();
searcher.index("doc2", "python is slow").await.unwrap();
searcher.index("doc3", "java is verbose").await.unwrap();
let after = searcher.score("rust is fast", "rust");
assert!(before > 0.0);
assert!(after > 0.0);
assert_ne!(before, after, "IDF weighting should change the score");
}
#[tokio::test]
async fn test_remove_cleans_index() {
let searcher = Bm25Searcher::new();
searcher.index("doc1", "hello world").await.unwrap();
{
let index = searcher.index.read().unwrap();
assert_eq!(index.doc_count, 1);
assert!(index.doc_lengths.contains_key("doc1"));
}
searcher.remove("doc1").await.unwrap();
{
let index = searcher.index.read().unwrap();
assert_eq!(index.doc_count, 0);
assert!(!index.doc_lengths.contains_key("doc1"));
}
}
#[tokio::test]
async fn test_index_upsert() {
let searcher = Bm25Searcher::new();
searcher.index("doc1", "hello world").await.unwrap();
searcher.index("doc1", "goodbye world").await.unwrap();
let index = searcher.index.read().unwrap();
assert_eq!(index.doc_count, 1, "Upsert should not increase doc count");
}
}