use crate::codec::h264::cavlc::{EmbedDomain, EmbeddablePosition};
use crate::codec::h264::slice::SliceType;
use crate::det_math::det_exp;
use crate::stego::cost::h264_mvd_cost::{
compute_mvd_cost, mb_luma_residual_energy, MvdCostParams,
};
const CSF_WEIGHT_4X4: [f32; 16] = [
f32::INFINITY, 8.0, 8.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 0.5, 0.5, 0.25, ];
const I_FRAME_PENALTY: f32 = 2.0;
const MIN_AC_ENERGY: f32 = 1.0;
pub fn compute_h264_costs(
positions: &[EmbeddablePosition],
block_ac_energies: &[f32],
slice_type: SliceType,
gop_position: u32,
gop_length: u32,
) -> Vec<f32> {
const TEMPORAL_DECAY_ALPHA: f64 = 0.3;
let temporal_weight = det_exp(-TEMPORAL_DECAY_ALPHA * gop_position as f64) as f32;
let _ = gop_length;
let i_frame_mult = if slice_type.is_intra() {
I_FRAME_PENALTY
} else {
1.0
};
positions
.iter()
.map(|pos| {
if pos.domain == EmbedDomain::MvdLsb {
let mb_residual =
mb_luma_residual_energy(block_ac_energies, pos.mb_idx);
return compute_mvd_cost(
pos.coeff_value,
mb_residual,
temporal_weight,
i_frame_mult,
&MvdCostParams::default(),
);
}
if pos.block_idx == u32::MAX {
return f32::INFINITY;
}
if pos.scan_pos == 0 {
return f32::INFINITY;
}
let domain_mult = match pos.domain {
EmbedDomain::T1Sign => 2.0, EmbedDomain::LevelSuffixMag => 1.0, EmbedDomain::LevelSuffixSign => 2.0 * pos.coeff_value.unsigned_abs() as f32,
EmbedDomain::MvdLsb => unreachable!("MvdLsb handled above"),
};
let csf = CSF_WEIGHT_4X4[pos.scan_pos.min(15) as usize];
if csf.is_infinite() {
return f32::INFINITY;
}
let ac_energy = if (pos.block_idx as usize) < block_ac_energies.len() {
block_ac_energies[pos.block_idx as usize]
} else {
0.0
};
if ac_energy < MIN_AC_ENERGY {
return f32::INFINITY;
}
let texture_mask = 1.0 / (1.0 + ac_energy);
csf * domain_mult * texture_mask * temporal_weight * i_frame_mult
})
.collect()
}
pub fn block_ac_energy(coeffs: &[i32; 16]) -> f32 {
let sum_sq: i64 = coeffs[1..]
.iter()
.map(|&c| (c as i64) * (c as i64))
.sum();
(sum_sq as f64).sqrt() as f32
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn csf_weights_monotonic_high_freq() {
for i in 1..15 {
assert!(
CSF_WEIGHT_4X4[i] >= CSF_WEIGHT_4X4[i + 1],
"CSF_WEIGHT_4X4[{i}]={} < CSF_WEIGHT_4X4[{}]={}",
CSF_WEIGHT_4X4[i],
i + 1,
CSF_WEIGHT_4X4[i + 1]
);
}
}
#[test]
fn csf_dc_is_wet() {
assert!(CSF_WEIGHT_4X4[0].is_infinite());
}
#[test]
fn temporal_decay_is_monotonically_decreasing() {
let pos = EmbeddablePosition {
raw_byte_offset: 0,
bit_offset: 0,
domain: EmbedDomain::T1Sign,
scan_pos: 5,
coeff_value: 1,
ep_conflict: false,
block_idx: 0,
frame_idx: 0,
mb_idx: 0,
};
let positions = vec![pos];
let ac = vec![10.0];
let mut prev = f32::INFINITY;
for gop_pos in 0..=10u32 {
let costs = compute_h264_costs(&positions, &ac, SliceType::P, gop_pos, 30);
let c = costs[0];
assert!(c.is_finite(), "cost must be finite at gop_pos={gop_pos}");
assert!(c < prev, "cost must strictly decrease, gop_pos={gop_pos} ({c} >= {prev})");
prev = c;
}
let cost_at_0 = compute_h264_costs(&positions, &ac, SliceType::P, 0, 30)[0];
let cost_at_5 = compute_h264_costs(&positions, &ac, SliceType::P, 5, 30)[0];
let ratio = cost_at_5 / cost_at_0;
assert!(
(ratio - 0.2231).abs() < 0.002,
"gop_pos=5/0 ratio = {ratio}, expected ~0.223 (e^(-1.5))"
);
}
#[test]
fn cost_ep_conflict_no_longer_forces_wet() {
let positions = vec![EmbeddablePosition {
raw_byte_offset: 0,
bit_offset: 0,
domain: EmbedDomain::T1Sign,
scan_pos: 15,
coeff_value: 1,
ep_conflict: true,
block_idx: 0,
frame_idx: 0,
mb_idx: 0,
}];
let ac = vec![10.0];
let costs = compute_h264_costs(&positions, &ac, SliceType::I, 0, 30);
assert!(costs[0].is_finite(), "ep_conflict should no longer force infinite cost");
}
#[test]
fn cost_dc_position_is_wet() {
let positions = vec![EmbeddablePosition {
raw_byte_offset: 0,
bit_offset: 0,
domain: EmbedDomain::T1Sign,
scan_pos: 0, coeff_value: 1,
ep_conflict: false,
block_idx: 0,
frame_idx: 0,
mb_idx: 0,
}];
let ac = vec![10.0];
let costs = compute_h264_costs(&positions, &ac, SliceType::I, 0, 30);
assert!(costs[0].is_infinite());
}
#[test]
fn cost_flat_block_is_wet() {
let positions = vec![EmbeddablePosition {
raw_byte_offset: 0,
bit_offset: 0,
domain: EmbedDomain::T1Sign,
scan_pos: 15,
coeff_value: 1,
ep_conflict: false,
block_idx: 0,
frame_idx: 0,
mb_idx: 0,
}];
let ac = vec![0.5]; let costs = compute_h264_costs(&positions, &ac, SliceType::I, 0, 30);
assert!(costs[0].is_infinite());
}
#[test]
fn cost_temporal_weight() {
let pos = EmbeddablePosition {
raw_byte_offset: 0,
bit_offset: 0,
domain: EmbedDomain::T1Sign,
scan_pos: 15,
coeff_value: 1,
ep_conflict: false,
block_idx: 0,
frame_idx: 0,
mb_idx: 0,
};
let ac = vec![10.0];
let cost_i = compute_h264_costs(&[pos.clone()], &ac, SliceType::I, 0, 30)[0];
let cost_p = compute_h264_costs(&[pos], &ac, SliceType::P, 29, 30)[0];
assert!(cost_i > cost_p, "I-frame cost {cost_i} should be > P-frame cost {cost_p}");
}
#[test]
fn cost_texture_masking() {
let pos = EmbeddablePosition {
raw_byte_offset: 0,
bit_offset: 0,
domain: EmbedDomain::T1Sign,
scan_pos: 15,
coeff_value: 1,
ep_conflict: false,
block_idx: 0,
frame_idx: 0,
mb_idx: 0,
};
let cost_low = compute_h264_costs(&[pos.clone()], &[2.0], SliceType::P, 1, 30)[0];
let cost_high = compute_h264_costs(&[pos], &[50.0], SliceType::P, 1, 30)[0];
assert!(cost_low > cost_high, "low-texture cost {cost_low} should be > high-texture cost {cost_high}");
}
#[test]
fn block_ac_energy_computation() {
let coeffs = [10, 3, -2, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
let energy = block_ac_energy(&coeffs);
assert!((energy - 3.74).abs() < 0.1);
}
#[test]
fn block_ac_energy_zero_block() {
let coeffs = [0; 16];
assert_eq!(block_ac_energy(&coeffs), 0.0);
}
#[test]
fn block_ac_energy_dc_only() {
let coeffs = [42, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
assert_eq!(block_ac_energy(&coeffs), 0.0); }
}