use byteorder::{ReadBytesExt, WriteBytesExt};
use std::io::{self, Seek};
use crate::ascii_str::ArrayList;
use crate::bit_io::{BitRead, BitWrite};
use crate::bits::Bitset;
use crate::num::*;
use crate::{Decode, Encode};
use crate::frame::BlockingStrategy;
#[derive(Debug, thiserror::Error)]
pub enum CodedNumberError {
#[error("IO: {0}")]
Io(#[from] io::Error),
#[error(
"Continuation byte missing or malformed at position {byte_index} (got `{got:0b}`; expected `10xxxxxx`)."
)]
InvalidContinuationByte { byte_index: u8, got: u8 },
#[error("Overlong encoding; The value should have been expressed in fewer bytes")]
Overlong,
#[error("Value `{value}` exceeds the maximum length for this mode ({max} bits).")]
OutOfRange { value: u64, max: u8 },
#[error("Unexpected end of input.")]
UnexpectedEof,
}
const fn is_continuation_byte(byte: u8) -> bool {
byte & 0b1100_0000 == 0b1000_0000
}
const fn min_encoded_len(x: u64) -> u8 {
match x {
0x000_0000_0000..=0x000_0000_007F => 1,
0x000_0000_0080..=0x000_0000_07FF => 2,
0x000_0000_0800..=0x000_0000_FFFF => 3,
0x000_0001_0000..=0x000_001F_FFFF => 4,
0x000_0020_0000..=0x000_03FF_FFFF => 5,
0x000_0400_0000..=0x000_7FFF_FFFF => 6,
_ => 7,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CodedNumber {
FrameNumber(U31),
SampleNumber(U36),
}
impl Decode<BlockingStrategy> for CodedNumber {
type Error = CodedNumberError;
fn decode<R: BitRead + Seek>(
reader: &mut R,
strategy: BlockingStrategy,
) -> Result<CodedNumber, Self::Error> {
let first = reader.read_u8()?;
let leading = first.leading_ones() as u8;
let length = match leading {
0 => 1,
n @ 2..=7 => n,
_ => {
return Err(CodedNumberError::InvalidContinuationByte {
byte_index: 0,
got: first,
});
}
};
let mut rest = ArrayList::<u8, 6>::default_with_len(length as usize - 1);
reader.read_exact(&mut rest)?;
let init =
u64::from(first.get_bit_range_msb(u32::from(leading + 1), u32::from(7 - leading)));
let value = rest
.iter()
.copied()
.enumerate()
.try_fold(init, |acc, (i, byte)| {
if is_continuation_byte(byte) {
let payload = u64::from(byte & 0b0011_1111);
Ok(acc << 6 | payload)
} else {
Err(CodedNumberError::InvalidContinuationByte {
byte_index: i as u8 + 1,
got: byte,
})
}
})?;
if min_encoded_len(value) != length {
return Err(CodedNumberError::Overlong);
}
let coded = match strategy {
BlockingStrategy::Fixed => u32::try_from(value)
.ok()
.and_then(U31::new)
.map(CodedNumber::FrameNumber)
.ok_or(CodedNumberError::OutOfRange { value, max: 31 })?,
BlockingStrategy::Variable => U36::new(value)
.map(CodedNumber::SampleNumber)
.ok_or(CodedNumberError::OutOfRange { value, max: 36 })?,
};
Ok(coded)
}
}
impl Encode<()> for CodedNumber {
type Error = CodedNumberError;
fn encode<W: BitWrite>(&self, writer: &mut W, _opt: ()) -> Result<(), Self::Error> {
let n = match *self {
CodedNumber::FrameNumber(n) => u64::from(n.inner()),
CodedNumber::SampleNumber(n) => n.inner(),
};
let len = min_encoded_len(n);
debug_assert!(len <= 7);
if len == 1 {
writer.write_u8(n as u8)?;
return Ok(());
}
let mut continuation = ArrayList::<u8, 7>::default();
let mut n = n;
for _ in 0..len - 1 {
let b = (n.get_bit_range_lsb(0, 6) as u8).set_bit_msb(0);
continuation.push(b);
n >>= 6;
}
let first_byte = 0u8.set_bit_range_msb(0, len) | (n as u8).get_bit_range_lsb(0, 8 - len);
writer.write_u8(first_byte)?;
for byte in continuation.iter().copied().rev() {
writer.write_u8(byte)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
const FIXED: BlockingStrategy = BlockingStrategy::Fixed;
const VAR: BlockingStrategy = BlockingStrategy::Variable;
#[test]
fn single_byte_zero() {
let n = CodedNumber::decode_bytes(&[0x00], FIXED).unwrap();
assert_eq!(n, CodedNumber::FrameNumber(U31!(0)));
}
#[test]
fn single_byte_max() {
let n = CodedNumber::decode_bytes(&[0x7F], FIXED).unwrap();
assert_eq!(n, CodedNumber::FrameNumber(U31!(127)));
}
#[test]
fn two_byte() {
let n = CodedNumber::decode_bytes(&[0xC2, 0x80], FIXED).unwrap();
assert_eq!(n, CodedNumber::FrameNumber(U31!(128)));
}
#[test]
fn frame_number_max() {
let n = CodedNumber::decode_bytes(&[0xFD, 0xBF, 0xBF, 0xBF, 0xBF, 0xBF], FIXED).unwrap();
assert_eq!(n, CodedNumber::FrameNumber(U31!(0x7FFF_FFFF)));
}
#[test]
fn spec_example_51_billion_samples() {
let bytes = [0xFE, 0xAF, 0x9F, 0xB5, 0xA3, 0xB8, 0x80];
let n = CodedNumber::decode_bytes(&bytes, VAR).unwrap();
assert_eq!(n, CodedNumber::SampleNumber(U36!(51_000_000_000)));
}
#[test]
fn err_unexpected_eof_empty() {
assert!(matches!(
CodedNumber::decode_bytes(&[], FIXED).unwrap_err(),
CodedNumberError::Io(io) if io.kind() == io::ErrorKind::UnexpectedEof,
));
}
#[test]
fn err_unexpected_eof_truncated_multibyte() {
assert!(matches!(
CodedNumber::decode_bytes(&[0xE0, 0x80], FIXED).unwrap_err(),
CodedNumberError::Io(io) if io.kind() == io::ErrorKind::UnexpectedEof,
));
}
#[test]
fn err_invalid_first_byte_continuation_pattern() {
assert!(matches!(
CodedNumber::decode_bytes(&[0x80], FIXED).unwrap_err(),
CodedNumberError::InvalidContinuationByte {
byte_index: _,
got: 0x80,
},
));
}
#[test]
fn err_invalid_first_byte_all_ones() {
assert!(matches!(
CodedNumber::decode_bytes(&[0xFF], FIXED).unwrap_err(),
CodedNumberError::InvalidContinuationByte {
byte_index: 0,
got: 0xFF,
}
));
}
#[test]
fn err_invalid_continuation_byte() {
assert!(matches!(
CodedNumber::decode_bytes(&[0xC2, 0x00], FIXED).unwrap_err(),
CodedNumberError::InvalidContinuationByte {
byte_index: _,
got: 0x00,
}
));
}
#[test]
fn err_overlong() {
assert!(matches!(
CodedNumber::decode_bytes(&[0xC1, 0xBF], FIXED,).unwrap_err(),
CodedNumberError::Overlong
));
}
#[test]
fn err_frame_number_out_of_range() {
assert!(matches!(
CodedNumber::decode_bytes(&[0xFE, 0x82, 0x80, 0x80, 0x80, 0x80, 0x80], FIXED)
.unwrap_err(),
CodedNumberError::OutOfRange {
max: _,
value: 0x8000_0000,
},
));
}
#[test]
fn err_seven_byte_fixed_is_always_out_of_range() {
assert!(matches!(
CodedNumber::decode_bytes(&[0xFE, 0xA0, 0x80, 0x80, 0x80, 0x80, 0x80], FIXED)
.unwrap_err(),
CodedNumberError::OutOfRange { .. },
));
}
}
#[cfg(test)]
mod encode_tests {
use super::*;
use crate::bit_io::BitWriter;
use crate::frame::BlockingStrategy;
fn encode(n: CodedNumber) -> Vec<u8> {
let mut w = BitWriter::new(Vec::new());
n.encode(&mut w, ()).unwrap();
w.into_inner()
}
#[test]
fn frame_number_byte_count() {
assert_eq!(encode(CodedNumber::FrameNumber(U31!(0))), [0x00]);
assert_eq!(encode(CodedNumber::FrameNumber(U31!(127))), [0x7F]);
assert_eq!(encode(CodedNumber::FrameNumber(U31!(128))), [0xC2, 0x80]);
assert_eq!(
encode(CodedNumber::FrameNumber(U31!(0x7FFF_FFFF))),
[0xFD, 0xBF, 0xBF, 0xBF, 0xBF, 0xBF]
);
}
#[test]
fn sample_number_max_len() {
assert_eq!(
encode(CodedNumber::SampleNumber(U36!(51_000_000_000))),
[0xFE, 0xAF, 0x9F, 0xB5, 0xA3, 0xB8, 0x80]
);
}
#[test]
fn round_trip_fixed() {
for &n in &[0u32, 1, 127, 128, 1000, 0x7FFF_FFFF] {
let cn = CodedNumber::FrameNumber(U31::new(n).unwrap());
let bytes = encode(cn);
let decoded = CodedNumber::decode_bytes(&bytes, BlockingStrategy::Fixed).unwrap();
assert_eq!(decoded, cn, "round-trip failed for n={n}");
}
}
#[test]
fn round_trip_variable() {
for &n in &[0u64, 1, 127, 128, 1000, 51_000_000_000, 0xF_FFFF_FFFF] {
let cn = CodedNumber::SampleNumber(U36::new(n).unwrap());
let bytes = encode(cn);
let decoded = CodedNumber::decode_bytes(&bytes, BlockingStrategy::Variable).unwrap();
assert_eq!(decoded, cn, "round-trip failed for n={n}");
}
}
}