1use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11
12use log::debug;
13
14use crate::DocId;
15
16#[derive(Clone, Copy)]
18pub struct HeapEntry {
19 pub doc_id: DocId,
20 pub score: f32,
21 pub ordinal: u16,
22}
23
24impl PartialEq for HeapEntry {
25 fn eq(&self, other: &Self) -> bool {
26 self.score == other.score && self.doc_id == other.doc_id
27 }
28}
29
30impl Eq for HeapEntry {}
31
32impl Ord for HeapEntry {
33 fn cmp(&self, other: &Self) -> Ordering {
34 other
37 .score
38 .total_cmp(&self.score)
39 .then(self.doc_id.cmp(&other.doc_id))
40 }
41}
42
43impl PartialOrd for HeapEntry {
44 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
45 Some(self.cmp(other))
46 }
47}
48
49pub struct ScoreCollector {
62 heap: BinaryHeap<HeapEntry>,
64 pub k: usize,
65 cached_threshold: f32,
68}
69
70impl ScoreCollector {
71 pub fn new(k: usize) -> Self {
73 let capacity = k.saturating_add(1).min(1_000_000);
75 Self {
76 heap: BinaryHeap::with_capacity(capacity),
77 k,
78 cached_threshold: 0.0,
79 }
80 }
81
82 #[inline]
84 pub fn threshold(&self) -> f32 {
85 self.cached_threshold
86 }
87
88 #[inline]
90 fn update_threshold(&mut self) {
91 self.cached_threshold = if self.heap.len() >= self.k {
92 self.heap.peek().map(|e| e.score).unwrap_or(0.0)
93 } else {
94 0.0
95 };
96 }
97
98 #[inline]
101 pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
102 self.insert_with_ordinal(doc_id, score, 0)
103 }
104
105 #[inline]
108 pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
109 if self.heap.len() < self.k {
110 self.heap.push(HeapEntry {
111 doc_id,
112 score,
113 ordinal,
114 });
115 if self.heap.len() == self.k {
117 self.update_threshold();
118 }
119 true
120 } else if score > self.cached_threshold {
121 self.heap.push(HeapEntry {
122 doc_id,
123 score,
124 ordinal,
125 });
126 self.heap.pop(); self.update_threshold();
128 true
129 } else {
130 false
131 }
132 }
133
134 #[inline]
136 pub fn would_enter(&self, score: f32) -> bool {
137 self.heap.len() < self.k || score > self.cached_threshold
138 }
139
140 #[inline]
142 pub fn len(&self) -> usize {
143 self.heap.len()
144 }
145
146 #[inline]
148 pub fn is_empty(&self) -> bool {
149 self.heap.is_empty()
150 }
151
152 pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
154 let heap_vec = self.heap.into_vec();
155 let mut results: Vec<(DocId, f32, u16)> = Vec::with_capacity(heap_vec.len());
156 for e in heap_vec {
157 results.push((e.doc_id, e.score, e.ordinal));
158 }
159
160 results.sort_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
162
163 results
164 }
165}
166
167#[derive(Debug, Clone, Copy)]
169pub struct ScoredDoc {
170 pub doc_id: DocId,
171 pub score: f32,
172 pub ordinal: u16,
174}
175
176pub struct MaxScoreExecutor<'a> {
187 cursors: Vec<TermCursor<'a>>,
188 prefix_sums: Vec<f32>,
189 collector: ScoreCollector,
190 heap_factor: f32,
191 predicate: Option<super::DocPredicate<'a>>,
192}
193
194pub(crate) struct TermCursor<'a> {
203 pub max_score: f32,
204 num_blocks: usize,
205 block_idx: usize,
207 doc_ids: Vec<u32>,
208 scores: Vec<f32>,
209 ordinals: Vec<u16>,
210 pos: usize,
211 block_loaded: bool,
212 exhausted: bool,
213 lazy_ordinals: bool,
217 ordinals_loaded: bool,
219 current_sparse_block: Option<crate::structures::SparseBlock>,
221 variant: CursorVariant<'a>,
223}
224
225enum CursorVariant<'a> {
226 Text {
228 list: crate::structures::BlockPostingList,
229 idf: f32,
230 b_over_avgfl: f32,
232 one_minus_b: f32,
234 tfs: Vec<u32>, },
236 Sparse {
238 si: &'a crate::segment::SparseIndex,
239 query_weight: f32,
240 skip_start: usize,
241 block_data_offset: u64,
242 },
243}
244
245macro_rules! cursor_ensure_block {
253 ($self:ident, $load_block_fn:ident, $($aw:tt)*) => {{
254 if $self.exhausted || $self.block_loaded {
255 return Ok(!$self.exhausted);
256 }
257 match &mut $self.variant {
258 CursorVariant::Text {
259 list,
260 idf,
261 b_over_avgfl,
262 one_minus_b,
263 tfs,
264 } => {
265 if list.decode_block_into($self.block_idx, &mut $self.doc_ids, tfs) {
266 let idf_val = *idf;
267 let b_avg = *b_over_avgfl;
268 let one_b = *one_minus_b;
269 $self.scores.clear();
270 $self.scores.reserve(tfs.len());
271 for &tf in tfs.iter() {
274 let tf = tf as f32;
275 let length_norm = one_b + b_avg * tf;
276 let tf_norm = (tf * (super::BM25_K1 + 1.0))
277 / (tf + super::BM25_K1 * length_norm);
278 $self.scores.push(idf_val * tf_norm);
279 }
280 $self.pos = 0;
281 $self.block_loaded = true;
282 Ok(true)
283 } else {
284 $self.exhausted = true;
285 Ok(false)
286 }
287 }
288 CursorVariant::Sparse {
289 si,
290 query_weight,
291 skip_start,
292 block_data_offset,
293 ..
294 } => {
295 let block = si
296 .$load_block_fn(*skip_start, *block_data_offset, $self.block_idx)
297 $($aw)* ?;
298 match block {
299 Some(b) => {
300 b.decode_doc_ids_into(&mut $self.doc_ids);
301 b.decode_scored_weights_into(*query_weight, &mut $self.scores);
302 if $self.lazy_ordinals {
303 $self.current_sparse_block = Some(b);
306 $self.ordinals_loaded = false;
307 } else {
308 b.decode_ordinals_into(&mut $self.ordinals);
309 $self.ordinals_loaded = true;
310 $self.current_sparse_block = None;
311 }
312 $self.pos = 0;
313 $self.block_loaded = true;
314 Ok(true)
315 }
316 None => {
317 $self.exhausted = true;
318 Ok(false)
319 }
320 }
321 }
322 }
323 }};
324}
325
326macro_rules! cursor_advance {
327 ($self:ident, $ensure_fn:ident, $($aw:tt)*) => {{
328 if $self.exhausted {
329 return Ok(u32::MAX);
330 }
331 $self.$ensure_fn() $($aw)* ?;
332 if $self.exhausted {
333 return Ok(u32::MAX);
334 }
335 Ok($self.advance_pos())
336 }};
337}
338
339macro_rules! cursor_seek {
340 ($self:ident, $ensure_fn:ident, $target:expr, $($aw:tt)*) => {{
341 if let Some(doc) = $self.seek_prepare($target) {
342 return Ok(doc);
343 }
344 $self.$ensure_fn() $($aw)* ?;
345 if $self.seek_finish($target) {
346 $self.$ensure_fn() $($aw)* ?;
347 }
348 Ok($self.doc())
349 }};
350}
351
352impl<'a> TermCursor<'a> {
353 pub fn text(
355 posting_list: crate::structures::BlockPostingList,
356 idf: f32,
357 avg_field_len: f32,
358 ) -> Self {
359 let max_tf = posting_list.max_tf() as f32;
360 let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
361 let num_blocks = posting_list.num_blocks();
362 let safe_avg = avg_field_len.max(1.0);
363 Self {
364 max_score,
365 num_blocks,
366 block_idx: 0,
367 doc_ids: Vec::with_capacity(128),
368 scores: Vec::with_capacity(128),
369 ordinals: Vec::new(),
370 pos: 0,
371 block_loaded: false,
372 exhausted: num_blocks == 0,
373 lazy_ordinals: false,
374 ordinals_loaded: true, current_sparse_block: None,
376 variant: CursorVariant::Text {
377 list: posting_list,
378 idf,
379 b_over_avgfl: super::BM25_B / safe_avg,
380 one_minus_b: 1.0 - super::BM25_B,
381 tfs: Vec::with_capacity(128),
382 },
383 }
384 }
385
386 pub fn sparse(
389 si: &'a crate::segment::SparseIndex,
390 query_weight: f32,
391 skip_start: usize,
392 skip_count: usize,
393 global_max_weight: f32,
394 block_data_offset: u64,
395 ) -> Self {
396 Self {
397 max_score: query_weight.abs() * global_max_weight,
398 num_blocks: skip_count,
399 block_idx: 0,
400 doc_ids: Vec::with_capacity(256),
401 scores: Vec::with_capacity(256),
402 ordinals: Vec::with_capacity(256),
403 pos: 0,
404 block_loaded: false,
405 exhausted: skip_count == 0,
406 lazy_ordinals: false,
407 ordinals_loaded: true,
408 current_sparse_block: None,
409 variant: CursorVariant::Sparse {
410 si,
411 query_weight,
412 skip_start,
413 block_data_offset,
414 },
415 }
416 }
417
418 #[inline]
421 fn block_first_doc(&self, idx: usize) -> DocId {
422 match &self.variant {
423 CursorVariant::Text { list, .. } => list.block_first_doc(idx).unwrap_or(u32::MAX),
424 CursorVariant::Sparse { si, skip_start, .. } => {
425 si.read_skip_entry(*skip_start + idx).first_doc
426 }
427 }
428 }
429
430 #[inline]
431 fn block_last_doc(&self, idx: usize) -> DocId {
432 match &self.variant {
433 CursorVariant::Text { list, .. } => list.block_last_doc(idx).unwrap_or(0),
434 CursorVariant::Sparse { si, skip_start, .. } => {
435 si.read_skip_entry(*skip_start + idx).last_doc
436 }
437 }
438 }
439
440 #[inline]
443 pub fn doc(&self) -> DocId {
444 if self.exhausted {
445 return u32::MAX;
446 }
447 if self.block_loaded {
448 debug_assert!(self.pos < self.doc_ids.len());
449 unsafe { *self.doc_ids.get_unchecked(self.pos) }
451 } else {
452 self.block_first_doc(self.block_idx)
453 }
454 }
455
456 #[inline]
457 pub fn ordinal(&self) -> u16 {
458 if !self.block_loaded || self.ordinals.is_empty() {
459 return 0;
460 }
461 debug_assert!(self.pos < self.ordinals.len());
462 unsafe { *self.ordinals.get_unchecked(self.pos) }
464 }
465
466 #[inline]
472 pub fn ordinal_mut(&mut self) -> u16 {
473 if !self.block_loaded {
474 return 0;
475 }
476 if !self.ordinals_loaded {
477 if let Some(ref block) = self.current_sparse_block {
478 block.decode_ordinals_into(&mut self.ordinals);
479 }
480 self.ordinals_loaded = true;
481 }
482 if self.ordinals.is_empty() {
483 return 0;
484 }
485 debug_assert!(self.pos < self.ordinals.len());
486 unsafe { *self.ordinals.get_unchecked(self.pos) }
487 }
488
489 #[inline]
490 pub fn score(&self) -> f32 {
491 if !self.block_loaded {
492 return 0.0;
493 }
494 debug_assert!(self.pos < self.scores.len());
495 unsafe { *self.scores.get_unchecked(self.pos) }
497 }
498
499 #[inline]
500 pub fn current_block_max_score(&self) -> f32 {
501 if self.exhausted {
502 return 0.0;
503 }
504 match &self.variant {
505 CursorVariant::Text { list, idf, .. } => {
506 let block_max_tf = list.block_max_tf(self.block_idx).unwrap_or(0) as f32;
507 super::bm25_upper_bound(block_max_tf.max(1.0), *idf)
508 }
509 CursorVariant::Sparse {
510 si,
511 query_weight,
512 skip_start,
513 ..
514 } => query_weight.abs() * si.read_skip_entry(*skip_start + self.block_idx).max_weight,
515 }
516 }
517
518 pub fn skip_to_next_block(&mut self) -> DocId {
521 if self.exhausted {
522 return u32::MAX;
523 }
524 self.block_idx += 1;
525 self.block_loaded = false;
526 if self.block_idx >= self.num_blocks {
527 self.exhausted = true;
528 return u32::MAX;
529 }
530 self.block_first_doc(self.block_idx)
531 }
532
533 #[inline]
534 fn advance_pos(&mut self) -> DocId {
535 self.pos += 1;
536 if self.pos >= self.doc_ids.len() {
537 self.block_idx += 1;
538 self.block_loaded = false;
539 if self.block_idx >= self.num_blocks {
540 self.exhausted = true;
541 return u32::MAX;
542 }
543 }
544 self.doc()
545 }
546
547 pub async fn ensure_block_loaded(&mut self) -> crate::Result<bool> {
553 cursor_ensure_block!(self, load_block_direct, .await)
554 }
555
556 pub fn ensure_block_loaded_sync(&mut self) -> crate::Result<bool> {
557 cursor_ensure_block!(self, load_block_direct_sync,)
558 }
559
560 pub async fn advance(&mut self) -> crate::Result<DocId> {
561 cursor_advance!(self, ensure_block_loaded, .await)
562 }
563
564 pub fn advance_sync(&mut self) -> crate::Result<DocId> {
565 cursor_advance!(self, ensure_block_loaded_sync,)
566 }
567
568 pub async fn seek(&mut self, target: DocId) -> crate::Result<DocId> {
569 cursor_seek!(self, ensure_block_loaded, target, .await)
570 }
571
572 pub fn seek_sync(&mut self, target: DocId) -> crate::Result<DocId> {
573 cursor_seek!(self, ensure_block_loaded_sync, target,)
574 }
575
576 fn seek_prepare(&mut self, target: DocId) -> Option<DocId> {
577 if self.exhausted {
578 return Some(u32::MAX);
579 }
580
581 if self.block_loaded
583 && let Some(&last) = self.doc_ids.last()
584 {
585 if last >= target && self.doc_ids[self.pos] < target {
586 let remaining = &self.doc_ids[self.pos..];
587 self.pos += crate::structures::simd::find_first_ge_u32(remaining, target);
588 if self.pos >= self.doc_ids.len() {
589 self.block_idx += 1;
590 self.block_loaded = false;
591 if self.block_idx >= self.num_blocks {
592 self.exhausted = true;
593 return Some(u32::MAX);
594 }
595 }
596 return Some(self.doc());
597 }
598 if self.doc_ids[self.pos] >= target {
599 return Some(self.doc());
600 }
601 }
602
603 let lo = match &self.variant {
605 CursorVariant::Text { list, .. } => match list.seek_block(target, self.block_idx) {
607 Some(idx) => idx,
608 None => {
609 self.exhausted = true;
610 return Some(u32::MAX);
611 }
612 },
613 CursorVariant::Sparse { .. } => {
615 let mut lo = self.block_idx;
616 let mut hi = self.num_blocks;
617 while lo < hi {
618 let mid = lo + (hi - lo) / 2;
619 if self.block_last_doc(mid) < target {
620 lo = mid + 1;
621 } else {
622 hi = mid;
623 }
624 }
625 lo
626 }
627 };
628 if lo >= self.num_blocks {
629 self.exhausted = true;
630 return Some(u32::MAX);
631 }
632 if lo != self.block_idx || !self.block_loaded {
633 self.block_idx = lo;
634 self.block_loaded = false;
635 }
636 None
637 }
638
639 #[inline]
640 fn seek_finish(&mut self, target: DocId) -> bool {
641 if self.exhausted {
642 return false;
643 }
644 self.pos = crate::structures::simd::find_first_ge_u32(&self.doc_ids, target);
645 if self.pos >= self.doc_ids.len() {
646 self.block_idx += 1;
647 self.block_loaded = false;
648 if self.block_idx >= self.num_blocks {
649 self.exhausted = true;
650 return false;
651 }
652 return true;
653 }
654 false
655 }
656}
657
658macro_rules! bms_execute_loop {
663 ($self:ident, $ensure:ident, $advance:ident, $seek:ident, $($aw:tt)*) => {{
664 let n = $self.cursors.len();
665
666 for cursor in &mut $self.cursors {
668 cursor.$ensure() $($aw)* ?;
669 }
670
671 let mut docs_scored = 0u64;
672 let mut docs_skipped = 0u64;
673 let mut blocks_skipped = 0u64;
674 let mut conjunction_skipped = 0u64;
675 let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
676
677 loop {
678 let partition = $self.find_partition();
679 if partition >= n {
680 break;
681 }
682
683 let mut min_doc = u32::MAX;
687 let mut at_min_mask = 0u64; for i in partition..n {
689 let doc = $self.cursors[i].doc();
690 match doc.cmp(&min_doc) {
691 std::cmp::Ordering::Less => {
692 min_doc = doc;
693 at_min_mask = 1u64 << (i as u32);
694 }
695 std::cmp::Ordering::Equal => {
696 at_min_mask |= 1u64 << (i as u32);
697 }
698 _ => {}
699 }
700 }
701 if min_doc == u32::MAX {
702 break;
703 }
704
705 let non_essential_upper = if partition > 0 {
706 $self.prefix_sums[partition - 1]
707 } else {
708 0.0
709 };
710 let adjusted_threshold = $self.collector.threshold() * $self.heap_factor - 1e-6;
715
716 if $self.collector.len() >= $self.collector.k {
718 let mut present_upper: f32 = 0.0;
719 let mut mask = at_min_mask;
720 while mask != 0 {
721 let i = mask.trailing_zeros() as usize;
722 present_upper += $self.cursors[i].max_score;
723 mask &= mask - 1;
724 }
725
726 if present_upper + non_essential_upper <= adjusted_threshold {
727 let mut mask = at_min_mask;
728 while mask != 0 {
729 let i = mask.trailing_zeros() as usize;
730 $self.cursors[i].$ensure() $($aw)* ?;
731 $self.cursors[i].$advance() $($aw)* ?;
732 mask &= mask - 1;
733 }
734 conjunction_skipped += 1;
735 continue;
736 }
737 }
738
739 if $self.collector.len() >= $self.collector.k {
741 let mut block_max_sum: f32 = 0.0;
742 let mut mask = at_min_mask;
743 while mask != 0 {
744 let i = mask.trailing_zeros() as usize;
745 block_max_sum += $self.cursors[i].current_block_max_score();
746 mask &= mask - 1;
747 }
748
749 if block_max_sum + non_essential_upper <= adjusted_threshold {
750 let mut mask = at_min_mask;
751 while mask != 0 {
752 let i = mask.trailing_zeros() as usize;
753 $self.cursors[i].skip_to_next_block();
754 $self.cursors[i].$ensure() $($aw)* ?;
755 mask &= mask - 1;
756 }
757 blocks_skipped += 1;
758 continue;
759 }
760 }
761
762 if let Some(ref pred) = $self.predicate {
764 if !pred(min_doc) {
765 let mut mask = at_min_mask;
766 while mask != 0 {
767 let i = mask.trailing_zeros() as usize;
768 $self.cursors[i].$ensure() $($aw)* ?;
769 $self.cursors[i].$advance() $($aw)* ?;
770 mask &= mask - 1;
771 }
772 continue;
773 }
774 }
775
776 ordinal_scores.clear();
778 {
779 let mut mask = at_min_mask;
780 while mask != 0 {
781 let i = mask.trailing_zeros() as usize;
782 $self.cursors[i].$ensure() $($aw)* ?;
783 while $self.cursors[i].doc() == min_doc {
784 let ord = $self.cursors[i].ordinal_mut();
785 let sc = $self.cursors[i].score();
786 ordinal_scores.push((ord, sc));
787 $self.cursors[i].$advance() $($aw)* ?;
788 }
789 mask &= mask - 1;
790 }
791 }
792
793 let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
794 if $self.collector.len() >= $self.collector.k
795 && essential_total + non_essential_upper <= adjusted_threshold
796 {
797 docs_skipped += 1;
798 continue;
799 }
800
801 let mut running_total = essential_total;
803 for i in (0..partition).rev() {
804 if $self.collector.len() >= $self.collector.k
805 && running_total + $self.prefix_sums[i] <= adjusted_threshold
806 {
807 break;
808 }
809
810 let doc = $self.cursors[i].$seek(min_doc) $($aw)* ?;
811 if doc == min_doc {
812 while $self.cursors[i].doc() == min_doc {
813 let s = $self.cursors[i].score();
814 running_total += s;
815 let ord = $self.cursors[i].ordinal_mut();
816 ordinal_scores.push((ord, s));
817 $self.cursors[i].$advance() $($aw)* ?;
818 }
819 }
820 }
821
822 if ordinal_scores.len() == 1 {
825 let (ord, score) = ordinal_scores[0];
826 if $self.collector.insert_with_ordinal(min_doc, score, ord) {
827 docs_scored += 1;
828 } else {
829 docs_skipped += 1;
830 }
831 } else if !ordinal_scores.is_empty() {
832 if ordinal_scores.len() > 2 {
833 ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
834 } else if ordinal_scores.len() == 2 && ordinal_scores[0].0 > ordinal_scores[1].0 {
835 ordinal_scores.swap(0, 1);
836 }
837 let mut j = 0;
838 while j < ordinal_scores.len() {
839 let current_ord = ordinal_scores[j].0;
840 let mut score = 0.0f32;
841 while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
842 score += ordinal_scores[j].1;
843 j += 1;
844 }
845 if $self
846 .collector
847 .insert_with_ordinal(min_doc, score, current_ord)
848 {
849 docs_scored += 1;
850 } else {
851 docs_skipped += 1;
852 }
853 }
854 }
855 }
856
857 let results: Vec<ScoredDoc> = $self
858 .collector
859 .into_sorted_results()
860 .into_iter()
861 .map(|(doc_id, score, ordinal)| ScoredDoc {
862 doc_id,
863 score,
864 ordinal,
865 })
866 .collect();
867
868 debug!(
869 "MaxScoreExecutor: scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
870 docs_scored,
871 docs_skipped,
872 blocks_skipped,
873 conjunction_skipped,
874 results.len(),
875 results.first().map(|r| r.score).unwrap_or(0.0)
876 );
877
878 Ok(results)
879 }};
880}
881
882impl<'a> MaxScoreExecutor<'a> {
883 pub(crate) fn new(mut cursors: Vec<TermCursor<'a>>, k: usize, heap_factor: f32) -> Self {
888 for c in &mut cursors {
891 c.lazy_ordinals = true;
892 }
893
894 cursors.sort_by(|a, b| {
896 a.max_score
897 .partial_cmp(&b.max_score)
898 .unwrap_or(Ordering::Equal)
899 });
900
901 let mut prefix_sums = Vec::with_capacity(cursors.len());
902 let mut cumsum = 0.0f32;
903 for c in &cursors {
904 cumsum += c.max_score;
905 prefix_sums.push(cumsum);
906 }
907
908 debug!(
909 "Creating MaxScoreExecutor: num_cursors={}, k={}, total_upper={:.4}, heap_factor={:.2}",
910 cursors.len(),
911 k,
912 cumsum,
913 heap_factor
914 );
915
916 Self {
917 cursors,
918 prefix_sums,
919 collector: ScoreCollector::new(k),
920 heap_factor: heap_factor.clamp(0.0, 1.0),
921 predicate: None,
922 }
923 }
924
925 pub fn sparse(
929 sparse_index: &'a crate::segment::SparseIndex,
930 query_terms: Vec<(u32, f32)>,
931 k: usize,
932 heap_factor: f32,
933 ) -> Self {
934 let cursors: Vec<TermCursor<'a>> = query_terms
935 .iter()
936 .filter_map(|&(dim_id, qw)| {
937 let (skip_start, skip_count, global_max, block_data_offset) =
938 sparse_index.get_skip_range_full(dim_id)?;
939 Some(TermCursor::sparse(
940 sparse_index,
941 qw,
942 skip_start,
943 skip_count,
944 global_max,
945 block_data_offset,
946 ))
947 })
948 .collect();
949 Self::new(cursors, k, heap_factor)
950 }
951
952 pub fn text(
956 posting_lists: Vec<(crate::structures::BlockPostingList, f32)>,
957 avg_field_len: f32,
958 k: usize,
959 ) -> Self {
960 let cursors: Vec<TermCursor<'a>> = posting_lists
961 .into_iter()
962 .map(|(pl, idf)| TermCursor::text(pl, idf, avg_field_len))
963 .collect();
964 Self::new(cursors, k, 1.0)
965 }
966
967 #[inline]
968 fn find_partition(&self) -> usize {
969 let threshold = self.collector.threshold() * self.heap_factor;
970 self.prefix_sums.partition_point(|&sum| sum <= threshold)
971 }
972
973 pub fn with_predicate(mut self, predicate: super::DocPredicate<'a>) -> Self {
979 self.predicate = Some(predicate);
980 self
981 }
982
983 pub async fn execute(mut self) -> crate::Result<Vec<ScoredDoc>> {
985 if self.cursors.is_empty() {
986 return Ok(Vec::new());
987 }
988 bms_execute_loop!(self, ensure_block_loaded, advance, seek, .await)
989 }
990
991 pub fn execute_sync(mut self) -> crate::Result<Vec<ScoredDoc>> {
993 if self.cursors.is_empty() {
994 return Ok(Vec::new());
995 }
996 bms_execute_loop!(self, ensure_block_loaded_sync, advance_sync, seek_sync,)
997 }
998}
999
1000#[cfg(test)]
1001mod tests {
1002 use super::*;
1003
1004 #[test]
1005 fn test_score_collector_basic() {
1006 let mut collector = ScoreCollector::new(3);
1007
1008 collector.insert(1, 1.0);
1009 collector.insert(2, 2.0);
1010 collector.insert(3, 3.0);
1011 assert_eq!(collector.threshold(), 1.0);
1012
1013 collector.insert(4, 4.0);
1014 assert_eq!(collector.threshold(), 2.0);
1015
1016 let results = collector.into_sorted_results();
1017 assert_eq!(results.len(), 3);
1018 assert_eq!(results[0].0, 4); assert_eq!(results[1].0, 3);
1020 assert_eq!(results[2].0, 2);
1021 }
1022
1023 #[test]
1024 fn test_score_collector_threshold() {
1025 let mut collector = ScoreCollector::new(2);
1026
1027 collector.insert(1, 5.0);
1028 collector.insert(2, 3.0);
1029 assert_eq!(collector.threshold(), 3.0);
1030
1031 assert!(!collector.would_enter(2.0));
1033 assert!(!collector.insert(3, 2.0));
1034
1035 assert!(collector.would_enter(4.0));
1037 assert!(collector.insert(4, 4.0));
1038 assert_eq!(collector.threshold(), 4.0);
1039 }
1040
1041 #[test]
1042 fn test_heap_entry_ordering() {
1043 let mut heap = BinaryHeap::new();
1044 heap.push(HeapEntry {
1045 doc_id: 1,
1046 score: 3.0,
1047 ordinal: 0,
1048 });
1049 heap.push(HeapEntry {
1050 doc_id: 2,
1051 score: 1.0,
1052 ordinal: 0,
1053 });
1054 heap.push(HeapEntry {
1055 doc_id: 3,
1056 score: 2.0,
1057 ordinal: 0,
1058 });
1059
1060 assert_eq!(heap.pop().unwrap().score, 1.0);
1062 assert_eq!(heap.pop().unwrap().score, 2.0);
1063 assert_eq!(heap.pop().unwrap().score, 3.0);
1064 }
1065}