Skip to main content

hermes_core/query/
scoring.rs

1//! Shared scoring abstractions for text and sparse vector search
2//!
3//! Provides common types and executors for efficient top-k retrieval:
4//! - `TermCursor`: Unified cursor for both BM25 text and sparse vector posting lists
5//! - `ScoreCollector`: Efficient min-heap for maintaining top-k results
6//! - `MaxScoreExecutor`: Unified Block-Max MaxScore with conjunction optimization
7//! - `BmpExecutor`: Block-at-a-time executor for learned sparse retrieval (12+ terms)
8
9use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11
12use log::debug;
13
14use crate::DocId;
15
16/// Entry for top-k min-heap
17#[derive(Clone, Copy)]
18pub struct HeapEntry {
19    pub doc_id: DocId,
20    pub score: f32,
21    pub ordinal: u16,
22}
23
24impl PartialEq for HeapEntry {
25    fn eq(&self, other: &Self) -> bool {
26        self.score == other.score && self.doc_id == other.doc_id
27    }
28}
29
30impl Eq for HeapEntry {}
31
32impl Ord for HeapEntry {
33    fn cmp(&self, other: &Self) -> Ordering {
34        // Min-heap: lower scores come first (to be evicted)
35        other
36            .score
37            .partial_cmp(&self.score)
38            .unwrap_or(Ordering::Equal)
39            .then_with(|| self.doc_id.cmp(&other.doc_id))
40    }
41}
42
43impl PartialOrd for HeapEntry {
44    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
45        Some(self.cmp(other))
46    }
47}
48
49/// Efficient top-k collector using min-heap
50///
51/// Maintains the k highest-scoring documents using a min-heap where the
52/// lowest score is at the top for O(1) threshold lookup and O(log k) eviction.
53/// No deduplication - caller must ensure each doc_id is inserted only once.
54pub struct ScoreCollector {
55    /// Min-heap of top-k entries (lowest score at top for eviction)
56    heap: BinaryHeap<HeapEntry>,
57    pub k: usize,
58    /// Cached threshold: avoids repeated heap.peek() in hot loops.
59    /// Updated only when the heap changes (insert/pop).
60    cached_threshold: f32,
61}
62
63impl ScoreCollector {
64    /// Create a new collector for top-k results
65    pub fn new(k: usize) -> Self {
66        // Cap capacity to avoid allocation overflow for very large k
67        let capacity = k.saturating_add(1).min(1_000_000);
68        Self {
69            heap: BinaryHeap::with_capacity(capacity),
70            k,
71            cached_threshold: 0.0,
72        }
73    }
74
75    /// Current score threshold (minimum score to enter top-k)
76    #[inline]
77    pub fn threshold(&self) -> f32 {
78        self.cached_threshold
79    }
80
81    /// Recompute cached threshold from heap state
82    #[inline]
83    fn update_threshold(&mut self) {
84        self.cached_threshold = if self.heap.len() >= self.k {
85            self.heap.peek().map(|e| e.score).unwrap_or(0.0)
86        } else {
87            0.0
88        };
89    }
90
91    /// Insert a document score. Returns true if inserted in top-k.
92    /// Caller must ensure each doc_id is inserted only once.
93    #[inline]
94    pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
95        self.insert_with_ordinal(doc_id, score, 0)
96    }
97
98    /// Insert a document score with ordinal. Returns true if inserted in top-k.
99    /// Caller must ensure each doc_id is inserted only once.
100    #[inline]
101    pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
102        if self.heap.len() < self.k {
103            self.heap.push(HeapEntry {
104                doc_id,
105                score,
106                ordinal,
107            });
108            self.update_threshold();
109            true
110        } else if score > self.cached_threshold {
111            self.heap.push(HeapEntry {
112                doc_id,
113                score,
114                ordinal,
115            });
116            self.heap.pop(); // Remove lowest
117            self.update_threshold();
118            true
119        } else {
120            false
121        }
122    }
123
124    /// Check if a score could potentially enter top-k
125    #[inline]
126    pub fn would_enter(&self, score: f32) -> bool {
127        self.heap.len() < self.k || score > self.cached_threshold
128    }
129
130    /// Get number of documents collected so far
131    #[inline]
132    pub fn len(&self) -> usize {
133        self.heap.len()
134    }
135
136    /// Check if collector is empty
137    #[inline]
138    pub fn is_empty(&self) -> bool {
139        self.heap.is_empty()
140    }
141
142    /// Convert to sorted top-k results (descending by score)
143    pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
144        let heap_vec = self.heap.into_vec();
145        let mut results: Vec<(DocId, f32, u16)> = Vec::with_capacity(heap_vec.len());
146        for e in heap_vec {
147            results.push((e.doc_id, e.score, e.ordinal));
148        }
149
150        // Sort by score descending, then doc_id ascending
151        results.sort_by(|a, b| {
152            b.1.partial_cmp(&a.1)
153                .unwrap_or(Ordering::Equal)
154                .then_with(|| a.0.cmp(&b.0))
155        });
156
157        results
158    }
159}
160
161/// Search result from MaxScore execution
162#[derive(Debug, Clone, Copy)]
163pub struct ScoredDoc {
164    pub doc_id: DocId,
165    pub score: f32,
166    /// Ordinal for multi-valued fields (which vector in the field matched)
167    pub ordinal: u16,
168}
169
170/// Block-Max Pruning (BMP) executor for learned sparse retrieval
171///
172/// Processes blocks in score-descending order using a priority queue.
173/// Best for queries with many terms (20+), like SPLADE expansions.
174/// Uses document accumulators (FxHashMap) instead of per-term iterators.
175///
176/// **Memory-efficient**: Only skip entries (block metadata) are kept in memory.
177/// Actual block data is loaded on-demand via mmap range reads during execution.
178///
179/// Reference: Mallia et al., "Faster Learned Sparse Retrieval with
180/// Block-Max Pruning" (SIGIR 2024)
181pub struct BmpExecutor<'a> {
182    /// Sparse index for on-demand block loading
183    sparse_index: &'a crate::segment::SparseIndex,
184    /// Query terms: (dim_id, query_weight) for each matched dimension
185    query_terms: Vec<(u32, f32)>,
186    /// Number of results to return
187    k: usize,
188    /// Heap factor for approximate search
189    heap_factor: f32,
190    /// Optional filter predicate (checked at final collection)
191    predicate: Option<super::DocPredicate<'a>>,
192}
193
194/// Superblock size: group S consecutive blocks into one priority queue entry.
195/// Reduces heap operations by S× (e.g. 8× fewer push/pop for S=8).
196const BMP_SUPERBLOCK_SIZE: usize = 8;
197
198/// Megablock size: group M superblocks into one outer priority queue entry.
199/// Two-level pruning: megablock-level (coarse) → superblock-level (fine).
200/// Reduces outer heap operations by M× compared to single-level superblocks.
201const BMP_MEGABLOCK_SIZE: usize = 16;
202
203/// Superblock entry (stored per-term, not in the heap directly)
204struct BmpSuperBlock {
205    /// Upper bound contribution of this superblock (sum of constituent blocks)
206    contribution: f32,
207    /// First block index in this superblock
208    block_start: usize,
209    /// Number of blocks in this superblock (1..=BMP_SUPERBLOCK_SIZE)
210    block_count: usize,
211}
212
213/// Entry in the BMP outer priority queue: represents a megablock (group of superblocks)
214struct BmpMegaBlockEntry {
215    /// Upper bound contribution of this megablock (sum of constituent superblocks)
216    contribution: f32,
217    /// Index into query_terms
218    term_idx: usize,
219    /// First superblock index within term_superblocks[term_idx]
220    sb_start: usize,
221    /// Number of superblocks in this megablock (1..=BMP_MEGABLOCK_SIZE)
222    sb_count: usize,
223}
224
225impl PartialEq for BmpMegaBlockEntry {
226    fn eq(&self, other: &Self) -> bool {
227        self.contribution == other.contribution
228    }
229}
230
231impl Eq for BmpMegaBlockEntry {}
232
233impl Ord for BmpMegaBlockEntry {
234    fn cmp(&self, other: &Self) -> Ordering {
235        // Max-heap: higher contributions come first
236        self.contribution
237            .partial_cmp(&other.contribution)
238            .unwrap_or(Ordering::Equal)
239    }
240}
241
242impl PartialOrd for BmpMegaBlockEntry {
243    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
244        Some(self.cmp(other))
245    }
246}
247
248/// Macro to stamp out the BMP execution loop for both async and sync paths.
249///
250/// `$get_blocks:ident` is the SparseIndex method (get_blocks_range or get_blocks_range_sync).
251/// `$($aw:tt)*` captures `.await` for async or nothing for sync.
252macro_rules! bmp_execute_loop {
253    ($self:ident, $get_blocks:ident, $($aw:tt)*) => {{
254        use rustc_hash::FxHashMap;
255
256        let num_terms = $self.query_terms.len();
257        let si = $self.sparse_index;
258
259        // Two-level queue construction:
260        // 1. Build superblocks per term (flat Vecs)
261        // 2. Group superblocks into megablocks, push to outer BinaryHeap
262        let mut term_superblocks: Vec<Vec<BmpSuperBlock>> = Vec::with_capacity(num_terms);
263        let mut term_skip_starts: Vec<usize> = Vec::with_capacity(num_terms);
264        let mut global_min_doc = u32::MAX;
265        let mut global_max_doc = 0u32;
266        let mut total_remaining = 0.0f32;
267
268        for &(dim_id, qw) in &$self.query_terms {
269            let mut term_skip_start = 0usize;
270            let mut superblocks = Vec::new();
271
272            let abs_qw = qw.abs();
273            if let Some((skip_start, skip_count, _global_max)) = si.get_skip_range(dim_id) {
274                term_skip_start = skip_start;
275                let mut sb_start = 0;
276                while sb_start < skip_count {
277                    let sb_count = (skip_count - sb_start).min(BMP_SUPERBLOCK_SIZE);
278                    let mut sb_contribution = 0.0f32;
279                    for j in 0..sb_count {
280                        let skip = si.read_skip_entry(skip_start + sb_start + j);
281                        sb_contribution += abs_qw * skip.max_weight;
282                        global_min_doc = global_min_doc.min(skip.first_doc);
283                        global_max_doc = global_max_doc.max(skip.last_doc);
284                    }
285                    total_remaining += sb_contribution;
286                    superblocks.push(BmpSuperBlock {
287                        contribution: sb_contribution,
288                        block_start: sb_start,
289                        block_count: sb_count,
290                    });
291                    sb_start += sb_count;
292                }
293            }
294            term_skip_starts.push(term_skip_start);
295            term_superblocks.push(superblocks);
296        }
297
298        // Step 2: Group superblocks into megablocks and build outer priority queue
299        let mut mega_queue: BinaryHeap<BmpMegaBlockEntry> = BinaryHeap::new();
300        for (term_idx, superblocks) in term_superblocks.iter().enumerate() {
301            let mut mb_start = 0;
302            while mb_start < superblocks.len() {
303                let mb_count = (superblocks.len() - mb_start).min(BMP_MEGABLOCK_SIZE);
304                let mb_contribution: f32 = superblocks[mb_start..mb_start + mb_count]
305                    .iter()
306                    .map(|sb| sb.contribution)
307                    .sum();
308                mega_queue.push(BmpMegaBlockEntry {
309                    contribution: mb_contribution,
310                    term_idx,
311                    sb_start: mb_start,
312                    sb_count: mb_count,
313                });
314                mb_start += mb_count;
315            }
316        }
317
318        // Hybrid accumulator: flat array for ordinal=0, FxHashMap for multi-ordinal
319        let doc_range = if global_max_doc >= global_min_doc {
320            (global_max_doc - global_min_doc + 1) as usize
321        } else {
322            0
323        };
324        let use_flat = doc_range > 0 && doc_range <= 256 * 1024;
325        let mut flat_scores: Vec<f32> = if use_flat {
326            vec![0.0; doc_range]
327        } else {
328            Vec::new()
329        };
330        let mut dirty: Vec<u32> = if use_flat {
331            Vec::with_capacity(4096)
332        } else {
333            Vec::new()
334        };
335        let mut multi_ord_accumulators: FxHashMap<u64, f32> = FxHashMap::default();
336
337        let mut blocks_processed = 0u64;
338        let mut blocks_skipped = 0u64;
339
340        let mut top_k = ScoreCollector::new($self.k);
341
342        let mut doc_ids_buf: Vec<u32> = Vec::with_capacity(256);
343        let mut weights_buf: Vec<f32> = Vec::with_capacity(256);
344        let mut ordinals_buf: Vec<u16> = Vec::with_capacity(256);
345
346        let mut terms_warmed = vec![false; num_terms];
347        let mut warmup_remaining = $self.k.min(num_terms);
348
349        while let Some(mega) = mega_queue.pop() {
350            total_remaining -= mega.contribution;
351
352            if !terms_warmed[mega.term_idx] {
353                terms_warmed[mega.term_idx] = true;
354                warmup_remaining = warmup_remaining.saturating_sub(1);
355            }
356
357            if warmup_remaining == 0 {
358                let adjusted_threshold = top_k.threshold() * $self.heap_factor;
359                if top_k.len() >= $self.k && total_remaining <= adjusted_threshold {
360                    let remaining_blocks: u64 = mega_queue
361                        .iter()
362                        .map(|m| {
363                            let sbs =
364                                &term_superblocks[m.term_idx][m.sb_start..m.sb_start + m.sb_count];
365                            sbs.iter().map(|sb| sb.block_count as u64).sum::<u64>()
366                        })
367                        .sum();
368                    blocks_skipped += remaining_blocks;
369                    debug!(
370                        "BMP early termination after {} blocks: remaining={:.4} <= threshold={:.4}",
371                        blocks_processed, total_remaining, adjusted_threshold
372                    );
373                    break;
374                }
375            }
376
377            let dim_id = $self.query_terms[mega.term_idx].0;
378            let qw = $self.query_terms[mega.term_idx].1;
379            let abs_qw = qw.abs();
380            let skip_start = term_skip_starts[mega.term_idx];
381
382            for sb in term_superblocks[mega.term_idx]
383                .iter()
384                .skip(mega.sb_start)
385                .take(mega.sb_count)
386            {
387                if top_k.len() >= $self.k {
388                    let adjusted_threshold = top_k.threshold() * $self.heap_factor;
389                    if sb.contribution + total_remaining <= adjusted_threshold {
390                        blocks_skipped += sb.block_count as u64;
391                        continue;
392                    }
393                }
394
395                // Coalesced superblock loading — async or sync dispatch point
396                let sb_blocks = si
397                    .$get_blocks(dim_id, sb.block_start, sb.block_count)
398                    $($aw)*?;
399
400                let adjusted_threshold2 = top_k.threshold() * $self.heap_factor;
401                let dirty_start = dirty.len();
402
403                for (blk_offset, block) in sb_blocks.into_iter().enumerate() {
404                    let blk_idx = sb.block_start + blk_offset;
405
406                    if top_k.len() >= $self.k {
407                        let skip = si.read_skip_entry(skip_start + blk_idx);
408                        let blk_contrib = abs_qw * skip.max_weight;
409                        if blk_contrib + total_remaining <= adjusted_threshold2 {
410                            blocks_skipped += 1;
411                            continue;
412                        }
413                    }
414
415                    block.decode_doc_ids_into(&mut doc_ids_buf);
416
417                    if block.header.ordinal_bits == 0 && use_flat {
418                        block.accumulate_scored_weights(
419                            qw,
420                            &doc_ids_buf,
421                            &mut flat_scores,
422                            global_min_doc,
423                            &mut dirty,
424                        );
425                    } else {
426                        block.decode_scored_weights_into(qw, &mut weights_buf);
427                        let count = block.header.count as usize;
428
429                        block.decode_ordinals_into(&mut ordinals_buf);
430                        if use_flat {
431                            for i in 0..count {
432                                let doc_id = doc_ids_buf[i];
433                                let ordinal = ordinals_buf[i];
434                                let score_contribution = weights_buf[i];
435
436                                if ordinal == 0 {
437                                    let off = (doc_id - global_min_doc) as usize;
438                                    if flat_scores[off] == 0.0 {
439                                        dirty.push(doc_id);
440                                    }
441                                    flat_scores[off] += score_contribution;
442                                } else {
443                                    let key = (doc_id as u64) << 16 | ordinal as u64;
444                                    let acc = multi_ord_accumulators.entry(key).or_insert(0.0);
445                                    *acc += score_contribution;
446                                    top_k.insert_with_ordinal(doc_id, *acc, ordinal);
447                                }
448                            }
449                        } else {
450                            for i in 0..count {
451                                let key = (doc_ids_buf[i] as u64) << 16 | ordinals_buf[i] as u64;
452                                let acc = multi_ord_accumulators.entry(key).or_insert(0.0);
453                                *acc += weights_buf[i];
454                                top_k.insert_with_ordinal(doc_ids_buf[i], *acc, ordinals_buf[i]);
455                            }
456                        }
457                    }
458
459                    blocks_processed += 1;
460                }
461
462                for &doc_id in &dirty[dirty_start..] {
463                    let off = (doc_id - global_min_doc) as usize;
464                    top_k.insert_with_ordinal(doc_id, flat_scores[off], 0);
465                }
466            }
467        }
468
469        // Collect final top-k with predicate filtering
470        let mut final_top_k = ScoreCollector::new($self.k);
471
472        let num_accumulators = if use_flat {
473            for &doc_id in &dirty {
474                if let Some(ref pred) = $self.predicate
475                    && !pred(doc_id)
476                {
477                    continue;
478                }
479                let off = (doc_id - global_min_doc) as usize;
480                let score = flat_scores[off];
481                if score > 0.0 {
482                    final_top_k.insert_with_ordinal(doc_id, score, 0);
483                }
484            }
485            dirty.len() + multi_ord_accumulators.len()
486        } else {
487            multi_ord_accumulators.len()
488        };
489
490        for (key, score) in &multi_ord_accumulators {
491            let doc_id = (*key >> 16) as crate::DocId;
492            if let Some(ref pred) = $self.predicate
493                && !pred(doc_id)
494            {
495                continue;
496            }
497            final_top_k.insert_with_ordinal(doc_id, *score, (*key & 0xFFFF) as u16);
498        }
499
500        let results: Vec<ScoredDoc> = final_top_k
501            .into_sorted_results()
502            .into_iter()
503            .map(|(doc_id, score, ordinal)| ScoredDoc {
504                doc_id,
505                score,
506                ordinal,
507            })
508            .collect();
509
510        debug!(
511            "BmpExecutor completed: blocks_processed={}, blocks_skipped={}, accumulators={}, flat={}, returned={}, top_score={:.4}",
512            blocks_processed,
513            blocks_skipped,
514            num_accumulators,
515            use_flat,
516            results.len(),
517            results.first().map(|r| r.score).unwrap_or(0.0)
518        );
519
520        Ok(results)
521    }};
522}
523
524impl<'a> BmpExecutor<'a> {
525    /// Create a new BMP executor with lazy block loading
526    ///
527    /// `query_terms` should contain only dimensions that exist in the index.
528    /// Block metadata (skip entries) is read from the sparse index directly.
529    pub fn new(
530        sparse_index: &'a crate::segment::SparseIndex,
531        query_terms: Vec<(u32, f32)>,
532        k: usize,
533        heap_factor: f32,
534    ) -> Self {
535        Self {
536            sparse_index,
537            query_terms,
538            k,
539            heap_factor: heap_factor.clamp(0.0, 1.0),
540            predicate: None,
541        }
542    }
543
544    /// Set a filter predicate that rejects documents at final collection.
545    pub fn set_predicate(&mut self, predicate: Option<super::DocPredicate<'a>>) {
546        self.predicate = predicate;
547    }
548
549    /// Execute BMP and return top-k results (async).
550    pub async fn execute(self) -> crate::Result<Vec<ScoredDoc>> {
551        if self.query_terms.is_empty() {
552            return Ok(Vec::new());
553        }
554        bmp_execute_loop!(self, get_blocks_range, .await)
555    }
556
557    /// Synchronous BMP execution — works when sparse index is mmap-backed.
558    #[cfg(feature = "sync")]
559    pub fn execute_sync(self) -> crate::Result<Vec<ScoredDoc>> {
560        if self.query_terms.is_empty() {
561            return Ok(Vec::new());
562        }
563        bmp_execute_loop!(self, get_blocks_range_sync,)
564    }
565}
566
567/// Unified Block-Max MaxScore executor for top-k retrieval
568///
569/// Works with both full-text (BM25) and sparse vector (dot product) queries
570/// through the polymorphic `TermCursor`. Combines three optimizations:
571/// 1. **MaxScore partitioning** (Turtle & Flood 1995): terms split into essential
572///    (must check) and non-essential (only scored if candidate is promising)
573/// 2. **Block-max pruning** (Ding & Suel 2011): skip blocks where per-block
574///    upper bounds can't beat the current threshold
575/// 3. **Conjunction optimization** (Lucene/Grand 2023): progressively intersect
576///    essential terms as threshold rises, skipping docs that lack enough terms
577pub struct MaxScoreExecutor<'a> {
578    cursors: Vec<TermCursor<'a>>,
579    prefix_sums: Vec<f32>,
580    collector: ScoreCollector,
581    heap_factor: f32,
582    predicate: Option<super::DocPredicate<'a>>,
583}
584
585/// Unified term cursor for Block-Max MaxScore execution.
586///
587/// All per-position decode buffers (`doc_ids`, `scores`, `ordinals`) live in
588/// the struct directly and are filled by `ensure_block_loaded`.
589///
590/// Skip-list metadata is **not** materialized — it is read lazily from the
591/// underlying source (`BlockPostingList` for text, `SparseIndex` for sparse),
592/// both backed by zero-copy mmap'd `OwnedBytes`.
593pub(crate) struct TermCursor<'a> {
594    pub max_score: f32,
595    num_blocks: usize,
596    // ── Per-position state (filled by ensure_block_loaded) ──────────
597    block_idx: usize,
598    doc_ids: Vec<u32>,
599    scores: Vec<f32>,
600    ordinals: Vec<u16>,
601    pos: usize,
602    block_loaded: bool,
603    exhausted: bool,
604    // ── Block decode + skip access source ───────────────────────────
605    variant: CursorVariant<'a>,
606}
607
608enum CursorVariant<'a> {
609    /// Full-text BM25 — in-memory BlockPostingList (skip list + block data)
610    Text {
611        list: crate::structures::BlockPostingList,
612        idf: f32,
613        avg_field_len: f32,
614        tfs: Vec<u32>, // temp decode buffer, converted to scores
615    },
616    /// Sparse vector — mmap'd SparseIndex (skip entries + block data)
617    Sparse {
618        si: &'a crate::segment::SparseIndex,
619        query_weight: f32,
620        skip_start: usize,
621        block_data_offset: u64,
622    },
623}
624
625impl<'a> TermCursor<'a> {
626    /// Create a full-text BM25 cursor (lazy — no blocks decoded yet).
627    pub fn text(
628        posting_list: crate::structures::BlockPostingList,
629        idf: f32,
630        avg_field_len: f32,
631    ) -> Self {
632        let max_tf = posting_list.max_tf() as f32;
633        let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
634        let num_blocks = posting_list.num_blocks();
635        Self {
636            max_score,
637            num_blocks,
638            block_idx: 0,
639            doc_ids: Vec::with_capacity(128),
640            scores: Vec::with_capacity(128),
641            ordinals: Vec::new(),
642            pos: 0,
643            block_loaded: false,
644            exhausted: num_blocks == 0,
645            variant: CursorVariant::Text {
646                list: posting_list,
647                idf,
648                avg_field_len,
649                tfs: Vec::with_capacity(128),
650            },
651        }
652    }
653
654    /// Create a sparse vector cursor with lazy block loading.
655    /// Skip entries are **not** copied — they are read from `SparseIndex` mmap on demand.
656    pub fn sparse(
657        si: &'a crate::segment::SparseIndex,
658        query_weight: f32,
659        skip_start: usize,
660        skip_count: usize,
661        global_max_weight: f32,
662        block_data_offset: u64,
663    ) -> Self {
664        Self {
665            max_score: query_weight.abs() * global_max_weight,
666            num_blocks: skip_count,
667            block_idx: 0,
668            doc_ids: Vec::with_capacity(256),
669            scores: Vec::with_capacity(256),
670            ordinals: Vec::with_capacity(256),
671            pos: 0,
672            block_loaded: false,
673            exhausted: skip_count == 0,
674            variant: CursorVariant::Sparse {
675                si,
676                query_weight,
677                skip_start,
678                block_data_offset,
679            },
680        }
681    }
682
683    // ── Skip-entry access (lazy, zero-copy for sparse) ──────────────────
684
685    #[inline]
686    fn block_first_doc(&self, idx: usize) -> DocId {
687        match &self.variant {
688            CursorVariant::Text { list, .. } => list.block_first_doc(idx).unwrap_or(u32::MAX),
689            CursorVariant::Sparse { si, skip_start, .. } => {
690                si.read_skip_entry(*skip_start + idx).first_doc
691            }
692        }
693    }
694
695    #[inline]
696    fn block_last_doc(&self, idx: usize) -> DocId {
697        match &self.variant {
698            CursorVariant::Text { list, .. } => list.block_last_doc(idx).unwrap_or(0),
699            CursorVariant::Sparse { si, skip_start, .. } => {
700                si.read_skip_entry(*skip_start + idx).last_doc
701            }
702        }
703    }
704
705    // ── Read-only accessors ─────────────────────────────────────────────
706
707    #[inline]
708    pub fn doc(&self) -> DocId {
709        if self.exhausted {
710            return u32::MAX;
711        }
712        if self.block_loaded {
713            self.doc_ids.get(self.pos).copied().unwrap_or(u32::MAX)
714        } else {
715            self.block_first_doc(self.block_idx)
716        }
717    }
718
719    #[inline]
720    pub fn ordinal(&self) -> u16 {
721        if !self.block_loaded || self.ordinals.is_empty() {
722            return 0;
723        }
724        self.ordinals.get(self.pos).copied().unwrap_or(0)
725    }
726
727    #[inline]
728    pub fn score(&self) -> f32 {
729        if !self.block_loaded {
730            return 0.0;
731        }
732        self.scores.get(self.pos).copied().unwrap_or(0.0)
733    }
734
735    #[inline]
736    pub fn current_block_max_score(&self) -> f32 {
737        if self.exhausted {
738            return 0.0;
739        }
740        match &self.variant {
741            CursorVariant::Text { list, idf, .. } => {
742                let block_max_tf = list.block_max_tf(self.block_idx).unwrap_or(0) as f32;
743                super::bm25_upper_bound(block_max_tf.max(1.0), *idf)
744            }
745            CursorVariant::Sparse {
746                si,
747                query_weight,
748                skip_start,
749                ..
750            } => query_weight.abs() * si.read_skip_entry(*skip_start + self.block_idx).max_weight,
751        }
752    }
753
754    // ── Block navigation ────────────────────────────────────────────────
755
756    pub fn skip_to_next_block(&mut self) -> DocId {
757        if self.exhausted {
758            return u32::MAX;
759        }
760        self.block_idx += 1;
761        self.block_loaded = false;
762        if self.block_idx >= self.num_blocks {
763            self.exhausted = true;
764            return u32::MAX;
765        }
766        self.block_first_doc(self.block_idx)
767    }
768
769    #[inline]
770    fn advance_pos(&mut self) -> DocId {
771        self.pos += 1;
772        if self.pos >= self.doc_ids.len() {
773            self.block_idx += 1;
774            self.block_loaded = false;
775            if self.block_idx >= self.num_blocks {
776                self.exhausted = true;
777                return u32::MAX;
778            }
779        }
780        self.doc()
781    }
782
783    // ── Block loading (dispatch: decode format + I/O differ) ────────────
784
785    pub async fn ensure_block_loaded(&mut self) -> crate::Result<bool> {
786        if self.exhausted || self.block_loaded {
787            return Ok(!self.exhausted);
788        }
789        match &mut self.variant {
790            CursorVariant::Text {
791                list,
792                idf,
793                avg_field_len,
794                tfs,
795            } => {
796                if list.decode_block_into(self.block_idx, &mut self.doc_ids, tfs) {
797                    self.scores.clear();
798                    self.scores.reserve(tfs.len());
799                    for &tf in tfs.iter() {
800                        let tf = tf as f32;
801                        self.scores
802                            .push(super::bm25_score(tf, *idf, tf, *avg_field_len));
803                    }
804                    self.pos = 0;
805                    self.block_loaded = true;
806                    Ok(true)
807                } else {
808                    self.exhausted = true;
809                    Ok(false)
810                }
811            }
812            CursorVariant::Sparse {
813                si,
814                query_weight,
815                skip_start,
816                block_data_offset,
817                ..
818            } => {
819                let block = si
820                    .load_block_direct(*skip_start, *block_data_offset, self.block_idx)
821                    .await?;
822                match block {
823                    Some(b) => {
824                        b.decode_doc_ids_into(&mut self.doc_ids);
825                        b.decode_ordinals_into(&mut self.ordinals);
826                        b.decode_scored_weights_into(*query_weight, &mut self.scores);
827                        self.pos = 0;
828                        self.block_loaded = true;
829                        Ok(true)
830                    }
831                    None => {
832                        self.exhausted = true;
833                        Ok(false)
834                    }
835                }
836            }
837        }
838    }
839
840    pub fn ensure_block_loaded_sync(&mut self) -> crate::Result<bool> {
841        if self.exhausted || self.block_loaded {
842            return Ok(!self.exhausted);
843        }
844        match &mut self.variant {
845            CursorVariant::Text {
846                list,
847                idf,
848                avg_field_len,
849                tfs,
850            } => {
851                if list.decode_block_into(self.block_idx, &mut self.doc_ids, tfs) {
852                    self.scores.clear();
853                    self.scores.reserve(tfs.len());
854                    for &tf in tfs.iter() {
855                        let tf = tf as f32;
856                        self.scores
857                            .push(super::bm25_score(tf, *idf, tf, *avg_field_len));
858                    }
859                    self.pos = 0;
860                    self.block_loaded = true;
861                    Ok(true)
862                } else {
863                    self.exhausted = true;
864                    Ok(false)
865                }
866            }
867            CursorVariant::Sparse {
868                si,
869                query_weight,
870                skip_start,
871                block_data_offset,
872                ..
873            } => {
874                let block =
875                    si.load_block_direct_sync(*skip_start, *block_data_offset, self.block_idx)?;
876                match block {
877                    Some(b) => {
878                        b.decode_doc_ids_into(&mut self.doc_ids);
879                        b.decode_ordinals_into(&mut self.ordinals);
880                        b.decode_scored_weights_into(*query_weight, &mut self.scores);
881                        self.pos = 0;
882                        self.block_loaded = true;
883                        Ok(true)
884                    }
885                    None => {
886                        self.exhausted = true;
887                        Ok(false)
888                    }
889                }
890            }
891        }
892    }
893
894    // ── Advance / Seek ──────────────────────────────────────────────────
895
896    pub async fn advance(&mut self) -> crate::Result<DocId> {
897        if self.exhausted {
898            return Ok(u32::MAX);
899        }
900        self.ensure_block_loaded().await?;
901        if self.exhausted {
902            return Ok(u32::MAX);
903        }
904        Ok(self.advance_pos())
905    }
906
907    pub fn advance_sync(&mut self) -> crate::Result<DocId> {
908        if self.exhausted {
909            return Ok(u32::MAX);
910        }
911        self.ensure_block_loaded_sync()?;
912        if self.exhausted {
913            return Ok(u32::MAX);
914        }
915        Ok(self.advance_pos())
916    }
917
918    pub async fn seek(&mut self, target: DocId) -> crate::Result<DocId> {
919        if let Some(doc) = self.seek_prepare(target) {
920            return Ok(doc);
921        }
922        self.ensure_block_loaded().await?;
923        if self.seek_finish(target) {
924            self.ensure_block_loaded().await?;
925        }
926        Ok(self.doc())
927    }
928
929    pub fn seek_sync(&mut self, target: DocId) -> crate::Result<DocId> {
930        if let Some(doc) = self.seek_prepare(target) {
931            return Ok(doc);
932        }
933        self.ensure_block_loaded_sync()?;
934        if self.seek_finish(target) {
935            self.ensure_block_loaded_sync()?;
936        }
937        Ok(self.doc())
938    }
939
940    fn seek_prepare(&mut self, target: DocId) -> Option<DocId> {
941        if self.exhausted {
942            return Some(u32::MAX);
943        }
944
945        // Fast path: target is within the currently loaded block
946        if self.block_loaded
947            && let Some(&last) = self.doc_ids.last()
948        {
949            if last >= target && self.doc_ids[self.pos] < target {
950                let remaining = &self.doc_ids[self.pos..];
951                self.pos += crate::structures::simd::find_first_ge_u32(remaining, target);
952                if self.pos >= self.doc_ids.len() {
953                    self.block_idx += 1;
954                    self.block_loaded = false;
955                    if self.block_idx >= self.num_blocks {
956                        self.exhausted = true;
957                        return Some(u32::MAX);
958                    }
959                }
960                return Some(self.doc());
961            }
962            if self.doc_ids[self.pos] >= target {
963                return Some(self.doc());
964            }
965        }
966
967        // Seek to the block containing target
968        let lo = match &self.variant {
969            // Text: SIMD-accelerated 2-level seek (L1 + L0)
970            CursorVariant::Text { list, .. } => match list.seek_block(target, self.block_idx) {
971                Some(idx) => idx,
972                None => {
973                    self.exhausted = true;
974                    return Some(u32::MAX);
975                }
976            },
977            // Sparse: binary search on skip entries (lazy mmap reads)
978            CursorVariant::Sparse { .. } => {
979                let mut lo = self.block_idx;
980                let mut hi = self.num_blocks;
981                while lo < hi {
982                    let mid = lo + (hi - lo) / 2;
983                    if self.block_last_doc(mid) < target {
984                        lo = mid + 1;
985                    } else {
986                        hi = mid;
987                    }
988                }
989                lo
990            }
991        };
992        if lo >= self.num_blocks {
993            self.exhausted = true;
994            return Some(u32::MAX);
995        }
996        if lo != self.block_idx || !self.block_loaded {
997            self.block_idx = lo;
998            self.block_loaded = false;
999        }
1000        None
1001    }
1002
1003    #[inline]
1004    fn seek_finish(&mut self, target: DocId) -> bool {
1005        if self.exhausted {
1006            return false;
1007        }
1008        self.pos = crate::structures::simd::find_first_ge_u32(&self.doc_ids, target);
1009        if self.pos >= self.doc_ids.len() {
1010            self.block_idx += 1;
1011            self.block_loaded = false;
1012            if self.block_idx >= self.num_blocks {
1013                self.exhausted = true;
1014                return false;
1015            }
1016            return true;
1017        }
1018        false
1019    }
1020}
1021
1022/// Macro to stamp out the Block-Max MaxScore loop for both async and sync paths.
1023///
1024/// `$ensure`, `$advance`, `$seek` are cursor method idents (async or _sync variants).
1025/// `$($aw:tt)*` captures `.await` for async or nothing for sync.
1026macro_rules! bms_execute_loop {
1027    ($self:ident, $ensure:ident, $advance:ident, $seek:ident, $($aw:tt)*) => {{
1028        let n = $self.cursors.len();
1029
1030        // Load first block for each cursor (ensures doc() returns real values)
1031        for cursor in &mut $self.cursors {
1032            cursor.$ensure() $($aw)* ?;
1033        }
1034
1035        let mut docs_scored = 0u64;
1036        let mut docs_skipped = 0u64;
1037        let mut blocks_skipped = 0u64;
1038        let mut conjunction_skipped = 0u64;
1039        let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
1040
1041        loop {
1042            let partition = $self.find_partition();
1043            if partition >= n {
1044                break;
1045            }
1046
1047            // Find minimum doc_id across essential cursors
1048            let mut min_doc = u32::MAX;
1049            for i in partition..n {
1050                let doc = $self.cursors[i].doc();
1051                if doc < min_doc {
1052                    min_doc = doc;
1053                }
1054            }
1055            if min_doc == u32::MAX {
1056                break;
1057            }
1058
1059            // --- Filter predicate check (before any scoring) ---
1060            if let Some(ref pred) = $self.predicate {
1061                if !pred(min_doc) {
1062                    // Advance essential cursors past this doc
1063                    for i in partition..n {
1064                        if $self.cursors[i].doc() == min_doc {
1065                            $self.cursors[i].$ensure() $($aw)* ?;
1066                            $self.cursors[i].$advance() $($aw)* ?;
1067                        }
1068                    }
1069                    docs_skipped += 1;
1070                    continue;
1071                }
1072            }
1073
1074            let non_essential_upper = if partition > 0 {
1075                $self.prefix_sums[partition - 1]
1076            } else {
1077                0.0
1078            };
1079            let adjusted_threshold = $self.collector.threshold() * $self.heap_factor;
1080
1081            // --- Conjunction optimization ---
1082            if $self.collector.len() >= $self.collector.k {
1083                let present_upper: f32 = (partition..n)
1084                    .filter(|&i| $self.cursors[i].doc() == min_doc)
1085                    .map(|i| $self.cursors[i].max_score)
1086                    .sum();
1087
1088                if present_upper + non_essential_upper <= adjusted_threshold {
1089                    for i in partition..n {
1090                        if $self.cursors[i].doc() == min_doc {
1091                            $self.cursors[i].$ensure() $($aw)* ?;
1092                            $self.cursors[i].$advance() $($aw)* ?;
1093                        }
1094                    }
1095                    conjunction_skipped += 1;
1096                    continue;
1097                }
1098            }
1099
1100            // --- Block-max pruning ---
1101            if $self.collector.len() >= $self.collector.k {
1102                let block_max_sum: f32 = (partition..n)
1103                    .filter(|&i| $self.cursors[i].doc() == min_doc)
1104                    .map(|i| $self.cursors[i].current_block_max_score())
1105                    .sum();
1106
1107                if block_max_sum + non_essential_upper <= adjusted_threshold {
1108                    for i in partition..n {
1109                        if $self.cursors[i].doc() == min_doc {
1110                            $self.cursors[i].skip_to_next_block();
1111                            $self.cursors[i].$ensure() $($aw)* ?;
1112                        }
1113                    }
1114                    blocks_skipped += 1;
1115                    continue;
1116                }
1117            }
1118
1119            // --- Score essential cursors ---
1120            ordinal_scores.clear();
1121            for i in partition..n {
1122                if $self.cursors[i].doc() == min_doc {
1123                    $self.cursors[i].$ensure() $($aw)* ?;
1124                    while $self.cursors[i].doc() == min_doc {
1125                        ordinal_scores.push(($self.cursors[i].ordinal(), $self.cursors[i].score()));
1126                        $self.cursors[i].$advance() $($aw)* ?;
1127                    }
1128                }
1129            }
1130
1131            let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
1132            if $self.collector.len() >= $self.collector.k
1133                && essential_total + non_essential_upper <= adjusted_threshold
1134            {
1135                docs_skipped += 1;
1136                continue;
1137            }
1138
1139            // --- Score non-essential cursors (highest max_score first for early exit) ---
1140            let mut running_total = essential_total;
1141            for i in (0..partition).rev() {
1142                if $self.collector.len() >= $self.collector.k
1143                    && running_total + $self.prefix_sums[i] <= adjusted_threshold
1144                {
1145                    break;
1146                }
1147
1148                let doc = $self.cursors[i].$seek(min_doc) $($aw)* ?;
1149                if doc == min_doc {
1150                    while $self.cursors[i].doc() == min_doc {
1151                        let s = $self.cursors[i].score();
1152                        running_total += s;
1153                        ordinal_scores.push(($self.cursors[i].ordinal(), s));
1154                        $self.cursors[i].$advance() $($aw)* ?;
1155                    }
1156                }
1157            }
1158
1159            // --- Group by ordinal and insert ---
1160            // Fast path: single entry (common for single-valued fields) — skip sort + grouping
1161            if ordinal_scores.len() == 1 {
1162                let (ord, score) = ordinal_scores[0];
1163                if $self.collector.insert_with_ordinal(min_doc, score, ord) {
1164                    docs_scored += 1;
1165                } else {
1166                    docs_skipped += 1;
1167                }
1168            } else if !ordinal_scores.is_empty() {
1169                if ordinal_scores.len() > 2 {
1170                    ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
1171                } else if ordinal_scores[0].0 > ordinal_scores[1].0 {
1172                    ordinal_scores.swap(0, 1);
1173                }
1174                let mut j = 0;
1175                while j < ordinal_scores.len() {
1176                    let current_ord = ordinal_scores[j].0;
1177                    let mut score = 0.0f32;
1178                    while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
1179                        score += ordinal_scores[j].1;
1180                        j += 1;
1181                    }
1182                    if $self
1183                        .collector
1184                        .insert_with_ordinal(min_doc, score, current_ord)
1185                    {
1186                        docs_scored += 1;
1187                    } else {
1188                        docs_skipped += 1;
1189                    }
1190                }
1191            }
1192        }
1193
1194        let results: Vec<ScoredDoc> = $self
1195            .collector
1196            .into_sorted_results()
1197            .into_iter()
1198            .map(|(doc_id, score, ordinal)| ScoredDoc {
1199                doc_id,
1200                score,
1201                ordinal,
1202            })
1203            .collect();
1204
1205        debug!(
1206            "MaxScoreExecutor: scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
1207            docs_scored,
1208            docs_skipped,
1209            blocks_skipped,
1210            conjunction_skipped,
1211            results.len(),
1212            results.first().map(|r| r.score).unwrap_or(0.0)
1213        );
1214
1215        Ok(results)
1216    }};
1217}
1218
1219impl<'a> MaxScoreExecutor<'a> {
1220    /// Create a new executor from pre-built cursors.
1221    ///
1222    /// Cursors are sorted by max_score ascending (non-essential first) and
1223    /// prefix sums are computed for the MaxScore partitioning.
1224    pub(crate) fn new(mut cursors: Vec<TermCursor<'a>>, k: usize, heap_factor: f32) -> Self {
1225        // Sort by max_score ascending (non-essential first)
1226        cursors.sort_by(|a, b| {
1227            a.max_score
1228                .partial_cmp(&b.max_score)
1229                .unwrap_or(Ordering::Equal)
1230        });
1231
1232        let mut prefix_sums = Vec::with_capacity(cursors.len());
1233        let mut cumsum = 0.0f32;
1234        for c in &cursors {
1235            cumsum += c.max_score;
1236            prefix_sums.push(cumsum);
1237        }
1238
1239        debug!(
1240            "Creating MaxScoreExecutor: num_cursors={}, k={}, total_upper={:.4}, heap_factor={:.2}",
1241            cursors.len(),
1242            k,
1243            cumsum,
1244            heap_factor
1245        );
1246
1247        Self {
1248            cursors,
1249            prefix_sums,
1250            collector: ScoreCollector::new(k),
1251            heap_factor: heap_factor.clamp(0.0, 1.0),
1252            predicate: None,
1253        }
1254    }
1255
1256    /// Set a filter predicate that rejects documents before scoring.
1257    pub fn set_predicate(&mut self, predicate: Option<super::DocPredicate<'a>>) {
1258        self.predicate = predicate;
1259    }
1260
1261    /// Create an executor for sparse vector queries.
1262    ///
1263    /// Builds `TermCursor::Sparse` for each matched dimension.
1264    pub fn sparse(
1265        sparse_index: &'a crate::segment::SparseIndex,
1266        query_terms: Vec<(u32, f32)>,
1267        k: usize,
1268        heap_factor: f32,
1269    ) -> Self {
1270        let cursors: Vec<TermCursor<'a>> = query_terms
1271            .iter()
1272            .filter_map(|&(dim_id, qw)| {
1273                let (skip_start, skip_count, global_max, block_data_offset) =
1274                    sparse_index.get_skip_range_full(dim_id)?;
1275                Some(TermCursor::sparse(
1276                    sparse_index,
1277                    qw,
1278                    skip_start,
1279                    skip_count,
1280                    global_max,
1281                    block_data_offset,
1282                ))
1283            })
1284            .collect();
1285        Self::new(cursors, k, heap_factor)
1286    }
1287
1288    /// Create an executor for full-text BM25 queries.
1289    ///
1290    /// Builds `TermCursor::Text` for each posting list.
1291    pub fn text(
1292        posting_lists: Vec<(crate::structures::BlockPostingList, f32)>,
1293        avg_field_len: f32,
1294        k: usize,
1295    ) -> Self {
1296        let cursors: Vec<TermCursor<'a>> = posting_lists
1297            .into_iter()
1298            .map(|(pl, idf)| TermCursor::text(pl, idf, avg_field_len))
1299            .collect();
1300        Self::new(cursors, k, 1.0)
1301    }
1302
1303    #[inline]
1304    fn find_partition(&self) -> usize {
1305        let threshold = self.collector.threshold() * self.heap_factor;
1306        self.prefix_sums.partition_point(|&sum| sum <= threshold)
1307    }
1308
1309    /// Execute Block-Max MaxScore and return top-k results (async).
1310    pub async fn execute(mut self) -> crate::Result<Vec<ScoredDoc>> {
1311        if self.cursors.is_empty() {
1312            return Ok(Vec::new());
1313        }
1314        bms_execute_loop!(self, ensure_block_loaded, advance, seek, .await)
1315    }
1316
1317    /// Synchronous execution — works when all cursors are text or mmap-backed sparse.
1318    pub fn execute_sync(mut self) -> crate::Result<Vec<ScoredDoc>> {
1319        if self.cursors.is_empty() {
1320            return Ok(Vec::new());
1321        }
1322        bms_execute_loop!(self, ensure_block_loaded_sync, advance_sync, seek_sync,)
1323    }
1324}
1325
1326#[cfg(test)]
1327mod tests {
1328    use super::*;
1329
1330    #[test]
1331    fn test_score_collector_basic() {
1332        let mut collector = ScoreCollector::new(3);
1333
1334        collector.insert(1, 1.0);
1335        collector.insert(2, 2.0);
1336        collector.insert(3, 3.0);
1337        assert_eq!(collector.threshold(), 1.0);
1338
1339        collector.insert(4, 4.0);
1340        assert_eq!(collector.threshold(), 2.0);
1341
1342        let results = collector.into_sorted_results();
1343        assert_eq!(results.len(), 3);
1344        assert_eq!(results[0].0, 4); // Highest score
1345        assert_eq!(results[1].0, 3);
1346        assert_eq!(results[2].0, 2);
1347    }
1348
1349    #[test]
1350    fn test_score_collector_threshold() {
1351        let mut collector = ScoreCollector::new(2);
1352
1353        collector.insert(1, 5.0);
1354        collector.insert(2, 3.0);
1355        assert_eq!(collector.threshold(), 3.0);
1356
1357        // Should not enter (score too low)
1358        assert!(!collector.would_enter(2.0));
1359        assert!(!collector.insert(3, 2.0));
1360
1361        // Should enter (score high enough)
1362        assert!(collector.would_enter(4.0));
1363        assert!(collector.insert(4, 4.0));
1364        assert_eq!(collector.threshold(), 4.0);
1365    }
1366
1367    #[test]
1368    fn test_heap_entry_ordering() {
1369        let mut heap = BinaryHeap::new();
1370        heap.push(HeapEntry {
1371            doc_id: 1,
1372            score: 3.0,
1373            ordinal: 0,
1374        });
1375        heap.push(HeapEntry {
1376            doc_id: 2,
1377            score: 1.0,
1378            ordinal: 0,
1379        });
1380        heap.push(HeapEntry {
1381            doc_id: 3,
1382            score: 2.0,
1383            ordinal: 0,
1384        });
1385
1386        // Min-heap: lowest score should come out first
1387        assert_eq!(heap.pop().unwrap().score, 1.0);
1388        assert_eq!(heap.pop().unwrap().score, 2.0);
1389        assert_eq!(heap.pop().unwrap().score, 3.0);
1390    }
1391}