use crate::codec::jpeg::dct::{DctGrid, QuantTable};
use crate::codec::jpeg::pixels::idct_block;
use crate::stego::error::StegoError;
use crate::stego::progress;
use super::CostMap;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
pub const UNIWARD_PROGRESS_STEPS: u32 = 100;
const SIGMA: f64 = 0.015625;
const HPDF: [f64; 16] = [
-0.0544158422,
0.3128715909,
-0.6756307363,
0.5853546837,
0.0158291053,
-0.2840155430,
-0.0004724846,
0.1287474266,
0.0173693010,
-0.0440882539,
-0.0139810279,
0.0087460940,
0.0048703530,
-0.0003917404,
-0.0006754494,
-0.0001174768,
];
fn lpdf() -> [f64; 16] {
let mut lp = [0.0f64; 16];
for n in 0..16 {
let sign = if n % 2 == 0 { 1.0 } else { -1.0 };
lp[n] = sign * HPDF[15 - n];
}
lp
}
const FILT_LEN: usize = 16;
const IMPACT_SIZE: usize = 8 + FILT_LEN - 1;
struct CostMapPtr {
ptr: *mut f32,
total_len: usize,
}
unsafe impl Send for CostMapPtr {}
unsafe impl Sync for CostMapPtr {}
impl CostMapPtr {
unsafe fn write(&self, idx: usize, val: f32) {
debug_assert!(idx < self.total_len, "CostMapPtr write out of bounds: {idx} >= {}", self.total_len);
unsafe { *self.ptr.add(idx) = val; }
}
}
pub fn compute_uniward(grid: &DctGrid, qt: &QuantTable) -> CostMap {
let bw = grid.blocks_wide();
let bt = grid.blocks_tall();
let mut map = CostMap::new(bw, bt);
let img_w = bw * 8;
let img_h = bt * 8;
let cover_pixels = decompress_to_pixels(grid, qt, bw, bt);
let lpdf = lpdf();
let cover_wavelets = compute_three_subbands(&cover_pixels, img_w, img_h, &lpdf);
drop(cover_pixels);
let basis = precompute_basis_functions(qt);
let pad = FILT_LEN - 1; let n_blocks = bt * bw;
let total_len = n_blocks * 64;
let costs_ptr = CostMapPtr { ptr: map.costs_ptr(), total_len };
let compute_block = |bi: usize| {
let br = bi / bw;
let bc = bi % bw;
let blk = grid.block(br, bc);
for fi in 0..8 {
for fj in 0..8 {
if fi == 0 && fj == 0 {
continue;
}
let coeff = blk[fi * 8 + fj];
if coeff == 0 {
continue;
}
let basis_block = &basis[fi][fj];
let cost = compute_coefficient_cost(
basis_block,
&cover_wavelets,
br, bc,
img_w, img_h,
pad,
&lpdf,
);
if cost > 0.0 && cost.is_finite() {
let idx = (br * bw + bc) * 64 + fi * 8 + fj;
unsafe { costs_ptr.write(idx, cost as f32); }
}
}
}
};
#[cfg(feature = "parallel")]
(0..n_blocks).into_par_iter().for_each(compute_block);
#[cfg(not(feature = "parallel"))]
(0..n_blocks).for_each(|bi| compute_block(bi));
map
}
pub fn compute_uniward_with_progress(grid: &DctGrid, qt: &QuantTable) -> Result<CostMap, StegoError> {
let bw = grid.blocks_wide();
let bt = grid.blocks_tall();
let mut map = CostMap::new(bw, bt);
let img_w = bw * 8;
let img_h = bt * 8;
let phase1_steps = UNIWARD_PROGRESS_STEPS / 10; let phase2_steps = (UNIWARD_PROGRESS_STEPS * 3) / 10;
let cover_pixels = decompress_to_pixels(grid, qt, bw, bt);
for _ in 0..phase1_steps {
progress::advance();
}
progress::check_cancelled()?;
let lpdf = lpdf();
let cover_wavelets = compute_three_subbands(&cover_pixels, img_w, img_h, &lpdf);
drop(cover_pixels);
for _ in 0..phase2_steps {
progress::advance();
}
progress::check_cancelled()?;
let basis = precompute_basis_functions(qt);
let pad = FILT_LEN - 1;
let n_blocks = bt * bw;
let total_len = n_blocks * 64;
let costs_ptr = CostMapPtr { ptr: map.costs_ptr(), total_len };
let compute_block = |bi: usize| {
let br = bi / bw;
let bc = bi % bw;
let blk = grid.block(br, bc);
for fi in 0..8 {
for fj in 0..8 {
if fi == 0 && fj == 0 {
continue;
}
let coeff = blk[fi * 8 + fj];
if coeff == 0 {
continue;
}
let basis_block = &basis[fi][fj];
let cost = compute_coefficient_cost(
basis_block,
&cover_wavelets,
br, bc,
img_w, img_h,
pad,
&lpdf,
);
if cost > 0.0 && cost.is_finite() {
let idx = (br * bw + bc) * 64 + fi * 8 + fj;
unsafe { costs_ptr.write(idx, cost as f32); }
}
}
}
};
#[cfg(feature = "parallel")]
(0..n_blocks).into_par_iter().for_each(compute_block);
#[cfg(not(feature = "parallel"))]
(0..n_blocks).for_each(|bi| compute_block(bi));
let phase3_steps = UNIWARD_PROGRESS_STEPS - phase1_steps - phase2_steps;
for _ in 0..phase3_steps {
progress::advance();
}
Ok(map)
}
fn decompress_to_pixels(grid: &DctGrid, qt: &QuantTable, bw: usize, bt: usize) -> Vec<f32> {
let img_w = bw * 8;
let img_h = bt * 8;
let mut pixels = vec![0.0f32; img_w * img_h];
for br in 0..bt {
for bc in 0..bw {
let block = grid.block(br, bc);
let quantized: [i16; 64] = block.try_into().unwrap();
let block_pixels = idct_block(&quantized, &qt.values);
for row in 0..8 {
for col in 0..8 {
let py = br * 8 + row;
let px = bc * 8 + col;
pixels[py * img_w + px] = block_pixels[row * 8 + col] as f32;
}
}
}
}
pixels
}
fn precompute_basis_functions(qt: &QuantTable) -> [[[f64; 64]; 8]; 8] {
let mut basis = [[[0.0f64; 64]; 8]; 8];
for fi in 0..8 {
for fj in 0..8 {
let mut impulse = [0i16; 64];
impulse[fi * 8 + fj] = 1;
let unity_qt = [1u16; 64];
let pixels = idct_block(&impulse, &unity_qt);
let q = qt.values[fi * 8 + fj] as f64;
for k in 0..64 {
basis[fi][fj][k] = (pixels[k] - 128.0) * q;
}
}
}
basis
}
struct ThreeSubbands {
lh: Vec<f32>,
hl: Vec<f32>,
hh: Vec<f32>,
width: usize,
y_offset: usize,
}
fn compute_three_subbands(
pixels: &[f32],
width: usize,
height: usize,
lpdf: &[f64; 16],
) -> ThreeSubbands {
#[cfg(feature = "parallel")]
{
let (row_low, row_high) = rayon::join(
|| filter_rows(pixels, width, height, lpdf),
|| filter_rows(pixels, width, height, &HPDF),
);
let (lh, (hl, hh)) = rayon::join(
|| filter_cols(&row_low, width, height, &HPDF),
|| rayon::join(
|| filter_cols(&row_high, width, height, lpdf),
|| filter_cols(&row_high, width, height, &HPDF),
),
);
ThreeSubbands { lh, hl, hh, width, y_offset: 0 }
}
#[cfg(not(feature = "parallel"))]
{
let row_low = filter_rows(pixels, width, height, lpdf);
let lh = filter_cols(&row_low, width, height, &HPDF);
drop(row_low);
let row_high = filter_rows(pixels, width, height, &HPDF);
let hl = filter_cols(&row_high, width, height, lpdf);
let hh = filter_cols(&row_high, width, height, &HPDF);
ThreeSubbands { lh, hl, hh, width, y_offset: 0 }
}
}
fn filter_rows(
pixels: &[f32],
width: usize,
height: usize,
filter: &[f64; 16],
) -> Vec<f32> {
let flen = FILT_LEN;
let half = (flen - 1) / 2; let mut output = vec![0.0f32; width * height];
for y in 0..height {
for x in 0..width {
let mut sum = 0.0f64;
for k in 0..flen {
let sx = (x as isize) + (k as isize) - (half as isize);
let sx = mirror_index(sx, width);
sum += pixels[y * width + sx] as f64 * filter[k];
}
output[y * width + x] = sum as f32;
}
}
output
}
fn filter_cols(
pixels: &[f32],
width: usize,
height: usize,
filter: &[f64; 16],
) -> Vec<f32> {
let flen = FILT_LEN;
let half = (flen - 1) / 2; let mut output = vec![0.0f32; width * height];
for y in 0..height {
for x in 0..width {
let mut sum = 0.0f64;
for k in 0..flen {
let sy = (y as isize) + (k as isize) - (half as isize);
let sy = mirror_index(sy, height);
sum += pixels[sy * width + x] as f64 * filter[k];
}
output[y * width + x] = sum as f32;
}
}
output
}
#[inline]
fn mirror_index(idx: isize, size: usize) -> usize {
let s = size as isize;
if idx < 0 {
(-idx).min(s - 1) as usize
} else if idx >= s {
let reflected = 2 * s - 2 - idx;
reflected.max(0) as usize
} else {
idx as usize
}
}
fn compute_coefficient_cost(
basis_block: &[f64; 64],
cover_wavelets: &ThreeSubbands,
br: usize,
bc: usize,
img_w: usize,
img_h: usize,
pad: usize,
lpdf: &[f64; 16],
) -> f64 {
let mut cost = 0.0;
let mut row_low = [[0.0f64; IMPACT_SIZE]; 8]; let mut row_high = [[0.0f64; IMPACT_SIZE]; 8];
for r in 0..8 {
for out_c in 0..IMPACT_SIZE {
let mut sum_low = 0.0;
let mut sum_high = 0.0;
for k in 0..FILT_LEN {
let src_col = out_c as isize - 14 + k as isize;
if (0..8).contains(&src_col) {
let val = basis_block[r * 8 + src_col as usize];
sum_low += lpdf[k] * val;
sum_high += HPDF[k] * val;
}
}
row_low[r][out_c] = sum_low;
row_high[r][out_c] = sum_high;
}
}
for out_r in 0..IMPACT_SIZE {
for out_c in 0..IMPACT_SIZE {
let mut delta_lh = 0.0;
let mut delta_hl = 0.0;
let mut delta_hh = 0.0;
for k in 0..FILT_LEN {
let src_row = out_r as isize - 14 + k as isize;
if (0..8).contains(&src_row) {
let r = src_row as usize;
let low_val = row_low[r][out_c];
let high_val = row_high[r][out_c];
delta_lh += HPDF[k] * low_val;
delta_hl += lpdf[k] * high_val;
delta_hh += HPDF[k] * high_val;
}
}
let abs_y = (br * 8) as isize + out_r as isize - (pad as isize);
let abs_x = (bc * 8) as isize + out_c as isize - (pad as isize);
if abs_y >= 0 && abs_y < img_h as isize && abs_x >= 0 && abs_x < img_w as isize {
let wy = abs_y as usize;
let wx = abs_x as usize;
let idx = (wy - cover_wavelets.y_offset) * cover_wavelets.width + wx;
let w_lh = cover_wavelets.lh[idx] as f64;
let w_hl = cover_wavelets.hl[idx] as f64;
let w_hh = cover_wavelets.hh[idx] as f64;
cost += delta_lh.abs() / (w_lh.abs() + SIGMA);
cost += delta_hl.abs() / (w_hl.abs() + SIGMA);
cost += delta_hh.abs() / (w_hh.abs() + SIGMA);
}
}
}
cost
}
use crate::stego::permute::CoeffPos;
use crate::stego::side_info::SideInfo;
const STRIP_BLOCK_ROWS: usize = 50;
const MIN_SI_COST_F32: f32 = 1e-6;
pub fn compute_positions_streaming(
grid: &DctGrid,
qt: &QuantTable,
si: Option<(&SideInfo, &DctGrid)>,
) -> Result<Vec<CoeffPos>, StegoError> {
let bw = grid.blocks_wide();
let bt = grid.blocks_tall();
let img_w = bw * 8;
let img_h = bt * 8;
let lpdf = lpdf();
let basis = precompute_basis_functions(qt);
let pad = FILT_LEN - 1;
let est_positions = bt * bw * 32;
let mut positions: Vec<CoeffPos> = Vec::with_capacity(est_positions);
let num_strips = bt.div_ceil(STRIP_BLOCK_ROWS);
let mut strip_idx = 0usize;
let mut steps_sent = 0u32;
for strip_start in (0..bt).step_by(STRIP_BLOCK_ROWS) {
let strip_end = (strip_start + STRIP_BLOCK_ROWS).min(bt);
let wav_y_start = (strip_start * 8).saturating_sub(pad);
let wav_y_end = (strip_end * 8).min(img_h);
let pix_y_start = wav_y_start.saturating_sub(7);
let pix_y_end = (wav_y_end + 8).min(img_h);
let pix_br_start = pix_y_start / 8;
let pix_br_end = pix_y_end.div_ceil(8).min(bt);
let pix_strip_h = (pix_br_end - pix_br_start) * 8;
let pix_strip_y0 = pix_br_start * 8; let pixels = decompress_strip_pixels(grid, qt, bw, pix_br_start, pix_br_end);
{
let sub_target = (strip_idx as u32 * 3 + 1) * UNIWARD_PROGRESS_STEPS / (num_strips as u32 * 3);
while steps_sent < sub_target {
progress::advance();
steps_sent += 1;
}
}
let strip_wavelets = compute_strip_subbands(
&pixels, img_w, pix_strip_h, pix_strip_y0,
wav_y_start, wav_y_end, img_h, &lpdf,
);
drop(pixels);
{
let sub_target = (strip_idx as u32 * 3 + 2) * UNIWARD_PROGRESS_STEPS / (num_strips as u32 * 3);
while steps_sent < sub_target {
progress::advance();
steps_sent += 1;
}
}
let strip_bt = strip_end - strip_start;
let n_strip_blocks = strip_bt * bw;
let mut strip_map = CostMap::new(bw, strip_bt);
let total_len = n_strip_blocks * 64;
let costs_ptr = CostMapPtr { ptr: strip_map.costs_ptr(), total_len };
let compute_block = |bi: usize| {
let br_local = bi / bw;
let bc = bi % bw;
let br = strip_start + br_local;
let blk = grid.block(br, bc);
for fi in 0..8 {
for fj in 0..8 {
if fi == 0 && fj == 0 { continue; }
let coeff = blk[fi * 8 + fj];
if coeff == 0 { continue; }
let basis_block = &basis[fi][fj];
let cost = compute_coefficient_cost(
basis_block, &strip_wavelets,
br, bc, img_w, img_h, pad, &lpdf,
);
if cost > 0.0 && cost.is_finite() {
let idx = br_local * bw * 64 + bc * 64 + fi * 8 + fj;
unsafe { costs_ptr.write(idx, cost as f32); }
}
}
}
};
#[cfg(feature = "parallel")]
(0..n_strip_blocks).into_par_iter().for_each(compute_block);
#[cfg(not(feature = "parallel"))]
(0..n_strip_blocks).for_each(|bi| compute_block(bi));
for br_local in 0..strip_bt {
let br = strip_start + br_local;
for bc in 0..bw {
for i in 0..8 {
for j in 0..8 {
if i == 0 && j == 0 { continue; }
let cost_f32 = strip_map.get(br_local, bc, i, j);
if !cost_f32.is_finite() { continue; }
let flat_idx = ((br * bw + bc) * 64 + i * 8 + j) as u32;
let final_cost = if let Some((side_info, cover_grid)) = si {
let coeff = cover_grid.get(br, bc, i, j);
if coeff.abs() == 1 {
cost_f32
} else {
let error = side_info.error_at(flat_idx as usize);
let factor = 1.0f32 - 2.0 * error.abs();
(cost_f32 * factor).max(MIN_SI_COST_F32)
}
} else {
cost_f32
};
positions.push(CoeffPos { flat_idx, cost: final_cost });
}
}
}
}
strip_idx += 1;
let target_steps = (strip_idx as u32 * UNIWARD_PROGRESS_STEPS) / num_strips as u32;
while steps_sent < target_steps {
progress::advance();
steps_sent += 1;
}
if strip_idx.is_multiple_of(2) {
progress::check_cancelled()?;
}
}
while steps_sent < UNIWARD_PROGRESS_STEPS {
progress::advance();
steps_sent += 1;
}
Ok(positions)
}
fn decompress_strip_pixels(
grid: &DctGrid,
qt: &QuantTable,
bw: usize,
br_start: usize,
br_end: usize,
) -> Vec<f32> {
let img_w = bw * 8;
let strip_h = (br_end - br_start) * 8;
let mut pixels = vec![0.0f32; img_w * strip_h];
for br in br_start..br_end {
for bc in 0..bw {
let block = grid.block(br, bc);
let quantized: [i16; 64] = block.try_into().unwrap();
let block_pixels = idct_block(&quantized, &qt.values);
for row in 0..8 {
for col in 0..8 {
let py = (br - br_start) * 8 + row;
let px = bc * 8 + col;
pixels[py * img_w + px] = block_pixels[row * 8 + col] as f32;
}
}
}
}
pixels
}
fn compute_strip_subbands(
pixels: &[f32],
width: usize,
pix_h: usize,
pix_y0: usize,
wav_y_start: usize,
wav_y_end: usize,
img_h: usize,
lpdf: &[f64; 16],
) -> ThreeSubbands {
#[cfg(feature = "parallel")]
{
let (row_low, row_high) = rayon::join(
|| filter_rows(pixels, width, pix_h, lpdf),
|| filter_rows(pixels, width, pix_h, &HPDF),
);
let (lh, (hl, hh)) = rayon::join(
|| filter_cols_strip(&row_low, width, pix_h, pix_y0, wav_y_start, wav_y_end, img_h, &HPDF),
|| rayon::join(
|| filter_cols_strip(&row_high, width, pix_h, pix_y0, wav_y_start, wav_y_end, img_h, lpdf),
|| filter_cols_strip(&row_high, width, pix_h, pix_y0, wav_y_start, wav_y_end, img_h, &HPDF),
),
);
ThreeSubbands { lh, hl, hh, width, y_offset: wav_y_start }
}
#[cfg(not(feature = "parallel"))]
{
let row_low = filter_rows(pixels, width, pix_h, lpdf);
let lh = filter_cols_strip(&row_low, width, pix_h, pix_y0, wav_y_start, wav_y_end, img_h, &HPDF);
drop(row_low);
let row_high = filter_rows(pixels, width, pix_h, &HPDF);
let hl = filter_cols_strip(&row_high, width, pix_h, pix_y0, wav_y_start, wav_y_end, img_h, lpdf);
let hh = filter_cols_strip(&row_high, width, pix_h, pix_y0, wav_y_start, wav_y_end, img_h, &HPDF);
ThreeSubbands { lh, hl, hh, width, y_offset: wav_y_start }
}
}
fn filter_cols_strip(
input: &[f32],
width: usize,
input_h: usize,
input_y0: usize,
out_y_start: usize,
out_y_end: usize,
img_h: usize,
filter: &[f64; 16],
) -> Vec<f32> {
let flen = FILT_LEN;
let half = (flen - 1) / 2; let out_h = out_y_end - out_y_start;
let mut output = vec![0.0f32; width * out_h];
for out_idx in 0..out_h {
let abs_y = out_y_start + out_idx;
for x in 0..width {
let mut sum = 0.0f64;
for k in 0..flen {
let sy_abs = abs_y as isize + k as isize - half as isize;
let sy_mirrored = mirror_index(sy_abs, img_h);
let sy_local = sy_mirrored - input_y0;
debug_assert!(
sy_local < input_h,
"strip col filter OOB: abs_y={abs_y} k={k} sy_abs={sy_abs} sy_mirrored={sy_mirrored} input_y0={input_y0} input_h={input_h}"
);
sum += input[sy_local * width + x] as f64 * filter[k];
}
output[out_idx * width + x] = sum as f32;
}
}
output
}
#[cfg(test)]
mod tests {
use super::*;
use crate::stego::cost::WET_COST;
fn make_qt_uniform(val: u16) -> QuantTable {
QuantTable::new([val; 64])
}
fn standard_qt() -> QuantTable {
QuantTable::new([
16, 11, 10, 16, 24, 40, 51, 61,
12, 12, 14, 19, 26, 58, 60, 55,
14, 13, 16, 24, 40, 57, 69, 56,
14, 17, 22, 29, 51, 87, 80, 62,
18, 22, 37, 56, 68, 109, 103, 77,
24, 35, 55, 64, 81, 104, 113, 92,
49, 64, 78, 87, 103, 121, 120, 101,
72, 92, 95, 98, 112, 100, 103, 99,
])
}
#[test]
fn dc_is_wet() {
let mut grid = DctGrid::new(4, 4);
for br in 0..4 {
for bc in 0..4 {
grid.set(br, bc, 0, 0, 100);
grid.set(br, bc, 1, 0, 10);
grid.set(br, bc, 0, 1, -5);
}
}
let map = compute_uniward(&grid, &make_qt_uniform(16));
for br in 0..4 {
for bc in 0..4 {
assert_eq!(map.get(br, bc, 0, 0), WET_COST);
}
}
}
#[test]
fn zero_ac_is_wet() {
let mut grid = DctGrid::new(4, 4);
for br in 0..4 {
for bc in 0..4 {
grid.set(br, bc, 0, 0, 100);
grid.set(br, bc, 1, 0, 10);
}
}
let map = compute_uniward(&grid, &make_qt_uniform(16));
for br in 0..4 {
for bc in 0..4 {
assert_eq!(map.get(br, bc, 0, 1), WET_COST);
}
}
}
#[test]
fn non_zero_ac_has_finite_cost() {
let mut grid = DctGrid::new(4, 4);
for br in 0..4 {
for bc in 0..4 {
grid.set(br, bc, 0, 0, 100);
grid.set(br, bc, 1, 0, 10);
grid.set(br, bc, 0, 1, -5);
grid.set(br, bc, 1, 1, 3);
}
}
let map = compute_uniward(&grid, &standard_qt());
let cost = map.get(2, 2, 1, 0);
assert!(
cost.is_finite() && cost > 0.0,
"expected finite positive cost, got {cost}"
);
}
#[test]
fn textured_region_cheaper_than_smooth_region() {
let qt = make_qt_uniform(16);
let mut grid = DctGrid::new(8, 8);
for br in 0..8 {
for bc in 0..8 {
grid.set(br, bc, 0, 0, 100);
}
}
for br in 0..4 {
for bc in 0..4 {
for i in 0..8 {
for j in 0..8 {
if i == 0 && j == 0 { continue; }
grid.set(br, bc, i, j, (((i * 7 + j * 3) % 15) as i16) - 7);
}
}
}
}
grid.set(2, 2, 1, 0, 5);
grid.set(6, 6, 1, 0, 5);
let map = compute_uniward(&grid, &qt);
let cost_textured = map.get(2, 2, 1, 0);
let cost_smooth = map.get(6, 6, 1, 0);
assert!(
cost_textured.is_finite(),
"textured cost should be finite: {cost_textured}"
);
assert!(
cost_smooth.is_finite(),
"smooth cost should be finite: {cost_smooth}"
);
assert!(
cost_textured < cost_smooth,
"textured region {cost_textured} should be < smooth region {cost_smooth}"
);
}
#[test]
fn costs_are_positive() {
let mut grid = DctGrid::new(4, 4);
for br in 0..4 {
for bc in 0..4 {
grid.set(br, bc, 0, 0, 80);
for i in 0..8 {
for j in 0..8 {
if i == 0 && j == 0 { continue; }
if (i + j) % 3 == 0 {
grid.set(br, bc, i, j, ((i * 3 + j * 7) % 20) as i16 - 10);
}
}
}
}
}
let map = compute_uniward(&grid, &standard_qt());
for br in 0..4 {
for bc in 0..4 {
for i in 0..8 {
for j in 0..8 {
let cost = map.get(br, bc, i, j);
assert!(
cost >= 0.0,
"negative cost {cost} at ({br},{bc},{i},{j})"
);
}
}
}
}
}
#[test]
fn mirror_index_works() {
assert_eq!(mirror_index(-1, 10), 1);
assert_eq!(mirror_index(-3, 10), 3);
assert_eq!(mirror_index(0, 10), 0);
assert_eq!(mirror_index(5, 10), 5);
assert_eq!(mirror_index(9, 10), 9);
assert_eq!(mirror_index(10, 10), 8);
assert_eq!(mirror_index(11, 10), 7);
}
#[test]
fn lpdf_is_correct_length() {
let lp = lpdf();
assert_eq!(lp.len(), 16);
let sum: f64 = lp.iter().sum();
assert!(
(sum - std::f64::consts::SQRT_2).abs() < 0.01,
"low-pass filter sum {sum} should be ~sqrt(2)"
);
}
#[test]
fn hpdf_sums_to_zero() {
let sum: f64 = HPDF.iter().sum();
assert!(
sum.abs() < 1e-10,
"high-pass filter sum {sum} should be ~0"
);
}
#[test]
fn cost_with_real_photo() {
let data = match std::fs::read("test-vectors/image/photo_320x240_q75_420.jpg") {
Ok(d) => d,
Err(_) => return, };
let img = crate::codec::jpeg::JpegImage::from_bytes(&data).unwrap();
let grid = img.dct_grid(0);
let qt_id = img.frame_info().components[0].quant_table_id as usize;
let qt = img.quant_table(qt_id).unwrap();
let map = compute_uniward(grid, qt);
let bw = grid.blocks_wide();
let bt = grid.blocks_tall();
let mut finite_count = 0;
let mut total_cost = 0.0f64;
for br in 0..bt {
for bc in 0..bw {
for i in 0..8 {
for j in 0..8 {
let c = map.get(br, bc, i, j);
if c.is_finite() {
finite_count += 1;
total_cost += c as f64;
}
}
}
}
}
assert!(
finite_count > 1000,
"expected >1000 finite costs, got {finite_count}"
);
let avg = total_cost / finite_count as f64;
assert!(avg > 0.0, "average cost should be positive: {avg}");
}
#[test]
fn determinism_repeated_runs() {
let mut grid = DctGrid::new(6, 6);
for br in 0..6 {
for bc in 0..6 {
grid.set(br, bc, 0, 0, 100);
for i in 0..8 {
for j in 0..8 {
if i == 0 && j == 0 { continue; }
let val = (((br * 7 + bc * 13 + i * 3 + j * 11) % 21) as i16) - 10;
if val != 0 {
grid.set(br, bc, i, j, val);
}
}
}
}
}
let qt = standard_qt();
let map1 = compute_uniward(&grid, &qt);
let map2 = compute_uniward(&grid, &qt);
for br in 0..6 {
for bc in 0..6 {
for i in 0..8 {
for j in 0..8 {
let c1 = map1.get(br, bc, i, j);
let c2 = map2.get(br, bc, i, j);
assert_eq!(
c1.to_bits(), c2.to_bits(),
"cost mismatch at ({br},{bc},{i},{j}): {c1} vs {c2}"
);
}
}
}
}
}
#[test]
fn ghost_roundtrip_with_current_feature_set() {
let data = match std::fs::read("test-vectors/image/photo_320x240_q75_420.jpg") {
Ok(d) => d,
Err(_) => return, };
let message = "Parallel cost test";
let passphrase = "test-pass-42";
let stego = crate::stego::ghost_encode(&data, message, passphrase)
.expect("ghost_encode should succeed");
let decoded = crate::stego::ghost_decode(&stego, passphrase)
.expect("ghost_decode should succeed");
assert_eq!(decoded.text, message, "round-trip mismatch");
assert!(decoded.files.is_empty(), "no files expected");
}
#[test]
#[ignore]
fn cost_computation_benchmark() {
let data = if let Ok(d) = std::fs::read("test-vectors/image/photo_320x240_q75_420.jpg") { d } else {
eprintln!("Skipping benchmark: test vector not found");
return;
};
let img = crate::codec::jpeg::JpegImage::from_bytes(&data).unwrap();
let grid = img.dct_grid(0);
let qt_id = img.frame_info().components[0].quant_table_id as usize;
let qt = img.quant_table(qt_id).unwrap();
let _ = compute_uniward(grid, qt);
let iterations = 10;
let start = std::time::Instant::now();
for _ in 0..iterations {
let _ = compute_uniward(grid, qt);
}
let elapsed = start.elapsed();
let bw = grid.blocks_wide();
let bt = grid.blocks_tall();
eprintln!(
"J-UNIWARD cost ({bw}x{bt} blocks, {}x{} pixels): {:.1} ms avg over {iterations} runs [feature=parallel: {}]",
bw * 8, bt * 8,
elapsed.as_secs_f64() * 1000.0 / iterations as f64,
cfg!(feature = "parallel"),
);
}
}