use crate::core::{DocId, NO_MORE_DOCS, Scorer, TwoPhaseIterator};
use crate::inverted::norms::FieldNormsReader;
use crate::inverted::postings::{BlockMaxPostingListReader, PostingListReader};
const K1: f32 = 1.2;
const B: f32 = 0.75;
pub fn bm25_idf(total_docs: u32, doc_freq: u32) -> f32 {
let n = doc_freq as f64;
let big_n = total_docs as f64;
((1.0 + (big_n - n + 0.5) / (n + 0.5)).ln()) as f32
}
pub fn bm25_score(idf: f32, tf: f32, dl: f32, avgdl: f32) -> f32 {
let norm_tf = (tf * (K1 + 1.0)) / (tf + K1 * (1.0 - B + B * dl / avgdl));
idf * norm_tf
}
#[derive(Clone)]
pub struct Bm25Weight {
pub idf: f32,
pub avg_field_length: f32,
norm_cache: [f32; 256],
}
impl Bm25Weight {
pub fn new(total_docs: u32, doc_freq: u32, avg_field_length: f32) -> Self {
let idf = bm25_idf(total_docs, doc_freq);
let avgdl = avg_field_length;
let mut norm_cache = [0.0f32; 256];
for i in 0..256 {
let dl = crate::inverted::norms::decode_norm(i as u8);
norm_cache[i] = K1 * (1.0 - B + B * dl / avgdl);
}
Self {
idf,
avg_field_length: avgdl,
norm_cache,
}
}
pub fn max_score_for_tf(&self, max_tf: f32) -> f32 {
bm25_score(self.idf, max_tf, 1.0, self.avg_field_length)
}
pub fn max_score_unbounded(&self) -> f32 {
self.idf * (K1 + 1.0)
}
}
pub struct Bm25Scorer<'a> {
weight: Bm25Weight,
postings: PostingListReader<'a>,
norms: FieldNormsReader<'a>,
current_doc_id: DocId,
current_tf: u32,
constant_score: Option<f32>,
}
impl<'a> Bm25Scorer<'a> {
pub fn new(
weight: Bm25Weight,
postings: PostingListReader<'a>,
norms: FieldNormsReader<'a>,
) -> Self {
let constant_score = norms
.uniform_norm()
.map(|dl| bm25_score(weight.idf, 1.0, dl, weight.avg_field_length));
let mut scorer = Self {
weight,
postings,
norms,
current_doc_id: DocId::new(0),
current_tf: 0,
constant_score,
};
scorer.read_next();
scorer
}
fn read_next(&mut self) {
match self.postings.next() {
Some((doc_id, tf)) => {
self.current_doc_id = doc_id;
self.current_tf = tf;
}
None => {
self.current_doc_id = NO_MORE_DOCS;
self.current_tf = 0;
}
}
}
}
impl Scorer for Bm25Scorer<'_> {
fn doc_id(&self) -> DocId {
self.current_doc_id
}
fn next(&mut self) -> DocId {
self.read_next();
self.current_doc_id
}
fn advance(&mut self, target: DocId) -> DocId {
while self.current_doc_id < target && self.current_doc_id != NO_MORE_DOCS {
self.read_next();
}
self.current_doc_id
}
fn score(&mut self) -> f32 {
if self.current_tf == 1 {
if let Some(cs) = self.constant_score {
return cs;
}
}
let tf = self.current_tf as f32;
let norm_byte = self.norms.raw_byte(self.current_doc_id);
let denom = tf + self.weight.norm_cache[norm_byte as usize];
self.weight.idf * (tf * (K1 + 1.0)) / denom
}
fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
None
}
fn max_score(&self) -> f32 {
self.weight.max_score_unbounded()
}
}
pub struct BlockMaxBm25Scorer<'a> {
weight: Bm25Weight,
postings: BlockMaxPostingListReader<'a>,
norms: FieldNormsReader<'a>,
current_doc_id: DocId,
current_tf: u32,
term_max_score: f32,
block_max_scores: Vec<f32>,
constant_score: Option<f32>,
}
impl<'a> BlockMaxBm25Scorer<'a> {
pub fn new(
weight: Bm25Weight,
postings: BlockMaxPostingListReader<'a>,
norms: FieldNormsReader<'a>,
) -> Self {
let num_blocks = postings.num_blocks();
let mut block_max_scores = Vec::with_capacity(num_blocks as usize);
let mut global_max = 0.0f32;
for b in 0..num_blocks {
let max_tf = postings.block_max_tf(b) as f32;
let score = weight.max_score_for_tf(max_tf);
block_max_scores.push(score);
if score > global_max {
global_max = score;
}
}
let constant_score = norms
.uniform_norm()
.map(|dl| bm25_score(weight.idf, 1.0, dl, weight.avg_field_length));
let mut scorer = Self {
weight,
postings,
norms,
current_doc_id: DocId::new(0),
current_tf: 0,
term_max_score: global_max,
block_max_scores,
constant_score,
};
scorer.read_next();
scorer
}
fn read_next(&mut self) {
match self.postings.next() {
Some((doc_id, tf)) => {
self.current_doc_id = doc_id;
self.current_tf = tf;
}
None => {
self.current_doc_id = NO_MORE_DOCS;
self.current_tf = 0;
}
}
}
}
impl Scorer for BlockMaxBm25Scorer<'_> {
fn doc_id(&self) -> DocId {
self.current_doc_id
}
fn next(&mut self) -> DocId {
self.read_next();
self.current_doc_id
}
fn advance(&mut self, target: DocId) -> DocId {
if self.current_doc_id < target && self.current_doc_id != NO_MORE_DOCS {
self.postings.advance_to_block(target);
loop {
match self.postings.next() {
Some((doc_id, tf)) => {
self.current_doc_id = doc_id;
self.current_tf = tf;
if doc_id >= target {
return self.current_doc_id;
}
}
None => {
self.current_doc_id = NO_MORE_DOCS;
self.current_tf = 0;
return NO_MORE_DOCS;
}
}
}
}
self.current_doc_id
}
fn score(&mut self) -> f32 {
if self.current_tf == 1 {
if let Some(cs) = self.constant_score {
return cs;
}
}
let tf = self.current_tf as f32;
let norm_byte = self.norms.raw_byte(self.current_doc_id);
let denom = tf + self.weight.norm_cache[norm_byte as usize];
self.weight.idf * (tf * (K1 + 1.0)) / denom
}
fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
None
}
fn max_score(&self) -> f32 {
self.term_max_score
}
fn block_max_score(&mut self, doc: DocId) -> f32 {
self.postings.advance_shallow(doc);
let block = self.postings.current_block_idx();
if (block as usize) < self.block_max_scores.len() {
self.block_max_scores[block as usize]
} else {
self.term_max_score
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::FieldId;
use crate::inverted::norms::FieldNormsWriter;
use crate::inverted::postings::PostingListWriter;
fn make_postings(docs: &[(u32, u32)]) -> Vec<u8> {
let mut writer = PostingListWriter::new();
for &(doc_id, tf) in docs {
writer.add(DocId::new(doc_id), tf);
}
writer.finish()
}
fn make_norms(lengths: &[u32]) -> Vec<u8> {
let mut writer = FieldNormsWriter::new(FieldId::new(0));
for &len in lengths {
writer.add(len);
}
writer.finish()
}
#[test]
fn idf_rare_term_higher() {
let common_idf = bm25_idf(1000, 900); let rare_idf = bm25_idf(1000, 10); assert!(rare_idf > common_idf);
}
#[test]
fn idf_zero_doc_freq() {
let idf = bm25_idf(1000, 0);
assert!(idf > 0.0);
assert!(idf.is_finite());
}
#[test]
fn idf_all_docs_match() {
let idf = bm25_idf(1000, 1000);
assert!(idf >= 0.0);
}
#[test]
fn higher_tf_higher_score() {
let idf = bm25_idf(100, 10);
let s1 = bm25_score(idf, 1.0, 10.0, 10.0);
let s2 = bm25_score(idf, 5.0, 10.0, 10.0);
let s3 = bm25_score(idf, 20.0, 10.0, 10.0);
assert!(s2 > s1, "higher TF should give higher score");
assert!(s3 > s2, "even higher TF should give higher score");
}
#[test]
fn tf_saturation() {
let idf = bm25_idf(100, 10);
let s10 = bm25_score(idf, 10.0, 10.0, 10.0);
let s100 = bm25_score(idf, 100.0, 10.0, 10.0);
let s1000 = bm25_score(idf, 1000.0, 10.0, 10.0);
assert!(s100 > s10);
assert!(s1000 > s100);
assert!((s100 - s10) > (s1000 - s100));
}
#[test]
fn longer_docs_lower_score() {
let idf = bm25_idf(100, 10);
let short = bm25_score(idf, 2.0, 5.0, 10.0);
let avg = bm25_score(idf, 2.0, 10.0, 10.0);
let long = bm25_score(idf, 2.0, 20.0, 10.0);
assert!(short > avg, "shorter doc should score higher");
assert!(avg > long, "average doc should score higher than long");
}
#[test]
fn scorer_iterates_docs() {
let postings_data = make_postings(&[(0, 1), (5, 2), (10, 1)]);
let norms_data = make_norms(&[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]);
let weight = Bm25Weight::new(100, 3, 3.0);
let reader = PostingListReader::new(&postings_data);
let norms = FieldNormsReader::open(&norms_data);
let mut scorer = Bm25Scorer::new(weight, reader, norms);
assert_eq!(scorer.doc_id(), DocId::new(0));
assert_eq!(scorer.next(), DocId::new(5));
assert_eq!(scorer.next(), DocId::new(10));
assert_eq!(scorer.next(), NO_MORE_DOCS);
}
#[test]
fn scorer_advance() {
let postings_data = make_postings(&[(0, 1), (5, 2), (10, 1), (20, 3)]);
let norms_data = make_norms(&(0..21).map(|_| 5u32).collect::<Vec<_>>());
let weight = Bm25Weight::new(100, 4, 5.0);
let reader = PostingListReader::new(&postings_data);
let norms = FieldNormsReader::open(&norms_data);
let mut scorer = Bm25Scorer::new(weight, reader, norms);
assert_eq!(scorer.advance(DocId::new(5)), DocId::new(5));
assert_eq!(scorer.advance(DocId::new(15)), DocId::new(20));
assert_eq!(scorer.advance(DocId::new(21)), NO_MORE_DOCS);
}
#[test]
fn scorer_advance_past_end() {
let postings_data = make_postings(&[(0, 1), (1, 1)]);
let norms_data = make_norms(&[5, 5]);
let weight = Bm25Weight::new(10, 2, 5.0);
let reader = PostingListReader::new(&postings_data);
let norms = FieldNormsReader::open(&norms_data);
let mut scorer = Bm25Scorer::new(weight, reader, norms);
assert_eq!(scorer.advance(DocId::new(100)), NO_MORE_DOCS);
}
#[test]
fn scorer_scores_correctly() {
let postings_data = make_postings(&[(0, 3), (5, 1)]);
let norms_data = make_norms(&[10, 10, 10, 10, 10, 5, 10, 10, 10, 10]);
let avg_dl = 10.0 * 9.0 / 10.0 + 5.0 / 10.0; let weight = Bm25Weight::new(10, 2, avg_dl);
let reader = PostingListReader::new(&postings_data);
let norms = FieldNormsReader::open(&norms_data);
let mut scorer = Bm25Scorer::new(weight, reader, norms);
assert_eq!(scorer.doc_id(), DocId::new(0));
let score0 = scorer.score();
scorer.next();
assert_eq!(scorer.doc_id(), DocId::new(5));
let score5 = scorer.score();
assert!(score0 > 0.0);
assert!(score5 > 0.0);
assert_ne!(score0, score5);
}
#[test]
fn scorer_no_two_phase() {
let postings_data = make_postings(&[(0, 1)]);
let norms_data = make_norms(&[5]);
let weight = Bm25Weight::new(10, 1, 5.0);
let reader = PostingListReader::new(&postings_data);
let norms = FieldNormsReader::open(&norms_data);
let mut scorer = Bm25Scorer::new(weight, reader, norms);
assert!(scorer.two_phase().is_none());
}
#[test]
fn hand_computed_bm25() {
let idf = bm25_idf(100, 10);
let score = bm25_score(idf, 2.0, 15.0, 10.0);
let expected_idf = ((1.0 + 90.5 / 10.5) as f64).ln() as f32;
assert!(
(idf - expected_idf).abs() < 0.001,
"idf={idf} expected={expected_idf}"
);
let expected = expected_idf * (2.0 * 2.2) / (2.0 + 1.2 * (0.25 + 0.75 * 15.0 / 10.0));
assert!(
(score - expected).abs() < 0.01,
"score={score} expected={expected}"
);
}
}