use crate::codec::jpeg::dct::DctGrid;
use crate::codec::jpeg::pixels::dct_block_unquantized;
pub struct SideInfo {
rounding_errors: Vec<i8>,
pub blocks_wide: usize,
pub blocks_tall: usize,
}
#[inline]
fn encode_error(error: f64) -> i8 {
(error * 254.0).round().clamp(-127.0, 127.0) as i8
}
#[inline]
fn decode_error(val: i8) -> f32 {
val as f32 / 254.0
}
const MIN_SI_COST: f32 = 1e-6;
impl SideInfo {
pub fn compute(
raw_rgb: &[u8],
pixel_width: u32,
pixel_height: u32,
cover_grid: &DctGrid,
qt_values: &[u16; 64],
) -> Self {
let bw = cover_grid.blocks_wide();
let bh = cover_grid.blocks_tall();
let total_coeffs = bw * bh * 64;
let mut errors = vec![0i8; total_coeffs];
let luma_bw = (pixel_width as usize).div_ceil(8);
let luma_bh = (pixel_height as usize).div_ceil(8);
const STRIP_ROWS: usize = 50;
for strip_start in (0..bh).step_by(STRIP_ROWS) {
let strip_end = (strip_start + STRIP_ROWS).min(bh);
let luma_strip = rgb_to_luma_blocks_strip(
raw_rgb, pixel_width, pixel_height, strip_start, strip_end,
);
for br in strip_start..strip_end {
for bc in 0..bw {
let block_idx = br * bw + bc;
if br >= luma_bh || bc >= luma_bw {
continue; }
let local_idx = (br - strip_start) * luma_bw + bc;
let luma_block = &luma_strip[local_idx];
let unquantized = dct_block_unquantized(luma_block, qt_values);
let cover_block: [i16; 64] = {
let slice = cover_grid.block(br, bc);
slice.try_into().unwrap()
};
for k in 0..64 {
let error = (unquantized[k] - cover_block[k] as f64).clamp(-0.5, 0.5);
errors[block_idx * 64 + k] = encode_error(error);
}
}
}
}
SideInfo {
rounding_errors: errors,
blocks_wide: bw,
blocks_tall: bh,
}
}
#[inline]
pub fn error_at(&self, flat_idx: usize) -> f32 {
decode_error(self.rounding_errors[flat_idx])
}
}
fn rgb_to_luma_blocks_strip(
rgb: &[u8],
width: u32,
height: u32,
br_start: usize,
br_end: usize,
) -> Vec<[f64; 64]> {
let w = width as usize;
let h = height as usize;
let luma_bw = w.div_ceil(8);
let luma_bh = h.div_ceil(8);
let strip_br_end = br_end.min(luma_bh);
let strip_rows = strip_br_end.saturating_sub(br_start);
let mut blocks = Vec::with_capacity(strip_rows * luma_bw);
for br in br_start..strip_br_end {
for bc in 0..luma_bw {
let mut block = [0.0f64; 64];
for row in 0..8 {
for col in 0..8 {
let py = (br * 8 + row).min(h - 1);
let px = (bc * 8 + col).min(w - 1);
let idx = (py * w + px) * 3;
let r = rgb[idx] as f64;
let g = rgb[idx + 1] as f64;
let b = rgb[idx + 2] as f64;
block[row * 8 + col] = 0.299 * r + 0.587 * g + 0.114 * b;
}
}
blocks.push(block);
}
}
blocks
}
pub fn modulate_costs_si(
cost_map: &mut crate::stego::cost::CostMap,
side_info: &SideInfo,
cover_grid: &DctGrid,
) {
let bw = cost_map.blocks_wide();
let bh = cost_map.blocks_tall();
for br in 0..bh {
for bc in 0..bw {
let block_idx = br * bw + bc;
for i in 0..8 {
for j in 0..8 {
if i == 0 && j == 0 {
continue;
}
let cost = cost_map.get(br, bc, i, j);
if !cost.is_finite() {
continue; }
let coeff = cover_grid.get(br, bc, i, j);
if coeff.abs() == 1 {
continue;
}
let flat_idx = block_idx * 64 + i * 8 + j;
let error = side_info.error_at(flat_idx);
let abs_error = error.abs();
let factor = 1.0f32 - 2.0 * abs_error;
let modulated = (cost * factor).max(MIN_SI_COST);
cost_map.set(br, bc, i, j, modulated);
}
}
}
}
}
#[inline]
pub fn si_modify_coefficient(coeff: i16, rounding_error: f32) -> i16 {
if coeff == 1 {
2 } else if coeff == -1 {
-2 } else if rounding_error > 0.0 {
coeff + 1 } else {
coeff - 1 }
}
#[inline]
pub fn nsf5_modify_coefficient(coeff: i16) -> i16 {
if coeff == 1 {
2
} else if coeff == -1 {
-2
} else if coeff > 1 {
coeff - 1
} else if coeff < -1 {
coeff + 1
} else {
coeff }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::jpeg::pixels::dct_block;
fn standard_qt() -> [u16; 64] {
[
16, 11, 10, 16, 24, 40, 51, 61,
12, 12, 14, 19, 26, 58, 60, 55,
14, 13, 16, 24, 40, 57, 69, 56,
14, 17, 22, 29, 51, 87, 80, 62,
18, 22, 37, 56, 68, 109, 103, 77,
24, 35, 55, 64, 81, 104, 113, 92,
49, 64, 78, 87, 103, 121, 120, 101,
72, 92, 95, 98, 112, 100, 103, 99,
]
}
#[test]
fn t1_unquantized_rounds_to_quantized() {
let patterns: Vec<[f64; 64]> = vec![
[128.0; 64],
{
let mut p = [0.0f64; 64];
for i in 0..64 {
p[i] = 50.0 + (i as f64) * 3.0;
}
p
},
{
let mut p = [0.0f64; 64];
for i in 0..64 {
p[i] = if i % 2 == 0 { 20.0 } else { 230.0 };
}
p
},
];
let qt = standard_qt();
for pixels in &patterns {
let quantized = dct_block(pixels, &qt);
let unquantized = dct_block_unquantized(pixels, &qt);
for i in 0..64 {
assert_eq!(
quantized[i],
unquantized[i].round() as i16,
"Mismatch at index {i}: quantized={}, unquantized={}",
quantized[i],
unquantized[i]
);
}
}
}
#[test]
fn t2_rounding_errors_in_range() {
let qt = standard_qt();
for seed in 0..10u8 {
let mut pixels = [0.0f64; 64];
for i in 0..64 {
pixels[i] = ((seed as f64 * 37.0 + i as f64 * 13.0) % 256.0).abs();
}
let quantized = dct_block(&pixels, &qt);
let unquantized = dct_block_unquantized(&pixels, &qt);
for i in 0..64 {
let error = unquantized[i] - quantized[i] as f64;
assert!(
(-0.50001..=0.50001).contains(&error),
"seed={seed}, index={i}: error={error}"
);
}
}
}
#[test]
fn t3_half_coefficient_cost_not_zero() {
let factor = 1.0f32 - 2.0 * 0.5_f32; let cost = 1.0f32;
let modulated = (cost * factor).max(MIN_SI_COST);
assert!(modulated > 0.0, "1/2-coefficient must not have zero cost");
assert_eq!(modulated, MIN_SI_COST);
}
#[test]
fn t4_si_cost_scales_with_rounding_error() {
let cost = 1.0f32;
let small_error = 0.1f32;
let large_error = 0.4f32;
let small_modulated = (cost * (1.0f32 - 2.0 * small_error)).max(MIN_SI_COST);
let large_modulated = (cost * (1.0f32 - 2.0 * large_error)).max(MIN_SI_COST);
assert!(
large_modulated < small_modulated,
"larger error should give lower cost: small={small_modulated}, large={large_modulated}"
);
}
#[test]
fn t4_si_costs_never_exceed_original() {
for error_pct in 0..=50 {
let error = error_pct as f32 / 100.0;
let cost = 5.0f32;
let factor = 1.0f32 - 2.0 * error;
let modulated = (cost * factor).max(MIN_SI_COST);
assert!(
modulated <= cost + 1e-6,
"modulated={modulated} > original={cost} at error={error}"
);
}
}
#[test]
fn t5_anti_shrinkage_preserved() {
assert_eq!(si_modify_coefficient(1, -0.4_f32), 2);
assert_eq!(si_modify_coefficient(1, 0.4_f32), 2);
assert_eq!(si_modify_coefficient(1, 0.0_f32), 2);
assert_eq!(si_modify_coefficient(-1, -0.4_f32), -2);
assert_eq!(si_modify_coefficient(-1, 0.4_f32), -2);
assert_eq!(si_modify_coefficient(-1, 0.0_f32), -2);
}
#[test]
fn t6_direction_follows_rounding_error() {
assert_eq!(si_modify_coefficient(5, 0.3_f32), 6);
assert_eq!(si_modify_coefficient(-5, 0.3_f32), -4);
assert_eq!(si_modify_coefficient(5, -0.3_f32), 4);
assert_eq!(si_modify_coefficient(-5, -0.3_f32), -6);
assert_eq!(si_modify_coefficient(5, 0.0_f32), 4);
assert_eq!(si_modify_coefficient(-5, 0.0_f32), -6);
}
#[test]
fn t6b_nsf5_toward_zero() {
assert_eq!(nsf5_modify_coefficient(5), 4);
assert_eq!(nsf5_modify_coefficient(-5), -4);
assert_eq!(nsf5_modify_coefficient(2), 1);
assert_eq!(nsf5_modify_coefficient(-2), -1);
assert_eq!(nsf5_modify_coefficient(1), 2);
assert_eq!(nsf5_modify_coefficient(-1), -2);
}
#[test]
fn t7_i8_encode_decode_precision() {
for i in 0..=100 {
let error = (i as f64 - 50.0) / 100.0; let encoded = encode_error(error);
let decoded = decode_error(encoded);
let original_factor = 1.0 - 2.0 * error.abs();
let decoded_factor = 1.0f32 - 2.0 * decoded.abs();
let factor_error = (original_factor as f32 - decoded_factor).abs();
assert!(
factor_error < 0.02, "error={error}, encoded={encoded}, decoded={decoded}, factor_error={factor_error}"
);
}
}
#[test]
fn t7_i8_sign_preserved() {
assert!(decode_error(encode_error(0.3)) > 0.0);
assert!(decode_error(encode_error(-0.3)) < 0.0);
assert_eq!(decode_error(encode_error(0.0)), 0.0);
}
#[test]
fn t8_strip_luma_matches_full() {
use crate::codec::jpeg::pixels::rgb_to_luma_blocks;
let width = 24u32;
let height = 16u32;
let mut rgb = vec![0u8; (width * height * 3) as usize];
for i in 0..rgb.len() {
rgb[i] = ((i * 37 + 13) % 256) as u8;
}
let full_blocks = rgb_to_luma_blocks(&rgb, width, height);
let luma_bw = (width as usize).div_ceil(8);
let strip_all = rgb_to_luma_blocks_strip(&rgb, width, height, 0, 2);
assert_eq!(strip_all.len(), full_blocks.len());
for (i, (a, b)) in full_blocks.iter().zip(strip_all.iter()).enumerate() {
for k in 0..64 {
assert!(
(a[k] - b[k]).abs() < 1e-10,
"block {i}, coeff {k}: full={}, strip={}",
a[k], b[k]
);
}
}
let strip0 = rgb_to_luma_blocks_strip(&rgb, width, height, 0, 1);
let strip1 = rgb_to_luma_blocks_strip(&rgb, width, height, 1, 2);
assert_eq!(strip0.len(), luma_bw);
assert_eq!(strip1.len(), luma_bw);
for bc in 0..luma_bw {
for k in 0..64 {
assert!(
(full_blocks[bc][k] - strip0[bc][k]).abs() < 1e-10,
"row 0, block {bc}, coeff {k}"
);
assert!(
(full_blocks[luma_bw + bc][k] - strip1[bc][k]).abs() < 1e-10,
"row 1, block {bc}, coeff {k}"
);
}
}
}
}