1use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11use std::sync::Arc;
12
13use log::{debug, trace};
14
15use crate::DocId;
16use crate::structures::BlockSparsePostingList;
17
18pub trait ScoringIterator {
22 fn doc(&self) -> DocId;
24
25 fn ordinal(&self) -> u16 {
27 0
28 }
29
30 fn advance(&mut self) -> DocId;
32
33 fn seek(&mut self, target: DocId) -> DocId;
35
36 fn is_exhausted(&self) -> bool {
38 self.doc() == u32::MAX
39 }
40
41 fn score(&self) -> f32;
43
44 fn max_score(&self) -> f32;
46
47 fn current_block_max_score(&self) -> f32;
49
50 fn skip_to_next_block(&mut self) -> DocId {
54 self.advance()
55 }
56}
57
58#[derive(Clone, Copy)]
60pub struct HeapEntry {
61 pub doc_id: DocId,
62 pub score: f32,
63 pub ordinal: u16,
64}
65
66impl PartialEq for HeapEntry {
67 fn eq(&self, other: &Self) -> bool {
68 self.score == other.score && self.doc_id == other.doc_id
69 }
70}
71
72impl Eq for HeapEntry {}
73
74impl Ord for HeapEntry {
75 fn cmp(&self, other: &Self) -> Ordering {
76 other
78 .score
79 .partial_cmp(&self.score)
80 .unwrap_or(Ordering::Equal)
81 .then_with(|| self.doc_id.cmp(&other.doc_id))
82 }
83}
84
85impl PartialOrd for HeapEntry {
86 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
87 Some(self.cmp(other))
88 }
89}
90
91pub struct ScoreCollector {
97 heap: BinaryHeap<HeapEntry>,
99 pub k: usize,
100}
101
102impl ScoreCollector {
103 pub fn new(k: usize) -> Self {
105 let capacity = k.saturating_add(1).min(1_000_000);
107 Self {
108 heap: BinaryHeap::with_capacity(capacity),
109 k,
110 }
111 }
112
113 #[inline]
115 pub fn threshold(&self) -> f32 {
116 if self.heap.len() >= self.k {
117 self.heap.peek().map(|e| e.score).unwrap_or(0.0)
118 } else {
119 0.0
120 }
121 }
122
123 #[inline]
126 pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
127 self.insert_with_ordinal(doc_id, score, 0)
128 }
129
130 #[inline]
133 pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
134 if self.heap.len() < self.k {
135 self.heap.push(HeapEntry {
136 doc_id,
137 score,
138 ordinal,
139 });
140 true
141 } else if score > self.threshold() {
142 self.heap.push(HeapEntry {
143 doc_id,
144 score,
145 ordinal,
146 });
147 self.heap.pop(); true
149 } else {
150 false
151 }
152 }
153
154 #[inline]
156 pub fn would_enter(&self, score: f32) -> bool {
157 self.heap.len() < self.k || score > self.threshold()
158 }
159
160 #[inline]
162 pub fn len(&self) -> usize {
163 self.heap.len()
164 }
165
166 #[inline]
168 pub fn is_empty(&self) -> bool {
169 self.heap.is_empty()
170 }
171
172 pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
174 let mut results: Vec<_> = self
175 .heap
176 .into_vec()
177 .into_iter()
178 .map(|e| (e.doc_id, e.score, e.ordinal))
179 .collect();
180
181 results.sort_by(|a, b| {
183 b.1.partial_cmp(&a.1)
184 .unwrap_or(Ordering::Equal)
185 .then_with(|| a.0.cmp(&b.0))
186 });
187
188 results
189 }
190}
191
192#[derive(Debug, Clone, Copy)]
194pub struct ScoredDoc {
195 pub doc_id: DocId,
196 pub score: f32,
197 pub ordinal: u16,
199}
200
201pub struct WandExecutor<S: ScoringIterator> {
209 scorers: Vec<S>,
211 collector: ScoreCollector,
213 heap_factor: f32,
218}
219
220impl<S: ScoringIterator> WandExecutor<S> {
221 pub fn new(scorers: Vec<S>, k: usize) -> Self {
223 Self::with_heap_factor(scorers, k, 1.0)
224 }
225
226 pub fn with_heap_factor(scorers: Vec<S>, k: usize, heap_factor: f32) -> Self {
233 let total_upper: f32 = scorers.iter().map(|s| s.max_score()).sum();
234
235 debug!(
236 "Creating WandExecutor: num_scorers={}, k={}, total_upper={:.4}, heap_factor={:.2}",
237 scorers.len(),
238 k,
239 total_upper,
240 heap_factor
241 );
242
243 Self {
244 scorers,
245 collector: ScoreCollector::new(k),
246 heap_factor: heap_factor.clamp(0.0, 1.0),
247 }
248 }
249
250 pub fn execute(mut self) -> Vec<ScoredDoc> {
265 if self.scorers.is_empty() {
266 debug!("WandExecutor: no scorers, returning empty results");
267 return Vec::new();
268 }
269
270 let mut docs_scored = 0u64;
271 let mut docs_skipped = 0u64;
272 let mut blocks_skipped = 0u64;
273 let num_scorers = self.scorers.len();
274
275 let mut sorted_indices: Vec<usize> = (0..num_scorers).collect();
277 sorted_indices.sort_by_key(|&i| self.scorers[i].doc());
278
279 loop {
280 let first_active = sorted_indices
282 .iter()
283 .position(|&i| self.scorers[i].doc() != u32::MAX);
284
285 let first_active = match first_active {
286 Some(pos) => pos,
287 None => break, };
289
290 let total_upper: f32 = sorted_indices[first_active..]
293 .iter()
294 .map(|&i| self.scorers[i].max_score())
295 .sum();
296
297 let adjusted_threshold = self.collector.threshold() * self.heap_factor;
298 if self.collector.len() >= self.collector.k && total_upper <= adjusted_threshold {
299 debug!(
300 "Early termination: upper_bound={:.4} <= adjusted_threshold={:.4}",
301 total_upper, adjusted_threshold
302 );
303 break;
304 }
305
306 let mut cumsum = 0.0f32;
308 let mut pivot_pos = first_active;
309
310 for (pos, &idx) in sorted_indices.iter().enumerate().skip(first_active) {
311 cumsum += self.scorers[idx].max_score();
312 if cumsum > adjusted_threshold || self.collector.len() < self.collector.k {
313 pivot_pos = pos;
314 break;
315 }
316 }
317
318 let pivot_idx = sorted_indices[pivot_pos];
319 let pivot_doc = self.scorers[pivot_idx].doc();
320
321 if pivot_doc == u32::MAX {
322 break;
323 }
324
325 let all_at_pivot = sorted_indices[first_active..=pivot_pos]
327 .iter()
328 .all(|&i| self.scorers[i].doc() == pivot_doc);
329
330 if all_at_pivot {
331 let block_max_sum: f32 = sorted_indices[first_active..=pivot_pos]
334 .iter()
335 .filter(|&&i| self.scorers[i].doc() == pivot_doc)
336 .map(|&i| self.scorers[i].current_block_max_score())
337 .sum();
338
339 if self.collector.len() >= self.collector.k && block_max_sum <= adjusted_threshold {
340 debug!(
342 "Block skip at doc {}: block_max={:.4} <= threshold={:.4}",
343 pivot_doc, block_max_sum, adjusted_threshold
344 );
345
346 for (_pos, &idx) in sorted_indices.iter().enumerate().skip(first_active) {
347 if self.scorers[idx].doc() == pivot_doc {
348 self.scorers[idx].skip_to_next_block();
349 } else if self.scorers[idx].doc() > pivot_doc {
350 break;
351 }
352 }
353
354 sorted_indices[first_active..].sort_by_key(|&i| self.scorers[i].doc());
356 blocks_skipped += 1;
357 continue;
358 }
359
360 let mut score = 0.0f32;
362 let mut matching_terms = 0u32;
363 let mut ordinal: u16 = 0;
364
365 let mut modified_positions: Vec<usize> = Vec::new();
368
369 for (pos, &idx) in sorted_indices.iter().enumerate().skip(first_active) {
370 let doc = self.scorers[idx].doc();
371 if doc == pivot_doc {
372 score += self.scorers[idx].score();
373 if matching_terms == 0 {
375 ordinal = self.scorers[idx].ordinal();
376 }
377 matching_terms += 1;
378 self.scorers[idx].advance();
379 modified_positions.push(pos);
380 } else if doc > pivot_doc {
381 break;
382 }
383 }
384
385 trace!(
386 "Doc {}: score={:.4}, matching={}/{}, threshold={:.4}",
387 pivot_doc, score, matching_terms, num_scorers, adjusted_threshold
388 );
389
390 if self
391 .collector
392 .insert_with_ordinal(pivot_doc, score, ordinal)
393 {
394 docs_scored += 1;
395 } else {
396 docs_skipped += 1;
397 }
398
399 for &pos in modified_positions.iter().rev() {
402 let idx = sorted_indices[pos];
403 let new_doc = self.scorers[idx].doc();
404 let mut curr = pos;
406 while curr + 1 < sorted_indices.len()
407 && self.scorers[sorted_indices[curr + 1]].doc() < new_doc
408 {
409 sorted_indices.swap(curr, curr + 1);
410 curr += 1;
411 }
412 }
413 } else {
414 let first_pos = first_active;
416 let first_idx = sorted_indices[first_pos];
417 self.scorers[first_idx].seek(pivot_doc);
418 docs_skipped += 1;
419
420 let new_doc = self.scorers[first_idx].doc();
422 let mut curr = first_pos;
423 while curr + 1 < sorted_indices.len()
424 && self.scorers[sorted_indices[curr + 1]].doc() < new_doc
425 {
426 sorted_indices.swap(curr, curr + 1);
427 curr += 1;
428 }
429 }
430 }
431
432 let results: Vec<ScoredDoc> = self
433 .collector
434 .into_sorted_results()
435 .into_iter()
436 .map(|(doc_id, score, ordinal)| ScoredDoc {
437 doc_id,
438 score,
439 ordinal,
440 })
441 .collect();
442
443 debug!(
444 "WandExecutor completed: scored={}, skipped={}, blocks_skipped={}, returned={}, top_score={:.4}",
445 docs_scored,
446 docs_skipped,
447 blocks_skipped,
448 results.len(),
449 results.first().map(|r| r.score).unwrap_or(0.0)
450 );
451
452 results
453 }
454}
455
456pub struct TextTermScorer {
461 iter: crate::structures::BlockPostingIterator<'static>,
463 idf: f32,
465 avg_field_len: f32,
467 max_score: f32,
469}
470
471impl TextTermScorer {
472 pub fn new(
474 posting_list: crate::structures::BlockPostingList,
475 idf: f32,
476 avg_field_len: f32,
477 ) -> Self {
478 let max_tf = posting_list.max_tf() as f32;
480 let doc_count = posting_list.doc_count();
481 let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
482
483 debug!(
484 "Created TextTermScorer: doc_count={}, max_tf={:.0}, idf={:.4}, avg_field_len={:.2}, max_score={:.4}",
485 doc_count, max_tf, idf, avg_field_len, max_score
486 );
487
488 Self {
489 iter: posting_list.into_iterator(),
490 idf,
491 avg_field_len,
492 max_score,
493 }
494 }
495}
496
497impl ScoringIterator for TextTermScorer {
498 #[inline]
499 fn doc(&self) -> DocId {
500 self.iter.doc()
501 }
502
503 #[inline]
504 fn advance(&mut self) -> DocId {
505 self.iter.advance()
506 }
507
508 #[inline]
509 fn seek(&mut self, target: DocId) -> DocId {
510 self.iter.seek(target)
511 }
512
513 #[inline]
514 fn score(&self) -> f32 {
515 let tf = self.iter.term_freq() as f32;
516 super::bm25_score(tf, self.idf, tf, self.avg_field_len)
518 }
519
520 #[inline]
521 fn max_score(&self) -> f32 {
522 self.max_score
523 }
524
525 #[inline]
526 fn current_block_max_score(&self) -> f32 {
527 let block_max_tf = self.iter.current_block_max_tf() as f32;
529 super::bm25_upper_bound(block_max_tf.max(1.0), self.idf)
530 }
531
532 #[inline]
533 fn skip_to_next_block(&mut self) -> DocId {
534 self.iter.skip_to_next_block()
535 }
536}
537
538pub struct SparseTermScorer<'a> {
542 iter: crate::structures::BlockSparsePostingIterator<'a>,
544 query_weight: f32,
546 max_score: f32,
548}
549
550impl<'a> SparseTermScorer<'a> {
551 pub fn new(posting_list: &'a BlockSparsePostingList, query_weight: f32) -> Self {
556 let max_score = query_weight.abs() * posting_list.global_max_weight();
559 Self {
560 iter: posting_list.iterator(),
561 query_weight,
562 max_score,
563 }
564 }
565
566 pub fn from_arc(posting_list: &'a Arc<BlockSparsePostingList>, query_weight: f32) -> Self {
568 Self::new(posting_list.as_ref(), query_weight)
569 }
570}
571
572impl ScoringIterator for SparseTermScorer<'_> {
573 #[inline]
574 fn doc(&self) -> DocId {
575 self.iter.doc()
576 }
577
578 #[inline]
579 fn ordinal(&self) -> u16 {
580 self.iter.ordinal()
581 }
582
583 #[inline]
584 fn advance(&mut self) -> DocId {
585 self.iter.advance()
586 }
587
588 #[inline]
589 fn seek(&mut self, target: DocId) -> DocId {
590 self.iter.seek(target)
591 }
592
593 #[inline]
594 fn score(&self) -> f32 {
595 self.query_weight * self.iter.weight()
597 }
598
599 #[inline]
600 fn max_score(&self) -> f32 {
601 self.max_score
602 }
603
604 #[inline]
605 fn current_block_max_score(&self) -> f32 {
606 self.iter
608 .current_block_max_contribution(self.query_weight.abs())
609 }
610
611 #[inline]
612 fn skip_to_next_block(&mut self) -> DocId {
613 self.iter.skip_to_next_block()
614 }
615}
616
617pub struct BmpExecutor {
626 posting_lists: Vec<Arc<BlockSparsePostingList>>,
628 query_weights: Vec<f32>,
630 k: usize,
632 heap_factor: f32,
634}
635
636struct BmpBlockEntry {
638 contribution: f32,
640 term_idx: usize,
642 block_idx: usize,
644}
645
646impl PartialEq for BmpBlockEntry {
647 fn eq(&self, other: &Self) -> bool {
648 self.contribution == other.contribution
649 }
650}
651
652impl Eq for BmpBlockEntry {}
653
654impl Ord for BmpBlockEntry {
655 fn cmp(&self, other: &Self) -> Ordering {
656 self.contribution
658 .partial_cmp(&other.contribution)
659 .unwrap_or(Ordering::Equal)
660 }
661}
662
663impl PartialOrd for BmpBlockEntry {
664 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
665 Some(self.cmp(other))
666 }
667}
668
669impl BmpExecutor {
670 pub fn new(
672 posting_lists: Vec<Arc<BlockSparsePostingList>>,
673 query_weights: Vec<f32>,
674 k: usize,
675 heap_factor: f32,
676 ) -> Self {
677 Self {
678 posting_lists,
679 query_weights,
680 k,
681 heap_factor: heap_factor.clamp(0.0, 1.0),
682 }
683 }
684
685 pub fn execute(self) -> Vec<ScoredDoc> {
687 use rustc_hash::FxHashMap;
688
689 if self.posting_lists.is_empty() {
690 return Vec::new();
691 }
692
693 let num_terms = self.posting_lists.len();
694
695 let mut block_queue: BinaryHeap<BmpBlockEntry> = BinaryHeap::new();
697
698 let mut remaining_max: Vec<f32> = Vec::with_capacity(num_terms);
700
701 for (term_idx, pl) in self.posting_lists.iter().enumerate() {
702 let qw = self.query_weights[term_idx].abs();
703 let mut term_remaining = 0.0f32;
704
705 for block_idx in 0..pl.num_blocks() {
706 if let Some(block_max_weight) = pl.block_max_weight(block_idx) {
707 let contribution = qw * block_max_weight;
708 term_remaining += contribution;
709 block_queue.push(BmpBlockEntry {
710 contribution,
711 term_idx,
712 block_idx,
713 });
714 }
715 }
716 remaining_max.push(term_remaining);
717 }
718
719 let mut accumulators: FxHashMap<DocId, (f32, u16)> = FxHashMap::default();
721 let mut collector = ScoreCollector::new(self.k);
722 let mut blocks_processed = 0u64;
723
724 while let Some(entry) = block_queue.pop() {
726 remaining_max[entry.term_idx] -= entry.contribution;
728
729 let total_remaining: f32 = remaining_max.iter().sum();
732 let adjusted_threshold = collector.threshold() * self.heap_factor;
733 if collector.len() >= self.k && total_remaining <= adjusted_threshold {
734 debug!(
735 "BMP early termination after {} blocks: remaining={:.4} <= threshold={:.4}",
736 blocks_processed, total_remaining, adjusted_threshold
737 );
738 break;
739 }
740
741 let pl = &self.posting_lists[entry.term_idx];
743 let block = &pl.blocks[entry.block_idx];
744 let doc_ids = block.decode_doc_ids();
745 let weights = block.decode_weights();
746 let ordinals = block.decode_ordinals();
747 let qw = self.query_weights[entry.term_idx];
748
749 for i in 0..block.header.count as usize {
750 let score_contribution = qw * weights[i];
751 let acc = accumulators.entry(doc_ids[i]).or_insert((0.0, ordinals[i]));
752 acc.0 += score_contribution;
753 }
754
755 blocks_processed += 1;
756 }
757
758 for (doc_id, (score, ordinal)) in &accumulators {
760 collector.insert_with_ordinal(*doc_id, *score, *ordinal);
761 }
762
763 let results: Vec<ScoredDoc> = collector
764 .into_sorted_results()
765 .into_iter()
766 .map(|(doc_id, score, ordinal)| ScoredDoc {
767 doc_id,
768 score,
769 ordinal,
770 })
771 .collect();
772
773 debug!(
774 "BmpExecutor completed: blocks_processed={}, accumulators={}, returned={}, top_score={:.4}",
775 blocks_processed,
776 accumulators.len(),
777 results.len(),
778 results.first().map(|r| r.score).unwrap_or(0.0)
779 );
780
781 results
782 }
783}
784
785pub struct MaxScoreExecutor<S: ScoringIterator> {
794 scorers: Vec<S>,
796 prefix_sums: Vec<f32>,
798 collector: ScoreCollector,
800 heap_factor: f32,
802}
803
804impl<S: ScoringIterator> MaxScoreExecutor<S> {
805 pub fn new(scorers: Vec<S>, k: usize) -> Self {
807 Self::with_heap_factor(scorers, k, 1.0)
808 }
809
810 pub fn with_heap_factor(mut scorers: Vec<S>, k: usize, heap_factor: f32) -> Self {
812 scorers.sort_by(|a, b| {
814 a.max_score()
815 .partial_cmp(&b.max_score())
816 .unwrap_or(Ordering::Equal)
817 });
818
819 let mut prefix_sums = Vec::with_capacity(scorers.len());
821 let mut cumsum = 0.0f32;
822 for s in &scorers {
823 cumsum += s.max_score();
824 prefix_sums.push(cumsum);
825 }
826
827 debug!(
828 "Creating MaxScoreExecutor: num_scorers={}, k={}, total_upper={:.4}, heap_factor={:.2}",
829 scorers.len(),
830 k,
831 cumsum,
832 heap_factor
833 );
834
835 Self {
836 scorers,
837 prefix_sums,
838 collector: ScoreCollector::new(k),
839 heap_factor: heap_factor.clamp(0.0, 1.0),
840 }
841 }
842
843 fn find_partition(&self) -> usize {
846 let threshold = self.collector.threshold() * self.heap_factor;
847 self.prefix_sums
850 .iter()
851 .position(|&sum| sum > threshold)
852 .unwrap_or(self.scorers.len())
853 }
854
855 pub fn execute(mut self) -> Vec<ScoredDoc> {
857 if self.scorers.is_empty() {
858 return Vec::new();
859 }
860
861 let n = self.scorers.len();
862 let mut docs_scored = 0u64;
863 let mut docs_skipped = 0u64;
864
865 loop {
866 let partition = self.find_partition();
867
868 if partition >= n {
870 debug!("MaxScore: all terms non-essential, early termination");
871 break;
872 }
873
874 let mut min_doc = u32::MAX;
876 for i in partition..n {
877 let doc = self.scorers[i].doc();
878 if doc < min_doc {
879 min_doc = doc;
880 }
881 }
882
883 if min_doc == u32::MAX {
884 break; }
886
887 let mut score = 0.0f32;
889 let mut ordinal = 0u16;
890 let mut first_match = true;
891
892 for i in partition..n {
893 if self.scorers[i].doc() == min_doc {
894 score += self.scorers[i].score();
895 if first_match {
896 ordinal = self.scorers[i].ordinal();
897 first_match = false;
898 }
899 self.scorers[i].advance();
900 }
901 }
902
903 let non_essential_upper = if partition > 0 {
905 self.prefix_sums[partition - 1]
906 } else {
907 0.0
908 };
909
910 let adjusted_threshold = self.collector.threshold() * self.heap_factor;
911
912 if self.collector.len() >= self.collector.k
913 && score + non_essential_upper <= adjusted_threshold
914 {
915 docs_skipped += 1;
916 continue;
917 }
918
919 for i in 0..partition {
921 let doc = self.scorers[i].seek(min_doc);
922 if doc == min_doc {
923 score += self.scorers[i].score();
924 self.scorers[i].advance();
925 }
926 }
927
928 self.collector.insert_with_ordinal(min_doc, score, ordinal);
929 docs_scored += 1;
930 }
931
932 let results: Vec<ScoredDoc> = self
933 .collector
934 .into_sorted_results()
935 .into_iter()
936 .map(|(doc_id, score, ordinal)| ScoredDoc {
937 doc_id,
938 score,
939 ordinal,
940 })
941 .collect();
942
943 debug!(
944 "MaxScoreExecutor completed: scored={}, skipped={}, returned={}, top_score={:.4}",
945 docs_scored,
946 docs_skipped,
947 results.len(),
948 results.first().map(|r| r.score).unwrap_or(0.0)
949 );
950
951 results
952 }
953}
954
955#[cfg(test)]
956mod tests {
957 use super::*;
958
959 #[test]
960 fn test_score_collector_basic() {
961 let mut collector = ScoreCollector::new(3);
962
963 collector.insert(1, 1.0);
964 collector.insert(2, 2.0);
965 collector.insert(3, 3.0);
966 assert_eq!(collector.threshold(), 1.0);
967
968 collector.insert(4, 4.0);
969 assert_eq!(collector.threshold(), 2.0);
970
971 let results = collector.into_sorted_results();
972 assert_eq!(results.len(), 3);
973 assert_eq!(results[0].0, 4); assert_eq!(results[1].0, 3);
975 assert_eq!(results[2].0, 2);
976 }
977
978 #[test]
979 fn test_score_collector_threshold() {
980 let mut collector = ScoreCollector::new(2);
981
982 collector.insert(1, 5.0);
983 collector.insert(2, 3.0);
984 assert_eq!(collector.threshold(), 3.0);
985
986 assert!(!collector.would_enter(2.0));
988 assert!(!collector.insert(3, 2.0));
989
990 assert!(collector.would_enter(4.0));
992 assert!(collector.insert(4, 4.0));
993 assert_eq!(collector.threshold(), 4.0);
994 }
995
996 #[test]
997 fn test_heap_entry_ordering() {
998 let mut heap = BinaryHeap::new();
999 heap.push(HeapEntry {
1000 doc_id: 1,
1001 score: 3.0,
1002 ordinal: 0,
1003 });
1004 heap.push(HeapEntry {
1005 doc_id: 2,
1006 score: 1.0,
1007 ordinal: 0,
1008 });
1009 heap.push(HeapEntry {
1010 doc_id: 3,
1011 score: 2.0,
1012 ordinal: 0,
1013 });
1014
1015 assert_eq!(heap.pop().unwrap().score, 1.0);
1017 assert_eq!(heap.pop().unwrap().score, 2.0);
1018 assert_eq!(heap.pop().unwrap().score, 3.0);
1019 }
1020}