hermes_core/query/
wand.rs

1//! BlockWAND and MaxScore query optimization with BM25F scoring
2//!
3//! Implements efficient top-k retrieval using:
4//! - MaxScore: skips terms that can't contribute to top-k
5//! - BlockWAND: block-level score upper bounds for early termination
6//! - BM25F: field-aware scoring with length normalization
7
8use std::cmp::Ordering;
9use std::collections::BinaryHeap;
10
11use crate::structures::{BitpackedPostingIterator, BitpackedPostingList};
12use crate::{DocId, Score};
13
14/// BM25F parameters for WAND scoring
15pub const WAND_K1: f32 = 1.2;
16pub const WAND_B: f32 = 0.75;
17
18/// Term scorer with MaxScore info and BM25F support
19pub struct TermScorer<'a> {
20    /// Iterator over postings
21    pub iter: BitpackedPostingIterator<'a>,
22    /// Maximum possible score for this term (computed with BM25F upper bound)
23    pub max_score: f32,
24    /// IDF component
25    pub idf: f32,
26    /// Term index (for tracking)
27    pub term_idx: usize,
28    /// Field boost for BM25F
29    pub field_boost: f32,
30    /// Average field length for BM25F length normalization
31    pub avg_field_len: f32,
32}
33
34impl<'a> TermScorer<'a> {
35    /// Create a new term scorer with default BM25F parameters
36    pub fn new(posting_list: &'a BitpackedPostingList, idf: f32, term_idx: usize) -> Self {
37        Self {
38            iter: posting_list.iterator(),
39            max_score: posting_list.max_score,
40            idf,
41            term_idx,
42            field_boost: 1.0,
43            avg_field_len: 1.0, // Default: no length normalization effect
44        }
45    }
46
47    /// Create a term scorer with BM25F parameters
48    pub fn with_bm25f(
49        posting_list: &'a BitpackedPostingList,
50        idf: f32,
51        term_idx: usize,
52        field_boost: f32,
53        avg_field_len: f32,
54    ) -> Self {
55        // Recompute max_score with field boost
56        let max_score = if field_boost != 1.0 {
57            // Find max_tf across all blocks and recompute upper bound
58            let max_tf = posting_list
59                .blocks
60                .iter()
61                .map(|b| b.max_tf)
62                .max()
63                .unwrap_or(1);
64            BitpackedPostingList::compute_bm25f_upper_bound(max_tf, idf, field_boost)
65        } else {
66            posting_list.max_score
67        };
68
69        Self {
70            iter: posting_list.iterator(),
71            max_score,
72            idf,
73            term_idx,
74            field_boost,
75            avg_field_len,
76        }
77    }
78
79    /// Current document
80    #[inline]
81    pub fn doc(&self) -> DocId {
82        self.iter.doc()
83    }
84
85    /// Compute BM25F score for current document
86    #[inline]
87    pub fn score(&self) -> Score {
88        let tf = self.iter.term_freq() as f32;
89
90        // BM25F scoring with length normalization
91        // Since we don't have per-doc field length, we approximate using tf
92        // This is a common approximation when field lengths aren't stored per-posting
93        let length_norm = 1.0 - WAND_B + WAND_B * (tf / self.avg_field_len.max(1.0));
94        let tf_norm = (tf * self.field_boost * (WAND_K1 + 1.0))
95            / (tf * self.field_boost + WAND_K1 * length_norm);
96
97        self.idf * tf_norm
98    }
99
100    /// Get current block's max score (for block-level pruning)
101    #[inline]
102    pub fn current_block_max_score(&self) -> f32 {
103        if self.field_boost == 1.0 {
104            self.iter.current_block_max_score()
105        } else {
106            // Recompute with field boost
107            let block_max_tf = self.iter.current_block_max_tf();
108            BitpackedPostingList::compute_bm25f_upper_bound(
109                block_max_tf,
110                self.idf,
111                self.field_boost,
112            )
113        }
114    }
115
116    /// Advance to next document
117    #[inline]
118    pub fn advance(&mut self) -> DocId {
119        self.iter.advance()
120    }
121
122    /// Seek to doc >= target
123    #[inline]
124    pub fn seek(&mut self, target: DocId) -> DocId {
125        self.iter.seek(target)
126    }
127
128    /// Is this scorer exhausted?
129    #[inline]
130    pub fn is_exhausted(&self) -> bool {
131        self.doc() == u32::MAX
132    }
133}
134
135/// Result entry for top-k heap
136#[derive(Clone, Copy)]
137struct HeapEntry {
138    doc_id: DocId,
139    score: Score,
140}
141
142impl PartialEq for HeapEntry {
143    fn eq(&self, other: &Self) -> bool {
144        self.score == other.score && self.doc_id == other.doc_id
145    }
146}
147
148impl Eq for HeapEntry {}
149
150impl PartialOrd for HeapEntry {
151    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
152        Some(self.cmp(other))
153    }
154}
155
156impl Ord for HeapEntry {
157    fn cmp(&self, other: &Self) -> Ordering {
158        // Min-heap: lower scores come first (to be evicted)
159        other
160            .score
161            .partial_cmp(&self.score)
162            .unwrap_or(Ordering::Equal)
163            .then_with(|| self.doc_id.cmp(&other.doc_id))
164    }
165}
166
167/// Search result
168#[derive(Debug, Clone, Copy)]
169pub struct WandResult {
170    pub doc_id: DocId,
171    pub score: Score,
172}
173
174/// MaxScore WAND algorithm for efficient top-k retrieval
175///
176/// Key optimizations:
177/// 1. Terms sorted by max_score descending
178/// 2. Threshold tracking: skip terms whose max_score < current threshold
179/// 3. Block-level skipping using block max scores
180pub struct MaxScoreWand<'a> {
181    /// Term scorers sorted by current doc_id
182    scorers: Vec<TermScorer<'a>>,
183    /// Top-k results heap
184    heap: BinaryHeap<HeapEntry>,
185    /// Number of results to return
186    k: usize,
187    /// Current score threshold (min score in top-k)
188    threshold: Score,
189    /// Sum of max_scores for "essential" terms (reserved for future use)
190    #[allow(dead_code)]
191    essential_max_sum: Score,
192}
193
194impl<'a> MaxScoreWand<'a> {
195    /// Create a new MaxScore WAND executor
196    pub fn new(mut scorers: Vec<TermScorer<'a>>, k: usize) -> Self {
197        // Sort scorers by max_score descending
198        scorers.sort_by(|a, b| {
199            b.max_score
200                .partial_cmp(&a.max_score)
201                .unwrap_or(Ordering::Equal)
202        });
203
204        let essential_max_sum: Score = scorers.iter().map(|s| s.max_score).sum();
205
206        Self {
207            scorers,
208            heap: BinaryHeap::with_capacity(k + 1),
209            k,
210            threshold: 0.0,
211            essential_max_sum,
212        }
213    }
214
215    /// Execute the query and return top-k results
216    pub fn execute(mut self) -> Vec<WandResult> {
217        if self.scorers.is_empty() {
218            return Vec::new();
219        }
220
221        // Remove exhausted scorers
222        self.scorers.retain(|s| !s.is_exhausted());
223
224        while !self.scorers.is_empty() {
225            // Sort by current doc_id
226            self.scorers.sort_by_key(|s| s.doc());
227
228            // Find pivot: first position where cumulative max_score >= threshold
229            let pivot_idx = self.find_pivot();
230
231            if pivot_idx.is_none() {
232                break;
233            }
234            let pivot_idx = pivot_idx.unwrap();
235            let pivot_doc = self.scorers[pivot_idx].doc();
236
237            if pivot_doc == u32::MAX {
238                break;
239            }
240
241            // Check if all scorers up to pivot are at pivot_doc
242            let all_at_pivot = self.scorers[..=pivot_idx]
243                .iter()
244                .all(|s| s.doc() == pivot_doc);
245
246            if all_at_pivot {
247                // Score this document
248                let score = self.score_document(pivot_doc);
249                self.maybe_insert(pivot_doc, score);
250
251                // Advance all scorers at pivot_doc
252                for scorer in &mut self.scorers {
253                    if scorer.doc() == pivot_doc {
254                        scorer.advance();
255                    }
256                }
257            } else {
258                // Advance scorers before pivot to pivot_doc
259                for i in 0..pivot_idx {
260                    if self.scorers[i].doc() < pivot_doc {
261                        self.scorers[i].seek(pivot_doc);
262                    }
263                }
264            }
265
266            // Remove exhausted scorers
267            self.scorers.retain(|s| !s.is_exhausted());
268        }
269
270        self.into_results()
271    }
272
273    /// Find pivot index where cumulative max_score >= threshold
274    fn find_pivot(&self) -> Option<usize> {
275        let mut cumsum = 0.0f32;
276
277        for (i, scorer) in self.scorers.iter().enumerate() {
278            cumsum += scorer.max_score;
279            if cumsum >= self.threshold {
280                return Some(i);
281            }
282        }
283
284        // If we can't reach threshold, we're done
285        if cumsum < self.threshold {
286            None
287        } else {
288            Some(self.scorers.len() - 1)
289        }
290    }
291
292    /// Score a document across all matching scorers
293    fn score_document(&self, doc_id: DocId) -> Score {
294        let mut score = 0.0;
295        for scorer in &self.scorers {
296            if scorer.doc() == doc_id {
297                score += scorer.score();
298            }
299        }
300        score
301    }
302
303    /// Insert into top-k heap if score is high enough
304    fn maybe_insert(&mut self, doc_id: DocId, score: Score) {
305        if self.heap.len() < self.k {
306            self.heap.push(HeapEntry { doc_id, score });
307            if self.heap.len() == self.k {
308                self.threshold = self.heap.peek().map(|e| e.score).unwrap_or(0.0);
309            }
310        } else if score > self.threshold {
311            self.heap.pop();
312            self.heap.push(HeapEntry { doc_id, score });
313            self.threshold = self.heap.peek().map(|e| e.score).unwrap_or(0.0);
314        }
315    }
316
317    /// Convert heap to sorted results
318    fn into_results(self) -> Vec<WandResult> {
319        let mut results: Vec<_> = self
320            .heap
321            .into_vec()
322            .into_iter()
323            .map(|e| WandResult {
324                doc_id: e.doc_id,
325                score: e.score,
326            })
327            .collect();
328
329        results.sort_by(|a, b| {
330            b.score
331                .partial_cmp(&a.score)
332                .unwrap_or(Ordering::Equal)
333                .then_with(|| a.doc_id.cmp(&b.doc_id))
334        });
335
336        results
337    }
338}
339
340/// BlockWAND: Block-level WAND with early termination
341///
342/// Uses block-level max scores for more aggressive skipping
343pub struct BlockWand<'a> {
344    scorers: Vec<TermScorer<'a>>,
345    heap: BinaryHeap<HeapEntry>,
346    k: usize,
347    threshold: Score,
348}
349
350impl<'a> BlockWand<'a> {
351    pub fn new(scorers: Vec<TermScorer<'a>>, k: usize) -> Self {
352        Self {
353            scorers,
354            heap: BinaryHeap::with_capacity(k + 1),
355            k,
356            threshold: 0.0,
357        }
358    }
359
360    /// Execute with block-level skipping
361    pub fn execute(mut self) -> Vec<WandResult> {
362        if self.scorers.is_empty() {
363            return Vec::new();
364        }
365
366        self.scorers.retain(|s| !s.is_exhausted());
367
368        while !self.scorers.is_empty() {
369            // Sort by current doc
370            self.scorers.sort_by_key(|s| s.doc());
371
372            // Find minimum doc across all scorers
373            let min_doc = self.scorers[0].doc();
374            if min_doc == u32::MAX {
375                break;
376            }
377
378            // Compute upper bound score for this doc using block max scores (BM25F aware)
379            let upper_bound: Score = self
380                .scorers
381                .iter()
382                .filter(|s| s.doc() <= min_doc || s.current_block_max_score() > 0.0)
383                .map(|s| {
384                    if s.doc() == min_doc {
385                        s.score() // Exact BM25F score
386                    } else {
387                        s.current_block_max_score() // BM25F upper bound
388                    }
389                })
390                .sum();
391
392            if upper_bound >= self.threshold {
393                // Need to evaluate this document
394                // First, advance all scorers to min_doc
395                for scorer in &mut self.scorers {
396                    if scorer.doc() < min_doc {
397                        scorer.seek(min_doc);
398                    }
399                }
400
401                // Score document
402                let score = self.score_document(min_doc);
403                self.maybe_insert(min_doc, score);
404            }
405
406            // Advance scorers at min_doc
407            for scorer in &mut self.scorers {
408                if scorer.doc() == min_doc {
409                    scorer.advance();
410                }
411            }
412
413            self.scorers.retain(|s| !s.is_exhausted());
414        }
415
416        self.into_results()
417    }
418
419    fn score_document(&self, doc_id: DocId) -> Score {
420        self.scorers
421            .iter()
422            .filter(|s| s.doc() == doc_id)
423            .map(|s| s.score())
424            .sum()
425    }
426
427    fn maybe_insert(&mut self, doc_id: DocId, score: Score) {
428        if self.heap.len() < self.k {
429            self.heap.push(HeapEntry { doc_id, score });
430            if self.heap.len() == self.k {
431                self.threshold = self.heap.peek().map(|e| e.score).unwrap_or(0.0);
432            }
433        } else if score > self.threshold {
434            self.heap.pop();
435            self.heap.push(HeapEntry { doc_id, score });
436            self.threshold = self.heap.peek().map(|e| e.score).unwrap_or(0.0);
437        }
438    }
439
440    fn into_results(self) -> Vec<WandResult> {
441        let mut results: Vec<_> = self
442            .heap
443            .into_vec()
444            .into_iter()
445            .map(|e| WandResult {
446                doc_id: e.doc_id,
447                score: e.score,
448            })
449            .collect();
450
451        results.sort_by(|a, b| {
452            b.score
453                .partial_cmp(&a.score)
454                .unwrap_or(Ordering::Equal)
455                .then_with(|| a.doc_id.cmp(&b.doc_id))
456        });
457
458        results
459    }
460}
461
462/// Simple DAAT (Document-At-A-Time) scorer for comparison
463pub fn daat_or<'a>(scorers: &mut [TermScorer<'a>], k: usize) -> Vec<WandResult> {
464    let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
465    let mut threshold = 0.0f32;
466
467    loop {
468        // Find minimum doc
469        let min_doc = scorers
470            .iter()
471            .filter(|s| !s.is_exhausted())
472            .map(|s| s.doc())
473            .min();
474
475        let min_doc = match min_doc {
476            Some(d) if d != u32::MAX => d,
477            _ => break,
478        };
479
480        // Score this document
481        let score: Score = scorers
482            .iter()
483            .filter(|s| s.doc() == min_doc)
484            .map(|s| s.score())
485            .sum();
486
487        // Insert if good enough
488        if heap.len() < k {
489            heap.push(HeapEntry {
490                doc_id: min_doc,
491                score,
492            });
493            if heap.len() == k {
494                threshold = heap.peek().map(|e| e.score).unwrap_or(0.0);
495            }
496        } else if score > threshold {
497            heap.pop();
498            heap.push(HeapEntry {
499                doc_id: min_doc,
500                score,
501            });
502            threshold = heap.peek().map(|e| e.score).unwrap_or(0.0);
503        }
504
505        // Advance scorers at min_doc
506        for scorer in scorers.iter_mut() {
507            if scorer.doc() == min_doc {
508                scorer.advance();
509            }
510        }
511    }
512
513    let mut results: Vec<_> = heap
514        .into_vec()
515        .into_iter()
516        .map(|e| WandResult {
517            doc_id: e.doc_id,
518            score: e.score,
519        })
520        .collect();
521
522    results.sort_by(|a, b| {
523        b.score
524            .partial_cmp(&a.score)
525            .unwrap_or(Ordering::Equal)
526            .then_with(|| a.doc_id.cmp(&b.doc_id))
527    });
528
529    results
530}
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535
536    fn create_test_posting_list(
537        doc_ids: &[u32],
538        term_freqs: &[u32],
539        idf: f32,
540    ) -> BitpackedPostingList {
541        BitpackedPostingList::from_postings(doc_ids, term_freqs, idf)
542    }
543
544    #[test]
545    fn test_maxscore_wand_basic() {
546        // Term 1: docs 1, 3, 5, 7
547        let pl1 = create_test_posting_list(&[1, 3, 5, 7], &[2, 1, 3, 1], 1.0);
548        // Term 2: docs 2, 3, 6, 7
549        let pl2 = create_test_posting_list(&[2, 3, 6, 7], &[1, 2, 1, 2], 1.5);
550
551        let scorers = vec![TermScorer::new(&pl1, 1.0, 0), TermScorer::new(&pl2, 1.5, 1)];
552
553        let results = MaxScoreWand::new(scorers, 3).execute();
554
555        assert!(!results.is_empty());
556        // Doc 3 and 7 should have highest scores (match both terms)
557        let top_docs: Vec<_> = results.iter().map(|r| r.doc_id).collect();
558        assert!(top_docs.contains(&3) || top_docs.contains(&7));
559    }
560
561    #[test]
562    fn test_block_wand_basic() {
563        let pl1 = create_test_posting_list(&[1, 3, 5, 7, 9], &[1, 2, 1, 3, 1], 1.0);
564        let pl2 = create_test_posting_list(&[2, 3, 7, 8], &[1, 1, 2, 1], 1.2);
565
566        let scorers = vec![TermScorer::new(&pl1, 1.0, 0), TermScorer::new(&pl2, 1.2, 1)];
567
568        let results = BlockWand::new(scorers, 5).execute();
569
570        assert!(!results.is_empty());
571        // Should find documents from both lists
572        let doc_ids: Vec<_> = results.iter().map(|r| r.doc_id).collect();
573        assert!(doc_ids.iter().any(|&d| d == 3 || d == 7)); // Intersection docs
574    }
575
576    #[test]
577    fn test_daat_or() {
578        let pl1 = create_test_posting_list(&[1, 2, 3], &[1, 1, 1], 1.0);
579        let pl2 = create_test_posting_list(&[2, 3, 4], &[1, 1, 1], 1.0);
580
581        let mut scorers = vec![TermScorer::new(&pl1, 1.0, 0), TermScorer::new(&pl2, 1.0, 1)];
582
583        let results = daat_or(&mut scorers, 10);
584
585        assert_eq!(results.len(), 4); // Docs 1, 2, 3, 4
586
587        // Docs 2 and 3 should have higher scores (match both)
588        assert!(results[0].doc_id == 2 || results[0].doc_id == 3);
589        assert!(results[1].doc_id == 2 || results[1].doc_id == 3);
590    }
591
592    #[test]
593    fn test_maxscore_threshold_pruning() {
594        // Create posting lists where MaxScore can prune effectively
595        // High-scoring term
596        let pl1 = create_test_posting_list(&[1, 100, 200], &[10, 10, 10], 2.0);
597        // Low-scoring term with many docs
598        let pl2 = create_test_posting_list(&(0..50).collect::<Vec<_>>(), &vec![1; 50], 0.1);
599
600        let scorers = vec![TermScorer::new(&pl1, 2.0, 0), TermScorer::new(&pl2, 0.1, 1)];
601
602        let results = MaxScoreWand::new(scorers, 3).execute();
603
604        // Top results should be from pl1 (higher scores)
605        assert!(
606            results
607                .iter()
608                .any(|r| r.doc_id == 1 || r.doc_id == 100 || r.doc_id == 200)
609        );
610    }
611}