lac 0.1.0

Lo Audio Codec — lossless audio codec with LPC + partitioned Rice coding.
Documentation
//! MSB-first bit writer and reader for the Rice entropy coder.
//!
//! # Wire convention
//!
//! LAC's entropy-coded payload is big-endian at the bit level: the first bit
//! written lands at bit 7 (MSB) of the first output byte, and the reader
//! recovers bits in the same order. This matches FLAC's convention and keeps
//! Rice codewords human-inspectable in a hex dump.

use alloc::vec::Vec;

/// Bit writer that accumulates MSB-first and appends bytes into a
/// caller-owned `Vec<u8>` once eight bits are ready. Final partial byte
/// is zero-padded on the LSB side via `finish`.
///
/// The writer borrows its destination buffer for its entire lifetime —
/// this matches the `_into` idiom used elsewhere in the crate (`encode_frame_into`,
/// `decode_frame_into`, `rice_encode_zigzag_into`, `rice_decode_into`) and lets
/// every encoder path reuse a single caller-provided buffer with zero
/// heap traffic past the first-frame warm-up.
pub struct BitWriter<'a> {
    /// Output buffer. Bytes are appended as the accumulator fills. The
    /// caller retains ownership; `finish` returns nothing.
    buf: &'a mut Vec<u8>,
    /// Partial byte under construction. Bits are appended at the LSB of this
    /// accumulator; `finish` left-aligns the partial byte so the first bit
    /// written always occupies the MSB of its output byte.
    current: u8,
    /// Count of valid bits in `current`. Invariant: `0 ≤ bits ≤ 7`. When `bits`
    /// reaches 8, `current` is pushed to `buf` and both are reset.
    bits: u8,
}

impl<'a> BitWriter<'a> {
    /// Create a bit writer that appends into `buf`. Existing bytes in
    /// `buf` are preserved — the writer only extends. Callers that want
    /// a fresh bitstream start with a cleared `Vec<u8>`.
    pub fn new(buf: &'a mut Vec<u8>) -> Self {
        Self {
            buf,
            current: 0,
            bits: 0,
        }
    }

    /// Append one bit, MSB-first.
    #[inline]
    pub fn write_bit(&mut self, bit: bool) {
        self.current = (self.current << 1) | (bit as u8);
        self.bits += 1;
        if self.bits == 8 {
            self.buf.push(self.current);
            self.current = 0;
            self.bits = 0;
        }
    }

    /// Append `count` bits from `value`, MSB-first. Bit `count-1` is written
    /// first, down to bit 0 last.
    ///
    /// `debug_assert!(count <= 32)` guards the shift `value >> i` from undefined
    /// behaviour on 32-bit operands.
    #[inline]
    pub fn write_bits(&mut self, value: u32, count: u8) {
        debug_assert!(count <= 32, "count={count} exceeds u32 width");
        for i in (0..count).rev() {
            self.write_bit((value >> i) & 1 != 0);
        }
    }

    /// Flush any partial byte (zero-padded on the LSB side) into the
    /// backing buffer. Consumes the writer so the borrow on `buf` ends
    /// and the caller regains access.
    ///
    /// The `<< (8 - bits)` shift moves the accumulated bits to the
    /// most-significant positions of the final byte. This preserves MSB-first
    /// ordering: the first bit written is still at bit 7 of the flushed byte.
    pub fn finish(self) {
        if self.bits > 0 {
            self.buf.push(self.current << (8 - self.bits));
        }
    }

    /// Total bits written so far (committed bytes plus the partial byte).
    /// Only referenced by unit tests; gated accordingly.
    #[cfg(test)]
    pub(crate) fn bit_count(&self) -> usize {
        self.buf.len() * 8 + self.bits as usize
    }
}

/// Bit reader matching `BitWriter`'s MSB-first ordering. Consumes from the MSB
/// of each byte down to the LSB, then advances to the next byte.
pub struct BitReader<'a> {
    buf: &'a [u8],
    byte_pos: usize,
    /// Bit offset within `buf[byte_pos]`, counting from MSB (7) down to LSB (0).
    /// Starts at 7 so the first read extracts bit 7 — the bit that `BitWriter`
    /// wrote first.
    bit_pos: u8,
}

impl<'a> BitReader<'a> {
    /// Create a reader positioned at the MSB of the first byte.
    pub fn new(buf: &'a [u8]) -> Self {
        Self {
            buf,
            byte_pos: 0,
            bit_pos: 7,
        }
    }

