use std::{
fmt::Debug,
ops::{BitAnd, BitOr, BitXor, Sub},
panic::{RefUnwindSafe, UnwindSafe},
};
use num_traits::NumCast;
use crate::{integer::Integer, BackingType, BIT_WIDTH};
#[derive(Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct BitSet<const LOWER: usize, const UPPER: usize, T: Integer> {
data: Vec<BackingType>,
len: usize,
lower_cast: T,
upper_cast: T,
}
impl<const LOWER: usize, const UPPER: usize, T: Integer> Debug for BitSet<LOWER, UPPER, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_set().entries(self.iter()).finish()
}
}
impl<const LOWER: usize, const UPPER: usize, T: Integer> Default for BitSet<LOWER, UPPER, T> {
fn default() -> Self {
Self {
data: vec![0; (UPPER - LOWER).div_ceil(BIT_WIDTH) + 1],
len: 0,
lower_cast: if let Some(lower_cast) = NumCast::from(LOWER) {
lower_cast
} else {
panic!(
"Unable to cast LOWER bound of BitSet<{}, {}> into \
associated type",
LOWER, UPPER,
)
},
upper_cast: if let Some(upper_cast) = NumCast::from(UPPER) {
upper_cast
} else {
panic!(
"Unable to cast UPPER bound of BitSet<{}, {}> into \
associated type",
LOWER, UPPER,
)
},
}
}
}
macro_rules! bounds_check {
($x:expr, $self:expr) => {
assert!(
$x >= $self.lower_cast && $x <= $self.upper_cast,
"Out of bounds: BitSet<{}, {}> can never contain {:?}",
LOWER,
UPPER,
$x
);
};
}
impl<const LOWER: usize, const UPPER: usize, T: Integer> BitSet<LOWER, UPPER, T> {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
debug_assert_eq!(
self.data.iter().map(|x| x.count_ones()).sum::<u32>(),
self.len as u32
);
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn clear(&mut self) {
self.data.fill(0u64);
debug_assert_eq!(self.data.iter().map(|x| x.count_ones()).sum::<u32>(), 0);
self.len = 0;
}
fn position(x: T) -> (usize, usize) {
let y: usize = <usize as NumCast>::from(x).unwrap() - LOWER;
(y / BIT_WIDTH, y % BIT_WIDTH)
}
pub fn contains(&self, x: &T) -> bool {
bounds_check!(*x, self);
unsafe { self.contains_unsafe(*x) }
}
unsafe fn contains_unsafe(&self, x: T) -> bool {
let (idx, bit) = Self::position(x);
self.is_bit_set(idx, bit)
}
fn is_bit_set(&self, idx: usize, bit: usize) -> bool {
unsafe { *self.data.get_unchecked(idx) & (1 << bit) != 0 }
}
pub fn insert(&mut self, x: T) {
bounds_check!(x, self);
unsafe {
self.insert_unsafe(x);
}
}
unsafe fn insert_unsafe(&mut self, x: T) {
let (idx, bit) = Self::position(x);
self.len += Into::<usize>::into(!self.is_bit_set(idx, bit));
*self.data.get_unchecked_mut(idx) |= 1 << bit;
}
pub fn remove(&mut self, x: &T) {
bounds_check!(*x, self);
unsafe {
self.remove_unsafe(*x);
}
}
unsafe fn remove_unsafe(&mut self, x: T) {
let (idx, bit) = Self::position(x);
self.len = self.len.saturating_sub(self.is_bit_set(idx, bit).into());
*self.data.get_unchecked_mut(idx) ^= 1 << bit;
}
pub fn retain(&mut self, f: impl Fn(T) -> bool) {
for x in LOWER..=UPPER {
let x = NumCast::from(x).unwrap();
if !f(x) {
unsafe {
self.remove_unsafe(x);
}
}
}
}
pub fn take(&mut self, v: &T) -> Option<T> {
let v = *v;
unsafe {
if v >= self.lower_cast && v <= self.upper_cast && self.contains_unsafe(v) {
self.remove_unsafe(v);
Some(v)
} else {
None
}
}
}
pub fn iter(&self) -> Iter<'_, LOWER, UPPER, T> {
Iter::new(self)
}
pub fn drain(&mut self) -> Drain<'_, LOWER, UPPER, T> {
Drain::new(self)
}
pub fn difference<'a>(&'a self, rhs: &'a Self) -> Difference<'a, LOWER, UPPER, T> {
Difference::new(self, rhs)
}
pub fn symmetric_difference<'a>(
&'a self,
rhs: &'a Self,
) -> SymmetricDifference<'a, LOWER, UPPER, T> {
SymmetricDifference::new(self, rhs)
}
pub fn intersection<'a>(&'a self, rhs: &'a Self) -> Intersection<'a, LOWER, UPPER, T> {
Intersection::new(self, rhs)
}
pub fn union<'a>(&'a self, rhs: &'a Self) -> Union<'a, LOWER, UPPER, T> {
Union::new(self, rhs)
}
pub fn is_disjoint(&self, other: &Self) -> bool {
Intersection::new(self, other).count() == 0
}
pub fn is_subset(&self, other: &Self) -> bool {
if self.len() <= other.len() {
unsafe { self.iter().all(|v| other.contains_unsafe(v)) }
} else {
false
}
}
pub fn is_superset(&self, other: &Self) -> bool {
other.is_subset(self)
}
}
impl<const LOWER: usize, const UPPER: usize, T: Integer> BitAnd<&BitSet<LOWER, UPPER, T>>
for &BitSet<LOWER, UPPER, T>
{
type Output = BitSet<LOWER, UPPER, T>;
fn bitand(self, rhs: &BitSet<LOWER, UPPER, T>) -> Self::Output {
self.intersection(rhs).collect()
}
}
impl<const LOWER: usize, const UPPER: usize, T: Integer> BitOr<&BitSet<LOWER, UPPER, T>>
for &BitSet<LOWER, UPPER, T>
{
type Output = BitSet<LOWER, UPPER, T>;
fn bitor(self, rhs: &BitSet<LOWER, UPPER, T>) -> Self::Output {
self.union(rhs).collect()
}
}
impl<const LOWER: usize, const UPPER: usize, T: Integer> BitXor<&BitSet<LOWER, UPPER, T>>
for &BitSet<LOWER, UPPER, T>
{
type Output = BitSet<LOWER, UPPER, T>;
fn bitxor(self, rhs: &BitSet<LOWER, UPPER, T>) -> Self::Output {
self.symmetric_difference(rhs).collect()
}
}
impl<const LOWER: usize, const UPPER: usize, T: Integer> Extend<T> for BitSet<LOWER, UPPER, T> {
fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
for x in iter {
self.insert(x);
}
}
}
impl<'a, const LOWER: usize, const UPPER: usize, T: Integer> Extend<&'a T>
for BitSet<LOWER, UPPER, T>
{
fn extend<I: IntoIterator<Item = &'a T>>(&mut self, iter: I) {
for x in iter {
self.insert(*x);
}
}
}
impl<const LOWER: usize, const UPPER: usize, const M: usize, T: Integer> From<[T; M]>
for BitSet<LOWER, UPPER, T>
{
fn from(value: [T; M]) -> Self {
Self::from_iter(value)
}
}
impl<const LOWER: usize, const UPPER: usize, T: Integer> FromIterator<T>
for BitSet<LOWER, UPPER, T>
{
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
let mut s = Self::new();
s.extend(iter);
s
}
}
impl<'a, const LOWER: usize, const UPPER: usize, T: Integer> FromIterator<&'a T>
for BitSet<LOWER, UPPER, T>
{
fn from_iter<I: IntoIterator<Item = &'a T>>(iter: I) -> Self {
let mut s = Self::new();
s.extend(iter);
s
}
}
impl<const LOWER: usize, const UPPER: usize, T: Integer> IntoIterator for BitSet<LOWER, UPPER, T> {
type Item = T;
type IntoIter = IntoIter<LOWER, UPPER, T>;
fn into_iter(self) -> Self::IntoIter {
IntoIter::new(self)
}
}
impl<'a, const LOWER: usize, const UPPER: usize, T: Integer> IntoIterator
for &'a BitSet<LOWER, UPPER, T>
{
type Item = T;
type IntoIter = Iter<'a, LOWER, UPPER, T>;
fn into_iter(self) -> Self::IntoIter {
Iter::new(self)
}
}
impl<const LOWER: usize, const UPPER: usize, T: Integer + PartialEq> PartialEq
for BitSet<LOWER, UPPER, T>
{
fn eq(&self, other: &Self) -> bool {
self.len == other.len && self.iter().zip(other).all(|(x, y)| x == y)
}
}
impl<const LOWER: usize, const UPPER: usize, T: Integer + Eq> Eq for &BitSet<LOWER, UPPER, T> {}
impl<const LOWER: usize, const UPPER: usize, T: Integer> Sub for &BitSet<LOWER, UPPER, T> {
type Output = BitSet<LOWER, UPPER, T>;
fn sub(self, rhs: &BitSet<LOWER, UPPER, T>) -> Self::Output {
self.difference(rhs).collect()
}
}
unsafe impl<const LOWER: usize, const UPPER: usize, T: Integer> Send for BitSet<LOWER, UPPER, T> {}
unsafe impl<const LOWER: usize, const UPPER: usize, T: Integer> Sync for BitSet<LOWER, UPPER, T> {}
impl<const LOWER: usize, const UPPER: usize, T: Integer> RefUnwindSafe for BitSet<LOWER, UPPER, T> {}
impl<const LOWER: usize, const UPPER: usize, T: Integer> UnwindSafe for BitSet<LOWER, UPPER, T> {}
#[derive(Debug, Clone, Copy)]
pub struct Difference<'a, const LOWER: usize, const UPPER: usize, T: Integer> {
index: usize,
lhs: &'a BitSet<LOWER, UPPER, T>,
rhs: &'a BitSet<LOWER, UPPER, T>,
}
impl<'a, const LOWER: usize, const UPPER: usize, T: Integer> Difference<'a, LOWER, UPPER, T> {
fn new(lhs: &'a BitSet<LOWER, UPPER, T>, rhs: &'a BitSet<LOWER, UPPER, T>) -> Self {
Self {
index: LOWER,
lhs,
rhs,
}
}
}
fn generic_next_binop<
const LOWER: usize,
const UPPER: usize,
T: Integer,
F: Fn(bool, bool) -> bool,
>(
index: &mut usize,
lhs: &BitSet<LOWER, UPPER, T>,
rhs: &BitSet<LOWER, UPPER, T>,
op: F,
) -> Option<T> {
while *index <= UPPER {
let v: T = NumCast::from(*index).unwrap();
*index += 1;
unsafe {
if op(lhs.contains_unsafe(v), rhs.contains_unsafe(v)) {
return Some(v);
}
}
}
None
}
fn generic_size_hint_binop<const LOWER: usize, const UPPER: usize, T: Integer>(
lhs: &BitSet<LOWER, UPPER, T>,
rhs: &BitSet<LOWER, UPPER, T>,
) -> (usize, Option<usize>) {
(lhs.len.min(rhs.len), Some(UPPER - LOWER))
}
impl<const LOWER: usize, const UPPER: usize, T: Integer> Iterator
for Difference<'_, LOWER, UPPER, T>
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
generic_next_binop(&mut self.index, self.lhs, self.rhs, |x, y| x && !y)
}
fn size_hint(&self) -> (usize, Option<usize>) {
generic_size_hint_binop(self.lhs, self.rhs)
}
}
#[derive(Debug, Clone, Copy)]
pub struct SymmetricDifference<'a, const LOWER: usize, const UPPER: usize, T: Integer> {
index: usize,
lhs: &'a BitSet<LOWER, UPPER, T>,
rhs: &'a BitSet<LOWER, UPPER, T>,
}
impl<'a, const LOWER: usize, const UPPER: usize, T: Integer>
SymmetricDifference<'a, LOWER, UPPER, T>
{
fn new(lhs: &'a BitSet<LOWER, UPPER, T>, rhs: &'a BitSet<LOWER, UPPER, T>) -> Self {
Self {
index: LOWER,
lhs,
rhs,
}
}
}
impl<const LOWER: usize, const UPPER: usize, T: Integer> Iterator
for SymmetricDifference<'_, LOWER, UPPER, T>
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
generic_next_binop(&mut self.index, self.lhs, self.rhs, |x, y| x ^ y)
}
fn size_hint(&self) -> (usize, Option<usize>) {
generic_size_hint_binop(self.lhs, self.rhs)
}
}
#[derive(Debug, Clone, Copy)]
pub struct Intersection<'a, const LOWER: usize, const UPPER: usize, T: Integer> {
index: usize,
lhs: &'a BitSet<LOWER, UPPER, T>,
rhs: &'a BitSet<LOWER, UPPER, T>,
}
impl<'a, const LOWER: usize, const UPPER: usize, T: Integer> Intersection<'a, LOWER, UPPER, T> {
fn new(lhs: &'a BitSet<LOWER, UPPER, T>, rhs: &'a BitSet<LOWER, UPPER, T>) -> Self {
Self {
index: LOWER,
lhs,
rhs,
}
}
}
impl<const LOWER: usize, const UPPER: usize, T: Integer> Iterator
for Intersection<'_, LOWER, UPPER, T>
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
generic_next_binop(&mut self.index, self.lhs, self.rhs, |x, y| x && y)
}
fn size_hint(&self) -> (usize, Option<usize>) {
generic_size_hint_binop(self.lhs, self.rhs)
}
}
#[derive(Debug, Clone, Copy)]
pub struct Union<'a, const LOWER: usize, const UPPER: usize, T: Integer> {
index: usize,
lhs: &'a BitSet<LOWER, UPPER, T>,
rhs: &'a BitSet<LOWER, UPPER, T>,
}
impl<'a, const LOWER: usize, const UPPER: usize, T: Integer> Union<'a, LOWER, UPPER, T> {
fn new(lhs: &'a BitSet<LOWER, UPPER, T>, rhs: &'a BitSet<LOWER, UPPER, T>) -> Self {
Self {
index: LOWER,
lhs,
rhs,
}
}
}
impl<const LOWER: usize, const UPPER: usize, T: Integer> Iterator for Union<'_, LOWER, UPPER, T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
generic_next_binop(&mut self.index, self.lhs, self.rhs, |x, y| x || y)
}
fn size_hint(&self) -> (usize, Option<usize>) {
generic_size_hint_binop(self.lhs, self.rhs)
}
}
#[derive(Debug, Clone, Copy)]
pub struct Iter<'a, const LOWER: usize, const UPPER: usize, T: Integer> {
index: usize,
collection: &'a BitSet<LOWER, UPPER, T>,
}
impl<'a, const LOWER: usize, const UPPER: usize, T: Integer> Iter<'a, LOWER, UPPER, T> {
fn new(collection: &'a BitSet<LOWER, UPPER, T>) -> Self {
Self {
index: LOWER,
collection,
}
}
}
impl<const LOWER: usize, const UPPER: usize, T: Integer> Iterator for Iter<'_, LOWER, UPPER, T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
while self.index <= UPPER {
let v: T = NumCast::from(self.index).unwrap();
self.index += 1;
unsafe {
if self.collection.contains_unsafe(v) {
return Some(v);
}
}
}
None
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.collection.len, Some(UPPER - LOWER))
}
}
#[derive(Debug)]
pub struct Drain<'a, const LOWER: usize, const UPPER: usize, T: Integer> {
index: usize,
collection: &'a mut BitSet<LOWER, UPPER, T>,
}
impl<'a, const LOWER: usize, const UPPER: usize, T: Integer> Drain<'a, LOWER, UPPER, T> {
fn new(collection: &'a mut BitSet<LOWER, UPPER, T>) -> Self {
Self {
index: LOWER,
collection,
}
}
}
impl<const LOWER: usize, const UPPER: usize, T: Integer> Iterator for Drain<'_, LOWER, UPPER, T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
while self.index <= UPPER {
let v: T = NumCast::from(self.index).unwrap();
self.index += 1;
unsafe {
if self.collection.contains_unsafe(v) {
self.collection.remove_unsafe(v);
return Some(v);
}
}
}
None
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.collection.len, Some(UPPER - LOWER))
}
}
#[derive(Debug)]
pub struct IntoIter<const LOWER: usize, const UPPER: usize, T: Integer> {
index: usize,
collection: BitSet<LOWER, UPPER, T>,
}
impl<const LOWER: usize, const UPPER: usize, T: Integer> IntoIter<LOWER, UPPER, T> {
fn new(collection: BitSet<LOWER, UPPER, T>) -> Self {
Self {
index: LOWER,
collection,
}
}
}
impl<const LOWER: usize, const UPPER: usize, T: Integer> Iterator for IntoIter<LOWER, UPPER, T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
while self.index <= UPPER {
let v: T = NumCast::from(self.index).unwrap();
self.index += 1;
unsafe {
if self.collection.contains_unsafe(v) {
self.collection.remove_unsafe(v);
return Some(v);
}
}
}
None
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.collection.len, Some(UPPER - LOWER))
}
}