use crate::linalg::faer_ndarray::fast_ata;
use super::*;
use ndarray::array;
pub(crate) fn real_data_torus_seed_term(
z: ArrayView2<'_, f64>,
k: usize,
num_harmonics: usize,
) -> SaeManifoldTerm {
let n = z.nrows();
let evaluator = Arc::new(TorusHarmonicEvaluator::new(2, num_harmonics).unwrap());
let basis_kinds = vec![SaeAtomBasisKind::Periodic; k];
let atom_dims = vec![2usize; k];
let seed_coords = sae_pca_seed_initial_coords(z, &basis_kinds, &atom_dims).unwrap();
let mut atoms = Vec::with_capacity(k);
let mut coords_blocks = Vec::with_capacity(k);
let mut manifolds = Vec::with_capacity(k);
for atom_idx in 0..k {
let coords = seed_coords.slice(s![atom_idx, .., 0..2]).to_owned();
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let m = phi.ncols();
let mut xtx = fast_ata(&phi);
for i in 0..m {
xtx[[i, i]] += 1.0e-8;
}
let xtz = fast_atb(&phi, &z.to_owned());
let decoder = xtx.cholesky(Side::Lower).unwrap().solve_mat(&xtz);
let atom = SaeManifoldAtom::new(
"torus",
SaeAtomBasisKind::Periodic,
2,
phi,
jet,
decoder,
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_evaluator(evaluator.clone());
atoms.push(atom);
coords_blocks.push(coords);
manifolds.push(LatentManifold::Product(vec![
LatentManifold::Circle { period: 1.0 },
LatentManifold::Circle { period: 1.0 },
]));
}
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::from_elem((n, k), 0.0),
coords_blocks,
manifolds,
AssignmentMode::softmax(1.0),
)
.unwrap();
SaeManifoldTerm::new(atoms, assignment).unwrap()
}
#[test]
pub(crate) fn olmo_real_curvature_anchor_is_positive_definite() {
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/olmo_mixedlayer_pca64_768.npy");
let z = read_npy_f32_2d(&path);
assert_eq!(z.dim(), (768, 64), "real OLMo fixture shape");
let z_train = z.slice(s![..160, ..]).to_owned();
let k = 2usize;
let mut term = real_data_torus_seed_term(z_train.view(), k, 3);
let rho = SaeManifoldRho::new(0.0, 0.0, vec![array![0.0, 0.0]; k]);
let registry = SaeManifoldOuterObjective::new(
term.clone(),
z_train.clone(),
None,
rho.clone(),
0,
0.04,
1.0e-6,
1.0e-6,
)
.registry;
use crate::linalg::faer_ndarray::FaerEigh;
let sys = term
.assemble_arrow_schur(z_train.view(), &rho, registry.as_ref())
.expect("assemble raw curvature anchor");
let mut min_raw_eig = f64::INFINITY;
let mut max_raw_eig = 0.0_f64;
let mut indefinite_rows = 0usize;
let mut total_neg_dirs = 0usize;
for block in &sys.rows {
let d = block.htt.nrows();
if d == 0 {
continue;
}
let mut sym = Array2::<f64>::zeros((d, d));
for i in 0..d {
for j in 0..d {
sym[[i, j]] = 0.5 * (block.htt[[i, j]] + block.htt[[j, i]]);
}
}
let (evals, _) = sym.eigh(faer::Side::Lower).unwrap();
let max_abs = evals.iter().fold(0.0_f64, |a, &v| a.max(v.abs())).max(1.0);
let neg_floor = -1.0e-8 * max_abs;
let row_min = evals.iter().cloned().fold(f64::INFINITY, f64::min);
let row_neg = evals.iter().filter(|&&v| v < neg_floor).count();
min_raw_eig = min_raw_eig.min(row_min);
max_raw_eig = max_raw_eig.max(max_abs);
if row_neg > 0 {
indefinite_rows += 1;
total_neg_dirs += row_neg;
}
}
let rel_min = min_raw_eig / max_raw_eig.max(1.0);
eprintln!(
"[#1190] real-data curvature anchor (K={k}, N={}): RAW assembled H_tt \
min_eig={min_raw_eig:.6e} (rel={rel_min:.3e}) indefinite_rows={indefinite_rows}/{} \
total_neg_dirs={total_neg_dirs}",
z_train.nrows(),
sys.rows.len()
);
assert!(
rel_min >= -1.0e-8,
"real-data curvature anchor is genuinely indefinite: raw assembled H_tt \
min eigenvalue {min_raw_eig:.6e} (relative {rel_min:.3e}) is negative on \
{indefinite_rows}/{} rows ({total_neg_dirs} negative directions) — the \
d=2 atoms are under-identified on real OLMo activations (#1190). The \
curvature anchor must be PD (or its negative directions must be genuine \
closed-form gauge nulls, not data-supported directions).",
sys.rows.len()
);
}
#[test]
pub(crate) fn olmo_real_outer_fit_does_not_pin_at_collapse_sentinel() {
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/olmo_mixedlayer_pca64_768.npy");
let z = read_npy_f32_2d(&path);
assert_eq!(z.dim(), (768, 64), "real OLMo fixture shape");
let z_train = z.slice(s![..384, ..]).to_owned();
let k = 8usize;
let term = real_data_torus_seed_term(z_train.view(), k, 2);
let rho = SaeManifoldRho::new(0.0, 0.0, vec![array![0.0, 0.0]; k]);
let objective = SaeManifoldOuterObjective::new(
term,
z_train.clone(),
None,
rho.clone(),
0,
0.04,
1.0e-6,
1.0e-6,
);
let anchor = linear_span_anchor(&objective.term, z_train.view())
.expect("Eckart-Young anchor must be recoverable on the real fixture");
let sst = {
let mut means = vec![0.0_f64; z_train.ncols()];
for col in 0..z_train.ncols() {
let mut acc = 0.0;
for row in 0..z_train.nrows() {
acc += z_train[[row, col]];
}
means[col] = acc / z_train.nrows() as f64;
}
let mut s = 0.0_f64;
for row in 0..z_train.nrows() {
for col in 0..z_train.ncols() {
let c = z_train[[row, col]] - means[col];
s += c * c;
}
}
s
};
let anchor_ev = 1.0 - anchor.residual_norm_sq / sst;
assert!(
anchor_ev.is_finite() && anchor_ev > SAE_FIT_DATA_COLLAPSE_EV_FLOOR,
"real-data Eckart-Young anchor ceiling {anchor_ev:.5} is degenerate (#1189)."
);
eprintln!("[#1189] real-data anchor ceiling anchor_ev={anchor_ev:.5}");
let arrival_floor_k = |achievable_ceiling: f64, k_active: usize| -> f64 {
let base = CURVATURE_WALK_ARRIVAL_EV_FLOOR
.min(CURVATURE_WALK_ARRIVAL_ANCHOR_FRACTION * achievable_ceiling)
.max(SAE_FIT_DATA_COLLAPSE_EV_FLOOR);
if k_active >= 2 {
let k = k_active as f64;
let per_atom_share_floor = achievable_ceiling * ((k - 1.0) / k);
base.min(per_atom_share_floor.max(0.0))
.max(SAE_FIT_DATA_COLLAPSE_EV_FLOOR)
} else {
base
}
};
let arrival_floor = |achievable_ceiling: f64| -> f64 { arrival_floor_k(achievable_ceiling, 1) };
let real_regime_ceiling = 0.40_f64; let real_floor = arrival_floor(real_regime_ceiling);
eprintln!(
"[#1189] real regime: ceiling={real_regime_ceiling} absolute_floor={CURVATURE_WALK_ARRIVAL_EV_FLOOR} relative_floor={real_floor:.5}"
);
assert!(
real_floor < CURVATURE_WALK_ARRIVAL_EV_FLOOR,
"the #1189 relative floor did NOT relax below the absolute floor on the real-data regime \
(ceiling {real_regime_ceiling}, relative floor {real_floor:.5} >= absolute \
{CURVATURE_WALK_ARRIVAL_EV_FLOOR}): a genuine fit at the achievable ceiling would still be \
rejected and demoted to the collapsing cascade (#1189)."
);
assert!(
real_regime_ceiling >= real_floor,
"a genuine fit AT the achievable real-data ceiling {real_regime_ceiling} is rejected by the \
#1189 relative floor {real_floor:.5} (#1189)."
);
let synthetic_ceiling = 0.95_f64;
let synthetic_floor = arrival_floor(synthetic_ceiling);
assert!(
(synthetic_floor - CURVATURE_WALK_ARRIVAL_EV_FLOOR).abs() < 1e-12,
"the #1189 relative floor wrongly relaxed the gate on the synthetic regime (ceiling \
{synthetic_ceiling}, floor {synthetic_floor:.5} != absolute {CURVATURE_WALK_ARRIVAL_EV_FLOOR}); \
planted-harmonic recovery must keep the strict absolute floor (#1189)."
);
let pathological_floor = arrival_floor(0.0);
assert!(
pathological_floor >= SAE_FIT_DATA_COLLAPSE_EV_FLOOR,
"the #1189 floor dropped below the data-collapse threshold on a pathological ceiling \
(floor {pathological_floor:.5} < {SAE_FIT_DATA_COLLAPSE_EV_FLOOR}) (#1189)."
);
let real_anchor_floor = arrival_floor(anchor_ev);
assert!(
(SAE_FIT_DATA_COLLAPSE_EV_FLOOR..=CURVATURE_WALK_ARRIVAL_EV_FLOOR)
.contains(&real_anchor_floor),
"real-data anchor floor {real_anchor_floor:.5} fell outside [{SAE_FIT_DATA_COLLAPSE_EV_FLOOR}, \
{CURVATURE_WALK_ARRIVAL_EV_FLOOR}] (#1189)."
);
let k3_linear_ceiling = 0.30_f64;
let k3_curved_arrival = 0.2461_f64; let k3_floor = arrival_floor_k(k3_linear_ceiling, 3);
let k3_floor_old = CURVATURE_WALK_ARRIVAL_EV_FLOOR
.min(CURVATURE_WALK_ARRIVAL_ANCHOR_FRACTION * k3_linear_ceiling)
.max(SAE_FIT_DATA_COLLAPSE_EV_FLOOR);
eprintln!(
"[#1026] K=3 ceiling={k3_linear_ceiling:.4} curved_arrival={k3_curved_arrival:.4} \
new_floor={k3_floor:.4} old_floor={k3_floor_old:.4}"
);
assert!(
k3_curved_arrival >= k3_floor,
"[#1026] the per-atom-share floor {k3_floor:.4} still demotes a genuine curved K=3 \
arrival at EV {k3_curved_arrival:.4} (linear ceiling {k3_linear_ceiling:.4}); the K>=2 \
co-collapse regression is NOT fixed."
);
assert!(
k3_curved_arrival < k3_floor_old,
"[#1026] the OLD full-ceiling floor {k3_floor_old:.4} should have demoted the curved K=3 \
arrival at EV {k3_curved_arrival:.4} — if it did not, this fixture no longer exercises \
the co-collapse bug and the regression is vacuous."
);
let f1 = arrival_floor_k(k3_linear_ceiling, 1);
let f2 = arrival_floor_k(k3_linear_ceiling, 2);
let f3 = arrival_floor_k(k3_linear_ceiling, 3);
let f8 = arrival_floor_k(k3_linear_ceiling, 8);
assert!(
f2 <= f1 + 1e-12,
"[#1026] K=2 floor {f2:.4} should relax BELOW the K=1 base gate {f1:.4} \
(the per-atom share must forgive one collapsed atom)."
);
assert!(
f2 <= f3 && f3 <= f8,
"[#1026] per-atom-share floor is not monotone non-decreasing across K>=2 \
(K=2 {f2:.4}, K=3 {f3:.4}, K=8 {f8:.4})."
);
assert!(
f8 <= f1 + 1e-12,
"[#1026] the share floor exceeded the K=1 base gate at large K \
(K=8 {f8:.4} > base {f1:.4})."
);
assert_eq!(SAE_FIT_DATA_COLLAPSE_COST, 1.0e12);
}
pub(crate) fn read_npy_f32_2d(path: &std::path::Path) -> Array2<f64> {
let bytes = std::fs::read(path).unwrap_or_else(|e| panic!("read {}: {e}", path.display()));
assert!(
bytes.len() > 10 && &bytes[0..6] == b"\x93NUMPY",
"not a .npy"
);
let major = bytes[6];
let (hdr_start, hdr_len) = if major == 1 {
(10usize, u16::from_le_bytes([bytes[8], bytes[9]]) as usize)
} else {
(
12usize,
u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize,
)
};
let data_off = hdr_start + hdr_len;
let header = std::str::from_utf8(&bytes[hdr_start..data_off]).unwrap();
assert!(
header.contains("'<f4'") || header.contains("\"<f4\""),
"fixture must be little-endian float32; header: {header}"
);
assert!(!header.contains("True"), "fixture must be C-contiguous");
let open = header.find('(').unwrap();
let close = header[open..].find(')').unwrap() + open;
let dims: Vec<usize> = header[open + 1..close]
.split(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.map(|s| s.parse::<usize>().unwrap())
.collect();
assert_eq!(dims.len(), 2, "fixture must be 2-D");
let (n, p) = (dims[0], dims[1]);
let mut out = Array2::<f64>::zeros((n, p));
let payload = &bytes[data_off..];
assert!(payload.len() >= n * p * 4, "truncated payload");
for r in 0..n {
for c in 0..p {
let i = (r * p + c) * 4;
let v =
f32::from_le_bytes([payload[i], payload[i + 1], payload[i + 2], payload[i + 3]]);
out[[r, c]] = v as f64;
}
}
out
}