Skip to main content

hermes_core/query/
term.rs

1//! Term query - matches documents containing a specific term
2
3use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::BlockPostingList;
8use crate::structures::TERMINATED;
9use crate::{DocId, Score};
10
11use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture, TermQueryInfo};
12
13/// Term query - matches documents containing a specific term
14#[derive(Clone)]
15pub struct TermQuery {
16    pub field: Field,
17    pub term: Vec<u8>,
18    /// Optional global statistics for cross-segment IDF
19    global_stats: Option<Arc<GlobalStats>>,
20}
21
22impl std::fmt::Debug for TermQuery {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        f.debug_struct("TermQuery")
25            .field("field", &self.field)
26            .field("term", &String::from_utf8_lossy(&self.term))
27            .field("has_global_stats", &self.global_stats.is_some())
28            .finish()
29    }
30}
31
32impl std::fmt::Display for TermQuery {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        write!(
35            f,
36            "Term({}:\"{}\")",
37            self.field.0,
38            String::from_utf8_lossy(&self.term)
39        )
40    }
41}
42
43impl TermQuery {
44    pub fn new(field: Field, term: impl Into<Vec<u8>>) -> Self {
45        Self {
46            field,
47            term: term.into(),
48            global_stats: None,
49        }
50    }
51
52    pub fn text(field: Field, text: &str) -> Self {
53        Self {
54            field,
55            term: text.to_lowercase().into_bytes(),
56            global_stats: None,
57        }
58    }
59
60    /// Create with global statistics for cross-segment IDF
61    pub fn with_global_stats(field: Field, text: &str, stats: Arc<GlobalStats>) -> Self {
62        Self {
63            field,
64            term: text.to_lowercase().into_bytes(),
65            global_stats: Some(stats),
66        }
67    }
68
69    /// Set global statistics for cross-segment IDF
70    pub fn set_global_stats(&mut self, stats: Arc<GlobalStats>) {
71        self.global_stats = Some(stats);
72    }
73}
74
75impl Query for TermQuery {
76    fn scorer<'a>(&self, reader: &'a SegmentReader, _limit: usize) -> ScorerFuture<'a> {
77        let field = self.field;
78        let term = self.term.clone();
79        let global_stats = self.global_stats.clone();
80        Box::pin(async move {
81            let postings = reader.get_postings(field, &term).await?;
82
83            match postings {
84                Some(posting_list) => {
85                    // Use global stats IDF if available, otherwise segment-local
86                    let (idf, avg_field_len) = if let Some(ref stats) = global_stats {
87                        let term_str = String::from_utf8_lossy(&term);
88                        let global_idf = stats.text_idf(field, &term_str);
89
90                        // If global stats has this term, use global IDF
91                        // Otherwise fall back to segment-local
92                        if global_idf > 0.0 {
93                            (global_idf, stats.avg_field_len(field))
94                        } else {
95                            // Fall back to segment-local IDF
96                            let num_docs = reader.num_docs() as f32;
97                            let doc_freq = posting_list.doc_count() as f32;
98                            let idf = super::bm25_idf(doc_freq, num_docs);
99                            (idf, reader.avg_field_len(field))
100                        }
101                    } else {
102                        // Compute IDF from segment statistics
103                        let num_docs = reader.num_docs() as f32;
104                        let doc_freq = posting_list.doc_count() as f32;
105                        let idf = super::bm25_idf(doc_freq, num_docs);
106                        (idf, reader.avg_field_len(field))
107                    };
108
109                    // Try to load positions if available
110                    let positions = reader.get_positions(field, &term).await.ok().flatten();
111
112                    let mut scorer = TermScorer::new(
113                        posting_list,
114                        idf,
115                        avg_field_len,
116                        1.0, // default field boost
117                    );
118
119                    if let Some(pos) = positions {
120                        scorer = scorer.with_positions(field.0, pos);
121                    }
122
123                    Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
124                }
125                None => {
126                    // Fall back to fast field scanning for fast-only text fields
127                    let term_str = String::from_utf8_lossy(&term);
128                    if let Some(scorer) = FastFieldTextScorer::try_new(reader, field, &term_str) {
129                        Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
130                    } else {
131                        Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>)
132                    }
133                }
134            }
135        })
136    }
137
138    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
139        let field = self.field;
140        let term = self.term.clone();
141        Box::pin(async move {
142            match reader.get_postings(field, &term).await? {
143                Some(list) => Ok(list.doc_count()),
144                None => Ok(0),
145            }
146        })
147    }
148
149    #[cfg(feature = "sync")]
150    fn scorer_sync<'a>(
151        &self,
152        reader: &'a SegmentReader,
153        _limit: usize,
154    ) -> crate::Result<Box<dyn Scorer + 'a>> {
155        let postings = reader.get_postings_sync(self.field, &self.term)?;
156
157        match postings {
158            Some(posting_list) => {
159                let (idf, avg_field_len) = if let Some(ref stats) = self.global_stats {
160                    let term_str = String::from_utf8_lossy(&self.term);
161                    let global_idf = stats.text_idf(self.field, &term_str);
162                    if global_idf > 0.0 {
163                        (global_idf, stats.avg_field_len(self.field))
164                    } else {
165                        let num_docs = reader.num_docs() as f32;
166                        let doc_freq = posting_list.doc_count() as f32;
167                        (
168                            super::bm25_idf(doc_freq, num_docs),
169                            reader.avg_field_len(self.field),
170                        )
171                    }
172                } else {
173                    let num_docs = reader.num_docs() as f32;
174                    let doc_freq = posting_list.doc_count() as f32;
175                    (
176                        super::bm25_idf(doc_freq, num_docs),
177                        reader.avg_field_len(self.field),
178                    )
179                };
180
181                let positions = reader
182                    .get_positions_sync(self.field, &self.term)
183                    .ok()
184                    .flatten();
185
186                let mut scorer = TermScorer::new(posting_list, idf, avg_field_len, 1.0);
187                if let Some(pos) = positions {
188                    scorer = scorer.with_positions(self.field.0, pos);
189                }
190
191                Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
192            }
193            None => {
194                let term_str = String::from_utf8_lossy(&self.term);
195                if let Some(scorer) = FastFieldTextScorer::try_new(reader, self.field, &term_str) {
196                    Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
197                } else {
198                    Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>)
199                }
200            }
201        }
202    }
203
204    fn as_doc_predicate<'a>(&self, reader: &'a SegmentReader) -> Option<super::DocPredicate<'a>> {
205        let fast_field = reader.fast_field(self.field.0)?;
206        let term_str = String::from_utf8_lossy(&self.term);
207        let target_ordinal = fast_field.text_ordinal(&term_str)?;
208        Some(Box::new(move |doc_id: DocId| -> bool {
209            fast_field.get_u64(doc_id) == target_ordinal
210        }))
211    }
212
213    fn as_term_query_info(&self) -> Option<TermQueryInfo> {
214        Some(TermQueryInfo {
215            field: self.field,
216            term: self.term.clone(),
217        })
218    }
219}
220
221struct TermScorer {
222    iterator: crate::structures::BlockPostingIterator<'static>,
223    idf: f32,
224    /// Average field length for this field
225    avg_field_len: f32,
226    /// Field boost/weight for BM25F
227    field_boost: f32,
228    /// Field ID for position reporting
229    field_id: u32,
230    /// Position posting list (if positions are enabled)
231    positions: Option<crate::structures::PositionPostingList>,
232}
233
234impl TermScorer {
235    pub fn new(
236        posting_list: BlockPostingList,
237        idf: f32,
238        avg_field_len: f32,
239        field_boost: f32,
240    ) -> Self {
241        Self {
242            iterator: posting_list.into_iterator(),
243            idf,
244            avg_field_len,
245            field_boost,
246            field_id: 0,
247            positions: None,
248        }
249    }
250
251    pub fn with_positions(
252        mut self,
253        field_id: u32,
254        positions: crate::structures::PositionPostingList,
255    ) -> Self {
256        self.field_id = field_id;
257        self.positions = Some(positions);
258        self
259    }
260}
261
262impl super::docset::DocSet for TermScorer {
263    fn doc(&self) -> DocId {
264        self.iterator.doc()
265    }
266
267    fn advance(&mut self) -> DocId {
268        self.iterator.advance()
269    }
270
271    fn seek(&mut self, target: DocId) -> DocId {
272        self.iterator.seek(target)
273    }
274
275    fn size_hint(&self) -> u32 {
276        0
277    }
278}
279
280// ── Fast field text equality scorer ──────────────────────────────────────
281
282/// Scorer that scans a text fast field for exact string equality.
283/// Used as fallback when a TermQuery targets a fast-only text field (no inverted index).
284/// Returns score 1.0 for matching docs (filter-style, like RangeScorer).
285struct FastFieldTextScorer<'a> {
286    fast_field: &'a crate::structures::fast_field::FastFieldReader,
287    target_ordinal: u64,
288    current: u32,
289    num_docs: u32,
290}
291
292impl<'a> FastFieldTextScorer<'a> {
293    fn try_new(reader: &'a SegmentReader, field: Field, text: &str) -> Option<Self> {
294        let fast_field = reader.fast_field(field.0)?;
295        let target_ordinal = fast_field.text_ordinal(text)?;
296        let num_docs = reader.num_docs();
297        let mut scorer = Self {
298            fast_field,
299            target_ordinal,
300            current: 0,
301            num_docs,
302        };
303        // Position on first matching doc
304        if num_docs > 0 && fast_field.get_u64(0) != target_ordinal {
305            scorer.scan_forward();
306        }
307        Some(scorer)
308    }
309
310    fn scan_forward(&mut self) {
311        loop {
312            self.current += 1;
313            if self.current >= self.num_docs {
314                self.current = self.num_docs;
315                return;
316            }
317            if self.fast_field.get_u64(self.current) == self.target_ordinal {
318                return;
319            }
320        }
321    }
322}
323
324impl super::docset::DocSet for FastFieldTextScorer<'_> {
325    fn doc(&self) -> DocId {
326        if self.current >= self.num_docs {
327            TERMINATED
328        } else {
329            self.current
330        }
331    }
332
333    fn advance(&mut self) -> DocId {
334        self.scan_forward();
335        self.doc()
336    }
337
338    fn seek(&mut self, target: DocId) -> DocId {
339        if target > self.current {
340            self.current = target;
341            if self.current < self.num_docs
342                && self.fast_field.get_u64(self.current) != self.target_ordinal
343            {
344                self.scan_forward();
345            }
346        }
347        self.doc()
348    }
349
350    fn size_hint(&self) -> u32 {
351        0
352    }
353}
354
355impl Scorer for FastFieldTextScorer<'_> {
356    fn score(&self) -> Score {
357        1.0
358    }
359}
360
361impl Scorer for TermScorer {
362    fn score(&self) -> Score {
363        let tf = self.iterator.term_freq() as f32;
364        // Note: Using tf as doc_len proxy since we don't store per-doc field lengths.
365        // This is a common approximation - longer docs tend to have higher TF.
366        super::bm25f_score(tf, self.idf, tf, self.avg_field_len, self.field_boost)
367    }
368
369    fn matched_positions(&self) -> Option<super::MatchedPositions> {
370        let positions = self.positions.as_ref()?;
371        let doc_id = self.iterator.doc();
372        let pos = positions.get_positions(doc_id)?;
373        let score = self.score();
374        // Each position contributes equally to the term score
375        let per_position_score = if pos.is_empty() {
376            0.0
377        } else {
378            score / pos.len() as f32
379        };
380        let scored_positions: Vec<super::ScoredPosition> = pos
381            .iter()
382            .map(|&p| super::ScoredPosition::new(p, per_position_score))
383            .collect();
384        Some(vec![(self.field_id, scored_positions)])
385    }
386}