ez_bitset/
bitset.rs

1use bitvec::prelude::*;
2use num::{NumCast, ToPrimitive, Unsigned};
3use std::cmp::Ordering;
4use std::fmt::{Debug, Display, Formatter};
5use std::hash::{Hash, Hasher};
6use std::ops::{AddAssign, Div, Index};
7use std::{fmt, iter};
8
9#[derive(Default)]
10pub struct BitSet {
11    cardinality: usize,
12    bit_vec: BitVec,
13}
14
15impl Clone for BitSet {
16    fn clone(&self) -> Self {
17        // it's quite common to write vec![BitSet::new(n), n] which is quite expensive
18        // if done by actually copying the BitSet. The following heuristic causes a massive
19        // speed-up in these situations.
20        if self.empty() {
21            Self::new(self.len())
22        } else {
23            Self {
24                cardinality: self.cardinality,
25                bit_vec: self.bit_vec.clone(),
26            }
27        }
28    }
29}
30
31impl Ord for BitSet {
32    fn cmp(&self, other: &Self) -> Ordering {
33        self.bit_vec.cmp(&other.bit_vec)
34    }
35}
36
37impl PartialOrd for BitSet {
38    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
39        self.bit_vec.partial_cmp(&other.bit_vec)
40    }
41}
42
43impl Debug for BitSet {
44    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
45        let values: Vec<_> = self.iter().map(|i| i.to_string()).collect();
46        write!(
47            f,
48            "BitSet {{ cardinality: {}, bit_vec: [{}]}}",
49            self.cardinality,
50            values.join(", "),
51        )
52    }
53}
54
55impl PartialEq for BitSet {
56    fn eq(&self, other: &Self) -> bool {
57        self.cardinality == other.cardinality && self.bit_vec == other.bit_vec
58    }
59}
60impl Eq for BitSet {}
61
62impl Hash for BitSet {
63    fn hash<H: Hasher>(&self, state: &mut H) {
64        self.bit_vec.hash(state)
65    }
66}
67
68#[inline]
69fn subset_helper(a: &[usize], b: &[usize]) -> bool {
70    if a.len() > b.len() {
71        !a.iter()
72            .zip(b.iter().chain(iter::repeat(&0usize)))
73            .any(|(a, b)| (*a | *b) != *b)
74    } else {
75        !a.iter()
76            .chain(iter::repeat(&0usize))
77            .zip(b.iter())
78            .any(|(a, b)| (*a | *b) != *b)
79    }
80}
81
82const fn block_size() -> usize {
83    std::mem::size_of::<usize>() * 8
84}
85
86impl BitSet {
87    #[inline]
88    pub fn new(size: usize) -> Self {
89        let mut bit_vec: BitVec = BitVec::with_capacity(size);
90        unsafe {
91            bit_vec.set_len(size);
92        }
93        for i in bit_vec.as_raw_mut_slice() {
94            *i = 0;
95        }
96        Self {
97            cardinality: 0,
98            bit_vec,
99        }
100    }
101
102    pub fn from_bitvec(bit_vec: BitVec) -> Self {
103        let cardinality = bit_vec.iter().filter(|b| **b).count();
104        Self {
105            cardinality,
106            bit_vec,
107        }
108    }
109
110    pub fn from_slice<T: Div<Output = T> + ToPrimitive + AddAssign + Default + Copy + Display>(
111        size: usize,
112        slice: &[T],
113    ) -> Self {
114        let mut bit_vec: BitVec = BitVec::with_capacity(size);
115        unsafe {
116            bit_vec.set_len(size);
117        }
118        slice.iter().for_each(|i| {
119            bit_vec.set(NumCast::from(*i).unwrap(), true);
120        });
121        let cardinality = slice.len();
122        Self {
123            cardinality,
124            bit_vec,
125        }
126    }
127
128    #[inline]
129    pub fn empty(&self) -> bool {
130        self.cardinality == 0
131    }
132
133    #[inline]
134    pub fn full(&self) -> bool {
135        self.cardinality == self.bit_vec.len()
136    }
137
138    pub fn new_all_set(size: usize) -> Self {
139        let mut bit_vec: BitVec = BitVec::with_capacity(size);
140        unsafe {
141            bit_vec.set_len(size);
142        }
143        for i in bit_vec.as_raw_mut_slice() {
144            *i = usize::MAX;
145        }
146        Self {
147            cardinality: size,
148            bit_vec,
149        }
150    }
151
152    pub fn new_all_set_but<T, I>(size: usize, bits_unset: I) -> Self
153    where
154        I: IntoIterator<Item = T>,
155        T: Unsigned + ToPrimitive,
156    {
157        let mut bs = BitSet::new_all_set(size);
158        for i in bits_unset {
159            bs.unset_bit(i.to_usize().unwrap());
160        }
161        bs
162    }
163
164    pub fn new_all_unset_but<T, I>(size: usize, bits_set: I) -> Self
165    where
166        I: IntoIterator<Item = T>,
167        T: Unsigned + ToPrimitive,
168    {
169        let mut bs = BitSet::new(size);
170        for i in bits_set {
171            bs.set_bit(i.to_usize().unwrap());
172        }
173        bs
174    }
175
176    #[inline]
177    pub fn is_disjoint_with(&self, other: &BitSet) -> bool {
178        !self
179            .bit_vec
180            .as_raw_slice()
181            .iter()
182            .zip(other.as_slice().iter())
183            .any(|(x, y)| *x ^ *y != *x | *y)
184    }
185
186    #[inline]
187    pub fn intersects_with(&self, other: &BitSet) -> bool {
188        !self.is_disjoint_with(other)
189    }
190
191    #[inline]
192    pub fn is_subset_of(&self, other: &BitSet) -> bool {
193        self.cardinality <= other.cardinality
194            && subset_helper(self.bit_vec.as_raw_slice(), other.as_slice())
195    }
196
197    #[inline]
198    pub fn is_superset_of(&self, other: &BitSet) -> bool {
199        other.is_subset_of(self)
200    }
201
202    #[inline]
203    pub fn as_slice(&self) -> &[usize] {
204        self.bit_vec.as_raw_slice()
205    }
206
207    #[inline]
208    pub fn as_bitslice(&self) -> &BitSlice {
209        self.bit_vec.as_bitslice()
210    }
211
212    #[inline]
213    pub fn as_bit_vec(&self) -> &BitVec {
214        &self.bit_vec
215    }
216
217    #[inline]
218    pub fn set_bit(&mut self, idx: usize) -> bool {
219        if !*self.bit_vec.get(idx).unwrap() {
220            self.bit_vec.set(idx, true);
221            self.cardinality += 1;
222            false
223        } else {
224            true
225        }
226    }
227
228    #[inline]
229    pub fn unset_bit(&mut self, idx: usize) -> bool {
230        if *self.bit_vec.get(idx).unwrap() {
231            self.bit_vec.set(idx, false);
232            self.cardinality -= 1;
233            true
234        } else {
235            false
236        }
237    }
238
239    #[inline]
240    pub fn cardinality(&self) -> usize {
241        self.cardinality
242    }
243
244    #[inline]
245    pub fn len(&self) -> usize {
246        self.bit_vec.len()
247    }
248
249    #[inline]
250    pub fn is_empty(&self) -> bool {
251        self.bit_vec.is_empty()
252    }
253
254    #[inline]
255    pub fn or(&mut self, other: &BitSet) {
256        if other.len() > self.bit_vec.len() {
257            self.bit_vec.resize(other.len(), false);
258        }
259        for (x, y) in self
260            .bit_vec
261            .as_raw_mut_slice()
262            .iter_mut()
263            .zip(other.as_slice().iter())
264        {
265            *x |= y;
266        }
267        self.cardinality = self.bit_vec.count_ones();
268    }
269
270    #[inline]
271    pub fn resize(&mut self, size: usize) {
272        let old_size = self.bit_vec.len();
273        self.bit_vec.resize(size, false);
274        if size < old_size {
275            self.cardinality = self.bit_vec.count_ones();
276        }
277    }
278
279    #[inline]
280    pub fn and(&mut self, other: &BitSet) {
281        for (x, y) in self
282            .bit_vec
283            .as_raw_mut_slice()
284            .iter_mut()
285            .zip(other.as_slice().iter())
286        {
287            *x &= y;
288        }
289        self.cardinality = self.bit_vec.count_ones();
290    }
291
292    #[inline]
293    pub fn and_not(&mut self, other: &BitSet) {
294        for (x, y) in self
295            .bit_vec
296            .as_raw_mut_slice()
297            .iter_mut()
298            .zip(other.as_slice().iter())
299        {
300            *x &= !y;
301        }
302        self.cardinality = self.bit_vec.count_ones();
303    }
304
305    #[inline]
306    pub fn not(&mut self) {
307        self.bit_vec
308            .as_raw_mut_slice()
309            .iter_mut()
310            .for_each(|x| *x = !*x);
311        self.cardinality = self.bit_vec.count_ones();
312    }
313
314    #[inline]
315    pub fn unset_all(&mut self) {
316        self.bit_vec
317            .as_raw_mut_slice()
318            .iter_mut()
319            .for_each(|x| *x = 0);
320        self.cardinality = 0;
321    }
322
323    #[inline]
324    pub fn set_all(&mut self) {
325        self.bit_vec
326            .as_raw_mut_slice()
327            .iter_mut()
328            .for_each(|x| *x = std::usize::MAX);
329        self.cardinality = self.bit_vec.len();
330    }
331
332    #[inline]
333    pub fn has_smaller(&mut self, other: &BitSet) -> Option<bool> {
334        let self_idx = self.get_first_set()?;
335        let other_idx = other.get_first_set()?;
336        Some(self_idx < other_idx)
337    }
338
339    #[inline]
340    pub fn get_first_set(&self) -> Option<usize> {
341        if self.cardinality != 0 {
342            return self.get_next_set(0);
343        }
344        None
345    }
346
347    #[inline]
348    pub fn get_next_set(&self, idx: usize) -> Option<usize> {
349        if idx >= self.bit_vec.len() {
350            return None;
351        }
352        let mut block_idx = idx / block_size();
353        let word_idx = idx % block_size();
354        let mut block = self.bit_vec.as_raw_slice()[block_idx];
355        let max = self.bit_vec.as_raw_slice().len();
356        block &= usize::MAX << word_idx;
357        while block == 0usize {
358            block_idx += 1;
359            if block_idx >= max {
360                return None;
361            }
362            block = self.bit_vec.as_raw_slice()[block_idx];
363        }
364        let v = block_idx * block_size() + block.trailing_zeros() as usize;
365        if v >= self.bit_vec.len() {
366            None
367        } else {
368            Some(v)
369        }
370    }
371
372    #[inline]
373    pub fn get_first_unset(&self) -> Option<usize> {
374        if self.cardinality != self.len() {
375            return self.get_next_unset(0);
376        }
377        None
378    }
379
380    #[inline]
381    pub fn get_next_unset(&self, idx: usize) -> Option<usize> {
382        if idx >= self.bit_vec.len() {
383            return None;
384        }
385        let mut block_idx = idx / block_size();
386        let word_idx = idx % block_size();
387        let mut block = self.bit_vec.as_raw_slice()[block_idx];
388        let max = self.bit_vec.as_raw_slice().len();
389        block |= (1 << word_idx) - 1;
390        while block == usize::MAX {
391            block_idx += 1;
392            if block_idx >= max {
393                return None;
394            }
395            block = self.bit_vec.as_raw_slice()[block_idx];
396        }
397        let v = block_idx * block_size() + block.trailing_ones() as usize;
398        if v >= self.bit_vec.len() {
399            None
400        } else {
401            Some(v)
402        }
403    }
404
405    #[inline]
406    pub fn to_vec(&self) -> Vec<u32> {
407        let mut tmp = Vec::with_capacity(self.cardinality);
408        for (i, _) in self
409            .bit_vec
410            .as_bitslice()
411            .iter()
412            .enumerate()
413            .filter(|(_, x)| **x)
414        {
415            tmp.push(i as u32);
416        }
417        tmp
418    }
419
420    #[inline]
421    pub fn at(&self, idx: usize) -> bool {
422        self.bit_vec[idx]
423    }
424
425    #[inline]
426    pub fn iter(&self) -> BitSetIterator {
427        BitSetIterator {
428            iter: self.bit_vec.as_raw_slice().iter(),
429            block: 0,
430            idx: 0,
431            size: self.bit_vec.len(),
432        }
433    }
434}
435
436pub struct BitSetIterator<'a> {
437    iter: ::std::slice::Iter<'a, usize>,
438    block: usize,
439    idx: usize,
440    size: usize,
441}
442
443impl<'a> Iterator for BitSetIterator<'a> {
444    type Item = usize;
445
446    #[inline]
447    fn next(&mut self) -> Option<Self::Item> {
448        while self.block == 0 {
449            self.block = if let Some(&i) = self.iter.next() {
450                if i == 0 {
451                    self.idx += block_size();
452                    continue;
453                } else {
454                    self.idx = ((self.idx + block_size() - 1) / block_size()) * block_size();
455                    i
456                }
457            } else {
458                return None;
459            }
460        }
461        let offset = self.block.trailing_zeros() as usize;
462        self.block >>= offset;
463        self.block >>= 1;
464        self.idx += offset + 1;
465        if self.idx > self.size {
466            return None;
467        }
468        Some(self.idx - 1)
469    }
470}
471
472impl Index<usize> for BitSet {
473    type Output = bool;
474
475    #[inline]
476    fn index(&self, index: usize) -> &Self::Output {
477        self.bit_vec.index(index)
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    use crate::bitset::BitSet;
484
485    #[test]
486    fn iter() {
487        let mut bs = BitSet::new(256);
488
489        let a: Vec<usize> = (0..256).filter(|i| i % 2 == 0).collect();
490        for i in &a {
491            bs.set_bit(*i);
492        }
493
494        let b: Vec<usize> = bs.iter().collect();
495        assert_eq!(a, b);
496        {
497            let mut c = Vec::new();
498            let mut v = bs.get_next_set(0);
499            while v.is_some() {
500                c.push(v.unwrap());
501                v = bs.get_next_set(v.unwrap() + 1);
502            }
503            assert_eq!(a, c);
504        }
505
506        {
507            let odds: Vec<usize> = (0..256).filter(|i| i % 2 == 1).collect();
508            let mut d = Vec::new();
509            let mut v = bs.get_next_unset(0);
510            while v.is_some() {
511                d.push(v.unwrap());
512                v = bs.get_next_unset(v.unwrap() + 1);
513            }
514            assert_eq!(odds, d);
515        }
516    }
517
518    #[test]
519    fn get_set() {
520        let n = 257;
521        let mut bs = BitSet::new(n);
522        for i in 0..n {
523            assert_eq!(false, bs[i]);
524        }
525        for i in 0..n {
526            bs.set_bit(i);
527            assert_eq!(true, bs[i]);
528        }
529
530        for i in 0..n {
531            bs.unset_bit(i);
532            assert_eq!(false, bs[i]);
533        }
534    }
535
536    #[test]
537    fn logic() {
538        let n = 257;
539        let mut bs1 = BitSet::new_all_set(n);
540
541        for i in 0..n {
542            assert_eq!(true, bs1[i]);
543        }
544
545        let mut bs2 = BitSet::new(n);
546
547        for i in 0..n {
548            assert_eq!(false, bs2[i]);
549        }
550        for i in (0..n).filter(|i| i % 2 == 0) {
551            bs2.set_bit(i);
552            bs1.unset_bit(i);
553        }
554
555        let mut tmp = bs1.clone();
556        tmp.and(&bs2);
557        for i in 0..n {
558            assert_eq!(false, tmp[i]);
559        }
560
561        let mut tmp = bs1.clone();
562        tmp.or(&bs2);
563        for i in 0..n {
564            assert_eq!(true, tmp[i]);
565        }
566
567        let mut tmp = bs1.clone();
568        tmp.and_not(&bs2);
569        for i in (0..n).filter(|i| i % 2 == 0) {
570            assert_eq!(false, tmp[i]);
571        }
572    }
573
574    #[test]
575    fn test_new_all_set_but() {
576        // 0123456789
577        //  ++ ++ ++
578        let bs = BitSet::new_all_set_but(
579            10,
580            (0usize..10).filter_map(|x| if x % 3 == 0 { Some(x) } else { None }),
581        );
582        assert_eq!(bs.cardinality(), 6);
583        let out: Vec<usize> = bs.iter().collect();
584        assert_eq!(out, vec![1, 2, 4, 5, 7, 8]);
585    }
586
587    #[test]
588    fn test_new_all_unset_but() {
589        // 0123456789
590        // +  +  +  +
591        let into: Vec<usize> = (0..10)
592            .filter_map(|x| if x % 3 == 0 { Some(x) } else { None })
593            .collect();
594        let bs = BitSet::new_all_unset_but(10, into.clone().into_iter());
595        assert_eq!(bs.cardinality(), 4);
596        let out: Vec<usize> = bs.iter().collect();
597        assert_eq!(out, into);
598    }
599
600    #[test]
601    fn test_clone() {
602        for n in [0, 1, 100] {
603            let empty = BitSet::new(n);
604            let copied = empty.clone();
605            assert_eq!(copied.len(), n);
606            assert_eq!(copied.cardinality(), 0);
607        }
608
609        for n in [10, 50, 100] {
610            let mut orig = BitSet::new(n);
611            for i in 0..n / 5 {
612                orig.set_bit(i % 3);
613            }
614
615            let copied = orig.clone();
616            assert_eq!(copied, orig);
617            assert_eq!(copied.len(), orig.len());
618            assert_eq!(copied.cardinality(), orig.cardinality());
619
620            for i in 0..n {
621                assert_eq!(copied[i], orig[i]);
622            }
623        }
624    }
625}