Skip to main content

hermes_core/query/
prefix.rs

1//! Prefix query — matches all documents containing any term that starts with a
2//! given prefix. Materializes the union of matching posting lists into a sorted
3//! doc ID set, giving O(log N) seek via `SortedVecDocSet`. Score is always 1.0
4//! (filter-style, like `RangeQuery`).
5
6use std::sync::Arc;
7
8use crate::dsl::Field;
9use crate::segment::SegmentReader;
10use crate::structures::{BlockPostingList, TERMINATED};
11use crate::{DocId, Score};
12
13use super::docset::{DocSet, SortedVecDocSet};
14use super::traits::{CountFuture, EmptyScorer, Query, Scorer, ScorerFuture};
15
16/// Prefix query — matches documents containing any term starting with `prefix`.
17#[derive(Debug, Clone)]
18pub struct PrefixQuery {
19    pub field: Field,
20    pub prefix: Vec<u8>,
21}
22
23impl std::fmt::Display for PrefixQuery {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        write!(
26            f,
27            "Prefix({}:\"{}*\")",
28            self.field.0,
29            String::from_utf8_lossy(&self.prefix)
30        )
31    }
32}
33
34impl PrefixQuery {
35    /// Create from raw bytes.
36    pub fn new(field: Field, prefix: impl Into<Vec<u8>>) -> Self {
37        Self {
38            field,
39            prefix: prefix.into(),
40        }
41    }
42
43    /// Create from text — lowercased to match default tokenization.
44    pub fn text(field: Field, text: &str) -> Self {
45        Self {
46            field,
47            prefix: text.to_lowercase().into_bytes(),
48        }
49    }
50}
51
52impl Query for PrefixQuery {
53    fn scorer<'a>(&self, reader: &'a SegmentReader, _limit: usize) -> ScorerFuture<'a> {
54        let field = self.field;
55        let prefix = self.prefix.clone();
56        Box::pin(async move {
57            let postings = reader.get_prefix_postings(field, &prefix).await?;
58            if postings.is_empty() {
59                return Ok(Box::new(EmptyScorer) as Box<dyn Scorer>);
60            }
61            let docs = materialize_union(&postings);
62            if docs.is_empty() {
63                return Ok(Box::new(EmptyScorer) as Box<dyn Scorer>);
64            }
65            Ok(Box::new(PrefixScorer::new(docs)) as Box<dyn Scorer>)
66        })
67    }
68
69    #[cfg(feature = "sync")]
70    fn scorer_sync<'a>(
71        &self,
72        reader: &'a SegmentReader,
73        _limit: usize,
74    ) -> crate::Result<Box<dyn Scorer + 'a>> {
75        let postings = reader.get_prefix_postings_sync(self.field, &self.prefix)?;
76        if postings.is_empty() {
77            return Ok(Box::new(EmptyScorer) as Box<dyn Scorer>);
78        }
79        let docs = materialize_union(&postings);
80        if docs.is_empty() {
81            return Ok(Box::new(EmptyScorer) as Box<dyn Scorer>);
82        }
83        Ok(Box::new(PrefixScorer::new(docs)) as Box<dyn Scorer>)
84    }
85
86    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
87        let field = self.field;
88        let prefix = self.prefix.clone();
89        Box::pin(async move {
90            let postings = reader.get_prefix_postings(field, &prefix).await?;
91            Ok(postings.iter().map(|p| p.doc_count()).sum())
92        })
93    }
94
95    fn is_filter(&self) -> bool {
96        true
97    }
98}
99
100// ── PrefixScorer ────────────────────────────────────────────────────────
101
102/// Scorer backed by a pre-materialized sorted doc ID set.
103struct PrefixScorer {
104    inner: SortedVecDocSet,
105}
106
107impl PrefixScorer {
108    fn new(docs: Vec<u32>) -> Self {
109        Self {
110            inner: SortedVecDocSet::new(Arc::new(docs)),
111        }
112    }
113}
114
115impl DocSet for PrefixScorer {
116    #[inline]
117    fn doc(&self) -> DocId {
118        self.inner.doc()
119    }
120
121    #[inline]
122    fn advance(&mut self) -> DocId {
123        self.inner.advance()
124    }
125
126    fn seek(&mut self, target: DocId) -> DocId {
127        self.inner.seek(target)
128    }
129
130    fn size_hint(&self) -> u32 {
131        self.inner.size_hint()
132    }
133}
134
135impl Scorer for PrefixScorer {
136    fn score(&self) -> Score {
137        1.0
138    }
139}
140
141// ── Helpers ─────────────────────────────────────────────────────────────
142
143/// Iterate all posting lists, collect doc IDs, sort, and deduplicate.
144fn materialize_union(postings: &[BlockPostingList]) -> Vec<u32> {
145    let total: usize = postings.iter().map(|p| p.doc_count() as usize).sum();
146    let mut docs = Vec::with_capacity(total);
147
148    for posting in postings {
149        let mut iter = posting.iterator();
150        loop {
151            let d = iter.doc();
152            if d == TERMINATED {
153                break;
154            }
155            docs.push(d);
156            iter.advance();
157        }
158    }
159
160    docs.sort_unstable();
161    docs.dedup();
162    docs
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[test]
170    fn test_materialize_union_empty() {
171        let docs = materialize_union(&[]);
172        assert!(docs.is_empty());
173    }
174
175    #[test]
176    fn test_prefix_scorer_basic() {
177        let mut scorer = PrefixScorer::new(vec![1, 5, 10, 20]);
178        assert_eq!(scorer.doc(), 1);
179        assert_eq!(scorer.score(), 1.0);
180        assert_eq!(scorer.advance(), 5);
181        assert_eq!(scorer.seek(10), 10);
182        assert_eq!(scorer.advance(), 20);
183        assert_eq!(scorer.advance(), TERMINATED);
184    }
185
186    #[test]
187    fn test_prefix_scorer_seek_past() {
188        let mut scorer = PrefixScorer::new(vec![1, 5, 10, 20]);
189        assert_eq!(scorer.seek(7), 10);
190        assert_eq!(scorer.seek(100), TERMINATED);
191    }
192
193    #[test]
194    fn test_prefix_query_display() {
195        let q = PrefixQuery::text(Field(0), "abc");
196        assert_eq!(format!("{}", q), "Prefix(0:\"abc*\")");
197    }
198
199    #[test]
200    fn test_prefix_query_is_filter() {
201        let q = PrefixQuery::text(Field(0), "test");
202        assert!(q.is_filter());
203    }
204}