ssz/bitfield/
bitvector_dynamic.rs

1//! Provides `Bitfield<Dynamic>` (BitVectorDynamic)
2/// for encoding and decoding bitvectors that have a dynamic length.
3use crate::{
4    bitfield::{bytes_for_bit_len, Bitfield, BitfieldBehaviour, Error, SMALLVEC_LEN},
5    Decode, DecodeError, Encode,
6};
7use core::marker::PhantomData;
8use serde::de::{Deserialize, Deserializer};
9use serde::ser::{Serialize, Serializer};
10use serde_utils::hex::{encode as hex_encode, PrefixedHexVisitor};
11use smallvec::{smallvec, SmallVec, ToSmallVec};
12
13/// A marker struct used to declare dynamic length behaviour on a Bitfield.
14#[derive(Clone, PartialEq, Eq, Debug)]
15pub struct Dynamic;
16
17/// A heap-allocated, ordered collection of `bool` values with a length set at runtime.
18pub type BitVectorDynamic = Bitfield<Dynamic>;
19
20impl BitfieldBehaviour for Dynamic {}
21
22impl Bitfield<Dynamic> {
23    /// Create a new dynamic bitfield with the given length (all bits false).
24    /// Length must be a multiple of 8 bits and greater than 0.
25    pub fn new(len: usize) -> Result<Self, Error> {
26        if len == 0 {
27            return Err(Error::InvalidByteCount {
28                given: len,
29                expected: 8,
30            });
31        }
32        if !len.is_multiple_of(8) {
33            return Err(Error::InvalidByteCount {
34                given: len,
35                expected: (len / 8 + 1) * 8,
36            });
37        }
38        Ok(Self {
39            bytes: smallvec![0; bytes_for_bit_len(len)],
40            len,
41            _phantom: PhantomData,
42        })
43    }
44
45    pub fn into_bytes(self) -> SmallVec<[u8; SMALLVEC_LEN]> {
46        self.into_raw_bytes()
47    }
48
49    /// Create a dynamic bitfield from raw bytes and a declared logical length.
50    /// Can be used to specify a max_length if called directly instead of via Decode.
51    pub fn from_bytes_with_len(
52        bytes: SmallVec<[u8; SMALLVEC_LEN]>,
53        len: usize,
54    ) -> Result<Self, Error> {
55        if len != bytes.len() * 8 {
56            return Err(Error::InvalidByteCount {
57                given: len,
58                expected: bytes.len() * 8,
59            });
60        }
61        Self::from_raw_bytes(bytes, len)
62    }
63
64    /// Compute the intersection of two bitfields.
65    pub fn intersection(&self, other: &Self) -> Result<Self, Error> {
66        let max_len = std::cmp::max(self.len(), other.len());
67        let mut result = Self::new(max_len)?;
68
69        for (i, byte) in result.bytes.iter_mut().enumerate() {
70            *byte =
71                self.bytes.get(i).copied().unwrap_or(0) & other.bytes.get(i).copied().unwrap_or(0);
72        }
73        Ok(result)
74    }
75
76    /// Compute the union of two bitfields.
77    pub fn union(&self, other: &Self) -> Result<Self, Error> {
78        let max_len = std::cmp::max(self.len(), other.len());
79        let mut result = Self::new(max_len)?;
80
81        for (i, byte) in result.bytes.iter_mut().enumerate() {
82            *byte =
83                self.bytes.get(i).copied().unwrap_or(0) | other.bytes.get(i).copied().unwrap_or(0);
84        }
85        Ok(result)
86    }
87}
88
89impl Encode for Bitfield<Dynamic> {
90    fn is_ssz_fixed_len() -> bool {
91        false
92    }
93
94    fn ssz_bytes_len(&self) -> usize {
95        self.bytes.len()
96    }
97
98    fn ssz_append(&self, buf: &mut Vec<u8>) {
99        buf.extend_from_slice(&self.bytes)
100    }
101}
102
103impl Decode for Bitfield<Dynamic> {
104    fn is_ssz_fixed_len() -> bool {
105        false
106    }
107
108    fn from_ssz_bytes(bytes: &[u8]) -> Result<Self, DecodeError> {
109        if bytes.is_empty() {
110            return Err(DecodeError::BytesInvalid("Empty bytes".into()));
111        }
112
113        let len = bytes.len() * 8;
114        Self::from_raw_bytes(bytes.to_smallvec(), len).map_err(|e| {
115            DecodeError::BytesInvalid(format!("BitVectorDynamic failed to decode: {:?}", e))
116        })
117    }
118}
119
120impl Serialize for Bitfield<Dynamic> {
121    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
122    where
123        S: Serializer,
124    {
125        serializer.serialize_str(&hex_encode(self.as_ssz_bytes()))
126    }
127}
128
129impl<'de> Deserialize<'de> for Bitfield<Dynamic> {
130    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
131    where
132        D: Deserializer<'de>,
133    {
134        let bytes = deserializer.deserialize_str(PrefixedHexVisitor)?;
135        Self::from_ssz_bytes(&bytes)
136            .map_err(|e| serde::de::Error::custom(format!("BitVectorDynamic {:?}", e)))
137    }
138}
139
140#[cfg(test)]
141mod dynamic_bitfield_tests {
142    use super::*;
143
144    #[test]
145    fn test_basic_operations() -> Result<(), Error> {
146        let mut bitfield = BitVectorDynamic::new(16)?;
147
148        // Set/get within bounds should succeed or return Ok
149        assert!(bitfield.set(0, true).is_ok());
150        assert!(bitfield.set(15, true).is_ok());
151        assert!(bitfield.set(16, true).is_err()); // Out of bounds
152
153        assert_eq!(bitfield.get(0)?, true);
154        assert_eq!(bitfield.get(15)?, true);
155        assert!(bitfield.get(16).is_err());
156
157        Ok(())
158    }
159
160    #[test]
161    fn test_ssz_encode_decode() -> Result<(), Error> {
162        let mut bitfield = BitVectorDynamic::new(8)?;
163        bitfield.set(0, true)?;
164        bitfield.set(7, true)?;
165
166        // Convert to raw bytes and decode with a known length
167        let bytes = bitfield.clone().into_bytes();
168        let decoded = BitVectorDynamic::from_bytes_with_len(bytes, 8)?;
169
170        assert_eq!(bitfield, decoded);
171
172        // Test invalid cases
173        assert!(BitVectorDynamic::from_bytes_with_len(smallvec![], 8).is_err());
174        assert!(BitVectorDynamic::from_bytes_with_len(smallvec![0, 0, 0], 8).is_err());
175
176        Ok(())
177    }
178
179    #[test]
180    fn test_ssz_decode_errors() {
181        // Test empty bytes error
182        let empty_bytes: &[u8] = &[];
183        assert!(matches!(
184            BitVectorDynamic::from_ssz_bytes(empty_bytes),
185            Err(DecodeError::BytesInvalid(msg)) if msg == "Empty bytes"
186        ));
187    }
188
189    #[test]
190    fn test_intersection() -> Result<(), Error> {
191        let mut a = BitVectorDynamic::new(16)?;
192        let mut b = BitVectorDynamic::new(16)?;
193        let mut expected = BitVectorDynamic::new(16)?;
194
195        a.set(1, true)?;
196        a.set(3, true)?;
197        b.set(3, true)?;
198        b.set(4, true)?;
199        expected.set(3, true)?;
200
201        assert_eq!(a.intersection(&b)?, expected);
202        assert_eq!(b.intersection(&a)?, expected);
203
204        Ok(())
205    }
206
207    #[test]
208    fn test_union() -> Result<(), Error> {
209        let mut a = BitVectorDynamic::new(16)?;
210        let mut b = BitVectorDynamic::new(16)?;
211        let mut expected = BitVectorDynamic::new(16)?;
212
213        a.set(1, true)?;
214        a.set(3, true)?;
215        b.set(3, true)?;
216        b.set(4, true)?;
217
218        expected.set(1, true)?;
219        expected.set(3, true)?;
220        expected.set(4, true)?;
221
222        assert_eq!(a.union(&b)?, expected);
223        assert_eq!(b.union(&a)?, expected);
224
225        Ok(())
226    }
227
228    #[test]
229    fn test_highest_set_bit() -> Result<(), Error> {
230        let mut bitfield = BitVectorDynamic::new(16)?;
231        assert_eq!(bitfield.highest_set_bit(), None);
232
233        bitfield.set(3, true)?;
234        assert_eq!(bitfield.highest_set_bit(), Some(3));
235
236        bitfield.set(15, true)?;
237        assert_eq!(bitfield.highest_set_bit(), Some(15));
238
239        Ok(())
240    }
241
242    #[test]
243    fn test_is_zero() -> Result<(), Error> {
244        let mut bitfield = BitVectorDynamic::new(16)?;
245        assert!(bitfield.is_zero());
246
247        bitfield.set(0, true)?;
248        assert!(!bitfield.is_zero());
249
250        bitfield.set(0, false)?;
251        assert!(bitfield.is_zero());
252
253        Ok(())
254    }
255
256    #[test]
257    fn test_num_set_bits() -> Result<(), Error> {
258        let mut bitfield = BitVectorDynamic::new(16)?;
259        assert_eq!(bitfield.num_set_bits(), 0);
260
261        bitfield.set(1, true)?;
262        bitfield.set(3, true)?;
263        bitfield.set(7, true)?;
264        assert_eq!(bitfield.num_set_bits(), 3);
265
266        Ok(())
267    }
268
269    #[test]
270    fn test_difference() -> Result<(), Error> {
271        let mut a = BitVectorDynamic::new(16)?;
272        let mut b = BitVectorDynamic::new(16)?;
273
274        a.set(1, true)?;
275        a.set(3, true)?;
276        b.set(3, true)?;
277        b.set(4, true)?;
278
279        let diff = a.difference(&b);
280        assert_eq!(diff.get(1)?, true);
281        assert_eq!(diff.get(3)?, false);
282        assert_eq!(diff.get(4)?, false);
283
284        Ok(())
285    }
286
287    #[test]
288    fn test_shift_up() -> Result<(), Error> {
289        let mut bitfield = BitVectorDynamic::new(16)?;
290        bitfield.set(0, true)?;
291        bitfield.set(1, true)?;
292
293        bitfield.shift_up(1)?;
294        assert_eq!(bitfield.get(0)?, false);
295        assert_eq!(bitfield.get(1)?, true);
296        assert_eq!(bitfield.get(2)?, true);
297
298        // Test error case
299        assert!(bitfield.shift_up(17).is_err());
300
301        Ok(())
302    }
303
304    #[test]
305    fn test_iter() -> Result<(), Error> {
306        let mut bitfield = BitVectorDynamic::new(8)?;
307        bitfield.set(1, true)?;
308        bitfield.set(4, true)?;
309        bitfield.set(7, true)?;
310
311        let bits: Vec<bool> = bitfield.iter().collect();
312        assert_eq!(
313            bits,
314            vec![false, true, false, false, true, false, false, true]
315        );
316
317        Ok(())
318    }
319
320    #[test]
321    fn test_non_byte_aligned_lengths() {
322        // Zero bits should error
323        assert!(BitVectorDynamic::new(0).is_err());
324
325        // 1 bit should error
326        assert!(BitVectorDynamic::new(1).is_err());
327
328        // 7 bits should error
329        assert!(BitVectorDynamic::new(7).is_err());
330
331        // 8 bits should succeed
332        assert!(BitVectorDynamic::new(8).is_ok());
333
334        // 9 bits should error
335        assert!(BitVectorDynamic::new(9).is_err());
336    }
337
338    #[test]
339    fn test_encode_decode() -> Result<(), Error> {
340        let mut bitfield = BitVectorDynamic::new(32)?;
341        bitfield.set(0, true)?;
342        bitfield.set(16, true)?;
343        bitfield.set(31, true)?;
344
345        let expected: SmallVec<[u8; 4]> =
346            smallvec![0b0000_0001, 0b0000_0000, 0b0000_0001, 0b1000_0000];
347        let bytes = bitfield.clone().into_bytes();
348        assert_eq!(bytes, expected);
349
350        let encoded = bitfield.as_ssz_bytes();
351        let decoded = BitVectorDynamic::from_ssz_bytes(&encoded)
352            .map_err(|_| Error::InvalidByteCount {
353                given: encoded.len(),
354                expected: bytes.len(),
355            })?
356            .into_bytes();
357
358        assert_eq!(bytes, decoded);
359
360        Ok(())
361    }
362
363    #[test]
364    fn test_from_bytes_equivalence() -> Result<(), Error> {
365        let mut original = BitVectorDynamic::new(16)?;
366        original.set(0, true)?;
367        original.set(8, true)?;
368        original.set(15, true)?;
369
370        let bytes = original.clone().into_bytes();
371
372        // Both methods should produce equivalent results
373        let from_bytes = BitVectorDynamic::from_bytes_with_len(bytes.clone(), 16)?;
374        let from_ssz =
375            BitVectorDynamic::from_ssz_bytes(&bytes).map_err(|_| Error::InvalidByteCount {
376                given: 0,
377                expected: 1,
378            })?;
379
380        assert_eq!(from_bytes, from_ssz);
381        assert_eq!(from_bytes, original);
382
383        Ok(())
384    }
385
386    #[test]
387    fn test_from_bytes_length_validation() {
388        // Create bytes representing a 32-bit field
389        let bytes = smallvec![0b1111_1111; 4]; // 32 bits set
390
391        // Trying to decode as 16 bits should fail
392        assert!(matches!(
393            BitVectorDynamic::from_bytes_with_len(bytes.clone(), 16),
394            Err(Error::InvalidByteCount { .. })
395        ));
396
397        // Decoding with correct length should succeed
398        assert!(BitVectorDynamic::from_bytes_with_len(bytes, 32).is_ok());
399    }
400
401    #[test]
402    fn test_ssz_bytes_len() -> Result<(), Error> {
403        let bitfield = BitVectorDynamic::new(16)?; // 16 bits = 2 bytes
404        assert_eq!(bitfield.ssz_bytes_len(), 2);
405
406        let bitfield = BitVectorDynamic::new(24)?; // 24 bits = 3 bytes
407        assert_eq!(bitfield.ssz_bytes_len(), 3);
408
409        Ok(())
410    }
411}
412#[cfg(test)]
413mod roundtrip_tests {
414    use super::*;
415    fn assert_round_trip_bitdyn<T>(t: T) -> Result<(), Error>
416    where
417        T: Encode + Decode + PartialEq + std::fmt::Debug,
418    {
419        let bytes = t.as_ssz_bytes();
420        let decoded = T::from_ssz_bytes(&bytes).expect("decode failed in test");
421        assert_eq!(decoded, t, "BitDyn must SSZ-roundtrip correctly.");
422        Ok(())
423    }
424
425    #[test]
426    fn bitdyn_ssz_round_trip() -> Result<(), Error> {
427        // length = 8, set even bits
428        let mut b = BitVectorDynamic::new(8)?;
429        for j in 0..8 {
430            if j % 2 == 0 {
431                b.set(j, true)?;
432            }
433        }
434        assert_round_trip_bitdyn(b)?;
435
436        // length = 16, all bits
437        let mut b = BitVectorDynamic::new(16)?;
438        for j in 0..16 {
439            b.set(j, true)?;
440        }
441        assert_round_trip_bitdyn(b)?;
442
443        Ok(())
444    }
445
446    #[test]
447    fn test_ssz_roundtrip_various_sizes() -> Result<(), Error> {
448        // Test empty vector (8 bits)
449        let empty = BitVectorDynamic::new(8)?;
450        assert_round_trip_bitdyn(empty)?;
451
452        // Test partially filled vector (16 bits)
453        let mut partial = BitVectorDynamic::new(16)?;
454        partial.set(0, true)?;
455        partial.set(8, true)?;
456        partial.set(15, true)?;
457        assert_round_trip_bitdyn(partial)?;
458
459        // Test fully filled vector (24 bits)
460        let mut full = BitVectorDynamic::new(24)?;
461        for i in 0..24 {
462            full.set(i, true)?;
463        }
464        assert_round_trip_bitdyn(full)?;
465
466        // Test alternating pattern (32 bits)
467        let mut alternating = BitVectorDynamic::new(32)?;
468        for i in 0..32 {
469            alternating.set(i, i % 2 == 0)?;
470        }
471        assert_round_trip_bitdyn(alternating)?;
472
473        // Test fully filled vector (128 bits)
474        let mut full = BitVectorDynamic::new(128)?;
475        for i in 0..128 {
476            full.set(i, true)?;
477        }
478        assert_round_trip_bitdyn(full)?;
479
480        Ok(())
481    }
482
483    #[test]
484    fn test_serde_roundtrip() -> Result<(), Error> {
485        use serde_json::de::Deserializer as json_deserializer;
486        use serde_json::ser::Serializer as json_serializer;
487
488        // Create a test BitVectorDynamic
489        let mut b = BitVectorDynamic::new(16)?;
490        b.set(0, true)?;
491        b.set(7, true)?;
492
493        // Create a string buffer and serializer
494        let mut output = Vec::new();
495        let mut serializer = json_serializer::new(&mut output);
496
497        // Call serialize
498        b.serialize(&mut serializer).expect("Serialization failed");
499
500        // Create a deserializer
501        let json = String::from_utf8(output).expect("UTF-8 error");
502        let mut deserializer = json_deserializer::from_str(&json);
503
504        // Call deserialize
505        let deserialized =
506            BitVectorDynamic::deserialize(&mut deserializer).expect("Deserialization failed");
507
508        assert_eq!(b, deserialized);
509
510        Ok(())
511    }
512}