1use crate::core::{DocId, NO_MORE_DOCS, Result, ScoreMode, Scorer, TwoPhaseIterator};
11
12use crate::query::{BoundQuery, Query, ScorerSupplier};
13use crate::search::conjunction::ConjunctionScorer;
14use crate::search::searcher::Searcher;
15use crate::segment::reader::SegmentReader;
16
17pub struct BoolQuery {
19 pub(crate) must: Vec<Box<dyn Query>>,
20 pub(crate) should: Vec<Box<dyn Query>>,
21 pub(crate) must_not: Vec<Box<dyn Query>>,
22 pub(crate) filter: Vec<Box<dyn Query>>,
23 pub(crate) minimum_should_match: Option<u32>,
24}
25
26impl Query for BoolQuery {
27 fn bind(&self, searcher: &Searcher, score_mode: ScoreMode) -> Result<Box<dyn BoundQuery>> {
28 let must_weights: Vec<Box<dyn BoundQuery>> = self
29 .must
30 .iter()
31 .map(|q| q.bind(searcher, score_mode))
32 .collect::<Result<_>>()?;
33
34 let should_weights: Vec<Box<dyn BoundQuery>> = self
35 .should
36 .iter()
37 .map(|q| q.bind(searcher, score_mode))
38 .collect::<Result<_>>()?;
39
40 let must_not_weights: Vec<Box<dyn BoundQuery>> = self
41 .must_not
42 .iter()
43 .map(|q| q.bind(searcher, ScoreMode::CompleteNoScores))
44 .collect::<Result<_>>()?;
45
46 let filter_weights: Vec<Box<dyn BoundQuery>> = self
47 .filter
48 .iter()
49 .map(|q| q.bind(searcher, ScoreMode::CompleteNoScores))
50 .collect::<Result<_>>()?;
51
52 Ok(Box::new(BoundBoolQuery {
53 must: must_weights,
54 should: should_weights,
55 must_not: must_not_weights,
56 filter: filter_weights,
57 minimum_should_match: self.minimum_should_match,
58 score_mode,
59 }))
60 }
61}
62
63struct BoundBoolQuery {
64 must: Vec<Box<dyn BoundQuery>>,
65 should: Vec<Box<dyn BoundQuery>>,
66 must_not: Vec<Box<dyn BoundQuery>>,
67 filter: Vec<Box<dyn BoundQuery>>,
68 minimum_should_match: Option<u32>,
69 score_mode: ScoreMode,
70}
71
72impl BoundQuery for BoundBoolQuery {
73 fn bulk_score(
74 &self,
75 reader: &SegmentReader,
76 collector: &mut crate::search::collector::TopDocsCollector,
77 segment_id: crate::core::SegmentId,
78 ) -> Result<Option<u64>> {
79 if !self.must.is_empty() || !self.filter.is_empty() || !self.must_not.is_empty() {
81 return Ok(None);
82 }
83 if self.minimum_should_match.map_or(false, |m| m > 1) {
84 return Ok(None);
85 }
86 if self.should.len() < 2 {
87 return Ok(None);
88 }
89
90 let mut scorers: Vec<Box<dyn crate::core::Scorer>> = Vec::new();
91 for w in &self.should {
92 if let Some(supplier) = w.scorer_supplier(reader)? {
93 scorers.push(supplier.scorer()?);
94 }
95 }
96 if scorers.len() < 2 {
97 return Ok(None);
98 }
99
100 let max_doc = reader.doc_count();
101 let mut bulk = crate::search::bulk::MaxScoreBulkScorer::new(scorers);
102 let hits = bulk.score(collector, segment_id, max_doc);
103 Ok(Some(hits))
104 }
105
106 fn scorer_supplier(&self, reader: &SegmentReader) -> Result<Option<Box<dyn ScorerSupplier>>> {
107 let mut must_suppliers: Vec<Box<dyn ScorerSupplier>> = Vec::new();
109 for w in &self.must {
110 match w.scorer_supplier(reader)? {
111 Some(s) => must_suppliers.push(s),
112 None => return Ok(None), }
114 }
115
116 let mut filter_suppliers: Vec<Box<dyn ScorerSupplier>> = Vec::new();
118 for w in &self.filter {
119 match w.scorer_supplier(reader)? {
120 Some(s) => filter_suppliers.push(s),
121 None => return Ok(None), }
123 }
124
125 let mut should_suppliers: Vec<Box<dyn ScorerSupplier>> = Vec::new();
127 for w in &self.should {
128 if let Some(s) = w.scorer_supplier(reader)? {
129 should_suppliers.push(s);
130 }
131 }
132
133 let mut must_not_suppliers: Vec<Box<dyn ScorerSupplier>> = Vec::new();
135 for w in &self.must_not {
136 if let Some(s) = w.scorer_supplier(reader)? {
137 must_not_suppliers.push(s);
138 }
139 }
140
141 if must_suppliers.is_empty() && filter_suppliers.is_empty() && should_suppliers.is_empty() {
143 return Ok(None);
144 }
145
146 let cost = must_suppliers
148 .iter()
149 .chain(filter_suppliers.iter())
150 .map(|s| s.cost())
151 .min()
152 .unwrap_or_else(|| should_suppliers.iter().map(|s| s.cost()).sum::<u64>());
153
154 Ok(Some(Box::new(BoolScorerSupplier {
155 must: must_suppliers,
156 should: should_suppliers,
157 must_not: must_not_suppliers,
158 filter: filter_suppliers,
159 minimum_should_match: self.minimum_should_match,
160 score_mode: self.score_mode,
161 cost,
162 })))
163 }
164}
165
166struct BoolScorerSupplier {
167 must: Vec<Box<dyn ScorerSupplier>>,
168 should: Vec<Box<dyn ScorerSupplier>>,
169 must_not: Vec<Box<dyn ScorerSupplier>>,
170 filter: Vec<Box<dyn ScorerSupplier>>,
171 minimum_should_match: Option<u32>,
172 score_mode: ScoreMode,
173 cost: u64,
174}
175
176unsafe impl Send for BoolScorerSupplier {}
178
179impl ScorerSupplier for BoolScorerSupplier {
180 fn cost(&self) -> u64 {
181 self.cost
182 }
183
184 fn scorer(self: Box<Self>) -> Result<Box<dyn Scorer>> {
185 let mut required_scorers: Vec<Box<dyn Scorer>> = Vec::new();
187
188 let mut must_with_cost: Vec<_> = self
191 .must
192 .into_iter()
193 .map(|s| {
194 let c = s.cost();
195 (s, c)
196 })
197 .collect();
198 must_with_cost.sort_by_key(|(_, c)| *c);
199 for (supplier, _) in must_with_cost {
200 required_scorers.push(supplier.scorer()?);
201 }
202
203 let mut filter_with_cost: Vec<_> = self
205 .filter
206 .into_iter()
207 .map(|s| {
208 let c = s.cost();
209 (s, c)
210 })
211 .collect();
212 filter_with_cost.sort_by_key(|(_, c)| *c);
213 for (supplier, _) in filter_with_cost {
214 required_scorers.push(supplier.scorer()?);
215 }
216
217 let mut exclusion_scorers: Vec<Box<dyn Scorer>> = Vec::new();
219 for supplier in self.must_not {
220 exclusion_scorers.push(supplier.scorer()?);
221 }
222
223 let should_scorers: Vec<Box<dyn Scorer>> = self
225 .should
226 .into_iter()
227 .map(|s| s.scorer())
228 .collect::<Result<_>>()?;
229
230 let min_should = self.minimum_should_match.unwrap_or(0) as usize;
231
232 let mut base_scorer: Box<dyn Scorer> = if !required_scorers.is_empty() {
234 if required_scorers.len() == 1 {
236 required_scorers.pop().unwrap()
237 } else {
238 Box::new(ConjunctionScorer::new(required_scorers))
239 }
240 } else if !should_scorers.is_empty() {
241 let effective_min = if min_should > 0 { min_should } else { 1 };
243
244 let mut scorer = build_should_scorer(should_scorers, effective_min, self.score_mode)?;
245
246 if !exclusion_scorers.is_empty() {
248 scorer = Box::new(ExclusionScorer::new(scorer, exclusion_scorers));
249 }
250
251 return Ok(scorer);
252 } else {
253 return Ok(Box::new(EmptyScorer));
254 };
255
256 if !exclusion_scorers.is_empty() {
258 base_scorer = Box::new(ExclusionScorer::new(base_scorer, exclusion_scorers));
259 }
260
261 if !should_scorers.is_empty() {
268 if min_should > 0 {
269 let should_scorer =
272 build_should_scorer(should_scorers, min_should, self.score_mode)?;
273 base_scorer = Box::new(ConjunctionScorer::new(vec![base_scorer, should_scorer]));
274 } else if self.score_mode.needs_scores() {
275 base_scorer = Box::new(OptionalScorer::new(base_scorer, should_scorers));
277 }
278 }
280
281 Ok(base_scorer)
282 }
283}
284
285struct EmptyScorer;
287
288impl Scorer for EmptyScorer {
289 fn doc_id(&self) -> DocId {
290 NO_MORE_DOCS
291 }
292 fn next(&mut self) -> DocId {
293 NO_MORE_DOCS
294 }
295 fn advance(&mut self, _: DocId) -> DocId {
296 NO_MORE_DOCS
297 }
298 fn score(&mut self) -> f32 {
299 0.0
300 }
301 fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
302 None
303 }
304}
305
306struct ExclusionScorer {
308 base: Box<dyn Scorer>,
309 exclusions: Vec<Box<dyn Scorer>>,
310}
311
312impl ExclusionScorer {
313 fn new(base: Box<dyn Scorer>, exclusions: Vec<Box<dyn Scorer>>) -> Self {
314 let mut s = Self { base, exclusions };
315 s.skip_excluded();
316 s
317 }
318
319 fn is_excluded(&mut self) -> bool {
320 let target = self.base.doc_id();
321 for exc in &mut self.exclusions {
322 let doc = exc.advance(target);
323 if doc == target {
324 return true;
325 }
326 }
327 false
328 }
329
330 fn skip_excluded(&mut self) {
331 while self.base.doc_id() != NO_MORE_DOCS && self.is_excluded() {
332 self.base.next();
333 }
334 }
335}
336
337impl Scorer for ExclusionScorer {
338 fn doc_id(&self) -> DocId {
339 self.base.doc_id()
340 }
341 fn next(&mut self) -> DocId {
342 self.base.next();
343 self.skip_excluded();
344 self.base.doc_id()
345 }
346 fn advance(&mut self, target: DocId) -> DocId {
347 self.base.advance(target);
348 self.skip_excluded();
349 self.base.doc_id()
350 }
351 fn score(&mut self) -> f32 {
352 self.base.score()
353 }
354 fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
355 None
356 }
357}
358
359struct OptionalScorer {
361 base: Box<dyn Scorer>,
362 optionals: Vec<Box<dyn Scorer>>,
363}
364
365impl OptionalScorer {
366 fn new(base: Box<dyn Scorer>, optionals: Vec<Box<dyn Scorer>>) -> Self {
367 Self { base, optionals }
368 }
369}
370
371impl Scorer for OptionalScorer {
372 fn doc_id(&self) -> DocId {
373 self.base.doc_id()
374 }
375 fn next(&mut self) -> DocId {
376 self.base.next()
377 }
378 fn advance(&mut self, target: DocId) -> DocId {
379 self.base.advance(target)
380 }
381 fn score(&mut self) -> f32 {
382 let mut score = self.base.score();
383 let target = self.base.doc_id();
384 for opt in &mut self.optionals {
385 if opt.advance(target) == target {
386 score += opt.score();
387 }
388 }
389 score
390 }
391 fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
392 None
393 }
394}
395
396fn build_should_scorer(
407 mut scorers: Vec<Box<dyn Scorer>>,
408 min_match: usize,
409 score_mode: ScoreMode,
410) -> Result<Box<dyn Scorer>> {
411 if min_match > scorers.len() {
412 return Ok(Box::new(EmptyScorer));
413 }
414 if scorers.len() == 1 {
415 return Ok(scorers.pop().unwrap());
416 }
417 if min_match == scorers.len() {
418 return Ok(Box::new(ConjunctionScorer::new(scorers)));
420 }
421 if !score_mode.needs_scores() && min_match <= 1 {
424 return Ok(Box::new(
425 crate::search::buffered_union::BufferedUnionScorer::new(scorers),
426 ));
427 }
428 if min_match <= 1 {
429 return Ok(Box::new(crate::search::wand::WANDScorer::new(scorers)));
430 }
431 Ok(Box::new(
432 crate::search::wand::WANDScorer::new_min_should_match(scorers, min_match),
433 ))
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439 use crate::query::term::TermQuery;
440
441 use crate::analysis::Token;
442 use crate::core::{FieldId, SegmentId};
443 use crate::mapping::{FieldType, Mapping};
444 use crate::segment::builder::SegmentBuilder;
445
446 fn make_tokens(terms: &[&str]) -> Vec<Token> {
447 terms
448 .iter()
449 .enumerate()
450 .map(|(i, t)| Token::new(*t, 0, t.len(), i as u32))
451 .collect()
452 }
453
454 fn build_test_store() -> crate::search::segment_store::SegmentStore {
455 let schema = Mapping::builder()
456 .field("body", FieldType::Text)
457 .field("tag", FieldType::Keyword)
458 .build();
459 let mut builder = SegmentBuilder::new(SegmentId::new(1), &schema);
460
461 builder.add_document(
463 &[
464 (FieldId::new(0), make_tokens(&["hello", "world"])),
465 (FieldId::new(1), make_tokens(&["a"])),
466 ],
467 b"{}",
468 );
469 builder.add_document(
471 &[
472 (FieldId::new(0), make_tokens(&["hello", "luci"])),
473 (FieldId::new(1), make_tokens(&["b"])),
474 ],
475 b"{}",
476 );
477 builder.add_document(
479 &[
480 (FieldId::new(0), make_tokens(&["goodbye", "world"])),
481 (FieldId::new(1), make_tokens(&["a"])),
482 ],
483 b"{}",
484 );
485 builder.add_document(
487 &[
488 (FieldId::new(0), make_tokens(&["luci", "search"])),
489 (FieldId::new(1), make_tokens(&["c"])),
490 ],
491 b"{}",
492 );
493
494 let reader = SegmentReader::open(builder.build()).unwrap();
495 crate::search::segment_store::SegmentStore::new(
496 vec![reader],
497 crate::analysis::AnalyzerRegistry::new(),
498 None,
499 None,
500 )
501 }
502
503 fn collect_doc_ids(scorer: &mut dyn Scorer) -> Vec<u32> {
504 let mut ids = Vec::new();
505 while scorer.doc_id() != NO_MORE_DOCS {
506 ids.push(scorer.doc_id().as_u32());
507 scorer.next();
508 }
509 ids
510 }
511
512 #[test]
513 fn bool_must_two_clauses() {
514 let store = build_test_store();
515 let searcher = Searcher::new(&store);
516 let query = BoolQuery {
517 must: vec![
518 Box::new(TermQuery {
519 field: "body".into(),
520 value: "hello".into(),
521 }),
522 Box::new(TermQuery {
523 field: "body".into(),
524 value: "world".into(),
525 }),
526 ],
527 should: vec![],
528 must_not: vec![],
529 filter: vec![],
530 minimum_should_match: None,
531 };
532
533 let weight = query.bind(&searcher, ScoreMode::Complete).unwrap();
534 let supplier = weight
535 .scorer_supplier(&searcher.segments()[0])
536 .unwrap()
537 .unwrap();
538 let mut scorer = supplier.scorer().unwrap();
539
540 let ids = collect_doc_ids(scorer.as_mut());
542 assert_eq!(ids, vec![0]);
543 }
544
545 #[test]
546 fn bool_should_two_clauses() {
547 let store = build_test_store();
548 let searcher = Searcher::new(&store);
549 let query = BoolQuery {
550 must: vec![],
551 should: vec![
552 Box::new(TermQuery {
553 field: "body".into(),
554 value: "hello".into(),
555 }),
556 Box::new(TermQuery {
557 field: "body".into(),
558 value: "goodbye".into(),
559 }),
560 ],
561 must_not: vec![],
562 filter: vec![],
563 minimum_should_match: None,
564 };
565
566 let weight = query.bind(&searcher, ScoreMode::Complete).unwrap();
567 let supplier = weight
568 .scorer_supplier(&searcher.segments()[0])
569 .unwrap()
570 .unwrap();
571 let mut scorer = supplier.scorer().unwrap();
572
573 let ids = collect_doc_ids(scorer.as_mut());
575 assert_eq!(ids, vec![0, 1, 2]);
576 }
577
578 #[test]
579 fn bool_must_not() {
580 let store = build_test_store();
581 let searcher = Searcher::new(&store);
582 let query = BoolQuery {
583 must: vec![Box::new(TermQuery {
584 field: "body".into(),
585 value: "hello".into(),
586 })],
587 should: vec![],
588 must_not: vec![Box::new(TermQuery {
589 field: "body".into(),
590 value: "world".into(),
591 })],
592 filter: vec![],
593 minimum_should_match: None,
594 };
595
596 let weight = query.bind(&searcher, ScoreMode::Complete).unwrap();
597 let supplier = weight
598 .scorer_supplier(&searcher.segments()[0])
599 .unwrap()
600 .unwrap();
601 let mut scorer = supplier.scorer().unwrap();
602
603 let ids = collect_doc_ids(scorer.as_mut());
605 assert_eq!(ids, vec![1]);
606 }
607
608 #[test]
609 fn bool_filter_no_scores() {
610 let store = build_test_store();
611 let searcher = Searcher::new(&store);
612 let query = BoolQuery {
613 must: vec![],
614 should: vec![],
615 must_not: vec![],
616 filter: vec![Box::new(TermQuery {
617 field: "tag".into(),
618 value: "a".into(),
619 })],
620 minimum_should_match: None,
621 };
622
623 let weight = query.bind(&searcher, ScoreMode::Complete).unwrap();
624 let supplier = weight
625 .scorer_supplier(&searcher.segments()[0])
626 .unwrap()
627 .unwrap();
628 let mut scorer = supplier.scorer().unwrap();
629
630 let ids = collect_doc_ids(scorer.as_mut());
632 assert_eq!(ids, vec![0, 2]);
633 }
634
635 #[test]
636 fn bool_must_plus_filter() {
637 let store = build_test_store();
638 let searcher = Searcher::new(&store);
639 let query = BoolQuery {
640 must: vec![Box::new(TermQuery {
641 field: "body".into(),
642 value: "hello".into(),
643 })],
644 should: vec![],
645 must_not: vec![],
646 filter: vec![Box::new(TermQuery {
647 field: "tag".into(),
648 value: "a".into(),
649 })],
650 minimum_should_match: None,
651 };
652
653 let weight = query.bind(&searcher, ScoreMode::Complete).unwrap();
654 let supplier = weight
655 .scorer_supplier(&searcher.segments()[0])
656 .unwrap()
657 .unwrap();
658 let mut scorer = supplier.scorer().unwrap();
659
660 let ids = collect_doc_ids(scorer.as_mut());
662 assert_eq!(ids, vec![0]);
663 }
664
665 #[test]
666 fn bool_empty_must_returns_none() {
667 let store = build_test_store();
668 let searcher = Searcher::new(&store);
669 let query = BoolQuery {
670 must: vec![Box::new(TermQuery {
671 field: "body".into(),
672 value: "nonexistent".into(),
673 })],
674 should: vec![],
675 must_not: vec![],
676 filter: vec![],
677 minimum_should_match: None,
678 };
679
680 let weight = query.bind(&searcher, ScoreMode::Complete).unwrap();
681 let supplier = weight.scorer_supplier(&searcher.segments()[0]).unwrap();
682 assert!(supplier.is_none());
683 }
684
685 #[test]
691 fn min_should_match_scores_all_matching_clauses() {
692 use crate::analysis::AnalyzerRegistry;
693
694 let schema = Mapping::builder().field("body", FieldType::Text).build();
695 let mut builder = SegmentBuilder::new(SegmentId::new(1), &schema);
696
697 builder.add_document(
699 &[(FieldId::new(0), make_tokens(&["aaa", "bbb", "ccc", "ddd"]))],
700 b"{}",
701 );
702 builder.add_document(&[(FieldId::new(0), make_tokens(&["aaa", "bbb"]))], b"{}");
704 builder.add_document(&[(FieldId::new(0), make_tokens(&["aaa"]))], b"{}");
706
707 let reader = crate::segment::reader::SegmentReader::open(builder.build()).unwrap();
708 let store = crate::search::segment_store::SegmentStore::new(
709 vec![reader],
710 AnalyzerRegistry::new(),
711 None,
712 None,
713 );
714 let searcher = crate::search::searcher::Searcher::new(&store);
715
716 let terms = ["aaa", "bbb", "ccc", "ddd"];
718 let mut expected_sum: f32 = 0.0;
719 for term in &terms {
720 let tq = TermQuery {
721 field: "body".into(),
722 value: (*term).into(),
723 };
724 let weight = tq.bind(&searcher, ScoreMode::Complete).unwrap();
725 let supplier = weight
726 .scorer_supplier(&searcher.segments()[0])
727 .unwrap()
728 .unwrap();
729 let mut scorer = supplier.scorer().unwrap();
730 assert_eq!(
731 scorer.doc_id(),
732 DocId::new(0),
733 "term '{term}' must be in doc 0"
734 );
735 expected_sum += scorer.score();
736 }
737
738 let msm_query = BoolQuery {
740 must: vec![],
741 should: vec![
742 Box::new(TermQuery {
743 field: "body".into(),
744 value: "aaa".into(),
745 }),
746 Box::new(TermQuery {
747 field: "body".into(),
748 value: "bbb".into(),
749 }),
750 Box::new(TermQuery {
751 field: "body".into(),
752 value: "ccc".into(),
753 }),
754 Box::new(TermQuery {
755 field: "body".into(),
756 value: "ddd".into(),
757 }),
758 ],
759 must_not: vec![],
760 filter: vec![],
761 minimum_should_match: Some(2),
762 };
763
764 let weight = msm_query.bind(&searcher, ScoreMode::Complete).unwrap();
765 let supplier = weight
766 .scorer_supplier(&searcher.segments()[0])
767 .unwrap()
768 .unwrap();
769 let mut scorer = supplier.scorer().unwrap();
770
771 assert_eq!(scorer.doc_id(), DocId::new(0));
772 let msm_score = scorer.score();
773
774 assert!(
776 (msm_score - expected_sum).abs() < 1e-5,
777 "MSM score ({msm_score}) must equal sum of all clause scores ({expected_sum}); \
778 difference {} suggests tail entries were not scored",
779 (msm_score - expected_sum).abs(),
780 );
781 }
782}