1use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::{BlockPostingIterator, BlockPostingList, PositionPostingList, TERMINATED};
8use crate::{DocId, Score};
9
10use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture};
11
12#[derive(Clone)]
17pub struct PhraseQuery {
18 pub field: Field,
19 pub terms: Vec<Vec<u8>>,
21 pub slop: u32,
23 global_stats: Option<Arc<GlobalStats>>,
25}
26
27impl std::fmt::Display for PhraseQuery {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 let terms: Vec<_> = self
30 .terms
31 .iter()
32 .map(|t| String::from_utf8_lossy(t))
33 .collect();
34 write!(f, "Phrase({}:\"{}\"", self.field.0, terms.join(" "))?;
35 if self.slop > 0 {
36 write!(f, "~{}", self.slop)?;
37 }
38 write!(f, ")")
39 }
40}
41
42impl std::fmt::Debug for PhraseQuery {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 let terms: Vec<_> = self
45 .terms
46 .iter()
47 .map(|t| String::from_utf8_lossy(t).to_string())
48 .collect();
49 f.debug_struct("PhraseQuery")
50 .field("field", &self.field)
51 .field("terms", &terms)
52 .field("slop", &self.slop)
53 .finish()
54 }
55}
56
57impl PhraseQuery {
58 pub fn new(field: Field, terms: Vec<Vec<u8>>) -> Self {
60 Self {
61 field,
62 terms,
63 slop: 0,
64 global_stats: None,
65 }
66 }
67
68 pub fn text(field: Field, phrase: &str) -> Self {
70 let terms: Vec<Vec<u8>> = phrase
71 .split_whitespace()
72 .map(|w| w.to_lowercase().into_bytes())
73 .collect();
74 Self {
75 field,
76 terms,
77 slop: 0,
78 global_stats: None,
79 }
80 }
81
82 pub fn with_slop(mut self, slop: u32) -> Self {
84 self.slop = slop;
85 self
86 }
87
88 pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
90 self.global_stats = Some(stats);
91 self
92 }
93}
94
95fn build_phrase_scorer<'a>(
97 term_data: Vec<(BlockPostingList, PositionPostingList)>,
98 slop: u32,
99 reader: &SegmentReader,
100 field: Field,
101) -> Box<dyn Scorer + 'a> {
102 let idf: f32 = term_data
103 .iter()
104 .map(|(p, _)| {
105 let num_docs = reader.num_docs() as f32;
106 let doc_freq = p.doc_count() as f32;
107 super::bm25_idf(doc_freq, num_docs)
108 })
109 .sum();
110 let avg_field_len = reader.avg_field_len(field);
111 let (postings, positions): (Vec<_>, Vec<_>) = term_data.into_iter().unzip();
112 Box::new(PhraseScorer::new(
113 postings,
114 positions,
115 slop,
116 idf,
117 avg_field_len,
118 ))
119}
120
121macro_rules! phrase_early_returns {
126 ($field:expr, $terms:expr, $reader:expr, $limit:expr,
127 $scorer_fn:ident $(, $aw:tt)*) => {
128 if $terms.is_empty() {
129 return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + '_>);
130 }
131 if $terms.len() == 1 {
132 let tq = super::TermQuery::new($field, $terms[0].clone());
133 return tq.$scorer_fn($reader, $limit) $(. $aw)* ;
134 }
135 if !$reader.has_positions($field) {
136 let mut bq = super::BooleanQuery::new();
137 for t in $terms.iter() {
138 bq = bq.must(super::TermQuery::new($field, t.clone()));
139 }
140 return bq.$scorer_fn($reader, $limit) $(. $aw)* ;
141 }
142 };
143}
144
145impl Query for PhraseQuery {
146 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
147 let field = self.field;
148 let terms = self.terms.clone();
149 let slop = self.slop;
150
151 Box::pin(async move {
152 phrase_early_returns!(field, terms, reader, limit, scorer, await);
153
154 let mut term_data = Vec::with_capacity(terms.len());
156 for term in &terms {
157 let (postings, positions) = futures::join!(
158 reader.get_postings(field, term),
159 reader.get_positions(field, term)
160 );
161 match (postings?, positions?) {
162 (Some(p), Some(pos)) => term_data.push((p, pos)),
163 _ => return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
164 }
165 }
166
167 Ok(build_phrase_scorer(term_data, slop, reader, field))
168 })
169 }
170
171 #[cfg(feature = "sync")]
172 fn scorer_sync<'a>(
173 &self,
174 reader: &'a SegmentReader,
175 limit: usize,
176 ) -> crate::Result<Box<dyn Scorer + 'a>> {
177 phrase_early_returns!(self.field, self.terms, reader, limit, scorer_sync);
178
179 use rayon::prelude::*;
181 let pairs: crate::Result<Vec<Option<(BlockPostingList, PositionPostingList)>>> = self
182 .terms
183 .par_iter()
184 .map(|term| {
185 let postings = reader.get_postings_sync(self.field, term)?;
186 let positions = reader.get_positions_sync(self.field, term)?;
187 Ok(match (postings, positions) {
188 (Some(p), Some(pos)) => Some((p, pos)),
189 _ => None,
190 })
191 })
192 .collect();
193 let mut term_data = Vec::with_capacity(self.terms.len());
194 for entry in pairs? {
195 match entry {
196 Some(pair) => term_data.push(pair),
197 None => return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
198 }
199 }
200
201 Ok(build_phrase_scorer(
202 term_data, self.slop, reader, self.field,
203 ))
204 }
205
206 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
207 let field = self.field;
208 let terms = self.terms.clone();
209
210 Box::pin(async move {
211 if terms.is_empty() {
212 return Ok(0);
213 }
214
215 let mut min_count = u32::MAX;
217 for term in &terms {
218 match reader.get_postings(field, term).await? {
219 Some(list) => min_count = min_count.min(list.doc_count()),
220 None => return Ok(0),
221 }
222 }
223
224 Ok((min_count / 10).max(1))
227 })
228 }
229}
230
231struct PhraseScorer {
233 posting_iters: Vec<BlockPostingIterator<'static>>,
235 position_lists: Vec<PositionPostingList>,
237 slop: u32,
239 current_doc: DocId,
241 idf: f32,
243 avg_field_len: f32,
245 position_bufs: Vec<Vec<u32>>,
247}
248
249impl PhraseScorer {
250 fn new(
251 posting_lists: Vec<BlockPostingList>,
252 position_lists: Vec<PositionPostingList>,
253 slop: u32,
254 idf: f32,
255 avg_field_len: f32,
256 ) -> Self {
257 let posting_iters: Vec<_> = posting_lists
258 .into_iter()
259 .map(|p| p.into_iterator())
260 .collect();
261
262 let num_terms = position_lists.len();
263 let mut scorer = Self {
264 posting_iters,
265 position_lists,
266 slop,
267 current_doc: 0,
268 idf,
269 avg_field_len,
270 position_bufs: (0..num_terms).map(|_| Vec::new()).collect(),
271 };
272
273 scorer.find_next_phrase_match();
274 scorer
275 }
276
277 fn find_next_phrase_match(&mut self) {
279 loop {
280 let doc = self.find_next_and_match();
282 if doc == TERMINATED {
283 self.current_doc = TERMINATED;
284 return;
285 }
286
287 if self.check_phrase_positions(doc) {
289 self.current_doc = doc;
290 return;
291 }
292
293 self.posting_iters[0].advance();
295 }
296 }
297
298 fn find_next_and_match(&mut self) -> DocId {
300 if self.posting_iters.is_empty() {
301 return TERMINATED;
302 }
303
304 loop {
305 let max_doc = self.posting_iters.iter().map(|it| it.doc()).max().unwrap();
306
307 if max_doc == TERMINATED {
308 return TERMINATED;
309 }
310
311 let mut all_match = true;
312 for it in &mut self.posting_iters {
313 let doc = it.seek(max_doc);
314 if doc != max_doc {
315 all_match = false;
316 if doc == TERMINATED {
317 return TERMINATED;
318 }
319 }
320 }
321
322 if all_match {
323 return max_doc;
324 }
325 }
326 }
327
328 fn check_phrase_positions(&mut self, doc_id: DocId) -> bool {
330 for (i, pos_list) in self.position_lists.iter().enumerate() {
332 if !pos_list.get_positions_into(doc_id, &mut self.position_bufs[i]) {
333 return false;
334 }
335 }
336
337 self.find_phrase_match_from_bufs()
340 }
341
342 fn find_phrase_match_from_bufs(&self) -> bool {
344 if self.position_bufs.is_empty() || self.position_bufs[0].is_empty() {
345 return false;
346 }
347
348 for &first_pos in &self.position_bufs[0] {
349 if self.check_phrase_from_position(first_pos, &self.position_bufs) {
350 return true;
351 }
352 }
353
354 false
355 }
356
357 fn check_phrase_from_position(&self, start_pos: u32, term_positions: &[Vec<u32>]) -> bool {
359 let mut expected_pos = start_pos;
360
361 for (i, positions) in term_positions.iter().enumerate() {
362 if i == 0 {
363 continue; }
365
366 expected_pos += 1;
367
368 let found = positions.iter().any(|&pos| {
370 if self.slop == 0 {
371 pos == expected_pos
372 } else {
373 let diff = pos.abs_diff(expected_pos);
374 diff <= self.slop
375 }
376 });
377
378 if !found {
379 return false;
380 }
381 }
382
383 true
384 }
385}
386
387impl super::docset::DocSet for PhraseScorer {
388 fn doc(&self) -> DocId {
389 self.current_doc
390 }
391
392 fn advance(&mut self) -> DocId {
393 if self.current_doc == TERMINATED {
394 return TERMINATED;
395 }
396
397 self.posting_iters[0].advance();
398 self.find_next_phrase_match();
399 self.current_doc
400 }
401
402 fn seek(&mut self, target: DocId) -> DocId {
403 if target == TERMINATED {
404 self.current_doc = TERMINATED;
405 return TERMINATED;
406 }
407
408 self.posting_iters[0].seek(target);
409 self.find_next_phrase_match();
410 self.current_doc
411 }
412
413 fn size_hint(&self) -> u32 {
414 0
415 }
416}
417
418impl Scorer for PhraseScorer {
419 fn score(&self) -> Score {
420 if self.current_doc == TERMINATED {
421 return 0.0;
422 }
423
424 let tf: f32 = self
426 .posting_iters
427 .iter()
428 .map(|it| it.term_freq() as f32)
429 .sum();
430
431 super::bm25_score(tf, self.idf, tf, self.avg_field_len) * 1.5
433 }
434}