1#![no_std]
2use core::hash::{Hash, Hasher};
3use std::{
4 alloc::Layout,
5 mem::transmute,
6 ptr::{addr_of, addr_of_mut},
7};
8extern crate alloc;
9use core as std;
10pub struct BitVector {
31 bits_or_pointer: *mut (),
32}
33
34impl core::fmt::Debug for BitVector {
35 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
36 f.debug_list().entries(self.iter()).finish()
37 }
38}
39
40impl Default for BitVector {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl Clone for BitVector {
47 fn clone(&self) -> Self {
48 if self.is_inline() {
49 Self {
50 bits_or_pointer: self.bits_or_pointer,
51 }
52 } else {
53 unsafe {
54 let my_out_of_line_bits = self.out_of_line_bits();
55 let mut result = Self::with_capacity((*my_out_of_line_bits).num_bits());
56 result.resize_out_of_line((*my_out_of_line_bits).num_bits(), 0);
57 OutOfLineBits::bits_mut(result.out_of_line_bits_mut())
63 .copy_from_slice(OutOfLineBits::bits(my_out_of_line_bits));
64 result
65 }
66 }
67 }
68}
69
70impl BitVector {
71 pub fn new() -> Self {
72 Self {
73 bits_or_pointer: Self::make_inline_bits(0),
74 }
75 }
76
77 pub fn with_capacity(num_bits: usize) -> Self {
78 let mut result = Self::new();
79 result.ensure_size(num_bits);
80 result
81 }
82
83 pub fn merge(&mut self, other: &Self) {
85 if !self.is_inline() || !other.is_inline() {
86 self.merge_slow(other);
87 return;
88 }
89
90 {
92 self.bits_or_pointer = with_addr(self.bits_or_pointer, |addr| {
93 let addr = addr as usize;
94 (addr | other.bits_or_pointer as usize) as isize
95 });
96 }
97 }
98
99 pub fn filter(&mut self, other: &Self) {
101 if !self.is_inline() || !other.is_inline() {
102 self.filter_slow(other);
103 return;
104 }
105
106 {
108 self.bits_or_pointer = with_addr(self.bits_or_pointer, |addr| {
109 let addr = addr as usize;
110 (addr & other.bits_or_pointer as usize) as isize
111 });
112 }
113 }
114 pub fn exclude(&mut self, other: &Self) {
116 if !self.is_inline() || !other.is_inline() {
117 self.exclude_slow(other);
118 return;
119 }
120
121 {
123 self.bits_or_pointer = with_addr(self.bits_or_pointer, |addr| {
124 let addr = addr as usize;
125 (addr & !(other.bits_or_pointer as usize)) as isize
126 });
127 }
128 debug_assert!(self.is_inline());
129 }
130
131 fn exclude_slow(&mut self, other: &Self) {
132 unsafe {
133 if other.is_inline() {
134 debug_assert!(!self.is_inline());
135 let other_bits = Self::cleanse_inline_bits(other.bits_or_pointer as _);
136 let my_bits = self.out_of_line_bits_mut();
137 OutOfLineBits::bits_mut(my_bits)[0] &= !other_bits;
139 return;
140 }
141
142 if self.is_inline() {
143 self.bits_or_pointer = with_addr(self.bits_or_pointer, |addr| {
147 let addr = addr as usize;
148 (addr & !OutOfLineBits::bits(other.out_of_line_bits())[0]) as isize
149 });
150
151 self.bits_or_pointer = with_addr(self.bits_or_pointer, |addr| {
152 let addr = addr as usize;
153 (addr | (1 << Self::max_inline_bits())) as isize
154 });
155 debug_assert!(self.is_inline());
156 return;
157 }
158
159 self.ensure_size(other.len());
160
161 debug_assert!(!other.is_inline());
162 debug_assert!(!self.is_inline());
163
164 let a = self.out_of_line_bits_mut();
165 let b = other.out_of_line_bits();
166
167 for i in (0..(*a).num_words().min((*b).num_words())).rev() {
168 OutOfLineBits::bits_mut(a)[i] &= !OutOfLineBits::bits(b)[i];
170 }
171 }
172 }
173
174 fn merge_slow(&mut self, other: &Self) {
175 unsafe {
176 if other.is_inline() {
177 debug_assert!(!self.is_inline());
178 let other_bits = Self::cleanse_inline_bits(other.bits_or_pointer as usize);
179 let my_bits = self.out_of_line_bits_mut();
180 OutOfLineBits::bits_mut(my_bits)[0] |= other_bits;
182 return;
183 }
184
185 self.ensure_size(other.len());
186
187 debug_assert!(!other.is_inline());
188 debug_assert!(!self.is_inline());
189
190 let a = self.out_of_line_bits_mut();
191 let b = other.out_of_line_bits();
192
193 for i in (0..(*a).num_words()).rev() {
194 OutOfLineBits::bits_mut(a)[i] |= OutOfLineBits::bits(b)[i];
196 }
197 }
198 }
199
200 fn filter_slow(&mut self, other: &Self) {
201 unsafe {
202 if other.is_inline() {
203 debug_assert!(!self.is_inline());
204 let other_bits = Self::cleanse_inline_bits(other.bits_or_pointer as usize);
205 let my_bits = self.out_of_line_bits_mut();
206 OutOfLineBits::bits_mut(my_bits)[0] &= other_bits;
208 return;
209 }
210
211 if self.is_inline() {
212 self.bits_or_pointer = with_addr(self.bits_or_pointer, |addr| {
213 (addr as usize & OutOfLineBits::bits(other.out_of_line_bits())[0]) as isize
214 });
215 self.bits_or_pointer = with_addr(self.bits_or_pointer, |addr| {
217 (addr as usize | 1 << Self::max_inline_bits()) as isize
218 });
219 debug_assert!(self.is_inline());
220 return;
221 }
222
223 self.ensure_size(other.len());
224
225 debug_assert!(!other.is_inline());
226 debug_assert!(!self.is_inline());
227
228 let a = self.out_of_line_bits_mut();
229 let b = other.out_of_line_bits();
230
231 for i in (0..(*a).num_words().min((*b).num_words())).rev() {
232 OutOfLineBits::bits_mut(a)[i] &= OutOfLineBits::bits(b)[i];
233 }
234
235 for i in (*b).num_words()..(*a).num_words() {
236 OutOfLineBits::bits_mut(a)[i] = 0;
237 }
238 }
239 }
240
241 pub fn is_empty(&self) -> bool {
242 if self.is_inline() {
243 Self::cleanse_inline_bits(self.bits_or_pointer as _) == 0
244 } else {
245 unsafe {
246 OutOfLineBits::bits(self.out_of_line_bits())
247 .iter()
248 .all(|&x| x == 0)
249 }
250 }
251 }
252
253 pub fn bit_count(&self) -> usize {
255 if self.is_inline() {
256 Self::cleanse_inline_bits(self.bits_or_pointer as _).count_ones() as usize
257 } else {
258 unsafe { OutOfLineBits::bits(self.out_of_line_bits()) }
259 .iter()
260 .map(|&x| x.count_ones() as usize)
261 .sum()
262 }
263 }
264
265 pub fn find_bit(&self, index: usize, value: bool) -> usize {
268 let result = self.find_bit_fast(index, value);
269
270 debug_assert!(
271 result == self.find_bit_simple(index, value),
272 "find_bit_fast failed"
273 );
274
275 result
276 }
277
278 pub fn len(&self) -> usize {
280 if self.is_inline() {
281 Self::max_inline_bits()
282 } else {
283 unsafe { (*self.out_of_line_bits()).num_bits() }
284 }
285 }
286
287 pub fn quick_clear(&mut self, bit: usize) -> bool {
293 assert!(bit < self.len());
294
295 unsafe {
296 let word = &mut *self.bits_mut().add(bit / Self::bits_in_pointer());
297 let mask = 1 << (bit & (Self::bits_in_pointer() - 1));
298 let result = (*word & mask) != 0;
299 *word &= !mask;
300 result
301 }
302 }
303
304 pub fn quick_set(&mut self, bit: usize, value: bool) -> bool {
310 assert!(bit < self.len());
311 if value == false {
312 return self.quick_clear(bit);
313 }
314 unsafe {
315 let word = &mut *self.bits_mut().add(bit / Self::bits_in_pointer());
316 let mask = 1 << (bit & (Self::bits_in_pointer() - 1));
317 let result = (*word & mask) != 0;
318 *word |= mask;
319 result
320 }
321 }
322
323 pub fn quick_get(&self, bit: usize) -> bool {
329 assert!(bit < self.len());
330 unsafe {
331 (self.bits().add(bit / Self::bits_in_pointer()).read()
332 & (1 << (bit & (Self::bits_in_pointer() - 1))))
333 != 0
334 }
335 }
336
337 pub fn get(&self, index: usize) -> bool {
339 if index >= self.len() {
340 return false;
341 }
342
343 self.quick_get(index)
344 }
345
346 pub fn contains(&self, index: usize) -> bool {
348 self.get(index)
349 }
350
351 pub fn clear(&mut self, index: usize) -> bool {
353 if index >= self.len() {
354 return false;
355 }
356
357 self.quick_clear(index)
358 }
359
360 pub fn set(&mut self, index: usize, value: bool) -> bool {
362 if value == false {
363 return self.clear(index);
364 }
365
366 self.ensure_size(index + 1);
367 self.quick_set(index, value)
368 }
369
370 pub fn ensure_size(&mut self, num_bits: usize) {
372 if num_bits <= self.len() {
373 return;
374 }
375
376 self.resize_out_of_line(num_bits, 0);
377 }
378
379 pub fn resize(&mut self, num_bits: usize) {
381 if num_bits <= Self::max_inline_bits() {
382 if self.is_inline() {
383 return;
384 }
385
386 let my_out_of_line_bits = self.out_of_line_bits_mut();
387 unsafe {
388 let bits_or_pointer =
389 Self::make_inline_bits(OutOfLineBits::bits(my_out_of_line_bits)[0] as usize);
390
391 OutOfLineBits::destroy(my_out_of_line_bits);
392
393 self.bits_or_pointer = bits_or_pointer;
394 }
395
396 return;
397 }
398
399 self.resize_out_of_line(num_bits, 0);
400 }
401
402 pub fn clear_all(&mut self) {
404 if self.is_inline() {
405 self.bits_or_pointer = Self::make_inline_bits(0);
406 } else {
407 unsafe {
408 core::ptr::write_bytes(
409 self.bits_mut().cast::<u8>(),
410 0,
411 (*self.out_of_line_bits()).num_words() * core::mem::size_of::<usize>(),
412 );
413 }
414 }
415 }
416 pub fn shift_right_by_multiple_of_64(&mut self, shift_in_bits: usize) {
418 debug_assert!(shift_in_bits % 64 == 0);
419 debug_assert!(8 % core::mem::size_of::<usize>() == 0);
420 let shift_in_words = shift_in_bits / 64;
421 let num_bits = self.len() + shift_in_bits;
422 self.resize_out_of_line(num_bits, shift_in_words);
423 }
424
425 pub fn iter(&self) -> BitVectorIter<'_> {
427 BitVectorIter {
428 index: self.find_bit(0, true),
429 bit_vector: self,
430 }
431 }
432
433 fn resize_out_of_line(&mut self, num_bits: usize, shift_in_words: usize) {
434 debug_assert!(num_bits > Self::max_inline_bits());
435
436 unsafe {
437 let new_out_of_line_bits = OutOfLineBits::create(num_bits);
438 let new_num_words = (*new_out_of_line_bits).num_words();
439
440 if self.is_inline() {
441 core::ptr::write_bytes(
442 OutOfLineBits::bits_mut(new_out_of_line_bits)
443 .as_mut_ptr()
444 .cast::<u8>(),
445 0,
446 shift_in_words * core::mem::size_of::<usize>(),
447 );
448
449 let addr = OutOfLineBits::bits_mut(new_out_of_line_bits)
450 .as_mut_ptr()
451 .add(shift_in_words);
452
453 addr.write(self.bits_or_pointer as usize & !(1 << Self::max_inline_bits()));
454 debug_assert!(shift_in_words + 1 <= new_num_words);
455
456 core::ptr::write_bytes(
457 OutOfLineBits::bits_mut(new_out_of_line_bits)
458 .as_mut_ptr()
459 .add(shift_in_words + 1)
460 .cast::<u8>(),
461 0,
462 (new_num_words - 1 - shift_in_words) * core::mem::size_of::<usize>(),
463 );
464 } else {
465 if num_bits > self.len() {
466 let old_num_words = (*self.out_of_line_bits()).num_words();
467
468 core::ptr::write_bytes(
469 OutOfLineBits::bits_mut(new_out_of_line_bits)
471 .as_mut_ptr()
472 .cast::<u8>(),
473 0,
474 shift_in_words * core::mem::size_of::<usize>(),
475 );
476
477 core::ptr::copy_nonoverlapping(
478 OutOfLineBits::bits(self.out_of_line_bits())
479 .as_ptr()
480 .cast::<u8>(),
481 OutOfLineBits::bits_mut(new_out_of_line_bits)
482 .as_mut_ptr()
483 .add(shift_in_words)
484 .cast::<u8>(),
485 old_num_words * core::mem::size_of::<usize>(),
486 );
487
488 debug_assert!(shift_in_words + old_num_words <= new_num_words);
489
490 core::ptr::write_bytes(
502 OutOfLineBits::bits_mut(new_out_of_line_bits)
503 .as_mut_ptr()
504 .add(shift_in_words + old_num_words)
505 .cast::<u8>(),
506 0,
507 (new_num_words - old_num_words - shift_in_words)
508 * core::mem::size_of::<usize>(),
509 );
510 } else {
511 core::ptr::copy_nonoverlapping(
518 OutOfLineBits::bits(self.out_of_line_bits())
519 .as_ptr()
520 .cast::<u8>(),
521 OutOfLineBits::bits_mut(new_out_of_line_bits)
522 .as_mut_ptr()
523 .cast::<u8>(),
524 new_num_words * core::mem::size_of::<usize>(),
525 );
526 }
527
528 OutOfLineBits::destroy(self.out_of_line_bits_mut());
529 }
530
531 self.bits_or_pointer = with_addr(new_out_of_line_bits.cast(), |a| a >> 1).cast();
532 }
534 }
535
536 const fn bits_in_pointer() -> usize {
537 core::mem::size_of::<usize>() << 3
538 }
539
540 const fn max_inline_bits() -> usize {
541 Self::bits_in_pointer() - 1
542 }
543 #[allow(dead_code)]
544 const fn byte_count(bits: usize) -> usize {
545 (bits + 7) >> 3
546 }
547
548 const fn make_inline_bits(bits: usize) -> *mut () {
549 unsafe { transmute(bits | (1 << Self::max_inline_bits())) }
550 }
551
552 const fn cleanse_inline_bits(bits: usize) -> usize {
553 bits & !(1 << Self::max_inline_bits())
554 }
555
556 const fn is_inline(&self) -> bool {
557 unsafe { (transmute::<_, usize>(self.bits_or_pointer) >> Self::max_inline_bits()) != 0 }
558 }
559
560 fn out_of_line_bits(&self) -> *const OutOfLineBits {
561 with_addr(self.bits_or_pointer, |a| a << 1).cast()
562 }
564
565 fn out_of_line_bits_mut(&mut self) -> *mut OutOfLineBits {
566 with_addr(self.bits_or_pointer, |a| a << 1).cast()
567 }
568
569 fn bits(&self) -> *const usize {
570 if self.is_inline() {
571 &self.bits_or_pointer as *const _ as *const usize
572 } else {
573 unsafe { OutOfLineBits::bits(self.out_of_line_bits()).as_ptr() }
574 }
575 }
576
577 fn bits_mut(&mut self) -> *mut usize {
578 if self.is_inline() {
579 &mut self.bits_or_pointer as *mut _ as *mut usize
580 } else {
581 unsafe { OutOfLineBits::bits_mut(self.out_of_line_bits_mut()).as_mut_ptr() }
582 }
584 }
585
586 fn find_bit_fast(&self, start_index: usize, value: bool) -> usize {
587 if self.is_inline() {
588 let mut index = start_index;
589 find_bit_in_word(
590 self.bits_or_pointer as usize,
591 &mut index,
592 Self::max_inline_bits(),
593 value,
594 );
595 return index;
596 }
597
598 let bits = self.out_of_line_bits();
599 unsafe {
600 let skip_value: usize = (value as usize ^ 1).wrapping_neg();
603
604 let num_words = (*bits).num_words();
605
606 let mut word_index = start_index / Self::bits_in_pointer();
607 let mut start_index_in_word = start_index - word_index * Self::bits_in_pointer();
608
609 while word_index < num_words {
610 let word = OutOfLineBits::bits(bits)[word_index];
611 if word != skip_value {
613 let mut index = start_index_in_word;
614 if find_bit_in_word(word, &mut index, Self::bits_in_pointer(), value) {
615 return word_index * Self::bits_in_pointer() + index;
616 }
617 }
618
619 word_index += 1;
620 start_index_in_word = 0;
621 }
622
623 (*bits).num_bits()
624 }
625 }
626
627 fn find_bit_simple(&self, start_index: usize, value: bool) -> usize {
628 let mut index = start_index;
629 while index < self.len() {
630 if self.get(index) == value {
631 return index;
632 }
633 index += 1;
634 }
635 self.len()
636 }
637}
638
639impl Drop for BitVector {
640 fn drop(&mut self) {
641 if !self.is_inline() {
642 unsafe { OutOfLineBits::destroy(self.out_of_line_bits_mut()) }
643 }
644 }
645}
646
647#[repr(C)]
648struct OutOfLineBits {
649 num_bits: usize,
650 bits: [usize; 1],
651}
652
653impl OutOfLineBits {
654 const fn num_bits(&self) -> usize {
655 self.num_bits
656 }
657
658 const fn num_words(&self) -> usize {
659 (self.num_bits + BitVector::bits_in_pointer() - 1) / BitVector::bits_in_pointer()
660 }
661
662 const unsafe fn bits<'a>(this: *const Self) -> &'a [usize] {
663 let words = (*this).num_words();
664
665 core::slice::from_raw_parts(addr_of!((*this).bits).cast::<usize>(), words)
666 }
667
668 unsafe fn bits_mut<'a>(this: *mut Self) -> &'a mut [usize] {
669 let words = (*this).num_words();
674
675 unsafe {
676 core::slice::from_raw_parts_mut(addr_of_mut!((*this).bits).cast::<usize>(), words)
677 }
678 }
679
680 unsafe fn create(num_bits: usize) -> *mut Self {
681 let num_bits = (num_bits + 7) & !7;
682 let size = core::mem::size_of::<Self>() + core::mem::size_of::<usize>() * (num_bits / 64);
683
684 let layout = Layout::from_size_align_unchecked(size, core::mem::align_of::<usize>());
685
686 let ptr = alloc::alloc::alloc(layout) as *mut Self;
687
688 ptr.write(Self {
689 num_bits,
690 bits: [0; 1],
691 });
692
693 ptr
694 }
695
696 unsafe fn destroy(this: *mut Self) {
697 let layout = Layout::from_size_align_unchecked(
698 core::mem::size_of::<Self>() + core::mem::size_of::<usize>() * ((*this).num_bits / 64),
699 core::mem::align_of::<usize>(),
700 );
701
702 alloc::alloc::dealloc(this as *mut u8, layout);
703 }
704}
705
706pub fn find_bit_in_word(
707 mut word: usize,
708 start_or_result_index: &mut usize,
709 end_index: usize,
710 value: bool,
711) -> bool {
712 let bits_in_word = core::mem::size_of::<usize>() << 3;
713 debug_assert!(*start_or_result_index <= bits_in_word && end_index <= bits_in_word);
714
715 let mut index = *start_or_result_index;
716 word >>= index;
717
718 word ^= (value as usize).wrapping_sub(1);
719 index += word.trailing_zeros() as usize;
720
721 if index < end_index {
722 *start_or_result_index = index;
723 true
724 } else {
725 *start_or_result_index = end_index;
726 false
727 }
728}
729
730impl Hash for BitVector {
731 fn hash<H: Hasher>(&self, state: &mut H) {
732 if self.is_inline() {
733 self.bits_or_pointer.hash(state);
734 } else {
735 }
737 }
738}
739
740impl PartialEq for BitVector {
741 fn eq(&self, other: &Self) -> bool {
742 if self.is_inline() {
743 if other.is_inline() {
744 return self.bits_or_pointer == other.bits_or_pointer;
745 }
746
747 unsafe {
748 return self.bits_or_pointer as usize
749 == OutOfLineBits::bits(other.out_of_line_bits())[0];
750 }
751 }
752
753 if other.is_inline() {
754 unsafe {
755 return other.bits_or_pointer as usize
756 == OutOfLineBits::bits(self.out_of_line_bits())[0];
757 }
758 }
760
761 unsafe {
762 return OutOfLineBits::bits(self.out_of_line_bits())[0]
763 == OutOfLineBits::bits(other.out_of_line_bits())[0];
764 }
765 }
767}
768
769impl Eq for BitVector {}
770
771pub struct BitVectorIter<'a> {
773 bit_vector: &'a BitVector,
774 index: usize,
775}
776
777impl<'a> Iterator for BitVectorIter<'a> {
778 type Item = usize;
779
780 fn next(&mut self) -> Option<Self::Item> {
781 if self.index >= self.bit_vector.len() {
782 return None;
783 }
784 let old = self.index;
785 let index = self.bit_vector.find_bit_fast(self.index + 1, true);
786
787 if index >= self.bit_vector.len() {
788 self.index = self.bit_vector.len();
789 Some(old)
790 } else {
791 self.index = index;
792 Some(old)
793 }
794 }
795
796 fn size_hint(&self) -> (usize, Option<usize>) {
797 let len = self.bit_vector.bit_count();
798 (len, Some(len))
799 }
800}
801
802impl<'a> ExactSizeIterator for BitVectorIter<'a> {
803 fn len(&self) -> usize {
804 self.bit_vector.bit_count()
805 }
806}
807
808#[cfg(test)]
809mod tests {
810 use crate::BitVector;
811
812 #[test]
813 fn test_bvec() {
814 let mut bv = BitVector::new();
815
816 bv.set(0, true);
817 bv.set(3, true);
818 bv.set(17, true);
819
820 let mut iter = bv.iter();
821
822 assert_eq!(iter.next(), Some(0));
823 assert_eq!(iter.next(), Some(3));
824 assert_eq!(iter.next(), Some(17));
825 assert_eq!(iter.next(), None);
826
827 bv.set(640, true);
828
829 let mut iter = bv.iter();
830
831 assert_eq!(iter.next(), Some(0));
832 assert_eq!(iter.next(), Some(3));
833 assert_eq!(iter.next(), Some(17));
834 assert_eq!(iter.next(), Some(640));
835 assert_eq!(iter.next(), None);
836
837 assert_eq!(bv.find_bit(19, true), 640);
838
839 let mut bv1 = BitVector::new();
840 let mut bv2 = BitVector::new();
841
842 bv1.set(0, true);
843 bv1.set(3, true);
844 bv1.set(17, true);
845
846 bv2.set(1, true);
847 bv2.set(4, true);
848
849 bv1.merge(&bv2);
850
851 assert!(bv1.get(0));
852 assert!(bv1.get(1));
853 assert!(bv1.get(3));
854 assert!(bv1.get(4));
855 assert!(bv1.get(17));
856 }
857}
858
859fn with_addr(this: *mut (), addr: impl FnOnce(isize) -> isize) -> *mut () {
860 let self_addr = unsafe { transmute::<_, isize>(this) };
861 let dest_addr = addr(self_addr);
862 let offset = dest_addr.wrapping_sub(self_addr);
863
864 this.cast::<u8>().wrapping_offset(offset).cast()
865}