Skip to main content

ssz/
bitfield.rs

1use crate::{Decode, DecodeError, Encode};
2use core::marker::PhantomData;
3use serde::de::{Deserialize, Deserializer};
4use serde::ser::{Serialize, Serializer};
5use serde_utils::hex::{encode as hex_encode, PrefixedHexVisitor};
6use smallvec::{smallvec, SmallVec, ToSmallVec};
7use typenum::Unsigned;
8pub mod bitvector_dynamic;
9
10/// Returned when an item encounters an error.
11#[derive(PartialEq, Debug, Clone)]
12pub enum Error {
13    OutOfBounds {
14        i: usize,
15        len: usize,
16    },
17    /// A `BitList` does not have a set bit, therefore it's length is unknowable.
18    MissingLengthInformation,
19    /// A `BitList` has excess bits set to true.
20    ExcessBits,
21    /// A `BitList` has an invalid number of bytes for a given bit length.
22    InvalidByteCount {
23        given: usize,
24        expected: usize,
25    },
26}
27
28/// Maximum number of bytes to store on the stack in a bitfield's `SmallVec`.
29///
30/// 128 bytes is enough to take us through to ~2M active validators, as the byte
31/// length of attestation bitfields is roughly `N // 32 slots // 64 committes //
32/// 8 bits`.
33pub const SMALLVEC_LEN: usize = 128;
34
35/// A marker trait applied to `Variable` and `Fixed` that defines the behaviour of a `Bitfield`.
36pub trait BitfieldBehaviour: Clone {}
37
38/// A marker struct used to declare SSZ `Variable` behaviour on a `Bitfield`.
39///
40/// See the [`Bitfield`](struct.Bitfield.html) docs for usage.
41#[derive(Clone, PartialEq, Eq, Debug)]
42pub struct Variable<N> {
43    _phantom: PhantomData<N>,
44}
45
46/// A marker struct used to declare SSZ `Fixed` behaviour on a `Bitfield`.
47///
48/// See the [`Bitfield`](struct.Bitfield.html) docs for usage.
49#[derive(Clone, PartialEq, Eq, Debug)]
50pub struct Fixed<N> {
51    _phantom: PhantomData<N>,
52}
53
54impl<N: Unsigned + Clone> BitfieldBehaviour for Variable<N> {}
55impl<N: Unsigned + Clone> BitfieldBehaviour for Fixed<N> {}
56
57/// A heap-allocated, ordered, variable-length collection of `bool` values, limited to `N` bits.
58pub type BitList<N> = Bitfield<Variable<N>>;
59
60/// A heap-allocated, ordered, fixed-length collection of `bool` values, with `N` bits.
61///
62/// See [Bitfield](struct.Bitfield.html) documentation.
63pub type BitVector<N> = Bitfield<Fixed<N>>;
64
65/// A heap-allocated, ordered, fixed-length, collection of `bool` values. Use of
66/// [`BitList`](type.BitList.html) or [`BitVector`](type.BitVector.html) type aliases is preferred
67/// over direct use of this struct.
68///
69/// The `T` type parameter is used to define length behaviour with the `Variable` or `Fixed` marker
70/// structs.
71///
72/// The length of the Bitfield is set at instantiation (i.e., runtime, not compile time). However,
73/// use with a `Variable` sets a type-level (i.e., compile-time) maximum length and `Fixed`
74/// provides a type-level fixed length.
75///
76/// ## Example
77///
78/// The example uses the following crate-level type aliases:
79///
80/// - `BitList<N>` is an alias for `Bitfield<Variable<N>>`
81/// - `BitVector<N>` is an alias for `Bitfield<Fixed<N>>`
82///
83/// ```
84/// use ssz::{BitVector, BitList};
85/// use typenum;
86///
87/// // `BitList` has a type-level maximum length. The length of the list is specified at runtime
88/// // and it must be less than or equal to `N`. After instantiation, `BitList` cannot grow or
89/// // shrink.
90/// type BitList8 = BitList<typenum::U8>;
91///
92/// // Creating a `BitList` with a larger-than-`N` capacity returns `None`.
93/// assert!(BitList8::with_capacity(9).is_err());
94///
95/// let mut bitlist = BitList8::with_capacity(4).unwrap();  // `BitList` permits a capacity of less than the maximum.
96/// assert!(bitlist.set(3, true).is_ok());  // Setting inside the instantiation capacity is permitted.
97/// assert!(bitlist.set(5, true).is_err());  // Setting outside that capacity is not.
98///
99/// // `BitVector` has a type-level fixed length. Unlike `BitList`, it cannot be instantiated with a custom length
100/// // or grow/shrink.
101/// type BitVector8 = BitVector<typenum::U8>;
102///
103/// let mut bitvector = BitVector8::new();
104/// assert_eq!(bitvector.len(), 8); // `BitVector` length is fixed at the type-level.
105/// assert!(bitvector.set(7, true).is_ok());  // Setting inside the capacity is permitted.
106/// assert!(bitvector.set(9, true).is_err());  // Setting outside the capacity is not.
107///
108/// ```
109///
110/// ## Note
111///
112/// The internal representation of the bitfield is the same as that required by SSZ. The lowest
113/// byte (by `Vec` index) stores the lowest bit-indices and the right-most bit stores the lowest
114/// bit-index. E.g., `smallvec![0b0000_0001, 0b0000_0010]` has bits `0, 9` set.
115#[derive(Clone, Debug)]
116pub struct Bitfield<T> {
117    bytes: SmallVec<[u8; SMALLVEC_LEN]>,
118    len: usize,
119    _phantom: PhantomData<T>,
120}
121
122impl<N: Unsigned + Clone> Bitfield<Variable<N>> {
123    /// Instantiate with capacity for `num_bits` boolean values. The length cannot be grown or
124    /// shrunk after instantiation.
125    ///
126    /// All bits are initialized to `false`.
127    ///
128    /// Returns `Err` if `num_bits > N`.
129    pub fn with_capacity(num_bits: usize) -> Result<Self, Error> {
130        if num_bits <= N::to_usize() {
131            Ok(Self {
132                bytes: smallvec![0; bytes_for_bit_len(num_bits)],
133                len: num_bits,
134                _phantom: PhantomData,
135            })
136        } else {
137            Err(Error::OutOfBounds {
138                i: num_bits,
139                len: Self::max_len(),
140            })
141        }
142    }
143
144    /// Equal to `N` regardless of the value supplied to `with_capacity`.
145    pub fn max_len() -> usize {
146        N::to_usize()
147    }
148
149    /// Consumes `self`, returning a serialized representation.
150    ///
151    /// The output is faithful to the SSZ encoding of `self`, such that a leading `true` bit is
152    /// used to indicate the length of the bitfield.
153    ///
154    /// ## Example
155    /// ```
156    /// use ssz::BitList;
157    /// use smallvec::SmallVec;
158    /// use typenum;
159    ///
160    /// type BitList8 = BitList<typenum::U8>;
161    ///
162    /// let b = BitList8::with_capacity(4).unwrap();
163    ///
164    /// assert_eq!(b.into_bytes(), SmallVec::from_buf([0b0001_0000]));
165    /// ```
166    pub fn into_bytes(self) -> SmallVec<[u8; SMALLVEC_LEN]> {
167        let len = self.len();
168        let mut bytes = self.bytes;
169
170        bytes.resize(bytes_for_bit_len(len + 1), 0);
171
172        let mut bitfield: Bitfield<Variable<N>> = Bitfield::from_raw_bytes(bytes, len + 1)
173            .unwrap_or_else(|_| {
174                unreachable!(
175                    "Bitfield with {} bytes must have enough capacity for {} bits.",
176                    bytes_for_bit_len(len + 1),
177                    len + 1
178                )
179            });
180        bitfield
181            .set(len, true)
182            .expect("len must be in bounds for bitfield.");
183
184        bitfield.bytes
185    }
186
187    /// Instantiates a new instance from `bytes`. Consumes the same format that `self.into_bytes()`
188    /// produces (SSZ).
189    ///
190    /// Returns `None` if `bytes` are not a valid encoding.
191    pub fn from_bytes(bytes: SmallVec<[u8; SMALLVEC_LEN]>) -> Result<Self, Error> {
192        let bytes_len = bytes.len();
193        let mut initial_bitfield: Bitfield<Variable<N>> = {
194            let num_bits = bytes.len() * 8;
195            Bitfield::from_raw_bytes(bytes, num_bits)?
196        };
197
198        let len = initial_bitfield
199            .highest_set_bit()
200            .ok_or(Error::MissingLengthInformation)?;
201
202        // The length bit should be in the last byte, or else it means we have too many bytes.
203        if len / 8 + 1 != bytes_len {
204            return Err(Error::InvalidByteCount {
205                given: bytes_len,
206                expected: len / 8 + 1,
207            });
208        }
209
210        if len <= Self::max_len() {
211            initial_bitfield
212                .set(len, false)
213                .expect("Bit has been confirmed to exist");
214
215            let mut bytes = initial_bitfield.into_raw_bytes();
216
217            bytes.truncate(bytes_for_bit_len(len));
218
219            Self::from_raw_bytes(bytes, len)
220        } else {
221            Err(Error::OutOfBounds {
222                i: Self::max_len(),
223                len: Self::max_len(),
224            })
225        }
226    }
227
228    /// Compute the intersection of two BitLists of potentially different lengths.
229    ///
230    /// Return a new BitList with length equal to the shorter of the two inputs.
231    pub fn intersection(&self, other: &Self) -> Self {
232        let min_len = std::cmp::min(self.len(), other.len());
233        let mut result = Self::with_capacity(min_len).expect("min len always less than N");
234        // Bitwise-and the bytes together, starting from the left of each vector. This takes care
235        // of masking out any entries beyond `min_len` as well, assuming the bitfield doesn't
236        // contain any set bits beyond its length.
237        for i in 0..result.bytes.len() {
238            result.bytes[i] = self.bytes[i] & other.bytes[i];
239        }
240        result
241    }
242
243    /// Compute the union of two BitLists of potentially different lengths.
244    ///
245    /// Return a new BitList with length equal to the longer of the two inputs.
246    pub fn union(&self, other: &Self) -> Self {
247        let max_len = std::cmp::max(self.len(), other.len());
248        let mut result = Self::with_capacity(max_len).expect("max len always less than N");
249        for i in 0..result.bytes.len() {
250            result.bytes[i] =
251                self.bytes.get(i).copied().unwrap_or(0) | other.bytes.get(i).copied().unwrap_or(0);
252        }
253        result
254    }
255
256    /// Returns `true` if `self` is a subset of `other` and `false` otherwise.
257    pub fn is_subset(&self, other: &Self) -> bool {
258        self.difference(other).is_zero()
259    }
260
261    /// Returns a new BitList of length M, with the same bits set as `self`.
262    pub fn resize<M: Unsigned + Clone>(&self) -> Result<Bitfield<Variable<M>>, Error> {
263        if N::to_usize() > M::to_usize() {
264            return Err(Error::InvalidByteCount {
265                given: M::to_usize(),
266                expected: N::to_usize() + 1,
267            });
268        }
269
270        let mut resized = Bitfield::<Variable<M>>::with_capacity(M::to_usize())?;
271
272        for (i, bit) in self.iter().enumerate() {
273            resized.set(i, bit)?;
274        }
275
276        Ok(resized)
277    }
278
279    /// Returns a clone of the bitfield with all bits set to false.
280    ///
281    /// Compared to `with_capacity`, this is infallible.
282    pub fn clone_zeroed(&self) -> Self {
283        Self::with_capacity(self.len()).expect("`len` is guaranteed to be `N` or less")
284    }
285}
286
287impl<N: Unsigned + Clone> Bitfield<Fixed<N>> {
288    /// Instantiate a new `Bitfield` with a fixed-length of `N` bits.
289    ///
290    /// All bits are initialized to `false`.
291    pub fn new() -> Self {
292        Self {
293            bytes: smallvec![0; bytes_for_bit_len(Self::capacity())],
294            len: Self::capacity(),
295            _phantom: PhantomData,
296        }
297    }
298
299    /// Returns `N`, the number of bits in `Self`.
300    pub fn capacity() -> usize {
301        N::to_usize()
302    }
303
304    /// Consumes `self`, returning a serialized representation.
305    ///
306    /// The output is faithful to the SSZ encoding of `self`.
307    ///
308    /// ## Example
309    /// ```
310    /// use ssz::BitVector;
311    /// use smallvec::SmallVec;
312    /// use typenum;
313    ///
314    /// type BitVector4 = BitVector<typenum::U4>;
315    ///
316    /// assert_eq!(BitVector4::new().into_bytes(), SmallVec::from_buf([0b0000_0000]));
317    /// ```
318    pub fn into_bytes(self) -> SmallVec<[u8; SMALLVEC_LEN]> {
319        self.into_raw_bytes()
320    }
321
322    /// Instantiates a new instance from `bytes`. Consumes the same format that `self.into_bytes()`
323    /// produces (SSZ).
324    ///
325    /// Returns `None` if `bytes` are not a valid encoding.
326    pub fn from_bytes(bytes: SmallVec<[u8; SMALLVEC_LEN]>) -> Result<Self, Error> {
327        Self::from_raw_bytes(bytes, Self::capacity())
328    }
329
330    /// Compute the intersection of two fixed-length `Bitfield`s.
331    ///
332    /// Return a new fixed-length `Bitfield`.
333    pub fn intersection(&self, other: &Self) -> Self {
334        let mut result = Self::new();
335        // Bitwise-and the bytes together, starting from the left of each vector. This takes care
336        // of masking out any entries beyond `min_len` as well, assuming the bitfield doesn't
337        // contain any set bits beyond its length.
338        for i in 0..result.bytes.len() {
339            result.bytes[i] = self.bytes[i] & other.bytes[i];
340        }
341        result
342    }
343
344    /// Compute the union of two fixed-length `Bitfield`s.
345    ///
346    /// Return a new fixed-length `Bitfield`.
347    pub fn union(&self, other: &Self) -> Self {
348        let mut result = Self::new();
349        for i in 0..result.bytes.len() {
350            result.bytes[i] =
351                self.bytes.get(i).copied().unwrap_or(0) | other.bytes.get(i).copied().unwrap_or(0);
352        }
353        result
354    }
355
356    /// Returns `true` if `self` is a subset of `other` and `false` otherwise.
357    pub fn is_subset(&self, other: &Self) -> bool {
358        self.difference(other).is_zero()
359    }
360}
361
362impl<T: BitfieldBehaviour> std::fmt::Display for Bitfield<T> {
363    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364        let mut field: String = "".to_string();
365        for i in self.iter() {
366            if i {
367                field.push('1')
368            } else {
369                field.push('0')
370            }
371        }
372        write!(f, "{field}")
373    }
374}
375
376impl<N: Unsigned + Clone> Default for Bitfield<Fixed<N>> {
377    fn default() -> Self {
378        Self::new()
379    }
380}
381
382impl<T: BitfieldBehaviour> Bitfield<T> {
383    /// Sets the `i`'th bit to `value`.
384    ///
385    /// Returns `None` if `i` is out-of-bounds of `self`.
386    pub fn set(&mut self, i: usize, value: bool) -> Result<(), Error> {
387        let len = self.len;
388
389        if i < len {
390            let byte = self
391                .bytes
392                .get_mut(i / 8)
393                .ok_or(Error::OutOfBounds { i, len })?;
394
395            if value {
396                *byte |= 1 << (i % 8)
397            } else {
398                *byte &= !(1 << (i % 8))
399            }
400
401            Ok(())
402        } else {
403            Err(Error::OutOfBounds { i, len: self.len })
404        }
405    }
406
407    /// Returns the value of the `i`'th bit.
408    ///
409    /// Returns `Error` if `i` is out-of-bounds of `self`.
410    pub fn get(&self, i: usize) -> Result<bool, Error> {
411        if i < self.len {
412            let byte = self
413                .bytes
414                .get(i / 8)
415                .ok_or(Error::OutOfBounds { i, len: self.len })?;
416
417            Ok(*byte & (1 << (i % 8)) > 0)
418        } else {
419            Err(Error::OutOfBounds { i, len: self.len })
420        }
421    }
422
423    /// Returns the number of bits stored in `self`.
424    pub fn len(&self) -> usize {
425        self.len
426    }
427
428    /// Returns `true` if `self.len() == 0`.
429    pub fn is_empty(&self) -> bool {
430        self.len == 0
431    }
432
433    /// Returns the underlying bytes representation of the bitfield.
434    pub fn into_raw_bytes(self) -> SmallVec<[u8; SMALLVEC_LEN]> {
435        self.bytes
436    }
437
438    /// Returns a view into the underlying bytes representation of the bitfield.
439    pub fn as_slice(&self) -> &[u8] {
440        &self.bytes
441    }
442
443    /// Instantiates from the given `bytes`, which are the same format as output from
444    /// `self.into_raw_bytes()`.
445    ///
446    /// Returns `None` if:
447    ///
448    /// - `bytes` is not the minimal required bytes to represent a bitfield of `bit_len` bits.
449    /// - `bit_len` is not a multiple of 8 and `bytes` contains set bits that are higher than, or
450    ///   equal to `bit_len`.
451    fn from_raw_bytes(bytes: SmallVec<[u8; SMALLVEC_LEN]>, bit_len: usize) -> Result<Self, Error> {
452        if bit_len == 0 {
453            if bytes.len() == 1 && bytes[0] == 0 {
454                // A bitfield with `bit_len` 0 can only be represented by a single zero byte.
455                Ok(Self {
456                    bytes,
457                    len: 0,
458                    _phantom: PhantomData,
459                })
460            } else {
461                Err(Error::ExcessBits)
462            }
463        } else if bytes.len() != bytes_for_bit_len(bit_len) {
464            // The number of bytes must be the minimum required to represent `bit_len`.
465            Err(Error::InvalidByteCount {
466                given: bytes.len(),
467                expected: bytes_for_bit_len(bit_len),
468            })
469        } else {
470            // Ensure there are no bits higher than `bit_len` that are set to true.
471            let mask = last_byte_mask(bit_len);
472
473            if (bytes.last().expect("Guarded against empty bytes") & !mask) == 0 {
474                Ok(Self {
475                    bytes,
476                    len: bit_len,
477                    _phantom: PhantomData,
478                })
479            } else {
480                Err(Error::ExcessBits)
481            }
482        }
483    }
484
485    /// Returns the `Some(i)` where `i` is the highest index with a set bit. Returns `None` if
486    /// there are no set bits.
487    pub fn highest_set_bit(&self) -> Option<usize> {
488        self.bytes
489            .iter()
490            .enumerate()
491            .rev()
492            .find(|(_, byte)| **byte > 0)
493            .map(|(i, byte)| i * 8 + 7 - byte.leading_zeros() as usize)
494    }
495
496    /// Returns an iterator across bitfield `bool` values, starting at the lowest index.
497    pub fn iter(&self) -> BitIter<'_, T> {
498        BitIter {
499            bitfield: self,
500            i: 0,
501        }
502    }
503
504    /// Returns true if no bits are set.
505    pub fn is_zero(&self) -> bool {
506        self.bytes.iter().all(|byte| *byte == 0)
507    }
508
509    /// Returns the number of bits that are set to `true`.
510    pub fn num_set_bits(&self) -> usize {
511        self.bytes
512            .iter()
513            .map(|byte| byte.count_ones() as usize)
514            .sum()
515    }
516
517    /// Compute the difference of this Bitfield and another of potentially different length.
518    pub fn difference(&self, other: &Self) -> Self {
519        let mut result = self.clone();
520        result.difference_inplace(other);
521        result
522    }
523
524    /// Compute the difference of this Bitfield and another of potentially different length.
525    pub fn difference_inplace(&mut self, other: &Self) {
526        let min_byte_len = std::cmp::min(self.bytes.len(), other.bytes.len());
527
528        for i in 0..min_byte_len {
529            self.bytes[i] &= !other.bytes[i];
530        }
531    }
532
533    /// Perform a bitwise-not operation on the bits in `self`. Creates a new Bitfield.
534    pub fn not(&self) -> Self {
535        let mut result = self.clone();
536        result.not_inplace();
537        result
538    }
539
540    /// Perform a bitwise-not operation on the bits in `self`.
541    pub fn not_inplace(&mut self) {
542        for byte in self.bytes.iter_mut() {
543            *byte = !*byte;
544        }
545        // Mask out any bits higher than `self.len`.
546        if let Some(last_byte) = self.bytes.last_mut() {
547            *last_byte &= last_byte_mask(self.len);
548        }
549    }
550
551    /// Shift the bits to higher indices, filling the lower indices with zeroes.
552    ///
553    /// The amount to shift by, `n`, must be less than or equal to `self.len()`.
554    pub fn shift_up(&mut self, n: usize) -> Result<(), Error> {
555        if n <= self.len() {
556            // Shift the bits up (starting from the high indices to avoid overwriting)
557            for i in (n..self.len()).rev() {
558                self.set(i, self.get(i - n)?)?;
559            }
560            // Zero the low bits
561            for i in 0..n {
562                self.set(i, false).unwrap();
563            }
564            Ok(())
565        } else {
566            Err(Error::OutOfBounds {
567                i: n,
568                len: self.len(),
569            })
570        }
571    }
572}
573
574/// Return the bitmask appropriate for the last byte of the internal representation for a bitfield
575/// of length `len`. Notably, this function also returns the correct mask for length zero.
576///
577/// This should be applied via bitwise AND.
578fn last_byte_mask(len: usize) -> u8 {
579    // If the length is zero, the last byte is always zero.
580    if len == 0 {
581        return 0;
582    }
583    u8::MAX.wrapping_shr((8 - (len % 8)) as u32)
584}
585
586impl<T> Eq for Bitfield<T> {}
587impl<T> PartialEq for Bitfield<T> {
588    #[inline]
589    fn eq(&self, other: &Bitfield<T>) -> bool {
590        self.len == other.len && self.bytes == other.bytes
591    }
592}
593
594impl<T> core::hash::Hash for Bitfield<T> {
595    #[inline]
596    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
597        core::hash::Hash::hash(&self.bytes, state);
598        core::hash::Hash::hash(&self.len, state);
599    }
600}
601
602/// Returns the minimum required bytes to represent a given number of bits.
603///
604/// `bit_len == 0` requires a single byte.
605fn bytes_for_bit_len(bit_len: usize) -> usize {
606    std::cmp::max(1, bit_len.div_ceil(8))
607}
608
609/// An iterator over the bits in a `Bitfield`.
610pub struct BitIter<'a, T> {
611    bitfield: &'a Bitfield<T>,
612    i: usize,
613}
614
615impl<T: BitfieldBehaviour> Iterator for BitIter<'_, T> {
616    type Item = bool;
617
618    fn next(&mut self) -> Option<Self::Item> {
619        let res = self.bitfield.get(self.i).ok()?;
620        self.i += 1;
621        Some(res)
622    }
623}
624
625impl<N: Unsigned + Clone> Encode for Bitfield<Variable<N>> {
626    fn is_ssz_fixed_len() -> bool {
627        false
628    }
629
630    fn ssz_bytes_len(&self) -> usize {
631        // We could likely do better than turning this into bytes and reading the length, however
632        // it is kept this way for simplicity.
633        self.clone().into_bytes().len()
634    }
635
636    fn ssz_append(&self, buf: &mut Vec<u8>) {
637        buf.extend_from_slice(&self.clone().into_bytes())
638    }
639}
640
641impl<N: Unsigned + Clone> Decode for Bitfield<Variable<N>> {
642    fn is_ssz_fixed_len() -> bool {
643        false
644    }
645
646    fn from_ssz_bytes(bytes: &[u8]) -> Result<Self, DecodeError> {
647        Self::from_bytes(bytes.to_smallvec())
648            .map_err(|e| DecodeError::BytesInvalid(format!("BitList failed to decode: {:?}", e)))
649    }
650}
651
652impl<N: Unsigned + Clone> Encode for Bitfield<Fixed<N>> {
653    fn is_ssz_fixed_len() -> bool {
654        true
655    }
656
657    fn ssz_bytes_len(&self) -> usize {
658        self.as_slice().len()
659    }
660
661    fn ssz_fixed_len() -> usize {
662        bytes_for_bit_len(N::to_usize())
663    }
664
665    fn ssz_append(&self, buf: &mut Vec<u8>) {
666        buf.extend_from_slice(&self.clone().into_bytes())
667    }
668}
669
670impl<N: Unsigned + Clone> Decode for Bitfield<Fixed<N>> {
671    fn is_ssz_fixed_len() -> bool {
672        true
673    }
674
675    fn ssz_fixed_len() -> usize {
676        bytes_for_bit_len(N::to_usize())
677    }
678
679    fn from_ssz_bytes(bytes: &[u8]) -> Result<Self, DecodeError> {
680        Self::from_bytes(bytes.to_smallvec())
681            .map_err(|e| DecodeError::BytesInvalid(format!("BitVector failed to decode: {:?}", e)))
682    }
683}
684
685impl<N: Unsigned + Clone> Serialize for Bitfield<Variable<N>> {
686    /// Serde serialization is compliant with the Ethereum YAML test format.
687    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
688    where
689        S: Serializer,
690    {
691        serializer.serialize_str(&hex_encode(self.as_ssz_bytes()))
692    }
693}
694
695impl<'de, N: Unsigned + Clone> Deserialize<'de> for Bitfield<Variable<N>> {
696    /// Serde serialization is compliant with the Ethereum YAML test format.
697    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
698    where
699        D: Deserializer<'de>,
700    {
701        let bytes = deserializer.deserialize_str(PrefixedHexVisitor)?;
702        Self::from_ssz_bytes(&bytes)
703            .map_err(|e| serde::de::Error::custom(format!("Bitfield {:?}", e)))
704    }
705}
706
707impl<N: Unsigned + Clone> Serialize for Bitfield<Fixed<N>> {
708    /// Serde serialization is compliant with the Ethereum YAML test format.
709    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
710    where
711        S: Serializer,
712    {
713        serializer.serialize_str(&hex_encode(self.as_ssz_bytes()))
714    }
715}
716
717impl<'de, N: Unsigned + Clone> Deserialize<'de> for Bitfield<Fixed<N>> {
718    /// Serde serialization is compliant with the Ethereum YAML test format.
719    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
720    where
721        D: Deserializer<'de>,
722    {
723        let bytes = deserializer.deserialize_str(PrefixedHexVisitor)?;
724        Self::from_ssz_bytes(&bytes)
725            .map_err(|e| serde::de::Error::custom(format!("Bitfield {:?}", e)))
726    }
727}
728
729#[cfg(feature = "arbitrary")]
730impl<N: 'static + Unsigned> arbitrary::Arbitrary<'_> for Bitfield<Fixed<N>> {
731    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
732        // N is the number of bits.
733        let num_bytes = bytes_for_bit_len(N::to_usize());
734        let mut vec = smallvec![0u8; num_bytes];
735        u.fill_buffer(&mut vec)?;
736        // Mask out any excess bits in the last byte.
737        if let Some(last) = vec.last_mut() {
738            *last &= last_byte_mask(N::to_usize());
739        }
740        Self::from_bytes(vec).map_err(|_| arbitrary::Error::IncorrectFormat)
741    }
742}
743
744#[cfg(feature = "arbitrary")]
745impl<N: 'static + Unsigned> arbitrary::Arbitrary<'_> for Bitfield<Variable<N>> {
746    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
747        let max_len = N::to_usize();
748        if max_len == 0 {
749            return Err(arbitrary::Error::IncorrectFormat);
750        }
751        // Pick a random data length in 1..=N.
752        let len = u.int_in_range(1..=max_len)?;
753        // The encoding requires len data bits + 1 length bit.
754        let total_bits = len + 1;
755        let num_bytes = bytes_for_bit_len(total_bits);
756        let mut vec = smallvec![0u8; num_bytes];
757        u.fill_buffer(&mut vec)?;
758        // Place the length bit at position `len` and clear everything above it
759        // in the last byte. Bits below the length bit are random data.
760        let length_bit_byte = len / 8;
761        let length_bit_pos = len % 8;
762        // Clear bytes at or above `length_bit_pos`.
763        vec[length_bit_byte] &= last_byte_mask(len);
764        // Set the length bit.
765        vec[length_bit_byte] |= 1u8 << length_bit_pos;
766        Self::from_bytes(vec).map_err(|_| arbitrary::Error::IncorrectFormat)
767    }
768}
769
770#[cfg(test)]
771mod bitvector {
772    use super::*;
773    use crate::BitVector;
774
775    pub type BitVector0 = BitVector<typenum::U0>;
776    pub type BitVector1 = BitVector<typenum::U1>;
777    pub type BitVector4 = BitVector<typenum::U4>;
778    pub type BitVector8 = BitVector<typenum::U8>;
779    pub type BitVector16 = BitVector<typenum::U16>;
780    pub type BitVector64 = BitVector<typenum::U64>;
781
782    #[test]
783    fn ssz_encode() {
784        assert_eq!(BitVector0::new().as_ssz_bytes(), vec![0b0000_0000]);
785        assert_eq!(BitVector1::new().as_ssz_bytes(), vec![0b0000_0000]);
786        assert_eq!(BitVector4::new().as_ssz_bytes(), vec![0b0000_0000]);
787        assert_eq!(BitVector8::new().as_ssz_bytes(), vec![0b0000_0000]);
788        assert_eq!(
789            BitVector16::new().as_ssz_bytes(),
790            vec![0b0000_0000, 0b0000_0000]
791        );
792
793        let mut b = BitVector8::new();
794        for i in 0..8 {
795            b.set(i, true).unwrap();
796        }
797        assert_eq!(b.as_ssz_bytes(), vec![255]);
798
799        let mut b = BitVector4::new();
800        for i in 0..4 {
801            b.set(i, true).unwrap();
802        }
803        assert_eq!(b.as_ssz_bytes(), vec![0b0000_1111]);
804    }
805
806    #[test]
807    fn ssz_decode() {
808        assert!(BitVector0::from_ssz_bytes(&[0b0000_0000]).is_ok());
809        assert!(BitVector0::from_ssz_bytes(&[0b0000_0001]).is_err());
810        assert!(BitVector0::from_ssz_bytes(&[0b0000_0010]).is_err());
811
812        assert!(BitVector1::from_ssz_bytes(&[0b0000_0001]).is_ok());
813        assert!(BitVector1::from_ssz_bytes(&[0b0000_0010]).is_err());
814        assert!(BitVector1::from_ssz_bytes(&[0b0000_0100]).is_err());
815        assert!(BitVector1::from_ssz_bytes(&[0b0000_0000, 0b0000_0000]).is_err());
816
817        assert!(BitVector8::from_ssz_bytes(&[0b0000_0000]).is_ok());
818        assert!(BitVector8::from_ssz_bytes(&[1, 0b0000_0000]).is_err());
819        assert!(BitVector8::from_ssz_bytes(&[0b0000_0000, 1]).is_err());
820        assert!(BitVector8::from_ssz_bytes(&[0b0000_0001]).is_ok());
821        assert!(BitVector8::from_ssz_bytes(&[0b0000_0010]).is_ok());
822        assert!(BitVector8::from_ssz_bytes(&[0b0000_0100, 0b0000_0001]).is_err());
823        assert!(BitVector8::from_ssz_bytes(&[0b0000_0100, 0b0000_0010]).is_err());
824        assert!(BitVector8::from_ssz_bytes(&[0b0000_0100, 0b0000_0100]).is_err());
825
826        assert!(BitVector16::from_ssz_bytes(&[0b0000_0000]).is_err());
827        assert!(BitVector16::from_ssz_bytes(&[0b0000_0000, 0b0000_0000]).is_ok());
828        assert!(BitVector16::from_ssz_bytes(&[1, 0b0000_0000, 0b0000_0000]).is_err());
829    }
830
831    #[test]
832    fn intersection() {
833        let a = BitVector16::from_raw_bytes(smallvec![0b1100, 0b0001], 16).unwrap();
834        let b = BitVector16::from_raw_bytes(smallvec![0b1011, 0b1001], 16).unwrap();
835        let c = BitVector16::from_raw_bytes(smallvec![0b1000, 0b0001], 16).unwrap();
836
837        assert_eq!(a.intersection(&b), c);
838        assert_eq!(b.intersection(&a), c);
839        assert_eq!(a.intersection(&c), c);
840        assert_eq!(b.intersection(&c), c);
841        assert_eq!(a.intersection(&a), a);
842        assert_eq!(b.intersection(&b), b);
843        assert_eq!(c.intersection(&c), c);
844    }
845
846    #[test]
847    fn intersection_diff_length() {
848        let a = BitVector16::from_bytes(smallvec![0b0010_1110, 0b0010_1011]).unwrap();
849        let b = BitVector16::from_bytes(smallvec![0b0010_1101, 0b0000_0001]).unwrap();
850        let c = BitVector16::from_bytes(smallvec![0b0010_1100, 0b0000_0001]).unwrap();
851
852        assert_eq!(a.len(), 16);
853        assert_eq!(b.len(), 16);
854        assert_eq!(c.len(), 16);
855        assert_eq!(a.intersection(&b), c);
856        assert_eq!(b.intersection(&a), c);
857    }
858
859    #[test]
860    fn subset() {
861        let a = BitVector16::from_raw_bytes(smallvec![0b1000, 0b0001], 16).unwrap();
862        let b = BitVector16::from_raw_bytes(smallvec![0b1100, 0b0001], 16).unwrap();
863        let c = BitVector16::from_raw_bytes(smallvec![0b1100, 0b1001], 16).unwrap();
864
865        assert_eq!(a.len(), 16);
866        assert_eq!(b.len(), 16);
867        assert_eq!(c.len(), 16);
868
869        // a vector is always a subset of itself
870        assert!(a.is_subset(&a));
871        assert!(b.is_subset(&b));
872        assert!(c.is_subset(&c));
873
874        assert!(a.is_subset(&b));
875        assert!(a.is_subset(&c));
876        assert!(b.is_subset(&c));
877
878        assert!(!b.is_subset(&a));
879        assert!(!c.is_subset(&a));
880        assert!(!c.is_subset(&b));
881    }
882
883    #[test]
884    fn union() {
885        let a = BitVector16::from_raw_bytes(smallvec![0b1100, 0b0001], 16).unwrap();
886        let b = BitVector16::from_raw_bytes(smallvec![0b1011, 0b1001], 16).unwrap();
887        let c = BitVector16::from_raw_bytes(smallvec![0b1111, 0b1001], 16).unwrap();
888
889        assert_eq!(a.union(&b), c);
890        assert_eq!(b.union(&a), c);
891        assert_eq!(a.union(&a), a);
892        assert_eq!(b.union(&b), b);
893        assert_eq!(c.union(&c), c);
894    }
895
896    #[test]
897    fn union_diff_length() {
898        let a = BitVector16::from_bytes(smallvec![0b0010_1011, 0b0010_1110]).unwrap();
899        let b = BitVector16::from_bytes(smallvec![0b0000_0001, 0b0010_1101]).unwrap();
900        let c = BitVector16::from_bytes(smallvec![0b0010_1011, 0b0010_1111]).unwrap();
901
902        assert_eq!(a.len(), c.len());
903        assert_eq!(a.union(&b), c);
904        assert_eq!(b.union(&a), c);
905    }
906
907    #[test]
908    fn ssz_round_trip() {
909        assert_round_trip(BitVector0::new());
910
911        let mut b = BitVector1::new();
912        b.set(0, true).unwrap();
913        assert_round_trip(b);
914
915        let mut b = BitVector8::new();
916        for j in 0..8 {
917            if j % 2 == 0 {
918                b.set(j, true).unwrap();
919            }
920        }
921        assert_round_trip(b);
922
923        let mut b = BitVector8::new();
924        for j in 0..8 {
925            b.set(j, true).unwrap();
926        }
927        assert_round_trip(b);
928
929        let mut b = BitVector16::new();
930        for j in 0..16 {
931            if j % 2 == 0 {
932                b.set(j, true).unwrap();
933            }
934        }
935        assert_round_trip(b);
936
937        let mut b = BitVector16::new();
938        for j in 0..16 {
939            b.set(j, true).unwrap();
940        }
941        assert_round_trip(b);
942    }
943
944    fn assert_round_trip<T: Encode + Decode + PartialEq + std::fmt::Debug>(t: T) {
945        assert_eq!(T::from_ssz_bytes(&t.as_ssz_bytes()).unwrap(), t);
946    }
947
948    #[test]
949    fn ssz_bytes_len() {
950        for i in 0..64 {
951            let mut bitfield = BitVector64::new();
952            for j in 0..i {
953                bitfield.set(j, true).expect("should set bit in bounds");
954            }
955            let bytes = bitfield.as_ssz_bytes();
956            assert_eq!(bitfield.ssz_bytes_len(), bytes.len(), "i = {}", i);
957        }
958    }
959
960    #[test]
961    fn excess_bits_nimbus() {
962        let bad = vec![0b0001_1111];
963
964        assert!(BitVector4::from_ssz_bytes(&bad).is_err());
965    }
966
967    // Ensure that stack size of a BitVector is manageable.
968    #[test]
969    fn size_of() {
970        assert_eq!(std::mem::size_of::<BitVector64>(), SMALLVEC_LEN + 24);
971    }
972
973    #[test]
974    fn display() {
975        let bitvec = BitVector16::from_bytes(smallvec![0b0010_1011, 0b0010_1110]).unwrap();
976        assert_eq!("1101010001110100", bitvec.to_string());
977    }
978
979    #[test]
980    fn not() {
981        // Test empty
982        let empty = BitVector0::new();
983        assert_eq!(empty.not(), empty);
984
985        // Test with all zeros
986        let a = BitVector8::new();
987        let mut expected = BitVector8::new();
988        for i in 0..8 {
989            expected.set(i, true).unwrap();
990        }
991        assert_eq!(a.not(), expected);
992
993        // Test with all ones
994        let b = expected.clone();
995        assert_eq!(b.not(), BitVector8::new());
996
997        // Test with mixed pattern
998        let c = BitVector16::from_raw_bytes(smallvec![0b1100_1010, 0b0011_0101], 16).unwrap();
999        let expected_c =
1000            BitVector16::from_raw_bytes(smallvec![0b0011_0101, 0b1100_1010], 16).unwrap();
1001        assert_eq!(c.not(), expected_c);
1002
1003        // Test with partial byte (4 bits)
1004        let d = BitVector4::from_raw_bytes(smallvec![0b0000_1010], 4).unwrap();
1005        let expected_d = BitVector4::from_raw_bytes(smallvec![0b0000_0101], 4).unwrap();
1006        assert_eq!(d.not(), expected_d);
1007
1008        // Test that masking works correctly for partial bytes
1009        let e = BitVector4::from_raw_bytes(smallvec![0b0000_1111], 4).unwrap();
1010        let expected_e = BitVector4::from_raw_bytes(smallvec![0b0000_0000], 4).unwrap();
1011        assert_eq!(e.not(), expected_e);
1012    }
1013
1014    #[test]
1015    fn not_inplace() {
1016        // Test with all zeros
1017        let mut a = BitVector8::new();
1018        a.not_inplace();
1019        let mut expected = BitVector8::new();
1020        for i in 0..8 {
1021            expected.set(i, true).unwrap();
1022        }
1023        assert_eq!(a, expected);
1024
1025        // Test with all ones
1026        let mut b = expected.clone();
1027        b.not_inplace();
1028        assert_eq!(b, BitVector8::new());
1029
1030        // Test with mixed pattern
1031        let mut c = BitVector16::from_raw_bytes(smallvec![0b1100_1010, 0b0011_0101], 16).unwrap();
1032        c.not_inplace();
1033        let expected_c =
1034            BitVector16::from_raw_bytes(smallvec![0b0011_0101, 0b1100_1010], 16).unwrap();
1035        assert_eq!(c, expected_c);
1036
1037        // Test with partial byte (4 bits)
1038        let mut d = BitVector4::from_raw_bytes(smallvec![0b0000_1010], 4).unwrap();
1039        d.not_inplace();
1040        let expected_d = BitVector4::from_raw_bytes(smallvec![0b0000_0101], 4).unwrap();
1041        assert_eq!(d, expected_d);
1042    }
1043}
1044
1045#[cfg(test)]
1046#[allow(clippy::cognitive_complexity)]
1047mod bitlist {
1048    use super::*;
1049    use crate::BitList;
1050
1051    pub type BitList0 = BitList<typenum::U0>;
1052    pub type BitList1 = BitList<typenum::U1>;
1053    pub type BitList8 = BitList<typenum::U8>;
1054    pub type BitList16 = BitList<typenum::U16>;
1055    pub type BitList1024 = BitList<typenum::U1024>;
1056
1057    #[test]
1058    fn ssz_encode() {
1059        assert_eq!(
1060            BitList0::with_capacity(0).unwrap().as_ssz_bytes(),
1061            vec![0b0000_0001],
1062        );
1063
1064        assert_eq!(
1065            BitList1::with_capacity(0).unwrap().as_ssz_bytes(),
1066            vec![0b0000_0001],
1067        );
1068
1069        assert_eq!(
1070            BitList1::with_capacity(1).unwrap().as_ssz_bytes(),
1071            vec![0b0000_0010],
1072        );
1073
1074        assert_eq!(
1075            BitList8::with_capacity(8).unwrap().as_ssz_bytes(),
1076            vec![0b0000_0000, 0b0000_0001],
1077        );
1078
1079        assert_eq!(
1080            BitList8::with_capacity(7).unwrap().as_ssz_bytes(),
1081            vec![0b1000_0000]
1082        );
1083
1084        let mut b = BitList8::with_capacity(8).unwrap();
1085        for i in 0..8 {
1086            b.set(i, true).unwrap();
1087        }
1088        assert_eq!(b.as_ssz_bytes(), vec![255, 0b0000_0001]);
1089
1090        let mut b = BitList8::with_capacity(8).unwrap();
1091        for i in 0..4 {
1092            b.set(i, true).unwrap();
1093        }
1094        assert_eq!(b.as_ssz_bytes(), vec![0b0000_1111, 0b0000_0001]);
1095
1096        assert_eq!(
1097            BitList16::with_capacity(16).unwrap().as_ssz_bytes(),
1098            vec![0b0000_0000, 0b0000_0000, 0b0000_0001]
1099        );
1100    }
1101
1102    #[test]
1103    fn ssz_decode() {
1104        assert!(BitList0::from_ssz_bytes(&[]).is_err());
1105        assert!(BitList1::from_ssz_bytes(&[]).is_err());
1106        assert!(BitList8::from_ssz_bytes(&[]).is_err());
1107        assert!(BitList16::from_ssz_bytes(&[]).is_err());
1108
1109        assert!(BitList0::from_ssz_bytes(&[0b0000_0000]).is_err());
1110        assert!(BitList1::from_ssz_bytes(&[0b0000_0000, 0b0000_0000]).is_err());
1111        assert!(BitList8::from_ssz_bytes(&[0b0000_0000]).is_err());
1112        assert!(BitList16::from_ssz_bytes(&[0b0000_0000]).is_err());
1113
1114        assert!(BitList0::from_ssz_bytes(&[0b0000_0001]).is_ok());
1115        assert!(BitList0::from_ssz_bytes(&[0b0000_0010]).is_err());
1116
1117        assert!(BitList1::from_ssz_bytes(&[0b0000_0001]).is_ok());
1118        assert!(BitList1::from_ssz_bytes(&[0b0000_0010]).is_ok());
1119        assert!(BitList1::from_ssz_bytes(&[0b0000_0100]).is_err());
1120
1121        assert!(BitList8::from_ssz_bytes(&[0b0000_0001]).is_ok());
1122        assert!(BitList8::from_ssz_bytes(&[0b0000_0010]).is_ok());
1123        assert!(BitList8::from_ssz_bytes(&[0b0000_0001, 0b0000_0001]).is_ok());
1124        assert!(BitList8::from_ssz_bytes(&[0b0000_0001, 0b0000_0010]).is_err());
1125        assert!(BitList8::from_ssz_bytes(&[0b0000_0001, 0b0000_0100]).is_err());
1126    }
1127
1128    #[test]
1129    fn ssz_decode_extra_bytes() {
1130        assert!(BitList0::from_ssz_bytes(&[0b0000_0001, 0b0000_0000]).is_err());
1131        assert!(BitList1::from_ssz_bytes(&[0b0000_0001, 0b0000_0000]).is_err());
1132        assert!(BitList8::from_ssz_bytes(&[0b0000_0001, 0b0000_0000]).is_err());
1133        assert!(BitList16::from_ssz_bytes(&[0b0000_0001, 0b0000_0000]).is_err());
1134        assert!(BitList1024::from_ssz_bytes(&[0b1000_0000, 0]).is_err());
1135        assert!(BitList1024::from_ssz_bytes(&[0b1000_0000, 0, 0]).is_err());
1136        assert!(BitList1024::from_ssz_bytes(&[0b1000_0000, 0, 0, 0, 0]).is_err());
1137    }
1138
1139    #[test]
1140    fn ssz_round_trip() {
1141        assert_round_trip(BitList0::with_capacity(0).unwrap());
1142
1143        for i in 0..2 {
1144            assert_round_trip(BitList1::with_capacity(i).unwrap());
1145        }
1146        for i in 0..9 {
1147            assert_round_trip(BitList8::with_capacity(i).unwrap());
1148        }
1149        for i in 0..17 {
1150            assert_round_trip(BitList16::with_capacity(i).unwrap());
1151        }
1152
1153        let mut b = BitList1::with_capacity(1).unwrap();
1154        b.set(0, true).unwrap();
1155        assert_round_trip(b);
1156
1157        for i in 0..8 {
1158            let mut b = BitList8::with_capacity(i).unwrap();
1159            for j in 0..i {
1160                if j % 2 == 0 {
1161                    b.set(j, true).unwrap();
1162                }
1163            }
1164            assert_round_trip(b);
1165
1166            let mut b = BitList8::with_capacity(i).unwrap();
1167            for j in 0..i {
1168                b.set(j, true).unwrap();
1169            }
1170            assert_round_trip(b);
1171        }
1172
1173        for i in 0..16 {
1174            let mut b = BitList16::with_capacity(i).unwrap();
1175            for j in 0..i {
1176                if j % 2 == 0 {
1177                    b.set(j, true).unwrap();
1178                }
1179            }
1180            assert_round_trip(b);
1181
1182            let mut b = BitList16::with_capacity(i).unwrap();
1183            for j in 0..i {
1184                b.set(j, true).unwrap();
1185            }
1186            assert_round_trip(b);
1187        }
1188    }
1189
1190    fn assert_round_trip<T: Encode + Decode + PartialEq + std::fmt::Debug>(t: T) {
1191        assert_eq!(T::from_ssz_bytes(&t.as_ssz_bytes()).unwrap(), t);
1192    }
1193
1194    #[test]
1195    fn from_raw_bytes() {
1196        assert!(BitList1024::from_raw_bytes(smallvec![0b0000_0000], 0).is_ok());
1197        assert!(BitList1024::from_raw_bytes(smallvec![0b0000_0001], 1).is_ok());
1198        assert!(BitList1024::from_raw_bytes(smallvec![0b0000_0011], 2).is_ok());
1199        assert!(BitList1024::from_raw_bytes(smallvec![0b0000_0111], 3).is_ok());
1200        assert!(BitList1024::from_raw_bytes(smallvec![0b0000_1111], 4).is_ok());
1201        assert!(BitList1024::from_raw_bytes(smallvec![0b0001_1111], 5).is_ok());
1202        assert!(BitList1024::from_raw_bytes(smallvec![0b0011_1111], 6).is_ok());
1203        assert!(BitList1024::from_raw_bytes(smallvec![0b0111_1111], 7).is_ok());
1204        assert!(BitList1024::from_raw_bytes(smallvec![0b1111_1111], 8).is_ok());
1205
1206        assert!(BitList1024::from_raw_bytes(smallvec![0b1111_1111, 0b0000_0001], 9).is_ok());
1207        assert!(BitList1024::from_raw_bytes(smallvec![0b1111_1111, 0b0000_0011], 10).is_ok());
1208        assert!(BitList1024::from_raw_bytes(smallvec![0b1111_1111, 0b0000_0111], 11).is_ok());
1209        assert!(BitList1024::from_raw_bytes(smallvec![0b1111_1111, 0b0000_1111], 12).is_ok());
1210        assert!(BitList1024::from_raw_bytes(smallvec![0b1111_1111, 0b0001_1111], 13).is_ok());
1211        assert!(BitList1024::from_raw_bytes(smallvec![0b1111_1111, 0b0011_1111], 14).is_ok());
1212        assert!(BitList1024::from_raw_bytes(smallvec![0b1111_1111, 0b0111_1111], 15).is_ok());
1213        assert!(BitList1024::from_raw_bytes(smallvec![0b1111_1111, 0b1111_1111], 16).is_ok());
1214
1215        for i in 0..8 {
1216            assert!(BitList1024::from_raw_bytes(smallvec![], i).is_err());
1217            assert!(BitList1024::from_raw_bytes(smallvec![0b1111_1111], i).is_err());
1218            assert!(BitList1024::from_raw_bytes(smallvec![0b0000_0000, 0b1111_1110], i).is_err());
1219        }
1220
1221        assert!(BitList1024::from_raw_bytes(smallvec![0b0000_0001], 0).is_err());
1222
1223        assert!(BitList1024::from_raw_bytes(smallvec![0b0000_0001], 0).is_err());
1224        assert!(BitList1024::from_raw_bytes(smallvec![0b0000_0011], 1).is_err());
1225        assert!(BitList1024::from_raw_bytes(smallvec![0b0000_0111], 2).is_err());
1226        assert!(BitList1024::from_raw_bytes(smallvec![0b0000_1111], 3).is_err());
1227        assert!(BitList1024::from_raw_bytes(smallvec![0b0001_1111], 4).is_err());
1228        assert!(BitList1024::from_raw_bytes(smallvec![0b0011_1111], 5).is_err());
1229        assert!(BitList1024::from_raw_bytes(smallvec![0b0111_1111], 6).is_err());
1230        assert!(BitList1024::from_raw_bytes(smallvec![0b1111_1111], 7).is_err());
1231
1232        assert!(BitList1024::from_raw_bytes(smallvec![0b1111_1111, 0b0000_0001], 8).is_err());
1233        assert!(BitList1024::from_raw_bytes(smallvec![0b1111_1111, 0b0000_0011], 9).is_err());
1234        assert!(BitList1024::from_raw_bytes(smallvec![0b1111_1111, 0b0000_0111], 10).is_err());
1235        assert!(BitList1024::from_raw_bytes(smallvec![0b1111_1111, 0b0000_1111], 11).is_err());
1236        assert!(BitList1024::from_raw_bytes(smallvec![0b1111_1111, 0b0001_1111], 12).is_err());
1237        assert!(BitList1024::from_raw_bytes(smallvec![0b1111_1111, 0b0011_1111], 13).is_err());
1238        assert!(BitList1024::from_raw_bytes(smallvec![0b1111_1111, 0b0111_1111], 14).is_err());
1239        assert!(BitList1024::from_raw_bytes(smallvec![0b1111_1111, 0b1111_1111], 15).is_err());
1240    }
1241
1242    fn test_set_unset(num_bits: usize) {
1243        let mut bitfield = BitList1024::with_capacity(num_bits).unwrap();
1244
1245        for i in 0..=num_bits {
1246            if i < num_bits {
1247                // Starts as false
1248                assert_eq!(bitfield.get(i), Ok(false));
1249                // Can be set true.
1250                assert!(bitfield.set(i, true).is_ok());
1251                assert_eq!(bitfield.get(i), Ok(true));
1252                // Can be set false
1253                assert!(bitfield.set(i, false).is_ok());
1254                assert_eq!(bitfield.get(i), Ok(false));
1255            } else {
1256                assert!(bitfield.get(i).is_err());
1257                assert!(bitfield.set(i, true).is_err());
1258                assert!(bitfield.get(i).is_err());
1259            }
1260        }
1261    }
1262
1263    fn test_bytes_round_trip(num_bits: usize) {
1264        for i in 0..num_bits {
1265            let mut bitfield = BitList1024::with_capacity(num_bits).unwrap();
1266            bitfield.set(i, true).unwrap();
1267
1268            let bytes = bitfield.clone().into_raw_bytes();
1269            assert_eq!(bitfield, Bitfield::from_raw_bytes(bytes, num_bits).unwrap());
1270        }
1271    }
1272
1273    #[test]
1274    fn set_unset() {
1275        for i in 0..8 * 5 {
1276            test_set_unset(i)
1277        }
1278    }
1279
1280    #[test]
1281    fn bytes_round_trip() {
1282        for i in 0..8 * 5 {
1283            test_bytes_round_trip(i)
1284        }
1285    }
1286
1287    /// Type-specialised `smallvec` macro for testing.
1288    macro_rules! bytevec {
1289        ($($x : expr),* $(,)*) => {
1290            {
1291                let __smallvec: SmallVec<[u8; SMALLVEC_LEN]> = smallvec!($($x),*);
1292                __smallvec
1293            }
1294        };
1295    }
1296
1297    #[test]
1298    fn into_raw_bytes() {
1299        let mut bitfield = BitList1024::with_capacity(9).unwrap();
1300        bitfield.set(0, true).unwrap();
1301        assert_eq!(
1302            bitfield.clone().into_raw_bytes(),
1303            bytevec![0b0000_0001, 0b0000_0000]
1304        );
1305        bitfield.set(1, true).unwrap();
1306        assert_eq!(
1307            bitfield.clone().into_raw_bytes(),
1308            bytevec![0b0000_0011, 0b0000_0000]
1309        );
1310        bitfield.set(2, true).unwrap();
1311        assert_eq!(
1312            bitfield.clone().into_raw_bytes(),
1313            bytevec![0b0000_0111, 0b0000_0000]
1314        );
1315        bitfield.set(3, true).unwrap();
1316        assert_eq!(
1317            bitfield.clone().into_raw_bytes(),
1318            bytevec![0b0000_1111, 0b0000_0000]
1319        );
1320        bitfield.set(4, true).unwrap();
1321        assert_eq!(
1322            bitfield.clone().into_raw_bytes(),
1323            bytevec![0b0001_1111, 0b0000_0000]
1324        );
1325        bitfield.set(5, true).unwrap();
1326        assert_eq!(
1327            bitfield.clone().into_raw_bytes(),
1328            bytevec![0b0011_1111, 0b0000_0000]
1329        );
1330        bitfield.set(6, true).unwrap();
1331        assert_eq!(
1332            bitfield.clone().into_raw_bytes(),
1333            bytevec![0b0111_1111, 0b0000_0000]
1334        );
1335        bitfield.set(7, true).unwrap();
1336        assert_eq!(
1337            bitfield.clone().into_raw_bytes(),
1338            bytevec![0b1111_1111, 0b0000_0000]
1339        );
1340        bitfield.set(8, true).unwrap();
1341        assert_eq!(
1342            bitfield.into_raw_bytes(),
1343            bytevec![0b1111_1111, 0b0000_0001]
1344        );
1345    }
1346
1347    #[test]
1348    fn highest_set_bit() {
1349        assert_eq!(
1350            BitList1024::with_capacity(16).unwrap().highest_set_bit(),
1351            None
1352        );
1353
1354        assert_eq!(
1355            BitList1024::from_raw_bytes(smallvec![0b0000_0001, 0b0000_0000], 16)
1356                .unwrap()
1357                .highest_set_bit(),
1358            Some(0)
1359        );
1360
1361        assert_eq!(
1362            BitList1024::from_raw_bytes(smallvec![0b0000_0010, 0b0000_0000], 16)
1363                .unwrap()
1364                .highest_set_bit(),
1365            Some(1)
1366        );
1367
1368        assert_eq!(
1369            BitList1024::from_raw_bytes(smallvec![0b0000_1000], 8)
1370                .unwrap()
1371                .highest_set_bit(),
1372            Some(3)
1373        );
1374
1375        assert_eq!(
1376            BitList1024::from_raw_bytes(smallvec![0b0000_0000, 0b1000_0000], 16)
1377                .unwrap()
1378                .highest_set_bit(),
1379            Some(15)
1380        );
1381    }
1382
1383    #[test]
1384    fn intersection() {
1385        let a = BitList1024::from_raw_bytes(smallvec![0b1100, 0b0001], 16).unwrap();
1386        let b = BitList1024::from_raw_bytes(smallvec![0b1011, 0b1001], 16).unwrap();
1387        let c = BitList1024::from_raw_bytes(smallvec![0b1000, 0b0001], 16).unwrap();
1388
1389        assert_eq!(a.intersection(&b), c);
1390        assert_eq!(b.intersection(&a), c);
1391        assert_eq!(a.intersection(&c), c);
1392        assert_eq!(b.intersection(&c), c);
1393        assert_eq!(a.intersection(&a), a);
1394        assert_eq!(b.intersection(&b), b);
1395        assert_eq!(c.intersection(&c), c);
1396    }
1397
1398    #[test]
1399    fn subset() {
1400        let a = BitList1024::from_raw_bytes(smallvec![0b1000, 0b0001], 16).unwrap();
1401        let b = BitList1024::from_raw_bytes(smallvec![0b1100, 0b0001], 16).unwrap();
1402        let c = BitList1024::from_raw_bytes(smallvec![0b1100, 0b1001], 16).unwrap();
1403
1404        assert_eq!(a.len(), 16);
1405        assert_eq!(b.len(), 16);
1406        assert_eq!(c.len(), 16);
1407
1408        // a vector is always a subset of itself
1409        assert!(a.is_subset(&a));
1410        assert!(b.is_subset(&b));
1411        assert!(c.is_subset(&c));
1412
1413        assert!(a.is_subset(&b));
1414        assert!(a.is_subset(&c));
1415        assert!(b.is_subset(&c));
1416
1417        assert!(!b.is_subset(&a));
1418        assert!(!c.is_subset(&a));
1419        assert!(!c.is_subset(&b));
1420
1421        let d = BitList1024::from_raw_bytes(smallvec![0b1100, 0b1001, 0b1010], 24).unwrap();
1422        assert!(d.is_subset(&d));
1423
1424        assert!(a.is_subset(&d));
1425        assert!(b.is_subset(&d));
1426        assert!(c.is_subset(&d));
1427
1428        // A bigger length bitlist cannot be a subset of a smaller length bitlist
1429        assert!(!d.is_subset(&a));
1430        assert!(!d.is_subset(&b));
1431        assert!(!d.is_subset(&c));
1432
1433        let e = BitList1024::from_raw_bytes(smallvec![0b1100, 0b1001, 0b0000], 24).unwrap();
1434        assert!(e.is_subset(&c));
1435        assert!(c.is_subset(&e));
1436    }
1437
1438    #[test]
1439    fn intersection_diff_length() {
1440        let a = BitList1024::from_bytes(smallvec![0b0010_1110, 0b0010_1011]).unwrap();
1441        let b = BitList1024::from_bytes(smallvec![0b0010_1101, 0b0000_0001]).unwrap();
1442        let c = BitList1024::from_bytes(smallvec![0b0010_1100, 0b0000_0001]).unwrap();
1443        let d = BitList1024::from_bytes(smallvec![0b0010_1110, 0b1111_1111, 0b1111_1111]).unwrap();
1444
1445        assert_eq!(a.len(), 13);
1446        assert_eq!(b.len(), 8);
1447        assert_eq!(c.len(), 8);
1448        assert_eq!(d.len(), 23);
1449        assert_eq!(a.intersection(&b), c);
1450        assert_eq!(b.intersection(&a), c);
1451        assert_eq!(a.intersection(&d), a);
1452        assert_eq!(d.intersection(&a), a);
1453    }
1454
1455    #[test]
1456    fn union() {
1457        let a = BitList1024::from_raw_bytes(smallvec![0b1100, 0b0001], 16).unwrap();
1458        let b = BitList1024::from_raw_bytes(smallvec![0b1011, 0b1001], 16).unwrap();
1459        let c = BitList1024::from_raw_bytes(smallvec![0b1111, 0b1001], 16).unwrap();
1460
1461        assert_eq!(a.union(&b), c);
1462        assert_eq!(b.union(&a), c);
1463        assert_eq!(a.union(&a), a);
1464        assert_eq!(b.union(&b), b);
1465        assert_eq!(c.union(&c), c);
1466    }
1467
1468    #[test]
1469    fn union_diff_length() {
1470        let a = BitList1024::from_bytes(smallvec![0b0010_1011, 0b0010_1110]).unwrap();
1471        let b = BitList1024::from_bytes(smallvec![0b0000_0001, 0b0010_1101]).unwrap();
1472        let c = BitList1024::from_bytes(smallvec![0b0010_1011, 0b0010_1111]).unwrap();
1473        let d = BitList1024::from_bytes(smallvec![0b0010_1011, 0b1011_1110, 0b1000_1101]).unwrap();
1474
1475        assert_eq!(a.len(), c.len());
1476        assert_eq!(a.union(&b), c);
1477        assert_eq!(b.union(&a), c);
1478        assert_eq!(a.union(&d), d);
1479        assert_eq!(d.union(&a), d);
1480    }
1481
1482    #[test]
1483    fn difference() {
1484        let a = BitList1024::from_raw_bytes(smallvec![0b1100, 0b0001], 16).unwrap();
1485        let b = BitList1024::from_raw_bytes(smallvec![0b1011, 0b1001], 16).unwrap();
1486        let a_b = BitList1024::from_raw_bytes(smallvec![0b0100, 0b0000], 16).unwrap();
1487        let b_a = BitList1024::from_raw_bytes(smallvec![0b0011, 0b1000], 16).unwrap();
1488
1489        assert_eq!(a.difference(&b), a_b);
1490        assert_eq!(b.difference(&a), b_a);
1491        assert!(a.difference(&a).is_zero());
1492    }
1493
1494    #[test]
1495    fn difference_diff_length() {
1496        let a = BitList1024::from_raw_bytes(smallvec![0b0110, 0b1100, 0b0011], 24).unwrap();
1497        let b = BitList1024::from_raw_bytes(smallvec![0b1011, 0b1001], 16).unwrap();
1498        let a_b = BitList1024::from_raw_bytes(smallvec![0b0100, 0b0100, 0b0011], 24).unwrap();
1499        let b_a = BitList1024::from_raw_bytes(smallvec![0b1001, 0b0001], 16).unwrap();
1500
1501        assert_eq!(a.difference(&b), a_b);
1502        assert_eq!(b.difference(&a), b_a);
1503    }
1504
1505    #[test]
1506    fn shift_up() {
1507        let mut a = BitList1024::from_raw_bytes(smallvec![0b1100_1111, 0b1101_0110], 16).unwrap();
1508        let mut b = BitList1024::from_raw_bytes(smallvec![0b1001_1110, 0b1010_1101], 16).unwrap();
1509
1510        a.shift_up(1).unwrap();
1511        assert_eq!(a, b);
1512        a.shift_up(15).unwrap();
1513        assert!(a.is_zero());
1514
1515        b.shift_up(16).unwrap();
1516        assert!(b.is_zero());
1517        assert!(b.shift_up(17).is_err());
1518    }
1519
1520    #[test]
1521    fn num_set_bits() {
1522        let a = BitList1024::from_raw_bytes(smallvec![0b1100, 0b0001], 16).unwrap();
1523        let b = BitList1024::from_raw_bytes(smallvec![0b1011, 0b1001], 16).unwrap();
1524
1525        assert_eq!(a.num_set_bits(), 3);
1526        assert_eq!(b.num_set_bits(), 5);
1527    }
1528
1529    #[test]
1530    fn iter() {
1531        let mut bitfield = BitList1024::with_capacity(9).unwrap();
1532        bitfield.set(2, true).unwrap();
1533        bitfield.set(8, true).unwrap();
1534
1535        assert_eq!(
1536            bitfield.iter().collect::<Vec<bool>>(),
1537            vec![false, false, true, false, false, false, false, false, true]
1538        );
1539    }
1540
1541    #[test]
1542    fn ssz_bytes_len() {
1543        for i in 1..64 {
1544            let mut bitfield = BitList1024::with_capacity(i).unwrap();
1545            for j in 0..i {
1546                bitfield.set(j, true).expect("should set bit in bounds");
1547            }
1548            let bytes = bitfield.as_ssz_bytes();
1549            assert_eq!(bitfield.ssz_bytes_len(), bytes.len(), "i = {}", i);
1550        }
1551    }
1552
1553    // Ensure that the stack size of a BitList is manageable.
1554    #[test]
1555    fn size_of() {
1556        assert_eq!(std::mem::size_of::<BitList1024>(), SMALLVEC_LEN + 24);
1557    }
1558
1559    #[test]
1560    fn resize() {
1561        let mut bit_list = BitList1::with_capacity(1).unwrap();
1562        bit_list.set(0, true).unwrap();
1563        assert_eq!(bit_list.len(), 1);
1564        assert_eq!(bit_list.num_set_bits(), 1);
1565        assert_eq!(bit_list.highest_set_bit().unwrap(), 0);
1566
1567        let resized_bit_list = bit_list.resize::<typenum::U1024>().unwrap();
1568        assert_eq!(resized_bit_list.len(), 1024);
1569        assert_eq!(resized_bit_list.num_set_bits(), 1);
1570        assert_eq!(resized_bit_list.highest_set_bit().unwrap(), 0);
1571
1572        // Can't extend a BitList to a smaller BitList
1573        resized_bit_list.resize::<typenum::U16>().unwrap_err();
1574    }
1575
1576    #[test]
1577    fn over_capacity_err() {
1578        let e = BitList8::with_capacity(9).expect_err("over-sized bit list");
1579        assert_eq!(e, Error::OutOfBounds { i: 9, len: 8 });
1580    }
1581
1582    #[test]
1583    fn clone_zeroed() {
1584        let mut bitfield = BitList1024::with_capacity(16).unwrap();
1585        bitfield.set(0, true).unwrap();
1586        bitfield.set(5, true).unwrap();
1587        bitfield.set(15, true).unwrap();
1588
1589        let zeroed = bitfield.clone_zeroed();
1590        assert_eq!(zeroed.len(), 16);
1591        assert!(zeroed.is_zero());
1592
1593        let mut bitfield = BitList1::with_capacity(1).unwrap();
1594        bitfield.set(0, true).unwrap();
1595
1596        let zeroed = bitfield.clone_zeroed();
1597        assert_eq!(zeroed.len(), 1);
1598        assert!(zeroed.is_zero());
1599
1600        let empty = BitList0::with_capacity(0).unwrap();
1601        let zeroed_empty = empty.clone_zeroed();
1602        assert_eq!(zeroed_empty.len(), 0);
1603        assert!(zeroed_empty.is_zero());
1604    }
1605
1606    #[test]
1607    fn display() {
1608        let bitlist = BitList1024::from_raw_bytes(smallvec![0b0011_1111, 0b0001_0101], 15).unwrap();
1609        assert_eq!("111111001010100", bitlist.to_string());
1610    }
1611
1612    #[test]
1613    fn not() {
1614        // Test with all zeros
1615        let a = BitList8::with_capacity(8).unwrap();
1616        let mut expected = BitList8::with_capacity(8).unwrap();
1617        for i in 0..8 {
1618            expected.set(i, true).unwrap();
1619        }
1620        assert_eq!(a.not(), expected);
1621
1622        // Test with all ones
1623        let b = expected.clone();
1624        assert_eq!(b.not(), BitList8::with_capacity(8).unwrap());
1625
1626        // Test with mixed pattern
1627        let c = BitList16::from_raw_bytes(smallvec![0b1100_1010, 0b0011_0101], 16).unwrap();
1628        let expected_c =
1629            BitList16::from_raw_bytes(smallvec![0b0011_0101, 0b1100_1010], 16).unwrap();
1630        assert_eq!(c.not(), expected_c);
1631
1632        // Test with partial byte (5 bits)
1633        let d = BitList8::from_raw_bytes(smallvec![0b0001_1010], 5).unwrap();
1634        let expected_d = BitList8::from_raw_bytes(smallvec![0b0000_0101], 5).unwrap();
1635        assert_eq!(d.not(), expected_d);
1636
1637        // Test that masking works correctly for partial bytes
1638        let e = BitList8::from_raw_bytes(smallvec![0b0001_1111], 5).unwrap();
1639        let expected_e = BitList8::from_raw_bytes(smallvec![0b0000_0000], 5).unwrap();
1640        assert_eq!(e.not(), expected_e);
1641        let f = BitList8::from_raw_bytes(smallvec![0b0000_0001], 5).unwrap();
1642        let expected_f = BitList8::from_raw_bytes(smallvec![0b0001_1110], 5).unwrap();
1643        assert_eq!(f.not(), expected_f);
1644
1645        // Test with zero-length bitlist
1646        let g = BitList0::with_capacity(0).unwrap();
1647        assert_eq!(g.not(), g);
1648    }
1649
1650    #[test]
1651    fn not_inplace() {
1652        // Test with all zeros
1653        let mut a = BitList8::with_capacity(8).unwrap();
1654        a.not_inplace();
1655        let mut expected = BitList8::with_capacity(8).unwrap();
1656        for i in 0..8 {
1657            expected.set(i, true).unwrap();
1658        }
1659        assert_eq!(a, expected);
1660
1661        // Test with all ones
1662        let mut b = expected.clone();
1663        b.not_inplace();
1664        assert_eq!(b, BitList8::with_capacity(8).unwrap());
1665
1666        // Test with mixed pattern
1667        let mut c = BitList16::from_raw_bytes(smallvec![0b1100_1010, 0b0011_0101], 16).unwrap();
1668        c.not_inplace();
1669        let expected_c =
1670            BitList16::from_raw_bytes(smallvec![0b0011_0101, 0b1100_1010], 16).unwrap();
1671        assert_eq!(c, expected_c);
1672
1673        // Test with partial byte (5 bits)
1674        let mut d = BitList8::from_raw_bytes(smallvec![0b0001_1010], 5).unwrap();
1675        d.not_inplace();
1676        let expected_d = BitList8::from_raw_bytes(smallvec![0b0000_0101], 5).unwrap();
1677        assert_eq!(d, expected_d);
1678
1679        // Test with zero-length bitlist
1680        let mut f = BitList0::with_capacity(0).unwrap();
1681        let expected_f = f.clone();
1682        f.not_inplace();
1683        assert_eq!(f, expected_f);
1684    }
1685}