Skip to main content

hermes_core/query/
scoring.rs

1//! Shared scoring abstractions for text and sparse vector search
2//!
3//! Provides common types and executors for efficient top-k retrieval:
4//! - `TermCursor`: Unified cursor for both BM25 text and sparse vector posting lists
5//! - `ScoreCollector`: Efficient min-heap for maintaining top-k results
6//! - `MaxScoreExecutor`: Unified Block-Max MaxScore with conjunction optimization
7//! - `ScoredDoc`: Result type with doc_id, score, and ordinal
8
9use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11
12use log::{debug, warn};
13
14use crate::DocId;
15
16/// Entry for top-k min-heap
17#[derive(Clone, Copy)]
18pub struct HeapEntry {
19    pub doc_id: DocId,
20    pub score: f32,
21    pub ordinal: u16,
22}
23
24impl PartialEq for HeapEntry {
25    fn eq(&self, other: &Self) -> bool {
26        self.score == other.score && self.doc_id == other.doc_id
27    }
28}
29
30impl Eq for HeapEntry {}
31
32impl Ord for HeapEntry {
33    fn cmp(&self, other: &Self) -> Ordering {
34        // Min-heap: lower scores come first (to be evicted).
35        // total_cmp is branchless (compiles to a single comparison instruction).
36        other
37            .score
38            .total_cmp(&self.score)
39            .then(self.doc_id.cmp(&other.doc_id))
40    }
41}
42
43impl PartialOrd for HeapEntry {
44    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
45        Some(self.cmp(other))
46    }
47}
48
49/// Efficient top-k collector using min-heap (internal, scoring-layer)
50///
51/// Maintains the k highest-scoring documents using a min-heap where the
52/// lowest score is at the top for O(1) threshold lookup and O(log k) eviction.
53/// No deduplication — caller must ensure each doc_id is inserted only once.
54///
55/// This is intentionally separate from `TopKCollector` in `collector.rs`:
56/// `ScoreCollector` is used inside `MaxScoreExecutor` where only `(doc_id,
57/// score, ordinal)` tuples exist — no `Scorer` trait, no position tracking,
58/// and the threshold must be inlined for tight block-max loops.
59/// `TopKCollector` wraps a `Scorer` and drives the full `DocSet`/`Scorer`
60/// protocol, collecting positions on demand.
61pub struct ScoreCollector {
62    /// Min-heap of top-k entries (lowest score at top for eviction)
63    heap: BinaryHeap<HeapEntry>,
64    pub k: usize,
65    /// Cached threshold: avoids repeated heap.peek() in hot loops.
66    /// Updated only when the heap changes (insert/pop).
67    cached_threshold: f32,
68}
69
70impl ScoreCollector {
71    /// Create a new collector for top-k results
72    pub fn new(k: usize) -> Self {
73        // Cap capacity to avoid allocation overflow for very large k
74        let capacity = k.saturating_add(1).min(1_000_000);
75        Self {
76            heap: BinaryHeap::with_capacity(capacity),
77            k,
78            cached_threshold: 0.0,
79        }
80    }
81
82    /// Current score threshold (minimum score to enter top-k)
83    #[inline]
84    pub fn threshold(&self) -> f32 {
85        self.cached_threshold
86    }
87
88    /// Recompute cached threshold from heap state
89    #[inline]
90    fn update_threshold(&mut self) {
91        self.cached_threshold = if self.heap.len() >= self.k {
92            self.heap.peek().map(|e| e.score).unwrap_or(0.0)
93        } else {
94            0.0
95        };
96    }
97
98    /// Insert a document score. Returns true if inserted in top-k.
99    /// Caller must ensure each doc_id is inserted only once.
100    #[inline]
101    pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
102        self.insert_with_ordinal(doc_id, score, 0)
103    }
104
105    /// Insert a document score with ordinal. Returns true if inserted in top-k.
106    /// Caller must ensure each doc_id is inserted only once.
107    #[inline]
108    pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
109        if self.heap.len() < self.k {
110            self.heap.push(HeapEntry {
111                doc_id,
112                score,
113                ordinal,
114            });
115            // Only recompute threshold when heap just became full
116            if self.heap.len() == self.k {
117                self.update_threshold();
118            }
119            true
120        } else if score > self.cached_threshold {
121            self.heap.push(HeapEntry {
122                doc_id,
123                score,
124                ordinal,
125            });
126            self.heap.pop(); // Remove lowest
127            self.update_threshold();
128            true
129        } else {
130            false
131        }
132    }
133
134    /// Check if a score could potentially enter top-k
135    #[inline]
136    pub fn would_enter(&self, score: f32) -> bool {
137        self.heap.len() < self.k || score > self.cached_threshold
138    }
139
140    /// Get number of documents collected so far
141    #[inline]
142    pub fn len(&self) -> usize {
143        self.heap.len()
144    }
145
146    /// Check if collector is empty
147    #[inline]
148    pub fn is_empty(&self) -> bool {
149        self.heap.is_empty()
150    }
151
152    /// Convert to sorted top-k results (descending by score)
153    pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
154        let mut results: Vec<(DocId, f32, u16)> = self
155            .heap
156            .into_vec()
157            .into_iter()
158            .map(|e| (e.doc_id, e.score, e.ordinal))
159            .collect();
160
161        // Sort by score descending, then doc_id ascending
162        results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
163
164        results
165    }
166}
167
168/// Search result from MaxScore execution
169#[derive(Debug, Clone, Copy)]
170pub struct ScoredDoc {
171    pub doc_id: DocId,
172    pub score: f32,
173    /// Ordinal for multi-valued fields (which vector in the field matched)
174    pub ordinal: u16,
175}
176
177/// Unified Block-Max MaxScore executor for top-k retrieval
178///
179/// Works with both full-text (BM25) and sparse vector (dot product) queries
180/// through the polymorphic `TermCursor`. Combines three optimizations:
181/// 1. **MaxScore partitioning** (Turtle & Flood 1995): terms split into essential
182///    (must check) and non-essential (only scored if candidate is promising)
183/// 2. **Block-max pruning** (Ding & Suel 2011): skip blocks where per-block
184///    upper bounds can't beat the current threshold
185/// 3. **Conjunction optimization** (Lucene/Grand 2023): progressively intersect
186///    essential terms as threshold rises, skipping docs that lack enough terms
187pub struct MaxScoreExecutor<'a> {
188    cursors: Vec<TermCursor<'a>>,
189    prefix_sums: Vec<f32>,
190    collector: ScoreCollector,
191    heap_factor: f32,
192    predicate: Option<super::DocPredicate<'a>>,
193}
194
195/// Unified term cursor for Block-Max MaxScore execution.
196///
197/// All per-position decode buffers (`doc_ids`, `scores`, `ordinals`) live in
198/// the struct directly and are filled by `ensure_block_loaded`.
199///
200/// Skip-list metadata is **not** materialized — it is read lazily from the
201/// underlying source (`BlockPostingList` for text, `SparseIndex` for sparse),
202/// both backed by zero-copy mmap'd `OwnedBytes`.
203pub(crate) struct TermCursor<'a> {
204    pub max_score: f32,
205    num_blocks: usize,
206    // ── Per-position state (filled by ensure_block_loaded) ──────────
207    block_idx: usize,
208    doc_ids: Vec<u32>,
209    scores: Vec<f32>,
210    ordinals: Vec<u16>,
211    pos: usize,
212    block_loaded: bool,
213    exhausted: bool,
214    // ── Lazy ordinal decode (sparse only) ───────────────────────────
215    /// When true, ordinal decode is deferred until ordinal_mut() is called.
216    /// Set to true for MaxScoreExecutor cursors (most blocks never need ordinals).
217    lazy_ordinals: bool,
218    /// Whether ordinals have been decoded for the current block.
219    ordinals_loaded: bool,
220    /// Stored sparse block for deferred ordinal decode (cheap Arc clone of mmap data).
221    current_sparse_block: Option<crate::structures::SparseBlock>,
222    // ── Block decode + skip access source ───────────────────────────
223    variant: CursorVariant<'a>,
224}
225
226enum CursorVariant<'a> {
227    /// Full-text BM25 — in-memory BlockPostingList (skip list + block data)
228    Text {
229        list: crate::structures::BlockPostingList,
230        idf: f32,
231        /// Precomputed: BM25_B / max(avg_field_len, 1.0)
232        b_over_avgfl: f32,
233        /// Precomputed: 1.0 - BM25_B
234        one_minus_b: f32,
235        tfs: Vec<u32>, // temp decode buffer, converted to scores
236    },
237    /// Sparse vector — mmap'd SparseIndex (skip entries + block data)
238    Sparse {
239        si: &'a crate::segment::SparseIndex,
240        query_weight: f32,
241        skip_start: usize,
242        block_data_offset: u64,
243    },
244}
245
246// ── TermCursor async/sync macros ──────────────────────────────────────────
247//
248// Parameterised on:
249//   $load_block_fn – load_block_direct | load_block_direct_sync  (sparse I/O)
250//   $ensure_fn     – ensure_block_loaded | ensure_block_loaded_sync
251//   $($aw)*        – .await  (present for async, absent for sync)
252
253macro_rules! cursor_ensure_block {
254    ($self:ident, $load_block_fn:ident, $($aw:tt)*) => {{
255        if $self.exhausted || $self.block_loaded {
256            return Ok(!$self.exhausted);
257        }
258        match &mut $self.variant {
259            CursorVariant::Text {
260                list,
261                idf,
262                b_over_avgfl,
263                one_minus_b,
264                tfs,
265            } => {
266                if list.decode_block_into($self.block_idx, &mut $self.doc_ids, tfs) {
267                    let idf_val = *idf;
268                    let b_avg = *b_over_avgfl;
269                    let one_b = *one_minus_b;
270                    $self.scores.clear();
271                    $self.scores.reserve(tfs.len());
272                    // Precomputed BM25: length_norm = one_minus_b + b_over_avgfl * tf
273                    // (tf is used as both term frequency and doc length — a known approx)
274                    for &tf in tfs.iter() {
275                        let tf = tf as f32;
276                        let length_norm = one_b + b_avg * tf;
277                        let tf_norm = (tf * (super::BM25_K1 + 1.0))
278                            / (tf + super::BM25_K1 * length_norm);
279                        $self.scores.push(idf_val * tf_norm);
280                    }
281                    $self.pos = 0;
282                    $self.block_loaded = true;
283                    Ok(true)
284                } else {
285                    $self.exhausted = true;
286                    Ok(false)
287                }
288            }
289            CursorVariant::Sparse {
290                si,
291                query_weight,
292                skip_start,
293                block_data_offset,
294                ..
295            } => {
296                let block = si
297                    .$load_block_fn(*skip_start, *block_data_offset, $self.block_idx)
298                    $($aw)* ?;
299                match block {
300                    Some(b) => {
301                        b.decode_doc_ids_into(&mut $self.doc_ids);
302                        b.decode_scored_weights_into(*query_weight, &mut $self.scores);
303                        if $self.lazy_ordinals {
304                            // Defer ordinal decode until ordinal_mut() is called.
305                            // Stores cheap Arc-backed mmap slice, no copy.
306                            $self.current_sparse_block = Some(b);
307                            $self.ordinals_loaded = false;
308                        } else {
309                            b.decode_ordinals_into(&mut $self.ordinals);
310                            $self.ordinals_loaded = true;
311                            $self.current_sparse_block = None;
312                        }
313                        $self.pos = 0;
314                        $self.block_loaded = true;
315                        Ok(true)
316                    }
317                    None => {
318                        $self.exhausted = true;
319                        Ok(false)
320                    }
321                }
322            }
323        }
324    }};
325}
326
327macro_rules! cursor_advance {
328    ($self:ident, $ensure_fn:ident, $($aw:tt)*) => {{
329        if $self.exhausted {
330            return Ok(u32::MAX);
331        }
332        $self.$ensure_fn() $($aw)* ?;
333        if $self.exhausted {
334            return Ok(u32::MAX);
335        }
336        Ok($self.advance_pos())
337    }};
338}
339
340macro_rules! cursor_seek {
341    ($self:ident, $ensure_fn:ident, $target:expr, $($aw:tt)*) => {{
342        if let Some(doc) = $self.seek_prepare($target) {
343            return Ok(doc);
344        }
345        $self.$ensure_fn() $($aw)* ?;
346        if $self.seek_finish($target) {
347            $self.$ensure_fn() $($aw)* ?;
348        }
349        Ok($self.doc())
350    }};
351}
352
353impl<'a> TermCursor<'a> {
354    /// Create a full-text BM25 cursor (lazy — no blocks decoded yet).
355    pub fn text(
356        posting_list: crate::structures::BlockPostingList,
357        idf: f32,
358        avg_field_len: f32,
359    ) -> Self {
360        let max_tf = posting_list.max_tf() as f32;
361        let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
362        let num_blocks = posting_list.num_blocks();
363        let safe_avg = avg_field_len.max(1.0);
364        Self {
365            max_score,
366            num_blocks,
367            block_idx: 0,
368            doc_ids: Vec::with_capacity(128),
369            scores: Vec::with_capacity(128),
370            ordinals: Vec::new(),
371            pos: 0,
372            block_loaded: false,
373            exhausted: num_blocks == 0,
374            lazy_ordinals: false,
375            ordinals_loaded: true, // text cursors never have ordinals
376            current_sparse_block: None,
377            variant: CursorVariant::Text {
378                list: posting_list,
379                idf,
380                b_over_avgfl: super::BM25_B / safe_avg,
381                one_minus_b: 1.0 - super::BM25_B,
382                tfs: Vec::with_capacity(128),
383            },
384        }
385    }
386
387    /// Create a sparse vector cursor with lazy block loading.
388    /// Skip entries are **not** copied — they are read from `SparseIndex` mmap on demand.
389    pub fn sparse(
390        si: &'a crate::segment::SparseIndex,
391        query_weight: f32,
392        skip_start: usize,
393        skip_count: usize,
394        global_max_weight: f32,
395        block_data_offset: u64,
396    ) -> Self {
397        Self {
398            max_score: query_weight.abs() * global_max_weight,
399            num_blocks: skip_count,
400            block_idx: 0,
401            doc_ids: Vec::with_capacity(256),
402            scores: Vec::with_capacity(256),
403            ordinals: Vec::with_capacity(256),
404            pos: 0,
405            block_loaded: false,
406            exhausted: skip_count == 0,
407            lazy_ordinals: false,
408            ordinals_loaded: true,
409            current_sparse_block: None,
410            variant: CursorVariant::Sparse {
411                si,
412                query_weight,
413                skip_start,
414                block_data_offset,
415            },
416        }
417    }
418
419    // ── Skip-entry access (lazy, zero-copy for sparse) ──────────────────
420
421    #[inline]
422    fn block_first_doc(&self, idx: usize) -> DocId {
423        match &self.variant {
424            CursorVariant::Text { list, .. } => list.block_first_doc(idx).unwrap_or(u32::MAX),
425            CursorVariant::Sparse { si, skip_start, .. } => {
426                si.read_skip_entry(*skip_start + idx).first_doc
427            }
428        }
429    }
430
431    #[inline]
432    fn block_last_doc(&self, idx: usize) -> DocId {
433        match &self.variant {
434            CursorVariant::Text { list, .. } => list.block_last_doc(idx).unwrap_or(0),
435            CursorVariant::Sparse { si, skip_start, .. } => {
436                si.read_skip_entry(*skip_start + idx).last_doc
437            }
438        }
439    }
440
441    // ── Read-only accessors ─────────────────────────────────────────────
442
443    #[inline]
444    pub fn doc(&self) -> DocId {
445        if self.exhausted {
446            return u32::MAX;
447        }
448        if self.block_loaded {
449            debug_assert!(self.pos < self.doc_ids.len());
450            // SAFETY: pos < doc_ids.len() is maintained by advance_pos/ensure_block_loaded.
451            unsafe { *self.doc_ids.get_unchecked(self.pos) }
452        } else {
453            self.block_first_doc(self.block_idx)
454        }
455    }
456
457    #[inline]
458    pub fn ordinal(&self) -> u16 {
459        if !self.block_loaded || self.ordinals.is_empty() {
460            return 0;
461        }
462        debug_assert!(self.pos < self.ordinals.len());
463        // SAFETY: pos < ordinals.len() is maintained by advance_pos/ensure_block_loaded.
464        unsafe { *self.ordinals.get_unchecked(self.pos) }
465    }
466
467    /// Lazily-decoded ordinal accessor for MaxScore executor.
468    ///
469    /// When `lazy_ordinals=true`, ordinals are not decoded during block loading.
470    /// This method triggers the deferred decode on first access, amortized over
471    /// the block. Subsequent calls within the same block are free.
472    #[inline]
473    pub fn ordinal_mut(&mut self) -> u16 {
474        if !self.block_loaded {
475            return 0;
476        }
477        if !self.ordinals_loaded {
478            if let Some(ref block) = self.current_sparse_block {
479                block.decode_ordinals_into(&mut self.ordinals);
480            }
481            self.ordinals_loaded = true;
482        }
483        if self.ordinals.is_empty() {
484            return 0;
485        }
486        debug_assert!(self.pos < self.ordinals.len());
487        unsafe { *self.ordinals.get_unchecked(self.pos) }
488    }
489
490    #[inline]
491    pub fn score(&self) -> f32 {
492        if !self.block_loaded {
493            return 0.0;
494        }
495        debug_assert!(self.pos < self.scores.len());
496        // SAFETY: pos < scores.len() is maintained by advance_pos/ensure_block_loaded.
497        unsafe { *self.scores.get_unchecked(self.pos) }
498    }
499
500    #[inline]
501    pub fn current_block_max_score(&self) -> f32 {
502        if self.exhausted {
503            return 0.0;
504        }
505        match &self.variant {
506            CursorVariant::Text { list, idf, .. } => {
507                let block_max_tf = list.block_max_tf(self.block_idx).unwrap_or(0) as f32;
508                super::bm25_upper_bound(block_max_tf.max(1.0), *idf)
509            }
510            CursorVariant::Sparse {
511                si,
512                query_weight,
513                skip_start,
514                ..
515            } => query_weight.abs() * si.read_skip_entry(*skip_start + self.block_idx).max_weight,
516        }
517    }
518
519    // ── Block navigation ────────────────────────────────────────────────
520
521    pub fn skip_to_next_block(&mut self) -> DocId {
522        if self.exhausted {
523            return u32::MAX;
524        }
525        self.block_idx += 1;
526        self.block_loaded = false;
527        if self.block_idx >= self.num_blocks {
528            self.exhausted = true;
529            return u32::MAX;
530        }
531        self.block_first_doc(self.block_idx)
532    }
533
534    #[inline]
535    fn advance_pos(&mut self) -> DocId {
536        self.pos += 1;
537        if self.pos >= self.doc_ids.len() {
538            self.block_idx += 1;
539            self.block_loaded = false;
540            if self.block_idx >= self.num_blocks {
541                self.exhausted = true;
542                return u32::MAX;
543            }
544        }
545        self.doc()
546    }
547
548    // ── Block loading / advance / seek ─────────────────────────────────
549    //
550    // Macros parameterised on sparse I/O method + optional .await to
551    // stamp out both async and sync variants without duplication.
552
553    pub async fn ensure_block_loaded(&mut self) -> crate::Result<bool> {
554        cursor_ensure_block!(self, load_block_direct, .await)
555    }
556
557    pub fn ensure_block_loaded_sync(&mut self) -> crate::Result<bool> {
558        cursor_ensure_block!(self, load_block_direct_sync,)
559    }
560
561    pub async fn advance(&mut self) -> crate::Result<DocId> {
562        cursor_advance!(self, ensure_block_loaded, .await)
563    }
564
565    pub fn advance_sync(&mut self) -> crate::Result<DocId> {
566        cursor_advance!(self, ensure_block_loaded_sync,)
567    }
568
569    pub async fn seek(&mut self, target: DocId) -> crate::Result<DocId> {
570        cursor_seek!(self, ensure_block_loaded, target, .await)
571    }
572
573    pub fn seek_sync(&mut self, target: DocId) -> crate::Result<DocId> {
574        cursor_seek!(self, ensure_block_loaded_sync, target,)
575    }
576
577    fn seek_prepare(&mut self, target: DocId) -> Option<DocId> {
578        if self.exhausted {
579            return Some(u32::MAX);
580        }
581
582        // Fast path: target is within the currently loaded block
583        if self.block_loaded
584            && let Some(&last) = self.doc_ids.last()
585        {
586            if last >= target && self.doc_ids[self.pos] < target {
587                let remaining = &self.doc_ids[self.pos..];
588                self.pos += crate::structures::simd::find_first_ge_u32(remaining, target);
589                if self.pos >= self.doc_ids.len() {
590                    self.block_idx += 1;
591                    self.block_loaded = false;
592                    if self.block_idx >= self.num_blocks {
593                        self.exhausted = true;
594                        return Some(u32::MAX);
595                    }
596                }
597                return Some(self.doc());
598            }
599            if self.doc_ids[self.pos] >= target {
600                return Some(self.doc());
601            }
602        }
603
604        // Seek to the block containing target
605        let lo = match &self.variant {
606            // Text: SIMD-accelerated 2-level seek (L1 + L0)
607            CursorVariant::Text { list, .. } => match list.seek_block(target, self.block_idx) {
608                Some(idx) => idx,
609                None => {
610                    self.exhausted = true;
611                    return Some(u32::MAX);
612                }
613            },
614            // Sparse: binary search on skip entries (lazy mmap reads)
615            CursorVariant::Sparse { .. } => {
616                let mut lo = self.block_idx;
617                let mut hi = self.num_blocks;
618                while lo < hi {
619                    let mid = lo + (hi - lo) / 2;
620                    if self.block_last_doc(mid) < target {
621                        lo = mid + 1;
622                    } else {
623                        hi = mid;
624                    }
625                }
626                lo
627            }
628        };
629        if lo >= self.num_blocks {
630            self.exhausted = true;
631            return Some(u32::MAX);
632        }
633        if lo != self.block_idx || !self.block_loaded {
634            self.block_idx = lo;
635            self.block_loaded = false;
636        }
637        None
638    }
639
640    #[inline]
641    fn seek_finish(&mut self, target: DocId) -> bool {
642        if self.exhausted {
643            return false;
644        }
645        self.pos = crate::structures::simd::find_first_ge_u32(&self.doc_ids, target);
646        if self.pos >= self.doc_ids.len() {
647            self.block_idx += 1;
648            self.block_loaded = false;
649            if self.block_idx >= self.num_blocks {
650                self.exhausted = true;
651                return false;
652            }
653            return true;
654        }
655        false
656    }
657}
658
659/// Macro to stamp out the Block-Max MaxScore loop for both async and sync paths.
660///
661/// `$ensure`, `$advance`, `$seek` are cursor method idents (async or _sync variants).
662/// `$($aw:tt)*` captures `.await` for async or nothing for sync.
663macro_rules! bms_execute_loop {
664    ($self:ident, $ensure:ident, $advance:ident, $seek:ident, $($aw:tt)*) => {{
665        let n = $self.cursors.len();
666
667        // Load first block for each cursor (ensures doc() returns real values)
668        for cursor in &mut $self.cursors {
669            cursor.$ensure() $($aw)* ?;
670        }
671
672        let mut docs_scored = 0u64;
673        let mut docs_skipped = 0u64;
674        let mut blocks_skipped = 0u64;
675        let mut conjunction_skipped = 0u64;
676        let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
677        let _bms_start = std::time::Instant::now();
678
679        loop {
680            let partition = $self.find_partition();
681            if partition >= n {
682                break;
683            }
684
685            // Find minimum doc_id across essential cursors and collect
686            // which cursors are at min_doc (avoids redundant re-checks in
687            // conjunction, block-max, predicate, and scoring passes).
688            let mut min_doc = u32::MAX;
689            let mut at_min_mask = 0u64; // bitset of cursor indices at min_doc
690            for i in partition..n {
691                let doc = $self.cursors[i].doc();
692                match doc.cmp(&min_doc) {
693                    std::cmp::Ordering::Less => {
694                        min_doc = doc;
695                        at_min_mask = 1u64 << (i as u32);
696                    }
697                    std::cmp::Ordering::Equal => {
698                        at_min_mask |= 1u64 << (i as u32);
699                    }
700                    _ => {}
701                }
702            }
703            if min_doc == u32::MAX {
704                break;
705            }
706
707            let non_essential_upper = if partition > 0 {
708                $self.prefix_sums[partition - 1]
709            } else {
710                0.0
711            };
712            // Small epsilon to guard against FP rounding in score accumulation.
713            // Without this, a document whose true score equals the threshold can
714            // be incorrectly pruned due to rounding in the heap_factor multiply
715            // or in the prefix_sum additions.
716            let adjusted_threshold = $self.collector.threshold() * $self.heap_factor - 1e-6;
717
718            // --- Conjunction optimization ---
719            if $self.collector.len() >= $self.collector.k {
720                let mut present_upper: f32 = 0.0;
721                let mut mask = at_min_mask;
722                while mask != 0 {
723                    let i = mask.trailing_zeros() as usize;
724                    present_upper += $self.cursors[i].max_score;
725                    mask &= mask - 1;
726                }
727
728                if present_upper + non_essential_upper <= adjusted_threshold {
729                    let mut mask = at_min_mask;
730                    while mask != 0 {
731                        let i = mask.trailing_zeros() as usize;
732                        $self.cursors[i].$ensure() $($aw)* ?;
733                        $self.cursors[i].$advance() $($aw)* ?;
734                        mask &= mask - 1;
735                    }
736                    conjunction_skipped += 1;
737                    continue;
738                }
739            }
740
741            // --- Block-max pruning ---
742            if $self.collector.len() >= $self.collector.k {
743                let mut block_max_sum: f32 = 0.0;
744                let mut mask = at_min_mask;
745                while mask != 0 {
746                    let i = mask.trailing_zeros() as usize;
747                    block_max_sum += $self.cursors[i].current_block_max_score();
748                    mask &= mask - 1;
749                }
750
751                if block_max_sum + non_essential_upper <= adjusted_threshold {
752                    let mut mask = at_min_mask;
753                    while mask != 0 {
754                        let i = mask.trailing_zeros() as usize;
755                        $self.cursors[i].skip_to_next_block();
756                        $self.cursors[i].$ensure() $($aw)* ?;
757                        mask &= mask - 1;
758                    }
759                    blocks_skipped += 1;
760                    continue;
761                }
762            }
763
764            // --- Predicate filter (after block-max, before scoring) ---
765            if let Some(ref pred) = $self.predicate {
766                if !pred(min_doc) {
767                    let mut mask = at_min_mask;
768                    while mask != 0 {
769                        let i = mask.trailing_zeros() as usize;
770                        $self.cursors[i].$ensure() $($aw)* ?;
771                        $self.cursors[i].$advance() $($aw)* ?;
772                        mask &= mask - 1;
773                    }
774                    continue;
775                }
776            }
777
778            // --- Score essential cursors ---
779            ordinal_scores.clear();
780            {
781                let mut mask = at_min_mask;
782                while mask != 0 {
783                    let i = mask.trailing_zeros() as usize;
784                    $self.cursors[i].$ensure() $($aw)* ?;
785                    while $self.cursors[i].doc() == min_doc {
786                        let ord = $self.cursors[i].ordinal_mut();
787                        let sc = $self.cursors[i].score();
788                        ordinal_scores.push((ord, sc));
789                        $self.cursors[i].$advance() $($aw)* ?;
790                    }
791                    mask &= mask - 1;
792                }
793            }
794
795            let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
796            if $self.collector.len() >= $self.collector.k
797                && essential_total + non_essential_upper <= adjusted_threshold
798            {
799                docs_skipped += 1;
800                continue;
801            }
802
803            // --- Score non-essential cursors (highest max_score first for early exit) ---
804            let mut running_total = essential_total;
805            for i in (0..partition).rev() {
806                if $self.collector.len() >= $self.collector.k
807                    && running_total + $self.prefix_sums[i] <= adjusted_threshold
808                {
809                    break;
810                }
811
812                let doc = $self.cursors[i].$seek(min_doc) $($aw)* ?;
813                if doc == min_doc {
814                    while $self.cursors[i].doc() == min_doc {
815                        let s = $self.cursors[i].score();
816                        running_total += s;
817                        let ord = $self.cursors[i].ordinal_mut();
818                        ordinal_scores.push((ord, s));
819                        $self.cursors[i].$advance() $($aw)* ?;
820                    }
821                }
822            }
823
824            // --- Group by ordinal and insert ---
825            // Fast path: single entry (common for single-valued fields) — skip sort + grouping
826            if ordinal_scores.len() == 1 {
827                let (ord, score) = ordinal_scores[0];
828                if $self.collector.insert_with_ordinal(min_doc, score, ord) {
829                    docs_scored += 1;
830                } else {
831                    docs_skipped += 1;
832                }
833            } else if !ordinal_scores.is_empty() {
834                if ordinal_scores.len() > 2 {
835                    ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
836                } else if ordinal_scores.len() == 2 && ordinal_scores[0].0 > ordinal_scores[1].0 {
837                    ordinal_scores.swap(0, 1);
838                }
839                let mut j = 0;
840                while j < ordinal_scores.len() {
841                    let current_ord = ordinal_scores[j].0;
842                    let mut score = 0.0f32;
843                    while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
844                        score += ordinal_scores[j].1;
845                        j += 1;
846                    }
847                    if $self
848                        .collector
849                        .insert_with_ordinal(min_doc, score, current_ord)
850                    {
851                        docs_scored += 1;
852                    } else {
853                        docs_skipped += 1;
854                    }
855                }
856            }
857        }
858
859        let results: Vec<ScoredDoc> = $self
860            .collector
861            .into_sorted_results()
862            .into_iter()
863            .map(|(doc_id, score, ordinal)| ScoredDoc {
864                doc_id,
865                score,
866                ordinal,
867            })
868            .collect();
869
870        let _bms_elapsed_ms = _bms_start.elapsed().as_millis() as u64;
871        if _bms_elapsed_ms > 500 {
872            warn!(
873                "slow MaxScore: {}ms, cursors={}, scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
874                _bms_elapsed_ms,
875                n,
876                docs_scored,
877                docs_skipped,
878                blocks_skipped,
879                conjunction_skipped,
880                results.len(),
881                results.first().map(|r| r.score).unwrap_or(0.0)
882            );
883        } else {
884            debug!(
885                "MaxScoreExecutor: {}ms, scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
886                _bms_elapsed_ms,
887                docs_scored,
888                docs_skipped,
889                blocks_skipped,
890                conjunction_skipped,
891                results.len(),
892                results.first().map(|r| r.score).unwrap_or(0.0)
893            );
894        }
895
896        Ok(results)
897    }};
898}
899
900impl<'a> MaxScoreExecutor<'a> {
901    /// Create a new executor from pre-built cursors.
902    ///
903    /// Cursors are sorted by max_score ascending (non-essential first) and
904    /// prefix sums are computed for the MaxScore partitioning.
905    pub(crate) fn new(mut cursors: Vec<TermCursor<'a>>, k: usize, heap_factor: f32) -> Self {
906        // Enable lazy ordinal decode — ordinals are only decoded when a doc
907        // actually reaches the scoring phase (saves ~100ns per skipped block).
908        for c in &mut cursors {
909            c.lazy_ordinals = true;
910        }
911
912        // Sort by max_score ascending (non-essential first)
913        cursors.sort_by(|a, b| {
914            a.max_score
915                .partial_cmp(&b.max_score)
916                .unwrap_or(Ordering::Equal)
917        });
918
919        let mut prefix_sums = Vec::with_capacity(cursors.len());
920        let mut cumsum = 0.0f32;
921        for c in &cursors {
922            cumsum += c.max_score;
923            prefix_sums.push(cumsum);
924        }
925
926        debug!(
927            "Creating MaxScoreExecutor: num_cursors={}, k={}, total_upper={:.4}, heap_factor={:.2}",
928            cursors.len(),
929            k,
930            cumsum,
931            heap_factor
932        );
933
934        Self {
935            cursors,
936            prefix_sums,
937            collector: ScoreCollector::new(k),
938            heap_factor: heap_factor.clamp(0.0, 1.0),
939            predicate: None,
940        }
941    }
942
943    /// Create an executor for sparse vector queries.
944    ///
945    /// Builds `TermCursor::Sparse` for each matched dimension.
946    pub fn sparse(
947        sparse_index: &'a crate::segment::SparseIndex,
948        query_terms: Vec<(u32, f32)>,
949        k: usize,
950        heap_factor: f32,
951    ) -> Self {
952        let cursors: Vec<TermCursor<'a>> = query_terms
953            .iter()
954            .filter_map(|&(dim_id, qw)| {
955                let (skip_start, skip_count, global_max, block_data_offset) =
956                    sparse_index.get_skip_range_full(dim_id)?;
957                Some(TermCursor::sparse(
958                    sparse_index,
959                    qw,
960                    skip_start,
961                    skip_count,
962                    global_max,
963                    block_data_offset,
964                ))
965            })
966            .collect();
967        Self::new(cursors, k, heap_factor)
968    }
969
970    /// Create an executor for full-text BM25 queries.
971    ///
972    /// Builds `TermCursor::Text` for each posting list.
973    pub fn text(
974        posting_lists: Vec<(crate::structures::BlockPostingList, f32)>,
975        avg_field_len: f32,
976        k: usize,
977    ) -> Self {
978        let cursors: Vec<TermCursor<'a>> = posting_lists
979            .into_iter()
980            .map(|(pl, idf)| TermCursor::text(pl, idf, avg_field_len))
981            .collect();
982        Self::new(cursors, k, 1.0)
983    }
984
985    #[inline]
986    fn find_partition(&self) -> usize {
987        let threshold = self.collector.threshold() * self.heap_factor;
988        self.prefix_sums.partition_point(|&sum| sum <= threshold)
989    }
990
991    /// Attach a per-doc predicate filter to this executor.
992    ///
993    /// Docs failing the predicate are skipped after block-max pruning but
994    /// before scoring. The predicate does not affect thresholds or block-max
995    /// comparisons — the heap stores pure sparse/text scores.
996    pub fn with_predicate(mut self, predicate: super::DocPredicate<'a>) -> Self {
997        self.predicate = Some(predicate);
998        self
999    }
1000
1001    /// Execute Block-Max MaxScore and return top-k results (async).
1002    pub async fn execute(mut self) -> crate::Result<Vec<ScoredDoc>> {
1003        if self.cursors.is_empty() {
1004            return Ok(Vec::new());
1005        }
1006        bms_execute_loop!(self, ensure_block_loaded, advance, seek, .await)
1007    }
1008
1009    /// Synchronous execution — works when all cursors are text or mmap-backed sparse.
1010    pub fn execute_sync(mut self) -> crate::Result<Vec<ScoredDoc>> {
1011        if self.cursors.is_empty() {
1012            return Ok(Vec::new());
1013        }
1014        bms_execute_loop!(self, ensure_block_loaded_sync, advance_sync, seek_sync,)
1015    }
1016}
1017
1018#[cfg(test)]
1019mod tests {
1020    use super::*;
1021
1022    #[test]
1023    fn test_score_collector_basic() {
1024        let mut collector = ScoreCollector::new(3);
1025
1026        collector.insert(1, 1.0);
1027        collector.insert(2, 2.0);
1028        collector.insert(3, 3.0);
1029        assert_eq!(collector.threshold(), 1.0);
1030
1031        collector.insert(4, 4.0);
1032        assert_eq!(collector.threshold(), 2.0);
1033
1034        let results = collector.into_sorted_results();
1035        assert_eq!(results.len(), 3);
1036        assert_eq!(results[0].0, 4); // Highest score
1037        assert_eq!(results[1].0, 3);
1038        assert_eq!(results[2].0, 2);
1039    }
1040
1041    #[test]
1042    fn test_score_collector_threshold() {
1043        let mut collector = ScoreCollector::new(2);
1044
1045        collector.insert(1, 5.0);
1046        collector.insert(2, 3.0);
1047        assert_eq!(collector.threshold(), 3.0);
1048
1049        // Should not enter (score too low)
1050        assert!(!collector.would_enter(2.0));
1051        assert!(!collector.insert(3, 2.0));
1052
1053        // Should enter (score high enough)
1054        assert!(collector.would_enter(4.0));
1055        assert!(collector.insert(4, 4.0));
1056        assert_eq!(collector.threshold(), 4.0);
1057    }
1058
1059    #[test]
1060    fn test_heap_entry_ordering() {
1061        let mut heap = BinaryHeap::new();
1062        heap.push(HeapEntry {
1063            doc_id: 1,
1064            score: 3.0,
1065            ordinal: 0,
1066        });
1067        heap.push(HeapEntry {
1068            doc_id: 2,
1069            score: 1.0,
1070            ordinal: 0,
1071        });
1072        heap.push(HeapEntry {
1073            doc_id: 3,
1074            score: 2.0,
1075            ordinal: 0,
1076        });
1077
1078        // Min-heap: lowest score should come out first
1079        assert_eq!(heap.pop().unwrap().score, 1.0);
1080        assert_eq!(heap.pop().unwrap().score, 2.0);
1081        assert_eq!(heap.pop().unwrap().score, 3.0);
1082    }
1083}