cbitset/
lib.rs

1#![no_std]
2
3extern crate num_traits;
4
5use core::{
6    iter::{FromIterator, FusedIterator, ExactSizeIterator},
7    mem,
8    ops::{Bound, RangeBounds},
9};
10use num_traits::{Bounded, PrimInt, Zero, One};
11
12/// An internal trait used to bypass the fact that rust does not yet
13/// have const generics
14pub trait BitArray: Default + Clone + Copy {
15    /// The item type this array holds
16    type Item: Default + PrimInt;
17    /// Returns how many elements this array can hold
18    fn len() -> usize;
19    /// Access the element at a specified index.
20    ///
21    /// # Panics
22    /// Panics if the index is bigger than the length
23    fn get(&self, index: usize) -> Self::Item;
24    /// Access a mutable reference to the element at a specified index
25    ///
26    /// # Panics
27    /// Panics if the index is bigger than the length
28    fn get_mut(&mut self, index: usize) -> &mut Self::Item;
29}
30
31macro_rules! impl_arrays {
32    ($($len:expr),*) => {
33        $(
34            impl<T: Default + PrimInt> BitArray for [T; $len] {
35                type Item = T;
36
37                fn len() -> usize { $len }
38                fn get(&self, index: usize) -> Self::Item {
39                    self[index]
40                }
41                fn get_mut(&mut self, index: usize) -> &mut Self::Item {
42                    &mut self[index]
43                }
44            }
45        )*
46    }
47}
48
49impl_arrays!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16);
50
51/// A bit set able to hold up to 8 elements
52pub type BitSet8 = BitSet<[u8; 1]>;
53/// A bit set able to hold up to 16 elements
54pub type BitSet16 = BitSet<[u16; 1]>;
55/// A bit set able to hold up to 32 elements
56pub type BitSet32 = BitSet<[u32; 1]>;
57/// A bit set able to hold up to 64 elements
58pub type BitSet64 = BitSet<[u64; 1]>;
59/// A bit set able to hold up to 128 elements
60pub type BitSet128 = BitSet<[u64; 2]>;
61/// A bit set able to hold up to 256 elements
62pub type BitSet256 = BitSet<[u64; 4]>;
63/// A bit set able to hold up to 512 elements
64pub type BitSet512 = BitSet<[u64; 8]>;
65/// A bit set able to hold up to 1024 elements
66pub type BitSet1024 = BitSet<[u64; 16]>;
67
68/// The bit set itself
69///
70/// This wrapper is `#![repr(transparent)]` and guaranteed to have the same memory
71/// representation as the inner bit array
72///
73/// # Panics
74/// All non-try functions taking a bit parameter panics if the bit is bigger
75/// than the capacity of the set. For non-panicking versions, use `try_`.
76#[repr(transparent)]
77#[derive(Default, Clone, Copy)]
78pub struct BitSet<T: BitArray> {
79    inner: T
80}
81impl<T: BitArray> From<T> for BitSet<T> {
82    fn from(inner: T) -> Self {
83        Self { inner }
84    }
85}
86impl<T: BitArray> BitSet<T> {
87    /// Create an empty instance
88    pub fn new() -> Self {
89        Self::default()
90    }
91    /// Transmutes a reference to a borrowed bit array to a borrowed BitSet
92    /// with the same lifetime
93    pub fn from_ref(inner: &mut T) -> &mut Self {
94        // This should be completely safe as the memory representation is the
95        // same
96        unsafe { mem::transmute(inner) }
97    }
98    /// Return the inner integer array
99    pub fn into_inner(self) -> T {
100        self.inner
101    }
102    /// Returns the capacity of the set, in other words how many bits it can
103    /// hold. This function may very well overflow if the size or length is too
104    /// big, but if you're making that big allocations you probably got bigger
105    /// things to worry about.
106    pub fn capacity() -> usize {
107        T::len() * Self::item_size()
108    }
109
110    /// Returns the bit size of each item
111    fn item_size() -> usize {
112        mem::size_of::<T::Item>() * 8
113    }
114    /// Returns slot index along with the bitmask for the bit
115    /// index to the slot this item was in
116    fn location(bit: usize) -> (usize, T::Item) {
117        let index = bit / Self::item_size();
118        let bitmask = T::Item::one() << (bit & Self::item_size() - 1);
119        (index, bitmask)
120    }
121
122    /// Enable the specified bit in the set. If the bit is already
123    /// enabled this is a no-op.
124    pub fn insert(&mut self, bit: usize) {
125        assert!(self.try_insert(bit), "BitSet::insert called on an integer bigger than capacity");
126    }
127    /// Like `insert`, but does not panic if the bit is too large. See
128    /// the struct level documentation for notes on panicking.
129    pub fn try_insert(&mut self, bit: usize) -> bool {
130        if bit >= Self::capacity() {
131            return false;
132        }
133        let (index, bitmask) = Self::location(bit);
134        *self.inner.get_mut(index) = self.inner.get(index) | bitmask;
135        true
136    }
137    /// Disable the specified bit in the set. If the bit is already
138    /// disabled this is a no-op.
139    pub fn remove(&mut self, bit: usize) {
140        assert!(self.try_remove(bit), "BitSet::remove called on an integer bigger than capacity");
141    }
142    /// Like `remove`, but does not panic if the bit is too large.
143    /// See the struct level documentation for notes on panicking.
144    pub fn try_remove(&mut self, bit: usize) -> bool {
145        if bit >= Self::capacity() {
146            return false;
147        }
148        let (index, bitmask) = Self::location(bit);
149        *self.inner.get_mut(index) = self.inner.get(index) & !bitmask;
150        true
151    }
152    /// Returns true if the specified bit is enabled. If the bit is
153    /// out of bounds this silently returns false.
154    pub fn contains(&self, bit: usize) -> bool {
155        self.try_contains(bit).unwrap_or(false)
156    }
157    /// Returns true if the specified bit is enabled
158    pub fn try_contains(&self, bit: usize) -> Option<bool> {
159        if bit >= Self::capacity() {
160            return None;
161        }
162
163        let (index, bitmask) = Self::location(bit);
164        Some(self.inner.get(index) & bitmask == bitmask)
165    }
166
167    /// Returns the total number of enabled bits
168    pub fn count_ones(&self) -> u32 {
169        let mut total = 0;
170        for i in 0..T::len() {
171            total += self.inner.get(i).count_ones();
172        }
173        total
174    }
175    /// Returns the total number of disabled bits
176    pub fn count_zeros(&self) -> u32 {
177        let mut total = 0;
178        for i in 0..T::len() {
179            total += self.inner.get(i).count_zeros();
180        }
181        total
182    }
183
184    /// Disable all bits
185    pub fn clear(&mut self) {
186        for i in 0..T::len() {
187            *self.inner.get_mut(i) = Default::default();
188        }
189    }
190    /// Set all bits in a range.
191    /// `fill(.., false)` is effectively the same as `clear()`.
192    ///
193    /// # Panics
194    /// Panics if the start or end bounds are more than the capacity.
195    pub fn fill<R: RangeBounds<usize>>(&mut self, range: R, on: bool) {
196        let mut start = match range.start_bound() {
197            Bound::Unbounded => 0,
198            Bound::Included(&i) => {
199                assert!(i <= Self::capacity(), "start bound is too big for capacity");
200                i
201            },
202            Bound::Excluded(&i) => {
203                assert!(i + 1 <= Self::capacity(), "start bound is too big for capacity");
204                i + 1
205            }
206        };
207        let end = match range.end_bound() {
208            Bound::Unbounded => Self::capacity(),
209            Bound::Included(0) => return,
210            Bound::Included(&i) => {
211                assert!(i - 1 <= Self::capacity(), "end bound is too big for capacity");
212                i - 1
213            },
214            Bound::Excluded(&i) => {
215                assert!(i <= Self::capacity(), "end bound is too big for capacity");
216                i
217            }
218        };
219
220        if start >= end {
221            return;
222        }
223
224        let end_first = start - (start % Self::item_size()) + Self::item_size();
225        if start % Self::item_size() != 0 || end < end_first {
226            // Unaligned write to either the end or the start of next integer
227            let end_first = end_first.min(end);
228            // println!("Doing initial unaligned from {} to {}", start, end_first);
229            for bit in start..end_first {
230                if on { self.insert(bit); } else { self.remove(bit); }
231            }
232
233            if end == end_first {
234                return;
235            }
236
237            start = end_first + 1;
238        }
239
240        // Fast way to fill all bits in all integers: Just set them to the min/max value.
241        let start_last = end - (end % Self::item_size());
242        // println!("Doing aligned from {} to {}", start, start_last);
243        for i in start / Self::item_size()..start_last / Self::item_size() {
244            *self.inner.get_mut(i) = if on { Bounded::max_value() } else { Default::default() };
245        }
246
247        // Unaligned write to the end
248        // println!("Doing unaligned from {} to {}", start_last, end);
249        for bit in start_last..end {
250            if on { self.insert(bit); } else { self.remove(bit); }
251        }
252    }
253}
254impl<T: BitArray, N: Into<usize>> FromIterator<N> for BitSet<T> {
255    fn from_iter<I>(iter: I) -> Self
256        where I: IntoIterator<Item = N>
257    {
258        let mut set = BitSet::new();
259        for bit in iter.into_iter() {
260            set.insert(bit.into());
261        }
262        set
263    }
264}
265impl<T: BitArray> Iterator for BitSet<T> {
266    type Item = usize;
267
268    /// Iterator implementation for BitSet, guaranteed to remove and
269    /// return the items in ascending order
270    fn next(&mut self) -> Option<Self::Item> {
271        for index in 0..T::len() {
272            let item = self.inner.get_mut(index);
273            if !item.is_zero() {
274                let bitindex = item.trailing_zeros() as usize;
275
276                // E.g. 1010 & 1001 = 1000
277                *item = *item & *item - T::Item::one();
278
279                // Safe from overflows because one couldn't possibly add an item with this index if it did overflow
280                return Some(index * Self::item_size() + bitindex);
281            }
282        }
283        None
284    }
285    fn size_hint(&self) -> (usize, Option<usize>) {
286        let len = self.count_ones() as usize;
287        (len, Some(len))
288    }
289}
290impl<T: BitArray> DoubleEndedIterator for BitSet<T> {
291    /// Reversed iterator implementation for BitSet, guaranteed to
292    /// remove and return the items in descending order
293    fn next_back(&mut self) -> Option<Self::Item> {
294        for index in (0..T::len()).rev() {
295            let item = self.inner.get_mut(index);
296            if !item.is_zero() {
297                let bitindex = Self::item_size() - 1 - item.leading_zeros() as usize;
298
299                // E.g. 00101 & 11011 = 00001, same as remove procedure but using relative index
300                *item = *item & !(T::Item::one() << bitindex);
301
302                // Safe from overflows because one couldn't possibly add an item with this index if it did overflow
303                return Some(index * Self::item_size() + bitindex);
304            }
305        }
306        None
307    }
308}
309impl<T: BitArray> FusedIterator for BitSet<T> {}
310impl<T: BitArray> ExactSizeIterator for BitSet<T> {}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    #[test]
317    fn repr() {
318        assert_eq!(mem::size_of::<BitSet8>(), 1);
319        assert_eq!(mem::size_of::<BitSet16>(), 2);
320        assert_eq!(mem::size_of::<BitSet32>(), 4);
321        assert_eq!(mem::size_of::<BitSet64>(), 8);
322        assert_eq!(mem::size_of::<BitSet128>(), 16);
323        assert_eq!(mem::size_of::<BitSet256>(), 32);
324        assert_eq!(mem::size_of::<BitSet512>(), 64);
325    }
326    #[test]
327    fn capacity() {
328        assert_eq!(BitSet8::capacity(), 8);
329        assert_eq!(BitSet16::capacity(), 16);
330        assert_eq!(BitSet32::capacity(), 32);
331        assert_eq!(BitSet64::capacity(), 64);
332        assert_eq!(BitSet128::capacity(), 128);
333        assert_eq!(BitSet256::capacity(), 256);
334        assert_eq!(BitSet512::capacity(), 512);
335    }
336    #[test]
337    fn try_too_big() {
338        let mut set = BitSet8::new();
339        assert!(!set.try_insert(8));
340    }
341    #[test]
342    #[should_panic]
343    fn panic_too_big() {
344        let mut set = BitSet128::new();
345        set.insert(128);
346    }
347    #[test]
348    fn insert() {
349        let mut set = BitSet128::new();
350        set.insert(0);
351        set.insert(12);
352        set.insert(67);
353        set.insert(82);
354        set.insert(127);
355        assert!(set.contains(0));
356        assert!(set.contains(12));
357        assert!(!set.contains(51));
358        assert!(!set.contains(63));
359        assert!(set.contains(67));
360        assert!(!set.contains(73));
361        assert!(set.contains(82));
362        assert!(set.contains(127));
363    }
364    #[test]
365    fn remove() {
366        let mut set = BitSet32::new();
367        set.insert(12);
368        set.insert(17);
369        assert!(set.contains(12));
370        assert!(set.contains(17));
371        set.remove(17);
372        assert!(set.contains(12));
373        assert!(!set.contains(17));
374    }
375    #[test]
376    fn clear() {
377        let mut set = BitSet64::new();
378        set.insert(35);
379        set.insert(42);
380        assert!(set.contains(35));
381        assert!(set.contains(42));
382        set.clear();
383        assert!(!set.contains(35));
384        assert!(!set.contains(42));
385    }
386    #[test]
387    fn count_ones_and_zeros() {
388        let mut set = BitSet8::new();
389        set.insert(5);
390        set.insert(7);
391        assert_eq!(set.count_ones(), 2);
392        assert_eq!(set.count_zeros(), 8 - 2);
393    }
394    #[test]
395    fn fill() {
396        // Care must be taken when changing the `range` function, as this test
397        // won't detect if it actually does as many aligned writes as it can.
398
399        let mut set = BitSet::<[u8; 2]>::new();
400
401        // Aligned
402        set.fill(.., true);
403        for i in 0..16 {
404            assert!(set.contains(i));
405        }
406
407        // println!("---");
408
409        // Within the same int
410        set.clear();
411        set.fill(1..3, true);
412        assert!(!set.contains(0));
413        assert!(set.contains(1));
414        assert!(set.contains(2));
415        assert!(!set.contains(3));
416
417        // println!("---");
418
419        // Unaligned end
420        set.clear();
421        set.fill(8..10, true);
422        assert!(!set.contains(7));
423        assert!(set.contains(8));
424        assert!(set.contains(9));
425        assert!(!set.contains(10));
426
427        // println!("---");
428
429        // Unaligned start
430        set.clear();
431        set.fill(3..16, true);
432        assert!(!set.contains(2));
433        for i in 3..16 {
434            assert!(set.contains(i));
435        }
436    }
437    #[test]
438    fn iter() {
439        let mut set: BitSet<[u8; 4]> = [30u8, 0, 4, 2, 12, 22, 23, 29].iter().map(|x| *x).collect();
440        assert_eq!(set.len(), 8); assert_eq!(set.next(), Some(0));
441        assert_eq!(set.len(), 7); assert_eq!(set.next_back(), Some(30));
442        assert_eq!(set.len(), 6); assert_eq!(set.next(), Some(2));
443        assert_eq!(set.len(), 5); assert_eq!(set.next_back(), Some(29));
444        assert_eq!(set.len(), 4); assert_eq!(set.next(), Some(4));
445        assert_eq!(set.len(), 3); assert_eq!(set.next_back(), Some(23));
446        assert_eq!(set.len(), 2); assert_eq!(set.next(), Some(12));
447        assert_eq!(set.len(), 1); assert_eq!(set.next_back(), Some(22));
448        assert_eq!(set.len(), 0); assert_eq!(set.next_back(), None);
449        assert_eq!(set.len(), 0); assert_eq!(set.next(), None);
450    }
451}