Skip to main content

hermes_core/query/
docset.rs

1//! DocSet trait and concrete implementations for document iteration.
2//!
3//! `DocSet` is the base abstraction for forward-only cursors over sorted document IDs.
4//! Posting lists, filter results, and scorers all implement this trait.
5//! `IntersectionScorer` intersects a Scorer with a DocSet filter, driving from
6//! the smaller side by `size_hint`.
7
8use std::sync::Arc;
9
10use crate::DocId;
11use crate::structures::TERMINATED;
12
13// ── DocSet trait ─────────────────────────────────────────────────────────
14
15macro_rules! define_docset_trait {
16    ($($send_bounds:tt)*) => {
17        /// Forward-only cursor over sorted document IDs.
18        ///
19        /// This is the base iteration abstraction. Posting lists, filter cursors,
20        /// and scorers all implement this trait.
21        pub trait DocSet: $($send_bounds)* {
22            /// Current document ID, or [`TERMINATED`] if exhausted.
23            fn doc(&self) -> DocId;
24
25            /// Advance to the next document. Returns the new doc ID or [`TERMINATED`].
26            fn advance(&mut self) -> DocId;
27
28            /// Seek to the first document >= `target`. Returns doc ID or [`TERMINATED`].
29            fn seek(&mut self, target: DocId) -> DocId {
30                let mut doc = self.doc();
31                while doc < target {
32                    doc = self.advance();
33                }
34                doc
35            }
36
37            /// Estimated number of remaining documents.
38            fn size_hint(&self) -> u32;
39        }
40    };
41}
42
43#[cfg(not(target_arch = "wasm32"))]
44define_docset_trait!(Send + Sync);
45
46#[cfg(target_arch = "wasm32")]
47define_docset_trait!();
48
49// ── DocSet for Box<dyn DocSet> ───────────────────────────────────────────
50
51impl DocSet for Box<dyn DocSet + '_> {
52    #[inline]
53    fn doc(&self) -> DocId {
54        (**self).doc()
55    }
56    #[inline]
57    fn advance(&mut self) -> DocId {
58        (**self).advance()
59    }
60    #[inline]
61    fn seek(&mut self, target: DocId) -> DocId {
62        (**self).seek(target)
63    }
64    #[inline]
65    fn size_hint(&self) -> u32 {
66        (**self).size_hint()
67    }
68}
69
70// ── SortedVecDocSet ──────────────────────────────────────────────────────
71
72/// DocSet backed by a sorted `Vec<u32>`. Binary search for seek.
73pub struct SortedVecDocSet {
74    docs: Arc<Vec<u32>>,
75    pos: usize,
76}
77
78impl SortedVecDocSet {
79    pub fn new(docs: Arc<Vec<u32>>) -> Self {
80        Self { docs, pos: 0 }
81    }
82}
83
84impl DocSet for SortedVecDocSet {
85    #[inline]
86    fn doc(&self) -> DocId {
87        self.docs.get(self.pos).copied().unwrap_or(TERMINATED)
88    }
89
90    #[inline]
91    fn advance(&mut self) -> DocId {
92        if self.pos < self.docs.len() {
93            self.pos += 1;
94        }
95        self.doc()
96    }
97
98    fn seek(&mut self, target: DocId) -> DocId {
99        if self.pos >= self.docs.len() {
100            return TERMINATED;
101        }
102        let remaining = &self.docs[self.pos..];
103        match remaining.binary_search(&target) {
104            Ok(offset) => {
105                self.pos += offset;
106                self.docs[self.pos]
107            }
108            Err(offset) => {
109                self.pos += offset;
110                self.doc()
111            }
112        }
113    }
114
115    fn size_hint(&self) -> u32 {
116        self.docs.len().saturating_sub(self.pos) as u32
117    }
118}
119
120// ── IntersectionDocSet ───────────────────────────────────────────────────
121
122/// DocSet that yields the intersection of two DocSets.
123pub struct IntersectionDocSet<A: DocSet, B: DocSet> {
124    a: A,
125    b: B,
126}
127
128impl<A: DocSet, B: DocSet> IntersectionDocSet<A, B> {
129    pub fn new(mut a: A, mut b: B) -> Self {
130        // Align both on the first common doc
131        let mut da = a.doc();
132        let mut db = b.doc();
133        loop {
134            if da == TERMINATED || db == TERMINATED {
135                break;
136            }
137            if da == db {
138                break;
139            }
140            if da < db {
141                da = a.seek(db);
142            } else {
143                db = b.seek(da);
144            }
145        }
146        Self { a, b }
147    }
148}
149
150impl<A: DocSet, B: DocSet> DocSet for IntersectionDocSet<A, B> {
151    fn doc(&self) -> DocId {
152        let da = self.a.doc();
153        if da == TERMINATED || self.b.doc() == TERMINATED {
154            TERMINATED
155        } else {
156            da
157        }
158    }
159
160    fn advance(&mut self) -> DocId {
161        let mut da = self.a.advance();
162        let mut db = self.b.doc();
163        loop {
164            if da == TERMINATED || db == TERMINATED {
165                return TERMINATED;
166            }
167            if da == db {
168                return da;
169            }
170            if da < db {
171                da = self.a.seek(db);
172            } else {
173                db = self.b.seek(da);
174            }
175        }
176    }
177
178    fn seek(&mut self, target: DocId) -> DocId {
179        let mut da = self.a.seek(target);
180        let mut db = self.b.seek(target);
181        loop {
182            if da == TERMINATED || db == TERMINATED {
183                return TERMINATED;
184            }
185            if da == db {
186                return da;
187            }
188            if da < db {
189                da = self.a.seek(db);
190            } else {
191                db = self.b.seek(da);
192            }
193        }
194    }
195
196    fn size_hint(&self) -> u32 {
197        self.a.size_hint().min(self.b.size_hint())
198    }
199}
200
201// ── AllDocSet ────────────────────────────────────────────────────────────
202
203/// DocSet that yields all documents 0..num_docs.
204pub struct AllDocSet {
205    current: u32,
206    num_docs: u32,
207}
208
209impl AllDocSet {
210    pub fn new(num_docs: u32) -> Self {
211        Self {
212            current: 0,
213            num_docs,
214        }
215    }
216}
217
218impl DocSet for AllDocSet {
219    #[inline]
220    fn doc(&self) -> DocId {
221        if self.current >= self.num_docs {
222            TERMINATED
223        } else {
224            self.current
225        }
226    }
227
228    #[inline]
229    fn advance(&mut self) -> DocId {
230        self.current += 1;
231        self.doc()
232    }
233
234    #[inline]
235    fn seek(&mut self, target: DocId) -> DocId {
236        self.current = target;
237        self.doc()
238    }
239
240    fn size_hint(&self) -> u32 {
241        self.num_docs.saturating_sub(self.current)
242    }
243}
244
245// ── EmptyDocSet ──────────────────────────────────────────────────────────
246
247/// DocSet that is always empty.
248pub struct EmptyDocSet;
249
250impl DocSet for EmptyDocSet {
251    #[inline]
252    fn doc(&self) -> DocId {
253        TERMINATED
254    }
255    #[inline]
256    fn advance(&mut self) -> DocId {
257        TERMINATED
258    }
259    #[inline]
260    fn seek(&mut self, _target: DocId) -> DocId {
261        TERMINATED
262    }
263    fn size_hint(&self) -> u32 {
264        0
265    }
266}
267
268// ── IntersectionScorer ───────────────────────────────────────────────────
269
270/// Intersects a Scorer with a filter DocSet, driving from the smaller side.
271///
272/// This is the core composition primitive: filter queries create a filter DocSet,
273/// get the inner Scorer, and return `IntersectionScorer(scorer, filter)`.
274pub struct IntersectionScorer<'a> {
275    scorer: Box<dyn super::Scorer + 'a>,
276    filter: Box<dyn DocSet + 'a>,
277}
278
279impl<'a> IntersectionScorer<'a> {
280    pub fn new(mut scorer: Box<dyn super::Scorer + 'a>, mut filter: Box<dyn DocSet + 'a>) -> Self {
281        // Align both on first common doc
282        let mut ds = scorer.doc();
283        let mut df = filter.doc();
284        loop {
285            if ds == TERMINATED || df == TERMINATED {
286                break;
287            }
288            if ds == df {
289                break;
290            }
291            if ds < df {
292                ds = scorer.seek(df);
293            } else {
294                df = filter.seek(ds);
295            }
296        }
297        Self { scorer, filter }
298    }
299}
300
301impl DocSet for IntersectionScorer<'_> {
302    fn doc(&self) -> DocId {
303        let ds = self.scorer.doc();
304        if ds == TERMINATED || self.filter.doc() == TERMINATED {
305            TERMINATED
306        } else {
307            ds
308        }
309    }
310
311    fn advance(&mut self) -> DocId {
312        // Drive from the smaller side
313        let filter_smaller = self.filter.size_hint() < self.scorer.size_hint();
314
315        if filter_smaller {
316            // Filter drives
317            let mut df = self.filter.advance();
318            let mut ds = self.scorer.doc();
319            loop {
320                if df == TERMINATED || ds == TERMINATED {
321                    return TERMINATED;
322                }
323                if df == ds {
324                    return df;
325                }
326                if df < ds {
327                    df = self.filter.seek(ds);
328                } else {
329                    ds = self.scorer.seek(df);
330                }
331            }
332        } else {
333            // Scorer drives
334            let mut ds = self.scorer.advance();
335            let mut df = self.filter.doc();
336            loop {
337                if ds == TERMINATED || df == TERMINATED {
338                    return TERMINATED;
339                }
340                if ds == df {
341                    return ds;
342                }
343                if ds < df {
344                    ds = self.scorer.seek(df);
345                } else {
346                    df = self.filter.seek(ds);
347                }
348            }
349        }
350    }
351
352    fn seek(&mut self, target: DocId) -> DocId {
353        let mut ds = self.scorer.seek(target);
354        let mut df = self.filter.seek(target);
355        loop {
356            if ds == TERMINATED || df == TERMINATED {
357                return TERMINATED;
358            }
359            if ds == df {
360                return ds;
361            }
362            if ds < df {
363                ds = self.scorer.seek(df);
364            } else {
365                df = self.filter.seek(ds);
366            }
367        }
368    }
369
370    fn size_hint(&self) -> u32 {
371        self.scorer.size_hint().min(self.filter.size_hint())
372    }
373}
374
375impl super::Scorer for IntersectionScorer<'_> {
376    fn score(&self) -> crate::Score {
377        self.scorer.score()
378    }
379
380    fn matched_positions(&self) -> Option<super::MatchedPositions> {
381        self.scorer.matched_positions()
382    }
383}
384
385// ── PredicatedScorer ─────────────────────────────────────────────────────
386
387/// Wraps a driving Scorer with filter conditions pushed down.
388///
389/// Used by the query planner to flip iteration order: the SHOULD scorer
390/// drives and MUST/MUST_NOT clauses are checked per-doc via:
391/// - O(1) predicate closures (e.g., fast-field range checks)
392/// - seek()-based verifier scorers (e.g., TermQuery posting list lookups)
393///
394/// `filter_score` is a constant added for predicate-converted filters.
395/// Verifier scorers contribute their actual per-doc score (e.g., BM25).
396pub struct PredicatedScorer<'a> {
397    /// Driving scorer (typically SHOULD clauses)
398    driver: Box<dyn super::Scorer + 'a>,
399    /// O(1) predicate checks (from filter queries like RangeQuery, or negated MUST_NOT)
400    predicates: Vec<super::DocPredicate<'a>>,
401    /// MUST scorers verified via seek() — preserves per-doc scoring
402    must_verifiers: Vec<Box<dyn super::Scorer + 'a>>,
403    /// MUST_NOT scorers — docs are excluded if these land on them
404    must_not_verifiers: Vec<Box<dyn super::Scorer + 'a>>,
405    /// Constant score from predicate-converted filters (1.0 per filter)
406    filter_score: f32,
407}
408
409impl<'a> PredicatedScorer<'a> {
410    pub fn new(
411        driver: Box<dyn super::Scorer + 'a>,
412        predicates: Vec<super::DocPredicate<'a>>,
413        must_verifiers: Vec<Box<dyn super::Scorer + 'a>>,
414        must_not_verifiers: Vec<Box<dyn super::Scorer + 'a>>,
415        filter_score: f32,
416    ) -> Self {
417        let mut s = Self {
418            driver,
419            predicates,
420            must_verifiers,
421            must_not_verifiers,
422            filter_score,
423        };
424        // Position on first matching doc
425        s.skip_non_matching();
426        s
427    }
428
429    /// Check whether `doc` passes all filter conditions.
430    #[inline]
431    fn check_filters(&mut self, doc: DocId) -> bool {
432        // O(1) predicate checks first (cheapest)
433        if !self.predicates.iter().all(|p| p(doc)) {
434            return false;
435        }
436        // MUST verifiers: seek to doc, must land exactly on it
437        if !self.must_verifiers.iter_mut().all(|s| s.seek(doc) == doc) {
438            return false;
439        }
440        // MUST_NOT verifiers: seek to doc, must NOT land on it
441        self.must_not_verifiers
442            .iter_mut()
443            .all(|s| s.seek(doc) != doc)
444    }
445
446    /// Advance driver past non-matching docs.
447    fn skip_non_matching(&mut self) -> DocId {
448        let mut doc = self.driver.doc();
449        while doc != TERMINATED && !self.check_filters(doc) {
450            doc = self.driver.advance();
451        }
452        doc
453    }
454}
455
456impl DocSet for PredicatedScorer<'_> {
457    fn doc(&self) -> DocId {
458        self.driver.doc()
459    }
460
461    fn advance(&mut self) -> DocId {
462        self.driver.advance();
463        self.skip_non_matching()
464    }
465
466    fn seek(&mut self, target: DocId) -> DocId {
467        self.driver.seek(target);
468        self.skip_non_matching()
469    }
470
471    fn size_hint(&self) -> u32 {
472        self.driver.size_hint()
473    }
474}
475
476impl super::Scorer for PredicatedScorer<'_> {
477    fn score(&self) -> crate::Score {
478        let mut total = self.driver.score();
479        for v in &self.must_verifiers {
480            total += v.score();
481        }
482        total + self.filter_score
483    }
484
485    fn matched_positions(&self) -> Option<super::MatchedPositions> {
486        let mut all: super::MatchedPositions = Vec::new();
487        if let Some(p) = self.driver.matched_positions() {
488            all.extend(p);
489        }
490        for v in &self.must_verifiers {
491            if let Some(p) = v.matched_positions() {
492                all.extend(p);
493            }
494        }
495        if all.is_empty() { None } else { Some(all) }
496    }
497}
498
499// ── Tests ────────────────────────────────────────────────────────────────
500
501#[cfg(test)]
502mod tests {
503    use super::*;
504
505    #[test]
506    fn test_sorted_vec_docset_basic() {
507        let docs = Arc::new(vec![1, 3, 5, 7, 9]);
508        let mut ds = SortedVecDocSet::new(docs);
509
510        assert_eq!(ds.doc(), 1);
511        assert_eq!(ds.advance(), 3);
512        assert_eq!(ds.advance(), 5);
513        assert_eq!(ds.seek(7), 7);
514        assert_eq!(ds.advance(), 9);
515        assert_eq!(ds.advance(), TERMINATED);
516        assert_eq!(ds.doc(), TERMINATED);
517    }
518
519    #[test]
520    fn test_sorted_vec_docset_seek_past() {
521        let docs = Arc::new(vec![1, 5, 10, 20]);
522        let mut ds = SortedVecDocSet::new(docs);
523
524        assert_eq!(ds.seek(3), 5);
525        assert_eq!(ds.seek(15), 20);
526        assert_eq!(ds.seek(21), TERMINATED);
527    }
528
529    #[test]
530    fn test_sorted_vec_docset_empty() {
531        let docs = Arc::new(vec![]);
532        let ds = SortedVecDocSet::new(docs);
533        assert_eq!(ds.doc(), TERMINATED);
534    }
535
536    #[test]
537    fn test_all_docset() {
538        let mut ds = AllDocSet::new(3);
539        assert_eq!(ds.doc(), 0);
540        assert_eq!(ds.advance(), 1);
541        assert_eq!(ds.advance(), 2);
542        assert_eq!(ds.advance(), TERMINATED);
543    }
544
545    #[test]
546    fn test_all_docset_seek() {
547        let mut ds = AllDocSet::new(10);
548        assert_eq!(ds.seek(5), 5);
549        assert_eq!(ds.seek(9), 9);
550        assert_eq!(ds.seek(10), TERMINATED);
551    }
552
553    #[test]
554    fn test_empty_docset() {
555        let mut ds = EmptyDocSet;
556        assert_eq!(ds.doc(), TERMINATED);
557        assert_eq!(ds.advance(), TERMINATED);
558        assert_eq!(ds.seek(5), TERMINATED);
559        assert_eq!(ds.size_hint(), 0);
560    }
561
562    #[test]
563    fn test_intersection_docset() {
564        let a = SortedVecDocSet::new(Arc::new(vec![1, 3, 5, 7, 9]));
565        let b = SortedVecDocSet::new(Arc::new(vec![2, 3, 5, 8, 9, 10]));
566        let mut isect = IntersectionDocSet::new(a, b);
567
568        assert_eq!(isect.doc(), 3);
569        assert_eq!(isect.advance(), 5);
570        assert_eq!(isect.advance(), 9);
571        assert_eq!(isect.advance(), TERMINATED);
572    }
573
574    #[test]
575    fn test_intersection_docset_empty() {
576        let a = SortedVecDocSet::new(Arc::new(vec![1, 3, 5]));
577        let b = SortedVecDocSet::new(Arc::new(vec![2, 4, 6]));
578        let isect = IntersectionDocSet::new(a, b);
579        assert_eq!(isect.doc(), TERMINATED);
580    }
581
582    #[test]
583    fn test_intersection_docset_seek() {
584        let a = SortedVecDocSet::new(Arc::new(vec![1, 5, 10, 20, 30]));
585        let b = SortedVecDocSet::new(Arc::new(vec![5, 10, 15, 20, 25, 30]));
586        let mut isect = IntersectionDocSet::new(a, b);
587
588        assert_eq!(isect.doc(), 5);
589        assert_eq!(isect.seek(15), 20);
590        assert_eq!(isect.advance(), 30);
591        assert_eq!(isect.advance(), TERMINATED);
592    }
593
594    #[test]
595    fn test_size_hint() {
596        let docs = Arc::new(vec![1, 2, 3, 4, 5]);
597        let mut ds = SortedVecDocSet::new(docs);
598        assert_eq!(ds.size_hint(), 5);
599        ds.advance();
600        assert_eq!(ds.size_hint(), 4);
601        ds.seek(4);
602        assert_eq!(ds.size_hint(), 2); // pos=3, remaining: [4, 5]
603    }
604}