use crate::core::{DocId, NO_MORE_DOCS, Scorer, TwoPhaseIterator};
pub struct ConjunctionScorer {
scorers: Vec<Box<dyn Scorer>>,
current: DocId,
}
impl ConjunctionScorer {
pub fn new(mut scorers: Vec<Box<dyn Scorer>>) -> Self {
assert!(!scorers.is_empty(), "conjunction needs at least one scorer");
let mut target = scorers[0].doc_id();
let mut i = 1;
while i < scorers.len() && target != NO_MORE_DOCS {
let doc = scorers[i].advance(target);
if doc == target {
i += 1;
} else {
target = doc;
if target == NO_MORE_DOCS {
break;
}
target = scorers[0].advance(target);
i = 1;
}
}
Self {
scorers,
current: target,
}
}
fn advance_to_common(&mut self) -> DocId {
if self.scorers.is_empty() {
return NO_MORE_DOCS;
}
let mut target = self.scorers[0].doc_id();
let mut i = 1;
while i < self.scorers.len() && target != NO_MORE_DOCS {
let doc = self.scorers[i].advance(target);
if doc == target {
i += 1;
} else {
target = doc;
if target == NO_MORE_DOCS {
break;
}
target = self.scorers[0].advance(target);
i = 1;
}
}
target
}
}
impl Scorer for ConjunctionScorer {
fn doc_id(&self) -> DocId {
self.current
}
fn next(&mut self) -> DocId {
if self.current == NO_MORE_DOCS {
return NO_MORE_DOCS;
}
self.scorers[0].next();
self.current = self.advance_to_common();
self.current
}
fn advance(&mut self, target: DocId) -> DocId {
self.scorers[0].advance(target);
self.current = self.advance_to_common();
self.current
}
fn score(&mut self) -> f32 {
self.scorers.iter_mut().map(|s| s.score()).sum()
}
fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
struct VecScorer {
docs: Vec<(DocId, f32)>,
pos: usize,
}
impl VecScorer {
fn new(docs: Vec<(u32, f32)>) -> Box<dyn Scorer> {
Box::new(Self {
docs: docs
.into_iter()
.map(|(id, s)| (DocId::new(id), s))
.collect(),
pos: 0,
})
}
fn current(&self) -> (DocId, f32) {
if self.pos < self.docs.len() {
self.docs[self.pos]
} else {
(NO_MORE_DOCS, 0.0)
}
}
}
impl Scorer for VecScorer {
fn doc_id(&self) -> DocId {
self.current().0
}
fn next(&mut self) -> DocId {
if self.pos < self.docs.len() {
self.pos += 1;
}
self.current().0
}
fn advance(&mut self, target: DocId) -> DocId {
while self.pos < self.docs.len() && self.docs[self.pos].0 < target {
self.pos += 1;
}
self.current().0
}
fn score(&mut self) -> f32 {
self.current().1
}
fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
None
}
}
#[test]
fn two_scorers_intersection() {
let s1 = VecScorer::new(vec![(0, 1.0), (2, 1.0), (4, 1.0), (6, 1.0)]);
let s2 = VecScorer::new(vec![(1, 2.0), (2, 2.0), (5, 2.0), (6, 2.0)]);
let mut conj = ConjunctionScorer::new(vec![s1, s2]);
assert_eq!(conj.doc_id(), DocId::new(2));
assert_eq!(conj.score(), 3.0); assert_eq!(conj.next(), DocId::new(6));
assert_eq!(conj.score(), 3.0);
assert_eq!(conj.next(), NO_MORE_DOCS);
}
#[test]
fn three_scorers() {
let s1 = VecScorer::new(vec![(1, 1.0), (3, 1.0), (5, 1.0), (7, 1.0)]);
let s2 = VecScorer::new(vec![(3, 2.0), (5, 2.0), (7, 2.0), (9, 2.0)]);
let s3 = VecScorer::new(vec![(5, 3.0), (7, 3.0)]);
let mut conj = ConjunctionScorer::new(vec![s1, s2, s3]);
assert_eq!(conj.doc_id(), DocId::new(5));
assert_eq!(conj.score(), 6.0);
assert_eq!(conj.next(), DocId::new(7));
assert_eq!(conj.next(), NO_MORE_DOCS);
}
#[test]
fn no_intersection() {
let s1 = VecScorer::new(vec![(0, 1.0), (2, 1.0)]);
let s2 = VecScorer::new(vec![(1, 2.0), (3, 2.0)]);
let conj = ConjunctionScorer::new(vec![s1, s2]);
assert_eq!(conj.doc_id(), NO_MORE_DOCS);
}
#[test]
fn advance_conjunction() {
let s1 = VecScorer::new(vec![(0, 1.0), (5, 1.0), (10, 1.0), (15, 1.0)]);
let s2 = VecScorer::new(vec![(5, 2.0), (10, 2.0), (15, 2.0), (20, 2.0)]);
let mut conj = ConjunctionScorer::new(vec![s1, s2]);
assert_eq!(conj.doc_id(), DocId::new(5));
assert_eq!(conj.advance(DocId::new(12)), DocId::new(15));
assert_eq!(conj.next(), NO_MORE_DOCS);
}
#[test]
fn single_scorer() {
let s1 = VecScorer::new(vec![(0, 1.0), (1, 2.0)]);
let mut conj = ConjunctionScorer::new(vec![s1]);
assert_eq!(conj.doc_id(), DocId::new(0));
assert_eq!(conj.next(), DocId::new(1));
assert_eq!(conj.next(), NO_MORE_DOCS);
}
}