1use bytes::{Buf, BufMut};
14use commonware_codec::{
15 EncodeSize, Error as CodecError, FixedSize, RangeConfig, Read, ReadExt, Write,
16};
17use std::{
18 fmt::{self, Write as _},
19 ops::{BitAnd, BitOr, BitXor, Index},
20};
21
22type Block = u8;
24
25const BITS_PER_BLOCK: usize = std::mem::size_of::<Block>() * 8;
27
28const EMPTY_BLOCK: Block = 0;
30
31const FULL_BLOCK: Block = Block::MAX;
33
34#[derive(Clone, PartialEq, Eq)]
38pub struct BitVec {
39 storage: Vec<Block>,
41 num_bits: usize,
43}
44
45impl BitVec {
46 #[inline]
48 pub fn new() -> Self {
49 BitVec {
50 storage: Vec::new(),
51 num_bits: 0,
52 }
53 }
54
55 #[inline]
57 pub fn with_capacity(size: usize) -> Self {
58 BitVec {
59 storage: Vec::with_capacity(Self::num_blocks(size)),
60 num_bits: 0,
61 }
62 }
63
64 #[inline]
66 pub fn zeroes(size: usize) -> Self {
67 BitVec {
68 storage: vec![EMPTY_BLOCK; Self::num_blocks(size)],
69 num_bits: size,
70 }
71 }
72
73 #[inline]
75 pub fn ones(size: usize) -> Self {
76 let mut result = Self {
77 storage: vec![FULL_BLOCK; Self::num_blocks(size)],
78 num_bits: size,
79 };
80 result.clear_trailing_bits();
81 result
82 }
83
84 #[inline]
86 pub fn from_bools(bools: &[bool]) -> Self {
87 let mut bv = Self::with_capacity(bools.len());
88 for &b in bools {
89 bv.push(b);
90 }
91 bv
92 }
93
94 #[inline]
96 pub fn len(&self) -> usize {
97 self.num_bits
98 }
99
100 #[inline]
102 pub fn is_empty(&self) -> bool {
103 self.num_bits == 0
104 }
105
106 #[inline]
108 pub fn push(&mut self, value: bool) {
109 let index = self.num_bits;
111 self.num_bits += 1;
112
113 if Self::block_index(index) >= self.storage.len() {
115 self.storage.push(EMPTY_BLOCK);
116 }
117
118 if value {
120 self.set_bit_unchecked(index);
121 }
122 }
123
124 #[inline]
128 pub fn pop(&mut self) -> Option<bool> {
129 if self.is_empty() {
130 return None;
131 }
132
133 self.num_bits -= 1;
135 let index = self.num_bits;
136 let value = self.get_bit_unchecked(index);
137
138 if Self::bit_offset(index) == 0 {
141 self.storage.pop().expect("Storage should not be empty");
142 } else if value {
143 self.clear_bit_unchecked(index);
144 }
145
146 Some(value)
147 }
148
149 #[inline]
153 pub fn get(&self, index: usize) -> Option<bool> {
154 if index >= self.num_bits {
155 return None;
156 }
157 Some(self.get_bit_unchecked(index))
158 }
159
160 #[inline]
166 pub unsafe fn get_unchecked(&self, index: usize) -> bool {
167 self.get_bit_unchecked(index)
168 }
169
170 #[inline]
176 pub fn set(&mut self, index: usize) {
177 self.assert_index(index);
178 self.set_bit_unchecked(index);
179 }
180
181 #[inline]
187 pub fn clear(&mut self, index: usize) {
188 self.assert_index(index);
189 self.clear_bit_unchecked(index);
190 }
191
192 #[inline]
198 pub fn toggle(&mut self, index: usize) {
199 self.assert_index(index);
200 self.toggle_bit_unchecked(index);
201 }
202
203 #[inline]
209 pub fn set_to(&mut self, index: usize, value: bool) {
210 self.assert_index(index);
211 if value {
212 self.set_bit_unchecked(index);
213 } else {
214 self.clear_bit_unchecked(index);
215 }
216 }
217
218 #[inline]
220 pub fn clear_all(&mut self) {
221 for block in &mut self.storage {
222 *block = EMPTY_BLOCK;
223 }
224 }
225
226 #[inline]
228 pub fn set_all(&mut self) {
229 for block in &mut self.storage {
230 *block = FULL_BLOCK;
231 }
232 self.clear_trailing_bits();
233 }
234
235 #[inline]
237 pub fn count_ones(&self) -> usize {
238 self.storage
239 .iter()
240 .map(|block| block.count_ones() as usize)
241 .sum()
242 }
243
244 #[inline]
246 pub fn count_zeros(&self) -> usize {
247 self.num_bits
248 .checked_sub(self.count_ones())
249 .expect("Overflow in count_zeros")
250 }
251
252 pub fn and(&mut self, other: &BitVec) {
258 self.binary_op(other, |a, b| a & b);
259 self.clear_trailing_bits();
260 }
261
262 pub fn or(&mut self, other: &BitVec) {
268 self.binary_op(other, |a, b| a | b);
269 self.clear_trailing_bits();
270 }
271
272 pub fn xor(&mut self, other: &BitVec) {
278 self.binary_op(other, |a, b| a ^ b);
279 self.clear_trailing_bits();
280 }
281
282 pub fn invert(&mut self) {
284 for block in &mut self.storage {
285 *block = !*block;
286 }
287 self.clear_trailing_bits();
288 }
289
290 pub fn iter(&self) -> BitIterator {
292 BitIterator { vec: self, pos: 0 }
293 }
294
295 #[inline(always)]
299 fn block_index(index: usize) -> usize {
300 index / BITS_PER_BLOCK
301 }
302
303 #[inline(always)]
305 fn bit_offset(index: usize) -> usize {
306 index % BITS_PER_BLOCK
307 }
308
309 #[inline(always)]
311 fn num_blocks(num_bits: usize) -> usize {
312 num_bits.div_ceil(BITS_PER_BLOCK)
313 }
314
315 #[inline(always)]
317 fn mask_over_first_n_bits(num_bits: usize) -> Block {
318 assert!(num_bits <= BITS_PER_BLOCK, "num_bits exceeds block size");
319 match num_bits {
321 BITS_PER_BLOCK => FULL_BLOCK,
322 _ => (1 << num_bits) - 1,
323 }
324 }
325
326 #[inline(always)]
327 fn get_bit_unchecked(&self, index: usize) -> bool {
328 let block_index = Self::block_index(index);
329 let bit_index = Self::bit_offset(index);
330 (self.storage[block_index] & (1 << bit_index)) != 0
331 }
332
333 #[inline(always)]
334 fn set_bit_unchecked(&mut self, index: usize) {
335 let block_index = Self::block_index(index);
336 let bit_index = Self::bit_offset(index);
337 self.storage[block_index] |= 1 << bit_index;
338 }
339
340 #[inline(always)]
341 fn clear_bit_unchecked(&mut self, index: usize) {
342 let block_index = Self::block_index(index);
343 let bit_index = Self::bit_offset(index);
344 self.storage[block_index] &= !(1 << bit_index);
345 }
346
347 #[inline(always)]
348 fn toggle_bit_unchecked(&mut self, index: usize) {
349 let block_index = Self::block_index(index);
350 let bit_index = Self::bit_offset(index);
351 self.storage[block_index] ^= 1 << bit_index;
352 }
353
354 #[inline(always)]
356 fn assert_index(&self, index: usize) {
357 assert!(index < self.num_bits, "Index out of bounds");
358 }
359
360 #[inline(always)]
362 fn assert_eq_len(&self, other: &BitVec) {
363 assert_eq!(self.num_bits, other.num_bits, "BitVec lengths don't match");
364 }
365
366 #[inline]
368 fn binary_op<F: Fn(Block, Block) -> Block>(&mut self, other: &BitVec, op: F) {
369 self.assert_eq_len(other);
370 for (a, b) in self.storage.iter_mut().zip(other.storage.iter()) {
371 *a = op(*a, *b);
372 }
373 }
374
375 #[inline]
377 fn clear_trailing_bits(&mut self) -> bool {
378 let bit_offset = Self::bit_offset(self.num_bits);
379 if bit_offset == 0 {
380 return false;
382 }
383
384 let block = self
386 .storage
387 .last_mut()
388 .expect("Storage should not be empty");
389 let old_block = *block;
390 let mask = Self::mask_over_first_n_bits(bit_offset);
391 *block &= mask;
392
393 *block != old_block
395 }
396}
397
398impl Default for BitVec {
401 fn default() -> Self {
402 Self::new()
403 }
404}
405
406impl From<Vec<bool>> for BitVec {
407 fn from(v: Vec<bool>) -> Self {
408 Self::from_bools(&v)
409 }
410}
411
412impl From<&[bool]> for BitVec {
413 fn from(s: &[bool]) -> Self {
414 Self::from_bools(s)
415 }
416}
417
418impl<const N: usize> From<[bool; N]> for BitVec {
419 fn from(arr: [bool; N]) -> Self {
420 Self::from_bools(&arr)
421 }
422}
423
424impl<const N: usize> From<&[bool; N]> for BitVec {
425 fn from(arr: &[bool; N]) -> Self {
426 Self::from_bools(arr)
427 }
428}
429
430impl From<BitVec> for Vec<bool> {
433 fn from(bv: BitVec) -> Self {
434 bv.iter().collect()
435 }
436}
437
438impl fmt::Debug for BitVec {
441 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
442 const MAX_DISPLAY: usize = 64;
444 const HALF_DISPLAY: usize = MAX_DISPLAY / 2;
445
446 let write_bit = |formatter: &mut fmt::Formatter<'_>, index: usize| -> fmt::Result {
448 formatter.write_char(if self.get_bit_unchecked(index) {
449 '1'
450 } else {
451 '0'
452 })
453 };
454
455 f.write_str("BitVec[")?;
456 if self.num_bits <= MAX_DISPLAY {
457 for i in 0..self.num_bits {
459 write_bit(f, i)?;
460 }
461 } else {
462 for i in 0..HALF_DISPLAY {
464 write_bit(f, i)?;
465 }
466
467 f.write_str("...")?;
468
469 for i in (self.num_bits - HALF_DISPLAY)..self.num_bits {
470 write_bit(f, i)?;
471 }
472 }
473 f.write_str("]")
474 }
475}
476
477impl Index<usize> for BitVec {
480 type Output = bool;
481
482 #[inline]
486 fn index(&self, index: usize) -> &Self::Output {
487 self.assert_index(index);
488 let value = self.get_bit_unchecked(index);
489 if value {
490 &true
491 } else {
492 &false
493 }
494 }
495}
496
497impl BitAnd for &BitVec {
498 type Output = BitVec;
499
500 fn bitand(self, rhs: Self) -> Self::Output {
501 self.assert_eq_len(rhs);
502 let mut result = self.clone();
503 result.and(rhs);
504 result
505 }
506}
507
508impl BitOr for &BitVec {
509 type Output = BitVec;
510
511 fn bitor(self, rhs: Self) -> Self::Output {
512 self.assert_eq_len(rhs);
513 let mut result = self.clone();
514 result.or(rhs);
515 result
516 }
517}
518
519impl BitXor for &BitVec {
520 type Output = BitVec;
521
522 fn bitxor(self, rhs: Self) -> Self::Output {
523 self.assert_eq_len(rhs);
524 let mut result = self.clone();
525 result.xor(rhs);
526 result
527 }
528}
529
530impl Write for BitVec {
533 fn write(&self, buf: &mut impl BufMut) {
534 self.num_bits.write(buf);
536
537 for &block in &self.storage {
539 block.write(buf);
540 }
541 }
542}
543
544impl<R: RangeConfig> Read<R> for BitVec {
545 fn read_cfg(buf: &mut impl Buf, range: &R) -> Result<Self, CodecError> {
546 let num_bits = usize::read_cfg(buf, range)?;
548
549 let num_blocks = num_bits.div_ceil(BITS_PER_BLOCK);
551 let mut storage = Vec::with_capacity(num_blocks);
552 for _ in 0..num_blocks {
553 let block = Block::read(buf)?;
554 storage.push(block);
555 }
556
557 let mut result = BitVec { storage, num_bits };
559 if result.clear_trailing_bits() {
560 return Err(CodecError::Invalid("BitVec", "trailing bits"));
561 }
562
563 Ok(result)
564 }
565}
566
567impl EncodeSize for BitVec {
568 fn encode_size(&self) -> usize {
569 self.num_bits.encode_size() + (Block::SIZE * self.storage.len())
570 }
571}
572
573pub struct BitIterator<'a> {
577 vec: &'a BitVec,
579
580 pos: usize,
582}
583
584impl Iterator for BitIterator<'_> {
585 type Item = bool;
586
587 fn next(&mut self) -> Option<Self::Item> {
588 if self.pos >= self.vec.len() {
589 return None;
590 }
591
592 let bit = self.vec.get_bit_unchecked(self.pos);
593 self.pos += 1;
594 Some(bit)
595 }
596
597 fn size_hint(&self) -> (usize, Option<usize>) {
598 let remaining = self.vec.len() - self.pos;
599 (remaining, Some(remaining))
600 }
601}
602
603impl ExactSizeIterator for BitIterator<'_> {}
604
605#[cfg(test)]
608mod tests {
609 use super::*;
610 use bytes::BytesMut;
611 use commonware_codec::{Decode, Encode};
612
613 #[test]
614 fn test_constructors() {
615 let bv = BitVec::new();
617 assert_eq!(bv.len(), 0);
618 assert!(bv.is_empty());
619 assert_eq!(bv.storage.len(), 0);
620
621 let bv = BitVec::with_capacity(100);
623 assert_eq!(bv.len(), 0);
624 assert!(bv.is_empty());
625 assert!(bv.storage.capacity() >= BitVec::num_blocks(100));
626
627 let bv = BitVec::zeroes(100);
629 assert_eq!(bv.len(), 100);
630 assert!(!bv.is_empty());
631 assert_eq!(bv.count_zeros(), 100);
632 for i in 0..100 {
633 assert!(!bv.get(i).unwrap());
634 }
635
636 let bv = BitVec::ones(100);
638 assert_eq!(bv.len(), 100);
639 assert!(!bv.is_empty());
640 assert_eq!(bv.count_ones(), 100);
641 for i in 0..100 {
642 assert!(bv.get(i).unwrap());
643 }
644
645 let bools = [true, false, true, false, true];
647 let bv = BitVec::from_bools(&bools);
648 assert_eq!(bv.len(), 5);
649 assert_eq!(bv.count_ones(), 3);
650
651 let vec_bool = vec![true, false, true];
653 let bv: BitVec = vec_bool.into();
654 assert_eq!(bv.len(), 3);
655 assert_eq!(bv.count_ones(), 2);
656
657 let bools_slice = [false, true, false];
658 let bv: BitVec = bools_slice.into();
659 assert_eq!(bv.len(), 3);
660 assert_eq!(bv.count_ones(), 1);
661
662 let bv: BitVec = Default::default();
664 assert_eq!(bv.len(), 0);
665 assert!(bv.is_empty());
666 }
667
668 #[test]
669 fn test_basic_operations() {
670 let mut bv = BitVec::zeroes(100);
671
672 for i in 0..100 {
674 assert_eq!(bv.get(i), Some(false));
675 }
676
677 bv.set(0);
679 bv.set(50);
680 bv.set(63); bv.set(64); bv.set(99); assert_eq!(bv.get(0), Some(true));
685 assert_eq!(bv.get(50), Some(true));
686 assert_eq!(bv.get(63), Some(true));
687 assert_eq!(bv.get(64), Some(true));
688 assert_eq!(bv.get(99), Some(true));
689 assert_eq!(bv.get(30), Some(false));
690
691 bv.clear(0);
693 bv.clear(50);
694 bv.clear(64);
695
696 assert_eq!(bv.get(0), Some(false));
697 assert_eq!(bv.get(50), Some(false));
698 assert_eq!(bv.get(63), Some(true));
699 assert_eq!(bv.get(64), Some(false));
700 assert_eq!(bv.get(99), Some(true));
701
702 bv.toggle(0); bv.toggle(63); assert_eq!(bv.get(0), Some(true));
707 assert_eq!(bv.get(63), Some(false));
708
709 bv.toggle(0);
711 assert!(!bv.get(0).unwrap());
712 bv.toggle(0);
713 assert!(bv.get(0).unwrap());
714
715 bv.set_to(10, true);
717 bv.set_to(11, false);
718
719 assert_eq!(bv.get(10), Some(true));
720 assert_eq!(bv.get(11), Some(false));
721
722 bv.push(true);
724 assert_eq!(bv.len(), 101);
725 assert!(bv.get(100).unwrap());
726
727 bv.push(false);
728 assert_eq!(bv.len(), 102);
729 assert!(!bv.get(101).unwrap());
730
731 assert_eq!(bv.pop(), Some(false));
732 assert_eq!(bv.len(), 101);
733 assert_eq!(bv.pop(), Some(true));
734 assert_eq!(bv.len(), 100);
735
736 assert_eq!(bv.get(100), None);
738 assert_eq!(bv.get(1000), None);
739 }
740
741 #[test]
742 fn test_conversions() {
743 let original = vec![true, false, true];
745 let bv: BitVec = original.clone().into();
746 assert_eq!(bv.len(), 3);
747 assert_eq!(bv.count_ones(), 2);
748
749 let converted: Vec<bool> = bv.into();
750 assert_eq!(converted.len(), 3);
751 assert_eq!(converted, original);
752 }
753
754 #[test]
755 fn test_bitwise_operations() {
756 let a = BitVec::from_bools(&[true, false, true, false, true]);
758 let b = BitVec::from_bools(&[true, true, false, false, true]);
759
760 let mut result = a.clone();
762 result.and(&b);
763 assert_eq!(
764 result,
765 BitVec::from_bools(&[true, false, false, false, true])
766 );
767
768 let mut result = a.clone();
770 result.or(&b);
771 assert_eq!(result, BitVec::from_bools(&[true, true, true, false, true]));
772
773 let mut result = a.clone();
775 result.xor(&b);
776 assert_eq!(
777 result,
778 BitVec::from_bools(&[false, true, true, false, false])
779 );
780
781 let mut result = a.clone();
783 result.invert();
784 assert_eq!(
785 result,
786 BitVec::from_bools(&[false, true, false, true, false])
787 );
788
789 let a_ref = &a;
791 let b_ref = &b;
792
793 let result = a_ref & b_ref;
794 assert_eq!(
795 result,
796 BitVec::from_bools(&[true, false, false, false, true])
797 );
798
799 let result = a_ref | b_ref;
800 assert_eq!(result, BitVec::from_bools(&[true, true, true, false, true]));
801
802 let result = a_ref ^ b_ref;
803 assert_eq!(
804 result,
805 BitVec::from_bools(&[false, true, true, false, false])
806 );
807
808 let mut bv_long1 = BitVec::zeroes(70);
810 bv_long1.set(0);
811 bv_long1.set(65);
812
813 let mut bv_long2 = BitVec::zeroes(70);
814 bv_long2.set(1);
815 bv_long2.set(65);
816
817 let mut bv_long_and = bv_long1.clone();
818 bv_long_and.and(&bv_long2);
819 let mut expected_and = BitVec::zeroes(70);
820 expected_and.set(65);
821 assert_eq!(bv_long_and, expected_and);
822 }
823
824 #[test]
825 fn test_out_of_bounds_get() {
826 let bv = BitVec::zeroes(10);
827 assert_eq!(bv.get(10), None);
829 assert_eq!(bv.get(100), None);
830
831 let empty_bv = BitVec::new();
833 assert_eq!(empty_bv.get(0), None);
834 }
835
836 #[test]
837 #[should_panic(expected = "Index out of bounds")]
838 fn test_set_out_of_bounds() {
839 let mut bv = BitVec::zeroes(10);
840 bv.set(10);
841 }
842
843 #[test]
844 #[should_panic(expected = "Index out of bounds")]
845 fn test_clear_out_of_bounds() {
846 let mut bv = BitVec::zeroes(10);
847 bv.clear(10);
848 }
849
850 #[test]
851 #[should_panic(expected = "Index out of bounds")]
852 fn test_toggle_out_of_bounds() {
853 let mut bv = BitVec::zeroes(10);
854 bv.toggle(10);
855 }
856
857 #[test]
858 #[should_panic(expected = "Index out of bounds")]
859 fn test_set_to_out_of_bounds() {
860 let mut bv = BitVec::zeroes(10);
861 bv.set_to(10, true);
862 }
863
864 #[test]
865 #[should_panic(expected = "Index out of bounds")]
866 fn test_index_out_of_bounds() {
867 let bv = BitVec::zeroes(10);
868 let _ = bv[10];
869 }
870
871 #[test]
872 fn test_count_operations() {
873 let bv = BitVec::from_bools(&[true, false, true, true, false, true]);
875 assert_eq!(bv.count_ones(), 4);
876 assert_eq!(bv.count_zeros(), 2);
877
878 let empty = BitVec::new();
880 assert_eq!(empty.count_ones(), 0);
881 assert_eq!(empty.count_zeros(), 0);
882
883 let zeroes = BitVec::zeroes(100);
885 assert_eq!(zeroes.count_ones(), 0);
886 assert_eq!(zeroes.count_zeros(), 100);
887
888 let ones = BitVec::ones(100);
889 assert_eq!(ones.count_ones(), 100);
890 assert_eq!(ones.count_zeros(), 0);
891
892 let mut bv_multi = BitVec::zeroes(70);
894 bv_multi.set(0);
895 bv_multi.set(63); bv_multi.set(64); bv_multi.set(69);
898 assert_eq!(bv_multi.count_ones(), 4);
899 assert_eq!(bv_multi.count_zeros(), 66);
900 }
901
902 #[test]
903 fn test_clear_set_all_invert() {
904 let mut bv = BitVec::from_bools(&[true, false, true, false, true]); bv.set_all();
909 assert_eq!(bv.len(), 5);
910 assert_eq!(bv.count_ones(), 5);
911 for i in 0..5 {
912 assert_eq!(bv.get(i), Some(true));
913 }
914 assert_eq!(bv.storage[0], (1 << 5) - 1);
916
917 bv.clear_all();
919 assert_eq!(bv.len(), 5);
920 assert_eq!(bv.count_ones(), 0);
921 assert_eq!(bv.storage[0], 0);
922
923 bv.set(1);
925 bv.set(3); bv.invert(); assert_eq!(bv.count_ones(), 3);
928 assert_eq!(bv.get(0), Some(true));
929 assert_eq!(bv.get(1), Some(false));
930 assert_eq!(bv.get(2), Some(true));
931 assert_eq!(bv.get(3), Some(false));
932 assert_eq!(bv.get(4), Some(true));
933
934 let mut bv_full = BitVec::ones(64);
936 bv_full.invert();
937 assert_eq!(bv_full.count_ones(), 0);
938
939 let mut bv_part = BitVec::ones(67);
940 bv_part.invert();
941 assert_eq!(bv_part.count_ones(), 0);
942 }
943
944 #[test]
945 fn test_mask_over_first_n_bits() {
946 for i in 0..=BITS_PER_BLOCK {
948 let mask = BitVec::mask_over_first_n_bits(i);
949 assert_eq!(mask.count_ones() as usize, i);
950 assert_eq!(mask.count_zeros() as usize, BITS_PER_BLOCK - i);
951 assert_eq!(
952 mask,
953 ((1 as Block)
954 .checked_shl(i as u32)
955 .unwrap_or(0)
956 .wrapping_sub(1))
957 );
958 }
959 }
960
961 #[test]
962 fn test_codec_roundtrip() {
963 let original = BitVec::from_bools(&[true, false, true, false, true]);
964 let mut buf = original.encode();
965 let decoded = BitVec::decode_cfg(&mut buf, &..).unwrap();
966 assert_eq!(original, decoded);
967 }
968
969 #[test]
970 fn test_codec_error_invalid_length() {
971 let original = BitVec::from_bools(&[true, false, true, false, true]);
972 let buf = original.encode();
973
974 let mut buf_clone1 = buf.clone();
975 assert!(matches!(
976 BitVec::decode_cfg(&mut buf_clone1, &..=4),
977 Err(CodecError::InvalidLength(_))
978 ));
979
980 let mut buf_clone2 = buf.clone();
981 assert!(matches!(
982 BitVec::decode_cfg(&mut buf_clone2, &(6..)),
983 Err(CodecError::InvalidLength(_))
984 ));
985 }
986
987 #[test]
988 fn test_codec_error_trailing_bits() {
989 let mut buf = BytesMut::new();
990 1usize.write(&mut buf); (2 as Block).write(&mut buf); assert!(matches!(
993 BitVec::decode_cfg(&mut buf, &..),
994 Err(CodecError::Invalid("BitVec", "trailing bits"))
995 ));
996 }
997}