Skip to main content

idmap/enums/
set.rs

1//! Implements an [`EnumSet`] using a bitset.
2
3use crate::direct::macros::impl_direct_set_iter;
4use crate::utils::bitsets::ones::OnesIter;
5use crate::utils::bitsets::retain_word;
6use alloc::boxed::Box;
7use core::cmp::Ordering;
8use core::fmt;
9use core::fmt::{Debug, Formatter};
10use core::hash::{Hash, Hasher};
11use core::iter::FusedIterator;
12use core::marker::PhantomData;
13use core::ops::Index;
14use intid::array::{Array, BitsetLimb};
15use intid::{EnumId, EquivalentId};
16
17/// A set whose members implement [`EnumId`].
18///
19/// This is implemented as a bitset,
20/// so memory is proportional to [`EnumId::COUNT`].
21#[derive(Clone)]
22pub struct EnumSet<T: EnumId> {
23    limbs: T::BitSet,
24    /// It is possible to avoid storing this field by using a [popcount] instruction
25    /// like [`u64::count_ones`]
26    ///
27    /// On older architectures, popcount can be very slow.
28    /// Even on recent Intel architectures, the instruction has a 3-cycle latency.
29    /// We don't want the `len()` call to any slower than [`crate::DirectIdSet`],
30    /// so we unconditionally store the length even when [`T::COUNT`] is small.
31    ///
32    /// Intel AVX2 has instructions to accelerate popcount computation,
33    /// as discussed in [this paper] and implemented in [this library].
34    /// We could consider implementing this behind a cfg-flag
35    /// if the space savings become significant enough.
36    ///
37    /// It is safe to use a `u32` because `EnumId::MAX_ID + 1` is guaranteed to always fit in it.
38    /// Restricting [`EnumId`] to 16-bits does not give any space advantage here.
39    /// As long as the limbs in the array are at least 4-byte aligned,
40    /// a 16-bit length requires 2 bytes of padding and so is effectively the same size as a 32-bit length.
41    /// See [issue #4](https://github.com/DuckLogic/intid.rs/issues/14) for history.
42    ///
43    /// [popcount]: https://en.wikipedia.org/wiki/Hamming_weight
44    /// [this paper]: https://arxiv.org/pdf/1611.07a612
45    /// [this library]: https://github.com/kimwalisch/libpopcnt
46    len: u32,
47    marker: PhantomData<T>,
48}
49#[inline]
50fn divmod_index(index: u32) -> (usize, u32) {
51    (
52        (index / BitsetLimb::BITS) as usize,
53        index % BitsetLimb::BITS,
54    )
55}
56#[inline]
57fn bitmask_for(bit_index: u32) -> BitsetLimb {
58    let one: BitsetLimb = 1;
59    one << bit_index
60}
61impl<T: EnumId> EnumSet<T> {
62    /// Create a new set with no entries.
63    #[inline]
64    pub fn new() -> Self {
65        assert_eq!(
66            crate::enums::verify_enum_type::<T, ()>().bitset_len,
67            Self::BITSET_LEN
68        );
69        // We could just zero initialize the whole map
70        let _assert_can_zero_init = <Self as crate::utils::Zeroable>::zeroed;
71        // However, we initialize field-by-field in case that is somehow faster (skips padding?)
72        EnumSet {
73            // SAFETY: We know that that limbs is an array of integers, so can be zero-initialized
74            limbs: unsafe { core::mem::zeroed() },
75            len: 0,
76            marker: PhantomData,
77        }
78    }
79
80    const BITSET_LEN: usize = <T::BitSet as intid::array::Array<BitsetLimb>>::LEN;
81
82    /// Create a new set with no entries, allocating memory on the heap instead of the stack.
83    ///
84    /// Using `Box::new(EnumSet::new())` could require moving the underlying table
85    /// from the stack to the heap, as LLVM can struggle at eliminating copies.
86    /// This method avoids that copy by always allocating in-place.
87    #[inline]
88    pub fn new_boxed() -> Box<Self> {
89        assert_eq!(
90            crate::enums::verify_enum_type::<T, ()>().bitset_len,
91            Self::BITSET_LEN
92        );
93        crate::utils::Zeroable::zeroed_boxed()
94    }
95
96    #[inline]
97    fn limbs(&self) -> &[BitsetLimb] {
98        self.limbs.as_ref()
99    }
100
101    #[inline]
102    fn limbs_mut(&mut self) -> &mut [BitsetLimb] {
103        self.limbs.as_mut()
104    }
105
106    #[cold]
107    fn index_overflow() -> ! {
108        panic!(
109            "An index for `{}` overflowed its claimed maximum",
110            core::any::type_name::<T>()
111        )
112    }
113
114    /// Break apart a key into its word index and bit index.
115    ///
116    /// Guarantees that the resulting word index will be in-bounds for the bitset.
117    ///
118    /// # Safety
119    /// Relies on the unsafe guarantees of [`IntegerId::TRUSTED_RANGE`] if present.
120    /// If this token is missing, this function makes no unsafe assumptions.
121    #[inline]
122    fn verified_index(key: &T) -> (usize, u32) {
123        let index = intid::uint::checked_cast::<_, u32>(key.to_int()).unwrap_or_else(|| {
124            if T::TRUSTED_RANGE.is_some() {
125                // SAFETY: We have a TRUSTED_RANGE, so cannot overflow a u32
126                unsafe { core::hint::unreachable_unchecked() }
127            } else {
128                Self::index_overflow()
129            }
130        });
131        let (word_index, bit_index) = divmod_index(index);
132        // if we don't have a TRUSTED_RANGE, we have to do a length check
133        if T::TRUSTED_RANGE.is_none() && word_index >= Self::BITSET_LEN {
134            Self::index_overflow();
135        }
136        (word_index, bit_index)
137    }
138
139    /// Inserts the specified element into the set,
140    /// returning `true` if it was newly added and `false` if it was already present.
141    ///
142    /// Return value is consistent with [`HashSet::insert`].
143    ///
144    /// [`HashSet::insert`]: std::collections::HashSet::insert
145    #[inline]
146    pub fn insert(&mut self, value: T) -> bool {
147        let (word_index, bit_index) = Self::verified_index(&value);
148        // SAFETY: Validity of word index checked by verified_index
149        let word = unsafe { self.limbs_mut().get_unchecked_mut(word_index) };
150        let mask = bitmask_for(bit_index);
151        let was_present = (mask & *word) != 0;
152        *word |= mask;
153        !was_present
154    }
155
156    /// Remove the specified value from the set,
157    /// returning whether it was previously present.
158    ///
159    /// Return value is consistent with [`HashSet::remove`].
160    ///
161    /// [`HashSet::remove`]: std::collections::HashSet::insert
162    #[inline]
163    pub fn remove(&mut self, value: impl EquivalentId<T>) -> bool {
164        let value = value.as_id();
165        let (word_index, bit_index) = Self::verified_index(&value);
166        // SAFETY: Validity of word index checked by verified_index
167        let word = unsafe { self.limbs_mut().get_unchecked_mut(word_index) };
168        let mask = bitmask_for(bit_index);
169        let was_present = (mask & *word) != 0;
170        *word &= !mask;
171        was_present
172    }
173
174    /// Check if this set contains the specified value
175    #[inline]
176    pub fn contains(&self, value: impl EquivalentId<T>) -> bool {
177        let (word_index, bit_index) = Self::verified_index(&value.as_id());
178        // SAFETY: Validity of word index checked by verified_index
179        let word = unsafe { self.limbs().get_unchecked(word_index) };
180        (word & bitmask_for(bit_index)) != 0
181    }
182
183    /// Iterate over the values in this set.
184    ///
185    /// Guaranteed to be ordered by the integer value of the key.
186    #[inline]
187    pub fn iter(&self) -> Iter<'_, T> {
188        Iter {
189            len: self.len as usize,
190            handle: OnesIter::new(self.limbs().iter().copied()),
191            marker: PhantomData,
192        }
193    }
194
195    /// Clear the values in this set
196    #[inline]
197    pub fn clear(&mut self) {
198        // SAFETY: Since the limbs are an array of integers,
199        // they are safe to zero initialize
200        unsafe {
201            core::ptr::write_bytes(&mut self.limbs, 0, 1);
202        }
203        self.len = 0;
204    }
205
206    /// The number of entries in this set
207    #[inline]
208    pub fn len(&self) -> usize {
209        self.len as usize
210    }
211
212    /// If this set is empty
213    #[inline]
214    pub fn is_empty(&self) -> bool {
215        self.len == 0
216    }
217
218    /// Retain values in the set if the specified closure returns true
219    ///
220    /// Otherwise, they are removed
221    pub fn retain<F: FnMut(T) -> bool>(&mut self, mut func: F) {
222        for (word_index, word) in self.limbs.as_mut().iter_mut().enumerate() {
223            let (updated_word, word_removed) = retain_word(*word, |bit| {
224                let id = (word_index * 32) + (bit as usize);
225                // Safety: If present in the map, it is known to be valid
226                let key = unsafe { T::from_int_unchecked(intid::uint::from_usize_wrapping(id)) };
227                func(key)
228            });
229            *word = updated_word;
230            self.len -= word_removed;
231        }
232    }
233}
234// SAFETY: We know that the bitset can be zero-initialized because it is an array of integers
235// The only other field is the length, which can also be zero-initialized
236unsafe impl<T: EnumId> crate::utils::Zeroable for EnumSet<T> {}
237
238impl<T: EnumId> Default for EnumSet<T> {
239    #[inline]
240    fn default() -> Self {
241        EnumSet::new()
242    }
243}
244impl<T: EnumId> PartialEq for EnumSet<T> {
245    #[inline]
246    fn eq(&self, other: &Self) -> bool {
247        self.len == other.len && self.limbs() == other.limbs()
248    }
249}
250impl<T: EnumId> Eq for EnumSet<T> {}
251impl<T: EnumId> Debug for EnumSet<T> {
252    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
253        f.debug_set().entries(self.iter()).finish()
254    }
255}
256impl<T: EnumId> Extend<T> for EnumSet<T> {
257    #[inline]
258    fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
259        for value in iter {
260            self.insert(value);
261        }
262    }
263}
264impl<'a, T: EnumId> Extend<&'a T> for EnumSet<T> {
265    #[inline]
266    fn extend<I: IntoIterator<Item = &'a T>>(&mut self, iter: I) {
267        self.extend(iter.into_iter().copied());
268    }
269}
270impl<T: EnumId> FromIterator<T> for EnumSet<T> {
271    #[inline]
272    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
273        let iter = iter.into_iter();
274        let mut set = Self::new();
275        set.extend(iter);
276        set
277    }
278}
279
280impl<'a, T: EnumId> FromIterator<&'a T> for EnumSet<T> {
281    #[inline]
282    fn from_iter<I: IntoIterator<Item = &'a T>>(iter: I) -> Self {
283        iter.into_iter().copied().collect()
284    }
285}
286
287impl<'a, T: EnumId + 'a> IntoIterator for &'a EnumSet<T> {
288    type Item = T;
289    type IntoIter = Iter<'a, T>;
290
291    #[inline]
292    fn into_iter(self) -> Self::IntoIter {
293        self.iter()
294    }
295}
296impl<T: EnumId> IntoIterator for EnumSet<T> {
297    type Item = T;
298    type IntoIter = IntoIter<T>;
299
300    #[inline]
301    fn into_iter(self) -> Self::IntoIter {
302        IntoIter {
303            len: self.len as usize,
304            marker: PhantomData,
305            handle: OnesIter::new(Array::into_iter(self.limbs)),
306        }
307    }
308}
309
310impl<'a, T: EnumId + 'a> Index<&'a T> for EnumSet<T> {
311    type Output = bool;
312
313    #[inline]
314    fn index(&self, index: &'a T) -> &Self::Output {
315        &self[*index]
316    }
317}
318impl<T: EnumId> Index<T> for EnumSet<T> {
319    type Output = bool;
320
321    #[inline]
322    fn index(&self, index: T) -> &Self::Output {
323        const TRUE_REF: &bool = &true;
324        const FALSE_REF: &bool = &false;
325        if self.contains(index) {
326            TRUE_REF
327        } else {
328            FALSE_REF
329        }
330    }
331}
332impl<T: EnumId + Hash> Hash for EnumSet<T> {
333    fn hash<H: Hasher>(&self, state: &mut H) {
334        state.write_usize(self.len());
335        // guaranteed to be ordered by key
336        for value in self {
337            value.hash(state);
338        }
339    }
340}
341impl<T: EnumId + PartialOrd> PartialOrd for EnumSet<T> {
342    #[inline]
343    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
344        self.iter().partial_cmp(other.iter())
345    }
346}
347impl<T: EnumId + Ord> Ord for EnumSet<T> {
348    #[inline]
349    fn cmp(&self, other: &Self) -> Ordering {
350        self.iter().cmp(other.iter())
351    }
352}
353
354/// An iterator over the values in an [`EnumSet`].
355///
356/// [PR #130]: https://github.com/petgraph/fixedbitset/pull/130
357pub struct Iter<'a, T: EnumId> {
358    len: usize,
359    handle: OnesIter<BitsetLimb, core::iter::Copied<core::slice::Iter<'a, BitsetLimb>>>,
360    marker: PhantomData<fn() -> T>,
361}
362impl_direct_set_iter!(Iter<'a, K: EnumId>);
363
364/// An iterator over the values in an [`EnumSet`],
365/// consuming ownership the set.
366pub struct IntoIter<T: EnumId> {
367    handle: OnesIter<BitsetLimb, <T::BitSet as Array<BitsetLimb>>::Iter>,
368    len: usize,
369    marker: PhantomData<T>,
370}
371impl_direct_set_iter!(IntoIter<K: EnumId>);
372
373#[cfg(feature = "petgraph_0_8")]
374impl<T: EnumId> petgraph_0_8::visit::VisitMap<T> for EnumSet<T> {
375    #[inline]
376    fn visit(&mut self, a: T) -> bool {
377        self.insert(a)
378    }
379    #[inline]
380    fn is_visited(&self, value: &T) -> bool {
381        self.contains(*value)
382    }
383    #[inline]
384    fn unvisit(&mut self, a: T) -> bool {
385        self.remove(a)
386    }
387}
388
389/// Creates an [`EnumSet`] from a list of values
390#[macro_export]
391macro_rules! direct_enum_map {
392    () => ($crate::enums::EnumSet::new());
393    ($($value:expr),+ $(,)?) => ({
394        let mut set = $crate::enums::EnumSet::new();
395        $(set.insert($value);)*
396        set
397    });
398}