Skip to main content

hermes_core/query/
scoring.rs

1//! Shared scoring abstractions for text and sparse vector search
2//!
3//! Provides common types and executors for efficient top-k retrieval:
4//! - `TermCursor`: Unified cursor for both BM25 text and sparse vector posting lists
5//! - `ScoreCollector`: Efficient min-heap for maintaining top-k results
6//! - `MaxScoreExecutor`: Unified Block-Max MaxScore with conjunction optimization
7//! - `ScoredDoc`: Result type with doc_id, score, and ordinal
8
9use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11
12use log::debug;
13
14use crate::DocId;
15
16/// Entry for top-k min-heap
17#[derive(Clone, Copy)]
18pub struct HeapEntry {
19    pub doc_id: DocId,
20    pub score: f32,
21    pub ordinal: u16,
22}
23
24impl PartialEq for HeapEntry {
25    fn eq(&self, other: &Self) -> bool {
26        self.score == other.score && self.doc_id == other.doc_id
27    }
28}
29
30impl Eq for HeapEntry {}
31
32impl Ord for HeapEntry {
33    fn cmp(&self, other: &Self) -> Ordering {
34        // Min-heap: lower scores come first (to be evicted).
35        // total_cmp is branchless (compiles to a single comparison instruction).
36        other
37            .score
38            .total_cmp(&self.score)
39            .then(self.doc_id.cmp(&other.doc_id))
40    }
41}
42
43impl PartialOrd for HeapEntry {
44    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
45        Some(self.cmp(other))
46    }
47}
48
49/// Efficient top-k collector using min-heap (internal, scoring-layer)
50///
51/// Maintains the k highest-scoring documents using a min-heap where the
52/// lowest score is at the top for O(1) threshold lookup and O(log k) eviction.
53/// No deduplication — caller must ensure each doc_id is inserted only once.
54///
55/// This is intentionally separate from `TopKCollector` in `collector.rs`:
56/// `ScoreCollector` is used inside `MaxScoreExecutor` where only `(doc_id,
57/// score, ordinal)` tuples exist — no `Scorer` trait, no position tracking,
58/// and the threshold must be inlined for tight block-max loops.
59/// `TopKCollector` wraps a `Scorer` and drives the full `DocSet`/`Scorer`
60/// protocol, collecting positions on demand.
61pub struct ScoreCollector {
62    /// Min-heap of top-k entries (lowest score at top for eviction)
63    heap: BinaryHeap<HeapEntry>,
64    pub k: usize,
65    /// Cached threshold: avoids repeated heap.peek() in hot loops.
66    /// Updated only when the heap changes (insert/pop).
67    cached_threshold: f32,
68}
69
70impl ScoreCollector {
71    /// Create a new collector for top-k results
72    pub fn new(k: usize) -> Self {
73        // Cap capacity to avoid allocation overflow for very large k
74        let capacity = k.saturating_add(1).min(1_000_000);
75        Self {
76            heap: BinaryHeap::with_capacity(capacity),
77            k,
78            cached_threshold: 0.0,
79        }
80    }
81
82    /// Current score threshold (minimum score to enter top-k)
83    #[inline]
84    pub fn threshold(&self) -> f32 {
85        self.cached_threshold
86    }
87
88    /// Recompute cached threshold from heap state
89    #[inline]
90    fn update_threshold(&mut self) {
91        self.cached_threshold = if self.heap.len() >= self.k {
92            self.heap.peek().map(|e| e.score).unwrap_or(0.0)
93        } else {
94            0.0
95        };
96    }
97
98    /// Insert a document score. Returns true if inserted in top-k.
99    /// Caller must ensure each doc_id is inserted only once.
100    #[inline]
101    pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
102        self.insert_with_ordinal(doc_id, score, 0)
103    }
104
105    /// Insert a document score with ordinal. Returns true if inserted in top-k.
106    /// Caller must ensure each doc_id is inserted only once.
107    #[inline]
108    pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
109        if self.heap.len() < self.k {
110            self.heap.push(HeapEntry {
111                doc_id,
112                score,
113                ordinal,
114            });
115            // Only recompute threshold when heap just became full
116            if self.heap.len() == self.k {
117                self.update_threshold();
118            }
119            true
120        } else if score > self.cached_threshold {
121            self.heap.push(HeapEntry {
122                doc_id,
123                score,
124                ordinal,
125            });
126            self.heap.pop(); // Remove lowest
127            self.update_threshold();
128            true
129        } else {
130            false
131        }
132    }
133
134    /// Check if a score could potentially enter top-k
135    #[inline]
136    pub fn would_enter(&self, score: f32) -> bool {
137        self.heap.len() < self.k || score > self.cached_threshold
138    }
139
140    /// Get number of documents collected so far
141    #[inline]
142    pub fn len(&self) -> usize {
143        self.heap.len()
144    }
145
146    /// Check if collector is empty
147    #[inline]
148    pub fn is_empty(&self) -> bool {
149        self.heap.is_empty()
150    }
151
152    /// Convert to sorted top-k results (descending by score)
153    pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
154        let mut results: Vec<(DocId, f32, u16)> = self
155            .heap
156            .into_vec()
157            .into_iter()
158            .map(|e| (e.doc_id, e.score, e.ordinal))
159            .collect();
160
161        // Sort by score descending, then doc_id ascending
162        results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
163
164        results
165    }
166}
167
168/// Search result from MaxScore execution
169#[derive(Debug, Clone, Copy)]
170pub struct ScoredDoc {
171    pub doc_id: DocId,
172    pub score: f32,
173    /// Ordinal for multi-valued fields (which vector in the field matched)
174    pub ordinal: u16,
175}
176
177/// Unified Block-Max MaxScore executor for top-k retrieval
178///
179/// Works with both full-text (BM25) and sparse vector (dot product) queries
180/// through the polymorphic `TermCursor`. Combines three optimizations:
181/// 1. **MaxScore partitioning** (Turtle & Flood 1995): terms split into essential
182///    (must check) and non-essential (only scored if candidate is promising)
183/// 2. **Block-max pruning** (Ding & Suel 2011): skip blocks where per-block
184///    upper bounds can't beat the current threshold
185/// 3. **Conjunction optimization** (Lucene/Grand 2023): progressively intersect
186///    essential terms as threshold rises, skipping docs that lack enough terms
187pub struct MaxScoreExecutor<'a> {
188    cursors: Vec<TermCursor<'a>>,
189    prefix_sums: Vec<f32>,
190    collector: ScoreCollector,
191    heap_factor: f32,
192    predicate: Option<super::DocPredicate<'a>>,
193}
194
195/// Unified term cursor for Block-Max MaxScore execution.
196///
197/// All per-position decode buffers (`doc_ids`, `scores`, `ordinals`) live in
198/// the struct directly and are filled by `ensure_block_loaded`.
199///
200/// Skip-list metadata is **not** materialized — it is read lazily from the
201/// underlying source (`BlockPostingList` for text, `SparseIndex` for sparse),
202/// both backed by zero-copy mmap'd `OwnedBytes`.
203pub(crate) struct TermCursor<'a> {
204    pub max_score: f32,
205    num_blocks: usize,
206    // ── Per-position state (filled by ensure_block_loaded) ──────────
207    block_idx: usize,
208    doc_ids: Vec<u32>,
209    scores: Vec<f32>,
210    ordinals: Vec<u16>,
211    pos: usize,
212    block_loaded: bool,
213    exhausted: bool,
214    // ── Lazy ordinal decode (sparse only) ───────────────────────────
215    /// When true, ordinal decode is deferred until ordinal_mut() is called.
216    /// Set to true for MaxScoreExecutor cursors (most blocks never need ordinals).
217    lazy_ordinals: bool,
218    /// Whether ordinals have been decoded for the current block.
219    ordinals_loaded: bool,
220    /// Stored sparse block for deferred ordinal decode (cheap Arc clone of mmap data).
221    current_sparse_block: Option<crate::structures::SparseBlock>,
222    // ── Block decode + skip access source ───────────────────────────
223    variant: CursorVariant<'a>,
224}
225
226enum CursorVariant<'a> {
227    /// Full-text BM25 — in-memory BlockPostingList (skip list + block data)
228    Text {
229        list: crate::structures::BlockPostingList,
230        idf: f32,
231        /// Precomputed: BM25_B / max(avg_field_len, 1.0)
232        b_over_avgfl: f32,
233        /// Precomputed: 1.0 - BM25_B
234        one_minus_b: f32,
235        tfs: Vec<u32>, // temp decode buffer, converted to scores
236    },
237    /// Sparse vector — mmap'd SparseIndex (skip entries + block data)
238    Sparse {
239        si: &'a crate::segment::SparseIndex,
240        query_weight: f32,
241        skip_start: usize,
242        block_data_offset: u64,
243    },
244}
245
246// ── TermCursor async/sync macros ──────────────────────────────────────────
247//
248// Parameterised on:
249//   $load_block_fn – load_block_direct | load_block_direct_sync  (sparse I/O)
250//   $ensure_fn     – ensure_block_loaded | ensure_block_loaded_sync
251//   $($aw)*        – .await  (present for async, absent for sync)
252
253macro_rules! cursor_ensure_block {
254    ($self:ident, $load_block_fn:ident, $($aw:tt)*) => {{
255        if $self.exhausted || $self.block_loaded {
256            return Ok(!$self.exhausted);
257        }
258        match &mut $self.variant {
259            CursorVariant::Text {
260                list,
261                idf,
262                b_over_avgfl,
263                one_minus_b,
264                tfs,
265            } => {
266                if list.decode_block_into($self.block_idx, &mut $self.doc_ids, tfs) {
267                    let idf_val = *idf;
268                    let b_avg = *b_over_avgfl;
269                    let one_b = *one_minus_b;
270                    $self.scores.clear();
271                    $self.scores.reserve(tfs.len());
272                    // Precomputed BM25: length_norm = one_minus_b + b_over_avgfl * tf
273                    // (tf is used as both term frequency and doc length — a known approx)
274                    for &tf in tfs.iter() {
275                        let tf = tf as f32;
276                        let length_norm = one_b + b_avg * tf;
277                        let tf_norm = (tf * (super::BM25_K1 + 1.0))
278                            / (tf + super::BM25_K1 * length_norm);
279                        $self.scores.push(idf_val * tf_norm);
280                    }
281                    $self.pos = 0;
282                    $self.block_loaded = true;
283                    Ok(true)
284                } else {
285                    $self.exhausted = true;
286                    Ok(false)
287                }
288            }
289            CursorVariant::Sparse {
290                si,
291                query_weight,
292                skip_start,
293                block_data_offset,
294                ..
295            } => {
296                let block = si
297                    .$load_block_fn(*skip_start, *block_data_offset, $self.block_idx)
298                    $($aw)* ?;
299                match block {
300                    Some(b) => {
301                        b.decode_doc_ids_into(&mut $self.doc_ids);
302                        b.decode_scored_weights_into(*query_weight, &mut $self.scores);
303                        if $self.lazy_ordinals {
304                            // Defer ordinal decode until ordinal_mut() is called.
305                            // Stores cheap Arc-backed mmap slice, no copy.
306                            $self.current_sparse_block = Some(b);
307                            $self.ordinals_loaded = false;
308                        } else {
309                            b.decode_ordinals_into(&mut $self.ordinals);
310                            $self.ordinals_loaded = true;
311                            $self.current_sparse_block = None;
312                        }
313                        $self.pos = 0;
314                        $self.block_loaded = true;
315                        Ok(true)
316                    }
317                    None => {
318                        $self.exhausted = true;
319                        Ok(false)
320                    }
321                }
322            }
323        }
324    }};
325}
326
327macro_rules! cursor_advance {
328    ($self:ident, $ensure_fn:ident, $($aw:tt)*) => {{
329        if $self.exhausted {
330            return Ok(u32::MAX);
331        }
332        $self.$ensure_fn() $($aw)* ?;
333        if $self.exhausted {
334            return Ok(u32::MAX);
335        }
336        Ok($self.advance_pos())
337    }};
338}
339
340macro_rules! cursor_seek {
341    ($self:ident, $ensure_fn:ident, $target:expr, $($aw:tt)*) => {{
342        if let Some(doc) = $self.seek_prepare($target) {
343            return Ok(doc);
344        }
345        $self.$ensure_fn() $($aw)* ?;
346        if $self.seek_finish($target) {
347            $self.$ensure_fn() $($aw)* ?;
348        }
349        Ok($self.doc())
350    }};
351}
352
353impl<'a> TermCursor<'a> {
354    /// Create a full-text BM25 cursor (lazy — no blocks decoded yet).
355    pub fn text(
356        posting_list: crate::structures::BlockPostingList,
357        idf: f32,
358        avg_field_len: f32,
359    ) -> Self {
360        let max_tf = posting_list.max_tf() as f32;
361        let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
362        let num_blocks = posting_list.num_blocks();
363        let safe_avg = avg_field_len.max(1.0);
364        Self {
365            max_score,
366            num_blocks,
367            block_idx: 0,
368            doc_ids: Vec::with_capacity(128),
369            scores: Vec::with_capacity(128),
370            ordinals: Vec::new(),
371            pos: 0,
372            block_loaded: false,
373            exhausted: num_blocks == 0,
374            lazy_ordinals: false,
375            ordinals_loaded: true, // text cursors never have ordinals
376            current_sparse_block: None,
377            variant: CursorVariant::Text {
378                list: posting_list,
379                idf,
380                b_over_avgfl: super::BM25_B / safe_avg,
381                one_minus_b: 1.0 - super::BM25_B,
382                tfs: Vec::with_capacity(128),
383            },
384        }
385    }
386
387    /// Create a sparse vector cursor with lazy block loading.
388    /// Skip entries are **not** copied — they are read from `SparseIndex` mmap on demand.
389    pub fn sparse(
390        si: &'a crate::segment::SparseIndex,
391        query_weight: f32,
392        skip_start: usize,
393        skip_count: usize,
394        global_max_weight: f32,
395        block_data_offset: u64,
396    ) -> Self {
397        Self {
398            max_score: query_weight.abs() * global_max_weight,
399            num_blocks: skip_count,
400            block_idx: 0,
401            doc_ids: Vec::with_capacity(256),
402            scores: Vec::with_capacity(256),
403            ordinals: Vec::with_capacity(256),
404            pos: 0,
405            block_loaded: false,
406            exhausted: skip_count == 0,
407            lazy_ordinals: false,
408            ordinals_loaded: true,
409            current_sparse_block: None,
410            variant: CursorVariant::Sparse {
411                si,
412                query_weight,
413                skip_start,
414                block_data_offset,
415            },
416        }
417    }
418
419    // ── Skip-entry access (lazy, zero-copy for sparse) ──────────────────
420
421    #[inline]
422    fn block_first_doc(&self, idx: usize) -> DocId {
423        match &self.variant {
424            CursorVariant::Text { list, .. } => list.block_first_doc(idx).unwrap_or(u32::MAX),
425            CursorVariant::Sparse { si, skip_start, .. } => {
426                si.read_skip_entry(*skip_start + idx).first_doc
427            }
428        }
429    }
430
431    #[inline]
432    fn block_last_doc(&self, idx: usize) -> DocId {
433        match &self.variant {
434            CursorVariant::Text { list, .. } => list.block_last_doc(idx).unwrap_or(0),
435            CursorVariant::Sparse { si, skip_start, .. } => {
436                si.read_skip_entry(*skip_start + idx).last_doc
437            }
438        }
439    }
440
441    // ── Read-only accessors ─────────────────────────────────────────────
442
443    #[inline]
444    pub fn doc(&self) -> DocId {
445        if self.exhausted {
446            return u32::MAX;
447        }
448        if self.block_loaded {
449            debug_assert!(self.pos < self.doc_ids.len());
450            // SAFETY: pos < doc_ids.len() is maintained by advance_pos/ensure_block_loaded.
451            unsafe { *self.doc_ids.get_unchecked(self.pos) }
452        } else {
453            self.block_first_doc(self.block_idx)
454        }
455    }
456
457    #[inline]
458    pub fn ordinal(&self) -> u16 {
459        if !self.block_loaded || self.ordinals.is_empty() {
460            return 0;
461        }
462        debug_assert!(self.pos < self.ordinals.len());
463        // SAFETY: pos < ordinals.len() is maintained by advance_pos/ensure_block_loaded.
464        unsafe { *self.ordinals.get_unchecked(self.pos) }
465    }
466
467    /// Lazily-decoded ordinal accessor for MaxScore executor.
468    ///
469    /// When `lazy_ordinals=true`, ordinals are not decoded during block loading.
470    /// This method triggers the deferred decode on first access, amortized over
471    /// the block. Subsequent calls within the same block are free.
472    #[inline]
473    pub fn ordinal_mut(&mut self) -> u16 {
474        if !self.block_loaded {
475            return 0;
476        }
477        if !self.ordinals_loaded {
478            if let Some(ref block) = self.current_sparse_block {
479                block.decode_ordinals_into(&mut self.ordinals);
480            }
481            self.ordinals_loaded = true;
482        }
483        if self.ordinals.is_empty() {
484            return 0;
485        }
486        debug_assert!(self.pos < self.ordinals.len());
487        unsafe { *self.ordinals.get_unchecked(self.pos) }
488    }
489
490    #[inline]
491    pub fn score(&self) -> f32 {
492        if !self.block_loaded {
493            return 0.0;
494        }
495        debug_assert!(self.pos < self.scores.len());
496        // SAFETY: pos < scores.len() is maintained by advance_pos/ensure_block_loaded.
497        unsafe { *self.scores.get_unchecked(self.pos) }
498    }
499
500    #[inline]
501    pub fn current_block_max_score(&self) -> f32 {
502        if self.exhausted {
503            return 0.0;
504        }
505        match &self.variant {
506            CursorVariant::Text { list, idf, .. } => {
507                let block_max_tf = list.block_max_tf(self.block_idx).unwrap_or(0) as f32;
508                super::bm25_upper_bound(block_max_tf.max(1.0), *idf)
509            }
510            CursorVariant::Sparse {
511                si,
512                query_weight,
513                skip_start,
514                ..
515            } => query_weight.abs() * si.read_skip_entry(*skip_start + self.block_idx).max_weight,
516        }
517    }
518
519    // ── Block navigation ────────────────────────────────────────────────
520
521    pub fn skip_to_next_block(&mut self) -> DocId {
522        if self.exhausted {
523            return u32::MAX;
524        }
525        self.block_idx += 1;
526        self.block_loaded = false;
527        if self.block_idx >= self.num_blocks {
528            self.exhausted = true;
529            return u32::MAX;
530        }
531        self.block_first_doc(self.block_idx)
532    }
533
534    #[inline]
535    fn advance_pos(&mut self) -> DocId {
536        self.pos += 1;
537        if self.pos >= self.doc_ids.len() {
538            self.block_idx += 1;
539            self.block_loaded = false;
540            if self.block_idx >= self.num_blocks {
541                self.exhausted = true;
542                return u32::MAX;
543            }
544        }
545        self.doc()
546    }
547
548    // ── Block loading / advance / seek ─────────────────────────────────
549    //
550    // Macros parameterised on sparse I/O method + optional .await to
551    // stamp out both async and sync variants without duplication.
552
553    pub async fn ensure_block_loaded(&mut self) -> crate::Result<bool> {
554        cursor_ensure_block!(self, load_block_direct, .await)
555    }
556
557    pub fn ensure_block_loaded_sync(&mut self) -> crate::Result<bool> {
558        cursor_ensure_block!(self, load_block_direct_sync,)
559    }
560
561    pub async fn advance(&mut self) -> crate::Result<DocId> {
562        cursor_advance!(self, ensure_block_loaded, .await)
563    }
564
565    pub fn advance_sync(&mut self) -> crate::Result<DocId> {
566        cursor_advance!(self, ensure_block_loaded_sync,)
567    }
568
569    pub async fn seek(&mut self, target: DocId) -> crate::Result<DocId> {
570        cursor_seek!(self, ensure_block_loaded, target, .await)
571    }
572
573    pub fn seek_sync(&mut self, target: DocId) -> crate::Result<DocId> {
574        cursor_seek!(self, ensure_block_loaded_sync, target,)
575    }
576
577    fn seek_prepare(&mut self, target: DocId) -> Option<DocId> {
578        if self.exhausted {
579            return Some(u32::MAX);
580        }
581
582        // Fast path: target is within the currently loaded block
583        if self.block_loaded
584            && let Some(&last) = self.doc_ids.last()
585        {
586            if last >= target && self.doc_ids[self.pos] < target {
587                let remaining = &self.doc_ids[self.pos..];
588                self.pos += crate::structures::simd::find_first_ge_u32(remaining, target);
589                if self.pos >= self.doc_ids.len() {
590                    self.block_idx += 1;
591                    self.block_loaded = false;
592                    if self.block_idx >= self.num_blocks {
593                        self.exhausted = true;
594                        return Some(u32::MAX);
595                    }
596                }
597                return Some(self.doc());
598            }
599            if self.doc_ids[self.pos] >= target {
600                return Some(self.doc());
601            }
602        }
603
604        // Seek to the block containing target
605        let lo = match &self.variant {
606            // Text: SIMD-accelerated 2-level seek (L1 + L0)
607            CursorVariant::Text { list, .. } => match list.seek_block(target, self.block_idx) {
608                Some(idx) => idx,
609                None => {
610                    self.exhausted = true;
611                    return Some(u32::MAX);
612                }
613            },
614            // Sparse: binary search on skip entries (lazy mmap reads)
615            CursorVariant::Sparse { .. } => {
616                let mut lo = self.block_idx;
617                let mut hi = self.num_blocks;
618                while lo < hi {
619                    let mid = lo + (hi - lo) / 2;
620                    if self.block_last_doc(mid) < target {
621                        lo = mid + 1;
622                    } else {
623                        hi = mid;
624                    }
625                }
626                lo
627            }
628        };
629        if lo >= self.num_blocks {
630            self.exhausted = true;
631            return Some(u32::MAX);
632        }
633        if lo != self.block_idx || !self.block_loaded {
634            self.block_idx = lo;
635            self.block_loaded = false;
636        }
637        None
638    }
639
640    #[inline]
641    fn seek_finish(&mut self, target: DocId) -> bool {
642        if self.exhausted {
643            return false;
644        }
645        self.pos = crate::structures::simd::find_first_ge_u32(&self.doc_ids, target);
646        if self.pos >= self.doc_ids.len() {
647            self.block_idx += 1;
648            self.block_loaded = false;
649            if self.block_idx >= self.num_blocks {
650                self.exhausted = true;
651                return false;
652            }
653            return true;
654        }
655        false
656    }
657}
658
659/// Macro to stamp out the Block-Max MaxScore loop for both async and sync paths.
660///
661/// `$ensure`, `$advance`, `$seek` are cursor method idents (async or _sync variants).
662/// `$($aw:tt)*` captures `.await` for async or nothing for sync.
663macro_rules! bms_execute_loop {
664    ($self:ident, $ensure:ident, $advance:ident, $seek:ident, $($aw:tt)*) => {{
665        let n = $self.cursors.len();
666
667        // Load first block for each cursor (ensures doc() returns real values)
668        for cursor in &mut $self.cursors {
669            cursor.$ensure() $($aw)* ?;
670        }
671
672        let mut docs_scored = 0u64;
673        let mut docs_skipped = 0u64;
674        let mut blocks_skipped = 0u64;
675        let mut conjunction_skipped = 0u64;
676        let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
677
678        loop {
679            let partition = $self.find_partition();
680            if partition >= n {
681                break;
682            }
683
684            // Find minimum doc_id across essential cursors and collect
685            // which cursors are at min_doc (avoids redundant re-checks in
686            // conjunction, block-max, predicate, and scoring passes).
687            let mut min_doc = u32::MAX;
688            let mut at_min_mask = 0u64; // bitset of cursor indices at min_doc
689            for i in partition..n {
690                let doc = $self.cursors[i].doc();
691                match doc.cmp(&min_doc) {
692                    std::cmp::Ordering::Less => {
693                        min_doc = doc;
694                        at_min_mask = 1u64 << (i as u32);
695                    }
696                    std::cmp::Ordering::Equal => {
697                        at_min_mask |= 1u64 << (i as u32);
698                    }
699                    _ => {}
700                }
701            }
702            if min_doc == u32::MAX {
703                break;
704            }
705
706            let non_essential_upper = if partition > 0 {
707                $self.prefix_sums[partition - 1]
708            } else {
709                0.0
710            };
711            // Small epsilon to guard against FP rounding in score accumulation.
712            // Without this, a document whose true score equals the threshold can
713            // be incorrectly pruned due to rounding in the heap_factor multiply
714            // or in the prefix_sum additions.
715            let adjusted_threshold = $self.collector.threshold() * $self.heap_factor - 1e-6;
716
717            // --- Conjunction optimization ---
718            if $self.collector.len() >= $self.collector.k {
719                let mut present_upper: f32 = 0.0;
720                let mut mask = at_min_mask;
721                while mask != 0 {
722                    let i = mask.trailing_zeros() as usize;
723                    present_upper += $self.cursors[i].max_score;
724                    mask &= mask - 1;
725                }
726
727                if present_upper + non_essential_upper <= adjusted_threshold {
728                    let mut mask = at_min_mask;
729                    while mask != 0 {
730                        let i = mask.trailing_zeros() as usize;
731                        $self.cursors[i].$ensure() $($aw)* ?;
732                        $self.cursors[i].$advance() $($aw)* ?;
733                        mask &= mask - 1;
734                    }
735                    conjunction_skipped += 1;
736                    continue;
737                }
738            }
739
740            // --- Block-max pruning ---
741            if $self.collector.len() >= $self.collector.k {
742                let mut block_max_sum: f32 = 0.0;
743                let mut mask = at_min_mask;
744                while mask != 0 {
745                    let i = mask.trailing_zeros() as usize;
746                    block_max_sum += $self.cursors[i].current_block_max_score();
747                    mask &= mask - 1;
748                }
749
750                if block_max_sum + non_essential_upper <= adjusted_threshold {
751                    let mut mask = at_min_mask;
752                    while mask != 0 {
753                        let i = mask.trailing_zeros() as usize;
754                        $self.cursors[i].skip_to_next_block();
755                        $self.cursors[i].$ensure() $($aw)* ?;
756                        mask &= mask - 1;
757                    }
758                    blocks_skipped += 1;
759                    continue;
760                }
761            }
762
763            // --- Predicate filter (after block-max, before scoring) ---
764            if let Some(ref pred) = $self.predicate {
765                if !pred(min_doc) {
766                    let mut mask = at_min_mask;
767                    while mask != 0 {
768                        let i = mask.trailing_zeros() as usize;
769                        $self.cursors[i].$ensure() $($aw)* ?;
770                        $self.cursors[i].$advance() $($aw)* ?;
771                        mask &= mask - 1;
772                    }
773                    continue;
774                }
775            }
776
777            // --- Score essential cursors ---
778            ordinal_scores.clear();
779            {
780                let mut mask = at_min_mask;
781                while mask != 0 {
782                    let i = mask.trailing_zeros() as usize;
783                    $self.cursors[i].$ensure() $($aw)* ?;
784                    while $self.cursors[i].doc() == min_doc {
785                        let ord = $self.cursors[i].ordinal_mut();
786                        let sc = $self.cursors[i].score();
787                        ordinal_scores.push((ord, sc));
788                        $self.cursors[i].$advance() $($aw)* ?;
789                    }
790                    mask &= mask - 1;
791                }
792            }
793
794            let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
795            if $self.collector.len() >= $self.collector.k
796                && essential_total + non_essential_upper <= adjusted_threshold
797            {
798                docs_skipped += 1;
799                continue;
800            }
801
802            // --- Score non-essential cursors (highest max_score first for early exit) ---
803            let mut running_total = essential_total;
804            for i in (0..partition).rev() {
805                if $self.collector.len() >= $self.collector.k
806                    && running_total + $self.prefix_sums[i] <= adjusted_threshold
807                {
808                    break;
809                }
810
811                let doc = $self.cursors[i].$seek(min_doc) $($aw)* ?;
812                if doc == min_doc {
813                    while $self.cursors[i].doc() == min_doc {
814                        let s = $self.cursors[i].score();
815                        running_total += s;
816                        let ord = $self.cursors[i].ordinal_mut();
817                        ordinal_scores.push((ord, s));
818                        $self.cursors[i].$advance() $($aw)* ?;
819                    }
820                }
821            }
822
823            // --- Group by ordinal and insert ---
824            // Fast path: single entry (common for single-valued fields) — skip sort + grouping
825            if ordinal_scores.len() == 1 {
826                let (ord, score) = ordinal_scores[0];
827                if $self.collector.insert_with_ordinal(min_doc, score, ord) {
828                    docs_scored += 1;
829                } else {
830                    docs_skipped += 1;
831                }
832            } else if !ordinal_scores.is_empty() {
833                if ordinal_scores.len() > 2 {
834                    ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
835                } else if ordinal_scores.len() == 2 && ordinal_scores[0].0 > ordinal_scores[1].0 {
836                    ordinal_scores.swap(0, 1);
837                }
838                let mut j = 0;
839                while j < ordinal_scores.len() {
840                    let current_ord = ordinal_scores[j].0;
841                    let mut score = 0.0f32;
842                    while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
843                        score += ordinal_scores[j].1;
844                        j += 1;
845                    }
846                    if $self
847                        .collector
848                        .insert_with_ordinal(min_doc, score, current_ord)
849                    {
850                        docs_scored += 1;
851                    } else {
852                        docs_skipped += 1;
853                    }
854                }
855            }
856        }
857
858        let results: Vec<ScoredDoc> = $self
859            .collector
860            .into_sorted_results()
861            .into_iter()
862            .map(|(doc_id, score, ordinal)| ScoredDoc {
863                doc_id,
864                score,
865                ordinal,
866            })
867            .collect();
868
869        debug!(
870            "MaxScoreExecutor: scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
871            docs_scored,
872            docs_skipped,
873            blocks_skipped,
874            conjunction_skipped,
875            results.len(),
876            results.first().map(|r| r.score).unwrap_or(0.0)
877        );
878
879        Ok(results)
880    }};
881}
882
883impl<'a> MaxScoreExecutor<'a> {
884    /// Create a new executor from pre-built cursors.
885    ///
886    /// Cursors are sorted by max_score ascending (non-essential first) and
887    /// prefix sums are computed for the MaxScore partitioning.
888    pub(crate) fn new(mut cursors: Vec<TermCursor<'a>>, k: usize, heap_factor: f32) -> Self {
889        // Enable lazy ordinal decode — ordinals are only decoded when a doc
890        // actually reaches the scoring phase (saves ~100ns per skipped block).
891        for c in &mut cursors {
892            c.lazy_ordinals = true;
893        }
894
895        // Sort by max_score ascending (non-essential first)
896        cursors.sort_by(|a, b| {
897            a.max_score
898                .partial_cmp(&b.max_score)
899                .unwrap_or(Ordering::Equal)
900        });
901
902        let mut prefix_sums = Vec::with_capacity(cursors.len());
903        let mut cumsum = 0.0f32;
904        for c in &cursors {
905            cumsum += c.max_score;
906            prefix_sums.push(cumsum);
907        }
908
909        debug!(
910            "Creating MaxScoreExecutor: num_cursors={}, k={}, total_upper={:.4}, heap_factor={:.2}",
911            cursors.len(),
912            k,
913            cumsum,
914            heap_factor
915        );
916
917        Self {
918            cursors,
919            prefix_sums,
920            collector: ScoreCollector::new(k),
921            heap_factor: heap_factor.clamp(0.0, 1.0),
922            predicate: None,
923        }
924    }
925
926    /// Create an executor for sparse vector queries.
927    ///
928    /// Builds `TermCursor::Sparse` for each matched dimension.
929    pub fn sparse(
930        sparse_index: &'a crate::segment::SparseIndex,
931        query_terms: Vec<(u32, f32)>,
932        k: usize,
933        heap_factor: f32,
934    ) -> Self {
935        let cursors: Vec<TermCursor<'a>> = query_terms
936            .iter()
937            .filter_map(|&(dim_id, qw)| {
938                let (skip_start, skip_count, global_max, block_data_offset) =
939                    sparse_index.get_skip_range_full(dim_id)?;
940                Some(TermCursor::sparse(
941                    sparse_index,
942                    qw,
943                    skip_start,
944                    skip_count,
945                    global_max,
946                    block_data_offset,
947                ))
948            })
949            .collect();
950        Self::new(cursors, k, heap_factor)
951    }
952
953    /// Create an executor for full-text BM25 queries.
954    ///
955    /// Builds `TermCursor::Text` for each posting list.
956    pub fn text(
957        posting_lists: Vec<(crate::structures::BlockPostingList, f32)>,
958        avg_field_len: f32,
959        k: usize,
960    ) -> Self {
961        let cursors: Vec<TermCursor<'a>> = posting_lists
962            .into_iter()
963            .map(|(pl, idf)| TermCursor::text(pl, idf, avg_field_len))
964            .collect();
965        Self::new(cursors, k, 1.0)
966    }
967
968    #[inline]
969    fn find_partition(&self) -> usize {
970        let threshold = self.collector.threshold() * self.heap_factor;
971        self.prefix_sums.partition_point(|&sum| sum <= threshold)
972    }
973
974    /// Attach a per-doc predicate filter to this executor.
975    ///
976    /// Docs failing the predicate are skipped after block-max pruning but
977    /// before scoring. The predicate does not affect thresholds or block-max
978    /// comparisons — the heap stores pure sparse/text scores.
979    pub fn with_predicate(mut self, predicate: super::DocPredicate<'a>) -> Self {
980        self.predicate = Some(predicate);
981        self
982    }
983
984    /// Execute Block-Max MaxScore and return top-k results (async).
985    pub async fn execute(mut self) -> crate::Result<Vec<ScoredDoc>> {
986        if self.cursors.is_empty() {
987            return Ok(Vec::new());
988        }
989        bms_execute_loop!(self, ensure_block_loaded, advance, seek, .await)
990    }
991
992    /// Synchronous execution — works when all cursors are text or mmap-backed sparse.
993    pub fn execute_sync(mut self) -> crate::Result<Vec<ScoredDoc>> {
994        if self.cursors.is_empty() {
995            return Ok(Vec::new());
996        }
997        bms_execute_loop!(self, ensure_block_loaded_sync, advance_sync, seek_sync,)
998    }
999}
1000
1001#[cfg(test)]
1002mod tests {
1003    use super::*;
1004
1005    #[test]
1006    fn test_score_collector_basic() {
1007        let mut collector = ScoreCollector::new(3);
1008
1009        collector.insert(1, 1.0);
1010        collector.insert(2, 2.0);
1011        collector.insert(3, 3.0);
1012        assert_eq!(collector.threshold(), 1.0);
1013
1014        collector.insert(4, 4.0);
1015        assert_eq!(collector.threshold(), 2.0);
1016
1017        let results = collector.into_sorted_results();
1018        assert_eq!(results.len(), 3);
1019        assert_eq!(results[0].0, 4); // Highest score
1020        assert_eq!(results[1].0, 3);
1021        assert_eq!(results[2].0, 2);
1022    }
1023
1024    #[test]
1025    fn test_score_collector_threshold() {
1026        let mut collector = ScoreCollector::new(2);
1027
1028        collector.insert(1, 5.0);
1029        collector.insert(2, 3.0);
1030        assert_eq!(collector.threshold(), 3.0);
1031
1032        // Should not enter (score too low)
1033        assert!(!collector.would_enter(2.0));
1034        assert!(!collector.insert(3, 2.0));
1035
1036        // Should enter (score high enough)
1037        assert!(collector.would_enter(4.0));
1038        assert!(collector.insert(4, 4.0));
1039        assert_eq!(collector.threshold(), 4.0);
1040    }
1041
1042    #[test]
1043    fn test_heap_entry_ordering() {
1044        let mut heap = BinaryHeap::new();
1045        heap.push(HeapEntry {
1046            doc_id: 1,
1047            score: 3.0,
1048            ordinal: 0,
1049        });
1050        heap.push(HeapEntry {
1051            doc_id: 2,
1052            score: 1.0,
1053            ordinal: 0,
1054        });
1055        heap.push(HeapEntry {
1056            doc_id: 3,
1057            score: 2.0,
1058            ordinal: 0,
1059        });
1060
1061        // Min-heap: lowest score should come out first
1062        assert_eq!(heap.pop().unwrap().score, 1.0);
1063        assert_eq!(heap.pop().unwrap().score, 2.0);
1064        assert_eq!(heap.pop().unwrap().score, 3.0);
1065    }
1066}