use super::tests::{
TestPeriodicEvaluator, diagonal_latent_cache, periodic_basis,
warmstart_test_objective_with_evaluator,
};
use super::*;
use crate::assignment::{AssignmentMode, SaeAssignment};
use gam_terms::latent::LatentManifold;
use approx::assert_abs_diff_eq;
use ndarray::array;
use std::sync::Arc;
pub(crate) fn euclidean_line_contract_fixture() -> (SaeManifoldTerm, Array2<f64>, SaeManifoldRho) {
let n = 150usize;
let p = 8usize;
let mut coords = Array2::<f64>::zeros((n, 1));
let mut z = Array2::<f64>::zeros((n, p));
for row in 0..n {
let u = -1.0 + 2.0 * row as f64 / (n as f64 - 1.0);
coords[[row, 0]] = 2.5 + 3.0 * u;
for col in 0..p {
let linear_loading = 0.35 + 0.07 * col as f64;
let offset = 0.08 * ((col % 3) as f64 - 1.0);
let phase = (row * (col + 3)) as f64;
let noise = 0.04 * (phase.sin() + 0.5 * (0.37 * phase).cos());
z[[row, col]] = offset + linear_loading * u + noise;
}
}
let evaluator = Arc::new(EuclideanPatchEvaluator::new(1, 2).expect("evaluator"));
let (phi, jet) = evaluator.evaluate(coords.view()).expect("basis");
let m = phi.ncols();
let smooth_penalty =
gam_terms::basis::create_difference_penalty_matrix(m, 2, None).expect("penalty");
let atom = SaeManifoldAtom::new(
"contract-line",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
Array2::<f64>::zeros((m, p)),
smooth_penalty,
)
.expect("atom")
.with_basis_second_jet(evaluator);
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 1)),
vec![coords],
vec![LatentManifold::Euclidean],
AssignmentMode::softmax(1.0),
)
.expect("assignment");
let term = SaeManifoldTerm::new(vec![atom], assignment).expect("term");
let rho = SaeManifoldRho::new(0.0, (0.01_f64).ln(), vec![Array1::<f64>::zeros(1)]);
(term, z, rho)
}
pub(crate) fn assert_contract_close(label: &str, analytic: f64, finite_difference: f64) {
let rel = (analytic - finite_difference).abs()
/ finite_difference.abs().max(analytic.abs()).max(1.0e-12);
assert!(
rel < 1.0e-5,
"{label}: analytic={analytic:.12e} fd={finite_difference:.12e} rel={rel:.3e}"
);
}
#[test]
pub(crate) fn euclidean_line_decoder_gradient_matches_penalized_objective_fd() {
let (mut term, z, mut rho) = euclidean_line_contract_fixture();
let ridge = 1.0e-6;
for step in 0..6 {
let loss = term
.run_joint_fit_arrow_schur(z.view(), &mut rho, None, 1, 1.0, ridge, ridge)
.unwrap_or_else(|err| panic!("warm step {step} failed: {err}"));
assert!(
loss.total().is_finite(),
"warm step {step} loss is non-finite"
);
}
let sys_coord = term
.assemble_arrow_schur(z.view(), &rho, None)
.expect("coord assemble");
assert_eq!(
sys_coord.k,
term.beta_dim(),
"p=8 contract fixture must stay on full-B coordinates"
);
assert!(
!term.frames_active(),
"p=8 contract fixture must not activate a frame"
);
let h = 1.0e-6;
for row in [3usize, 75, 140] {
let analytic = sys_coord.rows[row].gt[0];
let base_coord = term.assignment.coords[0].as_matrix()[[row, 0]];
let mut plus_coords = term.assignment.coords[0].as_matrix();
plus_coords[[row, 0]] = base_coord + h;
let plus_flat = Array1::from_iter(plus_coords.iter().copied());
term.assignment.coords[0].set_flat(plus_flat.view());
term.refresh_basis_from_current_coords()
.expect("plus refresh");
let f_plus = term
.penalized_objective_total(z.view(), &rho, None, 1.0)
.expect("coord f+");
let mut minus_coords = term.assignment.coords[0].as_matrix();
minus_coords[[row, 0]] = base_coord - h;
let minus_flat = Array1::from_iter(minus_coords.iter().copied());
term.assignment.coords[0].set_flat(minus_flat.view());
term.refresh_basis_from_current_coords()
.expect("minus refresh");
let f_minus = term
.penalized_objective_total(z.view(), &rho, None, 1.0)
.expect("coord f-");
let mut restored_coords = term.assignment.coords[0].as_matrix();
restored_coords[[row, 0]] = base_coord;
let restored_flat = Array1::from_iter(restored_coords.iter().copied());
term.assignment.coords[0].set_flat(restored_flat.view());
term.refresh_basis_from_current_coords()
.expect("restore refresh");
let fd = (f_plus - f_minus) / (2.0 * h);
assert_contract_close(&format!("CONTRACT coord row {row}"), analytic, fd);
}
let sys_decoder = term
.assemble_arrow_schur(z.view(), &rho, None)
.expect("decoder assemble");
assert_eq!(sys_decoder.k, term.beta_dim());
let p = term.output_dim();
for (basis_col, out_col) in [(0usize, 0usize), (1, 3), (2, 7)] {
let beta_idx = basis_col * p + out_col;
let analytic = sys_decoder.gb[beta_idx];
let base = term.atoms[0].decoder_coefficients[[basis_col, out_col]];
term.atoms[0].decoder_coefficients[[basis_col, out_col]] = base + h;
let f_plus = term
.penalized_objective_total(z.view(), &rho, None, 1.0)
.expect("decoder f+");
term.atoms[0].decoder_coefficients[[basis_col, out_col]] = base - h;
let f_minus = term
.penalized_objective_total(z.view(), &rho, None, 1.0)
.expect("decoder f-");
term.atoms[0].decoder_coefficients[[basis_col, out_col]] = base;
let fd = (f_plus - f_minus) / (2.0 * h);
assert_contract_close(
&format!("CONTRACT decoder ({basis_col},{out_col})"),
analytic,
fd,
);
}
}
#[test]
fn cotrained_criterion_folds_faithful_amortized_encoder_on_known_manifold() {
let n = 24usize;
let p = 4usize;
let coords = Array2::from_shape_fn((n, 1), |(row, _)| (row as f64 + 0.5) / n as f64);
let (phi, jet) = periodic_basis(&coords);
let m = phi.ncols();
let decoder = Array2::from_shape_fn((m, p), |(b, c)| {
let scale = 1.0 / (1.0 + b as f64);
scale * ((b as f64 + 1.0) * (c as f64 + 1.0)).cos()
});
let atom = SaeManifoldAtom::new(
"periodic_truth",
SaeAtomBasisKind::Periodic,
1,
phi.clone(),
jet,
decoder.clone(),
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let target = phi.dot(&decoder);
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 1)),
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(1.0),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let rho = SaeManifoldRho::new(0.0, 0.8_f64.ln(), vec![array![1.0_f64.ln()]]);
let mut rho_fit = rho.clone();
term.run_joint_fit_arrow_schur(target.view(), &mut rho_fit, None, 12, 1.0, 1.0e-4, 1.0e-4)
.expect("inner solve converges on the known periodic manifold");
let (reml, _loss) = term
.reml_criterion_with_refine_policy(
target.view(),
&rho_fit,
None,
25,
1.0,
1.0e-4,
1.0e-4,
true,
)
.expect("REML criterion evaluates");
let (cotrained, _loss2, consistency) = term
.reml_criterion_cotrained(target.view(), &rho_fit, None, 64, 1.0, 1.0e-4, 1.0e-4)
.expect("co-trained criterion evaluates");
assert!(
cotrained.is_finite() && reml.is_finite(),
"both criteria must be finite: cotrained={cotrained}, reml={reml}"
);
assert!(
cotrained >= reml - 1.0e-9,
"co-trained criterion must add a NON-NEGATIVE consistency penalty: \
cotrained={cotrained} < reml={reml}"
);
assert!(
consistency.recon_consistency >= 0.0 && consistency.recon_consistency.is_finite(),
"recon consistency must be a finite non-negative gap, got {}",
consistency.recon_consistency
);
assert!(
(0.0..=1.0).contains(&consistency.uncertified_fraction),
"uncertified fraction must be a probability, got {}",
consistency.uncertified_fraction
);
assert!(
consistency.uncertified_fraction < 1.0,
"the amortized encoder must certify at least some rows of a \
well-conditioned periodic dictionary; uncertified_fraction={}",
consistency.uncertified_fraction
);
let amplitudes = term.fitted_assignment_amplitudes(&rho_fit).unwrap();
let encodes = term
.amortized_encode_target(target.view(), amplitudes.view())
.expect("amortized encode runs");
let atom0 = &term.atoms[0];
let evaluator = atom0.basis_evaluator.as_ref().unwrap();
let (phi_hat, _j) = evaluator.evaluate(encodes[0].coords.view()).unwrap();
let decoded_hat = phi_hat.dot(&atom0.decoder_coefficients);
let mut in_sample_norm_bound = 0.0_f64;
for row in 0..n {
in_sample_norm_bound =
in_sample_norm_bound.max(target.row(row).dot(&target.row(row)).sqrt());
}
let in_sample_atlas = crate::encode::EncodeAtlas::build(
&term.atoms,
&[1.0],
in_sample_norm_bound,
crate::encode::AtlasConfig::default(),
)
.expect("in-sample encode atlas builds");
let mut certified_rows = 0usize;
let mut max_certified_gap = 0.0_f64;
for row in 0..n {
if !encodes[0].certified[row] {
continue;
}
let z = amplitudes[[row, 0]];
let (exact_t, exact_cert) = in_sample_atlas
.certified_encode_row(atom0, 0, target.row(row), z)
.expect("exact per-row encode runs");
if !exact_cert.certified() {
continue;
}
certified_rows += 1;
let exact_phi = evaluator
.evaluate(exact_t.view().insert_axis(ndarray::Axis(0)))
.unwrap()
.0;
let exact_decoded = exact_phi.dot(&atom0.decoder_coefficients); for col in 0..p {
let amortized = z * decoded_hat[[row, col]];
let exact = z * exact_decoded[[0, col]];
let gap = (amortized - exact).abs();
if gap > max_certified_gap {
max_certified_gap = gap;
}
}
}
assert!(
certified_rows > 0,
"the certificate must accept at least one row to measure faithfulness"
);
assert!(
max_certified_gap < 1.0e-2,
"amortized encode must reconstruct certified rows within the encode \
tolerance of the exact per-row encode-by-inner-solve; max gap={max_certified_gap}"
);
let n_holdout = 12usize;
let heldout_coords = Array2::from_shape_fn((n_holdout, 1), |(row, _)| {
(row as f64 + 0.25) / n_holdout as f64
});
let (heldout_phi, _heldout_jet) = periodic_basis(&heldout_coords);
let heldout = heldout_phi.dot(&atom0.decoder_coefficients);
let heldout_amplitudes = Array1::<f64>::ones(n_holdout);
let mut target_norm_bound = 0.0_f64;
for row in 0..n_holdout {
target_norm_bound = target_norm_bound.max(heldout.row(row).dot(&heldout.row(row)).sqrt());
}
let atlas = crate::encode::EncodeAtlas::build(
&term.atoms,
&[1.0],
target_norm_bound,
crate::encode::AtlasConfig::default(),
)
.expect("held-out encode atlas builds");
let fast_heldout = atlas
.amortized_encode_batch(atom0, 0, heldout.view(), heldout_amplitudes.view())
.expect("held-out amortized encode runs");
let mut max_fast_vs_exact = 0.0_f64;
let mut max_fast_truth = 0.0_f64;
let mut max_exact_truth = 0.0_f64;
let mut heldout_certified = 0usize;
for row in 0..n_holdout {
if !fast_heldout.certified[row] {
continue;
}
heldout_certified += 1;
let (exact_t, exact_cert) = atlas
.certified_encode_row(atom0, 0, heldout.row(row), 1.0)
.expect("held-out exact certified row encode runs");
assert!(
exact_cert.certified(),
"sequential exact #1010 teacher must certify held-out row {row}"
);
let truth = heldout_coords[[row, 0]];
let fast = fast_heldout.coords[[row, 0]];
let exact = exact_t[0];
let fast_vs_exact = circle_phase_gap(fast, exact);
let fast_truth = circle_phase_gap(fast, truth);
let exact_truth = circle_phase_gap(exact, truth);
max_fast_vs_exact = max_fast_vs_exact.max(fast_vs_exact);
max_fast_truth = max_fast_truth.max(fast_truth);
max_exact_truth = max_exact_truth.max(exact_truth);
}
eprintln!(
"#1154 AMORTIZED-VS-EXACT: held-out certified={heldout_certified} \
| max fast-vs-exact #1010 phase gap={max_fast_vs_exact:.6e} \
| max fast-vs-truth={max_fast_truth:.6e} | max exact-vs-truth={max_exact_truth:.6e}"
);
assert!(
heldout_certified > 0,
"fast amortized encode must certify held-out rows on the known manifold"
);
assert!(
max_fast_vs_exact < 1.0e-2,
"fast amortized held-out encode must match exact #1010 encode within \
certified tolerance; max phase gap={max_fast_vs_exact}"
);
assert!(
max_fast_truth <= max_exact_truth + 1.0e-2,
"co-trained fast encoder must recover the known held-out manifold at \
least as well as the sequential exact-teacher path within tolerance; \
fast={max_fast_truth}, sequential={max_exact_truth}"
);
}
fn circle_phase_gap(a: f64, b: f64) -> f64 {
let raw = (a - b).abs();
raw.min((raw - raw.floor()).abs())
.min((1.0 - raw.fract()).abs())
}
#[test]
fn cotrain_fold_is_value_lane_only_so_gradient_lane_pair_is_consistent() {
let mut objective = warmstart_test_objective_with_evaluator();
let rho_flat = objective.current_rho.to_flat();
let value_lane = objective
.eval_cost(&rho_flat)
.expect("value-probe lane evaluates the co-trained cost");
let mut objective_grad = warmstart_test_objective_with_evaluator();
let gradient_lane = objective_grad
.eval(&rho_flat)
.expect("gradient lane evaluates")
.cost;
assert!(
value_lane.is_finite() && gradient_lane.is_finite(),
"both lanes must be finite: value={value_lane}, gradient={gradient_lane}"
);
let bare_value = {
let mut probe = warmstart_test_objective_with_evaluator();
let target = probe.target.clone();
let rho_state = probe.baseline_rho.from_flat(rho_flat.view());
probe
.term
.warm_start_latents_from_amortized_encoder(target.view(), &rho_state)
.ok();
let (reml, _loss) = probe
.term
.reml_criterion_with_refine_policy(
target.view(),
&rho_state,
None,
probe.inner_max_iter,
probe.learning_rate,
probe.ridge_ext_coord,
probe.ridge_beta,
false,
)
.expect("bare value-lane REML criterion evaluates");
probe
.add_fit_data_collapse_penalty(reml, &rho_state)
.expect("collapse penalty evaluates")
};
let value_fold = value_lane - bare_value;
assert!(
value_fold > 1.0e-12,
"the value-probe lane carries the co-training fold (positive penalty \
over bare REML): value_lane={value_lane}, bare={bare_value}, \
fold={value_fold}"
);
let bare_grad = {
let mut probe = warmstart_test_objective_with_evaluator();
let target = probe.target.clone();
let rho_state = probe.baseline_rho.from_flat(rho_flat.view());
probe
.term
.warm_start_latents_from_amortized_encoder(target.view(), &rho_state)
.ok();
let (reml, _loss, _cache) = probe
.term
.reml_criterion_with_cache(
target.view(),
&rho_state,
None,
probe.inner_max_iter,
probe.learning_rate,
probe.ridge_ext_coord,
probe.ridge_beta,
)
.expect("bare gradient-lane REML criterion evaluates");
probe
.add_fit_data_collapse_penalty(reml, &rho_state)
.expect("collapse penalty evaluates")
};
let gradient_vs_bare = (gradient_lane - bare_grad).abs();
assert!(
gradient_vs_bare < 1.0e-9,
"the gradient lane must report bare REML (no consistency fold), so its \
(cost, ∇f) pair is self-consistent for BFGS Armijo: \
gradient_lane={gradient_lane}, bare_grad={bare_grad}, \
diff={gradient_vs_bare}"
);
}
#[test]
fn amortized_warm_start_matches_or_beats_cold_inner_solve_on_known_manifold() {
let n = 24usize;
let p = 4usize;
let coords = Array2::from_shape_fn((n, 1), |(row, _)| (row as f64 + 0.5) / n as f64);
let (phi, jet) = periodic_basis(&coords);
let m = phi.ncols();
let decoder = Array2::from_shape_fn((m, p), |(b, c)| {
let scale = 1.0 / (1.0 + b as f64);
scale * ((b as f64 + 1.0) * (c as f64 + 1.0)).cos()
});
let atom = SaeManifoldAtom::new(
"periodic_truth",
SaeAtomBasisKind::Periodic,
1,
phi.clone(),
jet,
decoder.clone(),
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let target = phi.dot(&decoder);
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 1)),
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(1.0),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let rho = SaeManifoldRho::new(0.0, 0.8_f64.ln(), vec![array![1.0_f64.ln()]]);
let mut rho_cold = rho.clone();
term.run_joint_fit_arrow_schur(target.view(), &mut rho_cold, None, 12, 0.1, 1.0e-4, 1.0e-4)
.expect("cold inner solve converges on the known periodic manifold");
let cold_ev = {
let fitted = term.try_fitted_for_rho(&rho_cold).unwrap();
reconstruction_explained_variance(target.view(), fitted.view())
.expect("explained variance is defined for the planted target")
};
assert!(
cold_ev > 0.9,
"cold fit must recover the planted periodic manifold (EV={cold_ev})"
);
let warm_started = term
.warm_start_latents_from_amortized_encoder(target.view(), &rho_cold)
.expect("amortized warm-start runs on the fitted dictionary");
eprintln!("#1154 WARM-START: certified warm-started rows={warm_started}/{n}");
assert!(
warm_started <= n,
"the amortized encoder cannot warm-start more rows than the fitted \
batch size; warm_started={warm_started}, n={n}"
);
let mut rho_warm = rho.clone();
term.run_joint_fit_arrow_schur(target.view(), &mut rho_warm, None, 12, 0.1, 1.0e-4, 1.0e-4)
.expect("warm-started inner solve converges");
let warm_ev = {
let fitted = term.try_fitted_for_rho(&rho_warm).unwrap();
reconstruction_explained_variance(target.view(), fitted.view())
.expect("explained variance is defined for the planted target")
};
assert!(
warm_ev >= cold_ev - 1.0e-6,
"amortized warm-start (co-trained inner solve) must recover the manifold \
at least as well as the cold/sequential solve: warm_ev={warm_ev}, \
cold_ev={cold_ev}"
);
}
#[test]
fn cotrained_encoder_recovers_planted_manifold_at_least_as_well_as_sequential() {
let n = 32usize;
let p = 4usize;
let coords = Array2::from_shape_fn((n, 1), |(row, _)| (row as f64 + 0.5) / n as f64);
let (phi, jet) = periodic_basis(&coords);
let m = phi.ncols();
let decoder = Array2::from_shape_fn((m, p), |(b, c)| {
let scale = 1.0 / (1.0 + b as f64);
scale * ((b as f64 + 1.0) * (c as f64 + 1.0)).cos()
});
let target = phi.dot(&decoder);
let rho_grid: Vec<SaeManifoldRho> = [(-0.5_f64, 0.4_f64), (0.0, 0.8), (0.3, 1.2)]
.iter()
.map(|&(ls, lsm)| SaeManifoldRho::new(ls, lsm.ln(), vec![array![1.0_f64.ln()]]))
.collect();
let build_term = || {
let atom = SaeManifoldAtom::new(
"periodic_truth",
SaeAtomBasisKind::Periodic,
1,
phi.clone(),
jet.clone(),
decoder.clone(),
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 1)),
vec![coords.clone()],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(1.0),
)
.unwrap();
SaeManifoldTerm::new(vec![atom], assignment).unwrap()
};
let n_holdout = 16usize;
let heldout_truth = Array2::from_shape_fn((n_holdout, 1), |(row, _)| {
(row as f64 + 0.25) / n_holdout as f64
});
let (heldout_phi, _hjet) = periodic_basis(&heldout_truth);
let heldout_recovery_gap = |term: &SaeManifoldTerm| -> (f64, usize) {
let atom0 = &term.atoms[0];
let heldout = heldout_phi.dot(&atom0.decoder_coefficients);
let amps = Array1::<f64>::ones(n_holdout);
let mut norm_bound = 0.0_f64;
for row in 0..n_holdout {
norm_bound = norm_bound.max(heldout.row(row).dot(&heldout.row(row)).sqrt());
}
let atlas = crate::encode::EncodeAtlas::build(
&term.atoms,
&[1.0],
norm_bound,
crate::encode::AtlasConfig::default(),
)
.expect("held-out encode atlas builds");
let encoded = atlas
.amortized_encode_batch(atom0, 0, heldout.view(), amps.view())
.expect("held-out amortized encode runs");
let mut max_gap = 0.0_f64;
let mut amortized_certified = 0usize;
for row in 0..n_holdout {
if encoded.certified[row] {
amortized_certified += 1;
}
let (coord, _cert) = atlas
.certified_encode_row(atom0, 0, heldout.row(row), amps[row])
.expect("held-out exact encode converges");
let gap = circle_phase_gap(coord[0], heldout_truth[[row, 0]]);
max_gap = max_gap.max(gap);
}
(max_gap, amortized_certified)
};
let mut best_seq_rho = rho_grid[0].clone();
let mut best_seq_cost = f64::INFINITY;
for rho in &rho_grid {
let mut probe = build_term();
let Ok((reml, _loss)) =
probe.reml_criterion(target.view(), rho, None, 12, 1.0, 1.0e-4, 1.0e-4)
else {
continue;
};
if reml < best_seq_cost {
best_seq_cost = reml;
best_seq_rho = rho.clone();
}
}
assert!(
best_seq_cost.is_finite(),
"the sequential grid must contain at least one converged bare-REML candidate"
);
let mut seq_term = build_term();
let mut seq_rho = best_seq_rho.clone();
seq_term
.run_joint_fit_arrow_schur(target.view(), &mut seq_rho, None, 12, 1.0, 1.0e-4, 1.0e-4)
.expect("sequential cold inner solve converges");
let (seq_gap, seq_certified) = heldout_recovery_gap(&seq_term);
let mut best_cot_rho = rho_grid[0].clone();
let mut best_cot_cost = f64::INFINITY;
for rho in &rho_grid {
let mut probe = build_term();
probe
.warm_start_latents_from_amortized_encoder(target.view(), rho)
.ok();
let Ok((cotrained, _loss, _consistency)) =
probe.reml_criterion_cotrained(target.view(), rho, None, 64, 1.0, 1.0e-4, 1.0e-4)
else {
continue;
};
if cotrained < best_cot_cost {
best_cot_cost = cotrained;
best_cot_rho = rho.clone();
}
}
assert!(
best_cot_cost.is_finite(),
"the co-trained grid must contain at least one converged candidate"
);
let mut cot_term = build_term();
let mut cot_rho = best_cot_rho.clone();
cot_term
.warm_start_latents_from_amortized_encoder(target.view(), &cot_rho)
.ok();
cot_term
.run_joint_fit_arrow_schur(target.view(), &mut cot_rho, None, 64, 1.0, 1.0e-4, 1.0e-4)
.expect("co-trained warm-started inner solve converges");
let (cot_gap, cot_certified) = heldout_recovery_gap(&cot_term);
eprintln!(
"#1154 RECOVERY: sequential max-phase-gap={seq_gap:.6e} (certified={seq_certified}) \
| co-trained max-phase-gap={cot_gap:.6e} (certified={cot_certified}) \
| delta(cot-seq)={:.6e}",
cot_gap - seq_gap
);
assert!(
seq_gap.is_finite() && cot_gap.is_finite(),
"both dictionaries' exact held-out recovery gaps must be finite: \
sequential={seq_gap}, co-trained={cot_gap}"
);
assert!(
cot_gap <= seq_gap + 1.0e-3,
"co-trained dictionary must recover the planted held-out manifold at \
least as well as the sequential REML-then-distill path: \
co-trained max phase gap={cot_gap}, sequential={seq_gap} \
(amortized-certified rows: co-trained={cot_certified}, \
sequential={seq_certified})"
);
assert!(
seq_certified > 0 && cot_certified > 0,
"the amortized encoder must certify at least one unit-amplitude held-out row \
on this clean planted circle (basin warm-up closed the #1154 certifies-zero \
gap); got sequential={seq_certified}, co-trained={cot_certified} of {n_holdout}"
);
}
#[test]
fn sae_1026_curved_beats_linear_reconstruction_through_solver() {
let n = 48usize;
let p = 4usize;
let coords = Array2::from_shape_fn((n, 1), |(row, _)| (row as f64 + 0.5) / n as f64);
let (phi_c, jet_c) = periodic_basis(&coords);
let mc = phi_c.ncols();
let decoder_c = Array2::from_shape_fn((mc, p), |(b, c)| {
(1.0 / (1.0 + b as f64)) * ((b as f64 + 1.0) * (c as f64 + 1.0)).cos()
});
let target = phi_c.dot(&decoder_c);
let curved_ev = {
let atom = SaeManifoldAtom::new(
"circle",
SaeAtomBasisKind::Periodic,
1,
phi_c.clone(),
jet_c,
decoder_c.clone(),
Array2::<f64>::eye(mc),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 1)),
vec![coords.clone()],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(1.0),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let mut rho = SaeManifoldRho::new(0.0, 0.8_f64.ln(), vec![array![1.0_f64.ln()]]);
term.run_joint_fit_arrow_schur(target.view(), &mut rho, None, 12, 0.1, 1.0e-4, 1.0e-4)
.expect("curved inner solve converges on the planted circle");
let fitted = term.try_fitted_for_rho(&rho).unwrap();
reconstruction_explained_variance(target.view(), fitted.view()).unwrap()
};
let linear_ev = {
let evaluator = Arc::new(EuclideanPatchEvaluator::new(1, 1).unwrap());
let (phi_l, jet_l) = evaluator.evaluate(coords.view()).unwrap();
let ml = phi_l.ncols();
let atom = SaeManifoldAtom::new(
"linear",
SaeAtomBasisKind::EuclideanPatch,
1,
phi_l,
jet_l,
Array2::<f64>::zeros((ml, p)),
Array2::<f64>::eye(ml),
)
.unwrap()
.with_basis_second_jet(evaluator);
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 1)),
vec![coords.clone()],
vec![LatentManifold::Euclidean],
AssignmentMode::softmax(1.0),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let mut rho = SaeManifoldRho::new(0.0, 0.8_f64.ln(), vec![Array1::<f64>::zeros(1)]);
term.run_joint_fit_arrow_schur(target.view(), &mut rho, None, 12, 0.1, 1.0e-4, 1.0e-4)
.expect("linear inner solve converges");
let fitted = term.try_fitted_for_rho(&rho).unwrap();
reconstruction_explained_variance(target.view(), fitted.view()).unwrap()
};
eprintln!("#1026 solver reconstruction: curved EV={curved_ev:.4}, linear EV={linear_ev:.4}");
assert!(
curved_ev > 0.9,
"the periodic atom must recover the planted circle through the solver (EV={curved_ev})"
);
assert!(
curved_ev > linear_ev + 0.2,
"curved must beat the matched-K linear baseline by a wide margin (the shatter \
penalty: a degree-1 secant cannot follow a closed circle): \
curved={curved_ev}, linear={linear_ev}"
);
}
#[test]
fn sae_1026_full_encode_decode_heldout_curved_certifies() {
let n = 48usize; let p = 4usize;
let coords = Array2::from_shape_fn((n,1), |(r,_)| (r as f64 + 0.5)/n as f64);
let (phi, jet) = periodic_basis(&coords);
let m = phi.ncols();
let decoder = Array2::from_shape_fn((m,p), |(b,c)| (1.0/(1.0+b as f64))*((b as f64+1.0)*(c as f64+1.0)).cos());
let atom = SaeManifoldAtom::new("circle", SaeAtomBasisKind::Periodic, 1, phi, jet,
decoder.clone(), Array2::<f64>::eye(m)).unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n,1)), vec![coords.clone()],
vec![LatentManifold::Circle{period:1.0}], AssignmentMode::softmax(1.0)).unwrap();
let term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let n_test = 32usize;
let theta_test = Array2::from_shape_fn((n_test,1), |(r,_)| (r as f64 + 0.25)/n_test as f64);
let (phi_test, _) = periodic_basis(&theta_test);
let z_test = phi_test.dot(&decoder);
let amps = Array1::<f64>::ones(n_test);
let mut norm_bound = 0.0_f64;
for r in 0..n_test { norm_bound = norm_bound.max(z_test.row(r).dot(&z_test.row(r)).sqrt()); }
let atlas = crate::encode::EncodeAtlas::build(&term.atoms, &[1.0], norm_bound,
crate::encode::AtlasConfig::default()).expect("atlas builds");
let mut recon = Array2::<f64>::zeros((n_test, p));
let mut certified = 0usize;
for r in 0..n_test {
let (coord, cert) = atlas.certified_encode_row(&term.atoms[0], 0, z_test.row(r), amps[r])
.expect("held-out encode runs");
if cert.certified() { certified += 1; }
let cc = Array2::from_shape_fn((1,1), |_| coord[0]);
let (phi_enc, _) = periodic_basis(&cc);
let rr = phi_enc.dot(&decoder);
for c in 0..p { recon[[r,c]] = rr[[0,c]]; }
}
let ev = reconstruction_explained_variance(z_test.view(), recon.view()).unwrap();
eprintln!("FULL_ENCODE_DECODE heldout EV={ev:.4} certified={certified}/{n_test}");
assert!(ev > 0.95,
"full encode+decode must recover on-manifold held-out curved points (EV={ev})");
assert!(certified > 0,
"basin-warmup fix must certify held-out curved encodes at unit amplitude (got {certified})");
}
#[test]
fn sae_1026_solver_recovers_separable_superposition_but_not_below_2k() {
let recover = |p: usize, overlap: bool| -> f64 {
let n = 80usize;
let theta_a = Array2::from_shape_fn((n, 1), |(r, _)| ((r as f64) * 0.043).rem_euclid(1.0));
let theta_b =
Array2::from_shape_fn((n, 1), |(r, _)| ((r as f64) * 0.071 + 0.13).rem_euclid(1.0));
let mut target = Array2::<f64>::zeros((n, p));
for r in 0..n {
let a = std::f64::consts::TAU * theta_a[[r, 0]];
let b = std::f64::consts::TAU * theta_b[[r, 0]];
if !overlap {
target[[r, 0]] = a.cos();
target[[r, 1]] = a.sin();
target[[r, 2]] = b.cos();
target[[r, 3]] = b.sin();
} else {
target[[r, 0]] += a.cos();
target[[r, 1]] += a.sin();
target[[r, 1]] += b.cos();
target[[r, 2]] += b.sin();
}
}
let seed_a = Array2::from_shape_fn((n, 1), |(r, _)| (theta_a[[r, 0]] + 0.03).rem_euclid(1.0));
let seed_b = Array2::from_shape_fn((n, 1), |(r, _)| (theta_b[[r, 0]] + 0.03).rem_euclid(1.0));
let (pa, ja) = periodic_basis(&seed_a);
let (pb, jb) = periodic_basis(&seed_b);
let m = pa.ncols();
let a0 = SaeManifoldAtom::new(
"cA", SaeAtomBasisKind::Periodic, 1, pa, ja,
Array2::<f64>::zeros((m, p)), Array2::<f64>::eye(m),
).unwrap().with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let a1 = SaeManifoldAtom::new(
"cB", SaeAtomBasisKind::Periodic, 1, pb, jb,
Array2::<f64>::zeros((m, p)), Array2::<f64>::eye(m),
).unwrap().with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let logits = Array2::<f64>::from_elem((n, 2), 6.0 * 0.5);
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
vec![seed_a.clone(), seed_b.clone()],
vec![LatentManifold::Circle { period: 1.0 }, LatentManifold::Circle { period: 1.0 }],
AssignmentMode::ibp_map(0.5, 1.0, false),
).unwrap();
let mut term = SaeManifoldTerm::new(vec![a0, a1], assignment).unwrap();
let mut rho = SaeManifoldRho::new(
0.0, 0.01_f64.ln(), vec![array![1.0_f64.ln()], array![1.0_f64.ln()]],
);
term.run_joint_fit_arrow_schur(target.view(), &mut rho, None, 24, 0.1, 1.0e-4, 1.0e-4)
.expect("K=2 inner solve converges");
let fitted = term.try_fitted_for_rho(&rho).unwrap();
reconstruction_explained_variance(target.view(), fitted.view()).unwrap()
};
let separable = recover(4, false);
let under_determined = recover(3, true);
eprintln!("#1026 K=2 superposition: separable(p=4)={separable:.4}, overlap(p=3)={under_determined:.4}");
assert!(
separable > 0.95,
"the joint solver must recover two superposed circles when p >= 2K (EV={separable})"
);
assert!(
separable > under_determined + 0.2,
"p >= 2K must matter: separable superposition recovers but p < 2K (overlapping planes) is under-determined and collapses — separable={separable}, overlap={under_determined}"
);
}
#[test]
pub(crate) fn deflated_solver_matches_plain_solve_when_no_gauge_is_installed() {
let cache = diagonal_latent_cache(&[2.0_f64, 5.0, 7.0]);
let solver = DeflatedArrowSolver::plain(&cache);
let rhs_t = array![4.0_f64, 10.0, -14.0];
let rhs_beta = Array1::<f64>::zeros(0);
let (plain_t, plain_beta) = cache
.full_inverse_apply(rhs_t.view(), rhs_beta.view())
.expect("plain cache solve");
let solved = solver
.solve(rhs_t.view(), rhs_beta.view())
.expect("adapter solve");
assert_eq!(solved.t.len(), plain_t.len());
for idx in 0..plain_t.len() {
assert_abs_diff_eq!(solved.t[idx], plain_t[idx], epsilon = 0.0);
}
assert_eq!(solved.beta.len(), plain_beta.len());
for idx in 0..plain_beta.len() {
assert_abs_diff_eq!(solved.beta[idx], plain_beta[idx], epsilon = 0.0);
}
}
#[test]
pub(crate) fn deflated_solver_matches_dense_quotient_pseudoinverse_on_near_null_fixture() {
let cache = diagonal_latent_cache(&[2.0_f64, 1.0e-14]);
let gauge = array![0.0_f64, 1.0];
let solver = DeflatedArrowSolver::from_orthonormal_gauges(&cache, vec![gauge], 2.0)
.expect("deflated solver");
let rhs_beta = Array1::<f64>::zeros(0);
let physical_rhs = array![4.0_f64, 0.0];
let solved = solver
.solve(physical_rhs.view(), rhs_beta.view())
.expect("physical solve");
let oracle = array![2.0_f64, 0.0];
for idx in 0..oracle.len() {
assert_abs_diff_eq!(solved.t[idx], oracle[idx], epsilon = 1.0e-12);
}
let gauge_rhs = array![0.0_f64, 1.0];
let plain = cache
.full_inverse_apply(gauge_rhs.view(), rhs_beta.view())
.expect("plain gauge solve")
.0;
let stiffened = solver
.solve(gauge_rhs.view(), rhs_beta.view())
.expect("stiffened gauge solve")
.t;
assert!(plain[1] > 1.0e13, "plain near-null solve must be huge");
assert_abs_diff_eq!(stiffened[1], 0.5, epsilon = 1.0e-12);
}
#[test]
pub(crate) fn pca_seed_handles_huge_equal_finite_columns_without_mean_overflow() {
let z = array![[1.0e308_f64, 1.0e308], [1.0e308, 1.0e308]];
let coords =
sae_pca_seed_initial_coords(z.view(), &[SaeAtomBasisKind::Periodic], &[1]).unwrap();
assert_eq!(coords.dim(), (1, 2, 1));
assert!(
coords.iter().all(|value| value.is_finite()),
"huge finite equal columns must not overflow the PCA seed mean: {coords:?}"
);
}
#[test]
pub(crate) fn pca_seed_rejects_huge_finite_span_that_overflows_centering() {
let z = array![[1.0e308_f64, 0.0], [-1.0e308, 0.0]];
let err = sae_pca_seed_initial_coords(z.view(), &[SaeAtomBasisKind::Periodic], &[1])
.expect_err("opposite huge finite values exceed f64 centering range");
assert!(
err.contains("centered Z is non-finite") || err.contains("SVD failed"),
"unexpected PCA seed error: {err}"
);
}
#[test]
pub(crate) fn planted_low_rank_frame_recovered_by_polar() {
let p = 12usize;
let r = 3usize;
let n = 200usize;
let mut planted = Array2::<f64>::zeros((p, r));
for j in 0..r {
planted[[j, j]] = 1.0;
}
let mut coords = Array2::<f64>::zeros((n, r));
for i in 0..n {
for j in 0..r {
let x = ((i * 7 + j * 13 + 1) % 97) as f64 / 97.0 - 0.5;
coords[[i, j]] = x;
}
}
let targets = fast_abt(&coords, &planted);
let angle = grassmann_recover_planted_span_angle(targets.view(), coords.view(), planted.view())
.expect("span recovery");
assert_abs_diff_eq!(angle, 0.0, epsilon = 1.0e-9);
let frame = GrassmannFrame::polar_update(planted.view()).expect("polar");
let recovered_angle = frame
.max_principal_angle(planted.view())
.expect("principal angle");
assert_abs_diff_eq!(recovered_angle, 0.0, epsilon = 1.0e-9);
let gram = fast_atb(&frame.frame().to_owned(), &frame.frame().to_owned());
for i in 0..r {
for j in 0..r {
let expect = if i == j { 1.0 } else { 0.0 };
assert_abs_diff_eq!(gram[[i, j]], expect, epsilon = 1.0e-9);
}
}
}
#[test]
fn jumprelu_hdiag_third_derivative_matches_central_difference_1415() {
use ndarray::{Array1, Array2, Array3};
let n = 6usize;
let k = 2usize;
let p = 3usize;
let temperature = 0.35_f64;
let threshold = 0.1_f64;
let logits = Array2::<f64>::from_shape_vec(
(n, k),
vec![
0.1, 0.0, 0.2, -0.05, 0.05, 0.15, 0.25, 0.3, -0.1, 0.12, 0.18, 0.08,
],
)
.expect("valid logit grid");
let atoms: Vec<SaeManifoldAtom> = (0..k)
.map(|i| {
SaeManifoldAtom::new(
&format!("atom{i}"),
SaeAtomBasisKind::EuclideanPatch,
1,
Array2::<f64>::ones((n, 2)),
Array3::<f64>::zeros((n, 2, 1)),
Array2::<f64>::zeros((2, p)),
Array2::<f64>::eye(2),
)
.unwrap()
})
.collect();
let coords: Vec<Array2<f64>> = (0..k).map(|_| Array2::<f64>::zeros((n, 1))).collect();
let manifolds = vec![LatentManifold::Euclidean; k];
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits.clone(),
coords,
manifolds,
AssignmentMode::jumprelu(temperature, threshold),
)
.expect("valid JumpReLU assignment");
let term = SaeManifoldTerm::new(atoms, assignment).unwrap();
let rho = SaeManifoldRho::new(0.7_f64.ln(), -6.0, vec![Array1::<f64>::zeros(1); k]);
let inv_tau = 1.0 / temperature;
let sparsity = rho.log_lambda_sparse.exp();
let in_band = |logit: f64| {
crate::assignment::jumprelu_in_optimization_band(logit, threshold, temperature)
};
let p2 = |logit: f64| -> f64 {
if !in_band(logit) {
return 0.0;
}
let a = gam_linalg::utils::stable_logistic((logit - threshold) * inv_tau);
let s = a * (1.0 - a);
sparsity * s * (1.0 - 2.0 * a) * inv_tau * inv_tau
};
let mut saw_threshold = false;
for row in 0..n {
for atom in 0..k {
let logit = logits[[row, atom]];
if !in_band(logit) {
continue;
}
let entry = term.assignment_prior_hdiag_derivative_entry(
&rho,
row,
atom,
SaeLocalRowVar::Logit { atom },
None,
);
let h = 1.0e-3_f64;
let fd = (-p2(logit + 2.0 * h) + 8.0 * p2(logit + h) - 8.0 * p2(logit - h)
+ p2(logit - 2.0 * h))
/ (12.0 * h);
let scale = entry.abs().max(fd.abs()).max(1.0e-8);
assert!(
(entry - fd).abs() <= 1.0e-5 * scale,
"row {row} atom {atom}: P''' entry {entry:e} vs FD {fd:e}"
);
if (logit - threshold).abs() < 1e-12 {
saw_threshold = true;
let expected = -sparsity / 8.0 * inv_tau * inv_tau * inv_tau;
assert_abs_diff_eq!(entry, expected, epsilon = 1e-9);
assert!(
entry < -1e-6,
"threshold third derivative must be strictly negative (old buggy \
formula returned 0): entry={entry:e}"
);
}
}
}
assert!(
saw_threshold,
"fixture must include a logit exactly at the threshold to pin −λ/(8τ³)"
);
}
#[test]
fn encode_grad_hess_and_beta_eta_match_finite_differences() {
use crate::encode::{beta_eta_newton, encode_grad_hess};
use ndarray::Array2;
let train = Array2::from_shape_fn((24, 1), |(r, _)| (r as f64 + 0.5) / 24.0);
let (phi, jet) = periodic_basis(&train);
let m = phi.ncols();
let p = 4usize;
let decoder = Array2::from_shape_fn((m, p), |(b, c)| {
(1.0 / (1.0 + b as f64)) * ((b as f64 + 1.0) * (c as f64 + 1.0)).cos()
});
let atom = SaeManifoldAtom::new(
"circle",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
decoder.clone(),
Array2::<f64>::eye(m),
)
.unwrap();
let eval = TestPeriodicEvaluator;
let amplitude = 0.8_f64;
let decode = |t: f64| -> ndarray::Array1<f64> {
let coords = Array2::from_shape_fn((1, 1), |_| t);
let (ph, _) = periodic_basis(&coords);
amplitude * ph.dot(&decoder).row(0).to_owned()
};
let t0 = 0.137_f64;
let x = &decode(0.42) + &ndarray::Array1::from_vec(vec![0.3, -0.2, 0.15, -0.25]);
let f = |t: f64| -> f64 {
let r = &decode(t) - &x;
0.5 * r.dot(&r)
};
let t_view = ndarray::Array1::from_vec(vec![t0]);
let (g, h) = encode_grad_hess(&atom, &eval, t_view.view(), x.view(), amplitude, 0.0)
.expect("encode_grad_hess runs")
.expect("second jet present ⇒ Some");
let eps = 1e-6;
let g_fd = (f(t0 + eps) - f(t0 - eps)) / (2.0 * eps);
assert_abs_diff_eq!(g[0], g_fd, epsilon = 1e-6);
let h_fd = (f(t0 + eps) - 2.0 * f(t0) + f(t0 - eps)) / (eps * eps);
assert_abs_diff_eq!(h[[0, 0]], h_fd, epsilon = 5e-3);
let mut hpd = h.clone();
if hpd[[0, 0]] <= 0.0 {
hpd[[0, 0]] = 1.5;
}
let (beta, eta, delta) = beta_eta_newton(hpd.view(), g.view())
.expect("beta_eta_newton runs")
.expect("SPD ⇒ Some");
assert_abs_diff_eq!(beta * hpd[[0, 0]], 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(delta[0], -g[0] / hpd[[0, 0]], epsilon = 1e-12);
assert_abs_diff_eq!(eta, (g[0] / hpd[[0, 0]]).abs(), epsilon = 1e-12);
}
#[test]
fn robust_norm_row_weights_rebalances_heavy_tailed_objective() {
use ndarray::Array2;
let n = 100usize;
let p = 4usize;
let mut target = Array2::<f64>::zeros((n, p));
for i in 0..n {
if i < 95 {
for c in 0..p {
target[[i, c]] = ((i * 7 + c * 13) % 11) as f64 / 11.0 - 0.5;
}
} else {
for c in 0..p {
target[[i, c]] = 5.0 * if c == 0 { 1.0 } else { 0.3 };
}
}
}
let norms: Vec<f64> = (0..n)
.map(|i| {
let r = target.row(i);
r.dot(&r).sqrt()
})
.collect();
let hi: Vec<usize> = (95..n).collect();
let total_sq: f64 = norms.iter().map(|nm| nm * nm).sum();
let hi_sq: f64 = hi.iter().map(|&i| norms[i] * norms[i]).sum();
let unweighted_share = hi_sq / total_sq;
let w = SaeManifoldTerm::robust_norm_row_weights(target.view(), 1.0).unwrap();
let mean: f64 = w.iter().sum::<f64>() / n as f64;
assert_abs_diff_eq!(mean, 1.0, epsilon = 1e-9);
for &i in &hi {
assert!(w[i] < 0.5, "high-norm token {i} should be downweighted, w={}", w[i]);
}
let total_w: f64 = (0..n).map(|i| w[i] * norms[i] * norms[i]).sum();
let hi_w: f64 = hi.iter().map(|&i| w[i] * norms[i] * norms[i]).sum();
let weighted_share = hi_w / total_w;
assert!(
weighted_share < unweighted_share * 0.6,
"robust weighting must materially cut the high-norm cluster's objective \
share: unweighted={unweighted_share:.3}, weighted={weighted_share:.3}"
);
let flat = Array2::<f64>::from_elem((4, p), 2.0);
let wf = SaeManifoldTerm::robust_norm_row_weights(flat.view(), 1.0).unwrap();
assert!(wf.iter().all(|&x| (x - 1.0).abs() < 1e-12), "flat norms → uniform weights");
}