sharkyflac 0.2.0

A pure rust FLAC decoder and encoder
Documentation
use byteorder::{ReadBytesExt, WriteBytesExt};
use std::io::{self, Seek};

use crate::ascii_str::ArrayList;
use crate::bit_io::{BitRead, BitWrite};
use crate::bits::Bitset;
use crate::num::*;
use crate::{Decode, Encode};

use crate::frame::BlockingStrategy;

#[derive(Debug, thiserror::Error)]
pub enum CodedNumberError {
    #[error("IO: {0}")]
    Io(#[from] io::Error),

    #[error(
        "Continuation byte missing or malformed at position {byte_index} (got `{got:0b}`; expected `10xxxxxx`)."
    )]
    InvalidContinuationByte { byte_index: u8, got: u8 },

    #[error("Overlong encoding; The value should have been expressed in fewer bytes")]
    Overlong,

    #[error("Value `{value}` exceeds the maximum length for this mode ({max} bits).")]
    OutOfRange { value: u64, max: u8 },

    #[error("Unexpected end of input.")]
    UnexpectedEof,
}

const fn is_continuation_byte(byte: u8) -> bool {
    byte & 0b1100_0000 == 0b1000_0000
}

/// Minimum number of bytes required to encode `value`.
const fn min_encoded_len(x: u64) -> u8 {
    match x {
        0x000_0000_0000..=0x000_0000_007F => 1,
        0x000_0000_0080..=0x000_0000_07FF => 2,
        0x000_0000_0800..=0x000_0000_FFFF => 3,
        0x000_0001_0000..=0x000_001F_FFFF => 4,
        0x000_0020_0000..=0x000_03FF_FFFF => 5,
        0x000_0400_0000..=0x000_7FFF_FFFF => 6,
        _ => 7,
    }
}

/// A [coded number](https://www.ietf.org/rfc/rfc9639.html#section-9.1.5).
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CodedNumber {
    /// Only present with a fixed block size stream.
    FrameNumber(U31),
    /// Only present with a variable block size stream.
    SampleNumber(U36),
}

impl Decode<BlockingStrategy> for CodedNumber {
    type Error = CodedNumberError;

    fn decode<R: BitRead + Seek>(
        reader: &mut R,
        strategy: BlockingStrategy,
    ) -> Result<CodedNumber, Self::Error> {
        let first = reader.read_u8()?;

        let leading = first.leading_ones() as u8;
        let length = match leading {
            0 => 1,
            n @ 2..=7 => n,
            _ => {
                return Err(CodedNumberError::InvalidContinuationByte {
                    byte_index: 0,
                    got:        first,
                });
            }
        };

        let mut rest = ArrayList::<u8, 6>::default_with_len(length as usize - 1);
        reader.read_exact(&mut rest)?;

        let init =
            u64::from(first.get_bit_range_msb(u32::from(leading + 1), u32::from(7 - leading)));

        let value = rest
            .iter()
            .copied()
            .enumerate()
            .try_fold(init, |acc, (i, byte)| {
                if is_continuation_byte(byte) {
                    // I'd pay for a load if you know what I mean.
                    let payload = u64::from(byte & 0b0011_1111);
                    Ok(acc << 6 | payload)
                } else {
                    Err(CodedNumberError::InvalidContinuationByte {
                        byte_index: i as u8 + 1,
                        got:        byte,
                    })
                }
            })?;

        if min_encoded_len(value) != length {
            return Err(CodedNumberError::Overlong);
        }

        let coded = match strategy {
            BlockingStrategy::Fixed => u32::try_from(value)
                .ok()
                .and_then(U31::new)
                .map(CodedNumber::FrameNumber)
                .ok_or(CodedNumberError::OutOfRange { value, max: 31 })?,
            BlockingStrategy::Variable => U36::new(value)
                .map(CodedNumber::SampleNumber)
                .ok_or(CodedNumberError::OutOfRange { value, max: 36 })?,
        };

        Ok(coded)
    }
}

