1use crate::error::Error;
8
9#[derive(Default)]
11pub struct BitWriter {
12 bytes: Vec<u8>,
14 bit_position: usize,
16}
17
18impl BitWriter {
19 pub fn new() -> Self {
21 Self {
22 bytes: Vec::new(),
23 bit_position: 0,
24 }
25 }
26
27 pub fn write_bits(&mut self, value: u64, count: usize) {
30 if count == 0 {
31 return;
32 }
33 debug_assert!(count <= 64, "write_bits count must be ≤ 64");
34
35 let masked = if count == 64 {
37 value
38 } else {
39 value & ((1u64 << count) - 1)
40 };
41
42 let mut remaining = count;
44 while remaining > 0 {
45 if self.bit_position == 0 {
47 self.bytes.push(0);
48 }
49 let last = self.bytes.last_mut().unwrap();
50
51 let free_in_byte = 8 - self.bit_position;
53 let chunk = remaining.min(free_in_byte);
54
55 let shift = (remaining - chunk) as u32;
57 let bits = ((masked >> shift) & ((1u64 << chunk) - 1)) as u8;
58
59 let byte_shift = (free_in_byte - chunk) as u32;
61 *last |= bits << byte_shift;
62
63 self.bit_position += chunk;
64 if self.bit_position == 8 {
65 self.bit_position = 0;
66 }
67 remaining -= chunk;
68 }
69 }
70
71 pub fn bit_len(&self) -> usize {
73 if self.bit_position == 0 {
74 self.bytes.len() * 8
75 } else {
76 (self.bytes.len() - 1) * 8 + self.bit_position
77 }
78 }
79
80 pub fn into_bytes(self) -> Vec<u8> {
82 self.bytes
83 }
84}
85
86pub struct BitReader<'a> {
90 bytes: &'a [u8],
92 bit_position: usize,
94 bit_limit: usize,
96}
97
98impl<'a> BitReader<'a> {
99 pub fn new(bytes: &'a [u8]) -> Self {
102 Self {
103 bytes,
104 bit_position: 0,
105 bit_limit: bytes.len() * 8,
106 }
107 }
108
109 pub fn with_bit_limit(bytes: &'a [u8], bit_limit: usize) -> Self {
114 debug_assert!(bit_limit <= bytes.len() * 8);
115 Self {
116 bytes,
117 bit_position: 0,
118 bit_limit,
119 }
120 }
121
122 pub fn read_bits(&mut self, count: usize) -> Result<u64, Error> {
124 if count == 0 {
125 return Ok(0);
126 }
127 debug_assert!(count <= 64, "read_bits count must be ≤ 64");
128 if self.remaining_bits() < count {
129 return Err(Error::BitStreamTruncated {
130 requested: count,
131 available: self.remaining_bits(),
132 });
133 }
134
135 let mut result: u64 = 0;
136 let mut remaining = count;
137 while remaining > 0 {
138 let byte_idx = self.bit_position / 8;
139 let bit_in_byte = self.bit_position % 8; let free_in_byte = 8 - bit_in_byte;
141 let chunk = remaining.min(free_in_byte);
142
143 let byte = self.bytes[byte_idx];
145 let shift = (free_in_byte - chunk) as u32;
146 let mask: u8 = if chunk == 8 { 0xff } else { (1u8 << chunk) - 1 };
148 let bits = (byte >> shift) & mask;
149
150 result = (result << chunk) | bits as u64;
151 self.bit_position += chunk;
152 remaining -= chunk;
153 }
154 Ok(result)
155 }
156
157 pub(crate) fn bit_position(&self) -> usize {
160 self.bit_position
161 }
162
163 pub fn remaining_bits(&self) -> usize {
165 self.bit_limit.saturating_sub(self.bit_position)
166 }
167
168 pub fn is_exhausted(&self) -> bool {
170 self.remaining_bits() == 0
171 }
172
173 pub fn save_position(&self) -> usize {
177 self.bit_position
178 }
179
180 pub fn restore_position(&mut self, saved: usize) {
182 debug_assert!(saved <= self.bit_limit);
183 self.bit_position = saved;
184 }
185
186 pub(crate) fn save_bit_limit(&self) -> usize {
190 self.bit_limit
191 }
192
193 pub(crate) fn set_bit_limit_for_scope(&mut self, new_limit: usize) {
199 debug_assert!(new_limit >= self.bit_position);
200 debug_assert!(new_limit <= self.bit_limit);
201 self.bit_limit = new_limit;
202 }
203
204 pub(crate) fn restore_bit_limit(&mut self, saved: usize) {
206 debug_assert!(self.bit_position <= saved);
207 self.bit_limit = saved;
208 }
209}
210
211pub fn re_emit_bits(dst: &mut BitWriter, src_bytes: &[u8], bit_len: usize) -> Result<(), Error> {
221 let mut src_reader = BitReader::with_bit_limit(src_bytes, bit_len);
222 let mut remaining = bit_len;
223 while remaining > 0 {
224 let chunk = remaining.min(8);
225 let bits = src_reader.read_bits(chunk)?;
226 dst.write_bits(bits, chunk);
227 remaining -= chunk;
228 }
229 Ok(())
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn write_5_bits_msb_first() {
238 let mut w = BitWriter::new();
239 w.write_bits(0b10110, 5);
240 assert_eq!(w.into_bytes(), vec![0b1011_0000]);
243 }
244
245 #[test]
246 fn write_two_5_bit_values_packs_into_one_and_a_bit() {
247 let mut w = BitWriter::new();
248 w.write_bits(0b11111, 5);
249 w.write_bits(0b00001, 5);
250 assert_eq!(w.into_bytes(), vec![0b1111_1000, 0b0100_0000]);
253 }
254
255 #[test]
256 fn write_8_bits_is_one_byte() {
257 let mut w = BitWriter::new();
258 w.write_bits(0xab, 8);
259 assert_eq!(w.into_bytes(), vec![0xab]);
260 }
261
262 #[test]
263 fn write_zero_bits_is_noop() {
264 let mut w = BitWriter::new();
265 w.write_bits(0xff, 0);
266 assert_eq!(w.bit_len(), 0);
267 assert_eq!(w.into_bytes(), Vec::<u8>::new());
268 }
269
270 #[test]
271 fn round_trip_5_bit_values() {
272 let mut w = BitWriter::new();
273 w.write_bits(0b10110, 5);
274 w.write_bits(0b00001, 5);
275 let bytes = w.into_bytes();
276
277 let mut r = BitReader::new(&bytes);
278 assert_eq!(r.read_bits(5).unwrap(), 0b10110);
279 assert_eq!(r.read_bits(5).unwrap(), 0b00001);
280 }
281
282 #[test]
283 fn read_past_end_errors() {
284 let bytes = vec![0xff];
285 let mut r = BitReader::new(&bytes);
286 assert!(r.read_bits(9).is_err());
287 assert_eq!(r.remaining_bits(), 8);
289 }
290
291 #[test]
292 fn read_full_byte_aligned() {
293 let bytes = vec![0xab, 0xcd];
294 let mut r = BitReader::new(&bytes);
295 assert_eq!(r.read_bits(8).unwrap(), 0xab);
296 assert_eq!(r.read_bits(8).unwrap(), 0xcd);
297 }
298
299 #[test]
300 fn save_and_restore_position() {
301 let bytes = vec![0b1011_0010, 0b0100_0000];
302 let mut r = BitReader::new(&bytes);
303 let saved = r.save_position();
304 let _ = r.read_bits(5).unwrap();
305 assert_eq!(r.save_position(), 5);
306 r.restore_position(saved);
307 assert_eq!(r.read_bits(5).unwrap(), 0b10110);
308 }
309
310 #[test]
311 fn with_bit_limit_excludes_padding() {
312 let mut w = BitWriter::new();
314 w.write_bits(0b10110, 5);
315 let bytes = w.into_bytes(); let mut r = BitReader::with_bit_limit(&bytes, 5);
318 assert_eq!(r.read_bits(5).unwrap(), 0b10110);
319 assert!(r.is_exhausted());
320 assert!(r.read_bits(1).is_err());
322 }
323
324 #[test]
325 fn re_emit_bits_round_trip_byte_aligned() {
326 let mut src = BitWriter::new();
328 src.write_bits(0xab, 8);
329 let src_bit_len = src.bit_len();
330 let src_bytes = src.into_bytes();
331
332 let mut dst = BitWriter::new();
333 re_emit_bits(&mut dst, &src_bytes, src_bit_len).unwrap();
334
335 assert_eq!(dst.bit_len(), 8);
336 let dst_bytes = dst.into_bytes();
337 assert_eq!(dst_bytes, vec![0xab]);
338 }
339
340 #[test]
341 fn re_emit_bits_round_trip_all_widths_1_through_23() {
342 for width in 1..=23usize {
346 let pattern: u64 = if width == 64 {
347 0xffff_ffff_ffff_ffff
348 } else {
349 (1u64 << width) - 1
350 } & 0xa5_a5_a5_a5_a5_a5_a5_a5; let mut src = BitWriter::new();
353 src.write_bits(pattern, width);
354 let src_bit_len = src.bit_len();
355 let src_bytes = src.into_bytes();
356 assert_eq!(src_bit_len, width);
357
358 let mut dst = BitWriter::new();
359 re_emit_bits(&mut dst, &src_bytes, width).unwrap();
360 assert_eq!(dst.bit_len(), width);
361
362 let dst_bytes = dst.into_bytes();
363 let mut r = BitReader::with_bit_limit(&dst_bytes, width);
364 assert_eq!(r.read_bits(width).unwrap(), pattern, "width={width}");
365 }
366 }
367
368 #[test]
369 fn re_emit_bits_non_byte_aligned_source() {
370 let mut src = BitWriter::new();
372 src.write_bits(0b10110, 5);
373 src.write_bits(0b1010101, 7);
374 let src_bit_len = src.bit_len();
375 assert_eq!(src_bit_len, 12);
376 let src_bytes = src.into_bytes();
377
378 let mut dst = BitWriter::new();
379 re_emit_bits(&mut dst, &src_bytes, src_bit_len).unwrap();
380 assert_eq!(dst.bit_len(), 12);
381
382 let dst_bytes = dst.into_bytes();
383 let mut r = BitReader::with_bit_limit(&dst_bytes, 12);
384 assert_eq!(r.read_bits(5).unwrap(), 0b10110);
385 assert_eq!(r.read_bits(7).unwrap(), 0b1010101);
386 }
387
388 #[test]
389 fn re_emit_bits_appends_to_existing_dst() {
390 let mut dst = BitWriter::new();
393 dst.write_bits(0b101, 3);
394
395 let mut src = BitWriter::new();
396 src.write_bits(0b1_1110_0001, 9);
397 let src_bit_len = src.bit_len();
398 let src_bytes = src.into_bytes();
399
400 re_emit_bits(&mut dst, &src_bytes, src_bit_len).unwrap();
401 assert_eq!(dst.bit_len(), 12);
402
403 let dst_bytes = dst.into_bytes();
404 let mut r = BitReader::with_bit_limit(&dst_bytes, 12);
405 assert_eq!(r.read_bits(3).unwrap(), 0b101);
406 assert_eq!(r.read_bits(9).unwrap(), 0b1_1110_0001);
407 }
408}