Skip to main content

hermes_core/query/
boolean.rs

1//! Boolean query with MUST, SHOULD, and MUST_NOT clauses
2
3use std::sync::Arc;
4
5use crate::segment::SegmentReader;
6use crate::structures::TERMINATED;
7use crate::{DocId, Score};
8
9use super::{
10    CountFuture, GlobalStats, MaxScoreExecutor, Query, ScoredDoc, Scorer, ScorerFuture,
11    SparseTermQueryInfo,
12};
13
14/// Boolean query with MUST, SHOULD, and MUST_NOT clauses
15///
16/// When all clauses are SHOULD term queries on the same field, automatically
17/// uses MaxScore optimization for efficient top-k retrieval.
18#[derive(Default, Clone)]
19pub struct BooleanQuery {
20    pub must: Vec<Arc<dyn Query>>,
21    pub should: Vec<Arc<dyn Query>>,
22    pub must_not: Vec<Arc<dyn Query>>,
23    /// Optional global statistics for cross-segment IDF
24    global_stats: Option<Arc<GlobalStats>>,
25}
26
27impl std::fmt::Debug for BooleanQuery {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        f.debug_struct("BooleanQuery")
30            .field("must_count", &self.must.len())
31            .field("should_count", &self.should.len())
32            .field("must_not_count", &self.must_not.len())
33            .field("has_global_stats", &self.global_stats.is_some())
34            .finish()
35    }
36}
37
38impl std::fmt::Display for BooleanQuery {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        write!(f, "Boolean(")?;
41        let mut first = true;
42        for q in &self.must {
43            if !first {
44                write!(f, " ")?;
45            }
46            write!(f, "+{}", q)?;
47            first = false;
48        }
49        for q in &self.should {
50            if !first {
51                write!(f, " ")?;
52            }
53            write!(f, "{}", q)?;
54            first = false;
55        }
56        for q in &self.must_not {
57            if !first {
58                write!(f, " ")?;
59            }
60            write!(f, "-{}", q)?;
61            first = false;
62        }
63        write!(f, ")")
64    }
65}
66
67impl BooleanQuery {
68    pub fn new() -> Self {
69        Self::default()
70    }
71
72    pub fn must(mut self, query: impl Query + 'static) -> Self {
73        self.must.push(Arc::new(query));
74        self
75    }
76
77    pub fn should(mut self, query: impl Query + 'static) -> Self {
78        self.should.push(Arc::new(query));
79        self
80    }
81
82    pub fn must_not(mut self, query: impl Query + 'static) -> Self {
83        self.must_not.push(Arc::new(query));
84        self
85    }
86
87    /// Set global statistics for cross-segment IDF
88    pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
89        self.global_stats = Some(stats);
90        self
91    }
92}
93
94/// Compute IDF for a posting list, preferring global stats.
95fn compute_idf(
96    posting_list: &crate::structures::BlockPostingList,
97    field: crate::Field,
98    term: &[u8],
99    num_docs: f32,
100    global_stats: Option<&Arc<GlobalStats>>,
101) -> f32 {
102    if let Some(stats) = global_stats {
103        let global_idf = stats.text_idf(field, &String::from_utf8_lossy(term));
104        if global_idf > 0.0 {
105            return global_idf;
106        }
107    }
108    let doc_freq = posting_list.doc_count() as f32;
109    super::bm25_idf(doc_freq, num_docs)
110}
111
112/// Shared pre-check for text MaxScore: extract term infos, field, avg_field_len, num_docs.
113/// Returns None if not all SHOULD clauses are single-field term queries.
114fn prepare_text_maxscore(
115    should: &[Arc<dyn Query>],
116    reader: &SegmentReader,
117    global_stats: Option<&Arc<GlobalStats>>,
118) -> Option<(Vec<super::TermQueryInfo>, crate::Field, f32, f32)> {
119    let infos: Vec<_> = should
120        .iter()
121        .filter_map(|q| q.as_term_query_info())
122        .collect();
123    if infos.len() != should.len() {
124        return None;
125    }
126    let field = infos[0].field;
127    if !infos.iter().all(|t| t.field == field) {
128        return None;
129    }
130    let avg_field_len = global_stats
131        .map(|s| s.avg_field_len(field))
132        .unwrap_or_else(|| reader.avg_field_len(field));
133    let num_docs = reader.num_docs() as f32;
134    Some((infos, field, avg_field_len, num_docs))
135}
136
137/// Build a TopK scorer from fetched posting lists via text MaxScore.
138fn finish_text_maxscore<'a>(
139    posting_lists: Vec<(crate::structures::BlockPostingList, f32)>,
140    avg_field_len: f32,
141    limit: usize,
142) -> crate::Result<Box<dyn Scorer + 'a>> {
143    if posting_lists.is_empty() {
144        return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>);
145    }
146    let results = MaxScoreExecutor::text(posting_lists, avg_field_len, limit).execute_sync()?;
147    Ok(Box::new(TopKResultScorer::new(results)) as Box<dyn Scorer + 'a>)
148}
149
150/// Try text MaxScore for pure OR queries (async).
151async fn try_maxscore_scorer<'a>(
152    should: &[Arc<dyn Query>],
153    reader: &'a SegmentReader,
154    limit: usize,
155    global_stats: Option<&Arc<GlobalStats>>,
156) -> crate::Result<Option<Box<dyn Scorer + 'a>>> {
157    let (mut infos, _field, avg_field_len, num_docs) =
158        match prepare_text_maxscore(should, reader, global_stats) {
159            Some(v) => v,
160            None => return Ok(None),
161        };
162    let mut posting_lists = Vec::with_capacity(infos.len());
163    for info in infos.drain(..) {
164        if let Some(pl) = reader.get_postings(info.field, &info.term).await? {
165            let idf = compute_idf(&pl, info.field, &info.term, num_docs, global_stats);
166            posting_lists.push((pl, idf));
167        }
168    }
169    Ok(Some(finish_text_maxscore(
170        posting_lists,
171        avg_field_len,
172        limit,
173    )?))
174}
175
176/// Try text MaxScore for pure OR queries (sync).
177#[cfg(feature = "sync")]
178fn try_maxscore_scorer_sync<'a>(
179    should: &[Arc<dyn Query>],
180    reader: &'a SegmentReader,
181    limit: usize,
182    global_stats: Option<&Arc<GlobalStats>>,
183) -> crate::Result<Option<Box<dyn Scorer + 'a>>> {
184    let (mut infos, _field, avg_field_len, num_docs) =
185        match prepare_text_maxscore(should, reader, global_stats) {
186            Some(v) => v,
187            None => return Ok(None),
188        };
189    let mut posting_lists = Vec::with_capacity(infos.len());
190    for info in infos.drain(..) {
191        if let Some(pl) = reader.get_postings_sync(info.field, &info.term)? {
192            let idf = compute_idf(&pl, info.field, &info.term, num_docs, global_stats);
193            posting_lists.push((pl, idf));
194        }
195    }
196    Ok(Some(finish_text_maxscore(
197        posting_lists,
198        avg_field_len,
199        limit,
200    )?))
201}
202
203/// Shared grouping result for per-field MaxScore.
204struct PerFieldGrouping {
205    /// (field, avg_field_len, term_infos) for groups with 2+ terms
206    multi_term_groups: Vec<(crate::Field, f32, Vec<super::TermQueryInfo>)>,
207    /// Original indices of single-term and non-term SHOULD clauses (fallback scorers)
208    fallback_indices: Vec<usize>,
209    /// Limit per field group (over-fetched to compensate for cross-field scoring)
210    per_field_limit: usize,
211    num_docs: f32,
212}
213
214/// Group SHOULD clauses by field for per-field MaxScore.
215/// Returns None if no group has 2+ terms (no optimization benefit).
216fn prepare_per_field_grouping(
217    should: &[Arc<dyn Query>],
218    reader: &SegmentReader,
219    limit: usize,
220    global_stats: Option<&Arc<GlobalStats>>,
221) -> Option<PerFieldGrouping> {
222    let mut field_groups: rustc_hash::FxHashMap<crate::Field, Vec<(usize, super::TermQueryInfo)>> =
223        rustc_hash::FxHashMap::default();
224    let mut non_term_indices: Vec<usize> = Vec::new();
225
226    for (i, q) in should.iter().enumerate() {
227        if let Some(info) = q.as_term_query_info() {
228            field_groups.entry(info.field).or_default().push((i, info));
229        } else {
230            non_term_indices.push(i);
231        }
232    }
233
234    if !field_groups.values().any(|g| g.len() >= 2) {
235        return None;
236    }
237
238    let num_groups = field_groups.len() + non_term_indices.len();
239    let per_field_limit = limit * num_groups;
240    let num_docs = reader.num_docs() as f32;
241
242    let mut multi_term_groups = Vec::new();
243    let mut fallback_indices = non_term_indices;
244
245    for group in field_groups.into_values() {
246        if group.len() >= 2 {
247            let field = group[0].1.field;
248            let avg_field_len = global_stats
249                .map(|s| s.avg_field_len(field))
250                .unwrap_or_else(|| reader.avg_field_len(field));
251            let infos: Vec<_> = group.into_iter().map(|(_, info)| info).collect();
252            multi_term_groups.push((field, avg_field_len, infos));
253        } else {
254            fallback_indices.push(group[0].0);
255        }
256    }
257
258    Some(PerFieldGrouping {
259        multi_term_groups,
260        fallback_indices,
261        per_field_limit,
262        num_docs,
263    })
264}
265
266/// Build a SHOULD-only scorer from a vec of optimized scorers.
267fn build_should_scorer<'a>(scorers: Vec<Box<dyn Scorer + 'a>>) -> Box<dyn Scorer + 'a> {
268    if scorers.is_empty() {
269        return Box::new(EmptyScorer);
270    }
271    if scorers.len() == 1 {
272        return scorers.into_iter().next().unwrap();
273    }
274    let mut scorer = BooleanScorer {
275        must: vec![],
276        should: scorers,
277        must_not: vec![],
278        current_doc: 0,
279    };
280    scorer.current_doc = scorer.find_next_match();
281    Box::new(scorer)
282}
283
284/// Per-field MaxScore grouping for multi-field SHOULD queries (async).
285///
286/// When SHOULD clauses span multiple fields (e.g., "hello world" across title, body, desc),
287/// single-field MaxScore can't apply. This groups TermQuery clauses by field, runs MaxScore
288/// per group, and returns a compact scorer per field.
289async fn try_per_field_maxscore<'a>(
290    should: &[Arc<dyn Query>],
291    reader: &'a SegmentReader,
292    limit: usize,
293    global_stats: Option<&Arc<GlobalStats>>,
294) -> crate::Result<Option<Box<dyn Scorer + 'a>>> {
295    let grouping = match prepare_per_field_grouping(should, reader, limit, global_stats) {
296        Some(g) => g,
297        None => return Ok(None),
298    };
299
300    let mut scorers: Vec<Box<dyn Scorer + 'a>> = Vec::new();
301
302    for (field, avg_field_len, infos) in &grouping.multi_term_groups {
303        let mut posting_lists = Vec::with_capacity(infos.len());
304        for info in infos {
305            if let Some(pl) = reader.get_postings(info.field, &info.term).await? {
306                let idf = compute_idf(&pl, *field, &info.term, grouping.num_docs, global_stats);
307                posting_lists.push((pl, idf));
308            }
309        }
310        if !posting_lists.is_empty() {
311            scorers.push(finish_text_maxscore(
312                posting_lists,
313                *avg_field_len,
314                grouping.per_field_limit,
315            )?);
316        }
317    }
318
319    for &idx in &grouping.fallback_indices {
320        scorers.push(should[idx].scorer(reader, limit).await?);
321    }
322
323    Ok(Some(build_should_scorer(scorers)))
324}
325
326/// Per-field MaxScore grouping for multi-field SHOULD queries (sync).
327#[cfg(feature = "sync")]
328fn try_per_field_maxscore_sync<'a>(
329    should: &[Arc<dyn Query>],
330    reader: &'a SegmentReader,
331    limit: usize,
332    global_stats: Option<&Arc<GlobalStats>>,
333) -> crate::Result<Option<Box<dyn Scorer + 'a>>> {
334    let grouping = match prepare_per_field_grouping(should, reader, limit, global_stats) {
335        Some(g) => g,
336        None => return Ok(None),
337    };
338
339    let mut scorers: Vec<Box<dyn Scorer + 'a>> = Vec::new();
340
341    for (field, avg_field_len, infos) in &grouping.multi_term_groups {
342        let mut posting_lists = Vec::with_capacity(infos.len());
343        for info in infos {
344            if let Some(pl) = reader.get_postings_sync(info.field, &info.term)? {
345                let idf = compute_idf(&pl, *field, &info.term, grouping.num_docs, global_stats);
346                posting_lists.push((pl, idf));
347            }
348        }
349        if !posting_lists.is_empty() {
350            scorers.push(finish_text_maxscore(
351                posting_lists,
352                *avg_field_len,
353                grouping.per_field_limit,
354            )?);
355        }
356    }
357
358    for &idx in &grouping.fallback_indices {
359        scorers.push(should[idx].scorer_sync(reader, limit)?);
360    }
361
362    Ok(Some(build_should_scorer(scorers)))
363}
364
365/// Try to build a sparse MaxScoreExecutor from SHOULD clauses.
366/// Returns None if not eligible, Some(Err) for empty segment, Some(Ok) otherwise.
367fn prepare_sparse_maxscore<'a>(
368    should: &[Arc<dyn Query>],
369    reader: &'a SegmentReader,
370    limit: usize,
371) -> Option<Result<MaxScoreExecutor<'a>, Box<dyn Scorer + 'a>>> {
372    let infos: Vec<SparseTermQueryInfo> = should
373        .iter()
374        .filter_map(|q| q.as_sparse_term_query_info())
375        .collect();
376    if infos.len() != should.len() {
377        return None;
378    }
379    let field = infos[0].field;
380    if !infos.iter().all(|t| t.field == field) {
381        return None;
382    }
383    let si = match reader.sparse_index(field) {
384        Some(si) => si,
385        None => return Some(Err(Box::new(EmptyScorer))),
386    };
387    let query_terms: Vec<(u32, f32)> = infos
388        .iter()
389        .filter(|info| si.has_dimension(info.dim_id))
390        .map(|info| (info.dim_id, info.weight))
391        .collect();
392    if query_terms.is_empty() {
393        return Some(Err(Box::new(EmptyScorer)));
394    }
395    let executor_limit = (limit as f32 * infos[0].over_fetch_factor).ceil() as usize;
396    Some(Ok(MaxScoreExecutor::sparse(
397        si,
398        query_terms,
399        executor_limit,
400        infos[0].heap_factor,
401    )))
402}
403
404/// Combine raw MaxScore results with ordinal deduplication into a scorer.
405fn combine_sparse_results<'a>(
406    raw: Vec<ScoredDoc>,
407    combiner: super::MultiValueCombiner,
408    field: crate::Field,
409    limit: usize,
410) -> Box<dyn Scorer + 'a> {
411    let combined = crate::segment::combine_ordinal_results(
412        raw.into_iter().map(|r| (r.doc_id, r.ordinal, r.score)),
413        combiner,
414        limit,
415    );
416    Box::new(VectorTopKResultScorer::new(combined, field.0))
417}
418
419/// Build MaxScore scorer from sparse term infos (async).
420async fn try_sparse_maxscore_scorer<'a>(
421    should: &[Arc<dyn Query>],
422    reader: &'a SegmentReader,
423    limit: usize,
424) -> crate::Result<Option<Box<dyn Scorer + 'a>>> {
425    let executor = match prepare_sparse_maxscore(should, reader, limit) {
426        None => return Ok(None),
427        Some(Err(empty)) => return Ok(Some(empty)),
428        Some(Ok(e)) => e,
429    };
430    let info = should[0].as_sparse_term_query_info().unwrap();
431    let raw = executor.execute().await?;
432    Ok(Some(combine_sparse_results(
433        raw,
434        info.combiner,
435        info.field,
436        limit,
437    )))
438}
439
440/// Build MaxScore scorer from sparse term infos (sync).
441#[cfg(feature = "sync")]
442fn try_sparse_maxscore_scorer_sync<'a>(
443    should: &[Arc<dyn Query>],
444    reader: &'a SegmentReader,
445    limit: usize,
446) -> crate::Result<Option<Box<dyn Scorer + 'a>>> {
447    let executor = match prepare_sparse_maxscore(should, reader, limit) {
448        None => return Ok(None),
449        Some(Err(empty)) => return Ok(Some(empty)),
450        Some(Ok(e)) => e,
451    };
452    let info = should[0].as_sparse_term_query_info().unwrap();
453    let raw = executor.execute_sync()?;
454    Ok(Some(combine_sparse_results(
455        raw,
456        info.combiner,
457        info.field,
458        limit,
459    )))
460}
461
462impl Query for BooleanQuery {
463    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
464        // Clone Arc vectors - cheap reference counting
465        let must = self.must.clone();
466        let should = self.should.clone();
467        let must_not = self.must_not.clone();
468        let global_stats = self.global_stats.clone();
469
470        Box::pin(async move {
471            // Single-clause optimization: unwrap to inner scorer directly
472            if must_not.is_empty() {
473                if must.len() == 1 && should.is_empty() {
474                    return must[0].scorer(reader, limit).await;
475                }
476                if should.len() == 1 && must.is_empty() {
477                    return should[0].scorer(reader, limit).await;
478                }
479            }
480
481            // Check if this is a pure OR query eligible for MaxScore optimization
482            // Conditions: no MUST, no MUST_NOT, multiple SHOULD clauses, all same field
483            if must.is_empty() && must_not.is_empty() && should.len() >= 2 {
484                // Try text MaxScore first
485                if let Some(scorer) =
486                    try_maxscore_scorer(&should, reader, limit, global_stats.as_ref()).await?
487                {
488                    return Ok(scorer);
489                }
490                // Try sparse MaxScore
491                if let Some(scorer) = try_sparse_maxscore_scorer(&should, reader, limit).await? {
492                    return Ok(scorer);
493                }
494                // Try per-field MaxScore grouping for multi-field text queries
495                if let Some(scorer) =
496                    try_per_field_maxscore(&should, reader, limit, global_stats.as_ref()).await?
497                {
498                    return Ok(scorer);
499                }
500            }
501
502            // Fall back to standard boolean scoring
503            let mut must_scorers = Vec::with_capacity(must.len());
504            for q in &must {
505                must_scorers.push(q.scorer(reader, limit).await?);
506            }
507
508            let mut should_scorers = Vec::with_capacity(should.len());
509            for q in &should {
510                should_scorers.push(q.scorer(reader, limit).await?);
511            }
512
513            let mut must_not_scorers = Vec::with_capacity(must_not.len());
514            for q in &must_not {
515                must_not_scorers.push(q.scorer(reader, limit).await?);
516            }
517
518            let mut scorer = BooleanScorer {
519                must: must_scorers,
520                should: should_scorers,
521                must_not: must_not_scorers,
522                current_doc: 0,
523            };
524            // Initialize to first match
525            scorer.current_doc = scorer.find_next_match();
526            Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
527        })
528    }
529
530    #[cfg(feature = "sync")]
531    fn scorer_sync<'a>(
532        &self,
533        reader: &'a SegmentReader,
534        limit: usize,
535    ) -> crate::Result<Box<dyn Scorer + 'a>> {
536        // Single-clause optimization: unwrap to inner scorer directly
537        if self.must_not.is_empty() {
538            if self.must.len() == 1 && self.should.is_empty() {
539                return self.must[0].scorer_sync(reader, limit);
540            }
541            if self.should.len() == 1 && self.must.is_empty() {
542                return self.should[0].scorer_sync(reader, limit);
543            }
544        }
545
546        // MaxScore optimization for pure OR queries
547        if self.must.is_empty() && self.must_not.is_empty() && self.should.len() >= 2 {
548            if let Some(scorer) =
549                try_maxscore_scorer_sync(&self.should, reader, limit, self.global_stats.as_ref())?
550            {
551                return Ok(scorer);
552            }
553            if let Some(scorer) = try_sparse_maxscore_scorer_sync(&self.should, reader, limit)? {
554                return Ok(scorer);
555            }
556            // Try per-field MaxScore grouping for multi-field text queries
557            if let Some(scorer) = try_per_field_maxscore_sync(
558                &self.should,
559                reader,
560                limit,
561                self.global_stats.as_ref(),
562            )? {
563                return Ok(scorer);
564            }
565        }
566
567        // Fall back to standard boolean scoring
568        let mut must_scorers = Vec::with_capacity(self.must.len());
569        for q in &self.must {
570            must_scorers.push(q.scorer_sync(reader, limit)?);
571        }
572
573        let mut should_scorers = Vec::with_capacity(self.should.len());
574        for q in &self.should {
575            should_scorers.push(q.scorer_sync(reader, limit)?);
576        }
577
578        let mut must_not_scorers = Vec::with_capacity(self.must_not.len());
579        for q in &self.must_not {
580            must_not_scorers.push(q.scorer_sync(reader, limit)?);
581        }
582
583        let mut scorer = BooleanScorer {
584            must: must_scorers,
585            should: should_scorers,
586            must_not: must_not_scorers,
587            current_doc: 0,
588        };
589        scorer.current_doc = scorer.find_next_match();
590        Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
591    }
592
593    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
594        let must = self.must.clone();
595        let should = self.should.clone();
596
597        Box::pin(async move {
598            if !must.is_empty() {
599                let mut estimates = Vec::with_capacity(must.len());
600                for q in &must {
601                    estimates.push(q.count_estimate(reader).await?);
602                }
603                estimates
604                    .into_iter()
605                    .min()
606                    .ok_or_else(|| crate::Error::Corruption("Empty must clause".to_string()))
607            } else if !should.is_empty() {
608                let mut sum = 0u32;
609                for q in &should {
610                    sum = sum.saturating_add(q.count_estimate(reader).await?);
611                }
612                Ok(sum)
613            } else {
614                Ok(0)
615            }
616        })
617    }
618}
619
620struct BooleanScorer<'a> {
621    must: Vec<Box<dyn Scorer + 'a>>,
622    should: Vec<Box<dyn Scorer + 'a>>,
623    must_not: Vec<Box<dyn Scorer + 'a>>,
624    current_doc: DocId,
625}
626
627impl BooleanScorer<'_> {
628    fn find_next_match(&mut self) -> DocId {
629        if self.must.is_empty() && self.should.is_empty() {
630            return TERMINATED;
631        }
632
633        loop {
634            let candidate = if !self.must.is_empty() {
635                let mut max_doc = self
636                    .must
637                    .iter()
638                    .map(|s| s.doc())
639                    .max()
640                    .unwrap_or(TERMINATED);
641
642                if max_doc == TERMINATED {
643                    return TERMINATED;
644                }
645
646                loop {
647                    let mut all_match = true;
648                    for scorer in &mut self.must {
649                        let doc = scorer.seek(max_doc);
650                        if doc == TERMINATED {
651                            return TERMINATED;
652                        }
653                        if doc > max_doc {
654                            max_doc = doc;
655                            all_match = false;
656                            break;
657                        }
658                    }
659                    if all_match {
660                        break;
661                    }
662                }
663                max_doc
664            } else {
665                self.should
666                    .iter()
667                    .map(|s| s.doc())
668                    .filter(|&d| d != TERMINATED)
669                    .min()
670                    .unwrap_or(TERMINATED)
671            };
672
673            if candidate == TERMINATED {
674                return TERMINATED;
675            }
676
677            let excluded = self.must_not.iter_mut().any(|scorer| {
678                let doc = scorer.seek(candidate);
679                doc == candidate
680            });
681
682            if !excluded {
683                // Seek SHOULD scorers to candidate so score() can see their contributions
684                for scorer in &mut self.should {
685                    scorer.seek(candidate);
686                }
687                self.current_doc = candidate;
688                return candidate;
689            }
690
691            // Advance past excluded candidate
692            if !self.must.is_empty() {
693                for scorer in &mut self.must {
694                    scorer.advance();
695                }
696            } else {
697                // For SHOULD-only: seek all scorers past the excluded candidate
698                for scorer in &mut self.should {
699                    if scorer.doc() <= candidate && scorer.doc() != TERMINATED {
700                        scorer.seek(candidate + 1);
701                    }
702                }
703            }
704        }
705    }
706}
707
708impl super::docset::DocSet for BooleanScorer<'_> {
709    fn doc(&self) -> DocId {
710        self.current_doc
711    }
712
713    fn advance(&mut self) -> DocId {
714        if !self.must.is_empty() {
715            for scorer in &mut self.must {
716                scorer.advance();
717            }
718        } else {
719            for scorer in &mut self.should {
720                if scorer.doc() == self.current_doc {
721                    scorer.advance();
722                }
723            }
724        }
725
726        self.current_doc = self.find_next_match();
727        self.current_doc
728    }
729
730    fn seek(&mut self, target: DocId) -> DocId {
731        for scorer in &mut self.must {
732            scorer.seek(target);
733        }
734
735        for scorer in &mut self.should {
736            scorer.seek(target);
737        }
738
739        self.current_doc = self.find_next_match();
740        self.current_doc
741    }
742
743    fn size_hint(&self) -> u32 {
744        if !self.must.is_empty() {
745            self.must.iter().map(|s| s.size_hint()).min().unwrap_or(0)
746        } else {
747            self.should.iter().map(|s| s.size_hint()).sum()
748        }
749    }
750}
751
752impl Scorer for BooleanScorer<'_> {
753    fn score(&self) -> Score {
754        let mut total = 0.0;
755
756        for scorer in &self.must {
757            if scorer.doc() == self.current_doc {
758                total += scorer.score();
759            }
760        }
761
762        for scorer in &self.should {
763            if scorer.doc() == self.current_doc {
764                total += scorer.score();
765            }
766        }
767
768        total
769    }
770
771    fn matched_positions(&self) -> Option<super::MatchedPositions> {
772        let mut all_positions: super::MatchedPositions = Vec::new();
773
774        for scorer in &self.must {
775            if scorer.doc() == self.current_doc
776                && let Some(positions) = scorer.matched_positions()
777            {
778                all_positions.extend(positions);
779            }
780        }
781
782        for scorer in &self.should {
783            if scorer.doc() == self.current_doc
784                && let Some(positions) = scorer.matched_positions()
785            {
786                all_positions.extend(positions);
787            }
788        }
789
790        if all_positions.is_empty() {
791            None
792        } else {
793            Some(all_positions)
794        }
795    }
796}
797
798/// Scorer that iterates over pre-computed top-k results
799struct TopKResultScorer {
800    results: Vec<ScoredDoc>,
801    position: usize,
802}
803
804impl TopKResultScorer {
805    fn new(mut results: Vec<ScoredDoc>) -> Self {
806        // Sort by doc_id ascending — required for DocSet seek() correctness
807        results.sort_unstable_by_key(|r| r.doc_id);
808        Self {
809            results,
810            position: 0,
811        }
812    }
813}
814
815impl super::docset::DocSet for TopKResultScorer {
816    fn doc(&self) -> DocId {
817        if self.position < self.results.len() {
818            self.results[self.position].doc_id
819        } else {
820            TERMINATED
821        }
822    }
823
824    fn advance(&mut self) -> DocId {
825        self.position += 1;
826        self.doc()
827    }
828
829    fn seek(&mut self, target: DocId) -> DocId {
830        let remaining = &self.results[self.position..];
831        self.position += remaining.partition_point(|r| r.doc_id < target);
832        self.doc()
833    }
834
835    fn size_hint(&self) -> u32 {
836        (self.results.len() - self.position) as u32
837    }
838}
839
840impl Scorer for TopKResultScorer {
841    fn score(&self) -> Score {
842        if self.position < self.results.len() {
843            self.results[self.position].score
844        } else {
845            0.0
846        }
847    }
848}
849
850/// Scorer that iterates over pre-computed vector results with ordinal information.
851/// Used by sparse MaxScore path to preserve per-ordinal scores for matched_positions().
852struct VectorTopKResultScorer {
853    results: Vec<crate::segment::VectorSearchResult>,
854    position: usize,
855    field_id: u32,
856}
857
858impl VectorTopKResultScorer {
859    fn new(mut results: Vec<crate::segment::VectorSearchResult>, field_id: u32) -> Self {
860        results.sort_unstable_by_key(|r| r.doc_id);
861        Self {
862            results,
863            position: 0,
864            field_id,
865        }
866    }
867}
868
869impl super::docset::DocSet for VectorTopKResultScorer {
870    fn doc(&self) -> DocId {
871        if self.position < self.results.len() {
872            self.results[self.position].doc_id
873        } else {
874            TERMINATED
875        }
876    }
877
878    fn advance(&mut self) -> DocId {
879        self.position += 1;
880        self.doc()
881    }
882
883    fn seek(&mut self, target: DocId) -> DocId {
884        let remaining = &self.results[self.position..];
885        self.position += remaining.partition_point(|r| r.doc_id < target);
886        self.doc()
887    }
888
889    fn size_hint(&self) -> u32 {
890        (self.results.len() - self.position) as u32
891    }
892}
893
894impl Scorer for VectorTopKResultScorer {
895    fn score(&self) -> Score {
896        if self.position < self.results.len() {
897            self.results[self.position].score
898        } else {
899            0.0
900        }
901    }
902
903    fn matched_positions(&self) -> Option<super::MatchedPositions> {
904        if self.position >= self.results.len() {
905            return None;
906        }
907        let result = &self.results[self.position];
908        let scored_positions: Vec<super::ScoredPosition> = result
909            .ordinals
910            .iter()
911            .map(|&(ordinal, score)| super::ScoredPosition::new(ordinal, score))
912            .collect();
913        Some(vec![(self.field_id, scored_positions)])
914    }
915}
916
917/// Empty scorer for when no terms match
918struct EmptyScorer;
919
920impl super::docset::DocSet for EmptyScorer {
921    fn doc(&self) -> DocId {
922        TERMINATED
923    }
924
925    fn advance(&mut self) -> DocId {
926        TERMINATED
927    }
928
929    fn seek(&mut self, _target: DocId) -> DocId {
930        TERMINATED
931    }
932
933    fn size_hint(&self) -> u32 {
934        0
935    }
936}
937
938impl Scorer for EmptyScorer {
939    fn score(&self) -> Score {
940        0.0
941    }
942}
943
944#[cfg(test)]
945mod tests {
946    use super::*;
947    use crate::dsl::Field;
948    use crate::query::TermQuery;
949
950    #[test]
951    fn test_maxscore_eligible_pure_or_same_field() {
952        // Pure OR query with multiple terms in same field should be MaxScore-eligible
953        let query = BooleanQuery::new()
954            .should(TermQuery::text(Field(0), "hello"))
955            .should(TermQuery::text(Field(0), "world"))
956            .should(TermQuery::text(Field(0), "foo"));
957
958        // All clauses should return term info
959        assert!(
960            query
961                .should
962                .iter()
963                .all(|q| q.as_term_query_info().is_some())
964        );
965
966        // All should be same field
967        let infos: Vec<_> = query
968            .should
969            .iter()
970            .filter_map(|q| q.as_term_query_info())
971            .collect();
972        assert_eq!(infos.len(), 3);
973        assert!(infos.iter().all(|i| i.field == Field(0)));
974    }
975
976    #[test]
977    fn test_maxscore_not_eligible_different_fields() {
978        // OR query with terms in different fields should NOT use MaxScore
979        let query = BooleanQuery::new()
980            .should(TermQuery::text(Field(0), "hello"))
981            .should(TermQuery::text(Field(1), "world")); // Different field!
982
983        let infos: Vec<_> = query
984            .should
985            .iter()
986            .filter_map(|q| q.as_term_query_info())
987            .collect();
988        assert_eq!(infos.len(), 2);
989        // Fields are different, MaxScore should not be used
990        assert!(infos[0].field != infos[1].field);
991    }
992
993    #[test]
994    fn test_maxscore_not_eligible_with_must() {
995        // Query with MUST clause should NOT use MaxScore optimization
996        let query = BooleanQuery::new()
997            .must(TermQuery::text(Field(0), "required"))
998            .should(TermQuery::text(Field(0), "hello"))
999            .should(TermQuery::text(Field(0), "world"));
1000
1001        // Has MUST clause, so MaxScore optimization should not kick in
1002        assert!(!query.must.is_empty());
1003    }
1004
1005    #[test]
1006    fn test_maxscore_not_eligible_with_must_not() {
1007        // Query with MUST_NOT clause should NOT use MaxScore optimization
1008        let query = BooleanQuery::new()
1009            .should(TermQuery::text(Field(0), "hello"))
1010            .should(TermQuery::text(Field(0), "world"))
1011            .must_not(TermQuery::text(Field(0), "excluded"));
1012
1013        // Has MUST_NOT clause, so MaxScore optimization should not kick in
1014        assert!(!query.must_not.is_empty());
1015    }
1016
1017    #[test]
1018    fn test_maxscore_not_eligible_single_term() {
1019        // Single SHOULD clause should NOT use MaxScore (no benefit)
1020        let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
1021
1022        // Only one term, MaxScore not beneficial
1023        assert_eq!(query.should.len(), 1);
1024    }
1025
1026    #[test]
1027    fn test_term_query_info_extraction() {
1028        let term_query = TermQuery::text(Field(42), "test");
1029        let info = term_query.as_term_query_info();
1030
1031        assert!(info.is_some());
1032        let info = info.unwrap();
1033        assert_eq!(info.field, Field(42));
1034        assert_eq!(info.term, b"test");
1035    }
1036
1037    #[test]
1038    fn test_boolean_query_no_term_info() {
1039        // BooleanQuery itself should not return term info
1040        let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
1041
1042        assert!(query.as_term_query_info().is_none());
1043    }
1044}