Skip to main content

hermes_core/query/
wand_or.rs

1//! WAND-optimized OR query for efficient multi-term full-text search
2//!
3//! Uses MaxScore WAND algorithm for efficient top-k retrieval when
4//! searching for documents matching any of multiple terms.
5
6use std::sync::Arc;
7
8use crate::dsl::Field;
9use crate::segment::SegmentReader;
10use crate::{DocId, Score};
11
12use super::{
13    CountFuture, GlobalStats, Query, ScoredDoc, Scorer, ScorerFuture, TextTermScorer, WandExecutor,
14};
15
16/// WAND-optimized OR query for multiple terms
17///
18/// More efficient than `BooleanQuery` with SHOULD clauses for top-k retrieval
19/// because it uses MaxScore pruning to skip low-scoring documents.
20///
21/// # Example
22/// ```ignore
23/// let query = WandOrQuery::new(field)
24///     .term("hello")
25///     .term("world");
26/// let results = index.search(&query, 10).await?;
27/// ```
28#[derive(Clone)]
29pub struct WandOrQuery {
30    /// Field to search
31    pub field: Field,
32    /// Terms to search for (OR semantics)
33    pub terms: Vec<String>,
34    /// Optional global statistics for cross-segment IDF
35    global_stats: Option<Arc<GlobalStats>>,
36}
37
38impl std::fmt::Debug for WandOrQuery {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        f.debug_struct("WandOrQuery")
41            .field("field", &self.field)
42            .field("terms", &self.terms)
43            .field("has_global_stats", &self.global_stats.is_some())
44            .finish()
45    }
46}
47
48impl WandOrQuery {
49    /// Create a new WAND OR query for a field
50    pub fn new(field: Field) -> Self {
51        Self {
52            field,
53            terms: Vec::new(),
54            global_stats: None,
55        }
56    }
57
58    /// Add a term to the OR query
59    pub fn term(mut self, term: impl Into<String>) -> Self {
60        self.terms.push(term.into().to_lowercase());
61        self
62    }
63
64    /// Add multiple terms
65    pub fn terms(mut self, terms: impl IntoIterator<Item = impl Into<String>>) -> Self {
66        for t in terms {
67            self.terms.push(t.into().to_lowercase());
68        }
69        self
70    }
71
72    /// Set global statistics for cross-segment IDF
73    pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
74        self.global_stats = Some(stats);
75        self
76    }
77}
78
79impl Query for WandOrQuery {
80    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
81        let field = self.field;
82        let terms = self.terms.clone();
83        let global_stats = self.global_stats.clone();
84
85        Box::pin(async move {
86            let mut scorers: Vec<TextTermScorer> = Vec::with_capacity(terms.len());
87
88            // Get avg field length (from global stats or segment)
89            let avg_field_len = global_stats
90                .as_ref()
91                .map(|s| s.avg_field_len(field))
92                .unwrap_or_else(|| reader.avg_field_len(field));
93
94            let num_docs = reader.num_docs() as f32;
95
96            for term in &terms {
97                let term_bytes = term.as_bytes();
98
99                if let Some(posting_list) = reader.get_postings(field, term_bytes).await? {
100                    // Compute IDF
101                    let doc_freq = posting_list.doc_count() as f32;
102                    let idf = if let Some(ref stats) = global_stats {
103                        let global_idf = stats.text_idf(field, term);
104                        if global_idf > 0.0 {
105                            global_idf
106                        } else {
107                            super::bm25_idf(doc_freq, num_docs)
108                        }
109                    } else {
110                        super::bm25_idf(doc_freq, num_docs)
111                    };
112
113                    scorers.push(TextTermScorer::new(posting_list, idf, avg_field_len));
114                }
115            }
116
117            if scorers.is_empty() {
118                return Ok(Box::new(EmptyWandScorer) as Box<dyn Scorer + 'a>);
119            }
120
121            // Use WAND executor for efficient top-k
122            let results = WandExecutor::new(scorers, limit).execute();
123
124            Ok(Box::new(WandResultScorer::new(results)) as Box<dyn Scorer + 'a>)
125        })
126    }
127
128    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
129        let field = self.field;
130        let terms = self.terms.clone();
131
132        Box::pin(async move {
133            let mut sum = 0u32;
134            for term in &terms {
135                if let Some(posting_list) = reader.get_postings(field, term.as_bytes()).await? {
136                    sum += posting_list.doc_count();
137                }
138            }
139            Ok(sum)
140        })
141    }
142}
143
144/// Scorer that iterates over pre-computed WAND results
145struct WandResultScorer {
146    results: Vec<ScoredDoc>,
147    position: usize,
148}
149
150impl WandResultScorer {
151    fn new(results: Vec<ScoredDoc>) -> Self {
152        Self {
153            results,
154            position: 0,
155        }
156    }
157}
158
159impl Scorer for WandResultScorer {
160    fn doc(&self) -> DocId {
161        if self.position < self.results.len() {
162            self.results[self.position].doc_id
163        } else {
164            crate::structures::TERMINATED
165        }
166    }
167
168    fn score(&self) -> Score {
169        if self.position < self.results.len() {
170            self.results[self.position].score
171        } else {
172            0.0
173        }
174    }
175
176    fn advance(&mut self) -> DocId {
177        self.position += 1;
178        self.doc()
179    }
180
181    fn seek(&mut self, target: DocId) -> DocId {
182        while self.position < self.results.len() && self.results[self.position].doc_id < target {
183            self.position += 1;
184        }
185        self.doc()
186    }
187
188    fn size_hint(&self) -> u32 {
189        self.results.len() as u32
190    }
191}
192
193/// Empty scorer for when no terms match
194struct EmptyWandScorer;
195
196impl Scorer for EmptyWandScorer {
197    fn doc(&self) -> DocId {
198        crate::structures::TERMINATED
199    }
200
201    fn score(&self) -> Score {
202        0.0
203    }
204
205    fn advance(&mut self) -> DocId {
206        crate::structures::TERMINATED
207    }
208
209    fn seek(&mut self, _target: DocId) -> DocId {
210        crate::structures::TERMINATED
211    }
212
213    fn size_hint(&self) -> u32 {
214        0
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    #[test]
223    fn test_wand_or_query_builder() {
224        let query = WandOrQuery::new(Field(0))
225            .term("hello")
226            .term("world")
227            .terms(vec!["foo", "bar"]);
228
229        assert_eq!(query.terms.len(), 4);
230        assert_eq!(query.terms[0], "hello");
231        assert_eq!(query.terms[1], "world");
232        assert_eq!(query.terms[2], "foo");
233        assert_eq!(query.terms[3], "bar");
234    }
235}