hermes_core/query/
wand_or.rs

1//! WAND-optimized OR query for efficient multi-term full-text search
2//!
3//! Uses MaxScore WAND algorithm for efficient top-k retrieval when
4//! searching for documents matching any of multiple terms.
5
6use std::sync::Arc;
7
8use crate::dsl::Field;
9use crate::segment::SegmentReader;
10use crate::{DocId, Score};
11
12use super::{
13    CountFuture, GlobalStats, Query, ScoredDoc, Scorer, ScorerFuture, TextTermScorer, WandExecutor,
14};
15
16/// WAND-optimized OR query for multiple terms
17///
18/// More efficient than `BooleanQuery` with SHOULD clauses for top-k retrieval
19/// because it uses MaxScore pruning to skip low-scoring documents.
20///
21/// # Example
22/// ```ignore
23/// let query = WandOrQuery::new(field)
24///     .term("hello")
25///     .term("world");
26/// let results = index.search(&query, 10).await?;
27/// ```
28#[derive(Clone)]
29pub struct WandOrQuery {
30    /// Field to search
31    pub field: Field,
32    /// Terms to search for (OR semantics)
33    pub terms: Vec<String>,
34    /// Optional global statistics for cross-segment IDF
35    global_stats: Option<Arc<GlobalStats>>,
36}
37
38impl std::fmt::Debug for WandOrQuery {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        f.debug_struct("WandOrQuery")
41            .field("field", &self.field)
42            .field("terms", &self.terms)
43            .field("has_global_stats", &self.global_stats.is_some())
44            .finish()
45    }
46}
47
48impl WandOrQuery {
49    /// Create a new WAND OR query for a field
50    pub fn new(field: Field) -> Self {
51        Self {
52            field,
53            terms: Vec::new(),
54            global_stats: None,
55        }
56    }
57
58    /// Add a term to the OR query
59    pub fn term(mut self, term: impl Into<String>) -> Self {
60        self.terms.push(term.into().to_lowercase());
61        self
62    }
63
64    /// Add multiple terms
65    pub fn terms(mut self, terms: impl IntoIterator<Item = impl Into<String>>) -> Self {
66        for t in terms {
67            self.terms.push(t.into().to_lowercase());
68        }
69        self
70    }
71
72    /// Set global statistics for cross-segment IDF
73    pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
74        self.global_stats = Some(stats);
75        self
76    }
77}
78
79impl Query for WandOrQuery {
80    fn scorer<'a>(&'a self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
81        Box::pin(async move {
82            let mut scorers: Vec<TextTermScorer> = Vec::with_capacity(self.terms.len());
83
84            // Get avg field length (from global stats or segment)
85            let avg_field_len = self
86                .global_stats
87                .as_ref()
88                .map(|s| s.avg_field_len(self.field))
89                .unwrap_or_else(|| reader.avg_field_len(self.field));
90
91            let num_docs = reader.num_docs() as f32;
92
93            for term in &self.terms {
94                let term_bytes = term.as_bytes();
95
96                if let Some(posting_list) = reader.get_postings(self.field, term_bytes).await? {
97                    // Compute IDF
98                    let doc_freq = posting_list.doc_count() as f32;
99                    let idf = if let Some(ref stats) = self.global_stats {
100                        let global_idf = stats.text_idf(self.field, term);
101                        if global_idf > 0.0 {
102                            global_idf
103                        } else {
104                            super::bm25_idf(doc_freq, num_docs)
105                        }
106                    } else {
107                        super::bm25_idf(doc_freq, num_docs)
108                    };
109
110                    scorers.push(TextTermScorer::new(posting_list, idf, avg_field_len));
111                }
112            }
113
114            if scorers.is_empty() {
115                return Ok(Box::new(EmptyWandScorer) as Box<dyn Scorer + 'a>);
116            }
117
118            // Use WAND executor for efficient top-k
119            let results = WandExecutor::new(scorers, limit).execute();
120
121            Ok(Box::new(WandResultScorer::new(results)) as Box<dyn Scorer + 'a>)
122        })
123    }
124
125    fn count_estimate<'a>(&'a self, reader: &'a SegmentReader) -> CountFuture<'a> {
126        Box::pin(async move {
127            let mut sum = 0u32;
128            for term in &self.terms {
129                if let Some(posting_list) = reader.get_postings(self.field, term.as_bytes()).await?
130                {
131                    sum += posting_list.doc_count();
132                }
133            }
134            Ok(sum)
135        })
136    }
137}
138
139/// Scorer that iterates over pre-computed WAND results
140struct WandResultScorer {
141    results: Vec<ScoredDoc>,
142    position: usize,
143}
144
145impl WandResultScorer {
146    fn new(results: Vec<ScoredDoc>) -> Self {
147        Self {
148            results,
149            position: 0,
150        }
151    }
152}
153
154impl Scorer for WandResultScorer {
155    fn doc(&self) -> DocId {
156        if self.position < self.results.len() {
157            self.results[self.position].doc_id
158        } else {
159            crate::structures::TERMINATED
160        }
161    }
162
163    fn score(&self) -> Score {
164        if self.position < self.results.len() {
165            self.results[self.position].score
166        } else {
167            0.0
168        }
169    }
170
171    fn advance(&mut self) -> DocId {
172        self.position += 1;
173        self.doc()
174    }
175
176    fn seek(&mut self, target: DocId) -> DocId {
177        while self.position < self.results.len() && self.results[self.position].doc_id < target {
178            self.position += 1;
179        }
180        self.doc()
181    }
182
183    fn size_hint(&self) -> u32 {
184        self.results.len() as u32
185    }
186}
187
188/// Empty scorer for when no terms match
189struct EmptyWandScorer;
190
191impl Scorer for EmptyWandScorer {
192    fn doc(&self) -> DocId {
193        crate::structures::TERMINATED
194    }
195
196    fn score(&self) -> Score {
197        0.0
198    }
199
200    fn advance(&mut self) -> DocId {
201        crate::structures::TERMINATED
202    }
203
204    fn seek(&mut self, _target: DocId) -> DocId {
205        crate::structures::TERMINATED
206    }
207
208    fn size_hint(&self) -> u32 {
209        0
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn test_wand_or_query_builder() {
219        let query = WandOrQuery::new(Field(0))
220            .term("hello")
221            .term("world")
222            .terms(vec!["foo", "bar"]);
223
224        assert_eq!(query.terms.len(), 4);
225        assert_eq!(query.terms[0], "hello");
226        assert_eq!(query.terms[1], "world");
227        assert_eq!(query.terms[2], "foo");
228        assert_eq!(query.terms[3], "bar");
229    }
230}