1use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::{BlockPostingIterator, BlockPostingList, PositionPostingList, TERMINATED};
8use crate::{DocId, Score};
9
10use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture};
11
12#[derive(Clone)]
17pub struct PhraseQuery {
18 pub field: Field,
19 pub terms: Vec<Vec<u8>>,
21 pub slop: u32,
23 global_stats: Option<Arc<GlobalStats>>,
25}
26
27impl std::fmt::Display for PhraseQuery {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 let terms: Vec<_> = self
30 .terms
31 .iter()
32 .map(|t| String::from_utf8_lossy(t))
33 .collect();
34 write!(f, "Phrase({}:\"{}\"", self.field.0, terms.join(" "))?;
35 if self.slop > 0 {
36 write!(f, "~{}", self.slop)?;
37 }
38 write!(f, ")")
39 }
40}
41
42impl std::fmt::Debug for PhraseQuery {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 let terms: Vec<_> = self
45 .terms
46 .iter()
47 .map(|t| String::from_utf8_lossy(t).to_string())
48 .collect();
49 f.debug_struct("PhraseQuery")
50 .field("field", &self.field)
51 .field("terms", &terms)
52 .field("slop", &self.slop)
53 .finish()
54 }
55}
56
57impl PhraseQuery {
58 pub fn new(field: Field, terms: Vec<Vec<u8>>) -> Self {
60 Self {
61 field,
62 terms,
63 slop: 0,
64 global_stats: None,
65 }
66 }
67
68 pub fn text(field: Field, phrase: &str) -> Self {
70 let terms: Vec<Vec<u8>> = phrase
71 .split_whitespace()
72 .map(|w| w.to_lowercase().into_bytes())
73 .collect();
74 Self {
75 field,
76 terms,
77 slop: 0,
78 global_stats: None,
79 }
80 }
81
82 pub fn with_slop(mut self, slop: u32) -> Self {
84 self.slop = slop;
85 self
86 }
87
88 pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
90 self.global_stats = Some(stats);
91 self
92 }
93}
94
95impl Query for PhraseQuery {
96 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
97 let field = self.field;
98 let terms = self.terms.clone();
99 let slop = self.slop;
100 let _global_stats = self.global_stats.clone();
101
102 Box::pin(async move {
103 if terms.is_empty() {
104 return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>);
105 }
106
107 if terms.len() == 1 {
109 let term_query = super::TermQuery::new(field, terms[0].clone());
110 return term_query.scorer(reader, limit).await;
111 }
112
113 if !reader.has_positions(field) {
115 let mut bool_query = super::BooleanQuery::new();
117 for term in &terms {
118 bool_query = bool_query.must(super::TermQuery::new(field, term.clone()));
119 }
120 return bool_query.scorer(reader, limit).await;
121 }
122
123 let mut term_postings: Vec<BlockPostingList> = Vec::with_capacity(terms.len());
125 let mut term_positions: Vec<PositionPostingList> = Vec::with_capacity(terms.len());
126
127 for term in &terms {
128 let (postings, positions) = futures::join!(
130 reader.get_postings(field, term),
131 reader.get_positions(field, term)
132 );
133
134 match (postings?, positions?) {
135 (Some(p), Some(pos)) => {
136 term_postings.push(p);
137 term_positions.push(pos);
138 }
139 _ => {
140 return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>);
142 }
143 }
144 }
145
146 let idf: f32 = term_postings
148 .iter()
149 .map(|p| {
150 let num_docs = reader.num_docs() as f32;
151 let doc_freq = p.doc_count() as f32;
152 super::bm25_idf(doc_freq, num_docs)
153 })
154 .sum();
155
156 let avg_field_len = reader.avg_field_len(field);
157
158 Ok(Box::new(PhraseScorer::new(
159 term_postings,
160 term_positions,
161 slop,
162 idf,
163 avg_field_len,
164 )) as Box<dyn Scorer + 'a>)
165 })
166 }
167
168 #[cfg(feature = "sync")]
169 fn scorer_sync<'a>(
170 &self,
171 reader: &'a SegmentReader,
172 limit: usize,
173 ) -> crate::Result<Box<dyn Scorer + 'a>> {
174 if self.terms.is_empty() {
175 return Ok(Box::new(super::EmptyScorer) as Box<dyn Scorer + 'a>);
176 }
177
178 if self.terms.len() == 1 {
179 let term_query = super::TermQuery::new(self.field, self.terms[0].clone());
180 return term_query.scorer_sync(reader, limit);
181 }
182
183 if !reader.has_positions(self.field) {
184 let mut bool_query = super::BooleanQuery::new();
185 for term in &self.terms {
186 bool_query = bool_query.must(super::TermQuery::new(self.field, term.clone()));
187 }
188 return bool_query.scorer_sync(reader, limit);
189 }
190
191 let mut term_postings: Vec<BlockPostingList> = Vec::with_capacity(self.terms.len());
192 let mut term_positions: Vec<PositionPostingList> = Vec::with_capacity(self.terms.len());
193
194 for term in &self.terms {
195 let postings = reader.get_postings_sync(self.field, term)?;
196 let positions = reader.get_positions_sync(self.field, term)?;
197
198 match (postings, positions) {
199 (Some(p), Some(pos)) => {
200 term_postings.push(p);
201 term_positions.push(pos);
202 }
203 _ => return Ok(Box::new(super::EmptyScorer) as Box<dyn Scorer + 'a>),
204 }
205 }
206
207 let idf: f32 = term_postings
208 .iter()
209 .map(|p| {
210 let num_docs = reader.num_docs() as f32;
211 let doc_freq = p.doc_count() as f32;
212 super::bm25_idf(doc_freq, num_docs)
213 })
214 .sum();
215
216 let avg_field_len = reader.avg_field_len(self.field);
217
218 Ok(Box::new(PhraseScorer::new(
219 term_postings,
220 term_positions,
221 self.slop,
222 idf,
223 avg_field_len,
224 )) as Box<dyn Scorer + 'a>)
225 }
226
227 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
228 let field = self.field;
229 let terms = self.terms.clone();
230
231 Box::pin(async move {
232 if terms.is_empty() {
233 return Ok(0);
234 }
235
236 let mut min_count = u32::MAX;
238 for term in &terms {
239 match reader.get_postings(field, term).await? {
240 Some(list) => min_count = min_count.min(list.doc_count()),
241 None => return Ok(0),
242 }
243 }
244
245 Ok((min_count / 10).max(1))
248 })
249 }
250}
251
252struct PhraseScorer {
254 posting_iters: Vec<BlockPostingIterator<'static>>,
256 position_lists: Vec<PositionPostingList>,
258 slop: u32,
260 current_doc: DocId,
262 idf: f32,
264 avg_field_len: f32,
266}
267
268impl PhraseScorer {
269 fn new(
270 posting_lists: Vec<BlockPostingList>,
271 position_lists: Vec<PositionPostingList>,
272 slop: u32,
273 idf: f32,
274 avg_field_len: f32,
275 ) -> Self {
276 let posting_iters: Vec<_> = posting_lists
277 .into_iter()
278 .map(|p| p.into_iterator())
279 .collect();
280
281 let mut scorer = Self {
282 posting_iters,
283 position_lists,
284 slop,
285 current_doc: 0,
286 idf,
287 avg_field_len,
288 };
289
290 scorer.find_next_phrase_match();
291 scorer
292 }
293
294 fn find_next_phrase_match(&mut self) {
296 loop {
297 let doc = self.find_next_and_match();
299 if doc == TERMINATED {
300 self.current_doc = TERMINATED;
301 return;
302 }
303
304 if self.check_phrase_positions(doc) {
306 self.current_doc = doc;
307 return;
308 }
309
310 self.posting_iters[0].advance();
312 }
313 }
314
315 fn find_next_and_match(&mut self) -> DocId {
317 if self.posting_iters.is_empty() {
318 return TERMINATED;
319 }
320
321 loop {
322 let max_doc = self.posting_iters.iter().map(|it| it.doc()).max().unwrap();
323
324 if max_doc == TERMINATED {
325 return TERMINATED;
326 }
327
328 let mut all_match = true;
329 for it in &mut self.posting_iters {
330 let doc = it.seek(max_doc);
331 if doc != max_doc {
332 all_match = false;
333 if doc == TERMINATED {
334 return TERMINATED;
335 }
336 }
337 }
338
339 if all_match {
340 return max_doc;
341 }
342 }
343 }
344
345 fn check_phrase_positions(&self, doc_id: DocId) -> bool {
347 let mut term_positions: Vec<Vec<u32>> = Vec::with_capacity(self.position_lists.len());
349
350 for pos_list in &self.position_lists {
351 match pos_list.get_positions(doc_id) {
352 Some(positions) => term_positions.push(positions.to_vec()),
353 None => return false,
354 }
355 }
356
357 self.find_phrase_match(&term_positions)
360 }
361
362 fn find_phrase_match(&self, term_positions: &[Vec<u32>]) -> bool {
364 if term_positions.is_empty() {
365 return false;
366 }
367
368 for &first_pos in &term_positions[0] {
371 if self.check_phrase_from_position(first_pos, term_positions) {
372 return true;
373 }
374 }
375
376 false
377 }
378
379 fn check_phrase_from_position(&self, start_pos: u32, term_positions: &[Vec<u32>]) -> bool {
381 let mut expected_pos = start_pos;
382
383 for (i, positions) in term_positions.iter().enumerate() {
384 if i == 0 {
385 continue; }
387
388 expected_pos += 1;
389
390 let found = positions.iter().any(|&pos| {
392 if self.slop == 0 {
393 pos == expected_pos
394 } else {
395 let diff = pos.abs_diff(expected_pos);
396 diff <= self.slop
397 }
398 });
399
400 if !found {
401 return false;
402 }
403 }
404
405 true
406 }
407}
408
409impl super::docset::DocSet for PhraseScorer {
410 fn doc(&self) -> DocId {
411 self.current_doc
412 }
413
414 fn advance(&mut self) -> DocId {
415 if self.current_doc == TERMINATED {
416 return TERMINATED;
417 }
418
419 self.posting_iters[0].advance();
420 self.find_next_phrase_match();
421 self.current_doc
422 }
423
424 fn seek(&mut self, target: DocId) -> DocId {
425 if target == TERMINATED {
426 self.current_doc = TERMINATED;
427 return TERMINATED;
428 }
429
430 self.posting_iters[0].seek(target);
431 self.find_next_phrase_match();
432 self.current_doc
433 }
434
435 fn size_hint(&self) -> u32 {
436 0
437 }
438}
439
440impl Scorer for PhraseScorer {
441 fn score(&self) -> Score {
442 if self.current_doc == TERMINATED {
443 return 0.0;
444 }
445
446 let tf: f32 = self
448 .posting_iters
449 .iter()
450 .map(|it| it.term_freq() as f32)
451 .sum();
452
453 super::bm25_score(tf, self.idf, tf, self.avg_field_len) * 1.5
455 }
456}