1use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11use std::sync::Arc;
12
13use log::{debug, trace};
14
15use crate::DocId;
16use crate::structures::BlockSparsePostingList;
17
18pub trait ScoringIterator {
22 fn doc(&self) -> DocId;
24
25 fn advance(&mut self) -> DocId;
27
28 fn seek(&mut self, target: DocId) -> DocId;
30
31 fn is_exhausted(&self) -> bool {
33 self.doc() == u32::MAX
34 }
35
36 fn score(&self) -> f32;
38
39 fn max_score(&self) -> f32;
41
42 fn current_block_max_score(&self) -> f32;
44
45 fn skip_to_next_block(&mut self) -> DocId {
49 self.advance()
50 }
51}
52
53#[derive(Clone, Copy)]
55pub struct HeapEntry {
56 pub doc_id: DocId,
57 pub score: f32,
58}
59
60impl PartialEq for HeapEntry {
61 fn eq(&self, other: &Self) -> bool {
62 self.score == other.score && self.doc_id == other.doc_id
63 }
64}
65
66impl Eq for HeapEntry {}
67
68impl Ord for HeapEntry {
69 fn cmp(&self, other: &Self) -> Ordering {
70 other
72 .score
73 .partial_cmp(&self.score)
74 .unwrap_or(Ordering::Equal)
75 .then_with(|| self.doc_id.cmp(&other.doc_id))
76 }
77}
78
79impl PartialOrd for HeapEntry {
80 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
81 Some(self.cmp(other))
82 }
83}
84
85pub struct ScoreCollector {
91 heap: BinaryHeap<HeapEntry>,
93 pub k: usize,
94}
95
96impl ScoreCollector {
97 pub fn new(k: usize) -> Self {
99 let capacity = k.saturating_add(1).min(1_000_000);
101 Self {
102 heap: BinaryHeap::with_capacity(capacity),
103 k,
104 }
105 }
106
107 #[inline]
109 pub fn threshold(&self) -> f32 {
110 if self.heap.len() >= self.k {
111 self.heap.peek().map(|e| e.score).unwrap_or(0.0)
112 } else {
113 0.0
114 }
115 }
116
117 #[inline]
120 pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
121 if self.heap.len() < self.k {
122 self.heap.push(HeapEntry { doc_id, score });
123 true
124 } else if score > self.threshold() {
125 self.heap.push(HeapEntry { doc_id, score });
126 self.heap.pop(); true
128 } else {
129 false
130 }
131 }
132
133 #[inline]
135 pub fn would_enter(&self, score: f32) -> bool {
136 self.heap.len() < self.k || score > self.threshold()
137 }
138
139 #[inline]
141 pub fn len(&self) -> usize {
142 self.heap.len()
143 }
144
145 #[inline]
147 pub fn is_empty(&self) -> bool {
148 self.heap.is_empty()
149 }
150
151 pub fn into_sorted_results(self) -> Vec<(DocId, f32)> {
153 let mut results: Vec<_> = self
154 .heap
155 .into_vec()
156 .into_iter()
157 .map(|e| (e.doc_id, e.score))
158 .collect();
159
160 results.sort_by(|a, b| {
162 b.1.partial_cmp(&a.1)
163 .unwrap_or(Ordering::Equal)
164 .then_with(|| a.0.cmp(&b.0))
165 });
166
167 results
168 }
169}
170
171#[derive(Debug, Clone, Copy)]
173pub struct ScoredDoc {
174 pub doc_id: DocId,
175 pub score: f32,
176}
177
178pub struct WandExecutor<S: ScoringIterator> {
186 scorers: Vec<S>,
188 collector: ScoreCollector,
190 heap_factor: f32,
195}
196
197impl<S: ScoringIterator> WandExecutor<S> {
198 pub fn new(scorers: Vec<S>, k: usize) -> Self {
200 Self::with_heap_factor(scorers, k, 1.0)
201 }
202
203 pub fn with_heap_factor(scorers: Vec<S>, k: usize, heap_factor: f32) -> Self {
210 let total_upper: f32 = scorers.iter().map(|s| s.max_score()).sum();
211
212 debug!(
213 "Creating WandExecutor: num_scorers={}, k={}, total_upper={:.4}, heap_factor={:.2}",
214 scorers.len(),
215 k,
216 total_upper,
217 heap_factor
218 );
219
220 Self {
221 scorers,
222 collector: ScoreCollector::new(k),
223 heap_factor: heap_factor.clamp(0.0, 1.0),
224 }
225 }
226
227 pub fn execute(mut self) -> Vec<ScoredDoc> {
242 if self.scorers.is_empty() {
243 debug!("WandExecutor: no scorers, returning empty results");
244 return Vec::new();
245 }
246
247 let mut docs_scored = 0u64;
248 let mut docs_skipped = 0u64;
249 let num_scorers = self.scorers.len();
250
251 let mut sorted_indices: Vec<usize> = (0..num_scorers).collect();
253 sorted_indices.sort_by_key(|&i| self.scorers[i].doc());
254
255 loop {
256 let first_active = sorted_indices
258 .iter()
259 .position(|&i| self.scorers[i].doc() != u32::MAX);
260
261 let first_active = match first_active {
262 Some(pos) => pos,
263 None => break, };
265
266 let total_upper: f32 = sorted_indices[first_active..]
269 .iter()
270 .map(|&i| self.scorers[i].max_score())
271 .sum();
272
273 let adjusted_threshold = self.collector.threshold() * self.heap_factor;
274 if self.collector.len() >= self.collector.k && total_upper <= adjusted_threshold {
275 debug!(
276 "Early termination: upper_bound={:.4} <= adjusted_threshold={:.4}",
277 total_upper, adjusted_threshold
278 );
279 break;
280 }
281
282 let mut cumsum = 0.0f32;
284 let mut pivot_pos = first_active;
285
286 for (pos, &idx) in sorted_indices.iter().enumerate().skip(first_active) {
287 cumsum += self.scorers[idx].max_score();
288 if cumsum > adjusted_threshold || self.collector.len() < self.collector.k {
289 pivot_pos = pos;
290 break;
291 }
292 }
293
294 let pivot_idx = sorted_indices[pivot_pos];
295 let pivot_doc = self.scorers[pivot_idx].doc();
296
297 if pivot_doc == u32::MAX {
298 break;
299 }
300
301 let all_at_pivot = sorted_indices[first_active..=pivot_pos]
303 .iter()
304 .all(|&i| self.scorers[i].doc() == pivot_doc);
305
306 if all_at_pivot {
307 let mut score = 0.0f32;
309 let mut matching_terms = 0u32;
310
311 let mut modified_positions: Vec<usize> = Vec::new();
314
315 for (pos, &idx) in sorted_indices.iter().enumerate().skip(first_active) {
316 let doc = self.scorers[idx].doc();
317 if doc == pivot_doc {
318 score += self.scorers[idx].score();
319 matching_terms += 1;
320 self.scorers[idx].advance();
321 modified_positions.push(pos);
322 } else if doc > pivot_doc {
323 break;
324 }
325 }
326
327 trace!(
328 "Doc {}: score={:.4}, matching={}/{}, threshold={:.4}",
329 pivot_doc, score, matching_terms, num_scorers, adjusted_threshold
330 );
331
332 if self.collector.insert(pivot_doc, score) {
333 docs_scored += 1;
334 } else {
335 docs_skipped += 1;
336 }
337
338 for &pos in modified_positions.iter().rev() {
341 let idx = sorted_indices[pos];
342 let new_doc = self.scorers[idx].doc();
343 let mut curr = pos;
345 while curr + 1 < sorted_indices.len()
346 && self.scorers[sorted_indices[curr + 1]].doc() < new_doc
347 {
348 sorted_indices.swap(curr, curr + 1);
349 curr += 1;
350 }
351 }
352 } else {
353 let first_pos = first_active;
355 let first_idx = sorted_indices[first_pos];
356 self.scorers[first_idx].seek(pivot_doc);
357 docs_skipped += 1;
358
359 let new_doc = self.scorers[first_idx].doc();
361 let mut curr = first_pos;
362 while curr + 1 < sorted_indices.len()
363 && self.scorers[sorted_indices[curr + 1]].doc() < new_doc
364 {
365 sorted_indices.swap(curr, curr + 1);
366 curr += 1;
367 }
368 }
369 }
370
371 let results: Vec<ScoredDoc> = self
372 .collector
373 .into_sorted_results()
374 .into_iter()
375 .map(|(doc_id, score)| ScoredDoc { doc_id, score })
376 .collect();
377
378 debug!(
379 "WandExecutor completed: scored={}, skipped={}, returned={}, top_score={:.4}",
380 docs_scored,
381 docs_skipped,
382 results.len(),
383 results.first().map(|r| r.score).unwrap_or(0.0)
384 );
385
386 results
387 }
388}
389
390pub struct TextTermScorer {
395 iter: crate::structures::BlockPostingIterator<'static>,
397 idf: f32,
399 avg_field_len: f32,
401 max_score: f32,
403}
404
405impl TextTermScorer {
406 pub fn new(
408 posting_list: crate::structures::BlockPostingList,
409 idf: f32,
410 avg_field_len: f32,
411 ) -> Self {
412 let max_tf = posting_list.max_tf() as f32;
414 let doc_count = posting_list.doc_count();
415 let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
416
417 debug!(
418 "Created TextTermScorer: doc_count={}, max_tf={:.0}, idf={:.4}, avg_field_len={:.2}, max_score={:.4}",
419 doc_count, max_tf, idf, avg_field_len, max_score
420 );
421
422 Self {
423 iter: posting_list.into_iterator(),
424 idf,
425 avg_field_len,
426 max_score,
427 }
428 }
429}
430
431impl ScoringIterator for TextTermScorer {
432 #[inline]
433 fn doc(&self) -> DocId {
434 self.iter.doc()
435 }
436
437 #[inline]
438 fn advance(&mut self) -> DocId {
439 self.iter.advance()
440 }
441
442 #[inline]
443 fn seek(&mut self, target: DocId) -> DocId {
444 self.iter.seek(target)
445 }
446
447 #[inline]
448 fn score(&self) -> f32 {
449 let tf = self.iter.term_freq() as f32;
450 super::bm25_score(tf, self.idf, tf, self.avg_field_len)
452 }
453
454 #[inline]
455 fn max_score(&self) -> f32 {
456 self.max_score
457 }
458
459 #[inline]
460 fn current_block_max_score(&self) -> f32 {
461 let block_max_tf = self.iter.current_block_max_tf() as f32;
463 super::bm25_upper_bound(block_max_tf.max(1.0), self.idf)
464 }
465
466 #[inline]
467 fn skip_to_next_block(&mut self) -> DocId {
468 self.iter.skip_to_next_block()
469 }
470}
471
472pub struct SparseTermScorer<'a> {
476 iter: crate::structures::BlockSparsePostingIterator<'a>,
478 query_weight: f32,
480 max_score: f32,
482}
483
484impl<'a> SparseTermScorer<'a> {
485 pub fn new(posting_list: &'a BlockSparsePostingList, query_weight: f32) -> Self {
487 let max_score = query_weight * posting_list.global_max_weight();
488 Self {
489 iter: posting_list.iterator(),
490 query_weight,
491 max_score,
492 }
493 }
494
495 pub fn from_arc(posting_list: &'a Arc<BlockSparsePostingList>, query_weight: f32) -> Self {
497 Self::new(posting_list.as_ref(), query_weight)
498 }
499}
500
501impl ScoringIterator for SparseTermScorer<'_> {
502 #[inline]
503 fn doc(&self) -> DocId {
504 self.iter.doc()
505 }
506
507 #[inline]
508 fn advance(&mut self) -> DocId {
509 self.iter.advance()
510 }
511
512 #[inline]
513 fn seek(&mut self, target: DocId) -> DocId {
514 self.iter.seek(target)
515 }
516
517 #[inline]
518 fn score(&self) -> f32 {
519 self.query_weight * self.iter.weight()
521 }
522
523 #[inline]
524 fn max_score(&self) -> f32 {
525 self.max_score
526 }
527
528 #[inline]
529 fn current_block_max_score(&self) -> f32 {
530 self.iter.current_block_max_contribution(self.query_weight)
531 }
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537
538 #[test]
539 fn test_score_collector_basic() {
540 let mut collector = ScoreCollector::new(3);
541
542 collector.insert(1, 1.0);
543 collector.insert(2, 2.0);
544 collector.insert(3, 3.0);
545 assert_eq!(collector.threshold(), 1.0);
546
547 collector.insert(4, 4.0);
548 assert_eq!(collector.threshold(), 2.0);
549
550 let results = collector.into_sorted_results();
551 assert_eq!(results.len(), 3);
552 assert_eq!(results[0].0, 4); assert_eq!(results[1].0, 3);
554 assert_eq!(results[2].0, 2);
555 }
556
557 #[test]
558 fn test_score_collector_threshold() {
559 let mut collector = ScoreCollector::new(2);
560
561 collector.insert(1, 5.0);
562 collector.insert(2, 3.0);
563 assert_eq!(collector.threshold(), 3.0);
564
565 assert!(!collector.would_enter(2.0));
567 assert!(!collector.insert(3, 2.0));
568
569 assert!(collector.would_enter(4.0));
571 assert!(collector.insert(4, 4.0));
572 assert_eq!(collector.threshold(), 4.0);
573 }
574
575 #[test]
576 fn test_heap_entry_ordering() {
577 let mut heap = BinaryHeap::new();
578 heap.push(HeapEntry {
579 doc_id: 1,
580 score: 3.0,
581 });
582 heap.push(HeapEntry {
583 doc_id: 2,
584 score: 1.0,
585 });
586 heap.push(HeapEntry {
587 doc_id: 3,
588 score: 2.0,
589 });
590
591 assert_eq!(heap.pop().unwrap().score, 1.0);
593 assert_eq!(heap.pop().unwrap().score, 2.0);
594 assert_eq!(heap.pop().unwrap().score, 3.0);
595 }
596}