Skip to main content

phasm_core/stego/ghost/
side_info.rs

1// Copyright (c) 2026 Christoph Gaffga
2// SPDX-License-Identifier: GPL-3.0-only
3// https://github.com/cgaffga/phasmcore
4
5//! Side information for SI-UNIWARD (Side-Informed UNIWARD).
6//!
7//! When the encoder has access to the original uncompressed pixels (e.g. PNG,
8//! HEIC, or RAW input), it can compute the quantization rounding errors — the
9//! difference between the continuous DCT coefficients and their rounded integer
10//! values. These errors reveal which coefficients are "close to the boundary"
11//! between two quantization levels and can be cheaply flipped.
12//!
13//! SI-UNIWARD uses this to:
14//! 1. **Lower embedding costs** for coefficients with large rounding errors
15//!    (cheap to push across the boundary).
16//! 2. **Choose modification direction** toward the pre-quantization value
17//!    (minimizing perceptual distortion).
18//!
19//! The result: ~1.5-2× capacity at the same detection risk, or equivalently
20//! the same capacity with significantly lower distortion.
21//!
22//! The decoder is completely unchanged — it reads LSBs regardless of which
23//! direction the modification went.
24
25use crate::codec::jpeg::dct::DctGrid;
26use crate::codec::jpeg::pixels::dct_block_unquantized;
27
28/// Per-coefficient rounding errors from quantization.
29///
30/// Each error is in [-0.5, +0.5] and represents how far the continuous
31/// (unquantized) DCT coefficient was from its rounded integer value.
32/// Positive error means the pre-quantization value was above the integer;
33/// negative means below.
34///
35/// Stored as i8 [-127, +127] via `error * 254`, giving ~0.004 resolution.
36/// This is 8x smaller than f64 and 4x smaller than f32, saving 85 MB (12MP)
37/// or 341 MB (48MP) compared to the original f64 representation.
38pub struct SideInfo {
39    /// Rounding errors in DctGrid flat order (block_idx * 64 + row * 8 + col).
40    /// Encoded as i8: value = (error * 254).round().clamp(-127, 127).
41    rounding_errors: Vec<i8>,
42    /// Number of 8x8 blocks horizontally.
43    pub blocks_wide: usize,
44    /// Number of 8x8 blocks vertically.
45    pub blocks_tall: usize,
46}
47
48/// Encode a rounding error [-0.5, +0.5] to i8 [-127, +127].
49#[inline]
50fn encode_error(error: f64) -> i8 {
51    (error * 254.0).round().clamp(-127.0, 127.0) as i8
52}
53
54/// Decode an i8 error back to approximate f32.
55#[inline]
56fn decode_error(val: i8) -> f32 {
57    val as f32 / 254.0
58}
59
60/// Minimum cost for SI-modulated coefficients.
61///
62/// When |rounding_error| ~ 0.5 ("1/2-coefficients"), the modulated cost
63/// approaches zero. Clamping to this floor prevents zero-cost embedding
64/// at quantization midpoints, which is a known detectable artifact.
65const MIN_SI_COST: f32 = 1e-6;
66
67impl SideInfo {
68    /// Compute side information from raw RGB pixels and the cover JPEG.
69    ///
70    /// For each Y-channel 8x8 block:
71    /// 1. Forward DCT on the original (pre-JPEG) pixels
72    /// 2. Divide by quantization table (without rounding)
73    /// 3. error = unquantized_value - cover_integer_coefficient
74    ///
75    /// Errors are clamped to [-0.5, +0.5] for robustness against minor
76    /// floating-point differences between the platform's JPEG encoder and
77    /// our forward DCT implementation.
78    ///
79    /// Luma blocks are computed in strips of 50 block-rows to limit transient
80    /// memory (~12.9 MB per strip instead of ~97.5 MB for all blocks at once
81    /// on a 12MP image).
82    pub fn compute(
83        raw_rgb: &[u8],
84        pixel_width: u32,
85        pixel_height: u32,
86        cover_grid: &DctGrid,
87        qt_values: &[u16; 64],
88    ) -> Self {
89        let bw = cover_grid.blocks_wide();
90        let bh = cover_grid.blocks_tall();
91        let total_coeffs = bw * bh * 64;
92        let mut errors = vec![0i8; total_coeffs];
93
94        let luma_bw = (pixel_width as usize).div_ceil(8);
95        let luma_bh = (pixel_height as usize).div_ceil(8);
96
97        // Process luma blocks in strips to limit transient memory.
98        // Each strip holds at most STRIP_ROWS block-rows of luma data.
99        const STRIP_ROWS: usize = 50;
100        for strip_start in (0..bh).step_by(STRIP_ROWS) {
101            let strip_end = (strip_start + STRIP_ROWS).min(bh);
102            let luma_strip = rgb_to_luma_blocks_strip(
103                raw_rgb, pixel_width, pixel_height, strip_start, strip_end,
104            );
105
106            for br in strip_start..strip_end {
107                for bc in 0..bw {
108                    let block_idx = br * bw + bc;
109
110                    // Skip if outside the raw pixel grid
111                    if br >= luma_bh || bc >= luma_bw {
112                        continue; // leave errors at 0
113                    }
114
115                    let local_idx = (br - strip_start) * luma_bw + bc;
116                    let luma_block = &luma_strip[local_idx];
117
118                    // Forward DCT + divide by QT (no rounding)
119                    let unquantized = dct_block_unquantized(luma_block, qt_values);
120
121                    // Compute and clamp rounding errors, encode to i8
122                    let cover_block: [i16; 64] = {
123                        let slice = cover_grid.block(br, bc);
124                        slice.try_into().unwrap()
125                    };
126
127                    for k in 0..64 {
128                        let error = (unquantized[k] - cover_block[k] as f64).clamp(-0.5, 0.5);
129                        errors[block_idx * 64 + k] = encode_error(error);
130                    }
131                }
132            }
133            // luma_strip dropped here -- only one strip in memory at a time
134        }
135
136        SideInfo {
137            rounding_errors: errors,
138            blocks_wide: bw,
139            blocks_tall: bh,
140        }
141    }
142
143    /// Get the rounding error at a flat index, decoded from i8 to f32.
144    #[inline]
145    pub fn error_at(&self, flat_idx: usize) -> f32 {
146        decode_error(self.rounding_errors[flat_idx])
147    }
148}
149
150/// Convert a horizontal strip of RGB pixels to Y (luminance) 8x8 blocks.
151///
152/// Only converts block rows in `[br_start, br_end)`, returning them in
153/// raster order with `luma_bw` blocks per row. This avoids allocating
154/// ALL luma blocks at once (97.5 MB for 12MP, 390 MB for 48MP).
155///
156/// Uses BT.601: `Y = 0.299*R + 0.587*G + 0.114*B`.
157/// Handles non-multiple-of-8 dimensions by edge-replicating.
158fn rgb_to_luma_blocks_strip(
159    rgb: &[u8],
160    width: u32,
161    height: u32,
162    br_start: usize,
163    br_end: usize,
164) -> Vec<[f64; 64]> {
165    let w = width as usize;
166    let h = height as usize;
167    let luma_bw = w.div_ceil(8);
168    let luma_bh = h.div_ceil(8);
169
170    let strip_br_end = br_end.min(luma_bh);
171    let strip_rows = strip_br_end.saturating_sub(br_start);
172
173    let mut blocks = Vec::with_capacity(strip_rows * luma_bw);
174
175    for br in br_start..strip_br_end {
176        for bc in 0..luma_bw {
177            let mut block = [0.0f64; 64];
178            for row in 0..8 {
179                for col in 0..8 {
180                    let py = (br * 8 + row).min(h - 1);
181                    let px = (bc * 8 + col).min(w - 1);
182                    let idx = (py * w + px) * 3;
183                    let r = rgb[idx] as f64;
184                    let g = rgb[idx + 1] as f64;
185                    let b = rgb[idx + 2] as f64;
186                    block[row * 8 + col] = 0.299 * r + 0.587 * g + 0.114 * b;
187                }
188            }
189            blocks.push(block);
190        }
191    }
192
193    blocks
194}
195
196/// Modulate J-UNIWARD costs using SI rounding errors.
197///
198/// For each AC coefficient with finite cost:
199/// - `modulated_cost = rho * (1 - 2|e|)` where `e` is the rounding error
200/// - Larger |e| -> lower cost (closer to quantization boundary -> cheaper to flip)
201/// - |e| = 0 -> cost unchanged (exactly on the integer -> no benefit)
202/// - |e| = 0.5 -> cost clamped to `MIN_SI_COST` (avoid zero-cost artifact)
203///
204/// Special cases:
205/// - DC coefficients remain WET (infinite cost)
206/// - |coeff| = 1 positions: cost is NOT modulated (anti-shrinkage forces
207///   the direction, so the rounding error doesn't help choose direction)
208pub fn modulate_costs_si(
209    cost_map: &mut crate::stego::cost::CostMap,
210    side_info: &SideInfo,
211    cover_grid: &DctGrid,
212) {
213    let bw = cost_map.blocks_wide();
214    let bh = cost_map.blocks_tall();
215
216    for br in 0..bh {
217        for bc in 0..bw {
218            let block_idx = br * bw + bc;
219            for i in 0..8 {
220                for j in 0..8 {
221                    // Skip DC
222                    if i == 0 && j == 0 {
223                        continue;
224                    }
225
226                    let cost = cost_map.get(br, bc, i, j);
227                    if !cost.is_finite() {
228                        continue; // WET position -- leave as-is
229                    }
230
231                    // Skip |coeff| == 1: anti-shrinkage forces direction,
232                    // SI modulation doesn't help
233                    let coeff = cover_grid.get(br, bc, i, j);
234                    if coeff.abs() == 1 {
235                        continue;
236                    }
237
238                    let flat_idx = block_idx * 64 + i * 8 + j;
239                    let error = side_info.error_at(flat_idx);
240                    let abs_error = error.abs();
241
242                    // modulated = rho * (1 - 2|e|)
243                    // When |e| = 0.5: modulated = 0 -> clamp to MIN_SI_COST
244                    let factor = 1.0f32 - 2.0 * abs_error;
245                    let modulated = (cost * factor).max(MIN_SI_COST);
246                    cost_map.set(br, bc, i, j, modulated);
247                }
248            }
249        }
250    }
251}
252
253/// Determine the modification direction for a coefficient using SI rounding error.
254///
255/// Returns the modified coefficient value (coeff +/- 1).
256///
257/// Rules:
258/// - |coeff| == 1: ALWAYS away from zero (anti-shrinkage, prevents coeff -> 0)
259/// - |coeff| > 1 with side info: toward the pre-quantization value
260///   (error > 0 -> precover was above -> go up; error < 0 -> go down)
261/// - |coeff| > 1 without side info: nsF5 convention (toward zero)
262/// - coeff == 0: should never be called (filtered out as WET)
263#[inline]
264pub fn si_modify_coefficient(coeff: i16, rounding_error: f32) -> i16 {
265    if coeff == 1 {
266        2 // anti-shrinkage: away from zero
267    } else if coeff == -1 {
268        -2 // anti-shrinkage: away from zero
269    } else if rounding_error > 0.0 {
270        coeff + 1 // precover was above -> go up
271    } else {
272        coeff - 1 // precover was at or below -> go down
273    }
274}
275
276/// Standard nsF5 modification direction (no side info).
277///
278/// - |coeff| == 1: away from zero
279/// - |coeff| > 1: toward zero
280#[inline]
281pub fn nsf5_modify_coefficient(coeff: i16) -> i16 {
282    if coeff == 1 {
283        2
284    } else if coeff == -1 {
285        -2
286    } else if coeff > 1 {
287        coeff - 1
288    } else if coeff < -1 {
289        coeff + 1
290    } else {
291        coeff // zero: should never happen
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use crate::codec::jpeg::pixels::dct_block;
299
300    // --- T1: dct_block_unquantized matches dct_block ---
301
302    fn standard_qt() -> [u16; 64] {
303        [
304            16, 11, 10, 16, 24, 40, 51, 61,
305            12, 12, 14, 19, 26, 58, 60, 55,
306            14, 13, 16, 24, 40, 57, 69, 56,
307            14, 17, 22, 29, 51, 87, 80, 62,
308            18, 22, 37, 56, 68, 109, 103, 77,
309            24, 35, 55, 64, 81, 104, 113, 92,
310            49, 64, 78, 87, 103, 121, 120, 101,
311            72, 92, 95, 98, 112, 100, 103, 99,
312        ]
313    }
314
315    #[test]
316    fn t1_unquantized_rounds_to_quantized() {
317        // Various test patterns
318        let patterns: Vec<[f64; 64]> = vec![
319            // Flat gray
320            [128.0; 64],
321            // Gradient
322            {
323                let mut p = [0.0f64; 64];
324                for i in 0..64 {
325                    p[i] = 50.0 + (i as f64) * 3.0;
326                }
327                p
328            },
329            // High contrast
330            {
331                let mut p = [0.0f64; 64];
332                for i in 0..64 {
333                    p[i] = if i % 2 == 0 { 20.0 } else { 230.0 };
334                }
335                p
336            },
337        ];
338
339        let qt = standard_qt();
340        for pixels in &patterns {
341            let quantized = dct_block(pixels, &qt);
342            let unquantized = dct_block_unquantized(pixels, &qt);
343            for i in 0..64 {
344                assert_eq!(
345                    quantized[i],
346                    unquantized[i].round() as i16,
347                    "Mismatch at index {i}: quantized={}, unquantized={}",
348                    quantized[i],
349                    unquantized[i]
350                );
351            }
352        }
353    }
354
355    // --- T2: Rounding errors in range ---
356
357    #[test]
358    fn t2_rounding_errors_in_range() {
359        let qt = standard_qt();
360        // Test with multiple pixel patterns
361        for seed in 0..10u8 {
362            let mut pixels = [0.0f64; 64];
363            for i in 0..64 {
364                pixels[i] = ((seed as f64 * 37.0 + i as f64 * 13.0) % 256.0).abs();
365            }
366            let quantized = dct_block(&pixels, &qt);
367            let unquantized = dct_block_unquantized(&pixels, &qt);
368            for i in 0..64 {
369                let error = unquantized[i] - quantized[i] as f64;
370                assert!(
371                    (-0.50001..=0.50001).contains(&error),
372                    "seed={seed}, index={i}: error={error}"
373                );
374            }
375        }
376    }
377
378    // --- T3: Half-coefficient clamping ---
379
380    #[test]
381    fn t3_half_coefficient_cost_not_zero() {
382        // When |error| = 0.5, modulated cost must NOT be zero
383        let factor = 1.0f32 - 2.0 * 0.5_f32; // = 0.0
384        let cost = 1.0f32;
385        let modulated = (cost * factor).max(MIN_SI_COST);
386        assert!(modulated > 0.0, "1/2-coefficient must not have zero cost");
387        assert_eq!(modulated, MIN_SI_COST);
388    }
389
390    // --- T4: Asymmetric cost modulation ---
391
392    #[test]
393    fn t4_si_cost_scales_with_rounding_error() {
394        // Larger |error| -> lower cost
395        let cost = 1.0f32;
396        let small_error = 0.1f32;
397        let large_error = 0.4f32;
398
399        let small_modulated = (cost * (1.0f32 - 2.0 * small_error)).max(MIN_SI_COST);
400        let large_modulated = (cost * (1.0f32 - 2.0 * large_error)).max(MIN_SI_COST);
401
402        assert!(
403            large_modulated < small_modulated,
404            "larger error should give lower cost: small={small_modulated}, large={large_modulated}"
405        );
406    }
407
408    #[test]
409    fn t4_si_costs_never_exceed_original() {
410        // Modulated costs should always be <= original
411        for error_pct in 0..=50 {
412            let error = error_pct as f32 / 100.0;
413            let cost = 5.0f32;
414            let factor = 1.0f32 - 2.0 * error;
415            let modulated = (cost * factor).max(MIN_SI_COST);
416            assert!(
417                modulated <= cost + 1e-6,
418                "modulated={modulated} > original={cost} at error={error}"
419            );
420        }
421    }
422
423    // --- T5: Anti-shrinkage preserved ---
424
425    #[test]
426    fn t5_anti_shrinkage_preserved() {
427        // |coeff| = 1 must always go away from zero
428        assert_eq!(si_modify_coefficient(1, -0.4_f32), 2);
429        assert_eq!(si_modify_coefficient(1, 0.4_f32), 2);
430        assert_eq!(si_modify_coefficient(1, 0.0_f32), 2);
431        assert_eq!(si_modify_coefficient(-1, -0.4_f32), -2);
432        assert_eq!(si_modify_coefficient(-1, 0.4_f32), -2);
433        assert_eq!(si_modify_coefficient(-1, 0.0_f32), -2);
434    }
435
436    // --- T6: Direction selection ---
437
438    #[test]
439    fn t6_direction_follows_rounding_error() {
440        // Positive error -> precover above -> go up
441        assert_eq!(si_modify_coefficient(5, 0.3_f32), 6);
442        assert_eq!(si_modify_coefficient(-5, 0.3_f32), -4); // toward zero = up
443
444        // Negative error -> precover below -> go down
445        assert_eq!(si_modify_coefficient(5, -0.3_f32), 4);
446        assert_eq!(si_modify_coefficient(-5, -0.3_f32), -6); // away from zero = down
447
448        // Zero error -> down (the else branch)
449        assert_eq!(si_modify_coefficient(5, 0.0_f32), 4);
450        assert_eq!(si_modify_coefficient(-5, 0.0_f32), -6);
451    }
452
453    // --- T6b: nsF5 direction ---
454
455    #[test]
456    fn t6b_nsf5_toward_zero() {
457        assert_eq!(nsf5_modify_coefficient(5), 4);
458        assert_eq!(nsf5_modify_coefficient(-5), -4);
459        assert_eq!(nsf5_modify_coefficient(2), 1);
460        assert_eq!(nsf5_modify_coefficient(-2), -1);
461        // Anti-shrinkage
462        assert_eq!(nsf5_modify_coefficient(1), 2);
463        assert_eq!(nsf5_modify_coefficient(-1), -2);
464    }
465
466    // --- T7: i8 encode/decode roundtrip ---
467
468    #[test]
469    fn t7_i8_encode_decode_precision() {
470        // Test that encode_error/decode_error roundtrip has <1% error
471        // for the cost modulation factor (1 - 2|e|).
472        for i in 0..=100 {
473            let error = (i as f64 - 50.0) / 100.0; // [-0.5, +0.5]
474            let encoded = encode_error(error);
475            let decoded = decode_error(encoded);
476
477            // Check the modulation factor precision
478            let original_factor = 1.0 - 2.0 * error.abs();
479            let decoded_factor = 1.0f32 - 2.0 * decoded.abs();
480            let factor_error = (original_factor as f32 - decoded_factor).abs();
481            assert!(
482                factor_error < 0.02, // <2% error on factor
483                "error={error}, encoded={encoded}, decoded={decoded}, factor_error={factor_error}"
484            );
485        }
486    }
487
488    #[test]
489    fn t7_i8_sign_preserved() {
490        // Sign must be exact for si_modify_coefficient direction
491        assert!(decode_error(encode_error(0.3)) > 0.0);
492        assert!(decode_error(encode_error(-0.3)) < 0.0);
493        assert_eq!(decode_error(encode_error(0.0)), 0.0);
494    }
495
496    // --- T8: strip-based luma matches full computation ---
497
498    #[test]
499    fn t8_strip_luma_matches_full() {
500        use crate::codec::jpeg::pixels::rgb_to_luma_blocks;
501
502        // Create a small test image (24x16 = 3x2 blocks)
503        let width = 24u32;
504        let height = 16u32;
505        let mut rgb = vec![0u8; (width * height * 3) as usize];
506        for i in 0..rgb.len() {
507            rgb[i] = ((i * 37 + 13) % 256) as u8;
508        }
509
510        let full_blocks = rgb_to_luma_blocks(&rgb, width, height);
511        let luma_bw = (width as usize).div_ceil(8); // 3
512
513        // Get strip for all rows at once
514        let strip_all = rgb_to_luma_blocks_strip(&rgb, width, height, 0, 2);
515        assert_eq!(strip_all.len(), full_blocks.len());
516        for (i, (a, b)) in full_blocks.iter().zip(strip_all.iter()).enumerate() {
517            for k in 0..64 {
518                assert!(
519                    (a[k] - b[k]).abs() < 1e-10,
520                    "block {i}, coeff {k}: full={}, strip={}",
521                    a[k], b[k]
522                );
523            }
524        }
525
526        // Get strips one row at a time
527        let strip0 = rgb_to_luma_blocks_strip(&rgb, width, height, 0, 1);
528        let strip1 = rgb_to_luma_blocks_strip(&rgb, width, height, 1, 2);
529        assert_eq!(strip0.len(), luma_bw);
530        assert_eq!(strip1.len(), luma_bw);
531        for bc in 0..luma_bw {
532            for k in 0..64 {
533                assert!(
534                    (full_blocks[bc][k] - strip0[bc][k]).abs() < 1e-10,
535                    "row 0, block {bc}, coeff {k}"
536                );
537                assert!(
538                    (full_blocks[luma_bw + bc][k] - strip1[bc][k]).abs() < 1e-10,
539                    "row 1, block {bc}, coeff {k}"
540                );
541            }
542        }
543    }
544}