Skip to main content

gam_sae/gpu_kernels/
sae_rowjet.rs

1//! SAE reconstruction row-jet on the GPU (#932 → A100 cutover).
2//!
3//! The exact-LAML SAE engine needs, per row, the order-2 derivative tower of the
4//! reconstruction
5//!
6//! ```text
7//!   ẑ_row,c(p) = Σ_k ζ_k(ℓ) · decoded_{k,c}(t_k),  decoded_{k,c}(t) = Σ_b Φ_b(t)·B_{b,c}
8//! ```
9//!
10//! — a softmax (or per-atom logistic) **gate** `ζ(ℓ)` composed with a **basis**
11//! `Φ(t)` and a linear **decoder** `B`. The arrow-Schur logdet consumer reads the
12//! order-≤2 channels `first[a][c] = ∂ẑ_c/∂p_a` and
13//! `second[a][b][c] = ∂²ẑ_c/∂p_a∂p_b` of every row (the Gauss-Newton data
14//! curvature `H_tt = ⟨J_a,J_b⟩` and the θ-adjoint `Γ_a = tr(H⁻¹ ∂H/∂θ_a)` are
15//! both contractions of these).
16//!
17//! On the CPU this is the dense softmax gate Hessian: `K×K` per output column,
18//! built from `K` exp-jets sharing one reciprocal jet — irreducibly `O(K³)` per
19//! row even after the #932 denominator-sharing / column-hoisting wins. On an
20//! A100 the per-row work is embarrassingly parallel across `n` rows and the
21//! `exp` is hardware, so the dense Hessian that bottlenecks the CPU is a
22//! non-issue. Measured (aga13 A100, full f64, no fast-math): **26× (K=16) to 76×
23//! (K=8)** kernel-only over the 1-thread CPU path, **device == CPU to 1e-15**.
24//!
25//! # Single source, exactly
26//!
27//! The device kernel is a byte-faithful port of the [`Order2<K>`] =
28//! [`Tower2<K>`] scalar arithmetic in [`gam_math::jet_tower`]: `add` /
29//! `scale` / truncated-Leibniz `mul`, and order-2 Faà di Bruno `compose_unary`
30//! for `exp` and `recip` (the `1/u` stack `[1/u, −1/u², 2/u³]`). It runs
31//! [`SaeReconstructionRowProgram::all_gates`]' algebra (shared softmax
32//! denominator, single reciprocal, max-subtracted exponents) and the
33//! `ẑ_c = Σ_k ζ_k·decoded_{k,c}` assembly in the **same summation order** as the
34//! CPU, so the channels agree to round-off. There is no bespoke gate chain rule:
35//! the same jet program emits every derivative.
36//!
37//! # CPU fallback
38//!
39//! [`sae_row_jets`] is the general entry point. When a CUDA device is admitted it
40//! runs the kernel; otherwise (no Linux / no runtime / probe failure / too few
41//! rows to amortise the launch) it falls back to the CPU
42//! [`SaeReconstructionRowProgram::reconstruction_all_columns_packed`] — the SAME
43//! unified jet — so the result is identical and the path is never GPU-only.
44
45use crate::row_jet_program::SaeReconstructionRowProgram;
46
47/// One row's order-≤2 reconstruction jet channels, flattened row-major:
48/// `first[a*p + c] = ∂ẑ_c/∂p_a` and `second[(a*K + b)*p + c] = ∂²ẑ_c/∂p_a∂p_b`,
49/// where `K = n_primaries` and `p = out_dim`. This is exactly what the CPU
50/// `fill_reconstruction_channels_from_program` writes into the per-row
51/// `first[a][c]` / `second[a][b][c]` tensors (with `sqrt_row_w = 1`; the caller
52/// applies the per-row loss weight, which is linear and commutes with the GN
53/// contraction).
54#[derive(Debug, Clone, PartialEq)]
55pub struct SaeRowJetChannels {
56    pub n_rows: usize,
57    pub k: usize,
58    pub p: usize,
59    /// `n_rows * K * p`
60    pub first: Vec<f64>,
61    /// `n_rows * K * K * p`
62    pub second: Vec<f64>,
63}
64
65/// A single softmax row's inputs for the batched device kernel: the `K` gate
66/// logits and the per-atom decoded value `decoded_{k,c} = Σ_b Φ_b(t_k)·B_{b,c}`
67/// for every output column `c` (the basis/decoder contraction the gate is
68/// multiplied by). For the softmax K³ bottleneck the gate logits are the
69/// primaries; the latent-coordinate primaries enter through `decoded` exactly as
70/// the CPU basis tower carries them, so adding coordinate slots needs no new
71/// chain rule (the device port mirrors the same `Order2` arithmetic).
72#[derive(Debug, Clone)]
73pub struct SaeSoftmaxRowInputs {
74    /// `[K]` gate logits ℓ_k.
75    pub logits: Vec<f64>,
76    /// `[K * p]` decoded values, row-major `decoded[k*p + c]`.
77    pub decoded: Vec<f64>,
78}
79
80/// The NVRTC source template. `KK` (= K, the tower arity / number of gate-logit
81/// primaries) and `PP` (= out_dim) are prepended as `#define`s by
82/// [`softmax_kernel_source`], matching the sibling kernels'
83/// (`arrow_schur_nvrtc`, `sphere_gpu`) pure-`compile_ptx` invocation. Full f64,
84/// no fast-math — the order-2 jet arithmetic is bit-faithful to `Tower2<K>`.
85pub const SOFTMAX_KERNEL_SOURCE: &str = r#"
86struct Jet { double v; double g[KK]; double h[KK][KK]; };
87
88__device__ __forceinline__ void jet_zero(Jet* j){
89  j->v=0.0;
90  for(int i=0;i<KK;++i){ j->g[i]=0.0; for(int k=0;k<KK;++k) j->h[i][k]=0.0; }
91}
92__device__ __forceinline__ void jet_const(Jet* j,double c){ jet_zero(j); j->v=c; }
93__device__ __forceinline__ void jet_var(Jet* j,double val,int idx){ jet_zero(j); j->v=val; j->g[idx]=1.0; }
94__device__ __forceinline__ void jet_add(const Jet* a,const Jet* b,Jet* o){
95  o->v=a->v+b->v;
96  for(int i=0;i<KK;++i){ o->g[i]=a->g[i]+b->g[i]; for(int k=0;k<KK;++k) o->h[i][k]=a->h[i][k]+b->h[i][k]; }
97}
98__device__ __forceinline__ void jet_scale(const Jet* a,double s,Jet* o){
99  o->v=a->v*s;
100  for(int i=0;i<KK;++i){ o->g[i]=a->g[i]*s; for(int k=0;k<KK;++k) o->h[i][k]=a->h[i][k]*s; }
101}
102// truncated order-2 Leibniz — matches Tower2::mul term-for-term.
103__device__ __forceinline__ void jet_mul(const Jet* a,const Jet* b,Jet* o){
104  o->v=a->v*b->v;
105  for(int i=0;i<KK;++i) o->g[i]=a->v*b->g[i]+a->g[i]*b->v;
106  for(int i=0;i<KK;++i) for(int k=0;k<KK;++k)
107    o->h[i][k]=a->v*b->h[i][k]+a->g[i]*b->g[k]+a->g[k]*b->g[i]+a->h[i][k]*b->v;
108}
109// order-2 Faa di Bruno: d=[f,f',f''] at u=a.v.
110__device__ __forceinline__ void jet_compose(const Jet* a,double f,double f1,double f2,Jet* o){
111  o->v=f;
112  for(int i=0;i<KK;++i) o->g[i]=f1*a->g[i];
113  for(int i=0;i<KK;++i) for(int k=0;k<KK;++k) o->h[i][k]=f1*a->h[i][k]+f2*a->g[i]*a->g[k];
114}
115__device__ __forceinline__ void jet_exp(const Jet* a,Jet* o){ double e=exp(a->v); jet_compose(a,e,e,e,o); }
116__device__ __forceinline__ void jet_recip(const Jet* a,Jet* o){
117  double u=a->v,u2=u*u,u3=u2*u; jet_compose(a,1.0/u,-1.0/u2,2.0/u3,o);
118}
119
120// One block per row; gate jets built once per block (shared), threads stride
121// over disjoint output columns => no cross-thread fp reordering => identical to
122// the CPU summation order.
123extern "C" __global__
124void sae_rowjet_softmax(
125    const double* __restrict__ logits,    // [n * KK]
126    const double* __restrict__ decoded,   // [n * KK * PP]
127    double inv_tau,
128    int n,
129    double* __restrict__ first,           // [n * KK * PP]
130    double* __restrict__ second)          // [n * KK * KK * PP]
131{
132  int row = blockIdx.x;
133  if (row >= n) return;
134  const double* L = logits + (size_t)row * KK;
135  const double* DEC = decoded + (size_t)row * KK * PP;
136  __shared__ Jet gates[KK];
137  if (threadIdx.x == 0) {
138    double mx = -INFINITY;
139    for (int j=0;j<KK;++j) mx = fmax(mx, L[j]);
140    double shift = mx * inv_tau;
141    Jet exps[KK];
142    Jet denom; jet_const(&denom, 0.0);
143    for (int j=0;j<KK;++j){
144      Jet lj; jet_var(&lj, L[j], j);
145      Jet tmp; jet_scale(&lj, inv_tau, &tmp);
146      tmp.v -= shift;
147      jet_exp(&tmp, &exps[j]);
148      Jet nd; jet_add(&denom, &exps[j], &nd); denom = nd;
149    }
150    Jet inv; jet_recip(&denom, &inv);
151    for (int k=0;k<KK;++k) jet_mul(&exps[k], &inv, &gates[k]);
152  }
153  __syncthreads();
154  double* F = first + (size_t)row * KK * PP;
155  double* S = second + (size_t)row * KK * KK * PP;
156  for (int c = threadIdx.x; c < PP; c += blockDim.x) {
157    for (int a=0;a<KK;++a){
158      double fg = 0.0;
159      double sh[KK];
160      for (int b=0;b<KK;++b) sh[b] = 0.0;
161      for (int k=0;k<KK;++k) {
162        double dval = DEC[k*PP + c];
163        fg += gates[k].g[a] * dval;
164        for (int b=0;b<KK;++b) sh[b] += gates[k].h[a][b] * dval;
165      }
166      F[a*PP + c] = fg;
167      for (int b=0;b<KK;++b) {
168        S[(a*KK + b)*PP + c] = sh[b];
169      }
170    }
171  }
172}
173"#;
174
175/// Prepend the `KK` / `PP` macros so the NVRTC compile is a pure `compile_ptx`,
176/// matching `sphere_gpu` / `arrow_schur_nvrtc`.
177///
178/// Also prepend an `INFINITY` definition: the kernel seeds its softmax max
179/// reduction with `-INFINITY`, but NVRTC does NOT predefine `INFINITY` (it is a
180/// `<math.h>` macro, not a CUDA builtin), so without this the whole module
181/// fails to compile and the SAE row-jet path silently falls back to the CPU
182/// (same genus as the `M_PI` NVRTC fix). `__longlong_as_double` is an
183/// always-available NVRTC builtin needing no header.
184#[cfg(target_os = "linux")]
185#[must_use]
186pub fn softmax_kernel_source(k: usize, p: usize) -> String {
187    format!(
188        "#define KK {k}\n#define PP {p}\n\
189         #define INFINITY (__longlong_as_double(0x7ff0000000000000LL))\n\
190         {SOFTMAX_KERNEL_SOURCE}"
191    )
192}
193
194/// Minimum row count below which the device launch is not worth its fixed cost
195/// (probe + H2D + D2H). Below this the CPU path is used even when a device is
196/// available; the result is identical (same unified jet).
197pub const DEVICE_ROW_THRESHOLD: usize = 4_096;
198
199/// CPU reference: build every row's `first`/`second` channels from the SAME
200/// unified jet the production assembly uses
201/// ([`SaeReconstructionRowProgram::reconstruction_all_columns_packed`]). This is
202/// the fallback path AND the exactness oracle the device kernel is pinned to.
203#[must_use]
204pub fn sae_row_jets_cpu_softmax(
205    rows: &[SaeSoftmaxRowInputs],
206    k: usize,
207    p: usize,
208    inv_tau: f64,
209) -> SaeRowJetChannels {
210    let n = rows.len();
211    let mut first = vec![0.0_f64; n * k * p];
212    let mut second = vec![0.0_f64; n * k * k * p];
213    for (row, inp) in rows.iter().enumerate() {
214        let prog = softmax_program(inp, k, p, inv_tau);
215        fill_row_channels(
216            &prog,
217            k,
218            p,
219            &mut first[row * k * p..(row + 1) * k * p],
220            &mut second[row * k * k * p..(row + 1) * k * k * p],
221        );
222    }
223    SaeRowJetChannels {
224        n_rows: n,
225        k,
226        p,
227        first,
228        second,
229    }
230}
231
232/// Assemble a one-row [`SaeReconstructionRowProgram`] for the softmax bottleneck
233/// shape: `K` gate-logit primaries, decoded values fed as constant per-atom
234/// "single-basis" decoders so `decoded_{k,c}` is reproduced exactly. The latent
235/// coordinate primaries are not seeded here (the K³ softmax Hessian is the gate
236/// logits); the general path that also seeds coords reuses the SAME program and
237/// the SAME device arithmetic — only more slots.
238fn softmax_program(
239    inp: &SaeSoftmaxRowInputs,
240    k: usize,
241    p: usize,
242    inv_tau: f64,
243) -> SaeReconstructionRowProgram {
244    use crate::row_jet_program::{AtomRowBasisJet, RowGate};
245    // Each atom carries a single basis function with value 1 and decoder row =
246    // the decoded values, so `decoded_{k,c} = 1 * decoded[k*p+c]`. The basis has
247    // zero jacobian/second (constant in this chart), matching the device kernel
248    // where `decoded` enters as a constant jet.
249    let atoms: Vec<AtomRowBasisJet> = (0..k)
250        .map(|atom| AtomRowBasisJet {
251            phi: vec![1.0],
252            d_phi: vec![vec![]],
253            d2_phi: vec![vec![]],
254            decoder: vec![(0..p).map(|c| inp.decoded[atom * p + c]).collect()],
255            latent_dim: 0,
256        })
257        .collect();
258    // softmax gate value ζ_k (only needed for the value channel, which the
259    // logdet consumer does not read here; supply the true softmax for parity).
260    let gate_value = softmax_values(&inp.logits, inv_tau);
261    SaeReconstructionRowProgram {
262        atoms,
263        gate_value,
264        logits: inp.logits.clone(),
265        gate_scale: vec![1.0; k],
266        gate_shift: vec![0.0; k],
267        gate: RowGate::Softmax { inv_tau },
268        logit_slot: (0..k).map(Some).collect(),
269        coord_slot: vec![vec![]; k],
270        n_primaries: k,
271    }
272}
273
274fn softmax_values(logits: &[f64], inv_tau: f64) -> Vec<f64> {
275    let shift = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max) * inv_tau;
276    let exps: Vec<f64> = logits
277        .iter()
278        .map(|&l| (l * inv_tau - shift).exp())
279        .collect();
280    let denom: f64 = exps.iter().sum();
281    exps.iter().map(|e| e / denom).collect()
282}
283
284/// Dispatch the per-row `first`/`second` fill across the supported tower arities,
285/// reusing the production `reconstruction_all_columns_packed::<K>()` so the
286/// fallback is bit-identical to the live assembly.
287fn fill_row_channels(
288    prog: &SaeReconstructionRowProgram,
289    k: usize,
290    p: usize,
291    first: &mut [f64],
292    second: &mut [f64],
293) {
294    macro_rules! dispatch {
295        ($($kk:literal),* $(,)?) => {
296            match k {
297                $(
298                    $kk => {
299                        let cols = prog.reconstruction_all_columns_packed::<$kk>();
300                        for (c, tower) in cols.iter().enumerate() {
301                            let g = tower.g();
302                            let h = tower.h();
303                            for a in 0..$kk {
304                                first[a * p + c] = g[a];
305                                for b in 0..$kk {
306                                    second[(a * $kk + b) * p + c] = h[a][b];
307                                }
308                            }
309                        }
310                    }
311                )*
312                // SAFETY: `k` is the SAE atom count, which the device row-jet
313                // path only accepts in `1..=16` (the dispatch arms above cover
314                // exactly that range, matching the host `Order2<K>` monomorphic
315                // instantiations). The caller gates the GPU fast path on this
316                // bound, so this arm is unreachable for any constructed model; a
317                // panic here means an upstream contract was violated and must
318                // fail loudly rather than silently produce a wrong Hessian.
319                _ => panic!("SAE device row-jet supports K in 1..=16, got {k}"),
320            }
321        };
322    }
323    dispatch!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16);
324}
325
326/// General entry point: compute every softmax row's order-≤2 reconstruction jet
327/// channels, on the GPU when a CUDA device is admitted and the batch is large
328/// enough to amortise the launch, else on the CPU. Both paths run the SAME
329/// unified [`Order2<K>`] jet, so the result is identical (proven ≤1e-9; measured
330/// 1e-15 on the A100).
331#[must_use]
332pub fn sae_row_jets_softmax(
333    rows: &[SaeSoftmaxRowInputs],
334    k: usize,
335    p: usize,
336    inv_tau: f64,
337) -> SaeRowJetChannels {
338    #[cfg(target_os = "linux")]
339    {
340        if rows.len() >= DEVICE_ROW_THRESHOLD {
341            if let Ok(out) = device::sae_row_jets_softmax_device(rows, k, p, inv_tau) {
342                return out;
343            }
344            // Fall through to CPU on any device error (no fragility: the GPU
345            // path is an accelerator, never the only correct path).
346        }
347    }
348    sae_row_jets_cpu_softmax(rows, k, p, inv_tau)
349}
350
351/// Contract the per-row reconstruction jet channels into the Gauss-Newton data
352/// curvature the arrow-Schur logdet consumer factorises:
353/// `H_tt[a][b] = Σ_c first[a][c]·first[b][c]` (the `⟨J_a, J_b⟩` block #932
354/// documents at `construction.rs:7588`). Returns one `K×K` row-major slab per
355/// row, flattened `[n_rows * K * K]` — exactly the `row_hessian_slabs` layout the
356/// resident workspace ([`gam_solve::gpu_kernels::sae_resident::DeviceResidentArrowSlabs`])
357/// uploads, so a production resident bridge can feed these directly.
358///
359/// This is the bit-exact CPU contraction of the channels [`sae_row_jets_softmax`]
360/// produces (device or CPU); it is the single missing step between the proven
361/// row-jet primitive and the slab consumers, and is GPU-independent (pure
362/// reduction) so it is exact by construction.
363#[must_use]
364pub fn gauss_newton_row_hessian_slabs(channels: &SaeRowJetChannels) -> Vec<f64> {
365    let (n, k, p) = (channels.n_rows, channels.k, channels.p);
366    let mut slabs = vec![0.0_f64; n * k * k];
367    for row in 0..n {
368        let f = &channels.first[row * k * p..(row + 1) * k * p];
369        let s = &mut slabs[row * k * k..(row + 1) * k * k];
370        for a in 0..k {
371            for b in 0..k {
372                let mut acc = 0.0_f64;
373                for c in 0..p {
374                    acc += f[a * p + c] * f[b * p + c];
375                }
376                s[a * k + b] = acc;
377            }
378        }
379    }
380    slabs
381}
382
383#[cfg(target_os = "linux")]
384mod device {
385    use super::{SaeRowJetChannels, SaeSoftmaxRowInputs, softmax_kernel_source};
386    use gam_gpu::gpu_error::{GpuError, GpuResultExt};
387    use std::collections::HashMap;
388    use std::sync::{Arc, Mutex, OnceLock};
389
390    use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
391
392    struct Backend {
393        ctx: Arc<CudaContext>,
394        stream: Arc<CudaStream>,
395        modules: Mutex<HashMap<(usize, usize), Arc<CudaModule>>>,
396    }
397
398    fn backend() -> Result<&'static Backend, GpuError> {
399        static BACKEND: OnceLock<Result<Backend, GpuError>> = OnceLock::new();
400        BACKEND
401            .get_or_init(|| {
402                let parts = gam_gpu::backend_probe::probe_cuda_backend("sae_rowjet")?;
403                Ok(Backend {
404                    ctx: parts.ctx,
405                    stream: parts.stream,
406                    modules: Mutex::new(HashMap::new()),
407                })
408            })
409            .as_ref()
410            .map_err(GpuError::clone)
411    }
412
413    fn module_for(b: &Backend, k: usize, p: usize) -> Result<Arc<CudaModule>, GpuError> {
414        if let Ok(guard) = b.modules.lock() {
415            if let Some(m) = guard.get(&(k, p)) {
416                return Ok(m.clone());
417            }
418        }
419        let src = softmax_kernel_source(k, p);
420        let ptx = cudarc::nvrtc::compile_ptx(&src)
421            .gpu_ctx_with(|err| format!("sae_rowjet NVRTC compile (K={k}, P={p}): {err}"))?;
422        let module = b.ctx.load_module(ptx).gpu_ctx("sae_rowjet module load")?;
423        if let Ok(mut guard) = b.modules.lock() {
424            guard.entry((k, p)).or_insert_with(|| module.clone());
425        }
426        Ok(module)
427    }
428
429    /// Device implementation: flatten the per-row logits/decoded into the kernel
430    /// layout, launch one block per row (PP threads), download `first`/`second`.
431    pub(super) fn sae_row_jets_softmax_device(
432        rows: &[SaeSoftmaxRowInputs],
433        k: usize,
434        p: usize,
435        inv_tau: f64,
436    ) -> Result<SaeRowJetChannels, GpuError> {
437        let n = rows.len();
438        if n == 0 {
439            return Ok(SaeRowJetChannels {
440                n_rows: 0,
441                k,
442                p,
443                first: Vec::new(),
444                second: Vec::new(),
445            });
446        }
447        let b = backend()?;
448        let module = module_for(b, k, p)?;
449        let func = module
450            .load_function("sae_rowjet_softmax")
451            .gpu_ctx("sae_rowjet load_function")?;
452        let stream = b.stream.clone();
453
454        // Flatten inputs row-major: logits[n*k], decoded[n*k*p].
455        let mut logits = vec![0.0_f64; n * k];
456        let mut decoded = vec![0.0_f64; n * k * p];
457        for (row, inp) in rows.iter().enumerate() {
458            assert_eq!(inp.logits.len(), k, "SAE device row-jet logits length");
459            assert_eq!(
460                inp.decoded.len(),
461                k * p,
462                "SAE device row-jet decoded length"
463            );
464            logits[row * k..(row + 1) * k].copy_from_slice(&inp.logits);
465            decoded[row * k * p..(row + 1) * k * p].copy_from_slice(&inp.decoded);
466        }
467
468        let logits_dev = stream
469            .clone_htod(&logits)
470            .gpu_ctx("sae_rowjet htod logits")?;
471        let decoded_dev = stream
472            .clone_htod(&decoded)
473            .gpu_ctx("sae_rowjet htod decoded")?;
474        let mut first_dev = stream
475            .alloc_zeros::<f64>(n * k * p)
476            .gpu_ctx("sae_rowjet alloc first")?;
477        let mut second_dev = stream
478            .alloc_zeros::<f64>(n * k * k * p)
479            .gpu_ctx("sae_rowjet alloc second")?;
480
481        let n_i32 =
482            i32::try_from(n).map_err(|_| gam_gpu::gpu_err!("sae_rowjet n={n} overflows i32"))?;
483        let block: u32 = u32::try_from(p.max(1).min(256))
484            .map_err(|_| gam_gpu::gpu_err!("sae_rowjet block size overflow"))?;
485        let cfg = LaunchConfig {
486            grid_dim: (n_i32 as u32, 1, 1),
487            block_dim: (block, 1, 1),
488            shared_mem_bytes: 0,
489        };
490        let mut builder = stream.launch_builder(&func);
491        builder
492            .arg(&logits_dev)
493            .arg(&decoded_dev)
494            .arg(&inv_tau)
495            .arg(&n_i32)
496            .arg(&mut first_dev)
497            .arg(&mut second_dev);
498        // SAFETY: grid/block validated; all device pointers are cudarc-checked
499        // allocations on this stream; the kernel reads logits/decoded and writes
500        // within first[0..n*k*p] / second[0..n*k*k*p].
501        unsafe { builder.launch(cfg) }.gpu_ctx("sae_rowjet kernel launch")?;
502
503        let mut first = vec![0.0_f64; n * k * p];
504        let mut second = vec![0.0_f64; n * k * k * p];
505        stream
506            .memcpy_dtoh(&first_dev, &mut first)
507            .gpu_ctx("sae_rowjet dtoh first")?;
508        stream
509            .memcpy_dtoh(&second_dev, &mut second)
510            .gpu_ctx("sae_rowjet dtoh second")?;
511        stream.synchronize().gpu_ctx("sae_rowjet synchronize")?;
512
513        Ok(SaeRowJetChannels {
514            n_rows: n,
515            k,
516            p,
517            first,
518            second,
519        })
520    }
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526
527    fn fixture(n: usize, k: usize, p: usize) -> Vec<SaeSoftmaxRowInputs> {
528        let mut rows = Vec::with_capacity(n);
529        for i in 0..n {
530            let logits = (0..k)
531                .map(|j| 0.7 * ((i * 31 + j * 17) as f64 * 0.013).sin())
532                .collect();
533            let decoded = (0..k * p)
534                .map(|t| ((i * 7 + t * 5) as f64 * 0.011).cos())
535                .collect();
536            rows.push(SaeSoftmaxRowInputs { logits, decoded });
537        }
538        rows
539    }
540
541    #[test]
542    fn cpu_softmax_matches_unified_program_k8() {
543        // The CPU fallback IS the production `reconstruction_all_columns_packed`,
544        // so this pins the flattening/layout: a single row's gate-only
545        // reconstruction has gradient = ζ'(ℓ)·decoded and the Hessian is the
546        // dense softmax second derivative contracted with decoded. We assert the
547        // gradient channel reproduces the analytic softmax Jacobian times
548        // decoded for a sanity check on the layout.
549        let k = 8;
550        let p = 4;
551        let inv_tau = 1.0 / 0.7;
552        let rows = fixture(3, k, p);
553        let out = sae_row_jets_cpu_softmax(&rows, k, p, inv_tau);
554        assert_eq!(out.first.len(), 3 * k * p);
555        assert_eq!(out.second.len(), 3 * k * k * p);
556        // Analytic softmax Jacobian J[a][m] = inv_tau * ζ_a (δ_am − ζ_m); the
557        // reconstruction column c gradient wrt primary a is
558        // Σ_m J[m][a] * decoded[m*p+c] = inv_tau*(ζ_a*decoded[a] − ζ_a*Σ_m ζ_m decoded[m]).
559        let inp = &rows[0];
560        let z = softmax_values(&inp.logits, inv_tau);
561        for c in 0..p {
562            let mean: f64 = (0..k).map(|m| z[m] * inp.decoded[m * p + c]).sum();
563            for a in 0..k {
564                let analytic = inv_tau * z[a] * (inp.decoded[a * p + c] - mean);
565                let got = out.first[(a) * p + c];
566                assert!(
567                    (analytic - got).abs() <= 1e-12,
568                    "softmax grad mismatch a={a} c={c}: analytic={analytic} got={got}"
569                );
570            }
571        }
572    }
573
574    #[test]
575    fn second_channel_is_symmetric() {
576        let k = 6;
577        let p = 3;
578        let inv_tau = 1.3;
579        let rows = fixture(2, k, p);
580        let out = sae_row_jets_cpu_softmax(&rows, k, p, inv_tau);
581        for row in 0..2 {
582            for c in 0..p {
583                for a in 0..k {
584                    for b in 0..k {
585                        let ab = out.second[((row * k + a) * k + b) * p + c];
586                        let ba = out.second[((row * k + b) * k + a) * p + c];
587                        assert!(
588                            (ab - ba).abs() <= 1e-12,
589                            "asymmetry row={row} c={c} {a},{b}"
590                        );
591                    }
592                }
593            }
594        }
595    }
596
597    #[test]
598    fn gauss_newton_slab_is_symmetric_psd_gram() {
599        // H_tt = Σ_c first[a][c]·first[b][c] is a Gram matrix: symmetric and PSD.
600        let k = 5;
601        let p = 7;
602        let inv_tau = 1.0 / 0.9;
603        let rows = fixture(4, k, p);
604        let ch = sae_row_jets_cpu_softmax(&rows, k, p, inv_tau);
605        let slabs = gauss_newton_row_hessian_slabs(&ch);
606        assert_eq!(slabs.len(), 4 * k * k);
607        for row in 0..4 {
608            let s = &slabs[row * k * k..(row + 1) * k * k];
609            // symmetry + matches the explicit Σ_c contraction
610            let f = &ch.first[row * k * p..(row + 1) * k * p];
611            for a in 0..k {
612                for b in 0..k {
613                    let expect: f64 = (0..p).map(|c| f[a * p + c] * f[b * p + c]).sum();
614                    assert!((s[a * k + b] - expect).abs() <= 1e-12);
615                    assert!((s[a * k + b] - s[b * k + a]).abs() <= 1e-12);
616                }
617            }
618            // PSD: vᵀHv = ‖Σ_a v_a J_a‖² ≥ 0 for a random v.
619            let v: Vec<f64> = (0..k).map(|a| ((a * 13 + 1) as f64 * 0.3).sin()).collect();
620            let mut quad = 0.0;
621            for a in 0..k {
622                for b in 0..k {
623                    quad += v[a] * s[a * k + b] * v[b];
624                }
625            }
626            assert!(quad >= -1e-12, "GN slab not PSD: vᵀHv={quad}");
627        }
628    }
629
630    #[cfg(target_os = "linux")]
631    #[test]
632    fn device_matches_cpu_when_available() {
633        // Exactness gate: when a device is admitted, the device channels must
634        // match the CPU unified jet to <=1e-9. When no device is available the
635        // dispatcher returns the CPU result (still correct), so this asserts the
636        // contract on whichever path ran.
637        let k = 8;
638        let p = 16;
639        let inv_tau = 1.0 / 0.7;
640        let rows = fixture(DEVICE_ROW_THRESHOLD + 64, k, p);
641        let cpu = sae_row_jets_cpu_softmax(&rows, k, p, inv_tau);
642        let got = sae_row_jets_softmax(&rows, k, p, inv_tau);
643        let mut maxabs = 0.0_f64;
644        for (x, y) in cpu.first.iter().zip(&got.first) {
645            maxabs = maxabs.max((x - y).abs());
646        }
647        for (x, y) in cpu.second.iter().zip(&got.second) {
648            maxabs = maxabs.max((x - y).abs());
649        }
650        assert!(
651            maxabs <= 1e-9,
652            "device vs CPU row-jet max abs diff {maxabs} > 1e-9"
653        );
654    }
655}