kona_protocol/batch/
bits.rs

1//! Module for working with span batch bits.
2
3use crate::SpanBatchError;
4use alloc::{vec, vec::Vec};
5use alloy_primitives::bytes;
6use alloy_rlp::Buf;
7use core::cmp::Ordering;
8
9/// Type for span batch bits.
10#[derive(Debug, Default, Clone, PartialEq, Eq)]
11pub struct SpanBatchBits(Vec<u8>);
12
13impl AsRef<[u8]> for SpanBatchBits {
14    fn as_ref(&self) -> &[u8] {
15        &self.0
16    }
17}
18
19impl SpanBatchBits {
20    /// Creates a new span batch bits.
21    pub const fn new(inner: Vec<u8>) -> Self {
22        Self(inner)
23    }
24
25    /// Decodes a standard span-batch bitlist from a reader.
26    /// The bitlist is encoded as big-endian integer, left-padded with zeroes to a multiple of 8
27    /// bits. The encoded bitlist cannot be longer than `bit_length`.
28    pub fn decode(b: &mut &[u8], bit_length: usize) -> Result<Self, SpanBatchError> {
29        let buffer_len = bit_length / 8 + if bit_length % 8 != 0 { 1 } else { 0 };
30        let bits = if b.len() < buffer_len {
31            let mut bits = vec![0; buffer_len];
32            bits[..b.len()].copy_from_slice(b);
33            b.advance(b.len());
34            bits
35        } else {
36            let v = b[..buffer_len].to_vec();
37            b.advance(buffer_len);
38            v
39        };
40        let sb_bits = Self(bits);
41
42        if sb_bits.bit_len() > bit_length {
43            return Err(SpanBatchError::BitfieldTooLong);
44        }
45
46        Ok(sb_bits)
47    }
48
49    /// Encodes a standard span-batch bitlist.
50    /// The bitlist is encoded as big-endian integer, left-padded with zeroes to a multiple of 8
51    /// bits. The encoded bitlist cannot be longer than `bit_length`
52    pub fn encode(
53        w: &mut dyn bytes::BufMut,
54        bit_length: usize,
55        bits: &Self,
56    ) -> Result<(), SpanBatchError> {
57        if bits.bit_len() > bit_length {
58            return Err(SpanBatchError::BitfieldTooLong);
59        }
60
61        // Round up, ensure enough bytes when number of bits is not a multiple of 8.
62        // Alternative of (L+7)/8 is not overflow-safe.
63        let buf_len = bit_length / 8 + if bit_length % 8 != 0 { 1 } else { 0 };
64        let mut buf = vec![0; buf_len];
65        buf[buf_len - bits.0.len()..].copy_from_slice(bits.as_ref());
66        w.put_slice(&buf);
67        Ok(())
68    }
69
70    /// Get a bit from the [`SpanBatchBits`] bitlist.
71    pub fn get_bit(&self, index: usize) -> Option<u8> {
72        let byte_index = index / 8;
73        let bit_index = index % 8;
74
75        // Check if the byte index is within the bounds of the bitlist
76        if byte_index < self.0.len() {
77            // Retrieve the specific byte that contains the bit we're interested in
78            let byte = self.0[self.0.len() - byte_index - 1];
79
80            // Shift the bits of the byte to the right, based on the bit index, and
81            // mask it with 1 to isolate the bit we're interested in.
82            // If the result is not zero, the bit is set to 1, otherwise it's 0.
83            Some(if byte & (1 << bit_index) != 0 { 1 } else { 0 })
84        } else {
85            // Return None if the index is out of bounds
86            None
87        }
88    }
89
90    /// Sets a bit in the [`SpanBatchBits`] bitlist.
91    pub fn set_bit(&mut self, index: usize, value: bool) {
92        let byte_index = index / 8;
93        let bit_index = index % 8;
94
95        // Ensure the vector is large enough to contain the bit at 'index'.
96        // If not, resize the vector, filling with 0s.
97        if byte_index >= self.0.len() {
98            Self::resize_from_right(&mut self.0, byte_index + 1);
99        }
100
101        // Retrieve the specific byte to modify
102        let len = self.0.len();
103        let byte = &mut self.0[len - byte_index - 1];
104
105        if value {
106            // Set the bit to 1
107            *byte |= 1 << bit_index;
108        } else {
109            // Set the bit to 0
110            *byte &= !(1 << bit_index);
111        }
112    }
113
114    /// Calculates the bit length of the [`SpanBatchBits`] bitfield.
115    pub fn bit_len(&self) -> usize {
116        // Iterate over the bytes from left to right to find the first non-zero byte
117        for (i, &byte) in self.0.iter().enumerate() {
118            if byte != 0 {
119                // Calculate the index of the most significant bit in the byte
120                let msb_index = 7 - byte.leading_zeros() as usize; // 0-based index
121
122                // Calculate the total bit length
123                let total_bit_length = msb_index + 1 + ((self.0.len() - i - 1) * 8);
124                return total_bit_length;
125            }
126        }
127
128        // If all bytes are zero, the bitlist is considered to have a length of 0
129        0
130    }
131
132    /// Resizes an array from the right. Useful for big-endian zero extension.
133    fn resize_from_right<T: Default + Clone>(vec: &mut Vec<T>, new_size: usize) {
134        let current_size = vec.len();
135        match new_size.cmp(&current_size) {
136            Ordering::Less => {
137                // Remove elements from the beginning.
138                let remove_count = current_size - new_size;
139                vec.drain(0..remove_count);
140            }
141            Ordering::Greater => {
142                // Calculate how many new elements to add.
143                let additional = new_size - current_size;
144                // Prepend new elements with default values.
145                let mut prepend_elements = vec![T::default(); additional];
146                prepend_elements.append(vec);
147                *vec = prepend_elements;
148            }
149            Ordering::Equal => { /* If new_size == current_size, do nothing. */ }
150        }
151    }
152}
153
154#[cfg(test)]
155mod test {
156    use super::*;
157    use proptest::{collection::vec, prelude::any, proptest};
158
159    proptest! {
160        #[test]
161        fn test_encode_decode_roundtrip_span_bitlist(vec in vec(any::<u8>(), 0..5096)) {
162            let bits = SpanBatchBits(vec);
163            assert_eq!(SpanBatchBits::decode(&mut bits.as_ref(), bits.0.len() * 8).unwrap(), bits);
164            let mut encoded = Vec::new();
165            SpanBatchBits::encode(&mut encoded, bits.0.len() * 8, &bits).unwrap();
166            assert_eq!(encoded, bits.0);
167        }
168
169        #[test]
170        fn test_span_bitlist_bitlen(index in 0usize..65536) {
171            let mut bits = SpanBatchBits::default();
172            bits.set_bit(index, true);
173            assert_eq!(bits.0.len(), (index / 8) + 1);
174            assert_eq!(bits.bit_len(), index + 1);
175        }
176
177        #[test]
178        fn test_span_bitlist_bitlen_shrink(first_index in 8usize..65536) {
179            let second_index = first_index.clamp(0, first_index - 8);
180            let mut bits = SpanBatchBits::default();
181
182            // Set and clear first index.
183            bits.set_bit(first_index, true);
184            assert_eq!(bits.0.len(), (first_index / 8) + 1);
185            assert_eq!(bits.bit_len(), first_index + 1);
186            bits.set_bit(first_index, false);
187            assert_eq!(bits.0.len(), (first_index / 8) + 1);
188            assert_eq!(bits.bit_len(), 0);
189
190            // Set second bit. Even though the array is larger, as it was originally allocated with more words,
191            // the bitlength should still be lowered as the higher-order words are 0'd out.
192            bits.set_bit(second_index, true);
193            assert_eq!(bits.0.len(), (first_index / 8) + 1);
194            assert_eq!(bits.bit_len(), second_index + 1);
195        }
196    }
197
198    #[test]
199    fn bitlist_big_endian_zero_extended() {
200        let mut bits = SpanBatchBits::default();
201
202        bits.set_bit(1, true);
203        bits.set_bit(6, true);
204        bits.set_bit(8, true);
205        bits.set_bit(15, true);
206        assert_eq!(bits.0[0], 0b1000_0001);
207        assert_eq!(bits.0[1], 0b0100_0010);
208        assert_eq!(bits.0.len(), 2);
209        assert_eq!(bits.bit_len(), 16);
210    }
211
212    #[test]
213    fn test_static_set_get_bits_span_bitlist() {
214        let mut bits = SpanBatchBits::default();
215        assert!(bits.0.is_empty());
216
217        bits.set_bit(0, true);
218        bits.set_bit(1, true);
219        bits.set_bit(2, true);
220        bits.set_bit(4, true);
221        bits.set_bit(7, true);
222        assert_eq!(bits.0.len(), 1);
223        assert_eq!(bits.get_bit(0), Some(1));
224        assert_eq!(bits.get_bit(1), Some(1));
225        assert_eq!(bits.get_bit(2), Some(1));
226        assert_eq!(bits.get_bit(3), Some(0));
227        assert_eq!(bits.get_bit(4), Some(1));
228
229        bits.set_bit(17, true);
230        assert_eq!(bits.get_bit(17), Some(1));
231        assert_eq!(bits.get_bit(32), None);
232        assert_eq!(bits.0.len(), 3);
233    }
234}