use ndarray::{Array1, Array2};
use std::sync::Arc;
use crate::manifold::{
AssignmentMode, BehaviorBlock, LatentManifold, PeriodicHarmonicEvaluator, SaeAssignment,
SaeAtomBasisKind, SaeBasisEvaluator, SaeManifoldAtom, SaeManifoldRho, SaeManifoldTerm,
reconstruction_explained_variance,
};
const ON: f64 = 6.0;
fn softmax(logits: &[f64]) -> Vec<f64> {
let m = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = logits.iter().map(|&l| (l - m).exp()).collect();
let sum: f64 = exps.iter().sum();
exps.into_iter().map(|e| e / sum).collect()
}
fn block_ev(target: &Array2<f64>, fitted: &Array2<f64>, c0: usize, c1: usize) -> f64 {
let t = target.slice(ndarray::s![.., c0..c1]).to_owned();
let f = fitted.slice(ndarray::s![.., c0..c1]).to_owned();
reconstruction_explained_variance(t.view(), f.view())
.unwrap_or_else(|| panic!("EV undefined for block [{c0},{c1})"))
}
fn augmented_circle_atom(
evaluator: &Arc<PeriodicHarmonicEvaluator>,
coords: &Array2<f64>,
p_tot: usize,
) -> (SaeManifoldAtom, Array2<f64>) {
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let m = phi.ncols();
let atom = SaeManifoldAtom::new(
"b0",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
Array2::<f64>::zeros((m, p_tot)),
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_second_jet(evaluator.clone());
(atom, coords.clone())
}
fn build_k1(atom: SaeManifoldAtom, coord_block: Array2<f64>) -> (SaeManifoldTerm, SaeManifoldRho) {
let n = coord_block.nrows();
let logits = Array2::<f64>::from_elem((n, 1), ON);
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
vec![coord_block],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(1.0),
)
.unwrap();
let term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let rho = SaeManifoldRho::new(0.0, 0.0, vec![Array1::<f64>::zeros(1)]);
(term, rho)
}
#[test]
fn two_block_joint_fit_reconstructs_activation_and_behavior() {
let n = 60usize;
let p_x = 4usize;
let vocab = 4usize; let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(5).unwrap());
let coords = Array2::<f64>::from_shape_fn((n, 1), |(i, _)| i as f64 / n as f64);
let mut z = Array2::<f64>::zeros((n, p_x));
let mut probs = Array2::<f64>::zeros((n, vocab));
for i in 0..n {
let theta = std::f64::consts::TAU * (i as f64 / n as f64);
z[[i, 0]] = theta.cos();
z[[i, 1]] = theta.sin();
z[[i, 2]] = (2.0 * theta).cos();
let law = softmax(&[
1.5 * theta.cos(),
1.5 * theta.sin(),
0.3 * (2.0 * theta).cos(),
0.0,
]);
for j in 0..vocab {
probs[[i, j]] = law[j];
}
}
let block = BehaviorBlock::fit(probs.view(), p_x, 0.0).unwrap();
let p_y = block.behavior_dim();
assert_eq!(p_y, vocab - 1);
let p_tot = p_x + p_y;
let augmented = block.augmented_target(z.view()).unwrap();
assert_eq!(augmented.dim(), (n, p_tot));
let (atom, cb) = augmented_circle_atom(&evaluator, &coords, p_tot);
let (mut term, mut rho) = build_k1(atom, cb);
term.set_behavior_block(block.clone()).unwrap();
assert_eq!(term.activation_output_dim(), p_x);
assert_eq!(term.behavior_output_range(), Some(p_x..p_tot));
term.set_guards_enabled(false);
term.run_joint_fit_arrow_schur(augmented.view(), &mut rho, None, 48, 1.0, 1e-6, 1e-6)
.expect("two-block joint fit must complete");
let fitted = term.try_fitted_for_rho(&rho).unwrap();
let ev_act = block_ev(&augmented, &fitted, 0, p_x);
let ev_beh = block_ev(&augmented, &fitted, p_x, p_tot);
assert!(ev_act > 0.9, "activation block EV too low: {ev_act}");
assert!(ev_beh > 0.9, "behavior block EV too low: {ev_beh}");
let (b_k, c_k) = block
.split_decoder(term.atoms[0].decoder_coefficients.view())
.unwrap();
assert_eq!(b_k.dim().1, p_x);
assert_eq!(c_k.dim().1, p_y);
let c_norm = c_k.iter().map(|v| v * v).sum::<f64>().sqrt();
assert!(
c_norm > 0.1,
"behavior decoder collapsed to ~0 despite real behavior: {c_norm}"
);
let mut worst_kl = 0.0_f64;
for &i in &[0usize, n / 4, n / 2, 3 * n / 4] {
let inv = 1.0 / block.sqrt_lambda_y();
let y_hat = Array1::from_shape_fn(p_y, |j| fitted[[i, p_x + j]] * inv);
let p_hat = block.embedding.decode(y_hat.view()).unwrap();
let kl =
crate::manifold::SphereTangentEmbedding::exact_kl(probs.row(i), p_hat.view()).unwrap();
worst_kl = worst_kl.max(kl);
}
assert!(
worst_kl < 0.02,
"decoded behavior diverges from planted (worst KL {worst_kl} nats)"
);
}
#[test]
fn constant_behavior_yields_zero_behavior_decoder() {
let n = 48usize;
let p_x = 3usize;
let vocab = 4usize;
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(3).unwrap());
let coords = Array2::<f64>::from_shape_fn((n, 1), |(i, _)| i as f64 / n as f64);
let mut z = Array2::<f64>::zeros((n, p_x));
let mut probs = Array2::<f64>::zeros((n, vocab));
let flat = softmax(&[0.2, 0.1, -0.1, 0.0]); for i in 0..n {
let theta = std::f64::consts::TAU * (i as f64 / n as f64);
z[[i, 0]] = theta.cos();
z[[i, 1]] = theta.sin();
for j in 0..vocab {
probs[[i, j]] = flat[j];
}
}
let block = BehaviorBlock::fit(probs.view(), p_x, 0.0).unwrap();
assert!(block.target.iter().all(|v| v.abs() < 1e-10));
let p_y = block.behavior_dim();
let p_tot = p_x + p_y;
let augmented = block.augmented_target(z.view()).unwrap();
let (atom, cb) = augmented_circle_atom(&evaluator, &coords, p_tot);
let (mut term, mut rho) = build_k1(atom, cb);
term.set_behavior_block(block.clone()).unwrap();
term.set_guards_enabled(false);
term.run_joint_fit_arrow_schur(augmented.view(), &mut rho, None, 48, 1.0, 1e-6, 1e-6)
.expect("fit must complete");
let (b_k, c_k) = block
.split_decoder(term.atoms[0].decoder_coefficients.view())
.unwrap();
let b_norm = b_k.iter().map(|v| v * v).sum::<f64>().sqrt();
let c_norm = c_k.iter().map(|v| v * v).sum::<f64>().sqrt();
assert!(
b_norm > 0.1,
"activation decoder should still fit the circle: {b_norm}"
);
assert!(
c_norm < 1e-6,
"constant behavior must earn a ~0 behavior decoder; got {c_norm}"
);
}