1use std::cmp::Ordering;
9use std::collections::BinaryHeap;
10
11use crate::structures::{HorizontalBP128Iterator, HorizontalBP128PostingList};
12use crate::{DocId, Score};
13
14pub const WAND_K1: f32 = 1.2;
16pub const WAND_B: f32 = 0.75;
17
18pub struct TermScorer<'a> {
20 pub iter: HorizontalBP128Iterator<'a>,
22 pub max_score: f32,
24 pub idf: f32,
26 pub term_idx: usize,
28 pub field_boost: f32,
30 pub avg_field_len: f32,
32}
33
34impl<'a> TermScorer<'a> {
35 pub fn new(posting_list: &'a HorizontalBP128PostingList, idf: f32, term_idx: usize) -> Self {
37 Self {
38 iter: posting_list.iterator(),
39 max_score: posting_list.max_score,
40 idf,
41 term_idx,
42 field_boost: 1.0,
43 avg_field_len: 1.0, }
45 }
46
47 pub fn with_bm25f(
49 posting_list: &'a HorizontalBP128PostingList,
50 idf: f32,
51 term_idx: usize,
52 field_boost: f32,
53 avg_field_len: f32,
54 ) -> Self {
55 let max_score = if field_boost != 1.0 {
57 let max_tf = posting_list
59 .blocks
60 .iter()
61 .map(|b| b.max_tf)
62 .max()
63 .unwrap_or(1);
64 HorizontalBP128PostingList::compute_bm25f_upper_bound(max_tf, idf, field_boost)
65 } else {
66 posting_list.max_score
67 };
68
69 Self {
70 iter: posting_list.iterator(),
71 max_score,
72 idf,
73 term_idx,
74 field_boost,
75 avg_field_len,
76 }
77 }
78
79 #[inline]
81 pub fn doc(&self) -> DocId {
82 self.iter.doc()
83 }
84
85 #[inline]
87 pub fn score(&self) -> Score {
88 let tf = self.iter.term_freq() as f32;
89
90 let length_norm = 1.0 - WAND_B + WAND_B * (tf / self.avg_field_len.max(1.0));
94 let tf_norm = (tf * self.field_boost * (WAND_K1 + 1.0))
95 / (tf * self.field_boost + WAND_K1 * length_norm);
96
97 self.idf * tf_norm
98 }
99
100 #[inline]
102 pub fn current_block_max_score(&self) -> f32 {
103 if self.field_boost == 1.0 {
104 self.iter.current_block_max_score()
105 } else {
106 let block_max_tf = self.iter.current_block_max_tf();
108 HorizontalBP128PostingList::compute_bm25f_upper_bound(
109 block_max_tf,
110 self.idf,
111 self.field_boost,
112 )
113 }
114 }
115
116 #[inline]
118 pub fn advance(&mut self) -> DocId {
119 self.iter.advance()
120 }
121
122 #[inline]
124 pub fn seek(&mut self, target: DocId) -> DocId {
125 self.iter.seek(target)
126 }
127
128 #[inline]
130 pub fn is_exhausted(&self) -> bool {
131 self.doc() == u32::MAX
132 }
133}
134
135#[derive(Clone, Copy)]
137struct HeapEntry {
138 doc_id: DocId,
139 score: Score,
140}
141
142impl PartialEq for HeapEntry {
143 fn eq(&self, other: &Self) -> bool {
144 self.score == other.score && self.doc_id == other.doc_id
145 }
146}
147
148impl Eq for HeapEntry {}
149
150impl PartialOrd for HeapEntry {
151 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
152 Some(self.cmp(other))
153 }
154}
155
156impl Ord for HeapEntry {
157 fn cmp(&self, other: &Self) -> Ordering {
158 other
160 .score
161 .partial_cmp(&self.score)
162 .unwrap_or(Ordering::Equal)
163 .then_with(|| self.doc_id.cmp(&other.doc_id))
164 }
165}
166
167#[derive(Debug, Clone, Copy)]
169pub struct WandResult {
170 pub doc_id: DocId,
171 pub score: Score,
172}
173
174pub struct MaxScoreWand<'a> {
181 scorers: Vec<TermScorer<'a>>,
183 heap: BinaryHeap<HeapEntry>,
185 k: usize,
187 threshold: Score,
189 #[allow(dead_code)]
191 essential_max_sum: Score,
192}
193
194impl<'a> MaxScoreWand<'a> {
195 pub fn new(mut scorers: Vec<TermScorer<'a>>, k: usize) -> Self {
197 scorers.sort_by(|a, b| {
199 b.max_score
200 .partial_cmp(&a.max_score)
201 .unwrap_or(Ordering::Equal)
202 });
203
204 let essential_max_sum: Score = scorers.iter().map(|s| s.max_score).sum();
205
206 Self {
207 scorers,
208 heap: BinaryHeap::with_capacity(k + 1),
209 k,
210 threshold: 0.0,
211 essential_max_sum,
212 }
213 }
214
215 pub fn execute(mut self) -> Vec<WandResult> {
217 if self.scorers.is_empty() {
218 return Vec::new();
219 }
220
221 self.scorers.retain(|s| !s.is_exhausted());
223
224 while !self.scorers.is_empty() {
225 self.scorers.sort_by_key(|s| s.doc());
227
228 let pivot_idx = self.find_pivot();
230
231 if pivot_idx.is_none() {
232 break;
233 }
234 let pivot_idx = pivot_idx.unwrap();
235 let pivot_doc = self.scorers[pivot_idx].doc();
236
237 if pivot_doc == u32::MAX {
238 break;
239 }
240
241 let all_at_pivot = self.scorers[..=pivot_idx]
243 .iter()
244 .all(|s| s.doc() == pivot_doc);
245
246 if all_at_pivot {
247 let score = self.score_document(pivot_doc);
249 self.maybe_insert(pivot_doc, score);
250
251 for scorer in &mut self.scorers {
253 if scorer.doc() == pivot_doc {
254 scorer.advance();
255 }
256 }
257 } else {
258 for i in 0..pivot_idx {
260 if self.scorers[i].doc() < pivot_doc {
261 self.scorers[i].seek(pivot_doc);
262 }
263 }
264 }
265
266 self.scorers.retain(|s| !s.is_exhausted());
268 }
269
270 self.into_results()
271 }
272
273 fn find_pivot(&self) -> Option<usize> {
275 let mut cumsum = 0.0f32;
276
277 for (i, scorer) in self.scorers.iter().enumerate() {
278 cumsum += scorer.max_score;
279 if cumsum >= self.threshold {
280 return Some(i);
281 }
282 }
283
284 if cumsum < self.threshold {
286 None
287 } else {
288 Some(self.scorers.len() - 1)
289 }
290 }
291
292 fn score_document(&self, doc_id: DocId) -> Score {
294 let mut score = 0.0;
295 for scorer in &self.scorers {
296 if scorer.doc() == doc_id {
297 score += scorer.score();
298 }
299 }
300 score
301 }
302
303 fn maybe_insert(&mut self, doc_id: DocId, score: Score) {
305 if self.heap.len() < self.k {
306 self.heap.push(HeapEntry { doc_id, score });
307 if self.heap.len() == self.k {
308 self.threshold = self.heap.peek().map(|e| e.score).unwrap_or(0.0);
309 }
310 } else if score > self.threshold {
311 self.heap.pop();
312 self.heap.push(HeapEntry { doc_id, score });
313 self.threshold = self.heap.peek().map(|e| e.score).unwrap_or(0.0);
314 }
315 }
316
317 fn into_results(self) -> Vec<WandResult> {
319 let mut results: Vec<_> = self
320 .heap
321 .into_vec()
322 .into_iter()
323 .map(|e| WandResult {
324 doc_id: e.doc_id,
325 score: e.score,
326 })
327 .collect();
328
329 results.sort_by(|a, b| {
330 b.score
331 .partial_cmp(&a.score)
332 .unwrap_or(Ordering::Equal)
333 .then_with(|| a.doc_id.cmp(&b.doc_id))
334 });
335
336 results
337 }
338}
339
340pub struct BlockWand<'a> {
344 scorers: Vec<TermScorer<'a>>,
345 heap: BinaryHeap<HeapEntry>,
346 k: usize,
347 threshold: Score,
348}
349
350impl<'a> BlockWand<'a> {
351 pub fn new(scorers: Vec<TermScorer<'a>>, k: usize) -> Self {
352 Self {
353 scorers,
354 heap: BinaryHeap::with_capacity(k + 1),
355 k,
356 threshold: 0.0,
357 }
358 }
359
360 pub fn execute(mut self) -> Vec<WandResult> {
362 if self.scorers.is_empty() {
363 return Vec::new();
364 }
365
366 self.scorers.retain(|s| !s.is_exhausted());
367
368 while !self.scorers.is_empty() {
369 self.scorers.sort_by_key(|s| s.doc());
371
372 let min_doc = self.scorers[0].doc();
374 if min_doc == u32::MAX {
375 break;
376 }
377
378 let upper_bound: Score = self
380 .scorers
381 .iter()
382 .filter(|s| s.doc() <= min_doc || s.current_block_max_score() > 0.0)
383 .map(|s| {
384 if s.doc() == min_doc {
385 s.score() } else {
387 s.current_block_max_score() }
389 })
390 .sum();
391
392 if upper_bound >= self.threshold {
393 for scorer in &mut self.scorers {
396 if scorer.doc() < min_doc {
397 scorer.seek(min_doc);
398 }
399 }
400
401 let score = self.score_document(min_doc);
403 self.maybe_insert(min_doc, score);
404 }
405
406 for scorer in &mut self.scorers {
408 if scorer.doc() == min_doc {
409 scorer.advance();
410 }
411 }
412
413 self.scorers.retain(|s| !s.is_exhausted());
414 }
415
416 self.into_results()
417 }
418
419 fn score_document(&self, doc_id: DocId) -> Score {
420 self.scorers
421 .iter()
422 .filter(|s| s.doc() == doc_id)
423 .map(|s| s.score())
424 .sum()
425 }
426
427 fn maybe_insert(&mut self, doc_id: DocId, score: Score) {
428 if self.heap.len() < self.k {
429 self.heap.push(HeapEntry { doc_id, score });
430 if self.heap.len() == self.k {
431 self.threshold = self.heap.peek().map(|e| e.score).unwrap_or(0.0);
432 }
433 } else if score > self.threshold {
434 self.heap.pop();
435 self.heap.push(HeapEntry { doc_id, score });
436 self.threshold = self.heap.peek().map(|e| e.score).unwrap_or(0.0);
437 }
438 }
439
440 fn into_results(self) -> Vec<WandResult> {
441 let mut results: Vec<_> = self
442 .heap
443 .into_vec()
444 .into_iter()
445 .map(|e| WandResult {
446 doc_id: e.doc_id,
447 score: e.score,
448 })
449 .collect();
450
451 results.sort_by(|a, b| {
452 b.score
453 .partial_cmp(&a.score)
454 .unwrap_or(Ordering::Equal)
455 .then_with(|| a.doc_id.cmp(&b.doc_id))
456 });
457
458 results
459 }
460}
461
462pub fn daat_or<'a>(scorers: &mut [TermScorer<'a>], k: usize) -> Vec<WandResult> {
464 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
465 let mut threshold = 0.0f32;
466
467 loop {
468 let min_doc = scorers
470 .iter()
471 .filter(|s| !s.is_exhausted())
472 .map(|s| s.doc())
473 .min();
474
475 let min_doc = match min_doc {
476 Some(d) if d != u32::MAX => d,
477 _ => break,
478 };
479
480 let score: Score = scorers
482 .iter()
483 .filter(|s| s.doc() == min_doc)
484 .map(|s| s.score())
485 .sum();
486
487 if heap.len() < k {
489 heap.push(HeapEntry {
490 doc_id: min_doc,
491 score,
492 });
493 if heap.len() == k {
494 threshold = heap.peek().map(|e| e.score).unwrap_or(0.0);
495 }
496 } else if score > threshold {
497 heap.pop();
498 heap.push(HeapEntry {
499 doc_id: min_doc,
500 score,
501 });
502 threshold = heap.peek().map(|e| e.score).unwrap_or(0.0);
503 }
504
505 for scorer in scorers.iter_mut() {
507 if scorer.doc() == min_doc {
508 scorer.advance();
509 }
510 }
511 }
512
513 let mut results: Vec<_> = heap
514 .into_vec()
515 .into_iter()
516 .map(|e| WandResult {
517 doc_id: e.doc_id,
518 score: e.score,
519 })
520 .collect();
521
522 results.sort_by(|a, b| {
523 b.score
524 .partial_cmp(&a.score)
525 .unwrap_or(Ordering::Equal)
526 .then_with(|| a.doc_id.cmp(&b.doc_id))
527 });
528
529 results
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535
536 fn create_test_posting_list(
537 doc_ids: &[u32],
538 term_freqs: &[u32],
539 idf: f32,
540 ) -> HorizontalBP128PostingList {
541 HorizontalBP128PostingList::from_postings(doc_ids, term_freqs, idf)
542 }
543
544 #[test]
545 fn test_maxscore_wand_basic() {
546 let pl1 = create_test_posting_list(&[1, 3, 5, 7], &[2, 1, 3, 1], 1.0);
548 let pl2 = create_test_posting_list(&[2, 3, 6, 7], &[1, 2, 1, 2], 1.5);
550
551 let scorers = vec![TermScorer::new(&pl1, 1.0, 0), TermScorer::new(&pl2, 1.5, 1)];
552
553 let results = MaxScoreWand::new(scorers, 3).execute();
554
555 assert!(!results.is_empty());
556 let top_docs: Vec<_> = results.iter().map(|r| r.doc_id).collect();
558 assert!(top_docs.contains(&3) || top_docs.contains(&7));
559 }
560
561 #[test]
562 fn test_block_wand_basic() {
563 let pl1 = create_test_posting_list(&[1, 3, 5, 7, 9], &[1, 2, 1, 3, 1], 1.0);
564 let pl2 = create_test_posting_list(&[2, 3, 7, 8], &[1, 1, 2, 1], 1.2);
565
566 let scorers = vec![TermScorer::new(&pl1, 1.0, 0), TermScorer::new(&pl2, 1.2, 1)];
567
568 let results = BlockWand::new(scorers, 5).execute();
569
570 assert!(!results.is_empty());
571 let doc_ids: Vec<_> = results.iter().map(|r| r.doc_id).collect();
573 assert!(doc_ids.iter().any(|&d| d == 3 || d == 7)); }
575
576 #[test]
577 fn test_daat_or() {
578 let pl1 = create_test_posting_list(&[1, 2, 3], &[1, 1, 1], 1.0);
579 let pl2 = create_test_posting_list(&[2, 3, 4], &[1, 1, 1], 1.0);
580
581 let mut scorers = vec![TermScorer::new(&pl1, 1.0, 0), TermScorer::new(&pl2, 1.0, 1)];
582
583 let results = daat_or(&mut scorers, 10);
584
585 assert_eq!(results.len(), 4); assert!(results[0].doc_id == 2 || results[0].doc_id == 3);
589 assert!(results[1].doc_id == 2 || results[1].doc_id == 3);
590 }
591
592 #[test]
593 fn test_maxscore_threshold_pruning() {
594 let pl1 = create_test_posting_list(&[1, 100, 200], &[10, 10, 10], 2.0);
597 let pl2 = create_test_posting_list(&(0..50).collect::<Vec<_>>(), &[1; 50], 0.1);
599
600 let scorers = vec![TermScorer::new(&pl1, 2.0, 0), TermScorer::new(&pl2, 0.1, 1)];
601
602 let results = MaxScoreWand::new(scorers, 3).execute();
603
604 assert!(
606 results
607 .iter()
608 .any(|r| r.doc_id == 1 || r.doc_id == 100 || r.doc_id == 200)
609 );
610 }
611}