Skip to main content

hermes_core/query/
scoring.rs

1//! Shared scoring abstractions for text and sparse vector search
2//!
3//! Provides common types and executors for efficient top-k retrieval:
4//! - `TermCursor`: Unified cursor for both BM25 text and sparse vector posting lists
5//! - `ScoreCollector`: Efficient min-heap for maintaining top-k results
6//! - `MaxScoreExecutor`: Unified Block-Max MaxScore with conjunction optimization
7//! - `ScoredDoc`: Result type with doc_id, score, and ordinal
8
9use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11
12use log::{debug, warn};
13
14use crate::DocId;
15
16/// Entry for top-k min-heap
17#[derive(Clone, Copy)]
18pub struct HeapEntry {
19    pub doc_id: DocId,
20    pub score: f32,
21    pub ordinal: u16,
22}
23
24impl PartialEq for HeapEntry {
25    fn eq(&self, other: &Self) -> bool {
26        self.score == other.score && self.doc_id == other.doc_id
27    }
28}
29
30impl Eq for HeapEntry {}
31
32impl Ord for HeapEntry {
33    fn cmp(&self, other: &Self) -> Ordering {
34        // Min-heap: lower scores come first (to be evicted).
35        // total_cmp is branchless (compiles to a single comparison instruction).
36        other
37            .score
38            .total_cmp(&self.score)
39            .then(self.doc_id.cmp(&other.doc_id))
40    }
41}
42
43impl PartialOrd for HeapEntry {
44    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
45        Some(self.cmp(other))
46    }
47}
48
49/// Efficient top-k collector using min-heap (internal, scoring-layer)
50///
51/// Maintains the k highest-scoring documents using a min-heap where the
52/// lowest score is at the top for O(1) threshold lookup and O(log k) eviction.
53/// No deduplication — caller must ensure each doc_id is inserted only once.
54///
55/// This is intentionally separate from `TopKCollector` in `collector.rs`:
56/// `ScoreCollector` is used inside `MaxScoreExecutor` where only `(doc_id,
57/// score, ordinal)` tuples exist — no `Scorer` trait, no position tracking,
58/// and the threshold must be inlined for tight block-max loops.
59/// `TopKCollector` wraps a `Scorer` and drives the full `DocSet`/`Scorer`
60/// protocol, collecting positions on demand.
61pub struct ScoreCollector {
62    /// Min-heap of top-k entries (lowest score at top for eviction)
63    heap: BinaryHeap<HeapEntry>,
64    pub k: usize,
65    /// Cached threshold: avoids repeated heap.peek() in hot loops.
66    /// Updated only when the heap changes (insert/pop).
67    cached_threshold: f32,
68}
69
70impl ScoreCollector {
71    /// Create a new collector for top-k results
72    pub fn new(k: usize) -> Self {
73        // Cap capacity to avoid allocation overflow for very large k
74        let capacity = k.saturating_add(1).min(1_000_000);
75        Self {
76            heap: BinaryHeap::with_capacity(capacity),
77            k,
78            cached_threshold: 0.0,
79        }
80    }
81
82    /// Current score threshold (minimum score to enter top-k)
83    #[inline]
84    pub fn threshold(&self) -> f32 {
85        self.cached_threshold
86    }
87
88    /// Recompute cached threshold from heap state
89    #[inline]
90    fn update_threshold(&mut self) {
91        self.cached_threshold = if self.heap.len() >= self.k {
92            self.heap.peek().map(|e| e.score).unwrap_or(0.0)
93        } else {
94            0.0
95        };
96    }
97
98    /// Insert a document score. Returns true if inserted in top-k.
99    /// Caller must ensure each doc_id is inserted only once.
100    #[inline]
101    pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
102        self.insert_with_ordinal(doc_id, score, 0)
103    }
104
105    /// Insert a document score with ordinal. Returns true if inserted in top-k.
106    /// Caller must ensure each doc_id is inserted only once.
107    #[inline]
108    pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
109        if self.heap.len() < self.k {
110            self.heap.push(HeapEntry {
111                doc_id,
112                score,
113                ordinal,
114            });
115            // Only recompute threshold when heap just became full
116            if self.heap.len() == self.k {
117                self.update_threshold();
118            }
119            true
120        } else if score > self.cached_threshold {
121            self.heap.push(HeapEntry {
122                doc_id,
123                score,
124                ordinal,
125            });
126            self.heap.pop(); // Remove lowest
127            self.update_threshold();
128            true
129        } else {
130            false
131        }
132    }
133
134    /// Check if a score could potentially enter top-k
135    #[inline]
136    pub fn would_enter(&self, score: f32) -> bool {
137        self.heap.len() < self.k || score > self.cached_threshold
138    }
139
140    /// Get number of documents collected so far
141    #[inline]
142    pub fn len(&self) -> usize {
143        self.heap.len()
144    }
145
146    /// Check if collector is empty
147    #[inline]
148    pub fn is_empty(&self) -> bool {
149        self.heap.is_empty()
150    }
151
152    /// Seed the threshold from a cross-segment shared value.
153    ///
154    /// Pre-fills the heap with `k` dummy entries at the given score so that
155    /// pruning kicks in immediately. Only has effect if called before any
156    /// real inserts and `initial_threshold > 0.0`.
157    pub fn seed_threshold(&mut self, initial_threshold: f32) {
158        if initial_threshold > 0.0 && self.heap.is_empty() {
159            for _ in 0..self.k {
160                self.heap.push(HeapEntry {
161                    doc_id: u32::MAX,
162                    score: initial_threshold,
163                    ordinal: 0,
164                });
165            }
166            self.update_threshold();
167        }
168    }
169
170    /// Convert to sorted top-k results (descending by score).
171    /// Filters out sentinel entries (doc_id == u32::MAX) from threshold seeding.
172    pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
173        let mut results: Vec<(DocId, f32, u16)> = self
174            .heap
175            .into_vec()
176            .into_iter()
177            .filter(|e| e.doc_id != u32::MAX)
178            .map(|e| (e.doc_id, e.score, e.ordinal))
179            .collect();
180
181        // Sort by score descending, then doc_id ascending
182        results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
183
184        results
185    }
186}
187
188/// Search result from MaxScore execution
189#[derive(Debug, Clone, Copy)]
190pub struct ScoredDoc {
191    pub doc_id: DocId,
192    pub score: f32,
193    /// Ordinal for multi-valued fields (which vector in the field matched)
194    pub ordinal: u16,
195}
196
197/// Unified Block-Max MaxScore executor for top-k retrieval
198///
199/// Works with both full-text (BM25) and sparse vector (dot product) queries
200/// through the polymorphic `TermCursor`. Combines three optimizations:
201/// 1. **MaxScore partitioning** (Turtle & Flood 1995): terms split into essential
202///    (must check) and non-essential (only scored if candidate is promising)
203/// 2. **Block-max pruning** (Ding & Suel 2011): skip blocks where per-block
204///    upper bounds can't beat the current threshold
205/// 3. **Conjunction optimization** (Lucene/Grand 2023): progressively intersect
206///    essential terms as threshold rises, skipping docs that lack enough terms
207pub struct MaxScoreExecutor<'a> {
208    cursors: Vec<TermCursor<'a>>,
209    prefix_sums: Vec<f32>,
210    collector: ScoreCollector,
211    inv_heap_factor: f32,
212    predicate: Option<super::DocPredicate<'a>>,
213}
214
215/// Unified term cursor for Block-Max MaxScore execution.
216///
217/// All per-position decode buffers (`doc_ids`, `scores`, `ordinals`) live in
218/// the struct directly and are filled by `ensure_block_loaded`.
219///
220/// Skip-list metadata is **not** materialized — it is read lazily from the
221/// underlying source (`BlockPostingList` for text, `SparseIndex` for sparse),
222/// both backed by zero-copy mmap'd `OwnedBytes`.
223pub(crate) struct TermCursor<'a> {
224    pub max_score: f32,
225    num_blocks: usize,
226    // ── Per-position state (filled by ensure_block_loaded) ──────────
227    block_idx: usize,
228    doc_ids: Vec<u32>,
229    scores: Vec<f32>,
230    ordinals: Vec<u16>,
231    pos: usize,
232    block_loaded: bool,
233    exhausted: bool,
234    // ── Lazy ordinal decode (sparse only) ───────────────────────────
235    /// When true, ordinal decode is deferred until ordinal_mut() is called.
236    /// Set to true for MaxScoreExecutor cursors (most blocks never need ordinals).
237    lazy_ordinals: bool,
238    /// Whether ordinals have been decoded for the current block.
239    ordinals_loaded: bool,
240    /// Stored sparse block for deferred ordinal decode (cheap Arc clone of mmap data).
241    current_sparse_block: Option<crate::structures::SparseBlock>,
242    // ── Block decode + skip access source ───────────────────────────
243    variant: CursorVariant<'a>,
244}
245
246enum CursorVariant<'a> {
247    /// Full-text BM25 — in-memory BlockPostingList (skip list + block data)
248    Text {
249        list: crate::structures::BlockPostingList,
250        idf: f32,
251        /// Precomputed: idf * (BM25_K1 + 1.0) — numerator scale factor
252        idf_times_k1_plus_1: f32,
253        /// Precomputed: 1.0 + BM25_K1 * (BM25_B / avg_field_len) — denominator tf coefficient
254        denom_tf_coeff: f32,
255        /// Precomputed: BM25_K1 * (1.0 - BM25_B) — denominator constant
256        denom_const: f32,
257        tfs: Vec<u32>,
258        /// Deferred TF decode state: (block_offset, tf_start, count).
259        /// Set when doc_ids are decoded but TFs/scores are not yet computed.
260        deferred_tf: Option<(usize, usize, usize)>,
261    },
262    /// Sparse vector — mmap'd SparseIndex (skip entries + block data)
263    Sparse {
264        si: &'a crate::segment::SparseIndex,
265        query_weight: f32,
266        skip_start: usize,
267        block_data_offset: u64,
268    },
269}
270
271// ── TermCursor async/sync macros ──────────────────────────────────────────
272//
273// Parameterised on:
274//   $load_block_fn – load_block_direct | load_block_direct_sync  (sparse I/O)
275//   $ensure_fn     – ensure_block_loaded | ensure_block_loaded_sync
276//   $($aw)*        – .await  (present for async, absent for sync)
277
278macro_rules! cursor_ensure_block {
279    ($self:ident, $load_block_fn:ident, $($aw:tt)*) => {{
280        if $self.exhausted || $self.block_loaded {
281            return Ok(!$self.exhausted);
282        }
283        match &mut $self.variant {
284            CursorVariant::Text {
285                list,
286                deferred_tf,
287                ..
288            } => {
289                if let Some(state) = list.decode_block_doc_ids_only($self.block_idx, &mut $self.doc_ids) {
290                    *deferred_tf = Some(state);
291                    $self.scores.clear();
292                    $self.pos = 0;
293                    $self.block_loaded = true;
294                    Ok(true)
295                } else {
296                    $self.exhausted = true;
297                    Ok(false)
298                }
299            }
300            CursorVariant::Sparse {
301                si,
302                query_weight,
303                skip_start,
304                block_data_offset,
305                ..
306            } => {
307                let block = si
308                    .$load_block_fn(*skip_start, *block_data_offset, $self.block_idx)
309                    $($aw)* ?;
310                match block {
311                    Some(b) => {
312                        b.decode_doc_ids_into(&mut $self.doc_ids);
313                        b.decode_scored_weights_into(*query_weight, &mut $self.scores);
314                        if $self.lazy_ordinals {
315                            // Defer ordinal decode until ordinal_mut() is called.
316                            // Stores cheap Arc-backed mmap slice, no copy.
317                            $self.current_sparse_block = Some(b);
318                            $self.ordinals_loaded = false;
319                        } else {
320                            b.decode_ordinals_into(&mut $self.ordinals);
321                            $self.ordinals_loaded = true;
322                            $self.current_sparse_block = None;
323                        }
324                        $self.pos = 0;
325                        $self.block_loaded = true;
326                        Ok(true)
327                    }
328                    None => {
329                        $self.exhausted = true;
330                        Ok(false)
331                    }
332                }
333            }
334        }
335    }};
336}
337
338macro_rules! cursor_advance {
339    ($self:ident, $ensure_fn:ident, $($aw:tt)*) => {{
340        if $self.exhausted {
341            return Ok(u32::MAX);
342        }
343        $self.$ensure_fn() $($aw)* ?;
344        if $self.exhausted {
345            return Ok(u32::MAX);
346        }
347        Ok($self.advance_pos())
348    }};
349}
350
351macro_rules! cursor_seek {
352    ($self:ident, $ensure_fn:ident, $target:expr, $($aw:tt)*) => {{
353        if let Some(doc) = $self.seek_prepare($target) {
354            return Ok(doc);
355        }
356        $self.$ensure_fn() $($aw)* ?;
357        if $self.seek_finish($target) {
358            $self.$ensure_fn() $($aw)* ?;
359        }
360        Ok($self.doc())
361    }};
362}
363
364impl<'a> TermCursor<'a> {
365    /// Create a full-text BM25 cursor (lazy — no blocks decoded yet).
366    pub fn text(
367        posting_list: crate::structures::BlockPostingList,
368        idf: f32,
369        avg_field_len: f32,
370    ) -> Self {
371        let max_tf = posting_list.max_tf() as f32;
372        let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
373        let num_blocks = posting_list.num_blocks();
374        let safe_avg = avg_field_len.max(1.0);
375        Self {
376            max_score,
377            num_blocks,
378            block_idx: 0,
379            doc_ids: Vec::with_capacity(128),
380            scores: Vec::with_capacity(128),
381            ordinals: Vec::new(),
382            pos: 0,
383            block_loaded: false,
384            exhausted: num_blocks == 0,
385            lazy_ordinals: false,
386            ordinals_loaded: true, // text cursors never have ordinals
387            current_sparse_block: None,
388            variant: CursorVariant::Text {
389                list: posting_list,
390                idf,
391                idf_times_k1_plus_1: idf * (super::BM25_K1 + 1.0),
392                denom_tf_coeff: 1.0 + super::BM25_K1 * (super::BM25_B / safe_avg),
393                denom_const: super::BM25_K1 * (1.0 - super::BM25_B),
394                tfs: Vec::with_capacity(128),
395                deferred_tf: None,
396            },
397        }
398    }
399
400    /// Create a sparse vector cursor with lazy block loading.
401    /// Skip entries are **not** copied — they are read from `SparseIndex` mmap on demand.
402    pub fn sparse(
403        si: &'a crate::segment::SparseIndex,
404        query_weight: f32,
405        skip_start: usize,
406        skip_count: usize,
407        global_max_weight: f32,
408        block_data_offset: u64,
409    ) -> Self {
410        Self {
411            max_score: query_weight.abs() * global_max_weight,
412            num_blocks: skip_count,
413            block_idx: 0,
414            doc_ids: Vec::with_capacity(256),
415            scores: Vec::with_capacity(256),
416            ordinals: Vec::with_capacity(256),
417            pos: 0,
418            block_loaded: false,
419            exhausted: skip_count == 0,
420            lazy_ordinals: false,
421            ordinals_loaded: true,
422            current_sparse_block: None,
423            variant: CursorVariant::Sparse {
424                si,
425                query_weight,
426                skip_start,
427                block_data_offset,
428            },
429        }
430    }
431
432    // ── Skip-entry access (lazy, zero-copy for sparse) ──────────────────
433
434    #[inline]
435    fn block_first_doc(&self, idx: usize) -> DocId {
436        match &self.variant {
437            CursorVariant::Text { list, .. } => list.block_first_doc(idx).unwrap_or(u32::MAX),
438            CursorVariant::Sparse { si, skip_start, .. } => {
439                si.read_skip_entry(*skip_start + idx).first_doc
440            }
441        }
442    }
443
444    #[inline]
445    fn block_last_doc(&self, idx: usize) -> DocId {
446        match &self.variant {
447            CursorVariant::Text { list, .. } => list.block_last_doc(idx).unwrap_or(0),
448            CursorVariant::Sparse { si, skip_start, .. } => {
449                si.read_skip_entry(*skip_start + idx).last_doc
450            }
451        }
452    }
453
454    // ── Read-only accessors ─────────────────────────────────────────────
455
456    #[inline]
457    pub fn doc(&self) -> DocId {
458        if self.exhausted {
459            return u32::MAX;
460        }
461        if self.block_loaded {
462            debug_assert!(self.pos < self.doc_ids.len());
463            // SAFETY: pos < doc_ids.len() is maintained by advance_pos/ensure_block_loaded.
464            unsafe { *self.doc_ids.get_unchecked(self.pos) }
465        } else {
466            self.block_first_doc(self.block_idx)
467        }
468    }
469
470    #[inline]
471    pub fn ordinal(&self) -> u16 {
472        if !self.block_loaded || self.ordinals.is_empty() {
473            return 0;
474        }
475        debug_assert!(self.pos < self.ordinals.len());
476        // SAFETY: pos < ordinals.len() is maintained by advance_pos/ensure_block_loaded.
477        unsafe { *self.ordinals.get_unchecked(self.pos) }
478    }
479
480    /// Lazily-decoded ordinal accessor for MaxScore executor.
481    ///
482    /// When `lazy_ordinals=true`, ordinals are not decoded during block loading.
483    /// This method triggers the deferred decode on first access, amortized over
484    /// the block. Subsequent calls within the same block are free.
485    #[inline]
486    pub fn ordinal_mut(&mut self) -> u16 {
487        if !self.block_loaded {
488            return 0;
489        }
490        if !self.ordinals_loaded {
491            if let Some(ref block) = self.current_sparse_block {
492                block.decode_ordinals_into(&mut self.ordinals);
493            }
494            self.ordinals_loaded = true;
495        }
496        if self.ordinals.is_empty() {
497            return 0;
498        }
499        debug_assert!(self.pos < self.ordinals.len());
500        unsafe { *self.ordinals.get_unchecked(self.pos) }
501    }
502
503    #[inline]
504    pub fn score(&self) -> f32 {
505        if !self.block_loaded {
506            return 0.0;
507        }
508        debug_assert!(self.pos < self.scores.len());
509        // SAFETY: pos < scores.len() is maintained by advance_pos/ensure_block_loaded.
510        unsafe { *self.scores.get_unchecked(self.pos) }
511    }
512
513    /// Ensure BM25 scores are computed for the current block (lazy TF decode).
514    ///
515    /// For text cursors, TF unpacking and BM25 scoring are deferred from block
516    /// loading until this method is called, saving work for blocks skipped by
517    /// block-max or conjunction pruning. No-op for sparse cursors.
518    #[inline]
519    pub fn ensure_scores(&mut self) {
520        if self.block_loaded && self.scores.is_empty() {
521            self.compute_deferred_scores();
522        }
523    }
524
525    #[inline]
526    pub fn current_block_max_score(&self) -> f32 {
527        if self.exhausted {
528            return 0.0;
529        }
530        match &self.variant {
531            CursorVariant::Text { list, idf, .. } => {
532                let block_max_tf = list.block_max_tf(self.block_idx).unwrap_or(0) as f32;
533                super::bm25_upper_bound(block_max_tf.max(1.0), *idf)
534            }
535            CursorVariant::Sparse {
536                si,
537                query_weight,
538                skip_start,
539                ..
540            } => query_weight.abs() * si.read_skip_entry(*skip_start + self.block_idx).max_weight,
541        }
542    }
543
544    // ── Block navigation ────────────────────────────────────────────────
545
546    pub fn skip_to_next_block(&mut self) -> DocId {
547        if self.exhausted {
548            return u32::MAX;
549        }
550        self.block_idx += 1;
551        self.block_loaded = false;
552        if self.block_idx >= self.num_blocks {
553            self.exhausted = true;
554            return u32::MAX;
555        }
556        self.block_first_doc(self.block_idx)
557    }
558
559    #[inline]
560    fn advance_pos(&mut self) -> DocId {
561        self.pos += 1;
562        if self.pos >= self.doc_ids.len() {
563            self.block_idx += 1;
564            self.block_loaded = false;
565            if self.block_idx >= self.num_blocks {
566                self.exhausted = true;
567                return u32::MAX;
568            }
569        }
570        self.doc()
571    }
572
573    /// Compute BM25 scores from deferred TF data (lazy decode for text cursors).
574    #[inline(never)]
575    fn compute_deferred_scores(&mut self) {
576        if let CursorVariant::Text {
577            list,
578            idf_times_k1_plus_1,
579            denom_tf_coeff,
580            denom_const,
581            tfs,
582            deferred_tf,
583            ..
584        } = &mut self.variant
585            && let Some((block_offset, tf_start, count)) = deferred_tf.take()
586        {
587            list.decode_block_tfs_deferred(block_offset, tf_start, count, tfs);
588            let num_scale = *idf_times_k1_plus_1;
589            let d_tf = *denom_tf_coeff;
590            let d_const = *denom_const;
591            self.scores.clear();
592            self.scores.resize(count, 0.0);
593            for i in 0..count {
594                let tf = unsafe { *tfs.get_unchecked(i) } as f32;
595                let score = (num_scale * tf) / (d_tf * tf + d_const);
596                unsafe {
597                    *self.scores.get_unchecked_mut(i) = score;
598                }
599            }
600        }
601    }
602
603    // ── Block loading / advance / seek ─────────────────────────────────
604    //
605    // Macros parameterised on sparse I/O method + optional .await to
606    // stamp out both async and sync variants without duplication.
607
608    pub async fn ensure_block_loaded(&mut self) -> crate::Result<bool> {
609        cursor_ensure_block!(self, load_block_direct, .await)
610    }
611
612    pub fn ensure_block_loaded_sync(&mut self) -> crate::Result<bool> {
613        cursor_ensure_block!(self, load_block_direct_sync,)
614    }
615
616    pub async fn advance(&mut self) -> crate::Result<DocId> {
617        cursor_advance!(self, ensure_block_loaded, .await)
618    }
619
620    pub fn advance_sync(&mut self) -> crate::Result<DocId> {
621        cursor_advance!(self, ensure_block_loaded_sync,)
622    }
623
624    pub async fn seek(&mut self, target: DocId) -> crate::Result<DocId> {
625        cursor_seek!(self, ensure_block_loaded, target, .await)
626    }
627
628    pub fn seek_sync(&mut self, target: DocId) -> crate::Result<DocId> {
629        cursor_seek!(self, ensure_block_loaded_sync, target,)
630    }
631
632    fn seek_prepare(&mut self, target: DocId) -> Option<DocId> {
633        if self.exhausted {
634            return Some(u32::MAX);
635        }
636
637        // Fast path: target is within the currently loaded block
638        if self.block_loaded
639            && let Some(&last) = self.doc_ids.last()
640        {
641            if last >= target && self.doc_ids[self.pos] < target {
642                let remaining = &self.doc_ids[self.pos..];
643                self.pos += crate::structures::simd::find_first_ge_u32(remaining, target);
644                if self.pos >= self.doc_ids.len() {
645                    self.block_idx += 1;
646                    self.block_loaded = false;
647                    if self.block_idx >= self.num_blocks {
648                        self.exhausted = true;
649                        return Some(u32::MAX);
650                    }
651                }
652                return Some(self.doc());
653            }
654            if self.doc_ids[self.pos] >= target {
655                return Some(self.doc());
656            }
657        }
658
659        // Seek to the block containing target
660        let lo = match &self.variant {
661            // Text: SIMD-accelerated 2-level seek (L1 + L0)
662            CursorVariant::Text { list, .. } => match list.seek_block(target, self.block_idx) {
663                Some(idx) => idx,
664                None => {
665                    self.exhausted = true;
666                    return Some(u32::MAX);
667                }
668            },
669            // Sparse: binary search on skip entries (lazy mmap reads)
670            CursorVariant::Sparse { .. } => {
671                let mut lo = self.block_idx;
672                let mut hi = self.num_blocks;
673                while lo < hi {
674                    let mid = lo + (hi - lo) / 2;
675                    if self.block_last_doc(mid) < target {
676                        lo = mid + 1;
677                    } else {
678                        hi = mid;
679                    }
680                }
681                lo
682            }
683        };
684        if lo >= self.num_blocks {
685            self.exhausted = true;
686            return Some(u32::MAX);
687        }
688        if lo != self.block_idx || !self.block_loaded {
689            self.block_idx = lo;
690            self.block_loaded = false;
691        }
692        None
693    }
694
695    #[inline]
696    fn seek_finish(&mut self, target: DocId) -> bool {
697        if self.exhausted {
698            return false;
699        }
700        self.pos = crate::structures::simd::find_first_ge_u32(&self.doc_ids, target);
701        if self.pos >= self.doc_ids.len() {
702            self.block_idx += 1;
703            self.block_loaded = false;
704            if self.block_idx >= self.num_blocks {
705                self.exhausted = true;
706                return false;
707            }
708            return true;
709        }
710        false
711    }
712}
713
714/// Macro to stamp out the Block-Max MaxScore loop for both async and sync paths.
715///
716/// `$ensure`, `$advance`, `$seek` are cursor method idents (async or _sync variants).
717/// `$($aw:tt)*` captures `.await` for async or nothing for sync.
718macro_rules! bms_execute_loop {
719    ($self:ident, $ensure:ident, $advance:ident, $seek:ident, $($aw:tt)*) => {{
720        let n = $self.cursors.len();
721
722        // Load first block for each cursor (ensures doc() returns real values)
723        for cursor in &mut $self.cursors {
724            cursor.$ensure() $($aw)* ?;
725        }
726
727        let mut docs_scored = 0u64;
728        let mut docs_skipped = 0u64;
729        let mut blocks_skipped = 0u64;
730        let mut conjunction_skipped = 0u64;
731        let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
732        let _bms_start = std::time::Instant::now();
733
734        let inv_heap_factor = $self.inv_heap_factor;
735        let mut adjusted_threshold = $self.collector.threshold() * inv_heap_factor - 1e-6;
736
737        loop {
738            let partition = $self.find_partition();
739            if partition >= n {
740                break;
741            }
742
743            // Find minimum doc_id across essential cursors and collect
744            // which cursors are at min_doc (avoids redundant re-checks in
745            // conjunction, block-max, predicate, and scoring passes).
746            let mut min_doc = u32::MAX;
747            let mut at_min_mask = 0u64; // bitset of cursor indices at min_doc
748            for i in partition..n {
749                let doc = $self.cursors[i].doc();
750                match doc.cmp(&min_doc) {
751                    std::cmp::Ordering::Less => {
752                        min_doc = doc;
753                        at_min_mask = 1u64 << (i as u32);
754                    }
755                    std::cmp::Ordering::Equal => {
756                        at_min_mask |= 1u64 << (i as u32);
757                    }
758                    _ => {}
759                }
760            }
761            if min_doc == u32::MAX {
762                break;
763            }
764
765            let non_essential_upper = if partition > 0 {
766                $self.prefix_sums[partition - 1]
767            } else {
768                0.0
769            };
770
771            // --- Conjunction optimization ---
772            if $self.collector.len() >= $self.collector.k {
773                let mut present_upper: f32 = 0.0;
774                let mut mask = at_min_mask;
775                while mask != 0 {
776                    let i = mask.trailing_zeros() as usize;
777                    present_upper += $self.cursors[i].max_score;
778                    mask &= mask - 1;
779                }
780
781                if present_upper + non_essential_upper <= adjusted_threshold {
782                    let mut mask = at_min_mask;
783                    while mask != 0 {
784                        let i = mask.trailing_zeros() as usize;
785                        $self.cursors[i].$ensure() $($aw)* ?;
786                        $self.cursors[i].$advance() $($aw)* ?;
787                        mask &= mask - 1;
788                    }
789                    conjunction_skipped += 1;
790                    continue;
791                }
792            }
793
794            // --- Block-max pruning ---
795            if $self.collector.len() >= $self.collector.k {
796                let mut block_max_sum: f32 = 0.0;
797                let mut mask = at_min_mask;
798                while mask != 0 {
799                    let i = mask.trailing_zeros() as usize;
800                    block_max_sum += $self.cursors[i].current_block_max_score();
801                    mask &= mask - 1;
802                }
803
804                if block_max_sum + non_essential_upper <= adjusted_threshold {
805                    let mut mask = at_min_mask;
806                    while mask != 0 {
807                        let i = mask.trailing_zeros() as usize;
808                        $self.cursors[i].skip_to_next_block();
809                        $self.cursors[i].$ensure() $($aw)* ?;
810                        mask &= mask - 1;
811                    }
812                    blocks_skipped += 1;
813                    continue;
814                }
815            }
816
817            // --- Predicate filter (after block-max, before scoring) ---
818            if let Some(ref pred) = $self.predicate {
819                if !pred(min_doc) {
820                    let mut mask = at_min_mask;
821                    while mask != 0 {
822                        let i = mask.trailing_zeros() as usize;
823                        $self.cursors[i].$ensure() $($aw)* ?;
824                        $self.cursors[i].$advance() $($aw)* ?;
825                        mask &= mask - 1;
826                    }
827                    continue;
828                }
829            }
830
831            // --- Score essential cursors ---
832            ordinal_scores.clear();
833            {
834                let mut mask = at_min_mask;
835                while mask != 0 {
836                    let i = mask.trailing_zeros() as usize;
837                    $self.cursors[i].$ensure() $($aw)* ?;
838                    $self.cursors[i].ensure_scores();
839                    while $self.cursors[i].doc() == min_doc {
840                        let ord = $self.cursors[i].ordinal_mut();
841                        let sc = $self.cursors[i].score();
842                        ordinal_scores.push((ord, sc));
843                        $self.cursors[i].$advance() $($aw)* ?;
844                    }
845                    mask &= mask - 1;
846                }
847            }
848
849            let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
850            if $self.collector.len() >= $self.collector.k
851                && essential_total + non_essential_upper <= adjusted_threshold
852            {
853                docs_skipped += 1;
854                continue;
855            }
856
857            // --- Score non-essential cursors (highest max_score first for early exit) ---
858            let mut running_total = essential_total;
859            for i in (0..partition).rev() {
860                if $self.collector.len() >= $self.collector.k
861                    && running_total + $self.prefix_sums[i] <= adjusted_threshold
862                {
863                    break;
864                }
865
866                let doc = $self.cursors[i].$seek(min_doc) $($aw)* ?;
867                if doc == min_doc {
868                    $self.cursors[i].ensure_scores();
869                    while $self.cursors[i].doc() == min_doc {
870                        let s = $self.cursors[i].score();
871                        running_total += s;
872                        let ord = $self.cursors[i].ordinal_mut();
873                        ordinal_scores.push((ord, s));
874                        $self.cursors[i].$advance() $($aw)* ?;
875                    }
876                }
877            }
878
879            // --- Group by ordinal and insert ---
880            // Fast path: single entry (common for single-valued fields) — skip sort + grouping
881            if ordinal_scores.len() == 1 {
882                let (ord, score) = ordinal_scores[0];
883                if $self.collector.insert_with_ordinal(min_doc, score, ord) {
884                    docs_scored += 1;
885                    adjusted_threshold = $self.collector.threshold() * inv_heap_factor - 1e-6;
886                } else {
887                    docs_skipped += 1;
888                }
889            } else if !ordinal_scores.is_empty() {
890                if ordinal_scores.len() > 2 {
891                    ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
892                } else if ordinal_scores.len() == 2 && ordinal_scores[0].0 > ordinal_scores[1].0 {
893                    ordinal_scores.swap(0, 1);
894                }
895                let mut j = 0;
896                while j < ordinal_scores.len() {
897                    let current_ord = ordinal_scores[j].0;
898                    let mut score = 0.0f32;
899                    while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
900                        score += ordinal_scores[j].1;
901                        j += 1;
902                    }
903                    if $self
904                        .collector
905                        .insert_with_ordinal(min_doc, score, current_ord)
906                    {
907                        docs_scored += 1;
908                        adjusted_threshold = $self.collector.threshold() * inv_heap_factor - 1e-6;
909                    } else {
910                        docs_skipped += 1;
911                    }
912                }
913            }
914        }
915
916        let results: Vec<ScoredDoc> = $self
917            .collector
918            .into_sorted_results()
919            .into_iter()
920            .map(|(doc_id, score, ordinal)| ScoredDoc {
921                doc_id,
922                score,
923                ordinal,
924            })
925            .collect();
926
927        let _bms_elapsed_ms = _bms_start.elapsed().as_millis() as u64;
928        if _bms_elapsed_ms > 500 {
929            warn!(
930                "slow MaxScore: {}ms, cursors={}, scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
931                _bms_elapsed_ms,
932                n,
933                docs_scored,
934                docs_skipped,
935                blocks_skipped,
936                conjunction_skipped,
937                results.len(),
938                results.first().map(|r| r.score).unwrap_or(0.0)
939            );
940        } else {
941            debug!(
942                "MaxScoreExecutor: {}ms, scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
943                _bms_elapsed_ms,
944                docs_scored,
945                docs_skipped,
946                blocks_skipped,
947                conjunction_skipped,
948                results.len(),
949                results.first().map(|r| r.score).unwrap_or(0.0)
950            );
951        }
952
953        Ok(results)
954    }};
955}
956
957impl<'a> MaxScoreExecutor<'a> {
958    /// Create a new executor from pre-built cursors.
959    ///
960    /// Cursors are sorted by max_score ascending (non-essential first) and
961    /// prefix sums are computed for the MaxScore partitioning.
962    pub(crate) fn new(mut cursors: Vec<TermCursor<'a>>, k: usize, heap_factor: f32) -> Self {
963        // Enable lazy ordinal decode — ordinals are only decoded when a doc
964        // actually reaches the scoring phase (saves ~100ns per skipped block).
965        for c in &mut cursors {
966            c.lazy_ordinals = true;
967        }
968
969        // Sort by max_score ascending (non-essential first)
970        cursors.sort_by(|a, b| {
971            a.max_score
972                .partial_cmp(&b.max_score)
973                .unwrap_or(Ordering::Equal)
974        });
975
976        let mut prefix_sums = Vec::with_capacity(cursors.len());
977        let mut cumsum = 0.0f32;
978        for c in &cursors {
979            cumsum += c.max_score;
980            prefix_sums.push(cumsum);
981        }
982
983        let clamped_heap_factor = heap_factor.clamp(0.01, 1.0);
984
985        debug!(
986            "Creating MaxScoreExecutor: num_cursors={}, k={}, total_upper={:.4}, heap_factor={:.2}",
987            cursors.len(),
988            k,
989            cumsum,
990            clamped_heap_factor
991        );
992
993        Self {
994            cursors,
995            prefix_sums,
996            collector: ScoreCollector::new(k),
997            inv_heap_factor: 1.0 / clamped_heap_factor,
998            predicate: None,
999        }
1000    }
1001
1002    /// Create an executor for sparse vector queries.
1003    ///
1004    /// Builds `TermCursor::Sparse` for each matched dimension.
1005    pub fn sparse(
1006        sparse_index: &'a crate::segment::SparseIndex,
1007        query_terms: Vec<(u32, f32)>,
1008        k: usize,
1009        heap_factor: f32,
1010    ) -> Self {
1011        let cursors: Vec<TermCursor<'a>> = query_terms
1012            .iter()
1013            .filter_map(|&(dim_id, qw)| {
1014                let (skip_start, skip_count, global_max, block_data_offset) =
1015                    sparse_index.get_skip_range_full(dim_id)?;
1016                Some(TermCursor::sparse(
1017                    sparse_index,
1018                    qw,
1019                    skip_start,
1020                    skip_count,
1021                    global_max,
1022                    block_data_offset,
1023                ))
1024            })
1025            .collect();
1026        Self::new(cursors, k, heap_factor)
1027    }
1028
1029    /// Create an executor for full-text BM25 queries.
1030    ///
1031    /// Builds `TermCursor::Text` for each posting list.
1032    pub fn text(
1033        posting_lists: Vec<(crate::structures::BlockPostingList, f32)>,
1034        avg_field_len: f32,
1035        k: usize,
1036    ) -> Self {
1037        let cursors: Vec<TermCursor<'a>> = posting_lists
1038            .into_iter()
1039            .map(|(pl, idf)| TermCursor::text(pl, idf, avg_field_len))
1040            .collect();
1041        Self::new(cursors, k, 1.0)
1042    }
1043
1044    #[inline]
1045    fn find_partition(&self) -> usize {
1046        // Alpha < 1.0 raises the effective threshold → more terms become
1047        // non-essential → more aggressive pruning (approximate retrieval).
1048        // Use multiplication by reciprocal (cheaper than division).
1049        let threshold = self.collector.threshold() * self.inv_heap_factor;
1050        self.prefix_sums.partition_point(|&sum| sum <= threshold)
1051    }
1052
1053    /// Attach a per-doc predicate filter to this executor.
1054    ///
1055    /// Docs failing the predicate are skipped after block-max pruning but
1056    /// before scoring. The predicate does not affect thresholds or block-max
1057    /// comparisons — the heap stores pure sparse/text scores.
1058    pub fn with_predicate(mut self, predicate: super::DocPredicate<'a>) -> Self {
1059        self.predicate = Some(predicate);
1060        self
1061    }
1062
1063    /// Seed the collector with an initial threshold for tighter early pruning.
1064    pub fn seed_threshold(&mut self, initial_threshold: f32) {
1065        self.collector.seed_threshold(initial_threshold);
1066    }
1067
1068    /// Execute Block-Max MaxScore and return top-k results (async).
1069    pub async fn execute(mut self) -> crate::Result<Vec<ScoredDoc>> {
1070        if self.cursors.is_empty() {
1071            return Ok(Vec::new());
1072        }
1073        bms_execute_loop!(self, ensure_block_loaded, advance, seek, .await)
1074    }
1075
1076    /// Synchronous execution — works when all cursors are text or mmap-backed sparse.
1077    pub fn execute_sync(mut self) -> crate::Result<Vec<ScoredDoc>> {
1078        if self.cursors.is_empty() {
1079            return Ok(Vec::new());
1080        }
1081        bms_execute_loop!(self, ensure_block_loaded_sync, advance_sync, seek_sync,)
1082    }
1083}
1084
1085#[cfg(test)]
1086mod tests {
1087    use super::*;
1088
1089    #[test]
1090    fn test_score_collector_basic() {
1091        let mut collector = ScoreCollector::new(3);
1092
1093        collector.insert(1, 1.0);
1094        collector.insert(2, 2.0);
1095        collector.insert(3, 3.0);
1096        assert_eq!(collector.threshold(), 1.0);
1097
1098        collector.insert(4, 4.0);
1099        assert_eq!(collector.threshold(), 2.0);
1100
1101        let results = collector.into_sorted_results();
1102        assert_eq!(results.len(), 3);
1103        assert_eq!(results[0].0, 4); // Highest score
1104        assert_eq!(results[1].0, 3);
1105        assert_eq!(results[2].0, 2);
1106    }
1107
1108    #[test]
1109    fn test_score_collector_threshold() {
1110        let mut collector = ScoreCollector::new(2);
1111
1112        collector.insert(1, 5.0);
1113        collector.insert(2, 3.0);
1114        assert_eq!(collector.threshold(), 3.0);
1115
1116        // Should not enter (score too low)
1117        assert!(!collector.would_enter(2.0));
1118        assert!(!collector.insert(3, 2.0));
1119
1120        // Should enter (score high enough)
1121        assert!(collector.would_enter(4.0));
1122        assert!(collector.insert(4, 4.0));
1123        assert_eq!(collector.threshold(), 4.0);
1124    }
1125
1126    #[test]
1127    fn test_heap_entry_ordering() {
1128        let mut heap = BinaryHeap::new();
1129        heap.push(HeapEntry {
1130            doc_id: 1,
1131            score: 3.0,
1132            ordinal: 0,
1133        });
1134        heap.push(HeapEntry {
1135            doc_id: 2,
1136            score: 1.0,
1137            ordinal: 0,
1138        });
1139        heap.push(HeapEntry {
1140            doc_id: 3,
1141            score: 2.0,
1142            ordinal: 0,
1143        });
1144
1145        // Min-heap: lowest score should come out first
1146        assert_eq!(heap.pop().unwrap().score, 1.0);
1147        assert_eq!(heap.pop().unwrap().score, 2.0);
1148        assert_eq!(heap.pop().unwrap().score, 3.0);
1149    }
1150}