Skip to main content

md_codec/
bitstream.rs

1//! Bit-aligned reader and writer.
2//!
3//! Per spec §4.6: bits are packed MSB-first into bytes. The first bit of the
4//! payload occupies the most-significant bit of the first byte. The final byte
5//! is zero-padded if needed.
6
7use crate::error::Error;
8
9/// MSB-first bit packer.
10#[derive(Default)]
11pub struct BitWriter {
12    /// Backing byte buffer; the last byte is the in-progress byte.
13    bytes: Vec<u8>,
14    /// Bit offset within the last byte, in `0..8`. Zero means no in-progress byte.
15    bit_position: usize,
16}
17
18impl BitWriter {
19    /// Create an empty `BitWriter`.
20    pub fn new() -> Self {
21        Self {
22            bytes: Vec::new(),
23            bit_position: 0,
24        }
25    }
26
27    /// Write `count` bits from `value` (LSB-aligned in `value`) into the
28    /// stream MSB-first. Bits beyond `count` in `value` are ignored.
29    pub fn write_bits(&mut self, value: u64, count: usize) {
30        if count == 0 {
31            return;
32        }
33        debug_assert!(count <= 64, "write_bits count must be ≤ 64");
34
35        // Mask `value` to the requested bit count.
36        let masked = if count == 64 {
37            value
38        } else {
39            value & ((1u64 << count) - 1)
40        };
41
42        // Iterate from MSB to LSB of the requested value.
43        let mut remaining = count;
44        while remaining > 0 {
45            // Ensure there's a current byte to write into.
46            if self.bit_position == 0 {
47                self.bytes.push(0);
48            }
49            let last = self.bytes.last_mut().unwrap();
50
51            // How many bits free in the current byte (from bit_position MSB-side)?
52            let free_in_byte = 8 - self.bit_position;
53            let chunk = remaining.min(free_in_byte);
54
55            // Pull `chunk` bits from the top of the masked value.
56            let shift = (remaining - chunk) as u32;
57            let bits = ((masked >> shift) & ((1u64 << chunk) - 1)) as u8;
58
59            // Place bits into the byte at the correct offset (MSB-first).
60            let byte_shift = (free_in_byte - chunk) as u32;
61            *last |= bits << byte_shift;
62
63            self.bit_position += chunk;
64            if self.bit_position == 8 {
65                self.bit_position = 0;
66            }
67            remaining -= chunk;
68        }
69    }
70
71    /// Total number of bits written.
72    pub fn bit_len(&self) -> usize {
73        if self.bit_position == 0 {
74            self.bytes.len() * 8
75        } else {
76            (self.bytes.len() - 1) * 8 + self.bit_position
77        }
78    }
79
80    /// Consume self and produce the byte stream (final byte zero-padded).
81    pub fn into_bytes(self) -> Vec<u8> {
82        self.bytes
83    }
84}
85
86// --- BitReader ---
87
88/// MSB-first bit unpacker over a borrowed byte slice.
89pub struct BitReader<'a> {
90    /// Backing byte slice.
91    bytes: &'a [u8],
92    /// Total bits consumed so far (counted from the MSB of `bytes[0]`).
93    bit_position: usize,
94    /// Total bits available; defaults to `bytes.len() * 8`.
95    bit_limit: usize,
96}
97
98impl<'a> BitReader<'a> {
99    /// Reader that consumes exactly `bytes.len() * 8` bits (used by tests
100    /// where the bit count is byte-aligned).
101    pub fn new(bytes: &'a [u8]) -> Self {
102        Self {
103            bytes,
104            bit_position: 0,
105            bit_limit: bytes.len() * 8,
106        }
107    }
108
109    /// Reader that consumes at most `bit_limit` bits — required when the
110    /// payload's exact bit length is shorter than the byte buffer (zero-padding).
111    /// Per spec §3.7, the TLV section ends when total bits are exhausted; the
112    /// decoder must know `bit_limit` to avoid reading padding bits as TLV data.
113    pub fn with_bit_limit(bytes: &'a [u8], bit_limit: usize) -> Self {
114        debug_assert!(bit_limit <= bytes.len() * 8);
115        Self {
116            bytes,
117            bit_position: 0,
118            bit_limit,
119        }
120    }
121
122    /// Read `count` bits MSB-first; returns the value LSB-aligned.
123    pub fn read_bits(&mut self, count: usize) -> Result<u64, Error> {
124        if count == 0 {
125            return Ok(0);
126        }
127        debug_assert!(count <= 64, "read_bits count must be ≤ 64");
128        if self.remaining_bits() < count {
129            return Err(Error::BitStreamTruncated {
130                requested: count,
131                available: self.remaining_bits(),
132            });
133        }
134
135        let mut result: u64 = 0;
136        let mut remaining = count;
137        while remaining > 0 {
138            let byte_idx = self.bit_position / 8;
139            let bit_in_byte = self.bit_position % 8; // 0 = MSB
140            let free_in_byte = 8 - bit_in_byte;
141            let chunk = remaining.min(free_in_byte);
142
143            // Extract `chunk` bits starting at bit_in_byte from the MSB side.
144            let byte = self.bytes[byte_idx];
145            let shift = (free_in_byte - chunk) as u32;
146            // Note: `1u8 << 8` overflows; guard explicitly.
147            let mask: u8 = if chunk == 8 { 0xff } else { (1u8 << chunk) - 1 };
148            let bits = (byte >> shift) & mask;
149
150            result = (result << chunk) | bits as u64;
151            self.bit_position += chunk;
152            remaining -= chunk;
153        }
154        Ok(result)
155    }
156
157    /// Returns the current bit position within the stream. Used by the TLV
158    /// decoder to measure consumed bits within a length-delimited region.
159    pub(crate) fn bit_position(&self) -> usize {
160        self.bit_position
161    }
162
163    /// Bits remaining unread (within the configured bit limit).
164    pub fn remaining_bits(&self) -> usize {
165        self.bit_limit.saturating_sub(self.bit_position)
166    }
167
168    /// Whether the stream is exhausted.
169    pub fn is_exhausted(&self) -> bool {
170        self.remaining_bits() == 0
171    }
172
173    /// Snapshot the current bit position for rollback. Used by the TLV
174    /// decoder loop to handle graceful end-of-stream when trailing
175    /// codex32-padding bits look like a partial TLV.
176    pub fn save_position(&self) -> usize {
177        self.bit_position
178    }
179
180    /// Restore a previously saved bit position.
181    pub fn restore_position(&mut self, saved: usize) {
182        debug_assert!(saved <= self.bit_limit);
183        self.bit_position = saved;
184    }
185
186    /// Snapshot the current bit_limit for later restoration. Paired with
187    /// [`Self::set_bit_limit_for_scope`] when reading a length-delimited
188    /// sub-region (e.g., a TLV body).
189    pub(crate) fn save_bit_limit(&self) -> usize {
190        self.bit_limit
191    }
192
193    /// Tighten the bit_limit to bound the next read operations. The new
194    /// limit MUST be ≥ `bit_position` (callers already past the new
195    /// limit would see truncation immediately) and ≤ the previous
196    /// limit (cannot widen). Use [`Self::save_bit_limit`] to capture the
197    /// prior limit and [`Self::restore_bit_limit`] to restore.
198    pub(crate) fn set_bit_limit_for_scope(&mut self, new_limit: usize) {
199        debug_assert!(new_limit >= self.bit_position);
200        debug_assert!(new_limit <= self.bit_limit);
201        self.bit_limit = new_limit;
202    }
203
204    /// Restore a previously saved bit_limit.
205    pub(crate) fn restore_bit_limit(&mut self, saved: usize) {
206        debug_assert!(self.bit_position <= saved);
207        self.bit_limit = saved;
208    }
209}
210
211/// Reads exactly `bit_len` MSB-first bits from `src_bytes` and appends them
212/// to `dst`. Bits are sourced as if `src_bytes` were the output of a
213/// `BitWriter` finalized with `into_bytes()` (so the trailing partial byte
214/// is in the high bits of the last source byte). The destination is
215/// extended in-place — no padding inserted.
216///
217/// Generalizes the read-bits-then-write-bits pattern used by the TLV
218/// encoder when re-emitting a sub-bitstream's bits into the outer wire
219/// without 1-bit drift on non-byte-aligned boundaries.
220pub fn re_emit_bits(dst: &mut BitWriter, src_bytes: &[u8], bit_len: usize) -> Result<(), Error> {
221    let mut src_reader = BitReader::with_bit_limit(src_bytes, bit_len);
222    let mut remaining = bit_len;
223    while remaining > 0 {
224        let chunk = remaining.min(8);
225        let bits = src_reader.read_bits(chunk)?;
226        dst.write_bits(bits, chunk);
227        remaining -= chunk;
228    }
229    Ok(())
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn write_5_bits_msb_first() {
238        let mut w = BitWriter::new();
239        w.write_bits(0b10110, 5);
240        // 0b10110_000 = 0xb0 in MSB-first packing of just 5 bits with zero
241        // padding on the low 3 bits.
242        assert_eq!(w.into_bytes(), vec![0b1011_0000]);
243    }
244
245    #[test]
246    fn write_two_5_bit_values_packs_into_one_and_a_bit() {
247        let mut w = BitWriter::new();
248        w.write_bits(0b11111, 5);
249        w.write_bits(0b00001, 5);
250        // first 5: 11111___, then 00001 occupies bits 5..0 of the next
251        // 5-bit slot. Combined: 11111_000_01 = 11111000_01000000 = 0xf8 0x40
252        assert_eq!(w.into_bytes(), vec![0b1111_1000, 0b0100_0000]);
253    }
254
255    #[test]
256    fn write_8_bits_is_one_byte() {
257        let mut w = BitWriter::new();
258        w.write_bits(0xab, 8);
259        assert_eq!(w.into_bytes(), vec![0xab]);
260    }
261
262    #[test]
263    fn write_zero_bits_is_noop() {
264        let mut w = BitWriter::new();
265        w.write_bits(0xff, 0);
266        assert_eq!(w.bit_len(), 0);
267        assert_eq!(w.into_bytes(), Vec::<u8>::new());
268    }
269
270    #[test]
271    fn round_trip_5_bit_values() {
272        let mut w = BitWriter::new();
273        w.write_bits(0b10110, 5);
274        w.write_bits(0b00001, 5);
275        let bytes = w.into_bytes();
276
277        let mut r = BitReader::new(&bytes);
278        assert_eq!(r.read_bits(5).unwrap(), 0b10110);
279        assert_eq!(r.read_bits(5).unwrap(), 0b00001);
280    }
281
282    #[test]
283    fn read_past_end_errors() {
284        let bytes = vec![0xff];
285        let mut r = BitReader::new(&bytes);
286        assert!(r.read_bits(9).is_err());
287        // State must be preserved on truncation error.
288        assert_eq!(r.remaining_bits(), 8);
289    }
290
291    #[test]
292    fn read_full_byte_aligned() {
293        let bytes = vec![0xab, 0xcd];
294        let mut r = BitReader::new(&bytes);
295        assert_eq!(r.read_bits(8).unwrap(), 0xab);
296        assert_eq!(r.read_bits(8).unwrap(), 0xcd);
297    }
298
299    #[test]
300    fn save_and_restore_position() {
301        let bytes = vec![0b1011_0010, 0b0100_0000];
302        let mut r = BitReader::new(&bytes);
303        let saved = r.save_position();
304        let _ = r.read_bits(5).unwrap();
305        assert_eq!(r.save_position(), 5);
306        r.restore_position(saved);
307        assert_eq!(r.read_bits(5).unwrap(), 0b10110);
308    }
309
310    #[test]
311    fn with_bit_limit_excludes_padding() {
312        // 5-bit payload + 3-bit zero padding = 1 byte
313        let mut w = BitWriter::new();
314        w.write_bits(0b10110, 5);
315        let bytes = w.into_bytes(); // [0b1011_0000]; padding is the trailing 000
316
317        let mut r = BitReader::with_bit_limit(&bytes, 5);
318        assert_eq!(r.read_bits(5).unwrap(), 0b10110);
319        assert!(r.is_exhausted());
320        // Attempting to read further (into the padding) errors.
321        assert!(r.read_bits(1).is_err());
322    }
323
324    #[test]
325    fn re_emit_bits_round_trip_byte_aligned() {
326        // Source bitstream: a single full byte 0xab.
327        let mut src = BitWriter::new();
328        src.write_bits(0xab, 8);
329        let src_bit_len = src.bit_len();
330        let src_bytes = src.into_bytes();
331
332        let mut dst = BitWriter::new();
333        re_emit_bits(&mut dst, &src_bytes, src_bit_len).unwrap();
334
335        assert_eq!(dst.bit_len(), 8);
336        let dst_bytes = dst.into_bytes();
337        assert_eq!(dst_bytes, vec![0xab]);
338    }
339
340    #[test]
341    fn re_emit_bits_round_trip_all_widths_1_through_23() {
342        // Sweep every bit-width in 1..=23. For each width, write a unique
343        // pattern as the source, re-emit it into a destination, then read it
344        // back from the destination and assert equality.
345        for width in 1..=23usize {
346            let pattern: u64 = if width == 64 {
347                0xffff_ffff_ffff_ffff
348            } else {
349                (1u64 << width) - 1
350            } & 0xa5_a5_a5_a5_a5_a5_a5_a5; // checkerboard, masked to width
351
352            let mut src = BitWriter::new();
353            src.write_bits(pattern, width);
354            let src_bit_len = src.bit_len();
355            let src_bytes = src.into_bytes();
356            assert_eq!(src_bit_len, width);
357
358            let mut dst = BitWriter::new();
359            re_emit_bits(&mut dst, &src_bytes, width).unwrap();
360            assert_eq!(dst.bit_len(), width);
361
362            let dst_bytes = dst.into_bytes();
363            let mut r = BitReader::with_bit_limit(&dst_bytes, width);
364            assert_eq!(r.read_bits(width).unwrap(), pattern, "width={width}");
365        }
366    }
367
368    #[test]
369    fn re_emit_bits_non_byte_aligned_source() {
370        // Source: 5 bits then 7 bits = 12-bit non-byte-aligned bitstream.
371        let mut src = BitWriter::new();
372        src.write_bits(0b10110, 5);
373        src.write_bits(0b1010101, 7);
374        let src_bit_len = src.bit_len();
375        assert_eq!(src_bit_len, 12);
376        let src_bytes = src.into_bytes();
377
378        let mut dst = BitWriter::new();
379        re_emit_bits(&mut dst, &src_bytes, src_bit_len).unwrap();
380        assert_eq!(dst.bit_len(), 12);
381
382        let dst_bytes = dst.into_bytes();
383        let mut r = BitReader::with_bit_limit(&dst_bytes, 12);
384        assert_eq!(r.read_bits(5).unwrap(), 0b10110);
385        assert_eq!(r.read_bits(7).unwrap(), 0b1010101);
386    }
387
388    #[test]
389    fn re_emit_bits_appends_to_existing_dst() {
390        // Pre-fill destination with 3 bits, then re-emit 9 bits from source.
391        // Verify total length is 12 and the bits are positioned correctly.
392        let mut dst = BitWriter::new();
393        dst.write_bits(0b101, 3);
394
395        let mut src = BitWriter::new();
396        src.write_bits(0b1_1110_0001, 9);
397        let src_bit_len = src.bit_len();
398        let src_bytes = src.into_bytes();
399
400        re_emit_bits(&mut dst, &src_bytes, src_bit_len).unwrap();
401        assert_eq!(dst.bit_len(), 12);
402
403        let dst_bytes = dst.into_bytes();
404        let mut r = BitReader::with_bit_limit(&dst_bytes, 12);
405        assert_eq!(r.read_bits(3).unwrap(), 0b101);
406        assert_eq!(r.read_bits(9).unwrap(), 0b1_1110_0001);
407    }
408}