gam_sae/manifold/row_layout.rs
1use super::*;
2
3// The JumpReLU optimization-inclusion band is the single canonical predicate
4// `crate::assignment::jumprelu_in_optimization_band`, whose support
5// is the machine-precision cutoff `(logit − threshold)/τ > −36` (`σ(−36) ≈
6// 2e-16`). The compact Newton active set below MUST use exactly that band so
7// every coordinate carrying nonzero sparsity-prior value/gradient/Hessian
8// (assembled over the same −36 support in `assignment.rs`, and in the logdet
9// third-derivative adjoint in `construction.rs`) has a Newton row to receive
10// it. A former module-local copy used a tighter `−4·τ` band, which silently
11// dropped coordinates in `(−36τ, −4τ]` from the solve while the prior still put
12// gradient on them — an objective↔gradient desync that stalls the inner Newton
13// fit. That copy and its `JUMPRELU_REACTIVATION_MARGIN` constant are deleted;
14// there is now one band, one source of truth.
15
16/// Per-row active-set layout for sparse SAE assignment (any mode).
17///
18/// When the assignment is sparse — structurally (JumpReLU gate) or
19/// effectively (softmax / IBP-MAP at large `K`, where the assignment mass
20/// concentrates on a small support) — only a subset of `K` atoms are active
21/// per observation. The Arrow-Schur row block for observation `i` has dim
22/// `q_active_i = |active_atoms_i| + Σ_{k ∈ active_i} d_k` rather than
23/// `q = assignment_dim + Σ_k d_k`. This struct records which atoms are active per row
24/// and maps compressed block positions back to full-q positions so that
25/// `apply_newton_step` can unpack the compact `delta_t` from the solve.
26///
27/// For JumpReLU the active set is exactly the gated support
28/// (`a_{n,k} ≠ 0`), so the compact solve is identity to the dense solve.
29/// For IBP-MAP the active set is the union of a top-`k_active_cap`
30/// truncation and a magnitude cutoff on `a_{n,k}`; this is only enabled when
31/// `K` is large enough that the dense `(m_total · p)²` data Gram would not
32/// fit the host / device working-set budget, and the dropped atoms carry
33/// `O(a_{n,k}²)` curvature that is negligible by construction of the cutoff.
34///
35/// #1408: SOFTMAX engages this compact layout when an explicit `top_k`
36/// (`softmax_active_cap`) and/or the in-core memory budget bounds the active
37/// set — the `AssignmentMode::Softmax` arm of `assemble_arrow_schur` consults
38/// [`crate::manifold::SaeManifoldTerm::softmax_active_plan`] and,
39/// on `Some((cap, cutoff))`, builds the active set via
40/// [`Self::from_dense_weights`]. The full-`K` dense softmax layout is retained
41/// only when neither lever engages (no `top_k`, in-budget `K`). Folding softmax
42/// `top_k` into the compact solve required writing the active×active Gershgorin
43/// Loewner majorizer sub-block (#1419; the softmax entropy curvature is
44/// indefinite, so its raw diagonal cannot be used) AND contracting that SAME
45/// majorizer over the compact logit slots in the logdet ρ-trace
46/// (`assignment_log_strength_hessian_trace`) and the θ-adjoint, so value,
47/// `log|H|`, and Γ differentiate one operator on the compact support. That
48/// coordinated change is landed and FD-certified; the FFI's after-the-fit
49/// top-`k` projection is then a no-op at the optimum.
50#[derive(Debug, Clone)]
51pub struct SaeRowLayout {
52 /// `active_atoms[row]` — sorted indices of active atoms for that row.
53 pub active_atoms: Vec<Vec<usize>>,
54 /// For row `i`, active atom `active_atoms[i][j]` has its coord block
55 /// starting at compressed position `coord_starts[i][j]`.
56 pub coord_starts: Vec<Vec<usize>>,
57 /// Full-q coordinate offset for atom `k` (length `k_atoms`).
58 pub coord_offsets_full: Vec<usize>,
59 /// Per-atom coordinate dimensions, indexed by atom index.
60 pub coord_dims: Vec<usize>,
61}
62
63impl SaeRowLayout {
64 /// JumpReLU optimization active set: atoms inside the smooth prior's
65 /// machine-precision support `(logit - threshold)/tau > -36` (see
66 /// [`crate::assignment::jumprelu_in_optimization_band`], the one
67 /// canonical band). This is intentionally wider than the hard forward gate
68 /// `logit > threshold` so gated-off atoms can remain in the Newton system for
69 /// value-consistent prior terms. Their forward reconstruction contribution
70 /// and data-fit logit JVP remain hard-zero while `a_k = 0`.
71 pub(crate) fn from_jumprelu(
72 n: usize,
73 k_atoms: usize,
74 threshold: f64,
75 temperature: f64,
76 logits: &Array2<f64>,
77 coord_dims: Vec<usize>,
78 coord_offsets_full: Vec<usize>,
79 ) -> Self {
80 let mut per_row = Vec::with_capacity(n);
81 for row in 0..n {
82 let row_logits = logits.row(row);
83 let active: Vec<usize> = (0..k_atoms)
84 .filter(|&k| {
85 crate::assignment::jumprelu_in_optimization_band(
86 row_logits[k],
87 threshold,
88 temperature,
89 )
90 })
91 .collect();
92 per_row.push(active);
93 }
94 Self::from_active_atoms(per_row, coord_dims, coord_offsets_full)
95 }
96
97 /// Mode-agnostic effective active set for dense-weight modes (softmax /
98 /// IBP-MAP) at large `K`: keep, per row, the top-`k_active_cap` atoms by
99 /// `|a_{n,k}|` whose magnitude also exceeds `relative_cutoff · rowpeak`.
100 ///
101 /// #1414: the cutoff is RELATIVE TO EACH ROW'S OWN PEAK `max_k |a_{n,k}|`,
102 /// matching the documented `sparse_active_plan` contract
103 /// (`construction.rs:1763-1766`). A global cutoff (one threshold from the
104 /// whole-dataset peak) would wrongly drop both atoms of a uniformly-small row
105 /// `[0.0009, 0.0008]` just because another row peaks at `1.0`, changing the
106 /// high-`K` compact model.
107 ///
108 /// `assignments[row]` is the dense length-`K` assignment vector `a_{n,·}`.
109 /// The active set is always non-empty (the single largest-magnitude atom is
110 /// retained even if below cutoff) so every row keeps a valid block.
111 pub(crate) fn from_dense_weights(
112 assignments: &[Array1<f64>],
113 k_active_cap: usize,
114 relative_cutoff: f64,
115 coord_dims: Vec<usize>,
116 coord_offsets_full: Vec<usize>,
117 ) -> Self {
118 let cap = k_active_cap.max(1);
119 let mut per_row = Vec::with_capacity(assignments.len());
120 for a in assignments {
121 let k = a.len();
122 // #1411: select the top-`cap` atoms by |a_k| in O(K) with a PARTIAL
123 // select (`select_nth_unstable_by`), not a full O(K log K) sort. Only
124 // the cap-sized active prefix matters; its internal order is
125 // irrelevant (sorted at the end). The row peak is a separate O(K) max
126 // scan. End-to-end this keeps support proposal O(K) (single pass +
127 // partial select), the contracted per-token cost the high-K plan
128 // claims, instead of sorting all K per row.
129 let row_peak = a.iter().fold(0.0_f64, |m, &v| m.max(v.abs()));
130 let cutoff = relative_cutoff * row_peak;
131 let mut idx: Vec<usize> = (0..k).collect();
132 // Partition so the `cap` largest-|a| indices occupy `idx[..cap]`
133 // (unordered within); cheaper than a full sort when `cap << k`.
134 if cap < k {
135 idx.select_nth_unstable_by(cap - 1, |&i, &j| {
136 a[j].abs()
137 .partial_cmp(&a[i].abs())
138 .unwrap_or(std::cmp::Ordering::Equal)
139 });
140 idx.truncate(cap);
141 }
142 let mut active: Vec<usize> = idx
143 .into_iter()
144 .filter(|&k_idx| a[k_idx].abs() > cutoff)
145 .collect();
146 if active.is_empty() {
147 // Retain the single largest-magnitude atom so the row block is
148 // never empty (a degenerate empty block would zero the row).
149 let top = (0..k).fold(None::<usize>, |best, i| match best {
150 Some(b) if a[b].abs() >= a[i].abs() => Some(b),
151 _ => Some(i),
152 });
153 if let Some(top) = top {
154 active.push(top);
155 }
156 }
157 active.sort_unstable();
158 per_row.push(active);
159 }
160 Self::from_active_atoms(per_row, coord_dims, coord_offsets_full)
161 }
162
163 /// Build from explicit per-row active-atom index lists.
164 pub(crate) fn from_active_atoms(
165 active_atoms: Vec<Vec<usize>>,
166 coord_dims: Vec<usize>,
167 coord_offsets_full: Vec<usize>,
168 ) -> Self {
169 let mut coord_starts_all = Vec::with_capacity(active_atoms.len());
170 for active in &active_atoms {
171 let mut starts = Vec::with_capacity(active.len());
172 let mut cursor = active.len();
173 for &k in active {
174 starts.push(cursor);
175 cursor += coord_dims[k];
176 }
177 coord_starts_all.push(starts);
178 }
179 Self {
180 active_atoms,
181 coord_starts: coord_starts_all,
182 coord_offsets_full,
183 coord_dims,
184 }
185 }
186
187 /// Per-row compressed dim.
188 pub fn row_q_active(&self, row: usize) -> usize {
189 let active = &self.active_atoms[row];
190 let coord_sum: usize = active.iter().map(|&k| self.coord_dims[k]).sum();
191 active.len() + coord_sum
192 }
193
194 /// Expand a compact `delta_t` row slice back into full-q, zeros for inactive.
195 pub fn expand_row(&self, row: usize, delta_t_row: &[f64], out: &mut [f64]) {
196 for v in out.iter_mut() {
197 *v = 0.0;
198 }
199 let active = &self.active_atoms[row];
200 let starts = &self.coord_starts[row];
201 for (j, &k) in active.iter().enumerate() {
202 out[k] = delta_t_row[j];
203 let d = self.coord_dims[k];
204 let full_off = self.coord_offsets_full[k];
205 for axis in 0..d {
206 out[full_off + axis] = delta_t_row[starts[j] + axis];
207 }
208 }
209 }
210}