use ndarray::ArrayView1;
#[derive(Debug, Clone, Copy)]
pub struct CriterionCertificate {
pub grad_norm: f64,
pub fd_directional: f64,
pub analytic_directional: f64,
pub fd_error_bar: f64,
pub step: f64,
pub well_posed: bool,
}
impl CriterionCertificate {
#[must_use]
pub fn agreement_rel(&self) -> f64 {
let scale = self
.analytic_directional
.abs()
.max(self.fd_directional.abs())
.max(self.fd_error_bar)
.max(1e-12);
(self.fd_directional - self.analytic_directional).abs() / scale
}
#[must_use]
pub fn passes(&self, rel_tol: f64) -> bool {
self.well_posed && self.agreement_rel() <= rel_tol
}
}
#[must_use]
pub fn deterministic_probe_direction(rho_hat: ArrayView1<'_, f64>) -> Vec<f64> {
let n = rho_hat.len();
if n == 0 {
return Vec::new();
}
let mut seed: u64 = 0x9E37_79B9_7F4A_7C15;
for (idx, &value) in rho_hat.iter().enumerate() {
seed =
splitmix64(seed ^ value.to_bits() ^ (idx as u64).wrapping_mul(0x2545_F491_4F6C_DD1D));
}
let mut dir = vec![0.0_f64; n];
let mut s = seed;
let mut norm_sq = 0.0_f64;
for slot in dir.iter_mut() {
s = splitmix64(s);
let unit = (s >> 11) as f64 / ((1u64 << 53) as f64); let coord = 2.0 * unit - 1.0;
*slot = coord;
norm_sq += coord * coord;
}
let norm = norm_sq.sqrt();
if norm > 0.0 {
for slot in dir.iter_mut() {
*slot /= norm;
}
} else {
dir[0] = 1.0;
}
dir
}
fn splitmix64(state: u64) -> u64 {
let mut z = state.wrapping_add(0x9E37_79B9_7F4A_7C15);
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
#[must_use]
pub fn probe_step(rho_hat: ArrayView1<'_, f64>) -> f64 {
const BASE: f64 = 1e-4;
let scale = rho_hat.iter().fold(1.0_f64, |m, &x| m.max(x.abs()));
BASE * scale
}
#[derive(Debug, Clone, Copy)]
pub struct DirectionalSamples {
pub plus_h: f64,
pub minus_h: f64,
pub plus_2h: f64,
pub minus_2h: f64,
pub step: f64,
pub grad_norm: f64,
pub analytic_directional: f64,
pub well_posed: bool,
}
#[must_use]
pub fn certificate_from_samples(s: &DirectionalSamples) -> CriterionCertificate {
let d_h = (s.plus_h - s.minus_h) / (2.0 * s.step);
let d_2h = (s.plus_2h - s.minus_2h) / (4.0 * s.step);
let fd_error_bar = (d_h - d_2h).abs() / 3.0;
CriterionCertificate {
grad_norm: s.grad_norm,
fd_directional: d_h,
analytic_directional: s.analytic_directional,
fd_error_bar,
step: s.step,
well_posed: s.well_posed && s.plus_h.is_finite() && s.minus_h.is_finite(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array1;
#[test]
fn quadratic_certificate_agrees_exactly() {
let v = |r: &[f64]| 0.5 * (2.0 * r[0] * r[0] + 3.0 * r[1] * r[1]) + (r[0] - 2.0 * r[1]);
let rho = Array1::from(vec![0.7_f64, -1.3]);
let grad = [2.0 * rho[0] + 1.0, 3.0 * rho[1] - 2.0];
let dir = deterministic_probe_direction(rho.view());
let h = probe_step(rho.view());
let at = |sign: f64, mult: f64| {
let p: Vec<f64> = (0..2).map(|i| rho[i] + sign * mult * h * dir[i]).collect();
v(&p)
};
let grad_norm = (grad[0] * grad[0] + grad[1] * grad[1]).sqrt();
let analytic_directional = grad[0] * dir[0] + grad[1] * dir[1];
let samples = DirectionalSamples {
plus_h: at(1.0, 1.0),
minus_h: at(-1.0, 1.0),
plus_2h: at(1.0, 2.0),
minus_2h: at(-1.0, 2.0),
step: h,
grad_norm,
analytic_directional,
well_posed: true,
};
let cert = certificate_from_samples(&samples);
assert!(
cert.agreement_rel() < 1e-6,
"quadratic FD must match analytic: rel {}, fd {}, analytic {}",
cert.agreement_rel(),
cert.fd_directional,
cert.analytic_directional
);
assert!(
cert.fd_error_bar < 1e-6,
"quadratic has zero third derivative, error bar must be tiny: {}",
cert.fd_error_bar
);
assert!(cert.passes(1e-4), "well-posed quadratic must certify");
}
#[test]
fn planted_desync_is_caught() {
let v = |r: &[f64]| r[0].sin() + 0.5 * r[1] * r[1];
let rho = Array1::from(vec![0.4_f64, 0.9]);
let true_grad = [rho[0].cos(), rho[1]];
let bad_grad = [1.3 * true_grad[0], true_grad[1]];
let dir = deterministic_probe_direction(rho.view());
let h = probe_step(rho.view());
let at = |sign: f64, mult: f64| {
let p: Vec<f64> = (0..2).map(|i| rho[i] + sign * mult * h * dir[i]).collect();
v(&p)
};
let grad_norm = (bad_grad[0] * bad_grad[0] + bad_grad[1] * bad_grad[1]).sqrt();
let analytic_directional = bad_grad[0] * dir[0] + bad_grad[1] * dir[1];
let samples = DirectionalSamples {
plus_h: at(1.0, 1.0),
minus_h: at(-1.0, 1.0),
plus_2h: at(1.0, 2.0),
minus_2h: at(-1.0, 2.0),
step: h,
grad_norm,
analytic_directional,
well_posed: true,
};
let cert = certificate_from_samples(&samples);
assert!(
!cert.passes(1e-3),
"30% desync must fail the certificate: rel {}, fd {}, analytic {}",
cert.agreement_rel(),
cert.fd_directional,
cert.analytic_directional
);
}
#[test]
fn probe_direction_is_deterministic_unit() {
let rho = Array1::from(vec![1.0_f64, -2.0, 0.5, 3.3]);
let a = deterministic_probe_direction(rho.view());
let b = deterministic_probe_direction(rho.view());
assert_eq!(a, b, "same fingerprint must give same direction");
let norm: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-12,
"direction must be unit, got {norm}"
);
let rho2 = Array1::from(vec![1.0_f64, -2.0, 0.5, 3.4]);
let c = deterministic_probe_direction(rho2.view());
assert_ne!(a, c, "different fingerprint must give different direction");
}
#[test]
fn nonfinite_sample_marks_not_well_posed() {
let samples = DirectionalSamples {
plus_h: f64::NAN,
minus_h: 1.0,
plus_2h: 2.0,
minus_2h: 0.0,
step: 1e-4,
grad_norm: 1.0,
analytic_directional: 0.0,
well_posed: true,
};
let cert = certificate_from_samples(&samples);
assert!(
!cert.well_posed,
"NaN value sample must flag not-well-posed"
);
assert!(!cert.passes(1.0), "not-well-posed never certifies");
}
}