use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub type FieldWeights = HashMap<String, f32>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BM25Config {
pub k1: f32,
pub b: f32,
}
impl Default for BM25Config {
fn default() -> Self {
Self { k1: 1.2, b: 0.75 }
}
}
#[derive(Debug, Clone)]
pub struct BM25Stats {
pub avg_doc_length: f32,
pub idf: HashMap<usize, f32>,
pub num_docs: usize,
}
impl BM25Stats {
pub fn from_corpus<'a, I>(documents: I) -> Self
where
I: Iterator<Item = (&'a [usize], &'a [f32])>,
{
let mut doc_count: HashMap<usize, usize> = HashMap::new();
let mut total_doc_length = 0.0;
let mut num_docs = 0;
for (indices, values) in documents {
num_docs += 1;
total_doc_length += values.iter().sum::<f32>();
for &term_idx in indices {
*doc_count.entry(term_idx).or_insert(0) += 1;
}
}
let avg_doc_length = if num_docs > 0 {
total_doc_length / num_docs as f32
} else {
0.0
};
let idf = doc_count
.into_iter()
.map(|(term_idx, df)| {
let idf_score =
((num_docs as f32 - df as f32 + 0.5) / (df as f32 + 0.5) + 1.0).ln();
(term_idx, idf_score)
})
.collect();
BM25Stats {
avg_doc_length,
idf,
num_docs,
}
}
pub fn get_idf(&self, term_idx: usize) -> f32 {
self.idf.get(&term_idx).copied().unwrap_or(0.0)
}
}
pub fn bm25_score(
query_indices: &[usize],
query_weights: &[f32],
doc_indices: &[usize],
doc_values: &[f32],
stats: &BM25Stats,
config: &BM25Config,
) -> f32 {
let doc_terms: HashMap<usize, f32> = doc_indices
.iter()
.zip(doc_values.iter())
.map(|(&idx, &val)| (idx, val))
.collect();
let doc_length = doc_values.iter().sum::<f32>();
let mut score = 0.0;
for (&term_idx, &query_weight) in query_indices.iter().zip(query_weights.iter()) {
let term_freq = match doc_terms.get(&term_idx) {
Some(&tf) => tf,
None => continue,
};
let idf = stats.get_idf(term_idx);
let numerator = term_freq * (config.k1 + 1.0);
let denominator =
term_freq + config.k1 * (1.0 - config.b + config.b * doc_length / stats.avg_doc_length);
score += idf * query_weight * (numerator / denominator);
}
score
}
pub fn bm25_score_simple(
query_indices: &[usize],
doc_indices: &[usize],
doc_values: &[f32],
config: &BM25Config,
) -> f32 {
let doc_terms: HashMap<usize, f32> = doc_indices
.iter()
.zip(doc_values.iter())
.map(|(&idx, &val)| (idx, val))
.collect();
let doc_length = doc_values.iter().sum::<f32>();
let avg_doc_length = doc_length;
let mut score = 0.0;
for &term_idx in query_indices {
let term_freq = match doc_terms.get(&term_idx) {
Some(&tf) => tf,
None => continue,
};
let numerator = term_freq * (config.k1 + 1.0);
let denominator =
term_freq + config.k1 * (1.0 - config.b + config.b * doc_length / avg_doc_length);
score += numerator / denominator;
}
score
}
pub fn bm25f_score(
query_indices: &[usize],
query_weights: &[f32],
doc_fields: &HashMap<String, (Vec<usize>, Vec<f32>)>,
field_weights: &FieldWeights,
stats: &BM25Stats,
config: &BM25Config,
) -> f32 {
let mut combined_tf: HashMap<usize, f32> = HashMap::new();
let mut total_doc_length = 0.0;
for (field_name, (indices, values)) in doc_fields {
let boost = field_weights.get(field_name).copied().unwrap_or(1.0);
let field_length: f32 = values.iter().sum();
total_doc_length += field_length * boost;
for (&term_idx, &freq) in indices.iter().zip(values.iter()) {
*combined_tf.entry(term_idx).or_insert(0.0) += freq * boost;
}
}
let mut score = 0.0;
for (&term_idx, &query_weight) in query_indices.iter().zip(query_weights.iter()) {
let term_freq = match combined_tf.get(&term_idx) {
Some(&tf) => tf,
None => continue,
};
let idf = stats.get_idf(term_idx);
let numerator = term_freq * (config.k1 + 1.0);
let denominator = term_freq
+ config.k1 * (1.0 - config.b + config.b * total_doc_length / stats.avg_doc_length);
score += idf * query_weight * (numerator / denominator);
}
score
}
pub fn parse_field_weight(field_spec: &str) -> (&str, f32) {
if let Some(pos) = field_spec.find('^') {
let field = &field_spec[..pos];
let weight_str = &field_spec[pos + 1..];
let weight = weight_str.parse::<f32>().unwrap_or(1.0);
(field, weight)
} else {
(field_spec, 1.0)
}
}
pub fn parse_field_weights(field_specs: &[&str]) -> FieldWeights {
field_specs
.iter()
.map(|spec| {
let (field, weight) = parse_field_weight(spec);
(field.to_string(), weight)
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bm25_config_default() {
let config = BM25Config::default();
assert_eq!(config.k1, 1.2);
assert_eq!(config.b, 0.75);
}
#[test]
fn test_bm25_stats_from_corpus() {
let corpus = vec![
(vec![1, 2, 3], vec![1.0, 1.0, 1.0]),
(vec![1, 2], vec![1.0, 1.0]),
(vec![1, 4], vec![1.0, 1.0]),
];
let docs: Vec<(&[usize], &[f32])> = corpus
.iter()
.map(|(indices, values)| (indices.as_slice(), values.as_slice()))
.collect();
let stats = BM25Stats::from_corpus(docs.into_iter());
assert_eq!(stats.num_docs, 3);
assert_eq!(stats.avg_doc_length, (3.0 + 2.0 + 2.0) / 3.0);
let idf_1 = stats.get_idf(1);
assert!(idf_1 > 0.0);
let idf_2 = stats.get_idf(2);
assert!(idf_2 > idf_1);
let idf_5 = stats.get_idf(5);
assert_eq!(idf_5, 0.0);
}
#[test]
fn test_bm25_score_exact_match() {
let mut idf = HashMap::new();
idf.insert(1, 1.0);
idf.insert(2, 1.0);
let stats = BM25Stats {
avg_doc_length: 2.0,
idf,
num_docs: 100,
};
let query_indices = vec![1, 2];
let query_weights = vec![1.0, 1.0];
let doc_indices = vec![1, 2];
let doc_values = vec![1.0, 1.0];
let score = bm25_score(
&query_indices,
&query_weights,
&doc_indices,
&doc_values,
&stats,
&BM25Config::default(),
);
assert!(score > 0.0);
}
#[test]
fn test_bm25_score_no_match() {
let mut idf = HashMap::new();
idf.insert(1, 1.0);
idf.insert(2, 1.0);
idf.insert(3, 1.0);
idf.insert(4, 1.0);
let stats = BM25Stats {
avg_doc_length: 2.0,
idf,
num_docs: 100,
};
let query_indices = vec![1, 2];
let query_weights = vec![1.0, 1.0];
let doc_indices = vec![3, 4];
let doc_values = vec![1.0, 1.0];
let score = bm25_score(
&query_indices,
&query_weights,
&doc_indices,
&doc_values,
&stats,
&BM25Config::default(),
);
assert_eq!(score, 0.0);
}
#[test]
fn test_bm25_score_partial_match() {
let mut idf = HashMap::new();
idf.insert(1, 2.0);
idf.insert(2, 2.0);
idf.insert(3, 2.0);
let stats = BM25Stats {
avg_doc_length: 2.0,
idf,
num_docs: 100,
};
let query_indices = vec![1, 2];
let query_weights = vec![1.0, 1.0];
let doc_indices = vec![1, 3];
let doc_values = vec![1.0, 1.0];
let score = bm25_score(
&query_indices,
&query_weights,
&doc_indices,
&doc_values,
&stats,
&BM25Config::default(),
);
assert!(score > 0.0);
}
#[test]
fn test_bm25_score_frequency_matters() {
let mut idf = HashMap::new();
idf.insert(1, 2.0);
let stats = BM25Stats {
avg_doc_length: 5.0,
idf,
num_docs: 100,
};
let query_indices = vec![1];
let query_weights = vec![1.0];
let doc1_indices = vec![1];
let doc1_values = vec![1.0];
let score1 = bm25_score(
&query_indices,
&query_weights,
&doc1_indices,
&doc1_values,
&stats,
&BM25Config::default(),
);
let doc2_indices = vec![1];
let doc2_values = vec![5.0];
let score2 = bm25_score(
&query_indices,
&query_weights,
&doc2_indices,
&doc2_values,
&stats,
&BM25Config::default(),
);
assert!(score2 > score1);
}
#[test]
fn test_bm25_score_simple() {
let query_indices = vec![1, 2];
let doc_indices = vec![1, 2, 3];
let doc_values = vec![2.0, 1.0, 1.0];
let score = bm25_score_simple(
&query_indices,
&doc_indices,
&doc_values,
&BM25Config::default(),
);
assert!(score > 0.0);
}
#[test]
fn test_bm25_k1_parameter() {
let mut idf = HashMap::new();
idf.insert(1, 1.0);
let stats = BM25Stats {
avg_doc_length: 10.0,
idf,
num_docs: 100,
};
let query_indices = vec![1];
let query_weights = vec![1.0];
let doc_indices = vec![1];
let doc_values = vec![10.0];
let config_low = BM25Config { k1: 0.5, b: 0.75 };
let score_low = bm25_score(
&query_indices,
&query_weights,
&doc_indices,
&doc_values,
&stats,
&config_low,
);
let config_high = BM25Config { k1: 3.0, b: 0.75 };
let score_high = bm25_score(
&query_indices,
&query_weights,
&doc_indices,
&doc_values,
&stats,
&config_high,
);
assert!(score_high > score_low);
}
#[test]
fn test_parse_field_weight_with_boost() {
let (field, weight) = parse_field_weight("title^3");
assert_eq!(field, "title");
assert_eq!(weight, 3.0);
}
#[test]
fn test_parse_field_weight_with_float_boost() {
let (field, weight) = parse_field_weight("abstract^2.5");
assert_eq!(field, "abstract");
assert_eq!(weight, 2.5);
}
#[test]
fn test_parse_field_weight_without_boost() {
let (field, weight) = parse_field_weight("content");
assert_eq!(field, "content");
assert_eq!(weight, 1.0);
}
#[test]
fn test_parse_field_weight_invalid_boost() {
let (field, weight) = parse_field_weight("title^invalid");
assert_eq!(field, "title");
assert_eq!(weight, 1.0); }
#[test]
fn test_parse_field_weights_multiple() {
let specs = vec!["title^3", "abstract^2", "content"];
let weights = parse_field_weights(&specs);
assert_eq!(weights.len(), 3);
assert_eq!(weights.get("title"), Some(&3.0));
assert_eq!(weights.get("abstract"), Some(&2.0));
assert_eq!(weights.get("content"), Some(&1.0));
}
#[test]
fn test_parse_field_weights_empty() {
let specs: Vec<&str> = vec![];
let weights = parse_field_weights(&specs);
assert_eq!(weights.len(), 0);
}
#[test]
fn test_bm25f_single_field_matches_regular_bm25() {
let mut idf = HashMap::new();
idf.insert(1, 2.0);
idf.insert(2, 1.5);
let stats = BM25Stats {
avg_doc_length: 10.0,
idf,
num_docs: 100,
};
let query_indices = vec![1, 2];
let query_weights = vec![1.0, 1.0];
let doc_indices = vec![1, 2, 3];
let doc_values = vec![2.0, 1.0, 1.0];
let regular_score = bm25_score(
&query_indices,
&query_weights,
&doc_indices,
&doc_values,
&stats,
&BM25Config::default(),
);
let mut doc_fields = HashMap::new();
doc_fields.insert(
"content".to_string(),
(doc_indices.clone(), doc_values.clone()),
);
let mut field_weights = HashMap::new();
field_weights.insert("content".to_string(), 1.0);
let bm25f_score_result = bm25f_score(
&query_indices,
&query_weights,
&doc_fields,
&field_weights,
&stats,
&BM25Config::default(),
);
assert!((regular_score - bm25f_score_result).abs() < 0.01);
}
#[test]
fn test_bm25f_multiple_fields() {
let mut idf = HashMap::new();
idf.insert(1, 2.0); idf.insert(2, 1.5); idf.insert(3, 1.0);
let stats = BM25Stats {
avg_doc_length: 10.0,
idf,
num_docs: 100,
};
let query_indices = vec![1, 2]; let query_weights = vec![1.0, 1.0];
let mut doc_fields = HashMap::new();
doc_fields.insert("title".to_string(), (vec![1, 2], vec![1.0, 1.0]));
doc_fields.insert("abstract".to_string(), (vec![1, 3], vec![1.0, 1.0]));
doc_fields.insert("content".to_string(), (vec![2, 3], vec![1.0, 1.0]));
let mut field_weights = HashMap::new();
field_weights.insert("title".to_string(), 1.0);
field_weights.insert("abstract".to_string(), 1.0);
field_weights.insert("content".to_string(), 1.0);
let score = bm25f_score(
&query_indices,
&query_weights,
&doc_fields,
&field_weights,
&stats,
&BM25Config::default(),
);
assert!(score > 0.0);
}
#[test]
fn test_bm25f_title_boost() {
let mut idf = HashMap::new();
idf.insert(1, 2.0);
let stats = BM25Stats {
avg_doc_length: 10.0,
idf,
num_docs: 100,
};
let query_indices = vec![1];
let query_weights = vec![1.0];
let mut doc_fields = HashMap::new();
doc_fields.insert("title".to_string(), (vec![1], vec![1.0]));
doc_fields.insert("content".to_string(), (vec![1], vec![1.0]));
let mut field_weights_no_boost = HashMap::new();
field_weights_no_boost.insert("title".to_string(), 1.0);
field_weights_no_boost.insert("content".to_string(), 1.0);
let score_no_boost = bm25f_score(
&query_indices,
&query_weights,
&doc_fields,
&field_weights_no_boost,
&stats,
&BM25Config::default(),
);
let mut field_weights_with_boost = HashMap::new();
field_weights_with_boost.insert("title".to_string(), 3.0);
field_weights_with_boost.insert("content".to_string(), 1.0);
let score_with_boost = bm25f_score(
&query_indices,
&query_weights,
&doc_fields,
&field_weights_with_boost,
&stats,
&BM25Config::default(),
);
assert!(score_with_boost > score_no_boost);
}
#[test]
fn test_bm25f_missing_field_weight() {
let mut idf = HashMap::new();
idf.insert(1, 2.0);
let stats = BM25Stats {
avg_doc_length: 10.0,
idf,
num_docs: 100,
};
let query_indices = vec![1];
let query_weights = vec![1.0];
let mut doc_fields = HashMap::new();
doc_fields.insert("title".to_string(), (vec![1], vec![1.0]));
doc_fields.insert("content".to_string(), (vec![1], vec![1.0]));
let mut field_weights = HashMap::new();
field_weights.insert("title".to_string(), 2.0);
let score = bm25f_score(
&query_indices,
&query_weights,
&doc_fields,
&field_weights,
&stats,
&BM25Config::default(),
);
assert!(score > 0.0);
}
#[test]
fn test_bm25f_no_matching_terms() {
let mut idf = HashMap::new();
idf.insert(1, 2.0);
idf.insert(2, 1.5);
let stats = BM25Stats {
avg_doc_length: 10.0,
idf,
num_docs: 100,
};
let query_indices = vec![1, 2];
let query_weights = vec![1.0, 1.0];
let mut doc_fields = HashMap::new();
doc_fields.insert("title".to_string(), (vec![3, 4], vec![1.0, 1.0]));
let mut field_weights = HashMap::new();
field_weights.insert("title".to_string(), 1.0);
let score = bm25f_score(
&query_indices,
&query_weights,
&doc_fields,
&field_weights,
&stats,
&BM25Config::default(),
);
assert_eq!(score, 0.0);
}
#[test]
fn test_bm25f_empty_fields() {
let mut idf = HashMap::new();
idf.insert(1, 2.0);
let stats = BM25Stats {
avg_doc_length: 10.0,
idf,
num_docs: 100,
};
let query_indices = vec![1];
let query_weights = vec![1.0];
let doc_fields = HashMap::new(); let field_weights = HashMap::new();
let score = bm25f_score(
&query_indices,
&query_weights,
&doc_fields,
&field_weights,
&stats,
&BM25Config::default(),
);
assert_eq!(score, 0.0);
}
#[test]
fn test_bm25f_realistic_document() {
let mut idf = HashMap::new();
idf.insert(100, 2.5); idf.insert(200, 2.0); idf.insert(300, 1.8);
let stats = BM25Stats {
avg_doc_length: 50.0,
idf,
num_docs: 1000,
};
let query_indices = vec![100, 200, 300];
let query_weights = vec![1.0, 1.0, 1.0];
let mut doc_fields = HashMap::new();
doc_fields.insert("title".to_string(), (vec![100, 200], vec![1.0, 1.0])); doc_fields.insert("abstract".to_string(), (vec![200, 300], vec![1.0, 1.0])); doc_fields.insert(
"content".to_string(),
(vec![100, 200, 300], vec![2.0, 3.0, 1.0]),
);
let field_weights = parse_field_weights(&["title^3", "abstract^2", "content"]);
let score = bm25f_score(
&query_indices,
&query_weights,
&doc_fields,
&field_weights,
&stats,
&BM25Config::default(),
);
assert!(score > 5.0); }
}