Skip to main content

hermes_core/query/
phrase.rs

1//! Phrase query - matches documents containing terms in consecutive positions
2
3use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::{BlockPostingIterator, BlockPostingList, PositionPostingList, TERMINATED};
8use crate::{DocId, Score};
9
10use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture};
11
12/// Phrase query - matches documents containing terms in consecutive positions
13///
14/// Example: "quick brown fox" matches only if all three terms appear
15/// consecutively in the document.
16#[derive(Clone)]
17pub struct PhraseQuery {
18    pub field: Field,
19    /// Terms in the phrase, in order
20    pub terms: Vec<Vec<u8>>,
21    /// Optional slop (max distance between terms, 0 = exact phrase)
22    pub slop: u32,
23    /// Optional global statistics for cross-segment IDF
24    global_stats: Option<Arc<GlobalStats>>,
25}
26
27impl std::fmt::Debug for PhraseQuery {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        let terms: Vec<_> = self
30            .terms
31            .iter()
32            .map(|t| String::from_utf8_lossy(t).to_string())
33            .collect();
34        f.debug_struct("PhraseQuery")
35            .field("field", &self.field)
36            .field("terms", &terms)
37            .field("slop", &self.slop)
38            .finish()
39    }
40}
41
42impl PhraseQuery {
43    /// Create a new exact phrase query
44    pub fn new(field: Field, terms: Vec<Vec<u8>>) -> Self {
45        Self {
46            field,
47            terms,
48            slop: 0,
49            global_stats: None,
50        }
51    }
52
53    /// Create from text (splits on whitespace and lowercases)
54    pub fn text(field: Field, phrase: &str) -> Self {
55        let terms: Vec<Vec<u8>> = phrase
56            .split_whitespace()
57            .map(|w| w.to_lowercase().into_bytes())
58            .collect();
59        Self {
60            field,
61            terms,
62            slop: 0,
63            global_stats: None,
64        }
65    }
66
67    /// Set slop (max distance between terms)
68    pub fn with_slop(mut self, slop: u32) -> Self {
69        self.slop = slop;
70        self
71    }
72
73    /// Set global statistics for cross-segment IDF
74    pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
75        self.global_stats = Some(stats);
76        self
77    }
78}
79
80impl Query for PhraseQuery {
81    fn scorer<'a>(
82        &self,
83        reader: &'a SegmentReader,
84        limit: usize,
85        predicate: Option<super::DocPredicate<'a>>,
86    ) -> ScorerFuture<'a> {
87        let field = self.field;
88        let terms = self.terms.clone();
89        let slop = self.slop;
90        let _global_stats = self.global_stats.clone();
91
92        Box::pin(async move {
93            if terms.is_empty() {
94                return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>);
95            }
96
97            // Single term - delegate to TermQuery
98            if terms.len() == 1 {
99                let term_query = super::TermQuery::new(field, terms[0].clone());
100                return term_query.scorer(reader, limit, predicate).await;
101            }
102
103            // Check if positions are available
104            if !reader.has_positions(field) {
105                // Fall back to AND query (BooleanQuery with MUST clauses)
106                let mut bool_query = super::BooleanQuery::new();
107                for term in &terms {
108                    bool_query = bool_query.must(super::TermQuery::new(field, term.clone()));
109                }
110                return bool_query.scorer(reader, limit, predicate).await;
111            }
112
113            // Load postings and positions for all terms (parallel per term)
114            let mut term_postings: Vec<BlockPostingList> = Vec::with_capacity(terms.len());
115            let mut term_positions: Vec<PositionPostingList> = Vec::with_capacity(terms.len());
116
117            for term in &terms {
118                // Fetch postings and positions in parallel
119                let (postings, positions) = futures::join!(
120                    reader.get_postings(field, term),
121                    reader.get_positions(field, term)
122                );
123
124                match (postings?, positions?) {
125                    (Some(p), Some(pos)) => {
126                        term_postings.push(p);
127                        term_positions.push(pos);
128                    }
129                    _ => {
130                        // If any term is missing, no documents can match
131                        return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>);
132                    }
133                }
134            }
135
136            // Compute combined IDF (sum of individual IDFs)
137            let idf: f32 = term_postings
138                .iter()
139                .map(|p| {
140                    let num_docs = reader.num_docs() as f32;
141                    let doc_freq = p.doc_count() as f32;
142                    super::bm25_idf(doc_freq, num_docs)
143                })
144                .sum();
145
146            let avg_field_len = reader.avg_field_len(field);
147
148            Ok(Box::new(PhraseScorer::new(
149                term_postings,
150                term_positions,
151                slop,
152                idf,
153                avg_field_len,
154            )) as Box<dyn Scorer + 'a>)
155        })
156    }
157
158    #[cfg(feature = "sync")]
159    fn scorer_sync<'a>(
160        &self,
161        reader: &'a SegmentReader,
162        limit: usize,
163        predicate: Option<super::DocPredicate<'a>>,
164    ) -> crate::Result<Box<dyn Scorer + 'a>> {
165        if self.terms.is_empty() {
166            return Ok(Box::new(super::EmptyScorer) as Box<dyn Scorer + 'a>);
167        }
168
169        if self.terms.len() == 1 {
170            let term_query = super::TermQuery::new(self.field, self.terms[0].clone());
171            return term_query.scorer_sync(reader, limit, predicate);
172        }
173
174        if !reader.has_positions(self.field) {
175            let mut bool_query = super::BooleanQuery::new();
176            for term in &self.terms {
177                bool_query = bool_query.must(super::TermQuery::new(self.field, term.clone()));
178            }
179            return bool_query.scorer_sync(reader, limit, predicate);
180        }
181
182        let mut term_postings: Vec<BlockPostingList> = Vec::with_capacity(self.terms.len());
183        let mut term_positions: Vec<PositionPostingList> = Vec::with_capacity(self.terms.len());
184
185        for term in &self.terms {
186            let postings = reader.get_postings_sync(self.field, term)?;
187            let positions = reader.get_positions_sync(self.field, term)?;
188
189            match (postings, positions) {
190                (Some(p), Some(pos)) => {
191                    term_postings.push(p);
192                    term_positions.push(pos);
193                }
194                _ => return Ok(Box::new(super::EmptyScorer) as Box<dyn Scorer + 'a>),
195            }
196        }
197
198        let idf: f32 = term_postings
199            .iter()
200            .map(|p| {
201                let num_docs = reader.num_docs() as f32;
202                let doc_freq = p.doc_count() as f32;
203                super::bm25_idf(doc_freq, num_docs)
204            })
205            .sum();
206
207        let avg_field_len = reader.avg_field_len(self.field);
208
209        Ok(Box::new(PhraseScorer::new(
210            term_postings,
211            term_positions,
212            self.slop,
213            idf,
214            avg_field_len,
215        )) as Box<dyn Scorer + 'a>)
216    }
217
218    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
219        let field = self.field;
220        let terms = self.terms.clone();
221
222        Box::pin(async move {
223            if terms.is_empty() {
224                return Ok(0);
225            }
226
227            // Estimate based on minimum posting list size
228            let mut min_count = u32::MAX;
229            for term in &terms {
230                match reader.get_postings(field, term).await? {
231                    Some(list) => min_count = min_count.min(list.doc_count()),
232                    None => return Ok(0),
233                }
234            }
235
236            // Phrase matching will typically match fewer docs than the minimum
237            // Estimate ~10% of the smallest posting list
238            Ok((min_count / 10).max(1))
239        })
240    }
241}
242
243/// Scorer that checks phrase positions
244struct PhraseScorer {
245    /// Posting iterators for each term
246    posting_iters: Vec<BlockPostingIterator<'static>>,
247    /// Position iterators for each term
248    position_lists: Vec<PositionPostingList>,
249    /// Max slop between terms
250    slop: u32,
251    /// Current matching document
252    current_doc: DocId,
253    /// Combined IDF
254    idf: f32,
255    /// Average field length
256    avg_field_len: f32,
257}
258
259impl PhraseScorer {
260    fn new(
261        posting_lists: Vec<BlockPostingList>,
262        position_lists: Vec<PositionPostingList>,
263        slop: u32,
264        idf: f32,
265        avg_field_len: f32,
266    ) -> Self {
267        let posting_iters: Vec<_> = posting_lists
268            .into_iter()
269            .map(|p| p.into_iterator())
270            .collect();
271
272        let mut scorer = Self {
273            posting_iters,
274            position_lists,
275            slop,
276            current_doc: 0,
277            idf,
278            avg_field_len,
279        };
280
281        scorer.find_next_phrase_match();
282        scorer
283    }
284
285    /// Find next document where all terms appear as a phrase
286    fn find_next_phrase_match(&mut self) {
287        loop {
288            // First, find a document where all terms appear (AND semantics)
289            let doc = self.find_next_and_match();
290            if doc == TERMINATED {
291                self.current_doc = TERMINATED;
292                return;
293            }
294
295            // Check if positions form a valid phrase
296            if self.check_phrase_positions(doc) {
297                self.current_doc = doc;
298                return;
299            }
300
301            // Advance and try again
302            self.posting_iters[0].advance();
303        }
304    }
305
306    /// Find next document where all terms appear
307    fn find_next_and_match(&mut self) -> DocId {
308        if self.posting_iters.is_empty() {
309            return TERMINATED;
310        }
311
312        loop {
313            let max_doc = self.posting_iters.iter().map(|it| it.doc()).max().unwrap();
314
315            if max_doc == TERMINATED {
316                return TERMINATED;
317            }
318
319            let mut all_match = true;
320            for it in &mut self.posting_iters {
321                let doc = it.seek(max_doc);
322                if doc != max_doc {
323                    all_match = false;
324                    if doc == TERMINATED {
325                        return TERMINATED;
326                    }
327                }
328            }
329
330            if all_match {
331                return max_doc;
332            }
333        }
334    }
335
336    /// Check if positions form a valid phrase for the given document
337    fn check_phrase_positions(&self, doc_id: DocId) -> bool {
338        // Get positions for each term in this document
339        let mut term_positions: Vec<Vec<u32>> = Vec::with_capacity(self.position_lists.len());
340
341        for pos_list in &self.position_lists {
342            match pos_list.get_positions(doc_id) {
343                Some(positions) => term_positions.push(positions.to_vec()),
344                None => return false,
345            }
346        }
347
348        // Check for consecutive positions
349        // For exact phrase (slop=0), position[i+1] = position[i] + 1
350        self.find_phrase_match(&term_positions)
351    }
352
353    /// Find if there's a valid phrase match among the positions
354    fn find_phrase_match(&self, term_positions: &[Vec<u32>]) -> bool {
355        if term_positions.is_empty() {
356            return false;
357        }
358
359        // For each position of the first term, check if subsequent terms
360        // have positions that form a phrase
361        for &first_pos in &term_positions[0] {
362            if self.check_phrase_from_position(first_pos, term_positions) {
363                return true;
364            }
365        }
366
367        false
368    }
369
370    /// Check if a phrase exists starting from the given position
371    fn check_phrase_from_position(&self, start_pos: u32, term_positions: &[Vec<u32>]) -> bool {
372        let mut expected_pos = start_pos;
373
374        for (i, positions) in term_positions.iter().enumerate() {
375            if i == 0 {
376                continue; // Skip first term, already matched
377            }
378
379            expected_pos += 1;
380
381            // Find a position within slop distance
382            let found = positions.iter().any(|&pos| {
383                if self.slop == 0 {
384                    pos == expected_pos
385                } else {
386                    let diff = pos.abs_diff(expected_pos);
387                    diff <= self.slop
388                }
389            });
390
391            if !found {
392                return false;
393            }
394        }
395
396        true
397    }
398}
399
400impl Scorer for PhraseScorer {
401    fn doc(&self) -> DocId {
402        self.current_doc
403    }
404
405    fn score(&self) -> Score {
406        if self.current_doc == TERMINATED {
407            return 0.0;
408        }
409
410        // Sum term frequencies for BM25 scoring
411        let tf: f32 = self
412            .posting_iters
413            .iter()
414            .map(|it| it.term_freq() as f32)
415            .sum();
416
417        // Phrase matches get a boost since they're more precise
418        super::bm25_score(tf, self.idf, tf, self.avg_field_len) * 1.5
419    }
420
421    fn advance(&mut self) -> DocId {
422        if self.current_doc == TERMINATED {
423            return TERMINATED;
424        }
425
426        self.posting_iters[0].advance();
427        self.find_next_phrase_match();
428        self.current_doc
429    }
430
431    fn seek(&mut self, target: DocId) -> DocId {
432        if target == TERMINATED {
433            self.current_doc = TERMINATED;
434            return TERMINATED;
435        }
436
437        self.posting_iters[0].seek(target);
438        self.find_next_phrase_match();
439        self.current_doc
440    }
441
442    fn size_hint(&self) -> u32 {
443        0
444    }
445}