1use std::sync::Arc;
8
9use crate::DocId;
10use crate::structures::TERMINATED;
11
12macro_rules! define_docset_trait {
15 ($($send_bounds:tt)*) => {
16 pub trait DocSet: $($send_bounds)* {
21 fn doc(&self) -> DocId;
23
24 fn advance(&mut self) -> DocId;
26
27 fn seek(&mut self, target: DocId) -> DocId {
29 let mut doc = self.doc();
30 while doc < target {
31 doc = self.advance();
32 }
33 doc
34 }
35
36 fn size_hint(&self) -> u32;
38 }
39 };
40}
41
42#[cfg(not(target_arch = "wasm32"))]
43define_docset_trait!(Send + Sync);
44
45#[cfg(target_arch = "wasm32")]
46define_docset_trait!();
47
48impl DocSet for Box<dyn DocSet + '_> {
51 #[inline]
52 fn doc(&self) -> DocId {
53 (**self).doc()
54 }
55 #[inline]
56 fn advance(&mut self) -> DocId {
57 (**self).advance()
58 }
59 #[inline]
60 fn seek(&mut self, target: DocId) -> DocId {
61 (**self).seek(target)
62 }
63 #[inline]
64 fn size_hint(&self) -> u32 {
65 (**self).size_hint()
66 }
67}
68
69pub struct SortedVecDocSet {
73 docs: Arc<Vec<u32>>,
74 pos: usize,
75}
76
77impl SortedVecDocSet {
78 pub fn new(docs: Arc<Vec<u32>>) -> Self {
79 Self { docs, pos: 0 }
80 }
81}
82
83impl DocSet for SortedVecDocSet {
84 #[inline]
85 fn doc(&self) -> DocId {
86 self.docs.get(self.pos).copied().unwrap_or(TERMINATED)
87 }
88
89 #[inline]
90 fn advance(&mut self) -> DocId {
91 if self.pos < self.docs.len() {
92 self.pos += 1;
93 }
94 self.doc()
95 }
96
97 fn seek(&mut self, target: DocId) -> DocId {
98 if self.pos >= self.docs.len() {
99 return TERMINATED;
100 }
101 let remaining = &self.docs[self.pos..];
102 match remaining.binary_search(&target) {
103 Ok(offset) => {
104 self.pos += offset;
105 self.docs[self.pos]
106 }
107 Err(offset) => {
108 self.pos += offset;
109 self.doc()
110 }
111 }
112 }
113
114 fn size_hint(&self) -> u32 {
115 self.docs.len().saturating_sub(self.pos) as u32
116 }
117}
118
119pub struct IntersectionDocSet<A: DocSet, B: DocSet> {
123 a: A,
124 b: B,
125}
126
127impl<A: DocSet, B: DocSet> IntersectionDocSet<A, B> {
128 pub fn new(mut a: A, mut b: B) -> Self {
129 let mut da = a.doc();
131 let mut db = b.doc();
132 loop {
133 if da == TERMINATED || db == TERMINATED {
134 break;
135 }
136 if da == db {
137 break;
138 }
139 if da < db {
140 da = a.seek(db);
141 } else {
142 db = b.seek(da);
143 }
144 }
145 Self { a, b }
146 }
147}
148
149impl<A: DocSet, B: DocSet> DocSet for IntersectionDocSet<A, B> {
150 fn doc(&self) -> DocId {
151 let da = self.a.doc();
152 if da == TERMINATED || self.b.doc() == TERMINATED {
153 TERMINATED
154 } else {
155 da
156 }
157 }
158
159 fn advance(&mut self) -> DocId {
160 let mut da = self.a.advance();
161 let mut db = self.b.doc();
162 loop {
163 if da == TERMINATED || db == TERMINATED {
164 return TERMINATED;
165 }
166 if da == db {
167 return da;
168 }
169 if da < db {
170 da = self.a.seek(db);
171 } else {
172 db = self.b.seek(da);
173 }
174 }
175 }
176
177 fn seek(&mut self, target: DocId) -> DocId {
178 let mut da = self.a.seek(target);
179 let mut db = self.b.seek(target);
180 loop {
181 if da == TERMINATED || db == TERMINATED {
182 return TERMINATED;
183 }
184 if da == db {
185 return da;
186 }
187 if da < db {
188 da = self.a.seek(db);
189 } else {
190 db = self.b.seek(da);
191 }
192 }
193 }
194
195 fn size_hint(&self) -> u32 {
196 self.a.size_hint().min(self.b.size_hint())
197 }
198}
199
200pub struct AllDocSet {
204 current: u32,
205 num_docs: u32,
206}
207
208impl AllDocSet {
209 pub fn new(num_docs: u32) -> Self {
210 Self {
211 current: 0,
212 num_docs,
213 }
214 }
215}
216
217impl DocSet for AllDocSet {
218 #[inline]
219 fn doc(&self) -> DocId {
220 if self.current >= self.num_docs {
221 TERMINATED
222 } else {
223 self.current
224 }
225 }
226
227 #[inline]
228 fn advance(&mut self) -> DocId {
229 self.current += 1;
230 self.doc()
231 }
232
233 #[inline]
234 fn seek(&mut self, target: DocId) -> DocId {
235 self.current = target;
236 self.doc()
237 }
238
239 fn size_hint(&self) -> u32 {
240 self.num_docs.saturating_sub(self.current)
241 }
242}
243
244pub struct EmptyDocSet;
248
249impl DocSet for EmptyDocSet {
250 #[inline]
251 fn doc(&self) -> DocId {
252 TERMINATED
253 }
254 #[inline]
255 fn advance(&mut self) -> DocId {
256 TERMINATED
257 }
258 #[inline]
259 fn seek(&mut self, _target: DocId) -> DocId {
260 TERMINATED
261 }
262 fn size_hint(&self) -> u32 {
263 0
264 }
265}
266
267pub struct PredicatedScorer<'a> {
278 driver: Box<dyn super::Scorer + 'a>,
280 predicates: Vec<super::DocPredicate<'a>>,
282 must_verifiers: Vec<Box<dyn super::Scorer + 'a>>,
284 must_not_verifiers: Vec<Box<dyn super::Scorer + 'a>>,
286}
287
288impl<'a> PredicatedScorer<'a> {
289 pub fn new(
290 driver: Box<dyn super::Scorer + 'a>,
291 predicates: Vec<super::DocPredicate<'a>>,
292 must_verifiers: Vec<Box<dyn super::Scorer + 'a>>,
293 must_not_verifiers: Vec<Box<dyn super::Scorer + 'a>>,
294 ) -> Self {
295 let mut s = Self {
296 driver,
297 predicates,
298 must_verifiers,
299 must_not_verifiers,
300 };
301 s.skip_non_matching();
303 s
304 }
305
306 #[inline]
308 fn check_filters(&mut self, doc: DocId) -> bool {
309 if !self.predicates.iter().all(|p| p(doc)) {
311 return false;
312 }
313 if !self.must_verifiers.iter_mut().all(|s| s.seek(doc) == doc) {
315 return false;
316 }
317 self.must_not_verifiers
319 .iter_mut()
320 .all(|s| s.seek(doc) != doc)
321 }
322
323 fn skip_non_matching(&mut self) -> DocId {
325 let mut doc = self.driver.doc();
326 while doc != TERMINATED && !self.check_filters(doc) {
327 doc = self.driver.advance();
328 }
329 doc
330 }
331}
332
333impl DocSet for PredicatedScorer<'_> {
334 fn doc(&self) -> DocId {
335 self.driver.doc()
336 }
337
338 fn advance(&mut self) -> DocId {
339 self.driver.advance();
340 self.skip_non_matching()
341 }
342
343 fn seek(&mut self, target: DocId) -> DocId {
344 self.driver.seek(target);
345 self.skip_non_matching()
346 }
347
348 fn size_hint(&self) -> u32 {
349 self.driver.size_hint()
350 }
351}
352
353impl super::Scorer for PredicatedScorer<'_> {
354 fn score(&self) -> crate::Score {
355 let mut total = self.driver.score();
356 for v in &self.must_verifiers {
357 total += v.score();
358 }
359 total
360 }
361
362 fn matched_positions(&self) -> Option<super::MatchedPositions> {
363 let mut all: super::MatchedPositions = Vec::new();
364 if let Some(p) = self.driver.matched_positions() {
365 all.extend(p);
366 }
367 for v in &self.must_verifiers {
368 if let Some(p) = v.matched_positions() {
369 all.extend(p);
370 }
371 }
372 if all.is_empty() { None } else { Some(all) }
373 }
374}
375
376#[cfg(test)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn test_sorted_vec_docset_basic() {
384 let docs = Arc::new(vec![1, 3, 5, 7, 9]);
385 let mut ds = SortedVecDocSet::new(docs);
386
387 assert_eq!(ds.doc(), 1);
388 assert_eq!(ds.advance(), 3);
389 assert_eq!(ds.advance(), 5);
390 assert_eq!(ds.seek(7), 7);
391 assert_eq!(ds.advance(), 9);
392 assert_eq!(ds.advance(), TERMINATED);
393 assert_eq!(ds.doc(), TERMINATED);
394 }
395
396 #[test]
397 fn test_sorted_vec_docset_seek_past() {
398 let docs = Arc::new(vec![1, 5, 10, 20]);
399 let mut ds = SortedVecDocSet::new(docs);
400
401 assert_eq!(ds.seek(3), 5);
402 assert_eq!(ds.seek(15), 20);
403 assert_eq!(ds.seek(21), TERMINATED);
404 }
405
406 #[test]
407 fn test_sorted_vec_docset_empty() {
408 let docs = Arc::new(vec![]);
409 let ds = SortedVecDocSet::new(docs);
410 assert_eq!(ds.doc(), TERMINATED);
411 }
412
413 #[test]
414 fn test_all_docset() {
415 let mut ds = AllDocSet::new(3);
416 assert_eq!(ds.doc(), 0);
417 assert_eq!(ds.advance(), 1);
418 assert_eq!(ds.advance(), 2);
419 assert_eq!(ds.advance(), TERMINATED);
420 }
421
422 #[test]
423 fn test_all_docset_seek() {
424 let mut ds = AllDocSet::new(10);
425 assert_eq!(ds.seek(5), 5);
426 assert_eq!(ds.seek(9), 9);
427 assert_eq!(ds.seek(10), TERMINATED);
428 }
429
430 #[test]
431 fn test_empty_docset() {
432 let mut ds = EmptyDocSet;
433 assert_eq!(ds.doc(), TERMINATED);
434 assert_eq!(ds.advance(), TERMINATED);
435 assert_eq!(ds.seek(5), TERMINATED);
436 assert_eq!(ds.size_hint(), 0);
437 }
438
439 #[test]
440 fn test_intersection_docset() {
441 let a = SortedVecDocSet::new(Arc::new(vec![1, 3, 5, 7, 9]));
442 let b = SortedVecDocSet::new(Arc::new(vec![2, 3, 5, 8, 9, 10]));
443 let mut isect = IntersectionDocSet::new(a, b);
444
445 assert_eq!(isect.doc(), 3);
446 assert_eq!(isect.advance(), 5);
447 assert_eq!(isect.advance(), 9);
448 assert_eq!(isect.advance(), TERMINATED);
449 }
450
451 #[test]
452 fn test_intersection_docset_empty() {
453 let a = SortedVecDocSet::new(Arc::new(vec![1, 3, 5]));
454 let b = SortedVecDocSet::new(Arc::new(vec![2, 4, 6]));
455 let isect = IntersectionDocSet::new(a, b);
456 assert_eq!(isect.doc(), TERMINATED);
457 }
458
459 #[test]
460 fn test_intersection_docset_seek() {
461 let a = SortedVecDocSet::new(Arc::new(vec![1, 5, 10, 20, 30]));
462 let b = SortedVecDocSet::new(Arc::new(vec![5, 10, 15, 20, 25, 30]));
463 let mut isect = IntersectionDocSet::new(a, b);
464
465 assert_eq!(isect.doc(), 5);
466 assert_eq!(isect.seek(15), 20);
467 assert_eq!(isect.advance(), 30);
468 assert_eq!(isect.advance(), TERMINATED);
469 }
470
471 #[test]
472 fn test_size_hint() {
473 let docs = Arc::new(vec![1, 2, 3, 4, 5]);
474 let mut ds = SortedVecDocSet::new(docs);
475 assert_eq!(ds.size_hint(), 5);
476 ds.advance();
477 assert_eq!(ds.size_hint(), 4);
478 ds.seek(4);
479 assert_eq!(ds.size_hint(), 2); }
481}