1use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::{BlockPostingIterator, BlockPostingList, PositionPostingList, TERMINATED};
8use crate::{DocId, Score};
9
10use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture};
11
12#[derive(Clone)]
17pub struct PhraseQuery {
18 pub field: Field,
19 pub terms: Vec<Vec<u8>>,
21 pub slop: u32,
23 global_stats: Option<Arc<GlobalStats>>,
25}
26
27impl std::fmt::Display for PhraseQuery {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 let terms: Vec<_> = self
30 .terms
31 .iter()
32 .map(|t| String::from_utf8_lossy(t))
33 .collect();
34 write!(f, "Phrase({}:\"{}\"", self.field.0, terms.join(" "))?;
35 if self.slop > 0 {
36 write!(f, "~{}", self.slop)?;
37 }
38 write!(f, ")")
39 }
40}
41
42impl std::fmt::Debug for PhraseQuery {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 let terms: Vec<_> = self
45 .terms
46 .iter()
47 .map(|t| String::from_utf8_lossy(t).to_string())
48 .collect();
49 f.debug_struct("PhraseQuery")
50 .field("field", &self.field)
51 .field("terms", &terms)
52 .field("slop", &self.slop)
53 .finish()
54 }
55}
56
57impl PhraseQuery {
58 pub fn new(field: Field, terms: Vec<Vec<u8>>) -> Self {
60 Self {
61 field,
62 terms,
63 slop: 0,
64 global_stats: None,
65 }
66 }
67
68 pub fn text(field: Field, phrase: &str) -> Self {
70 let terms: Vec<Vec<u8>> = phrase
71 .split_whitespace()
72 .map(|w| w.to_lowercase().into_bytes())
73 .collect();
74 Self {
75 field,
76 terms,
77 slop: 0,
78 global_stats: None,
79 }
80 }
81
82 pub fn with_slop(mut self, slop: u32) -> Self {
84 self.slop = slop;
85 self
86 }
87
88 pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
90 self.global_stats = Some(stats);
91 self
92 }
93}
94
95fn build_phrase_scorer<'a>(
97 term_data: Vec<(BlockPostingList, PositionPostingList)>,
98 slop: u32,
99 reader: &SegmentReader,
100 field: Field,
101) -> Box<dyn Scorer + 'a> {
102 let idf: f32 = term_data
103 .iter()
104 .map(|(p, _)| {
105 let num_docs = reader.num_docs() as f32;
106 let doc_freq = p.doc_count() as f32;
107 super::bm25_idf(doc_freq, num_docs)
108 })
109 .sum();
110 let avg_field_len = reader.avg_field_len(field);
111 let (postings, positions): (Vec<_>, Vec<_>) = term_data.into_iter().unzip();
112 Box::new(PhraseScorer::new(
113 postings,
114 positions,
115 slop,
116 idf,
117 avg_field_len,
118 ))
119}
120
121macro_rules! phrase_early_returns {
126 ($field:expr, $terms:expr, $reader:expr, $limit:expr,
127 $scorer_fn:ident $(, $aw:tt)*) => {
128 if $terms.is_empty() {
129 return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + '_>);
130 }
131 if $terms.len() == 1 {
132 let tq = super::TermQuery::new($field, $terms[0].clone());
133 return tq.$scorer_fn($reader, $limit) $(. $aw)* ;
134 }
135 if !$reader.has_positions($field) {
136 let mut bq = super::BooleanQuery::new();
137 for t in $terms.iter() {
138 bq = bq.must(super::TermQuery::new($field, t.clone()));
139 }
140 return bq.$scorer_fn($reader, $limit) $(. $aw)* ;
141 }
142 };
143}
144
145impl Query for PhraseQuery {
146 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
147 let field = self.field;
148 let terms = self.terms.clone();
149 let slop = self.slop;
150
151 Box::pin(async move {
152 phrase_early_returns!(field, terms, reader, limit, scorer, await);
153
154 let mut term_data = Vec::with_capacity(terms.len());
156 for term in &terms {
157 let (postings, positions) = futures::join!(
158 reader.get_postings(field, term),
159 reader.get_positions(field, term)
160 );
161 match (postings?, positions?) {
162 (Some(p), Some(pos)) => term_data.push((p, pos)),
163 _ => return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
164 }
165 }
166
167 Ok(build_phrase_scorer(term_data, slop, reader, field))
168 })
169 }
170
171 #[cfg(feature = "sync")]
172 fn scorer_sync<'a>(
173 &self,
174 reader: &'a SegmentReader,
175 limit: usize,
176 ) -> crate::Result<Box<dyn Scorer + 'a>> {
177 phrase_early_returns!(self.field, self.terms, reader, limit, scorer_sync);
178
179 use rayon::prelude::*;
181 let pairs: crate::Result<Vec<Option<(BlockPostingList, PositionPostingList)>>> = self
182 .terms
183 .par_iter()
184 .map(|term| {
185 let postings = reader.get_postings_sync(self.field, term)?;
186 let positions = reader.get_positions_sync(self.field, term)?;
187 Ok(match (postings, positions) {
188 (Some(p), Some(pos)) => Some((p, pos)),
189 _ => None,
190 })
191 })
192 .collect();
193 let mut term_data = Vec::with_capacity(self.terms.len());
194 for entry in pairs? {
195 match entry {
196 Some(pair) => term_data.push(pair),
197 None => return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
198 }
199 }
200
201 Ok(build_phrase_scorer(
202 term_data, self.slop, reader, self.field,
203 ))
204 }
205
206 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
207 let field = self.field;
208 let terms = self.terms.clone();
209
210 Box::pin(async move {
211 if terms.is_empty() {
212 return Ok(0);
213 }
214
215 let mut min_count = u32::MAX;
217 for term in &terms {
218 match reader.get_postings(field, term).await? {
219 Some(list) => min_count = min_count.min(list.doc_count()),
220 None => return Ok(0),
221 }
222 }
223
224 Ok((min_count / 10).max(1))
227 })
228 }
229}
230
231struct PhraseScorer {
233 posting_iters: Vec<BlockPostingIterator<'static>>,
235 position_lists: Vec<PositionPostingList>,
237 slop: u32,
239 current_doc: DocId,
241 idf: f32,
243 avg_field_len: f32,
245}
246
247impl PhraseScorer {
248 fn new(
249 posting_lists: Vec<BlockPostingList>,
250 position_lists: Vec<PositionPostingList>,
251 slop: u32,
252 idf: f32,
253 avg_field_len: f32,
254 ) -> Self {
255 let posting_iters: Vec<_> = posting_lists
256 .into_iter()
257 .map(|p| p.into_iterator())
258 .collect();
259
260 let mut scorer = Self {
261 posting_iters,
262 position_lists,
263 slop,
264 current_doc: 0,
265 idf,
266 avg_field_len,
267 };
268
269 scorer.find_next_phrase_match();
270 scorer
271 }
272
273 fn find_next_phrase_match(&mut self) {
275 loop {
276 let doc = self.find_next_and_match();
278 if doc == TERMINATED {
279 self.current_doc = TERMINATED;
280 return;
281 }
282
283 if self.check_phrase_positions(doc) {
285 self.current_doc = doc;
286 return;
287 }
288
289 self.posting_iters[0].advance();
291 }
292 }
293
294 fn find_next_and_match(&mut self) -> DocId {
296 if self.posting_iters.is_empty() {
297 return TERMINATED;
298 }
299
300 loop {
301 let max_doc = self.posting_iters.iter().map(|it| it.doc()).max().unwrap();
302
303 if max_doc == TERMINATED {
304 return TERMINATED;
305 }
306
307 let mut all_match = true;
308 for it in &mut self.posting_iters {
309 let doc = it.seek(max_doc);
310 if doc != max_doc {
311 all_match = false;
312 if doc == TERMINATED {
313 return TERMINATED;
314 }
315 }
316 }
317
318 if all_match {
319 return max_doc;
320 }
321 }
322 }
323
324 fn check_phrase_positions(&self, doc_id: DocId) -> bool {
326 let mut term_positions: Vec<Vec<u32>> = Vec::with_capacity(self.position_lists.len());
328
329 for pos_list in &self.position_lists {
330 match pos_list.get_positions(doc_id) {
331 Some(positions) => term_positions.push(positions.to_vec()),
332 None => return false,
333 }
334 }
335
336 self.find_phrase_match(&term_positions)
339 }
340
341 fn find_phrase_match(&self, term_positions: &[Vec<u32>]) -> bool {
343 if term_positions.is_empty() {
344 return false;
345 }
346
347 for &first_pos in &term_positions[0] {
350 if self.check_phrase_from_position(first_pos, term_positions) {
351 return true;
352 }
353 }
354
355 false
356 }
357
358 fn check_phrase_from_position(&self, start_pos: u32, term_positions: &[Vec<u32>]) -> bool {
360 let mut expected_pos = start_pos;
361
362 for (i, positions) in term_positions.iter().enumerate() {
363 if i == 0 {
364 continue; }
366
367 expected_pos += 1;
368
369 let found = positions.iter().any(|&pos| {
371 if self.slop == 0 {
372 pos == expected_pos
373 } else {
374 let diff = pos.abs_diff(expected_pos);
375 diff <= self.slop
376 }
377 });
378
379 if !found {
380 return false;
381 }
382 }
383
384 true
385 }
386}
387
388impl super::docset::DocSet for PhraseScorer {
389 fn doc(&self) -> DocId {
390 self.current_doc
391 }
392
393 fn advance(&mut self) -> DocId {
394 if self.current_doc == TERMINATED {
395 return TERMINATED;
396 }
397
398 self.posting_iters[0].advance();
399 self.find_next_phrase_match();
400 self.current_doc
401 }
402
403 fn seek(&mut self, target: DocId) -> DocId {
404 if target == TERMINATED {
405 self.current_doc = TERMINATED;
406 return TERMINATED;
407 }
408
409 self.posting_iters[0].seek(target);
410 self.find_next_phrase_match();
411 self.current_doc
412 }
413
414 fn size_hint(&self) -> u32 {
415 0
416 }
417}
418
419impl Scorer for PhraseScorer {
420 fn score(&self) -> Score {
421 if self.current_doc == TERMINATED {
422 return 0.0;
423 }
424
425 let tf: f32 = self
427 .posting_iters
428 .iter()
429 .map(|it| it.term_freq() as f32)
430 .sum();
431
432 super::bm25_score(tf, self.idf, tf, self.avg_field_len) * 1.5
434 }
435}