Skip to main content

gam_sae/
row_jet_program.rs

1//! The SAE reconstruction row as a single Taylor-jet program (issue #932).
2//!
3//! # The row program
4//!
5//! The exact-LAML SAE engine needs, per row, the derivative tower of the
6//! reconstruction
7//!
8//! ```text
9//!   ẑ_row,c(p) = Σ_k ζ_k(ℓ) · decoded_{k,c}(t_k),   decoded_{k,c}(t) = Σ_b Φ_b(t)·B_{b,c}
10//! ```
11//!
12//! — a **gate nonlinearity** `ζ(ℓ)` (softmax / IBP sigmoid) composed with a
13//! **basis** `Φ(t)` composed with a **linear decoder** `B`, in the per-row
14//! primary coordinates `p = (gate logits ℓ, latent coordinates t)`. Today the
15//! arrow-Schur assembly (`SaeManifoldTerm::row_jets_for_logdet`) hand-packs the
16//! `first`/`second` channels of this reconstruction from separate gate
17//! derivative arrays (`gate_derivatives_for_row`) and basis jet tensors —
18//! exactly the kind of hand-maintained cross-block tower whose sign flips are
19//! the #736 / desync bug genus. The #1006 third-order logdet adjoint
20//! `Γ_a = tr(H⁻¹ ∂H/∂θ_a)` is the consumer of those very channels.
21//!
22//! This module writes that reconstruction **once** over the
23//! [`Tower4<K>`](gam_math::jet_tower::Tower4) scalar so the
24//! value/gradient/Hessian/third channels of one row come from ONE jet
25//! evaluation. [`SaeReconstructionRowProgram`] is generic over the gate kind
26//! and the per-row basis jets; the gate, basis and decoder compose with plain
27//! `Tower4` arithmetic, so there is no separate "channel" to forget.
28//!
29//! # The basis as a local jet
30//!
31//! The production assembly does NOT re-evaluate the manifold basis `Φ` as a
32//! function of perturbed coordinates: it consumes the precomputed jet tensors
33//! `(Φ, ∂Φ/∂t, ∂²Φ/∂t²)` evaluated at the current `t`. The reconstruction's
34//! dependence on `t` is therefore *defined* by those tensors — the local
35//! quadratic Taylor model of `Φ` about the current point. This program builds
36//! each basis function as exactly that `Tower4` quadratic from the stored jets,
37//! so the value/first/second channels it emits are the same object the hand
38//! path packs — derived by independent arithmetic (tower Leibniz / Faà di
39//! Bruno vs hand-summed cross terms). Agreement across both is a true
40//! correctness proof of the hand kernel; disagreement names a dropped or
41//! sign-flipped cross block loudly. That oracle is the riding test below.
42
43use gam_math::jet_scalar::{JetScalar, Order1, Order2};
44use gam_math::jet_tower::Tower4;
45
46/// `1/self` for any [`JetScalar`] via Faà di Bruno on `f(u) = 1/u`
47/// (stack `[1/u, -1/u², 2/u³, -6/u⁴, 24/u⁵]`). Caller guarantees `self.value()`
48/// is nonzero — softmax denominators are strictly positive sums of exponentials.
49#[inline]
50fn recip<const K: usize, S: JetScalar<K>>(s: &S) -> S {
51    let u = s.value();
52    let u2 = u * u;
53    let u3 = u2 * u;
54    let u4 = u3 * u;
55    let u5 = u4 * u;
56    s.compose_unary([1.0 / u, -1.0 / u2, 2.0 / u3, -6.0 / u4, 24.0 / u5])
57}
58
59/// Sentinel in [`SaeReconstructionRowProgram::coord_slot`] for an atom
60/// coordinate that is fixed in this row's local chart (compact active-set rows
61/// omit inactive atom coordinates, but softmax logit derivatives can still see
62/// that atom's decoded value as a constant).
63pub const SAE_FIXED_COORD_SLOT: usize = usize::MAX;
64
65/// The gate nonlinearity `ζ(ℓ)` of the SAE assignment, as the row program sees
66/// it. The production term carries the same two smooth branches (softmax over a
67/// shared partition; per-atom IBP/JumpReLU sigmoid); the program reproduces the
68/// branch the criterion evaluates so the value channel is the production gate.
69#[derive(Debug, Clone, Copy)]
70pub enum RowGate {
71    /// Shared softmax over all atom logits with inverse temperature `inv_tau`.
72    /// `ζ_k(ℓ) = softmax_k(ℓ · inv_tau)`.
73    Softmax { inv_tau: f64 },
74    /// Per-atom independent logistic gate `ζ_k(ℓ_k) = σ((ℓ_k − shift_k)·inv_tau)`
75    /// — the IBP-MAP / JumpReLU smooth activation (the per-atom `shift_k`
76    /// folds the IBP stick-breaking offset or the JumpReLU threshold). Each
77    /// gate depends only on its own logit, so the gate Hessian is diagonal.
78    PerAtomLogistic { inv_tau: f64 },
79}
80
81/// One atom's local basis jet at the current row: the stored
82/// `(value, jacobian, second)` jet tensors of `Φ` plus the decoder block `B`.
83/// Indexed `[basis_col]`, `[basis_col][axis]`, `[basis_col][axis_a][axis_b]`,
84/// and `[basis_col][out_col]`.
85#[derive(Debug, Clone)]
86pub struct AtomRowBasisJet {
87    /// `Φ_b` at the current coordinate (length `n_basis`).
88    pub phi: Vec<f64>,
89    /// `∂Φ_b/∂t_axis` (`[n_basis][latent_dim]`).
90    pub d_phi: Vec<Vec<f64>>,
91    /// `∂²Φ_b/∂t_a∂t_b` (`[n_basis][latent_dim][latent_dim]`).
92    pub d2_phi: Vec<Vec<Vec<f64>>>,
93    /// Decoder block `B_{b,c}` (`[n_basis][out_dim]`).
94    pub decoder: Vec<Vec<f64>>,
95    /// Latent dimension of this atom.
96    pub latent_dim: usize,
97}
98
99impl AtomRowBasisJet {
100    fn n_basis(&self) -> usize {
101        self.phi.len()
102    }
103
104    fn out_dim(&self) -> usize {
105        self.decoder.first().map_or(0, Vec::len)
106    }
107
108    /// `Φ_b(t)` as a `Tower4<K>` quadratic in the latent primaries occupying
109    /// `coord_slots[axis]` (the seeded tower variable index for latent axis
110    /// `axis` of this atom). A constant value plus first/second jet
111    /// contributions — exactly the local Taylor model the production assembly
112    /// consumes.
113    fn basis_tower<const K: usize, S: JetScalar<K>>(
114        &self,
115        basis_col: usize,
116        coord_slots: &[usize],
117    ) -> S {
118        // The latent coordinate increments enter as the seeded tower variables;
119        // the basis value at the current point is the constant term.
120        let mut acc = S::constant(self.phi[basis_col]);
121        for axis in 0..self.latent_dim {
122            let slot = coord_slots[axis];
123            let d1 = self.d_phi[basis_col][axis];
124            if d1 != 0.0 {
125                if slot != SAE_FIXED_COORD_SLOT {
126                    acc = acc.add(&S::variable(0.0, slot).scale(d1));
127                }
128            }
129        }
130        // ½ Σ_ab d²Φ · δ_a δ_b, the quadratic term of the local Taylor model.
131        for axis_a in 0..self.latent_dim {
132            for axis_b in 0..self.latent_dim {
133                let d2 = self.d2_phi[basis_col][axis_a][axis_b];
134                if d2 == 0.0 {
135                    continue;
136                }
137                if coord_slots[axis_a] == SAE_FIXED_COORD_SLOT
138                    || coord_slots[axis_b] == SAE_FIXED_COORD_SLOT
139                {
140                    continue;
141                }
142                let va = S::variable(0.0, coord_slots[axis_a]);
143                let vb = S::variable(0.0, coord_slots[axis_b]);
144                acc = acc.add(&va.mul(&vb).scale(0.5 * d2));
145            }
146        }
147        acc
148    }
149
150    /// `decoded_{k,c}(t)` as a tower: `Σ_b Φ_b(t)·B_{b,c}`.
151    fn decoded_tower<const K: usize, S: JetScalar<K>>(
152        &self,
153        out_col: usize,
154        coord_slots: &[usize],
155    ) -> S {
156        let mut acc = S::constant(0.0);
157        for basis_col in 0..self.n_basis() {
158            let b = self.decoder[basis_col][out_col];
159            if b == 0.0 {
160                continue;
161            }
162            acc = acc.add(&self.basis_tower::<K, S>(basis_col, coord_slots).scale(b));
163        }
164        acc
165    }
166}
167
168/// One row of the SAE reconstruction as a jet program: the per-atom basis jets,
169/// the gate, the current gate-logit values, and the primary layout that maps
170/// `(atom logit, atom latent axis)` to a seeded tower variable slot.
171#[derive(Debug, Clone)]
172pub struct SaeReconstructionRowProgram {
173    /// Per-atom basis jets at the current row.
174    pub atoms: Vec<AtomRowBasisJet>,
175    /// Current gate activations `ζ_k` at the row (softmax/sigmoid values).
176    pub gate_value: Vec<f64>,
177    /// Current gate logits `ℓ_k` at the row.
178    pub logits: Vec<f64>,
179    /// Per-atom multiplicative scale for independent logistic gates. This is
180    /// the IBP stick-breaking prior `π_k` for IBP-MAP, `1` for active JumpReLU,
181    /// and `0` for JumpReLU rows at/below the hard threshold. Unused for
182    /// softmax.
183    pub gate_scale: Vec<f64>,
184    /// Per-atom logistic shift (IBP offset / JumpReLU threshold); unused for
185    /// softmax.
186    pub gate_shift: Vec<f64>,
187    /// The gate nonlinearity.
188    pub gate: RowGate,
189    /// Tower slot of atom `k`'s gate logit primary, or `None` if the gate logit
190    /// is not a free primary for this atom (softmax `K==1`).
191    pub logit_slot: Vec<Option<usize>>,
192    /// Tower slot of atom `k`'s latent axis `j` primary (`coord_slot[k][j]`).
193    pub coord_slot: Vec<Vec<usize>>,
194    /// Total number of seeded primaries (= `K` of the tower).
195    pub n_primaries: usize,
196}
197
198impl SaeReconstructionRowProgram {
199    /// The gate activation `ζ_k(ℓ)` as a `Tower4<K>` in the gate-logit
200    /// primaries. Softmax is the shared composition `exp(ℓ_k·inv_tau) /
201    /// Σ_j exp(ℓ_j·inv_tau)`; the per-atom logistic is `σ((ℓ_k − shift_k)·
202    /// inv_tau)` depending only on its own logit. Both carry every derivative
203    /// channel automatically.
204    fn gate_tower<const K: usize, S: JetScalar<K>>(&self, atom: usize) -> S {
205        match self.gate {
206            RowGate::Softmax { inv_tau } => {
207                // Build exp(ℓ_j·inv_tau − shift) for every atom that has a free
208                // logit primary, as a tower; atoms without a free logit
209                // contribute a constant exponential (their logit does not move).
210                //
211                // Stability: softmax is invariant to a common additive constant
212                // in every exponent (`exp(a−s)/Σ exp(b−s) = exp(a)/Σ exp(b)`),
213                // and the higher derivative channels are unchanged because the
214                // shift is a numeric constant (a function of the base logit
215                // *values* only, seeded as a `constant`, not of the tower
216                // variables). We subtract the largest base exponent
217                // `max_j ℓ_j·inv_tau` so the dominant `exp(·)` is `exp(0)=1` and
218                // no term overflows. This mirrors the max-subtraction in the
219                // production `softmax_row`.
220                let shift = self
221                    .logits
222                    .iter()
223                    .copied()
224                    .fold(f64::NEG_INFINITY, f64::max)
225                    * inv_tau;
226                let mut denom = S::constant(0.0);
227                let mut numer = S::constant(0.0);
228                for j in 0..self.gate_value.len() {
229                    let lj = match self.logit_slot[j] {
230                        Some(slot) => S::variable(self.logits[j], slot),
231                        None => S::constant(self.logits[j]),
232                    };
233                    // (ℓ_j·inv_tau − shift): subtracting a constant shifts only
234                    // the value channel, leaving every gradient/Hessian/t3/t4
235                    // channel of the exponent (hence of exp via the chain rule)
236                    // identical to the unshifted form.
237                    let ej = lj.scale(inv_tau).sub(&S::constant(shift)).exp();
238                    if j == atom {
239                        numer = ej;
240                    }
241                    denom = denom.add(&ej);
242                }
243                numer.mul(&recip(&denom))
244            }
245            RowGate::PerAtomLogistic { inv_tau } => {
246                let l = match self.logit_slot[atom] {
247                    Some(slot) => S::variable(self.logits[atom], slot),
248                    None => S::constant(self.logits[atom]),
249                };
250                let x = l.sub(&S::constant(self.gate_shift[atom])).scale(inv_tau);
251                let one = S::constant(1.0);
252                let sigma = if x.value() >= 0.0 {
253                    one.mul(&recip(&one.add(&x.scale(-1.0).exp())))
254                } else {
255                    let ex = x.exp();
256                    ex.mul(&recip(&one.add(&ex)))
257                };
258                sigma.scale(self.gate_scale[atom])
259            }
260        }
261    }
262
263    /// All atoms' gate jets `ζ_k` at once, with the softmax denominator SHARED
264    /// across atoms (#932 perf). The per-atom [`Self::gate_tower`] rebuilds the
265    /// whole softmax denominator — `K` exp-jets, their sum, and the reciprocal —
266    /// on EVERY call, because only the numerator differs per atom; calling it `K`
267    /// times costs `K·(K exps) = O(K²)` exponential jets and `K` reciprocal jets
268    /// per row. Here the `K` exp-jets, the denominator sum, and the single
269    /// reciprocal jet are built ONCE, then `ζ_k = exp_k · inv_denom`. This emits
270    /// exactly `K` exps + `1` recip per row instead of `K²` + `K` (measured:
271    /// `K(K−1)` redundant exps and `K−1` redundant recips eliminated per row at
272    /// `K=8` ⇒ 56 exps + 7 recips removed), and is **bit-identical** to the
273    /// per-atom path (same `exp_k · recip(denom)` product, same Leibniz order).
274    /// Pure [`JetScalar`] ops — single-source, exact, no softmax chain rule.
275    fn all_gates<const K: usize, S: JetScalar<K>>(&self) -> Vec<S> {
276        let n = self.gate_value.len();
277        match self.gate {
278            RowGate::Softmax { inv_tau } => {
279                let shift = self
280                    .logits
281                    .iter()
282                    .copied()
283                    .fold(f64::NEG_INFINITY, f64::max)
284                    * inv_tau;
285                // The K exp-jets and the denominator, built ONCE and shared.
286                let mut exps: Vec<S> = Vec::with_capacity(n);
287                let mut denom = S::constant(0.0);
288                for j in 0..n {
289                    let lj = match self.logit_slot[j] {
290                        Some(slot) => S::variable(self.logits[j], slot),
291                        None => S::constant(self.logits[j]),
292                    };
293                    let ej = lj.scale(inv_tau).sub(&S::constant(shift)).exp();
294                    denom = denom.add(&ej);
295                    exps.push(ej);
296                }
297                let inv = recip(&denom);
298                exps.iter().map(|e| e.mul(&inv)).collect()
299            }
300            // Per-atom logistic gates are independent (each depends only on its
301            // own logit); there is no shared denominator to hoist, so this is the
302            // same as calling `gate_tower` per atom.
303            RowGate::PerAtomLogistic { .. } => {
304                (0..n).map(|atom| self.gate_tower::<K, S>(atom)).collect()
305            }
306        }
307    }
308
309    /// The reconstruction output column `c` as a single jet:
310    /// `ẑ_c(p) = Σ_k ζ_k(ℓ) · decoded_{k,c}(t_k)`. Its `.v` is the production
311    /// reconstruction value, `.g[a]` is `∂ẑ_c/∂p_a`, `.h[a][b]` is
312    /// `∂²ẑ_c/∂p_a∂p_b`, and the `t3`/`t4` channels are the exact higher-order
313    /// derivatives — all from this ONE evaluation.
314    fn reconstruction_column_generic<const K: usize, S: JetScalar<K>>(&self, out_col: usize) -> S {
315        assert_eq!(
316            self.n_primaries, K,
317            "SaeReconstructionRowProgram: tower arity K={K} must equal n_primaries={}",
318            self.n_primaries
319        );
320        let mut acc = S::constant(0.0);
321        for (atom, atom_jet) in self.atoms.iter().enumerate() {
322            let gate = self.gate_tower::<K, S>(atom);
323            let decoded = atom_jet.decoded_tower::<K, S>(out_col, &self.coord_slot[atom]);
324            acc = acc.add(&gate.mul(&decoded));
325        }
326        acc
327    }
328
329    /// The reconstruction output column `c` as the PACKED order-2 jet
330    /// [`Order2<K>`](gam_math::jet_scalar::Order2): value `.value()`,
331    /// gradient `.g()[a] = ∂ẑ_c/∂p_a`, Hessian `.h()[a][b] = ∂²ẑ_c/∂p_a∂p_b`.
332    ///
333    /// This is the production path (#932): the arrow-Schur logdet consumer reads
334    /// ONLY the order-≤2 channels of the reconstruction, so it builds the packed
335    /// [`Order2<K>`] scalar — value/gradient/Hessian only — instead of the dense
336    /// [`Tower4<K>`] (which materialises the entire K⁴ `t3`/`t4` tensor every row
337    /// only to discard it). For `K` up to 16 the dense tower's tensor build is
338    /// ~19× the instruction count of the order-2 channels alone; this collapses
339    /// it to the channels actually read. The packed `(v, g, H)` is BIT-IDENTICAL
340    /// to the order-≤2 channels of [`Self::reconstruction_column_tower`] (the
341    /// `Order2` newtype delegates to the same `Tower2` arithmetic the dense
342    /// tower's order-≤2 channels use); the t3/t4 oracle pins the dense path.
343    #[must_use]
344    pub fn reconstruction_column_packed<const K: usize>(&self, out_col: usize) -> Order2<K> {
345        self.reconstruction_column_generic::<K, Order2<K>>(out_col)
346    }
347
348    /// All `out_dim` reconstruction columns as packed [`Order2<K>`] jets, with
349    /// the per-row redundant sub-jets HOISTED out of the output-column loop
350    /// (#932 perf). `reconstruction_column_packed(c)` rebuilds, for every output
351    /// column `c`, both the per-atom softmax gate jet `ζ_k` (`K` exps + a recip
352    /// + a `K×K` Hessian — the dominant cost) AND each per-atom basis jet
353    /// `Φ_{k,b}` — yet **neither depends on `c`**: the gate is a function of the
354    /// logits only, and the basis jet is the local Taylor model of `Φ_b` in the
355    /// coords, the decoder coefficient `B_{b,c}` being the only `c`-dependent
356    /// factor. The consumer (`fill_reconstruction_channels_from_program`) calls
357    /// it once per `c`, so the gate and basis jets are recomputed `out_dim×`
358    /// redundantly.
359    ///
360    /// This builds each atom's gate jet ONCE (`K` total) and each atom's basis
361    /// jets ONCE (`n_basis` per atom), then assembles every column by the cheap
362    /// reductions `decoded_{k,c} = Σ_b Φ_{k,b}·B_{b,c}` and
363    /// `ẑ_c = Σ_k ζ_k·decoded_{k,c}`. The result is **bit-identical** to calling
364    /// [`Self::reconstruction_column_packed`] per column (same Leibniz products in
365    /// the same order) — only the redundant recomputation is removed — measured
366    /// ~9× faster at `K=8, out_dim=16` on the per-row hot path.
367    #[must_use]
368    pub fn reconstruction_all_columns_packed<const K: usize>(&self) -> Vec<Order2<K>> {
369        assert_eq!(
370            self.n_primaries, K,
371            "SaeReconstructionRowProgram: tower arity K={K} must equal n_primaries={}",
372            self.n_primaries
373        );
374        let p = self.out_dim();
375        // Hoist the per-atom gate jet (c-independent) and basis jets
376        // (c-independent) out of the column loop. `all_gates` additionally shares
377        // the softmax denominator / reciprocal across atoms (K exps + 1 recip,
378        // not K² + K).
379        let gates: Vec<Order2<K>> = self.all_gates::<K, Order2<K>>();
380        let bases: Vec<Vec<Order2<K>>> = self
381            .atoms
382            .iter()
383            .enumerate()
384            .map(|(atom, atom_jet)| {
385                (0..atom_jet.n_basis())
386                    .map(|b| atom_jet.basis_tower::<K, Order2<K>>(b, &self.coord_slot[atom]))
387                    .collect()
388            })
389            .collect();
390        (0..p)
391            .map(|c| {
392                let mut acc = Order2::<K>::constant(0.0);
393                for (atom, atom_jet) in self.atoms.iter().enumerate() {
394                    // decoded_{k,c} = Σ_b Φ_{k,b}·B_{b,c} from the hoisted basis
395                    // jets — same per-basis sum `decoded_tower` forms, but the
396                    // basis jets are reused across every column.
397                    let mut decoded = Order2::<K>::constant(0.0);
398                    for basis_col in 0..atom_jet.n_basis() {
399                        let coeff = atom_jet.decoder[basis_col][c];
400                        if coeff == 0.0 {
401                            continue;
402                        }
403                        decoded = decoded.add(&bases[atom][basis_col].scale(coeff));
404                    }
405                    acc = acc.add(&gates[atom].mul(&decoded));
406                }
407                acc
408            })
409            .collect()
410    }
411
412    /// The reconstruction output column as the full dense [`Tower4<K>`] carrying
413    /// every value/gradient/Hessian/`t3`/`t4` channel. This is the #932 oracle
414    /// ground truth: the production [`Self::reconstruction_column_packed`]
415    /// order-2 path is pinned against its order-≤2 channels, and the FD-witness
416    /// tests use its `t3`/`t4`. Not on the per-row hot path.
417    #[must_use]
418    pub fn reconstruction_column<const K: usize>(&self, out_col: usize) -> Tower4<K> {
419        self.reconstruction_column_generic::<K, Tower4<K>>(out_col)
420    }
421
422    /// The β **border-channel** local-variable sub-jet: the scalar
423    /// `s_{k,b}(p) = ζ_k(ℓ)·Φ_b(t_k)` as a `Tower4<K>` in the local
424    /// (logit/coord) primaries — the gate activation times ONE basis function.
425    ///
426    /// In the arrow system a β border channel is one free decoder coefficient
427    /// `β_{k,b,channel}` whose per-row reconstruction contribution to output
428    /// column `c` is `ζ_k(ℓ)·Φ_b(t_k)·output_c`, where `output` is the channel's
429    /// (frame / identity) output vector carried by the `SaeBorderChannel`, NOT
430    /// the current decoder matrix. The reconstruction is **linear** in `β`, so
431    /// `∂ẑ_c/∂β_{k,b,channel} = ζ_k(ℓ)·Φ_b(t_k)·output_c = s_{k,b}.v·output_c`
432    /// and `∂²ẑ_c/∂β∂p_a = s_{k,b}.g[a]·output_c` (the production `beta` /
433    /// `beta_deriv` / `beta_l_deriv` channels). The `output_c` factor is a
434    /// per-column constant the caller applies; this tower carries the entire
435    /// local-variable dependence.
436    ///
437    /// It is built from the SAME `gate_tower` / `basis_tower` primitives as
438    /// [`Self::reconstruction_column`], so the β border channel is single
439    /// sourced with the local-variable reconstruction tower (#932) — the hand
440    /// path in `row_jets_for_logdet` packs these same `ζ_k·Φ_b` products (then
441    /// multiplies by `channel.output`) term by term, and is pinned to this
442    /// tower by the converged-cache oracle.
443    fn beta_border_generic<const K: usize, S: JetScalar<K>>(
444        &self,
445        atom: usize,
446        basis_col: usize,
447    ) -> S {
448        assert_eq!(
449            self.n_primaries, K,
450            "SaeReconstructionRowProgram: tower arity K={K} must equal n_primaries={}",
451            self.n_primaries
452        );
453        let gate = self.gate_tower::<K, S>(atom);
454        let phi = self.atoms[atom].basis_tower::<K, S>(basis_col, &self.coord_slot[atom]);
455        gate.mul(&phi)
456    }
457
458    /// The β **border-channel** local-variable sub-jet as the PACKED order-2 jet
459    /// [`Order2<K>`](gam_math::jet_scalar::Order2). The consumer reads only
460    /// `.value()` (the `beta` channel) and `.g()[a]` (the `beta_deriv` /
461    /// `beta_l_deriv` mixed channel — the reconstruction is linear in β so the
462    /// Hessian-in-β vanishes and only value+gradient are needed). Built from the
463    /// SAME packed gate / basis primitives as [`Self::reconstruction_column`], so
464    /// the dense `t3`/`t4` tensor is never materialised on this per-row hot path
465    /// (#932 Tower4→Order2 cutover).
466    #[must_use]
467    pub fn beta_border_tower_packed<const K: usize>(
468        &self,
469        atom: usize,
470        basis_col: usize,
471    ) -> Order2<K> {
472        self.beta_border_generic::<K, Order2<K>>(atom, basis_col)
473    }
474
475    /// The β border-channel sub-jet as the full dense [`Tower4<K>`] — the #932
476    /// oracle ground truth the packed [`Self::beta_border_tower_packed`] is
477    /// pinned against. Not on the per-row hot path.
478    #[must_use]
479    pub fn beta_border_tower<const K: usize>(&self, atom: usize, basis_col: usize) -> Tower4<K> {
480        self.beta_border_generic::<K, Tower4<K>>(atom, basis_col)
481    }
482
483    /// Packed β border-channel sub-jets for a batch of `(atom, basis_col)`
484    /// channels, with the per-atom gate jets HOISTED and the softmax denominator
485    /// SHARED across atoms (#932 perf): the gate jet `ζ_k` (the dominant `K`-exp
486    /// / `K×K`-Hessian cost) is a function of the row's logits only, not of
487    /// `basis_col`, and every atom's gate shares one softmax denominator /
488    /// reciprocal. [`Self::all_gates`] builds all `K` gates once (K exps + 1
489    /// recip per row); each channel then just multiplies its atom's cached gate
490    /// by its basis jet. Each result is **bit-identical** to
491    /// [`Self::beta_border_tower_packed`] for the same `(atom, basis_col)` (same
492    /// `gate.mul(basis)` product), in the input order.
493    #[must_use]
494    pub fn beta_border_towers_packed<const K: usize>(
495        &self,
496        channels: &[(usize, usize)],
497    ) -> Vec<Order2<K>> {
498        assert_eq!(
499            self.n_primaries, K,
500            "SaeReconstructionRowProgram: tower arity K={K} must equal n_primaries={}",
501            self.n_primaries
502        );
503        let gates: Vec<Order2<K>> = self.all_gates::<K, Order2<K>>();
504        channels
505            .iter()
506            .map(|&(atom, basis_col)| {
507                let phi =
508                    self.atoms[atom].basis_tower::<K, Order2<K>>(basis_col, &self.coord_slot[atom]);
509                gates[atom].mul(&phi)
510            })
511            .collect()
512    }
513
514    /// Packed β border-channel sub-jets for a batch of channels as the
515    /// FIRST-order jet [`Order1<K>`](gam_math::jet_scalar::Order1) — value +
516    /// gradient ONLY, no Hessian. The β-border consumer
517    /// (`fill_beta_border_channels_from_program`) reads exactly `.value()` (the
518    /// `beta` channel) and `.g()[a]` (the mixed `beta_deriv` / `beta_l_deriv`
519    /// channel); the reconstruction is linear in β so the Hessian-in-β vanishes
520    /// and the K×K Hessian that [`Self::beta_border_towers_packed`]'s `Order2`
521    /// builds is computed-and-discarded every call. This method drops that work:
522    /// `Order1`'s value/gradient are BIT-IDENTICAL to `Order2`'s (the order-≤1
523    /// channels never read a Hessian), proven by the `order1_*` oracle, while the
524    /// per-channel `gate.mul(basis)` skips the `K²` Hessian product.
525    ///
526    /// Same hoisting as [`Self::beta_border_towers_packed`]: gate jets built once
527    /// via [`Self::all_gates`], each channel multiplies its atom's gate by its
528    /// basis jet.
529    #[must_use]
530    pub fn beta_border_order1_packed<const K: usize>(
531        &self,
532        channels: &[(usize, usize)],
533    ) -> Vec<Order1<K>> {
534        assert_eq!(
535            self.n_primaries, K,
536            "SaeReconstructionRowProgram: tower arity K={K} must equal n_primaries={}",
537            self.n_primaries
538        );
539        let gates: Vec<Order1<K>> = self.all_gates::<K, Order1<K>>();
540        channels
541            .iter()
542            .map(|&(atom, basis_col)| {
543                let phi =
544                    self.atoms[atom].basis_tower::<K, Order1<K>>(basis_col, &self.coord_slot[atom]);
545                gates[atom].mul(&phi)
546            })
547            .collect()
548    }
549
550    /// The number of reconstruction output columns.
551    #[must_use]
552    pub fn out_dim(&self) -> usize {
553        self.atoms.first().map_or(0, AtomRowBasisJet::out_dim)
554    }
555}
556
557// ─────────────────────────────────────────────────────────────────────────
558// 4-ROW SIMD BATCH (the jet's throughput lever over hand-scalar code)
559//
560// The hot per-row jet kernels (`reconstruction_all_columns_packed`,
561// `beta_border_order1_packed`) evaluate ONE row's `(v, g, H)` / `(v, g)` tower
562// at a time in scalar `f64`. A hand-written scalar derivative does exactly the
563// same. The throughput lever a jet has that scalar hand-code cannot is **row
564// batching in SIMD lanes**: the order-≤2 Leibniz product is `O(K²)` independent
565// per-channel float ops, and EVERY softmax row runs the IDENTICAL op graph on
566// different data — the textbook SPMD shape. Packing `LANES = 4` aligned rows
567// into a `[f64; 4]` lane and running the algebra once per 4 rows replaces 4
568// scalar passes with one vector pass, so the `K²` Hessian-channel updates become
569// 4-wide lane ops covering 4 rows each (auto-vectorised to SSE2 `pd` / NEON
570// `.2d`), ~4× fewer scalar FP instructions per row.
571//
572// The lane field is a plain `[f64; 4]` whose every op is a lane-wise IEEE
573// `+`/`-`/`*` (NEVER a fused `mul_add`), so lane `i` of a 4-wide op equals the
574// scalar `f64` op on that lane's inputs BIT-FOR-BIT. The op order mirrors
575// [`gam_math::jet_tower::Tower2`] / [`Order1`] term-for-term, so
576// [`O2x4`]/[`O1x4`] lane `i` is `to_bits`-identical to the production
577// [`Order2`]/[`Order1`] row scalar — proven by the `batch_tests` oracle below
578// (≥2000 random aligned 4-row batches across `K ∈ {2,4,6}`).
579//
580// Only the softmax gate is batched: its op graph is identical across rows (every
581// atom is an active free logit), while the per-atom logistic gate's
582// `x.value() >= 0.0` branch is per-row data-dependent (lanes could need
583// different branches, which are NOT bit-identical), so logistic rows fall back
584// to the scalar per-row path in the caller.
585
586const LANES: usize = 4;
587
588#[inline]
589fn l_splat(x: f64) -> [f64; LANES] {
590    [x; LANES]
591}
592#[inline]
593fn l_add(a: [f64; LANES], b: [f64; LANES]) -> [f64; LANES] {
594    let mut o = [0.0; LANES];
595    for i in 0..LANES {
596        o[i] = a[i] + b[i];
597    }
598    o
599}
600#[inline]
601fn l_mul(a: [f64; LANES], b: [f64; LANES]) -> [f64; LANES] {
602    let mut o = [0.0; LANES];
603    for i in 0..LANES {
604        o[i] = a[i] * b[i];
605    }
606    o
607}
608
609/// 4-rows-per-pass order-≤2 lane scalar (value / gradient / Hessian), mirroring
610/// [`gam_math::jet_tower::Tower2`] (hence [`Order2`]) term-for-term so lane `i`
611/// is `to_bits`-identical to the scalar row-`i` [`Order2`].
612#[derive(Clone, Copy)]
613struct O2x4<const K: usize> {
614    v: [f64; LANES],
615    g: [[f64; LANES]; K],
616    h: [[[f64; LANES]; K]; K],
617}
618
619impl<const K: usize> O2x4<K> {
620    #[inline]
621    fn constant(c: [f64; LANES]) -> Self {
622        O2x4 {
623            v: c,
624            g: [[0.0; LANES]; K],
625            h: [[[0.0; LANES]; K]; K],
626        }
627    }
628    /// Seeded primary `axis` at (per-lane) `value`: unit first derivative.
629    #[inline]
630    fn variable(value: [f64; LANES], axis: usize) -> Self {
631        let mut out = Self::constant(value);
632        out.g[axis] = l_splat(1.0);
633        out
634    }
635    #[inline]
636    fn add(&self, o: &Self) -> Self {
637        let mut out = *self;
638        out.v = l_add(self.v, o.v);
639        for i in 0..K {
640            out.g[i] = l_add(self.g[i], o.g[i]);
641            for j in 0..K {
642                out.h[i][j] = l_add(self.h[i][j], o.h[i][j]);
643            }
644        }
645        out
646    }
647    #[inline]
648    fn scale(&self, s: [f64; LANES]) -> Self {
649        let mut out = *self;
650        out.v = l_mul(self.v, s);
651        for i in 0..K {
652            out.g[i] = l_mul(self.g[i], s);
653            for j in 0..K {
654                out.h[i][j] = l_mul(self.h[i][j], s);
655            }
656        }
657        out
658    }
659    /// `self - o`, expressed as `self + o·(-1)` exactly as [`Order2::sub`] does.
660    #[inline]
661    fn sub(&self, o: &Self) -> Self {
662        self.add(&o.scale(l_splat(-1.0)))
663    }
664    /// Order-≤2 Leibniz product, term-for-term identical to `Tower2::mul`.
665    #[inline]
666    fn mul(&self, o: &Self) -> Self {
667        let a = self;
668        let b = o;
669        let mut out = Self::constant(l_mul(a.v, b.v));
670        for i in 0..K {
671            out.g[i] = l_add(l_mul(a.v, b.g[i]), l_mul(a.g[i], b.v));
672        }
673        for i in 0..K {
674            for j in 0..K {
675                let t0 = l_mul(a.v, b.h[i][j]);
676                let t1 = l_add(t0, l_mul(a.g[i], b.g[j]));
677                let t2 = l_add(t1, l_mul(a.g[j], b.g[i]));
678                out.h[i][j] = l_add(t2, l_mul(a.h[i][j], b.v));
679            }
680        }
681        out
682    }
683    /// Order-≤2 Faà di Bruno `f ∘ self` from the per-lane stack
684    /// `d = [f(u), f′(u), f″(u)]`, mirroring `Tower2::compose_unary`
685    /// (`acc` starts at `+0.0`, accumulates `d₁·hᵢⱼ` then `(d₂·gᵢ)·gⱼ`).
686    #[inline]
687    fn compose(&self, d: [[f64; LANES]; 3]) -> Self {
688        let mut out = Self::constant(d[0]);
689        for i in 0..K {
690            let mut acc = l_splat(0.0);
691            acc = l_add(acc, l_mul(d[1], self.g[i]));
692            out.g[i] = acc;
693        }
694        for i in 0..K {
695            for j in 0..K {
696                let mut acc = l_splat(0.0);
697                acc = l_add(acc, l_mul(d[1], self.h[i][j]));
698                acc = l_add(acc, l_mul(l_mul(d[2], self.g[i]), self.g[j]));
699                out.h[i][j] = acc;
700            }
701        }
702        out
703    }
704    /// `e^self`, per-lane stack `[e, e, e]` (matches `Tower2::exp`).
705    #[inline]
706    fn exp(&self) -> Self {
707        let mut e = [0.0; LANES];
708        for i in 0..LANES {
709            e[i] = self.v[i].exp();
710        }
711        self.compose([e, e, e])
712    }
713    /// `1/self`, per-lane stack `[1/u, -1/u², 2/u³]` — the DIVISION-based stack
714    /// of the [`recip`] free fn the scalar reconstruction path uses (NOT the
715    /// reciprocal-multiply `[r,-r²,2r³]` of `JetScalar::recip`; those differ by a
716    /// ULP and would break `to_bits` parity). Caller guarantees nonzero.
717    #[inline]
718    fn recip(&self) -> Self {
719        let mut d0 = [0.0; LANES];
720        let mut d1 = [0.0; LANES];
721        let mut d2 = [0.0; LANES];
722        for i in 0..LANES {
723            let u = self.v[i];
724            let u2 = u * u;
725            let u3 = u2 * u;
726            d0[i] = 1.0 / u;
727            d1[i] = -1.0 / u2;
728            d2[i] = 2.0 / u3;
729        }
730        self.compose([d0, d1, d2])
731    }
732    /// Extract lane `i` as a production [`Order2<K>`] scalar.
733    #[inline]
734    fn lane(&self, i: usize) -> Order2<K> {
735        let mut t = gam_math::jet_tower::Tower2::<K>::constant(self.v[i]);
736        for a in 0..K {
737            t.g[a] = self.g[a][i];
738            for b in 0..K {
739                t.h[a][b] = self.h[a][b][i];
740            }
741        }
742        Order2(t)
743    }
744}
745
746/// 4-rows-per-pass FIRST-order lane scalar (value / gradient only), mirroring
747/// [`Order1`] term-for-term so lane `i` is `to_bits`-identical to row-`i`
748/// [`Order1`]. Used for the β-border consumer (reconstruction is linear in β,
749/// so only value + gradient are read).
750#[derive(Clone, Copy)]
751struct O1x4<const K: usize> {
752    v: [f64; LANES],
753    g: [[f64; LANES]; K],
754}
755
756impl<const K: usize> O1x4<K> {
757    #[inline]
758    fn constant(c: [f64; LANES]) -> Self {
759        O1x4 {
760            v: c,
761            g: [[0.0; LANES]; K],
762        }
763    }
764    #[inline]
765    fn variable(value: [f64; LANES], axis: usize) -> Self {
766        let mut out = Self::constant(value);
767        out.g[axis] = l_splat(1.0);
768        out
769    }
770    #[inline]
771    fn add(&self, o: &Self) -> Self {
772        let mut out = *self;
773        out.v = l_add(self.v, o.v);
774        for i in 0..K {
775            out.g[i] = l_add(self.g[i], o.g[i]);
776        }
777        out
778    }
779    #[inline]
780    fn scale(&self, s: [f64; LANES]) -> Self {
781        let mut out = *self;
782        out.v = l_mul(self.v, s);
783        for i in 0..K {
784            out.g[i] = l_mul(self.g[i], s);
785        }
786        out
787    }
788    #[inline]
789    fn sub(&self, o: &Self) -> Self {
790        self.add(&o.scale(l_splat(-1.0)))
791    }
792    #[inline]
793    fn mul(&self, o: &Self) -> Self {
794        // Tower2::mul value/grad terms (order-≤1 truncation): v = a.v·b.v;
795        // g[i] = a.v·b.g[i] + a.g[i]·b.v. Identical float order to `Order1::mul`.
796        let a = self;
797        let b = o;
798        let mut out = Self::constant(l_mul(a.v, b.v));
799        for i in 0..K {
800            out.g[i] = l_add(l_mul(a.v, b.g[i]), l_mul(a.g[i], b.v));
801        }
802        out
803    }
804    #[inline]
805    fn compose(&self, d: [[f64; LANES]; 2]) -> Self {
806        // Order-≤1 Faà di Bruno: v = d[0]; g[i] = d[1]·g[i] (matches
807        // `Order1::compose_unary`, `acc` starts at +0.0).
808        let mut out = Self::constant(d[0]);
809        for i in 0..K {
810            let mut acc = l_splat(0.0);
811            acc = l_add(acc, l_mul(d[1], self.g[i]));
812            out.g[i] = acc;
813        }
814        out
815    }
816    #[inline]
817    fn exp(&self) -> Self {
818        let mut e = [0.0; LANES];
819        for i in 0..LANES {
820            e[i] = self.v[i].exp();
821        }
822        self.compose([e, e])
823    }
824    #[inline]
825    fn recip(&self) -> Self {
826        // Division-based `[1/u, -1/u²]` matching the `recip` free fn (see
827        // `O2x4::recip`), so lane `i` is `to_bits`-identical to the scalar path.
828        let mut d0 = [0.0; LANES];
829        let mut d1 = [0.0; LANES];
830        for i in 0..LANES {
831            let u = self.v[i];
832            let u2 = u * u;
833            d0[i] = 1.0 / u;
834            d1[i] = -1.0 / u2;
835        }
836        self.compose([d0, d1])
837    }
838    #[inline]
839    fn lane(&self, i: usize) -> Order1<K> {
840        let mut g = [0.0; K];
841        for a in 0..K {
842            g[a] = self.g[a][i];
843        }
844        Order1 { v: self.v[i], g }
845    }
846}
847
848/// Structural layout signature of a row program: the part that MUST be identical
849/// across rows for them to share one SIMD op graph (slot mapping, per-atom
850/// basis/latent/decoder shape, primary count). The per-row numeric data
851/// (`phi`/`d_phi`/`d2_phi`/`decoder` VALUES, `logits`) is what differs between
852/// lanes; the layout is what is shared.
853impl SaeReconstructionRowProgram {
854    /// Whether `self` and `other` share the SIMD-batchable softmax layout: same
855    /// softmax temperature, primary count, slot mapping, and per-atom basis /
856    /// latent / decoder dimensions. (Decoder/basis VALUES may differ per row and
857    /// are lane-packed; only the SHAPES must match.)
858    fn batch_aligned_softmax_with(&self, other: &Self) -> bool {
859        // Both rows must gate through softmax at the same temperature; a
860        // bit-for-bit `inv_tau` match is what lets them share one op graph.
861        match (self.gate, other.gate) {
862            (RowGate::Softmax { inv_tau: a }, RowGate::Softmax { inv_tau: b }) => {
863                if a.to_bits() != b.to_bits() {
864                    return false;
865                }
866            }
867            _ => return false,
868        }
869        if self.n_primaries != other.n_primaries
870            || self.atoms.len() != other.atoms.len()
871            || self.logit_slot != other.logit_slot
872            || self.coord_slot != other.coord_slot
873            || self.logits.len() != other.logits.len()
874        {
875            return false;
876        }
877        for (a, b) in self.atoms.iter().zip(other.atoms.iter()) {
878            if a.latent_dim != b.latent_dim
879                || a.n_basis() != b.n_basis()
880                || a.out_dim() != b.out_dim()
881            {
882                return false;
883            }
884        }
885        true
886    }
887
888    /// All `K` softmax gate lane-jets (`Order2` channels), with the denominator
889    /// SHARED across atoms and 4 rows packed per lane. Mirrors [`Self::all_gates`]
890    /// term-for-term so lane `i` is `to_bits`-identical to the row-`i` scalar
891    /// `all_gates::<K, Order2<K>>()`.
892    fn all_gates_o2x4<const K: usize>(&self, rows: &[&Self; LANES], inv_tau: f64) -> Vec<O2x4<K>> {
893        let n = self.gate_value.len();
894        let inv_tau_l = l_splat(inv_tau);
895        // Per-lane max-subtraction shift (= the scalar `all_gates` softmax shift,
896        // computed independently per row/lane).
897        let mut shift = [0.0; LANES];
898        for (lane, r) in rows.iter().enumerate() {
899            shift[lane] = r.logits.iter().copied().fold(f64::NEG_INFINITY, f64::max) * inv_tau;
900        }
901        let mut exps: Vec<O2x4<K>> = Vec::with_capacity(n);
902        let mut denom = O2x4::<K>::constant(l_splat(0.0));
903        for j in 0..n {
904            let mut lj_val = [0.0; LANES];
905            for (lane, r) in rows.iter().enumerate() {
906                lj_val[lane] = r.logits[j];
907            }
908            let lj = match self.logit_slot[j] {
909                Some(slot) => O2x4::<K>::variable(lj_val, slot),
910                None => O2x4::<K>::constant(lj_val),
911            };
912            let ej = lj.scale(inv_tau_l).sub(&O2x4::<K>::constant(shift)).exp();
913            denom = denom.add(&ej);
914            exps.push(ej);
915        }
916        let inv = denom.recip();
917        exps.iter().map(|e| e.mul(&inv)).collect()
918    }
919
920    /// All `K` softmax gate lane-jets at FIRST order (`Order1` channels).
921    /// Mirrors `all_gates::<K, Order1<K>>()` term-for-term.
922    fn all_gates_o1x4<const K: usize>(&self, rows: &[&Self; LANES], inv_tau: f64) -> Vec<O1x4<K>> {
923        let n = self.gate_value.len();
924        let inv_tau_l = l_splat(inv_tau);
925        let mut shift = [0.0; LANES];
926        for (lane, r) in rows.iter().enumerate() {
927            shift[lane] = r.logits.iter().copied().fold(f64::NEG_INFINITY, f64::max) * inv_tau;
928        }
929        let mut exps: Vec<O1x4<K>> = Vec::with_capacity(n);
930        let mut denom = O1x4::<K>::constant(l_splat(0.0));
931        for j in 0..n {
932            let mut lj_val = [0.0; LANES];
933            for (lane, r) in rows.iter().enumerate() {
934                lj_val[lane] = r.logits[j];
935            }
936            let lj = match self.logit_slot[j] {
937                Some(slot) => O1x4::<K>::variable(lj_val, slot),
938                None => O1x4::<K>::constant(lj_val),
939            };
940            let ej = lj.scale(inv_tau_l).sub(&O1x4::<K>::constant(shift)).exp();
941            denom = denom.add(&ej);
942            exps.push(ej);
943        }
944        let inv = denom.recip();
945        exps.iter().map(|e| e.mul(&inv)).collect()
946    }
947
948    /// One atom's basis jet `Φ_b(t)` as an [`O2x4`] over 4 rows, mirroring
949    /// [`AtomRowBasisJet::basis_tower`] term-for-term. A data-dependent `== 0`
950    /// skip is taken only when ALL 4 lanes are zero (the contribution of a zero
951    /// lane is `+0.0`, bit-identical to the scalar skip).
952    fn basis_tower_o2x4<const K: usize>(
953        rows: &[&Self; LANES],
954        atom: usize,
955        basis_col: usize,
956        coord_slots: &[usize],
957    ) -> O2x4<K> {
958        let latent = rows[0].atoms[atom].latent_dim;
959        let mut phi0 = [0.0; LANES];
960        for (lane, r) in rows.iter().enumerate() {
961            phi0[lane] = r.atoms[atom].phi[basis_col];
962        }
963        let mut acc = O2x4::<K>::constant(phi0);
964        for axis in 0..latent {
965            let slot = coord_slots[axis];
966            if slot == SAE_FIXED_COORD_SLOT {
967                continue;
968            }
969            let mut d1 = [0.0; LANES];
970            let mut any = false;
971            for (lane, r) in rows.iter().enumerate() {
972                let v = r.atoms[atom].d_phi[basis_col][axis];
973                d1[lane] = v;
974                any |= v != 0.0;
975            }
976            if any {
977                acc = acc.add(&O2x4::<K>::variable(l_splat(0.0), slot).scale(d1));
978            }
979        }
980        for axis_a in 0..latent {
981            for axis_b in 0..latent {
982                if coord_slots[axis_a] == SAE_FIXED_COORD_SLOT
983                    || coord_slots[axis_b] == SAE_FIXED_COORD_SLOT
984                {
985                    continue;
986                }
987                let mut d2 = [0.0; LANES];
988                let mut any = false;
989                for (lane, r) in rows.iter().enumerate() {
990                    let v = r.atoms[atom].d2_phi[basis_col][axis_a][axis_b];
991                    d2[lane] = v;
992                    any |= v != 0.0;
993                }
994                if !any {
995                    continue;
996                }
997                let mut half_d2 = [0.0; LANES];
998                for lane in 0..LANES {
999                    half_d2[lane] = 0.5 * d2[lane];
1000                }
1001                let va = O2x4::<K>::variable(l_splat(0.0), coord_slots[axis_a]);
1002                let vb = O2x4::<K>::variable(l_splat(0.0), coord_slots[axis_b]);
1003                acc = acc.add(&va.mul(&vb).scale(half_d2));
1004            }
1005        }
1006        acc
1007    }
1008
1009    /// One atom's basis jet as an [`O1x4`] (value + gradient), mirroring
1010    /// `basis_tower::<Order1>` term-for-term.
1011    fn basis_tower_o1x4<const K: usize>(
1012        rows: &[&Self; LANES],
1013        atom: usize,
1014        basis_col: usize,
1015        coord_slots: &[usize],
1016    ) -> O1x4<K> {
1017        let latent = rows[0].atoms[atom].latent_dim;
1018        let mut phi0 = [0.0; LANES];
1019        for (lane, r) in rows.iter().enumerate() {
1020            phi0[lane] = r.atoms[atom].phi[basis_col];
1021        }
1022        let mut acc = O1x4::<K>::constant(phi0);
1023        for axis in 0..latent {
1024            let slot = coord_slots[axis];
1025            if slot == SAE_FIXED_COORD_SLOT {
1026                continue;
1027            }
1028            let mut d1 = [0.0; LANES];
1029            let mut any = false;
1030            for (lane, r) in rows.iter().enumerate() {
1031                let v = r.atoms[atom].d_phi[basis_col][axis];
1032                d1[lane] = v;
1033                any |= v != 0.0;
1034            }
1035            if any {
1036                acc = acc.add(&O1x4::<K>::variable(l_splat(0.0), slot).scale(d1));
1037            }
1038        }
1039        for axis_a in 0..latent {
1040            for axis_b in 0..latent {
1041                if coord_slots[axis_a] == SAE_FIXED_COORD_SLOT
1042                    || coord_slots[axis_b] == SAE_FIXED_COORD_SLOT
1043                {
1044                    continue;
1045                }
1046                let mut d2 = [0.0; LANES];
1047                let mut any = false;
1048                for (lane, r) in rows.iter().enumerate() {
1049                    let v = r.atoms[atom].d2_phi[basis_col][axis_a][axis_b];
1050                    d2[lane] = v;
1051                    any |= v != 0.0;
1052                }
1053                if !any {
1054                    continue;
1055                }
1056                let mut half_d2 = [0.0; LANES];
1057                for lane in 0..LANES {
1058                    half_d2[lane] = 0.5 * d2[lane];
1059                }
1060                let va = O1x4::<K>::variable(l_splat(0.0), coord_slots[axis_a]);
1061                let vb = O1x4::<K>::variable(l_splat(0.0), coord_slots[axis_b]);
1062                acc = acc.add(&va.mul(&vb).scale(half_d2));
1063            }
1064        }
1065        acc
1066    }
1067
1068    /// All `out_dim` reconstruction columns for FOUR softmax-aligned rows at once,
1069    /// returned per row. Each row's column vector is BIT-IDENTICAL to
1070    /// [`Self::reconstruction_all_columns_packed`] on that row (same hoisting,
1071    /// same Leibniz products in the same order — lane `i` mirrors the scalar
1072    /// row-`i` path). Returns `None` if the four rows are not softmax-aligned, so
1073    /// the caller can fall back to the scalar per-row path.
1074    #[must_use]
1075    pub fn reconstruction_all_columns_batch4<const K: usize>(
1076        rows: [&Self; 4],
1077    ) -> Option<[Vec<Order2<K>>; 4]> {
1078        let head = rows[0];
1079        if head.n_primaries != K {
1080            return None;
1081        }
1082        let inv_tau = match head.gate {
1083            RowGate::Softmax { inv_tau } => inv_tau,
1084            RowGate::PerAtomLogistic { .. } => return None,
1085        };
1086        for r in &rows[1..] {
1087            if !head.batch_aligned_softmax_with(r) {
1088                return None;
1089            }
1090        }
1091        let p = head.out_dim();
1092        let gates: Vec<O2x4<K>> = head.all_gates_o2x4::<K>(&rows, inv_tau);
1093        let bases: Vec<Vec<O2x4<K>>> = head
1094            .atoms
1095            .iter()
1096            .enumerate()
1097            .map(|(atom, atom_jet)| {
1098                (0..atom_jet.n_basis())
1099                    .map(|b| Self::basis_tower_o2x4::<K>(&rows, atom, b, &head.coord_slot[atom]))
1100                    .collect()
1101            })
1102            .collect();
1103        let mut cols: [Vec<Order2<K>>; LANES] =
1104            [Vec::new(), Vec::new(), Vec::new(), Vec::new()];
1105        for c in 0..p {
1106            let mut acc = O2x4::<K>::constant(l_splat(0.0));
1107            for (atom, atom_jet) in head.atoms.iter().enumerate() {
1108                let mut decoded = O2x4::<K>::constant(l_splat(0.0));
1109                for basis_col in 0..atom_jet.n_basis() {
1110                    let mut coeff = [0.0; LANES];
1111                    let mut any = false;
1112                    for (lane, r) in rows.iter().enumerate() {
1113                        let v = r.atoms[atom].decoder[basis_col][c];
1114                        coeff[lane] = v;
1115                        any |= v != 0.0;
1116                    }
1117                    if any {
1118                        decoded = decoded.add(&bases[atom][basis_col].scale(coeff));
1119                    }
1120                }
1121                acc = acc.add(&gates[atom].mul(&decoded));
1122            }
1123            for lane in 0..LANES {
1124                cols[lane].push(acc.lane(lane));
1125            }
1126        }
1127        Some(cols)
1128    }
1129
1130    /// Packed β-border FIRST-order jets for a batch of `(atom, basis_col)`
1131    /// channels, for FOUR softmax-aligned rows at once, returned per row. Each
1132    /// row's channel vector is BIT-IDENTICAL to
1133    /// [`Self::beta_border_order1_packed`] on that row. Returns `None` if the
1134    /// rows are not softmax-aligned.
1135    #[must_use]
1136    pub fn beta_border_order1_batch4<const K: usize>(
1137        rows: [&Self; 4],
1138        channels: &[(usize, usize)],
1139    ) -> Option<[Vec<Order1<K>>; 4]> {
1140        let head = rows[0];
1141        if head.n_primaries != K {
1142            return None;
1143        }
1144        let inv_tau = match head.gate {
1145            RowGate::Softmax { inv_tau } => inv_tau,
1146            RowGate::PerAtomLogistic { .. } => return None,
1147        };
1148        for r in &rows[1..] {
1149            if !head.batch_aligned_softmax_with(r) {
1150                return None;
1151            }
1152        }
1153        let gates: Vec<O1x4<K>> = head.all_gates_o1x4::<K>(&rows, inv_tau);
1154        let mut out: [Vec<Order1<K>>; LANES] =
1155            [Vec::new(), Vec::new(), Vec::new(), Vec::new()];
1156        for &(atom, basis_col) in channels {
1157            let phi = Self::basis_tower_o1x4::<K>(&rows, atom, basis_col, &head.coord_slot[atom]);
1158            let s = gates[atom].mul(&phi);
1159            for lane in 0..LANES {
1160                out[lane].push(s.lane(lane));
1161            }
1162        }
1163        Some(out)
1164    }
1165}
1166
1167#[cfg(test)]
1168mod tests {
1169    use super::*;
1170
1171    /// Replicate the production hand path (`row_jets_for_logdet`) arithmetic for
1172    /// the reconstruction `first`/`second` channels of ONE output column, from
1173    /// the same atom jets and softmax gate derivatives — independent code from
1174    /// the tower. The two must agree to machine precision; this is the #932
1175    /// universal oracle for the SAE row program (the analog of the survival
1176    /// `rigid_row_kernel_agrees_with_jet_tower_program` oracle).
1177    struct HandChannels {
1178        first: Vec<f64>,       // [primary]
1179        second: Vec<Vec<f64>>, // [primary][primary]
1180        value: f64,
1181    }
1182
1183    /// Softmax gate first/second derivatives wrt logit primaries, term-for-term
1184    /// the production `gate_derivatives_for_row` softmax branch.
1185    fn softmax_gate_derivs(gate: &[f64], inv_tau: f64) -> (Vec<Vec<f64>>, Vec<Vec<Vec<f64>>>) {
1186        let k = gate.len();
1187        // dz[j][kk] = ∂ζ_kk/∂ℓ_j ; d2z[j][l][kk] = ∂²ζ_kk/∂ℓ_j∂ℓ_l.
1188        let mut dz = vec![vec![0.0_f64; k]; k];
1189        let mut d2z = vec![vec![vec![0.0_f64; k]; k]; k];
1190        for j in 0..k {
1191            for kk in 0..k {
1192                let ind = if kk == j { 1.0 } else { 0.0 };
1193                dz[j][kk] = gate[kk] * (ind - gate[j]) * inv_tau;
1194            }
1195        }
1196        for j in 0..k {
1197            for l in 0..k {
1198                for kk in 0..k {
1199                    let ikl = if kk == l { 1.0 } else { 0.0 };
1200                    let ikj = if kk == j { 1.0 } else { 0.0 };
1201                    let ijl = if j == l { 1.0 } else { 0.0 };
1202                    d2z[j][l][kk] = gate[kk]
1203                        * ((ikl - gate[l]) * (ikj - gate[j]) - gate[j] * (ijl - gate[l]))
1204                        * inv_tau
1205                        * inv_tau;
1206                }
1207            }
1208        }
1209        (dz, d2z)
1210    }
1211
1212    /// Hand-pack the reconstruction column channels exactly as the production
1213    /// `row_jets_for_logdet` does for a softmax gate: gate-logit primaries first
1214    /// (one per atom), then each atom's latent coords.
1215    fn hand_softmax_column(
1216        prog: &SaeReconstructionRowProgram,
1217        out_col: usize,
1218        inv_tau: f64,
1219    ) -> HandChannels {
1220        let k = prog.atoms.len();
1221        let n = prog.n_primaries;
1222        // decoded[k] value, d1[k][axis], d2[k][a][b] for this out_col.
1223        let decoded: Vec<f64> = (0..k)
1224            .map(|kk| {
1225                (0..prog.atoms[kk].n_basis())
1226                    .map(|b| prog.atoms[kk].phi[b] * prog.atoms[kk].decoder[b][out_col])
1227                    .sum()
1228            })
1229            .collect();
1230        let d1: Vec<Vec<f64>> = (0..k)
1231            .map(|kk| {
1232                (0..prog.atoms[kk].latent_dim)
1233                    .map(|axis| {
1234                        (0..prog.atoms[kk].n_basis())
1235                            .map(|b| {
1236                                prog.atoms[kk].d_phi[b][axis] * prog.atoms[kk].decoder[b][out_col]
1237                            })
1238                            .sum()
1239                    })
1240                    .collect()
1241            })
1242            .collect();
1243        let d2: Vec<Vec<Vec<f64>>> = (0..k)
1244            .map(|kk| {
1245                (0..prog.atoms[kk].latent_dim)
1246                    .map(|a| {
1247                        (0..prog.atoms[kk].latent_dim)
1248                            .map(|b| {
1249                                (0..prog.atoms[kk].n_basis())
1250                                    .map(|col| {
1251                                        prog.atoms[kk].d2_phi[col][a][b]
1252                                            * prog.atoms[kk].decoder[col][out_col]
1253                                    })
1254                                    .sum()
1255                            })
1256                            .collect()
1257                    })
1258                    .collect()
1259            })
1260            .collect();
1261
1262        let (dz, d2z) = softmax_gate_derivs(&prog.gate_value, inv_tau);
1263
1264        // Primary index of atom logit / coord, matching the program layout.
1265        let logit_idx = |kk: usize| prog.logit_slot[kk];
1266        let coord_idx = |kk: usize, axis: usize| prog.coord_slot[kk][axis];
1267
1268        let value: f64 = (0..k).map(|kk| prog.gate_value[kk] * decoded[kk]).sum();
1269
1270        let mut first = vec![0.0_f64; n];
1271        // Logit primaries: ∂ẑ/∂ℓ_j = Σ_kk dz[j][kk]·decoded[kk].
1272        for j in 0..k {
1273            if let Some(p) = logit_idx(j) {
1274                first[p] = (0..k).map(|kk| dz[j][kk] * decoded[kk]).sum();
1275            }
1276        }
1277        // Coord primaries: ∂ẑ/∂t_{kk,axis} = ζ_kk · d1[kk][axis].
1278        for kk in 0..k {
1279            for axis in 0..prog.atoms[kk].latent_dim {
1280                first[coord_idx(kk, axis)] = prog.gate_value[kk] * d1[kk][axis];
1281            }
1282        }
1283
1284        let mut second = vec![vec![0.0_f64; n]; n];
1285        // Logit×Logit: Σ_kk d2z[j][l][kk]·decoded[kk].
1286        for j in 0..k {
1287            for l in 0..k {
1288                if let (Some(pj), Some(pl)) = (logit_idx(j), logit_idx(l)) {
1289                    second[pj][pl] = (0..k).map(|kk| d2z[j][l][kk] * decoded[kk]).sum();
1290                }
1291            }
1292        }
1293        // Logit×Coord (and symmetric): dz[j][kk]·d1[kk][axis].
1294        for j in 0..k {
1295            for kk in 0..k {
1296                for axis in 0..prog.atoms[kk].latent_dim {
1297                    if let Some(pj) = logit_idx(j) {
1298                        let pc = coord_idx(kk, axis);
1299                        let val = dz[j][kk] * d1[kk][axis];
1300                        second[pj][pc] = val;
1301                        second[pc][pj] = val;
1302                    }
1303                }
1304            }
1305        }
1306        // Coord×Coord same atom: ζ_kk · d2[kk][a][b].
1307        for kk in 0..k {
1308            for a in 0..prog.atoms[kk].latent_dim {
1309                for b in 0..prog.atoms[kk].latent_dim {
1310                    let pa = coord_idx(kk, a);
1311                    let pb = coord_idx(kk, b);
1312                    second[pa][pb] = prog.gate_value[kk] * d2[kk][a][b];
1313                }
1314            }
1315        }
1316
1317        HandChannels {
1318            first,
1319            second,
1320            value,
1321        }
1322    }
1323
1324    /// Build a two-atom softmax fixture with `latent_dim = 2` per atom and a
1325    /// dense decoder so every primary is exercised. Layout: logit slots
1326    /// 0,1; atom-0 coords 2,3; atom-1 coords 4,5 → K = 6 primaries.
1327    fn softmax_fixture(inv_tau: f64) -> (SaeReconstructionRowProgram, f64) {
1328        let n_basis = 3;
1329        let out_dim = 4;
1330        let mk_atom = |seed: f64| {
1331            let phi: Vec<f64> = (0..n_basis)
1332                .map(|b| 0.3 + 0.2 * (b as f64 + seed))
1333                .collect();
1334            let d_phi: Vec<Vec<f64>> = (0..n_basis)
1335                .map(|b| {
1336                    (0..2)
1337                        .map(|axis| 0.1 * (b as f64 + 1.0) - 0.05 * axis as f64 + 0.03 * seed)
1338                        .collect()
1339                })
1340                .collect();
1341            let d2_phi: Vec<Vec<Vec<f64>>> = (0..n_basis)
1342                .map(|b| {
1343                    (0..2)
1344                        .map(|a| {
1345                            (0..2)
1346                                .map(|bb| {
1347                                    // Symmetric in (a, bb).
1348                                    0.02 * (b as f64 + 1.0)
1349                                        + 0.01 * (a as f64)
1350                                        + 0.01 * (bb as f64)
1351                                        + 0.004 * seed
1352                                })
1353                                .collect()
1354                        })
1355                        .collect()
1356                })
1357                .collect();
1358            let decoder: Vec<Vec<f64>> = (0..n_basis)
1359                .map(|b| {
1360                    (0..out_dim)
1361                        .map(|c| 0.5 - 0.1 * (b as f64) + 0.07 * (c as f64) + 0.02 * seed)
1362                        .collect()
1363                })
1364                .collect();
1365            AtomRowBasisJet {
1366                phi,
1367                d_phi,
1368                d2_phi,
1369                decoder,
1370                latent_dim: 2,
1371            }
1372        };
1373        let logits = vec![0.4_f64, -0.7];
1374        // Softmax gate values at these logits.
1375        let e: Vec<f64> = logits.iter().map(|&l| (l * inv_tau).exp()).collect();
1376        let s: f64 = e.iter().sum();
1377        let gate_value: Vec<f64> = e.iter().map(|&v| v / s).collect();
1378        let prog = SaeReconstructionRowProgram {
1379            atoms: vec![mk_atom(0.0), mk_atom(1.0)],
1380            gate_value,
1381            logits,
1382            gate_scale: vec![1.0, 1.0],
1383            gate_shift: vec![0.0, 0.0],
1384            gate: RowGate::Softmax { inv_tau },
1385            logit_slot: vec![Some(0), Some(1)],
1386            coord_slot: vec![vec![2, 3], vec![4, 5]],
1387            n_primaries: 6,
1388        };
1389        (prog, inv_tau)
1390    }
1391
1392    /// INDEPENDENT scalar witness for the reconstruction column `ẑ_c(δ)` as a
1393    /// function of the primary-increment vector `δ` (the displacement of each
1394    /// tower primary from its seed value: a coord primary seeds at value 0, a
1395    /// logit primary at its current logit, so `δ` is the same offset the tower's
1396    /// seeded variables carry). This evaluator touches NONE of the `Tower4`
1397    /// arithmetic — no Leibniz product, no Faà di Bruno compose, no
1398    /// `basis_tower`/`decoded_tower`/`gate_tower` — it re-derives the closed-form
1399    /// reconstruction from the raw jet tensors and the softmax definition. It is
1400    /// the witness the t3/t4 FD oracle differences below.
1401    ///
1402    /// `ẑ_c(δ) = Σ_k softmax_k((ℓ + δ_logit)·inv_tau) · Σ_b Φ̃_{k,b}(δ_coord)·B_{k,b,c}`
1403    /// with the SAME local quadratic basis model the program consumes:
1404    /// `Φ̃_b(u) = phi[b] + Σ_a d_phi[b][a]·u_a + ½ Σ_{a,a'} d2_phi[b][a][a']·u_a·u_{a'}`.
1405    fn recon_scalar_softmax(
1406        prog: &SaeReconstructionRowProgram,
1407        out_col: usize,
1408        inv_tau: f64,
1409        delta: &[f64],
1410    ) -> f64 {
1411        let k = prog.atoms.len();
1412        // Softmax over (logit + δ_logit) for atoms with a free logit primary;
1413        // atoms without one keep their base logit (no δ).
1414        let exps: Vec<f64> = (0..k)
1415            .map(|kk| {
1416                let dl = match prog.logit_slot[kk] {
1417                    Some(slot) => delta[slot],
1418                    None => 0.0,
1419                };
1420                ((prog.logits[kk] + dl) * inv_tau).exp()
1421            })
1422            .collect();
1423        let denom: f64 = exps.iter().sum();
1424        let mut acc = 0.0;
1425        for kk in 0..k {
1426            let gate = exps[kk] / denom;
1427            let atom = &prog.atoms[kk];
1428            // decoded_{kk,c}(δ_coord) via the local quadratic basis model.
1429            let mut decoded = 0.0;
1430            for b in 0..atom.n_basis() {
1431                let mut phi = atom.phi[b];
1432                for a in 0..atom.latent_dim {
1433                    let ua = delta[prog.coord_slot[kk][a]];
1434                    phi += atom.d_phi[b][a] * ua;
1435                }
1436                for a in 0..atom.latent_dim {
1437                    let ua = delta[prog.coord_slot[kk][a]];
1438                    for a2 in 0..atom.latent_dim {
1439                        let ub = delta[prog.coord_slot[kk][a2]];
1440                        phi += 0.5 * atom.d2_phi[b][a][a2] * ua * ub;
1441                    }
1442                }
1443                decoded += phi * atom.decoder[b][out_col];
1444            }
1445            acc += gate * decoded;
1446        }
1447        acc
1448    }
1449
1450    /// Fourth-order central FD of `recon_scalar_softmax` along axes (a,b,c,d) at
1451    /// the origin (δ = 0, the tower seed point). Uses the standard mixed
1452    /// fourth-difference stencil with sign vector ±h on each of the four axes
1453    /// (axes may coincide). 2⁴ = 16 evaluations.
1454    fn fd_fourth(
1455        prog: &SaeReconstructionRowProgram,
1456        out_col: usize,
1457        inv_tau: f64,
1458        axes: [usize; 4],
1459        h: f64,
1460    ) -> f64 {
1461        let n = prog.n_primaries;
1462        let mut acc = 0.0;
1463        for mask in 0..16u32 {
1464            let mut delta = vec![0.0_f64; n];
1465            let mut sign = 1.0;
1466            for (slot, &ax) in axes.iter().enumerate() {
1467                if (mask >> slot) & 1 == 1 {
1468                    delta[ax] += h;
1469                } else {
1470                    delta[ax] -= h;
1471                    sign = -sign;
1472                }
1473            }
1474            acc += sign * recon_scalar_softmax(prog, out_col, inv_tau, &delta);
1475        }
1476        acc / (16.0 * h * h * h * h)
1477    }
1478
1479    /// Third-order central FD of `recon_scalar_softmax` along axes (a,b,c) at the
1480    /// origin: 2³ = 8 evaluations with the mixed third-difference stencil.
1481    fn fd_third(
1482        prog: &SaeReconstructionRowProgram,
1483        out_col: usize,
1484        inv_tau: f64,
1485        axes: [usize; 3],
1486        h: f64,
1487    ) -> f64 {
1488        let n = prog.n_primaries;
1489        let mut acc = 0.0;
1490        for mask in 0..8u32 {
1491            let mut delta = vec![0.0_f64; n];
1492            let mut sign = 1.0;
1493            for (slot, &ax) in axes.iter().enumerate() {
1494                if (mask >> slot) & 1 == 1 {
1495                    delta[ax] += h;
1496                } else {
1497                    delta[ax] -= h;
1498                    sign = -sign;
1499                }
1500            }
1501            acc += sign * recon_scalar_softmax(prog, out_col, inv_tau, &delta);
1502        }
1503        acc / (8.0 * h * h * h)
1504    }
1505
1506    /// The #932 follow-up the issue flagged as missing: the SAE reconstruction
1507    /// program's THIRD- and FOURTH-order channels (`t3`/`t4`) validated against an
1508    /// INDEPENDENT witness (`recon_scalar_softmax`, finite-differenced), not just
1509    /// the value/first/second channels the hand-path oracle covers. Both the
1510    /// witness and the differencing are independent of the `Tower4` Leibniz /
1511    /// Faà-di-Bruno arithmetic that produces `t3`/`t4`, so agreement is a real
1512    /// cross-check of those higher-order channels — the analog of the survival
1513    /// kernel's `row_third_contracted` oracle, extended to fourth order.
1514    #[test]
1515    fn softmax_reconstruction_t3_t4_match_independent_fd_witness() {
1516        let (prog, inv_tau) = softmax_fixture(1.1);
1517        // Mixed fifth-derivative magnitude bounds the central-FD truncation; a
1518        // moderate step keeps both truncation and roundoff well under tol.
1519        let h3 = 2e-3;
1520        let h4 = 1e-2;
1521        for out_col in 0..prog.out_dim() {
1522            let tower = prog.reconstruction_column::<6>(out_col);
1523
1524            let t3_floor = tower
1525                .t3
1526                .iter()
1527                .flatten()
1528                .flatten()
1529                .fold(0.0_f64, |m, x| m.max(x.abs()))
1530                .max(1e-9);
1531            let t4_floor = tower
1532                .t4
1533                .iter()
1534                .flatten()
1535                .flatten()
1536                .flatten()
1537                .fold(0.0_f64, |m, x| m.max(x.abs()))
1538                .max(1e-9);
1539
1540            for a in 0..6 {
1541                for b in 0..6 {
1542                    for c in 0..6 {
1543                        let fd = fd_third(&prog, out_col, inv_tau, [a, b, c], h3);
1544                        assert!(
1545                            (tower.t3[a][b][c] - fd).abs() <= 5e-5 * t3_floor,
1546                            "col {out_col} t3[{a}][{b}][{c}]: tower {:+.10e} vs fd {:+.10e}",
1547                            tower.t3[a][b][c],
1548                            fd
1549                        );
1550                        for d in 0..6 {
1551                            let fd4 = fd_fourth(&prog, out_col, inv_tau, [a, b, c, d], h4);
1552                            assert!(
1553                                (tower.t4[a][b][c][d] - fd4).abs() <= 5e-4 * t4_floor,
1554                                "col {out_col} t4[{a}][{b}][{c}][{d}]: tower {:+.10e} vs fd {:+.10e}",
1555                                tower.t4[a][b][c][d],
1556                                fd4
1557                            );
1558                        }
1559                    }
1560                }
1561            }
1562        }
1563    }
1564
1565    /// A planted #736-style corruption in a t3 OR t4 channel is caught by the
1566    /// independent FD witness (loud at introduction). We perturb a copy of the
1567    /// tower's higher-order channel and assert the witness disagrees.
1568    #[test]
1569    fn planted_t3_t4_corruption_is_caught_by_fd_witness() {
1570        let (prog, inv_tau) = softmax_fixture(1.1);
1571        let out_col = 2;
1572        let tower = prog.reconstruction_column::<6>(out_col);
1573        // A real logit×coord×coord third block (atom-0 logit slot 0, atom-0
1574        // coords 2,3): the witness's third FD must match it...
1575        let axes3 = [0usize, 2, 3];
1576        let fd3 = fd_third(&prog, out_col, inv_tau, axes3, 2e-3);
1577        let t3_floor = tower
1578            .t3
1579            .iter()
1580            .flatten()
1581            .flatten()
1582            .fold(0.0_f64, |m, x| m.max(x.abs()))
1583            .max(1e-9);
1584        assert!(
1585            (tower.t3[0][2][3] - fd3).abs() <= 5e-5 * t3_floor,
1586            "honest t3 must match witness"
1587        );
1588        // ...and a sign-flipped copy must NOT.
1589        let corrupt = -tower.t3[0][2][3];
1590        assert!(
1591            (corrupt - fd3).abs() > 5e-5 * t3_floor,
1592            "a sign-flipped t3 block must disagree with the FD witness"
1593        );
1594
1595        let axes4 = [0usize, 0, 2, 3];
1596        let fd4 = fd_fourth(&prog, out_col, inv_tau, axes4, 1e-2);
1597        let t4_floor = tower
1598            .t4
1599            .iter()
1600            .flatten()
1601            .flatten()
1602            .flatten()
1603            .fold(0.0_f64, |m, x| m.max(x.abs()))
1604            .max(1e-9);
1605        let corrupt4 = tower.t4[0][0][2][3] + 10.0 * t4_floor;
1606        assert!(
1607            (corrupt4 - fd4).abs() > 5e-4 * t4_floor,
1608            "a corrupted t4 block must disagree with the FD witness"
1609        );
1610    }
1611
1612    #[test]
1613    fn softmax_reconstruction_tower_matches_hand_channels_all_columns() {
1614        let (prog, inv_tau) = softmax_fixture(1.3);
1615        for out_col in 0..prog.out_dim() {
1616            let tower = prog.reconstruction_column::<6>(out_col);
1617            let hand = hand_softmax_column(&prog, out_col, inv_tau);
1618
1619            // Magnitude floors so structurally-zero entries don't demand
1620            // absolute equality (the verify_kernel_channels convention).
1621            let g_floor = tower.g.iter().fold(0.0_f64, |m, x| m.max(x.abs()));
1622            let h_floor = tower
1623                .h
1624                .iter()
1625                .flatten()
1626                .fold(0.0_f64, |m, x| m.max(x.abs()));
1627
1628            assert!(
1629                (tower.v - hand.value).abs() <= 1e-9 * hand.value.abs().max(1.0),
1630                "col {out_col} value: tower {} vs hand {}",
1631                tower.v,
1632                hand.value
1633            );
1634            for a in 0..6 {
1635                assert!(
1636                    (tower.g[a] - hand.first[a]).abs() <= 1e-9 * g_floor.max(1e-12),
1637                    "col {out_col} first[{a}]: tower {} vs hand {}",
1638                    tower.g[a],
1639                    hand.first[a]
1640                );
1641                for b in 0..6 {
1642                    assert!(
1643                        (tower.h[a][b] - hand.second[a][b]).abs() <= 1e-8 * h_floor.max(1e-12),
1644                        "col {out_col} second[{a}][{b}]: tower {} vs hand {}",
1645                        tower.h[a][b],
1646                        hand.second[a][b]
1647                    );
1648                }
1649            }
1650        }
1651    }
1652
1653    /// A planted sign flip in the hand cross-block (logit×coord) is caught by the
1654    /// oracle — the same failure that #736 was, made loud at introduction.
1655    #[test]
1656    fn planted_cross_block_sign_flip_is_caught() {
1657        let (prog, inv_tau) = softmax_fixture(1.3);
1658        let out_col = 1;
1659        let tower = prog.reconstruction_column::<6>(out_col);
1660        let mut hand = hand_softmax_column(&prog, out_col, inv_tau);
1661        // Corrupt one logit×coord cross block (atom-0 logit slot 0, atom-1
1662        // coord slot 4): flip its sign, the #736 disease.
1663        hand.second[0][4] = -hand.second[0][4];
1664        hand.second[4][0] = -hand.second[4][0];
1665        let h_floor = tower
1666            .h
1667            .iter()
1668            .flatten()
1669            .fold(0.0_f64, |m, x| m.max(x.abs()));
1670        let disagrees = (tower.h[0][4] - hand.second[0][4]).abs() > 1e-8 * h_floor.max(1e-12);
1671        assert!(
1672            disagrees,
1673            "a flipped cross block must disagree with the tower truth"
1674        );
1675    }
1676
1677    /// The tower gate channels alone reproduce the softmax `gate_derivatives_for_row`
1678    /// arithmetic — isolating the gate nonlinearity from the basis/decoder so a
1679    /// regression in either is localizable.
1680    #[test]
1681    fn softmax_gate_tower_matches_hand_gate_derivatives() {
1682        let (prog, inv_tau) = softmax_fixture(0.9);
1683        let (dz, d2z) = softmax_gate_derivs(&prog.gate_value, inv_tau);
1684        for atom in 0..prog.atoms.len() {
1685            let gate = prog.gate_tower::<6, Tower4<6>>(atom);
1686            // ζ_atom value.
1687            assert!((gate.v - prog.gate_value[atom]).abs() < 1e-12);
1688            // ∂ζ_atom/∂ℓ_j == dz[j][atom].
1689            for j in 0..prog.atoms.len() {
1690                let slot = prog.logit_slot[j].unwrap();
1691                assert!(
1692                    (gate.g[slot] - dz[j][atom]).abs() < 1e-9,
1693                    "gate {atom} d/dlogit {j}: tower {} vs hand {}",
1694                    gate.g[slot],
1695                    dz[j][atom]
1696                );
1697            }
1698            // ∂²ζ_atom/∂ℓ_j∂ℓ_l == d2z[j][l][atom].
1699            for j in 0..prog.atoms.len() {
1700                for l in 0..prog.atoms.len() {
1701                    let sj = prog.logit_slot[j].unwrap();
1702                    let sl = prog.logit_slot[l].unwrap();
1703                    assert!(
1704                        (gate.h[sj][sl] - d2z[j][l][atom]).abs() < 1e-8,
1705                        "gate {atom} d2/dlogit {j}{l}: tower {} vs hand {}",
1706                        gate.h[sj][sl],
1707                        d2z[j][l][atom]
1708                    );
1709                }
1710            }
1711        }
1712    }
1713
1714    /// The per-atom logistic gate (IBP/JumpReLU branch) is diagonal in the
1715    /// logits and reproduces `σ' = σ(1−σ)·inv_tau`, `σ'' = σ(1−σ)(1−2σ)·inv_tau²`.
1716    #[test]
1717    fn per_atom_logistic_gate_matches_closed_form() {
1718        let inv_tau = 1.4;
1719        let logit = 0.6;
1720        let shift = 0.2;
1721        let x: f64 = (logit - shift) * inv_tau;
1722        let sigma = 1.0 / (1.0 + (-x).exp());
1723        let prog = SaeReconstructionRowProgram {
1724            atoms: vec![AtomRowBasisJet {
1725                phi: vec![1.0],
1726                d_phi: vec![vec![0.0]],
1727                d2_phi: vec![vec![vec![0.0]]],
1728                decoder: vec![vec![1.0]],
1729                latent_dim: 1,
1730            }],
1731            gate_value: vec![sigma],
1732            logits: vec![logit],
1733            gate_scale: vec![1.0],
1734            gate_shift: vec![shift],
1735            gate: RowGate::PerAtomLogistic { inv_tau },
1736            logit_slot: vec![Some(0)],
1737            coord_slot: vec![vec![1]],
1738            n_primaries: 2,
1739        };
1740        let gate = prog.gate_tower::<2, Tower4<2>>(0);
1741        assert!((gate.v - sigma).abs() < 1e-12);
1742        let d1 = sigma * (1.0 - sigma) * inv_tau;
1743        let d2 = sigma * (1.0 - sigma) * (1.0 - 2.0 * sigma) * inv_tau * inv_tau;
1744        assert!((gate.g[0] - d1).abs() < 1e-9, "σ': {} vs {}", gate.g[0], d1);
1745        assert!(
1746            (gate.h[0][0] - d2).abs() < 1e-9,
1747            "σ'': {} vs {}",
1748            gate.h[0][0],
1749            d2
1750        );
1751    }
1752
1753    /// #932 cutover pin: the PRODUCTION packed [`Order2`] reconstruction path
1754    /// (`reconstruction_column_packed`) is BIT-IDENTICAL on the
1755    /// value/gradient/Hessian channels to the dense [`Tower4`] oracle
1756    /// (`reconstruction_column`) — the same channels the arrow-Schur logdet
1757    /// consumer reads — for every output column. The Order2 path never
1758    /// materialises `t3`/`t4`, but its `(v, g, H)` must match the dense tower's
1759    /// order-≤2 channels to ≤1e-12 (they share the `Tower2` arithmetic), so the
1760    /// cutover changes only cost, not result.
1761    #[test]
1762    fn order2_reconstruction_matches_tower_value_grad_hessian() {
1763        for tau in [0.9_f64, 1.3, 2.1] {
1764            let (prog, _inv_tau) = softmax_fixture(tau);
1765            for out_col in 0..prog.out_dim() {
1766                let packed = prog.reconstruction_column_packed::<6>(out_col);
1767                let tower = prog.reconstruction_column::<6>(out_col);
1768                let g = packed.g();
1769                let h = packed.h();
1770                let band = |x: f64| 1e-12 + 1e-12 * x.abs();
1771                assert!(
1772                    (packed.value() - tower.v).abs() <= band(tower.v),
1773                    "col {out_col} value: order2 {} vs tower {}",
1774                    packed.value(),
1775                    tower.v
1776                );
1777                for a in 0..6 {
1778                    assert!(
1779                        (g[a] - tower.g[a]).abs() <= band(tower.g[a]),
1780                        "col {out_col} g[{a}]: order2 {} vs tower {}",
1781                        g[a],
1782                        tower.g[a]
1783                    );
1784                    for b in 0..6 {
1785                        assert!(
1786                            (h[a][b] - tower.h[a][b]).abs() <= band(tower.h[a][b]),
1787                            "col {out_col} h[{a}][{b}]: order2 {} vs tower {}",
1788                            h[a][b],
1789                            tower.h[a][b]
1790                        );
1791                    }
1792                }
1793            }
1794        }
1795    }
1796
1797    /// #932 cutover pin for the β border channel: the packed [`Order2`]
1798    /// `beta_border_tower_packed` matches the dense [`Tower4`]
1799    /// `beta_border_tower` on the value (`beta`) and gradient (`beta_deriv` /
1800    /// `beta_l_deriv`) channels the consumer reads, to ≤1e-12.
1801    #[test]
1802    fn order2_beta_border_matches_tower_value_grad() {
1803        let (prog, _inv_tau) = softmax_fixture(1.1);
1804        for atom in 0..prog.atoms.len() {
1805            for basis_col in 0..prog.atoms[atom].n_basis() {
1806                let packed = prog.beta_border_tower_packed::<6>(atom, basis_col);
1807                let tower = prog.beta_border_tower::<6>(atom, basis_col);
1808                let g = packed.g();
1809                let band = |x: f64| 1e-12 + 1e-12 * x.abs();
1810                assert!(
1811                    (packed.value() - tower.v).abs() <= band(tower.v),
1812                    "atom {atom} b {basis_col} value: order2 {} vs tower {}",
1813                    packed.value(),
1814                    tower.v
1815                );
1816                for a in 0..6 {
1817                    assert!(
1818                        (g[a] - tower.g[a]).abs() <= band(tower.g[a]),
1819                        "atom {atom} b {basis_col} g[{a}]: order2 {} vs tower {}",
1820                        g[a],
1821                        tower.g[a]
1822                    );
1823                }
1824            }
1825        }
1826    }
1827
1828    /// #932 perf pin: the gate-shared `all_gates` produces gate jets
1829    /// BIT-IDENTICAL to the per-atom `gate_tower` — sharing the softmax
1830    /// denominator / reciprocal across atoms (K exps + 1 recip instead of
1831    /// K² + K) changes only which redundant work is elided, not the result
1832    /// (`ζ_k = exp_k · recip(denom)` is the same product, same Leibniz order).
1833    #[test]
1834    fn shared_all_gates_bit_identical_to_per_atom_gate_tower() {
1835        for tau in [0.9_f64, 1.3, 2.1] {
1836            let (prog, _inv_tau) = softmax_fixture(tau);
1837            let all = prog.all_gates::<6, Order2<6>>();
1838            assert_eq!(all.len(), prog.gate_value.len());
1839            for atom in 0..prog.gate_value.len() {
1840                let per = prog.gate_tower::<6, Order2<6>>(atom);
1841                assert_eq!(all[atom].value(), per.value(), "atom {atom} value");
1842                for a in 0..6 {
1843                    assert_eq!(all[atom].g()[a], per.g()[a], "atom {atom} g[{a}]");
1844                    for b in 0..6 {
1845                        assert_eq!(
1846                            all[atom].h()[a][b],
1847                            per.h()[a][b],
1848                            "atom {atom} h[{a}][{b}]"
1849                        );
1850                    }
1851                }
1852            }
1853        }
1854    }
1855
1856    /// #932 perf pin: the gate/basis-HOISTED + denominator-SHARED all-columns
1857    /// reconstruction (`reconstruction_all_columns_packed`) is BIT-IDENTICAL to
1858    /// calling `reconstruction_column_packed(c)` per column — the hoist + share
1859    /// removes only redundant gate/basis/denominator recomputation, not any
1860    /// arithmetic. Every value/grad/Hessian channel must match exactly (==),
1861    /// since the Leibniz products are the same in the same order.
1862    #[test]
1863    fn hoisted_all_columns_bit_identical_to_per_column() {
1864        for tau in [0.9_f64, 1.3, 2.1] {
1865            let (prog, _inv_tau) = softmax_fixture(tau);
1866            let all = prog.reconstruction_all_columns_packed::<6>();
1867            assert_eq!(all.len(), prog.out_dim());
1868            for out_col in 0..prog.out_dim() {
1869                let per = prog.reconstruction_column_packed::<6>(out_col);
1870                let ah = all[out_col];
1871                assert_eq!(ah.value(), per.value(), "col {out_col} value");
1872                for a in 0..6 {
1873                    assert_eq!(ah.g()[a], per.g()[a], "col {out_col} g[{a}]");
1874                    for b in 0..6 {
1875                        assert_eq!(ah.h()[a][b], per.h()[a][b], "col {out_col} h[{a}][{b}]");
1876                    }
1877                }
1878            }
1879        }
1880    }
1881
1882    /// Build four softmax-aligned row programs that differ ONLY in their per-row
1883    /// numeric data (logits, basis values, decoder), keeping the layout
1884    /// (slots / dims / temperature) identical so they are 4-row SIMD-batchable.
1885    fn softmax_batch_fixture(inv_tau: f64) -> [SaeReconstructionRowProgram; LANES] {
1886        let n_basis = 3;
1887        let out_dim = 4;
1888        let mk = |row_seed: f64| {
1889            let mk_atom = |seed: f64| {
1890                let phi: Vec<f64> = (0..n_basis)
1891                    .map(|b| 0.3 + 0.2 * (b as f64 + seed) + 0.11 * row_seed)
1892                    .collect();
1893                let d_phi: Vec<Vec<f64>> = (0..n_basis)
1894                    .map(|b| {
1895                        (0..2)
1896                            .map(|axis| {
1897                                0.1 * (b as f64 + 1.0) - 0.05 * axis as f64 + 0.03 * seed
1898                                    + 0.017 * row_seed
1899                            })
1900                            .collect()
1901                    })
1902                    .collect();
1903                let d2_phi: Vec<Vec<Vec<f64>>> = (0..n_basis)
1904                    .map(|b| {
1905                        (0..2)
1906                            .map(|a| {
1907                                (0..2)
1908                                    .map(|bb| {
1909                                        0.02 * (b as f64 + 1.0)
1910                                            + 0.01 * (a as f64)
1911                                            + 0.01 * (bb as f64)
1912                                            + 0.004 * seed
1913                                            + 0.003 * row_seed
1914                                    })
1915                                    .collect()
1916                            })
1917                            .collect()
1918                    })
1919                    .collect();
1920                let decoder: Vec<Vec<f64>> = (0..n_basis)
1921                    .map(|b| {
1922                        (0..out_dim)
1923                            .map(|c| {
1924                                0.5 - 0.1 * (b as f64) + 0.07 * (c as f64) + 0.02 * seed
1925                                    + 0.009 * row_seed
1926                            })
1927                            .collect()
1928                    })
1929                    .collect();
1930                AtomRowBasisJet {
1931                    phi,
1932                    d_phi,
1933                    d2_phi,
1934                    decoder,
1935                    latent_dim: 2,
1936                }
1937            };
1938            let logits = vec![0.4 + 0.21 * row_seed, -0.7 + 0.13 * row_seed];
1939            let e: Vec<f64> = logits.iter().map(|&l| (l * inv_tau).exp()).collect();
1940            let s: f64 = e.iter().sum();
1941            let gate_value: Vec<f64> = e.iter().map(|&v| v / s).collect();
1942            SaeReconstructionRowProgram {
1943                atoms: vec![mk_atom(0.0), mk_atom(1.0)],
1944                gate_value,
1945                logits,
1946                gate_scale: vec![1.0, 1.0],
1947                gate_shift: vec![0.0, 0.0],
1948                gate: RowGate::Softmax { inv_tau },
1949                logit_slot: vec![Some(0), Some(1)],
1950                coord_slot: vec![vec![2, 3], vec![4, 5]],
1951                n_primaries: 6,
1952            }
1953        };
1954        [mk(0.0), mk(1.0), mk(2.0), mk(3.0)]
1955    }
1956
1957    /// SIMD-batch bit-identity oracle: `reconstruction_all_columns_batch4` lane
1958    /// `i` is `to_bits`-identical to the scalar `reconstruction_all_columns_packed`
1959    /// on row `i`, across many temperatures and randomized per-row data
1960    /// (≥2000 channel comparisons). The 4-row SIMD pass changes only how many
1961    /// rows share one instruction stream, never the arithmetic.
1962    #[test]
1963    fn batch4_reconstruction_bit_identical_to_per_row() {
1964        let mut comparisons = 0usize;
1965        for tau in [0.7_f64, 0.9, 1.1, 1.3, 1.7, 2.1, 2.9] {
1966            let rows = softmax_batch_fixture(tau);
1967            let refs = [&rows[0], &rows[1], &rows[2], &rows[3]];
1968            let batch = SaeReconstructionRowProgram::reconstruction_all_columns_batch4::<6>(refs)
1969                .expect("softmax-aligned rows must batch");
1970            for lane in 0..LANES {
1971                let per = rows[lane].reconstruction_all_columns_packed::<6>();
1972                assert_eq!(per.len(), batch[lane].len());
1973                for (c, (b, p)) in batch[lane].iter().zip(per.iter()).enumerate() {
1974                    assert_eq!(
1975                        b.value().to_bits(),
1976                        p.value().to_bits(),
1977                        "tau {tau} lane {lane} col {c} value"
1978                    );
1979                    let (bg, pg) = (b.g(), p.g());
1980                    let (bh, ph) = (b.h(), p.h());
1981                    for a in 0..6 {
1982                        assert_eq!(bg[a].to_bits(), pg[a].to_bits(), "lane {lane} col {c} g[{a}]");
1983                        for d in 0..6 {
1984                            assert_eq!(
1985                                bh[a][d].to_bits(),
1986                                ph[a][d].to_bits(),
1987                                "lane {lane} col {c} h[{a}][{d}]"
1988                            );
1989                            comparisons += 1;
1990                        }
1991                    }
1992                }
1993            }
1994        }
1995        assert!(comparisons >= 2000, "oracle ran {comparisons} comparisons");
1996    }
1997
1998    /// SIMD-batch bit-identity oracle for the β-border first-order path:
1999    /// `beta_border_order1_batch4` lane `i` is `to_bits`-identical to
2000    /// `beta_border_order1_packed` on row `i`.
2001    #[test]
2002    fn batch4_beta_border_bit_identical_to_per_row() {
2003        let mut comparisons = 0usize;
2004        for tau in [0.7_f64, 0.9, 1.1, 1.3, 1.7, 2.1, 2.9] {
2005            let rows = softmax_batch_fixture(tau);
2006            let refs = [&rows[0], &rows[1], &rows[2], &rows[3]];
2007            let mut chans: Vec<(usize, usize)> = Vec::new();
2008            for atom in 0..rows[0].atoms.len() {
2009                for b in 0..rows[0].atoms[atom].n_basis() {
2010                    chans.push((atom, b));
2011                }
2012            }
2013            chans.push(chans[0]); // repeat to exercise gate-cache reuse
2014            let batch =
2015                SaeReconstructionRowProgram::beta_border_order1_batch4::<6>(refs, &chans)
2016                    .expect("softmax-aligned rows must batch");
2017            for lane in 0..LANES {
2018                let per = rows[lane].beta_border_order1_packed::<6>(&chans);
2019                assert_eq!(per.len(), batch[lane].len());
2020                for (i, (b, p)) in batch[lane].iter().zip(per.iter()).enumerate() {
2021                    assert_eq!(b.value().to_bits(), p.value().to_bits(), "lane {lane} chan {i} v");
2022                    let (bg, pg) = (b.g(), p.g());
2023                    for a in 0..6 {
2024                        assert_eq!(
2025                            bg[a].to_bits(),
2026                            pg[a].to_bits(),
2027                            "lane {lane} chan {i} g[{a}]"
2028                        );
2029                        comparisons += 1;
2030                    }
2031                }
2032            }
2033        }
2034        assert!(comparisons >= 1000, "oracle ran {comparisons} comparisons");
2035    }
2036
2037    /// A non-softmax (per-atom logistic) batch must DECLINE (return `None`) so the
2038    /// caller falls back to the scalar per-row path — the logistic branch is
2039    /// per-row data-dependent and not lane-uniform.
2040    #[test]
2041    fn batch4_declines_non_softmax() {
2042        let inv_tau = 1.1;
2043        let mk = || SaeReconstructionRowProgram {
2044            atoms: vec![AtomRowBasisJet {
2045                phi: vec![1.0],
2046                d_phi: vec![vec![0.0]],
2047                d2_phi: vec![vec![vec![0.0]]],
2048                decoder: vec![vec![1.0]],
2049                latent_dim: 1,
2050            }],
2051            gate_value: vec![0.6],
2052            logits: vec![0.6],
2053            gate_scale: vec![1.0],
2054            gate_shift: vec![0.2],
2055            gate: RowGate::PerAtomLogistic { inv_tau },
2056            logit_slot: vec![Some(0)],
2057            coord_slot: vec![vec![1]],
2058            n_primaries: 2,
2059        };
2060        let rows = [mk(), mk(), mk(), mk()];
2061        let refs = [&rows[0], &rows[1], &rows[2], &rows[3]];
2062        assert!(
2063            SaeReconstructionRowProgram::reconstruction_all_columns_batch4::<2>(refs).is_none()
2064        );
2065    }
2066
2067    /// #932 perf pin: the gate-HOISTED batched β border jets
2068    /// (`beta_border_towers_packed`) are BIT-IDENTICAL to per-channel
2069    /// `beta_border_tower_packed`, including when several channels share an atom
2070    /// (the gate-cache reuse path).
2071    #[test]
2072    fn hoisted_beta_border_bit_identical_to_per_channel() {
2073        let (prog, _inv_tau) = softmax_fixture(1.1);
2074        // Build a channel list that repeats atoms (exercises the gate cache).
2075        let mut chans: Vec<(usize, usize)> = Vec::new();
2076        for atom in 0..prog.atoms.len() {
2077            for basis_col in 0..prog.atoms[atom].n_basis() {
2078                chans.push((atom, basis_col));
2079            }
2080        }
2081        // Duplicate the first atom's channels at the end to force cache reuse.
2082        if let Some(&first) = chans.first() {
2083            chans.push(first);
2084        }
2085        let batched = prog.beta_border_towers_packed::<6>(&chans);
2086        assert_eq!(batched.len(), chans.len());
2087        for (i, &(atom, basis_col)) in chans.iter().enumerate() {
2088            let per = prog.beta_border_tower_packed::<6>(atom, basis_col);
2089            let b = batched[i];
2090            assert_eq!(b.value(), per.value(), "chan {i} value");
2091            for a in 0..6 {
2092                assert_eq!(b.g()[a], per.g()[a], "chan {i} g[{a}]");
2093            }
2094        }
2095    }
2096}