sharkyflac 0.1.0

A pure rust FLAC decoder and encoder
Documentation
use byteorder::ReadBytesExt;
use std::io::{self, Read};

use crate::ascii_str::ArrayList;
use crate::bit_reader::BitReader;
use crate::bits::Bitset;
use crate::num::*;

use crate::frame::BlockingStrategy;

#[derive(Debug, thiserror::Error)]
pub enum CodedNumberError {
    #[error("IO: {0}")]
    Io(#[from] io::Error),

    #[error(
        "Continuation byte missing or malformed at position {byte_index} (got `{got:0b}`; expected `10xxxxxx`)."
    )]
    InvalidContinuationByte { byte_index: u8, got: u8 },

    #[error("Overlong encoding; The value should have been expressed in fewer bytes")]
    Overlong,

    #[error("Value `{value}` exceeds the maximum length for this mode ({max} bits).")]
    OutOfRange { value: u64, max: u8 },

    #[error("Unexpected end of input.")]
    UnexpectedEof,
}

const fn is_continuation_byte(byte: u8) -> bool {
    byte & 0b1100_0000 == 0b1000_0000
}

/// Minimum number of bytes required to encode `value`.
const fn min_encoded_len(x: u64) -> u8 {
    match x {
        0x000_0000_0000..=0x000_0000_007F => 1,
        0x000_0000_0080..=0x000_0000_07FF => 2,
        0x000_0000_0800..=0x000_0000_FFFF => 3,
        0x000_0001_0000..=0x000_001F_FFFF => 4,
        0x000_0020_0000..=0x000_03FF_FFFF => 5,
        0x000_0400_0000..=0x000_7FFF_FFFF => 6,
        _ => 7,
    }
}

/// A [coded number](https://www.ietf.org/rfc/rfc9639.html#section-9.1.5).
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CodedNumber {
    /// Only present with a fixed block size stream.
    FrameNumber(U31),
    /// Only present with a variable block size stream.
    SampleNumber(U36),
}

impl CodedNumber {
    pub fn decode_from_reader<R: Read>(
        strategy: BlockingStrategy,
        reader: &mut BitReader<R>,
    ) -> Result<(CodedNumber, u8), CodedNumberError> {
        let first = reader.read_u8()?;

        let leading = first.leading_ones() as u8;
        let length = match leading {
            0 => 1,
            n @ 2..=7 => n,
            _ => {
                return Err(CodedNumberError::InvalidContinuationByte {
                    byte_index: 0,
                    got:        first,
                });
            }
        };

        let mut rest = ArrayList::<u8, 6>::new_with_length(length as usize - 1);
        reader.read_exact(&mut rest)?;

        let init =
            u64::from(first.get_bit_range_msb(u32::from(leading + 1), u32::from(7 - leading)));

        let value = rest
            .iter()
            .copied()
            .enumerate()
            .try_fold(init, |acc, (i, byte)| {
                if is_continuation_byte(byte) {
                    // I'd pay for a load if you know what I mean.
                    let payload = u64::from(byte & 0b0011_1111);
                    Ok(acc << 6 | payload)
                } else {
                    Err(CodedNumberError::InvalidContinuationByte {
                        byte_index: i as u8 + 1,
                        got:        byte,
                    })
                }
            })?;

        if min_encoded_len(value) != length {
            return Err(CodedNumberError::Overlong);
        }

        let coded = match strategy {
            BlockingStrategy::Fixed => u32::try_from(value)
                .ok()
                .and_then(U31::new)
                .map(CodedNumber::FrameNumber)
                .ok_or(CodedNumberError::OutOfRange { value, max: 31 })?,
            BlockingStrategy::Variable => U36::new(value)
                .map(CodedNumber::SampleNumber)
                .ok_or(CodedNumberError::OutOfRange { value, max: 36 })?,
        };

        Ok((coded, length))
    }

