1use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11use std::sync::Arc;
12
13use log::{debug, trace};
14
15use crate::DocId;
16use crate::structures::BlockSparsePostingList;
17
18pub trait ScoringIterator {
22 fn doc(&self) -> DocId;
24
25 fn advance(&mut self) -> DocId;
27
28 fn seek(&mut self, target: DocId) -> DocId;
30
31 fn is_exhausted(&self) -> bool {
33 self.doc() == u32::MAX
34 }
35
36 fn score(&self) -> f32;
38
39 fn max_score(&self) -> f32;
41
42 fn current_block_max_score(&self) -> f32;
44
45 fn skip_to_next_block(&mut self) -> DocId {
49 self.advance()
50 }
51}
52
53#[derive(Clone, Copy)]
55pub struct HeapEntry {
56 pub doc_id: DocId,
57 pub score: f32,
58}
59
60impl PartialEq for HeapEntry {
61 fn eq(&self, other: &Self) -> bool {
62 self.score == other.score && self.doc_id == other.doc_id
63 }
64}
65
66impl Eq for HeapEntry {}
67
68impl Ord for HeapEntry {
69 fn cmp(&self, other: &Self) -> Ordering {
70 other
72 .score
73 .partial_cmp(&self.score)
74 .unwrap_or(Ordering::Equal)
75 .then_with(|| self.doc_id.cmp(&other.doc_id))
76 }
77}
78
79impl PartialOrd for HeapEntry {
80 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
81 Some(self.cmp(other))
82 }
83}
84
85pub struct ScoreCollector {
91 heap: BinaryHeap<HeapEntry>,
93 pub k: usize,
94}
95
96impl ScoreCollector {
97 pub fn new(k: usize) -> Self {
99 Self {
100 heap: BinaryHeap::with_capacity(k + 1),
101 k,
102 }
103 }
104
105 #[inline]
107 pub fn threshold(&self) -> f32 {
108 if self.heap.len() >= self.k {
109 self.heap.peek().map(|e| e.score).unwrap_or(0.0)
110 } else {
111 0.0
112 }
113 }
114
115 #[inline]
118 pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
119 if self.heap.len() < self.k {
120 self.heap.push(HeapEntry { doc_id, score });
121 true
122 } else if score > self.threshold() {
123 self.heap.push(HeapEntry { doc_id, score });
124 self.heap.pop(); true
126 } else {
127 false
128 }
129 }
130
131 #[inline]
133 pub fn would_enter(&self, score: f32) -> bool {
134 self.heap.len() < self.k || score > self.threshold()
135 }
136
137 #[inline]
139 pub fn len(&self) -> usize {
140 self.heap.len()
141 }
142
143 #[inline]
145 pub fn is_empty(&self) -> bool {
146 self.heap.is_empty()
147 }
148
149 pub fn into_sorted_results(self) -> Vec<(DocId, f32)> {
151 let mut results: Vec<_> = self
152 .heap
153 .into_vec()
154 .into_iter()
155 .map(|e| (e.doc_id, e.score))
156 .collect();
157
158 results.sort_by(|a, b| {
160 b.1.partial_cmp(&a.1)
161 .unwrap_or(Ordering::Equal)
162 .then_with(|| a.0.cmp(&b.0))
163 });
164
165 results
166 }
167}
168
169#[derive(Debug, Clone, Copy)]
171pub struct ScoredDoc {
172 pub doc_id: DocId,
173 pub score: f32,
174}
175
176pub struct WandExecutor<S: ScoringIterator> {
184 scorers: Vec<S>,
186 collector: ScoreCollector,
188 heap_factor: f32,
193}
194
195impl<S: ScoringIterator> WandExecutor<S> {
196 pub fn new(scorers: Vec<S>, k: usize) -> Self {
198 Self::with_heap_factor(scorers, k, 1.0)
199 }
200
201 pub fn with_heap_factor(scorers: Vec<S>, k: usize, heap_factor: f32) -> Self {
208 let total_upper: f32 = scorers.iter().map(|s| s.max_score()).sum();
209
210 debug!(
211 "Creating WandExecutor: num_scorers={}, k={}, total_upper={:.4}, heap_factor={:.2}",
212 scorers.len(),
213 k,
214 total_upper,
215 heap_factor
216 );
217
218 Self {
219 scorers,
220 collector: ScoreCollector::new(k),
221 heap_factor: heap_factor.clamp(0.0, 1.0),
222 }
223 }
224
225 pub fn execute(mut self) -> Vec<ScoredDoc> {
240 if self.scorers.is_empty() {
241 debug!("WandExecutor: no scorers, returning empty results");
242 return Vec::new();
243 }
244
245 let mut docs_scored = 0u64;
246 let mut docs_skipped = 0u64;
247 let num_scorers = self.scorers.len();
248
249 let mut sorted_indices: Vec<usize> = (0..num_scorers).collect();
251 sorted_indices.sort_by_key(|&i| self.scorers[i].doc());
252
253 loop {
254 let first_active = sorted_indices
256 .iter()
257 .position(|&i| self.scorers[i].doc() != u32::MAX);
258
259 let first_active = match first_active {
260 Some(pos) => pos,
261 None => break, };
263
264 let total_upper: f32 = sorted_indices[first_active..]
267 .iter()
268 .map(|&i| self.scorers[i].max_score())
269 .sum();
270
271 let adjusted_threshold = self.collector.threshold() * self.heap_factor;
272 if self.collector.len() >= self.collector.k && total_upper <= adjusted_threshold {
273 debug!(
274 "Early termination: upper_bound={:.4} <= adjusted_threshold={:.4}",
275 total_upper, adjusted_threshold
276 );
277 break;
278 }
279
280 let mut cumsum = 0.0f32;
282 let mut pivot_pos = first_active;
283
284 for (pos, &idx) in sorted_indices.iter().enumerate().skip(first_active) {
285 cumsum += self.scorers[idx].max_score();
286 if cumsum > adjusted_threshold || self.collector.len() < self.collector.k {
287 pivot_pos = pos;
288 break;
289 }
290 }
291
292 let pivot_idx = sorted_indices[pivot_pos];
293 let pivot_doc = self.scorers[pivot_idx].doc();
294
295 if pivot_doc == u32::MAX {
296 break;
297 }
298
299 let all_at_pivot = sorted_indices[first_active..=pivot_pos]
301 .iter()
302 .all(|&i| self.scorers[i].doc() == pivot_doc);
303
304 if all_at_pivot {
305 let mut score = 0.0f32;
307 let mut matching_terms = 0u32;
308
309 let mut modified_positions: Vec<usize> = Vec::new();
312
313 for (pos, &idx) in sorted_indices.iter().enumerate().skip(first_active) {
314 let doc = self.scorers[idx].doc();
315 if doc == pivot_doc {
316 score += self.scorers[idx].score();
317 matching_terms += 1;
318 self.scorers[idx].advance();
319 modified_positions.push(pos);
320 } else if doc > pivot_doc {
321 break;
322 }
323 }
324
325 trace!(
326 "Doc {}: score={:.4}, matching={}/{}, threshold={:.4}",
327 pivot_doc, score, matching_terms, num_scorers, adjusted_threshold
328 );
329
330 if self.collector.insert(pivot_doc, score) {
331 docs_scored += 1;
332 } else {
333 docs_skipped += 1;
334 }
335
336 for &pos in modified_positions.iter().rev() {
339 let idx = sorted_indices[pos];
340 let new_doc = self.scorers[idx].doc();
341 let mut curr = pos;
343 while curr + 1 < sorted_indices.len()
344 && self.scorers[sorted_indices[curr + 1]].doc() < new_doc
345 {
346 sorted_indices.swap(curr, curr + 1);
347 curr += 1;
348 }
349 }
350 } else {
351 let first_pos = first_active;
353 let first_idx = sorted_indices[first_pos];
354 self.scorers[first_idx].seek(pivot_doc);
355 docs_skipped += 1;
356
357 let new_doc = self.scorers[first_idx].doc();
359 let mut curr = first_pos;
360 while curr + 1 < sorted_indices.len()
361 && self.scorers[sorted_indices[curr + 1]].doc() < new_doc
362 {
363 sorted_indices.swap(curr, curr + 1);
364 curr += 1;
365 }
366 }
367 }
368
369 let results: Vec<ScoredDoc> = self
370 .collector
371 .into_sorted_results()
372 .into_iter()
373 .map(|(doc_id, score)| ScoredDoc { doc_id, score })
374 .collect();
375
376 debug!(
377 "WandExecutor completed: scored={}, skipped={}, returned={}, top_score={:.4}",
378 docs_scored,
379 docs_skipped,
380 results.len(),
381 results.first().map(|r| r.score).unwrap_or(0.0)
382 );
383
384 results
385 }
386}
387
388pub struct TextTermScorer {
393 iter: crate::structures::BlockPostingIterator<'static>,
395 idf: f32,
397 avg_field_len: f32,
399 max_score: f32,
401}
402
403impl TextTermScorer {
404 pub fn new(
406 posting_list: crate::structures::BlockPostingList,
407 idf: f32,
408 avg_field_len: f32,
409 ) -> Self {
410 let max_tf = posting_list.max_tf() as f32;
412 let doc_count = posting_list.doc_count();
413 let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
414
415 debug!(
416 "Created TextTermScorer: doc_count={}, max_tf={:.0}, idf={:.4}, avg_field_len={:.2}, max_score={:.4}",
417 doc_count, max_tf, idf, avg_field_len, max_score
418 );
419
420 Self {
421 iter: posting_list.into_iterator(),
422 idf,
423 avg_field_len,
424 max_score,
425 }
426 }
427}
428
429impl ScoringIterator for TextTermScorer {
430 #[inline]
431 fn doc(&self) -> DocId {
432 self.iter.doc()
433 }
434
435 #[inline]
436 fn advance(&mut self) -> DocId {
437 self.iter.advance()
438 }
439
440 #[inline]
441 fn seek(&mut self, target: DocId) -> DocId {
442 self.iter.seek(target)
443 }
444
445 #[inline]
446 fn score(&self) -> f32 {
447 let tf = self.iter.term_freq() as f32;
448 super::bm25_score(tf, self.idf, tf, self.avg_field_len)
450 }
451
452 #[inline]
453 fn max_score(&self) -> f32 {
454 self.max_score
455 }
456
457 #[inline]
458 fn current_block_max_score(&self) -> f32 {
459 let block_max_tf = self.iter.current_block_max_tf() as f32;
461 super::bm25_upper_bound(block_max_tf.max(1.0), self.idf)
462 }
463
464 #[inline]
465 fn skip_to_next_block(&mut self) -> DocId {
466 self.iter.skip_to_next_block()
467 }
468}
469
470pub struct SparseTermScorer<'a> {
474 iter: crate::structures::BlockSparsePostingIterator<'a>,
476 query_weight: f32,
478 max_score: f32,
480}
481
482impl<'a> SparseTermScorer<'a> {
483 pub fn new(posting_list: &'a BlockSparsePostingList, query_weight: f32) -> Self {
485 let max_score = query_weight * posting_list.global_max_weight();
486 Self {
487 iter: posting_list.iterator(),
488 query_weight,
489 max_score,
490 }
491 }
492
493 pub fn from_arc(posting_list: &'a Arc<BlockSparsePostingList>, query_weight: f32) -> Self {
495 Self::new(posting_list.as_ref(), query_weight)
496 }
497}
498
499impl ScoringIterator for SparseTermScorer<'_> {
500 #[inline]
501 fn doc(&self) -> DocId {
502 self.iter.doc()
503 }
504
505 #[inline]
506 fn advance(&mut self) -> DocId {
507 self.iter.advance()
508 }
509
510 #[inline]
511 fn seek(&mut self, target: DocId) -> DocId {
512 self.iter.seek(target)
513 }
514
515 #[inline]
516 fn score(&self) -> f32 {
517 self.query_weight * self.iter.weight()
519 }
520
521 #[inline]
522 fn max_score(&self) -> f32 {
523 self.max_score
524 }
525
526 #[inline]
527 fn current_block_max_score(&self) -> f32 {
528 self.iter.current_block_max_contribution(self.query_weight)
529 }
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535
536 #[test]
537 fn test_score_collector_basic() {
538 let mut collector = ScoreCollector::new(3);
539
540 collector.insert(1, 1.0);
541 collector.insert(2, 2.0);
542 collector.insert(3, 3.0);
543 assert_eq!(collector.threshold(), 1.0);
544
545 collector.insert(4, 4.0);
546 assert_eq!(collector.threshold(), 2.0);
547
548 let results = collector.into_sorted_results();
549 assert_eq!(results.len(), 3);
550 assert_eq!(results[0].0, 4); assert_eq!(results[1].0, 3);
552 assert_eq!(results[2].0, 2);
553 }
554
555 #[test]
556 fn test_score_collector_threshold() {
557 let mut collector = ScoreCollector::new(2);
558
559 collector.insert(1, 5.0);
560 collector.insert(2, 3.0);
561 assert_eq!(collector.threshold(), 3.0);
562
563 assert!(!collector.would_enter(2.0));
565 assert!(!collector.insert(3, 2.0));
566
567 assert!(collector.would_enter(4.0));
569 assert!(collector.insert(4, 4.0));
570 assert_eq!(collector.threshold(), 4.0);
571 }
572
573 #[test]
574 fn test_heap_entry_ordering() {
575 let mut heap = BinaryHeap::new();
576 heap.push(HeapEntry {
577 doc_id: 1,
578 score: 3.0,
579 });
580 heap.push(HeapEntry {
581 doc_id: 2,
582 score: 1.0,
583 });
584 heap.push(HeapEntry {
585 doc_id: 3,
586 score: 2.0,
587 });
588
589 assert_eq!(heap.pop().unwrap().score, 1.0);
591 assert_eq!(heap.pop().unwrap().score, 2.0);
592 assert_eq!(heap.pop().unwrap().score, 3.0);
593 }
594}