1use crate::{
4 bitfield::{bytes_for_bit_len, Bitfield, BitfieldBehaviour, Error, SMALLVEC_LEN},
5 Decode, DecodeError, Encode,
6};
7use core::marker::PhantomData;
8use serde::de::{Deserialize, Deserializer};
9use serde::ser::{Serialize, Serializer};
10use serde_utils::hex::{encode as hex_encode, PrefixedHexVisitor};
11use smallvec::{smallvec, SmallVec, ToSmallVec};
12
13#[derive(Clone, PartialEq, Eq, Debug)]
15pub struct Dynamic;
16
17pub type BitVectorDynamic = Bitfield<Dynamic>;
19
20impl BitfieldBehaviour for Dynamic {}
21
22impl Bitfield<Dynamic> {
23 pub fn new(len: usize) -> Result<Self, Error> {
26 if len == 0 {
27 return Err(Error::InvalidByteCount {
28 given: len,
29 expected: 8,
30 });
31 }
32 if !len.is_multiple_of(8) {
33 return Err(Error::InvalidByteCount {
34 given: len,
35 expected: (len / 8 + 1) * 8,
36 });
37 }
38 Ok(Self {
39 bytes: smallvec![0; bytes_for_bit_len(len)],
40 len,
41 _phantom: PhantomData,
42 })
43 }
44
45 pub fn into_bytes(self) -> SmallVec<[u8; SMALLVEC_LEN]> {
46 self.into_raw_bytes()
47 }
48
49 pub fn from_bytes_with_len(
52 bytes: SmallVec<[u8; SMALLVEC_LEN]>,
53 len: usize,
54 ) -> Result<Self, Error> {
55 if len != bytes.len() * 8 {
56 return Err(Error::InvalidByteCount {
57 given: len,
58 expected: bytes.len() * 8,
59 });
60 }
61 Self::from_raw_bytes(bytes, len)
62 }
63
64 pub fn intersection(&self, other: &Self) -> Result<Self, Error> {
66 let max_len = std::cmp::max(self.len(), other.len());
67 let mut result = Self::new(max_len)?;
68
69 for (i, byte) in result.bytes.iter_mut().enumerate() {
70 *byte =
71 self.bytes.get(i).copied().unwrap_or(0) & other.bytes.get(i).copied().unwrap_or(0);
72 }
73 Ok(result)
74 }
75
76 pub fn union(&self, other: &Self) -> Result<Self, Error> {
78 let max_len = std::cmp::max(self.len(), other.len());
79 let mut result = Self::new(max_len)?;
80
81 for (i, byte) in result.bytes.iter_mut().enumerate() {
82 *byte =
83 self.bytes.get(i).copied().unwrap_or(0) | other.bytes.get(i).copied().unwrap_or(0);
84 }
85 Ok(result)
86 }
87}
88
89impl Encode for Bitfield<Dynamic> {
90 fn is_ssz_fixed_len() -> bool {
91 false
92 }
93
94 fn ssz_bytes_len(&self) -> usize {
95 self.bytes.len()
96 }
97
98 fn ssz_append(&self, buf: &mut Vec<u8>) {
99 buf.extend_from_slice(&self.bytes)
100 }
101}
102
103impl Decode for Bitfield<Dynamic> {
104 fn is_ssz_fixed_len() -> bool {
105 false
106 }
107
108 fn from_ssz_bytes(bytes: &[u8]) -> Result<Self, DecodeError> {
109 if bytes.is_empty() {
110 return Err(DecodeError::BytesInvalid("Empty bytes".into()));
111 }
112
113 let len = bytes.len() * 8;
114 Self::from_raw_bytes(bytes.to_smallvec(), len).map_err(|e| {
115 DecodeError::BytesInvalid(format!("BitVectorDynamic failed to decode: {:?}", e))
116 })
117 }
118}
119
120impl Serialize for Bitfield<Dynamic> {
121 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
122 where
123 S: Serializer,
124 {
125 serializer.serialize_str(&hex_encode(self.as_ssz_bytes()))
126 }
127}
128
129impl<'de> Deserialize<'de> for Bitfield<Dynamic> {
130 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
131 where
132 D: Deserializer<'de>,
133 {
134 let bytes = deserializer.deserialize_str(PrefixedHexVisitor)?;
135 Self::from_ssz_bytes(&bytes)
136 .map_err(|e| serde::de::Error::custom(format!("BitVectorDynamic {:?}", e)))
137 }
138}
139
140#[cfg(test)]
141mod dynamic_bitfield_tests {
142 use super::*;
143
144 #[test]
145 fn test_basic_operations() -> Result<(), Error> {
146 let mut bitfield = BitVectorDynamic::new(16)?;
147
148 assert!(bitfield.set(0, true).is_ok());
150 assert!(bitfield.set(15, true).is_ok());
151 assert!(bitfield.set(16, true).is_err()); assert_eq!(bitfield.get(0)?, true);
154 assert_eq!(bitfield.get(15)?, true);
155 assert!(bitfield.get(16).is_err());
156
157 Ok(())
158 }
159
160 #[test]
161 fn test_ssz_encode_decode() -> Result<(), Error> {
162 let mut bitfield = BitVectorDynamic::new(8)?;
163 bitfield.set(0, true)?;
164 bitfield.set(7, true)?;
165
166 let bytes = bitfield.clone().into_bytes();
168 let decoded = BitVectorDynamic::from_bytes_with_len(bytes, 8)?;
169
170 assert_eq!(bitfield, decoded);
171
172 assert!(BitVectorDynamic::from_bytes_with_len(smallvec![], 8).is_err());
174 assert!(BitVectorDynamic::from_bytes_with_len(smallvec![0, 0, 0], 8).is_err());
175
176 Ok(())
177 }
178
179 #[test]
180 fn test_ssz_decode_errors() {
181 let empty_bytes: &[u8] = &[];
183 assert!(matches!(
184 BitVectorDynamic::from_ssz_bytes(empty_bytes),
185 Err(DecodeError::BytesInvalid(msg)) if msg == "Empty bytes"
186 ));
187 }
188
189 #[test]
190 fn test_intersection() -> Result<(), Error> {
191 let mut a = BitVectorDynamic::new(16)?;
192 let mut b = BitVectorDynamic::new(16)?;
193 let mut expected = BitVectorDynamic::new(16)?;
194
195 a.set(1, true)?;
196 a.set(3, true)?;
197 b.set(3, true)?;
198 b.set(4, true)?;
199 expected.set(3, true)?;
200
201 assert_eq!(a.intersection(&b)?, expected);
202 assert_eq!(b.intersection(&a)?, expected);
203
204 Ok(())
205 }
206
207 #[test]
208 fn test_union() -> Result<(), Error> {
209 let mut a = BitVectorDynamic::new(16)?;
210 let mut b = BitVectorDynamic::new(16)?;
211 let mut expected = BitVectorDynamic::new(16)?;
212
213 a.set(1, true)?;
214 a.set(3, true)?;
215 b.set(3, true)?;
216 b.set(4, true)?;
217
218 expected.set(1, true)?;
219 expected.set(3, true)?;
220 expected.set(4, true)?;
221
222 assert_eq!(a.union(&b)?, expected);
223 assert_eq!(b.union(&a)?, expected);
224
225 Ok(())
226 }
227
228 #[test]
229 fn test_highest_set_bit() -> Result<(), Error> {
230 let mut bitfield = BitVectorDynamic::new(16)?;
231 assert_eq!(bitfield.highest_set_bit(), None);
232
233 bitfield.set(3, true)?;
234 assert_eq!(bitfield.highest_set_bit(), Some(3));
235
236 bitfield.set(15, true)?;
237 assert_eq!(bitfield.highest_set_bit(), Some(15));
238
239 Ok(())
240 }
241
242 #[test]
243 fn test_is_zero() -> Result<(), Error> {
244 let mut bitfield = BitVectorDynamic::new(16)?;
245 assert!(bitfield.is_zero());
246
247 bitfield.set(0, true)?;
248 assert!(!bitfield.is_zero());
249
250 bitfield.set(0, false)?;
251 assert!(bitfield.is_zero());
252
253 Ok(())
254 }
255
256 #[test]
257 fn test_num_set_bits() -> Result<(), Error> {
258 let mut bitfield = BitVectorDynamic::new(16)?;
259 assert_eq!(bitfield.num_set_bits(), 0);
260
261 bitfield.set(1, true)?;
262 bitfield.set(3, true)?;
263 bitfield.set(7, true)?;
264 assert_eq!(bitfield.num_set_bits(), 3);
265
266 Ok(())
267 }
268
269 #[test]
270 fn test_difference() -> Result<(), Error> {
271 let mut a = BitVectorDynamic::new(16)?;
272 let mut b = BitVectorDynamic::new(16)?;
273
274 a.set(1, true)?;
275 a.set(3, true)?;
276 b.set(3, true)?;
277 b.set(4, true)?;
278
279 let diff = a.difference(&b);
280 assert_eq!(diff.get(1)?, true);
281 assert_eq!(diff.get(3)?, false);
282 assert_eq!(diff.get(4)?, false);
283
284 Ok(())
285 }
286
287 #[test]
288 fn test_shift_up() -> Result<(), Error> {
289 let mut bitfield = BitVectorDynamic::new(16)?;
290 bitfield.set(0, true)?;
291 bitfield.set(1, true)?;
292
293 bitfield.shift_up(1)?;
294 assert_eq!(bitfield.get(0)?, false);
295 assert_eq!(bitfield.get(1)?, true);
296 assert_eq!(bitfield.get(2)?, true);
297
298 assert!(bitfield.shift_up(17).is_err());
300
301 Ok(())
302 }
303
304 #[test]
305 fn test_iter() -> Result<(), Error> {
306 let mut bitfield = BitVectorDynamic::new(8)?;
307 bitfield.set(1, true)?;
308 bitfield.set(4, true)?;
309 bitfield.set(7, true)?;
310
311 let bits: Vec<bool> = bitfield.iter().collect();
312 assert_eq!(
313 bits,
314 vec![false, true, false, false, true, false, false, true]
315 );
316
317 Ok(())
318 }
319
320 #[test]
321 fn test_non_byte_aligned_lengths() {
322 assert!(BitVectorDynamic::new(0).is_err());
324
325 assert!(BitVectorDynamic::new(1).is_err());
327
328 assert!(BitVectorDynamic::new(7).is_err());
330
331 assert!(BitVectorDynamic::new(8).is_ok());
333
334 assert!(BitVectorDynamic::new(9).is_err());
336 }
337
338 #[test]
339 fn test_encode_decode() -> Result<(), Error> {
340 let mut bitfield = BitVectorDynamic::new(32)?;
341 bitfield.set(0, true)?;
342 bitfield.set(16, true)?;
343 bitfield.set(31, true)?;
344
345 let expected: SmallVec<[u8; 4]> =
346 smallvec![0b0000_0001, 0b0000_0000, 0b0000_0001, 0b1000_0000];
347 let bytes = bitfield.clone().into_bytes();
348 assert_eq!(bytes, expected);
349
350 let encoded = bitfield.as_ssz_bytes();
351 let decoded = BitVectorDynamic::from_ssz_bytes(&encoded)
352 .map_err(|_| Error::InvalidByteCount {
353 given: encoded.len(),
354 expected: bytes.len(),
355 })?
356 .into_bytes();
357
358 assert_eq!(bytes, decoded);
359
360 Ok(())
361 }
362
363 #[test]
364 fn test_from_bytes_equivalence() -> Result<(), Error> {
365 let mut original = BitVectorDynamic::new(16)?;
366 original.set(0, true)?;
367 original.set(8, true)?;
368 original.set(15, true)?;
369
370 let bytes = original.clone().into_bytes();
371
372 let from_bytes = BitVectorDynamic::from_bytes_with_len(bytes.clone(), 16)?;
374 let from_ssz =
375 BitVectorDynamic::from_ssz_bytes(&bytes).map_err(|_| Error::InvalidByteCount {
376 given: 0,
377 expected: 1,
378 })?;
379
380 assert_eq!(from_bytes, from_ssz);
381 assert_eq!(from_bytes, original);
382
383 Ok(())
384 }
385
386 #[test]
387 fn test_from_bytes_length_validation() {
388 let bytes = smallvec![0b1111_1111; 4]; assert!(matches!(
393 BitVectorDynamic::from_bytes_with_len(bytes.clone(), 16),
394 Err(Error::InvalidByteCount { .. })
395 ));
396
397 assert!(BitVectorDynamic::from_bytes_with_len(bytes, 32).is_ok());
399 }
400
401 #[test]
402 fn test_ssz_bytes_len() -> Result<(), Error> {
403 let bitfield = BitVectorDynamic::new(16)?; assert_eq!(bitfield.ssz_bytes_len(), 2);
405
406 let bitfield = BitVectorDynamic::new(24)?; assert_eq!(bitfield.ssz_bytes_len(), 3);
408
409 Ok(())
410 }
411}
412#[cfg(test)]
413mod roundtrip_tests {
414 use super::*;
415 fn assert_round_trip_bitdyn<T>(t: T) -> Result<(), Error>
416 where
417 T: Encode + Decode + PartialEq + std::fmt::Debug,
418 {
419 let bytes = t.as_ssz_bytes();
420 let decoded = T::from_ssz_bytes(&bytes).expect("decode failed in test");
421 assert_eq!(decoded, t, "BitDyn must SSZ-roundtrip correctly.");
422 Ok(())
423 }
424
425 #[test]
426 fn bitdyn_ssz_round_trip() -> Result<(), Error> {
427 let mut b = BitVectorDynamic::new(8)?;
429 for j in 0..8 {
430 if j % 2 == 0 {
431 b.set(j, true)?;
432 }
433 }
434 assert_round_trip_bitdyn(b)?;
435
436 let mut b = BitVectorDynamic::new(16)?;
438 for j in 0..16 {
439 b.set(j, true)?;
440 }
441 assert_round_trip_bitdyn(b)?;
442
443 Ok(())
444 }
445
446 #[test]
447 fn test_ssz_roundtrip_various_sizes() -> Result<(), Error> {
448 let empty = BitVectorDynamic::new(8)?;
450 assert_round_trip_bitdyn(empty)?;
451
452 let mut partial = BitVectorDynamic::new(16)?;
454 partial.set(0, true)?;
455 partial.set(8, true)?;
456 partial.set(15, true)?;
457 assert_round_trip_bitdyn(partial)?;
458
459 let mut full = BitVectorDynamic::new(24)?;
461 for i in 0..24 {
462 full.set(i, true)?;
463 }
464 assert_round_trip_bitdyn(full)?;
465
466 let mut alternating = BitVectorDynamic::new(32)?;
468 for i in 0..32 {
469 alternating.set(i, i % 2 == 0)?;
470 }
471 assert_round_trip_bitdyn(alternating)?;
472
473 let mut full = BitVectorDynamic::new(128)?;
475 for i in 0..128 {
476 full.set(i, true)?;
477 }
478 assert_round_trip_bitdyn(full)?;
479
480 Ok(())
481 }
482
483 #[test]
484 fn test_serde_roundtrip() -> Result<(), Error> {
485 use serde_json::de::Deserializer as json_deserializer;
486 use serde_json::ser::Serializer as json_serializer;
487
488 let mut b = BitVectorDynamic::new(16)?;
490 b.set(0, true)?;
491 b.set(7, true)?;
492
493 let mut output = Vec::new();
495 let mut serializer = json_serializer::new(&mut output);
496
497 b.serialize(&mut serializer).expect("Serialization failed");
499
500 let json = String::from_utf8(output).expect("UTF-8 error");
502 let mut deserializer = json_deserializer::from_str(&json);
503
504 let deserialized =
506 BitVectorDynamic::deserialize(&mut deserializer).expect("Deserialization failed");
507
508 assert_eq!(b, deserialized);
509
510 Ok(())
511 }
512}