use byteorder::ReadBytesExt;
use std::io::{self, Read};
use crate::ascii_str::ArrayList;
use crate::bit_reader::BitReader;
use crate::bits::Bitset;
use crate::num::*;
use crate::frame::BlockingStrategy;
#[derive(Debug, thiserror::Error)]
pub enum CodedNumberError {
#[error("IO: {0}")]
Io(#[from] io::Error),
#[error(
"Continuation byte missing or malformed at position {byte_index} (got `{got:0b}`; expected `10xxxxxx`)."
)]
InvalidContinuationByte { byte_index: u8, got: u8 },
#[error("Overlong encoding; The value should have been expressed in fewer bytes")]
Overlong,
#[error("Value `{value}` exceeds the maximum length for this mode ({max} bits).")]
OutOfRange { value: u64, max: u8 },
#[error("Unexpected end of input.")]
UnexpectedEof,
}
const fn is_continuation_byte(byte: u8) -> bool {
byte & 0b1100_0000 == 0b1000_0000
}
const fn min_encoded_len(x: u64) -> u8 {
match x {
0x000_0000_0000..=0x000_0000_007F => 1,
0x000_0000_0080..=0x000_0000_07FF => 2,
0x000_0000_0800..=0x000_0000_FFFF => 3,
0x000_0001_0000..=0x000_001F_FFFF => 4,
0x000_0020_0000..=0x000_03FF_FFFF => 5,
0x000_0400_0000..=0x000_7FFF_FFFF => 6,
_ => 7,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CodedNumber {
FrameNumber(U31),
SampleNumber(U36),
}
impl CodedNumber {
pub fn decode_from_reader<R: Read>(
strategy: BlockingStrategy,
reader: &mut BitReader<R>,
) -> Result<(CodedNumber, u8), CodedNumberError> {
let first = reader.read_u8()?;
let leading = first.leading_ones() as u8;
let length = match leading {
0 => 1,
n @ 2..=7 => n,
_ => {
return Err(CodedNumberError::InvalidContinuationByte {
byte_index: 0,
got: first,
});
}
};
let mut rest = ArrayList::<u8, 6>::new_with_length(length as usize - 1);
reader.read_exact(&mut rest)?;
let init =
u64::from(first.get_bit_range_msb(u32::from(leading + 1), u32::from(7 - leading)));
let value = rest
.iter()
.copied()
.enumerate()
.try_fold(init, |acc, (i, byte)| {
if is_continuation_byte(byte) {
let payload = u64::from(byte & 0b0011_1111);
Ok(acc << 6 | payload)
} else {
Err(CodedNumberError::InvalidContinuationByte {
byte_index: i as u8 + 1,
got: byte,
})
}
})?;
if min_encoded_len(value) != length {
return Err(CodedNumberError::Overlong);
}
let coded = match strategy {
BlockingStrategy::Fixed => u32::try_from(value)
.ok()
.and_then(U31::new)
.map(CodedNumber::FrameNumber)
.ok_or(CodedNumberError::OutOfRange { value, max: 31 })?,
BlockingStrategy::Variable => U36::new(value)
.map(CodedNumber::SampleNumber)
.ok_or(CodedNumberError::OutOfRange { value, max: 36 })?,
};
Ok((coded, length))
}
pub fn decode(
strategy: BlockingStrategy,
bytes: &[u8],
) -> Result<(CodedNumber, u8), CodedNumberError> {
Self::decode_from_reader(strategy, &mut BitReader::new(io::Cursor::new(bytes)))
}
}
#[cfg(test)]
mod tests {
use super::*;
const FIXED: BlockingStrategy = BlockingStrategy::Fixed;
const VAR: BlockingStrategy = BlockingStrategy::Variable;
#[test]
fn single_byte_zero() {
let (n, consumed) = CodedNumber::decode(FIXED, &[0x00]).unwrap();
assert_eq!(n, CodedNumber::FrameNumber(U31!(0)));
assert_eq!(consumed, 1);
}
#[test]
fn single_byte_max() {
let (n, consumed) = CodedNumber::decode(FIXED, &[0x7F]).unwrap();
assert_eq!(n, CodedNumber::FrameNumber(U31!(127)));
assert_eq!(consumed, 1);
}
#[test]
fn two_byte() {
let (n, consumed) = CodedNumber::decode(FIXED, &[0xC2, 0x80]).unwrap();
assert_eq!(n, CodedNumber::FrameNumber(U31!(128)));
assert_eq!(consumed, 2);
}
#[test]
fn frame_number_max() {
let (n, consumed) =
CodedNumber::decode(FIXED, &[0xFD, 0xBF, 0xBF, 0xBF, 0xBF, 0xBF]).unwrap();
assert_eq!(n, CodedNumber::FrameNumber(U31!(0x7FFF_FFFF)));
assert_eq!(consumed, 6);
}
#[test]
fn spec_example_51_billion_samples() {
let bytes = [0xFE, 0xAF, 0x9F, 0xB5, 0xA3, 0xB8, 0x80];
let (n, consumed) = CodedNumber::decode(VAR, &bytes).unwrap();
assert_eq!(n, CodedNumber::SampleNumber(U36!(51_000_000_000)));
assert_eq!(consumed, 7);
}
#[test]
fn consumed_ignores_trailing_bytes() {
let (_, consumed) = CodedNumber::decode(FIXED, &[0x01, 0xFF, 0xFF]).unwrap();
assert_eq!(consumed, 1);
}
#[test]
fn err_unexpected_eof_empty() {
assert!(matches!(
CodedNumber::decode(FIXED, &[]).unwrap_err(),
CodedNumberError::Io(io) if io.kind() == io::ErrorKind::UnexpectedEof,
));
}
#[test]
fn err_unexpected_eof_truncated_multibyte() {
assert!(matches!(
CodedNumber::decode(FIXED, &[0xE0, 0x80]).unwrap_err(),
CodedNumberError::Io(io) if io.kind() == io::ErrorKind::UnexpectedEof,
));
}
#[test]
fn err_invalid_first_byte_continuation_pattern() {
assert!(matches!(
CodedNumber::decode(FIXED, &[0x80]).unwrap_err(),
CodedNumberError::InvalidContinuationByte {
byte_index: _,
got: 0x80,
},
));
}
#[test]
fn err_invalid_first_byte_all_ones() {
assert!(matches!(
CodedNumber::decode(FIXED, &[0xFF]).unwrap_err(),
CodedNumberError::InvalidContinuationByte {
byte_index: 0,
got: 0xFF,
}
));
}
#[test]
fn err_invalid_continuation_byte() {
assert!(matches!(
CodedNumber::decode(FIXED, &[0xC2, 0x00]).unwrap_err(),
CodedNumberError::InvalidContinuationByte {
byte_index: _,
got: 0x00,
}
));
}
#[test]
fn err_overlong() {
assert!(matches!(
CodedNumber::decode(FIXED, &[0xC1, 0xBF]).unwrap_err(),
CodedNumberError::Overlong
));
}
#[test]
fn err_frame_number_out_of_range() {
assert!(matches!(
CodedNumber::decode(FIXED, &[0xFE, 0x82, 0x80, 0x80, 0x80, 0x80, 0x80]).unwrap_err(),
CodedNumberError::OutOfRange {
max: _,
value: 0x8000_0000,
},
));
}
#[test]
fn err_seven_byte_fixed_is_always_out_of_range() {
assert!(matches!(
CodedNumber::decode(FIXED, &[0xFE, 0xA0, 0x80, 0x80, 0x80, 0x80, 0x80]).unwrap_err(),
CodedNumberError::OutOfRange { .. },
));
}
}