use std::{fmt, hash::Hash, marker::PhantomData};
pub(crate) trait BitSetElement:
Copy + TryInto<usize> + TryFrom<usize> + Default + fmt::Debug + Eq + Hash
{
}
impl BitSetElement for u32 {}
impl BitSetElement for u64 {}
impl BitSetElement for usize {}
#[derive(Clone, Default)]
pub(crate) struct BitSet<Elem: BitSetElement = usize>(Vec<usize>, PhantomData<Elem>);
const BUCKET_SIZE: usize = std::mem::size_of::<usize>() * 8;
const fn bucket_and_bit(n: usize) -> (usize, usize) {
(n / BUCKET_SIZE, 1 << (n % BUCKET_SIZE))
}
const fn from_bucket_and_bit(bucket: usize, bit: usize) -> usize {
(bucket * BUCKET_SIZE) | (bit.trailing_zeros() as usize)
}
impl<Elem: BitSetElement> BitSet<Elem> {
pub(crate) fn new() -> Self {
Self(Vec::new(), PhantomData)
}
pub(crate) fn with_capacity(capacity: usize) -> Self {
let buckets_capacity = (capacity + BUCKET_SIZE - 1) / BUCKET_SIZE;
Self(Vec::with_capacity(buckets_capacity), PhantomData)
}
pub(crate) fn with_zeroed(capacity: usize) -> Self {
let buckets_capacity = (capacity + BUCKET_SIZE - 1) / BUCKET_SIZE;
Self(vec![0; buckets_capacity], PhantomData)
}
pub(crate) fn insert(&mut self, value: Elem) -> bool {
let value = match value.try_into() {
Ok(v) => v,
#[cfg(debug_assertions)]
Err(_) => unreachable!("error converting {:?} to usize", value),
#[cfg(not(debug_assertions))]
Err(_) => return false,
};
let (bucket, bit_in_bucket) = bucket_and_bit(value);
if bucket >= self.0.len() {
self.0.resize(bucket + 1, 0);
}
let bucket = &mut self.0[bucket];
let vacant = (*bucket & bit_in_bucket) == 0;
*bucket |= bit_in_bucket;
vacant
}
pub(crate) fn remove(&mut self, value: Elem) -> bool {
let value = match value.try_into() {
Ok(v) => v,
Err(_) => return false,
};
let (bucket, bit_in_bucket) = bucket_and_bit(value);
match self.0.get_mut(bucket) {
Some(bucket) => {
let present = (*bucket & bit_in_bucket) != 0;
*bucket &= !bit_in_bucket;
present
}
None => false,
}
}
pub(crate) fn clear(&mut self) {
self.0.clear();
}
pub(crate) fn contains(&self, value: Elem) -> bool {
let value = match value.try_into() {
Ok(v) => v,
Err(_) => return false,
};
let (bucket, bit_in_bucket) = bucket_and_bit(value);
if let Some(bucket) = self.0.get(bucket) {
(*bucket & bit_in_bucket) != 0
} else {
false
}
}
pub(crate) fn is_empty(&self) -> bool {
for bucket in &self.0 {
if *bucket != 0 {
return false;
}
}
true
}
pub(crate) fn intersects_with(&self, other: &Self) -> bool {
for (&x, &y) in self.0.iter().zip(&other.0) {
if x & y != 0 {
return true;
}
}
false
}
pub(crate) fn iter(&self) -> impl Iterator<Item = Elem> + '_ {
BitSetIter {
bitset: self,
next_index: 0,
current_bucket: 0,
}
}
pub(crate) fn as_slice(&self) -> &[usize] {
&self.0
}
pub(crate) fn as_mut_slice(&mut self) -> &mut [usize] {
&mut self.0
}
}
impl<Elem: BitSetElement> PartialEq for BitSet<Elem> {
fn eq(&self, other: &Self) -> bool {
let mut self_iter = self.0.iter().peekable();
let mut other_iter = other.0.iter().peekable();
while let (Some(&&self_bucket), Some(&&other_bucket)) =
(self_iter.peek(), other_iter.peek())
{
self_iter.next();
other_iter.next();
if self_bucket != other_bucket {
return false;
}
}
for item in self_iter.chain(other_iter) {
if *item != 0 {
return false;
}
}
true
}
}
impl<Elem: BitSetElement> Eq for BitSet<Elem> {}
impl<Elem: BitSetElement> Hash for BitSet<Elem> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
let mut len = 0;
for bucket in self.0.iter().skip_while(|bucket| **bucket == 0) {
len += 1;
bucket.hash(state);
}
len.hash(state);
}
}
impl<Elem: BitSetElement> fmt::Debug for BitSet<Elem> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_set().entries(self.iter()).finish()
}
}
struct BitSetIter<'a, Elem: BitSetElement> {
bitset: &'a BitSet<Elem>,
next_index: usize,
current_bucket: usize,
}
impl<Elem: BitSetElement> Iterator for BitSetIter<'_, Elem> {
type Item = Elem;
fn next(&mut self) -> Option<Self::Item> {
while self.current_bucket == 0 {
self.current_bucket = *self.bitset.0.get(self.next_index)?;
self.next_index += 1;
}
let bit = self.current_bucket.trailing_zeros();
self.current_bucket &= !(1 << bit);
match Elem::try_from(from_bucket_and_bit(self.next_index - 1, 1 << bit)) {
Ok(result) => Some(result),
#[cfg(debug_assertions)]
Err(_) => unreachable!(""),
#[cfg(not(debug_assertions))]
Err(_) => Some(Elem::default()),
}
}
}
impl<Elem: BitSetElement> FromIterator<Elem> for BitSet<Elem> {
fn from_iter<T: IntoIterator<Item = Elem>>(iter: T) -> Self {
let mut result = BitSet::new();
for item in iter {
result.insert(item);
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_common::collect;
#[test]
fn bucket_and_bit_conversions() {
for n in [1, 2, 3, 5, 512, 513, 1000] {
let (bucket, bit) = bucket_and_bit(n);
let n2 = from_bucket_and_bit(bucket, bit);
assert_eq!(n, n2, "bucket = {bucket}, bit = {bit:#b}");
}
assert_eq!(bucket_and_bit(0), (0, 0b1));
assert_eq!(bucket_and_bit(2), (0, 0b100));
assert_eq!(bucket_and_bit(5), (0, 0b100000));
assert_eq!(bucket_and_bit(513), (8, 0b10));
}
#[test]
fn basic_contain_elements() {
let mut bitset = BitSet::<u64>::new();
assert!(bitset.is_empty());
assert!(bitset.insert(5));
assert!(!bitset.is_empty());
assert!(!bitset.insert(5));
assert!(bitset.insert(1));
assert!(bitset.insert(2));
assert!(bitset.insert(3));
assert!(!bitset.insert(1));
assert!(bitset.insert(512));
assert!(bitset.insert(513));
assert!(bitset.insert(514));
assert!(bitset.insert(1000));
assert!(!bitset.is_empty());
assert!(bitset.contains(1));
assert!(bitset.contains(2));
assert!(bitset.contains(3));
assert!(bitset.contains(5));
assert!(bitset.contains(512));
assert!(bitset.contains(513));
assert!(bitset.contains(514));
assert!(bitset.contains(1000));
assert!(!bitset.contains(4));
assert!(!bitset.contains(6));
assert!(!bitset.contains(511));
assert!(!bitset.contains(515));
assert!(!bitset.contains(998));
assert!(!bitset.contains(999));
assert!(!bitset.contains(1001));
}
#[test]
fn equality() {
let bitset1: BitSet = collect(&[1, 2, 3, 5]);
let bitset2: BitSet = BitSet(vec![0b101110], PhantomData);
let bitset3: BitSet = BitSet(vec![0b101110, 0], PhantomData);
let bitset4: BitSet = BitSet(vec![0b11110], PhantomData);
assert_eq!(bitset1, bitset2);
assert_eq!(bitset1, bitset3);
assert_eq!(bitset2, bitset3);
assert_ne!(bitset1, bitset4);
assert_ne!(bitset2, bitset4);
assert_ne!(bitset3, bitset4);
}
#[test]
fn clear_and_empty() {
let empty_bitset = BitSet::new();
let mut bitset = BitSet::new();
for i in 0u64..50 {
bitset.insert(i * 300);
}
assert_ne!(bitset, BitSet::new());
bitset.clear();
assert!(bitset.is_empty());
assert_eq!(bitset, empty_bitset);
}
#[test]
fn iterator() {
let bitset: BitSet = collect(&[1, 2, 3, 5, 512, 513, 514, 1000]);
assert_eq!(
bitset.iter().collect::<Vec<_>>(),
vec![1, 2, 3, 5, 512, 513, 514, 1000]
);
}
}