1use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::BlockPostingList;
8use crate::{DocId, Score};
9
10use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture, TermQueryInfo};
11
12#[derive(Clone)]
14pub struct TermQuery {
15 pub field: Field,
16 pub term: Vec<u8>,
17 global_stats: Option<Arc<GlobalStats>>,
19}
20
21impl std::fmt::Debug for TermQuery {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("TermQuery")
24 .field("field", &self.field)
25 .field("term", &String::from_utf8_lossy(&self.term))
26 .field("has_global_stats", &self.global_stats.is_some())
27 .finish()
28 }
29}
30
31impl std::fmt::Display for TermQuery {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 write!(
34 f,
35 "Term({}:\"{}\")",
36 self.field.0,
37 String::from_utf8_lossy(&self.term)
38 )
39 }
40}
41
42impl TermQuery {
43 pub fn new(field: Field, term: impl Into<Vec<u8>>) -> Self {
44 Self {
45 field,
46 term: term.into(),
47 global_stats: None,
48 }
49 }
50
51 pub fn text(field: Field, text: &str) -> Self {
52 Self {
53 field,
54 term: text.to_lowercase().into_bytes(),
55 global_stats: None,
56 }
57 }
58
59 pub fn with_global_stats(field: Field, text: &str, stats: Arc<GlobalStats>) -> Self {
61 Self {
62 field,
63 term: text.to_lowercase().into_bytes(),
64 global_stats: Some(stats),
65 }
66 }
67
68 pub fn set_global_stats(&mut self, stats: Arc<GlobalStats>) {
70 self.global_stats = Some(stats);
71 }
72}
73
74impl Query for TermQuery {
75 fn scorer<'a>(&self, reader: &'a SegmentReader, _limit: usize) -> ScorerFuture<'a> {
76 let field = self.field;
77 let term = self.term.clone();
78 let global_stats = self.global_stats.clone();
79 Box::pin(async move {
80 let postings = reader.get_postings(field, &term).await?;
81
82 match postings {
83 Some(posting_list) => {
84 let (idf, avg_field_len) = if let Some(ref stats) = global_stats {
86 let term_str = String::from_utf8_lossy(&term);
87 let global_idf = stats.text_idf(field, &term_str);
88
89 if global_idf > 0.0 {
92 (global_idf, stats.avg_field_len(field))
93 } else {
94 let num_docs = reader.num_docs() as f32;
96 let doc_freq = posting_list.doc_count() as f32;
97 let idf = super::bm25_idf(doc_freq, num_docs);
98 (idf, reader.avg_field_len(field))
99 }
100 } else {
101 let num_docs = reader.num_docs() as f32;
103 let doc_freq = posting_list.doc_count() as f32;
104 let idf = super::bm25_idf(doc_freq, num_docs);
105 (idf, reader.avg_field_len(field))
106 };
107
108 let positions = reader.get_positions(field, &term).await.ok().flatten();
110
111 let mut scorer = TermScorer::new(
112 posting_list,
113 idf,
114 avg_field_len,
115 1.0, );
117
118 if let Some(pos) = positions {
119 scorer = scorer.with_positions(field.0, pos);
120 }
121
122 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
123 }
124 None => Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
125 }
126 })
127 }
128
129 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
130 let field = self.field;
131 let term = self.term.clone();
132 Box::pin(async move {
133 match reader.get_postings(field, &term).await? {
134 Some(list) => Ok(list.doc_count()),
135 None => Ok(0),
136 }
137 })
138 }
139
140 #[cfg(feature = "sync")]
141 fn scorer_sync<'a>(
142 &self,
143 reader: &'a SegmentReader,
144 _limit: usize,
145 ) -> crate::Result<Box<dyn Scorer + 'a>> {
146 let postings = reader.get_postings_sync(self.field, &self.term)?;
147
148 match postings {
149 Some(posting_list) => {
150 let (idf, avg_field_len) = if let Some(ref stats) = self.global_stats {
151 let term_str = String::from_utf8_lossy(&self.term);
152 let global_idf = stats.text_idf(self.field, &term_str);
153 if global_idf > 0.0 {
154 (global_idf, stats.avg_field_len(self.field))
155 } else {
156 let num_docs = reader.num_docs() as f32;
157 let doc_freq = posting_list.doc_count() as f32;
158 (
159 super::bm25_idf(doc_freq, num_docs),
160 reader.avg_field_len(self.field),
161 )
162 }
163 } else {
164 let num_docs = reader.num_docs() as f32;
165 let doc_freq = posting_list.doc_count() as f32;
166 (
167 super::bm25_idf(doc_freq, num_docs),
168 reader.avg_field_len(self.field),
169 )
170 };
171
172 let positions = reader
173 .get_positions_sync(self.field, &self.term)
174 .ok()
175 .flatten();
176
177 let mut scorer = TermScorer::new(posting_list, idf, avg_field_len, 1.0);
178 if let Some(pos) = positions {
179 scorer = scorer.with_positions(self.field.0, pos);
180 }
181
182 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
183 }
184 None => Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
185 }
186 }
187
188 fn as_term_query_info(&self) -> Option<TermQueryInfo> {
189 Some(TermQueryInfo {
190 field: self.field,
191 term: self.term.clone(),
192 })
193 }
194}
195
196struct TermScorer {
197 iterator: crate::structures::BlockPostingIterator<'static>,
198 idf: f32,
199 avg_field_len: f32,
201 field_boost: f32,
203 field_id: u32,
205 positions: Option<crate::structures::PositionPostingList>,
207}
208
209impl TermScorer {
210 pub fn new(
211 posting_list: BlockPostingList,
212 idf: f32,
213 avg_field_len: f32,
214 field_boost: f32,
215 ) -> Self {
216 Self {
217 iterator: posting_list.into_iterator(),
218 idf,
219 avg_field_len,
220 field_boost,
221 field_id: 0,
222 positions: None,
223 }
224 }
225
226 pub fn with_positions(
227 mut self,
228 field_id: u32,
229 positions: crate::structures::PositionPostingList,
230 ) -> Self {
231 self.field_id = field_id;
232 self.positions = Some(positions);
233 self
234 }
235}
236
237impl super::docset::DocSet for TermScorer {
238 fn doc(&self) -> DocId {
239 self.iterator.doc()
240 }
241
242 fn advance(&mut self) -> DocId {
243 self.iterator.advance()
244 }
245
246 fn seek(&mut self, target: DocId) -> DocId {
247 self.iterator.seek(target)
248 }
249
250 fn size_hint(&self) -> u32 {
251 0
252 }
253}
254
255impl Scorer for TermScorer {
256 fn score(&self) -> Score {
257 let tf = self.iterator.term_freq() as f32;
258 super::bm25f_score(tf, self.idf, tf, self.avg_field_len, self.field_boost)
261 }
262
263 fn matched_positions(&self) -> Option<super::MatchedPositions> {
264 let positions = self.positions.as_ref()?;
265 let doc_id = self.iterator.doc();
266 let pos = positions.get_positions(doc_id)?;
267 let score = self.score();
268 let per_position_score = if pos.is_empty() {
270 0.0
271 } else {
272 score / pos.len() as f32
273 };
274 let scored_positions: Vec<super::ScoredPosition> = pos
275 .iter()
276 .map(|&p| super::ScoredPosition::new(p, per_position_score))
277 .collect();
278 Some(vec![(self.field_id, scored_positions)])
279 }
280}