use std::fmt::Debug;
use std::io::{self, Seek};
use std::ops::{Deref, DerefMut};
use byteorder::WriteBytesExt;
use crate::bit_io::{BitRead, BitWrite};
use crate::frame::Channels;
use crate::{Decode, Encode};
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("IO: {0}")]
Io(#[from] io::Error),
#[error("Effective bit depth cannot be 0")]
BitDepthZero,
#[error("A unary encoded value was too long")]
OverlongUnary,
#[error("Reserved leading bit was set in subframe header")]
ReservedBit,
#[error("Reserved subframe type")]
ReservedType,
#[error("Reserved Rice coding method")]
ReservedRiceCoding,
#[error("Reserved QLP precision (0b1111)")]
ReservedQlpPrecision,
#[error("Invalid predictor order")]
InvalidPredictorOrder,
}
#[rustfmt::skip]
const FIXED_COEFFS: [&[i32]; 5] = [
&[],
&[1],
&[2, -1],
&[3, -3, 1],
&[4, -6, 4, -1],
];
fn read_unary<R: BitRead>(r: &mut R) -> Result<u64, Error> {
let mut n = 0u64;
while !r.read_bit()? {
n = n.checked_add(1).ok_or(Error::OverlongUnary)?;
}
Ok(n)
}
fn write_unary<W: BitWrite>(w: &mut W, n: u64) -> Result<(), Error> {
let mut bytes = n / 8;
let mut bits_rem = n & 8;
while bytes > 0 {
w.write_u8(0xFF)?;
bytes -= 1;
}
while bits_rem > 0 {
w.write_bit(true)?;
bits_rem -= 1;
}
w.write_bit(false)?;
Ok(())
}
#[inline(always)]
fn decode_zigzag(n: u64) -> i64 {
if n & 1 == 0 {
(n >> 1) as i64
} else {
(n >> 1).cast_signed()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum RiceMethod {
P4,
P5,
}
impl RiceMethod {
const fn bits(self) -> u8 {
match self {
Self::P4 => 4,
Self::P5 => 5,
}
}
const fn sentinel(self) -> u8 {
match self {
Self::P4 => 0b1111,
Self::P5 => 0b1_1111,
}
}
}
impl Decode<()> for RiceMethod {
type Error = Error;
fn decode<R: BitRead + io::Seek>(reader: &mut R, _opt: ()) -> Result<Self, Self::Error> {
Ok(match reader.read_bits(2)? {
0b00 => Self::P4,
0b01 => Self::P5,
0b10..=0b11 => return Err(Error::ReservedRiceCoding),
_ => unreachable!("Two bits can't encode any more information"),
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Type {
Constant,
Verbatim,
FixedPredictor(u8),
LinearPredictor(u8),
}
impl Decode for Type {
type Error = Error;
fn decode<R: BitRead + Seek>(reader: &mut R, _opt: ()) -> Result<Self, Self::Error> {
match reader.read_bits(6)? as u8 {
0b00_0000 => Ok(Self::Constant),
0b00_0001 => Ok(Self::Verbatim),
v @ 0b00_1000..=0b00_1100 => {
let v = v - 8;
if v > 4 {
Err(Error::InvalidPredictorOrder)
} else {
Ok(Self::FixedPredictor(v))
}
}
v @ 0b10_0000..=0b11_1111 => {
let v = v - 31;
if (1..=32).contains(&v) {
Ok(Self::LinearPredictor(v))
} else {
Err(Error::InvalidPredictorOrder)
}
}
0b00_0010..=0b00_0111 | 0b00_1101..=0b01_1111 => Err(Error::ReservedType),
_ => {
unreachable!("Six bits can not encode any more information")
}
}
}
}
impl Encode for Type {
type Error = Error;
fn encode<W: BitWrite>(&self, writer: &mut W, _opt: ()) -> Result<(), Self::Error> {
match *self {
Self::Constant => writer.write_bits(0b0, 6)?,
Self::Verbatim => writer.write_bits(0b1, 6)?,
Self::FixedPredictor(v) => {
if v > 4 {
return Err(Error::InvalidPredictorOrder);
}
writer.write_bits(u64::from(v) + 8, 6)?;
}
Self::LinearPredictor(v) => {
if !(1..=32).contains(&v) {
return Err(Error::InvalidPredictorOrder);
}
writer.write_bits(u64::from(v) + 31, 6)?;
}
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Header {
pub kind: Type,
pub wasted_bits: u8,
}
impl Header {
#[inline]
pub const fn effective_bit_depth(self, frame_bit_depth: u8, is_side_channel: bool) -> u8 {
(frame_bit_depth + is_side_channel as u8).saturating_sub(self.wasted_bits)
}
}
impl Decode for Header {
type Error = Error;
fn decode<R: BitRead + Seek>(reader: &mut R, _opt: ()) -> Result<Self, Self::Error> {
if reader.read_bit()? {
return Err(Error::ReservedBit);
}
let kind = Type::decode(reader, ())?;
let wasted_bits = if reader.read_bit()? {
read_unary(reader)
.ok()
.and_then(|x| x.checked_add(1))
.and_then(|x| u8::try_from(x).ok())
.ok_or(Error::OverlongUnary)?
} else {
0
};
Ok(Self { kind, wasted_bits })
}
}
impl Encode for Header {
type Error = Error;
fn encode<W: BitWrite>(&self, writer: &mut W, _opt: ()) -> Result<(), Self::Error> {
writer.write_bit(false)?;
self.kind.encode(writer, ())?;
if self.wasted_bits == 0 {
writer.write_bit(false)?;
} else {
write_unary(writer, u64::from(self.wasted_bits) - 1)?;
}
Ok(())
}
}
pub(super) fn reconstruct_fixed(buf: &mut [i32], order: usize) {
let coefficents = FIXED_COEFFS[order];
for i in order..buf.len() {
let p: i64 = coefficents
.iter()
.zip((0..order).map(|j| buf[i - 1 - j]))
.map(|(&c, s)| i64::from(c) * i64::from(s))
.sum();
buf[i] = buf[i].wrapping_add(p as i32);
}
}
pub(super) fn reconstruct_lpc(buf: &mut [i32], coefficients: &[i32], shift: i8) {
let order = coefficients.len();
for i in order..buf.len() {
let mut p = lpc_dot_product(coefficients, &buf[i - order..i]);
if shift.is_positive() {
p >>= shift;
} else {
p <<= -shift;
}
buf[i] = buf[i].wrapping_add(p as i32);
}
}
#[inline(always)]
fn lpc_dot_product(coefficients: &[i32], history: &[i32]) -> i64 {
debug_assert_eq!(history.len(), coefficients.len());
coefficients
.iter()
.copied()
.zip(history.iter().copied().rev())
.map(|(c, s)| i64::from(c) * i64::from(s))
.sum()
}
fn decode_residuals<R: BitRead + Seek>(
reader: &mut R,
buf: &mut Vec<i32>,
predictor_order: usize,
block_size: usize,
) -> Result<(), Error> {
let rice_method = RiceMethod::decode(reader, ())?;
let partition_order = reader.read_bits(4)? as u8;
let num_partitions = 1usize << partition_order;
let parameter_bits = rice_method.bits();
let sentinel = rice_method.sentinel();
for partition in 0..num_partitions {
let n = if partition_order == 0 {
block_size
.checked_sub(predictor_order)
.ok_or(Error::InvalidPredictorOrder)?
} else if partition == 0 {
(block_size >> partition_order)
.checked_sub(predictor_order)
.ok_or(Error::InvalidPredictorOrder)?
} else {
block_size >> partition_order
};
let rice_param = reader.read_bits(parameter_bits.into())? as u8;
if rice_param == sentinel {
let bit_count = reader.read_bits(5)? as u8;
for _ in 0..n {
buf.push(if bit_count == 0 {
0
} else {
reader.read_signed(bit_count)? as i32
});
}
} else {
let k = rice_param;
for _ in 0..n {
let q = read_unary(reader)?;
let lo = reader.read_bits(k.into())?;
buf.push(decode_zigzag(q << k | lo) as i32);
}
}
}
Ok(())
}
#[derive(Clone)]
pub struct Block(pub Box<[i32]>);
impl Debug for Block {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Block").finish_non_exhaustive()
}
}
impl Deref for Block {
type Target = [i32];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for Block {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl Block {
#[inline]
pub fn scale(&mut self, bit_depth: u8) {
let shift = 32u32.wrapping_sub(u32::from(bit_depth));
for s in &mut self.0 {
*s = s.wrapping_shl(shift);
}
}
pub fn decorrelate(ch0: &mut Self, ch1: &mut Self, channels: Channels) {
match channels {
Channels::LeftSideStereo => {
for (l, s) in ch0.iter().zip(ch1.iter_mut()) {
*s = l.wrapping_sub(*s);
}
}
Channels::SideRightStereo => {
for (s, r) in ch0.iter_mut().zip(ch1.iter()) {
*s = s.wrapping_add(*r);
}
}
Channels::MidSideStereo => {
for (m, s) in ch0.iter_mut().zip(ch1.iter_mut()) {
let mid = i64::from(*m) << 1 | (i64::from(*s) & 1);
let side = i64::from(*s);
*m = ((mid + side) >> 1) as i32;
*s = ((mid - side) >> 1) as i32;
}
}
_ => {}
}
}
}
impl Decode<(usize, u8, bool)> for Block {
type Error = Error;
fn decode<R: BitRead + Seek>(
reader: &mut R,
(block_size, frame_bit_depth, is_side_channel): (usize, u8, bool),
) -> Result<Self, Self::Error> {
let header = Header::decode(reader, ())?;
let effective_depth = header.effective_bit_depth(frame_bit_depth, is_side_channel);
if effective_depth == 0 {
return Err(Error::BitDepthZero);
}
let mut buf: Vec<i32> = Vec::with_capacity(block_size);
match header.kind {
Type::Constant => {
let sample = reader.read_signed(effective_depth)?;
buf.resize(block_size, sample as i32);
}
Type::Verbatim => {
for _ in 0..block_size {
buf.push(reader.read_signed(effective_depth)? as i32);
}
}
Type::FixedPredictor(order) => {
let order = order as usize;
for _ in 0..order {
buf.push(reader.read_signed(effective_depth)? as i32);
}
decode_residuals(reader, &mut buf, order, block_size)?;
reconstruct_fixed(&mut buf, order);
}
Type::LinearPredictor(order) => {
let order = order as usize;
for _ in 0..order {
buf.push(reader.read_signed(effective_depth)? as i32);
}
let coeff_precision = {
let prec = reader.read_bits(4)? as u8;
if prec == 0b1111 {
return Err(Error::ReservedQlpPrecision);
}
prec + 1
};
let shift = reader.read_signed(5)? as i8;
let coefficients = {
let mut coeffs = [0i32; 32];
for c in &mut coeffs[..order] {
*c = reader.read_signed(coeff_precision)? as i32;
}
coeffs
};
decode_residuals(reader, &mut buf, order, block_size)?;
reconstruct_lpc(&mut buf, &coefficients[0..order], shift);
}
}
debug_assert_eq!(buf.len(), block_size);
if header.wasted_bits > 0 {
for s in &mut buf {
*s <<= header.wasted_bits;
}
}
Ok(Self(buf.into_boxed_slice()))
}
}
pub struct BlockEncodingParameters {
pub header: Header,
pub frame_bit_depth: u8,
pub is_side_channel: bool,
}
impl Encode<BlockEncodingParameters> for Block {
type Error = Error;
fn encode<W: BitWrite>(
&self,
writer: &mut W,
BlockEncodingParameters {
header,
frame_bit_depth,
is_side_channel,
}: BlockEncodingParameters,
) -> Result<(), Self::Error> {
let effective_bit_depth = header.effective_bit_depth(frame_bit_depth, is_side_channel);
header.encode(writer, ())?;
match header.kind {
Type::Verbatim => {
for &sample in self.iter() {
writer.write_signed(i64::from(sample), effective_bit_depth)?;
}
}
_ => unimplemented!(),
}
Ok(())
}
}
pub struct BlockIter<'a, R: BitRead + Seek> {
reader: &'a mut R,
block_size: usize,
frame_bit_depth: u8,
side_flags: [bool; 8],
channel_idx: usize,
num_channels: usize,
}
impl<'a, R: BitRead + Seek> BlockIter<'a, R> {
pub fn new(
reader: &'a mut R,
block_size: usize,
frame_bit_depth: u8,
side_flags: &[bool],
) -> Self {
let n = side_flags.len().min(8);
let mut flags = [false; 8];
flags[..n].copy_from_slice(&side_flags[..n]);
Self {
reader,
block_size,
frame_bit_depth,
side_flags: flags,
channel_idx: 0,
num_channels: n,
}
}
}
impl<R: BitRead + Seek> Iterator for BlockIter<'_, R> {
type Item = Result<Block, Error>;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.channel_idx >= self.num_channels {
return None;
}
self.channel_idx += 1;
Some(Block::decode(
self.reader,
(
self.block_size,
self.frame_bit_depth,
self.side_flags[self.channel_idx - 1],
),
))
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.num_channels - self.channel_idx;
(remaining, Some(remaining))
}
}
impl<R: BitRead + Seek> ExactSizeIterator for BlockIter<'_, R> {}