use thiserror::Error;
use super::bitreader::BitReader;
use super::dequant::dequantize_block;
use super::entropy::{decode_block, EntropyError};
use super::idct::{finalize_idct_output, idct_8x8};
use super::zigzag::{inverse_scan, PROGRESSIVE_ZIGZAG};
#[derive(Debug, Error, PartialEq, Eq)]
pub enum DecodeError {
#[error("slice plane sizes overrun: {0}")]
PlaneOverrun(&'static str),
#[error("destination plane too small: needed {needed}, had {available}")]
DestinationTooSmall {
needed: usize,
available: usize,
},
#[error("entropy decode failed: {0}")]
Entropy(#[from] EntropyError),
}
#[derive(Debug, Clone, Copy)]
pub struct SliceData<'a> {
pub luma: &'a [u8],
pub cb: &'a [u8],
pub cr: &'a [u8],
pub alpha: Option<&'a [u8]>,
}
pub fn split_slice_planes(
data: &[u8],
luma_size: u16,
cb_size: u16,
cr_size: u16,
alpha_size: Option<u16>,
) -> Result<SliceData<'_>, DecodeError> {
let total = usize::from(luma_size)
+ usize::from(cb_size)
+ usize::from(cr_size)
+ alpha_size.map_or(0, usize::from);
if total != data.len() {
return Err(DecodeError::PlaneOverrun(
"sum of plane sizes != slice data length",
));
}
let mut cursor = 0usize;
let luma = &data[cursor..cursor + usize::from(luma_size)];
cursor += usize::from(luma_size);
let cb = &data[cursor..cursor + usize::from(cb_size)];
cursor += usize::from(cb_size);
let cr = &data[cursor..cursor + usize::from(cr_size)];
cursor += usize::from(cr_size);
let alpha = alpha_size.map(|s| &data[cursor..cursor + usize::from(s)]);
Ok(SliceData {
luma,
cb,
cr,
alpha,
})
}
pub fn decode_slice_to_yuv422(
slice: SliceData<'_>,
quant_matrix_luma: &[u8; 64],
quant_matrix_chroma: &[u8; 64],
qscale: u8,
slice_mb_width: usize,
dst_luma: &mut [u16],
dst_luma_stride: usize,
dst_cb: &mut [u16],
dst_cb_stride: usize,
dst_cr: &mut [u16],
dst_cr_stride: usize,
) -> Result<(), DecodeError> {
let needed_luma = dst_luma_stride * 16;
if dst_luma.len() < needed_luma {
return Err(DecodeError::DestinationTooSmall {
needed: needed_luma,
available: dst_luma.len(),
});
}
let needed_chroma = dst_cb_stride * 16;
if dst_cb.len() < needed_chroma || dst_cr.len() < needed_chroma {
return Err(DecodeError::DestinationTooSmall {
needed: needed_chroma,
available: dst_cb.len().min(dst_cr.len()),
});
}
decode_plane(
slice.luma,
quant_matrix_luma,
qscale,
slice_mb_width,
Plane::Luma422,
dst_luma,
dst_luma_stride,
)?;
decode_plane(
slice.cb,
quant_matrix_chroma,
qscale,
slice_mb_width,
Plane::Chroma422,
dst_cb,
dst_cb_stride,
)?;
decode_plane(
slice.cr,
quant_matrix_chroma,
qscale,
slice_mb_width,
Plane::Chroma422,
dst_cr,
dst_cr_stride,
)?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn decode_slice_to_yuv444(
slice: SliceData<'_>,
quant_matrix_luma: &[u8; 64],
quant_matrix_chroma: &[u8; 64],
qscale: u8,
slice_mb_width: usize,
dst_luma: &mut [u16],
dst_luma_stride: usize,
dst_cb: &mut [u16],
dst_cb_stride: usize,
dst_cr: &mut [u16],
dst_cr_stride: usize,
) -> Result<(), DecodeError> {
let needed_luma = dst_luma_stride * 16;
if dst_luma.len() < needed_luma {
return Err(DecodeError::DestinationTooSmall {
needed: needed_luma,
available: dst_luma.len(),
});
}
let needed_cb = dst_cb_stride * 16;
let needed_cr = dst_cr_stride * 16;
if dst_cb.len() < needed_cb || dst_cr.len() < needed_cr {
return Err(DecodeError::DestinationTooSmall {
needed: needed_cb.max(needed_cr),
available: dst_cb.len().min(dst_cr.len()),
});
}
decode_plane(
slice.luma,
quant_matrix_luma,
qscale,
slice_mb_width,
Plane::Luma422,
dst_luma,
dst_luma_stride,
)?;
decode_plane(
slice.cb,
quant_matrix_chroma,
qscale,
slice_mb_width,
Plane::Chroma444,
dst_cb,
dst_cb_stride,
)?;
decode_plane(
slice.cr,
quant_matrix_chroma,
qscale,
slice_mb_width,
Plane::Chroma444,
dst_cr,
dst_cr_stride,
)?;
Ok(())
}
#[derive(Clone, Copy)]
enum Plane {
Luma422,
Chroma422,
Chroma444,
}
impl Plane {
fn blocks_per_mb(self) -> usize {
match self {
Self::Luma422 | Self::Chroma444 => 4,
Self::Chroma422 => 2,
}
}
fn block_offset(self, mb_x: usize, block_in_mb: usize, stride: usize) -> usize {
match self {
Self::Luma422 | Self::Chroma444 => {
let col = mb_x * 16 + (block_in_mb & 1) * 8;
let row = (block_in_mb / 2) * 8;
row * stride + col
}
Self::Chroma422 => {
let col = mb_x * 8;
let row = block_in_mb * 8;
row * stride + col
}
}
}
}
fn decode_plane(
compressed: &[u8],
matrix: &[u8; 64],
qscale: u8,
slice_mb_width: usize,
plane: Plane,
dst: &mut [u16],
stride: usize,
) -> Result<(), DecodeError> {
let mut reader = BitReader::new(compressed);
let mut running_dc: i32 = 0;
let blocks_per_mb = plane.blocks_per_mb();
for mb_x in 0..slice_mb_width {
for b in 0..blocks_per_mb {
let (scan_coeffs, new_dc) = decode_block(&mut reader, running_dc)?;
running_dc = new_dc;
let raster = inverse_scan(&scan_coeffs, &PROGRESSIVE_ZIGZAG);
let dequantized = dequantize_block(&raster, matrix, qscale);
let spatial = idct_8x8(&dequantized);
let samples = finalize_idct_output(&spatial);
blit_8x8_to_plane(&samples, dst, stride, plane.block_offset(mb_x, b, stride));
}
}
Ok(())
}
fn blit_8x8_to_plane(samples: &[u16; 64], dst: &mut [u16], stride: usize, dst_offset: usize) {
for row in 0..8 {
let dst_row = dst_offset + row * stride;
let src_row = row * 8;
dst[dst_row..dst_row + 8].copy_from_slice(&samples[src_row..src_row + 8]);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn split_planes_no_alpha() {
let data: Vec<u8> = (0u8..30).collect();
let s = split_slice_planes(&data, 10, 8, 12, None).unwrap();
assert_eq!(s.luma.len(), 10);
assert_eq!(s.cb.len(), 8);
assert_eq!(s.cr.len(), 12);
assert!(s.alpha.is_none());
}
#[test]
fn split_planes_with_alpha() {
let data: Vec<u8> = (0u8..40).collect();
let s = split_slice_planes(&data, 10, 8, 12, Some(10)).unwrap();
assert_eq!(s.alpha.unwrap().len(), 10);
}
#[test]
fn split_planes_size_mismatch_errors() {
let data = [0u8; 30];
assert!(matches!(
split_slice_planes(&data, 10, 8, 13, None).unwrap_err(),
DecodeError::PlaneOverrun(_)
));
}
#[test]
fn end_to_end_zero_input_produces_midgrey_frame() {
let zero_coeffs = [0i32; 64];
let m = [4u8; 64];
let dequantized = dequantize_block(&zero_coeffs, &m, 4);
assert!(dequantized.iter().all(|&v| v == 0));
let spatial = idct_8x8(&dequantized);
let samples = finalize_idct_output(&spatial);
assert!(samples.iter().all(|&v| v == 512));
}
#[test]
fn end_to_end_dc_only_produces_uniform_block() {
let mut coeffs = [0i32; 64];
coeffs[0] = 100;
let m = [4u8; 64];
let dequantized = dequantize_block(&coeffs, &m, 4);
let spatial = idct_8x8(&dequantized);
let samples = finalize_idct_output(&spatial);
let first = samples[0];
assert!(
samples
.iter()
.all(|&v| (v as i32 - first as i32).abs() <= 1),
"DC-only block should be uniform; got {samples:?}"
);
assert!(first > 512);
}
#[test]
fn blit_writes_correct_8x8_region() {
let samples: [u16; 64] = std::array::from_fn(|i| i as u16);
let mut dst = vec![0u16; 16 * 16];
let stride = 16;
blit_8x8_to_plane(&samples, &mut dst, stride, 0);
for row in 0..8 {
for col in 0..8 {
assert_eq!(dst[row * stride + col], (row * 8 + col) as u16);
}
}
for row in 0..8 {
for col in 8..16 {
assert_eq!(dst[row * stride + col], 0);
}
}
}
#[test]
fn block_offset_luma_2x2_arrangement() {
assert_eq!(Plane::Luma422.block_offset(0, 0, 16), 0);
assert_eq!(Plane::Luma422.block_offset(0, 1, 16), 8);
assert_eq!(Plane::Luma422.block_offset(0, 2, 16), 128);
assert_eq!(Plane::Luma422.block_offset(0, 3, 16), 136);
assert_eq!(Plane::Luma422.block_offset(1, 0, 16), 16);
}
#[test]
fn block_offset_chroma_vertical_arrangement() {
let stride = 64;
assert_eq!(Plane::Chroma422.block_offset(0, 0, stride), 0);
assert_eq!(Plane::Chroma422.block_offset(0, 1, stride), 8 * stride);
assert_eq!(Plane::Chroma422.block_offset(1, 0, stride), 8);
}
#[test]
fn block_offset_chroma444_matches_luma_layout() {
let stride = 32;
assert_eq!(Plane::Chroma444.block_offset(0, 0, stride), 0);
assert_eq!(Plane::Chroma444.block_offset(0, 1, stride), 8);
assert_eq!(Plane::Chroma444.block_offset(0, 2, stride), 8 * stride);
assert_eq!(Plane::Chroma444.block_offset(0, 3, stride), 8 * stride + 8);
assert_eq!(Plane::Chroma444.block_offset(1, 0, stride), 16);
assert_eq!(Plane::Chroma444.blocks_per_mb(), 4);
}
#[test]
fn decode_slice_to_yuv444_destination_too_small_errors() {
let dummy = [0u8; 8];
let slice = SliceData {
luma: &dummy,
cb: &dummy,
cr: &dummy,
alpha: None,
};
let mut y = vec![0u16; 8 * 16 * 16];
let mut cb = vec![0u16; 100];
let mut cr = vec![0u16; 8 * 16 * 16];
let m = [4u8; 64];
let err =
decode_slice_to_yuv444(slice, &m, &m, 4, 8, &mut y, 128, &mut cb, 128, &mut cr, 128)
.unwrap_err();
assert!(matches!(err, DecodeError::DestinationTooSmall { .. }));
}
#[test]
fn destination_too_small_errors() {
let dummy = [0u8; 8];
let slice = SliceData {
luma: &dummy,
cb: &dummy,
cr: &dummy,
alpha: None,
};
let mut y = vec![0u16; 100];
let mut cb = vec![0u16; 64 * 16];
let mut cr = vec![0u16; 64 * 16];
let m = [4u8; 64];
let err =
decode_slice_to_yuv422(slice, &m, &m, 4, 8, &mut y, 128, &mut cb, 64, &mut cr, 64)
.unwrap_err();
assert!(matches!(err, DecodeError::DestinationTooSmall { .. }));
}
}