Skip to main content

gam_sae/sparse_dict/
mod.rs

1//! Fixed-K sparse, minibatched SAE trainer (#1026, "collapsed linear lane").
2//!
3//! This is an **additive, standalone** path that makes very large dictionaries
4//! (`K` up to tens of thousands) tractable, where the exact-REML / Arrow-Schur
5//! dense joint manifold solver in [`crate::manifold`] is the wrong
6//! engine: that solver carries a dense `N×K` latent state, `N×K×P` sensitivity
7//! tensors, `K²N` penalty couplings, and a joint Newton over all `K` outer
8//! parameters. None of that survives `K ≈ 32_000`.
9//!
10//! The collapsed linear lane instead trains a dictionary by alternating
11//! minimisation with **no dense `N×K` object anywhere**:
12//!
13//! 1. **route** — for each row, score it against the whole dictionary in
14//!    `K`-tiles ([`scoring`]) and keep only the top-`s` atoms online, so the
15//!    `N×K` score matrix is produced one tile at a time and discarded;
16//! 2. **codes** — solve the small `s×s` active-set least-squares system per row
17//!    ([`codes`]), giving a fixed-width sparse code `(indices, codes)`;
18//! 3. **decoder** — accumulate the sparse normal equations (method-of-optimal
19//!    -directions / sparse GEMM) and refresh each atom ([`update`]);
20//! 4. **project** — re-unit-norm every atom so the code scale is identified.
21//!
22//! All heavy state is FP32. The only dense `K`-sized objects are the decoder
23//! itself (`K×P`) and the per-atom `P×P`/scalar accumulators — never `N×K`.
24//!
25//! The exact manifold engine is **untouched**: it remains the certification /
26//! small-`K` path. This module is reached only through its own public entry
27//! [`fit_sparse_dictionary`] (and the `gamfit` Python facade that wraps it).
28
29mod codes;
30mod scoring;
31mod update;
32
33#[cfg(test)]
34mod tests;
35
36pub use codes::SparseCode;
37pub use scoring::{TileScorer, top_s_online};
38
39use ndarray::{Array2, ArrayView2};
40
41/// Shared (NOT per-atom) hyper-parameters for the collapsed linear lane.
42///
43/// The whole point of the sparse trainer is that `K` is too large to carry a
44/// per-atom smoothing parameter / Newton state; every knob here is a single
45/// scalar shared across the entire dictionary.
46#[derive(Clone, Copy, Debug)]
47pub struct SparseDictConfig {
48    /// Dictionary width `K` (number of atoms).
49    pub n_atoms: usize,
50    /// Active budget `s`: how many atoms may fire per row (`top_s`). This is the
51    /// shared routing-sparsity hyper-parameter.
52    pub active: usize,
53    /// Minibatch size (rows per route→code→accumulate step). The decoder is
54    /// refreshed once per full epoch from the accumulated sparse normal
55    /// equations, so this only bounds peak working set, not the solution.
56    pub minibatch: usize,
57    /// Number of full passes over the data.
58    pub max_epochs: usize,
59    /// Column tile width used when scoring rows against the dictionary. Score
60    /// tiles of shape `minibatch × tile` are formed and discarded; the `N×K`
61    /// score matrix is never materialised.
62    pub score_tile: usize,
63    /// Shared ridge on the per-row active-set code solve (Tikhonov on the
64    /// `s×s` Gram). Identifies the codes when active atoms are collinear.
65    pub code_ridge: f32,
66    /// Shared ridge on the per-atom decoder refresh (method-of-optimal
67    /// -directions normal equations). Keeps a thinly-used atom well posed.
68    pub decoder_ridge: f32,
69    /// Relative explained-variance improvement below which training stops.
70    pub tolerance: f64,
71}
72
73impl SparseDictConfig {
74    /// Construct a config for a `K`-atom dictionary, leaving every other knob at
75    /// its shared default.
76    pub fn new(n_atoms: usize) -> Self {
77        Self {
78            n_atoms,
79            ..Self::default()
80        }
81    }
82}
83
84impl Default for SparseDictConfig {
85    fn default() -> Self {
86        Self {
87            n_atoms: 1,
88            active: 1,
89            minibatch: 512,
90            max_epochs: 30,
91            score_tile: 4096,
92            code_ridge: 1.0e-6,
93            decoder_ridge: 1.0e-6,
94            tolerance: 1.0e-6,
95        }
96    }
97}
98
99/// Result of a collapsed-linear-lane fit.
100///
101/// The routing is stored fixed-width and **sparse**: `indices[N, s]` /
102/// `codes[N, s]`. There is deliberately no dense `N×K` assignment matrix —
103/// reconstructing it would defeat the purpose of the lane.
104#[derive(Clone, Debug)]
105pub struct SparseDictFit {
106    /// Decoder, `K×P`, unit-norm rows (one atom per row).
107    pub decoder: Array2<f32>,
108    /// Active atom indices per row, `N×s` (column `j` of row `i` is the `j`-th
109    /// active atom for that row). Rows with fewer than `s` live atoms pad with
110    /// repeated indices whose matching code is zero.
111    pub indices: Array2<u32>,
112    /// Sparse codes per row, `N×s`, aligned with [`Self::indices`].
113    pub codes: Array2<f32>,
114    /// Final held-in explained variance (`1 − RSS/TSS`).
115    pub explained_variance: f64,
116    /// Number of epochs actually run.
117    pub epochs: usize,
118    /// Whether the EV-improvement tolerance was reached.
119    pub converged: bool,
120    /// Active budget `s` actually used (`min(active, K)`).
121    pub active: usize,
122}
123
124impl SparseDictFit {
125    /// Dense reconstruction `N×P` of the training rows from the sparse routing.
126    ///
127    /// This *does* allocate an `N×P` array (the data size, not `N×K`); it exists
128    /// for diagnostics / EV checks, not as part of the trainer's hot loop.
129    pub fn reconstruct(&self) -> Array2<f32> {
130        let n = self.indices.nrows();
131        let p = self.decoder.ncols();
132        let mut out = Array2::<f32>::zeros((n, p));
133        for i in 0..n {
134            for j in 0..self.active {
135                let atom = self.indices[[i, j]] as usize;
136                let code = self.codes[[i, j]];
137                if code == 0.0 {
138                    continue;
139                }
140                let row = self.decoder.row(atom);
141                for c in 0..p {
142                    out[[i, c]] += code * row[c];
143                }
144            }
145        }
146        out
147    }
148}
149
150/// Fit a fixed-`K` sparse minibatched linear dictionary to `x` (`N×P`).
151///
152/// This is the public entry of the collapsed linear lane. It never forms a
153/// dense `N×K` object: scoring is tiled, routing is fixed-width sparse, and the
154/// decoder is refreshed from accumulated sparse normal equations.
155pub fn fit_sparse_dictionary(
156    x: ArrayView2<'_, f32>,
157    config: &SparseDictConfig,
158) -> Result<SparseDictFit, String> {
159    update::run(x, config)
160}