1use core::fmt;
29
30#[derive(Debug, Clone, PartialEq, Eq)]
32#[non_exhaustive]
33pub enum BitError {
34 OutOfBounds {
36 needed_bits: usize,
38 remaining_bits: usize,
40 },
41 TooManyBits {
43 requested: u32,
45 },
46 ValueTooWide {
48 value: u64,
50 bits: u32,
52 },
53}
54
55impl fmt::Display for BitError {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 match self {
58 BitError::OutOfBounds {
59 needed_bits,
60 remaining_bits,
61 } => write!(
62 f,
63 "bit buffer out of bounds: need {needed_bits} bit(s), {remaining_bits} remaining"
64 ),
65 BitError::TooManyBits { requested } => {
66 write!(f, "requested {requested} bits exceeds the 64-bit carrier")
67 }
68 BitError::ValueTooWide { value, bits } => {
69 write!(f, "value {value:#x} does not fit in {bits} bit(s)")
70 }
71 }
72 }
73}
74
75impl std::error::Error for BitError {}
76
77#[derive(Debug, Clone)]
79pub struct BitReader<'a> {
80 data: &'a [u8],
81 bit_pos: usize,
83}
84
85impl<'a> BitReader<'a> {
86 #[must_use]
88 pub fn new(data: &'a [u8]) -> Self {
89 Self { data, bit_pos: 0 }
90 }
91
92 #[must_use]
94 pub fn total_bits(&self) -> usize {
95 self.data.len() * 8
96 }
97
98 #[must_use]
100 pub fn bits_read(&self) -> usize {
101 self.bit_pos
102 }
103
104 #[must_use]
106 pub fn bits_remaining(&self) -> usize {
107 self.total_bits() - self.bit_pos
108 }
109
110 #[must_use]
112 pub fn is_byte_aligned(&self) -> bool {
113 self.bit_pos % 8 == 0
114 }
115
116 pub fn read_bits(&mut self, n: u32) -> Result<u64, BitError> {
124 if n > 64 {
125 return Err(BitError::TooManyBits { requested: n });
126 }
127 if n == 0 {
128 return Ok(0);
129 }
130 let need = n as usize;
131 let remaining = self.bits_remaining();
132 if need > remaining {
133 return Err(BitError::OutOfBounds {
134 needed_bits: need,
135 remaining_bits: remaining,
136 });
137 }
138 let mut value: u64 = 0;
139 for _ in 0..n {
140 let byte = self.data[self.bit_pos / 8];
141 let bit_index = 7 - (self.bit_pos % 8); let bit = (byte >> bit_index) & 1;
143 value = (value << 1) | u64::from(bit);
144 self.bit_pos += 1;
145 }
146 Ok(value)
147 }
148
149 pub fn read_bool(&mut self) -> Result<bool, BitError> {
154 Ok(self.read_bits(1)? != 0)
155 }
156
157 pub fn skip_bits(&mut self, n: usize) -> Result<(), BitError> {
162 let remaining = self.bits_remaining();
163 if n > remaining {
164 return Err(BitError::OutOfBounds {
165 needed_bits: n,
166 remaining_bits: remaining,
167 });
168 }
169 self.bit_pos += n;
170 Ok(())
171 }
172
173 pub fn align_to_byte(&mut self) {
175 let rem = self.bit_pos % 8;
176 if rem != 0 {
177 self.bit_pos += 8 - rem;
178 }
179 }
180}
181
182#[derive(Debug)]
186pub struct BitWriter<'a> {
187 data: &'a mut [u8],
188 bit_pos: usize,
189}
190
191impl<'a> BitWriter<'a> {
192 #[must_use]
194 pub fn new(data: &'a mut [u8]) -> Self {
195 Self { data, bit_pos: 0 }
196 }
197
198 #[must_use]
200 pub fn capacity_bits(&self) -> usize {
201 self.data.len() * 8
202 }
203
204 #[must_use]
206 pub fn bits_written(&self) -> usize {
207 self.bit_pos
208 }
209
210 #[must_use]
212 pub fn is_byte_aligned(&self) -> bool {
213 self.bit_pos % 8 == 0
214 }
215
216 pub fn write_bits(&mut self, value: u64, n: u32) -> Result<(), BitError> {
225 if n > 64 {
226 return Err(BitError::TooManyBits { requested: n });
227 }
228 if n == 0 {
229 return Ok(());
230 }
231 if n < 64 && value >= (1u64 << n) {
233 return Err(BitError::ValueTooWide { value, bits: n });
234 }
235 let need = n as usize;
236 let remaining = self.capacity_bits() - self.bit_pos;
237 if need > remaining {
238 return Err(BitError::OutOfBounds {
239 needed_bits: need,
240 remaining_bits: remaining,
241 });
242 }
243 for i in (0..n).rev() {
244 let bit = ((value >> i) & 1) as u8;
245 let byte_idx = self.bit_pos / 8;
246 let bit_index = 7 - (self.bit_pos % 8);
247 if bit == 1 {
248 self.data[byte_idx] |= 1 << bit_index;
249 } else {
250 self.data[byte_idx] &= !(1u8 << bit_index);
251 }
252 self.bit_pos += 1;
253 }
254 Ok(())
255 }
256
257 pub fn write_bool(&mut self, value: bool) -> Result<(), BitError> {
262 self.write_bits(u64::from(value), 1)
263 }
264
265 pub fn align_to_byte(&mut self) -> Result<(), BitError> {
270 let rem = self.bit_pos % 8;
271 if rem != 0 {
272 self.write_bits(0, (8 - rem) as u32)?;
273 }
274 Ok(())
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 #![allow(clippy::unusual_byte_groupings)]
284
285 use super::*;
286
287 #[test]
288 fn single_byte_fields_round_trip() {
289 let mut buf = [0u8; 1];
290 let mut w = BitWriter::new(&mut buf);
291 w.write_bits(0b1, 1).unwrap();
292 w.write_bits(0b01, 2).unwrap();
293 w.write_bits(0b10101, 5).unwrap();
294 assert_eq!(w.bits_written(), 8);
295 assert_eq!(buf[0], 0b1_01_10101);
297
298 let mut r = BitReader::new(&buf);
299 assert_eq!(r.read_bits(1).unwrap(), 0b1);
300 assert_eq!(r.read_bits(2).unwrap(), 0b01);
301 assert_eq!(r.read_bits(5).unwrap(), 0b10101);
302 assert_eq!(r.bits_remaining(), 0);
303 }
304
305 #[test]
306 fn field_crossing_byte_boundary_round_trips() {
307 let mut buf = [0u8; 4];
309 let mut w = BitWriter::new(&mut buf);
310 w.write_bits(0b101, 3).unwrap();
311 w.write_bits(0b10_1010_1010_1010_1011, 18).unwrap();
312 let val18 = 0b10_1010_1010_1010_1011u64;
313
314 let mut r = BitReader::new(&buf);
315 assert_eq!(r.read_bits(3).unwrap(), 0b101);
316 assert_eq!(r.read_bits(18).unwrap(), val18);
317 }
318
319 #[test]
320 fn read_zero_bits_is_noop() {
321 let buf = [0xFFu8];
322 let mut r = BitReader::new(&buf);
323 assert_eq!(r.read_bits(0).unwrap(), 0);
324 assert_eq!(r.bits_read(), 0);
325 }
326
327 #[test]
328 fn full_64_bit_field() {
329 let mut buf = [0u8; 8];
330 let value = 0xDEAD_BEEF_CAFE_F00Du64;
331 let mut w = BitWriter::new(&mut buf);
332 w.write_bits(value, 64).unwrap();
333 let mut r = BitReader::new(&buf);
334 assert_eq!(r.read_bits(64).unwrap(), value);
335 }
336
337 #[test]
338 fn read_past_end_errs() {
339 let buf = [0xFFu8]; let mut r = BitReader::new(&buf);
341 r.read_bits(7).unwrap();
342 let err = r.read_bits(2).unwrap_err();
343 assert_eq!(
344 err,
345 BitError::OutOfBounds {
346 needed_bits: 2,
347 remaining_bits: 1,
348 }
349 );
350 }
351
352 #[test]
353 fn read_too_many_bits_errs() {
354 let buf = [0u8; 16];
355 let mut r = BitReader::new(&buf);
356 assert_eq!(
357 r.read_bits(65).unwrap_err(),
358 BitError::TooManyBits { requested: 65 }
359 );
360 }
361
362 #[test]
363 fn write_value_too_wide_errs() {
364 let mut buf = [0u8; 4];
365 let mut w = BitWriter::new(&mut buf);
366 assert_eq!(
368 w.write_bits(0b100, 2).unwrap_err(),
369 BitError::ValueTooWide {
370 value: 0b100,
371 bits: 2
372 }
373 );
374 }
375
376 #[test]
377 fn write_past_end_errs() {
378 let mut buf = [0u8; 1];
379 let mut w = BitWriter::new(&mut buf);
380 w.write_bits(0, 7).unwrap();
381 assert_eq!(
382 w.write_bits(0b11, 2).unwrap_err(),
383 BitError::OutOfBounds {
384 needed_bits: 2,
385 remaining_bits: 1,
386 }
387 );
388 }
389
390 #[test]
391 fn writer_does_not_require_zeroed_buffer() {
392 let mut buf = [0xFFu8; 1];
394 let mut w = BitWriter::new(&mut buf);
395 w.write_bits(0b0000_0000, 8).unwrap();
396 assert_eq!(buf[0], 0x00);
397 }
398
399 #[test]
400 fn bool_round_trips() {
401 let mut buf = [0u8; 1];
402 let mut w = BitWriter::new(&mut buf);
403 w.write_bool(true).unwrap();
404 w.write_bool(false).unwrap();
405 w.write_bool(true).unwrap();
406 let mut r = BitReader::new(&buf);
407 assert!(r.read_bool().unwrap());
408 assert!(!r.read_bool().unwrap());
409 assert!(r.read_bool().unwrap());
410 }
411
412 #[test]
413 fn skip_and_align() {
414 let buf = [0b1010_1100u8, 0b1111_0000];
415 let mut r = BitReader::new(&buf);
416 r.read_bits(2).unwrap(); r.skip_bits(3).unwrap(); assert!(!r.is_byte_aligned());
419 r.align_to_byte(); assert!(r.is_byte_aligned());
421 assert_eq!(r.read_bits(4).unwrap(), 0b1111);
422 }
423
424 #[test]
425 fn writer_align_pads_with_zero() {
426 let mut buf = [0xFFu8; 1];
427 let mut w = BitWriter::new(&mut buf);
428 w.write_bits(0b101, 3).unwrap();
429 w.align_to_byte().unwrap();
430 assert_eq!(w.bits_written(), 8);
431 assert_eq!(buf[0], 0b1010_0000); }
433
434 #[test]
435 fn exhaustive_small_width_round_trip() {
436 for bits in 1u32..=16 {
438 let max = if bits == 64 {
439 u64::MAX
440 } else {
441 (1u64 << bits) - 1
442 };
443 for value in [0u64, 1, max, max / 2] {
444 let mut buf = [0u8; 8];
445 let mut w = BitWriter::new(&mut buf);
446 w.write_bits(value, bits).unwrap();
447 let mut r = BitReader::new(&buf);
448 assert_eq!(
449 r.read_bits(bits).unwrap(),
450 value,
451 "round-trip failed: value={value:#x} bits={bits}"
452 );
453 }
454 }
455 }
456}