hermes_core/query/
phrase.rs1use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::{BlockPostingIterator, BlockPostingList, PositionPostingList, TERMINATED};
8use crate::{DocId, Score};
9
10use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture};
11
12#[derive(Clone)]
17pub struct PhraseQuery {
18 pub field: Field,
19 pub terms: Vec<Vec<u8>>,
21 pub slop: u32,
23 global_stats: Option<Arc<GlobalStats>>,
25}
26
27impl std::fmt::Debug for PhraseQuery {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 let terms: Vec<_> = self
30 .terms
31 .iter()
32 .map(|t| String::from_utf8_lossy(t).to_string())
33 .collect();
34 f.debug_struct("PhraseQuery")
35 .field("field", &self.field)
36 .field("terms", &terms)
37 .field("slop", &self.slop)
38 .finish()
39 }
40}
41
42impl PhraseQuery {
43 pub fn new(field: Field, terms: Vec<Vec<u8>>) -> Self {
45 Self {
46 field,
47 terms,
48 slop: 0,
49 global_stats: None,
50 }
51 }
52
53 pub fn text(field: Field, phrase: &str) -> Self {
55 let terms: Vec<Vec<u8>> = phrase
56 .split_whitespace()
57 .map(|w| w.to_lowercase().into_bytes())
58 .collect();
59 Self {
60 field,
61 terms,
62 slop: 0,
63 global_stats: None,
64 }
65 }
66
67 pub fn with_slop(mut self, slop: u32) -> Self {
69 self.slop = slop;
70 self
71 }
72
73 pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
75 self.global_stats = Some(stats);
76 self
77 }
78}
79
80impl Query for PhraseQuery {
81 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
82 let field = self.field;
83 let terms = self.terms.clone();
84 let slop = self.slop;
85 let _global_stats = self.global_stats.clone();
86
87 Box::pin(async move {
88 if terms.is_empty() {
89 return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>);
90 }
91
92 if terms.len() == 1 {
94 let term_query = super::TermQuery::new(field, terms[0].clone());
95 return term_query.scorer(reader, limit).await;
96 }
97
98 if !reader.has_positions(field) {
100 let mut bool_query = super::BooleanQuery::new();
102 for term in &terms {
103 bool_query = bool_query.must(super::TermQuery::new(field, term.clone()));
104 }
105 return bool_query.scorer(reader, limit).await;
106 }
107
108 let mut term_postings: Vec<BlockPostingList> = Vec::with_capacity(terms.len());
110 let mut term_positions: Vec<PositionPostingList> = Vec::with_capacity(terms.len());
111
112 for term in &terms {
113 let (postings, positions) = futures::join!(
115 reader.get_postings(field, term),
116 reader.get_positions(field, term)
117 );
118
119 match (postings?, positions?) {
120 (Some(p), Some(pos)) => {
121 term_postings.push(p);
122 term_positions.push(pos);
123 }
124 _ => {
125 return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>);
127 }
128 }
129 }
130
131 let idf: f32 = term_postings
133 .iter()
134 .map(|p| {
135 let num_docs = reader.num_docs() as f32;
136 let doc_freq = p.doc_count() as f32;
137 super::bm25_idf(doc_freq, num_docs)
138 })
139 .sum();
140
141 let avg_field_len = reader.avg_field_len(field);
142
143 Ok(Box::new(PhraseScorer::new(
144 term_postings,
145 term_positions,
146 slop,
147 idf,
148 avg_field_len,
149 )) as Box<dyn Scorer + 'a>)
150 })
151 }
152
153 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
154 let field = self.field;
155 let terms = self.terms.clone();
156
157 Box::pin(async move {
158 if terms.is_empty() {
159 return Ok(0);
160 }
161
162 let mut min_count = u32::MAX;
164 for term in &terms {
165 match reader.get_postings(field, term).await? {
166 Some(list) => min_count = min_count.min(list.doc_count()),
167 None => return Ok(0),
168 }
169 }
170
171 Ok((min_count / 10).max(1))
174 })
175 }
176}
177
178struct PhraseScorer {
180 posting_iters: Vec<BlockPostingIterator<'static>>,
182 position_lists: Vec<PositionPostingList>,
184 slop: u32,
186 current_doc: DocId,
188 idf: f32,
190 avg_field_len: f32,
192}
193
194impl PhraseScorer {
195 fn new(
196 posting_lists: Vec<BlockPostingList>,
197 position_lists: Vec<PositionPostingList>,
198 slop: u32,
199 idf: f32,
200 avg_field_len: f32,
201 ) -> Self {
202 let posting_iters: Vec<_> = posting_lists
203 .into_iter()
204 .map(|p| p.into_iterator())
205 .collect();
206
207 let mut scorer = Self {
208 posting_iters,
209 position_lists,
210 slop,
211 current_doc: 0,
212 idf,
213 avg_field_len,
214 };
215
216 scorer.find_next_phrase_match();
217 scorer
218 }
219
220 fn find_next_phrase_match(&mut self) {
222 loop {
223 let doc = self.find_next_and_match();
225 if doc == TERMINATED {
226 self.current_doc = TERMINATED;
227 return;
228 }
229
230 if self.check_phrase_positions(doc) {
232 self.current_doc = doc;
233 return;
234 }
235
236 self.posting_iters[0].advance();
238 }
239 }
240
241 fn find_next_and_match(&mut self) -> DocId {
243 if self.posting_iters.is_empty() {
244 return TERMINATED;
245 }
246
247 loop {
248 let max_doc = self.posting_iters.iter().map(|it| it.doc()).max().unwrap();
249
250 if max_doc == TERMINATED {
251 return TERMINATED;
252 }
253
254 let mut all_match = true;
255 for it in &mut self.posting_iters {
256 let doc = it.seek(max_doc);
257 if doc != max_doc {
258 all_match = false;
259 if doc == TERMINATED {
260 return TERMINATED;
261 }
262 }
263 }
264
265 if all_match {
266 return max_doc;
267 }
268 }
269 }
270
271 fn check_phrase_positions(&self, doc_id: DocId) -> bool {
273 let mut term_positions: Vec<Vec<u32>> = Vec::with_capacity(self.position_lists.len());
275
276 for pos_list in &self.position_lists {
277 match pos_list.get_positions(doc_id) {
278 Some(positions) => term_positions.push(positions.to_vec()),
279 None => return false,
280 }
281 }
282
283 self.find_phrase_match(&term_positions)
286 }
287
288 fn find_phrase_match(&self, term_positions: &[Vec<u32>]) -> bool {
290 if term_positions.is_empty() {
291 return false;
292 }
293
294 for &first_pos in &term_positions[0] {
297 if self.check_phrase_from_position(first_pos, term_positions) {
298 return true;
299 }
300 }
301
302 false
303 }
304
305 fn check_phrase_from_position(&self, start_pos: u32, term_positions: &[Vec<u32>]) -> bool {
307 let mut expected_pos = start_pos;
308
309 for (i, positions) in term_positions.iter().enumerate() {
310 if i == 0 {
311 continue; }
313
314 expected_pos += 1;
315
316 let found = positions.iter().any(|&pos| {
318 if self.slop == 0 {
319 pos == expected_pos
320 } else {
321 let diff = pos.abs_diff(expected_pos);
322 diff <= self.slop
323 }
324 });
325
326 if !found {
327 return false;
328 }
329 }
330
331 true
332 }
333}
334
335impl Scorer for PhraseScorer {
336 fn doc(&self) -> DocId {
337 self.current_doc
338 }
339
340 fn score(&self) -> Score {
341 if self.current_doc == TERMINATED {
342 return 0.0;
343 }
344
345 let tf: f32 = self
347 .posting_iters
348 .iter()
349 .map(|it| it.term_freq() as f32)
350 .sum();
351
352 super::bm25_score(tf, self.idf, tf, self.avg_field_len) * 1.5
354 }
355
356 fn advance(&mut self) -> DocId {
357 if self.current_doc == TERMINATED {
358 return TERMINATED;
359 }
360
361 self.posting_iters[0].advance();
362 self.find_next_phrase_match();
363 self.current_doc
364 }
365
366 fn seek(&mut self, target: DocId) -> DocId {
367 if target == TERMINATED {
368 self.current_doc = TERMINATED;
369 return TERMINATED;
370 }
371
372 self.posting_iters[0].seek(target);
373 self.find_next_phrase_match();
374 self.current_doc
375 }
376
377 fn size_hint(&self) -> u32 {
378 0
379 }
380}