gam_sae/sparse_dict/mod.rs
1//! Fixed-K sparse, minibatched SAE trainer (#1026, "collapsed linear lane").
2//!
3//! This is an **additive, standalone** path that makes very large dictionaries
4//! (`K` up to tens of thousands) tractable, where the exact-REML / Arrow-Schur
5//! dense joint manifold solver in [`crate::manifold`] is the wrong
6//! engine: that solver carries a dense `N×K` latent state, `N×K×P` sensitivity
7//! tensors, `K²N` penalty couplings, and a joint Newton over all `K` outer
8//! parameters. None of that survives `K ≈ 32_000`.
9//!
10//! The collapsed linear lane instead trains a dictionary by alternating
11//! minimisation with **no dense `N×K` object anywhere**:
12//!
13//! 1. **route** — for each row, score it against the whole dictionary in
14//! `K`-tiles ([`scoring`]) and keep only the top-`s` atoms online, so the
15//! `N×K` score matrix is produced one tile at a time and discarded;
16//! 2. **codes** — solve the small `s×s` active-set least-squares system per row
17//! ([`codes`]), giving a fixed-width sparse code `(indices, codes)`;
18//! 3. **decoder** — accumulate the sparse normal equations (method-of-optimal
19//! -directions / sparse GEMM) and refresh each atom ([`update`]);
20//! 4. **project** — re-unit-norm every atom so the code scale is identified.
21//!
22//! All heavy state is FP32. The only dense `K`-sized objects are the decoder
23//! itself (`K×P`) and the per-atom `P×P`/scalar accumulators — never `N×K`.
24//!
25//! The exact manifold engine is **untouched**: it remains the certification /
26//! small-`K` path. This module is reached only through its own public entry
27//! [`fit_sparse_dictionary`] (and the `gamfit` Python facade that wraps it).
28
29mod codes;
30mod scoring;
31mod update;
32
33#[cfg(test)]
34mod tests;
35
36pub use codes::SparseCode;
37pub use scoring::{TileScorer, top_s_online};
38
39use ndarray::{Array2, ArrayView2};
40
41/// Shared (NOT per-atom) hyper-parameters for the collapsed linear lane.
42///
43/// The whole point of the sparse trainer is that `K` is too large to carry a
44/// per-atom smoothing parameter / Newton state; every knob here is a single
45/// scalar shared across the entire dictionary.
46#[derive(Clone, Copy, Debug)]
47pub struct SparseDictConfig {
48 /// Dictionary width `K` (number of atoms).
49 pub n_atoms: usize,
50 /// Active budget `s`: how many atoms may fire per row (`top_s`). This is the
51 /// shared routing-sparsity hyper-parameter.
52 pub active: usize,
53 /// Minibatch size (rows per route→code→accumulate step). The decoder is
54 /// refreshed once per full epoch from the accumulated sparse normal
55 /// equations, so this only bounds peak working set, not the solution.
56 pub minibatch: usize,
57 /// Number of full passes over the data.
58 pub max_epochs: usize,
59 /// Column tile width used when scoring rows against the dictionary. Score
60 /// tiles of shape `minibatch × tile` are formed and discarded; the `N×K`
61 /// score matrix is never materialised.
62 pub score_tile: usize,
63 /// Shared ridge on the per-row active-set code solve (Tikhonov on the
64 /// `s×s` Gram). Identifies the codes when active atoms are collinear.
65 pub code_ridge: f32,
66 /// Shared ridge on the per-atom decoder refresh (method-of-optimal
67 /// -directions normal equations). Keeps a thinly-used atom well posed.
68 pub decoder_ridge: f32,
69 /// Relative explained-variance improvement below which training stops.
70 pub tolerance: f64,
71}
72
73impl SparseDictConfig {
74 /// Construct a config for a `K`-atom dictionary, leaving every other knob at
75 /// its shared default.
76 pub fn new(n_atoms: usize) -> Self {
77 Self {
78 n_atoms,
79 ..Self::default()
80 }
81 }
82}
83
84impl Default for SparseDictConfig {
85 fn default() -> Self {
86 Self {
87 n_atoms: 1,
88 active: 1,
89 minibatch: 512,
90 max_epochs: 30,
91 score_tile: 4096,
92 code_ridge: 1.0e-6,
93 decoder_ridge: 1.0e-6,
94 tolerance: 1.0e-6,
95 }
96 }
97}
98
99/// Result of a collapsed-linear-lane fit.
100///
101/// The routing is stored fixed-width and **sparse**: `indices[N, s]` /
102/// `codes[N, s]`. There is deliberately no dense `N×K` assignment matrix —
103/// reconstructing it would defeat the purpose of the lane.
104#[derive(Clone, Debug)]
105pub struct SparseDictFit {
106 /// Decoder, `K×P`, unit-norm rows (one atom per row).
107 pub decoder: Array2<f32>,
108 /// Active atom indices per row, `N×s` (column `j` of row `i` is the `j`-th
109 /// active atom for that row). Rows with fewer than `s` live atoms pad with
110 /// repeated indices whose matching code is zero.
111 pub indices: Array2<u32>,
112 /// Sparse codes per row, `N×s`, aligned with [`Self::indices`].
113 pub codes: Array2<f32>,
114 /// Final held-in explained variance (`1 − RSS/TSS`).
115 pub explained_variance: f64,
116 /// Number of epochs actually run.
117 pub epochs: usize,
118 /// Whether the EV-improvement tolerance was reached.
119 pub converged: bool,
120 /// Active budget `s` actually used (`min(active, K)`).
121 pub active: usize,
122}
123
124impl SparseDictFit {
125 /// Dense reconstruction `N×P` of the training rows from the sparse routing.
126 ///
127 /// This *does* allocate an `N×P` array (the data size, not `N×K`); it exists
128 /// for diagnostics / EV checks, not as part of the trainer's hot loop.
129 pub fn reconstruct(&self) -> Array2<f32> {
130 let n = self.indices.nrows();
131 let p = self.decoder.ncols();
132 let mut out = Array2::<f32>::zeros((n, p));
133 for i in 0..n {
134 for j in 0..self.active {
135 let atom = self.indices[[i, j]] as usize;
136 let code = self.codes[[i, j]];
137 if code == 0.0 {
138 continue;
139 }
140 let row = self.decoder.row(atom);
141 for c in 0..p {
142 out[[i, c]] += code * row[c];
143 }
144 }
145 }
146 out
147 }
148}
149
150/// Fit a fixed-`K` sparse minibatched linear dictionary to `x` (`N×P`).
151///
152/// This is the public entry of the collapsed linear lane. It never forms a
153/// dense `N×K` object: scoring is tiled, routing is fixed-width sparse, and the
154/// decoder is refreshed from accumulated sparse normal equations.
155pub fn fit_sparse_dictionary(
156 x: ArrayView2<'_, f32>,
157 config: &SparseDictConfig,
158) -> Result<SparseDictFit, String> {
159 update::run(x, config)
160}