use super::codes::solve_row_codes;
use super::scoring::{TileScorer, top_s_online};
use super::{SparseDictConfig, fit_sparse_dictionary};
use ndarray::{Array2, ArrayView2};
fn planted(k: usize, p: usize, n: usize, second_share: f32) -> (Array2<f32>, Array2<f32>) {
let mut a = Array2::<f64>::zeros((p, p));
for i in 0..p {
for j in 0..p {
a[[i, j]] = ((i * 7 + j * 3 + 1) % 11) as f64 - 5.0;
}
}
let sym = &a + &a.t();
use gam_linalg::faer_ndarray::FaerEigh;
let (_ev, evecs) = sym.eigh(faer::Side::Lower).expect("orthonormal seed");
let mut atoms = Array2::<f32>::zeros((k, p));
for atom in 0..k {
let col = evecs.column(atom % p);
for c in 0..p {
atoms[[atom, c]] = col[c] as f32;
}
}
let mut x = Array2::<f32>::zeros((n, p));
for row in 0..n {
let primary = row % k;
let secondary = (primary + 1) % k;
let scale = 0.7 + 0.01 * (row / k) as f32;
for c in 0..p {
x[[row, c]] =
scale * atoms[[primary, c]] + second_share * scale * atoms[[secondary, c]];
}
}
(x, atoms)
}
fn pca_ev(x: ArrayView2<'_, f32>, rank: usize) -> f64 {
let n = x.nrows();
let p = x.ncols();
let mut means = vec![0.0f64; p];
for i in 0..n {
for c in 0..p {
means[c] += x[[i, c]] as f64;
}
}
for c in 0..p {
means[c] /= n as f64;
}
let mut cov = Array2::<f64>::zeros((p, p));
for i in 0..n {
for a in 0..p {
let xa = x[[i, a]] as f64 - means[a];
for b in 0..p {
cov[[a, b]] += xa * (x[[i, b]] as f64 - means[b]);
}
}
}
use gam_linalg::faer_ndarray::FaerEigh;
let (evals, _) = cov.eigh(faer::Side::Lower).expect("pca eig");
let total: f64 = evals.iter().sum();
let mut sorted: Vec<f64> = evals.to_vec();
sorted.sort_by(|a, b| b.partial_cmp(a).unwrap());
let top: f64 = sorted.iter().take(rank).sum();
if total <= 1.0e-24 { 1.0 } else { top / total }
}
fn held_out_ev(
decoder: ArrayView2<'_, f32>,
x_test: ArrayView2<'_, f32>,
s: usize,
tile: usize,
code_ridge: f32,
) -> f64 {
let n = x_test.nrows();
let p = x_test.ncols();
let scorer = TileScorer::new(s, tile);
let mut means = vec![0.0f64; p];
for i in 0..n {
for c in 0..p {
means[c] += x_test[[i, c]] as f64;
}
}
for c in 0..p {
means[c] /= n as f64;
}
let mut rss = 0.0f64;
let mut tss = 0.0f64;
for i in 0..n {
let row = x_test.row(i);
let active = scorer.route_row(row, decoder);
let code = solve_row_codes(row, decoder, &active, s, code_ridge);
let mut recon = vec![0.0f64; p];
for j in 0..code.indices.len() {
let cj = code.codes[j] as f64;
if cj == 0.0 {
continue;
}
let drow = decoder.row(code.indices[j] as usize);
for c in 0..p {
recon[c] += cj * drow[c] as f64;
}
}
for c in 0..p {
let r = x_test[[i, c]] as f64 - recon[c];
rss += r * r;
let t = x_test[[i, c]] as f64 - means[c];
tss += t * t;
}
}
if tss <= 1.0e-24 {
if rss <= 1.0e-24 { 1.0 } else { 0.0 }
} else {
1.0 - rss / tss
}
}
fn pca_ev_held_out(x_train: ArrayView2<'_, f32>, x_test: ArrayView2<'_, f32>, rank: usize) -> f64 {
let p = x_train.ncols();
let ntr = x_train.nrows();
let mut means = vec![0.0f64; p];
for i in 0..ntr {
for c in 0..p {
means[c] += x_train[[i, c]] as f64;
}
}
for c in 0..p {
means[c] /= ntr as f64;
}
let mut cov = Array2::<f64>::zeros((p, p));
for i in 0..ntr {
for a in 0..p {
let xa = x_train[[i, a]] as f64 - means[a];
for b in 0..p {
cov[[a, b]] += xa * (x_train[[i, b]] as f64 - means[b]);
}
}
}
use gam_linalg::faer_ndarray::FaerEigh;
let (evals, evecs) = cov.eigh(faer::Side::Lower).expect("pca eig");
let mut order: Vec<usize> = (0..p).collect();
order.sort_by(|&a, &b| evals[b].partial_cmp(&evals[a]).unwrap());
let keep: Vec<usize> = order.into_iter().take(rank.min(p)).collect();
let nte = x_test.nrows();
let mut means_te = vec![0.0f64; p];
for i in 0..nte {
for c in 0..p {
means_te[c] += x_test[[i, c]] as f64;
}
}
for c in 0..p {
means_te[c] /= nte as f64;
}
let mut rss = 0.0f64;
let mut tss = 0.0f64;
for i in 0..nte {
let mut centred = vec![0.0f64; p];
for c in 0..p {
centred[c] = x_test[[i, c]] as f64 - means[c];
}
let mut recon = vec![0.0f64; p];
for &k in &keep {
let mut coord = 0.0f64;
for c in 0..p {
coord += centred[c] * evecs[[c, k]];
}
for c in 0..p {
recon[c] += coord * evecs[[c, k]];
}
}
for c in 0..p {
let r = centred[c] - recon[c];
rss += r * r;
let t = x_test[[i, c]] as f64 - means_te[c];
tss += t * t;
}
}
if tss <= 1.0e-24 {
if rss <= 1.0e-24 { 1.0 } else { 0.0 }
} else {
1.0 - rss / tss
}
}
#[test]
fn online_top_s_recovers_planted_largest_scores() {
let p = 50;
let k = 50;
let mut decoder = Array2::<f32>::zeros((k, p));
for atom in 0..k {
decoder[[atom, atom]] = 1.0;
}
let mut row = ndarray::Array1::<f32>::zeros(p);
row[17] = 9.0;
row[4] = 5.0;
row[31] = 3.0;
let picked = top_s_online(row.view(), decoder.view(), 3, 8);
let want_atoms = [17u32, 4u32, 31u32];
assert_eq!(picked.len(), 3);
for (rank, &(atom, score)) in picked.iter().enumerate() {
assert_eq!(
atom, want_atoms[rank],
"rank {rank}: expected atom {}, got atom {atom} (score {score})",
want_atoms[rank]
);
}
}
#[test]
fn tile_scorer_matches_untiled_brute_force() {
let p = 5;
let k = 37;
let mut decoder = Array2::<f32>::zeros((k, p));
for atom in 0..k {
for c in 0..p {
decoder[[atom, c]] = (((atom * 3 + c * 5 + 1) % 7) as f32 - 3.0) / 3.0;
}
}
let row = ndarray::Array1::<f32>::from_vec((0..p).map(|c| (c as f32) - 2.0).collect());
let mut brute: Vec<(u32, f32)> = (0..k)
.map(|a| {
let mut acc = 0.0f32;
for c in 0..p {
acc += row[c] * decoder[[a, c]];
}
(a as u32, acc)
})
.collect();
brute.sort_by(|x, y| {
y.1.abs()
.partial_cmp(&x.1.abs())
.unwrap()
.then(x.0.cmp(&y.0))
});
let scorer = TileScorer::new(4, 7);
let tiled = scorer.route_row(row.view(), decoder.view());
assert_eq!(tiled.len(), 4);
for j in 0..4 {
assert_eq!(
tiled[j].0, brute[j].0,
"tiled top-{j} disagrees with brute force"
);
}
}
#[test]
fn sparse_trainer_recovers_planted_dictionary_beats_pca_baseline() {
let (k, p, n) = (8usize, 12usize, 480usize);
let (x, _atoms) = planted(k, p, n, 0.2);
let config = SparseDictConfig {
n_atoms: k,
active: 2,
minibatch: 128,
max_epochs: 40,
score_tile: 16,
code_ridge: 1.0e-6,
decoder_ridge: 1.0e-6,
tolerance: 1.0e-9,
};
let fit = fit_sparse_dictionary(x.view(), &config).expect("sparse dictionary fit");
let baseline = pca_ev(x.view(), k);
assert!(
fit.explained_variance > 0.95,
"expected EV > 0.95, got {}",
fit.explained_variance
);
assert!(
fit.explained_variance + 1.0e-6 >= baseline,
"sparse trainer EV {} must match-or-beat rank-{k} PCA baseline {}",
fit.explained_variance,
baseline
);
}
#[test]
fn sparse_trainer_beats_rank_k_pca_on_held_out_reconstruction() {
let (k, p, n) = (64usize, 16usize, 1600usize);
let (x, _atoms) = planted(k, p, n, 0.35);
let n_test = n / 5;
let mut train_rows: Vec<usize> = Vec::new();
let mut test_rows: Vec<usize> = Vec::new();
for i in 0..n {
if i % 5 == 0 {
test_rows.push(i);
} else {
train_rows.push(i);
}
}
let mut x_train = Array2::<f32>::zeros((train_rows.len(), p));
for (r, &i) in train_rows.iter().enumerate() {
x_train.row_mut(r).assign(&x.row(i));
}
let mut x_test = Array2::<f32>::zeros((test_rows.len(), p));
for (r, &i) in test_rows.iter().enumerate() {
x_test.row_mut(r).assign(&x.row(i));
}
assert_eq!(x_test.nrows(), n_test);
let s = 2usize;
let tile = 16usize;
let code_ridge = 1.0e-6f32;
let config = SparseDictConfig {
n_atoms: k,
active: s,
minibatch: 256,
max_epochs: 60,
score_tile: tile,
code_ridge,
decoder_ridge: 1.0e-6,
tolerance: 1.0e-9,
};
let fit = fit_sparse_dictionary(x_train.view(), &config).expect("held-out trainer fit");
let sparse_out = held_out_ev(fit.decoder.view(), x_test.view(), s, tile, code_ridge);
let pca_out = pca_ev_held_out(x_train.view(), x_test.view(), k);
assert!(
sparse_out > 0.9,
"held-out sparse-dictionary EV {sparse_out} should explain the planted held-out block"
);
assert!(
sparse_out + 1.0e-4 >= pca_out,
"held-out sparse EV {sparse_out} must match-or-beat held-out rank-{k} PCA baseline {pca_out}"
);
}
fn dead_atom_fraction(fit: &super::SparseDictFit) -> f64 {
let k = fit.decoder.nrows();
let mut alive = vec![false; k];
for (i, idx_row) in fit.indices.outer_iter().enumerate() {
for (j, &idx) in idx_row.iter().enumerate() {
if fit.codes[[i, j]] != 0.0 {
alive[idx as usize] = true;
}
}
}
let dead = alive.iter().filter(|&&a| !a).count();
dead as f64 / k as f64
}
#[test]
fn dead_atom_revival_keeps_ev_monotone_in_k_and_beats_linear_subspace() {
let (planted_k, p, n) = (64usize, 16usize, 2000usize);
let (x, _atoms) = planted(planted_k, p, n, 0.35);
let mut train_rows: Vec<usize> = Vec::new();
let mut test_rows: Vec<usize> = Vec::new();
for i in 0..n {
if i % 5 == 0 {
test_rows.push(i);
} else {
train_rows.push(i);
}
}
let mut x_train = Array2::<f32>::zeros((train_rows.len(), p));
for (r, &i) in train_rows.iter().enumerate() {
x_train.row_mut(r).assign(&x.row(i));
}
let mut x_test = Array2::<f32>::zeros((test_rows.len(), p));
for (r, &i) in test_rows.iter().enumerate() {
x_test.row_mut(r).assign(&x.row(i));
}
let s = 2usize;
let tile = 16usize;
let code_ridge = 1.0e-6f32;
let mk = |k: usize| SparseDictConfig {
n_atoms: k,
active: s,
minibatch: 256,
max_epochs: 60,
score_tile: tile,
code_ridge,
decoder_ridge: 1.0e-6,
tolerance: 1.0e-9,
};
let fit_small = fit_sparse_dictionary(x_train.view(), &mk(16)).expect("K=16 fit");
let fit_mid = fit_sparse_dictionary(x_train.view(), &mk(64)).expect("K=64 fit");
let fit_large = fit_sparse_dictionary(x_train.view(), &mk(256)).expect("K=256 fit");
let ev_small = held_out_ev(fit_small.decoder.view(), x_test.view(), s, tile, code_ridge);
let ev_mid = held_out_ev(fit_mid.decoder.view(), x_test.view(), s, tile, code_ridge);
let ev_large = held_out_ev(fit_large.decoder.view(), x_test.view(), s, tile, code_ridge);
assert!(
ev_mid + 5.0e-3 >= ev_small,
"[#1026] held-out EV must not drop from K=16 ({ev_small:.4}) to K=64 ({ev_mid:.4})"
);
assert!(
ev_large + 5.0e-3 >= ev_mid,
"[#1026] held-out EV must not drop from K=64 ({ev_mid:.4}) to K=256 ({ev_large:.4})"
);
let pca_rank_s = pca_ev_held_out(x_train.view(), x_test.view(), s);
assert!(
ev_large > pca_rank_s + 0.05,
"[#1026] K=256 held-out EV ({ev_large:.4}) must beat fixed rank-{s} PCA \
({pca_rank_s:.4}) — adaptive over-complete sparse coding must dominate a \
single s-dim linear subspace at matched active budget"
);
assert!(
ev_large > 0.85,
"[#1026] K=256 held-out EV ({ev_large:.4}) should resolve the 2-sparse \
planted mixture (reconstruction parity at scale)"
);
}
fn read_npy_f32_2d(path: &str) -> (usize, usize, Vec<f32>) {
let bytes = std::fs::read(path).unwrap_or_else(|e| panic!("read {path}: {e}"));
assert!(
bytes.len() > 10 && &bytes[0..6] == b"\x93NUMPY",
"{path}: not a .npy file"
);
let header_len = u16::from_le_bytes([bytes[8], bytes[9]]) as usize;
let header = std::str::from_utf8(&bytes[10..10 + header_len]).expect("utf8 header");
assert!(
header.contains("'<f4'") || header.contains("\"<f4\""),
"{path}: expected little-endian float32 (<f4); header: {header}"
);
assert!(
header.contains("'fortran_order': False") || header.contains("\"fortran_order\": false"),
"{path}: expected C-order; header: {header}"
);
let shape_start = header.find("'shape':").expect("shape key") + "'shape':".len();
let paren_open = header[shape_start..].find('(').expect("shape (") + shape_start + 1;
let paren_close = header[paren_open..].find(')').expect("shape )") + paren_open;
let dims: Vec<usize> = header[paren_open..paren_close]
.split(',')
.filter_map(|t| t.trim().parse::<usize>().ok())
.collect();
assert_eq!(dims.len(), 2, "{path}: expected a 2-D array, got {dims:?}");
let (n, p) = (dims[0], dims[1]);
let data_off = 10 + header_len;
let expect = n * p * 4;
assert_eq!(
bytes.len() - data_off,
expect,
"{path}: data length mismatch (n={n}, p={p})"
);
let mut data = Vec::with_capacity(n * p);
let mut off = data_off;
for _ in 0..(n * p) {
data.push(f32::from_le_bytes([
bytes[off],
bytes[off + 1],
bytes[off + 2],
bytes[off + 3],
]));
off += 4;
}
(n, p, data)
}
#[test]
fn real_olmo_sparse_dict_ev_vs_k_parity() {
let files = [
concat!(env!("CARGO_MANIFEST_DIR"), "/../../tests/data/olmo_l18_pca64_635.npy"),
concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../tests/data/olmo_mixedlayer_pca64_768.npy"
),
];
for path in files {
let (n, p, data) = read_npy_f32_2d(path);
let x = Array2::from_shape_vec((n, p), data).expect("shape");
let mut tr: Vec<usize> = Vec::new();
let mut te: Vec<usize> = Vec::new();
for i in 0..n {
if i % 5 == 0 {
te.push(i);
} else {
tr.push(i);
}
}
let mut x_tr = Array2::<f32>::zeros((tr.len(), p));
for (r, &i) in tr.iter().enumerate() {
x_tr.row_mut(r).assign(&x.row(i));
}
let mut x_te = Array2::<f32>::zeros((te.len(), p));
for (r, &i) in te.iter().enumerate() {
x_te.row_mut(r).assign(&x.row(i));
}
println!("\n=== {path} (N={n}, P={p}, train={}, test={}) ===", tr.len(), te.len());
for s in [8usize, 32usize] {
let tile = p.max(1);
let pca = pca_ev_held_out(x_tr.view(), x_te.view(), s);
println!(" active s={s} rank-{s} held-out PCA EV = {pca:.4}");
let mut prev = f64::NEG_INFINITY;
for k in [s, 32usize, 128, 512, 1024] {
if k < s {
continue;
}
let config = SparseDictConfig {
n_atoms: k,
active: s,
minibatch: 256,
max_epochs: 40,
score_tile: tile,
code_ridge: 1.0e-6,
decoder_ridge: 1.0e-6,
tolerance: 1.0e-7,
};
let fit = fit_sparse_dictionary(x_tr.view(), &config).expect("fit");
let ev_te = held_out_ev(fit.decoder.view(), x_te.view(), s, tile, 1.0e-6);
let dead = dead_atom_fraction(&fit);
let mono = if ev_te + 5.0e-3 >= prev { "" } else { " <-- DROP" };
println!(
" K={k:5} train_EV={:.4} test_EV={ev_te:.4} dead={dead:.3} epochs={}{mono}",
fit.explained_variance, fit.epochs
);
prev = ev_te;
}
}
}
}
#[test]
fn fixed_width_sparse_storage_never_dense_and_reconstructs() {
let (k, p, n) = (6usize, 8usize, 240usize);
let (x, _atoms) = planted(k, p, n, 0.0);
let config = SparseDictConfig {
n_atoms: k,
active: 1,
max_epochs: 30,
score_tile: 4,
..SparseDictConfig::new(k)
};
let fit = fit_sparse_dictionary(x.view(), &config).expect("fit");
assert_eq!(fit.indices.dim(), (n, 1));
assert_eq!(fit.codes.dim(), (n, 1));
assert_eq!(fit.decoder.dim(), (k, p));
let recon = fit.reconstruct();
let mut rss = 0.0f64;
let mut tss = 0.0f64;
let mut means = vec![0.0f64; p];
for i in 0..n {
for c in 0..p {
means[c] += x[[i, c]] as f64;
}
}
for c in 0..p {
means[c] /= n as f64;
}
for i in 0..n {
for c in 0..p {
let r = x[[i, c]] as f64 - recon[[i, c]] as f64;
rss += r * r;
let t = x[[i, c]] as f64 - means[c];
tss += t * t;
}
}
let recon_ev = 1.0 - rss / tss;
assert!(
(recon_ev - fit.explained_variance).abs() < 1.0e-4,
"packed-code reconstruction EV {recon_ev} disagrees with reported {}",
fit.explained_variance
);
}
#[test]
fn route_minibatch_returns_a_valid_top_s() {
let (k, p, n) = (40usize, 11usize, 137usize);
let mut decoder = Array2::<f32>::zeros((k, p));
for atom in 0..k {
for c in 0..p {
decoder[[atom, c]] = (((atom * 5 + c * 3 + 1) % 13) as f32 - 6.0) / 6.0;
}
}
for mut row in decoder.outer_iter_mut() {
let nrm: f32 = row.iter().map(|v| v * v).sum::<f32>().sqrt();
if nrm > 1.0e-12 {
row.mapv_inplace(|v| v / nrm);
}
}
let mut x = Array2::<f32>::zeros((n, p));
for row in 0..n {
for c in 0..p {
x[[row, c]] = (((row * 7 + c * 2 + 3) % 17) as f32 - 8.0) / 4.0;
}
}
let s = 4usize;
let scorer = TileScorer::new(s, 7);
let batched = scorer.route_minibatch(x.view(), decoder.view());
assert_eq!(batched.len(), n);
let exact_mag = |row: usize, atom: usize| -> f64 {
let mut acc = 0.0f64;
for c in 0..p {
acc += x[[row, c]] as f64 * decoder[[atom, c]] as f64;
}
acc.abs()
};
const TOL: f64 = 1.0e-5;
for (i, shortlist) in batched.iter().enumerate() {
assert_eq!(shortlist.len(), s, "row {i}: shortlist must have width s");
let mut seen = std::collections::HashSet::new();
for &(atom, _) in shortlist {
assert!(seen.insert(atom), "row {i}: atom {atom} selected twice");
}
for &(atom, score) in shortlist {
assert!(
(score.abs() as f64 - exact_mag(i, atom as usize)).abs() <= TOL,
"row {i}: reported |score| {} for atom {atom} != exact {}",
score.abs(),
exact_mag(i, atom as usize)
);
}
let mut all: Vec<f64> = (0..k).map(|a| exact_mag(i, a)).collect();
all.sort_by(|a, b| b.partial_cmp(a).unwrap());
let cutoff = all[s - 1];
for &(atom, _) in shortlist {
assert!(
exact_mag(i, atom as usize) + TOL >= cutoff,
"row {i}: selected atom {atom} (|score| {}) is below the top-{s} cutoff {cutoff}",
exact_mag(i, atom as usize)
);
}
for w in shortlist.windows(2) {
assert!(
w[0].1.abs() + (TOL as f32) >= w[1].1.abs(),
"row {i}: shortlist not sorted by descending |score|"
);
}
}
}
#[test]
fn fit_is_minibatch_size_invariant() {
let (k, p, n) = (8usize, 12usize, 480usize);
let (x, _atoms) = planted(k, p, n, 0.2);
let base = SparseDictConfig {
n_atoms: k,
active: 2,
minibatch: 1,
max_epochs: 40,
score_tile: 16,
code_ridge: 1.0e-6,
decoder_ridge: 1.0e-6,
tolerance: 1.0e-9,
};
let fit_mb1 = fit_sparse_dictionary(x.view(), &base).expect("minibatch=1 fit");
let fit_mbn = fit_sparse_dictionary(
x.view(),
&SparseDictConfig {
minibatch: n,
..base
},
)
.expect("minibatch=N fit");
let fit_mb_mid = fit_sparse_dictionary(
x.view(),
&SparseDictConfig {
minibatch: 64,
..base
},
)
.expect("minibatch=64 fit");
assert!(
(fit_mb1.explained_variance - fit_mbn.explained_variance).abs() < 1.0e-4,
"minibatch=1 EV {} vs minibatch=N EV {} must agree",
fit_mb1.explained_variance,
fit_mbn.explained_variance
);
assert!(
(fit_mb1.explained_variance - fit_mb_mid.explained_variance).abs() < 1.0e-4,
"minibatch=1 EV {} vs minibatch=64 EV {} must agree",
fit_mb1.explained_variance,
fit_mb_mid.explained_variance
);
}
#[test]
fn scales_to_large_k_without_dense_n_by_k() {
let (planted_k, p, n) = (8usize, 10usize, 240usize);
let (x, _atoms) = planted(planted_k, p, n, 0.1);
let k = 2000usize;
let config = SparseDictConfig {
n_atoms: k,
active: 1,
max_epochs: 6,
score_tile: 256,
..SparseDictConfig::new(k)
};
let fit = fit_sparse_dictionary(x.view(), &config).expect("large-K fit");
assert_eq!(fit.indices.dim(), (n, 1));
assert!(
fit.explained_variance > 0.9,
"large-K trainer should still explain the low-rank signal; got {}",
fit.explained_variance
);
}
#[test]
fn large_k_fit_routes_on_gpu_above_breakeven_and_is_reproducible() {
let (planted_k, p, n) = (8usize, 48usize, 1536usize);
let (x, _atoms) = planted(planted_k, p, n, 0.1);
let k = 4096usize;
let config = SparseDictConfig {
n_atoms: k,
active: 2,
minibatch: 512,
max_epochs: 4,
score_tile: 1024,
..SparseDictConfig::new(k)
};
let fit = fit_sparse_dictionary(x.view(), &config).expect("large-K fit");
let fit2 = fit_sparse_dictionary(x.view(), &config).expect("large-K fit (rerun)");
assert_eq!(
fit.decoder, fit2.decoder,
"[#1026] sparse-dict fit is non-deterministic across runs (GPU route must \
be bit-reproducible)"
);
assert_eq!(fit.indices, fit2.indices);
assert_eq!(fit.codes, fit2.codes);
assert!(
fit.explained_variance > 0.9,
"[#1026] large-K fit should explain the low-rank signal; got {}",
fit.explained_variance
);
}