use libm::Libm;
pub struct Bm25Params {
pub k1: f32,
pub b: f32,
}
impl Default for Bm25Params {
fn default() -> Self {
Self { k1: 1.2, b: 0.75 }
}
}
pub fn idf(total_docs: u64, doc_freq: u32) -> f32 {
let n = total_docs as f32;
let df = doc_freq as f32;
if df == 0.0 {
return 0.0;
}
let numerator = n - df + 0.5;
let denominator = df + 0.5;
Libm::<f32>::log(numerator / denominator + 1.0)
}
pub fn bm25_score(
term_freq: u32,
doc_length: u32,
avg_doc_length: f32,
total_docs: u64,
doc_freq: u32,
params: &Bm25Params,
) -> f32 {
let tf = term_freq as f32;
let dl = doc_length as f32;
let avgdl = if avg_doc_length > 0.0 {
avg_doc_length
} else {
1.0
};
let idf_score = idf(total_docs, doc_freq);
let length_norm = 1.0 - params.b + params.b * (dl / avgdl);
let denominator = tf + params.k1 * length_norm;
let numerator = tf * (params.k1 + 1.0);
idf_score * (numerator / denominator)
}
pub fn multi_term_bm25_score(term_scores: impl Iterator<Item = f32>) -> f32 {
term_scores.sum()
}
pub fn normalize_score(score: f32, max_score: f32) -> f32 {
if max_score <= 0.0 {
return 0.0;
}
(score / max_score).clamp(0.0, 1.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_idf_basic() {
let rare_idf = idf(1000, 10);
let common_idf = idf(1000, 500);
assert!(rare_idf > common_idf);
assert!(rare_idf > 0.0);
}
#[test]
fn test_idf_edge_cases() {
let all_docs = idf(100, 100);
let few_docs = idf(100, 1);
assert!(few_docs > all_docs);
}
#[test]
fn test_bm25_score() {
let params = Bm25Params::default();
let high_tf = bm25_score(10, 100, 100.0, 1000, 50, ¶ms);
let low_tf = bm25_score(1, 100, 100.0, 1000, 50, ¶ms);
assert!(high_tf > low_tf);
}
#[test]
fn test_bm25_length_normalization() {
let params = Bm25Params::default();
let short_doc = bm25_score(5, 50, 100.0, 1000, 50, ¶ms);
let long_doc = bm25_score(5, 200, 100.0, 1000, 50, ¶ms);
assert!(short_doc > long_doc);
}
}