1use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::BlockPostingList;
8use crate::structures::TERMINATED;
9use crate::{DocId, Score};
10
11use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture, TermQueryInfo};
12
13#[derive(Clone)]
15pub struct TermQuery {
16 pub field: Field,
17 pub term: Vec<u8>,
18 global_stats: Option<Arc<GlobalStats>>,
20}
21
22impl std::fmt::Debug for TermQuery {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 f.debug_struct("TermQuery")
25 .field("field", &self.field)
26 .field("term", &String::from_utf8_lossy(&self.term))
27 .field("has_global_stats", &self.global_stats.is_some())
28 .finish()
29 }
30}
31
32impl std::fmt::Display for TermQuery {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 write!(
35 f,
36 "Term({}:\"{}\")",
37 self.field.0,
38 String::from_utf8_lossy(&self.term)
39 )
40 }
41}
42
43impl TermQuery {
44 pub fn new(field: Field, term: impl Into<Vec<u8>>) -> Self {
45 Self {
46 field,
47 term: term.into(),
48 global_stats: None,
49 }
50 }
51
52 pub fn text(field: Field, text: &str) -> Self {
53 Self {
54 field,
55 term: text.to_lowercase().into_bytes(),
56 global_stats: None,
57 }
58 }
59
60 pub fn with_global_stats(field: Field, text: &str, stats: Arc<GlobalStats>) -> Self {
62 Self {
63 field,
64 term: text.to_lowercase().into_bytes(),
65 global_stats: Some(stats),
66 }
67 }
68
69 pub fn set_global_stats(&mut self, stats: Arc<GlobalStats>) {
71 self.global_stats = Some(stats);
72 }
73}
74
75impl Query for TermQuery {
76 fn scorer<'a>(&self, reader: &'a SegmentReader, _limit: usize) -> ScorerFuture<'a> {
77 let field = self.field;
78 let term = self.term.clone();
79 let global_stats = self.global_stats.clone();
80 Box::pin(async move {
81 let postings = reader.get_postings(field, &term).await?;
82
83 match postings {
84 Some(posting_list) => {
85 let (idf, avg_field_len) = if let Some(ref stats) = global_stats {
87 let term_str = String::from_utf8_lossy(&term);
88 let global_idf = stats.text_idf(field, &term_str);
89
90 if global_idf > 0.0 {
93 (global_idf, stats.avg_field_len(field))
94 } else {
95 let num_docs = reader.num_docs() as f32;
97 let doc_freq = posting_list.doc_count() as f32;
98 let idf = super::bm25_idf(doc_freq, num_docs);
99 (idf, reader.avg_field_len(field))
100 }
101 } else {
102 let num_docs = reader.num_docs() as f32;
104 let doc_freq = posting_list.doc_count() as f32;
105 let idf = super::bm25_idf(doc_freq, num_docs);
106 (idf, reader.avg_field_len(field))
107 };
108
109 let positions = reader.get_positions(field, &term).await.ok().flatten();
111
112 let mut scorer = TermScorer::new(
113 posting_list,
114 idf,
115 avg_field_len,
116 1.0, );
118
119 if let Some(pos) = positions {
120 scorer = scorer.with_positions(field.0, pos);
121 }
122
123 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
124 }
125 None => {
126 let term_str = String::from_utf8_lossy(&term);
128 if let Some(scorer) = FastFieldTextScorer::try_new(reader, field, &term_str) {
129 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
130 } else {
131 Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>)
132 }
133 }
134 }
135 })
136 }
137
138 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
139 let field = self.field;
140 let term = self.term.clone();
141 Box::pin(async move {
142 match reader.get_postings(field, &term).await? {
143 Some(list) => Ok(list.doc_count()),
144 None => Ok(0),
145 }
146 })
147 }
148
149 #[cfg(feature = "sync")]
150 fn scorer_sync<'a>(
151 &self,
152 reader: &'a SegmentReader,
153 _limit: usize,
154 ) -> crate::Result<Box<dyn Scorer + 'a>> {
155 let postings = reader.get_postings_sync(self.field, &self.term)?;
156
157 match postings {
158 Some(posting_list) => {
159 let (idf, avg_field_len) = if let Some(ref stats) = self.global_stats {
160 let term_str = String::from_utf8_lossy(&self.term);
161 let global_idf = stats.text_idf(self.field, &term_str);
162 if global_idf > 0.0 {
163 (global_idf, stats.avg_field_len(self.field))
164 } else {
165 let num_docs = reader.num_docs() as f32;
166 let doc_freq = posting_list.doc_count() as f32;
167 (
168 super::bm25_idf(doc_freq, num_docs),
169 reader.avg_field_len(self.field),
170 )
171 }
172 } else {
173 let num_docs = reader.num_docs() as f32;
174 let doc_freq = posting_list.doc_count() as f32;
175 (
176 super::bm25_idf(doc_freq, num_docs),
177 reader.avg_field_len(self.field),
178 )
179 };
180
181 let positions = reader
182 .get_positions_sync(self.field, &self.term)
183 .ok()
184 .flatten();
185
186 let mut scorer = TermScorer::new(posting_list, idf, avg_field_len, 1.0);
187 if let Some(pos) = positions {
188 scorer = scorer.with_positions(self.field.0, pos);
189 }
190
191 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
192 }
193 None => {
194 let term_str = String::from_utf8_lossy(&self.term);
195 if let Some(scorer) = FastFieldTextScorer::try_new(reader, self.field, &term_str) {
196 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
197 } else {
198 Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>)
199 }
200 }
201 }
202 }
203
204 fn as_doc_predicate<'a>(&self, reader: &'a SegmentReader) -> Option<super::DocPredicate<'a>> {
205 let fast_field = reader.fast_field(self.field.0)?;
206 let term_str = String::from_utf8_lossy(&self.term);
207 let target_ordinal = fast_field.text_ordinal(&term_str)?;
208 Some(Box::new(move |doc_id: DocId| -> bool {
209 fast_field.get_u64(doc_id) == target_ordinal
210 }))
211 }
212
213 fn as_term_query_info(&self) -> Option<TermQueryInfo> {
214 Some(TermQueryInfo {
215 field: self.field,
216 term: self.term.clone(),
217 })
218 }
219}
220
221struct TermScorer {
222 iterator: crate::structures::BlockPostingIterator<'static>,
223 idf: f32,
224 avg_field_len: f32,
226 field_boost: f32,
228 field_id: u32,
230 positions: Option<crate::structures::PositionPostingList>,
232}
233
234impl TermScorer {
235 pub fn new(
236 posting_list: BlockPostingList,
237 idf: f32,
238 avg_field_len: f32,
239 field_boost: f32,
240 ) -> Self {
241 Self {
242 iterator: posting_list.into_iterator(),
243 idf,
244 avg_field_len,
245 field_boost,
246 field_id: 0,
247 positions: None,
248 }
249 }
250
251 pub fn with_positions(
252 mut self,
253 field_id: u32,
254 positions: crate::structures::PositionPostingList,
255 ) -> Self {
256 self.field_id = field_id;
257 self.positions = Some(positions);
258 self
259 }
260}
261
262impl super::docset::DocSet for TermScorer {
263 fn doc(&self) -> DocId {
264 self.iterator.doc()
265 }
266
267 fn advance(&mut self) -> DocId {
268 self.iterator.advance()
269 }
270
271 fn seek(&mut self, target: DocId) -> DocId {
272 self.iterator.seek(target)
273 }
274
275 fn size_hint(&self) -> u32 {
276 0
277 }
278}
279
280struct FastFieldTextScorer<'a> {
286 fast_field: &'a crate::structures::fast_field::FastFieldReader,
287 target_ordinal: u64,
288 current: u32,
289 num_docs: u32,
290}
291
292impl<'a> FastFieldTextScorer<'a> {
293 fn try_new(reader: &'a SegmentReader, field: Field, text: &str) -> Option<Self> {
294 let fast_field = reader.fast_field(field.0)?;
295 let target_ordinal = fast_field.text_ordinal(text)?;
296 let num_docs = reader.num_docs();
297 let mut scorer = Self {
298 fast_field,
299 target_ordinal,
300 current: 0,
301 num_docs,
302 };
303 if num_docs > 0 && fast_field.get_u64(0) != target_ordinal {
305 scorer.scan_forward();
306 }
307 Some(scorer)
308 }
309
310 fn scan_forward(&mut self) {
311 loop {
312 self.current += 1;
313 if self.current >= self.num_docs {
314 self.current = self.num_docs;
315 return;
316 }
317 if self.fast_field.get_u64(self.current) == self.target_ordinal {
318 return;
319 }
320 }
321 }
322}
323
324impl super::docset::DocSet for FastFieldTextScorer<'_> {
325 fn doc(&self) -> DocId {
326 if self.current >= self.num_docs {
327 TERMINATED
328 } else {
329 self.current
330 }
331 }
332
333 fn advance(&mut self) -> DocId {
334 self.scan_forward();
335 self.doc()
336 }
337
338 fn seek(&mut self, target: DocId) -> DocId {
339 if target > self.current {
340 self.current = target;
341 if self.current < self.num_docs
342 && self.fast_field.get_u64(self.current) != self.target_ordinal
343 {
344 self.scan_forward();
345 }
346 }
347 self.doc()
348 }
349
350 fn size_hint(&self) -> u32 {
351 0
352 }
353}
354
355impl Scorer for FastFieldTextScorer<'_> {
356 fn score(&self) -> Score {
357 1.0
358 }
359}
360
361impl Scorer for TermScorer {
362 fn score(&self) -> Score {
363 let tf = self.iterator.term_freq() as f32;
364 super::bm25f_score(tf, self.idf, tf, self.avg_field_len, self.field_boost)
367 }
368
369 fn matched_positions(&self) -> Option<super::MatchedPositions> {
370 let positions = self.positions.as_ref()?;
371 let doc_id = self.iterator.doc();
372 let pos = positions.get_positions(doc_id)?;
373 let score = self.score();
374 let per_position_score = if pos.is_empty() {
376 0.0
377 } else {
378 score / pos.len() as f32
379 };
380 let scored_positions: Vec<super::ScoredPosition> = pos
381 .iter()
382 .map(|&p| super::ScoredPosition::new(p, per_position_score))
383 .collect();
384 Some(vec![(self.field_id, scored_positions)])
385 }
386}