use super::tests::{global_ev, planted_circle_embedded};
use super::*;
use crate::sparse_dict::{fit_sparse_dictionary, SparseDictConfig};
use crate::basis::{EuclideanPatchEvaluator, PeriodicHarmonicEvaluator, SaeBasisSecondJet};
use gam_linalg::faer_ndarray::{fast_atb, FaerCholesky};
use gam_solve::rho_optimizer::OuterObjective;
use ndarray::{array, s, Array1, Array2, ArrayView2};
use std::sync::Arc;
#[derive(Clone, Copy)]
enum Topo {
Circle,
Euclidean,
Linear,
}
fn build_term(
z: ArrayView2<'_, f64>,
k: usize,
topo: Topo,
mode: AssignmentMode,
) -> (SaeManifoldTerm, f64) {
let n = z.nrows();
let (basis_kind, dim, topo_name): (SaeAtomBasisKind, usize, &str) = match topo {
Topo::Circle => (SaeAtomBasisKind::Periodic, 1, "circle"),
Topo::Euclidean => (SaeAtomBasisKind::EuclideanPatch, 1, "euclidean"),
Topo::Linear => (SaeAtomBasisKind::Linear, 1, "linear"),
};
let evaluator: Arc<dyn SaeBasisSecondJet> = match topo {
Topo::Circle => Arc::new(PeriodicHarmonicEvaluator::new(3).unwrap()),
Topo::Euclidean => Arc::new(EuclideanPatchEvaluator::new(dim, 2).unwrap()),
Topo::Linear => Arc::new(EuclideanPatchEvaluator::new(dim, 1).unwrap()),
};
let basis_kinds = vec![basis_kind.clone(); k];
let atom_dims = vec![dim; k];
let seed_coords = sae_pca_seed_initial_coords(z, &basis_kinds, &atom_dims).unwrap();
let mut atoms = Vec::with_capacity(k);
let mut coords_blocks = Vec::with_capacity(k);
let mut manifolds = Vec::with_capacity(k);
let mut rss = 0.0_f64;
for atom_idx in 0..k {
let coords = seed_coords.slice(s![atom_idx, .., 0..dim]).to_owned();
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let mm = phi.ncols();
let mut xtx = fast_atb(&phi, &phi);
for i in 0..mm {
xtx[[i, i]] += 1.0e-8;
}
let xtz = fast_atb(&phi, &z.to_owned());
let decoder = xtx.cholesky(Side::Lower).unwrap().solve_mat(&xtz);
let fitted = phi.dot(&decoder);
for row in 0..n {
for col in 0..z.ncols() {
let r = z[[row, col]] - fitted[[row, col]];
rss += r * r;
}
}
let atom = SaeManifoldAtom::new(
topo_name,
basis_kind.clone(),
dim,
phi,
jet,
decoder,
Array2::<f64>::eye(mm),
)
.unwrap()
.with_basis_evaluator(evaluator.clone());
atoms.push(atom);
coords_blocks.push(coords);
manifolds.push(match topo {
Topo::Circle => LatentManifold::Circle { period: 1.0 },
_ => LatentManifold::Euclidean,
});
}
let seed_dispersion = (rss / (k * n * z.ncols()) as f64).max(1.0e-12);
let mut logits = Array2::<f64>::zeros((n, k));
for row in 0..n {
for atom in 0..k {
logits[[row, atom]] = match mode {
AssignmentMode::IBPMap { .. } => 6.0,
AssignmentMode::Softmax { .. } => {
if atom == row % k {
3.0
} else {
0.0
}
}
AssignmentMode::ThresholdGate { .. } => {
if atom == row % k {
3.0
} else {
-3.0
}
}
};
}
}
let assignment =
SaeAssignment::from_blocks_with_mode_and_manifolds(logits, coords_blocks, manifolds, mode)
.unwrap();
(
SaeManifoldTerm::new(atoms, assignment).unwrap(),
seed_dispersion,
)
}
fn objective_and_seed(
z: ArrayView2<'_, f64>,
k: usize,
topo: Topo,
mode: AssignmentMode,
) -> (SaeManifoldOuterObjective, Array1<f64>) {
let (term, seed_dispersion) = build_term(z, k, topo, mode);
let init_rho = SaeManifoldRho::new(0.02_f64.ln(), 1.0_f64.ln(), vec![array![0.0]; k])
.seed_scaled_by_dispersion_for_assignment(seed_dispersion, mode)
.unwrap();
let init_rho_flat = init_rho.to_flat();
let objective = SaeManifoldOuterObjective::new(
term,
z.to_owned(),
None,
init_rho,
8,
0.04,
1.0e-6,
1.0e-6,
);
(objective, init_rho_flat)
}
fn seed_passes_startup_validation(
z: ArrayView2<'_, f64>,
k: usize,
topo: Topo,
mode: AssignmentMode,
) -> Result<f64, String> {
let (mut objective, seed) = objective_and_seed(z, k, topo, mode);
assert!(
seed.len() > 8,
"test must exercise the EFS lane (n_params={} must exceed 8)",
seed.len()
);
let eval = objective
.eval_efs(&seed)
.map_err(|e| e.to_string())?;
if !eval.cost.is_finite() {
return Err(format!("EFS seed cost is non-finite ({})", eval.cost));
}
if let Some((idx, v)) = eval.steps.iter().enumerate().find(|(_, v)| !v.is_finite()) {
return Err(format!("EFS seed step[{idx}] is non-finite ({v})"));
}
Ok(eval.cost)
}
#[test]
fn all_assignment_topology_combinations_pass_startup_validation_1782() {
let z = planted_circle_embedded(48, 6, 0.03);
let k = 4usize;
let cases: Vec<(&str, Topo, AssignmentMode)> = vec![
(
"circle/ibp_map",
Topo::Circle,
AssignmentMode::ibp_map(1.0, 1.0, false),
),
("circle/softmax", Topo::Circle, AssignmentMode::softmax(1.0)),
(
"circle/threshold_gate",
Topo::Circle,
AssignmentMode::threshold_gate(1.0, 0.0),
),
(
"euclidean/ibp_map",
Topo::Euclidean,
AssignmentMode::ibp_map(1.0, 1.0, false),
),
(
"linear/ibp_map",
Topo::Linear,
AssignmentMode::ibp_map(1.0, 1.0, false),
),
];
for (label, topo, mode) in cases {
let result = seed_passes_startup_validation(z.view(), k, topo, mode);
match &result {
Ok(cost) => eprintln!("REPRO1782 {label}: startup OK (cost={cost:.4e})"),
Err(e) => eprintln!("REPRO1782 {label}: startup ERR={e}"),
}
result.unwrap_or_else(|e| {
panic!("#1782 {label} must pass outer startup validation, got: {e}")
});
}
}
fn run_full_fit(
z: ArrayView2<'_, f64>,
k: usize,
topo: Topo,
mode: AssignmentMode,
label: &str,
) -> f64 {
let (mut objective, seed) = objective_and_seed(z, k, topo, mode);
let n_params = seed.len();
gam_solve::rho_optimizer::OuterProblem::new(n_params)
.with_initial_rho(seed)
.with_max_iter(4)
.with_seed_config(gam_problem::SeedConfig {
max_seeds: 1,
seed_budget: 1,
..Default::default()
})
.run(&mut objective, "SAE manifold")
.unwrap_or_else(|e| {
panic!("#1782 {label} fit must not abort at startup / in the outer solver, got: {e}")
});
let fitted = objective.into_fitted();
let ev = global_ev(z, fitted.term.fitted().view());
eprintln!("REPRO1782 {label} fit: ev={ev:.4}");
assert!(
ev.is_finite(),
"#1782 {label} produced a non-finite reconstruction EV ({ev})"
);
ev
}
#[test]
fn assignment_kinds_fit_on_circle_1782() {
let z = planted_circle_embedded(48, 6, 0.03);
let k = 4usize;
for (label, mode) in [
("circle/ibp_map", AssignmentMode::ibp_map(1.0, 1.0, false)),
("circle/softmax", AssignmentMode::softmax(1.0)),
(
"circle/threshold_gate",
AssignmentMode::threshold_gate(1.0, 0.0),
),
] {
run_full_fit(z.view(), k, Topo::Circle, mode, label);
}
}
#[test]
fn topologies_fit_on_circle_data_1782() {
let z = planted_circle_embedded(48, 6, 0.03);
let k = 4usize;
for (label, topo) in [
("euclidean/ibp_map", Topo::Euclidean),
("linear/ibp_map", Topo::Linear),
] {
run_full_fit(
z.view(),
k,
topo,
AssignmentMode::ibp_map(1.0, 1.0, false),
label,
);
}
}
#[test]
fn cocollapse_startup_frontier_1026() {
let z = planted_circle_embedded(96, 10, 0.03);
let ks = [4usize, 8];
let modes: [(&str, fn() -> AssignmentMode); 3] = [
("ibp_map ", || AssignmentMode::ibp_map(1.0, 1.0, false)),
("thresh_gate", || AssignmentMode::threshold_gate(1.0, 0.5)),
("softmax ", || AssignmentMode::softmax(1.0)),
];
let mut ibp_frontier = 0usize;
for (label, mk) in modes {
let mut frontier = 0usize;
for &k in &ks {
match seed_passes_startup_validation(z.view(), k, Topo::Circle, mk()) {
Ok(cost) => {
eprintln!("FRONTIER1026 {label} K={k:>3}: startup PASS (cost={cost:.4e})");
frontier = k;
}
Err(e) => {
eprintln!("FRONTIER1026 {label} K={k:>3}: startup FAIL ({e})");
break;
}
}
}
eprintln!("FRONTIER1026 {label}: largest passing K = {frontier}");
if label.trim() == "ibp_map" {
ibp_frontier = frontier;
}
}
assert!(
ibp_frontier >= 4,
"startup validation must hold at least to K=4 (got frontier {ibp_frontier})"
);
}
#[test]
fn manifold_beats_linear_joint_streaming_1026() {
let z = planted_circle_embedded(120, 10, 0.03);
for &k in &[8usize] {
let z32 = z.mapv(|v| v as f32);
let lin = fit_sparse_dictionary(z32.view(), &SparseDictConfig::new(k))
.expect("linear SAE baseline fits");
let ev_linear = lin.explained_variance;
let mode = AssignmentMode::threshold_gate(1.0, 0.0);
let (mut term, _disp) = build_term(z.view(), k, Topo::Circle, mode);
let mut rho = SaeManifoldRho::new(
1.0e-3_f64.ln(),
1.0e-3_f64.ln(),
vec![array![1.0e-3_f64.ln()]; k],
);
term.run_joint_fit_arrow_schur(z.view(), &mut rho, None, 24, 1.0, 1.0e-6, 1.0e-6)
.unwrap_or_else(|e| panic!("#1026 manifold K={k} joint inner fit must run e2e, got: {e}"));
let fitted = term.try_fitted().expect("manifold fitted");
let ev_manifold = global_ev(z.view(), fitted.view());
eprintln!(
"WIN1026 K={k:>3}: manifold EV={ev_manifold:.4} linear EV={ev_linear:.4} \
margin={:+.4}",
ev_manifold - ev_linear
);
assert!(
ev_manifold.is_finite() && ev_linear.is_finite(),
"#1026 K={k}: both EVs must be finite (manifold={ev_manifold}, linear={ev_linear})"
);
assert!(
ev_manifold + 5.0e-2 >= ev_linear,
"#1026 K={k}: principled manifold SAE must match-or-beat linear \
(manifold={ev_manifold:.4} vs linear={ev_linear:.4})"
);
}
}