use core::mem::size_of;
use core::ops::{Add, BitAnd, BitOr, Not, Shl, Shr, Sub};
#[derive(Clone, Copy, PartialEq, Eq)]
#[cfg_attr(
feature = "enable-serde",
derive(serde_derive::Serialize, serde_derive::Deserialize)
)]
pub struct ScalarBitSet<T>(pub T);
impl<T> core::fmt::Debug for ScalarBitSet<T>
where
T: ScalarBitSetStorage,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let mut s = f.debug_struct(core::any::type_name::<Self>());
for i in 0..Self::capacity() {
use alloc::string::ToString;
let i = u8::try_from(i).unwrap();
s.field(&i.to_string(), &self.contains(i));
}
s.finish()
}
}
impl<T> Default for ScalarBitSet<T>
where
T: ScalarBitSetStorage,
{
#[inline]
fn default() -> Self {
Self::new()
}
}
impl<T> ScalarBitSet<T>
where
T: ScalarBitSetStorage,
{
#[inline]
pub fn new() -> Self {
Self(T::from(0))
}
#[inline]
pub fn from_range(lo: u8, hi: u8) -> Self {
assert!(lo <= hi);
assert!(hi <= Self::capacity());
let one = T::from(1);
let hi_rng = if hi >= 1 {
(one << (hi - 1)) + ((one << (hi - 1)) - one)
} else {
T::from(0)
};
let lo_rng = (one << lo) - one;
Self(hi_rng - lo_rng)
}
#[inline]
pub fn capacity() -> u8 {
u8::try_from(size_of::<T>()).unwrap() * 8
}
#[inline]
pub fn len(&self) -> u8 {
self.0.count_ones()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.0 == T::from(0)
}
#[inline]
pub fn contains(&self, i: u8) -> bool {
assert!(i < Self::capacity());
self.0 & (T::from(1) << i) != T::from(0)
}
#[inline]
pub fn insert(&mut self, i: u8) -> bool {
let is_new = !self.contains(i);
self.0 = self.0 | (T::from(1) << i);
is_new
}
#[inline]
pub fn remove(&mut self, i: u8) -> bool {
let was_present = self.contains(i);
self.0 = self.0 & !(T::from(1) << i);
was_present
}
#[inline]
pub fn clear(&mut self) {
self.0 = T::from(0);
}
#[inline]
pub fn pop(&mut self) -> Option<u8> {
let max = self.max()?;
self.remove(max);
Some(max)
}
#[inline]
pub fn min(&self) -> Option<u8> {
if self.0 == T::from(0) {
None
} else {
Some(self.0.trailing_zeros())
}
}
#[inline]
pub fn max(&self) -> Option<u8> {
if self.0 == T::from(0) {
None
} else {
let leading_zeroes = self.0.leading_zeros();
Some(Self::capacity() - leading_zeroes - 1)
}
}
#[inline]
pub fn iter(&self) -> Iter<T> {
Iter {
value: self.0,
index: 0,
}
}
}
impl<T> IntoIterator for ScalarBitSet<T>
where
T: ScalarBitSetStorage,
{
type Item = u8;
type IntoIter = Iter<T>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<'a, T> IntoIterator for &'a ScalarBitSet<T>
where
T: ScalarBitSetStorage,
{
type Item = u8;
type IntoIter = Iter<T>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
pub trait ScalarBitSetStorage:
Default
+ From<u8>
+ Shl<u8, Output = Self>
+ Shr<u8, Output = Self>
+ BitAnd<Output = Self>
+ BitOr<Output = Self>
+ Not<Output = Self>
+ Sub<Output = Self>
+ Add<Output = Self>
+ PartialEq
+ Copy
{
fn leading_zeros(self) -> u8;
fn trailing_zeros(self) -> u8;
fn count_ones(self) -> u8;
}
macro_rules! impl_storage {
( $int:ty ) => {
impl ScalarBitSetStorage for $int {
fn leading_zeros(self) -> u8 {
u8::try_from(self.leading_zeros()).unwrap()
}
fn trailing_zeros(self) -> u8 {
u8::try_from(self.trailing_zeros()).unwrap()
}
fn count_ones(self) -> u8 {
u8::try_from(self.count_ones()).unwrap()
}
}
};
}
impl_storage!(u8);
impl_storage!(u16);
impl_storage!(u32);
impl_storage!(u64);
impl_storage!(u128);
impl_storage!(usize);
pub struct Iter<T> {
value: T,
index: u8,
}
impl<T> Iterator for Iter<T>
where
T: ScalarBitSetStorage,
{
type Item = u8;
#[inline]
fn next(&mut self) -> Option<u8> {
if self.value == T::from(0) {
None
} else {
let trailing_zeros = self.value.trailing_zeros();
let elem = self.index + trailing_zeros;
self.index += trailing_zeros + 1;
self.value = self.value >> (trailing_zeros + 1);
Some(elem)
}
}
}