use super::family::*;
use super::hessian_paths::*;
use super::*;
use crate::probability::normal_cdf;
use gam_linalg::matrix::DesignMatrix;
use gam_problem::{InverseLink, StandardLink};
use ndarray::Array1;
use std::sync::{Arc, Mutex};
struct VFixture {
family: BernoulliMarginalSlopeFamily,
primary: PrimarySlices,
runtime: DeviationRuntime,
is_score_warp: bool,
grid: EmpiricalZGrid,
beta_dev: Array1<f64>,
}
fn vgrid() -> EmpiricalZGrid {
let nodes = vec![-1.4_f64, -0.6, 0.1, 0.8, 1.5];
let raw = [0.14_f64, 0.24, 0.28, 0.20, 0.14];
let total: f64 = raw.iter().sum();
let weights: Vec<f64> = raw.iter().map(|w| w / total).collect();
EmpiricalZGrid::new(nodes, weights, "flex_verify_932 grid").expect("valid grid")
}
fn vruntime() -> DeviationRuntime {
let n_knots = 11usize;
let knots = Array1::from_iter(
(0..n_knots).map(|i| -2.45_f64 + 5.0_f64 * (i as f64) / ((n_knots - 1) as f64)),
);
DeviationRuntime::try_new(knots, 0.0, 3).expect("deviation runtime")
}
fn vfixture(is_score_warp: bool) -> VFixture {
let grid = vgrid();
let runtime = vruntime();
let basis_dim = runtime.basis_dim();
let policy = gam_runtime::resource::ResourcePolicy::default_library();
let dummy = || {
DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
ndarray::Array2::zeros((1, 1)),
))
};
let family = BernoulliMarginalSlopeFamily {
y: Arc::new(Array1::from_vec(vec![1.0])),
weights: Arc::new(Array1::from_vec(vec![1.0])),
z: Arc::new(Array1::from_vec(vec![0.45])),
latent_measure: LatentMeasureKind::GlobalEmpirical { grid: grid.clone() },
gaussian_frailty_sd: None,
base_link: InverseLink::Standard(StandardLink::Probit),
marginal_design: dummy(),
logslope_design: dummy(),
score_warp: if is_score_warp {
Some(runtime.clone())
} else {
None
},
link_dev: if is_score_warp {
None
} else {
Some(runtime.clone())
},
policy: policy.clone(),
cell_moment_lru: new_cell_moment_lru_cache(&policy),
cell_moment_cache_stats: new_cell_moment_cache_stats(),
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
};
let primary = PrimarySlices {
q: 0,
logslope: 1,
h: if is_score_warp {
Some(2..2 + basis_dim)
} else {
None
},
w: if is_score_warp {
None
} else {
Some(2..2 + basis_dim)
},
total: 2 + basis_dim,
};
let beta_dev = Array1::from_shape_fn(basis_dim, |i| {
let center = 0.5 * (basis_dim.saturating_sub(1) as f64);
let radius = center.max(1.0);
0.06 * ((i as f64) - center) / radius
});
VFixture {
family,
primary,
runtime,
is_score_warp,
grid,
beta_dev,
}
}
fn veta(fx: &VFixture, a: f64, b: f64, beta: &Array1<f64>, z: f64, scale: f64) -> f64 {
let u = a + b * z;
let mut inside = u;
if fx.is_score_warp {
let row = fx
.runtime
.design(&Array1::from_vec(vec![z]))
.expect("score-warp basis");
let warp: f64 = row.row(0).iter().zip(beta.iter()).map(|(v, c)| v * c).sum();
inside += b * warp;
} else {
let row = fx
.runtime
.design(&Array1::from_vec(vec![u]))
.expect("link-dev basis");
let dev: f64 = row.row(0).iter().zip(beta.iter()).map(|(v, c)| v * c).sum();
inside += dev;
}
scale * inside
}
fn vintercept(fx: &VFixture, mu: f64, b: f64, beta: &Array1<f64>, scale: f64) -> f64 {
let calib = |a: f64| -> f64 {
let mut acc = -mu;
for (node, weight) in fx.grid.pairs() {
acc += weight * normal_cdf(veta(fx, a, b, beta, node, scale));
}
acc
};
let mut lo = -1.0_f64;
let mut hi = 1.0_f64;
let mut flo = calib(lo);
let mut fhi = calib(hi);
for _ in 0..100 {
if flo <= 0.0 && fhi >= 0.0 {
break;
}
if flo > 0.0 {
hi = lo;
fhi = flo;
lo *= 2.0;
flo = calib(lo);
} else {
lo = hi;
flo = fhi;
hi *= 2.0;
fhi = calib(hi);
}
}
assert!(
flo <= 0.0 && fhi >= 0.0,
"failed to bracket flex calibration root F({lo})={flo} F({hi})={fhi}"
);
for _ in 0..200 {
let mid = 0.5 * (lo + hi);
let fmid = calib(mid);
if fmid == 0.0 || (hi - lo).abs() <= 1e-16 * mid.abs().max(1.0) {
return mid;
}
if fmid < 0.0 {
lo = mid;
} else {
hi = mid;
}
}
0.5 * (lo + hi)
}
fn vnll(fx: &VFixture, p: &[f64]) -> f64 {
let q = p[fx.primary.q];
let b = p[fx.primary.logslope];
let dev = if fx.is_score_warp {
fx.primary.h.clone().unwrap()
} else {
fx.primary.w.clone().unwrap()
};
let beta = Array1::from_iter(dev.map(|i| p[i]));
let scale = fx.family.probit_frailty_scale();
let marginal =
bernoulli_marginal_link_map(&InverseLink::Standard(StandardLink::Probit), q)
.expect("link map");
let a = vintercept(fx, marginal.mu, b, &beta, scale);
let z = fx.family.z[0];
let eta = veta(fx, a, b, &beta, z, scale);
let s_y = 2.0 * fx.family.y[0] - 1.0;
let logcdf = normal_cdf(s_y * eta).max(1e-300).ln();
-fx.family.weights[0] * logcdf
}
fn beta_vec(fx: &VFixture, p: &[f64]) -> Array1<f64> {
let dev = if fx.is_score_warp {
fx.primary.h.clone().unwrap()
} else {
fx.primary.w.clone().unwrap()
};
Array1::from_iter(dev.map(|i| p[i]))
}
fn hand_value(fx: &VFixture, p: &[f64]) -> f64 {
let q = p[fx.primary.q];
let b = p[fx.primary.logslope];
let beta = beta_vec(fx, p);
let (beta_h, beta_w) = if fx.is_score_warp {
(Some(&beta), None)
} else {
(None, Some(&beta))
};
let scale = fx.family.probit_frailty_scale();
let marginal =
bernoulli_marginal_link_map(&InverseLink::Standard(StandardLink::Probit), q)
.expect("link map");
let intercept = vintercept(fx, marginal.mu, b, &beta, scale);
let row_ctx = BernoulliMarginalSlopeRowExactContext {
intercept,
m_a: 1.0,
intercept_fast_path: false,
degree9_cells: None,
};
let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(fx.primary.total);
fx.family
.compute_row_analytic_flex_from_parts_into(
0, &fx.primary, q, b, beta_h, beta_w, &row_ctx, None, None, false, &mut scratch,
)
.expect("hand path value")
}
fn hand_grad_hess(fx: &VFixture, p: &[f64]) -> (f64, Vec<f64>, Vec<f64>) {
let r = fx.primary.total;
let q = p[fx.primary.q];
let b = p[fx.primary.logslope];
let beta = beta_vec(fx, p);
let (beta_h, beta_w) = if fx.is_score_warp {
(Some(&beta), None)
} else {
(None, Some(&beta))
};
let (intercept, m_a, _) = fx
.family
.solve_row_intercept_base(0, q, b, beta_h, beta_w, None)
.expect("intercept solve");
let row_ctx = BernoulliMarginalSlopeRowExactContext {
intercept,
m_a,
intercept_fast_path: false,
degree9_cells: None,
};
let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(r);
let v = fx
.family
.compute_row_analytic_flex_from_parts_into(
0, &fx.primary, q, b, beta_h, beta_w, &row_ctx, None, None, true, &mut scratch,
)
.expect("hand path grad/hess");
let grad = scratch.grad.iter().copied().collect::<Vec<_>>();
let mut hess = vec![0.0; r * r];
for u in 0..r {
for w in 0..r {
hess[u * r + w] = scratch.hess[[u, w]];
}
}
(v, grad, hess)
}
fn fd_grad(fx: &VFixture, p0: &[f64], i: usize, h: f64) -> f64 {
let central = |step: f64| {
let mut pp = p0.to_vec();
let mut pm = p0.to_vec();
pp[i] += step;
pm[i] -= step;
(hand_value(fx, &pp) - hand_value(fx, &pm)) / (2.0 * step)
};
let g_h = central(h);
let g_h2 = central(0.5 * h);
(4.0 * g_h2 - g_h) / 3.0
}
fn fd_hess(fx: &VFixture, p0: &[f64], i: usize, j: usize, h: f64) -> f64 {
let cross = |step: f64| {
if i == j {
let mut pp = p0.to_vec();
let mut pm = p0.to_vec();
pp[i] += step;
pm[i] -= step;
let f0 = hand_value(fx, p0);
(hand_value(fx, &pp) - 2.0 * f0 + hand_value(fx, &pm)) / (step * step)
} else {
let mut tpp = p0.to_vec();
let mut tpm = p0.to_vec();
let mut tmp = p0.to_vec();
let mut tmm = p0.to_vec();
tpp[i] += step;
tpp[j] += step;
tpm[i] += step;
tpm[j] -= step;
tmp[i] -= step;
tmp[j] += step;
tmm[i] -= step;
tmm[j] -= step;
(hand_value(fx, &tpp) - hand_value(fx, &tpm) - hand_value(fx, &tmp)
+ hand_value(fx, &tmm))
/ (4.0 * step * step)
}
};
let d_h = cross(h);
let d_h2 = cross(0.5 * h);
(4.0 * d_h2 - d_h) / 3.0
}
fn run_hand_gate(is_score_warp: bool) {
let fx = vfixture(is_score_warp);
let r = fx.primary.total;
let label = if is_score_warp { "score-warp" } else { "link-dev" };
let q0 = 0.2_f64;
let b0 = 0.35_f64;
let mut p0 = vec![0.0; r];
p0[fx.primary.q] = q0;
p0[fx.primary.logslope] = b0;
let dev = if is_score_warp {
fx.primary.h.clone().unwrap()
} else {
fx.primary.w.clone().unwrap()
};
for (k, i) in dev.clone().enumerate() {
p0[i] = fx.beta_dev[k];
}
let v_hand = hand_value(&fx, &p0);
let v_ind = vnll(&fx, &p0);
assert!(
(v_hand - v_ind).abs() <= 1e-9 * v_ind.abs().max(1.0),
"{label} hand value {v_hand:+.12e} != independent scalar {v_ind:+.12e}"
);
let marginal =
bernoulli_marginal_link_map(&InverseLink::Standard(StandardLink::Probit), q0)
.expect("link map");
let scale = fx.family.probit_frailty_scale();
let beta = Array1::from_iter(dev.clone().map(|i| p0[i]));
let (beta_h, beta_w) = if is_score_warp {
(Some(&beta), None)
} else {
(None, Some(&beta))
};
let a_ind = vintercept(&fx, marginal.mu, b0, &beta, scale);
let (a_prod, _, _) = fx
.family
.solve_row_intercept_base(0, q0, b0, beta_h, beta_w, None)
.expect("prod intercept");
assert!(
(a_ind - a_prod).abs() <= 1e-9 * a_prod.abs().max(1.0),
"{label} intercept independent {a_ind:+.12e} != production {a_prod:+.12e}"
);
let (v_gh, grad, hess) = hand_grad_hess(&fx, &p0);
assert!(
(v_gh - v_hand).abs() <= 1e-9 * v_hand.abs().max(1.0),
"{label} analytic-call value {v_gh:+.12e} != value-call {v_hand:+.12e}"
);
let h = 1.0e-3_f64;
let mut max_g = 0.0_f64;
let mut max_hd = 0.0_f64;
for i in 0..r {
let fdg = fd_grad(&fx, &p0, i, h);
let e = (grad[i] - fdg).abs();
max_g = max_g.max(e);
assert!(
e <= 1e-7 * fdg.abs().max(1.0) + 1e-9,
"{label} grad[{i}] analytic {:+.12e} != fd {fdg:+.12e} (err {e:.2e})",
grad[i]
);
for j in i..r {
let fdh = fd_hess(&fx, &p0, i, j, h);
let e = (hess[i * r + j] - fdh).abs();
max_hd = max_hd.max(e);
assert!(
e <= 1e-5 * fdh.abs().max(1.0) + 1e-7,
"{label} hess[{i},{j}] analytic {:+.12e} != fd {fdh:+.12e} (err {e:.2e})",
hess[i * r + j]
);
assert!((hess[i * r + j] - hess[j * r + i]).abs() <= 1e-12);
}
}
eprintln!(
"#932 verify {label}: r={r} max|grad−fd|={max_g:.2e} max|hess−fd|={max_hd:.2e}"
);
}
#[test]
fn hand_flex_grad_hess_matches_independent_fd_score_warp_932() {
run_hand_gate(true);
}
#[test]
fn hand_flex_grad_hess_matches_independent_fd_link_dev_932() {
run_hand_gate(false);
}
fn veta_cell(c: [f64; 4], z: f64) -> f64 {
c[0] + c[1] * z + c[2] * z * z + c[3] * z * z * z
}
fn vq_cell(c: [f64; 4], z: f64) -> f64 {
let e = veta_cell(c, z);
0.5 * (z * z + e * e)
}
fn vmoment(c: [f64; 4], zl: f64, zr: f64, n: usize, panels: usize) -> f64 {
let m = panels * 2;
let hstep = (zr - zl) / (m as f64);
let f = |z: f64| z.powi(n as i32) * (-vq_cell(c, z)).exp();
let mut acc = f(zl) + f(zr);
for k in 1..m {
let z = zl + (k as f64) * hstep;
acc += if k % 2 == 1 { 4.0 } else { 2.0 } * f(z);
}
acc * hstep / 3.0
}
#[test]
fn moving_edge_leibniz_tracks_boundary_flux_932() {
use super::test_support::{Jet2, RuntimeJet, cell_base_moment_jets_moving};
let a0 = 0.30_f64;
let b0 = 0.80_f64;
let tau = 1.10_f64; let zl0 = -0.90_f64;
let zr0 = (tau - a0) / b0; let base_c = [0.10_f64, 0.45, -0.18, 0.08];
let max_n = 4usize;
let scalar_deg = max_n + 12;
let panels = 6000usize;
let p = 2usize; let a_jet = Jet2::primary(a0, 0, p);
let b_jet = Jet2::primary(b0, 1, p);
let tau_c = Jet2::constant(tau, p);
let inv_b = {
let v = b0;
b_jet.compose_unary([1.0 / v, -1.0 / (v * v), 2.0 / (v * v * v), 0.0, 0.0])
};
let zr_jet = tau_c.sub(&a_jet).mul(&inv_b);
let zl_jet = Jet2::constant(zl0, p);
let c_jets: [Jet2; 4] = std::array::from_fn(|k| Jet2::constant(base_c[k], p));
let scalar_moments: Vec<f64> =
(0..=scalar_deg).map(|n| vmoment(base_c, zl0, zr0, n, panels)).collect();
let m_jets = cell_base_moment_jets_moving(
&c_jets,
base_c,
&scalar_moments,
max_n,
&zl_jet,
zl0,
&zr_jet,
zr0,
);
let m_of = |a: f64, b: f64, n: usize| -> f64 {
let zr = (tau - a) / b;
vmoment(base_c, zl0, zr, n, panels)
};
let h = 1e-4_f64;
let sweep = ((tau - a0) / (b0 + h) - (tau - a0) / (b0 - h)).abs();
assert!(sweep > 1e-4, "knot crossing did not sweep ({sweep:.2e})");
let mut max_g = 0.0_f64;
let mut max_h = 0.0_f64;
for n in 0..=max_n {
assert!(
(m_jets[n].v - scalar_moments[n]).abs() <= 1e-9 * scalar_moments[n].abs().max(1.0),
"moving M[{n}] value mismatch"
);
let ga = (m_of(a0 + h, b0, n) - m_of(a0 - h, b0, n)) / (2.0 * h);
let gb = (m_of(a0, b0 + h, n) - m_of(a0, b0 - h, n)) / (2.0 * h);
max_g = max_g.max((m_jets[n].g[0] - ga).abs()).max((m_jets[n].g[1] - gb).abs());
assert!(
(m_jets[n].g[0] - ga).abs() <= 1e-6 * ga.abs().max(1.0) + 1e-8,
"moving dM[{n}]/da jet {:+.12e} != fd {ga:+.12e}",
m_jets[n].g[0]
);
assert!(
(m_jets[n].g[1] - gb).abs() <= 1e-6 * gb.abs().max(1.0) + 1e-8,
"moving dM[{n}]/db jet {:+.12e} != fd {gb:+.12e}",
m_jets[n].g[1]
);
let hbb = (m_of(a0, b0 + h, n) - 2.0 * m_of(a0, b0, n) + m_of(a0, b0 - h, n)) / (h * h);
let hab = (m_of(a0 + h, b0 + h, n) - m_of(a0 + h, b0 - h, n)
- m_of(a0 - h, b0 + h, n)
+ m_of(a0 - h, b0 - h, n))
/ (4.0 * h * h);
max_h = max_h
.max((m_jets[n].h[p + 1] - hbb).abs())
.max((m_jets[n].h[1] - hab).abs());
assert!(
(m_jets[n].h[p + 1] - hbb).abs() <= 1e-3 * hbb.abs().max(1.0) + 1e-5,
"moving d2M[{n}]/db2 jet {:+.12e} != fd {hbb:+.12e}",
m_jets[n].h[p + 1]
);
assert!(
(m_jets[n].h[1] - hab).abs() <= 1e-3 * hab.abs().max(1.0) + 1e-5,
"moving d2M[{n}]/dadb jet {:+.12e} != fd {hab:+.12e}",
m_jets[n].h[1]
);
}
eprintln!(
"#932 verify leibniz: sweep={sweep:.3e} max|grad−fd|={max_g:.2e} max|hess−fd|={max_h:.2e}"
);
}
fn corrupt_moving_moment(
c0: [f64; 4],
scalar_moments: &[f64],
n: usize,
zr_jet: &super::test_support::Jet2,
zr0: f64,
) -> super::test_support::Jet2 {
use super::test_support::{Jet2, RuntimeJet};
let p = zr_jet.p();
let cst = |x: f64| Jet2::constant(x, p);
let interior = cst(scalar_moments[n]);
let eta0 = veta_cell(c0, zr0);
let q0 = 0.5 * (zr0 * zr0 + eta0 * eta0);
let g0 = zr0.powi(n as i32) * (-q0).exp();
let delta = zr_jet.sub(&cst(zr0));
let sliver_r = delta.scale(g0); interior.add(&sliver_r)
}
#[test]
fn planted_corruption_tripwire_fails_932() {
use super::test_support::{Jet2, RuntimeJet};
let a0 = 0.30_f64;
let b0 = 0.80_f64;
let tau = 1.10_f64;
let zl0 = -0.90_f64;
let zr0 = (tau - a0) / b0;
let base_c = [0.10_f64, 0.45, -0.18, 0.08];
let n = 2usize;
let panels = 4000usize;
let scalar_deg = n + 12;
let p = 2usize;
let a_jet = Jet2::primary(a0, 0, p);
let b_jet = Jet2::primary(b0, 1, p);
let inv_b =
b_jet.compose_unary([1.0 / b0, -1.0 / (b0 * b0), 2.0 / (b0 * b0 * b0), 0.0, 0.0]);
let zr_jet = Jet2::constant(tau, p).sub(&a_jet).mul(&inv_b);
let scalar_moments: Vec<f64> =
(0..=scalar_deg).map(|k| vmoment(base_c, zl0, zr0, k, panels)).collect();
let corrupt = corrupt_moving_moment(base_c, &scalar_moments, n, &zr_jet, zr0);
let m_of = |a: f64, b: f64| -> f64 {
let zr = (tau - a) / b;
vmoment(base_c, zl0, zr, n, panels)
};
let h = 5e-4_f64;
let hbb = (m_of(a0, b0 + h) - 2.0 * m_of(a0, b0) + m_of(a0, b0 - h)) / (h * h);
let err = (corrupt.h[p + 1] - hbb).abs();
let bound = 1e-3 * hbb.abs().max(1.0) + 1e-5;
assert!(
err > bound,
"TRIPWIRE TOOTHLESS: corrupt sliver Hessian-bb err {err:.3e} <= bound {bound:.3e} \
(the dropped ½·g_z·δ² term went undetected — the moving-edge oracle has no teeth)"
);
eprintln!(
"#932 verify tripwire: corrupt err={err:.3e} > bound={bound:.3e} (oracle has teeth)"
);
}