use crate::direct::macros::impl_direct_set_iter;
use crate::utils::bitsets::ones::OnesIter;
use crate::utils::bitsets::retain_word;
use alloc::boxed::Box;
use core::cmp::Ordering;
use core::fmt;
use core::fmt::{Debug, Formatter};
use core::hash::{Hash, Hasher};
use core::iter::FusedIterator;
use core::marker::PhantomData;
use core::ops::Index;
use intid::array::{Array, BitsetLimb};
use intid::{EnumId, EquivalentId};
#[derive(Clone)]
pub struct EnumSet<T: EnumId> {
limbs: T::BitSet,
len: u32,
marker: PhantomData<T>,
}
#[inline]
fn divmod_index(index: u32) -> (usize, u32) {
(
(index / BitsetLimb::BITS) as usize,
index % BitsetLimb::BITS,
)
}
#[inline]
fn bitmask_for(bit_index: u32) -> BitsetLimb {
let one: BitsetLimb = 1;
one << bit_index
}
impl<T: EnumId> EnumSet<T> {
#[inline]
pub fn new() -> Self {
assert_eq!(
crate::enums::verify_enum_type::<T, ()>().bitset_len,
Self::BITSET_LEN
);
let _assert_can_zero_init = <Self as crate::utils::Zeroable>::zeroed;
EnumSet {
limbs: unsafe { core::mem::zeroed() },
len: 0,
marker: PhantomData,
}
}
const BITSET_LEN: usize = <T::BitSet as intid::array::Array<BitsetLimb>>::LEN;
#[inline]
pub fn new_boxed() -> Box<Self> {
assert_eq!(
crate::enums::verify_enum_type::<T, ()>().bitset_len,
Self::BITSET_LEN
);
crate::utils::Zeroable::zeroed_boxed()
}
#[inline]
fn limbs(&self) -> &[BitsetLimb] {
self.limbs.as_ref()
}
#[inline]
fn limbs_mut(&mut self) -> &mut [BitsetLimb] {
self.limbs.as_mut()
}
#[cold]
fn index_overflow() -> ! {
panic!(
"An index for `{}` overflowed its claimed maximum",
core::any::type_name::<T>()
)
}
#[inline]
fn verified_index(key: &T) -> (usize, u32) {
let index = intid::uint::checked_cast::<_, u32>(key.to_int()).unwrap_or_else(|| {
if T::TRUSTED_RANGE.is_some() {
unsafe { core::hint::unreachable_unchecked() }
} else {
Self::index_overflow()
}
});
let (word_index, bit_index) = divmod_index(index);
if T::TRUSTED_RANGE.is_none() && word_index >= Self::BITSET_LEN {
Self::index_overflow();
}
(word_index, bit_index)
}
#[inline]
pub fn insert(&mut self, value: T) -> bool {
let (word_index, bit_index) = Self::verified_index(&value);
let word = unsafe { self.limbs_mut().get_unchecked_mut(word_index) };
let mask = bitmask_for(bit_index);
let was_present = (mask & *word) != 0;
*word |= mask;
!was_present
}
#[inline]
pub fn remove(&mut self, value: impl EquivalentId<T>) -> bool {
let value = value.as_id();
let (word_index, bit_index) = Self::verified_index(&value);
let word = unsafe { self.limbs_mut().get_unchecked_mut(word_index) };
let mask = bitmask_for(bit_index);
let was_present = (mask & *word) != 0;
*word &= !mask;
was_present
}
#[inline]
pub fn contains(&self, value: impl EquivalentId<T>) -> bool {
let (word_index, bit_index) = Self::verified_index(&value.as_id());
let word = unsafe { self.limbs().get_unchecked(word_index) };
(word & bitmask_for(bit_index)) != 0
}
#[inline]
pub fn iter(&self) -> Iter<'_, T> {
Iter {
len: self.len as usize,
handle: OnesIter::new(self.limbs().iter().copied()),
marker: PhantomData,
}
}
#[inline]
pub fn clear(&mut self) {
unsafe {
core::ptr::write_bytes(&mut self.limbs, 0, 1);
}
self.len = 0;
}
#[inline]
pub fn len(&self) -> usize {
self.len as usize
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn retain<F: FnMut(T) -> bool>(&mut self, mut func: F) {
for (word_index, word) in self.limbs.as_mut().iter_mut().enumerate() {
let (updated_word, word_removed) = retain_word(*word, |bit| {
let id = (word_index * 32) + (bit as usize);
let key = unsafe { T::from_int_unchecked(intid::uint::from_usize_wrapping(id)) };
func(key)
});
*word = updated_word;
self.len -= word_removed;
}
}
}
unsafe impl<T: EnumId> crate::utils::Zeroable for EnumSet<T> {}
impl<T: EnumId> Default for EnumSet<T> {
#[inline]
fn default() -> Self {
EnumSet::new()
}
}
impl<T: EnumId> PartialEq for EnumSet<T> {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.len == other.len && self.limbs() == other.limbs()
}
}
impl<T: EnumId> Eq for EnumSet<T> {}
impl<T: EnumId> Debug for EnumSet<T> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_set().entries(self.iter()).finish()
}
}
impl<T: EnumId> Extend<T> for EnumSet<T> {
#[inline]
fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
for value in iter {
self.insert(value);
}
}
}
impl<'a, T: EnumId> Extend<&'a T> for EnumSet<T> {
#[inline]
fn extend<I: IntoIterator<Item = &'a T>>(&mut self, iter: I) {
self.extend(iter.into_iter().copied());
}
}
impl<T: EnumId> FromIterator<T> for EnumSet<T> {
#[inline]
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
let iter = iter.into_iter();
let mut set = Self::new();
set.extend(iter);
set
}
}
impl<'a, T: EnumId> FromIterator<&'a T> for EnumSet<T> {
#[inline]
fn from_iter<I: IntoIterator<Item = &'a T>>(iter: I) -> Self {
iter.into_iter().copied().collect()
}
}
impl<'a, T: EnumId + 'a> IntoIterator for &'a EnumSet<T> {
type Item = T;
type IntoIter = Iter<'a, T>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<T: EnumId> IntoIterator for EnumSet<T> {
type Item = T;
type IntoIter = IntoIter<T>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
IntoIter {
len: self.len as usize,
marker: PhantomData,
handle: OnesIter::new(Array::into_iter(self.limbs)),
}
}
}
impl<'a, T: EnumId + 'a> Index<&'a T> for EnumSet<T> {
type Output = bool;
#[inline]
fn index(&self, index: &'a T) -> &Self::Output {
&self[*index]
}
}
impl<T: EnumId> Index<T> for EnumSet<T> {
type Output = bool;
#[inline]
fn index(&self, index: T) -> &Self::Output {
const TRUE_REF: &bool = &true;
const FALSE_REF: &bool = &false;
if self.contains(index) {
TRUE_REF
} else {
FALSE_REF
}
}
}
impl<T: EnumId + Hash> Hash for EnumSet<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write_usize(self.len());
for value in self {
value.hash(state);
}
}
}
impl<T: EnumId + PartialOrd> PartialOrd for EnumSet<T> {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.iter().partial_cmp(other.iter())
}
}
impl<T: EnumId + Ord> Ord for EnumSet<T> {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
self.iter().cmp(other.iter())
}
}
pub struct Iter<'a, T: EnumId> {
len: usize,
handle: OnesIter<BitsetLimb, core::iter::Copied<core::slice::Iter<'a, BitsetLimb>>>,
marker: PhantomData<fn() -> T>,
}
impl_direct_set_iter!(Iter<'a, K: EnumId>);
pub struct IntoIter<T: EnumId> {
handle: OnesIter<BitsetLimb, <T::BitSet as Array<BitsetLimb>>::Iter>,
len: usize,
marker: PhantomData<T>,
}
impl_direct_set_iter!(IntoIter<K: EnumId>);
#[cfg(feature = "petgraph_0_8")]
impl<T: EnumId> petgraph_0_8::visit::VisitMap<T> for EnumSet<T> {
#[inline]
fn visit(&mut self, a: T) -> bool {
self.insert(a)
}
#[inline]
fn is_visited(&self, value: &T) -> bool {
self.contains(*value)
}
#[inline]
fn unvisit(&mut self, a: T) -> bool {
self.remove(a)
}
}
#[macro_export]
macro_rules! direct_enum_map {
() => ($crate::enums::EnumSet::new());
($($value:expr),+ $(,)?) => ({
let mut set = $crate::enums::EnumSet::new();
$(set.insert($value);)*
set
});
}