1use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11
12use log::debug;
13
14use crate::DocId;
15
16#[derive(Clone, Copy)]
18pub struct HeapEntry {
19 pub doc_id: DocId,
20 pub score: f32,
21 pub ordinal: u16,
22}
23
24impl PartialEq for HeapEntry {
25 fn eq(&self, other: &Self) -> bool {
26 self.score == other.score && self.doc_id == other.doc_id
27 }
28}
29
30impl Eq for HeapEntry {}
31
32impl Ord for HeapEntry {
33 fn cmp(&self, other: &Self) -> Ordering {
34 other
37 .score
38 .total_cmp(&self.score)
39 .then(self.doc_id.cmp(&other.doc_id))
40 }
41}
42
43impl PartialOrd for HeapEntry {
44 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
45 Some(self.cmp(other))
46 }
47}
48
49pub struct ScoreCollector {
62 heap: BinaryHeap<HeapEntry>,
64 pub k: usize,
65 cached_threshold: f32,
68}
69
70impl ScoreCollector {
71 pub fn new(k: usize) -> Self {
73 let capacity = k.saturating_add(1).min(1_000_000);
75 Self {
76 heap: BinaryHeap::with_capacity(capacity),
77 k,
78 cached_threshold: 0.0,
79 }
80 }
81
82 #[inline]
84 pub fn threshold(&self) -> f32 {
85 self.cached_threshold
86 }
87
88 #[inline]
90 fn update_threshold(&mut self) {
91 self.cached_threshold = if self.heap.len() >= self.k {
92 self.heap.peek().map(|e| e.score).unwrap_or(0.0)
93 } else {
94 0.0
95 };
96 }
97
98 #[inline]
101 pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
102 self.insert_with_ordinal(doc_id, score, 0)
103 }
104
105 #[inline]
108 pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
109 if self.heap.len() < self.k {
110 self.heap.push(HeapEntry {
111 doc_id,
112 score,
113 ordinal,
114 });
115 if self.heap.len() == self.k {
117 self.update_threshold();
118 }
119 true
120 } else if score > self.cached_threshold {
121 self.heap.push(HeapEntry {
122 doc_id,
123 score,
124 ordinal,
125 });
126 self.heap.pop(); self.update_threshold();
128 true
129 } else {
130 false
131 }
132 }
133
134 #[inline]
136 pub fn would_enter(&self, score: f32) -> bool {
137 self.heap.len() < self.k || score > self.cached_threshold
138 }
139
140 #[inline]
142 pub fn len(&self) -> usize {
143 self.heap.len()
144 }
145
146 #[inline]
148 pub fn is_empty(&self) -> bool {
149 self.heap.is_empty()
150 }
151
152 pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
154 let mut results: Vec<(DocId, f32, u16)> = self
155 .heap
156 .into_vec()
157 .into_iter()
158 .map(|e| (e.doc_id, e.score, e.ordinal))
159 .collect();
160
161 results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
163
164 results
165 }
166}
167
168#[derive(Debug, Clone, Copy)]
170pub struct ScoredDoc {
171 pub doc_id: DocId,
172 pub score: f32,
173 pub ordinal: u16,
175}
176
177pub struct MaxScoreExecutor<'a> {
188 cursors: Vec<TermCursor<'a>>,
189 prefix_sums: Vec<f32>,
190 collector: ScoreCollector,
191 heap_factor: f32,
192 predicate: Option<super::DocPredicate<'a>>,
193}
194
195pub(crate) struct TermCursor<'a> {
204 pub max_score: f32,
205 num_blocks: usize,
206 block_idx: usize,
208 doc_ids: Vec<u32>,
209 scores: Vec<f32>,
210 ordinals: Vec<u16>,
211 pos: usize,
212 block_loaded: bool,
213 exhausted: bool,
214 lazy_ordinals: bool,
218 ordinals_loaded: bool,
220 current_sparse_block: Option<crate::structures::SparseBlock>,
222 variant: CursorVariant<'a>,
224}
225
226enum CursorVariant<'a> {
227 Text {
229 list: crate::structures::BlockPostingList,
230 idf: f32,
231 b_over_avgfl: f32,
233 one_minus_b: f32,
235 tfs: Vec<u32>, },
237 Sparse {
239 si: &'a crate::segment::SparseIndex,
240 query_weight: f32,
241 skip_start: usize,
242 block_data_offset: u64,
243 },
244}
245
246macro_rules! cursor_ensure_block {
254 ($self:ident, $load_block_fn:ident, $($aw:tt)*) => {{
255 if $self.exhausted || $self.block_loaded {
256 return Ok(!$self.exhausted);
257 }
258 match &mut $self.variant {
259 CursorVariant::Text {
260 list,
261 idf,
262 b_over_avgfl,
263 one_minus_b,
264 tfs,
265 } => {
266 if list.decode_block_into($self.block_idx, &mut $self.doc_ids, tfs) {
267 let idf_val = *idf;
268 let b_avg = *b_over_avgfl;
269 let one_b = *one_minus_b;
270 $self.scores.clear();
271 $self.scores.reserve(tfs.len());
272 for &tf in tfs.iter() {
275 let tf = tf as f32;
276 let length_norm = one_b + b_avg * tf;
277 let tf_norm = (tf * (super::BM25_K1 + 1.0))
278 / (tf + super::BM25_K1 * length_norm);
279 $self.scores.push(idf_val * tf_norm);
280 }
281 $self.pos = 0;
282 $self.block_loaded = true;
283 Ok(true)
284 } else {
285 $self.exhausted = true;
286 Ok(false)
287 }
288 }
289 CursorVariant::Sparse {
290 si,
291 query_weight,
292 skip_start,
293 block_data_offset,
294 ..
295 } => {
296 let block = si
297 .$load_block_fn(*skip_start, *block_data_offset, $self.block_idx)
298 $($aw)* ?;
299 match block {
300 Some(b) => {
301 b.decode_doc_ids_into(&mut $self.doc_ids);
302 b.decode_scored_weights_into(*query_weight, &mut $self.scores);
303 if $self.lazy_ordinals {
304 $self.current_sparse_block = Some(b);
307 $self.ordinals_loaded = false;
308 } else {
309 b.decode_ordinals_into(&mut $self.ordinals);
310 $self.ordinals_loaded = true;
311 $self.current_sparse_block = None;
312 }
313 $self.pos = 0;
314 $self.block_loaded = true;
315 Ok(true)
316 }
317 None => {
318 $self.exhausted = true;
319 Ok(false)
320 }
321 }
322 }
323 }
324 }};
325}
326
327macro_rules! cursor_advance {
328 ($self:ident, $ensure_fn:ident, $($aw:tt)*) => {{
329 if $self.exhausted {
330 return Ok(u32::MAX);
331 }
332 $self.$ensure_fn() $($aw)* ?;
333 if $self.exhausted {
334 return Ok(u32::MAX);
335 }
336 Ok($self.advance_pos())
337 }};
338}
339
340macro_rules! cursor_seek {
341 ($self:ident, $ensure_fn:ident, $target:expr, $($aw:tt)*) => {{
342 if let Some(doc) = $self.seek_prepare($target) {
343 return Ok(doc);
344 }
345 $self.$ensure_fn() $($aw)* ?;
346 if $self.seek_finish($target) {
347 $self.$ensure_fn() $($aw)* ?;
348 }
349 Ok($self.doc())
350 }};
351}
352
353impl<'a> TermCursor<'a> {
354 pub fn text(
356 posting_list: crate::structures::BlockPostingList,
357 idf: f32,
358 avg_field_len: f32,
359 ) -> Self {
360 let max_tf = posting_list.max_tf() as f32;
361 let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
362 let num_blocks = posting_list.num_blocks();
363 let safe_avg = avg_field_len.max(1.0);
364 Self {
365 max_score,
366 num_blocks,
367 block_idx: 0,
368 doc_ids: Vec::with_capacity(128),
369 scores: Vec::with_capacity(128),
370 ordinals: Vec::new(),
371 pos: 0,
372 block_loaded: false,
373 exhausted: num_blocks == 0,
374 lazy_ordinals: false,
375 ordinals_loaded: true, current_sparse_block: None,
377 variant: CursorVariant::Text {
378 list: posting_list,
379 idf,
380 b_over_avgfl: super::BM25_B / safe_avg,
381 one_minus_b: 1.0 - super::BM25_B,
382 tfs: Vec::with_capacity(128),
383 },
384 }
385 }
386
387 pub fn sparse(
390 si: &'a crate::segment::SparseIndex,
391 query_weight: f32,
392 skip_start: usize,
393 skip_count: usize,
394 global_max_weight: f32,
395 block_data_offset: u64,
396 ) -> Self {
397 Self {
398 max_score: query_weight.abs() * global_max_weight,
399 num_blocks: skip_count,
400 block_idx: 0,
401 doc_ids: Vec::with_capacity(256),
402 scores: Vec::with_capacity(256),
403 ordinals: Vec::with_capacity(256),
404 pos: 0,
405 block_loaded: false,
406 exhausted: skip_count == 0,
407 lazy_ordinals: false,
408 ordinals_loaded: true,
409 current_sparse_block: None,
410 variant: CursorVariant::Sparse {
411 si,
412 query_weight,
413 skip_start,
414 block_data_offset,
415 },
416 }
417 }
418
419 #[inline]
422 fn block_first_doc(&self, idx: usize) -> DocId {
423 match &self.variant {
424 CursorVariant::Text { list, .. } => list.block_first_doc(idx).unwrap_or(u32::MAX),
425 CursorVariant::Sparse { si, skip_start, .. } => {
426 si.read_skip_entry(*skip_start + idx).first_doc
427 }
428 }
429 }
430
431 #[inline]
432 fn block_last_doc(&self, idx: usize) -> DocId {
433 match &self.variant {
434 CursorVariant::Text { list, .. } => list.block_last_doc(idx).unwrap_or(0),
435 CursorVariant::Sparse { si, skip_start, .. } => {
436 si.read_skip_entry(*skip_start + idx).last_doc
437 }
438 }
439 }
440
441 #[inline]
444 pub fn doc(&self) -> DocId {
445 if self.exhausted {
446 return u32::MAX;
447 }
448 if self.block_loaded {
449 debug_assert!(self.pos < self.doc_ids.len());
450 unsafe { *self.doc_ids.get_unchecked(self.pos) }
452 } else {
453 self.block_first_doc(self.block_idx)
454 }
455 }
456
457 #[inline]
458 pub fn ordinal(&self) -> u16 {
459 if !self.block_loaded || self.ordinals.is_empty() {
460 return 0;
461 }
462 debug_assert!(self.pos < self.ordinals.len());
463 unsafe { *self.ordinals.get_unchecked(self.pos) }
465 }
466
467 #[inline]
473 pub fn ordinal_mut(&mut self) -> u16 {
474 if !self.block_loaded {
475 return 0;
476 }
477 if !self.ordinals_loaded {
478 if let Some(ref block) = self.current_sparse_block {
479 block.decode_ordinals_into(&mut self.ordinals);
480 }
481 self.ordinals_loaded = true;
482 }
483 if self.ordinals.is_empty() {
484 return 0;
485 }
486 debug_assert!(self.pos < self.ordinals.len());
487 unsafe { *self.ordinals.get_unchecked(self.pos) }
488 }
489
490 #[inline]
491 pub fn score(&self) -> f32 {
492 if !self.block_loaded {
493 return 0.0;
494 }
495 debug_assert!(self.pos < self.scores.len());
496 unsafe { *self.scores.get_unchecked(self.pos) }
498 }
499
500 #[inline]
501 pub fn current_block_max_score(&self) -> f32 {
502 if self.exhausted {
503 return 0.0;
504 }
505 match &self.variant {
506 CursorVariant::Text { list, idf, .. } => {
507 let block_max_tf = list.block_max_tf(self.block_idx).unwrap_or(0) as f32;
508 super::bm25_upper_bound(block_max_tf.max(1.0), *idf)
509 }
510 CursorVariant::Sparse {
511 si,
512 query_weight,
513 skip_start,
514 ..
515 } => query_weight.abs() * si.read_skip_entry(*skip_start + self.block_idx).max_weight,
516 }
517 }
518
519 pub fn skip_to_next_block(&mut self) -> DocId {
522 if self.exhausted {
523 return u32::MAX;
524 }
525 self.block_idx += 1;
526 self.block_loaded = false;
527 if self.block_idx >= self.num_blocks {
528 self.exhausted = true;
529 return u32::MAX;
530 }
531 self.block_first_doc(self.block_idx)
532 }
533
534 #[inline]
535 fn advance_pos(&mut self) -> DocId {
536 self.pos += 1;
537 if self.pos >= self.doc_ids.len() {
538 self.block_idx += 1;
539 self.block_loaded = false;
540 if self.block_idx >= self.num_blocks {
541 self.exhausted = true;
542 return u32::MAX;
543 }
544 }
545 self.doc()
546 }
547
548 pub async fn ensure_block_loaded(&mut self) -> crate::Result<bool> {
554 cursor_ensure_block!(self, load_block_direct, .await)
555 }
556
557 pub fn ensure_block_loaded_sync(&mut self) -> crate::Result<bool> {
558 cursor_ensure_block!(self, load_block_direct_sync,)
559 }
560
561 pub async fn advance(&mut self) -> crate::Result<DocId> {
562 cursor_advance!(self, ensure_block_loaded, .await)
563 }
564
565 pub fn advance_sync(&mut self) -> crate::Result<DocId> {
566 cursor_advance!(self, ensure_block_loaded_sync,)
567 }
568
569 pub async fn seek(&mut self, target: DocId) -> crate::Result<DocId> {
570 cursor_seek!(self, ensure_block_loaded, target, .await)
571 }
572
573 pub fn seek_sync(&mut self, target: DocId) -> crate::Result<DocId> {
574 cursor_seek!(self, ensure_block_loaded_sync, target,)
575 }
576
577 fn seek_prepare(&mut self, target: DocId) -> Option<DocId> {
578 if self.exhausted {
579 return Some(u32::MAX);
580 }
581
582 if self.block_loaded
584 && let Some(&last) = self.doc_ids.last()
585 {
586 if last >= target && self.doc_ids[self.pos] < target {
587 let remaining = &self.doc_ids[self.pos..];
588 self.pos += crate::structures::simd::find_first_ge_u32(remaining, target);
589 if self.pos >= self.doc_ids.len() {
590 self.block_idx += 1;
591 self.block_loaded = false;
592 if self.block_idx >= self.num_blocks {
593 self.exhausted = true;
594 return Some(u32::MAX);
595 }
596 }
597 return Some(self.doc());
598 }
599 if self.doc_ids[self.pos] >= target {
600 return Some(self.doc());
601 }
602 }
603
604 let lo = match &self.variant {
606 CursorVariant::Text { list, .. } => match list.seek_block(target, self.block_idx) {
608 Some(idx) => idx,
609 None => {
610 self.exhausted = true;
611 return Some(u32::MAX);
612 }
613 },
614 CursorVariant::Sparse { .. } => {
616 let mut lo = self.block_idx;
617 let mut hi = self.num_blocks;
618 while lo < hi {
619 let mid = lo + (hi - lo) / 2;
620 if self.block_last_doc(mid) < target {
621 lo = mid + 1;
622 } else {
623 hi = mid;
624 }
625 }
626 lo
627 }
628 };
629 if lo >= self.num_blocks {
630 self.exhausted = true;
631 return Some(u32::MAX);
632 }
633 if lo != self.block_idx || !self.block_loaded {
634 self.block_idx = lo;
635 self.block_loaded = false;
636 }
637 None
638 }
639
640 #[inline]
641 fn seek_finish(&mut self, target: DocId) -> bool {
642 if self.exhausted {
643 return false;
644 }
645 self.pos = crate::structures::simd::find_first_ge_u32(&self.doc_ids, target);
646 if self.pos >= self.doc_ids.len() {
647 self.block_idx += 1;
648 self.block_loaded = false;
649 if self.block_idx >= self.num_blocks {
650 self.exhausted = true;
651 return false;
652 }
653 return true;
654 }
655 false
656 }
657}
658
659macro_rules! bms_execute_loop {
664 ($self:ident, $ensure:ident, $advance:ident, $seek:ident, $($aw:tt)*) => {{
665 let n = $self.cursors.len();
666
667 for cursor in &mut $self.cursors {
669 cursor.$ensure() $($aw)* ?;
670 }
671
672 let mut docs_scored = 0u64;
673 let mut docs_skipped = 0u64;
674 let mut blocks_skipped = 0u64;
675 let mut conjunction_skipped = 0u64;
676 let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
677
678 loop {
679 let partition = $self.find_partition();
680 if partition >= n {
681 break;
682 }
683
684 let mut min_doc = u32::MAX;
688 let mut at_min_mask = 0u64; for i in partition..n {
690 let doc = $self.cursors[i].doc();
691 match doc.cmp(&min_doc) {
692 std::cmp::Ordering::Less => {
693 min_doc = doc;
694 at_min_mask = 1u64 << (i as u32);
695 }
696 std::cmp::Ordering::Equal => {
697 at_min_mask |= 1u64 << (i as u32);
698 }
699 _ => {}
700 }
701 }
702 if min_doc == u32::MAX {
703 break;
704 }
705
706 let non_essential_upper = if partition > 0 {
707 $self.prefix_sums[partition - 1]
708 } else {
709 0.0
710 };
711 let adjusted_threshold = $self.collector.threshold() * $self.heap_factor - 1e-6;
716
717 if $self.collector.len() >= $self.collector.k {
719 let mut present_upper: f32 = 0.0;
720 let mut mask = at_min_mask;
721 while mask != 0 {
722 let i = mask.trailing_zeros() as usize;
723 present_upper += $self.cursors[i].max_score;
724 mask &= mask - 1;
725 }
726
727 if present_upper + non_essential_upper <= adjusted_threshold {
728 let mut mask = at_min_mask;
729 while mask != 0 {
730 let i = mask.trailing_zeros() as usize;
731 $self.cursors[i].$ensure() $($aw)* ?;
732 $self.cursors[i].$advance() $($aw)* ?;
733 mask &= mask - 1;
734 }
735 conjunction_skipped += 1;
736 continue;
737 }
738 }
739
740 if $self.collector.len() >= $self.collector.k {
742 let mut block_max_sum: f32 = 0.0;
743 let mut mask = at_min_mask;
744 while mask != 0 {
745 let i = mask.trailing_zeros() as usize;
746 block_max_sum += $self.cursors[i].current_block_max_score();
747 mask &= mask - 1;
748 }
749
750 if block_max_sum + non_essential_upper <= adjusted_threshold {
751 let mut mask = at_min_mask;
752 while mask != 0 {
753 let i = mask.trailing_zeros() as usize;
754 $self.cursors[i].skip_to_next_block();
755 $self.cursors[i].$ensure() $($aw)* ?;
756 mask &= mask - 1;
757 }
758 blocks_skipped += 1;
759 continue;
760 }
761 }
762
763 if let Some(ref pred) = $self.predicate {
765 if !pred(min_doc) {
766 let mut mask = at_min_mask;
767 while mask != 0 {
768 let i = mask.trailing_zeros() as usize;
769 $self.cursors[i].$ensure() $($aw)* ?;
770 $self.cursors[i].$advance() $($aw)* ?;
771 mask &= mask - 1;
772 }
773 continue;
774 }
775 }
776
777 ordinal_scores.clear();
779 {
780 let mut mask = at_min_mask;
781 while mask != 0 {
782 let i = mask.trailing_zeros() as usize;
783 $self.cursors[i].$ensure() $($aw)* ?;
784 while $self.cursors[i].doc() == min_doc {
785 let ord = $self.cursors[i].ordinal_mut();
786 let sc = $self.cursors[i].score();
787 ordinal_scores.push((ord, sc));
788 $self.cursors[i].$advance() $($aw)* ?;
789 }
790 mask &= mask - 1;
791 }
792 }
793
794 let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
795 if $self.collector.len() >= $self.collector.k
796 && essential_total + non_essential_upper <= adjusted_threshold
797 {
798 docs_skipped += 1;
799 continue;
800 }
801
802 let mut running_total = essential_total;
804 for i in (0..partition).rev() {
805 if $self.collector.len() >= $self.collector.k
806 && running_total + $self.prefix_sums[i] <= adjusted_threshold
807 {
808 break;
809 }
810
811 let doc = $self.cursors[i].$seek(min_doc) $($aw)* ?;
812 if doc == min_doc {
813 while $self.cursors[i].doc() == min_doc {
814 let s = $self.cursors[i].score();
815 running_total += s;
816 let ord = $self.cursors[i].ordinal_mut();
817 ordinal_scores.push((ord, s));
818 $self.cursors[i].$advance() $($aw)* ?;
819 }
820 }
821 }
822
823 if ordinal_scores.len() == 1 {
826 let (ord, score) = ordinal_scores[0];
827 if $self.collector.insert_with_ordinal(min_doc, score, ord) {
828 docs_scored += 1;
829 } else {
830 docs_skipped += 1;
831 }
832 } else if !ordinal_scores.is_empty() {
833 if ordinal_scores.len() > 2 {
834 ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
835 } else if ordinal_scores.len() == 2 && ordinal_scores[0].0 > ordinal_scores[1].0 {
836 ordinal_scores.swap(0, 1);
837 }
838 let mut j = 0;
839 while j < ordinal_scores.len() {
840 let current_ord = ordinal_scores[j].0;
841 let mut score = 0.0f32;
842 while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
843 score += ordinal_scores[j].1;
844 j += 1;
845 }
846 if $self
847 .collector
848 .insert_with_ordinal(min_doc, score, current_ord)
849 {
850 docs_scored += 1;
851 } else {
852 docs_skipped += 1;
853 }
854 }
855 }
856 }
857
858 let results: Vec<ScoredDoc> = $self
859 .collector
860 .into_sorted_results()
861 .into_iter()
862 .map(|(doc_id, score, ordinal)| ScoredDoc {
863 doc_id,
864 score,
865 ordinal,
866 })
867 .collect();
868
869 debug!(
870 "MaxScoreExecutor: scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
871 docs_scored,
872 docs_skipped,
873 blocks_skipped,
874 conjunction_skipped,
875 results.len(),
876 results.first().map(|r| r.score).unwrap_or(0.0)
877 );
878
879 Ok(results)
880 }};
881}
882
883impl<'a> MaxScoreExecutor<'a> {
884 pub(crate) fn new(mut cursors: Vec<TermCursor<'a>>, k: usize, heap_factor: f32) -> Self {
889 for c in &mut cursors {
892 c.lazy_ordinals = true;
893 }
894
895 cursors.sort_by(|a, b| {
897 a.max_score
898 .partial_cmp(&b.max_score)
899 .unwrap_or(Ordering::Equal)
900 });
901
902 let mut prefix_sums = Vec::with_capacity(cursors.len());
903 let mut cumsum = 0.0f32;
904 for c in &cursors {
905 cumsum += c.max_score;
906 prefix_sums.push(cumsum);
907 }
908
909 debug!(
910 "Creating MaxScoreExecutor: num_cursors={}, k={}, total_upper={:.4}, heap_factor={:.2}",
911 cursors.len(),
912 k,
913 cumsum,
914 heap_factor
915 );
916
917 Self {
918 cursors,
919 prefix_sums,
920 collector: ScoreCollector::new(k),
921 heap_factor: heap_factor.clamp(0.0, 1.0),
922 predicate: None,
923 }
924 }
925
926 pub fn sparse(
930 sparse_index: &'a crate::segment::SparseIndex,
931 query_terms: Vec<(u32, f32)>,
932 k: usize,
933 heap_factor: f32,
934 ) -> Self {
935 let cursors: Vec<TermCursor<'a>> = query_terms
936 .iter()
937 .filter_map(|&(dim_id, qw)| {
938 let (skip_start, skip_count, global_max, block_data_offset) =
939 sparse_index.get_skip_range_full(dim_id)?;
940 Some(TermCursor::sparse(
941 sparse_index,
942 qw,
943 skip_start,
944 skip_count,
945 global_max,
946 block_data_offset,
947 ))
948 })
949 .collect();
950 Self::new(cursors, k, heap_factor)
951 }
952
953 pub fn text(
957 posting_lists: Vec<(crate::structures::BlockPostingList, f32)>,
958 avg_field_len: f32,
959 k: usize,
960 ) -> Self {
961 let cursors: Vec<TermCursor<'a>> = posting_lists
962 .into_iter()
963 .map(|(pl, idf)| TermCursor::text(pl, idf, avg_field_len))
964 .collect();
965 Self::new(cursors, k, 1.0)
966 }
967
968 #[inline]
969 fn find_partition(&self) -> usize {
970 let threshold = self.collector.threshold() * self.heap_factor;
971 self.prefix_sums.partition_point(|&sum| sum <= threshold)
972 }
973
974 pub fn with_predicate(mut self, predicate: super::DocPredicate<'a>) -> Self {
980 self.predicate = Some(predicate);
981 self
982 }
983
984 pub async fn execute(mut self) -> crate::Result<Vec<ScoredDoc>> {
986 if self.cursors.is_empty() {
987 return Ok(Vec::new());
988 }
989 bms_execute_loop!(self, ensure_block_loaded, advance, seek, .await)
990 }
991
992 pub fn execute_sync(mut self) -> crate::Result<Vec<ScoredDoc>> {
994 if self.cursors.is_empty() {
995 return Ok(Vec::new());
996 }
997 bms_execute_loop!(self, ensure_block_loaded_sync, advance_sync, seek_sync,)
998 }
999}
1000
1001#[cfg(test)]
1002mod tests {
1003 use super::*;
1004
1005 #[test]
1006 fn test_score_collector_basic() {
1007 let mut collector = ScoreCollector::new(3);
1008
1009 collector.insert(1, 1.0);
1010 collector.insert(2, 2.0);
1011 collector.insert(3, 3.0);
1012 assert_eq!(collector.threshold(), 1.0);
1013
1014 collector.insert(4, 4.0);
1015 assert_eq!(collector.threshold(), 2.0);
1016
1017 let results = collector.into_sorted_results();
1018 assert_eq!(results.len(), 3);
1019 assert_eq!(results[0].0, 4); assert_eq!(results[1].0, 3);
1021 assert_eq!(results[2].0, 2);
1022 }
1023
1024 #[test]
1025 fn test_score_collector_threshold() {
1026 let mut collector = ScoreCollector::new(2);
1027
1028 collector.insert(1, 5.0);
1029 collector.insert(2, 3.0);
1030 assert_eq!(collector.threshold(), 3.0);
1031
1032 assert!(!collector.would_enter(2.0));
1034 assert!(!collector.insert(3, 2.0));
1035
1036 assert!(collector.would_enter(4.0));
1038 assert!(collector.insert(4, 4.0));
1039 assert_eq!(collector.threshold(), 4.0);
1040 }
1041
1042 #[test]
1043 fn test_heap_entry_ordering() {
1044 let mut heap = BinaryHeap::new();
1045 heap.push(HeapEntry {
1046 doc_id: 1,
1047 score: 3.0,
1048 ordinal: 0,
1049 });
1050 heap.push(HeapEntry {
1051 doc_id: 2,
1052 score: 1.0,
1053 ordinal: 0,
1054 });
1055 heap.push(HeapEntry {
1056 doc_id: 3,
1057 score: 2.0,
1058 ordinal: 0,
1059 });
1060
1061 assert_eq!(heap.pop().unwrap().score, 1.0);
1063 assert_eq!(heap.pop().unwrap().score, 2.0);
1064 assert_eq!(heap.pop().unwrap().score, 3.0);
1065 }
1066}