1use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12use std::sync::Arc;
13
14use log::{debug, trace};
15
16use crate::DocId;
17use crate::structures::BlockSparsePostingList;
18
19pub trait ScoringIterator {
23 fn doc(&self) -> DocId;
25
26 fn ordinal(&self) -> u16 {
28 0
29 }
30
31 fn advance(&mut self) -> DocId;
33
34 fn seek(&mut self, target: DocId) -> DocId;
36
37 fn is_exhausted(&self) -> bool {
39 self.doc() == u32::MAX
40 }
41
42 fn score(&self) -> f32;
44
45 fn max_score(&self) -> f32;
47
48 fn current_block_max_score(&self) -> f32;
50
51 fn skip_to_next_block(&mut self) -> DocId {
55 self.advance()
56 }
57}
58
59#[derive(Clone, Copy)]
61pub struct HeapEntry {
62 pub doc_id: DocId,
63 pub score: f32,
64 pub ordinal: u16,
65}
66
67impl PartialEq for HeapEntry {
68 fn eq(&self, other: &Self) -> bool {
69 self.score == other.score && self.doc_id == other.doc_id
70 }
71}
72
73impl Eq for HeapEntry {}
74
75impl Ord for HeapEntry {
76 fn cmp(&self, other: &Self) -> Ordering {
77 other
79 .score
80 .partial_cmp(&self.score)
81 .unwrap_or(Ordering::Equal)
82 .then_with(|| self.doc_id.cmp(&other.doc_id))
83 }
84}
85
86impl PartialOrd for HeapEntry {
87 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
88 Some(self.cmp(other))
89 }
90}
91
92pub struct ScoreCollector {
98 heap: BinaryHeap<HeapEntry>,
100 pub k: usize,
101}
102
103impl ScoreCollector {
104 pub fn new(k: usize) -> Self {
106 let capacity = k.saturating_add(1).min(1_000_000);
108 Self {
109 heap: BinaryHeap::with_capacity(capacity),
110 k,
111 }
112 }
113
114 #[inline]
116 pub fn threshold(&self) -> f32 {
117 if self.heap.len() >= self.k {
118 self.heap.peek().map(|e| e.score).unwrap_or(0.0)
119 } else {
120 0.0
121 }
122 }
123
124 #[inline]
127 pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
128 self.insert_with_ordinal(doc_id, score, 0)
129 }
130
131 #[inline]
134 pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
135 if self.heap.len() < self.k {
136 self.heap.push(HeapEntry {
137 doc_id,
138 score,
139 ordinal,
140 });
141 true
142 } else if score > self.threshold() {
143 self.heap.push(HeapEntry {
144 doc_id,
145 score,
146 ordinal,
147 });
148 self.heap.pop(); true
150 } else {
151 false
152 }
153 }
154
155 #[inline]
157 pub fn would_enter(&self, score: f32) -> bool {
158 self.heap.len() < self.k || score > self.threshold()
159 }
160
161 #[inline]
163 pub fn len(&self) -> usize {
164 self.heap.len()
165 }
166
167 #[inline]
169 pub fn is_empty(&self) -> bool {
170 self.heap.is_empty()
171 }
172
173 pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
175 let heap_vec = self.heap.into_vec();
176 let mut results: Vec<(DocId, f32, u16)> = Vec::with_capacity(heap_vec.len());
177 for e in heap_vec {
178 results.push((e.doc_id, e.score, e.ordinal));
179 }
180
181 results.sort_by(|a, b| {
183 b.1.partial_cmp(&a.1)
184 .unwrap_or(Ordering::Equal)
185 .then_with(|| a.0.cmp(&b.0))
186 });
187
188 results
189 }
190}
191
192#[derive(Debug, Clone, Copy)]
194pub struct ScoredDoc {
195 pub doc_id: DocId,
196 pub score: f32,
197 pub ordinal: u16,
199}
200
201pub struct BlockMaxScoreExecutor<S: ScoringIterator> {
215 scorers: Vec<S>,
217 prefix_sums: Vec<f32>,
219 collector: ScoreCollector,
221 heap_factor: f32,
225}
226
227pub type WandExecutor<S> = BlockMaxScoreExecutor<S>;
229
230impl<S: ScoringIterator> BlockMaxScoreExecutor<S> {
231 pub fn new(scorers: Vec<S>, k: usize) -> Self {
233 Self::with_heap_factor(scorers, k, 1.0)
234 }
235
236 pub fn with_heap_factor(mut scorers: Vec<S>, k: usize, heap_factor: f32) -> Self {
243 scorers.sort_by(|a, b| {
245 a.max_score()
246 .partial_cmp(&b.max_score())
247 .unwrap_or(Ordering::Equal)
248 });
249
250 let mut prefix_sums = Vec::with_capacity(scorers.len());
252 let mut cumsum = 0.0f32;
253 for s in &scorers {
254 cumsum += s.max_score();
255 prefix_sums.push(cumsum);
256 }
257
258 debug!(
259 "Creating BlockMaxScoreExecutor: num_scorers={}, k={}, total_upper={:.4}, heap_factor={:.2}",
260 scorers.len(),
261 k,
262 cumsum,
263 heap_factor
264 );
265
266 Self {
267 scorers,
268 prefix_sums,
269 collector: ScoreCollector::new(k),
270 heap_factor: heap_factor.clamp(0.0, 1.0),
271 }
272 }
273
274 #[inline]
277 fn find_partition(&self) -> usize {
278 let threshold = self.collector.threshold() * self.heap_factor;
279 self.prefix_sums
280 .iter()
281 .position(|&sum| sum > threshold)
282 .unwrap_or(self.scorers.len())
283 }
284
285 pub fn execute(mut self) -> Vec<ScoredDoc> {
295 if self.scorers.is_empty() {
296 debug!("BlockMaxScoreExecutor: no scorers, returning empty results");
297 return Vec::new();
298 }
299
300 let n = self.scorers.len();
301 let mut docs_scored = 0u64;
302 let mut docs_skipped = 0u64;
303 let mut blocks_skipped = 0u64;
304 let mut conjunction_skipped = 0u64;
305
306 let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
308
309 loop {
310 let partition = self.find_partition();
311
312 if partition >= n {
314 debug!("BlockMaxScore: all terms non-essential, early termination");
315 break;
316 }
317
318 let mut min_doc = u32::MAX;
320 for i in partition..n {
321 let doc = self.scorers[i].doc();
322 if doc < min_doc {
323 min_doc = doc;
324 }
325 }
326
327 if min_doc == u32::MAX {
328 break; }
330
331 let non_essential_upper = if partition > 0 {
332 self.prefix_sums[partition - 1]
333 } else {
334 0.0
335 };
336 let adjusted_threshold = self.collector.threshold() * self.heap_factor;
337
338 if self.collector.len() >= self.collector.k {
343 let present_upper: f32 = (partition..n)
344 .filter(|&i| self.scorers[i].doc() == min_doc)
345 .map(|i| self.scorers[i].max_score())
346 .sum();
347
348 if present_upper + non_essential_upper <= adjusted_threshold {
349 for i in partition..n {
351 if self.scorers[i].doc() == min_doc {
352 self.scorers[i].advance();
353 }
354 }
355 conjunction_skipped += 1;
356 continue;
357 }
358 }
359
360 if self.collector.len() >= self.collector.k {
364 let block_max_sum: f32 = (partition..n)
365 .filter(|&i| self.scorers[i].doc() == min_doc)
366 .map(|i| self.scorers[i].current_block_max_score())
367 .sum();
368
369 if block_max_sum + non_essential_upper <= adjusted_threshold {
370 for i in partition..n {
371 if self.scorers[i].doc() == min_doc {
372 self.scorers[i].skip_to_next_block();
373 }
374 }
375 blocks_skipped += 1;
376 continue;
377 }
378 }
379
380 ordinal_scores.clear();
383
384 for i in partition..n {
385 if self.scorers[i].doc() == min_doc {
386 while self.scorers[i].doc() == min_doc {
387 ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
388 self.scorers[i].advance();
389 }
390 }
391 }
392
393 let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
395
396 if self.collector.len() >= self.collector.k
397 && essential_total + non_essential_upper <= adjusted_threshold
398 {
399 docs_skipped += 1;
400 continue;
401 }
402
403 for i in 0..partition {
405 let doc = self.scorers[i].seek(min_doc);
406 if doc == min_doc {
407 while self.scorers[i].doc() == min_doc {
408 ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
409 self.scorers[i].advance();
410 }
411 }
412 }
413
414 ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
416 let mut j = 0;
417 while j < ordinal_scores.len() {
418 let current_ord = ordinal_scores[j].0;
419 let mut score = 0.0f32;
420 while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
421 score += ordinal_scores[j].1;
422 j += 1;
423 }
424
425 trace!(
426 "Doc {}: ordinal={}, score={:.4}, threshold={:.4}",
427 min_doc, current_ord, score, adjusted_threshold
428 );
429
430 if self
431 .collector
432 .insert_with_ordinal(min_doc, score, current_ord)
433 {
434 docs_scored += 1;
435 } else {
436 docs_skipped += 1;
437 }
438 }
439 }
440
441 let results: Vec<ScoredDoc> = self
442 .collector
443 .into_sorted_results()
444 .into_iter()
445 .map(|(doc_id, score, ordinal)| ScoredDoc {
446 doc_id,
447 score,
448 ordinal,
449 })
450 .collect();
451
452 debug!(
453 "BlockMaxScoreExecutor completed: scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
454 docs_scored,
455 docs_skipped,
456 blocks_skipped,
457 conjunction_skipped,
458 results.len(),
459 results.first().map(|r| r.score).unwrap_or(0.0)
460 );
461
462 results
463 }
464}
465
466pub struct TextTermScorer {
471 iter: crate::structures::BlockPostingIterator<'static>,
473 idf: f32,
475 avg_field_len: f32,
477 max_score: f32,
479}
480
481impl TextTermScorer {
482 pub fn new(
484 posting_list: crate::structures::BlockPostingList,
485 idf: f32,
486 avg_field_len: f32,
487 ) -> Self {
488 let max_tf = posting_list.max_tf() as f32;
490 let doc_count = posting_list.doc_count();
491 let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
492
493 debug!(
494 "Created TextTermScorer: doc_count={}, max_tf={:.0}, idf={:.4}, avg_field_len={:.2}, max_score={:.4}",
495 doc_count, max_tf, idf, avg_field_len, max_score
496 );
497
498 Self {
499 iter: posting_list.into_iterator(),
500 idf,
501 avg_field_len,
502 max_score,
503 }
504 }
505}
506
507impl ScoringIterator for TextTermScorer {
508 #[inline]
509 fn doc(&self) -> DocId {
510 self.iter.doc()
511 }
512
513 #[inline]
514 fn advance(&mut self) -> DocId {
515 self.iter.advance()
516 }
517
518 #[inline]
519 fn seek(&mut self, target: DocId) -> DocId {
520 self.iter.seek(target)
521 }
522
523 #[inline]
524 fn score(&self) -> f32 {
525 let tf = self.iter.term_freq() as f32;
526 super::bm25_score(tf, self.idf, tf, self.avg_field_len)
528 }
529
530 #[inline]
531 fn max_score(&self) -> f32 {
532 self.max_score
533 }
534
535 #[inline]
536 fn current_block_max_score(&self) -> f32 {
537 let block_max_tf = self.iter.current_block_max_tf() as f32;
539 super::bm25_upper_bound(block_max_tf.max(1.0), self.idf)
540 }
541
542 #[inline]
543 fn skip_to_next_block(&mut self) -> DocId {
544 self.iter.skip_to_next_block()
545 }
546}
547
548pub struct SparseTermScorer<'a> {
552 iter: crate::structures::BlockSparsePostingIterator<'a>,
554 query_weight: f32,
556 max_score: f32,
558}
559
560impl<'a> SparseTermScorer<'a> {
561 pub fn new(posting_list: &'a BlockSparsePostingList, query_weight: f32) -> Self {
566 let max_score = query_weight.abs() * posting_list.global_max_weight();
569 Self {
570 iter: posting_list.iterator(),
571 query_weight,
572 max_score,
573 }
574 }
575
576 pub fn from_arc(posting_list: &'a Arc<BlockSparsePostingList>, query_weight: f32) -> Self {
578 Self::new(posting_list.as_ref(), query_weight)
579 }
580}
581
582impl ScoringIterator for SparseTermScorer<'_> {
583 #[inline]
584 fn doc(&self) -> DocId {
585 self.iter.doc()
586 }
587
588 #[inline]
589 fn ordinal(&self) -> u16 {
590 self.iter.ordinal()
591 }
592
593 #[inline]
594 fn advance(&mut self) -> DocId {
595 self.iter.advance()
596 }
597
598 #[inline]
599 fn seek(&mut self, target: DocId) -> DocId {
600 self.iter.seek(target)
601 }
602
603 #[inline]
604 fn score(&self) -> f32 {
605 self.query_weight * self.iter.weight()
607 }
608
609 #[inline]
610 fn max_score(&self) -> f32 {
611 self.max_score
612 }
613
614 #[inline]
615 fn current_block_max_score(&self) -> f32 {
616 self.iter
618 .current_block_max_contribution(self.query_weight.abs())
619 }
620
621 #[inline]
622 fn skip_to_next_block(&mut self) -> DocId {
623 self.iter.skip_to_next_block()
624 }
625}
626
627pub struct BmpExecutor<'a> {
639 sparse_index: &'a crate::segment::SparseIndex,
641 query_terms: Vec<(u32, f32)>,
643 k: usize,
645 heap_factor: f32,
647}
648
649const BMP_SUPERBLOCK_SIZE: usize = 8;
652
653struct BmpBlockEntry {
655 contribution: f32,
657 term_idx: usize,
659 block_start: usize,
661 block_count: usize,
663}
664
665impl PartialEq for BmpBlockEntry {
666 fn eq(&self, other: &Self) -> bool {
667 self.contribution == other.contribution
668 }
669}
670
671impl Eq for BmpBlockEntry {}
672
673impl Ord for BmpBlockEntry {
674 fn cmp(&self, other: &Self) -> Ordering {
675 self.contribution
677 .partial_cmp(&other.contribution)
678 .unwrap_or(Ordering::Equal)
679 }
680}
681
682impl PartialOrd for BmpBlockEntry {
683 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
684 Some(self.cmp(other))
685 }
686}
687
688impl<'a> BmpExecutor<'a> {
689 pub fn new(
694 sparse_index: &'a crate::segment::SparseIndex,
695 query_terms: Vec<(u32, f32)>,
696 k: usize,
697 heap_factor: f32,
698 ) -> Self {
699 Self {
700 sparse_index,
701 query_terms,
702 k,
703 heap_factor: heap_factor.clamp(0.0, 1.0),
704 }
705 }
706
707 pub async fn execute(self) -> crate::Result<Vec<ScoredDoc>> {
715 use rustc_hash::FxHashMap;
716
717 if self.query_terms.is_empty() {
718 return Ok(Vec::new());
719 }
720
721 let num_terms = self.query_terms.len();
722 let si = self.sparse_index;
723
724 let mut block_queue: BinaryHeap<BmpBlockEntry> = BinaryHeap::new();
726 let mut remaining_max: Vec<f32> = Vec::with_capacity(num_terms);
727 let mut global_min_doc = u32::MAX;
728 let mut global_max_doc = 0u32;
729
730 for (term_idx, &(dim_id, qw)) in self.query_terms.iter().enumerate() {
731 let mut term_remaining = 0.0f32;
732
733 if let Some((skip_start, skip_count, _global_max)) = si.get_skip_range(dim_id) {
734 let mut sb_start = 0;
736 while sb_start < skip_count {
737 let sb_count = (skip_count - sb_start).min(BMP_SUPERBLOCK_SIZE);
738 let mut sb_contribution = 0.0f32;
739 for j in 0..sb_count {
740 let skip = si.read_skip_entry(skip_start + sb_start + j);
741 sb_contribution += qw.abs() * skip.max_weight;
742 global_min_doc = global_min_doc.min(skip.first_doc);
743 global_max_doc = global_max_doc.max(skip.last_doc);
744 }
745 term_remaining += sb_contribution;
746 block_queue.push(BmpBlockEntry {
747 contribution: sb_contribution,
748 term_idx,
749 block_start: sb_start,
750 block_count: sb_count,
751 });
752 sb_start += sb_count;
753 }
754 }
755 remaining_max.push(term_remaining);
756 }
757
758 let doc_range = if global_max_doc >= global_min_doc {
760 (global_max_doc - global_min_doc + 1) as usize
761 } else {
762 0
763 };
764 let use_flat = doc_range > 0 && doc_range <= 256 * 1024;
766 let mut flat_scores: Vec<f32> = if use_flat {
767 vec![0.0; doc_range]
768 } else {
769 Vec::new()
770 };
771 let mut dirty: Vec<u32> = if use_flat {
773 Vec::with_capacity(4096)
774 } else {
775 Vec::new()
776 };
777 let mut multi_ord_accumulators: FxHashMap<u64, f32> = FxHashMap::default();
779
780 let mut blocks_processed = 0u64;
781 let mut blocks_skipped = 0u64;
782
783 let mut top_k = ScoreCollector::new(self.k);
785
786 let mut doc_ids_buf: Vec<u32> = Vec::with_capacity(128);
788 let mut weights_buf: Vec<f32> = Vec::with_capacity(128);
789 let mut ordinals_buf: Vec<u16> = Vec::with_capacity(128);
790
791 while let Some(entry) = block_queue.pop() {
793 remaining_max[entry.term_idx] -= entry.contribution;
794
795 let total_remaining: f32 = remaining_max.iter().sum();
796 let adjusted_threshold = top_k.threshold() * self.heap_factor;
797 if top_k.len() >= self.k && total_remaining <= adjusted_threshold {
798 blocks_skipped += block_queue.len() as u64;
799 debug!(
800 "BMP early termination after {} blocks: remaining={:.4} <= threshold={:.4}",
801 blocks_processed, total_remaining, adjusted_threshold
802 );
803 break;
804 }
805
806 let dim_id = self.query_terms[entry.term_idx].0;
808 let qw = self.query_terms[entry.term_idx].1;
809 let adjusted_threshold2 = top_k.threshold() * self.heap_factor;
810
811 let sb_blocks = si
812 .get_blocks_range(dim_id, entry.block_start, entry.block_count)
813 .await?;
814
815 for (blk_offset, block) in sb_blocks.into_iter().enumerate() {
816 let blk_idx = entry.block_start + blk_offset;
817
818 if top_k.len() >= self.k
820 && let Some((skip_start, _, _)) = si.get_skip_range(dim_id)
821 {
822 let skip = si.read_skip_entry(skip_start + blk_idx);
823 let blk_contrib = qw.abs() * skip.max_weight;
824 let total_remaining: f32 = remaining_max.iter().sum();
825 if blk_contrib + total_remaining - entry.contribution <= adjusted_threshold2 {
826 blocks_skipped += 1;
827 continue;
828 }
829 }
830
831 block.decode_doc_ids_into(&mut doc_ids_buf);
832 block.decode_scored_weights_into(qw, &mut weights_buf);
833 block.decode_ordinals_into(&mut ordinals_buf);
834
835 if use_flat {
836 for i in 0..block.header.count as usize {
837 let doc_id = doc_ids_buf[i];
838 let ordinal = ordinals_buf[i];
839 let score_contribution = weights_buf[i];
840
841 if ordinal == 0 {
842 let off = (doc_id - global_min_doc) as usize;
843 if flat_scores[off] == 0.0 {
844 dirty.push(doc_id);
845 }
846 flat_scores[off] += score_contribution;
847 top_k.insert_with_ordinal(doc_id, flat_scores[off], 0);
848 } else {
849 let key = (doc_id as u64) << 16 | ordinal as u64;
850 let acc = multi_ord_accumulators.entry(key).or_insert(0.0);
851 *acc += score_contribution;
852 top_k.insert_with_ordinal(doc_id, *acc, ordinal);
853 }
854 }
855 } else {
856 for i in 0..block.header.count as usize {
857 let key = (doc_ids_buf[i] as u64) << 16 | ordinals_buf[i] as u64;
858 let acc = multi_ord_accumulators.entry(key).or_insert(0.0);
859 *acc += weights_buf[i];
860 top_k.insert_with_ordinal(doc_ids_buf[i], *acc, ordinals_buf[i]);
861 }
862 }
863
864 blocks_processed += 1;
865 }
866 }
867
868 let mut scored: Vec<ScoredDoc> = Vec::new();
870
871 let num_accumulators = if use_flat {
872 scored.reserve(dirty.len() + multi_ord_accumulators.len());
874 for &doc_id in &dirty {
875 let off = (doc_id - global_min_doc) as usize;
876 let score = flat_scores[off];
877 if score > 0.0 {
878 scored.push(ScoredDoc {
879 doc_id,
880 score,
881 ordinal: 0,
882 });
883 }
884 }
885 dirty.len() + multi_ord_accumulators.len()
886 } else {
887 multi_ord_accumulators.len()
888 };
889
890 scored.extend(
892 multi_ord_accumulators
893 .into_iter()
894 .map(|(key, score)| ScoredDoc {
895 doc_id: (key >> 16) as DocId,
896 score,
897 ordinal: (key & 0xFFFF) as u16,
898 }),
899 );
900
901 scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
902 scored.truncate(self.k);
903 let results = scored;
904
905 debug!(
906 "BmpExecutor completed: blocks_processed={}, blocks_skipped={}, accumulators={}, flat={}, returned={}, top_score={:.4}",
907 blocks_processed,
908 blocks_skipped,
909 num_accumulators,
910 use_flat,
911 results.len(),
912 results.first().map(|r| r.score).unwrap_or(0.0)
913 );
914
915 Ok(results)
916 }
917}
918
919pub struct LazyBlockMaxScoreExecutor<'a> {
931 sparse_index: &'a crate::segment::SparseIndex,
932 cursors: Vec<LazyTermCursor>,
933 prefix_sums: Vec<f32>,
934 collector: ScoreCollector,
935 heap_factor: f32,
936}
937
938struct LazyTermCursor {
940 dim_id: u32,
941 query_weight: f32,
942 max_score: f32,
943 skip_start: usize,
945 skip_count: usize,
947 block_idx: usize,
949 doc_ids: Vec<u32>,
951 ordinals: Vec<u16>,
952 weights: Vec<f32>,
953 pos: usize,
955 block_loaded: bool,
957 exhausted: bool,
958}
959
960impl LazyTermCursor {
961 fn new(
962 dim_id: u32,
963 query_weight: f32,
964 skip_start: usize,
965 skip_count: usize,
966 global_max_weight: f32,
967 ) -> Self {
968 let exhausted = skip_count == 0;
969 Self {
970 dim_id,
971 query_weight,
972 max_score: query_weight.abs() * global_max_weight,
973 skip_start,
974 skip_count,
975 block_idx: 0,
976 doc_ids: Vec::with_capacity(128),
977 ordinals: Vec::with_capacity(128),
978 weights: Vec::with_capacity(128),
979 pos: 0,
980 block_loaded: false,
981 exhausted,
982 }
983 }
984
985 async fn ensure_block_loaded(
987 &mut self,
988 sparse_index: &crate::segment::SparseIndex,
989 ) -> crate::Result<bool> {
990 if self.exhausted || self.block_loaded {
991 return Ok(!self.exhausted);
992 }
993 match sparse_index.get_block(self.dim_id, self.block_idx).await? {
994 Some(block) => {
995 block.decode_doc_ids_into(&mut self.doc_ids);
996 block.decode_ordinals_into(&mut self.ordinals);
997 block.decode_scored_weights_into(self.query_weight, &mut self.weights);
998 self.pos = 0;
999 self.block_loaded = true;
1000 Ok(true)
1001 }
1002 None => {
1003 self.exhausted = true;
1004 Ok(false)
1005 }
1006 }
1007 }
1008
1009 #[inline]
1010 fn doc_with_si(&self, si: &crate::segment::SparseIndex) -> DocId {
1011 if self.exhausted {
1012 return u32::MAX;
1013 }
1014 if !self.block_loaded {
1015 return si
1018 .read_skip_entry(self.skip_start + self.block_idx)
1019 .first_doc;
1020 }
1021 self.doc_ids.get(self.pos).copied().unwrap_or(u32::MAX)
1022 }
1023
1024 #[inline]
1025 fn doc(&self) -> DocId {
1026 if self.exhausted {
1027 return u32::MAX;
1028 }
1029 if self.block_loaded {
1030 return self.doc_ids.get(self.pos).copied().unwrap_or(u32::MAX);
1031 }
1032 u32::MAX
1034 }
1035
1036 #[inline]
1037 fn ordinal(&self) -> u16 {
1038 if !self.block_loaded {
1039 return 0;
1040 }
1041 self.ordinals.get(self.pos).copied().unwrap_or(0)
1042 }
1043
1044 #[inline]
1045 fn score(&self) -> f32 {
1046 if !self.block_loaded {
1047 return 0.0;
1048 }
1049 self.weights.get(self.pos).copied().unwrap_or(0.0)
1050 }
1051
1052 #[inline]
1053 fn current_block_max_score(&self, si: &crate::segment::SparseIndex) -> f32 {
1054 if self.exhausted || self.block_idx >= self.skip_count {
1055 return 0.0;
1056 }
1057 self.query_weight.abs()
1058 * si.read_skip_entry(self.skip_start + self.block_idx)
1059 .max_weight
1060 }
1061
1062 async fn advance(
1064 &mut self,
1065 sparse_index: &crate::segment::SparseIndex,
1066 ) -> crate::Result<DocId> {
1067 if self.exhausted {
1068 return Ok(u32::MAX);
1069 }
1070 self.ensure_block_loaded(sparse_index).await?;
1071 if self.exhausted {
1072 return Ok(u32::MAX);
1073 }
1074 self.pos += 1;
1075 if self.pos >= self.doc_ids.len() {
1076 self.block_idx += 1;
1077 self.block_loaded = false;
1078 if self.block_idx >= self.skip_count {
1079 self.exhausted = true;
1080 return Ok(u32::MAX);
1081 }
1082 }
1084 Ok(self.doc())
1085 }
1086
1087 async fn seek(
1089 &mut self,
1090 sparse_index: &crate::segment::SparseIndex,
1091 target: DocId,
1092 ) -> crate::Result<DocId> {
1093 if self.exhausted {
1094 return Ok(u32::MAX);
1095 }
1096
1097 if self.block_loaded
1099 && let Some(&last) = self.doc_ids.last()
1100 {
1101 if last >= target && self.doc_ids[self.pos] < target {
1102 let remaining = &self.doc_ids[self.pos..];
1104 let offset = crate::structures::simd::find_first_ge_u32(remaining, target);
1105 self.pos += offset;
1106 if self.pos >= self.doc_ids.len() {
1107 self.block_idx += 1;
1108 self.block_loaded = false;
1109 if self.block_idx >= self.skip_count {
1110 self.exhausted = true;
1111 return Ok(u32::MAX);
1112 }
1113 }
1114 return Ok(self.doc());
1115 }
1116 if self.doc_ids[self.pos] >= target {
1117 return Ok(self.doc());
1118 }
1119 }
1120
1121 let mut found_idx = None;
1123 for i in 0..self.skip_count {
1124 if sparse_index.read_skip_entry(self.skip_start + i).last_doc >= target {
1125 found_idx = Some(i);
1126 break;
1127 }
1128 }
1129 match found_idx {
1130 Some(idx) => {
1131 if idx != self.block_idx || !self.block_loaded {
1132 self.block_idx = idx;
1133 self.block_loaded = false;
1134 }
1135 self.ensure_block_loaded(sparse_index).await?;
1136 if self.exhausted {
1137 return Ok(u32::MAX);
1138 }
1139 let offset = crate::structures::simd::find_first_ge_u32(&self.doc_ids, target);
1140 self.pos = offset;
1141 if self.pos >= self.doc_ids.len() {
1142 self.block_idx += 1;
1143 self.block_loaded = false;
1144 if self.block_idx >= self.skip_count {
1145 self.exhausted = true;
1146 return Ok(u32::MAX);
1147 }
1148 self.ensure_block_loaded(sparse_index).await?;
1149 }
1150 Ok(self.doc())
1151 }
1152 None => {
1153 self.exhausted = true;
1154 Ok(u32::MAX)
1155 }
1156 }
1157 }
1158
1159 fn skip_to_next_block(&mut self, si: &crate::segment::SparseIndex) -> DocId {
1161 if self.exhausted {
1162 return u32::MAX;
1163 }
1164 self.block_idx += 1;
1165 self.block_loaded = false;
1166 if self.block_idx >= self.skip_count {
1167 self.exhausted = true;
1168 return u32::MAX;
1169 }
1170 si.read_skip_entry(self.skip_start + self.block_idx)
1172 .first_doc
1173 }
1174}
1175
1176impl<'a> LazyBlockMaxScoreExecutor<'a> {
1177 pub fn new(
1182 sparse_index: &'a crate::segment::SparseIndex,
1183 query_terms: Vec<(u32, f32)>,
1184 k: usize,
1185 heap_factor: f32,
1186 ) -> Self {
1187 let mut cursors: Vec<LazyTermCursor> = query_terms
1188 .iter()
1189 .filter_map(|&(dim_id, qw)| {
1190 let (skip_start, skip_count, global_max) = sparse_index.get_skip_range(dim_id)?;
1191 Some(LazyTermCursor::new(
1192 dim_id, qw, skip_start, skip_count, global_max,
1193 ))
1194 })
1195 .collect();
1196
1197 cursors.sort_by(|a, b| {
1199 a.max_score
1200 .partial_cmp(&b.max_score)
1201 .unwrap_or(Ordering::Equal)
1202 });
1203
1204 let mut prefix_sums = Vec::with_capacity(cursors.len());
1205 let mut cumsum = 0.0f32;
1206 for c in &cursors {
1207 cumsum += c.max_score;
1208 prefix_sums.push(cumsum);
1209 }
1210
1211 debug!(
1212 "Creating LazyBlockMaxScoreExecutor: num_terms={}, k={}, total_upper={:.4}, heap_factor={:.2}",
1213 cursors.len(),
1214 k,
1215 cumsum,
1216 heap_factor
1217 );
1218
1219 Self {
1220 sparse_index,
1221 cursors,
1222 prefix_sums,
1223 collector: ScoreCollector::new(k),
1224 heap_factor: heap_factor.clamp(0.0, 1.0),
1225 }
1226 }
1227
1228 #[inline]
1229 fn find_partition(&self) -> usize {
1230 let threshold = self.collector.threshold() * self.heap_factor;
1231 self.prefix_sums
1232 .iter()
1233 .position(|&sum| sum > threshold)
1234 .unwrap_or(self.cursors.len())
1235 }
1236
1237 pub async fn execute(mut self) -> crate::Result<Vec<ScoredDoc>> {
1239 if self.cursors.is_empty() {
1240 return Ok(Vec::new());
1241 }
1242
1243 let n = self.cursors.len();
1244 let si = self.sparse_index;
1245
1246 for cursor in &mut self.cursors {
1248 cursor.ensure_block_loaded(si).await?;
1249 }
1250
1251 let mut docs_scored = 0u64;
1252 let mut docs_skipped = 0u64;
1253 let mut blocks_skipped = 0u64;
1254 let mut blocks_loaded = 0u64;
1255 let mut conjunction_skipped = 0u64;
1256 let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
1257
1258 loop {
1259 let partition = self.find_partition();
1260 if partition >= n {
1261 break;
1262 }
1263
1264 let mut min_doc = u32::MAX;
1266 for i in partition..n {
1267 let doc = self.cursors[i].doc_with_si(si);
1268 if doc < min_doc {
1269 min_doc = doc;
1270 }
1271 }
1272 if min_doc == u32::MAX {
1273 break;
1274 }
1275
1276 let non_essential_upper = if partition > 0 {
1277 self.prefix_sums[partition - 1]
1278 } else {
1279 0.0
1280 };
1281 let adjusted_threshold = self.collector.threshold() * self.heap_factor;
1282
1283 if self.collector.len() >= self.collector.k {
1285 let present_upper: f32 = (partition..n)
1286 .filter(|&i| self.cursors[i].doc_with_si(si) == min_doc)
1287 .map(|i| self.cursors[i].max_score)
1288 .sum();
1289
1290 if present_upper + non_essential_upper <= adjusted_threshold {
1291 for i in partition..n {
1292 if self.cursors[i].doc_with_si(si) == min_doc {
1293 self.cursors[i].ensure_block_loaded(si).await?;
1294 self.cursors[i].advance(si).await?;
1295 blocks_loaded += u64::from(self.cursors[i].block_loaded);
1296 }
1297 }
1298 conjunction_skipped += 1;
1299 continue;
1300 }
1301 }
1302
1303 if self.collector.len() >= self.collector.k {
1305 let block_max_sum: f32 = (partition..n)
1306 .filter(|&i| self.cursors[i].doc_with_si(si) == min_doc)
1307 .map(|i| self.cursors[i].current_block_max_score(si))
1308 .sum();
1309
1310 if block_max_sum + non_essential_upper <= adjusted_threshold {
1311 for i in partition..n {
1312 if self.cursors[i].doc_with_si(si) == min_doc {
1313 self.cursors[i].skip_to_next_block(si);
1314 self.cursors[i].ensure_block_loaded(si).await?;
1316 blocks_loaded += 1;
1317 }
1318 }
1319 blocks_skipped += 1;
1320 continue;
1321 }
1322 }
1323
1324 ordinal_scores.clear();
1326 for i in partition..n {
1327 if self.cursors[i].doc_with_si(si) == min_doc {
1328 self.cursors[i].ensure_block_loaded(si).await?;
1329 while self.cursors[i].doc_with_si(si) == min_doc {
1330 ordinal_scores.push((self.cursors[i].ordinal(), self.cursors[i].score()));
1331 self.cursors[i].advance(si).await?;
1332 }
1333 }
1334 }
1335
1336 let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
1337 if self.collector.len() >= self.collector.k
1338 && essential_total + non_essential_upper <= adjusted_threshold
1339 {
1340 docs_skipped += 1;
1341 continue;
1342 }
1343
1344 for i in 0..partition {
1346 let doc = self.cursors[i].seek(si, min_doc).await?;
1347 if doc == min_doc {
1348 while self.cursors[i].doc_with_si(si) == min_doc {
1349 ordinal_scores.push((self.cursors[i].ordinal(), self.cursors[i].score()));
1350 self.cursors[i].advance(si).await?;
1351 }
1352 }
1353 }
1354
1355 ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
1357 let mut j = 0;
1358 while j < ordinal_scores.len() {
1359 let current_ord = ordinal_scores[j].0;
1360 let mut score = 0.0f32;
1361 while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
1362 score += ordinal_scores[j].1;
1363 j += 1;
1364 }
1365 if self
1366 .collector
1367 .insert_with_ordinal(min_doc, score, current_ord)
1368 {
1369 docs_scored += 1;
1370 } else {
1371 docs_skipped += 1;
1372 }
1373 }
1374 }
1375
1376 let results: Vec<ScoredDoc> = self
1377 .collector
1378 .into_sorted_results()
1379 .into_iter()
1380 .map(|(doc_id, score, ordinal)| ScoredDoc {
1381 doc_id,
1382 score,
1383 ordinal,
1384 })
1385 .collect();
1386
1387 debug!(
1388 "LazyBlockMaxScoreExecutor completed: scored={}, skipped={}, blocks_skipped={}, blocks_loaded={}, conjunction_skipped={}, returned={}, top_score={:.4}",
1389 docs_scored,
1390 docs_skipped,
1391 blocks_skipped,
1392 blocks_loaded,
1393 conjunction_skipped,
1394 results.len(),
1395 results.first().map(|r| r.score).unwrap_or(0.0)
1396 );
1397
1398 Ok(results)
1399 }
1400}
1401
1402#[cfg(test)]
1403mod tests {
1404 use super::*;
1405
1406 #[test]
1407 fn test_score_collector_basic() {
1408 let mut collector = ScoreCollector::new(3);
1409
1410 collector.insert(1, 1.0);
1411 collector.insert(2, 2.0);
1412 collector.insert(3, 3.0);
1413 assert_eq!(collector.threshold(), 1.0);
1414
1415 collector.insert(4, 4.0);
1416 assert_eq!(collector.threshold(), 2.0);
1417
1418 let results = collector.into_sorted_results();
1419 assert_eq!(results.len(), 3);
1420 assert_eq!(results[0].0, 4); assert_eq!(results[1].0, 3);
1422 assert_eq!(results[2].0, 2);
1423 }
1424
1425 #[test]
1426 fn test_score_collector_threshold() {
1427 let mut collector = ScoreCollector::new(2);
1428
1429 collector.insert(1, 5.0);
1430 collector.insert(2, 3.0);
1431 assert_eq!(collector.threshold(), 3.0);
1432
1433 assert!(!collector.would_enter(2.0));
1435 assert!(!collector.insert(3, 2.0));
1436
1437 assert!(collector.would_enter(4.0));
1439 assert!(collector.insert(4, 4.0));
1440 assert_eq!(collector.threshold(), 4.0);
1441 }
1442
1443 #[test]
1444 fn test_heap_entry_ordering() {
1445 let mut heap = BinaryHeap::new();
1446 heap.push(HeapEntry {
1447 doc_id: 1,
1448 score: 3.0,
1449 ordinal: 0,
1450 });
1451 heap.push(HeapEntry {
1452 doc_id: 2,
1453 score: 1.0,
1454 ordinal: 0,
1455 });
1456 heap.push(HeapEntry {
1457 doc_id: 3,
1458 score: 2.0,
1459 ordinal: 0,
1460 });
1461
1462 assert_eq!(heap.pop().unwrap().score, 1.0);
1464 assert_eq!(heap.pop().unwrap().score, 2.0);
1465 assert_eq!(heap.pop().unwrap().score, 3.0);
1466 }
1467}