use super::codes::{SparseCode, solve_row_codes};
use super::scoring::TileScorer;
use super::{SparseDictConfig, SparseDictFit};
use ndarray::{Array2, ArrayView2, Axis};
use rayon::prelude::*;
use std::collections::HashMap;
fn route_and_code_all(
x: ArrayView2<'_, f32>,
decoder: ArrayView2<'_, f32>,
scorer: &TileScorer,
s: usize,
code_ridge: f32,
minibatch: usize,
) -> Vec<SparseCode> {
let n = x.nrows();
let batch = minibatch.max(1);
let mut codes: Vec<SparseCode> = Vec::with_capacity(n);
let mut start = 0usize;
while start < n {
let end = (start + batch).min(n);
let block = x.slice(ndarray::s![start..end, ..]);
let active_lists = route_block(block, decoder, scorer);
let mut block_codes: Vec<SparseCode> = block
.axis_iter(Axis(0))
.into_par_iter()
.zip(active_lists.into_par_iter())
.map(|(row, active)| solve_row_codes(row, decoder, &active, s, code_ridge))
.collect();
codes.append(&mut block_codes);
start = end;
}
codes
}
fn route_block(
block: ArrayView2<'_, f32>,
decoder: ArrayView2<'_, f32>,
scorer: &TileScorer,
) -> Vec<Vec<(u32, f32)>> {
#[cfg(target_os = "linux")]
{
if gam_gpu::gpu_mode() != gam_gpu::GpuMode::Off {
if let Ok((routed, super::scoring_gpu::ScoreBlockPath::Device)) =
super::scoring_gpu::route_minibatch_required(
block,
decoder,
scorer.active,
scorer.tile,
gam_gpu::GpuMode::Auto,
)
{
return routed;
}
}
}
scorer.route_minibatch(block, decoder)
}
pub(super) fn run(
x: ArrayView2<'_, f32>,
config: &SparseDictConfig,
) -> Result<SparseDictFit, String> {
validate(x, config)?;
let n = x.nrows();
let p = x.ncols();
let k = config.n_atoms;
let s = config.active.min(k).max(1);
let mut decoder = seed_decoder(x, k);
unit_norm_rows(&mut decoder);
let scorer = TileScorer::new(s, config.score_tile);
let mut prev_ev = f64::NEG_INFINITY;
let mut converged = false;
let mut epochs_run = 0usize;
let mut codes = route_and_code_all(
x,
decoder.view(),
&scorer,
s,
config.code_ridge,
config.minibatch,
);
for epoch in 0..config.max_epochs {
epochs_run = epoch + 1;
refresh_decoder(x, &codes, &mut decoder, k, p, config);
unit_norm_rows(&mut decoder);
let revived = revive_dead_atoms(x, &codes, &mut decoder);
if revived > 0 {
unit_norm_rows(&mut decoder);
}
codes = route_and_code_all(
x,
decoder.view(),
&scorer,
s,
config.code_ridge,
config.minibatch,
);
let ev = explained_variance(x, &codes, decoder.view());
let improve = ev - prev_ev;
if revived == 0 && improve.abs() <= config.tolerance && epoch > 0 {
converged = true;
break;
}
prev_ev = ev;
}
let final_ev = explained_variance(x, &codes, decoder.view());
let (indices, code_mat) = pack_codes(&codes, n, s);
Ok(SparseDictFit {
decoder,
indices,
codes: code_mat,
explained_variance: final_ev,
epochs: epochs_run,
converged,
active: s,
})
}
fn validate(x: ArrayView2<'_, f32>, config: &SparseDictConfig) -> Result<(), String> {
if x.nrows() == 0 || x.ncols() == 0 {
return Err("fit_sparse_dictionary requires a non-empty N×P matrix".to_string());
}
if !x.iter().all(|v| v.is_finite()) {
return Err("fit_sparse_dictionary input must be finite".to_string());
}
if config.n_atoms == 0 {
return Err("fit_sparse_dictionary requires K >= 1".to_string());
}
if config.active == 0 {
return Err("fit_sparse_dictionary requires active (top_s) >= 1".to_string());
}
if config.max_epochs == 0 {
return Err("fit_sparse_dictionary requires max_epochs >= 1".to_string());
}
if !(config.code_ridge.is_finite() && config.code_ridge >= 0.0) {
return Err("fit_sparse_dictionary code_ridge must be finite and non-negative".to_string());
}
if !(config.decoder_ridge.is_finite() && config.decoder_ridge >= 0.0) {
return Err(
"fit_sparse_dictionary decoder_ridge must be finite and non-negative".to_string(),
);
}
if !config.tolerance.is_finite() {
return Err("fit_sparse_dictionary tolerance must be finite".to_string());
}
Ok(())
}
fn seed_decoder(x: ArrayView2<'_, f32>, k: usize) -> Array2<f32> {
let n = x.nrows();
let p = x.ncols();
let mut decoder = Array2::<f32>::zeros((k, p));
let mut first = 0usize;
let mut best = f32::NEG_INFINITY;
for i in 0..n {
let r = x.row(i);
let nrm: f32 = r.iter().map(|v| v * v).sum();
if nrm > best {
best = nrm;
first = i;
}
}
decoder.row_mut(0).assign(&x.row(first));
let mut min_dist2 = vec![f32::INFINITY; n];
for atom in 1..k {
let prev = decoder.row(atom - 1);
for i in 0..n {
let mut d2 = 0.0f32;
let xi = x.row(i);
for c in 0..p {
let d = xi[c] - prev[c];
d2 += d * d;
}
if d2 < min_dist2[i] {
min_dist2[i] = d2;
}
}
let chosen = if atom < n {
let mut bi = 0usize;
let mut bv = f32::NEG_INFINITY;
for i in 0..n {
if min_dist2[i] > bv {
bv = min_dist2[i];
bi = i;
}
}
bi
} else {
atom % n
};
decoder.row_mut(atom).assign(&x.row(chosen));
}
decoder
}
struct DecoderNormalEq {
diag: Vec<f64>,
b: Array2<f64>,
off: HashMap<(u32, u32), f64>,
}
fn assemble_normal_eq(
x: ArrayView2<'_, f32>,
codes: &[SparseCode],
k: usize,
p: usize,
) -> DecoderNormalEq {
let mut diag = vec![0.0f64; k];
let mut b = Array2::<f64>::zeros((k, p));
let mut off: HashMap<(u32, u32), f64> = HashMap::new();
for (row_idx, code) in codes.iter().enumerate() {
let xi = x.row(row_idx);
let xi_slice = xi.as_slice();
for a in 0..code.indices.len() {
let ca = code.codes[a] as f64;
if ca == 0.0 {
continue;
}
let ka = code.indices[a];
diag[ka as usize] += ca * ca;
let brow = ka as usize;
let mut brow_view = b.row_mut(brow);
match (brow_view.as_slice_mut(), xi_slice) {
(Some(bs), Some(xs)) => {
for (bref, &xv) in bs.iter_mut().zip(xs.iter()) {
*bref += ca * xv as f64;
}
}
_ => {
for c in 0..p {
brow_view[c] += ca * xi[c] as f64;
}
}
}
for bsel in (a + 1)..code.indices.len() {
let cb = code.codes[bsel] as f64;
if cb == 0.0 {
continue;
}
let kb = code.indices[bsel];
if ka == kb {
diag[ka as usize] += 2.0 * ca * cb;
continue;
}
let key = if ka < kb { (ka, kb) } else { (kb, ka) };
*off.entry(key).or_insert(0.0) += ca * cb;
}
}
}
DecoderNormalEq { diag, b, off }
}
const DEAD_DENOM: f64 = 1.0e-12;
const MAX_DIRECT_BLOCK: usize = 512;
const CG_REL_TOL: f64 = 1.0e-10;
fn refresh_decoder(
x: ArrayView2<'_, f32>,
codes: &[SparseCode],
decoder: &mut Array2<f32>,
k: usize,
p: usize,
config: &SparseDictConfig,
) {
let ridge = config.decoder_ridge as f64;
let eq = assemble_normal_eq(x, codes, k, p);
solve_decoder(decoder, &eq, ridge);
}
fn revive_dead_atoms(
x: ArrayView2<'_, f32>,
codes: &[SparseCode],
decoder: &mut Array2<f32>,
) -> usize {
let n = x.nrows();
let p = x.ncols();
let k = decoder.nrows();
let mut alive = vec![false; k];
for code in codes.iter() {
for (j, &idx) in code.indices.iter().enumerate() {
if code.codes[j] != 0.0 {
alive[idx as usize] = true;
}
}
}
let dead: Vec<usize> = (0..k).filter(|&a| !alive[a]).collect();
if dead.is_empty() {
return 0;
}
let mut resid = Array2::<f32>::zeros((n, p));
let mut resid_norm2 = vec![0.0f64; n];
for i in 0..n {
let xi = x.row(i);
let mut ri = resid.row_mut(i);
for c in 0..p {
ri[c] = xi[c];
}
let code = &codes[i];
for j in 0..code.indices.len() {
let cj = code.codes[j];
if cj == 0.0 {
continue;
}
let drow = decoder.row(code.indices[j] as usize);
for c in 0..p {
ri[c] -= cj * drow[c];
}
}
let mut acc = 0.0f64;
for c in 0..p {
acc += ri[c] as f64 * ri[c] as f64;
}
resid_norm2[i] = acc;
}
let mut order: Vec<usize> = (0..n).collect();
order.sort_by(|&a, &b| {
resid_norm2[b]
.partial_cmp(&resid_norm2[a])
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.cmp(&b))
});
let mut revived = 0usize;
for (t, &atom) in dead.iter().enumerate() {
if t >= n {
break; }
let row = order[t];
if resid_norm2[row] <= (DEAD_DENOM as f64) {
break; }
let src = resid.row(row);
let mut dst = decoder.row_mut(atom);
for c in 0..p {
dst[c] = src[c];
}
revived += 1;
}
revived
}
fn solve_decoder(decoder: &mut Array2<f32>, eq: &DecoderNormalEq, ridge: f64) {
let k = eq.diag.len();
let p = eq.b.ncols();
let mut neigh: Vec<Vec<(u32, f64)>> = vec![Vec::new(); k];
for (&(a, b), &val) in eq.off.iter() {
neigh[a as usize].push((b, val));
neigh[b as usize].push((a, val));
}
for list in neigh.iter_mut() {
list.sort_by_key(|&(nb, _)| nb);
}
let mut visited = vec![false; k];
for start in 0..k {
if visited[start] {
continue;
}
if neigh[start].is_empty() {
visited[start] = true;
let denom = eq.diag[start] + ridge;
if denom <= DEAD_DENOM {
continue;
}
for c in 0..p {
decoder[[start, c]] = (eq.b[[start, c]] / denom) as f32;
}
continue;
}
let mut comp = vec![start];
visited[start] = true;
let mut head = 0usize;
while head < comp.len() {
let node = comp[head];
head += 1;
for &(nb, _) in &neigh[node] {
let nb = nb as usize;
if !visited[nb] {
visited[nb] = true;
comp.push(nb);
}
}
}
comp.sort_unstable();
solve_component(decoder, eq, ridge, &comp, &neigh, p);
}
}
fn solve_component(
decoder: &mut Array2<f32>,
eq: &DecoderNormalEq,
ridge: f64,
comp: &[usize],
neigh: &[Vec<(u32, f64)>],
p: usize,
) {
let m = comp.len();
let mut local: HashMap<usize, usize> = HashMap::with_capacity(m);
for (i, &a) in comp.iter().enumerate() {
local.insert(a, i);
}
if m <= MAX_DIRECT_BLOCK {
let mut mat = Array2::<f64>::zeros((m, m));
let mut rhs = Array2::<f64>::zeros((m, p));
for (i, &a) in comp.iter().enumerate() {
mat[[i, i]] = eq.diag[a] + ridge;
for &(nb, val) in &neigh[a] {
if let Some(&j) = local.get(&(nb as usize)) {
mat[[i, j]] = val;
}
}
for c in 0..p {
rhs[[i, c]] = eq.b[[a, c]];
}
}
let sol = cholesky_solve_block(&mat, &rhs);
for (i, &a) in comp.iter().enumerate() {
for c in 0..p {
decoder[[a, c]] = sol[[i, c]] as f32;
}
}
return;
}
let matvec = |xloc: &[f64]| -> Vec<f64> {
let mut y = vec![0.0f64; m];
for (i, &a) in comp.iter().enumerate() {
let mut acc = (eq.diag[a] + ridge) * xloc[i];
for &(nb, val) in &neigh[a] {
if let Some(&j) = local.get(&(nb as usize)) {
acc += val * xloc[j];
}
}
y[i] = acc;
}
y
};
let cap = m.saturating_mul(20).saturating_add(100);
for c in 0..p {
let mut bvec = vec![0.0f64; m];
let mut bnorm2 = 0.0f64;
for (i, &a) in comp.iter().enumerate() {
bvec[i] = eq.b[[a, c]];
bnorm2 += bvec[i] * bvec[i];
}
let bnorm = bnorm2.sqrt();
let mut xvec = vec![0.0f64; m];
if bnorm <= DEAD_DENOM {
for &a in comp {
decoder[[a, c]] = 0.0;
}
continue;
}
let mut r = bvec;
let mut pdir = r.clone();
let mut rs_old: f64 = r.iter().map(|v| v * v).sum();
for _ in 0..cap {
let ap = matvec(&pdir);
let mut pap = 0.0f64;
for i in 0..m {
pap += pdir[i] * ap[i];
}
if pap <= 0.0 {
break; }
let alpha = rs_old / pap;
for i in 0..m {
xvec[i] += alpha * pdir[i];
r[i] -= alpha * ap[i];
}
let rnorm: f64 = r.iter().map(|v| v * v).sum::<f64>().sqrt();
if rnorm / bnorm <= CG_REL_TOL {
break;
}
let rs_new: f64 = r.iter().map(|v| v * v).sum();
let beta = rs_new / rs_old;
for i in 0..m {
pdir[i] = r[i] + beta * pdir[i];
}
rs_old = rs_new;
}
for (i, &a) in comp.iter().enumerate() {
decoder[[a, c]] = xvec[i] as f32;
}
}
}
fn cholesky_solve_block(mat: &Array2<f64>, rhs: &Array2<f64>) -> Array2<f64> {
use faer::Side;
use gam_linalg::faer_ndarray::FaerCholesky;
let m = mat.nrows();
let mut a = mat.clone();
let mut bump = 0.0f64;
for _attempt in 0..6 {
if let Ok(factor) = a.cholesky(Side::Lower) {
return factor.solve_mat(rhs);
}
bump = if bump == 0.0 { 1.0e-8 } else { bump * 16.0 };
a = mat.clone();
for i in 0..m {
a[[i, i]] += bump;
}
}
let p = rhs.ncols();
let mut out = Array2::<f64>::zeros((m, p));
for i in 0..m {
let d = mat[[i, i]].max(DEAD_DENOM);
for c in 0..p {
out[[i, c]] = rhs[[i, c]] / d;
}
}
out
}
fn unit_norm_rows(decoder: &mut Array2<f32>) {
for mut row in decoder.outer_iter_mut() {
let nrm: f32 = row.iter().map(|v| v * v).sum::<f32>().sqrt();
if nrm > 1.0e-12 {
row.mapv_inplace(|v| v / nrm);
let mut sign = 1.0f32;
for &v in row.iter() {
if v.abs() > 1.0e-9 {
sign = v.signum();
break;
}
}
if sign < 0.0 {
row.mapv_inplace(|v| -v);
}
}
}
}
fn explained_variance(
x: ArrayView2<'_, f32>,
codes: &[SparseCode],
decoder: ArrayView2<'_, f32>,
) -> f64 {
let n = x.nrows();
let p = x.ncols();
let mut means = vec![0.0f64; p];
for i in 0..n {
let xi = x.row(i);
for c in 0..p {
means[c] += xi[c] as f64;
}
}
for c in 0..p {
means[c] /= n as f64;
}
let mut rss = 0.0f64;
let mut tss = 0.0f64;
let mut recon = vec![0.0f64; p];
for i in 0..n {
for c in 0..p {
recon[c] = 0.0;
}
let code = &codes[i];
for j in 0..code.indices.len() {
let cj = code.codes[j] as f64;
if cj == 0.0 {
continue;
}
let drow = decoder.row(code.indices[j] as usize);
for c in 0..p {
recon[c] += cj * drow[c] as f64;
}
}
let xi = x.row(i);
for c in 0..p {
let r = xi[c] as f64 - recon[c];
rss += r * r;
let t = xi[c] as f64 - means[c];
tss += t * t;
}
}
if tss <= 1.0e-24 {
if rss <= 1.0e-24 { 1.0 } else { 0.0 }
} else {
1.0 - rss / tss
}
}
fn pack_codes(codes: &[SparseCode], n: usize, s: usize) -> (Array2<u32>, Array2<f32>) {
let mut indices = Array2::<u32>::zeros((n, s));
let mut code_mat = Array2::<f32>::zeros((n, s));
for (i, code) in codes.iter().enumerate() {
for j in 0..s {
indices[[i, j]] = code.indices[j];
code_mat[[i, j]] = code.codes[j];
}
}
(indices, code_mat)
}
#[cfg(test)]
mod exact_solve_tests {
use super::{
DecoderNormalEq, assemble_normal_eq, explained_variance, route_and_code_all, solve_decoder,
};
use crate::sparse_dict::codes::SparseCode;
use crate::sparse_dict::scoring::TileScorer;
use crate::sparse_dict::{SparseDictConfig, fit_sparse_dictionary};
use ndarray::Array2;
impl DecoderNormalEq {
fn matvec_col(&self, ridge: f64, x: &[f64]) -> Vec<f64> {
let k = self.diag.len();
let mut y = vec![0.0f64; k];
for i in 0..k {
y[i] = (self.diag[i] + ridge) * x[i];
}
for (&(a, b), &val) in self.off.iter() {
y[a as usize] += val * x[b as usize];
y[b as usize] += val * x[a as usize];
}
y
}
}
fn overlapping_problem() -> (Array2<f32>, Vec<SparseCode>, usize, usize) {
let k = 5usize;
let p = 4usize;
let supports: [[u32; 3]; 5] = [
[0, 1, 2],
[1, 2, 3],
[2, 3, 4],
[3, 4, 0],
[4, 0, 1],
];
let codevals: [[f32; 3]; 5] = [
[1.0, 0.5, -0.3],
[0.7, -0.2, 0.4],
[-0.6, 0.9, 0.1],
[0.3, -0.5, 0.8],
[0.2, 0.6, -0.4],
];
let codes: Vec<SparseCode> = supports
.iter()
.zip(codevals.iter())
.map(|(idx, cv)| SparseCode {
indices: idx.to_vec(),
codes: cv.to_vec(),
})
.collect();
let n = codes.len();
let mut x = Array2::<f32>::zeros((n, p));
for i in 0..n {
for c in 0..p {
x[[i, c]] = (((i * 7 + c * 3 + 1) % 13) as f32 - 6.0) / 4.0;
}
}
(x, codes, k, p)
}
fn normal_eq_residual(eq: &DecoderNormalEq, decoder: &Array2<f32>, ridge: f64) -> f64 {
let k = eq.diag.len();
let p = eq.b.ncols();
let mut rss = 0.0f64;
let mut bss = 0.0f64;
for c in 0..p {
let dcol: Vec<f64> = (0..k).map(|i| decoder[[i, c]] as f64).collect();
let y = eq.matvec_col(ridge, &dcol);
for i in 0..k {
let r = y[i] - eq.b[[i, c]];
rss += r * r;
bss += eq.b[[i, c]] * eq.b[[i, c]];
}
}
if bss <= 0.0 { 0.0 } else { (rss / bss).sqrt() }
}
#[test]
fn exact_solver_drives_normal_eq_residual_below_tolerance() {
let (x, codes, k, p) = overlapping_problem();
let ridge = 1.0e-6f64;
let eq = assemble_normal_eq(x.view(), &codes, k, p);
assert!(
!eq.off.is_empty(),
"test problem must have off-diagonal coupling (overlapping supports)"
);
let mut decoder = Array2::<f32>::zeros((k, p));
solve_decoder(&mut decoder, &eq, ridge);
let rel = normal_eq_residual(&eq, &decoder, ridge);
assert!(
rel < 1.0e-6,
"coupled decoder solve must drive ‖(A+ρI)D−B‖/‖B‖ to the f32 floor \
(< 1e-6), got {rel}"
);
}
#[test]
fn block_solve_matches_independent_dense_solve() {
use faer::Side;
use gam_linalg::faer_ndarray::FaerCholesky;
let (x, codes, k, p) = overlapping_problem();
let ridge = 1.0e-6f64;
let eq = assemble_normal_eq(x.view(), &codes, k, p);
let mut decoder = Array2::<f32>::zeros((k, p));
solve_decoder(&mut decoder, &eq, ridge);
let mut mat = Array2::<f64>::zeros((k, k));
for i in 0..k {
mat[[i, i]] = eq.diag[i] + ridge;
}
for (&(a, b), &val) in eq.off.iter() {
mat[[a as usize, b as usize]] = val;
mat[[b as usize, a as usize]] = val;
}
let factor = mat.cholesky(Side::Lower).expect("dense SPD system");
let dense = factor.solve_mat(&eq.b);
for i in 0..k {
for c in 0..p {
let got = decoder[[i, c]] as f64;
let want = dense[[i, c]];
assert!(
(got - want).abs() <= 1.0e-5 + 1.0e-5 * want.abs(),
"block solve [{i},{c}] = {got} disagrees with dense solve {want}"
);
}
}
}
#[test]
fn returned_ev_is_fresh_code_ev_no_stale_gap() {
let (n, p, k) = (60usize, 6usize, 8usize);
let mut x = Array2::<f32>::zeros((n, p));
for i in 0..n {
for c in 0..p {
x[[i, c]] = (((i * 3 + c * 7 + 1) % 11) as f32 - 5.0) / 5.0;
}
}
let config = SparseDictConfig {
n_atoms: k,
active: 2, minibatch: 16,
max_epochs: 25,
score_tile: 8,
code_ridge: 1.0e-6,
decoder_ridge: 1.0e-6,
tolerance: 1.0e-9,
};
let fit = fit_sparse_dictionary(x.view(), &config).expect("fit");
let s = fit.active;
assert!(s > 1, "test must run the coupled s>1 lane");
let scorer = TileScorer::new(s, config.score_tile);
let codes = route_and_code_all(
x.view(),
fit.decoder.view(),
&scorer,
s,
config.code_ridge,
config.minibatch,
);
let fresh_ev = explained_variance(x.view(), &codes, fit.decoder.view());
assert!(
(fresh_ev - fit.explained_variance).abs() < 1.0e-6,
"returned EV {} must equal fresh-code EV {fresh_ev} (no stale-code gap)",
fit.explained_variance
);
}
}