hermes_core/query/
term.rs

1//! Term query - matches documents containing a specific term
2
3use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::BlockPostingList;
8use crate::wand::WandData;
9use crate::{DocId, Score};
10
11use super::{Bm25Params, CountFuture, EmptyScorer, Query, Scorer, ScorerFuture};
12
13/// Term query - matches documents containing a specific term
14#[derive(Debug, Clone)]
15pub struct TermQuery {
16    pub field: Field,
17    pub term: Vec<u8>,
18    /// Optional pre-computed WAND data for collection-wide IDF
19    wand_data: Option<Arc<WandData>>,
20    /// Field name for WAND data lookup
21    field_name: Option<String>,
22}
23
24impl TermQuery {
25    pub fn new(field: Field, term: impl Into<Vec<u8>>) -> Self {
26        Self {
27            field,
28            term: term.into(),
29            wand_data: None,
30            field_name: None,
31        }
32    }
33
34    pub fn text(field: Field, text: &str) -> Self {
35        Self {
36            field,
37            term: text.to_lowercase().into_bytes(),
38            wand_data: None,
39            field_name: None,
40        }
41    }
42
43    /// Create a term query with pre-computed WAND data for collection-wide IDF
44    ///
45    /// This enables more accurate scoring when querying across multiple segments,
46    /// as the IDF values are computed from the entire collection rather than
47    /// per-segment.
48    pub fn with_wand_data(
49        field: Field,
50        field_name: &str,
51        term: &str,
52        wand_data: Arc<WandData>,
53    ) -> Self {
54        Self {
55            field,
56            term: term.to_lowercase().into_bytes(),
57            wand_data: Some(wand_data),
58            field_name: Some(field_name.to_string()),
59        }
60    }
61
62    /// Set WAND data for this query
63    pub fn set_wand_data(&mut self, field_name: &str, wand_data: Arc<WandData>) {
64        self.wand_data = Some(wand_data);
65        self.field_name = Some(field_name.to_string());
66    }
67}
68
69impl Query for TermQuery {
70    fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> ScorerFuture<'a> {
71        Box::pin(async move {
72            let postings = reader.get_postings(self.field, &self.term).await?;
73
74            match postings {
75                Some(posting_list) => {
76                    // Try to get IDF from pre-computed WAND data first
77                    let idf = if let (Some(wand_data), Some(field_name)) =
78                        (&self.wand_data, &self.field_name)
79                    {
80                        let term_str = String::from_utf8_lossy(&self.term);
81                        wand_data.get_idf(field_name, &term_str).unwrap_or_else(|| {
82                            // Fall back to segment-level IDF if term not in WAND data
83                            let num_docs = reader.num_docs() as f32;
84                            let doc_freq = posting_list.doc_count() as f32;
85                            ((num_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1.0).ln()
86                        })
87                    } else {
88                        // Compute IDF from segment statistics
89                        let num_docs = reader.num_docs() as f32;
90                        let doc_freq = posting_list.doc_count() as f32;
91                        ((num_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1.0).ln()
92                    };
93
94                    // Get average field length for BM25F length normalization
95                    // Use WAND data avg_doc_len if available, otherwise segment-level
96                    let avg_field_len = self
97                        .wand_data
98                        .as_ref()
99                        .map(|w| w.avg_doc_len)
100                        .unwrap_or_else(|| reader.avg_field_len(self.field));
101
102                    Ok(Box::new(TermScorer::new(
103                        posting_list,
104                        idf,
105                        avg_field_len,
106                        Bm25Params::default(),
107                        1.0, // default field boost
108                    )) as Box<dyn Scorer + 'a>)
109                }
110                None => Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
111            }
112        })
113    }
114
115    fn count_estimate<'a>(&'a self, reader: &'a SegmentReader) -> CountFuture<'a> {
116        Box::pin(async move {
117            match reader.get_postings(self.field, &self.term).await? {
118                Some(list) => Ok(list.doc_count()),
119                None => Ok(0),
120            }
121        })
122    }
123}
124
125struct TermScorer {
126    iterator: crate::structures::BlockPostingIterator<'static>,
127    idf: f32,
128    /// BM25 parameters
129    params: Bm25Params,
130    /// Average field length for this field
131    avg_field_len: f32,
132    /// Field boost/weight for BM25F
133    field_boost: f32,
134}
135
136impl TermScorer {
137    pub fn new(
138        posting_list: BlockPostingList,
139        idf: f32,
140        avg_field_len: f32,
141        params: Bm25Params,
142        field_boost: f32,
143    ) -> Self {
144        Self {
145            iterator: posting_list.into_iterator(),
146            idf,
147            params,
148            avg_field_len,
149            field_boost,
150        }
151    }
152}
153
154impl Scorer for TermScorer {
155    fn doc(&self) -> DocId {
156        self.iterator.doc()
157    }
158
159    fn score(&self) -> Score {
160        let tf = self.iterator.term_freq() as f32;
161        let k1 = self.params.k1;
162        let b = self.params.b;
163
164        // BM25F: apply field boost and length normalization
165        let length_norm = 1.0 - b + b * (tf / self.avg_field_len.max(1.0));
166        let tf_norm =
167            (tf * self.field_boost * (k1 + 1.0)) / (tf * self.field_boost + k1 * length_norm);
168
169        self.idf * tf_norm
170    }
171
172    fn advance(&mut self) -> DocId {
173        self.iterator.advance()
174    }
175
176    fn seek(&mut self, target: DocId) -> DocId {
177        self.iterator.seek(target)
178    }
179
180    fn size_hint(&self) -> u32 {
181        0
182    }
183}