Skip to main content

hermes_core/query/
scoring.rs

1//! Shared scoring abstractions for text and sparse vector search
2//!
3//! Provides common traits and utilities for efficient top-k retrieval:
4//! - `ScoringIterator`: Common interface for posting list iteration with scoring
5//! - `TopKCollector`: Efficient min-heap for maintaining top-k results
6//! - `MaxScoreExecutor`: Unified Block-Max MaxScore with conjunction optimization
7//! - `BmpExecutor`: Block-at-a-time executor for learned sparse retrieval (12+ terms)
8//! - `SparseTermScorer`: ScoringIterator implementation for sparse vectors
9
10use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12use std::sync::Arc;
13
14use log::{debug, trace};
15
16use crate::DocId;
17use crate::structures::BlockSparsePostingList;
18
19/// Common interface for scoring iterators (text terms or sparse dimensions)
20///
21/// Abstracts the common operations needed for MaxScore top-k retrieval.
22pub trait ScoringIterator {
23    /// Current document ID (u32::MAX if exhausted)
24    fn doc(&self) -> DocId;
25
26    /// Current ordinal for multi-valued fields (default 0)
27    fn ordinal(&mut self) -> u16 {
28        0
29    }
30
31    /// Advance to next document, returns new doc ID
32    fn advance(&mut self) -> DocId;
33
34    /// Seek to first document >= target, returns new doc ID
35    fn seek(&mut self, target: DocId) -> DocId;
36
37    /// Check if iterator is exhausted
38    fn is_exhausted(&self) -> bool {
39        self.doc() == u32::MAX
40    }
41
42    /// Score contribution for current document
43    fn score(&self) -> f32;
44
45    /// Maximum possible score for this term/dimension (global upper bound)
46    fn max_score(&self) -> f32;
47
48    /// Current block's maximum score upper bound (for block-level pruning)
49    fn current_block_max_score(&self) -> f32;
50
51    /// Skip to the next block, returning the first doc_id in the new block.
52    /// Used for block-max pruning when current block can't beat threshold.
53    /// Default implementation just advances (no block-level skipping).
54    fn skip_to_next_block(&mut self) -> DocId {
55        self.advance()
56    }
57}
58
59/// Entry for top-k min-heap
60#[derive(Clone, Copy)]
61pub struct HeapEntry {
62    pub doc_id: DocId,
63    pub score: f32,
64    pub ordinal: u16,
65}
66
67impl PartialEq for HeapEntry {
68    fn eq(&self, other: &Self) -> bool {
69        self.score == other.score && self.doc_id == other.doc_id
70    }
71}
72
73impl Eq for HeapEntry {}
74
75impl Ord for HeapEntry {
76    fn cmp(&self, other: &Self) -> Ordering {
77        // Min-heap: lower scores come first (to be evicted)
78        other
79            .score
80            .partial_cmp(&self.score)
81            .unwrap_or(Ordering::Equal)
82            .then_with(|| self.doc_id.cmp(&other.doc_id))
83    }
84}
85
86impl PartialOrd for HeapEntry {
87    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
88        Some(self.cmp(other))
89    }
90}
91
92/// Efficient top-k collector using min-heap
93///
94/// Maintains the k highest-scoring documents using a min-heap where the
95/// lowest score is at the top for O(1) threshold lookup and O(log k) eviction.
96/// No deduplication - caller must ensure each doc_id is inserted only once.
97pub struct ScoreCollector {
98    /// Min-heap of top-k entries (lowest score at top for eviction)
99    heap: BinaryHeap<HeapEntry>,
100    pub k: usize,
101    /// Cached threshold: avoids repeated heap.peek() in hot loops.
102    /// Updated only when the heap changes (insert/pop).
103    cached_threshold: f32,
104}
105
106impl ScoreCollector {
107    /// Create a new collector for top-k results
108    pub fn new(k: usize) -> Self {
109        // Cap capacity to avoid allocation overflow for very large k
110        let capacity = k.saturating_add(1).min(1_000_000);
111        Self {
112            heap: BinaryHeap::with_capacity(capacity),
113            k,
114            cached_threshold: 0.0,
115        }
116    }
117
118    /// Current score threshold (minimum score to enter top-k)
119    #[inline]
120    pub fn threshold(&self) -> f32 {
121        self.cached_threshold
122    }
123
124    /// Recompute cached threshold from heap state
125    #[inline]
126    fn update_threshold(&mut self) {
127        self.cached_threshold = if self.heap.len() >= self.k {
128            self.heap.peek().map(|e| e.score).unwrap_or(0.0)
129        } else {
130            0.0
131        };
132    }
133
134    /// Insert a document score. Returns true if inserted in top-k.
135    /// Caller must ensure each doc_id is inserted only once.
136    #[inline]
137    pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
138        self.insert_with_ordinal(doc_id, score, 0)
139    }
140
141    /// Insert a document score with ordinal. Returns true if inserted in top-k.
142    /// Caller must ensure each doc_id is inserted only once.
143    #[inline]
144    pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
145        if self.heap.len() < self.k {
146            self.heap.push(HeapEntry {
147                doc_id,
148                score,
149                ordinal,
150            });
151            self.update_threshold();
152            true
153        } else if score > self.cached_threshold {
154            self.heap.push(HeapEntry {
155                doc_id,
156                score,
157                ordinal,
158            });
159            self.heap.pop(); // Remove lowest
160            self.update_threshold();
161            true
162        } else {
163            false
164        }
165    }
166
167    /// Check if a score could potentially enter top-k
168    #[inline]
169    pub fn would_enter(&self, score: f32) -> bool {
170        self.heap.len() < self.k || score > self.cached_threshold
171    }
172
173    /// Get number of documents collected so far
174    #[inline]
175    pub fn len(&self) -> usize {
176        self.heap.len()
177    }
178
179    /// Check if collector is empty
180    #[inline]
181    pub fn is_empty(&self) -> bool {
182        self.heap.is_empty()
183    }
184
185    /// Convert to sorted top-k results (descending by score)
186    pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
187        let heap_vec = self.heap.into_vec();
188        let mut results: Vec<(DocId, f32, u16)> = Vec::with_capacity(heap_vec.len());
189        for e in heap_vec {
190            results.push((e.doc_id, e.score, e.ordinal));
191        }
192
193        // Sort by score descending, then doc_id ascending
194        results.sort_by(|a, b| {
195            b.1.partial_cmp(&a.1)
196                .unwrap_or(Ordering::Equal)
197                .then_with(|| a.0.cmp(&b.0))
198        });
199
200        results
201    }
202}
203
204/// Search result from MaxScore execution
205#[derive(Debug, Clone, Copy)]
206pub struct ScoredDoc {
207    pub doc_id: DocId,
208    pub score: f32,
209    /// Ordinal for multi-valued fields (which vector in the field matched)
210    pub ordinal: u16,
211}
212
213/// Unified Block-Max MaxScore executor for top-k retrieval
214///
215/// Combines three optimizations from the literature into one executor:
216/// 1. **MaxScore partitioning** (Turtle & Flood 1995): terms split into essential
217///    (must check) and non-essential (only scored if candidate is promising)
218/// 2. **Block-max pruning** (Ding & Suel 2011): skip blocks where per-block
219///    upper bounds can't beat the current threshold
220/// 3. **Conjunction optimization** (Lucene/Grand 2023): progressively intersect
221///    essential terms as threshold rises, skipping docs that lack enough terms
222///
223/// Works with any type implementing `ScoringIterator` (text or sparse).
224/// Unified executor with better performance across all query lengths.
225pub struct MaxScoreExecutor<S: ScoringIterator> {
226    /// Scorers sorted by max_score ascending (non-essential first)
227    scorers: Vec<S>,
228    /// Cumulative max_score prefix sums: prefix_sums[i] = sum(max_score[0..=i])
229    prefix_sums: Vec<f32>,
230    /// Top-k collector
231    collector: ScoreCollector,
232    /// Heap factor for approximate search (SEISMIC-style)
233    /// - 1.0 = exact search (default)
234    /// - 0.8 = approximate, faster with minor recall loss
235    heap_factor: f32,
236}
237
238impl<S: ScoringIterator> MaxScoreExecutor<S> {
239    /// Create a new executor with exact search (heap_factor = 1.0)
240    pub fn new(scorers: Vec<S>, k: usize) -> Self {
241        Self::with_heap_factor(scorers, k, 1.0)
242    }
243
244    /// Create a new executor with approximate search
245    ///
246    /// `heap_factor` controls the trade-off between speed and recall:
247    /// - 1.0 = exact search
248    /// - 0.8 = ~20% faster, minor recall loss
249    /// - 0.5 = much faster, noticeable recall loss
250    pub fn with_heap_factor(mut scorers: Vec<S>, k: usize, heap_factor: f32) -> Self {
251        // Sort scorers by max_score ascending (non-essential terms first)
252        scorers.sort_by(|a, b| {
253            a.max_score()
254                .partial_cmp(&b.max_score())
255                .unwrap_or(Ordering::Equal)
256        });
257
258        // Compute prefix sums of max_scores
259        let mut prefix_sums = Vec::with_capacity(scorers.len());
260        let mut cumsum = 0.0f32;
261        for s in &scorers {
262            cumsum += s.max_score();
263            prefix_sums.push(cumsum);
264        }
265
266        debug!(
267            "Creating MaxScoreExecutor: num_scorers={}, k={}, total_upper={:.4}, heap_factor={:.2}",
268            scorers.len(),
269            k,
270            cumsum,
271            heap_factor
272        );
273
274        Self {
275            scorers,
276            prefix_sums,
277            collector: ScoreCollector::new(k),
278            heap_factor: heap_factor.clamp(0.0, 1.0),
279        }
280    }
281
282    /// Find partition point: [0..partition) = non-essential, [partition..n) = essential
283    /// Non-essential terms have cumulative max_score <= threshold
284    #[inline]
285    fn find_partition(&self) -> usize {
286        let threshold = self.collector.threshold() * self.heap_factor;
287        // Binary search: prefix_sums is monotonically increasing
288        self.prefix_sums.partition_point(|&sum| sum <= threshold)
289    }
290
291    /// Execute Block-Max MaxScore and return top-k results
292    ///
293    /// Algorithm:
294    /// 1. Partition terms into essential/non-essential based on max_score
295    /// 2. Find min_doc across essential scorers
296    /// 3. Conjunction check: skip if not enough essential terms present
297    /// 4. Block-max check: skip if block upper bounds can't beat threshold
298    /// 5. Score essential scorers, check if non-essential scoring is needed
299    /// 6. Score non-essential scorers, group by ordinal, insert results
300    pub fn execute(mut self) -> Vec<ScoredDoc> {
301        if self.scorers.is_empty() {
302            debug!("MaxScoreExecutor: no scorers, returning empty results");
303            return Vec::new();
304        }
305
306        let n = self.scorers.len();
307        let mut docs_scored = 0u64;
308        let mut docs_skipped = 0u64;
309        let mut blocks_skipped = 0u64;
310        let mut conjunction_skipped = 0u64;
311
312        // Pre-allocate scratch buffers outside the loop
313        let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
314
315        loop {
316            let partition = self.find_partition();
317
318            // If all terms are non-essential, we're done
319            if partition >= n {
320                debug!("BlockMaxScore: all terms non-essential, early termination");
321                break;
322            }
323
324            // Single fused pass over essential scorers: find min_doc and
325            // accumulate conjunction/block-max upper bounds simultaneously.
326            // This replaces 3 separate iterations with 1, reducing cache misses.
327            let mut min_doc = u32::MAX;
328            let mut present_upper = 0.0f32;
329            let mut block_max_sum = 0.0f32;
330            for i in partition..n {
331                let doc = self.scorers[i].doc();
332                if doc < min_doc {
333                    min_doc = doc;
334                    // New min_doc — reset accumulators to this scorer only
335                    present_upper = self.scorers[i].max_score();
336                    block_max_sum = self.scorers[i].current_block_max_score();
337                } else if doc == min_doc {
338                    present_upper += self.scorers[i].max_score();
339                    block_max_sum += self.scorers[i].current_block_max_score();
340                }
341            }
342
343            if min_doc == u32::MAX {
344                break; // All essential scorers exhausted
345            }
346
347            let non_essential_upper = if partition > 0 {
348                self.prefix_sums[partition - 1]
349            } else {
350                0.0
351            };
352            let adjusted_threshold = self.collector.threshold() * self.heap_factor;
353
354            // --- Conjunction optimization (Lucene-style) ---
355            // Check if enough essential terms are present at min_doc.
356            if self.collector.len() >= self.collector.k
357                && present_upper + non_essential_upper <= adjusted_threshold
358            {
359                for i in partition..n {
360                    if self.scorers[i].doc() == min_doc {
361                        self.scorers[i].advance();
362                    }
363                }
364                conjunction_skipped += 1;
365                continue;
366            }
367
368            // --- Block-max pruning ---
369            // If block-max sum + non-essential upper can't beat threshold, skip blocks.
370            if self.collector.len() >= self.collector.k
371                && block_max_sum + non_essential_upper <= adjusted_threshold
372            {
373                for i in partition..n {
374                    if self.scorers[i].doc() == min_doc {
375                        self.scorers[i].skip_to_next_block();
376                    }
377                }
378                blocks_skipped += 1;
379                continue;
380            }
381
382            // --- Score essential scorers ---
383            // Drain all entries for min_doc from each essential scorer
384            ordinal_scores.clear();
385
386            for i in partition..n {
387                if self.scorers[i].doc() == min_doc {
388                    while self.scorers[i].doc() == min_doc {
389                        ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
390                        self.scorers[i].advance();
391                    }
392                }
393            }
394
395            // Check if essential score + non-essential upper could beat threshold
396            let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
397
398            if self.collector.len() >= self.collector.k
399                && essential_total + non_essential_upper <= adjusted_threshold
400            {
401                docs_skipped += 1;
402                continue;
403            }
404
405            // --- Score non-essential scorers ---
406            for i in 0..partition {
407                let doc = self.scorers[i].seek(min_doc);
408                if doc == min_doc {
409                    while self.scorers[i].doc() == min_doc {
410                        ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
411                        self.scorers[i].advance();
412                    }
413                }
414            }
415
416            // --- Group by ordinal and insert ---
417            // Fast path: single entry (common for single-valued fields) — skip sort + grouping
418            if ordinal_scores.len() == 1 {
419                let (ord, score) = ordinal_scores[0];
420                trace!(
421                    "Doc {}: ordinal={}, score={:.4}, threshold={:.4}",
422                    min_doc, ord, score, adjusted_threshold
423                );
424                if self.collector.insert_with_ordinal(min_doc, score, ord) {
425                    docs_scored += 1;
426                } else {
427                    docs_skipped += 1;
428                }
429            } else if !ordinal_scores.is_empty() {
430                if ordinal_scores.len() > 2 {
431                    ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
432                } else if ordinal_scores[0].0 > ordinal_scores[1].0 {
433                    ordinal_scores.swap(0, 1);
434                }
435                let mut j = 0;
436                while j < ordinal_scores.len() {
437                    let current_ord = ordinal_scores[j].0;
438                    let mut score = 0.0f32;
439                    while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
440                        score += ordinal_scores[j].1;
441                        j += 1;
442                    }
443
444                    trace!(
445                        "Doc {}: ordinal={}, score={:.4}, threshold={:.4}",
446                        min_doc, current_ord, score, adjusted_threshold
447                    );
448
449                    if self
450                        .collector
451                        .insert_with_ordinal(min_doc, score, current_ord)
452                    {
453                        docs_scored += 1;
454                    } else {
455                        docs_skipped += 1;
456                    }
457                }
458            }
459        }
460
461        let results: Vec<ScoredDoc> = self
462            .collector
463            .into_sorted_results()
464            .into_iter()
465            .map(|(doc_id, score, ordinal)| ScoredDoc {
466                doc_id,
467                score,
468                ordinal,
469            })
470            .collect();
471
472        debug!(
473            "MaxScoreExecutor completed: scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
474            docs_scored,
475            docs_skipped,
476            blocks_skipped,
477            conjunction_skipped,
478            results.len(),
479            results.first().map(|r| r.score).unwrap_or(0.0)
480        );
481
482        results
483    }
484}
485
486/// Scorer for full-text terms using MaxScore optimization
487///
488/// Wraps a `BlockPostingList` with BM25 parameters to implement `ScoringIterator`.
489/// Enables MaxScore pruning for efficient top-k retrieval in OR queries.
490pub struct TextTermScorer {
491    /// Iterator over the posting list (owned)
492    iter: crate::structures::BlockPostingIterator<'static>,
493    /// IDF component for BM25
494    idf: f32,
495    /// Average field length for BM25 normalization
496    avg_field_len: f32,
497    /// Pre-computed max score (using max_tf from posting list)
498    max_score: f32,
499}
500
501impl TextTermScorer {
502    /// Create a new text term scorer with BM25 parameters
503    pub fn new(
504        posting_list: crate::structures::BlockPostingList,
505        idf: f32,
506        avg_field_len: f32,
507    ) -> Self {
508        // Compute max score using actual max_tf from posting list
509        let max_tf = posting_list.max_tf() as f32;
510        let doc_count = posting_list.doc_count();
511        let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
512
513        debug!(
514            "Created TextTermScorer: doc_count={}, max_tf={:.0}, idf={:.4}, avg_field_len={:.2}, max_score={:.4}",
515            doc_count, max_tf, idf, avg_field_len, max_score
516        );
517
518        Self {
519            iter: posting_list.into_iterator(),
520            idf,
521            avg_field_len,
522            max_score,
523        }
524    }
525}
526
527impl ScoringIterator for TextTermScorer {
528    #[inline]
529    fn doc(&self) -> DocId {
530        self.iter.doc()
531    }
532
533    #[inline]
534    fn advance(&mut self) -> DocId {
535        self.iter.advance()
536    }
537
538    #[inline]
539    fn seek(&mut self, target: DocId) -> DocId {
540        self.iter.seek(target)
541    }
542
543    #[inline]
544    fn score(&self) -> f32 {
545        let tf = self.iter.term_freq() as f32;
546        // Use tf as proxy for doc length (common approximation when field lengths aren't stored)
547        super::bm25_score(tf, self.idf, tf, self.avg_field_len)
548    }
549
550    #[inline]
551    fn max_score(&self) -> f32 {
552        self.max_score
553    }
554
555    #[inline]
556    fn current_block_max_score(&self) -> f32 {
557        // Use per-block max_tf for tighter block-max bounds
558        let block_max_tf = self.iter.current_block_max_tf() as f32;
559        super::bm25_upper_bound(block_max_tf.max(1.0), self.idf)
560    }
561
562    #[inline]
563    fn skip_to_next_block(&mut self) -> DocId {
564        self.iter.skip_to_next_block()
565    }
566}
567
568/// Scorer for sparse vector dimensions
569///
570/// Wraps a `BlockSparsePostingList` with a query weight to implement `ScoringIterator`.
571pub struct SparseTermScorer<'a> {
572    /// Iterator over the posting list
573    iter: crate::structures::BlockSparsePostingIterator<'a>,
574    /// Query weight for this dimension
575    query_weight: f32,
576    /// Pre-computed |query_weight| to avoid repeated .abs() in hot paths
577    abs_query_weight: f32,
578    /// Global max score (|query_weight| * global_max_weight)
579    max_score: f32,
580}
581
582impl<'a> SparseTermScorer<'a> {
583    /// Create a new sparse term scorer
584    ///
585    /// Note: Assumes positive weights for MaxScore upper bound calculation.
586    /// For negative query weights, uses absolute value to ensure valid upper bound.
587    pub fn new(posting_list: &'a BlockSparsePostingList, query_weight: f32) -> Self {
588        // Upper bound must account for sign: |query_weight| * max_weight
589        // This ensures the bound is valid regardless of weight sign
590        let abs_qw = query_weight.abs();
591        let max_score = abs_qw * posting_list.global_max_weight();
592        Self {
593            iter: posting_list.iterator(),
594            query_weight,
595            abs_query_weight: abs_qw,
596            max_score,
597        }
598    }
599
600    /// Create from Arc reference (for use with shared posting lists)
601    pub fn from_arc(posting_list: &'a Arc<BlockSparsePostingList>, query_weight: f32) -> Self {
602        Self::new(posting_list.as_ref(), query_weight)
603    }
604}
605
606impl ScoringIterator for SparseTermScorer<'_> {
607    #[inline]
608    fn doc(&self) -> DocId {
609        self.iter.doc()
610    }
611
612    #[inline]
613    fn ordinal(&mut self) -> u16 {
614        self.iter.ordinal()
615    }
616
617    #[inline]
618    fn advance(&mut self) -> DocId {
619        self.iter.advance()
620    }
621
622    #[inline]
623    fn seek(&mut self, target: DocId) -> DocId {
624        self.iter.seek(target)
625    }
626
627    #[inline]
628    fn score(&self) -> f32 {
629        // Dot product contribution: query_weight * stored_weight
630        self.query_weight * self.iter.weight()
631    }
632
633    #[inline]
634    fn max_score(&self) -> f32 {
635        self.max_score
636    }
637
638    #[inline]
639    fn current_block_max_score(&self) -> f32 {
640        self.iter
641            .current_block_max_contribution(self.abs_query_weight)
642    }
643
644    #[inline]
645    fn skip_to_next_block(&mut self) -> DocId {
646        self.iter.skip_to_next_block()
647    }
648}
649
650/// Block-Max Pruning (BMP) executor for learned sparse retrieval
651///
652/// Processes blocks in score-descending order using a priority queue.
653/// Best for queries with many terms (20+), like SPLADE expansions.
654/// Uses document accumulators (FxHashMap) instead of per-term iterators.
655///
656/// **Memory-efficient**: Only skip entries (block metadata) are kept in memory.
657/// Actual block data is loaded on-demand via mmap range reads during execution.
658///
659/// Reference: Mallia et al., "Faster Learned Sparse Retrieval with
660/// Block-Max Pruning" (SIGIR 2024)
661pub struct BmpExecutor<'a> {
662    /// Sparse index for on-demand block loading
663    sparse_index: &'a crate::segment::SparseIndex,
664    /// Query terms: (dim_id, query_weight) for each matched dimension
665    query_terms: Vec<(u32, f32)>,
666    /// Number of results to return
667    k: usize,
668    /// Heap factor for approximate search
669    heap_factor: f32,
670}
671
672/// Superblock size: group S consecutive blocks into one priority queue entry.
673/// Reduces heap operations by S× (e.g. 8× fewer push/pop for S=8).
674const BMP_SUPERBLOCK_SIZE: usize = 8;
675
676/// Megablock size: group M superblocks into one outer priority queue entry.
677/// Two-level pruning: megablock-level (coarse) → superblock-level (fine).
678/// Reduces outer heap operations by M× compared to single-level superblocks.
679const BMP_MEGABLOCK_SIZE: usize = 16;
680
681/// Superblock entry (stored per-term, not in the heap directly)
682struct BmpSuperBlock {
683    /// Upper bound contribution of this superblock (sum of constituent blocks)
684    contribution: f32,
685    /// First block index in this superblock
686    block_start: usize,
687    /// Number of blocks in this superblock (1..=BMP_SUPERBLOCK_SIZE)
688    block_count: usize,
689}
690
691/// Entry in the BMP outer priority queue: represents a megablock (group of superblocks)
692struct BmpMegaBlockEntry {
693    /// Upper bound contribution of this megablock (sum of constituent superblocks)
694    contribution: f32,
695    /// Index into query_terms
696    term_idx: usize,
697    /// First superblock index within term_superblocks[term_idx]
698    sb_start: usize,
699    /// Number of superblocks in this megablock (1..=BMP_MEGABLOCK_SIZE)
700    sb_count: usize,
701}
702
703impl PartialEq for BmpMegaBlockEntry {
704    fn eq(&self, other: &Self) -> bool {
705        self.contribution == other.contribution
706    }
707}
708
709impl Eq for BmpMegaBlockEntry {}
710
711impl Ord for BmpMegaBlockEntry {
712    fn cmp(&self, other: &Self) -> Ordering {
713        // Max-heap: higher contributions come first
714        self.contribution
715            .partial_cmp(&other.contribution)
716            .unwrap_or(Ordering::Equal)
717    }
718}
719
720impl PartialOrd for BmpMegaBlockEntry {
721    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
722        Some(self.cmp(other))
723    }
724}
725
726impl<'a> BmpExecutor<'a> {
727    /// Create a new BMP executor with lazy block loading
728    ///
729    /// `query_terms` should contain only dimensions that exist in the index.
730    /// Block metadata (skip entries) is read from the sparse index directly.
731    pub fn new(
732        sparse_index: &'a crate::segment::SparseIndex,
733        query_terms: Vec<(u32, f32)>,
734        k: usize,
735        heap_factor: f32,
736    ) -> Self {
737        Self {
738            sparse_index,
739            query_terms,
740            k,
741            heap_factor: heap_factor.clamp(0.0, 1.0),
742        }
743    }
744
745    /// Execute BMP and return top-k results
746    ///
747    /// Builds the priority queue from skip entries (already in memory),
748    /// then loads blocks on-demand via mmap range reads as they are visited.
749    ///
750    /// Uses a hybrid accumulator: flat `Vec<f32>` for single-ordinal (ordinal=0)
751    /// entries with O(1) insert, FxHashMap fallback for multi-ordinal entries.
752    pub async fn execute(self) -> crate::Result<Vec<ScoredDoc>> {
753        use rustc_hash::FxHashMap;
754
755        if self.query_terms.is_empty() {
756            return Ok(Vec::new());
757        }
758
759        let num_terms = self.query_terms.len();
760        let si = self.sparse_index;
761
762        // Two-level queue construction:
763        // 1. Build superblocks per term (flat Vecs)
764        // 2. Group superblocks into megablocks, push to outer BinaryHeap
765        let mut term_superblocks: Vec<Vec<BmpSuperBlock>> = Vec::with_capacity(num_terms);
766        let mut term_skip_starts: Vec<usize> = Vec::with_capacity(num_terms);
767        let mut global_min_doc = u32::MAX;
768        let mut global_max_doc = 0u32;
769        let mut total_remaining = 0.0f32;
770
771        for &(dim_id, qw) in &self.query_terms {
772            let mut term_skip_start = 0usize;
773            let mut superblocks = Vec::new();
774
775            let abs_qw = qw.abs();
776            if let Some((skip_start, skip_count, _global_max)) = si.get_skip_range(dim_id) {
777                term_skip_start = skip_start;
778                // Step 1: Build superblock entries
779                let mut sb_start = 0;
780                while sb_start < skip_count {
781                    let sb_count = (skip_count - sb_start).min(BMP_SUPERBLOCK_SIZE);
782                    let mut sb_contribution = 0.0f32;
783                    for j in 0..sb_count {
784                        let skip = si.read_skip_entry(skip_start + sb_start + j);
785                        sb_contribution += abs_qw * skip.max_weight;
786                        global_min_doc = global_min_doc.min(skip.first_doc);
787                        global_max_doc = global_max_doc.max(skip.last_doc);
788                    }
789                    total_remaining += sb_contribution;
790                    superblocks.push(BmpSuperBlock {
791                        contribution: sb_contribution,
792                        block_start: sb_start,
793                        block_count: sb_count,
794                    });
795                    sb_start += sb_count;
796                }
797            }
798            term_skip_starts.push(term_skip_start);
799            term_superblocks.push(superblocks);
800        }
801
802        // Step 2: Group superblocks into megablocks and build outer priority queue
803        let mut mega_queue: BinaryHeap<BmpMegaBlockEntry> = BinaryHeap::new();
804        for (term_idx, superblocks) in term_superblocks.iter().enumerate() {
805            let mut mb_start = 0;
806            while mb_start < superblocks.len() {
807                let mb_count = (superblocks.len() - mb_start).min(BMP_MEGABLOCK_SIZE);
808                let mb_contribution: f32 = superblocks[mb_start..mb_start + mb_count]
809                    .iter()
810                    .map(|sb| sb.contribution)
811                    .sum();
812                mega_queue.push(BmpMegaBlockEntry {
813                    contribution: mb_contribution,
814                    term_idx,
815                    sb_start: mb_start,
816                    sb_count: mb_count,
817                });
818                mb_start += mb_count;
819            }
820        }
821
822        // Hybrid accumulator: flat array for ordinal=0, FxHashMap for multi-ordinal
823        let doc_range = if global_max_doc >= global_min_doc {
824            (global_max_doc - global_min_doc + 1) as usize
825        } else {
826            0
827        };
828        // Use flat array if range is reasonable (< 256K docs)
829        let use_flat = doc_range > 0 && doc_range <= 256 * 1024;
830        let mut flat_scores: Vec<f32> = if use_flat {
831            vec![0.0; doc_range]
832        } else {
833            Vec::new()
834        };
835        // Dirty list: track which doc offsets were touched (avoids scanning full array)
836        let mut dirty: Vec<u32> = if use_flat {
837            Vec::with_capacity(4096)
838        } else {
839            Vec::new()
840        };
841        // FxHashMap fallback for multi-ordinal entries or when flat array is too large
842        let mut multi_ord_accumulators: FxHashMap<u64, f32> = FxHashMap::default();
843
844        let mut blocks_processed = 0u64;
845        let mut blocks_skipped = 0u64;
846
847        // Incremental top-k tracker for threshold
848        let mut top_k = ScoreCollector::new(self.k);
849
850        // Reusable decode buffers
851        let mut doc_ids_buf: Vec<u32> = Vec::with_capacity(256);
852        let mut weights_buf: Vec<f32> = Vec::with_capacity(256);
853        let mut ordinals_buf: Vec<u16> = Vec::with_capacity(256);
854
855        // Warm-start: ensure at least one megablock per term is processed before
856        // enabling early termination. This diversifies initial scores across terms,
857        // giving the top-k heap a better starting threshold for pruning.
858        let mut terms_warmed = vec![false; num_terms];
859        let mut warmup_remaining = self.k.min(num_terms);
860
861        // Two-level processing: outer loop pops megablocks, inner loop iterates superblocks
862        while let Some(mega) = mega_queue.pop() {
863            total_remaining -= mega.contribution;
864
865            // Track warm-start progress: count unique terms seen
866            if !terms_warmed[mega.term_idx] {
867                terms_warmed[mega.term_idx] = true;
868                warmup_remaining = warmup_remaining.saturating_sub(1);
869            }
870
871            // Megablock-level early termination (coarse)
872            if warmup_remaining == 0 {
873                let adjusted_threshold = top_k.threshold() * self.heap_factor;
874                if top_k.len() >= self.k && total_remaining <= adjusted_threshold {
875                    // Count remaining blocks across all unprocessed megablocks
876                    let remaining_blocks: u64 = mega_queue
877                        .iter()
878                        .map(|m| {
879                            let sbs =
880                                &term_superblocks[m.term_idx][m.sb_start..m.sb_start + m.sb_count];
881                            sbs.iter().map(|sb| sb.block_count as u64).sum::<u64>()
882                        })
883                        .sum();
884                    blocks_skipped += remaining_blocks;
885                    debug!(
886                        "BMP early termination after {} blocks: remaining={:.4} <= threshold={:.4}",
887                        blocks_processed, total_remaining, adjusted_threshold
888                    );
889                    break;
890                }
891            }
892
893            let dim_id = self.query_terms[mega.term_idx].0;
894            let qw = self.query_terms[mega.term_idx].1;
895            let abs_qw = qw.abs();
896            let skip_start = term_skip_starts[mega.term_idx];
897
898            // Inner loop: iterate superblocks within this megablock
899            for sb in term_superblocks[mega.term_idx]
900                .iter()
901                .skip(mega.sb_start)
902                .take(mega.sb_count)
903            {
904                // Superblock-level pruning (fine-grained)
905                if top_k.len() >= self.k {
906                    let adjusted_threshold = top_k.threshold() * self.heap_factor;
907                    if sb.contribution + total_remaining <= adjusted_threshold {
908                        blocks_skipped += sb.block_count as u64;
909                        continue;
910                    }
911                }
912
913                // Coalesced superblock loading: single mmap read for all blocks
914                let sb_blocks = si
915                    .get_blocks_range(dim_id, sb.block_start, sb.block_count)
916                    .await?;
917
918                let adjusted_threshold2 = top_k.threshold() * self.heap_factor;
919
920                // Track dirty start for deferred top_k insertion at superblock boundary
921                let dirty_start = dirty.len();
922
923                for (blk_offset, block) in sb_blocks.into_iter().enumerate() {
924                    let blk_idx = sb.block_start + blk_offset;
925
926                    // Per-block pruning within superblock
927                    if top_k.len() >= self.k {
928                        let skip = si.read_skip_entry(skip_start + blk_idx);
929                        let blk_contrib = abs_qw * skip.max_weight;
930                        if blk_contrib + total_remaining <= adjusted_threshold2 {
931                            blocks_skipped += 1;
932                            continue;
933                        }
934                    }
935
936                    block.decode_doc_ids_into(&mut doc_ids_buf);
937
938                    // Fast path: ordinal=0 + flat accumulator → fused decode+scatter
939                    if block.header.ordinal_bits == 0 && use_flat {
940                        block.accumulate_scored_weights(
941                            qw,
942                            &doc_ids_buf,
943                            &mut flat_scores,
944                            global_min_doc,
945                            &mut dirty,
946                        );
947                    } else {
948                        block.decode_scored_weights_into(qw, &mut weights_buf);
949                        let count = block.header.count as usize;
950
951                        block.decode_ordinals_into(&mut ordinals_buf);
952                        if use_flat {
953                            for i in 0..count {
954                                let doc_id = doc_ids_buf[i];
955                                let ordinal = ordinals_buf[i];
956                                let score_contribution = weights_buf[i];
957
958                                if ordinal == 0 {
959                                    let off = (doc_id - global_min_doc) as usize;
960                                    if flat_scores[off] == 0.0 {
961                                        dirty.push(doc_id);
962                                    }
963                                    flat_scores[off] += score_contribution;
964                                } else {
965                                    let key = (doc_id as u64) << 16 | ordinal as u64;
966                                    let acc = multi_ord_accumulators.entry(key).or_insert(0.0);
967                                    *acc += score_contribution;
968                                    top_k.insert_with_ordinal(doc_id, *acc, ordinal);
969                                }
970                            }
971                        } else {
972                            for i in 0..count {
973                                let key = (doc_ids_buf[i] as u64) << 16 | ordinals_buf[i] as u64;
974                                let acc = multi_ord_accumulators.entry(key).or_insert(0.0);
975                                *acc += weights_buf[i];
976                                top_k.insert_with_ordinal(doc_ids_buf[i], *acc, ordinals_buf[i]);
977                            }
978                        }
979                    }
980
981                    blocks_processed += 1;
982                }
983
984                // Deferred top_k insertion at superblock boundary:
985                // Scan only newly-dirty entries (first-touch docs from this superblock).
986                // Eliminates ~90% of per-posting heap operations.
987                for &doc_id in &dirty[dirty_start..] {
988                    let off = (doc_id - global_min_doc) as usize;
989                    top_k.insert_with_ordinal(doc_id, flat_scores[off], 0);
990                }
991            }
992        }
993
994        // Collect results from both accumulators
995        let mut scored: Vec<ScoredDoc> = Vec::new();
996
997        let num_accumulators = if use_flat {
998            // Flat array entries (ordinal=0)
999            scored.reserve(dirty.len() + multi_ord_accumulators.len());
1000            for &doc_id in &dirty {
1001                let off = (doc_id - global_min_doc) as usize;
1002                let score = flat_scores[off];
1003                if score > 0.0 {
1004                    scored.push(ScoredDoc {
1005                        doc_id,
1006                        score,
1007                        ordinal: 0,
1008                    });
1009                }
1010            }
1011            dirty.len() + multi_ord_accumulators.len()
1012        } else {
1013            multi_ord_accumulators.len()
1014        };
1015
1016        // Multi-ordinal entries
1017        scored.extend(
1018            multi_ord_accumulators
1019                .into_iter()
1020                .map(|(key, score)| ScoredDoc {
1021                    doc_id: (key >> 16) as DocId,
1022                    score,
1023                    ordinal: (key & 0xFFFF) as u16,
1024                }),
1025        );
1026
1027        scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
1028        scored.truncate(self.k);
1029        let results = scored;
1030
1031        debug!(
1032            "BmpExecutor completed: blocks_processed={}, blocks_skipped={}, accumulators={}, flat={}, returned={}, top_score={:.4}",
1033            blocks_processed,
1034            blocks_skipped,
1035            num_accumulators,
1036            use_flat,
1037            results.len(),
1038            results.first().map(|r| r.score).unwrap_or(0.0)
1039        );
1040
1041        Ok(results)
1042    }
1043}
1044
1045/// Lazy Block-Max MaxScore executor for sparse retrieval (1-11 terms)
1046///
1047/// Combines BlockMaxScore's cursor-based document-at-a-time traversal with
1048/// BMP's lazy block loading. Skip entries (already in memory via zero-copy
1049/// mmap) drive block-level navigation; actual block data is loaded on-demand
1050/// only when the cursor visits that block.
1051///
1052/// For typical 1-11 term queries with MaxScore pruning, many blocks are
1053/// skipped entirely — lazy loading avoids the I/O and decode cost for those
1054/// blocks. This hybrid achieves BMP's memory efficiency with BlockMaxScore's
1055/// superior pruning for few-term queries.
1056pub struct SparseMaxScoreExecutor<'a> {
1057    sparse_index: &'a crate::segment::SparseIndex,
1058    cursors: Vec<LazyTermCursor>,
1059    prefix_sums: Vec<f32>,
1060    collector: ScoreCollector,
1061    heap_factor: f32,
1062}
1063
1064/// Per-term cursor state for lazy block loading
1065struct LazyTermCursor {
1066    query_weight: f32,
1067    /// Pre-computed |query_weight| to avoid repeated .abs() in hot paths
1068    abs_query_weight: f32,
1069    max_score: f32,
1070    /// Index of first skip entry in the SparseIndex skip section (zero-alloc)
1071    skip_start: usize,
1072    /// Number of skip entries (blocks) for this dimension
1073    skip_count: usize,
1074    /// Base byte offset for block data (pre-resolved, avoids dim_id lookup per load)
1075    block_data_offset: u64,
1076    /// Current block index (0-based relative to this dimension's blocks)
1077    block_idx: usize,
1078    /// Decoded block data (loaded on demand, reused across seeks)
1079    doc_ids: Vec<u32>,
1080    ordinals: Vec<u16>,
1081    weights: Vec<f32>,
1082    /// Position within current decoded block
1083    pos: usize,
1084    /// Whether block at block_idx is decoded into doc_ids/ordinals/weights
1085    block_loaded: bool,
1086    exhausted: bool,
1087}
1088
1089impl LazyTermCursor {
1090    fn new(
1091        query_weight: f32,
1092        skip_start: usize,
1093        skip_count: usize,
1094        global_max_weight: f32,
1095        block_data_offset: u64,
1096    ) -> Self {
1097        let exhausted = skip_count == 0;
1098        let abs_qw = query_weight.abs();
1099        Self {
1100            query_weight,
1101            abs_query_weight: abs_qw,
1102            max_score: abs_qw * global_max_weight,
1103            skip_start,
1104            skip_count,
1105            block_data_offset,
1106            block_idx: 0,
1107            doc_ids: Vec::with_capacity(256),
1108            ordinals: Vec::with_capacity(256),
1109            weights: Vec::with_capacity(256),
1110            pos: 0,
1111            block_loaded: false,
1112            exhausted,
1113        }
1114    }
1115
1116    // --- Shared non-I/O helpers ---
1117
1118    /// Decode a loaded block into the cursor's buffers
1119    #[inline]
1120    fn decode_block(&mut self, block: crate::structures::SparseBlock) {
1121        block.decode_doc_ids_into(&mut self.doc_ids);
1122        block.decode_ordinals_into(&mut self.ordinals);
1123        block.decode_scored_weights_into(self.query_weight, &mut self.weights);
1124        self.pos = 0;
1125        self.block_loaded = true;
1126    }
1127
1128    /// Handle a loaded block result (Some/None). Returns Ok(true) if block was loaded.
1129    #[inline]
1130    fn handle_block_result(
1131        &mut self,
1132        block: Option<crate::structures::SparseBlock>,
1133    ) -> crate::Result<bool> {
1134        match block {
1135            Some(b) => {
1136                self.decode_block(b);
1137                Ok(true)
1138            }
1139            None => {
1140                self.exhausted = true;
1141                Ok(false)
1142            }
1143        }
1144    }
1145
1146    /// Advance position within current block, moving to next block if needed.
1147    /// Does NOT load the next block (lazy). Returns current doc.
1148    #[inline]
1149    fn advance_pos(&mut self) -> DocId {
1150        self.pos += 1;
1151        if self.pos >= self.doc_ids.len() {
1152            self.block_idx += 1;
1153            self.block_loaded = false;
1154            if self.block_idx >= self.skip_count {
1155                self.exhausted = true;
1156                return u32::MAX;
1157            }
1158        }
1159        self.doc()
1160    }
1161
1162    /// Seek preparation: handle in-block seek and binary search on skip entries.
1163    /// Returns `Ok(Some(doc))` if seek resolved without needing a block load,
1164    /// or `Ok(None)` if a block load is needed (block_idx updated, block_loaded = false).
1165    fn seek_prepare(
1166        &mut self,
1167        si: &crate::segment::SparseIndex,
1168        target: DocId,
1169    ) -> crate::Result<Option<DocId>> {
1170        if self.exhausted {
1171            return Ok(Some(u32::MAX));
1172        }
1173
1174        // If block is loaded and target is within current block range
1175        if self.block_loaded
1176            && let Some(&last) = self.doc_ids.last()
1177        {
1178            if last >= target && self.doc_ids[self.pos] < target {
1179                let remaining = &self.doc_ids[self.pos..];
1180                let offset = crate::structures::simd::find_first_ge_u32(remaining, target);
1181                self.pos += offset;
1182                if self.pos >= self.doc_ids.len() {
1183                    self.block_idx += 1;
1184                    self.block_loaded = false;
1185                    if self.block_idx >= self.skip_count {
1186                        self.exhausted = true;
1187                        return Ok(Some(u32::MAX));
1188                    }
1189                }
1190                return Ok(Some(self.doc()));
1191            }
1192            if self.doc_ids[self.pos] >= target {
1193                return Ok(Some(self.doc()));
1194            }
1195        }
1196
1197        // Binary search on skip entries: find first block where last_doc >= target.
1198        let mut lo = self.block_idx;
1199        let mut hi = self.skip_count;
1200        while lo < hi {
1201            let mid = lo + (hi - lo) / 2;
1202            if si.read_skip_entry(self.skip_start + mid).last_doc < target {
1203                lo = mid + 1;
1204            } else {
1205                hi = mid;
1206            }
1207        }
1208        if lo >= self.skip_count {
1209            self.exhausted = true;
1210            return Ok(Some(u32::MAX));
1211        }
1212        if lo != self.block_idx || !self.block_loaded {
1213            self.block_idx = lo;
1214            self.block_loaded = false;
1215        }
1216        // Need a block load — caller handles sync/async
1217        Ok(None)
1218    }
1219
1220    /// Finish seek after block load: position within the loaded block.
1221    /// Returns `true` if a second block load is needed.
1222    #[inline]
1223    fn seek_finish(&mut self, target: DocId) -> bool {
1224        if self.exhausted {
1225            return false;
1226        }
1227        self.pos = crate::structures::simd::find_first_ge_u32(&self.doc_ids, target);
1228        if self.pos >= self.doc_ids.len() {
1229            self.block_idx += 1;
1230            self.block_loaded = false;
1231            if self.block_idx >= self.skip_count {
1232                self.exhausted = true;
1233                return false;
1234            }
1235            return true; // need second block load
1236        }
1237        false
1238    }
1239
1240    // --- Async methods (delegate to shared helpers) ---
1241
1242    async fn ensure_block_loaded(
1243        &mut self,
1244        si: &crate::segment::SparseIndex,
1245    ) -> crate::Result<bool> {
1246        if self.exhausted || self.block_loaded {
1247            return Ok(!self.exhausted);
1248        }
1249        let block = si
1250            .load_block_direct(self.skip_start, self.block_data_offset, self.block_idx)
1251            .await?;
1252        self.handle_block_result(block)
1253    }
1254
1255    async fn advance(&mut self, si: &crate::segment::SparseIndex) -> crate::Result<DocId> {
1256        if self.exhausted {
1257            return Ok(u32::MAX);
1258        }
1259        self.ensure_block_loaded(si).await?;
1260        if self.exhausted {
1261            return Ok(u32::MAX);
1262        }
1263        Ok(self.advance_pos())
1264    }
1265
1266    async fn seek(
1267        &mut self,
1268        si: &crate::segment::SparseIndex,
1269        target: DocId,
1270    ) -> crate::Result<DocId> {
1271        if let Some(doc) = self.seek_prepare(si, target)? {
1272            return Ok(doc);
1273        }
1274        self.ensure_block_loaded(si).await?;
1275        if self.seek_finish(target) {
1276            self.ensure_block_loaded(si).await?;
1277        }
1278        Ok(self.doc())
1279    }
1280
1281    // --- Sync methods (delegate to same shared helpers) ---
1282
1283    fn ensure_block_loaded_sync(
1284        &mut self,
1285        si: &crate::segment::SparseIndex,
1286    ) -> crate::Result<bool> {
1287        if self.exhausted || self.block_loaded {
1288            return Ok(!self.exhausted);
1289        }
1290        let block =
1291            si.load_block_direct_sync(self.skip_start, self.block_data_offset, self.block_idx)?;
1292        self.handle_block_result(block)
1293    }
1294
1295    fn advance_sync(&mut self, si: &crate::segment::SparseIndex) -> crate::Result<DocId> {
1296        if self.exhausted {
1297            return Ok(u32::MAX);
1298        }
1299        self.ensure_block_loaded_sync(si)?;
1300        if self.exhausted {
1301            return Ok(u32::MAX);
1302        }
1303        Ok(self.advance_pos())
1304    }
1305
1306    fn seek_sync(
1307        &mut self,
1308        si: &crate::segment::SparseIndex,
1309        target: DocId,
1310    ) -> crate::Result<DocId> {
1311        if let Some(doc) = self.seek_prepare(si, target)? {
1312            return Ok(doc);
1313        }
1314        self.ensure_block_loaded_sync(si)?;
1315        if self.seek_finish(target) {
1316            self.ensure_block_loaded_sync(si)?;
1317        }
1318        Ok(self.doc())
1319    }
1320
1321    // --- Read-only accessors (shared) ---
1322
1323    #[inline]
1324    fn doc_with_si(&self, si: &crate::segment::SparseIndex) -> DocId {
1325        if self.exhausted {
1326            return u32::MAX;
1327        }
1328        if !self.block_loaded {
1329            return si
1330                .read_skip_entry(self.skip_start + self.block_idx)
1331                .first_doc;
1332        }
1333        self.doc_ids.get(self.pos).copied().unwrap_or(u32::MAX)
1334    }
1335
1336    #[inline]
1337    fn doc(&self) -> DocId {
1338        if self.exhausted {
1339            return u32::MAX;
1340        }
1341        if self.block_loaded {
1342            return self.doc_ids.get(self.pos).copied().unwrap_or(u32::MAX);
1343        }
1344        u32::MAX
1345    }
1346
1347    #[inline]
1348    fn ordinal(&self) -> u16 {
1349        if !self.block_loaded {
1350            return 0;
1351        }
1352        self.ordinals.get(self.pos).copied().unwrap_or(0)
1353    }
1354
1355    #[inline]
1356    fn score(&self) -> f32 {
1357        if !self.block_loaded {
1358            return 0.0;
1359        }
1360        self.weights.get(self.pos).copied().unwrap_or(0.0)
1361    }
1362
1363    #[inline]
1364    fn current_block_max_score(&self, si: &crate::segment::SparseIndex) -> f32 {
1365        if self.exhausted || self.block_idx >= self.skip_count {
1366            return 0.0;
1367        }
1368        self.abs_query_weight
1369            * si.read_skip_entry(self.skip_start + self.block_idx)
1370                .max_weight
1371    }
1372
1373    /// Skip to next block without loading it (for block-max pruning)
1374    fn skip_to_next_block(&mut self, si: &crate::segment::SparseIndex) -> DocId {
1375        if self.exhausted {
1376            return u32::MAX;
1377        }
1378        self.block_idx += 1;
1379        self.block_loaded = false;
1380        if self.block_idx >= self.skip_count {
1381            self.exhausted = true;
1382            return u32::MAX;
1383        }
1384        si.read_skip_entry(self.skip_start + self.block_idx)
1385            .first_doc
1386    }
1387}
1388
1389/// Macro to stamp out the Block-Max MaxScore loop for both async and sync paths.
1390///
1391/// `$ensure`, `$advance`, `$seek` are cursor method idents (async or _sync variants).
1392/// `$($aw:tt)*` captures `.await` for async or nothing for sync.
1393macro_rules! bms_execute_loop {
1394    ($self:ident, $ensure:ident, $advance:ident, $seek:ident, $($aw:tt)*) => {{
1395        let n = $self.cursors.len();
1396        let si = $self.sparse_index;
1397
1398        // Load first block for each cursor (ensures doc() returns real values)
1399        for cursor in &mut $self.cursors {
1400            cursor.$ensure(si) $($aw)* ?;
1401        }
1402
1403        let mut docs_scored = 0u64;
1404        let mut docs_skipped = 0u64;
1405        let mut blocks_skipped = 0u64;
1406        let mut blocks_loaded = 0u64;
1407        let mut conjunction_skipped = 0u64;
1408        let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
1409
1410        loop {
1411            let partition = $self.find_partition();
1412            if partition >= n {
1413                break;
1414            }
1415
1416            // Find minimum doc_id across essential cursors
1417            let mut min_doc = u32::MAX;
1418            for i in partition..n {
1419                let doc = $self.cursors[i].doc_with_si(si);
1420                if doc < min_doc {
1421                    min_doc = doc;
1422                }
1423            }
1424            if min_doc == u32::MAX {
1425                break;
1426            }
1427
1428            let non_essential_upper = if partition > 0 {
1429                $self.prefix_sums[partition - 1]
1430            } else {
1431                0.0
1432            };
1433            let adjusted_threshold = $self.collector.threshold() * $self.heap_factor;
1434
1435            // --- Conjunction optimization ---
1436            if $self.collector.len() >= $self.collector.k {
1437                let present_upper: f32 = (partition..n)
1438                    .filter(|&i| $self.cursors[i].doc_with_si(si) == min_doc)
1439                    .map(|i| $self.cursors[i].max_score)
1440                    .sum();
1441
1442                if present_upper + non_essential_upper <= adjusted_threshold {
1443                    for i in partition..n {
1444                        if $self.cursors[i].doc_with_si(si) == min_doc {
1445                            $self.cursors[i].$ensure(si) $($aw)* ?;
1446                            $self.cursors[i].$advance(si) $($aw)* ?;
1447                            blocks_loaded += u64::from($self.cursors[i].block_loaded);
1448                        }
1449                    }
1450                    conjunction_skipped += 1;
1451                    continue;
1452                }
1453            }
1454
1455            // --- Block-max pruning ---
1456            if $self.collector.len() >= $self.collector.k {
1457                let block_max_sum: f32 = (partition..n)
1458                    .filter(|&i| $self.cursors[i].doc_with_si(si) == min_doc)
1459                    .map(|i| $self.cursors[i].current_block_max_score(si))
1460                    .sum();
1461
1462                if block_max_sum + non_essential_upper <= adjusted_threshold {
1463                    for i in partition..n {
1464                        if $self.cursors[i].doc_with_si(si) == min_doc {
1465                            $self.cursors[i].skip_to_next_block(si);
1466                            $self.cursors[i].$ensure(si) $($aw)* ?;
1467                            blocks_loaded += 1;
1468                        }
1469                    }
1470                    blocks_skipped += 1;
1471                    continue;
1472                }
1473            }
1474
1475            // --- Score essential cursors ---
1476            ordinal_scores.clear();
1477            for i in partition..n {
1478                if $self.cursors[i].doc_with_si(si) == min_doc {
1479                    $self.cursors[i].$ensure(si) $($aw)* ?;
1480                    while $self.cursors[i].doc_with_si(si) == min_doc {
1481                        ordinal_scores.push(($self.cursors[i].ordinal(), $self.cursors[i].score()));
1482                        $self.cursors[i].$advance(si) $($aw)* ?;
1483                    }
1484                }
1485            }
1486
1487            let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
1488            if $self.collector.len() >= $self.collector.k
1489                && essential_total + non_essential_upper <= adjusted_threshold
1490            {
1491                docs_skipped += 1;
1492                continue;
1493            }
1494
1495            // --- Score non-essential cursors (highest max_score first for early exit) ---
1496            let mut running_total = essential_total;
1497            for i in (0..partition).rev() {
1498                if $self.collector.len() >= $self.collector.k
1499                    && running_total + $self.prefix_sums[i] <= adjusted_threshold
1500                {
1501                    break;
1502                }
1503
1504                let doc = $self.cursors[i].$seek(si, min_doc) $($aw)* ?;
1505                if doc == min_doc {
1506                    while $self.cursors[i].doc_with_si(si) == min_doc {
1507                        let s = $self.cursors[i].score();
1508                        running_total += s;
1509                        ordinal_scores.push(($self.cursors[i].ordinal(), s));
1510                        $self.cursors[i].$advance(si) $($aw)* ?;
1511                    }
1512                }
1513            }
1514
1515            // --- Group by ordinal and insert ---
1516            // Fast path: single entry (common for single-valued fields) — skip sort + grouping
1517            if ordinal_scores.len() == 1 {
1518                let (ord, score) = ordinal_scores[0];
1519                if $self.collector.insert_with_ordinal(min_doc, score, ord) {
1520                    docs_scored += 1;
1521                } else {
1522                    docs_skipped += 1;
1523                }
1524            } else if !ordinal_scores.is_empty() {
1525                if ordinal_scores.len() > 2 {
1526                    ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
1527                } else if ordinal_scores[0].0 > ordinal_scores[1].0 {
1528                    ordinal_scores.swap(0, 1);
1529                }
1530                let mut j = 0;
1531                while j < ordinal_scores.len() {
1532                    let current_ord = ordinal_scores[j].0;
1533                    let mut score = 0.0f32;
1534                    while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
1535                        score += ordinal_scores[j].1;
1536                        j += 1;
1537                    }
1538                    if $self
1539                        .collector
1540                        .insert_with_ordinal(min_doc, score, current_ord)
1541                    {
1542                        docs_scored += 1;
1543                    } else {
1544                        docs_skipped += 1;
1545                    }
1546                }
1547            }
1548        }
1549
1550        let results: Vec<ScoredDoc> = $self
1551            .collector
1552            .into_sorted_results()
1553            .into_iter()
1554            .map(|(doc_id, score, ordinal)| ScoredDoc {
1555                doc_id,
1556                score,
1557                ordinal,
1558            })
1559            .collect();
1560
1561        debug!(
1562            "SparseMaxScoreExecutor: scored={}, skipped={}, blocks_skipped={}, blocks_loaded={}, conjunction_skipped={}, returned={}, top_score={:.4}",
1563            docs_scored,
1564            docs_skipped,
1565            blocks_skipped,
1566            blocks_loaded,
1567            conjunction_skipped,
1568            results.len(),
1569            results.first().map(|r| r.score).unwrap_or(0.0)
1570        );
1571
1572        Ok(results)
1573    }};
1574}
1575
1576impl<'a> SparseMaxScoreExecutor<'a> {
1577    /// Create a new lazy executor
1578    ///
1579    /// `query_terms` should contain only dimensions present in the index.
1580    /// Skip entries are read from the zero-copy mmap section (no I/O).
1581    pub fn new(
1582        sparse_index: &'a crate::segment::SparseIndex,
1583        query_terms: Vec<(u32, f32)>,
1584        k: usize,
1585        heap_factor: f32,
1586    ) -> Self {
1587        let mut cursors: Vec<LazyTermCursor> = query_terms
1588            .iter()
1589            .filter_map(|&(dim_id, qw)| {
1590                let (skip_start, skip_count, global_max, block_data_offset) =
1591                    sparse_index.get_skip_range_full(dim_id)?;
1592                Some(LazyTermCursor::new(
1593                    qw,
1594                    skip_start,
1595                    skip_count,
1596                    global_max,
1597                    block_data_offset,
1598                ))
1599            })
1600            .collect();
1601
1602        // Sort by max_score ascending (non-essential first)
1603        cursors.sort_by(|a, b| {
1604            a.max_score
1605                .partial_cmp(&b.max_score)
1606                .unwrap_or(Ordering::Equal)
1607        });
1608
1609        let mut prefix_sums = Vec::with_capacity(cursors.len());
1610        let mut cumsum = 0.0f32;
1611        for c in &cursors {
1612            cumsum += c.max_score;
1613            prefix_sums.push(cumsum);
1614        }
1615
1616        debug!(
1617            "Creating SparseMaxScoreExecutor: num_terms={}, k={}, total_upper={:.4}, heap_factor={:.2}",
1618            cursors.len(),
1619            k,
1620            cumsum,
1621            heap_factor
1622        );
1623
1624        Self {
1625            sparse_index,
1626            cursors,
1627            prefix_sums,
1628            collector: ScoreCollector::new(k),
1629            heap_factor: heap_factor.clamp(0.0, 1.0),
1630        }
1631    }
1632
1633    #[inline]
1634    fn find_partition(&self) -> usize {
1635        let threshold = self.collector.threshold() * self.heap_factor;
1636        // Binary search: prefix_sums is monotonically increasing
1637        self.prefix_sums.partition_point(|&sum| sum <= threshold)
1638    }
1639
1640    /// Execute lazy Block-Max MaxScore and return top-k results
1641    pub async fn execute(mut self) -> crate::Result<Vec<ScoredDoc>> {
1642        if self.cursors.is_empty() {
1643            return Ok(Vec::new());
1644        }
1645        bms_execute_loop!(self, ensure_block_loaded, advance, seek, .await)
1646    }
1647
1648    /// Synchronous execution — only works when the sparse index has an Inline (mmap/RAM) handle.
1649    /// Bypasses all async overhead for mmap-backed indexes.
1650    pub fn execute_sync(mut self) -> crate::Result<Vec<ScoredDoc>> {
1651        if self.cursors.is_empty() {
1652            return Ok(Vec::new());
1653        }
1654        bms_execute_loop!(self, ensure_block_loaded_sync, advance_sync, seek_sync,)
1655    }
1656}
1657
1658#[cfg(test)]
1659mod tests {
1660    use super::*;
1661
1662    #[test]
1663    fn test_score_collector_basic() {
1664        let mut collector = ScoreCollector::new(3);
1665
1666        collector.insert(1, 1.0);
1667        collector.insert(2, 2.0);
1668        collector.insert(3, 3.0);
1669        assert_eq!(collector.threshold(), 1.0);
1670
1671        collector.insert(4, 4.0);
1672        assert_eq!(collector.threshold(), 2.0);
1673
1674        let results = collector.into_sorted_results();
1675        assert_eq!(results.len(), 3);
1676        assert_eq!(results[0].0, 4); // Highest score
1677        assert_eq!(results[1].0, 3);
1678        assert_eq!(results[2].0, 2);
1679    }
1680
1681    #[test]
1682    fn test_score_collector_threshold() {
1683        let mut collector = ScoreCollector::new(2);
1684
1685        collector.insert(1, 5.0);
1686        collector.insert(2, 3.0);
1687        assert_eq!(collector.threshold(), 3.0);
1688
1689        // Should not enter (score too low)
1690        assert!(!collector.would_enter(2.0));
1691        assert!(!collector.insert(3, 2.0));
1692
1693        // Should enter (score high enough)
1694        assert!(collector.would_enter(4.0));
1695        assert!(collector.insert(4, 4.0));
1696        assert_eq!(collector.threshold(), 4.0);
1697    }
1698
1699    #[test]
1700    fn test_heap_entry_ordering() {
1701        let mut heap = BinaryHeap::new();
1702        heap.push(HeapEntry {
1703            doc_id: 1,
1704            score: 3.0,
1705            ordinal: 0,
1706        });
1707        heap.push(HeapEntry {
1708            doc_id: 2,
1709            score: 1.0,
1710            ordinal: 0,
1711        });
1712        heap.push(HeapEntry {
1713            doc_id: 3,
1714            score: 2.0,
1715            ordinal: 0,
1716        });
1717
1718        // Min-heap: lowest score should come out first
1719        assert_eq!(heap.pop().unwrap().score, 1.0);
1720        assert_eq!(heap.pop().unwrap().score, 2.0);
1721        assert_eq!(heap.pop().unwrap().score, 3.0);
1722    }
1723}