1use bitvec::prelude::*;
2use num::{NumCast, ToPrimitive, Unsigned};
3use std::cmp::Ordering;
4use std::fmt::{Debug, Display, Formatter};
5use std::hash::{Hash, Hasher};
6use std::ops::{AddAssign, Div, Index};
7use std::{fmt, iter};
8
9#[derive(Default)]
10pub struct BitSet {
11 cardinality: usize,
12 bit_vec: BitVec,
13}
14
15impl Clone for BitSet {
16 fn clone(&self) -> Self {
17 if self.empty() {
21 Self::new(self.len())
22 } else {
23 Self {
24 cardinality: self.cardinality,
25 bit_vec: self.bit_vec.clone(),
26 }
27 }
28 }
29}
30
31impl Ord for BitSet {
32 fn cmp(&self, other: &Self) -> Ordering {
33 self.bit_vec.cmp(&other.bit_vec)
34 }
35}
36
37impl PartialOrd for BitSet {
38 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
39 self.bit_vec.partial_cmp(&other.bit_vec)
40 }
41}
42
43impl Debug for BitSet {
44 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
45 let values: Vec<_> = self.iter().map(|i| i.to_string()).collect();
46 write!(
47 f,
48 "BitSet {{ cardinality: {}, bit_vec: [{}]}}",
49 self.cardinality,
50 values.join(", "),
51 )
52 }
53}
54
55impl PartialEq for BitSet {
56 fn eq(&self, other: &Self) -> bool {
57 self.cardinality == other.cardinality && self.bit_vec == other.bit_vec
58 }
59}
60impl Eq for BitSet {}
61
62impl Hash for BitSet {
63 fn hash<H: Hasher>(&self, state: &mut H) {
64 self.bit_vec.hash(state)
65 }
66}
67
68#[inline]
69fn subset_helper(a: &[usize], b: &[usize]) -> bool {
70 if a.len() > b.len() {
71 !a.iter()
72 .zip(b.iter().chain(iter::repeat(&0usize)))
73 .any(|(a, b)| (*a | *b) != *b)
74 } else {
75 !a.iter()
76 .chain(iter::repeat(&0usize))
77 .zip(b.iter())
78 .any(|(a, b)| (*a | *b) != *b)
79 }
80}
81
82const fn block_size() -> usize {
83 std::mem::size_of::<usize>() * 8
84}
85
86impl BitSet {
87 #[inline]
88 pub fn new(size: usize) -> Self {
89 let mut bit_vec: BitVec = BitVec::with_capacity(size);
90 unsafe {
91 bit_vec.set_len(size);
92 }
93 for i in bit_vec.as_raw_mut_slice() {
94 *i = 0;
95 }
96 Self {
97 cardinality: 0,
98 bit_vec,
99 }
100 }
101
102 pub fn from_bitvec(bit_vec: BitVec) -> Self {
103 let cardinality = bit_vec.iter().filter(|b| **b).count();
104 Self {
105 cardinality,
106 bit_vec,
107 }
108 }
109
110 pub fn from_slice<T: Div<Output = T> + ToPrimitive + AddAssign + Default + Copy + Display>(
111 size: usize,
112 slice: &[T],
113 ) -> Self {
114 let mut bit_vec: BitVec = BitVec::with_capacity(size);
115 unsafe {
116 bit_vec.set_len(size);
117 }
118 slice.iter().for_each(|i| {
119 bit_vec.set(NumCast::from(*i).unwrap(), true);
120 });
121 let cardinality = slice.len();
122 Self {
123 cardinality,
124 bit_vec,
125 }
126 }
127
128 #[inline]
129 pub fn empty(&self) -> bool {
130 self.cardinality == 0
131 }
132
133 #[inline]
134 pub fn full(&self) -> bool {
135 self.cardinality == self.bit_vec.len()
136 }
137
138 pub fn new_all_set(size: usize) -> Self {
139 let mut bit_vec: BitVec = BitVec::with_capacity(size);
140 unsafe {
141 bit_vec.set_len(size);
142 }
143 for i in bit_vec.as_raw_mut_slice() {
144 *i = usize::MAX;
145 }
146 Self {
147 cardinality: size,
148 bit_vec,
149 }
150 }
151
152 pub fn new_all_set_but<T, I>(size: usize, bits_unset: I) -> Self
153 where
154 I: IntoIterator<Item = T>,
155 T: Unsigned + ToPrimitive,
156 {
157 let mut bs = BitSet::new_all_set(size);
158 for i in bits_unset {
159 bs.unset_bit(i.to_usize().unwrap());
160 }
161 bs
162 }
163
164 pub fn new_all_unset_but<T, I>(size: usize, bits_set: I) -> Self
165 where
166 I: IntoIterator<Item = T>,
167 T: Unsigned + ToPrimitive,
168 {
169 let mut bs = BitSet::new(size);
170 for i in bits_set {
171 bs.set_bit(i.to_usize().unwrap());
172 }
173 bs
174 }
175
176 #[inline]
177 pub fn is_disjoint_with(&self, other: &BitSet) -> bool {
178 !self
179 .bit_vec
180 .as_raw_slice()
181 .iter()
182 .zip(other.as_slice().iter())
183 .any(|(x, y)| *x ^ *y != *x | *y)
184 }
185
186 #[inline]
187 pub fn intersects_with(&self, other: &BitSet) -> bool {
188 !self.is_disjoint_with(other)
189 }
190
191 #[inline]
192 pub fn is_subset_of(&self, other: &BitSet) -> bool {
193 self.cardinality <= other.cardinality
194 && subset_helper(self.bit_vec.as_raw_slice(), other.as_slice())
195 }
196
197 #[inline]
198 pub fn is_superset_of(&self, other: &BitSet) -> bool {
199 other.is_subset_of(self)
200 }
201
202 #[inline]
203 pub fn as_slice(&self) -> &[usize] {
204 self.bit_vec.as_raw_slice()
205 }
206
207 #[inline]
208 pub fn as_bitslice(&self) -> &BitSlice {
209 self.bit_vec.as_bitslice()
210 }
211
212 #[inline]
213 pub fn as_bit_vec(&self) -> &BitVec {
214 &self.bit_vec
215 }
216
217 #[inline]
218 pub fn set_bit(&mut self, idx: usize) -> bool {
219 if !*self.bit_vec.get(idx).unwrap() {
220 self.bit_vec.set(idx, true);
221 self.cardinality += 1;
222 false
223 } else {
224 true
225 }
226 }
227
228 #[inline]
229 pub fn unset_bit(&mut self, idx: usize) -> bool {
230 if *self.bit_vec.get(idx).unwrap() {
231 self.bit_vec.set(idx, false);
232 self.cardinality -= 1;
233 true
234 } else {
235 false
236 }
237 }
238
239 #[inline]
240 pub fn cardinality(&self) -> usize {
241 self.cardinality
242 }
243
244 #[inline]
245 pub fn len(&self) -> usize {
246 self.bit_vec.len()
247 }
248
249 #[inline]
250 pub fn is_empty(&self) -> bool {
251 self.bit_vec.is_empty()
252 }
253
254 #[inline]
255 pub fn or(&mut self, other: &BitSet) {
256 if other.len() > self.bit_vec.len() {
257 self.bit_vec.resize(other.len(), false);
258 }
259 for (x, y) in self
260 .bit_vec
261 .as_raw_mut_slice()
262 .iter_mut()
263 .zip(other.as_slice().iter())
264 {
265 *x |= y;
266 }
267 self.cardinality = self.bit_vec.count_ones();
268 }
269
270 #[inline]
271 pub fn resize(&mut self, size: usize) {
272 let old_size = self.bit_vec.len();
273 self.bit_vec.resize(size, false);
274 if size < old_size {
275 self.cardinality = self.bit_vec.count_ones();
276 }
277 }
278
279 #[inline]
280 pub fn and(&mut self, other: &BitSet) {
281 for (x, y) in self
282 .bit_vec
283 .as_raw_mut_slice()
284 .iter_mut()
285 .zip(other.as_slice().iter())
286 {
287 *x &= y;
288 }
289 self.cardinality = self.bit_vec.count_ones();
290 }
291
292 #[inline]
293 pub fn and_not(&mut self, other: &BitSet) {
294 for (x, y) in self
295 .bit_vec
296 .as_raw_mut_slice()
297 .iter_mut()
298 .zip(other.as_slice().iter())
299 {
300 *x &= !y;
301 }
302 self.cardinality = self.bit_vec.count_ones();
303 }
304
305 #[inline]
306 pub fn not(&mut self) {
307 self.bit_vec
308 .as_raw_mut_slice()
309 .iter_mut()
310 .for_each(|x| *x = !*x);
311 self.cardinality = self.bit_vec.count_ones();
312 }
313
314 #[inline]
315 pub fn unset_all(&mut self) {
316 self.bit_vec
317 .as_raw_mut_slice()
318 .iter_mut()
319 .for_each(|x| *x = 0);
320 self.cardinality = 0;
321 }
322
323 #[inline]
324 pub fn set_all(&mut self) {
325 self.bit_vec
326 .as_raw_mut_slice()
327 .iter_mut()
328 .for_each(|x| *x = std::usize::MAX);
329 self.cardinality = self.bit_vec.len();
330 }
331
332 #[inline]
333 pub fn has_smaller(&mut self, other: &BitSet) -> Option<bool> {
334 let self_idx = self.get_first_set()?;
335 let other_idx = other.get_first_set()?;
336 Some(self_idx < other_idx)
337 }
338
339 #[inline]
340 pub fn get_first_set(&self) -> Option<usize> {
341 if self.cardinality != 0 {
342 return self.get_next_set(0);
343 }
344 None
345 }
346
347 #[inline]
348 pub fn get_next_set(&self, idx: usize) -> Option<usize> {
349 if idx >= self.bit_vec.len() {
350 return None;
351 }
352 let mut block_idx = idx / block_size();
353 let word_idx = idx % block_size();
354 let mut block = self.bit_vec.as_raw_slice()[block_idx];
355 let max = self.bit_vec.as_raw_slice().len();
356 block &= usize::MAX << word_idx;
357 while block == 0usize {
358 block_idx += 1;
359 if block_idx >= max {
360 return None;
361 }
362 block = self.bit_vec.as_raw_slice()[block_idx];
363 }
364 let v = block_idx * block_size() + block.trailing_zeros() as usize;
365 if v >= self.bit_vec.len() {
366 None
367 } else {
368 Some(v)
369 }
370 }
371
372 #[inline]
373 pub fn get_first_unset(&self) -> Option<usize> {
374 if self.cardinality != self.len() {
375 return self.get_next_unset(0);
376 }
377 None
378 }
379
380 #[inline]
381 pub fn get_next_unset(&self, idx: usize) -> Option<usize> {
382 if idx >= self.bit_vec.len() {
383 return None;
384 }
385 let mut block_idx = idx / block_size();
386 let word_idx = idx % block_size();
387 let mut block = self.bit_vec.as_raw_slice()[block_idx];
388 let max = self.bit_vec.as_raw_slice().len();
389 block |= (1 << word_idx) - 1;
390 while block == usize::MAX {
391 block_idx += 1;
392 if block_idx >= max {
393 return None;
394 }
395 block = self.bit_vec.as_raw_slice()[block_idx];
396 }
397 let v = block_idx * block_size() + block.trailing_ones() as usize;
398 if v >= self.bit_vec.len() {
399 None
400 } else {
401 Some(v)
402 }
403 }
404
405 #[inline]
406 pub fn to_vec(&self) -> Vec<u32> {
407 let mut tmp = Vec::with_capacity(self.cardinality);
408 for (i, _) in self
409 .bit_vec
410 .as_bitslice()
411 .iter()
412 .enumerate()
413 .filter(|(_, x)| **x)
414 {
415 tmp.push(i as u32);
416 }
417 tmp
418 }
419
420 #[inline]
421 pub fn at(&self, idx: usize) -> bool {
422 self.bit_vec[idx]
423 }
424
425 #[inline]
426 pub fn iter(&self) -> BitSetIterator {
427 BitSetIterator {
428 iter: self.bit_vec.as_raw_slice().iter(),
429 block: 0,
430 idx: 0,
431 size: self.bit_vec.len(),
432 }
433 }
434}
435
436pub struct BitSetIterator<'a> {
437 iter: ::std::slice::Iter<'a, usize>,
438 block: usize,
439 idx: usize,
440 size: usize,
441}
442
443impl<'a> Iterator for BitSetIterator<'a> {
444 type Item = usize;
445
446 #[inline]
447 fn next(&mut self) -> Option<Self::Item> {
448 while self.block == 0 {
449 self.block = if let Some(&i) = self.iter.next() {
450 if i == 0 {
451 self.idx += block_size();
452 continue;
453 } else {
454 self.idx = ((self.idx + block_size() - 1) / block_size()) * block_size();
455 i
456 }
457 } else {
458 return None;
459 }
460 }
461 let offset = self.block.trailing_zeros() as usize;
462 self.block >>= offset;
463 self.block >>= 1;
464 self.idx += offset + 1;
465 if self.idx > self.size {
466 return None;
467 }
468 Some(self.idx - 1)
469 }
470}
471
472impl Index<usize> for BitSet {
473 type Output = bool;
474
475 #[inline]
476 fn index(&self, index: usize) -> &Self::Output {
477 self.bit_vec.index(index)
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use crate::bitset::BitSet;
484
485 #[test]
486 fn iter() {
487 let mut bs = BitSet::new(256);
488
489 let a: Vec<usize> = (0..256).filter(|i| i % 2 == 0).collect();
490 for i in &a {
491 bs.set_bit(*i);
492 }
493
494 let b: Vec<usize> = bs.iter().collect();
495 assert_eq!(a, b);
496 {
497 let mut c = Vec::new();
498 let mut v = bs.get_next_set(0);
499 while v.is_some() {
500 c.push(v.unwrap());
501 v = bs.get_next_set(v.unwrap() + 1);
502 }
503 assert_eq!(a, c);
504 }
505
506 {
507 let odds: Vec<usize> = (0..256).filter(|i| i % 2 == 1).collect();
508 let mut d = Vec::new();
509 let mut v = bs.get_next_unset(0);
510 while v.is_some() {
511 d.push(v.unwrap());
512 v = bs.get_next_unset(v.unwrap() + 1);
513 }
514 assert_eq!(odds, d);
515 }
516 }
517
518 #[test]
519 fn get_set() {
520 let n = 257;
521 let mut bs = BitSet::new(n);
522 for i in 0..n {
523 assert_eq!(false, bs[i]);
524 }
525 for i in 0..n {
526 bs.set_bit(i);
527 assert_eq!(true, bs[i]);
528 }
529
530 for i in 0..n {
531 bs.unset_bit(i);
532 assert_eq!(false, bs[i]);
533 }
534 }
535
536 #[test]
537 fn logic() {
538 let n = 257;
539 let mut bs1 = BitSet::new_all_set(n);
540
541 for i in 0..n {
542 assert_eq!(true, bs1[i]);
543 }
544
545 let mut bs2 = BitSet::new(n);
546
547 for i in 0..n {
548 assert_eq!(false, bs2[i]);
549 }
550 for i in (0..n).filter(|i| i % 2 == 0) {
551 bs2.set_bit(i);
552 bs1.unset_bit(i);
553 }
554
555 let mut tmp = bs1.clone();
556 tmp.and(&bs2);
557 for i in 0..n {
558 assert_eq!(false, tmp[i]);
559 }
560
561 let mut tmp = bs1.clone();
562 tmp.or(&bs2);
563 for i in 0..n {
564 assert_eq!(true, tmp[i]);
565 }
566
567 let mut tmp = bs1.clone();
568 tmp.and_not(&bs2);
569 for i in (0..n).filter(|i| i % 2 == 0) {
570 assert_eq!(false, tmp[i]);
571 }
572 }
573
574 #[test]
575 fn test_new_all_set_but() {
576 let bs = BitSet::new_all_set_but(
579 10,
580 (0usize..10).filter_map(|x| if x % 3 == 0 { Some(x) } else { None }),
581 );
582 assert_eq!(bs.cardinality(), 6);
583 let out: Vec<usize> = bs.iter().collect();
584 assert_eq!(out, vec![1, 2, 4, 5, 7, 8]);
585 }
586
587 #[test]
588 fn test_new_all_unset_but() {
589 let into: Vec<usize> = (0..10)
592 .filter_map(|x| if x % 3 == 0 { Some(x) } else { None })
593 .collect();
594 let bs = BitSet::new_all_unset_but(10, into.clone().into_iter());
595 assert_eq!(bs.cardinality(), 4);
596 let out: Vec<usize> = bs.iter().collect();
597 assert_eq!(out, into);
598 }
599
600 #[test]
601 fn test_clone() {
602 for n in [0, 1, 100] {
603 let empty = BitSet::new(n);
604 let copied = empty.clone();
605 assert_eq!(copied.len(), n);
606 assert_eq!(copied.cardinality(), 0);
607 }
608
609 for n in [10, 50, 100] {
610 let mut orig = BitSet::new(n);
611 for i in 0..n / 5 {
612 orig.set_bit(i % 3);
613 }
614
615 let copied = orig.clone();
616 assert_eq!(copied, orig);
617 assert_eq!(copied.len(), orig.len());
618 assert_eq!(copied.cardinality(), orig.cardinality());
619
620 for i in 0..n {
621 assert_eq!(copied[i], orig[i]);
622 }
623 }
624 }
625}