use crate::core::{DocId, NO_MORE_DOCS, Scorer, SegmentId};
use crate::search::collector::TopDocsCollector;
const WINDOW_SIZE: usize = 2048;
pub struct MaxScoreBulkScorer {
scorers: Vec<Box<dyn Scorer>>,
doc_ids: Vec<u32>,
max_scores: Vec<f32>,
max_score_prefix: Vec<f32>,
partition_idx: usize,
min_competitive_score: f32,
window_scores: Box<[f32; WINDOW_SIZE]>,
window_matches: [u64; WINDOW_SIZE / 64],
}
impl MaxScoreBulkScorer {
pub fn new(mut scorers: Vec<Box<dyn Scorer>>) -> Self {
scorers.retain(|s| s.doc_id() != NO_MORE_DOCS);
scorers.sort_by(|a, b| {
a.max_score()
.partial_cmp(&b.max_score())
.unwrap_or(std::cmp::Ordering::Equal)
});
let doc_ids: Vec<u32> = scorers.iter().map(|s| s.doc_id().as_u32()).collect();
let max_scores: Vec<f32> = scorers.iter().map(|s| s.max_score()).collect();
let mut max_score_prefix = Vec::with_capacity(scorers.len());
let mut cum = 0.0f32;
for &ms in &max_scores {
cum += ms;
max_score_prefix.push(cum);
}
Self {
scorers,
doc_ids,
max_scores,
max_score_prefix,
partition_idx: 0,
min_competitive_score: 0.0,
window_scores: Box::new([0.0; WINDOW_SIZE]),
window_matches: [0u64; WINDOW_SIZE / 64],
}
}
pub fn score(
&mut self,
collector: &mut TopDocsCollector,
segment_id: SegmentId,
max_doc: u32,
) -> u64 {
let mut total_hits: u64 = 0;
let n = self.scorers.len();
if n == 0 {
return 0;
}
let first = self
.doc_ids
.iter()
.copied()
.filter(|&d| d != NO_MORE_DOCS.as_u32())
.min()
.unwrap_or(max_doc);
let mut window_base = first;
while window_base < max_doc {
let window_end = (window_base + WINDOW_SIZE as u32).min(max_doc);
self.update_partition_for_window(collector.min_score(), DocId::new(window_base));
if self.partition_idx >= n {
break; }
let has_essential = self.doc_ids[self.partition_idx..]
.iter()
.any(|&d| d < window_end && d != NO_MORE_DOCS.as_u32());
if !has_essential {
let next = self.doc_ids[self.partition_idx..]
.iter()
.copied()
.filter(|&d| d != NO_MORE_DOCS.as_u32())
.min()
.unwrap_or(max_doc);
window_base = next;
continue;
}
let num_essential = n - self.partition_idx;
if num_essential == 1 {
total_hits += self.score_single_essential(window_end, collector, segment_id);
} else {
total_hits +=
self.score_multi_essential(window_base, window_end, collector, segment_id);
}
window_base = window_end;
}
total_hits
}
fn score_single_essential(
&mut self,
window_end: u32,
collector: &mut TopDocsCollector,
segment_id: SegmentId,
) -> u64 {
let ess_idx = self.partition_idx;
let mut total_hits: u64 = 0;
let non_ess_max = if self.partition_idx > 0 {
self.max_score_prefix[self.partition_idx - 1]
} else {
0.0
};
while self.doc_ids[ess_idx] < window_end && self.doc_ids[ess_idx] != NO_MORE_DOCS.as_u32() {
let doc = self.doc_ids[ess_idx];
let mut score = self.scorers[ess_idx].score();
if self.min_competitive_score == 0.0
|| score + non_ess_max >= self.min_competitive_score
{
for i in 0..self.partition_idx {
let adv = self.scorers[i].advance(DocId::new(doc));
self.doc_ids[i] = adv.as_u32();
if adv.as_u32() == doc {
score += self.scorers[i].score();
}
}
collector.collect(DocId::new(doc), segment_id, score);
total_hits += 1;
let new_min = collector.min_score();
if new_min > self.min_competitive_score {
self.min_competitive_score = new_min;
self.update_partition(new_min);
if self.partition_idx >= self.scorers.len() {
return total_hits;
}
if self.scorers.len() - self.partition_idx != 1 {
self.doc_ids[ess_idx] = self.scorers[ess_idx].next().as_u32();
return total_hits;
}
}
}
self.doc_ids[ess_idx] = self.scorers[ess_idx].next().as_u32();
}
total_hits
}
fn score_multi_essential(
&mut self,
window_base: u32,
window_end: u32,
collector: &mut TopDocsCollector,
segment_id: SegmentId,
) -> u64 {
let n = self.scorers.len();
for i in self.partition_idx..n {
if self.doc_ids[i] < window_base && self.doc_ids[i] != NO_MORE_DOCS.as_u32() {
self.doc_ids[i] = self.scorers[i].advance(DocId::new(window_base)).as_u32();
}
while self.doc_ids[i] < window_end && self.doc_ids[i] != NO_MORE_DOCS.as_u32() {
let idx = (self.doc_ids[i] - window_base) as usize;
self.window_scores[idx] += self.scorers[i].score();
self.window_matches[idx / 64] |= 1u64 << (idx % 64);
self.doc_ids[i] = self.scorers[i].next().as_u32();
}
}
let non_ess_max = if self.partition_idx > 0 {
self.max_score_prefix[self.partition_idx - 1]
} else {
0.0
};
let mut total_hits: u64 = 0;
for word_idx in 0..(WINDOW_SIZE / 64) {
let mut bits = self.window_matches[word_idx];
while bits != 0 {
let bit = bits.trailing_zeros() as usize;
let idx = word_idx * 64 + bit;
let doc_raw = window_base + idx as u32;
let mut score = self.window_scores[idx];
if self.min_competitive_score == 0.0
|| score + non_ess_max >= self.min_competitive_score
{
for i in 0..self.partition_idx {
let doc_id = DocId::new(doc_raw);
let adv = self.scorers[i].advance(doc_id);
self.doc_ids[i] = adv.as_u32();
if adv.as_u32() == doc_raw {
score += self.scorers[i].score();
}
}
collector.collect(DocId::new(doc_raw), segment_id, score);
total_hits += 1;
}
self.window_scores[idx] = 0.0;
bits &= bits - 1;
}
self.window_matches[word_idx] = 0;
}
total_hits
}
fn update_partition(&mut self, min_score: f32) {
if min_score <= self.min_competitive_score {
return;
}
self.min_competitive_score = min_score;
self.partition_idx = 0;
for i in 0..self.scorers.len() {
if self.max_score_prefix[i] >= min_score {
break;
}
self.partition_idx = i + 1;
}
}
fn update_partition_for_window(&mut self, min_score: f32, window_start: DocId) {
if min_score <= 0.0 {
self.min_competitive_score = 0.0;
self.partition_idx = 0;
return;
}
self.min_competitive_score = min_score;
let mut window_prefix = 0.0f32;
self.partition_idx = 0;
for i in 0..self.scorers.len() {
let window_max = if self.doc_ids[i] != NO_MORE_DOCS.as_u32() {
self.scorers[i]
.block_max_score(window_start)
.min(self.max_scores[i]) } else {
0.0
};
window_prefix += window_max;
if window_prefix >= min_score {
break;
}
self.partition_idx = i + 1;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::TwoPhaseIterator;
struct VecScorer {
docs: Vec<(DocId, f32)>,
pos: usize,
max: f32,
}
impl VecScorer {
fn new(docs: Vec<(u32, f32)>, max: f32) -> Box<dyn Scorer> {
Box::new(Self {
docs: docs
.into_iter()
.map(|(id, s)| (DocId::new(id), s))
.collect(),
pos: 0,
max,
})
}
fn current(&self) -> (DocId, f32) {
if self.pos < self.docs.len() {
self.docs[self.pos]
} else {
(NO_MORE_DOCS, 0.0)
}
}
}
impl Scorer for VecScorer {
fn doc_id(&self) -> DocId {
self.current().0
}
fn next(&mut self) -> DocId {
if self.pos < self.docs.len() {
self.pos += 1;
}
self.current().0
}
fn advance(&mut self, target: DocId) -> DocId {
while self.pos < self.docs.len() && self.docs[self.pos].0 < target {
self.pos += 1;
}
self.current().0
}
fn score(&mut self) -> f32 {
self.current().1
}
fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
None
}
fn max_score(&self) -> f32 {
self.max
}
}
fn collect_all(scorers: Vec<Box<dyn Scorer>>, max_doc: u32) -> Vec<(u32, f32)> {
let mut bulk = MaxScoreBulkScorer::new(scorers);
let mut collector = TopDocsCollector::new(100_000);
bulk.score(&mut collector, SegmentId::new(1), max_doc);
let results = collector.into_sorted_results();
let mut docs: Vec<(u32, f32)> = results
.iter()
.map(|sd| (sd.doc_id.as_u32(), sd.score))
.collect();
docs.sort_by_key(|(id, _)| *id);
docs
}
#[test]
fn matches_wand_output() {
use crate::search::wand::WANDScorer;
let docs1 = vec![(0, 1.0), (2, 1.5), (5, 0.8), (10, 2.0), (15, 1.0)];
let docs2 = vec![(1, 2.0), (2, 0.5), (5, 1.5), (7, 3.0), (15, 0.5)];
let docs3 = vec![(0, 0.3), (5, 0.7), (15, 2.0)];
let mut wand = WANDScorer::new(vec![
VecScorer::new(docs1.clone(), 2.0),
VecScorer::new(docs2.clone(), 3.0),
VecScorer::new(docs3.clone(), 2.0),
]);
let mut wand_results = Vec::new();
while wand.doc_id() != NO_MORE_DOCS {
wand_results.push((wand.doc_id().as_u32(), wand.score()));
wand.next();
}
wand_results.sort_by_key(|(id, _)| *id);
let bulk_results = collect_all(
vec![
VecScorer::new(docs1, 2.0),
VecScorer::new(docs2, 3.0),
VecScorer::new(docs3, 2.0),
],
20,
);
assert_eq!(wand_results, bulk_results);
}
#[test]
fn two_scorers() {
let results = collect_all(
vec![
VecScorer::new(vec![(0, 1.0), (2, 1.0), (4, 1.0)], 1.0),
VecScorer::new(vec![(1, 2.0), (2, 2.0), (3, 2.0)], 2.0),
],
10,
);
assert_eq!(
results,
vec![(0, 1.0), (1, 2.0), (2, 3.0), (3, 2.0), (4, 1.0)]
);
}
#[test]
fn window_boundary() {
let results = collect_all(
vec![
VecScorer::new(vec![(2047, 1.0), (2048, 2.0)], 2.0),
VecScorer::new(vec![(2047, 0.5), (2049, 0.5)], 0.5),
],
4096,
);
assert_eq!(results, vec![(2047, 1.5), (2048, 2.0), (2049, 0.5)]);
}
#[test]
fn empty() {
let results = collect_all(vec![], 100);
assert!(results.is_empty());
}
#[test]
fn single_scorer() {
let results = collect_all(vec![VecScorer::new(vec![(0, 1.0), (5, 2.0)], 2.0)], 10);
assert_eq!(results, vec![(0, 1.0), (5, 2.0)]);
}
}