Skip to main content

hermes_core/query/
term.rs

1//! Term query - matches documents containing a specific term
2
3use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::BlockPostingList;
8use crate::structures::TERMINATED;
9use crate::{DocId, Score};
10
11use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture, TermQueryInfo};
12
13/// Term query - matches documents containing a specific term
14#[derive(Clone)]
15pub struct TermQuery {
16    pub field: Field,
17    pub term: Vec<u8>,
18    /// Optional global statistics for cross-segment IDF
19    global_stats: Option<Arc<GlobalStats>>,
20}
21
22impl std::fmt::Debug for TermQuery {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        f.debug_struct("TermQuery")
25            .field("field", &self.field)
26            .field("term", &String::from_utf8_lossy(&self.term))
27            .field("has_global_stats", &self.global_stats.is_some())
28            .finish()
29    }
30}
31
32impl std::fmt::Display for TermQuery {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        write!(
35            f,
36            "Term({}:\"{}\")",
37            self.field.0,
38            String::from_utf8_lossy(&self.term)
39        )
40    }
41}
42
43impl TermQuery {
44    pub fn new(field: Field, term: impl Into<Vec<u8>>) -> Self {
45        Self {
46            field,
47            term: term.into(),
48            global_stats: None,
49        }
50    }
51
52    pub fn text(field: Field, text: &str) -> Self {
53        Self {
54            field,
55            term: text.to_lowercase().into_bytes(),
56            global_stats: None,
57        }
58    }
59
60    /// Create with global statistics for cross-segment IDF
61    pub fn with_global_stats(field: Field, text: &str, stats: Arc<GlobalStats>) -> Self {
62        Self {
63            field,
64            term: text.to_lowercase().into_bytes(),
65            global_stats: Some(stats),
66        }
67    }
68
69    /// Set global statistics for cross-segment IDF
70    pub fn set_global_stats(&mut self, stats: Arc<GlobalStats>) {
71        self.global_stats = Some(stats);
72    }
73}
74
75impl Query for TermQuery {
76    fn scorer<'a>(&self, reader: &'a SegmentReader, _limit: usize) -> ScorerFuture<'a> {
77        let field = self.field;
78        let term = self.term.clone();
79        let global_stats = self.global_stats.clone();
80        let is_indexed = reader
81            .schema()
82            .get_field_entry(field)
83            .is_none_or(|e| e.indexed);
84        Box::pin(async move {
85            // For non-indexed fields (fast-field-only), skip SSTable entirely
86            if !is_indexed {
87                let term_str = String::from_utf8_lossy(&term);
88                if let Some(scorer) = FastFieldTextScorer::try_new(reader, field, &term_str) {
89                    return Ok(Box::new(scorer) as Box<dyn Scorer + 'a>);
90                }
91                return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>);
92            }
93
94            let postings = reader.get_postings(field, &term).await?;
95
96            match postings {
97                Some(posting_list) => {
98                    // Use global stats IDF if available, otherwise segment-local
99                    let (idf, avg_field_len) = if let Some(ref stats) = global_stats {
100                        let term_str = String::from_utf8_lossy(&term);
101                        let global_idf = stats.text_idf(field, &term_str);
102
103                        // If global stats has this term, use global IDF
104                        // Otherwise fall back to segment-local
105                        if global_idf > 0.0 {
106                            (global_idf, stats.avg_field_len(field))
107                        } else {
108                            // Fall back to segment-local IDF
109                            let num_docs = reader.num_docs() as f32;
110                            let doc_freq = posting_list.doc_count() as f32;
111                            let idf = super::bm25_idf(doc_freq, num_docs);
112                            (idf, reader.avg_field_len(field))
113                        }
114                    } else {
115                        // Compute IDF from segment statistics
116                        let num_docs = reader.num_docs() as f32;
117                        let doc_freq = posting_list.doc_count() as f32;
118                        let idf = super::bm25_idf(doc_freq, num_docs);
119                        (idf, reader.avg_field_len(field))
120                    };
121
122                    // Try to load positions if available
123                    let positions = reader.get_positions(field, &term).await.ok().flatten();
124
125                    let mut scorer = TermScorer::new(
126                        posting_list,
127                        idf,
128                        avg_field_len,
129                        1.0, // default field boost
130                    );
131
132                    if let Some(pos) = positions {
133                        scorer = scorer.with_positions(field.0, pos);
134                    }
135
136                    Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
137                }
138                None => {
139                    // Fall back to fast field scanning for fast-only text fields
140                    let term_str = String::from_utf8_lossy(&term);
141                    if let Some(scorer) = FastFieldTextScorer::try_new(reader, field, &term_str) {
142                        Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
143                    } else {
144                        Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>)
145                    }
146                }
147            }
148        })
149    }
150
151    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
152        let field = self.field;
153        let term = self.term.clone();
154        Box::pin(async move {
155            match reader.get_postings(field, &term).await? {
156                Some(list) => Ok(list.doc_count()),
157                None => Ok(0),
158            }
159        })
160    }
161
162    #[cfg(feature = "sync")]
163    fn scorer_sync<'a>(
164        &self,
165        reader: &'a SegmentReader,
166        _limit: usize,
167    ) -> crate::Result<Box<dyn Scorer + 'a>> {
168        // For non-indexed fields (fast-field-only), skip SSTable entirely
169        let is_indexed = reader
170            .schema()
171            .get_field_entry(self.field)
172            .is_none_or(|e| e.indexed);
173        if !is_indexed {
174            let term_str = String::from_utf8_lossy(&self.term);
175            if let Some(scorer) = FastFieldTextScorer::try_new(reader, self.field, &term_str) {
176                return Ok(Box::new(scorer) as Box<dyn Scorer + 'a>);
177            }
178            return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>);
179        }
180
181        let postings = reader.get_postings_sync(self.field, &self.term)?;
182
183        match postings {
184            Some(posting_list) => {
185                let (idf, avg_field_len) = if let Some(ref stats) = self.global_stats {
186                    let term_str = String::from_utf8_lossy(&self.term);
187                    let global_idf = stats.text_idf(self.field, &term_str);
188                    if global_idf > 0.0 {
189                        (global_idf, stats.avg_field_len(self.field))
190                    } else {
191                        let num_docs = reader.num_docs() as f32;
192                        let doc_freq = posting_list.doc_count() as f32;
193                        (
194                            super::bm25_idf(doc_freq, num_docs),
195                            reader.avg_field_len(self.field),
196                        )
197                    }
198                } else {
199                    let num_docs = reader.num_docs() as f32;
200                    let doc_freq = posting_list.doc_count() as f32;
201                    (
202                        super::bm25_idf(doc_freq, num_docs),
203                        reader.avg_field_len(self.field),
204                    )
205                };
206
207                let positions = reader
208                    .get_positions_sync(self.field, &self.term)
209                    .ok()
210                    .flatten();
211
212                let mut scorer = TermScorer::new(posting_list, idf, avg_field_len, 1.0);
213                if let Some(pos) = positions {
214                    scorer = scorer.with_positions(self.field.0, pos);
215                }
216
217                Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
218            }
219            None => {
220                let term_str = String::from_utf8_lossy(&self.term);
221                if let Some(scorer) = FastFieldTextScorer::try_new(reader, self.field, &term_str) {
222                    Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
223                } else {
224                    Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>)
225                }
226            }
227        }
228    }
229
230    fn as_doc_predicate<'a>(&self, reader: &'a SegmentReader) -> Option<super::DocPredicate<'a>> {
231        let fast_field = reader.fast_field(self.field.0)?;
232        let term_str = String::from_utf8_lossy(&self.term);
233        let target_ordinal = fast_field.text_ordinal(&term_str)?;
234        Some(Box::new(move |doc_id: DocId| -> bool {
235            fast_field.get_u64(doc_id) == target_ordinal
236        }))
237    }
238
239    fn as_term_query_info(&self) -> Option<TermQueryInfo> {
240        Some(TermQueryInfo {
241            field: self.field,
242            term: self.term.clone(),
243        })
244    }
245}
246
247struct TermScorer {
248    iterator: crate::structures::BlockPostingIterator<'static>,
249    idf: f32,
250    /// Average field length for this field
251    avg_field_len: f32,
252    /// Field boost/weight for BM25F
253    field_boost: f32,
254    /// Field ID for position reporting
255    field_id: u32,
256    /// Position posting list (if positions are enabled)
257    positions: Option<crate::structures::PositionPostingList>,
258}
259
260impl TermScorer {
261    pub fn new(
262        posting_list: BlockPostingList,
263        idf: f32,
264        avg_field_len: f32,
265        field_boost: f32,
266    ) -> Self {
267        Self {
268            iterator: posting_list.into_iterator(),
269            idf,
270            avg_field_len,
271            field_boost,
272            field_id: 0,
273            positions: None,
274        }
275    }
276
277    pub fn with_positions(
278        mut self,
279        field_id: u32,
280        positions: crate::structures::PositionPostingList,
281    ) -> Self {
282        self.field_id = field_id;
283        self.positions = Some(positions);
284        self
285    }
286}
287
288impl super::docset::DocSet for TermScorer {
289    fn doc(&self) -> DocId {
290        self.iterator.doc()
291    }
292
293    fn advance(&mut self) -> DocId {
294        self.iterator.advance()
295    }
296
297    fn seek(&mut self, target: DocId) -> DocId {
298        self.iterator.seek(target)
299    }
300
301    fn size_hint(&self) -> u32 {
302        0
303    }
304}
305
306// ── Fast field text equality scorer ──────────────────────────────────────
307
308/// Scorer that scans a text fast field for exact string equality.
309/// Used as fallback when a TermQuery targets a fast-only text field (no inverted index).
310/// Returns score 1.0 for matching docs (filter-style, like RangeScorer).
311struct FastFieldTextScorer<'a> {
312    fast_field: &'a crate::structures::fast_field::FastFieldReader,
313    target_ordinal: u64,
314    current: u32,
315    num_docs: u32,
316}
317
318impl<'a> FastFieldTextScorer<'a> {
319    fn try_new(reader: &'a SegmentReader, field: Field, text: &str) -> Option<Self> {
320        let fast_field = reader.fast_field(field.0)?;
321        let target_ordinal = fast_field.text_ordinal(text)?;
322        let num_docs = reader.num_docs();
323        let mut scorer = Self {
324            fast_field,
325            target_ordinal,
326            current: 0,
327            num_docs,
328        };
329        // Position on first matching doc
330        if num_docs > 0 && fast_field.get_u64(0) != target_ordinal {
331            scorer.scan_forward();
332        }
333        Some(scorer)
334    }
335
336    fn scan_forward(&mut self) {
337        loop {
338            self.current += 1;
339            if self.current >= self.num_docs {
340                self.current = self.num_docs;
341                return;
342            }
343            if self.fast_field.get_u64(self.current) == self.target_ordinal {
344                return;
345            }
346        }
347    }
348}
349
350impl super::docset::DocSet for FastFieldTextScorer<'_> {
351    fn doc(&self) -> DocId {
352        if self.current >= self.num_docs {
353            TERMINATED
354        } else {
355            self.current
356        }
357    }
358
359    fn advance(&mut self) -> DocId {
360        self.scan_forward();
361        self.doc()
362    }
363
364    fn seek(&mut self, target: DocId) -> DocId {
365        if target > self.current {
366            self.current = target;
367            if self.current < self.num_docs
368                && self.fast_field.get_u64(self.current) != self.target_ordinal
369            {
370                self.scan_forward();
371            }
372        }
373        self.doc()
374    }
375
376    fn size_hint(&self) -> u32 {
377        0
378    }
379}
380
381impl Scorer for FastFieldTextScorer<'_> {
382    fn score(&self) -> Score {
383        1.0
384    }
385}
386
387impl Scorer for TermScorer {
388    fn score(&self) -> Score {
389        let tf = self.iterator.term_freq() as f32;
390        // Note: Using tf as doc_len proxy since we don't store per-doc field lengths.
391        // This is a common approximation - longer docs tend to have higher TF.
392        super::bm25f_score(tf, self.idf, tf, self.avg_field_len, self.field_boost)
393    }
394
395    fn matched_positions(&self) -> Option<super::MatchedPositions> {
396        let positions = self.positions.as_ref()?;
397        let doc_id = self.iterator.doc();
398        let pos = positions.get_positions(doc_id)?;
399        let score = self.score();
400        // Each position contributes equally to the term score
401        let per_position_score = if pos.is_empty() {
402            0.0
403        } else {
404            score / pos.len() as f32
405        };
406        let scored_positions: Vec<super::ScoredPosition> = pos
407            .iter()
408            .map(|&p| super::ScoredPosition::new(p, per_position_score))
409            .collect();
410        Some(vec![(self.field_id, scored_positions)])
411    }
412}