Skip to main content

gam_sae/
row_jet_program.rs

1//! The SAE reconstruction row as a single Taylor-jet program (issue #932).
2//!
3//! # The row program
4//!
5//! The exact-LAML SAE engine needs, per row, the derivative tower of the
6//! reconstruction
7//!
8//! ```text
9//!   ẑ_row,c(p) = Σ_k ζ_k(ℓ) · decoded_{k,c}(t_k),   decoded_{k,c}(t) = Σ_b Φ_b(t)·B_{b,c}
10//! ```
11//!
12//! — a **gate nonlinearity** `ζ(ℓ)` (softmax / IBP sigmoid) composed with a
13//! **basis** `Φ(t)` composed with a **linear decoder** `B`, in the per-row
14//! primary coordinates `p = (gate logits ℓ, latent coordinates t)`. Today the
15//! arrow-Schur assembly (`SaeManifoldTerm::row_jets_for_logdet`) hand-packs the
16//! `first`/`second` channels of this reconstruction from separate gate
17//! derivative arrays (`gate_derivatives_for_row`) and basis jet tensors —
18//! exactly the kind of hand-maintained cross-block tower whose sign flips are
19//! the #736 / desync bug genus. The #1006 third-order logdet adjoint
20//! `Γ_a = tr(H⁻¹ ∂H/∂θ_a)` is the consumer of those very channels.
21//!
22//! This module writes that reconstruction **once** over the
23//! [`Tower4<K>`](gam_math::jet_tower::Tower4) scalar so the
24//! value/gradient/Hessian/third channels of one row come from ONE jet
25//! evaluation. [`SaeReconstructionRowProgram`] is generic over the gate kind
26//! and the per-row basis jets; the gate, basis and decoder compose with plain
27//! `Tower4` arithmetic, so there is no separate "channel" to forget.
28//!
29//! # The basis as a local jet
30//!
31//! The production assembly does NOT re-evaluate the manifold basis `Φ` as a
32//! function of perturbed coordinates: it consumes the precomputed jet tensors
33//! `(Φ, ∂Φ/∂t, ∂²Φ/∂t²)` evaluated at the current `t`. The reconstruction's
34//! dependence on `t` is therefore *defined* by those tensors — the local
35//! quadratic Taylor model of `Φ` about the current point. This program builds
36//! each basis function as exactly that `Tower4` quadratic from the stored jets,
37//! so the value/first/second channels it emits are the same object the hand
38//! path packs — derived by independent arithmetic (tower Leibniz / Faà di
39//! Bruno vs hand-summed cross terms). Agreement across both is a true
40//! correctness proof of the hand kernel; disagreement names a dropped or
41//! sign-flipped cross block loudly. That oracle is the riding test below.
42
43use gam_math::jet_scalar::{JetScalar, Order1, Order2};
44use gam_math::jet_tower::Tower4;
45
46/// `1/self` for any [`JetScalar`] via Faà di Bruno on `f(u) = 1/u`
47/// (stack `[1/u, -1/u², 2/u³, -6/u⁴, 24/u⁵]`). Caller guarantees `self.value()`
48/// is nonzero — softmax denominators are strictly positive sums of exponentials.
49#[inline]
50fn recip<const K: usize, S: JetScalar<K>>(s: &S) -> S {
51    let u = s.value();
52    let u2 = u * u;
53    let u3 = u2 * u;
54    let u4 = u3 * u;
55    let u5 = u4 * u;
56    s.compose_unary([1.0 / u, -1.0 / u2, 2.0 / u3, -6.0 / u4, 24.0 / u5])
57}
58
59/// Sentinel in [`SaeReconstructionRowProgram::coord_slot`] for an atom
60/// coordinate that is fixed in this row's local chart (compact active-set rows
61/// omit inactive atom coordinates, but softmax logit derivatives can still see
62/// that atom's decoded value as a constant).
63pub const SAE_FIXED_COORD_SLOT: usize = usize::MAX;
64
65/// The gate nonlinearity `ζ(ℓ)` of the SAE assignment, as the row program sees
66/// it. The production term carries the same two smooth branches (softmax over a
67/// shared partition; per-atom IBP/JumpReLU sigmoid); the program reproduces the
68/// branch the criterion evaluates so the value channel is the production gate.
69#[derive(Debug, Clone, Copy)]
70pub enum RowGate {
71    /// Shared softmax over all atom logits with inverse temperature `inv_tau`.
72    /// `ζ_k(ℓ) = softmax_k(ℓ · inv_tau)`.
73    Softmax { inv_tau: f64 },
74    /// Per-atom independent logistic gate `ζ_k(ℓ_k) = σ((ℓ_k − shift_k)·inv_tau)`
75    /// — the IBP-MAP / JumpReLU smooth activation (the per-atom `shift_k`
76    /// folds the IBP stick-breaking offset or the JumpReLU threshold). Each
77    /// gate depends only on its own logit, so the gate Hessian is diagonal.
78    PerAtomLogistic { inv_tau: f64 },
79}
80
81/// One atom's local basis jet at the current row: the stored
82/// `(value, jacobian, second)` jet tensors of `Φ` plus the decoder block `B`.
83/// Indexed `[basis_col]`, `[basis_col][axis]`, `[basis_col][axis_a][axis_b]`,
84/// and `[basis_col][out_col]`.
85#[derive(Debug, Clone)]
86pub struct AtomRowBasisJet {
87    /// `Φ_b` at the current coordinate (length `n_basis`).
88    pub phi: Vec<f64>,
89    /// `∂Φ_b/∂t_axis` (`[n_basis][latent_dim]`).
90    pub d_phi: Vec<Vec<f64>>,
91    /// `∂²Φ_b/∂t_a∂t_b` (`[n_basis][latent_dim][latent_dim]`).
92    pub d2_phi: Vec<Vec<Vec<f64>>>,
93    /// Decoder block `B_{b,c}` (`[n_basis][out_dim]`).
94    pub decoder: Vec<Vec<f64>>,
95    /// Latent dimension of this atom.
96    pub latent_dim: usize,
97}
98
99impl AtomRowBasisJet {
100    fn n_basis(&self) -> usize {
101        self.phi.len()
102    }
103
104    fn out_dim(&self) -> usize {
105        self.decoder.first().map_or(0, Vec::len)
106    }
107
108    /// `Φ_b(t)` as a `Tower4<K>` quadratic in the latent primaries occupying
109    /// `coord_slots[axis]` (the seeded tower variable index for latent axis
110    /// `axis` of this atom). A constant value plus first/second jet
111    /// contributions — exactly the local Taylor model the production assembly
112    /// consumes.
113    fn basis_tower<const K: usize, S: JetScalar<K>>(
114        &self,
115        basis_col: usize,
116        coord_slots: &[usize],
117    ) -> S {
118        // The latent coordinate increments enter as the seeded tower variables;
119        // the basis value at the current point is the constant term.
120        let mut acc = S::constant(self.phi[basis_col]);
121        for axis in 0..self.latent_dim {
122            let slot = coord_slots[axis];
123            let d1 = self.d_phi[basis_col][axis];
124            if d1 != 0.0 {
125                if slot != SAE_FIXED_COORD_SLOT {
126                    acc = acc.add(&S::variable(0.0, slot).scale(d1));
127                }
128            }
129        }
130        // ½ Σ_ab d²Φ · δ_a δ_b, the quadratic term of the local Taylor model.
131        // Hoist the axis_a fixed-slot skip and `va` build out of the inner loop.
132        for axis_a in 0..self.latent_dim {
133            let slot_a = coord_slots[axis_a];
134            if slot_a == SAE_FIXED_COORD_SLOT {
135                continue;
136            }
137            let va = S::variable(0.0, slot_a);
138            for axis_b in 0..self.latent_dim {
139                let d2 = self.d2_phi[basis_col][axis_a][axis_b];
140                if d2 == 0.0 {
141                    continue;
142                }
143                let slot_b = coord_slots[axis_b];
144                if slot_b == SAE_FIXED_COORD_SLOT {
145                    continue;
146                }
147                let vb = S::variable(0.0, slot_b);
148                acc = acc.add(&va.mul(&vb).scale(0.5 * d2));
149            }
150        }
151        acc
152    }
153
154    /// `decoded_{k,c}(t)` as a tower: `Σ_b Φ_b(t)·B_{b,c}`.
155    fn decoded_tower<const K: usize, S: JetScalar<K>>(
156        &self,
157        out_col: usize,
158        coord_slots: &[usize],
159    ) -> S {
160        let mut acc = S::constant(0.0);
161        for basis_col in 0..self.n_basis() {
162            let b = self.decoder[basis_col][out_col];
163            if b == 0.0 {
164                continue;
165            }
166            acc = acc.add(&self.basis_tower::<K, S>(basis_col, coord_slots).scale(b));
167        }
168        acc
169    }
170}
171
172/// One row of the SAE reconstruction as a jet program: the per-atom basis jets,
173/// the gate, the current gate-logit values, and the primary layout that maps
174/// `(atom logit, atom latent axis)` to a seeded tower variable slot.
175#[derive(Debug, Clone)]
176pub struct SaeReconstructionRowProgram {
177    /// Per-atom basis jets at the current row.
178    pub atoms: Vec<AtomRowBasisJet>,
179    /// Current gate activations `ζ_k` at the row (softmax/sigmoid values).
180    pub gate_value: Vec<f64>,
181    /// Current gate logits `ℓ_k` at the row.
182    pub logits: Vec<f64>,
183    /// Per-atom multiplicative scale for independent logistic gates. This is
184    /// the IBP stick-breaking prior `π_k` for IBP-MAP, `1` for active JumpReLU,
185    /// and `0` for JumpReLU rows at/below the hard threshold. Unused for
186    /// softmax.
187    pub gate_scale: Vec<f64>,
188    /// Per-atom logistic shift (IBP offset / JumpReLU threshold); unused for
189    /// softmax.
190    pub gate_shift: Vec<f64>,
191    /// The gate nonlinearity.
192    pub gate: RowGate,
193    /// Tower slot of atom `k`'s gate logit primary, or `None` if the gate logit
194    /// is not a free primary for this atom (softmax `K==1`).
195    pub logit_slot: Vec<Option<usize>>,
196    /// Tower slot of atom `k`'s latent axis `j` primary (`coord_slot[k][j]`).
197    pub coord_slot: Vec<Vec<usize>>,
198    /// Total number of seeded primaries (= `K` of the tower).
199    pub n_primaries: usize,
200}
201
202impl SaeReconstructionRowProgram {
203    /// The gate activation `ζ_k(ℓ)` as a `Tower4<K>` in the gate-logit
204    /// primaries. Softmax is the shared composition `exp(ℓ_k·inv_tau) /
205    /// Σ_j exp(ℓ_j·inv_tau)`; the per-atom logistic is `σ((ℓ_k − shift_k)·
206    /// inv_tau)` depending only on its own logit. Both carry every derivative
207    /// channel automatically.
208    fn gate_tower<const K: usize, S: JetScalar<K>>(&self, atom: usize) -> S {
209        match self.gate {
210            RowGate::Softmax { inv_tau } => {
211                // Build exp(ℓ_j·inv_tau − shift) for every atom that has a free
212                // logit primary, as a tower; atoms without a free logit
213                // contribute a constant exponential (their logit does not move).
214                //
215                // Stability: softmax is invariant to a common additive constant
216                // in every exponent (`exp(a−s)/Σ exp(b−s) = exp(a)/Σ exp(b)`),
217                // and the higher derivative channels are unchanged because the
218                // shift is a numeric constant (a function of the base logit
219                // *values* only, seeded as a `constant`, not of the tower
220                // variables). We subtract the largest base exponent
221                // `max_j ℓ_j·inv_tau` so the dominant `exp(·)` is `exp(0)=1` and
222                // no term overflows. This mirrors the max-subtraction in the
223                // production `softmax_row`.
224                let shift = self
225                    .logits
226                    .iter()
227                    .copied()
228                    .fold(f64::NEG_INFINITY, f64::max)
229                    * inv_tau;
230                let mut denom = S::constant(0.0);
231                let mut numer = S::constant(0.0);
232                for j in 0..self.gate_value.len() {
233                    let lj = match self.logit_slot[j] {
234                        Some(slot) => S::variable(self.logits[j], slot),
235                        None => S::constant(self.logits[j]),
236                    };
237                    // (ℓ_j·inv_tau − shift): subtracting a constant shifts only
238                    // the value channel, leaving every gradient/Hessian/t3/t4
239                    // channel of the exponent (hence of exp via the chain rule)
240                    // identical to the unshifted form.
241                    let ej = lj.scale(inv_tau).sub(&S::constant(shift)).exp();
242                    if j == atom {
243                        numer = ej;
244                    }
245                    denom = denom.add(&ej);
246                }
247                numer.mul(&recip(&denom))
248            }
249            RowGate::PerAtomLogistic { inv_tau } => {
250                let l = match self.logit_slot[atom] {
251                    Some(slot) => S::variable(self.logits[atom], slot),
252                    None => S::constant(self.logits[atom]),
253                };
254                let x = l.sub(&S::constant(self.gate_shift[atom])).scale(inv_tau);
255                let one = S::constant(1.0);
256                let sigma = if x.value() >= 0.0 {
257                    one.mul(&recip(&one.add(&x.scale(-1.0).exp())))
258                } else {
259                    let ex = x.exp();
260                    ex.mul(&recip(&one.add(&ex)))
261                };
262                sigma.scale(self.gate_scale[atom])
263            }
264        }
265    }
266
267    /// All atoms' gate jets `ζ_k` at once, with the softmax denominator SHARED
268    /// across atoms (#932 perf). The per-atom [`Self::gate_tower`] rebuilds the
269    /// whole softmax denominator — `K` exp-jets, their sum, and the reciprocal —
270    /// on EVERY call, because only the numerator differs per atom; calling it `K`
271    /// times costs `K·(K exps) = O(K²)` exponential jets and `K` reciprocal jets
272    /// per row. Here the `K` exp-jets, the denominator sum, and the single
273    /// reciprocal jet are built ONCE, then `ζ_k = exp_k · inv_denom`. This emits
274    /// exactly `K` exps + `1` recip per row instead of `K²` + `K` (measured:
275    /// `K(K−1)` redundant exps and `K−1` redundant recips eliminated per row at
276    /// `K=8` ⇒ 56 exps + 7 recips removed), and is **bit-identical** to the
277    /// per-atom path (same `exp_k · recip(denom)` product, same Leibniz order).
278    /// Pure [`JetScalar`] ops — single-source, exact, no softmax chain rule.
279    fn all_gates<const K: usize, S: JetScalar<K>>(&self) -> Vec<S> {
280        let n = self.gate_value.len();
281        match self.gate {
282            RowGate::Softmax { inv_tau } => {
283                let shift = self
284                    .logits
285                    .iter()
286                    .copied()
287                    .fold(f64::NEG_INFINITY, f64::max)
288                    * inv_tau;
289                // The K exp-jets and the denominator, built ONCE and shared.
290                let mut exps: Vec<S> = Vec::with_capacity(n);
291                let mut denom = S::constant(0.0);
292                for j in 0..n {
293                    let lj = match self.logit_slot[j] {
294                        Some(slot) => S::variable(self.logits[j], slot),
295                        None => S::constant(self.logits[j]),
296                    };
297                    let ej = lj.scale(inv_tau).sub(&S::constant(shift)).exp();
298                    denom = denom.add(&ej);
299                    exps.push(ej);
300                }
301                let inv = recip(&denom);
302                exps.iter().map(|e| e.mul(&inv)).collect()
303            }
304            // Per-atom logistic gates are independent (each depends only on its
305            // own logit); there is no shared denominator to hoist, so this is the
306            // same as calling `gate_tower` per atom.
307            RowGate::PerAtomLogistic { .. } => {
308                (0..n).map(|atom| self.gate_tower::<K, S>(atom)).collect()
309            }
310        }
311    }
312
313    /// The reconstruction output column `c` as a single jet:
314    /// `ẑ_c(p) = Σ_k ζ_k(ℓ) · decoded_{k,c}(t_k)`. Its `.v` is the production
315    /// reconstruction value, `.g[a]` is `∂ẑ_c/∂p_a`, `.h[a][b]` is
316    /// `∂²ẑ_c/∂p_a∂p_b`, and the `t3`/`t4` channels are the exact higher-order
317    /// derivatives — all from this ONE evaluation.
318    fn reconstruction_column_generic<const K: usize, S: JetScalar<K>>(&self, out_col: usize) -> S {
319        assert_eq!(
320            self.n_primaries, K,
321            "SaeReconstructionRowProgram: tower arity K={K} must equal n_primaries={}",
322            self.n_primaries
323        );
324        let mut acc = S::constant(0.0);
325        for (atom, atom_jet) in self.atoms.iter().enumerate() {
326            let gate = self.gate_tower::<K, S>(atom);
327            let decoded = atom_jet.decoded_tower::<K, S>(out_col, &self.coord_slot[atom]);
328            acc = acc.add(&gate.mul(&decoded));
329        }
330        acc
331    }
332
333    /// The reconstruction output column `c` as the PACKED order-2 jet
334    /// [`Order2<K>`](gam_math::jet_scalar::Order2): value `.value()`,
335    /// gradient `.g()[a] = ∂ẑ_c/∂p_a`, Hessian `.h()[a][b] = ∂²ẑ_c/∂p_a∂p_b`.
336    ///
337    /// This is the production path (#932): the arrow-Schur logdet consumer reads
338    /// ONLY the order-≤2 channels of the reconstruction, so it builds the packed
339    /// [`Order2<K>`] scalar — value/gradient/Hessian only — instead of the dense
340    /// [`Tower4<K>`] (which materialises the entire K⁴ `t3`/`t4` tensor every row
341    /// only to discard it). For `K` up to 16 the dense tower's tensor build is
342    /// ~19× the instruction count of the order-2 channels alone; this collapses
343    /// it to the channels actually read. The packed `(v, g, H)` is BIT-IDENTICAL
344    /// to the order-≤2 channels of [`Self::reconstruction_column_tower`] (the
345    /// `Order2` newtype delegates to the same `Tower2` arithmetic the dense
346    /// tower's order-≤2 channels use); the t3/t4 oracle pins the dense path.
347    #[must_use]
348    pub fn reconstruction_column_packed<const K: usize>(&self, out_col: usize) -> Order2<K> {
349        self.reconstruction_column_generic::<K, Order2<K>>(out_col)
350    }
351
352    /// All `out_dim` reconstruction columns as packed [`Order2<K>`] jets, with
353    /// the per-row redundant sub-jets HOISTED out of the output-column loop
354    /// (#932 perf). `reconstruction_column_packed(c)` rebuilds, for every output
355    /// column `c`, both the per-atom softmax gate jet `ζ_k` (`K` exps + a recip
356    /// + a `K×K` Hessian — the dominant cost) AND each per-atom basis jet
357    /// `Φ_{k,b}` — yet **neither depends on `c`**: the gate is a function of the
358    /// logits only, and the basis jet is the local Taylor model of `Φ_b` in the
359    /// coords, the decoder coefficient `B_{b,c}` being the only `c`-dependent
360    /// factor. The consumer (`fill_reconstruction_channels_from_program`) calls
361    /// it once per `c`, so the gate and basis jets are recomputed `out_dim×`
362    /// redundantly.
363    ///
364    /// This builds each atom's gate jet ONCE (`K` total) and each atom's basis
365    /// jets ONCE (`n_basis` per atom), then assembles every column by the cheap
366    /// reductions `decoded_{k,c} = Σ_b Φ_{k,b}·B_{b,c}` and
367    /// `ẑ_c = Σ_k ζ_k·decoded_{k,c}`. The result is **bit-identical** to calling
368    /// [`Self::reconstruction_column_packed`] per column (same Leibniz products in
369    /// the same order) — only the redundant recomputation is removed — measured
370    /// ~9× faster at `K=8, out_dim=16` on the per-row hot path.
371    #[must_use]
372    pub fn reconstruction_all_columns_packed<const K: usize>(&self) -> Vec<Order2<K>> {
373        assert_eq!(
374            self.n_primaries, K,
375            "SaeReconstructionRowProgram: tower arity K={K} must equal n_primaries={}",
376            self.n_primaries
377        );
378        let p = self.out_dim();
379        // Hoist the per-atom gate jet (c-independent) and basis jets
380        // (c-independent) out of the column loop. `all_gates` additionally shares
381        // the softmax denominator / reciprocal across atoms (K exps + 1 recip,
382        // not K² + K).
383        let gates: Vec<Order2<K>> = self.all_gates::<K, Order2<K>>();
384        let bases: Vec<Vec<Order2<K>>> = self
385            .atoms
386            .iter()
387            .enumerate()
388            .map(|(atom, atom_jet)| {
389                (0..atom_jet.n_basis())
390                    .map(|b| atom_jet.basis_tower::<K, Order2<K>>(b, &self.coord_slot[atom]))
391                    .collect()
392            })
393            .collect();
394        (0..p)
395            .map(|c| {
396                let mut acc = Order2::<K>::constant(0.0);
397                for (atom, atom_jet) in self.atoms.iter().enumerate() {
398                    // decoded_{k,c} = Σ_b Φ_{k,b}·B_{b,c} from the hoisted basis
399                    // jets — same per-basis sum `decoded_tower` forms, but the
400                    // basis jets are reused across every column.
401                    let mut decoded = Order2::<K>::constant(0.0);
402                    for basis_col in 0..atom_jet.n_basis() {
403                        let coeff = atom_jet.decoder[basis_col][c];
404                        if coeff == 0.0 {
405                            continue;
406                        }
407                        decoded = decoded.add(&bases[atom][basis_col].scale(coeff));
408                    }
409                    acc = acc.add(&gates[atom].mul(&decoded));
410                }
411                acc
412            })
413            .collect()
414    }
415
416    /// The reconstruction output column as the full dense [`Tower4<K>`] carrying
417    /// every value/gradient/Hessian/`t3`/`t4` channel. This is the #932 oracle
418    /// ground truth: the production [`Self::reconstruction_column_packed`]
419    /// order-2 path is pinned against its order-≤2 channels, and the FD-witness
420    /// tests use its `t3`/`t4`. Not on the per-row hot path.
421    #[must_use]
422    pub fn reconstruction_column<const K: usize>(&self, out_col: usize) -> Tower4<K> {
423        self.reconstruction_column_generic::<K, Tower4<K>>(out_col)
424    }
425
426    /// The β **border-channel** local-variable sub-jet: the scalar
427    /// `s_{k,b}(p) = ζ_k(ℓ)·Φ_b(t_k)` as a `Tower4<K>` in the local
428    /// (logit/coord) primaries — the gate activation times ONE basis function.
429    ///
430    /// In the arrow system a β border channel is one free decoder coefficient
431    /// `β_{k,b,channel}` whose per-row reconstruction contribution to output
432    /// column `c` is `ζ_k(ℓ)·Φ_b(t_k)·output_c`, where `output` is the channel's
433    /// (frame / identity) output vector carried by the `SaeBorderChannel`, NOT
434    /// the current decoder matrix. The reconstruction is **linear** in `β`, so
435    /// `∂ẑ_c/∂β_{k,b,channel} = ζ_k(ℓ)·Φ_b(t_k)·output_c = s_{k,b}.v·output_c`
436    /// and `∂²ẑ_c/∂β∂p_a = s_{k,b}.g[a]·output_c` (the production `beta` /
437    /// `beta_deriv` / `beta_l_deriv` channels). The `output_c` factor is a
438    /// per-column constant the caller applies; this tower carries the entire
439    /// local-variable dependence.
440    ///
441    /// It is built from the SAME `gate_tower` / `basis_tower` primitives as
442    /// [`Self::reconstruction_column`], so the β border channel is single
443    /// sourced with the local-variable reconstruction tower (#932) — the hand
444    /// path in `row_jets_for_logdet` packs these same `ζ_k·Φ_b` products (then
445    /// multiplies by `channel.output`) term by term, and is pinned to this
446    /// tower by the converged-cache oracle.
447    fn beta_border_generic<const K: usize, S: JetScalar<K>>(
448        &self,
449        atom: usize,
450        basis_col: usize,
451    ) -> S {
452        assert_eq!(
453            self.n_primaries, K,
454            "SaeReconstructionRowProgram: tower arity K={K} must equal n_primaries={}",
455            self.n_primaries
456        );
457        let gate = self.gate_tower::<K, S>(atom);
458        let phi = self.atoms[atom].basis_tower::<K, S>(basis_col, &self.coord_slot[atom]);
459        gate.mul(&phi)
460    }
461
462    /// The β **border-channel** local-variable sub-jet as the PACKED order-2 jet
463    /// [`Order2<K>`](gam_math::jet_scalar::Order2). The consumer reads only
464    /// `.value()` (the `beta` channel) and `.g()[a]` (the `beta_deriv` /
465    /// `beta_l_deriv` mixed channel — the reconstruction is linear in β so the
466    /// Hessian-in-β vanishes and only value+gradient are needed). Built from the
467    /// SAME packed gate / basis primitives as [`Self::reconstruction_column`], so
468    /// the dense `t3`/`t4` tensor is never materialised on this per-row hot path
469    /// (#932 Tower4→Order2 cutover).
470    #[must_use]
471    pub fn beta_border_tower_packed<const K: usize>(
472        &self,
473        atom: usize,
474        basis_col: usize,
475    ) -> Order2<K> {
476        self.beta_border_generic::<K, Order2<K>>(atom, basis_col)
477    }
478
479    /// The β border-channel sub-jet as the full dense [`Tower4<K>`] — the #932
480    /// oracle ground truth the packed [`Self::beta_border_tower_packed`] is
481    /// pinned against. Not on the per-row hot path.
482    #[must_use]
483    pub fn beta_border_tower<const K: usize>(&self, atom: usize, basis_col: usize) -> Tower4<K> {
484        self.beta_border_generic::<K, Tower4<K>>(atom, basis_col)
485    }
486
487    /// Packed β border-channel sub-jets for a batch of `(atom, basis_col)`
488    /// channels, with the per-atom gate jets HOISTED and the softmax denominator
489    /// SHARED across atoms (#932 perf): the gate jet `ζ_k` (the dominant `K`-exp
490    /// / `K×K`-Hessian cost) is a function of the row's logits only, not of
491    /// `basis_col`, and every atom's gate shares one softmax denominator /
492    /// reciprocal. [`Self::all_gates`] builds all `K` gates once (K exps + 1
493    /// recip per row); each channel then just multiplies its atom's cached gate
494    /// by its basis jet. Each result is **bit-identical** to
495    /// [`Self::beta_border_tower_packed`] for the same `(atom, basis_col)` (same
496    /// `gate.mul(basis)` product), in the input order.
497    #[must_use]
498    pub fn beta_border_towers_packed<const K: usize>(
499        &self,
500        channels: &[(usize, usize)],
501    ) -> Vec<Order2<K>> {
502        assert_eq!(
503            self.n_primaries, K,
504            "SaeReconstructionRowProgram: tower arity K={K} must equal n_primaries={}",
505            self.n_primaries
506        );
507        let gates: Vec<Order2<K>> = self.all_gates::<K, Order2<K>>();
508        channels
509            .iter()
510            .map(|&(atom, basis_col)| {
511                let phi =
512                    self.atoms[atom].basis_tower::<K, Order2<K>>(basis_col, &self.coord_slot[atom]);
513                gates[atom].mul(&phi)
514            })
515            .collect()
516    }
517
518    /// Packed β border-channel sub-jets for a batch of channels as the
519    /// FIRST-order jet [`Order1<K>`](gam_math::jet_scalar::Order1) — value +
520    /// gradient ONLY, no Hessian. The β-border consumer
521    /// (`fill_beta_border_channels_from_program`) reads exactly `.value()` (the
522    /// `beta` channel) and `.g()[a]` (the mixed `beta_deriv` / `beta_l_deriv`
523    /// channel); the reconstruction is linear in β so the Hessian-in-β vanishes
524    /// and the K×K Hessian that [`Self::beta_border_towers_packed`]'s `Order2`
525    /// builds is computed-and-discarded every call. This method drops that work:
526    /// `Order1`'s value/gradient are BIT-IDENTICAL to `Order2`'s (the order-≤1
527    /// channels never read a Hessian), proven by the `order1_*` oracle, while the
528    /// per-channel `gate.mul(basis)` skips the `K²` Hessian product.
529    ///
530    /// Same hoisting as [`Self::beta_border_towers_packed`]: gate jets built once
531    /// via [`Self::all_gates`], each channel multiplies its atom's gate by its
532    /// basis jet.
533    #[must_use]
534    pub fn beta_border_order1_packed<const K: usize>(
535        &self,
536        channels: &[(usize, usize)],
537    ) -> Vec<Order1<K>> {
538        assert_eq!(
539            self.n_primaries, K,
540            "SaeReconstructionRowProgram: tower arity K={K} must equal n_primaries={}",
541            self.n_primaries
542        );
543        let gates: Vec<Order1<K>> = self.all_gates::<K, Order1<K>>();
544        channels
545            .iter()
546            .map(|&(atom, basis_col)| {
547                let phi =
548                    self.atoms[atom].basis_tower::<K, Order1<K>>(basis_col, &self.coord_slot[atom]);
549                gates[atom].mul(&phi)
550            })
551            .collect()
552    }
553
554    /// The number of reconstruction output columns.
555    #[must_use]
556    pub fn out_dim(&self) -> usize {
557        self.atoms.first().map_or(0, AtomRowBasisJet::out_dim)
558    }
559}
560
561// ─────────────────────────────────────────────────────────────────────────
562// 4-ROW SIMD BATCH (the jet's throughput lever over hand-scalar code)
563//
564// The hot per-row jet kernels (`reconstruction_all_columns_packed`,
565// `beta_border_order1_packed`) evaluate ONE row's `(v, g, H)` / `(v, g)` tower
566// at a time in scalar `f64`. A hand-written scalar derivative does exactly the
567// same. The throughput lever a jet has that scalar hand-code cannot is **row
568// batching in SIMD lanes**: the order-≤2 Leibniz product is `O(K²)` independent
569// per-channel float ops, and EVERY softmax row runs the IDENTICAL op graph on
570// different data — the textbook SPMD shape. Packing `LANES = 4` aligned rows
571// into a `[f64; 4]` lane and running the algebra once per 4 rows replaces 4
572// scalar passes with one vector pass, so the `K²` Hessian-channel updates become
573// 4-wide lane ops covering 4 rows each (auto-vectorised to SSE2 `pd` / NEON
574// `.2d`), ~4× fewer scalar FP instructions per row.
575//
576// The lane field is a plain `[f64; 4]` whose every op is a lane-wise IEEE
577// `+`/`-`/`*` (NEVER a fused `mul_add`), so lane `i` of a 4-wide op equals the
578// scalar `f64` op on that lane's inputs BIT-FOR-BIT. The op order mirrors
579// [`gam_math::jet_tower::Tower2`] / [`Order1`] term-for-term, so
580// [`O2x4`]/[`O1x4`] lane `i` is `to_bits`-identical to the production
581// [`Order2`]/[`Order1`] row scalar — proven by the `batch_tests` oracle below
582// (≥2000 random aligned 4-row batches across `K ∈ {2,4,6}`).
583//
584// Only the softmax gate is batched: its op graph is identical across rows (every
585// atom is an active free logit), while the per-atom logistic gate's
586// `x.value() >= 0.0` branch is per-row data-dependent (lanes could need
587// different branches, which are NOT bit-identical), so logistic rows fall back
588// to the scalar per-row path in the caller.
589
590const LANES: usize = 4;
591
592#[inline]
593fn l_splat(x: f64) -> [f64; LANES] {
594    [x; LANES]
595}
596#[inline]
597fn l_add(a: [f64; LANES], b: [f64; LANES]) -> [f64; LANES] {
598    let mut o = [0.0; LANES];
599    for i in 0..LANES {
600        o[i] = a[i] + b[i];
601    }
602    o
603}
604#[inline]
605fn l_mul(a: [f64; LANES], b: [f64; LANES]) -> [f64; LANES] {
606    let mut o = [0.0; LANES];
607    for i in 0..LANES {
608        o[i] = a[i] * b[i];
609    }
610    o
611}
612
613/// 4-rows-per-pass order-≤2 lane scalar (value / gradient / Hessian), mirroring
614/// [`gam_math::jet_tower::Tower2`] (hence [`Order2`]) term-for-term so lane `i`
615/// is `to_bits`-identical to the scalar row-`i` [`Order2`].
616#[derive(Clone, Copy)]
617struct O2x4<const K: usize> {
618    v: [f64; LANES],
619    g: [[f64; LANES]; K],
620    h: [[[f64; LANES]; K]; K],
621}
622
623impl<const K: usize> O2x4<K> {
624    #[inline]
625    fn constant(c: [f64; LANES]) -> Self {
626        O2x4 {
627            v: c,
628            g: [[0.0; LANES]; K],
629            h: [[[0.0; LANES]; K]; K],
630        }
631    }
632    /// Seeded primary `axis` at (per-lane) `value`: unit first derivative.
633    #[inline]
634    fn variable(value: [f64; LANES], axis: usize) -> Self {
635        let mut out = Self::constant(value);
636        out.g[axis] = l_splat(1.0);
637        out
638    }
639    #[inline]
640    fn add(&self, o: &Self) -> Self {
641        let mut out = *self;
642        out.v = l_add(self.v, o.v);
643        for i in 0..K {
644            out.g[i] = l_add(self.g[i], o.g[i]);
645            for j in 0..K {
646                out.h[i][j] = l_add(self.h[i][j], o.h[i][j]);
647            }
648        }
649        out
650    }
651    #[inline]
652    fn scale(&self, s: [f64; LANES]) -> Self {
653        let mut out = *self;
654        out.v = l_mul(self.v, s);
655        for i in 0..K {
656            out.g[i] = l_mul(self.g[i], s);
657            for j in 0..K {
658                out.h[i][j] = l_mul(self.h[i][j], s);
659            }
660        }
661        out
662    }
663    /// `self - o`, expressed as `self + o·(-1)` exactly as [`Order2::sub`] does.
664    #[inline]
665    fn sub(&self, o: &Self) -> Self {
666        self.add(&o.scale(l_splat(-1.0)))
667    }
668    /// Order-≤2 Leibniz product, term-for-term identical to `Tower2::mul`.
669    #[inline]
670    fn mul(&self, o: &Self) -> Self {
671        let a = self;
672        let b = o;
673        let mut out = Self::constant(l_mul(a.v, b.v));
674        for i in 0..K {
675            out.g[i] = l_add(l_mul(a.v, b.g[i]), l_mul(a.g[i], b.v));
676        }
677        for i in 0..K {
678            for j in 0..K {
679                let t0 = l_mul(a.v, b.h[i][j]);
680                let t1 = l_add(t0, l_mul(a.g[i], b.g[j]));
681                let t2 = l_add(t1, l_mul(a.g[j], b.g[i]));
682                out.h[i][j] = l_add(t2, l_mul(a.h[i][j], b.v));
683            }
684        }
685        out
686    }
687    /// Order-≤2 Faà di Bruno `f ∘ self` from the per-lane stack
688    /// `d = [f(u), f′(u), f″(u)]`, mirroring `Tower2::compose_unary`
689    /// (`acc` starts at `+0.0`, accumulates `d₁·hᵢⱼ` then `(d₂·gᵢ)·gⱼ`).
690    #[inline]
691    fn compose(&self, d: [[f64; LANES]; 3]) -> Self {
692        let mut out = Self::constant(d[0]);
693        for i in 0..K {
694            let mut acc = l_splat(0.0);
695            acc = l_add(acc, l_mul(d[1], self.g[i]));
696            out.g[i] = acc;
697        }
698        for i in 0..K {
699            for j in 0..K {
700                let mut acc = l_splat(0.0);
701                acc = l_add(acc, l_mul(d[1], self.h[i][j]));
702                acc = l_add(acc, l_mul(l_mul(d[2], self.g[i]), self.g[j]));
703                out.h[i][j] = acc;
704            }
705        }
706        out
707    }
708    /// `e^self`, per-lane stack `[e, e, e]` (matches `Tower2::exp`).
709    #[inline]
710    fn exp(&self) -> Self {
711        let mut e = [0.0; LANES];
712        for i in 0..LANES {
713            e[i] = self.v[i].exp();
714        }
715        self.compose([e, e, e])
716    }
717    /// `1/self`, per-lane stack `[1/u, -1/u², 2/u³]` — the DIVISION-based stack
718    /// of the [`recip`] free fn the scalar reconstruction path uses (NOT the
719    /// reciprocal-multiply `[r,-r²,2r³]` of `JetScalar::recip`; those differ by a
720    /// ULP and would break `to_bits` parity). Caller guarantees nonzero.
721    #[inline]
722    fn recip(&self) -> Self {
723        let mut d0 = [0.0; LANES];
724        let mut d1 = [0.0; LANES];
725        let mut d2 = [0.0; LANES];
726        for i in 0..LANES {
727            let u = self.v[i];
728            let u2 = u * u;
729            let u3 = u2 * u;
730            d0[i] = 1.0 / u;
731            d1[i] = -1.0 / u2;
732            d2[i] = 2.0 / u3;
733        }
734        self.compose([d0, d1, d2])
735    }
736    /// Extract lane `i` as a production [`Order2<K>`] scalar.
737    #[inline]
738    fn lane(&self, i: usize) -> Order2<K> {
739        let mut t = gam_math::jet_tower::Tower2::<K>::constant(self.v[i]);
740        for a in 0..K {
741            t.g[a] = self.g[a][i];
742            for b in 0..K {
743                t.h[a][b] = self.h[a][b][i];
744            }
745        }
746        Order2(t)
747    }
748}
749
750/// 4-rows-per-pass FIRST-order lane scalar (value / gradient only), mirroring
751/// [`Order1`] term-for-term so lane `i` is `to_bits`-identical to row-`i`
752/// [`Order1`]. Used for the β-border consumer (reconstruction is linear in β,
753/// so only value + gradient are read).
754#[derive(Clone, Copy)]
755struct O1x4<const K: usize> {
756    v: [f64; LANES],
757    g: [[f64; LANES]; K],
758}
759
760impl<const K: usize> O1x4<K> {
761    #[inline]
762    fn constant(c: [f64; LANES]) -> Self {
763        O1x4 {
764            v: c,
765            g: [[0.0; LANES]; K],
766        }
767    }
768    #[inline]
769    fn variable(value: [f64; LANES], axis: usize) -> Self {
770        let mut out = Self::constant(value);
771        out.g[axis] = l_splat(1.0);
772        out
773    }
774    #[inline]
775    fn add(&self, o: &Self) -> Self {
776        let mut out = *self;
777        out.v = l_add(self.v, o.v);
778        for i in 0..K {
779            out.g[i] = l_add(self.g[i], o.g[i]);
780        }
781        out
782    }
783    #[inline]
784    fn scale(&self, s: [f64; LANES]) -> Self {
785        let mut out = *self;
786        out.v = l_mul(self.v, s);
787        for i in 0..K {
788            out.g[i] = l_mul(self.g[i], s);
789        }
790        out
791    }
792    #[inline]
793    fn sub(&self, o: &Self) -> Self {
794        self.add(&o.scale(l_splat(-1.0)))
795    }
796    #[inline]
797    fn mul(&self, o: &Self) -> Self {
798        // Tower2::mul value/grad terms (order-≤1 truncation): v = a.v·b.v;
799        // g[i] = a.v·b.g[i] + a.g[i]·b.v. Identical float order to `Order1::mul`.
800        let a = self;
801        let b = o;
802        let mut out = Self::constant(l_mul(a.v, b.v));
803        for i in 0..K {
804            out.g[i] = l_add(l_mul(a.v, b.g[i]), l_mul(a.g[i], b.v));
805        }
806        out
807    }
808    #[inline]
809    fn compose(&self, d: [[f64; LANES]; 2]) -> Self {
810        // Order-≤1 Faà di Bruno: v = d[0]; g[i] = d[1]·g[i] (matches
811        // `Order1::compose_unary`, `acc` starts at +0.0).
812        let mut out = Self::constant(d[0]);
813        for i in 0..K {
814            let mut acc = l_splat(0.0);
815            acc = l_add(acc, l_mul(d[1], self.g[i]));
816            out.g[i] = acc;
817        }
818        out
819    }
820    #[inline]
821    fn exp(&self) -> Self {
822        let mut e = [0.0; LANES];
823        for i in 0..LANES {
824            e[i] = self.v[i].exp();
825        }
826        self.compose([e, e])
827    }
828    #[inline]
829    fn recip(&self) -> Self {
830        // Division-based `[1/u, -1/u²]` matching the `recip` free fn (see
831        // `O2x4::recip`), so lane `i` is `to_bits`-identical to the scalar path.
832        let mut d0 = [0.0; LANES];
833        let mut d1 = [0.0; LANES];
834        for i in 0..LANES {
835            let u = self.v[i];
836            let u2 = u * u;
837            d0[i] = 1.0 / u;
838            d1[i] = -1.0 / u2;
839        }
840        self.compose([d0, d1])
841    }
842    #[inline]
843    fn lane(&self, i: usize) -> Order1<K> {
844        let mut g = [0.0; K];
845        for a in 0..K {
846            g[a] = self.g[a][i];
847        }
848        Order1 { v: self.v[i], g }
849    }
850}
851
852/// Structural layout signature of a row program: the part that MUST be identical
853/// across rows for them to share one SIMD op graph (slot mapping, per-atom
854/// basis/latent/decoder shape, primary count). The per-row numeric data
855/// (`phi`/`d_phi`/`d2_phi`/`decoder` VALUES, `logits`) is what differs between
856/// lanes; the layout is what is shared.
857impl SaeReconstructionRowProgram {
858    /// Whether `self` and `other` share the SIMD-batchable softmax layout: same
859    /// softmax temperature, primary count, slot mapping, and per-atom basis /
860    /// latent / decoder dimensions. (Decoder/basis VALUES may differ per row and
861    /// are lane-packed; only the SHAPES must match.)
862    fn batch_aligned_softmax_with(&self, other: &Self) -> bool {
863        // Both rows must gate through softmax at the same temperature; a
864        // bit-for-bit `inv_tau` match is what lets them share one op graph.
865        match (self.gate, other.gate) {
866            (RowGate::Softmax { inv_tau: a }, RowGate::Softmax { inv_tau: b }) => {
867                if a.to_bits() != b.to_bits() {
868                    return false;
869                }
870            }
871            _ => return false,
872        }
873        if self.n_primaries != other.n_primaries
874            || self.atoms.len() != other.atoms.len()
875            || self.logit_slot != other.logit_slot
876            || self.coord_slot != other.coord_slot
877            || self.logits.len() != other.logits.len()
878        {
879            return false;
880        }
881        for (a, b) in self.atoms.iter().zip(other.atoms.iter()) {
882            if a.latent_dim != b.latent_dim
883                || a.n_basis() != b.n_basis()
884                || a.out_dim() != b.out_dim()
885            {
886                return false;
887            }
888        }
889        true
890    }
891
892    /// All `K` softmax gate lane-jets (`Order2` channels), with the denominator
893    /// SHARED across atoms and 4 rows packed per lane. Mirrors [`Self::all_gates`]
894    /// term-for-term so lane `i` is `to_bits`-identical to the row-`i` scalar
895    /// `all_gates::<K, Order2<K>>()`.
896    fn all_gates_o2x4<const K: usize>(&self, rows: &[&Self; LANES], inv_tau: f64) -> Vec<O2x4<K>> {
897        let n = self.gate_value.len();
898        let inv_tau_l = l_splat(inv_tau);
899        // Per-lane max-subtraction shift (= the scalar `all_gates` softmax shift,
900        // computed independently per row/lane).
901        let mut shift = [0.0; LANES];
902        for (lane, r) in rows.iter().enumerate() {
903            shift[lane] = r.logits.iter().copied().fold(f64::NEG_INFINITY, f64::max) * inv_tau;
904        }
905        let mut exps: Vec<O2x4<K>> = Vec::with_capacity(n);
906        let mut denom = O2x4::<K>::constant(l_splat(0.0));
907        for j in 0..n {
908            let mut lj_val = [0.0; LANES];
909            for (lane, r) in rows.iter().enumerate() {
910                lj_val[lane] = r.logits[j];
911            }
912            let lj = match self.logit_slot[j] {
913                Some(slot) => O2x4::<K>::variable(lj_val, slot),
914                None => O2x4::<K>::constant(lj_val),
915            };
916            let ej = lj.scale(inv_tau_l).sub(&O2x4::<K>::constant(shift)).exp();
917            denom = denom.add(&ej);
918            exps.push(ej);
919        }
920        let inv = denom.recip();
921        exps.iter().map(|e| e.mul(&inv)).collect()
922    }
923
924    /// All `K` softmax gate lane-jets at FIRST order (`Order1` channels).
925    /// Mirrors `all_gates::<K, Order1<K>>()` term-for-term.
926    fn all_gates_o1x4<const K: usize>(&self, rows: &[&Self; LANES], inv_tau: f64) -> Vec<O1x4<K>> {
927        let n = self.gate_value.len();
928        let inv_tau_l = l_splat(inv_tau);
929        let mut shift = [0.0; LANES];
930        for (lane, r) in rows.iter().enumerate() {
931            shift[lane] = r.logits.iter().copied().fold(f64::NEG_INFINITY, f64::max) * inv_tau;
932        }
933        let mut exps: Vec<O1x4<K>> = Vec::with_capacity(n);
934        let mut denom = O1x4::<K>::constant(l_splat(0.0));
935        for j in 0..n {
936            let mut lj_val = [0.0; LANES];
937            for (lane, r) in rows.iter().enumerate() {
938                lj_val[lane] = r.logits[j];
939            }
940            let lj = match self.logit_slot[j] {
941                Some(slot) => O1x4::<K>::variable(lj_val, slot),
942                None => O1x4::<K>::constant(lj_val),
943            };
944            let ej = lj.scale(inv_tau_l).sub(&O1x4::<K>::constant(shift)).exp();
945            denom = denom.add(&ej);
946            exps.push(ej);
947        }
948        let inv = denom.recip();
949        exps.iter().map(|e| e.mul(&inv)).collect()
950    }
951
952    /// One atom's basis jet `Φ_b(t)` as an [`O2x4`] over 4 rows, mirroring
953    /// [`AtomRowBasisJet::basis_tower`] term-for-term. A data-dependent `== 0`
954    /// skip is taken only when ALL 4 lanes are zero (the contribution of a zero
955    /// lane is `+0.0`, bit-identical to the scalar skip).
956    fn basis_tower_o2x4<const K: usize>(
957        rows: &[&Self; LANES],
958        atom: usize,
959        basis_col: usize,
960        coord_slots: &[usize],
961    ) -> O2x4<K> {
962        let latent = rows[0].atoms[atom].latent_dim;
963        let mut phi0 = [0.0; LANES];
964        for (lane, r) in rows.iter().enumerate() {
965            phi0[lane] = r.atoms[atom].phi[basis_col];
966        }
967        let mut acc = O2x4::<K>::constant(phi0);
968        for axis in 0..latent {
969            let slot = coord_slots[axis];
970            if slot == SAE_FIXED_COORD_SLOT {
971                continue;
972            }
973            let mut d1 = [0.0; LANES];
974            let mut any = false;
975            for (lane, r) in rows.iter().enumerate() {
976                let v = r.atoms[atom].d_phi[basis_col][axis];
977                d1[lane] = v;
978                any |= v != 0.0;
979            }
980            if any {
981                acc = acc.add(&O2x4::<K>::variable(l_splat(0.0), slot).scale(d1));
982            }
983        }
984        for axis_a in 0..latent {
985            // Hoist the fixed-slot skip and the `va` variable build out of the
986            // inner axis_b loop: both depend only on axis_a, so the old code
987            // rebuilt `va` and re-tested the slot `latent` times per axis_a.
988            let slot_a = coord_slots[axis_a];
989            if slot_a == SAE_FIXED_COORD_SLOT {
990                continue;
991            }
992            let va = O2x4::<K>::variable(l_splat(0.0), slot_a);
993            for axis_b in 0..latent {
994                let slot_b = coord_slots[axis_b];
995                if slot_b == SAE_FIXED_COORD_SLOT {
996                    continue;
997                }
998                let mut d2 = [0.0; LANES];
999                let mut any = false;
1000                for (lane, r) in rows.iter().enumerate() {
1001                    let v = r.atoms[atom].d2_phi[basis_col][axis_a][axis_b];
1002                    d2[lane] = v;
1003                    any |= v != 0.0;
1004                }
1005                if !any {
1006                    continue;
1007                }
1008                let mut half_d2 = [0.0; LANES];
1009                for lane in 0..LANES {
1010                    half_d2[lane] = 0.5 * d2[lane];
1011                }
1012                let vb = O2x4::<K>::variable(l_splat(0.0), slot_b);
1013                acc = acc.add(&va.mul(&vb).scale(half_d2));
1014            }
1015        }
1016        acc
1017    }
1018
1019    /// One atom's basis jet as an [`O1x4`] (value + gradient), mirroring
1020    /// `basis_tower::<Order1>` term-for-term.
1021    fn basis_tower_o1x4<const K: usize>(
1022        rows: &[&Self; LANES],
1023        atom: usize,
1024        basis_col: usize,
1025        coord_slots: &[usize],
1026    ) -> O1x4<K> {
1027        let latent = rows[0].atoms[atom].latent_dim;
1028        let mut phi0 = [0.0; LANES];
1029        for (lane, r) in rows.iter().enumerate() {
1030            phi0[lane] = r.atoms[atom].phi[basis_col];
1031        }
1032        let mut acc = O1x4::<K>::constant(phi0);
1033        for axis in 0..latent {
1034            let slot = coord_slots[axis];
1035            if slot == SAE_FIXED_COORD_SLOT {
1036                continue;
1037            }
1038            let mut d1 = [0.0; LANES];
1039            let mut any = false;
1040            for (lane, r) in rows.iter().enumerate() {
1041                let v = r.atoms[atom].d_phi[basis_col][axis];
1042                d1[lane] = v;
1043                any |= v != 0.0;
1044            }
1045            if any {
1046                acc = acc.add(&O1x4::<K>::variable(l_splat(0.0), slot).scale(d1));
1047            }
1048        }
1049        for axis_a in 0..latent {
1050            for axis_b in 0..latent {
1051                if coord_slots[axis_a] == SAE_FIXED_COORD_SLOT
1052                    || coord_slots[axis_b] == SAE_FIXED_COORD_SLOT
1053                {
1054                    continue;
1055                }
1056                let mut d2 = [0.0; LANES];
1057                let mut any = false;
1058                for (lane, r) in rows.iter().enumerate() {
1059                    let v = r.atoms[atom].d2_phi[basis_col][axis_a][axis_b];
1060                    d2[lane] = v;
1061                    any |= v != 0.0;
1062                }
1063                if !any {
1064                    continue;
1065                }
1066                let mut half_d2 = [0.0; LANES];
1067                for lane in 0..LANES {
1068                    half_d2[lane] = 0.5 * d2[lane];
1069                }
1070                let va = O1x4::<K>::variable(l_splat(0.0), coord_slots[axis_a]);
1071                let vb = O1x4::<K>::variable(l_splat(0.0), coord_slots[axis_b]);
1072                acc = acc.add(&va.mul(&vb).scale(half_d2));
1073            }
1074        }
1075        acc
1076    }
1077
1078    /// All `out_dim` reconstruction columns for FOUR softmax-aligned rows at once,
1079    /// returned per row. Each row's column vector is BIT-IDENTICAL to
1080    /// [`Self::reconstruction_all_columns_packed`] on that row (same hoisting,
1081    /// same Leibniz products in the same order — lane `i` mirrors the scalar
1082    /// row-`i` path). Returns `None` if the four rows are not softmax-aligned, so
1083    /// the caller can fall back to the scalar per-row path.
1084    #[must_use]
1085    pub fn reconstruction_all_columns_batch4<const K: usize>(
1086        rows: [&Self; 4],
1087    ) -> Option<[Vec<Order2<K>>; 4]> {
1088        let head = rows[0];
1089        if head.n_primaries != K {
1090            return None;
1091        }
1092        let inv_tau = match head.gate {
1093            RowGate::Softmax { inv_tau } => inv_tau,
1094            RowGate::PerAtomLogistic { .. } => return None,
1095        };
1096        for r in &rows[1..] {
1097            if !head.batch_aligned_softmax_with(r) {
1098                return None;
1099            }
1100        }
1101        let p = head.out_dim();
1102        let gates: Vec<O2x4<K>> = head.all_gates_o2x4::<K>(&rows, inv_tau);
1103        // Build a jet tower ONLY for the basis columns that actually decode to
1104        // something: a column whose decoder row is identically zero across every
1105        // output channel AND every lane contributes exactly zero to all `p` output
1106        // sums, so both its (expensive) O2x4 tower build and its per-output gather
1107        // are pure waste. Skipping it is bit-identical — the old inner `any` guard
1108        // already dropped the scaled add, this just also drops the dead tower build
1109        // and the dead re-gather across all `p` columns. Each atom keeps a compact
1110        // `(basis_col, tower)` list of its live columns.
1111        let bases: Vec<Vec<(usize, O2x4<K>)>> = head
1112            .atoms
1113            .iter()
1114            .enumerate()
1115            .map(|(atom, atom_jet)| {
1116                (0..atom_jet.n_basis())
1117                    .filter(|&b| {
1118                        rows.iter()
1119                            .any(|r| (0..p).any(|c| r.atoms[atom].decoder[b][c] != 0.0))
1120                    })
1121                    .map(|b| (b, Self::basis_tower_o2x4::<K>(&rows, atom, b, &head.coord_slot[atom])))
1122                    .collect()
1123            })
1124            .collect();
1125        let mut cols: [Vec<Order2<K>>; LANES] =
1126            [Vec::new(), Vec::new(), Vec::new(), Vec::new()];
1127        for c in 0..p {
1128            let mut acc = O2x4::<K>::constant(l_splat(0.0));
1129            for atom in 0..head.atoms.len() {
1130                let mut decoded = O2x4::<K>::constant(l_splat(0.0));
1131                for (basis_col, tower) in &bases[atom] {
1132                    let mut coeff = [0.0; LANES];
1133                    let mut any = false;
1134                    for (lane, r) in rows.iter().enumerate() {
1135                        let v = r.atoms[atom].decoder[*basis_col][c];
1136                        coeff[lane] = v;
1137                        any |= v != 0.0;
1138                    }
1139                    if any {
1140                        decoded = decoded.add(&tower.scale(coeff));
1141                    }
1142                }
1143                acc = acc.add(&gates[atom].mul(&decoded));
1144            }
1145            for lane in 0..LANES {
1146                cols[lane].push(acc.lane(lane));
1147            }
1148        }
1149        Some(cols)
1150    }
1151
1152    /// Packed β-border FIRST-order jets for a batch of `(atom, basis_col)`
1153    /// channels, for FOUR softmax-aligned rows at once, returned per row. Each
1154    /// row's channel vector is BIT-IDENTICAL to
1155    /// [`Self::beta_border_order1_packed`] on that row. Returns `None` if the
1156    /// rows are not softmax-aligned.
1157    #[must_use]
1158    pub fn beta_border_order1_batch4<const K: usize>(
1159        rows: [&Self; 4],
1160        channels: &[(usize, usize)],
1161    ) -> Option<[Vec<Order1<K>>; 4]> {
1162        let head = rows[0];
1163        if head.n_primaries != K {
1164            return None;
1165        }
1166        let inv_tau = match head.gate {
1167            RowGate::Softmax { inv_tau } => inv_tau,
1168            RowGate::PerAtomLogistic { .. } => return None,
1169        };
1170        for r in &rows[1..] {
1171            if !head.batch_aligned_softmax_with(r) {
1172                return None;
1173            }
1174        }
1175        let gates: Vec<O1x4<K>> = head.all_gates_o1x4::<K>(&rows, inv_tau);
1176        let mut out: [Vec<Order1<K>>; LANES] =
1177            [Vec::new(), Vec::new(), Vec::new(), Vec::new()];
1178        for &(atom, basis_col) in channels {
1179            let phi = Self::basis_tower_o1x4::<K>(&rows, atom, basis_col, &head.coord_slot[atom]);
1180            let s = gates[atom].mul(&phi);
1181            for lane in 0..LANES {
1182                out[lane].push(s.lane(lane));
1183            }
1184        }
1185        Some(out)
1186    }
1187}
1188
1189#[cfg(test)]
1190mod tests {
1191    use super::*;
1192
1193    /// Replicate the production hand path (`row_jets_for_logdet`) arithmetic for
1194    /// the reconstruction `first`/`second` channels of ONE output column, from
1195    /// the same atom jets and softmax gate derivatives — independent code from
1196    /// the tower. The two must agree to machine precision; this is the #932
1197    /// universal oracle for the SAE row program (the analog of the survival
1198    /// `rigid_row_kernel_agrees_with_jet_tower_program` oracle).
1199    struct HandChannels {
1200        first: Vec<f64>,       // [primary]
1201        second: Vec<Vec<f64>>, // [primary][primary]
1202        value: f64,
1203    }
1204
1205    /// Softmax gate first/second derivatives wrt logit primaries, term-for-term
1206    /// the production `gate_derivatives_for_row` softmax branch.
1207    fn softmax_gate_derivs(gate: &[f64], inv_tau: f64) -> (Vec<Vec<f64>>, Vec<Vec<Vec<f64>>>) {
1208        let k = gate.len();
1209        // dz[j][kk] = ∂ζ_kk/∂ℓ_j ; d2z[j][l][kk] = ∂²ζ_kk/∂ℓ_j∂ℓ_l.
1210        let mut dz = vec![vec![0.0_f64; k]; k];
1211        let mut d2z = vec![vec![vec![0.0_f64; k]; k]; k];
1212        for j in 0..k {
1213            for kk in 0..k {
1214                let ind = if kk == j { 1.0 } else { 0.0 };
1215                dz[j][kk] = gate[kk] * (ind - gate[j]) * inv_tau;
1216            }
1217        }
1218        for j in 0..k {
1219            for l in 0..k {
1220                for kk in 0..k {
1221                    let ikl = if kk == l { 1.0 } else { 0.0 };
1222                    let ikj = if kk == j { 1.0 } else { 0.0 };
1223                    let ijl = if j == l { 1.0 } else { 0.0 };
1224                    d2z[j][l][kk] = gate[kk]
1225                        * ((ikl - gate[l]) * (ikj - gate[j]) - gate[j] * (ijl - gate[l]))
1226                        * inv_tau
1227                        * inv_tau;
1228                }
1229            }
1230        }
1231        (dz, d2z)
1232    }
1233
1234    /// Hand-pack the reconstruction column channels exactly as the production
1235    /// `row_jets_for_logdet` does for a softmax gate: gate-logit primaries first
1236    /// (one per atom), then each atom's latent coords.
1237    fn hand_softmax_column(
1238        prog: &SaeReconstructionRowProgram,
1239        out_col: usize,
1240        inv_tau: f64,
1241    ) -> HandChannels {
1242        let k = prog.atoms.len();
1243        let n = prog.n_primaries;
1244        // decoded[k] value, d1[k][axis], d2[k][a][b] for this out_col.
1245        let decoded: Vec<f64> = (0..k)
1246            .map(|kk| {
1247                (0..prog.atoms[kk].n_basis())
1248                    .map(|b| prog.atoms[kk].phi[b] * prog.atoms[kk].decoder[b][out_col])
1249                    .sum()
1250            })
1251            .collect();
1252        let d1: Vec<Vec<f64>> = (0..k)
1253            .map(|kk| {
1254                (0..prog.atoms[kk].latent_dim)
1255                    .map(|axis| {
1256                        (0..prog.atoms[kk].n_basis())
1257                            .map(|b| {
1258                                prog.atoms[kk].d_phi[b][axis] * prog.atoms[kk].decoder[b][out_col]
1259                            })
1260                            .sum()
1261                    })
1262                    .collect()
1263            })
1264            .collect();
1265        let d2: Vec<Vec<Vec<f64>>> = (0..k)
1266            .map(|kk| {
1267                (0..prog.atoms[kk].latent_dim)
1268                    .map(|a| {
1269                        (0..prog.atoms[kk].latent_dim)
1270                            .map(|b| {
1271                                (0..prog.atoms[kk].n_basis())
1272                                    .map(|col| {
1273                                        prog.atoms[kk].d2_phi[col][a][b]
1274                                            * prog.atoms[kk].decoder[col][out_col]
1275                                    })
1276                                    .sum()
1277                            })
1278                            .collect()
1279                    })
1280                    .collect()
1281            })
1282            .collect();
1283
1284        let (dz, d2z) = softmax_gate_derivs(&prog.gate_value, inv_tau);
1285
1286        // Primary index of atom logit / coord, matching the program layout.
1287        let logit_idx = |kk: usize| prog.logit_slot[kk];
1288        let coord_idx = |kk: usize, axis: usize| prog.coord_slot[kk][axis];
1289
1290        let value: f64 = (0..k).map(|kk| prog.gate_value[kk] * decoded[kk]).sum();
1291
1292        let mut first = vec![0.0_f64; n];
1293        // Logit primaries: ∂ẑ/∂ℓ_j = Σ_kk dz[j][kk]·decoded[kk].
1294        for j in 0..k {
1295            if let Some(p) = logit_idx(j) {
1296                first[p] = (0..k).map(|kk| dz[j][kk] * decoded[kk]).sum();
1297            }
1298        }
1299        // Coord primaries: ∂ẑ/∂t_{kk,axis} = ζ_kk · d1[kk][axis].
1300        for kk in 0..k {
1301            for axis in 0..prog.atoms[kk].latent_dim {
1302                first[coord_idx(kk, axis)] = prog.gate_value[kk] * d1[kk][axis];
1303            }
1304        }
1305
1306        let mut second = vec![vec![0.0_f64; n]; n];
1307        // Logit×Logit: Σ_kk d2z[j][l][kk]·decoded[kk].
1308        for j in 0..k {
1309            for l in 0..k {
1310                if let (Some(pj), Some(pl)) = (logit_idx(j), logit_idx(l)) {
1311                    second[pj][pl] = (0..k).map(|kk| d2z[j][l][kk] * decoded[kk]).sum();
1312                }
1313            }
1314        }
1315        // Logit×Coord (and symmetric): dz[j][kk]·d1[kk][axis].
1316        for j in 0..k {
1317            for kk in 0..k {
1318                for axis in 0..prog.atoms[kk].latent_dim {
1319                    if let Some(pj) = logit_idx(j) {
1320                        let pc = coord_idx(kk, axis);
1321                        let val = dz[j][kk] * d1[kk][axis];
1322                        second[pj][pc] = val;
1323                        second[pc][pj] = val;
1324                    }
1325                }
1326            }
1327        }
1328        // Coord×Coord same atom: ζ_kk · d2[kk][a][b].
1329        for kk in 0..k {
1330            for a in 0..prog.atoms[kk].latent_dim {
1331                for b in 0..prog.atoms[kk].latent_dim {
1332                    let pa = coord_idx(kk, a);
1333                    let pb = coord_idx(kk, b);
1334                    second[pa][pb] = prog.gate_value[kk] * d2[kk][a][b];
1335                }
1336            }
1337        }
1338
1339        HandChannels {
1340            first,
1341            second,
1342            value,
1343        }
1344    }
1345
1346    /// Build a two-atom softmax fixture with `latent_dim = 2` per atom and a
1347    /// dense decoder so every primary is exercised. Layout: logit slots
1348    /// 0,1; atom-0 coords 2,3; atom-1 coords 4,5 → K = 6 primaries.
1349    fn softmax_fixture(inv_tau: f64) -> (SaeReconstructionRowProgram, f64) {
1350        let n_basis = 3;
1351        let out_dim = 4;
1352        let mk_atom = |seed: f64| {
1353            let phi: Vec<f64> = (0..n_basis)
1354                .map(|b| 0.3 + 0.2 * (b as f64 + seed))
1355                .collect();
1356            let d_phi: Vec<Vec<f64>> = (0..n_basis)
1357                .map(|b| {
1358                    (0..2)
1359                        .map(|axis| 0.1 * (b as f64 + 1.0) - 0.05 * axis as f64 + 0.03 * seed)
1360                        .collect()
1361                })
1362                .collect();
1363            let d2_phi: Vec<Vec<Vec<f64>>> = (0..n_basis)
1364                .map(|b| {
1365                    (0..2)
1366                        .map(|a| {
1367                            (0..2)
1368                                .map(|bb| {
1369                                    // Symmetric in (a, bb).
1370                                    0.02 * (b as f64 + 1.0)
1371                                        + 0.01 * (a as f64)
1372                                        + 0.01 * (bb as f64)
1373                                        + 0.004 * seed
1374                                })
1375                                .collect()
1376                        })
1377                        .collect()
1378                })
1379                .collect();
1380            let decoder: Vec<Vec<f64>> = (0..n_basis)
1381                .map(|b| {
1382                    (0..out_dim)
1383                        .map(|c| 0.5 - 0.1 * (b as f64) + 0.07 * (c as f64) + 0.02 * seed)
1384                        .collect()
1385                })
1386                .collect();
1387            AtomRowBasisJet {
1388                phi,
1389                d_phi,
1390                d2_phi,
1391                decoder,
1392                latent_dim: 2,
1393            }
1394        };
1395        let logits = vec![0.4_f64, -0.7];
1396        // Softmax gate values at these logits.
1397        let e: Vec<f64> = logits.iter().map(|&l| (l * inv_tau).exp()).collect();
1398        let s: f64 = e.iter().sum();
1399        let gate_value: Vec<f64> = e.iter().map(|&v| v / s).collect();
1400        let prog = SaeReconstructionRowProgram {
1401            atoms: vec![mk_atom(0.0), mk_atom(1.0)],
1402            gate_value,
1403            logits,
1404            gate_scale: vec![1.0, 1.0],
1405            gate_shift: vec![0.0, 0.0],
1406            gate: RowGate::Softmax { inv_tau },
1407            logit_slot: vec![Some(0), Some(1)],
1408            coord_slot: vec![vec![2, 3], vec![4, 5]],
1409            n_primaries: 6,
1410        };
1411        (prog, inv_tau)
1412    }
1413
1414    /// Parametrized softmax fixture with `n_atoms` softmax atoms, each carrying a
1415    /// free logit primary and `latent_dim` free coord primaries, so
1416    /// `n_primaries = n_atoms·(1 + latent_dim)`. Layout: logit slots
1417    /// `0..n_atoms`, then atom `k`'s coord axis `j` at `n_atoms + k·latent_dim +
1418    /// j`. Used by the #932 ns/row microbench to instantiate the tower at
1419    /// `K = n_primaries` for `K ∈ {8, 16}` (the softmax gate Hessian is `n_atoms³`,
1420    /// the cost driver the hand path pays per output column).
1421    fn softmax_fixture_k(
1422        n_atoms: usize,
1423        latent_dim: usize,
1424        n_basis: usize,
1425        out_dim: usize,
1426        inv_tau: f64,
1427    ) -> SaeReconstructionRowProgram {
1428        let mk_atom = |seed: f64| {
1429            let phi: Vec<f64> = (0..n_basis)
1430                .map(|b| 0.3 + 0.2 * (b as f64 + seed))
1431                .collect();
1432            let d_phi: Vec<Vec<f64>> = (0..n_basis)
1433                .map(|b| {
1434                    (0..latent_dim)
1435                        .map(|axis| 0.1 * (b as f64 + 1.0) - 0.05 * axis as f64 + 0.03 * seed)
1436                        .collect()
1437                })
1438                .collect();
1439            let d2_phi: Vec<Vec<Vec<f64>>> = (0..n_basis)
1440                .map(|b| {
1441                    (0..latent_dim)
1442                        .map(|a| {
1443                            (0..latent_dim)
1444                                .map(|bb| {
1445                                    0.02 * (b as f64 + 1.0)
1446                                        + 0.01 * (a as f64)
1447                                        + 0.01 * (bb as f64)
1448                                        + 0.004 * seed
1449                                })
1450                                .collect()
1451                        })
1452                        .collect()
1453                })
1454                .collect();
1455            let decoder: Vec<Vec<f64>> = (0..n_basis)
1456                .map(|b| {
1457                    (0..out_dim)
1458                        .map(|c| 0.5 - 0.1 * (b as f64) + 0.07 * (c as f64) + 0.02 * seed)
1459                        .collect()
1460                })
1461                .collect();
1462            AtomRowBasisJet {
1463                phi,
1464                d_phi,
1465                d2_phi,
1466                decoder,
1467                latent_dim,
1468            }
1469        };
1470        let logits: Vec<f64> = (0..n_atoms)
1471            .map(|k| 0.4 - 0.13 * k as f64 + 0.05 * (k as f64).sin())
1472            .collect();
1473        let e: Vec<f64> = logits.iter().map(|&l| (l * inv_tau).exp()).collect();
1474        let s: f64 = e.iter().sum();
1475        let gate_value: Vec<f64> = e.iter().map(|&v| v / s).collect();
1476        let atoms: Vec<AtomRowBasisJet> = (0..n_atoms).map(|k| mk_atom(k as f64)).collect();
1477        let logit_slot: Vec<Option<usize>> = (0..n_atoms).map(Some).collect();
1478        let coord_slot: Vec<Vec<usize>> = (0..n_atoms)
1479            .map(|k| (0..latent_dim).map(|j| n_atoms + k * latent_dim + j).collect())
1480            .collect();
1481        SaeReconstructionRowProgram {
1482            atoms,
1483            gate_value,
1484            logits,
1485            gate_scale: vec![1.0; n_atoms],
1486            gate_shift: vec![0.0; n_atoms],
1487            gate: RowGate::Softmax { inv_tau },
1488            logit_slot,
1489            coord_slot,
1490            n_primaries: n_atoms * (1 + latent_dim),
1491        }
1492    }
1493
1494    /// #932 correctness gate: the production packed jet recon
1495    /// ([`SaeReconstructionRowProgram::reconstruction_all_columns_packed`], gate +
1496    /// basis jets HOISTED out of the column loop, softmax denom/recip SHARED) and
1497    /// the per-column packed call must each reproduce the hand path
1498    /// ([`hand_softmax_column`], the old `row_jets_for_logdet` closed-form softmax
1499    /// gate Jacobian/Hessian × decoded basis, re-derived per output column) on
1500    /// value/grad/Hessian — the #932 bit-identity bar. (The ns/row timing
1501    /// comparison this gate used to precede lives in `bench/`, not in a `#[test]`:
1502    /// `#[ignore]`d timing benches are banned by `build.rs`.)
1503    #[test]
1504    fn recon_jet_matches_hand_path_value_grad_hess() {
1505        let out_dim = 16;
1506        let n_basis = 4;
1507        let inv_tau = 1.3;
1508        // K=8: 4 atoms × (1 logit + 1 coord) = 8 primaries.
1509        check_recon_vs_hand::<8>(softmax_fixture_k(4, 1, n_basis, out_dim, inv_tau), inv_tau);
1510        // K=16: 8 atoms × (1 logit + 1 coord) = 16 primaries.
1511        check_recon_vs_hand::<16>(softmax_fixture_k(8, 1, n_basis, out_dim, inv_tau), inv_tau);
1512    }
1513
1514    fn check_recon_vs_hand<const K: usize>(prog: SaeReconstructionRowProgram, inv_tau: f64) {
1515        let out_dim = prog.out_dim();
1516        let cols = prog.reconstruction_all_columns_packed::<K>();
1517        for c in 0..out_dim {
1518            let hand = hand_softmax_column(&prog, c, inv_tau);
1519            let h_floor = hand
1520                .second
1521                .iter()
1522                .flatten()
1523                .fold(0.0_f64, |m, x| m.max(x.abs()));
1524            // The all-columns (hoisted) path matches hand value + Hessian.
1525            assert!((cols[c].value() - hand.value).abs() <= 1e-9 * hand.value.abs().max(1.0));
1526            // The per-column path matches the all-columns path (same kernel, no hoist).
1527            let percol = prog.reconstruction_column_packed::<K>(c);
1528            assert!((percol.value() - cols[c].value()).abs() <= 1e-12 * cols[c].value().abs().max(1.0));
1529            for a in 0..K {
1530                for b in 0..K {
1531                    assert!(
1532                        (cols[c].h()[a][b] - hand.second[a][b]).abs()
1533                            <= 1e-8 * h_floor.max(1e-12)
1534                    );
1535                    assert!(
1536                        (percol.h()[a][b] - cols[c].h()[a][b]).abs()
1537                            <= 1e-12 * h_floor.max(1e-12)
1538                    );
1539                }
1540            }
1541        }
1542    }
1543
1544    /// INDEPENDENT scalar witness for the reconstruction column `ẑ_c(δ)` as a
1545    /// function of the primary-increment vector `δ` (the displacement of each
1546    /// tower primary from its seed value: a coord primary seeds at value 0, a
1547    /// logit primary at its current logit, so `δ` is the same offset the tower's
1548    /// seeded variables carry). This evaluator touches NONE of the `Tower4`
1549    /// arithmetic — no Leibniz product, no Faà di Bruno compose, no
1550    /// `basis_tower`/`decoded_tower`/`gate_tower` — it re-derives the closed-form
1551    /// reconstruction from the raw jet tensors and the softmax definition. It is
1552    /// the witness the t3/t4 FD oracle differences below.
1553    ///
1554    /// `ẑ_c(δ) = Σ_k softmax_k((ℓ + δ_logit)·inv_tau) · Σ_b Φ̃_{k,b}(δ_coord)·B_{k,b,c}`
1555    /// with the SAME local quadratic basis model the program consumes:
1556    /// `Φ̃_b(u) = phi[b] + Σ_a d_phi[b][a]·u_a + ½ Σ_{a,a'} d2_phi[b][a][a']·u_a·u_{a'}`.
1557    fn recon_scalar_softmax(
1558        prog: &SaeReconstructionRowProgram,
1559        out_col: usize,
1560        inv_tau: f64,
1561        delta: &[f64],
1562    ) -> f64 {
1563        let k = prog.atoms.len();
1564        // Softmax over (logit + δ_logit) for atoms with a free logit primary;
1565        // atoms without one keep their base logit (no δ).
1566        let exps: Vec<f64> = (0..k)
1567            .map(|kk| {
1568                let dl = match prog.logit_slot[kk] {
1569                    Some(slot) => delta[slot],
1570                    None => 0.0,
1571                };
1572                ((prog.logits[kk] + dl) * inv_tau).exp()
1573            })
1574            .collect();
1575        let denom: f64 = exps.iter().sum();
1576        let mut acc = 0.0;
1577        for kk in 0..k {
1578            let gate = exps[kk] / denom;
1579            let atom = &prog.atoms[kk];
1580            // decoded_{kk,c}(δ_coord) via the local quadratic basis model.
1581            let mut decoded = 0.0;
1582            for b in 0..atom.n_basis() {
1583                let mut phi = atom.phi[b];
1584                for a in 0..atom.latent_dim {
1585                    let ua = delta[prog.coord_slot[kk][a]];
1586                    phi += atom.d_phi[b][a] * ua;
1587                }
1588                for a in 0..atom.latent_dim {
1589                    let ua = delta[prog.coord_slot[kk][a]];
1590                    for a2 in 0..atom.latent_dim {
1591                        let ub = delta[prog.coord_slot[kk][a2]];
1592                        phi += 0.5 * atom.d2_phi[b][a][a2] * ua * ub;
1593                    }
1594                }
1595                decoded += phi * atom.decoder[b][out_col];
1596            }
1597            acc += gate * decoded;
1598        }
1599        acc
1600    }
1601
1602    /// Fourth-order central FD of `recon_scalar_softmax` along axes (a,b,c,d) at
1603    /// the origin (δ = 0, the tower seed point). Uses the standard mixed
1604    /// fourth-difference stencil with sign vector ±h on each of the four axes
1605    /// (axes may coincide). 2⁴ = 16 evaluations.
1606    fn fd_fourth(
1607        prog: &SaeReconstructionRowProgram,
1608        out_col: usize,
1609        inv_tau: f64,
1610        axes: [usize; 4],
1611        h: f64,
1612    ) -> f64 {
1613        let n = prog.n_primaries;
1614        let mut acc = 0.0;
1615        for mask in 0..16u32 {
1616            let mut delta = vec![0.0_f64; n];
1617            let mut sign = 1.0;
1618            for (slot, &ax) in axes.iter().enumerate() {
1619                if (mask >> slot) & 1 == 1 {
1620                    delta[ax] += h;
1621                } else {
1622                    delta[ax] -= h;
1623                    sign = -sign;
1624                }
1625            }
1626            acc += sign * recon_scalar_softmax(prog, out_col, inv_tau, &delta);
1627        }
1628        acc / (16.0 * h * h * h * h)
1629    }
1630
1631    /// Third-order central FD of `recon_scalar_softmax` along axes (a,b,c) at the
1632    /// origin: 2³ = 8 evaluations with the mixed third-difference stencil.
1633    fn fd_third(
1634        prog: &SaeReconstructionRowProgram,
1635        out_col: usize,
1636        inv_tau: f64,
1637        axes: [usize; 3],
1638        h: f64,
1639    ) -> f64 {
1640        let n = prog.n_primaries;
1641        let mut acc = 0.0;
1642        for mask in 0..8u32 {
1643            let mut delta = vec![0.0_f64; n];
1644            let mut sign = 1.0;
1645            for (slot, &ax) in axes.iter().enumerate() {
1646                if (mask >> slot) & 1 == 1 {
1647                    delta[ax] += h;
1648                } else {
1649                    delta[ax] -= h;
1650                    sign = -sign;
1651                }
1652            }
1653            acc += sign * recon_scalar_softmax(prog, out_col, inv_tau, &delta);
1654        }
1655        acc / (8.0 * h * h * h)
1656    }
1657
1658    /// The #932 follow-up the issue flagged as missing: the SAE reconstruction
1659    /// program's THIRD- and FOURTH-order channels (`t3`/`t4`) validated against an
1660    /// INDEPENDENT witness (`recon_scalar_softmax`, finite-differenced), not just
1661    /// the value/first/second channels the hand-path oracle covers. Both the
1662    /// witness and the differencing are independent of the `Tower4` Leibniz /
1663    /// Faà-di-Bruno arithmetic that produces `t3`/`t4`, so agreement is a real
1664    /// cross-check of those higher-order channels — the analog of the survival
1665    /// kernel's `row_third_contracted` oracle, extended to fourth order.
1666    #[test]
1667    fn softmax_reconstruction_t3_t4_match_independent_fd_witness() {
1668        let (prog, inv_tau) = softmax_fixture(1.1);
1669        // Mixed fifth-derivative magnitude bounds the central-FD truncation; a
1670        // moderate step keeps both truncation and roundoff well under tol.
1671        let h3 = 2e-3;
1672        let h4 = 1e-2;
1673        for out_col in 0..prog.out_dim() {
1674            let tower = prog.reconstruction_column::<6>(out_col);
1675
1676            let t3_floor = tower
1677                .t3
1678                .iter()
1679                .flatten()
1680                .flatten()
1681                .fold(0.0_f64, |m, x| m.max(x.abs()))
1682                .max(1e-9);
1683            let t4_floor = tower
1684                .t4
1685                .iter()
1686                .flatten()
1687                .flatten()
1688                .flatten()
1689                .fold(0.0_f64, |m, x| m.max(x.abs()))
1690                .max(1e-9);
1691
1692            for a in 0..6 {
1693                for b in 0..6 {
1694                    for c in 0..6 {
1695                        let fd = fd_third(&prog, out_col, inv_tau, [a, b, c], h3);
1696                        assert!(
1697                            (tower.t3[a][b][c] - fd).abs() <= 5e-5 * t3_floor,
1698                            "col {out_col} t3[{a}][{b}][{c}]: tower {:+.10e} vs fd {:+.10e}",
1699                            tower.t3[a][b][c],
1700                            fd
1701                        );
1702                        for d in 0..6 {
1703                            let fd4 = fd_fourth(&prog, out_col, inv_tau, [a, b, c, d], h4);
1704                            assert!(
1705                                (tower.t4[a][b][c][d] - fd4).abs() <= 5e-4 * t4_floor,
1706                                "col {out_col} t4[{a}][{b}][{c}][{d}]: tower {:+.10e} vs fd {:+.10e}",
1707                                tower.t4[a][b][c][d],
1708                                fd4
1709                            );
1710                        }
1711                    }
1712                }
1713            }
1714        }
1715    }
1716
1717    /// A planted #736-style corruption in a t3 OR t4 channel is caught by the
1718    /// independent FD witness (loud at introduction). We perturb a copy of the
1719    /// tower's higher-order channel and assert the witness disagrees.
1720    #[test]
1721    fn planted_t3_t4_corruption_is_caught_by_fd_witness() {
1722        let (prog, inv_tau) = softmax_fixture(1.1);
1723        let out_col = 2;
1724        let tower = prog.reconstruction_column::<6>(out_col);
1725        // A real logit×coord×coord third block (atom-0 logit slot 0, atom-0
1726        // coords 2,3): the witness's third FD must match it...
1727        let axes3 = [0usize, 2, 3];
1728        let fd3 = fd_third(&prog, out_col, inv_tau, axes3, 2e-3);
1729        let t3_floor = tower
1730            .t3
1731            .iter()
1732            .flatten()
1733            .flatten()
1734            .fold(0.0_f64, |m, x| m.max(x.abs()))
1735            .max(1e-9);
1736        assert!(
1737            (tower.t3[0][2][3] - fd3).abs() <= 5e-5 * t3_floor,
1738            "honest t3 must match witness"
1739        );
1740        // ...and a sign-flipped copy must NOT.
1741        let corrupt = -tower.t3[0][2][3];
1742        assert!(
1743            (corrupt - fd3).abs() > 5e-5 * t3_floor,
1744            "a sign-flipped t3 block must disagree with the FD witness"
1745        );
1746
1747        let axes4 = [0usize, 0, 2, 3];
1748        let fd4 = fd_fourth(&prog, out_col, inv_tau, axes4, 1e-2);
1749        let t4_floor = tower
1750            .t4
1751            .iter()
1752            .flatten()
1753            .flatten()
1754            .flatten()
1755            .fold(0.0_f64, |m, x| m.max(x.abs()))
1756            .max(1e-9);
1757        let corrupt4 = tower.t4[0][0][2][3] + 10.0 * t4_floor;
1758        assert!(
1759            (corrupt4 - fd4).abs() > 5e-4 * t4_floor,
1760            "a corrupted t4 block must disagree with the FD witness"
1761        );
1762    }
1763
1764    #[test]
1765    fn softmax_reconstruction_tower_matches_hand_channels_all_columns() {
1766        let (prog, inv_tau) = softmax_fixture(1.3);
1767        for out_col in 0..prog.out_dim() {
1768            let tower = prog.reconstruction_column::<6>(out_col);
1769            let hand = hand_softmax_column(&prog, out_col, inv_tau);
1770
1771            // Magnitude floors so structurally-zero entries don't demand
1772            // absolute equality (the verify_kernel_channels convention).
1773            let g_floor = tower.g.iter().fold(0.0_f64, |m, x| m.max(x.abs()));
1774            let h_floor = tower
1775                .h
1776                .iter()
1777                .flatten()
1778                .fold(0.0_f64, |m, x| m.max(x.abs()));
1779
1780            assert!(
1781                (tower.v - hand.value).abs() <= 1e-9 * hand.value.abs().max(1.0),
1782                "col {out_col} value: tower {} vs hand {}",
1783                tower.v,
1784                hand.value
1785            );
1786            for a in 0..6 {
1787                assert!(
1788                    (tower.g[a] - hand.first[a]).abs() <= 1e-9 * g_floor.max(1e-12),
1789                    "col {out_col} first[{a}]: tower {} vs hand {}",
1790                    tower.g[a],
1791                    hand.first[a]
1792                );
1793                for b in 0..6 {
1794                    assert!(
1795                        (tower.h[a][b] - hand.second[a][b]).abs() <= 1e-8 * h_floor.max(1e-12),
1796                        "col {out_col} second[{a}][{b}]: tower {} vs hand {}",
1797                        tower.h[a][b],
1798                        hand.second[a][b]
1799                    );
1800                }
1801            }
1802        }
1803    }
1804
1805    /// A planted sign flip in the hand cross-block (logit×coord) is caught by the
1806    /// oracle — the same failure that #736 was, made loud at introduction.
1807    #[test]
1808    fn planted_cross_block_sign_flip_is_caught() {
1809        let (prog, inv_tau) = softmax_fixture(1.3);
1810        let out_col = 1;
1811        let tower = prog.reconstruction_column::<6>(out_col);
1812        let mut hand = hand_softmax_column(&prog, out_col, inv_tau);
1813        // Corrupt one logit×coord cross block (atom-0 logit slot 0, atom-1
1814        // coord slot 4): flip its sign, the #736 disease.
1815        hand.second[0][4] = -hand.second[0][4];
1816        hand.second[4][0] = -hand.second[4][0];
1817        let h_floor = tower
1818            .h
1819            .iter()
1820            .flatten()
1821            .fold(0.0_f64, |m, x| m.max(x.abs()));
1822        let disagrees = (tower.h[0][4] - hand.second[0][4]).abs() > 1e-8 * h_floor.max(1e-12);
1823        assert!(
1824            disagrees,
1825            "a flipped cross block must disagree with the tower truth"
1826        );
1827    }
1828
1829    /// The tower gate channels alone reproduce the softmax `gate_derivatives_for_row`
1830    /// arithmetic — isolating the gate nonlinearity from the basis/decoder so a
1831    /// regression in either is localizable.
1832    #[test]
1833    fn softmax_gate_tower_matches_hand_gate_derivatives() {
1834        let (prog, inv_tau) = softmax_fixture(0.9);
1835        let (dz, d2z) = softmax_gate_derivs(&prog.gate_value, inv_tau);
1836        for atom in 0..prog.atoms.len() {
1837            let gate = prog.gate_tower::<6, Tower4<6>>(atom);
1838            // ζ_atom value.
1839            assert!((gate.v - prog.gate_value[atom]).abs() < 1e-12);
1840            // ∂ζ_atom/∂ℓ_j == dz[j][atom].
1841            for j in 0..prog.atoms.len() {
1842                let slot = prog.logit_slot[j].unwrap();
1843                assert!(
1844                    (gate.g[slot] - dz[j][atom]).abs() < 1e-9,
1845                    "gate {atom} d/dlogit {j}: tower {} vs hand {}",
1846                    gate.g[slot],
1847                    dz[j][atom]
1848                );
1849            }
1850            // ∂²ζ_atom/∂ℓ_j∂ℓ_l == d2z[j][l][atom].
1851            for j in 0..prog.atoms.len() {
1852                for l in 0..prog.atoms.len() {
1853                    let sj = prog.logit_slot[j].unwrap();
1854                    let sl = prog.logit_slot[l].unwrap();
1855                    assert!(
1856                        (gate.h[sj][sl] - d2z[j][l][atom]).abs() < 1e-8,
1857                        "gate {atom} d2/dlogit {j}{l}: tower {} vs hand {}",
1858                        gate.h[sj][sl],
1859                        d2z[j][l][atom]
1860                    );
1861                }
1862            }
1863        }
1864    }
1865
1866    /// The per-atom logistic gate (IBP/JumpReLU branch) is diagonal in the
1867    /// logits and reproduces `σ' = σ(1−σ)·inv_tau`, `σ'' = σ(1−σ)(1−2σ)·inv_tau²`.
1868    #[test]
1869    fn per_atom_logistic_gate_matches_closed_form() {
1870        let inv_tau = 1.4;
1871        let logit = 0.6;
1872        let shift = 0.2;
1873        let x: f64 = (logit - shift) * inv_tau;
1874        let sigma = 1.0 / (1.0 + (-x).exp());
1875        let prog = SaeReconstructionRowProgram {
1876            atoms: vec![AtomRowBasisJet {
1877                phi: vec![1.0],
1878                d_phi: vec![vec![0.0]],
1879                d2_phi: vec![vec![vec![0.0]]],
1880                decoder: vec![vec![1.0]],
1881                latent_dim: 1,
1882            }],
1883            gate_value: vec![sigma],
1884            logits: vec![logit],
1885            gate_scale: vec![1.0],
1886            gate_shift: vec![shift],
1887            gate: RowGate::PerAtomLogistic { inv_tau },
1888            logit_slot: vec![Some(0)],
1889            coord_slot: vec![vec![1]],
1890            n_primaries: 2,
1891        };
1892        let gate = prog.gate_tower::<2, Tower4<2>>(0);
1893        assert!((gate.v - sigma).abs() < 1e-12);
1894        let d1 = sigma * (1.0 - sigma) * inv_tau;
1895        let d2 = sigma * (1.0 - sigma) * (1.0 - 2.0 * sigma) * inv_tau * inv_tau;
1896        assert!((gate.g[0] - d1).abs() < 1e-9, "σ': {} vs {}", gate.g[0], d1);
1897        assert!(
1898            (gate.h[0][0] - d2).abs() < 1e-9,
1899            "σ'': {} vs {}",
1900            gate.h[0][0],
1901            d2
1902        );
1903    }
1904
1905    /// #932 cutover pin: the PRODUCTION packed [`Order2`] reconstruction path
1906    /// (`reconstruction_column_packed`) is BIT-IDENTICAL on the
1907    /// value/gradient/Hessian channels to the dense [`Tower4`] oracle
1908    /// (`reconstruction_column`) — the same channels the arrow-Schur logdet
1909    /// consumer reads — for every output column. The Order2 path never
1910    /// materialises `t3`/`t4`, but its `(v, g, H)` must match the dense tower's
1911    /// order-≤2 channels to ≤1e-12 (they share the `Tower2` arithmetic), so the
1912    /// cutover changes only cost, not result.
1913    #[test]
1914    fn order2_reconstruction_matches_tower_value_grad_hessian() {
1915        for tau in [0.9_f64, 1.3, 2.1] {
1916            let (prog, _inv_tau) = softmax_fixture(tau);
1917            for out_col in 0..prog.out_dim() {
1918                let packed = prog.reconstruction_column_packed::<6>(out_col);
1919                let tower = prog.reconstruction_column::<6>(out_col);
1920                let g = packed.g();
1921                let h = packed.h();
1922                let band = |x: f64| 1e-12 + 1e-12 * x.abs();
1923                assert!(
1924                    (packed.value() - tower.v).abs() <= band(tower.v),
1925                    "col {out_col} value: order2 {} vs tower {}",
1926                    packed.value(),
1927                    tower.v
1928                );
1929                for a in 0..6 {
1930                    assert!(
1931                        (g[a] - tower.g[a]).abs() <= band(tower.g[a]),
1932                        "col {out_col} g[{a}]: order2 {} vs tower {}",
1933                        g[a],
1934                        tower.g[a]
1935                    );
1936                    for b in 0..6 {
1937                        assert!(
1938                            (h[a][b] - tower.h[a][b]).abs() <= band(tower.h[a][b]),
1939                            "col {out_col} h[{a}][{b}]: order2 {} vs tower {}",
1940                            h[a][b],
1941                            tower.h[a][b]
1942                        );
1943                    }
1944                }
1945            }
1946        }
1947    }
1948
1949    /// #932 cutover pin for the β border channel: the packed [`Order2`]
1950    /// `beta_border_tower_packed` matches the dense [`Tower4`]
1951    /// `beta_border_tower` on the value (`beta`) and gradient (`beta_deriv` /
1952    /// `beta_l_deriv`) channels the consumer reads, to ≤1e-12.
1953    #[test]
1954    fn order2_beta_border_matches_tower_value_grad() {
1955        let (prog, _inv_tau) = softmax_fixture(1.1);
1956        for atom in 0..prog.atoms.len() {
1957            for basis_col in 0..prog.atoms[atom].n_basis() {
1958                let packed = prog.beta_border_tower_packed::<6>(atom, basis_col);
1959                let tower = prog.beta_border_tower::<6>(atom, basis_col);
1960                let g = packed.g();
1961                let band = |x: f64| 1e-12 + 1e-12 * x.abs();
1962                assert!(
1963                    (packed.value() - tower.v).abs() <= band(tower.v),
1964                    "atom {atom} b {basis_col} value: order2 {} vs tower {}",
1965                    packed.value(),
1966                    tower.v
1967                );
1968                for a in 0..6 {
1969                    assert!(
1970                        (g[a] - tower.g[a]).abs() <= band(tower.g[a]),
1971                        "atom {atom} b {basis_col} g[{a}]: order2 {} vs tower {}",
1972                        g[a],
1973                        tower.g[a]
1974                    );
1975                }
1976            }
1977        }
1978    }
1979
1980    /// #932 perf pin: the gate-shared `all_gates` produces gate jets
1981    /// BIT-IDENTICAL to the per-atom `gate_tower` — sharing the softmax
1982    /// denominator / reciprocal across atoms (K exps + 1 recip instead of
1983    /// K² + K) changes only which redundant work is elided, not the result
1984    /// (`ζ_k = exp_k · recip(denom)` is the same product, same Leibniz order).
1985    #[test]
1986    fn shared_all_gates_bit_identical_to_per_atom_gate_tower() {
1987        for tau in [0.9_f64, 1.3, 2.1] {
1988            let (prog, _inv_tau) = softmax_fixture(tau);
1989            let all = prog.all_gates::<6, Order2<6>>();
1990            assert_eq!(all.len(), prog.gate_value.len());
1991            for atom in 0..prog.gate_value.len() {
1992                let per = prog.gate_tower::<6, Order2<6>>(atom);
1993                assert_eq!(all[atom].value(), per.value(), "atom {atom} value");
1994                for a in 0..6 {
1995                    assert_eq!(all[atom].g()[a], per.g()[a], "atom {atom} g[{a}]");
1996                    for b in 0..6 {
1997                        assert_eq!(
1998                            all[atom].h()[a][b],
1999                            per.h()[a][b],
2000                            "atom {atom} h[{a}][{b}]"
2001                        );
2002                    }
2003                }
2004            }
2005        }
2006    }
2007
2008    /// #932 perf pin: the gate/basis-HOISTED + denominator-SHARED all-columns
2009    /// reconstruction (`reconstruction_all_columns_packed`) is BIT-IDENTICAL to
2010    /// calling `reconstruction_column_packed(c)` per column — the hoist + share
2011    /// removes only redundant gate/basis/denominator recomputation, not any
2012    /// arithmetic. Every value/grad/Hessian channel must match exactly (==),
2013    /// since the Leibniz products are the same in the same order.
2014    #[test]
2015    fn hoisted_all_columns_bit_identical_to_per_column() {
2016        for tau in [0.9_f64, 1.3, 2.1] {
2017            let (prog, _inv_tau) = softmax_fixture(tau);
2018            let all = prog.reconstruction_all_columns_packed::<6>();
2019            assert_eq!(all.len(), prog.out_dim());
2020            for out_col in 0..prog.out_dim() {
2021                let per = prog.reconstruction_column_packed::<6>(out_col);
2022                let ah = all[out_col];
2023                assert_eq!(ah.value(), per.value(), "col {out_col} value");
2024                for a in 0..6 {
2025                    assert_eq!(ah.g()[a], per.g()[a], "col {out_col} g[{a}]");
2026                    for b in 0..6 {
2027                        assert_eq!(ah.h()[a][b], per.h()[a][b], "col {out_col} h[{a}][{b}]");
2028                    }
2029                }
2030            }
2031        }
2032    }
2033
2034    /// Build four softmax-aligned row programs that differ ONLY in their per-row
2035    /// numeric data (logits, basis values, decoder), keeping the layout
2036    /// (slots / dims / temperature) identical so they are 4-row SIMD-batchable.
2037    fn softmax_batch_fixture(inv_tau: f64) -> [SaeReconstructionRowProgram; LANES] {
2038        let n_basis = 3;
2039        let out_dim = 4;
2040        let mk = |row_seed: f64| {
2041            let mk_atom = |seed: f64| {
2042                let phi: Vec<f64> = (0..n_basis)
2043                    .map(|b| 0.3 + 0.2 * (b as f64 + seed) + 0.11 * row_seed)
2044                    .collect();
2045                let d_phi: Vec<Vec<f64>> = (0..n_basis)
2046                    .map(|b| {
2047                        (0..2)
2048                            .map(|axis| {
2049                                0.1 * (b as f64 + 1.0) - 0.05 * axis as f64 + 0.03 * seed
2050                                    + 0.017 * row_seed
2051                            })
2052                            .collect()
2053                    })
2054                    .collect();
2055                let d2_phi: Vec<Vec<Vec<f64>>> = (0..n_basis)
2056                    .map(|b| {
2057                        (0..2)
2058                            .map(|a| {
2059                                (0..2)
2060                                    .map(|bb| {
2061                                        0.02 * (b as f64 + 1.0)
2062                                            + 0.01 * (a as f64)
2063                                            + 0.01 * (bb as f64)
2064                                            + 0.004 * seed
2065                                            + 0.003 * row_seed
2066                                    })
2067                                    .collect()
2068                            })
2069                            .collect()
2070                    })
2071                    .collect();
2072                let decoder: Vec<Vec<f64>> = (0..n_basis)
2073                    .map(|b| {
2074                        (0..out_dim)
2075                            .map(|c| {
2076                                0.5 - 0.1 * (b as f64) + 0.07 * (c as f64) + 0.02 * seed
2077                                    + 0.009 * row_seed
2078                            })
2079                            .collect()
2080                    })
2081                    .collect();
2082                AtomRowBasisJet {
2083                    phi,
2084                    d_phi,
2085                    d2_phi,
2086                    decoder,
2087                    latent_dim: 2,
2088                }
2089            };
2090            let logits = vec![0.4 + 0.21 * row_seed, -0.7 + 0.13 * row_seed];
2091            let e: Vec<f64> = logits.iter().map(|&l| (l * inv_tau).exp()).collect();
2092            let s: f64 = e.iter().sum();
2093            let gate_value: Vec<f64> = e.iter().map(|&v| v / s).collect();
2094            SaeReconstructionRowProgram {
2095                atoms: vec![mk_atom(0.0), mk_atom(1.0)],
2096                gate_value,
2097                logits,
2098                gate_scale: vec![1.0, 1.0],
2099                gate_shift: vec![0.0, 0.0],
2100                gate: RowGate::Softmax { inv_tau },
2101                logit_slot: vec![Some(0), Some(1)],
2102                coord_slot: vec![vec![2, 3], vec![4, 5]],
2103                n_primaries: 6,
2104            }
2105        };
2106        [mk(0.0), mk(1.0), mk(2.0), mk(3.0)]
2107    }
2108
2109    /// SIMD-batch bit-identity oracle: `reconstruction_all_columns_batch4` lane
2110    /// `i` is `to_bits`-identical to the scalar `reconstruction_all_columns_packed`
2111    /// on row `i`, across many temperatures and randomized per-row data
2112    /// (≥2000 channel comparisons). The 4-row SIMD pass changes only how many
2113    /// rows share one instruction stream, never the arithmetic.
2114    #[test]
2115    fn batch4_reconstruction_bit_identical_to_per_row() {
2116        let mut comparisons = 0usize;
2117        for tau in [0.7_f64, 0.9, 1.1, 1.3, 1.7, 2.1, 2.9] {
2118            let rows = softmax_batch_fixture(tau);
2119            let refs = [&rows[0], &rows[1], &rows[2], &rows[3]];
2120            let batch = SaeReconstructionRowProgram::reconstruction_all_columns_batch4::<6>(refs)
2121                .expect("softmax-aligned rows must batch");
2122            for lane in 0..LANES {
2123                let per = rows[lane].reconstruction_all_columns_packed::<6>();
2124                assert_eq!(per.len(), batch[lane].len());
2125                for (c, (b, p)) in batch[lane].iter().zip(per.iter()).enumerate() {
2126                    assert_eq!(
2127                        b.value().to_bits(),
2128                        p.value().to_bits(),
2129                        "tau {tau} lane {lane} col {c} value"
2130                    );
2131                    let (bg, pg) = (b.g(), p.g());
2132                    let (bh, ph) = (b.h(), p.h());
2133                    for a in 0..6 {
2134                        assert_eq!(bg[a].to_bits(), pg[a].to_bits(), "lane {lane} col {c} g[{a}]");
2135                        for d in 0..6 {
2136                            assert_eq!(
2137                                bh[a][d].to_bits(),
2138                                ph[a][d].to_bits(),
2139                                "lane {lane} col {c} h[{a}][{d}]"
2140                            );
2141                            comparisons += 1;
2142                        }
2143                    }
2144                }
2145            }
2146        }
2147        assert!(comparisons >= 2000, "oracle ran {comparisons} comparisons");
2148    }
2149
2150    /// SIMD-batch bit-identity oracle for the β-border first-order path:
2151    /// `beta_border_order1_batch4` lane `i` is `to_bits`-identical to
2152    /// `beta_border_order1_packed` on row `i`.
2153    #[test]
2154    fn batch4_beta_border_bit_identical_to_per_row() {
2155        let mut comparisons = 0usize;
2156        for tau in [0.7_f64, 0.9, 1.1, 1.3, 1.7, 2.1, 2.9] {
2157            let rows = softmax_batch_fixture(tau);
2158            let refs = [&rows[0], &rows[1], &rows[2], &rows[3]];
2159            let mut chans: Vec<(usize, usize)> = Vec::new();
2160            for atom in 0..rows[0].atoms.len() {
2161                for b in 0..rows[0].atoms[atom].n_basis() {
2162                    chans.push((atom, b));
2163                }
2164            }
2165            chans.push(chans[0]); // repeat to exercise gate-cache reuse
2166            let batch =
2167                SaeReconstructionRowProgram::beta_border_order1_batch4::<6>(refs, &chans)
2168                    .expect("softmax-aligned rows must batch");
2169            for lane in 0..LANES {
2170                let per = rows[lane].beta_border_order1_packed::<6>(&chans);
2171                assert_eq!(per.len(), batch[lane].len());
2172                for (i, (b, p)) in batch[lane].iter().zip(per.iter()).enumerate() {
2173                    assert_eq!(b.value().to_bits(), p.value().to_bits(), "lane {lane} chan {i} v");
2174                    let (bg, pg) = (b.g(), p.g());
2175                    for a in 0..6 {
2176                        assert_eq!(
2177                            bg[a].to_bits(),
2178                            pg[a].to_bits(),
2179                            "lane {lane} chan {i} g[{a}]"
2180                        );
2181                        comparisons += 1;
2182                    }
2183                }
2184            }
2185        }
2186        assert!(comparisons >= 1000, "oracle ran {comparisons} comparisons");
2187    }
2188
2189    /// A non-softmax (per-atom logistic) batch must DECLINE (return `None`) so the
2190    /// caller falls back to the scalar per-row path — the logistic branch is
2191    /// per-row data-dependent and not lane-uniform.
2192    #[test]
2193    fn batch4_declines_non_softmax() {
2194        let inv_tau = 1.1;
2195        let mk = || SaeReconstructionRowProgram {
2196            atoms: vec![AtomRowBasisJet {
2197                phi: vec![1.0],
2198                d_phi: vec![vec![0.0]],
2199                d2_phi: vec![vec![vec![0.0]]],
2200                decoder: vec![vec![1.0]],
2201                latent_dim: 1,
2202            }],
2203            gate_value: vec![0.6],
2204            logits: vec![0.6],
2205            gate_scale: vec![1.0],
2206            gate_shift: vec![0.2],
2207            gate: RowGate::PerAtomLogistic { inv_tau },
2208            logit_slot: vec![Some(0)],
2209            coord_slot: vec![vec![1]],
2210            n_primaries: 2,
2211        };
2212        let rows = [mk(), mk(), mk(), mk()];
2213        let refs = [&rows[0], &rows[1], &rows[2], &rows[3]];
2214        assert!(
2215            SaeReconstructionRowProgram::reconstruction_all_columns_batch4::<2>(refs).is_none()
2216        );
2217    }
2218
2219    /// #932 perf pin: the gate-HOISTED batched β border jets
2220    /// (`beta_border_towers_packed`) are BIT-IDENTICAL to per-channel
2221    /// `beta_border_tower_packed`, including when several channels share an atom
2222    /// (the gate-cache reuse path).
2223    #[test]
2224    fn hoisted_beta_border_bit_identical_to_per_channel() {
2225        let (prog, _inv_tau) = softmax_fixture(1.1);
2226        // Build a channel list that repeats atoms (exercises the gate cache).
2227        let mut chans: Vec<(usize, usize)> = Vec::new();
2228        for atom in 0..prog.atoms.len() {
2229            for basis_col in 0..prog.atoms[atom].n_basis() {
2230                chans.push((atom, basis_col));
2231            }
2232        }
2233        // Duplicate the first atom's channels at the end to force cache reuse.
2234        if let Some(&first) = chans.first() {
2235            chans.push(first);
2236        }
2237        let batched = prog.beta_border_towers_packed::<6>(&chans);
2238        assert_eq!(batched.len(), chans.len());
2239        for (i, &(atom, basis_col)) in chans.iter().enumerate() {
2240            let per = prog.beta_border_tower_packed::<6>(atom, basis_col);
2241            let b = batched[i];
2242            assert_eq!(b.value(), per.value(), "chan {i} value");
2243            for a in 0..6 {
2244                assert_eq!(b.g()[a], per.g()[a], "chan {i} g[{a}]");
2245            }
2246        }
2247    }
2248}