Skip to main content

hermes_core/query/
scoring.rs

1//! Shared scoring abstractions for text and sparse vector search
2//!
3//! Provides common traits and utilities for efficient top-k retrieval:
4//! - `ScoringIterator`: Common interface for posting list iteration with scoring
5//! - `TopKCollector`: Efficient min-heap for maintaining top-k results
6//! - `BlockMaxScoreExecutor`: Unified Block-Max MaxScore with conjunction optimization
7//! - `BmpExecutor`: Block-at-a-time executor for learned sparse retrieval (12+ terms)
8//! - `SparseTermScorer`: ScoringIterator implementation for sparse vectors
9
10use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12use std::sync::Arc;
13
14use log::{debug, trace};
15
16use crate::DocId;
17use crate::structures::BlockSparsePostingList;
18
19/// Common interface for scoring iterators (text terms or sparse dimensions)
20///
21/// Abstracts the common operations needed for WAND-style top-k retrieval.
22pub trait ScoringIterator {
23    /// Current document ID (u32::MAX if exhausted)
24    fn doc(&self) -> DocId;
25
26    /// Current ordinal for multi-valued fields (default 0)
27    fn ordinal(&self) -> u16 {
28        0
29    }
30
31    /// Advance to next document, returns new doc ID
32    fn advance(&mut self) -> DocId;
33
34    /// Seek to first document >= target, returns new doc ID
35    fn seek(&mut self, target: DocId) -> DocId;
36
37    /// Check if iterator is exhausted
38    fn is_exhausted(&self) -> bool {
39        self.doc() == u32::MAX
40    }
41
42    /// Score contribution for current document
43    fn score(&self) -> f32;
44
45    /// Maximum possible score for this term/dimension (global upper bound)
46    fn max_score(&self) -> f32;
47
48    /// Current block's maximum score upper bound (for block-level pruning)
49    fn current_block_max_score(&self) -> f32;
50
51    /// Skip to the next block, returning the first doc_id in the new block.
52    /// Used for block-max WAND optimization when current block can't beat threshold.
53    /// Default implementation just advances (no block-level skipping).
54    fn skip_to_next_block(&mut self) -> DocId {
55        self.advance()
56    }
57}
58
59/// Entry for top-k min-heap
60#[derive(Clone, Copy)]
61pub struct HeapEntry {
62    pub doc_id: DocId,
63    pub score: f32,
64    pub ordinal: u16,
65}
66
67impl PartialEq for HeapEntry {
68    fn eq(&self, other: &Self) -> bool {
69        self.score == other.score && self.doc_id == other.doc_id
70    }
71}
72
73impl Eq for HeapEntry {}
74
75impl Ord for HeapEntry {
76    fn cmp(&self, other: &Self) -> Ordering {
77        // Min-heap: lower scores come first (to be evicted)
78        other
79            .score
80            .partial_cmp(&self.score)
81            .unwrap_or(Ordering::Equal)
82            .then_with(|| self.doc_id.cmp(&other.doc_id))
83    }
84}
85
86impl PartialOrd for HeapEntry {
87    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
88        Some(self.cmp(other))
89    }
90}
91
92/// Efficient top-k collector using min-heap
93///
94/// Maintains the k highest-scoring documents using a min-heap where the
95/// lowest score is at the top for O(1) threshold lookup and O(log k) eviction.
96/// No deduplication - caller must ensure each doc_id is inserted only once.
97pub struct ScoreCollector {
98    /// Min-heap of top-k entries (lowest score at top for eviction)
99    heap: BinaryHeap<HeapEntry>,
100    pub k: usize,
101}
102
103impl ScoreCollector {
104    /// Create a new collector for top-k results
105    pub fn new(k: usize) -> Self {
106        // Cap capacity to avoid allocation overflow for very large k
107        let capacity = k.saturating_add(1).min(1_000_000);
108        Self {
109            heap: BinaryHeap::with_capacity(capacity),
110            k,
111        }
112    }
113
114    /// Current score threshold (minimum score to enter top-k)
115    #[inline]
116    pub fn threshold(&self) -> f32 {
117        if self.heap.len() >= self.k {
118            self.heap.peek().map(|e| e.score).unwrap_or(0.0)
119        } else {
120            0.0
121        }
122    }
123
124    /// Insert a document score. Returns true if inserted in top-k.
125    /// Caller must ensure each doc_id is inserted only once.
126    #[inline]
127    pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
128        self.insert_with_ordinal(doc_id, score, 0)
129    }
130
131    /// Insert a document score with ordinal. Returns true if inserted in top-k.
132    /// Caller must ensure each doc_id is inserted only once.
133    #[inline]
134    pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
135        if self.heap.len() < self.k {
136            self.heap.push(HeapEntry {
137                doc_id,
138                score,
139                ordinal,
140            });
141            true
142        } else if score > self.threshold() {
143            self.heap.push(HeapEntry {
144                doc_id,
145                score,
146                ordinal,
147            });
148            self.heap.pop(); // Remove lowest
149            true
150        } else {
151            false
152        }
153    }
154
155    /// Check if a score could potentially enter top-k
156    #[inline]
157    pub fn would_enter(&self, score: f32) -> bool {
158        self.heap.len() < self.k || score > self.threshold()
159    }
160
161    /// Get number of documents collected so far
162    #[inline]
163    pub fn len(&self) -> usize {
164        self.heap.len()
165    }
166
167    /// Check if collector is empty
168    #[inline]
169    pub fn is_empty(&self) -> bool {
170        self.heap.is_empty()
171    }
172
173    /// Convert to sorted top-k results (descending by score)
174    pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
175        let heap_vec = self.heap.into_vec();
176        let mut results: Vec<(DocId, f32, u16)> = Vec::with_capacity(heap_vec.len());
177        for e in heap_vec {
178            results.push((e.doc_id, e.score, e.ordinal));
179        }
180
181        // Sort by score descending, then doc_id ascending
182        results.sort_by(|a, b| {
183            b.1.partial_cmp(&a.1)
184                .unwrap_or(Ordering::Equal)
185                .then_with(|| a.0.cmp(&b.0))
186        });
187
188        results
189    }
190}
191
192/// Search result from WAND execution
193#[derive(Debug, Clone, Copy)]
194pub struct ScoredDoc {
195    pub doc_id: DocId,
196    pub score: f32,
197    /// Ordinal for multi-valued fields (which vector in the field matched)
198    pub ordinal: u16,
199}
200
201/// Unified Block-Max MaxScore executor for top-k retrieval
202///
203/// Combines three optimizations from the literature into one executor:
204/// 1. **MaxScore partitioning** (Turtle & Flood 1995): terms split into essential
205///    (must check) and non-essential (only scored if candidate is promising)
206/// 2. **Block-max pruning** (Ding & Suel 2011): skip blocks where per-block
207///    upper bounds can't beat the current threshold
208/// 3. **Conjunction optimization** (Lucene/Grand 2023): progressively intersect
209///    essential terms as threshold rises, skipping docs that lack enough terms
210///
211/// Works with any type implementing `ScoringIterator` (text or sparse).
212/// Replaces separate WAND and MaxScore executors with better performance
213/// across all query lengths.
214pub struct BlockMaxScoreExecutor<S: ScoringIterator> {
215    /// Scorers sorted by max_score ascending (non-essential first)
216    scorers: Vec<S>,
217    /// Cumulative max_score prefix sums: prefix_sums[i] = sum(max_score[0..=i])
218    prefix_sums: Vec<f32>,
219    /// Top-k collector
220    collector: ScoreCollector,
221    /// Heap factor for approximate search (SEISMIC-style)
222    /// - 1.0 = exact search (default)
223    /// - 0.8 = approximate, faster with minor recall loss
224    heap_factor: f32,
225}
226
227/// Backwards-compatible alias for `BlockMaxScoreExecutor`
228pub type WandExecutor<S> = BlockMaxScoreExecutor<S>;
229
230impl<S: ScoringIterator> BlockMaxScoreExecutor<S> {
231    /// Create a new executor with exact search (heap_factor = 1.0)
232    pub fn new(scorers: Vec<S>, k: usize) -> Self {
233        Self::with_heap_factor(scorers, k, 1.0)
234    }
235
236    /// Create a new executor with approximate search
237    ///
238    /// `heap_factor` controls the trade-off between speed and recall:
239    /// - 1.0 = exact search
240    /// - 0.8 = ~20% faster, minor recall loss
241    /// - 0.5 = much faster, noticeable recall loss
242    pub fn with_heap_factor(mut scorers: Vec<S>, k: usize, heap_factor: f32) -> Self {
243        // Sort scorers by max_score ascending (non-essential terms first)
244        scorers.sort_by(|a, b| {
245            a.max_score()
246                .partial_cmp(&b.max_score())
247                .unwrap_or(Ordering::Equal)
248        });
249
250        // Compute prefix sums of max_scores
251        let mut prefix_sums = Vec::with_capacity(scorers.len());
252        let mut cumsum = 0.0f32;
253        for s in &scorers {
254            cumsum += s.max_score();
255            prefix_sums.push(cumsum);
256        }
257
258        debug!(
259            "Creating BlockMaxScoreExecutor: num_scorers={}, k={}, total_upper={:.4}, heap_factor={:.2}",
260            scorers.len(),
261            k,
262            cumsum,
263            heap_factor
264        );
265
266        Self {
267            scorers,
268            prefix_sums,
269            collector: ScoreCollector::new(k),
270            heap_factor: heap_factor.clamp(0.0, 1.0),
271        }
272    }
273
274    /// Find partition point: [0..partition) = non-essential, [partition..n) = essential
275    /// Non-essential terms have cumulative max_score <= threshold
276    #[inline]
277    fn find_partition(&self) -> usize {
278        let threshold = self.collector.threshold() * self.heap_factor;
279        self.prefix_sums
280            .iter()
281            .position(|&sum| sum > threshold)
282            .unwrap_or(self.scorers.len())
283    }
284
285    /// Execute Block-Max MaxScore and return top-k results
286    ///
287    /// Algorithm:
288    /// 1. Partition terms into essential/non-essential based on max_score
289    /// 2. Find min_doc across essential scorers
290    /// 3. Conjunction check: skip if not enough essential terms present
291    /// 4. Block-max check: skip if block upper bounds can't beat threshold
292    /// 5. Score essential scorers, check if non-essential scoring is needed
293    /// 6. Score non-essential scorers, group by ordinal, insert results
294    pub fn execute(mut self) -> Vec<ScoredDoc> {
295        if self.scorers.is_empty() {
296            debug!("BlockMaxScoreExecutor: no scorers, returning empty results");
297            return Vec::new();
298        }
299
300        let n = self.scorers.len();
301        let mut docs_scored = 0u64;
302        let mut docs_skipped = 0u64;
303        let mut blocks_skipped = 0u64;
304        let mut conjunction_skipped = 0u64;
305
306        // Pre-allocate scratch buffers outside the loop
307        let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
308
309        loop {
310            let partition = self.find_partition();
311
312            // If all terms are non-essential, we're done
313            if partition >= n {
314                debug!("BlockMaxScore: all terms non-essential, early termination");
315                break;
316            }
317
318            // Find minimum doc_id across essential scorers [partition..n)
319            let mut min_doc = u32::MAX;
320            for i in partition..n {
321                let doc = self.scorers[i].doc();
322                if doc < min_doc {
323                    min_doc = doc;
324                }
325            }
326
327            if min_doc == u32::MAX {
328                break; // All essential scorers exhausted
329            }
330
331            let non_essential_upper = if partition > 0 {
332                self.prefix_sums[partition - 1]
333            } else {
334                0.0
335            };
336            let adjusted_threshold = self.collector.threshold() * self.heap_factor;
337
338            // --- Conjunction optimization (Lucene-style) ---
339            // Check if enough essential terms are present at min_doc.
340            // Sum max_scores of essential terms AT min_doc. If that plus
341            // non-essential upper can't beat threshold, skip this doc.
342            if self.collector.len() >= self.collector.k {
343                let present_upper: f32 = (partition..n)
344                    .filter(|&i| self.scorers[i].doc() == min_doc)
345                    .map(|i| self.scorers[i].max_score())
346                    .sum();
347
348                if present_upper + non_essential_upper <= adjusted_threshold {
349                    // Not enough essential terms present - advance past min_doc
350                    for i in partition..n {
351                        if self.scorers[i].doc() == min_doc {
352                            self.scorers[i].advance();
353                        }
354                    }
355                    conjunction_skipped += 1;
356                    continue;
357                }
358            }
359
360            // --- Block-max pruning ---
361            // Sum block-max scores for essential scorers at min_doc.
362            // If block-max sum + non-essential upper can't beat threshold, skip blocks.
363            if self.collector.len() >= self.collector.k {
364                let block_max_sum: f32 = (partition..n)
365                    .filter(|&i| self.scorers[i].doc() == min_doc)
366                    .map(|i| self.scorers[i].current_block_max_score())
367                    .sum();
368
369                if block_max_sum + non_essential_upper <= adjusted_threshold {
370                    for i in partition..n {
371                        if self.scorers[i].doc() == min_doc {
372                            self.scorers[i].skip_to_next_block();
373                        }
374                    }
375                    blocks_skipped += 1;
376                    continue;
377                }
378            }
379
380            // --- Score essential scorers ---
381            // Drain all entries for min_doc from each essential scorer
382            ordinal_scores.clear();
383
384            for i in partition..n {
385                if self.scorers[i].doc() == min_doc {
386                    while self.scorers[i].doc() == min_doc {
387                        ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
388                        self.scorers[i].advance();
389                    }
390                }
391            }
392
393            // Check if essential score + non-essential upper could beat threshold
394            let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
395
396            if self.collector.len() >= self.collector.k
397                && essential_total + non_essential_upper <= adjusted_threshold
398            {
399                docs_skipped += 1;
400                continue;
401            }
402
403            // --- Score non-essential scorers ---
404            for i in 0..partition {
405                let doc = self.scorers[i].seek(min_doc);
406                if doc == min_doc {
407                    while self.scorers[i].doc() == min_doc {
408                        ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
409                        self.scorers[i].advance();
410                    }
411                }
412            }
413
414            // --- Group by ordinal and insert ---
415            ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
416            let mut j = 0;
417            while j < ordinal_scores.len() {
418                let current_ord = ordinal_scores[j].0;
419                let mut score = 0.0f32;
420                while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
421                    score += ordinal_scores[j].1;
422                    j += 1;
423                }
424
425                trace!(
426                    "Doc {}: ordinal={}, score={:.4}, threshold={:.4}",
427                    min_doc, current_ord, score, adjusted_threshold
428                );
429
430                if self
431                    .collector
432                    .insert_with_ordinal(min_doc, score, current_ord)
433                {
434                    docs_scored += 1;
435                } else {
436                    docs_skipped += 1;
437                }
438            }
439        }
440
441        let results: Vec<ScoredDoc> = self
442            .collector
443            .into_sorted_results()
444            .into_iter()
445            .map(|(doc_id, score, ordinal)| ScoredDoc {
446                doc_id,
447                score,
448                ordinal,
449            })
450            .collect();
451
452        debug!(
453            "BlockMaxScoreExecutor completed: scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
454            docs_scored,
455            docs_skipped,
456            blocks_skipped,
457            conjunction_skipped,
458            results.len(),
459            results.first().map(|r| r.score).unwrap_or(0.0)
460        );
461
462        results
463    }
464}
465
466/// Scorer for full-text terms using WAND optimization
467///
468/// Wraps a `BlockPostingList` with BM25 parameters to implement `ScoringIterator`.
469/// Enables MaxScore pruning for efficient top-k retrieval in OR queries.
470pub struct TextTermScorer {
471    /// Iterator over the posting list (owned)
472    iter: crate::structures::BlockPostingIterator<'static>,
473    /// IDF component for BM25
474    idf: f32,
475    /// Average field length for BM25 normalization
476    avg_field_len: f32,
477    /// Pre-computed max score (using max_tf from posting list)
478    max_score: f32,
479}
480
481impl TextTermScorer {
482    /// Create a new text term scorer with BM25 parameters
483    pub fn new(
484        posting_list: crate::structures::BlockPostingList,
485        idf: f32,
486        avg_field_len: f32,
487    ) -> Self {
488        // Compute max score using actual max_tf from posting list
489        let max_tf = posting_list.max_tf() as f32;
490        let doc_count = posting_list.doc_count();
491        let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
492
493        debug!(
494            "Created TextTermScorer: doc_count={}, max_tf={:.0}, idf={:.4}, avg_field_len={:.2}, max_score={:.4}",
495            doc_count, max_tf, idf, avg_field_len, max_score
496        );
497
498        Self {
499            iter: posting_list.into_iterator(),
500            idf,
501            avg_field_len,
502            max_score,
503        }
504    }
505}
506
507impl ScoringIterator for TextTermScorer {
508    #[inline]
509    fn doc(&self) -> DocId {
510        self.iter.doc()
511    }
512
513    #[inline]
514    fn advance(&mut self) -> DocId {
515        self.iter.advance()
516    }
517
518    #[inline]
519    fn seek(&mut self, target: DocId) -> DocId {
520        self.iter.seek(target)
521    }
522
523    #[inline]
524    fn score(&self) -> f32 {
525        let tf = self.iter.term_freq() as f32;
526        // Use tf as proxy for doc length (common approximation when field lengths aren't stored)
527        super::bm25_score(tf, self.idf, tf, self.avg_field_len)
528    }
529
530    #[inline]
531    fn max_score(&self) -> f32 {
532        self.max_score
533    }
534
535    #[inline]
536    fn current_block_max_score(&self) -> f32 {
537        // Use per-block max_tf for tighter Block-Max WAND bounds
538        let block_max_tf = self.iter.current_block_max_tf() as f32;
539        super::bm25_upper_bound(block_max_tf.max(1.0), self.idf)
540    }
541
542    #[inline]
543    fn skip_to_next_block(&mut self) -> DocId {
544        self.iter.skip_to_next_block()
545    }
546}
547
548/// Scorer for sparse vector dimensions
549///
550/// Wraps a `BlockSparsePostingList` with a query weight to implement `ScoringIterator`.
551pub struct SparseTermScorer<'a> {
552    /// Iterator over the posting list
553    iter: crate::structures::BlockSparsePostingIterator<'a>,
554    /// Query weight for this dimension
555    query_weight: f32,
556    /// Global max score (query_weight * global_max_weight)
557    max_score: f32,
558}
559
560impl<'a> SparseTermScorer<'a> {
561    /// Create a new sparse term scorer
562    ///
563    /// Note: Assumes positive weights for WAND upper bound calculation.
564    /// For negative query weights, uses absolute value to ensure valid upper bound.
565    pub fn new(posting_list: &'a BlockSparsePostingList, query_weight: f32) -> Self {
566        // Upper bound must account for sign: |query_weight| * max_weight
567        // This ensures the bound is valid regardless of weight sign
568        let max_score = query_weight.abs() * posting_list.global_max_weight();
569        Self {
570            iter: posting_list.iterator(),
571            query_weight,
572            max_score,
573        }
574    }
575
576    /// Create from Arc reference (for use with shared posting lists)
577    pub fn from_arc(posting_list: &'a Arc<BlockSparsePostingList>, query_weight: f32) -> Self {
578        Self::new(posting_list.as_ref(), query_weight)
579    }
580}
581
582impl ScoringIterator for SparseTermScorer<'_> {
583    #[inline]
584    fn doc(&self) -> DocId {
585        self.iter.doc()
586    }
587
588    #[inline]
589    fn ordinal(&self) -> u16 {
590        self.iter.ordinal()
591    }
592
593    #[inline]
594    fn advance(&mut self) -> DocId {
595        self.iter.advance()
596    }
597
598    #[inline]
599    fn seek(&mut self, target: DocId) -> DocId {
600        self.iter.seek(target)
601    }
602
603    #[inline]
604    fn score(&self) -> f32 {
605        // Dot product contribution: query_weight * stored_weight
606        self.query_weight * self.iter.weight()
607    }
608
609    #[inline]
610    fn max_score(&self) -> f32 {
611        self.max_score
612    }
613
614    #[inline]
615    fn current_block_max_score(&self) -> f32 {
616        // Use abs() for valid upper bound with negative weights
617        self.iter
618            .current_block_max_contribution(self.query_weight.abs())
619    }
620
621    #[inline]
622    fn skip_to_next_block(&mut self) -> DocId {
623        self.iter.skip_to_next_block()
624    }
625}
626
627/// Block-Max Pruning (BMP) executor for learned sparse retrieval
628///
629/// Processes blocks in score-descending order using a priority queue.
630/// Best for queries with many terms (20+), like SPLADE expansions.
631/// Uses document accumulators (FxHashMap) instead of per-term iterators.
632///
633/// **Memory-efficient**: Only skip entries (block metadata) are kept in memory.
634/// Actual block data is loaded on-demand via mmap range reads during execution.
635///
636/// Reference: Mallia et al., "Faster Learned Sparse Retrieval with
637/// Block-Max Pruning" (SIGIR 2024)
638pub struct BmpExecutor<'a> {
639    /// Sparse index for on-demand block loading
640    sparse_index: &'a crate::segment::SparseIndex,
641    /// Query terms: (dim_id, query_weight) for each matched dimension
642    query_terms: Vec<(u32, f32)>,
643    /// Number of results to return
644    k: usize,
645    /// Heap factor for approximate search
646    heap_factor: f32,
647}
648
649/// Entry in the BMP priority queue: (term_index, block_index)
650struct BmpBlockEntry {
651    /// Upper bound contribution of this block
652    contribution: f32,
653    /// Index into posting_lists
654    term_idx: usize,
655    /// Block index within the posting list
656    block_idx: usize,
657}
658
659impl PartialEq for BmpBlockEntry {
660    fn eq(&self, other: &Self) -> bool {
661        self.contribution == other.contribution
662    }
663}
664
665impl Eq for BmpBlockEntry {}
666
667impl Ord for BmpBlockEntry {
668    fn cmp(&self, other: &Self) -> Ordering {
669        // Max-heap: higher contributions come first
670        self.contribution
671            .partial_cmp(&other.contribution)
672            .unwrap_or(Ordering::Equal)
673    }
674}
675
676impl PartialOrd for BmpBlockEntry {
677    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
678        Some(self.cmp(other))
679    }
680}
681
682impl<'a> BmpExecutor<'a> {
683    /// Create a new BMP executor with lazy block loading
684    ///
685    /// `query_terms` should contain only dimensions that exist in the index.
686    /// Block metadata (skip entries) is read from the sparse index directly.
687    pub fn new(
688        sparse_index: &'a crate::segment::SparseIndex,
689        query_terms: Vec<(u32, f32)>,
690        k: usize,
691        heap_factor: f32,
692    ) -> Self {
693        Self {
694            sparse_index,
695            query_terms,
696            k,
697            heap_factor: heap_factor.clamp(0.0, 1.0),
698        }
699    }
700
701    /// Execute BMP and return top-k results
702    ///
703    /// Builds the priority queue from skip entries (already in memory),
704    /// then loads blocks on-demand via mmap range reads as they are visited.
705    pub async fn execute(self) -> crate::Result<Vec<ScoredDoc>> {
706        use rustc_hash::FxHashMap;
707
708        if self.query_terms.is_empty() {
709            return Ok(Vec::new());
710        }
711
712        let num_terms = self.query_terms.len();
713
714        // Build priority queue from skip entries (already in memory — no I/O)
715        let mut block_queue: BinaryHeap<BmpBlockEntry> = BinaryHeap::new();
716        let mut remaining_max: Vec<f32> = Vec::with_capacity(num_terms);
717
718        for (term_idx, &(dim_id, qw)) in self.query_terms.iter().enumerate() {
719            let mut term_remaining = 0.0f32;
720
721            if let Some((skip_entries, _global_max)) = self.sparse_index.get_skip_list(dim_id) {
722                for (block_idx, skip) in skip_entries.iter().enumerate() {
723                    let contribution = qw.abs() * skip.max_weight;
724                    term_remaining += contribution;
725                    block_queue.push(BmpBlockEntry {
726                        contribution,
727                        term_idx,
728                        block_idx,
729                    });
730                }
731            }
732            remaining_max.push(term_remaining);
733        }
734
735        // Document accumulators: packed (doc_id << 16 | ordinal) -> accumulated_score
736        // Using packed u64 key: single-word FxHash vs tuple hashing overhead.
737        // (doc_id, ordinal) ensures scores from different ordinals are NOT mixed.
738        let mut accumulators: FxHashMap<u64, f32> = FxHashMap::default();
739        let mut blocks_processed = 0u64;
740        let mut blocks_skipped = 0u64;
741
742        // Incremental top-k tracker for threshold — O(log k) per insert vs
743        // the old O(n) select_nth_unstable every 32 blocks.
744        let mut top_k = ScoreCollector::new(self.k);
745
746        // Reusable decode buffers — avoids 3 allocations per block
747        let mut doc_ids_buf: Vec<u32> = Vec::with_capacity(128);
748        let mut weights_buf: Vec<f32> = Vec::with_capacity(128);
749        let mut ordinals_buf: Vec<u16> = Vec::with_capacity(128);
750
751        // Process blocks in contribution-descending order, loading each on-demand
752        while let Some(entry) = block_queue.pop() {
753            // Update remaining max for this term
754            remaining_max[entry.term_idx] -= entry.contribution;
755
756            // Early termination: check if total remaining across all terms
757            // can beat the current k-th best accumulated score
758            let total_remaining: f32 = remaining_max.iter().sum();
759            let adjusted_threshold = top_k.threshold() * self.heap_factor;
760            if top_k.len() >= self.k && total_remaining <= adjusted_threshold {
761                blocks_skipped += block_queue.len() as u64;
762                debug!(
763                    "BMP early termination after {} blocks: remaining={:.4} <= threshold={:.4}",
764                    blocks_processed, total_remaining, adjusted_threshold
765                );
766                break;
767            }
768
769            // Load this single block on-demand via mmap range read
770            let dim_id = self.query_terms[entry.term_idx].0;
771            let block = match self.sparse_index.get_block(dim_id, entry.block_idx).await? {
772                Some(b) => b,
773                None => continue,
774            };
775
776            // Decode into reusable buffers (avoids alloc per block)
777            let qw = self.query_terms[entry.term_idx].1;
778            block.decode_doc_ids_into(&mut doc_ids_buf);
779            block.decode_scored_weights_into(qw, &mut weights_buf);
780            block.decode_ordinals_into(&mut ordinals_buf);
781
782            for i in 0..block.header.count as usize {
783                let score_contribution = weights_buf[i];
784                let key = (doc_ids_buf[i] as u64) << 16 | ordinals_buf[i] as u64;
785                let acc = accumulators.entry(key).or_insert(0.0);
786                *acc += score_contribution;
787                // Update top-k tracker with new accumulated score.
788                // ScoreCollector handles duplicates by keeping the entry with
789                // the highest score — stale lower entries are evicted naturally.
790                top_k.insert_with_ordinal(doc_ids_buf[i], *acc, ordinals_buf[i]);
791            }
792
793            blocks_processed += 1;
794        }
795
796        // Collect top-k directly from accumulators (use final accumulated scores)
797        let num_accumulators = accumulators.len();
798        let mut scored: Vec<ScoredDoc> = accumulators
799            .into_iter()
800            .map(|(key, score)| ScoredDoc {
801                doc_id: (key >> 16) as DocId,
802                score,
803                ordinal: (key & 0xFFFF) as u16,
804            })
805            .collect();
806        scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
807        scored.truncate(self.k);
808        let results = scored;
809
810        debug!(
811            "BmpExecutor completed: blocks_processed={}, blocks_skipped={}, accumulators={}, returned={}, top_score={:.4}",
812            blocks_processed,
813            blocks_skipped,
814            num_accumulators,
815            results.len(),
816            results.first().map(|r| r.score).unwrap_or(0.0)
817        );
818
819        Ok(results)
820    }
821}
822
823/// Lazy Block-Max MaxScore executor for sparse retrieval (1-11 terms)
824///
825/// Combines BlockMaxScore's cursor-based document-at-a-time traversal with
826/// BMP's lazy block loading. Skip entries (already in memory via zero-copy
827/// mmap) drive block-level navigation; actual block data is loaded on-demand
828/// only when the cursor visits that block.
829///
830/// For typical 1-11 term queries with MaxScore pruning, many blocks are
831/// skipped entirely — lazy loading avoids the I/O and decode cost for those
832/// blocks. This hybrid achieves BMP's memory efficiency with BlockMaxScore's
833/// superior pruning for few-term queries.
834pub struct LazyBlockMaxScoreExecutor<'a> {
835    sparse_index: &'a crate::segment::SparseIndex,
836    cursors: Vec<LazyTermCursor>,
837    prefix_sums: Vec<f32>,
838    collector: ScoreCollector,
839    heap_factor: f32,
840}
841
842/// Per-term cursor state for lazy block loading
843struct LazyTermCursor {
844    dim_id: u32,
845    query_weight: f32,
846    max_score: f32,
847    /// Skip entries (small, pre-loaded from zero-copy mmap section)
848    skip_entries: Vec<crate::structures::SparseSkipEntry>,
849    /// Current block index in skip_entries
850    block_idx: usize,
851    /// Decoded block data (loaded on demand, reused across seeks)
852    doc_ids: Vec<u32>,
853    ordinals: Vec<u16>,
854    weights: Vec<f32>,
855    /// Position within current decoded block
856    pos: usize,
857    /// Whether block at block_idx is decoded into doc_ids/ordinals/weights
858    block_loaded: bool,
859    exhausted: bool,
860}
861
862impl LazyTermCursor {
863    fn new(
864        dim_id: u32,
865        query_weight: f32,
866        skip_entries: Vec<crate::structures::SparseSkipEntry>,
867        global_max_weight: f32,
868    ) -> Self {
869        let exhausted = skip_entries.is_empty();
870        Self {
871            dim_id,
872            query_weight,
873            max_score: query_weight.abs() * global_max_weight,
874            skip_entries,
875            block_idx: 0,
876            doc_ids: Vec::new(),
877            ordinals: Vec::new(),
878            weights: Vec::new(),
879            pos: 0,
880            block_loaded: false,
881            exhausted,
882        }
883    }
884
885    /// Ensure current block is loaded and decoded
886    async fn ensure_block_loaded(
887        &mut self,
888        sparse_index: &crate::segment::SparseIndex,
889    ) -> crate::Result<bool> {
890        if self.exhausted || self.block_loaded {
891            return Ok(!self.exhausted);
892        }
893        match sparse_index.get_block(self.dim_id, self.block_idx).await? {
894            Some(block) => {
895                block.decode_doc_ids_into(&mut self.doc_ids);
896                block.decode_ordinals_into(&mut self.ordinals);
897                block.decode_scored_weights_into(self.query_weight, &mut self.weights);
898                self.pos = 0;
899                self.block_loaded = true;
900                Ok(true)
901            }
902            None => {
903                self.exhausted = true;
904                Ok(false)
905            }
906        }
907    }
908
909    #[inline]
910    fn doc(&self) -> DocId {
911        if self.exhausted {
912            return u32::MAX;
913        }
914        if !self.block_loaded {
915            // Block not yet loaded — return first_doc of current skip entry
916            // as a lower bound (actual doc may be higher after decode)
917            return self.skip_entries[self.block_idx].first_doc;
918        }
919        self.doc_ids.get(self.pos).copied().unwrap_or(u32::MAX)
920    }
921
922    #[inline]
923    fn ordinal(&self) -> u16 {
924        if !self.block_loaded {
925            return 0;
926        }
927        self.ordinals.get(self.pos).copied().unwrap_or(0)
928    }
929
930    #[inline]
931    fn score(&self) -> f32 {
932        if !self.block_loaded {
933            return 0.0;
934        }
935        self.weights.get(self.pos).copied().unwrap_or(0.0)
936    }
937
938    #[inline]
939    fn current_block_max_score(&self) -> f32 {
940        if self.exhausted {
941            return 0.0;
942        }
943        self.query_weight.abs()
944            * self
945                .skip_entries
946                .get(self.block_idx)
947                .map(|e| e.max_weight)
948                .unwrap_or(0.0)
949    }
950
951    /// Advance to next posting within current block, or move to next block
952    async fn advance(
953        &mut self,
954        sparse_index: &crate::segment::SparseIndex,
955    ) -> crate::Result<DocId> {
956        if self.exhausted {
957            return Ok(u32::MAX);
958        }
959        self.ensure_block_loaded(sparse_index).await?;
960        if self.exhausted {
961            return Ok(u32::MAX);
962        }
963        self.pos += 1;
964        if self.pos >= self.doc_ids.len() {
965            self.block_idx += 1;
966            self.block_loaded = false;
967            if self.block_idx >= self.skip_entries.len() {
968                self.exhausted = true;
969                return Ok(u32::MAX);
970            }
971            // Don't load next block yet — lazy
972        }
973        Ok(self.doc())
974    }
975
976    /// Seek to first doc >= target using skip entries for block navigation
977    async fn seek(
978        &mut self,
979        sparse_index: &crate::segment::SparseIndex,
980        target: DocId,
981    ) -> crate::Result<DocId> {
982        if self.exhausted {
983            return Ok(u32::MAX);
984        }
985
986        // If block is loaded and target is within current block range
987        if self.block_loaded
988            && let Some(&last) = self.doc_ids.last()
989        {
990            if last >= target && self.doc_ids[self.pos] < target {
991                // Binary search within current block
992                let remaining = &self.doc_ids[self.pos..];
993                let offset = crate::structures::simd::find_first_ge_u32(remaining, target);
994                self.pos += offset;
995                if self.pos >= self.doc_ids.len() {
996                    self.block_idx += 1;
997                    self.block_loaded = false;
998                    if self.block_idx >= self.skip_entries.len() {
999                        self.exhausted = true;
1000                        return Ok(u32::MAX);
1001                    }
1002                }
1003                return Ok(self.doc());
1004            }
1005            if self.doc_ids[self.pos] >= target {
1006                return Ok(self.doc());
1007            }
1008        }
1009
1010        // Binary search on skip entries: find first block where last_doc >= target
1011        let bi = self.skip_entries.iter().position(|e| e.last_doc >= target);
1012        match bi {
1013            Some(idx) => {
1014                if idx != self.block_idx || !self.block_loaded {
1015                    self.block_idx = idx;
1016                    self.block_loaded = false;
1017                }
1018                self.ensure_block_loaded(sparse_index).await?;
1019                if self.exhausted {
1020                    return Ok(u32::MAX);
1021                }
1022                let offset = crate::structures::simd::find_first_ge_u32(&self.doc_ids, target);
1023                self.pos = offset;
1024                if self.pos >= self.doc_ids.len() {
1025                    self.block_idx += 1;
1026                    self.block_loaded = false;
1027                    if self.block_idx >= self.skip_entries.len() {
1028                        self.exhausted = true;
1029                        return Ok(u32::MAX);
1030                    }
1031                    self.ensure_block_loaded(sparse_index).await?;
1032                }
1033                Ok(self.doc())
1034            }
1035            None => {
1036                self.exhausted = true;
1037                Ok(u32::MAX)
1038            }
1039        }
1040    }
1041
1042    /// Skip to next block without loading it (for block-max pruning)
1043    fn skip_to_next_block(&mut self) -> DocId {
1044        if self.exhausted {
1045            return u32::MAX;
1046        }
1047        self.block_idx += 1;
1048        self.block_loaded = false;
1049        if self.block_idx >= self.skip_entries.len() {
1050            self.exhausted = true;
1051            return u32::MAX;
1052        }
1053        // Return first_doc of next block as lower bound
1054        self.skip_entries[self.block_idx].first_doc
1055    }
1056}
1057
1058impl<'a> LazyBlockMaxScoreExecutor<'a> {
1059    /// Create a new lazy executor
1060    ///
1061    /// `query_terms` should contain only dimensions present in the index.
1062    /// Skip entries are read from the zero-copy mmap section (no I/O).
1063    pub fn new(
1064        sparse_index: &'a crate::segment::SparseIndex,
1065        query_terms: Vec<(u32, f32)>,
1066        k: usize,
1067        heap_factor: f32,
1068    ) -> Self {
1069        let mut cursors: Vec<LazyTermCursor> = query_terms
1070            .iter()
1071            .filter_map(|&(dim_id, qw)| {
1072                let (skip_entries, global_max) = sparse_index.get_skip_list(dim_id)?;
1073                Some(LazyTermCursor::new(dim_id, qw, skip_entries, global_max))
1074            })
1075            .collect();
1076
1077        // Sort by max_score ascending (non-essential first)
1078        cursors.sort_by(|a, b| {
1079            a.max_score
1080                .partial_cmp(&b.max_score)
1081                .unwrap_or(Ordering::Equal)
1082        });
1083
1084        let mut prefix_sums = Vec::with_capacity(cursors.len());
1085        let mut cumsum = 0.0f32;
1086        for c in &cursors {
1087            cumsum += c.max_score;
1088            prefix_sums.push(cumsum);
1089        }
1090
1091        debug!(
1092            "Creating LazyBlockMaxScoreExecutor: num_terms={}, k={}, total_upper={:.4}, heap_factor={:.2}",
1093            cursors.len(),
1094            k,
1095            cumsum,
1096            heap_factor
1097        );
1098
1099        Self {
1100            sparse_index,
1101            cursors,
1102            prefix_sums,
1103            collector: ScoreCollector::new(k),
1104            heap_factor: heap_factor.clamp(0.0, 1.0),
1105        }
1106    }
1107
1108    #[inline]
1109    fn find_partition(&self) -> usize {
1110        let threshold = self.collector.threshold() * self.heap_factor;
1111        self.prefix_sums
1112            .iter()
1113            .position(|&sum| sum > threshold)
1114            .unwrap_or(self.cursors.len())
1115    }
1116
1117    /// Execute lazy Block-Max MaxScore and return top-k results
1118    pub async fn execute(mut self) -> crate::Result<Vec<ScoredDoc>> {
1119        if self.cursors.is_empty() {
1120            return Ok(Vec::new());
1121        }
1122
1123        let n = self.cursors.len();
1124        let si = self.sparse_index;
1125
1126        // Load first block for each cursor (ensures doc() returns real values)
1127        for cursor in &mut self.cursors {
1128            cursor.ensure_block_loaded(si).await?;
1129        }
1130
1131        let mut docs_scored = 0u64;
1132        let mut docs_skipped = 0u64;
1133        let mut blocks_skipped = 0u64;
1134        let mut blocks_loaded = 0u64;
1135        let mut conjunction_skipped = 0u64;
1136        let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
1137
1138        loop {
1139            let partition = self.find_partition();
1140            if partition >= n {
1141                break;
1142            }
1143
1144            // Find minimum doc_id across essential cursors
1145            let mut min_doc = u32::MAX;
1146            for i in partition..n {
1147                let doc = self.cursors[i].doc();
1148                if doc < min_doc {
1149                    min_doc = doc;
1150                }
1151            }
1152            if min_doc == u32::MAX {
1153                break;
1154            }
1155
1156            let non_essential_upper = if partition > 0 {
1157                self.prefix_sums[partition - 1]
1158            } else {
1159                0.0
1160            };
1161            let adjusted_threshold = self.collector.threshold() * self.heap_factor;
1162
1163            // --- Conjunction optimization ---
1164            if self.collector.len() >= self.collector.k {
1165                let present_upper: f32 = (partition..n)
1166                    .filter(|&i| self.cursors[i].doc() == min_doc)
1167                    .map(|i| self.cursors[i].max_score)
1168                    .sum();
1169
1170                if present_upper + non_essential_upper <= adjusted_threshold {
1171                    for i in partition..n {
1172                        if self.cursors[i].doc() == min_doc {
1173                            self.cursors[i].advance(si).await?;
1174                            blocks_loaded += u64::from(self.cursors[i].block_loaded);
1175                        }
1176                    }
1177                    conjunction_skipped += 1;
1178                    continue;
1179                }
1180            }
1181
1182            // --- Block-max pruning ---
1183            if self.collector.len() >= self.collector.k {
1184                let block_max_sum: f32 = (partition..n)
1185                    .filter(|&i| self.cursors[i].doc() == min_doc)
1186                    .map(|i| self.cursors[i].current_block_max_score())
1187                    .sum();
1188
1189                if block_max_sum + non_essential_upper <= adjusted_threshold {
1190                    for i in partition..n {
1191                        if self.cursors[i].doc() == min_doc {
1192                            self.cursors[i].skip_to_next_block();
1193                            // Ensure next block is loaded for doc() to return real value
1194                            self.cursors[i].ensure_block_loaded(si).await?;
1195                            blocks_loaded += 1;
1196                        }
1197                    }
1198                    blocks_skipped += 1;
1199                    continue;
1200                }
1201            }
1202
1203            // --- Score essential cursors ---
1204            ordinal_scores.clear();
1205            for i in partition..n {
1206                if self.cursors[i].doc() == min_doc {
1207                    while self.cursors[i].doc() == min_doc {
1208                        ordinal_scores.push((self.cursors[i].ordinal(), self.cursors[i].score()));
1209                        self.cursors[i].advance(si).await?;
1210                    }
1211                }
1212            }
1213
1214            let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
1215            if self.collector.len() >= self.collector.k
1216                && essential_total + non_essential_upper <= adjusted_threshold
1217            {
1218                docs_skipped += 1;
1219                continue;
1220            }
1221
1222            // --- Score non-essential cursors ---
1223            for i in 0..partition {
1224                let doc = self.cursors[i].seek(si, min_doc).await?;
1225                if doc == min_doc {
1226                    while self.cursors[i].doc() == min_doc {
1227                        ordinal_scores.push((self.cursors[i].ordinal(), self.cursors[i].score()));
1228                        self.cursors[i].advance(si).await?;
1229                    }
1230                }
1231            }
1232
1233            // --- Group by ordinal and insert ---
1234            ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
1235            let mut j = 0;
1236            while j < ordinal_scores.len() {
1237                let current_ord = ordinal_scores[j].0;
1238                let mut score = 0.0f32;
1239                while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
1240                    score += ordinal_scores[j].1;
1241                    j += 1;
1242                }
1243                if self
1244                    .collector
1245                    .insert_with_ordinal(min_doc, score, current_ord)
1246                {
1247                    docs_scored += 1;
1248                } else {
1249                    docs_skipped += 1;
1250                }
1251            }
1252        }
1253
1254        let results: Vec<ScoredDoc> = self
1255            .collector
1256            .into_sorted_results()
1257            .into_iter()
1258            .map(|(doc_id, score, ordinal)| ScoredDoc {
1259                doc_id,
1260                score,
1261                ordinal,
1262            })
1263            .collect();
1264
1265        debug!(
1266            "LazyBlockMaxScoreExecutor completed: scored={}, skipped={}, blocks_skipped={}, blocks_loaded={}, conjunction_skipped={}, returned={}, top_score={:.4}",
1267            docs_scored,
1268            docs_skipped,
1269            blocks_skipped,
1270            blocks_loaded,
1271            conjunction_skipped,
1272            results.len(),
1273            results.first().map(|r| r.score).unwrap_or(0.0)
1274        );
1275
1276        Ok(results)
1277    }
1278}
1279
1280#[cfg(test)]
1281mod tests {
1282    use super::*;
1283
1284    #[test]
1285    fn test_score_collector_basic() {
1286        let mut collector = ScoreCollector::new(3);
1287
1288        collector.insert(1, 1.0);
1289        collector.insert(2, 2.0);
1290        collector.insert(3, 3.0);
1291        assert_eq!(collector.threshold(), 1.0);
1292
1293        collector.insert(4, 4.0);
1294        assert_eq!(collector.threshold(), 2.0);
1295
1296        let results = collector.into_sorted_results();
1297        assert_eq!(results.len(), 3);
1298        assert_eq!(results[0].0, 4); // Highest score
1299        assert_eq!(results[1].0, 3);
1300        assert_eq!(results[2].0, 2);
1301    }
1302
1303    #[test]
1304    fn test_score_collector_threshold() {
1305        let mut collector = ScoreCollector::new(2);
1306
1307        collector.insert(1, 5.0);
1308        collector.insert(2, 3.0);
1309        assert_eq!(collector.threshold(), 3.0);
1310
1311        // Should not enter (score too low)
1312        assert!(!collector.would_enter(2.0));
1313        assert!(!collector.insert(3, 2.0));
1314
1315        // Should enter (score high enough)
1316        assert!(collector.would_enter(4.0));
1317        assert!(collector.insert(4, 4.0));
1318        assert_eq!(collector.threshold(), 4.0);
1319    }
1320
1321    #[test]
1322    fn test_heap_entry_ordering() {
1323        let mut heap = BinaryHeap::new();
1324        heap.push(HeapEntry {
1325            doc_id: 1,
1326            score: 3.0,
1327            ordinal: 0,
1328        });
1329        heap.push(HeapEntry {
1330            doc_id: 2,
1331            score: 1.0,
1332            ordinal: 0,
1333        });
1334        heap.push(HeapEntry {
1335            doc_id: 3,
1336            score: 2.0,
1337            ordinal: 0,
1338        });
1339
1340        // Min-heap: lowest score should come out first
1341        assert_eq!(heap.pop().unwrap().score, 1.0);
1342        assert_eq!(heap.pop().unwrap().score, 2.0);
1343        assert_eq!(heap.pop().unwrap().score, 3.0);
1344    }
1345}