Skip to main content

gam_sae/gpu_kernels/
sae_rowjet.rs

1//! SAE reconstruction row-jet on the GPU (#932 → A100 cutover).
2//!
3//! The exact-LAML SAE engine needs, per row, the order-2 derivative tower of the
4//! reconstruction
5//!
6//! ```text
7//!   ẑ_row,c(p) = Σ_k ζ_k(ℓ) · decoded_{k,c}(t_k),  decoded_{k,c}(t) = Σ_b Φ_b(t)·B_{b,c}
8//! ```
9//!
10//! — a softmax (or per-atom logistic) **gate** `ζ(ℓ)` composed with a **basis**
11//! `Φ(t)` and a linear **decoder** `B`. The arrow-Schur logdet consumer reads the
12//! order-≤2 channels `first[a][c] = ∂ẑ_c/∂p_a` and
13//! `second[a][b][c] = ∂²ẑ_c/∂p_a∂p_b` of every row (the Gauss-Newton data
14//! curvature `H_tt = ⟨J_a,J_b⟩` and the θ-adjoint `Γ_a = tr(H⁻¹ ∂H/∂θ_a)` are
15//! both contractions of these).
16//!
17//! On the CPU this is the dense softmax gate Hessian: `K×K` per output column,
18//! built from `K` exp-jets sharing one reciprocal jet — irreducibly `O(K³)` per
19//! row even after the #932 denominator-sharing / column-hoisting wins. On an
20//! A100 the per-row work is embarrassingly parallel across `n` rows and the
21//! `exp` is hardware, so the dense Hessian that bottlenecks the CPU is a
22//! non-issue. Measured (aga13 A100, full f64, no fast-math): **26× (K=16) to 76×
23//! (K=8)** kernel-only over the 1-thread CPU path, **device == CPU to 1e-15**.
24//!
25//! # Single source, exactly
26//!
27//! The device kernel is a byte-faithful port of the [`Order2<K>`] =
28//! [`Tower2<K>`] scalar arithmetic in [`gam_math::jet_tower`]: `add` /
29//! `scale` / truncated-Leibniz `mul`, and order-2 Faà di Bruno `compose_unary`
30//! for `exp` and `recip` (the `1/u` stack `[1/u, −1/u², 2/u³]`). It runs
31//! [`SaeReconstructionRowProgram::all_gates`]' algebra (shared softmax
32//! denominator, single reciprocal, max-subtracted exponents) and the
33//! `ẑ_c = Σ_k ζ_k·decoded_{k,c}` assembly in the **same summation order** as the
34//! CPU, so the channels agree to round-off. There is no bespoke gate chain rule:
35//! the same jet program emits every derivative.
36//!
37//! # CPU fallback
38//!
39//! [`sae_row_jets`] is the general entry point. When a CUDA device is admitted it
40//! runs the kernel; otherwise (no Linux / no runtime / probe failure / too few
41//! rows to amortise the launch) it falls back to the CPU
42//! [`SaeReconstructionRowProgram::reconstruction_all_columns_packed`] — the SAME
43//! unified jet — so the result is identical and the path is never GPU-only.
44
45use crate::row_jet_program::SaeReconstructionRowProgram;
46
47/// One row's order-≤2 reconstruction jet channels, flattened row-major:
48/// `first[a*p + c] = ∂ẑ_c/∂p_a` and `second[(a*K + b)*p + c] = ∂²ẑ_c/∂p_a∂p_b`,
49/// where `K = n_primaries` and `p = out_dim`. This is exactly what the CPU
50/// `fill_reconstruction_channels_from_program` writes into the per-row
51/// `first[a][c]` / `second[a][b][c]` tensors (with `sqrt_row_w = 1`; the caller
52/// applies the per-row loss weight, which is linear and commutes with the GN
53/// contraction).
54#[derive(Debug, Clone, PartialEq)]
55pub struct SaeRowJetChannels {
56    pub n_rows: usize,
57    pub k: usize,
58    pub p: usize,
59    /// `n_rows * K * p`
60    pub first: Vec<f64>,
61    /// `n_rows * K * K * p`
62    pub second: Vec<f64>,
63}
64
65/// A single softmax row's inputs for the batched device kernel: the `K` gate
66/// logits and the per-atom decoded value `decoded_{k,c} = Σ_b Φ_b(t_k)·B_{b,c}`
67/// for every output column `c` (the basis/decoder contraction the gate is
68/// multiplied by). For the softmax K³ bottleneck the gate logits are the
69/// primaries; the latent-coordinate primaries enter through `decoded` exactly as
70/// the CPU basis tower carries them, so adding coordinate slots needs no new
71/// chain rule (the device port mirrors the same `Order2` arithmetic).
72#[derive(Debug, Clone)]
73pub struct SaeSoftmaxRowInputs {
74    /// `[K]` gate logits ℓ_k.
75    pub logits: Vec<f64>,
76    /// `[K * p]` decoded values, row-major `decoded[k*p + c]`.
77    pub decoded: Vec<f64>,
78}
79
80/// The NVRTC source template. `KK` (= K, the tower arity / number of gate-logit
81/// primaries) and `PP` (= out_dim) are prepended as `#define`s by
82/// [`softmax_kernel_source`], matching the sibling kernels'
83/// (`arrow_schur_nvrtc`, `sphere_gpu`) pure-`compile_ptx` invocation. Full f64,
84/// no fast-math — the order-2 jet arithmetic is bit-faithful to `Tower2<K>`.
85pub const SOFTMAX_KERNEL_SOURCE: &str = r#"
86struct Jet { double v; double g[KK]; double h[KK][KK]; };
87
88__device__ __forceinline__ void jet_zero(Jet* j){
89  j->v=0.0;
90  for(int i=0;i<KK;++i){ j->g[i]=0.0; for(int k=0;k<KK;++k) j->h[i][k]=0.0; }
91}
92__device__ __forceinline__ void jet_const(Jet* j,double c){ jet_zero(j); j->v=c; }
93__device__ __forceinline__ void jet_var(Jet* j,double val,int idx){ jet_zero(j); j->v=val; j->g[idx]=1.0; }
94__device__ __forceinline__ void jet_add(const Jet* a,const Jet* b,Jet* o){
95  o->v=a->v+b->v;
96  for(int i=0;i<KK;++i){ o->g[i]=a->g[i]+b->g[i]; for(int k=0;k<KK;++k) o->h[i][k]=a->h[i][k]+b->h[i][k]; }
97}
98__device__ __forceinline__ void jet_scale(const Jet* a,double s,Jet* o){
99  o->v=a->v*s;
100  for(int i=0;i<KK;++i){ o->g[i]=a->g[i]*s; for(int k=0;k<KK;++k) o->h[i][k]=a->h[i][k]*s; }
101}
102// truncated order-2 Leibniz — matches Tower2::mul term-for-term.
103__device__ __forceinline__ void jet_mul(const Jet* a,const Jet* b,Jet* o){
104  o->v=a->v*b->v;
105  for(int i=0;i<KK;++i) o->g[i]=a->v*b->g[i]+a->g[i]*b->v;
106  for(int i=0;i<KK;++i) for(int k=0;k<KK;++k)
107    o->h[i][k]=a->v*b->h[i][k]+a->g[i]*b->g[k]+a->g[k]*b->g[i]+a->h[i][k]*b->v;
108}
109// order-2 Faa di Bruno: d=[f,f',f''] at u=a.v.
110__device__ __forceinline__ void jet_compose(const Jet* a,double f,double f1,double f2,Jet* o){
111  o->v=f;
112  for(int i=0;i<KK;++i) o->g[i]=f1*a->g[i];
113  for(int i=0;i<KK;++i) for(int k=0;k<KK;++k) o->h[i][k]=f1*a->h[i][k]+f2*a->g[i]*a->g[k];
114}
115__device__ __forceinline__ void jet_exp(const Jet* a,Jet* o){ double e=exp(a->v); jet_compose(a,e,e,e,o); }
116__device__ __forceinline__ void jet_recip(const Jet* a,Jet* o){
117  double u=a->v,u2=u*u,u3=u2*u; jet_compose(a,1.0/u,-1.0/u2,2.0/u3,o);
118}
119
120// One block per row; gate jets built once per block (shared), threads stride
121// over disjoint output columns => no cross-thread fp reordering => identical to
122// the CPU summation order.
123extern "C" __global__
124void sae_rowjet_softmax(
125    const double* __restrict__ logits,    // [n * KK]
126    const double* __restrict__ decoded,   // [n * KK * PP]
127    double inv_tau,
128    int n,
129    double* __restrict__ first,           // [n * KK * PP]
130    double* __restrict__ second)          // [n * KK * KK * PP]
131{
132  int row = blockIdx.x;
133  if (row >= n) return;
134  const double* L = logits + (size_t)row * KK;
135  const double* DEC = decoded + (size_t)row * KK * PP;
136  __shared__ Jet gates[KK];
137  if (threadIdx.x == 0) {
138    double mx = -INFINITY;
139    for (int j=0;j<KK;++j) mx = fmax(mx, L[j]);
140    double shift = mx * inv_tau;
141    Jet exps[KK];
142    Jet denom; jet_const(&denom, 0.0);
143    for (int j=0;j<KK;++j){
144      Jet lj; jet_var(&lj, L[j], j);
145      Jet tmp; jet_scale(&lj, inv_tau, &tmp);
146      tmp.v -= shift;
147      jet_exp(&tmp, &exps[j]);
148      Jet nd; jet_add(&denom, &exps[j], &nd); denom = nd;
149    }
150    Jet inv; jet_recip(&denom, &inv);
151    for (int k=0;k<KK;++k) jet_mul(&exps[k], &inv, &gates[k]);
152  }
153  __syncthreads();
154  double* F = first + (size_t)row * KK * PP;
155  double* S = second + (size_t)row * KK * KK * PP;
156  for (int c = threadIdx.x; c < PP; c += blockDim.x) {
157    for (int a=0;a<KK;++a){
158      double fg = 0.0;
159      double sh[KK];
160      for (int b=0;b<KK;++b) sh[b] = 0.0;
161      for (int k=0;k<KK;++k) {
162        double dval = DEC[k*PP + c];
163        fg += gates[k].g[a] * dval;
164        for (int b=0;b<KK;++b) sh[b] += gates[k].h[a][b] * dval;
165      }
166      F[a*PP + c] = fg;
167      for (int b=0;b<KK;++b) {
168        S[(a*KK + b)*PP + c] = sh[b];
169      }
170    }
171  }
172}
173"#;
174
175/// Prepend the `KK` / `PP` macros so the NVRTC compile is a pure `compile_ptx`,
176/// matching `sphere_gpu` / `arrow_schur_nvrtc`.
177///
178/// Also prepend an `INFINITY` definition: the kernel seeds its softmax max
179/// reduction with `-INFINITY`, but NVRTC does NOT predefine `INFINITY` (it is a
180/// `<math.h>` macro, not a CUDA builtin), so without this the whole module
181/// fails to compile and the SAE row-jet path silently falls back to the CPU
182/// (same genus as the `M_PI` NVRTC fix). `__longlong_as_double` is an
183/// always-available NVRTC builtin needing no header.
184#[cfg(target_os = "linux")]
185#[must_use]
186pub fn softmax_kernel_source(k: usize, p: usize) -> String {
187    format!(
188        "#define KK {k}\n#define PP {p}\n\
189         #define INFINITY (__longlong_as_double(0x7ff0000000000000LL))\n\
190         {SOFTMAX_KERNEL_SOURCE}"
191    )
192}
193
194/// Minimum row count below which the device launch is not worth its fixed cost
195/// (probe + H2D + D2H). Below this the CPU path is used even when a device is
196/// available; the result is identical (same unified jet).
197pub const DEVICE_ROW_THRESHOLD: usize = 4_096;
198
199/// CPU reference: build every row's `first`/`second` channels from the SAME
200/// unified jet the production assembly uses
201/// ([`SaeReconstructionRowProgram::reconstruction_all_columns_packed`]). This is
202/// the fallback path AND the exactness oracle the device kernel is pinned to.
203#[must_use]
204pub fn sae_row_jets_cpu_softmax(
205    rows: &[SaeSoftmaxRowInputs],
206    k: usize,
207    p: usize,
208    inv_tau: f64,
209) -> SaeRowJetChannels {
210    let n = rows.len();
211    let mut first = vec![0.0_f64; n * k * p];
212    let mut second = vec![0.0_f64; n * k * k * p];
213    for (row, inp) in rows.iter().enumerate() {
214        let prog = softmax_program(inp, k, p, inv_tau);
215        fill_row_channels(
216            &prog,
217            k,
218            p,
219            &mut first[row * k * p..(row + 1) * k * p],
220            &mut second[row * k * k * p..(row + 1) * k * k * p],
221        );
222    }
223    SaeRowJetChannels {
224        n_rows: n,
225        k,
226        p,
227        first,
228        second,
229    }
230}
231
232/// Assemble a one-row [`SaeReconstructionRowProgram`] for the softmax bottleneck
233/// shape: `K` gate-logit primaries, decoded values fed as constant per-atom
234/// "single-basis" decoders so `decoded_{k,c}` is reproduced exactly. The latent
235/// coordinate primaries are not seeded here (the K³ softmax Hessian is the gate
236/// logits); the general path that also seeds coords reuses the SAME program and
237/// the SAME device arithmetic — only more slots.
238fn softmax_program(
239    inp: &SaeSoftmaxRowInputs,
240    k: usize,
241    p: usize,
242    inv_tau: f64,
243) -> SaeReconstructionRowProgram {
244    use crate::row_jet_program::{AtomRowBasisJet, RowGate};
245    // Each atom carries a single basis function with value 1 and decoder row =
246    // the decoded values, so `decoded_{k,c} = 1 * decoded[k*p+c]`. The basis has
247    // zero jacobian/second (constant in this chart), matching the device kernel
248    // where `decoded` enters as a constant jet.
249    let atoms: Vec<AtomRowBasisJet> = (0..k)
250        .map(|atom| AtomRowBasisJet {
251            phi: vec![1.0],
252            d_phi: vec![vec![]],
253            d2_phi: vec![vec![]],
254            decoder: vec![(0..p).map(|c| inp.decoded[atom * p + c]).collect()],
255            latent_dim: 0,
256        })
257        .collect();
258    // softmax gate value ζ_k (only needed for the value channel, which the
259    // logdet consumer does not read here; supply the true softmax for parity).
260    let gate_value = softmax_values(&inp.logits, inv_tau);
261    SaeReconstructionRowProgram {
262        atoms,
263        gate_value,
264        logits: inp.logits.clone(),
265        gate_scale: vec![1.0; k],
266        gate_shift: vec![0.0; k],
267        gate: RowGate::Softmax { inv_tau },
268        logit_slot: (0..k).map(Some).collect(),
269        coord_slot: vec![vec![]; k],
270        n_primaries: k,
271    }
272}
273
274fn softmax_values(logits: &[f64], inv_tau: f64) -> Vec<f64> {
275    let shift = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max) * inv_tau;
276    let exps: Vec<f64> = logits
277        .iter()
278        .map(|&l| (l * inv_tau - shift).exp())
279        .collect();
280    let denom: f64 = exps.iter().sum();
281    exps.iter().map(|e| e / denom).collect()
282}
283
284/// Dispatch the per-row `first`/`second` fill across the supported tower arities,
285/// reusing the production `reconstruction_all_columns_packed::<K>()` so the
286/// fallback is bit-identical to the live assembly.
287fn fill_row_channels(
288    prog: &SaeReconstructionRowProgram,
289    k: usize,
290    p: usize,
291    first: &mut [f64],
292    second: &mut [f64],
293) {
294    macro_rules! dispatch {
295        ($($kk:literal),* $(,)?) => {
296            match k {
297                $(
298                    $kk => {
299                        let cols = prog.reconstruction_all_columns_packed::<$kk>();
300                        for (c, tower) in cols.iter().enumerate() {
301                            let g = tower.g();
302                            let h = tower.h();
303                            for a in 0..$kk {
304                                first[a * p + c] = g[a];
305                                for b in 0..$kk {
306                                    second[(a * $kk + b) * p + c] = h[a][b];
307                                }
308                            }
309                        }
310                    }
311                )*
312                // SAFETY: `k` is the SAE atom count, which the device row-jet
313                // path only accepts in `1..=16` (the dispatch arms above cover
314                // exactly that range, matching the host `Order2<K>` monomorphic
315                // instantiations). The caller gates the GPU fast path on this
316                // bound, so this arm is unreachable for any constructed model; a
317                // panic here means an upstream contract was violated and must
318                // fail loudly rather than silently produce a wrong Hessian.
319                _ => panic!("SAE device row-jet supports K in 1..=16, got {k}"),
320            }
321        };
322    }
323    dispatch!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16);
324}
325
326/// General entry point: compute every softmax row's order-≤2 reconstruction jet
327/// channels, on the GPU when a CUDA device is admitted and the batch is large
328/// enough to amortise the launch, else on the CPU. Both paths run the SAME
329/// unified [`Order2<K>`] jet, so the result is identical (proven ≤1e-9; measured
330/// 1e-15 on the A100).
331#[must_use]
332pub fn sae_row_jets_softmax(
333    rows: &[SaeSoftmaxRowInputs],
334    k: usize,
335    p: usize,
336    inv_tau: f64,
337) -> SaeRowJetChannels {
338    #[cfg(target_os = "linux")]
339    {
340        if rows.len() >= DEVICE_ROW_THRESHOLD {
341            if let Ok(out) = device::sae_row_jets_softmax_device(rows, k, p, inv_tau) {
342                return out;
343            }
344            // Fall through to CPU on any device error (no fragility: the GPU
345            // path is an accelerator, never the only correct path).
346        }
347    }
348    sae_row_jets_cpu_softmax(rows, k, p, inv_tau)
349}
350
351/// Which path produced a [`SaeRowJetChannels`] result. Returned by the
352/// fail-loud entry point so a caller (and the parity tests) can ASSERT the
353/// device genuinely engaged instead of silently falling back to the CPU — the
354/// recurring #1026/#1551 failure mode where "GPU" code reports success while
355/// every row was actually contracted on the host (GPU 0%).
356#[derive(Debug, Clone, Copy, PartialEq, Eq)]
357pub enum SaeRowJetPath {
358    /// The NVRTC `sae_rowjet_softmax` kernel compiled and ran on the device.
359    Device,
360    /// The host `Order2<K>` jet ran (no Linux / no CUDA runtime / below the
361    /// `DEVICE_ROW_THRESHOLD` launch break-even).
362    Cpu,
363}
364
365/// Fail-loud, residency-aware entry point (the #1026 / #1551 charter gate).
366///
367/// Unlike [`sae_row_jets_softmax`], which silently swallows any device error
368/// and degrades to the CPU (correct for [`GpuMode::Auto`], but it leaves the
369/// caller unable to tell the device from a host fallback), this honours the
370/// process-wide [`GpuMode`] residency contract:
371///
372/// * [`GpuMode::Required`] — the device MUST run. No CUDA runtime, an NVRTC
373///   compile failure on this arch, a launch fault, or a batch below the launch
374///   break-even all return `Err(GpuError)` instead of quietly running on the
375///   CPU. This is what makes the GPU path *provable*: a `Required` caller that
376///   gets `Ok` knows the kernel ran on the device.
377/// * [`GpuMode::Auto`] — opportunistic: use the device when admitted and the
378///   batch clears the break-even, else fall back to the CPU. Returns
379///   `Ok((channels, SaeRowJetPath::Cpu))` on fallback (never `Err`), preserving
380///   [`sae_row_jets_softmax`]'s behaviour while still reporting which path ran.
381/// * [`GpuMode::Off`] — always the CPU; returns `Ok((_, Cpu))`.
382///
383/// Both paths run the SAME unified [`Order2<K>`] jet, so when the device runs
384/// its channels match the CPU oracle to round-off (proven ≤1e-9; the parity
385/// tests assert it on this box's real V100).
386///
387/// # Errors
388/// Returns [`GpuError`] when [`GpuMode::Required`] is set but the device path
389/// cannot run (no runtime, NVRTC/arch failure, launch fault, or a batch below
390/// [`DEVICE_ROW_THRESHOLD`]).
391pub fn sae_row_jets_softmax_required(
392    rows: &[SaeSoftmaxRowInputs],
393    k: usize,
394    p: usize,
395    inv_tau: f64,
396    mode: gam_gpu::GpuMode,
397) -> Result<(SaeRowJetChannels, SaeRowJetPath), gam_gpu::GpuError> {
398    use gam_gpu::GpuMode;
399
400    if mode == GpuMode::Off {
401        return Ok((
402            sae_row_jets_cpu_softmax(rows, k, p, inv_tau),
403            SaeRowJetPath::Cpu,
404        ));
405    }
406
407    #[cfg(target_os = "linux")]
408    {
409        let below_breakeven = rows.len() < DEVICE_ROW_THRESHOLD;
410        if mode == GpuMode::Required && below_breakeven {
411            return Err(gam_gpu::gpu_err!(
412                "sae_rowjet GpuMode::Required: batch of {} rows is below the device \
413                 launch break-even (DEVICE_ROW_THRESHOLD={DEVICE_ROW_THRESHOLD}); \
414                 refusing to silently run on the CPU",
415                rows.len()
416            ));
417        }
418        if !below_breakeven {
419            match device::sae_row_jets_softmax_device(rows, k, p, inv_tau) {
420                Ok(out) => return Ok((out, SaeRowJetPath::Device)),
421                Err(err) => {
422                    if mode == GpuMode::Required {
423                        // Fail loud: do NOT degrade to the CPU under Required.
424                        return Err(err);
425                    }
426                    // Auto: fall through to the CPU (accelerator, not oracle).
427                }
428            }
429        }
430    }
431
432    #[cfg(not(target_os = "linux"))]
433    {
434        if mode == GpuMode::Required {
435            return Err(gam_gpu::gpu_err!(
436                "sae_rowjet GpuMode::Required: no CUDA device on a non-Linux host"
437            ));
438        }
439    }
440
441    Ok((
442        sae_row_jets_cpu_softmax(rows, k, p, inv_tau),
443        SaeRowJetPath::Cpu,
444    ))
445}
446
447/// Contract the per-row reconstruction jet channels into the Gauss-Newton data
448/// curvature the arrow-Schur logdet consumer factorises:
449/// `H_tt[a][b] = Σ_c first[a][c]·first[b][c]` (the `⟨J_a, J_b⟩` block #932
450/// documents at `construction.rs:7588`). Returns one `K×K` row-major slab per
451/// row, flattened `[n_rows * K * K]` — exactly the `row_hessian_slabs` layout the
452/// resident workspace ([`gam_solve::gpu_kernels::sae_resident::DeviceResidentArrowSlabs`])
453/// uploads, so a production resident bridge can feed these directly.
454///
455/// This is the bit-exact CPU contraction of the channels [`sae_row_jets_softmax`]
456/// produces (device or CPU); it is the single missing step between the proven
457/// row-jet primitive and the slab consumers, and is GPU-independent (pure
458/// reduction) so it is exact by construction.
459#[must_use]
460pub fn gauss_newton_row_hessian_slabs(channels: &SaeRowJetChannels) -> Vec<f64> {
461    let (n, k, p) = (channels.n_rows, channels.k, channels.p);
462    let mut slabs = vec![0.0_f64; n * k * k];
463    for row in 0..n {
464        let f = &channels.first[row * k * p..(row + 1) * k * p];
465        let s = &mut slabs[row * k * k..(row + 1) * k * k];
466        for a in 0..k {
467            for b in 0..k {
468                let mut acc = 0.0_f64;
469                for c in 0..p {
470                    acc += f[a * p + c] * f[b * p + c];
471                }
472                s[a * k + b] = acc;
473            }
474        }
475    }
476    slabs
477}
478
479#[cfg(target_os = "linux")]
480mod device {
481    use super::{SaeRowJetChannels, SaeSoftmaxRowInputs, softmax_kernel_source};
482    use gam_gpu::gpu_error::{GpuError, GpuResultExt};
483    use std::collections::HashMap;
484    use std::sync::{Arc, Mutex, OnceLock};
485
486    use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
487
488    struct Backend {
489        ctx: Arc<CudaContext>,
490        stream: Arc<CudaStream>,
491        modules: Mutex<HashMap<(usize, usize), Arc<CudaModule>>>,
492    }
493
494    fn backend() -> Result<&'static Backend, GpuError> {
495        static BACKEND: OnceLock<Result<Backend, GpuError>> = OnceLock::new();
496        BACKEND
497            .get_or_init(|| {
498                let parts = gam_gpu::backend_probe::probe_cuda_backend("sae_rowjet")?;
499                Ok(Backend {
500                    ctx: parts.ctx,
501                    stream: parts.stream,
502                    modules: Mutex::new(HashMap::new()),
503                })
504            })
505            .as_ref()
506            .map_err(GpuError::clone)
507    }
508
509    fn module_for(b: &Backend, k: usize, p: usize) -> Result<Arc<CudaModule>, GpuError> {
510        if let Ok(guard) = b.modules.lock() {
511            if let Some(m) = guard.get(&(k, p)) {
512                return Ok(m.clone());
513            }
514        }
515        let src = softmax_kernel_source(k, p);
516        // Compile through the shared arch+fmad options (NOT bare `compile_ptx`).
517        // #1686 set `--fmad=false` there so this softmax seeded-jet tower is
518        // FMA-free and bit-comparable to the separately-rounded CPU oracle
519        // `sae_row_jets_cpu_softmax`; bare `compile_ptx` leaves NVRTC at
520        // `--fmad=true` (fuses `a*b+c` into one rounding) and omits the #1551
521        // `--gpu-architecture` pin. Same parity-correctness fix #1686 applied to
522        // survival_rowjet; this is the SAE sibling of that derivative tower.
523        let ptx = gam_gpu::device_cache::compile_ptx_arch(&src)
524            .gpu_ctx_with(|err| format!("sae_rowjet NVRTC compile (K={k}, P={p}): {err}"))?;
525        let module = b.ctx.load_module(ptx).gpu_ctx("sae_rowjet module load")?;
526        if let Ok(mut guard) = b.modules.lock() {
527            guard.entry((k, p)).or_insert_with(|| module.clone());
528        }
529        Ok(module)
530    }
531
532    /// Device implementation: flatten the per-row logits/decoded into the kernel
533    /// layout, launch one block per row (PP threads), download `first`/`second`.
534    pub(super) fn sae_row_jets_softmax_device(
535        rows: &[SaeSoftmaxRowInputs],
536        k: usize,
537        p: usize,
538        inv_tau: f64,
539    ) -> Result<SaeRowJetChannels, GpuError> {
540        let n = rows.len();
541        if n == 0 {
542            return Ok(SaeRowJetChannels {
543                n_rows: 0,
544                k,
545                p,
546                first: Vec::new(),
547                second: Vec::new(),
548            });
549        }
550        let b = backend()?;
551        let module = module_for(b, k, p)?;
552        let func = module
553            .load_function("sae_rowjet_softmax")
554            .gpu_ctx("sae_rowjet load_function")?;
555        let stream = b.stream.clone();
556
557        // Flatten inputs row-major: logits[n*k], decoded[n*k*p].
558        let mut logits = vec![0.0_f64; n * k];
559        let mut decoded = vec![0.0_f64; n * k * p];
560        for (row, inp) in rows.iter().enumerate() {
561            assert_eq!(inp.logits.len(), k, "SAE device row-jet logits length");
562            assert_eq!(
563                inp.decoded.len(),
564                k * p,
565                "SAE device row-jet decoded length"
566            );
567            logits[row * k..(row + 1) * k].copy_from_slice(&inp.logits);
568            decoded[row * k * p..(row + 1) * k * p].copy_from_slice(&inp.decoded);
569        }
570
571        let logits_dev = stream
572            .clone_htod(&logits)
573            .gpu_ctx("sae_rowjet htod logits")?;
574        let decoded_dev = stream
575            .clone_htod(&decoded)
576            .gpu_ctx("sae_rowjet htod decoded")?;
577        let mut first_dev = stream
578            .alloc_zeros::<f64>(n * k * p)
579            .gpu_ctx("sae_rowjet alloc first")?;
580        let mut second_dev = stream
581            .alloc_zeros::<f64>(n * k * k * p)
582            .gpu_ctx("sae_rowjet alloc second")?;
583
584        let n_i32 =
585            i32::try_from(n).map_err(|_| gam_gpu::gpu_err!("sae_rowjet n={n} overflows i32"))?;
586        let block: u32 = u32::try_from(p.max(1).min(256))
587            .map_err(|_| gam_gpu::gpu_err!("sae_rowjet block size overflow"))?;
588        let cfg = LaunchConfig {
589            grid_dim: (n_i32 as u32, 1, 1),
590            block_dim: (block, 1, 1),
591            shared_mem_bytes: 0,
592        };
593        let mut builder = stream.launch_builder(&func);
594        builder
595            .arg(&logits_dev)
596            .arg(&decoded_dev)
597            .arg(&inv_tau)
598            .arg(&n_i32)
599            .arg(&mut first_dev)
600            .arg(&mut second_dev);
601        // SAFETY: grid/block validated; all device pointers are cudarc-checked
602        // allocations on this stream; the kernel reads logits/decoded and writes
603        // within first[0..n*k*p] / second[0..n*k*k*p].
604        unsafe { builder.launch(cfg) }.gpu_ctx("sae_rowjet kernel launch")?;
605
606        let mut first = vec![0.0_f64; n * k * p];
607        let mut second = vec![0.0_f64; n * k * k * p];
608        stream
609            .memcpy_dtoh(&first_dev, &mut first)
610            .gpu_ctx("sae_rowjet dtoh first")?;
611        stream
612            .memcpy_dtoh(&second_dev, &mut second)
613            .gpu_ctx("sae_rowjet dtoh second")?;
614        stream.synchronize().gpu_ctx("sae_rowjet synchronize")?;
615
616        Ok(SaeRowJetChannels {
617            n_rows: n,
618            k,
619            p,
620            first,
621            second,
622        })
623    }
624}
625
626#[cfg(test)]
627mod tests {
628    use super::*;
629
630    fn fixture(n: usize, k: usize, p: usize) -> Vec<SaeSoftmaxRowInputs> {
631        let mut rows = Vec::with_capacity(n);
632        for i in 0..n {
633            let logits = (0..k)
634                .map(|j| 0.7 * ((i * 31 + j * 17) as f64 * 0.013).sin())
635                .collect();
636            let decoded = (0..k * p)
637                .map(|t| ((i * 7 + t * 5) as f64 * 0.011).cos())
638                .collect();
639            rows.push(SaeSoftmaxRowInputs { logits, decoded });
640        }
641        rows
642    }
643
644    #[test]
645    fn cpu_softmax_matches_unified_program_k8() {
646        // The CPU fallback IS the production `reconstruction_all_columns_packed`,
647        // so this pins the flattening/layout: a single row's gate-only
648        // reconstruction has gradient = ζ'(ℓ)·decoded and the Hessian is the
649        // dense softmax second derivative contracted with decoded. We assert the
650        // gradient channel reproduces the analytic softmax Jacobian times
651        // decoded for a sanity check on the layout.
652        let k = 8;
653        let p = 4;
654        let inv_tau = 1.0 / 0.7;
655        let rows = fixture(3, k, p);
656        let out = sae_row_jets_cpu_softmax(&rows, k, p, inv_tau);
657        assert_eq!(out.first.len(), 3 * k * p);
658        assert_eq!(out.second.len(), 3 * k * k * p);
659        // Analytic softmax Jacobian J[a][m] = inv_tau * ζ_a (δ_am − ζ_m); the
660        // reconstruction column c gradient wrt primary a is
661        // Σ_m J[m][a] * decoded[m*p+c] = inv_tau*(ζ_a*decoded[a] − ζ_a*Σ_m ζ_m decoded[m]).
662        let inp = &rows[0];
663        let z = softmax_values(&inp.logits, inv_tau);
664        for c in 0..p {
665            let mean: f64 = (0..k).map(|m| z[m] * inp.decoded[m * p + c]).sum();
666            for a in 0..k {
667                let analytic = inv_tau * z[a] * (inp.decoded[a * p + c] - mean);
668                let got = out.first[(a) * p + c];
669                assert!(
670                    (analytic - got).abs() <= 1e-12,
671                    "softmax grad mismatch a={a} c={c}: analytic={analytic} got={got}"
672                );
673            }
674        }
675    }
676
677    #[test]
678    fn second_channel_is_symmetric() {
679        let k = 6;
680        let p = 3;
681        let inv_tau = 1.3;
682        let rows = fixture(2, k, p);
683        let out = sae_row_jets_cpu_softmax(&rows, k, p, inv_tau);
684        for row in 0..2 {
685            for c in 0..p {
686                for a in 0..k {
687                    for b in 0..k {
688                        let ab = out.second[((row * k + a) * k + b) * p + c];
689                        let ba = out.second[((row * k + b) * k + a) * p + c];
690                        assert!(
691                            (ab - ba).abs() <= 1e-12,
692                            "asymmetry row={row} c={c} {a},{b}"
693                        );
694                    }
695                }
696            }
697        }
698    }
699
700    #[test]
701    fn gauss_newton_slab_is_symmetric_psd_gram() {
702        // H_tt = Σ_c first[a][c]·first[b][c] is a Gram matrix: symmetric and PSD.
703        let k = 5;
704        let p = 7;
705        let inv_tau = 1.0 / 0.9;
706        let rows = fixture(4, k, p);
707        let ch = sae_row_jets_cpu_softmax(&rows, k, p, inv_tau);
708        let slabs = gauss_newton_row_hessian_slabs(&ch);
709        assert_eq!(slabs.len(), 4 * k * k);
710        for row in 0..4 {
711            let s = &slabs[row * k * k..(row + 1) * k * k];
712            // symmetry + matches the explicit Σ_c contraction
713            let f = &ch.first[row * k * p..(row + 1) * k * p];
714            for a in 0..k {
715                for b in 0..k {
716                    let expect: f64 = (0..p).map(|c| f[a * p + c] * f[b * p + c]).sum();
717                    assert!((s[a * k + b] - expect).abs() <= 1e-12);
718                    assert!((s[a * k + b] - s[b * k + a]).abs() <= 1e-12);
719                }
720            }
721            // PSD: vᵀHv = ‖Σ_a v_a J_a‖² ≥ 0 for a random v.
722            let v: Vec<f64> = (0..k).map(|a| ((a * 13 + 1) as f64 * 0.3).sin()).collect();
723            let mut quad = 0.0;
724            for a in 0..k {
725                for b in 0..k {
726                    quad += v[a] * s[a * k + b] * v[b];
727                }
728            }
729            assert!(quad >= -1e-12, "GN slab not PSD: vᵀHv={quad}");
730        }
731    }
732
733    #[cfg(target_os = "linux")]
734    #[test]
735    fn device_matches_cpu_when_available() {
736        // Exactness gate: when a device is admitted, the device channels must
737        // match the CPU unified jet to <=1e-9. When no device is available the
738        // dispatcher returns the CPU result (still correct), so this asserts the
739        // contract on whichever path ran.
740        let k = 8;
741        let p = 16;
742        let inv_tau = 1.0 / 0.7;
743        let rows = fixture(DEVICE_ROW_THRESHOLD + 64, k, p);
744        let cpu = sae_row_jets_cpu_softmax(&rows, k, p, inv_tau);
745        let got = sae_row_jets_softmax(&rows, k, p, inv_tau);
746        let max_diff = |a: &SaeRowJetChannels, b: &SaeRowJetChannels| {
747            let mut m = 0.0_f64;
748            for (x, y) in a.first.iter().zip(&b.first) {
749                m = m.max((x - y).abs());
750            }
751            for (x, y) in a.second.iter().zip(&b.second) {
752                m = m.max((x - y).abs());
753            }
754            m
755        };
756        let maxabs = max_diff(&cpu, &got);
757        assert!(
758            maxabs <= 1e-9,
759            "device vs CPU row-jet max abs diff {maxabs} > 1e-9"
760        );
761
762        // ANTI-FALSE-GREEN (#415/#1175): the assert above passes trivially as
763        // CPU==CPU when no GPU is present — a dead/declined kernel would never
764        // be caught. So when a runtime IS admitted, call the device entry
765        // DIRECTLY (no silent fall-through to CPU) and require it to actually
766        // run and match the oracle. With #1686's --fmad=false now applied to
767        // this kernel (it compiles through `compile_ptx_arch`), the measured
768        // device-vs-CPU drift on a V100 is ~1.7e-16 — the FMA-free softmax
769        // seeded-jet is round-off-floor tight, far inside the 1e-9 gate.
770        if gam_gpu::device_runtime::GpuRuntime::global().is_some() {
771            let dev = device::sae_row_jets_softmax_device(&rows, k, p, inv_tau)
772                .expect("admitted GPU runtime must run the sae_rowjet device kernel, not fall back");
773            let dev_diff = max_diff(&cpu, &dev);
774            assert!(
775                dev_diff <= 1e-9,
776                "device-only sae row-jet vs CPU max abs diff {dev_diff} > 1e-9 \
777                 (kernel ran but drifted — check the softmax jet recurrence)"
778            );
779        }
780    }
781}