use crate::core::{DocId, NO_MORE_DOCS, Scorer, TwoPhaseIterator};
const WINDOW_SIZE: usize = 2048;
const WINDOW_WORDS: usize = WINDOW_SIZE / 64;
const WINDOW_SIZE_U32: u32 = WINDOW_SIZE as u32;
pub struct BufferedUnionScorer {
scorers: Vec<Box<dyn Scorer>>,
bitset: [u64; WINDOW_WORDS],
window_base: u32,
current: DocId,
cursor: usize,
boost: f32,
}
impl BufferedUnionScorer {
pub fn new(scorers: Vec<Box<dyn Scorer>>) -> Self {
Self::with_boost(scorers, 1.0)
}
pub fn with_boost(mut scorers: Vec<Box<dyn Scorer>>, boost: f32) -> Self {
scorers.retain(|s| s.doc_id() != NO_MORE_DOCS);
let mut union = Self {
scorers,
bitset: [0u64; WINDOW_WORDS],
window_base: 0,
current: NO_MORE_DOCS,
cursor: 0,
boost,
};
if !union.scorers.is_empty() {
union.advance_to_next_set_bit();
}
union
}
fn min_scorer_doc(&self) -> DocId {
let mut min = NO_MORE_DOCS;
for s in &self.scorers {
let d = s.doc_id();
if d != NO_MORE_DOCS && d < min {
min = d;
}
}
min
}
fn refill_from(&mut self, base: u32) {
self.window_base = base;
self.bitset = [0u64; WINDOW_WORDS];
self.cursor = 0;
let window_end = base.saturating_add(WINDOW_SIZE_U32);
for scorer in &mut self.scorers {
let mut doc = scorer.doc_id();
if doc != NO_MORE_DOCS && doc.as_u32() < base {
doc = scorer.advance(DocId::new(base));
}
while doc != NO_MORE_DOCS && doc.as_u32() < window_end {
let idx = (doc.as_u32() - base) as usize;
self.bitset[idx / 64] |= 1u64 << (idx % 64);
doc = scorer.next();
}
}
}
fn next_set_bit_in_window(&self) -> Option<usize> {
let mut word_idx = self.cursor / 64;
if word_idx >= WINDOW_WORDS {
return None;
}
let bit_offset = self.cursor % 64;
let mask = u64::MAX << bit_offset;
let mut word = self.bitset[word_idx] & mask;
loop {
if word != 0 {
let bit = word.trailing_zeros() as usize;
return Some(word_idx * 64 + bit);
}
word_idx += 1;
if word_idx >= WINDOW_WORDS {
return None;
}
word = self.bitset[word_idx];
}
}
fn advance_to_next_set_bit(&mut self) {
if let Some(idx) = self.next_set_bit_in_window() {
self.cursor = idx + 1;
self.current = DocId::new(self.window_base + idx as u32);
return;
}
loop {
let next_min = self.min_scorer_doc();
if next_min == NO_MORE_DOCS {
self.current = NO_MORE_DOCS;
return;
}
self.refill_from(next_min.as_u32());
if let Some(idx) = self.next_set_bit_in_window() {
self.cursor = idx + 1;
self.current = DocId::new(self.window_base + idx as u32);
return;
}
}
}
}
impl Scorer for BufferedUnionScorer {
fn doc_id(&self) -> DocId {
self.current
}
fn next(&mut self) -> DocId {
if self.current == NO_MORE_DOCS {
return NO_MORE_DOCS;
}
self.advance_to_next_set_bit();
self.current
}
fn advance(&mut self, target: DocId) -> DocId {
if self.current == NO_MORE_DOCS {
return NO_MORE_DOCS;
}
if self.current >= target {
return self.current;
}
let target_u32 = target.as_u32();
let window_end = self.window_base.saturating_add(WINDOW_SIZE_U32);
if target_u32 >= self.window_base && target_u32 < window_end {
self.cursor = (target_u32 - self.window_base) as usize;
self.advance_to_next_set_bit();
} else {
self.refill_from(target_u32);
if let Some(idx) = self.next_set_bit_in_window() {
self.cursor = idx + 1;
self.current = DocId::new(self.window_base + idx as u32);
} else {
self.advance_to_next_set_bit();
}
}
self.current
}
fn score(&mut self) -> f32 {
self.boost
}
fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
struct VecScorer {
docs: Vec<DocId>,
pos: usize,
}
impl VecScorer {
fn new(docs: Vec<u32>) -> Box<dyn Scorer> {
Box::new(Self {
docs: docs.into_iter().map(DocId::new).collect(),
pos: 0,
})
}
}
impl Scorer for VecScorer {
fn doc_id(&self) -> DocId {
if self.pos < self.docs.len() {
self.docs[self.pos]
} else {
NO_MORE_DOCS
}
}
fn next(&mut self) -> DocId {
if self.pos < self.docs.len() {
self.pos += 1;
}
self.doc_id()
}
fn advance(&mut self, target: DocId) -> DocId {
while self.pos < self.docs.len() && self.docs[self.pos] < target {
self.pos += 1;
}
self.doc_id()
}
fn score(&mut self) -> f32 {
1.0
}
fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
None
}
}
fn collect(scorer: &mut dyn Scorer) -> Vec<u32> {
let mut out = Vec::new();
while scorer.doc_id() != NO_MORE_DOCS {
out.push(scorer.doc_id().as_u32());
scorer.next();
}
out
}
#[test]
fn buffered_union_two_scorers() {
let s1 = VecScorer::new(vec![0, 2, 4]);
let s2 = VecScorer::new(vec![1, 2, 3]);
let mut union = BufferedUnionScorer::new(vec![s1, s2]);
assert_eq!(collect(&mut union), vec![0, 1, 2, 3, 4]);
}
#[test]
fn buffered_union_three_scorers_overlap() {
let s1 = VecScorer::new(vec![5]);
let s2 = VecScorer::new(vec![5]);
let s3 = VecScorer::new(vec![5]);
let mut union = BufferedUnionScorer::new(vec![s1, s2, s3]);
assert_eq!(collect(&mut union), vec![5]);
}
#[test]
fn buffered_union_single_scorer() {
let s = VecScorer::new(vec![0, 1, 2]);
let mut union = BufferedUnionScorer::new(vec![s]);
assert_eq!(collect(&mut union), vec![0, 1, 2]);
}
#[test]
fn buffered_union_no_overlap() {
let s1 = VecScorer::new(vec![0, 2, 4, 6]);
let s2 = VecScorer::new(vec![1, 3, 5, 7]);
let mut union = BufferedUnionScorer::new(vec![s1, s2]);
assert_eq!(collect(&mut union), vec![0, 1, 2, 3, 4, 5, 6, 7]);
}
#[test]
fn buffered_union_empty_scorer_filtered() {
let s1 = VecScorer::new(vec![5]);
let s2 = VecScorer::new(vec![]); let mut union = BufferedUnionScorer::new(vec![s1, s2]);
assert_eq!(collect(&mut union), vec![5]);
}
#[test]
fn buffered_union_advance_within_window() {
let s1 = VecScorer::new(vec![0, 5, 10, 15]);
let s2 = VecScorer::new(vec![1, 6, 11, 16]);
let mut union = BufferedUnionScorer::new(vec![s1, s2]);
assert_eq!(union.doc_id(), DocId::new(0));
assert_eq!(union.advance(DocId::new(8)), DocId::new(10));
assert_eq!(union.next(), DocId::new(11));
assert_eq!(union.next(), DocId::new(15));
assert_eq!(union.next(), DocId::new(16));
assert_eq!(union.next(), NO_MORE_DOCS);
}
#[test]
fn buffered_union_advance_past_window() {
let s1 = VecScorer::new(vec![0, 5000, 10000]);
let s2 = VecScorer::new(vec![100, 5100, 10100]);
let mut union = BufferedUnionScorer::new(vec![s1, s2]);
assert_eq!(union.doc_id(), DocId::new(0));
assert_eq!(union.next(), DocId::new(100));
assert_eq!(union.next(), DocId::new(5000));
assert_eq!(union.next(), DocId::new(5100));
assert_eq!(union.next(), DocId::new(10000));
assert_eq!(union.next(), DocId::new(10100));
assert_eq!(union.next(), NO_MORE_DOCS);
}
#[test]
fn buffered_union_window_jump() {
let s1 = VecScorer::new(vec![5, 100000]);
let s2 = VecScorer::new(vec![10, 100010]);
let mut union = BufferedUnionScorer::new(vec![s1, s2]);
assert_eq!(collect(&mut union), vec![5, 10, 100000, 100010]);
}
#[test]
fn buffered_union_advance_to_existing_doc() {
let s1 = VecScorer::new(vec![0, 5, 10]);
let s2 = VecScorer::new(vec![1, 6, 11]);
let mut union = BufferedUnionScorer::new(vec![s1, s2]);
assert_eq!(union.advance(DocId::new(5)), DocId::new(5));
assert_eq!(union.next(), DocId::new(6));
}
#[test]
fn buffered_union_advance_past_end() {
let s1 = VecScorer::new(vec![0, 5]);
let mut union = BufferedUnionScorer::new(vec![s1]);
assert_eq!(union.advance(DocId::new(100)), NO_MORE_DOCS);
}
#[test]
fn buffered_union_score_is_constant() {
let s1 = VecScorer::new(vec![0, 1, 2]);
let mut union = BufferedUnionScorer::new(vec![s1]);
assert_eq!(union.score(), 1.0);
union.next();
assert_eq!(union.score(), 1.0);
union.next();
assert_eq!(union.score(), 1.0);
}
#[test]
fn buffered_union_custom_boost() {
let s1 = VecScorer::new(vec![0]);
let mut union = BufferedUnionScorer::with_boost(vec![s1], 2.5);
assert_eq!(union.score(), 2.5);
}
#[test]
fn buffered_union_dense_window() {
let docs: Vec<u32> = (0..100).collect();
let s = VecScorer::new(docs.clone());
let mut union = BufferedUnionScorer::new(vec![s]);
assert_eq!(collect(&mut union), docs);
}
#[test]
fn buffered_union_window_boundary() {
let s1 = VecScorer::new(vec![2047, 2048, 2049]);
let mut union = BufferedUnionScorer::new(vec![s1]);
assert_eq!(collect(&mut union), vec![2047, 2048, 2049]);
}
#[test]
fn buffered_union_many_scorers() {
let scorers: Vec<Box<dyn Scorer>> = (0..10)
.map(|i| VecScorer::new(vec![i * 100, i * 100 + 50, i * 100 + 99]))
.collect();
let mut union = BufferedUnionScorer::new(scorers);
let docs = collect(&mut union);
assert_eq!(docs.len(), 30);
for w in docs.windows(2) {
assert!(w[0] < w[1], "not sorted: {:?}", w);
}
}
}