    pub fn decode(
        strategy: BlockingStrategy,
        bytes: &[u8],
    ) -> Result<(CodedNumber, u8), CodedNumberError> {
        Self::decode_from_reader(strategy, &mut BitReader::new(io::Cursor::new(bytes)))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    const FIXED: BlockingStrategy = BlockingStrategy::Fixed;
    const VAR: BlockingStrategy = BlockingStrategy::Variable;

    #[test]
    fn single_byte_zero() {
        let (n, consumed) = CodedNumber::decode(FIXED, &[0x00]).unwrap();
        assert_eq!(n, CodedNumber::FrameNumber(U31!(0)));
        assert_eq!(consumed, 1);
    }

    #[test]
    fn single_byte_max() {
        let (n, consumed) = CodedNumber::decode(FIXED, &[0x7F]).unwrap();
        assert_eq!(n, CodedNumber::FrameNumber(U31!(127)));
        assert_eq!(consumed, 1);
    }

    #[test]
    fn two_byte() {
        let (n, consumed) = CodedNumber::decode(FIXED, &[0xC2, 0x80]).unwrap();
        assert_eq!(n, CodedNumber::FrameNumber(U31!(128)));
        assert_eq!(consumed, 2);
    }

    #[test]
    fn frame_number_max() {
        let (n, consumed) =
            CodedNumber::decode(FIXED, &[0xFD, 0xBF, 0xBF, 0xBF, 0xBF, 0xBF]).unwrap();
        assert_eq!(n, CodedNumber::FrameNumber(U31!(0x7FFF_FFFF)));
        assert_eq!(consumed, 6);
    }

    #[test]
    fn spec_example_51_billion_samples() {
        let bytes = [0xFE, 0xAF, 0x9F, 0xB5, 0xA3, 0xB8, 0x80];
        let (n, consumed) = CodedNumber::decode(VAR, &bytes).unwrap();
        assert_eq!(n, CodedNumber::SampleNumber(U36!(51_000_000_000)));
        assert_eq!(consumed, 7);
    }

    #[test]
    fn consumed_ignores_trailing_bytes() {
        let (_, consumed) = CodedNumber::decode(FIXED, &[0x01, 0xFF, 0xFF]).unwrap();
        assert_eq!(consumed, 1);
    }

    #[test]
    fn err_unexpected_eof_empty() {
        assert!(matches!(
            CodedNumber::decode(FIXED, &[]).unwrap_err(),
            CodedNumberError::Io(io) if io.kind() == io::ErrorKind::UnexpectedEof,
        ));
    }

    #[test]
    fn err_unexpected_eof_truncated_multibyte() {
        assert!(matches!(
            CodedNumber::decode(FIXED, &[0xE0, 0x80]).unwrap_err(),
            CodedNumberError::Io(io) if io.kind() == io::ErrorKind::UnexpectedEof,
        ));
    }

    #[test]
    fn err_invalid_first_byte_continuation_pattern() {
        // 0b10xxxxxx — looks like a continuation byte, not a header
        assert!(matches!(
            CodedNumber::decode(FIXED, &[0x80]).unwrap_err(),
            CodedNumberError::InvalidContinuationByte {
                byte_index: _,
                got:        0x80,
            },
        ));
    }

    #[test]
    fn err_invalid_first_byte_all_ones() {
        assert!(matches!(
            CodedNumber::decode(FIXED, &[0xFF]).unwrap_err(),
            CodedNumberError::InvalidContinuationByte {
                byte_index: 0,
                got:        0xFF,
            }
        ));
    }

    #[test]
    fn err_invalid_continuation_byte() {
        // Second byte is 0x00, not 0b10xxxxxx
        assert!(matches!(
            CodedNumber::decode(FIXED, &[0xC2, 0x00]).unwrap_err(),
            CodedNumberError::InvalidContinuationByte {
                byte_index: _,
                got:        0x00,
            }
        ));
    }

    #[test]
    fn err_overlong() {
        assert!(matches!(
            CodedNumber::decode(FIXED, &[0xC1, 0xBF]).unwrap_err(),
            CodedNumberError::Overlong
        ));
    }
    #[test]
    fn err_frame_number_out_of_range() {
        assert!(matches!(
            CodedNumber::decode(FIXED, &[0xFE, 0x82, 0x80, 0x80, 0x80, 0x80, 0x80]).unwrap_err(),
            CodedNumberError::OutOfRange {
                max:   _,
                value: 0x8000_0000,
            },
        ));
    }

    #[test]
    fn err_seven_byte_fixed_is_always_out_of_range() {
        // The minimum 7-byte value (0x8000_0000) exceeds the 31-bit frame number limit,
        // so the 7-byte form is implicitly rejected via OutOfRange, not a special case.
        assert!(matches!(
            CodedNumber::decode(FIXED, &[0xFE, 0xA0, 0x80, 0x80, 0x80, 0x80, 0x80]).unwrap_err(),
            CodedNumberError::OutOfRange { .. },
        ));
    }
}