use gam_linalg::faer_ndarray::fast_ata;
use super::*;
use ndarray::array;
pub(crate) fn real_data_torus_seed_term(
z: ArrayView2<'_, f64>,
k: usize,
num_harmonics: usize,
) -> SaeManifoldTerm {
let n = z.nrows();
let evaluator = Arc::new(TorusHarmonicEvaluator::new(2, num_harmonics).unwrap());
let basis_kinds = vec![SaeAtomBasisKind::Periodic; k];
let atom_dims = vec![2usize; k];
let seed_coords = sae_pca_seed_initial_coords(z, &basis_kinds, &atom_dims).unwrap();
let mut atoms = Vec::with_capacity(k);
let mut coords_blocks = Vec::with_capacity(k);
let mut manifolds = Vec::with_capacity(k);
for atom_idx in 0..k {
let coords = seed_coords.slice(s![atom_idx, .., 0..2]).to_owned();
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let m = phi.ncols();
let mut xtx = fast_ata(&phi);
for i in 0..m {
xtx[[i, i]] += 1.0e-8;
}
let xtz = fast_atb(&phi, &z.to_owned());
let decoder = xtx.cholesky(Side::Lower).unwrap().solve_mat(&xtz);
let atom = SaeManifoldAtom::new(
"torus",
SaeAtomBasisKind::Periodic,
2,
phi,
jet,
decoder,
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_evaluator(evaluator.clone());
atoms.push(atom);
coords_blocks.push(coords);
manifolds.push(LatentManifold::Product(vec![
LatentManifold::Circle { period: 1.0 },
LatentManifold::Circle { period: 1.0 },
]));
}
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::from_elem((n, k), 0.0),
coords_blocks,
manifolds,
AssignmentMode::softmax(1.0),
)
.unwrap();
SaeManifoldTerm::new(atoms, assignment).unwrap()
}
#[test]
pub(crate) fn olmo_real_curvature_anchor_is_positive_definite() {
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/olmo_mixedlayer_pca64_768.npy");
let z = read_npy_f32_2d(&path);
assert_eq!(z.dim(), (768, 64), "real OLMo fixture shape");
let z_train = z.slice(s![..160, ..]).to_owned();
let k = 2usize;
let mut term = real_data_torus_seed_term(z_train.view(), k, 3);
let rho = SaeManifoldRho::new(0.0, 0.0, vec![array![0.0, 0.0]; k]);
let registry = SaeManifoldOuterObjective::new(
term.clone(),
z_train.clone(),
None,
rho.clone(),
0,
0.04,
1.0e-6,
1.0e-6,
)
.registry;
use gam_linalg::faer_ndarray::FaerEigh;
let sys = term
.assemble_arrow_schur(z_train.view(), &rho, registry.as_ref())
.expect("assemble raw curvature anchor");
let mut min_raw_eig = f64::INFINITY;
let mut max_raw_eig = 0.0_f64;
let mut indefinite_rows = 0usize;
let mut total_neg_dirs = 0usize;
for block in &sys.rows {
let d = block.htt.nrows();
if d == 0 {
continue;
}
let mut sym = Array2::<f64>::zeros((d, d));
for i in 0..d {
for j in 0..d {
sym[[i, j]] = 0.5 * (block.htt[[i, j]] + block.htt[[j, i]]);
}
}
let (evals, _) = sym.eigh(faer::Side::Lower).unwrap();
let max_abs = evals.iter().fold(0.0_f64, |a, &v| a.max(v.abs())).max(1.0);
let neg_floor = -1.0e-8 * max_abs;
let row_min = evals.iter().cloned().fold(f64::INFINITY, f64::min);
let row_neg = evals.iter().filter(|&&v| v < neg_floor).count();
min_raw_eig = min_raw_eig.min(row_min);
max_raw_eig = max_raw_eig.max(max_abs);
if row_neg > 0 {
indefinite_rows += 1;
total_neg_dirs += row_neg;
}
}
let rel_min = min_raw_eig / max_raw_eig.max(1.0);
eprintln!(
"[#1190] real-data curvature anchor (K={k}, N={}): RAW assembled H_tt \
min_eig={min_raw_eig:.6e} (rel={rel_min:.3e}) indefinite_rows={indefinite_rows}/{} \
total_neg_dirs={total_neg_dirs}",
z_train.nrows(),
sys.rows.len()
);
assert!(
rel_min >= -1.0e-8,
"real-data curvature anchor is genuinely indefinite: raw assembled H_tt \
min eigenvalue {min_raw_eig:.6e} (relative {rel_min:.3e}) is negative on \
{indefinite_rows}/{} rows ({total_neg_dirs} negative directions) — the \
d=2 atoms are under-identified on real OLMo activations (#1190). The \
curvature anchor must be PD (or its negative directions must be genuine \
closed-form gauge nulls, not data-supported directions).",
sys.rows.len()
);
}
#[test]
pub(crate) fn olmo_real_outer_fit_does_not_pin_at_collapse_sentinel() {
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/olmo_mixedlayer_pca64_768.npy");
let z = read_npy_f32_2d(&path);
assert_eq!(z.dim(), (768, 64), "real OLMo fixture shape");
let z_train = z.slice(s![..384, ..]).to_owned();
let k = 8usize;
let term = real_data_torus_seed_term(z_train.view(), k, 2);
let rho = SaeManifoldRho::new(0.0, 0.0, vec![array![0.0, 0.0]; k]);
let objective = SaeManifoldOuterObjective::new(
term,
z_train.clone(),
None,
rho.clone(),
0,
0.04,
1.0e-6,
1.0e-6,
);
let anchor = linear_span_anchor(&objective.term, z_train.view())
.expect("Eckart-Young anchor must be recoverable on the real fixture");
let sst = {
let mut means = vec![0.0_f64; z_train.ncols()];
for col in 0..z_train.ncols() {
let mut acc = 0.0;
for row in 0..z_train.nrows() {
acc += z_train[[row, col]];
}
means[col] = acc / z_train.nrows() as f64;
}
let mut s = 0.0_f64;
for row in 0..z_train.nrows() {
for col in 0..z_train.ncols() {
let c = z_train[[row, col]] - means[col];
s += c * c;
}
}
s
};
let anchor_ev = 1.0 - anchor.residual_norm_sq / sst;
assert!(
anchor_ev.is_finite() && anchor_ev > SAE_FIT_DATA_COLLAPSE_EV_FLOOR,
"real-data Eckart-Young anchor ceiling {anchor_ev:.5} is degenerate (#1189)."
);
eprintln!("[#1189] real-data anchor ceiling anchor_ev={anchor_ev:.5}");
let arrival_floor_k = |achievable_ceiling: f64, k_active: usize| -> f64 {
let k = k_active.max(1) as f64;
(achievable_ceiling * ((k - 1.0) / k)).max(SAE_FIT_DATA_COLLAPSE_EV_FLOOR)
};
let arrival_floor = |achievable_ceiling: f64| -> f64 { arrival_floor_k(achievable_ceiling, 1) };
let real_regime_ceiling = 0.40_f64; for k in [1usize, 2, 8] {
let f = arrival_floor_k(real_regime_ceiling, k);
eprintln!("[#1189] real regime K={k}: ceiling={real_regime_ceiling} floor={f:.5}");
assert!(
f < real_regime_ceiling,
"[#1189] arrival floor {f:.5} (K={k}) is not strictly below the achievable real-data \
ceiling {real_regime_ceiling}: a genuine fit AT the ceiling would be rejected and \
demoted to the collapsing cascade that pins the loop at the 1e12 sentinel."
);
}
let synthetic_ceiling = 0.95_f64;
for k in [1usize, 2, 8] {
let f = arrival_floor_k(synthetic_ceiling, k);
assert!(
f < synthetic_ceiling && 0.94 >= f,
"[#1189] synthetic floor {f:.5} (K={k}) must sit below the achievable ceiling \
{synthetic_ceiling} so a genuine planted-harmonic recovery (EV ≈ 0.94) is accepted."
);
}
let pathological_floor = arrival_floor(0.0);
assert!(
pathological_floor >= SAE_FIT_DATA_COLLAPSE_EV_FLOOR,
"the #1189 floor dropped below the data-collapse threshold on a pathological ceiling \
(floor {pathological_floor:.5} < {SAE_FIT_DATA_COLLAPSE_EV_FLOOR}) (#1189)."
);
for k in [1usize, 2, 8] {
let f = arrival_floor_k(anchor_ev, k);
assert!(
f >= SAE_FIT_DATA_COLLAPSE_EV_FLOOR
&& f < anchor_ev.max(SAE_FIT_DATA_COLLAPSE_EV_FLOOR + 1e-9),
"real-data anchor floor {f:.5} (K={k}) fell outside [{SAE_FIT_DATA_COLLAPSE_EV_FLOOR}, \
anchor ceiling {anchor_ev:.5}) (#1189)."
);
}
let k3_linear_ceiling = 0.30_f64;
let k3_curved_arrival = 0.2461_f64; let k3_floor = arrival_floor_k(k3_linear_ceiling, 3);
eprintln!(
"[#1026] K=3 ceiling={k3_linear_ceiling:.4} curved_arrival={k3_curved_arrival:.4} \
floor={k3_floor:.4}"
);
assert!(
k3_curved_arrival >= k3_floor,
"[#1026] the per-atom-share floor {k3_floor:.4} still demotes a genuine curved K=3 \
arrival at EV {k3_curved_arrival:.4} (linear ceiling {k3_linear_ceiling:.4}); the K>=2 \
co-collapse regression is NOT fixed."
);
assert!(
k3_floor < k3_linear_ceiling && k3_curved_arrival < k3_linear_ceiling,
"[#1026] the per-atom-share floor {k3_floor:.4} must sit strictly below the FULL linear \
ceiling {k3_linear_ceiling:.4} (else there is no forgiveness and the regression is \
vacuous), and the curved arrival {k3_curved_arrival:.4} must lie in that forgiven band."
);
let f1 = arrival_floor_k(k3_linear_ceiling, 1);
let f2 = arrival_floor_k(k3_linear_ceiling, 2);
let f3 = arrival_floor_k(k3_linear_ceiling, 3);
let f8 = arrival_floor_k(k3_linear_ceiling, 8);
assert!(
f1 <= f2 && f2 <= f3 && f3 <= f8,
"[#1026] arrival floor is not monotone non-decreasing across K \
(K=1 {f1:.4}, K=2 {f2:.4}, K=3 {f3:.4}, K=8 {f8:.4})."
);
assert!(
f8 < k3_linear_ceiling,
"[#1026] the share floor reached/exceeded the full ceiling at large K \
(K=8 {f8:.4} >= ceiling {k3_linear_ceiling:.4})."
);
assert_eq!(SAE_FIT_DATA_COLLAPSE_COST, 1.0e12);
}
pub(crate) fn read_npy_f32_2d(path: &std::path::Path) -> Array2<f64> {
let bytes = std::fs::read(path).unwrap_or_else(|e| panic!("read {}: {e}", path.display()));
assert!(
bytes.len() > 10 && &bytes[0..6] == b"\x93NUMPY",
"not a .npy"
);
let major = bytes[6];
let (hdr_start, hdr_len) = if major == 1 {
(10usize, u16::from_le_bytes([bytes[8], bytes[9]]) as usize)
} else {
(
12usize,
u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize,
)
};
let data_off = hdr_start + hdr_len;
let header = std::str::from_utf8(&bytes[hdr_start..data_off]).unwrap();
assert!(
header.contains("'<f4'") || header.contains("\"<f4\""),
"fixture must be little-endian float32; header: {header}"
);
assert!(!header.contains("True"), "fixture must be C-contiguous");
let open = header.find('(').unwrap();
let close = header[open..].find(')').unwrap() + open;
let dims: Vec<usize> = header[open + 1..close]
.split(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.map(|s| s.parse::<usize>().unwrap())
.collect();
assert_eq!(dims.len(), 2, "fixture must be 2-D");
let (n, p) = (dims[0], dims[1]);
let mut out = Array2::<f64>::zeros((n, p));
let payload = &bytes[data_off..];
assert!(payload.len() >= n * p * 4, "truncated payload");
for r in 0..n {
for c in 0..p {
let i = (r * p + c) * 4;
let v =
f32::from_le_bytes([payload[i], payload[i + 1], payload[i + 2], payload[i + 3]]);
out[[r, c]] = v as f64;
}
}
out
}
#[test]
pub(crate) fn fit_data_collapse_bar_is_data_derived_not_absolute_floor_1522() {
let coords = array![[0.0_f64], [0.25], [0.5], [0.75]];
let n = coords.nrows();
let mut phi = Array2::<f64>::zeros((n, 3));
let mut jet = Array3::<f64>::zeros((n, 3, 1));
for row in 0..n {
let angle = 2.0 * std::f64::consts::PI * coords[[row, 0]];
phi[[row, 0]] = 1.0;
phi[[row, 1]] = angle.sin();
phi[[row, 2]] = angle.cos();
jet[[row, 1, 0]] = 2.0 * std::f64::consts::PI * angle.cos();
jet[[row, 2, 0]] = -2.0 * std::f64::consts::PI * angle.sin();
}
let atom = SaeManifoldAtom::new(
"circle",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
Array2::<f64>::zeros((3, 2)),
Array2::<f64>::eye(3),
)
.unwrap();
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 1)),
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(1.0),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]];
let dictionary_rank = term
.atoms
.iter()
.map(|atom| atom.basis_size())
.sum::<usize>()
.min(target.nrows())
.min(target.ncols());
let ceiling = crate::manifold::outer_objective::pca_ev_ceiling(
target.view(),
dictionary_rank,
);
assert!(
(ceiling - 1.0).abs() < 1e-9,
"rank-{dictionary_rank} PCA ceiling of the unit-circle target must be 1.0; got {ceiling}"
);
let derived_bar = crate::manifold::outer_objective::collapse_ev_bar(
target.view(),
dictionary_rank,
);
assert!(
derived_bar > SAE_FIT_DATA_COLLAPSE_EV_FLOOR,
"data-derived bar {derived_bar} must exceed the old absolute floor \
{SAE_FIT_DATA_COLLAPSE_EV_FLOOR}, or the test cannot distinguish them"
);
let alpha = 1.0 - (0.70_f64).sqrt();
let fitted = &target * alpha;
let ssr: f64 = target
.iter()
.zip(fitted.iter())
.map(|(t, f)| (t - f) * (t - f))
.sum();
let sst: f64 = target.iter().map(|t| t * t).sum();
let ev = 1.0 - ssr / sst;
assert!(
ev > SAE_FIT_DATA_COLLAPSE_EV_FLOOR && ev < derived_bar,
"fit EV {ev} must sit STRICTLY between the old floor \
{SAE_FIT_DATA_COLLAPSE_EV_FLOOR} and the derived bar {derived_bar}"
);
let assignments = Array2::<f64>::ones((n, 1));
let recorded = term
.record_fit_data_collapse_if_needed(target.view(), fitted.view(), assignments.view(), 3)
.unwrap();
assert!(
recorded,
"fit EV {ev} is below the data-derived bar {derived_bar} (half the rank-K \
PCA ceiling) and must be recorded as a collapse; the guard is still keying \
on the absolute floor {SAE_FIT_DATA_COLLAPSE_EV_FLOOR} instead of the data"
);
let terminal = term
.collapse_events()
.iter()
.find(|e| e.action == CollapseAction::Terminal)
.expect("a terminal collapse event must be recorded");
assert!(
(terminal.floor - derived_bar).abs() < 1e-9,
"ledger floor {} must be the data-derived bar {derived_bar}, not the absolute \
{SAE_FIT_DATA_COLLAPSE_EV_FLOOR}",
terminal.floor
);
}
#[test]
pub(crate) fn fast_encode_matches_per_row_warm_start() {
let mani = std::path::Path::new(env!("CARGO_MANIFEST_DIR"));
let mut path = mani.join("tests/data/olmo_l18_pca64_635.npy");
if !path.exists() {
path = mani.join("../../tests/data/olmo_l18_pca64_635.npy");
}
let z = read_npy_f32_2d(&path);
let n = z.nrows();
let k = 1usize;
let term = real_data_torus_seed_term(z.view(), k, 3);
let mut norm_bound = 0.0_f64;
for r in 0..n {
norm_bound = norm_bound.max(z.row(r).dot(&z.row(r)).sqrt());
}
let atlas = crate::encode::EncodeAtlas::build(
&term.atoms,
&vec![1.0_f64; k],
norm_bound,
crate::encode::AtlasConfig::default(),
)
.expect("atlas builds");
let atom = &term.atoms[0];
let amps = ndarray::Array1::<f64>::ones(n);
let evaluator = atom.basis_evaluator.as_ref().unwrap().clone();
let mut ref_coords = ndarray::Array2::<f64>::zeros((n, atom.latent_dim));
let mut ref_valid = vec![false; n];
for row in 0..n {
if let Some((cidx, _)) =
crate::encode::nearest_chart(&atlas.atoms[0], z.row(row), atom, evaluator.as_ref())
{
if let Some(t) = crate::encode::amortized_warm_start(
&atlas.atoms[0].charts[cidx],
z.row(row),
amps[row],
) {
ref_coords.row_mut(row).assign(&t);
ref_valid[row] = true;
}
}
}
let (fast_coords, fast_valid) = atlas
.amortized_encode_batch_fast(atom, 0, z.view(), amps.view())
.expect("batched fast encode runs");
let mut max_diff = 0.0_f64;
for row in 0..n {
assert_eq!(
fast_valid[row], ref_valid[row],
"valid-mask mismatch at row {row} (routing/predictor disagreement)"
);
if ref_valid[row] {
for c in 0..atom.latent_dim {
max_diff = max_diff.max((fast_coords[[row, c]] - ref_coords[[row, c]]).abs());
}
}
}
assert!(
max_diff < 1.0e-12,
"batched fast-encode must match the per-row warm-start to 1e-12 (same affine \
map, GEMM-batched); max|Δcoord| = {max_diff:.3e}"
);
assert!(
ref_valid.iter().filter(|&&v| v).count() > n / 2,
"fixture must produce valid encodes on most rows; got {}",
ref_valid.iter().filter(|&&v| v).count()
);
}
#[test]
pub(crate) fn fast_reconstruct_matches_per_row_decode() {
let mani = std::path::Path::new(env!("CARGO_MANIFEST_DIR"));
let mut path = mani.join("tests/data/olmo_l18_pca64_635.npy");
if !path.exists() {
path = mani.join("../../tests/data/olmo_l18_pca64_635.npy");
}
let z = read_npy_f32_2d(&path);
let n = z.nrows();
let p = z.ncols();
let k = 1usize;
let term = real_data_torus_seed_term(z.view(), k, 3);
let mut norm_bound = 0.0_f64;
for r in 0..n {
norm_bound = norm_bound.max(z.row(r).dot(&z.row(r)).sqrt());
}
let atlas = crate::encode::EncodeAtlas::build(
&term.atoms,
&vec![1.0_f64; k],
norm_bound,
crate::encode::AtlasConfig::default(),
)
.expect("atlas builds");
let atom = &term.atoms[0];
let amps = ndarray::Array1::<f64>::ones(n);
let evaluator = atom.basis_evaluator.as_ref().unwrap().clone();
let (fast_recon, fast_valid) = atlas
.amortized_reconstruct_batch_fast(atom, 0, z.view(), amps.view())
.expect("batched fast reconstruct runs");
let (coords, enc_valid) = atlas
.amortized_encode_batch_fast(atom, 0, z.view(), amps.view())
.expect("batched fast encode runs");
let mut max_diff = 0.0_f64;
let mut valid_rows = 0usize;
for row in 0..n {
assert_eq!(
fast_valid[row], enc_valid[row],
"reconstruct valid-mask must equal encode valid-mask at row {row}"
);
if !fast_valid[row] {
for col in 0..p {
assert_eq!(
fast_recon[[row, col]],
0.0,
"uncertified row {row} must decode to zero, got {}",
fast_recon[[row, col]]
);
}
continue;
}
valid_rows += 1;
let single = coords.row(row).insert_axis(ndarray::Axis(0)).to_owned();
let (phi_row, _jet) = evaluator.evaluate(single.view()).expect("single basis eval");
let decoded_row = phi_row.dot(&atom.decoder_coefficients); for col in 0..p {
let expect = amps[row] * decoded_row[[0, col]];
max_diff = max_diff.max((fast_recon[[row, col]] - expect).abs());
}
}
assert!(
max_diff < 1.0e-10,
"batched fast reconstruct must match the per-row decode z·Φ(t̂)·B (same GEMM, \
batched basis eval); max|Δrecon| = {max_diff:.3e}"
);
assert!(
valid_rows > n / 2,
"fixture must reconstruct most rows; got {valid_rows} valid of {n}"
);
}
#[test]
fn fast_forward_is_accuracy_parity_with_certified() {
let (z_tr, z) = olmo_l18_oos_split();
let n = z.nrows();
let p = z.ncols();
let term = real_data_torus_seed_term(z_tr.view(), 1, 3);
let mut norm_bound = 0.0_f64;
for r in 0..z_tr.nrows() {
norm_bound = norm_bound.max(z_tr.row(r).dot(&z_tr.row(r)).sqrt());
}
for r in 0..n {
norm_bound = norm_bound.max(z.row(r).dot(&z.row(r)).sqrt());
}
let atlas = crate::encode::EncodeAtlas::build(
&term.atoms,
&vec![1.0_f64; 1],
norm_bound,
crate::encode::AtlasConfig::default(),
)
.unwrap();
let atom = &term.atoms[0];
let amps = ndarray::Array1::<f64>::ones(n);
let evaluator = atom.basis_evaluator.as_ref().unwrap().clone();
let (fast_recon, fast_valid) = atlas
.amortized_reconstruct_batch_fast(atom, 0, z.view(), amps.view())
.unwrap();
let mut both: Vec<(f64, f64)> = Vec::new(); let mut fast_valid_count = 0usize;
let mut cert_valid_count = 0usize;
for row in 0..n {
let xr = z.row(row);
let xn = xr.dot(&xr).sqrt().max(1e-12);
let fast_e = if fast_valid[row] {
fast_valid_count += 1;
let mut e = 0.0;
for c in 0..p {
let d = fast_recon[[row, c]] - xr[c];
e += d * d;
}
Some(e.sqrt() / xn)
} else {
None
};
let (coords, cert) = atlas.certified_encode_row(atom, 0, xr, amps[row]).unwrap();
let cert_e = if cert.beta.is_finite() && cert.h.is_finite() {
cert_valid_count += 1;
let single = coords.insert_axis(ndarray::Axis(0));
let (phi, _) = evaluator.evaluate(single.view()).unwrap();
let dec = phi.dot(&atom.decoder_coefficients);
let mut e = 0.0;
for c in 0..p {
let d = amps[row] * dec[[0, c]] - xr[c];
e += d * d;
}
Some(e.sqrt() / xn)
} else {
None
};
if let (Some(f), Some(c)) = (fast_e, cert_e) {
both.push((f, c));
}
}
assert!(
fast_valid_count >= cert_valid_count,
"fast path must cover >= certified rows; fast={fast_valid_count} cert={cert_valid_count}"
);
assert!(
both.len() > n / 8,
"need a non-trivial co-valid set; got {} of {n}",
both.len()
);
let mean = |v: &[f64]| v.iter().sum::<f64>() / v.len().max(1) as f64;
let fast_mean = mean(&both.iter().map(|x| x.0).collect::<Vec<_>>());
let cert_mean = mean(&both.iter().map(|x| x.1).collect::<Vec<_>>());
assert!(
fast_mean <= 1.05 * cert_mean,
"fast forward must be accuracy-parity with certified on co-valid rows; \
fast_mean={fast_mean:.4} cert_mean={cert_mean:.4} ratio={:.3}",
fast_mean / cert_mean
);
}
pub(crate) fn olmo_l18_oos_split() -> (Array2<f64>, Array2<f64>) {
let mani = std::path::Path::new(env!("CARGO_MANIFEST_DIR"));
let mut path = mani.join("tests/data/olmo_l18_pca64_635.npy");
if !path.exists() {
path = mani.join("../../tests/data/olmo_l18_pca64_635.npy");
}
let z = read_npy_f32_2d(&path);
let n = z.nrows();
let n_tr = (n * 6) / 10;
(
z.slice(s![..n_tr, ..]).to_owned(),
z.slice(s![n_tr.., ..]).to_owned(),
)
}
fn oos_sq_sum(z: &Array2<f64>) -> f64 {
let mut t = 0.0;
for r in 0..z.nrows() {
for c in 0..z.ncols() {
t += z[[r, c]] * z[[r, c]];
}
}
t
}
pub(crate) fn oos_train_curved(
z_tr: &Array2<f64>,
z_te: &Array2<f64>,
d: usize,
h: usize,
iters: usize,
data_driven: bool,
maxc: usize,
) -> (f64, f64) {
let n_tr = z_tr.nrows();
let n_te = z_te.nrows();
let p = z_tr.ncols();
let tot_tr = oos_sq_sum(z_tr);
let tot_te = oos_sq_sum(z_te);
let mut nb = 0.0_f64;
for r in 0..n_tr {
nb = nb.max(z_tr.row(r).dot(&z_tr.row(r)).sqrt());
}
for r in 0..n_te {
nb = nb.max(z_te.row(r).dot(&z_te.row(r)).sqrt());
}
let ev_eval = Arc::new(TorusHarmonicEvaluator::new(d, h).unwrap());
let seed = sae_pca_seed_initial_coords(
z_tr.view(),
&vec![SaeAtomBasisKind::Periodic; 1],
&vec![d],
)
.unwrap();
let mut coords = seed.slice(s![0, .., 0..d]).to_owned();
let build = |coords: &Array2<f64>| -> SaeManifoldAtom {
let (phi, jet) = ev_eval.evaluate(coords.view()).unwrap();
let m = phi.ncols();
let mut xtx = fast_ata(&phi);
for i in 0..m {
xtx[[i, i]] += 1e-8;
}
let xtz = fast_atb(&phi, &z_tr.to_owned());
let dec = xtx.cholesky(Side::Lower).unwrap().solve_mat(&xtz);
SaeManifoldAtom::new("t", SaeAtomBasisKind::Periodic, d, phi, jet, dec, Array2::eye(m))
.unwrap()
.with_basis_evaluator(ev_eval.clone())
};
let mk_atlas = |atom: &SaeManifoldAtom, coords: &Array2<f64>| {
if data_driven {
crate::encode::EncodeAtlas::build_data_driven(
std::slice::from_ref(atom),
std::slice::from_ref(coords),
&[1.0],
nb,
maxc,
crate::encode::AtlasConfig::default(),
)
.unwrap()
} else {
crate::encode::EncodeAtlas::build(
std::slice::from_ref(atom),
&[1.0],
nb,
crate::encode::AtlasConfig::default(),
)
.unwrap()
}
};
let amps_tr = ndarray::Array1::<f64>::ones(n_tr);
let mut atom = build(&coords);
for _ in 0..iters {
let atlas = mk_atlas(&atom, &coords);
let (ec, v) = atlas
.amortized_encode_batch_fast(&atom, 0, z_tr.view(), amps_tr.view())
.unwrap();
for i in 0..n_tr {
if v[i] {
coords.row_mut(i).assign(&ec.row(i));
}
}
atom = build(&coords);
}
let rt = atom.basis_values.dot(&atom.decoder_coefficients);
let mut etr = 0.0;
for r in 0..n_tr {
for c in 0..p {
let dd = rt[[r, c]] - z_tr[[r, c]];
etr += dd * dd;
}
}
let ev_in = 1.0 - etr / tot_tr;
let atlas = mk_atlas(&atom, &coords);
let amps_te = ndarray::Array1::<f64>::ones(n_te);
let (rte, _vm) = atlas
.amortized_reconstruct_batch_fast(&atom, 0, z_te.view(), amps_te.view())
.unwrap();
let mut ete = 0.0;
for r in 0..n_te {
for c in 0..p {
let dd = rte[[r, c]] - z_te[[r, c]];
ete += dd * dd;
}
}
let ev_oos = 1.0 - ete / tot_te;
(ev_in, ev_oos)
}
#[test]
fn curved_atom_oos_competitive_with_real_topk_sae() {
let (tr, te) = olmo_l18_oos_split();
let (_in, curved_oos) = oos_train_curved(&tr, &te, 2, 3, 5, false, 0);
eprintln!(
"curved d=2 OOS EV={curved_oos:.4} (real TopK SAE k=2 OOS ≈ 0.217–0.242, \
see tests/sae/real_topk_sae_baseline.py)"
);
assert!(
curved_oos > 0.15 && curved_oos < 0.40,
"curved d=2 OOS EV must sit in the real-TopK-SAE-competitive band [0.15,0.40] \
(measured ~0.22); got {curved_oos:.4}"
);
}
#[test]
fn more_harmonics_overfit_out_of_sample() {
let (tr, te) = olmo_l18_oos_split();
let (in3, oos3) = oos_train_curved(&tr, &te, 2, 3, 5, false, 0);
let (in4, oos4) = oos_train_curved(&tr, &te, 2, 4, 5, false, 0);
eprintln!("h=3: in={in3:.4} OOS={oos3:.4} h=4: in={in4:.4} OOS={oos4:.4}");
assert!(
in4 - in3 > 0.03,
"extra harmonic must raise IN-SAMPLE EV (capacity added); in3={in3:.4} in4={in4:.4}"
);
assert!(
oos4 - oos3 < 0.01,
"extra harmonic must NOT improve OOS EV (it overfits); oos3={oos3:.4} oos4={oos4:.4}"
);
}
#[test]
fn manifold_training_loop_generalizes_but_overfits_out_of_sample() {
let (tr, te) = olmo_l18_oos_split();
let (in0, oos0) = oos_train_curved(&tr, &te, 2, 3, 0, false, 0); let (in6, oos6) = oos_train_curved(&tr, &te, 2, 3, 6, false, 0); eprintln!("seed: in={in0:.4} OOS={oos0:.4} trained: in={in6:.4} OOS={oos6:.4}");
assert!(
oos6 > oos0,
"training must improve OOS reconstruction over the seed; oos0={oos0:.4} oos6={oos6:.4}"
);
assert!(
in6 - in0 > 3.0 * (oos6 - oos0),
"in-sample gain must dwarf OOS gain (overfitting); din={:.4} doos={:.4}",
in6 - in0,
oos6 - oos0
);
}
#[test]
fn data_driven_higher_latent_dim_helps_out_of_sample() {
let (tr, te) = olmo_l18_oos_split();
let (_in2, oos_d2) = oos_train_curved(&tr, &te, 2, 1, 5, true, 256);
let (_in4, oos_d4) = oos_train_curved(&tr, &te, 4, 1, 5, true, 256);
eprintln!("OOS data-driven d=2 EV={oos_d2:.4} d=4 EV={oos_d4:.4}");
assert!(
oos_d2 > 0.0 && oos_d4 > 1.3 * oos_d2,
"data-driven d=4 must beat d=2 OUT-OF-SAMPLE by >30% (latent-dim unlock \
generalises); oos_d2={oos_d2:.4} oos_d4={oos_d4:.4}"
);
}
#[test]
fn certified_encode_is_globally_sound_near_self_crossing() {
use ndarray::{Array1, Array2};
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(5).unwrap());
let n_seed = 64usize;
let seed: Array2<f64> = Array2::from_shape_fn((n_seed, 1), |(i, _)| i as f64 / n_seed as f64);
let (phi, jet) = evaluator.evaluate(seed.view()).unwrap();
let m = phi.ncols();
let mut decoder = Array2::<f64>::zeros((m, 2));
decoder[[2, 0]] = 1.0; decoder[[3, 1]] = 1.0; let atom = SaeManifoldAtom::new("fig8", SaeAtomBasisKind::Periodic, 1, phi, jet, decoder,
Array2::<f64>::eye(m)).unwrap().with_basis_evaluator(evaluator.clone());
let recon = |t: f64| -> [f64; 2] {
let a = 2.0 * std::f64::consts::PI * t;
[a.cos(), (2.0 * a).sin()]
};
let grad = |t: f64, x: &[f64; 2]| -> f64 {
let a = 2.0 * std::f64::consts::PI * t;
let dm = [-2.0 * std::f64::consts::PI * a.sin(),
4.0 * std::f64::consts::PI * (2.0 * a).cos()];
let r = recon(t);
-(dm[0] * (x[0] - r[0]) + dm[1] * (x[1] - r[1]))
};
let global_min_err = |x: &[f64; 2]| -> f64 {
let mut best = f64::INFINITY;
let g = 20000;
for i in 0..g {
let t = i as f64 / g as f64;
let r = recon(t);
let e = (r[0]-x[0]).powi(2) + (r[1]-x[1]).powi(2);
if e < best { best = e; }
}
best.sqrt()
};
let atlas = crate::encode::EncodeAtlas::build(std::slice::from_ref(&atom), &[1.0], 1.6,
crate::encode::AtlasConfig { grid_resolution: 64, ridge: 1e-10, newton_steps: 8 }).unwrap();
let mut certified = 0usize;
let mut worst_grad = 0.0_f64;
let mut worst_global_excess = 0.0_f64;
let steps = 41;
for ix in 0..steps {
for iy in 0..steps {
let x0 = -0.30 + 0.60 * ix as f64 / (steps - 1) as f64;
let x1 = -0.30 + 0.60 * iy as f64 / (steps - 1) as f64;
let xv = Array1::from(vec![x0, x1]);
let (coord, cert) = atlas.certified_encode_row(&atom, 0, xv.view(), 1.0).unwrap();
if !cert.certified() { continue; }
certified += 1;
let t = coord[0];
worst_grad = worst_grad.max(grad(t, &[x0, x1]).abs());
let r = recon(t);
let cert_err = ((r[0]-x0).powi(2) + (r[1]-x1).powi(2)).sqrt();
worst_global_excess = worst_global_excess.max(cert_err - global_min_err(&[x0, x1]));
}
}
eprintln!("certified={certified}/{} worst|grad|={worst_grad:.2e} worst global excess={worst_global_excess:.4}",
steps*steps);
assert!(
worst_grad < 1e-4,
"certificate's LOCAL claim must hold: certified coords must be stationary \
points (‖∇‖≈0); worst |∇| = {worst_grad:.2e}"
);
assert!(certified > steps * steps / 2, "fixture must certify most targets; got {certified}");
assert!(
worst_global_excess < 5e-3,
"certified encode must be GLOBALLY sound (top-K routing): worst excess over \
the global min = {worst_global_excess:.5} (was ~0.08 with single-chart routing)"
);
}