sharkyflac 0.1.0

A pure rust FLAC decoder and encoder
Documentation
use std::io::{self, Read, Seek};

use byteorder::ReadBytesExt;

use crate::bits::Bitset;

/// A bit reader over a byte stream. Starts reading a byte boundary
#[derive(Debug)]
pub struct BitReader<R: Read> {
    inner:     R,
    buf:       u8,
    bits_left: u8,
}

impl<R: Read> BitReader<R> {
    /// Create a new bit reader.
    pub const fn new(inner: R) -> Self {
        Self {
            inner,
            buf: 0,
            bits_left: 0,
        }
    }

    /// Returns whether the reader is byte-aligned.
    #[inline]
    pub const fn is_byte_aligned(&self) -> bool {
        self.bits_left == 0
    }

    /// Read one bit
    #[inline]
    pub fn read_bit(&mut self) -> io::Result<bool> {
        if self.bits_left == 0 {
            self.buf = self.inner.read_u8()?;
            self.bits_left = 8;
        }

        let bit = self.buf.get_bit_msb(8 - self.bits_left as u32); // MSB-first
        self.bits_left -= 1;
        Ok(bit)
    }

    /// Read `n` bits (n ≤ 64) MSB-first.
    pub fn read_bits(&mut self, mut n: u32) -> io::Result<u64> {
        if n == 0 {
            return Ok(0);
        }
        assert!(n <= 64);

        let mut value: u64 = 0;
        let mut bit_count = 0;

        while n > 0 {
            if self.bits_left == 0 {
                match self.inner.read_u8() {
                    Ok(byte) => {
                        self.buf = byte;
                        self.bits_left = 8;
                    }
                    // recover trailing bits in the case of unaligned reads
                    Err(e) => {
                        self.buf = value as u8;
                        self.bits_left = bit_count as u8;
                        return Err(e);
                    }
                }
            }

            let take = n.min(self.bits_left as u32);
            let shift = self.bits_left as u32 - take;

            // Extract the top `take` bits from the current buffer byte
            let chunk = (self.buf as u64) >> shift;
            let mask = (1u64 << take) - 1;
            value = (value << take) | (chunk & mask);

            self.bits_left -= take as u8;
            bit_count += take;
            n -= take;
        }

        Ok(value)
    }

    /// Read `n` bits sign-extended to an `i64`. `n` must be in the range
    /// `1..=64`. Reading zero bits returns `0`.
    #[inline]
    pub fn read_signed(&mut self, n: u8) -> io::Result<i64> {
        debug_assert!((0..=64).contains(&n));
        if n == 0 {
            return Ok(0);
        }

        let raw = self.read_bits(n as u32)? as i64;
        let sh = 64 - n;
        Ok((raw << sh) >> sh)
    }

    /// Retrieve the leftover bits. This byte-aligns the reader. The data is
    /// returned as a `u8` followed by the number of bits read.
    #[inline]
    pub fn take_leftover(&mut self) -> Option<(u8, u8)> {
        if self.bits_left != 0 {
            Some((
                std::mem::take(&mut self.buf).get_bit_range_lsb(0, self.bits_left.into()),
                std::mem::take(&mut self.bits_left),
            ))
        } else {
            None
        }
    }

    /// Discard the leftover bits. This makes future reads byte aligned.
    #[inline]
    pub fn discard_leftover(&mut self) {
        self.bits_left = 0;
    }

    #[inline]
    pub fn into_inner(self) -> R {
        self.inner
    }
}

impl<R: Read> Read for BitReader<R> {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        if buf.is_empty() {
            return Ok(0);
        }

        // Fast Path: If we are on a byte boundary, delegate straight to the inner
        // reader.
        if self.bits_left == 0 {
            return self.inner.read(buf);
        }

        // Slow Path: We are misaligned. Stitch together bits to construct bytes.
        let mut bytes_read = 0;
        for byte in buf.iter_mut() {
            match self.read_bits(8) {
                Ok(val) => {
                    *byte = val as u8;
                    bytes_read += 1;
                }
                Err(e) => {
                    // Standard `Read` behavior: if we hit EOF but already managed
                    // to read some bytes, return the count instead of the error.
                    if bytes_read > 0 && e.kind() == io::ErrorKind::UnexpectedEof {
                        break;
                    }
                    return Err(e);
                }
            }
        }

        Ok(bytes_read)
    }
}

impl<R: Read + Seek> Seek for BitReader<R> {
    fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
        self.bits_left = 0;
        self.inner.seek(pos)
    }

    fn seek_relative(&mut self, offset: i64) -> io::Result<()> {
        self.bits_left = 0;
        self.inner.seek_relative(offset)
    }
}

#[allow(clippy::bool_assert_comparison)]
#[cfg(test)]
mod tests {
    use super::*;
    use std::io::{Cursor, ErrorKind, Read};

    fn create_reader(data: &[u8]) -> BitReader<Cursor<&[u8]>> {
        BitReader::new(Cursor::new(data))
    }

