use super::bitstream::{EpByteMap, RbspReader};
use super::cavlc::{check_ep_conflict, EmbedDomain, EmbeddablePosition};
use super::macroblock::MbType;
use super::H264Error;
pub const MVD_BLOCK_IDX_SENTINEL: u32 = u32::MAX - 1;
fn capture_mvd_position(
bits_before: usize,
bits_after: usize,
mvd_value: i16,
ep_map: &EpByteMap,
raw_data: &[u8],
) -> Option<EmbeddablePosition> {
if mvd_value == 0 {
return None;
}
let len = bits_after.saturating_sub(bits_before);
if len < 3 || len.is_multiple_of(2) {
return None;
}
let lsb_bit_idx = bits_after - 1;
let rbsp_byte = lsb_bit_idx / 8;
let rbsp_bit = (lsb_bit_idx % 8) as u8;
if rbsp_byte >= ep_map.rbsp_to_raw.len() {
return None;
}
let raw_byte = ep_map.rbsp_to_raw[rbsp_byte];
let ep_conflict = check_ep_conflict(raw_data, raw_byte, rbsp_bit);
Some(EmbeddablePosition {
raw_byte_offset: raw_byte,
bit_offset: rbsp_bit,
domain: EmbedDomain::MvdLsb,
scan_pos: 0,
coeff_value: mvd_value as i32,
ep_conflict,
block_idx: MVD_BLOCK_IDX_SENTINEL,
frame_idx: 0,
mb_idx: 0, })
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct MotionVector {
pub mv_x: i16,
pub mv_y: i16,
}
impl MotionVector {
pub const fn new(x: i16, y: i16) -> Self {
Self { mv_x: x, mv_y: y }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PSubPartition {
P8x8,
P8x4,
P4x8,
P4x4,
}
impl PSubPartition {
pub fn from_code(code: u32) -> Option<Self> {
Some(match code {
0 => Self::P8x8,
1 => Self::P8x4,
2 => Self::P4x8,
3 => Self::P4x4,
_ => return None,
})
}
pub fn num_mvds(self) -> usize {
match self {
Self::P8x8 => 1,
Self::P8x4 | Self::P4x8 => 2,
Self::P4x4 => 4,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct MvField {
pub mvs: [MotionVector; 16],
pub ref_idx: [i8; 16],
}
pub struct MvPredictorContext {
width_in_4x4: usize,
height_in_4x4: usize,
mv_grid: Vec<MotionVector>,
ref_idx_grid: Vec<i8>,
}
impl MvPredictorContext {
pub fn new(width_in_mbs: u32, height_in_mbs: u32) -> Self {
let width_in_4x4 = (width_in_mbs * 4) as usize;
let height_in_4x4 = (height_in_mbs * 4) as usize;
let total = width_in_4x4 * height_in_4x4;
Self {
width_in_4x4,
height_in_4x4,
mv_grid: vec![MotionVector::default(); total],
ref_idx_grid: vec![-1; total],
}
}
pub fn width_in_4x4(&self) -> usize {
self.width_in_4x4
}
pub fn set(&mut self, block_x: usize, block_y: usize, mv: MotionVector, ref_idx: i8) {
if block_x >= self.width_in_4x4 || block_y >= self.height_in_4x4 {
return;
}
let idx = block_y * self.width_in_4x4 + block_x;
self.mv_grid[idx] = mv;
self.ref_idx_grid[idx] = ref_idx;
}
pub fn get(&self, block_x: isize, block_y: isize) -> Option<(MotionVector, i8)> {
if block_x < 0
|| block_y < 0
|| (block_x as usize) >= self.width_in_4x4
|| (block_y as usize) >= self.height_in_4x4
{
return None;
}
let idx = (block_y as usize) * self.width_in_4x4 + block_x as usize;
let rid = self.ref_idx_grid[idx];
if rid < 0 {
None
} else {
Some((self.mv_grid[idx], rid))
}
}
}
pub fn amvp_predict(
left: Option<(MotionVector, i8)>,
top: Option<(MotionVector, i8)>,
top_right: Option<(MotionVector, i8)>,
current_ref_idx: i8,
part_w_4x4: usize,
part_h_4x4: usize,
mb_part_idx: u8,
) -> MotionVector {
if part_w_4x4 == 4 && part_h_4x4 == 2 {
if mb_part_idx == 0 {
if let Some((mv, r)) = top
&& r == current_ref_idx {
return mv;
}
} else if let Some((mv, r)) = left
&& r == current_ref_idx {
return mv;
}
} else if part_w_4x4 == 2 && part_h_4x4 == 4 {
if mb_part_idx == 0 {
if let Some((mv, r)) = left
&& r == current_ref_idx {
return mv;
}
} else if let Some((mv, r)) = top_right
&& r == current_ref_idx {
return mv;
}
}
median_mv(left, top, top_right, current_ref_idx)
}
pub fn median_mv(
left: Option<(MotionVector, i8)>,
top: Option<(MotionVector, i8)>,
top_right: Option<(MotionVector, i8)>,
current_ref_idx: i8,
) -> MotionVector {
let availability = [left.is_some(), top.is_some(), top_right.is_some()];
let avail_count: u8 = availability.iter().map(|&b| b as u8).sum();
if avail_count == 1
&& let Some((mv, _)) = left.or(top).or(top_right) {
return mv;
}
let matching: Vec<Option<MotionVector>> = [left, top, top_right]
.iter()
.map(|n| {
n.and_then(|(mv, rid)| {
if rid == current_ref_idx {
Some(mv)
} else {
None
}
})
})
.collect();
let match_count: usize = matching.iter().filter(|m| m.is_some()).count();
if match_count == 1
&& let Some(mv) = matching.iter().flatten().next() {
return *mv;
}
let l = left.map(|(mv, _)| mv).unwrap_or_default();
let t = top.map(|(mv, _)| mv).unwrap_or_default();
let tr = top_right.map(|(mv, _)| mv).unwrap_or_default();
MotionVector {
mv_x: median3(l.mv_x, t.mv_x, tr.mv_x),
mv_y: median3(l.mv_y, t.mv_y, tr.mv_y),
}
}
#[inline]
fn median3(a: i16, b: i16, c: i16) -> i16 {
a.max(b).min(a.max(c)).min(b.max(c))
}
#[inline]
fn read_mvd_capturing(
reader: &mut RbspReader<'_>,
ep_map: &EpByteMap,
raw_data: &[u8],
mvd_positions: &mut Vec<EmbeddablePosition>,
) -> Result<i16, H264Error> {
let bits_before = reader.bits_read();
let mvd = reader.read_se()? as i16;
let bits_after = reader.bits_read();
if let Some(p) = capture_mvd_position(bits_before, bits_after, mvd, ep_map, raw_data) {
mvd_positions.push(p);
}
Ok(mvd)
}
pub fn parse_mv_field(
reader: &mut RbspReader<'_>,
mb_type: MbType,
mb_x: u32,
mb_y: u32,
num_ref_idx_l0_active: u8,
ctx: &mut MvPredictorContext,
ep_map: &EpByteMap,
raw_data: &[u8],
mvd_positions: &mut Vec<EmbeddablePosition>,
) -> Result<Option<MvField>, H264Error> {
let max_ref = num_ref_idx_l0_active.saturating_sub(1) as u32;
let base_x = (mb_x * 4) as usize;
let base_y = (mb_y * 4) as usize;
let mut field = MvField::default();
match mb_type {
MbType::P16x16 => {
let ref_idx = if max_ref > 0 {
reader.read_te(max_ref)? as i8
} else {
0
};
let mvd_x = read_mvd_capturing(reader, ep_map, raw_data, mvd_positions)?;
let mvd_y = read_mvd_capturing(reader, ep_map, raw_data, mvd_positions)?;
resolve_partition(
&mut field,
ctx,
base_x,
base_y,
0,
0,
4,
4,
4,
4,
0,
ref_idx,
(mvd_x, mvd_y),
);
}
MbType::P16x8 => {
let mut ref_idxs = [0i8; 2];
for r in ref_idxs.iter_mut() {
*r = if max_ref > 0 {
reader.read_te(max_ref)? as i8
} else {
0
};
}
for (i, r) in ref_idxs.iter().enumerate() {
let mvd_x = read_mvd_capturing(reader, ep_map, raw_data, mvd_positions)?;
let mvd_y = read_mvd_capturing(reader, ep_map, raw_data, mvd_positions)?;
let (off_y, h) = (i * 2, 2);
resolve_partition(
&mut field,
ctx,
base_x,
base_y,
0,
off_y,
4,
h,
4,
2,
i as u8,
*r,
(mvd_x, mvd_y),
);
}
}
MbType::P8x16 => {
let mut ref_idxs = [0i8; 2];
for r in ref_idxs.iter_mut() {
*r = if max_ref > 0 {
reader.read_te(max_ref)? as i8
} else {
0
};
}
for (i, r) in ref_idxs.iter().enumerate() {
let mvd_x = read_mvd_capturing(reader, ep_map, raw_data, mvd_positions)?;
let mvd_y = read_mvd_capturing(reader, ep_map, raw_data, mvd_positions)?;
let (off_x, w) = (i * 2, 2);
resolve_partition(
&mut field,
ctx,
base_x,
base_y,
off_x,
0,
w,
4,
2,
4,
i as u8,
*r,
(mvd_x, mvd_y),
);
}
}
MbType::P8x8 | MbType::P8x8ref0 => {
let mut subs = [PSubPartition::P8x8; 4];
for s in subs.iter_mut() {
let code = reader.read_ue()?;
*s = PSubPartition::from_code(code).ok_or_else(|| {
H264Error::CavlcError(format!("invalid P-slice sub_mb_type: {code}"))
})?;
}
let mut ref_idxs = [0i8; 4];
if mb_type != MbType::P8x8ref0 && max_ref > 0 {
for r in ref_idxs.iter_mut() {
*r = reader.read_te(max_ref)? as i8;
}
}
let sub_origins = [(0usize, 0usize), (2, 0), (0, 2), (2, 2)];
for i in 0..4 {
let sub = subs[i];
let (off_x, off_y) = sub_origins[i];
let parts: &[(usize, usize, usize, usize)] = match sub {
PSubPartition::P8x8 => &[(0, 0, 2, 2)],
PSubPartition::P8x4 => &[(0, 0, 2, 1), (0, 1, 2, 1)],
PSubPartition::P4x8 => &[(0, 0, 1, 2), (1, 0, 1, 2)],
PSubPartition::P4x4 => {
&[(0, 0, 1, 1), (1, 0, 1, 1), (0, 1, 1, 1), (1, 1, 1, 1)]
}
};
for &(dx, dy, pw, ph) in parts {
let mvd_x = read_mvd_capturing(reader, ep_map, raw_data, mvd_positions)?;
let mvd_y = read_mvd_capturing(reader, ep_map, raw_data, mvd_positions)?;
resolve_partition(
&mut field,
ctx,
base_x,
base_y,
off_x + dx,
off_y + dy,
pw,
ph,
2,
2,
i as u8,
ref_idxs[i],
(mvd_x, mvd_y),
);
}
}
}
_ => return Ok(None),
}
Ok(Some(field))
}
#[allow(clippy::too_many_arguments)]
fn resolve_partition(
field: &mut MvField,
ctx: &mut MvPredictorContext,
base_x: usize,
base_y: usize,
off_x: usize,
off_y: usize,
width: usize,
height: usize,
mb_part_w_4x4: usize,
mb_part_h_4x4: usize,
mb_part_idx: u8,
ref_idx: i8,
mvd: (i16, i16),
) {
let top_left_x = (base_x + off_x) as isize;
let top_left_y = (base_y + off_y) as isize;
let a = ctx.get(top_left_x - 1, top_left_y); let b = ctx.get(top_left_x, top_left_y - 1); let c_x = top_left_x + width as isize;
let c_y = top_left_y - 1;
let c = ctx.get(c_x, c_y).or_else(|| ctx.get(top_left_x - 1, top_left_y - 1));
let mvp = amvp_predict(a, b, c, ref_idx, mb_part_w_4x4, mb_part_h_4x4, mb_part_idx);
let mv = MotionVector {
mv_x: mvp.mv_x.wrapping_add(mvd.0),
mv_y: mvp.mv_y.wrapping_add(mvd.1),
};
for dy in 0..height {
for dx in 0..width {
let block_x_in_mb = off_x + dx;
let block_y_in_mb = off_y + dy;
let idx_in_mb = block_y_in_mb * 4 + block_x_in_mb;
if idx_in_mb < 16 {
field.mvs[idx_in_mb] = mv;
field.ref_idx[idx_in_mb] = ref_idx;
}
ctx.set(base_x + off_x + dx, base_y + off_y + dy, mv, ref_idx);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn median3_picks_the_middle_of_three() {
assert_eq!(median3(1, 2, 3), 2);
assert_eq!(median3(3, 2, 1), 2);
assert_eq!(median3(5, 5, 1), 5);
assert_eq!(median3(-10, 0, 5), 0);
}
#[test]
fn median_mv_single_neighbour_uses_it_directly() {
let left = Some((MotionVector::new(8, -4), 0));
let pred = median_mv(left, None, None, 0);
assert_eq!(pred, MotionVector::new(8, -4));
let top = Some((MotionVector::new(2, 2), 0));
let pred = median_mv(None, top, None, 0);
assert_eq!(pred, MotionVector::new(2, 2));
}
#[test]
fn median_mv_single_match_on_refidx() {
let left = Some((MotionVector::new(100, 100), 0));
let top = Some((MotionVector::new(50, 50), 1));
let top_right = Some((MotionVector::new(0, 0), 0));
let pred = median_mv(left, top, top_right, 1);
assert_eq!(pred, MotionVector::new(50, 50));
}
#[test]
fn median_mv_three_neighbours_componentwise_median() {
let left = Some((MotionVector::new(1, 10), 0));
let top = Some((MotionVector::new(2, 20), 0));
let top_right = Some((MotionVector::new(3, 30), 0));
let pred = median_mv(left, top, top_right, 0);
assert_eq!(pred, MotionVector::new(2, 20));
}
#[test]
fn mv_predictor_context_roundtrips_writes() {
let mut ctx = MvPredictorContext::new(4, 4); ctx.set(3, 7, MotionVector::new(42, -7), 1);
assert_eq!(ctx.get(3, 7), Some((MotionVector::new(42, -7), 1)));
assert_eq!(ctx.get(0, 0), None);
assert_eq!(ctx.get(-1, 0), None);
assert_eq!(ctx.get(0, -1), None);
assert_eq!(ctx.get(100, 100), None);
}
fn identity_ep_map(n: usize) -> EpByteMap {
EpByteMap {
rbsp_to_raw: (0..n).collect(),
}
}
#[test]
fn capture_mvd_position_skips_zero_codeword() {
let ep_map = identity_ep_map(4);
let raw = [0u8; 4];
let p = capture_mvd_position(0, 1, 0, &ep_map, &raw);
assert!(p.is_none(), "mvd=0 must produce no embeddable position");
}
#[test]
fn capture_mvd_position_marks_suffix_lsb_for_nonzero_mvd() {
let ep_map = identity_ep_map(8);
let raw = [0u8; 8];
let p =
capture_mvd_position(0, 3, 1, &ep_map, &raw).expect("non-zero mvd must capture");
assert_eq!(p.domain, EmbedDomain::MvdLsb);
assert_eq!(p.block_idx, MVD_BLOCK_IDX_SENTINEL);
assert_eq!(p.coeff_value, 1);
assert_eq!(p.raw_byte_offset, 0);
assert_eq!(p.bit_offset, 2);
}
#[test]
fn capture_mvd_position_marks_large_mvd_at_long_suffix() {
let ep_map = identity_ep_map(8);
let raw = [0u8; 8];
let p =
capture_mvd_position(5, 10, 3, &ep_map, &raw).expect("non-zero mvd must capture");
assert_eq!(p.raw_byte_offset, 1);
assert_eq!(p.bit_offset, 1);
assert_eq!(p.coeff_value, 3);
}
#[test]
fn capture_mvd_position_preserves_negative_mvd_in_coeff_value() {
let ep_map = identity_ep_map(4);
let raw = [0u8; 4];
let p = capture_mvd_position(0, 5, -2, &ep_map, &raw)
.expect("non-zero mvd must capture");
assert_eq!(p.coeff_value, -2);
assert_eq!(p.domain, EmbedDomain::MvdLsb);
}
#[test]
fn capture_mvd_position_rejects_malformed_length() {
let ep_map = identity_ep_map(4);
let raw = [0u8; 4];
assert!(capture_mvd_position(0, 4, 2, &ep_map, &raw).is_none());
assert!(capture_mvd_position(0, 2, 2, &ep_map, &raw).is_none());
}
#[test]
fn parse_mv_field_captures_mvd_positions_from_synthetic_p16x16() {
use super::super::bitstream::{EpByteMap, RbspReader};
let bytes = [0x31u8, 0x40];
let mut reader = RbspReader::new(&bytes);
let ep_map = EpByteMap {
rbsp_to_raw: vec![0, 1],
};
let mut ctx = MvPredictorContext::new(1, 1);
let mut positions = Vec::new();
let field = parse_mv_field(
&mut reader,
MbType::P16x16,
0,
0,
1,
&mut ctx,
&ep_map,
&bytes,
&mut positions,
)
.expect("parse")
.expect("p16x16 returns Some");
assert_eq!(positions.len(), 2);
assert_eq!(positions[0].coeff_value, 3);
assert_eq!(positions[1].coeff_value, -2);
for p in &positions {
assert_eq!(p.domain, EmbedDomain::MvdLsb);
assert_eq!(p.block_idx, MVD_BLOCK_IDX_SENTINEL);
}
assert_eq!(field.mvs[0], MotionVector::new(3, -2));
}
#[test]
fn parse_mv_field_skips_zero_mvds_in_synthetic_p16x16() {
use super::super::bitstream::{EpByteMap, RbspReader};
let bytes = [0xC0u8];
let mut reader = RbspReader::new(&bytes);
let ep_map = EpByteMap {
rbsp_to_raw: vec![0],
};
let mut ctx = MvPredictorContext::new(1, 1);
let mut positions = Vec::new();
parse_mv_field(
&mut reader,
MbType::P16x16,
0,
0,
1,
&mut ctx,
&ep_map,
&bytes,
&mut positions,
)
.expect("parse");
assert!(
positions.is_empty(),
"mvd=0 pairs must produce no embeddable positions"
);
}
}