1use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::BlockPostingList;
8use crate::structures::TERMINATED;
9use crate::{DocId, Score};
10
11use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture, TermQueryInfo};
12
13#[derive(Clone)]
15pub struct TermQuery {
16 pub field: Field,
17 pub term: Vec<u8>,
18 global_stats: Option<Arc<GlobalStats>>,
20}
21
22impl std::fmt::Debug for TermQuery {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 f.debug_struct("TermQuery")
25 .field("field", &self.field)
26 .field("term", &String::from_utf8_lossy(&self.term))
27 .field("has_global_stats", &self.global_stats.is_some())
28 .finish()
29 }
30}
31
32impl std::fmt::Display for TermQuery {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 write!(
35 f,
36 "Term({}:\"{}\")",
37 self.field.0,
38 String::from_utf8_lossy(&self.term)
39 )
40 }
41}
42
43impl TermQuery {
44 pub fn new(field: Field, term: impl Into<Vec<u8>>) -> Self {
45 Self {
46 field,
47 term: term.into(),
48 global_stats: None,
49 }
50 }
51
52 pub fn text(field: Field, text: &str) -> Self {
53 Self {
54 field,
55 term: text.to_lowercase().into_bytes(),
56 global_stats: None,
57 }
58 }
59
60 pub fn with_global_stats(field: Field, text: &str, stats: Arc<GlobalStats>) -> Self {
62 Self {
63 field,
64 term: text.to_lowercase().into_bytes(),
65 global_stats: Some(stats),
66 }
67 }
68
69 pub fn set_global_stats(&mut self, stats: Arc<GlobalStats>) {
71 self.global_stats = Some(stats);
72 }
73}
74
75impl Query for TermQuery {
76 fn scorer<'a>(&self, reader: &'a SegmentReader, _limit: usize) -> ScorerFuture<'a> {
77 let field = self.field;
78 let term = self.term.clone();
79 let global_stats = self.global_stats.clone();
80 Box::pin(async move {
81 let postings = reader.get_postings(field, &term).await?;
82
83 match postings {
84 Some(posting_list) => {
85 let (idf, avg_field_len) = if let Some(ref stats) = global_stats {
87 let term_str = String::from_utf8_lossy(&term);
88 let global_idf = stats.text_idf(field, &term_str);
89
90 if global_idf > 0.0 {
93 (global_idf, stats.avg_field_len(field))
94 } else {
95 let num_docs = reader.num_docs() as f32;
97 let doc_freq = posting_list.doc_count() as f32;
98 let idf = super::bm25_idf(doc_freq, num_docs);
99 (idf, reader.avg_field_len(field))
100 }
101 } else {
102 let num_docs = reader.num_docs() as f32;
104 let doc_freq = posting_list.doc_count() as f32;
105 let idf = super::bm25_idf(doc_freq, num_docs);
106 (idf, reader.avg_field_len(field))
107 };
108
109 let positions = reader.get_positions(field, &term).await.ok().flatten();
111
112 let mut scorer = TermScorer::new(
113 posting_list,
114 idf,
115 avg_field_len,
116 1.0, );
118
119 if let Some(pos) = positions {
120 scorer = scorer.with_positions(field.0, pos);
121 }
122
123 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
124 }
125 None => {
126 let term_str = String::from_utf8_lossy(&term);
128 if let Some(scorer) = FastFieldTextScorer::try_new(reader, field, &term_str) {
129 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
130 } else {
131 Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>)
132 }
133 }
134 }
135 })
136 }
137
138 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
139 let field = self.field;
140 let term = self.term.clone();
141 Box::pin(async move {
142 match reader.get_postings(field, &term).await? {
143 Some(list) => Ok(list.doc_count()),
144 None => Ok(0),
145 }
146 })
147 }
148
149 #[cfg(feature = "sync")]
150 fn scorer_sync<'a>(
151 &self,
152 reader: &'a SegmentReader,
153 _limit: usize,
154 ) -> crate::Result<Box<dyn Scorer + 'a>> {
155 let postings = reader.get_postings_sync(self.field, &self.term)?;
156
157 match postings {
158 Some(posting_list) => {
159 let (idf, avg_field_len) = if let Some(ref stats) = self.global_stats {
160 let term_str = String::from_utf8_lossy(&self.term);
161 let global_idf = stats.text_idf(self.field, &term_str);
162 if global_idf > 0.0 {
163 (global_idf, stats.avg_field_len(self.field))
164 } else {
165 let num_docs = reader.num_docs() as f32;
166 let doc_freq = posting_list.doc_count() as f32;
167 (
168 super::bm25_idf(doc_freq, num_docs),
169 reader.avg_field_len(self.field),
170 )
171 }
172 } else {
173 let num_docs = reader.num_docs() as f32;
174 let doc_freq = posting_list.doc_count() as f32;
175 (
176 super::bm25_idf(doc_freq, num_docs),
177 reader.avg_field_len(self.field),
178 )
179 };
180
181 let positions = reader
182 .get_positions_sync(self.field, &self.term)
183 .ok()
184 .flatten();
185
186 let mut scorer = TermScorer::new(posting_list, idf, avg_field_len, 1.0);
187 if let Some(pos) = positions {
188 scorer = scorer.with_positions(self.field.0, pos);
189 }
190
191 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
192 }
193 None => {
194 let term_str = String::from_utf8_lossy(&self.term);
195 if let Some(scorer) = FastFieldTextScorer::try_new(reader, self.field, &term_str) {
196 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
197 } else {
198 Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>)
199 }
200 }
201 }
202 }
203
204 fn as_term_query_info(&self) -> Option<TermQueryInfo> {
205 Some(TermQueryInfo {
206 field: self.field,
207 term: self.term.clone(),
208 })
209 }
210}
211
212struct TermScorer {
213 iterator: crate::structures::BlockPostingIterator<'static>,
214 idf: f32,
215 avg_field_len: f32,
217 field_boost: f32,
219 field_id: u32,
221 positions: Option<crate::structures::PositionPostingList>,
223}
224
225impl TermScorer {
226 pub fn new(
227 posting_list: BlockPostingList,
228 idf: f32,
229 avg_field_len: f32,
230 field_boost: f32,
231 ) -> Self {
232 Self {
233 iterator: posting_list.into_iterator(),
234 idf,
235 avg_field_len,
236 field_boost,
237 field_id: 0,
238 positions: None,
239 }
240 }
241
242 pub fn with_positions(
243 mut self,
244 field_id: u32,
245 positions: crate::structures::PositionPostingList,
246 ) -> Self {
247 self.field_id = field_id;
248 self.positions = Some(positions);
249 self
250 }
251}
252
253impl super::docset::DocSet for TermScorer {
254 fn doc(&self) -> DocId {
255 self.iterator.doc()
256 }
257
258 fn advance(&mut self) -> DocId {
259 self.iterator.advance()
260 }
261
262 fn seek(&mut self, target: DocId) -> DocId {
263 self.iterator.seek(target)
264 }
265
266 fn size_hint(&self) -> u32 {
267 0
268 }
269}
270
271struct FastFieldTextScorer<'a> {
277 fast_field: &'a crate::structures::fast_field::FastFieldReader,
278 target_ordinal: u64,
279 current: u32,
280 num_docs: u32,
281}
282
283impl<'a> FastFieldTextScorer<'a> {
284 fn try_new(reader: &'a SegmentReader, field: Field, text: &str) -> Option<Self> {
285 let fast_field = reader.fast_field(field.0)?;
286 let target_ordinal = fast_field.text_ordinal(text)?;
287 let num_docs = reader.num_docs();
288 let mut scorer = Self {
289 fast_field,
290 target_ordinal,
291 current: 0,
292 num_docs,
293 };
294 if num_docs > 0 && fast_field.get_u64(0) != target_ordinal {
296 scorer.scan_forward();
297 }
298 Some(scorer)
299 }
300
301 fn scan_forward(&mut self) {
302 loop {
303 self.current += 1;
304 if self.current >= self.num_docs {
305 self.current = self.num_docs;
306 return;
307 }
308 if self.fast_field.get_u64(self.current) == self.target_ordinal {
309 return;
310 }
311 }
312 }
313}
314
315impl super::docset::DocSet for FastFieldTextScorer<'_> {
316 fn doc(&self) -> DocId {
317 if self.current >= self.num_docs {
318 TERMINATED
319 } else {
320 self.current
321 }
322 }
323
324 fn advance(&mut self) -> DocId {
325 self.scan_forward();
326 self.doc()
327 }
328
329 fn seek(&mut self, target: DocId) -> DocId {
330 if target > self.current {
331 self.current = target;
332 if self.current < self.num_docs
333 && self.fast_field.get_u64(self.current) != self.target_ordinal
334 {
335 self.scan_forward();
336 }
337 }
338 self.doc()
339 }
340
341 fn size_hint(&self) -> u32 {
342 0
343 }
344}
345
346impl Scorer for FastFieldTextScorer<'_> {
347 fn score(&self) -> Score {
348 1.0
349 }
350}
351
352impl Scorer for TermScorer {
353 fn score(&self) -> Score {
354 let tf = self.iterator.term_freq() as f32;
355 super::bm25f_score(tf, self.idf, tf, self.avg_field_len, self.field_boost)
358 }
359
360 fn matched_positions(&self) -> Option<super::MatchedPositions> {
361 let positions = self.positions.as_ref()?;
362 let doc_id = self.iterator.doc();
363 let pos = positions.get_positions(doc_id)?;
364 let score = self.score();
365 let per_position_score = if pos.is_empty() {
367 0.0
368 } else {
369 score / pos.len() as f32
370 };
371 let scored_positions: Vec<super::ScoredPosition> = pos
372 .iter()
373 .map(|&p| super::ScoredPosition::new(p, per_position_score))
374 .collect();
375 Some(vec![(self.field_id, scored_positions)])
376 }
377}