Skip to main content

hermes_core/query/
boolean.rs

1//! Boolean query with MUST, SHOULD, and MUST_NOT clauses
2
3use std::sync::Arc;
4
5use crate::segment::SegmentReader;
6use crate::structures::TERMINATED;
7use crate::{DocId, Score};
8
9use super::planner::{
10    build_sparse_maxscore_executor, chain_predicates, combine_sparse_results, compute_idf,
11    extract_all_sparse_infos, finish_text_maxscore, prepare_per_field_grouping,
12    prepare_text_maxscore,
13};
14use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture};
15
16/// Boolean query with MUST, SHOULD, and MUST_NOT clauses
17///
18/// When all clauses are SHOULD term queries on the same field, automatically
19/// uses MaxScore optimization for efficient top-k retrieval.
20#[derive(Default, Clone)]
21pub struct BooleanQuery {
22    pub must: Vec<Arc<dyn Query>>,
23    pub should: Vec<Arc<dyn Query>>,
24    pub must_not: Vec<Arc<dyn Query>>,
25    /// Optional global statistics for cross-segment IDF
26    global_stats: Option<Arc<GlobalStats>>,
27}
28
29impl std::fmt::Debug for BooleanQuery {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("BooleanQuery")
32            .field("must_count", &self.must.len())
33            .field("should_count", &self.should.len())
34            .field("must_not_count", &self.must_not.len())
35            .field("has_global_stats", &self.global_stats.is_some())
36            .finish()
37    }
38}
39
40impl std::fmt::Display for BooleanQuery {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        write!(f, "Boolean(")?;
43        let mut first = true;
44        for q in &self.must {
45            if !first {
46                write!(f, " ")?;
47            }
48            write!(f, "+{}", q)?;
49            first = false;
50        }
51        for q in &self.should {
52            if !first {
53                write!(f, " ")?;
54            }
55            write!(f, "{}", q)?;
56            first = false;
57        }
58        for q in &self.must_not {
59            if !first {
60                write!(f, " ")?;
61            }
62            write!(f, "-{}", q)?;
63            first = false;
64        }
65        write!(f, ")")
66    }
67}
68
69impl BooleanQuery {
70    pub fn new() -> Self {
71        Self::default()
72    }
73
74    pub fn must(mut self, query: impl Query + 'static) -> Self {
75        self.must.push(Arc::new(query));
76        self
77    }
78
79    pub fn should(mut self, query: impl Query + 'static) -> Self {
80        self.should.push(Arc::new(query));
81        self
82    }
83
84    pub fn must_not(mut self, query: impl Query + 'static) -> Self {
85        self.must_not.push(Arc::new(query));
86        self
87    }
88
89    /// Set global statistics for cross-segment IDF
90    pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
91        self.global_stats = Some(stats);
92        self
93    }
94}
95
96/// Build a SHOULD-only scorer from a vec of optimized scorers.
97fn build_should_scorer<'a>(scorers: Vec<Box<dyn Scorer + 'a>>) -> Box<dyn Scorer + 'a> {
98    if scorers.is_empty() {
99        return Box::new(EmptyScorer);
100    }
101    if scorers.len() == 1 {
102        return scorers.into_iter().next().unwrap();
103    }
104    let mut scorer = BooleanScorer {
105        must: vec![],
106        should: scorers,
107        must_not: vec![],
108        current_doc: 0,
109    };
110    scorer.current_doc = scorer.find_next_match();
111    Box::new(scorer)
112}
113
114// ── Planner macro ────────────────────────────────────────────────────────
115//
116// Unified planner for both async and sync paths.  Parameterised on:
117//   $scorer_fn      – scorer | scorer_sync
118//   $get_postings_fn – get_postings | get_postings_sync
119//   $execute_fn     – execute | execute_sync
120//   $($aw)*         – .await  (present for async, absent for sync)
121//
122// Decision order:
123//   1. Single-clause unwrap
124//   2. Pure OR → text MaxScore | sparse MaxScore | per-field MaxScore
125//   3. Filter push-down → predicate-aware sparse MaxScore | PredicatedScorer
126//   4. Standard BooleanScorer fallback
127macro_rules! boolean_plan {
128    ($must:expr, $should:expr, $must_not:expr, $global_stats:expr,
129     $reader:expr, $limit:expr,
130     $scorer_fn:ident, $get_postings_fn:ident, $execute_fn:ident
131     $(, $aw:tt)*) => {{
132        let must: &[Arc<dyn Query>] = &$must;
133        let should: &[Arc<dyn Query>] = &$should;
134        let must_not: &[Arc<dyn Query>] = &$must_not;
135        let global_stats: Option<&Arc<GlobalStats>> = $global_stats;
136        let reader: &SegmentReader = $reader;
137        let limit: usize = $limit;
138
139        // ── 1. Single-clause optimisation ────────────────────────────────
140        if must_not.is_empty() {
141            if must.len() == 1 && should.is_empty() {
142                return must[0].$scorer_fn(reader, limit) $(.  $aw)* ;
143            }
144            if should.len() == 1 && must.is_empty() {
145                return should[0].$scorer_fn(reader, limit) $(. $aw)* ;
146            }
147        }
148
149        // ── 2. Pure OR → MaxScore optimisations ──────────────────────────
150        if must.is_empty() && must_not.is_empty() && should.len() >= 2 {
151            // 2a. Text MaxScore (single-field, all term queries)
152            if let Some((mut infos, _field, avg_field_len, num_docs)) =
153                prepare_text_maxscore(should, reader, global_stats)
154            {
155                let mut posting_lists = Vec::with_capacity(infos.len());
156                for info in infos.drain(..) {
157                    if let Some(pl) = reader.$get_postings_fn(info.field, &info.term)
158                        $(. $aw)* ?
159                    {
160                        let idf = compute_idf(&pl, info.field, &info.term, num_docs, global_stats);
161                        posting_lists.push((pl, idf));
162                    }
163                }
164                return finish_text_maxscore(posting_lists, avg_field_len, limit);
165            }
166
167            // 2b. Sparse MaxScore (single-field, all sparse term queries)
168            if let Some(infos) = extract_all_sparse_infos(should) {
169                if let Some((executor, info)) =
170                    build_sparse_maxscore_executor(&infos, reader, limit, None)
171                {
172                    let raw = executor.$execute_fn() $(. $aw)* ?;
173                    return Ok(combine_sparse_results(raw, info.combiner, info.field, limit));
174                }
175            }
176
177            // 2c. Per-field text MaxScore (multi-field term grouping)
178            if let Some(grouping) = prepare_per_field_grouping(should, reader, limit, global_stats)
179            {
180                let mut scorers: Vec<Box<dyn Scorer + '_>> = Vec::new();
181                for (field, avg_field_len, infos) in &grouping.multi_term_groups {
182                    let mut posting_lists = Vec::with_capacity(infos.len());
183                    for info in infos {
184                        if let Some(pl) = reader.$get_postings_fn(info.field, &info.term)
185                            $(. $aw)* ?
186                        {
187                            let idf = compute_idf(
188                                &pl, *field, &info.term, grouping.num_docs, global_stats,
189                            );
190                            posting_lists.push((pl, idf));
191                        }
192                    }
193                    if !posting_lists.is_empty() {
194                        scorers.push(finish_text_maxscore(
195                            posting_lists,
196                            *avg_field_len,
197                            grouping.per_field_limit,
198                        )?);
199                    }
200                }
201                for &idx in &grouping.fallback_indices {
202                    scorers.push(should[idx].$scorer_fn(reader, limit) $(. $aw)* ?);
203                }
204                return Ok(build_should_scorer(scorers));
205            }
206        }
207
208        // ── 3. Filter push-down (MUST + SHOULD) ─────────────────────────
209        if !should.is_empty() && !must.is_empty() && limit < usize::MAX / 4 {
210            // 3a. Compile MUST → predicates (O(1)) vs verifier scorers (seek)
211            let mut predicates: Vec<super::DocPredicate<'_>> = Vec::new();
212            let mut must_verifiers: Vec<Box<dyn super::Scorer + '_>> = Vec::new();
213            for q in must {
214                if let Some(pred) = q.as_doc_predicate(reader) {
215                    predicates.push(pred);
216                } else {
217                    must_verifiers.push(q.$scorer_fn(reader, limit) $(. $aw)* ?);
218                }
219            }
220            // Compile MUST_NOT → negated predicates vs verifier scorers
221            let mut must_not_verifiers: Vec<Box<dyn super::Scorer + '_>> = Vec::new();
222            for q in must_not {
223                if let Some(pred) = q.as_doc_predicate(reader) {
224                    let negated: super::DocPredicate<'_> =
225                        Box::new(move |doc_id| !pred(doc_id));
226                    predicates.push(negated);
227                } else {
228                    must_not_verifiers.push(q.$scorer_fn(reader, limit) $(. $aw)* ?);
229                }
230            }
231
232            // 3b. Fast path: pure predicates + sparse SHOULD → MaxScore w/ predicate
233            if must_verifiers.is_empty()
234                && must_not_verifiers.is_empty()
235                && !predicates.is_empty()
236            {
237                if let Some(infos) = extract_all_sparse_infos(should) {
238                    let combined = chain_predicates(predicates);
239                    if let Some((executor, info)) =
240                        build_sparse_maxscore_executor(&infos, reader, limit, Some(combined))
241                    {
242                        log::debug!(
243                            "BooleanQuery planner: predicate-aware sparse MaxScore, {} dims",
244                            infos.len()
245                        );
246                        let raw = executor.$execute_fn() $(. $aw)* ?;
247                        return Ok(combine_sparse_results(raw, info.combiner, info.field, limit));
248                    }
249                    // predicates consumed — cannot fall through; rebuild them
250                    // (this path only triggers if sparse index is absent)
251                    predicates = Vec::new();
252                    for q in must {
253                        if let Some(pred) = q.as_doc_predicate(reader) {
254                            predicates.push(pred);
255                        }
256                    }
257                    for q in must_not {
258                        if let Some(pred) = q.as_doc_predicate(reader) {
259                            let negated: super::DocPredicate<'_> =
260                                Box::new(move |doc_id| !pred(doc_id));
261                            predicates.push(negated);
262                        }
263                    }
264                }
265            }
266
267            // 3c. PredicatedScorer fallback (over-fetch 4x when predicates present)
268            let should_limit = if !predicates.is_empty() { limit * 4 } else { limit };
269            let should_scorer = if should.len() == 1 {
270                should[0].$scorer_fn(reader, should_limit) $(. $aw)* ?
271            } else {
272                let sub = BooleanQuery {
273                    must: Vec::new(),
274                    should: should.to_vec(),
275                    must_not: Vec::new(),
276                    global_stats: global_stats.cloned(),
277                };
278                sub.$scorer_fn(reader, should_limit) $(. $aw)* ?
279            };
280
281            let use_predicated =
282                must_verifiers.is_empty() || should_scorer.size_hint() >= limit as u32;
283
284            if use_predicated {
285                log::debug!(
286                    "BooleanQuery planner: PredicatedScorer {} preds + {} must_v + {} must_not_v, \
287                     SHOULD size_hint={}, over_fetch={}",
288                    predicates.len(), must_verifiers.len(), must_not_verifiers.len(),
289                    should_scorer.size_hint(), should_limit
290                );
291                return Ok(Box::new(super::PredicatedScorer::new(
292                    should_scorer, predicates, must_verifiers, must_not_verifiers,
293                )));
294            }
295
296            // size_hint < limit with verifiers → BooleanScorer
297            let mut scorer = BooleanScorer {
298                must: must_verifiers,
299                should: vec![should_scorer],
300                must_not: must_not_verifiers,
301                current_doc: 0,
302            };
303            scorer.current_doc = scorer.find_next_match();
304            return Ok(Box::new(scorer));
305        }
306
307        // ── 4. Standard BooleanScorer fallback ───────────────────────────
308        let mut must_scorers = Vec::with_capacity(must.len());
309        for q in must {
310            must_scorers.push(q.$scorer_fn(reader, limit) $(. $aw)* ?);
311        }
312        let mut should_scorers = Vec::with_capacity(should.len());
313        for q in should {
314            should_scorers.push(q.$scorer_fn(reader, limit) $(. $aw)* ?);
315        }
316        let mut must_not_scorers = Vec::with_capacity(must_not.len());
317        for q in must_not {
318            must_not_scorers.push(q.$scorer_fn(reader, limit) $(. $aw)* ?);
319        }
320        let mut scorer = BooleanScorer {
321            must: must_scorers,
322            should: should_scorers,
323            must_not: must_not_scorers,
324            current_doc: 0,
325        };
326        scorer.current_doc = scorer.find_next_match();
327        Ok(Box::new(scorer) as Box<dyn Scorer + '_>)
328    }};
329}
330
331impl Query for BooleanQuery {
332    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
333        let must = self.must.clone();
334        let should = self.should.clone();
335        let must_not = self.must_not.clone();
336        let global_stats = self.global_stats.clone();
337        Box::pin(async move {
338            boolean_plan!(
339                must,
340                should,
341                must_not,
342                global_stats.as_ref(),
343                reader,
344                limit,
345                scorer,
346                get_postings,
347                execute,
348                await
349            )
350        })
351    }
352
353    #[cfg(feature = "sync")]
354    fn scorer_sync<'a>(
355        &self,
356        reader: &'a SegmentReader,
357        limit: usize,
358    ) -> crate::Result<Box<dyn Scorer + 'a>> {
359        boolean_plan!(
360            self.must,
361            self.should,
362            self.must_not,
363            self.global_stats.as_ref(),
364            reader,
365            limit,
366            scorer_sync,
367            get_postings_sync,
368            execute_sync
369        )
370    }
371
372    fn as_doc_predicate<'a>(&self, reader: &'a SegmentReader) -> Option<super::DocPredicate<'a>> {
373        // Need at least some clauses
374        if self.must.is_empty() && self.should.is_empty() {
375            return None;
376        }
377
378        // Try converting all clauses to predicates; bail if any child can't
379        let must_preds: Vec<_> = self
380            .must
381            .iter()
382            .map(|q| q.as_doc_predicate(reader))
383            .collect::<Option<Vec<_>>>()?;
384        let should_preds: Vec<_> = self
385            .should
386            .iter()
387            .map(|q| q.as_doc_predicate(reader))
388            .collect::<Option<Vec<_>>>()?;
389        let must_not_preds: Vec<_> = self
390            .must_not
391            .iter()
392            .map(|q| q.as_doc_predicate(reader))
393            .collect::<Option<Vec<_>>>()?;
394
395        let has_must = !must_preds.is_empty();
396
397        Some(Box::new(move |doc_id| {
398            // All MUST predicates must pass
399            if !must_preds.iter().all(|p| p(doc_id)) {
400                return false;
401            }
402            // When there are no MUST clauses, at least one SHOULD must pass
403            if !has_must && !should_preds.is_empty() && !should_preds.iter().any(|p| p(doc_id)) {
404                return false;
405            }
406            // No MUST_NOT predicate should pass
407            must_not_preds.iter().all(|p| !p(doc_id))
408        }))
409    }
410
411    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
412        let must = self.must.clone();
413        let should = self.should.clone();
414
415        Box::pin(async move {
416            if !must.is_empty() {
417                let mut estimates = Vec::with_capacity(must.len());
418                for q in &must {
419                    estimates.push(q.count_estimate(reader).await?);
420                }
421                estimates
422                    .into_iter()
423                    .min()
424                    .ok_or_else(|| crate::Error::Corruption("Empty must clause".to_string()))
425            } else if !should.is_empty() {
426                let mut sum = 0u32;
427                for q in &should {
428                    sum = sum.saturating_add(q.count_estimate(reader).await?);
429                }
430                Ok(sum)
431            } else {
432                Ok(0)
433            }
434        })
435    }
436}
437
438struct BooleanScorer<'a> {
439    must: Vec<Box<dyn Scorer + 'a>>,
440    should: Vec<Box<dyn Scorer + 'a>>,
441    must_not: Vec<Box<dyn Scorer + 'a>>,
442    current_doc: DocId,
443}
444
445impl BooleanScorer<'_> {
446    fn find_next_match(&mut self) -> DocId {
447        if self.must.is_empty() && self.should.is_empty() {
448            return TERMINATED;
449        }
450
451        loop {
452            let candidate = if !self.must.is_empty() {
453                let mut max_doc = self
454                    .must
455                    .iter()
456                    .map(|s| s.doc())
457                    .max()
458                    .unwrap_or(TERMINATED);
459
460                if max_doc == TERMINATED {
461                    return TERMINATED;
462                }
463
464                loop {
465                    let mut all_match = true;
466                    for scorer in &mut self.must {
467                        let doc = scorer.seek(max_doc);
468                        if doc == TERMINATED {
469                            return TERMINATED;
470                        }
471                        if doc > max_doc {
472                            max_doc = doc;
473                            all_match = false;
474                            break;
475                        }
476                    }
477                    if all_match {
478                        break;
479                    }
480                }
481                max_doc
482            } else {
483                self.should
484                    .iter()
485                    .map(|s| s.doc())
486                    .filter(|&d| d != TERMINATED)
487                    .min()
488                    .unwrap_or(TERMINATED)
489            };
490
491            if candidate == TERMINATED {
492                return TERMINATED;
493            }
494
495            let excluded = self.must_not.iter_mut().any(|scorer| {
496                let doc = scorer.seek(candidate);
497                doc == candidate
498            });
499
500            if !excluded {
501                // Seek SHOULD scorers to candidate so score() can see their contributions
502                for scorer in &mut self.should {
503                    scorer.seek(candidate);
504                }
505                self.current_doc = candidate;
506                return candidate;
507            }
508
509            // Advance past excluded candidate
510            if !self.must.is_empty() {
511                for scorer in &mut self.must {
512                    scorer.advance();
513                }
514            } else {
515                // For SHOULD-only: seek all scorers past the excluded candidate
516                for scorer in &mut self.should {
517                    if scorer.doc() <= candidate && scorer.doc() != TERMINATED {
518                        scorer.seek(candidate + 1);
519                    }
520                }
521            }
522        }
523    }
524}
525
526impl super::docset::DocSet for BooleanScorer<'_> {
527    fn doc(&self) -> DocId {
528        self.current_doc
529    }
530
531    fn advance(&mut self) -> DocId {
532        if !self.must.is_empty() {
533            for scorer in &mut self.must {
534                scorer.advance();
535            }
536        } else {
537            for scorer in &mut self.should {
538                if scorer.doc() == self.current_doc {
539                    scorer.advance();
540                }
541            }
542        }
543
544        self.current_doc = self.find_next_match();
545        self.current_doc
546    }
547
548    fn seek(&mut self, target: DocId) -> DocId {
549        for scorer in &mut self.must {
550            scorer.seek(target);
551        }
552
553        for scorer in &mut self.should {
554            scorer.seek(target);
555        }
556
557        self.current_doc = self.find_next_match();
558        self.current_doc
559    }
560
561    fn size_hint(&self) -> u32 {
562        if !self.must.is_empty() {
563            self.must.iter().map(|s| s.size_hint()).min().unwrap_or(0)
564        } else {
565            self.should.iter().map(|s| s.size_hint()).sum()
566        }
567    }
568}
569
570impl Scorer for BooleanScorer<'_> {
571    fn score(&self) -> Score {
572        let mut total = 0.0;
573
574        for scorer in &self.must {
575            if scorer.doc() == self.current_doc {
576                total += scorer.score();
577            }
578        }
579
580        for scorer in &self.should {
581            if scorer.doc() == self.current_doc {
582                total += scorer.score();
583            }
584        }
585
586        total
587    }
588
589    fn matched_positions(&self) -> Option<super::MatchedPositions> {
590        let mut all_positions: super::MatchedPositions = Vec::new();
591
592        for scorer in &self.must {
593            if scorer.doc() == self.current_doc
594                && let Some(positions) = scorer.matched_positions()
595            {
596                all_positions.extend(positions);
597            }
598        }
599
600        for scorer in &self.should {
601            if scorer.doc() == self.current_doc
602                && let Some(positions) = scorer.matched_positions()
603            {
604                all_positions.extend(positions);
605            }
606        }
607
608        if all_positions.is_empty() {
609            None
610        } else {
611            Some(all_positions)
612        }
613    }
614}
615
616#[cfg(test)]
617mod tests {
618    use super::*;
619    use crate::dsl::Field;
620    use crate::query::{QueryDecomposition, TermQuery};
621
622    #[test]
623    fn test_maxscore_eligible_pure_or_same_field() {
624        // Pure OR query with multiple terms in same field should be MaxScore-eligible
625        let query = BooleanQuery::new()
626            .should(TermQuery::text(Field(0), "hello"))
627            .should(TermQuery::text(Field(0), "world"))
628            .should(TermQuery::text(Field(0), "foo"));
629
630        // All clauses should return term info
631        assert!(
632            query
633                .should
634                .iter()
635                .all(|q| matches!(q.decompose(), QueryDecomposition::TextTerm(_)))
636        );
637
638        // All should be same field
639        let infos: Vec<_> = query
640            .should
641            .iter()
642            .filter_map(|q| match q.decompose() {
643                QueryDecomposition::TextTerm(info) => Some(info),
644                _ => None,
645            })
646            .collect();
647        assert_eq!(infos.len(), 3);
648        assert!(infos.iter().all(|i| i.field == Field(0)));
649    }
650
651    #[test]
652    fn test_maxscore_not_eligible_different_fields() {
653        // OR query with terms in different fields should NOT use MaxScore
654        let query = BooleanQuery::new()
655            .should(TermQuery::text(Field(0), "hello"))
656            .should(TermQuery::text(Field(1), "world")); // Different field!
657
658        let infos: Vec<_> = query
659            .should
660            .iter()
661            .filter_map(|q| match q.decompose() {
662                QueryDecomposition::TextTerm(info) => Some(info),
663                _ => None,
664            })
665            .collect();
666        assert_eq!(infos.len(), 2);
667        // Fields are different, MaxScore should not be used
668        assert!(infos[0].field != infos[1].field);
669    }
670
671    #[test]
672    fn test_maxscore_not_eligible_with_must() {
673        // Query with MUST clause should NOT use MaxScore optimization
674        let query = BooleanQuery::new()
675            .must(TermQuery::text(Field(0), "required"))
676            .should(TermQuery::text(Field(0), "hello"))
677            .should(TermQuery::text(Field(0), "world"));
678
679        // Has MUST clause, so MaxScore optimization should not kick in
680        assert!(!query.must.is_empty());
681    }
682
683    #[test]
684    fn test_maxscore_not_eligible_with_must_not() {
685        // Query with MUST_NOT clause should NOT use MaxScore optimization
686        let query = BooleanQuery::new()
687            .should(TermQuery::text(Field(0), "hello"))
688            .should(TermQuery::text(Field(0), "world"))
689            .must_not(TermQuery::text(Field(0), "excluded"));
690
691        // Has MUST_NOT clause, so MaxScore optimization should not kick in
692        assert!(!query.must_not.is_empty());
693    }
694
695    #[test]
696    fn test_maxscore_not_eligible_single_term() {
697        // Single SHOULD clause should NOT use MaxScore (no benefit)
698        let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
699
700        // Only one term, MaxScore not beneficial
701        assert_eq!(query.should.len(), 1);
702    }
703
704    #[test]
705    fn test_term_query_info_extraction() {
706        let term_query = TermQuery::text(Field(42), "test");
707        match term_query.decompose() {
708            QueryDecomposition::TextTerm(info) => {
709                assert_eq!(info.field, Field(42));
710                assert_eq!(info.term, b"test");
711            }
712            _ => panic!("Expected TextTerm decomposition"),
713        }
714    }
715
716    #[test]
717    fn test_boolean_query_no_term_info() {
718        // BooleanQuery itself should not return term info
719        let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
720
721        assert!(matches!(query.decompose(), QueryDecomposition::Opaque));
722    }
723}