hermes_core/query/
boolean.rs1use std::sync::Arc;
4
5use crate::segment::SegmentReader;
6use crate::structures::TERMINATED;
7use crate::{DocId, Score};
8
9use super::{
10 CountFuture, GlobalStats, MaxScoreExecutor, Query, ScoredDoc, Scorer, ScorerFuture,
11 TextTermScorer,
12};
13
14#[derive(Default, Clone)]
19pub struct BooleanQuery {
20 pub must: Vec<Arc<dyn Query>>,
21 pub should: Vec<Arc<dyn Query>>,
22 pub must_not: Vec<Arc<dyn Query>>,
23 global_stats: Option<Arc<GlobalStats>>,
25}
26
27impl std::fmt::Debug for BooleanQuery {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("BooleanQuery")
30 .field("must_count", &self.must.len())
31 .field("should_count", &self.should.len())
32 .field("must_not_count", &self.must_not.len())
33 .field("has_global_stats", &self.global_stats.is_some())
34 .finish()
35 }
36}
37
38impl BooleanQuery {
39 pub fn new() -> Self {
40 Self::default()
41 }
42
43 pub fn must(mut self, query: impl Query + 'static) -> Self {
44 self.must.push(Arc::new(query));
45 self
46 }
47
48 pub fn should(mut self, query: impl Query + 'static) -> Self {
49 self.should.push(Arc::new(query));
50 self
51 }
52
53 pub fn must_not(mut self, query: impl Query + 'static) -> Self {
54 self.must_not.push(Arc::new(query));
55 self
56 }
57
58 pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
60 self.global_stats = Some(stats);
61 self
62 }
63}
64
65async fn try_maxscore_scorer<'a>(
67 should: &[Arc<dyn Query>],
68 reader: &'a SegmentReader,
69 limit: usize,
70 global_stats: Option<&Arc<GlobalStats>>,
71) -> crate::Result<Option<Box<dyn Scorer + 'a>>> {
72 let mut term_infos: Vec<_> = should
74 .iter()
75 .filter_map(|q| q.as_term_query_info())
76 .collect();
77
78 if term_infos.len() != should.len() {
80 return Ok(None);
81 }
82
83 let first_field = term_infos[0].field;
85 if !term_infos.iter().all(|t| t.field == first_field) {
86 return Ok(None);
87 }
88
89 let mut scorers: Vec<TextTermScorer> = Vec::with_capacity(term_infos.len());
91 let avg_field_len = global_stats
92 .map(|s| s.avg_field_len(first_field))
93 .unwrap_or_else(|| reader.avg_field_len(first_field));
94 let num_docs = reader.num_docs() as f32;
95
96 for info in term_infos.drain(..) {
97 if let Some(posting_list) = reader.get_postings(info.field, &info.term).await? {
98 let doc_freq = posting_list.doc_count() as f32;
99 let idf = if let Some(stats) = global_stats {
100 let global_idf = stats.text_idf(info.field, &String::from_utf8_lossy(&info.term));
101 if global_idf > 0.0 {
102 global_idf
103 } else {
104 super::bm25_idf(doc_freq, num_docs)
105 }
106 } else {
107 super::bm25_idf(doc_freq, num_docs)
108 };
109 scorers.push(TextTermScorer::new(posting_list, idf, avg_field_len));
110 }
111 }
112
113 if scorers.is_empty() {
114 return Ok(Some(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>));
115 }
116
117 let results = MaxScoreExecutor::new(scorers, limit).execute();
119 Ok(Some(
120 Box::new(TopKResultScorer::new(results)) as Box<dyn Scorer + 'a>
121 ))
122}
123
124impl Query for BooleanQuery {
125 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
126 let must = self.must.clone();
128 let should = self.should.clone();
129 let must_not = self.must_not.clone();
130 let global_stats = self.global_stats.clone();
131
132 Box::pin(async move {
133 if must.is_empty()
136 && must_not.is_empty()
137 && should.len() >= 2
138 && let Some(scorer) =
139 try_maxscore_scorer(&should, reader, limit, global_stats.as_ref()).await?
140 {
141 return Ok(scorer);
142 }
143
144 let mut must_scorers = Vec::with_capacity(must.len());
146 for q in &must {
147 must_scorers.push(q.scorer(reader, limit).await?);
148 }
149
150 let mut should_scorers = Vec::with_capacity(should.len());
151 for q in &should {
152 should_scorers.push(q.scorer(reader, limit).await?);
153 }
154
155 let mut must_not_scorers = Vec::with_capacity(must_not.len());
156 for q in &must_not {
157 must_not_scorers.push(q.scorer(reader, limit).await?);
158 }
159
160 let mut scorer = BooleanScorer {
161 must: must_scorers,
162 should: should_scorers,
163 must_not: must_not_scorers,
164 current_doc: 0,
165 };
166 scorer.current_doc = scorer.find_next_match();
168 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
169 })
170 }
171
172 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
173 let must = self.must.clone();
174 let should = self.should.clone();
175
176 Box::pin(async move {
177 if !must.is_empty() {
178 let mut estimates = Vec::with_capacity(must.len());
179 for q in &must {
180 estimates.push(q.count_estimate(reader).await?);
181 }
182 estimates
183 .into_iter()
184 .min()
185 .ok_or_else(|| crate::Error::Corruption("Empty must clause".to_string()))
186 } else if !should.is_empty() {
187 let mut sum = 0u32;
188 for q in &should {
189 sum += q.count_estimate(reader).await?;
190 }
191 Ok(sum)
192 } else {
193 Ok(0)
194 }
195 })
196 }
197}
198
199struct BooleanScorer<'a> {
200 must: Vec<Box<dyn Scorer + 'a>>,
201 should: Vec<Box<dyn Scorer + 'a>>,
202 must_not: Vec<Box<dyn Scorer + 'a>>,
203 current_doc: DocId,
204}
205
206impl BooleanScorer<'_> {
207 fn find_next_match(&mut self) -> DocId {
208 if self.must.is_empty() && self.should.is_empty() {
209 return TERMINATED;
210 }
211
212 loop {
213 let candidate = if !self.must.is_empty() {
214 let mut max_doc = self
215 .must
216 .iter()
217 .map(|s| s.doc())
218 .max()
219 .unwrap_or(TERMINATED);
220
221 if max_doc == TERMINATED {
222 return TERMINATED;
223 }
224
225 loop {
226 let mut all_match = true;
227 for scorer in &mut self.must {
228 let doc = scorer.seek(max_doc);
229 if doc == TERMINATED {
230 return TERMINATED;
231 }
232 if doc > max_doc {
233 max_doc = doc;
234 all_match = false;
235 break;
236 }
237 }
238 if all_match {
239 break;
240 }
241 }
242 max_doc
243 } else {
244 self.should
245 .iter()
246 .map(|s| s.doc())
247 .filter(|&d| d != TERMINATED)
248 .min()
249 .unwrap_or(TERMINATED)
250 };
251
252 if candidate == TERMINATED {
253 return TERMINATED;
254 }
255
256 let excluded = self.must_not.iter_mut().any(|scorer| {
257 let doc = scorer.seek(candidate);
258 doc == candidate
259 });
260
261 if !excluded {
262 self.current_doc = candidate;
263 return candidate;
264 }
265
266 if !self.must.is_empty() {
268 for scorer in &mut self.must {
269 scorer.advance();
270 }
271 } else {
272 for scorer in &mut self.should {
274 if scorer.doc() <= candidate && scorer.doc() != TERMINATED {
275 scorer.seek(candidate + 1);
276 }
277 }
278 }
279 }
280 }
281}
282
283impl Scorer for BooleanScorer<'_> {
284 fn doc(&self) -> DocId {
285 self.current_doc
286 }
287
288 fn score(&self) -> Score {
289 let mut total = 0.0;
290
291 for scorer in &self.must {
292 if scorer.doc() == self.current_doc {
293 total += scorer.score();
294 }
295 }
296
297 for scorer in &self.should {
298 if scorer.doc() == self.current_doc {
299 total += scorer.score();
300 }
301 }
302
303 total
304 }
305
306 fn advance(&mut self) -> DocId {
307 if !self.must.is_empty() {
308 for scorer in &mut self.must {
309 scorer.advance();
310 }
311 } else {
312 for scorer in &mut self.should {
313 if scorer.doc() == self.current_doc {
314 scorer.advance();
315 }
316 }
317 }
318
319 self.find_next_match()
320 }
321
322 fn seek(&mut self, target: DocId) -> DocId {
323 for scorer in &mut self.must {
324 scorer.seek(target);
325 }
326
327 for scorer in &mut self.should {
328 scorer.seek(target);
329 }
330
331 self.find_next_match()
332 }
333
334 fn size_hint(&self) -> u32 {
335 if !self.must.is_empty() {
336 self.must.iter().map(|s| s.size_hint()).min().unwrap_or(0)
337 } else {
338 self.should.iter().map(|s| s.size_hint()).sum()
339 }
340 }
341}
342
343struct TopKResultScorer {
345 results: Vec<ScoredDoc>,
346 position: usize,
347}
348
349impl TopKResultScorer {
350 fn new(results: Vec<ScoredDoc>) -> Self {
351 Self {
352 results,
353 position: 0,
354 }
355 }
356}
357
358impl Scorer for TopKResultScorer {
359 fn doc(&self) -> DocId {
360 if self.position < self.results.len() {
361 self.results[self.position].doc_id
362 } else {
363 TERMINATED
364 }
365 }
366
367 fn score(&self) -> Score {
368 if self.position < self.results.len() {
369 self.results[self.position].score
370 } else {
371 0.0
372 }
373 }
374
375 fn advance(&mut self) -> DocId {
376 self.position += 1;
377 self.doc()
378 }
379
380 fn seek(&mut self, target: DocId) -> DocId {
381 while self.position < self.results.len() && self.results[self.position].doc_id < target {
382 self.position += 1;
383 }
384 self.doc()
385 }
386
387 fn size_hint(&self) -> u32 {
388 self.results.len() as u32
389 }
390}
391
392struct EmptyScorer;
394
395impl Scorer for EmptyScorer {
396 fn doc(&self) -> DocId {
397 TERMINATED
398 }
399
400 fn score(&self) -> Score {
401 0.0
402 }
403
404 fn advance(&mut self) -> DocId {
405 TERMINATED
406 }
407
408 fn seek(&mut self, _target: DocId) -> DocId {
409 TERMINATED
410 }
411
412 fn size_hint(&self) -> u32 {
413 0
414 }
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420 use crate::dsl::Field;
421 use crate::query::TermQuery;
422
423 #[test]
424 fn test_maxscore_eligible_pure_or_same_field() {
425 let query = BooleanQuery::new()
427 .should(TermQuery::text(Field(0), "hello"))
428 .should(TermQuery::text(Field(0), "world"))
429 .should(TermQuery::text(Field(0), "foo"));
430
431 assert!(
433 query
434 .should
435 .iter()
436 .all(|q| q.as_term_query_info().is_some())
437 );
438
439 let infos: Vec<_> = query
441 .should
442 .iter()
443 .filter_map(|q| q.as_term_query_info())
444 .collect();
445 assert_eq!(infos.len(), 3);
446 assert!(infos.iter().all(|i| i.field == Field(0)));
447 }
448
449 #[test]
450 fn test_maxscore_not_eligible_different_fields() {
451 let query = BooleanQuery::new()
453 .should(TermQuery::text(Field(0), "hello"))
454 .should(TermQuery::text(Field(1), "world")); let infos: Vec<_> = query
457 .should
458 .iter()
459 .filter_map(|q| q.as_term_query_info())
460 .collect();
461 assert_eq!(infos.len(), 2);
462 assert!(infos[0].field != infos[1].field);
464 }
465
466 #[test]
467 fn test_maxscore_not_eligible_with_must() {
468 let query = BooleanQuery::new()
470 .must(TermQuery::text(Field(0), "required"))
471 .should(TermQuery::text(Field(0), "hello"))
472 .should(TermQuery::text(Field(0), "world"));
473
474 assert!(!query.must.is_empty());
476 }
477
478 #[test]
479 fn test_maxscore_not_eligible_with_must_not() {
480 let query = BooleanQuery::new()
482 .should(TermQuery::text(Field(0), "hello"))
483 .should(TermQuery::text(Field(0), "world"))
484 .must_not(TermQuery::text(Field(0), "excluded"));
485
486 assert!(!query.must_not.is_empty());
488 }
489
490 #[test]
491 fn test_maxscore_not_eligible_single_term() {
492 let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
494
495 assert_eq!(query.should.len(), 1);
497 }
498
499 #[test]
500 fn test_term_query_info_extraction() {
501 let term_query = TermQuery::text(Field(42), "test");
502 let info = term_query.as_term_query_info();
503
504 assert!(info.is_some());
505 let info = info.unwrap();
506 assert_eq!(info.field, Field(42));
507 assert_eq!(info.term, b"test");
508 }
509
510 #[test]
511 fn test_boolean_query_no_term_info() {
512 let query = BooleanQuery::new().should(TermQuery::text(Field(0), "hello"));
514
515 assert!(query.as_term_query_info().is_none());
516 }
517}