phasm_core/stego/ghost/side_info.rs
1// Copyright (c) 2026 Christoph Gaffga
2// SPDX-License-Identifier: GPL-3.0-only
3// https://github.com/cgaffga/phasmcore
4
5//! Side information for SI-UNIWARD (Side-Informed UNIWARD).
6//!
7//! When the encoder has access to the original uncompressed pixels (e.g. PNG,
8//! HEIC, or RAW input), it can compute the quantization rounding errors — the
9//! difference between the continuous DCT coefficients and their rounded integer
10//! values. These errors reveal which coefficients are "close to the boundary"
11//! between two quantization levels and can be cheaply flipped.
12//!
13//! SI-UNIWARD uses this to:
14//! 1. **Lower embedding costs** for coefficients with large rounding errors
15//! (cheap to push across the boundary).
16//! 2. **Choose modification direction** toward the pre-quantization value
17//! (minimizing perceptual distortion).
18//!
19//! The result: ~1.5-2× capacity at the same detection risk, or equivalently
20//! the same capacity with significantly lower distortion.
21//!
22//! The decoder is completely unchanged — it reads LSBs regardless of which
23//! direction the modification went.
24
25use crate::codec::jpeg::dct::DctGrid;
26use crate::codec::jpeg::pixels::dct_block_unquantized;
27
28/// Per-coefficient rounding errors from quantization.
29///
30/// Each error is in [-0.5, +0.5] and represents how far the continuous
31/// (unquantized) DCT coefficient was from its rounded integer value.
32/// Positive error means the pre-quantization value was above the integer;
33/// negative means below.
34///
35/// Stored as i8 [-127, +127] via `error * 254`, giving ~0.004 resolution.
36/// This is 8x smaller than f64 and 4x smaller than f32, saving 85 MB (12MP)
37/// or 341 MB (48MP) compared to the original f64 representation.
38pub struct SideInfo {
39 /// Rounding errors in DctGrid flat order (block_idx * 64 + row * 8 + col).
40 /// Encoded as i8: value = (error * 254).round().clamp(-127, 127).
41 rounding_errors: Vec<i8>,
42 /// Number of 8x8 blocks horizontally.
43 pub blocks_wide: usize,
44 /// Number of 8x8 blocks vertically.
45 pub blocks_tall: usize,
46}
47
48/// Encode a rounding error [-0.5, +0.5] to i8 [-127, +127].
49#[inline]
50fn encode_error(error: f64) -> i8 {
51 (error * 254.0).round().clamp(-127.0, 127.0) as i8
52}
53
54/// Decode an i8 error back to approximate f32.
55#[inline]
56fn decode_error(val: i8) -> f32 {
57 val as f32 / 254.0
58}
59
60/// Minimum cost for SI-modulated coefficients.
61///
62/// When |rounding_error| ~ 0.5 ("1/2-coefficients"), the modulated cost
63/// approaches zero. Clamping to this floor prevents zero-cost embedding
64/// at quantization midpoints, which is a known detectable artifact.
65const MIN_SI_COST: f32 = 1e-6;
66
67impl SideInfo {
68 /// Compute side information from raw RGB pixels and the cover JPEG.
69 ///
70 /// For each Y-channel 8x8 block:
71 /// 1. Forward DCT on the original (pre-JPEG) pixels
72 /// 2. Divide by quantization table (without rounding)
73 /// 3. error = unquantized_value - cover_integer_coefficient
74 ///
75 /// Errors are clamped to [-0.5, +0.5] for robustness against minor
76 /// floating-point differences between the platform's JPEG encoder and
77 /// our forward DCT implementation.
78 ///
79 /// Luma blocks are computed in strips of 50 block-rows to limit transient
80 /// memory (~12.9 MB per strip instead of ~97.5 MB for all blocks at once
81 /// on a 12MP image).
82 pub fn compute(
83 raw_rgb: &[u8],
84 pixel_width: u32,
85 pixel_height: u32,
86 cover_grid: &DctGrid,
87 qt_values: &[u16; 64],
88 ) -> Self {
89 let bw = cover_grid.blocks_wide();
90 let bh = cover_grid.blocks_tall();
91 let total_coeffs = bw * bh * 64;
92 let mut errors = vec![0i8; total_coeffs];
93
94 let luma_bw = (pixel_width as usize).div_ceil(8);
95 let luma_bh = (pixel_height as usize).div_ceil(8);
96
97 // Process luma blocks in strips to limit transient memory.
98 // Each strip holds at most STRIP_ROWS block-rows of luma data.
99 const STRIP_ROWS: usize = 50;
100 for strip_start in (0..bh).step_by(STRIP_ROWS) {
101 let strip_end = (strip_start + STRIP_ROWS).min(bh);
102 let luma_strip = rgb_to_luma_blocks_strip(
103 raw_rgb, pixel_width, pixel_height, strip_start, strip_end,
104 );
105
106 for br in strip_start..strip_end {
107 for bc in 0..bw {
108 let block_idx = br * bw + bc;
109
110 // Skip if outside the raw pixel grid
111 if br >= luma_bh || bc >= luma_bw {
112 continue; // leave errors at 0
113 }
114
115 let local_idx = (br - strip_start) * luma_bw + bc;
116 let luma_block = &luma_strip[local_idx];
117
118 // Forward DCT + divide by QT (no rounding)
119 let unquantized = dct_block_unquantized(luma_block, qt_values);
120
121 // Compute and clamp rounding errors, encode to i8
122 let cover_block: [i16; 64] = {
123 let slice = cover_grid.block(br, bc);
124 slice.try_into().unwrap()
125 };
126
127 for k in 0..64 {
128 let error = (unquantized[k] - cover_block[k] as f64).clamp(-0.5, 0.5);
129 errors[block_idx * 64 + k] = encode_error(error);
130 }
131 }
132 }
133 // luma_strip dropped here -- only one strip in memory at a time
134 }
135
136 SideInfo {
137 rounding_errors: errors,
138 blocks_wide: bw,
139 blocks_tall: bh,
140 }
141 }
142
143 /// Get the rounding error at a flat index, decoded from i8 to f32.
144 #[inline]
145 pub fn error_at(&self, flat_idx: usize) -> f32 {
146 decode_error(self.rounding_errors[flat_idx])
147 }
148}
149
150/// Convert a horizontal strip of RGB pixels to Y (luminance) 8x8 blocks.
151///
152/// Only converts block rows in `[br_start, br_end)`, returning them in
153/// raster order with `luma_bw` blocks per row. This avoids allocating
154/// ALL luma blocks at once (97.5 MB for 12MP, 390 MB for 48MP).
155///
156/// Uses BT.601: `Y = 0.299*R + 0.587*G + 0.114*B`.
157/// Handles non-multiple-of-8 dimensions by edge-replicating.
158fn rgb_to_luma_blocks_strip(
159 rgb: &[u8],
160 width: u32,
161 height: u32,
162 br_start: usize,
163 br_end: usize,
164) -> Vec<[f64; 64]> {
165 let w = width as usize;
166 let h = height as usize;
167 let luma_bw = w.div_ceil(8);
168 let luma_bh = h.div_ceil(8);
169
170 let strip_br_end = br_end.min(luma_bh);
171 let strip_rows = strip_br_end.saturating_sub(br_start);
172
173 let mut blocks = Vec::with_capacity(strip_rows * luma_bw);
174
175 for br in br_start..strip_br_end {
176 for bc in 0..luma_bw {
177 let mut block = [0.0f64; 64];
178 for row in 0..8 {
179 for col in 0..8 {
180 let py = (br * 8 + row).min(h - 1);
181 let px = (bc * 8 + col).min(w - 1);
182 let idx = (py * w + px) * 3;
183 let r = rgb[idx] as f64;
184 let g = rgb[idx + 1] as f64;
185 let b = rgb[idx + 2] as f64;
186 block[row * 8 + col] = 0.299 * r + 0.587 * g + 0.114 * b;
187 }
188 }
189 blocks.push(block);
190 }
191 }
192
193 blocks
194}
195
196/// Modulate J-UNIWARD costs using SI rounding errors.
197///
198/// For each AC coefficient with finite cost:
199/// - `modulated_cost = rho * (1 - 2|e|)` where `e` is the rounding error
200/// - Larger |e| -> lower cost (closer to quantization boundary -> cheaper to flip)
201/// - |e| = 0 -> cost unchanged (exactly on the integer -> no benefit)
202/// - |e| = 0.5 -> cost clamped to `MIN_SI_COST` (avoid zero-cost artifact)
203///
204/// Special cases:
205/// - DC coefficients remain WET (infinite cost)
206/// - |coeff| = 1 positions: cost is NOT modulated (anti-shrinkage forces
207/// the direction, so the rounding error doesn't help choose direction)
208pub fn modulate_costs_si(
209 cost_map: &mut crate::stego::cost::CostMap,
210 side_info: &SideInfo,
211 cover_grid: &DctGrid,
212) {
213 let bw = cost_map.blocks_wide();
214 let bh = cost_map.blocks_tall();
215
216 for br in 0..bh {
217 for bc in 0..bw {
218 let block_idx = br * bw + bc;
219 for i in 0..8 {
220 for j in 0..8 {
221 // Skip DC
222 if i == 0 && j == 0 {
223 continue;
224 }
225
226 let cost = cost_map.get(br, bc, i, j);
227 if !cost.is_finite() {
228 continue; // WET position -- leave as-is
229 }
230
231 // Skip |coeff| == 1: anti-shrinkage forces direction,
232 // SI modulation doesn't help
233 let coeff = cover_grid.get(br, bc, i, j);
234 if coeff.abs() == 1 {
235 continue;
236 }
237
238 let flat_idx = block_idx * 64 + i * 8 + j;
239 let error = side_info.error_at(flat_idx);
240 let abs_error = error.abs();
241
242 // modulated = rho * (1 - 2|e|)
243 // When |e| = 0.5: modulated = 0 -> clamp to MIN_SI_COST
244 let factor = 1.0f32 - 2.0 * abs_error;
245 let modulated = (cost * factor).max(MIN_SI_COST);
246 cost_map.set(br, bc, i, j, modulated);
247 }
248 }
249 }
250 }
251}
252
253/// Determine the modification direction for a coefficient using SI rounding error.
254///
255/// Returns the modified coefficient value (coeff +/- 1).
256///
257/// Rules:
258/// - |coeff| == 1: ALWAYS away from zero (anti-shrinkage, prevents coeff -> 0)
259/// - |coeff| > 1 with side info: toward the pre-quantization value
260/// (error > 0 -> precover was above -> go up; error < 0 -> go down)
261/// - |coeff| > 1 without side info: nsF5 convention (toward zero)
262/// - coeff == 0: should never be called (filtered out as WET)
263#[inline]
264pub fn si_modify_coefficient(coeff: i16, rounding_error: f32) -> i16 {
265 if coeff == 1 {
266 2 // anti-shrinkage: away from zero
267 } else if coeff == -1 {
268 -2 // anti-shrinkage: away from zero
269 } else if rounding_error > 0.0 {
270 coeff + 1 // precover was above -> go up
271 } else {
272 coeff - 1 // precover was at or below -> go down
273 }
274}
275
276/// Standard nsF5 modification direction (no side info).
277///
278/// - |coeff| == 1: away from zero
279/// - |coeff| > 1: toward zero
280#[inline]
281pub fn nsf5_modify_coefficient(coeff: i16) -> i16 {
282 if coeff == 1 {
283 2
284 } else if coeff == -1 {
285 -2
286 } else if coeff > 1 {
287 coeff - 1
288 } else if coeff < -1 {
289 coeff + 1
290 } else {
291 coeff // zero: should never happen
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use crate::codec::jpeg::pixels::dct_block;
299
300 // --- T1: dct_block_unquantized matches dct_block ---
301
302 fn standard_qt() -> [u16; 64] {
303 [
304 16, 11, 10, 16, 24, 40, 51, 61,
305 12, 12, 14, 19, 26, 58, 60, 55,
306 14, 13, 16, 24, 40, 57, 69, 56,
307 14, 17, 22, 29, 51, 87, 80, 62,
308 18, 22, 37, 56, 68, 109, 103, 77,
309 24, 35, 55, 64, 81, 104, 113, 92,
310 49, 64, 78, 87, 103, 121, 120, 101,
311 72, 92, 95, 98, 112, 100, 103, 99,
312 ]
313 }
314
315 #[test]
316 fn t1_unquantized_rounds_to_quantized() {
317 // Various test patterns
318 let patterns: Vec<[f64; 64]> = vec![
319 // Flat gray
320 [128.0; 64],
321 // Gradient
322 {
323 let mut p = [0.0f64; 64];
324 for i in 0..64 {
325 p[i] = 50.0 + (i as f64) * 3.0;
326 }
327 p
328 },
329 // High contrast
330 {
331 let mut p = [0.0f64; 64];
332 for i in 0..64 {
333 p[i] = if i % 2 == 0 { 20.0 } else { 230.0 };
334 }
335 p
336 },
337 ];
338
339 let qt = standard_qt();
340 for pixels in &patterns {
341 let quantized = dct_block(pixels, &qt);
342 let unquantized = dct_block_unquantized(pixels, &qt);
343 for i in 0..64 {
344 assert_eq!(
345 quantized[i],
346 unquantized[i].round() as i16,
347 "Mismatch at index {i}: quantized={}, unquantized={}",
348 quantized[i],
349 unquantized[i]
350 );
351 }
352 }
353 }
354
355 // --- T2: Rounding errors in range ---
356
357 #[test]
358 fn t2_rounding_errors_in_range() {
359 let qt = standard_qt();
360 // Test with multiple pixel patterns
361 for seed in 0..10u8 {
362 let mut pixels = [0.0f64; 64];
363 for i in 0..64 {
364 pixels[i] = ((seed as f64 * 37.0 + i as f64 * 13.0) % 256.0).abs();
365 }
366 let quantized = dct_block(&pixels, &qt);
367 let unquantized = dct_block_unquantized(&pixels, &qt);
368 for i in 0..64 {
369 let error = unquantized[i] - quantized[i] as f64;
370 assert!(
371 (-0.50001..=0.50001).contains(&error),
372 "seed={seed}, index={i}: error={error}"
373 );
374 }
375 }
376 }
377
378 // --- T3: Half-coefficient clamping ---
379
380 #[test]
381 fn t3_half_coefficient_cost_not_zero() {
382 // When |error| = 0.5, modulated cost must NOT be zero
383 let factor = 1.0f32 - 2.0 * 0.5_f32; // = 0.0
384 let cost = 1.0f32;
385 let modulated = (cost * factor).max(MIN_SI_COST);
386 assert!(modulated > 0.0, "1/2-coefficient must not have zero cost");
387 assert_eq!(modulated, MIN_SI_COST);
388 }
389
390 // --- T4: Asymmetric cost modulation ---
391
392 #[test]
393 fn t4_si_cost_scales_with_rounding_error() {
394 // Larger |error| -> lower cost
395 let cost = 1.0f32;
396 let small_error = 0.1f32;
397 let large_error = 0.4f32;
398
399 let small_modulated = (cost * (1.0f32 - 2.0 * small_error)).max(MIN_SI_COST);
400 let large_modulated = (cost * (1.0f32 - 2.0 * large_error)).max(MIN_SI_COST);
401
402 assert!(
403 large_modulated < small_modulated,
404 "larger error should give lower cost: small={small_modulated}, large={large_modulated}"
405 );
406 }
407
408 #[test]
409 fn t4_si_costs_never_exceed_original() {
410 // Modulated costs should always be <= original
411 for error_pct in 0..=50 {
412 let error = error_pct as f32 / 100.0;
413 let cost = 5.0f32;
414 let factor = 1.0f32 - 2.0 * error;
415 let modulated = (cost * factor).max(MIN_SI_COST);
416 assert!(
417 modulated <= cost + 1e-6,
418 "modulated={modulated} > original={cost} at error={error}"
419 );
420 }
421 }
422
423 // --- T5: Anti-shrinkage preserved ---
424
425 #[test]
426 fn t5_anti_shrinkage_preserved() {
427 // |coeff| = 1 must always go away from zero
428 assert_eq!(si_modify_coefficient(1, -0.4_f32), 2);
429 assert_eq!(si_modify_coefficient(1, 0.4_f32), 2);
430 assert_eq!(si_modify_coefficient(1, 0.0_f32), 2);
431 assert_eq!(si_modify_coefficient(-1, -0.4_f32), -2);
432 assert_eq!(si_modify_coefficient(-1, 0.4_f32), -2);
433 assert_eq!(si_modify_coefficient(-1, 0.0_f32), -2);
434 }
435
436 // --- T6: Direction selection ---
437
438 #[test]
439 fn t6_direction_follows_rounding_error() {
440 // Positive error -> precover above -> go up
441 assert_eq!(si_modify_coefficient(5, 0.3_f32), 6);
442 assert_eq!(si_modify_coefficient(-5, 0.3_f32), -4); // toward zero = up
443
444 // Negative error -> precover below -> go down
445 assert_eq!(si_modify_coefficient(5, -0.3_f32), 4);
446 assert_eq!(si_modify_coefficient(-5, -0.3_f32), -6); // away from zero = down
447
448 // Zero error -> down (the else branch)
449 assert_eq!(si_modify_coefficient(5, 0.0_f32), 4);
450 assert_eq!(si_modify_coefficient(-5, 0.0_f32), -6);
451 }
452
453 // --- T6b: nsF5 direction ---
454
455 #[test]
456 fn t6b_nsf5_toward_zero() {
457 assert_eq!(nsf5_modify_coefficient(5), 4);
458 assert_eq!(nsf5_modify_coefficient(-5), -4);
459 assert_eq!(nsf5_modify_coefficient(2), 1);
460 assert_eq!(nsf5_modify_coefficient(-2), -1);
461 // Anti-shrinkage
462 assert_eq!(nsf5_modify_coefficient(1), 2);
463 assert_eq!(nsf5_modify_coefficient(-1), -2);
464 }
465
466 // --- T7: i8 encode/decode roundtrip ---
467
468 #[test]
469 fn t7_i8_encode_decode_precision() {
470 // Test that encode_error/decode_error roundtrip has <1% error
471 // for the cost modulation factor (1 - 2|e|).
472 for i in 0..=100 {
473 let error = (i as f64 - 50.0) / 100.0; // [-0.5, +0.5]
474 let encoded = encode_error(error);
475 let decoded = decode_error(encoded);
476
477 // Check the modulation factor precision
478 let original_factor = 1.0 - 2.0 * error.abs();
479 let decoded_factor = 1.0f32 - 2.0 * decoded.abs();
480 let factor_error = (original_factor as f32 - decoded_factor).abs();
481 assert!(
482 factor_error < 0.02, // <2% error on factor
483 "error={error}, encoded={encoded}, decoded={decoded}, factor_error={factor_error}"
484 );
485 }
486 }
487
488 #[test]
489 fn t7_i8_sign_preserved() {
490 // Sign must be exact for si_modify_coefficient direction
491 assert!(decode_error(encode_error(0.3)) > 0.0);
492 assert!(decode_error(encode_error(-0.3)) < 0.0);
493 assert_eq!(decode_error(encode_error(0.0)), 0.0);
494 }
495
496 // --- T8: strip-based luma matches full computation ---
497
498 #[test]
499 fn t8_strip_luma_matches_full() {
500 use crate::codec::jpeg::pixels::rgb_to_luma_blocks;
501
502 // Create a small test image (24x16 = 3x2 blocks)
503 let width = 24u32;
504 let height = 16u32;
505 let mut rgb = vec![0u8; (width * height * 3) as usize];
506 for i in 0..rgb.len() {
507 rgb[i] = ((i * 37 + 13) % 256) as u8;
508 }
509
510 let full_blocks = rgb_to_luma_blocks(&rgb, width, height);
511 let luma_bw = (width as usize).div_ceil(8); // 3
512
513 // Get strip for all rows at once
514 let strip_all = rgb_to_luma_blocks_strip(&rgb, width, height, 0, 2);
515 assert_eq!(strip_all.len(), full_blocks.len());
516 for (i, (a, b)) in full_blocks.iter().zip(strip_all.iter()).enumerate() {
517 for k in 0..64 {
518 assert!(
519 (a[k] - b[k]).abs() < 1e-10,
520 "block {i}, coeff {k}: full={}, strip={}",
521 a[k], b[k]
522 );
523 }
524 }
525
526 // Get strips one row at a time
527 let strip0 = rgb_to_luma_blocks_strip(&rgb, width, height, 0, 1);
528 let strip1 = rgb_to_luma_blocks_strip(&rgb, width, height, 1, 2);
529 assert_eq!(strip0.len(), luma_bw);
530 assert_eq!(strip1.len(), luma_bw);
531 for bc in 0..luma_bw {
532 for k in 0..64 {
533 assert!(
534 (full_blocks[bc][k] - strip0[bc][k]).abs() < 1e-10,
535 "row 0, block {bc}, coeff {k}"
536 );
537 assert!(
538 (full_blocks[luma_bw + bc][k] - strip1[bc][k]).abs() < 1e-10,
539 "row 1, block {bc}, coeff {k}"
540 );
541 }
542 }
543 }
544}