impl Encode<()> for CodedNumber {
    type Error = CodedNumberError;

    fn encode<W: BitWrite>(&self, writer: &mut W, _opt: ()) -> Result<(), Self::Error> {
        let n = match *self {
            CodedNumber::FrameNumber(n) => u64::from(n.inner()),
            CodedNumber::SampleNumber(n) => n.inner(),
        };
        let len = min_encoded_len(n);
        debug_assert!(len <= 7);

        if len == 1 {
            writer.write_u8(n as u8)?;
            return Ok(());
        }

        let mut continuation = ArrayList::<u8, 7>::default();
        let mut n = n;
        for _ in 0..len - 1 {
            let b = (n.get_bit_range_lsb(0, 6) as u8).set_bit_msb(0);
            continuation.push(b);
            n >>= 6;
        }

        let first_byte = 0u8.set_bit_range_msb(0, len) | (n as u8).get_bit_range_lsb(0, 8 - len);

        writer.write_u8(first_byte)?;
        for byte in continuation.iter().copied().rev() {
            writer.write_u8(byte)?;
        }

        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    const FIXED: BlockingStrategy = BlockingStrategy::Fixed;
    const VAR: BlockingStrategy = BlockingStrategy::Variable;

    #[test]
    fn single_byte_zero() {
        let n = CodedNumber::decode_bytes(&[0x00], FIXED).unwrap();
        assert_eq!(n, CodedNumber::FrameNumber(U31!(0)));
    }

    #[test]
    fn single_byte_max() {
        let n = CodedNumber::decode_bytes(&[0x7F], FIXED).unwrap();
        assert_eq!(n, CodedNumber::FrameNumber(U31!(127)));
    }

    #[test]
    fn two_byte() {
        let n = CodedNumber::decode_bytes(&[0xC2, 0x80], FIXED).unwrap();
        assert_eq!(n, CodedNumber::FrameNumber(U31!(128)));
    }

    #[test]
    fn frame_number_max() {
        let n = CodedNumber::decode_bytes(&[0xFD, 0xBF, 0xBF, 0xBF, 0xBF, 0xBF], FIXED).unwrap();
        assert_eq!(n, CodedNumber::FrameNumber(U31!(0x7FFF_FFFF)));
    }

    #[test]
    fn spec_example_51_billion_samples() {
        let bytes = [0xFE, 0xAF, 0x9F, 0xB5, 0xA3, 0xB8, 0x80];
        let n = CodedNumber::decode_bytes(&bytes, VAR).unwrap();
        assert_eq!(n, CodedNumber::SampleNumber(U36!(51_000_000_000)));
    }

    #[test]
    fn err_unexpected_eof_empty() {
        assert!(matches!(
            CodedNumber::decode_bytes(&[], FIXED).unwrap_err(),
            CodedNumberError::Io(io) if io.kind() == io::ErrorKind::UnexpectedEof,
        ));
    }

    #[test]
    fn err_unexpected_eof_truncated_multibyte() {
        assert!(matches!(
            CodedNumber::decode_bytes(&[0xE0, 0x80], FIXED).unwrap_err(),
            CodedNumberError::Io(io) if io.kind() == io::ErrorKind::UnexpectedEof,
        ));
    }

    #[test]
    fn err_invalid_first_byte_continuation_pattern() {
        // 0b10xxxxxx — looks like a continuation byte, not a header
        assert!(matches!(
            CodedNumber::decode_bytes(&[0x80], FIXED).unwrap_err(),
            CodedNumberError::InvalidContinuationByte {
                byte_index: _,
                got:        0x80,
            },
        ));
    }

    #[test]
    fn err_invalid_first_byte_all_ones() {
        assert!(matches!(
            CodedNumber::decode_bytes(&[0xFF], FIXED).unwrap_err(),
            CodedNumberError::InvalidContinuationByte {
                byte_index: 0,
                got:        0xFF,
            }
        ));
    }

    #[test]
    fn err_invalid_continuation_byte() {
        // Second byte is 0x00, not 0b10xxxxxx
        assert!(matches!(
            CodedNumber::decode_bytes(&[0xC2, 0x00], FIXED).unwrap_err(),
            CodedNumberError::InvalidContinuationByte {
                byte_index: _,
                got:        0x00,
            }
        ));
    }

    #[test]
    fn err_overlong() {
        assert!(matches!(
            CodedNumber::decode_bytes(&[0xC1, 0xBF], FIXED,).unwrap_err(),
            CodedNumberError::Overlong
        ));
    }
    #[test]
    fn err_frame_number_out_of_range() {
        assert!(matches!(
            CodedNumber::decode_bytes(&[0xFE, 0x82, 0x80, 0x80, 0x80, 0x80, 0x80], FIXED)
                .unwrap_err(),
            CodedNumberError::OutOfRange {
                max:   _,
                value: 0x8000_0000,
            },
        ));
    }

    #[test]
    fn err_seven_byte_fixed_is_always_out_of_range() {
        // The minimum 7-byte value (0x8000_0000) exceeds the 31-bit frame number limit,
        // so the 7-byte form is implicitly rejected via OutOfRange, not a special case.
        assert!(matches!(
            CodedNumber::decode_bytes(&[0xFE, 0xA0, 0x80, 0x80, 0x80, 0x80, 0x80], FIXED)
                .unwrap_err(),
            CodedNumberError::OutOfRange { .. },
        ));
    }
}

#[cfg(test)]
mod encode_tests {
    use super::*;
    use crate::bit_io::BitWriter;
    use crate::frame::BlockingStrategy;

