1use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12use std::sync::Arc;
13
14use log::{debug, trace};
15
16use crate::DocId;
17use crate::structures::BlockSparsePostingList;
18
19pub trait ScoringIterator {
23 fn doc(&self) -> DocId;
25
26 fn ordinal(&self) -> u16 {
28 0
29 }
30
31 fn advance(&mut self) -> DocId;
33
34 fn seek(&mut self, target: DocId) -> DocId;
36
37 fn is_exhausted(&self) -> bool {
39 self.doc() == u32::MAX
40 }
41
42 fn score(&self) -> f32;
44
45 fn max_score(&self) -> f32;
47
48 fn current_block_max_score(&self) -> f32;
50
51 fn skip_to_next_block(&mut self) -> DocId {
55 self.advance()
56 }
57}
58
59#[derive(Clone, Copy)]
61pub struct HeapEntry {
62 pub doc_id: DocId,
63 pub score: f32,
64 pub ordinal: u16,
65}
66
67impl PartialEq for HeapEntry {
68 fn eq(&self, other: &Self) -> bool {
69 self.score == other.score && self.doc_id == other.doc_id
70 }
71}
72
73impl Eq for HeapEntry {}
74
75impl Ord for HeapEntry {
76 fn cmp(&self, other: &Self) -> Ordering {
77 other
79 .score
80 .partial_cmp(&self.score)
81 .unwrap_or(Ordering::Equal)
82 .then_with(|| self.doc_id.cmp(&other.doc_id))
83 }
84}
85
86impl PartialOrd for HeapEntry {
87 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
88 Some(self.cmp(other))
89 }
90}
91
92pub struct ScoreCollector {
98 heap: BinaryHeap<HeapEntry>,
100 pub k: usize,
101}
102
103impl ScoreCollector {
104 pub fn new(k: usize) -> Self {
106 let capacity = k.saturating_add(1).min(1_000_000);
108 Self {
109 heap: BinaryHeap::with_capacity(capacity),
110 k,
111 }
112 }
113
114 #[inline]
116 pub fn threshold(&self) -> f32 {
117 if self.heap.len() >= self.k {
118 self.heap.peek().map(|e| e.score).unwrap_or(0.0)
119 } else {
120 0.0
121 }
122 }
123
124 #[inline]
127 pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
128 self.insert_with_ordinal(doc_id, score, 0)
129 }
130
131 #[inline]
134 pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
135 if self.heap.len() < self.k {
136 self.heap.push(HeapEntry {
137 doc_id,
138 score,
139 ordinal,
140 });
141 true
142 } else if score > self.threshold() {
143 self.heap.push(HeapEntry {
144 doc_id,
145 score,
146 ordinal,
147 });
148 self.heap.pop(); true
150 } else {
151 false
152 }
153 }
154
155 #[inline]
157 pub fn would_enter(&self, score: f32) -> bool {
158 self.heap.len() < self.k || score > self.threshold()
159 }
160
161 #[inline]
163 pub fn len(&self) -> usize {
164 self.heap.len()
165 }
166
167 #[inline]
169 pub fn is_empty(&self) -> bool {
170 self.heap.is_empty()
171 }
172
173 pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
175 let heap_vec = self.heap.into_vec();
176 let mut results: Vec<(DocId, f32, u16)> = Vec::with_capacity(heap_vec.len());
177 for e in heap_vec {
178 results.push((e.doc_id, e.score, e.ordinal));
179 }
180
181 results.sort_by(|a, b| {
183 b.1.partial_cmp(&a.1)
184 .unwrap_or(Ordering::Equal)
185 .then_with(|| a.0.cmp(&b.0))
186 });
187
188 results
189 }
190}
191
192#[derive(Debug, Clone, Copy)]
194pub struct ScoredDoc {
195 pub doc_id: DocId,
196 pub score: f32,
197 pub ordinal: u16,
199}
200
201pub struct BlockMaxScoreExecutor<S: ScoringIterator> {
215 scorers: Vec<S>,
217 prefix_sums: Vec<f32>,
219 collector: ScoreCollector,
221 heap_factor: f32,
225}
226
227pub type WandExecutor<S> = BlockMaxScoreExecutor<S>;
229
230impl<S: ScoringIterator> BlockMaxScoreExecutor<S> {
231 pub fn new(scorers: Vec<S>, k: usize) -> Self {
233 Self::with_heap_factor(scorers, k, 1.0)
234 }
235
236 pub fn with_heap_factor(mut scorers: Vec<S>, k: usize, heap_factor: f32) -> Self {
243 scorers.sort_by(|a, b| {
245 a.max_score()
246 .partial_cmp(&b.max_score())
247 .unwrap_or(Ordering::Equal)
248 });
249
250 let mut prefix_sums = Vec::with_capacity(scorers.len());
252 let mut cumsum = 0.0f32;
253 for s in &scorers {
254 cumsum += s.max_score();
255 prefix_sums.push(cumsum);
256 }
257
258 debug!(
259 "Creating BlockMaxScoreExecutor: num_scorers={}, k={}, total_upper={:.4}, heap_factor={:.2}",
260 scorers.len(),
261 k,
262 cumsum,
263 heap_factor
264 );
265
266 Self {
267 scorers,
268 prefix_sums,
269 collector: ScoreCollector::new(k),
270 heap_factor: heap_factor.clamp(0.0, 1.0),
271 }
272 }
273
274 #[inline]
277 fn find_partition(&self) -> usize {
278 let threshold = self.collector.threshold() * self.heap_factor;
279 self.prefix_sums
280 .iter()
281 .position(|&sum| sum > threshold)
282 .unwrap_or(self.scorers.len())
283 }
284
285 pub fn execute(mut self) -> Vec<ScoredDoc> {
295 if self.scorers.is_empty() {
296 debug!("BlockMaxScoreExecutor: no scorers, returning empty results");
297 return Vec::new();
298 }
299
300 let n = self.scorers.len();
301 let mut docs_scored = 0u64;
302 let mut docs_skipped = 0u64;
303 let mut blocks_skipped = 0u64;
304 let mut conjunction_skipped = 0u64;
305
306 let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
308
309 loop {
310 let partition = self.find_partition();
311
312 if partition >= n {
314 debug!("BlockMaxScore: all terms non-essential, early termination");
315 break;
316 }
317
318 let mut min_doc = u32::MAX;
320 for i in partition..n {
321 let doc = self.scorers[i].doc();
322 if doc < min_doc {
323 min_doc = doc;
324 }
325 }
326
327 if min_doc == u32::MAX {
328 break; }
330
331 let non_essential_upper = if partition > 0 {
332 self.prefix_sums[partition - 1]
333 } else {
334 0.0
335 };
336 let adjusted_threshold = self.collector.threshold() * self.heap_factor;
337
338 if self.collector.len() >= self.collector.k {
343 let present_upper: f32 = (partition..n)
344 .filter(|&i| self.scorers[i].doc() == min_doc)
345 .map(|i| self.scorers[i].max_score())
346 .sum();
347
348 if present_upper + non_essential_upper <= adjusted_threshold {
349 for i in partition..n {
351 if self.scorers[i].doc() == min_doc {
352 self.scorers[i].advance();
353 }
354 }
355 conjunction_skipped += 1;
356 continue;
357 }
358 }
359
360 if self.collector.len() >= self.collector.k {
364 let block_max_sum: f32 = (partition..n)
365 .filter(|&i| self.scorers[i].doc() == min_doc)
366 .map(|i| self.scorers[i].current_block_max_score())
367 .sum();
368
369 if block_max_sum + non_essential_upper <= adjusted_threshold {
370 for i in partition..n {
371 if self.scorers[i].doc() == min_doc {
372 self.scorers[i].skip_to_next_block();
373 }
374 }
375 blocks_skipped += 1;
376 continue;
377 }
378 }
379
380 ordinal_scores.clear();
383
384 for i in partition..n {
385 if self.scorers[i].doc() == min_doc {
386 while self.scorers[i].doc() == min_doc {
387 ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
388 self.scorers[i].advance();
389 }
390 }
391 }
392
393 let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
395
396 if self.collector.len() >= self.collector.k
397 && essential_total + non_essential_upper <= adjusted_threshold
398 {
399 docs_skipped += 1;
400 continue;
401 }
402
403 for i in 0..partition {
405 let doc = self.scorers[i].seek(min_doc);
406 if doc == min_doc {
407 while self.scorers[i].doc() == min_doc {
408 ordinal_scores.push((self.scorers[i].ordinal(), self.scorers[i].score()));
409 self.scorers[i].advance();
410 }
411 }
412 }
413
414 ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
416 let mut j = 0;
417 while j < ordinal_scores.len() {
418 let current_ord = ordinal_scores[j].0;
419 let mut score = 0.0f32;
420 while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
421 score += ordinal_scores[j].1;
422 j += 1;
423 }
424
425 trace!(
426 "Doc {}: ordinal={}, score={:.4}, threshold={:.4}",
427 min_doc, current_ord, score, adjusted_threshold
428 );
429
430 if self
431 .collector
432 .insert_with_ordinal(min_doc, score, current_ord)
433 {
434 docs_scored += 1;
435 } else {
436 docs_skipped += 1;
437 }
438 }
439 }
440
441 let results: Vec<ScoredDoc> = self
442 .collector
443 .into_sorted_results()
444 .into_iter()
445 .map(|(doc_id, score, ordinal)| ScoredDoc {
446 doc_id,
447 score,
448 ordinal,
449 })
450 .collect();
451
452 debug!(
453 "BlockMaxScoreExecutor completed: scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
454 docs_scored,
455 docs_skipped,
456 blocks_skipped,
457 conjunction_skipped,
458 results.len(),
459 results.first().map(|r| r.score).unwrap_or(0.0)
460 );
461
462 results
463 }
464}
465
466pub struct TextTermScorer {
471 iter: crate::structures::BlockPostingIterator<'static>,
473 idf: f32,
475 avg_field_len: f32,
477 max_score: f32,
479}
480
481impl TextTermScorer {
482 pub fn new(
484 posting_list: crate::structures::BlockPostingList,
485 idf: f32,
486 avg_field_len: f32,
487 ) -> Self {
488 let max_tf = posting_list.max_tf() as f32;
490 let doc_count = posting_list.doc_count();
491 let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
492
493 debug!(
494 "Created TextTermScorer: doc_count={}, max_tf={:.0}, idf={:.4}, avg_field_len={:.2}, max_score={:.4}",
495 doc_count, max_tf, idf, avg_field_len, max_score
496 );
497
498 Self {
499 iter: posting_list.into_iterator(),
500 idf,
501 avg_field_len,
502 max_score,
503 }
504 }
505}
506
507impl ScoringIterator for TextTermScorer {
508 #[inline]
509 fn doc(&self) -> DocId {
510 self.iter.doc()
511 }
512
513 #[inline]
514 fn advance(&mut self) -> DocId {
515 self.iter.advance()
516 }
517
518 #[inline]
519 fn seek(&mut self, target: DocId) -> DocId {
520 self.iter.seek(target)
521 }
522
523 #[inline]
524 fn score(&self) -> f32 {
525 let tf = self.iter.term_freq() as f32;
526 super::bm25_score(tf, self.idf, tf, self.avg_field_len)
528 }
529
530 #[inline]
531 fn max_score(&self) -> f32 {
532 self.max_score
533 }
534
535 #[inline]
536 fn current_block_max_score(&self) -> f32 {
537 let block_max_tf = self.iter.current_block_max_tf() as f32;
539 super::bm25_upper_bound(block_max_tf.max(1.0), self.idf)
540 }
541
542 #[inline]
543 fn skip_to_next_block(&mut self) -> DocId {
544 self.iter.skip_to_next_block()
545 }
546}
547
548pub struct SparseTermScorer<'a> {
552 iter: crate::structures::BlockSparsePostingIterator<'a>,
554 query_weight: f32,
556 max_score: f32,
558}
559
560impl<'a> SparseTermScorer<'a> {
561 pub fn new(posting_list: &'a BlockSparsePostingList, query_weight: f32) -> Self {
566 let max_score = query_weight.abs() * posting_list.global_max_weight();
569 Self {
570 iter: posting_list.iterator(),
571 query_weight,
572 max_score,
573 }
574 }
575
576 pub fn from_arc(posting_list: &'a Arc<BlockSparsePostingList>, query_weight: f32) -> Self {
578 Self::new(posting_list.as_ref(), query_weight)
579 }
580}
581
582impl ScoringIterator for SparseTermScorer<'_> {
583 #[inline]
584 fn doc(&self) -> DocId {
585 self.iter.doc()
586 }
587
588 #[inline]
589 fn ordinal(&self) -> u16 {
590 self.iter.ordinal()
591 }
592
593 #[inline]
594 fn advance(&mut self) -> DocId {
595 self.iter.advance()
596 }
597
598 #[inline]
599 fn seek(&mut self, target: DocId) -> DocId {
600 self.iter.seek(target)
601 }
602
603 #[inline]
604 fn score(&self) -> f32 {
605 self.query_weight * self.iter.weight()
607 }
608
609 #[inline]
610 fn max_score(&self) -> f32 {
611 self.max_score
612 }
613
614 #[inline]
615 fn current_block_max_score(&self) -> f32 {
616 self.iter
618 .current_block_max_contribution(self.query_weight.abs())
619 }
620
621 #[inline]
622 fn skip_to_next_block(&mut self) -> DocId {
623 self.iter.skip_to_next_block()
624 }
625}
626
627pub struct BmpExecutor<'a> {
639 sparse_index: &'a crate::segment::SparseIndex,
641 query_terms: Vec<(u32, f32)>,
643 k: usize,
645 heap_factor: f32,
647}
648
649struct BmpBlockEntry {
651 contribution: f32,
653 term_idx: usize,
655 block_idx: usize,
657}
658
659impl PartialEq for BmpBlockEntry {
660 fn eq(&self, other: &Self) -> bool {
661 self.contribution == other.contribution
662 }
663}
664
665impl Eq for BmpBlockEntry {}
666
667impl Ord for BmpBlockEntry {
668 fn cmp(&self, other: &Self) -> Ordering {
669 self.contribution
671 .partial_cmp(&other.contribution)
672 .unwrap_or(Ordering::Equal)
673 }
674}
675
676impl PartialOrd for BmpBlockEntry {
677 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
678 Some(self.cmp(other))
679 }
680}
681
682impl<'a> BmpExecutor<'a> {
683 pub fn new(
688 sparse_index: &'a crate::segment::SparseIndex,
689 query_terms: Vec<(u32, f32)>,
690 k: usize,
691 heap_factor: f32,
692 ) -> Self {
693 Self {
694 sparse_index,
695 query_terms,
696 k,
697 heap_factor: heap_factor.clamp(0.0, 1.0),
698 }
699 }
700
701 pub async fn execute(self) -> crate::Result<Vec<ScoredDoc>> {
706 use rustc_hash::FxHashMap;
707
708 if self.query_terms.is_empty() {
709 return Ok(Vec::new());
710 }
711
712 let num_terms = self.query_terms.len();
713
714 let mut block_queue: BinaryHeap<BmpBlockEntry> = BinaryHeap::new();
716 let mut remaining_max: Vec<f32> = Vec::with_capacity(num_terms);
717
718 for (term_idx, &(dim_id, qw)) in self.query_terms.iter().enumerate() {
719 let mut term_remaining = 0.0f32;
720
721 if let Some((skip_entries, _global_max)) = self.sparse_index.get_skip_list(dim_id) {
722 for (block_idx, skip) in skip_entries.iter().enumerate() {
723 let contribution = qw.abs() * skip.max_weight;
724 term_remaining += contribution;
725 block_queue.push(BmpBlockEntry {
726 contribution,
727 term_idx,
728 block_idx,
729 });
730 }
731 }
732 remaining_max.push(term_remaining);
733 }
734
735 let mut accumulators: FxHashMap<u64, f32> = FxHashMap::default();
739 let mut blocks_processed = 0u64;
740 let mut blocks_skipped = 0u64;
741
742 let mut top_k = ScoreCollector::new(self.k);
745
746 let mut doc_ids_buf: Vec<u32> = Vec::with_capacity(128);
748 let mut weights_buf: Vec<f32> = Vec::with_capacity(128);
749 let mut ordinals_buf: Vec<u16> = Vec::with_capacity(128);
750
751 while let Some(entry) = block_queue.pop() {
753 remaining_max[entry.term_idx] -= entry.contribution;
755
756 let total_remaining: f32 = remaining_max.iter().sum();
759 let adjusted_threshold = top_k.threshold() * self.heap_factor;
760 if top_k.len() >= self.k && total_remaining <= adjusted_threshold {
761 blocks_skipped += block_queue.len() as u64;
762 debug!(
763 "BMP early termination after {} blocks: remaining={:.4} <= threshold={:.4}",
764 blocks_processed, total_remaining, adjusted_threshold
765 );
766 break;
767 }
768
769 let dim_id = self.query_terms[entry.term_idx].0;
771 let block = match self.sparse_index.get_block(dim_id, entry.block_idx).await? {
772 Some(b) => b,
773 None => continue,
774 };
775
776 let qw = self.query_terms[entry.term_idx].1;
778 block.decode_doc_ids_into(&mut doc_ids_buf);
779 block.decode_scored_weights_into(qw, &mut weights_buf);
780 block.decode_ordinals_into(&mut ordinals_buf);
781
782 for i in 0..block.header.count as usize {
783 let score_contribution = weights_buf[i];
784 let key = (doc_ids_buf[i] as u64) << 16 | ordinals_buf[i] as u64;
785 let acc = accumulators.entry(key).or_insert(0.0);
786 *acc += score_contribution;
787 top_k.insert_with_ordinal(doc_ids_buf[i], *acc, ordinals_buf[i]);
791 }
792
793 blocks_processed += 1;
794 }
795
796 let num_accumulators = accumulators.len();
798 let mut scored: Vec<ScoredDoc> = accumulators
799 .into_iter()
800 .map(|(key, score)| ScoredDoc {
801 doc_id: (key >> 16) as DocId,
802 score,
803 ordinal: (key & 0xFFFF) as u16,
804 })
805 .collect();
806 scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
807 scored.truncate(self.k);
808 let results = scored;
809
810 debug!(
811 "BmpExecutor completed: blocks_processed={}, blocks_skipped={}, accumulators={}, returned={}, top_score={:.4}",
812 blocks_processed,
813 blocks_skipped,
814 num_accumulators,
815 results.len(),
816 results.first().map(|r| r.score).unwrap_or(0.0)
817 );
818
819 Ok(results)
820 }
821}
822
823pub struct LazyBlockMaxScoreExecutor<'a> {
835 sparse_index: &'a crate::segment::SparseIndex,
836 cursors: Vec<LazyTermCursor>,
837 prefix_sums: Vec<f32>,
838 collector: ScoreCollector,
839 heap_factor: f32,
840}
841
842struct LazyTermCursor {
844 dim_id: u32,
845 query_weight: f32,
846 max_score: f32,
847 skip_entries: Vec<crate::structures::SparseSkipEntry>,
849 block_idx: usize,
851 doc_ids: Vec<u32>,
853 ordinals: Vec<u16>,
854 weights: Vec<f32>,
855 pos: usize,
857 block_loaded: bool,
859 exhausted: bool,
860}
861
862impl LazyTermCursor {
863 fn new(
864 dim_id: u32,
865 query_weight: f32,
866 skip_entries: Vec<crate::structures::SparseSkipEntry>,
867 global_max_weight: f32,
868 ) -> Self {
869 let exhausted = skip_entries.is_empty();
870 Self {
871 dim_id,
872 query_weight,
873 max_score: query_weight.abs() * global_max_weight,
874 skip_entries,
875 block_idx: 0,
876 doc_ids: Vec::new(),
877 ordinals: Vec::new(),
878 weights: Vec::new(),
879 pos: 0,
880 block_loaded: false,
881 exhausted,
882 }
883 }
884
885 async fn ensure_block_loaded(
887 &mut self,
888 sparse_index: &crate::segment::SparseIndex,
889 ) -> crate::Result<bool> {
890 if self.exhausted || self.block_loaded {
891 return Ok(!self.exhausted);
892 }
893 match sparse_index.get_block(self.dim_id, self.block_idx).await? {
894 Some(block) => {
895 block.decode_doc_ids_into(&mut self.doc_ids);
896 block.decode_ordinals_into(&mut self.ordinals);
897 block.decode_scored_weights_into(self.query_weight, &mut self.weights);
898 self.pos = 0;
899 self.block_loaded = true;
900 Ok(true)
901 }
902 None => {
903 self.exhausted = true;
904 Ok(false)
905 }
906 }
907 }
908
909 #[inline]
910 fn doc(&self) -> DocId {
911 if self.exhausted {
912 return u32::MAX;
913 }
914 if !self.block_loaded {
915 return self.skip_entries[self.block_idx].first_doc;
918 }
919 self.doc_ids.get(self.pos).copied().unwrap_or(u32::MAX)
920 }
921
922 #[inline]
923 fn ordinal(&self) -> u16 {
924 if !self.block_loaded {
925 return 0;
926 }
927 self.ordinals.get(self.pos).copied().unwrap_or(0)
928 }
929
930 #[inline]
931 fn score(&self) -> f32 {
932 if !self.block_loaded {
933 return 0.0;
934 }
935 self.weights.get(self.pos).copied().unwrap_or(0.0)
936 }
937
938 #[inline]
939 fn current_block_max_score(&self) -> f32 {
940 if self.exhausted {
941 return 0.0;
942 }
943 self.query_weight.abs()
944 * self
945 .skip_entries
946 .get(self.block_idx)
947 .map(|e| e.max_weight)
948 .unwrap_or(0.0)
949 }
950
951 async fn advance(
953 &mut self,
954 sparse_index: &crate::segment::SparseIndex,
955 ) -> crate::Result<DocId> {
956 if self.exhausted {
957 return Ok(u32::MAX);
958 }
959 self.ensure_block_loaded(sparse_index).await?;
960 if self.exhausted {
961 return Ok(u32::MAX);
962 }
963 self.pos += 1;
964 if self.pos >= self.doc_ids.len() {
965 self.block_idx += 1;
966 self.block_loaded = false;
967 if self.block_idx >= self.skip_entries.len() {
968 self.exhausted = true;
969 return Ok(u32::MAX);
970 }
971 }
973 Ok(self.doc())
974 }
975
976 async fn seek(
978 &mut self,
979 sparse_index: &crate::segment::SparseIndex,
980 target: DocId,
981 ) -> crate::Result<DocId> {
982 if self.exhausted {
983 return Ok(u32::MAX);
984 }
985
986 if self.block_loaded
988 && let Some(&last) = self.doc_ids.last()
989 {
990 if last >= target && self.doc_ids[self.pos] < target {
991 let remaining = &self.doc_ids[self.pos..];
993 let offset = crate::structures::simd::find_first_ge_u32(remaining, target);
994 self.pos += offset;
995 if self.pos >= self.doc_ids.len() {
996 self.block_idx += 1;
997 self.block_loaded = false;
998 if self.block_idx >= self.skip_entries.len() {
999 self.exhausted = true;
1000 return Ok(u32::MAX);
1001 }
1002 }
1003 return Ok(self.doc());
1004 }
1005 if self.doc_ids[self.pos] >= target {
1006 return Ok(self.doc());
1007 }
1008 }
1009
1010 let bi = self.skip_entries.iter().position(|e| e.last_doc >= target);
1012 match bi {
1013 Some(idx) => {
1014 if idx != self.block_idx || !self.block_loaded {
1015 self.block_idx = idx;
1016 self.block_loaded = false;
1017 }
1018 self.ensure_block_loaded(sparse_index).await?;
1019 if self.exhausted {
1020 return Ok(u32::MAX);
1021 }
1022 let offset = crate::structures::simd::find_first_ge_u32(&self.doc_ids, target);
1023 self.pos = offset;
1024 if self.pos >= self.doc_ids.len() {
1025 self.block_idx += 1;
1026 self.block_loaded = false;
1027 if self.block_idx >= self.skip_entries.len() {
1028 self.exhausted = true;
1029 return Ok(u32::MAX);
1030 }
1031 self.ensure_block_loaded(sparse_index).await?;
1032 }
1033 Ok(self.doc())
1034 }
1035 None => {
1036 self.exhausted = true;
1037 Ok(u32::MAX)
1038 }
1039 }
1040 }
1041
1042 fn skip_to_next_block(&mut self) -> DocId {
1044 if self.exhausted {
1045 return u32::MAX;
1046 }
1047 self.block_idx += 1;
1048 self.block_loaded = false;
1049 if self.block_idx >= self.skip_entries.len() {
1050 self.exhausted = true;
1051 return u32::MAX;
1052 }
1053 self.skip_entries[self.block_idx].first_doc
1055 }
1056}
1057
1058impl<'a> LazyBlockMaxScoreExecutor<'a> {
1059 pub fn new(
1064 sparse_index: &'a crate::segment::SparseIndex,
1065 query_terms: Vec<(u32, f32)>,
1066 k: usize,
1067 heap_factor: f32,
1068 ) -> Self {
1069 let mut cursors: Vec<LazyTermCursor> = query_terms
1070 .iter()
1071 .filter_map(|&(dim_id, qw)| {
1072 let (skip_entries, global_max) = sparse_index.get_skip_list(dim_id)?;
1073 Some(LazyTermCursor::new(dim_id, qw, skip_entries, global_max))
1074 })
1075 .collect();
1076
1077 cursors.sort_by(|a, b| {
1079 a.max_score
1080 .partial_cmp(&b.max_score)
1081 .unwrap_or(Ordering::Equal)
1082 });
1083
1084 let mut prefix_sums = Vec::with_capacity(cursors.len());
1085 let mut cumsum = 0.0f32;
1086 for c in &cursors {
1087 cumsum += c.max_score;
1088 prefix_sums.push(cumsum);
1089 }
1090
1091 debug!(
1092 "Creating LazyBlockMaxScoreExecutor: num_terms={}, k={}, total_upper={:.4}, heap_factor={:.2}",
1093 cursors.len(),
1094 k,
1095 cumsum,
1096 heap_factor
1097 );
1098
1099 Self {
1100 sparse_index,
1101 cursors,
1102 prefix_sums,
1103 collector: ScoreCollector::new(k),
1104 heap_factor: heap_factor.clamp(0.0, 1.0),
1105 }
1106 }
1107
1108 #[inline]
1109 fn find_partition(&self) -> usize {
1110 let threshold = self.collector.threshold() * self.heap_factor;
1111 self.prefix_sums
1112 .iter()
1113 .position(|&sum| sum > threshold)
1114 .unwrap_or(self.cursors.len())
1115 }
1116
1117 pub async fn execute(mut self) -> crate::Result<Vec<ScoredDoc>> {
1119 if self.cursors.is_empty() {
1120 return Ok(Vec::new());
1121 }
1122
1123 let n = self.cursors.len();
1124 let si = self.sparse_index;
1125
1126 for cursor in &mut self.cursors {
1128 cursor.ensure_block_loaded(si).await?;
1129 }
1130
1131 let mut docs_scored = 0u64;
1132 let mut docs_skipped = 0u64;
1133 let mut blocks_skipped = 0u64;
1134 let mut blocks_loaded = 0u64;
1135 let mut conjunction_skipped = 0u64;
1136 let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
1137
1138 loop {
1139 let partition = self.find_partition();
1140 if partition >= n {
1141 break;
1142 }
1143
1144 let mut min_doc = u32::MAX;
1146 for i in partition..n {
1147 let doc = self.cursors[i].doc();
1148 if doc < min_doc {
1149 min_doc = doc;
1150 }
1151 }
1152 if min_doc == u32::MAX {
1153 break;
1154 }
1155
1156 let non_essential_upper = if partition > 0 {
1157 self.prefix_sums[partition - 1]
1158 } else {
1159 0.0
1160 };
1161 let adjusted_threshold = self.collector.threshold() * self.heap_factor;
1162
1163 if self.collector.len() >= self.collector.k {
1165 let present_upper: f32 = (partition..n)
1166 .filter(|&i| self.cursors[i].doc() == min_doc)
1167 .map(|i| self.cursors[i].max_score)
1168 .sum();
1169
1170 if present_upper + non_essential_upper <= adjusted_threshold {
1171 for i in partition..n {
1172 if self.cursors[i].doc() == min_doc {
1173 self.cursors[i].advance(si).await?;
1174 blocks_loaded += u64::from(self.cursors[i].block_loaded);
1175 }
1176 }
1177 conjunction_skipped += 1;
1178 continue;
1179 }
1180 }
1181
1182 if self.collector.len() >= self.collector.k {
1184 let block_max_sum: f32 = (partition..n)
1185 .filter(|&i| self.cursors[i].doc() == min_doc)
1186 .map(|i| self.cursors[i].current_block_max_score())
1187 .sum();
1188
1189 if block_max_sum + non_essential_upper <= adjusted_threshold {
1190 for i in partition..n {
1191 if self.cursors[i].doc() == min_doc {
1192 self.cursors[i].skip_to_next_block();
1193 self.cursors[i].ensure_block_loaded(si).await?;
1195 blocks_loaded += 1;
1196 }
1197 }
1198 blocks_skipped += 1;
1199 continue;
1200 }
1201 }
1202
1203 ordinal_scores.clear();
1205 for i in partition..n {
1206 if self.cursors[i].doc() == min_doc {
1207 while self.cursors[i].doc() == min_doc {
1208 ordinal_scores.push((self.cursors[i].ordinal(), self.cursors[i].score()));
1209 self.cursors[i].advance(si).await?;
1210 }
1211 }
1212 }
1213
1214 let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
1215 if self.collector.len() >= self.collector.k
1216 && essential_total + non_essential_upper <= adjusted_threshold
1217 {
1218 docs_skipped += 1;
1219 continue;
1220 }
1221
1222 for i in 0..partition {
1224 let doc = self.cursors[i].seek(si, min_doc).await?;
1225 if doc == min_doc {
1226 while self.cursors[i].doc() == min_doc {
1227 ordinal_scores.push((self.cursors[i].ordinal(), self.cursors[i].score()));
1228 self.cursors[i].advance(si).await?;
1229 }
1230 }
1231 }
1232
1233 ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
1235 let mut j = 0;
1236 while j < ordinal_scores.len() {
1237 let current_ord = ordinal_scores[j].0;
1238 let mut score = 0.0f32;
1239 while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
1240 score += ordinal_scores[j].1;
1241 j += 1;
1242 }
1243 if self
1244 .collector
1245 .insert_with_ordinal(min_doc, score, current_ord)
1246 {
1247 docs_scored += 1;
1248 } else {
1249 docs_skipped += 1;
1250 }
1251 }
1252 }
1253
1254 let results: Vec<ScoredDoc> = self
1255 .collector
1256 .into_sorted_results()
1257 .into_iter()
1258 .map(|(doc_id, score, ordinal)| ScoredDoc {
1259 doc_id,
1260 score,
1261 ordinal,
1262 })
1263 .collect();
1264
1265 debug!(
1266 "LazyBlockMaxScoreExecutor completed: scored={}, skipped={}, blocks_skipped={}, blocks_loaded={}, conjunction_skipped={}, returned={}, top_score={:.4}",
1267 docs_scored,
1268 docs_skipped,
1269 blocks_skipped,
1270 blocks_loaded,
1271 conjunction_skipped,
1272 results.len(),
1273 results.first().map(|r| r.score).unwrap_or(0.0)
1274 );
1275
1276 Ok(results)
1277 }
1278}
1279
1280#[cfg(test)]
1281mod tests {
1282 use super::*;
1283
1284 #[test]
1285 fn test_score_collector_basic() {
1286 let mut collector = ScoreCollector::new(3);
1287
1288 collector.insert(1, 1.0);
1289 collector.insert(2, 2.0);
1290 collector.insert(3, 3.0);
1291 assert_eq!(collector.threshold(), 1.0);
1292
1293 collector.insert(4, 4.0);
1294 assert_eq!(collector.threshold(), 2.0);
1295
1296 let results = collector.into_sorted_results();
1297 assert_eq!(results.len(), 3);
1298 assert_eq!(results[0].0, 4); assert_eq!(results[1].0, 3);
1300 assert_eq!(results[2].0, 2);
1301 }
1302
1303 #[test]
1304 fn test_score_collector_threshold() {
1305 let mut collector = ScoreCollector::new(2);
1306
1307 collector.insert(1, 5.0);
1308 collector.insert(2, 3.0);
1309 assert_eq!(collector.threshold(), 3.0);
1310
1311 assert!(!collector.would_enter(2.0));
1313 assert!(!collector.insert(3, 2.0));
1314
1315 assert!(collector.would_enter(4.0));
1317 assert!(collector.insert(4, 4.0));
1318 assert_eq!(collector.threshold(), 4.0);
1319 }
1320
1321 #[test]
1322 fn test_heap_entry_ordering() {
1323 let mut heap = BinaryHeap::new();
1324 heap.push(HeapEntry {
1325 doc_id: 1,
1326 score: 3.0,
1327 ordinal: 0,
1328 });
1329 heap.push(HeapEntry {
1330 doc_id: 2,
1331 score: 1.0,
1332 ordinal: 0,
1333 });
1334 heap.push(HeapEntry {
1335 doc_id: 3,
1336 score: 2.0,
1337 ordinal: 0,
1338 });
1339
1340 assert_eq!(heap.pop().unwrap().score, 1.0);
1342 assert_eq!(heap.pop().unwrap().score, 2.0);
1343 assert_eq!(heap.pop().unwrap().score, 3.0);
1344 }
1345}