Skip to main content

hermes_core/query/
term.rs

1//! Term query - matches documents containing a specific term
2
3use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::BlockPostingList;
8use crate::{DocId, Score};
9
10use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture, TermQueryInfo};
11
12/// Term query - matches documents containing a specific term
13#[derive(Clone)]
14pub struct TermQuery {
15    pub field: Field,
16    pub term: Vec<u8>,
17    /// Optional global statistics for cross-segment IDF
18    global_stats: Option<Arc<GlobalStats>>,
19}
20
21impl std::fmt::Debug for TermQuery {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("TermQuery")
24            .field("field", &self.field)
25            .field("term", &String::from_utf8_lossy(&self.term))
26            .field("has_global_stats", &self.global_stats.is_some())
27            .finish()
28    }
29}
30
31impl TermQuery {
32    pub fn new(field: Field, term: impl Into<Vec<u8>>) -> Self {
33        Self {
34            field,
35            term: term.into(),
36            global_stats: None,
37        }
38    }
39
40    pub fn text(field: Field, text: &str) -> Self {
41        Self {
42            field,
43            term: text.to_lowercase().into_bytes(),
44            global_stats: None,
45        }
46    }
47
48    /// Create with global statistics for cross-segment IDF
49    pub fn with_global_stats(field: Field, text: &str, stats: Arc<GlobalStats>) -> Self {
50        Self {
51            field,
52            term: text.to_lowercase().into_bytes(),
53            global_stats: Some(stats),
54        }
55    }
56
57    /// Set global statistics for cross-segment IDF
58    pub fn set_global_stats(&mut self, stats: Arc<GlobalStats>) {
59        self.global_stats = Some(stats);
60    }
61}
62
63impl Query for TermQuery {
64    fn scorer<'a>(
65        &self,
66        reader: &'a SegmentReader,
67        _limit: usize,
68        _predicate: Option<super::DocPredicate<'a>>,
69    ) -> ScorerFuture<'a> {
70        let field = self.field;
71        let term = self.term.clone();
72        let global_stats = self.global_stats.clone();
73        Box::pin(async move {
74            let postings = reader.get_postings(field, &term).await?;
75
76            match postings {
77                Some(posting_list) => {
78                    // Use global stats IDF if available, otherwise segment-local
79                    let (idf, avg_field_len) = if let Some(ref stats) = global_stats {
80                        let term_str = String::from_utf8_lossy(&term);
81                        let global_idf = stats.text_idf(field, &term_str);
82
83                        // If global stats has this term, use global IDF
84                        // Otherwise fall back to segment-local
85                        if global_idf > 0.0 {
86                            (global_idf, stats.avg_field_len(field))
87                        } else {
88                            // Fall back to segment-local IDF
89                            let num_docs = reader.num_docs() as f32;
90                            let doc_freq = posting_list.doc_count() as f32;
91                            let idf = super::bm25_idf(doc_freq, num_docs);
92                            (idf, reader.avg_field_len(field))
93                        }
94                    } else {
95                        // Compute IDF from segment statistics
96                        let num_docs = reader.num_docs() as f32;
97                        let doc_freq = posting_list.doc_count() as f32;
98                        let idf = super::bm25_idf(doc_freq, num_docs);
99                        (idf, reader.avg_field_len(field))
100                    };
101
102                    // Try to load positions if available
103                    let positions = reader.get_positions(field, &term).await.ok().flatten();
104
105                    let mut scorer = TermScorer::new(
106                        posting_list,
107                        idf,
108                        avg_field_len,
109                        1.0, // default field boost
110                    );
111
112                    if let Some(pos) = positions {
113                        scorer = scorer.with_positions(field.0, pos);
114                    }
115
116                    Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
117                }
118                None => Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
119            }
120        })
121    }
122
123    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
124        let field = self.field;
125        let term = self.term.clone();
126        Box::pin(async move {
127            match reader.get_postings(field, &term).await? {
128                Some(list) => Ok(list.doc_count()),
129                None => Ok(0),
130            }
131        })
132    }
133
134    #[cfg(feature = "sync")]
135    fn scorer_sync<'a>(
136        &self,
137        reader: &'a SegmentReader,
138        _limit: usize,
139        _predicate: Option<super::DocPredicate<'a>>,
140    ) -> crate::Result<Box<dyn Scorer + 'a>> {
141        let postings = reader.get_postings_sync(self.field, &self.term)?;
142
143        match postings {
144            Some(posting_list) => {
145                let (idf, avg_field_len) = if let Some(ref stats) = self.global_stats {
146                    let term_str = String::from_utf8_lossy(&self.term);
147                    let global_idf = stats.text_idf(self.field, &term_str);
148                    if global_idf > 0.0 {
149                        (global_idf, stats.avg_field_len(self.field))
150                    } else {
151                        let num_docs = reader.num_docs() as f32;
152                        let doc_freq = posting_list.doc_count() as f32;
153                        (
154                            super::bm25_idf(doc_freq, num_docs),
155                            reader.avg_field_len(self.field),
156                        )
157                    }
158                } else {
159                    let num_docs = reader.num_docs() as f32;
160                    let doc_freq = posting_list.doc_count() as f32;
161                    (
162                        super::bm25_idf(doc_freq, num_docs),
163                        reader.avg_field_len(self.field),
164                    )
165                };
166
167                let positions = reader
168                    .get_positions_sync(self.field, &self.term)
169                    .ok()
170                    .flatten();
171
172                let mut scorer = TermScorer::new(posting_list, idf, avg_field_len, 1.0);
173                if let Some(pos) = positions {
174                    scorer = scorer.with_positions(self.field.0, pos);
175                }
176
177                Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
178            }
179            None => Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
180        }
181    }
182
183    fn as_term_query_info(&self) -> Option<TermQueryInfo> {
184        Some(TermQueryInfo {
185            field: self.field,
186            term: self.term.clone(),
187        })
188    }
189}
190
191struct TermScorer {
192    iterator: crate::structures::BlockPostingIterator<'static>,
193    idf: f32,
194    /// Average field length for this field
195    avg_field_len: f32,
196    /// Field boost/weight for BM25F
197    field_boost: f32,
198    /// Field ID for position reporting
199    field_id: u32,
200    /// Position posting list (if positions are enabled)
201    positions: Option<crate::structures::PositionPostingList>,
202}
203
204impl TermScorer {
205    pub fn new(
206        posting_list: BlockPostingList,
207        idf: f32,
208        avg_field_len: f32,
209        field_boost: f32,
210    ) -> Self {
211        Self {
212            iterator: posting_list.into_iterator(),
213            idf,
214            avg_field_len,
215            field_boost,
216            field_id: 0,
217            positions: None,
218        }
219    }
220
221    pub fn with_positions(
222        mut self,
223        field_id: u32,
224        positions: crate::structures::PositionPostingList,
225    ) -> Self {
226        self.field_id = field_id;
227        self.positions = Some(positions);
228        self
229    }
230}
231
232impl Scorer for TermScorer {
233    fn doc(&self) -> DocId {
234        self.iterator.doc()
235    }
236
237    fn score(&self) -> Score {
238        let tf = self.iterator.term_freq() as f32;
239        // Note: Using tf as doc_len proxy since we don't store per-doc field lengths.
240        // This is a common approximation - longer docs tend to have higher TF.
241        super::bm25f_score(tf, self.idf, tf, self.avg_field_len, self.field_boost)
242    }
243
244    fn advance(&mut self) -> DocId {
245        self.iterator.advance()
246    }
247
248    fn seek(&mut self, target: DocId) -> DocId {
249        self.iterator.seek(target)
250    }
251
252    fn size_hint(&self) -> u32 {
253        0
254    }
255
256    fn matched_positions(&self) -> Option<super::MatchedPositions> {
257        let positions = self.positions.as_ref()?;
258        let doc_id = self.iterator.doc();
259        let pos = positions.get_positions(doc_id)?;
260        let score = self.score();
261        // Each position contributes equally to the term score
262        let per_position_score = if pos.is_empty() {
263            0.0
264        } else {
265            score / pos.len() as f32
266        };
267        let scored_positions: Vec<super::ScoredPosition> = pos
268            .iter()
269            .map(|&p| super::ScoredPosition::new(p, per_position_score))
270            .collect();
271        Some(vec![(self.field_id, scored_positions)])
272    }
273}