Skip to main content

hermes_core/query/
scoring.rs

1//! Shared scoring abstractions for text and sparse vector search
2//!
3//! Provides common traits and utilities for efficient top-k retrieval:
4//! - `ScoringIterator`: Common interface for posting list iteration with scoring
5//! - `TopKCollector`: Efficient min-heap for maintaining top-k results
6//! - `WandExecutor`: Generic MaxScore WAND algorithm
7//! - `SparseTermScorer`: ScoringIterator implementation for sparse vectors
8
9use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11use std::sync::Arc;
12
13use log::{debug, trace};
14
15use crate::DocId;
16use crate::structures::BlockSparsePostingList;
17
18/// Common interface for scoring iterators (text terms or sparse dimensions)
19///
20/// Abstracts the common operations needed for WAND-style top-k retrieval.
21pub trait ScoringIterator {
22    /// Current document ID (u32::MAX if exhausted)
23    fn doc(&self) -> DocId;
24
25    /// Advance to next document, returns new doc ID
26    fn advance(&mut self) -> DocId;
27
28    /// Seek to first document >= target, returns new doc ID
29    fn seek(&mut self, target: DocId) -> DocId;
30
31    /// Check if iterator is exhausted
32    fn is_exhausted(&self) -> bool {
33        self.doc() == u32::MAX
34    }
35
36    /// Score contribution for current document
37    fn score(&self) -> f32;
38
39    /// Maximum possible score for this term/dimension (global upper bound)
40    fn max_score(&self) -> f32;
41
42    /// Current block's maximum score upper bound (for block-level pruning)
43    fn current_block_max_score(&self) -> f32;
44
45    /// Skip to the next block, returning the first doc_id in the new block.
46    /// Used for block-max WAND optimization when current block can't beat threshold.
47    /// Default implementation just advances (no block-level skipping).
48    fn skip_to_next_block(&mut self) -> DocId {
49        self.advance()
50    }
51}
52
53/// Entry for top-k min-heap
54#[derive(Clone, Copy)]
55pub struct HeapEntry {
56    pub doc_id: DocId,
57    pub score: f32,
58}
59
60impl PartialEq for HeapEntry {
61    fn eq(&self, other: &Self) -> bool {
62        self.score == other.score && self.doc_id == other.doc_id
63    }
64}
65
66impl Eq for HeapEntry {}
67
68impl Ord for HeapEntry {
69    fn cmp(&self, other: &Self) -> Ordering {
70        // Min-heap: lower scores come first (to be evicted)
71        other
72            .score
73            .partial_cmp(&self.score)
74            .unwrap_or(Ordering::Equal)
75            .then_with(|| self.doc_id.cmp(&other.doc_id))
76    }
77}
78
79impl PartialOrd for HeapEntry {
80    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
81        Some(self.cmp(other))
82    }
83}
84
85/// Efficient top-k collector using min-heap
86///
87/// Maintains the k highest-scoring documents using a min-heap where the
88/// lowest score is at the top for O(1) threshold lookup and O(log k) eviction.
89/// No deduplication - caller must ensure each doc_id is inserted only once.
90pub struct ScoreCollector {
91    /// Min-heap of top-k entries (lowest score at top for eviction)
92    heap: BinaryHeap<HeapEntry>,
93    pub k: usize,
94}
95
96impl ScoreCollector {
97    /// Create a new collector for top-k results
98    pub fn new(k: usize) -> Self {
99        // Cap capacity to avoid allocation overflow for very large k
100        let capacity = k.saturating_add(1).min(1_000_000);
101        Self {
102            heap: BinaryHeap::with_capacity(capacity),
103            k,
104        }
105    }
106
107    /// Current score threshold (minimum score to enter top-k)
108    #[inline]
109    pub fn threshold(&self) -> f32 {
110        if self.heap.len() >= self.k {
111            self.heap.peek().map(|e| e.score).unwrap_or(0.0)
112        } else {
113            0.0
114        }
115    }
116
117    /// Insert a document score. Returns true if inserted in top-k.
118    /// Caller must ensure each doc_id is inserted only once.
119    #[inline]
120    pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
121        if self.heap.len() < self.k {
122            self.heap.push(HeapEntry { doc_id, score });
123            true
124        } else if score > self.threshold() {
125            self.heap.push(HeapEntry { doc_id, score });
126            self.heap.pop(); // Remove lowest
127            true
128        } else {
129            false
130        }
131    }
132
133    /// Check if a score could potentially enter top-k
134    #[inline]
135    pub fn would_enter(&self, score: f32) -> bool {
136        self.heap.len() < self.k || score > self.threshold()
137    }
138
139    /// Get number of documents collected so far
140    #[inline]
141    pub fn len(&self) -> usize {
142        self.heap.len()
143    }
144
145    /// Check if collector is empty
146    #[inline]
147    pub fn is_empty(&self) -> bool {
148        self.heap.is_empty()
149    }
150
151    /// Convert to sorted top-k results (descending by score)
152    pub fn into_sorted_results(self) -> Vec<(DocId, f32)> {
153        let mut results: Vec<_> = self
154            .heap
155            .into_vec()
156            .into_iter()
157            .map(|e| (e.doc_id, e.score))
158            .collect();
159
160        // Sort by score descending, then doc_id ascending
161        results.sort_by(|a, b| {
162            b.1.partial_cmp(&a.1)
163                .unwrap_or(Ordering::Equal)
164                .then_with(|| a.0.cmp(&b.0))
165        });
166
167        results
168    }
169}
170
171/// Search result from WAND execution
172#[derive(Debug, Clone, Copy)]
173pub struct ScoredDoc {
174    pub doc_id: DocId,
175    pub score: f32,
176}
177
178/// Generic MaxScore WAND executor for top-k retrieval
179///
180/// Works with any type implementing `ScoringIterator`.
181/// Implements:
182/// - WAND pivot-based pruning: skip documents that can't beat threshold
183/// - Block-max WAND: skip blocks that can't beat threshold
184/// - Efficient top-k collection
185pub struct WandExecutor<S: ScoringIterator> {
186    /// Scorers for each query term
187    scorers: Vec<S>,
188    /// Top-k collector
189    collector: ScoreCollector,
190    /// Heap factor for approximate search (SEISMIC-style)
191    /// A block/document is skipped if max_possible < heap_factor * threshold
192    /// - 1.0 = exact search (default)
193    /// - 0.8 = approximate, faster with minor recall loss
194    heap_factor: f32,
195}
196
197impl<S: ScoringIterator> WandExecutor<S> {
198    /// Create a new WAND executor with exact search (heap_factor = 1.0)
199    pub fn new(scorers: Vec<S>, k: usize) -> Self {
200        Self::with_heap_factor(scorers, k, 1.0)
201    }
202
203    /// Create a new WAND executor with approximate search
204    ///
205    /// `heap_factor` controls the trade-off between speed and recall:
206    /// - 1.0 = exact search
207    /// - 0.8 = ~20% faster, minor recall loss
208    /// - 0.5 = much faster, noticeable recall loss
209    pub fn with_heap_factor(scorers: Vec<S>, k: usize, heap_factor: f32) -> Self {
210        let total_upper: f32 = scorers.iter().map(|s| s.max_score()).sum();
211
212        debug!(
213            "Creating WandExecutor: num_scorers={}, k={}, total_upper={:.4}, heap_factor={:.2}",
214            scorers.len(),
215            k,
216            total_upper,
217            heap_factor
218        );
219
220        Self {
221            scorers,
222            collector: ScoreCollector::new(k),
223            heap_factor: heap_factor.clamp(0.0, 1.0),
224        }
225    }
226
227    /// Execute WAND and return top-k results
228    ///
229    /// Implements the WAND (Weak AND) algorithm with pivot-based pruning:
230    /// 1. Maintain iterators sorted by current docID (using sorted vector)
231    /// 2. Find pivot: first term where cumulative upper bounds > threshold
232    /// 3. If all iterators at pivot docID, fully score; otherwise skip to pivot
233    /// 4. Insert into collector and advance
234    ///
235    /// Reference: Broder et al., "Efficient Query Evaluation using a Two-Level
236    /// Retrieval Process" (CIKM 2003)
237    ///
238    /// Note: For small number of terms (typical queries), a sorted vector with
239    /// insertion sort is faster than a heap due to better cache locality.
240    /// The vector stays mostly sorted, so insertion sort is ~O(n) amortized.
241    pub fn execute(mut self) -> Vec<ScoredDoc> {
242        if self.scorers.is_empty() {
243            debug!("WandExecutor: no scorers, returning empty results");
244            return Vec::new();
245        }
246
247        let mut docs_scored = 0u64;
248        let mut docs_skipped = 0u64;
249        let num_scorers = self.scorers.len();
250
251        // Indices sorted by current docID - initial sort O(n log n)
252        let mut sorted_indices: Vec<usize> = (0..num_scorers).collect();
253        sorted_indices.sort_by_key(|&i| self.scorers[i].doc());
254
255        loop {
256            // Find first non-exhausted iterator (they're sorted, so check first)
257            let first_active = sorted_indices
258                .iter()
259                .position(|&i| self.scorers[i].doc() != u32::MAX);
260
261            let first_active = match first_active {
262                Some(pos) => pos,
263                None => break, // All exhausted
264            };
265
266            // Early termination: if total upper bound can't beat (adjusted) threshold
267            // heap_factor < 1.0 makes pruning more aggressive (approximate search)
268            let total_upper: f32 = sorted_indices[first_active..]
269                .iter()
270                .map(|&i| self.scorers[i].max_score())
271                .sum();
272
273            let adjusted_threshold = self.collector.threshold() * self.heap_factor;
274            if self.collector.len() >= self.collector.k && total_upper <= adjusted_threshold {
275                debug!(
276                    "Early termination: upper_bound={:.4} <= adjusted_threshold={:.4}",
277                    total_upper, adjusted_threshold
278                );
279                break;
280            }
281
282            // Find pivot: first term where cumulative upper bounds > adjusted threshold
283            let mut cumsum = 0.0f32;
284            let mut pivot_pos = first_active;
285
286            for (pos, &idx) in sorted_indices.iter().enumerate().skip(first_active) {
287                cumsum += self.scorers[idx].max_score();
288                if cumsum > adjusted_threshold || self.collector.len() < self.collector.k {
289                    pivot_pos = pos;
290                    break;
291                }
292            }
293
294            let pivot_idx = sorted_indices[pivot_pos];
295            let pivot_doc = self.scorers[pivot_idx].doc();
296
297            if pivot_doc == u32::MAX {
298                break;
299            }
300
301            // Check if all iterators before pivot are at pivot_doc
302            let all_at_pivot = sorted_indices[first_active..=pivot_pos]
303                .iter()
304                .all(|&i| self.scorers[i].doc() == pivot_doc);
305
306            if all_at_pivot {
307                // All terms up to pivot are at the same doc - fully score it
308                let mut score = 0.0f32;
309                let mut matching_terms = 0u32;
310
311                // Score from all iterators that have this document and advance them
312                // Collect indices that need re-sorting
313                let mut modified_positions: Vec<usize> = Vec::new();
314
315                for (pos, &idx) in sorted_indices.iter().enumerate().skip(first_active) {
316                    let doc = self.scorers[idx].doc();
317                    if doc == pivot_doc {
318                        score += self.scorers[idx].score();
319                        matching_terms += 1;
320                        self.scorers[idx].advance();
321                        modified_positions.push(pos);
322                    } else if doc > pivot_doc {
323                        break;
324                    }
325                }
326
327                trace!(
328                    "Doc {}: score={:.4}, matching={}/{}, threshold={:.4}",
329                    pivot_doc, score, matching_terms, num_scorers, adjusted_threshold
330                );
331
332                if self.collector.insert(pivot_doc, score) {
333                    docs_scored += 1;
334                } else {
335                    docs_skipped += 1;
336                }
337
338                // Re-sort modified iterators using insertion sort (efficient for nearly-sorted)
339                // Move each modified iterator to its correct position
340                for &pos in modified_positions.iter().rev() {
341                    let idx = sorted_indices[pos];
342                    let new_doc = self.scorers[idx].doc();
343                    // Bubble up to correct position
344                    let mut curr = pos;
345                    while curr + 1 < sorted_indices.len()
346                        && self.scorers[sorted_indices[curr + 1]].doc() < new_doc
347                    {
348                        sorted_indices.swap(curr, curr + 1);
349                        curr += 1;
350                    }
351                }
352            } else {
353                // Not all at pivot - skip the first iterator to pivot_doc
354                let first_pos = first_active;
355                let first_idx = sorted_indices[first_pos];
356                self.scorers[first_idx].seek(pivot_doc);
357                docs_skipped += 1;
358
359                // Re-sort the modified iterator
360                let new_doc = self.scorers[first_idx].doc();
361                let mut curr = first_pos;
362                while curr + 1 < sorted_indices.len()
363                    && self.scorers[sorted_indices[curr + 1]].doc() < new_doc
364                {
365                    sorted_indices.swap(curr, curr + 1);
366                    curr += 1;
367                }
368            }
369        }
370
371        let results: Vec<ScoredDoc> = self
372            .collector
373            .into_sorted_results()
374            .into_iter()
375            .map(|(doc_id, score)| ScoredDoc { doc_id, score })
376            .collect();
377
378        debug!(
379            "WandExecutor completed: scored={}, skipped={}, returned={}, top_score={:.4}",
380            docs_scored,
381            docs_skipped,
382            results.len(),
383            results.first().map(|r| r.score).unwrap_or(0.0)
384        );
385
386        results
387    }
388}
389
390/// Scorer for full-text terms using WAND optimization
391///
392/// Wraps a `BlockPostingList` with BM25 parameters to implement `ScoringIterator`.
393/// Enables MaxScore pruning for efficient top-k retrieval in OR queries.
394pub struct TextTermScorer {
395    /// Iterator over the posting list (owned)
396    iter: crate::structures::BlockPostingIterator<'static>,
397    /// IDF component for BM25
398    idf: f32,
399    /// Average field length for BM25 normalization
400    avg_field_len: f32,
401    /// Pre-computed max score (using max_tf from posting list)
402    max_score: f32,
403}
404
405impl TextTermScorer {
406    /// Create a new text term scorer with BM25 parameters
407    pub fn new(
408        posting_list: crate::structures::BlockPostingList,
409        idf: f32,
410        avg_field_len: f32,
411    ) -> Self {
412        // Compute max score using actual max_tf from posting list
413        let max_tf = posting_list.max_tf() as f32;
414        let doc_count = posting_list.doc_count();
415        let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
416
417        debug!(
418            "Created TextTermScorer: doc_count={}, max_tf={:.0}, idf={:.4}, avg_field_len={:.2}, max_score={:.4}",
419            doc_count, max_tf, idf, avg_field_len, max_score
420        );
421
422        Self {
423            iter: posting_list.into_iterator(),
424            idf,
425            avg_field_len,
426            max_score,
427        }
428    }
429}
430
431impl ScoringIterator for TextTermScorer {
432    #[inline]
433    fn doc(&self) -> DocId {
434        self.iter.doc()
435    }
436
437    #[inline]
438    fn advance(&mut self) -> DocId {
439        self.iter.advance()
440    }
441
442    #[inline]
443    fn seek(&mut self, target: DocId) -> DocId {
444        self.iter.seek(target)
445    }
446
447    #[inline]
448    fn score(&self) -> f32 {
449        let tf = self.iter.term_freq() as f32;
450        // Use tf as proxy for doc length (common approximation when field lengths aren't stored)
451        super::bm25_score(tf, self.idf, tf, self.avg_field_len)
452    }
453
454    #[inline]
455    fn max_score(&self) -> f32 {
456        self.max_score
457    }
458
459    #[inline]
460    fn current_block_max_score(&self) -> f32 {
461        // Use per-block max_tf for tighter Block-Max WAND bounds
462        let block_max_tf = self.iter.current_block_max_tf() as f32;
463        super::bm25_upper_bound(block_max_tf.max(1.0), self.idf)
464    }
465
466    #[inline]
467    fn skip_to_next_block(&mut self) -> DocId {
468        self.iter.skip_to_next_block()
469    }
470}
471
472/// Scorer for sparse vector dimensions
473///
474/// Wraps a `BlockSparsePostingList` with a query weight to implement `ScoringIterator`.
475pub struct SparseTermScorer<'a> {
476    /// Iterator over the posting list
477    iter: crate::structures::BlockSparsePostingIterator<'a>,
478    /// Query weight for this dimension
479    query_weight: f32,
480    /// Global max score (query_weight * global_max_weight)
481    max_score: f32,
482}
483
484impl<'a> SparseTermScorer<'a> {
485    /// Create a new sparse term scorer
486    pub fn new(posting_list: &'a BlockSparsePostingList, query_weight: f32) -> Self {
487        let max_score = query_weight * posting_list.global_max_weight();
488        Self {
489            iter: posting_list.iterator(),
490            query_weight,
491            max_score,
492        }
493    }
494
495    /// Create from Arc reference (for use with shared posting lists)
496    pub fn from_arc(posting_list: &'a Arc<BlockSparsePostingList>, query_weight: f32) -> Self {
497        Self::new(posting_list.as_ref(), query_weight)
498    }
499}
500
501impl ScoringIterator for SparseTermScorer<'_> {
502    #[inline]
503    fn doc(&self) -> DocId {
504        self.iter.doc()
505    }
506
507    #[inline]
508    fn advance(&mut self) -> DocId {
509        self.iter.advance()
510    }
511
512    #[inline]
513    fn seek(&mut self, target: DocId) -> DocId {
514        self.iter.seek(target)
515    }
516
517    #[inline]
518    fn score(&self) -> f32 {
519        // Dot product contribution: query_weight * stored_weight
520        self.query_weight * self.iter.weight()
521    }
522
523    #[inline]
524    fn max_score(&self) -> f32 {
525        self.max_score
526    }
527
528    #[inline]
529    fn current_block_max_score(&self) -> f32 {
530        self.iter.current_block_max_contribution(self.query_weight)
531    }
532}
533
534#[cfg(test)]
535mod tests {
536    use super::*;
537
538    #[test]
539    fn test_score_collector_basic() {
540        let mut collector = ScoreCollector::new(3);
541
542        collector.insert(1, 1.0);
543        collector.insert(2, 2.0);
544        collector.insert(3, 3.0);
545        assert_eq!(collector.threshold(), 1.0);
546
547        collector.insert(4, 4.0);
548        assert_eq!(collector.threshold(), 2.0);
549
550        let results = collector.into_sorted_results();
551        assert_eq!(results.len(), 3);
552        assert_eq!(results[0].0, 4); // Highest score
553        assert_eq!(results[1].0, 3);
554        assert_eq!(results[2].0, 2);
555    }
556
557    #[test]
558    fn test_score_collector_threshold() {
559        let mut collector = ScoreCollector::new(2);
560
561        collector.insert(1, 5.0);
562        collector.insert(2, 3.0);
563        assert_eq!(collector.threshold(), 3.0);
564
565        // Should not enter (score too low)
566        assert!(!collector.would_enter(2.0));
567        assert!(!collector.insert(3, 2.0));
568
569        // Should enter (score high enough)
570        assert!(collector.would_enter(4.0));
571        assert!(collector.insert(4, 4.0));
572        assert_eq!(collector.threshold(), 4.0);
573    }
574
575    #[test]
576    fn test_heap_entry_ordering() {
577        let mut heap = BinaryHeap::new();
578        heap.push(HeapEntry {
579            doc_id: 1,
580            score: 3.0,
581        });
582        heap.push(HeapEntry {
583            doc_id: 2,
584            score: 1.0,
585        });
586        heap.push(HeapEntry {
587            doc_id: 3,
588            score: 2.0,
589        });
590
591        // Min-heap: lowest score should come out first
592        assert_eq!(heap.pop().unwrap().score, 1.0);
593        assert_eq!(heap.pop().unwrap().score, 2.0);
594        assert_eq!(heap.pop().unwrap().score, 3.0);
595    }
596}