hermes_core/query/
boolean.rs

1//! Boolean query with MUST, SHOULD, and MUST_NOT clauses
2
3use crate::segment::SegmentReader;
4use crate::structures::TERMINATED;
5use crate::{DocId, Score};
6
7use super::{CountFuture, Query, ScoredDoc, Scorer, ScorerFuture, TextTermScorer, WandExecutor};
8
9/// Boolean query with MUST, SHOULD, and MUST_NOT clauses
10#[derive(Default)]
11pub struct BooleanQuery {
12    pub must: Vec<Box<dyn Query>>,
13    pub should: Vec<Box<dyn Query>>,
14    pub must_not: Vec<Box<dyn Query>>,
15}
16
17impl std::fmt::Debug for BooleanQuery {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        f.debug_struct("BooleanQuery")
20            .field("must_count", &self.must.len())
21            .field("should_count", &self.should.len())
22            .field("must_not_count", &self.must_not.len())
23            .finish()
24    }
25}
26
27impl BooleanQuery {
28    pub fn new() -> Self {
29        Self::default()
30    }
31
32    pub fn must(mut self, query: impl Query + 'static) -> Self {
33        self.must.push(Box::new(query));
34        self
35    }
36
37    pub fn should(mut self, query: impl Query + 'static) -> Self {
38        self.should.push(Box::new(query));
39        self
40    }
41
42    pub fn must_not(mut self, query: impl Query + 'static) -> Self {
43        self.must_not.push(Box::new(query));
44        self
45    }
46
47    /// Try to create a WAND-optimized scorer for pure OR queries
48    ///
49    /// Returns Some(scorer) if all SHOULD clauses are term queries for the same field.
50    /// Returns None if WAND optimization is not applicable.
51    async fn try_wand_scorer<'a>(
52        &'a self,
53        reader: &'a SegmentReader,
54        limit: usize,
55    ) -> crate::Result<Option<Box<dyn Scorer + 'a>>> {
56        // Extract term info from all SHOULD clauses
57        let mut term_infos: Vec<_> = self
58            .should
59            .iter()
60            .filter_map(|q| q.as_term_query_info())
61            .collect();
62
63        // Check if all clauses are term queries
64        if term_infos.len() != self.should.len() {
65            return Ok(None);
66        }
67
68        // Check if all terms are for the same field
69        let first_field = term_infos[0].field;
70        if !term_infos.iter().all(|t| t.field == first_field) {
71            return Ok(None);
72        }
73
74        // Build WAND scorers for each term
75        let mut scorers: Vec<TextTermScorer> = Vec::with_capacity(term_infos.len());
76        let avg_field_len = reader.avg_field_len(first_field);
77        let num_docs = reader.num_docs() as f32;
78
79        for info in term_infos.drain(..) {
80            if let Some(posting_list) = reader.get_postings(info.field, &info.term).await? {
81                let doc_freq = posting_list.doc_count() as f32;
82                let idf = super::bm25_idf(doc_freq, num_docs);
83                scorers.push(TextTermScorer::new(posting_list, idf, avg_field_len));
84            }
85        }
86
87        if scorers.is_empty() {
88            return Ok(Some(Box::new(EmptyWandScorer) as Box<dyn Scorer + 'a>));
89        }
90
91        // Use WAND executor for efficient top-k
92        let results = WandExecutor::new(scorers, limit).execute();
93        Ok(Some(
94            Box::new(WandResultScorer::new(results)) as Box<dyn Scorer + 'a>
95        ))
96    }
97}
98
99impl Query for BooleanQuery {
100    fn scorer<'a>(&'a self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
101        Box::pin(async move {
102            // Check if this is a pure OR query eligible for WAND optimization
103            // Conditions: no MUST, no MUST_NOT, multiple SHOULD clauses, all same field
104            if self.must.is_empty()
105                && self.must_not.is_empty()
106                && self.should.len() >= 2
107                && let Some(scorer) = self.try_wand_scorer(reader, limit).await?
108            {
109                return Ok(scorer);
110            }
111
112            // Fall back to standard boolean scoring
113            let mut must_scorers = Vec::with_capacity(self.must.len());
114            for q in &self.must {
115                must_scorers.push(q.scorer(reader, limit).await?);
116            }
117
118            let mut should_scorers = Vec::with_capacity(self.should.len());
119            for q in &self.should {
120                should_scorers.push(q.scorer(reader, limit).await?);
121            }
122
123            let mut must_not_scorers = Vec::with_capacity(self.must_not.len());
124            for q in &self.must_not {
125                must_not_scorers.push(q.scorer(reader, limit).await?);
126            }
127
128            let mut scorer = BooleanScorer {
129                must: must_scorers,
130                should: should_scorers,
131                must_not: must_not_scorers,
132                current_doc: 0,
133            };
134            // Initialize to first match
135            scorer.current_doc = scorer.find_next_match();
136            Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
137        })
138    }
139
140    fn count_estimate<'a>(&'a self, reader: &'a SegmentReader) -> CountFuture<'a> {
141        Box::pin(async move {
142            if !self.must.is_empty() {
143                let mut estimates = Vec::with_capacity(self.must.len());
144                for q in &self.must {
145                    estimates.push(q.count_estimate(reader).await?);
146                }
147                estimates
148                    .into_iter()
149                    .min()
150                    .ok_or_else(|| crate::Error::Corruption("Empty must clause".to_string()))
151            } else if !self.should.is_empty() {
152                let mut sum = 0u32;
153                for q in &self.should {
154                    sum += q.count_estimate(reader).await?;
155                }
156                Ok(sum)
157            } else {
158                Ok(0)
159            }
160        })
161    }
162}
163
164struct BooleanScorer<'a> {
165    must: Vec<Box<dyn Scorer + 'a>>,
166    should: Vec<Box<dyn Scorer + 'a>>,
167    must_not: Vec<Box<dyn Scorer + 'a>>,
168    current_doc: DocId,
169}
170
171impl BooleanScorer<'_> {
172    fn find_next_match(&mut self) -> DocId {
173        if self.must.is_empty() && self.should.is_empty() {
174            return TERMINATED;
175        }
176
177        loop {
178            let candidate = if !self.must.is_empty() {
179                let mut max_doc = self
180                    .must
181                    .iter()
182                    .map(|s| s.doc())
183                    .max()
184                    .unwrap_or(TERMINATED);
185
186                if max_doc == TERMINATED {
187                    return TERMINATED;
188                }
189
190                loop {
191                    let mut all_match = true;
192                    for scorer in &mut self.must {
193                        let doc = scorer.seek(max_doc);
194                        if doc == TERMINATED {
195                            return TERMINATED;
196                        }
197                        if doc > max_doc {
198                            max_doc = doc;
199                            all_match = false;
200                            break;
201                        }
202                    }
203                    if all_match {
204                        break;
205                    }
206                }
207                max_doc
208            } else {
209                self.should
210                    .iter()
211                    .map(|s| s.doc())
212                    .filter(|&d| d != TERMINATED)
213                    .min()
214                    .unwrap_or(TERMINATED)
215            };
216
217            if candidate == TERMINATED {
218                return TERMINATED;
219            }
220
221            let excluded = self.must_not.iter_mut().any(|scorer| {
222                let doc = scorer.seek(candidate);
223                doc == candidate
224            });
225
226            if !excluded {
227                self.current_doc = candidate;
228                return candidate;
229            }
230
231            // Advance past excluded candidate
232            if !self.must.is_empty() {
233                for scorer in &mut self.must {
234                    scorer.advance();
235                }
236            } else {
237                // For SHOULD-only: seek all scorers past the excluded candidate
238                for scorer in &mut self.should {
239                    if scorer.doc() <= candidate && scorer.doc() != TERMINATED {
240                        scorer.seek(candidate + 1);
241                    }
242                }
243            }
244        }
245    }
246}
247
248impl Scorer for BooleanScorer<'_> {
249    fn doc(&self) -> DocId {
250        self.current_doc
251    }
252
253    fn score(&self) -> Score {
254        let mut total = 0.0;
255
256        for scorer in &self.must {
257            if scorer.doc() == self.current_doc {
258                total += scorer.score();
259            }
260        }
261
262        for scorer in &self.should {
263            if scorer.doc() == self.current_doc {
264                total += scorer.score();
265            }
266        }
267
268        total
269    }
270
271    fn advance(&mut self) -> DocId {
272        if !self.must.is_empty() {
273            for scorer in &mut self.must {
274                scorer.advance();
275            }
276        } else {
277            for scorer in &mut self.should {
278                if scorer.doc() == self.current_doc {
279                    scorer.advance();
280                }
281            }
282        }
283
284        self.find_next_match()
285    }
286
287    fn seek(&mut self, target: DocId) -> DocId {
288        for scorer in &mut self.must {
289            scorer.seek(target);
290        }
291
292        for scorer in &mut self.should {
293            scorer.seek(target);
294        }
295
296        self.find_next_match()
297    }
298
299    fn size_hint(&self) -> u32 {
300        if !self.must.is_empty() {
301            self.must.iter().map(|s| s.size_hint()).min().unwrap_or(0)
302        } else {
303            self.should.iter().map(|s| s.size_hint()).sum()
304        }
305    }
306}
307
308/// Scorer that iterates over pre-computed WAND results
309struct WandResultScorer {
310    results: Vec<ScoredDoc>,
311    position: usize,
312}
313
314impl WandResultScorer {
315    fn new(results: Vec<ScoredDoc>) -> Self {
316        Self {
317            results,
318            position: 0,
319        }
320    }
321}
322
323impl Scorer for WandResultScorer {
324    fn doc(&self) -> DocId {
325        if self.position < self.results.len() {
326            self.results[self.position].doc_id
327        } else {
328            TERMINATED
329        }
330    }
331
332    fn score(&self) -> Score {
333        if self.position < self.results.len() {
334            self.results[self.position].score
335        } else {
336            0.0
337        }
338    }
339
340    fn advance(&mut self) -> DocId {
341        self.position += 1;
342        self.doc()
343    }
344
345    fn seek(&mut self, target: DocId) -> DocId {
346        while self.position < self.results.len() && self.results[self.position].doc_id < target {
347            self.position += 1;
348        }
349        self.doc()
350    }
351
352    fn size_hint(&self) -> u32 {
353        self.results.len() as u32
354    }
355}
356
357/// Empty scorer for when no terms match
358struct EmptyWandScorer;
359
360impl Scorer for EmptyWandScorer {
361    fn doc(&self) -> DocId {
362        TERMINATED
363    }
364
365    fn score(&self) -> Score {
366        0.0
367    }
368
369    fn advance(&mut self) -> DocId {
370        TERMINATED
371    }
372
373    fn seek(&mut self, _target: DocId) -> DocId {
374        TERMINATED
375    }
376
377    fn size_hint(&self) -> u32 {
378        0
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use crate::dsl::Field;
386    use crate::query::TermQuery;
387
388    #[test]
389    fn test_wand_eligible_pure_or_same_field() {
390        // Pure OR query with multiple terms in same field should be WAND-eligible
391        let query = BooleanQuery::new()
392            .should(TermQuery::text(Field(0), "hello"))
393            .should(TermQuery::text(Field(0), "world"))
394            .should(TermQuery::text(Field(0), "foo"));
395
396        // All clauses should return term info
397        assert!(
398            query
399                .should
400                .iter()
401                .all(|q| q.as_term_query_info().is_some())
402        );
403
404        // All should be same field
405        let infos: Vec<_> = query
406            .should
407            .iter()
408            .filter_map(|q| q.as_term_query_info())
409            .collect();
410        assert_eq!(infos.len(), 3);
411        assert!(infos.iter().all(|i| i.field == Field(0)));
412    }
413
414    #[test]
415    fn test_wand_not_eligible_different_fields() {
416        // OR query with terms in different fields should NOT use WAND
417        let query = BooleanQuery::new()
418            .should(TermQuery::text(Field(0), "hello"))
419            .should(TermQuery::text(Field(1), "world")); // Different field!
420
421        let infos: Vec<_> = query
422            .should
423            .iter()
424            .filter_map(|q| q.as_term_query_info())
425            .collect();
426        assert_eq!(infos.len(), 2);
427        // Fields are different, WAND should not be used
428        assert!(infos[0].field != infos[1].field);
429    }
430
431    #[test]
432    fn test_wand_not_eligible_with_must() {
433        // Query with MUST clause should NOT use WAND optimization
434        let query = BooleanQuery::new()
435            .must(TermQuery::text(Field(0), "required"))
436            .should(TermQuery::text(Field(0), "hello"))
437            .should(TermQuery::text(Field(0), "world"));
438
439        // Has MUST clause, so WAND optimization should not kick in
440        assert!(!query.must.is_empty());
441    }
442
443    #[test]
444    fn test_wand_not_eligible_with_must_not() {
445        // Query with MUST_NOT clause should NOT use WAND optimization
446        let query = BooleanQuery::new()
447            .should(TermQuery::text(Field(0), "hello"))
448            .should(TermQuery::text(Field(0), "world"))
449            .must_not(TermQuery::text(Field(0), "excluded"));
450
451        // Has MUST_NOT clause, so WAND optimization should not kick in
452        assert!(!query.must_not.is_empty());
453    }
454
455    #[test]
456    fn test_wand_not_eligible_single_term() {
457        // Single SHOULD clause should NOT use WAND (no benefit)
458        let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
459
460        // Only one term, WAND not beneficial
461        assert_eq!(query.should.len(), 1);
462    }
463
464    #[test]
465    fn test_term_query_info_extraction() {
466        let term_query = TermQuery::text(Field(42), "test");
467        let info = term_query.as_term_query_info();
468
469        assert!(info.is_some());
470        let info = info.unwrap();
471        assert_eq!(info.field, Field(42));
472        assert_eq!(info.term, b"test");
473    }
474
475    #[test]
476    fn test_boolean_query_no_term_info() {
477        // BooleanQuery itself should not return term info
478        let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
479
480        assert!(query.as_term_query_info().is_none());
481    }
482}