1use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12use std::sync::Arc;
13
14use log::{debug, trace};
15
16use crate::DocId;
17use crate::structures::BlockSparsePostingList;
18
19pub trait ScoringIterator {
23 fn doc(&self) -> DocId;
25
26 fn ordinal(&mut self) -> u16 {
28 0
29 }
30
31 fn advance(&mut self) -> DocId;
33
34 fn seek(&mut self, target: DocId) -> DocId;
36
37 fn is_exhausted(&self) -> bool {
39 self.doc() == u32::MAX
40 }
41
42 fn score(&self) -> f32;
44
45 fn max_score(&self) -> f32;
47
48 fn current_block_max_score(&self) -> f32;
50
51 fn skip_to_next_block(&mut self) -> DocId {
55 self.advance()
56 }
57}
58
59#[derive(Clone, Copy)]
61pub struct HeapEntry {
62 pub doc_id: DocId,
63 pub score: f32,
64 pub ordinal: u16,
65}
66
67impl PartialEq for HeapEntry {
68 fn eq(&self, other: &Self) -> bool {
69 self.score == other.score && self.doc_id == other.doc_id
70 }
71}
72
73impl Eq for HeapEntry {}
74
75impl Ord for HeapEntry {
76 fn cmp(&self, other: &Self) -> Ordering {
77 other
79 .score
80 .partial_cmp(&self.score)
81 .unwrap_or(Ordering::Equal)
82 .then_with(|| self.doc_id.cmp(&other.doc_id))
83 }
84}
85
86impl PartialOrd for HeapEntry {
87 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
88 Some(self.cmp(other))
89 }
90}
91
92pub struct ScoreCollector {
98 heap: BinaryHeap<HeapEntry>,
100 pub k: usize,
101 cached_threshold: f32,
104}
105
106impl ScoreCollector {
107 pub fn new(k: usize) -> Self {
109 let capacity = k.saturating_add(1).min(1_000_000);
111 Self {
112 heap: BinaryHeap::with_capacity(capacity),
113 k,
114 cached_threshold: 0.0,
115 }
116 }
117
118 #[inline]
120 pub fn threshold(&self) -> f32 {
121 self.cached_threshold
122 }
123
124 #[inline]
126 fn update_threshold(&mut self) {
127 self.cached_threshold = if self.heap.len() >= self.k {
128 self.heap.peek().map(|e| e.score).unwrap_or(0.0)
129 } else {
130 0.0
131 };
132 }
133
134 #[inline]
137 pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
138 self.insert_with_ordinal(doc_id, score, 0)
139 }
140
141 #[inline]
144 pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
145 if self.heap.len() < self.k {
146 self.heap.push(HeapEntry {
147 doc_id,
148 score,
149 ordinal,
150 });
151 self.update_threshold();
152 true
153 } else if score > self.cached_threshold {
154 self.heap.push(HeapEntry {
155 doc_id,
156 score,
157 ordinal,
158 });
159 self.heap.pop(); self.update_threshold();
161 true
162 } else {
163 false
164 }
165 }
166
167 #[inline]
169 pub fn would_enter(&self, score: f32) -> bool {
170 self.heap.len() < self.k || score > self.cached_threshold
171 }
172
173 #[inline]
175 pub fn len(&self) -> usize {
176 self.heap.len()
177 }
178
179 #[inline]
181 pub fn is_empty(&self) -> bool {
182 self.heap.is_empty()
183 }
184
185 pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
187 let heap_vec = self.heap.into_vec();
188 let mut results: Vec<(DocId, f32, u16)> = Vec::with_capacity(heap_vec.len());
189 for e in heap_vec {
190 results.push((e.doc_id, e.score, e.ordinal));
191 }
192
193 results.sort_by(|a, b| {
195 b.1.partial_cmp(&a.1)
196 .unwrap_or(Ordering::Equal)
197 .then_with(|| a.0.cmp(&b.0))
198 });
199
200 results
201 }
202}
203
204#[derive(Debug, Clone, Copy)]
206pub struct ScoredDoc {
207 pub doc_id: DocId,
208 pub score: f32,
209 pub ordinal: u16,
211}
212
213pub struct MaxScoreExecutor<S: ScoringIterator> {
226 scorers: Vec<S>,
228 prefix_sums: Vec<f32>,
230 collector: ScoreCollector,
232 heap_factor: f32,
236}
237
238impl<S: ScoringIterator> MaxScoreExecutor<S> {
239 pub fn new(scorers: Vec<S>, k: usize) -> Self {
241 Self::with_heap_factor(scorers, k, 1.0)
242 }
243
244 pub fn with_heap_factor(mut scorers: Vec<S>, k: usize, heap_factor: f32) -> Self {
251 scorers.sort_by(|a, b| {
253 a.max_score()
254 .partial_cmp(&b.max_score())
255 .unwrap_or(Ordering::Equal)
256 });
257
258 let mut prefix_sums = Vec::with_capacity(scorers.len());
260 let mut cumsum = 0.0f32;
261 for s in &scorers {
262 cumsum += s.max_score();
263 prefix_sums.push(cumsum);
264 }
265
266 debug!(
267 "Creating MaxScoreExecutor: num_scorers={}, k={}, total_upper={:.4}, heap_factor={:.2}",
268 scorers.len(),
269 k,
270 cumsum,
271 heap_factor
272 );
273
274 Self {
275 scorers,
276 prefix_sums,
277 collector: ScoreCollector::new(k),
278 heap_factor: heap_factor.clamp(0.0, 1.0),
279 }
280 }
281
282 #[inline]
285 fn find_partition(&self) -> usize {
286 let threshold = self.collector.threshold() * self.heap_factor;
287 self.prefix_sums.partition_point(|&sum| sum <= threshold)
289 }
290
291 pub fn execute(mut self) -> Vec<ScoredDoc> {
301 if self.scorers.is_empty() {
302 debug!("MaxScoreExecutor: no scorers, returning empty results");
303 return Vec::new();
304 }
305
306 let n = self.scorers.len();
307 let mut docs_scored = 0u64;
308 let mut docs_skipped = 0u64;
309 let mut blocks_skipped = 0u64;
310 let mut conjunction_skipped = 0u64;
311
312 let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
314
315 loop {
316 let partition = self.find_partition();
317
318 if partition >= n {
320 debug!("BlockMaxScore: all terms non-essential, early termination");
321 break;
322 }
323
324 let mut min_doc = u32::MAX;
328 let mut present_upper = 0.0f32;
329 let mut block_max_sum = 0.0f32;
330 for i in partition..n {
331 let doc = self.scorers[i].doc();
332 if doc < min_doc {
333 min_doc = doc;
334 present_upper = self.scorers[i].max_score();
336 block_max_sum = self.scorers[i].current_block_max_score();
337 } else if doc == min_doc {
338 present_upper += self.scorers[i].max_score();
339 block_max_sum += self.scorers[i].current_block_max_score();
340 }
341 }
342
343 if min_doc == u32::MAX {
344 break; }
346
347 let non_essential_upper = if partition > 0 {
348 self.prefix_sums[partition - 1]
349 } else {
350 0.0
351 };
352 let adjusted_threshold = self.collector.threshold() * self.heap_factor;
353
354 if self.collector.len() >= self.collector.k
357 && present_upper + non_essential_upper <= adjusted_threshold
358 {
359 for i in partition..n {
360 if self.scorers[i].doc() == min_doc {
361 self.scorers[i].advance();
362 }
363 }
364 conjunction_skipped += 1;
365 continue;
366 }
367
368 if self.collector.len() >= self.collector.k
371 && block_max_sum + non_essential_upper <= adjusted_threshold
372 {
373 for i in partition..n {
374 if self.scorers[i].doc() == min_doc {
375 self.scorers[i].skip_to_next_block();
376 }
377 }
378 blocks_skipped += 1;
379 continue;
380 }
381
382 ordinal_scores.clear();
385
386 for i in partition..n {
387 if self.scorers[i].doc() == min_doc {
388 while self.scorers[i].doc() == min_doc {
389 ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
390 self.scorers[i].advance();
391 }
392 }
393 }
394
395 let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
397
398 if self.collector.len() >= self.collector.k
399 && essential_total + non_essential_upper <= adjusted_threshold
400 {
401 docs_skipped += 1;
402 continue;
403 }
404
405 for i in 0..partition {
407 let doc = self.scorers[i].seek(min_doc);
408 if doc == min_doc {
409 while self.scorers[i].doc() == min_doc {
410 ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
411 self.scorers[i].advance();
412 }
413 }
414 }
415
416 if ordinal_scores.len() == 1 {
419 let (ord, score) = ordinal_scores[0];
420 trace!(
421 "Doc {}: ordinal={}, score={:.4}, threshold={:.4}",
422 min_doc, ord, score, adjusted_threshold
423 );
424 if self.collector.insert_with_ordinal(min_doc, score, ord) {
425 docs_scored += 1;
426 } else {
427 docs_skipped += 1;
428 }
429 } else if !ordinal_scores.is_empty() {
430 if ordinal_scores.len() > 2 {
431 ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
432 } else if ordinal_scores[0].0 > ordinal_scores[1].0 {
433 ordinal_scores.swap(0, 1);
434 }
435 let mut j = 0;
436 while j < ordinal_scores.len() {
437 let current_ord = ordinal_scores[j].0;
438 let mut score = 0.0f32;
439 while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
440 score += ordinal_scores[j].1;
441 j += 1;
442 }
443
444 trace!(
445 "Doc {}: ordinal={}, score={:.4}, threshold={:.4}",
446 min_doc, current_ord, score, adjusted_threshold
447 );
448
449 if self
450 .collector
451 .insert_with_ordinal(min_doc, score, current_ord)
452 {
453 docs_scored += 1;
454 } else {
455 docs_skipped += 1;
456 }
457 }
458 }
459 }
460
461 let results: Vec<ScoredDoc> = self
462 .collector
463 .into_sorted_results()
464 .into_iter()
465 .map(|(doc_id, score, ordinal)| ScoredDoc {
466 doc_id,
467 score,
468 ordinal,
469 })
470 .collect();
471
472 debug!(
473 "MaxScoreExecutor completed: scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
474 docs_scored,
475 docs_skipped,
476 blocks_skipped,
477 conjunction_skipped,
478 results.len(),
479 results.first().map(|r| r.score).unwrap_or(0.0)
480 );
481
482 results
483 }
484}
485
486pub struct TextTermScorer {
491 iter: crate::structures::BlockPostingIterator<'static>,
493 idf: f32,
495 avg_field_len: f32,
497 max_score: f32,
499}
500
501impl TextTermScorer {
502 pub fn new(
504 posting_list: crate::structures::BlockPostingList,
505 idf: f32,
506 avg_field_len: f32,
507 ) -> Self {
508 let max_tf = posting_list.max_tf() as f32;
510 let doc_count = posting_list.doc_count();
511 let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
512
513 debug!(
514 "Created TextTermScorer: doc_count={}, max_tf={:.0}, idf={:.4}, avg_field_len={:.2}, max_score={:.4}",
515 doc_count, max_tf, idf, avg_field_len, max_score
516 );
517
518 Self {
519 iter: posting_list.into_iterator(),
520 idf,
521 avg_field_len,
522 max_score,
523 }
524 }
525}
526
527impl ScoringIterator for TextTermScorer {
528 #[inline]
529 fn doc(&self) -> DocId {
530 self.iter.doc()
531 }
532
533 #[inline]
534 fn advance(&mut self) -> DocId {
535 self.iter.advance()
536 }
537
538 #[inline]
539 fn seek(&mut self, target: DocId) -> DocId {
540 self.iter.seek(target)
541 }
542
543 #[inline]
544 fn score(&self) -> f32 {
545 let tf = self.iter.term_freq() as f32;
546 super::bm25_score(tf, self.idf, tf, self.avg_field_len)
548 }
549
550 #[inline]
551 fn max_score(&self) -> f32 {
552 self.max_score
553 }
554
555 #[inline]
556 fn current_block_max_score(&self) -> f32 {
557 let block_max_tf = self.iter.current_block_max_tf() as f32;
559 super::bm25_upper_bound(block_max_tf.max(1.0), self.idf)
560 }
561
562 #[inline]
563 fn skip_to_next_block(&mut self) -> DocId {
564 self.iter.skip_to_next_block()
565 }
566}
567
568pub struct SparseTermScorer<'a> {
572 iter: crate::structures::BlockSparsePostingIterator<'a>,
574 query_weight: f32,
576 abs_query_weight: f32,
578 max_score: f32,
580}
581
582impl<'a> SparseTermScorer<'a> {
583 pub fn new(posting_list: &'a BlockSparsePostingList, query_weight: f32) -> Self {
588 let abs_qw = query_weight.abs();
591 let max_score = abs_qw * posting_list.global_max_weight();
592 Self {
593 iter: posting_list.iterator(),
594 query_weight,
595 abs_query_weight: abs_qw,
596 max_score,
597 }
598 }
599
600 pub fn from_arc(posting_list: &'a Arc<BlockSparsePostingList>, query_weight: f32) -> Self {
602 Self::new(posting_list.as_ref(), query_weight)
603 }
604}
605
606impl ScoringIterator for SparseTermScorer<'_> {
607 #[inline]
608 fn doc(&self) -> DocId {
609 self.iter.doc()
610 }
611
612 #[inline]
613 fn ordinal(&mut self) -> u16 {
614 self.iter.ordinal()
615 }
616
617 #[inline]
618 fn advance(&mut self) -> DocId {
619 self.iter.advance()
620 }
621
622 #[inline]
623 fn seek(&mut self, target: DocId) -> DocId {
624 self.iter.seek(target)
625 }
626
627 #[inline]
628 fn score(&self) -> f32 {
629 self.query_weight * self.iter.weight()
631 }
632
633 #[inline]
634 fn max_score(&self) -> f32 {
635 self.max_score
636 }
637
638 #[inline]
639 fn current_block_max_score(&self) -> f32 {
640 self.iter
641 .current_block_max_contribution(self.abs_query_weight)
642 }
643
644 #[inline]
645 fn skip_to_next_block(&mut self) -> DocId {
646 self.iter.skip_to_next_block()
647 }
648}
649
650pub struct BmpExecutor<'a> {
662 sparse_index: &'a crate::segment::SparseIndex,
664 query_terms: Vec<(u32, f32)>,
666 k: usize,
668 heap_factor: f32,
670}
671
672const BMP_SUPERBLOCK_SIZE: usize = 8;
675
676const BMP_MEGABLOCK_SIZE: usize = 16;
680
681struct BmpSuperBlock {
683 contribution: f32,
685 block_start: usize,
687 block_count: usize,
689}
690
691struct BmpMegaBlockEntry {
693 contribution: f32,
695 term_idx: usize,
697 sb_start: usize,
699 sb_count: usize,
701}
702
703impl PartialEq for BmpMegaBlockEntry {
704 fn eq(&self, other: &Self) -> bool {
705 self.contribution == other.contribution
706 }
707}
708
709impl Eq for BmpMegaBlockEntry {}
710
711impl Ord for BmpMegaBlockEntry {
712 fn cmp(&self, other: &Self) -> Ordering {
713 self.contribution
715 .partial_cmp(&other.contribution)
716 .unwrap_or(Ordering::Equal)
717 }
718}
719
720impl PartialOrd for BmpMegaBlockEntry {
721 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
722 Some(self.cmp(other))
723 }
724}
725
726impl<'a> BmpExecutor<'a> {
727 pub fn new(
732 sparse_index: &'a crate::segment::SparseIndex,
733 query_terms: Vec<(u32, f32)>,
734 k: usize,
735 heap_factor: f32,
736 ) -> Self {
737 Self {
738 sparse_index,
739 query_terms,
740 k,
741 heap_factor: heap_factor.clamp(0.0, 1.0),
742 }
743 }
744
745 pub async fn execute(self) -> crate::Result<Vec<ScoredDoc>> {
753 use rustc_hash::FxHashMap;
754
755 if self.query_terms.is_empty() {
756 return Ok(Vec::new());
757 }
758
759 let num_terms = self.query_terms.len();
760 let si = self.sparse_index;
761
762 let mut term_superblocks: Vec<Vec<BmpSuperBlock>> = Vec::with_capacity(num_terms);
766 let mut term_skip_starts: Vec<usize> = Vec::with_capacity(num_terms);
767 let mut global_min_doc = u32::MAX;
768 let mut global_max_doc = 0u32;
769 let mut total_remaining = 0.0f32;
770
771 for &(dim_id, qw) in &self.query_terms {
772 let mut term_skip_start = 0usize;
773 let mut superblocks = Vec::new();
774
775 let abs_qw = qw.abs();
776 if let Some((skip_start, skip_count, _global_max)) = si.get_skip_range(dim_id) {
777 term_skip_start = skip_start;
778 let mut sb_start = 0;
780 while sb_start < skip_count {
781 let sb_count = (skip_count - sb_start).min(BMP_SUPERBLOCK_SIZE);
782 let mut sb_contribution = 0.0f32;
783 for j in 0..sb_count {
784 let skip = si.read_skip_entry(skip_start + sb_start + j);
785 sb_contribution += abs_qw * skip.max_weight;
786 global_min_doc = global_min_doc.min(skip.first_doc);
787 global_max_doc = global_max_doc.max(skip.last_doc);
788 }
789 total_remaining += sb_contribution;
790 superblocks.push(BmpSuperBlock {
791 contribution: sb_contribution,
792 block_start: sb_start,
793 block_count: sb_count,
794 });
795 sb_start += sb_count;
796 }
797 }
798 term_skip_starts.push(term_skip_start);
799 term_superblocks.push(superblocks);
800 }
801
802 let mut mega_queue: BinaryHeap<BmpMegaBlockEntry> = BinaryHeap::new();
804 for (term_idx, superblocks) in term_superblocks.iter().enumerate() {
805 let mut mb_start = 0;
806 while mb_start < superblocks.len() {
807 let mb_count = (superblocks.len() - mb_start).min(BMP_MEGABLOCK_SIZE);
808 let mb_contribution: f32 = superblocks[mb_start..mb_start + mb_count]
809 .iter()
810 .map(|sb| sb.contribution)
811 .sum();
812 mega_queue.push(BmpMegaBlockEntry {
813 contribution: mb_contribution,
814 term_idx,
815 sb_start: mb_start,
816 sb_count: mb_count,
817 });
818 mb_start += mb_count;
819 }
820 }
821
822 let doc_range = if global_max_doc >= global_min_doc {
824 (global_max_doc - global_min_doc + 1) as usize
825 } else {
826 0
827 };
828 let use_flat = doc_range > 0 && doc_range <= 256 * 1024;
830 let mut flat_scores: Vec<f32> = if use_flat {
831 vec![0.0; doc_range]
832 } else {
833 Vec::new()
834 };
835 let mut dirty: Vec<u32> = if use_flat {
837 Vec::with_capacity(4096)
838 } else {
839 Vec::new()
840 };
841 let mut multi_ord_accumulators: FxHashMap<u64, f32> = FxHashMap::default();
843
844 let mut blocks_processed = 0u64;
845 let mut blocks_skipped = 0u64;
846
847 let mut top_k = ScoreCollector::new(self.k);
849
850 let mut doc_ids_buf: Vec<u32> = Vec::with_capacity(256);
852 let mut weights_buf: Vec<f32> = Vec::with_capacity(256);
853 let mut ordinals_buf: Vec<u16> = Vec::with_capacity(256);
854
855 let mut terms_warmed = vec![false; num_terms];
859 let mut warmup_remaining = self.k.min(num_terms);
860
861 while let Some(mega) = mega_queue.pop() {
863 total_remaining -= mega.contribution;
864
865 if !terms_warmed[mega.term_idx] {
867 terms_warmed[mega.term_idx] = true;
868 warmup_remaining = warmup_remaining.saturating_sub(1);
869 }
870
871 if warmup_remaining == 0 {
873 let adjusted_threshold = top_k.threshold() * self.heap_factor;
874 if top_k.len() >= self.k && total_remaining <= adjusted_threshold {
875 let remaining_blocks: u64 = mega_queue
877 .iter()
878 .map(|m| {
879 let sbs =
880 &term_superblocks[m.term_idx][m.sb_start..m.sb_start + m.sb_count];
881 sbs.iter().map(|sb| sb.block_count as u64).sum::<u64>()
882 })
883 .sum();
884 blocks_skipped += remaining_blocks;
885 debug!(
886 "BMP early termination after {} blocks: remaining={:.4} <= threshold={:.4}",
887 blocks_processed, total_remaining, adjusted_threshold
888 );
889 break;
890 }
891 }
892
893 let dim_id = self.query_terms[mega.term_idx].0;
894 let qw = self.query_terms[mega.term_idx].1;
895 let abs_qw = qw.abs();
896 let skip_start = term_skip_starts[mega.term_idx];
897
898 for sb in term_superblocks[mega.term_idx]
900 .iter()
901 .skip(mega.sb_start)
902 .take(mega.sb_count)
903 {
904 if top_k.len() >= self.k {
906 let adjusted_threshold = top_k.threshold() * self.heap_factor;
907 if sb.contribution + total_remaining <= adjusted_threshold {
908 blocks_skipped += sb.block_count as u64;
909 continue;
910 }
911 }
912
913 let sb_blocks = si
915 .get_blocks_range(dim_id, sb.block_start, sb.block_count)
916 .await?;
917
918 let adjusted_threshold2 = top_k.threshold() * self.heap_factor;
919
920 let dirty_start = dirty.len();
922
923 for (blk_offset, block) in sb_blocks.into_iter().enumerate() {
924 let blk_idx = sb.block_start + blk_offset;
925
926 if top_k.len() >= self.k {
928 let skip = si.read_skip_entry(skip_start + blk_idx);
929 let blk_contrib = abs_qw * skip.max_weight;
930 if blk_contrib + total_remaining <= adjusted_threshold2 {
931 blocks_skipped += 1;
932 continue;
933 }
934 }
935
936 block.decode_doc_ids_into(&mut doc_ids_buf);
937
938 if block.header.ordinal_bits == 0 && use_flat {
940 block.accumulate_scored_weights(
941 qw,
942 &doc_ids_buf,
943 &mut flat_scores,
944 global_min_doc,
945 &mut dirty,
946 );
947 } else {
948 block.decode_scored_weights_into(qw, &mut weights_buf);
949 let count = block.header.count as usize;
950
951 block.decode_ordinals_into(&mut ordinals_buf);
952 if use_flat {
953 for i in 0..count {
954 let doc_id = doc_ids_buf[i];
955 let ordinal = ordinals_buf[i];
956 let score_contribution = weights_buf[i];
957
958 if ordinal == 0 {
959 let off = (doc_id - global_min_doc) as usize;
960 if flat_scores[off] == 0.0 {
961 dirty.push(doc_id);
962 }
963 flat_scores[off] += score_contribution;
964 } else {
965 let key = (doc_id as u64) << 16 | ordinal as u64;
966 let acc = multi_ord_accumulators.entry(key).or_insert(0.0);
967 *acc += score_contribution;
968 top_k.insert_with_ordinal(doc_id, *acc, ordinal);
969 }
970 }
971 } else {
972 for i in 0..count {
973 let key = (doc_ids_buf[i] as u64) << 16 | ordinals_buf[i] as u64;
974 let acc = multi_ord_accumulators.entry(key).or_insert(0.0);
975 *acc += weights_buf[i];
976 top_k.insert_with_ordinal(doc_ids_buf[i], *acc, ordinals_buf[i]);
977 }
978 }
979 }
980
981 blocks_processed += 1;
982 }
983
984 for &doc_id in &dirty[dirty_start..] {
988 let off = (doc_id - global_min_doc) as usize;
989 top_k.insert_with_ordinal(doc_id, flat_scores[off], 0);
990 }
991 }
992 }
993
994 let mut scored: Vec<ScoredDoc> = Vec::new();
996
997 let num_accumulators = if use_flat {
998 scored.reserve(dirty.len() + multi_ord_accumulators.len());
1000 for &doc_id in &dirty {
1001 let off = (doc_id - global_min_doc) as usize;
1002 let score = flat_scores[off];
1003 if score > 0.0 {
1004 scored.push(ScoredDoc {
1005 doc_id,
1006 score,
1007 ordinal: 0,
1008 });
1009 }
1010 }
1011 dirty.len() + multi_ord_accumulators.len()
1012 } else {
1013 multi_ord_accumulators.len()
1014 };
1015
1016 scored.extend(
1018 multi_ord_accumulators
1019 .into_iter()
1020 .map(|(key, score)| ScoredDoc {
1021 doc_id: (key >> 16) as DocId,
1022 score,
1023 ordinal: (key & 0xFFFF) as u16,
1024 }),
1025 );
1026
1027 scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
1028 scored.truncate(self.k);
1029 let results = scored;
1030
1031 debug!(
1032 "BmpExecutor completed: blocks_processed={}, blocks_skipped={}, accumulators={}, flat={}, returned={}, top_score={:.4}",
1033 blocks_processed,
1034 blocks_skipped,
1035 num_accumulators,
1036 use_flat,
1037 results.len(),
1038 results.first().map(|r| r.score).unwrap_or(0.0)
1039 );
1040
1041 Ok(results)
1042 }
1043}
1044
1045pub struct SparseMaxScoreExecutor<'a> {
1057 sparse_index: &'a crate::segment::SparseIndex,
1058 cursors: Vec<LazyTermCursor>,
1059 prefix_sums: Vec<f32>,
1060 collector: ScoreCollector,
1061 heap_factor: f32,
1062}
1063
1064struct LazyTermCursor {
1066 query_weight: f32,
1067 abs_query_weight: f32,
1069 max_score: f32,
1070 skip_start: usize,
1072 skip_count: usize,
1074 block_data_offset: u64,
1076 block_idx: usize,
1078 doc_ids: Vec<u32>,
1080 ordinals: Vec<u16>,
1081 weights: Vec<f32>,
1082 pos: usize,
1084 block_loaded: bool,
1086 exhausted: bool,
1087}
1088
1089impl LazyTermCursor {
1090 fn new(
1091 query_weight: f32,
1092 skip_start: usize,
1093 skip_count: usize,
1094 global_max_weight: f32,
1095 block_data_offset: u64,
1096 ) -> Self {
1097 let exhausted = skip_count == 0;
1098 let abs_qw = query_weight.abs();
1099 Self {
1100 query_weight,
1101 abs_query_weight: abs_qw,
1102 max_score: abs_qw * global_max_weight,
1103 skip_start,
1104 skip_count,
1105 block_data_offset,
1106 block_idx: 0,
1107 doc_ids: Vec::with_capacity(256),
1108 ordinals: Vec::with_capacity(256),
1109 weights: Vec::with_capacity(256),
1110 pos: 0,
1111 block_loaded: false,
1112 exhausted,
1113 }
1114 }
1115
1116 #[inline]
1120 fn decode_block(&mut self, block: crate::structures::SparseBlock) {
1121 block.decode_doc_ids_into(&mut self.doc_ids);
1122 block.decode_ordinals_into(&mut self.ordinals);
1123 block.decode_scored_weights_into(self.query_weight, &mut self.weights);
1124 self.pos = 0;
1125 self.block_loaded = true;
1126 }
1127
1128 #[inline]
1130 fn handle_block_result(
1131 &mut self,
1132 block: Option<crate::structures::SparseBlock>,
1133 ) -> crate::Result<bool> {
1134 match block {
1135 Some(b) => {
1136 self.decode_block(b);
1137 Ok(true)
1138 }
1139 None => {
1140 self.exhausted = true;
1141 Ok(false)
1142 }
1143 }
1144 }
1145
1146 #[inline]
1149 fn advance_pos(&mut self) -> DocId {
1150 self.pos += 1;
1151 if self.pos >= self.doc_ids.len() {
1152 self.block_idx += 1;
1153 self.block_loaded = false;
1154 if self.block_idx >= self.skip_count {
1155 self.exhausted = true;
1156 return u32::MAX;
1157 }
1158 }
1159 self.doc()
1160 }
1161
1162 fn seek_prepare(
1166 &mut self,
1167 si: &crate::segment::SparseIndex,
1168 target: DocId,
1169 ) -> crate::Result<Option<DocId>> {
1170 if self.exhausted {
1171 return Ok(Some(u32::MAX));
1172 }
1173
1174 if self.block_loaded
1176 && let Some(&last) = self.doc_ids.last()
1177 {
1178 if last >= target && self.doc_ids[self.pos] < target {
1179 let remaining = &self.doc_ids[self.pos..];
1180 let offset = crate::structures::simd::find_first_ge_u32(remaining, target);
1181 self.pos += offset;
1182 if self.pos >= self.doc_ids.len() {
1183 self.block_idx += 1;
1184 self.block_loaded = false;
1185 if self.block_idx >= self.skip_count {
1186 self.exhausted = true;
1187 return Ok(Some(u32::MAX));
1188 }
1189 }
1190 return Ok(Some(self.doc()));
1191 }
1192 if self.doc_ids[self.pos] >= target {
1193 return Ok(Some(self.doc()));
1194 }
1195 }
1196
1197 let mut lo = self.block_idx;
1199 let mut hi = self.skip_count;
1200 while lo < hi {
1201 let mid = lo + (hi - lo) / 2;
1202 if si.read_skip_entry(self.skip_start + mid).last_doc < target {
1203 lo = mid + 1;
1204 } else {
1205 hi = mid;
1206 }
1207 }
1208 if lo >= self.skip_count {
1209 self.exhausted = true;
1210 return Ok(Some(u32::MAX));
1211 }
1212 if lo != self.block_idx || !self.block_loaded {
1213 self.block_idx = lo;
1214 self.block_loaded = false;
1215 }
1216 Ok(None)
1218 }
1219
1220 #[inline]
1223 fn seek_finish(&mut self, target: DocId) -> bool {
1224 if self.exhausted {
1225 return false;
1226 }
1227 self.pos = crate::structures::simd::find_first_ge_u32(&self.doc_ids, target);
1228 if self.pos >= self.doc_ids.len() {
1229 self.block_idx += 1;
1230 self.block_loaded = false;
1231 if self.block_idx >= self.skip_count {
1232 self.exhausted = true;
1233 return false;
1234 }
1235 return true; }
1237 false
1238 }
1239
1240 async fn ensure_block_loaded(
1243 &mut self,
1244 si: &crate::segment::SparseIndex,
1245 ) -> crate::Result<bool> {
1246 if self.exhausted || self.block_loaded {
1247 return Ok(!self.exhausted);
1248 }
1249 let block = si
1250 .load_block_direct(self.skip_start, self.block_data_offset, self.block_idx)
1251 .await?;
1252 self.handle_block_result(block)
1253 }
1254
1255 async fn advance(&mut self, si: &crate::segment::SparseIndex) -> crate::Result<DocId> {
1256 if self.exhausted {
1257 return Ok(u32::MAX);
1258 }
1259 self.ensure_block_loaded(si).await?;
1260 if self.exhausted {
1261 return Ok(u32::MAX);
1262 }
1263 Ok(self.advance_pos())
1264 }
1265
1266 async fn seek(
1267 &mut self,
1268 si: &crate::segment::SparseIndex,
1269 target: DocId,
1270 ) -> crate::Result<DocId> {
1271 if let Some(doc) = self.seek_prepare(si, target)? {
1272 return Ok(doc);
1273 }
1274 self.ensure_block_loaded(si).await?;
1275 if self.seek_finish(target) {
1276 self.ensure_block_loaded(si).await?;
1277 }
1278 Ok(self.doc())
1279 }
1280
1281 fn ensure_block_loaded_sync(
1284 &mut self,
1285 si: &crate::segment::SparseIndex,
1286 ) -> crate::Result<bool> {
1287 if self.exhausted || self.block_loaded {
1288 return Ok(!self.exhausted);
1289 }
1290 let block =
1291 si.load_block_direct_sync(self.skip_start, self.block_data_offset, self.block_idx)?;
1292 self.handle_block_result(block)
1293 }
1294
1295 fn advance_sync(&mut self, si: &crate::segment::SparseIndex) -> crate::Result<DocId> {
1296 if self.exhausted {
1297 return Ok(u32::MAX);
1298 }
1299 self.ensure_block_loaded_sync(si)?;
1300 if self.exhausted {
1301 return Ok(u32::MAX);
1302 }
1303 Ok(self.advance_pos())
1304 }
1305
1306 fn seek_sync(
1307 &mut self,
1308 si: &crate::segment::SparseIndex,
1309 target: DocId,
1310 ) -> crate::Result<DocId> {
1311 if let Some(doc) = self.seek_prepare(si, target)? {
1312 return Ok(doc);
1313 }
1314 self.ensure_block_loaded_sync(si)?;
1315 if self.seek_finish(target) {
1316 self.ensure_block_loaded_sync(si)?;
1317 }
1318 Ok(self.doc())
1319 }
1320
1321 #[inline]
1324 fn doc_with_si(&self, si: &crate::segment::SparseIndex) -> DocId {
1325 if self.exhausted {
1326 return u32::MAX;
1327 }
1328 if !self.block_loaded {
1329 return si
1330 .read_skip_entry(self.skip_start + self.block_idx)
1331 .first_doc;
1332 }
1333 self.doc_ids.get(self.pos).copied().unwrap_or(u32::MAX)
1334 }
1335
1336 #[inline]
1337 fn doc(&self) -> DocId {
1338 if self.exhausted {
1339 return u32::MAX;
1340 }
1341 if self.block_loaded {
1342 return self.doc_ids.get(self.pos).copied().unwrap_or(u32::MAX);
1343 }
1344 u32::MAX
1345 }
1346
1347 #[inline]
1348 fn ordinal(&self) -> u16 {
1349 if !self.block_loaded {
1350 return 0;
1351 }
1352 self.ordinals.get(self.pos).copied().unwrap_or(0)
1353 }
1354
1355 #[inline]
1356 fn score(&self) -> f32 {
1357 if !self.block_loaded {
1358 return 0.0;
1359 }
1360 self.weights.get(self.pos).copied().unwrap_or(0.0)
1361 }
1362
1363 #[inline]
1364 fn current_block_max_score(&self, si: &crate::segment::SparseIndex) -> f32 {
1365 if self.exhausted || self.block_idx >= self.skip_count {
1366 return 0.0;
1367 }
1368 self.abs_query_weight
1369 * si.read_skip_entry(self.skip_start + self.block_idx)
1370 .max_weight
1371 }
1372
1373 fn skip_to_next_block(&mut self, si: &crate::segment::SparseIndex) -> DocId {
1375 if self.exhausted {
1376 return u32::MAX;
1377 }
1378 self.block_idx += 1;
1379 self.block_loaded = false;
1380 if self.block_idx >= self.skip_count {
1381 self.exhausted = true;
1382 return u32::MAX;
1383 }
1384 si.read_skip_entry(self.skip_start + self.block_idx)
1385 .first_doc
1386 }
1387}
1388
1389macro_rules! bms_execute_loop {
1394 ($self:ident, $ensure:ident, $advance:ident, $seek:ident, $($aw:tt)*) => {{
1395 let n = $self.cursors.len();
1396 let si = $self.sparse_index;
1397
1398 for cursor in &mut $self.cursors {
1400 cursor.$ensure(si) $($aw)* ?;
1401 }
1402
1403 let mut docs_scored = 0u64;
1404 let mut docs_skipped = 0u64;
1405 let mut blocks_skipped = 0u64;
1406 let mut blocks_loaded = 0u64;
1407 let mut conjunction_skipped = 0u64;
1408 let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
1409
1410 loop {
1411 let partition = $self.find_partition();
1412 if partition >= n {
1413 break;
1414 }
1415
1416 let mut min_doc = u32::MAX;
1418 for i in partition..n {
1419 let doc = $self.cursors[i].doc_with_si(si);
1420 if doc < min_doc {
1421 min_doc = doc;
1422 }
1423 }
1424 if min_doc == u32::MAX {
1425 break;
1426 }
1427
1428 let non_essential_upper = if partition > 0 {
1429 $self.prefix_sums[partition - 1]
1430 } else {
1431 0.0
1432 };
1433 let adjusted_threshold = $self.collector.threshold() * $self.heap_factor;
1434
1435 if $self.collector.len() >= $self.collector.k {
1437 let present_upper: f32 = (partition..n)
1438 .filter(|&i| $self.cursors[i].doc_with_si(si) == min_doc)
1439 .map(|i| $self.cursors[i].max_score)
1440 .sum();
1441
1442 if present_upper + non_essential_upper <= adjusted_threshold {
1443 for i in partition..n {
1444 if $self.cursors[i].doc_with_si(si) == min_doc {
1445 $self.cursors[i].$ensure(si) $($aw)* ?;
1446 $self.cursors[i].$advance(si) $($aw)* ?;
1447 blocks_loaded += u64::from($self.cursors[i].block_loaded);
1448 }
1449 }
1450 conjunction_skipped += 1;
1451 continue;
1452 }
1453 }
1454
1455 if $self.collector.len() >= $self.collector.k {
1457 let block_max_sum: f32 = (partition..n)
1458 .filter(|&i| $self.cursors[i].doc_with_si(si) == min_doc)
1459 .map(|i| $self.cursors[i].current_block_max_score(si))
1460 .sum();
1461
1462 if block_max_sum + non_essential_upper <= adjusted_threshold {
1463 for i in partition..n {
1464 if $self.cursors[i].doc_with_si(si) == min_doc {
1465 $self.cursors[i].skip_to_next_block(si);
1466 $self.cursors[i].$ensure(si) $($aw)* ?;
1467 blocks_loaded += 1;
1468 }
1469 }
1470 blocks_skipped += 1;
1471 continue;
1472 }
1473 }
1474
1475 ordinal_scores.clear();
1477 for i in partition..n {
1478 if $self.cursors[i].doc_with_si(si) == min_doc {
1479 $self.cursors[i].$ensure(si) $($aw)* ?;
1480 while $self.cursors[i].doc_with_si(si) == min_doc {
1481 ordinal_scores.push(($self.cursors[i].ordinal(), $self.cursors[i].score()));
1482 $self.cursors[i].$advance(si) $($aw)* ?;
1483 }
1484 }
1485 }
1486
1487 let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
1488 if $self.collector.len() >= $self.collector.k
1489 && essential_total + non_essential_upper <= adjusted_threshold
1490 {
1491 docs_skipped += 1;
1492 continue;
1493 }
1494
1495 let mut running_total = essential_total;
1497 for i in (0..partition).rev() {
1498 if $self.collector.len() >= $self.collector.k
1499 && running_total + $self.prefix_sums[i] <= adjusted_threshold
1500 {
1501 break;
1502 }
1503
1504 let doc = $self.cursors[i].$seek(si, min_doc) $($aw)* ?;
1505 if doc == min_doc {
1506 while $self.cursors[i].doc_with_si(si) == min_doc {
1507 let s = $self.cursors[i].score();
1508 running_total += s;
1509 ordinal_scores.push(($self.cursors[i].ordinal(), s));
1510 $self.cursors[i].$advance(si) $($aw)* ?;
1511 }
1512 }
1513 }
1514
1515 if ordinal_scores.len() == 1 {
1518 let (ord, score) = ordinal_scores[0];
1519 if $self.collector.insert_with_ordinal(min_doc, score, ord) {
1520 docs_scored += 1;
1521 } else {
1522 docs_skipped += 1;
1523 }
1524 } else if !ordinal_scores.is_empty() {
1525 if ordinal_scores.len() > 2 {
1526 ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
1527 } else if ordinal_scores[0].0 > ordinal_scores[1].0 {
1528 ordinal_scores.swap(0, 1);
1529 }
1530 let mut j = 0;
1531 while j < ordinal_scores.len() {
1532 let current_ord = ordinal_scores[j].0;
1533 let mut score = 0.0f32;
1534 while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
1535 score += ordinal_scores[j].1;
1536 j += 1;
1537 }
1538 if $self
1539 .collector
1540 .insert_with_ordinal(min_doc, score, current_ord)
1541 {
1542 docs_scored += 1;
1543 } else {
1544 docs_skipped += 1;
1545 }
1546 }
1547 }
1548 }
1549
1550 let results: Vec<ScoredDoc> = $self
1551 .collector
1552 .into_sorted_results()
1553 .into_iter()
1554 .map(|(doc_id, score, ordinal)| ScoredDoc {
1555 doc_id,
1556 score,
1557 ordinal,
1558 })
1559 .collect();
1560
1561 debug!(
1562 "SparseMaxScoreExecutor: scored={}, skipped={}, blocks_skipped={}, blocks_loaded={}, conjunction_skipped={}, returned={}, top_score={:.4}",
1563 docs_scored,
1564 docs_skipped,
1565 blocks_skipped,
1566 blocks_loaded,
1567 conjunction_skipped,
1568 results.len(),
1569 results.first().map(|r| r.score).unwrap_or(0.0)
1570 );
1571
1572 Ok(results)
1573 }};
1574}
1575
1576impl<'a> SparseMaxScoreExecutor<'a> {
1577 pub fn new(
1582 sparse_index: &'a crate::segment::SparseIndex,
1583 query_terms: Vec<(u32, f32)>,
1584 k: usize,
1585 heap_factor: f32,
1586 ) -> Self {
1587 let mut cursors: Vec<LazyTermCursor> = query_terms
1588 .iter()
1589 .filter_map(|&(dim_id, qw)| {
1590 let (skip_start, skip_count, global_max, block_data_offset) =
1591 sparse_index.get_skip_range_full(dim_id)?;
1592 Some(LazyTermCursor::new(
1593 qw,
1594 skip_start,
1595 skip_count,
1596 global_max,
1597 block_data_offset,
1598 ))
1599 })
1600 .collect();
1601
1602 cursors.sort_by(|a, b| {
1604 a.max_score
1605 .partial_cmp(&b.max_score)
1606 .unwrap_or(Ordering::Equal)
1607 });
1608
1609 let mut prefix_sums = Vec::with_capacity(cursors.len());
1610 let mut cumsum = 0.0f32;
1611 for c in &cursors {
1612 cumsum += c.max_score;
1613 prefix_sums.push(cumsum);
1614 }
1615
1616 debug!(
1617 "Creating SparseMaxScoreExecutor: num_terms={}, k={}, total_upper={:.4}, heap_factor={:.2}",
1618 cursors.len(),
1619 k,
1620 cumsum,
1621 heap_factor
1622 );
1623
1624 Self {
1625 sparse_index,
1626 cursors,
1627 prefix_sums,
1628 collector: ScoreCollector::new(k),
1629 heap_factor: heap_factor.clamp(0.0, 1.0),
1630 }
1631 }
1632
1633 #[inline]
1634 fn find_partition(&self) -> usize {
1635 let threshold = self.collector.threshold() * self.heap_factor;
1636 self.prefix_sums.partition_point(|&sum| sum <= threshold)
1638 }
1639
1640 pub async fn execute(mut self) -> crate::Result<Vec<ScoredDoc>> {
1642 if self.cursors.is_empty() {
1643 return Ok(Vec::new());
1644 }
1645 bms_execute_loop!(self, ensure_block_loaded, advance, seek, .await)
1646 }
1647
1648 pub fn execute_sync(mut self) -> crate::Result<Vec<ScoredDoc>> {
1651 if self.cursors.is_empty() {
1652 return Ok(Vec::new());
1653 }
1654 bms_execute_loop!(self, ensure_block_loaded_sync, advance_sync, seek_sync,)
1655 }
1656}
1657
1658#[cfg(test)]
1659mod tests {
1660 use super::*;
1661
1662 #[test]
1663 fn test_score_collector_basic() {
1664 let mut collector = ScoreCollector::new(3);
1665
1666 collector.insert(1, 1.0);
1667 collector.insert(2, 2.0);
1668 collector.insert(3, 3.0);
1669 assert_eq!(collector.threshold(), 1.0);
1670
1671 collector.insert(4, 4.0);
1672 assert_eq!(collector.threshold(), 2.0);
1673
1674 let results = collector.into_sorted_results();
1675 assert_eq!(results.len(), 3);
1676 assert_eq!(results[0].0, 4); assert_eq!(results[1].0, 3);
1678 assert_eq!(results[2].0, 2);
1679 }
1680
1681 #[test]
1682 fn test_score_collector_threshold() {
1683 let mut collector = ScoreCollector::new(2);
1684
1685 collector.insert(1, 5.0);
1686 collector.insert(2, 3.0);
1687 assert_eq!(collector.threshold(), 3.0);
1688
1689 assert!(!collector.would_enter(2.0));
1691 assert!(!collector.insert(3, 2.0));
1692
1693 assert!(collector.would_enter(4.0));
1695 assert!(collector.insert(4, 4.0));
1696 assert_eq!(collector.threshold(), 4.0);
1697 }
1698
1699 #[test]
1700 fn test_heap_entry_ordering() {
1701 let mut heap = BinaryHeap::new();
1702 heap.push(HeapEntry {
1703 doc_id: 1,
1704 score: 3.0,
1705 ordinal: 0,
1706 });
1707 heap.push(HeapEntry {
1708 doc_id: 2,
1709 score: 1.0,
1710 ordinal: 0,
1711 });
1712 heap.push(HeapEntry {
1713 doc_id: 3,
1714 score: 2.0,
1715 ordinal: 0,
1716 });
1717
1718 assert_eq!(heap.pop().unwrap().score, 1.0);
1720 assert_eq!(heap.pop().unwrap().score, 2.0);
1721 assert_eq!(heap.pop().unwrap().score, 3.0);
1722 }
1723}