use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Bm25Params {
pub k1: f64,
pub b: f64,
}
impl Default for Bm25Params {
fn default() -> Self {
Self { k1: 1.5, b: 0.75 }
}
}
pub fn score(
query_terms: &[String],
term_freq: &HashMap<String, u32>,
doc_len: u32,
avg_doc_len: f64,
n_docs_with: &HashMap<String, u32>,
total_docs: u32,
params: &Bm25Params,
) -> f64 {
if query_terms.is_empty() || total_docs == 0 {
return 0.0;
}
let n = total_docs as f64;
let dl = doc_len as f64;
let length_norm = if avg_doc_len > 0.0 {
params.b * (dl / avg_doc_len)
} else {
0.0
};
let denom_base = params.k1 * (1.0 - params.b + length_norm);
let mut total = 0.0;
for term in query_terms {
let tf = term_freq.get(term).copied().unwrap_or(0) as f64;
if tf == 0.0 {
continue;
}
let n_t = n_docs_with.get(term).copied().unwrap_or(0) as f64;
let idf = (1.0 + (n - n_t + 0.5) / (n_t + 0.5)).ln();
let numerator = tf * (params.k1 + 1.0);
let denominator = tf + denom_base;
total += idf * (numerator / denominator);
}
total
}
#[cfg(test)]
mod tests {
use super::*;
fn p() -> Bm25Params {
Bm25Params::default()
}
fn tf(pairs: &[(&str, u32)]) -> HashMap<String, u32> {
pairs.iter().map(|(k, v)| ((*k).to_string(), *v)).collect()
}
#[test]
fn empty_query_or_corpus_returns_zero() {
assert_eq!(score(&[], &tf(&[]), 0, 0.0, &tf(&[]), 0, &p()), 0.0);
let q = vec!["rust".to_string()];
assert_eq!(
score(
&q,
&tf(&[("rust", 3)]),
10,
10.0,
&tf(&[("rust", 1)]),
0,
&p()
),
0.0
);
}
#[test]
fn zero_term_freq_yields_zero_score() {
let q = vec!["rust".to_string()];
let s = score(
&q,
&tf(&[("python", 5)]),
10,
10.0,
&tf(&[("rust", 1), ("python", 1)]),
5,
&p(),
);
assert_eq!(s, 0.0);
}
#[test]
fn higher_tf_strictly_higher_score_at_fixed_length() {
let q = vec!["rust".to_string()];
let n_docs_with = tf(&[("rust", 2)]);
let s_low = score(&q, &tf(&[("rust", 1)]), 10, 10.0, &n_docs_with, 100, &p());
let s_hi = score(&q, &tf(&[("rust", 5)]), 10, 10.0, &n_docs_with, 100, &p());
assert!(s_hi > s_low, "tf=5 ({}) should beat tf=1 ({})", s_hi, s_low);
}
#[test]
fn longer_doc_scores_lower_at_same_tf() {
let q = vec!["rust".to_string()];
let n_docs_with = tf(&[("rust", 2)]);
let s_short = score(&q, &tf(&[("rust", 3)]), 10, 50.0, &n_docs_with, 100, &p());
let s_long = score(&q, &tf(&[("rust", 3)]), 200, 50.0, &n_docs_with, 100, &p());
assert!(
s_short > s_long,
"short ({}) should beat long ({}) at same tf",
s_short,
s_long
);
}
#[test]
fn rare_term_dominates_common_term() {
let q_common = vec!["the".to_string()];
let q_rare = vec!["quasar".to_string()];
let n_docs_with = tf(&[("the", 1000), ("quasar", 1)]);
let s_common = score(
&q_common,
&tf(&[("the", 2)]),
20,
20.0,
&n_docs_with,
1000,
&p(),
);
let s_rare = score(
&q_rare,
&tf(&[("quasar", 2)]),
20,
20.0,
&n_docs_with,
1000,
&p(),
);
assert!(
s_rare > s_common * 5.0,
"rare term ({}) should dominate common term ({})",
s_rare,
s_common
);
}
#[test]
fn hand_computed_reference_three_doc_corpus() {
let q = vec!["rust".to_string()];
let n_docs_with = tf(&[
("rust", 2),
("db", 3),
("lang", 1),
("python", 1),
("tool", 1),
]);
let avgdl = 3.0;
let s1 = score(
&q,
&tf(&[("rust", 2), ("db", 1)]),
3,
avgdl,
&n_docs_with,
3,
&p(),
);
let s2 = score(
&q,
&tf(&[("rust", 1), ("db", 1), ("lang", 1)]),
3,
avgdl,
&n_docs_with,
3,
&p(),
);
let s3 = score(
&q,
&tf(&[("python", 1), ("db", 1), ("tool", 1)]),
3,
avgdl,
&n_docs_with,
3,
&p(),
);
let idf = (1.0_f64 + (3.0 - 2.0 + 0.5) / (2.0 + 0.5)).ln();
let expected_s1 = idf * (2.0 * (1.5 + 1.0)) / (2.0 + 1.5);
let expected_s2 = idf * (1.0 * (1.5 + 1.0)) / (1.0 + 1.5);
let tol = f64::EPSILON * 16.0;
assert!(
(s1 - expected_s1).abs() < tol,
"doc1 score {} vs expected {}",
s1,
expected_s1
);
assert!(
(s2 - expected_s2).abs() < tol,
"doc2 score {} vs expected {}",
s2,
expected_s2
);
assert_eq!(s3, 0.0);
assert!(s1 > s2, "doc1 (tf=2) should outrank doc2 (tf=1)");
}
#[test]
fn duplicate_query_tokens_compound() {
let q_one = vec!["rust".to_string()];
let q_two = vec!["rust".to_string(), "rust".to_string()];
let n_docs_with = tf(&[("rust", 2)]);
let s1 = score(&q_one, &tf(&[("rust", 1)]), 5, 5.0, &n_docs_with, 10, &p());
let s2 = score(&q_two, &tf(&[("rust", 1)]), 5, 5.0, &n_docs_with, 10, &p());
assert!(
(s2 - 2.0 * s1).abs() < f64::EPSILON * 8.0,
"duplicated query token should double the score: 2*s1={}, s2={}",
2.0 * s1,
s2
);
}
}