Skip to main content

hermes_core/query/
term.rs

1//! Term query - matches documents containing a specific term
2
3use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::BlockPostingList;
8use crate::structures::TERMINATED;
9use crate::{DocId, Score};
10
11use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture, TermQueryInfo};
12
13/// Term query - matches documents containing a specific term
14#[derive(Clone)]
15pub struct TermQuery {
16    pub field: Field,
17    pub term: Vec<u8>,
18    /// Optional global statistics for cross-segment IDF
19    global_stats: Option<Arc<GlobalStats>>,
20}
21
22impl std::fmt::Debug for TermQuery {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        f.debug_struct("TermQuery")
25            .field("field", &self.field)
26            .field("term", &String::from_utf8_lossy(&self.term))
27            .field("has_global_stats", &self.global_stats.is_some())
28            .finish()
29    }
30}
31
32impl std::fmt::Display for TermQuery {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        write!(
35            f,
36            "Term({}:\"{}\")",
37            self.field.0,
38            String::from_utf8_lossy(&self.term)
39        )
40    }
41}
42
43impl TermQuery {
44    pub fn new(field: Field, term: impl Into<Vec<u8>>) -> Self {
45        Self {
46            field,
47            term: term.into(),
48            global_stats: None,
49        }
50    }
51
52    pub fn text(field: Field, text: &str) -> Self {
53        Self {
54            field,
55            term: text.to_lowercase().into_bytes(),
56            global_stats: None,
57        }
58    }
59
60    /// Create with global statistics for cross-segment IDF
61    pub fn with_global_stats(field: Field, text: &str, stats: Arc<GlobalStats>) -> Self {
62        Self {
63            field,
64            term: text.to_lowercase().into_bytes(),
65            global_stats: Some(stats),
66        }
67    }
68
69    /// Set global statistics for cross-segment IDF
70    pub fn set_global_stats(&mut self, stats: Arc<GlobalStats>) {
71        self.global_stats = Some(stats);
72    }
73}
74
75/// Compute (idf, avg_field_len) from a posting list, using global stats when available.
76fn compute_term_idf(
77    posting_list: &BlockPostingList,
78    field: Field,
79    reader: &SegmentReader,
80    global_stats: Option<&Arc<GlobalStats>>,
81    term: &[u8],
82) -> (f32, f32) {
83    if let Some(stats) = global_stats {
84        let term_str = String::from_utf8_lossy(term);
85        let global_idf = stats.text_idf(field, &term_str);
86        if global_idf > 0.0 {
87            return (global_idf, stats.avg_field_len(field));
88        }
89    }
90    let num_docs = reader.num_docs() as f32;
91    let doc_freq = posting_list.doc_count() as f32;
92    (
93        super::bm25_idf(doc_freq, num_docs),
94        reader.avg_field_len(field),
95    )
96}
97
98// ── Unified term scorer macro ────────────────────────────────────────────
99//
100// Parameterised on:
101//   $get_postings_fn – get_postings | get_postings_sync
102//   $get_positions_fn – get_positions | get_positions_sync
103//   $($aw)*          – .await  (present for async, absent for sync)
104macro_rules! term_plan {
105    ($field:expr, $term:expr, $global_stats:expr, $reader:expr,
106     $get_postings_fn:ident, $get_positions_fn:ident
107     $(, $aw:tt)*) => {{
108        let field: Field = $field;
109        let term: &[u8] = $term;
110        let global_stats: Option<&Arc<GlobalStats>> = $global_stats;
111        let reader: &SegmentReader = $reader;
112
113        // Non-indexed fields → fast-field-only path
114        let is_indexed = reader.schema().get_field_entry(field).is_none_or(|e| e.indexed);
115        if !is_indexed {
116            let term_str = String::from_utf8_lossy(term);
117            if let Some(scorer) = FastFieldTextScorer::try_new(reader, field, &term_str) {
118                return Ok(Box::new(scorer) as Box<dyn Scorer + '_>);
119            }
120            return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + '_>);
121        }
122
123        let postings = reader.$get_postings_fn(field, term) $(. $aw)* ?;
124
125        match postings {
126            Some(posting_list) => {
127                let (idf, avg_field_len) =
128                    compute_term_idf(&posting_list, field, reader, global_stats, term);
129
130                let positions = reader.$get_positions_fn(field, term)
131                    $(. $aw)* .ok().flatten();
132
133                let mut scorer = TermScorer::new(posting_list, idf, avg_field_len, 1.0);
134                if let Some(pos) = positions {
135                    scorer = scorer.with_positions(field.0, pos);
136                }
137                Ok(Box::new(scorer) as Box<dyn Scorer + '_>)
138            }
139            None => {
140                let term_str = String::from_utf8_lossy(term);
141                if let Some(scorer) = FastFieldTextScorer::try_new(reader, field, &term_str) {
142                    Ok(Box::new(scorer) as Box<dyn Scorer + '_>)
143                } else {
144                    Ok(Box::new(EmptyScorer) as Box<dyn Scorer + '_>)
145                }
146            }
147        }
148    }};
149}
150
151impl Query for TermQuery {
152    fn scorer<'a>(&self, reader: &'a SegmentReader, _limit: usize) -> ScorerFuture<'a> {
153        let field = self.field;
154        let term = self.term.clone();
155        let global_stats = self.global_stats.clone();
156        Box::pin(async move {
157            term_plan!(
158                field,
159                &term,
160                global_stats.as_ref(),
161                reader,
162                get_postings,
163                get_positions,
164                await
165            )
166        })
167    }
168
169    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
170        let field = self.field;
171        let term = self.term.clone();
172        Box::pin(async move {
173            match reader.get_postings(field, &term).await? {
174                Some(list) => Ok(list.doc_count()),
175                None => Ok(0),
176            }
177        })
178    }
179
180    #[cfg(feature = "sync")]
181    fn scorer_sync<'a>(
182        &self,
183        reader: &'a SegmentReader,
184        _limit: usize,
185    ) -> crate::Result<Box<dyn Scorer + 'a>> {
186        term_plan!(
187            self.field,
188            &self.term,
189            self.global_stats.as_ref(),
190            reader,
191            get_postings_sync,
192            get_positions_sync
193        )
194    }
195
196    fn as_doc_predicate<'a>(&self, reader: &'a SegmentReader) -> Option<super::DocPredicate<'a>> {
197        let fast_field = reader.fast_field(self.field.0)?;
198        let term_str = String::from_utf8_lossy(&self.term);
199        match fast_field.text_ordinal(&term_str) {
200            Some(target_ordinal) => Some(Box::new(move |doc_id: DocId| -> bool {
201                fast_field.get_u64(doc_id) == target_ordinal
202            })),
203            // Term doesn't exist in this segment — no doc can match.
204            None => Some(Box::new(|_| false)),
205        }
206    }
207
208    fn decompose(&self) -> super::QueryDecomposition {
209        super::QueryDecomposition::TextTerm(TermQueryInfo {
210            field: self.field,
211            term: self.term.clone(),
212        })
213    }
214}
215
216struct TermScorer {
217    iterator: crate::structures::BlockPostingIterator<'static>,
218    idf: f32,
219    /// Average field length for this field
220    avg_field_len: f32,
221    /// Field boost/weight for BM25F
222    field_boost: f32,
223    /// Field ID for position reporting
224    field_id: u32,
225    /// Position posting list (if positions are enabled)
226    positions: Option<crate::structures::PositionPostingList>,
227}
228
229impl TermScorer {
230    pub fn new(
231        posting_list: BlockPostingList,
232        idf: f32,
233        avg_field_len: f32,
234        field_boost: f32,
235    ) -> Self {
236        Self {
237            iterator: posting_list.into_iterator(),
238            idf,
239            avg_field_len,
240            field_boost,
241            field_id: 0,
242            positions: None,
243        }
244    }
245
246    pub fn with_positions(
247        mut self,
248        field_id: u32,
249        positions: crate::structures::PositionPostingList,
250    ) -> Self {
251        self.field_id = field_id;
252        self.positions = Some(positions);
253        self
254    }
255}
256
257impl super::docset::DocSet for TermScorer {
258    fn doc(&self) -> DocId {
259        self.iterator.doc()
260    }
261
262    fn advance(&mut self) -> DocId {
263        self.iterator.advance()
264    }
265
266    fn seek(&mut self, target: DocId) -> DocId {
267        self.iterator.seek(target)
268    }
269
270    fn size_hint(&self) -> u32 {
271        0
272    }
273}
274
275// ── Fast field text equality scorer ──────────────────────────────────────
276
277/// Scorer that scans a text fast field for exact string equality.
278/// Used as fallback when a TermQuery targets a fast-only text field (no inverted index).
279/// Returns score 1.0 for matching docs (filter-style, like RangeScorer).
280struct FastFieldTextScorer<'a> {
281    fast_field: &'a crate::structures::fast_field::FastFieldReader,
282    target_ordinal: u64,
283    current: u32,
284    num_docs: u32,
285}
286
287impl<'a> FastFieldTextScorer<'a> {
288    fn try_new(reader: &'a SegmentReader, field: Field, text: &str) -> Option<Self> {
289        let fast_field = reader.fast_field(field.0)?;
290        let target_ordinal = fast_field.text_ordinal(text)?;
291        let num_docs = reader.num_docs();
292        let mut scorer = Self {
293            fast_field,
294            target_ordinal,
295            current: 0,
296            num_docs,
297        };
298        // Position on first matching doc
299        if num_docs > 0 && fast_field.get_u64(0) != target_ordinal {
300            scorer.scan_forward();
301        }
302        Some(scorer)
303    }
304
305    fn scan_forward(&mut self) {
306        loop {
307            self.current += 1;
308            if self.current >= self.num_docs {
309                self.current = self.num_docs;
310                return;
311            }
312            if self.fast_field.get_u64(self.current) == self.target_ordinal {
313                return;
314            }
315        }
316    }
317}
318
319impl super::docset::DocSet for FastFieldTextScorer<'_> {
320    fn doc(&self) -> DocId {
321        if self.current >= self.num_docs {
322            TERMINATED
323        } else {
324            self.current
325        }
326    }
327
328    fn advance(&mut self) -> DocId {
329        self.scan_forward();
330        self.doc()
331    }
332
333    fn seek(&mut self, target: DocId) -> DocId {
334        if target > self.current {
335            self.current = target;
336            if self.current < self.num_docs
337                && self.fast_field.get_u64(self.current) != self.target_ordinal
338            {
339                self.scan_forward();
340            }
341        }
342        self.doc()
343    }
344
345    fn size_hint(&self) -> u32 {
346        0
347    }
348}
349
350impl Scorer for FastFieldTextScorer<'_> {
351    fn score(&self) -> Score {
352        1.0
353    }
354}
355
356impl Scorer for TermScorer {
357    fn score(&self) -> Score {
358        let tf = self.iterator.term_freq() as f32;
359        // Note: Using tf as doc_len proxy since we don't store per-doc field lengths.
360        // This is a common approximation - longer docs tend to have higher TF.
361        super::bm25f_score(tf, self.idf, tf, self.avg_field_len, self.field_boost)
362    }
363
364    fn matched_positions(&self) -> Option<super::MatchedPositions> {
365        let positions = self.positions.as_ref()?;
366        let doc_id = self.iterator.doc();
367        let pos = positions.get_positions(doc_id)?;
368        let score = self.score();
369        // Each position contributes equally to the term score
370        let per_position_score = if pos.is_empty() {
371            0.0
372        } else {
373            score / pos.len() as f32
374        };
375        let scored_positions: Vec<super::ScoredPosition> = pos
376            .iter()
377            .map(|&p| super::ScoredPosition::new(p, per_position_score))
378            .collect();
379        Some(vec![(self.field_id, scored_positions)])
380    }
381}