hermes_core/query/
boolean.rs1use std::sync::Arc;
4
5use crate::segment::SegmentReader;
6use crate::structures::TERMINATED;
7use crate::{DocId, Score};
8
9use super::{CountFuture, Query, ScoredDoc, Scorer, ScorerFuture, TextTermScorer, WandExecutor};
10
11#[derive(Default, Clone)]
13pub struct BooleanQuery {
14 pub must: Vec<Arc<dyn Query>>,
15 pub should: Vec<Arc<dyn Query>>,
16 pub must_not: Vec<Arc<dyn Query>>,
17}
18
19impl std::fmt::Debug for BooleanQuery {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 f.debug_struct("BooleanQuery")
22 .field("must_count", &self.must.len())
23 .field("should_count", &self.should.len())
24 .field("must_not_count", &self.must_not.len())
25 .finish()
26 }
27}
28
29impl BooleanQuery {
30 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub fn must(mut self, query: impl Query + 'static) -> Self {
35 self.must.push(Arc::new(query));
36 self
37 }
38
39 pub fn should(mut self, query: impl Query + 'static) -> Self {
40 self.should.push(Arc::new(query));
41 self
42 }
43
44 pub fn must_not(mut self, query: impl Query + 'static) -> Self {
45 self.must_not.push(Arc::new(query));
46 self
47 }
48}
49
50async fn try_wand_scorer<'a>(
52 should: &[Arc<dyn Query>],
53 reader: &'a SegmentReader,
54 limit: usize,
55) -> crate::Result<Option<Box<dyn Scorer + 'a>>> {
56 let mut term_infos: Vec<_> = should
58 .iter()
59 .filter_map(|q| q.as_term_query_info())
60 .collect();
61
62 if term_infos.len() != should.len() {
64 return Ok(None);
65 }
66
67 let first_field = term_infos[0].field;
69 if !term_infos.iter().all(|t| t.field == first_field) {
70 return Ok(None);
71 }
72
73 let mut scorers: Vec<TextTermScorer> = Vec::with_capacity(term_infos.len());
75 let avg_field_len = reader.avg_field_len(first_field);
76 let num_docs = reader.num_docs() as f32;
77
78 for info in term_infos.drain(..) {
79 if let Some(posting_list) = reader.get_postings(info.field, &info.term).await? {
80 let doc_freq = posting_list.doc_count() as f32;
81 let idf = super::bm25_idf(doc_freq, num_docs);
82 scorers.push(TextTermScorer::new(posting_list, idf, avg_field_len));
83 }
84 }
85
86 if scorers.is_empty() {
87 return Ok(Some(Box::new(EmptyWandScorer) as Box<dyn Scorer + 'a>));
88 }
89
90 let results = WandExecutor::new(scorers, limit).execute();
92 Ok(Some(
93 Box::new(WandResultScorer::new(results)) as Box<dyn Scorer + 'a>
94 ))
95}
96
97impl Query for BooleanQuery {
98 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
99 let must = self.must.clone();
101 let should = self.should.clone();
102 let must_not = self.must_not.clone();
103
104 Box::pin(async move {
105 if must.is_empty()
108 && must_not.is_empty()
109 && should.len() >= 2
110 && let Some(scorer) = try_wand_scorer(&should, reader, limit).await?
111 {
112 return Ok(scorer);
113 }
114
115 let mut must_scorers = Vec::with_capacity(must.len());
117 for q in &must {
118 must_scorers.push(q.scorer(reader, limit).await?);
119 }
120
121 let mut should_scorers = Vec::with_capacity(should.len());
122 for q in &should {
123 should_scorers.push(q.scorer(reader, limit).await?);
124 }
125
126 let mut must_not_scorers = Vec::with_capacity(must_not.len());
127 for q in &must_not {
128 must_not_scorers.push(q.scorer(reader, limit).await?);
129 }
130
131 let mut scorer = BooleanScorer {
132 must: must_scorers,
133 should: should_scorers,
134 must_not: must_not_scorers,
135 current_doc: 0,
136 };
137 scorer.current_doc = scorer.find_next_match();
139 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
140 })
141 }
142
143 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
144 let must = self.must.clone();
145 let should = self.should.clone();
146
147 Box::pin(async move {
148 if !must.is_empty() {
149 let mut estimates = Vec::with_capacity(must.len());
150 for q in &must {
151 estimates.push(q.count_estimate(reader).await?);
152 }
153 estimates
154 .into_iter()
155 .min()
156 .ok_or_else(|| crate::Error::Corruption("Empty must clause".to_string()))
157 } else if !should.is_empty() {
158 let mut sum = 0u32;
159 for q in &should {
160 sum += q.count_estimate(reader).await?;
161 }
162 Ok(sum)
163 } else {
164 Ok(0)
165 }
166 })
167 }
168}
169
170struct BooleanScorer<'a> {
171 must: Vec<Box<dyn Scorer + 'a>>,
172 should: Vec<Box<dyn Scorer + 'a>>,
173 must_not: Vec<Box<dyn Scorer + 'a>>,
174 current_doc: DocId,
175}
176
177impl BooleanScorer<'_> {
178 fn find_next_match(&mut self) -> DocId {
179 if self.must.is_empty() && self.should.is_empty() {
180 return TERMINATED;
181 }
182
183 loop {
184 let candidate = if !self.must.is_empty() {
185 let mut max_doc = self
186 .must
187 .iter()
188 .map(|s| s.doc())
189 .max()
190 .unwrap_or(TERMINATED);
191
192 if max_doc == TERMINATED {
193 return TERMINATED;
194 }
195
196 loop {
197 let mut all_match = true;
198 for scorer in &mut self.must {
199 let doc = scorer.seek(max_doc);
200 if doc == TERMINATED {
201 return TERMINATED;
202 }
203 if doc > max_doc {
204 max_doc = doc;
205 all_match = false;
206 break;
207 }
208 }
209 if all_match {
210 break;
211 }
212 }
213 max_doc
214 } else {
215 self.should
216 .iter()
217 .map(|s| s.doc())
218 .filter(|&d| d != TERMINATED)
219 .min()
220 .unwrap_or(TERMINATED)
221 };
222
223 if candidate == TERMINATED {
224 return TERMINATED;
225 }
226
227 let excluded = self.must_not.iter_mut().any(|scorer| {
228 let doc = scorer.seek(candidate);
229 doc == candidate
230 });
231
232 if !excluded {
233 self.current_doc = candidate;
234 return candidate;
235 }
236
237 if !self.must.is_empty() {
239 for scorer in &mut self.must {
240 scorer.advance();
241 }
242 } else {
243 for scorer in &mut self.should {
245 if scorer.doc() <= candidate && scorer.doc() != TERMINATED {
246 scorer.seek(candidate + 1);
247 }
248 }
249 }
250 }
251 }
252}
253
254impl Scorer for BooleanScorer<'_> {
255 fn doc(&self) -> DocId {
256 self.current_doc
257 }
258
259 fn score(&self) -> Score {
260 let mut total = 0.0;
261
262 for scorer in &self.must {
263 if scorer.doc() == self.current_doc {
264 total += scorer.score();
265 }
266 }
267
268 for scorer in &self.should {
269 if scorer.doc() == self.current_doc {
270 total += scorer.score();
271 }
272 }
273
274 total
275 }
276
277 fn advance(&mut self) -> DocId {
278 if !self.must.is_empty() {
279 for scorer in &mut self.must {
280 scorer.advance();
281 }
282 } else {
283 for scorer in &mut self.should {
284 if scorer.doc() == self.current_doc {
285 scorer.advance();
286 }
287 }
288 }
289
290 self.find_next_match()
291 }
292
293 fn seek(&mut self, target: DocId) -> DocId {
294 for scorer in &mut self.must {
295 scorer.seek(target);
296 }
297
298 for scorer in &mut self.should {
299 scorer.seek(target);
300 }
301
302 self.find_next_match()
303 }
304
305 fn size_hint(&self) -> u32 {
306 if !self.must.is_empty() {
307 self.must.iter().map(|s| s.size_hint()).min().unwrap_or(0)
308 } else {
309 self.should.iter().map(|s| s.size_hint()).sum()
310 }
311 }
312}
313
314struct WandResultScorer {
316 results: Vec<ScoredDoc>,
317 position: usize,
318}
319
320impl WandResultScorer {
321 fn new(results: Vec<ScoredDoc>) -> Self {
322 Self {
323 results,
324 position: 0,
325 }
326 }
327}
328
329impl Scorer for WandResultScorer {
330 fn doc(&self) -> DocId {
331 if self.position < self.results.len() {
332 self.results[self.position].doc_id
333 } else {
334 TERMINATED
335 }
336 }
337
338 fn score(&self) -> Score {
339 if self.position < self.results.len() {
340 self.results[self.position].score
341 } else {
342 0.0
343 }
344 }
345
346 fn advance(&mut self) -> DocId {
347 self.position += 1;
348 self.doc()
349 }
350
351 fn seek(&mut self, target: DocId) -> DocId {
352 while self.position < self.results.len() && self.results[self.position].doc_id < target {
353 self.position += 1;
354 }
355 self.doc()
356 }
357
358 fn size_hint(&self) -> u32 {
359 self.results.len() as u32
360 }
361}
362
363struct EmptyWandScorer;
365
366impl Scorer for EmptyWandScorer {
367 fn doc(&self) -> DocId {
368 TERMINATED
369 }
370
371 fn score(&self) -> Score {
372 0.0
373 }
374
375 fn advance(&mut self) -> DocId {
376 TERMINATED
377 }
378
379 fn seek(&mut self, _target: DocId) -> DocId {
380 TERMINATED
381 }
382
383 fn size_hint(&self) -> u32 {
384 0
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391 use crate::dsl::Field;
392 use crate::query::TermQuery;
393
394 #[test]
395 fn test_wand_eligible_pure_or_same_field() {
396 let query = BooleanQuery::new()
398 .should(TermQuery::text(Field(0), "hello"))
399 .should(TermQuery::text(Field(0), "world"))
400 .should(TermQuery::text(Field(0), "foo"));
401
402 assert!(
404 query
405 .should
406 .iter()
407 .all(|q| q.as_term_query_info().is_some())
408 );
409
410 let infos: Vec<_> = query
412 .should
413 .iter()
414 .filter_map(|q| q.as_term_query_info())
415 .collect();
416 assert_eq!(infos.len(), 3);
417 assert!(infos.iter().all(|i| i.field == Field(0)));
418 }
419
420 #[test]
421 fn test_wand_not_eligible_different_fields() {
422 let query = BooleanQuery::new()
424 .should(TermQuery::text(Field(0), "hello"))
425 .should(TermQuery::text(Field(1), "world")); let infos: Vec<_> = query
428 .should
429 .iter()
430 .filter_map(|q| q.as_term_query_info())
431 .collect();
432 assert_eq!(infos.len(), 2);
433 assert!(infos[0].field != infos[1].field);
435 }
436
437 #[test]
438 fn test_wand_not_eligible_with_must() {
439 let query = BooleanQuery::new()
441 .must(TermQuery::text(Field(0), "required"))
442 .should(TermQuery::text(Field(0), "hello"))
443 .should(TermQuery::text(Field(0), "world"));
444
445 assert!(!query.must.is_empty());
447 }
448
449 #[test]
450 fn test_wand_not_eligible_with_must_not() {
451 let query = BooleanQuery::new()
453 .should(TermQuery::text(Field(0), "hello"))
454 .should(TermQuery::text(Field(0), "world"))
455 .must_not(TermQuery::text(Field(0), "excluded"));
456
457 assert!(!query.must_not.is_empty());
459 }
460
461 #[test]
462 fn test_wand_not_eligible_single_term() {
463 let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
465
466 assert_eq!(query.should.len(), 1);
468 }
469
470 #[test]
471 fn test_term_query_info_extraction() {
472 let term_query = TermQuery::text(Field(42), "test");
473 let info = term_query.as_term_query_info();
474
475 assert!(info.is_some());
476 let info = info.unwrap();
477 assert_eq!(info.field, Field(42));
478 assert_eq!(info.term, b"test");
479 }
480
481 #[test]
482 fn test_boolean_query_no_term_info() {
483 let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
485
486 assert!(query.as_term_query_info().is_none());
487 }
488}