Skip to main content

hermes_core/query/
boolean.rs

1//! Boolean query with MUST, SHOULD, and MUST_NOT clauses
2
3use std::sync::Arc;
4
5use crate::segment::SegmentReader;
6use crate::structures::TERMINATED;
7use crate::{DocId, Score};
8
9use super::{CountFuture, Query, ScoredDoc, Scorer, ScorerFuture, TextTermScorer, WandExecutor};
10
11/// Boolean query with MUST, SHOULD, and MUST_NOT clauses
12#[derive(Default, Clone)]
13pub struct BooleanQuery {
14    pub must: Vec<Arc<dyn Query>>,
15    pub should: Vec<Arc<dyn Query>>,
16    pub must_not: Vec<Arc<dyn Query>>,
17}
18
19impl std::fmt::Debug for BooleanQuery {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        f.debug_struct("BooleanQuery")
22            .field("must_count", &self.must.len())
23            .field("should_count", &self.should.len())
24            .field("must_not_count", &self.must_not.len())
25            .finish()
26    }
27}
28
29impl BooleanQuery {
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    pub fn must(mut self, query: impl Query + 'static) -> Self {
35        self.must.push(Arc::new(query));
36        self
37    }
38
39    pub fn should(mut self, query: impl Query + 'static) -> Self {
40        self.should.push(Arc::new(query));
41        self
42    }
43
44    pub fn must_not(mut self, query: impl Query + 'static) -> Self {
45        self.must_not.push(Arc::new(query));
46        self
47    }
48}
49
50/// Try to create a WAND-optimized scorer for pure OR queries
51async fn try_wand_scorer<'a>(
52    should: &[Arc<dyn Query>],
53    reader: &'a SegmentReader,
54    limit: usize,
55) -> crate::Result<Option<Box<dyn Scorer + 'a>>> {
56    // Extract term info from all SHOULD clauses
57    let mut term_infos: Vec<_> = should
58        .iter()
59        .filter_map(|q| q.as_term_query_info())
60        .collect();
61
62    // Check if all clauses are term queries
63    if term_infos.len() != should.len() {
64        return Ok(None);
65    }
66
67    // Check if all terms are for the same field
68    let first_field = term_infos[0].field;
69    if !term_infos.iter().all(|t| t.field == first_field) {
70        return Ok(None);
71    }
72
73    // Build WAND scorers for each term
74    let mut scorers: Vec<TextTermScorer> = Vec::with_capacity(term_infos.len());
75    let avg_field_len = reader.avg_field_len(first_field);
76    let num_docs = reader.num_docs() as f32;
77
78    for info in term_infos.drain(..) {
79        if let Some(posting_list) = reader.get_postings(info.field, &info.term).await? {
80            let doc_freq = posting_list.doc_count() as f32;
81            let idf = super::bm25_idf(doc_freq, num_docs);
82            scorers.push(TextTermScorer::new(posting_list, idf, avg_field_len));
83        }
84    }
85
86    if scorers.is_empty() {
87        return Ok(Some(Box::new(EmptyWandScorer) as Box<dyn Scorer + 'a>));
88    }
89
90    // Use WAND executor for efficient top-k
91    let results = WandExecutor::new(scorers, limit).execute();
92    Ok(Some(
93        Box::new(WandResultScorer::new(results)) as Box<dyn Scorer + 'a>
94    ))
95}
96
97impl Query for BooleanQuery {
98    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
99        // Clone Arc vectors - cheap reference counting
100        let must = self.must.clone();
101        let should = self.should.clone();
102        let must_not = self.must_not.clone();
103
104        Box::pin(async move {
105            // Check if this is a pure OR query eligible for WAND optimization
106            // Conditions: no MUST, no MUST_NOT, multiple SHOULD clauses, all same field
107            if must.is_empty()
108                && must_not.is_empty()
109                && should.len() >= 2
110                && let Some(scorer) = try_wand_scorer(&should, reader, limit).await?
111            {
112                return Ok(scorer);
113            }
114
115            // Fall back to standard boolean scoring
116            let mut must_scorers = Vec::with_capacity(must.len());
117            for q in &must {
118                must_scorers.push(q.scorer(reader, limit).await?);
119            }
120
121            let mut should_scorers = Vec::with_capacity(should.len());
122            for q in &should {
123                should_scorers.push(q.scorer(reader, limit).await?);
124            }
125
126            let mut must_not_scorers = Vec::with_capacity(must_not.len());
127            for q in &must_not {
128                must_not_scorers.push(q.scorer(reader, limit).await?);
129            }
130
131            let mut scorer = BooleanScorer {
132                must: must_scorers,
133                should: should_scorers,
134                must_not: must_not_scorers,
135                current_doc: 0,
136            };
137            // Initialize to first match
138            scorer.current_doc = scorer.find_next_match();
139            Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
140        })
141    }
142
143    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
144        let must = self.must.clone();
145        let should = self.should.clone();
146
147        Box::pin(async move {
148            if !must.is_empty() {
149                let mut estimates = Vec::with_capacity(must.len());
150                for q in &must {
151                    estimates.push(q.count_estimate(reader).await?);
152                }
153                estimates
154                    .into_iter()
155                    .min()
156                    .ok_or_else(|| crate::Error::Corruption("Empty must clause".to_string()))
157            } else if !should.is_empty() {
158                let mut sum = 0u32;
159                for q in &should {
160                    sum += q.count_estimate(reader).await?;
161                }
162                Ok(sum)
163            } else {
164                Ok(0)
165            }
166        })
167    }
168}
169
170struct BooleanScorer<'a> {
171    must: Vec<Box<dyn Scorer + 'a>>,
172    should: Vec<Box<dyn Scorer + 'a>>,
173    must_not: Vec<Box<dyn Scorer + 'a>>,
174    current_doc: DocId,
175}
176
177impl BooleanScorer<'_> {
178    fn find_next_match(&mut self) -> DocId {
179        if self.must.is_empty() && self.should.is_empty() {
180            return TERMINATED;
181        }
182
183        loop {
184            let candidate = if !self.must.is_empty() {
185                let mut max_doc = self
186                    .must
187                    .iter()
188                    .map(|s| s.doc())
189                    .max()
190                    .unwrap_or(TERMINATED);
191
192                if max_doc == TERMINATED {
193                    return TERMINATED;
194                }
195
196                loop {
197                    let mut all_match = true;
198                    for scorer in &mut self.must {
199                        let doc = scorer.seek(max_doc);
200                        if doc == TERMINATED {
201                            return TERMINATED;
202                        }
203                        if doc > max_doc {
204                            max_doc = doc;
205                            all_match = false;
206                            break;
207                        }
208                    }
209                    if all_match {
210                        break;
211                    }
212                }
213                max_doc
214            } else {
215                self.should
216                    .iter()
217                    .map(|s| s.doc())
218                    .filter(|&d| d != TERMINATED)
219                    .min()
220                    .unwrap_or(TERMINATED)
221            };
222
223            if candidate == TERMINATED {
224                return TERMINATED;
225            }
226
227            let excluded = self.must_not.iter_mut().any(|scorer| {
228                let doc = scorer.seek(candidate);
229                doc == candidate
230            });
231
232            if !excluded {
233                self.current_doc = candidate;
234                return candidate;
235            }
236
237            // Advance past excluded candidate
238            if !self.must.is_empty() {
239                for scorer in &mut self.must {
240                    scorer.advance();
241                }
242            } else {
243                // For SHOULD-only: seek all scorers past the excluded candidate
244                for scorer in &mut self.should {
245                    if scorer.doc() <= candidate && scorer.doc() != TERMINATED {
246                        scorer.seek(candidate + 1);
247                    }
248                }
249            }
250        }
251    }
252}
253
254impl Scorer for BooleanScorer<'_> {
255    fn doc(&self) -> DocId {
256        self.current_doc
257    }
258
259    fn score(&self) -> Score {
260        let mut total = 0.0;
261
262        for scorer in &self.must {
263            if scorer.doc() == self.current_doc {
264                total += scorer.score();
265            }
266        }
267
268        for scorer in &self.should {
269            if scorer.doc() == self.current_doc {
270                total += scorer.score();
271            }
272        }
273
274        total
275    }
276
277    fn advance(&mut self) -> DocId {
278        if !self.must.is_empty() {
279            for scorer in &mut self.must {
280                scorer.advance();
281            }
282        } else {
283            for scorer in &mut self.should {
284                if scorer.doc() == self.current_doc {
285                    scorer.advance();
286                }
287            }
288        }
289
290        self.find_next_match()
291    }
292
293    fn seek(&mut self, target: DocId) -> DocId {
294        for scorer in &mut self.must {
295            scorer.seek(target);
296        }
297
298        for scorer in &mut self.should {
299            scorer.seek(target);
300        }
301
302        self.find_next_match()
303    }
304
305    fn size_hint(&self) -> u32 {
306        if !self.must.is_empty() {
307            self.must.iter().map(|s| s.size_hint()).min().unwrap_or(0)
308        } else {
309            self.should.iter().map(|s| s.size_hint()).sum()
310        }
311    }
312}
313
314/// Scorer that iterates over pre-computed WAND results
315struct WandResultScorer {
316    results: Vec<ScoredDoc>,
317    position: usize,
318}
319
320impl WandResultScorer {
321    fn new(results: Vec<ScoredDoc>) -> Self {
322        Self {
323            results,
324            position: 0,
325        }
326    }
327}
328
329impl Scorer for WandResultScorer {
330    fn doc(&self) -> DocId {
331        if self.position < self.results.len() {
332            self.results[self.position].doc_id
333        } else {
334            TERMINATED
335        }
336    }
337
338    fn score(&self) -> Score {
339        if self.position < self.results.len() {
340            self.results[self.position].score
341        } else {
342            0.0
343        }
344    }
345
346    fn advance(&mut self) -> DocId {
347        self.position += 1;
348        self.doc()
349    }
350
351    fn seek(&mut self, target: DocId) -> DocId {
352        while self.position < self.results.len() && self.results[self.position].doc_id < target {
353            self.position += 1;
354        }
355        self.doc()
356    }
357
358    fn size_hint(&self) -> u32 {
359        self.results.len() as u32
360    }
361}
362
363/// Empty scorer for when no terms match
364struct EmptyWandScorer;
365
366impl Scorer for EmptyWandScorer {
367    fn doc(&self) -> DocId {
368        TERMINATED
369    }
370
371    fn score(&self) -> Score {
372        0.0
373    }
374
375    fn advance(&mut self) -> DocId {
376        TERMINATED
377    }
378
379    fn seek(&mut self, _target: DocId) -> DocId {
380        TERMINATED
381    }
382
383    fn size_hint(&self) -> u32 {
384        0
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391    use crate::dsl::Field;
392    use crate::query::TermQuery;
393
394    #[test]
395    fn test_wand_eligible_pure_or_same_field() {
396        // Pure OR query with multiple terms in same field should be WAND-eligible
397        let query = BooleanQuery::new()
398            .should(TermQuery::text(Field(0), "hello"))
399            .should(TermQuery::text(Field(0), "world"))
400            .should(TermQuery::text(Field(0), "foo"));
401
402        // All clauses should return term info
403        assert!(
404            query
405                .should
406                .iter()
407                .all(|q| q.as_term_query_info().is_some())
408        );
409
410        // All should be same field
411        let infos: Vec<_> = query
412            .should
413            .iter()
414            .filter_map(|q| q.as_term_query_info())
415            .collect();
416        assert_eq!(infos.len(), 3);
417        assert!(infos.iter().all(|i| i.field == Field(0)));
418    }
419
420    #[test]
421    fn test_wand_not_eligible_different_fields() {
422        // OR query with terms in different fields should NOT use WAND
423        let query = BooleanQuery::new()
424            .should(TermQuery::text(Field(0), "hello"))
425            .should(TermQuery::text(Field(1), "world")); // Different field!
426
427        let infos: Vec<_> = query
428            .should
429            .iter()
430            .filter_map(|q| q.as_term_query_info())
431            .collect();
432        assert_eq!(infos.len(), 2);
433        // Fields are different, WAND should not be used
434        assert!(infos[0].field != infos[1].field);
435    }
436
437    #[test]
438    fn test_wand_not_eligible_with_must() {
439        // Query with MUST clause should NOT use WAND optimization
440        let query = BooleanQuery::new()
441            .must(TermQuery::text(Field(0), "required"))
442            .should(TermQuery::text(Field(0), "hello"))
443            .should(TermQuery::text(Field(0), "world"));
444
445        // Has MUST clause, so WAND optimization should not kick in
446        assert!(!query.must.is_empty());
447    }
448
449    #[test]
450    fn test_wand_not_eligible_with_must_not() {
451        // Query with MUST_NOT clause should NOT use WAND optimization
452        let query = BooleanQuery::new()
453            .should(TermQuery::text(Field(0), "hello"))
454            .should(TermQuery::text(Field(0), "world"))
455            .must_not(TermQuery::text(Field(0), "excluded"));
456
457        // Has MUST_NOT clause, so WAND optimization should not kick in
458        assert!(!query.must_not.is_empty());
459    }
460
461    #[test]
462    fn test_wand_not_eligible_single_term() {
463        // Single SHOULD clause should NOT use WAND (no benefit)
464        let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
465
466        // Only one term, WAND not beneficial
467        assert_eq!(query.should.len(), 1);
468    }
469
470    #[test]
471    fn test_term_query_info_extraction() {
472        let term_query = TermQuery::text(Field(42), "test");
473        let info = term_query.as_term_query_info();
474
475        assert!(info.is_some());
476        let info = info.unwrap();
477        assert_eq!(info.field, Field(42));
478        assert_eq!(info.term, b"test");
479    }
480
481    #[test]
482    fn test_boolean_query_no_term_info() {
483        // BooleanQuery itself should not return term info
484        let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
485
486        assert!(query.as_term_query_info().is_none());
487    }
488}