use super::motion_estimation::MotionVector;
pub const REF_IDX_NONE: i8 = -1;
#[derive(Debug, Clone)]
pub struct MbMvSnapshot {
mvs_l0: [MotionVector; 16],
refs_l0: [i8; 16],
mvs_l1: [MotionVector; 16],
refs_l1: [i8; 16],
decs: [bool; 16],
base_bx: usize,
base_by: usize,
}
#[derive(Debug, Clone)]
pub struct EncoderMvGrid {
width_4x4: usize,
height_4x4: usize,
mv_l0: Vec<MotionVector>,
ref_idx_l0: Vec<i8>,
mv_l1: Vec<MotionVector>,
ref_idx_l1: Vec<i8>,
decoded: Vec<bool>,
}
impl EncoderMvGrid {
pub fn new(mb_width: usize, mb_height: usize) -> Self {
let w = mb_width * 4;
let h = mb_height * 4;
Self {
width_4x4: w,
height_4x4: h,
mv_l0: vec![MotionVector::default(); w * h],
ref_idx_l0: vec![REF_IDX_NONE; w * h],
mv_l1: vec![MotionVector::default(); w * h],
ref_idx_l1: vec![REF_IDX_NONE; w * h],
decoded: vec![false; w * h],
}
}
pub fn width_4x4(&self) -> usize {
self.width_4x4
}
pub fn height_4x4(&self) -> usize {
self.height_4x4
}
pub fn reset(&mut self) {
for v in self.mv_l0.iter_mut() {
*v = MotionVector::default();
}
for r in self.ref_idx_l0.iter_mut() {
*r = REF_IDX_NONE;
}
for v in self.mv_l1.iter_mut() {
*v = MotionVector::default();
}
for r in self.ref_idx_l1.iter_mut() {
*r = REF_IDX_NONE;
}
for d in self.decoded.iter_mut() {
*d = false;
}
}
pub fn fill(
&mut self,
bx: usize,
by: usize,
w: usize,
h: usize,
mv: MotionVector,
ref_idx: i8,
) {
self.fill_lists(bx, by, w, h, Some((mv, ref_idx)), None);
}
pub fn fill_lists(
&mut self,
bx: usize,
by: usize,
w: usize,
h: usize,
l0: Option<(MotionVector, i8)>,
l1: Option<(MotionVector, i8)>,
) {
for dy in 0..h {
for dx in 0..w {
let x = bx + dx;
let y = by + dy;
if x >= self.width_4x4 || y >= self.height_4x4 {
continue;
}
let idx = y * self.width_4x4 + x;
if let Some((mv, ref_idx)) = l0 {
self.mv_l0[idx] = mv;
self.ref_idx_l0[idx] = ref_idx;
}
if let Some((mv, ref_idx)) = l1 {
self.mv_l1[idx] = mv;
self.ref_idx_l1[idx] = ref_idx;
}
self.decoded[idx] = true;
}
}
}
pub fn clear_l0_at(&mut self, bx: usize, by: usize, w: usize, h: usize) {
for dy in 0..h {
for dx in 0..w {
let x = bx + dx;
let y = by + dy;
if x < self.width_4x4 && y < self.height_4x4 {
let idx = y * self.width_4x4 + x;
self.mv_l0[idx] = MotionVector::default();
self.ref_idx_l0[idx] = REF_IDX_NONE;
}
}
}
}
pub fn clear_l1_at(&mut self, bx: usize, by: usize, w: usize, h: usize) {
for dy in 0..h {
for dx in 0..w {
let x = bx + dx;
let y = by + dy;
if x < self.width_4x4 && y < self.height_4x4 {
let idx = y * self.width_4x4 + x;
self.mv_l1[idx] = MotionVector::default();
self.ref_idx_l1[idx] = REF_IDX_NONE;
}
}
}
}
pub fn snapshot_mb(&self, mb_x: usize, mb_y: usize) -> MbMvSnapshot {
let base_bx = mb_x * 4;
let base_by = mb_y * 4;
let mut mvs_l0 = [MotionVector::default(); 16];
let mut refs_l0 = [REF_IDX_NONE; 16];
let mut mvs_l1 = [MotionVector::default(); 16];
let mut refs_l1 = [REF_IDX_NONE; 16];
let mut decs = [false; 16];
for dy in 0..4 {
for dx in 0..4 {
let x = base_bx + dx;
let y = base_by + dy;
if x < self.width_4x4 && y < self.height_4x4 {
let idx = y * self.width_4x4 + x;
let slot = dy * 4 + dx;
mvs_l0[slot] = self.mv_l0[idx];
refs_l0[slot] = self.ref_idx_l0[idx];
mvs_l1[slot] = self.mv_l1[idx];
refs_l1[slot] = self.ref_idx_l1[idx];
decs[slot] = self.decoded[idx];
}
}
}
MbMvSnapshot { mvs_l0, refs_l0, mvs_l1, refs_l1, decs, base_bx, base_by }
}
pub fn restore_mb(&mut self, snap: &MbMvSnapshot) {
for dy in 0..4 {
for dx in 0..4 {
let x = snap.base_bx + dx;
let y = snap.base_by + dy;
if x < self.width_4x4 && y < self.height_4x4 {
let idx = y * self.width_4x4 + x;
let slot = dy * 4 + dx;
self.mv_l0[idx] = snap.mvs_l0[slot];
self.ref_idx_l0[idx] = snap.refs_l0[slot];
self.mv_l1[idx] = snap.mvs_l1[slot];
self.ref_idx_l1[idx] = snap.refs_l1[slot];
self.decoded[idx] = snap.decs[slot];
}
}
}
}
pub fn get(&self, bx: isize, by: isize) -> Option<(MotionVector, i8)> {
self.get_l0(bx, by)
}
pub fn get_l0(&self, bx: isize, by: isize) -> Option<(MotionVector, i8)> {
let idx = self.cell_idx(bx, by)?;
let r = self.ref_idx_l0[idx];
if r == REF_IDX_NONE {
None
} else {
Some((self.mv_l0[idx], r))
}
}
pub fn get_l1(&self, bx: isize, by: isize) -> Option<(MotionVector, i8)> {
let idx = self.cell_idx(bx, by)?;
let r = self.ref_idx_l1[idx];
if r == REF_IDX_NONE {
None
} else {
Some((self.mv_l1[idx], r))
}
}
fn cell_idx(&self, bx: isize, by: isize) -> Option<usize> {
if bx < 0 || by < 0 {
return None;
}
let bx = bx as usize;
let by = by as usize;
if bx >= self.width_4x4 || by >= self.height_4x4 {
return None;
}
Some(by * self.width_4x4 + bx)
}
pub fn is_decoded(&self, bx: isize, by: isize) -> bool {
if bx < 0 || by < 0 {
return false;
}
let bx = bx as usize;
let by = by as usize;
if bx >= self.width_4x4 || by >= self.height_4x4 {
return false;
}
self.decoded[by * self.width_4x4 + bx]
}
}
#[inline]
fn median3(a: i16, b: i16, c: i16) -> i16 {
a.max(b).min(a.max(c)).min(b.max(c))
}
pub fn predict_mv_for_mb_partition(
grid: &EncoderMvGrid,
tl_bx: usize,
tl_by: usize,
part_w_4x4: usize,
part_h_4x4: usize,
mb_part_idx: u8,
current_ref_idx: i8,
) -> MotionVector {
let x = tl_bx as isize;
let y = tl_by as isize;
let a = grid.get(x - 1, y);
let b = grid.get(x, y - 1);
let c_bx = x + part_w_4x4 as isize;
let c_by = y - 1;
let c = if grid.is_decoded(c_bx, c_by) {
grid.get(c_bx, c_by)
} else {
grid.get(x - 1, y - 1)
};
if part_w_4x4 == 4 && part_h_4x4 == 2 {
if mb_part_idx == 0 {
if let Some((mv, r)) = b
&& r == current_ref_idx {
return mv;
}
} else if let Some((mv, r)) = a
&& r == current_ref_idx {
return mv;
}
} else if part_w_4x4 == 2 && part_h_4x4 == 4 {
if mb_part_idx == 0 {
if let Some((mv, r)) = a
&& r == current_ref_idx {
return mv;
}
} else if let Some((mv, r)) = c
&& r == current_ref_idx {
return mv;
}
}
predict_mv_for_partition(grid, tl_bx, tl_by, part_w_4x4, current_ref_idx)
}
pub fn predict_mv_for_partition(
grid: &EncoderMvGrid,
tl_bx: usize,
tl_by: usize,
part_w_4x4: usize,
current_ref_idx: i8,
) -> MotionVector {
let x = tl_bx as isize;
let y = tl_by as isize;
let a = grid.get(x - 1, y); let b = grid.get(x, y - 1); let c_bx = x + part_w_4x4 as isize;
let c_by = y - 1;
let c = if grid.is_decoded(c_bx, c_by) {
grid.get(c_bx, c_by) } else {
grid.get(x - 1, y - 1) };
let availability = [a.is_some(), b.is_some(), c.is_some()];
let avail_count: u8 = availability.iter().map(|&v| v as u8).sum();
if avail_count == 1
&& let Some((mv, _)) = a.or(b).or(c) {
return mv;
}
let matches: [Option<MotionVector>; 3] = [
a.and_then(|(mv, r)| if r == current_ref_idx { Some(mv) } else { None }),
b.and_then(|(mv, r)| if r == current_ref_idx { Some(mv) } else { None }),
c.and_then(|(mv, r)| if r == current_ref_idx { Some(mv) } else { None }),
];
let match_count = matches.iter().filter(|m| m.is_some()).count();
if match_count == 1
&& let Some(mv) = matches.iter().flatten().next() {
return *mv;
}
let la = a.map(|(m, _)| m).unwrap_or_default();
let tb = b.map(|(m, _)| m).unwrap_or_default();
let tr = c.map(|(m, _)| m).unwrap_or_default();
MotionVector {
mv_x: median3(la.mv_x, tb.mv_x, tr.mv_x),
mv_y: median3(la.mv_y, tb.mv_y, tr.mv_y),
}
}
pub fn predict_p_skip_mv(grid: &EncoderMvGrid, tl_bx: usize, tl_by: usize) -> MotionVector {
let x = tl_bx as isize;
let y = tl_by as isize;
let a_off_frame = x - 1 < 0;
let b_off_frame = y - 1 < 0;
if a_off_frame || b_off_frame {
return MotionVector { mv_x: 0, mv_y: 0 };
}
let a = grid.get(x - 1, y);
let b = grid.get(x, y - 1);
if let Some((mv_a, ref_a)) = a
&& ref_a == 0 && mv_a.mv_x == 0 && mv_a.mv_y == 0 {
return MotionVector { mv_x: 0, mv_y: 0 };
}
if let Some((mv_b, ref_b)) = b
&& ref_b == 0 && mv_b.mv_x == 0 && mv_b.mv_y == 0 {
return MotionVector { mv_x: 0, mv_y: 0 };
}
predict_mv_for_partition(grid, tl_bx, tl_by, 4, 0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn grid_get_none_when_unset() {
let g = EncoderMvGrid::new(2, 2);
assert_eq!(g.get(0, 0), None);
assert_eq!(g.get(-1, 0), None);
assert_eq!(g.get(100, 100), None);
}
#[test]
fn grid_fill_and_get_roundtrip() {
let mut g = EncoderMvGrid::new(2, 2);
let mv = MotionVector { mv_x: 42, mv_y: -7 };
g.fill(0, 0, 4, 4, mv, 0);
for bx in 0..4 {
for by in 0..4 {
assert_eq!(g.get(bx, by), Some((mv, 0)));
}
}
assert_eq!(g.get(4, 4), None);
}
#[test]
fn grid_reset_clears() {
let mut g = EncoderMvGrid::new(1, 1);
g.fill(0, 0, 4, 4, MotionVector { mv_x: 1, mv_y: 2 }, 0);
g.reset();
assert_eq!(g.get(0, 0), None);
}
#[test]
fn predictor_no_neighbors_returns_zero() {
let g = EncoderMvGrid::new(2, 2);
let mv = predict_mv_for_partition(&g, 0, 0, 4, 0);
assert_eq!(mv, MotionVector::ZERO);
}
#[test]
fn predictor_single_left_neighbor_returns_left() {
let mut g = EncoderMvGrid::new(2, 2);
g.fill(0, 0, 4, 4, MotionVector { mv_x: 10, mv_y: -5 }, 0);
let mv = predict_mv_for_partition(&g, 4, 0, 4, 0);
assert_eq!(mv, MotionVector { mv_x: 10, mv_y: -5 });
}
#[test]
fn predictor_three_neighbors_componentwise_median() {
let mut g = EncoderMvGrid::new(3, 2);
g.fill(0, 0, 4, 4, MotionVector { mv_x: 1, mv_y: 10 }, 0); g.fill(4, 0, 4, 4, MotionVector { mv_x: 2, mv_y: 20 }, 0);
g.fill(8, 0, 4, 4, MotionVector { mv_x: 3, mv_y: 30 }, 0); g.fill(0, 4, 4, 4, MotionVector { mv_x: 100, mv_y: 0 }, 0); let mv = predict_mv_for_partition(&g, 4, 4, 4, 0);
assert_eq!(mv, MotionVector { mv_x: 3, mv_y: 20 });
}
#[test]
fn p_skip_mv_no_neighbors_is_zero() {
let g = EncoderMvGrid::new(2, 2);
let mv = predict_p_skip_mv(&g, 0, 0);
assert_eq!(mv, MotionVector::ZERO);
}
#[test]
fn p_skip_mv_left_unavailable_is_zero() {
let mut g = EncoderMvGrid::new(2, 2);
g.fill(0, 0, 4, 4, MotionVector { mv_x: 7, mv_y: 3 }, 0);
let mv = predict_p_skip_mv(&g, 0, 4); assert_eq!(mv, MotionVector::ZERO);
}
#[test]
fn p_skip_mv_top_unavailable_is_zero() {
let mut g = EncoderMvGrid::new(2, 2);
g.fill(0, 0, 4, 4, MotionVector { mv_x: 7, mv_y: 3 }, 0);
let mv = predict_p_skip_mv(&g, 4, 0); assert_eq!(mv, MotionVector::ZERO);
}
#[test]
fn p_skip_mv_trivial_a_shortcut_to_zero() {
let mut g = EncoderMvGrid::new(3, 2);
g.fill(0, 4, 4, 4, MotionVector { mv_x: 0, mv_y: 0 }, 0);
g.fill(4, 0, 4, 4, MotionVector { mv_x: 5, mv_y: 5 }, 0);
let mv = predict_p_skip_mv(&g, 4, 4);
assert_eq!(mv, MotionVector::ZERO, "trivial A should shortcut to zero");
}
#[test]
fn p_skip_mv_trivial_b_shortcut_to_zero() {
let mut g = EncoderMvGrid::new(3, 2);
g.fill(0, 4, 4, 4, MotionVector { mv_x: 5, mv_y: 5 }, 0); g.fill(4, 0, 4, 4, MotionVector { mv_x: 0, mv_y: 0 }, 0); let mv = predict_p_skip_mv(&g, 4, 4);
assert_eq!(mv, MotionVector::ZERO);
}
#[test]
fn dual_list_l1_independent_from_l0() {
let mut g = EncoderMvGrid::new(2, 2);
let mv_l0 = MotionVector { mv_x: 10, mv_y: 20 };
let mv_l1 = MotionVector { mv_x: -30, mv_y: -40 };
g.fill(0, 0, 4, 4, mv_l0, 0);
assert_eq!(g.get_l0(0, 0), Some((mv_l0, 0)));
assert_eq!(g.get_l1(0, 0), None);
g.fill_lists(0, 0, 4, 4, None, Some((mv_l1, 0)));
assert_eq!(g.get_l0(0, 0), Some((mv_l0, 0)));
assert_eq!(g.get_l1(0, 0), Some((mv_l1, 0)));
}
#[test]
fn dual_list_bipred_fill_then_clear_l1() {
let mut g = EncoderMvGrid::new(2, 2);
let mv_l0 = MotionVector { mv_x: 1, mv_y: 2 };
let mv_l1 = MotionVector { mv_x: 3, mv_y: 4 };
g.fill_lists(0, 0, 4, 4, Some((mv_l0, 0)), Some((mv_l1, 0)));
assert_eq!(g.get_l0(0, 0), Some((mv_l0, 0)));
assert_eq!(g.get_l1(0, 0), Some((mv_l1, 0)));
g.clear_l1_at(0, 0, 4, 4);
for bx in 0..4 {
for by in 0..4 {
assert_eq!(g.get_l1(bx, by), None,
"L1 should be absent at ({bx},{by}) after clear_l1_at");
assert_eq!(g.get_l0(bx, by), Some((mv_l0, 0)),
"L0 should be unchanged at ({bx},{by})");
}
}
}
#[test]
fn dual_list_reset_clears_both() {
let mut g = EncoderMvGrid::new(1, 1);
g.fill_lists(
0, 0, 4, 4,
Some((MotionVector { mv_x: 1, mv_y: 2 }, 0)),
Some((MotionVector { mv_x: 3, mv_y: 4 }, 0)),
);
g.reset();
assert_eq!(g.get_l0(0, 0), None);
assert_eq!(g.get_l1(0, 0), None);
assert!(!g.is_decoded(0, 0));
}
#[test]
fn dual_list_snapshot_restore_preserves_both_lists() {
let mut g = EncoderMvGrid::new(2, 2);
let mv_l0_before = MotionVector { mv_x: 10, mv_y: 20 };
let mv_l1_before = MotionVector { mv_x: 30, mv_y: 40 };
g.fill_lists(0, 0, 4, 4, Some((mv_l0_before, 0)), Some((mv_l1_before, 0)));
let snap = g.snapshot_mb(0, 0);
g.fill_lists(
0, 0, 4, 4,
Some((MotionVector { mv_x: 99, mv_y: 99 }, 0)),
Some((MotionVector { mv_x: -99, mv_y: -99 }, 0)),
);
g.restore_mb(&snap);
for bx in 0..4 {
for by in 0..4 {
assert_eq!(g.get_l0(bx, by), Some((mv_l0_before, 0)));
assert_eq!(g.get_l1(bx, by), Some((mv_l1_before, 0)));
}
}
}
#[test]
fn p_style_fill_leaves_l1_absent_for_predictor() {
let mut g = EncoderMvGrid::new(2, 2);
g.fill(0, 0, 4, 4, MotionVector { mv_x: 7, mv_y: 9 }, 0);
for bx in 0..4 {
for by in 0..4 {
assert_eq!(g.get_l1(bx, by), None);
}
}
}
#[test]
fn p_skip_mv_nonzero_neighbors_use_median() {
let mut g = EncoderMvGrid::new(3, 2);
g.fill(0, 0, 4, 4, MotionVector { mv_x: 1, mv_y: 10 }, 0);
g.fill(4, 0, 4, 4, MotionVector { mv_x: 2, mv_y: 20 }, 0);
g.fill(8, 0, 4, 4, MotionVector { mv_x: 3, mv_y: 30 }, 0);
g.fill(0, 4, 4, 4, MotionVector { mv_x: 100, mv_y: 1 }, 0);
let mv = predict_p_skip_mv(&g, 4, 4);
assert_eq!(mv, MotionVector { mv_x: 3, mv_y: 20 });
}
}