Skip to main content

hermes_core/query/
phrase.rs

1//! Phrase query - matches documents containing terms in consecutive positions
2
3use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::{BlockPostingIterator, BlockPostingList, PositionPostingList, TERMINATED};
8use crate::{DocId, Score};
9
10use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture};
11
12/// Phrase query - matches documents containing terms in consecutive positions
13///
14/// Example: "quick brown fox" matches only if all three terms appear
15/// consecutively in the document.
16#[derive(Clone)]
17pub struct PhraseQuery {
18    pub field: Field,
19    /// Terms in the phrase, in order
20    pub terms: Vec<Vec<u8>>,
21    /// Optional slop (max distance between terms, 0 = exact phrase)
22    pub slop: u32,
23    /// Optional global statistics for cross-segment IDF
24    global_stats: Option<Arc<GlobalStats>>,
25}
26
27impl std::fmt::Display for PhraseQuery {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        let terms: Vec<_> = self
30            .terms
31            .iter()
32            .map(|t| String::from_utf8_lossy(t))
33            .collect();
34        write!(f, "Phrase({}:\"{}\"", self.field.0, terms.join(" "))?;
35        if self.slop > 0 {
36            write!(f, "~{}", self.slop)?;
37        }
38        write!(f, ")")
39    }
40}
41
42impl std::fmt::Debug for PhraseQuery {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        let terms: Vec<_> = self
45            .terms
46            .iter()
47            .map(|t| String::from_utf8_lossy(t).to_string())
48            .collect();
49        f.debug_struct("PhraseQuery")
50            .field("field", &self.field)
51            .field("terms", &terms)
52            .field("slop", &self.slop)
53            .finish()
54    }
55}
56
57impl PhraseQuery {
58    /// Create a new exact phrase query
59    pub fn new(field: Field, terms: Vec<Vec<u8>>) -> Self {
60        Self {
61            field,
62            terms,
63            slop: 0,
64            global_stats: None,
65        }
66    }
67
68    /// Create from text (splits on whitespace and lowercases)
69    pub fn text(field: Field, phrase: &str) -> Self {
70        let terms: Vec<Vec<u8>> = phrase
71            .split_whitespace()
72            .map(|w| w.to_lowercase().into_bytes())
73            .collect();
74        Self {
75            field,
76            terms,
77            slop: 0,
78            global_stats: None,
79        }
80    }
81
82    /// Set slop (max distance between terms)
83    pub fn with_slop(mut self, slop: u32) -> Self {
84        self.slop = slop;
85        self
86    }
87
88    /// Set global statistics for cross-segment IDF
89    pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
90        self.global_stats = Some(stats);
91        self
92    }
93}
94
95/// Build a PhraseScorer from already-fetched term data.
96fn build_phrase_scorer<'a>(
97    term_data: Vec<(BlockPostingList, PositionPostingList)>,
98    slop: u32,
99    reader: &SegmentReader,
100    field: Field,
101) -> Box<dyn Scorer + 'a> {
102    let idf: f32 = term_data
103        .iter()
104        .map(|(p, _)| {
105            let num_docs = reader.num_docs() as f32;
106            let doc_freq = p.doc_count() as f32;
107            super::bm25_idf(doc_freq, num_docs)
108        })
109        .sum();
110    let avg_field_len = reader.avg_field_len(field);
111    let (postings, positions): (Vec<_>, Vec<_>) = term_data.into_iter().unzip();
112    Box::new(PhraseScorer::new(
113        postings,
114        positions,
115        slop,
116        idf,
117        avg_field_len,
118    ))
119}
120
121// ── Shared early-return checks for phrase scorer ─────────────────────────
122//
123// Handles: empty terms, single-term delegation, no-positions fallback.
124// Parameterised on $scorer_fn + $($aw)* for async/sync.
125macro_rules! phrase_early_returns {
126    ($field:expr, $terms:expr, $reader:expr, $limit:expr,
127     $scorer_fn:ident $(, $aw:tt)*) => {
128        if $terms.is_empty() {
129            return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + '_>);
130        }
131        if $terms.len() == 1 {
132            let tq = super::TermQuery::new($field, $terms[0].clone());
133            return tq.$scorer_fn($reader, $limit) $(. $aw)* ;
134        }
135        if !$reader.has_positions($field) {
136            let mut bq = super::BooleanQuery::new();
137            for t in $terms.iter() {
138                bq = bq.must(super::TermQuery::new($field, t.clone()));
139            }
140            return bq.$scorer_fn($reader, $limit) $(. $aw)* ;
141        }
142    };
143}
144
145impl Query for PhraseQuery {
146    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
147        let field = self.field;
148        let terms = self.terms.clone();
149        let slop = self.slop;
150
151        Box::pin(async move {
152            phrase_early_returns!(field, terms, reader, limit, scorer, await);
153
154            // Fetch postings + positions in parallel per term via futures::join!
155            let mut term_data = Vec::with_capacity(terms.len());
156            for term in &terms {
157                let (postings, positions) = futures::join!(
158                    reader.get_postings(field, term),
159                    reader.get_positions(field, term)
160                );
161                match (postings?, positions?) {
162                    (Some(p), Some(pos)) => term_data.push((p, pos)),
163                    _ => return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
164                }
165            }
166
167            Ok(build_phrase_scorer(term_data, slop, reader, field))
168        })
169    }
170
171    #[cfg(feature = "sync")]
172    fn scorer_sync<'a>(
173        &self,
174        reader: &'a SegmentReader,
175        limit: usize,
176    ) -> crate::Result<Box<dyn Scorer + 'a>> {
177        phrase_early_returns!(self.field, self.terms, reader, limit, scorer_sync);
178
179        // Parallel fetch across all terms via rayon
180        use rayon::prelude::*;
181        let pairs: crate::Result<Vec<Option<(BlockPostingList, PositionPostingList)>>> = self
182            .terms
183            .par_iter()
184            .map(|term| {
185                let postings = reader.get_postings_sync(self.field, term)?;
186                let positions = reader.get_positions_sync(self.field, term)?;
187                Ok(match (postings, positions) {
188                    (Some(p), Some(pos)) => Some((p, pos)),
189                    _ => None,
190                })
191            })
192            .collect();
193        let mut term_data = Vec::with_capacity(self.terms.len());
194        for entry in pairs? {
195            match entry {
196                Some(pair) => term_data.push(pair),
197                None => return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
198            }
199        }
200
201        Ok(build_phrase_scorer(
202            term_data, self.slop, reader, self.field,
203        ))
204    }
205
206    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
207        let field = self.field;
208        let terms = self.terms.clone();
209
210        Box::pin(async move {
211            if terms.is_empty() {
212                return Ok(0);
213            }
214
215            // Estimate based on minimum posting list size
216            let mut min_count = u32::MAX;
217            for term in &terms {
218                match reader.get_postings(field, term).await? {
219                    Some(list) => min_count = min_count.min(list.doc_count()),
220                    None => return Ok(0),
221                }
222            }
223
224            // Phrase matching will typically match fewer docs than the minimum
225            // Estimate ~10% of the smallest posting list
226            Ok((min_count / 10).max(1))
227        })
228    }
229}
230
231/// Scorer that checks phrase positions
232struct PhraseScorer {
233    /// Posting iterators for each term
234    posting_iters: Vec<BlockPostingIterator<'static>>,
235    /// Position iterators for each term
236    position_lists: Vec<PositionPostingList>,
237    /// Max slop between terms
238    slop: u32,
239    /// Current matching document
240    current_doc: DocId,
241    /// Combined IDF
242    idf: f32,
243    /// Average field length
244    avg_field_len: f32,
245    /// Reusable position buffers (one per term, avoids per-document allocation)
246    position_bufs: Vec<Vec<u32>>,
247}
248
249impl PhraseScorer {
250    fn new(
251        posting_lists: Vec<BlockPostingList>,
252        position_lists: Vec<PositionPostingList>,
253        slop: u32,
254        idf: f32,
255        avg_field_len: f32,
256    ) -> Self {
257        let posting_iters: Vec<_> = posting_lists
258            .into_iter()
259            .map(|p| p.into_iterator())
260            .collect();
261
262        let num_terms = position_lists.len();
263        let mut scorer = Self {
264            posting_iters,
265            position_lists,
266            slop,
267            current_doc: 0,
268            idf,
269            avg_field_len,
270            position_bufs: (0..num_terms).map(|_| Vec::new()).collect(),
271        };
272
273        scorer.find_next_phrase_match();
274        scorer
275    }
276
277    /// Find next document where all terms appear as a phrase
278    fn find_next_phrase_match(&mut self) {
279        loop {
280            // First, find a document where all terms appear (AND semantics)
281            let doc = self.find_next_and_match();
282            if doc == TERMINATED {
283                self.current_doc = TERMINATED;
284                return;
285            }
286
287            // Check if positions form a valid phrase
288            if self.check_phrase_positions(doc) {
289                self.current_doc = doc;
290                return;
291            }
292
293            // Advance and try again
294            self.posting_iters[0].advance();
295        }
296    }
297
298    /// Find next document where all terms appear
299    fn find_next_and_match(&mut self) -> DocId {
300        if self.posting_iters.is_empty() {
301            return TERMINATED;
302        }
303
304        loop {
305            let max_doc = self.posting_iters.iter().map(|it| it.doc()).max().unwrap();
306
307            if max_doc == TERMINATED {
308                return TERMINATED;
309            }
310
311            let mut all_match = true;
312            for it in &mut self.posting_iters {
313                let doc = it.seek(max_doc);
314                if doc != max_doc {
315                    all_match = false;
316                    if doc == TERMINATED {
317                        return TERMINATED;
318                    }
319                }
320            }
321
322            if all_match {
323                return max_doc;
324            }
325        }
326    }
327
328    /// Check if positions form a valid phrase for the given document
329    fn check_phrase_positions(&mut self, doc_id: DocId) -> bool {
330        // Get positions for each term into reusable buffers (zero allocation)
331        for (i, pos_list) in self.position_lists.iter().enumerate() {
332            if !pos_list.get_positions_into(doc_id, &mut self.position_bufs[i]) {
333                return false;
334            }
335        }
336
337        // Check for consecutive positions
338        // For exact phrase (slop=0), position[i+1] = position[i] + 1
339        self.find_phrase_match_from_bufs()
340    }
341
342    /// Find phrase match using the internal reusable buffers
343    fn find_phrase_match_from_bufs(&self) -> bool {
344        if self.position_bufs.is_empty() || self.position_bufs[0].is_empty() {
345            return false;
346        }
347
348        for &first_pos in &self.position_bufs[0] {
349            if self.check_phrase_from_position(first_pos, &self.position_bufs) {
350                return true;
351            }
352        }
353
354        false
355    }
356
357    /// Check if a phrase exists starting from the given position
358    fn check_phrase_from_position(&self, start_pos: u32, term_positions: &[Vec<u32>]) -> bool {
359        let mut expected_pos = start_pos;
360
361        for (i, positions) in term_positions.iter().enumerate() {
362            if i == 0 {
363                continue; // Skip first term, already matched
364            }
365
366            expected_pos += 1;
367
368            // Find a position within slop distance
369            let found = positions.iter().any(|&pos| {
370                if self.slop == 0 {
371                    pos == expected_pos
372                } else {
373                    let diff = pos.abs_diff(expected_pos);
374                    diff <= self.slop
375                }
376            });
377
378            if !found {
379                return false;
380            }
381        }
382
383        true
384    }
385}
386
387impl super::docset::DocSet for PhraseScorer {
388    fn doc(&self) -> DocId {
389        self.current_doc
390    }
391
392    fn advance(&mut self) -> DocId {
393        if self.current_doc == TERMINATED {
394            return TERMINATED;
395        }
396
397        self.posting_iters[0].advance();
398        self.find_next_phrase_match();
399        self.current_doc
400    }
401
402    fn seek(&mut self, target: DocId) -> DocId {
403        if target == TERMINATED {
404            self.current_doc = TERMINATED;
405            return TERMINATED;
406        }
407
408        self.posting_iters[0].seek(target);
409        self.find_next_phrase_match();
410        self.current_doc
411    }
412
413    fn size_hint(&self) -> u32 {
414        0
415    }
416}
417
418impl Scorer for PhraseScorer {
419    fn score(&self) -> Score {
420        if self.current_doc == TERMINATED {
421            return 0.0;
422        }
423
424        // Sum term frequencies for BM25 scoring
425        let tf: f32 = self
426            .posting_iters
427            .iter()
428            .map(|it| it.term_freq() as f32)
429            .sum();
430
431        // Phrase matches get a boost since they're more precise
432        super::bm25_score(tf, self.idf, tf, self.avg_field_len) * 1.5
433    }
434}