1use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11use std::sync::Arc;
12
13use log::{debug, trace};
14
15use crate::DocId;
16use crate::structures::BlockSparsePostingList;
17
18pub trait ScoringIterator {
22 fn doc(&self) -> DocId;
24
25 fn ordinal(&self) -> u16 {
27 0
28 }
29
30 fn advance(&mut self) -> DocId;
32
33 fn seek(&mut self, target: DocId) -> DocId;
35
36 fn is_exhausted(&self) -> bool {
38 self.doc() == u32::MAX
39 }
40
41 fn score(&self) -> f32;
43
44 fn max_score(&self) -> f32;
46
47 fn current_block_max_score(&self) -> f32;
49
50 fn skip_to_next_block(&mut self) -> DocId {
54 self.advance()
55 }
56}
57
58#[derive(Clone, Copy)]
60pub struct HeapEntry {
61 pub doc_id: DocId,
62 pub score: f32,
63 pub ordinal: u16,
64}
65
66impl PartialEq for HeapEntry {
67 fn eq(&self, other: &Self) -> bool {
68 self.score == other.score && self.doc_id == other.doc_id
69 }
70}
71
72impl Eq for HeapEntry {}
73
74impl Ord for HeapEntry {
75 fn cmp(&self, other: &Self) -> Ordering {
76 other
78 .score
79 .partial_cmp(&self.score)
80 .unwrap_or(Ordering::Equal)
81 .then_with(|| self.doc_id.cmp(&other.doc_id))
82 }
83}
84
85impl PartialOrd for HeapEntry {
86 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
87 Some(self.cmp(other))
88 }
89}
90
91pub struct ScoreCollector {
97 heap: BinaryHeap<HeapEntry>,
99 pub k: usize,
100}
101
102impl ScoreCollector {
103 pub fn new(k: usize) -> Self {
105 let capacity = k.saturating_add(1).min(1_000_000);
107 Self {
108 heap: BinaryHeap::with_capacity(capacity),
109 k,
110 }
111 }
112
113 #[inline]
115 pub fn threshold(&self) -> f32 {
116 if self.heap.len() >= self.k {
117 self.heap.peek().map(|e| e.score).unwrap_or(0.0)
118 } else {
119 0.0
120 }
121 }
122
123 #[inline]
126 pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
127 self.insert_with_ordinal(doc_id, score, 0)
128 }
129
130 #[inline]
133 pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
134 if self.heap.len() < self.k {
135 self.heap.push(HeapEntry {
136 doc_id,
137 score,
138 ordinal,
139 });
140 true
141 } else if score > self.threshold() {
142 self.heap.push(HeapEntry {
143 doc_id,
144 score,
145 ordinal,
146 });
147 self.heap.pop(); true
149 } else {
150 false
151 }
152 }
153
154 #[inline]
156 pub fn would_enter(&self, score: f32) -> bool {
157 self.heap.len() < self.k || score > self.threshold()
158 }
159
160 #[inline]
162 pub fn len(&self) -> usize {
163 self.heap.len()
164 }
165
166 #[inline]
168 pub fn is_empty(&self) -> bool {
169 self.heap.is_empty()
170 }
171
172 pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
174 let mut results: Vec<_> = self
175 .heap
176 .into_vec()
177 .into_iter()
178 .map(|e| (e.doc_id, e.score, e.ordinal))
179 .collect();
180
181 results.sort_by(|a, b| {
183 b.1.partial_cmp(&a.1)
184 .unwrap_or(Ordering::Equal)
185 .then_with(|| a.0.cmp(&b.0))
186 });
187
188 results
189 }
190}
191
192#[derive(Debug, Clone, Copy)]
194pub struct ScoredDoc {
195 pub doc_id: DocId,
196 pub score: f32,
197 pub ordinal: u16,
199}
200
201pub struct WandExecutor<S: ScoringIterator> {
209 scorers: Vec<S>,
211 collector: ScoreCollector,
213 heap_factor: f32,
218}
219
220impl<S: ScoringIterator> WandExecutor<S> {
221 pub fn new(scorers: Vec<S>, k: usize) -> Self {
223 Self::with_heap_factor(scorers, k, 1.0)
224 }
225
226 pub fn with_heap_factor(scorers: Vec<S>, k: usize, heap_factor: f32) -> Self {
233 let total_upper: f32 = scorers.iter().map(|s| s.max_score()).sum();
234
235 debug!(
236 "Creating WandExecutor: num_scorers={}, k={}, total_upper={:.4}, heap_factor={:.2}",
237 scorers.len(),
238 k,
239 total_upper,
240 heap_factor
241 );
242
243 Self {
244 scorers,
245 collector: ScoreCollector::new(k),
246 heap_factor: heap_factor.clamp(0.0, 1.0),
247 }
248 }
249
250 pub fn execute(mut self) -> Vec<ScoredDoc> {
265 if self.scorers.is_empty() {
266 debug!("WandExecutor: no scorers, returning empty results");
267 return Vec::new();
268 }
269
270 let mut docs_scored = 0u64;
271 let mut docs_skipped = 0u64;
272 let mut blocks_skipped = 0u64;
273 let num_scorers = self.scorers.len();
274
275 let mut sorted_indices: Vec<usize> = (0..num_scorers).collect();
277 sorted_indices.sort_by_key(|&i| self.scorers[i].doc());
278
279 loop {
280 let first_active = sorted_indices
282 .iter()
283 .position(|&i| self.scorers[i].doc() != u32::MAX);
284
285 let first_active = match first_active {
286 Some(pos) => pos,
287 None => break, };
289
290 let total_upper: f32 = sorted_indices[first_active..]
293 .iter()
294 .map(|&i| self.scorers[i].max_score())
295 .sum();
296
297 let adjusted_threshold = self.collector.threshold() * self.heap_factor;
298 if self.collector.len() >= self.collector.k && total_upper <= adjusted_threshold {
299 debug!(
300 "Early termination: upper_bound={:.4} <= adjusted_threshold={:.4}",
301 total_upper, adjusted_threshold
302 );
303 break;
304 }
305
306 let mut cumsum = 0.0f32;
308 let mut pivot_pos = first_active;
309
310 for (pos, &idx) in sorted_indices.iter().enumerate().skip(first_active) {
311 cumsum += self.scorers[idx].max_score();
312 if cumsum > adjusted_threshold || self.collector.len() < self.collector.k {
313 pivot_pos = pos;
314 break;
315 }
316 }
317
318 let pivot_idx = sorted_indices[pivot_pos];
319 let pivot_doc = self.scorers[pivot_idx].doc();
320
321 if pivot_doc == u32::MAX {
322 break;
323 }
324
325 let all_at_pivot = sorted_indices[first_active..=pivot_pos]
327 .iter()
328 .all(|&i| self.scorers[i].doc() == pivot_doc);
329
330 if all_at_pivot {
331 let block_max_sum: f32 = sorted_indices[first_active..=pivot_pos]
334 .iter()
335 .filter(|&&i| self.scorers[i].doc() == pivot_doc)
336 .map(|&i| self.scorers[i].current_block_max_score())
337 .sum();
338
339 if self.collector.len() >= self.collector.k && block_max_sum <= adjusted_threshold {
340 debug!(
342 "Block skip at doc {}: block_max={:.4} <= threshold={:.4}",
343 pivot_doc, block_max_sum, adjusted_threshold
344 );
345
346 for (_pos, &idx) in sorted_indices.iter().enumerate().skip(first_active) {
347 if self.scorers[idx].doc() == pivot_doc {
348 self.scorers[idx].skip_to_next_block();
349 } else if self.scorers[idx].doc() > pivot_doc {
350 break;
351 }
352 }
353
354 sorted_indices[first_active..].sort_by_key(|&i| self.scorers[i].doc());
356 blocks_skipped += 1;
357 continue;
358 }
359
360 let mut score = 0.0f32;
362 let mut matching_terms = 0u32;
363 let mut ordinal: u16 = 0;
364
365 let mut modified_positions: Vec<usize> = Vec::new();
368
369 for (pos, &idx) in sorted_indices.iter().enumerate().skip(first_active) {
370 let doc = self.scorers[idx].doc();
371 if doc == pivot_doc {
372 score += self.scorers[idx].score();
373 if matching_terms == 0 {
375 ordinal = self.scorers[idx].ordinal();
376 }
377 matching_terms += 1;
378 self.scorers[idx].advance();
379 modified_positions.push(pos);
380 } else if doc > pivot_doc {
381 break;
382 }
383 }
384
385 trace!(
386 "Doc {}: score={:.4}, matching={}/{}, threshold={:.4}",
387 pivot_doc, score, matching_terms, num_scorers, adjusted_threshold
388 );
389
390 if self
391 .collector
392 .insert_with_ordinal(pivot_doc, score, ordinal)
393 {
394 docs_scored += 1;
395 } else {
396 docs_skipped += 1;
397 }
398
399 for &pos in modified_positions.iter().rev() {
402 let idx = sorted_indices[pos];
403 let new_doc = self.scorers[idx].doc();
404 let mut curr = pos;
406 while curr + 1 < sorted_indices.len()
407 && self.scorers[sorted_indices[curr + 1]].doc() < new_doc
408 {
409 sorted_indices.swap(curr, curr + 1);
410 curr += 1;
411 }
412 }
413 } else {
414 let first_pos = first_active;
416 let first_idx = sorted_indices[first_pos];
417 self.scorers[first_idx].seek(pivot_doc);
418 docs_skipped += 1;
419
420 let new_doc = self.scorers[first_idx].doc();
422 let mut curr = first_pos;
423 while curr + 1 < sorted_indices.len()
424 && self.scorers[sorted_indices[curr + 1]].doc() < new_doc
425 {
426 sorted_indices.swap(curr, curr + 1);
427 curr += 1;
428 }
429 }
430 }
431
432 let results: Vec<ScoredDoc> = self
433 .collector
434 .into_sorted_results()
435 .into_iter()
436 .map(|(doc_id, score, ordinal)| ScoredDoc {
437 doc_id,
438 score,
439 ordinal,
440 })
441 .collect();
442
443 debug!(
444 "WandExecutor completed: scored={}, skipped={}, blocks_skipped={}, returned={}, top_score={:.4}",
445 docs_scored,
446 docs_skipped,
447 blocks_skipped,
448 results.len(),
449 results.first().map(|r| r.score).unwrap_or(0.0)
450 );
451
452 results
453 }
454}
455
456pub struct TextTermScorer {
461 iter: crate::structures::BlockPostingIterator<'static>,
463 idf: f32,
465 avg_field_len: f32,
467 max_score: f32,
469}
470
471impl TextTermScorer {
472 pub fn new(
474 posting_list: crate::structures::BlockPostingList,
475 idf: f32,
476 avg_field_len: f32,
477 ) -> Self {
478 let max_tf = posting_list.max_tf() as f32;
480 let doc_count = posting_list.doc_count();
481 let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
482
483 debug!(
484 "Created TextTermScorer: doc_count={}, max_tf={:.0}, idf={:.4}, avg_field_len={:.2}, max_score={:.4}",
485 doc_count, max_tf, idf, avg_field_len, max_score
486 );
487
488 Self {
489 iter: posting_list.into_iterator(),
490 idf,
491 avg_field_len,
492 max_score,
493 }
494 }
495}
496
497impl ScoringIterator for TextTermScorer {
498 #[inline]
499 fn doc(&self) -> DocId {
500 self.iter.doc()
501 }
502
503 #[inline]
504 fn advance(&mut self) -> DocId {
505 self.iter.advance()
506 }
507
508 #[inline]
509 fn seek(&mut self, target: DocId) -> DocId {
510 self.iter.seek(target)
511 }
512
513 #[inline]
514 fn score(&self) -> f32 {
515 let tf = self.iter.term_freq() as f32;
516 super::bm25_score(tf, self.idf, tf, self.avg_field_len)
518 }
519
520 #[inline]
521 fn max_score(&self) -> f32 {
522 self.max_score
523 }
524
525 #[inline]
526 fn current_block_max_score(&self) -> f32 {
527 let block_max_tf = self.iter.current_block_max_tf() as f32;
529 super::bm25_upper_bound(block_max_tf.max(1.0), self.idf)
530 }
531
532 #[inline]
533 fn skip_to_next_block(&mut self) -> DocId {
534 self.iter.skip_to_next_block()
535 }
536}
537
538pub struct SparseTermScorer<'a> {
542 iter: crate::structures::BlockSparsePostingIterator<'a>,
544 query_weight: f32,
546 max_score: f32,
548}
549
550impl<'a> SparseTermScorer<'a> {
551 pub fn new(posting_list: &'a BlockSparsePostingList, query_weight: f32) -> Self {
556 let max_score = query_weight.abs() * posting_list.global_max_weight();
559 Self {
560 iter: posting_list.iterator(),
561 query_weight,
562 max_score,
563 }
564 }
565
566 pub fn from_arc(posting_list: &'a Arc<BlockSparsePostingList>, query_weight: f32) -> Self {
568 Self::new(posting_list.as_ref(), query_weight)
569 }
570}
571
572impl ScoringIterator for SparseTermScorer<'_> {
573 #[inline]
574 fn doc(&self) -> DocId {
575 self.iter.doc()
576 }
577
578 #[inline]
579 fn ordinal(&self) -> u16 {
580 self.iter.ordinal()
581 }
582
583 #[inline]
584 fn advance(&mut self) -> DocId {
585 self.iter.advance()
586 }
587
588 #[inline]
589 fn seek(&mut self, target: DocId) -> DocId {
590 self.iter.seek(target)
591 }
592
593 #[inline]
594 fn score(&self) -> f32 {
595 self.query_weight * self.iter.weight()
597 }
598
599 #[inline]
600 fn max_score(&self) -> f32 {
601 self.max_score
602 }
603
604 #[inline]
605 fn current_block_max_score(&self) -> f32 {
606 self.iter
608 .current_block_max_contribution(self.query_weight.abs())
609 }
610}
611
612#[cfg(test)]
613mod tests {
614 use super::*;
615
616 #[test]
617 fn test_score_collector_basic() {
618 let mut collector = ScoreCollector::new(3);
619
620 collector.insert(1, 1.0);
621 collector.insert(2, 2.0);
622 collector.insert(3, 3.0);
623 assert_eq!(collector.threshold(), 1.0);
624
625 collector.insert(4, 4.0);
626 assert_eq!(collector.threshold(), 2.0);
627
628 let results = collector.into_sorted_results();
629 assert_eq!(results.len(), 3);
630 assert_eq!(results[0].0, 4); assert_eq!(results[1].0, 3);
632 assert_eq!(results[2].0, 2);
633 }
634
635 #[test]
636 fn test_score_collector_threshold() {
637 let mut collector = ScoreCollector::new(2);
638
639 collector.insert(1, 5.0);
640 collector.insert(2, 3.0);
641 assert_eq!(collector.threshold(), 3.0);
642
643 assert!(!collector.would_enter(2.0));
645 assert!(!collector.insert(3, 2.0));
646
647 assert!(collector.would_enter(4.0));
649 assert!(collector.insert(4, 4.0));
650 assert_eq!(collector.threshold(), 4.0);
651 }
652
653 #[test]
654 fn test_heap_entry_ordering() {
655 let mut heap = BinaryHeap::new();
656 heap.push(HeapEntry {
657 doc_id: 1,
658 score: 3.0,
659 ordinal: 0,
660 });
661 heap.push(HeapEntry {
662 doc_id: 2,
663 score: 1.0,
664 ordinal: 0,
665 });
666 heap.push(HeapEntry {
667 doc_id: 3,
668 score: 2.0,
669 ordinal: 0,
670 });
671
672 assert_eq!(heap.pop().unwrap().score, 1.0);
674 assert_eq!(heap.pop().unwrap().score, 2.0);
675 assert_eq!(heap.pop().unwrap().score, 3.0);
676 }
677}