Skip to main content

hermes_core/query/
phrase.rs

1//! Phrase query - matches documents containing terms in consecutive positions
2
3use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::{BlockPostingIterator, BlockPostingList, PositionPostingList, TERMINATED};
8use crate::{DocId, Score};
9
10use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture};
11
12/// Phrase query - matches documents containing terms in consecutive positions
13///
14/// Example: "quick brown fox" matches only if all three terms appear
15/// consecutively in the document.
16#[derive(Clone)]
17pub struct PhraseQuery {
18    pub field: Field,
19    /// Terms in the phrase, in order
20    pub terms: Vec<Vec<u8>>,
21    /// Optional slop (max distance between terms, 0 = exact phrase)
22    pub slop: u32,
23    /// Optional global statistics for cross-segment IDF
24    global_stats: Option<Arc<GlobalStats>>,
25}
26
27impl std::fmt::Display for PhraseQuery {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        let terms: Vec<_> = self
30            .terms
31            .iter()
32            .map(|t| String::from_utf8_lossy(t))
33            .collect();
34        write!(f, "Phrase({}:\"{}\"", self.field.0, terms.join(" "))?;
35        if self.slop > 0 {
36            write!(f, "~{}", self.slop)?;
37        }
38        write!(f, ")")
39    }
40}
41
42impl std::fmt::Debug for PhraseQuery {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        let terms: Vec<_> = self
45            .terms
46            .iter()
47            .map(|t| String::from_utf8_lossy(t).to_string())
48            .collect();
49        f.debug_struct("PhraseQuery")
50            .field("field", &self.field)
51            .field("terms", &terms)
52            .field("slop", &self.slop)
53            .finish()
54    }
55}
56
57impl PhraseQuery {
58    /// Create a new exact phrase query
59    pub fn new(field: Field, terms: Vec<Vec<u8>>) -> Self {
60        Self {
61            field,
62            terms,
63            slop: 0,
64            global_stats: None,
65        }
66    }
67
68    /// Create from text (splits on whitespace and lowercases)
69    pub fn text(field: Field, phrase: &str) -> Self {
70        let terms: Vec<Vec<u8>> = phrase
71            .split_whitespace()
72            .map(|w| w.to_lowercase().into_bytes())
73            .collect();
74        Self {
75            field,
76            terms,
77            slop: 0,
78            global_stats: None,
79        }
80    }
81
82    /// Set slop (max distance between terms)
83    pub fn with_slop(mut self, slop: u32) -> Self {
84        self.slop = slop;
85        self
86    }
87
88    /// Set global statistics for cross-segment IDF
89    pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
90        self.global_stats = Some(stats);
91        self
92    }
93}
94
95/// Build a PhraseScorer from already-fetched term data.
96fn build_phrase_scorer<'a>(
97    term_data: Vec<(BlockPostingList, PositionPostingList)>,
98    slop: u32,
99    reader: &SegmentReader,
100    field: Field,
101) -> Box<dyn Scorer + 'a> {
102    let idf: f32 = term_data
103        .iter()
104        .map(|(p, _)| {
105            let num_docs = reader.num_docs() as f32;
106            let doc_freq = p.doc_count() as f32;
107            super::bm25_idf(doc_freq, num_docs)
108        })
109        .sum();
110    let avg_field_len = reader.avg_field_len(field);
111    let (postings, positions): (Vec<_>, Vec<_>) = term_data.into_iter().unzip();
112    Box::new(PhraseScorer::new(
113        postings,
114        positions,
115        slop,
116        idf,
117        avg_field_len,
118    ))
119}
120
121// ── Shared early-return checks for phrase scorer ─────────────────────────
122//
123// Handles: empty terms, single-term delegation, no-positions fallback.
124// Parameterised on $scorer_fn + $($aw)* for async/sync.
125macro_rules! phrase_early_returns {
126    ($field:expr, $terms:expr, $reader:expr, $limit:expr,
127     $scorer_fn:ident $(, $aw:tt)*) => {
128        if $terms.is_empty() {
129            return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + '_>);
130        }
131        if $terms.len() == 1 {
132            let tq = super::TermQuery::new($field, $terms[0].clone());
133            return tq.$scorer_fn($reader, $limit) $(. $aw)* ;
134        }
135        if !$reader.has_positions($field) {
136            let mut bq = super::BooleanQuery::new();
137            for t in $terms.iter() {
138                bq = bq.must(super::TermQuery::new($field, t.clone()));
139            }
140            return bq.$scorer_fn($reader, $limit) $(. $aw)* ;
141        }
142    };
143}
144
145impl Query for PhraseQuery {
146    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
147        let field = self.field;
148        let terms = self.terms.clone();
149        let slop = self.slop;
150
151        Box::pin(async move {
152            phrase_early_returns!(field, terms, reader, limit, scorer, await);
153
154            // Fetch postings + positions in parallel per term via futures::join!
155            let mut term_data = Vec::with_capacity(terms.len());
156            for term in &terms {
157                let (postings, positions) = futures::join!(
158                    reader.get_postings(field, term),
159                    reader.get_positions(field, term)
160                );
161                match (postings?, positions?) {
162                    (Some(p), Some(pos)) => term_data.push((p, pos)),
163                    _ => return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
164                }
165            }
166
167            Ok(build_phrase_scorer(term_data, slop, reader, field))
168        })
169    }
170
171    #[cfg(feature = "sync")]
172    fn scorer_sync<'a>(
173        &self,
174        reader: &'a SegmentReader,
175        limit: usize,
176    ) -> crate::Result<Box<dyn Scorer + 'a>> {
177        phrase_early_returns!(self.field, self.terms, reader, limit, scorer_sync);
178
179        // Parallel fetch across all terms via rayon
180        use rayon::prelude::*;
181        let pairs: crate::Result<Vec<Option<(BlockPostingList, PositionPostingList)>>> = self
182            .terms
183            .par_iter()
184            .map(|term| {
185                let postings = reader.get_postings_sync(self.field, term)?;
186                let positions = reader.get_positions_sync(self.field, term)?;
187                Ok(match (postings, positions) {
188                    (Some(p), Some(pos)) => Some((p, pos)),
189                    _ => None,
190                })
191            })
192            .collect();
193        let mut term_data = Vec::with_capacity(self.terms.len());
194        for entry in pairs? {
195            match entry {
196                Some(pair) => term_data.push(pair),
197                None => return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
198            }
199        }
200
201        Ok(build_phrase_scorer(
202            term_data, self.slop, reader, self.field,
203        ))
204    }
205
206    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
207        let field = self.field;
208        let terms = self.terms.clone();
209
210        Box::pin(async move {
211            if terms.is_empty() {
212                return Ok(0);
213            }
214
215            // Estimate based on minimum posting list size
216            let mut min_count = u32::MAX;
217            for term in &terms {
218                match reader.get_postings(field, term).await? {
219                    Some(list) => min_count = min_count.min(list.doc_count()),
220                    None => return Ok(0),
221                }
222            }
223
224            // Phrase matching will typically match fewer docs than the minimum
225            // Estimate ~10% of the smallest posting list
226            Ok((min_count / 10).max(1))
227        })
228    }
229}
230
231/// Scorer that checks phrase positions
232struct PhraseScorer {
233    /// Posting iterators for each term
234    posting_iters: Vec<BlockPostingIterator<'static>>,
235    /// Position iterators for each term
236    position_lists: Vec<PositionPostingList>,
237    /// Max slop between terms
238    slop: u32,
239    /// Current matching document
240    current_doc: DocId,
241    /// Combined IDF
242    idf: f32,
243    /// Average field length
244    avg_field_len: f32,
245}
246
247impl PhraseScorer {
248    fn new(
249        posting_lists: Vec<BlockPostingList>,
250        position_lists: Vec<PositionPostingList>,
251        slop: u32,
252        idf: f32,
253        avg_field_len: f32,
254    ) -> Self {
255        let posting_iters: Vec<_> = posting_lists
256            .into_iter()
257            .map(|p| p.into_iterator())
258            .collect();
259
260        let mut scorer = Self {
261            posting_iters,
262            position_lists,
263            slop,
264            current_doc: 0,
265            idf,
266            avg_field_len,
267        };
268
269        scorer.find_next_phrase_match();
270        scorer
271    }
272
273    /// Find next document where all terms appear as a phrase
274    fn find_next_phrase_match(&mut self) {
275        loop {
276            // First, find a document where all terms appear (AND semantics)
277            let doc = self.find_next_and_match();
278            if doc == TERMINATED {
279                self.current_doc = TERMINATED;
280                return;
281            }
282
283            // Check if positions form a valid phrase
284            if self.check_phrase_positions(doc) {
285                self.current_doc = doc;
286                return;
287            }
288
289            // Advance and try again
290            self.posting_iters[0].advance();
291        }
292    }
293
294    /// Find next document where all terms appear
295    fn find_next_and_match(&mut self) -> DocId {
296        if self.posting_iters.is_empty() {
297            return TERMINATED;
298        }
299
300        loop {
301            let max_doc = self.posting_iters.iter().map(|it| it.doc()).max().unwrap();
302
303            if max_doc == TERMINATED {
304                return TERMINATED;
305            }
306
307            let mut all_match = true;
308            for it in &mut self.posting_iters {
309                let doc = it.seek(max_doc);
310                if doc != max_doc {
311                    all_match = false;
312                    if doc == TERMINATED {
313                        return TERMINATED;
314                    }
315                }
316            }
317
318            if all_match {
319                return max_doc;
320            }
321        }
322    }
323
324    /// Check if positions form a valid phrase for the given document
325    fn check_phrase_positions(&self, doc_id: DocId) -> bool {
326        // Get positions for each term in this document
327        let mut term_positions: Vec<Vec<u32>> = Vec::with_capacity(self.position_lists.len());
328
329        for pos_list in &self.position_lists {
330            match pos_list.get_positions(doc_id) {
331                Some(positions) => term_positions.push(positions.to_vec()),
332                None => return false,
333            }
334        }
335
336        // Check for consecutive positions
337        // For exact phrase (slop=0), position[i+1] = position[i] + 1
338        self.find_phrase_match(&term_positions)
339    }
340
341    /// Find if there's a valid phrase match among the positions
342    fn find_phrase_match(&self, term_positions: &[Vec<u32>]) -> bool {
343        if term_positions.is_empty() {
344            return false;
345        }
346
347        // For each position of the first term, check if subsequent terms
348        // have positions that form a phrase
349        for &first_pos in &term_positions[0] {
350            if self.check_phrase_from_position(first_pos, term_positions) {
351                return true;
352            }
353        }
354
355        false
356    }
357
358    /// Check if a phrase exists starting from the given position
359    fn check_phrase_from_position(&self, start_pos: u32, term_positions: &[Vec<u32>]) -> bool {
360        let mut expected_pos = start_pos;
361
362        for (i, positions) in term_positions.iter().enumerate() {
363            if i == 0 {
364                continue; // Skip first term, already matched
365            }
366
367            expected_pos += 1;
368
369            // Find a position within slop distance
370            let found = positions.iter().any(|&pos| {
371                if self.slop == 0 {
372                    pos == expected_pos
373                } else {
374                    let diff = pos.abs_diff(expected_pos);
375                    diff <= self.slop
376                }
377            });
378
379            if !found {
380                return false;
381            }
382        }
383
384        true
385    }
386}
387
388impl super::docset::DocSet for PhraseScorer {
389    fn doc(&self) -> DocId {
390        self.current_doc
391    }
392
393    fn advance(&mut self) -> DocId {
394        if self.current_doc == TERMINATED {
395            return TERMINATED;
396        }
397
398        self.posting_iters[0].advance();
399        self.find_next_phrase_match();
400        self.current_doc
401    }
402
403    fn seek(&mut self, target: DocId) -> DocId {
404        if target == TERMINATED {
405            self.current_doc = TERMINATED;
406            return TERMINATED;
407        }
408
409        self.posting_iters[0].seek(target);
410        self.find_next_phrase_match();
411        self.current_doc
412    }
413
414    fn size_hint(&self) -> u32 {
415        0
416    }
417}
418
419impl Scorer for PhraseScorer {
420    fn score(&self) -> Score {
421        if self.current_doc == TERMINATED {
422            return 0.0;
423        }
424
425        // Sum term frequencies for BM25 scoring
426        let tf: f32 = self
427            .posting_iters
428            .iter()
429            .map(|it| it.term_freq() as f32)
430            .sum();
431
432        // Phrase matches get a boost since they're more precise
433        super::bm25_score(tf, self.idf, tf, self.avg_field_len) * 1.5
434    }
435}