use std::collections::BinaryHeap;
use std::io;
use super::collector::DocAndScoreAccBuffer;
use super::scorer::Scorer;
use crate::util::math_util;
pub fn min_required_score(
max_remaining_score: f64,
min_competitive_score: f32,
num_scorers: i32,
) -> f64 {
let mut min_required_score = min_competitive_score as f64 - max_remaining_score;
let subtraction = ulp_f32(min_competitive_score) as f64;
while min_required_score > 0.0
&& (math_util::sum_upper_bound(min_required_score + max_remaining_score, num_scorers)
as f32)
>= min_competitive_score
{
min_required_score -= subtraction;
}
min_required_score
}
pub fn cost_with_min_should_match(costs: &[i64], num_scorers: usize, min_should_match: i32) -> i64 {
let capacity = num_scorers as i64 - min_should_match as i64 + 1;
if capacity <= 0 {
return 0;
}
let capacity = capacity as usize;
let mut pq = BinaryHeap::with_capacity(capacity + 1);
for &cost in costs {
pq.push(cost);
if pq.len() > capacity {
pq.pop(); }
}
pq.into_iter().sum()
}
pub fn filter_competitive_hits(
buffer: &mut DocAndScoreAccBuffer,
max_remaining_score: f64,
min_competitive_score: f32,
num_scorers: i32,
) {
let min_req = min_required_score(max_remaining_score, min_competitive_score, num_scorers);
if min_req <= 0.0 {
return;
}
buffer.size = filter_by_score(&mut buffer.docs, &mut buffer.scores, min_req, buffer.size);
}
pub fn apply_required_clause(
buffer: &mut DocAndScoreAccBuffer,
scorer: &mut dyn Scorer,
) -> io::Result<()> {
let mut intersection_size = 0;
let mut cur_doc = scorer.doc_id();
for i in 0..buffer.size {
let target_doc = buffer.docs[i];
if cur_doc < target_doc {
cur_doc = scorer.iterator().advance(target_doc)?;
}
if cur_doc == target_doc {
buffer.docs[intersection_size] = target_doc;
buffer.scores[intersection_size] = buffer.scores[i] + scorer.score()? as f64;
intersection_size += 1;
}
}
buffer.size = intersection_size;
Ok(())
}
fn filter_by_score(
docs: &mut [i32],
scores: &mut [f64],
min_score_inclusive: f64,
up_to: usize,
) -> usize {
let mut new_size = 0;
for i in 0..up_to {
let doc = docs[i];
let score = scores[i];
docs[new_size] = doc;
scores[new_size] = score;
if score >= min_score_inclusive {
new_size += 1;
}
}
new_size
}
fn ulp_f32(value: f32) -> f32 {
if value.is_nan() {
return f32::NAN;
}
if value.is_infinite() {
return f32::INFINITY;
}
let abs = value.abs();
let next = f32::from_bits(abs.to_bits() + 1);
next - abs
}
#[cfg(test)]
mod tests {
use std::fmt;
use super::*;
use crate::search::collector::DocAndFloatFeatureBuffer;
use crate::search::doc_id_set_iterator::{DocIdSetIterator, NO_MORE_DOCS};
use crate::search::scorable::Scorable;
use assertables::*;
#[test]
fn test_ulp_f32_one() {
let u = ulp_f32(1.0);
assert_gt!(u, 0.0);
assert_in_delta!(u, 1.1920929e-7, 1e-14);
}
#[test]
fn test_ulp_f32_zero() {
let u = ulp_f32(0.0);
assert_gt!(u, 0.0);
assert_eq!(u, f32::from_bits(1));
}
#[test]
fn test_filter_by_score() {
let mut docs = vec![1, 2, 3, 4, 5];
let mut scores = vec![0.5, 1.5, 0.3, 2.0, 0.8];
let new_size = filter_by_score(&mut docs, &mut scores, 1.0, 5);
assert_eq!(new_size, 2);
assert_eq!(docs[0], 2);
assert_eq!(docs[1], 4);
assert_in_delta!(scores[0], 1.5, 1e-10);
assert_in_delta!(scores[1], 2.0, 1e-10);
}
#[test]
fn test_filter_by_score_all_pass() {
let mut docs = vec![1, 2, 3];
let mut scores = vec![1.0, 2.0, 3.0];
let new_size = filter_by_score(&mut docs, &mut scores, 0.5, 3);
assert_eq!(new_size, 3);
}
#[test]
fn test_filter_by_score_none_pass() {
let mut docs = vec![1, 2, 3];
let mut scores = vec![0.1, 0.2, 0.3];
let new_size = filter_by_score(&mut docs, &mut scores, 1.0, 3);
assert_eq!(new_size, 0);
}
#[test]
fn test_min_required_score_no_filtering() {
let result = min_required_score(100.0, 1.0, 3);
assert_le!(result, 0.0);
}
#[test]
fn test_min_required_score_tight() {
let result = min_required_score(0.0, 1.0, 3);
assert_gt!(result, 0.0);
}
#[test]
fn test_cost_with_min_should_match_msm_1() {
let costs = vec![10, 20, 30];
let result = cost_with_min_should_match(&costs, 3, 1);
assert_eq!(result, 60);
}
#[test]
fn test_cost_with_min_should_match_msm_2() {
let costs = vec![10, 20, 30];
let result = cost_with_min_should_match(&costs, 3, 2);
assert_eq!(result, 30);
}
#[test]
fn test_cost_with_min_should_match_msm_equals_num() {
let costs = vec![10, 20, 30];
let result = cost_with_min_should_match(&costs, 3, 3);
assert_eq!(result, 10);
}
#[test]
fn test_cost_with_min_should_match_single_scorer() {
let costs = vec![42];
let result = cost_with_min_should_match(&costs, 1, 1);
assert_eq!(result, 42);
}
#[test]
fn test_cost_with_min_should_match_zero_capacity() {
let costs = vec![10, 20];
let result = cost_with_min_should_match(&costs, 2, 3);
assert_eq!(result, 0);
}
#[test]
fn test_filter_competitive_hits_none_filtered() {
let mut buffer = DocAndScoreAccBuffer::new();
buffer.grow_no_copy(3);
buffer.docs[0] = 1;
buffer.docs[1] = 2;
buffer.docs[2] = 3;
buffer.scores[0] = 5.0;
buffer.scores[1] = 6.0;
buffer.scores[2] = 7.0;
buffer.size = 3;
filter_competitive_hits(&mut buffer, 100.0, 1.0, 3);
assert_eq!(buffer.size, 3);
}
#[test]
fn test_apply_required_clause() {
struct MockScorer {
docs: Vec<i32>,
idx: usize,
}
impl fmt::Debug for MockScorer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MockScorer").finish()
}
}
impl Scorable for MockScorer {
fn score(&mut self) -> io::Result<f32> {
Ok(0.5)
}
}
impl Scorer for MockScorer {
fn doc_id(&self) -> i32 {
if self.idx < self.docs.len() {
self.docs[self.idx]
} else {
NO_MORE_DOCS
}
}
fn iterator(&mut self) -> &mut dyn DocIdSetIterator {
self
}
fn get_max_score(&mut self, _up_to: i32) -> io::Result<f32> {
Ok(f32::MAX)
}
fn next_docs_and_scores(
&mut self,
_up_to: i32,
_buffer: &mut DocAndFloatFeatureBuffer,
) -> io::Result<()> {
Ok(())
}
}
impl DocIdSetIterator for MockScorer {
fn doc_id(&self) -> i32 {
Scorer::doc_id(self)
}
fn next_doc(&mut self) -> io::Result<i32> {
self.idx += 1;
Ok(Scorer::doc_id(self))
}
fn advance(&mut self, target: i32) -> io::Result<i32> {
while Scorer::doc_id(self) < target {
self.next_doc()?;
}
Ok(Scorer::doc_id(self))
}
fn cost(&self) -> i64 {
self.docs.len() as i64
}
}
let mut buffer = DocAndScoreAccBuffer::new();
buffer.grow_no_copy(5);
buffer.docs[0] = 1;
buffer.docs[1] = 2;
buffer.docs[2] = 3;
buffer.docs[3] = 4;
buffer.docs[4] = 5;
buffer.scores[0] = 1.0;
buffer.scores[1] = 1.0;
buffer.scores[2] = 1.0;
buffer.scores[3] = 1.0;
buffer.scores[4] = 1.0;
buffer.size = 5;
let mut scorer = MockScorer {
docs: vec![1, 3, 5],
idx: 0,
};
apply_required_clause(&mut buffer, &mut scorer).unwrap();
assert_eq!(buffer.size, 3);
assert_eq!(buffer.docs[0], 1);
assert_eq!(buffer.docs[1], 3);
assert_eq!(buffer.docs[2], 5);
assert_in_delta!(buffer.scores[0], 1.5, 1e-10);
assert_in_delta!(buffer.scores[1], 1.5, 1e-10);
assert_in_delta!(buffer.scores[2], 1.5, 1e-10);
}
}