use crate::assignment::{AssignmentMode, SaeAssignment};
use crate::basis::SaeBasisEvaluator;
use crate::chart_canonicalization::{CanonicalChartTopology, unit_speed_retraction};
use crate::manifold::{SaeAtomBasisKind, SaeManifoldAtom, SaeManifoldRho, SaeManifoldTerm};
use gam_terms::latent::LatentManifold;
use ndarray::{Array2, Array3, Array4, Array5, ArrayView2, array};
use std::sync::Arc;
use super::tests::{TestPeriodicEvaluator, periodic_basis};
fn build_circle_term(coords_col: &Array2<f64>, decoder: &Array2<f64>) -> SaeManifoldTerm {
let n = coords_col.nrows();
let (phi, jet) = periodic_basis(coords_col);
let m = phi.ncols();
assert_eq!(decoder.nrows(), m, "decoder rows must equal basis width");
let atom = SaeManifoldAtom::new(
"circle",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
decoder.clone(),
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let logits = Array2::<f64>::from_elem((n, 1), 2.0); let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
vec![coords_col.clone()],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::jumprelu(1.0, 0.0),
)
.unwrap();
SaeManifoldTerm::new(vec![atom], assignment).unwrap()
}
#[test]
fn unit_speed_hook_gradient_consistent_and_noop_safe_2022() {
let period = 1.0_f64;
let coords_col = array![
[0.02_f64],
[0.10],
[0.17],
[0.31],
[0.55],
[0.66],
[0.80],
[0.95]
];
let n = coords_col.nrows();
let p = 3usize;
let (phi0, _) = periodic_basis(&coords_col);
let m = phi0.ncols();
let decoder = Array2::<f64>::from_shape_fn((m, p), |(a, b)| {
0.3 * ((a + 1) as f64) - 0.15 * (b as f64) + 0.05 * ((a * p + b) as f64)
});
let mut term = build_circle_term(&coords_col, &decoder);
let target =
Array2::<f64>::from_shape_fn((n, p), |(r, c)| 0.2 - 0.05 * (r as f64) + 0.1 * (c as f64));
let rho = SaeManifoldRho::new(0.0, -4.0, vec![array![0.0]]);
let sys = term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
let h = 1.0e-6_f64;
let row = 3usize;
let base_flat = term.assignment.coords[0].as_matrix().column(0).to_owned();
let mut fp = base_flat.clone();
fp[row] = (fp[row] + h).rem_euclid(period);
term.assignment.coords[0].set_flat(fp.view());
term.refresh_basis_from_current_coords().unwrap();
let lp = term.loss(target.view(), &rho).unwrap().total();
let mut fm = base_flat.clone();
fm[row] = (fm[row] - h).rem_euclid(period);
term.assignment.coords[0].set_flat(fm.view());
term.refresh_basis_from_current_coords().unwrap();
let lm = term.loss(target.view(), &rho).unwrap().total();
term.assignment.coords[0].set_flat(base_flat.view());
term.refresh_basis_from_current_coords().unwrap();
let gt_fd = (lp - lm) / (2.0 * h);
let gt_analytic = sys.rows[row].gt[1];
assert!(
gt_fd.abs() > 1.0e-3,
"coord gradient must be non-trivial so the ARD term is genuinely guarded (got {gt_fd})"
);
assert!(
(gt_analytic - gt_fd).abs() <= 1.0e-4 * (1.0 + gt_fd.abs()),
"assembled coord gradient {gt_analytic} must match FD {gt_fd} (ARD/coord desync guard)"
);
for &(bm, bp) in &[(0usize, 0usize), (1usize, 1usize)] {
let beta_idx = bm * p + bp; let g_analytic = sys.gb[beta_idx];
let base = term.atoms[0].decoder_coefficients[[bm, bp]];
term.atoms[0].decoder_coefficients[[bm, bp]] = base + h;
let lpp = term.loss(target.view(), &rho).unwrap().total();
term.atoms[0].decoder_coefficients[[bm, bp]] = base - h;
let lmm = term.loss(target.view(), &rho).unwrap().total();
term.atoms[0].decoder_coefficients[[bm, bp]] = base; let g_fd = (lpp - lmm) / (2.0 * h);
assert!(
(g_analytic - g_fd).abs() <= 1.0e-6 * (1.0 + g_fd.abs()),
"β-gradient[{bm},{bp}] {g_analytic} must match FD {g_fd}"
);
}
let l0 = term.loss(target.view(), &rho).unwrap();
let coords0 = term.assignment.coords[0].as_matrix().column(0).to_owned();
let decoder0 = term.atoms[0].decoder_coefficients.clone();
let topo = CanonicalChartTopology::Circle { period };
let applied = term.canonicalize_atom_unit_speed_chart(0, &topo).unwrap();
assert!(
!applied,
"a non-uniform harmonic d=1 chart cannot meet the 1e-9 recomposition gate ⇒ must honest-skip"
);
let l1 = term.loss(target.view(), &rho).unwrap();
assert!(
(l1.data_fit - l0.data_fit).abs() < 1.0e-12,
"honest-skip must not move data_fit"
);
assert!(
(l1.smoothness - l0.smoothness).abs() < 1.0e-12,
"honest-skip must not move smoothness"
);
assert!(
(l1.ard - l0.ard).abs() < 1.0e-12,
"honest-skip must not move ARD"
);
assert!(
(l1.assignment_sparsity - l0.assignment_sparsity).abs() < 1.0e-12,
"honest-skip must not move the assignment prior"
);
let coords1 = term.assignment.coords[0].as_matrix().column(0).to_owned();
let cdrift = coords0
.iter()
.zip(coords1.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f64, f64::max);
assert!(
cdrift == 0.0,
"honest-skip must leave coords byte-unchanged; drift {cdrift}"
);
let ddrift = decoder0
.iter()
.zip(term.atoms[0].decoder_coefficients.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f64, f64::max);
assert!(
ddrift == 0.0,
"honest-skip must leave the decoder byte-unchanged; drift {ddrift}"
);
let n_retracted = term.retract_unit_speed_charts_in_loop().unwrap();
assert_eq!(
n_retracted, 0,
"in-loop hook must re-gauge 0 atoms when nothing can reparam-close"
);
let l2 = term.loss(target.view(), &rho).unwrap();
assert!((l2.data_fit - l0.data_fit).abs() < 1.0e-12 && (l2.ard - l0.ard).abs() < 1.0e-12);
let uniform_decoder =
Array2::<f64>::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0]).unwrap();
let early = unit_speed_retraction(
&TestPeriodicEvaluator,
uniform_decoder.view(),
coords_col.column(0),
&topo,
)
.unwrap();
assert!(
early.is_none(),
"an already-unit-speed chart must early-out (defect < UNIT_SPEED_INLOOP_DEFECT_TOL)"
);
}
#[derive(Debug)]
struct MonomialLineEvaluator;
impl SaeBasisEvaluator for MonomialLineEvaluator {
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
if coords.ncols() != 1 {
return Some(Err(format!(
"MonomialLineEvaluator::second_jet_dyn: expected latent_dim 1, got {}",
coords.ncols()
)));
}
let n = coords.nrows();
let mut h = Array4::<f64>::zeros((n, 3, 1, 1));
for row in 0..n {
h[[row, 2, 0, 0]] = 2.0;
}
Some(Ok(h))
}
fn third_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array5<f64>, String>> {
if coords.ncols() != 1 {
return Some(Err(format!(
"MonomialLineEvaluator::third_jet_dyn: expected latent_dim 1, got {}",
coords.ncols()
)));
}
Some(Ok(Array5::<f64>::zeros((coords.nrows(), 3, 1, 1, 1))))
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
if coords.ncols() != 1 {
return Err(format!(
"MonomialLineEvaluator: expected latent_dim 1, got {}",
coords.ncols()
));
}
let n = coords.nrows();
let mut phi = Array2::<f64>::zeros((n, 3));
let mut jet = Array3::<f64>::zeros((n, 3, 1));
for row in 0..n {
let t = coords[[row, 0]];
phi[[row, 0]] = 1.0;
phi[[row, 1]] = t;
phi[[row, 2]] = t * t;
jet[[row, 1, 0]] = 1.0;
jet[[row, 2, 0]] = 2.0 * t;
}
Ok((phi, jet))
}
}
fn build_line_term(coords_col: &Array2<f64>, decoder: &Array2<f64>) -> SaeManifoldTerm {
let n = coords_col.nrows();
let (phi, jet) = MonomialLineEvaluator
.evaluate(coords_col.view())
.expect("monomial basis evaluates");
let m = phi.ncols();
assert_eq!(decoder.nrows(), m, "decoder rows must equal basis width");
let atom = SaeManifoldAtom::new(
"line",
SaeAtomBasisKind::Linear,
1,
phi,
jet,
decoder.clone(),
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_evaluator(Arc::new(MonomialLineEvaluator));
let logits = Array2::<f64>::from_elem((n, 1), 2.0); let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
vec![coords_col.clone()],
vec![LatentManifold::Euclidean],
AssignmentMode::jumprelu(1.0, 0.0),
)
.unwrap();
SaeManifoldTerm::new(vec![atom], assignment).unwrap()
}
#[test]
fn unit_speed_active_retraction_moves_only_ard_2070() {
let coords_col = array![
[0.15_f64],
[0.30],
[0.42],
[0.58],
[0.71],
[0.88]
];
let n = coords_col.nrows();
let p = 1usize;
let decoder = Array2::<f64>::from_shape_vec((3, p), vec![0.0, 0.0, 1.0]).unwrap();
let mut term = build_line_term(&coords_col, &decoder);
let target = Array2::<f64>::from_shape_fn((n, p), |(r, _)| 0.10 + 0.05 * r as f64);
let rho = SaeManifoldRho::new(0.0, -4.0, vec![array![0.0]]);
let l0 = term.loss(target.view(), &rho).unwrap();
let coords0 = term.assignment.coords[0].as_matrix().column(0).to_owned();
let topo = CanonicalChartTopology::Interval;
let applied = term.canonicalize_atom_unit_speed_chart(0, &topo).unwrap();
assert!(
applied,
"the monomial line chart's image is affine in arc-length ⇒ the ACTIVE arc-length \
retraction MUST fire (this is the reachable faithful d=1 active path)"
);
let l1 = term.loss(target.view(), &rho).unwrap();
let coords1 = term.assignment.coords[0].as_matrix().column(0).to_owned();
let cdrift = coords0
.iter()
.zip(coords1.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f64, f64::max);
assert!(
cdrift > 1.0e-3,
"the active retraction must MOVE the coordinates; max drift {cdrift}"
);
assert!(
(l1.data_fit - l0.data_fit).abs() <= 1.0e-6 * (1.0 + l0.data_fit.abs()),
"data_fit must be invariant under the image-frozen retraction: {} vs {}",
l1.data_fit,
l0.data_fit
);
assert!(
(l1.smoothness - l0.smoothness).abs() <= 1.0e-6 * (1.0 + l0.smoothness.abs()),
"smoothness must be invariant (transport preserves BᵀSB): {} vs {}",
l1.smoothness,
l0.smoothness
);
assert!(
(l1.assignment_sparsity - l0.assignment_sparsity).abs()
<= 1.0e-9 * (1.0 + l0.assignment_sparsity.abs()),
"the assignment prior must be invariant under the retraction"
);
let ard_delta = l1.ard - l0.ard;
assert!(
ard_delta.abs() > 1.0e-6,
"the ARD coordinate prior MUST move under the active retraction (it pins the \
residual gauge); delta {ard_delta}"
);
let total_delta = l1.total() - l0.total();
assert!(
(total_delta - ard_delta).abs() <= 1.0e-6 * (1.0 + ard_delta.abs()),
"the retraction must move ONLY the ARD prior; total Δ {total_delta} vs ARD Δ {ard_delta}"
);
let alpha = SaeManifoldRho::stable_exp_strength(0.0);
let expected: f64 = (0..n)
.map(|i| 0.5 * alpha * (coords1[i] * coords1[i] - coords0[i] * coords0[i]))
.sum();
assert!(
(ard_delta - expected).abs() <= 1.0e-9 * (1.0 + expected.abs()),
"ARD delta {ard_delta} must equal the Euclidean coordinate-energy delta {expected} \
at the reparam'd coords (confirms it is the reparam effect, not a bookkeeping bug)"
);
let n_retracted = term.retract_unit_speed_charts_in_loop().unwrap();
assert_eq!(
n_retracted, 0,
"a freshly-retracted chart is already unit-speed ⇒ the hook is idempotent (0 re-gauges)"
);
}