Skip to main content

gam_sae/sparse_dict/
mod.rs

1//! Fixed-K sparse, minibatched SAE trainer (#1026, "collapsed linear lane").
2//!
3//! This is an **additive, standalone** path that makes very large dictionaries
4//! (`K` up to tens of thousands) tractable, where the exact-REML / Arrow-Schur
5//! dense joint manifold solver in [`crate::manifold`] is the wrong
6//! engine: that solver carries a dense `N×K` latent state, `N×K×P` sensitivity
7//! tensors, `K²N` penalty couplings, and a joint Newton over all `K` outer
8//! parameters. None of that survives `K ≈ 32_000`.
9//!
10//! The collapsed linear lane instead trains a dictionary by alternating
11//! minimisation with **no dense `N×K` object anywhere**:
12//!
13//! 1. **route** — for each row, score it against the whole dictionary in
14//!    `K`-tiles ([`scoring`]) and keep only the top-`s` atoms online, so the
15//!    `N×K` score matrix is produced one tile at a time and discarded;
16//! 2. **codes** — solve the small `s×s` active-set least-squares system per row
17//!    ([`codes`]), giving a fixed-width sparse code `(indices, codes)`;
18//! 3. **decoder** — accumulate the sparse normal equations (method-of-optimal
19//!    -directions / sparse GEMM) and refresh each atom ([`update`]);
20//! 4. **project** — re-unit-norm every atom so the code scale is identified.
21//!
22//! All heavy state is FP32. The only dense `K`-sized objects are the decoder
23//! itself (`K×P`) and the per-atom `P×P`/scalar accumulators — never `N×K`.
24//!
25//! The exact manifold engine is **untouched**: it remains the certification /
26//! small-`K` path. This module is reached only through its own public entry
27//! [`fit_sparse_dictionary`] (and the `gamfit` Python facade that wraps it).
28
29mod codes;
30mod scoring;
31#[cfg(target_os = "linux")]
32mod scoring_gpu;
33mod update;
34
35#[cfg(test)]
36mod tests;
37
38pub use codes::SparseCode;
39pub use scoring::{TileScorer, top_s_online};
40#[cfg(target_os = "linux")]
41pub use scoring_gpu::{
42    DEVICE_SCORE_BLOCK_MIN_ELEMS, ScoreBlockPath, score_block_cpu, score_block_required,
43};
44
45use ndarray::{Array2, ArrayView2};
46
47/// Shared (NOT per-atom) hyper-parameters for the collapsed linear lane.
48///
49/// The whole point of the sparse trainer is that `K` is too large to carry a
50/// per-atom smoothing parameter / Newton state; every knob here is a single
51/// scalar shared across the entire dictionary.
52#[derive(Clone, Copy, Debug)]
53pub struct SparseDictConfig {
54    /// Dictionary width `K` (number of atoms).
55    pub n_atoms: usize,
56    /// Active budget `s`: how many atoms may fire per row (`top_s`). This is the
57    /// shared routing-sparsity hyper-parameter.
58    pub active: usize,
59    /// Minibatch size (rows per route→code→accumulate step). The decoder is
60    /// refreshed once per full epoch from the accumulated sparse normal
61    /// equations, so this only bounds peak working set, not the solution.
62    pub minibatch: usize,
63    /// Number of full passes over the data.
64    pub max_epochs: usize,
65    /// Column tile width used when scoring rows against the dictionary. Score
66    /// tiles of shape `minibatch × tile` are formed and discarded; the `N×K`
67    /// score matrix is never materialised.
68    pub score_tile: usize,
69    /// Shared ridge on the per-row active-set code solve (Tikhonov on the
70    /// `s×s` Gram). Identifies the codes when active atoms are collinear.
71    pub code_ridge: f32,
72    /// Shared ridge on the per-atom decoder refresh (method-of-optimal
73    /// -directions normal equations). Keeps a thinly-used atom well posed.
74    pub decoder_ridge: f32,
75    /// Relative explained-variance improvement below which training stops.
76    pub tolerance: f64,
77}
78
79impl SparseDictConfig {
80    /// Construct a config for a `K`-atom dictionary, leaving every other knob at
81    /// its shared default.
82    pub fn new(n_atoms: usize) -> Self {
83        Self {
84            n_atoms,
85            ..Self::default()
86        }
87    }
88}
89
90impl Default for SparseDictConfig {
91    fn default() -> Self {
92        Self {
93            n_atoms: 1,
94            active: 1,
95            minibatch: 512,
96            max_epochs: 30,
97            score_tile: 4096,
98            code_ridge: 1.0e-6,
99            decoder_ridge: 1.0e-6,
100            tolerance: 1.0e-6,
101        }
102    }
103}
104
105/// Result of a collapsed-linear-lane fit.
106///
107/// The routing is stored fixed-width and **sparse**: `indices[N, s]` /
108/// `codes[N, s]`. There is deliberately no dense `N×K` assignment matrix —
109/// reconstructing it would defeat the purpose of the lane.
110#[derive(Clone, Debug)]
111pub struct SparseDictFit {
112    /// Decoder, `K×P`, unit-norm rows (one atom per row).
113    pub decoder: Array2<f32>,
114    /// Active atom indices per row, `N×s` (column `j` of row `i` is the `j`-th
115    /// active atom for that row). Rows with fewer than `s` live atoms pad with
116    /// repeated indices whose matching code is zero.
117    pub indices: Array2<u32>,
118    /// Sparse codes per row, `N×s`, aligned with [`Self::indices`].
119    pub codes: Array2<f32>,
120    /// Final held-in explained variance (`1 − RSS/TSS`).
121    pub explained_variance: f64,
122    /// Number of epochs actually run.
123    pub epochs: usize,
124    /// Whether the EV-improvement tolerance was reached.
125    pub converged: bool,
126    /// Active budget `s` actually used (`min(active, K)`).
127    pub active: usize,
128}
129
130impl SparseDictFit {
131    /// Dense reconstruction `N×P` of the training rows from the sparse routing.
132    ///
133    /// This *does* allocate an `N×P` array (the data size, not `N×K`); it exists
134    /// for diagnostics / EV checks, not as part of the trainer's hot loop.
135    pub fn reconstruct(&self) -> Array2<f32> {
136        let n = self.indices.nrows();
137        let p = self.decoder.ncols();
138        let mut out = Array2::<f32>::zeros((n, p));
139        for i in 0..n {
140            for j in 0..self.active {
141                let atom = self.indices[[i, j]] as usize;
142                let code = self.codes[[i, j]];
143                if code == 0.0 {
144                    continue;
145                }
146                let row = self.decoder.row(atom);
147                for c in 0..p {
148                    out[[i, c]] += code * row[c];
149                }
150            }
151        }
152        out
153    }
154}
155
156/// Fit a fixed-`K` sparse minibatched linear dictionary to `x` (`N×P`).
157///
158/// This is the public entry of the collapsed linear lane. It never forms a
159/// dense `N×K` object: scoring is tiled, routing is fixed-width sparse, and the
160/// decoder is refreshed from accumulated sparse normal equations.
161pub fn fit_sparse_dictionary(
162    x: ArrayView2<'_, f32>,
163    config: &SparseDictConfig,
164) -> Result<SparseDictFit, String> {
165    update::run(x, config)
166}