use oxideav_core::bits::BitReader;
use oxideav_core::{Error, Result};
use crate::block::decode_inter_block;
use crate::tables::{
decode_vlc, MbaSym, MtypeInfo, MvdSym, Prediction, CBP_TABLE, MBA_TABLE, MTYPE_TABLE, MVD_TABLE,
};
#[derive(Clone)]
pub struct Picture {
pub width: usize,
pub height: usize,
pub mb_width: usize,
pub mb_height: usize,
pub y: Vec<u8>,
pub cb: Vec<u8>,
pub cr: Vec<u8>,
pub y_stride: usize,
pub c_stride: usize,
}
impl Picture {
pub fn new(width: usize, height: usize) -> Self {
let mb_w = width.div_ceil(16);
let mb_h = height.div_ceil(16);
let y_stride = mb_w * 16;
let c_stride = mb_w * 8;
Self {
width,
height,
mb_width: mb_w,
mb_height: mb_h,
y_stride,
c_stride,
y: vec![0u8; y_stride * mb_h * 16],
cb: vec![0u8; c_stride * mb_h * 8],
cr: vec![0u8; c_stride * mb_h * 8],
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct MbContext {
pub mv: (i32, i32),
pub prev_was_mc: bool,
pub prev_mba: u8,
}
impl MbContext {
pub fn reset() -> Self {
Self {
mv: (0, 0),
prev_was_mc: false,
prev_mba: 0,
}
}
}
fn grey_reference(width: usize, height: usize) -> Picture {
let mut p = Picture::new(width, height);
p.y.fill(128);
p.cb.fill(128);
p.cr.fill(128);
p
}
pub fn luma_to_chroma_mv(v: i32) -> i32 {
v / 2
}
pub(crate) fn reconstruct_mv(predictor: i32, sym: MvdSym) -> i32 {
let a = predictor + sym.a as i32;
let b = predictor + sym.b as i32;
let in_range = |v: i32| (-15..=15).contains(&v);
match (in_range(a), in_range(b)) {
(true, _) => a,
(false, true) => b,
(false, false) => a, }
}
pub(crate) fn mvd_predictor(
mba: u8,
prev_mba: u8,
prev_was_mc: bool,
prev_mv: (i32, i32),
) -> (i32, i32) {
let row_start = matches!(mba, 1 | 12 | 23);
let consecutive = prev_mba != 0 && mba == prev_mba + 1;
if !row_start && consecutive && prev_was_mc {
prev_mv
} else {
(0, 0)
}
}
#[allow(clippy::too_many_arguments)]
pub fn decode_macroblock(
br: &mut BitReader<'_>,
mba: u8,
gob_x: usize,
gob_y: usize,
quant: &mut u32,
ctx: &mut MbContext,
pic: &mut Picture,
reference: Option<&Picture>,
) -> Result<()> {
let idx = (mba - 1) as usize;
let mb_col = idx % 11;
let mb_row = idx / 11;
let luma_x = gob_x + mb_col * 16;
let luma_y = gob_y + mb_row * 16;
let mtype: MtypeInfo = decode_vlc(br, MTYPE_TABLE)?;
if mtype.mquant {
let q = br.read_u32(5)?;
if q == 0 {
return Err(Error::invalid("h261 MB: MQUANT == 0"));
}
*quant = q;
}
let mut mv = (0i32, 0i32);
if mtype.mvd {
let pred = mvd_predictor(mba, ctx.prev_mba, ctx.prev_was_mc, ctx.mv);
let sym_x: MvdSym = decode_vlc(br, MVD_TABLE)?;
let sym_y: MvdSym = decode_vlc(br, MVD_TABLE)?;
let mx = reconstruct_mv(pred.0, sym_x);
let my = reconstruct_mv(pred.1, sym_y);
mv = (mx, my);
}
let cbp: u8 = if mtype.cbp {
decode_vlc(br, CBP_TABLE)?
} else if mtype.prediction == Prediction::Intra {
0b111111
} else {
0
};
let block_coded = [
(cbp >> 5) & 1 != 0, (cbp >> 4) & 1 != 0, (cbp >> 3) & 1 != 0, (cbp >> 2) & 1 != 0, (cbp >> 1) & 1 != 0, cbp & 1 != 0, ];
if mtype.prediction == Prediction::Intra {
for i in 0..6 {
let _ = block_coded[i];
let mut out = [0u8; 64];
crate::block::decode_intra_block(br, *quant, &mut out)?;
write_block(pic, i, luma_x, luma_y, &out);
}
ctx.mv = (0, 0);
ctx.prev_was_mc = false;
ctx.prev_mba = mba;
return Ok(());
}
let fallback_ref;
let reference_ref: &Picture = match reference {
Some(r) => r,
None => {
fallback_ref = grey_reference(pic.width, pic.height);
&fallback_ref
}
};
let reference = reference_ref;
let (mvx, mvy) = mv;
let cmvx = luma_to_chroma_mv(mvx);
let cmvy = luma_to_chroma_mv(mvy);
for i in 0..4usize {
let (sub_x, sub_y) = match i {
0 => (0, 0),
1 => (8, 0),
2 => (0, 8),
3 => (8, 8),
_ => unreachable!(),
};
let bx = (luma_x + sub_x) as i32;
let by = (luma_y + sub_y) as i32;
let mut pred = [0u8; 64];
copy_block_integer(
&reference.y,
reference.y_stride,
reference.y_stride as i32,
(reference.y.len() / reference.y_stride) as i32,
bx,
by,
mvx,
mvy,
&mut pred,
);
if mtype.filter {
pred = apply_loop_filter(&pred);
}
if block_coded[i] && mtype.tcoeff {
let mut resid = [0i32; 64];
decode_inter_block(br, *quant, &mut resid)?;
let mut out = [0u8; 64];
for j in 0..64 {
out[j] = (pred[j] as i32 + resid[j]).clamp(0, 255) as u8;
}
write_block(pic, i, luma_x, luma_y, &out);
} else {
write_block(pic, i, luma_x, luma_y, &pred);
}
}
for ci in 0..2usize {
let (ref_plane, ref_stride) = if ci == 0 {
(&reference.cb, reference.c_stride)
} else {
(&reference.cr, reference.c_stride)
};
let ref_h = (ref_plane.len() / ref_stride) as i32;
let cx = (luma_x / 2) as i32;
let cy = (luma_y / 2) as i32;
let mut pred = [0u8; 64];
copy_block_integer(
ref_plane,
ref_stride,
ref_stride as i32,
ref_h,
cx,
cy,
cmvx,
cmvy,
&mut pred,
);
if mtype.filter {
pred = apply_loop_filter(&pred);
}
let block_i = 4 + ci;
if block_coded[block_i] && mtype.tcoeff {
let mut resid = [0i32; 64];
decode_inter_block(br, *quant, &mut resid)?;
let mut out = [0u8; 64];
for j in 0..64 {
out[j] = (pred[j] as i32 + resid[j]).clamp(0, 255) as u8;
}
write_block(pic, block_i, luma_x, luma_y, &out);
} else {
write_block(pic, block_i, luma_x, luma_y, &pred);
}
}
ctx.mv = mv;
ctx.prev_was_mc = matches!(
mtype.prediction,
Prediction::InterMc | Prediction::InterMcFil
);
ctx.prev_mba = mba;
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn copy_block_integer(
ref_plane: &[u8],
ref_stride: usize,
ref_w: i32,
ref_h: i32,
bx: i32,
by: i32,
mvx: i32,
mvy: i32,
out: &mut [u8; 64],
) {
let sx = bx + mvx;
let sy = by + mvy;
for j in 0..8 {
for i in 0..8 {
let x = (sx + i).clamp(0, ref_w - 1) as usize;
let y = (sy + j).clamp(0, ref_h - 1) as usize;
out[(j as usize) * 8 + i as usize] = ref_plane[y * ref_stride + x];
}
}
}
pub fn apply_loop_filter(src: &[u8; 64]) -> [u8; 64] {
let mut h = [0i32; 64];
for j in 0..8 {
for i in 0..8 {
let v = if i == 0 || i == 7 {
src[j * 8 + i] as i32
} else {
let a = src[j * 8 + i - 1] as i32;
let b = src[j * 8 + i] as i32;
let c = src[j * 8 + i + 1] as i32;
(a + 2 * b + c + 2) >> 2
};
h[j * 8 + i] = v;
}
}
let mut out = [0u8; 64];
for i in 0..8 {
for j in 0..8 {
let v = if j == 0 || j == 7 {
h[j * 8 + i]
} else {
let a = h[(j - 1) * 8 + i];
let b = h[j * 8 + i];
let c = h[(j + 1) * 8 + i];
(a + 2 * b + c + 2) >> 2
};
out[j * 8 + i] = v.clamp(0, 255) as u8;
}
}
out
}
fn write_block(pic: &mut Picture, block_idx: usize, luma_x: usize, luma_y: usize, out: &[u8; 64]) {
let (plane, stride, px, py): (&mut [u8], usize, usize, usize) = match block_idx {
0 => (pic.y.as_mut_slice(), pic.y_stride, luma_x, luma_y),
1 => (pic.y.as_mut_slice(), pic.y_stride, luma_x + 8, luma_y),
2 => (pic.y.as_mut_slice(), pic.y_stride, luma_x, luma_y + 8),
3 => (pic.y.as_mut_slice(), pic.y_stride, luma_x + 8, luma_y + 8),
4 => (pic.cb.as_mut_slice(), pic.c_stride, luma_x / 2, luma_y / 2),
5 => (pic.cr.as_mut_slice(), pic.c_stride, luma_x / 2, luma_y / 2),
_ => unreachable!(),
};
for j in 0..8 {
for i in 0..8 {
plane[(py + j) * stride + (px + i)] = out[j * 8 + i];
}
}
}
pub fn decode_mba_diff(br: &mut BitReader<'_>) -> Result<Option<u8>> {
loop {
let remaining = br.bits_remaining();
if remaining == 0 {
return Ok(None);
}
if remaining >= 16 {
let peek = br.peek_u32(16)?;
if peek == 0x0001 {
return Ok(None);
}
}
match decode_vlc(br, MBA_TABLE) {
Ok(MbaSym::Diff(d)) => return Ok(Some(d)),
Ok(MbaSym::Stuffing) => continue,
Err(_) => return Ok(None),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chroma_mv_truncate_toward_zero() {
assert_eq!(luma_to_chroma_mv(0), 0);
assert_eq!(luma_to_chroma_mv(1), 0);
assert_eq!(luma_to_chroma_mv(2), 1);
assert_eq!(luma_to_chroma_mv(3), 1);
assert_eq!(luma_to_chroma_mv(-1), 0);
assert_eq!(luma_to_chroma_mv(-2), -1);
assert_eq!(luma_to_chroma_mv(-3), -1);
assert_eq!(luma_to_chroma_mv(15), 7);
assert_eq!(luma_to_chroma_mv(-15), -7);
}
#[test]
fn mv_paired_selection_in_range() {
let s = MvdSym { a: -2, b: 30 };
assert_eq!(reconstruct_mv(10, s), 8);
let s = MvdSym { a: 10, b: -22 };
assert_eq!(reconstruct_mv(10, s), -12);
}
#[test]
fn mvd_predictor_uses_prev_mv_mid_row() {
assert_eq!(mvd_predictor(6, 5, true, (3, -4)), (3, -4));
}
#[test]
fn mvd_predictor_resets_at_row_starts() {
assert_eq!(mvd_predictor(1, 0, false, (9, 9)), (0, 0));
assert_eq!(mvd_predictor(12, 11, true, (9, 9)), (0, 0));
assert_eq!(mvd_predictor(23, 22, true, (-7, 5)), (0, 0));
assert_eq!(mvd_predictor(13, 12, true, (9, 9)), (9, 9));
}
#[test]
fn mvd_predictor_resets_on_discontinuity_and_non_mc() {
assert_eq!(mvd_predictor(7, 5, true, (2, 2)), (0, 0));
assert_eq!(mvd_predictor(6, 5, false, (2, 2)), (0, 0));
assert_eq!(mvd_predictor(5, 0, true, (2, 2)), (0, 0));
}
#[test]
fn loop_filter_flat_block_identity() {
let src = [128u8; 64];
let out = apply_loop_filter(&src);
for &v in out.iter() {
assert_eq!(v, 128);
}
}
}