Skip to main content

ssz/bitfield/
bitvector_dynamic.rs

1//! Provides `Bitfield<Dynamic>` (BitVectorDynamic)
2/// for encoding and decoding bitvectors that have a dynamic length.
3use crate::{
4    bitfield::{bytes_for_bit_len, Bitfield, BitfieldBehaviour, Error, SMALLVEC_LEN},
5    Decode, DecodeError, Encode,
6};
7use core::marker::PhantomData;
8use serde::de::{Deserialize, Deserializer};
9use serde::ser::{Serialize, Serializer};
10use serde_utils::hex::{encode as hex_encode, PrefixedHexVisitor};
11use smallvec::{smallvec, SmallVec, ToSmallVec};
12
13/// A marker struct used to declare dynamic length behaviour on a Bitfield.
14#[derive(Clone, PartialEq, Eq, Debug)]
15pub struct Dynamic;
16
17/// A heap-allocated, ordered collection of `bool` values with a length set at runtime.
18pub type BitVectorDynamic = Bitfield<Dynamic>;
19
20impl BitfieldBehaviour for Dynamic {}
21
22impl Bitfield<Dynamic> {
23    /// Create a new dynamic bitfield with the given length (all bits false).
24    /// Length must be a multiple of 8 bits and greater than 0.
25    pub fn new(len: usize) -> Result<Self, Error> {
26        if len == 0 {
27            return Err(Error::InvalidByteCount {
28                given: len,
29                expected: 8,
30            });
31        }
32        if !len.is_multiple_of(8) {
33            return Err(Error::InvalidByteCount {
34                given: len,
35                expected: (len / 8 + 1) * 8,
36            });
37        }
38        Ok(Self {
39            bytes: smallvec![0; bytes_for_bit_len(len)],
40            len,
41            _phantom: PhantomData,
42        })
43    }
44
45    pub fn into_bytes(self) -> SmallVec<[u8; SMALLVEC_LEN]> {
46        self.into_raw_bytes()
47    }
48
49    /// Create a dynamic bitfield from raw bytes and a declared logical length.
50    /// Can be used to specify a max_length if called directly instead of via Decode.
51    pub fn from_bytes_with_len(
52        bytes: SmallVec<[u8; SMALLVEC_LEN]>,
53        len: usize,
54    ) -> Result<Self, Error> {
55        if len != bytes.len() * 8 {
56            return Err(Error::InvalidByteCount {
57                given: len,
58                expected: bytes.len() * 8,
59            });
60        }
61        Self::from_raw_bytes(bytes, len)
62    }
63
64    /// Compute the intersection of two bitfields.
65    pub fn intersection(&self, other: &Self) -> Result<Self, Error> {
66        let max_len = std::cmp::max(self.len(), other.len());
67        let mut result = Self::new(max_len)?;
68
69        for (i, byte) in result.bytes.iter_mut().enumerate() {
70            *byte =
71                self.bytes.get(i).copied().unwrap_or(0) & other.bytes.get(i).copied().unwrap_or(0);
72        }
73        Ok(result)
74    }
75
76    /// Compute the union of two bitfields.
77    pub fn union(&self, other: &Self) -> Result<Self, Error> {
78        let max_len = std::cmp::max(self.len(), other.len());
79        let mut result = Self::new(max_len)?;
80
81        for (i, byte) in result.bytes.iter_mut().enumerate() {
82            *byte =
83                self.bytes.get(i).copied().unwrap_or(0) | other.bytes.get(i).copied().unwrap_or(0);
84        }
85        Ok(result)
86    }
87}
88
89impl Encode for Bitfield<Dynamic> {
90    fn is_ssz_fixed_len() -> bool {
91        false
92    }
93
94    fn ssz_bytes_len(&self) -> usize {
95        self.bytes.len()
96    }
97
98    fn ssz_append(&self, buf: &mut Vec<u8>) {
99        buf.extend_from_slice(&self.bytes)
100    }
101}
102
103impl Decode for Bitfield<Dynamic> {
104    fn is_ssz_fixed_len() -> bool {
105        false
106    }
107
108    fn from_ssz_bytes(bytes: &[u8]) -> Result<Self, DecodeError> {
109        if bytes.is_empty() {
110            return Err(DecodeError::BytesInvalid("Empty bytes".into()));
111        }
112
113        let len = bytes.len() * 8;
114        Self::from_raw_bytes(bytes.to_smallvec(), len).map_err(|e| {
115            DecodeError::BytesInvalid(format!("BitVectorDynamic failed to decode: {:?}", e))
116        })
117    }
118}
119
120impl Serialize for Bitfield<Dynamic> {
121    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
122    where
123        S: Serializer,
124    {
125        serializer.serialize_str(&hex_encode(self.as_ssz_bytes()))
126    }
127}
128
129impl<'de> Deserialize<'de> for Bitfield<Dynamic> {
130    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
131    where
132        D: Deserializer<'de>,
133    {
134        let bytes = deserializer.deserialize_str(PrefixedHexVisitor)?;
135        Self::from_ssz_bytes(&bytes)
136            .map_err(|e| serde::de::Error::custom(format!("BitVectorDynamic {:?}", e)))
137    }
138}
139
140#[cfg(test)]
141mod dynamic_bitfield_tests {
142    use super::*;
143
144    #[test]
145    fn test_basic_operations() -> Result<(), Error> {
146        let mut bitfield = BitVectorDynamic::new(16)?;
147
148        // Set/get within bounds should succeed or return Ok
149        assert!(bitfield.set(0, true).is_ok());
150        assert!(bitfield.set(15, true).is_ok());
151        assert!(bitfield.set(16, true).is_err()); // Out of bounds
152
153        assert_eq!(bitfield.get(0)?, true);
154        assert_eq!(bitfield.get(15)?, true);
155        assert!(bitfield.get(16).is_err());
156
157        Ok(())
158    }
159
160    #[test]
161    fn test_ssz_encode_decode() -> Result<(), Error> {
162        let mut bitfield = BitVectorDynamic::new(8)?;
163        bitfield.set(0, true)?;
164        bitfield.set(7, true)?;
165
166        // Convert to raw bytes and decode with a known length
167        let bytes = bitfield.clone().into_bytes();
168        let decoded = BitVectorDynamic::from_bytes_with_len(bytes, 8)?;
169
170        assert_eq!(bitfield, decoded);
171
172        // Test invalid cases
173        assert!(BitVectorDynamic::from_bytes_with_len(smallvec![], 8).is_err());
174        assert!(BitVectorDynamic::from_bytes_with_len(smallvec![0, 0, 0], 8).is_err());
175
176        Ok(())
177    }
178
179    #[test]
180    fn test_ssz_decode_errors() {
181        // Test empty bytes error
182        let empty_bytes: &[u8] = &[];
183        assert!(matches!(
184            BitVectorDynamic::from_ssz_bytes(empty_bytes),
185            Err(DecodeError::BytesInvalid(msg)) if msg == "Empty bytes"
186        ));
187    }
188
189    #[test]
190    fn test_intersection() -> Result<(), Error> {
191        let mut a = BitVectorDynamic::new(16)?;
192        let mut b = BitVectorDynamic::new(16)?;
193        let mut expected = BitVectorDynamic::new(16)?;
194
195        a.set(1, true)?;
196        a.set(3, true)?;
197        b.set(3, true)?;
198        b.set(4, true)?;
199        expected.set(3, true)?;
200
201        assert_eq!(a.intersection(&b)?, expected);
202        assert_eq!(b.intersection(&a)?, expected);
203
204        Ok(())
205    }
206
207    #[test]
208    fn test_union() -> Result<(), Error> {
209        let mut a = BitVectorDynamic::new(16)?;
210        let mut b = BitVectorDynamic::new(16)?;
211        let mut expected = BitVectorDynamic::new(16)?;
212
213        a.set(1, true)?;
214        a.set(3, true)?;
215        b.set(3, true)?;
216        b.set(4, true)?;
217
218        expected.set(1, true)?;
219        expected.set(3, true)?;
220        expected.set(4, true)?;
221
222        assert_eq!(a.union(&b)?, expected);
223        assert_eq!(b.union(&a)?, expected);
224
225        Ok(())
226    }
227
228    #[test]
229    fn test_highest_set_bit() -> Result<(), Error> {
230        let mut bitfield = BitVectorDynamic::new(16)?;
231        assert_eq!(bitfield.highest_set_bit(), None);
232
233        bitfield.set(3, true)?;
234        assert_eq!(bitfield.highest_set_bit(), Some(3));
235
236        bitfield.set(15, true)?;
237        assert_eq!(bitfield.highest_set_bit(), Some(15));
238
239        Ok(())
240    }
241
242    #[test]
243    fn test_is_zero() -> Result<(), Error> {
244        let mut bitfield = BitVectorDynamic::new(16)?;
245        assert!(bitfield.is_zero());
246
247        bitfield.set(0, true)?;
248        assert!(!bitfield.is_zero());
249
250        bitfield.set(0, false)?;
251        assert!(bitfield.is_zero());
252
253        Ok(())
254    }
255
256    #[test]
257    fn test_num_set_bits() -> Result<(), Error> {
258        let mut bitfield = BitVectorDynamic::new(16)?;
259        assert_eq!(bitfield.num_set_bits(), 0);
260
261        bitfield.set(1, true)?;
262        bitfield.set(3, true)?;
263        bitfield.set(7, true)?;
264        assert_eq!(bitfield.num_set_bits(), 3);
265
266        Ok(())
267    }
268
269    #[test]
270    fn test_difference() -> Result<(), Error> {
271        let mut a = BitVectorDynamic::new(16)?;
272        let mut b = BitVectorDynamic::new(16)?;
273
274        a.set(1, true)?;
275        a.set(3, true)?;
276        b.set(3, true)?;
277        b.set(4, true)?;
278
279        let diff = a.difference(&b);
280        assert_eq!(diff.get(1)?, true);
281        assert_eq!(diff.get(3)?, false);
282        assert_eq!(diff.get(4)?, false);
283
284        Ok(())
285    }
286
287    #[test]
288    fn test_shift_up() -> Result<(), Error> {
289        let mut bitfield = BitVectorDynamic::new(16)?;
290        bitfield.set(0, true)?;
291        bitfield.set(1, true)?;
292
293        bitfield.shift_up(1)?;
294        assert_eq!(bitfield.get(0)?, false);
295        assert_eq!(bitfield.get(1)?, true);
296        assert_eq!(bitfield.get(2)?, true);
297
298        // Test error case
299        assert!(bitfield.shift_up(17).is_err());
300
301        Ok(())
302    }
303
304    #[test]
305    fn test_iter() -> Result<(), Error> {
306        let mut bitfield = BitVectorDynamic::new(8)?;
307        bitfield.set(1, true)?;
308        bitfield.set(4, true)?;
309        bitfield.set(7, true)?;
310
311        let bits: Vec<bool> = bitfield.iter().collect();
312        assert_eq!(
313            bits,
314            vec![false, true, false, false, true, false, false, true]
315        );
316
317        Ok(())
318    }
319
320    #[test]
321    fn test_non_byte_aligned_lengths() {
322        // Zero bits should error
323        assert!(BitVectorDynamic::new(0).is_err());
324
325        // 1 bit should error
326        assert!(BitVectorDynamic::new(1).is_err());
327
328        // 7 bits should error
329        assert!(BitVectorDynamic::new(7).is_err());
330
331        // 8 bits should succeed
332        assert!(BitVectorDynamic::new(8).is_ok());
333
334        // 9 bits should error
335        assert!(BitVectorDynamic::new(9).is_err());
336    }
337
338    #[test]
339    fn test_encode_decode() -> Result<(), Error> {
340        let mut bitfield = BitVectorDynamic::new(32)?;
341        bitfield.set(0, true)?;
342        bitfield.set(16, true)?;
343        bitfield.set(31, true)?;
344
345        let expected: SmallVec<[u8; 4]> =
346            smallvec![0b0000_0001, 0b0000_0000, 0b0000_0001, 0b1000_0000];
347        let bytes = bitfield.clone().into_bytes();
348        assert_eq!(bytes, expected);
349
350        let encoded = bitfield.as_ssz_bytes();
351        let decoded = BitVectorDynamic::from_ssz_bytes(&encoded)
352            .map_err(|_| Error::InvalidByteCount {
353                given: encoded.len(),
354                expected: bytes.len(),
355            })?
356            .into_bytes();
357
358        assert_eq!(bytes, decoded);
359
360        Ok(())
361    }
362
363    #[test]
364    fn test_from_bytes_equivalence() -> Result<(), Error> {
365        let mut original = BitVectorDynamic::new(16)?;
366        original.set(0, true)?;
367        original.set(8, true)?;
368        original.set(15, true)?;
369
370        let bytes = original.clone().into_bytes();
371
372        // Both methods should produce equivalent results
373        let from_bytes = BitVectorDynamic::from_bytes_with_len(bytes.clone(), 16)?;
374        let from_ssz =
375            BitVectorDynamic::from_ssz_bytes(&bytes).map_err(|_| Error::InvalidByteCount {
376                given: 0,
377                expected: 1,
378            })?;
379
380        assert_eq!(from_bytes, from_ssz);
381        assert_eq!(from_bytes, original);
382
383        Ok(())
384    }
385
386    #[test]
387    fn test_from_bytes_length_validation() {
388        // Create bytes representing a 32-bit field
389        let bytes = smallvec![0b1111_1111; 4]; // 32 bits set
390
391        // Trying to decode as 16 bits should fail
392        assert!(matches!(
393            BitVectorDynamic::from_bytes_with_len(bytes.clone(), 16),
394            Err(Error::InvalidByteCount { .. })
395        ));
396
397        // Decoding with correct length should succeed
398        assert!(BitVectorDynamic::from_bytes_with_len(bytes, 32).is_ok());
399    }
400
401    #[test]
402    fn test_not() -> Result<(), Error> {
403        // Test with all zeros -> all ones
404        let a = BitVectorDynamic::new(8)?;
405        let mut expected = BitVectorDynamic::new(8)?;
406        for i in 0..8 {
407            expected.set(i, true)?;
408        }
409        assert_eq!(a.not(), expected);
410
411        // Test with all ones -> all zeros
412        let b = expected.clone();
413        assert_eq!(b.not(), BitVectorDynamic::new(8)?);
414
415        // Test with mixed pattern (16 bits)
416        let c = BitVectorDynamic::from_bytes_with_len(smallvec![0b1100_1010, 0b0011_0101], 16)?;
417        let expected_c =
418            BitVectorDynamic::from_bytes_with_len(smallvec![0b0011_0101, 0b1100_1010], 16)?;
419        assert_eq!(c.not(), expected_c);
420
421        // Test double not is identity
422        let mut d = BitVectorDynamic::new(16)?;
423        d.set(0, true)?;
424        d.set(5, true)?;
425        d.set(15, true)?;
426        assert_eq!(d.not().not(), d);
427
428        Ok(())
429    }
430
431    #[test]
432    fn test_not_inplace() -> Result<(), Error> {
433        // Test with all zeros
434        let mut a = BitVectorDynamic::new(8)?;
435        a.not_inplace();
436        let mut expected = BitVectorDynamic::new(8)?;
437        for i in 0..8 {
438            expected.set(i, true)?;
439        }
440        assert_eq!(a, expected);
441
442        // Test with all ones
443        let mut b = expected.clone();
444        b.not_inplace();
445        assert_eq!(b, BitVectorDynamic::new(8)?);
446
447        // Test with mixed pattern (16 bits)
448        let mut c = BitVectorDynamic::from_bytes_with_len(smallvec![0b1100_1010, 0b0011_0101], 16)?;
449        c.not_inplace();
450        let expected_c =
451            BitVectorDynamic::from_bytes_with_len(smallvec![0b0011_0101, 0b1100_1010], 16)?;
452        assert_eq!(c, expected_c);
453
454        // Test double not_inplace is identity
455        let mut d = BitVectorDynamic::new(16)?;
456        d.set(0, true)?;
457        d.set(5, true)?;
458        d.set(15, true)?;
459        let original = d.clone();
460        d.not_inplace();
461        d.not_inplace();
462        assert_eq!(d, original);
463
464        Ok(())
465    }
466
467    #[test]
468    fn test_ssz_bytes_len() -> Result<(), Error> {
469        let bitfield = BitVectorDynamic::new(16)?; // 16 bits = 2 bytes
470        assert_eq!(bitfield.ssz_bytes_len(), 2);
471
472        let bitfield = BitVectorDynamic::new(24)?; // 24 bits = 3 bytes
473        assert_eq!(bitfield.ssz_bytes_len(), 3);
474
475        Ok(())
476    }
477}
478#[cfg(test)]
479mod roundtrip_tests {
480    use super::*;
481    fn assert_round_trip_bitdyn<T>(t: T) -> Result<(), Error>
482    where
483        T: Encode + Decode + PartialEq + std::fmt::Debug,
484    {
485        let bytes = t.as_ssz_bytes();
486        let decoded = T::from_ssz_bytes(&bytes).expect("decode failed in test");
487        assert_eq!(decoded, t, "BitDyn must SSZ-roundtrip correctly.");
488        Ok(())
489    }
490
491    #[test]
492    fn bitdyn_ssz_round_trip() -> Result<(), Error> {
493        // length = 8, set even bits
494        let mut b = BitVectorDynamic::new(8)?;
495        for j in 0..8 {
496            if j % 2 == 0 {
497                b.set(j, true)?;
498            }
499        }
500        assert_round_trip_bitdyn(b)?;
501
502        // length = 16, all bits
503        let mut b = BitVectorDynamic::new(16)?;
504        for j in 0..16 {
505            b.set(j, true)?;
506        }
507        assert_round_trip_bitdyn(b)?;
508
509        Ok(())
510    }
511
512    #[test]
513    fn test_ssz_roundtrip_various_sizes() -> Result<(), Error> {
514        // Test empty vector (8 bits)
515        let empty = BitVectorDynamic::new(8)?;
516        assert_round_trip_bitdyn(empty)?;
517
518        // Test partially filled vector (16 bits)
519        let mut partial = BitVectorDynamic::new(16)?;
520        partial.set(0, true)?;
521        partial.set(8, true)?;
522        partial.set(15, true)?;
523        assert_round_trip_bitdyn(partial)?;
524
525        // Test fully filled vector (24 bits)
526        let mut full = BitVectorDynamic::new(24)?;
527        for i in 0..24 {
528            full.set(i, true)?;
529        }
530        assert_round_trip_bitdyn(full)?;
531
532        // Test alternating pattern (32 bits)
533        let mut alternating = BitVectorDynamic::new(32)?;
534        for i in 0..32 {
535            alternating.set(i, i % 2 == 0)?;
536        }
537        assert_round_trip_bitdyn(alternating)?;
538
539        // Test fully filled vector (128 bits)
540        let mut full = BitVectorDynamic::new(128)?;
541        for i in 0..128 {
542            full.set(i, true)?;
543        }
544        assert_round_trip_bitdyn(full)?;
545
546        Ok(())
547    }
548
549    #[test]
550    fn test_serde_roundtrip() -> Result<(), Error> {
551        use serde_json::de::Deserializer as json_deserializer;
552        use serde_json::ser::Serializer as json_serializer;
553
554        // Create a test BitVectorDynamic
555        let mut b = BitVectorDynamic::new(16)?;
556        b.set(0, true)?;
557        b.set(7, true)?;
558
559        // Create a string buffer and serializer
560        let mut output = Vec::new();
561        let mut serializer = json_serializer::new(&mut output);
562
563        // Call serialize
564        b.serialize(&mut serializer).expect("Serialization failed");
565
566        // Create a deserializer
567        let json = String::from_utf8(output).expect("UTF-8 error");
568        let mut deserializer = json_deserializer::from_str(&json);
569
570        // Call deserialize
571        let deserialized =
572            BitVectorDynamic::deserialize(&mut deserializer).expect("Deserialization failed");
573
574        assert_eq!(b, deserialized);
575
576        Ok(())
577    }
578}