hermes_core/query/
term.rs

1//! Term query - matches documents containing a specific term
2
3use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::BlockPostingList;
8use crate::{DocId, Score};
9
10use super::{
11    Bm25Params, CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture, TermQueryInfo,
12};
13
14/// Term query - matches documents containing a specific term
15#[derive(Clone)]
16pub struct TermQuery {
17    pub field: Field,
18    pub term: Vec<u8>,
19    /// Optional global statistics for cross-segment IDF
20    global_stats: Option<Arc<GlobalStats>>,
21}
22
23impl std::fmt::Debug for TermQuery {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        f.debug_struct("TermQuery")
26            .field("field", &self.field)
27            .field("term", &String::from_utf8_lossy(&self.term))
28            .field("has_global_stats", &self.global_stats.is_some())
29            .finish()
30    }
31}
32
33impl TermQuery {
34    pub fn new(field: Field, term: impl Into<Vec<u8>>) -> Self {
35        Self {
36            field,
37            term: term.into(),
38            global_stats: None,
39        }
40    }
41
42    pub fn text(field: Field, text: &str) -> Self {
43        Self {
44            field,
45            term: text.to_lowercase().into_bytes(),
46            global_stats: None,
47        }
48    }
49
50    /// Create with global statistics for cross-segment IDF
51    pub fn with_global_stats(field: Field, text: &str, stats: Arc<GlobalStats>) -> Self {
52        Self {
53            field,
54            term: text.to_lowercase().into_bytes(),
55            global_stats: Some(stats),
56        }
57    }
58
59    /// Set global statistics for cross-segment IDF
60    pub fn set_global_stats(&mut self, stats: Arc<GlobalStats>) {
61        self.global_stats = Some(stats);
62    }
63}
64
65impl Query for TermQuery {
66    fn scorer<'a>(&'a self, reader: &'a SegmentReader, _limit: usize) -> ScorerFuture<'a> {
67        Box::pin(async move {
68            let postings = reader.get_postings(self.field, &self.term).await?;
69
70            match postings {
71                Some(posting_list) => {
72                    // Use global stats IDF if available, otherwise segment-local
73                    let (idf, avg_field_len) = if let Some(ref stats) = self.global_stats {
74                        let term_str = String::from_utf8_lossy(&self.term);
75                        let global_idf = stats.text_idf(self.field, &term_str);
76
77                        // If global stats has this term, use global IDF
78                        // Otherwise fall back to segment-local
79                        if global_idf > 0.0 {
80                            (global_idf, stats.avg_field_len(self.field))
81                        } else {
82                            // Fall back to segment-local IDF
83                            let num_docs = reader.num_docs() as f32;
84                            let doc_freq = posting_list.doc_count() as f32;
85                            let idf = super::bm25_idf(doc_freq, num_docs);
86                            (idf, reader.avg_field_len(self.field))
87                        }
88                    } else {
89                        // Compute IDF from segment statistics
90                        let num_docs = reader.num_docs() as f32;
91                        let doc_freq = posting_list.doc_count() as f32;
92                        let idf = super::bm25_idf(doc_freq, num_docs);
93                        (idf, reader.avg_field_len(self.field))
94                    };
95
96                    Ok(Box::new(TermScorer::new(
97                        posting_list,
98                        idf,
99                        avg_field_len,
100                        Bm25Params::default(),
101                        1.0, // default field boost
102                    )) as Box<dyn Scorer + 'a>)
103                }
104                None => Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
105            }
106        })
107    }
108
109    fn count_estimate<'a>(&'a self, reader: &'a SegmentReader) -> CountFuture<'a> {
110        Box::pin(async move {
111            match reader.get_postings(self.field, &self.term).await? {
112                Some(list) => Ok(list.doc_count()),
113                None => Ok(0),
114            }
115        })
116    }
117
118    fn as_term_query_info(&self) -> Option<TermQueryInfo> {
119        Some(TermQueryInfo {
120            field: self.field,
121            term: self.term.clone(),
122        })
123    }
124}
125
126struct TermScorer {
127    iterator: crate::structures::BlockPostingIterator<'static>,
128    idf: f32,
129    /// BM25 parameters
130    params: Bm25Params,
131    /// Average field length for this field
132    avg_field_len: f32,
133    /// Field boost/weight for BM25F
134    field_boost: f32,
135}
136
137impl TermScorer {
138    pub fn new(
139        posting_list: BlockPostingList,
140        idf: f32,
141        avg_field_len: f32,
142        params: Bm25Params,
143        field_boost: f32,
144    ) -> Self {
145        Self {
146            iterator: posting_list.into_iterator(),
147            idf,
148            params,
149            avg_field_len,
150            field_boost,
151        }
152    }
153}
154
155impl Scorer for TermScorer {
156    fn doc(&self) -> DocId {
157        self.iterator.doc()
158    }
159
160    fn score(&self) -> Score {
161        let tf = self.iterator.term_freq() as f32;
162        let k1 = self.params.k1;
163        let b = self.params.b;
164
165        // BM25F: apply field boost and length normalization
166        let length_norm = 1.0 - b + b * (tf / self.avg_field_len.max(1.0));
167        let tf_norm =
168            (tf * self.field_boost * (k1 + 1.0)) / (tf * self.field_boost + k1 * length_norm);
169
170        self.idf * tf_norm
171    }
172
173    fn advance(&mut self) -> DocId {
174        self.iterator.advance()
175    }
176
177    fn seek(&mut self, target: DocId) -> DocId {
178        self.iterator.seek(target)
179    }
180
181    fn size_hint(&self) -> u32 {
182        0
183    }
184}