1use super::spreading::SPREAD_LEN;
12use crate::codec::jpeg::zigzag::NATURAL_TO_ZIGZAG;
13
14pub const BOOTSTRAP_DELTA: f64 = 100.0;
20
21pub const MAX_ARMOR_ZIGZAG: usize = 15;
25
26pub fn compute_mean_qt(qt_values: &[u16; 64]) -> f64 {
28 let mut sum = 0.0f64;
29 let mut count = 0usize;
30 for nat_idx in 0..64 {
31 let zz = NATURAL_TO_ZIGZAG[nat_idx];
32 if (1..=MAX_ARMOR_ZIGZAG).contains(&zz) {
33 sum += qt_values[nat_idx] as f64;
34 count += 1;
35 }
36 }
37 if count == 0 {
38 return 10.0; }
40 sum / count as f64
41}
42
43pub fn encode_mean_qt(mean_qt: f64) -> u8 {
45 (mean_qt * 4.0).round().clamp(1.0, 255.0) as u8
46}
47
48pub fn decode_mean_qt(header_byte: u8) -> f64 {
50 header_byte as f64 / 4.0
51}
52
53pub const HEADER_BYTES: usize = 1;
55
56pub const HEADER_UNITS: usize = HEADER_BYTES * 8 * HEADER_COPIES;
59
60pub const HEADER_COPIES: usize = 7;
62
63pub fn compute_delta_from_mean_qt(mean_qt: f64, r: usize) -> f64 {
66 let mult = if r >= 7 {
67 8.0
68 } else if r >= 5 {
69 7.0
70 } else if r >= 3 {
71 6.0
72 } else if r >= 2 {
73 4.0
74 } else {
75 3.0 };
77 mult * mean_qt
78}
79
80pub fn stdm_embed(coeffs: &mut [f64; SPREAD_LEN], v: &[f64; SPREAD_LEN], bit: u8, delta: f64) {
87 debug_assert!(bit <= 1);
88
89 let p: f64 = coeffs.iter().zip(v.iter()).map(|(&c, &vi)| c * vi).sum();
91
92 let q = quantize_for_bit(p, delta, bit);
94
95 let dp = q - p;
97 for i in 0..SPREAD_LEN {
98 coeffs[i] += dp * v[i];
99 }
100}
101
102#[cfg(test)]
110pub fn stdm_extract(coeffs: &[f64; SPREAD_LEN], v: &[f64; SPREAD_LEN], delta: f64) -> u8 {
111 let p: f64 = coeffs.iter().zip(v.iter()).map(|(&c, &vi)| c * vi).sum();
112
113 let half_delta = delta / 2.0;
115 let m = (p / half_delta).round() as i64;
116 m.rem_euclid(2) as u8
117}
118
119fn quantize_for_bit(p: f64, delta: f64, bit: u8) -> f64 {
124 if bit == 0 {
125 (p / delta).round() * delta
126 } else {
127 ((p / delta - 0.5).round() + 0.5) * delta
128 }
129}
130
131pub fn stdm_extract_soft(coeffs: &[f64; SPREAD_LEN], v: &[f64; SPREAD_LEN], delta: f64) -> f64 {
136 let p: f64 = coeffs.iter().zip(v.iter()).map(|(&c, &vi)| c * vi).sum();
137
138 let q0 = (p / delta).round() * delta;
140 let d0 = (p - q0).abs();
141
142 let q1 = ((p / delta - 0.5).round() + 0.5) * delta;
144 let d1 = (p - q1).abs();
145
146 d1 - d0
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 fn make_spreading_vec() -> [f64; SPREAD_LEN] {
155 let raw = [1.0, 0.5, -0.3, 0.7, -0.2, 0.4, 0.6, -0.1];
157 let norm: f64 = raw.iter().map(|x| x * x).sum::<f64>().sqrt();
158 let mut v = [0.0; SPREAD_LEN];
159 for i in 0..SPREAD_LEN {
160 v[i] = raw[i] / norm;
161 }
162 v
163 }
164
165 #[test]
166 fn embed_extract_roundtrip_bit0() {
167 let v = make_spreading_vec();
168 let delta = 10.0;
169 let mut coeffs = [20.0, 15.0, -8.0, 30.0, -5.0, 10.0, 25.0, -3.0];
170
171 stdm_embed(&mut coeffs, &v, 0, delta);
172 let extracted = stdm_extract(&coeffs, &v, delta);
173 assert_eq!(extracted, 0);
174 }
175
176 #[test]
177 fn embed_extract_roundtrip_bit1() {
178 let v = make_spreading_vec();
179 let delta = 10.0;
180 let mut coeffs = [20.0, 15.0, -8.0, 30.0, -5.0, 10.0, 25.0, -3.0];
181
182 stdm_embed(&mut coeffs, &v, 1, delta);
183 let extracted = stdm_extract(&coeffs, &v, delta);
184 assert_eq!(extracted, 1);
185 }
186
187 #[test]
188 fn embed_extract_many_bits() {
189 let v = make_spreading_vec();
190 let delta = 8.0;
191
192 for bit in 0..=1 {
193 for base in [-50.0, -10.0, 0.0, 10.0, 50.0] {
194 let mut coeffs = [base; SPREAD_LEN];
195 stdm_embed(&mut coeffs, &v, bit, delta);
196 let extracted = stdm_extract(&coeffs, &v, delta);
197 assert_eq!(extracted, bit, "failed for bit={bit}, base={base}");
198 }
199 }
200 }
201
202 #[test]
203 fn survives_small_perturbation() {
204 let v = make_spreading_vec();
205 let delta = 16.0; for bit in 0..=1 {
208 let mut coeffs = [20.0, -10.0, 5.0, 30.0, -15.0, 8.0, 12.0, -6.0];
209 stdm_embed(&mut coeffs, &v, bit, delta);
210
211 for c in coeffs.iter_mut() {
213 *c += 0.3;
214 }
215
216 let extracted = stdm_extract(&coeffs, &v, delta);
217 assert_eq!(extracted, bit, "failed for bit={bit} after perturbation");
218 }
219 }
220
221 #[test]
222 fn quantize_for_bit_correct() {
223 let delta = 10.0;
224
225 assert!((quantize_for_bit(7.0, delta, 0) - 10.0).abs() < 1e-10);
227 assert!((quantize_for_bit(3.0, delta, 0) - 0.0).abs() < 1e-10);
228 assert!((quantize_for_bit(-7.0, delta, 0) - -10.0).abs() < 1e-10);
229
230 assert!((quantize_for_bit(3.0, delta, 1) - 5.0).abs() < 1e-10);
232 assert!((quantize_for_bit(8.0, delta, 1) - 5.0).abs() < 1e-10);
233 assert!((quantize_for_bit(12.0, delta, 1) - 15.0).abs() < 1e-10);
234 }
235
236 #[test]
237 fn compute_mean_qt_reasonable() {
238 let qt = [8, 6, 5, 8, 12, 20, 26, 31,
240 6, 6, 7, 10, 13, 29, 30, 28,
241 7, 7, 8, 12, 20, 29, 35, 28,
242 7, 9, 11, 15, 26, 44, 40, 31,
243 9, 11, 19, 28, 34, 55, 52, 39,
244 12, 18, 28, 32, 41, 52, 57, 46,
245 25, 32, 39, 44, 52, 61, 60, 51,
246 36, 46, 48, 49, 56, 50, 52, 50];
247 let mean = compute_mean_qt(&qt);
248 assert!(mean > 5.0 && mean < 30.0, "mean_qt={mean}");
250 }
251
252 #[test]
253 fn mean_qt_encode_decode_roundtrip() {
254 for qt_val in [5.0, 10.0, 15.5, 25.0, 50.0, 63.0] {
255 let encoded = encode_mean_qt(qt_val);
256 let decoded = decode_mean_qt(encoded);
257 assert!((decoded - qt_val).abs() < 0.5, "roundtrip failed: {qt_val} -> {encoded} -> {decoded}");
258 }
259 }
260
261 #[test]
262 fn soft_extract_sign_matches_hard_extract() {
263 let v = make_spreading_vec();
264 let delta = 10.0;
265
266 for bit in 0..=1 {
267 let mut coeffs = [20.0, 15.0, -8.0, 30.0, -5.0, 10.0, 25.0, -3.0];
268 stdm_embed(&mut coeffs, &v, bit, delta);
269
270 let llr = stdm_extract_soft(&coeffs, &v, delta);
271 let hard_bit = stdm_extract(&coeffs, &v, delta);
272
273 let soft_bit = if llr >= 0.0 { 0u8 } else { 1u8 };
275 assert_eq!(soft_bit, hard_bit, "bit={bit}, llr={llr}");
276 assert_eq!(soft_bit, bit, "bit={bit}, llr={llr}");
277 }
278 }
279
280 #[test]
281 fn soft_extract_confidence_decreases_with_noise() {
282 let v = make_spreading_vec();
283 let delta = 16.0;
284 let mut coeffs = [20.0, -10.0, 5.0, 30.0, -15.0, 8.0, 12.0, -6.0];
285 stdm_embed(&mut coeffs, &v, 0, delta);
286
287 let llr_clean = stdm_extract_soft(&coeffs, &v, delta);
288 assert!(llr_clean > 0.0, "should favor bit 0");
289
290 let mut noisy = coeffs;
292 for c in noisy.iter_mut() {
293 *c += 2.0;
294 }
295 let llr_noisy = stdm_extract_soft(&noisy, &v, delta);
296 assert!(llr_clean.abs() >= llr_noisy.abs() - 1.0, "noise should not increase confidence dramatically");
298 }
299
300 #[test]
301 fn header_units_constant_correct() {
302 assert_eq!(HEADER_UNITS, HEADER_BYTES * 8 * HEADER_COPIES);
303 assert_eq!(HEADER_UNITS, 56);
304 }
305
306 #[test]
307 fn delta_increases_with_r() {
308 let mean_qt = 10.0;
309 let d1 = compute_delta_from_mean_qt(mean_qt, 1);
310 let d2 = compute_delta_from_mean_qt(mean_qt, 2);
311 let d3 = compute_delta_from_mean_qt(mean_qt, 3);
312 let d5 = compute_delta_from_mean_qt(mean_qt, 5);
313 let d7 = compute_delta_from_mean_qt(mean_qt, 7);
314
315 assert!(d2 > d1, "r=2 should increase delta");
316 assert!(d3 > d2, "r=3 should increase delta more");
317 assert!(d5 > d3, "r=5 should increase delta further");
318 assert!(d7 > d5, "r=7 should increase delta even more");
319
320 assert!((d1 - 30.0).abs() < 1e-10, "r=1: 3.0 * 10.0 = 30.0, got {d1}");
322 assert!((d2 - 40.0).abs() < 1e-10, "r=2: 4.0 * 10.0 = 40.0, got {d2}");
323 assert!((d3 - 60.0).abs() < 1e-10, "r=3: 6.0 * 10.0 = 60.0, got {d3}");
324 assert!((d5 - 70.0).abs() < 1e-10, "r=5: 7.0 * 10.0 = 70.0, got {d5}");
325 assert!((d7 - 80.0).abs() < 1e-10, "r=7: 8.0 * 10.0 = 80.0, got {d7}");
326 }
327}