1use std::sync::Arc;
9
10use crate::DocId;
11use crate::structures::TERMINATED;
12
13macro_rules! define_docset_trait {
16 ($($send_bounds:tt)*) => {
17 pub trait DocSet: $($send_bounds)* {
22 fn doc(&self) -> DocId;
24
25 fn advance(&mut self) -> DocId;
27
28 fn seek(&mut self, target: DocId) -> DocId {
30 let mut doc = self.doc();
31 while doc < target {
32 doc = self.advance();
33 }
34 doc
35 }
36
37 fn size_hint(&self) -> u32;
39 }
40 };
41}
42
43#[cfg(not(target_arch = "wasm32"))]
44define_docset_trait!(Send + Sync);
45
46#[cfg(target_arch = "wasm32")]
47define_docset_trait!();
48
49impl DocSet for Box<dyn DocSet + '_> {
52 #[inline]
53 fn doc(&self) -> DocId {
54 (**self).doc()
55 }
56 #[inline]
57 fn advance(&mut self) -> DocId {
58 (**self).advance()
59 }
60 #[inline]
61 fn seek(&mut self, target: DocId) -> DocId {
62 (**self).seek(target)
63 }
64 #[inline]
65 fn size_hint(&self) -> u32 {
66 (**self).size_hint()
67 }
68}
69
70pub struct SortedVecDocSet {
74 docs: Arc<Vec<u32>>,
75 pos: usize,
76}
77
78impl SortedVecDocSet {
79 pub fn new(docs: Arc<Vec<u32>>) -> Self {
80 Self { docs, pos: 0 }
81 }
82}
83
84impl DocSet for SortedVecDocSet {
85 #[inline]
86 fn doc(&self) -> DocId {
87 self.docs.get(self.pos).copied().unwrap_or(TERMINATED)
88 }
89
90 #[inline]
91 fn advance(&mut self) -> DocId {
92 if self.pos < self.docs.len() {
93 self.pos += 1;
94 }
95 self.doc()
96 }
97
98 fn seek(&mut self, target: DocId) -> DocId {
99 if self.pos >= self.docs.len() {
100 return TERMINATED;
101 }
102 let remaining = &self.docs[self.pos..];
103 match remaining.binary_search(&target) {
104 Ok(offset) => {
105 self.pos += offset;
106 self.docs[self.pos]
107 }
108 Err(offset) => {
109 self.pos += offset;
110 self.doc()
111 }
112 }
113 }
114
115 fn size_hint(&self) -> u32 {
116 self.docs.len().saturating_sub(self.pos) as u32
117 }
118}
119
120pub struct IntersectionDocSet<A: DocSet, B: DocSet> {
124 a: A,
125 b: B,
126}
127
128impl<A: DocSet, B: DocSet> IntersectionDocSet<A, B> {
129 pub fn new(mut a: A, mut b: B) -> Self {
130 let mut da = a.doc();
132 let mut db = b.doc();
133 loop {
134 if da == TERMINATED || db == TERMINATED {
135 break;
136 }
137 if da == db {
138 break;
139 }
140 if da < db {
141 da = a.seek(db);
142 } else {
143 db = b.seek(da);
144 }
145 }
146 Self { a, b }
147 }
148}
149
150impl<A: DocSet, B: DocSet> DocSet for IntersectionDocSet<A, B> {
151 fn doc(&self) -> DocId {
152 let da = self.a.doc();
153 if da == TERMINATED || self.b.doc() == TERMINATED {
154 TERMINATED
155 } else {
156 da
157 }
158 }
159
160 fn advance(&mut self) -> DocId {
161 let mut da = self.a.advance();
162 let mut db = self.b.doc();
163 loop {
164 if da == TERMINATED || db == TERMINATED {
165 return TERMINATED;
166 }
167 if da == db {
168 return da;
169 }
170 if da < db {
171 da = self.a.seek(db);
172 } else {
173 db = self.b.seek(da);
174 }
175 }
176 }
177
178 fn seek(&mut self, target: DocId) -> DocId {
179 let mut da = self.a.seek(target);
180 let mut db = self.b.seek(target);
181 loop {
182 if da == TERMINATED || db == TERMINATED {
183 return TERMINATED;
184 }
185 if da == db {
186 return da;
187 }
188 if da < db {
189 da = self.a.seek(db);
190 } else {
191 db = self.b.seek(da);
192 }
193 }
194 }
195
196 fn size_hint(&self) -> u32 {
197 self.a.size_hint().min(self.b.size_hint())
198 }
199}
200
201pub struct AllDocSet {
205 current: u32,
206 num_docs: u32,
207}
208
209impl AllDocSet {
210 pub fn new(num_docs: u32) -> Self {
211 Self {
212 current: 0,
213 num_docs,
214 }
215 }
216}
217
218impl DocSet for AllDocSet {
219 #[inline]
220 fn doc(&self) -> DocId {
221 if self.current >= self.num_docs {
222 TERMINATED
223 } else {
224 self.current
225 }
226 }
227
228 #[inline]
229 fn advance(&mut self) -> DocId {
230 self.current += 1;
231 self.doc()
232 }
233
234 #[inline]
235 fn seek(&mut self, target: DocId) -> DocId {
236 self.current = target;
237 self.doc()
238 }
239
240 fn size_hint(&self) -> u32 {
241 self.num_docs.saturating_sub(self.current)
242 }
243}
244
245pub struct EmptyDocSet;
249
250impl DocSet for EmptyDocSet {
251 #[inline]
252 fn doc(&self) -> DocId {
253 TERMINATED
254 }
255 #[inline]
256 fn advance(&mut self) -> DocId {
257 TERMINATED
258 }
259 #[inline]
260 fn seek(&mut self, _target: DocId) -> DocId {
261 TERMINATED
262 }
263 fn size_hint(&self) -> u32 {
264 0
265 }
266}
267
268pub struct IntersectionScorer<'a> {
275 scorer: Box<dyn super::Scorer + 'a>,
276 filter: Box<dyn DocSet + 'a>,
277}
278
279impl<'a> IntersectionScorer<'a> {
280 pub fn new(mut scorer: Box<dyn super::Scorer + 'a>, mut filter: Box<dyn DocSet + 'a>) -> Self {
281 let mut ds = scorer.doc();
283 let mut df = filter.doc();
284 loop {
285 if ds == TERMINATED || df == TERMINATED {
286 break;
287 }
288 if ds == df {
289 break;
290 }
291 if ds < df {
292 ds = scorer.seek(df);
293 } else {
294 df = filter.seek(ds);
295 }
296 }
297 Self { scorer, filter }
298 }
299}
300
301impl DocSet for IntersectionScorer<'_> {
302 fn doc(&self) -> DocId {
303 let ds = self.scorer.doc();
304 if ds == TERMINATED || self.filter.doc() == TERMINATED {
305 TERMINATED
306 } else {
307 ds
308 }
309 }
310
311 fn advance(&mut self) -> DocId {
312 let filter_smaller = self.filter.size_hint() < self.scorer.size_hint();
314
315 if filter_smaller {
316 let mut df = self.filter.advance();
318 let mut ds = self.scorer.doc();
319 loop {
320 if df == TERMINATED || ds == TERMINATED {
321 return TERMINATED;
322 }
323 if df == ds {
324 return df;
325 }
326 if df < ds {
327 df = self.filter.seek(ds);
328 } else {
329 ds = self.scorer.seek(df);
330 }
331 }
332 } else {
333 let mut ds = self.scorer.advance();
335 let mut df = self.filter.doc();
336 loop {
337 if ds == TERMINATED || df == TERMINATED {
338 return TERMINATED;
339 }
340 if ds == df {
341 return ds;
342 }
343 if ds < df {
344 ds = self.scorer.seek(df);
345 } else {
346 df = self.filter.seek(ds);
347 }
348 }
349 }
350 }
351
352 fn seek(&mut self, target: DocId) -> DocId {
353 let mut ds = self.scorer.seek(target);
354 let mut df = self.filter.seek(target);
355 loop {
356 if ds == TERMINATED || df == TERMINATED {
357 return TERMINATED;
358 }
359 if ds == df {
360 return ds;
361 }
362 if ds < df {
363 ds = self.scorer.seek(df);
364 } else {
365 df = self.filter.seek(ds);
366 }
367 }
368 }
369
370 fn size_hint(&self) -> u32 {
371 self.scorer.size_hint().min(self.filter.size_hint())
372 }
373}
374
375impl super::Scorer for IntersectionScorer<'_> {
376 fn score(&self) -> crate::Score {
377 self.scorer.score()
378 }
379
380 fn matched_positions(&self) -> Option<super::MatchedPositions> {
381 self.scorer.matched_positions()
382 }
383}
384
385pub struct PredicatedScorer<'a> {
397 driver: Box<dyn super::Scorer + 'a>,
399 predicates: Vec<super::DocPredicate<'a>>,
401 must_verifiers: Vec<Box<dyn super::Scorer + 'a>>,
403 must_not_verifiers: Vec<Box<dyn super::Scorer + 'a>>,
405 filter_score: f32,
407}
408
409impl<'a> PredicatedScorer<'a> {
410 pub fn new(
411 driver: Box<dyn super::Scorer + 'a>,
412 predicates: Vec<super::DocPredicate<'a>>,
413 must_verifiers: Vec<Box<dyn super::Scorer + 'a>>,
414 must_not_verifiers: Vec<Box<dyn super::Scorer + 'a>>,
415 filter_score: f32,
416 ) -> Self {
417 let mut s = Self {
418 driver,
419 predicates,
420 must_verifiers,
421 must_not_verifiers,
422 filter_score,
423 };
424 s.skip_non_matching();
426 s
427 }
428
429 #[inline]
431 fn check_filters(&mut self, doc: DocId) -> bool {
432 if !self.predicates.iter().all(|p| p(doc)) {
434 return false;
435 }
436 if !self.must_verifiers.iter_mut().all(|s| s.seek(doc) == doc) {
438 return false;
439 }
440 self.must_not_verifiers
442 .iter_mut()
443 .all(|s| s.seek(doc) != doc)
444 }
445
446 fn skip_non_matching(&mut self) -> DocId {
448 let mut doc = self.driver.doc();
449 while doc != TERMINATED && !self.check_filters(doc) {
450 doc = self.driver.advance();
451 }
452 doc
453 }
454}
455
456impl DocSet for PredicatedScorer<'_> {
457 fn doc(&self) -> DocId {
458 self.driver.doc()
459 }
460
461 fn advance(&mut self) -> DocId {
462 self.driver.advance();
463 self.skip_non_matching()
464 }
465
466 fn seek(&mut self, target: DocId) -> DocId {
467 self.driver.seek(target);
468 self.skip_non_matching()
469 }
470
471 fn size_hint(&self) -> u32 {
472 self.driver.size_hint()
473 }
474}
475
476impl super::Scorer for PredicatedScorer<'_> {
477 fn score(&self) -> crate::Score {
478 let mut total = self.driver.score();
479 for v in &self.must_verifiers {
480 total += v.score();
481 }
482 total + self.filter_score
483 }
484
485 fn matched_positions(&self) -> Option<super::MatchedPositions> {
486 let mut all: super::MatchedPositions = Vec::new();
487 if let Some(p) = self.driver.matched_positions() {
488 all.extend(p);
489 }
490 for v in &self.must_verifiers {
491 if let Some(p) = v.matched_positions() {
492 all.extend(p);
493 }
494 }
495 if all.is_empty() { None } else { Some(all) }
496 }
497}
498
499#[cfg(test)]
502mod tests {
503 use super::*;
504
505 #[test]
506 fn test_sorted_vec_docset_basic() {
507 let docs = Arc::new(vec![1, 3, 5, 7, 9]);
508 let mut ds = SortedVecDocSet::new(docs);
509
510 assert_eq!(ds.doc(), 1);
511 assert_eq!(ds.advance(), 3);
512 assert_eq!(ds.advance(), 5);
513 assert_eq!(ds.seek(7), 7);
514 assert_eq!(ds.advance(), 9);
515 assert_eq!(ds.advance(), TERMINATED);
516 assert_eq!(ds.doc(), TERMINATED);
517 }
518
519 #[test]
520 fn test_sorted_vec_docset_seek_past() {
521 let docs = Arc::new(vec![1, 5, 10, 20]);
522 let mut ds = SortedVecDocSet::new(docs);
523
524 assert_eq!(ds.seek(3), 5);
525 assert_eq!(ds.seek(15), 20);
526 assert_eq!(ds.seek(21), TERMINATED);
527 }
528
529 #[test]
530 fn test_sorted_vec_docset_empty() {
531 let docs = Arc::new(vec![]);
532 let ds = SortedVecDocSet::new(docs);
533 assert_eq!(ds.doc(), TERMINATED);
534 }
535
536 #[test]
537 fn test_all_docset() {
538 let mut ds = AllDocSet::new(3);
539 assert_eq!(ds.doc(), 0);
540 assert_eq!(ds.advance(), 1);
541 assert_eq!(ds.advance(), 2);
542 assert_eq!(ds.advance(), TERMINATED);
543 }
544
545 #[test]
546 fn test_all_docset_seek() {
547 let mut ds = AllDocSet::new(10);
548 assert_eq!(ds.seek(5), 5);
549 assert_eq!(ds.seek(9), 9);
550 assert_eq!(ds.seek(10), TERMINATED);
551 }
552
553 #[test]
554 fn test_empty_docset() {
555 let mut ds = EmptyDocSet;
556 assert_eq!(ds.doc(), TERMINATED);
557 assert_eq!(ds.advance(), TERMINATED);
558 assert_eq!(ds.seek(5), TERMINATED);
559 assert_eq!(ds.size_hint(), 0);
560 }
561
562 #[test]
563 fn test_intersection_docset() {
564 let a = SortedVecDocSet::new(Arc::new(vec![1, 3, 5, 7, 9]));
565 let b = SortedVecDocSet::new(Arc::new(vec![2, 3, 5, 8, 9, 10]));
566 let mut isect = IntersectionDocSet::new(a, b);
567
568 assert_eq!(isect.doc(), 3);
569 assert_eq!(isect.advance(), 5);
570 assert_eq!(isect.advance(), 9);
571 assert_eq!(isect.advance(), TERMINATED);
572 }
573
574 #[test]
575 fn test_intersection_docset_empty() {
576 let a = SortedVecDocSet::new(Arc::new(vec![1, 3, 5]));
577 let b = SortedVecDocSet::new(Arc::new(vec![2, 4, 6]));
578 let isect = IntersectionDocSet::new(a, b);
579 assert_eq!(isect.doc(), TERMINATED);
580 }
581
582 #[test]
583 fn test_intersection_docset_seek() {
584 let a = SortedVecDocSet::new(Arc::new(vec![1, 5, 10, 20, 30]));
585 let b = SortedVecDocSet::new(Arc::new(vec![5, 10, 15, 20, 25, 30]));
586 let mut isect = IntersectionDocSet::new(a, b);
587
588 assert_eq!(isect.doc(), 5);
589 assert_eq!(isect.seek(15), 20);
590 assert_eq!(isect.advance(), 30);
591 assert_eq!(isect.advance(), TERMINATED);
592 }
593
594 #[test]
595 fn test_size_hint() {
596 let docs = Arc::new(vec![1, 2, 3, 4, 5]);
597 let mut ds = SortedVecDocSet::new(docs);
598 assert_eq!(ds.size_hint(), 5);
599 ds.advance();
600 assert_eq!(ds.size_hint(), 4);
601 ds.seek(4);
602 assert_eq!(ds.size_hint(), 2); }
604}