use crate::bm25::config::Bm25Config;
#[derive(Debug, Clone, Copy)]
pub struct Scorer {
k1: f32,
b: f32,
}
impl Scorer {
pub fn from_config(config: &Bm25Config) -> Self {
Self { k1: config.k1, b: config.b }
}
pub fn new(k1: f32, b: f32) -> Self {
Self { k1, b }
}
pub fn score(
&self,
query_terms: &[String],
term_frequencies: &std::collections::HashMap<String, u32>,
doc_length: u32,
avg_doc_length: f32,
doc_frequencies: &std::collections::HashMap<String, u32>,
corpus_size: u32,
) -> f32 {
if avg_doc_length < f32::EPSILON {
return 0.0;
}
let mut seen = std::collections::HashSet::new();
let mut total = 0.0_f32;
for term in query_terms {
if !seen.insert(term) {
continue; }
let tf = match term_frequencies.get(term) {
Some(&f) if f > 0 => f as f32,
_ => continue,
};
let df = doc_frequencies.get(term).copied().unwrap_or(0) as f32;
let idf = self.idf(df, corpus_size as f32);
let norm = doc_length as f32 / avg_doc_length;
let numerator = tf * (self.k1 + 1.0);
let denominator = tf + self.k1 * (1.0 - self.b + self.b * norm);
total += idf * (numerator / denominator);
}
total
}
pub fn idf(&self, doc_frequency: f32, corpus_size: f32) -> f32 {
if corpus_size < f32::EPSILON {
return 0.0;
}
let numerator = corpus_size - doc_frequency + 0.5;
let denominator = doc_frequency + 0.5;
((numerator / denominator) + 1.0).ln().max(0.0)
}
}
impl Default for Scorer {
fn default() -> Self {
Self::from_config(&Bm25Config::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn scorer() -> Scorer {
Scorer::default()
}
fn tf(pairs: &[(&str, u32)]) -> HashMap<String, u32> {
pairs.iter().map(|(k, v)| (k.to_string(), *v)).collect()
}
fn df(pairs: &[(&str, u32)]) -> HashMap<String, u32> {
pairs.iter().map(|(k, v)| (k.to_string(), *v)).collect()
}
#[test]
fn idf_zero_corpus_is_zero() {
assert_eq!(scorer().idf(0.0, 0.0), 0.0);
}
#[test]
fn idf_term_in_all_docs_approaches_zero() {
let idf = scorer().idf(100.0, 100.0);
assert!(idf >= 0.0);
assert!(idf < 0.01);
}
#[test]
fn idf_term_in_one_doc_of_large_corpus_is_high() {
let idf = scorer().idf(1.0, 1000.0);
assert!(idf > 5.0);
}
#[test]
fn idf_is_non_negative() {
for df_val in [0.0, 0.5, 1.0, 10.0, 100.0] {
assert!(scorer().idf(df_val, 100.0) >= 0.0);
}
}
#[test]
fn idf_increases_as_doc_frequency_decreases() {
let rare = scorer().idf(1.0, 100.0);
let common = scorer().idf(50.0, 100.0);
assert!(rare > common);
}
#[test]
fn score_zero_when_avg_doc_length_is_zero() {
let s = scorer().score(
&["rust".to_string()],
&tf(&[("rust", 3)]),
10, 0.0,
&df(&[("rust", 5)]),
100,
);
assert_eq!(s, 0.0);
}
#[test]
fn score_zero_when_term_not_in_document() {
let s = scorer().score(
&["absent".to_string()],
&tf(&[]),
10, 10.0,
&df(&[("absent", 2)]),
100,
);
assert_eq!(s, 0.0);
}
#[test]
fn score_zero_when_query_empty() {
let s = scorer().score(
&[],
&tf(&[("rust", 3)]),
10, 10.0,
&df(&[("rust", 5)]),
100,
);
assert_eq!(s, 0.0);
}
#[test]
fn score_is_positive_for_matching_term() {
let s = scorer().score(
&["rust".to_string()],
&tf(&[("rust", 3)]),
10, 10.0,
&df(&[("rust", 5)]),
100,
);
assert!(s > 0.0);
}
#[test]
fn score_increases_with_term_frequency() {
let low = scorer().score(
&["rust".to_string()],
&tf(&[("rust", 1)]),
10, 10.0,
&df(&[("rust", 5)]),
100,
);
let high = scorer().score(
&["rust".to_string()],
&tf(&[("rust", 10)]),
10, 10.0,
&df(&[("rust", 5)]),
100,
);
assert!(high > low);
}
#[test]
fn score_saturates_at_high_term_frequency() {
let s10 = scorer().score(
&["rust".to_string()], &tf(&[("rust", 10)]),
10, 10.0, &df(&[("rust", 5)]), 100,
);
let s1000 = scorer().score(
&["rust".to_string()], &tf(&[("rust", 1000)]),
10, 10.0, &df(&[("rust", 5)]), 100,
);
assert!(s1000 / s10 < 5.0);
}
#[test]
fn score_decreases_for_longer_documents() {
let short = scorer().score(
&["rust".to_string()], &tf(&[("rust", 2)]),
5, 10.0, &df(&[("rust", 5)]), 100,
);
let long = scorer().score(
&["rust".to_string()], &tf(&[("rust", 2)]),
50, 10.0, &df(&[("rust", 5)]), 100,
);
assert!(short > long);
}
#[test]
fn score_accumulates_across_multiple_terms() {
let single = scorer().score(
&["rust".to_string()],
&tf(&[("rust", 2), ("programming", 1)]),
10, 10.0,
&df(&[("rust", 5), ("programming", 3)]),
100,
);
let multi = scorer().score(
&["rust".to_string(), "programming".to_string()],
&tf(&[("rust", 2), ("programming", 1)]),
10, 10.0,
&df(&[("rust", 5), ("programming", 3)]),
100,
);
assert!(multi > single);
}
#[test]
fn duplicate_query_terms_are_counted_once() {
let single = scorer().score(
&["rust".to_string()],
&tf(&[("rust", 2)]),
10, 10.0, &df(&[("rust", 5)]), 100,
);
let doubled = scorer().score(
&["rust".to_string(), "rust".to_string()],
&tf(&[("rust", 2)]),
10, 10.0, &df(&[("rust", 5)]), 100,
);
assert!((single - doubled).abs() < 1e-5);
}
#[test]
fn b_zero_disables_length_normalisation() {
let scorer_no_norm = Scorer::new(1.2, 0.0);
let short = scorer_no_norm.score(
&["rust".to_string()], &tf(&[("rust", 2)]),
5, 10.0, &df(&[("rust", 5)]), 100,
);
let long = scorer_no_norm.score(
&["rust".to_string()], &tf(&[("rust", 2)]),
100, 10.0, &df(&[("rust", 5)]), 100,
);
assert!((short - long).abs() < 1e-5);
}
}