Skip to main content

bitflagset/
slice.rs

1use core::hash::{Hash, Hasher};
2use core::iter::FusedIterator;
3use core::marker::PhantomData;
4use core::ops::BitAndAssign;
5use num_traits::{AsPrimitive, PrimInt};
6
7use super::bitset::PrimBitSetIter;
8
9/// Iterator over set bit positions in a word slice.
10///
11/// The storage `S` can be `&[T]`, `[T; N]`, `Box<[T]>`, etc.
12pub struct WordSetIter<S, T: PrimInt, V> {
13    store: S,
14    word_idx: usize,
15    current: PrimBitSetIter<T, usize>,
16    _marker: PhantomData<V>,
17}
18
19impl<S: AsRef<[T]>, T: PrimInt + BitAndAssign, V> WordSetIter<S, T, V> {
20    #[inline]
21    pub(crate) fn new(store: S) -> Self {
22        Self {
23            store,
24            word_idx: 0,
25            current: PrimBitSetIter::empty(),
26            _marker: PhantomData,
27        }
28    }
29
30    #[inline]
31    fn remaining_len(&self) -> usize {
32        self.current.len()
33            + self.store.as_ref()[self.word_idx..]
34                .iter()
35                .map(|w| w.count_ones() as usize)
36                .sum::<usize>()
37    }
38}
39
40impl<S: AsRef<[T]>, T: PrimInt + BitAndAssign, V: TryFrom<usize>> Iterator
41    for WordSetIter<S, T, V>
42{
43    type Item = V;
44
45    fn next(&mut self) -> Option<V> {
46        let words = self.store.as_ref();
47        let bits_per = core::mem::size_of::<T>() * 8;
48        loop {
49            if let Some(pos) = self.current.next() {
50                let idx = (self.word_idx - 1) * bits_per + pos;
51                let converted = V::try_from(idx);
52                debug_assert!(converted.is_ok());
53                match converted {
54                    Ok(value) => return Some(value),
55                    Err(_) => unsafe { core::hint::unreachable_unchecked() },
56                }
57            }
58            if self.word_idx >= words.len() {
59                return None;
60            }
61            self.current = PrimBitSetIter::from_raw(words[self.word_idx]);
62            self.word_idx += 1;
63        }
64    }
65
66    #[inline]
67    fn size_hint(&self) -> (usize, Option<usize>) {
68        let len = self.remaining_len();
69        (len, Some(len))
70    }
71
72    #[inline]
73    fn count(self) -> usize
74    where
75        Self: Sized,
76    {
77        self.remaining_len()
78    }
79}
80
81impl<S: AsRef<[T]>, T: PrimInt + BitAndAssign, V: TryFrom<usize>> ExactSizeIterator
82    for WordSetIter<S, T, V>
83{
84    #[inline]
85    fn len(&self) -> usize {
86        self.remaining_len()
87    }
88}
89
90impl<S: AsRef<[T]>, T: PrimInt + BitAndAssign, V: TryFrom<usize>> FusedIterator
91    for WordSetIter<S, T, V>
92{
93}
94
95/// Iterator over set bit positions in a `BitSlice`.
96pub type BitSliceIter<'a, T, V> = WordSetIter<&'a [T], T, V>;
97
98/// Draining iterator over set bit positions in a `BitSlice`.
99///
100/// Each word is consumed and zeroed in-place as iteration advances.
101/// Dropping the iterator clears any remaining words.
102pub struct Drain<'a, T: PrimInt, V> {
103    words: &'a mut [T],
104    word_idx: usize,
105    current: PrimBitSetIter<T, usize>,
106    _marker: PhantomData<V>,
107}
108
109impl<T: PrimInt + BitAndAssign, V> Drain<'_, T, V> {
110    #[inline]
111    fn remaining_len(&self) -> usize {
112        self.current.len()
113            + self.words[self.word_idx..]
114                .iter()
115                .map(|w| w.count_ones() as usize)
116                .sum::<usize>()
117    }
118}
119
120impl<T: PrimInt + BitAndAssign, V: TryFrom<usize>> Iterator for Drain<'_, T, V> {
121    type Item = V;
122
123    fn next(&mut self) -> Option<V> {
124        let bits_per = core::mem::size_of::<T>() * 8;
125        loop {
126            if let Some(pos) = self.current.next() {
127                let idx = (self.word_idx - 1) * bits_per + pos;
128                let converted = V::try_from(idx);
129                debug_assert!(converted.is_ok());
130                match converted {
131                    Ok(value) => return Some(value),
132                    Err(_) => unsafe { core::hint::unreachable_unchecked() },
133                }
134            }
135            if self.word_idx >= self.words.len() {
136                return None;
137            }
138            self.current = PrimBitSetIter::from_raw(self.words[self.word_idx]);
139            self.words[self.word_idx] = T::zero();
140            self.word_idx += 1;
141        }
142    }
143
144    #[inline]
145    fn size_hint(&self) -> (usize, Option<usize>) {
146        let len = self.remaining_len();
147        (len, Some(len))
148    }
149
150    #[inline]
151    fn count(self) -> usize
152    where
153        Self: Sized,
154    {
155        self.remaining_len()
156    }
157}
158
159impl<T: PrimInt + BitAndAssign, V: TryFrom<usize>> ExactSizeIterator for Drain<'_, T, V> {
160    #[inline]
161    fn len(&self) -> usize {
162        self.remaining_len()
163    }
164}
165
166impl<T: PrimInt + BitAndAssign, V: TryFrom<usize>> FusedIterator for Drain<'_, T, V> {}
167
168impl<T: PrimInt, V> Drop for Drain<'_, T, V> {
169    fn drop(&mut self) {
170        // Clear any words not yet consumed
171        for w in &mut self.words[self.word_idx..] {
172            *w = T::zero();
173        }
174    }
175}
176
177/// Unsized shared base for all bitset types. Wraps a raw `[T]` primitive slice.
178///
179/// All operations use direct primitive bit manipulation (count_ones, bit masking, etc.),
180/// not bitvec's generic algorithms.
181///
182/// Owned types (`BitSet`, `BoxedBitSet`) implement `Deref<Target = BitSlice<T, V>>`
183/// so common methods are defined here once.
184#[repr(transparent)]
185pub struct BitSlice<T, V>(PhantomData<V>, [T]);
186
187impl<T, V> BitSlice<T, V> {
188    pub(crate) fn from_slice_ref(s: &[T]) -> &Self {
189        // SAFETY: BitSlice<T, V> is repr(transparent) over [T]
190        // (PhantomData<V> is ZST)
191        unsafe { &*(s as *const [T] as *const Self) }
192    }
193
194    pub(crate) fn from_slice_mut(s: &mut [T]) -> &mut Self {
195        // SAFETY: same layout guarantee
196        unsafe { &mut *(s as *mut [T] as *mut Self) }
197    }
198}
199
200impl<T: PrimInt, V> BitSlice<T, V> {
201    const BITS_PER: usize = core::mem::size_of::<T>() * 8;
202
203    #[inline]
204    fn index_of(idx: usize) -> (usize, T) {
205        (
206            idx / Self::BITS_PER,
207            T::one().unsigned_shl((idx % Self::BITS_PER) as u32),
208        )
209    }
210
211    #[inline]
212    pub fn capacity(&self) -> usize {
213        self.1.len() * Self::BITS_PER
214    }
215
216    #[inline]
217    pub fn len(&self) -> usize {
218        self.1.iter().map(|w| w.count_ones() as usize).sum()
219    }
220
221    #[inline]
222    pub fn is_empty(&self) -> bool {
223        self.1.iter().all(|w| w.is_zero())
224    }
225
226    #[inline]
227    pub fn first(&self) -> Option<V>
228    where
229        V: TryFrom<usize>,
230    {
231        for (i, &word) in self.1.iter().enumerate() {
232            if !word.is_zero() {
233                let bit = word.trailing_zeros() as usize;
234                let idx = i * Self::BITS_PER + bit;
235                let converted = V::try_from(idx);
236                debug_assert!(converted.is_ok());
237                return Some(match converted {
238                    Ok(value) => value,
239                    Err(_) => unsafe { core::hint::unreachable_unchecked() },
240                });
241            }
242        }
243        None
244    }
245
246    #[inline]
247    pub fn last(&self) -> Option<V>
248    where
249        V: TryFrom<usize>,
250    {
251        for (i, &word) in self.1.iter().enumerate().rev() {
252            if !word.is_zero() {
253                let bit = Self::BITS_PER - 1 - word.leading_zeros() as usize;
254                let idx = i * Self::BITS_PER + bit;
255                let converted = V::try_from(idx);
256                debug_assert!(converted.is_ok());
257                return Some(match converted {
258                    Ok(value) => value,
259                    Err(_) => unsafe { core::hint::unreachable_unchecked() },
260                });
261            }
262        }
263        None
264    }
265
266    #[inline]
267    pub fn pop_first(&mut self) -> Option<V>
268    where
269        V: TryFrom<usize>,
270    {
271        for (i, word) in self.1.iter_mut().enumerate() {
272            if !word.is_zero() {
273                let bit = word.trailing_zeros() as usize;
274                let mask = T::one().unsigned_shl(bit as u32);
275                *word = *word & !mask;
276                let idx = i * Self::BITS_PER + bit;
277                let converted = V::try_from(idx);
278                debug_assert!(converted.is_ok());
279                return Some(match converted {
280                    Ok(value) => value,
281                    Err(_) => unsafe { core::hint::unreachable_unchecked() },
282                });
283            }
284        }
285        None
286    }
287
288    #[inline]
289    pub fn pop_last(&mut self) -> Option<V>
290    where
291        V: TryFrom<usize>,
292    {
293        for (i, word) in self.1.iter_mut().enumerate().rev() {
294            if !word.is_zero() {
295                let bit = Self::BITS_PER - 1 - word.leading_zeros() as usize;
296                let mask = T::one().unsigned_shl(bit as u32);
297                *word = *word & !mask;
298                let idx = i * Self::BITS_PER + bit;
299                let converted = V::try_from(idx);
300                debug_assert!(converted.is_ok());
301                return Some(match converted {
302                    Ok(value) => value,
303                    Err(_) => unsafe { core::hint::unreachable_unchecked() },
304                });
305            }
306        }
307        None
308    }
309
310    #[inline]
311    pub fn contains(&self, id: &V) -> bool
312    where
313        V: Copy + AsPrimitive<usize>,
314    {
315        let idx = (*id).as_();
316        debug_assert!(
317            idx < self.capacity(),
318            "index {idx} out of range for capacity {}",
319            self.capacity()
320        );
321        let (seg, mask) = Self::index_of(idx);
322        self.1.get(seg).is_some_and(|w| *w & mask != T::zero())
323    }
324
325    #[inline]
326    pub fn set(&mut self, id: V, value: bool)
327    where
328        V: AsPrimitive<usize>,
329    {
330        let idx = id.as_();
331        debug_assert!(
332            idx < self.capacity(),
333            "index {idx} out of range for capacity {}",
334            self.capacity()
335        );
336        let (seg, mask) = Self::index_of(idx);
337        if let Some(word) = self.1.get_mut(seg) {
338            if value {
339                *word = *word | mask;
340            } else {
341                *word = *word & !mask;
342            }
343        }
344    }
345
346    #[inline]
347    pub fn insert(&mut self, id: V) -> bool
348    where
349        V: AsPrimitive<usize>,
350    {
351        let idx = id.as_();
352        debug_assert!(
353            idx < self.capacity(),
354            "index {idx} out of range for capacity {}",
355            self.capacity()
356        );
357        let (seg, mask) = Self::index_of(idx);
358        let Some(word) = self.1.get_mut(seg) else {
359            return false;
360        };
361        let was_absent = *word & mask == T::zero();
362        *word = *word | mask;
363        was_absent
364    }
365
366    #[inline]
367    pub fn remove(&mut self, id: V) -> bool
368    where
369        V: AsPrimitive<usize>,
370    {
371        let idx = id.as_();
372        debug_assert!(
373            idx < self.capacity(),
374            "index {idx} out of range for capacity {}",
375            self.capacity()
376        );
377        let (seg, mask) = Self::index_of(idx);
378        let Some(word) = self.1.get_mut(seg) else {
379            return false;
380        };
381        let was_present = *word & mask != T::zero();
382        *word = *word & !mask;
383        was_present
384    }
385
386    #[inline]
387    pub fn toggle(&mut self, id: V)
388    where
389        V: AsPrimitive<usize>,
390    {
391        let idx = id.as_();
392        debug_assert!(
393            idx < self.capacity(),
394            "index {idx} out of range for capacity {}",
395            self.capacity()
396        );
397        let (seg, mask) = Self::index_of(idx);
398        if let Some(word) = self.1.get_mut(seg) {
399            *word = *word ^ mask;
400        }
401    }
402
403    #[inline]
404    pub fn clear(&mut self) {
405        self.1.fill(T::zero());
406    }
407
408    #[inline]
409    pub fn drain(&mut self) -> Drain<'_, T, V>
410    where
411        T: BitAndAssign,
412        V: TryFrom<usize>,
413    {
414        Drain {
415            words: &mut self.1,
416            word_idx: 0,
417            current: PrimBitSetIter::empty(),
418            _marker: PhantomData,
419        }
420    }
421
422    pub fn retain(&mut self, mut f: impl FnMut(V) -> bool)
423    where
424        V: TryFrom<usize>,
425    {
426        for (i, word) in self.1.iter_mut().enumerate() {
427            let mut w = *word;
428            while !w.is_zero() {
429                let bit = w.trailing_zeros() as usize;
430                let mask = T::one().unsigned_shl(bit as u32);
431                w = w & !mask;
432                let idx = i * Self::BITS_PER + bit;
433                let converted = V::try_from(idx);
434                debug_assert!(converted.is_ok());
435                let value = match converted {
436                    Ok(v) => v,
437                    Err(_) => unsafe { core::hint::unreachable_unchecked() },
438                };
439                if !f(value) {
440                    *word = *word & !mask;
441                }
442            }
443        }
444    }
445
446    pub fn union_from(&mut self, other: &Self) {
447        let min = self.1.len().min(other.1.len());
448        for i in 0..min {
449            self.1[i] = self.1[i] | other.1[i];
450        }
451    }
452
453    pub fn append(&mut self, other: &mut Self) {
454        let min = self.1.len().min(other.1.len());
455        for i in 0..min {
456            self.1[i] = self.1[i] | other.1[i];
457            other.1[i] = T::zero();
458        }
459    }
460
461    #[inline]
462    pub fn iter(&self) -> BitSliceIter<'_, T, V>
463    where
464        T: BitAndAssign,
465        V: TryFrom<usize>,
466    {
467        WordSetIter::new(&self.1)
468    }
469
470    #[inline]
471    pub fn is_subset(&self, other: &Self) -> bool {
472        let min = self.1.len().min(other.1.len());
473        self.1[..min]
474            .iter()
475            .zip(other.1[..min].iter())
476            .all(|(a, b)| *a & *b == *a)
477            && self.1[min..].iter().all(|w| w.is_zero())
478    }
479
480    #[inline]
481    pub fn is_superset(&self, other: &Self) -> bool {
482        other.is_subset(self)
483    }
484
485    #[inline]
486    pub fn is_disjoint(&self, other: &Self) -> bool {
487        self.1
488            .iter()
489            .zip(other.1.iter())
490            .all(|(a, b)| (*a & *b).is_zero())
491    }
492
493    fn word_op_iter<'a>(
494        a: &'a [T],
495        b: &'a [T],
496        len: usize,
497        op: impl Fn(T, T) -> T + 'a,
498    ) -> impl Iterator<Item = V> + 'a
499    where
500        T: BitAndAssign,
501        V: TryFrom<usize>,
502    {
503        let bits_per = Self::BITS_PER;
504        (0..len).flat_map(move |i| {
505            let w_a = a.get(i).copied().unwrap_or(T::zero());
506            let w_b = b.get(i).copied().unwrap_or(T::zero());
507            let combined = op(w_a, w_b);
508            let offset = i * bits_per;
509            PrimBitSetIter::<T, usize>(combined, PhantomData).map(move |pos| {
510                let idx = offset + pos;
511                debug_assert!(V::try_from(idx).is_ok());
512                match V::try_from(idx) {
513                    Ok(v) => v,
514                    Err(_) => unsafe { core::hint::unreachable_unchecked() },
515                }
516            })
517        })
518    }
519
520    #[inline]
521    pub fn difference<'a>(&'a self, other: &'a Self) -> impl Iterator<Item = V> + 'a
522    where
523        T: BitAndAssign,
524        V: TryFrom<usize>,
525    {
526        Self::word_op_iter(&self.1, &other.1, self.1.len(), |a, b| a & !b)
527    }
528
529    #[inline]
530    pub fn intersection<'a>(&'a self, other: &'a Self) -> impl Iterator<Item = V> + 'a
531    where
532        T: BitAndAssign,
533        V: TryFrom<usize>,
534    {
535        Self::word_op_iter(
536            &self.1,
537            &other.1,
538            self.1.len().min(other.1.len()),
539            |a, b| a & b,
540        )
541    }
542
543    #[inline]
544    pub fn union<'a>(&'a self, other: &'a Self) -> impl Iterator<Item = V> + 'a
545    where
546        T: BitAndAssign,
547        V: TryFrom<usize>,
548    {
549        Self::word_op_iter(
550            &self.1,
551            &other.1,
552            self.1.len().max(other.1.len()),
553            |a, b| a | b,
554        )
555    }
556
557    #[inline]
558    pub fn symmetric_difference<'a>(&'a self, other: &'a Self) -> impl Iterator<Item = V> + 'a
559    where
560        T: BitAndAssign,
561        V: TryFrom<usize>,
562    {
563        Self::word_op_iter(
564            &self.1,
565            &other.1,
566            self.1.len().max(other.1.len()),
567            |a, b| a ^ b,
568        )
569    }
570
571    // bitvec interop
572
573    #[cfg(feature = "bitvec")]
574    #[inline]
575    pub fn as_bitvec_slice(&self) -> &bitvec::slice::BitSlice<T, bitvec::order::Lsb0>
576    where
577        T: bitvec::store::BitStore,
578    {
579        bitvec::slice::BitSlice::from_slice(&self.1)
580    }
581
582    #[cfg(feature = "bitvec")]
583    #[inline]
584    pub fn as_mut_bitvec_slice(&mut self) -> &mut bitvec::slice::BitSlice<T, bitvec::order::Lsb0>
585    where
586        T: bitvec::store::BitStore,
587    {
588        bitvec::slice::BitSlice::from_slice_mut(&mut self.1)
589    }
590
591    /// Raw word slice accessor.
592    #[inline]
593    pub fn raw_words(&self) -> &[T] {
594        &self.1
595    }
596}
597
598impl<'a, T: PrimInt + BitAndAssign, V: TryFrom<usize>> IntoIterator for &'a BitSlice<T, V> {
599    type Item = V;
600    type IntoIter = BitSliceIter<'a, T, V>;
601
602    #[inline]
603    fn into_iter(self) -> Self::IntoIter {
604        self.iter()
605    }
606}
607
608impl<T: PrimInt, V> PartialEq for BitSlice<T, V> {
609    fn eq(&self, other: &Self) -> bool {
610        let min = self.1.len().min(other.1.len());
611        self.1[..min] == other.1[..min]
612            && self.1[min..].iter().all(|w| w.is_zero())
613            && other.1[min..].iter().all(|w| w.is_zero())
614    }
615}
616
617impl<T: PrimInt, V> Eq for BitSlice<T, V> {}
618
619impl<T: PrimInt + Hash, V> Hash for BitSlice<T, V> {
620    fn hash<H: Hasher>(&self, state: &mut H) {
621        // Hash only up to the last non-zero word for length-independent hashing
622        let effective_len = self
623            .1
624            .iter()
625            .rposition(|w| !w.is_zero())
626            .map_or(0, |i| i + 1);
627        for w in &self.1[..effective_len] {
628            w.hash(state);
629        }
630    }
631}
632
633impl<T: PrimInt + BitAndAssign, V> core::fmt::Debug for BitSlice<T, V> {
634    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
635        let bits_per = core::mem::size_of::<T>() * 8;
636        f.write_str("{")?;
637        let mut first = true;
638        for (i, &word) in self.1.iter().enumerate() {
639            let offset = i * bits_per;
640            for pos in PrimBitSetIter::<T, usize>(word, PhantomData) {
641                if !first {
642                    f.write_str(", ")?;
643                }
644                first = false;
645                write!(f, "{}", offset + pos)?;
646            }
647        }
648        f.write_str("}")
649    }
650}
651
652impl<T: PrimInt + BitAndAssign, V> core::fmt::Display for BitSlice<T, V> {
653    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
654        core::fmt::Debug::fmt(self, f)
655    }
656}