#![no_std]
#![cfg_attr(docsrs, feature(doc_cfg))]
use core::{
fmt,
iter::{Chain, FusedIterator},
mem,
ops::{Bound, Not, RangeBounds},
};
use num_traits::{Bounded, PrimInt};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize, de::Visitor, ser::SerializeSeq};
pub type BitSet8 = BitSet<u8, 1>;
pub type BitSet16 = BitSet<u16, 1>;
pub type BitSet32 = BitSet<u32, 1>;
pub type BitSet64 = BitSet<u64, 1>;
pub type BitSet128 = BitSet<u64, 2>;
pub type BitSet256 = BitSet<u64, 4>;
pub type BitSet512 = BitSet<u64, 8>;
pub type BitSet1024 = BitSet<u64, 16>;
#[repr(transparent)]
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct BitSet<T, const N: usize> {
inner: [T; N],
}
impl<T: PrimInt + Default, const N: usize> Default for BitSet<T, N> {
fn default() -> Self {
Self {
inner: [Default::default(); N],
}
}
}
impl<T: PrimInt, const N: usize> From<[T; N]> for BitSet<T, N> {
fn from(inner: [T; N]) -> Self {
Self { inner }
}
}
impl<T, const N: usize> fmt::Debug for BitSet<T, N>
where T: PrimInt
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_set().entries(self.iter()).finish()
}
}
impl<T, const N: usize> fmt::Binary for BitSet<T, N>
where T: Copy + fmt::Binary
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "BitSet ")?;
let mut list = f.debug_list();
for item in self.inner.iter() {
list.entry(&format_args!(
"{:#0width$b}",
item,
width = 2 + Self::item_size()
));
}
list.finish()
}
}
macro_rules! impl_new {
($($t:ty)+) => {
$(
impl<const N: usize> BitSet<$t, N> {
pub const fn new() -> Self {
Self { inner: [0; N] }
}
}
)+
};
}
impl_new!(i8 i16 i32 i64 i128 isize);
impl_new!(u8 u16 u32 u64 u128 usize);
impl<T: PrimInt + Default, const N: usize> BitSet<T, N> {
pub fn with_default() -> Self {
Self::default()
}
pub fn clear(&mut self) {
for item in self.inner.iter_mut() {
*item = Default::default()
}
}
}
impl<T, const N: usize> BitSet<T, N> {
pub fn into_inner(self) -> [T; N] {
self.inner
}
pub const fn capacity() -> usize {
N * Self::item_size()
}
const fn item_size() -> usize {
mem::size_of::<T>() * 8
}
pub const fn from_ref(inner: &mut [T; N]) -> &mut Self {
debug_assert!(
size_of::<T>() <= 128,
"`T` should be one of type `{{integer}}`"
);
unsafe { mem::transmute(inner) }
}
}
impl<T: PrimInt, const N: usize> BitSet<T, N> {
fn location(bit: usize) -> (usize, T) {
let index = bit / Self::item_size();
let bitmask = T::one() << (bit & (Self::item_size() - 1));
(index, bitmask)
}
pub fn try_append<U, const M: usize>(
&mut self, other: &mut BitSet<U, M>,
) -> Result<(), BitSetError>
where U: PrimInt {
for item in other.drain() {
self.try_insert(item)?;
}
Ok(())
}
#[inline]
pub fn try_insert(&mut self, bit: usize) -> Result<bool, BitSetError> {
if bit >= Self::capacity() {
return Err(BitSetError::BiggerThanCapacity);
}
let (index, bitmask) = Self::location(bit);
Ok(match self.inner.get_mut(index) {
Some(v) => {
let contains = *v & bitmask == bitmask;
*v = (*v) | bitmask;
!contains
},
None => false,
})
}
pub unsafe fn insert_unchecked(&mut self, bit: usize) -> bool {
let (index, bitmask) = Self::location(bit);
let v = self.inner.get_unchecked_mut(index);
let contains = *v & bitmask == bitmask;
*v = (*v) | bitmask;
!contains
}
pub fn try_remove(&mut self, bit: usize) -> Result<bool, BitSetError> {
if bit >= Self::capacity() {
return Err(BitSetError::BiggerThanCapacity);
}
let (index, bitmask) = Self::location(bit);
Ok(match self.inner.get_mut(index) {
Some(v) => {
let was_present = *v & bitmask == bitmask;
*v = (*v) & !bitmask;
was_present
},
None => false,
})
}
pub unsafe fn remove_unchecked(&mut self, bit: usize) -> bool {
let (index, bitmask) = Self::location(bit);
let v = self.inner.get_unchecked_mut(index);
let was_present = *v & bitmask == bitmask;
*v = (*v) & !bitmask;
was_present
}
pub fn append<U, const M: usize>(&mut self, other: &mut BitSet<U, M>)
where U: PrimInt {
for item in other.drain() {
self.insert(item);
}
}
#[inline]
pub fn insert(&mut self, bit: usize) -> bool {
self.try_insert(bit)
.expect("BitSet::insert called on an integer bigger than capacity")
}
pub fn remove(&mut self, bit: usize) -> bool {
self.try_remove(bit)
.expect("BitSet::remove called on an integer bigger than capacity")
}
pub fn retain<F>(&mut self, mut f: F)
where F: FnMut(usize) -> bool {
for value in self.clone().iter() {
if !f(value) {
self.remove(value);
}
}
}
pub fn contains(&self, bit: usize) -> bool {
if bit >= Self::capacity() {
return false;
}
let (index, bitmask) = Self::location(bit);
match self.inner.get(index) {
Some(&v) => v & bitmask == bitmask,
None => false,
}
}
pub fn try_contains(&self, bit: usize) -> Result<bool, BitSetError> {
if bit >= Self::capacity() {
return Err(BitSetError::BiggerThanCapacity);
}
let (index, bitmask) = Self::location(bit);
match self.inner.get(index) {
Some(&v) => Ok(v & bitmask == bitmask),
None => Err(BitSetError::BiggerThanCapacity),
}
}
#[inline]
pub fn len(&self) -> usize {
self.count_ones() as usize
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn is_disjoint<U: PrimInt, const M: usize>(&self, other: &BitSet<U, M>) -> bool {
if self.len() <= other.len() {
self.iter().all(|v| !other.contains(v))
} else {
other.iter().all(|v| !self.contains(v))
}
}
pub fn is_subset<U: PrimInt, const M: usize>(&self, other: &BitSet<U, M>) -> bool {
if self.len() <= other.len() {
self.iter().all(|v| other.contains(v))
} else {
false
}
}
#[inline]
pub fn is_superset<U: PrimInt, const M: usize>(&self, other: &BitSet<U, M>) -> bool {
other.is_subset(self)
}
pub fn count_ones(&self) -> u32 {
let mut total = 0;
for item in self.inner.iter() {
total += item.count_ones();
}
total
}
pub fn count_zeros(&self) -> u32 {
let mut total = 0;
for item in self.inner.iter() {
total += item.count_zeros();
}
total
}
pub fn drain(&mut self) -> Drain<'_, T, N> {
Drain { inner: self }
}
pub fn difference<'a, U: PrimInt, const M: usize>(
&'a self, other: &'a BitSet<U, M>,
) -> Difference<'a, T, U, N, M> {
Difference {
iter: self.iter(),
other,
}
}
pub fn intersection<'a, U: PrimInt, const M: usize>(
&'a self, other: &'a BitSet<U, M>,
) -> Intersection<'a, T, U, N, M> {
Intersection {
iter: self.iter(),
other,
}
}
pub fn symmetric_difference<'a, U: PrimInt, const M: usize>(
&'a self, other: &'a BitSet<U, M>,
) -> SymmetricDifference<'a, T, U, N, M> {
SymmetricDifference {
iter: self.difference(other).chain(other.difference(self)),
}
}
pub fn union<'a, U: PrimInt, const M: usize>(
&'a self, other: &'a BitSet<U, M>,
) -> Union<'a, T, U, N, M> {
if self.len() >= other.len() {
Union {
iter: UnionChoose::SelfBiggerThanOther(self.iter().chain(other.difference(self))),
}
} else {
Union {
iter: UnionChoose::SelfSmallerThanOther(other.iter().chain(self.difference(other))),
}
}
}
pub fn iter(&self) -> Iter<'_, T, N> {
Iter::new(self)
}
}
impl<T: Default + PrimInt, const N: usize> BitSet<T, N> {
pub fn fill<R: RangeBounds<usize>>(&mut self, range: R, on: bool) {
let mut start = match range.start_bound() {
Bound::Unbounded => 0,
Bound::Included(&i) => {
assert!(i <= Self::capacity(), "start bound is too big for capacity");
i
},
Bound::Excluded(&i) => {
assert!(i < Self::capacity(), "start bound is too big for capacity");
i + 1
},
};
let end = match range.end_bound() {
Bound::Unbounded => Self::capacity(),
Bound::Included(0) => return,
Bound::Included(&i) => {
assert!(
i - 1 <= Self::capacity(),
"end bound is too big for capacity"
);
i - 1
},
Bound::Excluded(&i) => {
assert!(i <= Self::capacity(), "end bound is too big for capacity");
i
},
};
if start >= end {
return;
}
let end_first = start - (start % Self::item_size()) + Self::item_size();
if start % Self::item_size() != 0 || end < end_first {
let end_first = end_first.min(end);
for bit in start..end_first {
if on {
self.insert(bit);
} else {
self.remove(bit);
}
}
if end == end_first {
return;
}
start = end_first + 1;
}
let start_last = end - (end % Self::item_size());
for i in start / Self::item_size()..start_last / Self::item_size() {
self.inner[i] = if on {
Bounded::max_value()
} else {
Default::default()
};
}
for bit in start_last..end {
if on {
self.insert(bit);
} else {
self.remove(bit);
}
}
}
}
impl<T: PrimInt + Default, U: Into<usize>, const N: usize> FromIterator<U> for BitSet<T, N> {
fn from_iter<I>(iter: I) -> Self
where I: IntoIterator<Item = U> {
let mut set = BitSet::with_default();
for bit in iter.into_iter() {
set.insert(bit.into());
}
set
}
}
impl<T: PrimInt, U: Into<usize>, const N: usize> Extend<U> for BitSet<T, N> {
fn extend<I: IntoIterator<Item = U>>(&mut self, iter: I) {
for bit in iter.into_iter() {
self.insert(bit.into());
}
}
}
impl<T: PrimInt, const N: usize> IntoIterator for BitSet<T, N> {
type IntoIter = IntoIter<T, N>;
type Item = usize;
fn into_iter(self) -> Self::IntoIter {
IntoIter(self)
}
}
impl<'a, T: PrimInt, const N: usize> IntoIterator for &'a BitSet<T, N> {
type IntoIter = Iter<'a, T, N>;
type Item = usize;
fn into_iter(self) -> Self::IntoIter {
Iter::new(self)
}
}
impl<T: PrimInt, const N: usize> Not for BitSet<T, N> {
type Output = Self;
fn not(mut self) -> Self::Output {
for item in self.inner.iter_mut() {
*item = !(*item);
}
self
}
}
#[cfg(feature = "serde")]
impl<T: PrimInt, const N: usize> Serialize for BitSet<T, N> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: serde::Serializer {
let mut seq = serializer.serialize_seq(Some(self.len()))?;
for ref e in self {
seq.serialize_element(e)?;
}
seq.end()
}
}
#[cfg(feature = "serde")]
impl<'de, T: PrimInt + Default, const N: usize> Deserialize<'de> for BitSet<T, N> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: serde::Deserializer<'de> {
use core::marker::PhantomData;
struct BitSetVisitor<T: PrimInt, const N: usize>(PhantomData<BitSet<T, N>>);
impl<'de, T: PrimInt + Default, const N: usize> Visitor<'de> for BitSetVisitor<T, N> {
type Value = BitSet<T, N>;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("a sequence")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where A: serde::de::SeqAccess<'de> {
let mut set = BitSet::with_default();
while let Some(value) = seq.next_element()? {
set.insert(value);
}
Ok(set)
}
}
let visitor = BitSetVisitor(PhantomData);
deserializer.deserialize_seq(visitor)
}
}
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[non_exhaustive]
pub enum BitSetError {
BiggerThanCapacity,
}
impl fmt::Display for BitSetError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::BiggerThanCapacity => f.pad("tried to insert value bigger than capacity"),
}
}
}
#[must_use = "iterators are lazy and do nothing unless consumed"]
pub struct Drain<'a, T: PrimInt + 'a, const N: usize> {
inner: &'a mut BitSet<T, N>,
}
impl<T: PrimInt, const N: usize> fmt::Debug for Drain<'_, T, N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&self.inner, f)
}
}
impl<T: PrimInt, const N: usize> Iterator for Drain<'_, T, N> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
for (index, item) in self.inner.inner.iter_mut().enumerate() {
if !item.is_zero() {
let bitindex = item.trailing_zeros() as usize;
*item = *item & (*item - T::one());
return Some(index * BitSet::<T, N>::item_size() + bitindex);
}
}
None
}
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.inner.len();
(len, Some(len))
}
fn count(self) -> usize
where Self: Sized {
self.len()
}
}
impl<T: PrimInt, const N: usize> ExactSizeIterator for Drain<'_, T, N> {
#[inline]
fn len(&self) -> usize {
self.inner.len()
}
}
impl<T: PrimInt, const N: usize> FusedIterator for Drain<'_, T, N> {}
#[derive(Clone)]
#[repr(transparent)]
#[must_use = "iterators are lazy and do nothing unless consumed"]
pub struct IntoIter<T, const N: usize>(BitSet<T, N>);
impl<T: PrimInt, const N: usize> fmt::Debug for IntoIter<T, N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_set().entries(self.clone()).finish()
}
}
impl<T: PrimInt, const N: usize> Iterator for IntoIter<T, N> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
for (index, item) in self.0.inner.iter_mut().enumerate() {
if !item.is_zero() {
let bitindex = item.trailing_zeros() as usize;
*item = *item & (*item - T::one());
return Some(index * BitSet::<T, N>::item_size() + bitindex);
}
}
None
}
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.0.len();
(len, Some(len))
}
fn count(self) -> usize
where Self: Sized {
self.len()
}
}
impl<T: PrimInt, const N: usize> DoubleEndedIterator for IntoIter<T, N> {
fn next_back(&mut self) -> Option<Self::Item> {
for (index, item) in self.0.inner.iter_mut().enumerate().rev() {
if !item.is_zero() {
let bitindex = BitSet::<T, N>::item_size() - 1 - item.leading_zeros() as usize;
*item = *item & !(T::one() << bitindex);
return Some(index * BitSet::<T, N>::item_size() + bitindex);
}
}
None
}
}
impl<T: PrimInt, const N: usize> FusedIterator for IntoIter<T, N> {}
impl<T: PrimInt, const N: usize> ExactSizeIterator for IntoIter<T, N> {
#[inline]
fn len(&self) -> usize {
self.0.len()
}
}
#[derive(Clone)]
#[must_use = "iterators are lazy and do nothing unless consumed"]
pub struct Iter<'a, T, const N: usize> {
borrow: &'a BitSet<T, N>,
bit: usize,
passed_count: usize,
}
impl<'a, T: PrimInt, const N: usize> Iter<'a, T, N> {
fn new(bitset: &'a BitSet<T, N>) -> Self {
Self {
borrow: bitset,
bit: 0,
passed_count: 0,
}
}
}
impl<T: PrimInt, const N: usize> fmt::Debug for Iter<'_, T, N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_set().entries(self.clone()).finish()
}
}
impl<T: PrimInt, const N: usize> Iterator for Iter<'_, T, N> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
while !self.borrow.try_contains(self.bit).ok()? {
self.bit = self.bit.saturating_add(1);
}
let res = self.bit;
self.bit = self.bit.saturating_add(1);
self.passed_count = self.passed_count.saturating_add(1);
Some(res)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.borrow.len() - self.passed_count;
(len, Some(len))
}
fn count(self) -> usize
where Self: Sized {
self.len()
}
}
impl<T: PrimInt, const N: usize> DoubleEndedIterator for Iter<'_, T, N> {
fn next_back(&mut self) -> Option<Self::Item> {
self.bit = self.bit.saturating_sub(2);
while !self.borrow.try_contains(self.bit).ok()? {
self.bit = self.bit.saturating_sub(1);
}
let res = self.bit;
self.bit = self.bit.saturating_sub(1);
self.passed_count = self.passed_count.saturating_sub(1);
Some(res)
}
}
impl<T: PrimInt, const N: usize> FusedIterator for Iter<'_, T, N> {}
impl<T: PrimInt, const N: usize> ExactSizeIterator for Iter<'_, T, N> {
#[inline]
fn len(&self) -> usize {
self.borrow.len() - self.passed_count
}
}
#[derive(Clone)]
#[must_use = "this returns the difference as an iterator, without modifying either input set"]
pub struct Difference<'a, T: PrimInt + 'a, U: PrimInt + 'a, const N: usize, const M: usize> {
iter: Iter<'a, T, N>,
other: &'a BitSet<U, M>,
}
impl<T, U, const N: usize, const M: usize> fmt::Debug for Difference<'_, T, U, N, M>
where
T: PrimInt,
U: PrimInt,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_set().entries(self.clone()).finish()
}
}
impl<'a, T, U, const N: usize, const M: usize> Iterator for Difference<'a, T, U, N, M>
where
T: PrimInt + 'a,
U: PrimInt + 'a,
{
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
loop {
let elt = self.iter.next()?;
if !self.other.contains(elt) {
return Some(elt);
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let (_, upper) = self.iter.size_hint();
(0, upper)
}
}
impl<T, U, const N: usize, const M: usize> FusedIterator for Difference<'_, T, U, N, M>
where
T: PrimInt,
U: PrimInt,
{
}
#[derive(Clone)]
#[must_use = "this returns the intersection as an iterator, without modifying either input set"]
pub struct Intersection<'a, T, U, const N: usize, const M: usize>
where
T: PrimInt + 'a,
U: PrimInt + 'a,
{
iter: Iter<'a, T, N>,
other: &'a BitSet<U, M>,
}
impl<T, U, const N: usize, const M: usize> fmt::Debug for Intersection<'_, T, U, N, M>
where
T: PrimInt,
U: PrimInt,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_set().entries(self.clone()).finish()
}
}
impl<'a, T, U, const N: usize, const M: usize> Iterator for Intersection<'a, T, U, N, M>
where
T: PrimInt + 'a,
U: PrimInt + 'a,
{
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
loop {
let elt = self.iter.next()?;
if self.other.contains(elt) {
return Some(elt);
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let (_, upper) = self.iter.size_hint();
(0, upper)
}
}
impl<T, U, const N: usize, const M: usize> FusedIterator for Intersection<'_, T, U, N, M>
where
T: PrimInt,
U: PrimInt,
{
}
#[derive(Clone)]
#[must_use = "this returns the union as an iterator, without modifying either input set"]
pub struct Union<'a, T: PrimInt + 'a, U: PrimInt + 'a, const N: usize, const M: usize> {
iter: UnionChoose<'a, T, U, N, M>,
}
impl<T, U, const N: usize, const M: usize> fmt::Debug for Union<'_, T, U, N, M>
where
T: PrimInt,
U: PrimInt,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_set().entries(self.clone()).finish()
}
}
impl<'a, T, U, const N: usize, const M: usize> Iterator for Union<'a, T, U, N, M>
where
T: PrimInt + 'a,
U: PrimInt + 'a,
{
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.iter.size_hint()
}
}
impl<T, U, const N: usize, const M: usize> FusedIterator for Union<'_, T, U, N, M>
where
T: PrimInt,
U: PrimInt,
{
}
#[derive(Clone)]
enum UnionChoose<'a, T: PrimInt, U: PrimInt, const N: usize, const M: usize> {
SelfBiggerThanOther(Chain<Iter<'a, T, N>, Difference<'a, U, T, M, N>>),
SelfSmallerThanOther(Chain<Iter<'a, U, M>, Difference<'a, T, U, N, M>>),
}
impl<'a, T, U, const N: usize, const M: usize> Iterator for UnionChoose<'a, T, U, N, M>
where
T: PrimInt + 'a,
U: PrimInt + 'a,
{
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
match self {
Self::SelfBiggerThanOther(iter) => iter.next(),
Self::SelfSmallerThanOther(iter) => iter.next(),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
match self {
Self::SelfBiggerThanOther(iter) => iter.size_hint(),
Self::SelfSmallerThanOther(iter) => iter.size_hint(),
}
}
}
impl<T, U, const N: usize, const M: usize> FusedIterator for UnionChoose<'_, T, U, N, M>
where
T: PrimInt,
U: PrimInt,
{
}
#[derive(Clone)]
#[must_use = "this returns the difference as an iterator, without modifying either input set"]
pub struct SymmetricDifference<'a, T, U, const N: usize, const M: usize>
where
T: PrimInt + 'a,
U: PrimInt + 'a,
{
iter: Chain<Difference<'a, T, U, N, M>, Difference<'a, U, T, M, N>>,
}
impl<T, U, const N: usize, const M: usize> fmt::Debug for SymmetricDifference<'_, T, U, N, M>
where
T: PrimInt,
U: PrimInt,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_set().entries(self.clone()).finish()
}
}
impl<'a, T, U, const N: usize, const M: usize> Iterator for SymmetricDifference<'a, T, U, N, M>
where
T: PrimInt + 'a,
U: PrimInt + 'a,
{
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.iter.size_hint()
}
}
impl<T, U, const N: usize, const M: usize> FusedIterator for SymmetricDifference<'_, T, U, N, M>
where
T: PrimInt,
U: PrimInt,
{
}
#[cfg(test)]
mod tests;
#[cfg(doc)]
#[doc = include_str!("../CHANGELOG.md")]
#[allow(rustdoc::broken_intra_doc_links)]
pub mod changelog {}