abstract_bits/
lib.rs

1#![doc = include_str!("../README.md")]
2
3pub use abstract_bits_derive::abstract_bits;
4pub use arbitrary_int::{u1, u2, u3, u4, u5, u6, u7};
5pub use bitvec;
6use bitvec::order::Lsb0;
7use bitvec::slice::BitSlice;
8
9mod error;
10pub use error::{FromBytesError, ReadErrorCause, ToBytesError};
11
12pub trait AbstractBits {
13    const MIN_BITS: usize;
14    const MAX_BITS: usize;
15    /// To get the amount written use [`BitWriter::bits_written`]
16    /// or [`BitWriter::bytes_written`]
17    fn write_abstract_bits(&self, writer: &mut BitWriter) -> Result<(), ToBytesError>;
18    /// To get the amount read use [`BitReader::bits_read`]
19    /// or [`BitReader::bytes_read`]
20    fn read_abstract_bits(reader: &mut BitReader) -> Result<Self, FromBytesError>
21    where
22        Self: Sized;
23
24    fn to_abstract_bits(&self) -> Result<Vec<u8>, ToBytesError> {
25        let needed_bytes = Self::MAX_BITS.div_ceil(8);
26        let mut buffer = vec![0u8; needed_bytes];
27        let mut writer = BitWriter::from(buffer.as_mut_slice());
28        self.write_abstract_bits(&mut writer)?;
29        let bytes = writer.bytes_written();
30        buffer.truncate(bytes);
31        Ok(buffer)
32    }
33
34    fn from_abstract_bits(bytes: &[u8]) -> Result<Self, FromBytesError>
35    where
36        Self: Sized,
37    {
38        let mut reader = BitReader::from(bytes);
39        Self::read_abstract_bits(&mut reader)
40    }
41}
42
43macro_rules! impl_abstract_bits_for_UInt {
44    ($base_type:ty, $write_method:ident, $read_method: ident) => {
45        impl<const N: usize> AbstractBits for arbitrary_int::UInt<$base_type, N> {
46            const MIN_BITS: usize = Self::BITS;
47            const MAX_BITS: usize = Self::BITS;
48
49            fn write_abstract_bits(
50                &self,
51                writer: &mut BitWriter,
52            ) -> Result<(), ToBytesError> {
53                writer
54                    .$write_method(Self::BITS, self.value())
55                    .map_err(|cause| ToBytesError::BufferTooSmall {
56                        ty: std::any::type_name::<Self>(),
57                        cause,
58                    })
59            }
60
61            fn read_abstract_bits(reader: &mut BitReader) -> Result<Self, FromBytesError>
62            where
63                Self: Sized,
64            {
65                use FromBytesError::ReadPrimitive;
66                let value = reader.$read_method(Self::BITS).map_err(|cause| {
67                    ReadPrimitive(ReadErrorCause::NotEnoughInput {
68                        ty: std::any::type_name::<Self>(),
69                        cause,
70                    })
71                })?;
72                Ok(Self::new(value))
73            }
74        }
75    };
76}
77
78impl_abstract_bits_for_UInt! {u8, write_u8, read_u8}
79impl_abstract_bits_for_UInt! {u16, write_u16, read_u16}
80impl_abstract_bits_for_UInt! {u32, write_u32, read_u32}
81impl_abstract_bits_for_UInt! {u64, write_u64, read_u64}
82
83macro_rules! impl_abstract_bits_for_core_int {
84    ($type:ty, $write_method:ident, $read_method:ident, $bits:literal) => {
85        impl AbstractBits for $type {
86            const MIN_BITS: usize = core::mem::size_of::<Self>() * 8;
87            const MAX_BITS: usize = core::mem::size_of::<Self>() * 8;
88
89            fn write_abstract_bits(
90                &self,
91                writer: &mut BitWriter,
92            ) -> Result<(), ToBytesError> {
93                writer.$write_method($bits, *self).map_err(|cause| {
94                    ToBytesError::BufferTooSmall {
95                        ty: std::any::type_name::<Self>(),
96                        cause,
97                    }
98                })
99            }
100
101            fn read_abstract_bits(reader: &mut BitReader) -> Result<Self, FromBytesError>
102            where
103                Self: Sized,
104            {
105                use FromBytesError::ReadPrimitive;
106                reader.$read_method($bits).map_err(|cause| {
107                    ReadPrimitive(ReadErrorCause::NotEnoughInput {
108                        ty: std::any::type_name::<Self>(),
109                        cause,
110                    })
111                })
112            }
113        }
114    };
115}
116
117impl_abstract_bits_for_core_int! {u8, write_u8, read_u8, 8}
118impl_abstract_bits_for_core_int! {u16, write_u16, read_u16, 16}
119impl_abstract_bits_for_core_int! {u32, write_u32, read_u32, 32}
120impl_abstract_bits_for_core_int! {u64, write_u64, read_u64, 64}
121
122impl AbstractBits for bool {
123    const MIN_BITS: usize = 1;
124    const MAX_BITS: usize = 1;
125
126    fn write_abstract_bits(&self, writer: &mut BitWriter) -> Result<(), ToBytesError> {
127        writer
128            .write_bit(*self)
129            .map_err(|cause| ToBytesError::BufferTooSmall {
130                ty: core::any::type_name::<Self>(),
131                cause,
132            })
133    }
134
135    fn read_abstract_bits(reader: &mut BitReader) -> Result<Self, FromBytesError>
136    where
137        Self: Sized,
138    {
139        use FromBytesError::ReadPrimitive;
140        reader.read_bit().map_err(|cause| {
141            ReadPrimitive(ReadErrorCause::NotEnoughInput {
142                ty: core::any::type_name::<Self>(),
143                cause,
144            })
145        })
146    }
147}
148
149impl<const N: usize, T: AbstractBits + Sized> AbstractBits for [T; N] {
150    const MIN_BITS: usize = T::MIN_BITS * N;
151    const MAX_BITS: usize = T::MAX_BITS * N;
152
153    fn write_abstract_bits(&self, writer: &mut BitWriter) -> Result<(), ToBytesError> {
154        for element in self.iter() {
155            element.write_abstract_bits(writer)?;
156        }
157        Ok(())
158    }
159    fn read_abstract_bits(reader: &mut BitReader) -> Result<Self, FromBytesError>
160    where
161        Self: Sized,
162    {
163        let mut res = Vec::new();
164        for _ in 0..N {
165            res.push(T::read_abstract_bits(reader)?);
166        }
167
168        res.try_into()
169            .map_err(|_| unreachable!("for loop ensures vec length matches array's"))
170    }
171}
172
173pub struct BitReader<'a> {
174    pos: usize,
175    buf: &'a BitSlice<u8, Lsb0>,
176}
177
178#[derive(Debug, thiserror::Error, PartialEq, Eq)]
179#[error(
180    "Need to read beyond end of provided buffer to read {n_bits}. \
181    Buffer is missing {bits_needed} bits"
182)]
183pub struct UnexpectedEndOfBits {
184    n_bits: usize,
185    bits_needed: usize,
186}
187
188macro_rules! read_primitive {
189    ($name:ident, $ty:ty) => {
190        fn $name(&mut self, n_bits: usize) -> Result<$ty, UnexpectedEndOfBits> {
191            let mut res = <$ty>::default().to_le_bytes();
192            let res_bits = BitSlice::<_, Lsb0>::from_slice_mut(&mut res);
193            if self.buf.len() < self.pos + n_bits {
194                Err(UnexpectedEndOfBits {
195                    n_bits,
196                    bits_needed: self.pos + n_bits - self.buf.len(),
197                })
198            } else {
199                res_bits[0..n_bits]
200                    .copy_from_bitslice(&self.buf[self.pos..self.pos + n_bits]);
201                self.pos += n_bits;
202                Ok(<$ty>::from_le_bytes(res))
203            }
204        }
205    };
206}
207
208impl BitReader<'_> {
209    pub fn bits_read(&self) -> usize {
210        self.pos
211    }
212    /// 12 bits read corresponds to 2 bytes read
213    pub fn bytes_read(&self) -> usize {
214        self.pos.div_ceil(8)
215    }
216    pub fn skip(&mut self, n_bits: usize) -> Result<(), UnexpectedEndOfBits> {
217        if self.pos + n_bits > self.buf.len() {
218            Err(UnexpectedEndOfBits {
219                n_bits,
220                bits_needed: (self.pos + n_bits + 1) - self.buf.len(),
221            })
222        } else {
223            self.pos += n_bits;
224            Ok(())
225        }
226    }
227    fn read_bit(&mut self) -> Result<bool, UnexpectedEndOfBits> {
228        let Some(res) = self.buf.get(self.pos) else {
229            return Err(UnexpectedEndOfBits {
230                n_bits: 1,
231                bits_needed: 1,
232            });
233        };
234        self.pos += 1;
235        Ok(*res)
236    }
237
238    read_primitive! {read_u8, u8}
239    read_primitive! {read_u16, u16}
240    read_primitive! {read_u32, u32}
241    read_primitive! {read_u64, u64}
242}
243
244impl<'a> From<&'a [u8]> for BitReader<'a> {
245    fn from(bytes: &'a [u8]) -> Self {
246        Self {
247            pos: 0,
248            buf: BitSlice::from_slice(bytes),
249        }
250    }
251}
252
253pub struct BitWriter<'a> {
254    pos: usize,
255    buf: &'a mut BitSlice<u8, Lsb0>,
256}
257
258impl core::fmt::Debug for BitWriter<'_> {
259    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260        f.write_str("BitWriter\n")?;
261        f.write_fmt(format_args!("\tpos: {}\n", self.pos))?;
262        f.write_fmt(format_args!("\tbuf: BitSlice of {} bits\n", self.buf.len()))
263    }
264}
265
266#[derive(Debug, thiserror::Error, PartialEq, Eq)]
267#[error(
268    "Buffer is too small to serialize `{n_bits}` into. \
269    Buffer needs to be at least {bits_needed} bits extra"
270)]
271pub struct BufferTooSmall {
272    n_bits: usize,
273    bits_needed: usize,
274}
275
276macro_rules! write_primitive {
277    ($name:ident, $ty:ty) => {
278        fn $name(&mut self, n_bits: usize, val: $ty) -> Result<(), BufferTooSmall> {
279            let val = val.to_le_bytes();
280            let val = BitSlice::<_, Lsb0>::from_slice(&val);
281            if self.pos + n_bits > self.buf.len() {
282                Err(BufferTooSmall {
283                    n_bits,
284                    bits_needed: self.buf.len() - (self.pos + n_bits),
285                })
286            } else {
287                self.buf[self.pos..self.pos + n_bits].copy_from_bitslice(&val[..n_bits]);
288                self.pos += n_bits;
289                Ok(())
290            }
291        }
292    };
293}
294
295impl BitWriter<'_> {
296    pub fn bits_written(&self) -> usize {
297        self.pos
298    }
299    /// 12 bits read corresponds to 2 bytes read
300    pub fn bytes_written(&self) -> usize {
301        self.pos.div_ceil(8)
302    }
303    pub fn skip(&mut self, n_bits: usize) -> Result<(), BufferTooSmall> {
304        if self.pos + n_bits > self.buf.len() {
305            Err(BufferTooSmall {
306                n_bits,
307                bits_needed: (self.pos + n_bits + 1) - self.buf.len(),
308            })
309        } else {
310            self.pos += n_bits;
311            Ok(())
312        }
313    }
314    fn write_bit(&mut self, bit: bool) -> Result<(), BufferTooSmall> {
315        if self.pos >= self.buf.len() {
316            Err(BufferTooSmall {
317                n_bits: 1,
318                bits_needed: 1,
319            })
320        } else {
321            self.buf.set(self.pos, bit);
322            self.pos += 1;
323            Ok(())
324        }
325    }
326
327    write_primitive!(write_u8, u8);
328    write_primitive!(write_u16, u16);
329    write_primitive!(write_u32, u32);
330    write_primitive!(write_u64, u64);
331}
332
333impl<'a> From<&'a mut [u8]> for BitWriter<'a> {
334    fn from(buf: &'a mut [u8]) -> Self {
335        Self {
336            pos: 0,
337            buf: BitSlice::from_slice_mut(buf),
338        }
339    }
340}