1use crate::direct::macros::impl_direct_set_iter;
4use crate::utils::bitsets::ones::OnesIter;
5use crate::utils::bitsets::retain_word;
6use alloc::boxed::Box;
7use core::cmp::Ordering;
8use core::fmt;
9use core::fmt::{Debug, Formatter};
10use core::hash::{Hash, Hasher};
11use core::iter::FusedIterator;
12use core::marker::PhantomData;
13use core::ops::Index;
14use intid::array::{Array, BitsetLimb};
15use intid::{EnumId, EquivalentId};
16
17#[derive(Clone)]
22pub struct EnumSet<T: EnumId> {
23 limbs: T::BitSet,
24 len: u32,
47 marker: PhantomData<T>,
48}
49#[inline]
50fn divmod_index(index: u32) -> (usize, u32) {
51 (
52 (index / BitsetLimb::BITS) as usize,
53 index % BitsetLimb::BITS,
54 )
55}
56#[inline]
57fn bitmask_for(bit_index: u32) -> BitsetLimb {
58 let one: BitsetLimb = 1;
59 one << bit_index
60}
61impl<T: EnumId> EnumSet<T> {
62 #[inline]
64 pub fn new() -> Self {
65 assert_eq!(
66 crate::enums::verify_enum_type::<T, ()>().bitset_len,
67 Self::BITSET_LEN
68 );
69 let _assert_can_zero_init = <Self as crate::utils::Zeroable>::zeroed;
71 EnumSet {
73 limbs: unsafe { core::mem::zeroed() },
75 len: 0,
76 marker: PhantomData,
77 }
78 }
79
80 const BITSET_LEN: usize = <T::BitSet as intid::array::Array<BitsetLimb>>::LEN;
81
82 #[inline]
88 pub fn new_boxed() -> Box<Self> {
89 assert_eq!(
90 crate::enums::verify_enum_type::<T, ()>().bitset_len,
91 Self::BITSET_LEN
92 );
93 crate::utils::Zeroable::zeroed_boxed()
94 }
95
96 #[inline]
97 fn limbs(&self) -> &[BitsetLimb] {
98 self.limbs.as_ref()
99 }
100
101 #[inline]
102 fn limbs_mut(&mut self) -> &mut [BitsetLimb] {
103 self.limbs.as_mut()
104 }
105
106 #[cold]
107 fn index_overflow() -> ! {
108 panic!(
109 "An index for `{}` overflowed its claimed maximum",
110 core::any::type_name::<T>()
111 )
112 }
113
114 #[inline]
122 fn verified_index(key: &T) -> (usize, u32) {
123 let index = intid::uint::checked_cast::<_, u32>(key.to_int()).unwrap_or_else(|| {
124 if T::TRUSTED_RANGE.is_some() {
125 unsafe { core::hint::unreachable_unchecked() }
127 } else {
128 Self::index_overflow()
129 }
130 });
131 let (word_index, bit_index) = divmod_index(index);
132 if T::TRUSTED_RANGE.is_none() && word_index >= Self::BITSET_LEN {
134 Self::index_overflow();
135 }
136 (word_index, bit_index)
137 }
138
139 #[inline]
146 pub fn insert(&mut self, value: T) -> bool {
147 let (word_index, bit_index) = Self::verified_index(&value);
148 let word = unsafe { self.limbs_mut().get_unchecked_mut(word_index) };
150 let mask = bitmask_for(bit_index);
151 let was_present = (mask & *word) != 0;
152 *word |= mask;
153 !was_present
154 }
155
156 #[inline]
163 pub fn remove(&mut self, value: impl EquivalentId<T>) -> bool {
164 let value = value.as_id();
165 let (word_index, bit_index) = Self::verified_index(&value);
166 let word = unsafe { self.limbs_mut().get_unchecked_mut(word_index) };
168 let mask = bitmask_for(bit_index);
169 let was_present = (mask & *word) != 0;
170 *word &= !mask;
171 was_present
172 }
173
174 #[inline]
176 pub fn contains(&self, value: impl EquivalentId<T>) -> bool {
177 let (word_index, bit_index) = Self::verified_index(&value.as_id());
178 let word = unsafe { self.limbs().get_unchecked(word_index) };
180 (word & bitmask_for(bit_index)) != 0
181 }
182
183 #[inline]
187 pub fn iter(&self) -> Iter<'_, T> {
188 Iter {
189 len: self.len as usize,
190 handle: OnesIter::new(self.limbs().iter().copied()),
191 marker: PhantomData,
192 }
193 }
194
195 #[inline]
197 pub fn clear(&mut self) {
198 unsafe {
201 core::ptr::write_bytes(&mut self.limbs, 0, 1);
202 }
203 self.len = 0;
204 }
205
206 #[inline]
208 pub fn len(&self) -> usize {
209 self.len as usize
210 }
211
212 #[inline]
214 pub fn is_empty(&self) -> bool {
215 self.len == 0
216 }
217
218 pub fn retain<F: FnMut(T) -> bool>(&mut self, mut func: F) {
222 for (word_index, word) in self.limbs.as_mut().iter_mut().enumerate() {
223 let (updated_word, word_removed) = retain_word(*word, |bit| {
224 let id = (word_index * 32) + (bit as usize);
225 let key = unsafe { T::from_int_unchecked(intid::uint::from_usize_wrapping(id)) };
227 func(key)
228 });
229 *word = updated_word;
230 self.len -= word_removed;
231 }
232 }
233}
234unsafe impl<T: EnumId> crate::utils::Zeroable for EnumSet<T> {}
237
238impl<T: EnumId> Default for EnumSet<T> {
239 #[inline]
240 fn default() -> Self {
241 EnumSet::new()
242 }
243}
244impl<T: EnumId> PartialEq for EnumSet<T> {
245 #[inline]
246 fn eq(&self, other: &Self) -> bool {
247 self.len == other.len && self.limbs() == other.limbs()
248 }
249}
250impl<T: EnumId> Eq for EnumSet<T> {}
251impl<T: EnumId> Debug for EnumSet<T> {
252 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
253 f.debug_set().entries(self.iter()).finish()
254 }
255}
256impl<T: EnumId> Extend<T> for EnumSet<T> {
257 #[inline]
258 fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
259 for value in iter {
260 self.insert(value);
261 }
262 }
263}
264impl<'a, T: EnumId> Extend<&'a T> for EnumSet<T> {
265 #[inline]
266 fn extend<I: IntoIterator<Item = &'a T>>(&mut self, iter: I) {
267 self.extend(iter.into_iter().copied());
268 }
269}
270impl<T: EnumId> FromIterator<T> for EnumSet<T> {
271 #[inline]
272 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
273 let iter = iter.into_iter();
274 let mut set = Self::new();
275 set.extend(iter);
276 set
277 }
278}
279
280impl<'a, T: EnumId> FromIterator<&'a T> for EnumSet<T> {
281 #[inline]
282 fn from_iter<I: IntoIterator<Item = &'a T>>(iter: I) -> Self {
283 iter.into_iter().copied().collect()
284 }
285}
286
287impl<'a, T: EnumId + 'a> IntoIterator for &'a EnumSet<T> {
288 type Item = T;
289 type IntoIter = Iter<'a, T>;
290
291 #[inline]
292 fn into_iter(self) -> Self::IntoIter {
293 self.iter()
294 }
295}
296impl<T: EnumId> IntoIterator for EnumSet<T> {
297 type Item = T;
298 type IntoIter = IntoIter<T>;
299
300 #[inline]
301 fn into_iter(self) -> Self::IntoIter {
302 IntoIter {
303 len: self.len as usize,
304 marker: PhantomData,
305 handle: OnesIter::new(Array::into_iter(self.limbs)),
306 }
307 }
308}
309
310impl<'a, T: EnumId + 'a> Index<&'a T> for EnumSet<T> {
311 type Output = bool;
312
313 #[inline]
314 fn index(&self, index: &'a T) -> &Self::Output {
315 &self[*index]
316 }
317}
318impl<T: EnumId> Index<T> for EnumSet<T> {
319 type Output = bool;
320
321 #[inline]
322 fn index(&self, index: T) -> &Self::Output {
323 const TRUE_REF: &bool = &true;
324 const FALSE_REF: &bool = &false;
325 if self.contains(index) {
326 TRUE_REF
327 } else {
328 FALSE_REF
329 }
330 }
331}
332impl<T: EnumId + Hash> Hash for EnumSet<T> {
333 fn hash<H: Hasher>(&self, state: &mut H) {
334 state.write_usize(self.len());
335 for value in self {
337 value.hash(state);
338 }
339 }
340}
341impl<T: EnumId + PartialOrd> PartialOrd for EnumSet<T> {
342 #[inline]
343 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
344 self.iter().partial_cmp(other.iter())
345 }
346}
347impl<T: EnumId + Ord> Ord for EnumSet<T> {
348 #[inline]
349 fn cmp(&self, other: &Self) -> Ordering {
350 self.iter().cmp(other.iter())
351 }
352}
353
354pub struct Iter<'a, T: EnumId> {
358 len: usize,
359 handle: OnesIter<BitsetLimb, core::iter::Copied<core::slice::Iter<'a, BitsetLimb>>>,
360 marker: PhantomData<fn() -> T>,
361}
362impl_direct_set_iter!(Iter<'a, K: EnumId>);
363
364pub struct IntoIter<T: EnumId> {
367 handle: OnesIter<BitsetLimb, <T::BitSet as Array<BitsetLimb>>::Iter>,
368 len: usize,
369 marker: PhantomData<T>,
370}
371impl_direct_set_iter!(IntoIter<K: EnumId>);
372
373#[cfg(feature = "petgraph_0_8")]
374impl<T: EnumId> petgraph_0_8::visit::VisitMap<T> for EnumSet<T> {
375 #[inline]
376 fn visit(&mut self, a: T) -> bool {
377 self.insert(a)
378 }
379 #[inline]
380 fn is_visited(&self, value: &T) -> bool {
381 self.contains(*value)
382 }
383 #[inline]
384 fn unvisit(&mut self, a: T) -> bool {
385 self.remove(a)
386 }
387}
388
389#[macro_export]
391macro_rules! direct_enum_map {
392 () => ($crate::enums::EnumSet::new());
393 ($($value:expr),+ $(,)?) => ({
394 let mut set = $crate::enums::EnumSet::new();
395 $(set.insert($value);)*
396 set
397 });
398}