hermes_core/query/
wand_or.rs1use std::sync::Arc;
7
8use crate::dsl::Field;
9use crate::segment::SegmentReader;
10use crate::{DocId, Score};
11
12use super::{
13 CountFuture, GlobalStats, Query, ScoredDoc, Scorer, ScorerFuture, TextTermScorer, WandExecutor,
14};
15
16#[derive(Clone)]
29pub struct WandOrQuery {
30 pub field: Field,
32 pub terms: Vec<String>,
34 global_stats: Option<Arc<GlobalStats>>,
36}
37
38impl std::fmt::Debug for WandOrQuery {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 f.debug_struct("WandOrQuery")
41 .field("field", &self.field)
42 .field("terms", &self.terms)
43 .field("has_global_stats", &self.global_stats.is_some())
44 .finish()
45 }
46}
47
48impl WandOrQuery {
49 pub fn new(field: Field) -> Self {
51 Self {
52 field,
53 terms: Vec::new(),
54 global_stats: None,
55 }
56 }
57
58 pub fn term(mut self, term: impl Into<String>) -> Self {
60 self.terms.push(term.into().to_lowercase());
61 self
62 }
63
64 pub fn terms(mut self, terms: impl IntoIterator<Item = impl Into<String>>) -> Self {
66 for t in terms {
67 self.terms.push(t.into().to_lowercase());
68 }
69 self
70 }
71
72 pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
74 self.global_stats = Some(stats);
75 self
76 }
77}
78
79impl Query for WandOrQuery {
80 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
81 let field = self.field;
82 let terms = self.terms.clone();
83 let global_stats = self.global_stats.clone();
84
85 Box::pin(async move {
86 let mut scorers: Vec<TextTermScorer> = Vec::with_capacity(terms.len());
87
88 let avg_field_len = global_stats
90 .as_ref()
91 .map(|s| s.avg_field_len(field))
92 .unwrap_or_else(|| reader.avg_field_len(field));
93
94 let num_docs = reader.num_docs() as f32;
95
96 for term in &terms {
97 let term_bytes = term.as_bytes();
98
99 if let Some(posting_list) = reader.get_postings(field, term_bytes).await? {
100 let doc_freq = posting_list.doc_count() as f32;
102 let idf = if let Some(ref stats) = global_stats {
103 let global_idf = stats.text_idf(field, term);
104 if global_idf > 0.0 {
105 global_idf
106 } else {
107 super::bm25_idf(doc_freq, num_docs)
108 }
109 } else {
110 super::bm25_idf(doc_freq, num_docs)
111 };
112
113 scorers.push(TextTermScorer::new(posting_list, idf, avg_field_len));
114 }
115 }
116
117 if scorers.is_empty() {
118 return Ok(Box::new(EmptyWandScorer) as Box<dyn Scorer + 'a>);
119 }
120
121 let results = WandExecutor::new(scorers, limit).execute();
123
124 Ok(Box::new(WandResultScorer::new(results)) as Box<dyn Scorer + 'a>)
125 })
126 }
127
128 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
129 let field = self.field;
130 let terms = self.terms.clone();
131
132 Box::pin(async move {
133 let mut sum = 0u32;
134 for term in &terms {
135 if let Some(posting_list) = reader.get_postings(field, term.as_bytes()).await? {
136 sum += posting_list.doc_count();
137 }
138 }
139 Ok(sum)
140 })
141 }
142}
143
144struct WandResultScorer {
146 results: Vec<ScoredDoc>,
147 position: usize,
148}
149
150impl WandResultScorer {
151 fn new(results: Vec<ScoredDoc>) -> Self {
152 Self {
153 results,
154 position: 0,
155 }
156 }
157}
158
159impl Scorer for WandResultScorer {
160 fn doc(&self) -> DocId {
161 if self.position < self.results.len() {
162 self.results[self.position].doc_id
163 } else {
164 crate::structures::TERMINATED
165 }
166 }
167
168 fn score(&self) -> Score {
169 if self.position < self.results.len() {
170 self.results[self.position].score
171 } else {
172 0.0
173 }
174 }
175
176 fn advance(&mut self) -> DocId {
177 self.position += 1;
178 self.doc()
179 }
180
181 fn seek(&mut self, target: DocId) -> DocId {
182 while self.position < self.results.len() && self.results[self.position].doc_id < target {
183 self.position += 1;
184 }
185 self.doc()
186 }
187
188 fn size_hint(&self) -> u32 {
189 self.results.len() as u32
190 }
191}
192
193struct EmptyWandScorer;
195
196impl Scorer for EmptyWandScorer {
197 fn doc(&self) -> DocId {
198 crate::structures::TERMINATED
199 }
200
201 fn score(&self) -> Score {
202 0.0
203 }
204
205 fn advance(&mut self) -> DocId {
206 crate::structures::TERMINATED
207 }
208
209 fn seek(&mut self, _target: DocId) -> DocId {
210 crate::structures::TERMINATED
211 }
212
213 fn size_hint(&self) -> u32 {
214 0
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 fn test_wand_or_query_builder() {
224 let query = WandOrQuery::new(Field(0))
225 .term("hello")
226 .term("world")
227 .terms(vec!["foo", "bar"]);
228
229 assert_eq!(query.terms.len(), 4);
230 assert_eq!(query.terms[0], "hello");
231 assert_eq!(query.terms[1], "world");
232 assert_eq!(query.terms[2], "foo");
233 assert_eq!(query.terms[3], "bar");
234 }
235}