1use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12use std::sync::Arc;
13
14use log::{debug, trace};
15
16use crate::DocId;
17use crate::structures::BlockSparsePostingList;
18
19pub trait ScoringIterator {
23 fn doc(&self) -> DocId;
25
26 fn ordinal(&self) -> u16 {
28 0
29 }
30
31 fn advance(&mut self) -> DocId;
33
34 fn seek(&mut self, target: DocId) -> DocId;
36
37 fn is_exhausted(&self) -> bool {
39 self.doc() == u32::MAX
40 }
41
42 fn score(&self) -> f32;
44
45 fn max_score(&self) -> f32;
47
48 fn current_block_max_score(&self) -> f32;
50
51 fn skip_to_next_block(&mut self) -> DocId {
55 self.advance()
56 }
57}
58
59#[derive(Clone, Copy)]
61pub struct HeapEntry {
62 pub doc_id: DocId,
63 pub score: f32,
64 pub ordinal: u16,
65}
66
67impl PartialEq for HeapEntry {
68 fn eq(&self, other: &Self) -> bool {
69 self.score == other.score && self.doc_id == other.doc_id
70 }
71}
72
73impl Eq for HeapEntry {}
74
75impl Ord for HeapEntry {
76 fn cmp(&self, other: &Self) -> Ordering {
77 other
79 .score
80 .partial_cmp(&self.score)
81 .unwrap_or(Ordering::Equal)
82 .then_with(|| self.doc_id.cmp(&other.doc_id))
83 }
84}
85
86impl PartialOrd for HeapEntry {
87 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
88 Some(self.cmp(other))
89 }
90}
91
92pub struct ScoreCollector {
98 heap: BinaryHeap<HeapEntry>,
100 pub k: usize,
101}
102
103impl ScoreCollector {
104 pub fn new(k: usize) -> Self {
106 let capacity = k.saturating_add(1).min(1_000_000);
108 Self {
109 heap: BinaryHeap::with_capacity(capacity),
110 k,
111 }
112 }
113
114 #[inline]
116 pub fn threshold(&self) -> f32 {
117 if self.heap.len() >= self.k {
118 self.heap.peek().map(|e| e.score).unwrap_or(0.0)
119 } else {
120 0.0
121 }
122 }
123
124 #[inline]
127 pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
128 self.insert_with_ordinal(doc_id, score, 0)
129 }
130
131 #[inline]
134 pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
135 if self.heap.len() < self.k {
136 self.heap.push(HeapEntry {
137 doc_id,
138 score,
139 ordinal,
140 });
141 true
142 } else if score > self.threshold() {
143 self.heap.push(HeapEntry {
144 doc_id,
145 score,
146 ordinal,
147 });
148 self.heap.pop(); true
150 } else {
151 false
152 }
153 }
154
155 #[inline]
157 pub fn would_enter(&self, score: f32) -> bool {
158 self.heap.len() < self.k || score > self.threshold()
159 }
160
161 #[inline]
163 pub fn len(&self) -> usize {
164 self.heap.len()
165 }
166
167 #[inline]
169 pub fn is_empty(&self) -> bool {
170 self.heap.is_empty()
171 }
172
173 pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
175 let heap_vec = self.heap.into_vec();
176 let mut results: Vec<(DocId, f32, u16)> = Vec::with_capacity(heap_vec.len());
177 for e in heap_vec {
178 results.push((e.doc_id, e.score, e.ordinal));
179 }
180
181 results.sort_by(|a, b| {
183 b.1.partial_cmp(&a.1)
184 .unwrap_or(Ordering::Equal)
185 .then_with(|| a.0.cmp(&b.0))
186 });
187
188 results
189 }
190}
191
192#[derive(Debug, Clone, Copy)]
194pub struct ScoredDoc {
195 pub doc_id: DocId,
196 pub score: f32,
197 pub ordinal: u16,
199}
200
201pub struct BlockMaxScoreExecutor<S: ScoringIterator> {
215 scorers: Vec<S>,
217 prefix_sums: Vec<f32>,
219 collector: ScoreCollector,
221 heap_factor: f32,
225}
226
227pub type WandExecutor<S> = BlockMaxScoreExecutor<S>;
229
230impl<S: ScoringIterator> BlockMaxScoreExecutor<S> {
231 pub fn new(scorers: Vec<S>, k: usize) -> Self {
233 Self::with_heap_factor(scorers, k, 1.0)
234 }
235
236 pub fn with_heap_factor(mut scorers: Vec<S>, k: usize, heap_factor: f32) -> Self {
243 scorers.sort_by(|a, b| {
245 a.max_score()
246 .partial_cmp(&b.max_score())
247 .unwrap_or(Ordering::Equal)
248 });
249
250 let mut prefix_sums = Vec::with_capacity(scorers.len());
252 let mut cumsum = 0.0f32;
253 for s in &scorers {
254 cumsum += s.max_score();
255 prefix_sums.push(cumsum);
256 }
257
258 debug!(
259 "Creating BlockMaxScoreExecutor: num_scorers={}, k={}, total_upper={:.4}, heap_factor={:.2}",
260 scorers.len(),
261 k,
262 cumsum,
263 heap_factor
264 );
265
266 Self {
267 scorers,
268 prefix_sums,
269 collector: ScoreCollector::new(k),
270 heap_factor: heap_factor.clamp(0.0, 1.0),
271 }
272 }
273
274 #[inline]
277 fn find_partition(&self) -> usize {
278 let threshold = self.collector.threshold() * self.heap_factor;
279 self.prefix_sums
280 .iter()
281 .position(|&sum| sum > threshold)
282 .unwrap_or(self.scorers.len())
283 }
284
285 pub fn execute(mut self) -> Vec<ScoredDoc> {
295 if self.scorers.is_empty() {
296 debug!("BlockMaxScoreExecutor: no scorers, returning empty results");
297 return Vec::new();
298 }
299
300 let n = self.scorers.len();
301 let mut docs_scored = 0u64;
302 let mut docs_skipped = 0u64;
303 let mut blocks_skipped = 0u64;
304 let mut conjunction_skipped = 0u64;
305
306 let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
308
309 loop {
310 let partition = self.find_partition();
311
312 if partition >= n {
314 debug!("BlockMaxScore: all terms non-essential, early termination");
315 break;
316 }
317
318 let mut min_doc = u32::MAX;
320 for i in partition..n {
321 let doc = self.scorers[i].doc();
322 if doc < min_doc {
323 min_doc = doc;
324 }
325 }
326
327 if min_doc == u32::MAX {
328 break; }
330
331 let non_essential_upper = if partition > 0 {
332 self.prefix_sums[partition - 1]
333 } else {
334 0.0
335 };
336 let adjusted_threshold = self.collector.threshold() * self.heap_factor;
337
338 if self.collector.len() >= self.collector.k {
343 let present_upper: f32 = (partition..n)
344 .filter(|&i| self.scorers[i].doc() == min_doc)
345 .map(|i| self.scorers[i].max_score())
346 .sum();
347
348 if present_upper + non_essential_upper <= adjusted_threshold {
349 for i in partition..n {
351 if self.scorers[i].doc() == min_doc {
352 self.scorers[i].advance();
353 }
354 }
355 conjunction_skipped += 1;
356 continue;
357 }
358 }
359
360 if self.collector.len() >= self.collector.k {
364 let block_max_sum: f32 = (partition..n)
365 .filter(|&i| self.scorers[i].doc() == min_doc)
366 .map(|i| self.scorers[i].current_block_max_score())
367 .sum();
368
369 if block_max_sum + non_essential_upper <= adjusted_threshold {
370 for i in partition..n {
371 if self.scorers[i].doc() == min_doc {
372 self.scorers[i].skip_to_next_block();
373 }
374 }
375 blocks_skipped += 1;
376 continue;
377 }
378 }
379
380 ordinal_scores.clear();
383
384 for i in partition..n {
385 if self.scorers[i].doc() == min_doc {
386 while self.scorers[i].doc() == min_doc {
387 ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
388 self.scorers[i].advance();
389 }
390 }
391 }
392
393 let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
395
396 if self.collector.len() >= self.collector.k
397 && essential_total + non_essential_upper <= adjusted_threshold
398 {
399 docs_skipped += 1;
400 continue;
401 }
402
403 for i in 0..partition {
405 let doc = self.scorers[i].seek(min_doc);
406 if doc == min_doc {
407 while self.scorers[i].doc() == min_doc {
408 ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
409 self.scorers[i].advance();
410 }
411 }
412 }
413
414 ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
416 let mut j = 0;
417 while j < ordinal_scores.len() {
418 let current_ord = ordinal_scores[j].0;
419 let mut score = 0.0f32;
420 while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
421 score += ordinal_scores[j].1;
422 j += 1;
423 }
424
425 trace!(
426 "Doc {}: ordinal={}, score={:.4}, threshold={:.4}",
427 min_doc, current_ord, score, adjusted_threshold
428 );
429
430 if self
431 .collector
432 .insert_with_ordinal(min_doc, score, current_ord)
433 {
434 docs_scored += 1;
435 } else {
436 docs_skipped += 1;
437 }
438 }
439 }
440
441 let results: Vec<ScoredDoc> = self
442 .collector
443 .into_sorted_results()
444 .into_iter()
445 .map(|(doc_id, score, ordinal)| ScoredDoc {
446 doc_id,
447 score,
448 ordinal,
449 })
450 .collect();
451
452 debug!(
453 "BlockMaxScoreExecutor completed: scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
454 docs_scored,
455 docs_skipped,
456 blocks_skipped,
457 conjunction_skipped,
458 results.len(),
459 results.first().map(|r| r.score).unwrap_or(0.0)
460 );
461
462 results
463 }
464}
465
466pub struct TextTermScorer {
471 iter: crate::structures::BlockPostingIterator<'static>,
473 idf: f32,
475 avg_field_len: f32,
477 max_score: f32,
479}
480
481impl TextTermScorer {
482 pub fn new(
484 posting_list: crate::structures::BlockPostingList,
485 idf: f32,
486 avg_field_len: f32,
487 ) -> Self {
488 let max_tf = posting_list.max_tf() as f32;
490 let doc_count = posting_list.doc_count();
491 let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
492
493 debug!(
494 "Created TextTermScorer: doc_count={}, max_tf={:.0}, idf={:.4}, avg_field_len={:.2}, max_score={:.4}",
495 doc_count, max_tf, idf, avg_field_len, max_score
496 );
497
498 Self {
499 iter: posting_list.into_iterator(),
500 idf,
501 avg_field_len,
502 max_score,
503 }
504 }
505}
506
507impl ScoringIterator for TextTermScorer {
508 #[inline]
509 fn doc(&self) -> DocId {
510 self.iter.doc()
511 }
512
513 #[inline]
514 fn advance(&mut self) -> DocId {
515 self.iter.advance()
516 }
517
518 #[inline]
519 fn seek(&mut self, target: DocId) -> DocId {
520 self.iter.seek(target)
521 }
522
523 #[inline]
524 fn score(&self) -> f32 {
525 let tf = self.iter.term_freq() as f32;
526 super::bm25_score(tf, self.idf, tf, self.avg_field_len)
528 }
529
530 #[inline]
531 fn max_score(&self) -> f32 {
532 self.max_score
533 }
534
535 #[inline]
536 fn current_block_max_score(&self) -> f32 {
537 let block_max_tf = self.iter.current_block_max_tf() as f32;
539 super::bm25_upper_bound(block_max_tf.max(1.0), self.idf)
540 }
541
542 #[inline]
543 fn skip_to_next_block(&mut self) -> DocId {
544 self.iter.skip_to_next_block()
545 }
546}
547
548pub struct SparseTermScorer<'a> {
552 iter: crate::structures::BlockSparsePostingIterator<'a>,
554 query_weight: f32,
556 max_score: f32,
558}
559
560impl<'a> SparseTermScorer<'a> {
561 pub fn new(posting_list: &'a BlockSparsePostingList, query_weight: f32) -> Self {
566 let max_score = query_weight.abs() * posting_list.global_max_weight();
569 Self {
570 iter: posting_list.iterator(),
571 query_weight,
572 max_score,
573 }
574 }
575
576 pub fn from_arc(posting_list: &'a Arc<BlockSparsePostingList>, query_weight: f32) -> Self {
578 Self::new(posting_list.as_ref(), query_weight)
579 }
580}
581
582impl ScoringIterator for SparseTermScorer<'_> {
583 #[inline]
584 fn doc(&self) -> DocId {
585 self.iter.doc()
586 }
587
588 #[inline]
589 fn ordinal(&self) -> u16 {
590 self.iter.ordinal()
591 }
592
593 #[inline]
594 fn advance(&mut self) -> DocId {
595 self.iter.advance()
596 }
597
598 #[inline]
599 fn seek(&mut self, target: DocId) -> DocId {
600 self.iter.seek(target)
601 }
602
603 #[inline]
604 fn score(&self) -> f32 {
605 self.query_weight * self.iter.weight()
607 }
608
609 #[inline]
610 fn max_score(&self) -> f32 {
611 self.max_score
612 }
613
614 #[inline]
615 fn current_block_max_score(&self) -> f32 {
616 self.iter
618 .current_block_max_contribution(self.query_weight.abs())
619 }
620
621 #[inline]
622 fn skip_to_next_block(&mut self) -> DocId {
623 self.iter.skip_to_next_block()
624 }
625}
626
627pub struct BmpExecutor<'a> {
639 sparse_index: &'a crate::segment::SparseIndex,
641 query_terms: Vec<(u32, f32)>,
643 k: usize,
645 heap_factor: f32,
647}
648
649struct BmpBlockEntry {
651 contribution: f32,
653 term_idx: usize,
655 block_idx: usize,
657}
658
659impl PartialEq for BmpBlockEntry {
660 fn eq(&self, other: &Self) -> bool {
661 self.contribution == other.contribution
662 }
663}
664
665impl Eq for BmpBlockEntry {}
666
667impl Ord for BmpBlockEntry {
668 fn cmp(&self, other: &Self) -> Ordering {
669 self.contribution
671 .partial_cmp(&other.contribution)
672 .unwrap_or(Ordering::Equal)
673 }
674}
675
676impl PartialOrd for BmpBlockEntry {
677 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
678 Some(self.cmp(other))
679 }
680}
681
682impl<'a> BmpExecutor<'a> {
683 pub fn new(
688 sparse_index: &'a crate::segment::SparseIndex,
689 query_terms: Vec<(u32, f32)>,
690 k: usize,
691 heap_factor: f32,
692 ) -> Self {
693 Self {
694 sparse_index,
695 query_terms,
696 k,
697 heap_factor: heap_factor.clamp(0.0, 1.0),
698 }
699 }
700
701 pub async fn execute(self) -> crate::Result<Vec<ScoredDoc>> {
706 use rustc_hash::FxHashMap;
707
708 if self.query_terms.is_empty() {
709 return Ok(Vec::new());
710 }
711
712 let num_terms = self.query_terms.len();
713
714 let mut block_queue: BinaryHeap<BmpBlockEntry> = BinaryHeap::new();
716 let mut remaining_max: Vec<f32> = Vec::with_capacity(num_terms);
717
718 for (term_idx, &(dim_id, qw)) in self.query_terms.iter().enumerate() {
719 let mut term_remaining = 0.0f32;
720
721 if let Some((skip_entries, _global_max)) = self.sparse_index.get_skip_list(dim_id) {
722 for (block_idx, skip) in skip_entries.iter().enumerate() {
723 let contribution = qw.abs() * skip.max_weight;
724 term_remaining += contribution;
725 block_queue.push(BmpBlockEntry {
726 contribution,
727 term_idx,
728 block_idx,
729 });
730 }
731 }
732 remaining_max.push(term_remaining);
733 }
734
735 let mut accumulators: FxHashMap<u64, f32> = FxHashMap::default();
739 let mut blocks_processed = 0u64;
740 let mut blocks_skipped = 0u64;
741
742 let mut top_k = ScoreCollector::new(self.k);
745
746 let mut doc_ids_buf: Vec<u32> = Vec::with_capacity(128);
748 let mut weights_buf: Vec<f32> = Vec::with_capacity(128);
749 let mut ordinals_buf: Vec<u16> = Vec::with_capacity(128);
750
751 while let Some(entry) = block_queue.pop() {
753 remaining_max[entry.term_idx] -= entry.contribution;
755
756 let total_remaining: f32 = remaining_max.iter().sum();
759 let adjusted_threshold = top_k.threshold() * self.heap_factor;
760 if top_k.len() >= self.k && total_remaining <= adjusted_threshold {
761 blocks_skipped += block_queue.len() as u64;
762 debug!(
763 "BMP early termination after {} blocks: remaining={:.4} <= threshold={:.4}",
764 blocks_processed, total_remaining, adjusted_threshold
765 );
766 break;
767 }
768
769 let dim_id = self.query_terms[entry.term_idx].0;
771 let block = match self.sparse_index.get_block(dim_id, entry.block_idx).await? {
772 Some(b) => b,
773 None => continue,
774 };
775
776 block.decode_doc_ids_into(&mut doc_ids_buf);
778 block.decode_weights_into(&mut weights_buf);
779 block.decode_ordinals_into(&mut ordinals_buf);
780 let qw = self.query_terms[entry.term_idx].1;
781
782 for i in 0..block.header.count as usize {
783 let score_contribution = qw * weights_buf[i];
784 let key = (doc_ids_buf[i] as u64) << 16 | ordinals_buf[i] as u64;
785 let acc = accumulators.entry(key).or_insert(0.0);
786 *acc += score_contribution;
787 top_k.insert_with_ordinal(doc_ids_buf[i], *acc, ordinals_buf[i]);
791 }
792
793 blocks_processed += 1;
794 }
795
796 let num_accumulators = accumulators.len();
798 let mut scored: Vec<ScoredDoc> = accumulators
799 .into_iter()
800 .map(|(key, score)| ScoredDoc {
801 doc_id: (key >> 16) as DocId,
802 score,
803 ordinal: (key & 0xFFFF) as u16,
804 })
805 .collect();
806 scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
807 scored.truncate(self.k);
808 let results = scored;
809
810 debug!(
811 "BmpExecutor completed: blocks_processed={}, blocks_skipped={}, accumulators={}, returned={}, top_score={:.4}",
812 blocks_processed,
813 blocks_skipped,
814 num_accumulators,
815 results.len(),
816 results.first().map(|r| r.score).unwrap_or(0.0)
817 );
818
819 Ok(results)
820 }
821}
822
823#[cfg(test)]
824mod tests {
825 use super::*;
826
827 #[test]
828 fn test_score_collector_basic() {
829 let mut collector = ScoreCollector::new(3);
830
831 collector.insert(1, 1.0);
832 collector.insert(2, 2.0);
833 collector.insert(3, 3.0);
834 assert_eq!(collector.threshold(), 1.0);
835
836 collector.insert(4, 4.0);
837 assert_eq!(collector.threshold(), 2.0);
838
839 let results = collector.into_sorted_results();
840 assert_eq!(results.len(), 3);
841 assert_eq!(results[0].0, 4); assert_eq!(results[1].0, 3);
843 assert_eq!(results[2].0, 2);
844 }
845
846 #[test]
847 fn test_score_collector_threshold() {
848 let mut collector = ScoreCollector::new(2);
849
850 collector.insert(1, 5.0);
851 collector.insert(2, 3.0);
852 assert_eq!(collector.threshold(), 3.0);
853
854 assert!(!collector.would_enter(2.0));
856 assert!(!collector.insert(3, 2.0));
857
858 assert!(collector.would_enter(4.0));
860 assert!(collector.insert(4, 4.0));
861 assert_eq!(collector.threshold(), 4.0);
862 }
863
864 #[test]
865 fn test_heap_entry_ordering() {
866 let mut heap = BinaryHeap::new();
867 heap.push(HeapEntry {
868 doc_id: 1,
869 score: 3.0,
870 ordinal: 0,
871 });
872 heap.push(HeapEntry {
873 doc_id: 2,
874 score: 1.0,
875 ordinal: 0,
876 });
877 heap.push(HeapEntry {
878 doc_id: 3,
879 score: 2.0,
880 ordinal: 0,
881 });
882
883 assert_eq!(heap.pop().unwrap().score, 1.0);
885 assert_eq!(heap.pop().unwrap().score, 2.0);
886 assert_eq!(heap.pop().unwrap().score, 3.0);
887 }
888}