use super::*;
use approx::assert_abs_diff_eq;
use gam_terms::analytic_penalties::IsometryReference;
use ndarray::array;
pub(crate) fn deterministic_decoder(n_basis: usize, p_out: usize, seed: f64) -> Array2<f64> {
Array2::<f64>::from_shape_fn((n_basis, p_out), |(i, j)| {
let x = seed + 0.371 * (i as f64) - 0.193 * (j as f64) + 0.047 * ((i * j + 1) as f64);
0.8 * x.sin() + 0.35 * (1.7 * x).cos()
})
}
pub(crate) fn build_isometry_atom_for_evaluator(
evaluator: Arc<dyn SaeBasisSecondJet>,
kind: SaeAtomBasisKind,
coords: &Array2<f64>,
p_out: usize,
seed: f64,
) -> (SaeManifoldAtom, IsometryPenalty, Array1<f64>) {
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let m = phi.ncols();
let decoder = deterministic_decoder(m, p_out, seed);
let atom = SaeManifoldAtom::new(
"exact_hvp_atom",
kind,
coords.ncols(),
phi,
jet,
decoder,
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_second_jet(evaluator);
let target_flat: Array1<f64> = coords.iter().copied().collect();
let penalty = IsometryPenalty::new_euclidean(
PsiSlice::full(target_flat.len(), Some(coords.ncols())),
p_out,
);
(atom, penalty, target_flat)
}
pub(crate) fn assert_exact_isometry_hvp_matches_grad_fd(
evaluator: Arc<dyn SaeBasisSecondJet>,
kind: SaeAtomBasisKind,
coords: Array2<f64>,
p_out: usize,
direction: Array2<f64>,
) {
let (atom, penalty, target_flat) =
build_isometry_atom_for_evaluator(evaluator, kind, &coords, p_out, 0.91);
let rho = array![0.0_f64];
let installed = refresh_isometry_caches_from_atom(&penalty, &atom, coords.view()).unwrap();
assert!(
installed,
"second-jet cache must be installed for exact HVP test"
);
assert!(
penalty.third_decoder_derivative().is_some(),
"non-Duchon exact HVP requires a live refreshed third-decoder-jet cache"
);
let v: Array1<f64> = direction.iter().copied().collect();
let exact = penalty.hvp(target_flat.view(), rho.view(), v.view());
assert!(
exact.iter().any(|x| x.abs() > 1.0e-7),
"exact isometry HVP should be nonzero after K refresh; got {exact:?}"
);
let eps = 1.0e-6;
let coords_plus = &coords + &(direction.mapv(|x| eps * x));
let coords_minus = &coords - &(direction.mapv(|x| eps * x));
let target_plus: Array1<f64> = coords_plus.iter().copied().collect();
let target_minus: Array1<f64> = coords_minus.iter().copied().collect();
refresh_isometry_caches_from_atom(&penalty, &atom, coords_plus.view()).unwrap();
let grad_plus = penalty.grad_target(target_plus.view(), rho.view());
refresh_isometry_caches_from_atom(&penalty, &atom, coords_minus.view()).unwrap();
let grad_minus = penalty.grad_target(target_minus.view(), rho.view());
refresh_isometry_caches_from_atom(&penalty, &atom, coords.view()).unwrap();
let fd = (&grad_plus - &grad_minus).mapv(|x| x / (2.0 * eps));
for i in 0..exact.len() {
let err = (exact[i] - fd[i]).abs();
let tol = 2.0e-4 + 3.0e-5 * exact[i].abs().max(fd[i].abs());
assert!(
err <= tol,
"exact isometry HVP/grad-FD mismatch at flat index {i}: exact={:.12e}, fd={:.12e}, err={:.6e}, tol={:.6e}",
exact[i],
fd[i],
err,
tol
);
}
}
pub(crate) fn assert_exact_isometry_hvp_collapses_to_gn_at_zero_residual(
evaluator: Arc<dyn SaeBasisSecondJet>,
kind: SaeAtomBasisKind,
coords: Array2<f64>,
p_out: usize,
direction: Array2<f64>,
) {
let (atom, penalty, target_flat) =
build_isometry_atom_for_evaluator(evaluator, kind, &coords, p_out, 1.37);
let rho = array![0.0_f64];
let d = coords.ncols();
refresh_isometry_caches_from_atom(&penalty, &atom, coords.view()).unwrap();
let mut g_ref = penalty
.pullback_metric(d)
.expect("pullback metric is available after the cache refresh");
let mut trace_sum = 0.0_f64;
for row in 0..g_ref.nrows() {
for axis in 0..d {
trace_sum += g_ref[[row, axis * d + axis]];
}
}
let normalizer = trace_sum / (g_ref.nrows() * d) as f64;
for value in g_ref.iter_mut() {
*value /= normalizer;
}
let penalty = penalty.with_reference(IsometryReference::UserSupplied(Arc::new(g_ref)));
assert!(
penalty.third_decoder_derivative().is_some(),
"zero-residual exact/GN test must still carry the real refreshed K cache"
);
let v: Array1<f64> = direction.iter().copied().collect();
let exact = penalty.hvp(target_flat.view(), rho.view(), v.view());
let gn = penalty.psd_majorizer_hvp(target_flat.view(), rho.view(), v.view());
assert!(
gn.iter().any(|x| x.abs() > 1.0e-8),
"GN block should be nonzero so exact/GN equality is not vacuous"
);
for i in 0..exact.len() {
assert_abs_diff_eq!(exact[i], gn[i], epsilon = 1.0e-10);
}
}
#[test]
pub(crate) fn isometry_exact_hvp_sphere_matches_grad_fd_and_uses_refreshed_k() {
assert_exact_isometry_hvp_matches_grad_fd(
Arc::new(SphereChartEvaluator),
SaeAtomBasisKind::Sphere,
array![[-0.61, 0.23], [-0.18, -1.07], [0.42, 0.81], [0.73, -0.39]],
4,
array![[0.31, -0.27], [-0.18, 0.22], [0.14, 0.19], [-0.25, -0.11]],
);
}
#[test]
pub(crate) fn isometry_exact_hvp_torus_matches_grad_fd_and_uses_refreshed_k() {
assert_exact_isometry_hvp_matches_grad_fd(
Arc::new(TorusHarmonicEvaluator::new(2, 2).unwrap()),
SaeAtomBasisKind::Torus,
array![[0.13, 0.42], [0.66, 0.19], [0.88, 0.55]],
3,
array![[0.21, -0.16], [-0.24, 0.18], [0.13, 0.27]],
);
}
#[test]
pub(crate) fn isometry_exact_hvp_sphere_and_torus_collapse_to_gn_at_zero_residual() {
assert_exact_isometry_hvp_collapses_to_gn_at_zero_residual(
Arc::new(SphereChartEvaluator),
SaeAtomBasisKind::Sphere,
array![[-0.52, 0.17], [-0.11, -0.93], [0.39, 0.74]],
4,
array![[0.17, -0.21], [-0.13, 0.08], [0.22, 0.19]],
);
assert_exact_isometry_hvp_collapses_to_gn_at_zero_residual(
Arc::new(TorusHarmonicEvaluator::new(2, 2).unwrap()),
SaeAtomBasisKind::Torus,
array![[0.19, 0.31], [0.57, 0.73], [0.84, 0.12]],
3,
array![[0.11, -0.14], [-0.20, 0.07], [0.16, 0.23]],
);
}
pub(crate) fn assert_isometry_psd_majorizer_live_after_atom_refresh(
evaluator: Arc<dyn SaeBasisSecondJet>,
kind: SaeAtomBasisKind,
coords: Array2<f64>,
p_out: usize,
probes: &[Array2<f64>],
) {
let (atom, penalty, target_flat) =
build_isometry_atom_for_evaluator(evaluator, kind, &coords, p_out, 0.53);
let rho = array![0.0_f64];
let n = target_flat.len();
let unit0 = {
let mut e = Array1::<f64>::zeros(n);
e[0] = 1.0;
e
};
let pre = penalty.psd_majorizer_hvp(target_flat.view(), rho.view(), unit0.view());
assert!(
pre.iter().all(|x| *x == 0.0),
"psd_majorizer_hvp without a cache must be the zero block; got {pre:?}"
);
let installed = refresh_isometry_caches_from_atom(&penalty, &atom, coords.view()).unwrap();
assert!(
installed,
"second-jet cache must install for the PSD-majorizer liveness test"
);
let d = coords.ncols();
let g = penalty
.pullback_metric(d)
.expect("pullback metric available after refresh");
let mut trace_sum = 0.0_f64;
for row in 0..g.nrows() {
for axis in 0..d {
trace_sum += g[[row, axis * d + axis]];
}
}
let normalizer = trace_sum / (g.nrows() * d) as f64;
let mut residual_mass = 0.0_f64;
for row in 0..g.nrows() {
for a in 0..d {
for b in 0..d {
let g_ref = if a == b { 1.0 } else { 0.0 };
residual_mass += (g[[row, a * d + b]] / normalizer - g_ref).abs();
}
}
}
assert!(
residual_mass > 1.0e-3,
"Euclidean-reference residual must be nonzero for a real curvature test; \
got residual mass {residual_mass:.3e}"
);
let mut bmat = Array2::<f64>::zeros((n, n));
for k in 0..n {
let mut e = Array1::<f64>::zeros(n);
e[k] = 1.0;
let col = penalty.psd_majorizer_hvp(target_flat.view(), rho.view(), e.view());
for r in 0..n {
bmat[[r, k]] = col[r];
}
}
let max_abs = bmat.iter().fold(0.0_f64, |acc, x| acc.max(x.abs()));
assert!(
max_abs > 1.0e-6,
"isometry GN majorizer must be nonzero for a non-Duchon basis after refresh; \
max |B| = {max_abs:.3e}"
);
for r in 0..n {
for c in 0..n {
assert_abs_diff_eq!(bmat[[r, c]], bmat[[c, r]], epsilon = 1.0e-10);
}
}
for probe in probes {
let v: Array1<f64> = probe.iter().copied().collect();
assert_eq!(v.len(), n, "probe must match the flattened target length");
let bv = penalty.psd_majorizer_hvp(target_flat.view(), rho.view(), v.view());
let quad = v.dot(&bv);
assert!(
quad >= -1.0e-9,
"isometry GN majorizer must be PSD; got vᵀBv = {quad:.3e}"
);
}
}
#[test]
pub(crate) fn isometry_psd_majorizer_live_after_sphere_refresh() {
assert_isometry_psd_majorizer_live_after_atom_refresh(
Arc::new(SphereChartEvaluator),
SaeAtomBasisKind::Sphere,
array![[-0.61, 0.23], [-0.18, -1.07], [0.42, 0.81]],
4,
&[
array![[0.31, -0.27], [-0.18, 0.22], [0.14, 0.19]],
array![[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]],
array![[-2.3, 0.6], [-0.1, 1.4], [0.8, -1.7]],
],
);
}
#[test]
pub(crate) fn isometry_psd_majorizer_live_after_circle_refresh() {
assert_isometry_psd_majorizer_live_after_atom_refresh(
Arc::new(PeriodicHarmonicEvaluator::new(5).unwrap()),
SaeAtomBasisKind::Periodic,
array![[0.12], [0.37], [0.58], [0.81]],
3,
&[
array![[0.4], [-1.1], [0.7], [0.3]],
array![[1.0], [1.0], [1.0], [1.0]],
array![[-2.3], [0.6], [-0.1], [1.4]],
],
);
}
#[test]
pub(crate) fn isometry_psd_majorizer_live_after_torus_refresh() {
assert_isometry_psd_majorizer_live_after_atom_refresh(
Arc::new(TorusHarmonicEvaluator::new(2, 2).unwrap()),
SaeAtomBasisKind::Torus,
array![[0.13, 0.42], [0.66, 0.19], [0.88, 0.55]],
3,
&[
array![[0.21, -0.16], [-0.24, 0.18], [0.13, 0.27]],
array![[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]],
array![[-1.2, 0.5], [0.3, -0.9], [0.7, 0.2]],
],
);
}
#[test]
pub(crate) fn refresh_isometry_caches_pairs_each_penalty_to_its_own_atom() {
let latent_dim = 1usize;
let p_out = 3usize;
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(5).unwrap());
let coords0 = array![[0.05], [0.20], [0.55], [0.80]];
let coords1 = array![[0.13], [0.41], [0.62], [0.91]];
let build_atom = |name: &str, coords: &Array2<f64>, seed: f64| {
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let m = phi.ncols();
let mut decoder = Array2::<f64>::zeros((m, p_out));
for i in 0..m {
for j in 0..p_out {
let x = (i as f64) * 0.371 + (j as f64) * 0.193 + seed;
decoder[[i, j]] = (x.sin() * 0.9) + 0.1 * ((i + j) as f64).cos();
}
}
let smooth = Array2::<f64>::eye(m);
SaeManifoldAtom::new(
name,
SaeAtomBasisKind::Periodic,
latent_dim,
phi,
jet,
decoder,
smooth,
)
.unwrap()
.with_basis_second_jet(evaluator.clone() as Arc<dyn SaeBasisSecondJet>)
};
let atom0 = build_atom("atom0", &coords0, 0.5);
let atom1 = build_atom("atom1", &coords1, 1.7);
let slice0 = PsiSlice::full(coords0.nrows() * latent_dim, Some(latent_dim));
let control0 = IsometryPenalty::new_euclidean(slice0, p_out);
refresh_isometry_caches_from_atom(&control0, &atom0, coords0.view()).unwrap();
let expected0 = control0
.jacobian_cache()
.expect("control penalty 0 must have a Jacobian cache");
let slice1 = PsiSlice::full(coords1.nrows() * latent_dim, Some(latent_dim));
let control1 = IsometryPenalty::new_euclidean(slice1, p_out);
refresh_isometry_caches_from_atom(&control1, &atom1, coords1.view()).unwrap();
let expected1 = control1
.jacobian_cache()
.expect("control penalty 1 must have a Jacobian cache");
assert_ne!(
*expected0, *expected1,
"atom 0 and atom 1 must produce distinct Jacobian caches"
);
let logits = Array2::<f64>::zeros((coords0.nrows(), 2));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
vec![coords0.clone(), coords1.clone()],
vec![
LatentManifold::Circle { period: 1.0 },
LatentManifold::Circle { period: 1.0 },
],
AssignmentMode::ibp_map(0.7, 1.0, true),
)
.unwrap();
let term = SaeManifoldTerm::new(vec![atom0, atom1], assignment).unwrap();
let mut registry = AnalyticPenaltyRegistry::new();
let pslice0 = PsiSlice::full(coords0.nrows() * latent_dim, Some(latent_dim));
let pslice1 = PsiSlice::full(coords1.nrows() * latent_dim, Some(latent_dim));
registry.push(AnalyticPenaltyKind::Isometry(Arc::new(
IsometryPenalty::new_euclidean(pslice0, p_out),
)));
registry.push(AnalyticPenaltyKind::Isometry(Arc::new(
IsometryPenalty::new_euclidean(pslice1, p_out),
)));
let coords_per_atom = vec![coords0.clone(), coords1.clone()];
let refreshed = refresh_isometry_caches_from_term(®istry, &term, &coords_per_atom).unwrap();
assert_eq!(refreshed, 2, "both penalties should install second caches");
let cache0 = match ®istry.penalties[0] {
AnalyticPenaltyKind::Isometry(p) => p
.jacobian_cache()
.expect("penalty 0 cache must be populated"),
_ => panic!("expected isometry penalty at index 0"),
};
let cache1 = match ®istry.penalties[1] {
AnalyticPenaltyKind::Isometry(p) => p
.jacobian_cache()
.expect("penalty 1 cache must be populated"),
_ => panic!("expected isometry penalty at index 1"),
};
assert_eq!(
*cache0, *expected0,
"penalty 0 must be refreshed against atom 0"
);
assert_eq!(
*cache1, *expected1,
"penalty 1 must be refreshed against atom 1 (regression: old find() paired it to atom 0)"
);
assert_ne!(
*cache0, *cache1,
"the two penalties must not collapse onto the same atom"
);
}