use std::fmt;
use std::io;
use std::rc::Rc;
use super::collector::{DocAndFloatFeatureBuffer, LeafCollector, ScoreContext, ScoreMode};
use super::doc_id_set_iterator::DocIdSetIterator;
use super::index_searcher::IndexSearcher;
use super::scorer::Scorer;
use crate::index::directory_reader::LeafReaderContext;
pub trait Query: fmt::Debug {
fn create_weight(
&self,
searcher: &IndexSearcher,
score_mode: ScoreMode,
boost: f32,
) -> io::Result<Box<dyn Weight>>;
fn rewrite(&self, _searcher: &IndexSearcher) -> io::Result<Option<Box<dyn Query>>> {
Ok(None)
}
}
pub trait BulkScorer: fmt::Debug {
fn score(&mut self, collector: &mut dyn LeafCollector, min: i32, max: i32) -> io::Result<i32>;
fn cost(&self) -> i64;
}
pub trait ScorerSupplier<'a>: fmt::Debug + 'a {
fn get(&mut self, lead_cost: i64) -> io::Result<Box<dyn Scorer + 'a>>;
fn bulk_scorer(&mut self) -> io::Result<Box<dyn BulkScorer + 'a>> {
let scorer = self.get(i64::MAX)?;
Ok(Box::new(DefaultBulkScorer::new(scorer)))
}
fn cost(&self) -> i64;
fn set_top_level_scoring_clause(&mut self) -> io::Result<()> {
Ok(())
}
}
pub trait Weight: fmt::Debug {
fn scorer_supplier<'a>(
&self,
context: &'a LeafReaderContext,
) -> io::Result<Option<Box<dyn ScorerSupplier<'a> + 'a>>>;
fn scorer<'a>(
&self,
context: &'a LeafReaderContext,
) -> io::Result<Option<Box<dyn Scorer + 'a>>> {
match self.scorer_supplier(context)? {
None => Ok(None),
Some(mut supplier) => Ok(Some(supplier.get(i64::MAX)?)),
}
}
fn bulk_scorer<'a>(
&self,
context: &'a LeafReaderContext,
) -> io::Result<Option<Box<dyn BulkScorer + 'a>>> {
match self.scorer_supplier(context)? {
None => Ok(None),
Some(mut supplier) => {
supplier.set_top_level_scoring_clause()?;
Ok(Some(supplier.bulk_scorer()?))
}
}
}
fn count(&self, _context: &LeafReaderContext) -> io::Result<i32> {
Ok(-1)
}
}
pub struct DefaultBulkScorer<'a> {
scorer: Box<dyn Scorer + 'a>,
}
impl fmt::Debug for DefaultBulkScorer<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DefaultBulkScorer").finish()
}
}
impl<'a> DefaultBulkScorer<'a> {
pub fn new(scorer: Box<dyn Scorer + 'a>) -> Self {
Self { scorer }
}
fn score_iterator(
collector: &mut dyn LeafCollector,
iterator: &mut dyn DocIdSetIterator,
max: i32,
) -> io::Result<()> {
let mut doc = iterator.doc_id();
while doc < max {
collector.collect(doc)?;
doc = iterator.next_doc()?;
}
Ok(())
}
fn score_competitive_iterator(
collector: &mut dyn LeafCollector,
iterator: &mut dyn DocIdSetIterator,
competitive_iterator: &mut dyn DocIdSetIterator,
max: i32,
) -> io::Result<()> {
let mut doc = iterator.doc_id();
while doc < max {
debug_assert!(competitive_iterator.doc_id() <= doc);
if competitive_iterator.doc_id() < doc {
let competitive_next = competitive_iterator.advance(doc)?;
if competitive_next != doc {
doc = iterator.advance(competitive_next)?;
continue;
}
}
collector.collect(doc)?;
doc = iterator.next_doc()?;
}
Ok(())
}
}
impl BulkScorer for DefaultBulkScorer<'_> {
fn score(&mut self, collector: &mut dyn LeafCollector, min: i32, max: i32) -> io::Result<i32> {
let score_context = ScoreContext::new();
collector.set_scorer(Rc::clone(&score_context))?;
let competitive_iterator = collector.competitive_iterator();
if self.scorer.iterator().doc_id() < min {
if self.scorer.iterator().doc_id() == min - 1 {
self.scorer.iterator().next_doc()?;
} else {
self.scorer.iterator().advance(min)?;
}
}
match competitive_iterator {
None => {
while self.scorer.doc_id() < max {
score_context.score.set(self.scorer.score()?);
collector.collect(self.scorer.doc_id())?;
self.scorer.iterator().next_doc()?;
}
}
Some(mut ci) => {
let ci_doc = ci.doc_id();
let effective_min = if ci_doc > min { ci_doc.min(max) } else { min };
if self.scorer.iterator().doc_id() < effective_min {
self.scorer.iterator().advance(effective_min)?;
}
while self.scorer.doc_id() < max {
debug_assert!(ci.doc_id() <= self.scorer.doc_id());
if ci.doc_id() < self.scorer.doc_id() {
let competitive_next = ci.advance(self.scorer.doc_id())?;
if competitive_next != self.scorer.doc_id() {
self.scorer.iterator().advance(competitive_next)?;
continue;
}
}
score_context.score.set(self.scorer.score()?);
collector.collect(self.scorer.doc_id())?;
self.scorer.iterator().next_doc()?;
}
}
}
Ok(self.scorer.doc_id())
}
fn cost(&self) -> i64 {
0
}
}
pub struct DefaultScorerSupplier<'a> {
scorer: Option<Box<dyn Scorer + 'a>>,
cost: i64,
}
impl fmt::Debug for DefaultScorerSupplier<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DefaultScorerSupplier")
.field("cost", &self.cost)
.finish()
}
}
impl<'a> DefaultScorerSupplier<'a> {
pub fn new(scorer: Box<dyn Scorer + 'a>, cost: i64) -> Self {
Self {
scorer: Some(scorer),
cost,
}
}
}
impl<'a> ScorerSupplier<'a> for DefaultScorerSupplier<'a> {
fn get(&mut self, _lead_cost: i64) -> io::Result<Box<dyn Scorer + 'a>> {
self.scorer
.take()
.ok_or_else(|| io::Error::other("ScorerSupplier.get() called more than once"))
}
fn cost(&self) -> i64 {
self.cost
}
}
pub struct BatchScoreBulkScorer<'a> {
scorer: Box<dyn Scorer + 'a>,
buffer: DocAndFloatFeatureBuffer,
}
impl fmt::Debug for BatchScoreBulkScorer<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BatchScoreBulkScorer").finish()
}
}
impl<'a> BatchScoreBulkScorer<'a> {
pub fn new(scorer: Box<dyn Scorer + 'a>) -> Self {
Self {
scorer,
buffer: DocAndFloatFeatureBuffer::new(),
}
}
}
impl BulkScorer for BatchScoreBulkScorer<'_> {
fn score(&mut self, collector: &mut dyn LeafCollector, min: i32, max: i32) -> io::Result<i32> {
let score_context = ScoreContext::new();
if collector.competitive_iterator().is_some() {
collector.set_scorer(Rc::clone(&score_context))?;
let competitive_iterator = collector.competitive_iterator();
let iterator = self.scorer.iterator();
if iterator.doc_id() < min {
if iterator.doc_id() == min - 1 {
iterator.next_doc()?;
} else {
iterator.advance(min)?;
}
}
match competitive_iterator {
None => {
DefaultBulkScorer::score_iterator(collector, iterator, max)?;
}
Some(mut ci) => {
let ci_doc = ci.doc_id();
let effective_min = if ci_doc > min { ci_doc.min(max) } else { min };
if iterator.doc_id() < effective_min {
iterator.advance(effective_min)?;
}
DefaultBulkScorer::score_competitive_iterator(
collector,
iterator,
ci.as_mut(),
max,
)?;
}
}
return Ok(self.scorer.iterator().doc_id());
}
collector.set_scorer(Rc::clone(&score_context))?;
self.scorer
.set_min_competitive_score(score_context.min_competitive_score.get())?;
if self.scorer.doc_id() < min {
self.scorer.iterator().advance(min)?;
}
loop {
self.scorer.next_docs_and_scores(max, &mut self.buffer)?;
if self.buffer.size == 0 {
break;
}
let size = self.buffer.size;
for i in 0..size {
let score = self.buffer.features[i];
score_context.score.set(score);
if score >= score_context.min_competitive_score.get() {
collector.collect(self.buffer.docs[i])?;
}
}
self.scorer
.set_min_competitive_score(score_context.min_competitive_score.get())?;
}
Ok(self.scorer.doc_id())
}
fn cost(&self) -> i64 {
0
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::search::doc_id_set_iterator::NO_MORE_DOCS;
use crate::search::scorable::Scorable;
use assertables::*;
#[derive(Debug)]
struct MockScorerIterator {
docs: Vec<i32>,
index: usize,
}
impl DocIdSetIterator for MockScorerIterator {
fn doc_id(&self) -> i32 {
if self.index == 0 {
-1
} else if self.index > self.docs.len() {
NO_MORE_DOCS
} else {
self.docs[self.index - 1]
}
}
fn next_doc(&mut self) -> io::Result<i32> {
if self.index >= self.docs.len() {
self.index = self.docs.len() + 1;
return Ok(NO_MORE_DOCS);
}
self.index += 1;
Ok(self.docs[self.index - 1])
}
fn advance(&mut self, target: i32) -> io::Result<i32> {
loop {
let doc = self.next_doc()?;
if doc >= target {
return Ok(doc);
}
}
}
fn cost(&self) -> i64 {
self.docs.len() as i64
}
}
#[derive(Debug)]
struct FullMockScorer {
iter: MockScorerIterator,
scores: Vec<f32>,
}
impl FullMockScorer {
fn new(docs: Vec<i32>, scores: Vec<f32>) -> Self {
let iter = MockScorerIterator {
docs: docs.clone(),
index: 0,
};
Self { iter, scores }
}
}
impl Scorable for FullMockScorer {
fn score(&mut self) -> io::Result<f32> {
let doc = self.iter.doc_id();
if doc < 0 || doc == NO_MORE_DOCS {
return Ok(0.0);
}
for (i, &d) in self.iter.docs.iter().enumerate() {
if d == doc {
return Ok(self.scores[i]);
}
}
Ok(0.0)
}
}
impl Scorer for FullMockScorer {
fn doc_id(&self) -> i32 {
self.iter.doc_id()
}
fn iterator(&mut self) -> &mut dyn DocIdSetIterator {
&mut self.iter
}
fn get_max_score(&mut self, _up_to: i32) -> io::Result<f32> {
Ok(f32::MAX)
}
}
#[test]
fn test_default_scorer_supplier_get() {
let scorer = FullMockScorer::new(vec![0, 1, 2], vec![1.0, 2.0, 3.0]);
let mut supplier = DefaultScorerSupplier::new(Box::new(scorer), 3);
assert_eq!(supplier.cost(), 3);
let s = supplier.get(100);
assert_ok!(s);
}
#[test]
fn test_default_scorer_supplier_get_twice_fails() {
let scorer = FullMockScorer::new(vec![0], vec![1.0]);
let mut supplier = DefaultScorerSupplier::new(Box::new(scorer), 1);
supplier.get(100).unwrap();
let result = supplier.get(100);
assert_err!(result);
}
#[test]
fn test_default_scorer_supplier_bulk_scorer() {
let scorer = FullMockScorer::new(vec![0, 1, 2], vec![1.0, 2.0, 3.0]);
let mut supplier = DefaultScorerSupplier::new(Box::new(scorer), 3);
let bs = supplier.bulk_scorer();
assert_ok!(bs);
}
#[derive(Debug)]
struct DocCollector {
docs: Vec<i32>,
}
impl DocCollector {
fn new() -> Self {
Self { docs: Vec::new() }
}
}
impl LeafCollector for DocCollector {
fn set_scorer(&mut self, _score_context: Rc<ScoreContext>) -> io::Result<()> {
Ok(())
}
fn collect(&mut self, doc: i32) -> io::Result<()> {
self.docs.push(doc);
Ok(())
}
}
#[test]
fn test_default_bulk_scorer_scores_all_docs() {
let scorer = FullMockScorer::new(vec![0, 5, 10], vec![1.0, 2.0, 3.0]);
let mut bulk = DefaultBulkScorer::new(Box::new(scorer));
let mut collector = DocCollector::new();
bulk.score(&mut collector, 0, NO_MORE_DOCS).unwrap();
assert_eq!(collector.docs, vec![0, 5, 10]);
}
#[test]
fn test_default_bulk_scorer_respects_range() {
let scorer = FullMockScorer::new(vec![0, 5, 10, 15], vec![1.0, 2.0, 3.0, 4.0]);
let mut bulk = DefaultBulkScorer::new(Box::new(scorer));
let mut collector = DocCollector::new();
bulk.score(&mut collector, 3, 12).unwrap();
assert_eq!(collector.docs, vec![5, 10]);
}
#[test]
fn test_default_bulk_scorer_empty_range() {
let scorer = FullMockScorer::new(vec![10, 20], vec![1.0, 2.0]);
let mut bulk = DefaultBulkScorer::new(Box::new(scorer));
let mut collector = DocCollector::new();
bulk.score(&mut collector, 0, 5).unwrap();
assert_is_empty!(collector.docs);
}
#[derive(Debug)]
struct ScoreCollector {
docs: Vec<i32>,
scores: Vec<f32>,
score_context: Option<Rc<ScoreContext>>,
}
impl ScoreCollector {
fn new() -> Self {
Self {
docs: Vec::new(),
scores: Vec::new(),
score_context: None,
}
}
}
impl LeafCollector for ScoreCollector {
fn set_scorer(&mut self, score_context: Rc<ScoreContext>) -> io::Result<()> {
self.score_context = Some(score_context);
Ok(())
}
fn collect(&mut self, doc: i32) -> io::Result<()> {
self.docs.push(doc);
if let Some(ref ctx) = self.score_context {
self.scores.push(ctx.score.get());
}
Ok(())
}
}
#[test]
fn test_batch_score_bulk_scorer_collects_all_docs() {
let scorer = FullMockScorer::new(vec![0, 5, 10], vec![1.0, 2.0, 3.0]);
let mut bulk = BatchScoreBulkScorer::new(Box::new(scorer));
let mut collector = ScoreCollector::new();
bulk.score(&mut collector, 0, NO_MORE_DOCS).unwrap();
assert_eq!(collector.docs, vec![0, 5, 10]);
assert_eq!(collector.scores.len(), 3);
for &s in &collector.scores {
assert!(s > 0.0, "expected positive score, got {s}");
}
}
#[test]
fn test_batch_score_bulk_scorer_respects_range() {
let scorer = FullMockScorer::new(vec![0, 5, 10, 15], vec![1.0, 2.0, 3.0, 4.0]);
let mut bulk = BatchScoreBulkScorer::new(Box::new(scorer));
let mut collector = DocCollector::new();
bulk.score(&mut collector, 3, 12).unwrap();
assert_eq!(collector.docs, vec![5, 10]);
}
#[test]
fn test_batch_score_bulk_scorer_empty_range() {
let scorer = FullMockScorer::new(vec![10, 20], vec![1.0, 2.0]);
let mut bulk = BatchScoreBulkScorer::new(Box::new(scorer));
let mut collector = DocCollector::new();
bulk.score(&mut collector, 0, 5).unwrap();
assert_is_empty!(collector.docs);
}
#[test]
fn test_batch_score_bulk_scorer_min_competitive_score() {
let scorer = FullMockScorer::new(vec![0, 1, 2, 3], vec![0.5, 1.5, 0.3, 2.0]);
let mut bulk = BatchScoreBulkScorer::new(Box::new(scorer));
#[derive(Debug)]
struct FilteringCollector {
docs: Vec<i32>,
score_context: Option<Rc<ScoreContext>>,
first_collect: bool,
}
impl LeafCollector for FilteringCollector {
fn set_scorer(&mut self, score_context: Rc<ScoreContext>) -> io::Result<()> {
self.score_context = Some(score_context);
Ok(())
}
fn collect(&mut self, doc: i32) -> io::Result<()> {
self.docs.push(doc);
if !self.first_collect {
self.first_collect = true;
if let Some(ref ctx) = self.score_context {
ctx.min_competitive_score.set(1.0);
}
}
Ok(())
}
}
let mut collector = FilteringCollector {
docs: Vec::new(),
score_context: None,
first_collect: false,
};
bulk.score(&mut collector, 0, NO_MORE_DOCS).unwrap();
assert_eq!(collector.docs, vec![0, 1, 3]);
}
}