1#![doc = include_str!("../README.md")]
2
3pub use abstract_bits_derive::abstract_bits;
4pub use arbitrary_int::{u1, u2, u3, u4, u5, u6, u7};
5pub use bitvec;
6use bitvec::order::Lsb0;
7use bitvec::slice::BitSlice;
8
9mod error;
10pub use error::{FromBytesError, ReadErrorCause, ToBytesError};
11
12pub trait AbstractBits {
13 const MIN_BITS: usize;
14 const MAX_BITS: usize;
15 fn write_abstract_bits(&self, writer: &mut BitWriter) -> Result<(), ToBytesError>;
18 fn read_abstract_bits(reader: &mut BitReader) -> Result<Self, FromBytesError>
21 where
22 Self: Sized;
23
24 fn to_abstract_bits(&self) -> Result<Vec<u8>, ToBytesError> {
25 let needed_bytes = Self::MAX_BITS.div_ceil(8);
26 let mut buffer = vec![0u8; needed_bytes];
27 let mut writer = BitWriter::from(buffer.as_mut_slice());
28 self.write_abstract_bits(&mut writer)?;
29 let bytes = writer.bytes_written();
30 buffer.truncate(bytes);
31 Ok(buffer)
32 }
33
34 fn from_abstract_bits(bytes: &[u8]) -> Result<Self, FromBytesError>
35 where
36 Self: Sized,
37 {
38 let mut reader = BitReader::from(bytes);
39 Self::read_abstract_bits(&mut reader)
40 }
41}
42
43macro_rules! impl_abstract_bits_for_UInt {
44 ($base_type:ty, $write_method:ident, $read_method: ident) => {
45 impl<const N: usize> AbstractBits for arbitrary_int::UInt<$base_type, N> {
46 const MIN_BITS: usize = Self::BITS;
47 const MAX_BITS: usize = Self::BITS;
48
49 fn write_abstract_bits(
50 &self,
51 writer: &mut BitWriter,
52 ) -> Result<(), ToBytesError> {
53 writer
54 .$write_method(Self::BITS, self.value())
55 .map_err(|cause| ToBytesError::BufferTooSmall {
56 ty: std::any::type_name::<Self>(),
57 cause,
58 })
59 }
60
61 fn read_abstract_bits(reader: &mut BitReader) -> Result<Self, FromBytesError>
62 where
63 Self: Sized,
64 {
65 use FromBytesError::ReadPrimitive;
66 let value = reader.$read_method(Self::BITS).map_err(|cause| {
67 ReadPrimitive(ReadErrorCause::NotEnoughInput {
68 ty: std::any::type_name::<Self>(),
69 cause,
70 })
71 })?;
72 Ok(Self::new(value))
73 }
74 }
75 };
76}
77
78impl_abstract_bits_for_UInt! {u8, write_u8, read_u8}
79impl_abstract_bits_for_UInt! {u16, write_u16, read_u16}
80impl_abstract_bits_for_UInt! {u32, write_u32, read_u32}
81impl_abstract_bits_for_UInt! {u64, write_u64, read_u64}
82
83macro_rules! impl_abstract_bits_for_core_int {
84 ($type:ty, $write_method:ident, $read_method:ident, $bits:literal) => {
85 impl AbstractBits for $type {
86 const MIN_BITS: usize = core::mem::size_of::<Self>() * 8;
87 const MAX_BITS: usize = core::mem::size_of::<Self>() * 8;
88
89 fn write_abstract_bits(
90 &self,
91 writer: &mut BitWriter,
92 ) -> Result<(), ToBytesError> {
93 writer.$write_method($bits, *self).map_err(|cause| {
94 ToBytesError::BufferTooSmall {
95 ty: std::any::type_name::<Self>(),
96 cause,
97 }
98 })
99 }
100
101 fn read_abstract_bits(reader: &mut BitReader) -> Result<Self, FromBytesError>
102 where
103 Self: Sized,
104 {
105 use FromBytesError::ReadPrimitive;
106 reader.$read_method($bits).map_err(|cause| {
107 ReadPrimitive(ReadErrorCause::NotEnoughInput {
108 ty: std::any::type_name::<Self>(),
109 cause,
110 })
111 })
112 }
113 }
114 };
115}
116
117impl_abstract_bits_for_core_int! {u8, write_u8, read_u8, 8}
118impl_abstract_bits_for_core_int! {u16, write_u16, read_u16, 16}
119impl_abstract_bits_for_core_int! {u32, write_u32, read_u32, 32}
120impl_abstract_bits_for_core_int! {u64, write_u64, read_u64, 64}
121
122impl AbstractBits for bool {
123 const MIN_BITS: usize = 1;
124 const MAX_BITS: usize = 1;
125
126 fn write_abstract_bits(&self, writer: &mut BitWriter) -> Result<(), ToBytesError> {
127 writer
128 .write_bit(*self)
129 .map_err(|cause| ToBytesError::BufferTooSmall {
130 ty: core::any::type_name::<Self>(),
131 cause,
132 })
133 }
134
135 fn read_abstract_bits(reader: &mut BitReader) -> Result<Self, FromBytesError>
136 where
137 Self: Sized,
138 {
139 use FromBytesError::ReadPrimitive;
140 reader.read_bit().map_err(|cause| {
141 ReadPrimitive(ReadErrorCause::NotEnoughInput {
142 ty: core::any::type_name::<Self>(),
143 cause,
144 })
145 })
146 }
147}
148
149impl<const N: usize, T: AbstractBits + Sized> AbstractBits for [T; N] {
150 const MIN_BITS: usize = T::MIN_BITS * N;
151 const MAX_BITS: usize = T::MAX_BITS * N;
152
153 fn write_abstract_bits(&self, writer: &mut BitWriter) -> Result<(), ToBytesError> {
154 for element in self.iter() {
155 element.write_abstract_bits(writer)?;
156 }
157 Ok(())
158 }
159 fn read_abstract_bits(reader: &mut BitReader) -> Result<Self, FromBytesError>
160 where
161 Self: Sized,
162 {
163 let mut res = Vec::new();
164 for _ in 0..N {
165 res.push(T::read_abstract_bits(reader)?);
166 }
167
168 res.try_into()
169 .map_err(|_| unreachable!("for loop ensures vec length matches array's"))
170 }
171}
172
173pub struct BitReader<'a> {
174 pos: usize,
175 buf: &'a BitSlice<u8, Lsb0>,
176}
177
178#[derive(Debug, thiserror::Error, PartialEq, Eq)]
179#[error(
180 "Need to read beyond end of provided buffer to read {n_bits}. \
181 Buffer is missing {bits_needed} bits"
182)]
183pub struct UnexpectedEndOfBits {
184 n_bits: usize,
185 bits_needed: usize,
186}
187
188macro_rules! read_primitive {
189 ($name:ident, $ty:ty) => {
190 fn $name(&mut self, n_bits: usize) -> Result<$ty, UnexpectedEndOfBits> {
191 let mut res = <$ty>::default().to_le_bytes();
192 let res_bits = BitSlice::<_, Lsb0>::from_slice_mut(&mut res);
193 if self.buf.len() < self.pos + n_bits {
194 Err(UnexpectedEndOfBits {
195 n_bits,
196 bits_needed: self.pos + n_bits - self.buf.len(),
197 })
198 } else {
199 res_bits[0..n_bits]
200 .copy_from_bitslice(&self.buf[self.pos..self.pos + n_bits]);
201 self.pos += n_bits;
202 Ok(<$ty>::from_le_bytes(res))
203 }
204 }
205 };
206}
207
208impl BitReader<'_> {
209 pub fn bits_read(&self) -> usize {
210 self.pos
211 }
212 pub fn bytes_read(&self) -> usize {
214 self.pos.div_ceil(8)
215 }
216 pub fn skip(&mut self, n_bits: usize) -> Result<(), UnexpectedEndOfBits> {
217 if self.pos + n_bits > self.buf.len() {
218 Err(UnexpectedEndOfBits {
219 n_bits,
220 bits_needed: (self.pos + n_bits + 1) - self.buf.len(),
221 })
222 } else {
223 self.pos += n_bits;
224 Ok(())
225 }
226 }
227 fn read_bit(&mut self) -> Result<bool, UnexpectedEndOfBits> {
228 let Some(res) = self.buf.get(self.pos) else {
229 return Err(UnexpectedEndOfBits {
230 n_bits: 1,
231 bits_needed: 1,
232 });
233 };
234 self.pos += 1;
235 Ok(*res)
236 }
237
238 read_primitive! {read_u8, u8}
239 read_primitive! {read_u16, u16}
240 read_primitive! {read_u32, u32}
241 read_primitive! {read_u64, u64}
242}
243
244impl<'a> From<&'a [u8]> for BitReader<'a> {
245 fn from(bytes: &'a [u8]) -> Self {
246 Self {
247 pos: 0,
248 buf: BitSlice::from_slice(bytes),
249 }
250 }
251}
252
253pub struct BitWriter<'a> {
254 pos: usize,
255 buf: &'a mut BitSlice<u8, Lsb0>,
256}
257
258impl core::fmt::Debug for BitWriter<'_> {
259 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260 f.write_str("BitWriter\n")?;
261 f.write_fmt(format_args!("\tpos: {}\n", self.pos))?;
262 f.write_fmt(format_args!("\tbuf: BitSlice of {} bits\n", self.buf.len()))
263 }
264}
265
266#[derive(Debug, thiserror::Error, PartialEq, Eq)]
267#[error(
268 "Buffer is too small to serialize `{n_bits}` into. \
269 Buffer needs to be at least {bits_needed} bits extra"
270)]
271pub struct BufferTooSmall {
272 n_bits: usize,
273 bits_needed: usize,
274}
275
276macro_rules! write_primitive {
277 ($name:ident, $ty:ty) => {
278 fn $name(&mut self, n_bits: usize, val: $ty) -> Result<(), BufferTooSmall> {
279 let val = val.to_le_bytes();
280 let val = BitSlice::<_, Lsb0>::from_slice(&val);
281 if self.pos + n_bits > self.buf.len() {
282 Err(BufferTooSmall {
283 n_bits,
284 bits_needed: self.buf.len() - (self.pos + n_bits),
285 })
286 } else {
287 self.buf[self.pos..self.pos + n_bits].copy_from_bitslice(&val[..n_bits]);
288 self.pos += n_bits;
289 Ok(())
290 }
291 }
292 };
293}
294
295impl BitWriter<'_> {
296 pub fn bits_written(&self) -> usize {
297 self.pos
298 }
299 pub fn bytes_written(&self) -> usize {
301 self.pos.div_ceil(8)
302 }
303 pub fn skip(&mut self, n_bits: usize) -> Result<(), BufferTooSmall> {
304 if self.pos + n_bits > self.buf.len() {
305 Err(BufferTooSmall {
306 n_bits,
307 bits_needed: (self.pos + n_bits + 1) - self.buf.len(),
308 })
309 } else {
310 self.pos += n_bits;
311 Ok(())
312 }
313 }
314 fn write_bit(&mut self, bit: bool) -> Result<(), BufferTooSmall> {
315 if self.pos >= self.buf.len() {
316 Err(BufferTooSmall {
317 n_bits: 1,
318 bits_needed: 1,
319 })
320 } else {
321 self.buf.set(self.pos, bit);
322 self.pos += 1;
323 Ok(())
324 }
325 }
326
327 write_primitive!(write_u8, u8);
328 write_primitive!(write_u16, u16);
329 write_primitive!(write_u32, u32);
330 write_primitive!(write_u64, u64);
331}
332
333impl<'a> From<&'a mut [u8]> for BitWriter<'a> {
334 fn from(buf: &'a mut [u8]) -> Self {
335 Self {
336 pos: 0,
337 buf: BitSlice::from_slice_mut(buf),
338 }
339 }
340}