1use std::collections::HashSet;
53use std::fmt;
54use std::sync::Arc;
55
56#[derive(Clone)]
65pub enum AllowedSet {
66 All,
68
69 Bitmap(Arc<AllowedBitmap>),
71
72 SortedVec(Arc<Vec<u64>>),
74
75 HashSet(Arc<HashSet<u64>>),
77
78 None,
80}
81
82impl AllowedSet {
83 pub fn from_bitmap(bitmap: AllowedBitmap) -> Self {
85 if bitmap.is_empty() {
86 Self::None
87 } else if bitmap.is_all() {
88 Self::All
89 } else {
90 Self::Bitmap(Arc::new(bitmap))
91 }
92 }
93
94 pub fn from_sorted_vec(mut ids: Vec<u64>) -> Self {
96 if ids.is_empty() {
97 return Self::None;
98 }
99 ids.sort_unstable();
100 ids.dedup();
101 Self::SortedVec(Arc::new(ids))
102 }
103
104 pub fn from_iter(ids: impl IntoIterator<Item = u64>) -> Self {
106 let set: HashSet<u64> = ids.into_iter().collect();
107 if set.is_empty() {
108 Self::None
109 } else {
110 Self::HashSet(Arc::new(set))
111 }
112 }
113
114 #[inline]
118 pub fn contains(&self, doc_id: u64) -> bool {
119 match self {
120 Self::All => true,
121 Self::Bitmap(bm) => bm.contains(doc_id),
122 Self::SortedVec(vec) => vec.binary_search(&doc_id).is_ok(),
123 Self::HashSet(set) => set.contains(&doc_id),
124 Self::None => false,
125 }
126 }
127
128 pub fn is_empty(&self) -> bool {
130 matches!(self, Self::None)
131 }
132
133 pub fn is_all(&self) -> bool {
135 matches!(self, Self::All)
136 }
137
138 pub fn cardinality(&self) -> Option<usize> {
142 match self {
143 Self::All => None,
144 Self::Bitmap(bm) => Some(bm.count()),
145 Self::SortedVec(vec) => Some(vec.len()),
146 Self::HashSet(set) => Some(set.len()),
147 Self::None => Some(0),
148 }
149 }
150
151 pub fn selectivity(&self, universe_size: usize) -> f64 {
155 if universe_size == 0 {
156 return 0.0;
157 }
158 match self {
159 Self::All => 1.0,
160 Self::None => 0.0,
161 other => {
162 other.cardinality()
163 .map(|c| c as f64 / universe_size as f64)
164 .unwrap_or(1.0)
165 }
166 }
167 }
168
169 pub fn intersect(&self, other: &AllowedSet) -> AllowedSet {
171 match (self, other) {
172 (Self::All, x) | (x, Self::All) => x.clone(),
174 (Self::None, _) | (_, Self::None) => Self::None,
175
176 (Self::SortedVec(a), Self::SortedVec(b)) => {
178 let result = sorted_vec_intersect(a, b);
179 Self::from_sorted_vec(result)
180 }
181 (Self::HashSet(a), Self::HashSet(b)) => {
182 let result: HashSet<_> = a.intersection(b).copied().collect();
183 if result.is_empty() {
184 Self::None
185 } else {
186 Self::HashSet(Arc::new(result))
187 }
188 }
189 (Self::Bitmap(a), Self::Bitmap(b)) => {
190 let result = a.intersect(b);
191 Self::from_bitmap(result)
192 }
193
194 (a, b) => {
196 let set_a: HashSet<u64> = a.iter().collect();
197 let set_b: HashSet<u64> = b.iter().collect();
198 let result: HashSet<_> = set_a.intersection(&set_b).copied().collect();
199 if result.is_empty() {
200 Self::None
201 } else {
202 Self::HashSet(Arc::new(result))
203 }
204 }
205 }
206 }
207
208 pub fn union(&self, other: &AllowedSet) -> AllowedSet {
210 match (self, other) {
211 (Self::All, _) | (_, Self::All) => Self::All,
212 (Self::None, x) | (x, Self::None) => x.clone(),
213
214 (Self::HashSet(a), Self::HashSet(b)) => {
215 let result: HashSet<_> = a.union(b).copied().collect();
216 Self::HashSet(Arc::new(result))
217 }
218
219 (a, b) => {
221 let mut result: HashSet<u64> = a.iter().collect();
222 result.extend(b.iter());
223 Self::HashSet(Arc::new(result))
224 }
225 }
226 }
227
228 pub fn iter(&self) -> AllowedSetIter<'_> {
232 match self {
233 Self::All => AllowedSetIter::Empty,
234 Self::Bitmap(bm) => AllowedSetIter::Bitmap(bm.iter()),
235 Self::SortedVec(vec) => AllowedSetIter::SortedVec(vec.iter()),
236 Self::HashSet(set) => AllowedSetIter::HashSet(set.iter()),
237 Self::None => AllowedSetIter::Empty,
238 }
239 }
240
241 pub fn to_vec(&self) -> Vec<u64> {
243 self.iter().collect()
244 }
245}
246
247impl fmt::Debug for AllowedSet {
248 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249 match self {
250 Self::All => write!(f, "AllowedSet::All"),
251 Self::None => write!(f, "AllowedSet::None"),
252 Self::Bitmap(bm) => write!(f, "AllowedSet::Bitmap(count={})", bm.count()),
253 Self::SortedVec(vec) => write!(f, "AllowedSet::SortedVec(len={})", vec.len()),
254 Self::HashSet(set) => write!(f, "AllowedSet::HashSet(len={})", set.len()),
255 }
256 }
257}
258
259impl Default for AllowedSet {
260 fn default() -> Self {
261 Self::All
262 }
263}
264
265fn sorted_vec_intersect(a: &[u64], b: &[u64]) -> Vec<u64> {
267 let mut result = Vec::with_capacity(a.len().min(b.len()));
268 let mut i = 0;
269 let mut j = 0;
270
271 while i < a.len() && j < b.len() {
272 match a[i].cmp(&b[j]) {
273 std::cmp::Ordering::Less => i += 1,
274 std::cmp::Ordering::Greater => j += 1,
275 std::cmp::Ordering::Equal => {
276 result.push(a[i]);
277 i += 1;
278 j += 1;
279 }
280 }
281 }
282
283 result
284}
285
286pub enum AllowedSetIter<'a> {
292 Empty,
293 Bitmap(BitmapIter<'a>),
294 SortedVec(std::slice::Iter<'a, u64>),
295 HashSet(std::collections::hash_set::Iter<'a, u64>),
296}
297
298impl<'a> Iterator for AllowedSetIter<'a> {
299 type Item = u64;
300
301 fn next(&mut self) -> Option<Self::Item> {
302 match self {
303 Self::Empty => None,
304 Self::Bitmap(iter) => iter.next(),
305 Self::SortedVec(iter) => iter.next().copied(),
306 Self::HashSet(iter) => iter.next().copied(),
307 }
308 }
309
310 fn size_hint(&self) -> (usize, Option<usize>) {
311 match self {
312 Self::Empty => (0, Some(0)),
313 Self::Bitmap(iter) => iter.size_hint(),
314 Self::SortedVec(iter) => iter.size_hint(),
315 Self::HashSet(iter) => iter.size_hint(),
316 }
317 }
318}
319
320pub struct AllowedBitmap {
329 words: Vec<u64>,
331 count: usize,
333 all: bool,
335}
336
337impl AllowedBitmap {
338 pub fn new() -> Self {
340 Self {
341 words: Vec::new(),
342 count: 0,
343 all: false,
344 }
345 }
346
347 pub fn all(max_id: u64) -> Self {
349 let word_count = (max_id as usize / 64) + 1;
350 Self {
351 words: vec![u64::MAX; word_count],
352 count: max_id as usize + 1,
353 all: true,
354 }
355 }
356
357 pub fn from_ids(ids: &[u64]) -> Self {
359 if ids.is_empty() {
360 return Self::new();
361 }
362
363 let max_id = *ids.iter().max().unwrap();
364 let word_count = (max_id as usize / 64) + 1;
365 let mut words = vec![0u64; word_count];
366
367 for &id in ids {
368 let word_idx = id as usize / 64;
369 let bit_idx = id % 64;
370 words[word_idx] |= 1 << bit_idx;
371 }
372
373 Self {
374 words,
375 count: ids.len(),
376 all: false,
377 }
378 }
379
380 pub fn set(&mut self, id: u64) {
382 let word_idx = id as usize / 64;
383 let bit_idx = id % 64;
384
385 if word_idx >= self.words.len() {
387 self.words.resize(word_idx + 1, 0);
388 }
389
390 let old = self.words[word_idx];
391 self.words[word_idx] |= 1 << bit_idx;
392 if old != self.words[word_idx] {
393 self.count += 1;
394 }
395 }
396
397 pub fn clear(&mut self, id: u64) {
399 let word_idx = id as usize / 64;
400 if word_idx >= self.words.len() {
401 return;
402 }
403
404 let bit_idx = id % 64;
405 let old = self.words[word_idx];
406 self.words[word_idx] &= !(1 << bit_idx);
407 if old != self.words[word_idx] {
408 self.count -= 1;
409 }
410 }
411
412 #[inline]
414 pub fn contains(&self, id: u64) -> bool {
415 let word_idx = id as usize / 64;
416 if word_idx >= self.words.len() {
417 return false;
418 }
419 let bit_idx = id % 64;
420 (self.words[word_idx] & (1 << bit_idx)) != 0
421 }
422
423 pub fn count(&self) -> usize {
425 self.count
426 }
427
428 pub fn is_empty(&self) -> bool {
430 self.count == 0
431 }
432
433 pub fn is_all(&self) -> bool {
435 self.all
436 }
437
438 pub fn intersect(&self, other: &AllowedBitmap) -> AllowedBitmap {
440 let min_len = self.words.len().min(other.words.len());
441 let mut words = Vec::with_capacity(min_len);
442 let mut count = 0;
443
444 for i in 0..min_len {
445 let word = self.words[i] & other.words[i];
446 count += word.count_ones() as usize;
447 words.push(word);
448 }
449
450 AllowedBitmap {
451 words,
452 count,
453 all: false,
454 }
455 }
456
457 pub fn union(&self, other: &AllowedBitmap) -> AllowedBitmap {
459 let max_len = self.words.len().max(other.words.len());
460 let mut words = Vec::with_capacity(max_len);
461 let mut count = 0;
462
463 for i in 0..max_len {
464 let a = self.words.get(i).copied().unwrap_or(0);
465 let b = other.words.get(i).copied().unwrap_or(0);
466 let word = a | b;
467 count += word.count_ones() as usize;
468 words.push(word);
469 }
470
471 AllowedBitmap {
472 words,
473 count,
474 all: false,
475 }
476 }
477
478 pub fn iter(&self) -> BitmapIter<'_> {
480 BitmapIter {
481 words: &self.words,
482 word_idx: 0,
483 bit_offset: 0,
484 remaining: self.count,
485 }
486 }
487}
488
489impl Default for AllowedBitmap {
490 fn default() -> Self {
491 Self::new()
492 }
493}
494
495pub struct BitmapIter<'a> {
497 words: &'a [u64],
498 word_idx: usize,
499 bit_offset: u64,
500 remaining: usize,
501}
502
503impl<'a> Iterator for BitmapIter<'a> {
504 type Item = u64;
505
506 fn next(&mut self) -> Option<Self::Item> {
507 if self.remaining == 0 {
508 return None;
509 }
510
511 while self.word_idx < self.words.len() {
512 let word = self.words[self.word_idx];
513 let masked = word >> self.bit_offset;
514
515 if masked != 0 {
516 let trailing = masked.trailing_zeros() as u64;
517 let bit_pos = self.bit_offset + trailing;
518 self.bit_offset = bit_pos + 1;
519
520 if self.bit_offset >= 64 {
521 self.bit_offset = 0;
522 self.word_idx += 1;
523 }
524
525 self.remaining -= 1;
526 return Some(self.word_idx as u64 * 64 + bit_pos - (if self.bit_offset == 0 { 64 } else { 0 }) + (if bit_pos >= 64 { 0 } else { bit_pos }));
527 }
528
529 self.word_idx += 1;
530 self.bit_offset = 0;
531 }
532
533 None
534 }
535
536 fn size_hint(&self) -> (usize, Option<usize>) {
537 (self.remaining, Some(self.remaining))
538 }
539}
540
541impl<'a> BitmapIter<'a> {
543 #[allow(dead_code)]
544 fn new(words: &'a [u64], count: usize) -> Self {
545 Self {
546 words,
547 word_idx: 0,
548 bit_offset: 0,
549 remaining: count,
550 }
551 }
552}
553
554impl AllowedBitmap {
556 pub fn iter_simple(&self) -> impl Iterator<Item = u64> + '_ {
558 self.words.iter().enumerate().flat_map(|(word_idx, &word)| {
559 (0..64).filter_map(move |bit| {
560 if (word & (1 << bit)) != 0 {
561 Some(word_idx as u64 * 64 + bit as u64)
562 } else {
563 None
564 }
565 })
566 })
567 }
568}
569
570pub trait CandidateGate {
578 type Query;
580
581 type Result;
583
584 type Error;
586
587 fn execute_with_gate(
595 &self,
596 query: &Self::Query,
597 allowed_set: &AllowedSet,
598 ) -> Result<Self::Result, Self::Error>;
599
600 fn strategy_for_selectivity(&self, selectivity: f64) -> ExecutionStrategy {
602 if selectivity >= 0.1 {
603 ExecutionStrategy::FilterDuringSearch
604 } else if selectivity >= 0.001 {
605 ExecutionStrategy::ScanAllowedIds
606 } else {
607 ExecutionStrategy::LinearScan
608 }
609 }
610}
611
612#[derive(Debug, Clone, Copy, PartialEq, Eq)]
614pub enum ExecutionStrategy {
615 FilterDuringSearch,
617
618 ScanAllowedIds,
620
621 LinearScan,
623
624 Reject,
626}
627
628#[cfg(test)]
633mod tests {
634 use super::*;
635
636 #[test]
637 fn test_allowed_set_contains() {
638 let all = AllowedSet::All;
640 assert!(all.contains(0));
641 assert!(all.contains(1000000));
642
643 let none = AllowedSet::None;
645 assert!(!none.contains(0));
646
647 let vec = AllowedSet::from_sorted_vec(vec![1, 3, 5, 7, 9]);
649 assert!(vec.contains(1));
650 assert!(vec.contains(5));
651 assert!(!vec.contains(2));
652 assert!(!vec.contains(10));
653
654 let set = AllowedSet::from_iter([1, 3, 5, 7, 9]);
656 assert!(set.contains(1));
657 assert!(set.contains(5));
658 assert!(!set.contains(2));
659 }
660
661 #[test]
662 fn test_allowed_set_selectivity() {
663 let set = AllowedSet::from_sorted_vec(vec![1, 2, 3, 4, 5]);
664
665 assert_eq!(set.selectivity(100), 0.05);
666 assert_eq!(set.selectivity(10), 0.5);
667
668 assert_eq!(AllowedSet::All.selectivity(100), 1.0);
669 assert_eq!(AllowedSet::None.selectivity(100), 0.0);
670 }
671
672 #[test]
673 fn test_allowed_set_intersection() {
674 let a = AllowedSet::from_sorted_vec(vec![1, 2, 3, 4, 5]);
675 let b = AllowedSet::from_sorted_vec(vec![3, 4, 5, 6, 7]);
676
677 let c = a.intersect(&b);
678 assert_eq!(c.cardinality(), Some(3));
679 assert!(c.contains(3));
680 assert!(c.contains(4));
681 assert!(c.contains(5));
682 assert!(!c.contains(1));
683 assert!(!c.contains(7));
684 }
685
686 #[test]
687 fn test_bitmap_basic() {
688 let mut bm = AllowedBitmap::new();
689 bm.set(0);
690 bm.set(5);
691 bm.set(64);
692 bm.set(100);
693
694 assert!(bm.contains(0));
695 assert!(bm.contains(5));
696 assert!(bm.contains(64));
697 assert!(bm.contains(100));
698 assert!(!bm.contains(1));
699 assert!(!bm.contains(63));
700
701 assert_eq!(bm.count(), 4);
702 }
703
704 #[test]
705 fn test_bitmap_from_ids() {
706 let ids = vec![1, 5, 10, 100, 1000];
707 let bm = AllowedBitmap::from_ids(&ids);
708
709 for &id in &ids {
710 assert!(bm.contains(id));
711 }
712 assert!(!bm.contains(0));
713 assert!(!bm.contains(50));
714 }
715
716 #[test]
717 fn test_bitmap_intersection() {
718 let a = AllowedBitmap::from_ids(&[1, 2, 3, 4, 5]);
719 let b = AllowedBitmap::from_ids(&[3, 4, 5, 6, 7]);
720
721 let c = a.intersect(&b);
722 assert_eq!(c.count(), 3);
723 assert!(c.contains(3));
724 assert!(c.contains(4));
725 assert!(c.contains(5));
726 }
727
728 #[test]
729 fn test_execution_strategy() {
730 struct DummyGate;
731 impl CandidateGate for DummyGate {
732 type Query = ();
733 type Result = ();
734 type Error = ();
735 fn execute_with_gate(&self, _: &(), _: &AllowedSet) -> Result<(), ()> {
736 Ok(())
737 }
738 }
739
740 let gate = DummyGate;
741 assert_eq!(gate.strategy_for_selectivity(0.5), ExecutionStrategy::FilterDuringSearch);
742 assert_eq!(gate.strategy_for_selectivity(0.01), ExecutionStrategy::ScanAllowedIds);
743 assert_eq!(gate.strategy_for_selectivity(0.0001), ExecutionStrategy::LinearScan);
744 }
745}