1use std::sync::Arc;
4
5use crate::segment::SegmentReader;
6use crate::structures::TERMINATED;
7use crate::{DocId, Score};
8
9use super::{
10 CountFuture, GlobalStats, MaxScoreExecutor, Query, ScoredDoc, Scorer, ScorerFuture,
11 SparseTermQueryInfo,
12};
13
14#[derive(Default, Clone)]
19pub struct BooleanQuery {
20 pub must: Vec<Arc<dyn Query>>,
21 pub should: Vec<Arc<dyn Query>>,
22 pub must_not: Vec<Arc<dyn Query>>,
23 global_stats: Option<Arc<GlobalStats>>,
25}
26
27impl std::fmt::Debug for BooleanQuery {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("BooleanQuery")
30 .field("must_count", &self.must.len())
31 .field("should_count", &self.should.len())
32 .field("must_not_count", &self.must_not.len())
33 .field("has_global_stats", &self.global_stats.is_some())
34 .finish()
35 }
36}
37
38impl std::fmt::Display for BooleanQuery {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 write!(f, "Boolean(")?;
41 let mut first = true;
42 for q in &self.must {
43 if !first {
44 write!(f, " ")?;
45 }
46 write!(f, "+{}", q)?;
47 first = false;
48 }
49 for q in &self.should {
50 if !first {
51 write!(f, " ")?;
52 }
53 write!(f, "{}", q)?;
54 first = false;
55 }
56 for q in &self.must_not {
57 if !first {
58 write!(f, " ")?;
59 }
60 write!(f, "-{}", q)?;
61 first = false;
62 }
63 write!(f, ")")
64 }
65}
66
67impl BooleanQuery {
68 pub fn new() -> Self {
69 Self::default()
70 }
71
72 pub fn must(mut self, query: impl Query + 'static) -> Self {
73 self.must.push(Arc::new(query));
74 self
75 }
76
77 pub fn should(mut self, query: impl Query + 'static) -> Self {
78 self.should.push(Arc::new(query));
79 self
80 }
81
82 pub fn must_not(mut self, query: impl Query + 'static) -> Self {
83 self.must_not.push(Arc::new(query));
84 self
85 }
86
87 pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
89 self.global_stats = Some(stats);
90 self
91 }
92}
93
94fn compute_idf(
96 posting_list: &crate::structures::BlockPostingList,
97 field: crate::Field,
98 term: &[u8],
99 num_docs: f32,
100 global_stats: Option<&Arc<GlobalStats>>,
101) -> f32 {
102 if let Some(stats) = global_stats {
103 let global_idf = stats.text_idf(field, &String::from_utf8_lossy(term));
104 if global_idf > 0.0 {
105 return global_idf;
106 }
107 }
108 let doc_freq = posting_list.doc_count() as f32;
109 super::bm25_idf(doc_freq, num_docs)
110}
111
112fn prepare_text_maxscore(
115 should: &[Arc<dyn Query>],
116 reader: &SegmentReader,
117 global_stats: Option<&Arc<GlobalStats>>,
118) -> Option<(Vec<super::TermQueryInfo>, crate::Field, f32, f32)> {
119 let infos: Vec<_> = should
120 .iter()
121 .filter_map(|q| q.as_term_query_info())
122 .collect();
123 if infos.len() != should.len() {
124 return None;
125 }
126 let field = infos[0].field;
127 if !infos.iter().all(|t| t.field == field) {
128 return None;
129 }
130 let avg_field_len = global_stats
131 .map(|s| s.avg_field_len(field))
132 .unwrap_or_else(|| reader.avg_field_len(field));
133 let num_docs = reader.num_docs() as f32;
134 Some((infos, field, avg_field_len, num_docs))
135}
136
137fn finish_text_maxscore<'a>(
139 posting_lists: Vec<(crate::structures::BlockPostingList, f32)>,
140 avg_field_len: f32,
141 limit: usize,
142) -> crate::Result<Box<dyn Scorer + 'a>> {
143 if posting_lists.is_empty() {
144 return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>);
145 }
146 let results = MaxScoreExecutor::text(posting_lists, avg_field_len, limit).execute_sync()?;
147 Ok(Box::new(TopKResultScorer::new(results)) as Box<dyn Scorer + 'a>)
148}
149
150async fn try_maxscore_scorer<'a>(
152 should: &[Arc<dyn Query>],
153 reader: &'a SegmentReader,
154 limit: usize,
155 global_stats: Option<&Arc<GlobalStats>>,
156) -> crate::Result<Option<Box<dyn Scorer + 'a>>> {
157 let (mut infos, _field, avg_field_len, num_docs) =
158 match prepare_text_maxscore(should, reader, global_stats) {
159 Some(v) => v,
160 None => return Ok(None),
161 };
162 let mut posting_lists = Vec::with_capacity(infos.len());
163 for info in infos.drain(..) {
164 if let Some(pl) = reader.get_postings(info.field, &info.term).await? {
165 let idf = compute_idf(&pl, info.field, &info.term, num_docs, global_stats);
166 posting_lists.push((pl, idf));
167 }
168 }
169 Ok(Some(finish_text_maxscore(
170 posting_lists,
171 avg_field_len,
172 limit,
173 )?))
174}
175
176#[cfg(feature = "sync")]
178fn try_maxscore_scorer_sync<'a>(
179 should: &[Arc<dyn Query>],
180 reader: &'a SegmentReader,
181 limit: usize,
182 global_stats: Option<&Arc<GlobalStats>>,
183) -> crate::Result<Option<Box<dyn Scorer + 'a>>> {
184 let (mut infos, _field, avg_field_len, num_docs) =
185 match prepare_text_maxscore(should, reader, global_stats) {
186 Some(v) => v,
187 None => return Ok(None),
188 };
189 let mut posting_lists = Vec::with_capacity(infos.len());
190 for info in infos.drain(..) {
191 if let Some(pl) = reader.get_postings_sync(info.field, &info.term)? {
192 let idf = compute_idf(&pl, info.field, &info.term, num_docs, global_stats);
193 posting_lists.push((pl, idf));
194 }
195 }
196 Ok(Some(finish_text_maxscore(
197 posting_lists,
198 avg_field_len,
199 limit,
200 )?))
201}
202
203fn prepare_sparse_maxscore<'a>(
206 should: &[Arc<dyn Query>],
207 reader: &'a SegmentReader,
208 limit: usize,
209) -> Option<Result<MaxScoreExecutor<'a>, Box<dyn Scorer + 'a>>> {
210 let infos: Vec<SparseTermQueryInfo> = should
211 .iter()
212 .filter_map(|q| q.as_sparse_term_query_info())
213 .collect();
214 if infos.len() != should.len() {
215 return None;
216 }
217 let field = infos[0].field;
218 if !infos.iter().all(|t| t.field == field) {
219 return None;
220 }
221 let si = match reader.sparse_index(field) {
222 Some(si) => si,
223 None => return Some(Err(Box::new(EmptyScorer))),
224 };
225 let query_terms: Vec<(u32, f32)> = infos
226 .iter()
227 .filter(|info| si.has_dimension(info.dim_id))
228 .map(|info| (info.dim_id, info.weight))
229 .collect();
230 if query_terms.is_empty() {
231 return Some(Err(Box::new(EmptyScorer)));
232 }
233 let executor_limit = (limit as f32 * infos[0].over_fetch_factor).ceil() as usize;
234 Some(Ok(MaxScoreExecutor::sparse(
235 si,
236 query_terms,
237 executor_limit,
238 infos[0].heap_factor,
239 )))
240}
241
242fn combine_sparse_results<'a>(
244 raw: Vec<ScoredDoc>,
245 combiner: super::MultiValueCombiner,
246 limit: usize,
247) -> Box<dyn Scorer + 'a> {
248 let combined = crate::segment::combine_ordinal_results(
249 raw.into_iter().map(|r| (r.doc_id, r.ordinal, r.score)),
250 combiner,
251 limit,
252 );
253 let scored: Vec<ScoredDoc> = combined
254 .into_iter()
255 .map(|r| ScoredDoc {
256 doc_id: r.doc_id,
257 score: r.score,
258 ordinal: 0,
259 })
260 .collect();
261 Box::new(TopKResultScorer::new(scored))
262}
263
264async fn try_sparse_maxscore_scorer<'a>(
266 should: &[Arc<dyn Query>],
267 reader: &'a SegmentReader,
268 limit: usize,
269) -> crate::Result<Option<Box<dyn Scorer + 'a>>> {
270 let executor = match prepare_sparse_maxscore(should, reader, limit) {
271 None => return Ok(None),
272 Some(Err(empty)) => return Ok(Some(empty)),
273 Some(Ok(e)) => e,
274 };
275 let combiner = should[0].as_sparse_term_query_info().unwrap().combiner;
276 let raw = executor.execute().await?;
277 Ok(Some(combine_sparse_results(raw, combiner, limit)))
278}
279
280#[cfg(feature = "sync")]
282fn try_sparse_maxscore_scorer_sync<'a>(
283 should: &[Arc<dyn Query>],
284 reader: &'a SegmentReader,
285 limit: usize,
286) -> crate::Result<Option<Box<dyn Scorer + 'a>>> {
287 let executor = match prepare_sparse_maxscore(should, reader, limit) {
288 None => return Ok(None),
289 Some(Err(empty)) => return Ok(Some(empty)),
290 Some(Ok(e)) => e,
291 };
292 let combiner = should[0].as_sparse_term_query_info().unwrap().combiner;
293 let raw = executor.execute_sync()?;
294 Ok(Some(combine_sparse_results(raw, combiner, limit)))
295}
296
297impl Query for BooleanQuery {
298 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
299 let must = self.must.clone();
301 let should = self.should.clone();
302 let must_not = self.must_not.clone();
303 let global_stats = self.global_stats.clone();
304
305 Box::pin(async move {
306 if must_not.is_empty() {
308 if must.len() == 1 && should.is_empty() {
309 return must[0].scorer(reader, limit).await;
310 }
311 if should.len() == 1 && must.is_empty() {
312 return should[0].scorer(reader, limit).await;
313 }
314 }
315
316 if must.is_empty() && must_not.is_empty() && should.len() >= 2 {
319 if let Some(scorer) =
321 try_maxscore_scorer(&should, reader, limit, global_stats.as_ref()).await?
322 {
323 return Ok(scorer);
324 }
325 if let Some(scorer) = try_sparse_maxscore_scorer(&should, reader, limit).await? {
327 return Ok(scorer);
328 }
329 }
330
331 let mut must_scorers = Vec::with_capacity(must.len());
333 for q in &must {
334 must_scorers.push(q.scorer(reader, limit).await?);
335 }
336
337 let mut should_scorers = Vec::with_capacity(should.len());
338 for q in &should {
339 should_scorers.push(q.scorer(reader, limit).await?);
340 }
341
342 let mut must_not_scorers = Vec::with_capacity(must_not.len());
343 for q in &must_not {
344 must_not_scorers.push(q.scorer(reader, limit).await?);
345 }
346
347 let mut scorer = BooleanScorer {
348 must: must_scorers,
349 should: should_scorers,
350 must_not: must_not_scorers,
351 current_doc: 0,
352 };
353 scorer.current_doc = scorer.find_next_match();
355 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
356 })
357 }
358
359 #[cfg(feature = "sync")]
360 fn scorer_sync<'a>(
361 &self,
362 reader: &'a SegmentReader,
363 limit: usize,
364 ) -> crate::Result<Box<dyn Scorer + 'a>> {
365 if self.must_not.is_empty() {
367 if self.must.len() == 1 && self.should.is_empty() {
368 return self.must[0].scorer_sync(reader, limit);
369 }
370 if self.should.len() == 1 && self.must.is_empty() {
371 return self.should[0].scorer_sync(reader, limit);
372 }
373 }
374
375 if self.must.is_empty() && self.must_not.is_empty() && self.should.len() >= 2 {
377 if let Some(scorer) =
378 try_maxscore_scorer_sync(&self.should, reader, limit, self.global_stats.as_ref())?
379 {
380 return Ok(scorer);
381 }
382 if let Some(scorer) = try_sparse_maxscore_scorer_sync(&self.should, reader, limit)? {
383 return Ok(scorer);
384 }
385 }
386
387 let mut must_scorers = Vec::with_capacity(self.must.len());
389 for q in &self.must {
390 must_scorers.push(q.scorer_sync(reader, limit)?);
391 }
392
393 let mut should_scorers = Vec::with_capacity(self.should.len());
394 for q in &self.should {
395 should_scorers.push(q.scorer_sync(reader, limit)?);
396 }
397
398 let mut must_not_scorers = Vec::with_capacity(self.must_not.len());
399 for q in &self.must_not {
400 must_not_scorers.push(q.scorer_sync(reader, limit)?);
401 }
402
403 let mut scorer = BooleanScorer {
404 must: must_scorers,
405 should: should_scorers,
406 must_not: must_not_scorers,
407 current_doc: 0,
408 };
409 scorer.current_doc = scorer.find_next_match();
410 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
411 }
412
413 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
414 let must = self.must.clone();
415 let should = self.should.clone();
416
417 Box::pin(async move {
418 if !must.is_empty() {
419 let mut estimates = Vec::with_capacity(must.len());
420 for q in &must {
421 estimates.push(q.count_estimate(reader).await?);
422 }
423 estimates
424 .into_iter()
425 .min()
426 .ok_or_else(|| crate::Error::Corruption("Empty must clause".to_string()))
427 } else if !should.is_empty() {
428 let mut sum = 0u32;
429 for q in &should {
430 sum = sum.saturating_add(q.count_estimate(reader).await?);
431 }
432 Ok(sum)
433 } else {
434 Ok(0)
435 }
436 })
437 }
438}
439
440struct BooleanScorer<'a> {
441 must: Vec<Box<dyn Scorer + 'a>>,
442 should: Vec<Box<dyn Scorer + 'a>>,
443 must_not: Vec<Box<dyn Scorer + 'a>>,
444 current_doc: DocId,
445}
446
447impl BooleanScorer<'_> {
448 fn find_next_match(&mut self) -> DocId {
449 if self.must.is_empty() && self.should.is_empty() {
450 return TERMINATED;
451 }
452
453 loop {
454 let candidate = if !self.must.is_empty() {
455 let mut max_doc = self
456 .must
457 .iter()
458 .map(|s| s.doc())
459 .max()
460 .unwrap_or(TERMINATED);
461
462 if max_doc == TERMINATED {
463 return TERMINATED;
464 }
465
466 loop {
467 let mut all_match = true;
468 for scorer in &mut self.must {
469 let doc = scorer.seek(max_doc);
470 if doc == TERMINATED {
471 return TERMINATED;
472 }
473 if doc > max_doc {
474 max_doc = doc;
475 all_match = false;
476 break;
477 }
478 }
479 if all_match {
480 break;
481 }
482 }
483 max_doc
484 } else {
485 self.should
486 .iter()
487 .map(|s| s.doc())
488 .filter(|&d| d != TERMINATED)
489 .min()
490 .unwrap_or(TERMINATED)
491 };
492
493 if candidate == TERMINATED {
494 return TERMINATED;
495 }
496
497 let excluded = self.must_not.iter_mut().any(|scorer| {
498 let doc = scorer.seek(candidate);
499 doc == candidate
500 });
501
502 if !excluded {
503 for scorer in &mut self.should {
505 scorer.seek(candidate);
506 }
507 self.current_doc = candidate;
508 return candidate;
509 }
510
511 if !self.must.is_empty() {
513 for scorer in &mut self.must {
514 scorer.advance();
515 }
516 } else {
517 for scorer in &mut self.should {
519 if scorer.doc() <= candidate && scorer.doc() != TERMINATED {
520 scorer.seek(candidate + 1);
521 }
522 }
523 }
524 }
525 }
526}
527
528impl super::docset::DocSet for BooleanScorer<'_> {
529 fn doc(&self) -> DocId {
530 self.current_doc
531 }
532
533 fn advance(&mut self) -> DocId {
534 if !self.must.is_empty() {
535 for scorer in &mut self.must {
536 scorer.advance();
537 }
538 } else {
539 for scorer in &mut self.should {
540 if scorer.doc() == self.current_doc {
541 scorer.advance();
542 }
543 }
544 }
545
546 self.current_doc = self.find_next_match();
547 self.current_doc
548 }
549
550 fn seek(&mut self, target: DocId) -> DocId {
551 for scorer in &mut self.must {
552 scorer.seek(target);
553 }
554
555 for scorer in &mut self.should {
556 scorer.seek(target);
557 }
558
559 self.current_doc = self.find_next_match();
560 self.current_doc
561 }
562
563 fn size_hint(&self) -> u32 {
564 if !self.must.is_empty() {
565 self.must.iter().map(|s| s.size_hint()).min().unwrap_or(0)
566 } else {
567 self.should.iter().map(|s| s.size_hint()).sum()
568 }
569 }
570}
571
572impl Scorer for BooleanScorer<'_> {
573 fn score(&self) -> Score {
574 let mut total = 0.0;
575
576 for scorer in &self.must {
577 if scorer.doc() == self.current_doc {
578 total += scorer.score();
579 }
580 }
581
582 for scorer in &self.should {
583 if scorer.doc() == self.current_doc {
584 total += scorer.score();
585 }
586 }
587
588 total
589 }
590}
591
592struct TopKResultScorer {
594 results: Vec<ScoredDoc>,
595 position: usize,
596}
597
598impl TopKResultScorer {
599 fn new(mut results: Vec<ScoredDoc>) -> Self {
600 results.sort_unstable_by_key(|r| r.doc_id);
602 Self {
603 results,
604 position: 0,
605 }
606 }
607}
608
609impl super::docset::DocSet for TopKResultScorer {
610 fn doc(&self) -> DocId {
611 if self.position < self.results.len() {
612 self.results[self.position].doc_id
613 } else {
614 TERMINATED
615 }
616 }
617
618 fn advance(&mut self) -> DocId {
619 self.position += 1;
620 self.doc()
621 }
622
623 fn seek(&mut self, target: DocId) -> DocId {
624 let remaining = &self.results[self.position..];
625 self.position += remaining.partition_point(|r| r.doc_id < target);
626 self.doc()
627 }
628
629 fn size_hint(&self) -> u32 {
630 (self.results.len() - self.position) as u32
631 }
632}
633
634impl Scorer for TopKResultScorer {
635 fn score(&self) -> Score {
636 if self.position < self.results.len() {
637 self.results[self.position].score
638 } else {
639 0.0
640 }
641 }
642}
643
644struct EmptyScorer;
646
647impl super::docset::DocSet for EmptyScorer {
648 fn doc(&self) -> DocId {
649 TERMINATED
650 }
651
652 fn advance(&mut self) -> DocId {
653 TERMINATED
654 }
655
656 fn seek(&mut self, _target: DocId) -> DocId {
657 TERMINATED
658 }
659
660 fn size_hint(&self) -> u32 {
661 0
662 }
663}
664
665impl Scorer for EmptyScorer {
666 fn score(&self) -> Score {
667 0.0
668 }
669}
670
671#[cfg(test)]
672mod tests {
673 use super::*;
674 use crate::dsl::Field;
675 use crate::query::TermQuery;
676
677 #[test]
678 fn test_maxscore_eligible_pure_or_same_field() {
679 let query = BooleanQuery::new()
681 .should(TermQuery::text(Field(0), "hello"))
682 .should(TermQuery::text(Field(0), "world"))
683 .should(TermQuery::text(Field(0), "foo"));
684
685 assert!(
687 query
688 .should
689 .iter()
690 .all(|q| q.as_term_query_info().is_some())
691 );
692
693 let infos: Vec<_> = query
695 .should
696 .iter()
697 .filter_map(|q| q.as_term_query_info())
698 .collect();
699 assert_eq!(infos.len(), 3);
700 assert!(infos.iter().all(|i| i.field == Field(0)));
701 }
702
703 #[test]
704 fn test_maxscore_not_eligible_different_fields() {
705 let query = BooleanQuery::new()
707 .should(TermQuery::text(Field(0), "hello"))
708 .should(TermQuery::text(Field(1), "world")); let infos: Vec<_> = query
711 .should
712 .iter()
713 .filter_map(|q| q.as_term_query_info())
714 .collect();
715 assert_eq!(infos.len(), 2);
716 assert!(infos[0].field != infos[1].field);
718 }
719
720 #[test]
721 fn test_maxscore_not_eligible_with_must() {
722 let query = BooleanQuery::new()
724 .must(TermQuery::text(Field(0), "required"))
725 .should(TermQuery::text(Field(0), "hello"))
726 .should(TermQuery::text(Field(0), "world"));
727
728 assert!(!query.must.is_empty());
730 }
731
732 #[test]
733 fn test_maxscore_not_eligible_with_must_not() {
734 let query = BooleanQuery::new()
736 .should(TermQuery::text(Field(0), "hello"))
737 .should(TermQuery::text(Field(0), "world"))
738 .must_not(TermQuery::text(Field(0), "excluded"));
739
740 assert!(!query.must_not.is_empty());
742 }
743
744 #[test]
745 fn test_maxscore_not_eligible_single_term() {
746 let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
748
749 assert_eq!(query.should.len(), 1);
751 }
752
753 #[test]
754 fn test_term_query_info_extraction() {
755 let term_query = TermQuery::text(Field(42), "test");
756 let info = term_query.as_term_query_info();
757
758 assert!(info.is_some());
759 let info = info.unwrap();
760 assert_eq!(info.field, Field(42));
761 assert_eq!(info.term, b"test");
762 }
763
764 #[test]
765 fn test_boolean_query_no_term_info() {
766 let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
768
769 assert!(query.as_term_query_info().is_none());
770 }
771}