use std::collections::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BM25Config {
pub k1: f32,
pub b: f32,
pub min_df: usize,
pub max_df_ratio: f32,
pub stem: bool,
pub lowercase: bool,
}
impl Default for BM25Config {
fn default() -> Self {
Self {
k1: 1.5,
b: 0.75,
min_df: 1,
max_df_ratio: 0.85,
stem: false,
lowercase: true,
}
}
}
impl BM25Config {
#[must_use]
pub fn for_short_docs() -> Self {
Self {
k1: 1.2,
b: 0.5, ..Default::default()
}
}
#[must_use]
pub fn for_long_docs() -> Self {
Self {
k1: 2.0,
b: 0.75,
..Default::default()
}
}
}
#[derive(Debug, Clone)]
struct IndexedDocument {
id: String,
content: String,
term_freqs: HashMap<String, u32>,
length: usize,
}
#[derive(Debug, Clone)]
pub struct BM25Result {
pub id: String,
pub content: String,
pub score: f32,
}
pub struct BM25Index {
config: BM25Config,
documents: HashMap<String, IndexedDocument>,
inverted_index: HashMap<String, HashSet<String>>,
doc_freqs: HashMap<String, usize>,
total_docs: usize,
avg_doc_length: f32,
}
impl BM25Index {
#[must_use]
pub fn new(config: BM25Config) -> Self {
Self {
config,
documents: HashMap::new(),
inverted_index: HashMap::new(),
doc_freqs: HashMap::new(),
total_docs: 0,
avg_doc_length: 0.0,
}
}
#[must_use]
pub fn with_defaults() -> Self {
Self::new(BM25Config::default())
}
#[must_use]
pub fn config(&self) -> &BM25Config {
&self.config
}
#[must_use]
pub fn len(&self) -> usize {
self.total_docs
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.total_docs == 0
}
#[must_use]
pub fn vocabulary_size(&self) -> usize {
self.inverted_index.len()
}
fn tokenize(&self, text: &str) -> Vec<String> {
let text = if self.config.lowercase {
text.to_lowercase()
} else {
text.to_string()
};
let tokens: Vec<String> = text
.split(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty() && s.len() > 1)
.map(|s| {
if self.config.stem {
simple_stem(s)
} else {
s.to_string()
}
})
.collect();
tokens
}
pub fn add_document(&mut self, id: impl Into<String>, content: impl Into<String>) {
let id = id.into();
let content = content.into();
self.remove_document(&id);
let tokens = self.tokenize(&content);
let length = tokens.len();
let mut term_freqs: HashMap<String, u32> = HashMap::new();
for token in &tokens {
*term_freqs.entry(token.clone()).or_insert(0) += 1;
}
for term in term_freqs.keys() {
self.inverted_index
.entry(term.clone())
.or_insert_with(HashSet::new)
.insert(id.clone());
*self.doc_freqs.entry(term.clone()).or_insert(0) += 1;
}
let doc = IndexedDocument {
id: id.clone(),
content,
term_freqs,
length,
};
self.documents.insert(id, doc);
self.total_docs += 1;
self.update_avg_length();
}
pub fn add_documents<I, S1, S2>(&mut self, documents: I)
where
I: IntoIterator<Item = (S1, S2)>,
S1: Into<String>,
S2: Into<String>,
{
for (id, content) in documents {
self.add_document(id, content);
}
}
pub fn remove_document(&mut self, id: &str) -> bool {
if let Some(doc) = self.documents.remove(id) {
for term in doc.term_freqs.keys() {
if let Some(doc_set) = self.inverted_index.get_mut(term) {
doc_set.remove(id);
if doc_set.is_empty() {
self.inverted_index.remove(term);
}
}
if let Some(df) = self.doc_freqs.get_mut(term) {
*df = df.saturating_sub(1);
if *df == 0 {
self.doc_freqs.remove(term);
}
}
}
self.total_docs -= 1;
self.update_avg_length();
true
} else {
false
}
}
fn update_avg_length(&mut self) {
if self.total_docs == 0 {
self.avg_doc_length = 0.0;
} else {
let total_length: usize = self.documents.values().map(|d| d.length).sum();
self.avg_doc_length = total_length as f32 / self.total_docs as f32;
}
}
fn idf(&self, term: &str) -> f32 {
let df = self.doc_freqs.get(term).copied().unwrap_or(0);
if df < self.config.min_df {
return 0.0;
}
if self.total_docs >= 5 {
let df_ratio = df as f32 / self.total_docs as f32;
if df_ratio > self.config.max_df_ratio {
return 0.0;
}
}
let n = self.total_docs as f32;
let df = df as f32;
((n - df + 0.5) / (df + 0.5) + 1.0).ln()
}
fn score_document(&self, doc: &IndexedDocument, query_terms: &[String]) -> f32 {
let k1 = self.config.k1;
let b = self.config.b;
let avgdl = self.avg_doc_length;
let dl = doc.length as f32;
let mut score = 0.0;
for term in query_terms {
let idf = self.idf(term);
if idf == 0.0 {
continue;
}
let tf = doc.term_freqs.get(term).copied().unwrap_or(0) as f32;
if tf == 0.0 {
continue;
}
let numerator = tf * (k1 + 1.0);
let denominator = tf + k1 * (1.0 - b + b * dl / avgdl);
score += idf * numerator / denominator;
}
score
}
pub fn search(&self, query: &str, top_k: usize) -> Vec<BM25Result> {
if self.total_docs == 0 {
return Vec::new();
}
let query_terms = self.tokenize(query);
if query_terms.is_empty() {
return Vec::new();
}
let mut candidates: HashSet<String> = HashSet::new();
for term in &query_terms {
if let Some(doc_ids) = self.inverted_index.get(term) {
candidates.extend(doc_ids.iter().cloned());
}
}
let mut results: Vec<BM25Result> = candidates
.into_iter()
.filter_map(|candidate_id| {
let doc = self.documents.get(&candidate_id)?;
let score = self.score_document(doc, &query_terms);
if score > 0.0 {
Some(BM25Result {
id: doc.id.clone(),
content: doc.content.clone(),
score,
})
} else {
None
}
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
results
}
pub fn search_with_threshold(
&self,
query: &str,
top_k: usize,
min_score: f32,
) -> Vec<BM25Result> {
self.search(query, top_k)
.into_iter()
.filter(|r| r.score >= min_score)
.collect()
}
pub fn get_term_idfs(&self, query: &str) -> HashMap<String, f32> {
let terms = self.tokenize(query);
terms
.into_iter()
.map(|t| {
let idf = self.idf(&t);
(t, idf)
})
.collect()
}
pub fn clear(&mut self) {
self.documents.clear();
self.inverted_index.clear();
self.doc_freqs.clear();
self.total_docs = 0;
self.avg_doc_length = 0.0;
}
}
impl Default for BM25Index {
fn default() -> Self {
Self::with_defaults()
}
}
fn simple_stem(word: &str) -> String {
let word = word.to_lowercase();
let suffixes = ["ing", "ed", "es", "s", "ly", "ment", "ness", "tion", "sion"];
for suffix in suffixes {
if word.len() > suffix.len() + 2 && word.ends_with(suffix) {
return word[..word.len() - suffix.len()].to_string();
}
}
word
}
pub struct HybridRetriever {
bm25_index: BM25Index,
bm25_weight: f32,
dense_weight: f32,
}
impl HybridRetriever {
#[must_use]
pub fn new(bm25_config: BM25Config, bm25_weight: f32, dense_weight: f32) -> Self {
Self {
bm25_index: BM25Index::new(bm25_config),
bm25_weight,
dense_weight,
}
}
#[must_use]
pub fn with_equal_weights() -> Self {
Self::new(BM25Config::default(), 0.5, 0.5)
}
#[must_use]
pub fn dense_heavy() -> Self {
Self::new(BM25Config::default(), 0.3, 0.7)
}
#[must_use]
pub fn sparse_heavy() -> Self {
Self::new(BM25Config::default(), 0.7, 0.3)
}
#[must_use]
pub fn bm25_index(&self) -> &BM25Index {
&self.bm25_index
}
pub fn bm25_index_mut(&mut self) -> &mut BM25Index {
&mut self.bm25_index
}
#[must_use]
pub fn document_count(&self) -> usize {
self.bm25_index.len()
}
pub fn add_document(&mut self, id: impl Into<String>, content: impl Into<String>) {
self.bm25_index.add_document(id, content);
}
pub fn hybrid_search(
&self,
query: &str,
dense_results: &[(String, f32)],
top_k: usize,
) -> Vec<HybridResult> {
let bm25_results = self
.bm25_index
.search(query, dense_results.len().max(top_k * 2));
let max_bm25 = bm25_results.iter().map(|r| r.score).fold(0.0_f32, f32::max);
let bm25_scores: HashMap<String, f32> = bm25_results
.into_iter()
.map(|r| {
let normalized = if max_bm25 > 0.0 {
r.score / max_bm25
} else {
0.0
};
(r.id, normalized)
})
.collect();
let max_dense = dense_results
.iter()
.map(|(_, s)| *s)
.fold(0.0_f32, f32::max);
let dense_scores: HashMap<String, f32> = dense_results
.iter()
.map(|(id, score)| {
let normalized = if max_dense > 0.0 {
*score / max_dense
} else {
0.0
};
(id.clone(), normalized)
})
.collect();
let all_ids: HashSet<&String> = bm25_scores.keys().chain(dense_scores.keys()).collect();
let mut results: Vec<HybridResult> = all_ids
.into_iter()
.map(|id| {
let bm25 = bm25_scores.get(id).copied().unwrap_or(0.0);
let dense = dense_scores.get(id).copied().unwrap_or(0.0);
let hybrid = self.bm25_weight * bm25 + self.dense_weight * dense;
HybridResult {
id: id.clone(),
bm25_score: bm25,
dense_score: dense,
hybrid_score: hybrid,
}
})
.collect();
results.sort_by(|a, b| {
b.hybrid_score
.partial_cmp(&a.hybrid_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
results
}
pub fn set_bm25_weight(&mut self, weight: f32) {
self.bm25_weight = weight.clamp(0.0, 1.0);
}
pub fn set_dense_weight(&mut self, weight: f32) {
self.dense_weight = weight.clamp(0.0, 1.0);
}
}
#[derive(Debug, Clone)]
pub struct HybridResult {
pub id: String,
pub bm25_score: f32,
pub dense_score: f32,
pub hybrid_score: f32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bm25_config_default() {
let config = BM25Config::default();
assert!((config.k1 - 1.5).abs() < 0.001);
assert!((config.b - 0.75).abs() < 0.001);
assert!(config.lowercase);
}
#[test]
fn test_bm25_index_new() {
let index = BM25Index::with_defaults();
assert!(index.is_empty());
assert_eq!(index.len(), 0);
assert_eq!(index.vocabulary_size(), 0);
}
#[test]
fn test_bm25_add_document() {
let mut index = BM25Index::with_defaults();
index.add_document("doc1", "The quick brown fox");
assert_eq!(index.len(), 1);
assert!(!index.is_empty());
assert!(index.vocabulary_size() > 0);
}
#[test]
fn test_bm25_add_multiple_documents() {
let mut index = BM25Index::with_defaults();
index.add_documents([
("doc1", "The quick brown fox"),
("doc2", "The lazy dog"),
("doc3", "A quick lazy fox"),
]);
assert_eq!(index.len(), 3);
}
#[test]
fn test_bm25_remove_document() {
let mut index = BM25Index::with_defaults();
index.add_document("doc1", "The quick brown fox");
index.add_document("doc2", "The lazy dog");
assert!(index.remove_document("doc1"));
assert_eq!(index.len(), 1);
assert!(!index.remove_document("nonexistent"));
}
#[test]
fn test_bm25_search_basic() {
let mut index = BM25Index::with_defaults();
index.add_documents([
("doc1", "The quick brown fox jumps over the lazy dog"),
("doc2", "A lazy cat sleeps all day"),
("doc3", "The fox is quick and smart"),
]);
let results = index.search("quick fox", 10);
assert!(!results.is_empty());
assert!(results.len() >= 2);
}
#[test]
fn test_bm25_search_empty_query() {
let mut index = BM25Index::with_defaults();
index.add_document("doc1", "The quick brown fox");
let results = index.search("", 10);
assert!(results.is_empty());
}
#[test]
fn test_bm25_search_no_matches() {
let mut index = BM25Index::with_defaults();
index.add_document("doc1", "The quick brown fox");
let results = index.search("elephant", 10);
assert!(results.is_empty());
}
#[test]
fn test_bm25_search_with_threshold() {
let mut index = BM25Index::with_defaults();
index.add_documents([
("doc1", "machine learning artificial intelligence"),
("doc2", "machine parts factory"),
("doc3", "deep learning neural networks"),
]);
let results = index.search_with_threshold("machine learning", 10, 0.5);
for result in &results {
assert!(result.score >= 0.5);
}
}
#[test]
fn test_bm25_clear() {
let mut index = BM25Index::with_defaults();
index.add_document("doc1", "test document");
index.clear();
assert!(index.is_empty());
assert_eq!(index.vocabulary_size(), 0);
}
#[test]
fn test_bm25_get_term_idfs() {
let mut index = BM25Index::with_defaults();
index.add_documents([
("doc1", "common rare unique"),
("doc2", "common word"),
("doc3", "common another"),
]);
let idfs = index.get_term_idfs("common rare unique");
assert!(idfs.get("rare").unwrap_or(&0.0) > idfs.get("common").unwrap_or(&f32::MAX));
}
#[test]
fn test_simple_stem() {
assert_eq!(simple_stem("running"), "runn");
assert_eq!(simple_stem("played"), "play");
assert_eq!(simple_stem("cats"), "cat");
assert_eq!(simple_stem("quickly"), "quick");
}
#[test]
fn test_hybrid_retriever_new() {
let retriever = HybridRetriever::with_equal_weights();
assert!(retriever.bm25_index().is_empty());
}
#[test]
fn test_hybrid_retriever_add_document() {
let mut retriever = HybridRetriever::with_equal_weights();
retriever.add_document("doc1", "test content");
assert_eq!(retriever.bm25_index().len(), 1);
}
#[test]
fn test_hybrid_search() {
let mut retriever = HybridRetriever::with_equal_weights();
retriever.add_document("doc1", "machine learning algorithms");
retriever.add_document("doc2", "deep learning neural networks");
retriever.add_document("doc3", "learning to code");
let dense_results = vec![
("doc2".to_string(), 0.9),
("doc1".to_string(), 0.7),
("doc3".to_string(), 0.3),
];
let results = retriever.hybrid_search("machine learning", &dense_results, 10);
assert!(!results.is_empty());
for result in &results {
assert!(result.hybrid_score >= 0.0);
}
}
#[test]
fn test_hybrid_weights() {
let mut retriever = HybridRetriever::dense_heavy();
retriever.set_bm25_weight(0.4);
retriever.set_dense_weight(0.6);
retriever.set_bm25_weight(1.5);
retriever.set_dense_weight(-0.1);
}
#[test]
fn test_bm25_score_ordering() {
let mut index = BM25Index::with_defaults();
index.add_documents([
("doc1", "fox fox fox fox fox"), ("doc2", "fox"), ("doc3", "the quick brown animal jumps"), ]);
let results = index.search("fox", 10);
assert_eq!(results.len(), 2);
if results.len() >= 2 {
assert!(results[0].score >= results[1].score);
}
}
#[test]
fn test_bm25_document_update() {
let mut index = BM25Index::with_defaults();
index.add_document("doc1", "original content");
index.add_document("doc1", "updated content new");
assert_eq!(index.len(), 1);
let results = index.search("updated", 10);
assert!(!results.is_empty());
let results = index.search("original", 10);
assert!(results.is_empty());
}
}