Skip to main content

hermes_core/query/
term.rs

1//! Term query - matches documents containing a specific term
2
3use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::BlockPostingList;
8use crate::structures::TERMINATED;
9use crate::{DocId, Score};
10
11use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture, TermQueryInfo};
12
13/// Term query - matches documents containing a specific term
14#[derive(Clone)]
15pub struct TermQuery {
16    pub field: Field,
17    pub term: Vec<u8>,
18    /// Optional global statistics for cross-segment IDF
19    global_stats: Option<Arc<GlobalStats>>,
20}
21
22impl std::fmt::Debug for TermQuery {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        f.debug_struct("TermQuery")
25            .field("field", &self.field)
26            .field("term", &String::from_utf8_lossy(&self.term))
27            .field("has_global_stats", &self.global_stats.is_some())
28            .finish()
29    }
30}
31
32impl std::fmt::Display for TermQuery {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        write!(
35            f,
36            "Term({}:\"{}\")",
37            self.field.0,
38            String::from_utf8_lossy(&self.term)
39        )
40    }
41}
42
43impl TermQuery {
44    pub fn new(field: Field, term: impl Into<Vec<u8>>) -> Self {
45        Self {
46            field,
47            term: term.into(),
48            global_stats: None,
49        }
50    }
51
52    pub fn text(field: Field, text: &str) -> Self {
53        Self {
54            field,
55            term: text.to_lowercase().into_bytes(),
56            global_stats: None,
57        }
58    }
59
60    /// Create with global statistics for cross-segment IDF
61    pub fn with_global_stats(field: Field, text: &str, stats: Arc<GlobalStats>) -> Self {
62        Self {
63            field,
64            term: text.to_lowercase().into_bytes(),
65            global_stats: Some(stats),
66        }
67    }
68
69    /// Set global statistics for cross-segment IDF
70    pub fn set_global_stats(&mut self, stats: Arc<GlobalStats>) {
71        self.global_stats = Some(stats);
72    }
73}
74
75impl Query for TermQuery {
76    fn scorer<'a>(&self, reader: &'a SegmentReader, _limit: usize) -> ScorerFuture<'a> {
77        let field = self.field;
78        let term = self.term.clone();
79        let global_stats = self.global_stats.clone();
80        Box::pin(async move {
81            let postings = reader.get_postings(field, &term).await?;
82
83            match postings {
84                Some(posting_list) => {
85                    // Use global stats IDF if available, otherwise segment-local
86                    let (idf, avg_field_len) = if let Some(ref stats) = global_stats {
87                        let term_str = String::from_utf8_lossy(&term);
88                        let global_idf = stats.text_idf(field, &term_str);
89
90                        // If global stats has this term, use global IDF
91                        // Otherwise fall back to segment-local
92                        if global_idf > 0.0 {
93                            (global_idf, stats.avg_field_len(field))
94                        } else {
95                            // Fall back to segment-local IDF
96                            let num_docs = reader.num_docs() as f32;
97                            let doc_freq = posting_list.doc_count() as f32;
98                            let idf = super::bm25_idf(doc_freq, num_docs);
99                            (idf, reader.avg_field_len(field))
100                        }
101                    } else {
102                        // Compute IDF from segment statistics
103                        let num_docs = reader.num_docs() as f32;
104                        let doc_freq = posting_list.doc_count() as f32;
105                        let idf = super::bm25_idf(doc_freq, num_docs);
106                        (idf, reader.avg_field_len(field))
107                    };
108
109                    // Try to load positions if available
110                    let positions = reader.get_positions(field, &term).await.ok().flatten();
111
112                    let mut scorer = TermScorer::new(
113                        posting_list,
114                        idf,
115                        avg_field_len,
116                        1.0, // default field boost
117                    );
118
119                    if let Some(pos) = positions {
120                        scorer = scorer.with_positions(field.0, pos);
121                    }
122
123                    Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
124                }
125                None => {
126                    // Fall back to fast field scanning for fast-only text fields
127                    let term_str = String::from_utf8_lossy(&term);
128                    if let Some(scorer) = FastFieldTextScorer::try_new(reader, field, &term_str) {
129                        Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
130                    } else {
131                        Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>)
132                    }
133                }
134            }
135        })
136    }
137
138    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
139        let field = self.field;
140        let term = self.term.clone();
141        Box::pin(async move {
142            match reader.get_postings(field, &term).await? {
143                Some(list) => Ok(list.doc_count()),
144                None => Ok(0),
145            }
146        })
147    }
148
149    #[cfg(feature = "sync")]
150    fn scorer_sync<'a>(
151        &self,
152        reader: &'a SegmentReader,
153        _limit: usize,
154    ) -> crate::Result<Box<dyn Scorer + 'a>> {
155        let postings = reader.get_postings_sync(self.field, &self.term)?;
156
157        match postings {
158            Some(posting_list) => {
159                let (idf, avg_field_len) = if let Some(ref stats) = self.global_stats {
160                    let term_str = String::from_utf8_lossy(&self.term);
161                    let global_idf = stats.text_idf(self.field, &term_str);
162                    if global_idf > 0.0 {
163                        (global_idf, stats.avg_field_len(self.field))
164                    } else {
165                        let num_docs = reader.num_docs() as f32;
166                        let doc_freq = posting_list.doc_count() as f32;
167                        (
168                            super::bm25_idf(doc_freq, num_docs),
169                            reader.avg_field_len(self.field),
170                        )
171                    }
172                } else {
173                    let num_docs = reader.num_docs() as f32;
174                    let doc_freq = posting_list.doc_count() as f32;
175                    (
176                        super::bm25_idf(doc_freq, num_docs),
177                        reader.avg_field_len(self.field),
178                    )
179                };
180
181                let positions = reader
182                    .get_positions_sync(self.field, &self.term)
183                    .ok()
184                    .flatten();
185
186                let mut scorer = TermScorer::new(posting_list, idf, avg_field_len, 1.0);
187                if let Some(pos) = positions {
188                    scorer = scorer.with_positions(self.field.0, pos);
189                }
190
191                Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
192            }
193            None => {
194                let term_str = String::from_utf8_lossy(&self.term);
195                if let Some(scorer) = FastFieldTextScorer::try_new(reader, self.field, &term_str) {
196                    Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
197                } else {
198                    Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>)
199                }
200            }
201        }
202    }
203
204    fn as_term_query_info(&self) -> Option<TermQueryInfo> {
205        Some(TermQueryInfo {
206            field: self.field,
207            term: self.term.clone(),
208        })
209    }
210}
211
212struct TermScorer {
213    iterator: crate::structures::BlockPostingIterator<'static>,
214    idf: f32,
215    /// Average field length for this field
216    avg_field_len: f32,
217    /// Field boost/weight for BM25F
218    field_boost: f32,
219    /// Field ID for position reporting
220    field_id: u32,
221    /// Position posting list (if positions are enabled)
222    positions: Option<crate::structures::PositionPostingList>,
223}
224
225impl TermScorer {
226    pub fn new(
227        posting_list: BlockPostingList,
228        idf: f32,
229        avg_field_len: f32,
230        field_boost: f32,
231    ) -> Self {
232        Self {
233            iterator: posting_list.into_iterator(),
234            idf,
235            avg_field_len,
236            field_boost,
237            field_id: 0,
238            positions: None,
239        }
240    }
241
242    pub fn with_positions(
243        mut self,
244        field_id: u32,
245        positions: crate::structures::PositionPostingList,
246    ) -> Self {
247        self.field_id = field_id;
248        self.positions = Some(positions);
249        self
250    }
251}
252
253impl super::docset::DocSet for TermScorer {
254    fn doc(&self) -> DocId {
255        self.iterator.doc()
256    }
257
258    fn advance(&mut self) -> DocId {
259        self.iterator.advance()
260    }
261
262    fn seek(&mut self, target: DocId) -> DocId {
263        self.iterator.seek(target)
264    }
265
266    fn size_hint(&self) -> u32 {
267        0
268    }
269}
270
271// ── Fast field text equality scorer ──────────────────────────────────────
272
273/// Scorer that scans a text fast field for exact string equality.
274/// Used as fallback when a TermQuery targets a fast-only text field (no inverted index).
275/// Returns score 1.0 for matching docs (filter-style, like RangeScorer).
276struct FastFieldTextScorer<'a> {
277    fast_field: &'a crate::structures::fast_field::FastFieldReader,
278    target_ordinal: u64,
279    current: u32,
280    num_docs: u32,
281}
282
283impl<'a> FastFieldTextScorer<'a> {
284    fn try_new(reader: &'a SegmentReader, field: Field, text: &str) -> Option<Self> {
285        let fast_field = reader.fast_field(field.0)?;
286        let target_ordinal = fast_field.text_ordinal(text)?;
287        let num_docs = reader.num_docs();
288        let mut scorer = Self {
289            fast_field,
290            target_ordinal,
291            current: 0,
292            num_docs,
293        };
294        // Position on first matching doc
295        if num_docs > 0 && fast_field.get_u64(0) != target_ordinal {
296            scorer.scan_forward();
297        }
298        Some(scorer)
299    }
300
301    fn scan_forward(&mut self) {
302        loop {
303            self.current += 1;
304            if self.current >= self.num_docs {
305                self.current = self.num_docs;
306                return;
307            }
308            if self.fast_field.get_u64(self.current) == self.target_ordinal {
309                return;
310            }
311        }
312    }
313}
314
315impl super::docset::DocSet for FastFieldTextScorer<'_> {
316    fn doc(&self) -> DocId {
317        if self.current >= self.num_docs {
318            TERMINATED
319        } else {
320            self.current
321        }
322    }
323
324    fn advance(&mut self) -> DocId {
325        self.scan_forward();
326        self.doc()
327    }
328
329    fn seek(&mut self, target: DocId) -> DocId {
330        if target > self.current {
331            self.current = target;
332            if self.current < self.num_docs
333                && self.fast_field.get_u64(self.current) != self.target_ordinal
334            {
335                self.scan_forward();
336            }
337        }
338        self.doc()
339    }
340
341    fn size_hint(&self) -> u32 {
342        0
343    }
344}
345
346impl Scorer for FastFieldTextScorer<'_> {
347    fn score(&self) -> Score {
348        1.0
349    }
350}
351
352impl Scorer for TermScorer {
353    fn score(&self) -> Score {
354        let tf = self.iterator.term_freq() as f32;
355        // Note: Using tf as doc_len proxy since we don't store per-doc field lengths.
356        // This is a common approximation - longer docs tend to have higher TF.
357        super::bm25f_score(tf, self.idf, tf, self.avg_field_len, self.field_boost)
358    }
359
360    fn matched_positions(&self) -> Option<super::MatchedPositions> {
361        let positions = self.positions.as_ref()?;
362        let doc_id = self.iterator.doc();
363        let pos = positions.get_positions(doc_id)?;
364        let score = self.score();
365        // Each position contributes equally to the term score
366        let per_position_score = if pos.is_empty() {
367            0.0
368        } else {
369            score / pos.len() as f32
370        };
371        let scored_positions: Vec<super::ScoredPosition> = pos
372            .iter()
373            .map(|&p| super::ScoredPosition::new(p, per_position_score))
374            .collect();
375        Some(vec![(self.field_id, scored_positions)])
376    }
377}