use super::*;
use crate::assignment::{default_ibp_concentration_for_k_atoms, ordered_geometric_shrinkage_prior};
use crate::basis::PeriodicHarmonicEvaluator;
use gam_linalg::faer_ndarray::{fast_atb, FaerCholesky};
use gam_terms::dictionary::{fit_linear_dictionary, LinearDictionaryConfig};
use ndarray::{s, Array2, ArrayView2};
use std::sync::Arc;
fn real_like_activations(n: usize, p: usize, rank: usize, seed: u64) -> Array2<f64> {
let mut state = seed.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut next_u64 = move || {
state = state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
};
let mut normal = move || {
let u1 = ((next_u64() >> 11) as f64 + 1.0) / ((1u64 << 53) as f64 + 1.0);
let u2 = ((next_u64() >> 11) as f64) / ((1u64 << 53) as f64);
(-2.0 * u1.ln()).sqrt() * (std::f64::consts::TAU * u2).cos()
};
let v = Array2::from_shape_fn((p, rank), |_| normal());
let scores = Array2::from_shape_fn((n, rank), |(_, r)| normal() / ((r + 1) as f64));
let mut z = scores.dot(&v.t());
for e in z.iter_mut() {
*e += 0.02 * normal();
}
z
}
fn circle_dictionary_term(
z: ArrayView2<'_, f64>,
k: usize,
num_basis: usize,
alpha: f64,
) -> SaeManifoldTerm {
let n = z.nrows();
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(num_basis).unwrap());
let basis_kinds = vec![SaeAtomBasisKind::Periodic; k];
let atom_dims = vec![1usize; k];
let seed_coords = sae_pca_seed_initial_coords(z, &basis_kinds, &atom_dims).unwrap();
let mut atoms = Vec::with_capacity(k);
let mut coords_blocks = Vec::with_capacity(k);
let mut manifolds = Vec::with_capacity(k);
for atom_idx in 0..k {
let coords = seed_coords.slice(s![atom_idx, .., 0..1]).to_owned();
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let m = phi.ncols();
let mut xtx = fast_atb(&phi, &phi);
for i in 0..m {
xtx[[i, i]] += 1.0e-8;
}
let xtz = fast_atb(&phi, &z.to_owned());
let decoder = xtx.cholesky(Side::Lower).unwrap().solve_mat(&xtz);
let atom = SaeManifoldAtom::new(
"circle",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
decoder,
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_evaluator(evaluator.clone());
atoms.push(atom);
coords_blocks.push(coords);
manifolds.push(LatentManifold::Circle { period: 1.0 });
}
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, k)),
coords_blocks,
manifolds,
AssignmentMode::ibp_map(1.0, alpha, false),
)
.unwrap();
SaeManifoldTerm::new(atoms, assignment).unwrap()
}
fn fit_ev(
z: ArrayView2<'_, f64>,
k: usize,
alpha: f64,
num_basis: usize,
max_iter: usize,
) -> Result<f64, String> {
let mut term = circle_dictionary_term(z, k, num_basis, alpha);
let mut rho = SaeManifoldRho::new(
1.0_f64.ln(),
1.0_f64.ln(),
vec![ndarray::array![1.0_f64.ln()]; k],
);
term.run_joint_fit_arrow_schur(z, &mut rho, None, max_iter, 1.0, 1.0e-6, 1.0e-6)?;
let fitted = term.try_fitted_for_rho(&rho)?;
reconstruction_explained_variance(z, fitted.view()).ok_or_else(|| {
"reconstruction_explained_variance undefined (shape mismatch or degenerate total variance)"
.to_string()
})
}
fn linear_ev(z: ArrayView2<'_, f64>, k: usize) -> f64 {
let cfg = LinearDictionaryConfig {
n_atoms: k,
top_k: 1,
max_iter: 30,
..LinearDictionaryConfig::default()
};
fit_linear_dictionary(z, &cfg).unwrap().explained_variance
}
#[test]
fn ibp_default_alpha_underfits_but_k_aware_matches_linear_1784() {
let z = real_like_activations(64, 10, 6, 7);
let k = 8usize;
let num_basis = 3usize;
let max_iter = 8usize;
let lin = linear_ev(z.view(), k);
let ev_alpha1 = fit_ev(z.view(), k, 1.0, num_basis, max_iter).expect("alpha=1 fit runs");
let ev_kaware = fit_ev(
z.view(),
k,
default_ibp_concentration_for_k_atoms(k),
num_basis,
max_iter,
)
.expect("K-aware fit runs");
eprintln!(
"#1784 K={k}: linear EV={lin:.4} manifold(alpha=1) EV={ev_alpha1:.4} manifold(K-aware) EV={ev_kaware:.4}"
);
assert!(
ev_alpha1 + 0.015 < lin,
"alpha=1 IBP prior should structurally underfit the equal-K linear dictionary \
(manifold {ev_alpha1:.4} vs linear {lin:.4}) at K={k}"
);
assert!(
ev_kaware + 0.02 >= lin,
"K-aware IBP prior must reconstruct at least as well as the equal-K linear \
dictionary (manifold {ev_kaware:.4} vs linear {lin:.4}) at K={k}"
);
assert!(
ev_kaware > ev_alpha1 + 0.02,
"K-aware concentration must recover capacity the alpha=1 mask threw away \
(K-aware {ev_kaware:.4} vs alpha=1 {ev_alpha1:.4})"
);
}
#[test]
fn ibp_k_aware_prior_keeps_all_128_atoms_alive_1784() {
let k = 128usize;
let prior_alpha1 = ordered_geometric_shrinkage_prior(k, 1.0);
assert!(
prior_alpha1[k - 1] < 1.0e-30,
"alpha=1 must mask the last atom (pi_127={:e}) — the dead-atom rank deficiency \
behind the K=128 throw",
prior_alpha1[k - 1]
);
let alpha = default_ibp_concentration_for_k_atoms(k);
let prior = ordered_geometric_shrinkage_prior(k, alpha);
assert!(
prior[k - 1] > 0.3,
"K-aware prior must keep the last of {k} atoms alive (pi_127={:.4}, alpha={alpha:.2})",
prior[k - 1]
);
assert!(
prior[0] / prior[k - 1] < 3.0,
"K-aware prior head/tail span must be <= ~e (got {:.3})",
prior[0] / prior[k - 1]
);
}