1use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::BlockPostingList;
8use crate::{DocId, Score};
9
10use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture, TermQueryInfo};
11
12#[derive(Clone)]
14pub struct TermQuery {
15 pub field: Field,
16 pub term: Vec<u8>,
17 global_stats: Option<Arc<GlobalStats>>,
19}
20
21impl std::fmt::Debug for TermQuery {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("TermQuery")
24 .field("field", &self.field)
25 .field("term", &String::from_utf8_lossy(&self.term))
26 .field("has_global_stats", &self.global_stats.is_some())
27 .finish()
28 }
29}
30
31impl TermQuery {
32 pub fn new(field: Field, term: impl Into<Vec<u8>>) -> Self {
33 Self {
34 field,
35 term: term.into(),
36 global_stats: None,
37 }
38 }
39
40 pub fn text(field: Field, text: &str) -> Self {
41 Self {
42 field,
43 term: text.to_lowercase().into_bytes(),
44 global_stats: None,
45 }
46 }
47
48 pub fn with_global_stats(field: Field, text: &str, stats: Arc<GlobalStats>) -> Self {
50 Self {
51 field,
52 term: text.to_lowercase().into_bytes(),
53 global_stats: Some(stats),
54 }
55 }
56
57 pub fn set_global_stats(&mut self, stats: Arc<GlobalStats>) {
59 self.global_stats = Some(stats);
60 }
61}
62
63impl Query for TermQuery {
64 fn scorer<'a>(
65 &self,
66 reader: &'a SegmentReader,
67 _limit: usize,
68 _predicate: Option<super::DocPredicate<'a>>,
69 ) -> ScorerFuture<'a> {
70 let field = self.field;
71 let term = self.term.clone();
72 let global_stats = self.global_stats.clone();
73 Box::pin(async move {
74 let postings = reader.get_postings(field, &term).await?;
75
76 match postings {
77 Some(posting_list) => {
78 let (idf, avg_field_len) = if let Some(ref stats) = global_stats {
80 let term_str = String::from_utf8_lossy(&term);
81 let global_idf = stats.text_idf(field, &term_str);
82
83 if global_idf > 0.0 {
86 (global_idf, stats.avg_field_len(field))
87 } else {
88 let num_docs = reader.num_docs() as f32;
90 let doc_freq = posting_list.doc_count() as f32;
91 let idf = super::bm25_idf(doc_freq, num_docs);
92 (idf, reader.avg_field_len(field))
93 }
94 } else {
95 let num_docs = reader.num_docs() as f32;
97 let doc_freq = posting_list.doc_count() as f32;
98 let idf = super::bm25_idf(doc_freq, num_docs);
99 (idf, reader.avg_field_len(field))
100 };
101
102 let positions = reader.get_positions(field, &term).await.ok().flatten();
104
105 let mut scorer = TermScorer::new(
106 posting_list,
107 idf,
108 avg_field_len,
109 1.0, );
111
112 if let Some(pos) = positions {
113 scorer = scorer.with_positions(field.0, pos);
114 }
115
116 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
117 }
118 None => Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
119 }
120 })
121 }
122
123 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
124 let field = self.field;
125 let term = self.term.clone();
126 Box::pin(async move {
127 match reader.get_postings(field, &term).await? {
128 Some(list) => Ok(list.doc_count()),
129 None => Ok(0),
130 }
131 })
132 }
133
134 #[cfg(feature = "sync")]
135 fn scorer_sync<'a>(
136 &self,
137 reader: &'a SegmentReader,
138 _limit: usize,
139 _predicate: Option<super::DocPredicate<'a>>,
140 ) -> crate::Result<Box<dyn Scorer + 'a>> {
141 let postings = reader.get_postings_sync(self.field, &self.term)?;
142
143 match postings {
144 Some(posting_list) => {
145 let (idf, avg_field_len) = if let Some(ref stats) = self.global_stats {
146 let term_str = String::from_utf8_lossy(&self.term);
147 let global_idf = stats.text_idf(self.field, &term_str);
148 if global_idf > 0.0 {
149 (global_idf, stats.avg_field_len(self.field))
150 } else {
151 let num_docs = reader.num_docs() as f32;
152 let doc_freq = posting_list.doc_count() as f32;
153 (
154 super::bm25_idf(doc_freq, num_docs),
155 reader.avg_field_len(self.field),
156 )
157 }
158 } else {
159 let num_docs = reader.num_docs() as f32;
160 let doc_freq = posting_list.doc_count() as f32;
161 (
162 super::bm25_idf(doc_freq, num_docs),
163 reader.avg_field_len(self.field),
164 )
165 };
166
167 let positions = reader
168 .get_positions_sync(self.field, &self.term)
169 .ok()
170 .flatten();
171
172 let mut scorer = TermScorer::new(posting_list, idf, avg_field_len, 1.0);
173 if let Some(pos) = positions {
174 scorer = scorer.with_positions(self.field.0, pos);
175 }
176
177 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
178 }
179 None => Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
180 }
181 }
182
183 fn as_term_query_info(&self) -> Option<TermQueryInfo> {
184 Some(TermQueryInfo {
185 field: self.field,
186 term: self.term.clone(),
187 })
188 }
189}
190
191struct TermScorer {
192 iterator: crate::structures::BlockPostingIterator<'static>,
193 idf: f32,
194 avg_field_len: f32,
196 field_boost: f32,
198 field_id: u32,
200 positions: Option<crate::structures::PositionPostingList>,
202}
203
204impl TermScorer {
205 pub fn new(
206 posting_list: BlockPostingList,
207 idf: f32,
208 avg_field_len: f32,
209 field_boost: f32,
210 ) -> Self {
211 Self {
212 iterator: posting_list.into_iterator(),
213 idf,
214 avg_field_len,
215 field_boost,
216 field_id: 0,
217 positions: None,
218 }
219 }
220
221 pub fn with_positions(
222 mut self,
223 field_id: u32,
224 positions: crate::structures::PositionPostingList,
225 ) -> Self {
226 self.field_id = field_id;
227 self.positions = Some(positions);
228 self
229 }
230}
231
232impl Scorer for TermScorer {
233 fn doc(&self) -> DocId {
234 self.iterator.doc()
235 }
236
237 fn score(&self) -> Score {
238 let tf = self.iterator.term_freq() as f32;
239 super::bm25f_score(tf, self.idf, tf, self.avg_field_len, self.field_boost)
242 }
243
244 fn advance(&mut self) -> DocId {
245 self.iterator.advance()
246 }
247
248 fn seek(&mut self, target: DocId) -> DocId {
249 self.iterator.seek(target)
250 }
251
252 fn size_hint(&self) -> u32 {
253 0
254 }
255
256 fn matched_positions(&self) -> Option<super::MatchedPositions> {
257 let positions = self.positions.as_ref()?;
258 let doc_id = self.iterator.doc();
259 let pos = positions.get_positions(doc_id)?;
260 let score = self.score();
261 let per_position_score = if pos.is_empty() {
263 0.0
264 } else {
265 score / pos.len() as f32
266 };
267 let scored_positions: Vec<super::ScoredPosition> = pos
268 .iter()
269 .map(|&p| super::ScoredPosition::new(p, per_position_score))
270 .collect();
271 Some(vec![(self.field_id, scored_positions)])
272 }
273}