1use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12use std::sync::Arc;
13
14use log::{debug, trace};
15
16use crate::DocId;
17use crate::structures::BlockSparsePostingList;
18
19pub trait ScoringIterator {
23 fn doc(&self) -> DocId;
25
26 fn ordinal(&self) -> u16 {
28 0
29 }
30
31 fn advance(&mut self) -> DocId;
33
34 fn seek(&mut self, target: DocId) -> DocId;
36
37 fn is_exhausted(&self) -> bool {
39 self.doc() == u32::MAX
40 }
41
42 fn score(&self) -> f32;
44
45 fn max_score(&self) -> f32;
47
48 fn current_block_max_score(&self) -> f32;
50
51 fn skip_to_next_block(&mut self) -> DocId {
55 self.advance()
56 }
57}
58
59#[derive(Clone, Copy)]
61pub struct HeapEntry {
62 pub doc_id: DocId,
63 pub score: f32,
64 pub ordinal: u16,
65}
66
67impl PartialEq for HeapEntry {
68 fn eq(&self, other: &Self) -> bool {
69 self.score == other.score && self.doc_id == other.doc_id
70 }
71}
72
73impl Eq for HeapEntry {}
74
75impl Ord for HeapEntry {
76 fn cmp(&self, other: &Self) -> Ordering {
77 other
79 .score
80 .partial_cmp(&self.score)
81 .unwrap_or(Ordering::Equal)
82 .then_with(|| self.doc_id.cmp(&other.doc_id))
83 }
84}
85
86impl PartialOrd for HeapEntry {
87 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
88 Some(self.cmp(other))
89 }
90}
91
92pub struct ScoreCollector {
98 heap: BinaryHeap<HeapEntry>,
100 pub k: usize,
101}
102
103impl ScoreCollector {
104 pub fn new(k: usize) -> Self {
106 let capacity = k.saturating_add(1).min(1_000_000);
108 Self {
109 heap: BinaryHeap::with_capacity(capacity),
110 k,
111 }
112 }
113
114 #[inline]
116 pub fn threshold(&self) -> f32 {
117 if self.heap.len() >= self.k {
118 self.heap.peek().map(|e| e.score).unwrap_or(0.0)
119 } else {
120 0.0
121 }
122 }
123
124 #[inline]
127 pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
128 self.insert_with_ordinal(doc_id, score, 0)
129 }
130
131 #[inline]
134 pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
135 if self.heap.len() < self.k {
136 self.heap.push(HeapEntry {
137 doc_id,
138 score,
139 ordinal,
140 });
141 true
142 } else if score > self.threshold() {
143 self.heap.push(HeapEntry {
144 doc_id,
145 score,
146 ordinal,
147 });
148 self.heap.pop(); true
150 } else {
151 false
152 }
153 }
154
155 #[inline]
157 pub fn would_enter(&self, score: f32) -> bool {
158 self.heap.len() < self.k || score > self.threshold()
159 }
160
161 #[inline]
163 pub fn len(&self) -> usize {
164 self.heap.len()
165 }
166
167 #[inline]
169 pub fn is_empty(&self) -> bool {
170 self.heap.is_empty()
171 }
172
173 pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
175 let mut results: Vec<_> = self
176 .heap
177 .into_vec()
178 .into_iter()
179 .map(|e| (e.doc_id, e.score, e.ordinal))
180 .collect();
181
182 results.sort_by(|a, b| {
184 b.1.partial_cmp(&a.1)
185 .unwrap_or(Ordering::Equal)
186 .then_with(|| a.0.cmp(&b.0))
187 });
188
189 results
190 }
191}
192
193#[derive(Debug, Clone, Copy)]
195pub struct ScoredDoc {
196 pub doc_id: DocId,
197 pub score: f32,
198 pub ordinal: u16,
200}
201
202pub struct BlockMaxScoreExecutor<S: ScoringIterator> {
216 scorers: Vec<S>,
218 prefix_sums: Vec<f32>,
220 collector: ScoreCollector,
222 heap_factor: f32,
226}
227
228pub type WandExecutor<S> = BlockMaxScoreExecutor<S>;
230
231impl<S: ScoringIterator> BlockMaxScoreExecutor<S> {
232 pub fn new(scorers: Vec<S>, k: usize) -> Self {
234 Self::with_heap_factor(scorers, k, 1.0)
235 }
236
237 pub fn with_heap_factor(mut scorers: Vec<S>, k: usize, heap_factor: f32) -> Self {
244 scorers.sort_by(|a, b| {
246 a.max_score()
247 .partial_cmp(&b.max_score())
248 .unwrap_or(Ordering::Equal)
249 });
250
251 let mut prefix_sums = Vec::with_capacity(scorers.len());
253 let mut cumsum = 0.0f32;
254 for s in &scorers {
255 cumsum += s.max_score();
256 prefix_sums.push(cumsum);
257 }
258
259 debug!(
260 "Creating BlockMaxScoreExecutor: num_scorers={}, k={}, total_upper={:.4}, heap_factor={:.2}",
261 scorers.len(),
262 k,
263 cumsum,
264 heap_factor
265 );
266
267 Self {
268 scorers,
269 prefix_sums,
270 collector: ScoreCollector::new(k),
271 heap_factor: heap_factor.clamp(0.0, 1.0),
272 }
273 }
274
275 #[inline]
278 fn find_partition(&self) -> usize {
279 let threshold = self.collector.threshold() * self.heap_factor;
280 self.prefix_sums
281 .iter()
282 .position(|&sum| sum > threshold)
283 .unwrap_or(self.scorers.len())
284 }
285
286 pub fn execute(mut self) -> Vec<ScoredDoc> {
296 if self.scorers.is_empty() {
297 debug!("BlockMaxScoreExecutor: no scorers, returning empty results");
298 return Vec::new();
299 }
300
301 let n = self.scorers.len();
302 let mut docs_scored = 0u64;
303 let mut docs_skipped = 0u64;
304 let mut blocks_skipped = 0u64;
305 let mut conjunction_skipped = 0u64;
306
307 let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
309
310 loop {
311 let partition = self.find_partition();
312
313 if partition >= n {
315 debug!("BlockMaxScore: all terms non-essential, early termination");
316 break;
317 }
318
319 let mut min_doc = u32::MAX;
321 for i in partition..n {
322 let doc = self.scorers[i].doc();
323 if doc < min_doc {
324 min_doc = doc;
325 }
326 }
327
328 if min_doc == u32::MAX {
329 break; }
331
332 let non_essential_upper = if partition > 0 {
333 self.prefix_sums[partition - 1]
334 } else {
335 0.0
336 };
337 let adjusted_threshold = self.collector.threshold() * self.heap_factor;
338
339 if self.collector.len() >= self.collector.k {
344 let present_upper: f32 = (partition..n)
345 .filter(|&i| self.scorers[i].doc() == min_doc)
346 .map(|i| self.scorers[i].max_score())
347 .sum();
348
349 if present_upper + non_essential_upper <= adjusted_threshold {
350 for i in partition..n {
352 if self.scorers[i].doc() == min_doc {
353 self.scorers[i].advance();
354 }
355 }
356 conjunction_skipped += 1;
357 continue;
358 }
359 }
360
361 if self.collector.len() >= self.collector.k {
365 let block_max_sum: f32 = (partition..n)
366 .filter(|&i| self.scorers[i].doc() == min_doc)
367 .map(|i| self.scorers[i].current_block_max_score())
368 .sum();
369
370 if block_max_sum + non_essential_upper <= adjusted_threshold {
371 for i in partition..n {
372 if self.scorers[i].doc() == min_doc {
373 self.scorers[i].skip_to_next_block();
374 }
375 }
376 blocks_skipped += 1;
377 continue;
378 }
379 }
380
381 ordinal_scores.clear();
384
385 for i in partition..n {
386 if self.scorers[i].doc() == min_doc {
387 while self.scorers[i].doc() == min_doc {
388 ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
389 self.scorers[i].advance();
390 }
391 }
392 }
393
394 let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
396
397 if self.collector.len() >= self.collector.k
398 && essential_total + non_essential_upper <= adjusted_threshold
399 {
400 docs_skipped += 1;
401 continue;
402 }
403
404 for i in 0..partition {
406 let doc = self.scorers[i].seek(min_doc);
407 if doc == min_doc {
408 while self.scorers[i].doc() == min_doc {
409 ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
410 self.scorers[i].advance();
411 }
412 }
413 }
414
415 ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
417 let mut j = 0;
418 while j < ordinal_scores.len() {
419 let current_ord = ordinal_scores[j].0;
420 let mut score = 0.0f32;
421 while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
422 score += ordinal_scores[j].1;
423 j += 1;
424 }
425
426 trace!(
427 "Doc {}: ordinal={}, score={:.4}, threshold={:.4}",
428 min_doc, current_ord, score, adjusted_threshold
429 );
430
431 if self
432 .collector
433 .insert_with_ordinal(min_doc, score, current_ord)
434 {
435 docs_scored += 1;
436 } else {
437 docs_skipped += 1;
438 }
439 }
440 }
441
442 let results: Vec<ScoredDoc> = self
443 .collector
444 .into_sorted_results()
445 .into_iter()
446 .map(|(doc_id, score, ordinal)| ScoredDoc {
447 doc_id,
448 score,
449 ordinal,
450 })
451 .collect();
452
453 debug!(
454 "BlockMaxScoreExecutor completed: scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
455 docs_scored,
456 docs_skipped,
457 blocks_skipped,
458 conjunction_skipped,
459 results.len(),
460 results.first().map(|r| r.score).unwrap_or(0.0)
461 );
462
463 results
464 }
465}
466
467pub struct TextTermScorer {
472 iter: crate::structures::BlockPostingIterator<'static>,
474 idf: f32,
476 avg_field_len: f32,
478 max_score: f32,
480}
481
482impl TextTermScorer {
483 pub fn new(
485 posting_list: crate::structures::BlockPostingList,
486 idf: f32,
487 avg_field_len: f32,
488 ) -> Self {
489 let max_tf = posting_list.max_tf() as f32;
491 let doc_count = posting_list.doc_count();
492 let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
493
494 debug!(
495 "Created TextTermScorer: doc_count={}, max_tf={:.0}, idf={:.4}, avg_field_len={:.2}, max_score={:.4}",
496 doc_count, max_tf, idf, avg_field_len, max_score
497 );
498
499 Self {
500 iter: posting_list.into_iterator(),
501 idf,
502 avg_field_len,
503 max_score,
504 }
505 }
506}
507
508impl ScoringIterator for TextTermScorer {
509 #[inline]
510 fn doc(&self) -> DocId {
511 self.iter.doc()
512 }
513
514 #[inline]
515 fn advance(&mut self) -> DocId {
516 self.iter.advance()
517 }
518
519 #[inline]
520 fn seek(&mut self, target: DocId) -> DocId {
521 self.iter.seek(target)
522 }
523
524 #[inline]
525 fn score(&self) -> f32 {
526 let tf = self.iter.term_freq() as f32;
527 super::bm25_score(tf, self.idf, tf, self.avg_field_len)
529 }
530
531 #[inline]
532 fn max_score(&self) -> f32 {
533 self.max_score
534 }
535
536 #[inline]
537 fn current_block_max_score(&self) -> f32 {
538 let block_max_tf = self.iter.current_block_max_tf() as f32;
540 super::bm25_upper_bound(block_max_tf.max(1.0), self.idf)
541 }
542
543 #[inline]
544 fn skip_to_next_block(&mut self) -> DocId {
545 self.iter.skip_to_next_block()
546 }
547}
548
549pub struct SparseTermScorer<'a> {
553 iter: crate::structures::BlockSparsePostingIterator<'a>,
555 query_weight: f32,
557 max_score: f32,
559}
560
561impl<'a> SparseTermScorer<'a> {
562 pub fn new(posting_list: &'a BlockSparsePostingList, query_weight: f32) -> Self {
567 let max_score = query_weight.abs() * posting_list.global_max_weight();
570 Self {
571 iter: posting_list.iterator(),
572 query_weight,
573 max_score,
574 }
575 }
576
577 pub fn from_arc(posting_list: &'a Arc<BlockSparsePostingList>, query_weight: f32) -> Self {
579 Self::new(posting_list.as_ref(), query_weight)
580 }
581}
582
583impl ScoringIterator for SparseTermScorer<'_> {
584 #[inline]
585 fn doc(&self) -> DocId {
586 self.iter.doc()
587 }
588
589 #[inline]
590 fn ordinal(&self) -> u16 {
591 self.iter.ordinal()
592 }
593
594 #[inline]
595 fn advance(&mut self) -> DocId {
596 self.iter.advance()
597 }
598
599 #[inline]
600 fn seek(&mut self, target: DocId) -> DocId {
601 self.iter.seek(target)
602 }
603
604 #[inline]
605 fn score(&self) -> f32 {
606 self.query_weight * self.iter.weight()
608 }
609
610 #[inline]
611 fn max_score(&self) -> f32 {
612 self.max_score
613 }
614
615 #[inline]
616 fn current_block_max_score(&self) -> f32 {
617 self.iter
619 .current_block_max_contribution(self.query_weight.abs())
620 }
621
622 #[inline]
623 fn skip_to_next_block(&mut self) -> DocId {
624 self.iter.skip_to_next_block()
625 }
626}
627
628pub struct BmpExecutor {
637 posting_lists: Vec<Arc<BlockSparsePostingList>>,
639 query_weights: Vec<f32>,
641 k: usize,
643 heap_factor: f32,
645}
646
647struct BmpBlockEntry {
649 contribution: f32,
651 term_idx: usize,
653 block_idx: usize,
655}
656
657impl PartialEq for BmpBlockEntry {
658 fn eq(&self, other: &Self) -> bool {
659 self.contribution == other.contribution
660 }
661}
662
663impl Eq for BmpBlockEntry {}
664
665impl Ord for BmpBlockEntry {
666 fn cmp(&self, other: &Self) -> Ordering {
667 self.contribution
669 .partial_cmp(&other.contribution)
670 .unwrap_or(Ordering::Equal)
671 }
672}
673
674impl PartialOrd for BmpBlockEntry {
675 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
676 Some(self.cmp(other))
677 }
678}
679
680impl BmpExecutor {
681 pub fn new(
683 posting_lists: Vec<Arc<BlockSparsePostingList>>,
684 query_weights: Vec<f32>,
685 k: usize,
686 heap_factor: f32,
687 ) -> Self {
688 Self {
689 posting_lists,
690 query_weights,
691 k,
692 heap_factor: heap_factor.clamp(0.0, 1.0),
693 }
694 }
695
696 pub fn execute(self) -> Vec<ScoredDoc> {
698 use rustc_hash::FxHashMap;
699
700 if self.posting_lists.is_empty() {
701 return Vec::new();
702 }
703
704 let num_terms = self.posting_lists.len();
705
706 let mut block_queue: BinaryHeap<BmpBlockEntry> = BinaryHeap::new();
708
709 let mut remaining_max: Vec<f32> = Vec::with_capacity(num_terms);
711
712 for (term_idx, pl) in self.posting_lists.iter().enumerate() {
713 let qw = self.query_weights[term_idx].abs();
714 let mut term_remaining = 0.0f32;
715
716 for block_idx in 0..pl.num_blocks() {
717 if let Some(block_max_weight) = pl.block_max_weight(block_idx) {
718 let contribution = qw * block_max_weight;
719 term_remaining += contribution;
720 block_queue.push(BmpBlockEntry {
721 contribution,
722 term_idx,
723 block_idx,
724 });
725 }
726 }
727 remaining_max.push(term_remaining);
728 }
729
730 let mut accumulators: FxHashMap<(DocId, u16), f32> = FxHashMap::default();
734 let mut collector = ScoreCollector::new(self.k);
735 let mut blocks_processed = 0u64;
736
737 while let Some(entry) = block_queue.pop() {
739 remaining_max[entry.term_idx] -= entry.contribution;
741
742 let total_remaining: f32 = remaining_max.iter().sum();
745 let adjusted_threshold = collector.threshold() * self.heap_factor;
746 if collector.len() >= self.k && total_remaining <= adjusted_threshold {
747 debug!(
748 "BMP early termination after {} blocks: remaining={:.4} <= threshold={:.4}",
749 blocks_processed, total_remaining, adjusted_threshold
750 );
751 break;
752 }
753
754 let pl = &self.posting_lists[entry.term_idx];
756 let block = &pl.blocks[entry.block_idx];
757 let doc_ids = block.decode_doc_ids();
758 let weights = block.decode_weights();
759 let ordinals = block.decode_ordinals();
760 let qw = self.query_weights[entry.term_idx];
761
762 for i in 0..block.header.count as usize {
763 let score_contribution = qw * weights[i];
764 *accumulators.entry((doc_ids[i], ordinals[i])).or_insert(0.0) += score_contribution;
765 }
766
767 blocks_processed += 1;
768 }
769
770 for (&(doc_id, ordinal), &score) in &accumulators {
772 collector.insert_with_ordinal(doc_id, score, ordinal);
773 }
774
775 let results: Vec<ScoredDoc> = collector
776 .into_sorted_results()
777 .into_iter()
778 .map(|(doc_id, score, ordinal)| ScoredDoc {
779 doc_id,
780 score,
781 ordinal,
782 })
783 .collect();
784
785 debug!(
786 "BmpExecutor completed: blocks_processed={}, accumulators={}, returned={}, top_score={:.4}",
787 blocks_processed,
788 accumulators.len(),
789 results.len(),
790 results.first().map(|r| r.score).unwrap_or(0.0)
791 );
792
793 results
794 }
795}
796
797#[cfg(test)]
798mod tests {
799 use super::*;
800
801 #[test]
802 fn test_score_collector_basic() {
803 let mut collector = ScoreCollector::new(3);
804
805 collector.insert(1, 1.0);
806 collector.insert(2, 2.0);
807 collector.insert(3, 3.0);
808 assert_eq!(collector.threshold(), 1.0);
809
810 collector.insert(4, 4.0);
811 assert_eq!(collector.threshold(), 2.0);
812
813 let results = collector.into_sorted_results();
814 assert_eq!(results.len(), 3);
815 assert_eq!(results[0].0, 4); assert_eq!(results[1].0, 3);
817 assert_eq!(results[2].0, 2);
818 }
819
820 #[test]
821 fn test_score_collector_threshold() {
822 let mut collector = ScoreCollector::new(2);
823
824 collector.insert(1, 5.0);
825 collector.insert(2, 3.0);
826 assert_eq!(collector.threshold(), 3.0);
827
828 assert!(!collector.would_enter(2.0));
830 assert!(!collector.insert(3, 2.0));
831
832 assert!(collector.would_enter(4.0));
834 assert!(collector.insert(4, 4.0));
835 assert_eq!(collector.threshold(), 4.0);
836 }
837
838 #[test]
839 fn test_heap_entry_ordering() {
840 let mut heap = BinaryHeap::new();
841 heap.push(HeapEntry {
842 doc_id: 1,
843 score: 3.0,
844 ordinal: 0,
845 });
846 heap.push(HeapEntry {
847 doc_id: 2,
848 score: 1.0,
849 ordinal: 0,
850 });
851 heap.push(HeapEntry {
852 doc_id: 3,
853 score: 2.0,
854 ordinal: 0,
855 });
856
857 assert_eq!(heap.pop().unwrap().score, 1.0);
859 assert_eq!(heap.pop().unwrap().score, 2.0);
860 assert_eq!(heap.pop().unwrap().score, 3.0);
861 }
862}