Skip to main content

hermes_core/query/
scoring.rs

1//! Shared scoring abstractions for text and sparse vector search
2//!
3//! Provides common traits and utilities for efficient top-k retrieval:
4//! - `ScoringIterator`: Common interface for posting list iteration with scoring
5//! - `TopKCollector`: Efficient min-heap for maintaining top-k results
6//! - `WandExecutor`: Generic MaxScore WAND algorithm
7//! - `SparseTermScorer`: ScoringIterator implementation for sparse vectors
8
9use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11use std::sync::Arc;
12
13use log::{debug, trace};
14
15use crate::DocId;
16use crate::structures::BlockSparsePostingList;
17
18/// Common interface for scoring iterators (text terms or sparse dimensions)
19///
20/// Abstracts the common operations needed for WAND-style top-k retrieval.
21pub trait ScoringIterator {
22    /// Current document ID (u32::MAX if exhausted)
23    fn doc(&self) -> DocId;
24
25    /// Current ordinal for multi-valued fields (default 0)
26    fn ordinal(&self) -> u16 {
27        0
28    }
29
30    /// Advance to next document, returns new doc ID
31    fn advance(&mut self) -> DocId;
32
33    /// Seek to first document >= target, returns new doc ID
34    fn seek(&mut self, target: DocId) -> DocId;
35
36    /// Check if iterator is exhausted
37    fn is_exhausted(&self) -> bool {
38        self.doc() == u32::MAX
39    }
40
41    /// Score contribution for current document
42    fn score(&self) -> f32;
43
44    /// Maximum possible score for this term/dimension (global upper bound)
45    fn max_score(&self) -> f32;
46
47    /// Current block's maximum score upper bound (for block-level pruning)
48    fn current_block_max_score(&self) -> f32;
49
50    /// Skip to the next block, returning the first doc_id in the new block.
51    /// Used for block-max WAND optimization when current block can't beat threshold.
52    /// Default implementation just advances (no block-level skipping).
53    fn skip_to_next_block(&mut self) -> DocId {
54        self.advance()
55    }
56}
57
58/// Entry for top-k min-heap
59#[derive(Clone, Copy)]
60pub struct HeapEntry {
61    pub doc_id: DocId,
62    pub score: f32,
63    pub ordinal: u16,
64}
65
66impl PartialEq for HeapEntry {
67    fn eq(&self, other: &Self) -> bool {
68        self.score == other.score && self.doc_id == other.doc_id
69    }
70}
71
72impl Eq for HeapEntry {}
73
74impl Ord for HeapEntry {
75    fn cmp(&self, other: &Self) -> Ordering {
76        // Min-heap: lower scores come first (to be evicted)
77        other
78            .score
79            .partial_cmp(&self.score)
80            .unwrap_or(Ordering::Equal)
81            .then_with(|| self.doc_id.cmp(&other.doc_id))
82    }
83}
84
85impl PartialOrd for HeapEntry {
86    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
87        Some(self.cmp(other))
88    }
89}
90
91/// Efficient top-k collector using min-heap
92///
93/// Maintains the k highest-scoring documents using a min-heap where the
94/// lowest score is at the top for O(1) threshold lookup and O(log k) eviction.
95/// No deduplication - caller must ensure each doc_id is inserted only once.
96pub struct ScoreCollector {
97    /// Min-heap of top-k entries (lowest score at top for eviction)
98    heap: BinaryHeap<HeapEntry>,
99    pub k: usize,
100}
101
102impl ScoreCollector {
103    /// Create a new collector for top-k results
104    pub fn new(k: usize) -> Self {
105        // Cap capacity to avoid allocation overflow for very large k
106        let capacity = k.saturating_add(1).min(1_000_000);
107        Self {
108            heap: BinaryHeap::with_capacity(capacity),
109            k,
110        }
111    }
112
113    /// Current score threshold (minimum score to enter top-k)
114    #[inline]
115    pub fn threshold(&self) -> f32 {
116        if self.heap.len() >= self.k {
117            self.heap.peek().map(|e| e.score).unwrap_or(0.0)
118        } else {
119            0.0
120        }
121    }
122
123    /// Insert a document score. Returns true if inserted in top-k.
124    /// Caller must ensure each doc_id is inserted only once.
125    #[inline]
126    pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
127        self.insert_with_ordinal(doc_id, score, 0)
128    }
129
130    /// Insert a document score with ordinal. Returns true if inserted in top-k.
131    /// Caller must ensure each doc_id is inserted only once.
132    #[inline]
133    pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
134        if self.heap.len() < self.k {
135            self.heap.push(HeapEntry {
136                doc_id,
137                score,
138                ordinal,
139            });
140            true
141        } else if score > self.threshold() {
142            self.heap.push(HeapEntry {
143                doc_id,
144                score,
145                ordinal,
146            });
147            self.heap.pop(); // Remove lowest
148            true
149        } else {
150            false
151        }
152    }
153
154    /// Check if a score could potentially enter top-k
155    #[inline]
156    pub fn would_enter(&self, score: f32) -> bool {
157        self.heap.len() < self.k || score > self.threshold()
158    }
159
160    /// Get number of documents collected so far
161    #[inline]
162    pub fn len(&self) -> usize {
163        self.heap.len()
164    }
165
166    /// Check if collector is empty
167    #[inline]
168    pub fn is_empty(&self) -> bool {
169        self.heap.is_empty()
170    }
171
172    /// Convert to sorted top-k results (descending by score)
173    pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
174        let mut results: Vec<_> = self
175            .heap
176            .into_vec()
177            .into_iter()
178            .map(|e| (e.doc_id, e.score, e.ordinal))
179            .collect();
180
181        // Sort by score descending, then doc_id ascending
182        results.sort_by(|a, b| {
183            b.1.partial_cmp(&a.1)
184                .unwrap_or(Ordering::Equal)
185                .then_with(|| a.0.cmp(&b.0))
186        });
187
188        results
189    }
190}
191
192/// Search result from WAND execution
193#[derive(Debug, Clone, Copy)]
194pub struct ScoredDoc {
195    pub doc_id: DocId,
196    pub score: f32,
197    /// Ordinal for multi-valued fields (which vector in the field matched)
198    pub ordinal: u16,
199}
200
201/// Generic MaxScore WAND executor for top-k retrieval
202///
203/// Works with any type implementing `ScoringIterator`.
204/// Implements:
205/// - WAND pivot-based pruning: skip documents that can't beat threshold
206/// - Block-max WAND: skip blocks that can't beat threshold
207/// - Efficient top-k collection
208pub struct WandExecutor<S: ScoringIterator> {
209    /// Scorers for each query term
210    scorers: Vec<S>,
211    /// Top-k collector
212    collector: ScoreCollector,
213    /// Heap factor for approximate search (SEISMIC-style)
214    /// A block/document is skipped if max_possible < heap_factor * threshold
215    /// - 1.0 = exact search (default)
216    /// - 0.8 = approximate, faster with minor recall loss
217    heap_factor: f32,
218}
219
220impl<S: ScoringIterator> WandExecutor<S> {
221    /// Create a new WAND executor with exact search (heap_factor = 1.0)
222    pub fn new(scorers: Vec<S>, k: usize) -> Self {
223        Self::with_heap_factor(scorers, k, 1.0)
224    }
225
226    /// Create a new WAND executor with approximate search
227    ///
228    /// `heap_factor` controls the trade-off between speed and recall:
229    /// - 1.0 = exact search
230    /// - 0.8 = ~20% faster, minor recall loss
231    /// - 0.5 = much faster, noticeable recall loss
232    pub fn with_heap_factor(scorers: Vec<S>, k: usize, heap_factor: f32) -> Self {
233        let total_upper: f32 = scorers.iter().map(|s| s.max_score()).sum();
234
235        debug!(
236            "Creating WandExecutor: num_scorers={}, k={}, total_upper={:.4}, heap_factor={:.2}",
237            scorers.len(),
238            k,
239            total_upper,
240            heap_factor
241        );
242
243        Self {
244            scorers,
245            collector: ScoreCollector::new(k),
246            heap_factor: heap_factor.clamp(0.0, 1.0),
247        }
248    }
249
250    /// Execute WAND and return top-k results
251    ///
252    /// Implements the WAND (Weak AND) algorithm with pivot-based pruning:
253    /// 1. Maintain iterators sorted by current docID (using sorted vector)
254    /// 2. Find pivot: first term where cumulative upper bounds > threshold
255    /// 3. If all iterators at pivot docID, fully score; otherwise skip to pivot
256    /// 4. Insert into collector and advance
257    ///
258    /// Reference: Broder et al., "Efficient Query Evaluation using a Two-Level
259    /// Retrieval Process" (CIKM 2003)
260    ///
261    /// Note: For small number of terms (typical queries), a sorted vector with
262    /// insertion sort is faster than a heap due to better cache locality.
263    /// The vector stays mostly sorted, so insertion sort is ~O(n) amortized.
264    pub fn execute(mut self) -> Vec<ScoredDoc> {
265        if self.scorers.is_empty() {
266            debug!("WandExecutor: no scorers, returning empty results");
267            return Vec::new();
268        }
269
270        let mut docs_scored = 0u64;
271        let mut docs_skipped = 0u64;
272        let mut blocks_skipped = 0u64;
273        let num_scorers = self.scorers.len();
274
275        // Indices sorted by current docID - initial sort O(n log n)
276        let mut sorted_indices: Vec<usize> = (0..num_scorers).collect();
277        sorted_indices.sort_by_key(|&i| self.scorers[i].doc());
278
279        loop {
280            // Find first non-exhausted iterator (they're sorted, so check first)
281            let first_active = sorted_indices
282                .iter()
283                .position(|&i| self.scorers[i].doc() != u32::MAX);
284
285            let first_active = match first_active {
286                Some(pos) => pos,
287                None => break, // All exhausted
288            };
289
290            // Early termination: if total upper bound can't beat (adjusted) threshold
291            // heap_factor < 1.0 makes pruning more aggressive (approximate search)
292            let total_upper: f32 = sorted_indices[first_active..]
293                .iter()
294                .map(|&i| self.scorers[i].max_score())
295                .sum();
296
297            let adjusted_threshold = self.collector.threshold() * self.heap_factor;
298            if self.collector.len() >= self.collector.k && total_upper <= adjusted_threshold {
299                debug!(
300                    "Early termination: upper_bound={:.4} <= adjusted_threshold={:.4}",
301                    total_upper, adjusted_threshold
302                );
303                break;
304            }
305
306            // Find pivot: first term where cumulative upper bounds > adjusted threshold
307            let mut cumsum = 0.0f32;
308            let mut pivot_pos = first_active;
309
310            for (pos, &idx) in sorted_indices.iter().enumerate().skip(first_active) {
311                cumsum += self.scorers[idx].max_score();
312                if cumsum > adjusted_threshold || self.collector.len() < self.collector.k {
313                    pivot_pos = pos;
314                    break;
315                }
316            }
317
318            let pivot_idx = sorted_indices[pivot_pos];
319            let pivot_doc = self.scorers[pivot_idx].doc();
320
321            if pivot_doc == u32::MAX {
322                break;
323            }
324
325            // Check if all iterators before pivot are at pivot_doc
326            let all_at_pivot = sorted_indices[first_active..=pivot_pos]
327                .iter()
328                .all(|&i| self.scorers[i].doc() == pivot_doc);
329
330            if all_at_pivot {
331                // Block-max WAND optimization: check if current blocks can beat threshold
332                // Sum the block-max scores for all iterators at pivot_doc
333                let block_max_sum: f32 = sorted_indices[first_active..=pivot_pos]
334                    .iter()
335                    .filter(|&&i| self.scorers[i].doc() == pivot_doc)
336                    .map(|&i| self.scorers[i].current_block_max_score())
337                    .sum();
338
339                if self.collector.len() >= self.collector.k && block_max_sum <= adjusted_threshold {
340                    // Block can't beat threshold - skip all iterators at pivot to next block
341                    debug!(
342                        "Block skip at doc {}: block_max={:.4} <= threshold={:.4}",
343                        pivot_doc, block_max_sum, adjusted_threshold
344                    );
345
346                    for (_pos, &idx) in sorted_indices.iter().enumerate().skip(first_active) {
347                        if self.scorers[idx].doc() == pivot_doc {
348                            self.scorers[idx].skip_to_next_block();
349                        } else if self.scorers[idx].doc() > pivot_doc {
350                            break;
351                        }
352                    }
353
354                    // Re-sort after block skips
355                    sorted_indices[first_active..].sort_by_key(|&i| self.scorers[i].doc());
356                    blocks_skipped += 1;
357                    continue;
358                }
359
360                // All terms up to pivot are at the same doc - fully score it
361                let mut score = 0.0f32;
362                let mut matching_terms = 0u32;
363                let mut ordinal: u16 = 0;
364
365                // Score from all iterators that have this document and advance them
366                // Collect indices that need re-sorting
367                let mut modified_positions: Vec<usize> = Vec::new();
368
369                for (pos, &idx) in sorted_indices.iter().enumerate().skip(first_active) {
370                    let doc = self.scorers[idx].doc();
371                    if doc == pivot_doc {
372                        score += self.scorers[idx].score();
373                        // Take ordinal from first matching scorer (all should have same ordinal)
374                        if matching_terms == 0 {
375                            ordinal = self.scorers[idx].ordinal();
376                        }
377                        matching_terms += 1;
378                        self.scorers[idx].advance();
379                        modified_positions.push(pos);
380                    } else if doc > pivot_doc {
381                        break;
382                    }
383                }
384
385                trace!(
386                    "Doc {}: score={:.4}, matching={}/{}, threshold={:.4}",
387                    pivot_doc, score, matching_terms, num_scorers, adjusted_threshold
388                );
389
390                if self
391                    .collector
392                    .insert_with_ordinal(pivot_doc, score, ordinal)
393                {
394                    docs_scored += 1;
395                } else {
396                    docs_skipped += 1;
397                }
398
399                // Re-sort modified iterators using insertion sort (efficient for nearly-sorted)
400                // Move each modified iterator to its correct position
401                for &pos in modified_positions.iter().rev() {
402                    let idx = sorted_indices[pos];
403                    let new_doc = self.scorers[idx].doc();
404                    // Bubble up to correct position
405                    let mut curr = pos;
406                    while curr + 1 < sorted_indices.len()
407                        && self.scorers[sorted_indices[curr + 1]].doc() < new_doc
408                    {
409                        sorted_indices.swap(curr, curr + 1);
410                        curr += 1;
411                    }
412                }
413            } else {
414                // Not all at pivot - skip the first iterator to pivot_doc
415                let first_pos = first_active;
416                let first_idx = sorted_indices[first_pos];
417                self.scorers[first_idx].seek(pivot_doc);
418                docs_skipped += 1;
419
420                // Re-sort the modified iterator
421                let new_doc = self.scorers[first_idx].doc();
422                let mut curr = first_pos;
423                while curr + 1 < sorted_indices.len()
424                    && self.scorers[sorted_indices[curr + 1]].doc() < new_doc
425                {
426                    sorted_indices.swap(curr, curr + 1);
427                    curr += 1;
428                }
429            }
430        }
431
432        let results: Vec<ScoredDoc> = self
433            .collector
434            .into_sorted_results()
435            .into_iter()
436            .map(|(doc_id, score, ordinal)| ScoredDoc {
437                doc_id,
438                score,
439                ordinal,
440            })
441            .collect();
442
443        debug!(
444            "WandExecutor completed: scored={}, skipped={}, blocks_skipped={}, returned={}, top_score={:.4}",
445            docs_scored,
446            docs_skipped,
447            blocks_skipped,
448            results.len(),
449            results.first().map(|r| r.score).unwrap_or(0.0)
450        );
451
452        results
453    }
454}
455
456/// Scorer for full-text terms using WAND optimization
457///
458/// Wraps a `BlockPostingList` with BM25 parameters to implement `ScoringIterator`.
459/// Enables MaxScore pruning for efficient top-k retrieval in OR queries.
460pub struct TextTermScorer {
461    /// Iterator over the posting list (owned)
462    iter: crate::structures::BlockPostingIterator<'static>,
463    /// IDF component for BM25
464    idf: f32,
465    /// Average field length for BM25 normalization
466    avg_field_len: f32,
467    /// Pre-computed max score (using max_tf from posting list)
468    max_score: f32,
469}
470
471impl TextTermScorer {
472    /// Create a new text term scorer with BM25 parameters
473    pub fn new(
474        posting_list: crate::structures::BlockPostingList,
475        idf: f32,
476        avg_field_len: f32,
477    ) -> Self {
478        // Compute max score using actual max_tf from posting list
479        let max_tf = posting_list.max_tf() as f32;
480        let doc_count = posting_list.doc_count();
481        let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
482
483        debug!(
484            "Created TextTermScorer: doc_count={}, max_tf={:.0}, idf={:.4}, avg_field_len={:.2}, max_score={:.4}",
485            doc_count, max_tf, idf, avg_field_len, max_score
486        );
487
488        Self {
489            iter: posting_list.into_iterator(),
490            idf,
491            avg_field_len,
492            max_score,
493        }
494    }
495}
496
497impl ScoringIterator for TextTermScorer {
498    #[inline]
499    fn doc(&self) -> DocId {
500        self.iter.doc()
501    }
502
503    #[inline]
504    fn advance(&mut self) -> DocId {
505        self.iter.advance()
506    }
507
508    #[inline]
509    fn seek(&mut self, target: DocId) -> DocId {
510        self.iter.seek(target)
511    }
512
513    #[inline]
514    fn score(&self) -> f32 {
515        let tf = self.iter.term_freq() as f32;
516        // Use tf as proxy for doc length (common approximation when field lengths aren't stored)
517        super::bm25_score(tf, self.idf, tf, self.avg_field_len)
518    }
519
520    #[inline]
521    fn max_score(&self) -> f32 {
522        self.max_score
523    }
524
525    #[inline]
526    fn current_block_max_score(&self) -> f32 {
527        // Use per-block max_tf for tighter Block-Max WAND bounds
528        let block_max_tf = self.iter.current_block_max_tf() as f32;
529        super::bm25_upper_bound(block_max_tf.max(1.0), self.idf)
530    }
531
532    #[inline]
533    fn skip_to_next_block(&mut self) -> DocId {
534        self.iter.skip_to_next_block()
535    }
536}
537
538/// Scorer for sparse vector dimensions
539///
540/// Wraps a `BlockSparsePostingList` with a query weight to implement `ScoringIterator`.
541pub struct SparseTermScorer<'a> {
542    /// Iterator over the posting list
543    iter: crate::structures::BlockSparsePostingIterator<'a>,
544    /// Query weight for this dimension
545    query_weight: f32,
546    /// Global max score (query_weight * global_max_weight)
547    max_score: f32,
548}
549
550impl<'a> SparseTermScorer<'a> {
551    /// Create a new sparse term scorer
552    ///
553    /// Note: Assumes positive weights for WAND upper bound calculation.
554    /// For negative query weights, uses absolute value to ensure valid upper bound.
555    pub fn new(posting_list: &'a BlockSparsePostingList, query_weight: f32) -> Self {
556        // Upper bound must account for sign: |query_weight| * max_weight
557        // This ensures the bound is valid regardless of weight sign
558        let max_score = query_weight.abs() * posting_list.global_max_weight();
559        Self {
560            iter: posting_list.iterator(),
561            query_weight,
562            max_score,
563        }
564    }
565
566    /// Create from Arc reference (for use with shared posting lists)
567    pub fn from_arc(posting_list: &'a Arc<BlockSparsePostingList>, query_weight: f32) -> Self {
568        Self::new(posting_list.as_ref(), query_weight)
569    }
570}
571
572impl ScoringIterator for SparseTermScorer<'_> {
573    #[inline]
574    fn doc(&self) -> DocId {
575        self.iter.doc()
576    }
577
578    #[inline]
579    fn ordinal(&self) -> u16 {
580        self.iter.ordinal()
581    }
582
583    #[inline]
584    fn advance(&mut self) -> DocId {
585        self.iter.advance()
586    }
587
588    #[inline]
589    fn seek(&mut self, target: DocId) -> DocId {
590        self.iter.seek(target)
591    }
592
593    #[inline]
594    fn score(&self) -> f32 {
595        // Dot product contribution: query_weight * stored_weight
596        self.query_weight * self.iter.weight()
597    }
598
599    #[inline]
600    fn max_score(&self) -> f32 {
601        self.max_score
602    }
603
604    #[inline]
605    fn current_block_max_score(&self) -> f32 {
606        // Use abs() for valid upper bound with negative weights
607        self.iter
608            .current_block_max_contribution(self.query_weight.abs())
609    }
610}
611
612#[cfg(test)]
613mod tests {
614    use super::*;
615
616    #[test]
617    fn test_score_collector_basic() {
618        let mut collector = ScoreCollector::new(3);
619
620        collector.insert(1, 1.0);
621        collector.insert(2, 2.0);
622        collector.insert(3, 3.0);
623        assert_eq!(collector.threshold(), 1.0);
624
625        collector.insert(4, 4.0);
626        assert_eq!(collector.threshold(), 2.0);
627
628        let results = collector.into_sorted_results();
629        assert_eq!(results.len(), 3);
630        assert_eq!(results[0].0, 4); // Highest score
631        assert_eq!(results[1].0, 3);
632        assert_eq!(results[2].0, 2);
633    }
634
635    #[test]
636    fn test_score_collector_threshold() {
637        let mut collector = ScoreCollector::new(2);
638
639        collector.insert(1, 5.0);
640        collector.insert(2, 3.0);
641        assert_eq!(collector.threshold(), 3.0);
642
643        // Should not enter (score too low)
644        assert!(!collector.would_enter(2.0));
645        assert!(!collector.insert(3, 2.0));
646
647        // Should enter (score high enough)
648        assert!(collector.would_enter(4.0));
649        assert!(collector.insert(4, 4.0));
650        assert_eq!(collector.threshold(), 4.0);
651    }
652
653    #[test]
654    fn test_heap_entry_ordering() {
655        let mut heap = BinaryHeap::new();
656        heap.push(HeapEntry {
657            doc_id: 1,
658            score: 3.0,
659            ordinal: 0,
660        });
661        heap.push(HeapEntry {
662            doc_id: 2,
663            score: 1.0,
664            ordinal: 0,
665        });
666        heap.push(HeapEntry {
667            doc_id: 3,
668            score: 2.0,
669            ordinal: 0,
670        });
671
672        // Min-heap: lowest score should come out first
673        assert_eq!(heap.pop().unwrap().score, 1.0);
674        assert_eq!(heap.pop().unwrap().score, 2.0);
675        assert_eq!(heap.pop().unwrap().score, 3.0);
676    }
677}