use crate::codec::h264::cavlc::{EmbedDomain, EmbeddablePosition};
use crate::codec::h264::macroblock::BLOCK_INDEX_TO_POS;
use crate::codec::h264::tables::ZIGZAG_4X4;
const POW2_QBITS_MINUS_4: [f64; 9] = [
0.0625, 0.125, 0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0, ];
#[cfg(feature = "parallel")]
use rayon::prelude::*;
const SIGMA: f64 = 0.015625;
const HPDF: [f64; 16] = [
-0.0544158422,
0.3128715909,
-0.6756307363,
0.5853546837,
0.0158291053,
-0.2840155430,
-0.0004724846,
0.1287474266,
0.0173693010,
-0.0440882539,
-0.0139810279,
0.0087460940,
0.0048703530,
-0.0003917404,
-0.0006754494,
-0.0001174768,
];
fn lpdf() -> [f64; 16] {
let mut lp = [0.0f64; 16];
for n in 0..16 {
let sign = if n % 2 == 0 { 1.0 } else { -1.0 };
lp[n] = sign * HPDF[15 - n];
}
lp
}
const FILT_LEN: usize = 16;
const IMPACT_SIZE: usize = 4 + FILT_LEN - 1;
pub struct ThreeSubbands {
pub lh: Vec<f32>,
pub hl: Vec<f32>,
pub hh: Vec<f32>,
pub width: usize,
pub height: usize,
pub x_offset: isize,
pub y_offset: isize,
}
pub fn compute_three_subbands(y_plane: &[u8], width: usize, height: usize) -> ThreeSubbands {
let pad = FILT_LEN - 1; let padded_w = width + 2 * pad;
let padded_h = height + 2 * pad;
let mut row_low = vec![0.0f32; padded_w * height];
let mut row_high = vec![0.0f32; padded_w * height];
let lp = lpdf();
for y in 0..height {
for out_x in 0..padded_w {
let mut sum_low = 0.0f64;
let mut sum_high = 0.0f64;
for k in 0..FILT_LEN {
let _src_x = out_x as isize + k as isize - (2 * pad as isize) + pad as isize;
let src_x = out_x as isize - pad as isize + k as isize;
let clamped = symmetric_reflect(src_x, width as isize);
let v = y_plane[y * width + clamped as usize] as f64;
sum_low += lp[k] * v;
sum_high += HPDF[k] * v;
}
row_low[y * padded_w + out_x] = sum_low as f32;
row_high[y * padded_w + out_x] = sum_high as f32;
}
}
let mut lh = vec![0.0f32; padded_w * padded_h];
let mut hl = vec![0.0f32; padded_w * padded_h];
let mut hh = vec![0.0f32; padded_w * padded_h];
for out_y in 0..padded_h {
for x in 0..padded_w {
let mut sum_lh = 0.0f64; let mut sum_hl = 0.0f64; let mut sum_hh = 0.0f64; for k in 0..FILT_LEN {
let src_y = out_y as isize - pad as isize + k as isize;
let clamped = symmetric_reflect(src_y, height as isize);
let low_val = row_low[clamped as usize * padded_w + x] as f64;
let high_val = row_high[clamped as usize * padded_w + x] as f64;
sum_lh += HPDF[k] * low_val;
sum_hl += lp[k] * high_val;
sum_hh += HPDF[k] * high_val;
}
lh[out_y * padded_w + x] = sum_lh as f32;
hl[out_y * padded_w + x] = sum_hl as f32;
hh[out_y * padded_w + x] = sum_hh as f32;
}
}
ThreeSubbands {
lh,
hl,
hh,
width: padded_w,
height: padded_h,
x_offset: -(pad as isize),
y_offset: -(pad as isize),
}
}
#[inline]
fn symmetric_reflect(i: isize, len: isize) -> isize {
if len <= 0 {
return 0;
}
let mut v = i;
while v < 0 || v >= len {
if v < 0 {
v = -v - 1;
}
if v >= len {
v = 2 * len - v - 1;
}
}
v
}
fn precompute_unit_basis() -> [[[[f64; 4]; 4]; 4]; 4] {
let mut out = [[[[0.0f64; 4]; 4]; 4]; 4];
for u in 0..4 {
for v in 0..4 {
let mut d = [[0i32; 4]; 4];
d[u][v] = 64; let mut g = [[0i32; 4]; 4];
for j in 0..4 {
let e0 = d[0][j] + d[2][j];
let e1 = d[0][j] - d[2][j];
let e2 = (d[1][j] >> 1) - d[3][j];
let e3 = d[1][j] + (d[3][j] >> 1);
g[0][j] = e0 + e3;
g[1][j] = e1 + e2;
g[2][j] = e1 - e2;
g[3][j] = e0 - e3;
}
let mut h = [[0i32; 4]; 4];
for i in 0..4 {
let e0 = g[i][0] + g[i][2];
let e1 = g[i][0] - g[i][2];
let e2 = (g[i][1] >> 1) - g[i][3];
let e3 = g[i][1] + (g[i][3] >> 1);
h[i][0] = e0 + e3;
h[i][1] = e1 + e2;
h[i][2] = e1 - e2;
h[i][3] = e0 - e3;
}
for i in 0..4 {
for j in 0..4 {
out[u][v][i][j] = h[i][j] as f64 / 64.0;
}
}
}
}
out
}
const NORM_ADJUST_4X4: [[i32; 3]; 6] = [
[10, 16, 13],
[11, 18, 14],
[13, 20, 16],
[14, 23, 18],
[16, 25, 20],
[18, 29, 23],
];
#[inline]
const fn norm_adjust_class(u: usize, v: usize) -> usize {
let even_u = u & 1 == 0;
let even_v = v & 1 == 0;
if even_u && even_v {
0
} else if !even_u && !even_v {
1
} else {
2
}
}
#[inline]
fn pixel_scale(qp: i32, u: usize, v: usize) -> f64 {
let q_mod = qp.rem_euclid(6) as usize;
let q_bits = qp / 6;
let s = NORM_ADJUST_4X4[q_mod][norm_adjust_class(u, v)] as f64;
s * POW2_QBITS_MINUS_4[q_bits as usize]
}
fn compute_position_cost(
unit_basis: &[[[[f64; 4]; 4]; 4]; 4],
wavelets: &ThreeSubbands,
img_w: usize,
img_h: usize,
block_px_x: usize,
block_px_y: usize,
scan_pos: u8,
qp: i32,
delta_magnitude: f64,
) -> f64 {
let raster = ZIGZAG_4X4[scan_pos as usize] as usize;
let u = raster / 4;
let v = raster % 4;
let scale = pixel_scale(qp, u, v) * delta_magnitude;
let mut basis = [[0.0f64; 4]; 4];
for i in 0..4 {
for j in 0..4 {
basis[i][j] = unit_basis[u][v][i][j] * scale;
}
}
let mut row_low = [[0.0f64; IMPACT_SIZE]; 4];
let mut row_high = [[0.0f64; IMPACT_SIZE]; 4];
let lp = lpdf();
for r in 0..4 {
for out_c in 0..IMPACT_SIZE {
let mut sum_low = 0.0;
let mut sum_high = 0.0;
for k in 0..FILT_LEN {
let src = out_c as isize - (FILT_LEN - 1) as isize + k as isize;
if (0..4).contains(&src) {
let v = basis[r][src as usize];
sum_low += lp[k] * v;
sum_high += HPDF[k] * v;
}
}
row_low[r][out_c] = sum_low;
row_high[r][out_c] = sum_high;
}
}
let pad = FILT_LEN - 1; let mut cost = 0.0f64;
for out_r in 0..IMPACT_SIZE {
for out_c in 0..IMPACT_SIZE {
let mut delta_lh = 0.0;
let mut delta_hl = 0.0;
let mut delta_hh = 0.0;
for k in 0..FILT_LEN {
let src_r = out_r as isize - (FILT_LEN - 1) as isize + k as isize;
if (0..4).contains(&src_r) {
let r = src_r as usize;
let low_val = row_low[r][out_c];
let high_val = row_high[r][out_c];
delta_lh += HPDF[k] * low_val;
delta_hl += lp[k] * high_val;
delta_hh += HPDF[k] * high_val;
}
}
let abs_x = block_px_x as isize + out_c as isize - pad as isize;
let abs_y = block_px_y as isize + out_r as isize - pad as isize;
if abs_x < 0 || abs_y < 0 || abs_x >= img_w as isize || abs_y >= img_h as isize {
continue;
}
let wx = (abs_x - wavelets.x_offset) as usize;
let wy = (abs_y - wavelets.y_offset) as usize;
let idx = wy * wavelets.width + wx;
let w_lh = wavelets.lh[idx].abs() as f64;
let w_hl = wavelets.hl[idx].abs() as f64;
let w_hh = wavelets.hh[idx].abs() as f64;
cost += delta_lh.abs() / (w_lh + SIGMA);
cost += delta_hl.abs() / (w_hl + SIGMA);
cost += delta_hh.abs() / (w_hh + SIGMA);
}
}
cost
}
pub struct FramePlanes<'a> {
pub y: &'a [u8],
pub cb: &'a [u8],
pub cr: &'a [u8],
pub width: usize,
pub height: usize,
}
pub fn compute_frame_uniward_costs(
planes: &FramePlanes,
frame_positions: &[FramePosition],
qps: &[i32],
) -> Vec<f32> {
let width = planes.width;
let height = planes.height;
let chroma_w = width / 2;
let chroma_h = height / 2;
let y_wavelets = compute_three_subbands(planes.y, width, height);
let cb_wavelets = compute_three_subbands(planes.cb, chroma_w, chroma_h);
let cr_wavelets = compute_three_subbands(planes.cr, chroma_w, chroma_h);
let unit_basis = precompute_unit_basis();
let width_in_mbs = width / 16;
let compute_one = |fp: &FramePosition| -> f32 {
let pos = fp.pos;
if pos.domain == EmbedDomain::MvdLsb {
return f32::INFINITY;
}
if pos.scan_pos == 0 {
return f32::INFINITY;
}
let within_mb = fp.within_mb_block_idx;
let mb_x = fp.mb_idx % width_in_mbs;
let mb_y = fp.mb_idx / width_in_mbs;
let (wavelets, img_w, img_h, block_px_x, block_px_y, qp) = if within_mb < 16 {
let (bx, by) = BLOCK_INDEX_TO_POS[within_mb];
(
&y_wavelets,
width,
height,
mb_x * 16 + bx as usize * 4,
mb_y * 16 + by as usize * 4,
qps.get(fp.mb_idx).copied().unwrap_or(26),
)
} else if within_mb <= 17 {
return f32::INFINITY;
} else if within_mb < 22 {
let slot = within_mb - 18;
let bx = slot % 2;
let by = slot / 2;
(
&cb_wavelets,
chroma_w,
chroma_h,
mb_x * 8 + bx * 4,
mb_y * 8 + by * 4,
fp.qp_cb,
)
} else if within_mb < 26 {
let slot = within_mb - 22;
let bx = slot % 2;
let by = slot / 2;
(
&cr_wavelets,
chroma_w,
chroma_h,
mb_x * 8 + bx * 4,
mb_y * 8 + by * 4,
fp.qp_cr,
)
} else {
return f32::INFINITY;
};
let delta = match pos.domain {
EmbedDomain::T1Sign => 2.0,
EmbedDomain::LevelSuffixMag => 1.0,
EmbedDomain::LevelSuffixSign => 2.0 * pos.coeff_value.unsigned_abs() as f64,
EmbedDomain::MvdLsb => unreachable!("MvdLsb handled above"),
};
let cost = compute_position_cost(
&unit_basis,
wavelets,
img_w,
img_h,
block_px_x,
block_px_y,
pos.scan_pos,
qp,
delta,
);
if cost.is_finite() && cost > 0.0 {
cost as f32
} else {
f32::INFINITY
}
};
#[cfg(feature = "parallel")]
{
frame_positions.par_iter().map(compute_one).collect()
}
#[cfg(not(feature = "parallel"))]
{
frame_positions.iter().map(compute_one).collect()
}
}
pub struct FramePosition<'a> {
pub pos: &'a EmbeddablePosition,
pub mb_idx: usize,
pub within_mb_block_idx: usize,
pub qp_cb: i32,
pub qp_cr: i32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn wavelets_of_constant_plane_are_zero_away_from_border() {
let w = 32;
let h = 32;
let img = vec![128u8; w * h];
let bands = compute_three_subbands(&img, w, h);
for y in (FILT_LEN + 5)..(bands.height - FILT_LEN - 5) {
for x in (FILT_LEN + 5)..(bands.width - FILT_LEN - 5) {
let idx = y * bands.width + x;
assert!(
bands.lh[idx].abs() < 1e-3,
"center LH should be zero, got {}",
bands.lh[idx]
);
assert!(bands.hl[idx].abs() < 1e-3);
assert!(bands.hh[idx].abs() < 1e-3);
}
}
}
#[test]
fn unit_basis_dc_position_is_flat() {
let basis = precompute_unit_basis();
let dc = &basis[0][0];
let first = dc[0][0];
assert!(first > 0.0, "DC basis should be positive");
for i in 0..4 {
for j in 0..4 {
assert!(
(dc[i][j] - first).abs() < 1e-6,
"DC unit basis should be uniform, got {} vs {}",
dc[i][j],
first
);
}
}
}
#[test]
fn pixel_scale_matches_transform_module_contract() {
let s = pixel_scale(30, 0, 0);
assert!((s - 20.0).abs() < 1e-9, "pixel_scale(30, 0, 0) = {s}, expected 20.0");
}
#[test]
fn flat_image_gives_infinite_cost_for_mag_lsb() {
let w = 32;
let h = 32;
let img = vec![128u8; w * h];
let wavelets = compute_three_subbands(&img, w, h);
let unit_basis = precompute_unit_basis();
let cost = compute_position_cost(
&unit_basis,
&wavelets,
w,
h,
16, 16,
5, 26,
1.0,
);
assert!(cost.is_finite(), "flat image cost must be finite");
assert!(cost > 1.0, "flat image should give high cost, got {cost}");
}
}