    // A helper that encodes `n` and returns the raw bytes.
    fn encode(n: CodedNumber) -> Vec<u8> {
        let mut w = BitWriter::new(Vec::new());
        n.encode(&mut w, ()).unwrap();
        w.into_inner()
    }

    #[test]
    fn frame_number_byte_count() {
        assert_eq!(encode(CodedNumber::FrameNumber(U31!(0))), [0x00]);
        assert_eq!(encode(CodedNumber::FrameNumber(U31!(127))), [0x7F]);
        assert_eq!(encode(CodedNumber::FrameNumber(U31!(128))), [0xC2, 0x80]);
        assert_eq!(
            encode(CodedNumber::FrameNumber(U31!(0x7FFF_FFFF))),
            [0xFD, 0xBF, 0xBF, 0xBF, 0xBF, 0xBF]
        );
    }

    #[test]
    fn sample_number_max_len() {
        assert_eq!(
            encode(CodedNumber::SampleNumber(U36!(51_000_000_000))),
            [0xFE, 0xAF, 0x9F, 0xB5, 0xA3, 0xB8, 0x80]
        );
    }

    #[test]
    fn round_trip_fixed() {
        for &n in &[0u32, 1, 127, 128, 1000, 0x7FFF_FFFF] {
            let cn = CodedNumber::FrameNumber(U31::new(n).unwrap());
            let bytes = encode(cn);
            let decoded = CodedNumber::decode_bytes(&bytes, BlockingStrategy::Fixed).unwrap();
            assert_eq!(decoded, cn, "round-trip failed for n={n}");
        }
    }

    #[test]
    fn round_trip_variable() {
        for &n in &[0u64, 1, 127, 128, 1000, 51_000_000_000, 0xF_FFFF_FFFF] {
            let cn = CodedNumber::SampleNumber(U36::new(n).unwrap());
            let bytes = encode(cn);
            let decoded = CodedNumber::decode_bytes(&bytes, BlockingStrategy::Variable).unwrap();
            assert_eq!(decoded, cn, "round-trip failed for n={n}");
        }
    }
}