Skip to main content

hermes_core/query/
scoring.rs

1//! Shared scoring abstractions for text and sparse vector search
2//!
3//! Provides common traits and utilities for efficient top-k retrieval:
4//! - `ScoringIterator`: Common interface for posting list iteration with scoring
5//! - `TopKCollector`: Efficient min-heap for maintaining top-k results
6//! - `WandExecutor`: Generic MaxScore WAND algorithm
7//! - `SparseTermScorer`: ScoringIterator implementation for sparse vectors
8
9use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11use std::sync::Arc;
12
13use log::{debug, trace};
14
15use crate::DocId;
16use crate::structures::BlockSparsePostingList;
17
18/// Common interface for scoring iterators (text terms or sparse dimensions)
19///
20/// Abstracts the common operations needed for WAND-style top-k retrieval.
21pub trait ScoringIterator {
22    /// Current document ID (u32::MAX if exhausted)
23    fn doc(&self) -> DocId;
24
25    /// Current ordinal for multi-valued fields (default 0)
26    fn ordinal(&self) -> u16 {
27        0
28    }
29
30    /// Advance to next document, returns new doc ID
31    fn advance(&mut self) -> DocId;
32
33    /// Seek to first document >= target, returns new doc ID
34    fn seek(&mut self, target: DocId) -> DocId;
35
36    /// Check if iterator is exhausted
37    fn is_exhausted(&self) -> bool {
38        self.doc() == u32::MAX
39    }
40
41    /// Score contribution for current document
42    fn score(&self) -> f32;
43
44    /// Maximum possible score for this term/dimension (global upper bound)
45    fn max_score(&self) -> f32;
46
47    /// Current block's maximum score upper bound (for block-level pruning)
48    fn current_block_max_score(&self) -> f32;
49
50    /// Skip to the next block, returning the first doc_id in the new block.
51    /// Used for block-max WAND optimization when current block can't beat threshold.
52    /// Default implementation just advances (no block-level skipping).
53    fn skip_to_next_block(&mut self) -> DocId {
54        self.advance()
55    }
56}
57
58/// Entry for top-k min-heap
59#[derive(Clone, Copy)]
60pub struct HeapEntry {
61    pub doc_id: DocId,
62    pub score: f32,
63    pub ordinal: u16,
64}
65
66impl PartialEq for HeapEntry {
67    fn eq(&self, other: &Self) -> bool {
68        self.score == other.score && self.doc_id == other.doc_id
69    }
70}
71
72impl Eq for HeapEntry {}
73
74impl Ord for HeapEntry {
75    fn cmp(&self, other: &Self) -> Ordering {
76        // Min-heap: lower scores come first (to be evicted)
77        other
78            .score
79            .partial_cmp(&self.score)
80            .unwrap_or(Ordering::Equal)
81            .then_with(|| self.doc_id.cmp(&other.doc_id))
82    }
83}
84
85impl PartialOrd for HeapEntry {
86    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
87        Some(self.cmp(other))
88    }
89}
90
91/// Efficient top-k collector using min-heap
92///
93/// Maintains the k highest-scoring documents using a min-heap where the
94/// lowest score is at the top for O(1) threshold lookup and O(log k) eviction.
95/// No deduplication - caller must ensure each doc_id is inserted only once.
96pub struct ScoreCollector {
97    /// Min-heap of top-k entries (lowest score at top for eviction)
98    heap: BinaryHeap<HeapEntry>,
99    pub k: usize,
100}
101
102impl ScoreCollector {
103    /// Create a new collector for top-k results
104    pub fn new(k: usize) -> Self {
105        // Cap capacity to avoid allocation overflow for very large k
106        let capacity = k.saturating_add(1).min(1_000_000);
107        Self {
108            heap: BinaryHeap::with_capacity(capacity),
109            k,
110        }
111    }
112
113    /// Current score threshold (minimum score to enter top-k)
114    #[inline]
115    pub fn threshold(&self) -> f32 {
116        if self.heap.len() >= self.k {
117            self.heap.peek().map(|e| e.score).unwrap_or(0.0)
118        } else {
119            0.0
120        }
121    }
122
123    /// Insert a document score. Returns true if inserted in top-k.
124    /// Caller must ensure each doc_id is inserted only once.
125    #[inline]
126    pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
127        self.insert_with_ordinal(doc_id, score, 0)
128    }
129
130    /// Insert a document score with ordinal. Returns true if inserted in top-k.
131    /// Caller must ensure each doc_id is inserted only once.
132    #[inline]
133    pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
134        if self.heap.len() < self.k {
135            self.heap.push(HeapEntry {
136                doc_id,
137                score,
138                ordinal,
139            });
140            true
141        } else if score > self.threshold() {
142            self.heap.push(HeapEntry {
143                doc_id,
144                score,
145                ordinal,
146            });
147            self.heap.pop(); // Remove lowest
148            true
149        } else {
150            false
151        }
152    }
153
154    /// Check if a score could potentially enter top-k
155    #[inline]
156    pub fn would_enter(&self, score: f32) -> bool {
157        self.heap.len() < self.k || score > self.threshold()
158    }
159
160    /// Get number of documents collected so far
161    #[inline]
162    pub fn len(&self) -> usize {
163        self.heap.len()
164    }
165
166    /// Check if collector is empty
167    #[inline]
168    pub fn is_empty(&self) -> bool {
169        self.heap.is_empty()
170    }
171
172    /// Convert to sorted top-k results (descending by score)
173    pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
174        let mut results: Vec<_> = self
175            .heap
176            .into_vec()
177            .into_iter()
178            .map(|e| (e.doc_id, e.score, e.ordinal))
179            .collect();
180
181        // Sort by score descending, then doc_id ascending
182        results.sort_by(|a, b| {
183            b.1.partial_cmp(&a.1)
184                .unwrap_or(Ordering::Equal)
185                .then_with(|| a.0.cmp(&b.0))
186        });
187
188        results
189    }
190}
191
192/// Search result from WAND execution
193#[derive(Debug, Clone, Copy)]
194pub struct ScoredDoc {
195    pub doc_id: DocId,
196    pub score: f32,
197    /// Ordinal for multi-valued fields (which vector in the field matched)
198    pub ordinal: u16,
199}
200
201/// Generic MaxScore WAND executor for top-k retrieval
202///
203/// Works with any type implementing `ScoringIterator`.
204/// Implements:
205/// - WAND pivot-based pruning: skip documents that can't beat threshold
206/// - Block-max WAND: skip blocks that can't beat threshold
207/// - Efficient top-k collection
208pub struct WandExecutor<S: ScoringIterator> {
209    /// Scorers for each query term
210    scorers: Vec<S>,
211    /// Top-k collector
212    collector: ScoreCollector,
213    /// Heap factor for approximate search (SEISMIC-style)
214    /// A block/document is skipped if max_possible < heap_factor * threshold
215    /// - 1.0 = exact search (default)
216    /// - 0.8 = approximate, faster with minor recall loss
217    heap_factor: f32,
218}
219
220impl<S: ScoringIterator> WandExecutor<S> {
221    /// Create a new WAND executor with exact search (heap_factor = 1.0)
222    pub fn new(scorers: Vec<S>, k: usize) -> Self {
223        Self::with_heap_factor(scorers, k, 1.0)
224    }
225
226    /// Create a new WAND executor with approximate search
227    ///
228    /// `heap_factor` controls the trade-off between speed and recall:
229    /// - 1.0 = exact search
230    /// - 0.8 = ~20% faster, minor recall loss
231    /// - 0.5 = much faster, noticeable recall loss
232    pub fn with_heap_factor(scorers: Vec<S>, k: usize, heap_factor: f32) -> Self {
233        let total_upper: f32 = scorers.iter().map(|s| s.max_score()).sum();
234
235        debug!(
236            "Creating WandExecutor: num_scorers={}, k={}, total_upper={:.4}, heap_factor={:.2}",
237            scorers.len(),
238            k,
239            total_upper,
240            heap_factor
241        );
242
243        Self {
244            scorers,
245            collector: ScoreCollector::new(k),
246            heap_factor: heap_factor.clamp(0.0, 1.0),
247        }
248    }
249
250    /// Execute WAND and return top-k results
251    ///
252    /// Implements the WAND (Weak AND) algorithm with pivot-based pruning:
253    /// 1. Maintain iterators sorted by current docID (using sorted vector)
254    /// 2. Find pivot: first term where cumulative upper bounds > threshold
255    /// 3. If all iterators at pivot docID, fully score; otherwise skip to pivot
256    /// 4. Insert into collector and advance
257    ///
258    /// Reference: Broder et al., "Efficient Query Evaluation using a Two-Level
259    /// Retrieval Process" (CIKM 2003)
260    ///
261    /// Note: For small number of terms (typical queries), a sorted vector with
262    /// insertion sort is faster than a heap due to better cache locality.
263    /// The vector stays mostly sorted, so insertion sort is ~O(n) amortized.
264    pub fn execute(mut self) -> Vec<ScoredDoc> {
265        if self.scorers.is_empty() {
266            debug!("WandExecutor: no scorers, returning empty results");
267            return Vec::new();
268        }
269
270        let mut docs_scored = 0u64;
271        let mut docs_skipped = 0u64;
272        let mut blocks_skipped = 0u64;
273        let num_scorers = self.scorers.len();
274
275        // Indices sorted by current docID - initial sort O(n log n)
276        let mut sorted_indices: Vec<usize> = (0..num_scorers).collect();
277        sorted_indices.sort_by_key(|&i| self.scorers[i].doc());
278
279        loop {
280            // Find first non-exhausted iterator (they're sorted, so check first)
281            let first_active = sorted_indices
282                .iter()
283                .position(|&i| self.scorers[i].doc() != u32::MAX);
284
285            let first_active = match first_active {
286                Some(pos) => pos,
287                None => break, // All exhausted
288            };
289
290            // Early termination: if total upper bound can't beat (adjusted) threshold
291            // heap_factor < 1.0 makes pruning more aggressive (approximate search)
292            let total_upper: f32 = sorted_indices[first_active..]
293                .iter()
294                .map(|&i| self.scorers[i].max_score())
295                .sum();
296
297            let adjusted_threshold = self.collector.threshold() * self.heap_factor;
298            if self.collector.len() >= self.collector.k && total_upper <= adjusted_threshold {
299                debug!(
300                    "Early termination: upper_bound={:.4} <= adjusted_threshold={:.4}",
301                    total_upper, adjusted_threshold
302                );
303                break;
304            }
305
306            // Find pivot: first term where cumulative upper bounds > adjusted threshold
307            let mut cumsum = 0.0f32;
308            let mut pivot_pos = first_active;
309
310            for (pos, &idx) in sorted_indices.iter().enumerate().skip(first_active) {
311                cumsum += self.scorers[idx].max_score();
312                if cumsum > adjusted_threshold || self.collector.len() < self.collector.k {
313                    pivot_pos = pos;
314                    break;
315                }
316            }
317
318            let pivot_idx = sorted_indices[pivot_pos];
319            let pivot_doc = self.scorers[pivot_idx].doc();
320
321            if pivot_doc == u32::MAX {
322                break;
323            }
324
325            // Check if all iterators before pivot are at pivot_doc
326            let all_at_pivot = sorted_indices[first_active..=pivot_pos]
327                .iter()
328                .all(|&i| self.scorers[i].doc() == pivot_doc);
329
330            if all_at_pivot {
331                // Block-max WAND optimization: check if current blocks can beat threshold
332                // Sum the block-max scores for all iterators at pivot_doc
333                let block_max_sum: f32 = sorted_indices[first_active..=pivot_pos]
334                    .iter()
335                    .filter(|&&i| self.scorers[i].doc() == pivot_doc)
336                    .map(|&i| self.scorers[i].current_block_max_score())
337                    .sum();
338
339                if self.collector.len() >= self.collector.k && block_max_sum <= adjusted_threshold {
340                    // Block can't beat threshold - skip all iterators at pivot to next block
341                    debug!(
342                        "Block skip at doc {}: block_max={:.4} <= threshold={:.4}",
343                        pivot_doc, block_max_sum, adjusted_threshold
344                    );
345
346                    for (_pos, &idx) in sorted_indices.iter().enumerate().skip(first_active) {
347                        if self.scorers[idx].doc() == pivot_doc {
348                            self.scorers[idx].skip_to_next_block();
349                        } else if self.scorers[idx].doc() > pivot_doc {
350                            break;
351                        }
352                    }
353
354                    // Re-sort after block skips
355                    sorted_indices[first_active..].sort_by_key(|&i| self.scorers[i].doc());
356                    blocks_skipped += 1;
357                    continue;
358                }
359
360                // All terms up to pivot are at the same doc - fully score it
361                let mut score = 0.0f32;
362                let mut matching_terms = 0u32;
363                let mut ordinal: u16 = 0;
364
365                // Score from all iterators that have this document and advance them
366                // Collect indices that need re-sorting
367                let mut modified_positions: Vec<usize> = Vec::new();
368
369                for (pos, &idx) in sorted_indices.iter().enumerate().skip(first_active) {
370                    let doc = self.scorers[idx].doc();
371                    if doc == pivot_doc {
372                        score += self.scorers[idx].score();
373                        // Take ordinal from first matching scorer (all should have same ordinal)
374                        if matching_terms == 0 {
375                            ordinal = self.scorers[idx].ordinal();
376                        }
377                        matching_terms += 1;
378                        self.scorers[idx].advance();
379                        modified_positions.push(pos);
380                    } else if doc > pivot_doc {
381                        break;
382                    }
383                }
384
385                trace!(
386                    "Doc {}: score={:.4}, matching={}/{}, threshold={:.4}",
387                    pivot_doc, score, matching_terms, num_scorers, adjusted_threshold
388                );
389
390                if self
391                    .collector
392                    .insert_with_ordinal(pivot_doc, score, ordinal)
393                {
394                    docs_scored += 1;
395                } else {
396                    docs_skipped += 1;
397                }
398
399                // Re-sort modified iterators using insertion sort (efficient for nearly-sorted)
400                // Move each modified iterator to its correct position
401                for &pos in modified_positions.iter().rev() {
402                    let idx = sorted_indices[pos];
403                    let new_doc = self.scorers[idx].doc();
404                    // Bubble up to correct position
405                    let mut curr = pos;
406                    while curr + 1 < sorted_indices.len()
407                        && self.scorers[sorted_indices[curr + 1]].doc() < new_doc
408                    {
409                        sorted_indices.swap(curr, curr + 1);
410                        curr += 1;
411                    }
412                }
413            } else {
414                // Not all at pivot - skip the first iterator to pivot_doc
415                let first_pos = first_active;
416                let first_idx = sorted_indices[first_pos];
417                self.scorers[first_idx].seek(pivot_doc);
418                docs_skipped += 1;
419
420                // Re-sort the modified iterator
421                let new_doc = self.scorers[first_idx].doc();
422                let mut curr = first_pos;
423                while curr + 1 < sorted_indices.len()
424                    && self.scorers[sorted_indices[curr + 1]].doc() < new_doc
425                {
426                    sorted_indices.swap(curr, curr + 1);
427                    curr += 1;
428                }
429            }
430        }
431
432        let results: Vec<ScoredDoc> = self
433            .collector
434            .into_sorted_results()
435            .into_iter()
436            .map(|(doc_id, score, ordinal)| ScoredDoc {
437                doc_id,
438                score,
439                ordinal,
440            })
441            .collect();
442
443        debug!(
444            "WandExecutor completed: scored={}, skipped={}, blocks_skipped={}, returned={}, top_score={:.4}",
445            docs_scored,
446            docs_skipped,
447            blocks_skipped,
448            results.len(),
449            results.first().map(|r| r.score).unwrap_or(0.0)
450        );
451
452        results
453    }
454}
455
456/// Scorer for full-text terms using WAND optimization
457///
458/// Wraps a `BlockPostingList` with BM25 parameters to implement `ScoringIterator`.
459/// Enables MaxScore pruning for efficient top-k retrieval in OR queries.
460pub struct TextTermScorer {
461    /// Iterator over the posting list (owned)
462    iter: crate::structures::BlockPostingIterator<'static>,
463    /// IDF component for BM25
464    idf: f32,
465    /// Average field length for BM25 normalization
466    avg_field_len: f32,
467    /// Pre-computed max score (using max_tf from posting list)
468    max_score: f32,
469}
470
471impl TextTermScorer {
472    /// Create a new text term scorer with BM25 parameters
473    pub fn new(
474        posting_list: crate::structures::BlockPostingList,
475        idf: f32,
476        avg_field_len: f32,
477    ) -> Self {
478        // Compute max score using actual max_tf from posting list
479        let max_tf = posting_list.max_tf() as f32;
480        let doc_count = posting_list.doc_count();
481        let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
482
483        debug!(
484            "Created TextTermScorer: doc_count={}, max_tf={:.0}, idf={:.4}, avg_field_len={:.2}, max_score={:.4}",
485            doc_count, max_tf, idf, avg_field_len, max_score
486        );
487
488        Self {
489            iter: posting_list.into_iterator(),
490            idf,
491            avg_field_len,
492            max_score,
493        }
494    }
495}
496
497impl ScoringIterator for TextTermScorer {
498    #[inline]
499    fn doc(&self) -> DocId {
500        self.iter.doc()
501    }
502
503    #[inline]
504    fn advance(&mut self) -> DocId {
505        self.iter.advance()
506    }
507
508    #[inline]
509    fn seek(&mut self, target: DocId) -> DocId {
510        self.iter.seek(target)
511    }
512
513    #[inline]
514    fn score(&self) -> f32 {
515        let tf = self.iter.term_freq() as f32;
516        // Use tf as proxy for doc length (common approximation when field lengths aren't stored)
517        super::bm25_score(tf, self.idf, tf, self.avg_field_len)
518    }
519
520    #[inline]
521    fn max_score(&self) -> f32 {
522        self.max_score
523    }
524
525    #[inline]
526    fn current_block_max_score(&self) -> f32 {
527        // Use per-block max_tf for tighter Block-Max WAND bounds
528        let block_max_tf = self.iter.current_block_max_tf() as f32;
529        super::bm25_upper_bound(block_max_tf.max(1.0), self.idf)
530    }
531
532    #[inline]
533    fn skip_to_next_block(&mut self) -> DocId {
534        self.iter.skip_to_next_block()
535    }
536}
537
538/// Scorer for sparse vector dimensions
539///
540/// Wraps a `BlockSparsePostingList` with a query weight to implement `ScoringIterator`.
541pub struct SparseTermScorer<'a> {
542    /// Iterator over the posting list
543    iter: crate::structures::BlockSparsePostingIterator<'a>,
544    /// Query weight for this dimension
545    query_weight: f32,
546    /// Global max score (query_weight * global_max_weight)
547    max_score: f32,
548}
549
550impl<'a> SparseTermScorer<'a> {
551    /// Create a new sparse term scorer
552    ///
553    /// Note: Assumes positive weights for WAND upper bound calculation.
554    /// For negative query weights, uses absolute value to ensure valid upper bound.
555    pub fn new(posting_list: &'a BlockSparsePostingList, query_weight: f32) -> Self {
556        // Upper bound must account for sign: |query_weight| * max_weight
557        // This ensures the bound is valid regardless of weight sign
558        let max_score = query_weight.abs() * posting_list.global_max_weight();
559        Self {
560            iter: posting_list.iterator(),
561            query_weight,
562            max_score,
563        }
564    }
565
566    /// Create from Arc reference (for use with shared posting lists)
567    pub fn from_arc(posting_list: &'a Arc<BlockSparsePostingList>, query_weight: f32) -> Self {
568        Self::new(posting_list.as_ref(), query_weight)
569    }
570}
571
572impl ScoringIterator for SparseTermScorer<'_> {
573    #[inline]
574    fn doc(&self) -> DocId {
575        self.iter.doc()
576    }
577
578    #[inline]
579    fn ordinal(&self) -> u16 {
580        self.iter.ordinal()
581    }
582
583    #[inline]
584    fn advance(&mut self) -> DocId {
585        self.iter.advance()
586    }
587
588    #[inline]
589    fn seek(&mut self, target: DocId) -> DocId {
590        self.iter.seek(target)
591    }
592
593    #[inline]
594    fn score(&self) -> f32 {
595        // Dot product contribution: query_weight * stored_weight
596        self.query_weight * self.iter.weight()
597    }
598
599    #[inline]
600    fn max_score(&self) -> f32 {
601        self.max_score
602    }
603
604    #[inline]
605    fn current_block_max_score(&self) -> f32 {
606        // Use abs() for valid upper bound with negative weights
607        self.iter
608            .current_block_max_contribution(self.query_weight.abs())
609    }
610
611    #[inline]
612    fn skip_to_next_block(&mut self) -> DocId {
613        self.iter.skip_to_next_block()
614    }
615}
616
617/// Block-Max Pruning (BMP) executor for learned sparse retrieval
618///
619/// Processes blocks in score-descending order using a priority queue.
620/// Best for queries with many terms (20+), like SPLADE expansions.
621/// Uses document accumulators (FxHashMap) instead of per-term iterators.
622///
623/// Reference: Mallia et al., "Faster Learned Sparse Retrieval with
624/// Block-Max Pruning" (SIGIR 2024)
625pub struct BmpExecutor {
626    /// Posting lists for each query term
627    posting_lists: Vec<Arc<BlockSparsePostingList>>,
628    /// Query weight for each term
629    query_weights: Vec<f32>,
630    /// Number of results to return
631    k: usize,
632    /// Heap factor for approximate search
633    heap_factor: f32,
634}
635
636/// Entry in the BMP priority queue: (term_index, block_index)
637struct BmpBlockEntry {
638    /// Upper bound contribution of this block
639    contribution: f32,
640    /// Index into posting_lists
641    term_idx: usize,
642    /// Block index within the posting list
643    block_idx: usize,
644}
645
646impl PartialEq for BmpBlockEntry {
647    fn eq(&self, other: &Self) -> bool {
648        self.contribution == other.contribution
649    }
650}
651
652impl Eq for BmpBlockEntry {}
653
654impl Ord for BmpBlockEntry {
655    fn cmp(&self, other: &Self) -> Ordering {
656        // Max-heap: higher contributions come first
657        self.contribution
658            .partial_cmp(&other.contribution)
659            .unwrap_or(Ordering::Equal)
660    }
661}
662
663impl PartialOrd for BmpBlockEntry {
664    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
665        Some(self.cmp(other))
666    }
667}
668
669impl BmpExecutor {
670    /// Create a new BMP executor
671    pub fn new(
672        posting_lists: Vec<Arc<BlockSparsePostingList>>,
673        query_weights: Vec<f32>,
674        k: usize,
675        heap_factor: f32,
676    ) -> Self {
677        Self {
678            posting_lists,
679            query_weights,
680            k,
681            heap_factor: heap_factor.clamp(0.0, 1.0),
682        }
683    }
684
685    /// Execute BMP and return top-k results
686    pub fn execute(self) -> Vec<ScoredDoc> {
687        use rustc_hash::FxHashMap;
688
689        if self.posting_lists.is_empty() {
690            return Vec::new();
691        }
692
693        let num_terms = self.posting_lists.len();
694
695        // Build priority queue of all blocks across all terms
696        let mut block_queue: BinaryHeap<BmpBlockEntry> = BinaryHeap::new();
697
698        // Track remaining upper bound per term (sum of unprocessed block contributions)
699        let mut remaining_max: Vec<f32> = Vec::with_capacity(num_terms);
700
701        for (term_idx, pl) in self.posting_lists.iter().enumerate() {
702            let qw = self.query_weights[term_idx].abs();
703            let mut term_remaining = 0.0f32;
704
705            for block_idx in 0..pl.num_blocks() {
706                if let Some(block_max_weight) = pl.block_max_weight(block_idx) {
707                    let contribution = qw * block_max_weight;
708                    term_remaining += contribution;
709                    block_queue.push(BmpBlockEntry {
710                        contribution,
711                        term_idx,
712                        block_idx,
713                    });
714                }
715            }
716            remaining_max.push(term_remaining);
717        }
718
719        // Document accumulators: doc_id -> (accumulated_score, best_ordinal)
720        let mut accumulators: FxHashMap<DocId, (f32, u16)> = FxHashMap::default();
721        let mut collector = ScoreCollector::new(self.k);
722        let mut blocks_processed = 0u64;
723
724        // Process blocks in contribution-descending order
725        while let Some(entry) = block_queue.pop() {
726            // Update remaining max for this term
727            remaining_max[entry.term_idx] -= entry.contribution;
728
729            // Early termination: check if total remaining across all terms
730            // can beat the current threshold
731            let total_remaining: f32 = remaining_max.iter().sum();
732            let adjusted_threshold = collector.threshold() * self.heap_factor;
733            if collector.len() >= self.k && total_remaining <= adjusted_threshold {
734                debug!(
735                    "BMP early termination after {} blocks: remaining={:.4} <= threshold={:.4}",
736                    blocks_processed, total_remaining, adjusted_threshold
737                );
738                break;
739            }
740
741            // Decode this block and accumulate scores
742            let pl = &self.posting_lists[entry.term_idx];
743            let block = &pl.blocks[entry.block_idx];
744            let doc_ids = block.decode_doc_ids();
745            let weights = block.decode_weights();
746            let ordinals = block.decode_ordinals();
747            let qw = self.query_weights[entry.term_idx];
748
749            for i in 0..block.header.count as usize {
750                let score_contribution = qw * weights[i];
751                let acc = accumulators.entry(doc_ids[i]).or_insert((0.0, ordinals[i]));
752                acc.0 += score_contribution;
753            }
754
755            blocks_processed += 1;
756        }
757
758        // Flush accumulators to collector
759        for (doc_id, (score, ordinal)) in &accumulators {
760            collector.insert_with_ordinal(*doc_id, *score, *ordinal);
761        }
762
763        let results: Vec<ScoredDoc> = collector
764            .into_sorted_results()
765            .into_iter()
766            .map(|(doc_id, score, ordinal)| ScoredDoc {
767                doc_id,
768                score,
769                ordinal,
770            })
771            .collect();
772
773        debug!(
774            "BmpExecutor completed: blocks_processed={}, accumulators={}, returned={}, top_score={:.4}",
775            blocks_processed,
776            accumulators.len(),
777            results.len(),
778            results.first().map(|r| r.score).unwrap_or(0.0)
779        );
780
781        results
782    }
783}
784
785/// MaxScore executor with essential/non-essential term partitioning
786///
787/// For medium-length queries (6-20 terms), partitions terms into:
788/// - Essential terms: must be checked for every candidate document
789/// - Non-essential terms: only scored when candidate could enter top-k
790///
791/// Reference: Turtle & Flood, "Query Evaluation: Strategies and
792/// Optimizations" (Information Processing & Management, 1995)
793pub struct MaxScoreExecutor<S: ScoringIterator> {
794    /// Scorers sorted by max_score ascending
795    scorers: Vec<S>,
796    /// Cumulative max_score prefix sums (prefix_sums[i] = sum of max_scores[0..=i])
797    prefix_sums: Vec<f32>,
798    /// Top-k collector
799    collector: ScoreCollector,
800    /// Heap factor for approximate search
801    heap_factor: f32,
802}
803
804impl<S: ScoringIterator> MaxScoreExecutor<S> {
805    /// Create a new MaxScore executor
806    pub fn new(scorers: Vec<S>, k: usize) -> Self {
807        Self::with_heap_factor(scorers, k, 1.0)
808    }
809
810    /// Create a new MaxScore executor with approximate search
811    pub fn with_heap_factor(mut scorers: Vec<S>, k: usize, heap_factor: f32) -> Self {
812        // Sort scorers by max_score ascending (non-essential terms first)
813        scorers.sort_by(|a, b| {
814            a.max_score()
815                .partial_cmp(&b.max_score())
816                .unwrap_or(Ordering::Equal)
817        });
818
819        // Compute prefix sums of max_scores
820        let mut prefix_sums = Vec::with_capacity(scorers.len());
821        let mut cumsum = 0.0f32;
822        for s in &scorers {
823            cumsum += s.max_score();
824            prefix_sums.push(cumsum);
825        }
826
827        debug!(
828            "Creating MaxScoreExecutor: num_scorers={}, k={}, total_upper={:.4}, heap_factor={:.2}",
829            scorers.len(),
830            k,
831            cumsum,
832            heap_factor
833        );
834
835        Self {
836            scorers,
837            prefix_sums,
838            collector: ScoreCollector::new(k),
839            heap_factor: heap_factor.clamp(0.0, 1.0),
840        }
841    }
842
843    /// Find partition point: [0..partition) = non-essential, [partition..) = essential
844    /// Non-essential terms have cumulative max_score <= threshold
845    fn find_partition(&self) -> usize {
846        let threshold = self.collector.threshold() * self.heap_factor;
847        // Find first index where prefix_sums[i] > threshold
848        // Everything before that is non-essential
849        self.prefix_sums
850            .iter()
851            .position(|&sum| sum > threshold)
852            .unwrap_or(self.scorers.len())
853    }
854
855    /// Execute MaxScore and return top-k results
856    pub fn execute(mut self) -> Vec<ScoredDoc> {
857        if self.scorers.is_empty() {
858            return Vec::new();
859        }
860
861        let n = self.scorers.len();
862        let mut docs_scored = 0u64;
863        let mut docs_skipped = 0u64;
864
865        loop {
866            let partition = self.find_partition();
867
868            // If all terms are non-essential, we're done
869            if partition >= n {
870                debug!("MaxScore: all terms non-essential, early termination");
871                break;
872            }
873
874            // Find minimum doc_id across essential scorers [partition..n)
875            let mut min_doc = u32::MAX;
876            for i in partition..n {
877                let doc = self.scorers[i].doc();
878                if doc < min_doc {
879                    min_doc = doc;
880                }
881            }
882
883            if min_doc == u32::MAX {
884                break; // All essential scorers exhausted
885            }
886
887            // Score from essential scorers and advance them
888            let mut score = 0.0f32;
889            let mut ordinal = 0u16;
890            let mut first_match = true;
891
892            for i in partition..n {
893                if self.scorers[i].doc() == min_doc {
894                    score += self.scorers[i].score();
895                    if first_match {
896                        ordinal = self.scorers[i].ordinal();
897                        first_match = false;
898                    }
899                    self.scorers[i].advance();
900                }
901            }
902
903            // Check if score + non-essential upper bound could beat threshold
904            let non_essential_upper = if partition > 0 {
905                self.prefix_sums[partition - 1]
906            } else {
907                0.0
908            };
909
910            let adjusted_threshold = self.collector.threshold() * self.heap_factor;
911
912            if self.collector.len() >= self.collector.k
913                && score + non_essential_upper <= adjusted_threshold
914            {
915                docs_skipped += 1;
916                continue;
917            }
918
919            // Score from non-essential scorers (seek to min_doc)
920            for i in 0..partition {
921                let doc = self.scorers[i].seek(min_doc);
922                if doc == min_doc {
923                    score += self.scorers[i].score();
924                    self.scorers[i].advance();
925                }
926            }
927
928            self.collector.insert_with_ordinal(min_doc, score, ordinal);
929            docs_scored += 1;
930        }
931
932        let results: Vec<ScoredDoc> = self
933            .collector
934            .into_sorted_results()
935            .into_iter()
936            .map(|(doc_id, score, ordinal)| ScoredDoc {
937                doc_id,
938                score,
939                ordinal,
940            })
941            .collect();
942
943        debug!(
944            "MaxScoreExecutor completed: scored={}, skipped={}, returned={}, top_score={:.4}",
945            docs_scored,
946            docs_skipped,
947            results.len(),
948            results.first().map(|r| r.score).unwrap_or(0.0)
949        );
950
951        results
952    }
953}
954
955#[cfg(test)]
956mod tests {
957    use super::*;
958
959    #[test]
960    fn test_score_collector_basic() {
961        let mut collector = ScoreCollector::new(3);
962
963        collector.insert(1, 1.0);
964        collector.insert(2, 2.0);
965        collector.insert(3, 3.0);
966        assert_eq!(collector.threshold(), 1.0);
967
968        collector.insert(4, 4.0);
969        assert_eq!(collector.threshold(), 2.0);
970
971        let results = collector.into_sorted_results();
972        assert_eq!(results.len(), 3);
973        assert_eq!(results[0].0, 4); // Highest score
974        assert_eq!(results[1].0, 3);
975        assert_eq!(results[2].0, 2);
976    }
977
978    #[test]
979    fn test_score_collector_threshold() {
980        let mut collector = ScoreCollector::new(2);
981
982        collector.insert(1, 5.0);
983        collector.insert(2, 3.0);
984        assert_eq!(collector.threshold(), 3.0);
985
986        // Should not enter (score too low)
987        assert!(!collector.would_enter(2.0));
988        assert!(!collector.insert(3, 2.0));
989
990        // Should enter (score high enough)
991        assert!(collector.would_enter(4.0));
992        assert!(collector.insert(4, 4.0));
993        assert_eq!(collector.threshold(), 4.0);
994    }
995
996    #[test]
997    fn test_heap_entry_ordering() {
998        let mut heap = BinaryHeap::new();
999        heap.push(HeapEntry {
1000            doc_id: 1,
1001            score: 3.0,
1002            ordinal: 0,
1003        });
1004        heap.push(HeapEntry {
1005            doc_id: 2,
1006            score: 1.0,
1007            ordinal: 0,
1008        });
1009        heap.push(HeapEntry {
1010            doc_id: 3,
1011            score: 2.0,
1012            ordinal: 0,
1013        });
1014
1015        // Min-heap: lowest score should come out first
1016        assert_eq!(heap.pop().unwrap().score, 1.0);
1017        assert_eq!(heap.pop().unwrap().score, 2.0);
1018        assert_eq!(heap.pop().unwrap().score, 3.0);
1019    }
1020}