use crate::assignment::{AssignmentMode, SaeAssignment};
use crate::chart_canonicalization::{unit_speed_retraction, CanonicalChartTopology};
use crate::manifold::{SaeAtomBasisKind, SaeManifoldAtom, SaeManifoldRho, SaeManifoldTerm};
use gam_terms::latent::LatentManifold;
use ndarray::{array, Array2};
use std::sync::Arc;
use super::tests::{periodic_basis, TestPeriodicEvaluator};
fn build_circle_term(coords_col: &Array2<f64>, decoder: &Array2<f64>) -> SaeManifoldTerm {
let n = coords_col.nrows();
let (phi, jet) = periodic_basis(coords_col);
let m = phi.ncols();
assert_eq!(decoder.nrows(), m, "decoder rows must equal basis width");
let atom = SaeManifoldAtom::new(
"circle",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
decoder.clone(),
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let logits = Array2::<f64>::from_elem((n, 1), 2.0); let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
vec![coords_col.clone()],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::jumprelu(1.0, 0.0),
)
.unwrap();
SaeManifoldTerm::new(vec![atom], assignment).unwrap()
}
#[test]
fn unit_speed_hook_gradient_consistent_and_noop_safe_2022() {
let period = 1.0_f64;
let coords_col = array![
[0.02_f64],
[0.10],
[0.17],
[0.31],
[0.55],
[0.66],
[0.80],
[0.95]
];
let n = coords_col.nrows();
let p = 3usize;
let (phi0, _) = periodic_basis(&coords_col);
let m = phi0.ncols();
let decoder = Array2::<f64>::from_shape_fn((m, p), |(a, b)| {
0.3 * ((a + 1) as f64) - 0.15 * (b as f64) + 0.05 * ((a * p + b) as f64)
});
let mut term = build_circle_term(&coords_col, &decoder);
let target =
Array2::<f64>::from_shape_fn((n, p), |(r, c)| 0.2 - 0.05 * (r as f64) + 0.1 * (c as f64));
let rho = SaeManifoldRho::new(0.0, -4.0, vec![array![0.0]]);
let sys = term.assemble_arrow_schur(target.view(), &rho, None).unwrap();
let h = 1.0e-6_f64;
let row = 3usize;
let base_flat = term.assignment.coords[0].as_matrix().column(0).to_owned();
let mut fp = base_flat.clone();
fp[row] = (fp[row] + h).rem_euclid(period);
term.assignment.coords[0].set_flat(fp.view());
term.refresh_basis_from_current_coords().unwrap();
let lp = term.loss(target.view(), &rho).unwrap().total();
let mut fm = base_flat.clone();
fm[row] = (fm[row] - h).rem_euclid(period);
term.assignment.coords[0].set_flat(fm.view());
term.refresh_basis_from_current_coords().unwrap();
let lm = term.loss(target.view(), &rho).unwrap().total();
term.assignment.coords[0].set_flat(base_flat.view());
term.refresh_basis_from_current_coords().unwrap();
let gt_fd = (lp - lm) / (2.0 * h);
let gt_analytic = sys.rows[row].gt[1];
assert!(
gt_fd.abs() > 1.0e-3,
"coord gradient must be non-trivial so the ARD term is genuinely guarded (got {gt_fd})"
);
assert!(
(gt_analytic - gt_fd).abs() <= 1.0e-4 * (1.0 + gt_fd.abs()),
"assembled coord gradient {gt_analytic} must match FD {gt_fd} (ARD/coord desync guard)"
);
for &(bm, bp) in &[(0usize, 0usize), (1usize, 1usize)] {
let beta_idx = bm * p + bp; let g_analytic = sys.gb[beta_idx];
let base = term.atoms[0].decoder_coefficients[[bm, bp]];
term.atoms[0].decoder_coefficients[[bm, bp]] = base + h;
let lpp = term.loss(target.view(), &rho).unwrap().total();
term.atoms[0].decoder_coefficients[[bm, bp]] = base - h;
let lmm = term.loss(target.view(), &rho).unwrap().total();
term.atoms[0].decoder_coefficients[[bm, bp]] = base; let g_fd = (lpp - lmm) / (2.0 * h);
assert!(
(g_analytic - g_fd).abs() <= 1.0e-6 * (1.0 + g_fd.abs()),
"β-gradient[{bm},{bp}] {g_analytic} must match FD {g_fd}"
);
}
let l0 = term.loss(target.view(), &rho).unwrap();
let coords0 = term.assignment.coords[0].as_matrix().column(0).to_owned();
let decoder0 = term.atoms[0].decoder_coefficients.clone();
let topo = CanonicalChartTopology::Circle { period };
let applied = term.canonicalize_atom_unit_speed_chart(0, &topo).unwrap();
assert!(
!applied,
"a non-uniform harmonic d=1 chart cannot meet the 1e-9 recomposition gate ⇒ must honest-skip"
);
let l1 = term.loss(target.view(), &rho).unwrap();
assert!((l1.data_fit - l0.data_fit).abs() < 1.0e-12, "honest-skip must not move data_fit");
assert!((l1.smoothness - l0.smoothness).abs() < 1.0e-12, "honest-skip must not move smoothness");
assert!((l1.ard - l0.ard).abs() < 1.0e-12, "honest-skip must not move ARD");
assert!(
(l1.assignment_sparsity - l0.assignment_sparsity).abs() < 1.0e-12,
"honest-skip must not move the assignment prior"
);
let coords1 = term.assignment.coords[0].as_matrix().column(0).to_owned();
let cdrift = coords0
.iter()
.zip(coords1.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f64, f64::max);
assert!(cdrift == 0.0, "honest-skip must leave coords byte-unchanged; drift {cdrift}");
let ddrift = decoder0
.iter()
.zip(term.atoms[0].decoder_coefficients.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f64, f64::max);
assert!(ddrift == 0.0, "honest-skip must leave the decoder byte-unchanged; drift {ddrift}");
let n_retracted = term.retract_unit_speed_charts_in_loop().unwrap();
assert_eq!(n_retracted, 0, "in-loop hook must re-gauge 0 atoms when nothing can reparam-close");
let l2 = term.loss(target.view(), &rho).unwrap();
assert!((l2.data_fit - l0.data_fit).abs() < 1.0e-12 && (l2.ard - l0.ard).abs() < 1.0e-12);
let uniform_decoder =
Array2::<f64>::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0]).unwrap();
let early = unit_speed_retraction(
&TestPeriodicEvaluator,
uniform_decoder.view(),
coords_col.column(0),
&topo,
)
.unwrap();
assert!(
early.is_none(),
"an already-unit-speed chart must early-out (defect < UNIT_SPEED_INLOOP_DEFECT_TOL)"
);
}