1use std::sync::Arc;
4
5use crate::segment::SegmentReader;
6use crate::structures::TERMINATED;
7use crate::{DocId, Score};
8
9use super::{CountFuture, GlobalStats, MaxScoreExecutor, Query, ScoredDoc, Scorer, ScorerFuture};
10
11#[derive(Default, Clone)]
16pub struct BooleanQuery {
17 pub must: Vec<Arc<dyn Query>>,
18 pub should: Vec<Arc<dyn Query>>,
19 pub must_not: Vec<Arc<dyn Query>>,
20 global_stats: Option<Arc<GlobalStats>>,
22}
23
24impl std::fmt::Debug for BooleanQuery {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("BooleanQuery")
27 .field("must_count", &self.must.len())
28 .field("should_count", &self.should.len())
29 .field("must_not_count", &self.must_not.len())
30 .field("has_global_stats", &self.global_stats.is_some())
31 .finish()
32 }
33}
34
35impl BooleanQuery {
36 pub fn new() -> Self {
37 Self::default()
38 }
39
40 pub fn must(mut self, query: impl Query + 'static) -> Self {
41 self.must.push(Arc::new(query));
42 self
43 }
44
45 pub fn should(mut self, query: impl Query + 'static) -> Self {
46 self.should.push(Arc::new(query));
47 self
48 }
49
50 pub fn must_not(mut self, query: impl Query + 'static) -> Self {
51 self.must_not.push(Arc::new(query));
52 self
53 }
54
55 pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
57 self.global_stats = Some(stats);
58 self
59 }
60}
61
62fn maxscore_eligible(
65 should: &[Arc<dyn Query>],
66) -> Option<(Vec<super::TermQueryInfo>, crate::Field)> {
67 let term_infos: Vec<_> = should
68 .iter()
69 .filter_map(|q| q.as_term_query_info())
70 .collect();
71 if term_infos.len() != should.len() {
72 return None;
73 }
74 let first_field = term_infos[0].field;
75 if !term_infos.iter().all(|t| t.field == first_field) {
76 return None;
77 }
78 Some((term_infos, first_field))
79}
80
81fn compute_idf(
83 posting_list: &crate::structures::BlockPostingList,
84 field: crate::Field,
85 term: &[u8],
86 num_docs: f32,
87 global_stats: Option<&Arc<GlobalStats>>,
88) -> f32 {
89 if let Some(stats) = global_stats {
90 let global_idf = stats.text_idf(field, &String::from_utf8_lossy(term));
91 if global_idf > 0.0 {
92 return global_idf;
93 }
94 }
95 let doc_freq = posting_list.doc_count() as f32;
96 super::bm25_idf(doc_freq, num_docs)
97}
98
99fn maxscore_scorer_from_postings<'a>(
101 posting_lists: Vec<(crate::structures::BlockPostingList, f32)>,
102 avg_field_len: f32,
103 limit: usize,
104 predicate: Option<super::DocPredicate<'a>>,
105) -> crate::Result<Box<dyn Scorer + 'a>> {
106 if posting_lists.is_empty() {
107 return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>);
108 }
109 let mut executor = MaxScoreExecutor::text(posting_lists, avg_field_len, limit);
110 executor.set_predicate(predicate);
111 let results = executor.execute_sync()?;
112 Ok(Box::new(TopKResultScorer::new(results)) as Box<dyn Scorer + 'a>)
113}
114
115async fn try_maxscore_scorer<'a>(
117 should: &[Arc<dyn Query>],
118 reader: &'a SegmentReader,
119 limit: usize,
120 global_stats: Option<&Arc<GlobalStats>>,
121 predicate: Option<super::DocPredicate<'a>>,
122) -> crate::Result<Option<Box<dyn Scorer + 'a>>> {
123 let (mut term_infos, field) = match maxscore_eligible(should) {
124 Some(v) => v,
125 None => return Ok(None),
126 };
127
128 let avg_field_len = global_stats
129 .map(|s| s.avg_field_len(field))
130 .unwrap_or_else(|| reader.avg_field_len(field));
131 let num_docs = reader.num_docs() as f32;
132
133 let mut posting_lists: Vec<(crate::structures::BlockPostingList, f32)> =
134 Vec::with_capacity(term_infos.len());
135 for info in term_infos.drain(..) {
136 if let Some(pl) = reader.get_postings(info.field, &info.term).await? {
137 let idf = compute_idf(&pl, info.field, &info.term, num_docs, global_stats);
138 posting_lists.push((pl, idf));
139 }
140 }
141
142 Ok(Some(maxscore_scorer_from_postings(
143 posting_lists,
144 avg_field_len,
145 limit,
146 predicate,
147 )?))
148}
149
150#[cfg(feature = "sync")]
152fn try_maxscore_scorer_sync<'a>(
153 should: &[Arc<dyn Query>],
154 reader: &'a SegmentReader,
155 limit: usize,
156 global_stats: Option<&Arc<GlobalStats>>,
157 predicate: Option<super::DocPredicate<'a>>,
158) -> crate::Result<Option<Box<dyn Scorer + 'a>>> {
159 let (mut term_infos, field) = match maxscore_eligible(should) {
160 Some(v) => v,
161 None => return Ok(None),
162 };
163
164 let avg_field_len = global_stats
165 .map(|s| s.avg_field_len(field))
166 .unwrap_or_else(|| reader.avg_field_len(field));
167 let num_docs = reader.num_docs() as f32;
168
169 let mut posting_lists: Vec<(crate::structures::BlockPostingList, f32)> =
170 Vec::with_capacity(term_infos.len());
171 for info in term_infos.drain(..) {
172 if let Some(pl) = reader.get_postings_sync(info.field, &info.term)? {
173 let idf = compute_idf(&pl, info.field, &info.term, num_docs, global_stats);
174 posting_lists.push((pl, idf));
175 }
176 }
177
178 Ok(Some(maxscore_scorer_from_postings(
179 posting_lists,
180 avg_field_len,
181 limit,
182 predicate,
183 )?))
184}
185
186impl Query for BooleanQuery {
187 fn scorer<'a>(
188 &self,
189 reader: &'a SegmentReader,
190 limit: usize,
191 predicate: Option<super::DocPredicate<'a>>,
192 ) -> ScorerFuture<'a> {
193 let must = self.must.clone();
195 let should = self.should.clone();
196 let must_not = self.must_not.clone();
197 let global_stats = self.global_stats.clone();
198
199 Box::pin(async move {
200 if must.is_empty()
203 && must_not.is_empty()
204 && should.len() >= 2
205 && let Some(scorer) =
206 try_maxscore_scorer(&should, reader, limit, global_stats.as_ref(), predicate)
207 .await?
208 {
209 return Ok(scorer);
210 }
211
212 let mut must_scorers = Vec::with_capacity(must.len());
215 for q in &must {
216 must_scorers.push(q.scorer(reader, limit, None).await?);
217 }
218
219 let mut should_scorers = Vec::with_capacity(should.len());
220 for q in &should {
221 should_scorers.push(q.scorer(reader, limit, None).await?);
222 }
223
224 let mut must_not_scorers = Vec::with_capacity(must_not.len());
225 for q in &must_not {
226 must_not_scorers.push(q.scorer(reader, limit, None).await?);
227 }
228
229 let mut scorer = BooleanScorer {
230 must: must_scorers,
231 should: should_scorers,
232 must_not: must_not_scorers,
233 current_doc: 0,
234 };
235 scorer.current_doc = scorer.find_next_match();
237 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
238 })
239 }
240
241 #[cfg(feature = "sync")]
242 fn scorer_sync<'a>(
243 &self,
244 reader: &'a SegmentReader,
245 limit: usize,
246 predicate: Option<super::DocPredicate<'a>>,
247 ) -> crate::Result<Box<dyn Scorer + 'a>> {
248 if self.must.is_empty()
250 && self.must_not.is_empty()
251 && self.should.len() >= 2
252 && let Some(scorer) = try_maxscore_scorer_sync(
253 &self.should,
254 reader,
255 limit,
256 self.global_stats.as_ref(),
257 predicate,
258 )?
259 {
260 return Ok(scorer);
261 }
262
263 let mut must_scorers = Vec::with_capacity(self.must.len());
265 for q in &self.must {
266 must_scorers.push(q.scorer_sync(reader, limit, None)?);
267 }
268
269 let mut should_scorers = Vec::with_capacity(self.should.len());
270 for q in &self.should {
271 should_scorers.push(q.scorer_sync(reader, limit, None)?);
272 }
273
274 let mut must_not_scorers = Vec::with_capacity(self.must_not.len());
275 for q in &self.must_not {
276 must_not_scorers.push(q.scorer_sync(reader, limit, None)?);
277 }
278
279 let mut scorer = BooleanScorer {
280 must: must_scorers,
281 should: should_scorers,
282 must_not: must_not_scorers,
283 current_doc: 0,
284 };
285 scorer.current_doc = scorer.find_next_match();
286 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
287 }
288
289 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
290 let must = self.must.clone();
291 let should = self.should.clone();
292
293 Box::pin(async move {
294 if !must.is_empty() {
295 let mut estimates = Vec::with_capacity(must.len());
296 for q in &must {
297 estimates.push(q.count_estimate(reader).await?);
298 }
299 estimates
300 .into_iter()
301 .min()
302 .ok_or_else(|| crate::Error::Corruption("Empty must clause".to_string()))
303 } else if !should.is_empty() {
304 let mut sum = 0u32;
305 for q in &should {
306 sum += q.count_estimate(reader).await?;
307 }
308 Ok(sum)
309 } else {
310 Ok(0)
311 }
312 })
313 }
314}
315
316struct BooleanScorer<'a> {
317 must: Vec<Box<dyn Scorer + 'a>>,
318 should: Vec<Box<dyn Scorer + 'a>>,
319 must_not: Vec<Box<dyn Scorer + 'a>>,
320 current_doc: DocId,
321}
322
323impl BooleanScorer<'_> {
324 fn find_next_match(&mut self) -> DocId {
325 if self.must.is_empty() && self.should.is_empty() {
326 return TERMINATED;
327 }
328
329 loop {
330 let candidate = if !self.must.is_empty() {
331 let mut max_doc = self
332 .must
333 .iter()
334 .map(|s| s.doc())
335 .max()
336 .unwrap_or(TERMINATED);
337
338 if max_doc == TERMINATED {
339 return TERMINATED;
340 }
341
342 loop {
343 let mut all_match = true;
344 for scorer in &mut self.must {
345 let doc = scorer.seek(max_doc);
346 if doc == TERMINATED {
347 return TERMINATED;
348 }
349 if doc > max_doc {
350 max_doc = doc;
351 all_match = false;
352 break;
353 }
354 }
355 if all_match {
356 break;
357 }
358 }
359 max_doc
360 } else {
361 self.should
362 .iter()
363 .map(|s| s.doc())
364 .filter(|&d| d != TERMINATED)
365 .min()
366 .unwrap_or(TERMINATED)
367 };
368
369 if candidate == TERMINATED {
370 return TERMINATED;
371 }
372
373 let excluded = self.must_not.iter_mut().any(|scorer| {
374 let doc = scorer.seek(candidate);
375 doc == candidate
376 });
377
378 if !excluded {
379 self.current_doc = candidate;
380 return candidate;
381 }
382
383 if !self.must.is_empty() {
385 for scorer in &mut self.must {
386 scorer.advance();
387 }
388 } else {
389 for scorer in &mut self.should {
391 if scorer.doc() <= candidate && scorer.doc() != TERMINATED {
392 scorer.seek(candidate + 1);
393 }
394 }
395 }
396 }
397 }
398}
399
400impl Scorer for BooleanScorer<'_> {
401 fn doc(&self) -> DocId {
402 self.current_doc
403 }
404
405 fn score(&self) -> Score {
406 let mut total = 0.0;
407
408 for scorer in &self.must {
409 if scorer.doc() == self.current_doc {
410 total += scorer.score();
411 }
412 }
413
414 for scorer in &self.should {
415 if scorer.doc() == self.current_doc {
416 total += scorer.score();
417 }
418 }
419
420 total
421 }
422
423 fn advance(&mut self) -> DocId {
424 if !self.must.is_empty() {
425 for scorer in &mut self.must {
426 scorer.advance();
427 }
428 } else {
429 for scorer in &mut self.should {
430 if scorer.doc() == self.current_doc {
431 scorer.advance();
432 }
433 }
434 }
435
436 self.current_doc = self.find_next_match();
437 self.current_doc
438 }
439
440 fn seek(&mut self, target: DocId) -> DocId {
441 for scorer in &mut self.must {
442 scorer.seek(target);
443 }
444
445 for scorer in &mut self.should {
446 scorer.seek(target);
447 }
448
449 self.current_doc = self.find_next_match();
450 self.current_doc
451 }
452
453 fn size_hint(&self) -> u32 {
454 if !self.must.is_empty() {
455 self.must.iter().map(|s| s.size_hint()).min().unwrap_or(0)
456 } else {
457 self.should.iter().map(|s| s.size_hint()).sum()
458 }
459 }
460}
461
462struct TopKResultScorer {
464 results: Vec<ScoredDoc>,
465 position: usize,
466}
467
468impl TopKResultScorer {
469 fn new(results: Vec<ScoredDoc>) -> Self {
470 Self {
471 results,
472 position: 0,
473 }
474 }
475}
476
477impl Scorer for TopKResultScorer {
478 fn doc(&self) -> DocId {
479 if self.position < self.results.len() {
480 self.results[self.position].doc_id
481 } else {
482 TERMINATED
483 }
484 }
485
486 fn score(&self) -> Score {
487 if self.position < self.results.len() {
488 self.results[self.position].score
489 } else {
490 0.0
491 }
492 }
493
494 fn advance(&mut self) -> DocId {
495 self.position += 1;
496 self.doc()
497 }
498
499 fn seek(&mut self, target: DocId) -> DocId {
500 while self.position < self.results.len() && self.results[self.position].doc_id < target {
501 self.position += 1;
502 }
503 self.doc()
504 }
505
506 fn size_hint(&self) -> u32 {
507 self.results.len() as u32
508 }
509}
510
511struct EmptyScorer;
513
514impl Scorer for EmptyScorer {
515 fn doc(&self) -> DocId {
516 TERMINATED
517 }
518
519 fn score(&self) -> Score {
520 0.0
521 }
522
523 fn advance(&mut self) -> DocId {
524 TERMINATED
525 }
526
527 fn seek(&mut self, _target: DocId) -> DocId {
528 TERMINATED
529 }
530
531 fn size_hint(&self) -> u32 {
532 0
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539 use crate::dsl::Field;
540 use crate::query::TermQuery;
541
542 #[test]
543 fn test_maxscore_eligible_pure_or_same_field() {
544 let query = BooleanQuery::new()
546 .should(TermQuery::text(Field(0), "hello"))
547 .should(TermQuery::text(Field(0), "world"))
548 .should(TermQuery::text(Field(0), "foo"));
549
550 assert!(
552 query
553 .should
554 .iter()
555 .all(|q| q.as_term_query_info().is_some())
556 );
557
558 let infos: Vec<_> = query
560 .should
561 .iter()
562 .filter_map(|q| q.as_term_query_info())
563 .collect();
564 assert_eq!(infos.len(), 3);
565 assert!(infos.iter().all(|i| i.field == Field(0)));
566 }
567
568 #[test]
569 fn test_maxscore_not_eligible_different_fields() {
570 let query = BooleanQuery::new()
572 .should(TermQuery::text(Field(0), "hello"))
573 .should(TermQuery::text(Field(1), "world")); let infos: Vec<_> = query
576 .should
577 .iter()
578 .filter_map(|q| q.as_term_query_info())
579 .collect();
580 assert_eq!(infos.len(), 2);
581 assert!(infos[0].field != infos[1].field);
583 }
584
585 #[test]
586 fn test_maxscore_not_eligible_with_must() {
587 let query = BooleanQuery::new()
589 .must(TermQuery::text(Field(0), "required"))
590 .should(TermQuery::text(Field(0), "hello"))
591 .should(TermQuery::text(Field(0), "world"));
592
593 assert!(!query.must.is_empty());
595 }
596
597 #[test]
598 fn test_maxscore_not_eligible_with_must_not() {
599 let query = BooleanQuery::new()
601 .should(TermQuery::text(Field(0), "hello"))
602 .should(TermQuery::text(Field(0), "world"))
603 .must_not(TermQuery::text(Field(0), "excluded"));
604
605 assert!(!query.must_not.is_empty());
607 }
608
609 #[test]
610 fn test_maxscore_not_eligible_single_term() {
611 let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
613
614 assert_eq!(query.should.len(), 1);
616 }
617
618 #[test]
619 fn test_term_query_info_extraction() {
620 let term_query = TermQuery::text(Field(42), "test");
621 let info = term_query.as_term_query_info();
622
623 assert!(info.is_some());
624 let info = info.unwrap();
625 assert_eq!(info.field, Field(42));
626 assert_eq!(info.term, b"test");
627 }
628
629 #[test]
630 fn test_boolean_query_no_term_info() {
631 let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
633
634 assert!(query.as_term_query_info().is_none());
635 }
636}