hermes_core/query/
boolean.rs

1//! Boolean query with MUST, SHOULD, and MUST_NOT clauses
2
3use crate::segment::SegmentReader;
4use crate::structures::TERMINATED;
5use crate::{DocId, Score};
6
7use super::{CountFuture, Query, Scorer, ScorerFuture};
8
9/// Boolean query with MUST, SHOULD, and MUST_NOT clauses
10#[derive(Default)]
11pub struct BooleanQuery {
12    pub must: Vec<Box<dyn Query>>,
13    pub should: Vec<Box<dyn Query>>,
14    pub must_not: Vec<Box<dyn Query>>,
15}
16
17impl std::fmt::Debug for BooleanQuery {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        f.debug_struct("BooleanQuery")
20            .field("must_count", &self.must.len())
21            .field("should_count", &self.should.len())
22            .field("must_not_count", &self.must_not.len())
23            .finish()
24    }
25}
26
27impl BooleanQuery {
28    pub fn new() -> Self {
29        Self::default()
30    }
31
32    pub fn must(mut self, query: impl Query + 'static) -> Self {
33        self.must.push(Box::new(query));
34        self
35    }
36
37    pub fn should(mut self, query: impl Query + 'static) -> Self {
38        self.should.push(Box::new(query));
39        self
40    }
41
42    pub fn must_not(mut self, query: impl Query + 'static) -> Self {
43        self.must_not.push(Box::new(query));
44        self
45    }
46}
47
48impl Query for BooleanQuery {
49    fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> ScorerFuture<'a> {
50        Box::pin(async move {
51            let mut must_scorers = Vec::with_capacity(self.must.len());
52            for q in &self.must {
53                must_scorers.push(q.scorer(reader).await?);
54            }
55
56            let mut should_scorers = Vec::with_capacity(self.should.len());
57            for q in &self.should {
58                should_scorers.push(q.scorer(reader).await?);
59            }
60
61            let mut must_not_scorers = Vec::with_capacity(self.must_not.len());
62            for q in &self.must_not {
63                must_not_scorers.push(q.scorer(reader).await?);
64            }
65
66            let mut scorer = BooleanScorer {
67                must: must_scorers,
68                should: should_scorers,
69                must_not: must_not_scorers,
70                current_doc: 0,
71            };
72            // Initialize to first match
73            scorer.current_doc = scorer.find_next_match();
74            Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
75        })
76    }
77
78    fn count_estimate<'a>(&'a self, reader: &'a SegmentReader) -> CountFuture<'a> {
79        Box::pin(async move {
80            if !self.must.is_empty() {
81                let mut estimates = Vec::with_capacity(self.must.len());
82                for q in &self.must {
83                    estimates.push(q.count_estimate(reader).await?);
84                }
85                estimates
86                    .into_iter()
87                    .min()
88                    .ok_or_else(|| crate::Error::Corruption("Empty must clause".to_string()))
89            } else if !self.should.is_empty() {
90                let mut sum = 0u32;
91                for q in &self.should {
92                    sum += q.count_estimate(reader).await?;
93                }
94                Ok(sum)
95            } else {
96                Ok(0)
97            }
98        })
99    }
100}
101
102struct BooleanScorer<'a> {
103    must: Vec<Box<dyn Scorer + 'a>>,
104    should: Vec<Box<dyn Scorer + 'a>>,
105    must_not: Vec<Box<dyn Scorer + 'a>>,
106    current_doc: DocId,
107}
108
109impl BooleanScorer<'_> {
110    fn find_next_match(&mut self) -> DocId {
111        if self.must.is_empty() && self.should.is_empty() {
112            return TERMINATED;
113        }
114
115        loop {
116            let candidate = if !self.must.is_empty() {
117                let mut max_doc = self
118                    .must
119                    .iter()
120                    .map(|s| s.doc())
121                    .max()
122                    .unwrap_or(TERMINATED);
123
124                if max_doc == TERMINATED {
125                    return TERMINATED;
126                }
127
128                loop {
129                    let mut all_match = true;
130                    for scorer in &mut self.must {
131                        let doc = scorer.seek(max_doc);
132                        if doc == TERMINATED {
133                            return TERMINATED;
134                        }
135                        if doc > max_doc {
136                            max_doc = doc;
137                            all_match = false;
138                            break;
139                        }
140                    }
141                    if all_match {
142                        break;
143                    }
144                }
145                max_doc
146            } else {
147                self.should
148                    .iter()
149                    .map(|s| s.doc())
150                    .filter(|&d| d != TERMINATED)
151                    .min()
152                    .unwrap_or(TERMINATED)
153            };
154
155            if candidate == TERMINATED {
156                return TERMINATED;
157            }
158
159            let excluded = self.must_not.iter_mut().any(|scorer| {
160                let doc = scorer.seek(candidate);
161                doc == candidate
162            });
163
164            if !excluded {
165                self.current_doc = candidate;
166                return candidate;
167            }
168
169            // Advance past excluded candidate
170            if !self.must.is_empty() {
171                for scorer in &mut self.must {
172                    scorer.advance();
173                }
174            } else {
175                // For SHOULD-only: seek all scorers past the excluded candidate
176                for scorer in &mut self.should {
177                    if scorer.doc() <= candidate && scorer.doc() != TERMINATED {
178                        scorer.seek(candidate + 1);
179                    }
180                }
181            }
182        }
183    }
184}
185
186impl Scorer for BooleanScorer<'_> {
187    fn doc(&self) -> DocId {
188        self.current_doc
189    }
190
191    fn score(&self) -> Score {
192        let mut total = 0.0;
193
194        for scorer in &self.must {
195            if scorer.doc() == self.current_doc {
196                total += scorer.score();
197            }
198        }
199
200        for scorer in &self.should {
201            if scorer.doc() == self.current_doc {
202                total += scorer.score();
203            }
204        }
205
206        total
207    }
208
209    fn advance(&mut self) -> DocId {
210        if !self.must.is_empty() {
211            for scorer in &mut self.must {
212                scorer.advance();
213            }
214        } else {
215            for scorer in &mut self.should {
216                if scorer.doc() == self.current_doc {
217                    scorer.advance();
218                }
219            }
220        }
221
222        self.find_next_match()
223    }
224
225    fn seek(&mut self, target: DocId) -> DocId {
226        for scorer in &mut self.must {
227            scorer.seek(target);
228        }
229
230        for scorer in &mut self.should {
231            scorer.seek(target);
232        }
233
234        self.find_next_match()
235    }
236
237    fn size_hint(&self) -> u32 {
238        if !self.must.is_empty() {
239            self.must.iter().map(|s| s.size_hint()).min().unwrap_or(0)
240        } else {
241            self.should.iter().map(|s| s.size_hint()).sum()
242        }
243    }
244}