1use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::{BlockPostingIterator, BlockPostingList, PositionPostingList, TERMINATED};
8use crate::{DocId, Score};
9
10use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture};
11
12#[derive(Clone)]
17pub struct PhraseQuery {
18 pub field: Field,
19 pub terms: Vec<Vec<u8>>,
21 pub slop: u32,
23 global_stats: Option<Arc<GlobalStats>>,
25}
26
27impl std::fmt::Debug for PhraseQuery {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 let terms: Vec<_> = self
30 .terms
31 .iter()
32 .map(|t| String::from_utf8_lossy(t).to_string())
33 .collect();
34 f.debug_struct("PhraseQuery")
35 .field("field", &self.field)
36 .field("terms", &terms)
37 .field("slop", &self.slop)
38 .finish()
39 }
40}
41
42impl PhraseQuery {
43 pub fn new(field: Field, terms: Vec<Vec<u8>>) -> Self {
45 Self {
46 field,
47 terms,
48 slop: 0,
49 global_stats: None,
50 }
51 }
52
53 pub fn text(field: Field, phrase: &str) -> Self {
55 let terms: Vec<Vec<u8>> = phrase
56 .split_whitespace()
57 .map(|w| w.to_lowercase().into_bytes())
58 .collect();
59 Self {
60 field,
61 terms,
62 slop: 0,
63 global_stats: None,
64 }
65 }
66
67 pub fn with_slop(mut self, slop: u32) -> Self {
69 self.slop = slop;
70 self
71 }
72
73 pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
75 self.global_stats = Some(stats);
76 self
77 }
78}
79
80impl Query for PhraseQuery {
81 fn scorer<'a>(
82 &self,
83 reader: &'a SegmentReader,
84 limit: usize,
85 predicate: Option<super::DocPredicate<'a>>,
86 ) -> ScorerFuture<'a> {
87 let field = self.field;
88 let terms = self.terms.clone();
89 let slop = self.slop;
90 let _global_stats = self.global_stats.clone();
91
92 Box::pin(async move {
93 if terms.is_empty() {
94 return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>);
95 }
96
97 if terms.len() == 1 {
99 let term_query = super::TermQuery::new(field, terms[0].clone());
100 return term_query.scorer(reader, limit, predicate).await;
101 }
102
103 if !reader.has_positions(field) {
105 let mut bool_query = super::BooleanQuery::new();
107 for term in &terms {
108 bool_query = bool_query.must(super::TermQuery::new(field, term.clone()));
109 }
110 return bool_query.scorer(reader, limit, predicate).await;
111 }
112
113 let mut term_postings: Vec<BlockPostingList> = Vec::with_capacity(terms.len());
115 let mut term_positions: Vec<PositionPostingList> = Vec::with_capacity(terms.len());
116
117 for term in &terms {
118 let (postings, positions) = futures::join!(
120 reader.get_postings(field, term),
121 reader.get_positions(field, term)
122 );
123
124 match (postings?, positions?) {
125 (Some(p), Some(pos)) => {
126 term_postings.push(p);
127 term_positions.push(pos);
128 }
129 _ => {
130 return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>);
132 }
133 }
134 }
135
136 let idf: f32 = term_postings
138 .iter()
139 .map(|p| {
140 let num_docs = reader.num_docs() as f32;
141 let doc_freq = p.doc_count() as f32;
142 super::bm25_idf(doc_freq, num_docs)
143 })
144 .sum();
145
146 let avg_field_len = reader.avg_field_len(field);
147
148 Ok(Box::new(PhraseScorer::new(
149 term_postings,
150 term_positions,
151 slop,
152 idf,
153 avg_field_len,
154 )) as Box<dyn Scorer + 'a>)
155 })
156 }
157
158 #[cfg(feature = "sync")]
159 fn scorer_sync<'a>(
160 &self,
161 reader: &'a SegmentReader,
162 limit: usize,
163 predicate: Option<super::DocPredicate<'a>>,
164 ) -> crate::Result<Box<dyn Scorer + 'a>> {
165 if self.terms.is_empty() {
166 return Ok(Box::new(super::EmptyScorer) as Box<dyn Scorer + 'a>);
167 }
168
169 if self.terms.len() == 1 {
170 let term_query = super::TermQuery::new(self.field, self.terms[0].clone());
171 return term_query.scorer_sync(reader, limit, predicate);
172 }
173
174 if !reader.has_positions(self.field) {
175 let mut bool_query = super::BooleanQuery::new();
176 for term in &self.terms {
177 bool_query = bool_query.must(super::TermQuery::new(self.field, term.clone()));
178 }
179 return bool_query.scorer_sync(reader, limit, predicate);
180 }
181
182 let mut term_postings: Vec<BlockPostingList> = Vec::with_capacity(self.terms.len());
183 let mut term_positions: Vec<PositionPostingList> = Vec::with_capacity(self.terms.len());
184
185 for term in &self.terms {
186 let postings = reader.get_postings_sync(self.field, term)?;
187 let positions = reader.get_positions_sync(self.field, term)?;
188
189 match (postings, positions) {
190 (Some(p), Some(pos)) => {
191 term_postings.push(p);
192 term_positions.push(pos);
193 }
194 _ => return Ok(Box::new(super::EmptyScorer) as Box<dyn Scorer + 'a>),
195 }
196 }
197
198 let idf: f32 = term_postings
199 .iter()
200 .map(|p| {
201 let num_docs = reader.num_docs() as f32;
202 let doc_freq = p.doc_count() as f32;
203 super::bm25_idf(doc_freq, num_docs)
204 })
205 .sum();
206
207 let avg_field_len = reader.avg_field_len(self.field);
208
209 Ok(Box::new(PhraseScorer::new(
210 term_postings,
211 term_positions,
212 self.slop,
213 idf,
214 avg_field_len,
215 )) as Box<dyn Scorer + 'a>)
216 }
217
218 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
219 let field = self.field;
220 let terms = self.terms.clone();
221
222 Box::pin(async move {
223 if terms.is_empty() {
224 return Ok(0);
225 }
226
227 let mut min_count = u32::MAX;
229 for term in &terms {
230 match reader.get_postings(field, term).await? {
231 Some(list) => min_count = min_count.min(list.doc_count()),
232 None => return Ok(0),
233 }
234 }
235
236 Ok((min_count / 10).max(1))
239 })
240 }
241}
242
243struct PhraseScorer {
245 posting_iters: Vec<BlockPostingIterator<'static>>,
247 position_lists: Vec<PositionPostingList>,
249 slop: u32,
251 current_doc: DocId,
253 idf: f32,
255 avg_field_len: f32,
257}
258
259impl PhraseScorer {
260 fn new(
261 posting_lists: Vec<BlockPostingList>,
262 position_lists: Vec<PositionPostingList>,
263 slop: u32,
264 idf: f32,
265 avg_field_len: f32,
266 ) -> Self {
267 let posting_iters: Vec<_> = posting_lists
268 .into_iter()
269 .map(|p| p.into_iterator())
270 .collect();
271
272 let mut scorer = Self {
273 posting_iters,
274 position_lists,
275 slop,
276 current_doc: 0,
277 idf,
278 avg_field_len,
279 };
280
281 scorer.find_next_phrase_match();
282 scorer
283 }
284
285 fn find_next_phrase_match(&mut self) {
287 loop {
288 let doc = self.find_next_and_match();
290 if doc == TERMINATED {
291 self.current_doc = TERMINATED;
292 return;
293 }
294
295 if self.check_phrase_positions(doc) {
297 self.current_doc = doc;
298 return;
299 }
300
301 self.posting_iters[0].advance();
303 }
304 }
305
306 fn find_next_and_match(&mut self) -> DocId {
308 if self.posting_iters.is_empty() {
309 return TERMINATED;
310 }
311
312 loop {
313 let max_doc = self.posting_iters.iter().map(|it| it.doc()).max().unwrap();
314
315 if max_doc == TERMINATED {
316 return TERMINATED;
317 }
318
319 let mut all_match = true;
320 for it in &mut self.posting_iters {
321 let doc = it.seek(max_doc);
322 if doc != max_doc {
323 all_match = false;
324 if doc == TERMINATED {
325 return TERMINATED;
326 }
327 }
328 }
329
330 if all_match {
331 return max_doc;
332 }
333 }
334 }
335
336 fn check_phrase_positions(&self, doc_id: DocId) -> bool {
338 let mut term_positions: Vec<Vec<u32>> = Vec::with_capacity(self.position_lists.len());
340
341 for pos_list in &self.position_lists {
342 match pos_list.get_positions(doc_id) {
343 Some(positions) => term_positions.push(positions.to_vec()),
344 None => return false,
345 }
346 }
347
348 self.find_phrase_match(&term_positions)
351 }
352
353 fn find_phrase_match(&self, term_positions: &[Vec<u32>]) -> bool {
355 if term_positions.is_empty() {
356 return false;
357 }
358
359 for &first_pos in &term_positions[0] {
362 if self.check_phrase_from_position(first_pos, term_positions) {
363 return true;
364 }
365 }
366
367 false
368 }
369
370 fn check_phrase_from_position(&self, start_pos: u32, term_positions: &[Vec<u32>]) -> bool {
372 let mut expected_pos = start_pos;
373
374 for (i, positions) in term_positions.iter().enumerate() {
375 if i == 0 {
376 continue; }
378
379 expected_pos += 1;
380
381 let found = positions.iter().any(|&pos| {
383 if self.slop == 0 {
384 pos == expected_pos
385 } else {
386 let diff = pos.abs_diff(expected_pos);
387 diff <= self.slop
388 }
389 });
390
391 if !found {
392 return false;
393 }
394 }
395
396 true
397 }
398}
399
400impl Scorer for PhraseScorer {
401 fn doc(&self) -> DocId {
402 self.current_doc
403 }
404
405 fn score(&self) -> Score {
406 if self.current_doc == TERMINATED {
407 return 0.0;
408 }
409
410 let tf: f32 = self
412 .posting_iters
413 .iter()
414 .map(|it| it.term_freq() as f32)
415 .sum();
416
417 super::bm25_score(tf, self.idf, tf, self.avg_field_len) * 1.5
419 }
420
421 fn advance(&mut self) -> DocId {
422 if self.current_doc == TERMINATED {
423 return TERMINATED;
424 }
425
426 self.posting_iters[0].advance();
427 self.find_next_phrase_match();
428 self.current_doc
429 }
430
431 fn seek(&mut self, target: DocId) -> DocId {
432 if target == TERMINATED {
433 self.current_doc = TERMINATED;
434 return TERMINATED;
435 }
436
437 self.posting_iters[0].seek(target);
438 self.find_next_phrase_match();
439 self.current_doc
440 }
441
442 fn size_hint(&self) -> u32 {
443 0
444 }
445}