Skip to main content

hermes_core/query/
scoring.rs

1//! Shared scoring abstractions for text and sparse vector search
2//!
3//! Provides common traits and utilities for efficient top-k retrieval:
4//! - `ScoringIterator`: Common interface for posting list iteration with scoring
5//! - `TopKCollector`: Efficient min-heap for maintaining top-k results
6//! - `BlockMaxScoreExecutor`: Unified Block-Max MaxScore with conjunction optimization
7//! - `BmpExecutor`: Block-at-a-time executor for learned sparse retrieval (12+ terms)
8//! - `SparseTermScorer`: ScoringIterator implementation for sparse vectors
9
10use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12use std::sync::Arc;
13
14use log::{debug, trace};
15
16use crate::DocId;
17use crate::structures::BlockSparsePostingList;
18
19/// Common interface for scoring iterators (text terms or sparse dimensions)
20///
21/// Abstracts the common operations needed for WAND-style top-k retrieval.
22pub trait ScoringIterator {
23    /// Current document ID (u32::MAX if exhausted)
24    fn doc(&self) -> DocId;
25
26    /// Current ordinal for multi-valued fields (default 0)
27    fn ordinal(&self) -> u16 {
28        0
29    }
30
31    /// Advance to next document, returns new doc ID
32    fn advance(&mut self) -> DocId;
33
34    /// Seek to first document >= target, returns new doc ID
35    fn seek(&mut self, target: DocId) -> DocId;
36
37    /// Check if iterator is exhausted
38    fn is_exhausted(&self) -> bool {
39        self.doc() == u32::MAX
40    }
41
42    /// Score contribution for current document
43    fn score(&self) -> f32;
44
45    /// Maximum possible score for this term/dimension (global upper bound)
46    fn max_score(&self) -> f32;
47
48    /// Current block's maximum score upper bound (for block-level pruning)
49    fn current_block_max_score(&self) -> f32;
50
51    /// Skip to the next block, returning the first doc_id in the new block.
52    /// Used for block-max WAND optimization when current block can't beat threshold.
53    /// Default implementation just advances (no block-level skipping).
54    fn skip_to_next_block(&mut self) -> DocId {
55        self.advance()
56    }
57}
58
59/// Entry for top-k min-heap
60#[derive(Clone, Copy)]
61pub struct HeapEntry {
62    pub doc_id: DocId,
63    pub score: f32,
64    pub ordinal: u16,
65}
66
67impl PartialEq for HeapEntry {
68    fn eq(&self, other: &Self) -> bool {
69        self.score == other.score && self.doc_id == other.doc_id
70    }
71}
72
73impl Eq for HeapEntry {}
74
75impl Ord for HeapEntry {
76    fn cmp(&self, other: &Self) -> Ordering {
77        // Min-heap: lower scores come first (to be evicted)
78        other
79            .score
80            .partial_cmp(&self.score)
81            .unwrap_or(Ordering::Equal)
82            .then_with(|| self.doc_id.cmp(&other.doc_id))
83    }
84}
85
86impl PartialOrd for HeapEntry {
87    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
88        Some(self.cmp(other))
89    }
90}
91
92/// Efficient top-k collector using min-heap
93///
94/// Maintains the k highest-scoring documents using a min-heap where the
95/// lowest score is at the top for O(1) threshold lookup and O(log k) eviction.
96/// No deduplication - caller must ensure each doc_id is inserted only once.
97pub struct ScoreCollector {
98    /// Min-heap of top-k entries (lowest score at top for eviction)
99    heap: BinaryHeap<HeapEntry>,
100    pub k: usize,
101}
102
103impl ScoreCollector {
104    /// Create a new collector for top-k results
105    pub fn new(k: usize) -> Self {
106        // Cap capacity to avoid allocation overflow for very large k
107        let capacity = k.saturating_add(1).min(1_000_000);
108        Self {
109            heap: BinaryHeap::with_capacity(capacity),
110            k,
111        }
112    }
113
114    /// Current score threshold (minimum score to enter top-k)
115    #[inline]
116    pub fn threshold(&self) -> f32 {
117        if self.heap.len() >= self.k {
118            self.heap.peek().map(|e| e.score).unwrap_or(0.0)
119        } else {
120            0.0
121        }
122    }
123
124    /// Insert a document score. Returns true if inserted in top-k.
125    /// Caller must ensure each doc_id is inserted only once.
126    #[inline]
127    pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
128        self.insert_with_ordinal(doc_id, score, 0)
129    }
130
131    /// Insert a document score with ordinal. Returns true if inserted in top-k.
132    /// Caller must ensure each doc_id is inserted only once.
133    #[inline]
134    pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
135        if self.heap.len() < self.k {
136            self.heap.push(HeapEntry {
137                doc_id,
138                score,
139                ordinal,
140            });
141            true
142        } else if score > self.threshold() {
143            self.heap.push(HeapEntry {
144                doc_id,
145                score,
146                ordinal,
147            });
148            self.heap.pop(); // Remove lowest
149            true
150        } else {
151            false
152        }
153    }
154
155    /// Check if a score could potentially enter top-k
156    #[inline]
157    pub fn would_enter(&self, score: f32) -> bool {
158        self.heap.len() < self.k || score > self.threshold()
159    }
160
161    /// Get number of documents collected so far
162    #[inline]
163    pub fn len(&self) -> usize {
164        self.heap.len()
165    }
166
167    /// Check if collector is empty
168    #[inline]
169    pub fn is_empty(&self) -> bool {
170        self.heap.is_empty()
171    }
172
173    /// Convert to sorted top-k results (descending by score)
174    pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
175        let mut results: Vec<_> = self
176            .heap
177            .into_vec()
178            .into_iter()
179            .map(|e| (e.doc_id, e.score, e.ordinal))
180            .collect();
181
182        // Sort by score descending, then doc_id ascending
183        results.sort_by(|a, b| {
184            b.1.partial_cmp(&a.1)
185                .unwrap_or(Ordering::Equal)
186                .then_with(|| a.0.cmp(&b.0))
187        });
188
189        results
190    }
191}
192
193/// Search result from WAND execution
194#[derive(Debug, Clone, Copy)]
195pub struct ScoredDoc {
196    pub doc_id: DocId,
197    pub score: f32,
198    /// Ordinal for multi-valued fields (which vector in the field matched)
199    pub ordinal: u16,
200}
201
202/// Unified Block-Max MaxScore executor for top-k retrieval
203///
204/// Combines three optimizations from the literature into one executor:
205/// 1. **MaxScore partitioning** (Turtle & Flood 1995): terms split into essential
206///    (must check) and non-essential (only scored if candidate is promising)
207/// 2. **Block-max pruning** (Ding & Suel 2011): skip blocks where per-block
208///    upper bounds can't beat the current threshold
209/// 3. **Conjunction optimization** (Lucene/Grand 2023): progressively intersect
210///    essential terms as threshold rises, skipping docs that lack enough terms
211///
212/// Works with any type implementing `ScoringIterator` (text or sparse).
213/// Replaces separate WAND and MaxScore executors with better performance
214/// across all query lengths.
215pub struct BlockMaxScoreExecutor<S: ScoringIterator> {
216    /// Scorers sorted by max_score ascending (non-essential first)
217    scorers: Vec<S>,
218    /// Cumulative max_score prefix sums: prefix_sums[i] = sum(max_score[0..=i])
219    prefix_sums: Vec<f32>,
220    /// Top-k collector
221    collector: ScoreCollector,
222    /// Heap factor for approximate search (SEISMIC-style)
223    /// - 1.0 = exact search (default)
224    /// - 0.8 = approximate, faster with minor recall loss
225    heap_factor: f32,
226}
227
228/// Backwards-compatible alias for `BlockMaxScoreExecutor`
229pub type WandExecutor<S> = BlockMaxScoreExecutor<S>;
230
231impl<S: ScoringIterator> BlockMaxScoreExecutor<S> {
232    /// Create a new executor with exact search (heap_factor = 1.0)
233    pub fn new(scorers: Vec<S>, k: usize) -> Self {
234        Self::with_heap_factor(scorers, k, 1.0)
235    }
236
237    /// Create a new executor with approximate search
238    ///
239    /// `heap_factor` controls the trade-off between speed and recall:
240    /// - 1.0 = exact search
241    /// - 0.8 = ~20% faster, minor recall loss
242    /// - 0.5 = much faster, noticeable recall loss
243    pub fn with_heap_factor(mut scorers: Vec<S>, k: usize, heap_factor: f32) -> Self {
244        // Sort scorers by max_score ascending (non-essential terms first)
245        scorers.sort_by(|a, b| {
246            a.max_score()
247                .partial_cmp(&b.max_score())
248                .unwrap_or(Ordering::Equal)
249        });
250
251        // Compute prefix sums of max_scores
252        let mut prefix_sums = Vec::with_capacity(scorers.len());
253        let mut cumsum = 0.0f32;
254        for s in &scorers {
255            cumsum += s.max_score();
256            prefix_sums.push(cumsum);
257        }
258
259        debug!(
260            "Creating BlockMaxScoreExecutor: num_scorers={}, k={}, total_upper={:.4}, heap_factor={:.2}",
261            scorers.len(),
262            k,
263            cumsum,
264            heap_factor
265        );
266
267        Self {
268            scorers,
269            prefix_sums,
270            collector: ScoreCollector::new(k),
271            heap_factor: heap_factor.clamp(0.0, 1.0),
272        }
273    }
274
275    /// Find partition point: [0..partition) = non-essential, [partition..n) = essential
276    /// Non-essential terms have cumulative max_score <= threshold
277    #[inline]
278    fn find_partition(&self) -> usize {
279        let threshold = self.collector.threshold() * self.heap_factor;
280        self.prefix_sums
281            .iter()
282            .position(|&sum| sum > threshold)
283            .unwrap_or(self.scorers.len())
284    }
285
286    /// Execute Block-Max MaxScore and return top-k results
287    ///
288    /// Algorithm:
289    /// 1. Partition terms into essential/non-essential based on max_score
290    /// 2. Find min_doc across essential scorers
291    /// 3. Conjunction check: skip if not enough essential terms present
292    /// 4. Block-max check: skip if block upper bounds can't beat threshold
293    /// 5. Score essential scorers, check if non-essential scoring is needed
294    /// 6. Score non-essential scorers, group by ordinal, insert results
295    pub fn execute(mut self) -> Vec<ScoredDoc> {
296        if self.scorers.is_empty() {
297            debug!("BlockMaxScoreExecutor: no scorers, returning empty results");
298            return Vec::new();
299        }
300
301        let n = self.scorers.len();
302        let mut docs_scored = 0u64;
303        let mut docs_skipped = 0u64;
304        let mut blocks_skipped = 0u64;
305        let mut conjunction_skipped = 0u64;
306
307        // Pre-allocate scratch buffers outside the loop
308        let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
309
310        loop {
311            let partition = self.find_partition();
312
313            // If all terms are non-essential, we're done
314            if partition >= n {
315                debug!("BlockMaxScore: all terms non-essential, early termination");
316                break;
317            }
318
319            // Find minimum doc_id across essential scorers [partition..n)
320            let mut min_doc = u32::MAX;
321            for i in partition..n {
322                let doc = self.scorers[i].doc();
323                if doc < min_doc {
324                    min_doc = doc;
325                }
326            }
327
328            if min_doc == u32::MAX {
329                break; // All essential scorers exhausted
330            }
331
332            let non_essential_upper = if partition > 0 {
333                self.prefix_sums[partition - 1]
334            } else {
335                0.0
336            };
337            let adjusted_threshold = self.collector.threshold() * self.heap_factor;
338
339            // --- Conjunction optimization (Lucene-style) ---
340            // Check if enough essential terms are present at min_doc.
341            // Sum max_scores of essential terms AT min_doc. If that plus
342            // non-essential upper can't beat threshold, skip this doc.
343            if self.collector.len() >= self.collector.k {
344                let present_upper: f32 = (partition..n)
345                    .filter(|&i| self.scorers[i].doc() == min_doc)
346                    .map(|i| self.scorers[i].max_score())
347                    .sum();
348
349                if present_upper + non_essential_upper <= adjusted_threshold {
350                    // Not enough essential terms present - advance past min_doc
351                    for i in partition..n {
352                        if self.scorers[i].doc() == min_doc {
353                            self.scorers[i].advance();
354                        }
355                    }
356                    conjunction_skipped += 1;
357                    continue;
358                }
359            }
360
361            // --- Block-max pruning ---
362            // Sum block-max scores for essential scorers at min_doc.
363            // If block-max sum + non-essential upper can't beat threshold, skip blocks.
364            if self.collector.len() >= self.collector.k {
365                let block_max_sum: f32 = (partition..n)
366                    .filter(|&i| self.scorers[i].doc() == min_doc)
367                    .map(|i| self.scorers[i].current_block_max_score())
368                    .sum();
369
370                if block_max_sum + non_essential_upper <= adjusted_threshold {
371                    for i in partition..n {
372                        if self.scorers[i].doc() == min_doc {
373                            self.scorers[i].skip_to_next_block();
374                        }
375                    }
376                    blocks_skipped += 1;
377                    continue;
378                }
379            }
380
381            // --- Score essential scorers ---
382            // Drain all entries for min_doc from each essential scorer
383            ordinal_scores.clear();
384
385            for i in partition..n {
386                if self.scorers[i].doc() == min_doc {
387                    while self.scorers[i].doc() == min_doc {
388                        ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
389                        self.scorers[i].advance();
390                    }
391                }
392            }
393
394            // Check if essential score + non-essential upper could beat threshold
395            let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
396
397            if self.collector.len() >= self.collector.k
398                && essential_total + non_essential_upper <= adjusted_threshold
399            {
400                docs_skipped += 1;
401                continue;
402            }
403
404            // --- Score non-essential scorers ---
405            for i in 0..partition {
406                let doc = self.scorers[i].seek(min_doc);
407                if doc == min_doc {
408                    while self.scorers[i].doc() == min_doc {
409                        ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
410                        self.scorers[i].advance();
411                    }
412                }
413            }
414
415            // --- Group by ordinal and insert ---
416            ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
417            let mut j = 0;
418            while j < ordinal_scores.len() {
419                let current_ord = ordinal_scores[j].0;
420                let mut score = 0.0f32;
421                while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
422                    score += ordinal_scores[j].1;
423                    j += 1;
424                }
425
426                trace!(
427                    "Doc {}: ordinal={}, score={:.4}, threshold={:.4}",
428                    min_doc, current_ord, score, adjusted_threshold
429                );
430
431                if self
432                    .collector
433                    .insert_with_ordinal(min_doc, score, current_ord)
434                {
435                    docs_scored += 1;
436                } else {
437                    docs_skipped += 1;
438                }
439            }
440        }
441
442        let results: Vec<ScoredDoc> = self
443            .collector
444            .into_sorted_results()
445            .into_iter()
446            .map(|(doc_id, score, ordinal)| ScoredDoc {
447                doc_id,
448                score,
449                ordinal,
450            })
451            .collect();
452
453        debug!(
454            "BlockMaxScoreExecutor completed: scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
455            docs_scored,
456            docs_skipped,
457            blocks_skipped,
458            conjunction_skipped,
459            results.len(),
460            results.first().map(|r| r.score).unwrap_or(0.0)
461        );
462
463        results
464    }
465}
466
467/// Scorer for full-text terms using WAND optimization
468///
469/// Wraps a `BlockPostingList` with BM25 parameters to implement `ScoringIterator`.
470/// Enables MaxScore pruning for efficient top-k retrieval in OR queries.
471pub struct TextTermScorer {
472    /// Iterator over the posting list (owned)
473    iter: crate::structures::BlockPostingIterator<'static>,
474    /// IDF component for BM25
475    idf: f32,
476    /// Average field length for BM25 normalization
477    avg_field_len: f32,
478    /// Pre-computed max score (using max_tf from posting list)
479    max_score: f32,
480}
481
482impl TextTermScorer {
483    /// Create a new text term scorer with BM25 parameters
484    pub fn new(
485        posting_list: crate::structures::BlockPostingList,
486        idf: f32,
487        avg_field_len: f32,
488    ) -> Self {
489        // Compute max score using actual max_tf from posting list
490        let max_tf = posting_list.max_tf() as f32;
491        let doc_count = posting_list.doc_count();
492        let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
493
494        debug!(
495            "Created TextTermScorer: doc_count={}, max_tf={:.0}, idf={:.4}, avg_field_len={:.2}, max_score={:.4}",
496            doc_count, max_tf, idf, avg_field_len, max_score
497        );
498
499        Self {
500            iter: posting_list.into_iterator(),
501            idf,
502            avg_field_len,
503            max_score,
504        }
505    }
506}
507
508impl ScoringIterator for TextTermScorer {
509    #[inline]
510    fn doc(&self) -> DocId {
511        self.iter.doc()
512    }
513
514    #[inline]
515    fn advance(&mut self) -> DocId {
516        self.iter.advance()
517    }
518
519    #[inline]
520    fn seek(&mut self, target: DocId) -> DocId {
521        self.iter.seek(target)
522    }
523
524    #[inline]
525    fn score(&self) -> f32 {
526        let tf = self.iter.term_freq() as f32;
527        // Use tf as proxy for doc length (common approximation when field lengths aren't stored)
528        super::bm25_score(tf, self.idf, tf, self.avg_field_len)
529    }
530
531    #[inline]
532    fn max_score(&self) -> f32 {
533        self.max_score
534    }
535
536    #[inline]
537    fn current_block_max_score(&self) -> f32 {
538        // Use per-block max_tf for tighter Block-Max WAND bounds
539        let block_max_tf = self.iter.current_block_max_tf() as f32;
540        super::bm25_upper_bound(block_max_tf.max(1.0), self.idf)
541    }
542
543    #[inline]
544    fn skip_to_next_block(&mut self) -> DocId {
545        self.iter.skip_to_next_block()
546    }
547}
548
549/// Scorer for sparse vector dimensions
550///
551/// Wraps a `BlockSparsePostingList` with a query weight to implement `ScoringIterator`.
552pub struct SparseTermScorer<'a> {
553    /// Iterator over the posting list
554    iter: crate::structures::BlockSparsePostingIterator<'a>,
555    /// Query weight for this dimension
556    query_weight: f32,
557    /// Global max score (query_weight * global_max_weight)
558    max_score: f32,
559}
560
561impl<'a> SparseTermScorer<'a> {
562    /// Create a new sparse term scorer
563    ///
564    /// Note: Assumes positive weights for WAND upper bound calculation.
565    /// For negative query weights, uses absolute value to ensure valid upper bound.
566    pub fn new(posting_list: &'a BlockSparsePostingList, query_weight: f32) -> Self {
567        // Upper bound must account for sign: |query_weight| * max_weight
568        // This ensures the bound is valid regardless of weight sign
569        let max_score = query_weight.abs() * posting_list.global_max_weight();
570        Self {
571            iter: posting_list.iterator(),
572            query_weight,
573            max_score,
574        }
575    }
576
577    /// Create from Arc reference (for use with shared posting lists)
578    pub fn from_arc(posting_list: &'a Arc<BlockSparsePostingList>, query_weight: f32) -> Self {
579        Self::new(posting_list.as_ref(), query_weight)
580    }
581}
582
583impl ScoringIterator for SparseTermScorer<'_> {
584    #[inline]
585    fn doc(&self) -> DocId {
586        self.iter.doc()
587    }
588
589    #[inline]
590    fn ordinal(&self) -> u16 {
591        self.iter.ordinal()
592    }
593
594    #[inline]
595    fn advance(&mut self) -> DocId {
596        self.iter.advance()
597    }
598
599    #[inline]
600    fn seek(&mut self, target: DocId) -> DocId {
601        self.iter.seek(target)
602    }
603
604    #[inline]
605    fn score(&self) -> f32 {
606        // Dot product contribution: query_weight * stored_weight
607        self.query_weight * self.iter.weight()
608    }
609
610    #[inline]
611    fn max_score(&self) -> f32 {
612        self.max_score
613    }
614
615    #[inline]
616    fn current_block_max_score(&self) -> f32 {
617        // Use abs() for valid upper bound with negative weights
618        self.iter
619            .current_block_max_contribution(self.query_weight.abs())
620    }
621
622    #[inline]
623    fn skip_to_next_block(&mut self) -> DocId {
624        self.iter.skip_to_next_block()
625    }
626}
627
628/// Block-Max Pruning (BMP) executor for learned sparse retrieval
629///
630/// Processes blocks in score-descending order using a priority queue.
631/// Best for queries with many terms (20+), like SPLADE expansions.
632/// Uses document accumulators (FxHashMap) instead of per-term iterators.
633///
634/// Reference: Mallia et al., "Faster Learned Sparse Retrieval with
635/// Block-Max Pruning" (SIGIR 2024)
636pub struct BmpExecutor {
637    /// Posting lists for each query term
638    posting_lists: Vec<Arc<BlockSparsePostingList>>,
639    /// Query weight for each term
640    query_weights: Vec<f32>,
641    /// Number of results to return
642    k: usize,
643    /// Heap factor for approximate search
644    heap_factor: f32,
645}
646
647/// Entry in the BMP priority queue: (term_index, block_index)
648struct BmpBlockEntry {
649    /// Upper bound contribution of this block
650    contribution: f32,
651    /// Index into posting_lists
652    term_idx: usize,
653    /// Block index within the posting list
654    block_idx: usize,
655}
656
657impl PartialEq for BmpBlockEntry {
658    fn eq(&self, other: &Self) -> bool {
659        self.contribution == other.contribution
660    }
661}
662
663impl Eq for BmpBlockEntry {}
664
665impl Ord for BmpBlockEntry {
666    fn cmp(&self, other: &Self) -> Ordering {
667        // Max-heap: higher contributions come first
668        self.contribution
669            .partial_cmp(&other.contribution)
670            .unwrap_or(Ordering::Equal)
671    }
672}
673
674impl PartialOrd for BmpBlockEntry {
675    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
676        Some(self.cmp(other))
677    }
678}
679
680impl BmpExecutor {
681    /// Create a new BMP executor
682    pub fn new(
683        posting_lists: Vec<Arc<BlockSparsePostingList>>,
684        query_weights: Vec<f32>,
685        k: usize,
686        heap_factor: f32,
687    ) -> Self {
688        Self {
689            posting_lists,
690            query_weights,
691            k,
692            heap_factor: heap_factor.clamp(0.0, 1.0),
693        }
694    }
695
696    /// Execute BMP and return top-k results
697    pub fn execute(self) -> Vec<ScoredDoc> {
698        use rustc_hash::FxHashMap;
699
700        if self.posting_lists.is_empty() {
701            return Vec::new();
702        }
703
704        let num_terms = self.posting_lists.len();
705
706        // Build priority queue of all blocks across all terms
707        let mut block_queue: BinaryHeap<BmpBlockEntry> = BinaryHeap::new();
708
709        // Track remaining upper bound per term (sum of unprocessed block contributions)
710        let mut remaining_max: Vec<f32> = Vec::with_capacity(num_terms);
711
712        for (term_idx, pl) in self.posting_lists.iter().enumerate() {
713            let qw = self.query_weights[term_idx].abs();
714            let mut term_remaining = 0.0f32;
715
716            for block_idx in 0..pl.num_blocks() {
717                if let Some(block_max_weight) = pl.block_max_weight(block_idx) {
718                    let contribution = qw * block_max_weight;
719                    term_remaining += contribution;
720                    block_queue.push(BmpBlockEntry {
721                        contribution,
722                        term_idx,
723                        block_idx,
724                    });
725                }
726            }
727            remaining_max.push(term_remaining);
728        }
729
730        // Document accumulators: (doc_id, ordinal) -> accumulated_score
731        // Using (doc_id, ordinal) as key ensures scores from different ordinals
732        // are NOT mixed together for multi-valued sparse vector fields.
733        let mut accumulators: FxHashMap<(DocId, u16), f32> = FxHashMap::default();
734        let mut collector = ScoreCollector::new(self.k);
735        let mut blocks_processed = 0u64;
736
737        // Process blocks in contribution-descending order
738        while let Some(entry) = block_queue.pop() {
739            // Update remaining max for this term
740            remaining_max[entry.term_idx] -= entry.contribution;
741
742            // Early termination: check if total remaining across all terms
743            // can beat the current threshold
744            let total_remaining: f32 = remaining_max.iter().sum();
745            let adjusted_threshold = collector.threshold() * self.heap_factor;
746            if collector.len() >= self.k && total_remaining <= adjusted_threshold {
747                debug!(
748                    "BMP early termination after {} blocks: remaining={:.4} <= threshold={:.4}",
749                    blocks_processed, total_remaining, adjusted_threshold
750                );
751                break;
752            }
753
754            // Decode this block and accumulate scores
755            let pl = &self.posting_lists[entry.term_idx];
756            let block = &pl.blocks[entry.block_idx];
757            let doc_ids = block.decode_doc_ids();
758            let weights = block.decode_weights();
759            let ordinals = block.decode_ordinals();
760            let qw = self.query_weights[entry.term_idx];
761
762            for i in 0..block.header.count as usize {
763                let score_contribution = qw * weights[i];
764                *accumulators.entry((doc_ids[i], ordinals[i])).or_insert(0.0) += score_contribution;
765            }
766
767            blocks_processed += 1;
768        }
769
770        // Flush accumulators to collector
771        for (&(doc_id, ordinal), &score) in &accumulators {
772            collector.insert_with_ordinal(doc_id, score, ordinal);
773        }
774
775        let results: Vec<ScoredDoc> = collector
776            .into_sorted_results()
777            .into_iter()
778            .map(|(doc_id, score, ordinal)| ScoredDoc {
779                doc_id,
780                score,
781                ordinal,
782            })
783            .collect();
784
785        debug!(
786            "BmpExecutor completed: blocks_processed={}, accumulators={}, returned={}, top_score={:.4}",
787            blocks_processed,
788            accumulators.len(),
789            results.len(),
790            results.first().map(|r| r.score).unwrap_or(0.0)
791        );
792
793        results
794    }
795}
796
797#[cfg(test)]
798mod tests {
799    use super::*;
800
801    #[test]
802    fn test_score_collector_basic() {
803        let mut collector = ScoreCollector::new(3);
804
805        collector.insert(1, 1.0);
806        collector.insert(2, 2.0);
807        collector.insert(3, 3.0);
808        assert_eq!(collector.threshold(), 1.0);
809
810        collector.insert(4, 4.0);
811        assert_eq!(collector.threshold(), 2.0);
812
813        let results = collector.into_sorted_results();
814        assert_eq!(results.len(), 3);
815        assert_eq!(results[0].0, 4); // Highest score
816        assert_eq!(results[1].0, 3);
817        assert_eq!(results[2].0, 2);
818    }
819
820    #[test]
821    fn test_score_collector_threshold() {
822        let mut collector = ScoreCollector::new(2);
823
824        collector.insert(1, 5.0);
825        collector.insert(2, 3.0);
826        assert_eq!(collector.threshold(), 3.0);
827
828        // Should not enter (score too low)
829        assert!(!collector.would_enter(2.0));
830        assert!(!collector.insert(3, 2.0));
831
832        // Should enter (score high enough)
833        assert!(collector.would_enter(4.0));
834        assert!(collector.insert(4, 4.0));
835        assert_eq!(collector.threshold(), 4.0);
836    }
837
838    #[test]
839    fn test_heap_entry_ordering() {
840        let mut heap = BinaryHeap::new();
841        heap.push(HeapEntry {
842            doc_id: 1,
843            score: 3.0,
844            ordinal: 0,
845        });
846        heap.push(HeapEntry {
847            doc_id: 2,
848            score: 1.0,
849            ordinal: 0,
850        });
851        heap.push(HeapEntry {
852            doc_id: 3,
853            score: 2.0,
854            ordinal: 0,
855        });
856
857        // Min-heap: lowest score should come out first
858        assert_eq!(heap.pop().unwrap().score, 1.0);
859        assert_eq!(heap.pop().unwrap().score, 2.0);
860        assert_eq!(heap.pop().unwrap().score, 3.0);
861    }
862}