Skip to main content

hermes_core/query/
scoring.rs

1//! Shared scoring abstractions for text and sparse vector search
2//!
3//! Provides common types and executors for efficient top-k retrieval:
4//! - `TermCursor`: Unified cursor for both BM25 text and sparse vector posting lists
5//! - `ScoreCollector`: Efficient min-heap for maintaining top-k results
6//! - `MaxScoreExecutor`: Unified Block-Max MaxScore with conjunction optimization
7//! - `ScoredDoc`: Result type with doc_id, score, and ordinal
8
9use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11
12use log::debug;
13
14use crate::DocId;
15
16/// Entry for top-k min-heap
17#[derive(Clone, Copy)]
18pub struct HeapEntry {
19    pub doc_id: DocId,
20    pub score: f32,
21    pub ordinal: u16,
22}
23
24impl PartialEq for HeapEntry {
25    fn eq(&self, other: &Self) -> bool {
26        self.score == other.score && self.doc_id == other.doc_id
27    }
28}
29
30impl Eq for HeapEntry {}
31
32impl Ord for HeapEntry {
33    fn cmp(&self, other: &Self) -> Ordering {
34        // Min-heap: lower scores come first (to be evicted)
35        other
36            .score
37            .partial_cmp(&self.score)
38            .unwrap_or(Ordering::Equal)
39            .then_with(|| self.doc_id.cmp(&other.doc_id))
40    }
41}
42
43impl PartialOrd for HeapEntry {
44    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
45        Some(self.cmp(other))
46    }
47}
48
49/// Efficient top-k collector using min-heap
50///
51/// Maintains the k highest-scoring documents using a min-heap where the
52/// lowest score is at the top for O(1) threshold lookup and O(log k) eviction.
53/// No deduplication - caller must ensure each doc_id is inserted only once.
54pub struct ScoreCollector {
55    /// Min-heap of top-k entries (lowest score at top for eviction)
56    heap: BinaryHeap<HeapEntry>,
57    pub k: usize,
58    /// Cached threshold: avoids repeated heap.peek() in hot loops.
59    /// Updated only when the heap changes (insert/pop).
60    cached_threshold: f32,
61}
62
63impl ScoreCollector {
64    /// Create a new collector for top-k results
65    pub fn new(k: usize) -> Self {
66        // Cap capacity to avoid allocation overflow for very large k
67        let capacity = k.saturating_add(1).min(1_000_000);
68        Self {
69            heap: BinaryHeap::with_capacity(capacity),
70            k,
71            cached_threshold: 0.0,
72        }
73    }
74
75    /// Current score threshold (minimum score to enter top-k)
76    #[inline]
77    pub fn threshold(&self) -> f32 {
78        self.cached_threshold
79    }
80
81    /// Recompute cached threshold from heap state
82    #[inline]
83    fn update_threshold(&mut self) {
84        self.cached_threshold = if self.heap.len() >= self.k {
85            self.heap.peek().map(|e| e.score).unwrap_or(0.0)
86        } else {
87            0.0
88        };
89    }
90
91    /// Insert a document score. Returns true if inserted in top-k.
92    /// Caller must ensure each doc_id is inserted only once.
93    #[inline]
94    pub fn insert(&mut self, doc_id: DocId, score: f32) -> bool {
95        self.insert_with_ordinal(doc_id, score, 0)
96    }
97
98    /// Insert a document score with ordinal. Returns true if inserted in top-k.
99    /// Caller must ensure each doc_id is inserted only once.
100    #[inline]
101    pub fn insert_with_ordinal(&mut self, doc_id: DocId, score: f32, ordinal: u16) -> bool {
102        if self.heap.len() < self.k {
103            self.heap.push(HeapEntry {
104                doc_id,
105                score,
106                ordinal,
107            });
108            self.update_threshold();
109            true
110        } else if score > self.cached_threshold {
111            self.heap.push(HeapEntry {
112                doc_id,
113                score,
114                ordinal,
115            });
116            self.heap.pop(); // Remove lowest
117            self.update_threshold();
118            true
119        } else {
120            false
121        }
122    }
123
124    /// Check if a score could potentially enter top-k
125    #[inline]
126    pub fn would_enter(&self, score: f32) -> bool {
127        self.heap.len() < self.k || score > self.cached_threshold
128    }
129
130    /// Get number of documents collected so far
131    #[inline]
132    pub fn len(&self) -> usize {
133        self.heap.len()
134    }
135
136    /// Check if collector is empty
137    #[inline]
138    pub fn is_empty(&self) -> bool {
139        self.heap.is_empty()
140    }
141
142    /// Convert to sorted top-k results (descending by score)
143    pub fn into_sorted_results(self) -> Vec<(DocId, f32, u16)> {
144        let heap_vec = self.heap.into_vec();
145        let mut results: Vec<(DocId, f32, u16)> = Vec::with_capacity(heap_vec.len());
146        for e in heap_vec {
147            results.push((e.doc_id, e.score, e.ordinal));
148        }
149
150        // Sort by score descending, then doc_id ascending
151        results.sort_by(|a, b| {
152            b.1.partial_cmp(&a.1)
153                .unwrap_or(Ordering::Equal)
154                .then_with(|| a.0.cmp(&b.0))
155        });
156
157        results
158    }
159}
160
161/// Search result from MaxScore execution
162#[derive(Debug, Clone, Copy)]
163pub struct ScoredDoc {
164    pub doc_id: DocId,
165    pub score: f32,
166    /// Ordinal for multi-valued fields (which vector in the field matched)
167    pub ordinal: u16,
168}
169
170/// Unified Block-Max MaxScore executor for top-k retrieval
171///
172/// Works with both full-text (BM25) and sparse vector (dot product) queries
173/// through the polymorphic `TermCursor`. Combines three optimizations:
174/// 1. **MaxScore partitioning** (Turtle & Flood 1995): terms split into essential
175///    (must check) and non-essential (only scored if candidate is promising)
176/// 2. **Block-max pruning** (Ding & Suel 2011): skip blocks where per-block
177///    upper bounds can't beat the current threshold
178/// 3. **Conjunction optimization** (Lucene/Grand 2023): progressively intersect
179///    essential terms as threshold rises, skipping docs that lack enough terms
180pub struct MaxScoreExecutor<'a> {
181    cursors: Vec<TermCursor<'a>>,
182    prefix_sums: Vec<f32>,
183    collector: ScoreCollector,
184    heap_factor: f32,
185    predicate: Option<super::DocPredicate<'a>>,
186}
187
188/// Unified term cursor for Block-Max MaxScore execution.
189///
190/// All per-position decode buffers (`doc_ids`, `scores`, `ordinals`) live in
191/// the struct directly and are filled by `ensure_block_loaded`.
192///
193/// Skip-list metadata is **not** materialized — it is read lazily from the
194/// underlying source (`BlockPostingList` for text, `SparseIndex` for sparse),
195/// both backed by zero-copy mmap'd `OwnedBytes`.
196pub(crate) struct TermCursor<'a> {
197    pub max_score: f32,
198    num_blocks: usize,
199    // ── Per-position state (filled by ensure_block_loaded) ──────────
200    block_idx: usize,
201    doc_ids: Vec<u32>,
202    scores: Vec<f32>,
203    ordinals: Vec<u16>,
204    pos: usize,
205    block_loaded: bool,
206    exhausted: bool,
207    // ── Block decode + skip access source ───────────────────────────
208    variant: CursorVariant<'a>,
209}
210
211enum CursorVariant<'a> {
212    /// Full-text BM25 — in-memory BlockPostingList (skip list + block data)
213    Text {
214        list: crate::structures::BlockPostingList,
215        idf: f32,
216        avg_field_len: f32,
217        tfs: Vec<u32>, // temp decode buffer, converted to scores
218    },
219    /// Sparse vector — mmap'd SparseIndex (skip entries + block data)
220    Sparse {
221        si: &'a crate::segment::SparseIndex,
222        query_weight: f32,
223        skip_start: usize,
224        block_data_offset: u64,
225    },
226}
227
228impl<'a> TermCursor<'a> {
229    /// Create a full-text BM25 cursor (lazy — no blocks decoded yet).
230    pub fn text(
231        posting_list: crate::structures::BlockPostingList,
232        idf: f32,
233        avg_field_len: f32,
234    ) -> Self {
235        let max_tf = posting_list.max_tf() as f32;
236        let max_score = super::bm25_upper_bound(max_tf.max(1.0), idf);
237        let num_blocks = posting_list.num_blocks();
238        Self {
239            max_score,
240            num_blocks,
241            block_idx: 0,
242            doc_ids: Vec::with_capacity(128),
243            scores: Vec::with_capacity(128),
244            ordinals: Vec::new(),
245            pos: 0,
246            block_loaded: false,
247            exhausted: num_blocks == 0,
248            variant: CursorVariant::Text {
249                list: posting_list,
250                idf,
251                avg_field_len,
252                tfs: Vec::with_capacity(128),
253            },
254        }
255    }
256
257    /// Create a sparse vector cursor with lazy block loading.
258    /// Skip entries are **not** copied — they are read from `SparseIndex` mmap on demand.
259    pub fn sparse(
260        si: &'a crate::segment::SparseIndex,
261        query_weight: f32,
262        skip_start: usize,
263        skip_count: usize,
264        global_max_weight: f32,
265        block_data_offset: u64,
266    ) -> Self {
267        Self {
268            max_score: query_weight.abs() * global_max_weight,
269            num_blocks: skip_count,
270            block_idx: 0,
271            doc_ids: Vec::with_capacity(256),
272            scores: Vec::with_capacity(256),
273            ordinals: Vec::with_capacity(256),
274            pos: 0,
275            block_loaded: false,
276            exhausted: skip_count == 0,
277            variant: CursorVariant::Sparse {
278                si,
279                query_weight,
280                skip_start,
281                block_data_offset,
282            },
283        }
284    }
285
286    // ── Skip-entry access (lazy, zero-copy for sparse) ──────────────────
287
288    #[inline]
289    fn block_first_doc(&self, idx: usize) -> DocId {
290        match &self.variant {
291            CursorVariant::Text { list, .. } => list.block_first_doc(idx).unwrap_or(u32::MAX),
292            CursorVariant::Sparse { si, skip_start, .. } => {
293                si.read_skip_entry(*skip_start + idx).first_doc
294            }
295        }
296    }
297
298    #[inline]
299    fn block_last_doc(&self, idx: usize) -> DocId {
300        match &self.variant {
301            CursorVariant::Text { list, .. } => list.block_last_doc(idx).unwrap_or(0),
302            CursorVariant::Sparse { si, skip_start, .. } => {
303                si.read_skip_entry(*skip_start + idx).last_doc
304            }
305        }
306    }
307
308    // ── Read-only accessors ─────────────────────────────────────────────
309
310    #[inline]
311    pub fn doc(&self) -> DocId {
312        if self.exhausted {
313            return u32::MAX;
314        }
315        if self.block_loaded {
316            self.doc_ids.get(self.pos).copied().unwrap_or(u32::MAX)
317        } else {
318            self.block_first_doc(self.block_idx)
319        }
320    }
321
322    #[inline]
323    pub fn ordinal(&self) -> u16 {
324        if !self.block_loaded || self.ordinals.is_empty() {
325            return 0;
326        }
327        self.ordinals.get(self.pos).copied().unwrap_or(0)
328    }
329
330    #[inline]
331    pub fn score(&self) -> f32 {
332        if !self.block_loaded {
333            return 0.0;
334        }
335        self.scores.get(self.pos).copied().unwrap_or(0.0)
336    }
337
338    #[inline]
339    pub fn current_block_max_score(&self) -> f32 {
340        if self.exhausted {
341            return 0.0;
342        }
343        match &self.variant {
344            CursorVariant::Text { list, idf, .. } => {
345                let block_max_tf = list.block_max_tf(self.block_idx).unwrap_or(0) as f32;
346                super::bm25_upper_bound(block_max_tf.max(1.0), *idf)
347            }
348            CursorVariant::Sparse {
349                si,
350                query_weight,
351                skip_start,
352                ..
353            } => query_weight.abs() * si.read_skip_entry(*skip_start + self.block_idx).max_weight,
354        }
355    }
356
357    // ── Block navigation ────────────────────────────────────────────────
358
359    pub fn skip_to_next_block(&mut self) -> DocId {
360        if self.exhausted {
361            return u32::MAX;
362        }
363        self.block_idx += 1;
364        self.block_loaded = false;
365        if self.block_idx >= self.num_blocks {
366            self.exhausted = true;
367            return u32::MAX;
368        }
369        self.block_first_doc(self.block_idx)
370    }
371
372    #[inline]
373    fn advance_pos(&mut self) -> DocId {
374        self.pos += 1;
375        if self.pos >= self.doc_ids.len() {
376            self.block_idx += 1;
377            self.block_loaded = false;
378            if self.block_idx >= self.num_blocks {
379                self.exhausted = true;
380                return u32::MAX;
381            }
382        }
383        self.doc()
384    }
385
386    // ── Block loading (dispatch: decode format + I/O differ) ────────────
387
388    pub async fn ensure_block_loaded(&mut self) -> crate::Result<bool> {
389        if self.exhausted || self.block_loaded {
390            return Ok(!self.exhausted);
391        }
392        match &mut self.variant {
393            CursorVariant::Text {
394                list,
395                idf,
396                avg_field_len,
397                tfs,
398            } => {
399                if list.decode_block_into(self.block_idx, &mut self.doc_ids, tfs) {
400                    self.scores.clear();
401                    self.scores.reserve(tfs.len());
402                    for &tf in tfs.iter() {
403                        let tf = tf as f32;
404                        self.scores
405                            .push(super::bm25_score(tf, *idf, tf, *avg_field_len));
406                    }
407                    self.pos = 0;
408                    self.block_loaded = true;
409                    Ok(true)
410                } else {
411                    self.exhausted = true;
412                    Ok(false)
413                }
414            }
415            CursorVariant::Sparse {
416                si,
417                query_weight,
418                skip_start,
419                block_data_offset,
420                ..
421            } => {
422                let block = si
423                    .load_block_direct(*skip_start, *block_data_offset, self.block_idx)
424                    .await?;
425                match block {
426                    Some(b) => {
427                        b.decode_doc_ids_into(&mut self.doc_ids);
428                        b.decode_ordinals_into(&mut self.ordinals);
429                        b.decode_scored_weights_into(*query_weight, &mut self.scores);
430                        self.pos = 0;
431                        self.block_loaded = true;
432                        Ok(true)
433                    }
434                    None => {
435                        self.exhausted = true;
436                        Ok(false)
437                    }
438                }
439            }
440        }
441    }
442
443    pub fn ensure_block_loaded_sync(&mut self) -> crate::Result<bool> {
444        if self.exhausted || self.block_loaded {
445            return Ok(!self.exhausted);
446        }
447        match &mut self.variant {
448            CursorVariant::Text {
449                list,
450                idf,
451                avg_field_len,
452                tfs,
453            } => {
454                if list.decode_block_into(self.block_idx, &mut self.doc_ids, tfs) {
455                    self.scores.clear();
456                    self.scores.reserve(tfs.len());
457                    for &tf in tfs.iter() {
458                        let tf = tf as f32;
459                        self.scores
460                            .push(super::bm25_score(tf, *idf, tf, *avg_field_len));
461                    }
462                    self.pos = 0;
463                    self.block_loaded = true;
464                    Ok(true)
465                } else {
466                    self.exhausted = true;
467                    Ok(false)
468                }
469            }
470            CursorVariant::Sparse {
471                si,
472                query_weight,
473                skip_start,
474                block_data_offset,
475                ..
476            } => {
477                let block =
478                    si.load_block_direct_sync(*skip_start, *block_data_offset, self.block_idx)?;
479                match block {
480                    Some(b) => {
481                        b.decode_doc_ids_into(&mut self.doc_ids);
482                        b.decode_ordinals_into(&mut self.ordinals);
483                        b.decode_scored_weights_into(*query_weight, &mut self.scores);
484                        self.pos = 0;
485                        self.block_loaded = true;
486                        Ok(true)
487                    }
488                    None => {
489                        self.exhausted = true;
490                        Ok(false)
491                    }
492                }
493            }
494        }
495    }
496
497    // ── Advance / Seek ──────────────────────────────────────────────────
498
499    pub async fn advance(&mut self) -> crate::Result<DocId> {
500        if self.exhausted {
501            return Ok(u32::MAX);
502        }
503        self.ensure_block_loaded().await?;
504        if self.exhausted {
505            return Ok(u32::MAX);
506        }
507        Ok(self.advance_pos())
508    }
509
510    pub fn advance_sync(&mut self) -> crate::Result<DocId> {
511        if self.exhausted {
512            return Ok(u32::MAX);
513        }
514        self.ensure_block_loaded_sync()?;
515        if self.exhausted {
516            return Ok(u32::MAX);
517        }
518        Ok(self.advance_pos())
519    }
520
521    pub async fn seek(&mut self, target: DocId) -> crate::Result<DocId> {
522        if let Some(doc) = self.seek_prepare(target) {
523            return Ok(doc);
524        }
525        self.ensure_block_loaded().await?;
526        if self.seek_finish(target) {
527            self.ensure_block_loaded().await?;
528        }
529        Ok(self.doc())
530    }
531
532    pub fn seek_sync(&mut self, target: DocId) -> crate::Result<DocId> {
533        if let Some(doc) = self.seek_prepare(target) {
534            return Ok(doc);
535        }
536        self.ensure_block_loaded_sync()?;
537        if self.seek_finish(target) {
538            self.ensure_block_loaded_sync()?;
539        }
540        Ok(self.doc())
541    }
542
543    fn seek_prepare(&mut self, target: DocId) -> Option<DocId> {
544        if self.exhausted {
545            return Some(u32::MAX);
546        }
547
548        // Fast path: target is within the currently loaded block
549        if self.block_loaded
550            && let Some(&last) = self.doc_ids.last()
551        {
552            if last >= target && self.doc_ids[self.pos] < target {
553                let remaining = &self.doc_ids[self.pos..];
554                self.pos += crate::structures::simd::find_first_ge_u32(remaining, target);
555                if self.pos >= self.doc_ids.len() {
556                    self.block_idx += 1;
557                    self.block_loaded = false;
558                    if self.block_idx >= self.num_blocks {
559                        self.exhausted = true;
560                        return Some(u32::MAX);
561                    }
562                }
563                return Some(self.doc());
564            }
565            if self.doc_ids[self.pos] >= target {
566                return Some(self.doc());
567            }
568        }
569
570        // Seek to the block containing target
571        let lo = match &self.variant {
572            // Text: SIMD-accelerated 2-level seek (L1 + L0)
573            CursorVariant::Text { list, .. } => match list.seek_block(target, self.block_idx) {
574                Some(idx) => idx,
575                None => {
576                    self.exhausted = true;
577                    return Some(u32::MAX);
578                }
579            },
580            // Sparse: binary search on skip entries (lazy mmap reads)
581            CursorVariant::Sparse { .. } => {
582                let mut lo = self.block_idx;
583                let mut hi = self.num_blocks;
584                while lo < hi {
585                    let mid = lo + (hi - lo) / 2;
586                    if self.block_last_doc(mid) < target {
587                        lo = mid + 1;
588                    } else {
589                        hi = mid;
590                    }
591                }
592                lo
593            }
594        };
595        if lo >= self.num_blocks {
596            self.exhausted = true;
597            return Some(u32::MAX);
598        }
599        if lo != self.block_idx || !self.block_loaded {
600            self.block_idx = lo;
601            self.block_loaded = false;
602        }
603        None
604    }
605
606    #[inline]
607    fn seek_finish(&mut self, target: DocId) -> bool {
608        if self.exhausted {
609            return false;
610        }
611        self.pos = crate::structures::simd::find_first_ge_u32(&self.doc_ids, target);
612        if self.pos >= self.doc_ids.len() {
613            self.block_idx += 1;
614            self.block_loaded = false;
615            if self.block_idx >= self.num_blocks {
616                self.exhausted = true;
617                return false;
618            }
619            return true;
620        }
621        false
622    }
623}
624
625/// Macro to stamp out the Block-Max MaxScore loop for both async and sync paths.
626///
627/// `$ensure`, `$advance`, `$seek` are cursor method idents (async or _sync variants).
628/// `$($aw:tt)*` captures `.await` for async or nothing for sync.
629macro_rules! bms_execute_loop {
630    ($self:ident, $ensure:ident, $advance:ident, $seek:ident, $($aw:tt)*) => {{
631        let n = $self.cursors.len();
632
633        // Load first block for each cursor (ensures doc() returns real values)
634        for cursor in &mut $self.cursors {
635            cursor.$ensure() $($aw)* ?;
636        }
637
638        let mut docs_scored = 0u64;
639        let mut docs_skipped = 0u64;
640        let mut blocks_skipped = 0u64;
641        let mut conjunction_skipped = 0u64;
642        let mut ordinal_scores: Vec<(u16, f32)> = Vec::with_capacity(n * 2);
643
644        loop {
645            let partition = $self.find_partition();
646            if partition >= n {
647                break;
648            }
649
650            // Find minimum doc_id across essential cursors
651            let mut min_doc = u32::MAX;
652            for i in partition..n {
653                let doc = $self.cursors[i].doc();
654                if doc < min_doc {
655                    min_doc = doc;
656                }
657            }
658            if min_doc == u32::MAX {
659                break;
660            }
661
662            let non_essential_upper = if partition > 0 {
663                $self.prefix_sums[partition - 1]
664            } else {
665                0.0
666            };
667            let adjusted_threshold = $self.collector.threshold() * $self.heap_factor;
668
669            // --- Conjunction optimization ---
670            if $self.collector.len() >= $self.collector.k {
671                let present_upper: f32 = (partition..n)
672                    .filter(|&i| $self.cursors[i].doc() == min_doc)
673                    .map(|i| $self.cursors[i].max_score)
674                    .sum();
675
676                if present_upper + non_essential_upper <= adjusted_threshold {
677                    for i in partition..n {
678                        if $self.cursors[i].doc() == min_doc {
679                            $self.cursors[i].$ensure() $($aw)* ?;
680                            $self.cursors[i].$advance() $($aw)* ?;
681                        }
682                    }
683                    conjunction_skipped += 1;
684                    continue;
685                }
686            }
687
688            // --- Block-max pruning ---
689            if $self.collector.len() >= $self.collector.k {
690                let block_max_sum: f32 = (partition..n)
691                    .filter(|&i| $self.cursors[i].doc() == min_doc)
692                    .map(|i| $self.cursors[i].current_block_max_score())
693                    .sum();
694
695                if block_max_sum + non_essential_upper <= adjusted_threshold {
696                    for i in partition..n {
697                        if $self.cursors[i].doc() == min_doc {
698                            $self.cursors[i].skip_to_next_block();
699                            $self.cursors[i].$ensure() $($aw)* ?;
700                        }
701                    }
702                    blocks_skipped += 1;
703                    continue;
704                }
705            }
706
707            // --- Predicate filter (after block-max, before scoring) ---
708            if let Some(ref pred) = $self.predicate {
709                if !pred(min_doc) {
710                    for i in partition..n {
711                        if $self.cursors[i].doc() == min_doc {
712                            $self.cursors[i].$ensure() $($aw)* ?;
713                            $self.cursors[i].$advance() $($aw)* ?;
714                        }
715                    }
716                    continue;
717                }
718            }
719
720            // --- Score essential cursors ---
721            ordinal_scores.clear();
722            for i in partition..n {
723                if $self.cursors[i].doc() == min_doc {
724                    $self.cursors[i].$ensure() $($aw)* ?;
725                    while $self.cursors[i].doc() == min_doc {
726                        ordinal_scores.push(($self.cursors[i].ordinal(), $self.cursors[i].score()));
727                        $self.cursors[i].$advance() $($aw)* ?;
728                    }
729                }
730            }
731
732            let essential_total: f32 = ordinal_scores.iter().map(|(_, s)| *s).sum();
733            if $self.collector.len() >= $self.collector.k
734                && essential_total + non_essential_upper <= adjusted_threshold
735            {
736                docs_skipped += 1;
737                continue;
738            }
739
740            // --- Score non-essential cursors (highest max_score first for early exit) ---
741            let mut running_total = essential_total;
742            for i in (0..partition).rev() {
743                if $self.collector.len() >= $self.collector.k
744                    && running_total + $self.prefix_sums[i] <= adjusted_threshold
745                {
746                    break;
747                }
748
749                let doc = $self.cursors[i].$seek(min_doc) $($aw)* ?;
750                if doc == min_doc {
751                    while $self.cursors[i].doc() == min_doc {
752                        let s = $self.cursors[i].score();
753                        running_total += s;
754                        ordinal_scores.push(($self.cursors[i].ordinal(), s));
755                        $self.cursors[i].$advance() $($aw)* ?;
756                    }
757                }
758            }
759
760            // --- Group by ordinal and insert ---
761            // Fast path: single entry (common for single-valued fields) — skip sort + grouping
762            if ordinal_scores.len() == 1 {
763                let (ord, score) = ordinal_scores[0];
764                if $self.collector.insert_with_ordinal(min_doc, score, ord) {
765                    docs_scored += 1;
766                } else {
767                    docs_skipped += 1;
768                }
769            } else if !ordinal_scores.is_empty() {
770                if ordinal_scores.len() > 2 {
771                    ordinal_scores.sort_unstable_by_key(|(ord, _)| *ord);
772                } else if ordinal_scores[0].0 > ordinal_scores[1].0 {
773                    ordinal_scores.swap(0, 1);
774                }
775                let mut j = 0;
776                while j < ordinal_scores.len() {
777                    let current_ord = ordinal_scores[j].0;
778                    let mut score = 0.0f32;
779                    while j < ordinal_scores.len() && ordinal_scores[j].0 == current_ord {
780                        score += ordinal_scores[j].1;
781                        j += 1;
782                    }
783                    if $self
784                        .collector
785                        .insert_with_ordinal(min_doc, score, current_ord)
786                    {
787                        docs_scored += 1;
788                    } else {
789                        docs_skipped += 1;
790                    }
791                }
792            }
793        }
794
795        let results: Vec<ScoredDoc> = $self
796            .collector
797            .into_sorted_results()
798            .into_iter()
799            .map(|(doc_id, score, ordinal)| ScoredDoc {
800                doc_id,
801                score,
802                ordinal,
803            })
804            .collect();
805
806        debug!(
807            "MaxScoreExecutor: scored={}, skipped={}, blocks_skipped={}, conjunction_skipped={}, returned={}, top_score={:.4}",
808            docs_scored,
809            docs_skipped,
810            blocks_skipped,
811            conjunction_skipped,
812            results.len(),
813            results.first().map(|r| r.score).unwrap_or(0.0)
814        );
815
816        Ok(results)
817    }};
818}
819
820impl<'a> MaxScoreExecutor<'a> {
821    /// Create a new executor from pre-built cursors.
822    ///
823    /// Cursors are sorted by max_score ascending (non-essential first) and
824    /// prefix sums are computed for the MaxScore partitioning.
825    pub(crate) fn new(mut cursors: Vec<TermCursor<'a>>, k: usize, heap_factor: f32) -> Self {
826        // Sort by max_score ascending (non-essential first)
827        cursors.sort_by(|a, b| {
828            a.max_score
829                .partial_cmp(&b.max_score)
830                .unwrap_or(Ordering::Equal)
831        });
832
833        let mut prefix_sums = Vec::with_capacity(cursors.len());
834        let mut cumsum = 0.0f32;
835        for c in &cursors {
836            cumsum += c.max_score;
837            prefix_sums.push(cumsum);
838        }
839
840        debug!(
841            "Creating MaxScoreExecutor: num_cursors={}, k={}, total_upper={:.4}, heap_factor={:.2}",
842            cursors.len(),
843            k,
844            cumsum,
845            heap_factor
846        );
847
848        Self {
849            cursors,
850            prefix_sums,
851            collector: ScoreCollector::new(k),
852            heap_factor: heap_factor.clamp(0.0, 1.0),
853            predicate: None,
854        }
855    }
856
857    /// Create an executor for sparse vector queries.
858    ///
859    /// Builds `TermCursor::Sparse` for each matched dimension.
860    pub fn sparse(
861        sparse_index: &'a crate::segment::SparseIndex,
862        query_terms: Vec<(u32, f32)>,
863        k: usize,
864        heap_factor: f32,
865    ) -> Self {
866        let cursors: Vec<TermCursor<'a>> = query_terms
867            .iter()
868            .filter_map(|&(dim_id, qw)| {
869                let (skip_start, skip_count, global_max, block_data_offset) =
870                    sparse_index.get_skip_range_full(dim_id)?;
871                Some(TermCursor::sparse(
872                    sparse_index,
873                    qw,
874                    skip_start,
875                    skip_count,
876                    global_max,
877                    block_data_offset,
878                ))
879            })
880            .collect();
881        Self::new(cursors, k, heap_factor)
882    }
883
884    /// Create an executor for full-text BM25 queries.
885    ///
886    /// Builds `TermCursor::Text` for each posting list.
887    pub fn text(
888        posting_lists: Vec<(crate::structures::BlockPostingList, f32)>,
889        avg_field_len: f32,
890        k: usize,
891    ) -> Self {
892        let cursors: Vec<TermCursor<'a>> = posting_lists
893            .into_iter()
894            .map(|(pl, idf)| TermCursor::text(pl, idf, avg_field_len))
895            .collect();
896        Self::new(cursors, k, 1.0)
897    }
898
899    #[inline]
900    fn find_partition(&self) -> usize {
901        let threshold = self.collector.threshold() * self.heap_factor;
902        self.prefix_sums.partition_point(|&sum| sum <= threshold)
903    }
904
905    /// Attach a per-doc predicate filter to this executor.
906    ///
907    /// Docs failing the predicate are skipped after block-max pruning but
908    /// before scoring. The predicate does not affect thresholds or block-max
909    /// comparisons — the heap stores pure sparse/text scores.
910    pub fn with_predicate(mut self, predicate: super::DocPredicate<'a>) -> Self {
911        self.predicate = Some(predicate);
912        self
913    }
914
915    /// Execute Block-Max MaxScore and return top-k results (async).
916    pub async fn execute(mut self) -> crate::Result<Vec<ScoredDoc>> {
917        if self.cursors.is_empty() {
918            return Ok(Vec::new());
919        }
920        bms_execute_loop!(self, ensure_block_loaded, advance, seek, .await)
921    }
922
923    /// Synchronous execution — works when all cursors are text or mmap-backed sparse.
924    pub fn execute_sync(mut self) -> crate::Result<Vec<ScoredDoc>> {
925        if self.cursors.is_empty() {
926            return Ok(Vec::new());
927        }
928        bms_execute_loop!(self, ensure_block_loaded_sync, advance_sync, seek_sync,)
929    }
930}
931
932#[cfg(test)]
933mod tests {
934    use super::*;
935
936    #[test]
937    fn test_score_collector_basic() {
938        let mut collector = ScoreCollector::new(3);
939
940        collector.insert(1, 1.0);
941        collector.insert(2, 2.0);
942        collector.insert(3, 3.0);
943        assert_eq!(collector.threshold(), 1.0);
944
945        collector.insert(4, 4.0);
946        assert_eq!(collector.threshold(), 2.0);
947
948        let results = collector.into_sorted_results();
949        assert_eq!(results.len(), 3);
950        assert_eq!(results[0].0, 4); // Highest score
951        assert_eq!(results[1].0, 3);
952        assert_eq!(results[2].0, 2);
953    }
954
955    #[test]
956    fn test_score_collector_threshold() {
957        let mut collector = ScoreCollector::new(2);
958
959        collector.insert(1, 5.0);
960        collector.insert(2, 3.0);
961        assert_eq!(collector.threshold(), 3.0);
962
963        // Should not enter (score too low)
964        assert!(!collector.would_enter(2.0));
965        assert!(!collector.insert(3, 2.0));
966
967        // Should enter (score high enough)
968        assert!(collector.would_enter(4.0));
969        assert!(collector.insert(4, 4.0));
970        assert_eq!(collector.threshold(), 4.0);
971    }
972
973    #[test]
974    fn test_heap_entry_ordering() {
975        let mut heap = BinaryHeap::new();
976        heap.push(HeapEntry {
977            doc_id: 1,
978            score: 3.0,
979            ordinal: 0,
980        });
981        heap.push(HeapEntry {
982            doc_id: 2,
983            score: 1.0,
984            ordinal: 0,
985        });
986        heap.push(HeapEntry {
987            doc_id: 3,
988            score: 2.0,
989            ordinal: 0,
990        });
991
992        // Min-heap: lowest score should come out first
993        assert_eq!(heap.pop().unwrap().score, 1.0);
994        assert_eq!(heap.pop().unwrap().score, 2.0);
995        assert_eq!(heap.pop().unwrap().score, 3.0);
996    }
997}