1use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12use std::sync::Arc;
13
14use log::{debug, trace};
15
16use crate::DocId;
17use crate::structures::BlockSparsePostingList;
18
19pub trait ScoringIterator {
23 fn doc(&self) -> DocId;
25
26 fn ordinal(&mut self) -> u16 {
28 0
29 }
30
31 fn advance(&mut self) -> DocId;
33
34 fn seek(&mut self, target: DocId) -> DocId;
36
37 fn is_exhausted(&self) -> bool {
39 self.doc() == u32::MAX
40 }
41
42 fn score(&self) -> f32;
44
45 fn max_score(&self) -> f32;
47
48 fn current_block_max_score(&self) -> f32;
50
51 fn skip_to_next_block(&mut self) -> DocId {
55 self.advance()
56 }
57}
58
59#[derive(Clone, Copy)]
61pub struct HeapEntry {
62 pub doc_id: DocId,
63 pub score: f32,
64 pub ordinal: u16,
65}
66
67impl PartialEq for HeapEntry {
68 fn eq(&self, other: &Self) -> bool {
69 self.score == other.score && self.doc_id == other.doc_id
70 }
71}
72
73impl Eq for HeapEntry {}
74
75impl Ord for HeapEntry {
76 fn cmp(&self, other: &Self) -> Ordering {
77 other
79 .score
80 .partial_cmp(&self.score)
81 .unwrap_or(Ordering::Equal)
82 .then_with(|| self.doc_id.cmp(&other.doc_id))
83 }
84}
85
86impl PartialOrd for HeapEntry {
87 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
88 Some(self.cmp(other))
89 }
90}
91
92pub struct ScoreCollector {
98 heap: BinaryHeap<HeapEntry>,
100 pub k: usize,
101 cached_threshold: f32,
104}
105
106impl ScoreCollector {
107 pub fn new(k: usize) -> Self {
109 let capacity = k.saturating_add(1).min(1_000_000);
111 Self {
112 heap: BinaryHeap::with_capacity(capacity),
113 k,
114 cached_threshold: 0.0,
115 }
116 }
117
118 #[inline]
120 pub fn threshold(&self) -> f32 {
121 self.cached_threshold
122 }
123
124 #[inline]
126 fn update_threshold(&mut self) {
127 self.cached_threshold = if self.heap.len() >= self.k {
128 self.heap.peek().map(|e| e.score).unwrap_or(0.0)
129 } else {
130 0.0
131 };
132 }
133
134 #[inline]
137 pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
138 self.insert_with_ordinal(doc_id, score, 0)
139 }
140
141 #[inline]
144 pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
145 if self.heap.len() < self.k {
146 self.heap.push(HeapEntry {
147 doc_id,
148 score,
149 ordinal,
150 });
151 self.update_threshold();
152 true
153 } else if score > self.cached_threshold {
154 self.heap.push(HeapEntry {
155 doc_id,
156 score,
157 ordinal,
158 });
159 self.heap.pop(); self.update_threshold();
161 true
162 } else {
163 false
164 }
165 }
166
167 #[inline]
169 pub fn would_enter(&self, score: f32) -> bool {
170 self.heap.len() < self.k || score > self.cached_threshold
171 }
172
173 #[inline]
175 pub fn len(&self) -> usize {
176 self.heap.len()
177 }
178
179 #[inline]
181 pub fn is_empty(&self) -> bool {
182 self.heap.is_empty()
183 }
184
185 pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
187 let heap_vec = self.heap.into_vec();
188 let mut results: Vec<(DocId, f32, u16)> = Vec::with_capacity(heap_vec.len());
189 for e in heap_vec {
190 results.push((e.doc_id, e.score, e.ordinal));
191 }
192
193 results.sort_by(|a, b| {
195 b.1.partial_cmp(&a.1)
196 .unwrap_or(Ordering::Equal)
197 .then_with(|| a.0.cmp(&b.0))
198 });
199
200 results
201 }
202}
203
204#[derive(Debug, Clone, Copy)]
206pub struct ScoredDoc {
207 pub doc_id: DocId,
208 pub score: f32,
209 pub ordinal: u16,
211}
212
213pub struct BlockMaxScoreExecutor<S: ScoringIterator> {
227 scorers: Vec<S>,
229 prefix_sums: Vec<f32>,
231 collector: ScoreCollector,
233 heap_factor: f32,
237}
238
239pub type WandExecutor<S> = BlockMaxScoreExecutor<S>;
241
242impl<S: ScoringIterator> BlockMaxScoreExecutor<S> {
243 pub fn new(scorers: Vec<S>, k: usize) -> Self {
245 Self::with_heap_factor(scorers, k, 1.0)
246 }
247
248 pub fn with_heap_factor(mut scorers: Vec<S>, k: usize, heap_factor: f32) -> Self {
255 scorers.sort_by(|a, b| {
257 a.max_score()
258 .partial_cmp(&b.max_score())
259 .unwrap_or(Ordering::Equal)
260 });
261
262 let mut prefix_sums = Vec::with_capacity(scorers.len());
264 let mut cumsum = 0.0f32;
265 for s in &scorers {
266 cumsum += s.max_score();
267 prefix_sums.push(cumsum);
268 }
269
270 debug!(
271 "Creating BlockMaxScoreExecutor: num_scorers={}, k={}, total_upper={:.4}, heap_factor={:.2}",
272 scorers.len(),
273 k,
274 cumsum,
275 heap_factor
276 );
277
278 Self {
279 scorers,
280 prefix_sums,
281 collector: ScoreCollector::new(k),
282 heap_factor: heap_factor.clamp(0.0, 1.0),
283 }
284 }
285
286 #[inline]
289 fn find_partition(&self) -> usize {
290 let threshold = self.collector.threshold() * self.heap_factor;
291 self.prefix_sums.partition_point(|&sum| sum <= threshold)
293 }
294
295 pub fn execute(mut self) -> Vec<ScoredDoc> {
305 if self.scorers.is_empty() {
306 debug!("BlockMaxScoreExecutor: no scorers, returning empty results");
307 return Vec::new();
308 }
309
310 let n = self.scorers.len();
311 let mut docs_scored = 0u64;
312 let mut docs_skipped = 0u64;
313 let mut blocks_skipped = 0u64;
314 let mut conjunction_skipped = 0u64;
315
316 let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
318
319 loop {
320 let partition = self.find_partition();
321
322 if partition >= n {
324 debug!("BlockMaxScore: all terms non-essential, early termination");
325 break;
326 }
327
328 let mut min_doc = u32::MAX;
332 let mut present_upper = 0.0f32;
333 let mut block_max_sum = 0.0f32;
334 for i in partition..n {
335 let doc = self.scorers[i].doc();
336 if doc < min_doc {
337 min_doc = doc;
338 present_upper = self.scorers[i].max_score();
340 block_max_sum = self.scorers[i].current_block_max_score();
341 } else if doc == min_doc {
342 present_upper += self.scorers[i].max_score();
343 block_max_sum += self.scorers[i].current_block_max_score();
344 }
345 }
346
347 if min_doc == u32::MAX {
348 break; }
350
351 let non_essential_upper = if partition > 0 {
352 self.prefix_sums[partition - 1]
353 } else {
354 0.0
355 };
356 let adjusted_threshold = self.collector.threshold() * self.heap_factor;
357
358 if self.collector.len() >= self.collector.k
361 && present_upper + non_essential_upper <= adjusted_threshold
362 {
363 for i in partition..n {
364 if self.scorers[i].doc() == min_doc {
365 self.scorers[i].advance();
366 }
367 }
368 conjunction_skipped += 1;
369 continue;
370 }
371
372 if self.collector.len() >= self.collector.k
375 && block_max_sum + non_essential_upper <= adjusted_threshold
376 {
377 for i in partition..n {
378 if self.scorers[i].doc() == min_doc {
379 self.scorers[i].skip_to_next_block();
380 }
381 }
382 blocks_skipped += 1;
383 continue;
384 }
385
386 ordinal_scores.clear();
389
390 for i in partition..n {
391 if self.scorers[i].doc() == min_doc {
392 while self.scorers[i].doc() == min_doc {
393 ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
394 self.scorers[i].advance();
395 }
396 }
397 }
398
399 let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
401
402 if self.collector.len() >= self.collector.k
403 && essential_total + non_essential_upper <= adjusted_threshold
404 {
405 docs_skipped += 1;
406 continue;
407 }
408
409 for i in 0..partition {
411 let doc = self.scorers[i].seek(min_doc);
412 if doc == min_doc {
413 while self.scorers[i].doc() == min_doc {
414 ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
415 self.scorers[i].advance();
416 }
417 }
418 }
419
420 ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
422 let mut j = 0;
423 while j < ordinal_scores.len() {
424 let current_ord = ordinal_scores[j].0;
425 let mut score = 0.0f32;
426 while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
427 score += ordinal_scores[j].1;
428 j += 1;
429 }
430
431 trace!(
432 "Doc {}: ordinal={}, score={:.4}, threshold={:.4}",
433 min_doc, current_ord, score, adjusted_threshold
434 );
435
436 if self
437 .collector
438 .insert_with_ordinal(min_doc, score, current_ord)
439 {
440 docs_scored += 1;
441 } else {
442 docs_skipped += 1;
443 }
444 }
445 }
446
447 let results: Vec<ScoredDoc> = self
448 .collector
449 .into_sorted_results()
450 .into_iter()
451 .map(|(doc_id, score, ordinal)| ScoredDoc {
452 doc_id,
453 score,
454 ordinal,
455 })
456 .collect();
457
458 debug!(
459 "BlockMaxScoreExecutor completed: scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
460 docs_scored,
461 docs_skipped,
462 blocks_skipped,
463 conjunction_skipped,
464 results.len(),
465 results.first().map(|r| r.score).unwrap_or(0.0)
466 );
467
468 results
469 }
470}
471
472pub struct TextTermScorer {
477 iter: crate::structures::BlockPostingIterator<'static>,
479 idf: f32,
481 avg_field_len: f32,
483 max_score: f32,
485}
486
487impl TextTermScorer {
488 pub fn new(
490 posting_list: crate::structures::BlockPostingList,
491 idf: f32,
492 avg_field_len: f32,
493 ) -> Self {
494 let max_tf = posting_list.max_tf() as f32;
496 let doc_count = posting_list.doc_count();
497 let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
498
499 debug!(
500 "Created TextTermScorer: doc_count={}, max_tf={:.0}, idf={:.4}, avg_field_len={:.2}, max_score={:.4}",
501 doc_count, max_tf, idf, avg_field_len, max_score
502 );
503
504 Self {
505 iter: posting_list.into_iterator(),
506 idf,
507 avg_field_len,
508 max_score,
509 }
510 }
511}
512
513impl ScoringIterator for TextTermScorer {
514 #[inline]
515 fn doc(&self) -> DocId {
516 self.iter.doc()
517 }
518
519 #[inline]
520 fn advance(&mut self) -> DocId {
521 self.iter.advance()
522 }
523
524 #[inline]
525 fn seek(&mut self, target: DocId) -> DocId {
526 self.iter.seek(target)
527 }
528
529 #[inline]
530 fn score(&self) -> f32 {
531 let tf = self.iter.term_freq() as f32;
532 super::bm25_score(tf, self.idf, tf, self.avg_field_len)
534 }
535
536 #[inline]
537 fn max_score(&self) -> f32 {
538 self.max_score
539 }
540
541 #[inline]
542 fn current_block_max_score(&self) -> f32 {
543 let block_max_tf = self.iter.current_block_max_tf() as f32;
545 super::bm25_upper_bound(block_max_tf.max(1.0), self.idf)
546 }
547
548 #[inline]
549 fn skip_to_next_block(&mut self) -> DocId {
550 self.iter.skip_to_next_block()
551 }
552}
553
554pub struct SparseTermScorer<'a> {
558 iter: crate::structures::BlockSparsePostingIterator<'a>,
560 query_weight: f32,
562 abs_query_weight: f32,
564 max_score: f32,
566}
567
568impl<'a> SparseTermScorer<'a> {
569 pub fn new(posting_list: &'a BlockSparsePostingList, query_weight: f32) -> Self {
574 let abs_qw = query_weight.abs();
577 let max_score = abs_qw * posting_list.global_max_weight();
578 Self {
579 iter: posting_list.iterator(),
580 query_weight,
581 abs_query_weight: abs_qw,
582 max_score,
583 }
584 }
585
586 pub fn from_arc(posting_list: &'a Arc<BlockSparsePostingList>, query_weight: f32) -> Self {
588 Self::new(posting_list.as_ref(), query_weight)
589 }
590}
591
592impl ScoringIterator for SparseTermScorer<'_> {
593 #[inline]
594 fn doc(&self) -> DocId {
595 self.iter.doc()
596 }
597
598 #[inline]
599 fn ordinal(&mut self) -> u16 {
600 self.iter.ordinal()
601 }
602
603 #[inline]
604 fn advance(&mut self) -> DocId {
605 self.iter.advance()
606 }
607
608 #[inline]
609 fn seek(&mut self, target: DocId) -> DocId {
610 self.iter.seek(target)
611 }
612
613 #[inline]
614 fn score(&self) -> f32 {
615 self.query_weight * self.iter.weight()
617 }
618
619 #[inline]
620 fn max_score(&self) -> f32 {
621 self.max_score
622 }
623
624 #[inline]
625 fn current_block_max_score(&self) -> f32 {
626 self.iter
627 .current_block_max_contribution(self.abs_query_weight)
628 }
629
630 #[inline]
631 fn skip_to_next_block(&mut self) -> DocId {
632 self.iter.skip_to_next_block()
633 }
634}
635
636pub struct BmpExecutor<'a> {
648 sparse_index: &'a crate::segment::SparseIndex,
650 query_terms: Vec<(u32, f32)>,
652 k: usize,
654 heap_factor: f32,
656}
657
658const BMP_SUPERBLOCK_SIZE: usize = 8;
661
662struct BmpBlockEntry {
664 contribution: f32,
666 term_idx: usize,
668 block_start: usize,
670 block_count: usize,
672}
673
674impl PartialEq for BmpBlockEntry {
675 fn eq(&self, other: &Self) -> bool {
676 self.contribution == other.contribution
677 }
678}
679
680impl Eq for BmpBlockEntry {}
681
682impl Ord for BmpBlockEntry {
683 fn cmp(&self, other: &Self) -> Ordering {
684 self.contribution
686 .partial_cmp(&other.contribution)
687 .unwrap_or(Ordering::Equal)
688 }
689}
690
691impl PartialOrd for BmpBlockEntry {
692 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
693 Some(self.cmp(other))
694 }
695}
696
697impl<'a> BmpExecutor<'a> {
698 pub fn new(
703 sparse_index: &'a crate::segment::SparseIndex,
704 query_terms: Vec<(u32, f32)>,
705 k: usize,
706 heap_factor: f32,
707 ) -> Self {
708 Self {
709 sparse_index,
710 query_terms,
711 k,
712 heap_factor: heap_factor.clamp(0.0, 1.0),
713 }
714 }
715
716 pub async fn execute(self) -> crate::Result<Vec<ScoredDoc>> {
724 use rustc_hash::FxHashMap;
725
726 if self.query_terms.is_empty() {
727 return Ok(Vec::new());
728 }
729
730 let num_terms = self.query_terms.len();
731 let si = self.sparse_index;
732
733 let mut block_queue: BinaryHeap<BmpBlockEntry> = BinaryHeap::new();
735 let mut remaining_max: Vec<f32> = Vec::with_capacity(num_terms);
736 let mut term_skip_starts: Vec<usize> = Vec::with_capacity(num_terms);
738 let mut global_min_doc = u32::MAX;
739 let mut global_max_doc = 0u32;
740
741 for (term_idx, &(dim_id, qw)) in self.query_terms.iter().enumerate() {
742 let mut term_remaining = 0.0f32;
743 let mut term_skip_start = 0usize;
744
745 let abs_qw = qw.abs();
746 if let Some((skip_start, skip_count, _global_max)) = si.get_skip_range(dim_id) {
747 term_skip_start = skip_start;
748 let mut sb_start = 0;
750 while sb_start < skip_count {
751 let sb_count = (skip_count - sb_start).min(BMP_SUPERBLOCK_SIZE);
752 let mut sb_contribution = 0.0f32;
753 for j in 0..sb_count {
754 let skip = si.read_skip_entry(skip_start + sb_start + j);
755 sb_contribution += abs_qw * skip.max_weight;
756 global_min_doc = global_min_doc.min(skip.first_doc);
757 global_max_doc = global_max_doc.max(skip.last_doc);
758 }
759 term_remaining += sb_contribution;
760 block_queue.push(BmpBlockEntry {
761 contribution: sb_contribution,
762 term_idx,
763 block_start: sb_start,
764 block_count: sb_count,
765 });
766 sb_start += sb_count;
767 }
768 }
769 remaining_max.push(term_remaining);
770 term_skip_starts.push(term_skip_start);
771 }
772
773 let doc_range = if global_max_doc >= global_min_doc {
775 (global_max_doc - global_min_doc + 1) as usize
776 } else {
777 0
778 };
779 let use_flat = doc_range > 0 && doc_range <= 256 * 1024;
781 let mut flat_scores: Vec<f32> = if use_flat {
782 vec![0.0; doc_range]
783 } else {
784 Vec::new()
785 };
786 let mut dirty: Vec<u32> = if use_flat {
788 Vec::with_capacity(4096)
789 } else {
790 Vec::new()
791 };
792 let mut multi_ord_accumulators: FxHashMap<u64, f32> = FxHashMap::default();
794
795 let mut total_remaining: f32 = remaining_max.iter().sum();
796 let mut blocks_processed = 0u64;
797 let mut blocks_skipped = 0u64;
798
799 let mut top_k = ScoreCollector::new(self.k);
801
802 let mut doc_ids_buf: Vec<u32> = Vec::with_capacity(128);
804 let mut weights_buf: Vec<f32> = Vec::with_capacity(128);
805 let mut ordinals_buf: Vec<u16> = Vec::with_capacity(128);
806
807 while let Some(entry) = block_queue.pop() {
809 remaining_max[entry.term_idx] -= entry.contribution;
810 total_remaining -= entry.contribution;
811
812 let adjusted_threshold = top_k.threshold() * self.heap_factor;
813 if top_k.len() >= self.k && total_remaining <= adjusted_threshold {
814 blocks_skipped += block_queue.len() as u64;
815 debug!(
816 "BMP early termination after {} blocks: remaining={:.4} <= threshold={:.4}",
817 blocks_processed, total_remaining, adjusted_threshold
818 );
819 break;
820 }
821
822 let dim_id = self.query_terms[entry.term_idx].0;
824 let qw = self.query_terms[entry.term_idx].1;
825 let abs_qw = qw.abs();
826
827 let sb_blocks = si
828 .get_blocks_range(dim_id, entry.block_start, entry.block_count)
829 .await?;
830
831 let skip_start = term_skip_starts[entry.term_idx];
832 let adjusted_threshold2 = top_k.threshold() * self.heap_factor;
833
834 for (blk_offset, block) in sb_blocks.into_iter().enumerate() {
835 let blk_idx = entry.block_start + blk_offset;
836
837 if top_k.len() >= self.k {
839 let skip = si.read_skip_entry(skip_start + blk_idx);
840 let blk_contrib = abs_qw * skip.max_weight;
841 if blk_contrib + total_remaining <= adjusted_threshold2 {
842 blocks_skipped += 1;
843 continue;
844 }
845 }
846
847 block.decode_doc_ids_into(&mut doc_ids_buf);
848 block.decode_scored_weights_into(qw, &mut weights_buf);
849 let count = block.header.count as usize;
850
851 if block.header.ordinal_bits == 0 && use_flat {
854 for i in 0..count {
855 let doc_id = doc_ids_buf[i];
856 let off = (doc_id - global_min_doc) as usize;
857 if flat_scores[off] == 0.0 {
858 dirty.push(doc_id);
859 }
860 flat_scores[off] += weights_buf[i];
861 top_k.insert_with_ordinal(doc_id, flat_scores[off], 0);
862 }
863 } else {
864 block.decode_ordinals_into(&mut ordinals_buf);
865 if use_flat {
866 for i in 0..count {
867 let doc_id = doc_ids_buf[i];
868 let ordinal = ordinals_buf[i];
869 let score_contribution = weights_buf[i];
870
871 if ordinal == 0 {
872 let off = (doc_id - global_min_doc) as usize;
873 if flat_scores[off] == 0.0 {
874 dirty.push(doc_id);
875 }
876 flat_scores[off] += score_contribution;
877 top_k.insert_with_ordinal(doc_id, flat_scores[off], 0);
878 } else {
879 let key = (doc_id as u64) << 16 | ordinal as u64;
880 let acc = multi_ord_accumulators.entry(key).or_insert(0.0);
881 *acc += score_contribution;
882 top_k.insert_with_ordinal(doc_id, *acc, ordinal);
883 }
884 }
885 } else {
886 for i in 0..count {
887 let key = (doc_ids_buf[i] as u64) << 16 | ordinals_buf[i] as u64;
888 let acc = multi_ord_accumulators.entry(key).or_insert(0.0);
889 *acc += weights_buf[i];
890 top_k.insert_with_ordinal(doc_ids_buf[i], *acc, ordinals_buf[i]);
891 }
892 }
893 }
894
895 blocks_processed += 1;
896 }
897 }
898
899 let mut scored: Vec<ScoredDoc> = Vec::new();
901
902 let num_accumulators = if use_flat {
903 scored.reserve(dirty.len() + multi_ord_accumulators.len());
905 for &doc_id in &dirty {
906 let off = (doc_id - global_min_doc) as usize;
907 let score = flat_scores[off];
908 if score > 0.0 {
909 scored.push(ScoredDoc {
910 doc_id,
911 score,
912 ordinal: 0,
913 });
914 }
915 }
916 dirty.len() + multi_ord_accumulators.len()
917 } else {
918 multi_ord_accumulators.len()
919 };
920
921 scored.extend(
923 multi_ord_accumulators
924 .into_iter()
925 .map(|(key, score)| ScoredDoc {
926 doc_id: (key >> 16) as DocId,
927 score,
928 ordinal: (key & 0xFFFF) as u16,
929 }),
930 );
931
932 scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
933 scored.truncate(self.k);
934 let results = scored;
935
936 debug!(
937 "BmpExecutor completed: blocks_processed={}, blocks_skipped={}, accumulators={}, flat={}, returned={}, top_score={:.4}",
938 blocks_processed,
939 blocks_skipped,
940 num_accumulators,
941 use_flat,
942 results.len(),
943 results.first().map(|r| r.score).unwrap_or(0.0)
944 );
945
946 Ok(results)
947 }
948}
949
950pub struct LazyBlockMaxScoreExecutor<'a> {
962 sparse_index: &'a crate::segment::SparseIndex,
963 cursors: Vec<LazyTermCursor>,
964 prefix_sums: Vec<f32>,
965 collector: ScoreCollector,
966 heap_factor: f32,
967}
968
969struct LazyTermCursor {
971 query_weight: f32,
972 abs_query_weight: f32,
974 max_score: f32,
975 skip_start: usize,
977 skip_count: usize,
979 block_data_offset: u32,
981 block_idx: usize,
983 doc_ids: Vec<u32>,
985 ordinals: Vec<u16>,
986 weights: Vec<f32>,
987 pos: usize,
989 block_loaded: bool,
991 exhausted: bool,
992}
993
994impl LazyTermCursor {
995 fn new(
996 query_weight: f32,
997 skip_start: usize,
998 skip_count: usize,
999 global_max_weight: f32,
1000 block_data_offset: u32,
1001 ) -> Self {
1002 let exhausted = skip_count == 0;
1003 let abs_qw = query_weight.abs();
1004 Self {
1005 query_weight,
1006 abs_query_weight: abs_qw,
1007 max_score: abs_qw * global_max_weight,
1008 skip_start,
1009 skip_count,
1010 block_data_offset,
1011 block_idx: 0,
1012 doc_ids: Vec::with_capacity(128),
1013 ordinals: Vec::with_capacity(128),
1014 weights: Vec::with_capacity(128),
1015 pos: 0,
1016 block_loaded: false,
1017 exhausted,
1018 }
1019 }
1020
1021 async fn ensure_block_loaded(
1023 &mut self,
1024 sparse_index: &crate::segment::SparseIndex,
1025 ) -> crate::Result<bool> {
1026 if self.exhausted || self.block_loaded {
1027 return Ok(!self.exhausted);
1028 }
1029 match sparse_index
1030 .load_block_direct(self.skip_start, self.block_data_offset, self.block_idx)
1031 .await?
1032 {
1033 Some(block) => {
1034 block.decode_doc_ids_into(&mut self.doc_ids);
1035 block.decode_ordinals_into(&mut self.ordinals);
1036 block.decode_scored_weights_into(self.query_weight, &mut self.weights);
1037 self.pos = 0;
1038 self.block_loaded = true;
1039 Ok(true)
1040 }
1041 None => {
1042 self.exhausted = true;
1043 Ok(false)
1044 }
1045 }
1046 }
1047
1048 #[inline]
1049 fn doc_with_si(&self, si: &crate::segment::SparseIndex) -> DocId {
1050 if self.exhausted {
1051 return u32::MAX;
1052 }
1053 if !self.block_loaded {
1054 return si
1057 .read_skip_entry(self.skip_start + self.block_idx)
1058 .first_doc;
1059 }
1060 self.doc_ids.get(self.pos).copied().unwrap_or(u32::MAX)
1061 }
1062
1063 #[inline]
1064 fn doc(&self) -> DocId {
1065 if self.exhausted {
1066 return u32::MAX;
1067 }
1068 if self.block_loaded {
1069 return self.doc_ids.get(self.pos).copied().unwrap_or(u32::MAX);
1070 }
1071 u32::MAX
1073 }
1074
1075 #[inline]
1076 fn ordinal(&self) -> u16 {
1077 if !self.block_loaded {
1078 return 0;
1079 }
1080 self.ordinals.get(self.pos).copied().unwrap_or(0)
1081 }
1082
1083 #[inline]
1084 fn score(&self) -> f32 {
1085 if !self.block_loaded {
1086 return 0.0;
1087 }
1088 self.weights.get(self.pos).copied().unwrap_or(0.0)
1089 }
1090
1091 #[inline]
1092 fn current_block_max_score(&self, si: &crate::segment::SparseIndex) -> f32 {
1093 if self.exhausted || self.block_idx >= self.skip_count {
1094 return 0.0;
1095 }
1096 self.abs_query_weight
1097 * si.read_skip_entry(self.skip_start + self.block_idx)
1098 .max_weight
1099 }
1100
1101 async fn advance(
1103 &mut self,
1104 sparse_index: &crate::segment::SparseIndex,
1105 ) -> crate::Result<DocId> {
1106 if self.exhausted {
1107 return Ok(u32::MAX);
1108 }
1109 self.ensure_block_loaded(sparse_index).await?;
1110 if self.exhausted {
1111 return Ok(u32::MAX);
1112 }
1113 self.pos += 1;
1114 if self.pos >= self.doc_ids.len() {
1115 self.block_idx += 1;
1116 self.block_loaded = false;
1117 if self.block_idx >= self.skip_count {
1118 self.exhausted = true;
1119 return Ok(u32::MAX);
1120 }
1121 }
1123 Ok(self.doc())
1124 }
1125
1126 async fn seek(
1128 &mut self,
1129 sparse_index: &crate::segment::SparseIndex,
1130 target: DocId,
1131 ) -> crate::Result<DocId> {
1132 if self.exhausted {
1133 return Ok(u32::MAX);
1134 }
1135
1136 if self.block_loaded
1138 && let Some(&last) = self.doc_ids.last()
1139 {
1140 if last >= target && self.doc_ids[self.pos] < target {
1141 let remaining = &self.doc_ids[self.pos..];
1143 let offset = crate::structures::simd::find_first_ge_u32(remaining, target);
1144 self.pos += offset;
1145 if self.pos >= self.doc_ids.len() {
1146 self.block_idx += 1;
1147 self.block_loaded = false;
1148 if self.block_idx >= self.skip_count {
1149 self.exhausted = true;
1150 return Ok(u32::MAX);
1151 }
1152 }
1153 return Ok(self.doc());
1154 }
1155 if self.doc_ids[self.pos] >= target {
1156 return Ok(self.doc());
1157 }
1158 }
1159
1160 let mut lo = self.block_idx;
1163 let mut hi = self.skip_count;
1164 while lo < hi {
1165 let mid = lo + (hi - lo) / 2;
1166 if sparse_index.read_skip_entry(self.skip_start + mid).last_doc < target {
1167 lo = mid + 1;
1168 } else {
1169 hi = mid;
1170 }
1171 }
1172 if lo >= self.skip_count {
1173 self.exhausted = true;
1174 return Ok(u32::MAX);
1175 }
1176 if lo != self.block_idx || !self.block_loaded {
1177 self.block_idx = lo;
1178 self.block_loaded = false;
1179 }
1180 self.ensure_block_loaded(sparse_index).await?;
1181 if self.exhausted {
1182 return Ok(u32::MAX);
1183 }
1184 self.pos = crate::structures::simd::find_first_ge_u32(&self.doc_ids, target);
1185 if self.pos >= self.doc_ids.len() {
1186 self.block_idx += 1;
1187 self.block_loaded = false;
1188 if self.block_idx >= self.skip_count {
1189 self.exhausted = true;
1190 return Ok(u32::MAX);
1191 }
1192 self.ensure_block_loaded(sparse_index).await?;
1193 }
1194 Ok(self.doc())
1195 }
1196
1197 fn skip_to_next_block(&mut self, si: &crate::segment::SparseIndex) -> DocId {
1199 if self.exhausted {
1200 return u32::MAX;
1201 }
1202 self.block_idx += 1;
1203 self.block_loaded = false;
1204 if self.block_idx >= self.skip_count {
1205 self.exhausted = true;
1206 return u32::MAX;
1207 }
1208 si.read_skip_entry(self.skip_start + self.block_idx)
1210 .first_doc
1211 }
1212}
1213
1214impl<'a> LazyBlockMaxScoreExecutor<'a> {
1215 pub fn new(
1220 sparse_index: &'a crate::segment::SparseIndex,
1221 query_terms: Vec<(u32, f32)>,
1222 k: usize,
1223 heap_factor: f32,
1224 ) -> Self {
1225 let mut cursors: Vec<LazyTermCursor> = query_terms
1226 .iter()
1227 .filter_map(|&(dim_id, qw)| {
1228 let (skip_start, skip_count, global_max, block_data_offset) =
1229 sparse_index.get_skip_range_full(dim_id)?;
1230 Some(LazyTermCursor::new(
1231 qw,
1232 skip_start,
1233 skip_count,
1234 global_max,
1235 block_data_offset,
1236 ))
1237 })
1238 .collect();
1239
1240 cursors.sort_by(|a, b| {
1242 a.max_score
1243 .partial_cmp(&b.max_score)
1244 .unwrap_or(Ordering::Equal)
1245 });
1246
1247 let mut prefix_sums = Vec::with_capacity(cursors.len());
1248 let mut cumsum = 0.0f32;
1249 for c in &cursors {
1250 cumsum += c.max_score;
1251 prefix_sums.push(cumsum);
1252 }
1253
1254 debug!(
1255 "Creating LazyBlockMaxScoreExecutor: num_terms={}, k={}, total_upper={:.4}, heap_factor={:.2}",
1256 cursors.len(),
1257 k,
1258 cumsum,
1259 heap_factor
1260 );
1261
1262 Self {
1263 sparse_index,
1264 cursors,
1265 prefix_sums,
1266 collector: ScoreCollector::new(k),
1267 heap_factor: heap_factor.clamp(0.0, 1.0),
1268 }
1269 }
1270
1271 #[inline]
1272 fn find_partition(&self) -> usize {
1273 let threshold = self.collector.threshold() * self.heap_factor;
1274 self.prefix_sums.partition_point(|&sum| sum <= threshold)
1276 }
1277
1278 pub async fn execute(mut self) -> crate::Result<Vec<ScoredDoc>> {
1280 if self.cursors.is_empty() {
1281 return Ok(Vec::new());
1282 }
1283
1284 let n = self.cursors.len();
1285 let si = self.sparse_index;
1286
1287 for cursor in &mut self.cursors {
1289 cursor.ensure_block_loaded(si).await?;
1290 }
1291
1292 let mut docs_scored = 0u64;
1293 let mut docs_skipped = 0u64;
1294 let mut blocks_skipped = 0u64;
1295 let mut blocks_loaded = 0u64;
1296 let mut conjunction_skipped = 0u64;
1297 let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
1298
1299 loop {
1300 let partition = self.find_partition();
1301 if partition >= n {
1302 break;
1303 }
1304
1305 let mut min_doc = u32::MAX;
1307 for i in partition..n {
1308 let doc = self.cursors[i].doc_with_si(si);
1309 if doc < min_doc {
1310 min_doc = doc;
1311 }
1312 }
1313 if min_doc == u32::MAX {
1314 break;
1315 }
1316
1317 let non_essential_upper = if partition > 0 {
1318 self.prefix_sums[partition - 1]
1319 } else {
1320 0.0
1321 };
1322 let adjusted_threshold = self.collector.threshold() * self.heap_factor;
1323
1324 if self.collector.len() >= self.collector.k {
1326 let present_upper: f32 = (partition..n)
1327 .filter(|&i| self.cursors[i].doc_with_si(si) == min_doc)
1328 .map(|i| self.cursors[i].max_score)
1329 .sum();
1330
1331 if present_upper + non_essential_upper <= adjusted_threshold {
1332 for i in partition..n {
1333 if self.cursors[i].doc_with_si(si) == min_doc {
1334 self.cursors[i].ensure_block_loaded(si).await?;
1335 self.cursors[i].advance(si).await?;
1336 blocks_loaded += u64::from(self.cursors[i].block_loaded);
1337 }
1338 }
1339 conjunction_skipped += 1;
1340 continue;
1341 }
1342 }
1343
1344 if self.collector.len() >= self.collector.k {
1346 let block_max_sum: f32 = (partition..n)
1347 .filter(|&i| self.cursors[i].doc_with_si(si) == min_doc)
1348 .map(|i| self.cursors[i].current_block_max_score(si))
1349 .sum();
1350
1351 if block_max_sum + non_essential_upper <= adjusted_threshold {
1352 for i in partition..n {
1353 if self.cursors[i].doc_with_si(si) == min_doc {
1354 self.cursors[i].skip_to_next_block(si);
1355 self.cursors[i].ensure_block_loaded(si).await?;
1357 blocks_loaded += 1;
1358 }
1359 }
1360 blocks_skipped += 1;
1361 continue;
1362 }
1363 }
1364
1365 ordinal_scores.clear();
1367 for i in partition..n {
1368 if self.cursors[i].doc_with_si(si) == min_doc {
1369 self.cursors[i].ensure_block_loaded(si).await?;
1370 while self.cursors[i].doc_with_si(si) == min_doc {
1371 ordinal_scores.push((self.cursors[i].ordinal(), self.cursors[i].score()));
1372 self.cursors[i].advance(si).await?;
1373 }
1374 }
1375 }
1376
1377 let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
1378 if self.collector.len() >= self.collector.k
1379 && essential_total + non_essential_upper <= adjusted_threshold
1380 {
1381 docs_skipped += 1;
1382 continue;
1383 }
1384
1385 let mut running_total = essential_total;
1390 for i in (0..partition).rev() {
1391 if self.collector.len() >= self.collector.k
1394 && running_total + self.prefix_sums[i] <= adjusted_threshold
1395 {
1396 break;
1397 }
1398
1399 let doc = self.cursors[i].seek(si, min_doc).await?;
1400 if doc == min_doc {
1401 while self.cursors[i].doc_with_si(si) == min_doc {
1402 let s = self.cursors[i].score();
1403 running_total += s;
1404 ordinal_scores.push((self.cursors[i].ordinal(), s));
1405 self.cursors[i].advance(si).await?;
1406 }
1407 }
1408 }
1409
1410 ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
1412 let mut j = 0;
1413 while j < ordinal_scores.len() {
1414 let current_ord = ordinal_scores[j].0;
1415 let mut score = 0.0f32;
1416 while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
1417 score += ordinal_scores[j].1;
1418 j += 1;
1419 }
1420 if self
1421 .collector
1422 .insert_with_ordinal(min_doc, score, current_ord)
1423 {
1424 docs_scored += 1;
1425 } else {
1426 docs_skipped += 1;
1427 }
1428 }
1429 }
1430
1431 let results: Vec<ScoredDoc> = self
1432 .collector
1433 .into_sorted_results()
1434 .into_iter()
1435 .map(|(doc_id, score, ordinal)| ScoredDoc {
1436 doc_id,
1437 score,
1438 ordinal,
1439 })
1440 .collect();
1441
1442 debug!(
1443 "LazyBlockMaxScoreExecutor completed: scored={}, skipped={}, blocks_skipped={}, blocks_loaded={}, conjunction_skipped={}, returned={}, top_score={:.4}",
1444 docs_scored,
1445 docs_skipped,
1446 blocks_skipped,
1447 blocks_loaded,
1448 conjunction_skipped,
1449 results.len(),
1450 results.first().map(|r| r.score).unwrap_or(0.0)
1451 );
1452
1453 Ok(results)
1454 }
1455}
1456
1457#[cfg(test)]
1458mod tests {
1459 use super::*;
1460
1461 #[test]
1462 fn test_score_collector_basic() {
1463 let mut collector = ScoreCollector::new(3);
1464
1465 collector.insert(1, 1.0);
1466 collector.insert(2, 2.0);
1467 collector.insert(3, 3.0);
1468 assert_eq!(collector.threshold(), 1.0);
1469
1470 collector.insert(4, 4.0);
1471 assert_eq!(collector.threshold(), 2.0);
1472
1473 let results = collector.into_sorted_results();
1474 assert_eq!(results.len(), 3);
1475 assert_eq!(results[0].0, 4); assert_eq!(results[1].0, 3);
1477 assert_eq!(results[2].0, 2);
1478 }
1479
1480 #[test]
1481 fn test_score_collector_threshold() {
1482 let mut collector = ScoreCollector::new(2);
1483
1484 collector.insert(1, 5.0);
1485 collector.insert(2, 3.0);
1486 assert_eq!(collector.threshold(), 3.0);
1487
1488 assert!(!collector.would_enter(2.0));
1490 assert!(!collector.insert(3, 2.0));
1491
1492 assert!(collector.would_enter(4.0));
1494 assert!(collector.insert(4, 4.0));
1495 assert_eq!(collector.threshold(), 4.0);
1496 }
1497
1498 #[test]
1499 fn test_heap_entry_ordering() {
1500 let mut heap = BinaryHeap::new();
1501 heap.push(HeapEntry {
1502 doc_id: 1,
1503 score: 3.0,
1504 ordinal: 0,
1505 });
1506 heap.push(HeapEntry {
1507 doc_id: 2,
1508 score: 1.0,
1509 ordinal: 0,
1510 });
1511 heap.push(HeapEntry {
1512 doc_id: 3,
1513 score: 2.0,
1514 ordinal: 0,
1515 });
1516
1517 assert_eq!(heap.pop().unwrap().score, 1.0);
1519 assert_eq!(heap.pop().unwrap().score, 2.0);
1520 assert_eq!(heap.pop().unwrap().score, 3.0);
1521 }
1522}