Skip to main content

hermes_core/query/
boolean.rs

1//! Boolean query with MUST, SHOULD, and MUST_NOT clauses
2
3use std::sync::Arc;
4
5use crate::segment::SegmentReader;
6use crate::structures::TERMINATED;
7use crate::{DocId, Score};
8
9use super::{CountFuture, GlobalStats, MaxScoreExecutor, Query, ScoredDoc, Scorer, ScorerFuture};
10
11/// Boolean query with MUST, SHOULD, and MUST_NOT clauses
12///
13/// When all clauses are SHOULD term queries on the same field, automatically
14/// uses MaxScore optimization for efficient top-k retrieval.
15#[derive(Default, Clone)]
16pub struct BooleanQuery {
17    pub must: Vec<Arc<dyn Query>>,
18    pub should: Vec<Arc<dyn Query>>,
19    pub must_not: Vec<Arc<dyn Query>>,
20    /// Optional global statistics for cross-segment IDF
21    global_stats: Option<Arc<GlobalStats>>,
22}
23
24impl std::fmt::Debug for BooleanQuery {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        f.debug_struct("BooleanQuery")
27            .field("must_count", &self.must.len())
28            .field("should_count", &self.should.len())
29            .field("must_not_count", &self.must_not.len())
30            .field("has_global_stats", &self.global_stats.is_some())
31            .finish()
32    }
33}
34
35impl BooleanQuery {
36    pub fn new() -> Self {
37        Self::default()
38    }
39
40    pub fn must(mut self, query: impl Query + 'static) -> Self {
41        self.must.push(Arc::new(query));
42        self
43    }
44
45    pub fn should(mut self, query: impl Query + 'static) -> Self {
46        self.should.push(Arc::new(query));
47        self
48    }
49
50    pub fn must_not(mut self, query: impl Query + 'static) -> Self {
51        self.must_not.push(Arc::new(query));
52        self
53    }
54
55    /// Set global statistics for cross-segment IDF
56    pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
57        self.global_stats = Some(stats);
58        self
59    }
60}
61
62/// Check if SHOULD clauses are eligible for MaxScore optimization.
63/// Returns (term_infos, field) if all are single-field term queries, None otherwise.
64fn maxscore_eligible(
65    should: &[Arc<dyn Query>],
66) -> Option<(Vec<super::TermQueryInfo>, crate::Field)> {
67    let term_infos: Vec<_> = should
68        .iter()
69        .filter_map(|q| q.as_term_query_info())
70        .collect();
71    if term_infos.len() != should.len() {
72        return None;
73    }
74    let first_field = term_infos[0].field;
75    if !term_infos.iter().all(|t| t.field == first_field) {
76        return None;
77    }
78    Some((term_infos, first_field))
79}
80
81/// Compute IDF for a posting list, preferring global stats.
82fn compute_idf(
83    posting_list: &crate::structures::BlockPostingList,
84    field: crate::Field,
85    term: &[u8],
86    num_docs: f32,
87    global_stats: Option<&Arc<GlobalStats>>,
88) -> f32 {
89    if let Some(stats) = global_stats {
90        let global_idf = stats.text_idf(field, &String::from_utf8_lossy(term));
91        if global_idf > 0.0 {
92            return global_idf;
93        }
94    }
95    let doc_freq = posting_list.doc_count() as f32;
96    super::bm25_idf(doc_freq, num_docs)
97}
98
99/// Build MaxScore scorer from pre-fetched posting lists.
100fn maxscore_scorer_from_postings<'a>(
101    posting_lists: Vec<(crate::structures::BlockPostingList, f32)>,
102    avg_field_len: f32,
103    limit: usize,
104    predicate: Option<super::DocPredicate<'a>>,
105) -> crate::Result<Box<dyn Scorer + 'a>> {
106    if posting_lists.is_empty() {
107        return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>);
108    }
109    let mut executor = MaxScoreExecutor::text(posting_lists, avg_field_len, limit);
110    executor.set_predicate(predicate);
111    let results = executor.execute_sync()?;
112    Ok(Box::new(TopKResultScorer::new(results)) as Box<dyn Scorer + 'a>)
113}
114
115/// Try to create a MaxScore-optimized scorer for pure OR queries (async)
116async fn try_maxscore_scorer<'a>(
117    should: &[Arc<dyn Query>],
118    reader: &'a SegmentReader,
119    limit: usize,
120    global_stats: Option<&Arc<GlobalStats>>,
121    predicate: Option<super::DocPredicate<'a>>,
122) -> crate::Result<Option<Box<dyn Scorer + 'a>>> {
123    let (mut term_infos, field) = match maxscore_eligible(should) {
124        Some(v) => v,
125        None => return Ok(None),
126    };
127
128    let avg_field_len = global_stats
129        .map(|s| s.avg_field_len(field))
130        .unwrap_or_else(|| reader.avg_field_len(field));
131    let num_docs = reader.num_docs() as f32;
132
133    let mut posting_lists: Vec<(crate::structures::BlockPostingList, f32)> =
134        Vec::with_capacity(term_infos.len());
135    for info in term_infos.drain(..) {
136        if let Some(pl) = reader.get_postings(info.field, &info.term).await? {
137            let idf = compute_idf(&pl, info.field, &info.term, num_docs, global_stats);
138            posting_lists.push((pl, idf));
139        }
140    }
141
142    Ok(Some(maxscore_scorer_from_postings(
143        posting_lists,
144        avg_field_len,
145        limit,
146        predicate,
147    )?))
148}
149
150/// Try to create a MaxScore-optimized scorer for pure OR queries (sync)
151#[cfg(feature = "sync")]
152fn try_maxscore_scorer_sync<'a>(
153    should: &[Arc<dyn Query>],
154    reader: &'a SegmentReader,
155    limit: usize,
156    global_stats: Option<&Arc<GlobalStats>>,
157    predicate: Option<super::DocPredicate<'a>>,
158) -> crate::Result<Option<Box<dyn Scorer + 'a>>> {
159    let (mut term_infos, field) = match maxscore_eligible(should) {
160        Some(v) => v,
161        None => return Ok(None),
162    };
163
164    let avg_field_len = global_stats
165        .map(|s| s.avg_field_len(field))
166        .unwrap_or_else(|| reader.avg_field_len(field));
167    let num_docs = reader.num_docs() as f32;
168
169    let mut posting_lists: Vec<(crate::structures::BlockPostingList, f32)> =
170        Vec::with_capacity(term_infos.len());
171    for info in term_infos.drain(..) {
172        if let Some(pl) = reader.get_postings_sync(info.field, &info.term)? {
173            let idf = compute_idf(&pl, info.field, &info.term, num_docs, global_stats);
174            posting_lists.push((pl, idf));
175        }
176    }
177
178    Ok(Some(maxscore_scorer_from_postings(
179        posting_lists,
180        avg_field_len,
181        limit,
182        predicate,
183    )?))
184}
185
186impl Query for BooleanQuery {
187    fn scorer<'a>(
188        &self,
189        reader: &'a SegmentReader,
190        limit: usize,
191        predicate: Option<super::DocPredicate<'a>>,
192    ) -> ScorerFuture<'a> {
193        // Clone Arc vectors - cheap reference counting
194        let must = self.must.clone();
195        let should = self.should.clone();
196        let must_not = self.must_not.clone();
197        let global_stats = self.global_stats.clone();
198
199        Box::pin(async move {
200            // Check if this is a pure OR query eligible for MaxScore optimization
201            // Conditions: no MUST, no MUST_NOT, multiple SHOULD clauses, all same field
202            if must.is_empty()
203                && must_not.is_empty()
204                && should.len() >= 2
205                && let Some(scorer) =
206                    try_maxscore_scorer(&should, reader, limit, global_stats.as_ref(), predicate)
207                        .await?
208            {
209                return Ok(scorer);
210            }
211
212            // Fall back to standard boolean scoring
213            // Predicate not passed to sub-scorers — it's only useful for executors
214            let mut must_scorers = Vec::with_capacity(must.len());
215            for q in &must {
216                must_scorers.push(q.scorer(reader, limit, None).await?);
217            }
218
219            let mut should_scorers = Vec::with_capacity(should.len());
220            for q in &should {
221                should_scorers.push(q.scorer(reader, limit, None).await?);
222            }
223
224            let mut must_not_scorers = Vec::with_capacity(must_not.len());
225            for q in &must_not {
226                must_not_scorers.push(q.scorer(reader, limit, None).await?);
227            }
228
229            let mut scorer = BooleanScorer {
230                must: must_scorers,
231                should: should_scorers,
232                must_not: must_not_scorers,
233                current_doc: 0,
234            };
235            // Initialize to first match
236            scorer.current_doc = scorer.find_next_match();
237            Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
238        })
239    }
240
241    #[cfg(feature = "sync")]
242    fn scorer_sync<'a>(
243        &self,
244        reader: &'a SegmentReader,
245        limit: usize,
246        predicate: Option<super::DocPredicate<'a>>,
247    ) -> crate::Result<Box<dyn Scorer + 'a>> {
248        // MaxScore optimization for pure OR queries
249        if self.must.is_empty()
250            && self.must_not.is_empty()
251            && self.should.len() >= 2
252            && let Some(scorer) = try_maxscore_scorer_sync(
253                &self.should,
254                reader,
255                limit,
256                self.global_stats.as_ref(),
257                predicate,
258            )?
259        {
260            return Ok(scorer);
261        }
262
263        // Fall back to standard boolean scoring
264        let mut must_scorers = Vec::with_capacity(self.must.len());
265        for q in &self.must {
266            must_scorers.push(q.scorer_sync(reader, limit, None)?);
267        }
268
269        let mut should_scorers = Vec::with_capacity(self.should.len());
270        for q in &self.should {
271            should_scorers.push(q.scorer_sync(reader, limit, None)?);
272        }
273
274        let mut must_not_scorers = Vec::with_capacity(self.must_not.len());
275        for q in &self.must_not {
276            must_not_scorers.push(q.scorer_sync(reader, limit, None)?);
277        }
278
279        let mut scorer = BooleanScorer {
280            must: must_scorers,
281            should: should_scorers,
282            must_not: must_not_scorers,
283            current_doc: 0,
284        };
285        scorer.current_doc = scorer.find_next_match();
286        Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
287    }
288
289    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
290        let must = self.must.clone();
291        let should = self.should.clone();
292
293        Box::pin(async move {
294            if !must.is_empty() {
295                let mut estimates = Vec::with_capacity(must.len());
296                for q in &must {
297                    estimates.push(q.count_estimate(reader).await?);
298                }
299                estimates
300                    .into_iter()
301                    .min()
302                    .ok_or_else(|| crate::Error::Corruption("Empty must clause".to_string()))
303            } else if !should.is_empty() {
304                let mut sum = 0u32;
305                for q in &should {
306                    sum += q.count_estimate(reader).await?;
307                }
308                Ok(sum)
309            } else {
310                Ok(0)
311            }
312        })
313    }
314}
315
316struct BooleanScorer<'a> {
317    must: Vec<Box<dyn Scorer + 'a>>,
318    should: Vec<Box<dyn Scorer + 'a>>,
319    must_not: Vec<Box<dyn Scorer + 'a>>,
320    current_doc: DocId,
321}
322
323impl BooleanScorer<'_> {
324    fn find_next_match(&mut self) -> DocId {
325        if self.must.is_empty() && self.should.is_empty() {
326            return TERMINATED;
327        }
328
329        loop {
330            let candidate = if !self.must.is_empty() {
331                let mut max_doc = self
332                    .must
333                    .iter()
334                    .map(|s| s.doc())
335                    .max()
336                    .unwrap_or(TERMINATED);
337
338                if max_doc == TERMINATED {
339                    return TERMINATED;
340                }
341
342                loop {
343                    let mut all_match = true;
344                    for scorer in &mut self.must {
345                        let doc = scorer.seek(max_doc);
346                        if doc == TERMINATED {
347                            return TERMINATED;
348                        }
349                        if doc > max_doc {
350                            max_doc = doc;
351                            all_match = false;
352                            break;
353                        }
354                    }
355                    if all_match {
356                        break;
357                    }
358                }
359                max_doc
360            } else {
361                self.should
362                    .iter()
363                    .map(|s| s.doc())
364                    .filter(|&d| d != TERMINATED)
365                    .min()
366                    .unwrap_or(TERMINATED)
367            };
368
369            if candidate == TERMINATED {
370                return TERMINATED;
371            }
372
373            let excluded = self.must_not.iter_mut().any(|scorer| {
374                let doc = scorer.seek(candidate);
375                doc == candidate
376            });
377
378            if !excluded {
379                self.current_doc = candidate;
380                return candidate;
381            }
382
383            // Advance past excluded candidate
384            if !self.must.is_empty() {
385                for scorer in &mut self.must {
386                    scorer.advance();
387                }
388            } else {
389                // For SHOULD-only: seek all scorers past the excluded candidate
390                for scorer in &mut self.should {
391                    if scorer.doc() <= candidate && scorer.doc() != TERMINATED {
392                        scorer.seek(candidate + 1);
393                    }
394                }
395            }
396        }
397    }
398}
399
400impl Scorer for BooleanScorer<'_> {
401    fn doc(&self) -> DocId {
402        self.current_doc
403    }
404
405    fn score(&self) -> Score {
406        let mut total = 0.0;
407
408        for scorer in &self.must {
409            if scorer.doc() == self.current_doc {
410                total += scorer.score();
411            }
412        }
413
414        for scorer in &self.should {
415            if scorer.doc() == self.current_doc {
416                total += scorer.score();
417            }
418        }
419
420        total
421    }
422
423    fn advance(&mut self) -> DocId {
424        if !self.must.is_empty() {
425            for scorer in &mut self.must {
426                scorer.advance();
427            }
428        } else {
429            for scorer in &mut self.should {
430                if scorer.doc() == self.current_doc {
431                    scorer.advance();
432                }
433            }
434        }
435
436        self.current_doc = self.find_next_match();
437        self.current_doc
438    }
439
440    fn seek(&mut self, target: DocId) -> DocId {
441        for scorer in &mut self.must {
442            scorer.seek(target);
443        }
444
445        for scorer in &mut self.should {
446            scorer.seek(target);
447        }
448
449        self.current_doc = self.find_next_match();
450        self.current_doc
451    }
452
453    fn size_hint(&self) -> u32 {
454        if !self.must.is_empty() {
455            self.must.iter().map(|s| s.size_hint()).min().unwrap_or(0)
456        } else {
457            self.should.iter().map(|s| s.size_hint()).sum()
458        }
459    }
460}
461
462/// Scorer that iterates over pre-computed top-k results
463struct TopKResultScorer {
464    results: Vec<ScoredDoc>,
465    position: usize,
466}
467
468impl TopKResultScorer {
469    fn new(results: Vec<ScoredDoc>) -> Self {
470        Self {
471            results,
472            position: 0,
473        }
474    }
475}
476
477impl Scorer for TopKResultScorer {
478    fn doc(&self) -> DocId {
479        if self.position < self.results.len() {
480            self.results[self.position].doc_id
481        } else {
482            TERMINATED
483        }
484    }
485
486    fn score(&self) -> Score {
487        if self.position < self.results.len() {
488            self.results[self.position].score
489        } else {
490            0.0
491        }
492    }
493
494    fn advance(&mut self) -> DocId {
495        self.position += 1;
496        self.doc()
497    }
498
499    fn seek(&mut self, target: DocId) -> DocId {
500        while self.position < self.results.len() && self.results[self.position].doc_id < target {
501            self.position += 1;
502        }
503        self.doc()
504    }
505
506    fn size_hint(&self) -> u32 {
507        self.results.len() as u32
508    }
509}
510
511/// Empty scorer for when no terms match
512struct EmptyScorer;
513
514impl Scorer for EmptyScorer {
515    fn doc(&self) -> DocId {
516        TERMINATED
517    }
518
519    fn score(&self) -> Score {
520        0.0
521    }
522
523    fn advance(&mut self) -> DocId {
524        TERMINATED
525    }
526
527    fn seek(&mut self, _target: DocId) -> DocId {
528        TERMINATED
529    }
530
531    fn size_hint(&self) -> u32 {
532        0
533    }
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539    use crate::dsl::Field;
540    use crate::query::TermQuery;
541
542    #[test]
543    fn test_maxscore_eligible_pure_or_same_field() {
544        // Pure OR query with multiple terms in same field should be MaxScore-eligible
545        let query = BooleanQuery::new()
546            .should(TermQuery::text(Field(0), "hello"))
547            .should(TermQuery::text(Field(0), "world"))
548            .should(TermQuery::text(Field(0), "foo"));
549
550        // All clauses should return term info
551        assert!(
552            query
553                .should
554                .iter()
555                .all(|q| q.as_term_query_info().is_some())
556        );
557
558        // All should be same field
559        let infos: Vec<_> = query
560            .should
561            .iter()
562            .filter_map(|q| q.as_term_query_info())
563            .collect();
564        assert_eq!(infos.len(), 3);
565        assert!(infos.iter().all(|i| i.field == Field(0)));
566    }
567
568    #[test]
569    fn test_maxscore_not_eligible_different_fields() {
570        // OR query with terms in different fields should NOT use MaxScore
571        let query = BooleanQuery::new()
572            .should(TermQuery::text(Field(0), "hello"))
573            .should(TermQuery::text(Field(1), "world")); // Different field!
574
575        let infos: Vec<_> = query
576            .should
577            .iter()
578            .filter_map(|q| q.as_term_query_info())
579            .collect();
580        assert_eq!(infos.len(), 2);
581        // Fields are different, MaxScore should not be used
582        assert!(infos[0].field != infos[1].field);
583    }
584
585    #[test]
586    fn test_maxscore_not_eligible_with_must() {
587        // Query with MUST clause should NOT use MaxScore optimization
588        let query = BooleanQuery::new()
589            .must(TermQuery::text(Field(0), "required"))
590            .should(TermQuery::text(Field(0), "hello"))
591            .should(TermQuery::text(Field(0), "world"));
592
593        // Has MUST clause, so MaxScore optimization should not kick in
594        assert!(!query.must.is_empty());
595    }
596
597    #[test]
598    fn test_maxscore_not_eligible_with_must_not() {
599        // Query with MUST_NOT clause should NOT use MaxScore optimization
600        let query = BooleanQuery::new()
601            .should(TermQuery::text(Field(0), "hello"))
602            .should(TermQuery::text(Field(0), "world"))
603            .must_not(TermQuery::text(Field(0), "excluded"));
604
605        // Has MUST_NOT clause, so MaxScore optimization should not kick in
606        assert!(!query.must_not.is_empty());
607    }
608
609    #[test]
610    fn test_maxscore_not_eligible_single_term() {
611        // Single SHOULD clause should NOT use MaxScore (no benefit)
612        let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
613
614        // Only one term, MaxScore not beneficial
615        assert_eq!(query.should.len(), 1);
616    }
617
618    #[test]
619    fn test_term_query_info_extraction() {
620        let term_query = TermQuery::text(Field(42), "test");
621        let info = term_query.as_term_query_info();
622
623        assert!(info.is_some());
624        let info = info.unwrap();
625        assert_eq!(info.field, Field(42));
626        assert_eq!(info.term, b"test");
627    }
628
629    #[test]
630    fn test_boolean_query_no_term_info() {
631        // BooleanQuery itself should not return term info
632        let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
633
634        assert!(query.as_term_query_info().is_none());
635    }
636}