    #[test]
    fn read_single_bits() {
        // 0b10110010 = 178
        let mut reader = create_reader(&[0b1011_0010]);

        assert_eq!(reader.read_bit().unwrap(), true);
        assert_eq!(reader.read_bit().unwrap(), false);
        assert_eq!(reader.read_bit().unwrap(), true);
        assert_eq!(reader.read_bit().unwrap(), true);

        assert_eq!(reader.read_bit().unwrap(), false);
        assert_eq!(reader.read_bit().unwrap(), false);
        assert_eq!(reader.read_bit().unwrap(), true);
        assert_eq!(reader.read_bit().unwrap(), false);
    }

    #[test]
    fn read_bits_within_byte() {
        let mut reader = create_reader(&[0b1101_0101]);

        assert_eq!(reader.read_bits(4).unwrap(), 0b1101);
        assert_eq!(reader.read_bits(3).unwrap(), 0b010);
        assert_eq!(reader.read_bits(1).unwrap(), 0b1);
    }

    #[test]
    fn read_bits_across_byte_boundary() {
        // [0b1111_0000, 0b1010_1100]
        let mut reader = create_reader(&[0xF0, 0xAC]);

        assert_eq!(reader.read_bits(4).unwrap(), 0b1111);
        // Crosses boundary: takes 4 bits (0000) from byte 1, and 4 bits (1010) from
        // byte 2
        assert_eq!(reader.read_bits(8).unwrap(), 0b0000_1010);
        assert_eq!(reader.read_bits(4).unwrap(), 0b1100);
    }

    #[test]
    fn read_bits_64_bits() {
        // 8 bytes of 0xFF, followed by 1 byte of 0x00
        let data = [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00];
        let mut reader = create_reader(&data);

        // Read a full 64-bit integer
        assert_eq!(reader.read_bits(64).unwrap(), u64::MAX);
        // Read 8 bits from the remaining byte
        assert_eq!(reader.read_bits(8).unwrap(), 0x00);
    }

    #[test]
    fn discard_leftover() {
        let mut reader = create_reader(&[0b1111_1111, 0b0101_0101]);

        assert_eq!(reader.read_bits(3).unwrap(), 0b111);

        reader.discard_leftover();

        assert_eq!(reader.read_bits(8).unwrap(), 0b0101_0101);
    }

    #[test]
    fn read_trait_passthrough() {
        let data = [0x01, 0x02, 0x03, 0x04];
        let mut reader = create_reader(&data);
        let mut buf = [0u8; 3];

        let bytes_read = reader.read(&mut buf).unwrap();
        assert_eq!(bytes_read, 3);
        assert_eq!(buf, [0x01, 0x02, 0x03]);

        let cursor = reader.into_inner();
        assert_eq!(cursor.position(), 3);
    }

    #[test]
    fn read_trait_misaligned() {
        let data = [0b1111_0000, 0b1010_1111, 0b0000_1111];
        let mut reader = create_reader(&data);

        assert_eq!(reader.read_bits(4).unwrap(), 0b1111);

        let mut buf = [0u8; 2];
        let bytes_read = reader.read(&mut buf).unwrap();

        assert_eq!(bytes_read, 2);
        assert_eq!(buf[0], 0b0000_1010);
        assert_eq!(buf[1], 0b1111_0000);
    }

    #[test]
    fn read_bit_eof() {
        let mut reader = create_reader(&[0xAA]);

        assert_eq!(reader.read_bits(8).unwrap(), 0xAA);

        assert_eq!(
            reader.read_bit().unwrap_err().kind(),
            ErrorKind::UnexpectedEof
        );
        assert_eq!(reader.take_leftover(), None);
    }

    #[test]
    fn read_trait_partial_eof() {
        let data = [0b1111_0000, 0b1010_1111];
        let mut reader = create_reader(&data);

        // Misalign by 4 bits
        reader.read_bits(4).unwrap();

        let mut buf = [0u8; 2];
        let bytes_read = reader.read(&mut buf).unwrap();

        assert_eq!(bytes_read, 1);
        assert_eq!(buf[0], 0b0000_1010);

        assert_eq!(reader.take_leftover().unwrap(), (0b1111, 4));
    }

    #[test]
    fn seek() {
        let mut reader = create_reader(&[0b1011_0010, 0b0101_0101]);

        assert_eq!(reader.read_bit().unwrap(), true);
        assert_eq!(reader.read_bit().unwrap(), false);
        assert_eq!(reader.read_bit().unwrap(), true);
        assert_eq!(reader.read_bit().unwrap(), true);

        reader.seek(io::SeekFrom::Start(1)).unwrap();

        assert_eq!(reader.read_bit().unwrap(), false);
        assert_eq!(reader.read_bit().unwrap(), true);
        assert_eq!(reader.read_bit().unwrap(), false);
        assert_eq!(reader.read_bit().unwrap(), true);
        assert_eq!(reader.read_bit().unwrap(), false);
        assert_eq!(reader.read_bit().unwrap(), true);
    }
}