use crate::linalg::faer_ndarray::fast_ata;
use super::*;
use crate::solver::arrow_schur::{
ArrowFactorSlab, ArrowHtbetaCache, ArrowSolverMode, ArrowUndampedFactors, PcgDiagnostics,
};
use crate::terms::analytic_penalties::ARDPenalty;
use crate::terms::analytic_penalties::IsometryReference;
use approx::assert_abs_diff_eq;
use ndarray::{Array5, array};
#[test]
pub(crate) fn bessel_log_and_ratio_is_finite_and_matches_naive() {
for &eta in &[0.0_f64, 0.5, 1.0, 3.0, 3.75, 5.0, 20.0, 100.0, 300.0] {
let (log_i0, ratio) = bessel_i0_log_and_ratio(eta);
let naive_log = bessel_i0(eta).ln();
let naive_ratio = bessel_i1(eta) / bessel_i0(eta);
assert!(naive_log.is_finite(), "naive log finite at η={eta}");
assert!(naive_ratio.is_finite(), "naive ratio finite at η={eta}");
assert_abs_diff_eq!(log_i0, naive_log, epsilon = 1e-9);
assert_abs_diff_eq!(ratio, naive_ratio, epsilon = 1e-9);
}
for &eta in &[710.0_f64, 1.0e3, 1.0e6, 1.0e12, 1.0e300] {
assert!(
!bessel_i0(eta).is_finite(),
"naive I0 expected to overflow at η={eta} (guards the regression)"
);
let (log_i0, ratio) = bessel_i0_log_and_ratio(eta);
assert!(log_i0.is_finite(), "stable log I0 finite at η={eta}");
assert!(ratio.is_finite(), "stable I1/I0 finite at η={eta}");
assert!(ratio > 0.0 && ratio <= 1.0, "ratio in (0,1] at η={eta}");
}
}
pub(crate) fn assert_matrix_same_bits(left: &Array2<f64>, right: &Array2<f64>) {
assert_eq!(left.dim(), right.dim());
for ((row, col), &value) in left.indexed_iter() {
assert_eq!(
value.to_bits(),
right[[row, col]].to_bits(),
"matrix bits differ at ({row}, {col})"
);
}
}
pub(crate) fn assert_tensor3_same_bits(left: &Array3<f64>, right: &Array3<f64>) {
assert_eq!(left.dim(), right.dim());
for ((row, col, axis), &value) in left.indexed_iter() {
assert_eq!(
value.to_bits(),
right[[row, col, axis]].to_bits(),
"tensor bits differ at ({row}, {col}, {axis})"
);
}
}
pub(crate) fn assert_eta_one_parity(
evaluator: &dyn SaeBasisEvaluator,
coords: ArrayView2<'_, f64>,
expected_curved: usize,
) {
let (phi, jet) = evaluator.evaluate(coords).expect("base evaluate");
let eta = evaluator
.evaluate_phi_eta(coords, 1.0)
.expect("eta evaluate");
assert_matrix_same_bits(&eta.phi, &phi);
assert_tensor3_same_bits(&eta.jet, &jet);
assert_eq!(eta.split.curved_cols.len(), expected_curved);
for &col in &eta.split.linear_cols {
for row in 0..phi.nrows() {
assert_eq!(eta.dphi_deta[[row, col]], 0.0);
for axis in 0..jet.shape()[2] {
assert_eq!(eta.djet_deta[[row, col, axis]], 0.0);
}
}
}
for &col in &eta.split.curved_cols {
for row in 0..phi.nrows() {
assert_eq!(
eta.dphi_deta[[row, col]].to_bits(),
phi[[row, col]].to_bits()
);
for axis in 0..jet.shape()[2] {
assert_eq!(
eta.djet_deta[[row, col, axis]].to_bits(),
jet[[row, col, axis]].to_bits()
);
}
}
}
}
#[test]
pub(crate) fn phi_eta_one_reproduces_current_atom_bases_bit_for_bit() {
let periodic_coords = array![[0.0_f64], [0.125], [0.4]];
let periodic = PeriodicHarmonicEvaluator::new(7).unwrap();
assert_eta_one_parity(&periodic, periodic_coords.view(), 4);
let raw_circle_coords = array![[0.0_f64], [0.3], [1.1]];
let raw_circle = RawPeriodicCircleEvaluator::new(1).unwrap();
assert_eta_one_parity(&raw_circle, raw_circle_coords.view(), 0);
let torus_coords = array![[0.0_f64, 0.2], [0.25, 0.5], [0.7, 0.9]];
let torus = TorusHarmonicEvaluator::new(2, 2).unwrap();
assert_eta_one_parity(&torus, torus_coords.view(), 20);
let sphere_coords = array![[0.0_f64, 0.0], [0.3, 0.4], [-0.2, 1.1]];
let sphere = SphereChartEvaluator;
assert_eta_one_parity(&sphere, sphere_coords.view(), 3);
let centers = array![
[-1.0_f64, -1.0],
[1.0, -1.0],
[-1.0, 1.0],
[1.0, 1.0],
[0.0, 0.0],
[0.5, -0.25]
];
let duchon_coords = array![[0.1_f64, 0.2], [0.4, -0.3], [-0.2, 0.7]];
let duchon = DuchonCoordinateEvaluator::new(centers, 2).unwrap();
let (duchon_phi, _) = duchon.evaluate(duchon_coords.view()).unwrap();
let duchon_poly = 3usize;
assert_eta_one_parity(
&duchon,
duchon_coords.view(),
duchon_phi.ncols() - duchon_poly,
);
let euclidean = EuclideanPatchEvaluator::new(2, 3).unwrap();
let total_cols = crate::basis::monomial_exponents(2, 3).len();
let linear_cols = crate::basis::monomial_exponents(2, 3)
.iter()
.filter(|alpha| alpha.iter().sum::<usize>() <= 1)
.count();
assert_eta_one_parity(&euclidean, duchon_coords.view(), total_cols - linear_cols);
}
pub(crate) fn trivial_k1_euclidean_term() -> SaeManifoldTerm {
let n = 4usize;
let p = 3usize;
let atom = SaeManifoldAtom::new(
"atom0",
SaeAtomBasisKind::EuclideanPatch,
1,
Array2::<f64>::ones((n, 2)),
Array3::<f64>::zeros((n, 2, 1)),
Array2::<f64>::zeros((2, p)),
Array2::<f64>::eye(2),
)
.unwrap();
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 1)),
vec![Array2::<f64>::zeros((n, 1))],
vec![LatentManifold::Euclidean],
AssignmentMode::softmax(1.0),
)
.unwrap();
SaeManifoldTerm::new(vec![atom], assignment).unwrap()
}
#[test]
pub(crate) fn evidence_gauge_deflation_count_guard_reanchors_then_rejects_runaway() {
let mut term = trivial_k1_euclidean_term();
assert!(term.expected_evidence_gauge_deflated_directions.is_none());
term.record_evidence_gauge_deflation_count(5).unwrap();
assert_eq!(term.expected_evidence_gauge_deflated_directions, Some(5));
term.record_evidence_gauge_deflation_count(5).unwrap();
assert_eq!(term.expected_evidence_gauge_deflated_directions, Some(5));
term.record_evidence_gauge_deflation_count(6).unwrap(); assert_eq!(term.expected_evidence_gauge_deflated_directions, Some(6));
term.record_evidence_gauge_deflation_count(9).unwrap(); assert_eq!(term.expected_evidence_gauge_deflated_directions, Some(9));
term.record_evidence_gauge_deflation_count(7).unwrap(); assert_eq!(term.expected_evidence_gauge_deflated_directions, Some(7));
let err = term
.record_evidence_gauge_deflation_count(8)
.expect_err("a runaway quotient dimension must error past the re-anchor budget");
assert!(
err.contains("not stabilizing") && err.contains("re-anchored"),
"guard must report the runaway re-anchoring explicitly; got: {err}"
);
assert_eq!(term.expected_evidence_gauge_deflated_directions, Some(7));
}
#[test]
pub(crate) fn curvature_homotopy_eta_inertness_probe_tracks_curved_columns() {
let term = trivial_k1_euclidean_term();
assert!(term.curvature_homotopy_eta_is_inert().unwrap());
let (term, _target, _rho) = small_two_atom_periodic_term();
assert!(term.curvature_homotopy_eta_is_inert().unwrap());
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(7).unwrap());
let coords = array![[0.05], [0.20], [0.55], [0.80], [0.35]];
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let atom = SaeManifoldAtom::new(
"periodic7",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
Array2::<f64>::zeros((7, 1)),
Array2::<f64>::eye(7),
)
.unwrap()
.with_basis_evaluator(evaluator);
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((5, 1)),
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(1.0),
)
.unwrap();
let term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
assert!(!term.curvature_homotopy_eta_is_inert().unwrap());
}
#[test]
pub(crate) fn linear_span_anchor_recovers_planted_two_plane_configuration() {
let n = 4usize;
let p = 4usize;
let phi = Array2::<f64>::ones((n, 2));
let jet = Array3::<f64>::zeros((n, 2, 1));
let decoder = Array2::<f64>::zeros((2, p));
let smooth = Array2::<f64>::eye(2);
let atoms = vec![
SaeManifoldAtom::new(
"plane0",
SaeAtomBasisKind::EuclideanPatch,
1,
phi.clone(),
jet.clone(),
decoder.clone(),
smooth.clone(),
)
.unwrap(),
SaeManifoldAtom::new(
"plane1",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
smooth,
)
.unwrap(),
];
let coords = vec![Array2::<f64>::zeros((n, 1)), Array2::<f64>::zeros((n, 1))];
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 2)),
coords,
vec![LatentManifold::Euclidean, LatentManifold::Euclidean],
AssignmentMode::softmax(1.0),
)
.unwrap();
let term = SaeManifoldTerm::new(atoms, assignment).unwrap();
let target = array![
[3.0_f64, 0.0, 0.0, 0.0],
[0.0, 2.0, 0.0, 0.0],
[0.0, 0.0, 1.5, 0.0],
[0.0, 0.0, 0.0, 1.0]
];
let anchor = linear_span_anchor(&term, target.view()).unwrap();
assert_eq!(anchor.atoms.len(), 2);
assert_abs_diff_eq!(anchor.residual_norm_sq, 0.0, epsilon = 1.0e-18);
let plane0 = array![[1.0_f64, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0]];
let plane1 = array![[0.0_f64, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
let angle0 = anchor.atoms[0]
.frame
.max_principal_angle(plane0.view())
.unwrap();
let angle1 = anchor.atoms[1]
.frame
.max_principal_angle(plane1.view())
.unwrap();
assert_abs_diff_eq!(angle0, 0.0, epsilon = 1.0e-12);
assert_abs_diff_eq!(angle1, 0.0, epsilon = 1.0e-12);
}
pub(crate) fn circle_certificate_fixture(
radius: f64,
planes: &[(usize, usize)],
) -> SaeManifoldTerm {
let n = 16usize;
let p = 4usize;
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(3).unwrap());
let coords = Array2::<f64>::from_shape_fn((n, 1), |(row, _)| row as f64 / n as f64);
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let mut atoms = Vec::with_capacity(planes.len());
let mut coord_blocks = Vec::with_capacity(planes.len());
for (atom_idx, &(axis_sin, axis_cos)) in planes.iter().enumerate() {
let mut decoder = Array2::<f64>::zeros((3, p));
decoder[[1, axis_sin]] = radius;
decoder[[2, axis_cos]] = radius;
let atom = SaeManifoldAtom::new(
format!("circle_{atom_idx}"),
SaeAtomBasisKind::Periodic,
1,
phi.clone(),
jet.clone(),
decoder,
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_second_jet(evaluator.clone());
atoms.push(atom);
coord_blocks.push(coords.clone());
}
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, planes.len())),
coord_blocks,
vec![LatentManifold::Circle { period: 1.0 }; planes.len()],
AssignmentMode::softmax(1.0),
)
.unwrap();
let mut term = SaeManifoldTerm::new(atoms, assignment).unwrap();
term.set_certificate_dispersion(1.0).unwrap();
term
}
#[test]
pub(crate) fn dictionary_incoherence_report_orthogonal_frames_has_zero_mu_hat() {
let term = circle_certificate_fixture(2.0, &[(0, 1), (2, 3)]);
let report = dictionary_incoherence_report(&term).unwrap();
assert_abs_diff_eq!(report.mu_hat, 0.0, epsilon = 1.0e-12);
assert_eq!(report.per_atom_kappa_hat.len(), 2);
let kappa_max = report
.per_atom_kappa_hat
.iter()
.copied()
.fold(0.0_f64, f64::max);
let recomputed = curved_dictionary_global_optimality_verdict(
report.mu_hat,
kappa_max,
report.peak_activity_floor,
report.snr_proxy,
report.per_atom_kappa_hat.len(),
);
assert_eq!(report.global_optimality, recomputed);
if report.snr_proxy > 1.0 {
assert!(
report.global_optimality.is_certified(),
"μ̂=0, κ̂=0.5<1, SNR>1 ⇒ must certify; got {}",
report.note
);
}
}
#[test]
pub(crate) fn dictionary_incoherence_report_coherent_frames_has_unit_mu_hat() {
let term = circle_certificate_fixture(2.0, &[(0, 1), (0, 1)]);
let report = dictionary_incoherence_report(&term).unwrap();
assert_abs_diff_eq!(report.mu_hat, 1.0, epsilon = 1.0e-12);
}
#[test]
pub(crate) fn dictionary_incoherence_report_circle_kappa_matches_inverse_radius() {
let radius = 2.5_f64;
let mut term = circle_certificate_fixture(radius, &[(0, 1)]);
term.set_certificate_dispersion(0.25).unwrap();
let report = dictionary_incoherence_report(&term).unwrap();
assert_abs_diff_eq!(
report.per_atom_kappa_hat[0],
1.0 / radius,
epsilon = 1.0e-10
);
assert!(report.snr_proxy.is_finite() && report.snr_proxy > 0.0);
assert_abs_diff_eq!(report.mean_activity_floor, 1.0, epsilon = 1.0e-12);
assert_abs_diff_eq!(report.peak_activity_floor, 1.0, epsilon = 1.0e-12);
}
#[test]
pub(crate) fn search_strategy_exposes_fixed_and_sweep_values() {
assert!(SearchStrategy::Fixed.is_fixed());
let strategy = SearchStrategy::ExponentialSweep {
values: vec![0.1, 1.0, 10.0],
};
assert!(!strategy.is_fixed());
assert_eq!(strategy.sweep_values(), Some([0.1, 1.0, 10.0].as_slice()));
}
#[test]
pub(crate) fn k1_gate_modes_do_not_pin_assignment_to_one() {
let ibp = SaeAssignment::from_blocks_with_mode(
array![[0.0]],
vec![array![[0.0]]],
AssignmentMode::ibp_map(1.0, 1.0, false),
)
.unwrap();
assert_abs_diff_eq!(ibp.try_assignments_row(0).unwrap()[0], 0.5, epsilon = 1e-9);
let jr = SaeAssignment::from_blocks_with_mode(
array![[-1.0]],
vec![array![[0.0]]],
AssignmentMode::jumprelu(1.0, 0.0),
)
.unwrap();
assert_abs_diff_eq!(jr.try_assignments_row(0).unwrap()[0], 0.0, epsilon = 1e-12);
let sm = SaeAssignment::from_blocks_with_mode(
Array2::<f64>::zeros((1, 1)),
vec![array![[0.0]]],
AssignmentMode::softmax(1.0),
)
.unwrap();
assert_abs_diff_eq!(sm.try_assignments_row(0).unwrap()[0], 1.0, epsilon = 1e-12);
}
#[test]
pub(crate) fn jumprelu_surrogate_is_centered_at_threshold() {
let threshold = 2.0;
let temperature = 1.0;
let logits = array![2.0 + 1e-6, 1.0];
let gates = jumprelu_row(logits.view(), temperature, threshold);
assert_abs_diff_eq!(gates[0], 0.5, epsilon = 1e-3);
assert!(
gates[0] < 0.6,
"surrogate not centered at threshold: {}",
gates[0]
);
assert_abs_diff_eq!(gates[1], 0.0, epsilon = 1e-12);
}
pub(crate) fn periodic_basis(coords: &Array2<f64>) -> (Array2<f64>, Array3<f64>) {
let n = coords.nrows();
let mut phi = Array2::<f64>::zeros((n, 3));
let mut jet = Array3::<f64>::zeros((n, 3, 1));
for row in 0..n {
let x = coords[[row, 0]].rem_euclid(1.0);
let angle = 2.0 * std::f64::consts::PI * x;
phi[[row, 0]] = 1.0;
phi[[row, 1]] = angle.sin();
phi[[row, 2]] = angle.cos();
jet[[row, 1, 0]] = 2.0 * std::f64::consts::PI * angle.cos();
jet[[row, 2, 0]] = -2.0 * std::f64::consts::PI * angle.sin();
}
(phi, jet)
}
#[test]
pub(crate) fn ard_axis_prior_periodic_is_continuous_across_cut() {
let alpha = 2.3_f64;
let period = 1.0_f64;
let eps = 1.0e-6;
let below = ArdAxisPrior::eval(alpha, period - eps, Some(period));
let above = ArdAxisPrior::eval(alpha, period + eps, Some(period));
let at_zero = ArdAxisPrior::eval(alpha, 0.0, Some(period));
let cont_tol = 10.0 * alpha * eps; assert!((below.value - above.value).abs() < cont_tol);
assert!((below.grad - above.grad).abs() < cont_tol);
assert!((below.hess - above.hess).abs() < cont_tol);
assert!(below.grad.abs() < cont_tol);
assert!(above.grad.abs() < cont_tol);
assert_abs_diff_eq!(below.value, at_zero.value, epsilon = 1.0e-9);
assert_abs_diff_eq!(at_zero.value, 0.0, epsilon = 1.0e-12);
assert_abs_diff_eq!(at_zero.grad, 0.0, epsilon = 1.0e-12);
assert_abs_diff_eq!(at_zero.hess, alpha, epsilon = 1.0e-12);
let sq_a = ArdAxisPrior::eval(1.0, 0.3, Some(period)).sq_equiv;
let sq_b = ArdAxisPrior::eval(5.0, 0.3, Some(period)).sq_equiv;
assert_abs_diff_eq!(sq_a, sq_b, epsilon = 1.0e-12);
let p = ArdAxisPrior::eval(alpha, 0.3, Some(period));
assert_abs_diff_eq!(0.5 * alpha * p.sq_equiv, p.value, epsilon = 1.0e-12);
}
#[test]
pub(crate) fn ard_axis_prior_value_grad_fd_consistent() {
let alpha = 1.7_f64;
let h = 1.0e-6;
for &period in &[None, Some(1.0_f64), Some(std::f64::consts::TAU)] {
for &t in &[-0.37_f64, 0.02, 0.49, 0.83, 0.999, 1.4] {
let p = ArdAxisPrior::eval(alpha, t, period);
let vp = ArdAxisPrior::eval(alpha, t + h, period).value;
let vm = ArdAxisPrior::eval(alpha, t - h, period).value;
let fd_grad = (vp - vm) / (2.0 * h);
assert_abs_diff_eq!(p.grad, fd_grad, epsilon = 1.0e-5);
let gp = ArdAxisPrior::eval(alpha, t + h, period).grad;
let gm = ArdAxisPrior::eval(alpha, t - h, period).grad;
let fd_hess = (gp - gm) / (2.0 * h);
assert_abs_diff_eq!(p.hess, fd_hess, epsilon = 1.0e-5);
}
}
}
#[test]
pub(crate) fn axis_periods_map_each_topology() {
assert_eq!(LatentManifold::Euclidean.axis_periods(), vec![None]);
assert_eq!(
LatentManifold::Circle { period: 1.0 }.axis_periods(),
vec![Some(1.0)]
);
let torus = LatentManifold::Product(vec![
LatentManifold::Circle { period: 1.0 },
LatentManifold::Circle { period: 1.0 },
]);
assert_eq!(torus.axis_periods(), vec![Some(1.0), Some(1.0)]);
let sphere_chart = LatentManifold::Product(vec![
LatentManifold::Interval { lo: -1.0, hi: 1.0 },
LatentManifold::Circle {
period: std::f64::consts::TAU,
},
]);
assert_eq!(
sphere_chart.axis_periods(),
vec![None, Some(std::f64::consts::TAU)]
);
assert_eq!(
LatentManifold::Sphere { dim: 3 }.axis_periods(),
vec![None, None, None]
);
}
#[test]
pub(crate) fn ard_value_continuous_across_periodic_cut_d1() {
let coords0 = array![[0.999_f64]];
let (phi0, jet0) = periodic_basis(&coords0);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
array![[0.2], [-0.3], [0.4]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((1, 1)),
vec![coords0],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::ibp_map(0.7, 1.0, true),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.1_f64]];
let alpha = 50.0_f64;
let rho = SaeManifoldRho::new(0.0, 0.0, vec![array![alpha.ln()]]);
let ard_before = term.loss(target.view(), &rho).unwrap().ard;
let q = term.assignment.row_block_dim();
let beta_dim = term.beta_dim();
let mut delta_ext = Array1::<f64>::zeros(q);
delta_ext[q - 1] = 0.002;
let delta_beta = Array1::<f64>::zeros(beta_dim);
term.apply_newton_step(delta_ext.view(), delta_beta.view(), 1.0)
.unwrap();
let wrapped = term.assignment.coords[0].row(0)[0];
assert!(
wrapped < 0.01,
"coordinate should have wrapped across the cut, got {wrapped}"
);
let ard_after = term.loss(target.view(), &rho).unwrap().ard;
assert!(
(ard_after - ard_before).abs() < 1.0e-2,
"periodic ARD jumped across the cut: before={ard_before}, after={ard_after}"
);
}
#[test]
pub(crate) fn penalized_objective_continuous_across_periodic_cut_with_registry_ard() {
let coords0 = array![[0.999_f64]];
let (phi0, jet0) = periodic_basis(&coords0);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
array![[0.2], [-0.3], [0.4]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((1, 1)),
vec![coords0],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::ibp_map(0.7, 1.0, true),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.1_f64]];
let alpha = 50.0_f64;
let rho = SaeManifoldRho::new(0.0, 0.0, vec![array![alpha.ln()]]);
let coord = &term.assignment.coords[0];
let mut registry = AnalyticPenaltyRegistry::new();
let ard_pen = ARDPenalty::new(
PsiSlice::full(coord.len(), Some(coord.latent_dim())),
coord.latent_dim(),
);
registry.push(AnalyticPenaltyKind::Ard(Arc::new(ard_pen)));
let obj_before = term
.penalized_objective_total(target.view(), &rho, Some(®istry), 1.0)
.unwrap();
let q = term.assignment.row_block_dim();
let beta_dim = term.beta_dim();
let mut delta_ext = Array1::<f64>::zeros(q);
delta_ext[q - 1] = 0.002; let delta_beta = Array1::<f64>::zeros(beta_dim);
term.apply_newton_step(delta_ext.view(), delta_beta.view(), 1.0)
.unwrap();
let wrapped = term.assignment.coords[0].row(0)[0];
assert!(
wrapped < 0.01,
"coordinate should have wrapped across the cut, got {wrapped}"
);
let obj_after = term
.penalized_objective_total(target.view(), &rho, Some(®istry), 1.0)
.unwrap();
assert!(
(obj_after - obj_before).abs() < 1.0e-2,
"line-search objective jumped across the cut: before={obj_before}, after={obj_after}"
);
}
#[test]
pub(crate) fn scad_coord_penalty_inert_and_continuous_on_periodic_axis() {
use crate::terms::analytic_penalties::{PenaltyConcavity, ScadMcpPenalty};
let coords0 = array![[0.999_f64]];
let (phi0, jet0) = periodic_basis(&coords0);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
array![[0.2], [-0.3], [0.4]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((1, 1)),
vec![coords0],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::ibp_map(0.7, 1.0, true),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.1_f64]];
let rho = SaeManifoldRho::new(0.0, 0.0, vec![array![0.0_f64]]);
let coord = &term.assignment.coords[0];
let mut registry = AnalyticPenaltyRegistry::new();
let scad = ScadMcpPenalty::new(
PsiSlice::full(coord.len(), Some(coord.latent_dim())),
5.0,
coord.n_obs(),
3.7,
1.0e-3,
PenaltyConcavity::Scad,
false,
)
.unwrap();
registry.push(AnalyticPenaltyKind::ScadMcp(Arc::new(scad)));
let with_scad = term
.penalized_objective_total(target.view(), &rho, Some(®istry), 1.0)
.unwrap();
let without = term
.penalized_objective_total(target.view(), &rho, None, 1.0)
.unwrap();
assert!(
(with_scad - without).abs() < 1.0e-12,
"SCAD coord penalty must be inert on a pure periodic axis: \
with={with_scad}, without={without}"
);
let obj_before = with_scad;
let q = term.assignment.row_block_dim();
let beta_dim = term.beta_dim();
let mut delta_ext = Array1::<f64>::zeros(q);
delta_ext[q - 1] = 0.002;
let delta_beta = Array1::<f64>::zeros(beta_dim);
term.apply_newton_step(delta_ext.view(), delta_beta.view(), 1.0)
.unwrap();
let wrapped = term.assignment.coords[0].row(0)[0];
assert!(
wrapped < 0.01,
"coordinate should have wrapped across the cut, got {wrapped}"
);
let obj_after = term
.penalized_objective_total(target.view(), &rho, Some(®istry), 1.0)
.unwrap();
assert!(
(obj_after - obj_before).abs() < 1.0e-2,
"SCAD line-search objective jumped across the periodic cut: \
before={obj_before}, after={obj_after}"
);
}
#[test]
pub(crate) fn scad_coord_penalty_active_on_euclidean_axis() {
let euclid = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((3, 1)),
vec![array![[0.5_f64], [-0.7], [1.3]]],
vec![LatentManifold::Euclidean],
AssignmentMode::softmax(0.7),
)
.unwrap();
assert!(
sae_coord_penalty_euclidean_restriction(&euclid.coords[0]).is_none(),
"Euclidean coord must not be restricted"
);
let circle = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((3, 1)),
vec![array![[0.1_f64], [0.4], [0.9]]],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(0.7),
)
.unwrap();
let (axes, compacted) = sae_coord_penalty_euclidean_restriction(&circle.coords[0])
.expect("periodic coord must be restricted");
assert!(
axes.is_empty(),
"circle has no Euclidean axes, got {axes:?}"
);
assert_eq!(compacted.len(), 0, "compacted target must be empty");
}
#[test]
pub(crate) fn periodic_ard_curvature_is_psd_in_assembled_htt() {
let coords0 = array![[0.40_f64], [0.60_f64]];
let (phi0, jet0) = periodic_basis(&coords0);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
array![[0.2], [-0.3], [0.4]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((2, 1)),
vec![coords0],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.1_f64], [0.2_f64]];
let alpha = 100.0_f64;
let rho = SaeManifoldRho::new(0.0, 0.0, vec![array![alpha.ln()]]);
let sys = term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
for (row_idx, row) in sys.rows.iter().enumerate() {
let d = row.htt.nrows();
for a in 0..d {
assert!(
row.htt[[a, a]] >= 0.0,
"row {row_idx} htt diagonal[{a}]={} must be PSD (von-Mises \
curvature clamped to its positive part)",
row.htt[[a, a]]
);
}
}
}
#[test]
pub(crate) fn snapshot_restore_round_trips_mutated_state() {
let coords0 = array![[0.05], [0.20], [0.55], [0.80]];
let (phi0, jet0) = periodic_basis(&coords0);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
array![[0.2], [-0.3], [0.4]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((4, 1)),
vec![coords0],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::ibp_map(0.7, 1.0, true),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let snapshot = term.snapshot_mutable_state();
let pre_basis = term.atoms[0].basis_values.clone();
let pre_jet = term.atoms[0].basis_jacobian.clone();
let pre_decoder = term.atoms[0].decoder_coefficients.clone();
let pre_logits = term.assignment.logits.clone();
let pre_coords = term.assignment.coords[0].as_matrix();
let q = term.assignment.row_block_dim();
let beta_dim = term.beta_dim();
let delta_ext = Array1::<f64>::from_elem(4 * q, 0.3);
let delta_beta = Array1::<f64>::from_elem(beta_dim, -0.4);
term.apply_newton_step(delta_ext.view(), delta_beta.view(), 1.0)
.unwrap();
assert!(
(&term.atoms[0].basis_values - &pre_basis)
.mapv(f64::abs)
.sum()
> 1e-9
|| (&term.atoms[0].decoder_coefficients - &pre_decoder)
.mapv(f64::abs)
.sum()
> 1e-9,
"apply_newton_step did not perturb the snapshotted state"
);
term.restore_mutable_state(&snapshot);
assert_eq!(term.atoms[0].basis_values, pre_basis);
assert_eq!(term.atoms[0].basis_jacobian, pre_jet);
assert_eq!(term.atoms[0].decoder_coefficients, pre_decoder);
assert_eq!(term.assignment.logits, pre_logits);
assert_eq!(term.assignment.coords[0].as_matrix(), pre_coords);
}
#[test]
pub(crate) fn ibp_path_refreshes_periodic_basis_for_two_newton_iterations() {
let coords0 = array![[0.05], [0.20], [0.55], [0.80]];
let (phi0, jet0) = periodic_basis(&coords0);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
array![[0.2], [-0.3], [0.4]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((4, 1)),
vec![coords0],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::ibp_map(0.7, 1.0, true),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.10], [0.05], [-0.15], [0.20]];
let mut rho = SaeManifoldRho::new(0.0, -6.0, vec![Array1::<f64>::zeros(1)]);
let loss0 = term.loss(target.view(), &rho).unwrap().total();
let basis0 = term.atoms[0].basis_values.clone();
let loss = term
.run_joint_fit_arrow_schur(target.view(), &mut rho, None, 2, 0.05, 1.0e-3, 1.0e-3)
.unwrap();
assert!(loss.total().is_finite());
assert!(loss.total() <= loss0 + 1.0e-8);
assert!(
term.assignment.coords[0]
.as_flat()
.iter()
.all(|v| v.is_finite())
);
assert!(term.assignment.assignments().iter().all(|v| v.is_finite()));
let basis_delta = (&term.atoms[0].basis_values - &basis0).mapv(f64::abs).sum();
assert!(basis_delta > 1.0e-10);
}
pub(crate) fn small_two_atom_periodic_term() -> (SaeManifoldTerm, Array2<f64>, SaeManifoldRho) {
let coords0 = array![[0.05], [0.20], [0.55], [0.80], [0.35]];
let coords1 = array![[0.15], [0.30], [0.65], [0.90], [0.45]];
let (phi0, jet0) = periodic_basis(&coords0);
let (phi1, jet1) = periodic_basis(&coords1);
let atom0 = SaeManifoldAtom::new(
"periodic0",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
array![[0.25], [-0.35], [0.15]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let atom1 = SaeManifoldAtom::new(
"periodic1",
SaeAtomBasisKind::Periodic,
1,
phi1,
jet1,
array![[-0.10], [0.20], [0.30]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let logits = array![
[0.7, -0.2],
[0.1, 0.4],
[-0.3, 0.5],
[0.6, -0.1],
[0.2, 0.3]
];
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
vec![coords0, coords1],
vec![
LatentManifold::Circle { period: 1.0 },
LatentManifold::Circle { period: 1.0 },
],
AssignmentMode::softmax(0.8),
)
.unwrap();
let term = SaeManifoldTerm::new(vec![atom0, atom1], assignment).unwrap();
let target = array![[0.12], [-0.03], [0.08], [0.20], [-0.11]];
let rho = SaeManifoldRho::new(
(-0.3_f64).exp().ln(),
0.7_f64.ln(),
vec![array![0.9_f64.ln()], array![1.1_f64.ln()]],
);
(term, target, rho)
}
#[test]
pub(crate) fn per_atom_loao_ev_attributes_each_load_bearing_atom() {
let (term, _target, rho) = small_two_atom_periodic_term();
let target = term
.try_fitted_for_rho(&rho)
.expect("full reconstruction must assemble");
let ev_full = reconstruction_explained_variance(target.view(), target.view())
.expect("self-reconstruction EV defined");
assert!(
(ev_full - 1.0).abs() < 1e-12,
"target = full reconstruction ⇒ EV(full) = 1; got {ev_full}"
);
let dev = term
.per_atom_loao_explained_variance(target.view(), &rho)
.expect("LOAO EV must evaluate");
assert_eq!(dev.len(), term.k_atoms(), "one ΔEV per atom");
for (atom_idx, d) in dev.iter().enumerate() {
let d = d.unwrap_or_else(|| panic!("atom {atom_idx} ΔEV must be defined"));
assert!(
d > 1e-9,
"load-bearing atom {atom_idx} must earn positive held-out ΔEV; got {d:.3e}"
);
assert!(
d <= 1.0 + 1e-9,
"ΔEV for atom {atom_idx} cannot exceed EV(full)=1; got {d:.6e}"
);
}
let mut dead_term = term.clone();
dead_term.atoms[1].decoder_coefficients.fill(0.0);
let dead_target = term
.try_fitted_for_rho(&rho)
.expect("reconstruction with the live atom-1 decoder");
let dead_dev = dead_term
.per_atom_loao_explained_variance(dead_target.view(), &rho)
.expect("LOAO EV must evaluate for the dead-atom term");
let d_dead = dead_dev[1].expect("dead atom ΔEV defined");
assert!(
d_dead.abs() < 1e-9,
"a zero-decoder atom carries no reconstruction ⇒ ΔEV ≈ 0; got {d_dead:.3e}"
);
}
#[test]
pub(crate) fn decoder_norm_guard_reseeds_collapsed_atom_to_distinct_nonzero() {
let (term0, target, rho) = small_two_atom_periodic_term();
let mut term = term0.clone();
term.atoms[1].decoder_coefficients.fill(0.0);
term.atoms[1].refresh_intrinsic_smooth_penalty();
let norm = |a: &SaeManifoldAtom| -> f64 {
a.decoder_coefficients
.iter()
.map(|v| v * v)
.sum::<f64>()
.sqrt()
};
assert!(norm(&term.atoms[1]) < 1e-12, "atom 1 starts collapsed");
term.enforce_decoder_norm_guard(target.view(), 0, &rho)
.expect("decoder-norm guard must not error on a recoverable collapse");
let reseeded = term
.collapse_events()
.iter()
.any(|e| e.atom == 1 && e.action == CollapseAction::Reseeded);
assert!(
reseeded,
"collapsed atom 1 must be recorded as Reseeded; events: {:?}",
term.collapse_events()
);
let n1 = norm(&term.atoms[1]);
let n0 = norm(&term.atoms[0]);
assert!(
n0 > 0.0 && n1 > SAE_ATOM_DECODER_NORM_COLLAPSE_RATIO * n0,
"reseeded atom 1 decoder must be non-degenerate: ‖B0‖={n0:.3e} ‖B1‖={n1:.3e}"
);
let c1 = term.assignment.coords[1].as_matrix();
let (lo, hi) = c1
.iter()
.fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
(lo.min(v), hi.max(v))
});
assert!(
hi - lo > 1e-6,
"reseeded atom 1 coordinates must span a non-trivial range; got [{lo}, {hi}]"
);
let b0 = &term.atoms[0].decoder_coefficients;
let b1 = &term.atoms[1].decoder_coefficients;
let dot: f64 = b0.iter().zip(b1.iter()).map(|(x, y)| x * y).sum();
let cos = dot.abs() / (n0 * n1);
assert!(
cos < 0.999,
"reseeded atom 1 decoder must be distinct from atom 0 (|cos|={cos:.4})"
);
}
#[test]
pub(crate) fn decoder_norm_guard_is_noop_for_k1() {
let mut term = trivial_k1_euclidean_term();
let n = term.n_obs();
let p = term.output_dim();
let target = Array2::<f64>::zeros((n, p));
let rho = SaeManifoldRho::new(0.0, 0.0, vec![array![0.0_f64]]);
let before = term.atoms[0].decoder_coefficients.clone();
term.enforce_decoder_norm_guard(target.view(), 0, &rho)
.expect("K=1 decoder-norm guard must be a no-op, never error");
assert!(
term.collapse_events().is_empty(),
"K=1 must record no decoder-collapse events"
);
assert_eq!(
term.atoms[0].decoder_coefficients, before,
"K=1 decoder must be untouched by the guard"
);
}
#[test]
pub(crate) fn hybrid_collapse_is_load_bearing_and_dominates() {
let (mut term, _t, rho) = small_two_atom_periodic_term();
let curved = term
.try_fitted_for_rho(&rho)
.expect("curved reconstruction assembles");
let pre = term
.hybrid_collapsed_reconstruction(&rho)
.expect("collapse with no report returns the curved fit");
assert!(
(&curved - &pre).iter().all(|d| d.abs() < 1e-15),
"with no hybrid-split report the collapse must equal the curved fit"
);
for basis_row in 1..term.atoms[0].decoder_coefficients.nrows() {
for out_col in 0..term.atoms[0].decoder_coefficients.ncols() {
term.atoms[0].decoder_coefficients[[basis_row, out_col]] = 0.0;
}
}
let report = term
.compute_hybrid_split_report(&rho)
.expect("hybrid split report computes")
.expect("eligible d=1 atoms present a report");
term.hybrid_split_report = Some(report);
let collapsed_any = term
.hybrid_split_report
.as_ref()
.unwrap()
.verdicts
.iter()
.any(|v| v.linear_image.is_some());
assert!(
collapsed_any,
"a straight atom must collapse at least one slot to the linear tail"
);
let target = term
.try_fitted_for_rho(&rho)
.expect("post-straighten curved reconstruction assembles");
let ev_curved = reconstruction_explained_variance(target.view(), target.view())
.expect("self-reconstruction EV defined");
assert!(
(ev_curved - 1.0).abs() < 1e-12,
"target = curved fit ⇒ EV(curved) = 1; got {ev_curved}"
);
let ev_collapsed = term
.hybrid_collapsed_explained_variance(target.view(), &rho)
.expect("collapsed EV evaluates")
.expect("collapsed EV defined");
assert!(
ev_collapsed >= ev_curved - 1e-6,
"collapsing a straight atom must preserve EV (match-or-beat dominance \
floor): curved {ev_curved:.9}, collapsed {ev_collapsed:.9}"
);
let verdict = term
.hybrid_split_report
.as_ref()
.unwrap()
.verdicts
.iter()
.find(|v| v.linear_image.is_some())
.expect("a collapsed slot exists");
let collapsed_idx = verdict.linear_image.as_ref().unwrap().atom_idx;
let curved_params = term.atoms[collapsed_idx].decoder_coefficients.len();
assert!(
verdict.choice.num_parameters < curved_params,
"the linear-collapsed slot must shed curved coefficients: linear \
{} < curved {}",
verdict.choice.num_parameters,
curved_params
);
}
#[test]
pub(crate) fn assignment_logit_step_cap_bounds_single_iteration_gate_motion() {
let (mut term, _target, _rho) = small_two_atom_periodic_term();
let n = term.assignment.n_obs();
let q = term.assignment.row_block_dim();
let diff_before = term.assignment.logits[[0, 0]] - term.assignment.logits[[0, 1]];
let mut delta = Array1::<f64>::zeros(n * q);
delta[0] = 1.0e6;
let delta_beta = Array1::<f64>::zeros(term.beta_dim());
term.apply_newton_step(delta.view(), delta_beta.view(), 1.0)
.expect("step applies");
let cap = SAE_ASSIGNMENT_LOGIT_STEP_CAP_TAUS * term.assignment.mode.temperature();
let diff_after = term.assignment.logits[[0, 0]] - term.assignment.logits[[0, 1]];
assert!(
((diff_after - diff_before) - cap).abs() < 1.0e-9,
"a 1e6 raw logit delta must realise exactly the {cap}-cap, moved {}",
diff_after - diff_before
);
}
#[test]
pub(crate) fn active_mass_guard_reseeds_once_then_records_terminal_collapse() {
let (mut term, _target, _rho) = small_two_atom_periodic_term();
let n = term.assignment.n_obs();
let slam = |term: &mut SaeManifoldTerm| {
for row in 0..n {
term.assignment.logits[[row, 0]] = 0.0;
term.assignment.logits[[row, 1]] = -1.0e3;
}
};
slam(&mut term);
term.enforce_active_mass_guard(0, None).expect("guard runs");
assert_eq!(term.collapse_events().len(), 1);
let ev = term.collapse_events()[0];
assert_eq!(ev.atom, 1);
assert_eq!(ev.action, CollapseAction::Reseeded);
assert!(ev.max_active_mass < ev.floor);
let masses = term.assignment.assignments();
let max1 = (0..n).map(|r| masses[[r, 1]]).fold(0.0_f64, f64::max);
assert!(max1 > SAE_ATOM_ACTIVE_MASS_FLOOR);
term.enforce_active_mass_guard(1, None).expect("guard runs");
assert_eq!(term.collapse_events().len(), 1);
slam(&mut term);
term.enforce_active_mass_guard(2, None).expect("guard runs");
term.enforce_active_mass_guard(3, None).expect("guard runs");
let terminals: Vec<_> = term
.collapse_events()
.iter()
.filter(|e| e.action == CollapseAction::Terminal)
.collect();
assert_eq!(terminals.len(), 1);
assert_eq!(terminals[0].atom, 1);
assert!(
term.collapse_events().iter().all(|e| e.atom == 1),
"the healthy atom must never be flagged"
);
}
#[test]
pub(crate) fn sae_rho_seed_dispersion_scaling_shifts_every_scale_coupled_axis() {
let rho = SaeManifoldRho::new(0.7_f64.ln(), 1.3_f64.ln(), vec![array![0.2, -0.4]]);
let dispersion = 0.05_f64 * 0.05;
let scaled = rho
.seed_scaled_by_dispersion_for_assignment(dispersion, AssignmentMode::softmax(1.0))
.unwrap();
let shift = dispersion.ln();
assert_abs_diff_eq!(
scaled.log_lambda_sparse,
rho.log_lambda_sparse + shift,
epsilon = 1.0e-14
);
assert_abs_diff_eq!(
scaled.log_lambda_smooth,
rho.log_lambda_smooth + shift,
epsilon = 1.0e-14
);
assert_abs_diff_eq!(
scaled.log_ard[0][0],
rho.log_ard[0][0] + shift,
epsilon = 1.0e-14
);
assert_abs_diff_eq!(
scaled.log_ard[0][1],
rho.log_ard[0][1] + shift,
epsilon = 1.0e-14
);
let learnable_ibp = rho
.seed_scaled_by_dispersion_for_assignment(
dispersion,
AssignmentMode::ibp_map(1.0, 1.0, true),
)
.unwrap();
assert_abs_diff_eq!(
learnable_ibp.log_lambda_sparse,
rho.log_lambda_sparse,
epsilon = 1.0e-14
);
assert_abs_diff_eq!(
learnable_ibp.log_lambda_smooth,
rho.log_lambda_smooth + shift,
epsilon = 1.0e-14
);
assert_abs_diff_eq!(
learnable_ibp.log_ard[0][0],
rho.log_ard[0][0] + shift,
epsilon = 1.0e-14
);
}
#[test]
pub(crate) fn fit_data_collapse_records_terminal_event_for_active_atom() {
let coords = array![[0.0], [0.25], [0.5], [0.75]];
let (phi, jet) = periodic_basis(&coords);
let atom = SaeManifoldAtom::new(
"circle",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
Array2::<f64>::zeros((3, 2)),
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((4, 1)),
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(1.0),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]];
let fitted = Array2::<f64>::zeros(target.dim());
let assignments = Array2::<f64>::ones((4, 1));
let recorded = term
.record_fit_data_collapse_if_needed(target.view(), fitted.view(), assignments.view(), 7)
.unwrap();
assert!(recorded);
let terminals: Vec<_> = term
.collapse_events()
.iter()
.filter(|event| event.action == CollapseAction::Terminal)
.collect();
assert_eq!(terminals.len(), 1);
assert_eq!(terminals[0].atom, 0);
assert_eq!(terminals[0].iteration, 7);
assert!(terminals[0].max_active_mass <= SAE_FIT_DATA_COLLAPSE_EV_FLOOR);
}
pub(crate) fn deterministic_circle_noise(row: usize, col: usize) -> f64 {
let x = (row as f64 + 1.0) * 12.9898 + (col as f64 + 1.0) * 78.233;
(x.sin() * 43758.5453).sin()
}
pub(crate) fn planted_circle_data(n: usize, sigma: f64) -> Array2<f64> {
let mut z = Array2::<f64>::zeros((n, 2));
for row in 0..n {
let theta = std::f64::consts::TAU * row as f64 / n as f64;
z[[row, 0]] = theta.cos() + sigma * deterministic_circle_noise(row, 0);
z[[row, 1]] = theta.sin() + sigma * deterministic_circle_noise(row, 1);
}
z
}
pub(crate) fn global_ev(target: ArrayView2<'_, f64>, fitted: ArrayView2<'_, f64>) -> f64 {
let (n, p) = target.dim();
let mut means = vec![0.0_f64; p];
for col in 0..p {
for row in 0..n {
means[col] += target[[row, col]];
}
means[col] /= n as f64;
}
let mut ssr = 0.0_f64;
let mut sst = 0.0_f64;
for row in 0..n {
for col in 0..p {
let r = target[[row, col]] - fitted[[row, col]];
ssr += r * r;
let centered = target[[row, col]] - means[col];
sst += centered * centered;
}
}
1.0 - ssr / sst.max(1.0e-300)
}
#[derive(Clone, Copy)]
pub(crate) enum PlantedCircleAssignmentMode {
Softmax,
IbpMap,
}
impl PlantedCircleAssignmentMode {
pub(crate) fn label(self) -> &'static str {
match self {
Self::Softmax => "softmax",
Self::IbpMap => "ibp_map",
}
}
pub(crate) fn mode(self) -> AssignmentMode {
const TAU: f64 = 1.0;
const ALPHA: f64 = 1.0;
match self {
Self::Softmax => AssignmentMode::softmax(TAU),
Self::IbpMap => AssignmentMode::ibp_map(TAU, ALPHA, false),
}
}
pub(crate) fn seed_logit(self) -> f64 {
const TAU: f64 = 1.0;
match self {
Self::Softmax => 0.0,
Self::IbpMap => 6.0 * TAU,
}
}
pub(crate) fn seed_gate(self) -> f64 {
match self {
Self::Softmax => 1.0,
Self::IbpMap => 1.0 / (1.0 + (-6.0_f64).exp()),
}
}
}
pub(crate) fn planted_circle_seed_term(
z: ArrayView2<'_, f64>,
assignment_mode: PlantedCircleAssignmentMode,
) -> (SaeManifoldTerm, f64) {
let n = z.nrows();
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(3).unwrap());
let seed_coords = sae_pca_seed_initial_coords(z, &[SaeAtomBasisKind::Periodic], &[1]).unwrap();
let coords = seed_coords.slice(s![0, .., 0..1]).to_owned();
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let seed_gate = assignment_mode.seed_gate();
let gated_phi = &phi * seed_gate;
let mut xtx = fast_ata(&gated_phi);
for i in 0..xtx.nrows() {
xtx[[i, i]] += 1.0e-10;
}
let xtz = fast_atb(&gated_phi, &z.to_owned());
let decoder = xtx.cholesky(Side::Lower).unwrap().solve_mat(&xtz);
let seed_fitted = gated_phi.dot(&decoder);
let mut rss = 0.0_f64;
for row in 0..n {
for col in 0..z.ncols() {
let r = z[[row, col]] - seed_fitted[[row, col]];
rss += r * r;
}
}
let seed_dispersion = (rss / (n * z.ncols()) as f64).max(1.0e-12);
let atom = SaeManifoldAtom::new(
"circle",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
decoder,
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(evaluator);
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::from_elem((n, 1), assignment_mode.seed_logit()),
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
assignment_mode.mode(),
)
.unwrap();
(
SaeManifoldTerm::new(vec![atom], assignment).unwrap(),
seed_dispersion,
)
}
#[test]
pub(crate) fn planted_circle_noise_scale_sweep_reaches_high_ev_with_dimensionless_rho_seed() {
for assignment_mode in [
PlantedCircleAssignmentMode::Softmax,
PlantedCircleAssignmentMode::IbpMap,
] {
let assignment_label = assignment_mode.label();
for &n in &[40usize, 250usize] {
for &sigma in &[0.02_f64, 0.05, 0.18] {
let z = planted_circle_data(n, sigma);
let (term, seed_dispersion) = planted_circle_seed_term(z.view(), assignment_mode);
let seed_ev = global_ev(z.view(), term.fitted().view());
let init_rho = SaeManifoldRho::new(0.02_f64.ln(), 1.0_f64.ln(), vec![array![0.0]])
.seed_scaled_by_dispersion_for_assignment(
seed_dispersion,
assignment_mode.mode(),
)
.unwrap();
let init_rho_flat = init_rho.to_flat();
let n_params = init_rho_flat.len();
let mut objective = SaeManifoldOuterObjective::new(
term,
z.clone(),
None,
init_rho,
50,
0.04,
1.0e-6,
1.0e-6,
);
crate::solver::outer_strategy::OuterProblem::new(n_params)
.with_initial_rho(init_rho_flat)
.run(&mut objective, "SAE planted circle dimensionless seed")
.unwrap();
let (fitted_term, rho, _loss) = objective.into_fitted();
let fitted = fitted_term.fitted();
let ev = global_ev(z.view(), fitted.view());
assert!(
ev > 0.95,
"planted circle assignment={assignment_label} n={n} sigma={sigma} seed_ev={seed_ev:.4} seed_phi={seed_dispersion:.3e} \
final_rho=({:.3}, {:.3}, {:?}) EV={ev:.4} should exceed 0.95",
rho.log_lambda_sparse,
rho.log_lambda_smooth,
rho.log_ard
);
assert!(
fitted_term.collapse_events().is_empty(),
"healthy planted circle assignment={assignment_label} fit should not record collapse events: {:?}",
fitted_term.collapse_events()
);
}
}
}
}
#[test]
pub(crate) fn sae_value_probe_refusal_classification_is_inner_only() {
assert!(
SaeManifoldOuterObjective::is_recoverable_value_probe_refusal(
"SaeManifoldTerm::reml_criterion: inner solve did not converge at fixed ρ"
)
);
assert!(
SaeManifoldOuterObjective::is_recoverable_value_probe_refusal(
"SaeManifoldTerm::reml_criterion: undamped evidence factorization hit a non-PD per-row H_tt block before KKT stationarity"
)
);
assert!(
!SaeManifoldOuterObjective::is_recoverable_value_probe_refusal(
"SaeManifoldTerm::reml_criterion: row-gauge evidence deflation count re-anchored \
4 times within one optimization; the quotient dimension is not stabilizing"
)
);
}
#[test]
pub(crate) fn streaming_exact_reml_matches_full_batch_reml_small_sae() {
let (term0, target, rho) = small_two_atom_periodic_term();
let mut full = term0.clone();
let mut streaming = term0;
let (full_cost, full_loss, _cache) = full
.reml_criterion_with_cache(target.view(), &rho, None, 2, 0.25, 1.0e-4, 1.0e-4)
.unwrap();
let (stream_cost, stream_loss) = streaming
.reml_criterion_streaming_exact(target.view(), &rho, None, 2, 0.25, 1.0e-4, 1.0e-4)
.unwrap();
assert_abs_diff_eq!(stream_cost, full_cost, epsilon = 1.0e-8);
assert_abs_diff_eq!(stream_loss.total(), full_loss.total(), epsilon = 1.0e-8);
}
#[test]
pub(crate) fn value_probe_refine_policy_ranks_same_criterion_as_full_policy() {
let (term0, target, rho) = small_two_atom_periodic_term();
let mut full = term0.clone();
let mut probe = term0;
let (full_cost, full_loss) = full
.reml_criterion_with_refine_policy(target.view(), &rho, None, 2, 0.25, 1.0e-4, 1.0e-4, true)
.expect("full-budget criterion must converge on the small fixture");
let (probe_cost, probe_loss) = probe
.reml_criterion_with_refine_policy(
target.view(),
&rho,
None,
2,
0.25,
1.0e-4,
1.0e-4,
false,
)
.expect("probe-budget criterion must converge on the small fixture");
assert_abs_diff_eq!(probe_cost, full_cost, epsilon = 1.0e-8);
assert_abs_diff_eq!(probe_loss.total(), full_loss.total(), epsilon = 1.0e-8);
}
#[test]
pub(crate) fn refine_iteration_limit_probe_budget_never_extends() {
let probe_base = 16usize;
assert_eq!(
SaeManifoldTerm::refine_iteration_limit(
probe_base,
probe_base,
probe_base,
Some(1.0),
0.5,
true
),
probe_base
);
let accepted_base = 64usize;
let accepted_progress = 256usize;
assert_eq!(
SaeManifoldTerm::refine_iteration_limit(
accepted_base,
accepted_base,
accepted_progress,
Some(1.0),
0.5,
false
),
accepted_progress
);
assert_eq!(
SaeManifoldTerm::refine_iteration_limit(
accepted_base,
accepted_base,
accepted_progress,
Some(1.0),
1.0,
false
),
accepted_base
);
assert_eq!(
SaeManifoldTerm::refine_iteration_limit(
accepted_base - 1,
accepted_base,
accepted_progress,
None,
1.0e9,
false
),
accepted_base
);
}
#[test]
pub(crate) fn reml_retries_refinement_after_non_pd_undamped_evidence_factor() {
let (mut term0, target, rho) = small_two_atom_periodic_term();
let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
let cold_sys = term0
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
let cold_factor = solve_arrow_newton_step_with_options(&cold_sys, 0.0, 0.0, &options);
let cold_err = match cold_factor {
Err(err) => err,
Ok(_) => panic!("fixture must start with a non-PD undamped evidence row factor"),
};
assert!(
SaeManifoldTerm::is_undamped_evidence_row_non_pd(&cold_err),
"fixture must start with a genuine evidence-mode non-PD row factor; got {cold_err}",
);
let mut full = term0.clone();
let mut streaming = term0;
let (full_cost, full_loss, cache) = full
.reml_criterion_with_cache(target.view(), &rho, None, 1, 0.25, 1.0e-4, 1.0e-4)
.expect("dense REML must refine through the cold non-PD evidence factor");
let log_det = arrow_log_det_from_cache(&cache).expect("refined cache must carry log-det");
assert!(full_cost.is_finite());
assert!(full_loss.total().is_finite());
assert!(log_det.is_finite());
let (stream_cost, stream_loss) = streaming
.reml_criterion_streaming_exact(target.view(), &rho, None, 1, 0.25, 1.0e-4, 1.0e-4)
.expect("streaming REML must share the dense refinement retry");
assert_abs_diff_eq!(stream_cost, full_cost, epsilon = 1.0e-8);
assert_abs_diff_eq!(stream_loss.total(), full_loss.total(), epsilon = 1.0e-8);
}
#[test]
pub(crate) fn reconstruction_dispersion_uses_ard_shrunk_coordinate_edf() {
let n = 24usize;
let p = 2usize;
let coords = Array2::from_shape_fn((n, 1), |(row, _)| (row as f64 + 0.25) / n as f64);
let (phi, jet) = periodic_basis(&coords);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
array![[0.30, -0.10], [0.20, 0.40], [-0.35, 0.15]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 1)),
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(1.0),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = Array2::from_shape_fn((n, p), |(row, col)| {
let x = (row as f64 + 0.5) / n as f64;
if col == 0 {
0.45 * (std::f64::consts::TAU * x).sin() + 0.07
} else {
-0.20 * (std::f64::consts::TAU * x).cos() + 0.03 * row as f64
}
});
let alpha = 250.0_f64;
let rho = SaeManifoldRho::new(0.0, 0.8_f64.ln(), vec![array![alpha.ln()]]);
let loss = term.loss(target.view(), &rho).unwrap();
let sys = term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
let (_delta_t, _delta_beta, cache) =
solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, &options).unwrap();
let dispersion = term.reconstruction_dispersion(&loss, &cache, &rho).unwrap();
let smooth_edf = term
.decoder_smoothness_effective_dof(&cache, rho.lambda_smooth())
.unwrap();
let beta_edf = (term.beta_dim() as f64 - smooth_edf).max(0.0);
let traces = term.ard_inverse_traces(&cache).unwrap();
let coord_edf = (n as f64 - alpha * traces[0][0]).clamp(0.0, n as f64);
let rss = 2.0 * loss.data_fit;
let expected = rss / ((n * p) as f64 - beta_edf - coord_edf).max(1.0);
assert_abs_diff_eq!(dispersion, expected, epsilon = 1.0e-10);
let old_full_coordinate_edf = n as f64;
let old_full_coordinate_dispersion =
rss / ((n * p) as f64 - beta_edf - old_full_coordinate_edf).max(1.0);
assert!(
coord_edf < 0.25 * old_full_coordinate_edf,
"test setup must put the coordinate axis in an ARD-shrunk regime; \
coord_edf={coord_edf}, old_full_coordinate_edf={old_full_coordinate_edf}"
);
assert!(
dispersion < 0.75 * old_full_coordinate_dispersion,
"φ̂ must use the ARD-shrunk coordinate edf, not the old full \
coordinate count: got {dispersion}, old formula {old_full_coordinate_dispersion}"
);
}
#[test]
pub(crate) fn streaming_plan_routes_by_memory_budget_with_identical_logdet() {
let (term0, target, rho) = small_two_atom_periodic_term();
let total_basis: usize = term0.atoms.iter().map(|atom| atom.basis_size()).sum();
let d_max = term0
.atoms
.iter()
.map(|atom| atom.latent_dim)
.max()
.unwrap();
let dense_plan = sae_streaming_plan_from_budget(
term0.n_obs(),
total_basis,
term0.k_atoms(),
d_max,
term0.beta_dim(),
usize::MAX / 4,
1024 * 1024,
usize::MAX / 2,
);
assert!(!dense_plan.streaming);
assert!(dense_plan.direct_admitted);
let streaming_plan = sae_streaming_plan_from_budget(
term0.n_obs(),
total_basis,
term0.k_atoms(),
d_max,
term0.beta_dim(),
1,
512,
2,
);
assert!(streaming_plan.streaming);
assert!(!streaming_plan.direct_admitted);
let mut full = term0.clone();
full.reml_criterion_with_cache(target.view(), &rho, None, 2, 0.25, 1.0e-4, 1.0e-4)
.unwrap();
let sys = full
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
let factor_result = solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, &options).unwrap();
let full_logdet = arrow_log_det_from_cache(&factor_result.2).unwrap();
let mut streaming = StreamingArrowSchur::from_system(&sys, streaming_plan.chunk_size);
let streaming_logdet = streaming.exact_arrow_log_det(0.0, 0.0, &options).unwrap();
assert_abs_diff_eq!(streaming_logdet, full_logdet, epsilon = 1.0e-8);
}
#[test]
pub(crate) fn giant_host_working_set_plan_flips_to_matrix_free_before_dense_allocation() {
let n_obs = 128usize;
let total_basis = 48usize;
let k_atoms = 8usize;
let d_max = 2usize;
let p_out = 2048usize;
let border_dim = total_basis * p_out;
let budget = 60usize * 1024 * 1024 * 1024;
let plan = sae_streaming_plan_from_budget(
n_obs,
total_basis,
k_atoms,
d_max,
border_dim,
budget,
8 * 1024 * 1024,
120usize * 1024 * 1024 * 1024,
);
assert_eq!(border_dim, 98_304);
assert_eq!(
plan.estimated_row_cross_bytes,
n_obs * k_atoms * (1 + d_max) * border_dim * SAE_BYTES_PER_F64
);
assert!(plan.estimated_dense_schur_bytes > budget);
assert!(plan.estimated_matrix_free_peak_bytes < budget);
assert!(plan.streaming);
assert!(!plan.direct_admitted);
assert!(plan.matrix_free_admitted);
assert_eq!(
plan.solve_options_for_border_dim(border_dim).mode,
crate::solver::arrow_schur::ArrowSolverMode::InexactPCG
);
}
#[test]
pub(crate) fn sparse_active_layout_work_scales_with_active_atoms_not_total_k() {
let n = 3;
let k_atoms = 100_000;
let mut active_rows = Vec::with_capacity(n);
for row in 0..n {
active_rows.push(vec![row, 10_000 + row, 90_000 + row]);
}
let coord_dims = vec![1usize; k_atoms];
let coord_offsets_full: Vec<usize> = (0..k_atoms).map(|k| k_atoms + k).collect();
let layout = SaeRowLayout::from_active_atoms(active_rows, coord_dims, coord_offsets_full);
for row in 0..n {
assert_eq!(layout.active_atoms[row].len(), 3);
assert_eq!(layout.row_q_active(row), 6);
}
let compact_work: usize = (0..n)
.map(|row| {
let q = layout.row_q_active(row);
q * q
})
.sum();
let dense_q = 2 * k_atoms;
let dense_work = n * dense_q * dense_q;
assert!(compact_work < dense_work / 1_000_000_000);
assert_eq!(compact_work, n * 36);
}
#[test]
pub(crate) fn run_joint_fit_arrow_schur_escalates_ridge_on_non_pd_row_block() {
let coords = array![[0.1], [0.4], [0.7]];
let (phi, jet) = periodic_basis(&coords);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
array![[0.05], [-0.05], [0.05]],
Array2::<f64>::zeros((3, 3)),
)
.unwrap();
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((3, 1)),
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.20], [-0.10], [0.45]];
let mut rho = SaeManifoldRho::new(0.0, -20.0, vec![Array1::<f64>::zeros(1)]);
let result =
term.run_joint_fit_arrow_schur(target.view(), &mut rho, None, 1, 1.0, 1.0e-6, 1.0e-6);
assert!(
result.is_ok(),
"run_joint_fit_arrow_schur should recover from degenerate H_tt via LM ridge escalation; got: {result:?}",
);
}
#[test]
pub(crate) fn rank_revealing_reduction_collapses_unexcited_circle_harmonic_to_full_rank() {
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(5).unwrap());
let coords = array![[0.1], [0.45], [0.8], [0.1], [0.45], [0.8]];
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
assert_eq!(
phi.ncols(),
5,
"fixed-depth circle basis emits M = 5 columns"
);
let penalty = Array2::<f64>::eye(5);
let decoder = array![[0.05], [-0.05], [0.05], [0.02], [-0.02]];
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi.clone(),
jet,
decoder.clone(),
penalty,
)
.unwrap()
.with_basis_second_jet(evaluator.clone());
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::ones((6, 1)),
vec![coords.clone()],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let recon_before = phi.dot(&decoder);
term.reduce_atoms_to_data_supported_rank().unwrap();
let r = term.atoms[0].basis_size();
assert_eq!(
r, 3,
"rank-revealing reduction must drop the unexcited harmonic (r = 3 < M = 5)",
);
assert_eq!(term.atoms[0].decoder_coefficients.nrows(), 3);
assert_eq!(term.atoms[0].basis_jacobian.dim(), (6, 3, 1));
assert_eq!(term.atoms[0].smooth_penalty.dim(), (3, 3));
use crate::linalg::faer_ndarray::FaerEigh;
let reduced_design = term.atoms[0].basis_values.clone();
let gram = reduced_design.t().dot(&reduced_design);
let (evals, _) = gram.eigh(faer::Side::Lower).unwrap();
let max_eig = evals.iter().cloned().fold(0.0_f64, f64::max);
for &lam in evals.iter() {
assert!(
lam > 1e-9 * max_eig,
"reduced design Gram must be full rank; got eigenvalue {lam} (max {max_eig})",
);
}
let recon_after = reduced_design.dot(&term.atoms[0].decoder_coefficients);
for i in 0..recon_before.nrows() {
assert!(
(recon_before[[i, 0]] - recon_after[[i, 0]]).abs() < 1e-9,
"reduction must not change the data-fit reconstruction at row {i}: \
before={} after={}",
recon_before[[i, 0]],
recon_after[[i, 0]],
);
}
let (refreshed, _) = term.atoms[0]
.basis_evaluator
.as_ref()
.unwrap()
.evaluate(coords.view())
.unwrap();
assert_eq!(
refreshed.ncols(),
3,
"the SubspaceReducedEvaluator must re-emit the reduced width on refresh",
);
for i in 0..refreshed.nrows() {
for j in 0..3 {
assert!(
(refreshed[[i, j]] - reduced_design[[i, j]]).abs() < 1e-12,
"refresh must reproduce the reduced design bit-for-bit",
);
}
}
}
#[test]
pub(crate) fn subspace_reduced_evaluator_composes_all_jets_by_q() {
let inner = Arc::new(PeriodicHarmonicEvaluator::new(7).unwrap());
let coords = array![[-0.3_f64], [0.0], [0.15], [0.42], [0.88]];
let m = inner.num_basis; let mut a = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
a[[i, j]] = 1.0 / (1.0 + (i as f64 - j as f64).abs());
}
}
let (_evals, evecs) = a.eigh(Side::Lower).unwrap();
let r = 4usize;
let mut q = Array2::<f64>::zeros((m, r));
for col in 0..r {
for row in 0..m {
q[[row, col]] = evecs[[row, col]];
}
}
let reduced = SubspaceReducedEvaluator::new(inner.clone(), q.clone()).unwrap();
assert_eq!(reduced.inner_width(), m);
assert_eq!(reduced.reduced_width(), r);
let (phi_in, jet_in) = inner.evaluate(coords.view()).unwrap();
let (phi_red, jet_red) = reduced.evaluate(coords.view()).unwrap();
let phi_expect = phi_in.dot(&q);
assert_eq!(phi_red.dim(), phi_expect.dim());
for i in 0..phi_red.nrows() {
for j in 0..r {
assert_abs_diff_eq!(phi_red[[i, j]], phi_expect[[i, j]], epsilon = 1e-12);
}
}
for axis in 0..jet_in.shape()[2] {
let expect = jet_in.slice(s![.., .., axis]).to_owned().dot(&q);
for i in 0..jet_red.shape()[0] {
for j in 0..r {
assert_abs_diff_eq!(jet_red[[i, j, axis]], expect[[i, j]], epsilon = 1e-12);
}
}
}
let h_in = inner.second_jet(coords.view()).unwrap();
let h_red = reduced.second_jet(coords.view()).unwrap();
let d = h_in.shape()[2];
for a_ax in 0..d {
for c_ax in 0..d {
let expect = h_in.slice(s![.., .., a_ax, c_ax]).to_owned().dot(&q);
for i in 0..h_red.shape()[0] {
for j in 0..r {
assert_abs_diff_eq!(h_red[[i, j, a_ax, c_ax]], expect[[i, j]], epsilon = 1e-12);
}
}
}
}
let t_in = inner.third_jet(coords.view()).unwrap();
let t_red = reduced.third_jet_dyn(coords.view()).unwrap().unwrap();
for a_ax in 0..d {
for c_ax in 0..d {
for e_ax in 0..d {
let expect = t_in.slice(s![.., .., a_ax, c_ax, e_ax]).to_owned().dot(&q);
for i in 0..t_red.shape()[0] {
for j in 0..r {
assert_abs_diff_eq!(
t_red[[i, j, a_ax, c_ax, e_ax]],
expect[[i, j]],
epsilon = 1e-12
);
}
}
}
}
}
}
#[test]
pub(crate) fn rank_reduction_is_idempotent_on_already_reduced_atom() {
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(5).unwrap());
let coords = array![[0.1], [0.45], [0.8], [0.1], [0.45], [0.8]];
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let penalty = Array2::<f64>::eye(5);
let decoder = array![[0.05], [-0.05], [0.05], [0.02], [-0.02]];
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
decoder,
penalty,
)
.unwrap()
.with_basis_second_jet(evaluator.clone());
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::ones((6, 1)),
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
term.reduce_atoms_to_data_supported_rank().unwrap();
assert_eq!(term.atoms[0].basis_size(), 3);
let design_after_first = term.atoms[0].basis_values.clone();
let decoder_after_first = term.atoms[0].decoder_coefficients.clone();
term.reduce_atoms_to_data_supported_rank().unwrap();
assert_eq!(
term.atoms[0].basis_size(),
3,
"a second reduction pass on a full-rank reduced atom must be a no-op",
);
let design_after_second = &term.atoms[0].basis_values;
for i in 0..design_after_first.nrows() {
for j in 0..3 {
assert_eq!(
design_after_second[[i, j]],
design_after_first[[i, j]],
"idempotent reduction must leave the reduced design byte-identical",
);
}
}
let decoder_after_second = &term.atoms[0].decoder_coefficients;
for i in 0..3 {
assert_eq!(
decoder_after_second[[i, 0]],
decoder_after_first[[i, 0]],
"idempotent reduction must leave the reduced decoder byte-identical",
);
}
}
#[test]
pub(crate) fn full_rank_circle_design_keeps_full_harmonic_depth_unchanged() {
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(5).unwrap());
let coords = array![[0.05], [0.27], [0.46], [0.68], [0.91]];
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let penalty = Array2::<f64>::eye(5);
let decoder = array![[0.05], [-0.05], [0.05], [0.02], [-0.02]];
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi.clone(),
jet,
decoder.clone(),
penalty,
)
.unwrap()
.with_basis_second_jet(evaluator.clone());
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::ones((5, 1)),
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
term.reduce_atoms_to_data_supported_rank().unwrap();
assert_eq!(
term.atoms[0].basis_size(),
5,
"a full-rank circle design must keep all 5 harmonic columns",
);
let after_phi = &term.atoms[0].basis_values;
for i in 0..5 {
for j in 0..5 {
assert_eq!(
after_phi[[i, j]],
phi[[i, j]],
"full-rank basis must be unchanged by the (no-op) reduction",
);
}
}
let after = &term.atoms[0].decoder_coefficients;
for i in 0..5 {
assert_eq!(
after[[i, 0]],
decoder[[i, 0]],
"full-rank decoder must be unchanged by the (no-op) reduction",
);
}
}
#[test]
pub(crate) fn solve_newton_step_escalates_ridge_on_non_pd_row_block() {
let coords = array![[0.1], [0.4], [0.7]];
let (phi, jet) = periodic_basis(&coords);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
array![[0.05], [-0.05], [0.05]],
Array2::<f64>::zeros((3, 3)),
)
.unwrap();
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((3, 1)),
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.20], [-0.10], [0.45]];
let rho = SaeManifoldRho::new(0.0, -20.0, vec![Array1::<f64>::zeros(1)]);
let result = term.solve_newton_step(target.view(), &rho, None, 1.0e-6, 1.0e-6);
assert!(
result.is_ok(),
"solve_newton_step should recover from degenerate H_tt via LM ridge escalation; got: {result:?}",
);
}
#[test]
pub(crate) fn sae_arrow_schur_beta_quadratic_model_matches_penalized_loss_change() {
let coords = array![[0.10], [0.35], [0.80]];
let (phi, jet) = periodic_basis(&coords);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
array![[0.65], [-0.45], [0.25]],
array![[3.0, 0.4, -0.2], [0.1, 2.5, 0.3], [-0.5, 0.2, 1.8]],
)
.unwrap();
let assignment = SaeAssignment::from_blocks_with_mode(
Array2::<f64>::zeros((3, 1)),
vec![coords],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.20], [-0.10], [0.45]];
let rho = SaeManifoldRho::new(0.0, 1.3_f64.ln(), vec![array![0.9_f64.ln()]]);
let sys = term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
let beta0 = term.flatten_beta();
let loss0 = term.loss(target.view(), &rho).unwrap().total();
let mut direction = sys.gb.mapv(|v| -v);
let direction_norm = direction.iter().map(|v| v * v).sum::<f64>().sqrt();
assert!(direction_norm > 1.0e-12);
for value in direction.iter_mut() {
*value /= direction_norm;
}
let epsilon = 1.0e-3;
let delta = direction.mapv(|v| epsilon * v);
let beta_trial = beta0 + δ
term.set_flat_beta(beta_trial.view()).unwrap();
let actual = term.loss(target.view(), &rho).unwrap().total() - loss0;
let linear = sys.gb.dot(&delta);
let mut hbb_delta = Array1::<f64>::zeros(delta.len());
{
let op = sys.effective_penalty_op();
let d_slice = delta.as_slice().expect("delta is contiguous");
let hd_slice = hbb_delta.as_slice_mut().expect("hbb_delta is contiguous");
op.matvec(d_slice, hd_slice);
}
let quadratic = 0.5 * delta.dot(&hbb_delta);
let predicted = linear + quadratic;
let error = (actual - predicted).abs();
assert!(
error <= 1.0e-4,
"actual={actual:.12e}, predicted={predicted:.12e}, error={error:.12e}"
);
}
#[test]
pub(crate) fn sae_row_layout_from_dense_weights_top_k_and_cutoff() {
let coord_dims = vec![2usize, 1, 2];
let coord_offsets_full = vec![3usize, 5, 6];
let assignments = vec![
Array1::from_vec(vec![0.7, 0.01, 0.29]),
Array1::from_vec(vec![0.001, 0.002, 0.0005]),
];
let layout =
SaeRowLayout::from_dense_weights(&assignments, 2, 0.05, coord_dims, coord_offsets_full);
assert_eq!(layout.active_atoms[0], vec![0, 2]);
assert_eq!(layout.active_atoms[1], vec![1]);
assert_eq!(layout.row_q_active(0), 6);
assert_eq!(layout.row_q_active(1), 2);
let compact = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut full = vec![0.0_f64; 8];
layout.expand_row(0, &compact, &mut full);
assert_eq!(full[0], 1.0);
assert_eq!(full[1], 0.0);
assert_eq!(full[2], 2.0);
assert_eq!(full[3], 3.0);
assert_eq!(full[4], 4.0);
assert_eq!(full[5], 0.0);
assert_eq!(full[6], 5.0);
assert_eq!(full[7], 6.0);
}
#[test]
pub(crate) fn sae_mechsparsity_beta_block_routes_through_arrow_schur_gb() {
let coords = array![[0.10], [0.35], [0.80]];
let (phi, jet) = periodic_basis(&coords);
let decoder = array![
[0.7, -0.2, 0.05, 0.4],
[-0.5, 0.6, -0.1, 0.3],
[0.2, 0.0, -0.4, -0.6],
];
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
decoder.clone(),
Array2::<f64>::eye(3),
)
.unwrap();
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((3, 1)),
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let m = 3usize;
let p = 4usize;
let slice = PsiSlice::full(m * p, Some(m));
let penalty = MechanismSparsityPenalty::new(
slice,
vec![vec![0, 1], vec![2, 3]],
1.0,
1.0e-6,
(term.n_obs()) as f64,
false,
)
.unwrap();
let mut registry = AnalyticPenaltyRegistry::new();
registry.push(AnalyticPenaltyKind::MechanismSparsity(Arc::new(penalty)));
let target = array![
[0.20, 0.10, -0.05, 0.25],
[-0.10, 0.30, 0.15, -0.20],
[0.45, -0.05, 0.10, 0.30],
];
let rho = SaeManifoldRho::new(0.0, -6.0, vec![Array1::<f64>::zeros(1)]);
let sys = term
.assemble_arrow_schur(target.view(), &rho, Some(®istry))
.unwrap();
assert_eq!(sys.gb.len(), m * p, "gb should match flatten_beta length");
let mut absmax = 0.0_f64;
for v in sys.gb.iter().copied() {
assert!(v.is_finite());
if v.abs() > absmax {
absmax = v.abs();
}
}
assert!(
absmax > 1.0e-6,
"MechSparsity must inject a non-trivial gradient into the SAE arrow-Schur gb; absmax={absmax:.3e}"
);
let sys_no_penalty = term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
let beta = term.flatten_beta();
let expected = {
let s = (0.5_f64.powi(2) + 0.6_f64.powi(2) + 1.0e-12).sqrt();
(2.0_f64).sqrt() * (-0.5_f64) / s
};
let delta = sys.gb[1 * p + 0] - sys_no_penalty.gb[1 * p + 0];
assert!(
(delta - expected).abs() <= 1.0e-6,
"expected MechSparsity gb contribution at (basis=1, feat=0) ≈ {expected:.6e}, \
got Δgb={delta:.6e} (gb_with={:.6e}, gb_without={:.6e}, beta entry = {})",
sys.gb[1 * p + 0],
sys_no_penalty.gb[1 * p + 0],
beta[1 * p + 0]
);
}
pub(crate) fn smoothed_nuclear_norm(decoder: &Array2<f64>, eps: f64) -> f64 {
let (_u, s, _vt) = decoder.clone().svd(false, false).unwrap();
s.iter()
.map(|sigma| (sigma * sigma + eps * eps).sqrt() - eps)
.sum()
}
#[test]
pub(crate) fn sae_nuclear_norm_beta_block_routes_through_gb_and_shrinks_spectrum() {
let coords = array![[0.10], [0.35], [0.80]];
let (phi, jet) = periodic_basis(&coords);
let decoder = array![
[0.9, -0.2, 0.05, 0.4],
[-0.5, 0.7, -0.1, 0.3],
[0.2, 0.1, -0.8, -0.6],
];
let m = 3usize;
let p = 4usize;
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
decoder.clone(),
Array2::<f64>::eye(3),
)
.unwrap();
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((3, 1)),
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let eps = 1.0e-6;
let slice = PsiSlice::full(m * p, Some(m));
let penalty = NuclearNormPenalty::new(slice, 1.0, p, eps, None, false).unwrap();
let mut registry = AnalyticPenaltyRegistry::new();
registry.push(AnalyticPenaltyKind::NuclearNorm(Arc::new(penalty)));
term.validate_analytic_penalty_registry(®istry)
.expect("NuclearNorm must be accepted (redirected to the β block)");
let target = array![
[0.20, 0.10, -0.05, 0.25],
[-0.10, 0.30, 0.15, -0.20],
[0.45, -0.05, 0.10, 0.30],
];
let rho = SaeManifoldRho::new(0.0, -6.0, vec![Array1::<f64>::zeros(1)]);
let baseline = term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
let sys = term
.assemble_arrow_schur(target.view(), &rho, Some(®istry))
.unwrap();
assert_eq!(sys.gb.len(), m * p, "gb should match flatten_beta length");
assert_eq!(
baseline.gb.len(),
m * p,
"baseline gb should match flatten_beta length"
);
let mut absmax = 0.0_f64;
let mut penalty_grad = Array1::<f64>::zeros(m * p);
for ((dst, sys_g), baseline_g) in penalty_grad
.iter_mut()
.zip(sys.gb.iter())
.zip(baseline.gb.iter())
{
let v = *sys_g - *baseline_g;
assert!(v.is_finite());
*dst = v;
absmax = absmax.max(v.abs());
}
assert!(
absmax > 1.0e-6,
"NuclearNorm must inject a non-trivial gradient into the SAE \
arrow-Schur gb; absmax={absmax:.3e}"
);
let per_atom = NuclearNormPenalty::new(
PsiSlice {
range: 0..m * p,
latent_dim: Some(p),
},
1.0,
m,
eps,
None,
false,
)
.unwrap();
let beta = term.flatten_beta();
let ref_grad = per_atom.grad_target(beta.view(), Array1::<f64>::zeros(0).view());
for j in 0..m * p {
assert!(
(penalty_grad[j] - ref_grad[j]).abs() <= 1.0e-9,
"penalty gb[{j}]={:.12e} must equal analytic spectral grad {:.12e}",
penalty_grad[j],
ref_grad[j]
);
}
let base_norm = smoothed_nuclear_norm(&decoder, eps);
let step = 1.0e-2;
let mut shrunk = decoder.clone();
for ((row, feat), value) in shrunk.indexed_iter_mut() {
*value -= step * penalty_grad[row * p + feat];
}
let shrunk_norm = smoothed_nuclear_norm(&shrunk, eps);
assert!(
shrunk_norm < base_norm,
"a step along gb must shrink the decoder spectrum: \
before={base_norm:.9e}, after={shrunk_norm:.9e}"
);
assert!(sys.hbb.is_empty());
let mut hbb_diag = vec![0.0_f64; m * p];
sys.effective_penalty_op().diagonal(&mut hbb_diag);
for i in 0..m * p {
assert!(
hbb_diag[i] >= -1.0e-9,
"hbb diagonal must be non-negative (PSD majorizer); hbb[{i},{i}]={:.3e}",
hbb_diag[i]
);
}
}
#[derive(Debug)]
pub(crate) struct TestPeriodicEvaluator;
impl SaeBasisEvaluator for TestPeriodicEvaluator {
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
if coords.ncols() != 1 {
return Some(Err(format!(
"TestPeriodicEvaluator::second_jet_dyn: expected latent_dim 1, got {}",
coords.ncols()
)));
}
None
}
fn third_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array5<f64>, String>> {
if coords.ncols() != 1 {
return Some(Err(format!(
"TestPeriodicEvaluator::third_jet_dyn: expected latent_dim 1, got {}",
coords.ncols()
)));
}
None
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
Ok(periodic_basis(&coords.to_owned()))
}
}
#[derive(Debug, Clone)]
pub(crate) struct SaeFdWorst {
pub(crate) index: usize,
pub(crate) analytic: f64,
pub(crate) finite_difference: f64,
pub(crate) absolute_error: f64,
pub(crate) relative_error: f64,
}
impl SaeFdWorst {
pub(crate) fn new() -> Self {
Self {
index: 0,
analytic: 0.0,
finite_difference: 0.0,
absolute_error: 0.0,
relative_error: 0.0,
}
}
pub(crate) fn observe(&mut self, index: usize, analytic: f64, finite_difference: f64) {
let absolute_error = (analytic - finite_difference).abs();
let scale = analytic.abs().max(finite_difference.abs()).max(1.0e-9);
let relative_error = absolute_error / scale;
if relative_error > self.relative_error {
self.index = index;
self.analytic = analytic;
self.finite_difference = finite_difference;
self.absolute_error = absolute_error;
self.relative_error = relative_error;
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct SaeFdBlockReport {
pub(crate) label: String,
pub(crate) base_loss: f64,
pub(crate) coord: SaeFdWorst,
pub(crate) decoder: SaeFdWorst,
}
pub(crate) fn sae_fd_decoder(n_basis: usize, p_out: usize) -> Array2<f64> {
let mut decoder = Array2::<f64>::zeros((n_basis, p_out));
for basis in 0..n_basis {
for out_col in 0..p_out {
let phase = 0.73 * ((basis + 1) as f64) + 1.17 * ((out_col + 1) as f64);
decoder[[basis, out_col]] = 0.16 * phase.sin() + 0.05 * (1.9 * phase).cos();
}
}
decoder
}
pub(crate) fn sae_fd_target(n_obs: usize, p_out: usize) -> Array2<f64> {
let mut target = Array2::<f64>::zeros((n_obs, p_out));
for row in 0..n_obs {
for out_col in 0..p_out {
let x = (row as f64) + 1.0;
let y = (out_col as f64) + 1.0;
target[[row, out_col]] =
0.21 * (0.31 * x + 0.47 * y).sin() - 0.13 * (0.19 * x * y).cos();
}
}
target
}
pub(crate) fn sae_fd_coords(label: &str, n_obs: usize) -> Array2<f64> {
let mut coords = Array2::<f64>::zeros((n_obs, 1));
for row in 0..n_obs {
let x = row as f64;
coords[[row, 0]] = match label {
"periodic_d1" => 0.07 + 0.043 * x + 0.004 * (1.3 * x).sin(),
"euclidean_d1" => -0.46 + 0.048 * x + 0.006 * (1.7 * x).cos(),
other => panic!("unknown SAE FD case label {other}"),
};
}
coords
}
pub(crate) fn sae_fd_term(label: &str) -> (SaeManifoldTerm, Array2<f64>, SaeManifoldRho) {
let n_obs = 20usize;
let p_out = 3usize;
let coords = sae_fd_coords(label, n_obs);
let (basis_kind, phi, jet, n_basis, atom) = match label {
"periodic_d1" => {
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(3).unwrap());
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let n_basis = phi.ncols();
let atom = SaeManifoldAtom::new(
"periodic_d1",
SaeAtomBasisKind::Periodic,
1,
phi.clone(),
jet.clone(),
sae_fd_decoder(n_basis, p_out),
Array2::<f64>::eye(n_basis),
)
.unwrap()
.with_basis_second_jet(evaluator);
(SaeAtomBasisKind::Periodic, phi, jet, n_basis, atom)
}
"euclidean_d1" => {
let evaluator = Arc::new(EuclideanPatchEvaluator::new(1, 2).unwrap());
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let n_basis = phi.ncols();
let atom = SaeManifoldAtom::new(
"euclidean_d1",
SaeAtomBasisKind::EuclideanPatch,
1,
phi.clone(),
jet.clone(),
sae_fd_decoder(n_basis, p_out),
Array2::<f64>::eye(n_basis),
)
.unwrap()
.with_basis_second_jet(evaluator);
(SaeAtomBasisKind::EuclideanPatch, phi, jet, n_basis, atom)
}
other => panic!("unknown SAE FD case label {other}"),
};
assert_eq!(
basis_kind.latent_manifold(1),
atom.basis_kind.latent_manifold(1)
);
assert_eq!(phi.dim(), (n_obs, n_basis));
assert_eq!(jet.dim(), (n_obs, n_basis, 1));
let manifold = atom.basis_kind.latent_manifold(1);
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n_obs, 1)),
vec![coords],
vec![manifold],
AssignmentMode::softmax(1.0),
)
.unwrap();
let term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = sae_fd_target(n_obs, p_out);
let rho = SaeManifoldRho::new(0.0, 1.0e-4_f64.ln(), vec![array![-30.0]]);
(term, target, rho)
}
pub(crate) fn sae_fd_refresh(term: &mut SaeManifoldTerm) {
let coords = term.assignment.coords[0].as_matrix();
term.atoms[0].refresh_basis(coords.view()).unwrap();
}
pub(crate) fn sae_fd_set_coord(term: &mut SaeManifoldTerm, row: usize, value: f64) {
let mut flat = term.assignment.coords[0].as_flat().clone();
flat[row] = value;
term.assignment.coords[0].set_flat(flat.view());
sae_fd_refresh(term);
}
pub(crate) fn sae_fd_total_loss(
term: &SaeManifoldTerm,
target: &Array2<f64>,
rho: &SaeManifoldRho,
) -> f64 {
term.loss(target.view(), rho).unwrap().total()
}
pub(crate) fn sae_fd_check_case(label: &str) -> SaeFdBlockReport {
let epsilon = 1.0e-6;
let (term, target, rho) = sae_fd_term(label);
let base_loss = sae_fd_total_loss(&term, &target, &rho);
assert!(base_loss.is_finite(), "{label}: base loss is not finite");
let mut assembled = term.clone();
sae_fd_refresh(&mut assembled);
let sys = assembled
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
assert_eq!(sys.rows.len(), term.n_obs());
assert_eq!(sys.gb.len(), term.beta_dim());
for row in 0..term.n_obs() {
assert_eq!(
sys.rows[row].gt.len(),
1,
"{label}: K=1 softmax d=1 should expose exactly one row coordinate gradient"
);
}
let mut coord = SaeFdWorst::new();
let base_coords = term.assignment.coords[0].as_flat().clone();
for row in 0..term.n_obs() {
let mut plus = term.clone();
sae_fd_set_coord(&mut plus, row, base_coords[row] + epsilon);
let loss_plus = sae_fd_total_loss(&plus, &target, &rho);
let mut minus = term.clone();
sae_fd_set_coord(&mut minus, row, base_coords[row] - epsilon);
let loss_minus = sae_fd_total_loss(&minus, &target, &rho);
let finite_difference = (loss_plus - loss_minus) / (2.0 * epsilon);
coord.observe(row, sys.rows[row].gt[0], finite_difference);
}
let mut decoder = SaeFdWorst::new();
let beta = term.flatten_beta();
for beta_idx in 0..beta.len() {
let mut beta_plus = beta.clone();
beta_plus[beta_idx] += epsilon;
let mut plus = term.clone();
plus.set_flat_beta(beta_plus.view()).unwrap();
sae_fd_refresh(&mut plus);
let loss_plus = sae_fd_total_loss(&plus, &target, &rho);
let mut beta_minus = beta.clone();
beta_minus[beta_idx] -= epsilon;
let mut minus = term.clone();
minus.set_flat_beta(beta_minus.view()).unwrap();
sae_fd_refresh(&mut minus);
let loss_minus = sae_fd_total_loss(&minus, &target, &rho);
let finite_difference = (loss_plus - loss_minus) / (2.0 * epsilon);
decoder.observe(beta_idx, sys.gb[beta_idx], finite_difference);
}
SaeFdBlockReport {
label: label.to_string(),
base_loss,
coord,
decoder,
}
}
#[derive(Clone, Copy)]
pub(crate) enum SaePenCaseKind {
EuclideanD1,
PeriodicD1,
EuclideanD2,
}
#[derive(Clone, Copy)]
pub(crate) enum SaePenKind {
Isometry,
Ard,
ScadMcp,
NuclearNorm,
DecoderIncoherence,
}
pub(crate) fn sae_pen_term(
kind: SaePenCaseKind,
) -> (SaeManifoldTerm, Array2<f64>, SaeManifoldRho, PsiSlice) {
let n_obs = 12usize;
let p_out = 3usize;
let (coords, latent_dim, atom): (Array2<f64>, usize, SaeManifoldAtom) = match kind {
SaePenCaseKind::PeriodicD1 => {
let mut coords = Array2::<f64>::zeros((n_obs, 1));
for row in 0..n_obs {
let x = row as f64;
coords[[row, 0]] = 0.11 + 0.037 * x + 0.004 * (1.3 * x).sin();
}
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(3).unwrap());
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let n_basis = phi.ncols();
let atom = SaeManifoldAtom::new(
"periodic_d1",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
sae_fd_decoder(n_basis, p_out),
Array2::<f64>::eye(n_basis),
)
.unwrap()
.with_basis_second_jet(evaluator);
(coords, 1, atom)
}
SaePenCaseKind::EuclideanD1 => {
let mut coords = Array2::<f64>::zeros((n_obs, 1));
for row in 0..n_obs {
let x = row as f64;
coords[[row, 0]] = -0.41 + 0.052 * x + 0.006 * (1.7 * x).cos();
}
let evaluator = Arc::new(EuclideanPatchEvaluator::new(1, 2).unwrap());
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let n_basis = phi.ncols();
let atom = SaeManifoldAtom::new(
"euclidean_d1",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
sae_fd_decoder(n_basis, p_out),
Array2::<f64>::eye(n_basis),
)
.unwrap()
.with_basis_second_jet(evaluator);
(coords, 1, atom)
}
SaePenCaseKind::EuclideanD2 => {
let mut coords = Array2::<f64>::zeros((n_obs, 2));
for row in 0..n_obs {
let x = row as f64;
coords[[row, 0]] = -0.33 + 0.041 * x + 0.005 * (1.1 * x).cos();
coords[[row, 1]] = 0.27 - 0.036 * x + 0.004 * (0.9 * x).sin();
}
let evaluator = Arc::new(EuclideanPatchEvaluator::new(2, 2).unwrap());
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let n_basis = phi.ncols();
let atom = SaeManifoldAtom::new(
"euclidean_d2",
SaeAtomBasisKind::EuclideanPatch,
2,
phi,
jet,
sae_fd_decoder(n_basis, p_out),
Array2::<f64>::eye(n_basis),
)
.unwrap()
.with_basis_second_jet(evaluator);
(coords, 2, atom)
}
};
let manifold = atom.basis_kind.latent_manifold(latent_dim);
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n_obs, 1)),
vec![coords],
vec![manifold],
AssignmentMode::softmax(1.0),
)
.unwrap();
let term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = sae_fd_target(n_obs, p_out);
let log_ard = vec![Array1::from_elem(latent_dim, -30.0_f64)];
let rho = SaeManifoldRho::new(0.0, 1.0e-4_f64.ln(), log_ard);
let slice = PsiSlice {
range: 0..n_obs * latent_dim,
latent_dim: Some(latent_dim),
};
(term, target, rho, slice)
}
pub(crate) fn sae_pen_term_k2() -> (SaeManifoldTerm, Array2<f64>, SaeManifoldRho) {
let n_obs = 12usize;
let p_out = 3usize;
let mut atoms = Vec::with_capacity(2);
let mut coord_blocks = Vec::with_capacity(2);
for atom_idx in 0..2usize {
let mut coords = Array2::<f64>::zeros((n_obs, 1));
for row in 0..n_obs {
let x = row as f64;
coords[[row, 0]] = if atom_idx == 0 {
-0.41 + 0.052 * x + 0.006 * (1.7 * x).cos()
} else {
0.18 + 0.039 * x + 0.005 * (1.1 * x).sin()
};
}
let evaluator = Arc::new(EuclideanPatchEvaluator::new(1, 2).unwrap());
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let n_basis = phi.ncols();
let mut decoder = sae_fd_decoder(n_basis, p_out);
if atom_idx == 1 {
for basis in 0..n_basis {
for out_col in 0..p_out {
decoder[[basis, out_col]] += 0.07 * ((basis + out_col) as f64 + 1.0).cos();
}
}
}
let atom = SaeManifoldAtom::new(
"euclidean_d1",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
Array2::<f64>::eye(n_basis),
)
.unwrap()
.with_basis_second_jet(evaluator);
atoms.push(atom);
coord_blocks.push(coords);
}
let manifold = LatentManifold::Euclidean;
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::from_elem((n_obs, 2), 0.2),
coord_blocks,
vec![manifold.clone(), manifold],
AssignmentMode::softmax(1.0),
)
.unwrap();
let term = SaeManifoldTerm::new(atoms, assignment).unwrap();
let target = sae_fd_target(n_obs, p_out);
let log_ard = vec![
Array1::from_elem(1, -30.0_f64),
Array1::from_elem(1, -30.0_f64),
];
let rho = SaeManifoldRho::new(0.0, 1.0e-4_f64.ln(), log_ard);
(term, target, rho)
}
pub(crate) fn sae_pen_registry(
pen: SaePenKind,
coord_slice: &PsiSlice,
n_obs: usize,
latent_dim: usize,
beta_len: usize,
p_out: usize,
) -> AnalyticPenaltyRegistry {
use crate::terms::analytic_penalties::PenaltyConcavity;
use crate::terms::analytic_penalties::ScadMcpPenalty;
let mut registry = AnalyticPenaltyRegistry::new();
match pen {
SaePenKind::Isometry => {
let penalty = IsometryPenalty::new_euclidean(coord_slice.clone(), latent_dim);
registry.push(AnalyticPenaltyKind::Isometry(Arc::new(penalty)));
}
SaePenKind::Ard => {
let penalty = ARDPenalty::new(coord_slice.clone(), latent_dim);
registry.push(AnalyticPenaltyKind::Ard(Arc::new(penalty)));
}
SaePenKind::ScadMcp => {
let penalty = ScadMcpPenalty::new(
coord_slice.clone(),
0.5,
n_obs,
3.0,
1.0e-4,
PenaltyConcavity::Mcp,
false,
)
.unwrap();
registry.push(AnalyticPenaltyKind::ScadMcp(Arc::new(penalty)));
}
SaePenKind::NuclearNorm => {
let slice = PsiSlice {
range: 0..beta_len,
latent_dim: Some(beta_len / p_out),
};
let penalty = NuclearNormPenalty::new(slice, 0.7, p_out, 1.0e-4, None, false).unwrap();
registry.push(AnalyticPenaltyKind::NuclearNorm(Arc::new(penalty)));
}
SaePenKind::DecoderIncoherence => {
let m_per = beta_len / (2 * p_out);
let slice = PsiSlice {
range: 0..beta_len,
latent_dim: Some(beta_len / p_out),
};
let penalty = DecoderIncoherencePenalty::new(
slice,
vec![m_per, m_per],
p_out,
Array2::<f64>::from_elem((2, 2), 0.5),
0.6,
false,
)
.unwrap();
registry.push(AnalyticPenaltyKind::DecoderIncoherence(Arc::new(penalty)));
}
}
registry
}
pub(crate) fn sae_pen_fd_check(
label: &str,
term: &SaeManifoldTerm,
target: &Array2<f64>,
rho: &SaeManifoldRho,
registry: &AnalyticPenaltyRegistry,
) -> SaeFdBlockReport {
let epsilon = 1.0e-6;
let base_obj = term
.penalized_objective_total(target.view(), rho, Some(registry), 1.0)
.unwrap();
assert!(base_obj.is_finite(), "{label}: base objective not finite");
let mut assembled = term.clone();
let sys = assembled
.assemble_arrow_schur(target.view(), rho, Some(registry))
.unwrap();
let mut coord = SaeFdWorst::new();
let coord_offsets = term.assignment.coord_offsets();
for atom_idx in 0..term.k_atoms() {
let off = coord_offsets[atom_idx];
let d = term.assignment.coords[atom_idx].latent_dim();
let base_flat = term.assignment.coords[atom_idx].as_flat().clone();
let n_atom = base_flat.len() / d;
for row in 0..n_atom {
for axis in 0..d {
let lin = row * d + axis;
let mut plus = term.clone();
let mut flat_p = base_flat.clone();
flat_p[lin] += epsilon;
plus.assignment.coords[atom_idx].set_flat(flat_p.view());
let coords_p = plus.assignment.coords[atom_idx].as_matrix();
plus.atoms[atom_idx].refresh_basis(coords_p.view()).unwrap();
let obj_p = plus
.penalized_objective_total(target.view(), rho, Some(registry), 1.0)
.unwrap();
let mut minus = term.clone();
let mut flat_m = base_flat.clone();
flat_m[lin] -= epsilon;
minus.assignment.coords[atom_idx].set_flat(flat_m.view());
let coords_m = minus.assignment.coords[atom_idx].as_matrix();
minus.atoms[atom_idx]
.refresh_basis(coords_m.view())
.unwrap();
let obj_m = minus
.penalized_objective_total(target.view(), rho, Some(registry), 1.0)
.unwrap();
let finite_difference = (obj_p - obj_m) / (2.0 * epsilon);
coord.observe(
row * d + axis,
sys.rows[row].gt[off + axis],
finite_difference,
);
}
}
}
let mut decoder = SaeFdWorst::new();
let beta = term.flatten_beta();
for beta_idx in 0..beta.len() {
let mut beta_plus = beta.clone();
beta_plus[beta_idx] += epsilon;
let mut plus = term.clone();
plus.set_flat_beta(beta_plus.view()).unwrap();
let obj_p = plus
.penalized_objective_total(target.view(), rho, Some(registry), 1.0)
.unwrap();
let mut beta_minus = beta.clone();
beta_minus[beta_idx] -= epsilon;
let mut minus = term.clone();
minus.set_flat_beta(beta_minus.view()).unwrap();
let obj_m = minus
.penalized_objective_total(target.view(), rho, Some(registry), 1.0)
.unwrap();
let finite_difference = (obj_p - obj_m) / (2.0 * epsilon);
decoder.observe(beta_idx, sys.gb[beta_idx], finite_difference);
}
SaeFdBlockReport {
label: label.to_string(),
base_loss: base_obj,
coord,
decoder,
}
}
#[test]
pub(crate) fn sae_assembled_gradient_matches_penalized_objective_central_fd() {
let p_out = 3usize;
let mut reports: Vec<SaeFdBlockReport> = Vec::new();
let single_cases: &[(&str, SaePenCaseKind, SaePenKind)] = &[
(
"isometry_circle_d1",
SaePenCaseKind::PeriodicD1,
SaePenKind::Isometry,
),
(
"isometry_euclid_d2",
SaePenCaseKind::EuclideanD2,
SaePenKind::Isometry,
),
("ard_circle_d1", SaePenCaseKind::PeriodicD1, SaePenKind::Ard),
(
"scadmcp_euclid_d1",
SaePenCaseKind::EuclideanD1,
SaePenKind::ScadMcp,
),
(
"nuclearnorm_euclid_d1",
SaePenCaseKind::EuclideanD1,
SaePenKind::NuclearNorm,
),
];
for (label, case_kind, pen_kind) in single_cases {
let (term, target, rho, slice) = sae_pen_term(*case_kind);
let n_obs = term.n_obs();
let latent_dim = term.assignment.coords[0].latent_dim();
let beta_len = term.beta_dim();
let registry = sae_pen_registry(*pen_kind, &slice, n_obs, latent_dim, beta_len, p_out);
term.validate_analytic_penalty_registry(®istry)
.expect("penalty registry must validate for the SAE term");
reports.push(sae_pen_fd_check(label, &term, &target, &rho, ®istry));
}
{
let (term, target, rho) = sae_pen_term_k2();
let beta_len = term.beta_dim();
let slice = PsiSlice {
range: 0..beta_len,
latent_dim: Some(beta_len / p_out),
};
let registry = sae_pen_registry(
SaePenKind::DecoderIncoherence,
&slice,
term.n_obs(),
1,
beta_len,
p_out,
);
term.validate_analytic_penalty_registry(®istry)
.expect("DecoderIncoherence registry must validate for the K=2 SAE term");
reports.push(sae_pen_fd_check(
"decoder_incoherence_k2",
&term,
&target,
&rho,
®istry,
));
}
let relative_tolerance = 1.0e-5;
let absolute_tolerance = 1.0e-7;
let mut all_blocks_match = true;
for report in &reports {
let coord_ok = report.coord.relative_error <= relative_tolerance
|| report.coord.absolute_error <= absolute_tolerance;
let decoder_ok = report.decoder.relative_error <= relative_tolerance
|| report.decoder.absolute_error <= absolute_tolerance;
let metadata_ok = !report.label.is_empty() && report.base_loss.is_finite();
all_blocks_match = all_blocks_match && metadata_ok && coord_ok && decoder_ok;
}
assert!(
all_blocks_match,
"SAE assembled gradient does not match central FD of the penalized objective: {reports:#?}"
);
}
#[test]
pub(crate) fn sae_reml_extra_penalty_energy_counts_live_isometry_once() {
let p_out = 3usize;
let (term, _target, _rho, slice) = sae_pen_term(SaePenCaseKind::PeriodicD1);
let registry = sae_pen_registry(
SaePenKind::Isometry,
&slice,
term.n_obs(),
term.assignment.coords[0].latent_dim(),
term.beta_dim(),
p_out,
);
let isometry_energy = term
.isometry_penalty_value_total(®istry)
.expect("live isometry value");
assert!(
isometry_energy > 0.0,
"fixture must carry nonzero isometry energy"
);
let decoder_energy = term
.analytic_decoder_penalty_value_total(®istry)
.expect("decoder penalty value");
assert_abs_diff_eq!(decoder_energy, 0.0, epsilon = 1.0e-12);
let extra_energy = term
.reml_extra_penalty_value_total(®istry)
.expect("REML extra penalty value");
assert_abs_diff_eq!(extra_energy, isometry_energy, epsilon = 1.0e-12);
}
#[test]
pub(crate) fn sae_d1_assembled_gradient_matches_loss_central_fd() {
let reports = vec![
sae_fd_check_case("euclidean_d1"),
sae_fd_check_case("periodic_d1"),
];
let relative_tolerance = 3.0e-5;
let absolute_tolerance = 3.0e-7;
let mut all_blocks_match = true;
for report in &reports {
let coord_ok = report.coord.relative_error <= relative_tolerance
|| report.coord.absolute_error <= absolute_tolerance;
let decoder_ok = report.decoder.relative_error <= relative_tolerance
|| report.decoder.absolute_error <= absolute_tolerance;
let metadata_ok = !report.label.is_empty() && report.base_loss.is_finite();
all_blocks_match = all_blocks_match && metadata_ok && coord_ok && decoder_ok;
}
assert!(
all_blocks_match,
"SAE d=1 assembled gradient does not match central finite difference: {reports:#?}"
);
}
pub(crate) fn assert_jacobian_matches_central_difference<E: SaeBasisEvaluator>(
evaluator: &E,
coords: Array2<f64>,
tolerance: f64,
) {
let epsilon = 1.0e-6;
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let (n_rows, n_basis) = phi.dim();
let latent_dim = coords.ncols();
assert_eq!(jet.dim(), (n_rows, n_basis, latent_dim));
for row in 0..n_rows {
for axis in 0..latent_dim {
let mut plus = coords.clone();
let mut minus = coords.clone();
plus[[row, axis]] += epsilon;
minus[[row, axis]] -= epsilon;
let (phi_plus, plus_jet) = evaluator.evaluate(plus.view()).unwrap();
let (phi_minus, minus_jet) = evaluator.evaluate(minus.view()).unwrap();
assert_eq!(plus_jet.dim(), jet.dim());
assert_eq!(minus_jet.dim(), jet.dim());
for basis in 0..n_basis {
let finite_difference =
(phi_plus[[row, basis]] - phi_minus[[row, basis]]) / (2.0 * epsilon);
let analytic = jet[[row, basis, axis]];
let error = (analytic - finite_difference).abs();
assert!(
error <= tolerance,
"row={row} basis={basis} axis={axis}: analytic={analytic:.12e}, \
finite_difference={finite_difference:.12e}, error={error:.12e}, \
tolerance={tolerance:.12e}"
);
}
}
}
}
#[test]
pub(crate) fn sae_basis_evaluator_jacobians_match_central_differences() {
assert_jacobian_matches_central_difference(
&PeriodicHarmonicEvaluator::new(7).unwrap(),
array![[-0.37], [0.0], [0.125], [0.41]],
1.0e-6,
);
assert_jacobian_matches_central_difference(
&RawPeriodicCircleEvaluator::new(3).unwrap(),
array![[-1.2, 0.3, 2.0], [0.0, -0.4, 0.8], [2.4, 1.1, -0.7]],
1.0e-6,
);
let sphere_coords = array![[-0.7, -1.2], [-0.25, 0.0], [0.35, 0.9], [0.8, 2.1]];
assert_jacobian_matches_central_difference(
&SphereChartEvaluator,
sphere_coords.clone(),
1.0e-6,
);
let (sphere_phi, sphere_jet) = SphereChartEvaluator.evaluate(sphere_coords.view()).unwrap();
assert_eq!(sphere_phi.dim(), (sphere_coords.nrows(), 7));
assert_eq!(sphere_jet.dim(), (sphere_coords.nrows(), 7, 2));
for row in 0..sphere_coords.nrows() {
let lat = sphere_coords[[row, 0]];
let lon = sphere_coords[[row, 1]];
let clat = lat.cos();
let slat = lat.sin();
let clon = lon.cos();
let slon = lon.sin();
let z = slat;
let dx_dlon = -clat * slon;
let dy_dlon = clat * clon;
assert_eq!(sphere_jet[[row, 3, 1]], 0.0);
assert!((sphere_jet[[row, 5, 1]] - dy_dlon * z).abs() <= 1.0e-12);
assert!((sphere_jet[[row, 6, 1]] - dx_dlon * z).abs() <= 1.0e-12);
}
assert_jacobian_matches_central_difference(
&AffineCoordinateEvaluator::new(3),
array![[0.0, -1.0, 2.0], [3.5, 0.25, -0.75]],
1.0e-6,
);
let torus_coords = array![[0.1, 0.7], [0.42, 0.0], [0.95, 0.33], [0.5, 0.5]];
assert_jacobian_matches_central_difference(
&TorusHarmonicEvaluator::new(2, 3).unwrap(),
torus_coords.clone(),
1.0e-6,
);
let (torus_phi, torus_jet) = TorusHarmonicEvaluator::new(2, 3)
.unwrap()
.evaluate(torus_coords.view())
.unwrap();
assert_eq!(torus_phi.dim(), (torus_coords.nrows(), 49));
assert_eq!(torus_jet.dim(), (torus_coords.nrows(), 49, 2));
for row in 0..torus_coords.nrows() {
assert!((torus_phi[[row, 0]] - 1.0).abs() <= 1.0e-12);
assert!(torus_jet[[row, 0, 0]].abs() <= 1.0e-12);
assert!(torus_jet[[row, 0, 1]].abs() <= 1.0e-12);
}
}
#[test]
pub(crate) fn projection_seed_grid_spans_each_compact_manifold() {
use std::f64::consts::PI;
let periodic = SaeAtomBasisKind::Periodic
.projection_seed_grid(1, 16)
.unwrap();
assert_eq!(periodic.dim(), (16, 1));
for i in 0..16 {
assert_abs_diff_eq!(periodic[[i, 0]], i as f64 / 16.0, epsilon = 1e-12);
}
assert!(periodic.iter().all(|&t| (0.0..1.0).contains(&t)));
let r = 6usize;
let sphere = SaeAtomBasisKind::Sphere.projection_seed_grid(2, r).unwrap();
assert_eq!(sphere.dim(), (r * r, 2));
for row in 0..r * r {
let lat = sphere[[row, 0]];
let lon = sphere[[row, 1]];
assert!(
lat > -PI / 2.0 && lat < PI / 2.0,
"sphere seed latitude {lat} is not strictly interior to the chart"
);
assert!(
(-PI..PI).contains(&lon),
"sphere seed longitude {lon} is outside [-π, π)"
);
}
assert!(
SaeAtomBasisKind::EuclideanPatch
.projection_seed_grid(2, 64)
.is_none(),
"Euclidean-patch (unbounded) atoms must not expose a projection seed grid"
);
}
#[test]
pub(crate) fn torus_projection_seed_grid_caps_total_points() {
let g1 = SaeAtomBasisKind::Torus
.projection_seed_grid(1, 256)
.unwrap();
assert_eq!(g1.dim(), (256, 1));
let g3 = SaeAtomBasisKind::Torus
.projection_seed_grid(3, 256)
.unwrap();
assert_eq!(g3.ncols(), 3);
assert_eq!(g3.nrows(), 16 * 16 * 16);
assert!(
g3.nrows() <= 4096,
"torus d=3 seed grid has {} points, over the 4096 cap",
g3.nrows()
);
assert!(
g3.iter().all(|&t| (0.0..1.0).contains(&t)),
"every torus seed coordinate must be a phase on [0, 1)"
);
for axis in 0..3 {
let mut vals: Vec<f64> = g3.column(axis).iter().copied().collect();
vals.sort_by(|a, b| a.partial_cmp(b).unwrap());
vals.dedup();
assert_eq!(
vals.len(),
16,
"torus seed axis {axis} should take 16 distinct phases"
);
}
let g12 = SaeAtomBasisKind::Torus
.projection_seed_grid(12, 256)
.unwrap();
assert_eq!(g12.nrows(), 1usize << 12);
assert!(g12.nrows() <= 4096);
assert!(
SaeAtomBasisKind::Torus
.projection_seed_grid(13, 256)
.is_none(),
"torus d=13 seed grid (2^13 > 4096) must fall back to None, not blow up the cap"
);
}
#[test]
pub(crate) fn seed_coords_by_decoder_projection_lands_on_grid_minimiser() {
use std::f64::consts::PI;
let resolution = 8usize;
let init_coords = array![[0.05], [0.05]];
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(3).unwrap());
let (phi0, jet0) = evaluator.evaluate(init_coords.view()).unwrap();
let decoder = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
decoder,
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(evaluator.clone());
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((2, 1)),
vec![init_coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(1.0),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let phases = [3usize, 6usize];
let mut target = Array2::<f64>::zeros((2, 2));
for (row, &k) in phases.iter().enumerate() {
let t = k as f64 / resolution as f64;
target[[row, 0]] = (2.0 * PI * t).sin();
target[[row, 1]] = (2.0 * PI * t).cos();
}
term.seed_coords_by_decoder_projection(target.view(), resolution)
.unwrap();
let seeded = term.assignment.coords[0].as_matrix();
let mut expected_coords = Array2::<f64>::zeros((2, 1));
for (row, &k) in phases.iter().enumerate() {
let expected = k as f64 / resolution as f64;
assert_abs_diff_eq!(seeded[[row, 0]], expected, epsilon = 1e-12);
expected_coords[[row, 0]] = expected;
}
let (phi_expected, _) = evaluator.evaluate(expected_coords.view()).unwrap();
assert_abs_diff_eq!(
(&term.atoms[0].basis_values - &phi_expected)
.mapv(f64::abs)
.sum(),
0.0,
epsilon = 1e-12
);
}
#[test]
pub(crate) fn seed_coords_by_decoder_projection_rejects_shape_mismatch() {
let init_coords = array![[0.05], [0.05]];
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(3).unwrap());
let (phi0, jet0) = evaluator.evaluate(init_coords.view()).unwrap();
let decoder = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
decoder,
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(evaluator);
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((2, 1)),
vec![init_coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(1.0),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let bad_target = Array2::<f64>::zeros((2, 3));
let err = term
.seed_coords_by_decoder_projection(bad_target.view(), 8)
.unwrap_err();
assert!(
err.contains("target shape"),
"expected a target-shape error, got: {err}"
);
}
#[test]
pub(crate) fn sphere_chart_basis_jet_is_single_source_of_truth() {
let coords = array![
[-1.2, -2.4], [0.35, 0.9], [std::f64::consts::FRAC_PI_2, 0.4], [-std::f64::consts::FRAC_PI_2, -1.1], [2.3, 0.7], [-3.0, 1.9], ];
let (engine_phi, engine_jet) = sphere_chart_basis_jet(coords.view()).unwrap();
let (adapter_phi, adapter_jet) = SphereChartEvaluator.evaluate(coords.view()).unwrap();
assert_eq!(engine_phi, adapter_phi);
assert_eq!(engine_jet, adapter_jet);
for row in 0..coords.nrows() {
let lat = coords[[row, 0]];
let lon = coords[[row, 1]];
let clat = lat.cos();
let slat = lat.sin();
let clon = lon.cos();
let slon = lon.sin();
let x = clat * clon;
let y = clat * slon;
let z = slat;
assert!((engine_phi[[row, 0]] - 1.0).abs() <= 1.0e-12);
assert!((engine_phi[[row, 1]] - x).abs() <= 1.0e-12);
assert!((engine_phi[[row, 2]] - y).abs() <= 1.0e-12);
assert!((engine_phi[[row, 3]] - z).abs() <= 1.0e-12);
assert!((engine_phi[[row, 4]] - x * y).abs() <= 1.0e-12);
assert!((engine_phi[[row, 5]] - y * z).abs() <= 1.0e-12);
assert!((engine_phi[[row, 6]] - x * z).abs() <= 1.0e-12);
let dx_dlon = -clat * slon;
let dy_dlon = clat * clon;
assert!((engine_jet[[row, 1, 1]] - dx_dlon).abs() <= 1.0e-12);
assert!((engine_jet[[row, 2, 1]] - dy_dlon).abs() <= 1.0e-12);
assert_eq!(engine_jet[[row, 3, 1]], 0.0);
assert!((engine_jet[[row, 4, 1]] - (dx_dlon * y + x * dy_dlon)).abs() <= 1.0e-12);
assert!((engine_jet[[row, 5, 1]] - dy_dlon * z).abs() <= 1.0e-12);
assert!((engine_jet[[row, 6, 1]] - dx_dlon * z).abs() <= 1.0e-12);
let dx_dlat = -slat * clon;
let dy_dlat = -slat * slon;
let dz_dlat = clat;
assert!((engine_jet[[row, 1, 0]] - dx_dlat).abs() <= 1.0e-12);
assert!((engine_jet[[row, 2, 0]] - dy_dlat).abs() <= 1.0e-12);
assert!((engine_jet[[row, 3, 0]] - dz_dlat).abs() <= 1.0e-12);
assert!((engine_jet[[row, 4, 0]] - (dx_dlat * y + x * dy_dlat)).abs() <= 1.0e-12);
assert!((engine_jet[[row, 5, 0]] - (dy_dlat * z + y * dz_dlat)).abs() <= 1.0e-12);
assert!((engine_jet[[row, 6, 0]] - (dx_dlat * z + x * dz_dlat)).abs() <= 1.0e-12);
}
assert_eq!(
SPHERE_CHART_PENALTY_DIAGONAL,
[1e-8, 1.0, 1.0, 1.0, 4.0, 4.0, 4.0]
);
}
#[test]
pub(crate) fn sphere_chart_jet_matches_fd_at_clamp_boundary() {
let coords = array![
[std::f64::consts::FRAC_PI_2, 0.4], [-std::f64::consts::FRAC_PI_2, -1.1], [1.45, 2.0], [1.69, -0.3], [2.3, 0.7], [0.35, 0.9], ];
let (_, jet) = sphere_chart_basis_jet(coords.view()).unwrap();
let h = 1.0e-6;
for row in 0..coords.nrows() {
for axis in 0..2 {
let mut plus = coords.clone();
let mut minus = coords.clone();
plus[[row, axis]] += h;
minus[[row, axis]] -= h;
let (phi_p, _) = sphere_chart_basis_jet(plus.view()).unwrap();
let (phi_m, _) = sphere_chart_basis_jet(minus.view()).unwrap();
for col in 0..7 {
let fd = (phi_p[[row, col]] - phi_m[[row, col]]) / (2.0 * h);
let an = jet[[row, col, axis]];
assert!(
(fd - an).abs() <= 1.0e-7,
"row {row} col {col} axis {axis}: analytic {an} vs FD {fd}"
);
}
}
}
let eps = 1.0e-8;
let lon = 0.4;
let below = array![[std::f64::consts::FRAC_PI_2 - eps, lon]];
let above = array![[std::f64::consts::FRAC_PI_2 + eps, lon]];
let (phi_below, _) = sphere_chart_basis_jet(below.view()).unwrap();
let (phi_above, _) = sphere_chart_basis_jet(above.view()).unwrap();
for col in 0..7 {
assert!(
(phi_below[[0, col]] - phi_above[[0, col]]).abs() <= 1.0e-6,
"basis discontinuous across lat = π/2 at col {col}: \
{} vs {}",
phi_below[[0, col]],
phi_above[[0, col]]
);
}
}
pub(crate) fn assert_second_jet_matches_central_difference<E: SaeBasisSecondJet>(
evaluator: &E,
coords: Array2<f64>,
abs_tol: f64,
rel_tol: f64,
) -> Result<(), String> {
let epsilon = 1.0e-4;
let second = evaluator.second_jet(coords.view())?;
let (_phi, jet) = evaluator.evaluate(coords.view())?;
let (n_rows, n_basis, latent_dim, latent_dim_b) = second.dim();
assert_eq!(latent_dim, latent_dim_b);
assert_eq!((n_rows, n_basis, latent_dim), jet.dim());
for row in 0..n_rows {
for axis_c in 0..latent_dim {
let mut plus = coords.clone();
let mut minus = coords.clone();
plus[[row, axis_c]] += epsilon;
minus[[row, axis_c]] -= epsilon;
let (_, jet_plus) = evaluator.evaluate(plus.view()).unwrap();
let (_, jet_minus) = evaluator.evaluate(minus.view()).unwrap();
for basis in 0..n_basis {
for axis_a in 0..latent_dim {
let fd = (jet_plus[[row, basis, axis_a]] - jet_minus[[row, basis, axis_a]])
/ (2.0 * epsilon);
let analytic = second[[row, basis, axis_a, axis_c]];
let error = (analytic - fd).abs();
let threshold = abs_tol + rel_tol * analytic.abs().max(fd.abs());
assert!(
error <= threshold,
"row={row} basis={basis} axis_a={axis_a} axis_c={axis_c}: \
analytic={analytic:.12e}, fd={fd:.12e}, error={error:.12e}, \
threshold={threshold:.12e}"
);
}
}
}
}
for row in 0..n_rows {
for basis in 0..n_basis {
for axis_a in 0..latent_dim {
for axis_b in 0..latent_dim {
let h_ab = second[[row, basis, axis_a, axis_b]];
let h_ba = second[[row, basis, axis_b, axis_a]];
assert!(
(h_ab - h_ba).abs() <= 1.0e-12,
"second_jet not symmetric: row={row} basis={basis} \
({axis_a},{axis_b})={h_ab:.6e} vs ({axis_b},{axis_a})={h_ba:.6e}"
);
}
}
}
}
Ok(())
}
pub(crate) fn assert_third_jet_matches_central_difference<E: SaeBasisThirdJet>(
evaluator: &E,
coords: Array2<f64>,
abs_tol: f64,
rel_tol: f64,
) -> Result<(), String> {
let epsilon = 1.0e-4;
let third = evaluator.third_jet(coords.view())?;
let second = evaluator.second_jet(coords.view())?;
let (n_rows, n_basis, latent_dim, ld_b, ld_c) = third.dim();
assert_eq!(latent_dim, ld_b);
assert_eq!(latent_dim, ld_c);
assert_eq!((n_rows, n_basis, latent_dim, latent_dim), second.dim());
for row in 0..n_rows {
for axis_e in 0..latent_dim {
let mut plus = coords.clone();
let mut minus = coords.clone();
plus[[row, axis_e]] += epsilon;
minus[[row, axis_e]] -= epsilon;
let second_plus = evaluator.second_jet(plus.view())?;
let second_minus = evaluator.second_jet(minus.view())?;
for basis in 0..n_basis {
for axis_a in 0..latent_dim {
for axis_c in 0..latent_dim {
let fd = (second_plus[[row, basis, axis_a, axis_c]]
- second_minus[[row, basis, axis_a, axis_c]])
/ (2.0 * epsilon);
let analytic = third[[row, basis, axis_a, axis_c, axis_e]];
let error = (analytic - fd).abs();
let threshold = abs_tol + rel_tol * analytic.abs().max(fd.abs());
assert!(
error <= threshold,
"row={row} basis={basis} a={axis_a} c={axis_c} e={axis_e}: \
analytic={analytic:.12e}, fd={fd:.12e}, error={error:.6e}, \
threshold={threshold:.6e}"
);
}
}
}
}
}
for row in 0..n_rows {
for basis in 0..n_basis {
for a in 0..latent_dim {
for b in 0..latent_dim {
for c in 0..latent_dim {
let reference = third[[row, basis, a, b, c]];
for perm in [[a, c, b], [b, a, c], [b, c, a], [c, a, b], [c, b, a]] {
let permuted = third[[row, basis, perm[0], perm[1], perm[2]]];
assert!(
(reference - permuted).abs() <= 1.0e-10,
"third_jet not symmetric: row={row} basis={basis} \
({a},{b},{c})={reference:.6e} vs ({},{},{})={permuted:.6e}",
perm[0],
perm[1],
perm[2]
);
}
}
}
}
}
}
Ok(())
}
#[test]
pub(crate) fn isometry_periodic_second_jet_matches_fd() -> Result<(), String> {
assert_second_jet_matches_central_difference(
&PeriodicHarmonicEvaluator::new(7).unwrap(),
array![[-0.37], [0.0], [0.125], [0.41]],
1.0e-6,
1.0e-5,
)?;
Ok(())
}
#[test]
pub(crate) fn isometry_sphere_second_jet_matches_fd() -> Result<(), String> {
let sphere_coords = array![[-0.7, -1.2], [-0.25, 0.0], [0.35, 0.9], [0.8, 2.1]];
assert_second_jet_matches_central_difference(
&SphereChartEvaluator,
sphere_coords,
1.0e-6,
1.0e-5,
)?;
Ok(())
}
#[test]
pub(crate) fn isometry_torus_second_jet_matches_fd() -> Result<(), String> {
let torus_coords = array![[0.1, 0.7], [0.42, 0.0], [0.95, 0.33], [0.5, 0.5]];
let evaluator = TorusHarmonicEvaluator::new(2, 3).unwrap();
assert!(evaluator.basis_size() > 0);
assert_second_jet_matches_central_difference(&evaluator, torus_coords, 1.0e-6, 1.0e-5)?;
Ok(())
}
#[test]
pub(crate) fn isometry_periodic_third_jet_matches_fd() -> Result<(), String> {
assert_third_jet_matches_central_difference(
&PeriodicHarmonicEvaluator::new(7).unwrap(),
array![[-0.37], [0.0], [0.125], [0.41]],
1.0e-6,
1.0e-5,
)?;
Ok(())
}
#[test]
pub(crate) fn isometry_sphere_third_jet_matches_fd() -> Result<(), String> {
let sphere_coords = array![[-0.7, -1.2], [-0.25, 0.0], [0.35, 0.9], [0.8, 2.1]];
assert_third_jet_matches_central_difference(
&SphereChartEvaluator,
sphere_coords,
1.0e-6,
1.0e-5,
)?;
Ok(())
}
#[test]
pub(crate) fn isometry_torus_third_jet_matches_fd() -> Result<(), String> {
let torus_coords = array![[0.1, 0.7], [0.42, 0.0], [0.95, 0.33], [0.5, 0.5]];
let evaluator = TorusHarmonicEvaluator::new(2, 3).unwrap();
assert!(evaluator.basis_size() > 0);
assert_third_jet_matches_central_difference(&evaluator, torus_coords, 1.0e-6, 1.0e-5)?;
Ok(())
}
#[test]
pub(crate) fn isometry_affine_third_jet_is_trivial_zero() -> Result<(), String> {
let evaluator = AffineCoordinateEvaluator { latent_dim: 3 };
let coords = array![[0.2, -0.3, 0.7], [1.1, 0.0, -0.4]];
let third = evaluator.third_jet(coords.view())?;
assert_eq!(third.dim(), (coords.nrows(), 4, 3, 3, 3));
assert!(
third.iter().all(|x| *x == 0.0),
"affine third jet must vanish identically, got {third:?}"
);
Ok(())
}
#[test]
pub(crate) fn isometry_euclidean_patch_third_jet_matches_fd() -> Result<(), String> {
let evaluator = EuclideanPatchEvaluator::new(2, 4)?;
let coords = array![[0.2, -0.3], [0.7, 0.4], [-0.5, 0.9]];
assert_third_jet_matches_central_difference(&evaluator, coords, 1.0e-6, 1.0e-5)?;
Ok(())
}
fn cylinder_test_coords() -> Array2<f64> {
array![
[0.0_f64, -1.3],
[0.125, 0.0],
[0.4, 0.7],
[0.91, 2.2],
[0.6, -0.45]
]
}
#[test]
pub(crate) fn cylinder_phi_is_circle_tensor_line_product() -> Result<(), String> {
let h = 2usize;
let degree = 2usize;
let evaluator = CylinderHarmonicEvaluator::new(h, degree)?;
let mc = 2 * h + 1;
let ml = degree + 1;
assert_eq!(evaluator.circle_basis_size(), mc);
assert_eq!(evaluator.line_basis_size(), ml);
assert_eq!(evaluator.basis_size(), mc * ml);
let coords = cylinder_test_coords();
let (phi, jet) = evaluator.evaluate(coords.view())?;
assert_eq!(phi.dim(), (coords.nrows(), mc * ml));
assert_eq!(jet.dim(), (coords.nrows(), mc * ml, 2));
let two_pi = std::f64::consts::TAU;
for row in 0..coords.nrows() {
let t0 = coords[[row, 0]];
let t1 = coords[[row, 1]];
let mut circ = vec![0.0_f64; mc];
circ[0] = 1.0;
for k in 1..=h {
circ[2 * k - 1] = (two_pi * k as f64 * t0).sin();
circ[2 * k] = (two_pi * k as f64 * t0).cos();
}
let line: Vec<f64> = (0..ml).map(|j| t1.powi(j as i32)).collect();
for c in 0..mc {
for l in 0..ml {
let col = c * ml + l;
let expect = circ[c] * line[l];
assert_abs_diff_eq!(phi[[row, col]], expect, epsilon = 1e-12);
}
}
assert_abs_diff_eq!(phi[[row, 0]], 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(jet[[row, 0, 0]], 0.0, epsilon = 1e-12);
assert_abs_diff_eq!(jet[[row, 0, 1]], 0.0, epsilon = 1e-12);
}
Ok(())
}
#[test]
pub(crate) fn cylinder_jacobian_matches_central_difference() {
assert_jacobian_matches_central_difference(
&CylinderHarmonicEvaluator::new(3, 3).unwrap(),
cylinder_test_coords(),
1.0e-6,
);
}
#[test]
pub(crate) fn cylinder_second_jet_matches_fd() -> Result<(), String> {
let evaluator = CylinderHarmonicEvaluator::new(3, 3)?;
assert_second_jet_matches_central_difference(
&evaluator,
cylinder_test_coords(),
1.0e-6,
1.0e-5,
)?;
Ok(())
}
#[test]
pub(crate) fn cylinder_third_jet_matches_fd() -> Result<(), String> {
let evaluator = CylinderHarmonicEvaluator::new(3, 3)?;
assert_third_jet_matches_central_difference(
&evaluator,
cylinder_test_coords(),
1.0e-6,
1.0e-5,
)?;
Ok(())
}
#[test]
pub(crate) fn cylinder_roughness_gram_is_psd_with_constant_nullspace() {
let h = 2usize;
let degree = 2usize;
let evaluator = CylinderHarmonicEvaluator::new(h, degree).unwrap();
let mc = 2 * h + 1;
let ml = degree + 1;
let m = mc * ml;
let s = evaluator.roughness_gram();
assert_eq!(s.dim(), (m, m));
for i in 0..m {
for j in 0..m {
assert_abs_diff_eq!(s[[i, j]], s[[j, i]], epsilon = 1e-12);
}
}
for j in 0..m {
assert_abs_diff_eq!(s[[0, j]], 0.0, epsilon = 1e-12);
assert_abs_diff_eq!(s[[j, 0]], 0.0, epsilon = 1e-12);
}
let two_pi = std::f64::consts::TAU;
for k in 1..=h {
let omega4 = (two_pi * k as f64).powi(4);
let s_idx = 2 * k - 1;
let c_idx = 2 * k;
assert_abs_diff_eq!(s[[s_idx * ml, s_idx * ml]], omega4 * 0.5, epsilon = 1e-6);
assert_abs_diff_eq!(s[[c_idx * ml, c_idx * ml]], omega4 * 0.5, epsilon = 1e-6);
}
if degree >= 2 {
let col = 2; assert_abs_diff_eq!(s[[col, col]], 4.0, epsilon = 1e-12);
}
let (evals, _) = s.eigh(Side::Lower).unwrap();
for &lam in evals.iter() {
assert!(
lam >= -1.0e-9,
"cylinder roughness Gram must be PSD; got eigenvalue {lam:.3e}"
);
}
}
#[test]
pub(crate) fn cylinder_rejects_zero_harmonics() {
assert!(CylinderHarmonicEvaluator::new(0, 2).is_err());
assert!(CylinderHarmonicEvaluator::new(1, 0).is_ok());
}
#[test]
pub(crate) fn cylinder_latent_manifold_is_circle_times_line() {
let manifold = SaeAtomBasisKind::Cylinder.latent_manifold(2);
match manifold {
LatentManifold::Product(parts) => {
assert_eq!(parts.len(), 2);
assert!(matches!(parts[0], LatentManifold::Circle { period } if period == 1.0));
assert!(matches!(parts[1], LatentManifold::Euclidean));
}
other => panic!("expected Product[Circle, Euclidean], got {other:?}"),
}
}
#[test]
pub(crate) fn cylinder_projection_seed_grid_sweeps_circle_only() {
let r = 12usize;
let grid = SaeAtomBasisKind::Cylinder
.projection_seed_grid(2, r)
.unwrap();
assert_eq!(grid.dim(), (r, 2));
for i in 0..r {
assert_abs_diff_eq!(grid[[i, 0]], i as f64 / r as f64, epsilon = 1e-12);
assert_abs_diff_eq!(grid[[i, 1]], 0.0, epsilon = 1e-12);
}
assert!(grid.column(0).iter().all(|&t| (0.0..1.0).contains(&t)));
}
#[test]
pub(crate) fn duchon_coordinate_evaluator_phi_and_jet_share_column_count() {
for (d, centers) in [
(1usize, array![[-1.0], [-0.4], [0.1], [0.6], [1.2], [1.9]]),
(
2usize,
array![
[-1.0, -0.8],
[-0.3, 0.4],
[0.2, -0.5],
[0.7, 0.9],
[1.1, -0.2],
[1.6, 0.6],
],
),
] {
let evaluator = DuchonCoordinateEvaluator::new(centers, 2).unwrap();
let coords = match d {
1 => array![[-0.5], [0.0], [0.3], [0.8]],
_ => array![[-0.5, 0.2], [0.0, -0.3], [0.3, 0.7], [0.8, -0.1]],
};
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
assert_eq!(
phi.ncols(),
jet.shape()[1],
"Duchon d={d}: Phi has {} columns but jet has {}",
phi.ncols(),
jet.shape()[1]
);
assert_eq!(jet.shape()[0], coords.nrows());
assert_eq!(jet.shape()[2], d);
}
}
#[test]
pub(crate) fn duchon_coordinate_evaluator_jacobian_matches_fd() {
let centers = array![
[-1.0, -0.8],
[-0.3, 0.4],
[0.2, -0.5],
[0.7, 0.9],
[1.1, -0.2],
[1.6, 0.6],
];
let evaluator = DuchonCoordinateEvaluator::new(centers, 2).unwrap();
let coords = array![[-0.5, 0.2], [0.05, -0.35], [0.45, 0.75], [1.3, 0.1]];
assert_jacobian_matches_central_difference(&evaluator, coords, 1.0e-4);
}
#[test]
pub(crate) fn duchon_coordinate_evaluator_second_jet_matches_fd() -> Result<(), String> {
let centers = array![
[-1.0, -0.8],
[-0.3, 0.4],
[0.2, -0.5],
[0.7, 0.9],
[1.1, -0.2],
[1.6, 0.6],
];
let evaluator = DuchonCoordinateEvaluator::new(centers, 2).unwrap();
let coords = array![[-0.5, 0.2], [0.05, -0.35], [0.45, 0.75], [1.3, 0.1]];
assert_second_jet_matches_central_difference(&evaluator, coords, 1.0e-4, 1.0e-4)?;
Ok(())
}
#[test]
pub(crate) fn duchon_coordinate_evaluator_third_jet_matches_fd() -> Result<(), String> {
let centers = array![
[-1.0, -0.8],
[-0.3, 0.4],
[0.2, -0.5],
[0.7, 0.9],
[1.1, -0.2],
[1.6, 0.6],
];
let evaluator = DuchonCoordinateEvaluator::new(centers, 2).unwrap();
let coords = array![[-0.5, 0.2], [0.05, -0.35], [0.45, 0.75], [1.3, 0.1]];
assert_third_jet_matches_central_difference(&evaluator, coords, 1.0e-4, 1.0e-4)?;
Ok(())
}
#[test]
pub(crate) fn euclidean_patch_evaluator_jets_match_fd() -> Result<(), String> {
let evaluator = EuclideanPatchEvaluator::new(2, 2).unwrap();
let coords = array![[0.0, -1.0], [3.5, 0.25], [-0.75, 1.2], [0.4, 0.9]];
assert_jacobian_matches_central_difference(&evaluator, coords.clone(), 1.0e-6);
assert_second_jet_matches_central_difference(&evaluator, coords, 1.0e-5, 1.0e-5)?;
let (phi, _jet) = evaluator.evaluate(array![[0.0, 0.0]].view())?;
assert_eq!(phi.ncols(), 6);
Ok(())
}
#[test]
pub(crate) fn euclidean_affine_gauge_canonicalization_preserves_reconstruction()
-> Result<(), String> {
let evaluator = Arc::new(EuclideanPatchEvaluator::new(1, 2)?);
let canonical = array![[-1.0_f64], [-0.35], [0.1], [0.65], [1.2]];
let mut coords = canonical.clone();
for row in 0..coords.nrows() {
coords[[row, 0]] = 2.75 + 4.0 * canonical[[row, 0]];
}
let (phi, jet) = evaluator.evaluate(coords.view())?;
let decoder = array![[0.25, -0.4], [1.2, 0.3], [-0.15, 0.5]];
let atom = SaeManifoldAtom::new(
"euclidean_patch",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
Array2::<f64>::eye(evaluator.basis_size()),
)?
.with_basis_evaluator(evaluator);
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((coords.nrows(), 1)),
vec![coords],
vec![LatentManifold::Euclidean],
AssignmentMode::softmax(1.0),
)?;
let mut term = SaeManifoldTerm::new(vec![atom], assignment)?;
let before = term.fitted();
term.canonicalize_affine_gauge_after_accept(None)?;
let after = term.fitted();
let max_abs = before
.iter()
.zip(after.iter())
.fold(0.0_f64, |acc, (&a, &b)| acc.max((a - b).abs()));
assert!(
max_abs <= 1.0e-10,
"canonicalization changed reconstruction by {max_abs:.3e}"
);
let live = term.assignment.coords[0].as_matrix();
let mean = live.column(0).sum() / live.nrows() as f64;
let rms = (live.column(0).iter().map(|v| v * v).sum::<f64>() / live.nrows() as f64).sqrt();
assert_abs_diff_eq!(mean, 0.0, epsilon = 1.0e-12);
assert_abs_diff_eq!(rms, 1.0, epsilon = 1.0e-12);
Ok(())
}
#[test]
pub(crate) fn quotient_step_norm_removes_pure_euclidean_affine_gauge() -> Result<(), String> {
let evaluator = Arc::new(EuclideanPatchEvaluator::new(1, 2)?);
let coords = array![[-1.0_f64], [-0.4], [0.2], [0.8], [1.3]];
let (phi, jet) = evaluator.evaluate(coords.view())?;
let decoder = array![[0.1, -0.2], [1.0, 0.4], [0.25, -0.3]];
let atom = SaeManifoldAtom::new(
"euclidean_patch",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
Array2::<f64>::eye(evaluator.basis_size()),
)?
.with_basis_evaluator(evaluator);
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((coords.nrows(), 1)),
vec![coords],
vec![LatentManifold::Euclidean],
AssignmentMode::softmax(1.0),
)?;
let term = SaeManifoldTerm::new(vec![atom], assignment)?;
let gauges = term.dense_step_gauge_vectors()?;
assert!(
gauges.len() >= 2,
"expected translation and scale gauge generators"
);
let n_coord = term.n_obs() * term.assignment.row_block_dim();
let gauge = &gauges[1];
let delta_t = gauge.slice(s![..n_coord]);
let delta_beta = gauge.slice(s![n_coord..]);
let raw = gauge.iter().map(|v| v * v).sum::<f64>();
let quotient = term.quotient_newton_step_norm_sq(delta_t, delta_beta, raw, 0.0)?;
assert!(
quotient <= raw.max(1.0) * 1.0e-20,
"pure affine gauge step left quotient norm squared {quotient:.3e} from raw {raw:.3e}"
);
Ok(())
}
#[test]
pub(crate) fn sae_torus_atom_recovers_two_frequency_synthetic() {
let n = 96usize;
let p = 4usize;
let h = 3usize;
let d = 2usize;
let evaluator = TorusHarmonicEvaluator::new(d, h).unwrap();
let m = evaluator.basis_size();
let mut true_coords = Array2::<f64>::zeros((n, d));
for i in 0..n {
true_coords[[i, 0]] = ((i as f64) * 0.137).rem_euclid(1.0);
true_coords[[i, 1]] = ((i as f64) * 0.241 + 0.13).rem_euclid(1.0);
}
let mut z = Array2::<f64>::zeros((n, p));
for i in 0..n {
let t1 = 2.0 * std::f64::consts::PI * true_coords[[i, 0]];
let t2 = 2.0 * std::f64::consts::PI * true_coords[[i, 1]];
z[[i, 0]] = t1.sin() + 0.3 * t2.cos();
z[[i, 1]] = t1.cos() + 0.2 * (t1 + t2).sin();
z[[i, 2]] = t2.sin();
z[[i, 3]] = 0.5 * (t1 - t2).cos();
}
let sst: f64 = z.iter().map(|v| v * v).sum::<f64>();
let (phi0, jet0) = evaluator.evaluate(true_coords.view()).unwrap();
let mut penalty = Array2::<f64>::eye(m);
penalty *= 1.0e-4;
let atom = SaeManifoldAtom::new(
"torus_atom",
SaeAtomBasisKind::Torus,
d,
phi0,
jet0,
Array2::<f64>::zeros((m, p)),
penalty,
)
.unwrap()
.with_basis_evaluator(Arc::new(TorusHarmonicEvaluator::new(d, h).unwrap()));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 1)),
vec![true_coords],
vec![LatentManifold::Product(vec![
LatentManifold::Circle { period: 1.0 },
LatentManifold::Circle { period: 1.0 },
])],
AssignmentMode::softmax(0.5),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let mut rho = SaeManifoldRho::new(0.0, -4.0, vec![Array1::<f64>::zeros(d)]);
let ridge = 1.0e-6;
for _ in 0..10 {
let loss = term
.run_joint_fit_arrow_schur(z.view(), &mut rho, None, 1, 1.0, ridge, ridge)
.unwrap();
if !loss.total().is_finite() {
break;
}
}
let fitted = term.fitted();
assert_eq!(fitted.dim(), (n, p));
let mut sse = 0.0_f64;
for ((row, col), v) in fitted.indexed_iter() {
let r = v - z[[row, col]];
sse += r * r;
}
let r2 = 1.0 - sse / sst.max(1.0e-12);
assert!(
r2 >= 0.5,
"torus atom R² too low: {r2:.4} (sst={sst:.4}, sse={sse:.4})"
);
}
#[test]
pub(crate) fn sae_sphere_atom_recovers_synthetic_signal() {
let n = 96usize;
let p = 3usize;
let d = 2usize;
let mut true_coords = Array2::<f64>::zeros((n, d));
for i in 0..n {
let t = (i as f64) / (n as f64);
true_coords[[i, 0]] = -0.5 + 1.0 * t; true_coords[[i, 1]] = -std::f64::consts::PI + 2.0 * std::f64::consts::PI * t;
}
let mut z = Array2::<f64>::zeros((n, p));
for i in 0..n {
let lat = true_coords[[i, 0]];
let lon = true_coords[[i, 1]];
let x = lat.cos() * lon.cos();
let y = lat.cos() * lon.sin();
let zc = lat.sin();
z[[i, 0]] = x;
z[[i, 1]] = y;
z[[i, 2]] = zc;
}
let sst: f64 = z.iter().map(|v| v * v).sum::<f64>();
let (phi0, jet0) = SphereChartEvaluator.evaluate(true_coords.view()).unwrap();
let m = phi0.ncols();
let mut penalty = Array2::<f64>::eye(m);
penalty *= 1.0e-4;
let atom = SaeManifoldAtom::new(
"sphere_atom",
SaeAtomBasisKind::Sphere,
d,
phi0,
jet0,
Array2::<f64>::zeros((m, p)),
penalty,
)
.unwrap()
.with_basis_evaluator(Arc::new(SphereChartEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 1)),
vec![true_coords],
vec![LatentManifold::Product(vec![
LatentManifold::Interval {
lo: -std::f64::consts::FRAC_PI_2,
hi: std::f64::consts::FRAC_PI_2,
},
LatentManifold::Circle {
period: std::f64::consts::TAU,
},
])],
AssignmentMode::softmax(0.5),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let mut rho = SaeManifoldRho::new(0.0, -4.0, vec![Array1::<f64>::zeros(2)]);
let ridge = 1.0e-6;
for _ in 0..10 {
let loss = term
.run_joint_fit_arrow_schur(z.view(), &mut rho, None, 1, 1.0, ridge, ridge)
.unwrap();
if !loss.total().is_finite() {
break;
}
}
let fitted = term.fitted();
assert_eq!(fitted.dim(), (n, p));
let mut sse = 0.0_f64;
for ((row, col), v) in fitted.indexed_iter() {
let r = v - z[[row, col]];
sse += r * r;
}
let r2 = 1.0 - sse / sst.max(1.0e-12);
assert!(
r2 >= 0.5,
"sphere atom R² too low: {r2:.4} (sst={sst:.4}, sse={sse:.4})"
);
}
#[test]
pub(crate) fn sae_manifold_fit_10_steps_one_harmonic_reaches_high_r2() {
let n = 64usize;
let m = 3usize;
let p = 1usize;
let true_t: Vec<f64> = (0..n).map(|i| (i as f64) / (n as f64)).collect();
let mut z = Array2::<f64>::zeros((n, p));
for i in 0..n {
let angle = 2.0 * std::f64::consts::PI * true_t[i];
z[[i, 0]] = 0.7 * angle.sin() + 0.3 * angle.cos();
}
let sst: f64 = z.iter().map(|v| v * v).sum::<f64>();
let evaluator = PeriodicHarmonicEvaluator::new(m).unwrap();
let mut coords0_data = Array2::<f64>::zeros((n, 1));
for i in 0..n {
coords0_data[[i, 0]] = (true_t[i] + 0.25).rem_euclid(1.0);
}
let (phi0, jet0) = evaluator.evaluate(coords0_data.view()).unwrap();
let atom = SaeManifoldAtom::new(
"periodic_atom",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
Array2::<f64>::zeros((m, p)),
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_evaluator(Arc::new(PeriodicHarmonicEvaluator::new(m).unwrap()));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 1)),
vec![coords0_data],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(0.5),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let mut rho = SaeManifoldRho::new(0.0, -6.0, vec![Array1::<f64>::zeros(1)]);
let max_iter = 10usize;
let learning_rate = 1.0;
let ridge = 1.0e-6;
let mut prev_total = f64::INFINITY;
for _ in 0..max_iter {
let loss = term
.run_joint_fit_arrow_schur(z.view(), &mut rho, None, 1, learning_rate, ridge, ridge)
.unwrap();
let total = loss.total();
if !total.is_finite() {
break;
}
let denom = prev_total.abs().max(1.0e-12);
let rel = (prev_total - total).abs() / denom;
prev_total = total;
if rel < 1.0e-6 {
break;
}
}
let fitted = term.fitted();
assert_eq!(fitted.dim(), (n, p));
let mut ssr = 0.0;
for i in 0..n {
let r = z[[i, 0]] - fitted[[i, 0]];
ssr += r * r;
}
let r2 = 1.0 - ssr / sst.max(1.0e-12);
assert!(
r2 >= 0.95,
"10-step in-sample R² = {r2:.4} (ssr={ssr:.6}, sst={sst:.6}) should be >= 0.95"
);
}
#[test]
pub(crate) fn softmax_assignment_hessian_diag_is_available_for_k2() {
let n = 4usize;
let k = 2usize;
let logits = Array2::<f64>::from_shape_fn((n, k), |(i, j)| 0.1 * (i as f64) - 0.2 * (j as f64));
let coords: Vec<Array2<f64>> = (0..k).map(|_| Array2::<f64>::zeros((n, 1))).collect();
let manifolds = vec![LatentManifold::Circle { period: 1.0 }; k];
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
coords,
manifolds,
AssignmentMode::softmax(0.7),
)
.unwrap();
let rho = SaeManifoldRho::new(0.0, -6.0, vec![Array1::<f64>::zeros(1); k]);
let (grad, diag) = assignment_prior_grad_hdiag(&assignment, &rho)
.expect("softmax assignment Hessian diagonal must be available");
assert_eq!(grad.len(), n * k);
assert_eq!(diag.len(), n * k);
assert!(grad.iter().all(|v| v.is_finite()));
assert!(diag.iter().all(|v| v.is_finite()));
}
#[test]
pub(crate) fn sae_registry_refuses_assignment_sparsity_penalties() {
let n = 3usize;
let k = 2usize;
let logits = Array2::<f64>::zeros((n, k));
let coords: Vec<Array2<f64>> = (0..k).map(|_| Array2::<f64>::zeros((n, 1))).collect();
let manifolds = vec![LatentManifold::Circle { period: 1.0 }; k];
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
coords,
manifolds,
AssignmentMode::softmax(0.7),
)
.expect("valid assignment");
let atoms: Vec<SaeManifoldAtom> = (0..k)
.map(|atom_idx| {
SaeManifoldAtom::new(
format!("periodic_{atom_idx}"),
SaeAtomBasisKind::Periodic,
1,
Array2::<f64>::ones((n, 1)),
Array3::<f64>::zeros((n, 1, 1)),
Array2::<f64>::zeros((1, 1)),
Array2::<f64>::eye(1),
)
.expect("valid atom")
})
.collect();
let term = SaeManifoldTerm::new(atoms, assignment).expect("valid SAE term");
let mut softmax_registry = AnalyticPenaltyRegistry::new();
softmax_registry.push(AnalyticPenaltyKind::SoftmaxAssignmentSparsity(Arc::new(
crate::terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty::new(k, 0.7),
)));
let softmax_err = term
.validate_analytic_penalty_registry(&softmax_registry)
.expect_err("SAE registry must reject softmax assignment sparsity");
assert!(softmax_err.contains("assignment sparsity"));
let mut ibp_registry = AnalyticPenaltyRegistry::new();
ibp_registry.push(AnalyticPenaltyKind::IBPAssignment(Arc::new(
crate::terms::analytic_penalties::IBPAssignmentPenalty::new(k, 1.2, 0.7, false),
)));
let ibp_err = term
.validate_analytic_penalty_registry(&ibp_registry)
.expect_err("SAE registry must reject IBP assignment sparsity");
assert!(ibp_err.contains("assignment sparsity"));
}
#[test]
pub(crate) fn ibp_fixed_alpha_assignment_value_matches_logit_gradient_fd() {
let n = 4usize;
let k = 3usize;
let logits = Array2::<f64>::from_shape_vec(
(n, k),
vec![
-0.4, 0.2, 0.7, 0.1, -0.3, 0.5, 0.8, -0.1, -0.6, 0.3, 0.6, -0.2,
],
)
.expect("valid IBP logit grid");
let coords: Vec<Array2<f64>> = (0..k).map(|_| Array2::<f64>::zeros((n, 1))).collect();
let manifolds = vec![LatentManifold::Circle { period: 1.0 }; k];
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
coords,
manifolds,
AssignmentMode::ibp_map(0.9, 1.4, false),
)
.expect("valid IBP assignment");
let rho = SaeManifoldRho::new(0.23_f64.ln(), -6.0, vec![Array1::<f64>::zeros(1); k]);
let (grad, _) =
assignment_prior_grad_hdiag(&assignment, &rho).expect("IBP assignment gradient");
let idx = 5usize;
let step = 1.0e-6_f64;
let mut plus = assignment.clone();
plus.logits[[idx / k, idx % k]] += step;
let mut minus = assignment.clone();
minus.logits[[idx / k, idx % k]] -= step;
let fd =
(assignment_prior_value(&plus, &rho) - assignment_prior_value(&minus, &rho)) / (2.0 * step);
assert_abs_diff_eq!(grad[idx], fd, epsilon = 2.0e-7);
}
#[test]
pub(crate) fn ibp_assembly_emits_cross_row_woodbury_source_matching_fd_hessian() {
let coords0 = array![[0.05], [0.20], [0.55], [0.80]];
let coords1 = array![[0.15], [0.30], [0.65], [0.90]];
let (phi0, jet0) = periodic_basis(&coords0);
let (phi1, jet1) = periodic_basis(&coords1);
let atom0 = SaeManifoldAtom::new(
"periodic0",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
array![[0.25], [-0.35], [0.15]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let atom1 = SaeManifoldAtom::new(
"periodic1",
SaeAtomBasisKind::Periodic,
1,
phi1,
jet1,
array![[-0.10], [0.20], [0.30]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let logits = array![[1.2, 0.4], [0.6, 1.0], [0.9, 0.3], [1.4, 0.7]];
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
vec![coords0, coords1],
vec![
LatentManifold::Circle { period: 1.0 },
LatentManifold::Circle { period: 1.0 },
],
AssignmentMode::ibp_map(0.8, 1.0, false),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom0, atom1], assignment).unwrap();
let target = array![[0.12], [-0.03], [0.08], [0.20]];
let rho = SaeManifoldRho::new(
0.3_f64.ln(),
0.7_f64.ln(),
vec![array![0.9_f64.ln()], array![1.1_f64.ln()]],
);
let n = term.assignment.n_obs();
let k = term.assignment.k_atoms();
let sys = term
.assemble_arrow_schur(target.view(), &rho, None)
.expect("IBP arrow assembly");
let source = sys
.ibp_cross_row
.as_ref()
.expect("an IBP-active assembly must emit the cross-row Woodbury source");
assert_eq!(source.r, k, "the rank must be the atom count K");
let total_t = sys.row_offsets[n];
let mut u = Array2::<f64>::zeros((total_t, k));
for &(g, atom_k, z_prime) in &source.entries {
u[[g, atom_k]] += z_prime;
}
for i in 0..n {
for atom_k in 0..k {
let g = sys.row_offsets[i] + atom_k;
assert!(
u[[g, atom_k]].abs() > 0.0 || term.assignment.logits[[i, atom_k]].abs() > 1.0e3,
"row {i} atom {atom_k} logit slot must carry a z' entry"
);
}
}
let d = source.d.clone();
let step = 1.0e-5_f64;
let fd_cross = |i: usize, j: usize, atom_k: usize| -> f64 {
let bump = |si: f64, sj: f64| -> f64 {
let mut a = term.assignment.clone();
a.logits[[i, atom_k]] += si * step;
a.logits[[j, atom_k]] += sj * step;
assignment_prior_value(&a, &rho)
};
(bump(1.0, 1.0) - bump(1.0, -1.0) - bump(-1.0, 1.0) + bump(-1.0, -1.0))
/ (4.0 * step * step)
};
for atom_k in 0..k {
for i in 0..n {
for j in 0..n {
if i == j {
continue;
}
let gi = sys.row_offsets[i] + atom_k;
let gj = sys.row_offsets[j] + atom_k;
let analytic = d[atom_k] * u[[gi, atom_k]] * u[[gj, atom_k]];
let fd = fd_cross(i, j, atom_k);
assert_abs_diff_eq!(analytic, fd, epsilon = 5.0e-6);
}
}
}
for i in 0..n {
for j in 0..n {
if i == j {
continue;
}
let mut a = term.assignment.clone();
let cross = {
let s = 1.0e-5_f64;
let mut bump = |si: f64, sj: f64| -> f64 {
a.logits[[i, 0]] = term.assignment.logits[[i, 0]] + si * s;
a.logits[[j, 1]] = term.assignment.logits[[j, 1]] + sj * s;
assignment_prior_value(&a, &rho)
};
(bump(1.0, 1.0) - bump(1.0, -1.0) - bump(-1.0, 1.0) + bump(-1.0, -1.0))
/ (4.0 * s * s)
};
assert_abs_diff_eq!(cross, 0.0, epsilon = 5.0e-6);
}
}
}
#[test]
pub(crate) fn jumprelu_assignment_value_matches_logit_gradient_fd() {
let n = 4usize;
let k = 2usize;
let temperature = 0.35_f64;
let threshold = 0.1_f64;
let logits =
Array2::<f64>::from_shape_vec((n, k), vec![-13.0, -0.2, 0.0, 0.05, 0.15, 0.4, 0.9, 1.5])
.expect("valid JumpReLU logit grid");
let coords: Vec<Array2<f64>> = (0..k).map(|_| Array2::<f64>::zeros((n, 1))).collect();
let manifolds = vec![LatentManifold::Circle { period: 1.0 }; k];
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
coords,
manifolds,
AssignmentMode::jumprelu(temperature, threshold),
)
.expect("valid JumpReLU assignment");
let rho = SaeManifoldRho::new(0.7_f64.ln(), -6.0, vec![Array1::<f64>::zeros(1); k]);
let (grad, _) =
assignment_prior_grad_hdiag(&assignment, &rho).expect("JumpReLU assignment gradient");
let idx = 4usize;
let step = 1.0e-6_f64;
let mut plus = assignment.clone();
plus.logits[[idx / k, idx % k]] += step;
let mut minus = assignment.clone();
minus.logits[[idx / k, idx % k]] -= step;
let fd =
(assignment_prior_value(&plus, &rho) - assignment_prior_value(&minus, &rho)) / (2.0 * step);
assert_abs_diff_eq!(grad[idx], fd, epsilon = 2.0e-8);
}
#[test]
pub(crate) fn jumprelu_assignment_prior_hessian_diag_is_exact_over_logit_sweep() {
let n = 6usize;
let k = 2usize;
let temperature = 0.35_f64;
let threshold = 0.1_f64;
let logits = Array2::<f64>::from_shape_vec(
(n, k),
vec![
-2.0, -0.2, 0.0, 0.05, 0.1, 0.15, 0.4, 0.9, 1.5, 2.5, 4.0, 6.0,
],
)
.expect("valid logit grid");
let coords: Vec<Array2<f64>> = (0..k).map(|_| Array2::<f64>::zeros((n, 1))).collect();
let manifolds = vec![LatentManifold::Circle { period: 1.0 }; k];
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits.clone(),
coords,
manifolds,
AssignmentMode::jumprelu(temperature, threshold),
)
.expect("valid JumpReLU assignment");
let rho = SaeManifoldRho::new(0.7_f64.ln(), -6.0, vec![Array1::<f64>::zeros(1); k]);
let (grad, diag) = assignment_prior_grad_hdiag(&assignment, &rho)
.expect("JumpReLU assignment prior hessian diag");
let inv_tau = 1.0 / temperature;
let inv_tau2 = inv_tau * inv_tau;
let sparsity_strength = rho.log_lambda_sparse.exp();
assert_eq!(grad.len(), n * k);
assert_eq!(diag.len(), n * k);
let mut saw_negative = false;
for (idx, &entry) in diag.iter().enumerate() {
let logit = logits[[idx / k, idx % k]];
let expected = if jumprelu_in_optimization_band(logit, threshold, temperature) {
let activation = crate::linalg::utils::stable_logistic((logit - threshold) * inv_tau);
let slope = activation * (1.0 - activation);
sparsity_strength * slope * (1.0 - 2.0 * activation) * inv_tau2
} else {
0.0
};
assert!(
entry.is_finite(),
"JumpReLU hessian_diag must be finite at index {idx}"
);
saw_negative |= entry < 0.0;
assert_abs_diff_eq!(entry, expected, epsilon = 1e-12);
}
assert!(
saw_negative,
"exact JumpReLU hessian_diag must go negative above the threshold"
);
}
#[test]
pub(crate) fn ibp_map_k2_periodic_torus_recovers_signal_with_lsq_init() {
use crate::linalg::faer_ndarray::{FaerCholesky, fast_ata, fast_atb};
use faer::Side as FaerSide;
let n = 200usize;
let p = 8usize;
let k = 2usize;
let m = 5usize;
let mut theta = Array2::<f64>::zeros((n, 2));
for i in 0..n {
theta[[i, 0]] = ((i as f64) * 0.07) % 1.0;
theta[[i, 1]] = ((i as f64) * 0.13 + 0.31) % 1.0;
}
let mut raw = Array2::<f64>::zeros((n, 4));
for i in 0..n {
let a1 = 2.0 * std::f64::consts::PI * theta[[i, 0]];
let a2 = 2.0 * std::f64::consts::PI * theta[[i, 1]];
raw[[i, 0]] = a1.cos();
raw[[i, 1]] = a1.sin();
raw[[i, 2]] = a2.cos();
raw[[i, 3]] = a2.sin();
}
let mix = Array2::<f64>::from_shape_fn((4, p), |(i, j)| {
((i as f64 + 1.0) * 0.37 + (j as f64) * 0.21).sin()
});
let mut z = Array2::<f64>::zeros((n, p));
for i in 0..n {
for j in 0..p {
let mut acc = 0.0;
for r in 0..4 {
acc += raw[[i, r]] * mix[[r, j]];
}
z[[i, j]] = acc;
}
}
let mut col_mean = Array1::<f64>::zeros(p);
for j in 0..p {
let mut acc = 0.0;
for i in 0..n {
acc += z[[i, j]];
}
col_mean[j] = acc / n as f64;
}
for i in 0..n {
for j in 0..p {
z[[i, j]] -= col_mean[j];
}
}
let mut coords_k = vec![Array2::<f64>::zeros((n, 1)); k];
for i in 0..n {
coords_k[0][[i, 0]] = (theta[[i, 0]] + 0.05).rem_euclid(1.0);
coords_k[1][[i, 0]] = (theta[[i, 1]] + 0.07).rem_euclid(1.0);
}
let evaluator = PeriodicHarmonicEvaluator::new(m).unwrap();
let mut phi_k = Vec::with_capacity(k);
let mut jet_k = Vec::with_capacity(k);
for atom_idx in 0..k {
let (phi, jet) = evaluator.evaluate(coords_k[atom_idx].view()).unwrap();
phi_k.push(phi);
jet_k.push(jet);
}
let m_total = k * m;
let mut x = Array2::<f64>::zeros((n, m_total));
for atom_idx in 0..k {
for i in 0..n {
for col in 0..m {
x[[i, atom_idx * m + col]] = 0.5 * phi_k[atom_idx][[i, col]];
}
}
}
let mut xtx = fast_ata(&x);
let mut trace = 0.0_f64;
for i in 0..m_total {
trace += xtx[[i, i]];
}
let jitter = (trace / m_total as f64).max(1.0) * 1.0e-8;
for i in 0..m_total {
xtx[[i, i]] += jitter;
}
let xtz = fast_atb(&x, &z);
let b_joint = xtx
.cholesky(FaerSide::Lower)
.expect("LSQ Cholesky")
.solve_mat(&xtz);
let mut atoms = Vec::with_capacity(k);
for atom_idx in 0..k {
let mut b = Array2::<f64>::zeros((m, p));
for col in 0..m {
for j in 0..p {
b[[col, j]] = b_joint[[atom_idx * m + col, j]];
}
}
let atom = SaeManifoldAtom::new(
format!("torus_atom_{atom_idx}"),
SaeAtomBasisKind::Periodic,
1,
phi_k[atom_idx].clone(),
jet_k[atom_idx].clone(),
b,
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_evaluator(Arc::new(PeriodicHarmonicEvaluator::new(m).unwrap()));
atoms.push(atom);
}
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, k)),
coords_k,
vec![LatentManifold::Circle { period: 1.0 }; k],
AssignmentMode::ibp_map(0.7, 1.0, false),
)
.unwrap();
let mut term = SaeManifoldTerm::new(atoms, assignment).unwrap();
let mut rho = SaeManifoldRho::new((0.02_f64).ln(), -6.0, vec![Array1::<f64>::zeros(1); k]);
let mut prev_total = f64::INFINITY;
for _ in 0..30 {
let loss = term
.run_joint_fit_arrow_schur(z.view(), &mut rho, None, 1, 1.0, 1.0e-6, 1.0e-6)
.unwrap();
let total = loss.total();
if !total.is_finite() {
break;
}
let denom = prev_total.abs().max(1.0e-12);
let rel = (prev_total - total).abs() / denom;
prev_total = total;
if rel < 1.0e-6 {
break;
}
}
let fitted = term.fitted();
let mut ssr = 0.0;
let mut sst = 0.0;
for i in 0..n {
for j in 0..p {
let r = z[[i, j]] - fitted[[i, j]];
ssr += r * r;
sst += z[[i, j]] * z[[i, j]];
}
}
let r2 = 1.0 - ssr / sst.max(1.0e-12);
assert!(
r2 > 0.5,
"K=2 periodic torus IBP-MAP R² = {r2:.4} (ssr={ssr:.4}, sst={sst:.4}) should be > 0.5 with LSQ-seeded decoder"
);
let assignments = term.assignment.assignments();
let mean_active: f64 = assignments.iter().copied().sum::<f64>() / (n as f64);
assert!(
mean_active > 0.2,
"mean active mass across rows = {mean_active:.4} should exceed 0.2; assignment did not collapse"
);
}
#[test]
pub(crate) fn softmax_k2_periodic_completes_joint_fit_step() {
let n = 64usize;
let p = 4usize;
let k = 2usize;
let m = 3usize;
let mut z = Array2::<f64>::zeros((n, p));
for i in 0..n {
let a = 2.0 * std::f64::consts::PI * (i as f64) / (n as f64);
z[[i, 0]] = a.sin();
z[[i, 1]] = a.cos();
z[[i, 2]] = (2.0 * a).sin();
z[[i, 3]] = (2.0 * a).cos();
}
let evaluator = PeriodicHarmonicEvaluator::new(m).unwrap();
let mut coords_k = vec![Array2::<f64>::zeros((n, 1)); k];
for i in 0..n {
coords_k[0][[i, 0]] = (i as f64) / (n as f64);
coords_k[1][[i, 0]] = ((i as f64) * 2.0 / (n as f64)).rem_euclid(1.0);
}
let mut atoms = Vec::new();
for atom_idx in 0..k {
let (phi, jet) = evaluator.evaluate(coords_k[atom_idx].view()).unwrap();
let b = Array2::<f64>::from_shape_fn((m, p), |(i, j)| {
0.1 * ((i as f64 + 1.0) * (j as f64 + 1.0)).sin()
});
let atom = SaeManifoldAtom::new(
format!("a_{atom_idx}"),
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
b,
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_evaluator(Arc::new(PeriodicHarmonicEvaluator::new(m).unwrap()));
atoms.push(atom);
}
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, k)),
coords_k,
vec![LatentManifold::Circle { period: 1.0 }; k],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(atoms, assignment).unwrap();
let mut rho = SaeManifoldRho::new(0.0, -6.0, vec![Array1::<f64>::zeros(1); k]);
let loss0 = term
.run_joint_fit_arrow_schur(z.view(), &mut rho, None, 1, 1.0, 1.0e-6, 1.0e-6)
.expect("softmax K=2 must complete first joint-fit step");
assert!(loss0.total().is_finite());
let loss1 = term
.run_joint_fit_arrow_schur(z.view(), &mut rho, None, 1, 1.0, 1.0e-6, 1.0e-6)
.expect("softmax K=2 must complete second joint-fit step");
assert!(loss1.total().is_finite());
}
pub(crate) fn assert_isometry_wiring_matches_fd(
evaluator: Arc<dyn SaeBasisSecondJet>,
coords: Array2<f64>,
) {
let n_obs = coords.nrows();
let latent_dim = coords.ncols();
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let m = phi.ncols();
let p: usize = 3;
let mut decoder = Array2::<f64>::zeros((m, p));
for i in 0..m {
for j in 0..p {
let x = (i as f64) * 0.371 + (j as f64) * 0.193 + 0.5;
decoder[[i, j]] = (x.sin() * 0.9) + 0.1 * ((i + j) as f64).cos();
}
}
let smooth = Array2::<f64>::eye(m);
let atom = SaeManifoldAtom::new(
"iso_wire_test",
SaeAtomBasisKind::Periodic,
latent_dim,
phi.clone(),
jet.clone(),
decoder.clone(),
smooth,
)
.unwrap()
.with_basis_second_jet(evaluator);
let target_slice = PsiSlice::full(n_obs * latent_dim, Some(latent_dim));
let penalty = IsometryPenalty::new_euclidean(target_slice, p);
let rho = Array1::<f64>::zeros(1);
let target_flat: Array1<f64> = coords.iter().copied().collect();
let v0 = penalty.value(target_flat.view(), rho.view());
assert_eq!(v0, IsometryPenalty::DEFAULT_VALUE_ON_MISSING_CACHE);
let g0 = penalty.grad_target(target_flat.view(), rho.view());
assert!(
g0.iter().all(|x| *x == 0.0),
"grad_target without cache must be all zeros, got {g0:?}"
);
let installed_second =
refresh_isometry_caches_from_atom(&penalty, &atom, coords.view()).unwrap();
assert!(
installed_second,
"evaluator must implement second_jet for this oracle to run"
);
let value = penalty.value(target_flat.view(), rho.view());
assert!(
value > 1.0e-6,
"expected non-trivial isometry loss after cache refresh, got {value}"
);
let grad = penalty.grad_target(target_flat.view(), rho.view());
assert_eq!(grad.len(), target_flat.len());
let max_abs = grad.iter().fold(0.0_f64, |acc, x| acc.max(x.abs()));
assert!(
max_abs > 1.0e-6,
"expected non-zero isometry gradient on at least one component, max |grad|={max_abs}"
);
let h_fd = 1.0e-5;
let probe_idx = 0usize; let mut coords_plus = coords.clone();
coords_plus[[0, 0]] += h_fd;
let mut coords_minus = coords.clone();
coords_minus[[0, 0]] -= h_fd;
refresh_isometry_caches_from_atom(&penalty, &atom, coords_plus.view()).unwrap();
let target_plus: Array1<f64> = coords_plus.iter().copied().collect();
let v_plus = penalty.value(target_plus.view(), rho.view());
refresh_isometry_caches_from_atom(&penalty, &atom, coords_minus.view()).unwrap();
let target_minus: Array1<f64> = coords_minus.iter().copied().collect();
let v_minus = penalty.value(target_minus.view(), rho.view());
refresh_isometry_caches_from_atom(&penalty, &atom, coords.view()).unwrap();
let grad_base = penalty.grad_target(target_flat.view(), rho.view());
let fd = (v_plus - v_minus) / (2.0 * h_fd);
let analytic = grad_base[probe_idx];
assert!(
(analytic - fd).abs() <= 1.0e-3 + 1.0e-4 * analytic.abs().max(fd.abs()),
"isometry grad/FD mismatch at coord 0: analytic={analytic:.6e}, fd={fd:.6e}"
);
}
#[test]
pub(crate) fn isometry_wiring_periodic_matches_fd() {
assert_isometry_wiring_matches_fd(
Arc::new(PeriodicHarmonicEvaluator::new(5).unwrap()),
array![[0.12], [0.37], [0.58], [0.81]],
);
}
#[test]
pub(crate) fn isometry_wiring_sphere_matches_fd() {
assert_isometry_wiring_matches_fd(
Arc::new(SphereChartEvaluator),
array![[-0.5, 0.3], [0.2, -1.1], [0.7, 0.9]],
);
}
#[test]
pub(crate) fn isometry_wiring_torus_matches_fd() {
assert_isometry_wiring_matches_fd(
Arc::new(TorusHarmonicEvaluator::new(2, 2).unwrap()),
array![[0.13, 0.42], [0.66, 0.19], [0.88, 0.55]],
);
}
pub(crate) fn deterministic_decoder(n_basis: usize, p_out: usize, seed: f64) -> Array2<f64> {
Array2::<f64>::from_shape_fn((n_basis, p_out), |(i, j)| {
let x = seed + 0.371 * (i as f64) - 0.193 * (j as f64) + 0.047 * ((i * j + 1) as f64);
0.8 * x.sin() + 0.35 * (1.7 * x).cos()
})
}
pub(crate) fn build_isometry_atom_for_evaluator(
evaluator: Arc<dyn SaeBasisSecondJet>,
kind: SaeAtomBasisKind,
coords: &Array2<f64>,
p_out: usize,
seed: f64,
) -> (SaeManifoldAtom, IsometryPenalty, Array1<f64>) {
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let m = phi.ncols();
let decoder = deterministic_decoder(m, p_out, seed);
let atom = SaeManifoldAtom::new(
"exact_hvp_atom",
kind,
coords.ncols(),
phi,
jet,
decoder,
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_second_jet(evaluator);
let target_flat: Array1<f64> = coords.iter().copied().collect();
let penalty = IsometryPenalty::new_euclidean(
PsiSlice::full(target_flat.len(), Some(coords.ncols())),
p_out,
);
(atom, penalty, target_flat)
}
pub(crate) fn assert_exact_isometry_hvp_matches_grad_fd(
evaluator: Arc<dyn SaeBasisSecondJet>,
kind: SaeAtomBasisKind,
coords: Array2<f64>,
p_out: usize,
direction: Array2<f64>,
) {
let (atom, penalty, target_flat) =
build_isometry_atom_for_evaluator(evaluator, kind, &coords, p_out, 0.91);
let rho = array![0.0_f64];
let installed = refresh_isometry_caches_from_atom(&penalty, &atom, coords.view()).unwrap();
assert!(
installed,
"second-jet cache must be installed for exact HVP test"
);
assert!(
penalty.third_decoder_derivative().is_some(),
"non-Duchon exact HVP requires a live refreshed third-decoder-jet cache"
);
let v: Array1<f64> = direction.iter().copied().collect();
let exact = penalty.hvp(target_flat.view(), rho.view(), v.view());
assert!(
exact.iter().any(|x| x.abs() > 1.0e-7),
"exact isometry HVP should be nonzero after K refresh; got {exact:?}"
);
let eps = 1.0e-6;
let coords_plus = &coords + &(direction.mapv(|x| eps * x));
let coords_minus = &coords - &(direction.mapv(|x| eps * x));
let target_plus: Array1<f64> = coords_plus.iter().copied().collect();
let target_minus: Array1<f64> = coords_minus.iter().copied().collect();
refresh_isometry_caches_from_atom(&penalty, &atom, coords_plus.view()).unwrap();
let grad_plus = penalty.grad_target(target_plus.view(), rho.view());
refresh_isometry_caches_from_atom(&penalty, &atom, coords_minus.view()).unwrap();
let grad_minus = penalty.grad_target(target_minus.view(), rho.view());
refresh_isometry_caches_from_atom(&penalty, &atom, coords.view()).unwrap();
let fd = (&grad_plus - &grad_minus).mapv(|x| x / (2.0 * eps));
for i in 0..exact.len() {
let err = (exact[i] - fd[i]).abs();
let tol = 2.0e-4 + 3.0e-5 * exact[i].abs().max(fd[i].abs());
assert!(
err <= tol,
"exact isometry HVP/grad-FD mismatch at flat index {i}: exact={:.12e}, fd={:.12e}, err={:.6e}, tol={:.6e}",
exact[i],
fd[i],
err,
tol
);
}
}
pub(crate) fn assert_exact_isometry_hvp_collapses_to_gn_at_zero_residual(
evaluator: Arc<dyn SaeBasisSecondJet>,
kind: SaeAtomBasisKind,
coords: Array2<f64>,
p_out: usize,
direction: Array2<f64>,
) {
let (atom, penalty, target_flat) =
build_isometry_atom_for_evaluator(evaluator, kind, &coords, p_out, 1.37);
let rho = array![0.0_f64];
let d = coords.ncols();
refresh_isometry_caches_from_atom(&penalty, &atom, coords.view()).unwrap();
let mut g_ref = penalty
.pullback_metric(d)
.expect("pullback metric is available after the cache refresh");
let mut trace_sum = 0.0_f64;
for row in 0..g_ref.nrows() {
for axis in 0..d {
trace_sum += g_ref[[row, axis * d + axis]];
}
}
let normalizer = trace_sum / (g_ref.nrows() * d) as f64;
for value in g_ref.iter_mut() {
*value /= normalizer;
}
let penalty = penalty.with_reference(IsometryReference::UserSupplied(Arc::new(g_ref)));
assert!(
penalty.third_decoder_derivative().is_some(),
"zero-residual exact/GN test must still carry the real refreshed K cache"
);
let v: Array1<f64> = direction.iter().copied().collect();
let exact = penalty.hvp(target_flat.view(), rho.view(), v.view());
let gn = penalty.psd_majorizer_hvp(target_flat.view(), rho.view(), v.view());
assert!(
gn.iter().any(|x| x.abs() > 1.0e-8),
"GN block should be nonzero so exact/GN equality is not vacuous"
);
for i in 0..exact.len() {
assert_abs_diff_eq!(exact[i], gn[i], epsilon = 1.0e-10);
}
}
#[test]
pub(crate) fn isometry_exact_hvp_sphere_matches_grad_fd_and_uses_refreshed_k() {
assert_exact_isometry_hvp_matches_grad_fd(
Arc::new(SphereChartEvaluator),
SaeAtomBasisKind::Sphere,
array![[-0.61, 0.23], [-0.18, -1.07], [0.42, 0.81], [0.73, -0.39]],
4,
array![[0.31, -0.27], [-0.18, 0.22], [0.14, 0.19], [-0.25, -0.11]],
);
}
#[test]
pub(crate) fn isometry_exact_hvp_torus_matches_grad_fd_and_uses_refreshed_k() {
assert_exact_isometry_hvp_matches_grad_fd(
Arc::new(TorusHarmonicEvaluator::new(2, 2).unwrap()),
SaeAtomBasisKind::Torus,
array![[0.13, 0.42], [0.66, 0.19], [0.88, 0.55]],
3,
array![[0.21, -0.16], [-0.24, 0.18], [0.13, 0.27]],
);
}
#[test]
pub(crate) fn isometry_exact_hvp_sphere_and_torus_collapse_to_gn_at_zero_residual() {
assert_exact_isometry_hvp_collapses_to_gn_at_zero_residual(
Arc::new(SphereChartEvaluator),
SaeAtomBasisKind::Sphere,
array![[-0.52, 0.17], [-0.11, -0.93], [0.39, 0.74]],
4,
array![[0.17, -0.21], [-0.13, 0.08], [0.22, 0.19]],
);
assert_exact_isometry_hvp_collapses_to_gn_at_zero_residual(
Arc::new(TorusHarmonicEvaluator::new(2, 2).unwrap()),
SaeAtomBasisKind::Torus,
array![[0.19, 0.31], [0.57, 0.73], [0.84, 0.12]],
3,
array![[0.11, -0.14], [-0.20, 0.07], [0.16, 0.23]],
);
}
pub(crate) fn assert_isometry_psd_majorizer_live_after_atom_refresh(
evaluator: Arc<dyn SaeBasisSecondJet>,
kind: SaeAtomBasisKind,
coords: Array2<f64>,
p_out: usize,
probes: &[Array2<f64>],
) {
let (atom, penalty, target_flat) =
build_isometry_atom_for_evaluator(evaluator, kind, &coords, p_out, 0.53);
let rho = array![0.0_f64];
let n = target_flat.len();
let unit0 = {
let mut e = Array1::<f64>::zeros(n);
e[0] = 1.0;
e
};
let pre = penalty.psd_majorizer_hvp(target_flat.view(), rho.view(), unit0.view());
assert!(
pre.iter().all(|x| *x == 0.0),
"psd_majorizer_hvp without a cache must be the zero block; got {pre:?}"
);
let installed = refresh_isometry_caches_from_atom(&penalty, &atom, coords.view()).unwrap();
assert!(
installed,
"second-jet cache must install for the PSD-majorizer liveness test"
);
let d = coords.ncols();
let g = penalty
.pullback_metric(d)
.expect("pullback metric available after refresh");
let mut trace_sum = 0.0_f64;
for row in 0..g.nrows() {
for axis in 0..d {
trace_sum += g[[row, axis * d + axis]];
}
}
let normalizer = trace_sum / (g.nrows() * d) as f64;
let mut residual_mass = 0.0_f64;
for row in 0..g.nrows() {
for a in 0..d {
for b in 0..d {
let g_ref = if a == b { 1.0 } else { 0.0 };
residual_mass += (g[[row, a * d + b]] / normalizer - g_ref).abs();
}
}
}
assert!(
residual_mass > 1.0e-3,
"Euclidean-reference residual must be nonzero for a real curvature test; \
got residual mass {residual_mass:.3e}"
);
let mut bmat = Array2::<f64>::zeros((n, n));
for k in 0..n {
let mut e = Array1::<f64>::zeros(n);
e[k] = 1.0;
let col = penalty.psd_majorizer_hvp(target_flat.view(), rho.view(), e.view());
for r in 0..n {
bmat[[r, k]] = col[r];
}
}
let max_abs = bmat.iter().fold(0.0_f64, |acc, x| acc.max(x.abs()));
assert!(
max_abs > 1.0e-6,
"isometry GN majorizer must be nonzero for a non-Duchon basis after refresh; \
max |B| = {max_abs:.3e}"
);
for r in 0..n {
for c in 0..n {
assert_abs_diff_eq!(bmat[[r, c]], bmat[[c, r]], epsilon = 1.0e-10);
}
}
for probe in probes {
let v: Array1<f64> = probe.iter().copied().collect();
assert_eq!(v.len(), n, "probe must match the flattened target length");
let bv = penalty.psd_majorizer_hvp(target_flat.view(), rho.view(), v.view());
let quad = v.dot(&bv);
assert!(
quad >= -1.0e-9,
"isometry GN majorizer must be PSD; got vᵀBv = {quad:.3e}"
);
}
}
#[test]
pub(crate) fn isometry_psd_majorizer_live_after_sphere_refresh() {
assert_isometry_psd_majorizer_live_after_atom_refresh(
Arc::new(SphereChartEvaluator),
SaeAtomBasisKind::Sphere,
array![[-0.61, 0.23], [-0.18, -1.07], [0.42, 0.81]],
4,
&[
array![[0.31, -0.27], [-0.18, 0.22], [0.14, 0.19]],
array![[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]],
array![[-2.3, 0.6], [-0.1, 1.4], [0.8, -1.7]],
],
);
}
#[test]
pub(crate) fn isometry_psd_majorizer_live_after_circle_refresh() {
assert_isometry_psd_majorizer_live_after_atom_refresh(
Arc::new(PeriodicHarmonicEvaluator::new(5).unwrap()),
SaeAtomBasisKind::Periodic,
array![[0.12], [0.37], [0.58], [0.81]],
3,
&[
array![[0.4], [-1.1], [0.7], [0.3]],
array![[1.0], [1.0], [1.0], [1.0]],
array![[-2.3], [0.6], [-0.1], [1.4]],
],
);
}
#[test]
pub(crate) fn isometry_psd_majorizer_live_after_torus_refresh() {
assert_isometry_psd_majorizer_live_after_atom_refresh(
Arc::new(TorusHarmonicEvaluator::new(2, 2).unwrap()),
SaeAtomBasisKind::Torus,
array![[0.13, 0.42], [0.66, 0.19], [0.88, 0.55]],
3,
&[
array![[0.21, -0.16], [-0.24, 0.18], [0.13, 0.27]],
array![[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]],
array![[-1.2, 0.5], [0.3, -0.9], [0.7, 0.2]],
],
);
}
#[test]
pub(crate) fn refresh_isometry_caches_pairs_each_penalty_to_its_own_atom() {
let latent_dim = 1usize;
let p_out = 3usize;
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(5).unwrap());
let coords0 = array![[0.05], [0.20], [0.55], [0.80]];
let coords1 = array![[0.13], [0.41], [0.62], [0.91]];
let build_atom = |name: &str, coords: &Array2<f64>, seed: f64| {
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let m = phi.ncols();
let mut decoder = Array2::<f64>::zeros((m, p_out));
for i in 0..m {
for j in 0..p_out {
let x = (i as f64) * 0.371 + (j as f64) * 0.193 + seed;
decoder[[i, j]] = (x.sin() * 0.9) + 0.1 * ((i + j) as f64).cos();
}
}
let smooth = Array2::<f64>::eye(m);
SaeManifoldAtom::new(
name,
SaeAtomBasisKind::Periodic,
latent_dim,
phi,
jet,
decoder,
smooth,
)
.unwrap()
.with_basis_second_jet(evaluator.clone() as Arc<dyn SaeBasisSecondJet>)
};
let atom0 = build_atom("atom0", &coords0, 0.5);
let atom1 = build_atom("atom1", &coords1, 1.7);
let slice0 = PsiSlice::full(coords0.nrows() * latent_dim, Some(latent_dim));
let control0 = IsometryPenalty::new_euclidean(slice0, p_out);
refresh_isometry_caches_from_atom(&control0, &atom0, coords0.view()).unwrap();
let expected0 = control0
.jacobian_cache()
.expect("control penalty 0 must have a Jacobian cache");
let slice1 = PsiSlice::full(coords1.nrows() * latent_dim, Some(latent_dim));
let control1 = IsometryPenalty::new_euclidean(slice1, p_out);
refresh_isometry_caches_from_atom(&control1, &atom1, coords1.view()).unwrap();
let expected1 = control1
.jacobian_cache()
.expect("control penalty 1 must have a Jacobian cache");
assert_ne!(
*expected0, *expected1,
"atom 0 and atom 1 must produce distinct Jacobian caches"
);
let logits = Array2::<f64>::zeros((coords0.nrows(), 2));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
vec![coords0.clone(), coords1.clone()],
vec![
LatentManifold::Circle { period: 1.0 },
LatentManifold::Circle { period: 1.0 },
],
AssignmentMode::ibp_map(0.7, 1.0, true),
)
.unwrap();
let term = SaeManifoldTerm::new(vec![atom0, atom1], assignment).unwrap();
let mut registry = AnalyticPenaltyRegistry::new();
let pslice0 = PsiSlice::full(coords0.nrows() * latent_dim, Some(latent_dim));
let pslice1 = PsiSlice::full(coords1.nrows() * latent_dim, Some(latent_dim));
registry.push(AnalyticPenaltyKind::Isometry(Arc::new(
IsometryPenalty::new_euclidean(pslice0, p_out),
)));
registry.push(AnalyticPenaltyKind::Isometry(Arc::new(
IsometryPenalty::new_euclidean(pslice1, p_out),
)));
let coords_per_atom = vec![coords0.clone(), coords1.clone()];
let refreshed = refresh_isometry_caches_from_term(®istry, &term, &coords_per_atom).unwrap();
assert_eq!(refreshed, 2, "both penalties should install second caches");
let cache0 = match ®istry.penalties[0] {
AnalyticPenaltyKind::Isometry(p) => p
.jacobian_cache()
.expect("penalty 0 cache must be populated"),
_ => panic!("expected isometry penalty at index 0"),
};
let cache1 = match ®istry.penalties[1] {
AnalyticPenaltyKind::Isometry(p) => p
.jacobian_cache()
.expect("penalty 1 cache must be populated"),
_ => panic!("expected isometry penalty at index 1"),
};
assert_eq!(
*cache0, *expected0,
"penalty 0 must be refreshed against atom 0"
);
assert_eq!(
*cache1, *expected1,
"penalty 1 must be refreshed against atom 1 (regression: old find() paired it to atom 0)"
);
assert_ne!(
*cache0, *cache1,
"the two penalties must not collapse onto the same atom"
);
}
pub(crate) fn warmstart_test_objective() -> SaeManifoldOuterObjective {
let coords = array![[0.10], [0.35], [0.62], [0.88]];
let (phi, jet) = periodic_basis(&coords);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
array![[0.30], [-0.20], [0.15]],
Array2::<f64>::eye(3),
)
.unwrap();
let assignment = SaeAssignment::from_blocks_with_mode(
array![[0.9_f64], [0.8], [0.7], [0.6]],
vec![coords],
AssignmentMode::softmax(0.7),
)
.unwrap();
let term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.20_f64], [-0.10], [0.30], [0.05]];
let rho = SaeManifoldRho::new(0.0, 0.0, vec![Array1::<f64>::zeros(1)]);
SaeManifoldOuterObjective::new(term, target, None, rho, 8, 1.0, 1.0e-6, 1.0e-6)
}
pub(crate) fn near_singular_outer_gradient_cache() -> ArrowFactorCache {
ArrowFactorCache {
htt_factors: ArrowFactorSlab::from_blocks(vec![array![[1.0_f64, 0.0], [0.0, 1.0e-7]]]),
htt_factors_undamped: ArrowUndampedFactors::SameAsDamped,
schur_factor: Some(array![[1.0_f64]]),
solver_mode: ArrowSolverMode::Direct,
ridge_t: 0.0,
ridge_beta: 0.0,
htbeta: ArrowHtbetaCache::Disabled { estimated_bytes: 0 },
d: 2,
row_dims: Arc::from(vec![2usize].into_boxed_slice()),
row_offsets: Arc::from(vec![0usize, 2usize].into_boxed_slice()),
k: 1,
manifold_mode_fingerprint: 0,
row_hessian_fingerprint: 0,
pcg_diagnostics: PcgDiagnostics::default(),
gauge_deflated_directions: 0,
cross_row_woodbury: None,
}
}
pub(crate) fn diagonal_latent_cache(diagonal: &[f64]) -> ArrowFactorCache {
let dim = diagonal.len();
let mut factor = Array2::<f64>::zeros((dim, dim));
for i in 0..dim {
factor[[i, i]] = diagonal[i].sqrt();
}
ArrowFactorCache {
htt_factors: ArrowFactorSlab::from_blocks(vec![factor]),
htt_factors_undamped: ArrowUndampedFactors::SameAsDamped,
schur_factor: None,
solver_mode: ArrowSolverMode::Direct,
ridge_t: 0.0,
ridge_beta: 0.0,
htbeta: ArrowHtbetaCache::Disabled { estimated_bytes: 0 },
d: dim,
row_dims: Arc::from(vec![dim].into_boxed_slice()),
row_offsets: Arc::from(vec![0usize, dim].into_boxed_slice()),
k: 0,
manifold_mode_fingerprint: 0,
row_hessian_fingerprint: 0,
pcg_diagnostics: PcgDiagnostics::default(),
gauge_deflated_directions: 0,
cross_row_woodbury: None,
}
}
#[test]
pub(crate) fn outer_gradient_solver_rejects_near_singular_cache_without_matching_gauge() {
let cache = near_singular_outer_gradient_cache();
let obj = warmstart_test_objective();
let err = match obj.term.outer_gradient_arrow_solver(&cache) {
Err(err) => err,
Ok(..) => panic!("near-singular evidence factor without a matching gauge must reject"),
};
assert!(
err.contains("analytic outer gradient undefined at this rho"),
"guard error must name the undefined analytic-gradient condition; got: {err}"
);
assert!(
err.contains("min/max pivot ratio") && err.contains("floor"),
"guard error must report the pivot ratio and floor; got: {err}"
);
}
pub(crate) fn rank_deficient_euclidean_outer_gradient_objective() -> SaeManifoldOuterObjective {
let coords = array![[-0.7_f64], [-0.2], [0.3], [0.8]];
let n = coords.nrows();
let mut phi = Array2::<f64>::zeros((n, 2));
let mut jet = Array3::<f64>::zeros((n, 2, 1));
for row in 0..n {
phi[[row, 0]] = 1.0;
phi[[row, 1]] = coords[[row, 0]];
jet[[row, 1, 0]] = 1.0; }
let decoder = array![[1.0_f64, 2.0], [0.5, 1.0]];
let atom = SaeManifoldAtom::new(
"euclidean_line",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
Array2::<f64>::eye(2),
)
.unwrap();
let assignment = SaeAssignment::from_blocks_with_mode(
array![[0.9_f64], [0.8], [0.7], [0.6]],
vec![coords],
AssignmentMode::softmax(0.7),
)
.unwrap();
let term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[-1.0_f64, -2.0], [-0.3, -0.6], [0.4, 0.8], [1.1, 2.2]];
let rho = SaeManifoldRho::new(0.0, 0.0, vec![Array1::<f64>::zeros(1)]);
SaeManifoldOuterObjective::new(term, target, None, rho, 8, 1.0, 1.0e-6, 1.0e-6)
}
pub(crate) fn rank_deficient_beta_outer_gradient_cache() -> ArrowFactorCache {
let htt = ArrowFactorSlab::from_blocks(vec![array![[1.0_f64]]]);
let schur = array![
[1.0_f64, 0.0, 0.0, 0.0],
[0.0, 1.0e-7, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0e-7],
];
ArrowFactorCache {
htt_factors: htt,
htt_factors_undamped: ArrowUndampedFactors::SameAsDamped,
schur_factor: Some(schur),
solver_mode: ArrowSolverMode::Direct,
ridge_t: 0.0,
ridge_beta: 0.0,
htbeta: ArrowHtbetaCache::Dense {
blocks: Arc::from(vec![Array2::<f64>::zeros((1, 4))].into_boxed_slice()),
estimated_bytes: 0,
},
d: 1,
row_dims: Arc::from(vec![1usize].into_boxed_slice()),
row_offsets: Arc::from(vec![0usize, 1usize].into_boxed_slice()),
k: 4,
manifold_mode_fingerprint: 0,
row_hessian_fingerprint: 0,
pcg_diagnostics: PcgDiagnostics::default(),
gauge_deflated_directions: 0,
cross_row_woodbury: None,
}
}
#[test]
pub(crate) fn outer_gradient_solver_deflates_rank_deficient_decoder_beta_null() {
let obj = rank_deficient_euclidean_outer_gradient_objective();
let cache = rank_deficient_beta_outer_gradient_cache();
assert!(
SaeManifoldTerm::outer_gradient_conditioning_error(&cache).is_err(),
"fixture must be sub-floor singular so the conditioning path engages"
);
let solver = obj
.term
.outer_gradient_arrow_solver(&cache)
.expect("rank-deficient decoder β-null must be deflated, not rejected (#1051)");
let beta_null_rhs = array![0.0_f64, 0.0, 0.0, 1.0]; let rhs_t = Array1::<f64>::zeros(1);
let plain = cache
.full_inverse_apply(rhs_t.view(), beta_null_rhs.view())
.expect("plain solve")
.1;
let deflated = solver
.solve(rhs_t.view(), beta_null_rhs.view())
.expect("deflated solve")
.beta;
assert!(
plain[3].abs() > 1.0e13,
"plain near-null β solve must explode; got {}",
plain[3]
);
assert!(
deflated.iter().all(|v| v.is_finite()) && deflated[3].abs() < 10.0,
"deflated near-null β solve must be bounded at the Hessian scale; got {deflated:?}"
);
}
#[test]
pub(crate) fn deflated_solver_matches_plain_solve_when_no_gauge_is_installed() {
let cache = diagonal_latent_cache(&[2.0_f64, 5.0, 7.0]);
let solver = DeflatedArrowSolver::plain(&cache);
let rhs_t = array![4.0_f64, 10.0, -14.0];
let rhs_beta = Array1::<f64>::zeros(0);
let (plain_t, plain_beta) = cache
.full_inverse_apply(rhs_t.view(), rhs_beta.view())
.expect("plain cache solve");
let solved = solver
.solve(rhs_t.view(), rhs_beta.view())
.expect("adapter solve");
assert_eq!(solved.t.len(), plain_t.len());
for idx in 0..plain_t.len() {
assert_abs_diff_eq!(solved.t[idx], plain_t[idx], epsilon = 0.0);
}
assert_eq!(solved.beta.len(), plain_beta.len());
for idx in 0..plain_beta.len() {
assert_abs_diff_eq!(solved.beta[idx], plain_beta[idx], epsilon = 0.0);
}
}
#[test]
pub(crate) fn deflated_solver_matches_dense_quotient_pseudoinverse_on_near_null_fixture() {
let cache = diagonal_latent_cache(&[2.0_f64, 1.0e-14]);
let gauge = array![0.0_f64, 1.0];
let solver = DeflatedArrowSolver::from_orthonormal_gauges(&cache, vec![gauge], 2.0)
.expect("deflated solver");
let rhs_beta = Array1::<f64>::zeros(0);
let physical_rhs = array![4.0_f64, 0.0];
let solved = solver
.solve(physical_rhs.view(), rhs_beta.view())
.expect("physical solve");
let oracle = array![2.0_f64, 0.0];
for idx in 0..oracle.len() {
assert_abs_diff_eq!(solved.t[idx], oracle[idx], epsilon = 1.0e-12);
}
let gauge_rhs = array![0.0_f64, 1.0];
let plain = cache
.full_inverse_apply(gauge_rhs.view(), rhs_beta.view())
.expect("plain gauge solve")
.0;
let stiffened = solver
.solve(gauge_rhs.view(), rhs_beta.view())
.expect("stiffened gauge solve")
.t;
assert!(plain[1] > 1.0e13, "plain near-null solve must be huge");
assert_abs_diff_eq!(stiffened[1], 0.5, epsilon = 1.0e-12);
}
#[test]
pub(crate) fn seed_inner_state_accepts_empty_beta_as_noslot() {
let mut obj = warmstart_test_objective();
let empty: Array1<f64> = Array1::zeros(0);
let outcome = obj
.seed_inner_state(&empty)
.expect("empty-β seed must be accepted as a no-op, not rejected (gam#577/#579)");
assert!(
matches!(outcome, SeedOutcome::NoSlot),
"empty-β seed must report NoSlot (proceed cold); got {outcome:?}"
);
}
#[test]
pub(crate) fn seed_inner_state_installs_and_reuses_matching_beta() {
let mut obj = warmstart_test_objective();
let dim = obj.term.beta_dim();
let pristine = obj.term.flatten_beta();
let seed: Array1<f64> = Array1::from_shape_fn(dim, |i| pristine[i] + 0.5 + 0.01 * (i as f64));
assert!(
(&seed - &pristine).iter().any(|d| d.abs() > 1e-6),
"seed must differ from the pristine β for the reuse check to be meaningful"
);
let outcome = obj
.seed_inner_state(&seed)
.expect("a length-matching β must install");
assert!(
matches!(outcome, SeedOutcome::Installed),
"matching β must report Installed; got {outcome:?}"
);
obj.inner_max_iter = 0;
let rho_flat = obj.baseline_rho.to_flat();
let eval =
OuterObjective::eval(&mut obj, &rho_flat).expect("eval at the warm-started β must succeed");
let hint = eval
.inner_beta_hint
.expect("the SAE objective must publish inner_beta_hint for continuation reuse");
assert_eq!(
hint.len(),
dim,
"published hint must have decoder dimension"
);
for (i, (&h, &s)) in hint.iter().zip(seed.iter()).enumerate() {
assert!(
(h - s).abs() < 1e-12,
"warm-started β must be reused verbatim by the inner solve at coord {i}: \
hint {h} != seed {s} (gam#577/#579)"
);
}
}
#[test]
pub(crate) fn seed_inner_state_rejects_wrong_length_populated_beta() {
let mut obj = warmstart_test_objective();
let dim = obj.term.beta_dim();
let wrong: Array1<f64> = Array1::zeros(dim + 1);
let err = obj
.seed_inner_state(&wrong)
.expect_err("a populated β of the wrong length must be rejected");
match err {
EstimationError::RemlOptimizationFailed(msg) => {
assert!(
msg.contains("decoder dim"),
"error must name the decoder-dim mismatch; got: {msg}"
);
}
other => panic!("expected RemlOptimizationFailed, got {other:?}"),
}
}
pub(crate) fn intrinsic_test_atom(jacobian_scale: f64) -> SaeManifoldAtom {
let m = 5usize;
let n = m;
let p = 1usize;
let mut phi = Array2::<f64>::zeros((n, m));
let mut jet = Array3::<f64>::zeros((n, m, 1));
let mut decoder = Array2::<f64>::zeros((m, p));
for mu in 0..m {
phi[[mu, mu]] = 1.0;
jet[[mu, mu, 0]] = jacobian_scale * (1.0 + mu as f64);
decoder[[mu, 0]] = 1.0;
}
let s_raw = crate::basis::create_difference_penalty_matrix(m, 2, None).unwrap();
SaeManifoldAtom::new(
"intrinsic-1d",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
s_raw,
)
.unwrap()
}
#[test]
pub(crate) fn intrinsic_penalty_recovers_order_two_from_nullity() {
let atom = intrinsic_test_atom(1.0);
assert_eq!(atom.smooth_penalty_order, 2);
}
#[test]
pub(crate) fn line_search_snapshot_restores_intrinsic_smooth_penalty() {
let atom = intrinsic_test_atom(1.0);
let n = atom.n_obs();
let logits = Array2::<f64>::zeros((n, 1));
let coords = vec![Array2::<f64>::zeros((n, 1))];
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
coords,
vec![LatentManifold::Euclidean],
AssignmentMode::softmax(1.0),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let original = term.atoms[0].smooth_penalty.clone();
let snapshot = term.snapshot_mutable_state();
term.atoms[0].decoder_coefficients[[0, 0]] *= 3.0;
term.atoms[0].refresh_intrinsic_smooth_penalty();
let changed = (&term.atoms[0].smooth_penalty - &original)
.mapv(f64::abs)
.sum();
assert!(
changed > 1e-6,
"test setup must perturb the live intrinsic smoothness Gram"
);
term.restore_mutable_state(&snapshot);
let restored = (&term.atoms[0].smooth_penalty - &original)
.mapv(f64::abs)
.sum();
assert!(
restored < 1e-12,
"line-search restore left a stale intrinsic smoothness Gram: {restored}"
);
}
#[test]
pub(crate) fn intrinsic_penalty_is_invariant_to_speed_rescaling() {
let a1 = intrinsic_test_atom(1.0);
let a2 = intrinsic_test_atom(7.5);
assert_abs_diff_eq!(
(&a1.smooth_penalty_raw - &a2.smooth_penalty_raw)
.mapv(f64::abs)
.sum(),
0.0,
epsilon = 1e-12
);
let diff = (&a1.smooth_penalty - &a2.smooth_penalty)
.mapv(f64::abs)
.sum();
assert!(
diff < 1e-9,
"intrinsic Gram changed under a global speed rescale (gauge leak): {diff}"
);
}
pub(crate) fn affine_canonicalization_test_term() -> SaeManifoldTerm {
let n = 80usize;
let p = 2usize;
let evaluator = EuclideanPatchEvaluator::new(1, 2).unwrap();
let mut coords = Array2::<f64>::zeros((n, 1));
for row in 0..n {
coords[[row, 0]] = -4.0 + 12.0 * row as f64 / (n as f64 - 1.0);
}
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let mut decoder = Array2::<f64>::zeros((3, p));
decoder[[0, 0]] = 0.8;
decoder[[1, 0]] = -0.4;
decoder[[2, 0]] = 0.15;
decoder[[0, 1]] = -0.2;
decoder[[1, 1]] = 0.9;
decoder[[2, 1]] = -0.08;
let smooth_penalty = crate::basis::create_difference_penalty_matrix(3, 2, None).unwrap();
let atom = SaeManifoldAtom::new(
"affine-canonicalization",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
smooth_penalty,
)
.unwrap()
.with_basis_second_jet(Arc::new(evaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 1)),
vec![coords],
vec![LatentManifold::Euclidean],
AssignmentMode::softmax(1.0),
)
.unwrap();
SaeManifoldTerm::new(vec![atom], assignment).unwrap()
}
#[test]
pub(crate) fn affine_canonicalization_transports_live_penalty_instead_of_recomputing() {
let mut term = affine_canonicalization_test_term();
let before = term.decoder_smoothness_quadratic_form();
let old_smooth_penalty = term.atoms[0].smooth_penalty.clone();
let old_decoder = term.atoms[0].decoder_coefficients.clone();
term.canonicalize_atom_affine_gauge(0, None).unwrap();
let after = term.decoder_smoothness_quadratic_form();
let invariant_gap = (after - before).abs() / before.abs().max(1.0);
assert!(
invariant_gap < 1.0e-9,
"canonicalization changed fixed-rho smoothness energy: before={before:.12e}, after={after:.12e}"
);
let mut recomputed_atom = term.atoms[0].clone();
recomputed_atom.refresh_intrinsic_smooth_penalty();
let recomputed_term = SaeManifoldTerm::new(
vec![recomputed_atom],
SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((term.n_obs(), 1)),
vec![term.assignment.coords[0].as_matrix()],
vec![LatentManifold::Euclidean],
AssignmentMode::softmax(1.0),
)
.unwrap(),
)
.unwrap();
let recomputed = recomputed_term.decoder_smoothness_quadratic_form();
let recompute_jump = (recomputed - before).abs() / before.abs().max(1.0);
assert!(
recompute_jump > 1.0e-2,
"test fixture failed to expose the intrinsic recompute energy jump: before={before:.12e}, recomputed={recomputed:.12e}"
);
let transport =
solve_basis_transport(term.atoms[0].basis_values.view(), old_smooth_penalty.view())
.expect_err("shape mismatch must reject invalid transport solve");
assert!(
transport.contains("row mismatch") || transport.contains("SVD failed"),
"unexpected transport-shape diagnostic: {transport}"
);
let roundtrip = transport_smooth_penalty_for_decoder(
solve_design_least_squares(
term.atoms[0].decoder_coefficients.view(),
old_decoder.view(),
)
.unwrap_or_else(|err| panic!("decoder transport fixture became singular: {err}"))
.view(),
old_smooth_penalty.view(),
);
assert!(
roundtrip.is_err(),
"non-square decoder transport must not be accepted as a penalty congruence"
);
}
#[test]
pub(crate) fn intrinsic_penalty_differs_from_raw_under_varying_speed() {
let atom = intrinsic_test_atom(1.0);
let diff = (&atom.smooth_penalty - &atom.smooth_penalty_raw)
.mapv(f64::abs)
.sum();
assert!(
diff > 1e-6,
"intrinsic reweighting was a no-op on a non-constant-speed curve: {diff}"
);
for i in 0..atom.basis_size() {
for j in 0..atom.basis_size() {
assert_abs_diff_eq!(
atom.smooth_penalty[[i, j]],
atom.smooth_penalty[[j, i]],
epsilon = 1e-12
);
}
}
}
#[test]
pub(crate) fn intrinsic_penalty_leaves_constant_speed_atom_unchanged() {
let m = 6usize;
let n = m;
let mut phi = Array2::<f64>::zeros((n, m));
let mut jet = Array3::<f64>::zeros((n, m, 1));
let mut decoder = Array2::<f64>::zeros((m, 1));
for mu in 0..m {
phi[[mu, mu]] = 1.0;
jet[[mu, mu, 0]] = 2.0;
decoder[[mu, 0]] = 1.0;
}
let s_raw = crate::basis::create_difference_penalty_matrix(m, 2, None).unwrap();
let atom = SaeManifoldAtom::new(
"constant-speed",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
s_raw,
)
.unwrap();
let diff = (&atom.smooth_penalty - &atom.smooth_penalty_raw)
.mapv(f64::abs)
.sum();
assert!(
diff < 1e-9,
"constant-speed atom's penalty was reweighted (should be identity): {diff}"
);
}
#[test]
pub(crate) fn pca_seed_handles_huge_equal_finite_columns_without_mean_overflow() {
let z = array![[1.0e308_f64, 1.0e308], [1.0e308, 1.0e308]];
let coords =
sae_pca_seed_initial_coords(z.view(), &[SaeAtomBasisKind::Periodic], &[1]).unwrap();
assert_eq!(coords.dim(), (1, 2, 1));
assert!(
coords.iter().all(|value| value.is_finite()),
"huge finite equal columns must not overflow the PCA seed mean: {coords:?}"
);
}
#[test]
pub(crate) fn pca_seed_rejects_huge_finite_span_that_overflows_centering() {
let z = array![[1.0e308_f64, 0.0], [-1.0e308, 0.0]];
let err = sae_pca_seed_initial_coords(z.view(), &[SaeAtomBasisKind::Periodic], &[1])
.expect_err("opposite huge finite values exceed f64 centering range");
assert!(
err.contains("centered Z is non-finite") || err.contains("SVD failed"),
"unexpected PCA seed error: {err}"
);
}
#[test]
pub(crate) fn planted_low_rank_frame_recovered_by_polar() {
let p = 12usize;
let r = 3usize;
let n = 200usize;
let mut planted = Array2::<f64>::zeros((p, r));
for j in 0..r {
planted[[j, j]] = 1.0;
}
let mut coords = Array2::<f64>::zeros((n, r));
for i in 0..n {
for j in 0..r {
let x = ((i * 7 + j * 13 + 1) % 97) as f64 / 97.0 - 0.5;
coords[[i, j]] = x;
}
}
let targets = fast_abt(&coords, &planted);
let angle = grassmann_recover_planted_span_angle(targets.view(), coords.view(), planted.view())
.expect("span recovery");
assert_abs_diff_eq!(angle, 0.0, epsilon = 1.0e-9);
let frame = GrassmannFrame::polar_update(planted.view()).expect("polar");
let recovered_angle = frame
.max_principal_angle(planted.view())
.expect("principal angle");
assert_abs_diff_eq!(recovered_angle, 0.0, epsilon = 1.0e-9);
let gram = fast_atb(&frame.frame().to_owned(), &frame.frame().to_owned());
for i in 0..r {
for j in 0..r {
let expect = if i == j { 1.0 } else { 0.0 };
assert_abs_diff_eq!(gram[[i, j]], expect, epsilon = 1.0e-9);
}
}
}
#[test]
pub(crate) fn factored_border_dim_invariant_and_reconstruction() {
let m = 6usize;
let p = 16usize;
let r = 2usize;
let mut frame = Array2::<f64>::zeros((p, r));
frame[[0, 0]] = 1.0;
frame[[1, 1]] = 1.0;
let mut c0 = Array2::<f64>::zeros((m, r));
for mu in 0..m {
c0[[mu, 0]] = 1.0 + mu as f64;
c0[[mu, 1]] = 0.5 * mu as f64 - 1.0;
}
let decoder = fast_abt(&c0, &frame);
let mut phi = Array2::<f64>::zeros((m, m));
let mut jet = Array3::<f64>::zeros((m, m, 1));
for mu in 0..m {
phi[[mu, mu]] = 1.0;
jet[[mu, mu, 0]] = 1.0;
}
let s_raw = crate::basis::create_difference_penalty_matrix(m, 2, None).unwrap();
let mut atom = SaeManifoldAtom::new(
"lowrank",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder.clone(),
s_raw,
)
.unwrap();
let activated = atom.maybe_activate_decoder_frame().expect("activate");
assert_eq!(
activated,
Some(r),
"rank-{r} decoder should profile to r={r}"
);
assert_eq!(atom.border_frame_rank(), r);
assert_eq!(atom.frame_manifold_dimension(), r * (p - r));
let coords = atom.factored_coordinates().unwrap().expect("coords");
assert_eq!(coords.dim(), (m, r));
let reconstructed = atom
.reconstruct_decoder_coefficients(coords.view())
.unwrap();
for mu in 0..m {
for j in 0..p {
assert_abs_diff_eq!(reconstructed[[mu, j]], decoder[[mu, j]], epsilon = 1.0e-9);
}
}
let term = SaeManifoldTerm::new(
vec![atom],
SaeAssignment::from_blocks_with_mode(
Array2::<f64>::zeros((m, 1)),
vec![Array2::<f64>::zeros((m, 1))],
AssignmentMode::softmax(0.7),
)
.unwrap(),
)
.unwrap();
grassmann_assert_border_dim_invariant(&term).expect("border invariant");
assert_eq!(term.factored_border_dim(), m * r);
assert_eq!(term.grassmann_evidence_dimension(), r * (p - r));
let mut term = term;
let border = term.flatten_factored_border().unwrap();
assert_eq!(border.len(), m * r);
let saved = term.atoms[0].decoder_coefficients.clone();
term.scatter_factored_border(border.view()).unwrap();
for mu in 0..m {
for j in 0..p {
assert_abs_diff_eq!(
term.atoms[0].decoder_coefficients[[mu, j]],
saved[[mu, j]],
epsilon = 1.0e-9
);
}
}
}
#[test]
pub(crate) fn factored_beta_penalty_probing_matches_projected_dense_curvature() {
let k_atoms = 2usize;
let m = 4usize;
let p = 24usize;
let r = 2usize;
let n_obs = 5usize;
let mut atoms = Vec::with_capacity(k_atoms);
let mut coord_blocks = Vec::with_capacity(k_atoms);
for atom_idx in 0..k_atoms {
let mut frame = Array2::<f64>::zeros((p, r));
frame[[atom_idx * r, 0]] = 1.0;
frame[[atom_idx * r + 1, 1]] = 1.0;
let mut coords = Array2::<f64>::zeros((n_obs, 1));
for row in 0..n_obs {
coords[[row, 0]] = row as f64;
}
let mut phi = Array2::<f64>::zeros((n_obs, m));
let mut jet = Array3::<f64>::zeros((n_obs, m, 1));
for row in 0..n_obs {
for basis_col in 0..m {
let x = (row + 1) as f64 * (basis_col + 1) as f64;
phi[[row, basis_col]] = 0.05 * x + if row == basis_col { 1.0 } else { 0.0 };
jet[[row, basis_col, 0]] = 0.01 * x;
}
}
let mut c = Array2::<f64>::zeros((m, r));
for basis_col in 0..m {
c[[basis_col, 0]] = 0.3 + 0.07 * (basis_col + atom_idx) as f64;
c[[basis_col, 1]] = -0.2 + 0.05 * (basis_col * 2 + atom_idx) as f64;
}
let decoder = fast_abt(&c, &frame);
let mut atom = SaeManifoldAtom::new(
"factored_probe",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
Array2::<f64>::eye(m),
)
.unwrap();
atom.maybe_activate_decoder_frame()
.expect("frame activation")
.expect("rank-2 atom should activate a frame");
atoms.push(atom);
coord_blocks.push(coords);
}
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::from_elem((n_obs, k_atoms), 0.25),
coord_blocks,
vec![LatentManifold::Euclidean, LatentManifold::Euclidean],
AssignmentMode::softmax(1.0),
)
.unwrap();
let term = SaeManifoldTerm::new(atoms, assignment).unwrap();
assert!(term.frames_active());
assert_eq!(term.factored_border_dim(), k_atoms * m * r);
let beta_len = term.beta_dim();
let mut registry = AnalyticPenaltyRegistry::new();
let nuclear = NuclearNormPenalty::new(
PsiSlice {
range: 0..beta_len,
latent_dim: Some(beta_len / p),
},
0.7,
p,
1.0e-4,
None,
false,
)
.unwrap();
registry.push(AnalyticPenaltyKind::NuclearNorm(Arc::new(nuclear)));
let incoherence = DecoderIncoherencePenalty::new(
PsiSlice {
range: 0..beta_len,
latent_dim: Some(beta_len / p),
},
vec![m, m],
p,
Array2::<f64>::from_elem((k_atoms, k_atoms), 0.5),
0.6,
false,
)
.unwrap();
registry.push(AnalyticPenaltyKind::DecoderIncoherence(Arc::new(
incoherence,
)));
let mut dense_sys = ArrowSchurSystem::new(0, 0, beta_len);
let dense_assembly = term
.add_sae_analytic_penalty_contributions(&mut dense_sys, ®istry, 1.0, None, true, None)
.unwrap();
assert!(dense_assembly.dense_written);
assert!(!dense_assembly.deferred_factored);
let projection = FrameProjection::new(&term);
let border_dim = term.factored_border_dim();
let projected = term.project_dense_penalty_to_factored(dense_sys.hbb.view(), &projection);
let direct = term.build_factored_beta_penalty_curvature(®istry, 1.0, &projection);
for row in 0..border_dim {
for col in 0..border_dim {
assert_abs_diff_eq!(direct[[row, col]], projected[[row, col]], epsilon = 1.0e-10);
}
}
let mut deferred_term = term.clone();
let rho = SaeManifoldRho::new(
0.0,
-20.0,
vec![Array1::<f64>::zeros(1), Array1::<f64>::zeros(1)],
);
let target = Array2::<f64>::zeros((n_obs, p));
let sys = deferred_term
.assemble_arrow_schur_scaled_with_beta_penalty_probe_threshold(
target.view(),
&rho,
Some(®istry),
1.0,
1,
)
.unwrap();
assert_eq!(sys.k, border_dim);
assert!(sys.hbb.is_empty());
}
pub(crate) fn materialize_row_htbeta_for_test(
sys: &ArrowSchurSystem,
row_idx: usize,
) -> Array2<f64> {
let di = sys.row_dims[row_idx];
let k = sys.k;
let row = &sys.rows[row_idx];
let use_dense = sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none();
let mut out = if use_dense && row.htbeta.dim() == (di, k) {
row.htbeta.clone()
} else {
Array2::<f64>::zeros((di, k))
};
if let Some(op) = sys.htbeta_matvec.as_ref() {
let mut basis = Array1::<f64>::zeros(k);
let mut col = Array1::<f64>::zeros(di);
for beta_col in 0..k {
basis.fill(0.0);
basis[beta_col] = 1.0;
col.fill(0.0);
op(row_idx, basis.view(), &mut col);
for row_col in 0..di {
out[[row_col, beta_col]] += col[row_col];
}
}
}
out
}
pub(crate) fn project_row_htbeta_to_factored_for_test(
term: &SaeManifoldTerm,
htbeta_b: ArrayView2<'_, f64>,
) -> Array2<f64> {
FrameProjection::new(term).project_rows(htbeta_b)
}
pub(crate) fn low_rank_factored_htbeta_term(
k_atoms: usize,
m: usize,
p: usize,
frame_rank: usize,
latent_dim: usize,
n_obs: usize,
) -> SaeManifoldTerm {
let mut atoms = Vec::with_capacity(k_atoms);
let mut coord_blocks = Vec::with_capacity(k_atoms);
for atom_idx in 0..k_atoms {
let coords = Array2::from_shape_fn((n_obs, latent_dim), |(row, axis)| {
let phase = (row + 1) as f64 * (axis + 2) as f64 + 0.37 * (atom_idx + 1) as f64;
0.2 * phase.sin() + 0.1 * (0.17 * phase).cos()
});
let mut phi = Array2::<f64>::zeros((n_obs, m));
let mut jet = Array3::<f64>::zeros((n_obs, m, latent_dim));
for row in 0..n_obs {
for basis_col in 0..m {
let base = (row + 1) as f64 * (basis_col + 1) as f64;
phi[[row, basis_col]] = if basis_col == 0 { 1.0 } else { 0.0 }
+ 0.01 * (base + 3.0 * atom_idx as f64).sin();
for axis in 0..latent_dim {
jet[[row, basis_col, axis]] =
0.005 * ((base * (axis + 1) as f64) + atom_idx as f64).cos();
}
}
}
let mut frame = Array2::<f64>::zeros((p, frame_rank));
for frame_col in 0..frame_rank {
frame[[(atom_idx * frame_rank + frame_col) % p, frame_col]] = 1.0;
}
let coords_c = Array2::from_shape_fn((m, frame_rank), |(basis_col, frame_col)| {
0.2 + 0.03 * (basis_col + 2 * frame_col + atom_idx) as f64
});
let decoder = coords_c.dot(&frame.t());
let mut atom = SaeManifoldAtom::new(
"factored_htbeta_shape",
SaeAtomBasisKind::EuclideanPatch,
latent_dim,
phi,
jet,
decoder,
Array2::<f64>::eye(m),
)
.unwrap();
atom.maybe_activate_decoder_frame()
.expect("frame activation")
.expect("low-rank atom should activate a frame");
atoms.push(atom);
coord_blocks.push(coords);
}
let logits = Array2::<f64>::from_shape_fn((n_obs, k_atoms), |(row, atom)| {
0.03 * ((row + 1) as f64 * (atom + 2) as f64).sin()
});
let manifolds =
vec![LatentManifold::Product(vec![LatentManifold::Euclidean; latent_dim]); k_atoms];
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
coord_blocks,
manifolds,
AssignmentMode::softmax(0.9),
)
.unwrap();
SaeManifoldTerm::new(atoms, assignment).unwrap()
}
pub(crate) fn factored_htbeta_rho(k_atoms: usize, latent_dim: usize) -> SaeManifoldRho {
SaeManifoldRho::new(0.0, -0.2, vec![Array1::<f64>::zeros(latent_dim); k_atoms])
}
#[test]
pub(crate) fn factored_row_htbeta_native_solve_matches_full_b_then_project() {
let k_atoms = 2usize;
let m = 4usize;
let p = 24usize;
let r = 2usize;
let n_obs = 5usize;
let mut atoms = Vec::with_capacity(k_atoms);
let mut coord_blocks = Vec::with_capacity(k_atoms);
for atom_idx in 0..k_atoms {
let mut frame = Array2::<f64>::zeros((p, r));
frame[[atom_idx * r, 0]] = 1.0;
frame[[atom_idx * r + 1, 1]] = 1.0;
let coords = Array2::from_shape_fn((n_obs, 1), |(row, _)| 0.1 * (row + 1) as f64);
let mut phi = Array2::<f64>::zeros((n_obs, m));
let mut jet = Array3::<f64>::zeros((n_obs, m, 1));
for row in 0..n_obs {
for basis_col in 0..m {
let x = (row + 1) as f64 * (basis_col + 1) as f64;
phi[[row, basis_col]] = 0.03 * x + if row % m == basis_col { 1.0 } else { 0.0 };
jet[[row, basis_col, 0]] = 0.02 * x;
}
}
let c = Array2::from_shape_fn((m, r), |(basis_col, frame_col)| {
0.2 + 0.04 * (basis_col + 2 * frame_col + atom_idx) as f64
});
let decoder = fast_abt(&c, &frame);
let mut atom = SaeManifoldAtom::new(
"factored_row_native",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
Array2::<f64>::eye(m),
)
.unwrap();
atom.maybe_activate_decoder_frame()
.expect("frame activation")
.expect("rank-2 atom should activate a frame");
atoms.push(atom);
coord_blocks.push(coords);
}
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::from_shape_fn((n_obs, k_atoms), |(row, atom)| {
0.15 * (row + 1) as f64 - 0.07 * atom as f64
}),
coord_blocks,
vec![LatentManifold::Euclidean, LatentManifold::Euclidean],
AssignmentMode::softmax(0.9),
)
.unwrap();
let mut factored_term = SaeManifoldTerm::new(atoms, assignment).unwrap();
assert!(factored_term.frames_active());
let border_dim = factored_term.factored_border_dim();
assert!(border_dim < factored_term.beta_dim());
let mut full_term = factored_term.clone();
for atom in &mut full_term.atoms {
atom.deactivate_decoder_frame();
}
let rho = SaeManifoldRho::new(
0.0,
-0.2,
vec![Array1::<f64>::zeros(1), Array1::<f64>::zeros(1)],
);
let target = Array2::<f64>::from_shape_fn((n_obs, p), |(row, col)| {
0.01 * (row + 1) as f64 - 0.002 * (col + 1) as f64
});
let native_sys = factored_term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
assert_eq!(native_sys.k, border_dim);
assert!(native_sys.htbeta_matvec.is_none());
assert!(native_sys.htbeta_transpose_matvec.is_none());
for row in &native_sys.rows {
assert_eq!(row.htbeta.ncols(), border_dim);
}
let full_sys = full_term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
let mut projected_sys = factored_term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
projected_sys.htbeta_matvec = None;
projected_sys.htbeta_transpose_matvec = None;
projected_sys.htbeta_dense_supplement = false;
for row_idx in 0..n_obs {
let htbeta_b = materialize_row_htbeta_for_test(&full_sys, row_idx);
projected_sys.rows[row_idx].htbeta =
project_row_htbeta_to_factored_for_test(&factored_term, htbeta_b.view());
}
projected_sys.refresh_row_hessian_fingerprint();
let ridge_t = 5.0e-1;
let (native_dt, native_db, _) = native_sys.solve(ridge_t, 1.0e-8).unwrap();
let (projected_dt, projected_db, _) = projected_sys.solve(ridge_t, 1.0e-8).unwrap();
assert_eq!(native_dt.len(), projected_dt.len());
assert_eq!(native_db.len(), projected_db.len());
for idx in 0..native_dt.len() {
assert_abs_diff_eq!(native_dt[idx], projected_dt[idx], epsilon = 1.0e-10);
}
for idx in 0..native_db.len() {
assert_abs_diff_eq!(native_db[idx], projected_db[idx], epsilon = 1.0e-10);
}
}
#[test]
pub(crate) fn factored_row_htbeta_d2_matches_dense_full_b_then_project() {
let k_atoms = 3usize;
let m = 5usize;
let p = 32usize;
let frame_rank = 2usize;
let latent_dim = 2usize;
let n_obs = 6usize;
let mut factored_term =
low_rank_factored_htbeta_term(k_atoms, m, p, frame_rank, latent_dim, n_obs);
assert!(factored_term.frames_active());
assert_eq!(
factored_term.factored_border_dim(),
k_atoms * m * frame_rank
);
assert!(factored_term.factored_border_dim() < factored_term.beta_dim());
let mut full_term = factored_term.clone();
for atom in &mut full_term.atoms {
atom.deactivate_decoder_frame();
}
let rho = factored_htbeta_rho(k_atoms, latent_dim);
let target = Array2::<f64>::from_shape_fn((n_obs, p), |(row, col)| {
0.01 * (row + 1) as f64 - 0.002 * (col + 1) as f64
});
let native_sys = factored_term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
let full_sys = full_term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
let mut projected_sys = factored_term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
projected_sys.htbeta_matvec = None;
projected_sys.htbeta_transpose_matvec = None;
projected_sys.htbeta_dense_supplement = false;
for row_idx in 0..n_obs {
let htbeta_b = materialize_row_htbeta_for_test(&full_sys, row_idx);
projected_sys.rows[row_idx].htbeta =
project_row_htbeta_to_factored_for_test(&factored_term, htbeta_b.view());
}
projected_sys.refresh_row_hessian_fingerprint();
let ridge_t = 5.0e-1;
let (native_dt, native_db, _) = native_sys.solve(ridge_t, 1.0e-8).unwrap();
let (projected_dt, projected_db, _) = projected_sys.solve(ridge_t, 1.0e-8).unwrap();
assert_eq!(native_dt.len(), projected_dt.len());
assert_eq!(native_db.len(), projected_db.len());
for idx in 0..native_dt.len() {
assert_abs_diff_eq!(native_dt[idx], projected_dt[idx], epsilon = 1.0e-10);
}
for idx in 0..native_db.len() {
assert_abs_diff_eq!(native_db[idx], projected_db[idx], epsilon = 1.0e-10);
}
}
#[test]
pub(crate) fn qwen_shape_d2_factored_htbeta_assembly_stays_below_8gib() {
const K_ATOMS: usize = 8;
const M: usize = 10;
const P: usize = 2048;
const FRAME_RANK: usize = 2;
const LATENT_DIM: usize = 2;
const N_OBS: usize = 2000;
const EIGHT_GIB: usize = 8 * 1024 * 1024 * 1024;
let mut term = low_rank_factored_htbeta_term(K_ATOMS, M, P, FRAME_RANK, LATENT_DIM, N_OBS);
assert!(term.frames_active());
assert_eq!(term.beta_dim(), K_ATOMS * M * P);
assert_eq!(term.factored_border_dim(), K_ATOMS * M * FRAME_RANK);
assert!(term.factored_border_dim() < term.beta_dim());
let rho = factored_htbeta_rho(K_ATOMS, LATENT_DIM);
let target = Array2::<f64>::from_shape_fn((N_OBS, P), |(row, col)| {
1.0e-4 * ((row + 1) as f64 * (col + 3) as f64).sin()
});
let sys = term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
assert_eq!(sys.k, term.factored_border_dim());
assert!(sys.htbeta_matvec.is_none());
assert!(sys.htbeta_transpose_matvec.is_none());
let actual_row_dim = sys.row_dims[0];
assert!(actual_row_dim > 0);
assert!(sys.row_dims.iter().all(|&dim| dim == actual_row_dim));
for row in &sys.rows {
assert_eq!(row.htbeta.ncols(), term.factored_border_dim());
assert_eq!(row.htbeta.nrows(), actual_row_dim);
}
let htbeta_bytes: usize = sys
.rows
.iter()
.map(|row| row.htbeta.len() * std::mem::size_of::<f64>())
.sum();
let assembled_dense_bytes = htbeta_bytes
+ sys.hbb.len() * std::mem::size_of::<f64>()
+ sys.gb.len() * std::mem::size_of::<f64>();
let old_full_b_htbeta_bytes = N_OBS
.saturating_mul(actual_row_dim)
.saturating_mul(term.beta_dim())
.saturating_mul(std::mem::size_of::<f64>());
assert!(
old_full_b_htbeta_bytes > EIGHT_GIB,
"test shape must reproduce the old p-wide H_tbeta memory wall"
);
assert!(
assembled_dense_bytes < EIGHT_GIB,
"qwen-shaped factored assembly stored {assembled_dense_bytes} bytes, \
exceeding the 8 GiB gate"
);
}
#[test]
pub(crate) fn factored_evidence_matches_full_b_at_small_p() {
let m = 5usize;
let p = 2usize;
let mut decoder = Array2::<f64>::zeros((m, p));
for mu in 0..m {
decoder[[mu, 0]] = 1.0 + mu as f64;
decoder[[mu, 1]] = (mu as f64) - 2.0;
}
let mut phi = Array2::<f64>::zeros((m, m));
let mut jet = Array3::<f64>::zeros((m, m, 1));
for mu in 0..m {
phi[[mu, mu]] = 1.0;
jet[[mu, mu, 0]] = 1.0;
}
let s_raw = crate::basis::create_difference_penalty_matrix(m, 2, None).unwrap();
let mut atom = SaeManifoldAtom::new(
"fullrank",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
s_raw,
)
.unwrap();
let activated = atom.maybe_activate_decoder_frame().expect("activate");
assert_eq!(
activated, None,
"full-rank small-p must stay on full-B path"
);
assert!(atom.decoder_frame.is_none());
assert_eq!(atom.border_frame_rank(), p);
assert_eq!(atom.frame_manifold_dimension(), 0);
let mut term = SaeManifoldTerm::new(
vec![atom],
SaeAssignment::from_blocks_with_mode(
Array2::<f64>::zeros((m, 1)),
vec![Array2::<f64>::zeros((m, 1))],
AssignmentMode::softmax(0.7),
)
.unwrap(),
)
.unwrap();
assert!(!term.frames_active());
assert_eq!(term.factored_border_dim(), term.beta_dim());
assert_eq!(term.grassmann_evidence_dimension(), 0);
let activated_n = term.auto_activate_decoder_frames().expect("auto");
assert_eq!(activated_n, 0, "small-p auto-activation must be a no-op");
let rho = SaeManifoldRho::new(0.0, 0.37, vec![array![0.0_f64]]);
let occam = term.reml_occam_term(&rho).expect("occam");
let rank_s = SaeManifoldTerm::symmetric_rank(&term.atoms[0].smooth_penalty).unwrap();
let expected = 0.5 * (p as f64) * (rank_s as f64) * rho.log_lambda_smooth;
assert_abs_diff_eq!(occam, expected, epsilon = 1.0e-12);
}
#[test]
pub(crate) fn streaming_polar_refresh_reorients_frame() {
let m = 4usize;
let p = 12usize;
let r = 2usize;
let mut frame0 = Array2::<f64>::zeros((p, r));
frame0[[0, 0]] = 1.0;
frame0[[1, 1]] = 1.0;
let mut c0 = Array2::<f64>::zeros((m, r));
for mu in 0..m {
c0[[mu, 0]] = 1.0 + mu as f64;
c0[[mu, 1]] = 0.5 - mu as f64;
}
let decoder = fast_abt(&c0, &frame0);
let mut phi = Array2::<f64>::zeros((m, m));
let mut jet = Array3::<f64>::zeros((m, m, 1));
for mu in 0..m {
phi[[mu, mu]] = 1.0;
jet[[mu, mu, 0]] = 1.0;
}
let s_raw = crate::basis::create_difference_penalty_matrix(m, 2, None).unwrap();
let mut atom = SaeManifoldAtom::new(
"stream",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
s_raw,
)
.unwrap();
atom.maybe_activate_decoder_frame().expect("activate");
let mut cross = Array2::<f64>::zeros((p, r));
cross[[2, 0]] = 3.0;
cross[[3, 1]] = 2.0;
atom.refresh_frame_from_cross_moment(cross.view())
.expect("refresh");
let frame = atom.decoder_frame.as_ref().expect("frame");
let gram = fast_atb(&frame.frame().to_owned(), &frame.frame().to_owned());
for i in 0..r {
for j in 0..r {
let expect = if i == j { 1.0 } else { 0.0 };
assert_abs_diff_eq!(gram[[i, j]], expect, epsilon = 1.0e-9);
}
}
let mut target_span = Array2::<f64>::zeros((p, r));
target_span[[2, 0]] = 1.0;
target_span[[3, 1]] = 1.0;
let angle = frame
.max_principal_angle(target_span.view())
.expect("angle");
assert_abs_diff_eq!(angle, 0.0, epsilon = 1.0e-9);
}
#[test]
pub(crate) fn small_p_zero_decoder_stays_full_b() {
let m = 3usize;
let p = 8usize;
let mut phi = Array2::<f64>::zeros((m, m));
let mut jet = Array3::<f64>::zeros((m, m, 1));
for row in 0..m {
phi[[row, row]] = 1.0;
jet[[row, row, 0]] = 1.0;
}
let smooth_penalty = crate::basis::create_difference_penalty_matrix(m, 2, None).unwrap();
let mut atom = SaeManifoldAtom::new(
"small-p-zero",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
Array2::<f64>::zeros((m, p)),
smooth_penalty,
)
.unwrap();
assert_eq!(atom.decoder_frame_activation_rank().unwrap(), None);
assert_eq!(atom.maybe_activate_decoder_frame().unwrap(), None);
assert_eq!(atom.border_frame_rank(), p);
}
pub(crate) fn gamma_fd_tiny_fixture() -> (SaeManifoldTerm, Array2<f64>, SaeManifoldRho) {
let n = 10usize;
let p = 3usize;
let k_atoms = 2usize;
let m = 3usize;
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(m).unwrap());
let mut logits = Array2::<f64>::zeros((n, k_atoms));
let mut coords = vec![Array2::<f64>::zeros((n, 1)), Array2::<f64>::zeros((n, 1))];
let weights = [
[
[0.10, -0.05, 0.03],
[0.35, -0.20, 0.12],
[-0.16, 0.18, 0.08],
],
[
[-0.08, 0.04, 0.06],
[0.22, 0.10, -0.18],
[0.11, -0.24, 0.15],
],
];
let mut target = Array2::<f64>::zeros((n, p));
for row in 0..n {
let phase = (row as f64 + 0.35) / n as f64;
coords[0][[row, 0]] = phase;
coords[1][[row, 0]] = (phase + 0.21).fract();
logits[[row, 0]] = if row % 2 == 0 { 0.8 } else { -0.6 };
let assignments = softmax_row(logits.row(row), 0.9);
for atom in 0..k_atoms {
let theta = std::f64::consts::TAU * coords[atom][[row, 0]];
let basis = [1.0, theta.sin(), theta.cos()];
for out_col in 0..p {
for basis_col in 0..m {
target[[row, out_col]] +=
assignments[atom] * basis[basis_col] * weights[atom][basis_col][out_col];
}
}
}
}
let mut atoms = Vec::with_capacity(k_atoms);
for atom in 0..k_atoms {
let (phi, jet) = evaluator.evaluate(coords[atom].view()).unwrap();
let decoder = Array2::from_shape_fn((m, p), |(basis_col, out_col)| {
weights[atom][basis_col][out_col]
});
atoms.push(
SaeManifoldAtom::new(
format!("gamma_{atom}"),
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
decoder,
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_second_jet(evaluator.clone()),
);
}
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
coords,
vec![LatentManifold::Circle { period: 1.0 }; k_atoms],
AssignmentMode::softmax(0.9),
)
.unwrap();
let term = SaeManifoldTerm::new(atoms, assignment).unwrap();
let rho = SaeManifoldRho::new(
-6.0,
-6.0,
vec![Array1::from_vec(vec![-6.0]), Array1::from_vec(vec![-6.0])],
);
(term, target, rho)
}
pub(crate) fn fixed_state_logdet(
mut term: SaeManifoldTerm,
target: &Array2<f64>,
rho: &SaeManifoldRho,
) -> f64 {
let (_value, _loss, cache) = term
.reml_criterion_with_cache(target.view(), rho, None, 0, 0.4, 1.0e-6, 1.0e-6)
.expect("fixed-state cache");
let (tt, beta) = cache.arrow_log_det();
tt + beta.expect("dense Schur logdet")
}
#[test]
pub(crate) fn sae_logdet_theta_adjoint_matches_dense_fd_on_tiny_fixture() {
let (mut term, target, rho) = gamma_fd_tiny_fixture();
let (_value, _loss, cache) = term
.reml_criterion_with_cache(target.view(), &rho, None, 5, 0.4, 1.0e-6, 1.0e-6)
.expect("converged cache");
let solver = DeflatedArrowSolver::plain(&cache);
let gamma = term
.logdet_theta_adjoint(&rho, &cache, &solver)
.expect("Gamma");
let h = 1.0e-5;
let probes = [
(0usize, 0usize, SaeLocalRowVar::Logit { atom: 0 }),
(3usize, 1usize, SaeLocalRowVar::Coord { atom: 0, axis: 0 }),
];
for (row, local_pos, var) in probes {
let mut plus = term.clone();
let mut minus = term.clone();
match var {
SaeLocalRowVar::Logit { atom } => {
plus.assignment.logits[[row, atom]] += h;
minus.assignment.logits[[row, atom]] -= h;
}
SaeLocalRowVar::Coord { atom, axis } => {
let mut flat_p = plus.assignment.coords[atom].as_flat().clone();
let mut flat_m = minus.assignment.coords[atom].as_flat().clone();
let idx = row * plus.assignment.coords[atom].latent_dim() + axis;
flat_p[idx] += h;
flat_m[idx] -= h;
plus.assignment.coords[atom].set_flat(flat_p.view());
minus.assignment.coords[atom].set_flat(flat_m.view());
}
}
let fd = (fixed_state_logdet(plus, &target, &rho)
- fixed_state_logdet(minus, &target, &rho))
/ (2.0 * h);
let analytic = gamma.t[cache.row_offsets[row] + local_pos];
let tol = 2.0e-3 * (1.0 + fd.abs().max(analytic.abs()));
assert!(
(fd - analytic).abs() <= tol,
"Gamma row={row} local_pos={local_pos}: fd={fd:.8e}, analytic={analytic:.8e}"
);
}
}
#[test]
pub(crate) fn sae_logdet_theta_adjoint_matches_dense_fd_ibp_map() {
let (mut term, target, mut rho) = gamma_fd_tiny_fixture();
term.assignment.mode = AssignmentMode::ibp_map(0.7, 0.9, false);
rho.log_lambda_sparse = -1.0;
let (_value, _loss, cache) = term
.reml_criterion_with_cache(target.view(), &rho, None, 5, 0.4, 1.0e-6, 1.0e-6)
.expect("converged cache");
let solver = DeflatedArrowSolver::plain(&cache);
let gamma = term
.logdet_theta_adjoint(&rho, &cache, &solver)
.expect("Gamma");
let h = 1.0e-5;
let probes = [
(0usize, 0usize, SaeLocalRowVar::Logit { atom: 0 }),
(4usize, 1usize, SaeLocalRowVar::Logit { atom: 1 }),
(7usize, 0usize, SaeLocalRowVar::Logit { atom: 0 }),
];
for (row, local_pos, var) in probes {
let mut plus = term.clone();
let mut minus = term.clone();
match var {
SaeLocalRowVar::Logit { atom } => {
plus.assignment.logits[[row, atom]] += h;
minus.assignment.logits[[row, atom]] -= h;
}
SaeLocalRowVar::Coord { atom, axis } => {
let mut flat_p = plus.assignment.coords[atom].as_flat().clone();
let mut flat_m = minus.assignment.coords[atom].as_flat().clone();
let idx = row * plus.assignment.coords[atom].latent_dim() + axis;
flat_p[idx] += h;
flat_m[idx] -= h;
plus.assignment.coords[atom].set_flat(flat_p.view());
minus.assignment.coords[atom].set_flat(flat_m.view());
}
}
let fd = (fixed_state_logdet(plus, &target, &rho)
- fixed_state_logdet(minus, &target, &rho))
/ (2.0 * h);
let analytic = gamma.t[cache.row_offsets[row] + local_pos];
let tol = 3.0e-3 * (1.0 + fd.abs().max(analytic.abs()));
assert!(
(fd - analytic).abs() <= tol,
"IBP Gamma row={row} local_pos={local_pos}: fd={fd:.8e}, analytic={analytic:.8e}"
);
}
}
#[test]
pub(crate) fn sae_row_jet_program_matches_production_row_jets_on_converged_cache() {
use crate::terms::sae_row_jet_program::{
AtomRowBasisJet, RowGate, SaeReconstructionRowProgram,
};
const K: usize = 3;
for weighted in [false, true] {
let (mut term, target, rho) = gamma_fd_tiny_fixture();
if weighted {
let weights: Vec<f64> = (0..term.n_obs())
.map(|row| 0.5 + 0.17 * row as f64)
.collect();
term.set_row_loss_weights(weights)
.expect("set row loss weights");
}
let (_value, _loss, cache) = term
.reml_criterion_with_cache(target.view(), &rho, None, 5, 0.4, 1.0e-6, 1.0e-6)
.expect("converged cache");
let second_jets = term.atom_second_jets().expect("second jets");
let border = term
.border_channels_for_cache(&cache)
.expect("border channels");
let AssignmentMode::Softmax { temperature, .. } = term.assignment.mode else {
panic!("gamma fixture is softmax-gated");
};
let inv_tau = 1.0 / temperature;
let p = term.output_dim();
let k_atoms = term.k_atoms();
for row in 0..term.n_obs() {
let vars = term.row_vars_for_cache_row(row, &cache).expect("row vars");
assert_eq!(
vars.len(),
K,
"tiny fixture rows carry 1 free softmax logit + 2 coords"
);
let assignments = term
.assignment
.try_assignments_row(row)
.expect("assignments row");
let jets = term
.row_jets_for_logdet(
&rho,
row,
vars.clone(),
assignments.view(),
&second_jets,
&border,
)
.expect("production row jets");
let mut logit_slot = vec![None; k_atoms];
let mut coord_slot: Vec<Vec<usize>> = term
.atoms
.iter()
.map(|atom| vec![usize::MAX; atom.latent_dim])
.collect();
for (pos, var) in vars.iter().enumerate() {
match *var {
SaeLocalRowVar::Logit { atom } => logit_slot[atom] = Some(pos),
SaeLocalRowVar::Coord { atom, axis } => coord_slot[atom][axis] = pos,
}
}
let atoms: Vec<AtomRowBasisJet> = term
.atoms
.iter()
.enumerate()
.map(|(k, atom)| {
let m = atom.basis_size();
let d = atom.latent_dim;
AtomRowBasisJet {
phi: (0..m).map(|b| atom.basis_values[[row, b]]).collect(),
d_phi: (0..m)
.map(|b| {
(0..d)
.map(|axis| atom.basis_jacobian[[row, b, axis]])
.collect()
})
.collect(),
d2_phi: (0..m)
.map(|b| {
(0..d)
.map(|aa| {
(0..d).map(|bb| second_jets[k][[row, b, aa, bb]]).collect()
})
.collect()
})
.collect(),
decoder: (0..m)
.map(|b| (0..p).map(|c| atom.decoder_coefficients[[b, c]]).collect())
.collect(),
latent_dim: d,
}
})
.collect();
let prog = SaeReconstructionRowProgram {
atoms,
gate_value: assignments.to_vec(),
logits: term.assignment.logits.row(row).to_vec(),
gate_shift: vec![0.0; k_atoms],
gate: RowGate::Softmax { inv_tau },
logit_slot,
coord_slot,
n_primaries: K,
};
let sqrt_row_w = term
.row_loss_weights
.as_deref()
.map_or(1.0, |w| w[row].sqrt());
if weighted {
assert!(
(sqrt_row_w - 1.0).abs() > 1e-6,
"weighted arm must exercise a non-unit √w (row {row}, √w={sqrt_row_w})"
);
}
for out_col in 0..p {
let tower = prog.reconstruction_column::<K>(out_col);
let g_floor = (0..K)
.map(|a| jets.first[a][out_col].abs())
.fold(1e-12_f64, f64::max);
let h_floor = (0..K)
.flat_map(|a| (0..K).map(move |b| (a, b)))
.map(|(a, b)| jets.second[a][b][out_col].abs())
.fold(1e-12_f64, f64::max);
for a in 0..K {
let want = sqrt_row_w * tower.g[a];
assert!(
(jets.first[a][out_col] - want).abs() <= 1e-9 * g_floor,
"weighted={weighted} row {row} col {out_col} first[{a}]: \
production {} vs tower {}",
jets.first[a][out_col],
want
);
for b in 0..K {
let want2 = sqrt_row_w * tower.h[a][b];
assert!(
(jets.second[a][b][out_col] - want2).abs() <= 1e-9 * h_floor,
"weighted={weighted} row {row} col {out_col} \
second[{a}][{b}]: production {} vs tower {}",
jets.second[a][b][out_col],
want2
);
}
}
}
}
}
}
#[test]
pub(crate) fn ibp_map_outer_objective_advertises_analytic_gradient() {
let (mut term, target, rho) = gamma_fd_tiny_fixture();
term.assignment.mode = AssignmentMode::ibp_map(0.9, 1.0, false);
let obj = SaeManifoldOuterObjective::new(term, target, None, rho, 5, 0.4, 1.0e-6, 1.0e-6);
assert_eq!(obj.capability().gradient, Derivative::Analytic);
}
#[cfg(test)]
mod inner_contract_probe_tests {
use super::*;
use crate::terms::{AssignmentMode, LatentManifold, SaeAssignment};
use std::sync::Arc;
pub(crate) fn euclidean_line_contract_fixture() -> (SaeManifoldTerm, Array2<f64>, SaeManifoldRho)
{
let n = 150usize;
let p = 8usize;
let mut coords = Array2::<f64>::zeros((n, 1));
let mut z = Array2::<f64>::zeros((n, p));
for row in 0..n {
let u = -1.0 + 2.0 * row as f64 / (n as f64 - 1.0);
coords[[row, 0]] = 2.5 + 3.0 * u;
for col in 0..p {
let linear_loading = 0.35 + 0.07 * col as f64;
let offset = 0.08 * ((col % 3) as f64 - 1.0);
let phase = (row * (col + 3)) as f64;
let noise = 0.04 * (phase.sin() + 0.5 * (0.37 * phase).cos());
z[[row, col]] = offset + linear_loading * u + noise;
}
}
let evaluator = Arc::new(EuclideanPatchEvaluator::new(1, 2).expect("evaluator"));
let (phi, jet) = evaluator.evaluate(coords.view()).expect("basis");
let m = phi.ncols();
let smooth_penalty =
crate::basis::create_difference_penalty_matrix(m, 2, None).expect("penalty");
let atom = SaeManifoldAtom::new(
"contract-line",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
Array2::<f64>::zeros((m, p)),
smooth_penalty,
)
.expect("atom")
.with_basis_second_jet(evaluator);
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 1)),
vec![coords],
vec![LatentManifold::Euclidean],
AssignmentMode::softmax(1.0),
)
.expect("assignment");
let term = SaeManifoldTerm::new(vec![atom], assignment).expect("term");
let rho = SaeManifoldRho::new(0.0, (0.01_f64).ln(), vec![Array1::<f64>::zeros(1)]);
(term, z, rho)
}
pub(crate) fn assert_contract_close(label: &str, analytic: f64, finite_difference: f64) {
let rel = (analytic - finite_difference).abs()
/ finite_difference.abs().max(analytic.abs()).max(1.0e-12);
assert!(
rel < 1.0e-5,
"{label}: analytic={analytic:.12e} fd={finite_difference:.12e} rel={rel:.3e}"
);
}
#[test]
pub(crate) fn euclidean_line_decoder_gradient_matches_penalized_objective_fd() {
let (mut term, z, mut rho) = euclidean_line_contract_fixture();
let ridge = 1.0e-6;
for step in 0..6 {
let loss = term
.run_joint_fit_arrow_schur(z.view(), &mut rho, None, 1, 1.0, ridge, ridge)
.unwrap_or_else(|err| panic!("warm step {step} failed: {err}"));
assert!(
loss.total().is_finite(),
"warm step {step} loss is non-finite"
);
}
let sys_coord = term
.assemble_arrow_schur(z.view(), &rho, None)
.expect("coord assemble");
assert_eq!(
sys_coord.k,
term.beta_dim(),
"p=8 contract fixture must stay on full-B coordinates"
);
assert!(
!term.frames_active(),
"p=8 contract fixture must not activate a frame"
);
let h = 1.0e-6;
for row in [3usize, 75, 140] {
let analytic = sys_coord.rows[row].gt[0];
let base_coord = term.assignment.coords[0].as_matrix()[[row, 0]];
let mut plus_coords = term.assignment.coords[0].as_matrix();
plus_coords[[row, 0]] = base_coord + h;
let plus_flat = Array1::from_iter(plus_coords.iter().copied());
term.assignment.coords[0].set_flat(plus_flat.view());
term.refresh_basis_from_current_coords()
.expect("plus refresh");
let f_plus = term
.penalized_objective_total(z.view(), &rho, None, 1.0)
.expect("coord f+");
let mut minus_coords = term.assignment.coords[0].as_matrix();
minus_coords[[row, 0]] = base_coord - h;
let minus_flat = Array1::from_iter(minus_coords.iter().copied());
term.assignment.coords[0].set_flat(minus_flat.view());
term.refresh_basis_from_current_coords()
.expect("minus refresh");
let f_minus = term
.penalized_objective_total(z.view(), &rho, None, 1.0)
.expect("coord f-");
let mut restored_coords = term.assignment.coords[0].as_matrix();
restored_coords[[row, 0]] = base_coord;
let restored_flat = Array1::from_iter(restored_coords.iter().copied());
term.assignment.coords[0].set_flat(restored_flat.view());
term.refresh_basis_from_current_coords()
.expect("restore refresh");
let fd = (f_plus - f_minus) / (2.0 * h);
assert_contract_close(&format!("CONTRACT coord row {row}"), analytic, fd);
}
let sys_decoder = term
.assemble_arrow_schur(z.view(), &rho, None)
.expect("decoder assemble");
assert_eq!(sys_decoder.k, term.beta_dim());
let p = term.output_dim();
for (basis_col, out_col) in [(0usize, 0usize), (1, 3), (2, 7)] {
let beta_idx = basis_col * p + out_col;
let analytic = sys_decoder.gb[beta_idx];
let base = term.atoms[0].decoder_coefficients[[basis_col, out_col]];
term.atoms[0].decoder_coefficients[[basis_col, out_col]] = base + h;
let f_plus = term
.penalized_objective_total(z.view(), &rho, None, 1.0)
.expect("decoder f+");
term.atoms[0].decoder_coefficients[[basis_col, out_col]] = base - h;
let f_minus = term
.penalized_objective_total(z.view(), &rho, None, 1.0)
.expect("decoder f-");
term.atoms[0].decoder_coefficients[[basis_col, out_col]] = base;
let fd = (f_plus - f_minus) / (2.0 * h);
assert_contract_close(
&format!("CONTRACT decoder ({basis_col},{out_col})"),
analytic,
fd,
);
}
}
}