    /// Read the next bit. Returns `None` when the buffer is exhausted.
    #[inline]
    pub fn read_bit(&mut self) -> Option<bool> {
        if self.byte_pos >= self.buf.len() {
            return None;
        }
        let bit = (self.buf[self.byte_pos] >> self.bit_pos) & 1 != 0;
        if self.bit_pos == 0 {
            self.byte_pos += 1;
            self.bit_pos = 7;
        } else {
            self.bit_pos -= 1;
        }
        Some(bit)
    }

    /// Read `count` bits, MSB-first, into a u32. Returns `None` if fewer than
    /// `count` bits remain. Each successive bit is shifted into the LSB of the
    /// accumulator, so the first bit read becomes the most significant bit of
    /// the result.
    #[inline]
    pub fn read_bits(&mut self, count: u8) -> Option<u32> {
        let mut v = 0u32;
        for _ in 0..count {
            v = (v << 1) | (self.read_bit()? as u32);
        }
        Some(v)
    }

    /// Scan a unary prefix: count leading zero bits up to (and consume)
    /// the terminating `1`. Returns the zero count; returns `None` if
    /// the buffer is exhausted before a `1` is seen.
    ///
    /// Byte-at-a-time via `u8::leading_zeros` on the packed data —
    /// materially faster than `read_bit()`-per-zero in the Rice decoder
    /// when residuals have large quotients (wide content + small `k`).
    /// Typical audio has quotients in the single digits per codeword,
    /// where the bit-by-bit and byte-at-a-time paths are indistinguishable.
    ///
    /// Saturating accumulation keeps the returned count well-defined on
    /// hypothetical inputs with > 2³² consecutive zero bits; the caller
    /// applies any Rice-parameter-specific quotient cap after the fact.
    #[inline]
    pub fn read_unary(&mut self) -> Option<u32> {
        let mut q: u32 = 0;

        // Phase 1: drain the remaining bits of the current partial byte,
        // if any. `bit_pos < 7` implies we're mid-byte; the BitReader
        // invariant (see `read_bit`) guarantees `byte_pos` is in range.
        if self.bit_pos < 7 {
            let byte = self.buf[self.byte_pos];
            // The live region is bits `bit_pos..=0`. Mask off everything
            // above so `leading_zeros` counts only within it.
            let valid = self.bit_pos + 1;
            let live = byte & ((1u8 << valid) - 1);
            if live == 0 {
                // All remaining bits in this byte are zero.
                q = valid as u32;
                self.byte_pos += 1;
                self.bit_pos = 7;
            } else {
                // `live.leading_zeros()` counts zeros from bit 7 down.
                // Since the masked bits above `bit_pos` are all zero and
                // `live != 0`, the highest set bit index is in `[0, bit_pos]`
                // and the zero count within the live region is
                // `bit_pos - terminator_bit`.
                let terminator_bit = 7u8 - live.leading_zeros() as u8;
                q = (self.bit_pos - terminator_bit) as u32;
                if terminator_bit == 0 {
                    self.byte_pos += 1;
                    self.bit_pos = 7;
                } else {
                    self.bit_pos = terminator_bit - 1;
                }
                return Some(q);
            }
        }

        // Phase 2: byte-aligned scan over whole bytes.
        while let Some(&byte) = self.buf.get(self.byte_pos) {
            if byte == 0 {
                q = q.saturating_add(8);
                self.byte_pos += 1;
            } else {
                let lz = byte.leading_zeros();
                q = q.saturating_add(lz);
                let terminator_bit = 7u8 - lz as u8;
                if terminator_bit == 0 {
                    self.byte_pos += 1;
                    self.bit_pos = 7;
                } else {
                    self.bit_pos = terminator_bit - 1;
                }
                return Some(q);
            }
        }

        None
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use alloc::vec;

    #[test]
    fn roundtrip_single_bits() {
        let bits = [true, false, true, true, false, false, true, false, true];
        let mut buf = Vec::new();
        {
            let mut w = BitWriter::new(&mut buf);
            for &b in &bits {
                w.write_bit(b);
            }
            w.finish();
        }

        let mut r = BitReader::new(&buf);
        for &expected in &bits {
            assert_eq!(r.read_bit(), Some(expected));
        }
    }

    #[test]
    fn roundtrip_multi_bit_values() {
        let vals: &[(u32, u8)] = &[(0b10110, 5), (0, 3), (0xAB, 8), (1, 1), (0x1F, 5)];
        let mut buf = Vec::new();
        {
            let mut w = BitWriter::new(&mut buf);
            for &(v, n) in vals {
                w.write_bits(v, n);
            }
            w.finish();
        }

        let mut r = BitReader::new(&buf);
        for &(expected, n) in vals {
            assert_eq!(r.read_bits(n), Some(expected));
        }
    }

    #[test]
    fn byte_padding_is_zero() {
        let mut buf = Vec::new();
        {
            let mut w = BitWriter::new(&mut buf);
            w.write_bits(0b101, 3);
            w.finish();
        }
        assert_eq!(buf.len(), 1);
        // Written 3 bits; the remaining 5 bits of the output byte are zero.
        assert_eq!(buf[0], 0b101_00000);
    }

    #[test]
    fn writer_appends_to_existing_bytes() {
        // Constructing with a non-empty buffer must preserve existing
        // bytes — this is how frame::encode_frame_into threads a single
        // `out` through the header serialisation plus the Rice bitstream
        // writer without intermediate copies.
        let mut buf = vec![0xAA, 0xBB];
        {
            let mut w = BitWriter::new(&mut buf);
            w.write_bits(0b1111_0000, 8);
            w.finish();
        }
        assert_eq!(buf, vec![0xAA, 0xBB, 0xF0]);
    }

    #[test]
    fn read_past_end_returns_none() {
        let buf = [0xFFu8];
        let mut r = BitReader::new(&buf);
        for _ in 0..8 {
            assert_eq!(r.read_bit(), Some(true));
        }
        assert_eq!(r.read_bit(), None);
    }

    #[test]
    fn read_unary_matches_bit_loop() {
        // For every zero-count q in [0, 23] and k in [0, 5], write the
        // unary prefix by hand and check that `read_unary` returns the
        // same value as a bit-by-bit scan would. Starts the reader at
        // both byte-aligned and mid-byte offsets so the Phase 1 /
        // Phase 2 split is covered.
        for prefix_bits in 0u8..5 {
            for q in [0u32, 1, 2, 5, 7, 8, 15, 16, 17, 23] {
                let mut buf = Vec::new();
                {
                    let mut w = BitWriter::new(&mut buf);
                    // Prefix to shift the unary payload to a non-aligned start.
                    if prefix_bits > 0 {
                        w.write_bits(0, prefix_bits);
                    }
                    // q zero-bits then a single terminator.
                    for _ in 0..q {
                        w.write_bit(false);
                    }
                    w.write_bit(true);
                    w.finish();
                }

                let mut r = BitReader::new(&buf);
                if prefix_bits > 0 {
                    let _ = r.read_bits(prefix_bits);
                }
                assert_eq!(r.read_unary(), Some(q), "q={q} prefix_bits={prefix_bits}");
            }
        }
    }

    #[test]
    fn read_unary_truncated_returns_none() {
        // All-zero buffer with no terminator: reader scans the entire
        // buffer and reports exhaustion.
        let buf = [0u8; 4];
        let mut r = BitReader::new(&buf);
        assert_eq!(r.read_unary(), None);
    }

    #[test]
    fn read_unary_terminator_at_last_bit_of_byte() {
        // q = 7, then a 1-bit at bit 0 of the first byte: value 0x01.
        // Tests the terminator-at-bit-0 branch that advances the byte
        // cursor instead of decrementing bit_pos.
        let buf = [0x01u8, 0x80];
        let mut r = BitReader::new(&buf);
        assert_eq!(r.read_unary(), Some(7));
        // Reader should now be at byte 1, bit 7. A second unary read
        // consumes the 1-bit at bit 7 → q = 0.
        assert_eq!(r.read_unary(), Some(0));
    }

    #[test]
    fn bit_count_tracks_partial_byte() {
        let mut buf = Vec::new();
        let mut w = BitWriter::new(&mut buf);
        assert_eq!(w.bit_count(), 0);
        w.write_bits(0b1011, 4);
        assert_eq!(w.bit_count(), 4);
        w.write_bits(0b11, 2);
        assert_eq!(w.bit_count(), 6);
        w.write_bits(0b10, 2);
        assert_eq!(w.bit_count(), 8);
    }
}