use crate::decode::idct::inverse_dct_8x8;
use crate::foundation::consts::JPEG_ZIGZAG_ORDER;
use wide::f32x8;
const ALPHA_SQRT2: [f32; 8] = {
let s = core::f32::consts::SQRT_2;
[1.0, s, s, s, s, s, s, s]
};
const SIGN_ALT: [f32; 8] = [1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0];
const IDX_SQ: [f32; 8] = [0.0, 1.0, 4.0, 9.0, 16.0, 25.0, 36.0, 49.0];
const HF_THRESHOLD: f32 = 400.0;
const OFFSET_SCALE: f32 = 1.0 / (2.0 * core::f32::consts::SQRT_2);
const GRAD_LEFT: [f32; 4] = [
318.0 / 1024.0,
-285.0 / 1024.0,
81.0 / 1024.0,
-32.0 / 1024.0,
];
const GRAD_RIGHT: [f32; 4] = [
-318.0 / 1024.0,
-285.0 / 1024.0,
-81.0 / 1024.0,
-32.0 / 1024.0,
];
const GRAD_LEFT_HF: [f32; 4] = [
318.0 / 1024.0 * 0.5,
-285.0 / 1024.0 * 0.25,
81.0 / 1024.0 * 0.125,
-32.0 / 1024.0 * 0.0625,
];
const GRAD_RIGHT_HF: [f32; 4] = [
-318.0 / 1024.0 * 0.5,
-285.0 / 1024.0 * 0.25,
-81.0 / 1024.0 * 0.125,
-32.0 / 1024.0 * 0.0625,
];
pub fn process_component(
zigzag_coeffs: &[i16],
blocks_wide: usize,
blocks_high: usize,
quant_table: &[u16; 64],
) -> alloc::vec::Vec<f32> {
debug_assert_eq!(zigzag_coeffs.len(), blocks_wide * blocks_high * 64);
if blocks_wide == 0 || blocks_high == 0 {
return alloc::vec::Vec::new();
}
let row_len = blocks_wide * 64;
let pw = blocks_wide * 8;
let ph = blocks_high * 8;
let quant_f32: [f32; 64] = core::array::from_fn(|i| quant_table[i] as f32);
let mut prev_blocks = alloc::vec![0.0f32; row_len];
let mut prev_offsets = alloc::vec![0.0f32; row_len];
let mut curr_blocks = alloc::vec![0.0f32; row_len];
let mut curr_offsets = alloc::vec![0.0f32; row_len];
let mut plane = alloc::vec![0.0f32; pw * ph];
dequantize_row(zigzag_coeffs, 0, blocks_wide, quant_table, &mut prev_blocks);
correct_h_row(&prev_blocks, &mut prev_offsets, blocks_wide);
for by in 1..blocks_high {
dequantize_row(
zigzag_coeffs,
by,
blocks_wide,
quant_table,
&mut curr_blocks,
);
curr_offsets.fill(0.0);
correct_h_row(&curr_blocks, &mut curr_offsets, blocks_wide);
correct_v_between(
&prev_blocks,
&mut prev_offsets,
&curr_blocks,
&mut curr_offsets,
blocks_wide,
);
finalize_row(
&prev_blocks,
&prev_offsets,
&quant_f32,
by - 1,
blocks_wide,
pw,
&mut plane,
);
core::mem::swap(&mut prev_blocks, &mut curr_blocks);
core::mem::swap(&mut prev_offsets, &mut curr_offsets);
}
finalize_row(
&prev_blocks,
&prev_offsets,
&quant_f32,
blocks_high - 1,
blocks_wide,
pw,
&mut plane,
);
plane
}
#[inline]
fn dequantize_row(
zigzag_coeffs: &[i16],
by: usize,
blocks_wide: usize,
quant_table: &[u16; 64],
row_blocks: &mut [f32],
) {
let row_start = by * blocks_wide * 64;
for bx in 0..blocks_wide {
let src_off = row_start + bx * 64;
let dst_off = bx * 64;
for nat in 0..64 {
let zi = JPEG_ZIGZAG_ORDER[nat] as usize;
row_blocks[dst_off + nat] =
zigzag_coeffs[src_off + zi] as f32 * quant_table[nat] as f32;
}
}
}
#[inline(never)]
fn correct_h_row(blocks: &[f32], offsets: &mut [f32], blocks_wide: usize) {
let alpha_v = f32x8::new(ALPHA_SQRT2);
let sign_v = f32x8::new(SIGN_ALT);
let idx_sq_v = f32x8::new(IDX_SQ);
for bx in 0..blocks_wide.saturating_sub(1) {
let bi = bx * 64;
let bj = (bx + 1) * 64;
for v in 0..4 {
let row = v * 8;
let gi = f32x8::new(blocks[bi + row..bi + row + 8].try_into().unwrap());
let gj = f32x8::new(blocks[bj + row..bj + row + 8].try_into().unwrap());
let (delta, hf) = compute_delta_hf(gi, gj, alpha_v, sign_v, idx_sq_v);
let (gl, gr) = select_gradient(hf);
for k in 0..4 {
offsets[bi + row + k] += delta * gl[k];
offsets[bj + row + k] += delta * gr[k];
}
}
}
}
#[inline(never)]
fn correct_v_between(
top: &[f32],
top_off: &mut [f32],
bot: &[f32],
bot_off: &mut [f32],
blocks_wide: usize,
) {
let alpha_v = f32x8::new(ALPHA_SQRT2);
let sign_v = f32x8::new(SIGN_ALT);
let idx_sq_v = f32x8::new(IDX_SQ);
for bx in 0..blocks_wide {
let off = bx * 64;
for u in 0..4 {
let mut gi_arr = [0.0f32; 8];
let mut gj_arr = [0.0f32; 8];
for v in 0..8 {
gi_arr[v] = top[off + v * 8 + u];
gj_arr[v] = bot[off + v * 8 + u];
}
let gi = f32x8::new(gi_arr);
let gj = f32x8::new(gj_arr);
let (delta, hf) = compute_delta_hf(gi, gj, alpha_v, sign_v, idx_sq_v);
let (gl, gr) = select_gradient(hf);
for v in 0..4 {
top_off[off + v * 8 + u] += delta * gl[v];
bot_off[off + v * 8 + u] += delta * gr[v];
}
}
}
}
#[inline(always)]
fn compute_delta_hf(gi: f32x8, gj: f32x8, alpha: f32x8, sign: f32x8, idx_sq: f32x8) -> (f32, f32) {
let delta_lanes = alpha * (gj - sign * gi);
let delta = sum_f32x8(delta_lanes);
let hf_lanes = idx_sq * (gi * gi + gj * gj);
let hf = sum_f32x8(hf_lanes);
(delta, hf)
}
#[inline(always)]
fn select_gradient(hf: f32) -> (&'static [f32; 4], &'static [f32; 4]) {
if hf > HF_THRESHOLD {
(&GRAD_LEFT_HF, &GRAD_RIGHT_HF)
} else {
(&GRAD_LEFT, &GRAD_RIGHT)
}
}
#[inline(always)]
fn sum_f32x8(v: f32x8) -> f32 {
let a: [f32; 8] = v.into();
(a[0] + a[1]) + (a[2] + a[3]) + (a[4] + a[5]) + (a[6] + a[7])
}
#[inline(never)]
fn finalize_row(
blocks: &[f32],
offsets: &[f32],
quant_f32: &[f32; 64],
by: usize,
blocks_wide: usize,
pw: usize,
plane: &mut [f32],
) {
let scale_v = f32x8::splat(OFFSET_SCALE);
let half = f32x8::splat(0.5);
let level_shift = f32x8::splat(128.0);
let mut block = [0.0f32; 64];
for bx in 0..blocks_wide {
let off = bx * 64;
for k in (0..64).step_by(8) {
let mid = f32x8::new(blocks[off + k..off + k + 8].try_into().unwrap());
let correction = f32x8::new(offsets[off + k..off + k + 8].try_into().unwrap());
let q = f32x8::new(quant_f32[k..k + 8].try_into().unwrap());
let half_q = q * half;
let corrected = mid + correction * scale_v;
let clamped = corrected.max(mid - half_q).min(mid + half_q);
let result: [f32; 8] = clamped.into();
block[k..k + 8].copy_from_slice(&result);
}
let pixels = inverse_dct_8x8(&block);
for row in 0..8 {
let src = f32x8::new(pixels[row * 8..(row + 1) * 8].try_into().unwrap());
let shifted: [f32; 8] = (src + level_shift).into();
let dst = (by * 8 + row) * pw + bx * 8;
plane[dst..dst + 8].copy_from_slice(&shifted);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_uniform_blocks_no_correction() {
let blocks_wide = 2;
let blocks_high = 1;
let num_blocks = 2;
let mut zigzag = vec![0i16; num_blocks * 64];
zigzag[0] = 10;
zigzag[64] = 10;
let mut quant = [1u16; 64];
quant[0] = 8;
let plane = process_component(&zigzag, blocks_wide, blocks_high, &quant);
assert_eq!(plane.len(), 16 * 8);
let p0 = plane[0];
let p1 = plane[8];
assert!(
(p0 - p1).abs() < 0.01,
"Uniform blocks should produce identical pixels: {p0} vs {p1}"
);
}
#[test]
fn test_discontinuity_reduced() {
let blocks_wide = 2;
let blocks_high = 1;
let mut zigzag = vec![0i16; 2 * 64];
zigzag[0] = 5;
zigzag[64] = 20;
let quant = [8u16; 64];
let plane = process_component(&zigzag, blocks_wide, blocks_high, &quant);
let left_edge = plane[7];
let right_edge = plane[8];
let gap = (right_edge - left_edge).abs();
assert!(
gap < 15.0,
"Boundary gap should be reduced from 15.0, got {gap}"
);
}
#[test]
fn test_vertical_boundary() {
let blocks_wide = 1;
let blocks_high = 2;
let mut zigzag = vec![0i16; 2 * 64];
zigzag[0] = 5; zigzag[64] = 20;
let quant = [8u16; 64];
let plane = process_component(&zigzag, blocks_wide, blocks_high, &quant);
let pw = 8;
let top_edge = plane[7 * pw]; let bot_edge = plane[8 * pw];
let gap = (bot_edge - top_edge).abs();
assert!(
gap < 15.0,
"Vertical boundary gap should be reduced from 15.0, got {gap}"
);
}
#[test]
fn test_gradient_precomputation() {
for k in 0..4 {
let expected = GRAD_LEFT[k] * -SIGN_ALT[k];
assert!(
(GRAD_RIGHT[k] - expected).abs() < 1e-7,
"GRAD_RIGHT[{k}] = {}, expected {expected}",
GRAD_RIGHT[k]
);
}
for k in 0..4 {
let scale = 0.5f32.powi(k as i32 + 1);
assert!(
(GRAD_LEFT_HF[k] - GRAD_LEFT[k] * scale).abs() < 1e-7,
"GRAD_LEFT_HF[{k}] mismatch"
);
assert!(
(GRAD_RIGHT_HF[k] - GRAD_RIGHT[k] * scale).abs() < 1e-7,
"GRAD_RIGHT_HF[{k}] mismatch"
);
}
}
}