Skip to main content

hermes_core/query/
docset.rs

1//! DocSet trait and concrete implementations for document iteration.
2//!
3//! `DocSet` is the base abstraction for forward-only cursors over sorted document IDs.
4//! Posting lists, filter results, and scorers all implement this trait.
5//! `PredicatedScorer` wraps a driving Scorer with pushed-down filter predicates.
6
7use std::sync::Arc;
8
9use crate::DocId;
10use crate::structures::TERMINATED;
11
12// ── DocSet trait ─────────────────────────────────────────────────────────
13
14macro_rules! define_docset_trait {
15    ($($send_bounds:tt)*) => {
16        /// Forward-only cursor over sorted document IDs.
17        ///
18        /// This is the base iteration abstraction. Posting lists, filter cursors,
19        /// and scorers all implement this trait.
20        pub trait DocSet: $($send_bounds)* {
21            /// Current document ID, or [`TERMINATED`] if exhausted.
22            fn doc(&self) -> DocId;
23
24            /// Advance to the next document. Returns the new doc ID or [`TERMINATED`].
25            fn advance(&mut self) -> DocId;
26
27            /// Seek to the first document >= `target`. Returns doc ID or [`TERMINATED`].
28            fn seek(&mut self, target: DocId) -> DocId {
29                let mut doc = self.doc();
30                while doc < target {
31                    doc = self.advance();
32                }
33                doc
34            }
35
36            /// Estimated number of remaining documents.
37            fn size_hint(&self) -> u32;
38        }
39    };
40}
41
42#[cfg(not(target_arch = "wasm32"))]
43define_docset_trait!(Send + Sync);
44
45#[cfg(target_arch = "wasm32")]
46define_docset_trait!();
47
48// ── DocSet for Box<dyn DocSet> ───────────────────────────────────────────
49
50impl DocSet for Box<dyn DocSet + '_> {
51    #[inline]
52    fn doc(&self) -> DocId {
53        (**self).doc()
54    }
55    #[inline]
56    fn advance(&mut self) -> DocId {
57        (**self).advance()
58    }
59    #[inline]
60    fn seek(&mut self, target: DocId) -> DocId {
61        (**self).seek(target)
62    }
63    #[inline]
64    fn size_hint(&self) -> u32 {
65        (**self).size_hint()
66    }
67}
68
69// ── SortedVecDocSet ──────────────────────────────────────────────────────
70
71/// DocSet backed by a sorted `Vec<u32>`. Binary search for seek.
72pub struct SortedVecDocSet {
73    docs: Arc<Vec<u32>>,
74    pos: usize,
75}
76
77impl SortedVecDocSet {
78    pub fn new(docs: Arc<Vec<u32>>) -> Self {
79        Self { docs, pos: 0 }
80    }
81}
82
83impl DocSet for SortedVecDocSet {
84    #[inline]
85    fn doc(&self) -> DocId {
86        self.docs.get(self.pos).copied().unwrap_or(TERMINATED)
87    }
88
89    #[inline]
90    fn advance(&mut self) -> DocId {
91        if self.pos < self.docs.len() {
92            self.pos += 1;
93        }
94        self.doc()
95    }
96
97    fn seek(&mut self, target: DocId) -> DocId {
98        if self.pos >= self.docs.len() {
99            return TERMINATED;
100        }
101        let remaining = &self.docs[self.pos..];
102        match remaining.binary_search(&target) {
103            Ok(offset) => {
104                self.pos += offset;
105                self.docs[self.pos]
106            }
107            Err(offset) => {
108                self.pos += offset;
109                self.doc()
110            }
111        }
112    }
113
114    fn size_hint(&self) -> u32 {
115        self.docs.len().saturating_sub(self.pos) as u32
116    }
117}
118
119// ── IntersectionDocSet ───────────────────────────────────────────────────
120
121/// DocSet that yields the intersection of two DocSets.
122pub struct IntersectionDocSet<A: DocSet, B: DocSet> {
123    a: A,
124    b: B,
125}
126
127impl<A: DocSet, B: DocSet> IntersectionDocSet<A, B> {
128    pub fn new(mut a: A, mut b: B) -> Self {
129        // Align both on the first common doc
130        let mut da = a.doc();
131        let mut db = b.doc();
132        loop {
133            if da == TERMINATED || db == TERMINATED {
134                break;
135            }
136            if da == db {
137                break;
138            }
139            if da < db {
140                da = a.seek(db);
141            } else {
142                db = b.seek(da);
143            }
144        }
145        Self { a, b }
146    }
147}
148
149impl<A: DocSet, B: DocSet> DocSet for IntersectionDocSet<A, B> {
150    fn doc(&self) -> DocId {
151        let da = self.a.doc();
152        if da == TERMINATED || self.b.doc() == TERMINATED {
153            TERMINATED
154        } else {
155            da
156        }
157    }
158
159    fn advance(&mut self) -> DocId {
160        let mut da = self.a.advance();
161        let mut db = self.b.doc();
162        loop {
163            if da == TERMINATED || db == TERMINATED {
164                return TERMINATED;
165            }
166            if da == db {
167                return da;
168            }
169            if da < db {
170                da = self.a.seek(db);
171            } else {
172                db = self.b.seek(da);
173            }
174        }
175    }
176
177    fn seek(&mut self, target: DocId) -> DocId {
178        let mut da = self.a.seek(target);
179        let mut db = self.b.seek(target);
180        loop {
181            if da == TERMINATED || db == TERMINATED {
182                return TERMINATED;
183            }
184            if da == db {
185                return da;
186            }
187            if da < db {
188                da = self.a.seek(db);
189            } else {
190                db = self.b.seek(da);
191            }
192        }
193    }
194
195    fn size_hint(&self) -> u32 {
196        self.a.size_hint().min(self.b.size_hint())
197    }
198}
199
200// ── AllDocSet ────────────────────────────────────────────────────────────
201
202/// DocSet that yields all documents 0..num_docs.
203pub struct AllDocSet {
204    current: u32,
205    num_docs: u32,
206}
207
208impl AllDocSet {
209    pub fn new(num_docs: u32) -> Self {
210        Self {
211            current: 0,
212            num_docs,
213        }
214    }
215}
216
217impl DocSet for AllDocSet {
218    #[inline]
219    fn doc(&self) -> DocId {
220        if self.current >= self.num_docs {
221            TERMINATED
222        } else {
223            self.current
224        }
225    }
226
227    #[inline]
228    fn advance(&mut self) -> DocId {
229        self.current += 1;
230        self.doc()
231    }
232
233    #[inline]
234    fn seek(&mut self, target: DocId) -> DocId {
235        self.current = target;
236        self.doc()
237    }
238
239    fn size_hint(&self) -> u32 {
240        self.num_docs.saturating_sub(self.current)
241    }
242}
243
244// ── EmptyDocSet ──────────────────────────────────────────────────────────
245
246/// DocSet that is always empty.
247pub struct EmptyDocSet;
248
249impl DocSet for EmptyDocSet {
250    #[inline]
251    fn doc(&self) -> DocId {
252        TERMINATED
253    }
254    #[inline]
255    fn advance(&mut self) -> DocId {
256        TERMINATED
257    }
258    #[inline]
259    fn seek(&mut self, _target: DocId) -> DocId {
260        TERMINATED
261    }
262    fn size_hint(&self) -> u32 {
263        0
264    }
265}
266
267// ── PredicatedScorer ─────────────────────────────────────────────────────
268
269/// Wraps a driving Scorer with filter conditions pushed down.
270///
271/// Used by the query planner to flip iteration order: the SHOULD scorer
272/// drives and MUST/MUST_NOT clauses are checked per-doc via:
273/// - O(1) predicate closures (e.g., fast-field range checks)
274/// - seek()-based verifier scorers (e.g., TermQuery posting list lookups)
275///
276/// Verifier scorers contribute their actual per-doc score (e.g., BM25).
277pub struct PredicatedScorer<'a> {
278    /// Driving scorer (typically SHOULD clauses)
279    driver: Box<dyn super::Scorer + 'a>,
280    /// O(1) predicate checks (from filter queries like RangeQuery, or negated MUST_NOT)
281    predicates: Vec<super::DocPredicate<'a>>,
282    /// MUST scorers verified via seek() — preserves per-doc scoring
283    must_verifiers: Vec<Box<dyn super::Scorer + 'a>>,
284    /// MUST_NOT scorers — docs are excluded if these land on them
285    must_not_verifiers: Vec<Box<dyn super::Scorer + 'a>>,
286}
287
288impl<'a> PredicatedScorer<'a> {
289    pub fn new(
290        driver: Box<dyn super::Scorer + 'a>,
291        predicates: Vec<super::DocPredicate<'a>>,
292        must_verifiers: Vec<Box<dyn super::Scorer + 'a>>,
293        must_not_verifiers: Vec<Box<dyn super::Scorer + 'a>>,
294    ) -> Self {
295        let mut s = Self {
296            driver,
297            predicates,
298            must_verifiers,
299            must_not_verifiers,
300        };
301        // Position on first matching doc
302        s.skip_non_matching();
303        s
304    }
305
306    /// Check whether `doc` passes all filter conditions.
307    #[inline]
308    fn check_filters(&mut self, doc: DocId) -> bool {
309        // O(1) predicate checks first (cheapest)
310        if !self.predicates.iter().all(|p| p(doc)) {
311            return false;
312        }
313        // MUST verifiers: seek to doc, must land exactly on it
314        if !self.must_verifiers.iter_mut().all(|s| s.seek(doc) == doc) {
315            return false;
316        }
317        // MUST_NOT verifiers: seek to doc, must NOT land on it
318        self.must_not_verifiers
319            .iter_mut()
320            .all(|s| s.seek(doc) != doc)
321    }
322
323    /// Advance driver past non-matching docs.
324    fn skip_non_matching(&mut self) -> DocId {
325        let mut doc = self.driver.doc();
326        while doc != TERMINATED && !self.check_filters(doc) {
327            doc = self.driver.advance();
328        }
329        doc
330    }
331}
332
333impl DocSet for PredicatedScorer<'_> {
334    fn doc(&self) -> DocId {
335        self.driver.doc()
336    }
337
338    fn advance(&mut self) -> DocId {
339        self.driver.advance();
340        self.skip_non_matching()
341    }
342
343    fn seek(&mut self, target: DocId) -> DocId {
344        self.driver.seek(target);
345        self.skip_non_matching()
346    }
347
348    fn size_hint(&self) -> u32 {
349        self.driver.size_hint()
350    }
351}
352
353impl super::Scorer for PredicatedScorer<'_> {
354    fn score(&self) -> crate::Score {
355        let mut total = self.driver.score();
356        for v in &self.must_verifiers {
357            total += v.score();
358        }
359        total
360    }
361
362    fn matched_positions(&self) -> Option<super::MatchedPositions> {
363        let mut all: super::MatchedPositions = Vec::new();
364        if let Some(p) = self.driver.matched_positions() {
365            all.extend(p);
366        }
367        for v in &self.must_verifiers {
368            if let Some(p) = v.matched_positions() {
369                all.extend(p);
370            }
371        }
372        if all.is_empty() { None } else { Some(all) }
373    }
374}
375
376// ── Tests ────────────────────────────────────────────────────────────────
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381
382    #[test]
383    fn test_sorted_vec_docset_basic() {
384        let docs = Arc::new(vec![1, 3, 5, 7, 9]);
385        let mut ds = SortedVecDocSet::new(docs);
386
387        assert_eq!(ds.doc(), 1);
388        assert_eq!(ds.advance(), 3);
389        assert_eq!(ds.advance(), 5);
390        assert_eq!(ds.seek(7), 7);
391        assert_eq!(ds.advance(), 9);
392        assert_eq!(ds.advance(), TERMINATED);
393        assert_eq!(ds.doc(), TERMINATED);
394    }
395
396    #[test]
397    fn test_sorted_vec_docset_seek_past() {
398        let docs = Arc::new(vec![1, 5, 10, 20]);
399        let mut ds = SortedVecDocSet::new(docs);
400
401        assert_eq!(ds.seek(3), 5);
402        assert_eq!(ds.seek(15), 20);
403        assert_eq!(ds.seek(21), TERMINATED);
404    }
405
406    #[test]
407    fn test_sorted_vec_docset_empty() {
408        let docs = Arc::new(vec![]);
409        let ds = SortedVecDocSet::new(docs);
410        assert_eq!(ds.doc(), TERMINATED);
411    }
412
413    #[test]
414    fn test_all_docset() {
415        let mut ds = AllDocSet::new(3);
416        assert_eq!(ds.doc(), 0);
417        assert_eq!(ds.advance(), 1);
418        assert_eq!(ds.advance(), 2);
419        assert_eq!(ds.advance(), TERMINATED);
420    }
421
422    #[test]
423    fn test_all_docset_seek() {
424        let mut ds = AllDocSet::new(10);
425        assert_eq!(ds.seek(5), 5);
426        assert_eq!(ds.seek(9), 9);
427        assert_eq!(ds.seek(10), TERMINATED);
428    }
429
430    #[test]
431    fn test_empty_docset() {
432        let mut ds = EmptyDocSet;
433        assert_eq!(ds.doc(), TERMINATED);
434        assert_eq!(ds.advance(), TERMINATED);
435        assert_eq!(ds.seek(5), TERMINATED);
436        assert_eq!(ds.size_hint(), 0);
437    }
438
439    #[test]
440    fn test_intersection_docset() {
441        let a = SortedVecDocSet::new(Arc::new(vec![1, 3, 5, 7, 9]));
442        let b = SortedVecDocSet::new(Arc::new(vec![2, 3, 5, 8, 9, 10]));
443        let mut isect = IntersectionDocSet::new(a, b);
444
445        assert_eq!(isect.doc(), 3);
446        assert_eq!(isect.advance(), 5);
447        assert_eq!(isect.advance(), 9);
448        assert_eq!(isect.advance(), TERMINATED);
449    }
450
451    #[test]
452    fn test_intersection_docset_empty() {
453        let a = SortedVecDocSet::new(Arc::new(vec![1, 3, 5]));
454        let b = SortedVecDocSet::new(Arc::new(vec![2, 4, 6]));
455        let isect = IntersectionDocSet::new(a, b);
456        assert_eq!(isect.doc(), TERMINATED);
457    }
458
459    #[test]
460    fn test_intersection_docset_seek() {
461        let a = SortedVecDocSet::new(Arc::new(vec![1, 5, 10, 20, 30]));
462        let b = SortedVecDocSet::new(Arc::new(vec![5, 10, 15, 20, 25, 30]));
463        let mut isect = IntersectionDocSet::new(a, b);
464
465        assert_eq!(isect.doc(), 5);
466        assert_eq!(isect.seek(15), 20);
467        assert_eq!(isect.advance(), 30);
468        assert_eq!(isect.advance(), TERMINATED);
469    }
470
471    #[test]
472    fn test_size_hint() {
473        let docs = Arc::new(vec![1, 2, 3, 4, 5]);
474        let mut ds = SortedVecDocSet::new(docs);
475        assert_eq!(ds.size_hint(), 5);
476        ds.advance();
477        assert_eq!(ds.size_hint(), 4);
478        ds.seek(4);
479        assert_eq!(ds.size_hint(), 2); // pos=3, remaining: [4, 5]
480    }
481}