Skip to main content

hermes_core/query/
phrase.rs

1//! Phrase query - matches documents containing terms in consecutive positions
2
3use std::sync::Arc;
4
5use crate::dsl::Field;
6use crate::segment::SegmentReader;
7use crate::structures::{BlockPostingIterator, BlockPostingList, PositionPostingList, TERMINATED};
8use crate::{DocId, Score};
9
10use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture};
11
12/// Phrase query - matches documents containing terms in consecutive positions
13///
14/// Example: "quick brown fox" matches only if all three terms appear
15/// consecutively in the document.
16#[derive(Clone)]
17pub struct PhraseQuery {
18    pub field: Field,
19    /// Terms in the phrase, in order
20    pub terms: Vec<Vec<u8>>,
21    /// Optional slop (max distance between terms, 0 = exact phrase)
22    pub slop: u32,
23    /// Optional global statistics for cross-segment IDF
24    global_stats: Option<Arc<GlobalStats>>,
25}
26
27impl std::fmt::Debug for PhraseQuery {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        let terms: Vec<_> = self
30            .terms
31            .iter()
32            .map(|t| String::from_utf8_lossy(t).to_string())
33            .collect();
34        f.debug_struct("PhraseQuery")
35            .field("field", &self.field)
36            .field("terms", &terms)
37            .field("slop", &self.slop)
38            .finish()
39    }
40}
41
42impl PhraseQuery {
43    /// Create a new exact phrase query
44    pub fn new(field: Field, terms: Vec<Vec<u8>>) -> Self {
45        Self {
46            field,
47            terms,
48            slop: 0,
49            global_stats: None,
50        }
51    }
52
53    /// Create from text (splits on whitespace and lowercases)
54    pub fn text(field: Field, phrase: &str) -> Self {
55        let terms: Vec<Vec<u8>> = phrase
56            .split_whitespace()
57            .map(|w| w.to_lowercase().into_bytes())
58            .collect();
59        Self {
60            field,
61            terms,
62            slop: 0,
63            global_stats: None,
64        }
65    }
66
67    /// Set slop (max distance between terms)
68    pub fn with_slop(mut self, slop: u32) -> Self {
69        self.slop = slop;
70        self
71    }
72
73    /// Set global statistics for cross-segment IDF
74    pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
75        self.global_stats = Some(stats);
76        self
77    }
78}
79
80impl Query for PhraseQuery {
81    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
82        let field = self.field;
83        let terms = self.terms.clone();
84        let slop = self.slop;
85        let _global_stats = self.global_stats.clone();
86
87        Box::pin(async move {
88            if terms.is_empty() {
89                return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>);
90            }
91
92            // Single term - delegate to TermQuery
93            if terms.len() == 1 {
94                let term_query = super::TermQuery::new(field, terms[0].clone());
95                return term_query.scorer(reader, limit).await;
96            }
97
98            // Check if positions are available
99            if !reader.has_positions(field) {
100                // Fall back to AND query (BooleanQuery with MUST clauses)
101                let mut bool_query = super::BooleanQuery::new();
102                for term in &terms {
103                    bool_query = bool_query.must(super::TermQuery::new(field, term.clone()));
104                }
105                return bool_query.scorer(reader, limit).await;
106            }
107
108            // Load postings and positions for all terms (parallel per term)
109            let mut term_postings: Vec<BlockPostingList> = Vec::with_capacity(terms.len());
110            let mut term_positions: Vec<PositionPostingList> = Vec::with_capacity(terms.len());
111
112            for term in &terms {
113                // Fetch postings and positions in parallel
114                let (postings, positions) = futures::join!(
115                    reader.get_postings(field, term),
116                    reader.get_positions(field, term)
117                );
118
119                match (postings?, positions?) {
120                    (Some(p), Some(pos)) => {
121                        term_postings.push(p);
122                        term_positions.push(pos);
123                    }
124                    _ => {
125                        // If any term is missing, no documents can match
126                        return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>);
127                    }
128                }
129            }
130
131            // Compute combined IDF (sum of individual IDFs)
132            let idf: f32 = term_postings
133                .iter()
134                .map(|p| {
135                    let num_docs = reader.num_docs() as f32;
136                    let doc_freq = p.doc_count() as f32;
137                    super::bm25_idf(doc_freq, num_docs)
138                })
139                .sum();
140
141            let avg_field_len = reader.avg_field_len(field);
142
143            Ok(Box::new(PhraseScorer::new(
144                term_postings,
145                term_positions,
146                slop,
147                idf,
148                avg_field_len,
149            )) as Box<dyn Scorer + 'a>)
150        })
151    }
152
153    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
154        let field = self.field;
155        let terms = self.terms.clone();
156
157        Box::pin(async move {
158            if terms.is_empty() {
159                return Ok(0);
160            }
161
162            // Estimate based on minimum posting list size
163            let mut min_count = u32::MAX;
164            for term in &terms {
165                match reader.get_postings(field, term).await? {
166                    Some(list) => min_count = min_count.min(list.doc_count()),
167                    None => return Ok(0),
168                }
169            }
170
171            // Phrase matching will typically match fewer docs than the minimum
172            // Estimate ~10% of the smallest posting list
173            Ok((min_count / 10).max(1))
174        })
175    }
176}
177
178/// Scorer that checks phrase positions
179struct PhraseScorer {
180    /// Posting iterators for each term
181    posting_iters: Vec<BlockPostingIterator<'static>>,
182    /// Position iterators for each term
183    position_lists: Vec<PositionPostingList>,
184    /// Max slop between terms
185    slop: u32,
186    /// Current matching document
187    current_doc: DocId,
188    /// Combined IDF
189    idf: f32,
190    /// Average field length
191    avg_field_len: f32,
192}
193
194impl PhraseScorer {
195    fn new(
196        posting_lists: Vec<BlockPostingList>,
197        position_lists: Vec<PositionPostingList>,
198        slop: u32,
199        idf: f32,
200        avg_field_len: f32,
201    ) -> Self {
202        let posting_iters: Vec<_> = posting_lists
203            .into_iter()
204            .map(|p| p.into_iterator())
205            .collect();
206
207        let mut scorer = Self {
208            posting_iters,
209            position_lists,
210            slop,
211            current_doc: 0,
212            idf,
213            avg_field_len,
214        };
215
216        scorer.find_next_phrase_match();
217        scorer
218    }
219
220    /// Find next document where all terms appear as a phrase
221    fn find_next_phrase_match(&mut self) {
222        loop {
223            // First, find a document where all terms appear (AND semantics)
224            let doc = self.find_next_and_match();
225            if doc == TERMINATED {
226                self.current_doc = TERMINATED;
227                return;
228            }
229
230            // Check if positions form a valid phrase
231            if self.check_phrase_positions(doc) {
232                self.current_doc = doc;
233                return;
234            }
235
236            // Advance and try again
237            self.posting_iters[0].advance();
238        }
239    }
240
241    /// Find next document where all terms appear
242    fn find_next_and_match(&mut self) -> DocId {
243        if self.posting_iters.is_empty() {
244            return TERMINATED;
245        }
246
247        loop {
248            let max_doc = self.posting_iters.iter().map(|it| it.doc()).max().unwrap();
249
250            if max_doc == TERMINATED {
251                return TERMINATED;
252            }
253
254            let mut all_match = true;
255            for it in &mut self.posting_iters {
256                let doc = it.seek(max_doc);
257                if doc != max_doc {
258                    all_match = false;
259                    if doc == TERMINATED {
260                        return TERMINATED;
261                    }
262                }
263            }
264
265            if all_match {
266                return max_doc;
267            }
268        }
269    }
270
271    /// Check if positions form a valid phrase for the given document
272    fn check_phrase_positions(&self, doc_id: DocId) -> bool {
273        // Get positions for each term in this document
274        let mut term_positions: Vec<Vec<u32>> = Vec::with_capacity(self.position_lists.len());
275
276        for pos_list in &self.position_lists {
277            match pos_list.get_positions(doc_id) {
278                Some(positions) => term_positions.push(positions.to_vec()),
279                None => return false,
280            }
281        }
282
283        // Check for consecutive positions
284        // For exact phrase (slop=0), position[i+1] = position[i] + 1
285        self.find_phrase_match(&term_positions)
286    }
287
288    /// Find if there's a valid phrase match among the positions
289    fn find_phrase_match(&self, term_positions: &[Vec<u32>]) -> bool {
290        if term_positions.is_empty() {
291            return false;
292        }
293
294        // For each position of the first term, check if subsequent terms
295        // have positions that form a phrase
296        for &first_pos in &term_positions[0] {
297            if self.check_phrase_from_position(first_pos, term_positions) {
298                return true;
299            }
300        }
301
302        false
303    }
304
305    /// Check if a phrase exists starting from the given position
306    fn check_phrase_from_position(&self, start_pos: u32, term_positions: &[Vec<u32>]) -> bool {
307        let mut expected_pos = start_pos;
308
309        for (i, positions) in term_positions.iter().enumerate() {
310            if i == 0 {
311                continue; // Skip first term, already matched
312            }
313
314            expected_pos += 1;
315
316            // Find a position within slop distance
317            let found = positions.iter().any(|&pos| {
318                if self.slop == 0 {
319                    pos == expected_pos
320                } else {
321                    let diff = pos.abs_diff(expected_pos);
322                    diff <= self.slop
323                }
324            });
325
326            if !found {
327                return false;
328            }
329        }
330
331        true
332    }
333}
334
335impl Scorer for PhraseScorer {
336    fn doc(&self) -> DocId {
337        self.current_doc
338    }
339
340    fn score(&self) -> Score {
341        if self.current_doc == TERMINATED {
342            return 0.0;
343        }
344
345        // Sum term frequencies for BM25 scoring
346        let tf: f32 = self
347            .posting_iters
348            .iter()
349            .map(|it| it.term_freq() as f32)
350            .sum();
351
352        // Phrase matches get a boost since they're more precise
353        super::bm25_score(tf, self.idf, tf, self.avg_field_len) * 1.5
354    }
355
356    fn advance(&mut self) -> DocId {
357        if self.current_doc == TERMINATED {
358            return TERMINATED;
359        }
360
361        self.posting_iters[0].advance();
362        self.find_next_phrase_match();
363        self.current_doc
364    }
365
366    fn seek(&mut self, target: DocId) -> DocId {
367        if target == TERMINATED {
368            self.current_doc = TERMINATED;
369            return TERMINATED;
370        }
371
372        self.posting_iters[0].seek(target);
373        self.find_next_phrase_match();
374        self.current_doc
375    }
376
377    fn size_hint(&self) -> u32 {
378        0
379    }
380}