Skip to main content

hermes_core/query/
boolean.rs

1//! Boolean query with MUST, SHOULD, and MUST_NOT clauses
2
3use std::sync::Arc;
4
5use crate::segment::SegmentReader;
6use crate::structures::TERMINATED;
7use crate::{DocId, Score};
8
9use super::{
10    CountFuture, GlobalStats, MaxScoreExecutor, Query, ScoredDoc, Scorer, ScorerFuture,
11    TextTermScorer,
12};
13
14/// Boolean query with MUST, SHOULD, and MUST_NOT clauses
15///
16/// When all clauses are SHOULD term queries on the same field, automatically
17/// uses MaxScore optimization for efficient top-k retrieval.
18#[derive(Default, Clone)]
19pub struct BooleanQuery {
20    pub must: Vec<Arc<dyn Query>>,
21    pub should: Vec<Arc<dyn Query>>,
22    pub must_not: Vec<Arc<dyn Query>>,
23    /// Optional global statistics for cross-segment IDF
24    global_stats: Option<Arc<GlobalStats>>,
25}
26
27impl std::fmt::Debug for BooleanQuery {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        f.debug_struct("BooleanQuery")
30            .field("must_count", &self.must.len())
31            .field("should_count", &self.should.len())
32            .field("must_not_count", &self.must_not.len())
33            .field("has_global_stats", &self.global_stats.is_some())
34            .finish()
35    }
36}
37
38impl BooleanQuery {
39    pub fn new() -> Self {
40        Self::default()
41    }
42
43    pub fn must(mut self, query: impl Query + 'static) -> Self {
44        self.must.push(Arc::new(query));
45        self
46    }
47
48    pub fn should(mut self, query: impl Query + 'static) -> Self {
49        self.should.push(Arc::new(query));
50        self
51    }
52
53    pub fn must_not(mut self, query: impl Query + 'static) -> Self {
54        self.must_not.push(Arc::new(query));
55        self
56    }
57
58    /// Set global statistics for cross-segment IDF
59    pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
60        self.global_stats = Some(stats);
61        self
62    }
63}
64
65/// Try to create a MaxScore-optimized scorer for pure OR queries
66async fn try_maxscore_scorer<'a>(
67    should: &[Arc<dyn Query>],
68    reader: &'a SegmentReader,
69    limit: usize,
70    global_stats: Option<&Arc<GlobalStats>>,
71) -> crate::Result<Option<Box<dyn Scorer + 'a>>> {
72    // Extract term info from all SHOULD clauses
73    let mut term_infos: Vec<_> = should
74        .iter()
75        .filter_map(|q| q.as_term_query_info())
76        .collect();
77
78    // Check if all clauses are term queries
79    if term_infos.len() != should.len() {
80        return Ok(None);
81    }
82
83    // Check if all terms are for the same field
84    let first_field = term_infos[0].field;
85    if !term_infos.iter().all(|t| t.field == first_field) {
86        return Ok(None);
87    }
88
89    // Build scorers for each term
90    let mut scorers: Vec<TextTermScorer> = Vec::with_capacity(term_infos.len());
91    let avg_field_len = global_stats
92        .map(|s| s.avg_field_len(first_field))
93        .unwrap_or_else(|| reader.avg_field_len(first_field));
94    let num_docs = reader.num_docs() as f32;
95
96    for info in term_infos.drain(..) {
97        if let Some(posting_list) = reader.get_postings(info.field, &info.term).await? {
98            let doc_freq = posting_list.doc_count() as f32;
99            let idf = if let Some(stats) = global_stats {
100                let global_idf = stats.text_idf(info.field, &String::from_utf8_lossy(&info.term));
101                if global_idf > 0.0 {
102                    global_idf
103                } else {
104                    super::bm25_idf(doc_freq, num_docs)
105                }
106            } else {
107                super::bm25_idf(doc_freq, num_docs)
108            };
109            scorers.push(TextTermScorer::new(posting_list, idf, avg_field_len));
110        }
111    }
112
113    if scorers.is_empty() {
114        return Ok(Some(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>));
115    }
116
117    // Use MaxScore executor for efficient top-k
118    let results = MaxScoreExecutor::new(scorers, limit).execute();
119    Ok(Some(
120        Box::new(TopKResultScorer::new(results)) as Box<dyn Scorer + 'a>
121    ))
122}
123
124impl Query for BooleanQuery {
125    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
126        // Clone Arc vectors - cheap reference counting
127        let must = self.must.clone();
128        let should = self.should.clone();
129        let must_not = self.must_not.clone();
130        let global_stats = self.global_stats.clone();
131
132        Box::pin(async move {
133            // Check if this is a pure OR query eligible for MaxScore optimization
134            // Conditions: no MUST, no MUST_NOT, multiple SHOULD clauses, all same field
135            if must.is_empty()
136                && must_not.is_empty()
137                && should.len() >= 2
138                && let Some(scorer) =
139                    try_maxscore_scorer(&should, reader, limit, global_stats.as_ref()).await?
140            {
141                return Ok(scorer);
142            }
143
144            // Fall back to standard boolean scoring
145            let mut must_scorers = Vec::with_capacity(must.len());
146            for q in &must {
147                must_scorers.push(q.scorer(reader, limit).await?);
148            }
149
150            let mut should_scorers = Vec::with_capacity(should.len());
151            for q in &should {
152                should_scorers.push(q.scorer(reader, limit).await?);
153            }
154
155            let mut must_not_scorers = Vec::with_capacity(must_not.len());
156            for q in &must_not {
157                must_not_scorers.push(q.scorer(reader, limit).await?);
158            }
159
160            let mut scorer = BooleanScorer {
161                must: must_scorers,
162                should: should_scorers,
163                must_not: must_not_scorers,
164                current_doc: 0,
165            };
166            // Initialize to first match
167            scorer.current_doc = scorer.find_next_match();
168            Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
169        })
170    }
171
172    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
173        let must = self.must.clone();
174        let should = self.should.clone();
175
176        Box::pin(async move {
177            if !must.is_empty() {
178                let mut estimates = Vec::with_capacity(must.len());
179                for q in &must {
180                    estimates.push(q.count_estimate(reader).await?);
181                }
182                estimates
183                    .into_iter()
184                    .min()
185                    .ok_or_else(|| crate::Error::Corruption("Empty must clause".to_string()))
186            } else if !should.is_empty() {
187                let mut sum = 0u32;
188                for q in &should {
189                    sum += q.count_estimate(reader).await?;
190                }
191                Ok(sum)
192            } else {
193                Ok(0)
194            }
195        })
196    }
197}
198
199struct BooleanScorer<'a> {
200    must: Vec<Box<dyn Scorer + 'a>>,
201    should: Vec<Box<dyn Scorer + 'a>>,
202    must_not: Vec<Box<dyn Scorer + 'a>>,
203    current_doc: DocId,
204}
205
206impl BooleanScorer<'_> {
207    fn find_next_match(&mut self) -> DocId {
208        if self.must.is_empty() && self.should.is_empty() {
209            return TERMINATED;
210        }
211
212        loop {
213            let candidate = if !self.must.is_empty() {
214                let mut max_doc = self
215                    .must
216                    .iter()
217                    .map(|s| s.doc())
218                    .max()
219                    .unwrap_or(TERMINATED);
220
221                if max_doc == TERMINATED {
222                    return TERMINATED;
223                }
224
225                loop {
226                    let mut all_match = true;
227                    for scorer in &mut self.must {
228                        let doc = scorer.seek(max_doc);
229                        if doc == TERMINATED {
230                            return TERMINATED;
231                        }
232                        if doc > max_doc {
233                            max_doc = doc;
234                            all_match = false;
235                            break;
236                        }
237                    }
238                    if all_match {
239                        break;
240                    }
241                }
242                max_doc
243            } else {
244                self.should
245                    .iter()
246                    .map(|s| s.doc())
247                    .filter(|&d| d != TERMINATED)
248                    .min()
249                    .unwrap_or(TERMINATED)
250            };
251
252            if candidate == TERMINATED {
253                return TERMINATED;
254            }
255
256            let excluded = self.must_not.iter_mut().any(|scorer| {
257                let doc = scorer.seek(candidate);
258                doc == candidate
259            });
260
261            if !excluded {
262                self.current_doc = candidate;
263                return candidate;
264            }
265
266            // Advance past excluded candidate
267            if !self.must.is_empty() {
268                for scorer in &mut self.must {
269                    scorer.advance();
270                }
271            } else {
272                // For SHOULD-only: seek all scorers past the excluded candidate
273                for scorer in &mut self.should {
274                    if scorer.doc() <= candidate && scorer.doc() != TERMINATED {
275                        scorer.seek(candidate + 1);
276                    }
277                }
278            }
279        }
280    }
281}
282
283impl Scorer for BooleanScorer<'_> {
284    fn doc(&self) -> DocId {
285        self.current_doc
286    }
287
288    fn score(&self) -> Score {
289        let mut total = 0.0;
290
291        for scorer in &self.must {
292            if scorer.doc() == self.current_doc {
293                total += scorer.score();
294            }
295        }
296
297        for scorer in &self.should {
298            if scorer.doc() == self.current_doc {
299                total += scorer.score();
300            }
301        }
302
303        total
304    }
305
306    fn advance(&mut self) -> DocId {
307        if !self.must.is_empty() {
308            for scorer in &mut self.must {
309                scorer.advance();
310            }
311        } else {
312            for scorer in &mut self.should {
313                if scorer.doc() == self.current_doc {
314                    scorer.advance();
315                }
316            }
317        }
318
319        self.find_next_match()
320    }
321
322    fn seek(&mut self, target: DocId) -> DocId {
323        for scorer in &mut self.must {
324            scorer.seek(target);
325        }
326
327        for scorer in &mut self.should {
328            scorer.seek(target);
329        }
330
331        self.find_next_match()
332    }
333
334    fn size_hint(&self) -> u32 {
335        if !self.must.is_empty() {
336            self.must.iter().map(|s| s.size_hint()).min().unwrap_or(0)
337        } else {
338            self.should.iter().map(|s| s.size_hint()).sum()
339        }
340    }
341}
342
343/// Scorer that iterates over pre-computed top-k results
344struct TopKResultScorer {
345    results: Vec<ScoredDoc>,
346    position: usize,
347}
348
349impl TopKResultScorer {
350    fn new(results: Vec<ScoredDoc>) -> Self {
351        Self {
352            results,
353            position: 0,
354        }
355    }
356}
357
358impl Scorer for TopKResultScorer {
359    fn doc(&self) -> DocId {
360        if self.position < self.results.len() {
361            self.results[self.position].doc_id
362        } else {
363            TERMINATED
364        }
365    }
366
367    fn score(&self) -> Score {
368        if self.position < self.results.len() {
369            self.results[self.position].score
370        } else {
371            0.0
372        }
373    }
374
375    fn advance(&mut self) -> DocId {
376        self.position += 1;
377        self.doc()
378    }
379
380    fn seek(&mut self, target: DocId) -> DocId {
381        while self.position < self.results.len() && self.results[self.position].doc_id < target {
382            self.position += 1;
383        }
384        self.doc()
385    }
386
387    fn size_hint(&self) -> u32 {
388        self.results.len() as u32
389    }
390}
391
392/// Empty scorer for when no terms match
393struct EmptyScorer;
394
395impl Scorer for EmptyScorer {
396    fn doc(&self) -> DocId {
397        TERMINATED
398    }
399
400    fn score(&self) -> Score {
401        0.0
402    }
403
404    fn advance(&mut self) -> DocId {
405        TERMINATED
406    }
407
408    fn seek(&mut self, _target: DocId) -> DocId {
409        TERMINATED
410    }
411
412    fn size_hint(&self) -> u32 {
413        0
414    }
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420    use crate::dsl::Field;
421    use crate::query::TermQuery;
422
423    #[test]
424    fn test_maxscore_eligible_pure_or_same_field() {
425        // Pure OR query with multiple terms in same field should be MaxScore-eligible
426        let query = BooleanQuery::new()
427            .should(TermQuery::text(Field(0), "hello"))
428            .should(TermQuery::text(Field(0), "world"))
429            .should(TermQuery::text(Field(0), "foo"));
430
431        // All clauses should return term info
432        assert!(
433            query
434                .should
435                .iter()
436                .all(|q| q.as_term_query_info().is_some())
437        );
438
439        // All should be same field
440        let infos: Vec<_> = query
441            .should
442            .iter()
443            .filter_map(|q| q.as_term_query_info())
444            .collect();
445        assert_eq!(infos.len(), 3);
446        assert!(infos.iter().all(|i| i.field == Field(0)));
447    }
448
449    #[test]
450    fn test_maxscore_not_eligible_different_fields() {
451        // OR query with terms in different fields should NOT use MaxScore
452        let query = BooleanQuery::new()
453            .should(TermQuery::text(Field(0), "hello"))
454            .should(TermQuery::text(Field(1), "world")); // Different field!
455
456        let infos: Vec<_> = query
457            .should
458            .iter()
459            .filter_map(|q| q.as_term_query_info())
460            .collect();
461        assert_eq!(infos.len(), 2);
462        // Fields are different, MaxScore should not be used
463        assert!(infos[0].field != infos[1].field);
464    }
465
466    #[test]
467    fn test_maxscore_not_eligible_with_must() {
468        // Query with MUST clause should NOT use MaxScore optimization
469        let query = BooleanQuery::new()
470            .must(TermQuery::text(Field(0), "required"))
471            .should(TermQuery::text(Field(0), "hello"))
472            .should(TermQuery::text(Field(0), "world"));
473
474        // Has MUST clause, so MaxScore optimization should not kick in
475        assert!(!query.must.is_empty());
476    }
477
478    #[test]
479    fn test_maxscore_not_eligible_with_must_not() {
480        // Query with MUST_NOT clause should NOT use MaxScore optimization
481        let query = BooleanQuery::new()
482            .should(TermQuery::text(Field(0), "hello"))
483            .should(TermQuery::text(Field(0), "world"))
484            .must_not(TermQuery::text(Field(0), "excluded"));
485
486        // Has MUST_NOT clause, so MaxScore optimization should not kick in
487        assert!(!query.must_not.is_empty());
488    }
489
490    #[test]
491    fn test_maxscore_not_eligible_single_term() {
492        // Single SHOULD clause should NOT use MaxScore (no benefit)
493        let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
494
495        // Only one term, MaxScore not beneficial
496        assert_eq!(query.should.len(), 1);
497    }
498
499    #[test]
500    fn test_term_query_info_extraction() {
501        let term_query = TermQuery::text(Field(42), "test");
502        let info = term_query.as_term_query_info();
503
504        assert!(info.is_some());
505        let info = info.unwrap();
506        assert_eq!(info.field, Field(42));
507        assert_eq!(info.term, b"test");
508    }
509
510    #[test]
511    fn test_boolean_query_no_term_info() {
512        // BooleanQuery itself should not return term info
513        let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
514
515        assert!(query.as_term_query_info().is_none());
516    }
517}