hermes_core/query/
boolean.rs1use crate::segment::SegmentReader;
4use crate::structures::TERMINATED;
5use crate::{DocId, Score};
6
7use super::{CountFuture, Query, ScoredDoc, Scorer, ScorerFuture, TextTermScorer, WandExecutor};
8
9#[derive(Default)]
11pub struct BooleanQuery {
12 pub must: Vec<Box<dyn Query>>,
13 pub should: Vec<Box<dyn Query>>,
14 pub must_not: Vec<Box<dyn Query>>,
15}
16
17impl std::fmt::Debug for BooleanQuery {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 f.debug_struct("BooleanQuery")
20 .field("must_count", &self.must.len())
21 .field("should_count", &self.should.len())
22 .field("must_not_count", &self.must_not.len())
23 .finish()
24 }
25}
26
27impl BooleanQuery {
28 pub fn new() -> Self {
29 Self::default()
30 }
31
32 pub fn must(mut self, query: impl Query + 'static) -> Self {
33 self.must.push(Box::new(query));
34 self
35 }
36
37 pub fn should(mut self, query: impl Query + 'static) -> Self {
38 self.should.push(Box::new(query));
39 self
40 }
41
42 pub fn must_not(mut self, query: impl Query + 'static) -> Self {
43 self.must_not.push(Box::new(query));
44 self
45 }
46
47 async fn try_wand_scorer<'a>(
52 &'a self,
53 reader: &'a SegmentReader,
54 limit: usize,
55 ) -> crate::Result<Option<Box<dyn Scorer + 'a>>> {
56 let mut term_infos: Vec<_> = self
58 .should
59 .iter()
60 .filter_map(|q| q.as_term_query_info())
61 .collect();
62
63 if term_infos.len() != self.should.len() {
65 return Ok(None);
66 }
67
68 let first_field = term_infos[0].field;
70 if !term_infos.iter().all(|t| t.field == first_field) {
71 return Ok(None);
72 }
73
74 let mut scorers: Vec<TextTermScorer> = Vec::with_capacity(term_infos.len());
76 let avg_field_len = reader.avg_field_len(first_field);
77 let num_docs = reader.num_docs() as f32;
78
79 for info in term_infos.drain(..) {
80 if let Some(posting_list) = reader.get_postings(info.field, &info.term).await? {
81 let doc_freq = posting_list.doc_count() as f32;
82 let idf = super::bm25_idf(doc_freq, num_docs);
83 scorers.push(TextTermScorer::new(posting_list, idf, avg_field_len));
84 }
85 }
86
87 if scorers.is_empty() {
88 return Ok(Some(Box::new(EmptyWandScorer) as Box<dyn Scorer + 'a>));
89 }
90
91 let results = WandExecutor::new(scorers, limit).execute();
93 Ok(Some(
94 Box::new(WandResultScorer::new(results)) as Box<dyn Scorer + 'a>
95 ))
96 }
97}
98
99impl Query for BooleanQuery {
100 fn scorer<'a>(&'a self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
101 Box::pin(async move {
102 if self.must.is_empty()
105 && self.must_not.is_empty()
106 && self.should.len() >= 2
107 && let Some(scorer) = self.try_wand_scorer(reader, limit).await?
108 {
109 return Ok(scorer);
110 }
111
112 let mut must_scorers = Vec::with_capacity(self.must.len());
114 for q in &self.must {
115 must_scorers.push(q.scorer(reader, limit).await?);
116 }
117
118 let mut should_scorers = Vec::with_capacity(self.should.len());
119 for q in &self.should {
120 should_scorers.push(q.scorer(reader, limit).await?);
121 }
122
123 let mut must_not_scorers = Vec::with_capacity(self.must_not.len());
124 for q in &self.must_not {
125 must_not_scorers.push(q.scorer(reader, limit).await?);
126 }
127
128 let mut scorer = BooleanScorer {
129 must: must_scorers,
130 should: should_scorers,
131 must_not: must_not_scorers,
132 current_doc: 0,
133 };
134 scorer.current_doc = scorer.find_next_match();
136 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
137 })
138 }
139
140 fn count_estimate<'a>(&'a self, reader: &'a SegmentReader) -> CountFuture<'a> {
141 Box::pin(async move {
142 if !self.must.is_empty() {
143 let mut estimates = Vec::with_capacity(self.must.len());
144 for q in &self.must {
145 estimates.push(q.count_estimate(reader).await?);
146 }
147 estimates
148 .into_iter()
149 .min()
150 .ok_or_else(|| crate::Error::Corruption("Empty must clause".to_string()))
151 } else if !self.should.is_empty() {
152 let mut sum = 0u32;
153 for q in &self.should {
154 sum += q.count_estimate(reader).await?;
155 }
156 Ok(sum)
157 } else {
158 Ok(0)
159 }
160 })
161 }
162}
163
164struct BooleanScorer<'a> {
165 must: Vec<Box<dyn Scorer + 'a>>,
166 should: Vec<Box<dyn Scorer + 'a>>,
167 must_not: Vec<Box<dyn Scorer + 'a>>,
168 current_doc: DocId,
169}
170
171impl BooleanScorer<'_> {
172 fn find_next_match(&mut self) -> DocId {
173 if self.must.is_empty() && self.should.is_empty() {
174 return TERMINATED;
175 }
176
177 loop {
178 let candidate = if !self.must.is_empty() {
179 let mut max_doc = self
180 .must
181 .iter()
182 .map(|s| s.doc())
183 .max()
184 .unwrap_or(TERMINATED);
185
186 if max_doc == TERMINATED {
187 return TERMINATED;
188 }
189
190 loop {
191 let mut all_match = true;
192 for scorer in &mut self.must {
193 let doc = scorer.seek(max_doc);
194 if doc == TERMINATED {
195 return TERMINATED;
196 }
197 if doc > max_doc {
198 max_doc = doc;
199 all_match = false;
200 break;
201 }
202 }
203 if all_match {
204 break;
205 }
206 }
207 max_doc
208 } else {
209 self.should
210 .iter()
211 .map(|s| s.doc())
212 .filter(|&d| d != TERMINATED)
213 .min()
214 .unwrap_or(TERMINATED)
215 };
216
217 if candidate == TERMINATED {
218 return TERMINATED;
219 }
220
221 let excluded = self.must_not.iter_mut().any(|scorer| {
222 let doc = scorer.seek(candidate);
223 doc == candidate
224 });
225
226 if !excluded {
227 self.current_doc = candidate;
228 return candidate;
229 }
230
231 if !self.must.is_empty() {
233 for scorer in &mut self.must {
234 scorer.advance();
235 }
236 } else {
237 for scorer in &mut self.should {
239 if scorer.doc() <= candidate && scorer.doc() != TERMINATED {
240 scorer.seek(candidate + 1);
241 }
242 }
243 }
244 }
245 }
246}
247
248impl Scorer for BooleanScorer<'_> {
249 fn doc(&self) -> DocId {
250 self.current_doc
251 }
252
253 fn score(&self) -> Score {
254 let mut total = 0.0;
255
256 for scorer in &self.must {
257 if scorer.doc() == self.current_doc {
258 total += scorer.score();
259 }
260 }
261
262 for scorer in &self.should {
263 if scorer.doc() == self.current_doc {
264 total += scorer.score();
265 }
266 }
267
268 total
269 }
270
271 fn advance(&mut self) -> DocId {
272 if !self.must.is_empty() {
273 for scorer in &mut self.must {
274 scorer.advance();
275 }
276 } else {
277 for scorer in &mut self.should {
278 if scorer.doc() == self.current_doc {
279 scorer.advance();
280 }
281 }
282 }
283
284 self.find_next_match()
285 }
286
287 fn seek(&mut self, target: DocId) -> DocId {
288 for scorer in &mut self.must {
289 scorer.seek(target);
290 }
291
292 for scorer in &mut self.should {
293 scorer.seek(target);
294 }
295
296 self.find_next_match()
297 }
298
299 fn size_hint(&self) -> u32 {
300 if !self.must.is_empty() {
301 self.must.iter().map(|s| s.size_hint()).min().unwrap_or(0)
302 } else {
303 self.should.iter().map(|s| s.size_hint()).sum()
304 }
305 }
306}
307
308struct WandResultScorer {
310 results: Vec<ScoredDoc>,
311 position: usize,
312}
313
314impl WandResultScorer {
315 fn new(results: Vec<ScoredDoc>) -> Self {
316 Self {
317 results,
318 position: 0,
319 }
320 }
321}
322
323impl Scorer for WandResultScorer {
324 fn doc(&self) -> DocId {
325 if self.position < self.results.len() {
326 self.results[self.position].doc_id
327 } else {
328 TERMINATED
329 }
330 }
331
332 fn score(&self) -> Score {
333 if self.position < self.results.len() {
334 self.results[self.position].score
335 } else {
336 0.0
337 }
338 }
339
340 fn advance(&mut self) -> DocId {
341 self.position += 1;
342 self.doc()
343 }
344
345 fn seek(&mut self, target: DocId) -> DocId {
346 while self.position < self.results.len() && self.results[self.position].doc_id < target {
347 self.position += 1;
348 }
349 self.doc()
350 }
351
352 fn size_hint(&self) -> u32 {
353 self.results.len() as u32
354 }
355}
356
357struct EmptyWandScorer;
359
360impl Scorer for EmptyWandScorer {
361 fn doc(&self) -> DocId {
362 TERMINATED
363 }
364
365 fn score(&self) -> Score {
366 0.0
367 }
368
369 fn advance(&mut self) -> DocId {
370 TERMINATED
371 }
372
373 fn seek(&mut self, _target: DocId) -> DocId {
374 TERMINATED
375 }
376
377 fn size_hint(&self) -> u32 {
378 0
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385 use crate::dsl::Field;
386 use crate::query::TermQuery;
387
388 #[test]
389 fn test_wand_eligible_pure_or_same_field() {
390 let query = BooleanQuery::new()
392 .should(TermQuery::text(Field(0), "hello"))
393 .should(TermQuery::text(Field(0), "world"))
394 .should(TermQuery::text(Field(0), "foo"));
395
396 assert!(
398 query
399 .should
400 .iter()
401 .all(|q| q.as_term_query_info().is_some())
402 );
403
404 let infos: Vec<_> = query
406 .should
407 .iter()
408 .filter_map(|q| q.as_term_query_info())
409 .collect();
410 assert_eq!(infos.len(), 3);
411 assert!(infos.iter().all(|i| i.field == Field(0)));
412 }
413
414 #[test]
415 fn test_wand_not_eligible_different_fields() {
416 let query = BooleanQuery::new()
418 .should(TermQuery::text(Field(0), "hello"))
419 .should(TermQuery::text(Field(1), "world")); let infos: Vec<_> = query
422 .should
423 .iter()
424 .filter_map(|q| q.as_term_query_info())
425 .collect();
426 assert_eq!(infos.len(), 2);
427 assert!(infos[0].field != infos[1].field);
429 }
430
431 #[test]
432 fn test_wand_not_eligible_with_must() {
433 let query = BooleanQuery::new()
435 .must(TermQuery::text(Field(0), "required"))
436 .should(TermQuery::text(Field(0), "hello"))
437 .should(TermQuery::text(Field(0), "world"));
438
439 assert!(!query.must.is_empty());
441 }
442
443 #[test]
444 fn test_wand_not_eligible_with_must_not() {
445 let query = BooleanQuery::new()
447 .should(TermQuery::text(Field(0), "hello"))
448 .should(TermQuery::text(Field(0), "world"))
449 .must_not(TermQuery::text(Field(0), "excluded"));
450
451 assert!(!query.must_not.is_empty());
453 }
454
455 #[test]
456 fn test_wand_not_eligible_single_term() {
457 let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
459
460 assert_eq!(query.should.len(), 1);
462 }
463
464 #[test]
465 fn test_term_query_info_extraction() {
466 let term_query = TermQuery::text(Field(42), "test");
467 let info = term_query.as_term_query_info();
468
469 assert!(info.is_some());
470 let info = info.unwrap();
471 assert_eq!(info.field, Field(42));
472 assert_eq!(info.term, b"test");
473 }
474
475 #[test]
476 fn test_boolean_query_no_term_info() {
477 let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
479
480 assert!(query.as_term_query_info().is_none());
481 }
482}