1use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12use std::sync::Arc;
13
14use log::{debug, trace};
15
16use crate::DocId;
17use crate::structures::BlockSparsePostingList;
18
19pub trait ScoringIterator {
23 fn doc(&self) -> DocId;
25
26 fn ordinal(&mut self) -> u16 {
28 0
29 }
30
31 fn advance(&mut self) -> DocId;
33
34 fn seek(&mut self, target: DocId) -> DocId;
36
37 fn is_exhausted(&self) -> bool {
39 self.doc() == u32::MAX
40 }
41
42 fn score(&self) -> f32;
44
45 fn max_score(&self) -> f32;
47
48 fn current_block_max_score(&self) -> f32;
50
51 fn skip_to_next_block(&mut self) -> DocId {
55 self.advance()
56 }
57}
58
59#[derive(Clone, Copy)]
61pub struct HeapEntry {
62 pub doc_id: DocId,
63 pub score: f32,
64 pub ordinal: u16,
65}
66
67impl PartialEq for HeapEntry {
68 fn eq(&self, other: &Self) -> bool {
69 self.score == other.score && self.doc_id == other.doc_id
70 }
71}
72
73impl Eq for HeapEntry {}
74
75impl Ord for HeapEntry {
76 fn cmp(&self, other: &Self) -> Ordering {
77 other
79 .score
80 .partial_cmp(&self.score)
81 .unwrap_or(Ordering::Equal)
82 .then_with(|| self.doc_id.cmp(&other.doc_id))
83 }
84}
85
86impl PartialOrd for HeapEntry {
87 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
88 Some(self.cmp(other))
89 }
90}
91
92pub struct ScoreCollector {
98 heap: BinaryHeap<HeapEntry>,
100 pub k: usize,
101}
102
103impl ScoreCollector {
104 pub fn new(k: usize) -> Self {
106 let capacity = k.saturating_add(1).min(1_000_000);
108 Self {
109 heap: BinaryHeap::with_capacity(capacity),
110 k,
111 }
112 }
113
114 #[inline]
116 pub fn threshold(&self) -> f32 {
117 if self.heap.len() >= self.k {
118 self.heap.peek().map(|e| e.score).unwrap_or(0.0)
119 } else {
120 0.0
121 }
122 }
123
124 #[inline]
127 pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
128 self.insert_with_ordinal(doc_id, score, 0)
129 }
130
131 #[inline]
134 pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
135 if self.heap.len() < self.k {
136 self.heap.push(HeapEntry {
137 doc_id,
138 score,
139 ordinal,
140 });
141 true
142 } else if score > self.threshold() {
143 self.heap.push(HeapEntry {
144 doc_id,
145 score,
146 ordinal,
147 });
148 self.heap.pop(); true
150 } else {
151 false
152 }
153 }
154
155 #[inline]
157 pub fn would_enter(&self, score: f32) -> bool {
158 self.heap.len() < self.k || score > self.threshold()
159 }
160
161 #[inline]
163 pub fn len(&self) -> usize {
164 self.heap.len()
165 }
166
167 #[inline]
169 pub fn is_empty(&self) -> bool {
170 self.heap.is_empty()
171 }
172
173 pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
175 let heap_vec = self.heap.into_vec();
176 let mut results: Vec<(DocId, f32, u16)> = Vec::with_capacity(heap_vec.len());
177 for e in heap_vec {
178 results.push((e.doc_id, e.score, e.ordinal));
179 }
180
181 results.sort_by(|a, b| {
183 b.1.partial_cmp(&a.1)
184 .unwrap_or(Ordering::Equal)
185 .then_with(|| a.0.cmp(&b.0))
186 });
187
188 results
189 }
190}
191
192#[derive(Debug, Clone, Copy)]
194pub struct ScoredDoc {
195 pub doc_id: DocId,
196 pub score: f32,
197 pub ordinal: u16,
199}
200
201pub struct BlockMaxScoreExecutor<S: ScoringIterator> {
215 scorers: Vec<S>,
217 prefix_sums: Vec<f32>,
219 collector: ScoreCollector,
221 heap_factor: f32,
225}
226
227pub type WandExecutor<S> = BlockMaxScoreExecutor<S>;
229
230impl<S: ScoringIterator> BlockMaxScoreExecutor<S> {
231 pub fn new(scorers: Vec<S>, k: usize) -> Self {
233 Self::with_heap_factor(scorers, k, 1.0)
234 }
235
236 pub fn with_heap_factor(mut scorers: Vec<S>, k: usize, heap_factor: f32) -> Self {
243 scorers.sort_by(|a, b| {
245 a.max_score()
246 .partial_cmp(&b.max_score())
247 .unwrap_or(Ordering::Equal)
248 });
249
250 let mut prefix_sums = Vec::with_capacity(scorers.len());
252 let mut cumsum = 0.0f32;
253 for s in &scorers {
254 cumsum += s.max_score();
255 prefix_sums.push(cumsum);
256 }
257
258 debug!(
259 "Creating BlockMaxScoreExecutor: num_scorers={}, k={}, total_upper={:.4}, heap_factor={:.2}",
260 scorers.len(),
261 k,
262 cumsum,
263 heap_factor
264 );
265
266 Self {
267 scorers,
268 prefix_sums,
269 collector: ScoreCollector::new(k),
270 heap_factor: heap_factor.clamp(0.0, 1.0),
271 }
272 }
273
274 #[inline]
277 fn find_partition(&self) -> usize {
278 let threshold = self.collector.threshold() * self.heap_factor;
279 self.prefix_sums.partition_point(|&sum| sum <= threshold)
281 }
282
283 pub fn execute(mut self) -> Vec<ScoredDoc> {
293 if self.scorers.is_empty() {
294 debug!("BlockMaxScoreExecutor: no scorers, returning empty results");
295 return Vec::new();
296 }
297
298 let n = self.scorers.len();
299 let mut docs_scored = 0u64;
300 let mut docs_skipped = 0u64;
301 let mut blocks_skipped = 0u64;
302 let mut conjunction_skipped = 0u64;
303
304 let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
306
307 loop {
308 let partition = self.find_partition();
309
310 if partition >= n {
312 debug!("BlockMaxScore: all terms non-essential, early termination");
313 break;
314 }
315
316 let mut min_doc = u32::MAX;
318 for i in partition..n {
319 let doc = self.scorers[i].doc();
320 if doc < min_doc {
321 min_doc = doc;
322 }
323 }
324
325 if min_doc == u32::MAX {
326 break; }
328
329 let non_essential_upper = if partition > 0 {
330 self.prefix_sums[partition - 1]
331 } else {
332 0.0
333 };
334 let adjusted_threshold = self.collector.threshold() * self.heap_factor;
335
336 if self.collector.len() >= self.collector.k {
341 let present_upper: f32 = (partition..n)
342 .filter(|&i| self.scorers[i].doc() == min_doc)
343 .map(|i| self.scorers[i].max_score())
344 .sum();
345
346 if present_upper + non_essential_upper <= adjusted_threshold {
347 for i in partition..n {
349 if self.scorers[i].doc() == min_doc {
350 self.scorers[i].advance();
351 }
352 }
353 conjunction_skipped += 1;
354 continue;
355 }
356 }
357
358 if self.collector.len() >= self.collector.k {
362 let block_max_sum: f32 = (partition..n)
363 .filter(|&i| self.scorers[i].doc() == min_doc)
364 .map(|i| self.scorers[i].current_block_max_score())
365 .sum();
366
367 if block_max_sum + non_essential_upper <= adjusted_threshold {
368 for i in partition..n {
369 if self.scorers[i].doc() == min_doc {
370 self.scorers[i].skip_to_next_block();
371 }
372 }
373 blocks_skipped += 1;
374 continue;
375 }
376 }
377
378 ordinal_scores.clear();
381
382 for i in partition..n {
383 if self.scorers[i].doc() == min_doc {
384 while self.scorers[i].doc() == min_doc {
385 ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
386 self.scorers[i].advance();
387 }
388 }
389 }
390
391 let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
393
394 if self.collector.len() >= self.collector.k
395 && essential_total + non_essential_upper <= adjusted_threshold
396 {
397 docs_skipped += 1;
398 continue;
399 }
400
401 for i in 0..partition {
403 let doc = self.scorers[i].seek(min_doc);
404 if doc == min_doc {
405 while self.scorers[i].doc() == min_doc {
406 ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
407 self.scorers[i].advance();
408 }
409 }
410 }
411
412 ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
414 let mut j = 0;
415 while j < ordinal_scores.len() {
416 let current_ord = ordinal_scores[j].0;
417 let mut score = 0.0f32;
418 while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
419 score += ordinal_scores[j].1;
420 j += 1;
421 }
422
423 trace!(
424 "Doc {}: ordinal={}, score={:.4}, threshold={:.4}",
425 min_doc, current_ord, score, adjusted_threshold
426 );
427
428 if self
429 .collector
430 .insert_with_ordinal(min_doc, score, current_ord)
431 {
432 docs_scored += 1;
433 } else {
434 docs_skipped += 1;
435 }
436 }
437 }
438
439 let results: Vec<ScoredDoc> = self
440 .collector
441 .into_sorted_results()
442 .into_iter()
443 .map(|(doc_id, score, ordinal)| ScoredDoc {
444 doc_id,
445 score,
446 ordinal,
447 })
448 .collect();
449
450 debug!(
451 "BlockMaxScoreExecutor completed: scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
452 docs_scored,
453 docs_skipped,
454 blocks_skipped,
455 conjunction_skipped,
456 results.len(),
457 results.first().map(|r| r.score).unwrap_or(0.0)
458 );
459
460 results
461 }
462}
463
464pub struct TextTermScorer {
469 iter: crate::structures::BlockPostingIterator<'static>,
471 idf: f32,
473 avg_field_len: f32,
475 max_score: f32,
477}
478
479impl TextTermScorer {
480 pub fn new(
482 posting_list: crate::structures::BlockPostingList,
483 idf: f32,
484 avg_field_len: f32,
485 ) -> Self {
486 let max_tf = posting_list.max_tf() as f32;
488 let doc_count = posting_list.doc_count();
489 let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
490
491 debug!(
492 "Created TextTermScorer: doc_count={}, max_tf={:.0}, idf={:.4}, avg_field_len={:.2}, max_score={:.4}",
493 doc_count, max_tf, idf, avg_field_len, max_score
494 );
495
496 Self {
497 iter: posting_list.into_iterator(),
498 idf,
499 avg_field_len,
500 max_score,
501 }
502 }
503}
504
505impl ScoringIterator for TextTermScorer {
506 #[inline]
507 fn doc(&self) -> DocId {
508 self.iter.doc()
509 }
510
511 #[inline]
512 fn advance(&mut self) -> DocId {
513 self.iter.advance()
514 }
515
516 #[inline]
517 fn seek(&mut self, target: DocId) -> DocId {
518 self.iter.seek(target)
519 }
520
521 #[inline]
522 fn score(&self) -> f32 {
523 let tf = self.iter.term_freq() as f32;
524 super::bm25_score(tf, self.idf, tf, self.avg_field_len)
526 }
527
528 #[inline]
529 fn max_score(&self) -> f32 {
530 self.max_score
531 }
532
533 #[inline]
534 fn current_block_max_score(&self) -> f32 {
535 let block_max_tf = self.iter.current_block_max_tf() as f32;
537 super::bm25_upper_bound(block_max_tf.max(1.0), self.idf)
538 }
539
540 #[inline]
541 fn skip_to_next_block(&mut self) -> DocId {
542 self.iter.skip_to_next_block()
543 }
544}
545
546pub struct SparseTermScorer<'a> {
550 iter: crate::structures::BlockSparsePostingIterator<'a>,
552 query_weight: f32,
554 abs_query_weight: f32,
556 max_score: f32,
558}
559
560impl<'a> SparseTermScorer<'a> {
561 pub fn new(posting_list: &'a BlockSparsePostingList, query_weight: f32) -> Self {
566 let abs_qw = query_weight.abs();
569 let max_score = abs_qw * posting_list.global_max_weight();
570 Self {
571 iter: posting_list.iterator(),
572 query_weight,
573 abs_query_weight: abs_qw,
574 max_score,
575 }
576 }
577
578 pub fn from_arc(posting_list: &'a Arc<BlockSparsePostingList>, query_weight: f32) -> Self {
580 Self::new(posting_list.as_ref(), query_weight)
581 }
582}
583
584impl ScoringIterator for SparseTermScorer<'_> {
585 #[inline]
586 fn doc(&self) -> DocId {
587 self.iter.doc()
588 }
589
590 #[inline]
591 fn ordinal(&mut self) -> u16 {
592 self.iter.ordinal()
593 }
594
595 #[inline]
596 fn advance(&mut self) -> DocId {
597 self.iter.advance()
598 }
599
600 #[inline]
601 fn seek(&mut self, target: DocId) -> DocId {
602 self.iter.seek(target)
603 }
604
605 #[inline]
606 fn score(&self) -> f32 {
607 self.query_weight * self.iter.weight()
609 }
610
611 #[inline]
612 fn max_score(&self) -> f32 {
613 self.max_score
614 }
615
616 #[inline]
617 fn current_block_max_score(&self) -> f32 {
618 self.iter
619 .current_block_max_contribution(self.abs_query_weight)
620 }
621
622 #[inline]
623 fn skip_to_next_block(&mut self) -> DocId {
624 self.iter.skip_to_next_block()
625 }
626}
627
628pub struct BmpExecutor<'a> {
640 sparse_index: &'a crate::segment::SparseIndex,
642 query_terms: Vec<(u32, f32)>,
644 k: usize,
646 heap_factor: f32,
648}
649
650const BMP_SUPERBLOCK_SIZE: usize = 8;
653
654struct BmpBlockEntry {
656 contribution: f32,
658 term_idx: usize,
660 block_start: usize,
662 block_count: usize,
664}
665
666impl PartialEq for BmpBlockEntry {
667 fn eq(&self, other: &Self) -> bool {
668 self.contribution == other.contribution
669 }
670}
671
672impl Eq for BmpBlockEntry {}
673
674impl Ord for BmpBlockEntry {
675 fn cmp(&self, other: &Self) -> Ordering {
676 self.contribution
678 .partial_cmp(&other.contribution)
679 .unwrap_or(Ordering::Equal)
680 }
681}
682
683impl PartialOrd for BmpBlockEntry {
684 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
685 Some(self.cmp(other))
686 }
687}
688
689impl<'a> BmpExecutor<'a> {
690 pub fn new(
695 sparse_index: &'a crate::segment::SparseIndex,
696 query_terms: Vec<(u32, f32)>,
697 k: usize,
698 heap_factor: f32,
699 ) -> Self {
700 Self {
701 sparse_index,
702 query_terms,
703 k,
704 heap_factor: heap_factor.clamp(0.0, 1.0),
705 }
706 }
707
708 pub async fn execute(self) -> crate::Result<Vec<ScoredDoc>> {
716 use rustc_hash::FxHashMap;
717
718 if self.query_terms.is_empty() {
719 return Ok(Vec::new());
720 }
721
722 let num_terms = self.query_terms.len();
723 let si = self.sparse_index;
724
725 let mut block_queue: BinaryHeap<BmpBlockEntry> = BinaryHeap::new();
727 let mut remaining_max: Vec<f32> = Vec::with_capacity(num_terms);
728 let mut global_min_doc = u32::MAX;
729 let mut global_max_doc = 0u32;
730
731 for (term_idx, &(dim_id, qw)) in self.query_terms.iter().enumerate() {
732 let mut term_remaining = 0.0f32;
733
734 let abs_qw = qw.abs();
735 if let Some((skip_start, skip_count, _global_max)) = si.get_skip_range(dim_id) {
736 let mut sb_start = 0;
738 while sb_start < skip_count {
739 let sb_count = (skip_count - sb_start).min(BMP_SUPERBLOCK_SIZE);
740 let mut sb_contribution = 0.0f32;
741 for j in 0..sb_count {
742 let skip = si.read_skip_entry(skip_start + sb_start + j);
743 sb_contribution += abs_qw * skip.max_weight;
744 global_min_doc = global_min_doc.min(skip.first_doc);
745 global_max_doc = global_max_doc.max(skip.last_doc);
746 }
747 term_remaining += sb_contribution;
748 block_queue.push(BmpBlockEntry {
749 contribution: sb_contribution,
750 term_idx,
751 block_start: sb_start,
752 block_count: sb_count,
753 });
754 sb_start += sb_count;
755 }
756 }
757 remaining_max.push(term_remaining);
758 }
759
760 let doc_range = if global_max_doc >= global_min_doc {
762 (global_max_doc - global_min_doc + 1) as usize
763 } else {
764 0
765 };
766 let use_flat = doc_range > 0 && doc_range <= 256 * 1024;
768 let mut flat_scores: Vec<f32> = if use_flat {
769 vec![0.0; doc_range]
770 } else {
771 Vec::new()
772 };
773 let mut dirty: Vec<u32> = if use_flat {
775 Vec::with_capacity(4096)
776 } else {
777 Vec::new()
778 };
779 let mut multi_ord_accumulators: FxHashMap<u64, f32> = FxHashMap::default();
781
782 let mut total_remaining: f32 = remaining_max.iter().sum();
783 let mut blocks_processed = 0u64;
784 let mut blocks_skipped = 0u64;
785
786 let mut top_k = ScoreCollector::new(self.k);
788
789 let mut doc_ids_buf: Vec<u32> = Vec::with_capacity(128);
791 let mut weights_buf: Vec<f32> = Vec::with_capacity(128);
792 let mut ordinals_buf: Vec<u16> = Vec::with_capacity(128);
793
794 while let Some(entry) = block_queue.pop() {
796 remaining_max[entry.term_idx] -= entry.contribution;
797 total_remaining -= entry.contribution;
798
799 let adjusted_threshold = top_k.threshold() * self.heap_factor;
800 if top_k.len() >= self.k && total_remaining <= adjusted_threshold {
801 blocks_skipped += block_queue.len() as u64;
802 debug!(
803 "BMP early termination after {} blocks: remaining={:.4} <= threshold={:.4}",
804 blocks_processed, total_remaining, adjusted_threshold
805 );
806 break;
807 }
808
809 let dim_id = self.query_terms[entry.term_idx].0;
811 let qw = self.query_terms[entry.term_idx].1;
812 let abs_qw = qw.abs();
813
814 let sb_blocks = si
815 .get_blocks_range(dim_id, entry.block_start, entry.block_count)
816 .await?;
817
818 let skip_start_opt = si.get_skip_range(dim_id).map(|(s, _, _)| s);
820 let adjusted_threshold2 = top_k.threshold() * self.heap_factor;
821
822 for (blk_offset, block) in sb_blocks.into_iter().enumerate() {
823 let blk_idx = entry.block_start + blk_offset;
824
825 if top_k.len() >= self.k
827 && let Some(skip_start) = skip_start_opt
828 {
829 let skip = si.read_skip_entry(skip_start + blk_idx);
830 let blk_contrib = abs_qw * skip.max_weight;
831 if blk_contrib + total_remaining <= adjusted_threshold2 {
832 blocks_skipped += 1;
833 continue;
834 }
835 }
836
837 block.decode_doc_ids_into(&mut doc_ids_buf);
838 block.decode_scored_weights_into(qw, &mut weights_buf);
839 let count = block.header.count as usize;
840
841 if block.header.ordinal_bits == 0 && use_flat {
844 for i in 0..count {
845 let doc_id = doc_ids_buf[i];
846 let off = (doc_id - global_min_doc) as usize;
847 if flat_scores[off] == 0.0 {
848 dirty.push(doc_id);
849 }
850 flat_scores[off] += weights_buf[i];
851 top_k.insert_with_ordinal(doc_id, flat_scores[off], 0);
852 }
853 } else {
854 block.decode_ordinals_into(&mut ordinals_buf);
855 if use_flat {
856 for i in 0..count {
857 let doc_id = doc_ids_buf[i];
858 let ordinal = ordinals_buf[i];
859 let score_contribution = weights_buf[i];
860
861 if ordinal == 0 {
862 let off = (doc_id - global_min_doc) as usize;
863 if flat_scores[off] == 0.0 {
864 dirty.push(doc_id);
865 }
866 flat_scores[off] += score_contribution;
867 top_k.insert_with_ordinal(doc_id, flat_scores[off], 0);
868 } else {
869 let key = (doc_id as u64) << 16 | ordinal as u64;
870 let acc = multi_ord_accumulators.entry(key).or_insert(0.0);
871 *acc += score_contribution;
872 top_k.insert_with_ordinal(doc_id, *acc, ordinal);
873 }
874 }
875 } else {
876 for i in 0..count {
877 let key = (doc_ids_buf[i] as u64) << 16 | ordinals_buf[i] as u64;
878 let acc = multi_ord_accumulators.entry(key).or_insert(0.0);
879 *acc += weights_buf[i];
880 top_k.insert_with_ordinal(doc_ids_buf[i], *acc, ordinals_buf[i]);
881 }
882 }
883 }
884
885 blocks_processed += 1;
886 }
887 }
888
889 let mut scored: Vec<ScoredDoc> = Vec::new();
891
892 let num_accumulators = if use_flat {
893 scored.reserve(dirty.len() + multi_ord_accumulators.len());
895 for &doc_id in &dirty {
896 let off = (doc_id - global_min_doc) as usize;
897 let score = flat_scores[off];
898 if score > 0.0 {
899 scored.push(ScoredDoc {
900 doc_id,
901 score,
902 ordinal: 0,
903 });
904 }
905 }
906 dirty.len() + multi_ord_accumulators.len()
907 } else {
908 multi_ord_accumulators.len()
909 };
910
911 scored.extend(
913 multi_ord_accumulators
914 .into_iter()
915 .map(|(key, score)| ScoredDoc {
916 doc_id: (key >> 16) as DocId,
917 score,
918 ordinal: (key & 0xFFFF) as u16,
919 }),
920 );
921
922 scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
923 scored.truncate(self.k);
924 let results = scored;
925
926 debug!(
927 "BmpExecutor completed: blocks_processed={}, blocks_skipped={}, accumulators={}, flat={}, returned={}, top_score={:.4}",
928 blocks_processed,
929 blocks_skipped,
930 num_accumulators,
931 use_flat,
932 results.len(),
933 results.first().map(|r| r.score).unwrap_or(0.0)
934 );
935
936 Ok(results)
937 }
938}
939
940pub struct LazyBlockMaxScoreExecutor<'a> {
952 sparse_index: &'a crate::segment::SparseIndex,
953 cursors: Vec<LazyTermCursor>,
954 prefix_sums: Vec<f32>,
955 collector: ScoreCollector,
956 heap_factor: f32,
957}
958
959struct LazyTermCursor {
961 query_weight: f32,
962 abs_query_weight: f32,
964 max_score: f32,
965 skip_start: usize,
967 skip_count: usize,
969 block_data_offset: u32,
971 block_idx: usize,
973 doc_ids: Vec<u32>,
975 ordinals: Vec<u16>,
976 weights: Vec<f32>,
977 pos: usize,
979 block_loaded: bool,
981 exhausted: bool,
982}
983
984impl LazyTermCursor {
985 fn new(
986 query_weight: f32,
987 skip_start: usize,
988 skip_count: usize,
989 global_max_weight: f32,
990 block_data_offset: u32,
991 ) -> Self {
992 let exhausted = skip_count == 0;
993 let abs_qw = query_weight.abs();
994 Self {
995 query_weight,
996 abs_query_weight: abs_qw,
997 max_score: abs_qw * global_max_weight,
998 skip_start,
999 skip_count,
1000 block_data_offset,
1001 block_idx: 0,
1002 doc_ids: Vec::with_capacity(128),
1003 ordinals: Vec::with_capacity(128),
1004 weights: Vec::with_capacity(128),
1005 pos: 0,
1006 block_loaded: false,
1007 exhausted,
1008 }
1009 }
1010
1011 async fn ensure_block_loaded(
1013 &mut self,
1014 sparse_index: &crate::segment::SparseIndex,
1015 ) -> crate::Result<bool> {
1016 if self.exhausted || self.block_loaded {
1017 return Ok(!self.exhausted);
1018 }
1019 match sparse_index
1020 .load_block_direct(self.skip_start, self.block_data_offset, self.block_idx)
1021 .await?
1022 {
1023 Some(block) => {
1024 block.decode_doc_ids_into(&mut self.doc_ids);
1025 block.decode_ordinals_into(&mut self.ordinals);
1026 block.decode_scored_weights_into(self.query_weight, &mut self.weights);
1027 self.pos = 0;
1028 self.block_loaded = true;
1029 Ok(true)
1030 }
1031 None => {
1032 self.exhausted = true;
1033 Ok(false)
1034 }
1035 }
1036 }
1037
1038 #[inline]
1039 fn doc_with_si(&self, si: &crate::segment::SparseIndex) -> DocId {
1040 if self.exhausted {
1041 return u32::MAX;
1042 }
1043 if !self.block_loaded {
1044 return si
1047 .read_skip_entry(self.skip_start + self.block_idx)
1048 .first_doc;
1049 }
1050 self.doc_ids.get(self.pos).copied().unwrap_or(u32::MAX)
1051 }
1052
1053 #[inline]
1054 fn doc(&self) -> DocId {
1055 if self.exhausted {
1056 return u32::MAX;
1057 }
1058 if self.block_loaded {
1059 return self.doc_ids.get(self.pos).copied().unwrap_or(u32::MAX);
1060 }
1061 u32::MAX
1063 }
1064
1065 #[inline]
1066 fn ordinal(&self) -> u16 {
1067 if !self.block_loaded {
1068 return 0;
1069 }
1070 self.ordinals.get(self.pos).copied().unwrap_or(0)
1071 }
1072
1073 #[inline]
1074 fn score(&self) -> f32 {
1075 if !self.block_loaded {
1076 return 0.0;
1077 }
1078 self.weights.get(self.pos).copied().unwrap_or(0.0)
1079 }
1080
1081 #[inline]
1082 fn current_block_max_score(&self, si: &crate::segment::SparseIndex) -> f32 {
1083 if self.exhausted || self.block_idx >= self.skip_count {
1084 return 0.0;
1085 }
1086 self.abs_query_weight
1087 * si.read_skip_entry(self.skip_start + self.block_idx)
1088 .max_weight
1089 }
1090
1091 async fn advance(
1093 &mut self,
1094 sparse_index: &crate::segment::SparseIndex,
1095 ) -> crate::Result<DocId> {
1096 if self.exhausted {
1097 return Ok(u32::MAX);
1098 }
1099 self.ensure_block_loaded(sparse_index).await?;
1100 if self.exhausted {
1101 return Ok(u32::MAX);
1102 }
1103 self.pos += 1;
1104 if self.pos >= self.doc_ids.len() {
1105 self.block_idx += 1;
1106 self.block_loaded = false;
1107 if self.block_idx >= self.skip_count {
1108 self.exhausted = true;
1109 return Ok(u32::MAX);
1110 }
1111 }
1113 Ok(self.doc())
1114 }
1115
1116 async fn seek(
1118 &mut self,
1119 sparse_index: &crate::segment::SparseIndex,
1120 target: DocId,
1121 ) -> crate::Result<DocId> {
1122 if self.exhausted {
1123 return Ok(u32::MAX);
1124 }
1125
1126 if self.block_loaded
1128 && let Some(&last) = self.doc_ids.last()
1129 {
1130 if last >= target && self.doc_ids[self.pos] < target {
1131 let remaining = &self.doc_ids[self.pos..];
1133 let offset = crate::structures::simd::find_first_ge_u32(remaining, target);
1134 self.pos += offset;
1135 if self.pos >= self.doc_ids.len() {
1136 self.block_idx += 1;
1137 self.block_loaded = false;
1138 if self.block_idx >= self.skip_count {
1139 self.exhausted = true;
1140 return Ok(u32::MAX);
1141 }
1142 }
1143 return Ok(self.doc());
1144 }
1145 if self.doc_ids[self.pos] >= target {
1146 return Ok(self.doc());
1147 }
1148 }
1149
1150 let mut lo = self.block_idx;
1153 let mut hi = self.skip_count;
1154 while lo < hi {
1155 let mid = lo + (hi - lo) / 2;
1156 if sparse_index.read_skip_entry(self.skip_start + mid).last_doc < target {
1157 lo = mid + 1;
1158 } else {
1159 hi = mid;
1160 }
1161 }
1162 if lo >= self.skip_count {
1163 self.exhausted = true;
1164 return Ok(u32::MAX);
1165 }
1166 if lo != self.block_idx || !self.block_loaded {
1167 self.block_idx = lo;
1168 self.block_loaded = false;
1169 }
1170 self.ensure_block_loaded(sparse_index).await?;
1171 if self.exhausted {
1172 return Ok(u32::MAX);
1173 }
1174 self.pos = crate::structures::simd::find_first_ge_u32(&self.doc_ids, target);
1175 if self.pos >= self.doc_ids.len() {
1176 self.block_idx += 1;
1177 self.block_loaded = false;
1178 if self.block_idx >= self.skip_count {
1179 self.exhausted = true;
1180 return Ok(u32::MAX);
1181 }
1182 self.ensure_block_loaded(sparse_index).await?;
1183 }
1184 Ok(self.doc())
1185 }
1186
1187 fn skip_to_next_block(&mut self, si: &crate::segment::SparseIndex) -> DocId {
1189 if self.exhausted {
1190 return u32::MAX;
1191 }
1192 self.block_idx += 1;
1193 self.block_loaded = false;
1194 if self.block_idx >= self.skip_count {
1195 self.exhausted = true;
1196 return u32::MAX;
1197 }
1198 si.read_skip_entry(self.skip_start + self.block_idx)
1200 .first_doc
1201 }
1202}
1203
1204impl<'a> LazyBlockMaxScoreExecutor<'a> {
1205 pub fn new(
1210 sparse_index: &'a crate::segment::SparseIndex,
1211 query_terms: Vec<(u32, f32)>,
1212 k: usize,
1213 heap_factor: f32,
1214 ) -> Self {
1215 let mut cursors: Vec<LazyTermCursor> = query_terms
1216 .iter()
1217 .filter_map(|&(dim_id, qw)| {
1218 let (skip_start, skip_count, global_max, block_data_offset) =
1219 sparse_index.get_skip_range_full(dim_id)?;
1220 Some(LazyTermCursor::new(
1221 qw,
1222 skip_start,
1223 skip_count,
1224 global_max,
1225 block_data_offset,
1226 ))
1227 })
1228 .collect();
1229
1230 cursors.sort_by(|a, b| {
1232 a.max_score
1233 .partial_cmp(&b.max_score)
1234 .unwrap_or(Ordering::Equal)
1235 });
1236
1237 let mut prefix_sums = Vec::with_capacity(cursors.len());
1238 let mut cumsum = 0.0f32;
1239 for c in &cursors {
1240 cumsum += c.max_score;
1241 prefix_sums.push(cumsum);
1242 }
1243
1244 debug!(
1245 "Creating LazyBlockMaxScoreExecutor: num_terms={}, k={}, total_upper={:.4}, heap_factor={:.2}",
1246 cursors.len(),
1247 k,
1248 cumsum,
1249 heap_factor
1250 );
1251
1252 Self {
1253 sparse_index,
1254 cursors,
1255 prefix_sums,
1256 collector: ScoreCollector::new(k),
1257 heap_factor: heap_factor.clamp(0.0, 1.0),
1258 }
1259 }
1260
1261 #[inline]
1262 fn find_partition(&self) -> usize {
1263 let threshold = self.collector.threshold() * self.heap_factor;
1264 self.prefix_sums.partition_point(|&sum| sum <= threshold)
1266 }
1267
1268 pub async fn execute(mut self) -> crate::Result<Vec<ScoredDoc>> {
1270 if self.cursors.is_empty() {
1271 return Ok(Vec::new());
1272 }
1273
1274 let n = self.cursors.len();
1275 let si = self.sparse_index;
1276
1277 for cursor in &mut self.cursors {
1279 cursor.ensure_block_loaded(si).await?;
1280 }
1281
1282 let mut docs_scored = 0u64;
1283 let mut docs_skipped = 0u64;
1284 let mut blocks_skipped = 0u64;
1285 let mut blocks_loaded = 0u64;
1286 let mut conjunction_skipped = 0u64;
1287 let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
1288
1289 loop {
1290 let partition = self.find_partition();
1291 if partition >= n {
1292 break;
1293 }
1294
1295 let mut min_doc = u32::MAX;
1297 for i in partition..n {
1298 let doc = self.cursors[i].doc_with_si(si);
1299 if doc < min_doc {
1300 min_doc = doc;
1301 }
1302 }
1303 if min_doc == u32::MAX {
1304 break;
1305 }
1306
1307 let non_essential_upper = if partition > 0 {
1308 self.prefix_sums[partition - 1]
1309 } else {
1310 0.0
1311 };
1312 let adjusted_threshold = self.collector.threshold() * self.heap_factor;
1313
1314 if self.collector.len() >= self.collector.k {
1316 let present_upper: f32 = (partition..n)
1317 .filter(|&i| self.cursors[i].doc_with_si(si) == min_doc)
1318 .map(|i| self.cursors[i].max_score)
1319 .sum();
1320
1321 if present_upper + non_essential_upper <= adjusted_threshold {
1322 for i in partition..n {
1323 if self.cursors[i].doc_with_si(si) == min_doc {
1324 self.cursors[i].ensure_block_loaded(si).await?;
1325 self.cursors[i].advance(si).await?;
1326 blocks_loaded += u64::from(self.cursors[i].block_loaded);
1327 }
1328 }
1329 conjunction_skipped += 1;
1330 continue;
1331 }
1332 }
1333
1334 if self.collector.len() >= self.collector.k {
1336 let block_max_sum: f32 = (partition..n)
1337 .filter(|&i| self.cursors[i].doc_with_si(si) == min_doc)
1338 .map(|i| self.cursors[i].current_block_max_score(si))
1339 .sum();
1340
1341 if block_max_sum + non_essential_upper <= adjusted_threshold {
1342 for i in partition..n {
1343 if self.cursors[i].doc_with_si(si) == min_doc {
1344 self.cursors[i].skip_to_next_block(si);
1345 self.cursors[i].ensure_block_loaded(si).await?;
1347 blocks_loaded += 1;
1348 }
1349 }
1350 blocks_skipped += 1;
1351 continue;
1352 }
1353 }
1354
1355 ordinal_scores.clear();
1357 for i in partition..n {
1358 if self.cursors[i].doc_with_si(si) == min_doc {
1359 self.cursors[i].ensure_block_loaded(si).await?;
1360 while self.cursors[i].doc_with_si(si) == min_doc {
1361 ordinal_scores.push((self.cursors[i].ordinal(), self.cursors[i].score()));
1362 self.cursors[i].advance(si).await?;
1363 }
1364 }
1365 }
1366
1367 let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
1368 if self.collector.len() >= self.collector.k
1369 && essential_total + non_essential_upper <= adjusted_threshold
1370 {
1371 docs_skipped += 1;
1372 continue;
1373 }
1374
1375 let mut running_total = essential_total;
1380 for i in (0..partition).rev() {
1381 if self.collector.len() >= self.collector.k
1384 && running_total + self.prefix_sums[i] <= adjusted_threshold
1385 {
1386 break;
1387 }
1388
1389 let doc = self.cursors[i].seek(si, min_doc).await?;
1390 if doc == min_doc {
1391 while self.cursors[i].doc_with_si(si) == min_doc {
1392 let s = self.cursors[i].score();
1393 running_total += s;
1394 ordinal_scores.push((self.cursors[i].ordinal(), s));
1395 self.cursors[i].advance(si).await?;
1396 }
1397 }
1398 }
1399
1400 ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
1402 let mut j = 0;
1403 while j < ordinal_scores.len() {
1404 let current_ord = ordinal_scores[j].0;
1405 let mut score = 0.0f32;
1406 while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
1407 score += ordinal_scores[j].1;
1408 j += 1;
1409 }
1410 if self
1411 .collector
1412 .insert_with_ordinal(min_doc, score, current_ord)
1413 {
1414 docs_scored += 1;
1415 } else {
1416 docs_skipped += 1;
1417 }
1418 }
1419 }
1420
1421 let results: Vec<ScoredDoc> = self
1422 .collector
1423 .into_sorted_results()
1424 .into_iter()
1425 .map(|(doc_id, score, ordinal)| ScoredDoc {
1426 doc_id,
1427 score,
1428 ordinal,
1429 })
1430 .collect();
1431
1432 debug!(
1433 "LazyBlockMaxScoreExecutor completed: scored={}, skipped={}, blocks_skipped={}, blocks_loaded={}, conjunction_skipped={}, returned={}, top_score={:.4}",
1434 docs_scored,
1435 docs_skipped,
1436 blocks_skipped,
1437 blocks_loaded,
1438 conjunction_skipped,
1439 results.len(),
1440 results.first().map(|r| r.score).unwrap_or(0.0)
1441 );
1442
1443 Ok(results)
1444 }
1445}
1446
1447#[cfg(test)]
1448mod tests {
1449 use super::*;
1450
1451 #[test]
1452 fn test_score_collector_basic() {
1453 let mut collector = ScoreCollector::new(3);
1454
1455 collector.insert(1, 1.0);
1456 collector.insert(2, 2.0);
1457 collector.insert(3, 3.0);
1458 assert_eq!(collector.threshold(), 1.0);
1459
1460 collector.insert(4, 4.0);
1461 assert_eq!(collector.threshold(), 2.0);
1462
1463 let results = collector.into_sorted_results();
1464 assert_eq!(results.len(), 3);
1465 assert_eq!(results[0].0, 4); assert_eq!(results[1].0, 3);
1467 assert_eq!(results[2].0, 2);
1468 }
1469
1470 #[test]
1471 fn test_score_collector_threshold() {
1472 let mut collector = ScoreCollector::new(2);
1473
1474 collector.insert(1, 5.0);
1475 collector.insert(2, 3.0);
1476 assert_eq!(collector.threshold(), 3.0);
1477
1478 assert!(!collector.would_enter(2.0));
1480 assert!(!collector.insert(3, 2.0));
1481
1482 assert!(collector.would_enter(4.0));
1484 assert!(collector.insert(4, 4.0));
1485 assert_eq!(collector.threshold(), 4.0);
1486 }
1487
1488 #[test]
1489 fn test_heap_entry_ordering() {
1490 let mut heap = BinaryHeap::new();
1491 heap.push(HeapEntry {
1492 doc_id: 1,
1493 score: 3.0,
1494 ordinal: 0,
1495 });
1496 heap.push(HeapEntry {
1497 doc_id: 2,
1498 score: 1.0,
1499 ordinal: 0,
1500 });
1501 heap.push(HeapEntry {
1502 doc_id: 3,
1503 score: 2.0,
1504 ordinal: 0,
1505 });
1506
1507 assert_eq!(heap.pop().unwrap().score, 1.0);
1509 assert_eq!(heap.pop().unwrap().score, 2.0);
1510 assert_eq!(heap.pop().unwrap().score, 3.0);
1511 }
1512}