use crate::codec::h264::cavlc::EmbedDomain;
use crate::codec::h264::macroblock::{Macroblock, MbType, BLOCK_INDEX_TO_POS};
use crate::codec::h264::mv::MvField;
use super::h264_uniward::FramePosition;
#[derive(Debug, Clone, Copy)]
pub struct DdcaParams {
pub w_drift: f32,
pub decay_per_hop: f32,
pub max_hops: u8,
pub w_inter_drift: f32,
pub inter_frame_decay: f32,
}
impl Default for DdcaParams {
fn default() -> Self {
Self {
w_drift: 0.35,
decay_per_hop: 0.7,
max_hops: 3,
w_inter_drift: 0.25,
inter_frame_decay: 0.75,
}
}
}
static DRIFT_WEIGHTS: [[f32; 4]; 9] = [
[0.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.5, 0.5, 0.0, 0.0], [0.0, 0.7, 0.5, 0.0], [0.6, 0.6, 0.0, 0.4], [0.3, 0.8, 0.0, 0.3], [0.8, 0.3, 0.0, 0.3], [0.0, 0.8, 0.5, 0.0], [1.0, 0.0, 0.0, 0.0], ];
static PROPAGATION_ATTENUATION: [f32; 9] = [
1.00, 1.00, 0.50, 0.80, 0.75, 0.85, 0.85, 0.80, 0.95, ];
fn forward_successors(
bx: usize,
by: usize,
width_in_4x4: usize,
height_in_4x4: usize,
) -> [Option<(usize, usize, usize)>; 3] {
let mut out: [Option<(usize, usize, usize)>; 3] = [None; 3];
if bx + 1 < width_in_4x4 {
out[0] = Some((bx + 1, by, 0));
}
if by + 1 < height_in_4x4 {
out[1] = Some((bx, by + 1, 1));
}
if bx + 1 < width_in_4x4 && by + 1 < height_in_4x4 {
out[2] = Some((bx + 1, by + 1, 3));
}
out
}
pub struct IntraModeMap {
modes: Vec<Option<u8>>,
width_in_4x4: usize,
height_in_4x4: usize,
}
impl IntraModeMap {
pub fn build(mbs: &[Macroblock], width_in_mbs: usize, height_in_mbs: usize) -> Self {
let width_in_4x4 = width_in_mbs * 4;
let height_in_4x4 = height_in_mbs * 4;
let mut modes = vec![None; width_in_4x4 * height_in_4x4];
for (mb_idx, mb) in mbs.iter().enumerate() {
let mb_x = mb_idx % width_in_mbs;
let mb_y = mb_idx / width_in_mbs;
if !matches!(mb.mb_type, MbType::I4x4) {
continue;
}
let Some(recon) = mb.recon.as_ref() else {
continue;
};
for blk_idx in 0..16 {
let (bx_in_mb, by_in_mb) = BLOCK_INDEX_TO_POS[blk_idx];
let bx = mb_x * 4 + bx_in_mb as usize;
let by = mb_y * 4 + by_in_mb as usize;
modes[by * width_in_4x4 + bx] = Some(recon.intra4x4_modes[blk_idx]);
}
}
Self {
modes,
width_in_4x4,
height_in_4x4,
}
}
pub fn mode_at(&self, bx: usize, by: usize) -> Option<u8> {
if bx >= self.width_in_4x4 || by >= self.height_in_4x4 {
return None;
}
self.modes[by * self.width_in_4x4 + bx]
}
pub fn width_in_4x4(&self) -> usize {
self.width_in_4x4
}
pub fn height_in_4x4(&self) -> usize {
self.height_in_4x4
}
}
pub fn apply_drift_multipliers(
frame_positions: &[FramePosition<'_>],
base_costs: &[f32],
modes: &IntraModeMap,
width_in_mbs: usize,
params: &DdcaParams,
) -> Vec<f32> {
debug_assert_eq!(frame_positions.len(), base_costs.len());
frame_positions
.iter()
.zip(base_costs.iter())
.map(|(fp, &cost)| {
if !cost.is_finite() {
return cost;
}
if fp.within_mb_block_idx >= 16 {
return cost;
}
if matches!(fp.pos.domain, EmbedDomain::LevelSuffixSign)
&& fp.pos.coeff_value.unsigned_abs() > 4
{
return cost;
}
let (bx_in_mb, by_in_mb) = BLOCK_INDEX_TO_POS[fp.within_mb_block_idx];
let mb_x = fp.mb_idx % width_in_mbs;
let mb_y = fp.mb_idx / width_in_mbs;
let bx = mb_x * 4 + bx_in_mb as usize;
let by = mb_y * 4 + by_in_mb as usize;
let drift = forward_drift(bx, by, modes, params);
cost * (1.0 + params.w_drift * drift)
})
.collect()
}
fn forward_drift(bx: usize, by: usize, modes: &IntraModeMap, params: &DdcaParams) -> f32 {
let width = modes.width_in_4x4();
let height = modes.height_in_4x4();
let n = width * height;
let mut visited = vec![false; n];
visited[by * width + bx] = true;
let mut queue: Vec<(usize, usize, u8, f32, usize)> = Vec::with_capacity(8);
for succ in forward_successors(bx, by, width, height).iter().flatten() {
queue.push((succ.0, succ.1, 1, 1.0, succ.2));
}
let mut drift = 0.0f32;
while let Some((nx, ny, hops, incoming, relation)) = queue.pop() {
let idx = ny * width + nx;
if visited[idx] {
continue;
}
visited[idx] = true;
let Some(mode) = modes.mode_at(nx, ny) else {
continue;
};
let mode_weight = DRIFT_WEIGHTS[mode as usize][relation];
drift += incoming * mode_weight;
if hops >= params.max_hops {
continue;
}
let mode_attenuation = PROPAGATION_ATTENUATION[mode as usize];
let next_incoming = incoming * params.decay_per_hop * mode_attenuation;
if next_incoming < 0.01 {
continue;
}
for succ in forward_successors(nx, ny, width, height).iter().flatten() {
queue.push((succ.0, succ.1, hops + 1, next_incoming, succ.2));
}
}
drift
}
pub struct InterFrameRefMap {
ref_counts: Vec<f32>,
width_in_4x4: usize,
height_in_4x4: usize,
}
impl InterFrameRefMap {
pub fn new(width_in_4x4: usize, height_in_4x4: usize) -> Self {
Self {
ref_counts: vec![0.0; width_in_4x4 * height_in_4x4],
width_in_4x4,
height_in_4x4,
}
}
pub fn width_in_4x4(&self) -> usize {
self.width_in_4x4
}
pub fn height_in_4x4(&self) -> usize {
self.height_in_4x4
}
pub fn ref_count(&self, bx: usize, by: usize) -> f32 {
if bx >= self.width_in_4x4 || by >= self.height_in_4x4 {
return 0.0;
}
self.ref_counts[by * self.width_in_4x4 + bx]
}
pub fn accumulate_mv_field(
&mut self,
mv_field: &MvField,
p_mb_x: usize,
p_mb_y: usize,
decay: f32,
) {
for blk_idx in 0..16usize {
let bx_in_mb = blk_idx % 4;
let by_in_mb = blk_idx / 4;
let ref_idx = mv_field.ref_idx[blk_idx];
if ref_idx < 0 {
continue;
}
let mv = mv_field.mvs[blk_idx];
let p_pixel_x = (p_mb_x * 16 + bx_in_mb * 4) as i32;
let p_pixel_y = (p_mb_y * 16 + by_in_mb * 4) as i32;
let mv_px_x = (mv.mv_x as i32) >> 2;
let mv_px_y = (mv.mv_y as i32) >> 2;
let ref_pixel_x = p_pixel_x + mv_px_x;
let ref_pixel_y = p_pixel_y + mv_px_y;
let ref_block_x = ref_pixel_x.div_euclid(4);
let ref_block_y = ref_pixel_y.div_euclid(4);
let sub_x = ref_pixel_x.rem_euclid(4);
let sub_y = ref_pixel_y.rem_euclid(4);
for dy in 0..2i32 {
for dx in 0..2i32 {
let bx2 = ref_block_x + dx;
let by2 = ref_block_y + dy;
if bx2 < 0
|| by2 < 0
|| bx2 as usize >= self.width_in_4x4
|| by2 as usize >= self.height_in_4x4
{
continue;
}
let w = if dx == 0 { 4 - sub_x } else { sub_x };
let h = if dy == 0 { 4 - sub_y } else { sub_y };
if w <= 0 || h <= 0 {
continue;
}
let area = (w * h) as f32;
let weight = (area / 16.0) * decay;
let idx = by2 as usize * self.width_in_4x4 + bx2 as usize;
self.ref_counts[idx] += weight;
}
}
}
}
}
pub fn apply_inter_frame_drift(
frame_positions: &[FramePosition<'_>],
base_costs: &[f32],
ref_map: &InterFrameRefMap,
width_in_mbs: usize,
params: &DdcaParams,
) -> Vec<f32> {
debug_assert_eq!(frame_positions.len(), base_costs.len());
frame_positions
.iter()
.zip(base_costs.iter())
.map(|(fp, &cost)| {
if !cost.is_finite() {
return cost;
}
if fp.within_mb_block_idx >= 16 {
return cost;
}
let (bx_in_mb, by_in_mb) = BLOCK_INDEX_TO_POS[fp.within_mb_block_idx];
let mb_x = fp.mb_idx % width_in_mbs;
let mb_y = fp.mb_idx / width_in_mbs;
let bx = mb_x * 4 + bx_in_mb as usize;
let by = mb_y * 4 + by_in_mb as usize;
let refs = ref_map.ref_count(bx, by);
cost * (1.0 + params.w_inter_drift * refs)
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::h264::cavlc::EmbeddablePosition;
fn make_modes(w_mbs: usize, h_mbs: usize, fill: u8) -> IntraModeMap {
let w4 = w_mbs * 4;
let h4 = h_mbs * 4;
IntraModeMap {
modes: vec![Some(fill); w4 * h4],
width_in_4x4: w4,
height_in_4x4: h4,
}
}
#[test]
fn drift_zero_when_all_modes_are_dc() {
let modes = make_modes(2, 2, 2);
let d = forward_drift(0, 0, &modes, &DdcaParams::default());
assert!(d > 0.0, "DC modes should propagate some drift");
}
#[test]
fn drift_zero_at_frame_corner_with_no_successors() {
let w4 = 4;
let h4 = 4;
let modes = IntraModeMap {
modes: vec![None; w4 * h4],
width_in_4x4: w4,
height_in_4x4: h4,
};
let d = forward_drift(3, 3, &modes, &DdcaParams::default());
assert_eq!(d, 0.0);
}
#[test]
fn vertical_mode_dominates_top_relation() {
assert_eq!(DRIFT_WEIGHTS[0][1], 1.0);
assert_eq!(DRIFT_WEIGHTS[0][0], 0.0);
}
#[test]
fn horizontal_mode_dominates_left_relation() {
assert_eq!(DRIFT_WEIGHTS[1][0], 1.0);
assert_eq!(DRIFT_WEIGHTS[1][1], 0.0);
}
#[test]
fn dc_chain_attenuates_faster_than_vertical_chain() {
let dc_modes = make_modes(8, 8, 2); let vertical_modes = make_modes(8, 8, 0);
let params = DdcaParams::default();
let dc_drift = forward_drift(0, 0, &dc_modes, ¶ms);
let v_drift = forward_drift(0, 0, &vertical_modes, ¶ms);
assert!(
v_drift > dc_drift,
"Vertical-chain drift ({v_drift}) should exceed DC-chain drift ({dc_drift})"
);
assert_eq!(PROPAGATION_ATTENUATION[0], 1.00); assert_eq!(PROPAGATION_ATTENUATION[1], 1.00); assert_eq!(PROPAGATION_ATTENUATION[2], 0.50); assert!(PROPAGATION_ATTENUATION[2] < PROPAGATION_ATTENUATION[0] * 0.7);
}
#[test]
fn apply_drift_multipliers_leaves_chroma_unchanged() {
let pos = EmbeddablePosition {
raw_byte_offset: 0,
bit_offset: 0,
domain: EmbedDomain::T1Sign,
scan_pos: 5,
coeff_value: 1,
ep_conflict: false,
block_idx: 18, frame_idx: 0,
mb_idx: 0,
};
let positions = vec![FramePosition {
pos: &pos,
mb_idx: 0,
within_mb_block_idx: 18,
qp_cb: 26,
qp_cr: 26,
}];
let base_costs = vec![100.0f32];
let modes = make_modes(1, 1, 0);
let out = apply_drift_multipliers(&positions, &base_costs, &modes, 1, &DdcaParams::default());
assert_eq!(out.len(), 1);
assert_eq!(out[0], 100.0, "chroma position cost must pass through unchanged");
}
#[test]
fn apply_drift_multipliers_boosts_luma_cost() {
let pos = EmbeddablePosition {
raw_byte_offset: 0,
bit_offset: 0,
domain: EmbedDomain::T1Sign,
scan_pos: 5,
coeff_value: 1,
ep_conflict: false,
block_idx: 0, frame_idx: 0,
mb_idx: 0,
};
let positions = vec![FramePosition {
pos: &pos,
mb_idx: 0,
within_mb_block_idx: 0,
qp_cb: 26,
qp_cr: 26,
}];
let base_costs = vec![100.0f32];
let modes = make_modes(2, 2, 2);
let out = apply_drift_multipliers(&positions, &base_costs, &modes, 2, &DdcaParams::default());
assert!(out[0] > 100.0, "luma cost should be boosted by drift; got {}", out[0]);
}
#[test]
fn inter_frame_ref_map_accumulates_zero_mv_at_current_block() {
let mut map = InterFrameRefMap::new(8, 8); let mut mv_field = MvField::default();
for i in 0..16 {
mv_field.ref_idx[i] = 0; }
map.accumulate_mv_field(&mv_field, 1, 1, 1.0);
for by in 4..8 {
for bx in 4..8 {
assert!(
(map.ref_count(bx, by) - 1.0).abs() < 1e-5,
"zero MV should produce unit self-reference, got {} at ({bx},{by})",
map.ref_count(bx, by)
);
}
}
assert_eq!(map.ref_count(0, 0), 0.0);
assert_eq!(map.ref_count(2, 3), 0.0);
}
#[test]
fn inter_frame_ref_map_decay_applies() {
let mut map = InterFrameRefMap::new(4, 4);
let mut mv_field = MvField::default();
for i in 0..16 {
mv_field.ref_idx[i] = 0;
}
map.accumulate_mv_field(&mv_field, 0, 0, 0.5);
for by in 0..4 {
for bx in 0..4 {
assert!((map.ref_count(bx, by) - 0.5).abs() < 1e-5);
}
}
}
#[test]
fn inter_frame_drift_boosts_cost_where_references_exist() {
let pos = EmbeddablePosition {
raw_byte_offset: 0,
bit_offset: 0,
domain: EmbedDomain::T1Sign,
scan_pos: 5,
coeff_value: 1,
ep_conflict: false,
block_idx: 0,
frame_idx: 0,
mb_idx: 0,
};
let positions = vec![FramePosition {
pos: &pos,
mb_idx: 0,
within_mb_block_idx: 0,
qp_cb: 26,
qp_cr: 26,
}];
let base_costs = vec![100.0f32];
let mut ref_map = InterFrameRefMap::new(4, 4);
ref_map.ref_counts[0] = 4.0; let out = apply_inter_frame_drift(&positions, &base_costs, &ref_map, 1, &DdcaParams::default());
assert!(
(out[0] - 200.0).abs() < 1e-3,
"inter-frame drift should boost ~2x at 4 refs; got {}",
out[0]
);
}
#[test]
fn infinite_base_cost_remains_infinite() {
let pos = EmbeddablePosition {
raw_byte_offset: 0,
bit_offset: 0,
domain: EmbedDomain::T1Sign,
scan_pos: 0, coeff_value: 1,
ep_conflict: false,
block_idx: 0,
frame_idx: 0,
mb_idx: 0,
};
let positions = vec![FramePosition {
pos: &pos,
mb_idx: 0,
within_mb_block_idx: 0,
qp_cb: 26,
qp_cr: 26,
}];
let base_costs = vec![f32::INFINITY];
let modes = make_modes(1, 1, 0);
let out = apply_drift_multipliers(&positions, &base_costs, &modes, 1, &DdcaParams::default());
assert!(out[0].is_infinite());
}
}