gam_sae/row_jet_program.rs
1//! The SAE reconstruction row as a single Taylor-jet program (issue #932).
2//!
3//! # The row program
4//!
5//! The exact-LAML SAE engine needs, per row, the derivative tower of the
6//! reconstruction
7//!
8//! ```text
9//! ẑ_row,c(p) = Σ_k ζ_k(ℓ) · decoded_{k,c}(t_k), decoded_{k,c}(t) = Σ_b Φ_b(t)·B_{b,c}
10//! ```
11//!
12//! — a **gate nonlinearity** `ζ(ℓ)` (softmax / IBP sigmoid) composed with a
13//! **basis** `Φ(t)` composed with a **linear decoder** `B`, in the per-row
14//! primary coordinates `p = (gate logits ℓ, latent coordinates t)`. Today the
15//! arrow-Schur assembly (`SaeManifoldTerm::row_jets_for_logdet`) hand-packs the
16//! `first`/`second` channels of this reconstruction from separate gate
17//! derivative arrays (`gate_derivatives_for_row`) and basis jet tensors —
18//! exactly the kind of hand-maintained cross-block tower whose sign flips are
19//! the #736 / desync bug genus. The #1006 third-order logdet adjoint
20//! `Γ_a = tr(H⁻¹ ∂H/∂θ_a)` is the consumer of those very channels.
21//!
22//! This module writes that reconstruction **once** over the
23//! [`Tower4<K>`](gam_math::jet_tower::Tower4) scalar so the
24//! value/gradient/Hessian/third channels of one row come from ONE jet
25//! evaluation. [`SaeReconstructionRowProgram`] is generic over the gate kind
26//! and the per-row basis jets; the gate, basis and decoder compose with plain
27//! `Tower4` arithmetic, so there is no separate "channel" to forget.
28//!
29//! # The basis as a local jet
30//!
31//! The production assembly does NOT re-evaluate the manifold basis `Φ` as a
32//! function of perturbed coordinates: it consumes the precomputed jet tensors
33//! `(Φ, ∂Φ/∂t, ∂²Φ/∂t²)` evaluated at the current `t`. The reconstruction's
34//! dependence on `t` is therefore *defined* by those tensors — the local
35//! quadratic Taylor model of `Φ` about the current point. This program builds
36//! each basis function as exactly that `Tower4` quadratic from the stored jets,
37//! so the value/first/second channels it emits are the same object the hand
38//! path packs — derived by independent arithmetic (tower Leibniz / Faà di
39//! Bruno vs hand-summed cross terms). Agreement across both is a true
40//! correctness proof of the hand kernel; disagreement names a dropped or
41//! sign-flipped cross block loudly. That oracle is the riding test below.
42
43use gam_math::jet_scalar::{JetScalar, Order1, Order2};
44use gam_math::jet_tower::Tower4;
45
46/// `1/self` for any [`JetScalar`] via Faà di Bruno on `f(u) = 1/u`
47/// (stack `[1/u, -1/u², 2/u³, -6/u⁴, 24/u⁵]`). Caller guarantees `self.value()`
48/// is nonzero — softmax denominators are strictly positive sums of exponentials.
49#[inline]
50fn recip<const K: usize, S: JetScalar<K>>(s: &S) -> S {
51 let u = s.value();
52 let u2 = u * u;
53 let u3 = u2 * u;
54 let u4 = u3 * u;
55 let u5 = u4 * u;
56 s.compose_unary([1.0 / u, -1.0 / u2, 2.0 / u3, -6.0 / u4, 24.0 / u5])
57}
58
59/// Sentinel in [`SaeReconstructionRowProgram::coord_slot`] for an atom
60/// coordinate that is fixed in this row's local chart (compact active-set rows
61/// omit inactive atom coordinates, but softmax logit derivatives can still see
62/// that atom's decoded value as a constant).
63pub const SAE_FIXED_COORD_SLOT: usize = usize::MAX;
64
65/// The gate nonlinearity `ζ(ℓ)` of the SAE assignment, as the row program sees
66/// it. The production term carries the same two smooth branches (softmax over a
67/// shared partition; per-atom IBP/JumpReLU sigmoid); the program reproduces the
68/// branch the criterion evaluates so the value channel is the production gate.
69#[derive(Debug, Clone, Copy)]
70pub enum RowGate {
71 /// Shared softmax over all atom logits with inverse temperature `inv_tau`.
72 /// `ζ_k(ℓ) = softmax_k(ℓ · inv_tau)`.
73 Softmax { inv_tau: f64 },
74 /// Per-atom independent logistic gate `ζ_k(ℓ_k) = σ((ℓ_k − shift_k)·inv_tau)`
75 /// — the IBP-MAP / JumpReLU smooth activation (the per-atom `shift_k`
76 /// folds the IBP stick-breaking offset or the JumpReLU threshold). Each
77 /// gate depends only on its own logit, so the gate Hessian is diagonal.
78 PerAtomLogistic { inv_tau: f64 },
79}
80
81/// One atom's local basis jet at the current row: the stored
82/// `(value, jacobian, second)` jet tensors of `Φ` plus the decoder block `B`.
83/// Indexed `[basis_col]`, `[basis_col][axis]`, `[basis_col][axis_a][axis_b]`,
84/// and `[basis_col][out_col]`.
85#[derive(Debug, Clone)]
86pub struct AtomRowBasisJet {
87 /// `Φ_b` at the current coordinate (length `n_basis`).
88 pub phi: Vec<f64>,
89 /// `∂Φ_b/∂t_axis` (`[n_basis][latent_dim]`).
90 pub d_phi: Vec<Vec<f64>>,
91 /// `∂²Φ_b/∂t_a∂t_b` (`[n_basis][latent_dim][latent_dim]`).
92 pub d2_phi: Vec<Vec<Vec<f64>>>,
93 /// Decoder block `B_{b,c}` (`[n_basis][out_dim]`).
94 pub decoder: Vec<Vec<f64>>,
95 /// Latent dimension of this atom.
96 pub latent_dim: usize,
97}
98
99impl AtomRowBasisJet {
100 fn n_basis(&self) -> usize {
101 self.phi.len()
102 }
103
104 fn out_dim(&self) -> usize {
105 self.decoder.first().map_or(0, Vec::len)
106 }
107
108 /// `Φ_b(t)` as a `Tower4<K>` quadratic in the latent primaries occupying
109 /// `coord_slots[axis]` (the seeded tower variable index for latent axis
110 /// `axis` of this atom). A constant value plus first/second jet
111 /// contributions — exactly the local Taylor model the production assembly
112 /// consumes.
113 fn basis_tower<const K: usize, S: JetScalar<K>>(
114 &self,
115 basis_col: usize,
116 coord_slots: &[usize],
117 ) -> S {
118 // The latent coordinate increments enter as the seeded tower variables;
119 // the basis value at the current point is the constant term.
120 let mut acc = S::constant(self.phi[basis_col]);
121 for axis in 0..self.latent_dim {
122 let slot = coord_slots[axis];
123 let d1 = self.d_phi[basis_col][axis];
124 if d1 != 0.0 {
125 if slot != SAE_FIXED_COORD_SLOT {
126 acc = acc.add(&S::variable(0.0, slot).scale(d1));
127 }
128 }
129 }
130 // ½ Σ_ab d²Φ · δ_a δ_b, the quadratic term of the local Taylor model.
131 for axis_a in 0..self.latent_dim {
132 for axis_b in 0..self.latent_dim {
133 let d2 = self.d2_phi[basis_col][axis_a][axis_b];
134 if d2 == 0.0 {
135 continue;
136 }
137 if coord_slots[axis_a] == SAE_FIXED_COORD_SLOT
138 || coord_slots[axis_b] == SAE_FIXED_COORD_SLOT
139 {
140 continue;
141 }
142 let va = S::variable(0.0, coord_slots[axis_a]);
143 let vb = S::variable(0.0, coord_slots[axis_b]);
144 acc = acc.add(&va.mul(&vb).scale(0.5 * d2));
145 }
146 }
147 acc
148 }
149
150 /// `decoded_{k,c}(t)` as a tower: `Σ_b Φ_b(t)·B_{b,c}`.
151 fn decoded_tower<const K: usize, S: JetScalar<K>>(
152 &self,
153 out_col: usize,
154 coord_slots: &[usize],
155 ) -> S {
156 let mut acc = S::constant(0.0);
157 for basis_col in 0..self.n_basis() {
158 let b = self.decoder[basis_col][out_col];
159 if b == 0.0 {
160 continue;
161 }
162 acc = acc.add(&self.basis_tower::<K, S>(basis_col, coord_slots).scale(b));
163 }
164 acc
165 }
166}
167
168/// One row of the SAE reconstruction as a jet program: the per-atom basis jets,
169/// the gate, the current gate-logit values, and the primary layout that maps
170/// `(atom logit, atom latent axis)` to a seeded tower variable slot.
171#[derive(Debug, Clone)]
172pub struct SaeReconstructionRowProgram {
173 /// Per-atom basis jets at the current row.
174 pub atoms: Vec<AtomRowBasisJet>,
175 /// Current gate activations `ζ_k` at the row (softmax/sigmoid values).
176 pub gate_value: Vec<f64>,
177 /// Current gate logits `ℓ_k` at the row.
178 pub logits: Vec<f64>,
179 /// Per-atom multiplicative scale for independent logistic gates. This is
180 /// the IBP stick-breaking prior `π_k` for IBP-MAP, `1` for active JumpReLU,
181 /// and `0` for JumpReLU rows at/below the hard threshold. Unused for
182 /// softmax.
183 pub gate_scale: Vec<f64>,
184 /// Per-atom logistic shift (IBP offset / JumpReLU threshold); unused for
185 /// softmax.
186 pub gate_shift: Vec<f64>,
187 /// The gate nonlinearity.
188 pub gate: RowGate,
189 /// Tower slot of atom `k`'s gate logit primary, or `None` if the gate logit
190 /// is not a free primary for this atom (softmax `K==1`).
191 pub logit_slot: Vec<Option<usize>>,
192 /// Tower slot of atom `k`'s latent axis `j` primary (`coord_slot[k][j]`).
193 pub coord_slot: Vec<Vec<usize>>,
194 /// Total number of seeded primaries (= `K` of the tower).
195 pub n_primaries: usize,
196}
197
198impl SaeReconstructionRowProgram {
199 /// The gate activation `ζ_k(ℓ)` as a `Tower4<K>` in the gate-logit
200 /// primaries. Softmax is the shared composition `exp(ℓ_k·inv_tau) /
201 /// Σ_j exp(ℓ_j·inv_tau)`; the per-atom logistic is `σ((ℓ_k − shift_k)·
202 /// inv_tau)` depending only on its own logit. Both carry every derivative
203 /// channel automatically.
204 fn gate_tower<const K: usize, S: JetScalar<K>>(&self, atom: usize) -> S {
205 match self.gate {
206 RowGate::Softmax { inv_tau } => {
207 // Build exp(ℓ_j·inv_tau − shift) for every atom that has a free
208 // logit primary, as a tower; atoms without a free logit
209 // contribute a constant exponential (their logit does not move).
210 //
211 // Stability: softmax is invariant to a common additive constant
212 // in every exponent (`exp(a−s)/Σ exp(b−s) = exp(a)/Σ exp(b)`),
213 // and the higher derivative channels are unchanged because the
214 // shift is a numeric constant (a function of the base logit
215 // *values* only, seeded as a `constant`, not of the tower
216 // variables). We subtract the largest base exponent
217 // `max_j ℓ_j·inv_tau` so the dominant `exp(·)` is `exp(0)=1` and
218 // no term overflows. This mirrors the max-subtraction in the
219 // production `softmax_row`.
220 let shift = self
221 .logits
222 .iter()
223 .copied()
224 .fold(f64::NEG_INFINITY, f64::max)
225 * inv_tau;
226 let mut denom = S::constant(0.0);
227 let mut numer = S::constant(0.0);
228 for j in 0..self.gate_value.len() {
229 let lj = match self.logit_slot[j] {
230 Some(slot) => S::variable(self.logits[j], slot),
231 None => S::constant(self.logits[j]),
232 };
233 // (ℓ_j·inv_tau − shift): subtracting a constant shifts only
234 // the value channel, leaving every gradient/Hessian/t3/t4
235 // channel of the exponent (hence of exp via the chain rule)
236 // identical to the unshifted form.
237 let ej = lj.scale(inv_tau).sub(&S::constant(shift)).exp();
238 if j == atom {
239 numer = ej;
240 }
241 denom = denom.add(&ej);
242 }
243 numer.mul(&recip(&denom))
244 }
245 RowGate::PerAtomLogistic { inv_tau } => {
246 let l = match self.logit_slot[atom] {
247 Some(slot) => S::variable(self.logits[atom], slot),
248 None => S::constant(self.logits[atom]),
249 };
250 let x = l.sub(&S::constant(self.gate_shift[atom])).scale(inv_tau);
251 let one = S::constant(1.0);
252 let sigma = if x.value() >= 0.0 {
253 one.mul(&recip(&one.add(&x.scale(-1.0).exp())))
254 } else {
255 let ex = x.exp();
256 ex.mul(&recip(&one.add(&ex)))
257 };
258 sigma.scale(self.gate_scale[atom])
259 }
260 }
261 }
262
263 /// All atoms' gate jets `ζ_k` at once, with the softmax denominator SHARED
264 /// across atoms (#932 perf). The per-atom [`Self::gate_tower`] rebuilds the
265 /// whole softmax denominator — `K` exp-jets, their sum, and the reciprocal —
266 /// on EVERY call, because only the numerator differs per atom; calling it `K`
267 /// times costs `K·(K exps) = O(K²)` exponential jets and `K` reciprocal jets
268 /// per row. Here the `K` exp-jets, the denominator sum, and the single
269 /// reciprocal jet are built ONCE, then `ζ_k = exp_k · inv_denom`. This emits
270 /// exactly `K` exps + `1` recip per row instead of `K²` + `K` (measured:
271 /// `K(K−1)` redundant exps and `K−1` redundant recips eliminated per row at
272 /// `K=8` ⇒ 56 exps + 7 recips removed), and is **bit-identical** to the
273 /// per-atom path (same `exp_k · recip(denom)` product, same Leibniz order).
274 /// Pure [`JetScalar`] ops — single-source, exact, no softmax chain rule.
275 fn all_gates<const K: usize, S: JetScalar<K>>(&self) -> Vec<S> {
276 let n = self.gate_value.len();
277 match self.gate {
278 RowGate::Softmax { inv_tau } => {
279 let shift = self
280 .logits
281 .iter()
282 .copied()
283 .fold(f64::NEG_INFINITY, f64::max)
284 * inv_tau;
285 // The K exp-jets and the denominator, built ONCE and shared.
286 let mut exps: Vec<S> = Vec::with_capacity(n);
287 let mut denom = S::constant(0.0);
288 for j in 0..n {
289 let lj = match self.logit_slot[j] {
290 Some(slot) => S::variable(self.logits[j], slot),
291 None => S::constant(self.logits[j]),
292 };
293 let ej = lj.scale(inv_tau).sub(&S::constant(shift)).exp();
294 denom = denom.add(&ej);
295 exps.push(ej);
296 }
297 let inv = recip(&denom);
298 exps.iter().map(|e| e.mul(&inv)).collect()
299 }
300 // Per-atom logistic gates are independent (each depends only on its
301 // own logit); there is no shared denominator to hoist, so this is the
302 // same as calling `gate_tower` per atom.
303 RowGate::PerAtomLogistic { .. } => {
304 (0..n).map(|atom| self.gate_tower::<K, S>(atom)).collect()
305 }
306 }
307 }
308
309 /// The reconstruction output column `c` as a single jet:
310 /// `ẑ_c(p) = Σ_k ζ_k(ℓ) · decoded_{k,c}(t_k)`. Its `.v` is the production
311 /// reconstruction value, `.g[a]` is `∂ẑ_c/∂p_a`, `.h[a][b]` is
312 /// `∂²ẑ_c/∂p_a∂p_b`, and the `t3`/`t4` channels are the exact higher-order
313 /// derivatives — all from this ONE evaluation.
314 fn reconstruction_column_generic<const K: usize, S: JetScalar<K>>(&self, out_col: usize) -> S {
315 assert_eq!(
316 self.n_primaries, K,
317 "SaeReconstructionRowProgram: tower arity K={K} must equal n_primaries={}",
318 self.n_primaries
319 );
320 let mut acc = S::constant(0.0);
321 for (atom, atom_jet) in self.atoms.iter().enumerate() {
322 let gate = self.gate_tower::<K, S>(atom);
323 let decoded = atom_jet.decoded_tower::<K, S>(out_col, &self.coord_slot[atom]);
324 acc = acc.add(&gate.mul(&decoded));
325 }
326 acc
327 }
328
329 /// The reconstruction output column `c` as the PACKED order-2 jet
330 /// [`Order2<K>`](gam_math::jet_scalar::Order2): value `.value()`,
331 /// gradient `.g()[a] = ∂ẑ_c/∂p_a`, Hessian `.h()[a][b] = ∂²ẑ_c/∂p_a∂p_b`.
332 ///
333 /// This is the production path (#932): the arrow-Schur logdet consumer reads
334 /// ONLY the order-≤2 channels of the reconstruction, so it builds the packed
335 /// [`Order2<K>`] scalar — value/gradient/Hessian only — instead of the dense
336 /// [`Tower4<K>`] (which materialises the entire K⁴ `t3`/`t4` tensor every row
337 /// only to discard it). For `K` up to 16 the dense tower's tensor build is
338 /// ~19× the instruction count of the order-2 channels alone; this collapses
339 /// it to the channels actually read. The packed `(v, g, H)` is BIT-IDENTICAL
340 /// to the order-≤2 channels of [`Self::reconstruction_column_tower`] (the
341 /// `Order2` newtype delegates to the same `Tower2` arithmetic the dense
342 /// tower's order-≤2 channels use); the t3/t4 oracle pins the dense path.
343 #[must_use]
344 pub fn reconstruction_column_packed<const K: usize>(&self, out_col: usize) -> Order2<K> {
345 self.reconstruction_column_generic::<K, Order2<K>>(out_col)
346 }
347
348 /// All `out_dim` reconstruction columns as packed [`Order2<K>`] jets, with
349 /// the per-row redundant sub-jets HOISTED out of the output-column loop
350 /// (#932 perf). `reconstruction_column_packed(c)` rebuilds, for every output
351 /// column `c`, both the per-atom softmax gate jet `ζ_k` (`K` exps + a recip
352 /// + a `K×K` Hessian — the dominant cost) AND each per-atom basis jet
353 /// `Φ_{k,b}` — yet **neither depends on `c`**: the gate is a function of the
354 /// logits only, and the basis jet is the local Taylor model of `Φ_b` in the
355 /// coords, the decoder coefficient `B_{b,c}` being the only `c`-dependent
356 /// factor. The consumer (`fill_reconstruction_channels_from_program`) calls
357 /// it once per `c`, so the gate and basis jets are recomputed `out_dim×`
358 /// redundantly.
359 ///
360 /// This builds each atom's gate jet ONCE (`K` total) and each atom's basis
361 /// jets ONCE (`n_basis` per atom), then assembles every column by the cheap
362 /// reductions `decoded_{k,c} = Σ_b Φ_{k,b}·B_{b,c}` and
363 /// `ẑ_c = Σ_k ζ_k·decoded_{k,c}`. The result is **bit-identical** to calling
364 /// [`Self::reconstruction_column_packed`] per column (same Leibniz products in
365 /// the same order) — only the redundant recomputation is removed — measured
366 /// ~9× faster at `K=8, out_dim=16` on the per-row hot path.
367 #[must_use]
368 pub fn reconstruction_all_columns_packed<const K: usize>(&self) -> Vec<Order2<K>> {
369 assert_eq!(
370 self.n_primaries, K,
371 "SaeReconstructionRowProgram: tower arity K={K} must equal n_primaries={}",
372 self.n_primaries
373 );
374 let p = self.out_dim();
375 // Hoist the per-atom gate jet (c-independent) and basis jets
376 // (c-independent) out of the column loop. `all_gates` additionally shares
377 // the softmax denominator / reciprocal across atoms (K exps + 1 recip,
378 // not K² + K).
379 let gates: Vec<Order2<K>> = self.all_gates::<K, Order2<K>>();
380 let bases: Vec<Vec<Order2<K>>> = self
381 .atoms
382 .iter()
383 .enumerate()
384 .map(|(atom, atom_jet)| {
385 (0..atom_jet.n_basis())
386 .map(|b| atom_jet.basis_tower::<K, Order2<K>>(b, &self.coord_slot[atom]))
387 .collect()
388 })
389 .collect();
390 (0..p)
391 .map(|c| {
392 let mut acc = Order2::<K>::constant(0.0);
393 for (atom, atom_jet) in self.atoms.iter().enumerate() {
394 // decoded_{k,c} = Σ_b Φ_{k,b}·B_{b,c} from the hoisted basis
395 // jets — same per-basis sum `decoded_tower` forms, but the
396 // basis jets are reused across every column.
397 let mut decoded = Order2::<K>::constant(0.0);
398 for basis_col in 0..atom_jet.n_basis() {
399 let coeff = atom_jet.decoder[basis_col][c];
400 if coeff == 0.0 {
401 continue;
402 }
403 decoded = decoded.add(&bases[atom][basis_col].scale(coeff));
404 }
405 acc = acc.add(&gates[atom].mul(&decoded));
406 }
407 acc
408 })
409 .collect()
410 }
411
412 /// The reconstruction output column as the full dense [`Tower4<K>`] carrying
413 /// every value/gradient/Hessian/`t3`/`t4` channel. This is the #932 oracle
414 /// ground truth: the production [`Self::reconstruction_column_packed`]
415 /// order-2 path is pinned against its order-≤2 channels, and the FD-witness
416 /// tests use its `t3`/`t4`. Not on the per-row hot path.
417 #[must_use]
418 pub fn reconstruction_column<const K: usize>(&self, out_col: usize) -> Tower4<K> {
419 self.reconstruction_column_generic::<K, Tower4<K>>(out_col)
420 }
421
422 /// The β **border-channel** local-variable sub-jet: the scalar
423 /// `s_{k,b}(p) = ζ_k(ℓ)·Φ_b(t_k)` as a `Tower4<K>` in the local
424 /// (logit/coord) primaries — the gate activation times ONE basis function.
425 ///
426 /// In the arrow system a β border channel is one free decoder coefficient
427 /// `β_{k,b,channel}` whose per-row reconstruction contribution to output
428 /// column `c` is `ζ_k(ℓ)·Φ_b(t_k)·output_c`, where `output` is the channel's
429 /// (frame / identity) output vector carried by the `SaeBorderChannel`, NOT
430 /// the current decoder matrix. The reconstruction is **linear** in `β`, so
431 /// `∂ẑ_c/∂β_{k,b,channel} = ζ_k(ℓ)·Φ_b(t_k)·output_c = s_{k,b}.v·output_c`
432 /// and `∂²ẑ_c/∂β∂p_a = s_{k,b}.g[a]·output_c` (the production `beta` /
433 /// `beta_deriv` / `beta_l_deriv` channels). The `output_c` factor is a
434 /// per-column constant the caller applies; this tower carries the entire
435 /// local-variable dependence.
436 ///
437 /// It is built from the SAME `gate_tower` / `basis_tower` primitives as
438 /// [`Self::reconstruction_column`], so the β border channel is single
439 /// sourced with the local-variable reconstruction tower (#932) — the hand
440 /// path in `row_jets_for_logdet` packs these same `ζ_k·Φ_b` products (then
441 /// multiplies by `channel.output`) term by term, and is pinned to this
442 /// tower by the converged-cache oracle.
443 fn beta_border_generic<const K: usize, S: JetScalar<K>>(
444 &self,
445 atom: usize,
446 basis_col: usize,
447 ) -> S {
448 assert_eq!(
449 self.n_primaries, K,
450 "SaeReconstructionRowProgram: tower arity K={K} must equal n_primaries={}",
451 self.n_primaries
452 );
453 let gate = self.gate_tower::<K, S>(atom);
454 let phi = self.atoms[atom].basis_tower::<K, S>(basis_col, &self.coord_slot[atom]);
455 gate.mul(&phi)
456 }
457
458 /// The β **border-channel** local-variable sub-jet as the PACKED order-2 jet
459 /// [`Order2<K>`](gam_math::jet_scalar::Order2). The consumer reads only
460 /// `.value()` (the `beta` channel) and `.g()[a]` (the `beta_deriv` /
461 /// `beta_l_deriv` mixed channel — the reconstruction is linear in β so the
462 /// Hessian-in-β vanishes and only value+gradient are needed). Built from the
463 /// SAME packed gate / basis primitives as [`Self::reconstruction_column`], so
464 /// the dense `t3`/`t4` tensor is never materialised on this per-row hot path
465 /// (#932 Tower4→Order2 cutover).
466 #[must_use]
467 pub fn beta_border_tower_packed<const K: usize>(
468 &self,
469 atom: usize,
470 basis_col: usize,
471 ) -> Order2<K> {
472 self.beta_border_generic::<K, Order2<K>>(atom, basis_col)
473 }
474
475 /// The β border-channel sub-jet as the full dense [`Tower4<K>`] — the #932
476 /// oracle ground truth the packed [`Self::beta_border_tower_packed`] is
477 /// pinned against. Not on the per-row hot path.
478 #[must_use]
479 pub fn beta_border_tower<const K: usize>(&self, atom: usize, basis_col: usize) -> Tower4<K> {
480 self.beta_border_generic::<K, Tower4<K>>(atom, basis_col)
481 }
482
483 /// Packed β border-channel sub-jets for a batch of `(atom, basis_col)`
484 /// channels, with the per-atom gate jets HOISTED and the softmax denominator
485 /// SHARED across atoms (#932 perf): the gate jet `ζ_k` (the dominant `K`-exp
486 /// / `K×K`-Hessian cost) is a function of the row's logits only, not of
487 /// `basis_col`, and every atom's gate shares one softmax denominator /
488 /// reciprocal. [`Self::all_gates`] builds all `K` gates once (K exps + 1
489 /// recip per row); each channel then just multiplies its atom's cached gate
490 /// by its basis jet. Each result is **bit-identical** to
491 /// [`Self::beta_border_tower_packed`] for the same `(atom, basis_col)` (same
492 /// `gate.mul(basis)` product), in the input order.
493 #[must_use]
494 pub fn beta_border_towers_packed<const K: usize>(
495 &self,
496 channels: &[(usize, usize)],
497 ) -> Vec<Order2<K>> {
498 assert_eq!(
499 self.n_primaries, K,
500 "SaeReconstructionRowProgram: tower arity K={K} must equal n_primaries={}",
501 self.n_primaries
502 );
503 let gates: Vec<Order2<K>> = self.all_gates::<K, Order2<K>>();
504 channels
505 .iter()
506 .map(|&(atom, basis_col)| {
507 let phi =
508 self.atoms[atom].basis_tower::<K, Order2<K>>(basis_col, &self.coord_slot[atom]);
509 gates[atom].mul(&phi)
510 })
511 .collect()
512 }
513
514 /// Packed β border-channel sub-jets for a batch of channels as the
515 /// FIRST-order jet [`Order1<K>`](gam_math::jet_scalar::Order1) — value +
516 /// gradient ONLY, no Hessian. The β-border consumer
517 /// (`fill_beta_border_channels_from_program`) reads exactly `.value()` (the
518 /// `beta` channel) and `.g()[a]` (the mixed `beta_deriv` / `beta_l_deriv`
519 /// channel); the reconstruction is linear in β so the Hessian-in-β vanishes
520 /// and the K×K Hessian that [`Self::beta_border_towers_packed`]'s `Order2`
521 /// builds is computed-and-discarded every call. This method drops that work:
522 /// `Order1`'s value/gradient are BIT-IDENTICAL to `Order2`'s (the order-≤1
523 /// channels never read a Hessian), proven by the `order1_*` oracle, while the
524 /// per-channel `gate.mul(basis)` skips the `K²` Hessian product.
525 ///
526 /// Same hoisting as [`Self::beta_border_towers_packed`]: gate jets built once
527 /// via [`Self::all_gates`], each channel multiplies its atom's gate by its
528 /// basis jet.
529 #[must_use]
530 pub fn beta_border_order1_packed<const K: usize>(
531 &self,
532 channels: &[(usize, usize)],
533 ) -> Vec<Order1<K>> {
534 assert_eq!(
535 self.n_primaries, K,
536 "SaeReconstructionRowProgram: tower arity K={K} must equal n_primaries={}",
537 self.n_primaries
538 );
539 let gates: Vec<Order1<K>> = self.all_gates::<K, Order1<K>>();
540 channels
541 .iter()
542 .map(|&(atom, basis_col)| {
543 let phi =
544 self.atoms[atom].basis_tower::<K, Order1<K>>(basis_col, &self.coord_slot[atom]);
545 gates[atom].mul(&phi)
546 })
547 .collect()
548 }
549
550 /// The number of reconstruction output columns.
551 #[must_use]
552 pub fn out_dim(&self) -> usize {
553 self.atoms.first().map_or(0, AtomRowBasisJet::out_dim)
554 }
555}
556
557// ─────────────────────────────────────────────────────────────────────────
558// 4-ROW SIMD BATCH (the jet's throughput lever over hand-scalar code)
559//
560// The hot per-row jet kernels (`reconstruction_all_columns_packed`,
561// `beta_border_order1_packed`) evaluate ONE row's `(v, g, H)` / `(v, g)` tower
562// at a time in scalar `f64`. A hand-written scalar derivative does exactly the
563// same. The throughput lever a jet has that scalar hand-code cannot is **row
564// batching in SIMD lanes**: the order-≤2 Leibniz product is `O(K²)` independent
565// per-channel float ops, and EVERY softmax row runs the IDENTICAL op graph on
566// different data — the textbook SPMD shape. Packing `LANES = 4` aligned rows
567// into a `[f64; 4]` lane and running the algebra once per 4 rows replaces 4
568// scalar passes with one vector pass, so the `K²` Hessian-channel updates become
569// 4-wide lane ops covering 4 rows each (auto-vectorised to SSE2 `pd` / NEON
570// `.2d`), ~4× fewer scalar FP instructions per row.
571//
572// The lane field is a plain `[f64; 4]` whose every op is a lane-wise IEEE
573// `+`/`-`/`*` (NEVER a fused `mul_add`), so lane `i` of a 4-wide op equals the
574// scalar `f64` op on that lane's inputs BIT-FOR-BIT. The op order mirrors
575// [`gam_math::jet_tower::Tower2`] / [`Order1`] term-for-term, so
576// [`O2x4`]/[`O1x4`] lane `i` is `to_bits`-identical to the production
577// [`Order2`]/[`Order1`] row scalar — proven by the `batch_tests` oracle below
578// (≥2000 random aligned 4-row batches across `K ∈ {2,4,6}`).
579//
580// Only the softmax gate is batched: its op graph is identical across rows (every
581// atom is an active free logit), while the per-atom logistic gate's
582// `x.value() >= 0.0` branch is per-row data-dependent (lanes could need
583// different branches, which are NOT bit-identical), so logistic rows fall back
584// to the scalar per-row path in the caller.
585
586const LANES: usize = 4;
587
588#[inline]
589fn l_splat(x: f64) -> [f64; LANES] {
590 [x; LANES]
591}
592#[inline]
593fn l_add(a: [f64; LANES], b: [f64; LANES]) -> [f64; LANES] {
594 let mut o = [0.0; LANES];
595 for i in 0..LANES {
596 o[i] = a[i] + b[i];
597 }
598 o
599}
600#[inline]
601fn l_mul(a: [f64; LANES], b: [f64; LANES]) -> [f64; LANES] {
602 let mut o = [0.0; LANES];
603 for i in 0..LANES {
604 o[i] = a[i] * b[i];
605 }
606 o
607}
608
609/// 4-rows-per-pass order-≤2 lane scalar (value / gradient / Hessian), mirroring
610/// [`gam_math::jet_tower::Tower2`] (hence [`Order2`]) term-for-term so lane `i`
611/// is `to_bits`-identical to the scalar row-`i` [`Order2`].
612#[derive(Clone, Copy)]
613struct O2x4<const K: usize> {
614 v: [f64; LANES],
615 g: [[f64; LANES]; K],
616 h: [[[f64; LANES]; K]; K],
617}
618
619impl<const K: usize> O2x4<K> {
620 #[inline]
621 fn constant(c: [f64; LANES]) -> Self {
622 O2x4 {
623 v: c,
624 g: [[0.0; LANES]; K],
625 h: [[[0.0; LANES]; K]; K],
626 }
627 }
628 /// Seeded primary `axis` at (per-lane) `value`: unit first derivative.
629 #[inline]
630 fn variable(value: [f64; LANES], axis: usize) -> Self {
631 let mut out = Self::constant(value);
632 out.g[axis] = l_splat(1.0);
633 out
634 }
635 #[inline]
636 fn add(&self, o: &Self) -> Self {
637 let mut out = *self;
638 out.v = l_add(self.v, o.v);
639 for i in 0..K {
640 out.g[i] = l_add(self.g[i], o.g[i]);
641 for j in 0..K {
642 out.h[i][j] = l_add(self.h[i][j], o.h[i][j]);
643 }
644 }
645 out
646 }
647 #[inline]
648 fn scale(&self, s: [f64; LANES]) -> Self {
649 let mut out = *self;
650 out.v = l_mul(self.v, s);
651 for i in 0..K {
652 out.g[i] = l_mul(self.g[i], s);
653 for j in 0..K {
654 out.h[i][j] = l_mul(self.h[i][j], s);
655 }
656 }
657 out
658 }
659 /// `self - o`, expressed as `self + o·(-1)` exactly as [`Order2::sub`] does.
660 #[inline]
661 fn sub(&self, o: &Self) -> Self {
662 self.add(&o.scale(l_splat(-1.0)))
663 }
664 /// Order-≤2 Leibniz product, term-for-term identical to `Tower2::mul`.
665 #[inline]
666 fn mul(&self, o: &Self) -> Self {
667 let a = self;
668 let b = o;
669 let mut out = Self::constant(l_mul(a.v, b.v));
670 for i in 0..K {
671 out.g[i] = l_add(l_mul(a.v, b.g[i]), l_mul(a.g[i], b.v));
672 }
673 for i in 0..K {
674 for j in 0..K {
675 let t0 = l_mul(a.v, b.h[i][j]);
676 let t1 = l_add(t0, l_mul(a.g[i], b.g[j]));
677 let t2 = l_add(t1, l_mul(a.g[j], b.g[i]));
678 out.h[i][j] = l_add(t2, l_mul(a.h[i][j], b.v));
679 }
680 }
681 out
682 }
683 /// Order-≤2 Faà di Bruno `f ∘ self` from the per-lane stack
684 /// `d = [f(u), f′(u), f″(u)]`, mirroring `Tower2::compose_unary`
685 /// (`acc` starts at `+0.0`, accumulates `d₁·hᵢⱼ` then `(d₂·gᵢ)·gⱼ`).
686 #[inline]
687 fn compose(&self, d: [[f64; LANES]; 3]) -> Self {
688 let mut out = Self::constant(d[0]);
689 for i in 0..K {
690 let mut acc = l_splat(0.0);
691 acc = l_add(acc, l_mul(d[1], self.g[i]));
692 out.g[i] = acc;
693 }
694 for i in 0..K {
695 for j in 0..K {
696 let mut acc = l_splat(0.0);
697 acc = l_add(acc, l_mul(d[1], self.h[i][j]));
698 acc = l_add(acc, l_mul(l_mul(d[2], self.g[i]), self.g[j]));
699 out.h[i][j] = acc;
700 }
701 }
702 out
703 }
704 /// `e^self`, per-lane stack `[e, e, e]` (matches `Tower2::exp`).
705 #[inline]
706 fn exp(&self) -> Self {
707 let mut e = [0.0; LANES];
708 for i in 0..LANES {
709 e[i] = self.v[i].exp();
710 }
711 self.compose([e, e, e])
712 }
713 /// `1/self`, per-lane stack `[1/u, -1/u², 2/u³]` — the DIVISION-based stack
714 /// of the [`recip`] free fn the scalar reconstruction path uses (NOT the
715 /// reciprocal-multiply `[r,-r²,2r³]` of `JetScalar::recip`; those differ by a
716 /// ULP and would break `to_bits` parity). Caller guarantees nonzero.
717 #[inline]
718 fn recip(&self) -> Self {
719 let mut d0 = [0.0; LANES];
720 let mut d1 = [0.0; LANES];
721 let mut d2 = [0.0; LANES];
722 for i in 0..LANES {
723 let u = self.v[i];
724 let u2 = u * u;
725 let u3 = u2 * u;
726 d0[i] = 1.0 / u;
727 d1[i] = -1.0 / u2;
728 d2[i] = 2.0 / u3;
729 }
730 self.compose([d0, d1, d2])
731 }
732 /// Extract lane `i` as a production [`Order2<K>`] scalar.
733 #[inline]
734 fn lane(&self, i: usize) -> Order2<K> {
735 let mut t = gam_math::jet_tower::Tower2::<K>::constant(self.v[i]);
736 for a in 0..K {
737 t.g[a] = self.g[a][i];
738 for b in 0..K {
739 t.h[a][b] = self.h[a][b][i];
740 }
741 }
742 Order2(t)
743 }
744}
745
746/// 4-rows-per-pass FIRST-order lane scalar (value / gradient only), mirroring
747/// [`Order1`] term-for-term so lane `i` is `to_bits`-identical to row-`i`
748/// [`Order1`]. Used for the β-border consumer (reconstruction is linear in β,
749/// so only value + gradient are read).
750#[derive(Clone, Copy)]
751struct O1x4<const K: usize> {
752 v: [f64; LANES],
753 g: [[f64; LANES]; K],
754}
755
756impl<const K: usize> O1x4<K> {
757 #[inline]
758 fn constant(c: [f64; LANES]) -> Self {
759 O1x4 {
760 v: c,
761 g: [[0.0; LANES]; K],
762 }
763 }
764 #[inline]
765 fn variable(value: [f64; LANES], axis: usize) -> Self {
766 let mut out = Self::constant(value);
767 out.g[axis] = l_splat(1.0);
768 out
769 }
770 #[inline]
771 fn add(&self, o: &Self) -> Self {
772 let mut out = *self;
773 out.v = l_add(self.v, o.v);
774 for i in 0..K {
775 out.g[i] = l_add(self.g[i], o.g[i]);
776 }
777 out
778 }
779 #[inline]
780 fn scale(&self, s: [f64; LANES]) -> Self {
781 let mut out = *self;
782 out.v = l_mul(self.v, s);
783 for i in 0..K {
784 out.g[i] = l_mul(self.g[i], s);
785 }
786 out
787 }
788 #[inline]
789 fn sub(&self, o: &Self) -> Self {
790 self.add(&o.scale(l_splat(-1.0)))
791 }
792 #[inline]
793 fn mul(&self, o: &Self) -> Self {
794 // Tower2::mul value/grad terms (order-≤1 truncation): v = a.v·b.v;
795 // g[i] = a.v·b.g[i] + a.g[i]·b.v. Identical float order to `Order1::mul`.
796 let a = self;
797 let b = o;
798 let mut out = Self::constant(l_mul(a.v, b.v));
799 for i in 0..K {
800 out.g[i] = l_add(l_mul(a.v, b.g[i]), l_mul(a.g[i], b.v));
801 }
802 out
803 }
804 #[inline]
805 fn compose(&self, d: [[f64; LANES]; 2]) -> Self {
806 // Order-≤1 Faà di Bruno: v = d[0]; g[i] = d[1]·g[i] (matches
807 // `Order1::compose_unary`, `acc` starts at +0.0).
808 let mut out = Self::constant(d[0]);
809 for i in 0..K {
810 let mut acc = l_splat(0.0);
811 acc = l_add(acc, l_mul(d[1], self.g[i]));
812 out.g[i] = acc;
813 }
814 out
815 }
816 #[inline]
817 fn exp(&self) -> Self {
818 let mut e = [0.0; LANES];
819 for i in 0..LANES {
820 e[i] = self.v[i].exp();
821 }
822 self.compose([e, e])
823 }
824 #[inline]
825 fn recip(&self) -> Self {
826 // Division-based `[1/u, -1/u²]` matching the `recip` free fn (see
827 // `O2x4::recip`), so lane `i` is `to_bits`-identical to the scalar path.
828 let mut d0 = [0.0; LANES];
829 let mut d1 = [0.0; LANES];
830 for i in 0..LANES {
831 let u = self.v[i];
832 let u2 = u * u;
833 d0[i] = 1.0 / u;
834 d1[i] = -1.0 / u2;
835 }
836 self.compose([d0, d1])
837 }
838 #[inline]
839 fn lane(&self, i: usize) -> Order1<K> {
840 let mut g = [0.0; K];
841 for a in 0..K {
842 g[a] = self.g[a][i];
843 }
844 Order1 { v: self.v[i], g }
845 }
846}
847
848/// Structural layout signature of a row program: the part that MUST be identical
849/// across rows for them to share one SIMD op graph (slot mapping, per-atom
850/// basis/latent/decoder shape, primary count). The per-row numeric data
851/// (`phi`/`d_phi`/`d2_phi`/`decoder` VALUES, `logits`) is what differs between
852/// lanes; the layout is what is shared.
853impl SaeReconstructionRowProgram {
854 /// Whether `self` and `other` share the SIMD-batchable softmax layout: same
855 /// softmax temperature, primary count, slot mapping, and per-atom basis /
856 /// latent / decoder dimensions. (Decoder/basis VALUES may differ per row and
857 /// are lane-packed; only the SHAPES must match.)
858 fn batch_aligned_softmax_with(&self, other: &Self) -> bool {
859 // Both rows must gate through softmax at the same temperature; a
860 // bit-for-bit `inv_tau` match is what lets them share one op graph.
861 match (self.gate, other.gate) {
862 (RowGate::Softmax { inv_tau: a }, RowGate::Softmax { inv_tau: b }) => {
863 if a.to_bits() != b.to_bits() {
864 return false;
865 }
866 }
867 _ => return false,
868 }
869 if self.n_primaries != other.n_primaries
870 || self.atoms.len() != other.atoms.len()
871 || self.logit_slot != other.logit_slot
872 || self.coord_slot != other.coord_slot
873 || self.logits.len() != other.logits.len()
874 {
875 return false;
876 }
877 for (a, b) in self.atoms.iter().zip(other.atoms.iter()) {
878 if a.latent_dim != b.latent_dim
879 || a.n_basis() != b.n_basis()
880 || a.out_dim() != b.out_dim()
881 {
882 return false;
883 }
884 }
885 true
886 }
887
888 /// All `K` softmax gate lane-jets (`Order2` channels), with the denominator
889 /// SHARED across atoms and 4 rows packed per lane. Mirrors [`Self::all_gates`]
890 /// term-for-term so lane `i` is `to_bits`-identical to the row-`i` scalar
891 /// `all_gates::<K, Order2<K>>()`.
892 fn all_gates_o2x4<const K: usize>(&self, rows: &[&Self; LANES], inv_tau: f64) -> Vec<O2x4<K>> {
893 let n = self.gate_value.len();
894 let inv_tau_l = l_splat(inv_tau);
895 // Per-lane max-subtraction shift (= the scalar `all_gates` softmax shift,
896 // computed independently per row/lane).
897 let mut shift = [0.0; LANES];
898 for (lane, r) in rows.iter().enumerate() {
899 shift[lane] = r.logits.iter().copied().fold(f64::NEG_INFINITY, f64::max) * inv_tau;
900 }
901 let mut exps: Vec<O2x4<K>> = Vec::with_capacity(n);
902 let mut denom = O2x4::<K>::constant(l_splat(0.0));
903 for j in 0..n {
904 let mut lj_val = [0.0; LANES];
905 for (lane, r) in rows.iter().enumerate() {
906 lj_val[lane] = r.logits[j];
907 }
908 let lj = match self.logit_slot[j] {
909 Some(slot) => O2x4::<K>::variable(lj_val, slot),
910 None => O2x4::<K>::constant(lj_val),
911 };
912 let ej = lj.scale(inv_tau_l).sub(&O2x4::<K>::constant(shift)).exp();
913 denom = denom.add(&ej);
914 exps.push(ej);
915 }
916 let inv = denom.recip();
917 exps.iter().map(|e| e.mul(&inv)).collect()
918 }
919
920 /// All `K` softmax gate lane-jets at FIRST order (`Order1` channels).
921 /// Mirrors `all_gates::<K, Order1<K>>()` term-for-term.
922 fn all_gates_o1x4<const K: usize>(&self, rows: &[&Self; LANES], inv_tau: f64) -> Vec<O1x4<K>> {
923 let n = self.gate_value.len();
924 let inv_tau_l = l_splat(inv_tau);
925 let mut shift = [0.0; LANES];
926 for (lane, r) in rows.iter().enumerate() {
927 shift[lane] = r.logits.iter().copied().fold(f64::NEG_INFINITY, f64::max) * inv_tau;
928 }
929 let mut exps: Vec<O1x4<K>> = Vec::with_capacity(n);
930 let mut denom = O1x4::<K>::constant(l_splat(0.0));
931 for j in 0..n {
932 let mut lj_val = [0.0; LANES];
933 for (lane, r) in rows.iter().enumerate() {
934 lj_val[lane] = r.logits[j];
935 }
936 let lj = match self.logit_slot[j] {
937 Some(slot) => O1x4::<K>::variable(lj_val, slot),
938 None => O1x4::<K>::constant(lj_val),
939 };
940 let ej = lj.scale(inv_tau_l).sub(&O1x4::<K>::constant(shift)).exp();
941 denom = denom.add(&ej);
942 exps.push(ej);
943 }
944 let inv = denom.recip();
945 exps.iter().map(|e| e.mul(&inv)).collect()
946 }
947
948 /// One atom's basis jet `Φ_b(t)` as an [`O2x4`] over 4 rows, mirroring
949 /// [`AtomRowBasisJet::basis_tower`] term-for-term. A data-dependent `== 0`
950 /// skip is taken only when ALL 4 lanes are zero (the contribution of a zero
951 /// lane is `+0.0`, bit-identical to the scalar skip).
952 fn basis_tower_o2x4<const K: usize>(
953 rows: &[&Self; LANES],
954 atom: usize,
955 basis_col: usize,
956 coord_slots: &[usize],
957 ) -> O2x4<K> {
958 let latent = rows[0].atoms[atom].latent_dim;
959 let mut phi0 = [0.0; LANES];
960 for (lane, r) in rows.iter().enumerate() {
961 phi0[lane] = r.atoms[atom].phi[basis_col];
962 }
963 let mut acc = O2x4::<K>::constant(phi0);
964 for axis in 0..latent {
965 let slot = coord_slots[axis];
966 if slot == SAE_FIXED_COORD_SLOT {
967 continue;
968 }
969 let mut d1 = [0.0; LANES];
970 let mut any = false;
971 for (lane, r) in rows.iter().enumerate() {
972 let v = r.atoms[atom].d_phi[basis_col][axis];
973 d1[lane] = v;
974 any |= v != 0.0;
975 }
976 if any {
977 acc = acc.add(&O2x4::<K>::variable(l_splat(0.0), slot).scale(d1));
978 }
979 }
980 for axis_a in 0..latent {
981 for axis_b in 0..latent {
982 if coord_slots[axis_a] == SAE_FIXED_COORD_SLOT
983 || coord_slots[axis_b] == SAE_FIXED_COORD_SLOT
984 {
985 continue;
986 }
987 let mut d2 = [0.0; LANES];
988 let mut any = false;
989 for (lane, r) in rows.iter().enumerate() {
990 let v = r.atoms[atom].d2_phi[basis_col][axis_a][axis_b];
991 d2[lane] = v;
992 any |= v != 0.0;
993 }
994 if !any {
995 continue;
996 }
997 let mut half_d2 = [0.0; LANES];
998 for lane in 0..LANES {
999 half_d2[lane] = 0.5 * d2[lane];
1000 }
1001 let va = O2x4::<K>::variable(l_splat(0.0), coord_slots[axis_a]);
1002 let vb = O2x4::<K>::variable(l_splat(0.0), coord_slots[axis_b]);
1003 acc = acc.add(&va.mul(&vb).scale(half_d2));
1004 }
1005 }
1006 acc
1007 }
1008
1009 /// One atom's basis jet as an [`O1x4`] (value + gradient), mirroring
1010 /// `basis_tower::<Order1>` term-for-term.
1011 fn basis_tower_o1x4<const K: usize>(
1012 rows: &[&Self; LANES],
1013 atom: usize,
1014 basis_col: usize,
1015 coord_slots: &[usize],
1016 ) -> O1x4<K> {
1017 let latent = rows[0].atoms[atom].latent_dim;
1018 let mut phi0 = [0.0; LANES];
1019 for (lane, r) in rows.iter().enumerate() {
1020 phi0[lane] = r.atoms[atom].phi[basis_col];
1021 }
1022 let mut acc = O1x4::<K>::constant(phi0);
1023 for axis in 0..latent {
1024 let slot = coord_slots[axis];
1025 if slot == SAE_FIXED_COORD_SLOT {
1026 continue;
1027 }
1028 let mut d1 = [0.0; LANES];
1029 let mut any = false;
1030 for (lane, r) in rows.iter().enumerate() {
1031 let v = r.atoms[atom].d_phi[basis_col][axis];
1032 d1[lane] = v;
1033 any |= v != 0.0;
1034 }
1035 if any {
1036 acc = acc.add(&O1x4::<K>::variable(l_splat(0.0), slot).scale(d1));
1037 }
1038 }
1039 for axis_a in 0..latent {
1040 for axis_b in 0..latent {
1041 if coord_slots[axis_a] == SAE_FIXED_COORD_SLOT
1042 || coord_slots[axis_b] == SAE_FIXED_COORD_SLOT
1043 {
1044 continue;
1045 }
1046 let mut d2 = [0.0; LANES];
1047 let mut any = false;
1048 for (lane, r) in rows.iter().enumerate() {
1049 let v = r.atoms[atom].d2_phi[basis_col][axis_a][axis_b];
1050 d2[lane] = v;
1051 any |= v != 0.0;
1052 }
1053 if !any {
1054 continue;
1055 }
1056 let mut half_d2 = [0.0; LANES];
1057 for lane in 0..LANES {
1058 half_d2[lane] = 0.5 * d2[lane];
1059 }
1060 let va = O1x4::<K>::variable(l_splat(0.0), coord_slots[axis_a]);
1061 let vb = O1x4::<K>::variable(l_splat(0.0), coord_slots[axis_b]);
1062 acc = acc.add(&va.mul(&vb).scale(half_d2));
1063 }
1064 }
1065 acc
1066 }
1067
1068 /// All `out_dim` reconstruction columns for FOUR softmax-aligned rows at once,
1069 /// returned per row. Each row's column vector is BIT-IDENTICAL to
1070 /// [`Self::reconstruction_all_columns_packed`] on that row (same hoisting,
1071 /// same Leibniz products in the same order — lane `i` mirrors the scalar
1072 /// row-`i` path). Returns `None` if the four rows are not softmax-aligned, so
1073 /// the caller can fall back to the scalar per-row path.
1074 #[must_use]
1075 pub fn reconstruction_all_columns_batch4<const K: usize>(
1076 rows: [&Self; 4],
1077 ) -> Option<[Vec<Order2<K>>; 4]> {
1078 let head = rows[0];
1079 if head.n_primaries != K {
1080 return None;
1081 }
1082 let inv_tau = match head.gate {
1083 RowGate::Softmax { inv_tau } => inv_tau,
1084 RowGate::PerAtomLogistic { .. } => return None,
1085 };
1086 for r in &rows[1..] {
1087 if !head.batch_aligned_softmax_with(r) {
1088 return None;
1089 }
1090 }
1091 let p = head.out_dim();
1092 let gates: Vec<O2x4<K>> = head.all_gates_o2x4::<K>(&rows, inv_tau);
1093 let bases: Vec<Vec<O2x4<K>>> = head
1094 .atoms
1095 .iter()
1096 .enumerate()
1097 .map(|(atom, atom_jet)| {
1098 (0..atom_jet.n_basis())
1099 .map(|b| Self::basis_tower_o2x4::<K>(&rows, atom, b, &head.coord_slot[atom]))
1100 .collect()
1101 })
1102 .collect();
1103 let mut cols: [Vec<Order2<K>>; LANES] =
1104 [Vec::new(), Vec::new(), Vec::new(), Vec::new()];
1105 for c in 0..p {
1106 let mut acc = O2x4::<K>::constant(l_splat(0.0));
1107 for (atom, atom_jet) in head.atoms.iter().enumerate() {
1108 let mut decoded = O2x4::<K>::constant(l_splat(0.0));
1109 for basis_col in 0..atom_jet.n_basis() {
1110 let mut coeff = [0.0; LANES];
1111 let mut any = false;
1112 for (lane, r) in rows.iter().enumerate() {
1113 let v = r.atoms[atom].decoder[basis_col][c];
1114 coeff[lane] = v;
1115 any |= v != 0.0;
1116 }
1117 if any {
1118 decoded = decoded.add(&bases[atom][basis_col].scale(coeff));
1119 }
1120 }
1121 acc = acc.add(&gates[atom].mul(&decoded));
1122 }
1123 for lane in 0..LANES {
1124 cols[lane].push(acc.lane(lane));
1125 }
1126 }
1127 Some(cols)
1128 }
1129
1130 /// Packed β-border FIRST-order jets for a batch of `(atom, basis_col)`
1131 /// channels, for FOUR softmax-aligned rows at once, returned per row. Each
1132 /// row's channel vector is BIT-IDENTICAL to
1133 /// [`Self::beta_border_order1_packed`] on that row. Returns `None` if the
1134 /// rows are not softmax-aligned.
1135 #[must_use]
1136 pub fn beta_border_order1_batch4<const K: usize>(
1137 rows: [&Self; 4],
1138 channels: &[(usize, usize)],
1139 ) -> Option<[Vec<Order1<K>>; 4]> {
1140 let head = rows[0];
1141 if head.n_primaries != K {
1142 return None;
1143 }
1144 let inv_tau = match head.gate {
1145 RowGate::Softmax { inv_tau } => inv_tau,
1146 RowGate::PerAtomLogistic { .. } => return None,
1147 };
1148 for r in &rows[1..] {
1149 if !head.batch_aligned_softmax_with(r) {
1150 return None;
1151 }
1152 }
1153 let gates: Vec<O1x4<K>> = head.all_gates_o1x4::<K>(&rows, inv_tau);
1154 let mut out: [Vec<Order1<K>>; LANES] =
1155 [Vec::new(), Vec::new(), Vec::new(), Vec::new()];
1156 for &(atom, basis_col) in channels {
1157 let phi = Self::basis_tower_o1x4::<K>(&rows, atom, basis_col, &head.coord_slot[atom]);
1158 let s = gates[atom].mul(&phi);
1159 for lane in 0..LANES {
1160 out[lane].push(s.lane(lane));
1161 }
1162 }
1163 Some(out)
1164 }
1165}
1166
1167#[cfg(test)]
1168mod tests {
1169 use super::*;
1170
1171 /// Replicate the production hand path (`row_jets_for_logdet`) arithmetic for
1172 /// the reconstruction `first`/`second` channels of ONE output column, from
1173 /// the same atom jets and softmax gate derivatives — independent code from
1174 /// the tower. The two must agree to machine precision; this is the #932
1175 /// universal oracle for the SAE row program (the analog of the survival
1176 /// `rigid_row_kernel_agrees_with_jet_tower_program` oracle).
1177 struct HandChannels {
1178 first: Vec<f64>, // [primary]
1179 second: Vec<Vec<f64>>, // [primary][primary]
1180 value: f64,
1181 }
1182
1183 /// Softmax gate first/second derivatives wrt logit primaries, term-for-term
1184 /// the production `gate_derivatives_for_row` softmax branch.
1185 fn softmax_gate_derivs(gate: &[f64], inv_tau: f64) -> (Vec<Vec<f64>>, Vec<Vec<Vec<f64>>>) {
1186 let k = gate.len();
1187 // dz[j][kk] = ∂ζ_kk/∂ℓ_j ; d2z[j][l][kk] = ∂²ζ_kk/∂ℓ_j∂ℓ_l.
1188 let mut dz = vec![vec![0.0_f64; k]; k];
1189 let mut d2z = vec![vec![vec![0.0_f64; k]; k]; k];
1190 for j in 0..k {
1191 for kk in 0..k {
1192 let ind = if kk == j { 1.0 } else { 0.0 };
1193 dz[j][kk] = gate[kk] * (ind - gate[j]) * inv_tau;
1194 }
1195 }
1196 for j in 0..k {
1197 for l in 0..k {
1198 for kk in 0..k {
1199 let ikl = if kk == l { 1.0 } else { 0.0 };
1200 let ikj = if kk == j { 1.0 } else { 0.0 };
1201 let ijl = if j == l { 1.0 } else { 0.0 };
1202 d2z[j][l][kk] = gate[kk]
1203 * ((ikl - gate[l]) * (ikj - gate[j]) - gate[j] * (ijl - gate[l]))
1204 * inv_tau
1205 * inv_tau;
1206 }
1207 }
1208 }
1209 (dz, d2z)
1210 }
1211
1212 /// Hand-pack the reconstruction column channels exactly as the production
1213 /// `row_jets_for_logdet` does for a softmax gate: gate-logit primaries first
1214 /// (one per atom), then each atom's latent coords.
1215 fn hand_softmax_column(
1216 prog: &SaeReconstructionRowProgram,
1217 out_col: usize,
1218 inv_tau: f64,
1219 ) -> HandChannels {
1220 let k = prog.atoms.len();
1221 let n = prog.n_primaries;
1222 // decoded[k] value, d1[k][axis], d2[k][a][b] for this out_col.
1223 let decoded: Vec<f64> = (0..k)
1224 .map(|kk| {
1225 (0..prog.atoms[kk].n_basis())
1226 .map(|b| prog.atoms[kk].phi[b] * prog.atoms[kk].decoder[b][out_col])
1227 .sum()
1228 })
1229 .collect();
1230 let d1: Vec<Vec<f64>> = (0..k)
1231 .map(|kk| {
1232 (0..prog.atoms[kk].latent_dim)
1233 .map(|axis| {
1234 (0..prog.atoms[kk].n_basis())
1235 .map(|b| {
1236 prog.atoms[kk].d_phi[b][axis] * prog.atoms[kk].decoder[b][out_col]
1237 })
1238 .sum()
1239 })
1240 .collect()
1241 })
1242 .collect();
1243 let d2: Vec<Vec<Vec<f64>>> = (0..k)
1244 .map(|kk| {
1245 (0..prog.atoms[kk].latent_dim)
1246 .map(|a| {
1247 (0..prog.atoms[kk].latent_dim)
1248 .map(|b| {
1249 (0..prog.atoms[kk].n_basis())
1250 .map(|col| {
1251 prog.atoms[kk].d2_phi[col][a][b]
1252 * prog.atoms[kk].decoder[col][out_col]
1253 })
1254 .sum()
1255 })
1256 .collect()
1257 })
1258 .collect()
1259 })
1260 .collect();
1261
1262 let (dz, d2z) = softmax_gate_derivs(&prog.gate_value, inv_tau);
1263
1264 // Primary index of atom logit / coord, matching the program layout.
1265 let logit_idx = |kk: usize| prog.logit_slot[kk];
1266 let coord_idx = |kk: usize, axis: usize| prog.coord_slot[kk][axis];
1267
1268 let value: f64 = (0..k).map(|kk| prog.gate_value[kk] * decoded[kk]).sum();
1269
1270 let mut first = vec![0.0_f64; n];
1271 // Logit primaries: ∂ẑ/∂ℓ_j = Σ_kk dz[j][kk]·decoded[kk].
1272 for j in 0..k {
1273 if let Some(p) = logit_idx(j) {
1274 first[p] = (0..k).map(|kk| dz[j][kk] * decoded[kk]).sum();
1275 }
1276 }
1277 // Coord primaries: ∂ẑ/∂t_{kk,axis} = ζ_kk · d1[kk][axis].
1278 for kk in 0..k {
1279 for axis in 0..prog.atoms[kk].latent_dim {
1280 first[coord_idx(kk, axis)] = prog.gate_value[kk] * d1[kk][axis];
1281 }
1282 }
1283
1284 let mut second = vec![vec![0.0_f64; n]; n];
1285 // Logit×Logit: Σ_kk d2z[j][l][kk]·decoded[kk].
1286 for j in 0..k {
1287 for l in 0..k {
1288 if let (Some(pj), Some(pl)) = (logit_idx(j), logit_idx(l)) {
1289 second[pj][pl] = (0..k).map(|kk| d2z[j][l][kk] * decoded[kk]).sum();
1290 }
1291 }
1292 }
1293 // Logit×Coord (and symmetric): dz[j][kk]·d1[kk][axis].
1294 for j in 0..k {
1295 for kk in 0..k {
1296 for axis in 0..prog.atoms[kk].latent_dim {
1297 if let Some(pj) = logit_idx(j) {
1298 let pc = coord_idx(kk, axis);
1299 let val = dz[j][kk] * d1[kk][axis];
1300 second[pj][pc] = val;
1301 second[pc][pj] = val;
1302 }
1303 }
1304 }
1305 }
1306 // Coord×Coord same atom: ζ_kk · d2[kk][a][b].
1307 for kk in 0..k {
1308 for a in 0..prog.atoms[kk].latent_dim {
1309 for b in 0..prog.atoms[kk].latent_dim {
1310 let pa = coord_idx(kk, a);
1311 let pb = coord_idx(kk, b);
1312 second[pa][pb] = prog.gate_value[kk] * d2[kk][a][b];
1313 }
1314 }
1315 }
1316
1317 HandChannels {
1318 first,
1319 second,
1320 value,
1321 }
1322 }
1323
1324 /// Build a two-atom softmax fixture with `latent_dim = 2` per atom and a
1325 /// dense decoder so every primary is exercised. Layout: logit slots
1326 /// 0,1; atom-0 coords 2,3; atom-1 coords 4,5 → K = 6 primaries.
1327 fn softmax_fixture(inv_tau: f64) -> (SaeReconstructionRowProgram, f64) {
1328 let n_basis = 3;
1329 let out_dim = 4;
1330 let mk_atom = |seed: f64| {
1331 let phi: Vec<f64> = (0..n_basis)
1332 .map(|b| 0.3 + 0.2 * (b as f64 + seed))
1333 .collect();
1334 let d_phi: Vec<Vec<f64>> = (0..n_basis)
1335 .map(|b| {
1336 (0..2)
1337 .map(|axis| 0.1 * (b as f64 + 1.0) - 0.05 * axis as f64 + 0.03 * seed)
1338 .collect()
1339 })
1340 .collect();
1341 let d2_phi: Vec<Vec<Vec<f64>>> = (0..n_basis)
1342 .map(|b| {
1343 (0..2)
1344 .map(|a| {
1345 (0..2)
1346 .map(|bb| {
1347 // Symmetric in (a, bb).
1348 0.02 * (b as f64 + 1.0)
1349 + 0.01 * (a as f64)
1350 + 0.01 * (bb as f64)
1351 + 0.004 * seed
1352 })
1353 .collect()
1354 })
1355 .collect()
1356 })
1357 .collect();
1358 let decoder: Vec<Vec<f64>> = (0..n_basis)
1359 .map(|b| {
1360 (0..out_dim)
1361 .map(|c| 0.5 - 0.1 * (b as f64) + 0.07 * (c as f64) + 0.02 * seed)
1362 .collect()
1363 })
1364 .collect();
1365 AtomRowBasisJet {
1366 phi,
1367 d_phi,
1368 d2_phi,
1369 decoder,
1370 latent_dim: 2,
1371 }
1372 };
1373 let logits = vec![0.4_f64, -0.7];
1374 // Softmax gate values at these logits.
1375 let e: Vec<f64> = logits.iter().map(|&l| (l * inv_tau).exp()).collect();
1376 let s: f64 = e.iter().sum();
1377 let gate_value: Vec<f64> = e.iter().map(|&v| v / s).collect();
1378 let prog = SaeReconstructionRowProgram {
1379 atoms: vec![mk_atom(0.0), mk_atom(1.0)],
1380 gate_value,
1381 logits,
1382 gate_scale: vec![1.0, 1.0],
1383 gate_shift: vec![0.0, 0.0],
1384 gate: RowGate::Softmax { inv_tau },
1385 logit_slot: vec![Some(0), Some(1)],
1386 coord_slot: vec![vec![2, 3], vec![4, 5]],
1387 n_primaries: 6,
1388 };
1389 (prog, inv_tau)
1390 }
1391
1392 /// Parametrized softmax fixture with `n_atoms` softmax atoms, each carrying a
1393 /// free logit primary and `latent_dim` free coord primaries, so
1394 /// `n_primaries = n_atoms·(1 + latent_dim)`. Layout: logit slots
1395 /// `0..n_atoms`, then atom `k`'s coord axis `j` at `n_atoms + k·latent_dim +
1396 /// j`. Used by the #932 ns/row microbench to instantiate the tower at
1397 /// `K = n_primaries` for `K ∈ {8, 16}` (the softmax gate Hessian is `n_atoms³`,
1398 /// the cost driver the hand path pays per output column).
1399 fn softmax_fixture_k(
1400 n_atoms: usize,
1401 latent_dim: usize,
1402 n_basis: usize,
1403 out_dim: usize,
1404 inv_tau: f64,
1405 ) -> SaeReconstructionRowProgram {
1406 let mk_atom = |seed: f64| {
1407 let phi: Vec<f64> = (0..n_basis)
1408 .map(|b| 0.3 + 0.2 * (b as f64 + seed))
1409 .collect();
1410 let d_phi: Vec<Vec<f64>> = (0..n_basis)
1411 .map(|b| {
1412 (0..latent_dim)
1413 .map(|axis| 0.1 * (b as f64 + 1.0) - 0.05 * axis as f64 + 0.03 * seed)
1414 .collect()
1415 })
1416 .collect();
1417 let d2_phi: Vec<Vec<Vec<f64>>> = (0..n_basis)
1418 .map(|b| {
1419 (0..latent_dim)
1420 .map(|a| {
1421 (0..latent_dim)
1422 .map(|bb| {
1423 0.02 * (b as f64 + 1.0)
1424 + 0.01 * (a as f64)
1425 + 0.01 * (bb as f64)
1426 + 0.004 * seed
1427 })
1428 .collect()
1429 })
1430 .collect()
1431 })
1432 .collect();
1433 let decoder: Vec<Vec<f64>> = (0..n_basis)
1434 .map(|b| {
1435 (0..out_dim)
1436 .map(|c| 0.5 - 0.1 * (b as f64) + 0.07 * (c as f64) + 0.02 * seed)
1437 .collect()
1438 })
1439 .collect();
1440 AtomRowBasisJet {
1441 phi,
1442 d_phi,
1443 d2_phi,
1444 decoder,
1445 latent_dim,
1446 }
1447 };
1448 let logits: Vec<f64> = (0..n_atoms)
1449 .map(|k| 0.4 - 0.13 * k as f64 + 0.05 * (k as f64).sin())
1450 .collect();
1451 let e: Vec<f64> = logits.iter().map(|&l| (l * inv_tau).exp()).collect();
1452 let s: f64 = e.iter().sum();
1453 let gate_value: Vec<f64> = e.iter().map(|&v| v / s).collect();
1454 let atoms: Vec<AtomRowBasisJet> = (0..n_atoms).map(|k| mk_atom(k as f64)).collect();
1455 let logit_slot: Vec<Option<usize>> = (0..n_atoms).map(Some).collect();
1456 let coord_slot: Vec<Vec<usize>> = (0..n_atoms)
1457 .map(|k| (0..latent_dim).map(|j| n_atoms + k * latent_dim + j).collect())
1458 .collect();
1459 SaeReconstructionRowProgram {
1460 atoms,
1461 gate_value,
1462 logits,
1463 gate_scale: vec![1.0; n_atoms],
1464 gate_shift: vec![0.0; n_atoms],
1465 gate: RowGate::Softmax { inv_tau },
1466 logit_slot,
1467 coord_slot,
1468 n_primaries: n_atoms * (1 + latent_dim),
1469 }
1470 }
1471
1472 /// #932 correctness gate: the production packed jet recon
1473 /// ([`SaeReconstructionRowProgram::reconstruction_all_columns_packed`], gate +
1474 /// basis jets HOISTED out of the column loop, softmax denom/recip SHARED) and
1475 /// the per-column packed call must each reproduce the hand path
1476 /// ([`hand_softmax_column`], the old `row_jets_for_logdet` closed-form softmax
1477 /// gate Jacobian/Hessian × decoded basis, re-derived per output column) on
1478 /// value/grad/Hessian — the #932 bit-identity bar. (The ns/row timing
1479 /// comparison this gate used to precede lives in `bench/`, not in a `#[test]`:
1480 /// `#[ignore]`d timing benches are banned by `build.rs`.)
1481 #[test]
1482 fn recon_jet_matches_hand_path_value_grad_hess() {
1483 let out_dim = 16;
1484 let n_basis = 4;
1485 let inv_tau = 1.3;
1486 // K=8: 4 atoms × (1 logit + 1 coord) = 8 primaries.
1487 check_recon_vs_hand::<8>(softmax_fixture_k(4, 1, n_basis, out_dim, inv_tau), inv_tau);
1488 // K=16: 8 atoms × (1 logit + 1 coord) = 16 primaries.
1489 check_recon_vs_hand::<16>(softmax_fixture_k(8, 1, n_basis, out_dim, inv_tau), inv_tau);
1490 }
1491
1492 fn check_recon_vs_hand<const K: usize>(prog: SaeReconstructionRowProgram, inv_tau: f64) {
1493 let out_dim = prog.out_dim();
1494 let cols = prog.reconstruction_all_columns_packed::<K>();
1495 for c in 0..out_dim {
1496 let hand = hand_softmax_column(&prog, c, inv_tau);
1497 let h_floor = hand
1498 .second
1499 .iter()
1500 .flatten()
1501 .fold(0.0_f64, |m, x| m.max(x.abs()));
1502 // The all-columns (hoisted) path matches hand value + Hessian.
1503 assert!((cols[c].value() - hand.value).abs() <= 1e-9 * hand.value.abs().max(1.0));
1504 // The per-column path matches the all-columns path (same kernel, no hoist).
1505 let percol = prog.reconstruction_column_packed::<K>(c);
1506 assert!((percol.value() - cols[c].value()).abs() <= 1e-12 * cols[c].value().abs().max(1.0));
1507 for a in 0..K {
1508 for b in 0..K {
1509 assert!(
1510 (cols[c].h()[a][b] - hand.second[a][b]).abs()
1511 <= 1e-8 * h_floor.max(1e-12)
1512 );
1513 assert!(
1514 (percol.h()[a][b] - cols[c].h()[a][b]).abs()
1515 <= 1e-12 * h_floor.max(1e-12)
1516 );
1517 }
1518 }
1519 }
1520 }
1521
1522 /// INDEPENDENT scalar witness for the reconstruction column `ẑ_c(δ)` as a
1523 /// function of the primary-increment vector `δ` (the displacement of each
1524 /// tower primary from its seed value: a coord primary seeds at value 0, a
1525 /// logit primary at its current logit, so `δ` is the same offset the tower's
1526 /// seeded variables carry). This evaluator touches NONE of the `Tower4`
1527 /// arithmetic — no Leibniz product, no Faà di Bruno compose, no
1528 /// `basis_tower`/`decoded_tower`/`gate_tower` — it re-derives the closed-form
1529 /// reconstruction from the raw jet tensors and the softmax definition. It is
1530 /// the witness the t3/t4 FD oracle differences below.
1531 ///
1532 /// `ẑ_c(δ) = Σ_k softmax_k((ℓ + δ_logit)·inv_tau) · Σ_b Φ̃_{k,b}(δ_coord)·B_{k,b,c}`
1533 /// with the SAME local quadratic basis model the program consumes:
1534 /// `Φ̃_b(u) = phi[b] + Σ_a d_phi[b][a]·u_a + ½ Σ_{a,a'} d2_phi[b][a][a']·u_a·u_{a'}`.
1535 fn recon_scalar_softmax(
1536 prog: &SaeReconstructionRowProgram,
1537 out_col: usize,
1538 inv_tau: f64,
1539 delta: &[f64],
1540 ) -> f64 {
1541 let k = prog.atoms.len();
1542 // Softmax over (logit + δ_logit) for atoms with a free logit primary;
1543 // atoms without one keep their base logit (no δ).
1544 let exps: Vec<f64> = (0..k)
1545 .map(|kk| {
1546 let dl = match prog.logit_slot[kk] {
1547 Some(slot) => delta[slot],
1548 None => 0.0,
1549 };
1550 ((prog.logits[kk] + dl) * inv_tau).exp()
1551 })
1552 .collect();
1553 let denom: f64 = exps.iter().sum();
1554 let mut acc = 0.0;
1555 for kk in 0..k {
1556 let gate = exps[kk] / denom;
1557 let atom = &prog.atoms[kk];
1558 // decoded_{kk,c}(δ_coord) via the local quadratic basis model.
1559 let mut decoded = 0.0;
1560 for b in 0..atom.n_basis() {
1561 let mut phi = atom.phi[b];
1562 for a in 0..atom.latent_dim {
1563 let ua = delta[prog.coord_slot[kk][a]];
1564 phi += atom.d_phi[b][a] * ua;
1565 }
1566 for a in 0..atom.latent_dim {
1567 let ua = delta[prog.coord_slot[kk][a]];
1568 for a2 in 0..atom.latent_dim {
1569 let ub = delta[prog.coord_slot[kk][a2]];
1570 phi += 0.5 * atom.d2_phi[b][a][a2] * ua * ub;
1571 }
1572 }
1573 decoded += phi * atom.decoder[b][out_col];
1574 }
1575 acc += gate * decoded;
1576 }
1577 acc
1578 }
1579
1580 /// Fourth-order central FD of `recon_scalar_softmax` along axes (a,b,c,d) at
1581 /// the origin (δ = 0, the tower seed point). Uses the standard mixed
1582 /// fourth-difference stencil with sign vector ±h on each of the four axes
1583 /// (axes may coincide). 2⁴ = 16 evaluations.
1584 fn fd_fourth(
1585 prog: &SaeReconstructionRowProgram,
1586 out_col: usize,
1587 inv_tau: f64,
1588 axes: [usize; 4],
1589 h: f64,
1590 ) -> f64 {
1591 let n = prog.n_primaries;
1592 let mut acc = 0.0;
1593 for mask in 0..16u32 {
1594 let mut delta = vec![0.0_f64; n];
1595 let mut sign = 1.0;
1596 for (slot, &ax) in axes.iter().enumerate() {
1597 if (mask >> slot) & 1 == 1 {
1598 delta[ax] += h;
1599 } else {
1600 delta[ax] -= h;
1601 sign = -sign;
1602 }
1603 }
1604 acc += sign * recon_scalar_softmax(prog, out_col, inv_tau, &delta);
1605 }
1606 acc / (16.0 * h * h * h * h)
1607 }
1608
1609 /// Third-order central FD of `recon_scalar_softmax` along axes (a,b,c) at the
1610 /// origin: 2³ = 8 evaluations with the mixed third-difference stencil.
1611 fn fd_third(
1612 prog: &SaeReconstructionRowProgram,
1613 out_col: usize,
1614 inv_tau: f64,
1615 axes: [usize; 3],
1616 h: f64,
1617 ) -> f64 {
1618 let n = prog.n_primaries;
1619 let mut acc = 0.0;
1620 for mask in 0..8u32 {
1621 let mut delta = vec![0.0_f64; n];
1622 let mut sign = 1.0;
1623 for (slot, &ax) in axes.iter().enumerate() {
1624 if (mask >> slot) & 1 == 1 {
1625 delta[ax] += h;
1626 } else {
1627 delta[ax] -= h;
1628 sign = -sign;
1629 }
1630 }
1631 acc += sign * recon_scalar_softmax(prog, out_col, inv_tau, &delta);
1632 }
1633 acc / (8.0 * h * h * h)
1634 }
1635
1636 /// The #932 follow-up the issue flagged as missing: the SAE reconstruction
1637 /// program's THIRD- and FOURTH-order channels (`t3`/`t4`) validated against an
1638 /// INDEPENDENT witness (`recon_scalar_softmax`, finite-differenced), not just
1639 /// the value/first/second channels the hand-path oracle covers. Both the
1640 /// witness and the differencing are independent of the `Tower4` Leibniz /
1641 /// Faà-di-Bruno arithmetic that produces `t3`/`t4`, so agreement is a real
1642 /// cross-check of those higher-order channels — the analog of the survival
1643 /// kernel's `row_third_contracted` oracle, extended to fourth order.
1644 #[test]
1645 fn softmax_reconstruction_t3_t4_match_independent_fd_witness() {
1646 let (prog, inv_tau) = softmax_fixture(1.1);
1647 // Mixed fifth-derivative magnitude bounds the central-FD truncation; a
1648 // moderate step keeps both truncation and roundoff well under tol.
1649 let h3 = 2e-3;
1650 let h4 = 1e-2;
1651 for out_col in 0..prog.out_dim() {
1652 let tower = prog.reconstruction_column::<6>(out_col);
1653
1654 let t3_floor = tower
1655 .t3
1656 .iter()
1657 .flatten()
1658 .flatten()
1659 .fold(0.0_f64, |m, x| m.max(x.abs()))
1660 .max(1e-9);
1661 let t4_floor = tower
1662 .t4
1663 .iter()
1664 .flatten()
1665 .flatten()
1666 .flatten()
1667 .fold(0.0_f64, |m, x| m.max(x.abs()))
1668 .max(1e-9);
1669
1670 for a in 0..6 {
1671 for b in 0..6 {
1672 for c in 0..6 {
1673 let fd = fd_third(&prog, out_col, inv_tau, [a, b, c], h3);
1674 assert!(
1675 (tower.t3[a][b][c] - fd).abs() <= 5e-5 * t3_floor,
1676 "col {out_col} t3[{a}][{b}][{c}]: tower {:+.10e} vs fd {:+.10e}",
1677 tower.t3[a][b][c],
1678 fd
1679 );
1680 for d in 0..6 {
1681 let fd4 = fd_fourth(&prog, out_col, inv_tau, [a, b, c, d], h4);
1682 assert!(
1683 (tower.t4[a][b][c][d] - fd4).abs() <= 5e-4 * t4_floor,
1684 "col {out_col} t4[{a}][{b}][{c}][{d}]: tower {:+.10e} vs fd {:+.10e}",
1685 tower.t4[a][b][c][d],
1686 fd4
1687 );
1688 }
1689 }
1690 }
1691 }
1692 }
1693 }
1694
1695 /// A planted #736-style corruption in a t3 OR t4 channel is caught by the
1696 /// independent FD witness (loud at introduction). We perturb a copy of the
1697 /// tower's higher-order channel and assert the witness disagrees.
1698 #[test]
1699 fn planted_t3_t4_corruption_is_caught_by_fd_witness() {
1700 let (prog, inv_tau) = softmax_fixture(1.1);
1701 let out_col = 2;
1702 let tower = prog.reconstruction_column::<6>(out_col);
1703 // A real logit×coord×coord third block (atom-0 logit slot 0, atom-0
1704 // coords 2,3): the witness's third FD must match it...
1705 let axes3 = [0usize, 2, 3];
1706 let fd3 = fd_third(&prog, out_col, inv_tau, axes3, 2e-3);
1707 let t3_floor = tower
1708 .t3
1709 .iter()
1710 .flatten()
1711 .flatten()
1712 .fold(0.0_f64, |m, x| m.max(x.abs()))
1713 .max(1e-9);
1714 assert!(
1715 (tower.t3[0][2][3] - fd3).abs() <= 5e-5 * t3_floor,
1716 "honest t3 must match witness"
1717 );
1718 // ...and a sign-flipped copy must NOT.
1719 let corrupt = -tower.t3[0][2][3];
1720 assert!(
1721 (corrupt - fd3).abs() > 5e-5 * t3_floor,
1722 "a sign-flipped t3 block must disagree with the FD witness"
1723 );
1724
1725 let axes4 = [0usize, 0, 2, 3];
1726 let fd4 = fd_fourth(&prog, out_col, inv_tau, axes4, 1e-2);
1727 let t4_floor = tower
1728 .t4
1729 .iter()
1730 .flatten()
1731 .flatten()
1732 .flatten()
1733 .fold(0.0_f64, |m, x| m.max(x.abs()))
1734 .max(1e-9);
1735 let corrupt4 = tower.t4[0][0][2][3] + 10.0 * t4_floor;
1736 assert!(
1737 (corrupt4 - fd4).abs() > 5e-4 * t4_floor,
1738 "a corrupted t4 block must disagree with the FD witness"
1739 );
1740 }
1741
1742 #[test]
1743 fn softmax_reconstruction_tower_matches_hand_channels_all_columns() {
1744 let (prog, inv_tau) = softmax_fixture(1.3);
1745 for out_col in 0..prog.out_dim() {
1746 let tower = prog.reconstruction_column::<6>(out_col);
1747 let hand = hand_softmax_column(&prog, out_col, inv_tau);
1748
1749 // Magnitude floors so structurally-zero entries don't demand
1750 // absolute equality (the verify_kernel_channels convention).
1751 let g_floor = tower.g.iter().fold(0.0_f64, |m, x| m.max(x.abs()));
1752 let h_floor = tower
1753 .h
1754 .iter()
1755 .flatten()
1756 .fold(0.0_f64, |m, x| m.max(x.abs()));
1757
1758 assert!(
1759 (tower.v - hand.value).abs() <= 1e-9 * hand.value.abs().max(1.0),
1760 "col {out_col} value: tower {} vs hand {}",
1761 tower.v,
1762 hand.value
1763 );
1764 for a in 0..6 {
1765 assert!(
1766 (tower.g[a] - hand.first[a]).abs() <= 1e-9 * g_floor.max(1e-12),
1767 "col {out_col} first[{a}]: tower {} vs hand {}",
1768 tower.g[a],
1769 hand.first[a]
1770 );
1771 for b in 0..6 {
1772 assert!(
1773 (tower.h[a][b] - hand.second[a][b]).abs() <= 1e-8 * h_floor.max(1e-12),
1774 "col {out_col} second[{a}][{b}]: tower {} vs hand {}",
1775 tower.h[a][b],
1776 hand.second[a][b]
1777 );
1778 }
1779 }
1780 }
1781 }
1782
1783 /// A planted sign flip in the hand cross-block (logit×coord) is caught by the
1784 /// oracle — the same failure that #736 was, made loud at introduction.
1785 #[test]
1786 fn planted_cross_block_sign_flip_is_caught() {
1787 let (prog, inv_tau) = softmax_fixture(1.3);
1788 let out_col = 1;
1789 let tower = prog.reconstruction_column::<6>(out_col);
1790 let mut hand = hand_softmax_column(&prog, out_col, inv_tau);
1791 // Corrupt one logit×coord cross block (atom-0 logit slot 0, atom-1
1792 // coord slot 4): flip its sign, the #736 disease.
1793 hand.second[0][4] = -hand.second[0][4];
1794 hand.second[4][0] = -hand.second[4][0];
1795 let h_floor = tower
1796 .h
1797 .iter()
1798 .flatten()
1799 .fold(0.0_f64, |m, x| m.max(x.abs()));
1800 let disagrees = (tower.h[0][4] - hand.second[0][4]).abs() > 1e-8 * h_floor.max(1e-12);
1801 assert!(
1802 disagrees,
1803 "a flipped cross block must disagree with the tower truth"
1804 );
1805 }
1806
1807 /// The tower gate channels alone reproduce the softmax `gate_derivatives_for_row`
1808 /// arithmetic — isolating the gate nonlinearity from the basis/decoder so a
1809 /// regression in either is localizable.
1810 #[test]
1811 fn softmax_gate_tower_matches_hand_gate_derivatives() {
1812 let (prog, inv_tau) = softmax_fixture(0.9);
1813 let (dz, d2z) = softmax_gate_derivs(&prog.gate_value, inv_tau);
1814 for atom in 0..prog.atoms.len() {
1815 let gate = prog.gate_tower::<6, Tower4<6>>(atom);
1816 // ζ_atom value.
1817 assert!((gate.v - prog.gate_value[atom]).abs() < 1e-12);
1818 // ∂ζ_atom/∂ℓ_j == dz[j][atom].
1819 for j in 0..prog.atoms.len() {
1820 let slot = prog.logit_slot[j].unwrap();
1821 assert!(
1822 (gate.g[slot] - dz[j][atom]).abs() < 1e-9,
1823 "gate {atom} d/dlogit {j}: tower {} vs hand {}",
1824 gate.g[slot],
1825 dz[j][atom]
1826 );
1827 }
1828 // ∂²ζ_atom/∂ℓ_j∂ℓ_l == d2z[j][l][atom].
1829 for j in 0..prog.atoms.len() {
1830 for l in 0..prog.atoms.len() {
1831 let sj = prog.logit_slot[j].unwrap();
1832 let sl = prog.logit_slot[l].unwrap();
1833 assert!(
1834 (gate.h[sj][sl] - d2z[j][l][atom]).abs() < 1e-8,
1835 "gate {atom} d2/dlogit {j}{l}: tower {} vs hand {}",
1836 gate.h[sj][sl],
1837 d2z[j][l][atom]
1838 );
1839 }
1840 }
1841 }
1842 }
1843
1844 /// The per-atom logistic gate (IBP/JumpReLU branch) is diagonal in the
1845 /// logits and reproduces `σ' = σ(1−σ)·inv_tau`, `σ'' = σ(1−σ)(1−2σ)·inv_tau²`.
1846 #[test]
1847 fn per_atom_logistic_gate_matches_closed_form() {
1848 let inv_tau = 1.4;
1849 let logit = 0.6;
1850 let shift = 0.2;
1851 let x: f64 = (logit - shift) * inv_tau;
1852 let sigma = 1.0 / (1.0 + (-x).exp());
1853 let prog = SaeReconstructionRowProgram {
1854 atoms: vec![AtomRowBasisJet {
1855 phi: vec![1.0],
1856 d_phi: vec![vec![0.0]],
1857 d2_phi: vec![vec![vec![0.0]]],
1858 decoder: vec![vec![1.0]],
1859 latent_dim: 1,
1860 }],
1861 gate_value: vec![sigma],
1862 logits: vec![logit],
1863 gate_scale: vec![1.0],
1864 gate_shift: vec![shift],
1865 gate: RowGate::PerAtomLogistic { inv_tau },
1866 logit_slot: vec![Some(0)],
1867 coord_slot: vec![vec![1]],
1868 n_primaries: 2,
1869 };
1870 let gate = prog.gate_tower::<2, Tower4<2>>(0);
1871 assert!((gate.v - sigma).abs() < 1e-12);
1872 let d1 = sigma * (1.0 - sigma) * inv_tau;
1873 let d2 = sigma * (1.0 - sigma) * (1.0 - 2.0 * sigma) * inv_tau * inv_tau;
1874 assert!((gate.g[0] - d1).abs() < 1e-9, "σ': {} vs {}", gate.g[0], d1);
1875 assert!(
1876 (gate.h[0][0] - d2).abs() < 1e-9,
1877 "σ'': {} vs {}",
1878 gate.h[0][0],
1879 d2
1880 );
1881 }
1882
1883 /// #932 cutover pin: the PRODUCTION packed [`Order2`] reconstruction path
1884 /// (`reconstruction_column_packed`) is BIT-IDENTICAL on the
1885 /// value/gradient/Hessian channels to the dense [`Tower4`] oracle
1886 /// (`reconstruction_column`) — the same channels the arrow-Schur logdet
1887 /// consumer reads — for every output column. The Order2 path never
1888 /// materialises `t3`/`t4`, but its `(v, g, H)` must match the dense tower's
1889 /// order-≤2 channels to ≤1e-12 (they share the `Tower2` arithmetic), so the
1890 /// cutover changes only cost, not result.
1891 #[test]
1892 fn order2_reconstruction_matches_tower_value_grad_hessian() {
1893 for tau in [0.9_f64, 1.3, 2.1] {
1894 let (prog, _inv_tau) = softmax_fixture(tau);
1895 for out_col in 0..prog.out_dim() {
1896 let packed = prog.reconstruction_column_packed::<6>(out_col);
1897 let tower = prog.reconstruction_column::<6>(out_col);
1898 let g = packed.g();
1899 let h = packed.h();
1900 let band = |x: f64| 1e-12 + 1e-12 * x.abs();
1901 assert!(
1902 (packed.value() - tower.v).abs() <= band(tower.v),
1903 "col {out_col} value: order2 {} vs tower {}",
1904 packed.value(),
1905 tower.v
1906 );
1907 for a in 0..6 {
1908 assert!(
1909 (g[a] - tower.g[a]).abs() <= band(tower.g[a]),
1910 "col {out_col} g[{a}]: order2 {} vs tower {}",
1911 g[a],
1912 tower.g[a]
1913 );
1914 for b in 0..6 {
1915 assert!(
1916 (h[a][b] - tower.h[a][b]).abs() <= band(tower.h[a][b]),
1917 "col {out_col} h[{a}][{b}]: order2 {} vs tower {}",
1918 h[a][b],
1919 tower.h[a][b]
1920 );
1921 }
1922 }
1923 }
1924 }
1925 }
1926
1927 /// #932 cutover pin for the β border channel: the packed [`Order2`]
1928 /// `beta_border_tower_packed` matches the dense [`Tower4`]
1929 /// `beta_border_tower` on the value (`beta`) and gradient (`beta_deriv` /
1930 /// `beta_l_deriv`) channels the consumer reads, to ≤1e-12.
1931 #[test]
1932 fn order2_beta_border_matches_tower_value_grad() {
1933 let (prog, _inv_tau) = softmax_fixture(1.1);
1934 for atom in 0..prog.atoms.len() {
1935 for basis_col in 0..prog.atoms[atom].n_basis() {
1936 let packed = prog.beta_border_tower_packed::<6>(atom, basis_col);
1937 let tower = prog.beta_border_tower::<6>(atom, basis_col);
1938 let g = packed.g();
1939 let band = |x: f64| 1e-12 + 1e-12 * x.abs();
1940 assert!(
1941 (packed.value() - tower.v).abs() <= band(tower.v),
1942 "atom {atom} b {basis_col} value: order2 {} vs tower {}",
1943 packed.value(),
1944 tower.v
1945 );
1946 for a in 0..6 {
1947 assert!(
1948 (g[a] - tower.g[a]).abs() <= band(tower.g[a]),
1949 "atom {atom} b {basis_col} g[{a}]: order2 {} vs tower {}",
1950 g[a],
1951 tower.g[a]
1952 );
1953 }
1954 }
1955 }
1956 }
1957
1958 /// #932 perf pin: the gate-shared `all_gates` produces gate jets
1959 /// BIT-IDENTICAL to the per-atom `gate_tower` — sharing the softmax
1960 /// denominator / reciprocal across atoms (K exps + 1 recip instead of
1961 /// K² + K) changes only which redundant work is elided, not the result
1962 /// (`ζ_k = exp_k · recip(denom)` is the same product, same Leibniz order).
1963 #[test]
1964 fn shared_all_gates_bit_identical_to_per_atom_gate_tower() {
1965 for tau in [0.9_f64, 1.3, 2.1] {
1966 let (prog, _inv_tau) = softmax_fixture(tau);
1967 let all = prog.all_gates::<6, Order2<6>>();
1968 assert_eq!(all.len(), prog.gate_value.len());
1969 for atom in 0..prog.gate_value.len() {
1970 let per = prog.gate_tower::<6, Order2<6>>(atom);
1971 assert_eq!(all[atom].value(), per.value(), "atom {atom} value");
1972 for a in 0..6 {
1973 assert_eq!(all[atom].g()[a], per.g()[a], "atom {atom} g[{a}]");
1974 for b in 0..6 {
1975 assert_eq!(
1976 all[atom].h()[a][b],
1977 per.h()[a][b],
1978 "atom {atom} h[{a}][{b}]"
1979 );
1980 }
1981 }
1982 }
1983 }
1984 }
1985
1986 /// #932 perf pin: the gate/basis-HOISTED + denominator-SHARED all-columns
1987 /// reconstruction (`reconstruction_all_columns_packed`) is BIT-IDENTICAL to
1988 /// calling `reconstruction_column_packed(c)` per column — the hoist + share
1989 /// removes only redundant gate/basis/denominator recomputation, not any
1990 /// arithmetic. Every value/grad/Hessian channel must match exactly (==),
1991 /// since the Leibniz products are the same in the same order.
1992 #[test]
1993 fn hoisted_all_columns_bit_identical_to_per_column() {
1994 for tau in [0.9_f64, 1.3, 2.1] {
1995 let (prog, _inv_tau) = softmax_fixture(tau);
1996 let all = prog.reconstruction_all_columns_packed::<6>();
1997 assert_eq!(all.len(), prog.out_dim());
1998 for out_col in 0..prog.out_dim() {
1999 let per = prog.reconstruction_column_packed::<6>(out_col);
2000 let ah = all[out_col];
2001 assert_eq!(ah.value(), per.value(), "col {out_col} value");
2002 for a in 0..6 {
2003 assert_eq!(ah.g()[a], per.g()[a], "col {out_col} g[{a}]");
2004 for b in 0..6 {
2005 assert_eq!(ah.h()[a][b], per.h()[a][b], "col {out_col} h[{a}][{b}]");
2006 }
2007 }
2008 }
2009 }
2010 }
2011
2012 /// Build four softmax-aligned row programs that differ ONLY in their per-row
2013 /// numeric data (logits, basis values, decoder), keeping the layout
2014 /// (slots / dims / temperature) identical so they are 4-row SIMD-batchable.
2015 fn softmax_batch_fixture(inv_tau: f64) -> [SaeReconstructionRowProgram; LANES] {
2016 let n_basis = 3;
2017 let out_dim = 4;
2018 let mk = |row_seed: f64| {
2019 let mk_atom = |seed: f64| {
2020 let phi: Vec<f64> = (0..n_basis)
2021 .map(|b| 0.3 + 0.2 * (b as f64 + seed) + 0.11 * row_seed)
2022 .collect();
2023 let d_phi: Vec<Vec<f64>> = (0..n_basis)
2024 .map(|b| {
2025 (0..2)
2026 .map(|axis| {
2027 0.1 * (b as f64 + 1.0) - 0.05 * axis as f64 + 0.03 * seed
2028 + 0.017 * row_seed
2029 })
2030 .collect()
2031 })
2032 .collect();
2033 let d2_phi: Vec<Vec<Vec<f64>>> = (0..n_basis)
2034 .map(|b| {
2035 (0..2)
2036 .map(|a| {
2037 (0..2)
2038 .map(|bb| {
2039 0.02 * (b as f64 + 1.0)
2040 + 0.01 * (a as f64)
2041 + 0.01 * (bb as f64)
2042 + 0.004 * seed
2043 + 0.003 * row_seed
2044 })
2045 .collect()
2046 })
2047 .collect()
2048 })
2049 .collect();
2050 let decoder: Vec<Vec<f64>> = (0..n_basis)
2051 .map(|b| {
2052 (0..out_dim)
2053 .map(|c| {
2054 0.5 - 0.1 * (b as f64) + 0.07 * (c as f64) + 0.02 * seed
2055 + 0.009 * row_seed
2056 })
2057 .collect()
2058 })
2059 .collect();
2060 AtomRowBasisJet {
2061 phi,
2062 d_phi,
2063 d2_phi,
2064 decoder,
2065 latent_dim: 2,
2066 }
2067 };
2068 let logits = vec![0.4 + 0.21 * row_seed, -0.7 + 0.13 * row_seed];
2069 let e: Vec<f64> = logits.iter().map(|&l| (l * inv_tau).exp()).collect();
2070 let s: f64 = e.iter().sum();
2071 let gate_value: Vec<f64> = e.iter().map(|&v| v / s).collect();
2072 SaeReconstructionRowProgram {
2073 atoms: vec![mk_atom(0.0), mk_atom(1.0)],
2074 gate_value,
2075 logits,
2076 gate_scale: vec![1.0, 1.0],
2077 gate_shift: vec![0.0, 0.0],
2078 gate: RowGate::Softmax { inv_tau },
2079 logit_slot: vec![Some(0), Some(1)],
2080 coord_slot: vec![vec![2, 3], vec![4, 5]],
2081 n_primaries: 6,
2082 }
2083 };
2084 [mk(0.0), mk(1.0), mk(2.0), mk(3.0)]
2085 }
2086
2087 /// SIMD-batch bit-identity oracle: `reconstruction_all_columns_batch4` lane
2088 /// `i` is `to_bits`-identical to the scalar `reconstruction_all_columns_packed`
2089 /// on row `i`, across many temperatures and randomized per-row data
2090 /// (≥2000 channel comparisons). The 4-row SIMD pass changes only how many
2091 /// rows share one instruction stream, never the arithmetic.
2092 #[test]
2093 fn batch4_reconstruction_bit_identical_to_per_row() {
2094 let mut comparisons = 0usize;
2095 for tau in [0.7_f64, 0.9, 1.1, 1.3, 1.7, 2.1, 2.9] {
2096 let rows = softmax_batch_fixture(tau);
2097 let refs = [&rows[0], &rows[1], &rows[2], &rows[3]];
2098 let batch = SaeReconstructionRowProgram::reconstruction_all_columns_batch4::<6>(refs)
2099 .expect("softmax-aligned rows must batch");
2100 for lane in 0..LANES {
2101 let per = rows[lane].reconstruction_all_columns_packed::<6>();
2102 assert_eq!(per.len(), batch[lane].len());
2103 for (c, (b, p)) in batch[lane].iter().zip(per.iter()).enumerate() {
2104 assert_eq!(
2105 b.value().to_bits(),
2106 p.value().to_bits(),
2107 "tau {tau} lane {lane} col {c} value"
2108 );
2109 let (bg, pg) = (b.g(), p.g());
2110 let (bh, ph) = (b.h(), p.h());
2111 for a in 0..6 {
2112 assert_eq!(bg[a].to_bits(), pg[a].to_bits(), "lane {lane} col {c} g[{a}]");
2113 for d in 0..6 {
2114 assert_eq!(
2115 bh[a][d].to_bits(),
2116 ph[a][d].to_bits(),
2117 "lane {lane} col {c} h[{a}][{d}]"
2118 );
2119 comparisons += 1;
2120 }
2121 }
2122 }
2123 }
2124 }
2125 assert!(comparisons >= 2000, "oracle ran {comparisons} comparisons");
2126 }
2127
2128 /// SIMD-batch bit-identity oracle for the β-border first-order path:
2129 /// `beta_border_order1_batch4` lane `i` is `to_bits`-identical to
2130 /// `beta_border_order1_packed` on row `i`.
2131 #[test]
2132 fn batch4_beta_border_bit_identical_to_per_row() {
2133 let mut comparisons = 0usize;
2134 for tau in [0.7_f64, 0.9, 1.1, 1.3, 1.7, 2.1, 2.9] {
2135 let rows = softmax_batch_fixture(tau);
2136 let refs = [&rows[0], &rows[1], &rows[2], &rows[3]];
2137 let mut chans: Vec<(usize, usize)> = Vec::new();
2138 for atom in 0..rows[0].atoms.len() {
2139 for b in 0..rows[0].atoms[atom].n_basis() {
2140 chans.push((atom, b));
2141 }
2142 }
2143 chans.push(chans[0]); // repeat to exercise gate-cache reuse
2144 let batch =
2145 SaeReconstructionRowProgram::beta_border_order1_batch4::<6>(refs, &chans)
2146 .expect("softmax-aligned rows must batch");
2147 for lane in 0..LANES {
2148 let per = rows[lane].beta_border_order1_packed::<6>(&chans);
2149 assert_eq!(per.len(), batch[lane].len());
2150 for (i, (b, p)) in batch[lane].iter().zip(per.iter()).enumerate() {
2151 assert_eq!(b.value().to_bits(), p.value().to_bits(), "lane {lane} chan {i} v");
2152 let (bg, pg) = (b.g(), p.g());
2153 for a in 0..6 {
2154 assert_eq!(
2155 bg[a].to_bits(),
2156 pg[a].to_bits(),
2157 "lane {lane} chan {i} g[{a}]"
2158 );
2159 comparisons += 1;
2160 }
2161 }
2162 }
2163 }
2164 assert!(comparisons >= 1000, "oracle ran {comparisons} comparisons");
2165 }
2166
2167 /// A non-softmax (per-atom logistic) batch must DECLINE (return `None`) so the
2168 /// caller falls back to the scalar per-row path — the logistic branch is
2169 /// per-row data-dependent and not lane-uniform.
2170 #[test]
2171 fn batch4_declines_non_softmax() {
2172 let inv_tau = 1.1;
2173 let mk = || SaeReconstructionRowProgram {
2174 atoms: vec![AtomRowBasisJet {
2175 phi: vec![1.0],
2176 d_phi: vec![vec![0.0]],
2177 d2_phi: vec![vec![vec![0.0]]],
2178 decoder: vec![vec![1.0]],
2179 latent_dim: 1,
2180 }],
2181 gate_value: vec![0.6],
2182 logits: vec![0.6],
2183 gate_scale: vec![1.0],
2184 gate_shift: vec![0.2],
2185 gate: RowGate::PerAtomLogistic { inv_tau },
2186 logit_slot: vec![Some(0)],
2187 coord_slot: vec![vec![1]],
2188 n_primaries: 2,
2189 };
2190 let rows = [mk(), mk(), mk(), mk()];
2191 let refs = [&rows[0], &rows[1], &rows[2], &rows[3]];
2192 assert!(
2193 SaeReconstructionRowProgram::reconstruction_all_columns_batch4::<2>(refs).is_none()
2194 );
2195 }
2196
2197 /// #932 perf pin: the gate-HOISTED batched β border jets
2198 /// (`beta_border_towers_packed`) are BIT-IDENTICAL to per-channel
2199 /// `beta_border_tower_packed`, including when several channels share an atom
2200 /// (the gate-cache reuse path).
2201 #[test]
2202 fn hoisted_beta_border_bit_identical_to_per_channel() {
2203 let (prog, _inv_tau) = softmax_fixture(1.1);
2204 // Build a channel list that repeats atoms (exercises the gate cache).
2205 let mut chans: Vec<(usize, usize)> = Vec::new();
2206 for atom in 0..prog.atoms.len() {
2207 for basis_col in 0..prog.atoms[atom].n_basis() {
2208 chans.push((atom, basis_col));
2209 }
2210 }
2211 // Duplicate the first atom's channels at the end to force cache reuse.
2212 if let Some(&first) = chans.first() {
2213 chans.push(first);
2214 }
2215 let batched = prog.beta_border_towers_packed::<6>(&chans);
2216 assert_eq!(batched.len(), chans.len());
2217 for (i, &(atom, basis_col)) in chans.iter().enumerate() {
2218 let per = prog.beta_border_tower_packed::<6>(atom, basis_col);
2219 let b = batched[i];
2220 assert_eq!(b.value(), per.value(), "chan {i} value");
2221 for a in 0..6 {
2222 assert_eq!(b.g()[a], per.g()[a], "chan {i} g[{a}]");
2223 }
2224 }
2225 }
2226}