1use crate::tree_hash::bitfield_bytes_tree_hash_root;
2use crate::Error;
3use core::marker::PhantomData;
4use eth2_serde_utils::hex::{encode as hex_encode, PrefixedHexVisitor};
5use serde::de::{Deserialize, Deserializer};
6use serde::ser::{Serialize, Serializer};
7use ssz::{Decode, Encode};
8use tree_hash::Hash256;
9use typenum::Unsigned;
10
11pub trait BitfieldBehaviour: Clone {}
13
14#[derive(Clone, PartialEq, Debug)]
18pub struct Variable<N> {
19 _phantom: PhantomData<N>,
20}
21
22#[derive(Clone, PartialEq, Debug)]
26pub struct Fixed<N> {
27 _phantom: PhantomData<N>,
28}
29
30impl<N: Unsigned + Clone> BitfieldBehaviour for Variable<N> {}
31impl<N: Unsigned + Clone> BitfieldBehaviour for Fixed<N> {}
32
33pub type BitList<N> = Bitfield<Variable<N>>;
35
36pub type BitVector<N> = Bitfield<Fixed<N>>;
40
41#[derive(Clone, Debug, PartialEq)]
91pub struct Bitfield<T> {
92 bytes: Vec<u8>,
93 len: usize,
94 _phantom: PhantomData<T>,
95}
96
97impl<N: Unsigned + Clone> Bitfield<Variable<N>> {
98 pub fn with_capacity(num_bits: usize) -> Result<Self, Error> {
105 if num_bits <= N::to_usize() {
106 Ok(Self {
107 bytes: vec![0; bytes_for_bit_len(num_bits)],
108 len: num_bits,
109 _phantom: PhantomData,
110 })
111 } else {
112 Err(Error::OutOfBounds {
113 i: Self::max_len(),
114 len: Self::max_len(),
115 })
116 }
117 }
118
119 pub fn max_len() -> usize {
121 N::to_usize()
122 }
123
124 pub fn into_bytes(self) -> Vec<u8> {
140 let len = self.len();
141 let mut bytes = self.bytes;
142
143 bytes.resize(bytes_for_bit_len(len + 1), 0);
144
145 let mut bitfield: Bitfield<Variable<N>> = Bitfield::from_raw_bytes(bytes, len + 1)
146 .unwrap_or_else(|_| {
147 unreachable!(
148 "Bitfield with {} bytes must have enough capacity for {} bits.",
149 bytes_for_bit_len(len + 1),
150 len + 1
151 )
152 });
153 bitfield
154 .set(len, true)
155 .expect("len must be in bounds for bitfield.");
156
157 bitfield.bytes
158 }
159
160 pub fn from_bytes(bytes: Vec<u8>) -> Result<Self, Error> {
165 let bytes_len = bytes.len();
166 let mut initial_bitfield: Bitfield<Variable<N>> = {
167 let num_bits = bytes.len() * 8;
168 Bitfield::from_raw_bytes(bytes, num_bits)?
169 };
170
171 let len = initial_bitfield
172 .highest_set_bit()
173 .ok_or(Error::MissingLengthInformation)?;
174
175 if len / 8 + 1 != bytes_len {
177 return Err(Error::InvalidByteCount {
178 given: bytes_len,
179 expected: len / 8 + 1,
180 });
181 }
182
183 if len <= Self::max_len() {
184 initial_bitfield
185 .set(len, false)
186 .expect("Bit has been confirmed to exist");
187
188 let mut bytes = initial_bitfield.into_raw_bytes();
189
190 bytes.truncate(bytes_for_bit_len(len));
191
192 Self::from_raw_bytes(bytes, len)
193 } else {
194 Err(Error::OutOfBounds {
195 i: Self::max_len(),
196 len: Self::max_len(),
197 })
198 }
199 }
200
201 pub fn intersection(&self, other: &Self) -> Self {
205 let min_len = std::cmp::min(self.len(), other.len());
206 let mut result = Self::with_capacity(min_len).expect("min len always less than N");
207 for i in 0..result.bytes.len() {
211 result.bytes[i] = self.bytes[i] & other.bytes[i];
212 }
213 result
214 }
215
216 pub fn union(&self, other: &Self) -> Self {
220 let max_len = std::cmp::max(self.len(), other.len());
221 let mut result = Self::with_capacity(max_len).expect("max len always less than N");
222 for i in 0..result.bytes.len() {
223 result.bytes[i] =
224 self.bytes.get(i).copied().unwrap_or(0) | other.bytes.get(i).copied().unwrap_or(0);
225 }
226 result
227 }
228}
229
230impl<N: Unsigned + Clone> Bitfield<Fixed<N>> {
231 pub fn new() -> Self {
235 Self {
236 bytes: vec![0; bytes_for_bit_len(Self::capacity())],
237 len: Self::capacity(),
238 _phantom: PhantomData,
239 }
240 }
241
242 pub fn capacity() -> usize {
244 N::to_usize()
245 }
246
247 pub fn into_bytes(self) -> Vec<u8> {
260 self.into_raw_bytes()
261 }
262
263 pub fn from_bytes(bytes: Vec<u8>) -> Result<Self, Error> {
268 Self::from_raw_bytes(bytes, Self::capacity())
269 }
270
271 pub fn intersection(&self, other: &Self) -> Self {
275 let mut result = Self::new();
276 for i in 0..result.bytes.len() {
280 result.bytes[i] = self.bytes[i] & other.bytes[i];
281 }
282 result
283 }
284
285 pub fn union(&self, other: &Self) -> Self {
289 let mut result = Self::new();
290 for i in 0..result.bytes.len() {
291 result.bytes[i] =
292 self.bytes.get(i).copied().unwrap_or(0) | other.bytes.get(i).copied().unwrap_or(0);
293 }
294 result
295 }
296}
297
298impl<N: Unsigned + Clone> Default for Bitfield<Fixed<N>> {
299 fn default() -> Self {
300 Self::new()
301 }
302}
303
304impl<T: BitfieldBehaviour> Bitfield<T> {
305 pub fn set(&mut self, i: usize, value: bool) -> Result<(), Error> {
309 let len = self.len;
310
311 if i < len {
312 let byte = self
313 .bytes
314 .get_mut(i / 8)
315 .ok_or(Error::OutOfBounds { i, len })?;
316
317 if value {
318 *byte |= 1 << (i % 8)
319 } else {
320 *byte &= !(1 << (i % 8))
321 }
322
323 Ok(())
324 } else {
325 Err(Error::OutOfBounds { i, len: self.len })
326 }
327 }
328
329 pub fn get(&self, i: usize) -> Result<bool, Error> {
333 if i < self.len {
334 let byte = self
335 .bytes
336 .get(i / 8)
337 .ok_or(Error::OutOfBounds { i, len: self.len })?;
338
339 Ok(*byte & 1 << (i % 8) > 0)
340 } else {
341 Err(Error::OutOfBounds { i, len: self.len })
342 }
343 }
344
345 pub fn len(&self) -> usize {
347 self.len
348 }
349
350 pub fn is_empty(&self) -> bool {
352 self.len == 0
353 }
354
355 pub fn into_raw_bytes(self) -> Vec<u8> {
357 self.bytes
358 }
359
360 pub fn as_slice(&self) -> &[u8] {
362 &self.bytes
363 }
364
365 fn from_raw_bytes(bytes: Vec<u8>, bit_len: usize) -> Result<Self, Error> {
374 if bit_len == 0 {
375 if bytes.len() == 1 && bytes == [0] {
376 Ok(Self {
378 bytes,
379 len: 0,
380 _phantom: PhantomData,
381 })
382 } else {
383 Err(Error::ExcessBits)
384 }
385 } else if bytes.len() != bytes_for_bit_len(bit_len) {
386 Err(Error::InvalidByteCount {
388 given: bytes.len(),
389 expected: bytes_for_bit_len(bit_len),
390 })
391 } else {
392 let (mask, _) = u8::max_value().overflowing_shr(8 - (bit_len as u32 % 8));
394
395 if (bytes.last().expect("Guarded against empty bytes") & !mask) == 0 {
396 Ok(Self {
397 bytes,
398 len: bit_len,
399 _phantom: PhantomData,
400 })
401 } else {
402 Err(Error::ExcessBits)
403 }
404 }
405 }
406
407 pub fn highest_set_bit(&self) -> Option<usize> {
410 self.bytes
411 .iter()
412 .enumerate()
413 .rev()
414 .find(|(_, byte)| **byte > 0)
415 .map(|(i, byte)| i * 8 + 7 - byte.leading_zeros() as usize)
416 }
417
418 pub fn iter(&self) -> BitIter<'_, T> {
420 BitIter {
421 bitfield: self,
422 i: 0,
423 }
424 }
425
426 pub fn is_zero(&self) -> bool {
428 self.bytes.iter().all(|byte| *byte == 0)
429 }
430
431 pub fn num_set_bits(&self) -> usize {
433 self.bytes
434 .iter()
435 .map(|byte| byte.count_ones() as usize)
436 .sum()
437 }
438
439 pub fn difference(&self, other: &Self) -> Self {
441 let mut result = self.clone();
442 result.difference_inplace(other);
443 result
444 }
445
446 pub fn difference_inplace(&mut self, other: &Self) {
448 let min_byte_len = std::cmp::min(self.bytes.len(), other.bytes.len());
449
450 for i in 0..min_byte_len {
451 self.bytes[i] &= !other.bytes[i];
452 }
453 }
454
455 pub fn shift_up(&mut self, n: usize) -> Result<(), Error> {
459 if n <= self.len() {
460 for i in (n..self.len()).rev() {
462 self.set(i, self.get(i - n)?)?;
463 }
464 for i in 0..n {
466 self.set(i, false).unwrap();
467 }
468 Ok(())
469 } else {
470 Err(Error::OutOfBounds {
471 i: n,
472 len: self.len(),
473 })
474 }
475 }
476}
477
478fn bytes_for_bit_len(bit_len: usize) -> usize {
482 std::cmp::max(1, (bit_len + 7) / 8)
483}
484
485pub struct BitIter<'a, T> {
487 bitfield: &'a Bitfield<T>,
488 i: usize,
489}
490
491impl<'a, T: BitfieldBehaviour> Iterator for BitIter<'a, T> {
492 type Item = bool;
493
494 fn next(&mut self) -> Option<Self::Item> {
495 let res = self.bitfield.get(self.i).ok()?;
496 self.i += 1;
497 Some(res)
498 }
499}
500
501impl<N: Unsigned + Clone> Encode for Bitfield<Variable<N>> {
502 fn is_ssz_fixed_len() -> bool {
503 false
504 }
505
506 fn ssz_bytes_len(&self) -> usize {
507 self.clone().into_bytes().len()
510 }
511
512 fn ssz_append(&self, buf: &mut Vec<u8>) {
513 buf.append(&mut self.clone().into_bytes())
514 }
515}
516
517impl<N: Unsigned + Clone> Decode for Bitfield<Variable<N>> {
518 fn is_ssz_fixed_len() -> bool {
519 false
520 }
521
522 fn from_ssz_bytes(bytes: &[u8]) -> Result<Self, ssz::DecodeError> {
523 Self::from_bytes(bytes.to_vec()).map_err(|e| {
524 ssz::DecodeError::BytesInvalid(format!("BitList failed to decode: {:?}", e))
525 })
526 }
527}
528
529impl<N: Unsigned + Clone> Encode for Bitfield<Fixed<N>> {
530 fn is_ssz_fixed_len() -> bool {
531 true
532 }
533
534 fn ssz_bytes_len(&self) -> usize {
535 self.as_slice().len()
536 }
537
538 fn ssz_fixed_len() -> usize {
539 bytes_for_bit_len(N::to_usize())
540 }
541
542 fn ssz_append(&self, buf: &mut Vec<u8>) {
543 buf.append(&mut self.clone().into_bytes())
544 }
545}
546
547impl<N: Unsigned + Clone> Decode for Bitfield<Fixed<N>> {
548 fn is_ssz_fixed_len() -> bool {
549 true
550 }
551
552 fn ssz_fixed_len() -> usize {
553 bytes_for_bit_len(N::to_usize())
554 }
555
556 fn from_ssz_bytes(bytes: &[u8]) -> Result<Self, ssz::DecodeError> {
557 Self::from_bytes(bytes.to_vec()).map_err(|e| {
558 ssz::DecodeError::BytesInvalid(format!("BitVector failed to decode: {:?}", e))
559 })
560 }
561}
562
563impl<N: Unsigned + Clone> Serialize for Bitfield<Variable<N>> {
564 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
566 where
567 S: Serializer,
568 {
569 serializer.serialize_str(&hex_encode(self.as_ssz_bytes()))
570 }
571}
572
573impl<'de, N: Unsigned + Clone> Deserialize<'de> for Bitfield<Variable<N>> {
574 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
576 where
577 D: Deserializer<'de>,
578 {
579 let bytes = deserializer.deserialize_str(PrefixedHexVisitor)?;
580 Self::from_ssz_bytes(&bytes)
581 .map_err(|e| serde::de::Error::custom(format!("Bitfield {:?}", e)))
582 }
583}
584
585impl<N: Unsigned + Clone> Serialize for Bitfield<Fixed<N>> {
586 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
588 where
589 S: Serializer,
590 {
591 serializer.serialize_str(&hex_encode(self.as_ssz_bytes()))
592 }
593}
594
595impl<'de, N: Unsigned + Clone> Deserialize<'de> for Bitfield<Fixed<N>> {
596 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
598 where
599 D: Deserializer<'de>,
600 {
601 let bytes = deserializer.deserialize_str(PrefixedHexVisitor)?;
602 Self::from_ssz_bytes(&bytes)
603 .map_err(|e| serde::de::Error::custom(format!("Bitfield {:?}", e)))
604 }
605}
606
607impl<N: Unsigned + Clone> tree_hash::TreeHash for Bitfield<Variable<N>> {
608 fn tree_hash_type() -> tree_hash::TreeHashType {
609 tree_hash::TreeHashType::List
610 }
611
612 fn tree_hash_packed_encoding(&self) -> Vec<u8> {
613 unreachable!("List should never be packed.")
614 }
615
616 fn tree_hash_packing_factor() -> usize {
617 unreachable!("List should never be packed.")
618 }
619
620 fn tree_hash_root(&self) -> Hash256 {
621 let root = bitfield_bytes_tree_hash_root::<N>(self.as_slice());
624 tree_hash::mix_in_length(&root, self.len())
625 }
626}
627
628impl<N: Unsigned + Clone> tree_hash::TreeHash for Bitfield<Fixed<N>> {
629 fn tree_hash_type() -> tree_hash::TreeHashType {
630 tree_hash::TreeHashType::Vector
631 }
632
633 fn tree_hash_packed_encoding(&self) -> Vec<u8> {
634 unreachable!("Vector should never be packed.")
635 }
636
637 fn tree_hash_packing_factor() -> usize {
638 unreachable!("Vector should never be packed.")
639 }
640
641 fn tree_hash_root(&self) -> Hash256 {
642 bitfield_bytes_tree_hash_root::<N>(self.as_slice())
643 }
644}
645
646#[cfg(feature = "arbitrary")]
647impl<N: 'static + Unsigned> arbitrary::Arbitrary<'_> for Bitfield<Fixed<N>> {
648 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
649 let size = N::to_usize();
650 let mut vec: Vec<u8> = vec![0u8; size];
651 u.fill_buffer(&mut vec)?;
652 Ok(Self::from_bytes(vec).map_err(|_| arbitrary::Error::IncorrectFormat)?)
653 }
654}
655
656#[cfg(feature = "arbitrary")]
657impl<N: 'static + Unsigned> arbitrary::Arbitrary<'_> for Bitfield<Variable<N>> {
658 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
659 let max_size = N::to_usize();
660 let rand = usize::arbitrary(u)?;
661 let size = std::cmp::min(rand, max_size);
662 let mut vec: Vec<u8> = vec![0u8; size];
663 u.fill_buffer(&mut vec)?;
664 Ok(Self::from_bytes(vec).map_err(|_| arbitrary::Error::IncorrectFormat)?)
665 }
666}
667
668#[cfg(test)]
669mod bitvector {
670 use super::*;
671 use crate::BitVector;
672
673 pub type BitVector0 = BitVector<typenum::U0>;
674 pub type BitVector1 = BitVector<typenum::U1>;
675 pub type BitVector4 = BitVector<typenum::U4>;
676 pub type BitVector8 = BitVector<typenum::U8>;
677 pub type BitVector16 = BitVector<typenum::U16>;
678 pub type BitVector64 = BitVector<typenum::U64>;
679
680 #[test]
681 fn ssz_encode() {
682 assert_eq!(BitVector0::new().as_ssz_bytes(), vec![0b0000_0000]);
683 assert_eq!(BitVector1::new().as_ssz_bytes(), vec![0b0000_0000]);
684 assert_eq!(BitVector4::new().as_ssz_bytes(), vec![0b0000_0000]);
685 assert_eq!(BitVector8::new().as_ssz_bytes(), vec![0b0000_0000]);
686 assert_eq!(
687 BitVector16::new().as_ssz_bytes(),
688 vec![0b0000_0000, 0b0000_0000]
689 );
690
691 let mut b = BitVector8::new();
692 for i in 0..8 {
693 b.set(i, true).unwrap();
694 }
695 assert_eq!(b.as_ssz_bytes(), vec![255]);
696
697 let mut b = BitVector4::new();
698 for i in 0..4 {
699 b.set(i, true).unwrap();
700 }
701 assert_eq!(b.as_ssz_bytes(), vec![0b0000_1111]);
702 }
703
704 #[test]
705 fn ssz_decode() {
706 assert!(BitVector0::from_ssz_bytes(&[0b0000_0000]).is_ok());
707 assert!(BitVector0::from_ssz_bytes(&[0b0000_0001]).is_err());
708 assert!(BitVector0::from_ssz_bytes(&[0b0000_0010]).is_err());
709
710 assert!(BitVector1::from_ssz_bytes(&[0b0000_0001]).is_ok());
711 assert!(BitVector1::from_ssz_bytes(&[0b0000_0010]).is_err());
712 assert!(BitVector1::from_ssz_bytes(&[0b0000_0100]).is_err());
713 assert!(BitVector1::from_ssz_bytes(&[0b0000_0000, 0b0000_0000]).is_err());
714
715 assert!(BitVector8::from_ssz_bytes(&[0b0000_0000]).is_ok());
716 assert!(BitVector8::from_ssz_bytes(&[1, 0b0000_0000]).is_err());
717 assert!(BitVector8::from_ssz_bytes(&[0b0000_0000, 1]).is_err());
718 assert!(BitVector8::from_ssz_bytes(&[0b0000_0001]).is_ok());
719 assert!(BitVector8::from_ssz_bytes(&[0b0000_0010]).is_ok());
720 assert!(BitVector8::from_ssz_bytes(&[0b0000_0100, 0b0000_0001]).is_err());
721 assert!(BitVector8::from_ssz_bytes(&[0b0000_0100, 0b0000_0010]).is_err());
722 assert!(BitVector8::from_ssz_bytes(&[0b0000_0100, 0b0000_0100]).is_err());
723
724 assert!(BitVector16::from_ssz_bytes(&[0b0000_0000]).is_err());
725 assert!(BitVector16::from_ssz_bytes(&[0b0000_0000, 0b0000_0000]).is_ok());
726 assert!(BitVector16::from_ssz_bytes(&[1, 0b0000_0000, 0b0000_0000]).is_err());
727 }
728
729 #[test]
730 fn intersection() {
731 let a = BitVector16::from_raw_bytes(vec![0b1100, 0b0001], 16).unwrap();
732 let b = BitVector16::from_raw_bytes(vec![0b1011, 0b1001], 16).unwrap();
733 let c = BitVector16::from_raw_bytes(vec![0b1000, 0b0001], 16).unwrap();
734
735 assert_eq!(a.intersection(&b), c);
736 assert_eq!(b.intersection(&a), c);
737 assert_eq!(a.intersection(&c), c);
738 assert_eq!(b.intersection(&c), c);
739 assert_eq!(a.intersection(&a), a);
740 assert_eq!(b.intersection(&b), b);
741 assert_eq!(c.intersection(&c), c);
742 }
743
744 #[test]
745 fn intersection_diff_length() {
746 let a = BitVector16::from_bytes(vec![0b0010_1110, 0b0010_1011]).unwrap();
747 let b = BitVector16::from_bytes(vec![0b0010_1101, 0b0000_0001]).unwrap();
748 let c = BitVector16::from_bytes(vec![0b0010_1100, 0b0000_0001]).unwrap();
749
750 assert_eq!(a.len(), 16);
751 assert_eq!(b.len(), 16);
752 assert_eq!(c.len(), 16);
753 assert_eq!(a.intersection(&b), c);
754 assert_eq!(b.intersection(&a), c);
755 }
756
757 #[test]
758 fn union() {
759 let a = BitVector16::from_raw_bytes(vec![0b1100, 0b0001], 16).unwrap();
760 let b = BitVector16::from_raw_bytes(vec![0b1011, 0b1001], 16).unwrap();
761 let c = BitVector16::from_raw_bytes(vec![0b1111, 0b1001], 16).unwrap();
762
763 assert_eq!(a.union(&b), c);
764 assert_eq!(b.union(&a), c);
765 assert_eq!(a.union(&a), a);
766 assert_eq!(b.union(&b), b);
767 assert_eq!(c.union(&c), c);
768 }
769
770 #[test]
771 fn union_diff_length() {
772 let a = BitVector16::from_bytes(vec![0b0010_1011, 0b0010_1110]).unwrap();
773 let b = BitVector16::from_bytes(vec![0b0000_0001, 0b0010_1101]).unwrap();
774 let c = BitVector16::from_bytes(vec![0b0010_1011, 0b0010_1111]).unwrap();
775
776 assert_eq!(a.len(), c.len());
777 assert_eq!(a.union(&b), c);
778 assert_eq!(b.union(&a), c);
779 }
780
781 #[test]
782 fn ssz_round_trip() {
783 assert_round_trip(BitVector0::new());
784
785 let mut b = BitVector1::new();
786 b.set(0, true).unwrap();
787 assert_round_trip(b);
788
789 let mut b = BitVector8::new();
790 for j in 0..8 {
791 if j % 2 == 0 {
792 b.set(j, true).unwrap();
793 }
794 }
795 assert_round_trip(b);
796
797 let mut b = BitVector8::new();
798 for j in 0..8 {
799 b.set(j, true).unwrap();
800 }
801 assert_round_trip(b);
802
803 let mut b = BitVector16::new();
804 for j in 0..16 {
805 if j % 2 == 0 {
806 b.set(j, true).unwrap();
807 }
808 }
809 assert_round_trip(b);
810
811 let mut b = BitVector16::new();
812 for j in 0..16 {
813 b.set(j, true).unwrap();
814 }
815 assert_round_trip(b);
816 }
817
818 fn assert_round_trip<T: Encode + Decode + PartialEq + std::fmt::Debug>(t: T) {
819 assert_eq!(T::from_ssz_bytes(&t.as_ssz_bytes()).unwrap(), t);
820 }
821
822 #[test]
823 fn ssz_bytes_len() {
824 for i in 0..64 {
825 let mut bitfield = BitVector64::new();
826 for j in 0..i {
827 bitfield.set(j, true).expect("should set bit in bounds");
828 }
829 let bytes = bitfield.as_ssz_bytes();
830 assert_eq!(bitfield.ssz_bytes_len(), bytes.len(), "i = {}", i);
831 }
832 }
833
834 #[test]
835 fn excess_bits_nimbus() {
836 let bad = vec![0b0001_1111];
837
838 assert!(BitVector4::from_ssz_bytes(&bad).is_err());
839 }
840}
841
842#[cfg(test)]
843#[allow(clippy::cognitive_complexity)]
844mod bitlist {
845 use super::*;
846 use crate::BitList;
847
848 pub type BitList0 = BitList<typenum::U0>;
849 pub type BitList1 = BitList<typenum::U1>;
850 pub type BitList8 = BitList<typenum::U8>;
851 pub type BitList16 = BitList<typenum::U16>;
852 pub type BitList1024 = BitList<typenum::U1024>;
853
854 #[test]
855 fn ssz_encode() {
856 assert_eq!(
857 BitList0::with_capacity(0).unwrap().as_ssz_bytes(),
858 vec![0b0000_0001],
859 );
860
861 assert_eq!(
862 BitList1::with_capacity(0).unwrap().as_ssz_bytes(),
863 vec![0b0000_0001],
864 );
865
866 assert_eq!(
867 BitList1::with_capacity(1).unwrap().as_ssz_bytes(),
868 vec![0b0000_0010],
869 );
870
871 assert_eq!(
872 BitList8::with_capacity(8).unwrap().as_ssz_bytes(),
873 vec![0b0000_0000, 0b0000_0001],
874 );
875
876 assert_eq!(
877 BitList8::with_capacity(7).unwrap().as_ssz_bytes(),
878 vec![0b1000_0000]
879 );
880
881 let mut b = BitList8::with_capacity(8).unwrap();
882 for i in 0..8 {
883 b.set(i, true).unwrap();
884 }
885 assert_eq!(b.as_ssz_bytes(), vec![255, 0b0000_0001]);
886
887 let mut b = BitList8::with_capacity(8).unwrap();
888 for i in 0..4 {
889 b.set(i, true).unwrap();
890 }
891 assert_eq!(b.as_ssz_bytes(), vec![0b0000_1111, 0b0000_0001]);
892
893 assert_eq!(
894 BitList16::with_capacity(16).unwrap().as_ssz_bytes(),
895 vec![0b0000_0000, 0b0000_0000, 0b0000_0001]
896 );
897 }
898
899 #[test]
900 fn ssz_decode() {
901 assert!(BitList0::from_ssz_bytes(&[]).is_err());
902 assert!(BitList1::from_ssz_bytes(&[]).is_err());
903 assert!(BitList8::from_ssz_bytes(&[]).is_err());
904 assert!(BitList16::from_ssz_bytes(&[]).is_err());
905
906 assert!(BitList0::from_ssz_bytes(&[0b0000_0000]).is_err());
907 assert!(BitList1::from_ssz_bytes(&[0b0000_0000, 0b0000_0000]).is_err());
908 assert!(BitList8::from_ssz_bytes(&[0b0000_0000]).is_err());
909 assert!(BitList16::from_ssz_bytes(&[0b0000_0000]).is_err());
910
911 assert!(BitList0::from_ssz_bytes(&[0b0000_0001]).is_ok());
912 assert!(BitList0::from_ssz_bytes(&[0b0000_0010]).is_err());
913
914 assert!(BitList1::from_ssz_bytes(&[0b0000_0001]).is_ok());
915 assert!(BitList1::from_ssz_bytes(&[0b0000_0010]).is_ok());
916 assert!(BitList1::from_ssz_bytes(&[0b0000_0100]).is_err());
917
918 assert!(BitList8::from_ssz_bytes(&[0b0000_0001]).is_ok());
919 assert!(BitList8::from_ssz_bytes(&[0b0000_0010]).is_ok());
920 assert!(BitList8::from_ssz_bytes(&[0b0000_0001, 0b0000_0001]).is_ok());
921 assert!(BitList8::from_ssz_bytes(&[0b0000_0001, 0b0000_0010]).is_err());
922 assert!(BitList8::from_ssz_bytes(&[0b0000_0001, 0b0000_0100]).is_err());
923 }
924
925 #[test]
926 fn ssz_decode_extra_bytes() {
927 assert!(BitList0::from_ssz_bytes(&[0b0000_0001, 0b0000_0000]).is_err());
928 assert!(BitList1::from_ssz_bytes(&[0b0000_0001, 0b0000_0000]).is_err());
929 assert!(BitList8::from_ssz_bytes(&[0b0000_0001, 0b0000_0000]).is_err());
930 assert!(BitList16::from_ssz_bytes(&[0b0000_0001, 0b0000_0000]).is_err());
931 assert!(BitList1024::from_ssz_bytes(&[0b1000_0000, 0]).is_err());
932 assert!(BitList1024::from_ssz_bytes(&[0b1000_0000, 0, 0]).is_err());
933 assert!(BitList1024::from_ssz_bytes(&[0b1000_0000, 0, 0, 0, 0]).is_err());
934 }
935
936 #[test]
937 fn ssz_round_trip() {
938 assert_round_trip(BitList0::with_capacity(0).unwrap());
939
940 for i in 0..2 {
941 assert_round_trip(BitList1::with_capacity(i).unwrap());
942 }
943 for i in 0..9 {
944 assert_round_trip(BitList8::with_capacity(i).unwrap());
945 }
946 for i in 0..17 {
947 assert_round_trip(BitList16::with_capacity(i).unwrap());
948 }
949
950 let mut b = BitList1::with_capacity(1).unwrap();
951 b.set(0, true).unwrap();
952 assert_round_trip(b);
953
954 for i in 0..8 {
955 let mut b = BitList8::with_capacity(i).unwrap();
956 for j in 0..i {
957 if j % 2 == 0 {
958 b.set(j, true).unwrap();
959 }
960 }
961 assert_round_trip(b);
962
963 let mut b = BitList8::with_capacity(i).unwrap();
964 for j in 0..i {
965 b.set(j, true).unwrap();
966 }
967 assert_round_trip(b);
968 }
969
970 for i in 0..16 {
971 let mut b = BitList16::with_capacity(i).unwrap();
972 for j in 0..i {
973 if j % 2 == 0 {
974 b.set(j, true).unwrap();
975 }
976 }
977 assert_round_trip(b);
978
979 let mut b = BitList16::with_capacity(i).unwrap();
980 for j in 0..i {
981 b.set(j, true).unwrap();
982 }
983 assert_round_trip(b);
984 }
985 }
986
987 fn assert_round_trip<T: Encode + Decode + PartialEq + std::fmt::Debug>(t: T) {
988 assert_eq!(T::from_ssz_bytes(&t.as_ssz_bytes()).unwrap(), t);
989 }
990
991 #[test]
992 fn from_raw_bytes() {
993 assert!(BitList1024::from_raw_bytes(vec![0b0000_0000], 0).is_ok());
994 assert!(BitList1024::from_raw_bytes(vec![0b0000_0001], 1).is_ok());
995 assert!(BitList1024::from_raw_bytes(vec![0b0000_0011], 2).is_ok());
996 assert!(BitList1024::from_raw_bytes(vec![0b0000_0111], 3).is_ok());
997 assert!(BitList1024::from_raw_bytes(vec![0b0000_1111], 4).is_ok());
998 assert!(BitList1024::from_raw_bytes(vec![0b0001_1111], 5).is_ok());
999 assert!(BitList1024::from_raw_bytes(vec![0b0011_1111], 6).is_ok());
1000 assert!(BitList1024::from_raw_bytes(vec![0b0111_1111], 7).is_ok());
1001 assert!(BitList1024::from_raw_bytes(vec![0b1111_1111], 8).is_ok());
1002
1003 assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0000_0001], 9).is_ok());
1004 assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0000_0011], 10).is_ok());
1005 assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0000_0111], 11).is_ok());
1006 assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0000_1111], 12).is_ok());
1007 assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0001_1111], 13).is_ok());
1008 assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0011_1111], 14).is_ok());
1009 assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0111_1111], 15).is_ok());
1010 assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b1111_1111], 16).is_ok());
1011
1012 for i in 0..8 {
1013 assert!(BitList1024::from_raw_bytes(vec![], i).is_err());
1014 assert!(BitList1024::from_raw_bytes(vec![0b1111_1111], i).is_err());
1015 assert!(BitList1024::from_raw_bytes(vec![0b0000_0000, 0b1111_1110], i).is_err());
1016 }
1017
1018 assert!(BitList1024::from_raw_bytes(vec![0b0000_0001], 0).is_err());
1019
1020 assert!(BitList1024::from_raw_bytes(vec![0b0000_0001], 0).is_err());
1021 assert!(BitList1024::from_raw_bytes(vec![0b0000_0011], 1).is_err());
1022 assert!(BitList1024::from_raw_bytes(vec![0b0000_0111], 2).is_err());
1023 assert!(BitList1024::from_raw_bytes(vec![0b0000_1111], 3).is_err());
1024 assert!(BitList1024::from_raw_bytes(vec![0b0001_1111], 4).is_err());
1025 assert!(BitList1024::from_raw_bytes(vec![0b0011_1111], 5).is_err());
1026 assert!(BitList1024::from_raw_bytes(vec![0b0111_1111], 6).is_err());
1027 assert!(BitList1024::from_raw_bytes(vec![0b1111_1111], 7).is_err());
1028
1029 assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0000_0001], 8).is_err());
1030 assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0000_0011], 9).is_err());
1031 assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0000_0111], 10).is_err());
1032 assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0000_1111], 11).is_err());
1033 assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0001_1111], 12).is_err());
1034 assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0011_1111], 13).is_err());
1035 assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0111_1111], 14).is_err());
1036 assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b1111_1111], 15).is_err());
1037 }
1038
1039 fn test_set_unset(num_bits: usize) {
1040 let mut bitfield = BitList1024::with_capacity(num_bits).unwrap();
1041
1042 for i in 0..=num_bits {
1043 if i < num_bits {
1044 assert_eq!(bitfield.get(i), Ok(false));
1046 assert!(bitfield.set(i, true).is_ok());
1048 assert_eq!(bitfield.get(i), Ok(true));
1049 assert!(bitfield.set(i, false).is_ok());
1051 assert_eq!(bitfield.get(i), Ok(false));
1052 } else {
1053 assert!(bitfield.get(i).is_err());
1054 assert!(bitfield.set(i, true).is_err());
1055 assert!(bitfield.get(i).is_err());
1056 }
1057 }
1058 }
1059
1060 fn test_bytes_round_trip(num_bits: usize) {
1061 for i in 0..num_bits {
1062 let mut bitfield = BitList1024::with_capacity(num_bits).unwrap();
1063 bitfield.set(i, true).unwrap();
1064
1065 let bytes = bitfield.clone().into_raw_bytes();
1066 assert_eq!(bitfield, Bitfield::from_raw_bytes(bytes, num_bits).unwrap());
1067 }
1068 }
1069
1070 #[test]
1071 fn set_unset() {
1072 for i in 0..8 * 5 {
1073 test_set_unset(i)
1074 }
1075 }
1076
1077 #[test]
1078 fn bytes_round_trip() {
1079 for i in 0..8 * 5 {
1080 test_bytes_round_trip(i)
1081 }
1082 }
1083
1084 #[test]
1085 fn into_raw_bytes() {
1086 let mut bitfield = BitList1024::with_capacity(9).unwrap();
1087 bitfield.set(0, true).unwrap();
1088 assert_eq!(
1089 bitfield.clone().into_raw_bytes(),
1090 vec![0b0000_0001, 0b0000_0000]
1091 );
1092 bitfield.set(1, true).unwrap();
1093 assert_eq!(
1094 bitfield.clone().into_raw_bytes(),
1095 vec![0b0000_0011, 0b0000_0000]
1096 );
1097 bitfield.set(2, true).unwrap();
1098 assert_eq!(
1099 bitfield.clone().into_raw_bytes(),
1100 vec![0b0000_0111, 0b0000_0000]
1101 );
1102 bitfield.set(3, true).unwrap();
1103 assert_eq!(
1104 bitfield.clone().into_raw_bytes(),
1105 vec![0b0000_1111, 0b0000_0000]
1106 );
1107 bitfield.set(4, true).unwrap();
1108 assert_eq!(
1109 bitfield.clone().into_raw_bytes(),
1110 vec![0b0001_1111, 0b0000_0000]
1111 );
1112 bitfield.set(5, true).unwrap();
1113 assert_eq!(
1114 bitfield.clone().into_raw_bytes(),
1115 vec![0b0011_1111, 0b0000_0000]
1116 );
1117 bitfield.set(6, true).unwrap();
1118 assert_eq!(
1119 bitfield.clone().into_raw_bytes(),
1120 vec![0b0111_1111, 0b0000_0000]
1121 );
1122 bitfield.set(7, true).unwrap();
1123 assert_eq!(
1124 bitfield.clone().into_raw_bytes(),
1125 vec![0b1111_1111, 0b0000_0000]
1126 );
1127 bitfield.set(8, true).unwrap();
1128 assert_eq!(bitfield.into_raw_bytes(), vec![0b1111_1111, 0b0000_0001]);
1129 }
1130
1131 #[test]
1132 fn highest_set_bit() {
1133 assert_eq!(
1134 BitList1024::with_capacity(16).unwrap().highest_set_bit(),
1135 None
1136 );
1137
1138 assert_eq!(
1139 BitList1024::from_raw_bytes(vec![0b0000_0001, 0b0000_0000], 16)
1140 .unwrap()
1141 .highest_set_bit(),
1142 Some(0)
1143 );
1144
1145 assert_eq!(
1146 BitList1024::from_raw_bytes(vec![0b0000_0010, 0b0000_0000], 16)
1147 .unwrap()
1148 .highest_set_bit(),
1149 Some(1)
1150 );
1151
1152 assert_eq!(
1153 BitList1024::from_raw_bytes(vec![0b0000_1000], 8)
1154 .unwrap()
1155 .highest_set_bit(),
1156 Some(3)
1157 );
1158
1159 assert_eq!(
1160 BitList1024::from_raw_bytes(vec![0b0000_0000, 0b1000_0000], 16)
1161 .unwrap()
1162 .highest_set_bit(),
1163 Some(15)
1164 );
1165 }
1166
1167 #[test]
1168 fn intersection() {
1169 let a = BitList1024::from_raw_bytes(vec![0b1100, 0b0001], 16).unwrap();
1170 let b = BitList1024::from_raw_bytes(vec![0b1011, 0b1001], 16).unwrap();
1171 let c = BitList1024::from_raw_bytes(vec![0b1000, 0b0001], 16).unwrap();
1172
1173 assert_eq!(a.intersection(&b), c);
1174 assert_eq!(b.intersection(&a), c);
1175 assert_eq!(a.intersection(&c), c);
1176 assert_eq!(b.intersection(&c), c);
1177 assert_eq!(a.intersection(&a), a);
1178 assert_eq!(b.intersection(&b), b);
1179 assert_eq!(c.intersection(&c), c);
1180 }
1181
1182 #[test]
1183 fn intersection_diff_length() {
1184 let a = BitList1024::from_bytes(vec![0b0010_1110, 0b0010_1011]).unwrap();
1185 let b = BitList1024::from_bytes(vec![0b0010_1101, 0b0000_0001]).unwrap();
1186 let c = BitList1024::from_bytes(vec![0b0010_1100, 0b0000_0001]).unwrap();
1187 let d = BitList1024::from_bytes(vec![0b0010_1110, 0b1111_1111, 0b1111_1111]).unwrap();
1188
1189 assert_eq!(a.len(), 13);
1190 assert_eq!(b.len(), 8);
1191 assert_eq!(c.len(), 8);
1192 assert_eq!(d.len(), 23);
1193 assert_eq!(a.intersection(&b), c);
1194 assert_eq!(b.intersection(&a), c);
1195 assert_eq!(a.intersection(&d), a);
1196 assert_eq!(d.intersection(&a), a);
1197 }
1198
1199 #[test]
1200 fn union() {
1201 let a = BitList1024::from_raw_bytes(vec![0b1100, 0b0001], 16).unwrap();
1202 let b = BitList1024::from_raw_bytes(vec![0b1011, 0b1001], 16).unwrap();
1203 let c = BitList1024::from_raw_bytes(vec![0b1111, 0b1001], 16).unwrap();
1204
1205 assert_eq!(a.union(&b), c);
1206 assert_eq!(b.union(&a), c);
1207 assert_eq!(a.union(&a), a);
1208 assert_eq!(b.union(&b), b);
1209 assert_eq!(c.union(&c), c);
1210 }
1211
1212 #[test]
1213 fn union_diff_length() {
1214 let a = BitList1024::from_bytes(vec![0b0010_1011, 0b0010_1110]).unwrap();
1215 let b = BitList1024::from_bytes(vec![0b0000_0001, 0b0010_1101]).unwrap();
1216 let c = BitList1024::from_bytes(vec![0b0010_1011, 0b0010_1111]).unwrap();
1217 let d = BitList1024::from_bytes(vec![0b0010_1011, 0b1011_1110, 0b1000_1101]).unwrap();
1218
1219 assert_eq!(a.len(), c.len());
1220 assert_eq!(a.union(&b), c);
1221 assert_eq!(b.union(&a), c);
1222 assert_eq!(a.union(&d), d);
1223 assert_eq!(d.union(&a), d);
1224 }
1225
1226 #[test]
1227 fn difference() {
1228 let a = BitList1024::from_raw_bytes(vec![0b1100, 0b0001], 16).unwrap();
1229 let b = BitList1024::from_raw_bytes(vec![0b1011, 0b1001], 16).unwrap();
1230 let a_b = BitList1024::from_raw_bytes(vec![0b0100, 0b0000], 16).unwrap();
1231 let b_a = BitList1024::from_raw_bytes(vec![0b0011, 0b1000], 16).unwrap();
1232
1233 assert_eq!(a.difference(&b), a_b);
1234 assert_eq!(b.difference(&a), b_a);
1235 assert!(a.difference(&a).is_zero());
1236 }
1237
1238 #[test]
1239 fn difference_diff_length() {
1240 let a = BitList1024::from_raw_bytes(vec![0b0110, 0b1100, 0b0011], 24).unwrap();
1241 let b = BitList1024::from_raw_bytes(vec![0b1011, 0b1001], 16).unwrap();
1242 let a_b = BitList1024::from_raw_bytes(vec![0b0100, 0b0100, 0b0011], 24).unwrap();
1243 let b_a = BitList1024::from_raw_bytes(vec![0b1001, 0b0001], 16).unwrap();
1244
1245 assert_eq!(a.difference(&b), a_b);
1246 assert_eq!(b.difference(&a), b_a);
1247 }
1248
1249 #[test]
1250 fn shift_up() {
1251 let mut a = BitList1024::from_raw_bytes(vec![0b1100_1111, 0b1101_0110], 16).unwrap();
1252 let mut b = BitList1024::from_raw_bytes(vec![0b1001_1110, 0b1010_1101], 16).unwrap();
1253
1254 a.shift_up(1).unwrap();
1255 assert_eq!(a, b);
1256 a.shift_up(15).unwrap();
1257 assert!(a.is_zero());
1258
1259 b.shift_up(16).unwrap();
1260 assert!(b.is_zero());
1261 assert!(b.shift_up(17).is_err());
1262 }
1263
1264 #[test]
1265 fn num_set_bits() {
1266 let a = BitList1024::from_raw_bytes(vec![0b1100, 0b0001], 16).unwrap();
1267 let b = BitList1024::from_raw_bytes(vec![0b1011, 0b1001], 16).unwrap();
1268
1269 assert_eq!(a.num_set_bits(), 3);
1270 assert_eq!(b.num_set_bits(), 5);
1271 }
1272
1273 #[test]
1274 fn iter() {
1275 let mut bitfield = BitList1024::with_capacity(9).unwrap();
1276 bitfield.set(2, true).unwrap();
1277 bitfield.set(8, true).unwrap();
1278
1279 assert_eq!(
1280 bitfield.iter().collect::<Vec<bool>>(),
1281 vec![false, false, true, false, false, false, false, false, true]
1282 );
1283 }
1284
1285 #[test]
1286 fn ssz_bytes_len() {
1287 for i in 1..64 {
1288 let mut bitfield = BitList1024::with_capacity(i).unwrap();
1289 for j in 0..i {
1290 bitfield.set(j, true).expect("should set bit in bounds");
1291 }
1292 let bytes = bitfield.as_ssz_bytes();
1293 assert_eq!(bitfield.ssz_bytes_len(), bytes.len(), "i = {}", i);
1294 }
1295 }
1296}