kona_protocol/batch/
bits.rs1use crate::SpanBatchError;
4use alloc::{vec, vec::Vec};
5use alloy_primitives::bytes;
6use alloy_rlp::Buf;
7use core::cmp::Ordering;
8
9#[derive(Debug, Default, Clone, PartialEq, Eq)]
11pub struct SpanBatchBits(Vec<u8>);
12
13impl AsRef<[u8]> for SpanBatchBits {
14 fn as_ref(&self) -> &[u8] {
15 &self.0
16 }
17}
18
19impl SpanBatchBits {
20 pub const fn new(inner: Vec<u8>) -> Self {
22 Self(inner)
23 }
24
25 pub fn decode(b: &mut &[u8], bit_length: usize) -> Result<Self, SpanBatchError> {
29 let buffer_len = bit_length / 8 + if bit_length % 8 != 0 { 1 } else { 0 };
30 let bits = if b.len() < buffer_len {
31 let mut bits = vec![0; buffer_len];
32 bits[..b.len()].copy_from_slice(b);
33 b.advance(b.len());
34 bits
35 } else {
36 let v = b[..buffer_len].to_vec();
37 b.advance(buffer_len);
38 v
39 };
40 let sb_bits = Self(bits);
41
42 if sb_bits.bit_len() > bit_length {
43 return Err(SpanBatchError::BitfieldTooLong);
44 }
45
46 Ok(sb_bits)
47 }
48
49 pub fn encode(
53 w: &mut dyn bytes::BufMut,
54 bit_length: usize,
55 bits: &Self,
56 ) -> Result<(), SpanBatchError> {
57 if bits.bit_len() > bit_length {
58 return Err(SpanBatchError::BitfieldTooLong);
59 }
60
61 let buf_len = bit_length / 8 + if bit_length % 8 != 0 { 1 } else { 0 };
64 let mut buf = vec![0; buf_len];
65 buf[buf_len - bits.0.len()..].copy_from_slice(bits.as_ref());
66 w.put_slice(&buf);
67 Ok(())
68 }
69
70 pub fn get_bit(&self, index: usize) -> Option<u8> {
72 let byte_index = index / 8;
73 let bit_index = index % 8;
74
75 if byte_index < self.0.len() {
77 let byte = self.0[self.0.len() - byte_index - 1];
79
80 Some(if byte & (1 << bit_index) != 0 { 1 } else { 0 })
84 } else {
85 None
87 }
88 }
89
90 pub fn set_bit(&mut self, index: usize, value: bool) {
92 let byte_index = index / 8;
93 let bit_index = index % 8;
94
95 if byte_index >= self.0.len() {
98 Self::resize_from_right(&mut self.0, byte_index + 1);
99 }
100
101 let len = self.0.len();
103 let byte = &mut self.0[len - byte_index - 1];
104
105 if value {
106 *byte |= 1 << bit_index;
108 } else {
109 *byte &= !(1 << bit_index);
111 }
112 }
113
114 pub fn bit_len(&self) -> usize {
116 for (i, &byte) in self.0.iter().enumerate() {
118 if byte != 0 {
119 let msb_index = 7 - byte.leading_zeros() as usize; let total_bit_length = msb_index + 1 + ((self.0.len() - i - 1) * 8);
124 return total_bit_length;
125 }
126 }
127
128 0
130 }
131
132 fn resize_from_right<T: Default + Clone>(vec: &mut Vec<T>, new_size: usize) {
134 let current_size = vec.len();
135 match new_size.cmp(¤t_size) {
136 Ordering::Less => {
137 let remove_count = current_size - new_size;
139 vec.drain(0..remove_count);
140 }
141 Ordering::Greater => {
142 let additional = new_size - current_size;
144 let mut prepend_elements = vec![T::default(); additional];
146 prepend_elements.append(vec);
147 *vec = prepend_elements;
148 }
149 Ordering::Equal => { }
150 }
151 }
152}
153
154#[cfg(test)]
155mod test {
156 use super::*;
157 use proptest::{collection::vec, prelude::any, proptest};
158
159 proptest! {
160 #[test]
161 fn test_encode_decode_roundtrip_span_bitlist(vec in vec(any::<u8>(), 0..5096)) {
162 let bits = SpanBatchBits(vec);
163 assert_eq!(SpanBatchBits::decode(&mut bits.as_ref(), bits.0.len() * 8).unwrap(), bits);
164 let mut encoded = Vec::new();
165 SpanBatchBits::encode(&mut encoded, bits.0.len() * 8, &bits).unwrap();
166 assert_eq!(encoded, bits.0);
167 }
168
169 #[test]
170 fn test_span_bitlist_bitlen(index in 0usize..65536) {
171 let mut bits = SpanBatchBits::default();
172 bits.set_bit(index, true);
173 assert_eq!(bits.0.len(), (index / 8) + 1);
174 assert_eq!(bits.bit_len(), index + 1);
175 }
176
177 #[test]
178 fn test_span_bitlist_bitlen_shrink(first_index in 8usize..65536) {
179 let second_index = first_index.clamp(0, first_index - 8);
180 let mut bits = SpanBatchBits::default();
181
182 bits.set_bit(first_index, true);
184 assert_eq!(bits.0.len(), (first_index / 8) + 1);
185 assert_eq!(bits.bit_len(), first_index + 1);
186 bits.set_bit(first_index, false);
187 assert_eq!(bits.0.len(), (first_index / 8) + 1);
188 assert_eq!(bits.bit_len(), 0);
189
190 bits.set_bit(second_index, true);
193 assert_eq!(bits.0.len(), (first_index / 8) + 1);
194 assert_eq!(bits.bit_len(), second_index + 1);
195 }
196 }
197
198 #[test]
199 fn bitlist_big_endian_zero_extended() {
200 let mut bits = SpanBatchBits::default();
201
202 bits.set_bit(1, true);
203 bits.set_bit(6, true);
204 bits.set_bit(8, true);
205 bits.set_bit(15, true);
206 assert_eq!(bits.0[0], 0b1000_0001);
207 assert_eq!(bits.0[1], 0b0100_0010);
208 assert_eq!(bits.0.len(), 2);
209 assert_eq!(bits.bit_len(), 16);
210 }
211
212 #[test]
213 fn test_static_set_get_bits_span_bitlist() {
214 let mut bits = SpanBatchBits::default();
215 assert!(bits.0.is_empty());
216
217 bits.set_bit(0, true);
218 bits.set_bit(1, true);
219 bits.set_bit(2, true);
220 bits.set_bit(4, true);
221 bits.set_bit(7, true);
222 assert_eq!(bits.0.len(), 1);
223 assert_eq!(bits.get_bit(0), Some(1));
224 assert_eq!(bits.get_bit(1), Some(1));
225 assert_eq!(bits.get_bit(2), Some(1));
226 assert_eq!(bits.get_bit(3), Some(0));
227 assert_eq!(bits.get_bit(4), Some(1));
228
229 bits.set_bit(17, true);
230 assert_eq!(bits.get_bit(17), Some(1));
231 assert_eq!(bits.get_bit(32), None);
232 assert_eq!(bits.0.len(), 3);
233 }
234}