use super::spreading::SPREAD_LEN;
use crate::codec::jpeg::zigzag::NATURAL_TO_ZIGZAG;
pub const BOOTSTRAP_DELTA: f64 = 100.0;
pub const MAX_ARMOR_ZIGZAG: usize = 15;
pub fn compute_mean_qt(qt_values: &[u16; 64]) -> f64 {
let mut sum = 0.0f64;
let mut count = 0usize;
for nat_idx in 0..64 {
let zz = NATURAL_TO_ZIGZAG[nat_idx];
if (1..=MAX_ARMOR_ZIGZAG).contains(&zz) {
sum += qt_values[nat_idx] as f64;
count += 1;
}
}
if count == 0 {
return 10.0; }
sum / count as f64
}
pub fn encode_mean_qt(mean_qt: f64) -> u8 {
(mean_qt * 4.0).round().clamp(1.0, 255.0) as u8
}
pub fn decode_mean_qt(header_byte: u8) -> f64 {
header_byte as f64 / 4.0
}
pub const HEADER_BYTES: usize = 1;
pub const HEADER_UNITS: usize = HEADER_BYTES * 8 * HEADER_COPIES;
pub const HEADER_COPIES: usize = 7;
pub fn compute_delta_from_mean_qt(mean_qt: f64, r: usize) -> f64 {
let mult = if r >= 7 {
8.0
} else if r >= 5 {
7.0
} else if r >= 3 {
6.0
} else if r >= 2 {
4.0
} else {
3.0 };
mult * mean_qt
}
pub fn stdm_embed(coeffs: &mut [f64; SPREAD_LEN], v: &[f64; SPREAD_LEN], bit: u8, delta: f64) {
debug_assert!(bit <= 1);
let p: f64 = coeffs.iter().zip(v.iter()).map(|(&c, &vi)| c * vi).sum();
let q = quantize_for_bit(p, delta, bit);
let dp = q - p;
for i in 0..SPREAD_LEN {
coeffs[i] += dp * v[i];
}
}
#[cfg(test)]
pub fn stdm_extract(coeffs: &[f64; SPREAD_LEN], v: &[f64; SPREAD_LEN], delta: f64) -> u8 {
let p: f64 = coeffs.iter().zip(v.iter()).map(|(&c, &vi)| c * vi).sum();
let half_delta = delta / 2.0;
let m = (p / half_delta).round() as i64;
m.rem_euclid(2) as u8
}
fn quantize_for_bit(p: f64, delta: f64, bit: u8) -> f64 {
if bit == 0 {
(p / delta).round() * delta
} else {
((p / delta - 0.5).round() + 0.5) * delta
}
}
pub fn stdm_extract_soft(coeffs: &[f64; SPREAD_LEN], v: &[f64; SPREAD_LEN], delta: f64) -> f64 {
let p: f64 = coeffs.iter().zip(v.iter()).map(|(&c, &vi)| c * vi).sum();
let q0 = (p / delta).round() * delta;
let d0 = (p - q0).abs();
let q1 = ((p / delta - 0.5).round() + 0.5) * delta;
let d1 = (p - q1).abs();
d1 - d0
}
#[cfg(test)]
mod tests {
use super::*;
fn make_spreading_vec() -> [f64; SPREAD_LEN] {
let raw = [1.0, 0.5, -0.3, 0.7, -0.2, 0.4, 0.6, -0.1];
let norm: f64 = raw.iter().map(|x| x * x).sum::<f64>().sqrt();
let mut v = [0.0; SPREAD_LEN];
for i in 0..SPREAD_LEN {
v[i] = raw[i] / norm;
}
v
}
#[test]
fn embed_extract_roundtrip_bit0() {
let v = make_spreading_vec();
let delta = 10.0;
let mut coeffs = [20.0, 15.0, -8.0, 30.0, -5.0, 10.0, 25.0, -3.0];
stdm_embed(&mut coeffs, &v, 0, delta);
let extracted = stdm_extract(&coeffs, &v, delta);
assert_eq!(extracted, 0);
}
#[test]
fn embed_extract_roundtrip_bit1() {
let v = make_spreading_vec();
let delta = 10.0;
let mut coeffs = [20.0, 15.0, -8.0, 30.0, -5.0, 10.0, 25.0, -3.0];
stdm_embed(&mut coeffs, &v, 1, delta);
let extracted = stdm_extract(&coeffs, &v, delta);
assert_eq!(extracted, 1);
}
#[test]
fn embed_extract_many_bits() {
let v = make_spreading_vec();
let delta = 8.0;
for bit in 0..=1 {
for base in [-50.0, -10.0, 0.0, 10.0, 50.0] {
let mut coeffs = [base; SPREAD_LEN];
stdm_embed(&mut coeffs, &v, bit, delta);
let extracted = stdm_extract(&coeffs, &v, delta);
assert_eq!(extracted, bit, "failed for bit={bit}, base={base}");
}
}
}
#[test]
fn survives_small_perturbation() {
let v = make_spreading_vec();
let delta = 16.0;
for bit in 0..=1 {
let mut coeffs = [20.0, -10.0, 5.0, 30.0, -15.0, 8.0, 12.0, -6.0];
stdm_embed(&mut coeffs, &v, bit, delta);
for c in coeffs.iter_mut() {
*c += 0.3;
}
let extracted = stdm_extract(&coeffs, &v, delta);
assert_eq!(extracted, bit, "failed for bit={bit} after perturbation");
}
}
#[test]
fn quantize_for_bit_correct() {
let delta = 10.0;
assert!((quantize_for_bit(7.0, delta, 0) - 10.0).abs() < 1e-10);
assert!((quantize_for_bit(3.0, delta, 0) - 0.0).abs() < 1e-10);
assert!((quantize_for_bit(-7.0, delta, 0) - -10.0).abs() < 1e-10);
assert!((quantize_for_bit(3.0, delta, 1) - 5.0).abs() < 1e-10);
assert!((quantize_for_bit(8.0, delta, 1) - 5.0).abs() < 1e-10);
assert!((quantize_for_bit(12.0, delta, 1) - 15.0).abs() < 1e-10);
}
#[test]
fn compute_mean_qt_reasonable() {
let qt = [8, 6, 5, 8, 12, 20, 26, 31,
6, 6, 7, 10, 13, 29, 30, 28,
7, 7, 8, 12, 20, 29, 35, 28,
7, 9, 11, 15, 26, 44, 40, 31,
9, 11, 19, 28, 34, 55, 52, 39,
12, 18, 28, 32, 41, 52, 57, 46,
25, 32, 39, 44, 52, 61, 60, 51,
36, 46, 48, 49, 56, 50, 52, 50];
let mean = compute_mean_qt(&qt);
assert!(mean > 5.0 && mean < 30.0, "mean_qt={mean}");
}
#[test]
fn mean_qt_encode_decode_roundtrip() {
for qt_val in [5.0, 10.0, 15.5, 25.0, 50.0, 63.0] {
let encoded = encode_mean_qt(qt_val);
let decoded = decode_mean_qt(encoded);
assert!((decoded - qt_val).abs() < 0.5, "roundtrip failed: {qt_val} -> {encoded} -> {decoded}");
}
}
#[test]
fn soft_extract_sign_matches_hard_extract() {
let v = make_spreading_vec();
let delta = 10.0;
for bit in 0..=1 {
let mut coeffs = [20.0, 15.0, -8.0, 30.0, -5.0, 10.0, 25.0, -3.0];
stdm_embed(&mut coeffs, &v, bit, delta);
let llr = stdm_extract_soft(&coeffs, &v, delta);
let hard_bit = stdm_extract(&coeffs, &v, delta);
let soft_bit = if llr >= 0.0 { 0u8 } else { 1u8 };
assert_eq!(soft_bit, hard_bit, "bit={bit}, llr={llr}");
assert_eq!(soft_bit, bit, "bit={bit}, llr={llr}");
}
}
#[test]
fn soft_extract_confidence_decreases_with_noise() {
let v = make_spreading_vec();
let delta = 16.0;
let mut coeffs = [20.0, -10.0, 5.0, 30.0, -15.0, 8.0, 12.0, -6.0];
stdm_embed(&mut coeffs, &v, 0, delta);
let llr_clean = stdm_extract_soft(&coeffs, &v, delta);
assert!(llr_clean > 0.0, "should favor bit 0");
let mut noisy = coeffs;
for c in noisy.iter_mut() {
*c += 2.0;
}
let llr_noisy = stdm_extract_soft(&noisy, &v, delta);
assert!(llr_clean.abs() >= llr_noisy.abs() - 1.0, "noise should not increase confidence dramatically");
}
#[test]
fn header_units_constant_correct() {
assert_eq!(HEADER_UNITS, HEADER_BYTES * 8 * HEADER_COPIES);
assert_eq!(HEADER_UNITS, 56);
}
#[test]
fn delta_increases_with_r() {
let mean_qt = 10.0;
let d1 = compute_delta_from_mean_qt(mean_qt, 1);
let d2 = compute_delta_from_mean_qt(mean_qt, 2);
let d3 = compute_delta_from_mean_qt(mean_qt, 3);
let d5 = compute_delta_from_mean_qt(mean_qt, 5);
let d7 = compute_delta_from_mean_qt(mean_qt, 7);
assert!(d2 > d1, "r=2 should increase delta");
assert!(d3 > d2, "r=3 should increase delta more");
assert!(d5 > d3, "r=5 should increase delta further");
assert!(d7 > d5, "r=7 should increase delta even more");
assert!((d1 - 30.0).abs() < 1e-10, "r=1: 3.0 * 10.0 = 30.0, got {d1}");
assert!((d2 - 40.0).abs() < 1e-10, "r=2: 4.0 * 10.0 = 40.0, got {d2}");
assert!((d3 - 60.0).abs() < 1e-10, "r=3: 6.0 * 10.0 = 60.0, got {d3}");
assert!((d5 - 70.0).abs() < 1e-10, "r=5: 7.0 * 10.0 = 70.0, got {d5}");
assert!((d7 - 80.0).abs() < 1e-10, "r=7: 8.0 * 10.0 = 80.0, got {d7}");
}
}