#[cfg(test)]
mod tests {
use crate::manifold::tests::small_two_atom_periodic_term;
#[test]
fn merge_tiers_concatenates_atoms_assignment_and_rho() {
let (a_term, _a_target, a_rho) = small_two_atom_periodic_term();
let (b_term, _b_target, b_rho) = small_two_atom_periodic_term();
let n = a_term.n_obs();
let p = a_term.output_dim();
let k1 = a_term.k_atoms();
let k2 = b_term.k_atoms();
let a_names: Vec<String> = a_term.atoms.iter().map(|at| at.name.clone()).collect();
let b_names: Vec<String> = b_term.atoms.iter().map(|at| at.name.clone()).collect();
let a_logits = a_term.assignment.logits.clone();
let b_logits = b_term.assignment.logits.clone();
let a_smooth = a_rho.log_lambda_smooth.clone();
let b_smooth = b_rho.log_lambda_smooth.clone();
let a_sparse = a_rho.log_lambda_sparse;
let (merged, merged_rho) =
crate::manifold::SaeManifoldTerm::merge_tiers(a_term, &a_rho, b_term, &b_rho)
.expect("merge_tiers on two compatible K=2 terms");
assert_eq!(merged.k_atoms(), k1 + k2, "K must be K1+K2");
assert_eq!(merged.n_obs(), n, "n_obs preserved");
assert_eq!(merged.output_dim(), p, "output_dim preserved");
let merged_names: Vec<String> = merged.atoms.iter().map(|at| at.name.clone()).collect();
assert_eq!(
merged_names,
[a_names.clone(), b_names.clone()].concat(),
"atom order must be primary ++ secondary"
);
assert_eq!(merged.assignment.logits.dim(), (n, k1 + k2), "merged logits shape");
for j in 0..k1 {
for i in 0..n {
assert_eq!(
merged.assignment.logits[[i, j]], a_logits[[i, j]],
"primary logits column {j} preserved"
);
}
}
for j in 0..k2 {
for i in 0..n {
assert_eq!(
merged.assignment.logits[[i, k1 + j]], b_logits[[i, j]],
"secondary logits column {j} placed at {}",
k1 + j
);
}
}
assert_eq!(merged.assignment.coords.len(), k1 + k2, "coords length");
assert_eq!(merged.assignment.ungated.len(), k1 + k2, "ungated length");
assert_eq!(
merged_rho.log_lambda_smooth,
[a_smooth, b_smooth].concat(),
"log_lambda_smooth = primary ++ secondary"
);
assert_eq!(merged_rho.log_ard.len(), k1 + k2, "log_ard length");
assert_eq!(
merged_rho.log_lambda_sparse, a_sparse,
"global log_lambda_sparse carried from primary"
);
}
#[test]
fn merge_tiers_is_exactly_additive_under_jumprelu_gates() {
let (mut a_term, _at, a_rho) = small_two_atom_periodic_term();
let (mut b_term, _bt, b_rho) = small_two_atom_periodic_term();
let gate = crate::assignment::AssignmentMode::threshold_gate(1.0, 0.0);
a_term.assignment.mode = gate;
b_term.assignment.mode = gate;
assert_eq!(
a_rho.log_lambda_sparse, b_rho.log_lambda_sparse,
"fixture precondition: identical global sparsity across tiers"
);
let fa = a_term
.try_fitted_for_rho(&a_rho)
.expect("tier-1 reconstruction");
let fb = b_term
.try_fitted_for_rho(&b_rho)
.expect("tier-2 reconstruction");
let (merged, merged_rho) =
crate::manifold::SaeManifoldTerm::merge_tiers(a_term, &a_rho, b_term, &b_rho)
.expect("merge two JumpReLU tiers");
let fm = merged
.try_fitted_for_rho(&merged_rho)
.expect("merged reconstruction");
assert_eq!(fm.dim(), fa.dim(), "merged reconstruction shape preserved");
let mut max_abs = 0.0_f64;
for ((i, j), &v) in fm.indexed_iter() {
let expected = fa[[i, j]] + fb[[i, j]];
max_abs = max_abs.max((v - expected).abs());
}
assert!(
max_abs < 1e-12,
"JumpReLU merge must be EXACTLY additive; max |merged - (a+b)| = {max_abs}"
);
let fa_norm: f64 = fa.iter().map(|x| x * x).sum::<f64>().sqrt();
let fb_norm: f64 = fb.iter().map(|x| x * x).sum::<f64>().sqrt();
assert!(
fa_norm > 1e-9 && fb_norm > 1e-9,
"both tiers must contribute a nonzero reconstruction (a={fa_norm}, b={fb_norm})"
);
}
#[test]
fn merge_tiers_resets_stale_gauge_deflation_state() {
let (mut a_term, _at, a_rho) = small_two_atom_periodic_term();
let (b_term, _bt, b_rho) = small_two_atom_periodic_term();
a_term.evidence_gauge_deflation_reanchors = 3;
a_term.evidence_gauge_deflation_last_delta_sign = -1;
a_term.dictionary_cocollapse_reseeds = 2;
let (merged, _merged_rho) =
crate::manifold::SaeManifoldTerm::merge_tiers(a_term, &a_rho, b_term, &b_rho)
.expect("merge two compatible tiers");
assert_eq!(
merged.evidence_gauge_deflation_reanchors, 0,
"reanchor count must reset to 0 on merge"
);
assert_eq!(
merged.evidence_gauge_deflation_last_delta_sign, 0,
"last-delta sign must reset to 0 on merge"
);
assert_eq!(
merged.dictionary_cocollapse_reseeds, 0,
"co-collapse reseed count must reset to 0 on merge"
);
}
#[test]
fn reorder_atoms_gathers_every_per_atom_field_and_round_trips() {
let (a_term, _at, a_rho) = small_two_atom_periodic_term();
let (b_term, _bt, b_rho) = small_two_atom_periodic_term();
let (mut merged, mut rho) =
crate::manifold::SaeManifoldTerm::merge_tiers(a_term, &a_rho, b_term, &b_rho)
.expect("merge to K=4");
let k = merged.k_atoms();
assert_eq!(k, 4, "fixture merge gives K=4");
for j in 0..k {
merged.atoms[j].name = format!("orig{j}");
merged.assignment.logits.column_mut(j).fill((j as f64 + 1.0) * 10.0);
merged.assignment.ungated[j] = j % 2 == 0;
rho.log_lambda_smooth[j] = j as f64;
rho.log_ard[j][0] = 100.0 + j as f64;
}
let order = vec![3usize, 1, 2, 0];
merged
.reorder_atoms(&order, &mut rho)
.expect("valid permutation");
for (i, &o) in order.iter().enumerate() {
assert_eq!(merged.atoms[i].name, format!("orig{o}"), "atom name at {i}");
for row in 0..merged.n_obs() {
assert_eq!(
merged.assignment.logits[[row, i]],
(o as f64 + 1.0) * 10.0,
"logit column at {i} must be old column {o}"
);
}
assert_eq!(merged.assignment.ungated[i], o % 2 == 0, "ungated at {i}");
assert_eq!(rho.log_lambda_smooth[i], o as f64, "log_lambda_smooth at {i}");
assert_eq!(rho.log_ard[i][0], 100.0 + o as f64, "log_ard at {i}");
}
let mut inv = vec![0usize; k];
for (i, &o) in order.iter().enumerate() {
inv[o] = i;
}
merged.reorder_atoms(&inv, &mut rho).expect("inverse permutation");
for j in 0..k {
assert_eq!(merged.atoms[j].name, format!("orig{j}"), "round-trip name {j}");
assert_eq!(rho.log_lambda_smooth[j], j as f64, "round-trip smooth {j}");
}
}
#[test]
fn reorder_atoms_rejects_non_permutation() {
let (a_term, _at, a_rho) = small_two_atom_periodic_term();
let (b_term, _bt, b_rho) = small_two_atom_periodic_term();
let (mut merged, mut rho) =
crate::manifold::SaeManifoldTerm::merge_tiers(a_term, &a_rho, b_term, &b_rho)
.expect("merge to K=4");
let bad = vec![0usize, 0, 1, 2];
assert!(
merged.reorder_atoms(&bad, &mut rho).is_err(),
"reorder_atoms must reject a non-permutation"
);
let short = vec![0usize, 1, 2];
assert!(
merged.reorder_atoms(&short, &mut rho).is_err(),
"reorder_atoms must reject an order of wrong length"
);
}
#[test]
fn merge_tiers_rejects_shape_mismatch() {
let (a_term, _t, a_rho) = small_two_atom_periodic_term();
let (b_term, _t2, b_rho) = small_two_atom_periodic_term();
let mut bad_rho = a_rho.clone();
bad_rho.log_lambda_smooth.push(0.0); let err = crate::manifold::SaeManifoldTerm::merge_tiers(a_term, &bad_rho, b_term, &b_rho);
assert!(err.is_err(), "merge_tiers must reject a rho whose length != K");
}
}