use super::{EmbedDomain, MvdSlot};
pub const COEFF_SIGN_BASE_COST: f32 = 1.0;
pub const COEFF_SUFFIX_LSB_PER_MAG2: f32 = 2.0;
pub const MVD_SIGN_BASE_COST: f32 = 4.0;
pub const MVD_SUFFIX_LSB_PER_MAG2: f32 = 2.0;
pub const WET_COST: f32 = f32::INFINITY;
#[derive(Copy, Clone, Debug)]
pub struct PositionCostCtx {
pub frame_idx: u32,
pub mb_addr: u32,
pub intra_drift_factor: f32,
pub inter_drift_factor: f32,
}
impl Default for PositionCostCtx {
fn default() -> Self {
Self::new(0, 0)
}
}
impl PositionCostCtx {
pub fn new(frame_idx: u32, mb_addr: u32) -> Self {
Self {
frame_idx, mb_addr,
intra_drift_factor: 1.0,
inter_drift_factor: 1.0,
}
}
#[inline]
pub fn drift_factor(&self) -> f32 {
self.intra_drift_factor * self.inter_drift_factor
}
}
#[inline]
pub fn coeff_sign_cost(_coeff: i32, ctx: &PositionCostCtx) -> f32 {
COEFF_SIGN_BASE_COST * ctx.drift_factor()
}
#[inline]
pub fn coeff_suffix_lsb_cost(coeff: i32, ctx: &PositionCostCtx) -> f32 {
let mag = coeff.unsigned_abs() as f32;
COEFF_SUFFIX_LSB_PER_MAG2 * mag * mag * ctx.drift_factor()
}
#[inline]
pub fn mvd_sign_cost(_slot: &MvdSlot, ctx: &PositionCostCtx) -> f32 {
MVD_SIGN_BASE_COST * ctx.drift_factor()
}
#[inline]
pub fn mvd_suffix_lsb_cost(slot: &MvdSlot, ctx: &PositionCostCtx) -> f32 {
let mag = slot.value.unsigned_abs() as f32;
MVD_SUFFIX_LSB_PER_MAG2 * mag * mag * ctx.drift_factor()
}
pub fn coeff_sign_cost_vec(
scan_coeffs: &[i32],
start_idx: usize,
end_idx: usize,
ctx: &PositionCostCtx,
) -> Vec<f32> {
let mut sig: Vec<usize> = (start_idx..=end_idx)
.filter(|&i| scan_coeffs[i] != 0)
.collect();
sig.reverse();
sig.into_iter()
.map(|i| coeff_sign_cost(scan_coeffs[i], ctx))
.collect()
}
pub fn coeff_suffix_lsb_cost_vec(
scan_coeffs: &[i32],
start_idx: usize,
end_idx: usize,
ctx: &PositionCostCtx,
) -> Vec<f32> {
let mut sig: Vec<usize> = (start_idx..=end_idx)
.filter(|&i| scan_coeffs[i].unsigned_abs() >= 16)
.collect();
sig.reverse();
sig.into_iter()
.map(|i| coeff_suffix_lsb_cost(scan_coeffs[i], ctx))
.collect()
}
pub fn mvd_sign_cost_vec(slots: &[MvdSlot], ctx: &PositionCostCtx) -> Vec<f32> {
slots
.iter()
.filter(|s| s.value != 0)
.map(|s| mvd_sign_cost(s, ctx))
.collect()
}
pub fn mvd_suffix_lsb_cost_vec(slots: &[MvdSlot], ctx: &PositionCostCtx) -> Vec<f32> {
slots
.iter()
.filter(|s| s.value.unsigned_abs() >= 9)
.map(|s| mvd_suffix_lsb_cost(s, ctx))
.collect()
}
pub fn domain_base_cost(domain: EmbedDomain) -> f32 {
match domain {
EmbedDomain::CoeffSignBypass => COEFF_SIGN_BASE_COST,
EmbedDomain::CoeffSuffixLsb => COEFF_SUFFIX_LSB_PER_MAG2, EmbedDomain::MvdSignBypass => MVD_SIGN_BASE_COST,
EmbedDomain::MvdSuffixLsb => MVD_SUFFIX_LSB_PER_MAG2,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::h264::stego::Axis;
fn ctx() -> PositionCostCtx {
PositionCostCtx::new(0, 0)
}
#[test]
fn coeff_sign_cost_independent_of_magnitude() {
let c = ctx();
let small = coeff_sign_cost(1, &c);
let medium = coeff_sign_cost(50, &c);
let large = coeff_sign_cost(1000, &c);
assert_eq!(small, medium);
assert_eq!(medium, large);
assert!(small > 0.0);
}
#[test]
fn coeff_suffix_lsb_cost_grows_with_magnitude_squared() {
let c = ctx();
let c16 = coeff_suffix_lsb_cost(16, &c);
let c32 = coeff_suffix_lsb_cost(32, &c);
let c64 = coeff_suffix_lsb_cost(64, &c);
assert!((c32 - 4.0 * c16).abs() < 0.01);
assert!((c64 - 16.0 * c16).abs() < 0.1);
}
#[test]
fn coeff_suffix_lsb_cost_higher_than_sign_for_large_coeffs() {
let c = ctx();
let sign = coeff_sign_cost(20, &c);
let suffix = coeff_suffix_lsb_cost(20, &c);
assert!(suffix > sign, "suffix LSB cost must exceed sign cost for large coeffs");
assert!(suffix / sign > 100.0);
}
#[test]
fn drift_factor_multiplies_costs() {
let mut c = ctx();
let baseline = coeff_sign_cost(5, &c);
c.intra_drift_factor = 2.5;
let drifted = coeff_sign_cost(5, &c);
assert!((drifted - 2.5 * baseline).abs() < 0.01);
}
#[test]
fn drift_factor_composes_intra_and_inter() {
let mut c = ctx();
c.intra_drift_factor = 2.0;
c.inter_drift_factor = 1.5;
assert!((c.drift_factor() - 3.0).abs() < 0.01);
}
#[test]
fn cost_vec_matches_enumerate_order() {
let mut scan = vec![0i32; 16];
scan[0] = 5; scan[3] = -8; scan[7] = 12;
let costs = coeff_sign_cost_vec(&scan, 0, 15, &ctx());
assert_eq!(costs.len(), 3);
for &c in &costs {
assert_eq!(c, COEFF_SIGN_BASE_COST);
}
}
#[test]
fn coeff_suffix_lsb_cost_vec_filters_threshold() {
let mut scan = vec![0i32; 16];
scan[0] = 5; scan[3] = 16; scan[7] = -32; let costs = coeff_suffix_lsb_cost_vec(&scan, 0, 15, &ctx());
assert_eq!(costs.len(), 2, "only |coeff|>=16 positions");
assert!((costs[0] - 2048.0).abs() < 0.1);
assert!((costs[1] - 512.0).abs() < 0.1);
}
#[test]
fn mvd_costs_filter_by_threshold() {
let slots = vec![
MvdSlot { list: 0, partition: 0, axis: Axis::X, value: 0 },
MvdSlot { list: 0, partition: 0, axis: Axis::Y, value: 5 },
MvdSlot { list: 0, partition: 1, axis: Axis::X, value: -10 },
];
let sign_costs = mvd_sign_cost_vec(&slots, &ctx());
assert_eq!(sign_costs.len(), 2, "value=0 not eligible for sign");
let suffix_costs = mvd_suffix_lsb_cost_vec(&slots, &ctx());
assert_eq!(suffix_costs.len(), 1, "only |mvd|>=9 eligible for suffix");
assert!((suffix_costs[0] - 200.0).abs() < 0.1);
}
#[test]
fn domain_base_cost_dispatch() {
assert_eq!(
domain_base_cost(EmbedDomain::CoeffSignBypass),
COEFF_SIGN_BASE_COST,
);
assert_eq!(
domain_base_cost(EmbedDomain::MvdSignBypass),
MVD_SIGN_BASE_COST,
);
}
}