Skip to main content

hermes_core/query/
phrase.rs

1//! Phrase query - matches documents containing terms in consecutive positions
2
3use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::{BlockPostingIterator, BlockPostingList, PositionPostingList, TERMINATED};
8use crate::{DocId, Score};
9
10use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture};
11
12/// Phrase query - matches documents containing terms in consecutive positions
13///
14/// Example: "quick brown fox" matches only if all three terms appear
15/// consecutively in the document.
16#[derive(Clone)]
17pub struct PhraseQuery {
18    pub field: Field,
19    /// Terms in the phrase, in order
20    pub terms: Vec<Vec<u8>>,
21    /// Optional slop (max distance between terms, 0 = exact phrase)
22    pub slop: u32,
23    /// Optional global statistics for cross-segment IDF
24    global_stats: Option<Arc<GlobalStats>>,
25}
26
27impl std::fmt::Display for PhraseQuery {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        let terms: Vec<_> = self
30            .terms
31            .iter()
32            .map(|t| String::from_utf8_lossy(t))
33            .collect();
34        write!(f, "Phrase({}:\"{}\"", self.field.0, terms.join(" "))?;
35        if self.slop > 0 {
36            write!(f, "~{}", self.slop)?;
37        }
38        write!(f, ")")
39    }
40}
41
42impl std::fmt::Debug for PhraseQuery {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        let terms: Vec<_> = self
45            .terms
46            .iter()
47            .map(|t| String::from_utf8_lossy(t).to_string())
48            .collect();
49        f.debug_struct("PhraseQuery")
50            .field("field", &self.field)
51            .field("terms", &terms)
52            .field("slop", &self.slop)
53            .finish()
54    }
55}
56
57impl PhraseQuery {
58    /// Create a new exact phrase query
59    pub fn new(field: Field, terms: Vec<Vec<u8>>) -> Self {
60        Self {
61            field,
62            terms,
63            slop: 0,
64            global_stats: None,
65        }
66    }
67
68    /// Create from text (splits on whitespace and lowercases)
69    pub fn text(field: Field, phrase: &str) -> Self {
70        let terms: Vec<Vec<u8>> = phrase
71            .split_whitespace()
72            .map(|w| w.to_lowercase().into_bytes())
73            .collect();
74        Self {
75            field,
76            terms,
77            slop: 0,
78            global_stats: None,
79        }
80    }
81
82    /// Set slop (max distance between terms)
83    pub fn with_slop(mut self, slop: u32) -> Self {
84        self.slop = slop;
85        self
86    }
87
88    /// Set global statistics for cross-segment IDF
89    pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
90        self.global_stats = Some(stats);
91        self
92    }
93}
94
95impl Query for PhraseQuery {
96    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
97        let field = self.field;
98        let terms = self.terms.clone();
99        let slop = self.slop;
100        let _global_stats = self.global_stats.clone();
101
102        Box::pin(async move {
103            if terms.is_empty() {
104                return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>);
105            }
106
107            // Single term - delegate to TermQuery
108            if terms.len() == 1 {
109                let term_query = super::TermQuery::new(field, terms[0].clone());
110                return term_query.scorer(reader, limit).await;
111            }
112
113            // Check if positions are available
114            if !reader.has_positions(field) {
115                // Fall back to AND query (BooleanQuery with MUST clauses)
116                let mut bool_query = super::BooleanQuery::new();
117                for term in &terms {
118                    bool_query = bool_query.must(super::TermQuery::new(field, term.clone()));
119                }
120                return bool_query.scorer(reader, limit).await;
121            }
122
123            // Load postings and positions for all terms (parallel per term)
124            let mut term_postings: Vec<BlockPostingList> = Vec::with_capacity(terms.len());
125            let mut term_positions: Vec<PositionPostingList> = Vec::with_capacity(terms.len());
126
127            for term in &terms {
128                // Fetch postings and positions in parallel
129                let (postings, positions) = futures::join!(
130                    reader.get_postings(field, term),
131                    reader.get_positions(field, term)
132                );
133
134                match (postings?, positions?) {
135                    (Some(p), Some(pos)) => {
136                        term_postings.push(p);
137                        term_positions.push(pos);
138                    }
139                    _ => {
140                        // If any term is missing, no documents can match
141                        return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>);
142                    }
143                }
144            }
145
146            // Compute combined IDF (sum of individual IDFs)
147            let idf: f32 = term_postings
148                .iter()
149                .map(|p| {
150                    let num_docs = reader.num_docs() as f32;
151                    let doc_freq = p.doc_count() as f32;
152                    super::bm25_idf(doc_freq, num_docs)
153                })
154                .sum();
155
156            let avg_field_len = reader.avg_field_len(field);
157
158            Ok(Box::new(PhraseScorer::new(
159                term_postings,
160                term_positions,
161                slop,
162                idf,
163                avg_field_len,
164            )) as Box<dyn Scorer + 'a>)
165        })
166    }
167
168    #[cfg(feature = "sync")]
169    fn scorer_sync<'a>(
170        &self,
171        reader: &'a SegmentReader,
172        limit: usize,
173    ) -> crate::Result<Box<dyn Scorer + 'a>> {
174        if self.terms.is_empty() {
175            return Ok(Box::new(super::EmptyScorer) as Box<dyn Scorer + 'a>);
176        }
177
178        if self.terms.len() == 1 {
179            let term_query = super::TermQuery::new(self.field, self.terms[0].clone());
180            return term_query.scorer_sync(reader, limit);
181        }
182
183        if !reader.has_positions(self.field) {
184            let mut bool_query = super::BooleanQuery::new();
185            for term in &self.terms {
186                bool_query = bool_query.must(super::TermQuery::new(self.field, term.clone()));
187            }
188            return bool_query.scorer_sync(reader, limit);
189        }
190
191        let mut term_postings: Vec<BlockPostingList> = Vec::with_capacity(self.terms.len());
192        let mut term_positions: Vec<PositionPostingList> = Vec::with_capacity(self.terms.len());
193
194        for term in &self.terms {
195            let postings = reader.get_postings_sync(self.field, term)?;
196            let positions = reader.get_positions_sync(self.field, term)?;
197
198            match (postings, positions) {
199                (Some(p), Some(pos)) => {
200                    term_postings.push(p);
201                    term_positions.push(pos);
202                }
203                _ => return Ok(Box::new(super::EmptyScorer) as Box<dyn Scorer + 'a>),
204            }
205        }
206
207        let idf: f32 = term_postings
208            .iter()
209            .map(|p| {
210                let num_docs = reader.num_docs() as f32;
211                let doc_freq = p.doc_count() as f32;
212                super::bm25_idf(doc_freq, num_docs)
213            })
214            .sum();
215
216        let avg_field_len = reader.avg_field_len(self.field);
217
218        Ok(Box::new(PhraseScorer::new(
219            term_postings,
220            term_positions,
221            self.slop,
222            idf,
223            avg_field_len,
224        )) as Box<dyn Scorer + 'a>)
225    }
226
227    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
228        let field = self.field;
229        let terms = self.terms.clone();
230
231        Box::pin(async move {
232            if terms.is_empty() {
233                return Ok(0);
234            }
235
236            // Estimate based on minimum posting list size
237            let mut min_count = u32::MAX;
238            for term in &terms {
239                match reader.get_postings(field, term).await? {
240                    Some(list) => min_count = min_count.min(list.doc_count()),
241                    None => return Ok(0),
242                }
243            }
244
245            // Phrase matching will typically match fewer docs than the minimum
246            // Estimate ~10% of the smallest posting list
247            Ok((min_count / 10).max(1))
248        })
249    }
250}
251
252/// Scorer that checks phrase positions
253struct PhraseScorer {
254    /// Posting iterators for each term
255    posting_iters: Vec<BlockPostingIterator<'static>>,
256    /// Position iterators for each term
257    position_lists: Vec<PositionPostingList>,
258    /// Max slop between terms
259    slop: u32,
260    /// Current matching document
261    current_doc: DocId,
262    /// Combined IDF
263    idf: f32,
264    /// Average field length
265    avg_field_len: f32,
266}
267
268impl PhraseScorer {
269    fn new(
270        posting_lists: Vec<BlockPostingList>,
271        position_lists: Vec<PositionPostingList>,
272        slop: u32,
273        idf: f32,
274        avg_field_len: f32,
275    ) -> Self {
276        let posting_iters: Vec<_> = posting_lists
277            .into_iter()
278            .map(|p| p.into_iterator())
279            .collect();
280
281        let mut scorer = Self {
282            posting_iters,
283            position_lists,
284            slop,
285            current_doc: 0,
286            idf,
287            avg_field_len,
288        };
289
290        scorer.find_next_phrase_match();
291        scorer
292    }
293
294    /// Find next document where all terms appear as a phrase
295    fn find_next_phrase_match(&mut self) {
296        loop {
297            // First, find a document where all terms appear (AND semantics)
298            let doc = self.find_next_and_match();
299            if doc == TERMINATED {
300                self.current_doc = TERMINATED;
301                return;
302            }
303
304            // Check if positions form a valid phrase
305            if self.check_phrase_positions(doc) {
306                self.current_doc = doc;
307                return;
308            }
309
310            // Advance and try again
311            self.posting_iters[0].advance();
312        }
313    }
314
315    /// Find next document where all terms appear
316    fn find_next_and_match(&mut self) -> DocId {
317        if self.posting_iters.is_empty() {
318            return TERMINATED;
319        }
320
321        loop {
322            let max_doc = self.posting_iters.iter().map(|it| it.doc()).max().unwrap();
323
324            if max_doc == TERMINATED {
325                return TERMINATED;
326            }
327
328            let mut all_match = true;
329            for it in &mut self.posting_iters {
330                let doc = it.seek(max_doc);
331                if doc != max_doc {
332                    all_match = false;
333                    if doc == TERMINATED {
334                        return TERMINATED;
335                    }
336                }
337            }
338
339            if all_match {
340                return max_doc;
341            }
342        }
343    }
344
345    /// Check if positions form a valid phrase for the given document
346    fn check_phrase_positions(&self, doc_id: DocId) -> bool {
347        // Get positions for each term in this document
348        let mut term_positions: Vec<Vec<u32>> = Vec::with_capacity(self.position_lists.len());
349
350        for pos_list in &self.position_lists {
351            match pos_list.get_positions(doc_id) {
352                Some(positions) => term_positions.push(positions.to_vec()),
353                None => return false,
354            }
355        }
356
357        // Check for consecutive positions
358        // For exact phrase (slop=0), position[i+1] = position[i] + 1
359        self.find_phrase_match(&term_positions)
360    }
361
362    /// Find if there's a valid phrase match among the positions
363    fn find_phrase_match(&self, term_positions: &[Vec<u32>]) -> bool {
364        if term_positions.is_empty() {
365            return false;
366        }
367
368        // For each position of the first term, check if subsequent terms
369        // have positions that form a phrase
370        for &first_pos in &term_positions[0] {
371            if self.check_phrase_from_position(first_pos, term_positions) {
372                return true;
373            }
374        }
375
376        false
377    }
378
379    /// Check if a phrase exists starting from the given position
380    fn check_phrase_from_position(&self, start_pos: u32, term_positions: &[Vec<u32>]) -> bool {
381        let mut expected_pos = start_pos;
382
383        for (i, positions) in term_positions.iter().enumerate() {
384            if i == 0 {
385                continue; // Skip first term, already matched
386            }
387
388            expected_pos += 1;
389
390            // Find a position within slop distance
391            let found = positions.iter().any(|&pos| {
392                if self.slop == 0 {
393                    pos == expected_pos
394                } else {
395                    let diff = pos.abs_diff(expected_pos);
396                    diff <= self.slop
397                }
398            });
399
400            if !found {
401                return false;
402            }
403        }
404
405        true
406    }
407}
408
409impl super::docset::DocSet for PhraseScorer {
410    fn doc(&self) -> DocId {
411        self.current_doc
412    }
413
414    fn advance(&mut self) -> DocId {
415        if self.current_doc == TERMINATED {
416            return TERMINATED;
417        }
418
419        self.posting_iters[0].advance();
420        self.find_next_phrase_match();
421        self.current_doc
422    }
423
424    fn seek(&mut self, target: DocId) -> DocId {
425        if target == TERMINATED {
426            self.current_doc = TERMINATED;
427            return TERMINATED;
428        }
429
430        self.posting_iters[0].seek(target);
431        self.find_next_phrase_match();
432        self.current_doc
433    }
434
435    fn size_hint(&self) -> u32 {
436        0
437    }
438}
439
440impl Scorer for PhraseScorer {
441    fn score(&self) -> Score {
442        if self.current_doc == TERMINATED {
443            return 0.0;
444        }
445
446        // Sum term frequencies for BM25 scoring
447        let tf: f32 = self
448            .posting_iters
449            .iter()
450            .map(|it| it.term_freq() as f32)
451            .sum();
452
453        // Phrase matches get a boost since they're more precise
454        super::bm25_score(tf, self.idf, tf, self.avg_field_len) * 1.5
455    }
456}