1use crate::{
4 bitfield::{bytes_for_bit_len, Bitfield, BitfieldBehaviour, Error, SMALLVEC_LEN},
5 Decode, DecodeError, Encode,
6};
7use core::marker::PhantomData;
8use serde::de::{Deserialize, Deserializer};
9use serde::ser::{Serialize, Serializer};
10use serde_utils::hex::{encode as hex_encode, PrefixedHexVisitor};
11use smallvec::{smallvec, SmallVec, ToSmallVec};
12
13#[derive(Clone, PartialEq, Eq, Debug)]
15pub struct Dynamic;
16
17pub type BitVectorDynamic = Bitfield<Dynamic>;
19
20impl BitfieldBehaviour for Dynamic {}
21
22impl Bitfield<Dynamic> {
23 pub fn new(len: usize) -> Result<Self, Error> {
26 if len == 0 {
27 return Err(Error::InvalidByteCount {
28 given: len,
29 expected: 8,
30 });
31 }
32 if !len.is_multiple_of(8) {
33 return Err(Error::InvalidByteCount {
34 given: len,
35 expected: (len / 8 + 1) * 8,
36 });
37 }
38 Ok(Self {
39 bytes: smallvec![0; bytes_for_bit_len(len)],
40 len,
41 _phantom: PhantomData,
42 })
43 }
44
45 pub fn into_bytes(self) -> SmallVec<[u8; SMALLVEC_LEN]> {
46 self.into_raw_bytes()
47 }
48
49 pub fn from_bytes_with_len(
52 bytes: SmallVec<[u8; SMALLVEC_LEN]>,
53 len: usize,
54 ) -> Result<Self, Error> {
55 if len != bytes.len() * 8 {
56 return Err(Error::InvalidByteCount {
57 given: len,
58 expected: bytes.len() * 8,
59 });
60 }
61 Self::from_raw_bytes(bytes, len)
62 }
63
64 pub fn intersection(&self, other: &Self) -> Result<Self, Error> {
66 let max_len = std::cmp::max(self.len(), other.len());
67 let mut result = Self::new(max_len)?;
68
69 for (i, byte) in result.bytes.iter_mut().enumerate() {
70 *byte =
71 self.bytes.get(i).copied().unwrap_or(0) & other.bytes.get(i).copied().unwrap_or(0);
72 }
73 Ok(result)
74 }
75
76 pub fn union(&self, other: &Self) -> Result<Self, Error> {
78 let max_len = std::cmp::max(self.len(), other.len());
79 let mut result = Self::new(max_len)?;
80
81 for (i, byte) in result.bytes.iter_mut().enumerate() {
82 *byte =
83 self.bytes.get(i).copied().unwrap_or(0) | other.bytes.get(i).copied().unwrap_or(0);
84 }
85 Ok(result)
86 }
87}
88
89impl Encode for Bitfield<Dynamic> {
90 fn is_ssz_fixed_len() -> bool {
91 false
92 }
93
94 fn ssz_bytes_len(&self) -> usize {
95 self.bytes.len()
96 }
97
98 fn ssz_append(&self, buf: &mut Vec<u8>) {
99 buf.extend_from_slice(&self.bytes)
100 }
101}
102
103impl Decode for Bitfield<Dynamic> {
104 fn is_ssz_fixed_len() -> bool {
105 false
106 }
107
108 fn from_ssz_bytes(bytes: &[u8]) -> Result<Self, DecodeError> {
109 if bytes.is_empty() {
110 return Err(DecodeError::BytesInvalid("Empty bytes".into()));
111 }
112
113 let len = bytes.len() * 8;
114 Self::from_raw_bytes(bytes.to_smallvec(), len).map_err(|e| {
115 DecodeError::BytesInvalid(format!("BitVectorDynamic failed to decode: {:?}", e))
116 })
117 }
118}
119
120impl Serialize for Bitfield<Dynamic> {
121 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
122 where
123 S: Serializer,
124 {
125 serializer.serialize_str(&hex_encode(self.as_ssz_bytes()))
126 }
127}
128
129impl<'de> Deserialize<'de> for Bitfield<Dynamic> {
130 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
131 where
132 D: Deserializer<'de>,
133 {
134 let bytes = deserializer.deserialize_str(PrefixedHexVisitor)?;
135 Self::from_ssz_bytes(&bytes)
136 .map_err(|e| serde::de::Error::custom(format!("BitVectorDynamic {:?}", e)))
137 }
138}
139
140#[cfg(test)]
141mod dynamic_bitfield_tests {
142 use super::*;
143
144 #[test]
145 fn test_basic_operations() -> Result<(), Error> {
146 let mut bitfield = BitVectorDynamic::new(16)?;
147
148 assert!(bitfield.set(0, true).is_ok());
150 assert!(bitfield.set(15, true).is_ok());
151 assert!(bitfield.set(16, true).is_err()); assert_eq!(bitfield.get(0)?, true);
154 assert_eq!(bitfield.get(15)?, true);
155 assert!(bitfield.get(16).is_err());
156
157 Ok(())
158 }
159
160 #[test]
161 fn test_ssz_encode_decode() -> Result<(), Error> {
162 let mut bitfield = BitVectorDynamic::new(8)?;
163 bitfield.set(0, true)?;
164 bitfield.set(7, true)?;
165
166 let bytes = bitfield.clone().into_bytes();
168 let decoded = BitVectorDynamic::from_bytes_with_len(bytes, 8)?;
169
170 assert_eq!(bitfield, decoded);
171
172 assert!(BitVectorDynamic::from_bytes_with_len(smallvec![], 8).is_err());
174 assert!(BitVectorDynamic::from_bytes_with_len(smallvec![0, 0, 0], 8).is_err());
175
176 Ok(())
177 }
178
179 #[test]
180 fn test_ssz_decode_errors() {
181 let empty_bytes: &[u8] = &[];
183 assert!(matches!(
184 BitVectorDynamic::from_ssz_bytes(empty_bytes),
185 Err(DecodeError::BytesInvalid(msg)) if msg == "Empty bytes"
186 ));
187 }
188
189 #[test]
190 fn test_intersection() -> Result<(), Error> {
191 let mut a = BitVectorDynamic::new(16)?;
192 let mut b = BitVectorDynamic::new(16)?;
193 let mut expected = BitVectorDynamic::new(16)?;
194
195 a.set(1, true)?;
196 a.set(3, true)?;
197 b.set(3, true)?;
198 b.set(4, true)?;
199 expected.set(3, true)?;
200
201 assert_eq!(a.intersection(&b)?, expected);
202 assert_eq!(b.intersection(&a)?, expected);
203
204 Ok(())
205 }
206
207 #[test]
208 fn test_union() -> Result<(), Error> {
209 let mut a = BitVectorDynamic::new(16)?;
210 let mut b = BitVectorDynamic::new(16)?;
211 let mut expected = BitVectorDynamic::new(16)?;
212
213 a.set(1, true)?;
214 a.set(3, true)?;
215 b.set(3, true)?;
216 b.set(4, true)?;
217
218 expected.set(1, true)?;
219 expected.set(3, true)?;
220 expected.set(4, true)?;
221
222 assert_eq!(a.union(&b)?, expected);
223 assert_eq!(b.union(&a)?, expected);
224
225 Ok(())
226 }
227
228 #[test]
229 fn test_highest_set_bit() -> Result<(), Error> {
230 let mut bitfield = BitVectorDynamic::new(16)?;
231 assert_eq!(bitfield.highest_set_bit(), None);
232
233 bitfield.set(3, true)?;
234 assert_eq!(bitfield.highest_set_bit(), Some(3));
235
236 bitfield.set(15, true)?;
237 assert_eq!(bitfield.highest_set_bit(), Some(15));
238
239 Ok(())
240 }
241
242 #[test]
243 fn test_is_zero() -> Result<(), Error> {
244 let mut bitfield = BitVectorDynamic::new(16)?;
245 assert!(bitfield.is_zero());
246
247 bitfield.set(0, true)?;
248 assert!(!bitfield.is_zero());
249
250 bitfield.set(0, false)?;
251 assert!(bitfield.is_zero());
252
253 Ok(())
254 }
255
256 #[test]
257 fn test_num_set_bits() -> Result<(), Error> {
258 let mut bitfield = BitVectorDynamic::new(16)?;
259 assert_eq!(bitfield.num_set_bits(), 0);
260
261 bitfield.set(1, true)?;
262 bitfield.set(3, true)?;
263 bitfield.set(7, true)?;
264 assert_eq!(bitfield.num_set_bits(), 3);
265
266 Ok(())
267 }
268
269 #[test]
270 fn test_difference() -> Result<(), Error> {
271 let mut a = BitVectorDynamic::new(16)?;
272 let mut b = BitVectorDynamic::new(16)?;
273
274 a.set(1, true)?;
275 a.set(3, true)?;
276 b.set(3, true)?;
277 b.set(4, true)?;
278
279 let diff = a.difference(&b);
280 assert_eq!(diff.get(1)?, true);
281 assert_eq!(diff.get(3)?, false);
282 assert_eq!(diff.get(4)?, false);
283
284 Ok(())
285 }
286
287 #[test]
288 fn test_shift_up() -> Result<(), Error> {
289 let mut bitfield = BitVectorDynamic::new(16)?;
290 bitfield.set(0, true)?;
291 bitfield.set(1, true)?;
292
293 bitfield.shift_up(1)?;
294 assert_eq!(bitfield.get(0)?, false);
295 assert_eq!(bitfield.get(1)?, true);
296 assert_eq!(bitfield.get(2)?, true);
297
298 assert!(bitfield.shift_up(17).is_err());
300
301 Ok(())
302 }
303
304 #[test]
305 fn test_iter() -> Result<(), Error> {
306 let mut bitfield = BitVectorDynamic::new(8)?;
307 bitfield.set(1, true)?;
308 bitfield.set(4, true)?;
309 bitfield.set(7, true)?;
310
311 let bits: Vec<bool> = bitfield.iter().collect();
312 assert_eq!(
313 bits,
314 vec![false, true, false, false, true, false, false, true]
315 );
316
317 Ok(())
318 }
319
320 #[test]
321 fn test_non_byte_aligned_lengths() {
322 assert!(BitVectorDynamic::new(0).is_err());
324
325 assert!(BitVectorDynamic::new(1).is_err());
327
328 assert!(BitVectorDynamic::new(7).is_err());
330
331 assert!(BitVectorDynamic::new(8).is_ok());
333
334 assert!(BitVectorDynamic::new(9).is_err());
336 }
337
338 #[test]
339 fn test_encode_decode() -> Result<(), Error> {
340 let mut bitfield = BitVectorDynamic::new(32)?;
341 bitfield.set(0, true)?;
342 bitfield.set(16, true)?;
343 bitfield.set(31, true)?;
344
345 let expected: SmallVec<[u8; 4]> =
346 smallvec![0b0000_0001, 0b0000_0000, 0b0000_0001, 0b1000_0000];
347 let bytes = bitfield.clone().into_bytes();
348 assert_eq!(bytes, expected);
349
350 let encoded = bitfield.as_ssz_bytes();
351 let decoded = BitVectorDynamic::from_ssz_bytes(&encoded)
352 .map_err(|_| Error::InvalidByteCount {
353 given: encoded.len(),
354 expected: bytes.len(),
355 })?
356 .into_bytes();
357
358 assert_eq!(bytes, decoded);
359
360 Ok(())
361 }
362
363 #[test]
364 fn test_from_bytes_equivalence() -> Result<(), Error> {
365 let mut original = BitVectorDynamic::new(16)?;
366 original.set(0, true)?;
367 original.set(8, true)?;
368 original.set(15, true)?;
369
370 let bytes = original.clone().into_bytes();
371
372 let from_bytes = BitVectorDynamic::from_bytes_with_len(bytes.clone(), 16)?;
374 let from_ssz =
375 BitVectorDynamic::from_ssz_bytes(&bytes).map_err(|_| Error::InvalidByteCount {
376 given: 0,
377 expected: 1,
378 })?;
379
380 assert_eq!(from_bytes, from_ssz);
381 assert_eq!(from_bytes, original);
382
383 Ok(())
384 }
385
386 #[test]
387 fn test_from_bytes_length_validation() {
388 let bytes = smallvec![0b1111_1111; 4]; assert!(matches!(
393 BitVectorDynamic::from_bytes_with_len(bytes.clone(), 16),
394 Err(Error::InvalidByteCount { .. })
395 ));
396
397 assert!(BitVectorDynamic::from_bytes_with_len(bytes, 32).is_ok());
399 }
400
401 #[test]
402 fn test_not() -> Result<(), Error> {
403 let a = BitVectorDynamic::new(8)?;
405 let mut expected = BitVectorDynamic::new(8)?;
406 for i in 0..8 {
407 expected.set(i, true)?;
408 }
409 assert_eq!(a.not(), expected);
410
411 let b = expected.clone();
413 assert_eq!(b.not(), BitVectorDynamic::new(8)?);
414
415 let c = BitVectorDynamic::from_bytes_with_len(smallvec![0b1100_1010, 0b0011_0101], 16)?;
417 let expected_c =
418 BitVectorDynamic::from_bytes_with_len(smallvec![0b0011_0101, 0b1100_1010], 16)?;
419 assert_eq!(c.not(), expected_c);
420
421 let mut d = BitVectorDynamic::new(16)?;
423 d.set(0, true)?;
424 d.set(5, true)?;
425 d.set(15, true)?;
426 assert_eq!(d.not().not(), d);
427
428 Ok(())
429 }
430
431 #[test]
432 fn test_not_inplace() -> Result<(), Error> {
433 let mut a = BitVectorDynamic::new(8)?;
435 a.not_inplace();
436 let mut expected = BitVectorDynamic::new(8)?;
437 for i in 0..8 {
438 expected.set(i, true)?;
439 }
440 assert_eq!(a, expected);
441
442 let mut b = expected.clone();
444 b.not_inplace();
445 assert_eq!(b, BitVectorDynamic::new(8)?);
446
447 let mut c = BitVectorDynamic::from_bytes_with_len(smallvec![0b1100_1010, 0b0011_0101], 16)?;
449 c.not_inplace();
450 let expected_c =
451 BitVectorDynamic::from_bytes_with_len(smallvec![0b0011_0101, 0b1100_1010], 16)?;
452 assert_eq!(c, expected_c);
453
454 let mut d = BitVectorDynamic::new(16)?;
456 d.set(0, true)?;
457 d.set(5, true)?;
458 d.set(15, true)?;
459 let original = d.clone();
460 d.not_inplace();
461 d.not_inplace();
462 assert_eq!(d, original);
463
464 Ok(())
465 }
466
467 #[test]
468 fn test_ssz_bytes_len() -> Result<(), Error> {
469 let bitfield = BitVectorDynamic::new(16)?; assert_eq!(bitfield.ssz_bytes_len(), 2);
471
472 let bitfield = BitVectorDynamic::new(24)?; assert_eq!(bitfield.ssz_bytes_len(), 3);
474
475 Ok(())
476 }
477}
478#[cfg(test)]
479mod roundtrip_tests {
480 use super::*;
481 fn assert_round_trip_bitdyn<T>(t: T) -> Result<(), Error>
482 where
483 T: Encode + Decode + PartialEq + std::fmt::Debug,
484 {
485 let bytes = t.as_ssz_bytes();
486 let decoded = T::from_ssz_bytes(&bytes).expect("decode failed in test");
487 assert_eq!(decoded, t, "BitDyn must SSZ-roundtrip correctly.");
488 Ok(())
489 }
490
491 #[test]
492 fn bitdyn_ssz_round_trip() -> Result<(), Error> {
493 let mut b = BitVectorDynamic::new(8)?;
495 for j in 0..8 {
496 if j % 2 == 0 {
497 b.set(j, true)?;
498 }
499 }
500 assert_round_trip_bitdyn(b)?;
501
502 let mut b = BitVectorDynamic::new(16)?;
504 for j in 0..16 {
505 b.set(j, true)?;
506 }
507 assert_round_trip_bitdyn(b)?;
508
509 Ok(())
510 }
511
512 #[test]
513 fn test_ssz_roundtrip_various_sizes() -> Result<(), Error> {
514 let empty = BitVectorDynamic::new(8)?;
516 assert_round_trip_bitdyn(empty)?;
517
518 let mut partial = BitVectorDynamic::new(16)?;
520 partial.set(0, true)?;
521 partial.set(8, true)?;
522 partial.set(15, true)?;
523 assert_round_trip_bitdyn(partial)?;
524
525 let mut full = BitVectorDynamic::new(24)?;
527 for i in 0..24 {
528 full.set(i, true)?;
529 }
530 assert_round_trip_bitdyn(full)?;
531
532 let mut alternating = BitVectorDynamic::new(32)?;
534 for i in 0..32 {
535 alternating.set(i, i % 2 == 0)?;
536 }
537 assert_round_trip_bitdyn(alternating)?;
538
539 let mut full = BitVectorDynamic::new(128)?;
541 for i in 0..128 {
542 full.set(i, true)?;
543 }
544 assert_round_trip_bitdyn(full)?;
545
546 Ok(())
547 }
548
549 #[test]
550 fn test_serde_roundtrip() -> Result<(), Error> {
551 use serde_json::de::Deserializer as json_deserializer;
552 use serde_json::ser::Serializer as json_serializer;
553
554 let mut b = BitVectorDynamic::new(16)?;
556 b.set(0, true)?;
557 b.set(7, true)?;
558
559 let mut output = Vec::new();
561 let mut serializer = json_serializer::new(&mut output);
562
563 b.serialize(&mut serializer).expect("Serialization failed");
565
566 let json = String::from_utf8(output).expect("UTF-8 error");
568 let mut deserializer = json_deserializer::from_str(&json);
569
570 let deserialized =
572 BitVectorDynamic::deserialize(&mut deserializer).expect("Deserialization failed");
573
574 assert_eq!(b, deserialized);
575
576 Ok(())
577 }
578}