Skip to main content

phasm_core/stego/armor/
embedding.rs

1// Copyright (c) 2026 Christoph Gaffga
2// SPDX-License-Identifier: GPL-3.0-only
3// https://github.com/cgaffga/phasmcore
4
5//! STDM (Spread Transform Dither Modulation) embedding and extraction.
6//!
7//! Embeds one message bit per embedding unit of L coefficients using
8//! dither-quantized projections onto spreading vectors. The quantization
9//! step `delta` controls robustness vs. distortion.
10
11use super::spreading::SPREAD_LEN;
12use crate::codec::jpeg::zigzag::NATURAL_TO_ZIGZAG;
13
14/// Fixed bootstrap delta for the header region.
15///
16/// This constant is used to embed/extract the mean-QT byte at the start
17/// of the embedding stream. It must be robust enough to survive recompression
18/// (56 units with 7× redundancy).
19pub const BOOTSTRAP_DELTA: f64 = 100.0;
20
21/// Maximum zigzag position for frequency-restricted embedding.
22/// Positions 1..=MAX_ARMOR_ZIGZAG are used; higher frequencies are excluded
23/// to prevent pixel clamping issues during recompression.
24pub const MAX_ARMOR_ZIGZAG: usize = 15;
25
26/// Mean of actual QT values at zigzag positions 1..=MAX_ARMOR_ZIGZAG.
27pub fn compute_mean_qt(qt_values: &[u16; 64]) -> f64 {
28    let mut sum = 0.0f64;
29    let mut count = 0usize;
30    for nat_idx in 0..64 {
31        let zz = NATURAL_TO_ZIGZAG[nat_idx];
32        if (1..=MAX_ARMOR_ZIGZAG).contains(&zz) {
33            sum += qt_values[nat_idx] as f64;
34            count += 1;
35        }
36    }
37    if count == 0 {
38        return 10.0; // fallback
39    }
40    sum / count as f64
41}
42
43/// Encode mean QT as a header byte: round(mean * 4).clamp(1, 255).
44pub fn encode_mean_qt(mean_qt: f64) -> u8 {
45    (mean_qt * 4.0).round().clamp(1.0, 255.0) as u8
46}
47
48/// Decode header byte back to mean QT: byte / 4.0.
49pub fn decode_mean_qt(header_byte: u8) -> f64 {
50    header_byte as f64 / 4.0
51}
52
53/// Number of header bytes (mean QT byte).
54pub const HEADER_BYTES: usize = 1;
55
56/// Number of embedding units for the header.
57/// 1 byte × 8 bits × 7 copies = 56 units.
58pub const HEADER_UNITS: usize = HEADER_BYTES * 8 * HEADER_COPIES;
59
60/// Number of copies for header majority voting.
61pub const HEADER_COPIES: usize = 7;
62
63/// Compute delta from mean QT value and repetition factor.
64/// Uses adaptive multipliers scaled by r for larger decision regions.
65pub fn compute_delta_from_mean_qt(mean_qt: f64, r: usize) -> f64 {
66    let mult = if r >= 7 {
67        8.0
68    } else if r >= 5 {
69        7.0
70    } else if r >= 3 {
71        6.0
72    } else if r >= 2 {
73        4.0
74    } else {
75        3.0 // Phase 1 base: was 2.0, now 3.0
76    };
77    mult * mean_qt
78}
79
80/// Embed a single bit into a group of coefficients using STDM.
81///
82/// - `coeffs`: L coefficient values (will be modified in place)
83/// - `v`: unit-norm spreading vector of length L
84/// - `bit`: the message bit to embed (0 or 1)
85/// - `delta`: quantization step size
86pub fn stdm_embed(coeffs: &mut [f64; SPREAD_LEN], v: &[f64; SPREAD_LEN], bit: u8, delta: f64) {
87    debug_assert!(bit <= 1);
88
89    // Project onto spreading vector
90    let p: f64 = coeffs.iter().zip(v.iter()).map(|(&c, &vi)| c * vi).sum();
91
92    // Dither-quantize to encode the bit
93    let q = quantize_for_bit(p, delta, bit);
94
95    // Distribute the change along the spreading vector
96    let dp = q - p;
97    for i in 0..SPREAD_LEN {
98        coeffs[i] += dp * v[i];
99    }
100}
101
102/// Extract a single bit from a group of coefficients using STDM.
103///
104/// - `coeffs`: L coefficient values
105/// - `v`: unit-norm spreading vector of length L (same as used for embedding)
106/// - `delta`: quantization step size (same as used for embedding)
107///
108/// Returns the extracted bit (0 or 1).
109#[cfg(test)]
110pub fn stdm_extract(coeffs: &[f64; SPREAD_LEN], v: &[f64; SPREAD_LEN], delta: f64) -> u8 {
111    let p: f64 = coeffs.iter().zip(v.iter()).map(|(&c, &vi)| c * vi).sum();
112
113    // Determine which quantizer lattice is closest
114    let half_delta = delta / 2.0;
115    let m = (p / half_delta).round() as i64;
116    m.rem_euclid(2) as u8
117}
118
119/// Quantize `p` to the nearest point in the Q_b lattice.
120///
121/// - Q_0: centers at {n * delta} for integer n
122/// - Q_1: centers at {(n + 0.5) * delta} for integer n
123fn quantize_for_bit(p: f64, delta: f64, bit: u8) -> f64 {
124    if bit == 0 {
125        (p / delta).round() * delta
126    } else {
127        ((p / delta - 0.5).round() + 0.5) * delta
128    }
129}
130
131/// Extract a bit with soft confidence (log-likelihood ratio).
132///
133/// Positive LLR → bit 0 more likely, negative → bit 1 more likely.
134/// |LLR| magnitude indicates confidence.
135pub fn stdm_extract_soft(coeffs: &[f64; SPREAD_LEN], v: &[f64; SPREAD_LEN], delta: f64) -> f64 {
136    let p: f64 = coeffs.iter().zip(v.iter()).map(|(&c, &vi)| c * vi).sum();
137
138    // Distance to nearest Q_0 lattice point
139    let q0 = (p / delta).round() * delta;
140    let d0 = (p - q0).abs();
141
142    // Distance to nearest Q_1 lattice point
143    let q1 = ((p / delta - 0.5).round() + 0.5) * delta;
144    let d1 = (p - q1).abs();
145
146    // LLR: positive favors bit 0, negative favors bit 1
147    d1 - d0
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    fn make_spreading_vec() -> [f64; SPREAD_LEN] {
155        // Simple normalized vector
156        let raw = [1.0, 0.5, -0.3, 0.7, -0.2, 0.4, 0.6, -0.1];
157        let norm: f64 = raw.iter().map(|x| x * x).sum::<f64>().sqrt();
158        let mut v = [0.0; SPREAD_LEN];
159        for i in 0..SPREAD_LEN {
160            v[i] = raw[i] / norm;
161        }
162        v
163    }
164
165    #[test]
166    fn embed_extract_roundtrip_bit0() {
167        let v = make_spreading_vec();
168        let delta = 10.0;
169        let mut coeffs = [20.0, 15.0, -8.0, 30.0, -5.0, 10.0, 25.0, -3.0];
170
171        stdm_embed(&mut coeffs, &v, 0, delta);
172        let extracted = stdm_extract(&coeffs, &v, delta);
173        assert_eq!(extracted, 0);
174    }
175
176    #[test]
177    fn embed_extract_roundtrip_bit1() {
178        let v = make_spreading_vec();
179        let delta = 10.0;
180        let mut coeffs = [20.0, 15.0, -8.0, 30.0, -5.0, 10.0, 25.0, -3.0];
181
182        stdm_embed(&mut coeffs, &v, 1, delta);
183        let extracted = stdm_extract(&coeffs, &v, delta);
184        assert_eq!(extracted, 1);
185    }
186
187    #[test]
188    fn embed_extract_many_bits() {
189        let v = make_spreading_vec();
190        let delta = 8.0;
191
192        for bit in 0..=1 {
193            for base in [-50.0, -10.0, 0.0, 10.0, 50.0] {
194                let mut coeffs = [base; SPREAD_LEN];
195                stdm_embed(&mut coeffs, &v, bit, delta);
196                let extracted = stdm_extract(&coeffs, &v, delta);
197                assert_eq!(extracted, bit, "failed for bit={bit}, base={base}");
198            }
199        }
200    }
201
202    #[test]
203    fn survives_small_perturbation() {
204        let v = make_spreading_vec();
205        let delta = 16.0; // Large delta for more robustness
206
207        for bit in 0..=1 {
208            let mut coeffs = [20.0, -10.0, 5.0, 30.0, -15.0, 8.0, 12.0, -6.0];
209            stdm_embed(&mut coeffs, &v, bit, delta);
210
211            // Add small perturbation (simulating quantization noise)
212            for c in coeffs.iter_mut() {
213                *c += 0.3;
214            }
215
216            let extracted = stdm_extract(&coeffs, &v, delta);
217            assert_eq!(extracted, bit, "failed for bit={bit} after perturbation");
218        }
219    }
220
221    #[test]
222    fn quantize_for_bit_correct() {
223        let delta = 10.0;
224
225        // For bit=0, should quantize to nearest multiple of delta
226        assert!((quantize_for_bit(7.0, delta, 0) - 10.0).abs() < 1e-10);
227        assert!((quantize_for_bit(3.0, delta, 0) - 0.0).abs() < 1e-10);
228        assert!((quantize_for_bit(-7.0, delta, 0) - -10.0).abs() < 1e-10);
229
230        // For bit=1, should quantize to nearest half-multiple of delta
231        assert!((quantize_for_bit(3.0, delta, 1) - 5.0).abs() < 1e-10);
232        assert!((quantize_for_bit(8.0, delta, 1) - 5.0).abs() < 1e-10);
233        assert!((quantize_for_bit(12.0, delta, 1) - 15.0).abs() < 1e-10);
234    }
235
236    #[test]
237    fn compute_mean_qt_reasonable() {
238        // Standard luma QT at QF 75 (approximate)
239        let qt = [8, 6, 5, 8, 12, 20, 26, 31,
240                   6, 6, 7, 10, 13, 29, 30, 28,
241                   7, 7, 8, 12, 20, 29, 35, 28,
242                   7, 9, 11, 15, 26, 44, 40, 31,
243                   9, 11, 19, 28, 34, 55, 52, 39,
244                   12, 18, 28, 32, 41, 52, 57, 46,
245                   25, 32, 39, 44, 52, 61, 60, 51,
246                   36, 46, 48, 49, 56, 50, 52, 50];
247        let mean = compute_mean_qt(&qt);
248        // Mean of low-freq AC positions should be reasonable
249        assert!(mean > 5.0 && mean < 30.0, "mean_qt={mean}");
250    }
251
252    #[test]
253    fn mean_qt_encode_decode_roundtrip() {
254        for qt_val in [5.0, 10.0, 15.5, 25.0, 50.0, 63.0] {
255            let encoded = encode_mean_qt(qt_val);
256            let decoded = decode_mean_qt(encoded);
257            assert!((decoded - qt_val).abs() < 0.5, "roundtrip failed: {qt_val} -> {encoded} -> {decoded}");
258        }
259    }
260
261    #[test]
262    fn soft_extract_sign_matches_hard_extract() {
263        let v = make_spreading_vec();
264        let delta = 10.0;
265
266        for bit in 0..=1 {
267            let mut coeffs = [20.0, 15.0, -8.0, 30.0, -5.0, 10.0, 25.0, -3.0];
268            stdm_embed(&mut coeffs, &v, bit, delta);
269
270            let llr = stdm_extract_soft(&coeffs, &v, delta);
271            let hard_bit = stdm_extract(&coeffs, &v, delta);
272
273            // Positive LLR → bit 0, negative → bit 1
274            let soft_bit = if llr >= 0.0 { 0u8 } else { 1u8 };
275            assert_eq!(soft_bit, hard_bit, "bit={bit}, llr={llr}");
276            assert_eq!(soft_bit, bit, "bit={bit}, llr={llr}");
277        }
278    }
279
280    #[test]
281    fn soft_extract_confidence_decreases_with_noise() {
282        let v = make_spreading_vec();
283        let delta = 16.0;
284        let mut coeffs = [20.0, -10.0, 5.0, 30.0, -15.0, 8.0, 12.0, -6.0];
285        stdm_embed(&mut coeffs, &v, 0, delta);
286
287        let llr_clean = stdm_extract_soft(&coeffs, &v, delta);
288        assert!(llr_clean > 0.0, "should favor bit 0");
289
290        // Add noise
291        let mut noisy = coeffs;
292        for c in noisy.iter_mut() {
293            *c += 2.0;
294        }
295        let llr_noisy = stdm_extract_soft(&noisy, &v, delta);
296        // Confidence may decrease or stay, but should still likely favor bit 0
297        assert!(llr_clean.abs() >= llr_noisy.abs() - 1.0, "noise should not increase confidence dramatically");
298    }
299
300    #[test]
301    fn header_units_constant_correct() {
302        assert_eq!(HEADER_UNITS, HEADER_BYTES * 8 * HEADER_COPIES);
303        assert_eq!(HEADER_UNITS, 56);
304    }
305
306    #[test]
307    fn delta_increases_with_r() {
308        let mean_qt = 10.0;
309        let d1 = compute_delta_from_mean_qt(mean_qt, 1);
310        let d2 = compute_delta_from_mean_qt(mean_qt, 2);
311        let d3 = compute_delta_from_mean_qt(mean_qt, 3);
312        let d5 = compute_delta_from_mean_qt(mean_qt, 5);
313        let d7 = compute_delta_from_mean_qt(mean_qt, 7);
314
315        assert!(d2 > d1, "r=2 should increase delta");
316        assert!(d3 > d2, "r=3 should increase delta more");
317        assert!(d5 > d3, "r=5 should increase delta further");
318        assert!(d7 > d5, "r=7 should increase delta even more");
319
320        // Verify exact multipliers
321        assert!((d1 - 30.0).abs() < 1e-10, "r=1: 3.0 * 10.0 = 30.0, got {d1}");
322        assert!((d2 - 40.0).abs() < 1e-10, "r=2: 4.0 * 10.0 = 40.0, got {d2}");
323        assert!((d3 - 60.0).abs() < 1e-10, "r=3: 6.0 * 10.0 = 60.0, got {d3}");
324        assert!((d5 - 70.0).abs() < 1e-10, "r=5: 7.0 * 10.0 = 70.0, got {d5}");
325        assert!((d7 - 80.0).abs() < 1e-10, "r=7: 8.0 * 10.0 = 80.0, got {d7}");
326    }
327}