Skip to main content

bitflagset/
slice.rs

1use core::hash::{Hash, Hasher};
2use core::iter::FusedIterator;
3use core::marker::PhantomData;
4use core::ops::BitAndAssign;
5use num_traits::{AsPrimitive, PrimInt};
6
7use super::bitset::PrimBitSetIter;
8
9/// Iterator over set bit positions in a `BitSlice`.
10pub struct BitSliceIter<'a, T: PrimInt, V> {
11    words: &'a [T],
12    word_idx: usize,
13    current: PrimBitSetIter<T, usize>,
14    _marker: PhantomData<V>,
15}
16
17impl<T: PrimInt + BitAndAssign, V> BitSliceIter<'_, T, V> {
18    #[inline]
19    fn remaining_len(&self) -> usize {
20        self.current.len()
21            + self.words[self.word_idx..]
22                .iter()
23                .map(|w| w.count_ones() as usize)
24                .sum::<usize>()
25    }
26}
27
28impl<'a, T: PrimInt + BitAndAssign, V: TryFrom<usize>> Iterator for BitSliceIter<'a, T, V> {
29    type Item = V;
30
31    fn next(&mut self) -> Option<V> {
32        let bits_per = core::mem::size_of::<T>() * 8;
33        loop {
34            if let Some(pos) = self.current.next() {
35                let idx = (self.word_idx - 1) * bits_per + pos;
36                let converted = V::try_from(idx);
37                debug_assert!(converted.is_ok());
38                match converted {
39                    Ok(value) => return Some(value),
40                    Err(_) => unsafe { core::hint::unreachable_unchecked() },
41                }
42            }
43            if self.word_idx >= self.words.len() {
44                return None;
45            }
46            self.current = PrimBitSetIter::from_raw(self.words[self.word_idx]);
47            self.word_idx += 1;
48        }
49    }
50
51    #[inline]
52    fn size_hint(&self) -> (usize, Option<usize>) {
53        let len = self.remaining_len();
54        (len, Some(len))
55    }
56
57    #[inline]
58    fn count(self) -> usize
59    where
60        Self: Sized,
61    {
62        self.remaining_len()
63    }
64}
65
66impl<T: PrimInt + BitAndAssign, V: TryFrom<usize>> ExactSizeIterator for BitSliceIter<'_, T, V> {
67    #[inline]
68    fn len(&self) -> usize {
69        self.remaining_len()
70    }
71}
72
73impl<T: PrimInt + BitAndAssign, V: TryFrom<usize>> FusedIterator for BitSliceIter<'_, T, V> {}
74
75/// Unsized shared base for all bitset types. Wraps a raw `[T]` primitive slice.
76///
77/// All operations use direct primitive bit manipulation (count_ones, bit masking, etc.),
78/// not bitvec's generic algorithms.
79///
80/// Owned types (`BitSet`, `BoxedBitSet`) implement `Deref<Target = BitSlice<T, V>>`
81/// so common methods are defined here once.
82#[repr(transparent)]
83pub struct BitSlice<T, V>(PhantomData<V>, [T]);
84
85impl<T, V> BitSlice<T, V> {
86    pub(crate) fn from_slice_ref(s: &[T]) -> &Self {
87        // SAFETY: BitSlice<T, V> is repr(transparent) over [T]
88        // (PhantomData<V> is ZST)
89        unsafe { &*(s as *const [T] as *const Self) }
90    }
91
92    pub(crate) fn from_slice_mut(s: &mut [T]) -> &mut Self {
93        // SAFETY: same layout guarantee
94        unsafe { &mut *(s as *mut [T] as *mut Self) }
95    }
96}
97
98impl<T: PrimInt, V> BitSlice<T, V> {
99    const BITS_PER: usize = core::mem::size_of::<T>() * 8;
100
101    #[inline]
102    fn index_of(idx: usize) -> (usize, T) {
103        (
104            idx / Self::BITS_PER,
105            T::one().unsigned_shl((idx % Self::BITS_PER) as u32),
106        )
107    }
108
109    #[inline]
110    pub fn capacity(&self) -> usize {
111        self.1.len() * Self::BITS_PER
112    }
113
114    #[inline]
115    pub fn len(&self) -> usize {
116        self.1.iter().map(|w| w.count_ones() as usize).sum()
117    }
118
119    #[inline]
120    pub fn is_empty(&self) -> bool {
121        self.1.iter().all(|w| w.is_zero())
122    }
123
124    #[inline]
125    pub fn first(&self) -> Option<V>
126    where
127        V: TryFrom<usize>,
128    {
129        for (i, &word) in self.1.iter().enumerate() {
130            if !word.is_zero() {
131                let bit = word.trailing_zeros() as usize;
132                let idx = i * Self::BITS_PER + bit;
133                let converted = V::try_from(idx);
134                debug_assert!(converted.is_ok());
135                return Some(match converted {
136                    Ok(value) => value,
137                    Err(_) => unsafe { core::hint::unreachable_unchecked() },
138                });
139            }
140        }
141        None
142    }
143
144    #[inline]
145    pub fn last(&self) -> Option<V>
146    where
147        V: TryFrom<usize>,
148    {
149        for (i, &word) in self.1.iter().enumerate().rev() {
150            if !word.is_zero() {
151                let bit = Self::BITS_PER - 1 - word.leading_zeros() as usize;
152                let idx = i * Self::BITS_PER + bit;
153                let converted = V::try_from(idx);
154                debug_assert!(converted.is_ok());
155                return Some(match converted {
156                    Ok(value) => value,
157                    Err(_) => unsafe { core::hint::unreachable_unchecked() },
158                });
159            }
160        }
161        None
162    }
163
164    #[inline]
165    pub fn pop_first(&mut self) -> Option<V>
166    where
167        V: TryFrom<usize>,
168    {
169        for (i, word) in self.1.iter_mut().enumerate() {
170            if !word.is_zero() {
171                let bit = word.trailing_zeros() as usize;
172                let mask = T::one().unsigned_shl(bit as u32);
173                *word = *word & !mask;
174                let idx = i * Self::BITS_PER + bit;
175                let converted = V::try_from(idx);
176                debug_assert!(converted.is_ok());
177                return Some(match converted {
178                    Ok(value) => value,
179                    Err(_) => unsafe { core::hint::unreachable_unchecked() },
180                });
181            }
182        }
183        None
184    }
185
186    #[inline]
187    pub fn pop_last(&mut self) -> Option<V>
188    where
189        V: TryFrom<usize>,
190    {
191        for (i, word) in self.1.iter_mut().enumerate().rev() {
192            if !word.is_zero() {
193                let bit = Self::BITS_PER - 1 - word.leading_zeros() as usize;
194                let mask = T::one().unsigned_shl(bit as u32);
195                *word = *word & !mask;
196                let idx = i * Self::BITS_PER + bit;
197                let converted = V::try_from(idx);
198                debug_assert!(converted.is_ok());
199                return Some(match converted {
200                    Ok(value) => value,
201                    Err(_) => unsafe { core::hint::unreachable_unchecked() },
202                });
203            }
204        }
205        None
206    }
207
208    #[inline]
209    pub fn contains(&self, id: &V) -> bool
210    where
211        V: Copy + AsPrimitive<usize>,
212    {
213        let idx = (*id).as_();
214        debug_assert!(
215            idx < self.capacity(),
216            "index {idx} out of range for capacity {}",
217            self.capacity()
218        );
219        let (seg, mask) = Self::index_of(idx);
220        self.1.get(seg).is_some_and(|w| *w & mask != T::zero())
221    }
222
223    #[inline]
224    pub fn set(&mut self, id: V, value: bool)
225    where
226        V: AsPrimitive<usize>,
227    {
228        let idx = id.as_();
229        debug_assert!(
230            idx < self.capacity(),
231            "index {idx} out of range for capacity {}",
232            self.capacity()
233        );
234        let (seg, mask) = Self::index_of(idx);
235        if let Some(word) = self.1.get_mut(seg) {
236            if value {
237                *word = *word | mask;
238            } else {
239                *word = *word & !mask;
240            }
241        }
242    }
243
244    #[inline]
245    pub fn insert(&mut self, id: V) -> bool
246    where
247        V: AsPrimitive<usize>,
248    {
249        let idx = id.as_();
250        debug_assert!(
251            idx < self.capacity(),
252            "index {idx} out of range for capacity {}",
253            self.capacity()
254        );
255        let (seg, mask) = Self::index_of(idx);
256        let Some(word) = self.1.get_mut(seg) else {
257            return false;
258        };
259        let was_absent = *word & mask == T::zero();
260        *word = *word | mask;
261        was_absent
262    }
263
264    #[inline]
265    pub fn remove(&mut self, id: V) -> bool
266    where
267        V: AsPrimitive<usize>,
268    {
269        let idx = id.as_();
270        debug_assert!(
271            idx < self.capacity(),
272            "index {idx} out of range for capacity {}",
273            self.capacity()
274        );
275        let (seg, mask) = Self::index_of(idx);
276        let Some(word) = self.1.get_mut(seg) else {
277            return false;
278        };
279        let was_present = *word & mask != T::zero();
280        *word = *word & !mask;
281        was_present
282    }
283
284    #[inline]
285    pub fn toggle(&mut self, id: V)
286    where
287        V: AsPrimitive<usize>,
288    {
289        let idx = id.as_();
290        debug_assert!(
291            idx < self.capacity(),
292            "index {idx} out of range for capacity {}",
293            self.capacity()
294        );
295        let (seg, mask) = Self::index_of(idx);
296        if let Some(word) = self.1.get_mut(seg) {
297            *word = *word ^ mask;
298        }
299    }
300
301    #[inline]
302    pub fn clear(&mut self) {
303        self.1.fill(T::zero());
304    }
305
306    pub fn retain(&mut self, mut f: impl FnMut(V) -> bool)
307    where
308        V: TryFrom<usize>,
309    {
310        for (i, word) in self.1.iter_mut().enumerate() {
311            let mut w = *word;
312            while !w.is_zero() {
313                let bit = w.trailing_zeros() as usize;
314                let mask = T::one().unsigned_shl(bit as u32);
315                w = w & !mask;
316                let idx = i * Self::BITS_PER + bit;
317                let converted = V::try_from(idx);
318                debug_assert!(converted.is_ok());
319                let value = match converted {
320                    Ok(v) => v,
321                    Err(_) => unsafe { core::hint::unreachable_unchecked() },
322                };
323                if !f(value) {
324                    *word = *word & !mask;
325                }
326            }
327        }
328    }
329
330    pub fn append(&mut self, other: &mut Self) {
331        let min = self.1.len().min(other.1.len());
332        for i in 0..min {
333            self.1[i] = self.1[i] | other.1[i];
334            other.1[i] = T::zero();
335        }
336    }
337
338    #[inline]
339    pub fn iter(&self) -> BitSliceIter<'_, T, V>
340    where
341        T: BitAndAssign,
342        V: TryFrom<usize>,
343    {
344        BitSliceIter {
345            words: &self.1,
346            word_idx: 0,
347            current: PrimBitSetIter::empty(),
348            _marker: PhantomData,
349        }
350    }
351
352    #[inline]
353    pub fn is_subset(&self, other: &Self) -> bool {
354        let min = self.1.len().min(other.1.len());
355        self.1[..min]
356            .iter()
357            .zip(other.1[..min].iter())
358            .all(|(a, b)| *a & *b == *a)
359            && self.1[min..].iter().all(|w| w.is_zero())
360    }
361
362    #[inline]
363    pub fn is_superset(&self, other: &Self) -> bool {
364        other.is_subset(self)
365    }
366
367    #[inline]
368    pub fn is_disjoint(&self, other: &Self) -> bool {
369        self.1
370            .iter()
371            .zip(other.1.iter())
372            .all(|(a, b)| (*a & *b).is_zero())
373    }
374
375    fn word_op_iter<'a>(
376        a: &'a [T],
377        b: &'a [T],
378        len: usize,
379        op: impl Fn(T, T) -> T + 'a,
380    ) -> impl Iterator<Item = V> + 'a
381    where
382        T: BitAndAssign,
383        V: TryFrom<usize>,
384    {
385        let bits_per = Self::BITS_PER;
386        (0..len).flat_map(move |i| {
387            let w_a = a.get(i).copied().unwrap_or(T::zero());
388            let w_b = b.get(i).copied().unwrap_or(T::zero());
389            let combined = op(w_a, w_b);
390            let offset = i * bits_per;
391            PrimBitSetIter::<T, usize>(combined, PhantomData).map(move |pos| {
392                let idx = offset + pos;
393                debug_assert!(V::try_from(idx).is_ok());
394                match V::try_from(idx) {
395                    Ok(v) => v,
396                    Err(_) => unsafe { core::hint::unreachable_unchecked() },
397                }
398            })
399        })
400    }
401
402    #[inline]
403    pub fn difference<'a>(&'a self, other: &'a Self) -> impl Iterator<Item = V> + 'a
404    where
405        T: BitAndAssign,
406        V: TryFrom<usize>,
407    {
408        Self::word_op_iter(&self.1, &other.1, self.1.len(), |a, b| a & !b)
409    }
410
411    #[inline]
412    pub fn intersection<'a>(&'a self, other: &'a Self) -> impl Iterator<Item = V> + 'a
413    where
414        T: BitAndAssign,
415        V: TryFrom<usize>,
416    {
417        Self::word_op_iter(
418            &self.1,
419            &other.1,
420            self.1.len().min(other.1.len()),
421            |a, b| a & b,
422        )
423    }
424
425    #[inline]
426    pub fn union<'a>(&'a self, other: &'a Self) -> impl Iterator<Item = V> + 'a
427    where
428        T: BitAndAssign,
429        V: TryFrom<usize>,
430    {
431        Self::word_op_iter(
432            &self.1,
433            &other.1,
434            self.1.len().max(other.1.len()),
435            |a, b| a | b,
436        )
437    }
438
439    #[inline]
440    pub fn symmetric_difference<'a>(&'a self, other: &'a Self) -> impl Iterator<Item = V> + 'a
441    where
442        T: BitAndAssign,
443        V: TryFrom<usize>,
444    {
445        Self::word_op_iter(
446            &self.1,
447            &other.1,
448            self.1.len().max(other.1.len()),
449            |a, b| a ^ b,
450        )
451    }
452
453    // bitvec interop
454
455    #[cfg(feature = "bitvec")]
456    #[inline]
457    pub fn as_bitvec_slice(&self) -> &bitvec::slice::BitSlice<T, bitvec::order::Lsb0>
458    where
459        T: bitvec::store::BitStore,
460    {
461        bitvec::slice::BitSlice::from_slice(&self.1)
462    }
463
464    #[cfg(feature = "bitvec")]
465    #[inline]
466    pub fn as_mut_bitvec_slice(&mut self) -> &mut bitvec::slice::BitSlice<T, bitvec::order::Lsb0>
467    where
468        T: bitvec::store::BitStore,
469    {
470        bitvec::slice::BitSlice::from_slice_mut(&mut self.1)
471    }
472
473    /// Raw word slice accessor.
474    #[inline]
475    pub fn raw_words(&self) -> &[T] {
476        &self.1
477    }
478}
479
480impl<'a, T: PrimInt + BitAndAssign, V: TryFrom<usize>> IntoIterator for &'a BitSlice<T, V> {
481    type Item = V;
482    type IntoIter = BitSliceIter<'a, T, V>;
483
484    #[inline]
485    fn into_iter(self) -> Self::IntoIter {
486        self.iter()
487    }
488}
489
490impl<T: PrimInt, V> PartialEq for BitSlice<T, V> {
491    fn eq(&self, other: &Self) -> bool {
492        let min = self.1.len().min(other.1.len());
493        self.1[..min] == other.1[..min]
494            && self.1[min..].iter().all(|w| w.is_zero())
495            && other.1[min..].iter().all(|w| w.is_zero())
496    }
497}
498
499impl<T: PrimInt, V> Eq for BitSlice<T, V> {}
500
501impl<T: PrimInt + Hash, V> Hash for BitSlice<T, V> {
502    fn hash<H: Hasher>(&self, state: &mut H) {
503        // Hash only up to the last non-zero word for length-independent hashing
504        let effective_len = self
505            .1
506            .iter()
507            .rposition(|w| !w.is_zero())
508            .map_or(0, |i| i + 1);
509        for w in &self.1[..effective_len] {
510            w.hash(state);
511        }
512    }
513}
514
515impl<T: PrimInt + BitAndAssign, V> core::fmt::Debug for BitSlice<T, V> {
516    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
517        let bits_per = core::mem::size_of::<T>() * 8;
518        f.write_str("{")?;
519        let mut first = true;
520        for (i, &word) in self.1.iter().enumerate() {
521            let offset = i * bits_per;
522            for pos in PrimBitSetIter::<T, usize>(word, PhantomData) {
523                if !first {
524                    f.write_str(", ")?;
525                }
526                first = false;
527                write!(f, "{}", offset + pos)?;
528            }
529        }
530        f.write_str("}")
531    }
532}