use std::fmt::Debug;
use crate::error::Result;
use crate::lexical::query::Query;
use crate::lexical::query::matcher::Matcher;
use crate::util::simd;
type BooleanScorerClauses = std::cell::RefCell<Vec<(Box<dyn Scorer>, Box<dyn Matcher>)>>;
pub trait Scorer: Send + Debug {
fn score(&self, doc_id: u64, term_freq: f32, field_length: Option<f32>) -> f32;
fn boost(&self) -> f32;
fn set_boost(&mut self, boost: f32);
fn max_score(&self) -> f32;
fn name(&self) -> &'static str;
}
#[derive(Debug, Clone)]
pub struct BM25Scorer {
doc_freq: u64,
#[allow(dead_code)]
total_term_freq: u64,
#[allow(dead_code)]
field_doc_count: u64,
avg_field_length: f64,
total_docs: u64,
boost: f32,
k1: f32,
b: f32,
cached_idf: f32,
}
impl BM25Scorer {
fn compute_idf(doc_freq: u64, total_docs: u64) -> f32 {
if doc_freq == 0 || total_docs == 0 {
return 0.0;
}
let n = total_docs as f32;
let df = doc_freq as f32;
let base_idf = ((n - df + 0.5) / (df + 0.5)).ln();
let epsilon = 0.01;
base_idf.max(epsilon)
}
pub fn new(
doc_freq: u64,
total_term_freq: u64,
field_doc_count: u64,
avg_field_length: f64,
total_docs: u64,
boost: f32,
) -> Self {
let cached_idf = Self::compute_idf(doc_freq, total_docs);
BM25Scorer {
doc_freq,
total_term_freq,
field_doc_count,
avg_field_length,
total_docs,
boost,
k1: 1.2,
b: 0.75,
cached_idf,
}
}
#[allow(clippy::too_many_arguments)]
pub fn with_params(
doc_freq: u64,
total_term_freq: u64,
field_doc_count: u64,
avg_field_length: f64,
total_docs: u64,
boost: f32,
k1: f32,
b: f32,
) -> Self {
let cached_idf = Self::compute_idf(doc_freq, total_docs);
BM25Scorer {
doc_freq,
total_term_freq,
field_doc_count,
avg_field_length,
total_docs,
boost,
k1,
b,
cached_idf,
}
}
#[inline(always)]
fn idf(&self) -> f32 {
self.cached_idf
}
fn tf(&self, term_freq: f32, field_length: f32) -> f32 {
if term_freq == 0.0 {
return 0.0;
}
let avg_len = self.avg_field_length as f32;
let norm_factor = if avg_len == 0.0 {
1.0
} else {
1.0 - self.b + self.b * (field_length / avg_len)
};
(term_freq * (self.k1 + 1.0)) / (term_freq + self.k1 * norm_factor)
}
pub fn k1(&self) -> f32 {
self.k1
}
pub fn b(&self) -> f32 {
self.b
}
pub fn set_k1(&mut self, k1: f32) {
self.k1 = k1;
}
pub fn set_b(&mut self, b: f32) {
self.b = b;
}
}
impl Scorer for BM25Scorer {
fn score(&self, _doc_id: u64, term_freq: f32, field_length: Option<f32>) -> f32 {
if self.doc_freq == 0 || self.total_docs == 0 || term_freq == 0.0 {
return 0.0;
}
let idf = self.idf();
let field_len = field_length.unwrap_or(self.avg_field_length as f32);
let tf = self.tf(term_freq, field_len);
self.boost * idf * tf
}
fn boost(&self) -> f32 {
self.boost
}
fn set_boost(&mut self, boost: f32) {
self.boost = boost;
}
fn max_score(&self) -> f32 {
if self.doc_freq == 0 || self.total_docs == 0 {
return 0.0;
}
let idf = self.idf();
let max_tf = self.k1 + 1.0;
self.boost * idf * max_tf
}
fn name(&self) -> &'static str {
"BM25"
}
}
impl BM25Scorer {
pub fn batch_score(&self, term_freqs: &[f32], field_lengths: &[f32]) -> Vec<f32> {
assert_eq!(term_freqs.len(), field_lengths.len());
if term_freqs.len() >= 4 {
self.batch_score_optimized(term_freqs, field_lengths)
} else {
term_freqs
.iter()
.enumerate()
.map(|(i, &tf)| {
let idf = self.idf();
let tf_score = self.tf(tf, field_lengths[i]);
self.boost * idf * tf_score
})
.collect()
}
}
fn batch_score_optimized(&self, term_freqs: &[f32], field_lengths: &[f32]) -> Vec<f32> {
let avg_len = self.avg_field_length as f32;
let norm_factors: Vec<f32> = field_lengths
.iter()
.map(|&field_len| 1.0 - self.b + self.b * (field_len / avg_len))
.collect();
let tf_scores = simd::numeric::batch_bm25_tf(term_freqs, self.k1, &norm_factors);
let idf = self.idf();
let idf_scores = vec![idf; tf_scores.len()];
let boosts = vec![self.boost; tf_scores.len()];
simd::numeric::batch_bm25_final_score(&tf_scores, &idf_scores, &boosts)
}
pub fn batch_multi_term_score(
&self,
term_data: &[(Vec<f32>, Vec<f32>)], ) -> Vec<f32> {
let mut final_scores = Vec::new();
for (term_freqs, field_lengths) in term_data {
let term_scores = self.batch_score(term_freqs, field_lengths);
if final_scores.is_empty() {
final_scores = term_scores;
} else {
for (i, score) in term_scores.into_iter().enumerate() {
if i < final_scores.len() {
final_scores[i] += score;
} else {
final_scores.push(score);
}
}
}
}
final_scores
}
}
#[derive(Debug, Clone)]
pub struct ConstantScorer {
score: f32,
boost: f32,
}
impl ConstantScorer {
pub fn new(score: f32) -> Self {
ConstantScorer { score, boost: 1.0 }
}
pub fn with_boost(score: f32, boost: f32) -> Self {
ConstantScorer { score, boost }
}
pub fn score_value(&self) -> f32 {
self.score
}
pub fn set_score_value(&mut self, score: f32) {
self.score = score;
}
}
impl Scorer for ConstantScorer {
fn score(&self, _doc_id: u64, _term_freq: f32, _field_length: Option<f32>) -> f32 {
self.score * self.boost
}
fn boost(&self) -> f32 {
self.boost
}
fn set_boost(&mut self, boost: f32) {
self.boost = boost;
}
fn max_score(&self) -> f32 {
self.score * self.boost
}
fn name(&self) -> &'static str {
"Constant"
}
}
#[derive(Debug)]
pub struct BooleanScorer {
clauses: BooleanScorerClauses,
boost: f32,
}
unsafe impl Send for BooleanScorer {}
impl BooleanScorer {
pub fn new(
reader: &dyn crate::lexical::reader::LexicalIndexReader,
queries: Vec<Box<dyn Query>>,
) -> Result<Self> {
let mut clauses = Vec::new();
for query in queries {
let matcher = query.matcher(reader)?;
let scorer = query.scorer(reader)?;
clauses.push((scorer, matcher));
}
Ok(BooleanScorer {
clauses: std::cell::RefCell::new(clauses),
boost: 1.0,
})
}
}
impl Scorer for BooleanScorer {
fn score(&self, doc_id: u64, _term_freq: f32, field_length: Option<f32>) -> f32 {
let mut total_score = 0.0;
let mut clauses = self.clauses.borrow_mut();
for (scorer, matcher) in clauses.iter_mut() {
match matcher.skip_to(doc_id) {
Ok(true) if matcher.doc_id() == doc_id => {
let tf = matcher.term_freq() as f32;
total_score += scorer.score(doc_id, tf, field_length);
}
_ => {
}
}
}
total_score * self.boost
}
fn boost(&self) -> f32 {
self.boost
}
fn set_boost(&mut self, boost: f32) {
self.boost = boost;
}
fn max_score(&self) -> f32 {
let mut total_max = 0.0;
let clauses = self.clauses.borrow();
for (scorer, _) in clauses.iter() {
total_max += scorer.max_score();
}
total_max * self.boost
}
fn name(&self) -> &'static str {
"Boolean"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bm25_scorer_creation() {
let scorer = BM25Scorer::new(10, 100, 50, 10.0, 1000, 1.0);
assert_eq!(scorer.boost(), 1.0);
assert_eq!(scorer.k1(), 1.2);
assert_eq!(scorer.b(), 0.75);
assert_eq!(scorer.name(), "BM25");
}
#[test]
fn test_bm25_scorer_with_params() {
let scorer = BM25Scorer::with_params(10, 100, 50, 10.0, 1000, 2.0, 1.5, 0.8);
assert_eq!(scorer.boost(), 2.0);
assert_eq!(scorer.k1(), 1.5);
assert_eq!(scorer.b(), 0.8);
}
#[test]
fn test_bm25_scorer_idf() {
let scorer = BM25Scorer::new(10, 100, 50, 10.0, 1000, 1.0);
let idf = scorer.idf();
assert!(idf > 0.0);
let scorer_zero = BM25Scorer::new(0, 0, 0, 0.0, 0, 1.0);
assert_eq!(scorer_zero.idf(), 0.0);
}
#[test]
fn test_bm25_scorer_tf() {
let scorer = BM25Scorer::new(10, 100, 50, 10.0, 1000, 1.0);
let tf1 = scorer.tf(1.0, 10.0);
let tf2 = scorer.tf(2.0, 10.0);
assert!(tf2 > tf1);
assert_eq!(scorer.tf(0.0, 10.0), 0.0);
}
#[test]
fn test_bm25_scorer_score() {
let scorer = BM25Scorer::new(10, 100, 50, 10.0, 1000, 1.0);
let score1 = scorer.score(0, 1.0, None);
let score2 = scorer.score(0, 2.0, None);
assert!(score2 > score1);
assert_eq!(scorer.score(0, 0.0, None), 0.0);
}
#[test]
fn test_bm25_scorer_boost() {
let mut scorer = BM25Scorer::new(10, 100, 50, 10.0, 1000, 1.0);
let original_score = scorer.score(0, 1.0, None);
scorer.set_boost(2.0);
let boosted_score = scorer.score(0, 1.0, None);
assert_eq!(scorer.boost(), 2.0);
assert_eq!(boosted_score, original_score * 2.0);
}
#[test]
fn test_bm25_scorer_max_score() {
let scorer = BM25Scorer::new(10, 100, 50, 10.0, 1000, 1.0);
let max_score = scorer.max_score();
let actual_score = scorer.score(0, 1.0, None);
assert!(max_score >= actual_score);
}
#[test]
fn test_constant_scorer() {
let scorer = ConstantScorer::new(5.0);
assert_eq!(scorer.score_value(), 5.0);
assert_eq!(scorer.boost(), 1.0);
assert_eq!(scorer.name(), "Constant");
assert_eq!(scorer.score(0, 1.0, None), 5.0);
assert_eq!(scorer.score(100, 10.0, None), 5.0);
assert_eq!(scorer.score(0, 0.0, None), 5.0);
}
#[test]
fn test_constant_scorer_with_boost() {
let scorer = ConstantScorer::with_boost(5.0, 2.0);
assert_eq!(scorer.score_value(), 5.0);
assert_eq!(scorer.boost(), 2.0);
assert_eq!(scorer.score(0, 1.0, None), 10.0);
assert_eq!(scorer.max_score(), 10.0);
}
#[test]
fn test_constant_scorer_mutation() {
let mut scorer = ConstantScorer::new(5.0);
scorer.set_score_value(3.0);
assert_eq!(scorer.score_value(), 3.0);
assert_eq!(scorer.score(0, 1.0, None), 3.0);
scorer.set_boost(2.0);
assert_eq!(scorer.boost(), 2.0);
assert_eq!(scorer.score(0, 1.0, None), 6.0);
}
#[test]
fn test_bm25_batch_score() {
let scorer = BM25Scorer::new(10, 100, 50, 10.0, 1000, 1.0);
let term_freqs = vec![1.0, 2.0, 3.0, 4.0];
let field_lengths = vec![10.0, 15.0, 8.0, 12.0];
let batch_scores = scorer.batch_score(&term_freqs, &field_lengths);
for &score in &batch_scores {
assert!(score > 0.0);
}
assert_eq!(batch_scores.len(), term_freqs.len());
}
#[test]
fn test_bm25_batch_small() {
let scorer = BM25Scorer::new(5, 50, 25, 10.0, 500, 1.5);
let term_freqs = vec![1.5, 2.5];
let field_lengths = vec![8.0, 12.0];
let batch_scores = scorer.batch_score(&term_freqs, &field_lengths);
assert_eq!(batch_scores.len(), 2);
assert!(batch_scores[0] > 0.0);
assert!(batch_scores[1] > 0.0);
}
#[test]
fn test_bm25_multi_term_score() {
let scorer = BM25Scorer::new(10, 100, 50, 10.0, 1000, 1.0);
let term_data = vec![
(vec![1.0, 2.0], vec![10.0, 15.0]),
(vec![2.0, 1.0], vec![10.0, 15.0]),
];
let multi_scores = scorer.batch_multi_term_score(&term_data);
assert_eq!(multi_scores.len(), 2);
assert!(multi_scores[0] > 0.0);
assert!(multi_scores[1] > 0.0);
}
}