Skip to main content

hermes_core/query/
term.rs

1//! Term query - matches documents containing a specific term
2
3use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::BlockPostingList;
8use crate::{DocId, Score};
9
10use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture, TermQueryInfo};
11
12/// Term query - matches documents containing a specific term
13#[derive(Clone)]
14pub struct TermQuery {
15    pub field: Field,
16    pub term: Vec<u8>,
17    /// Optional global statistics for cross-segment IDF
18    global_stats: Option<Arc<GlobalStats>>,
19}
20
21impl std::fmt::Debug for TermQuery {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("TermQuery")
24            .field("field", &self.field)
25            .field("term", &String::from_utf8_lossy(&self.term))
26            .field("has_global_stats", &self.global_stats.is_some())
27            .finish()
28    }
29}
30
31impl std::fmt::Display for TermQuery {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        write!(
34            f,
35            "Term({}:\"{}\")",
36            self.field.0,
37            String::from_utf8_lossy(&self.term)
38        )
39    }
40}
41
42impl TermQuery {
43    pub fn new(field: Field, term: impl Into<Vec<u8>>) -> Self {
44        Self {
45            field,
46            term: term.into(),
47            global_stats: None,
48        }
49    }
50
51    pub fn text(field: Field, text: &str) -> Self {
52        Self {
53            field,
54            term: text.to_lowercase().into_bytes(),
55            global_stats: None,
56        }
57    }
58
59    /// Create with global statistics for cross-segment IDF
60    pub fn with_global_stats(field: Field, text: &str, stats: Arc<GlobalStats>) -> Self {
61        Self {
62            field,
63            term: text.to_lowercase().into_bytes(),
64            global_stats: Some(stats),
65        }
66    }
67
68    /// Set global statistics for cross-segment IDF
69    pub fn set_global_stats(&mut self, stats: Arc<GlobalStats>) {
70        self.global_stats = Some(stats);
71    }
72}
73
74impl Query for TermQuery {
75    fn scorer<'a>(&self, reader: &'a SegmentReader, _limit: usize) -> ScorerFuture<'a> {
76        let field = self.field;
77        let term = self.term.clone();
78        let global_stats = self.global_stats.clone();
79        Box::pin(async move {
80            let postings = reader.get_postings(field, &term).await?;
81
82            match postings {
83                Some(posting_list) => {
84                    // Use global stats IDF if available, otherwise segment-local
85                    let (idf, avg_field_len) = if let Some(ref stats) = global_stats {
86                        let term_str = String::from_utf8_lossy(&term);
87                        let global_idf = stats.text_idf(field, &term_str);
88
89                        // If global stats has this term, use global IDF
90                        // Otherwise fall back to segment-local
91                        if global_idf > 0.0 {
92                            (global_idf, stats.avg_field_len(field))
93                        } else {
94                            // Fall back to segment-local IDF
95                            let num_docs = reader.num_docs() as f32;
96                            let doc_freq = posting_list.doc_count() as f32;
97                            let idf = super::bm25_idf(doc_freq, num_docs);
98                            (idf, reader.avg_field_len(field))
99                        }
100                    } else {
101                        // Compute IDF from segment statistics
102                        let num_docs = reader.num_docs() as f32;
103                        let doc_freq = posting_list.doc_count() as f32;
104                        let idf = super::bm25_idf(doc_freq, num_docs);
105                        (idf, reader.avg_field_len(field))
106                    };
107
108                    // Try to load positions if available
109                    let positions = reader.get_positions(field, &term).await.ok().flatten();
110
111                    let mut scorer = TermScorer::new(
112                        posting_list,
113                        idf,
114                        avg_field_len,
115                        1.0, // default field boost
116                    );
117
118                    if let Some(pos) = positions {
119                        scorer = scorer.with_positions(field.0, pos);
120                    }
121
122                    Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
123                }
124                None => Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
125            }
126        })
127    }
128
129    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
130        let field = self.field;
131        let term = self.term.clone();
132        Box::pin(async move {
133            match reader.get_postings(field, &term).await? {
134                Some(list) => Ok(list.doc_count()),
135                None => Ok(0),
136            }
137        })
138    }
139
140    #[cfg(feature = "sync")]
141    fn scorer_sync<'a>(
142        &self,
143        reader: &'a SegmentReader,
144        _limit: usize,
145    ) -> crate::Result<Box<dyn Scorer + 'a>> {
146        let postings = reader.get_postings_sync(self.field, &self.term)?;
147
148        match postings {
149            Some(posting_list) => {
150                let (idf, avg_field_len) = if let Some(ref stats) = self.global_stats {
151                    let term_str = String::from_utf8_lossy(&self.term);
152                    let global_idf = stats.text_idf(self.field, &term_str);
153                    if global_idf > 0.0 {
154                        (global_idf, stats.avg_field_len(self.field))
155                    } else {
156                        let num_docs = reader.num_docs() as f32;
157                        let doc_freq = posting_list.doc_count() as f32;
158                        (
159                            super::bm25_idf(doc_freq, num_docs),
160                            reader.avg_field_len(self.field),
161                        )
162                    }
163                } else {
164                    let num_docs = reader.num_docs() as f32;
165                    let doc_freq = posting_list.doc_count() as f32;
166                    (
167                        super::bm25_idf(doc_freq, num_docs),
168                        reader.avg_field_len(self.field),
169                    )
170                };
171
172                let positions = reader
173                    .get_positions_sync(self.field, &self.term)
174                    .ok()
175                    .flatten();
176
177                let mut scorer = TermScorer::new(posting_list, idf, avg_field_len, 1.0);
178                if let Some(pos) = positions {
179                    scorer = scorer.with_positions(self.field.0, pos);
180                }
181
182                Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
183            }
184            None => Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
185        }
186    }
187
188    fn as_term_query_info(&self) -> Option<TermQueryInfo> {
189        Some(TermQueryInfo {
190            field: self.field,
191            term: self.term.clone(),
192        })
193    }
194}
195
196struct TermScorer {
197    iterator: crate::structures::BlockPostingIterator<'static>,
198    idf: f32,
199    /// Average field length for this field
200    avg_field_len: f32,
201    /// Field boost/weight for BM25F
202    field_boost: f32,
203    /// Field ID for position reporting
204    field_id: u32,
205    /// Position posting list (if positions are enabled)
206    positions: Option<crate::structures::PositionPostingList>,
207}
208
209impl TermScorer {
210    pub fn new(
211        posting_list: BlockPostingList,
212        idf: f32,
213        avg_field_len: f32,
214        field_boost: f32,
215    ) -> Self {
216        Self {
217            iterator: posting_list.into_iterator(),
218            idf,
219            avg_field_len,
220            field_boost,
221            field_id: 0,
222            positions: None,
223        }
224    }
225
226    pub fn with_positions(
227        mut self,
228        field_id: u32,
229        positions: crate::structures::PositionPostingList,
230    ) -> Self {
231        self.field_id = field_id;
232        self.positions = Some(positions);
233        self
234    }
235}
236
237impl super::docset::DocSet for TermScorer {
238    fn doc(&self) -> DocId {
239        self.iterator.doc()
240    }
241
242    fn advance(&mut self) -> DocId {
243        self.iterator.advance()
244    }
245
246    fn seek(&mut self, target: DocId) -> DocId {
247        self.iterator.seek(target)
248    }
249
250    fn size_hint(&self) -> u32 {
251        0
252    }
253}
254
255impl Scorer for TermScorer {
256    fn score(&self) -> Score {
257        let tf = self.iterator.term_freq() as f32;
258        // Note: Using tf as doc_len proxy since we don't store per-doc field lengths.
259        // This is a common approximation - longer docs tend to have higher TF.
260        super::bm25f_score(tf, self.idf, tf, self.avg_field_len, self.field_boost)
261    }
262
263    fn matched_positions(&self) -> Option<super::MatchedPositions> {
264        let positions = self.positions.as_ref()?;
265        let doc_id = self.iterator.doc();
266        let pos = positions.get_positions(doc_id)?;
267        let score = self.score();
268        // Each position contributes equally to the term score
269        let per_position_score = if pos.is_empty() {
270            0.0
271        } else {
272            score / pos.len() as f32
273        };
274        let scored_positions: Vec<super::ScoredPosition> = pos
275            .iter()
276            .map(|&p| super::ScoredPosition::new(p, per_position_score))
277            .collect();
278        Some(vec![(self.field_id, scored_positions)])
279    }
280}