use crate::core::{DocId, NO_MORE_DOCS, Result, ScoreMode, Scorer, TwoPhaseIterator};
use crate::query::{BoundQuery, Query, ScorerSupplier};
use crate::search::conjunction::ConjunctionScorer;
use crate::search::searcher::Searcher;
use crate::segment::reader::SegmentReader;
pub struct BoolQuery {
pub(crate) must: Vec<Box<dyn Query>>,
pub(crate) should: Vec<Box<dyn Query>>,
pub(crate) must_not: Vec<Box<dyn Query>>,
pub(crate) filter: Vec<Box<dyn Query>>,
pub(crate) minimum_should_match: Option<u32>,
}
impl Query for BoolQuery {
fn bind(&self, searcher: &Searcher, score_mode: ScoreMode) -> Result<Box<dyn BoundQuery>> {
let must_weights: Vec<Box<dyn BoundQuery>> = self
.must
.iter()
.map(|q| q.bind(searcher, score_mode))
.collect::<Result<_>>()?;
let should_weights: Vec<Box<dyn BoundQuery>> = self
.should
.iter()
.map(|q| q.bind(searcher, score_mode))
.collect::<Result<_>>()?;
let must_not_weights: Vec<Box<dyn BoundQuery>> = self
.must_not
.iter()
.map(|q| q.bind(searcher, ScoreMode::CompleteNoScores))
.collect::<Result<_>>()?;
let filter_weights: Vec<Box<dyn BoundQuery>> = self
.filter
.iter()
.map(|q| q.bind(searcher, ScoreMode::CompleteNoScores))
.collect::<Result<_>>()?;
Ok(Box::new(BoundBoolQuery {
must: must_weights,
should: should_weights,
must_not: must_not_weights,
filter: filter_weights,
minimum_should_match: self.minimum_should_match,
score_mode,
}))
}
}
struct BoundBoolQuery {
must: Vec<Box<dyn BoundQuery>>,
should: Vec<Box<dyn BoundQuery>>,
must_not: Vec<Box<dyn BoundQuery>>,
filter: Vec<Box<dyn BoundQuery>>,
minimum_should_match: Option<u32>,
score_mode: ScoreMode,
}
impl BoundQuery for BoundBoolQuery {
fn bulk_score(
&self,
reader: &SegmentReader,
collector: &mut crate::search::collector::TopDocsCollector,
segment_id: crate::core::SegmentId,
) -> Result<Option<u64>> {
if !self.must.is_empty() || !self.filter.is_empty() || !self.must_not.is_empty() {
return Ok(None);
}
if self.minimum_should_match.map_or(false, |m| m > 1) {
return Ok(None);
}
if self.should.len() < 2 {
return Ok(None);
}
let mut scorers: Vec<Box<dyn crate::core::Scorer>> = Vec::new();
for w in &self.should {
if let Some(supplier) = w.scorer_supplier(reader)? {
scorers.push(supplier.scorer()?);
}
}
if scorers.len() < 2 {
return Ok(None);
}
let max_doc = reader.doc_count();
let mut bulk = crate::search::bulk::MaxScoreBulkScorer::new(scorers);
let hits = bulk.score(collector, segment_id, max_doc);
Ok(Some(hits))
}
fn scorer_supplier(&self, reader: &SegmentReader) -> Result<Option<Box<dyn ScorerSupplier>>> {
let mut must_suppliers: Vec<Box<dyn ScorerSupplier>> = Vec::new();
for w in &self.must {
match w.scorer_supplier(reader)? {
Some(s) => must_suppliers.push(s),
None => return Ok(None), }
}
let mut filter_suppliers: Vec<Box<dyn ScorerSupplier>> = Vec::new();
for w in &self.filter {
match w.scorer_supplier(reader)? {
Some(s) => filter_suppliers.push(s),
None => return Ok(None), }
}
let mut should_suppliers: Vec<Box<dyn ScorerSupplier>> = Vec::new();
for w in &self.should {
if let Some(s) = w.scorer_supplier(reader)? {
should_suppliers.push(s);
}
}
let mut must_not_suppliers: Vec<Box<dyn ScorerSupplier>> = Vec::new();
for w in &self.must_not {
if let Some(s) = w.scorer_supplier(reader)? {
must_not_suppliers.push(s);
}
}
if must_suppliers.is_empty() && filter_suppliers.is_empty() && should_suppliers.is_empty() {
return Ok(None);
}
let cost = must_suppliers
.iter()
.chain(filter_suppliers.iter())
.map(|s| s.cost())
.min()
.unwrap_or_else(|| should_suppliers.iter().map(|s| s.cost()).sum::<u64>());
Ok(Some(Box::new(BoolScorerSupplier {
must: must_suppliers,
should: should_suppliers,
must_not: must_not_suppliers,
filter: filter_suppliers,
minimum_should_match: self.minimum_should_match,
score_mode: self.score_mode,
cost,
})))
}
}
struct BoolScorerSupplier {
must: Vec<Box<dyn ScorerSupplier>>,
should: Vec<Box<dyn ScorerSupplier>>,
must_not: Vec<Box<dyn ScorerSupplier>>,
filter: Vec<Box<dyn ScorerSupplier>>,
minimum_should_match: Option<u32>,
score_mode: ScoreMode,
cost: u64,
}
unsafe impl Send for BoolScorerSupplier {}
impl ScorerSupplier for BoolScorerSupplier {
fn cost(&self) -> u64 {
self.cost
}
fn scorer(self: Box<Self>) -> Result<Box<dyn Scorer>> {
let mut required_scorers: Vec<Box<dyn Scorer>> = Vec::new();
let mut must_with_cost: Vec<_> = self
.must
.into_iter()
.map(|s| {
let c = s.cost();
(s, c)
})
.collect();
must_with_cost.sort_by_key(|(_, c)| *c);
for (supplier, _) in must_with_cost {
required_scorers.push(supplier.scorer()?);
}
let mut filter_with_cost: Vec<_> = self
.filter
.into_iter()
.map(|s| {
let c = s.cost();
(s, c)
})
.collect();
filter_with_cost.sort_by_key(|(_, c)| *c);
for (supplier, _) in filter_with_cost {
required_scorers.push(supplier.scorer()?);
}
let mut exclusion_scorers: Vec<Box<dyn Scorer>> = Vec::new();
for supplier in self.must_not {
exclusion_scorers.push(supplier.scorer()?);
}
let should_scorers: Vec<Box<dyn Scorer>> = self
.should
.into_iter()
.map(|s| s.scorer())
.collect::<Result<_>>()?;
let min_should = self.minimum_should_match.unwrap_or(0) as usize;
let mut base_scorer: Box<dyn Scorer> = if !required_scorers.is_empty() {
if required_scorers.len() == 1 {
required_scorers.pop().unwrap()
} else {
Box::new(ConjunctionScorer::new(required_scorers))
}
} else if !should_scorers.is_empty() {
let effective_min = if min_should > 0 { min_should } else { 1 };
let mut scorer = build_should_scorer(should_scorers, effective_min, self.score_mode)?;
if !exclusion_scorers.is_empty() {
scorer = Box::new(ExclusionScorer::new(scorer, exclusion_scorers));
}
return Ok(scorer);
} else {
return Ok(Box::new(EmptyScorer));
};
if !exclusion_scorers.is_empty() {
base_scorer = Box::new(ExclusionScorer::new(base_scorer, exclusion_scorers));
}
if !should_scorers.is_empty() {
if min_should > 0 {
let should_scorer =
build_should_scorer(should_scorers, min_should, self.score_mode)?;
base_scorer = Box::new(ConjunctionScorer::new(vec![base_scorer, should_scorer]));
} else if self.score_mode.needs_scores() {
base_scorer = Box::new(OptionalScorer::new(base_scorer, should_scorers));
}
}
Ok(base_scorer)
}
}
struct EmptyScorer;
impl Scorer for EmptyScorer {
fn doc_id(&self) -> DocId {
NO_MORE_DOCS
}
fn next(&mut self) -> DocId {
NO_MORE_DOCS
}
fn advance(&mut self, _: DocId) -> DocId {
NO_MORE_DOCS
}
fn score(&mut self) -> f32 {
0.0
}
fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
None
}
}
struct ExclusionScorer {
base: Box<dyn Scorer>,
exclusions: Vec<Box<dyn Scorer>>,
}
impl ExclusionScorer {
fn new(base: Box<dyn Scorer>, exclusions: Vec<Box<dyn Scorer>>) -> Self {
let mut s = Self { base, exclusions };
s.skip_excluded();
s
}
fn is_excluded(&mut self) -> bool {
let target = self.base.doc_id();
for exc in &mut self.exclusions {
let doc = exc.advance(target);
if doc == target {
return true;
}
}
false
}
fn skip_excluded(&mut self) {
while self.base.doc_id() != NO_MORE_DOCS && self.is_excluded() {
self.base.next();
}
}
}
impl Scorer for ExclusionScorer {
fn doc_id(&self) -> DocId {
self.base.doc_id()
}
fn next(&mut self) -> DocId {
self.base.next();
self.skip_excluded();
self.base.doc_id()
}
fn advance(&mut self, target: DocId) -> DocId {
self.base.advance(target);
self.skip_excluded();
self.base.doc_id()
}
fn score(&mut self) -> f32 {
self.base.score()
}
fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
None
}
}
struct OptionalScorer {
base: Box<dyn Scorer>,
optionals: Vec<Box<dyn Scorer>>,
}
impl OptionalScorer {
fn new(base: Box<dyn Scorer>, optionals: Vec<Box<dyn Scorer>>) -> Self {
Self { base, optionals }
}
}
impl Scorer for OptionalScorer {
fn doc_id(&self) -> DocId {
self.base.doc_id()
}
fn next(&mut self) -> DocId {
self.base.next()
}
fn advance(&mut self, target: DocId) -> DocId {
self.base.advance(target)
}
fn score(&mut self) -> f32 {
let mut score = self.base.score();
let target = self.base.doc_id();
for opt in &mut self.optionals {
if opt.advance(target) == target {
score += opt.score();
}
}
score
}
fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
None
}
}
fn build_should_scorer(
mut scorers: Vec<Box<dyn Scorer>>,
min_match: usize,
score_mode: ScoreMode,
) -> Result<Box<dyn Scorer>> {
if min_match > scorers.len() {
return Ok(Box::new(EmptyScorer));
}
if scorers.len() == 1 {
return Ok(scorers.pop().unwrap());
}
if min_match == scorers.len() {
return Ok(Box::new(ConjunctionScorer::new(scorers)));
}
if !score_mode.needs_scores() && min_match <= 1 {
return Ok(Box::new(
crate::search::buffered_union::BufferedUnionScorer::new(scorers),
));
}
if min_match <= 1 {
return Ok(Box::new(crate::search::wand::WANDScorer::new(scorers)));
}
Ok(Box::new(
crate::search::wand::WANDScorer::new_min_should_match(scorers, min_match),
))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::query::term::TermQuery;
use crate::analysis::Token;
use crate::core::{FieldId, SegmentId};
use crate::mapping::{FieldType, Mapping};
use crate::segment::builder::SegmentBuilder;
fn make_tokens(terms: &[&str]) -> Vec<Token> {
terms
.iter()
.enumerate()
.map(|(i, t)| Token::new(*t, 0, t.len(), i as u32))
.collect()
}
fn build_test_store() -> crate::search::segment_store::SegmentStore {
let schema = Mapping::builder()
.field("body", FieldType::Text)
.field("tag", FieldType::Keyword)
.build();
let mut builder = SegmentBuilder::new(SegmentId::new(1), &schema);
builder.add_document(
&[
(FieldId::new(0), make_tokens(&["hello", "world"])),
(FieldId::new(1), make_tokens(&["a"])),
],
b"{}",
);
builder.add_document(
&[
(FieldId::new(0), make_tokens(&["hello", "luci"])),
(FieldId::new(1), make_tokens(&["b"])),
],
b"{}",
);
builder.add_document(
&[
(FieldId::new(0), make_tokens(&["goodbye", "world"])),
(FieldId::new(1), make_tokens(&["a"])),
],
b"{}",
);
builder.add_document(
&[
(FieldId::new(0), make_tokens(&["luci", "search"])),
(FieldId::new(1), make_tokens(&["c"])),
],
b"{}",
);
let reader = SegmentReader::open(builder.build()).unwrap();
crate::search::segment_store::SegmentStore::new(
vec![reader],
crate::analysis::AnalyzerRegistry::new(),
None,
None,
)
}
fn collect_doc_ids(scorer: &mut dyn Scorer) -> Vec<u32> {
let mut ids = Vec::new();
while scorer.doc_id() != NO_MORE_DOCS {
ids.push(scorer.doc_id().as_u32());
scorer.next();
}
ids
}
#[test]
fn bool_must_two_clauses() {
let store = build_test_store();
let searcher = Searcher::new(&store);
let query = BoolQuery {
must: vec![
Box::new(TermQuery {
field: "body".into(),
value: "hello".into(),
}),
Box::new(TermQuery {
field: "body".into(),
value: "world".into(),
}),
],
should: vec![],
must_not: vec![],
filter: vec![],
minimum_should_match: None,
};
let weight = query.bind(&searcher, ScoreMode::Complete).unwrap();
let supplier = weight
.scorer_supplier(&searcher.segments()[0])
.unwrap()
.unwrap();
let mut scorer = supplier.scorer().unwrap();
let ids = collect_doc_ids(scorer.as_mut());
assert_eq!(ids, vec![0]);
}
#[test]
fn bool_should_two_clauses() {
let store = build_test_store();
let searcher = Searcher::new(&store);
let query = BoolQuery {
must: vec![],
should: vec![
Box::new(TermQuery {
field: "body".into(),
value: "hello".into(),
}),
Box::new(TermQuery {
field: "body".into(),
value: "goodbye".into(),
}),
],
must_not: vec![],
filter: vec![],
minimum_should_match: None,
};
let weight = query.bind(&searcher, ScoreMode::Complete).unwrap();
let supplier = weight
.scorer_supplier(&searcher.segments()[0])
.unwrap()
.unwrap();
let mut scorer = supplier.scorer().unwrap();
let ids = collect_doc_ids(scorer.as_mut());
assert_eq!(ids, vec![0, 1, 2]);
}
#[test]
fn bool_must_not() {
let store = build_test_store();
let searcher = Searcher::new(&store);
let query = BoolQuery {
must: vec![Box::new(TermQuery {
field: "body".into(),
value: "hello".into(),
})],
should: vec![],
must_not: vec![Box::new(TermQuery {
field: "body".into(),
value: "world".into(),
})],
filter: vec![],
minimum_should_match: None,
};
let weight = query.bind(&searcher, ScoreMode::Complete).unwrap();
let supplier = weight
.scorer_supplier(&searcher.segments()[0])
.unwrap()
.unwrap();
let mut scorer = supplier.scorer().unwrap();
let ids = collect_doc_ids(scorer.as_mut());
assert_eq!(ids, vec![1]);
}
#[test]
fn bool_filter_no_scores() {
let store = build_test_store();
let searcher = Searcher::new(&store);
let query = BoolQuery {
must: vec![],
should: vec![],
must_not: vec![],
filter: vec![Box::new(TermQuery {
field: "tag".into(),
value: "a".into(),
})],
minimum_should_match: None,
};
let weight = query.bind(&searcher, ScoreMode::Complete).unwrap();
let supplier = weight
.scorer_supplier(&searcher.segments()[0])
.unwrap()
.unwrap();
let mut scorer = supplier.scorer().unwrap();
let ids = collect_doc_ids(scorer.as_mut());
assert_eq!(ids, vec![0, 2]);
}
#[test]
fn bool_must_plus_filter() {
let store = build_test_store();
let searcher = Searcher::new(&store);
let query = BoolQuery {
must: vec![Box::new(TermQuery {
field: "body".into(),
value: "hello".into(),
})],
should: vec![],
must_not: vec![],
filter: vec![Box::new(TermQuery {
field: "tag".into(),
value: "a".into(),
})],
minimum_should_match: None,
};
let weight = query.bind(&searcher, ScoreMode::Complete).unwrap();
let supplier = weight
.scorer_supplier(&searcher.segments()[0])
.unwrap()
.unwrap();
let mut scorer = supplier.scorer().unwrap();
let ids = collect_doc_ids(scorer.as_mut());
assert_eq!(ids, vec![0]);
}
#[test]
fn bool_empty_must_returns_none() {
let store = build_test_store();
let searcher = Searcher::new(&store);
let query = BoolQuery {
must: vec![Box::new(TermQuery {
field: "body".into(),
value: "nonexistent".into(),
})],
should: vec![],
must_not: vec![],
filter: vec![],
minimum_should_match: None,
};
let weight = query.bind(&searcher, ScoreMode::Complete).unwrap();
let supplier = weight.scorer_supplier(&searcher.segments()[0]).unwrap();
assert!(supplier.is_none());
}
#[test]
fn min_should_match_scores_all_matching_clauses() {
use crate::analysis::AnalyzerRegistry;
let schema = Mapping::builder().field("body", FieldType::Text).build();
let mut builder = SegmentBuilder::new(SegmentId::new(1), &schema);
builder.add_document(
&[(FieldId::new(0), make_tokens(&["aaa", "bbb", "ccc", "ddd"]))],
b"{}",
);
builder.add_document(&[(FieldId::new(0), make_tokens(&["aaa", "bbb"]))], b"{}");
builder.add_document(&[(FieldId::new(0), make_tokens(&["aaa"]))], b"{}");
let reader = crate::segment::reader::SegmentReader::open(builder.build()).unwrap();
let store = crate::search::segment_store::SegmentStore::new(
vec![reader],
AnalyzerRegistry::new(),
None,
None,
);
let searcher = crate::search::searcher::Searcher::new(&store);
let terms = ["aaa", "bbb", "ccc", "ddd"];
let mut expected_sum: f32 = 0.0;
for term in &terms {
let tq = TermQuery {
field: "body".into(),
value: (*term).into(),
};
let weight = tq.bind(&searcher, ScoreMode::Complete).unwrap();
let supplier = weight
.scorer_supplier(&searcher.segments()[0])
.unwrap()
.unwrap();
let mut scorer = supplier.scorer().unwrap();
assert_eq!(
scorer.doc_id(),
DocId::new(0),
"term '{term}' must be in doc 0"
);
expected_sum += scorer.score();
}
let msm_query = BoolQuery {
must: vec![],
should: vec![
Box::new(TermQuery {
field: "body".into(),
value: "aaa".into(),
}),
Box::new(TermQuery {
field: "body".into(),
value: "bbb".into(),
}),
Box::new(TermQuery {
field: "body".into(),
value: "ccc".into(),
}),
Box::new(TermQuery {
field: "body".into(),
value: "ddd".into(),
}),
],
must_not: vec![],
filter: vec![],
minimum_should_match: Some(2),
};
let weight = msm_query.bind(&searcher, ScoreMode::Complete).unwrap();
let supplier = weight
.scorer_supplier(&searcher.segments()[0])
.unwrap()
.unwrap();
let mut scorer = supplier.scorer().unwrap();
assert_eq!(scorer.doc_id(), DocId::new(0));
let msm_score = scorer.score();
assert!(
(msm_score - expected_sum).abs() < 1e-5,
"MSM score ({msm_score}) must equal sum of all clause scores ({expected_sum}); \
difference {} suggests tail entries were not scored",
(msm_score - expected_sum).abs(),
);
}
}