gam_sae/sparse_dict/codes.rs
1//! Per-row sparse codes via a small active-set least-squares solve.
2//!
3//! Given a row `x` and the `s` atoms the router selected for it, the optimal
4//! codes minimise `‖x − Σ_j c_j d_{a_j}‖² + ρ‖c‖²`. That is the tiny
5//! `s×s` normal-equation system `(Gᵃ + ρI) c = Dᵃ x` where `Gᵃ` is the Gram of
6//! the active atoms and `Dᵃ x` are their projections. `s` is the shared active
7//! budget (a handful), so this is a cheap dense solve regardless of `K`.
8
9use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
10
11/// One row's fixed-width sparse code.
12#[derive(Clone, Debug)]
13pub struct SparseCode {
14 /// Active atom indices, length `s` (padded with the last live index when the
15 /// row had fewer than `s` candidates; padded entries carry a zero code).
16 pub indices: Vec<u32>,
17 /// Codes aligned with [`Self::indices`], length `s`.
18 pub codes: Vec<f32>,
19}
20
21/// Solve the active-set least-squares codes for one row.
22///
23/// `active` is the router's `(atom, score)` shortlist; only the atom indices are
24/// used (the score chose the set, the LS solve sets the magnitudes). `s` is the
25/// fixed output width: shorter shortlists are padded so every row stores exactly
26/// `s` slots.
27pub fn solve_row_codes(
28 row: ArrayView1<'_, f32>,
29 decoder: ArrayView2<'_, f32>,
30 active: &[(u32, f32)],
31 s: usize,
32 ridge: f32,
33) -> SparseCode {
34 let m = active.len();
35 if m == 0 {
36 // No live atom — emit zero code on atom 0 (padding contract).
37 return SparseCode {
38 indices: vec![0u32; s],
39 codes: vec![0.0f32; s],
40 };
41 }
42 let p = row.len();
43 // Active Gram (m×m) and rhs (m) in f64 for a well-conditioned solve.
44 let mut gram = Array2::<f64>::zeros((m, m));
45 let mut rhs = Array1::<f64>::zeros(m);
46 for i in 0..m {
47 let ai = active[i].0 as usize;
48 let di = decoder.row(ai);
49 let mut proj = 0.0f64;
50 for c in 0..p {
51 proj += di[c] as f64 * row[c] as f64;
52 }
53 rhs[i] = proj;
54 for j in i..m {
55 let aj = active[j].0 as usize;
56 let dj = decoder.row(aj);
57 let mut g = 0.0f64;
58 for c in 0..p {
59 g += di[c] as f64 * dj[c] as f64;
60 }
61 gram[[i, j]] = g;
62 gram[[j, i]] = g;
63 }
64 gram[[i, i]] += ridge as f64;
65 }
66 let solution = solve_spd(&gram, &rhs);
67
68 let mut indices = Vec::with_capacity(s);
69 let mut codes = Vec::with_capacity(s);
70 for i in 0..m.min(s) {
71 indices.push(active[i].0);
72 codes.push(solution[i] as f32);
73 }
74 // Pad to fixed width with the first active index, zero code.
75 while indices.len() < s {
76 indices.push(active[0].0);
77 codes.push(0.0f32);
78 }
79 SparseCode { indices, codes }
80}
81
82/// SPD solve via Cholesky with a Tikhonov-bumped fallback. The system is `s×s`
83/// with `s` tiny, so an in-place dense factorisation is appropriate.
84fn solve_spd(gram: &Array2<f64>, rhs: &Array1<f64>) -> Array1<f64> {
85 use gam_linalg::faer_ndarray::FaerCholesky;
86 use faer::Side;
87
88 let m = rhs.len();
89 let mut a = gram.clone();
90 let mut bump = 0.0f64;
91 for _attempt in 0..6 {
92 if let Ok(factor) = a.cholesky(Side::Lower) {
93 return factor.solvevec(rhs);
94 }
95 // Indefinite (e.g. exactly collinear atoms): bump the diagonal and retry.
96 bump = if bump == 0.0 { 1.0e-8 } else { bump * 16.0 };
97 a = gram.clone();
98 for i in 0..m {
99 a[[i, i]] += bump;
100 }
101 }
102 // Degenerate beyond recovery: fall back to the diagonal (independent atoms).
103 let mut out = Array1::<f64>::zeros(m);
104 for i in 0..m {
105 let d = gram[[i, i]].max(1.0e-12);
106 out[i] = rhs[i] / d;
107 }
108 out
109}