1use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11
12use log::debug;
13
14use crate::DocId;
15
16#[derive(Clone, Copy)]
18pub struct HeapEntry {
19 pub doc_id: DocId,
20 pub score: f32,
21 pub ordinal: u16,
22}
23
24impl PartialEq for HeapEntry {
25 fn eq(&self, other: &Self) -> bool {
26 self.score == other.score && self.doc_id == other.doc_id
27 }
28}
29
30impl Eq for HeapEntry {}
31
32impl Ord for HeapEntry {
33 fn cmp(&self, other: &Self) -> Ordering {
34 other
36 .score
37 .partial_cmp(&self.score)
38 .unwrap_or(Ordering::Equal)
39 .then_with(|| self.doc_id.cmp(&other.doc_id))
40 }
41}
42
43impl PartialOrd for HeapEntry {
44 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
45 Some(self.cmp(other))
46 }
47}
48
49pub struct ScoreCollector {
55 heap: BinaryHeap<HeapEntry>,
57 pub k: usize,
58 cached_threshold: f32,
61}
62
63impl ScoreCollector {
64 pub fn new(k: usize) -> Self {
66 let capacity = k.saturating_add(1).min(1_000_000);
68 Self {
69 heap: BinaryHeap::with_capacity(capacity),
70 k,
71 cached_threshold: 0.0,
72 }
73 }
74
75 #[inline]
77 pub fn threshold(&self) -> f32 {
78 self.cached_threshold
79 }
80
81 #[inline]
83 fn update_threshold(&mut self) {
84 self.cached_threshold = if self.heap.len() >= self.k {
85 self.heap.peek().map(|e| e.score).unwrap_or(0.0)
86 } else {
87 0.0
88 };
89 }
90
91 #[inline]
94 pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
95 self.insert_with_ordinal(doc_id, score, 0)
96 }
97
98 #[inline]
101 pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
102 if self.heap.len() < self.k {
103 self.heap.push(HeapEntry {
104 doc_id,
105 score,
106 ordinal,
107 });
108 self.update_threshold();
109 true
110 } else if score > self.cached_threshold {
111 self.heap.push(HeapEntry {
112 doc_id,
113 score,
114 ordinal,
115 });
116 self.heap.pop(); self.update_threshold();
118 true
119 } else {
120 false
121 }
122 }
123
124 #[inline]
126 pub fn would_enter(&self, score: f32) -> bool {
127 self.heap.len() < self.k || score > self.cached_threshold
128 }
129
130 #[inline]
132 pub fn len(&self) -> usize {
133 self.heap.len()
134 }
135
136 #[inline]
138 pub fn is_empty(&self) -> bool {
139 self.heap.is_empty()
140 }
141
142 pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
144 let heap_vec = self.heap.into_vec();
145 let mut results: Vec<(DocId, f32, u16)> = Vec::with_capacity(heap_vec.len());
146 for e in heap_vec {
147 results.push((e.doc_id, e.score, e.ordinal));
148 }
149
150 results.sort_by(|a, b| {
152 b.1.partial_cmp(&a.1)
153 .unwrap_or(Ordering::Equal)
154 .then_with(|| a.0.cmp(&b.0))
155 });
156
157 results
158 }
159}
160
161#[derive(Debug, Clone, Copy)]
163pub struct ScoredDoc {
164 pub doc_id: DocId,
165 pub score: f32,
166 pub ordinal: u16,
168}
169
170pub struct BmpExecutor<'a> {
182 sparse_index: &'a crate::segment::SparseIndex,
184 query_terms: Vec<(u32, f32)>,
186 k: usize,
188 heap_factor: f32,
190 predicate: Option<super::DocPredicate<'a>>,
192}
193
194const BMP_SUPERBLOCK_SIZE: usize = 8;
197
198const BMP_MEGABLOCK_SIZE: usize = 16;
202
203struct BmpSuperBlock {
205 contribution: f32,
207 block_start: usize,
209 block_count: usize,
211}
212
213struct BmpMegaBlockEntry {
215 contribution: f32,
217 term_idx: usize,
219 sb_start: usize,
221 sb_count: usize,
223}
224
225impl PartialEq for BmpMegaBlockEntry {
226 fn eq(&self, other: &Self) -> bool {
227 self.contribution == other.contribution
228 }
229}
230
231impl Eq for BmpMegaBlockEntry {}
232
233impl Ord for BmpMegaBlockEntry {
234 fn cmp(&self, other: &Self) -> Ordering {
235 self.contribution
237 .partial_cmp(&other.contribution)
238 .unwrap_or(Ordering::Equal)
239 }
240}
241
242impl PartialOrd for BmpMegaBlockEntry {
243 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
244 Some(self.cmp(other))
245 }
246}
247
248macro_rules! bmp_execute_loop {
253 ($self:ident, $get_blocks:ident, $($aw:tt)*) => {{
254 use rustc_hash::FxHashMap;
255
256 let num_terms = $self.query_terms.len();
257 let si = $self.sparse_index;
258
259 let mut term_superblocks: Vec<Vec<BmpSuperBlock>> = Vec::with_capacity(num_terms);
263 let mut term_skip_starts: Vec<usize> = Vec::with_capacity(num_terms);
264 let mut global_min_doc = u32::MAX;
265 let mut global_max_doc = 0u32;
266 let mut total_remaining = 0.0f32;
267
268 for &(dim_id, qw) in &$self.query_terms {
269 let mut term_skip_start = 0usize;
270 let mut superblocks = Vec::new();
271
272 let abs_qw = qw.abs();
273 if let Some((skip_start, skip_count, _global_max)) = si.get_skip_range(dim_id) {
274 term_skip_start = skip_start;
275 let mut sb_start = 0;
276 while sb_start < skip_count {
277 let sb_count = (skip_count - sb_start).min(BMP_SUPERBLOCK_SIZE);
278 let mut sb_contribution = 0.0f32;
279 for j in 0..sb_count {
280 let skip = si.read_skip_entry(skip_start + sb_start + j);
281 sb_contribution += abs_qw * skip.max_weight;
282 global_min_doc = global_min_doc.min(skip.first_doc);
283 global_max_doc = global_max_doc.max(skip.last_doc);
284 }
285 total_remaining += sb_contribution;
286 superblocks.push(BmpSuperBlock {
287 contribution: sb_contribution,
288 block_start: sb_start,
289 block_count: sb_count,
290 });
291 sb_start += sb_count;
292 }
293 }
294 term_skip_starts.push(term_skip_start);
295 term_superblocks.push(superblocks);
296 }
297
298 let mut mega_queue: BinaryHeap<BmpMegaBlockEntry> = BinaryHeap::new();
300 for (term_idx, superblocks) in term_superblocks.iter().enumerate() {
301 let mut mb_start = 0;
302 while mb_start < superblocks.len() {
303 let mb_count = (superblocks.len() - mb_start).min(BMP_MEGABLOCK_SIZE);
304 let mb_contribution: f32 = superblocks[mb_start..mb_start + mb_count]
305 .iter()
306 .map(|sb| sb.contribution)
307 .sum();
308 mega_queue.push(BmpMegaBlockEntry {
309 contribution: mb_contribution,
310 term_idx,
311 sb_start: mb_start,
312 sb_count: mb_count,
313 });
314 mb_start += mb_count;
315 }
316 }
317
318 let doc_range = if global_max_doc >= global_min_doc {
320 (global_max_doc - global_min_doc + 1) as usize
321 } else {
322 0
323 };
324 let use_flat = doc_range > 0 && doc_range <= 256 * 1024;
325 let mut flat_scores: Vec<f32> = if use_flat {
326 vec![0.0; doc_range]
327 } else {
328 Vec::new()
329 };
330 let mut dirty: Vec<u32> = if use_flat {
331 Vec::with_capacity(4096)
332 } else {
333 Vec::new()
334 };
335 let mut multi_ord_accumulators: FxHashMap<u64, f32> = FxHashMap::default();
336
337 let mut blocks_processed = 0u64;
338 let mut blocks_skipped = 0u64;
339
340 let mut top_k = ScoreCollector::new($self.k);
341
342 let mut doc_ids_buf: Vec<u32> = Vec::with_capacity(256);
343 let mut weights_buf: Vec<f32> = Vec::with_capacity(256);
344 let mut ordinals_buf: Vec<u16> = Vec::with_capacity(256);
345
346 let mut terms_warmed = vec![false; num_terms];
347 let mut warmup_remaining = $self.k.min(num_terms);
348
349 while let Some(mega) = mega_queue.pop() {
350 total_remaining -= mega.contribution;
351
352 if !terms_warmed[mega.term_idx] {
353 terms_warmed[mega.term_idx] = true;
354 warmup_remaining = warmup_remaining.saturating_sub(1);
355 }
356
357 if warmup_remaining == 0 {
358 let adjusted_threshold = top_k.threshold() * $self.heap_factor;
359 if top_k.len() >= $self.k && total_remaining <= adjusted_threshold {
360 let remaining_blocks: u64 = mega_queue
361 .iter()
362 .map(|m| {
363 let sbs =
364 &term_superblocks[m.term_idx][m.sb_start..m.sb_start + m.sb_count];
365 sbs.iter().map(|sb| sb.block_count as u64).sum::<u64>()
366 })
367 .sum();
368 blocks_skipped += remaining_blocks;
369 debug!(
370 "BMP early termination after {} blocks: remaining={:.4} <= threshold={:.4}",
371 blocks_processed, total_remaining, adjusted_threshold
372 );
373 break;
374 }
375 }
376
377 let dim_id = $self.query_terms[mega.term_idx].0;
378 let qw = $self.query_terms[mega.term_idx].1;
379 let abs_qw = qw.abs();
380 let skip_start = term_skip_starts[mega.term_idx];
381
382 for sb in term_superblocks[mega.term_idx]
383 .iter()
384 .skip(mega.sb_start)
385 .take(mega.sb_count)
386 {
387 if top_k.len() >= $self.k {
388 let adjusted_threshold = top_k.threshold() * $self.heap_factor;
389 if sb.contribution + total_remaining <= adjusted_threshold {
390 blocks_skipped += sb.block_count as u64;
391 continue;
392 }
393 }
394
395 let sb_blocks = si
397 .$get_blocks(dim_id, sb.block_start, sb.block_count)
398 $($aw)*?;
399
400 let adjusted_threshold2 = top_k.threshold() * $self.heap_factor;
401 let dirty_start = dirty.len();
402
403 for (blk_offset, block) in sb_blocks.into_iter().enumerate() {
404 let blk_idx = sb.block_start + blk_offset;
405
406 if top_k.len() >= $self.k {
407 let skip = si.read_skip_entry(skip_start + blk_idx);
408 let blk_contrib = abs_qw * skip.max_weight;
409 if blk_contrib + total_remaining <= adjusted_threshold2 {
410 blocks_skipped += 1;
411 continue;
412 }
413 }
414
415 block.decode_doc_ids_into(&mut doc_ids_buf);
416
417 if block.header.ordinal_bits == 0 && use_flat {
418 block.accumulate_scored_weights(
419 qw,
420 &doc_ids_buf,
421 &mut flat_scores,
422 global_min_doc,
423 &mut dirty,
424 );
425 } else {
426 block.decode_scored_weights_into(qw, &mut weights_buf);
427 let count = block.header.count as usize;
428
429 block.decode_ordinals_into(&mut ordinals_buf);
430 if use_flat {
431 for i in 0..count {
432 let doc_id = doc_ids_buf[i];
433 let ordinal = ordinals_buf[i];
434 let score_contribution = weights_buf[i];
435
436 if ordinal == 0 {
437 let off = (doc_id - global_min_doc) as usize;
438 if flat_scores[off] == 0.0 {
439 dirty.push(doc_id);
440 }
441 flat_scores[off] += score_contribution;
442 } else {
443 let key = (doc_id as u64) << 16 | ordinal as u64;
444 let acc = multi_ord_accumulators.entry(key).or_insert(0.0);
445 *acc += score_contribution;
446 top_k.insert_with_ordinal(doc_id, *acc, ordinal);
447 }
448 }
449 } else {
450 for i in 0..count {
451 let key = (doc_ids_buf[i] as u64) << 16 | ordinals_buf[i] as u64;
452 let acc = multi_ord_accumulators.entry(key).or_insert(0.0);
453 *acc += weights_buf[i];
454 top_k.insert_with_ordinal(doc_ids_buf[i], *acc, ordinals_buf[i]);
455 }
456 }
457 }
458
459 blocks_processed += 1;
460 }
461
462 for &doc_id in &dirty[dirty_start..] {
463 let off = (doc_id - global_min_doc) as usize;
464 top_k.insert_with_ordinal(doc_id, flat_scores[off], 0);
465 }
466 }
467 }
468
469 let mut final_top_k = ScoreCollector::new($self.k);
471
472 let num_accumulators = if use_flat {
473 for &doc_id in &dirty {
474 if let Some(ref pred) = $self.predicate
475 && !pred(doc_id)
476 {
477 continue;
478 }
479 let off = (doc_id - global_min_doc) as usize;
480 let score = flat_scores[off];
481 if score > 0.0 {
482 final_top_k.insert_with_ordinal(doc_id, score, 0);
483 }
484 }
485 dirty.len() + multi_ord_accumulators.len()
486 } else {
487 multi_ord_accumulators.len()
488 };
489
490 for (key, score) in &multi_ord_accumulators {
491 let doc_id = (*key >> 16) as crate::DocId;
492 if let Some(ref pred) = $self.predicate
493 && !pred(doc_id)
494 {
495 continue;
496 }
497 final_top_k.insert_with_ordinal(doc_id, *score, (*key & 0xFFFF) as u16);
498 }
499
500 let results: Vec<ScoredDoc> = final_top_k
501 .into_sorted_results()
502 .into_iter()
503 .map(|(doc_id, score, ordinal)| ScoredDoc {
504 doc_id,
505 score,
506 ordinal,
507 })
508 .collect();
509
510 debug!(
511 "BmpExecutor completed: blocks_processed={}, blocks_skipped={}, accumulators={}, flat={}, returned={}, top_score={:.4}",
512 blocks_processed,
513 blocks_skipped,
514 num_accumulators,
515 use_flat,
516 results.len(),
517 results.first().map(|r| r.score).unwrap_or(0.0)
518 );
519
520 Ok(results)
521 }};
522}
523
524impl<'a> BmpExecutor<'a> {
525 pub fn new(
530 sparse_index: &'a crate::segment::SparseIndex,
531 query_terms: Vec<(u32, f32)>,
532 k: usize,
533 heap_factor: f32,
534 ) -> Self {
535 Self {
536 sparse_index,
537 query_terms,
538 k,
539 heap_factor: heap_factor.clamp(0.0, 1.0),
540 predicate: None,
541 }
542 }
543
544 pub fn set_predicate(&mut self, predicate: Option<super::DocPredicate<'a>>) {
546 self.predicate = predicate;
547 }
548
549 pub async fn execute(self) -> crate::Result<Vec<ScoredDoc>> {
551 if self.query_terms.is_empty() {
552 return Ok(Vec::new());
553 }
554 bmp_execute_loop!(self, get_blocks_range, .await)
555 }
556
557 #[cfg(feature = "sync")]
559 pub fn execute_sync(self) -> crate::Result<Vec<ScoredDoc>> {
560 if self.query_terms.is_empty() {
561 return Ok(Vec::new());
562 }
563 bmp_execute_loop!(self, get_blocks_range_sync,)
564 }
565}
566
567pub struct MaxScoreExecutor<'a> {
578 cursors: Vec<TermCursor<'a>>,
579 prefix_sums: Vec<f32>,
580 collector: ScoreCollector,
581 heap_factor: f32,
582 predicate: Option<super::DocPredicate<'a>>,
583}
584
585pub(crate) struct TermCursor<'a> {
594 pub max_score: f32,
595 num_blocks: usize,
596 block_idx: usize,
598 doc_ids: Vec<u32>,
599 scores: Vec<f32>,
600 ordinals: Vec<u16>,
601 pos: usize,
602 block_loaded: bool,
603 exhausted: bool,
604 variant: CursorVariant<'a>,
606}
607
608enum CursorVariant<'a> {
609 Text {
611 list: crate::structures::BlockPostingList,
612 idf: f32,
613 avg_field_len: f32,
614 tfs: Vec<u32>, },
616 Sparse {
618 si: &'a crate::segment::SparseIndex,
619 query_weight: f32,
620 skip_start: usize,
621 block_data_offset: u64,
622 },
623}
624
625impl<'a> TermCursor<'a> {
626 pub fn text(
628 posting_list: crate::structures::BlockPostingList,
629 idf: f32,
630 avg_field_len: f32,
631 ) -> Self {
632 let max_tf = posting_list.max_tf() as f32;
633 let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
634 let num_blocks = posting_list.num_blocks();
635 Self {
636 max_score,
637 num_blocks,
638 block_idx: 0,
639 doc_ids: Vec::with_capacity(128),
640 scores: Vec::with_capacity(128),
641 ordinals: Vec::new(),
642 pos: 0,
643 block_loaded: false,
644 exhausted: num_blocks == 0,
645 variant: CursorVariant::Text {
646 list: posting_list,
647 idf,
648 avg_field_len,
649 tfs: Vec::with_capacity(128),
650 },
651 }
652 }
653
654 pub fn sparse(
657 si: &'a crate::segment::SparseIndex,
658 query_weight: f32,
659 skip_start: usize,
660 skip_count: usize,
661 global_max_weight: f32,
662 block_data_offset: u64,
663 ) -> Self {
664 Self {
665 max_score: query_weight.abs() * global_max_weight,
666 num_blocks: skip_count,
667 block_idx: 0,
668 doc_ids: Vec::with_capacity(256),
669 scores: Vec::with_capacity(256),
670 ordinals: Vec::with_capacity(256),
671 pos: 0,
672 block_loaded: false,
673 exhausted: skip_count == 0,
674 variant: CursorVariant::Sparse {
675 si,
676 query_weight,
677 skip_start,
678 block_data_offset,
679 },
680 }
681 }
682
683 #[inline]
686 fn block_first_doc(&self, idx: usize) -> DocId {
687 match &self.variant {
688 CursorVariant::Text { list, .. } => list.block_first_doc(idx).unwrap_or(u32::MAX),
689 CursorVariant::Sparse { si, skip_start, .. } => {
690 si.read_skip_entry(*skip_start + idx).first_doc
691 }
692 }
693 }
694
695 #[inline]
696 fn block_last_doc(&self, idx: usize) -> DocId {
697 match &self.variant {
698 CursorVariant::Text { list, .. } => list.block_last_doc(idx).unwrap_or(0),
699 CursorVariant::Sparse { si, skip_start, .. } => {
700 si.read_skip_entry(*skip_start + idx).last_doc
701 }
702 }
703 }
704
705 #[inline]
708 pub fn doc(&self) -> DocId {
709 if self.exhausted {
710 return u32::MAX;
711 }
712 if self.block_loaded {
713 self.doc_ids.get(self.pos).copied().unwrap_or(u32::MAX)
714 } else {
715 self.block_first_doc(self.block_idx)
716 }
717 }
718
719 #[inline]
720 pub fn ordinal(&self) -> u16 {
721 if !self.block_loaded || self.ordinals.is_empty() {
722 return 0;
723 }
724 self.ordinals.get(self.pos).copied().unwrap_or(0)
725 }
726
727 #[inline]
728 pub fn score(&self) -> f32 {
729 if !self.block_loaded {
730 return 0.0;
731 }
732 self.scores.get(self.pos).copied().unwrap_or(0.0)
733 }
734
735 #[inline]
736 pub fn current_block_max_score(&self) -> f32 {
737 if self.exhausted {
738 return 0.0;
739 }
740 match &self.variant {
741 CursorVariant::Text { list, idf, .. } => {
742 let block_max_tf = list.block_max_tf(self.block_idx).unwrap_or(0) as f32;
743 super::bm25_upper_bound(block_max_tf.max(1.0), *idf)
744 }
745 CursorVariant::Sparse {
746 si,
747 query_weight,
748 skip_start,
749 ..
750 } => query_weight.abs() * si.read_skip_entry(*skip_start + self.block_idx).max_weight,
751 }
752 }
753
754 pub fn skip_to_next_block(&mut self) -> DocId {
757 if self.exhausted {
758 return u32::MAX;
759 }
760 self.block_idx += 1;
761 self.block_loaded = false;
762 if self.block_idx >= self.num_blocks {
763 self.exhausted = true;
764 return u32::MAX;
765 }
766 self.block_first_doc(self.block_idx)
767 }
768
769 #[inline]
770 fn advance_pos(&mut self) -> DocId {
771 self.pos += 1;
772 if self.pos >= self.doc_ids.len() {
773 self.block_idx += 1;
774 self.block_loaded = false;
775 if self.block_idx >= self.num_blocks {
776 self.exhausted = true;
777 return u32::MAX;
778 }
779 }
780 self.doc()
781 }
782
783 pub async fn ensure_block_loaded(&mut self) -> crate::Result<bool> {
786 if self.exhausted || self.block_loaded {
787 return Ok(!self.exhausted);
788 }
789 match &mut self.variant {
790 CursorVariant::Text {
791 list,
792 idf,
793 avg_field_len,
794 tfs,
795 } => {
796 if list.decode_block_into(self.block_idx, &mut self.doc_ids, tfs) {
797 self.scores.clear();
798 self.scores.reserve(tfs.len());
799 for &tf in tfs.iter() {
800 let tf = tf as f32;
801 self.scores
802 .push(super::bm25_score(tf, *idf, tf, *avg_field_len));
803 }
804 self.pos = 0;
805 self.block_loaded = true;
806 Ok(true)
807 } else {
808 self.exhausted = true;
809 Ok(false)
810 }
811 }
812 CursorVariant::Sparse {
813 si,
814 query_weight,
815 skip_start,
816 block_data_offset,
817 ..
818 } => {
819 let block = si
820 .load_block_direct(*skip_start, *block_data_offset, self.block_idx)
821 .await?;
822 match block {
823 Some(b) => {
824 b.decode_doc_ids_into(&mut self.doc_ids);
825 b.decode_ordinals_into(&mut self.ordinals);
826 b.decode_scored_weights_into(*query_weight, &mut self.scores);
827 self.pos = 0;
828 self.block_loaded = true;
829 Ok(true)
830 }
831 None => {
832 self.exhausted = true;
833 Ok(false)
834 }
835 }
836 }
837 }
838 }
839
840 pub fn ensure_block_loaded_sync(&mut self) -> crate::Result<bool> {
841 if self.exhausted || self.block_loaded {
842 return Ok(!self.exhausted);
843 }
844 match &mut self.variant {
845 CursorVariant::Text {
846 list,
847 idf,
848 avg_field_len,
849 tfs,
850 } => {
851 if list.decode_block_into(self.block_idx, &mut self.doc_ids, tfs) {
852 self.scores.clear();
853 self.scores.reserve(tfs.len());
854 for &tf in tfs.iter() {
855 let tf = tf as f32;
856 self.scores
857 .push(super::bm25_score(tf, *idf, tf, *avg_field_len));
858 }
859 self.pos = 0;
860 self.block_loaded = true;
861 Ok(true)
862 } else {
863 self.exhausted = true;
864 Ok(false)
865 }
866 }
867 CursorVariant::Sparse {
868 si,
869 query_weight,
870 skip_start,
871 block_data_offset,
872 ..
873 } => {
874 let block =
875 si.load_block_direct_sync(*skip_start, *block_data_offset, self.block_idx)?;
876 match block {
877 Some(b) => {
878 b.decode_doc_ids_into(&mut self.doc_ids);
879 b.decode_ordinals_into(&mut self.ordinals);
880 b.decode_scored_weights_into(*query_weight, &mut self.scores);
881 self.pos = 0;
882 self.block_loaded = true;
883 Ok(true)
884 }
885 None => {
886 self.exhausted = true;
887 Ok(false)
888 }
889 }
890 }
891 }
892 }
893
894 pub async fn advance(&mut self) -> crate::Result<DocId> {
897 if self.exhausted {
898 return Ok(u32::MAX);
899 }
900 self.ensure_block_loaded().await?;
901 if self.exhausted {
902 return Ok(u32::MAX);
903 }
904 Ok(self.advance_pos())
905 }
906
907 pub fn advance_sync(&mut self) -> crate::Result<DocId> {
908 if self.exhausted {
909 return Ok(u32::MAX);
910 }
911 self.ensure_block_loaded_sync()?;
912 if self.exhausted {
913 return Ok(u32::MAX);
914 }
915 Ok(self.advance_pos())
916 }
917
918 pub async fn seek(&mut self, target: DocId) -> crate::Result<DocId> {
919 if let Some(doc) = self.seek_prepare(target) {
920 return Ok(doc);
921 }
922 self.ensure_block_loaded().await?;
923 if self.seek_finish(target) {
924 self.ensure_block_loaded().await?;
925 }
926 Ok(self.doc())
927 }
928
929 pub fn seek_sync(&mut self, target: DocId) -> crate::Result<DocId> {
930 if let Some(doc) = self.seek_prepare(target) {
931 return Ok(doc);
932 }
933 self.ensure_block_loaded_sync()?;
934 if self.seek_finish(target) {
935 self.ensure_block_loaded_sync()?;
936 }
937 Ok(self.doc())
938 }
939
940 fn seek_prepare(&mut self, target: DocId) -> Option<DocId> {
941 if self.exhausted {
942 return Some(u32::MAX);
943 }
944
945 if self.block_loaded
947 && let Some(&last) = self.doc_ids.last()
948 {
949 if last >= target && self.doc_ids[self.pos] < target {
950 let remaining = &self.doc_ids[self.pos..];
951 self.pos += crate::structures::simd::find_first_ge_u32(remaining, target);
952 if self.pos >= self.doc_ids.len() {
953 self.block_idx += 1;
954 self.block_loaded = false;
955 if self.block_idx >= self.num_blocks {
956 self.exhausted = true;
957 return Some(u32::MAX);
958 }
959 }
960 return Some(self.doc());
961 }
962 if self.doc_ids[self.pos] >= target {
963 return Some(self.doc());
964 }
965 }
966
967 let lo = match &self.variant {
969 CursorVariant::Text { list, .. } => match list.seek_block(target, self.block_idx) {
971 Some(idx) => idx,
972 None => {
973 self.exhausted = true;
974 return Some(u32::MAX);
975 }
976 },
977 CursorVariant::Sparse { .. } => {
979 let mut lo = self.block_idx;
980 let mut hi = self.num_blocks;
981 while lo < hi {
982 let mid = lo + (hi - lo) / 2;
983 if self.block_last_doc(mid) < target {
984 lo = mid + 1;
985 } else {
986 hi = mid;
987 }
988 }
989 lo
990 }
991 };
992 if lo >= self.num_blocks {
993 self.exhausted = true;
994 return Some(u32::MAX);
995 }
996 if lo != self.block_idx || !self.block_loaded {
997 self.block_idx = lo;
998 self.block_loaded = false;
999 }
1000 None
1001 }
1002
1003 #[inline]
1004 fn seek_finish(&mut self, target: DocId) -> bool {
1005 if self.exhausted {
1006 return false;
1007 }
1008 self.pos = crate::structures::simd::find_first_ge_u32(&self.doc_ids, target);
1009 if self.pos >= self.doc_ids.len() {
1010 self.block_idx += 1;
1011 self.block_loaded = false;
1012 if self.block_idx >= self.num_blocks {
1013 self.exhausted = true;
1014 return false;
1015 }
1016 return true;
1017 }
1018 false
1019 }
1020}
1021
1022macro_rules! bms_execute_loop {
1027 ($self:ident, $ensure:ident, $advance:ident, $seek:ident, $($aw:tt)*) => {{
1028 let n = $self.cursors.len();
1029
1030 for cursor in &mut $self.cursors {
1032 cursor.$ensure() $($aw)* ?;
1033 }
1034
1035 let mut docs_scored = 0u64;
1036 let mut docs_skipped = 0u64;
1037 let mut blocks_skipped = 0u64;
1038 let mut conjunction_skipped = 0u64;
1039 let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
1040
1041 loop {
1042 let partition = $self.find_partition();
1043 if partition >= n {
1044 break;
1045 }
1046
1047 let mut min_doc = u32::MAX;
1049 for i in partition..n {
1050 let doc = $self.cursors[i].doc();
1051 if doc < min_doc {
1052 min_doc = doc;
1053 }
1054 }
1055 if min_doc == u32::MAX {
1056 break;
1057 }
1058
1059 if let Some(ref pred) = $self.predicate {
1061 if !pred(min_doc) {
1062 for i in partition..n {
1064 if $self.cursors[i].doc() == min_doc {
1065 $self.cursors[i].$ensure() $($aw)* ?;
1066 $self.cursors[i].$advance() $($aw)* ?;
1067 }
1068 }
1069 docs_skipped += 1;
1070 continue;
1071 }
1072 }
1073
1074 let non_essential_upper = if partition > 0 {
1075 $self.prefix_sums[partition - 1]
1076 } else {
1077 0.0
1078 };
1079 let adjusted_threshold = $self.collector.threshold() * $self.heap_factor;
1080
1081 if $self.collector.len() >= $self.collector.k {
1083 let present_upper: f32 = (partition..n)
1084 .filter(|&i| $self.cursors[i].doc() == min_doc)
1085 .map(|i| $self.cursors[i].max_score)
1086 .sum();
1087
1088 if present_upper + non_essential_upper <= adjusted_threshold {
1089 for i in partition..n {
1090 if $self.cursors[i].doc() == min_doc {
1091 $self.cursors[i].$ensure() $($aw)* ?;
1092 $self.cursors[i].$advance() $($aw)* ?;
1093 }
1094 }
1095 conjunction_skipped += 1;
1096 continue;
1097 }
1098 }
1099
1100 if $self.collector.len() >= $self.collector.k {
1102 let block_max_sum: f32 = (partition..n)
1103 .filter(|&i| $self.cursors[i].doc() == min_doc)
1104 .map(|i| $self.cursors[i].current_block_max_score())
1105 .sum();
1106
1107 if block_max_sum + non_essential_upper <= adjusted_threshold {
1108 for i in partition..n {
1109 if $self.cursors[i].doc() == min_doc {
1110 $self.cursors[i].skip_to_next_block();
1111 $self.cursors[i].$ensure() $($aw)* ?;
1112 }
1113 }
1114 blocks_skipped += 1;
1115 continue;
1116 }
1117 }
1118
1119 ordinal_scores.clear();
1121 for i in partition..n {
1122 if $self.cursors[i].doc() == min_doc {
1123 $self.cursors[i].$ensure() $($aw)* ?;
1124 while $self.cursors[i].doc() == min_doc {
1125 ordinal_scores.push(($self.cursors[i].ordinal(), $self.cursors[i].score()));
1126 $self.cursors[i].$advance() $($aw)* ?;
1127 }
1128 }
1129 }
1130
1131 let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
1132 if $self.collector.len() >= $self.collector.k
1133 && essential_total + non_essential_upper <= adjusted_threshold
1134 {
1135 docs_skipped += 1;
1136 continue;
1137 }
1138
1139 let mut running_total = essential_total;
1141 for i in (0..partition).rev() {
1142 if $self.collector.len() >= $self.collector.k
1143 && running_total + $self.prefix_sums[i] <= adjusted_threshold
1144 {
1145 break;
1146 }
1147
1148 let doc = $self.cursors[i].$seek(min_doc) $($aw)* ?;
1149 if doc == min_doc {
1150 while $self.cursors[i].doc() == min_doc {
1151 let s = $self.cursors[i].score();
1152 running_total += s;
1153 ordinal_scores.push(($self.cursors[i].ordinal(), s));
1154 $self.cursors[i].$advance() $($aw)* ?;
1155 }
1156 }
1157 }
1158
1159 if ordinal_scores.len() == 1 {
1162 let (ord, score) = ordinal_scores[0];
1163 if $self.collector.insert_with_ordinal(min_doc, score, ord) {
1164 docs_scored += 1;
1165 } else {
1166 docs_skipped += 1;
1167 }
1168 } else if !ordinal_scores.is_empty() {
1169 if ordinal_scores.len() > 2 {
1170 ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
1171 } else if ordinal_scores[0].0 > ordinal_scores[1].0 {
1172 ordinal_scores.swap(0, 1);
1173 }
1174 let mut j = 0;
1175 while j < ordinal_scores.len() {
1176 let current_ord = ordinal_scores[j].0;
1177 let mut score = 0.0f32;
1178 while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
1179 score += ordinal_scores[j].1;
1180 j += 1;
1181 }
1182 if $self
1183 .collector
1184 .insert_with_ordinal(min_doc, score, current_ord)
1185 {
1186 docs_scored += 1;
1187 } else {
1188 docs_skipped += 1;
1189 }
1190 }
1191 }
1192 }
1193
1194 let results: Vec<ScoredDoc> = $self
1195 .collector
1196 .into_sorted_results()
1197 .into_iter()
1198 .map(|(doc_id, score, ordinal)| ScoredDoc {
1199 doc_id,
1200 score,
1201 ordinal,
1202 })
1203 .collect();
1204
1205 debug!(
1206 "MaxScoreExecutor: scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
1207 docs_scored,
1208 docs_skipped,
1209 blocks_skipped,
1210 conjunction_skipped,
1211 results.len(),
1212 results.first().map(|r| r.score).unwrap_or(0.0)
1213 );
1214
1215 Ok(results)
1216 }};
1217}
1218
1219impl<'a> MaxScoreExecutor<'a> {
1220 pub(crate) fn new(mut cursors: Vec<TermCursor<'a>>, k: usize, heap_factor: f32) -> Self {
1225 cursors.sort_by(|a, b| {
1227 a.max_score
1228 .partial_cmp(&b.max_score)
1229 .unwrap_or(Ordering::Equal)
1230 });
1231
1232 let mut prefix_sums = Vec::with_capacity(cursors.len());
1233 let mut cumsum = 0.0f32;
1234 for c in &cursors {
1235 cumsum += c.max_score;
1236 prefix_sums.push(cumsum);
1237 }
1238
1239 debug!(
1240 "Creating MaxScoreExecutor: num_cursors={}, k={}, total_upper={:.4}, heap_factor={:.2}",
1241 cursors.len(),
1242 k,
1243 cumsum,
1244 heap_factor
1245 );
1246
1247 Self {
1248 cursors,
1249 prefix_sums,
1250 collector: ScoreCollector::new(k),
1251 heap_factor: heap_factor.clamp(0.0, 1.0),
1252 predicate: None,
1253 }
1254 }
1255
1256 pub fn set_predicate(&mut self, predicate: Option<super::DocPredicate<'a>>) {
1258 self.predicate = predicate;
1259 }
1260
1261 pub fn sparse(
1265 sparse_index: &'a crate::segment::SparseIndex,
1266 query_terms: Vec<(u32, f32)>,
1267 k: usize,
1268 heap_factor: f32,
1269 ) -> Self {
1270 let cursors: Vec<TermCursor<'a>> = query_terms
1271 .iter()
1272 .filter_map(|&(dim_id, qw)| {
1273 let (skip_start, skip_count, global_max, block_data_offset) =
1274 sparse_index.get_skip_range_full(dim_id)?;
1275 Some(TermCursor::sparse(
1276 sparse_index,
1277 qw,
1278 skip_start,
1279 skip_count,
1280 global_max,
1281 block_data_offset,
1282 ))
1283 })
1284 .collect();
1285 Self::new(cursors, k, heap_factor)
1286 }
1287
1288 pub fn text(
1292 posting_lists: Vec<(crate::structures::BlockPostingList, f32)>,
1293 avg_field_len: f32,
1294 k: usize,
1295 ) -> Self {
1296 let cursors: Vec<TermCursor<'a>> = posting_lists
1297 .into_iter()
1298 .map(|(pl, idf)| TermCursor::text(pl, idf, avg_field_len))
1299 .collect();
1300 Self::new(cursors, k, 1.0)
1301 }
1302
1303 #[inline]
1304 fn find_partition(&self) -> usize {
1305 let threshold = self.collector.threshold() * self.heap_factor;
1306 self.prefix_sums.partition_point(|&sum| sum <= threshold)
1307 }
1308
1309 pub async fn execute(mut self) -> crate::Result<Vec<ScoredDoc>> {
1311 if self.cursors.is_empty() {
1312 return Ok(Vec::new());
1313 }
1314 bms_execute_loop!(self, ensure_block_loaded, advance, seek, .await)
1315 }
1316
1317 pub fn execute_sync(mut self) -> crate::Result<Vec<ScoredDoc>> {
1319 if self.cursors.is_empty() {
1320 return Ok(Vec::new());
1321 }
1322 bms_execute_loop!(self, ensure_block_loaded_sync, advance_sync, seek_sync,)
1323 }
1324}
1325
1326#[cfg(test)]
1327mod tests {
1328 use super::*;
1329
1330 #[test]
1331 fn test_score_collector_basic() {
1332 let mut collector = ScoreCollector::new(3);
1333
1334 collector.insert(1, 1.0);
1335 collector.insert(2, 2.0);
1336 collector.insert(3, 3.0);
1337 assert_eq!(collector.threshold(), 1.0);
1338
1339 collector.insert(4, 4.0);
1340 assert_eq!(collector.threshold(), 2.0);
1341
1342 let results = collector.into_sorted_results();
1343 assert_eq!(results.len(), 3);
1344 assert_eq!(results[0].0, 4); assert_eq!(results[1].0, 3);
1346 assert_eq!(results[2].0, 2);
1347 }
1348
1349 #[test]
1350 fn test_score_collector_threshold() {
1351 let mut collector = ScoreCollector::new(2);
1352
1353 collector.insert(1, 5.0);
1354 collector.insert(2, 3.0);
1355 assert_eq!(collector.threshold(), 3.0);
1356
1357 assert!(!collector.would_enter(2.0));
1359 assert!(!collector.insert(3, 2.0));
1360
1361 assert!(collector.would_enter(4.0));
1363 assert!(collector.insert(4, 4.0));
1364 assert_eq!(collector.threshold(), 4.0);
1365 }
1366
1367 #[test]
1368 fn test_heap_entry_ordering() {
1369 let mut heap = BinaryHeap::new();
1370 heap.push(HeapEntry {
1371 doc_id: 1,
1372 score: 3.0,
1373 ordinal: 0,
1374 });
1375 heap.push(HeapEntry {
1376 doc_id: 2,
1377 score: 1.0,
1378 ordinal: 0,
1379 });
1380 heap.push(HeapEntry {
1381 doc_id: 3,
1382 score: 2.0,
1383 ordinal: 0,
1384 });
1385
1386 assert_eq!(heap.pop().unwrap().score, 1.0);
1388 assert_eq!(heap.pop().unwrap().score, 2.0);
1389 assert_eq!(heap.pop().unwrap().score, 3.0);
1390 }
1391}