ssz_types/
bitfield.rs

1use crate::tree_hash::bitfield_bytes_tree_hash_root;
2use crate::Error;
3use core::marker::PhantomData;
4use eth2_serde_utils::hex::{encode as hex_encode, PrefixedHexVisitor};
5use serde::de::{Deserialize, Deserializer};
6use serde::ser::{Serialize, Serializer};
7use ssz::{Decode, Encode};
8use tree_hash::Hash256;
9use typenum::Unsigned;
10
11/// A marker trait applied to `Variable` and `Fixed` that defines the behaviour of a `Bitfield`.
12pub trait BitfieldBehaviour: Clone {}
13
14/// A marker struct used to declare SSZ `Variable` behaviour on a `Bitfield`.
15///
16/// See the [`Bitfield`](struct.Bitfield.html) docs for usage.
17#[derive(Clone, PartialEq, Debug)]
18pub struct Variable<N> {
19    _phantom: PhantomData<N>,
20}
21
22/// A marker struct used to declare SSZ `Fixed` behaviour on a `Bitfield`.
23///
24/// See the [`Bitfield`](struct.Bitfield.html) docs for usage.
25#[derive(Clone, PartialEq, Debug)]
26pub struct Fixed<N> {
27    _phantom: PhantomData<N>,
28}
29
30impl<N: Unsigned + Clone> BitfieldBehaviour for Variable<N> {}
31impl<N: Unsigned + Clone> BitfieldBehaviour for Fixed<N> {}
32
33/// A heap-allocated, ordered, variable-length collection of `bool` values, limited to `N` bits.
34pub type BitList<N> = Bitfield<Variable<N>>;
35
36/// A heap-allocated, ordered, fixed-length collection of `bool` values, with `N` bits.
37///
38/// See [Bitfield](struct.Bitfield.html) documentation.
39pub type BitVector<N> = Bitfield<Fixed<N>>;
40
41/// A heap-allocated, ordered, fixed-length, collection of `bool` values. Use of
42/// [`BitList`](type.BitList.html) or [`BitVector`](type.BitVector.html) type aliases is preferred
43/// over direct use of this struct.
44///
45/// The `T` type parameter is used to define length behaviour with the `Variable` or `Fixed` marker
46/// structs.
47///
48/// The length of the Bitfield is set at instantiation (i.e., runtime, not compile time). However,
49/// use with a `Variable` sets a type-level (i.e., compile-time) maximum length and `Fixed`
50/// provides a type-level fixed length.
51///
52/// ## Example
53///
54/// The example uses the following crate-level type aliases:
55///
56/// - `BitList<N>` is an alias for `Bitfield<Variable<N>>`
57/// - `BitVector<N>` is an alias for `Bitfield<Fixed<N>>`
58///
59/// ```
60/// use ssz_types::{BitVector, BitList, typenum};
61///
62/// // `BitList` has a type-level maximum length. The length of the list is specified at runtime
63/// // and it must be less than or equal to `N`. After instantiation, `BitList` cannot grow or
64/// // shrink.
65/// type BitList8 = BitList<typenum::U8>;
66///
67/// // Creating a `BitList` with a larger-than-`N` capacity returns `None`.
68/// assert!(BitList8::with_capacity(9).is_err());
69///
70/// let mut bitlist = BitList8::with_capacity(4).unwrap();  // `BitList` permits a capacity of less than the maximum.
71/// assert!(bitlist.set(3, true).is_ok());  // Setting inside the instantiation capacity is permitted.
72/// assert!(bitlist.set(5, true).is_err());  // Setting outside that capacity is not.
73///
74/// // `BitVector` has a type-level fixed length. Unlike `BitList`, it cannot be instantiated with a custom length
75/// // or grow/shrink.
76/// type BitVector8 = BitVector<typenum::U8>;
77///
78/// let mut bitvector = BitVector8::new();
79/// assert_eq!(bitvector.len(), 8); // `BitVector` length is fixed at the type-level.
80/// assert!(bitvector.set(7, true).is_ok());  // Setting inside the capacity is permitted.
81/// assert!(bitvector.set(9, true).is_err());  // Setting outside the capacity is not.
82///
83/// ```
84///
85/// ## Note
86///
87/// The internal representation of the bitfield is the same as that required by SSZ. The lowest
88/// byte (by `Vec` index) stores the lowest bit-indices and the right-most bit stores the lowest
89/// bit-index. E.g., `vec![0b0000_0001, 0b0000_0010]` has bits `0, 9` set.
90#[derive(Clone, Debug, PartialEq)]
91pub struct Bitfield<T> {
92    bytes: Vec<u8>,
93    len: usize,
94    _phantom: PhantomData<T>,
95}
96
97impl<N: Unsigned + Clone> Bitfield<Variable<N>> {
98    /// Instantiate with capacity for `num_bits` boolean values. The length cannot be grown or
99    /// shrunk after instantiation.
100    ///
101    /// All bits are initialized to `false`.
102    ///
103    /// Returns `None` if `num_bits > N`.
104    pub fn with_capacity(num_bits: usize) -> Result<Self, Error> {
105        if num_bits <= N::to_usize() {
106            Ok(Self {
107                bytes: vec![0; bytes_for_bit_len(num_bits)],
108                len: num_bits,
109                _phantom: PhantomData,
110            })
111        } else {
112            Err(Error::OutOfBounds {
113                i: Self::max_len(),
114                len: Self::max_len(),
115            })
116        }
117    }
118
119    /// Equal to `N` regardless of the value supplied to `with_capacity`.
120    pub fn max_len() -> usize {
121        N::to_usize()
122    }
123
124    /// Consumes `self`, returning a serialized representation.
125    ///
126    /// The output is faithful to the SSZ encoding of `self`, such that a leading `true` bit is
127    /// used to indicate the length of the bitfield.
128    ///
129    /// ## Example
130    /// ```
131    /// use ssz_types::{BitList, typenum};
132    ///
133    /// type BitList8 = BitList<typenum::U8>;
134    ///
135    /// let b = BitList8::with_capacity(4).unwrap();
136    ///
137    /// assert_eq!(b.into_bytes(), vec![0b0001_0000]);
138    /// ```
139    pub fn into_bytes(self) -> Vec<u8> {
140        let len = self.len();
141        let mut bytes = self.bytes;
142
143        bytes.resize(bytes_for_bit_len(len + 1), 0);
144
145        let mut bitfield: Bitfield<Variable<N>> = Bitfield::from_raw_bytes(bytes, len + 1)
146            .unwrap_or_else(|_| {
147                unreachable!(
148                    "Bitfield with {} bytes must have enough capacity for {} bits.",
149                    bytes_for_bit_len(len + 1),
150                    len + 1
151                )
152            });
153        bitfield
154            .set(len, true)
155            .expect("len must be in bounds for bitfield.");
156
157        bitfield.bytes
158    }
159
160    /// Instantiates a new instance from `bytes`. Consumes the same format that `self.into_bytes()`
161    /// produces (SSZ).
162    ///
163    /// Returns `None` if `bytes` are not a valid encoding.
164    pub fn from_bytes(bytes: Vec<u8>) -> Result<Self, Error> {
165        let bytes_len = bytes.len();
166        let mut initial_bitfield: Bitfield<Variable<N>> = {
167            let num_bits = bytes.len() * 8;
168            Bitfield::from_raw_bytes(bytes, num_bits)?
169        };
170
171        let len = initial_bitfield
172            .highest_set_bit()
173            .ok_or(Error::MissingLengthInformation)?;
174
175        // The length bit should be in the last byte, or else it means we have too many bytes.
176        if len / 8 + 1 != bytes_len {
177            return Err(Error::InvalidByteCount {
178                given: bytes_len,
179                expected: len / 8 + 1,
180            });
181        }
182
183        if len <= Self::max_len() {
184            initial_bitfield
185                .set(len, false)
186                .expect("Bit has been confirmed to exist");
187
188            let mut bytes = initial_bitfield.into_raw_bytes();
189
190            bytes.truncate(bytes_for_bit_len(len));
191
192            Self::from_raw_bytes(bytes, len)
193        } else {
194            Err(Error::OutOfBounds {
195                i: Self::max_len(),
196                len: Self::max_len(),
197            })
198        }
199    }
200
201    /// Compute the intersection of two BitLists of potentially different lengths.
202    ///
203    /// Return a new BitList with length equal to the shorter of the two inputs.
204    pub fn intersection(&self, other: &Self) -> Self {
205        let min_len = std::cmp::min(self.len(), other.len());
206        let mut result = Self::with_capacity(min_len).expect("min len always less than N");
207        // Bitwise-and the bytes together, starting from the left of each vector. This takes care
208        // of masking out any entries beyond `min_len` as well, assuming the bitfield doesn't
209        // contain any set bits beyond its length.
210        for i in 0..result.bytes.len() {
211            result.bytes[i] = self.bytes[i] & other.bytes[i];
212        }
213        result
214    }
215
216    /// Compute the union of two BitLists of potentially different lengths.
217    ///
218    /// Return a new BitList with length equal to the longer of the two inputs.
219    pub fn union(&self, other: &Self) -> Self {
220        let max_len = std::cmp::max(self.len(), other.len());
221        let mut result = Self::with_capacity(max_len).expect("max len always less than N");
222        for i in 0..result.bytes.len() {
223            result.bytes[i] =
224                self.bytes.get(i).copied().unwrap_or(0) | other.bytes.get(i).copied().unwrap_or(0);
225        }
226        result
227    }
228}
229
230impl<N: Unsigned + Clone> Bitfield<Fixed<N>> {
231    /// Instantiate a new `Bitfield` with a fixed-length of `N` bits.
232    ///
233    /// All bits are initialized to `false`.
234    pub fn new() -> Self {
235        Self {
236            bytes: vec![0; bytes_for_bit_len(Self::capacity())],
237            len: Self::capacity(),
238            _phantom: PhantomData,
239        }
240    }
241
242    /// Returns `N`, the number of bits in `Self`.
243    pub fn capacity() -> usize {
244        N::to_usize()
245    }
246
247    /// Consumes `self`, returning a serialized representation.
248    ///
249    /// The output is faithful to the SSZ encoding of `self`.
250    ///
251    /// ## Example
252    /// ```
253    /// use ssz_types::{BitVector, typenum};
254    ///
255    /// type BitVector4 = BitVector<typenum::U4>;
256    ///
257    /// assert_eq!(BitVector4::new().into_bytes(), vec![0b0000_0000]);
258    /// ```
259    pub fn into_bytes(self) -> Vec<u8> {
260        self.into_raw_bytes()
261    }
262
263    /// Instantiates a new instance from `bytes`. Consumes the same format that `self.into_bytes()`
264    /// produces (SSZ).
265    ///
266    /// Returns `None` if `bytes` are not a valid encoding.
267    pub fn from_bytes(bytes: Vec<u8>) -> Result<Self, Error> {
268        Self::from_raw_bytes(bytes, Self::capacity())
269    }
270
271    /// Compute the intersection of two fixed-length `Bitfield`s.
272    ///
273    /// Return a new fixed-length `Bitfield`.
274    pub fn intersection(&self, other: &Self) -> Self {
275        let mut result = Self::new();
276        // Bitwise-and the bytes together, starting from the left of each vector. This takes care
277        // of masking out any entries beyond `min_len` as well, assuming the bitfield doesn't
278        // contain any set bits beyond its length.
279        for i in 0..result.bytes.len() {
280            result.bytes[i] = self.bytes[i] & other.bytes[i];
281        }
282        result
283    }
284
285    /// Compute the union of two fixed-length `Bitfield`s.
286    ///
287    /// Return a new fixed-length `Bitfield`.
288    pub fn union(&self, other: &Self) -> Self {
289        let mut result = Self::new();
290        for i in 0..result.bytes.len() {
291            result.bytes[i] =
292                self.bytes.get(i).copied().unwrap_or(0) | other.bytes.get(i).copied().unwrap_or(0);
293        }
294        result
295    }
296}
297
298impl<N: Unsigned + Clone> Default for Bitfield<Fixed<N>> {
299    fn default() -> Self {
300        Self::new()
301    }
302}
303
304impl<T: BitfieldBehaviour> Bitfield<T> {
305    /// Sets the `i`'th bit to `value`.
306    ///
307    /// Returns `None` if `i` is out-of-bounds of `self`.
308    pub fn set(&mut self, i: usize, value: bool) -> Result<(), Error> {
309        let len = self.len;
310
311        if i < len {
312            let byte = self
313                .bytes
314                .get_mut(i / 8)
315                .ok_or(Error::OutOfBounds { i, len })?;
316
317            if value {
318                *byte |= 1 << (i % 8)
319            } else {
320                *byte &= !(1 << (i % 8))
321            }
322
323            Ok(())
324        } else {
325            Err(Error::OutOfBounds { i, len: self.len })
326        }
327    }
328
329    /// Returns the value of the `i`'th bit.
330    ///
331    /// Returns `Error` if `i` is out-of-bounds of `self`.
332    pub fn get(&self, i: usize) -> Result<bool, Error> {
333        if i < self.len {
334            let byte = self
335                .bytes
336                .get(i / 8)
337                .ok_or(Error::OutOfBounds { i, len: self.len })?;
338
339            Ok(*byte & 1 << (i % 8) > 0)
340        } else {
341            Err(Error::OutOfBounds { i, len: self.len })
342        }
343    }
344
345    /// Returns the number of bits stored in `self`.
346    pub fn len(&self) -> usize {
347        self.len
348    }
349
350    /// Returns `true` if `self.len() == 0`.
351    pub fn is_empty(&self) -> bool {
352        self.len == 0
353    }
354
355    /// Returns the underlying bytes representation of the bitfield.
356    pub fn into_raw_bytes(self) -> Vec<u8> {
357        self.bytes
358    }
359
360    /// Returns a view into the underlying bytes representation of the bitfield.
361    pub fn as_slice(&self) -> &[u8] {
362        &self.bytes
363    }
364
365    /// Instantiates from the given `bytes`, which are the same format as output from
366    /// `self.into_raw_bytes()`.
367    ///
368    /// Returns `None` if:
369    ///
370    /// - `bytes` is not the minimal required bytes to represent a bitfield of `bit_len` bits.
371    /// - `bit_len` is not a multiple of 8 and `bytes` contains set bits that are higher than, or
372    /// equal to `bit_len`.
373    fn from_raw_bytes(bytes: Vec<u8>, bit_len: usize) -> Result<Self, Error> {
374        if bit_len == 0 {
375            if bytes.len() == 1 && bytes == [0] {
376                // A bitfield with `bit_len` 0 can only be represented by a single zero byte.
377                Ok(Self {
378                    bytes,
379                    len: 0,
380                    _phantom: PhantomData,
381                })
382            } else {
383                Err(Error::ExcessBits)
384            }
385        } else if bytes.len() != bytes_for_bit_len(bit_len) {
386            // The number of bytes must be the minimum required to represent `bit_len`.
387            Err(Error::InvalidByteCount {
388                given: bytes.len(),
389                expected: bytes_for_bit_len(bit_len),
390            })
391        } else {
392            // Ensure there are no bits higher than `bit_len` that are set to true.
393            let (mask, _) = u8::max_value().overflowing_shr(8 - (bit_len as u32 % 8));
394
395            if (bytes.last().expect("Guarded against empty bytes") & !mask) == 0 {
396                Ok(Self {
397                    bytes,
398                    len: bit_len,
399                    _phantom: PhantomData,
400                })
401            } else {
402                Err(Error::ExcessBits)
403            }
404        }
405    }
406
407    /// Returns the `Some(i)` where `i` is the highest index with a set bit. Returns `None` if
408    /// there are no set bits.
409    pub fn highest_set_bit(&self) -> Option<usize> {
410        self.bytes
411            .iter()
412            .enumerate()
413            .rev()
414            .find(|(_, byte)| **byte > 0)
415            .map(|(i, byte)| i * 8 + 7 - byte.leading_zeros() as usize)
416    }
417
418    /// Returns an iterator across bitfield `bool` values, starting at the lowest index.
419    pub fn iter(&self) -> BitIter<'_, T> {
420        BitIter {
421            bitfield: self,
422            i: 0,
423        }
424    }
425
426    /// Returns true if no bits are set.
427    pub fn is_zero(&self) -> bool {
428        self.bytes.iter().all(|byte| *byte == 0)
429    }
430
431    /// Returns the number of bits that are set to `true`.
432    pub fn num_set_bits(&self) -> usize {
433        self.bytes
434            .iter()
435            .map(|byte| byte.count_ones() as usize)
436            .sum()
437    }
438
439    /// Compute the difference of this Bitfield and another of potentially different length.
440    pub fn difference(&self, other: &Self) -> Self {
441        let mut result = self.clone();
442        result.difference_inplace(other);
443        result
444    }
445
446    /// Compute the difference of this Bitfield and another of potentially different length.
447    pub fn difference_inplace(&mut self, other: &Self) {
448        let min_byte_len = std::cmp::min(self.bytes.len(), other.bytes.len());
449
450        for i in 0..min_byte_len {
451            self.bytes[i] &= !other.bytes[i];
452        }
453    }
454
455    /// Shift the bits to higher indices, filling the lower indices with zeroes.
456    ///
457    /// The amount to shift by, `n`, must be less than or equal to `self.len()`.
458    pub fn shift_up(&mut self, n: usize) -> Result<(), Error> {
459        if n <= self.len() {
460            // Shift the bits up (starting from the high indices to avoid overwriting)
461            for i in (n..self.len()).rev() {
462                self.set(i, self.get(i - n)?)?;
463            }
464            // Zero the low bits
465            for i in 0..n {
466                self.set(i, false).unwrap();
467            }
468            Ok(())
469        } else {
470            Err(Error::OutOfBounds {
471                i: n,
472                len: self.len(),
473            })
474        }
475    }
476}
477
478/// Returns the minimum required bytes to represent a given number of bits.
479///
480/// `bit_len == 0` requires a single byte.
481fn bytes_for_bit_len(bit_len: usize) -> usize {
482    std::cmp::max(1, (bit_len + 7) / 8)
483}
484
485/// An iterator over the bits in a `Bitfield`.
486pub struct BitIter<'a, T> {
487    bitfield: &'a Bitfield<T>,
488    i: usize,
489}
490
491impl<'a, T: BitfieldBehaviour> Iterator for BitIter<'a, T> {
492    type Item = bool;
493
494    fn next(&mut self) -> Option<Self::Item> {
495        let res = self.bitfield.get(self.i).ok()?;
496        self.i += 1;
497        Some(res)
498    }
499}
500
501impl<N: Unsigned + Clone> Encode for Bitfield<Variable<N>> {
502    fn is_ssz_fixed_len() -> bool {
503        false
504    }
505
506    fn ssz_bytes_len(&self) -> usize {
507        // We could likely do better than turning this into bytes and reading the length, however
508        // it is kept this way for simplicity.
509        self.clone().into_bytes().len()
510    }
511
512    fn ssz_append(&self, buf: &mut Vec<u8>) {
513        buf.append(&mut self.clone().into_bytes())
514    }
515}
516
517impl<N: Unsigned + Clone> Decode for Bitfield<Variable<N>> {
518    fn is_ssz_fixed_len() -> bool {
519        false
520    }
521
522    fn from_ssz_bytes(bytes: &[u8]) -> Result<Self, ssz::DecodeError> {
523        Self::from_bytes(bytes.to_vec()).map_err(|e| {
524            ssz::DecodeError::BytesInvalid(format!("BitList failed to decode: {:?}", e))
525        })
526    }
527}
528
529impl<N: Unsigned + Clone> Encode for Bitfield<Fixed<N>> {
530    fn is_ssz_fixed_len() -> bool {
531        true
532    }
533
534    fn ssz_bytes_len(&self) -> usize {
535        self.as_slice().len()
536    }
537
538    fn ssz_fixed_len() -> usize {
539        bytes_for_bit_len(N::to_usize())
540    }
541
542    fn ssz_append(&self, buf: &mut Vec<u8>) {
543        buf.append(&mut self.clone().into_bytes())
544    }
545}
546
547impl<N: Unsigned + Clone> Decode for Bitfield<Fixed<N>> {
548    fn is_ssz_fixed_len() -> bool {
549        true
550    }
551
552    fn ssz_fixed_len() -> usize {
553        bytes_for_bit_len(N::to_usize())
554    }
555
556    fn from_ssz_bytes(bytes: &[u8]) -> Result<Self, ssz::DecodeError> {
557        Self::from_bytes(bytes.to_vec()).map_err(|e| {
558            ssz::DecodeError::BytesInvalid(format!("BitVector failed to decode: {:?}", e))
559        })
560    }
561}
562
563impl<N: Unsigned + Clone> Serialize for Bitfield<Variable<N>> {
564    /// Serde serialization is compliant with the Ethereum YAML test format.
565    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
566    where
567        S: Serializer,
568    {
569        serializer.serialize_str(&hex_encode(self.as_ssz_bytes()))
570    }
571}
572
573impl<'de, N: Unsigned + Clone> Deserialize<'de> for Bitfield<Variable<N>> {
574    /// Serde serialization is compliant with the Ethereum YAML test format.
575    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
576    where
577        D: Deserializer<'de>,
578    {
579        let bytes = deserializer.deserialize_str(PrefixedHexVisitor)?;
580        Self::from_ssz_bytes(&bytes)
581            .map_err(|e| serde::de::Error::custom(format!("Bitfield {:?}", e)))
582    }
583}
584
585impl<N: Unsigned + Clone> Serialize for Bitfield<Fixed<N>> {
586    /// Serde serialization is compliant with the Ethereum YAML test format.
587    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
588    where
589        S: Serializer,
590    {
591        serializer.serialize_str(&hex_encode(self.as_ssz_bytes()))
592    }
593}
594
595impl<'de, N: Unsigned + Clone> Deserialize<'de> for Bitfield<Fixed<N>> {
596    /// Serde serialization is compliant with the Ethereum YAML test format.
597    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
598    where
599        D: Deserializer<'de>,
600    {
601        let bytes = deserializer.deserialize_str(PrefixedHexVisitor)?;
602        Self::from_ssz_bytes(&bytes)
603            .map_err(|e| serde::de::Error::custom(format!("Bitfield {:?}", e)))
604    }
605}
606
607impl<N: Unsigned + Clone> tree_hash::TreeHash for Bitfield<Variable<N>> {
608    fn tree_hash_type() -> tree_hash::TreeHashType {
609        tree_hash::TreeHashType::List
610    }
611
612    fn tree_hash_packed_encoding(&self) -> Vec<u8> {
613        unreachable!("List should never be packed.")
614    }
615
616    fn tree_hash_packing_factor() -> usize {
617        unreachable!("List should never be packed.")
618    }
619
620    fn tree_hash_root(&self) -> Hash256 {
621        // Note: we use `as_slice` because it does _not_ have the length-delimiting bit set (or
622        // present).
623        let root = bitfield_bytes_tree_hash_root::<N>(self.as_slice());
624        tree_hash::mix_in_length(&root, self.len())
625    }
626}
627
628impl<N: Unsigned + Clone> tree_hash::TreeHash for Bitfield<Fixed<N>> {
629    fn tree_hash_type() -> tree_hash::TreeHashType {
630        tree_hash::TreeHashType::Vector
631    }
632
633    fn tree_hash_packed_encoding(&self) -> Vec<u8> {
634        unreachable!("Vector should never be packed.")
635    }
636
637    fn tree_hash_packing_factor() -> usize {
638        unreachable!("Vector should never be packed.")
639    }
640
641    fn tree_hash_root(&self) -> Hash256 {
642        bitfield_bytes_tree_hash_root::<N>(self.as_slice())
643    }
644}
645
646#[cfg(feature = "arbitrary")]
647impl<N: 'static + Unsigned> arbitrary::Arbitrary<'_> for Bitfield<Fixed<N>> {
648    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
649        let size = N::to_usize();
650        let mut vec: Vec<u8> = vec![0u8; size];
651        u.fill_buffer(&mut vec)?;
652        Ok(Self::from_bytes(vec).map_err(|_| arbitrary::Error::IncorrectFormat)?)
653    }
654}
655
656#[cfg(feature = "arbitrary")]
657impl<N: 'static + Unsigned> arbitrary::Arbitrary<'_> for Bitfield<Variable<N>> {
658    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
659        let max_size = N::to_usize();
660        let rand = usize::arbitrary(u)?;
661        let size = std::cmp::min(rand, max_size);
662        let mut vec: Vec<u8> = vec![0u8; size];
663        u.fill_buffer(&mut vec)?;
664        Ok(Self::from_bytes(vec).map_err(|_| arbitrary::Error::IncorrectFormat)?)
665    }
666}
667
668#[cfg(test)]
669mod bitvector {
670    use super::*;
671    use crate::BitVector;
672
673    pub type BitVector0 = BitVector<typenum::U0>;
674    pub type BitVector1 = BitVector<typenum::U1>;
675    pub type BitVector4 = BitVector<typenum::U4>;
676    pub type BitVector8 = BitVector<typenum::U8>;
677    pub type BitVector16 = BitVector<typenum::U16>;
678    pub type BitVector64 = BitVector<typenum::U64>;
679
680    #[test]
681    fn ssz_encode() {
682        assert_eq!(BitVector0::new().as_ssz_bytes(), vec![0b0000_0000]);
683        assert_eq!(BitVector1::new().as_ssz_bytes(), vec![0b0000_0000]);
684        assert_eq!(BitVector4::new().as_ssz_bytes(), vec![0b0000_0000]);
685        assert_eq!(BitVector8::new().as_ssz_bytes(), vec![0b0000_0000]);
686        assert_eq!(
687            BitVector16::new().as_ssz_bytes(),
688            vec![0b0000_0000, 0b0000_0000]
689        );
690
691        let mut b = BitVector8::new();
692        for i in 0..8 {
693            b.set(i, true).unwrap();
694        }
695        assert_eq!(b.as_ssz_bytes(), vec![255]);
696
697        let mut b = BitVector4::new();
698        for i in 0..4 {
699            b.set(i, true).unwrap();
700        }
701        assert_eq!(b.as_ssz_bytes(), vec![0b0000_1111]);
702    }
703
704    #[test]
705    fn ssz_decode() {
706        assert!(BitVector0::from_ssz_bytes(&[0b0000_0000]).is_ok());
707        assert!(BitVector0::from_ssz_bytes(&[0b0000_0001]).is_err());
708        assert!(BitVector0::from_ssz_bytes(&[0b0000_0010]).is_err());
709
710        assert!(BitVector1::from_ssz_bytes(&[0b0000_0001]).is_ok());
711        assert!(BitVector1::from_ssz_bytes(&[0b0000_0010]).is_err());
712        assert!(BitVector1::from_ssz_bytes(&[0b0000_0100]).is_err());
713        assert!(BitVector1::from_ssz_bytes(&[0b0000_0000, 0b0000_0000]).is_err());
714
715        assert!(BitVector8::from_ssz_bytes(&[0b0000_0000]).is_ok());
716        assert!(BitVector8::from_ssz_bytes(&[1, 0b0000_0000]).is_err());
717        assert!(BitVector8::from_ssz_bytes(&[0b0000_0000, 1]).is_err());
718        assert!(BitVector8::from_ssz_bytes(&[0b0000_0001]).is_ok());
719        assert!(BitVector8::from_ssz_bytes(&[0b0000_0010]).is_ok());
720        assert!(BitVector8::from_ssz_bytes(&[0b0000_0100, 0b0000_0001]).is_err());
721        assert!(BitVector8::from_ssz_bytes(&[0b0000_0100, 0b0000_0010]).is_err());
722        assert!(BitVector8::from_ssz_bytes(&[0b0000_0100, 0b0000_0100]).is_err());
723
724        assert!(BitVector16::from_ssz_bytes(&[0b0000_0000]).is_err());
725        assert!(BitVector16::from_ssz_bytes(&[0b0000_0000, 0b0000_0000]).is_ok());
726        assert!(BitVector16::from_ssz_bytes(&[1, 0b0000_0000, 0b0000_0000]).is_err());
727    }
728
729    #[test]
730    fn intersection() {
731        let a = BitVector16::from_raw_bytes(vec![0b1100, 0b0001], 16).unwrap();
732        let b = BitVector16::from_raw_bytes(vec![0b1011, 0b1001], 16).unwrap();
733        let c = BitVector16::from_raw_bytes(vec![0b1000, 0b0001], 16).unwrap();
734
735        assert_eq!(a.intersection(&b), c);
736        assert_eq!(b.intersection(&a), c);
737        assert_eq!(a.intersection(&c), c);
738        assert_eq!(b.intersection(&c), c);
739        assert_eq!(a.intersection(&a), a);
740        assert_eq!(b.intersection(&b), b);
741        assert_eq!(c.intersection(&c), c);
742    }
743
744    #[test]
745    fn intersection_diff_length() {
746        let a = BitVector16::from_bytes(vec![0b0010_1110, 0b0010_1011]).unwrap();
747        let b = BitVector16::from_bytes(vec![0b0010_1101, 0b0000_0001]).unwrap();
748        let c = BitVector16::from_bytes(vec![0b0010_1100, 0b0000_0001]).unwrap();
749
750        assert_eq!(a.len(), 16);
751        assert_eq!(b.len(), 16);
752        assert_eq!(c.len(), 16);
753        assert_eq!(a.intersection(&b), c);
754        assert_eq!(b.intersection(&a), c);
755    }
756
757    #[test]
758    fn union() {
759        let a = BitVector16::from_raw_bytes(vec![0b1100, 0b0001], 16).unwrap();
760        let b = BitVector16::from_raw_bytes(vec![0b1011, 0b1001], 16).unwrap();
761        let c = BitVector16::from_raw_bytes(vec![0b1111, 0b1001], 16).unwrap();
762
763        assert_eq!(a.union(&b), c);
764        assert_eq!(b.union(&a), c);
765        assert_eq!(a.union(&a), a);
766        assert_eq!(b.union(&b), b);
767        assert_eq!(c.union(&c), c);
768    }
769
770    #[test]
771    fn union_diff_length() {
772        let a = BitVector16::from_bytes(vec![0b0010_1011, 0b0010_1110]).unwrap();
773        let b = BitVector16::from_bytes(vec![0b0000_0001, 0b0010_1101]).unwrap();
774        let c = BitVector16::from_bytes(vec![0b0010_1011, 0b0010_1111]).unwrap();
775
776        assert_eq!(a.len(), c.len());
777        assert_eq!(a.union(&b), c);
778        assert_eq!(b.union(&a), c);
779    }
780
781    #[test]
782    fn ssz_round_trip() {
783        assert_round_trip(BitVector0::new());
784
785        let mut b = BitVector1::new();
786        b.set(0, true).unwrap();
787        assert_round_trip(b);
788
789        let mut b = BitVector8::new();
790        for j in 0..8 {
791            if j % 2 == 0 {
792                b.set(j, true).unwrap();
793            }
794        }
795        assert_round_trip(b);
796
797        let mut b = BitVector8::new();
798        for j in 0..8 {
799            b.set(j, true).unwrap();
800        }
801        assert_round_trip(b);
802
803        let mut b = BitVector16::new();
804        for j in 0..16 {
805            if j % 2 == 0 {
806                b.set(j, true).unwrap();
807            }
808        }
809        assert_round_trip(b);
810
811        let mut b = BitVector16::new();
812        for j in 0..16 {
813            b.set(j, true).unwrap();
814        }
815        assert_round_trip(b);
816    }
817
818    fn assert_round_trip<T: Encode + Decode + PartialEq + std::fmt::Debug>(t: T) {
819        assert_eq!(T::from_ssz_bytes(&t.as_ssz_bytes()).unwrap(), t);
820    }
821
822    #[test]
823    fn ssz_bytes_len() {
824        for i in 0..64 {
825            let mut bitfield = BitVector64::new();
826            for j in 0..i {
827                bitfield.set(j, true).expect("should set bit in bounds");
828            }
829            let bytes = bitfield.as_ssz_bytes();
830            assert_eq!(bitfield.ssz_bytes_len(), bytes.len(), "i = {}", i);
831        }
832    }
833
834    #[test]
835    fn excess_bits_nimbus() {
836        let bad = vec![0b0001_1111];
837
838        assert!(BitVector4::from_ssz_bytes(&bad).is_err());
839    }
840}
841
842#[cfg(test)]
843#[allow(clippy::cognitive_complexity)]
844mod bitlist {
845    use super::*;
846    use crate::BitList;
847
848    pub type BitList0 = BitList<typenum::U0>;
849    pub type BitList1 = BitList<typenum::U1>;
850    pub type BitList8 = BitList<typenum::U8>;
851    pub type BitList16 = BitList<typenum::U16>;
852    pub type BitList1024 = BitList<typenum::U1024>;
853
854    #[test]
855    fn ssz_encode() {
856        assert_eq!(
857            BitList0::with_capacity(0).unwrap().as_ssz_bytes(),
858            vec![0b0000_0001],
859        );
860
861        assert_eq!(
862            BitList1::with_capacity(0).unwrap().as_ssz_bytes(),
863            vec![0b0000_0001],
864        );
865
866        assert_eq!(
867            BitList1::with_capacity(1).unwrap().as_ssz_bytes(),
868            vec![0b0000_0010],
869        );
870
871        assert_eq!(
872            BitList8::with_capacity(8).unwrap().as_ssz_bytes(),
873            vec![0b0000_0000, 0b0000_0001],
874        );
875
876        assert_eq!(
877            BitList8::with_capacity(7).unwrap().as_ssz_bytes(),
878            vec![0b1000_0000]
879        );
880
881        let mut b = BitList8::with_capacity(8).unwrap();
882        for i in 0..8 {
883            b.set(i, true).unwrap();
884        }
885        assert_eq!(b.as_ssz_bytes(), vec![255, 0b0000_0001]);
886
887        let mut b = BitList8::with_capacity(8).unwrap();
888        for i in 0..4 {
889            b.set(i, true).unwrap();
890        }
891        assert_eq!(b.as_ssz_bytes(), vec![0b0000_1111, 0b0000_0001]);
892
893        assert_eq!(
894            BitList16::with_capacity(16).unwrap().as_ssz_bytes(),
895            vec![0b0000_0000, 0b0000_0000, 0b0000_0001]
896        );
897    }
898
899    #[test]
900    fn ssz_decode() {
901        assert!(BitList0::from_ssz_bytes(&[]).is_err());
902        assert!(BitList1::from_ssz_bytes(&[]).is_err());
903        assert!(BitList8::from_ssz_bytes(&[]).is_err());
904        assert!(BitList16::from_ssz_bytes(&[]).is_err());
905
906        assert!(BitList0::from_ssz_bytes(&[0b0000_0000]).is_err());
907        assert!(BitList1::from_ssz_bytes(&[0b0000_0000, 0b0000_0000]).is_err());
908        assert!(BitList8::from_ssz_bytes(&[0b0000_0000]).is_err());
909        assert!(BitList16::from_ssz_bytes(&[0b0000_0000]).is_err());
910
911        assert!(BitList0::from_ssz_bytes(&[0b0000_0001]).is_ok());
912        assert!(BitList0::from_ssz_bytes(&[0b0000_0010]).is_err());
913
914        assert!(BitList1::from_ssz_bytes(&[0b0000_0001]).is_ok());
915        assert!(BitList1::from_ssz_bytes(&[0b0000_0010]).is_ok());
916        assert!(BitList1::from_ssz_bytes(&[0b0000_0100]).is_err());
917
918        assert!(BitList8::from_ssz_bytes(&[0b0000_0001]).is_ok());
919        assert!(BitList8::from_ssz_bytes(&[0b0000_0010]).is_ok());
920        assert!(BitList8::from_ssz_bytes(&[0b0000_0001, 0b0000_0001]).is_ok());
921        assert!(BitList8::from_ssz_bytes(&[0b0000_0001, 0b0000_0010]).is_err());
922        assert!(BitList8::from_ssz_bytes(&[0b0000_0001, 0b0000_0100]).is_err());
923    }
924
925    #[test]
926    fn ssz_decode_extra_bytes() {
927        assert!(BitList0::from_ssz_bytes(&[0b0000_0001, 0b0000_0000]).is_err());
928        assert!(BitList1::from_ssz_bytes(&[0b0000_0001, 0b0000_0000]).is_err());
929        assert!(BitList8::from_ssz_bytes(&[0b0000_0001, 0b0000_0000]).is_err());
930        assert!(BitList16::from_ssz_bytes(&[0b0000_0001, 0b0000_0000]).is_err());
931        assert!(BitList1024::from_ssz_bytes(&[0b1000_0000, 0]).is_err());
932        assert!(BitList1024::from_ssz_bytes(&[0b1000_0000, 0, 0]).is_err());
933        assert!(BitList1024::from_ssz_bytes(&[0b1000_0000, 0, 0, 0, 0]).is_err());
934    }
935
936    #[test]
937    fn ssz_round_trip() {
938        assert_round_trip(BitList0::with_capacity(0).unwrap());
939
940        for i in 0..2 {
941            assert_round_trip(BitList1::with_capacity(i).unwrap());
942        }
943        for i in 0..9 {
944            assert_round_trip(BitList8::with_capacity(i).unwrap());
945        }
946        for i in 0..17 {
947            assert_round_trip(BitList16::with_capacity(i).unwrap());
948        }
949
950        let mut b = BitList1::with_capacity(1).unwrap();
951        b.set(0, true).unwrap();
952        assert_round_trip(b);
953
954        for i in 0..8 {
955            let mut b = BitList8::with_capacity(i).unwrap();
956            for j in 0..i {
957                if j % 2 == 0 {
958                    b.set(j, true).unwrap();
959                }
960            }
961            assert_round_trip(b);
962
963            let mut b = BitList8::with_capacity(i).unwrap();
964            for j in 0..i {
965                b.set(j, true).unwrap();
966            }
967            assert_round_trip(b);
968        }
969
970        for i in 0..16 {
971            let mut b = BitList16::with_capacity(i).unwrap();
972            for j in 0..i {
973                if j % 2 == 0 {
974                    b.set(j, true).unwrap();
975                }
976            }
977            assert_round_trip(b);
978
979            let mut b = BitList16::with_capacity(i).unwrap();
980            for j in 0..i {
981                b.set(j, true).unwrap();
982            }
983            assert_round_trip(b);
984        }
985    }
986
987    fn assert_round_trip<T: Encode + Decode + PartialEq + std::fmt::Debug>(t: T) {
988        assert_eq!(T::from_ssz_bytes(&t.as_ssz_bytes()).unwrap(), t);
989    }
990
991    #[test]
992    fn from_raw_bytes() {
993        assert!(BitList1024::from_raw_bytes(vec![0b0000_0000], 0).is_ok());
994        assert!(BitList1024::from_raw_bytes(vec![0b0000_0001], 1).is_ok());
995        assert!(BitList1024::from_raw_bytes(vec![0b0000_0011], 2).is_ok());
996        assert!(BitList1024::from_raw_bytes(vec![0b0000_0111], 3).is_ok());
997        assert!(BitList1024::from_raw_bytes(vec![0b0000_1111], 4).is_ok());
998        assert!(BitList1024::from_raw_bytes(vec![0b0001_1111], 5).is_ok());
999        assert!(BitList1024::from_raw_bytes(vec![0b0011_1111], 6).is_ok());
1000        assert!(BitList1024::from_raw_bytes(vec![0b0111_1111], 7).is_ok());
1001        assert!(BitList1024::from_raw_bytes(vec![0b1111_1111], 8).is_ok());
1002
1003        assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0000_0001], 9).is_ok());
1004        assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0000_0011], 10).is_ok());
1005        assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0000_0111], 11).is_ok());
1006        assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0000_1111], 12).is_ok());
1007        assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0001_1111], 13).is_ok());
1008        assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0011_1111], 14).is_ok());
1009        assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0111_1111], 15).is_ok());
1010        assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b1111_1111], 16).is_ok());
1011
1012        for i in 0..8 {
1013            assert!(BitList1024::from_raw_bytes(vec![], i).is_err());
1014            assert!(BitList1024::from_raw_bytes(vec![0b1111_1111], i).is_err());
1015            assert!(BitList1024::from_raw_bytes(vec![0b0000_0000, 0b1111_1110], i).is_err());
1016        }
1017
1018        assert!(BitList1024::from_raw_bytes(vec![0b0000_0001], 0).is_err());
1019
1020        assert!(BitList1024::from_raw_bytes(vec![0b0000_0001], 0).is_err());
1021        assert!(BitList1024::from_raw_bytes(vec![0b0000_0011], 1).is_err());
1022        assert!(BitList1024::from_raw_bytes(vec![0b0000_0111], 2).is_err());
1023        assert!(BitList1024::from_raw_bytes(vec![0b0000_1111], 3).is_err());
1024        assert!(BitList1024::from_raw_bytes(vec![0b0001_1111], 4).is_err());
1025        assert!(BitList1024::from_raw_bytes(vec![0b0011_1111], 5).is_err());
1026        assert!(BitList1024::from_raw_bytes(vec![0b0111_1111], 6).is_err());
1027        assert!(BitList1024::from_raw_bytes(vec![0b1111_1111], 7).is_err());
1028
1029        assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0000_0001], 8).is_err());
1030        assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0000_0011], 9).is_err());
1031        assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0000_0111], 10).is_err());
1032        assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0000_1111], 11).is_err());
1033        assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0001_1111], 12).is_err());
1034        assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0011_1111], 13).is_err());
1035        assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b0111_1111], 14).is_err());
1036        assert!(BitList1024::from_raw_bytes(vec![0b1111_1111, 0b1111_1111], 15).is_err());
1037    }
1038
1039    fn test_set_unset(num_bits: usize) {
1040        let mut bitfield = BitList1024::with_capacity(num_bits).unwrap();
1041
1042        for i in 0..=num_bits {
1043            if i < num_bits {
1044                // Starts as false
1045                assert_eq!(bitfield.get(i), Ok(false));
1046                // Can be set true.
1047                assert!(bitfield.set(i, true).is_ok());
1048                assert_eq!(bitfield.get(i), Ok(true));
1049                // Can be set false
1050                assert!(bitfield.set(i, false).is_ok());
1051                assert_eq!(bitfield.get(i), Ok(false));
1052            } else {
1053                assert!(bitfield.get(i).is_err());
1054                assert!(bitfield.set(i, true).is_err());
1055                assert!(bitfield.get(i).is_err());
1056            }
1057        }
1058    }
1059
1060    fn test_bytes_round_trip(num_bits: usize) {
1061        for i in 0..num_bits {
1062            let mut bitfield = BitList1024::with_capacity(num_bits).unwrap();
1063            bitfield.set(i, true).unwrap();
1064
1065            let bytes = bitfield.clone().into_raw_bytes();
1066            assert_eq!(bitfield, Bitfield::from_raw_bytes(bytes, num_bits).unwrap());
1067        }
1068    }
1069
1070    #[test]
1071    fn set_unset() {
1072        for i in 0..8 * 5 {
1073            test_set_unset(i)
1074        }
1075    }
1076
1077    #[test]
1078    fn bytes_round_trip() {
1079        for i in 0..8 * 5 {
1080            test_bytes_round_trip(i)
1081        }
1082    }
1083
1084    #[test]
1085    fn into_raw_bytes() {
1086        let mut bitfield = BitList1024::with_capacity(9).unwrap();
1087        bitfield.set(0, true).unwrap();
1088        assert_eq!(
1089            bitfield.clone().into_raw_bytes(),
1090            vec![0b0000_0001, 0b0000_0000]
1091        );
1092        bitfield.set(1, true).unwrap();
1093        assert_eq!(
1094            bitfield.clone().into_raw_bytes(),
1095            vec![0b0000_0011, 0b0000_0000]
1096        );
1097        bitfield.set(2, true).unwrap();
1098        assert_eq!(
1099            bitfield.clone().into_raw_bytes(),
1100            vec![0b0000_0111, 0b0000_0000]
1101        );
1102        bitfield.set(3, true).unwrap();
1103        assert_eq!(
1104            bitfield.clone().into_raw_bytes(),
1105            vec![0b0000_1111, 0b0000_0000]
1106        );
1107        bitfield.set(4, true).unwrap();
1108        assert_eq!(
1109            bitfield.clone().into_raw_bytes(),
1110            vec![0b0001_1111, 0b0000_0000]
1111        );
1112        bitfield.set(5, true).unwrap();
1113        assert_eq!(
1114            bitfield.clone().into_raw_bytes(),
1115            vec![0b0011_1111, 0b0000_0000]
1116        );
1117        bitfield.set(6, true).unwrap();
1118        assert_eq!(
1119            bitfield.clone().into_raw_bytes(),
1120            vec![0b0111_1111, 0b0000_0000]
1121        );
1122        bitfield.set(7, true).unwrap();
1123        assert_eq!(
1124            bitfield.clone().into_raw_bytes(),
1125            vec![0b1111_1111, 0b0000_0000]
1126        );
1127        bitfield.set(8, true).unwrap();
1128        assert_eq!(bitfield.into_raw_bytes(), vec![0b1111_1111, 0b0000_0001]);
1129    }
1130
1131    #[test]
1132    fn highest_set_bit() {
1133        assert_eq!(
1134            BitList1024::with_capacity(16).unwrap().highest_set_bit(),
1135            None
1136        );
1137
1138        assert_eq!(
1139            BitList1024::from_raw_bytes(vec![0b0000_0001, 0b0000_0000], 16)
1140                .unwrap()
1141                .highest_set_bit(),
1142            Some(0)
1143        );
1144
1145        assert_eq!(
1146            BitList1024::from_raw_bytes(vec![0b0000_0010, 0b0000_0000], 16)
1147                .unwrap()
1148                .highest_set_bit(),
1149            Some(1)
1150        );
1151
1152        assert_eq!(
1153            BitList1024::from_raw_bytes(vec![0b0000_1000], 8)
1154                .unwrap()
1155                .highest_set_bit(),
1156            Some(3)
1157        );
1158
1159        assert_eq!(
1160            BitList1024::from_raw_bytes(vec![0b0000_0000, 0b1000_0000], 16)
1161                .unwrap()
1162                .highest_set_bit(),
1163            Some(15)
1164        );
1165    }
1166
1167    #[test]
1168    fn intersection() {
1169        let a = BitList1024::from_raw_bytes(vec![0b1100, 0b0001], 16).unwrap();
1170        let b = BitList1024::from_raw_bytes(vec![0b1011, 0b1001], 16).unwrap();
1171        let c = BitList1024::from_raw_bytes(vec![0b1000, 0b0001], 16).unwrap();
1172
1173        assert_eq!(a.intersection(&b), c);
1174        assert_eq!(b.intersection(&a), c);
1175        assert_eq!(a.intersection(&c), c);
1176        assert_eq!(b.intersection(&c), c);
1177        assert_eq!(a.intersection(&a), a);
1178        assert_eq!(b.intersection(&b), b);
1179        assert_eq!(c.intersection(&c), c);
1180    }
1181
1182    #[test]
1183    fn intersection_diff_length() {
1184        let a = BitList1024::from_bytes(vec![0b0010_1110, 0b0010_1011]).unwrap();
1185        let b = BitList1024::from_bytes(vec![0b0010_1101, 0b0000_0001]).unwrap();
1186        let c = BitList1024::from_bytes(vec![0b0010_1100, 0b0000_0001]).unwrap();
1187        let d = BitList1024::from_bytes(vec![0b0010_1110, 0b1111_1111, 0b1111_1111]).unwrap();
1188
1189        assert_eq!(a.len(), 13);
1190        assert_eq!(b.len(), 8);
1191        assert_eq!(c.len(), 8);
1192        assert_eq!(d.len(), 23);
1193        assert_eq!(a.intersection(&b), c);
1194        assert_eq!(b.intersection(&a), c);
1195        assert_eq!(a.intersection(&d), a);
1196        assert_eq!(d.intersection(&a), a);
1197    }
1198
1199    #[test]
1200    fn union() {
1201        let a = BitList1024::from_raw_bytes(vec![0b1100, 0b0001], 16).unwrap();
1202        let b = BitList1024::from_raw_bytes(vec![0b1011, 0b1001], 16).unwrap();
1203        let c = BitList1024::from_raw_bytes(vec![0b1111, 0b1001], 16).unwrap();
1204
1205        assert_eq!(a.union(&b), c);
1206        assert_eq!(b.union(&a), c);
1207        assert_eq!(a.union(&a), a);
1208        assert_eq!(b.union(&b), b);
1209        assert_eq!(c.union(&c), c);
1210    }
1211
1212    #[test]
1213    fn union_diff_length() {
1214        let a = BitList1024::from_bytes(vec![0b0010_1011, 0b0010_1110]).unwrap();
1215        let b = BitList1024::from_bytes(vec![0b0000_0001, 0b0010_1101]).unwrap();
1216        let c = BitList1024::from_bytes(vec![0b0010_1011, 0b0010_1111]).unwrap();
1217        let d = BitList1024::from_bytes(vec![0b0010_1011, 0b1011_1110, 0b1000_1101]).unwrap();
1218
1219        assert_eq!(a.len(), c.len());
1220        assert_eq!(a.union(&b), c);
1221        assert_eq!(b.union(&a), c);
1222        assert_eq!(a.union(&d), d);
1223        assert_eq!(d.union(&a), d);
1224    }
1225
1226    #[test]
1227    fn difference() {
1228        let a = BitList1024::from_raw_bytes(vec![0b1100, 0b0001], 16).unwrap();
1229        let b = BitList1024::from_raw_bytes(vec![0b1011, 0b1001], 16).unwrap();
1230        let a_b = BitList1024::from_raw_bytes(vec![0b0100, 0b0000], 16).unwrap();
1231        let b_a = BitList1024::from_raw_bytes(vec![0b0011, 0b1000], 16).unwrap();
1232
1233        assert_eq!(a.difference(&b), a_b);
1234        assert_eq!(b.difference(&a), b_a);
1235        assert!(a.difference(&a).is_zero());
1236    }
1237
1238    #[test]
1239    fn difference_diff_length() {
1240        let a = BitList1024::from_raw_bytes(vec![0b0110, 0b1100, 0b0011], 24).unwrap();
1241        let b = BitList1024::from_raw_bytes(vec![0b1011, 0b1001], 16).unwrap();
1242        let a_b = BitList1024::from_raw_bytes(vec![0b0100, 0b0100, 0b0011], 24).unwrap();
1243        let b_a = BitList1024::from_raw_bytes(vec![0b1001, 0b0001], 16).unwrap();
1244
1245        assert_eq!(a.difference(&b), a_b);
1246        assert_eq!(b.difference(&a), b_a);
1247    }
1248
1249    #[test]
1250    fn shift_up() {
1251        let mut a = BitList1024::from_raw_bytes(vec![0b1100_1111, 0b1101_0110], 16).unwrap();
1252        let mut b = BitList1024::from_raw_bytes(vec![0b1001_1110, 0b1010_1101], 16).unwrap();
1253
1254        a.shift_up(1).unwrap();
1255        assert_eq!(a, b);
1256        a.shift_up(15).unwrap();
1257        assert!(a.is_zero());
1258
1259        b.shift_up(16).unwrap();
1260        assert!(b.is_zero());
1261        assert!(b.shift_up(17).is_err());
1262    }
1263
1264    #[test]
1265    fn num_set_bits() {
1266        let a = BitList1024::from_raw_bytes(vec![0b1100, 0b0001], 16).unwrap();
1267        let b = BitList1024::from_raw_bytes(vec![0b1011, 0b1001], 16).unwrap();
1268
1269        assert_eq!(a.num_set_bits(), 3);
1270        assert_eq!(b.num_set_bits(), 5);
1271    }
1272
1273    #[test]
1274    fn iter() {
1275        let mut bitfield = BitList1024::with_capacity(9).unwrap();
1276        bitfield.set(2, true).unwrap();
1277        bitfield.set(8, true).unwrap();
1278
1279        assert_eq!(
1280            bitfield.iter().collect::<Vec<bool>>(),
1281            vec![false, false, true, false, false, false, false, false, true]
1282        );
1283    }
1284
1285    #[test]
1286    fn ssz_bytes_len() {
1287        for i in 1..64 {
1288            let mut bitfield = BitList1024::with_capacity(i).unwrap();
1289            for j in 0..i {
1290                bitfield.set(j, true).expect("should set bit in bounds");
1291            }
1292            let bytes = bitfield.as_ssz_bytes();
1293            assert_eq!(bitfield.ssz_bytes_len(), bytes.len(), "i = {}", i);
1294        }
1295    }
1296}