use std::collections::HashSet;
use std::fmt;
use std::sync::Arc;
#[derive(Clone)]
pub enum AllowedSet {
All,
Bitmap(Arc<AllowedBitmap>),
SortedVec(Arc<Vec<u64>>),
HashSet(Arc<HashSet<u64>>),
None,
}
impl AllowedSet {
pub fn from_bitmap(bitmap: AllowedBitmap) -> Self {
if bitmap.is_empty() {
Self::None
} else if bitmap.is_all() {
Self::All
} else {
Self::Bitmap(Arc::new(bitmap))
}
}
pub fn from_sorted_vec(mut ids: Vec<u64>) -> Self {
if ids.is_empty() {
return Self::None;
}
ids.sort_unstable();
ids.dedup();
Self::SortedVec(Arc::new(ids))
}
pub fn from_iter(ids: impl IntoIterator<Item = u64>) -> Self {
let set: HashSet<u64> = ids.into_iter().collect();
if set.is_empty() {
Self::None
} else {
Self::HashSet(Arc::new(set))
}
}
#[inline]
pub fn contains(&self, doc_id: u64) -> bool {
match self {
Self::All => true,
Self::Bitmap(bm) => bm.contains(doc_id),
Self::SortedVec(vec) => vec.binary_search(&doc_id).is_ok(),
Self::HashSet(set) => set.contains(&doc_id),
Self::None => false,
}
}
pub fn is_empty(&self) -> bool {
matches!(self, Self::None)
}
pub fn is_all(&self) -> bool {
matches!(self, Self::All)
}
pub fn cardinality(&self) -> Option<usize> {
match self {
Self::All => None,
Self::Bitmap(bm) => Some(bm.count()),
Self::SortedVec(vec) => Some(vec.len()),
Self::HashSet(set) => Some(set.len()),
Self::None => Some(0),
}
}
pub fn selectivity(&self, universe_size: usize) -> f64 {
if universe_size == 0 {
return 0.0;
}
match self {
Self::All => 1.0,
Self::None => 0.0,
other => other
.cardinality()
.map(|c| c as f64 / universe_size as f64)
.unwrap_or(1.0),
}
}
pub fn intersect(&self, other: &AllowedSet) -> AllowedSet {
match (self, other) {
(Self::All, x) | (x, Self::All) => x.clone(),
(Self::None, _) | (_, Self::None) => Self::None,
(Self::SortedVec(a), Self::SortedVec(b)) => {
let result = sorted_vec_intersect(a, b);
Self::from_sorted_vec(result)
}
(Self::HashSet(a), Self::HashSet(b)) => {
let result: HashSet<_> = a.intersection(b).copied().collect();
if result.is_empty() {
Self::None
} else {
Self::HashSet(Arc::new(result))
}
}
(Self::Bitmap(a), Self::Bitmap(b)) => {
let result = a.intersect(b);
Self::from_bitmap(result)
}
(a, b) => {
let set_a: HashSet<u64> = a.iter().collect();
let set_b: HashSet<u64> = b.iter().collect();
let result: HashSet<_> = set_a.intersection(&set_b).copied().collect();
if result.is_empty() {
Self::None
} else {
Self::HashSet(Arc::new(result))
}
}
}
}
pub fn union(&self, other: &AllowedSet) -> AllowedSet {
match (self, other) {
(Self::All, _) | (_, Self::All) => Self::All,
(Self::None, x) | (x, Self::None) => x.clone(),
(Self::HashSet(a), Self::HashSet(b)) => {
let result: HashSet<_> = a.union(b).copied().collect();
Self::HashSet(Arc::new(result))
}
(a, b) => {
let mut result: HashSet<u64> = a.iter().collect();
result.extend(b.iter());
Self::HashSet(Arc::new(result))
}
}
}
pub fn iter(&self) -> AllowedSetIter<'_> {
match self {
Self::All => AllowedSetIter::Empty,
Self::Bitmap(bm) => AllowedSetIter::Bitmap(bm.iter()),
Self::SortedVec(vec) => AllowedSetIter::SortedVec(vec.iter()),
Self::HashSet(set) => AllowedSetIter::HashSet(set.iter()),
Self::None => AllowedSetIter::Empty,
}
}
pub fn to_vec(&self) -> Vec<u64> {
self.iter().collect()
}
}
impl fmt::Debug for AllowedSet {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::All => write!(f, "AllowedSet::All"),
Self::None => write!(f, "AllowedSet::None"),
Self::Bitmap(bm) => write!(f, "AllowedSet::Bitmap(count={})", bm.count()),
Self::SortedVec(vec) => write!(f, "AllowedSet::SortedVec(len={})", vec.len()),
Self::HashSet(set) => write!(f, "AllowedSet::HashSet(len={})", set.len()),
}
}
}
impl Default for AllowedSet {
fn default() -> Self {
Self::All
}
}
fn sorted_vec_intersect(a: &[u64], b: &[u64]) -> Vec<u64> {
let mut result = Vec::with_capacity(a.len().min(b.len()));
let mut i = 0;
let mut j = 0;
while i < a.len() && j < b.len() {
match a[i].cmp(&b[j]) {
std::cmp::Ordering::Less => i += 1,
std::cmp::Ordering::Greater => j += 1,
std::cmp::Ordering::Equal => {
result.push(a[i]);
i += 1;
j += 1;
}
}
}
result
}
pub enum AllowedSetIter<'a> {
Empty,
Bitmap(BitmapIter<'a>),
SortedVec(std::slice::Iter<'a, u64>),
HashSet(std::collections::hash_set::Iter<'a, u64>),
}
impl<'a> Iterator for AllowedSetIter<'a> {
type Item = u64;
fn next(&mut self) -> Option<Self::Item> {
match self {
Self::Empty => None,
Self::Bitmap(iter) => iter.next(),
Self::SortedVec(iter) => iter.next().copied(),
Self::HashSet(iter) => iter.next().copied(),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
match self {
Self::Empty => (0, Some(0)),
Self::Bitmap(iter) => iter.size_hint(),
Self::SortedVec(iter) => iter.size_hint(),
Self::HashSet(iter) => iter.size_hint(),
}
}
}
pub struct AllowedBitmap {
words: Vec<u64>,
count: usize,
all: bool,
}
impl AllowedBitmap {
pub fn new() -> Self {
Self {
words: Vec::new(),
count: 0,
all: false,
}
}
pub fn all(max_id: u64) -> Self {
let word_count = (max_id as usize / 64) + 1;
Self {
words: vec![u64::MAX; word_count],
count: max_id as usize + 1,
all: true,
}
}
pub fn from_ids(ids: &[u64]) -> Self {
if ids.is_empty() {
return Self::new();
}
let max_id = *ids.iter().max().unwrap();
let word_count = (max_id as usize / 64) + 1;
let mut words = vec![0u64; word_count];
for &id in ids {
let word_idx = id as usize / 64;
let bit_idx = id % 64;
words[word_idx] |= 1 << bit_idx;
}
Self {
words,
count: ids.len(),
all: false,
}
}
pub fn set(&mut self, id: u64) {
let word_idx = id as usize / 64;
let bit_idx = id % 64;
if word_idx >= self.words.len() {
self.words.resize(word_idx + 1, 0);
}
let old = self.words[word_idx];
self.words[word_idx] |= 1 << bit_idx;
if old != self.words[word_idx] {
self.count += 1;
}
}
pub fn clear(&mut self, id: u64) {
let word_idx = id as usize / 64;
if word_idx >= self.words.len() {
return;
}
let bit_idx = id % 64;
let old = self.words[word_idx];
self.words[word_idx] &= !(1 << bit_idx);
if old != self.words[word_idx] {
self.count -= 1;
}
}
#[inline]
pub fn contains(&self, id: u64) -> bool {
let word_idx = id as usize / 64;
if word_idx >= self.words.len() {
return false;
}
let bit_idx = id % 64;
(self.words[word_idx] & (1 << bit_idx)) != 0
}
pub fn count(&self) -> usize {
self.count
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn is_all(&self) -> bool {
self.all
}
pub fn intersect(&self, other: &AllowedBitmap) -> AllowedBitmap {
let min_len = self.words.len().min(other.words.len());
let mut words = Vec::with_capacity(min_len);
let mut count = 0;
for i in 0..min_len {
let word = self.words[i] & other.words[i];
count += word.count_ones() as usize;
words.push(word);
}
AllowedBitmap {
words,
count,
all: false,
}
}
pub fn union(&self, other: &AllowedBitmap) -> AllowedBitmap {
let max_len = self.words.len().max(other.words.len());
let mut words = Vec::with_capacity(max_len);
let mut count = 0;
for i in 0..max_len {
let a = self.words.get(i).copied().unwrap_or(0);
let b = other.words.get(i).copied().unwrap_or(0);
let word = a | b;
count += word.count_ones() as usize;
words.push(word);
}
AllowedBitmap {
words,
count,
all: false,
}
}
pub fn iter(&self) -> BitmapIter<'_> {
BitmapIter {
words: &self.words,
word_idx: 0,
bit_offset: 0,
remaining: self.count,
}
}
}
impl Default for AllowedBitmap {
fn default() -> Self {
Self::new()
}
}
pub struct BitmapIter<'a> {
words: &'a [u64],
word_idx: usize,
bit_offset: u64,
remaining: usize,
}
impl<'a> Iterator for BitmapIter<'a> {
type Item = u64;
fn next(&mut self) -> Option<Self::Item> {
if self.remaining == 0 {
return None;
}
while self.word_idx < self.words.len() {
let word = self.words[self.word_idx];
let masked = word >> self.bit_offset;
if masked != 0 {
let trailing = masked.trailing_zeros() as u64;
let bit_pos = self.bit_offset + trailing;
self.bit_offset = bit_pos + 1;
if self.bit_offset >= 64 {
self.bit_offset = 0;
self.word_idx += 1;
}
self.remaining -= 1;
return Some(
self.word_idx as u64 * 64 + bit_pos
- (if self.bit_offset == 0 { 64 } else { 0 })
+ (if bit_pos >= 64 { 0 } else { bit_pos }),
);
}
self.word_idx += 1;
self.bit_offset = 0;
}
None
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.remaining, Some(self.remaining))
}
}
impl<'a> BitmapIter<'a> {
#[allow(dead_code)]
fn new(words: &'a [u64], count: usize) -> Self {
Self {
words,
word_idx: 0,
bit_offset: 0,
remaining: count,
}
}
}
impl AllowedBitmap {
pub fn iter_simple(&self) -> impl Iterator<Item = u64> + '_ {
self.words.iter().enumerate().flat_map(|(word_idx, &word)| {
(0..64).filter_map(move |bit| {
if (word & (1 << bit)) != 0 {
Some(word_idx as u64 * 64 + bit as u64)
} else {
None
}
})
})
}
}
pub trait CandidateGate {
type Query;
type Result;
type Error;
fn execute_with_gate(
&self,
query: &Self::Query,
allowed_set: &AllowedSet,
) -> Result<Self::Result, Self::Error>;
fn strategy_for_selectivity(&self, selectivity: f64) -> ExecutionStrategy {
if selectivity >= 0.1 {
ExecutionStrategy::FilterDuringSearch
} else if selectivity >= 0.001 {
ExecutionStrategy::ScanAllowedIds
} else {
ExecutionStrategy::LinearScan
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExecutionStrategy {
FilterDuringSearch,
ScanAllowedIds,
LinearScan,
Reject,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_allowed_set_contains() {
let all = AllowedSet::All;
assert!(all.contains(0));
assert!(all.contains(1000000));
let none = AllowedSet::None;
assert!(!none.contains(0));
let vec = AllowedSet::from_sorted_vec(vec![1, 3, 5, 7, 9]);
assert!(vec.contains(1));
assert!(vec.contains(5));
assert!(!vec.contains(2));
assert!(!vec.contains(10));
let set = AllowedSet::from_iter([1, 3, 5, 7, 9]);
assert!(set.contains(1));
assert!(set.contains(5));
assert!(!set.contains(2));
}
#[test]
fn test_allowed_set_selectivity() {
let set = AllowedSet::from_sorted_vec(vec![1, 2, 3, 4, 5]);
assert_eq!(set.selectivity(100), 0.05);
assert_eq!(set.selectivity(10), 0.5);
assert_eq!(AllowedSet::All.selectivity(100), 1.0);
assert_eq!(AllowedSet::None.selectivity(100), 0.0);
}
#[test]
fn test_allowed_set_intersection() {
let a = AllowedSet::from_sorted_vec(vec![1, 2, 3, 4, 5]);
let b = AllowedSet::from_sorted_vec(vec![3, 4, 5, 6, 7]);
let c = a.intersect(&b);
assert_eq!(c.cardinality(), Some(3));
assert!(c.contains(3));
assert!(c.contains(4));
assert!(c.contains(5));
assert!(!c.contains(1));
assert!(!c.contains(7));
}
#[test]
fn test_bitmap_basic() {
let mut bm = AllowedBitmap::new();
bm.set(0);
bm.set(5);
bm.set(64);
bm.set(100);
assert!(bm.contains(0));
assert!(bm.contains(5));
assert!(bm.contains(64));
assert!(bm.contains(100));
assert!(!bm.contains(1));
assert!(!bm.contains(63));
assert_eq!(bm.count(), 4);
}
#[test]
fn test_bitmap_from_ids() {
let ids = vec![1, 5, 10, 100, 1000];
let bm = AllowedBitmap::from_ids(&ids);
for &id in &ids {
assert!(bm.contains(id));
}
assert!(!bm.contains(0));
assert!(!bm.contains(50));
}
#[test]
fn test_bitmap_intersection() {
let a = AllowedBitmap::from_ids(&[1, 2, 3, 4, 5]);
let b = AllowedBitmap::from_ids(&[3, 4, 5, 6, 7]);
let c = a.intersect(&b);
assert_eq!(c.count(), 3);
assert!(c.contains(3));
assert!(c.contains(4));
assert!(c.contains(5));
}
#[test]
fn test_execution_strategy() {
struct DummyGate;
impl CandidateGate for DummyGate {
type Query = ();
type Result = ();
type Error = ();
fn execute_with_gate(&self, _: &(), _: &AllowedSet) -> Result<(), ()> {
Ok(())
}
}
let gate = DummyGate;
assert_eq!(
gate.strategy_for_selectivity(0.5),
ExecutionStrategy::FilterDuringSearch
);
assert_eq!(
gate.strategy_for_selectivity(0.01),
ExecutionStrategy::ScanAllowedIds
);
assert_eq!(
gate.strategy_for_selectivity(0.0001),
ExecutionStrategy::LinearScan
);
}
}