Skip to main content

hermes_core/query/
scoring.rs

1//! Shared scoring abstractions for text and sparse vector search
2//!
3//! Provides common types and executors for efficient top-k retrieval:
4//! - `TermCursor`: Unified cursor for both BM25 text and sparse vector posting lists
5//! - `ScoreCollector`: Efficient min-heap for maintaining top-k results
6//! - `MaxScoreExecutor`: Unified Block-Max MaxScore with conjunction optimization
7//! - `ScoredDoc`: Result type with doc_id, score, and ordinal
8
9use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11
12use log::debug;
13
14use crate::DocId;
15
16/// Entry for top-k min-heap
17#[derive(Clone, Copy)]
18pub struct HeapEntry {
19    pub doc_id: DocId,
20    pub score: f32,
21    pub ordinal: u16,
22}
23
24impl PartialEq for HeapEntry {
25    fn eq(&self, other: &Self) -> bool {
26        self.score == other.score && self.doc_id == other.doc_id
27    }
28}
29
30impl Eq for HeapEntry {}
31
32impl Ord for HeapEntry {
33    fn cmp(&self, other: &Self) -> Ordering {
34        // Min-heap: lower scores come first (to be evicted).
35        // total_cmp is branchless (compiles to a single comparison instruction).
36        other
37            .score
38            .total_cmp(&self.score)
39            .then(self.doc_id.cmp(&other.doc_id))
40    }
41}
42
43impl PartialOrd for HeapEntry {
44    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
45        Some(self.cmp(other))
46    }
47}
48
49/// Efficient top-k collector using min-heap (internal, scoring-layer)
50///
51/// Maintains the k highest-scoring documents using a min-heap where the
52/// lowest score is at the top for O(1) threshold lookup and O(log k) eviction.
53/// No deduplication — caller must ensure each doc_id is inserted only once.
54///
55/// This is intentionally separate from `TopKCollector` in `collector.rs`:
56/// `ScoreCollector` is used inside `MaxScoreExecutor` where only `(doc_id,
57/// score, ordinal)` tuples exist — no `Scorer` trait, no position tracking,
58/// and the threshold must be inlined for tight block-max loops.
59/// `TopKCollector` wraps a `Scorer` and drives the full `DocSet`/`Scorer`
60/// protocol, collecting positions on demand.
61pub struct ScoreCollector {
62    /// Min-heap of top-k entries (lowest score at top for eviction)
63    heap: BinaryHeap<HeapEntry>,
64    pub k: usize,
65    /// Cached threshold: avoids repeated heap.peek() in hot loops.
66    /// Updated only when the heap changes (insert/pop).
67    cached_threshold: f32,
68}
69
70impl ScoreCollector {
71    /// Create a new collector for top-k results
72    pub fn new(k: usize) -> Self {
73        // Cap capacity to avoid allocation overflow for very large k
74        let capacity = k.saturating_add(1).min(1_000_000);
75        Self {
76            heap: BinaryHeap::with_capacity(capacity),
77            k,
78            cached_threshold: 0.0,
79        }
80    }
81
82    /// Current score threshold (minimum score to enter top-k)
83    #[inline]
84    pub fn threshold(&self) -> f32 {
85        self.cached_threshold
86    }
87
88    /// Recompute cached threshold from heap state
89    #[inline]
90    fn update_threshold(&mut self) {
91        self.cached_threshold = if self.heap.len() >= self.k {
92            self.heap.peek().map(|e| e.score).unwrap_or(0.0)
93        } else {
94            0.0
95        };
96    }
97
98    /// Insert a document score. Returns true if inserted in top-k.
99    /// Caller must ensure each doc_id is inserted only once.
100    #[inline]
101    pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
102        self.insert_with_ordinal(doc_id, score, 0)
103    }
104
105    /// Insert a document score with ordinal. Returns true if inserted in top-k.
106    /// Caller must ensure each doc_id is inserted only once.
107    #[inline]
108    pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
109        if self.heap.len() < self.k {
110            self.heap.push(HeapEntry {
111                doc_id,
112                score,
113                ordinal,
114            });
115            // Only recompute threshold when heap just became full
116            if self.heap.len() == self.k {
117                self.update_threshold();
118            }
119            true
120        } else if score > self.cached_threshold {
121            self.heap.push(HeapEntry {
122                doc_id,
123                score,
124                ordinal,
125            });
126            self.heap.pop(); // Remove lowest
127            self.update_threshold();
128            true
129        } else {
130            false
131        }
132    }
133
134    /// Check if a score could potentially enter top-k
135    #[inline]
136    pub fn would_enter(&self, score: f32) -> bool {
137        self.heap.len() < self.k || score > self.cached_threshold
138    }
139
140    /// Get number of documents collected so far
141    #[inline]
142    pub fn len(&self) -> usize {
143        self.heap.len()
144    }
145
146    /// Check if collector is empty
147    #[inline]
148    pub fn is_empty(&self) -> bool {
149        self.heap.is_empty()
150    }
151
152    /// Convert to sorted top-k results (descending by score)
153    pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
154        let heap_vec = self.heap.into_vec();
155        let mut results: Vec<(DocId, f32, u16)> = Vec::with_capacity(heap_vec.len());
156        for e in heap_vec {
157            results.push((e.doc_id, e.score, e.ordinal));
158        }
159
160        // Sort by score descending, then doc_id ascending
161        results.sort_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
162
163        results
164    }
165}
166
167/// Search result from MaxScore execution
168#[derive(Debug, Clone, Copy)]
169pub struct ScoredDoc {
170    pub doc_id: DocId,
171    pub score: f32,
172    /// Ordinal for multi-valued fields (which vector in the field matched)
173    pub ordinal: u16,
174}
175
176/// Unified Block-Max MaxScore executor for top-k retrieval
177///
178/// Works with both full-text (BM25) and sparse vector (dot product) queries
179/// through the polymorphic `TermCursor`. Combines three optimizations:
180/// 1. **MaxScore partitioning** (Turtle & Flood 1995): terms split into essential
181///    (must check) and non-essential (only scored if candidate is promising)
182/// 2. **Block-max pruning** (Ding & Suel 2011): skip blocks where per-block
183///    upper bounds can't beat the current threshold
184/// 3. **Conjunction optimization** (Lucene/Grand 2023): progressively intersect
185///    essential terms as threshold rises, skipping docs that lack enough terms
186pub struct MaxScoreExecutor<'a> {
187    cursors: Vec<TermCursor<'a>>,
188    prefix_sums: Vec<f32>,
189    collector: ScoreCollector,
190    heap_factor: f32,
191    predicate: Option<super::DocPredicate<'a>>,
192}
193
194/// Unified term cursor for Block-Max MaxScore execution.
195///
196/// All per-position decode buffers (`doc_ids`, `scores`, `ordinals`) live in
197/// the struct directly and are filled by `ensure_block_loaded`.
198///
199/// Skip-list metadata is **not** materialized — it is read lazily from the
200/// underlying source (`BlockPostingList` for text, `SparseIndex` for sparse),
201/// both backed by zero-copy mmap'd `OwnedBytes`.
202pub(crate) struct TermCursor<'a> {
203    pub max_score: f32,
204    num_blocks: usize,
205    // ── Per-position state (filled by ensure_block_loaded) ──────────
206    block_idx: usize,
207    doc_ids: Vec<u32>,
208    scores: Vec<f32>,
209    ordinals: Vec<u16>,
210    pos: usize,
211    block_loaded: bool,
212    exhausted: bool,
213    // ── Lazy ordinal decode (sparse only) ───────────────────────────
214    /// When true, ordinal decode is deferred until ordinal_mut() is called.
215    /// Set to true for MaxScoreExecutor cursors (most blocks never need ordinals).
216    lazy_ordinals: bool,
217    /// Whether ordinals have been decoded for the current block.
218    ordinals_loaded: bool,
219    /// Stored sparse block for deferred ordinal decode (cheap Arc clone of mmap data).
220    current_sparse_block: Option<crate::structures::SparseBlock>,
221    // ── Block decode + skip access source ───────────────────────────
222    variant: CursorVariant<'a>,
223}
224
225enum CursorVariant<'a> {
226    /// Full-text BM25 — in-memory BlockPostingList (skip list + block data)
227    Text {
228        list: crate::structures::BlockPostingList,
229        idf: f32,
230        /// Precomputed: BM25_B / max(avg_field_len, 1.0)
231        b_over_avgfl: f32,
232        /// Precomputed: 1.0 - BM25_B
233        one_minus_b: f32,
234        tfs: Vec<u32>, // temp decode buffer, converted to scores
235    },
236    /// Sparse vector — mmap'd SparseIndex (skip entries + block data)
237    Sparse {
238        si: &'a crate::segment::SparseIndex,
239        query_weight: f32,
240        skip_start: usize,
241        block_data_offset: u64,
242    },
243}
244
245// ── TermCursor async/sync macros ──────────────────────────────────────────
246//
247// Parameterised on:
248//   $load_block_fn – load_block_direct | load_block_direct_sync  (sparse I/O)
249//   $ensure_fn     – ensure_block_loaded | ensure_block_loaded_sync
250//   $($aw)*        – .await  (present for async, absent for sync)
251
252macro_rules! cursor_ensure_block {
253    ($self:ident, $load_block_fn:ident, $($aw:tt)*) => {{
254        if $self.exhausted || $self.block_loaded {
255            return Ok(!$self.exhausted);
256        }
257        match &mut $self.variant {
258            CursorVariant::Text {
259                list,
260                idf,
261                b_over_avgfl,
262                one_minus_b,
263                tfs,
264            } => {
265                if list.decode_block_into($self.block_idx, &mut $self.doc_ids, tfs) {
266                    let idf_val = *idf;
267                    let b_avg = *b_over_avgfl;
268                    let one_b = *one_minus_b;
269                    $self.scores.clear();
270                    $self.scores.reserve(tfs.len());
271                    // Precomputed BM25: length_norm = one_minus_b + b_over_avgfl * tf
272                    // (tf is used as both term frequency and doc length — a known approx)
273                    for &tf in tfs.iter() {
274                        let tf = tf as f32;
275                        let length_norm = one_b + b_avg * tf;
276                        let tf_norm = (tf * (super::BM25_K1 + 1.0))
277                            / (tf + super::BM25_K1 * length_norm);
278                        $self.scores.push(idf_val * tf_norm);
279                    }
280                    $self.pos = 0;
281                    $self.block_loaded = true;
282                    Ok(true)
283                } else {
284                    $self.exhausted = true;
285                    Ok(false)
286                }
287            }
288            CursorVariant::Sparse {
289                si,
290                query_weight,
291                skip_start,
292                block_data_offset,
293                ..
294            } => {
295                let block = si
296                    .$load_block_fn(*skip_start, *block_data_offset, $self.block_idx)
297                    $($aw)* ?;
298                match block {
299                    Some(b) => {
300                        b.decode_doc_ids_into(&mut $self.doc_ids);
301                        b.decode_scored_weights_into(*query_weight, &mut $self.scores);
302                        if $self.lazy_ordinals {
303                            // Defer ordinal decode until ordinal_mut() is called.
304                            // Stores cheap Arc-backed mmap slice, no copy.
305                            $self.current_sparse_block = Some(b);
306                            $self.ordinals_loaded = false;
307                        } else {
308                            b.decode_ordinals_into(&mut $self.ordinals);
309                            $self.ordinals_loaded = true;
310                            $self.current_sparse_block = None;
311                        }
312                        $self.pos = 0;
313                        $self.block_loaded = true;
314                        Ok(true)
315                    }
316                    None => {
317                        $self.exhausted = true;
318                        Ok(false)
319                    }
320                }
321            }
322        }
323    }};
324}
325
326macro_rules! cursor_advance {
327    ($self:ident, $ensure_fn:ident, $($aw:tt)*) => {{
328        if $self.exhausted {
329            return Ok(u32::MAX);
330        }
331        $self.$ensure_fn() $($aw)* ?;
332        if $self.exhausted {
333            return Ok(u32::MAX);
334        }
335        Ok($self.advance_pos())
336    }};
337}
338
339macro_rules! cursor_seek {
340    ($self:ident, $ensure_fn:ident, $target:expr, $($aw:tt)*) => {{
341        if let Some(doc) = $self.seek_prepare($target) {
342            return Ok(doc);
343        }
344        $self.$ensure_fn() $($aw)* ?;
345        if $self.seek_finish($target) {
346            $self.$ensure_fn() $($aw)* ?;
347        }
348        Ok($self.doc())
349    }};
350}
351
352impl<'a> TermCursor<'a> {
353    /// Create a full-text BM25 cursor (lazy — no blocks decoded yet).
354    pub fn text(
355        posting_list: crate::structures::BlockPostingList,
356        idf: f32,
357        avg_field_len: f32,
358    ) -> Self {
359        let max_tf = posting_list.max_tf() as f32;
360        let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
361        let num_blocks = posting_list.num_blocks();
362        let safe_avg = avg_field_len.max(1.0);
363        Self {
364            max_score,
365            num_blocks,
366            block_idx: 0,
367            doc_ids: Vec::with_capacity(128),
368            scores: Vec::with_capacity(128),
369            ordinals: Vec::new(),
370            pos: 0,
371            block_loaded: false,
372            exhausted: num_blocks == 0,
373            lazy_ordinals: false,
374            ordinals_loaded: true, // text cursors never have ordinals
375            current_sparse_block: None,
376            variant: CursorVariant::Text {
377                list: posting_list,
378                idf,
379                b_over_avgfl: super::BM25_B / safe_avg,
380                one_minus_b: 1.0 - super::BM25_B,
381                tfs: Vec::with_capacity(128),
382            },
383        }
384    }
385
386    /// Create a sparse vector cursor with lazy block loading.
387    /// Skip entries are **not** copied — they are read from `SparseIndex` mmap on demand.
388    pub fn sparse(
389        si: &'a crate::segment::SparseIndex,
390        query_weight: f32,
391        skip_start: usize,
392        skip_count: usize,
393        global_max_weight: f32,
394        block_data_offset: u64,
395    ) -> Self {
396        Self {
397            max_score: query_weight.abs() * global_max_weight,
398            num_blocks: skip_count,
399            block_idx: 0,
400            doc_ids: Vec::with_capacity(256),
401            scores: Vec::with_capacity(256),
402            ordinals: Vec::with_capacity(256),
403            pos: 0,
404            block_loaded: false,
405            exhausted: skip_count == 0,
406            lazy_ordinals: false,
407            ordinals_loaded: true,
408            current_sparse_block: None,
409            variant: CursorVariant::Sparse {
410                si,
411                query_weight,
412                skip_start,
413                block_data_offset,
414            },
415        }
416    }
417
418    // ── Skip-entry access (lazy, zero-copy for sparse) ──────────────────
419
420    #[inline]
421    fn block_first_doc(&self, idx: usize) -> DocId {
422        match &self.variant {
423            CursorVariant::Text { list, .. } => list.block_first_doc(idx).unwrap_or(u32::MAX),
424            CursorVariant::Sparse { si, skip_start, .. } => {
425                si.read_skip_entry(*skip_start + idx).first_doc
426            }
427        }
428    }
429
430    #[inline]
431    fn block_last_doc(&self, idx: usize) -> DocId {
432        match &self.variant {
433            CursorVariant::Text { list, .. } => list.block_last_doc(idx).unwrap_or(0),
434            CursorVariant::Sparse { si, skip_start, .. } => {
435                si.read_skip_entry(*skip_start + idx).last_doc
436            }
437        }
438    }
439
440    // ── Read-only accessors ─────────────────────────────────────────────
441
442    #[inline]
443    pub fn doc(&self) -> DocId {
444        if self.exhausted {
445            return u32::MAX;
446        }
447        if self.block_loaded {
448            debug_assert!(self.pos < self.doc_ids.len());
449            // SAFETY: pos < doc_ids.len() is maintained by advance_pos/ensure_block_loaded.
450            unsafe { *self.doc_ids.get_unchecked(self.pos) }
451        } else {
452            self.block_first_doc(self.block_idx)
453        }
454    }
455
456    #[inline]
457    pub fn ordinal(&self) -> u16 {
458        if !self.block_loaded || self.ordinals.is_empty() {
459            return 0;
460        }
461        debug_assert!(self.pos < self.ordinals.len());
462        // SAFETY: pos < ordinals.len() is maintained by advance_pos/ensure_block_loaded.
463        unsafe { *self.ordinals.get_unchecked(self.pos) }
464    }
465
466    /// Lazily-decoded ordinal accessor for MaxScore executor.
467    ///
468    /// When `lazy_ordinals=true`, ordinals are not decoded during block loading.
469    /// This method triggers the deferred decode on first access, amortized over
470    /// the block. Subsequent calls within the same block are free.
471    #[inline]
472    pub fn ordinal_mut(&mut self) -> u16 {
473        if !self.block_loaded {
474            return 0;
475        }
476        if !self.ordinals_loaded {
477            if let Some(ref block) = self.current_sparse_block {
478                block.decode_ordinals_into(&mut self.ordinals);
479            }
480            self.ordinals_loaded = true;
481        }
482        if self.ordinals.is_empty() {
483            return 0;
484        }
485        debug_assert!(self.pos < self.ordinals.len());
486        unsafe { *self.ordinals.get_unchecked(self.pos) }
487    }
488
489    #[inline]
490    pub fn score(&self) -> f32 {
491        if !self.block_loaded {
492            return 0.0;
493        }
494        debug_assert!(self.pos < self.scores.len());
495        // SAFETY: pos < scores.len() is maintained by advance_pos/ensure_block_loaded.
496        unsafe { *self.scores.get_unchecked(self.pos) }
497    }
498
499    #[inline]
500    pub fn current_block_max_score(&self) -> f32 {
501        if self.exhausted {
502            return 0.0;
503        }
504        match &self.variant {
505            CursorVariant::Text { list, idf, .. } => {
506                let block_max_tf = list.block_max_tf(self.block_idx).unwrap_or(0) as f32;
507                super::bm25_upper_bound(block_max_tf.max(1.0), *idf)
508            }
509            CursorVariant::Sparse {
510                si,
511                query_weight,
512                skip_start,
513                ..
514            } => query_weight.abs() * si.read_skip_entry(*skip_start + self.block_idx).max_weight,
515        }
516    }
517
518    // ── Block navigation ────────────────────────────────────────────────
519
520    pub fn skip_to_next_block(&mut self) -> DocId {
521        if self.exhausted {
522            return u32::MAX;
523        }
524        self.block_idx += 1;
525        self.block_loaded = false;
526        if self.block_idx >= self.num_blocks {
527            self.exhausted = true;
528            return u32::MAX;
529        }
530        self.block_first_doc(self.block_idx)
531    }
532
533    #[inline]
534    fn advance_pos(&mut self) -> DocId {
535        self.pos += 1;
536        if self.pos >= self.doc_ids.len() {
537            self.block_idx += 1;
538            self.block_loaded = false;
539            if self.block_idx >= self.num_blocks {
540                self.exhausted = true;
541                return u32::MAX;
542            }
543        }
544        self.doc()
545    }
546
547    // ── Block loading / advance / seek ─────────────────────────────────
548    //
549    // Macros parameterised on sparse I/O method + optional .await to
550    // stamp out both async and sync variants without duplication.
551
552    pub async fn ensure_block_loaded(&mut self) -> crate::Result<bool> {
553        cursor_ensure_block!(self, load_block_direct, .await)
554    }
555
556    pub fn ensure_block_loaded_sync(&mut self) -> crate::Result<bool> {
557        cursor_ensure_block!(self, load_block_direct_sync,)
558    }
559
560    pub async fn advance(&mut self) -> crate::Result<DocId> {
561        cursor_advance!(self, ensure_block_loaded, .await)
562    }
563
564    pub fn advance_sync(&mut self) -> crate::Result<DocId> {
565        cursor_advance!(self, ensure_block_loaded_sync,)
566    }
567
568    pub async fn seek(&mut self, target: DocId) -> crate::Result<DocId> {
569        cursor_seek!(self, ensure_block_loaded, target, .await)
570    }
571
572    pub fn seek_sync(&mut self, target: DocId) -> crate::Result<DocId> {
573        cursor_seek!(self, ensure_block_loaded_sync, target,)
574    }
575
576    fn seek_prepare(&mut self, target: DocId) -> Option<DocId> {
577        if self.exhausted {
578            return Some(u32::MAX);
579        }
580
581        // Fast path: target is within the currently loaded block
582        if self.block_loaded
583            && let Some(&last) = self.doc_ids.last()
584        {
585            if last >= target && self.doc_ids[self.pos] < target {
586                let remaining = &self.doc_ids[self.pos..];
587                self.pos += crate::structures::simd::find_first_ge_u32(remaining, target);
588                if self.pos >= self.doc_ids.len() {
589                    self.block_idx += 1;
590                    self.block_loaded = false;
591                    if self.block_idx >= self.num_blocks {
592                        self.exhausted = true;
593                        return Some(u32::MAX);
594                    }
595                }
596                return Some(self.doc());
597            }
598            if self.doc_ids[self.pos] >= target {
599                return Some(self.doc());
600            }
601        }
602
603        // Seek to the block containing target
604        let lo = match &self.variant {
605            // Text: SIMD-accelerated 2-level seek (L1 + L0)
606            CursorVariant::Text { list, .. } => match list.seek_block(target, self.block_idx) {
607                Some(idx) => idx,
608                None => {
609                    self.exhausted = true;
610                    return Some(u32::MAX);
611                }
612            },
613            // Sparse: binary search on skip entries (lazy mmap reads)
614            CursorVariant::Sparse { .. } => {
615                let mut lo = self.block_idx;
616                let mut hi = self.num_blocks;
617                while lo < hi {
618                    let mid = lo + (hi - lo) / 2;
619                    if self.block_last_doc(mid) < target {
620                        lo = mid + 1;
621                    } else {
622                        hi = mid;
623                    }
624                }
625                lo
626            }
627        };
628        if lo >= self.num_blocks {
629            self.exhausted = true;
630            return Some(u32::MAX);
631        }
632        if lo != self.block_idx || !self.block_loaded {
633            self.block_idx = lo;
634            self.block_loaded = false;
635        }
636        None
637    }
638
639    #[inline]
640    fn seek_finish(&mut self, target: DocId) -> bool {
641        if self.exhausted {
642            return false;
643        }
644        self.pos = crate::structures::simd::find_first_ge_u32(&self.doc_ids, target);
645        if self.pos >= self.doc_ids.len() {
646            self.block_idx += 1;
647            self.block_loaded = false;
648            if self.block_idx >= self.num_blocks {
649                self.exhausted = true;
650                return false;
651            }
652            return true;
653        }
654        false
655    }
656}
657
658/// Macro to stamp out the Block-Max MaxScore loop for both async and sync paths.
659///
660/// `$ensure`, `$advance`, `$seek` are cursor method idents (async or _sync variants).
661/// `$($aw:tt)*` captures `.await` for async or nothing for sync.
662macro_rules! bms_execute_loop {
663    ($self:ident, $ensure:ident, $advance:ident, $seek:ident, $($aw:tt)*) => {{
664        let n = $self.cursors.len();
665
666        // Load first block for each cursor (ensures doc() returns real values)
667        for cursor in &mut $self.cursors {
668            cursor.$ensure() $($aw)* ?;
669        }
670
671        let mut docs_scored = 0u64;
672        let mut docs_skipped = 0u64;
673        let mut blocks_skipped = 0u64;
674        let mut conjunction_skipped = 0u64;
675        let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
676
677        loop {
678            let partition = $self.find_partition();
679            if partition >= n {
680                break;
681            }
682
683            // Find minimum doc_id across essential cursors and collect
684            // which cursors are at min_doc (avoids redundant re-checks in
685            // conjunction, block-max, predicate, and scoring passes).
686            let mut min_doc = u32::MAX;
687            let mut at_min_mask = 0u64; // bitset of cursor indices at min_doc
688            for i in partition..n {
689                let doc = $self.cursors[i].doc();
690                match doc.cmp(&min_doc) {
691                    std::cmp::Ordering::Less => {
692                        min_doc = doc;
693                        at_min_mask = 1u64 << (i as u32);
694                    }
695                    std::cmp::Ordering::Equal => {
696                        at_min_mask |= 1u64 << (i as u32);
697                    }
698                    _ => {}
699                }
700            }
701            if min_doc == u32::MAX {
702                break;
703            }
704
705            let non_essential_upper = if partition > 0 {
706                $self.prefix_sums[partition - 1]
707            } else {
708                0.0
709            };
710            // Small epsilon to guard against FP rounding in score accumulation.
711            // Without this, a document whose true score equals the threshold can
712            // be incorrectly pruned due to rounding in the heap_factor multiply
713            // or in the prefix_sum additions.
714            let adjusted_threshold = $self.collector.threshold() * $self.heap_factor - 1e-6;
715
716            // --- Conjunction optimization ---
717            if $self.collector.len() >= $self.collector.k {
718                let mut present_upper: f32 = 0.0;
719                let mut mask = at_min_mask;
720                while mask != 0 {
721                    let i = mask.trailing_zeros() as usize;
722                    present_upper += $self.cursors[i].max_score;
723                    mask &= mask - 1;
724                }
725
726                if present_upper + non_essential_upper <= adjusted_threshold {
727                    let mut mask = at_min_mask;
728                    while mask != 0 {
729                        let i = mask.trailing_zeros() as usize;
730                        $self.cursors[i].$ensure() $($aw)* ?;
731                        $self.cursors[i].$advance() $($aw)* ?;
732                        mask &= mask - 1;
733                    }
734                    conjunction_skipped += 1;
735                    continue;
736                }
737            }
738
739            // --- Block-max pruning ---
740            if $self.collector.len() >= $self.collector.k {
741                let mut block_max_sum: f32 = 0.0;
742                let mut mask = at_min_mask;
743                while mask != 0 {
744                    let i = mask.trailing_zeros() as usize;
745                    block_max_sum += $self.cursors[i].current_block_max_score();
746                    mask &= mask - 1;
747                }
748
749                if block_max_sum + non_essential_upper <= adjusted_threshold {
750                    let mut mask = at_min_mask;
751                    while mask != 0 {
752                        let i = mask.trailing_zeros() as usize;
753                        $self.cursors[i].skip_to_next_block();
754                        $self.cursors[i].$ensure() $($aw)* ?;
755                        mask &= mask - 1;
756                    }
757                    blocks_skipped += 1;
758                    continue;
759                }
760            }
761
762            // --- Predicate filter (after block-max, before scoring) ---
763            if let Some(ref pred) = $self.predicate {
764                if !pred(min_doc) {
765                    let mut mask = at_min_mask;
766                    while mask != 0 {
767                        let i = mask.trailing_zeros() as usize;
768                        $self.cursors[i].$ensure() $($aw)* ?;
769                        $self.cursors[i].$advance() $($aw)* ?;
770                        mask &= mask - 1;
771                    }
772                    continue;
773                }
774            }
775
776            // --- Score essential cursors ---
777            ordinal_scores.clear();
778            {
779                let mut mask = at_min_mask;
780                while mask != 0 {
781                    let i = mask.trailing_zeros() as usize;
782                    $self.cursors[i].$ensure() $($aw)* ?;
783                    while $self.cursors[i].doc() == min_doc {
784                        let ord = $self.cursors[i].ordinal_mut();
785                        let sc = $self.cursors[i].score();
786                        ordinal_scores.push((ord, sc));
787                        $self.cursors[i].$advance() $($aw)* ?;
788                    }
789                    mask &= mask - 1;
790                }
791            }
792
793            let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
794            if $self.collector.len() >= $self.collector.k
795                && essential_total + non_essential_upper <= adjusted_threshold
796            {
797                docs_skipped += 1;
798                continue;
799            }
800
801            // --- Score non-essential cursors (highest max_score first for early exit) ---
802            let mut running_total = essential_total;
803            for i in (0..partition).rev() {
804                if $self.collector.len() >= $self.collector.k
805                    && running_total + $self.prefix_sums[i] <= adjusted_threshold
806                {
807                    break;
808                }
809
810                let doc = $self.cursors[i].$seek(min_doc) $($aw)* ?;
811                if doc == min_doc {
812                    while $self.cursors[i].doc() == min_doc {
813                        let s = $self.cursors[i].score();
814                        running_total += s;
815                        let ord = $self.cursors[i].ordinal_mut();
816                        ordinal_scores.push((ord, s));
817                        $self.cursors[i].$advance() $($aw)* ?;
818                    }
819                }
820            }
821
822            // --- Group by ordinal and insert ---
823            // Fast path: single entry (common for single-valued fields) — skip sort + grouping
824            if ordinal_scores.len() == 1 {
825                let (ord, score) = ordinal_scores[0];
826                if $self.collector.insert_with_ordinal(min_doc, score, ord) {
827                    docs_scored += 1;
828                } else {
829                    docs_skipped += 1;
830                }
831            } else if !ordinal_scores.is_empty() {
832                if ordinal_scores.len() > 2 {
833                    ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
834                } else if ordinal_scores.len() == 2 && ordinal_scores[0].0 > ordinal_scores[1].0 {
835                    ordinal_scores.swap(0, 1);
836                }
837                let mut j = 0;
838                while j < ordinal_scores.len() {
839                    let current_ord = ordinal_scores[j].0;
840                    let mut score = 0.0f32;
841                    while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
842                        score += ordinal_scores[j].1;
843                        j += 1;
844                    }
845                    if $self
846                        .collector
847                        .insert_with_ordinal(min_doc, score, current_ord)
848                    {
849                        docs_scored += 1;
850                    } else {
851                        docs_skipped += 1;
852                    }
853                }
854            }
855        }
856
857        let results: Vec<ScoredDoc> = $self
858            .collector
859            .into_sorted_results()
860            .into_iter()
861            .map(|(doc_id, score, ordinal)| ScoredDoc {
862                doc_id,
863                score,
864                ordinal,
865            })
866            .collect();
867
868        debug!(
869            "MaxScoreExecutor: scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
870            docs_scored,
871            docs_skipped,
872            blocks_skipped,
873            conjunction_skipped,
874            results.len(),
875            results.first().map(|r| r.score).unwrap_or(0.0)
876        );
877
878        Ok(results)
879    }};
880}
881
882impl<'a> MaxScoreExecutor<'a> {
883    /// Create a new executor from pre-built cursors.
884    ///
885    /// Cursors are sorted by max_score ascending (non-essential first) and
886    /// prefix sums are computed for the MaxScore partitioning.
887    pub(crate) fn new(mut cursors: Vec<TermCursor<'a>>, k: usize, heap_factor: f32) -> Self {
888        // Enable lazy ordinal decode — ordinals are only decoded when a doc
889        // actually reaches the scoring phase (saves ~100ns per skipped block).
890        for c in &mut cursors {
891            c.lazy_ordinals = true;
892        }
893
894        // Sort by max_score ascending (non-essential first)
895        cursors.sort_by(|a, b| {
896            a.max_score
897                .partial_cmp(&b.max_score)
898                .unwrap_or(Ordering::Equal)
899        });
900
901        let mut prefix_sums = Vec::with_capacity(cursors.len());
902        let mut cumsum = 0.0f32;
903        for c in &cursors {
904            cumsum += c.max_score;
905            prefix_sums.push(cumsum);
906        }
907
908        debug!(
909            "Creating MaxScoreExecutor: num_cursors={}, k={}, total_upper={:.4}, heap_factor={:.2}",
910            cursors.len(),
911            k,
912            cumsum,
913            heap_factor
914        );
915
916        Self {
917            cursors,
918            prefix_sums,
919            collector: ScoreCollector::new(k),
920            heap_factor: heap_factor.clamp(0.0, 1.0),
921            predicate: None,
922        }
923    }
924
925    /// Create an executor for sparse vector queries.
926    ///
927    /// Builds `TermCursor::Sparse` for each matched dimension.
928    pub fn sparse(
929        sparse_index: &'a crate::segment::SparseIndex,
930        query_terms: Vec<(u32, f32)>,
931        k: usize,
932        heap_factor: f32,
933    ) -> Self {
934        let cursors: Vec<TermCursor<'a>> = query_terms
935            .iter()
936            .filter_map(|&(dim_id, qw)| {
937                let (skip_start, skip_count, global_max, block_data_offset) =
938                    sparse_index.get_skip_range_full(dim_id)?;
939                Some(TermCursor::sparse(
940                    sparse_index,
941                    qw,
942                    skip_start,
943                    skip_count,
944                    global_max,
945                    block_data_offset,
946                ))
947            })
948            .collect();
949        Self::new(cursors, k, heap_factor)
950    }
951
952    /// Create an executor for full-text BM25 queries.
953    ///
954    /// Builds `TermCursor::Text` for each posting list.
955    pub fn text(
956        posting_lists: Vec<(crate::structures::BlockPostingList, f32)>,
957        avg_field_len: f32,
958        k: usize,
959    ) -> Self {
960        let cursors: Vec<TermCursor<'a>> = posting_lists
961            .into_iter()
962            .map(|(pl, idf)| TermCursor::text(pl, idf, avg_field_len))
963            .collect();
964        Self::new(cursors, k, 1.0)
965    }
966
967    #[inline]
968    fn find_partition(&self) -> usize {
969        let threshold = self.collector.threshold() * self.heap_factor;
970        self.prefix_sums.partition_point(|&sum| sum <= threshold)
971    }
972
973    /// Attach a per-doc predicate filter to this executor.
974    ///
975    /// Docs failing the predicate are skipped after block-max pruning but
976    /// before scoring. The predicate does not affect thresholds or block-max
977    /// comparisons — the heap stores pure sparse/text scores.
978    pub fn with_predicate(mut self, predicate: super::DocPredicate<'a>) -> Self {
979        self.predicate = Some(predicate);
980        self
981    }
982
983    /// Execute Block-Max MaxScore and return top-k results (async).
984    pub async fn execute(mut self) -> crate::Result<Vec<ScoredDoc>> {
985        if self.cursors.is_empty() {
986            return Ok(Vec::new());
987        }
988        bms_execute_loop!(self, ensure_block_loaded, advance, seek, .await)
989    }
990
991    /// Synchronous execution — works when all cursors are text or mmap-backed sparse.
992    pub fn execute_sync(mut self) -> crate::Result<Vec<ScoredDoc>> {
993        if self.cursors.is_empty() {
994            return Ok(Vec::new());
995        }
996        bms_execute_loop!(self, ensure_block_loaded_sync, advance_sync, seek_sync,)
997    }
998}
999
1000#[cfg(test)]
1001mod tests {
1002    use super::*;
1003
1004    #[test]
1005    fn test_score_collector_basic() {
1006        let mut collector = ScoreCollector::new(3);
1007
1008        collector.insert(1, 1.0);
1009        collector.insert(2, 2.0);
1010        collector.insert(3, 3.0);
1011        assert_eq!(collector.threshold(), 1.0);
1012
1013        collector.insert(4, 4.0);
1014        assert_eq!(collector.threshold(), 2.0);
1015
1016        let results = collector.into_sorted_results();
1017        assert_eq!(results.len(), 3);
1018        assert_eq!(results[0].0, 4); // Highest score
1019        assert_eq!(results[1].0, 3);
1020        assert_eq!(results[2].0, 2);
1021    }
1022
1023    #[test]
1024    fn test_score_collector_threshold() {
1025        let mut collector = ScoreCollector::new(2);
1026
1027        collector.insert(1, 5.0);
1028        collector.insert(2, 3.0);
1029        assert_eq!(collector.threshold(), 3.0);
1030
1031        // Should not enter (score too low)
1032        assert!(!collector.would_enter(2.0));
1033        assert!(!collector.insert(3, 2.0));
1034
1035        // Should enter (score high enough)
1036        assert!(collector.would_enter(4.0));
1037        assert!(collector.insert(4, 4.0));
1038        assert_eq!(collector.threshold(), 4.0);
1039    }
1040
1041    #[test]
1042    fn test_heap_entry_ordering() {
1043        let mut heap = BinaryHeap::new();
1044        heap.push(HeapEntry {
1045            doc_id: 1,
1046            score: 3.0,
1047            ordinal: 0,
1048        });
1049        heap.push(HeapEntry {
1050            doc_id: 2,
1051            score: 1.0,
1052            ordinal: 0,
1053        });
1054        heap.push(HeapEntry {
1055            doc_id: 3,
1056            score: 2.0,
1057            ordinal: 0,
1058        });
1059
1060        // Min-heap: lowest score should come out first
1061        assert_eq!(heap.pop().unwrap().score, 1.0);
1062        assert_eq!(heap.pop().unwrap().score, 2.0);
1063        assert_eq!(heap.pop().unwrap().score, 3.0);
1064    }
1065}