1use bytes::{Buf, BufMut};
14use commonware_codec::{
15 EncodeSize, Error as CodecError, FixedSize, RangeCfg, Read, ReadExt, Write,
16};
17use std::{
18 fmt::{self, Write as _},
19 ops::{BitAnd, BitOr, BitXor, Index},
20};
21
22type Block = u8;
24
25const BITS_PER_BLOCK: usize = std::mem::size_of::<Block>() * 8;
27
28const EMPTY_BLOCK: Block = 0;
30
31const FULL_BLOCK: Block = Block::MAX;
33
34#[derive(Clone, PartialEq, Eq)]
38pub struct BitVec {
39 storage: Vec<Block>,
41 num_bits: usize,
43}
44
45impl BitVec {
46 #[inline]
48 pub fn new() -> Self {
49 BitVec {
50 storage: Vec::new(),
51 num_bits: 0,
52 }
53 }
54
55 #[inline]
57 pub fn with_capacity(size: usize) -> Self {
58 BitVec {
59 storage: Vec::with_capacity(Self::num_blocks(size)),
60 num_bits: 0,
61 }
62 }
63
64 #[inline]
66 pub fn zeroes(size: usize) -> Self {
67 BitVec {
68 storage: vec![EMPTY_BLOCK; Self::num_blocks(size)],
69 num_bits: size,
70 }
71 }
72
73 #[inline]
75 pub fn ones(size: usize) -> Self {
76 let mut result = Self {
77 storage: vec![FULL_BLOCK; Self::num_blocks(size)],
78 num_bits: size,
79 };
80 result.clear_trailing_bits();
81 result
82 }
83
84 #[inline]
86 pub fn from_bools(bools: &[bool]) -> Self {
87 let mut bv = Self::with_capacity(bools.len());
88 for &b in bools {
89 bv.push(b);
90 }
91 bv
92 }
93
94 #[inline]
96 pub fn len(&self) -> usize {
97 self.num_bits
98 }
99
100 #[inline]
102 pub fn is_empty(&self) -> bool {
103 self.num_bits == 0
104 }
105
106 #[inline]
108 pub fn push(&mut self, value: bool) {
109 let index = self.num_bits;
111 self.num_bits += 1;
112
113 if Self::block_index(index) >= self.storage.len() {
115 self.storage.push(EMPTY_BLOCK);
116 }
117
118 if value {
120 self.set_bit_unchecked(index);
121 }
122 }
123
124 #[inline]
128 pub fn pop(&mut self) -> Option<bool> {
129 if self.is_empty() {
130 return None;
131 }
132
133 self.num_bits -= 1;
135 let index = self.num_bits;
136 let value = self.get_bit_unchecked(index);
137
138 if Self::bit_offset(index) == 0 {
141 self.storage.pop().expect("Storage should not be empty");
142 } else if value {
143 self.clear_bit_unchecked(index);
144 }
145
146 Some(value)
147 }
148
149 #[inline]
153 pub fn get(&self, index: usize) -> Option<bool> {
154 if index >= self.num_bits {
155 return None;
156 }
157 Some(self.get_bit_unchecked(index))
158 }
159
160 #[inline]
166 pub unsafe fn get_unchecked(&self, index: usize) -> bool {
167 self.get_bit_unchecked(index)
168 }
169
170 #[inline]
176 pub fn set(&mut self, index: usize) {
177 self.assert_index(index);
178 self.set_bit_unchecked(index);
179 }
180
181 #[inline]
187 pub fn clear(&mut self, index: usize) {
188 self.assert_index(index);
189 self.clear_bit_unchecked(index);
190 }
191
192 #[inline]
198 pub fn toggle(&mut self, index: usize) {
199 self.assert_index(index);
200 self.toggle_bit_unchecked(index);
201 }
202
203 #[inline]
209 pub fn set_to(&mut self, index: usize, value: bool) {
210 self.assert_index(index);
211 if value {
212 self.set_bit_unchecked(index);
213 } else {
214 self.clear_bit_unchecked(index);
215 }
216 }
217
218 #[inline]
220 pub fn clear_all(&mut self) {
221 for block in &mut self.storage {
222 *block = EMPTY_BLOCK;
223 }
224 }
225
226 #[inline]
228 pub fn set_all(&mut self) {
229 for block in &mut self.storage {
230 *block = FULL_BLOCK;
231 }
232 self.clear_trailing_bits();
233 }
234
235 #[inline]
237 pub fn count_ones(&self) -> usize {
238 self.storage
239 .iter()
240 .map(|block| block.count_ones() as usize)
241 .sum()
242 }
243
244 #[inline]
246 pub fn count_zeros(&self) -> usize {
247 self.num_bits
248 .checked_sub(self.count_ones())
249 .expect("Overflow in count_zeros")
250 }
251
252 pub fn and(&mut self, other: &BitVec) {
258 self.binary_op(other, |a, b| a & b);
259 self.clear_trailing_bits();
260 }
261
262 pub fn or(&mut self, other: &BitVec) {
268 self.binary_op(other, |a, b| a | b);
269 self.clear_trailing_bits();
270 }
271
272 pub fn xor(&mut self, other: &BitVec) {
278 self.binary_op(other, |a, b| a ^ b);
279 self.clear_trailing_bits();
280 }
281
282 pub fn invert(&mut self) {
284 for block in &mut self.storage {
285 *block = !*block;
286 }
287 self.clear_trailing_bits();
288 }
289
290 pub fn iter(&self) -> BitIterator {
292 BitIterator { vec: self, pos: 0 }
293 }
294
295 #[inline(always)]
299 fn block_index(index: usize) -> usize {
300 index / BITS_PER_BLOCK
301 }
302
303 #[inline(always)]
305 fn bit_offset(index: usize) -> usize {
306 index % BITS_PER_BLOCK
307 }
308
309 #[inline(always)]
311 fn num_blocks(num_bits: usize) -> usize {
312 num_bits.div_ceil(BITS_PER_BLOCK)
313 }
314
315 #[inline(always)]
317 fn mask_over_first_n_bits(num_bits: usize) -> Block {
318 assert!(num_bits <= BITS_PER_BLOCK, "num_bits exceeds block size");
319 match num_bits {
321 BITS_PER_BLOCK => FULL_BLOCK,
322 _ => (1 << num_bits) - 1,
323 }
324 }
325
326 #[inline(always)]
327 fn get_bit_unchecked(&self, index: usize) -> bool {
328 let block_index = Self::block_index(index);
329 let bit_index = Self::bit_offset(index);
330 (self.storage[block_index] & (1 << bit_index)) != 0
331 }
332
333 #[inline(always)]
334 fn set_bit_unchecked(&mut self, index: usize) {
335 let block_index = Self::block_index(index);
336 let bit_index = Self::bit_offset(index);
337 self.storage[block_index] |= 1 << bit_index;
338 }
339
340 #[inline(always)]
341 fn clear_bit_unchecked(&mut self, index: usize) {
342 let block_index = Self::block_index(index);
343 let bit_index = Self::bit_offset(index);
344 self.storage[block_index] &= !(1 << bit_index);
345 }
346
347 #[inline(always)]
348 fn toggle_bit_unchecked(&mut self, index: usize) {
349 let block_index = Self::block_index(index);
350 let bit_index = Self::bit_offset(index);
351 self.storage[block_index] ^= 1 << bit_index;
352 }
353
354 #[inline(always)]
356 fn assert_index(&self, index: usize) {
357 assert!(index < self.num_bits, "Index out of bounds");
358 }
359
360 #[inline(always)]
362 fn assert_eq_len(&self, other: &BitVec) {
363 assert_eq!(self.num_bits, other.num_bits, "BitVec lengths don't match");
364 }
365
366 #[inline]
368 fn binary_op<F: Fn(Block, Block) -> Block>(&mut self, other: &BitVec, op: F) {
369 self.assert_eq_len(other);
370 for (a, b) in self.storage.iter_mut().zip(other.storage.iter()) {
371 *a = op(*a, *b);
372 }
373 }
374
375 #[inline]
377 fn clear_trailing_bits(&mut self) -> bool {
378 let bit_offset = Self::bit_offset(self.num_bits);
379 if bit_offset == 0 {
380 return false;
382 }
383
384 let block = self
386 .storage
387 .last_mut()
388 .expect("Storage should not be empty");
389 let old_block = *block;
390 let mask = Self::mask_over_first_n_bits(bit_offset);
391 *block &= mask;
392
393 *block != old_block
395 }
396}
397
398impl Default for BitVec {
401 fn default() -> Self {
402 Self::new()
403 }
404}
405
406impl From<Vec<bool>> for BitVec {
407 fn from(v: Vec<bool>) -> Self {
408 Self::from_bools(&v)
409 }
410}
411
412impl From<&[bool]> for BitVec {
413 fn from(s: &[bool]) -> Self {
414 Self::from_bools(s)
415 }
416}
417
418impl<const N: usize> From<[bool; N]> for BitVec {
419 fn from(arr: [bool; N]) -> Self {
420 Self::from_bools(&arr)
421 }
422}
423
424impl<const N: usize> From<&[bool; N]> for BitVec {
425 fn from(arr: &[bool; N]) -> Self {
426 Self::from_bools(arr)
427 }
428}
429
430impl From<BitVec> for Vec<bool> {
433 fn from(bv: BitVec) -> Self {
434 bv.iter().collect()
435 }
436}
437
438impl fmt::Debug for BitVec {
441 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
442 const MAX_DISPLAY: usize = 64;
444 const HALF_DISPLAY: usize = MAX_DISPLAY / 2;
445
446 let write_bit = |formatter: &mut fmt::Formatter<'_>, index: usize| -> fmt::Result {
448 formatter.write_char(if self.get_bit_unchecked(index) {
449 '1'
450 } else {
451 '0'
452 })
453 };
454
455 f.write_str("BitVec[")?;
456 if self.num_bits <= MAX_DISPLAY {
457 for i in 0..self.num_bits {
459 write_bit(f, i)?;
460 }
461 } else {
462 for i in 0..HALF_DISPLAY {
464 write_bit(f, i)?;
465 }
466
467 f.write_str("...")?;
468
469 for i in (self.num_bits - HALF_DISPLAY)..self.num_bits {
470 write_bit(f, i)?;
471 }
472 }
473 f.write_str("]")
474 }
475}
476
477impl Index<usize> for BitVec {
480 type Output = bool;
481
482 #[inline]
486 fn index(&self, index: usize) -> &Self::Output {
487 self.assert_index(index);
488 let value = self.get_bit_unchecked(index);
489 if value {
490 &true
491 } else {
492 &false
493 }
494 }
495}
496
497impl BitAnd for &BitVec {
498 type Output = BitVec;
499
500 fn bitand(self, rhs: Self) -> Self::Output {
501 self.assert_eq_len(rhs);
502 let mut result = self.clone();
503 result.and(rhs);
504 result
505 }
506}
507
508impl BitOr for &BitVec {
509 type Output = BitVec;
510
511 fn bitor(self, rhs: Self) -> Self::Output {
512 self.assert_eq_len(rhs);
513 let mut result = self.clone();
514 result.or(rhs);
515 result
516 }
517}
518
519impl BitXor for &BitVec {
520 type Output = BitVec;
521
522 fn bitxor(self, rhs: Self) -> Self::Output {
523 self.assert_eq_len(rhs);
524 let mut result = self.clone();
525 result.xor(rhs);
526 result
527 }
528}
529
530impl Write for BitVec {
533 fn write(&self, buf: &mut impl BufMut) {
534 self.num_bits.write(buf);
536
537 for &block in &self.storage {
539 block.write(buf);
540 }
541 }
542}
543
544impl Read for BitVec {
545 type Cfg = RangeCfg;
546
547 fn read_cfg(buf: &mut impl Buf, range: &Self::Cfg) -> Result<Self, CodecError> {
548 let num_bits = usize::read_cfg(buf, range)?;
550
551 let num_blocks = num_bits.div_ceil(BITS_PER_BLOCK);
553 let mut storage = Vec::with_capacity(num_blocks);
554 for _ in 0..num_blocks {
555 let block = Block::read(buf)?;
556 storage.push(block);
557 }
558
559 let mut result = BitVec { storage, num_bits };
561 if result.clear_trailing_bits() {
562 return Err(CodecError::Invalid("BitVec", "trailing bits"));
563 }
564
565 Ok(result)
566 }
567}
568
569impl EncodeSize for BitVec {
570 fn encode_size(&self) -> usize {
571 self.num_bits.encode_size() + (Block::SIZE * self.storage.len())
572 }
573}
574
575pub struct BitIterator<'a> {
579 vec: &'a BitVec,
581
582 pos: usize,
584}
585
586impl Iterator for BitIterator<'_> {
587 type Item = bool;
588
589 fn next(&mut self) -> Option<Self::Item> {
590 if self.pos >= self.vec.len() {
591 return None;
592 }
593
594 let bit = self.vec.get_bit_unchecked(self.pos);
595 self.pos += 1;
596 Some(bit)
597 }
598
599 fn size_hint(&self) -> (usize, Option<usize>) {
600 let remaining = self.vec.len() - self.pos;
601 (remaining, Some(remaining))
602 }
603}
604
605impl ExactSizeIterator for BitIterator<'_> {}
606
607#[cfg(test)]
610mod tests {
611 use super::*;
612 use bytes::BytesMut;
613 use commonware_codec::{Decode, Encode};
614
615 #[test]
616 fn test_constructors() {
617 let bv = BitVec::new();
619 assert_eq!(bv.len(), 0);
620 assert!(bv.is_empty());
621 assert_eq!(bv.storage.len(), 0);
622
623 let bv = BitVec::with_capacity(100);
625 assert_eq!(bv.len(), 0);
626 assert!(bv.is_empty());
627 assert!(bv.storage.capacity() >= BitVec::num_blocks(100));
628
629 let bv = BitVec::zeroes(100);
631 assert_eq!(bv.len(), 100);
632 assert!(!bv.is_empty());
633 assert_eq!(bv.count_zeros(), 100);
634 for i in 0..100 {
635 assert!(!bv.get(i).unwrap());
636 }
637
638 let bv = BitVec::ones(100);
640 assert_eq!(bv.len(), 100);
641 assert!(!bv.is_empty());
642 assert_eq!(bv.count_ones(), 100);
643 for i in 0..100 {
644 assert!(bv.get(i).unwrap());
645 }
646
647 let bools = [true, false, true, false, true];
649 let bv = BitVec::from_bools(&bools);
650 assert_eq!(bv.len(), 5);
651 assert_eq!(bv.count_ones(), 3);
652
653 let vec_bool = vec![true, false, true];
655 let bv: BitVec = vec_bool.into();
656 assert_eq!(bv.len(), 3);
657 assert_eq!(bv.count_ones(), 2);
658
659 let bools_slice = [false, true, false];
660 let bv: BitVec = bools_slice.into();
661 assert_eq!(bv.len(), 3);
662 assert_eq!(bv.count_ones(), 1);
663
664 let bv: BitVec = Default::default();
666 assert_eq!(bv.len(), 0);
667 assert!(bv.is_empty());
668 }
669
670 #[test]
671 fn test_basic_operations() {
672 let mut bv = BitVec::zeroes(100);
673
674 for i in 0..100 {
676 assert_eq!(bv.get(i), Some(false));
677 }
678
679 bv.set(0);
681 bv.set(50);
682 bv.set(63); bv.set(64); bv.set(99); assert_eq!(bv.get(0), Some(true));
687 assert_eq!(bv.get(50), Some(true));
688 assert_eq!(bv.get(63), Some(true));
689 assert_eq!(bv.get(64), Some(true));
690 assert_eq!(bv.get(99), Some(true));
691 assert_eq!(bv.get(30), Some(false));
692
693 bv.clear(0);
695 bv.clear(50);
696 bv.clear(64);
697
698 assert_eq!(bv.get(0), Some(false));
699 assert_eq!(bv.get(50), Some(false));
700 assert_eq!(bv.get(63), Some(true));
701 assert_eq!(bv.get(64), Some(false));
702 assert_eq!(bv.get(99), Some(true));
703
704 bv.toggle(0); bv.toggle(63); assert_eq!(bv.get(0), Some(true));
709 assert_eq!(bv.get(63), Some(false));
710
711 bv.toggle(0);
713 assert!(!bv.get(0).unwrap());
714 bv.toggle(0);
715 assert!(bv.get(0).unwrap());
716
717 bv.set_to(10, true);
719 bv.set_to(11, false);
720
721 assert_eq!(bv.get(10), Some(true));
722 assert_eq!(bv.get(11), Some(false));
723
724 bv.push(true);
726 assert_eq!(bv.len(), 101);
727 assert!(bv.get(100).unwrap());
728
729 bv.push(false);
730 assert_eq!(bv.len(), 102);
731 assert!(!bv.get(101).unwrap());
732
733 assert_eq!(bv.pop(), Some(false));
734 assert_eq!(bv.len(), 101);
735 assert_eq!(bv.pop(), Some(true));
736 assert_eq!(bv.len(), 100);
737
738 assert_eq!(bv.get(100), None);
740 assert_eq!(bv.get(1000), None);
741 }
742
743 #[test]
744 fn test_conversions() {
745 let original = vec![true, false, true];
747 let bv: BitVec = original.clone().into();
748 assert_eq!(bv.len(), 3);
749 assert_eq!(bv.count_ones(), 2);
750
751 let converted: Vec<bool> = bv.into();
752 assert_eq!(converted.len(), 3);
753 assert_eq!(converted, original);
754 }
755
756 #[test]
757 fn test_bitwise_operations() {
758 let a = BitVec::from_bools(&[true, false, true, false, true]);
760 let b = BitVec::from_bools(&[true, true, false, false, true]);
761
762 let mut result = a.clone();
764 result.and(&b);
765 assert_eq!(
766 result,
767 BitVec::from_bools(&[true, false, false, false, true])
768 );
769
770 let mut result = a.clone();
772 result.or(&b);
773 assert_eq!(result, BitVec::from_bools(&[true, true, true, false, true]));
774
775 let mut result = a.clone();
777 result.xor(&b);
778 assert_eq!(
779 result,
780 BitVec::from_bools(&[false, true, true, false, false])
781 );
782
783 let mut result = a.clone();
785 result.invert();
786 assert_eq!(
787 result,
788 BitVec::from_bools(&[false, true, false, true, false])
789 );
790
791 let a_ref = &a;
793 let b_ref = &b;
794
795 let result = a_ref & b_ref;
796 assert_eq!(
797 result,
798 BitVec::from_bools(&[true, false, false, false, true])
799 );
800
801 let result = a_ref | b_ref;
802 assert_eq!(result, BitVec::from_bools(&[true, true, true, false, true]));
803
804 let result = a_ref ^ b_ref;
805 assert_eq!(
806 result,
807 BitVec::from_bools(&[false, true, true, false, false])
808 );
809
810 let mut bv_long1 = BitVec::zeroes(70);
812 bv_long1.set(0);
813 bv_long1.set(65);
814
815 let mut bv_long2 = BitVec::zeroes(70);
816 bv_long2.set(1);
817 bv_long2.set(65);
818
819 let mut bv_long_and = bv_long1.clone();
820 bv_long_and.and(&bv_long2);
821 let mut expected_and = BitVec::zeroes(70);
822 expected_and.set(65);
823 assert_eq!(bv_long_and, expected_and);
824 }
825
826 #[test]
827 fn test_out_of_bounds_get() {
828 let bv = BitVec::zeroes(10);
829 assert_eq!(bv.get(10), None);
831 assert_eq!(bv.get(100), None);
832
833 let empty_bv = BitVec::new();
835 assert_eq!(empty_bv.get(0), None);
836 }
837
838 #[test]
839 #[should_panic(expected = "Index out of bounds")]
840 fn test_set_out_of_bounds() {
841 let mut bv = BitVec::zeroes(10);
842 bv.set(10);
843 }
844
845 #[test]
846 #[should_panic(expected = "Index out of bounds")]
847 fn test_clear_out_of_bounds() {
848 let mut bv = BitVec::zeroes(10);
849 bv.clear(10);
850 }
851
852 #[test]
853 #[should_panic(expected = "Index out of bounds")]
854 fn test_toggle_out_of_bounds() {
855 let mut bv = BitVec::zeroes(10);
856 bv.toggle(10);
857 }
858
859 #[test]
860 #[should_panic(expected = "Index out of bounds")]
861 fn test_set_to_out_of_bounds() {
862 let mut bv = BitVec::zeroes(10);
863 bv.set_to(10, true);
864 }
865
866 #[test]
867 #[should_panic(expected = "Index out of bounds")]
868 fn test_index_out_of_bounds() {
869 let bv = BitVec::zeroes(10);
870 let _ = bv[10];
871 }
872
873 #[test]
874 fn test_count_operations() {
875 let bv = BitVec::from_bools(&[true, false, true, true, false, true]);
877 assert_eq!(bv.count_ones(), 4);
878 assert_eq!(bv.count_zeros(), 2);
879
880 let empty = BitVec::new();
882 assert_eq!(empty.count_ones(), 0);
883 assert_eq!(empty.count_zeros(), 0);
884
885 let zeroes = BitVec::zeroes(100);
887 assert_eq!(zeroes.count_ones(), 0);
888 assert_eq!(zeroes.count_zeros(), 100);
889
890 let ones = BitVec::ones(100);
891 assert_eq!(ones.count_ones(), 100);
892 assert_eq!(ones.count_zeros(), 0);
893
894 let mut bv_multi = BitVec::zeroes(70);
896 bv_multi.set(0);
897 bv_multi.set(63); bv_multi.set(64); bv_multi.set(69);
900 assert_eq!(bv_multi.count_ones(), 4);
901 assert_eq!(bv_multi.count_zeros(), 66);
902 }
903
904 #[test]
905 fn test_clear_set_all_invert() {
906 let mut bv = BitVec::from_bools(&[true, false, true, false, true]); bv.set_all();
911 assert_eq!(bv.len(), 5);
912 assert_eq!(bv.count_ones(), 5);
913 for i in 0..5 {
914 assert_eq!(bv.get(i), Some(true));
915 }
916 assert_eq!(bv.storage[0], (1 << 5) - 1);
918
919 bv.clear_all();
921 assert_eq!(bv.len(), 5);
922 assert_eq!(bv.count_ones(), 0);
923 assert_eq!(bv.storage[0], 0);
924
925 bv.set(1);
927 bv.set(3); bv.invert(); assert_eq!(bv.count_ones(), 3);
930 assert_eq!(bv.get(0), Some(true));
931 assert_eq!(bv.get(1), Some(false));
932 assert_eq!(bv.get(2), Some(true));
933 assert_eq!(bv.get(3), Some(false));
934 assert_eq!(bv.get(4), Some(true));
935
936 let mut bv_full = BitVec::ones(64);
938 bv_full.invert();
939 assert_eq!(bv_full.count_ones(), 0);
940
941 let mut bv_part = BitVec::ones(67);
942 bv_part.invert();
943 assert_eq!(bv_part.count_ones(), 0);
944 }
945
946 #[test]
947 fn test_mask_over_first_n_bits() {
948 for i in 0..=BITS_PER_BLOCK {
950 let mask = BitVec::mask_over_first_n_bits(i);
951 assert_eq!(mask.count_ones() as usize, i);
952 assert_eq!(mask.count_zeros() as usize, BITS_PER_BLOCK - i);
953 assert_eq!(
954 mask,
955 ((1 as Block)
956 .checked_shl(i as u32)
957 .unwrap_or(0)
958 .wrapping_sub(1))
959 );
960 }
961 }
962
963 #[test]
964 fn test_codec_roundtrip() {
965 let original = BitVec::from_bools(&[true, false, true, false, true]);
966 let mut buf = original.encode();
967 let decoded = BitVec::decode_cfg(&mut buf, &(..).into()).unwrap();
968 assert_eq!(original, decoded);
969 }
970
971 #[test]
972 fn test_codec_error_invalid_length() {
973 let original = BitVec::from_bools(&[true, false, true, false, true]);
974 let buf = original.encode();
975
976 let mut buf_clone1 = buf.clone();
977 assert!(matches!(
978 BitVec::decode_cfg(&mut buf_clone1, &(..=4usize).into()),
979 Err(CodecError::InvalidLength(_))
980 ));
981
982 let mut buf_clone2 = buf.clone();
983 assert!(matches!(
984 BitVec::decode_cfg(&mut buf_clone2, &(6usize..).into()),
985 Err(CodecError::InvalidLength(_))
986 ));
987 }
988
989 #[test]
990 fn test_codec_error_trailing_bits() {
991 let mut buf = BytesMut::new();
992 1usize.write(&mut buf); (2 as Block).write(&mut buf); assert!(matches!(
995 BitVec::decode_cfg(&mut buf, &(..).into()),
996 Err(CodecError::Invalid("BitVec", "trailing bits"))
997 ));
998 }
999}