1use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12use std::sync::Arc;
13
14use log::{debug, trace};
15
16use crate::DocId;
17use crate::structures::BlockSparsePostingList;
18
19pub trait ScoringIterator {
23 fn doc(&self) -> DocId;
25
26 fn ordinal(&mut self) -> u16 {
28 0
29 }
30
31 fn advance(&mut self) -> DocId;
33
34 fn seek(&mut self, target: DocId) -> DocId;
36
37 fn is_exhausted(&self) -> bool {
39 self.doc() == u32::MAX
40 }
41
42 fn score(&self) -> f32;
44
45 fn max_score(&self) -> f32;
47
48 fn current_block_max_score(&self) -> f32;
50
51 fn skip_to_next_block(&mut self) -> DocId {
55 self.advance()
56 }
57}
58
59#[derive(Clone, Copy)]
61pub struct HeapEntry {
62 pub doc_id: DocId,
63 pub score: f32,
64 pub ordinal: u16,
65}
66
67impl PartialEq for HeapEntry {
68 fn eq(&self, other: &Self) -> bool {
69 self.score == other.score && self.doc_id == other.doc_id
70 }
71}
72
73impl Eq for HeapEntry {}
74
75impl Ord for HeapEntry {
76 fn cmp(&self, other: &Self) -> Ordering {
77 other
79 .score
80 .partial_cmp(&self.score)
81 .unwrap_or(Ordering::Equal)
82 .then_with(|| self.doc_id.cmp(&other.doc_id))
83 }
84}
85
86impl PartialOrd for HeapEntry {
87 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
88 Some(self.cmp(other))
89 }
90}
91
92pub struct ScoreCollector {
98 heap: BinaryHeap<HeapEntry>,
100 pub k: usize,
101 cached_threshold: f32,
104}
105
106impl ScoreCollector {
107 pub fn new(k: usize) -> Self {
109 let capacity = k.saturating_add(1).min(1_000_000);
111 Self {
112 heap: BinaryHeap::with_capacity(capacity),
113 k,
114 cached_threshold: 0.0,
115 }
116 }
117
118 #[inline]
120 pub fn threshold(&self) -> f32 {
121 self.cached_threshold
122 }
123
124 #[inline]
126 fn update_threshold(&mut self) {
127 self.cached_threshold = if self.heap.len() >= self.k {
128 self.heap.peek().map(|e| e.score).unwrap_or(0.0)
129 } else {
130 0.0
131 };
132 }
133
134 #[inline]
137 pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
138 self.insert_with_ordinal(doc_id, score, 0)
139 }
140
141 #[inline]
144 pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
145 if self.heap.len() < self.k {
146 self.heap.push(HeapEntry {
147 doc_id,
148 score,
149 ordinal,
150 });
151 self.update_threshold();
152 true
153 } else if score > self.cached_threshold {
154 self.heap.push(HeapEntry {
155 doc_id,
156 score,
157 ordinal,
158 });
159 self.heap.pop(); self.update_threshold();
161 true
162 } else {
163 false
164 }
165 }
166
167 #[inline]
169 pub fn would_enter(&self, score: f32) -> bool {
170 self.heap.len() < self.k || score > self.cached_threshold
171 }
172
173 #[inline]
175 pub fn len(&self) -> usize {
176 self.heap.len()
177 }
178
179 #[inline]
181 pub fn is_empty(&self) -> bool {
182 self.heap.is_empty()
183 }
184
185 pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
187 let heap_vec = self.heap.into_vec();
188 let mut results: Vec<(DocId, f32, u16)> = Vec::with_capacity(heap_vec.len());
189 for e in heap_vec {
190 results.push((e.doc_id, e.score, e.ordinal));
191 }
192
193 results.sort_by(|a, b| {
195 b.1.partial_cmp(&a.1)
196 .unwrap_or(Ordering::Equal)
197 .then_with(|| a.0.cmp(&b.0))
198 });
199
200 results
201 }
202}
203
204#[derive(Debug, Clone, Copy)]
206pub struct ScoredDoc {
207 pub doc_id: DocId,
208 pub score: f32,
209 pub ordinal: u16,
211}
212
213pub struct MaxScoreExecutor<S: ScoringIterator> {
226 scorers: Vec<S>,
228 prefix_sums: Vec<f32>,
230 collector: ScoreCollector,
232 heap_factor: f32,
236}
237
238impl<S: ScoringIterator> MaxScoreExecutor<S> {
239 pub fn new(scorers: Vec<S>, k: usize) -> Self {
241 Self::with_heap_factor(scorers, k, 1.0)
242 }
243
244 pub fn with_heap_factor(mut scorers: Vec<S>, k: usize, heap_factor: f32) -> Self {
251 scorers.sort_by(|a, b| {
253 a.max_score()
254 .partial_cmp(&b.max_score())
255 .unwrap_or(Ordering::Equal)
256 });
257
258 let mut prefix_sums = Vec::with_capacity(scorers.len());
260 let mut cumsum = 0.0f32;
261 for s in &scorers {
262 cumsum += s.max_score();
263 prefix_sums.push(cumsum);
264 }
265
266 debug!(
267 "Creating MaxScoreExecutor: num_scorers={}, k={}, total_upper={:.4}, heap_factor={:.2}",
268 scorers.len(),
269 k,
270 cumsum,
271 heap_factor
272 );
273
274 Self {
275 scorers,
276 prefix_sums,
277 collector: ScoreCollector::new(k),
278 heap_factor: heap_factor.clamp(0.0, 1.0),
279 }
280 }
281
282 #[inline]
285 fn find_partition(&self) -> usize {
286 let threshold = self.collector.threshold() * self.heap_factor;
287 self.prefix_sums.partition_point(|&sum| sum <= threshold)
289 }
290
291 pub fn execute(mut self) -> Vec<ScoredDoc> {
301 if self.scorers.is_empty() {
302 debug!("MaxScoreExecutor: no scorers, returning empty results");
303 return Vec::new();
304 }
305
306 let n = self.scorers.len();
307 let mut docs_scored = 0u64;
308 let mut docs_skipped = 0u64;
309 let mut blocks_skipped = 0u64;
310 let mut conjunction_skipped = 0u64;
311
312 let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
314
315 loop {
316 let partition = self.find_partition();
317
318 if partition >= n {
320 debug!("BlockMaxScore: all terms non-essential, early termination");
321 break;
322 }
323
324 let mut min_doc = u32::MAX;
328 let mut present_upper = 0.0f32;
329 let mut block_max_sum = 0.0f32;
330 for i in partition..n {
331 let doc = self.scorers[i].doc();
332 if doc < min_doc {
333 min_doc = doc;
334 present_upper = self.scorers[i].max_score();
336 block_max_sum = self.scorers[i].current_block_max_score();
337 } else if doc == min_doc {
338 present_upper += self.scorers[i].max_score();
339 block_max_sum += self.scorers[i].current_block_max_score();
340 }
341 }
342
343 if min_doc == u32::MAX {
344 break; }
346
347 let non_essential_upper = if partition > 0 {
348 self.prefix_sums[partition - 1]
349 } else {
350 0.0
351 };
352 let adjusted_threshold = self.collector.threshold() * self.heap_factor;
353
354 if self.collector.len() >= self.collector.k
357 && present_upper + non_essential_upper <= adjusted_threshold
358 {
359 for i in partition..n {
360 if self.scorers[i].doc() == min_doc {
361 self.scorers[i].advance();
362 }
363 }
364 conjunction_skipped += 1;
365 continue;
366 }
367
368 if self.collector.len() >= self.collector.k
371 && block_max_sum + non_essential_upper <= adjusted_threshold
372 {
373 for i in partition..n {
374 if self.scorers[i].doc() == min_doc {
375 self.scorers[i].skip_to_next_block();
376 }
377 }
378 blocks_skipped += 1;
379 continue;
380 }
381
382 ordinal_scores.clear();
385
386 for i in partition..n {
387 if self.scorers[i].doc() == min_doc {
388 while self.scorers[i].doc() == min_doc {
389 ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
390 self.scorers[i].advance();
391 }
392 }
393 }
394
395 let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
397
398 if self.collector.len() >= self.collector.k
399 && essential_total + non_essential_upper <= adjusted_threshold
400 {
401 docs_skipped += 1;
402 continue;
403 }
404
405 for i in 0..partition {
407 let doc = self.scorers[i].seek(min_doc);
408 if doc == min_doc {
409 while self.scorers[i].doc() == min_doc {
410 ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
411 self.scorers[i].advance();
412 }
413 }
414 }
415
416 ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
418 let mut j = 0;
419 while j < ordinal_scores.len() {
420 let current_ord = ordinal_scores[j].0;
421 let mut score = 0.0f32;
422 while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
423 score += ordinal_scores[j].1;
424 j += 1;
425 }
426
427 trace!(
428 "Doc {}: ordinal={}, score={:.4}, threshold={:.4}",
429 min_doc, current_ord, score, adjusted_threshold
430 );
431
432 if self
433 .collector
434 .insert_with_ordinal(min_doc, score, current_ord)
435 {
436 docs_scored += 1;
437 } else {
438 docs_skipped += 1;
439 }
440 }
441 }
442
443 let results: Vec<ScoredDoc> = self
444 .collector
445 .into_sorted_results()
446 .into_iter()
447 .map(|(doc_id, score, ordinal)| ScoredDoc {
448 doc_id,
449 score,
450 ordinal,
451 })
452 .collect();
453
454 debug!(
455 "MaxScoreExecutor completed: scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
456 docs_scored,
457 docs_skipped,
458 blocks_skipped,
459 conjunction_skipped,
460 results.len(),
461 results.first().map(|r| r.score).unwrap_or(0.0)
462 );
463
464 results
465 }
466}
467
468pub struct TextTermScorer {
473 iter: crate::structures::BlockPostingIterator<'static>,
475 idf: f32,
477 avg_field_len: f32,
479 max_score: f32,
481}
482
483impl TextTermScorer {
484 pub fn new(
486 posting_list: crate::structures::BlockPostingList,
487 idf: f32,
488 avg_field_len: f32,
489 ) -> Self {
490 let max_tf = posting_list.max_tf() as f32;
492 let doc_count = posting_list.doc_count();
493 let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
494
495 debug!(
496 "Created TextTermScorer: doc_count={}, max_tf={:.0}, idf={:.4}, avg_field_len={:.2}, max_score={:.4}",
497 doc_count, max_tf, idf, avg_field_len, max_score
498 );
499
500 Self {
501 iter: posting_list.into_iterator(),
502 idf,
503 avg_field_len,
504 max_score,
505 }
506 }
507}
508
509impl ScoringIterator for TextTermScorer {
510 #[inline]
511 fn doc(&self) -> DocId {
512 self.iter.doc()
513 }
514
515 #[inline]
516 fn advance(&mut self) -> DocId {
517 self.iter.advance()
518 }
519
520 #[inline]
521 fn seek(&mut self, target: DocId) -> DocId {
522 self.iter.seek(target)
523 }
524
525 #[inline]
526 fn score(&self) -> f32 {
527 let tf = self.iter.term_freq() as f32;
528 super::bm25_score(tf, self.idf, tf, self.avg_field_len)
530 }
531
532 #[inline]
533 fn max_score(&self) -> f32 {
534 self.max_score
535 }
536
537 #[inline]
538 fn current_block_max_score(&self) -> f32 {
539 let block_max_tf = self.iter.current_block_max_tf() as f32;
541 super::bm25_upper_bound(block_max_tf.max(1.0), self.idf)
542 }
543
544 #[inline]
545 fn skip_to_next_block(&mut self) -> DocId {
546 self.iter.skip_to_next_block()
547 }
548}
549
550pub struct SparseTermScorer<'a> {
554 iter: crate::structures::BlockSparsePostingIterator<'a>,
556 query_weight: f32,
558 abs_query_weight: f32,
560 max_score: f32,
562}
563
564impl<'a> SparseTermScorer<'a> {
565 pub fn new(posting_list: &'a BlockSparsePostingList, query_weight: f32) -> Self {
570 let abs_qw = query_weight.abs();
573 let max_score = abs_qw * posting_list.global_max_weight();
574 Self {
575 iter: posting_list.iterator(),
576 query_weight,
577 abs_query_weight: abs_qw,
578 max_score,
579 }
580 }
581
582 pub fn from_arc(posting_list: &'a Arc<BlockSparsePostingList>, query_weight: f32) -> Self {
584 Self::new(posting_list.as_ref(), query_weight)
585 }
586}
587
588impl ScoringIterator for SparseTermScorer<'_> {
589 #[inline]
590 fn doc(&self) -> DocId {
591 self.iter.doc()
592 }
593
594 #[inline]
595 fn ordinal(&mut self) -> u16 {
596 self.iter.ordinal()
597 }
598
599 #[inline]
600 fn advance(&mut self) -> DocId {
601 self.iter.advance()
602 }
603
604 #[inline]
605 fn seek(&mut self, target: DocId) -> DocId {
606 self.iter.seek(target)
607 }
608
609 #[inline]
610 fn score(&self) -> f32 {
611 self.query_weight * self.iter.weight()
613 }
614
615 #[inline]
616 fn max_score(&self) -> f32 {
617 self.max_score
618 }
619
620 #[inline]
621 fn current_block_max_score(&self) -> f32 {
622 self.iter
623 .current_block_max_contribution(self.abs_query_weight)
624 }
625
626 #[inline]
627 fn skip_to_next_block(&mut self) -> DocId {
628 self.iter.skip_to_next_block()
629 }
630}
631
632pub struct BmpExecutor<'a> {
644 sparse_index: &'a crate::segment::SparseIndex,
646 query_terms: Vec<(u32, f32)>,
648 k: usize,
650 heap_factor: f32,
652}
653
654const BMP_SUPERBLOCK_SIZE: usize = 8;
657
658const BMP_MEGABLOCK_SIZE: usize = 16;
662
663struct BmpSuperBlock {
665 contribution: f32,
667 block_start: usize,
669 block_count: usize,
671}
672
673struct BmpMegaBlockEntry {
675 contribution: f32,
677 term_idx: usize,
679 sb_start: usize,
681 sb_count: usize,
683}
684
685impl PartialEq for BmpMegaBlockEntry {
686 fn eq(&self, other: &Self) -> bool {
687 self.contribution == other.contribution
688 }
689}
690
691impl Eq for BmpMegaBlockEntry {}
692
693impl Ord for BmpMegaBlockEntry {
694 fn cmp(&self, other: &Self) -> Ordering {
695 self.contribution
697 .partial_cmp(&other.contribution)
698 .unwrap_or(Ordering::Equal)
699 }
700}
701
702impl PartialOrd for BmpMegaBlockEntry {
703 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
704 Some(self.cmp(other))
705 }
706}
707
708impl<'a> BmpExecutor<'a> {
709 pub fn new(
714 sparse_index: &'a crate::segment::SparseIndex,
715 query_terms: Vec<(u32, f32)>,
716 k: usize,
717 heap_factor: f32,
718 ) -> Self {
719 Self {
720 sparse_index,
721 query_terms,
722 k,
723 heap_factor: heap_factor.clamp(0.0, 1.0),
724 }
725 }
726
727 pub async fn execute(self) -> crate::Result<Vec<ScoredDoc>> {
735 use rustc_hash::FxHashMap;
736
737 if self.query_terms.is_empty() {
738 return Ok(Vec::new());
739 }
740
741 let num_terms = self.query_terms.len();
742 let si = self.sparse_index;
743
744 let mut term_superblocks: Vec<Vec<BmpSuperBlock>> = Vec::with_capacity(num_terms);
748 let mut term_skip_starts: Vec<usize> = Vec::with_capacity(num_terms);
749 let mut global_min_doc = u32::MAX;
750 let mut global_max_doc = 0u32;
751 let mut total_remaining = 0.0f32;
752
753 for &(dim_id, qw) in &self.query_terms {
754 let mut term_skip_start = 0usize;
755 let mut superblocks = Vec::new();
756
757 let abs_qw = qw.abs();
758 if let Some((skip_start, skip_count, _global_max)) = si.get_skip_range(dim_id) {
759 term_skip_start = skip_start;
760 let mut sb_start = 0;
762 while sb_start < skip_count {
763 let sb_count = (skip_count - sb_start).min(BMP_SUPERBLOCK_SIZE);
764 let mut sb_contribution = 0.0f32;
765 for j in 0..sb_count {
766 let skip = si.read_skip_entry(skip_start + sb_start + j);
767 sb_contribution += abs_qw * skip.max_weight;
768 global_min_doc = global_min_doc.min(skip.first_doc);
769 global_max_doc = global_max_doc.max(skip.last_doc);
770 }
771 total_remaining += sb_contribution;
772 superblocks.push(BmpSuperBlock {
773 contribution: sb_contribution,
774 block_start: sb_start,
775 block_count: sb_count,
776 });
777 sb_start += sb_count;
778 }
779 }
780 term_skip_starts.push(term_skip_start);
781 term_superblocks.push(superblocks);
782 }
783
784 let mut mega_queue: BinaryHeap<BmpMegaBlockEntry> = BinaryHeap::new();
786 for (term_idx, superblocks) in term_superblocks.iter().enumerate() {
787 let mut mb_start = 0;
788 while mb_start < superblocks.len() {
789 let mb_count = (superblocks.len() - mb_start).min(BMP_MEGABLOCK_SIZE);
790 let mb_contribution: f32 = superblocks[mb_start..mb_start + mb_count]
791 .iter()
792 .map(|sb| sb.contribution)
793 .sum();
794 mega_queue.push(BmpMegaBlockEntry {
795 contribution: mb_contribution,
796 term_idx,
797 sb_start: mb_start,
798 sb_count: mb_count,
799 });
800 mb_start += mb_count;
801 }
802 }
803
804 let doc_range = if global_max_doc >= global_min_doc {
806 (global_max_doc - global_min_doc + 1) as usize
807 } else {
808 0
809 };
810 let use_flat = doc_range > 0 && doc_range <= 256 * 1024;
812 let mut flat_scores: Vec<f32> = if use_flat {
813 vec![0.0; doc_range]
814 } else {
815 Vec::new()
816 };
817 let mut dirty: Vec<u32> = if use_flat {
819 Vec::with_capacity(4096)
820 } else {
821 Vec::new()
822 };
823 let mut multi_ord_accumulators: FxHashMap<u64, f32> = FxHashMap::default();
825
826 let mut blocks_processed = 0u64;
827 let mut blocks_skipped = 0u64;
828
829 let mut top_k = ScoreCollector::new(self.k);
831
832 let mut doc_ids_buf: Vec<u32> = Vec::with_capacity(256);
834 let mut weights_buf: Vec<f32> = Vec::with_capacity(256);
835 let mut ordinals_buf: Vec<u16> = Vec::with_capacity(256);
836
837 let mut terms_warmed = vec![false; num_terms];
841 let mut warmup_remaining = self.k.min(num_terms);
842
843 while let Some(mega) = mega_queue.pop() {
845 total_remaining -= mega.contribution;
846
847 if !terms_warmed[mega.term_idx] {
849 terms_warmed[mega.term_idx] = true;
850 warmup_remaining = warmup_remaining.saturating_sub(1);
851 }
852
853 if warmup_remaining == 0 {
855 let adjusted_threshold = top_k.threshold() * self.heap_factor;
856 if top_k.len() >= self.k && total_remaining <= adjusted_threshold {
857 let remaining_blocks: u64 = mega_queue
859 .iter()
860 .map(|m| {
861 let sbs =
862 &term_superblocks[m.term_idx][m.sb_start..m.sb_start + m.sb_count];
863 sbs.iter().map(|sb| sb.block_count as u64).sum::<u64>()
864 })
865 .sum();
866 blocks_skipped += remaining_blocks;
867 debug!(
868 "BMP early termination after {} blocks: remaining={:.4} <= threshold={:.4}",
869 blocks_processed, total_remaining, adjusted_threshold
870 );
871 break;
872 }
873 }
874
875 let dim_id = self.query_terms[mega.term_idx].0;
876 let qw = self.query_terms[mega.term_idx].1;
877 let abs_qw = qw.abs();
878 let skip_start = term_skip_starts[mega.term_idx];
879
880 for sb in term_superblocks[mega.term_idx]
882 .iter()
883 .skip(mega.sb_start)
884 .take(mega.sb_count)
885 {
886 if top_k.len() >= self.k {
888 let adjusted_threshold = top_k.threshold() * self.heap_factor;
889 if sb.contribution + total_remaining <= adjusted_threshold {
890 blocks_skipped += sb.block_count as u64;
891 continue;
892 }
893 }
894
895 let sb_blocks = si
897 .get_blocks_range(dim_id, sb.block_start, sb.block_count)
898 .await?;
899
900 let adjusted_threshold2 = top_k.threshold() * self.heap_factor;
901
902 let dirty_start = dirty.len();
904
905 for (blk_offset, block) in sb_blocks.into_iter().enumerate() {
906 let blk_idx = sb.block_start + blk_offset;
907
908 if top_k.len() >= self.k {
910 let skip = si.read_skip_entry(skip_start + blk_idx);
911 let blk_contrib = abs_qw * skip.max_weight;
912 if blk_contrib + total_remaining <= adjusted_threshold2 {
913 blocks_skipped += 1;
914 continue;
915 }
916 }
917
918 block.decode_doc_ids_into(&mut doc_ids_buf);
919
920 if block.header.ordinal_bits == 0 && use_flat {
922 block.accumulate_scored_weights(
923 qw,
924 &doc_ids_buf,
925 &mut flat_scores,
926 global_min_doc,
927 &mut dirty,
928 );
929 } else {
930 block.decode_scored_weights_into(qw, &mut weights_buf);
931 let count = block.header.count as usize;
932
933 block.decode_ordinals_into(&mut ordinals_buf);
934 if use_flat {
935 for i in 0..count {
936 let doc_id = doc_ids_buf[i];
937 let ordinal = ordinals_buf[i];
938 let score_contribution = weights_buf[i];
939
940 if ordinal == 0 {
941 let off = (doc_id - global_min_doc) as usize;
942 if flat_scores[off] == 0.0 {
943 dirty.push(doc_id);
944 }
945 flat_scores[off] += score_contribution;
946 } else {
947 let key = (doc_id as u64) << 16 | ordinal as u64;
948 let acc = multi_ord_accumulators.entry(key).or_insert(0.0);
949 *acc += score_contribution;
950 top_k.insert_with_ordinal(doc_id, *acc, ordinal);
951 }
952 }
953 } else {
954 for i in 0..count {
955 let key = (doc_ids_buf[i] as u64) << 16 | ordinals_buf[i] as u64;
956 let acc = multi_ord_accumulators.entry(key).or_insert(0.0);
957 *acc += weights_buf[i];
958 top_k.insert_with_ordinal(doc_ids_buf[i], *acc, ordinals_buf[i]);
959 }
960 }
961 }
962
963 blocks_processed += 1;
964 }
965
966 for &doc_id in &dirty[dirty_start..] {
970 let off = (doc_id - global_min_doc) as usize;
971 top_k.insert_with_ordinal(doc_id, flat_scores[off], 0);
972 }
973 }
974 }
975
976 let mut scored: Vec<ScoredDoc> = Vec::new();
978
979 let num_accumulators = if use_flat {
980 scored.reserve(dirty.len() + multi_ord_accumulators.len());
982 for &doc_id in &dirty {
983 let off = (doc_id - global_min_doc) as usize;
984 let score = flat_scores[off];
985 if score > 0.0 {
986 scored.push(ScoredDoc {
987 doc_id,
988 score,
989 ordinal: 0,
990 });
991 }
992 }
993 dirty.len() + multi_ord_accumulators.len()
994 } else {
995 multi_ord_accumulators.len()
996 };
997
998 scored.extend(
1000 multi_ord_accumulators
1001 .into_iter()
1002 .map(|(key, score)| ScoredDoc {
1003 doc_id: (key >> 16) as DocId,
1004 score,
1005 ordinal: (key & 0xFFFF) as u16,
1006 }),
1007 );
1008
1009 scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
1010 scored.truncate(self.k);
1011 let results = scored;
1012
1013 debug!(
1014 "BmpExecutor completed: blocks_processed={}, blocks_skipped={}, accumulators={}, flat={}, returned={}, top_score={:.4}",
1015 blocks_processed,
1016 blocks_skipped,
1017 num_accumulators,
1018 use_flat,
1019 results.len(),
1020 results.first().map(|r| r.score).unwrap_or(0.0)
1021 );
1022
1023 Ok(results)
1024 }
1025}
1026
1027pub struct SparseMaxScoreExecutor<'a> {
1039 sparse_index: &'a crate::segment::SparseIndex,
1040 cursors: Vec<LazyTermCursor>,
1041 prefix_sums: Vec<f32>,
1042 collector: ScoreCollector,
1043 heap_factor: f32,
1044}
1045
1046struct LazyTermCursor {
1048 query_weight: f32,
1049 abs_query_weight: f32,
1051 max_score: f32,
1052 skip_start: usize,
1054 skip_count: usize,
1056 block_data_offset: u64,
1058 block_idx: usize,
1060 doc_ids: Vec<u32>,
1062 ordinals: Vec<u16>,
1063 weights: Vec<f32>,
1064 pos: usize,
1066 block_loaded: bool,
1068 exhausted: bool,
1069}
1070
1071impl LazyTermCursor {
1072 fn new(
1073 query_weight: f32,
1074 skip_start: usize,
1075 skip_count: usize,
1076 global_max_weight: f32,
1077 block_data_offset: u64,
1078 ) -> Self {
1079 let exhausted = skip_count == 0;
1080 let abs_qw = query_weight.abs();
1081 Self {
1082 query_weight,
1083 abs_query_weight: abs_qw,
1084 max_score: abs_qw * global_max_weight,
1085 skip_start,
1086 skip_count,
1087 block_data_offset,
1088 block_idx: 0,
1089 doc_ids: Vec::with_capacity(256),
1090 ordinals: Vec::with_capacity(256),
1091 weights: Vec::with_capacity(256),
1092 pos: 0,
1093 block_loaded: false,
1094 exhausted,
1095 }
1096 }
1097
1098 #[inline]
1102 fn decode_block(&mut self, block: crate::structures::SparseBlock) {
1103 block.decode_doc_ids_into(&mut self.doc_ids);
1104 block.decode_ordinals_into(&mut self.ordinals);
1105 block.decode_scored_weights_into(self.query_weight, &mut self.weights);
1106 self.pos = 0;
1107 self.block_loaded = true;
1108 }
1109
1110 #[inline]
1112 fn handle_block_result(
1113 &mut self,
1114 block: Option<crate::structures::SparseBlock>,
1115 ) -> crate::Result<bool> {
1116 match block {
1117 Some(b) => {
1118 self.decode_block(b);
1119 Ok(true)
1120 }
1121 None => {
1122 self.exhausted = true;
1123 Ok(false)
1124 }
1125 }
1126 }
1127
1128 #[inline]
1131 fn advance_pos(&mut self) -> DocId {
1132 self.pos += 1;
1133 if self.pos >= self.doc_ids.len() {
1134 self.block_idx += 1;
1135 self.block_loaded = false;
1136 if self.block_idx >= self.skip_count {
1137 self.exhausted = true;
1138 return u32::MAX;
1139 }
1140 }
1141 self.doc()
1142 }
1143
1144 fn seek_prepare(
1148 &mut self,
1149 si: &crate::segment::SparseIndex,
1150 target: DocId,
1151 ) -> crate::Result<Option<DocId>> {
1152 if self.exhausted {
1153 return Ok(Some(u32::MAX));
1154 }
1155
1156 if self.block_loaded
1158 && let Some(&last) = self.doc_ids.last()
1159 {
1160 if last >= target && self.doc_ids[self.pos] < target {
1161 let remaining = &self.doc_ids[self.pos..];
1162 let offset = crate::structures::simd::find_first_ge_u32(remaining, target);
1163 self.pos += offset;
1164 if self.pos >= self.doc_ids.len() {
1165 self.block_idx += 1;
1166 self.block_loaded = false;
1167 if self.block_idx >= self.skip_count {
1168 self.exhausted = true;
1169 return Ok(Some(u32::MAX));
1170 }
1171 }
1172 return Ok(Some(self.doc()));
1173 }
1174 if self.doc_ids[self.pos] >= target {
1175 return Ok(Some(self.doc()));
1176 }
1177 }
1178
1179 let mut lo = self.block_idx;
1181 let mut hi = self.skip_count;
1182 while lo < hi {
1183 let mid = lo + (hi - lo) / 2;
1184 if si.read_skip_entry(self.skip_start + mid).last_doc < target {
1185 lo = mid + 1;
1186 } else {
1187 hi = mid;
1188 }
1189 }
1190 if lo >= self.skip_count {
1191 self.exhausted = true;
1192 return Ok(Some(u32::MAX));
1193 }
1194 if lo != self.block_idx || !self.block_loaded {
1195 self.block_idx = lo;
1196 self.block_loaded = false;
1197 }
1198 Ok(None)
1200 }
1201
1202 #[inline]
1205 fn seek_finish(&mut self, target: DocId) -> bool {
1206 if self.exhausted {
1207 return false;
1208 }
1209 self.pos = crate::structures::simd::find_first_ge_u32(&self.doc_ids, target);
1210 if self.pos >= self.doc_ids.len() {
1211 self.block_idx += 1;
1212 self.block_loaded = false;
1213 if self.block_idx >= self.skip_count {
1214 self.exhausted = true;
1215 return false;
1216 }
1217 return true; }
1219 false
1220 }
1221
1222 async fn ensure_block_loaded(
1225 &mut self,
1226 si: &crate::segment::SparseIndex,
1227 ) -> crate::Result<bool> {
1228 if self.exhausted || self.block_loaded {
1229 return Ok(!self.exhausted);
1230 }
1231 let block = si
1232 .load_block_direct(self.skip_start, self.block_data_offset, self.block_idx)
1233 .await?;
1234 self.handle_block_result(block)
1235 }
1236
1237 async fn advance(&mut self, si: &crate::segment::SparseIndex) -> crate::Result<DocId> {
1238 if self.exhausted {
1239 return Ok(u32::MAX);
1240 }
1241 self.ensure_block_loaded(si).await?;
1242 if self.exhausted {
1243 return Ok(u32::MAX);
1244 }
1245 Ok(self.advance_pos())
1246 }
1247
1248 async fn seek(
1249 &mut self,
1250 si: &crate::segment::SparseIndex,
1251 target: DocId,
1252 ) -> crate::Result<DocId> {
1253 if let Some(doc) = self.seek_prepare(si, target)? {
1254 return Ok(doc);
1255 }
1256 self.ensure_block_loaded(si).await?;
1257 if self.seek_finish(target) {
1258 self.ensure_block_loaded(si).await?;
1259 }
1260 Ok(self.doc())
1261 }
1262
1263 fn ensure_block_loaded_sync(
1266 &mut self,
1267 si: &crate::segment::SparseIndex,
1268 ) -> crate::Result<bool> {
1269 if self.exhausted || self.block_loaded {
1270 return Ok(!self.exhausted);
1271 }
1272 let block =
1273 si.load_block_direct_sync(self.skip_start, self.block_data_offset, self.block_idx)?;
1274 self.handle_block_result(block)
1275 }
1276
1277 fn advance_sync(&mut self, si: &crate::segment::SparseIndex) -> crate::Result<DocId> {
1278 if self.exhausted {
1279 return Ok(u32::MAX);
1280 }
1281 self.ensure_block_loaded_sync(si)?;
1282 if self.exhausted {
1283 return Ok(u32::MAX);
1284 }
1285 Ok(self.advance_pos())
1286 }
1287
1288 fn seek_sync(
1289 &mut self,
1290 si: &crate::segment::SparseIndex,
1291 target: DocId,
1292 ) -> crate::Result<DocId> {
1293 if let Some(doc) = self.seek_prepare(si, target)? {
1294 return Ok(doc);
1295 }
1296 self.ensure_block_loaded_sync(si)?;
1297 if self.seek_finish(target) {
1298 self.ensure_block_loaded_sync(si)?;
1299 }
1300 Ok(self.doc())
1301 }
1302
1303 #[inline]
1306 fn doc_with_si(&self, si: &crate::segment::SparseIndex) -> DocId {
1307 if self.exhausted {
1308 return u32::MAX;
1309 }
1310 if !self.block_loaded {
1311 return si
1312 .read_skip_entry(self.skip_start + self.block_idx)
1313 .first_doc;
1314 }
1315 self.doc_ids.get(self.pos).copied().unwrap_or(u32::MAX)
1316 }
1317
1318 #[inline]
1319 fn doc(&self) -> DocId {
1320 if self.exhausted {
1321 return u32::MAX;
1322 }
1323 if self.block_loaded {
1324 return self.doc_ids.get(self.pos).copied().unwrap_or(u32::MAX);
1325 }
1326 u32::MAX
1327 }
1328
1329 #[inline]
1330 fn ordinal(&self) -> u16 {
1331 if !self.block_loaded {
1332 return 0;
1333 }
1334 self.ordinals.get(self.pos).copied().unwrap_or(0)
1335 }
1336
1337 #[inline]
1338 fn score(&self) -> f32 {
1339 if !self.block_loaded {
1340 return 0.0;
1341 }
1342 self.weights.get(self.pos).copied().unwrap_or(0.0)
1343 }
1344
1345 #[inline]
1346 fn current_block_max_score(&self, si: &crate::segment::SparseIndex) -> f32 {
1347 if self.exhausted || self.block_idx >= self.skip_count {
1348 return 0.0;
1349 }
1350 self.abs_query_weight
1351 * si.read_skip_entry(self.skip_start + self.block_idx)
1352 .max_weight
1353 }
1354
1355 fn skip_to_next_block(&mut self, si: &crate::segment::SparseIndex) -> DocId {
1357 if self.exhausted {
1358 return u32::MAX;
1359 }
1360 self.block_idx += 1;
1361 self.block_loaded = false;
1362 if self.block_idx >= self.skip_count {
1363 self.exhausted = true;
1364 return u32::MAX;
1365 }
1366 si.read_skip_entry(self.skip_start + self.block_idx)
1367 .first_doc
1368 }
1369}
1370
1371macro_rules! bms_execute_loop {
1376 ($self:ident, $ensure:ident, $advance:ident, $seek:ident, $($aw:tt)*) => {{
1377 let n = $self.cursors.len();
1378 let si = $self.sparse_index;
1379
1380 for cursor in &mut $self.cursors {
1382 cursor.$ensure(si) $($aw)* ?;
1383 }
1384
1385 let mut docs_scored = 0u64;
1386 let mut docs_skipped = 0u64;
1387 let mut blocks_skipped = 0u64;
1388 let mut blocks_loaded = 0u64;
1389 let mut conjunction_skipped = 0u64;
1390 let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
1391
1392 loop {
1393 let partition = $self.find_partition();
1394 if partition >= n {
1395 break;
1396 }
1397
1398 let mut min_doc = u32::MAX;
1400 for i in partition..n {
1401 let doc = $self.cursors[i].doc_with_si(si);
1402 if doc < min_doc {
1403 min_doc = doc;
1404 }
1405 }
1406 if min_doc == u32::MAX {
1407 break;
1408 }
1409
1410 let non_essential_upper = if partition > 0 {
1411 $self.prefix_sums[partition - 1]
1412 } else {
1413 0.0
1414 };
1415 let adjusted_threshold = $self.collector.threshold() * $self.heap_factor;
1416
1417 if $self.collector.len() >= $self.collector.k {
1419 let present_upper: f32 = (partition..n)
1420 .filter(|&i| $self.cursors[i].doc_with_si(si) == min_doc)
1421 .map(|i| $self.cursors[i].max_score)
1422 .sum();
1423
1424 if present_upper + non_essential_upper <= adjusted_threshold {
1425 for i in partition..n {
1426 if $self.cursors[i].doc_with_si(si) == min_doc {
1427 $self.cursors[i].$ensure(si) $($aw)* ?;
1428 $self.cursors[i].$advance(si) $($aw)* ?;
1429 blocks_loaded += u64::from($self.cursors[i].block_loaded);
1430 }
1431 }
1432 conjunction_skipped += 1;
1433 continue;
1434 }
1435 }
1436
1437 if $self.collector.len() >= $self.collector.k {
1439 let block_max_sum: f32 = (partition..n)
1440 .filter(|&i| $self.cursors[i].doc_with_si(si) == min_doc)
1441 .map(|i| $self.cursors[i].current_block_max_score(si))
1442 .sum();
1443
1444 if block_max_sum + non_essential_upper <= adjusted_threshold {
1445 for i in partition..n {
1446 if $self.cursors[i].doc_with_si(si) == min_doc {
1447 $self.cursors[i].skip_to_next_block(si);
1448 $self.cursors[i].$ensure(si) $($aw)* ?;
1449 blocks_loaded += 1;
1450 }
1451 }
1452 blocks_skipped += 1;
1453 continue;
1454 }
1455 }
1456
1457 ordinal_scores.clear();
1459 for i in partition..n {
1460 if $self.cursors[i].doc_with_si(si) == min_doc {
1461 $self.cursors[i].$ensure(si) $($aw)* ?;
1462 while $self.cursors[i].doc_with_si(si) == min_doc {
1463 ordinal_scores.push(($self.cursors[i].ordinal(), $self.cursors[i].score()));
1464 $self.cursors[i].$advance(si) $($aw)* ?;
1465 }
1466 }
1467 }
1468
1469 let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
1470 if $self.collector.len() >= $self.collector.k
1471 && essential_total + non_essential_upper <= adjusted_threshold
1472 {
1473 docs_skipped += 1;
1474 continue;
1475 }
1476
1477 let mut running_total = essential_total;
1479 for i in (0..partition).rev() {
1480 if $self.collector.len() >= $self.collector.k
1481 && running_total + $self.prefix_sums[i] <= adjusted_threshold
1482 {
1483 break;
1484 }
1485
1486 let doc = $self.cursors[i].$seek(si, min_doc) $($aw)* ?;
1487 if doc == min_doc {
1488 while $self.cursors[i].doc_with_si(si) == min_doc {
1489 let s = $self.cursors[i].score();
1490 running_total += s;
1491 ordinal_scores.push(($self.cursors[i].ordinal(), s));
1492 $self.cursors[i].$advance(si) $($aw)* ?;
1493 }
1494 }
1495 }
1496
1497 ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
1499 let mut j = 0;
1500 while j < ordinal_scores.len() {
1501 let current_ord = ordinal_scores[j].0;
1502 let mut score = 0.0f32;
1503 while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
1504 score += ordinal_scores[j].1;
1505 j += 1;
1506 }
1507 if $self
1508 .collector
1509 .insert_with_ordinal(min_doc, score, current_ord)
1510 {
1511 docs_scored += 1;
1512 } else {
1513 docs_skipped += 1;
1514 }
1515 }
1516 }
1517
1518 let results: Vec<ScoredDoc> = $self
1519 .collector
1520 .into_sorted_results()
1521 .into_iter()
1522 .map(|(doc_id, score, ordinal)| ScoredDoc {
1523 doc_id,
1524 score,
1525 ordinal,
1526 })
1527 .collect();
1528
1529 debug!(
1530 "SparseMaxScoreExecutor: scored={}, skipped={}, blocks_skipped={}, blocks_loaded={}, conjunction_skipped={}, returned={}, top_score={:.4}",
1531 docs_scored,
1532 docs_skipped,
1533 blocks_skipped,
1534 blocks_loaded,
1535 conjunction_skipped,
1536 results.len(),
1537 results.first().map(|r| r.score).unwrap_or(0.0)
1538 );
1539
1540 Ok(results)
1541 }};
1542}
1543
1544impl<'a> SparseMaxScoreExecutor<'a> {
1545 pub fn new(
1550 sparse_index: &'a crate::segment::SparseIndex,
1551 query_terms: Vec<(u32, f32)>,
1552 k: usize,
1553 heap_factor: f32,
1554 ) -> Self {
1555 let mut cursors: Vec<LazyTermCursor> = query_terms
1556 .iter()
1557 .filter_map(|&(dim_id, qw)| {
1558 let (skip_start, skip_count, global_max, block_data_offset) =
1559 sparse_index.get_skip_range_full(dim_id)?;
1560 Some(LazyTermCursor::new(
1561 qw,
1562 skip_start,
1563 skip_count,
1564 global_max,
1565 block_data_offset,
1566 ))
1567 })
1568 .collect();
1569
1570 cursors.sort_by(|a, b| {
1572 a.max_score
1573 .partial_cmp(&b.max_score)
1574 .unwrap_or(Ordering::Equal)
1575 });
1576
1577 let mut prefix_sums = Vec::with_capacity(cursors.len());
1578 let mut cumsum = 0.0f32;
1579 for c in &cursors {
1580 cumsum += c.max_score;
1581 prefix_sums.push(cumsum);
1582 }
1583
1584 debug!(
1585 "Creating SparseMaxScoreExecutor: num_terms={}, k={}, total_upper={:.4}, heap_factor={:.2}",
1586 cursors.len(),
1587 k,
1588 cumsum,
1589 heap_factor
1590 );
1591
1592 Self {
1593 sparse_index,
1594 cursors,
1595 prefix_sums,
1596 collector: ScoreCollector::new(k),
1597 heap_factor: heap_factor.clamp(0.0, 1.0),
1598 }
1599 }
1600
1601 #[inline]
1602 fn find_partition(&self) -> usize {
1603 let threshold = self.collector.threshold() * self.heap_factor;
1604 self.prefix_sums.partition_point(|&sum| sum <= threshold)
1606 }
1607
1608 pub async fn execute(mut self) -> crate::Result<Vec<ScoredDoc>> {
1610 if self.cursors.is_empty() {
1611 return Ok(Vec::new());
1612 }
1613 bms_execute_loop!(self, ensure_block_loaded, advance, seek, .await)
1614 }
1615
1616 pub fn execute_sync(mut self) -> crate::Result<Vec<ScoredDoc>> {
1619 if self.cursors.is_empty() {
1620 return Ok(Vec::new());
1621 }
1622 bms_execute_loop!(self, ensure_block_loaded_sync, advance_sync, seek_sync,)
1623 }
1624}
1625
1626#[cfg(test)]
1627mod tests {
1628 use super::*;
1629
1630 #[test]
1631 fn test_score_collector_basic() {
1632 let mut collector = ScoreCollector::new(3);
1633
1634 collector.insert(1, 1.0);
1635 collector.insert(2, 2.0);
1636 collector.insert(3, 3.0);
1637 assert_eq!(collector.threshold(), 1.0);
1638
1639 collector.insert(4, 4.0);
1640 assert_eq!(collector.threshold(), 2.0);
1641
1642 let results = collector.into_sorted_results();
1643 assert_eq!(results.len(), 3);
1644 assert_eq!(results[0].0, 4); assert_eq!(results[1].0, 3);
1646 assert_eq!(results[2].0, 2);
1647 }
1648
1649 #[test]
1650 fn test_score_collector_threshold() {
1651 let mut collector = ScoreCollector::new(2);
1652
1653 collector.insert(1, 5.0);
1654 collector.insert(2, 3.0);
1655 assert_eq!(collector.threshold(), 3.0);
1656
1657 assert!(!collector.would_enter(2.0));
1659 assert!(!collector.insert(3, 2.0));
1660
1661 assert!(collector.would_enter(4.0));
1663 assert!(collector.insert(4, 4.0));
1664 assert_eq!(collector.threshold(), 4.0);
1665 }
1666
1667 #[test]
1668 fn test_heap_entry_ordering() {
1669 let mut heap = BinaryHeap::new();
1670 heap.push(HeapEntry {
1671 doc_id: 1,
1672 score: 3.0,
1673 ordinal: 0,
1674 });
1675 heap.push(HeapEntry {
1676 doc_id: 2,
1677 score: 1.0,
1678 ordinal: 0,
1679 });
1680 heap.push(HeapEntry {
1681 doc_id: 3,
1682 score: 2.0,
1683 ordinal: 0,
1684 });
1685
1686 assert_eq!(heap.pop().unwrap().score, 1.0);
1688 assert_eq!(heap.pop().unwrap().score, 2.0);
1689 assert_eq!(heap.pop().unwrap().score, 3.0);
1690 }
1691}