Skip to main content

hermes_core/query/
scoring.rs

1//! Shared scoring abstractions for text and sparse vector search
2//!
3//! Provides common traits and utilities for efficient top-k retrieval:
4//! - `ScoringIterator`: Common interface for posting list iteration with scoring
5//! - `TopKCollector`: Efficient min-heap for maintaining top-k results
6//! - `WandExecutor`: Generic MaxScore WAND algorithm
7//! - `SparseTermScorer`: ScoringIterator implementation for sparse vectors
8
9use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11use std::sync::Arc;
12
13use log::{debug, trace};
14
15use crate::DocId;
16use crate::structures::BlockSparsePostingList;
17
18/// Common interface for scoring iterators (text terms or sparse dimensions)
19///
20/// Abstracts the common operations needed for WAND-style top-k retrieval.
21pub trait ScoringIterator {
22    /// Current document ID (u32::MAX if exhausted)
23    fn doc(&self) -> DocId;
24
25    /// Advance to next document, returns new doc ID
26    fn advance(&mut self) -> DocId;
27
28    /// Seek to first document >= target, returns new doc ID
29    fn seek(&mut self, target: DocId) -> DocId;
30
31    /// Check if iterator is exhausted
32    fn is_exhausted(&self) -> bool {
33        self.doc() == u32::MAX
34    }
35
36    /// Score contribution for current document
37    fn score(&self) -> f32;
38
39    /// Maximum possible score for this term/dimension (global upper bound)
40    fn max_score(&self) -> f32;
41
42    /// Current block's maximum score upper bound (for block-level pruning)
43    fn current_block_max_score(&self) -> f32;
44
45    /// Skip to the next block, returning the first doc_id in the new block.
46    /// Used for block-max WAND optimization when current block can't beat threshold.
47    /// Default implementation just advances (no block-level skipping).
48    fn skip_to_next_block(&mut self) -> DocId {
49        self.advance()
50    }
51}
52
53/// Entry for top-k min-heap
54#[derive(Clone, Copy)]
55pub struct HeapEntry {
56    pub doc_id: DocId,
57    pub score: f32,
58}
59
60impl PartialEq for HeapEntry {
61    fn eq(&self, other: &Self) -> bool {
62        self.score == other.score && self.doc_id == other.doc_id
63    }
64}
65
66impl Eq for HeapEntry {}
67
68impl Ord for HeapEntry {
69    fn cmp(&self, other: &Self) -> Ordering {
70        // Min-heap: lower scores come first (to be evicted)
71        other
72            .score
73            .partial_cmp(&self.score)
74            .unwrap_or(Ordering::Equal)
75            .then_with(|| self.doc_id.cmp(&other.doc_id))
76    }
77}
78
79impl PartialOrd for HeapEntry {
80    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
81        Some(self.cmp(other))
82    }
83}
84
85/// Efficient top-k collector using min-heap
86///
87/// Maintains the k highest-scoring documents using a min-heap where the
88/// lowest score is at the top for O(1) threshold lookup and O(log k) eviction.
89/// No deduplication - caller must ensure each doc_id is inserted only once.
90pub struct ScoreCollector {
91    /// Min-heap of top-k entries (lowest score at top for eviction)
92    heap: BinaryHeap<HeapEntry>,
93    pub k: usize,
94}
95
96impl ScoreCollector {
97    /// Create a new collector for top-k results
98    pub fn new(k: usize) -> Self {
99        Self {
100            heap: BinaryHeap::with_capacity(k + 1),
101            k,
102        }
103    }
104
105    /// Current score threshold (minimum score to enter top-k)
106    #[inline]
107    pub fn threshold(&self) -> f32 {
108        if self.heap.len() >= self.k {
109            self.heap.peek().map(|e| e.score).unwrap_or(0.0)
110        } else {
111            0.0
112        }
113    }
114
115    /// Insert a document score. Returns true if inserted in top-k.
116    /// Caller must ensure each doc_id is inserted only once.
117    #[inline]
118    pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
119        if self.heap.len() < self.k {
120            self.heap.push(HeapEntry { doc_id, score });
121            true
122        } else if score > self.threshold() {
123            self.heap.push(HeapEntry { doc_id, score });
124            self.heap.pop(); // Remove lowest
125            true
126        } else {
127            false
128        }
129    }
130
131    /// Check if a score could potentially enter top-k
132    #[inline]
133    pub fn would_enter(&self, score: f32) -> bool {
134        self.heap.len() < self.k || score > self.threshold()
135    }
136
137    /// Get number of documents collected so far
138    #[inline]
139    pub fn len(&self) -> usize {
140        self.heap.len()
141    }
142
143    /// Check if collector is empty
144    #[inline]
145    pub fn is_empty(&self) -> bool {
146        self.heap.is_empty()
147    }
148
149    /// Convert to sorted top-k results (descending by score)
150    pub fn into_sorted_results(self) -> Vec<(DocId, f32)> {
151        let mut results: Vec<_> = self
152            .heap
153            .into_vec()
154            .into_iter()
155            .map(|e| (e.doc_id, e.score))
156            .collect();
157
158        // Sort by score descending, then doc_id ascending
159        results.sort_by(|a, b| {
160            b.1.partial_cmp(&a.1)
161                .unwrap_or(Ordering::Equal)
162                .then_with(|| a.0.cmp(&b.0))
163        });
164
165        results
166    }
167}
168
169/// Search result from WAND execution
170#[derive(Debug, Clone, Copy)]
171pub struct ScoredDoc {
172    pub doc_id: DocId,
173    pub score: f32,
174}
175
176/// Generic MaxScore WAND executor for top-k retrieval
177///
178/// Works with any type implementing `ScoringIterator`.
179/// Implements:
180/// - WAND pivot-based pruning: skip documents that can't beat threshold
181/// - Block-max WAND: skip blocks that can't beat threshold
182/// - Efficient top-k collection
183pub struct WandExecutor<S: ScoringIterator> {
184    /// Scorers for each query term
185    scorers: Vec<S>,
186    /// Top-k collector
187    collector: ScoreCollector,
188    /// Heap factor for approximate search (SEISMIC-style)
189    /// A block/document is skipped if max_possible < heap_factor * threshold
190    /// - 1.0 = exact search (default)
191    /// - 0.8 = approximate, faster with minor recall loss
192    heap_factor: f32,
193}
194
195impl<S: ScoringIterator> WandExecutor<S> {
196    /// Create a new WAND executor with exact search (heap_factor = 1.0)
197    pub fn new(scorers: Vec<S>, k: usize) -> Self {
198        Self::with_heap_factor(scorers, k, 1.0)
199    }
200
201    /// Create a new WAND executor with approximate search
202    ///
203    /// `heap_factor` controls the trade-off between speed and recall:
204    /// - 1.0 = exact search
205    /// - 0.8 = ~20% faster, minor recall loss
206    /// - 0.5 = much faster, noticeable recall loss
207    pub fn with_heap_factor(scorers: Vec<S>, k: usize, heap_factor: f32) -> Self {
208        let total_upper: f32 = scorers.iter().map(|s| s.max_score()).sum();
209
210        debug!(
211            "Creating WandExecutor: num_scorers={}, k={}, total_upper={:.4}, heap_factor={:.2}",
212            scorers.len(),
213            k,
214            total_upper,
215            heap_factor
216        );
217
218        Self {
219            scorers,
220            collector: ScoreCollector::new(k),
221            heap_factor: heap_factor.clamp(0.0, 1.0),
222        }
223    }
224
225    /// Execute WAND and return top-k results
226    ///
227    /// Implements the WAND (Weak AND) algorithm with pivot-based pruning:
228    /// 1. Maintain iterators sorted by current docID (using sorted vector)
229    /// 2. Find pivot: first term where cumulative upper bounds > threshold
230    /// 3. If all iterators at pivot docID, fully score; otherwise skip to pivot
231    /// 4. Insert into collector and advance
232    ///
233    /// Reference: Broder et al., "Efficient Query Evaluation using a Two-Level
234    /// Retrieval Process" (CIKM 2003)
235    ///
236    /// Note: For small number of terms (typical queries), a sorted vector with
237    /// insertion sort is faster than a heap due to better cache locality.
238    /// The vector stays mostly sorted, so insertion sort is ~O(n) amortized.
239    pub fn execute(mut self) -> Vec<ScoredDoc> {
240        if self.scorers.is_empty() {
241            debug!("WandExecutor: no scorers, returning empty results");
242            return Vec::new();
243        }
244
245        let mut docs_scored = 0u64;
246        let mut docs_skipped = 0u64;
247        let num_scorers = self.scorers.len();
248
249        // Indices sorted by current docID - initial sort O(n log n)
250        let mut sorted_indices: Vec<usize> = (0..num_scorers).collect();
251        sorted_indices.sort_by_key(|&i| self.scorers[i].doc());
252
253        loop {
254            // Find first non-exhausted iterator (they're sorted, so check first)
255            let first_active = sorted_indices
256                .iter()
257                .position(|&i| self.scorers[i].doc() != u32::MAX);
258
259            let first_active = match first_active {
260                Some(pos) => pos,
261                None => break, // All exhausted
262            };
263
264            // Early termination: if total upper bound can't beat (adjusted) threshold
265            // heap_factor < 1.0 makes pruning more aggressive (approximate search)
266            let total_upper: f32 = sorted_indices[first_active..]
267                .iter()
268                .map(|&i| self.scorers[i].max_score())
269                .sum();
270
271            let adjusted_threshold = self.collector.threshold() * self.heap_factor;
272            if self.collector.len() >= self.collector.k && total_upper <= adjusted_threshold {
273                debug!(
274                    "Early termination: upper_bound={:.4} <= adjusted_threshold={:.4}",
275                    total_upper, adjusted_threshold
276                );
277                break;
278            }
279
280            // Find pivot: first term where cumulative upper bounds > adjusted threshold
281            let mut cumsum = 0.0f32;
282            let mut pivot_pos = first_active;
283
284            for (pos, &idx) in sorted_indices.iter().enumerate().skip(first_active) {
285                cumsum += self.scorers[idx].max_score();
286                if cumsum > adjusted_threshold || self.collector.len() < self.collector.k {
287                    pivot_pos = pos;
288                    break;
289                }
290            }
291
292            let pivot_idx = sorted_indices[pivot_pos];
293            let pivot_doc = self.scorers[pivot_idx].doc();
294
295            if pivot_doc == u32::MAX {
296                break;
297            }
298
299            // Check if all iterators before pivot are at pivot_doc
300            let all_at_pivot = sorted_indices[first_active..=pivot_pos]
301                .iter()
302                .all(|&i| self.scorers[i].doc() == pivot_doc);
303
304            if all_at_pivot {
305                // All terms up to pivot are at the same doc - fully score it
306                let mut score = 0.0f32;
307                let mut matching_terms = 0u32;
308
309                // Score from all iterators that have this document and advance them
310                // Collect indices that need re-sorting
311                let mut modified_positions: Vec<usize> = Vec::new();
312
313                for (pos, &idx) in sorted_indices.iter().enumerate().skip(first_active) {
314                    let doc = self.scorers[idx].doc();
315                    if doc == pivot_doc {
316                        score += self.scorers[idx].score();
317                        matching_terms += 1;
318                        self.scorers[idx].advance();
319                        modified_positions.push(pos);
320                    } else if doc > pivot_doc {
321                        break;
322                    }
323                }
324
325                trace!(
326                    "Doc {}: score={:.4}, matching={}/{}, threshold={:.4}",
327                    pivot_doc, score, matching_terms, num_scorers, adjusted_threshold
328                );
329
330                if self.collector.insert(pivot_doc, score) {
331                    docs_scored += 1;
332                } else {
333                    docs_skipped += 1;
334                }
335
336                // Re-sort modified iterators using insertion sort (efficient for nearly-sorted)
337                // Move each modified iterator to its correct position
338                for &pos in modified_positions.iter().rev() {
339                    let idx = sorted_indices[pos];
340                    let new_doc = self.scorers[idx].doc();
341                    // Bubble up to correct position
342                    let mut curr = pos;
343                    while curr + 1 < sorted_indices.len()
344                        && self.scorers[sorted_indices[curr + 1]].doc() < new_doc
345                    {
346                        sorted_indices.swap(curr, curr + 1);
347                        curr += 1;
348                    }
349                }
350            } else {
351                // Not all at pivot - skip the first iterator to pivot_doc
352                let first_pos = first_active;
353                let first_idx = sorted_indices[first_pos];
354                self.scorers[first_idx].seek(pivot_doc);
355                docs_skipped += 1;
356
357                // Re-sort the modified iterator
358                let new_doc = self.scorers[first_idx].doc();
359                let mut curr = first_pos;
360                while curr + 1 < sorted_indices.len()
361                    && self.scorers[sorted_indices[curr + 1]].doc() < new_doc
362                {
363                    sorted_indices.swap(curr, curr + 1);
364                    curr += 1;
365                }
366            }
367        }
368
369        let results: Vec<ScoredDoc> = self
370            .collector
371            .into_sorted_results()
372            .into_iter()
373            .map(|(doc_id, score)| ScoredDoc { doc_id, score })
374            .collect();
375
376        debug!(
377            "WandExecutor completed: scored={}, skipped={}, returned={}, top_score={:.4}",
378            docs_scored,
379            docs_skipped,
380            results.len(),
381            results.first().map(|r| r.score).unwrap_or(0.0)
382        );
383
384        results
385    }
386}
387
388/// Scorer for full-text terms using WAND optimization
389///
390/// Wraps a `BlockPostingList` with BM25 parameters to implement `ScoringIterator`.
391/// Enables MaxScore pruning for efficient top-k retrieval in OR queries.
392pub struct TextTermScorer {
393    /// Iterator over the posting list (owned)
394    iter: crate::structures::BlockPostingIterator<'static>,
395    /// IDF component for BM25
396    idf: f32,
397    /// Average field length for BM25 normalization
398    avg_field_len: f32,
399    /// Pre-computed max score (using max_tf from posting list)
400    max_score: f32,
401}
402
403impl TextTermScorer {
404    /// Create a new text term scorer with BM25 parameters
405    pub fn new(
406        posting_list: crate::structures::BlockPostingList,
407        idf: f32,
408        avg_field_len: f32,
409    ) -> Self {
410        // Compute max score using actual max_tf from posting list
411        let max_tf = posting_list.max_tf() as f32;
412        let doc_count = posting_list.doc_count();
413        let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
414
415        debug!(
416            "Created TextTermScorer: doc_count={}, max_tf={:.0}, idf={:.4}, avg_field_len={:.2}, max_score={:.4}",
417            doc_count, max_tf, idf, avg_field_len, max_score
418        );
419
420        Self {
421            iter: posting_list.into_iterator(),
422            idf,
423            avg_field_len,
424            max_score,
425        }
426    }
427}
428
429impl ScoringIterator for TextTermScorer {
430    #[inline]
431    fn doc(&self) -> DocId {
432        self.iter.doc()
433    }
434
435    #[inline]
436    fn advance(&mut self) -> DocId {
437        self.iter.advance()
438    }
439
440    #[inline]
441    fn seek(&mut self, target: DocId) -> DocId {
442        self.iter.seek(target)
443    }
444
445    #[inline]
446    fn score(&self) -> f32 {
447        let tf = self.iter.term_freq() as f32;
448        // Use tf as proxy for doc length (common approximation when field lengths aren't stored)
449        super::bm25_score(tf, self.idf, tf, self.avg_field_len)
450    }
451
452    #[inline]
453    fn max_score(&self) -> f32 {
454        self.max_score
455    }
456
457    #[inline]
458    fn current_block_max_score(&self) -> f32 {
459        // Use per-block max_tf for tighter Block-Max WAND bounds
460        let block_max_tf = self.iter.current_block_max_tf() as f32;
461        super::bm25_upper_bound(block_max_tf.max(1.0), self.idf)
462    }
463
464    #[inline]
465    fn skip_to_next_block(&mut self) -> DocId {
466        self.iter.skip_to_next_block()
467    }
468}
469
470/// Scorer for sparse vector dimensions
471///
472/// Wraps a `BlockSparsePostingList` with a query weight to implement `ScoringIterator`.
473pub struct SparseTermScorer<'a> {
474    /// Iterator over the posting list
475    iter: crate::structures::BlockSparsePostingIterator<'a>,
476    /// Query weight for this dimension
477    query_weight: f32,
478    /// Global max score (query_weight * global_max_weight)
479    max_score: f32,
480}
481
482impl<'a> SparseTermScorer<'a> {
483    /// Create a new sparse term scorer
484    pub fn new(posting_list: &'a BlockSparsePostingList, query_weight: f32) -> Self {
485        let max_score = query_weight * posting_list.global_max_weight();
486        Self {
487            iter: posting_list.iterator(),
488            query_weight,
489            max_score,
490        }
491    }
492
493    /// Create from Arc reference (for use with shared posting lists)
494    pub fn from_arc(posting_list: &'a Arc<BlockSparsePostingList>, query_weight: f32) -> Self {
495        Self::new(posting_list.as_ref(), query_weight)
496    }
497}
498
499impl ScoringIterator for SparseTermScorer<'_> {
500    #[inline]
501    fn doc(&self) -> DocId {
502        self.iter.doc()
503    }
504
505    #[inline]
506    fn advance(&mut self) -> DocId {
507        self.iter.advance()
508    }
509
510    #[inline]
511    fn seek(&mut self, target: DocId) -> DocId {
512        self.iter.seek(target)
513    }
514
515    #[inline]
516    fn score(&self) -> f32 {
517        // Dot product contribution: query_weight * stored_weight
518        self.query_weight * self.iter.weight()
519    }
520
521    #[inline]
522    fn max_score(&self) -> f32 {
523        self.max_score
524    }
525
526    #[inline]
527    fn current_block_max_score(&self) -> f32 {
528        self.iter.current_block_max_contribution(self.query_weight)
529    }
530}
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535
536    #[test]
537    fn test_score_collector_basic() {
538        let mut collector = ScoreCollector::new(3);
539
540        collector.insert(1, 1.0);
541        collector.insert(2, 2.0);
542        collector.insert(3, 3.0);
543        assert_eq!(collector.threshold(), 1.0);
544
545        collector.insert(4, 4.0);
546        assert_eq!(collector.threshold(), 2.0);
547
548        let results = collector.into_sorted_results();
549        assert_eq!(results.len(), 3);
550        assert_eq!(results[0].0, 4); // Highest score
551        assert_eq!(results[1].0, 3);
552        assert_eq!(results[2].0, 2);
553    }
554
555    #[test]
556    fn test_score_collector_threshold() {
557        let mut collector = ScoreCollector::new(2);
558
559        collector.insert(1, 5.0);
560        collector.insert(2, 3.0);
561        assert_eq!(collector.threshold(), 3.0);
562
563        // Should not enter (score too low)
564        assert!(!collector.would_enter(2.0));
565        assert!(!collector.insert(3, 2.0));
566
567        // Should enter (score high enough)
568        assert!(collector.would_enter(4.0));
569        assert!(collector.insert(4, 4.0));
570        assert_eq!(collector.threshold(), 4.0);
571    }
572
573    #[test]
574    fn test_heap_entry_ordering() {
575        let mut heap = BinaryHeap::new();
576        heap.push(HeapEntry {
577            doc_id: 1,
578            score: 3.0,
579        });
580        heap.push(HeapEntry {
581            doc_id: 2,
582            score: 1.0,
583        });
584        heap.push(HeapEntry {
585            doc_id: 3,
586            score: 2.0,
587        });
588
589        // Min-heap: lowest score should come out first
590        assert_eq!(heap.pop().unwrap().score, 1.0);
591        assert_eq!(heap.pop().unwrap().score, 2.0);
592        assert_eq!(heap.pop().unwrap().score, 3.0);
593    }
594}