Skip to main content

hermes_core/query/
term.rs

1//! Term query - matches documents containing a specific term
2
3use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::BlockPostingList;
8use crate::{DocId, Score};
9
10use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture, TermQueryInfo};
11
12/// Term query - matches documents containing a specific term
13#[derive(Clone)]
14pub struct TermQuery {
15    pub field: Field,
16    pub term: Vec<u8>,
17    /// Optional global statistics for cross-segment IDF
18    global_stats: Option<Arc<GlobalStats>>,
19}
20
21impl std::fmt::Debug for TermQuery {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("TermQuery")
24            .field("field", &self.field)
25            .field("term", &String::from_utf8_lossy(&self.term))
26            .field("has_global_stats", &self.global_stats.is_some())
27            .finish()
28    }
29}
30
31impl TermQuery {
32    pub fn new(field: Field, term: impl Into<Vec<u8>>) -> Self {
33        Self {
34            field,
35            term: term.into(),
36            global_stats: None,
37        }
38    }
39
40    pub fn text(field: Field, text: &str) -> Self {
41        Self {
42            field,
43            term: text.to_lowercase().into_bytes(),
44            global_stats: None,
45        }
46    }
47
48    /// Create with global statistics for cross-segment IDF
49    pub fn with_global_stats(field: Field, text: &str, stats: Arc<GlobalStats>) -> Self {
50        Self {
51            field,
52            term: text.to_lowercase().into_bytes(),
53            global_stats: Some(stats),
54        }
55    }
56
57    /// Set global statistics for cross-segment IDF
58    pub fn set_global_stats(&mut self, stats: Arc<GlobalStats>) {
59        self.global_stats = Some(stats);
60    }
61}
62
63impl Query for TermQuery {
64    fn scorer<'a>(&self, reader: &'a SegmentReader, _limit: usize) -> ScorerFuture<'a> {
65        let field = self.field;
66        let term = self.term.clone();
67        let global_stats = self.global_stats.clone();
68        Box::pin(async move {
69            let postings = reader.get_postings(field, &term).await?;
70
71            match postings {
72                Some(posting_list) => {
73                    // Use global stats IDF if available, otherwise segment-local
74                    let (idf, avg_field_len) = if let Some(ref stats) = global_stats {
75                        let term_str = String::from_utf8_lossy(&term);
76                        let global_idf = stats.text_idf(field, &term_str);
77
78                        // If global stats has this term, use global IDF
79                        // Otherwise fall back to segment-local
80                        if global_idf > 0.0 {
81                            (global_idf, stats.avg_field_len(field))
82                        } else {
83                            // Fall back to segment-local IDF
84                            let num_docs = reader.num_docs() as f32;
85                            let doc_freq = posting_list.doc_count() as f32;
86                            let idf = super::bm25_idf(doc_freq, num_docs);
87                            (idf, reader.avg_field_len(field))
88                        }
89                    } else {
90                        // Compute IDF from segment statistics
91                        let num_docs = reader.num_docs() as f32;
92                        let doc_freq = posting_list.doc_count() as f32;
93                        let idf = super::bm25_idf(doc_freq, num_docs);
94                        (idf, reader.avg_field_len(field))
95                    };
96
97                    // Try to load positions if available
98                    let positions = reader.get_positions(field, &term).await.ok().flatten();
99
100                    let mut scorer = TermScorer::new(
101                        posting_list,
102                        idf,
103                        avg_field_len,
104                        1.0, // default field boost
105                    );
106
107                    if let Some(pos) = positions {
108                        scorer = scorer.with_positions(field.0, pos);
109                    }
110
111                    Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
112                }
113                None => Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
114            }
115        })
116    }
117
118    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
119        let field = self.field;
120        let term = self.term.clone();
121        Box::pin(async move {
122            match reader.get_postings(field, &term).await? {
123                Some(list) => Ok(list.doc_count()),
124                None => Ok(0),
125            }
126        })
127    }
128
129    fn as_term_query_info(&self) -> Option<TermQueryInfo> {
130        Some(TermQueryInfo {
131            field: self.field,
132            term: self.term.clone(),
133        })
134    }
135}
136
137struct TermScorer {
138    iterator: crate::structures::BlockPostingIterator<'static>,
139    idf: f32,
140    /// Average field length for this field
141    avg_field_len: f32,
142    /// Field boost/weight for BM25F
143    field_boost: f32,
144    /// Field ID for position reporting
145    field_id: u32,
146    /// Position posting list (if positions are enabled)
147    positions: Option<crate::structures::PositionPostingList>,
148}
149
150impl TermScorer {
151    pub fn new(
152        posting_list: BlockPostingList,
153        idf: f32,
154        avg_field_len: f32,
155        field_boost: f32,
156    ) -> Self {
157        Self {
158            iterator: posting_list.into_iterator(),
159            idf,
160            avg_field_len,
161            field_boost,
162            field_id: 0,
163            positions: None,
164        }
165    }
166
167    pub fn with_positions(
168        mut self,
169        field_id: u32,
170        positions: crate::structures::PositionPostingList,
171    ) -> Self {
172        self.field_id = field_id;
173        self.positions = Some(positions);
174        self
175    }
176}
177
178impl Scorer for TermScorer {
179    fn doc(&self) -> DocId {
180        self.iterator.doc()
181    }
182
183    fn score(&self) -> Score {
184        let tf = self.iterator.term_freq() as f32;
185        // Note: Using tf as doc_len proxy since we don't store per-doc field lengths.
186        // This is a common approximation - longer docs tend to have higher TF.
187        super::bm25f_score(tf, self.idf, tf, self.avg_field_len, self.field_boost)
188    }
189
190    fn advance(&mut self) -> DocId {
191        self.iterator.advance()
192    }
193
194    fn seek(&mut self, target: DocId) -> DocId {
195        self.iterator.seek(target)
196    }
197
198    fn size_hint(&self) -> u32 {
199        0
200    }
201
202    fn matched_positions(&self) -> Option<super::MatchedPositions> {
203        let positions = self.positions.as_ref()?;
204        let doc_id = self.iterator.doc();
205        let pos = positions.get_positions(doc_id)?;
206        Some(vec![(self.field_id, pos.to_vec())])
207    }
208}