Skip to main content

gam_sae/sparse_dict/
codes.rs

1//! Per-row sparse codes via a small active-set least-squares solve.
2//!
3//! Given a row `x` and the `s` atoms the router selected for it, the optimal
4//! codes minimise `‖x − Σ_j c_j d_{a_j}‖² + ρ‖c‖²`. That is the tiny
5//! `s×s` normal-equation system `(Gᵃ + ρI) c = Dᵃ x` where `Gᵃ` is the Gram of
6//! the active atoms and `Dᵃ x` are their projections. `s` is the shared active
7//! budget (a handful), so this is a cheap dense solve regardless of `K`.
8
9use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
10
11/// One row's fixed-width sparse code.
12#[derive(Clone, Debug)]
13pub struct SparseCode {
14    /// Active atom indices, length `s` (padded with the last live index when the
15    /// row had fewer than `s` candidates; padded entries carry a zero code).
16    pub indices: Vec<u32>,
17    /// Codes aligned with [`Self::indices`], length `s`.
18    pub codes: Vec<f32>,
19}
20
21/// Solve the active-set least-squares codes for one row.
22///
23/// `active` is the router's `(atom, score)` shortlist; only the atom indices are
24/// used (the score chose the set, the LS solve sets the magnitudes). `s` is the
25/// fixed output width: shorter shortlists are padded so every row stores exactly
26/// `s` slots.
27pub fn solve_row_codes(
28    row: ArrayView1<'_, f32>,
29    decoder: ArrayView2<'_, f32>,
30    active: &[(u32, f32)],
31    s: usize,
32    ridge: f32,
33) -> SparseCode {
34    let m = active.len();
35    if m == 0 {
36        // No live atom — emit zero code on atom 0 (padding contract).
37        return SparseCode {
38            indices: vec![0u32; s],
39            codes: vec![0.0f32; s],
40        };
41    }
42    let p = row.len();
43    // Active Gram (m×m) and rhs (m) in f64 for a well-conditioned solve.
44    let mut gram = Array2::<f64>::zeros((m, m));
45    let mut rhs = Array1::<f64>::zeros(m);
46    for i in 0..m {
47        let ai = active[i].0 as usize;
48        let di = decoder.row(ai);
49        let mut proj = 0.0f64;
50        for c in 0..p {
51            proj += di[c] as f64 * row[c] as f64;
52        }
53        rhs[i] = proj;
54        for j in i..m {
55            let aj = active[j].0 as usize;
56            let dj = decoder.row(aj);
57            let mut g = 0.0f64;
58            for c in 0..p {
59                g += di[c] as f64 * dj[c] as f64;
60            }
61            gram[[i, j]] = g;
62            gram[[j, i]] = g;
63        }
64        gram[[i, i]] += ridge as f64;
65    }
66    let solution = solve_spd(&gram, &rhs);
67
68    let mut indices = Vec::with_capacity(s);
69    let mut codes = Vec::with_capacity(s);
70    for i in 0..m.min(s) {
71        indices.push(active[i].0);
72        codes.push(solution[i] as f32);
73    }
74    // Pad to fixed width with the first active index, zero code.
75    while indices.len() < s {
76        indices.push(active[0].0);
77        codes.push(0.0f32);
78    }
79    SparseCode { indices, codes }
80}
81
82/// SPD solve via Cholesky with a Tikhonov-bumped fallback. The system is `s×s`
83/// with `s` tiny, so an in-place dense factorisation is appropriate.
84fn solve_spd(gram: &Array2<f64>, rhs: &Array1<f64>) -> Array1<f64> {
85    use gam_linalg::faer_ndarray::FaerCholesky;
86    use faer::Side;
87
88    let m = rhs.len();
89    let mut a = gram.clone();
90    let mut bump = 0.0f64;
91    for _attempt in 0..6 {
92        if let Ok(factor) = a.cholesky(Side::Lower) {
93            return factor.solvevec(rhs);
94        }
95        // Indefinite (e.g. exactly collinear atoms): bump the diagonal and retry.
96        bump = if bump == 0.0 { 1.0e-8 } else { bump * 16.0 };
97        a = gram.clone();
98        for i in 0..m {
99            a[[i, i]] += bump;
100        }
101    }
102    // Degenerate beyond recovery: fall back to the diagonal (independent atoms).
103    let mut out = Array1::<f64>::zeros(m);
104    for i in 0..m {
105        let d = gram[[i, i]].max(1.0e-12);
106        out[i] = rhs[i] / d;
107    }
108    out
109}