1use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::BlockPostingList;
8use crate::structures::TERMINATED;
9use crate::{DocId, Score};
10
11use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture, TermQueryInfo};
12
13#[derive(Clone)]
15pub struct TermQuery {
16 pub field: Field,
17 pub term: Vec<u8>,
18 global_stats: Option<Arc<GlobalStats>>,
20}
21
22impl std::fmt::Debug for TermQuery {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 f.debug_struct("TermQuery")
25 .field("field", &self.field)
26 .field("term", &String::from_utf8_lossy(&self.term))
27 .field("has_global_stats", &self.global_stats.is_some())
28 .finish()
29 }
30}
31
32impl std::fmt::Display for TermQuery {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 write!(
35 f,
36 "Term({}:\"{}\")",
37 self.field.0,
38 String::from_utf8_lossy(&self.term)
39 )
40 }
41}
42
43impl TermQuery {
44 pub fn new(field: Field, term: impl Into<Vec<u8>>) -> Self {
45 Self {
46 field,
47 term: term.into(),
48 global_stats: None,
49 }
50 }
51
52 pub fn text(field: Field, text: &str) -> Self {
53 Self {
54 field,
55 term: text.to_lowercase().into_bytes(),
56 global_stats: None,
57 }
58 }
59
60 pub fn with_global_stats(field: Field, text: &str, stats: Arc<GlobalStats>) -> Self {
62 Self {
63 field,
64 term: text.to_lowercase().into_bytes(),
65 global_stats: Some(stats),
66 }
67 }
68
69 pub fn set_global_stats(&mut self, stats: Arc<GlobalStats>) {
71 self.global_stats = Some(stats);
72 }
73}
74
75fn compute_term_idf(
77 posting_list: &BlockPostingList,
78 field: Field,
79 reader: &SegmentReader,
80 global_stats: Option<&Arc<GlobalStats>>,
81 term: &[u8],
82) -> (f32, f32) {
83 if let Some(stats) = global_stats {
84 let term_str = String::from_utf8_lossy(term);
85 let global_idf = stats.text_idf(field, &term_str);
86 if global_idf > 0.0 {
87 return (global_idf, stats.avg_field_len(field));
88 }
89 }
90 let num_docs = reader.num_docs() as f32;
91 let doc_freq = posting_list.doc_count() as f32;
92 (
93 super::bm25_idf(doc_freq, num_docs),
94 reader.avg_field_len(field),
95 )
96}
97
98macro_rules! term_plan {
105 ($field:expr, $term:expr, $global_stats:expr, $reader:expr,
106 $get_postings_fn:ident, $get_positions_fn:ident
107 $(, $aw:tt)*) => {{
108 let field: Field = $field;
109 let term: &[u8] = $term;
110 let global_stats: Option<&Arc<GlobalStats>> = $global_stats;
111 let reader: &SegmentReader = $reader;
112
113 let is_indexed = reader.schema().get_field_entry(field).is_none_or(|e| e.indexed);
115 if !is_indexed {
116 let term_str = String::from_utf8_lossy(term);
117 if let Some(scorer) = FastFieldTextScorer::try_new(reader, field, &term_str) {
118 return Ok(Box::new(scorer) as Box<dyn Scorer + '_>);
119 }
120 return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + '_>);
121 }
122
123 let postings = reader.$get_postings_fn(field, term) $(. $aw)* ?;
124
125 match postings {
126 Some(posting_list) => {
127 let (idf, avg_field_len) =
128 compute_term_idf(&posting_list, field, reader, global_stats, term);
129
130 let positions = reader.$get_positions_fn(field, term)
131 $(. $aw)* .ok().flatten();
132
133 let mut scorer = TermScorer::new(posting_list, idf, avg_field_len, 1.0);
134 if let Some(pos) = positions {
135 scorer = scorer.with_positions(field.0, pos);
136 }
137 Ok(Box::new(scorer) as Box<dyn Scorer + '_>)
138 }
139 None => {
140 let term_str = String::from_utf8_lossy(term);
141 if let Some(scorer) = FastFieldTextScorer::try_new(reader, field, &term_str) {
142 Ok(Box::new(scorer) as Box<dyn Scorer + '_>)
143 } else {
144 Ok(Box::new(EmptyScorer) as Box<dyn Scorer + '_>)
145 }
146 }
147 }
148 }};
149}
150
151impl Query for TermQuery {
152 fn scorer<'a>(&self, reader: &'a SegmentReader, _limit: usize) -> ScorerFuture<'a> {
153 let field = self.field;
154 let term = self.term.clone();
155 let global_stats = self.global_stats.clone();
156 Box::pin(async move {
157 term_plan!(
158 field,
159 &term,
160 global_stats.as_ref(),
161 reader,
162 get_postings,
163 get_positions,
164 await
165 )
166 })
167 }
168
169 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
170 let field = self.field;
171 let term = self.term.clone();
172 Box::pin(async move {
173 match reader.get_postings(field, &term).await? {
174 Some(list) => Ok(list.doc_count()),
175 None => Ok(0),
176 }
177 })
178 }
179
180 #[cfg(feature = "sync")]
181 fn scorer_sync<'a>(
182 &self,
183 reader: &'a SegmentReader,
184 _limit: usize,
185 ) -> crate::Result<Box<dyn Scorer + 'a>> {
186 term_plan!(
187 self.field,
188 &self.term,
189 self.global_stats.as_ref(),
190 reader,
191 get_postings_sync,
192 get_positions_sync
193 )
194 }
195
196 fn as_doc_predicate<'a>(&self, reader: &'a SegmentReader) -> Option<super::DocPredicate<'a>> {
197 let fast_field = reader.fast_field(self.field.0)?;
198 let term_str = String::from_utf8_lossy(&self.term);
199 match fast_field.text_ordinal(&term_str) {
200 Some(target_ordinal) => Some(Box::new(move |doc_id: DocId| -> bool {
201 fast_field.get_u64(doc_id) == target_ordinal
202 })),
203 None => Some(Box::new(|_| false)),
205 }
206 }
207
208 fn decompose(&self) -> super::QueryDecomposition {
209 super::QueryDecomposition::TextTerm(TermQueryInfo {
210 field: self.field,
211 term: self.term.clone(),
212 })
213 }
214}
215
216struct TermScorer {
217 iterator: crate::structures::BlockPostingIterator<'static>,
218 idf: f32,
219 avg_field_len: f32,
221 field_boost: f32,
223 field_id: u32,
225 positions: Option<crate::structures::PositionPostingList>,
227}
228
229impl TermScorer {
230 pub fn new(
231 posting_list: BlockPostingList,
232 idf: f32,
233 avg_field_len: f32,
234 field_boost: f32,
235 ) -> Self {
236 Self {
237 iterator: posting_list.into_iterator(),
238 idf,
239 avg_field_len,
240 field_boost,
241 field_id: 0,
242 positions: None,
243 }
244 }
245
246 pub fn with_positions(
247 mut self,
248 field_id: u32,
249 positions: crate::structures::PositionPostingList,
250 ) -> Self {
251 self.field_id = field_id;
252 self.positions = Some(positions);
253 self
254 }
255}
256
257impl super::docset::DocSet for TermScorer {
258 fn doc(&self) -> DocId {
259 self.iterator.doc()
260 }
261
262 fn advance(&mut self) -> DocId {
263 self.iterator.advance()
264 }
265
266 fn seek(&mut self, target: DocId) -> DocId {
267 self.iterator.seek(target)
268 }
269
270 fn size_hint(&self) -> u32 {
271 0
272 }
273}
274
275struct FastFieldTextScorer<'a> {
281 fast_field: &'a crate::structures::fast_field::FastFieldReader,
282 target_ordinal: u64,
283 current: u32,
284 num_docs: u32,
285}
286
287impl<'a> FastFieldTextScorer<'a> {
288 fn try_new(reader: &'a SegmentReader, field: Field, text: &str) -> Option<Self> {
289 let fast_field = reader.fast_field(field.0)?;
290 let target_ordinal = fast_field.text_ordinal(text)?;
291 let num_docs = reader.num_docs();
292 let mut scorer = Self {
293 fast_field,
294 target_ordinal,
295 current: 0,
296 num_docs,
297 };
298 if num_docs > 0 && fast_field.get_u64(0) != target_ordinal {
300 scorer.scan_forward();
301 }
302 Some(scorer)
303 }
304
305 fn scan_forward(&mut self) {
306 loop {
307 self.current += 1;
308 if self.current >= self.num_docs {
309 self.current = self.num_docs;
310 return;
311 }
312 if self.fast_field.get_u64(self.current) == self.target_ordinal {
313 return;
314 }
315 }
316 }
317}
318
319impl super::docset::DocSet for FastFieldTextScorer<'_> {
320 fn doc(&self) -> DocId {
321 if self.current >= self.num_docs {
322 TERMINATED
323 } else {
324 self.current
325 }
326 }
327
328 fn advance(&mut self) -> DocId {
329 self.scan_forward();
330 self.doc()
331 }
332
333 fn seek(&mut self, target: DocId) -> DocId {
334 if target > self.current {
335 self.current = target;
336 if self.current < self.num_docs
337 && self.fast_field.get_u64(self.current) != self.target_ordinal
338 {
339 self.scan_forward();
340 }
341 }
342 self.doc()
343 }
344
345 fn size_hint(&self) -> u32 {
346 0
347 }
348}
349
350impl Scorer for FastFieldTextScorer<'_> {
351 fn score(&self) -> Score {
352 1.0
353 }
354}
355
356impl Scorer for TermScorer {
357 fn score(&self) -> Score {
358 let tf = self.iterator.term_freq() as f32;
359 super::bm25f_score(tf, self.idf, tf, self.avg_field_len, self.field_boost)
362 }
363
364 fn matched_positions(&self) -> Option<super::MatchedPositions> {
365 let positions = self.positions.as_ref()?;
366 let doc_id = self.iterator.doc();
367 let pos = positions.get_positions(doc_id)?;
368 let score = self.score();
369 let per_position_score = if pos.is_empty() {
371 0.0
372 } else {
373 score / pos.len() as f32
374 };
375 let scored_positions: Vec<super::ScoredPosition> = pos
376 .iter()
377 .map(|&p| super::ScoredPosition::new(p, per_position_score))
378 .collect();
379 Some(vec![(self.field_id, scored_positions)])
380 }
381}