1use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11
12use log::{debug, warn};
13
14use crate::DocId;
15
16#[derive(Clone, Copy)]
18pub struct HeapEntry {
19 pub doc_id: DocId,
20 pub score: f32,
21 pub ordinal: u16,
22}
23
24impl PartialEq for HeapEntry {
25 fn eq(&self, other: &Self) -> bool {
26 self.score == other.score && self.doc_id == other.doc_id
27 }
28}
29
30impl Eq for HeapEntry {}
31
32impl Ord for HeapEntry {
33 fn cmp(&self, other: &Self) -> Ordering {
34 other
37 .score
38 .total_cmp(&self.score)
39 .then(self.doc_id.cmp(&other.doc_id))
40 }
41}
42
43impl PartialOrd for HeapEntry {
44 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
45 Some(self.cmp(other))
46 }
47}
48
49pub struct ScoreCollector {
62 heap: BinaryHeap<HeapEntry>,
64 pub k: usize,
65 cached_threshold: f32,
68}
69
70impl ScoreCollector {
71 pub fn new(k: usize) -> Self {
73 let capacity = k.saturating_add(1).min(1_000_000);
75 Self {
76 heap: BinaryHeap::with_capacity(capacity),
77 k,
78 cached_threshold: 0.0,
79 }
80 }
81
82 #[inline]
84 pub fn threshold(&self) -> f32 {
85 self.cached_threshold
86 }
87
88 #[inline]
90 fn update_threshold(&mut self) {
91 self.cached_threshold = if self.heap.len() >= self.k {
92 self.heap.peek().map(|e| e.score).unwrap_or(0.0)
93 } else {
94 0.0
95 };
96 }
97
98 #[inline]
101 pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
102 self.insert_with_ordinal(doc_id, score, 0)
103 }
104
105 #[inline]
108 pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
109 if self.heap.len() < self.k {
110 self.heap.push(HeapEntry {
111 doc_id,
112 score,
113 ordinal,
114 });
115 if self.heap.len() == self.k {
117 self.update_threshold();
118 }
119 true
120 } else if score > self.cached_threshold {
121 self.heap.push(HeapEntry {
122 doc_id,
123 score,
124 ordinal,
125 });
126 self.heap.pop(); self.update_threshold();
128 true
129 } else {
130 false
131 }
132 }
133
134 #[inline]
136 pub fn would_enter(&self, score: f32) -> bool {
137 self.heap.len() < self.k || score > self.cached_threshold
138 }
139
140 #[inline]
142 pub fn len(&self) -> usize {
143 self.heap.len()
144 }
145
146 #[inline]
148 pub fn is_empty(&self) -> bool {
149 self.heap.is_empty()
150 }
151
152 pub fn seed_threshold(&mut self, initial_threshold: f32) {
158 if initial_threshold > 0.0 && self.heap.is_empty() {
159 for _ in 0..self.k {
160 self.heap.push(HeapEntry {
161 doc_id: u32::MAX,
162 score: initial_threshold,
163 ordinal: 0,
164 });
165 }
166 self.update_threshold();
167 }
168 }
169
170 pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
173 let mut results: Vec<(DocId, f32, u16)> = self
174 .heap
175 .into_vec()
176 .into_iter()
177 .filter(|e| e.doc_id != u32::MAX)
178 .map(|e| (e.doc_id, e.score, e.ordinal))
179 .collect();
180
181 results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
183
184 results
185 }
186}
187
188#[derive(Debug, Clone, Copy)]
190pub struct ScoredDoc {
191 pub doc_id: DocId,
192 pub score: f32,
193 pub ordinal: u16,
195}
196
197pub struct MaxScoreExecutor<'a> {
208 cursors: Vec<TermCursor<'a>>,
209 prefix_sums: Vec<f32>,
210 collector: ScoreCollector,
211 inv_heap_factor: f32,
212 predicate: Option<super::DocPredicate<'a>>,
213}
214
215pub(crate) struct TermCursor<'a> {
224 pub max_score: f32,
225 num_blocks: usize,
226 block_idx: usize,
228 doc_ids: Vec<u32>,
229 scores: Vec<f32>,
230 ordinals: Vec<u16>,
231 pos: usize,
232 block_loaded: bool,
233 exhausted: bool,
234 lazy_ordinals: bool,
238 ordinals_loaded: bool,
240 current_sparse_block: Option<crate::structures::SparseBlock>,
242 variant: CursorVariant<'a>,
244}
245
246enum CursorVariant<'a> {
247 Text {
249 list: crate::structures::BlockPostingList,
250 idf: f32,
251 idf_times_k1_plus_1: f32,
253 denom_tf_coeff: f32,
255 denom_const: f32,
257 tfs: Vec<u32>,
258 deferred_tf: Option<(usize, usize, usize)>,
261 },
262 Sparse {
264 si: &'a crate::segment::SparseIndex,
265 query_weight: f32,
266 skip_start: usize,
267 block_data_offset: u64,
268 },
269}
270
271macro_rules! cursor_ensure_block {
279 ($self:ident, $load_block_fn:ident, $($aw:tt)*) => {{
280 if $self.exhausted || $self.block_loaded {
281 return Ok(!$self.exhausted);
282 }
283 match &mut $self.variant {
284 CursorVariant::Text {
285 list,
286 deferred_tf,
287 ..
288 } => {
289 if let Some(state) = list.decode_block_doc_ids_only($self.block_idx, &mut $self.doc_ids) {
290 *deferred_tf = Some(state);
291 $self.scores.clear();
292 $self.pos = 0;
293 $self.block_loaded = true;
294 Ok(true)
295 } else {
296 $self.exhausted = true;
297 Ok(false)
298 }
299 }
300 CursorVariant::Sparse {
301 si,
302 query_weight,
303 skip_start,
304 block_data_offset,
305 ..
306 } => {
307 let block = si
308 .$load_block_fn(*skip_start, *block_data_offset, $self.block_idx)
309 $($aw)* ?;
310 match block {
311 Some(b) => {
312 b.decode_doc_ids_into(&mut $self.doc_ids);
313 b.decode_scored_weights_into(*query_weight, &mut $self.scores);
314 if $self.lazy_ordinals {
315 $self.current_sparse_block = Some(b);
318 $self.ordinals_loaded = false;
319 } else {
320 b.decode_ordinals_into(&mut $self.ordinals);
321 $self.ordinals_loaded = true;
322 $self.current_sparse_block = None;
323 }
324 $self.pos = 0;
325 $self.block_loaded = true;
326 Ok(true)
327 }
328 None => {
329 $self.exhausted = true;
330 Ok(false)
331 }
332 }
333 }
334 }
335 }};
336}
337
338macro_rules! cursor_advance {
339 ($self:ident, $ensure_fn:ident, $($aw:tt)*) => {{
340 if $self.exhausted {
341 return Ok(u32::MAX);
342 }
343 $self.$ensure_fn() $($aw)* ?;
344 if $self.exhausted {
345 return Ok(u32::MAX);
346 }
347 Ok($self.advance_pos())
348 }};
349}
350
351macro_rules! cursor_seek {
352 ($self:ident, $ensure_fn:ident, $target:expr, $($aw:tt)*) => {{
353 if let Some(doc) = $self.seek_prepare($target) {
354 return Ok(doc);
355 }
356 $self.$ensure_fn() $($aw)* ?;
357 if $self.seek_finish($target) {
358 $self.$ensure_fn() $($aw)* ?;
359 }
360 Ok($self.doc())
361 }};
362}
363
364impl<'a> TermCursor<'a> {
365 pub fn text(
367 posting_list: crate::structures::BlockPostingList,
368 idf: f32,
369 avg_field_len: f32,
370 ) -> Self {
371 let max_tf = posting_list.max_tf() as f32;
372 let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
373 let num_blocks = posting_list.num_blocks();
374 let safe_avg = avg_field_len.max(1.0);
375 Self {
376 max_score,
377 num_blocks,
378 block_idx: 0,
379 doc_ids: Vec::with_capacity(128),
380 scores: Vec::with_capacity(128),
381 ordinals: Vec::new(),
382 pos: 0,
383 block_loaded: false,
384 exhausted: num_blocks == 0,
385 lazy_ordinals: false,
386 ordinals_loaded: true, current_sparse_block: None,
388 variant: CursorVariant::Text {
389 list: posting_list,
390 idf,
391 idf_times_k1_plus_1: idf * (super::BM25_K1 + 1.0),
392 denom_tf_coeff: 1.0 + super::BM25_K1 * (super::BM25_B / safe_avg),
393 denom_const: super::BM25_K1 * (1.0 - super::BM25_B),
394 tfs: Vec::with_capacity(128),
395 deferred_tf: None,
396 },
397 }
398 }
399
400 pub fn sparse(
403 si: &'a crate::segment::SparseIndex,
404 query_weight: f32,
405 skip_start: usize,
406 skip_count: usize,
407 global_max_weight: f32,
408 block_data_offset: u64,
409 ) -> Self {
410 Self {
411 max_score: query_weight.abs() * global_max_weight,
412 num_blocks: skip_count,
413 block_idx: 0,
414 doc_ids: Vec::with_capacity(256),
415 scores: Vec::with_capacity(256),
416 ordinals: Vec::with_capacity(256),
417 pos: 0,
418 block_loaded: false,
419 exhausted: skip_count == 0,
420 lazy_ordinals: false,
421 ordinals_loaded: true,
422 current_sparse_block: None,
423 variant: CursorVariant::Sparse {
424 si,
425 query_weight,
426 skip_start,
427 block_data_offset,
428 },
429 }
430 }
431
432 #[inline]
435 fn block_first_doc(&self, idx: usize) -> DocId {
436 match &self.variant {
437 CursorVariant::Text { list, .. } => list.block_first_doc(idx).unwrap_or(u32::MAX),
438 CursorVariant::Sparse { si, skip_start, .. } => {
439 si.read_skip_entry(*skip_start + idx).first_doc
440 }
441 }
442 }
443
444 #[inline]
445 fn block_last_doc(&self, idx: usize) -> DocId {
446 match &self.variant {
447 CursorVariant::Text { list, .. } => list.block_last_doc(idx).unwrap_or(0),
448 CursorVariant::Sparse { si, skip_start, .. } => {
449 si.read_skip_entry(*skip_start + idx).last_doc
450 }
451 }
452 }
453
454 #[inline]
457 pub fn doc(&self) -> DocId {
458 if self.exhausted {
459 return u32::MAX;
460 }
461 if self.block_loaded {
462 debug_assert!(self.pos < self.doc_ids.len());
463 unsafe { *self.doc_ids.get_unchecked(self.pos) }
465 } else {
466 self.block_first_doc(self.block_idx)
467 }
468 }
469
470 #[inline]
471 pub fn ordinal(&self) -> u16 {
472 if !self.block_loaded || self.ordinals.is_empty() {
473 return 0;
474 }
475 debug_assert!(self.pos < self.ordinals.len());
476 unsafe { *self.ordinals.get_unchecked(self.pos) }
478 }
479
480 #[inline]
486 pub fn ordinal_mut(&mut self) -> u16 {
487 if !self.block_loaded {
488 return 0;
489 }
490 if !self.ordinals_loaded {
491 if let Some(ref block) = self.current_sparse_block {
492 block.decode_ordinals_into(&mut self.ordinals);
493 }
494 self.ordinals_loaded = true;
495 }
496 if self.ordinals.is_empty() {
497 return 0;
498 }
499 debug_assert!(self.pos < self.ordinals.len());
500 unsafe { *self.ordinals.get_unchecked(self.pos) }
501 }
502
503 #[inline]
504 pub fn score(&self) -> f32 {
505 if !self.block_loaded {
506 return 0.0;
507 }
508 debug_assert!(self.pos < self.scores.len());
509 unsafe { *self.scores.get_unchecked(self.pos) }
511 }
512
513 #[inline]
519 pub fn ensure_scores(&mut self) {
520 if self.block_loaded && self.scores.is_empty() {
521 self.compute_deferred_scores();
522 }
523 }
524
525 #[inline]
526 pub fn current_block_max_score(&self) -> f32 {
527 if self.exhausted {
528 return 0.0;
529 }
530 match &self.variant {
531 CursorVariant::Text { list, idf, .. } => {
532 let block_max_tf = list.block_max_tf(self.block_idx).unwrap_or(0) as f32;
533 super::bm25_upper_bound(block_max_tf.max(1.0), *idf)
534 }
535 CursorVariant::Sparse {
536 si,
537 query_weight,
538 skip_start,
539 ..
540 } => query_weight.abs() * si.read_skip_entry(*skip_start + self.block_idx).max_weight,
541 }
542 }
543
544 pub fn skip_to_next_block(&mut self) -> DocId {
547 if self.exhausted {
548 return u32::MAX;
549 }
550 self.block_idx += 1;
551 self.block_loaded = false;
552 if self.block_idx >= self.num_blocks {
553 self.exhausted = true;
554 return u32::MAX;
555 }
556 self.block_first_doc(self.block_idx)
557 }
558
559 #[inline]
560 fn advance_pos(&mut self) -> DocId {
561 self.pos += 1;
562 if self.pos >= self.doc_ids.len() {
563 self.block_idx += 1;
564 self.block_loaded = false;
565 if self.block_idx >= self.num_blocks {
566 self.exhausted = true;
567 return u32::MAX;
568 }
569 }
570 self.doc()
571 }
572
573 #[inline(never)]
575 fn compute_deferred_scores(&mut self) {
576 if let CursorVariant::Text {
577 list,
578 idf_times_k1_plus_1,
579 denom_tf_coeff,
580 denom_const,
581 tfs,
582 deferred_tf,
583 ..
584 } = &mut self.variant
585 && let Some((block_offset, tf_start, count)) = deferred_tf.take()
586 {
587 list.decode_block_tfs_deferred(block_offset, tf_start, count, tfs);
588 let num_scale = *idf_times_k1_plus_1;
589 let d_tf = *denom_tf_coeff;
590 let d_const = *denom_const;
591 self.scores.clear();
592 self.scores.resize(count, 0.0);
593 for i in 0..count {
594 let tf = unsafe { *tfs.get_unchecked(i) } as f32;
595 let score = (num_scale * tf) / (d_tf * tf + d_const);
596 unsafe {
597 *self.scores.get_unchecked_mut(i) = score;
598 }
599 }
600 }
601 }
602
603 pub async fn ensure_block_loaded(&mut self) -> crate::Result<bool> {
609 cursor_ensure_block!(self, load_block_direct, .await)
610 }
611
612 pub fn ensure_block_loaded_sync(&mut self) -> crate::Result<bool> {
613 cursor_ensure_block!(self, load_block_direct_sync,)
614 }
615
616 pub async fn advance(&mut self) -> crate::Result<DocId> {
617 cursor_advance!(self, ensure_block_loaded, .await)
618 }
619
620 pub fn advance_sync(&mut self) -> crate::Result<DocId> {
621 cursor_advance!(self, ensure_block_loaded_sync,)
622 }
623
624 pub async fn seek(&mut self, target: DocId) -> crate::Result<DocId> {
625 cursor_seek!(self, ensure_block_loaded, target, .await)
626 }
627
628 pub fn seek_sync(&mut self, target: DocId) -> crate::Result<DocId> {
629 cursor_seek!(self, ensure_block_loaded_sync, target,)
630 }
631
632 fn seek_prepare(&mut self, target: DocId) -> Option<DocId> {
633 if self.exhausted {
634 return Some(u32::MAX);
635 }
636
637 if self.block_loaded
639 && let Some(&last) = self.doc_ids.last()
640 {
641 if last >= target && self.doc_ids[self.pos] < target {
642 let remaining = &self.doc_ids[self.pos..];
643 self.pos += crate::structures::simd::find_first_ge_u32(remaining, target);
644 if self.pos >= self.doc_ids.len() {
645 self.block_idx += 1;
646 self.block_loaded = false;
647 if self.block_idx >= self.num_blocks {
648 self.exhausted = true;
649 return Some(u32::MAX);
650 }
651 }
652 return Some(self.doc());
653 }
654 if self.doc_ids[self.pos] >= target {
655 return Some(self.doc());
656 }
657 }
658
659 let lo = match &self.variant {
661 CursorVariant::Text { list, .. } => match list.seek_block(target, self.block_idx) {
663 Some(idx) => idx,
664 None => {
665 self.exhausted = true;
666 return Some(u32::MAX);
667 }
668 },
669 CursorVariant::Sparse { .. } => {
671 let mut lo = self.block_idx;
672 let mut hi = self.num_blocks;
673 while lo < hi {
674 let mid = lo + (hi - lo) / 2;
675 if self.block_last_doc(mid) < target {
676 lo = mid + 1;
677 } else {
678 hi = mid;
679 }
680 }
681 lo
682 }
683 };
684 if lo >= self.num_blocks {
685 self.exhausted = true;
686 return Some(u32::MAX);
687 }
688 if lo != self.block_idx || !self.block_loaded {
689 self.block_idx = lo;
690 self.block_loaded = false;
691 }
692 None
693 }
694
695 #[inline]
696 fn seek_finish(&mut self, target: DocId) -> bool {
697 if self.exhausted {
698 return false;
699 }
700 self.pos = crate::structures::simd::find_first_ge_u32(&self.doc_ids, target);
701 if self.pos >= self.doc_ids.len() {
702 self.block_idx += 1;
703 self.block_loaded = false;
704 if self.block_idx >= self.num_blocks {
705 self.exhausted = true;
706 return false;
707 }
708 return true;
709 }
710 false
711 }
712}
713
714macro_rules! bms_execute_loop {
719 ($self:ident, $ensure:ident, $advance:ident, $seek:ident, $($aw:tt)*) => {{
720 let n = $self.cursors.len();
721
722 for cursor in &mut $self.cursors {
724 cursor.$ensure() $($aw)* ?;
725 }
726
727 let mut docs_scored = 0u64;
728 let mut docs_skipped = 0u64;
729 let mut blocks_skipped = 0u64;
730 let mut conjunction_skipped = 0u64;
731 let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
732 let _bms_start = std::time::Instant::now();
733
734 let inv_heap_factor = $self.inv_heap_factor;
735 let mut adjusted_threshold = $self.collector.threshold() * inv_heap_factor - 1e-6;
736
737 loop {
738 let partition = $self.find_partition();
739 if partition >= n {
740 break;
741 }
742
743 let mut min_doc = u32::MAX;
747 let mut at_min_mask = 0u64; for i in partition..n {
749 let doc = $self.cursors[i].doc();
750 match doc.cmp(&min_doc) {
751 std::cmp::Ordering::Less => {
752 min_doc = doc;
753 at_min_mask = 1u64 << (i as u32);
754 }
755 std::cmp::Ordering::Equal => {
756 at_min_mask |= 1u64 << (i as u32);
757 }
758 _ => {}
759 }
760 }
761 if min_doc == u32::MAX {
762 break;
763 }
764
765 let non_essential_upper = if partition > 0 {
766 $self.prefix_sums[partition - 1]
767 } else {
768 0.0
769 };
770
771 if $self.collector.len() >= $self.collector.k {
773 let mut present_upper: f32 = 0.0;
774 let mut mask = at_min_mask;
775 while mask != 0 {
776 let i = mask.trailing_zeros() as usize;
777 present_upper += $self.cursors[i].max_score;
778 mask &= mask - 1;
779 }
780
781 if present_upper + non_essential_upper <= adjusted_threshold {
782 let mut mask = at_min_mask;
783 while mask != 0 {
784 let i = mask.trailing_zeros() as usize;
785 $self.cursors[i].$ensure() $($aw)* ?;
786 $self.cursors[i].$advance() $($aw)* ?;
787 mask &= mask - 1;
788 }
789 conjunction_skipped += 1;
790 continue;
791 }
792 }
793
794 if $self.collector.len() >= $self.collector.k {
796 let mut block_max_sum: f32 = 0.0;
797 let mut mask = at_min_mask;
798 while mask != 0 {
799 let i = mask.trailing_zeros() as usize;
800 block_max_sum += $self.cursors[i].current_block_max_score();
801 mask &= mask - 1;
802 }
803
804 if block_max_sum + non_essential_upper <= adjusted_threshold {
805 let mut mask = at_min_mask;
806 while mask != 0 {
807 let i = mask.trailing_zeros() as usize;
808 $self.cursors[i].skip_to_next_block();
809 $self.cursors[i].$ensure() $($aw)* ?;
810 mask &= mask - 1;
811 }
812 blocks_skipped += 1;
813 continue;
814 }
815 }
816
817 if let Some(ref pred) = $self.predicate {
819 if !pred(min_doc) {
820 let mut mask = at_min_mask;
821 while mask != 0 {
822 let i = mask.trailing_zeros() as usize;
823 $self.cursors[i].$ensure() $($aw)* ?;
824 $self.cursors[i].$advance() $($aw)* ?;
825 mask &= mask - 1;
826 }
827 continue;
828 }
829 }
830
831 ordinal_scores.clear();
833 {
834 let mut mask = at_min_mask;
835 while mask != 0 {
836 let i = mask.trailing_zeros() as usize;
837 $self.cursors[i].$ensure() $($aw)* ?;
838 $self.cursors[i].ensure_scores();
839 while $self.cursors[i].doc() == min_doc {
840 let ord = $self.cursors[i].ordinal_mut();
841 let sc = $self.cursors[i].score();
842 ordinal_scores.push((ord, sc));
843 $self.cursors[i].$advance() $($aw)* ?;
844 }
845 mask &= mask - 1;
846 }
847 }
848
849 let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
850 if $self.collector.len() >= $self.collector.k
851 && essential_total + non_essential_upper <= adjusted_threshold
852 {
853 docs_skipped += 1;
854 continue;
855 }
856
857 let mut running_total = essential_total;
859 for i in (0..partition).rev() {
860 if $self.collector.len() >= $self.collector.k
861 && running_total + $self.prefix_sums[i] <= adjusted_threshold
862 {
863 break;
864 }
865
866 let doc = $self.cursors[i].$seek(min_doc) $($aw)* ?;
867 if doc == min_doc {
868 $self.cursors[i].ensure_scores();
869 while $self.cursors[i].doc() == min_doc {
870 let s = $self.cursors[i].score();
871 running_total += s;
872 let ord = $self.cursors[i].ordinal_mut();
873 ordinal_scores.push((ord, s));
874 $self.cursors[i].$advance() $($aw)* ?;
875 }
876 }
877 }
878
879 if ordinal_scores.len() == 1 {
882 let (ord, score) = ordinal_scores[0];
883 if $self.collector.insert_with_ordinal(min_doc, score, ord) {
884 docs_scored += 1;
885 adjusted_threshold = $self.collector.threshold() * inv_heap_factor - 1e-6;
886 } else {
887 docs_skipped += 1;
888 }
889 } else if !ordinal_scores.is_empty() {
890 if ordinal_scores.len() > 2 {
891 ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
892 } else if ordinal_scores.len() == 2 && ordinal_scores[0].0 > ordinal_scores[1].0 {
893 ordinal_scores.swap(0, 1);
894 }
895 let mut j = 0;
896 while j < ordinal_scores.len() {
897 let current_ord = ordinal_scores[j].0;
898 let mut score = 0.0f32;
899 while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
900 score += ordinal_scores[j].1;
901 j += 1;
902 }
903 if $self
904 .collector
905 .insert_with_ordinal(min_doc, score, current_ord)
906 {
907 docs_scored += 1;
908 adjusted_threshold = $self.collector.threshold() * inv_heap_factor - 1e-6;
909 } else {
910 docs_skipped += 1;
911 }
912 }
913 }
914 }
915
916 let results: Vec<ScoredDoc> = $self
917 .collector
918 .into_sorted_results()
919 .into_iter()
920 .map(|(doc_id, score, ordinal)| ScoredDoc {
921 doc_id,
922 score,
923 ordinal,
924 })
925 .collect();
926
927 let _bms_elapsed_ms = _bms_start.elapsed().as_millis() as u64;
928 if _bms_elapsed_ms > 500 {
929 warn!(
930 "slow MaxScore: {}ms, cursors={}, scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
931 _bms_elapsed_ms,
932 n,
933 docs_scored,
934 docs_skipped,
935 blocks_skipped,
936 conjunction_skipped,
937 results.len(),
938 results.first().map(|r| r.score).unwrap_or(0.0)
939 );
940 } else {
941 debug!(
942 "MaxScoreExecutor: {}ms, scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
943 _bms_elapsed_ms,
944 docs_scored,
945 docs_skipped,
946 blocks_skipped,
947 conjunction_skipped,
948 results.len(),
949 results.first().map(|r| r.score).unwrap_or(0.0)
950 );
951 }
952
953 Ok(results)
954 }};
955}
956
957impl<'a> MaxScoreExecutor<'a> {
958 pub(crate) fn new(mut cursors: Vec<TermCursor<'a>>, k: usize, heap_factor: f32) -> Self {
963 for c in &mut cursors {
966 c.lazy_ordinals = true;
967 }
968
969 cursors.sort_by(|a, b| {
971 a.max_score
972 .partial_cmp(&b.max_score)
973 .unwrap_or(Ordering::Equal)
974 });
975
976 let mut prefix_sums = Vec::with_capacity(cursors.len());
977 let mut cumsum = 0.0f32;
978 for c in &cursors {
979 cumsum += c.max_score;
980 prefix_sums.push(cumsum);
981 }
982
983 let clamped_heap_factor = heap_factor.clamp(0.01, 1.0);
984
985 debug!(
986 "Creating MaxScoreExecutor: num_cursors={}, k={}, total_upper={:.4}, heap_factor={:.2}",
987 cursors.len(),
988 k,
989 cumsum,
990 clamped_heap_factor
991 );
992
993 Self {
994 cursors,
995 prefix_sums,
996 collector: ScoreCollector::new(k),
997 inv_heap_factor: 1.0 / clamped_heap_factor,
998 predicate: None,
999 }
1000 }
1001
1002 pub fn sparse(
1006 sparse_index: &'a crate::segment::SparseIndex,
1007 query_terms: Vec<(u32, f32)>,
1008 k: usize,
1009 heap_factor: f32,
1010 ) -> Self {
1011 let cursors: Vec<TermCursor<'a>> = query_terms
1012 .iter()
1013 .filter_map(|&(dim_id, qw)| {
1014 let (skip_start, skip_count, global_max, block_data_offset) =
1015 sparse_index.get_skip_range_full(dim_id)?;
1016 Some(TermCursor::sparse(
1017 sparse_index,
1018 qw,
1019 skip_start,
1020 skip_count,
1021 global_max,
1022 block_data_offset,
1023 ))
1024 })
1025 .collect();
1026 Self::new(cursors, k, heap_factor)
1027 }
1028
1029 pub fn text(
1033 posting_lists: Vec<(crate::structures::BlockPostingList, f32)>,
1034 avg_field_len: f32,
1035 k: usize,
1036 ) -> Self {
1037 let cursors: Vec<TermCursor<'a>> = posting_lists
1038 .into_iter()
1039 .map(|(pl, idf)| TermCursor::text(pl, idf, avg_field_len))
1040 .collect();
1041 Self::new(cursors, k, 1.0)
1042 }
1043
1044 #[inline]
1045 fn find_partition(&self) -> usize {
1046 let threshold = self.collector.threshold() * self.inv_heap_factor;
1050 self.prefix_sums.partition_point(|&sum| sum <= threshold)
1051 }
1052
1053 pub fn with_predicate(mut self, predicate: super::DocPredicate<'a>) -> Self {
1059 self.predicate = Some(predicate);
1060 self
1061 }
1062
1063 pub fn seed_threshold(&mut self, initial_threshold: f32) {
1065 self.collector.seed_threshold(initial_threshold);
1066 }
1067
1068 pub async fn execute(mut self) -> crate::Result<Vec<ScoredDoc>> {
1070 if self.cursors.is_empty() {
1071 return Ok(Vec::new());
1072 }
1073 bms_execute_loop!(self, ensure_block_loaded, advance, seek, .await)
1074 }
1075
1076 pub fn execute_sync(mut self) -> crate::Result<Vec<ScoredDoc>> {
1078 if self.cursors.is_empty() {
1079 return Ok(Vec::new());
1080 }
1081 bms_execute_loop!(self, ensure_block_loaded_sync, advance_sync, seek_sync,)
1082 }
1083}
1084
1085#[cfg(test)]
1086mod tests {
1087 use super::*;
1088
1089 #[test]
1090 fn test_score_collector_basic() {
1091 let mut collector = ScoreCollector::new(3);
1092
1093 collector.insert(1, 1.0);
1094 collector.insert(2, 2.0);
1095 collector.insert(3, 3.0);
1096 assert_eq!(collector.threshold(), 1.0);
1097
1098 collector.insert(4, 4.0);
1099 assert_eq!(collector.threshold(), 2.0);
1100
1101 let results = collector.into_sorted_results();
1102 assert_eq!(results.len(), 3);
1103 assert_eq!(results[0].0, 4); assert_eq!(results[1].0, 3);
1105 assert_eq!(results[2].0, 2);
1106 }
1107
1108 #[test]
1109 fn test_score_collector_threshold() {
1110 let mut collector = ScoreCollector::new(2);
1111
1112 collector.insert(1, 5.0);
1113 collector.insert(2, 3.0);
1114 assert_eq!(collector.threshold(), 3.0);
1115
1116 assert!(!collector.would_enter(2.0));
1118 assert!(!collector.insert(3, 2.0));
1119
1120 assert!(collector.would_enter(4.0));
1122 assert!(collector.insert(4, 4.0));
1123 assert_eq!(collector.threshold(), 4.0);
1124 }
1125
1126 #[test]
1127 fn test_heap_entry_ordering() {
1128 let mut heap = BinaryHeap::new();
1129 heap.push(HeapEntry {
1130 doc_id: 1,
1131 score: 3.0,
1132 ordinal: 0,
1133 });
1134 heap.push(HeapEntry {
1135 doc_id: 2,
1136 score: 1.0,
1137 ordinal: 0,
1138 });
1139 heap.push(HeapEntry {
1140 doc_id: 3,
1141 score: 2.0,
1142 ordinal: 0,
1143 });
1144
1145 assert_eq!(heap.pop().unwrap().score, 1.0);
1147 assert_eq!(heap.pop().unwrap().score, 2.0);
1148 assert_eq!(heap.pop().unwrap().score, 3.0);
1149 }
1150}