hermes_core/query/
wand_or.rs1use std::sync::Arc;
7
8use crate::dsl::Field;
9use crate::segment::SegmentReader;
10use crate::{DocId, Score};
11
12use super::{
13 CountFuture, GlobalStats, Query, ScoredDoc, Scorer, ScorerFuture, TextTermScorer, WandExecutor,
14};
15
16#[derive(Clone)]
29pub struct WandOrQuery {
30 pub field: Field,
32 pub terms: Vec<String>,
34 global_stats: Option<Arc<GlobalStats>>,
36}
37
38impl std::fmt::Debug for WandOrQuery {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 f.debug_struct("WandOrQuery")
41 .field("field", &self.field)
42 .field("terms", &self.terms)
43 .field("has_global_stats", &self.global_stats.is_some())
44 .finish()
45 }
46}
47
48impl WandOrQuery {
49 pub fn new(field: Field) -> Self {
51 Self {
52 field,
53 terms: Vec::new(),
54 global_stats: None,
55 }
56 }
57
58 pub fn term(mut self, term: impl Into<String>) -> Self {
60 self.terms.push(term.into().to_lowercase());
61 self
62 }
63
64 pub fn terms(mut self, terms: impl IntoIterator<Item = impl Into<String>>) -> Self {
66 for t in terms {
67 self.terms.push(t.into().to_lowercase());
68 }
69 self
70 }
71
72 pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
74 self.global_stats = Some(stats);
75 self
76 }
77}
78
79impl Query for WandOrQuery {
80 fn scorer<'a>(&'a self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
81 Box::pin(async move {
82 let mut scorers: Vec<TextTermScorer> = Vec::with_capacity(self.terms.len());
83
84 let avg_field_len = self
86 .global_stats
87 .as_ref()
88 .map(|s| s.avg_field_len(self.field))
89 .unwrap_or_else(|| reader.avg_field_len(self.field));
90
91 let num_docs = reader.num_docs() as f32;
92
93 for term in &self.terms {
94 let term_bytes = term.as_bytes();
95
96 if let Some(posting_list) = reader.get_postings(self.field, term_bytes).await? {
97 let doc_freq = posting_list.doc_count() as f32;
99 let idf = if let Some(ref stats) = self.global_stats {
100 let global_idf = stats.text_idf(self.field, term);
101 if global_idf > 0.0 {
102 global_idf
103 } else {
104 super::bm25_idf(doc_freq, num_docs)
105 }
106 } else {
107 super::bm25_idf(doc_freq, num_docs)
108 };
109
110 scorers.push(TextTermScorer::new(posting_list, idf, avg_field_len));
111 }
112 }
113
114 if scorers.is_empty() {
115 return Ok(Box::new(EmptyWandScorer) as Box<dyn Scorer + 'a>);
116 }
117
118 let results = WandExecutor::new(scorers, limit).execute();
120
121 Ok(Box::new(WandResultScorer::new(results)) as Box<dyn Scorer + 'a>)
122 })
123 }
124
125 fn count_estimate<'a>(&'a self, reader: &'a SegmentReader) -> CountFuture<'a> {
126 Box::pin(async move {
127 let mut sum = 0u32;
128 for term in &self.terms {
129 if let Some(posting_list) = reader.get_postings(self.field, term.as_bytes()).await?
130 {
131 sum += posting_list.doc_count();
132 }
133 }
134 Ok(sum)
135 })
136 }
137}
138
139struct WandResultScorer {
141 results: Vec<ScoredDoc>,
142 position: usize,
143}
144
145impl WandResultScorer {
146 fn new(results: Vec<ScoredDoc>) -> Self {
147 Self {
148 results,
149 position: 0,
150 }
151 }
152}
153
154impl Scorer for WandResultScorer {
155 fn doc(&self) -> DocId {
156 if self.position < self.results.len() {
157 self.results[self.position].doc_id
158 } else {
159 crate::structures::TERMINATED
160 }
161 }
162
163 fn score(&self) -> Score {
164 if self.position < self.results.len() {
165 self.results[self.position].score
166 } else {
167 0.0
168 }
169 }
170
171 fn advance(&mut self) -> DocId {
172 self.position += 1;
173 self.doc()
174 }
175
176 fn seek(&mut self, target: DocId) -> DocId {
177 while self.position < self.results.len() && self.results[self.position].doc_id < target {
178 self.position += 1;
179 }
180 self.doc()
181 }
182
183 fn size_hint(&self) -> u32 {
184 self.results.len() as u32
185 }
186}
187
188struct EmptyWandScorer;
190
191impl Scorer for EmptyWandScorer {
192 fn doc(&self) -> DocId {
193 crate::structures::TERMINATED
194 }
195
196 fn score(&self) -> Score {
197 0.0
198 }
199
200 fn advance(&mut self) -> DocId {
201 crate::structures::TERMINATED
202 }
203
204 fn seek(&mut self, _target: DocId) -> DocId {
205 crate::structures::TERMINATED
206 }
207
208 fn size_hint(&self) -> u32 {
209 0
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn test_wand_or_query_builder() {
219 let query = WandOrQuery::new(Field(0))
220 .term("hello")
221 .term("world")
222 .terms(vec!["foo", "bar"]);
223
224 assert_eq!(query.terms.len(), 4);
225 assert_eq!(query.terms[0], "hello");
226 assert_eq!(query.terms[1], "world");
227 assert_eq!(query.terms[2], "foo");
228 assert_eq!(query.terms[3], "bar");
229 }
230}