1use std::sync::Arc;
4
5use crate::segment::SegmentReader;
6use crate::structures::TERMINATED;
7use crate::{DocId, Score};
8
9use super::planner::{
10 build_sparse_maxscore_executor, chain_predicates, combine_sparse_results, compute_idf,
11 extract_all_sparse_infos, finish_text_maxscore, prepare_per_field_grouping,
12 prepare_text_maxscore,
13};
14use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture};
15
16#[derive(Default, Clone)]
21pub struct BooleanQuery {
22 pub must: Vec<Arc<dyn Query>>,
23 pub should: Vec<Arc<dyn Query>>,
24 pub must_not: Vec<Arc<dyn Query>>,
25 global_stats: Option<Arc<GlobalStats>>,
27}
28
29impl std::fmt::Debug for BooleanQuery {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("BooleanQuery")
32 .field("must_count", &self.must.len())
33 .field("should_count", &self.should.len())
34 .field("must_not_count", &self.must_not.len())
35 .field("has_global_stats", &self.global_stats.is_some())
36 .finish()
37 }
38}
39
40impl std::fmt::Display for BooleanQuery {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 write!(f, "Boolean(")?;
43 let mut first = true;
44 for q in &self.must {
45 if !first {
46 write!(f, " ")?;
47 }
48 write!(f, "+{}", q)?;
49 first = false;
50 }
51 for q in &self.should {
52 if !first {
53 write!(f, " ")?;
54 }
55 write!(f, "{}", q)?;
56 first = false;
57 }
58 for q in &self.must_not {
59 if !first {
60 write!(f, " ")?;
61 }
62 write!(f, "-{}", q)?;
63 first = false;
64 }
65 write!(f, ")")
66 }
67}
68
69impl BooleanQuery {
70 pub fn new() -> Self {
71 Self::default()
72 }
73
74 pub fn must(mut self, query: impl Query + 'static) -> Self {
75 self.must.push(Arc::new(query));
76 self
77 }
78
79 pub fn should(mut self, query: impl Query + 'static) -> Self {
80 self.should.push(Arc::new(query));
81 self
82 }
83
84 pub fn must_not(mut self, query: impl Query + 'static) -> Self {
85 self.must_not.push(Arc::new(query));
86 self
87 }
88
89 pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
91 self.global_stats = Some(stats);
92 self
93 }
94}
95
96fn build_should_scorer<'a>(scorers: Vec<Box<dyn Scorer + 'a>>) -> Box<dyn Scorer + 'a> {
98 if scorers.is_empty() {
99 return Box::new(EmptyScorer);
100 }
101 if scorers.len() == 1 {
102 return scorers.into_iter().next().unwrap();
103 }
104 let mut scorer = BooleanScorer {
105 must: vec![],
106 should: scorers,
107 must_not: vec![],
108 current_doc: 0,
109 };
110 scorer.current_doc = scorer.find_next_match();
111 Box::new(scorer)
112}
113
114macro_rules! boolean_plan {
128 ($must:expr, $should:expr, $must_not:expr, $global_stats:expr,
129 $reader:expr, $limit:expr,
130 $scorer_fn:ident, $get_postings_fn:ident, $execute_fn:ident
131 $(, $aw:tt)*) => {{
132 let must: &[Arc<dyn Query>] = &$must;
133 let should: &[Arc<dyn Query>] = &$should;
134 let must_not: &[Arc<dyn Query>] = &$must_not;
135 let global_stats: Option<&Arc<GlobalStats>> = $global_stats;
136 let reader: &SegmentReader = $reader;
137 let limit: usize = $limit;
138
139 if must_not.is_empty() {
141 if must.len() == 1 && should.is_empty() {
142 return must[0].$scorer_fn(reader, limit) $(. $aw)* ;
143 }
144 if should.len() == 1 && must.is_empty() {
145 return should[0].$scorer_fn(reader, limit) $(. $aw)* ;
146 }
147 }
148
149 if must.is_empty() && must_not.is_empty() && should.len() >= 2 {
151 if let Some((mut infos, _field, avg_field_len, num_docs)) =
153 prepare_text_maxscore(should, reader, global_stats)
154 {
155 let mut posting_lists = Vec::with_capacity(infos.len());
156 for info in infos.drain(..) {
157 if let Some(pl) = reader.$get_postings_fn(info.field, &info.term)
158 $(. $aw)* ?
159 {
160 let idf = compute_idf(&pl, info.field, &info.term, num_docs, global_stats);
161 posting_lists.push((pl, idf));
162 }
163 }
164 return finish_text_maxscore(posting_lists, avg_field_len, limit);
165 }
166
167 if let Some(infos) = extract_all_sparse_infos(should) {
169 if let Some((executor, info)) =
170 build_sparse_maxscore_executor(&infos, reader, limit, None)
171 {
172 let raw = executor.$execute_fn() $(. $aw)* ?;
173 return Ok(combine_sparse_results(raw, info.combiner, info.field, limit));
174 }
175 }
176
177 if let Some(grouping) = prepare_per_field_grouping(should, reader, limit, global_stats)
179 {
180 let mut scorers: Vec<Box<dyn Scorer + '_>> = Vec::new();
181 for (field, avg_field_len, infos) in &grouping.multi_term_groups {
182 let mut posting_lists = Vec::with_capacity(infos.len());
183 for info in infos {
184 if let Some(pl) = reader.$get_postings_fn(info.field, &info.term)
185 $(. $aw)* ?
186 {
187 let idf = compute_idf(
188 &pl, *field, &info.term, grouping.num_docs, global_stats,
189 );
190 posting_lists.push((pl, idf));
191 }
192 }
193 if !posting_lists.is_empty() {
194 scorers.push(finish_text_maxscore(
195 posting_lists,
196 *avg_field_len,
197 grouping.per_field_limit,
198 )?);
199 }
200 }
201 for &idx in &grouping.fallback_indices {
202 scorers.push(should[idx].$scorer_fn(reader, limit) $(. $aw)* ?);
203 }
204 return Ok(build_should_scorer(scorers));
205 }
206 }
207
208 if !should.is_empty() && !must.is_empty() && limit < usize::MAX / 4 {
210 let mut predicates: Vec<super::DocPredicate<'_>> = Vec::new();
212 let mut must_verifiers: Vec<Box<dyn super::Scorer + '_>> = Vec::new();
213 for q in must {
214 if let Some(pred) = q.as_doc_predicate(reader) {
215 predicates.push(pred);
216 } else {
217 must_verifiers.push(q.$scorer_fn(reader, limit) $(. $aw)* ?);
218 }
219 }
220 let mut must_not_verifiers: Vec<Box<dyn super::Scorer + '_>> = Vec::new();
222 for q in must_not {
223 if let Some(pred) = q.as_doc_predicate(reader) {
224 let negated: super::DocPredicate<'_> =
225 Box::new(move |doc_id| !pred(doc_id));
226 predicates.push(negated);
227 } else {
228 must_not_verifiers.push(q.$scorer_fn(reader, limit) $(. $aw)* ?);
229 }
230 }
231
232 if must_verifiers.is_empty()
234 && must_not_verifiers.is_empty()
235 && !predicates.is_empty()
236 {
237 if let Some(infos) = extract_all_sparse_infos(should) {
238 let combined = chain_predicates(predicates);
239 if let Some((executor, info)) =
240 build_sparse_maxscore_executor(&infos, reader, limit, Some(combined))
241 {
242 log::debug!(
243 "BooleanQuery planner: predicate-aware sparse MaxScore, {} dims",
244 infos.len()
245 );
246 let raw = executor.$execute_fn() $(. $aw)* ?;
247 return Ok(combine_sparse_results(raw, info.combiner, info.field, limit));
248 }
249 predicates = Vec::new();
252 for q in must {
253 if let Some(pred) = q.as_doc_predicate(reader) {
254 predicates.push(pred);
255 }
256 }
257 for q in must_not {
258 if let Some(pred) = q.as_doc_predicate(reader) {
259 let negated: super::DocPredicate<'_> =
260 Box::new(move |doc_id| !pred(doc_id));
261 predicates.push(negated);
262 }
263 }
264 }
265 }
266
267 let should_limit = if !predicates.is_empty() { limit * 4 } else { limit };
269 let should_scorer = if should.len() == 1 {
270 should[0].$scorer_fn(reader, should_limit) $(. $aw)* ?
271 } else {
272 let sub = BooleanQuery {
273 must: Vec::new(),
274 should: should.to_vec(),
275 must_not: Vec::new(),
276 global_stats: global_stats.cloned(),
277 };
278 sub.$scorer_fn(reader, should_limit) $(. $aw)* ?
279 };
280
281 let use_predicated =
282 must_verifiers.is_empty() || should_scorer.size_hint() >= limit as u32;
283
284 if use_predicated {
285 log::debug!(
286 "BooleanQuery planner: PredicatedScorer {} preds + {} must_v + {} must_not_v, \
287 SHOULD size_hint={}, over_fetch={}",
288 predicates.len(), must_verifiers.len(), must_not_verifiers.len(),
289 should_scorer.size_hint(), should_limit
290 );
291 return Ok(Box::new(super::PredicatedScorer::new(
292 should_scorer, predicates, must_verifiers, must_not_verifiers,
293 )));
294 }
295
296 let mut scorer = BooleanScorer {
298 must: must_verifiers,
299 should: vec![should_scorer],
300 must_not: must_not_verifiers,
301 current_doc: 0,
302 };
303 scorer.current_doc = scorer.find_next_match();
304 return Ok(Box::new(scorer));
305 }
306
307 let mut must_scorers = Vec::with_capacity(must.len());
309 for q in must {
310 must_scorers.push(q.$scorer_fn(reader, limit) $(. $aw)* ?);
311 }
312 let mut should_scorers = Vec::with_capacity(should.len());
313 for q in should {
314 should_scorers.push(q.$scorer_fn(reader, limit) $(. $aw)* ?);
315 }
316 let mut must_not_scorers = Vec::with_capacity(must_not.len());
317 for q in must_not {
318 must_not_scorers.push(q.$scorer_fn(reader, limit) $(. $aw)* ?);
319 }
320 let mut scorer = BooleanScorer {
321 must: must_scorers,
322 should: should_scorers,
323 must_not: must_not_scorers,
324 current_doc: 0,
325 };
326 scorer.current_doc = scorer.find_next_match();
327 Ok(Box::new(scorer) as Box<dyn Scorer + '_>)
328 }};
329}
330
331impl Query for BooleanQuery {
332 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
333 let must = self.must.clone();
334 let should = self.should.clone();
335 let must_not = self.must_not.clone();
336 let global_stats = self.global_stats.clone();
337 Box::pin(async move {
338 boolean_plan!(
339 must,
340 should,
341 must_not,
342 global_stats.as_ref(),
343 reader,
344 limit,
345 scorer,
346 get_postings,
347 execute,
348 await
349 )
350 })
351 }
352
353 #[cfg(feature = "sync")]
354 fn scorer_sync<'a>(
355 &self,
356 reader: &'a SegmentReader,
357 limit: usize,
358 ) -> crate::Result<Box<dyn Scorer + 'a>> {
359 boolean_plan!(
360 self.must,
361 self.should,
362 self.must_not,
363 self.global_stats.as_ref(),
364 reader,
365 limit,
366 scorer_sync,
367 get_postings_sync,
368 execute_sync
369 )
370 }
371
372 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
373 let must = self.must.clone();
374 let should = self.should.clone();
375
376 Box::pin(async move {
377 if !must.is_empty() {
378 let mut estimates = Vec::with_capacity(must.len());
379 for q in &must {
380 estimates.push(q.count_estimate(reader).await?);
381 }
382 estimates
383 .into_iter()
384 .min()
385 .ok_or_else(|| crate::Error::Corruption("Empty must clause".to_string()))
386 } else if !should.is_empty() {
387 let mut sum = 0u32;
388 for q in &should {
389 sum = sum.saturating_add(q.count_estimate(reader).await?);
390 }
391 Ok(sum)
392 } else {
393 Ok(0)
394 }
395 })
396 }
397}
398
399struct BooleanScorer<'a> {
400 must: Vec<Box<dyn Scorer + 'a>>,
401 should: Vec<Box<dyn Scorer + 'a>>,
402 must_not: Vec<Box<dyn Scorer + 'a>>,
403 current_doc: DocId,
404}
405
406impl BooleanScorer<'_> {
407 fn find_next_match(&mut self) -> DocId {
408 if self.must.is_empty() && self.should.is_empty() {
409 return TERMINATED;
410 }
411
412 loop {
413 let candidate = if !self.must.is_empty() {
414 let mut max_doc = self
415 .must
416 .iter()
417 .map(|s| s.doc())
418 .max()
419 .unwrap_or(TERMINATED);
420
421 if max_doc == TERMINATED {
422 return TERMINATED;
423 }
424
425 loop {
426 let mut all_match = true;
427 for scorer in &mut self.must {
428 let doc = scorer.seek(max_doc);
429 if doc == TERMINATED {
430 return TERMINATED;
431 }
432 if doc > max_doc {
433 max_doc = doc;
434 all_match = false;
435 break;
436 }
437 }
438 if all_match {
439 break;
440 }
441 }
442 max_doc
443 } else {
444 self.should
445 .iter()
446 .map(|s| s.doc())
447 .filter(|&d| d != TERMINATED)
448 .min()
449 .unwrap_or(TERMINATED)
450 };
451
452 if candidate == TERMINATED {
453 return TERMINATED;
454 }
455
456 let excluded = self.must_not.iter_mut().any(|scorer| {
457 let doc = scorer.seek(candidate);
458 doc == candidate
459 });
460
461 if !excluded {
462 for scorer in &mut self.should {
464 scorer.seek(candidate);
465 }
466 self.current_doc = candidate;
467 return candidate;
468 }
469
470 if !self.must.is_empty() {
472 for scorer in &mut self.must {
473 scorer.advance();
474 }
475 } else {
476 for scorer in &mut self.should {
478 if scorer.doc() <= candidate && scorer.doc() != TERMINATED {
479 scorer.seek(candidate + 1);
480 }
481 }
482 }
483 }
484 }
485}
486
487impl super::docset::DocSet for BooleanScorer<'_> {
488 fn doc(&self) -> DocId {
489 self.current_doc
490 }
491
492 fn advance(&mut self) -> DocId {
493 if !self.must.is_empty() {
494 for scorer in &mut self.must {
495 scorer.advance();
496 }
497 } else {
498 for scorer in &mut self.should {
499 if scorer.doc() == self.current_doc {
500 scorer.advance();
501 }
502 }
503 }
504
505 self.current_doc = self.find_next_match();
506 self.current_doc
507 }
508
509 fn seek(&mut self, target: DocId) -> DocId {
510 for scorer in &mut self.must {
511 scorer.seek(target);
512 }
513
514 for scorer in &mut self.should {
515 scorer.seek(target);
516 }
517
518 self.current_doc = self.find_next_match();
519 self.current_doc
520 }
521
522 fn size_hint(&self) -> u32 {
523 if !self.must.is_empty() {
524 self.must.iter().map(|s| s.size_hint()).min().unwrap_or(0)
525 } else {
526 self.should.iter().map(|s| s.size_hint()).sum()
527 }
528 }
529}
530
531impl Scorer for BooleanScorer<'_> {
532 fn score(&self) -> Score {
533 let mut total = 0.0;
534
535 for scorer in &self.must {
536 if scorer.doc() == self.current_doc {
537 total += scorer.score();
538 }
539 }
540
541 for scorer in &self.should {
542 if scorer.doc() == self.current_doc {
543 total += scorer.score();
544 }
545 }
546
547 total
548 }
549
550 fn matched_positions(&self) -> Option<super::MatchedPositions> {
551 let mut all_positions: super::MatchedPositions = Vec::new();
552
553 for scorer in &self.must {
554 if scorer.doc() == self.current_doc
555 && let Some(positions) = scorer.matched_positions()
556 {
557 all_positions.extend(positions);
558 }
559 }
560
561 for scorer in &self.should {
562 if scorer.doc() == self.current_doc
563 && let Some(positions) = scorer.matched_positions()
564 {
565 all_positions.extend(positions);
566 }
567 }
568
569 if all_positions.is_empty() {
570 None
571 } else {
572 Some(all_positions)
573 }
574 }
575}
576
577#[cfg(test)]
578mod tests {
579 use super::*;
580 use crate::dsl::Field;
581 use crate::query::{QueryDecomposition, TermQuery};
582
583 #[test]
584 fn test_maxscore_eligible_pure_or_same_field() {
585 let query = BooleanQuery::new()
587 .should(TermQuery::text(Field(0), "hello"))
588 .should(TermQuery::text(Field(0), "world"))
589 .should(TermQuery::text(Field(0), "foo"));
590
591 assert!(
593 query
594 .should
595 .iter()
596 .all(|q| matches!(q.decompose(), QueryDecomposition::TextTerm(_)))
597 );
598
599 let infos: Vec<_> = query
601 .should
602 .iter()
603 .filter_map(|q| match q.decompose() {
604 QueryDecomposition::TextTerm(info) => Some(info),
605 _ => None,
606 })
607 .collect();
608 assert_eq!(infos.len(), 3);
609 assert!(infos.iter().all(|i| i.field == Field(0)));
610 }
611
612 #[test]
613 fn test_maxscore_not_eligible_different_fields() {
614 let query = BooleanQuery::new()
616 .should(TermQuery::text(Field(0), "hello"))
617 .should(TermQuery::text(Field(1), "world")); let infos: Vec<_> = query
620 .should
621 .iter()
622 .filter_map(|q| match q.decompose() {
623 QueryDecomposition::TextTerm(info) => Some(info),
624 _ => None,
625 })
626 .collect();
627 assert_eq!(infos.len(), 2);
628 assert!(infos[0].field != infos[1].field);
630 }
631
632 #[test]
633 fn test_maxscore_not_eligible_with_must() {
634 let query = BooleanQuery::new()
636 .must(TermQuery::text(Field(0), "required"))
637 .should(TermQuery::text(Field(0), "hello"))
638 .should(TermQuery::text(Field(0), "world"));
639
640 assert!(!query.must.is_empty());
642 }
643
644 #[test]
645 fn test_maxscore_not_eligible_with_must_not() {
646 let query = BooleanQuery::new()
648 .should(TermQuery::text(Field(0), "hello"))
649 .should(TermQuery::text(Field(0), "world"))
650 .must_not(TermQuery::text(Field(0), "excluded"));
651
652 assert!(!query.must_not.is_empty());
654 }
655
656 #[test]
657 fn test_maxscore_not_eligible_single_term() {
658 let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
660
661 assert_eq!(query.should.len(), 1);
663 }
664
665 #[test]
666 fn test_term_query_info_extraction() {
667 let term_query = TermQuery::text(Field(42), "test");
668 match term_query.decompose() {
669 QueryDecomposition::TextTerm(info) => {
670 assert_eq!(info.field, Field(42));
671 assert_eq!(info.term, b"test");
672 }
673 _ => panic!("Expected TextTerm decomposition"),
674 }
675 }
676
677 #[test]
678 fn test_boolean_query_no_term_info() {
679 let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
681
682 assert!(matches!(query.decompose(), QueryDecomposition::Opaque));
683 }
684}