use crate::util::small_float;
#[derive(Debug, Clone)]
pub struct CollectionStatistics {
field: String,
max_doc: i64,
doc_count: i64,
sum_total_term_freq: i64,
sum_doc_freq: i64,
}
impl CollectionStatistics {
pub fn new(
field: String,
max_doc: i64,
doc_count: i64,
sum_total_term_freq: i64,
sum_doc_freq: i64,
) -> Self {
assert!(max_doc > 0, "maxDoc must be positive, maxDoc: {}", max_doc);
assert!(
doc_count > 0,
"docCount must be positive, docCount: {}",
doc_count
);
assert!(
doc_count <= max_doc,
"docCount must not exceed maxDoc, docCount: {}, maxDoc: {}",
doc_count,
max_doc
);
assert!(
sum_doc_freq > 0,
"sumDocFreq must be positive, sumDocFreq: {}",
sum_doc_freq
);
assert!(
sum_doc_freq >= doc_count,
"sumDocFreq must be at least docCount, sumDocFreq: {}, docCount: {}",
sum_doc_freq,
doc_count
);
assert!(
sum_total_term_freq > 0,
"sumTotalTermFreq must be positive, sumTotalTermFreq: {}",
sum_total_term_freq
);
assert!(
sum_total_term_freq >= sum_doc_freq,
"sumTotalTermFreq must be at least sumDocFreq, sumTotalTermFreq: {}, sumDocFreq: {}",
sum_total_term_freq,
sum_doc_freq
);
Self {
field,
max_doc,
doc_count,
sum_total_term_freq,
sum_doc_freq,
}
}
pub fn field(&self) -> &str {
&self.field
}
pub fn max_doc(&self) -> i64 {
self.max_doc
}
pub fn doc_count(&self) -> i64 {
self.doc_count
}
pub fn sum_total_term_freq(&self) -> i64 {
self.sum_total_term_freq
}
pub fn sum_doc_freq(&self) -> i64 {
self.sum_doc_freq
}
}
#[derive(Debug, Clone)]
pub struct TermStatistics {
term: Vec<u8>,
doc_freq: i64,
total_term_freq: i64,
}
impl TermStatistics {
pub fn new(term: Vec<u8>, doc_freq: i64, total_term_freq: i64) -> Self {
assert!(
doc_freq > 0,
"docFreq must be positive, docFreq: {}",
doc_freq
);
assert!(
total_term_freq > 0,
"totalTermFreq must be positive, totalTermFreq: {}",
total_term_freq
);
assert!(
total_term_freq >= doc_freq,
"totalTermFreq must be at least docFreq, totalTermFreq: {}, docFreq: {}",
total_term_freq,
doc_freq
);
Self {
term,
doc_freq,
total_term_freq,
}
}
pub fn term(&self) -> &[u8] {
&self.term
}
pub fn doc_freq(&self) -> i64 {
self.doc_freq
}
pub fn total_term_freq(&self) -> i64 {
self.total_term_freq
}
}
pub trait SimScorer {
fn score(&self, freq: f32, norm: i64) -> f32;
fn box_clone(&self) -> Box<dyn SimScorer>;
}
pub trait BulkSimScorer {
fn score(&self, freqs: &[f32], norms: &[i64], scores: &mut [f32]);
}
pub struct DefaultBulkSimScorer<'a> {
scorer: &'a dyn SimScorer,
}
impl<'a> DefaultBulkSimScorer<'a> {
pub fn new(scorer: &'a dyn SimScorer) -> Self {
Self { scorer }
}
}
impl BulkSimScorer for DefaultBulkSimScorer<'_> {
fn score(&self, freqs: &[f32], norms: &[i64], scores: &mut [f32]) {
let size = freqs.len().min(norms.len()).min(scores.len());
for i in 0..size {
scores[i] = self.scorer.score(freqs[i], norms[i]);
}
}
}
pub trait Similarity {
fn get_discount_overlaps(&self) -> bool;
fn compute_norm(&self, num_terms: i32) -> i64;
fn scorer(
&self,
boost: f32,
collection_stats: &CollectionStatistics,
term_stats: &[TermStatistics],
) -> Box<dyn SimScorer>;
fn box_clone(&self) -> Box<dyn Similarity>;
}
static LENGTH_TABLE: [f32; 256] = {
let mut table = [0.0f32; 256];
let mut i = 0u32;
while i < 256 {
table[i as usize] = small_float::byte4_to_int(i as u8) as f32;
i += 1;
}
table
};
#[derive(Debug)]
pub struct BM25Similarity {
k1: f32,
b: f32,
discount_overlaps: bool,
}
impl BM25Similarity {
pub fn new(k1: f32, b: f32, discount_overlaps: bool) -> Self {
assert!(
k1.is_finite() && k1 >= 0.0,
"illegal k1 value: {}, must be a non-negative finite value",
k1
);
assert!(
!b.is_nan() && (0.0..=1.0).contains(&b),
"illegal b value: {}, must be between 0 and 1",
b
);
Self {
k1,
b,
discount_overlaps,
}
}
pub fn new_with_defaults(k1: f32, b: f32) -> Self {
Self::new(k1, b, true)
}
fn idf(doc_freq: i64, doc_count: i64) -> f32 {
((1.0_f64 + (doc_count as f64 - doc_freq as f64 + 0.5) / (doc_freq as f64 + 0.5)).ln())
as f32
}
fn avg_field_length(collection_stats: &CollectionStatistics) -> f32 {
(collection_stats.sum_total_term_freq() as f64 / collection_stats.doc_count() as f64) as f32
}
pub fn get_k1(&self) -> f32 {
self.k1
}
pub fn get_b(&self) -> f32 {
self.b
}
}
impl Default for BM25Similarity {
fn default() -> Self {
Self::new(1.2, 0.75, true)
}
}
impl Similarity for BM25Similarity {
fn get_discount_overlaps(&self) -> bool {
self.discount_overlaps
}
fn compute_norm(&self, num_terms: i32) -> i64 {
small_float::int_to_byte4(num_terms) as i64
}
fn scorer(
&self,
boost: f32,
collection_stats: &CollectionStatistics,
term_stats: &[TermStatistics],
) -> Box<dyn SimScorer> {
let idf = if term_stats.len() == 1 {
Self::idf(term_stats[0].doc_freq(), collection_stats.doc_count())
} else {
let mut idf_sum = 0.0_f64;
for ts in term_stats {
idf_sum += Self::idf(ts.doc_freq(), collection_stats.doc_count()) as f64;
}
idf_sum as f32
};
let avgdl = Self::avg_field_length(collection_stats);
let mut cache = [0.0f32; 256];
for (i, entry) in cache.iter_mut().enumerate() {
*entry = 1.0 / (self.k1 * ((1.0 - self.b) + self.b * LENGTH_TABLE[i] / avgdl));
}
let weight = boost * idf;
Box::new(BM25Scorer { weight, cache })
}
fn box_clone(&self) -> Box<dyn Similarity> {
Box::new(BM25Similarity::new(self.k1, self.b, self.discount_overlaps))
}
}
struct BM25Scorer {
weight: f32,
cache: [f32; 256],
}
impl BM25Scorer {
fn do_score(&self, freq: f32, norm_inverse: f32) -> f32 {
self.weight - self.weight / (1.0 + freq * norm_inverse)
}
}
impl SimScorer for BM25Scorer {
fn score(&self, freq: f32, encoded_norm: i64) -> f32 {
let norm_inverse = self.cache[(encoded_norm as u8) as usize];
self.do_score(freq, norm_inverse)
}
fn box_clone(&self) -> Box<dyn SimScorer> {
Box::new(BM25Scorer {
weight: self.weight,
cache: self.cache,
})
}
}
#[expect(dead_code)]
pub(crate) struct BM25BulkSimScorer {
weight: f32,
cache: [f32; 256],
norm_inverses: Vec<f32>,
}
impl BM25BulkSimScorer {
#[expect(dead_code)]
pub(crate) fn new(weight: f32, cache: [f32; 256]) -> Self {
Self {
weight,
cache,
norm_inverses: Vec::new(),
}
}
}
impl BulkSimScorer for BM25BulkSimScorer {
fn score(&self, freqs: &[f32], norms: &[i64], scores: &mut [f32]) {
let size = freqs.len().min(norms.len()).min(scores.len());
let mut norm_inverses = vec![0.0f32; size];
for i in 0..size {
norm_inverses[i] = self.cache[(norms[i] as u8) as usize];
}
let weight = self.weight;
for i in 0..size {
scores[i] = weight - weight / (1.0 + freqs[i] * norm_inverses[i]);
}
}
}
#[cfg(test)]
mod tests {
use std::slice;
use super::*;
use assertables::*;
#[test]
fn test_collection_stats_valid() {
let stats = CollectionStatistics::new("body".to_string(), 100, 50, 500, 200);
assert_eq!(stats.field(), "body");
assert_eq!(stats.max_doc(), 100);
assert_eq!(stats.doc_count(), 50);
assert_eq!(stats.sum_total_term_freq(), 500);
assert_eq!(stats.sum_doc_freq(), 200);
}
#[test]
#[should_panic(expected = "maxDoc must be positive")]
fn test_collection_stats_max_doc_zero() {
CollectionStatistics::new("f".to_string(), 0, 1, 1, 1);
}
#[test]
#[should_panic(expected = "maxDoc must be positive")]
fn test_collection_stats_max_doc_negative() {
CollectionStatistics::new("f".to_string(), -1, 1, 1, 1);
}
#[test]
#[should_panic(expected = "docCount must be positive")]
fn test_collection_stats_doc_count_zero() {
CollectionStatistics::new("f".to_string(), 10, 0, 1, 1);
}
#[test]
#[should_panic(expected = "docCount must not exceed maxDoc")]
fn test_collection_stats_doc_count_exceeds_max_doc() {
CollectionStatistics::new("f".to_string(), 10, 11, 20, 15);
}
#[test]
#[should_panic(expected = "sumDocFreq must be positive")]
fn test_collection_stats_sum_doc_freq_zero() {
CollectionStatistics::new("f".to_string(), 10, 5, 10, 0);
}
#[test]
#[should_panic(expected = "sumDocFreq must be at least docCount")]
fn test_collection_stats_sum_doc_freq_less_than_doc_count() {
CollectionStatistics::new("f".to_string(), 10, 5, 10, 3);
}
#[test]
#[should_panic(expected = "sumTotalTermFreq must be positive")]
fn test_collection_stats_sum_total_term_freq_zero() {
CollectionStatistics::new("f".to_string(), 10, 5, 0, 5);
}
#[test]
#[should_panic(expected = "sumTotalTermFreq must be at least sumDocFreq")]
fn test_collection_stats_sum_total_term_freq_less_than_sum_doc_freq() {
CollectionStatistics::new("f".to_string(), 10, 5, 4, 5);
}
#[test]
fn test_collection_stats_minimum_valid() {
let stats = CollectionStatistics::new("f".to_string(), 1, 1, 1, 1);
assert_eq!(stats.max_doc(), 1);
assert_eq!(stats.doc_count(), 1);
assert_eq!(stats.sum_total_term_freq(), 1);
assert_eq!(stats.sum_doc_freq(), 1);
}
fn term(s: &str) -> Vec<u8> {
s.as_bytes().to_vec()
}
#[test]
fn test_term_stats_valid() {
let stats = TermStatistics::new(term("hello"), 10, 50);
assert_eq!(stats.term(), &term("hello"));
assert_eq!(stats.doc_freq(), 10);
assert_eq!(stats.total_term_freq(), 50);
}
#[test]
#[should_panic(expected = "docFreq must be positive")]
fn test_term_stats_doc_freq_zero() {
TermStatistics::new(term("t"), 0, 1);
}
#[test]
#[should_panic(expected = "docFreq must be positive")]
fn test_term_stats_doc_freq_negative() {
TermStatistics::new(term("t"), -1, 1);
}
#[test]
#[should_panic(expected = "totalTermFreq must be positive")]
fn test_term_stats_total_term_freq_zero() {
TermStatistics::new(term("t"), 1, 0);
}
#[test]
#[should_panic(expected = "totalTermFreq must be positive")]
fn test_term_stats_total_term_freq_negative() {
TermStatistics::new(term("t"), 1, -1);
}
#[test]
#[should_panic(expected = "totalTermFreq must be at least docFreq")]
fn test_term_stats_total_term_freq_less_than_doc_freq() {
TermStatistics::new(term("t"), 10, 5);
}
#[test]
fn test_term_stats_minimum_valid() {
let stats = TermStatistics::new(term("t"), 1, 1);
assert_eq!(stats.doc_freq(), 1);
assert_eq!(stats.total_term_freq(), 1);
}
struct ConstantSimScorer {
value: f32,
}
impl SimScorer for ConstantSimScorer {
fn score(&self, _freq: f32, _norm: i64) -> f32 {
self.value
}
fn box_clone(&self) -> Box<dyn SimScorer> {
Box::new(ConstantSimScorer { value: self.value })
}
}
#[test]
fn test_sim_scorer() {
let scorer = ConstantSimScorer { value: 2.5 };
assert_eq!(scorer.score(1.0, 1), 2.5);
assert_eq!(scorer.score(5.0, 100), 2.5);
}
#[test]
fn test_default_bulk_sim_scorer() {
let scorer = ConstantSimScorer { value: 3.0 };
let bulk = DefaultBulkSimScorer::new(&scorer);
let freqs = [1.0, 2.0, 3.0];
let norms = [1, 2, 3];
let mut scores = [0.0f32; 3];
bulk.score(&freqs, &norms, &mut scores);
assert_eq!(scores, [3.0, 3.0, 3.0]);
}
#[test]
fn test_default_bulk_sim_scorer_varying_scores() {
struct LinearSimScorer;
impl SimScorer for LinearSimScorer {
fn score(&self, freq: f32, _norm: i64) -> f32 {
freq * 2.0
}
fn box_clone(&self) -> Box<dyn SimScorer> {
Box::new(LinearSimScorer)
}
}
let scorer = LinearSimScorer;
let bulk = DefaultBulkSimScorer::new(&scorer);
let freqs = [1.0, 2.0, 4.0];
let norms = [1, 1, 1];
let mut scores = [0.0f32; 3];
bulk.score(&freqs, &norms, &mut scores);
assert_eq!(scores, [2.0, 4.0, 8.0]);
}
fn test_collection_stats() -> CollectionStatistics {
CollectionStatistics::new("body".to_string(), 100, 100, 1000, 500)
}
fn test_term_stats() -> TermStatistics {
TermStatistics::new(term("test"), 10, 50)
}
#[test]
fn test_bm25_default_parameters() {
let sim = BM25Similarity::default();
assert_in_delta!(sim.get_k1(), 1.2, 0.001);
assert_in_delta!(sim.get_b(), 0.75, 0.001);
assert!(sim.get_discount_overlaps());
}
#[test]
fn test_bm25_idf_single_doc() {
let result = BM25Similarity::idf(1, 1);
assert_in_delta!(result, (1.0_f64 + 0.5 / 1.5).ln() as f32, 0.0001);
}
#[test]
fn test_bm25_idf_rare_term() {
let result = BM25Similarity::idf(1, 10000);
let expected = (1.0_f64 + (10000.0 - 1.0 + 0.5) / (1.0 + 0.5)).ln() as f32;
assert_in_delta!(result, expected, 0.0001);
assert_gt!(result, 0.0);
}
#[test]
fn test_bm25_idf_common_term() {
let result = BM25Similarity::idf(9999, 10000);
assert_gt!(result, 0.0);
assert_lt!(result, BM25Similarity::idf(1, 10000));
}
#[test]
fn test_bm25_score_increases_with_freq() {
let sim = BM25Similarity::default();
let coll = test_collection_stats();
let ts = test_term_stats();
let scorer = sim.scorer(1.0, &coll, &[ts]);
let norm: i64 = 10;
let score1 = scorer.score(1.0, norm);
let score2 = scorer.score(5.0, norm);
let score3 = scorer.score(20.0, norm);
assert_gt!(score1, 0.0);
assert_gt!(score2, score1);
assert_gt!(score3, score2);
}
#[test]
fn test_bm25_score_with_boost() {
let sim = BM25Similarity::default();
let coll = test_collection_stats();
let ts = test_term_stats();
let scorer1 = sim.scorer(1.0, &coll, slice::from_ref(&ts));
let scorer2 = sim.scorer(2.0, &coll, &[ts]);
let norm: i64 = 10;
let score1 = scorer1.score(5.0, norm);
let score2 = scorer2.score(5.0, norm);
assert_in_delta!(score2, score1 * 2.0, 0.001);
}
#[test]
fn test_bm25_length_table_small_values() {
assert_in_delta!(LENGTH_TABLE[0], 0.0, 0.001);
assert_in_delta!(LENGTH_TABLE[1], 1.0, 0.001);
assert_in_delta!(LENGTH_TABLE[10], 10.0, 0.001);
}
#[test]
fn test_bm25_length_table_monotonic() {
for i in 1..256 {
assert_ge!(
LENGTH_TABLE[i],
LENGTH_TABLE[i - 1],
"LENGTH_TABLE not monotonic at index {i}"
);
}
}
#[test]
#[should_panic(expected = "illegal k1 value")]
fn test_bm25_negative_k1() {
BM25Similarity::new(-1.0, 0.75, true);
}
#[test]
#[should_panic(expected = "illegal k1 value")]
fn test_bm25_infinite_k1() {
BM25Similarity::new(f32::INFINITY, 0.75, true);
}
#[test]
#[should_panic(expected = "illegal b value")]
fn test_bm25_b_out_of_range() {
BM25Similarity::new(1.2, 1.5, true);
}
#[test]
#[should_panic(expected = "illegal b value")]
fn test_bm25_negative_b() {
BM25Similarity::new(1.2, -0.1, true);
}
#[test]
fn test_bm25_bulk_scorer_matches_individual() {
let sim = BM25Similarity::default();
let coll = test_collection_stats();
let ts = test_term_stats();
let scorer = sim.scorer(1.0, &coll, &[ts]);
let bulk = DefaultBulkSimScorer::new(scorer.as_ref());
let freqs = [1.0, 3.0, 5.0, 10.0];
let norms = [10i64, 20, 30, 40];
let mut bulk_scores = [0.0f32; 4];
bulk.score(&freqs, &norms, &mut bulk_scores);
for i in 0..4 {
let individual = scorer.score(freqs[i], norms[i]);
assert_in_delta!(bulk_scores[i], individual, 0.0001);
}
}
}