use core::mem::size_of;
use core::ops::{Add, BitAnd, BitOr, Not, Shl, Shr, Sub};
#[derive(Clone, Copy, PartialEq, Eq)]
#[cfg_attr(
feature = "enable-serde",
derive(serde_derive::Serialize, serde_derive::Deserialize)
)]
pub struct ScalarBitSet<T>(pub T);
impl<T> core::fmt::Debug for ScalarBitSet<T>
where
T: ScalarBitSetStorage,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let mut s = f.debug_struct(core::any::type_name::<Self>());
for i in 0..Self::capacity() {
use alloc::string::ToString;
s.field(&i.to_string(), &self.contains(i));
}
s.finish()
}
}
impl<T> Default for ScalarBitSet<T>
where
T: ScalarBitSetStorage,
{
#[inline]
fn default() -> Self {
Self::new()
}
}
impl<T> ScalarBitSet<T>
where
T: ScalarBitSetStorage,
{
#[inline]
pub fn new() -> Self {
Self(T::from(0))
}
#[inline]
pub fn from_range(lo: u8, hi: u8) -> Self {
assert!(lo <= hi);
assert!(hi <= Self::capacity());
let one = T::from(1);
let hi_rng = if hi >= 1 {
(one << (hi - 1)) + ((one << (hi - 1)) - one)
} else {
T::from(0)
};
let lo_rng = (one << lo) - one;
Self(hi_rng - lo_rng)
}
#[inline]
pub fn capacity() -> u8 {
u8::try_from(size_of::<T>()).unwrap() * 8
}
#[inline]
pub fn len(&self) -> u8 {
self.0.count_ones()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.0 == T::from(0)
}
#[inline]
pub fn contains(&self, i: u8) -> bool {
assert!(i < Self::capacity());
self.0 & (T::from(1) << i) != T::from(0)
}
#[inline]
pub fn insert(&mut self, i: u8) -> bool {
let is_new = !self.contains(i);
self.0 = self.0 | (T::from(1) << i);
is_new
}
#[inline]
pub fn remove(&mut self, i: u8) -> bool {
let was_present = self.contains(i);
self.0 = self.0 & !(T::from(1) << i);
was_present
}
#[inline]
pub fn clear(&mut self) {
self.0 = T::from(0);
}
#[inline]
pub fn pop_min(&mut self) -> Option<u8> {
let min = self.min()?;
self.remove(min);
Some(min)
}
#[inline]
pub fn pop_max(&mut self) -> Option<u8> {
let max = self.max()?;
self.remove(max);
Some(max)
}
#[inline]
pub fn min(&self) -> Option<u8> {
if self.0 == T::from(0) {
None
} else {
Some(self.0.trailing_zeros())
}
}
#[inline]
pub fn max(&self) -> Option<u8> {
if self.0 == T::from(0) {
None
} else {
let leading_zeroes = self.0.leading_zeros();
Some(Self::capacity() - leading_zeroes - 1)
}
}
#[inline]
pub fn iter(self) -> Iter<T> {
Iter { bitset: self }
}
}
impl<T> IntoIterator for ScalarBitSet<T>
where
T: ScalarBitSetStorage,
{
type Item = u8;
type IntoIter = Iter<T>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<'a, T> IntoIterator for &'a ScalarBitSet<T>
where
T: ScalarBitSetStorage,
{
type Item = u8;
type IntoIter = Iter<T>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<T: ScalarBitSetStorage> From<T> for ScalarBitSet<T> {
fn from(bits: T) -> Self {
Self(bits)
}
}
pub trait ScalarBitSetStorage:
Default
+ From<u8>
+ Shl<u8, Output = Self>
+ Shr<u8, Output = Self>
+ BitAnd<Output = Self>
+ BitOr<Output = Self>
+ Not<Output = Self>
+ Sub<Output = Self>
+ Add<Output = Self>
+ PartialEq
+ Copy
{
fn leading_zeros(self) -> u8;
fn trailing_zeros(self) -> u8;
fn count_ones(self) -> u8;
}
macro_rules! impl_storage {
( $int:ty ) => {
impl ScalarBitSetStorage for $int {
#[inline]
fn leading_zeros(self) -> u8 {
u8::try_from(self.leading_zeros()).unwrap()
}
#[inline]
fn trailing_zeros(self) -> u8 {
u8::try_from(self.trailing_zeros()).unwrap()
}
#[inline]
fn count_ones(self) -> u8 {
u8::try_from(self.count_ones()).unwrap()
}
}
};
}
impl_storage!(u8);
impl_storage!(u16);
impl_storage!(u32);
impl_storage!(u64);
impl_storage!(u128);
impl_storage!(usize);
pub struct Iter<T> {
bitset: ScalarBitSet<T>,
}
impl<T> Iterator for Iter<T>
where
T: ScalarBitSetStorage,
{
type Item = u8;
#[inline]
fn next(&mut self) -> Option<u8> {
self.bitset.pop_min()
}
}
impl<T> DoubleEndedIterator for Iter<T>
where
T: ScalarBitSetStorage,
{
#[inline]
fn next_back(&mut self) -> Option<Self::Item> {
self.bitset.pop_max()
}
}
impl<T> ExactSizeIterator for Iter<T>
where
T: ScalarBitSetStorage,
{
#[inline]
fn len(&self) -> usize {
usize::from(self.bitset.len())
}
}
#[cfg(feature = "arbitrary")]
impl<'a, T> arbitrary::Arbitrary<'a> for ScalarBitSet<T>
where
T: ScalarBitSetStorage + arbitrary::Arbitrary<'a>,
{
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
T::arbitrary(u).map(Self)
}
}