gam 0.3.125

Generalized penalized likelihood engine
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
//! SAE reconstruction row-jet on the GPU (#932 → A100 cutover).
//!
//! The exact-LAML SAE engine needs, per row, the order-2 derivative tower of the
//! reconstruction
//!
//! ```text
//!   ẑ_row,c(p) = Σ_k ζ_k(ℓ) · decoded_{k,c}(t_k),  decoded_{k,c}(t) = Σ_b Φ_b(t)·B_{b,c}
//! ```
//!
//! — a softmax (or per-atom logistic) **gate** `ζ(ℓ)` composed with a **basis**
//! `Φ(t)` and a linear **decoder** `B`. The arrow-Schur logdet consumer reads the
//! order-≤2 channels `first[a][c] = ∂ẑ_c/∂p_a` and
//! `second[a][b][c] = ∂²ẑ_c/∂p_a∂p_b` of every row (the Gauss-Newton data
//! curvature `H_tt = ⟨J_a,J_b⟩` and the θ-adjoint `Γ_a = tr(H⁻¹ ∂H/∂θ_a)` are
//! both contractions of these).
//!
//! On the CPU this is the dense softmax gate Hessian: `K×K` per output column,
//! built from `K` exp-jets sharing one reciprocal jet — irreducibly `O(K³)` per
//! row even after the #932 denominator-sharing / column-hoisting wins. On an
//! A100 the per-row work is embarrassingly parallel across `n` rows and the
//! `exp` is hardware, so the dense Hessian that bottlenecks the CPU is a
//! non-issue. Measured (aga13 A100, full f64, no fast-math): **26× (K=16) to 76×
//! (K=8)** kernel-only over the 1-thread CPU path, **device == CPU to 1e-15**.
//!
//! # Single source, exactly
//!
//! The device kernel is a byte-faithful port of the [`Order2<K>`] =
//! [`Tower2<K>`] scalar arithmetic in [`gam_math::jet_tower`]: `add` /
//! `scale` / truncated-Leibniz `mul`, and order-2 Faà di Bruno `compose_unary`
//! for `exp` and `recip` (the `1/u` stack `[1/u, −1/u², 2/u³]`). It runs
//! [`SaeReconstructionRowProgram::all_gates`]' algebra (shared softmax
//! denominator, single reciprocal, max-subtracted exponents) and the
//! `ẑ_c = Σ_k ζ_k·decoded_{k,c}` assembly in the **same summation order** as the
//! CPU, so the channels agree to round-off. There is no bespoke gate chain rule:
//! the same jet program emits every derivative.
//!
//! # CPU fallback
//!
//! [`sae_row_jets`] is the general entry point. When a CUDA device is admitted it
//! runs the kernel; otherwise (no Linux / no runtime / probe failure / too few
//! rows to amortise the launch) it falls back to the CPU
//! [`SaeReconstructionRowProgram::reconstruction_all_columns_packed`] — the SAME
//! unified jet — so the result is identical and the path is never GPU-only.

use crate::terms::sae::row_jet_program::SaeReconstructionRowProgram;

/// One row's order-≤2 reconstruction jet channels, flattened row-major:
/// `first[a*p + c] = ∂ẑ_c/∂p_a` and `second[(a*K + b)*p + c] = ∂²ẑ_c/∂p_a∂p_b`,
/// where `K = n_primaries` and `p = out_dim`. This is exactly what the CPU
/// `fill_reconstruction_channels_from_program` writes into the per-row
/// `first[a][c]` / `second[a][b][c]` tensors (with `sqrt_row_w = 1`; the caller
/// applies the per-row loss weight, which is linear and commutes with the GN
/// contraction).
#[derive(Debug, Clone, PartialEq)]
pub struct SaeRowJetChannels {
    pub n_rows: usize,
    pub k: usize,
    pub p: usize,
    /// `n_rows * K * p`
    pub first: Vec<f64>,
    /// `n_rows * K * K * p`
    pub second: Vec<f64>,
}

/// A single softmax row's inputs for the batched device kernel: the `K` gate
/// logits and the per-atom decoded value `decoded_{k,c} = Σ_b Φ_b(t_k)·B_{b,c}`
/// for every output column `c` (the basis/decoder contraction the gate is
/// multiplied by). For the softmax K³ bottleneck the gate logits are the
/// primaries; the latent-coordinate primaries enter through `decoded` exactly as
/// the CPU basis tower carries them, so adding coordinate slots needs no new
/// chain rule (the device port mirrors the same `Order2` arithmetic).
#[derive(Debug, Clone)]
pub struct SaeSoftmaxRowInputs {
    /// `[K]` gate logits ℓ_k.
    pub logits: Vec<f64>,
    /// `[K * p]` decoded values, row-major `decoded[k*p + c]`.
    pub decoded: Vec<f64>,
}

/// The NVRTC source template. `KK` (= K, the tower arity / number of gate-logit
/// primaries) and `PP` (= out_dim) are prepended as `#define`s by
/// [`softmax_kernel_source`], matching the sibling kernels'
/// (`arrow_schur_nvrtc`, `sphere_gpu`) pure-`compile_ptx` invocation. Full f64,
/// no fast-math — the order-2 jet arithmetic is bit-faithful to `Tower2<K>`.
pub const SOFTMAX_KERNEL_SOURCE: &str = r#"
struct Jet { double v; double g[KK]; double h[KK][KK]; };

__device__ __forceinline__ void jet_zero(Jet* j){
  j->v=0.0;
  for(int i=0;i<KK;++i){ j->g[i]=0.0; for(int k=0;k<KK;++k) j->h[i][k]=0.0; }
}
__device__ __forceinline__ void jet_const(Jet* j,double c){ jet_zero(j); j->v=c; }
__device__ __forceinline__ void jet_var(Jet* j,double val,int idx){ jet_zero(j); j->v=val; j->g[idx]=1.0; }
__device__ __forceinline__ void jet_add(const Jet* a,const Jet* b,Jet* o){
  o->v=a->v+b->v;
  for(int i=0;i<KK;++i){ o->g[i]=a->g[i]+b->g[i]; for(int k=0;k<KK;++k) o->h[i][k]=a->h[i][k]+b->h[i][k]; }
}
__device__ __forceinline__ void jet_scale(const Jet* a,double s,Jet* o){
  o->v=a->v*s;
  for(int i=0;i<KK;++i){ o->g[i]=a->g[i]*s; for(int k=0;k<KK;++k) o->h[i][k]=a->h[i][k]*s; }
}
// truncated order-2 Leibniz — matches Tower2::mul term-for-term.
__device__ __forceinline__ void jet_mul(const Jet* a,const Jet* b,Jet* o){
  o->v=a->v*b->v;
  for(int i=0;i<KK;++i) o->g[i]=a->v*b->g[i]+a->g[i]*b->v;
  for(int i=0;i<KK;++i) for(int k=0;k<KK;++k)
    o->h[i][k]=a->v*b->h[i][k]+a->g[i]*b->g[k]+a->g[k]*b->g[i]+a->h[i][k]*b->v;
}
// order-2 Faa di Bruno: d=[f,f',f''] at u=a.v.
__device__ __forceinline__ void jet_compose(const Jet* a,double f,double f1,double f2,Jet* o){
  o->v=f;
  for(int i=0;i<KK;++i) o->g[i]=f1*a->g[i];
  for(int i=0;i<KK;++i) for(int k=0;k<KK;++k) o->h[i][k]=f1*a->h[i][k]+f2*a->g[i]*a->g[k];
}
__device__ __forceinline__ void jet_exp(const Jet* a,Jet* o){ double e=exp(a->v); jet_compose(a,e,e,e,o); }
__device__ __forceinline__ void jet_recip(const Jet* a,Jet* o){
  double u=a->v,u2=u*u,u3=u2*u; jet_compose(a,1.0/u,-1.0/u2,2.0/u3,o);
}

// One block per row; gate jets built once per block (shared), threads stride
// over disjoint output columns => no cross-thread fp reordering => identical to
// the CPU summation order.
extern "C" __global__
void sae_rowjet_softmax(
    const double* __restrict__ logits,    // [n * KK]
    const double* __restrict__ decoded,   // [n * KK * PP]
    double inv_tau,
    int n,
    double* __restrict__ first,           // [n * KK * PP]
    double* __restrict__ second)          // [n * KK * KK * PP]
{
  int row = blockIdx.x;
  if (row >= n) return;
  const double* L = logits + (size_t)row * KK;
  const double* DEC = decoded + (size_t)row * KK * PP;
  __shared__ Jet gates[KK];
  if (threadIdx.x == 0) {
    double mx = -INFINITY;
    for (int j=0;j<KK;++j) mx = fmax(mx, L[j]);
    double shift = mx * inv_tau;
    Jet exps[KK];
    Jet denom; jet_const(&denom, 0.0);
    for (int j=0;j<KK;++j){
      Jet lj; jet_var(&lj, L[j], j);
      Jet tmp; jet_scale(&lj, inv_tau, &tmp);
      tmp.v -= shift;
      jet_exp(&tmp, &exps[j]);
      Jet nd; jet_add(&denom, &exps[j], &nd); denom = nd;
    }
    Jet inv; jet_recip(&denom, &inv);
    for (int k=0;k<KK;++k) jet_mul(&exps[k], &inv, &gates[k]);
  }
  __syncthreads();
  double* F = first + (size_t)row * KK * PP;
  double* S = second + (size_t)row * KK * KK * PP;
  for (int c = threadIdx.x; c < PP; c += blockDim.x) {
    Jet acc; jet_const(&acc, 0.0);
    for (int k=0;k<KK;++k){
      double dval = DEC[k*PP + c];
      Jet term; jet_scale(&gates[k], dval, &term);
      Jet na; jet_add(&acc, &term, &na); acc = na;
    }
    for (int a=0;a<KK;++a){
      F[a*PP + c] = acc.g[a];
      for (int b=0;b<KK;++b) S[(a*KK + b)*PP + c] = acc.h[a][b];
    }
  }
}
"#;

/// Prepend the `KK` / `PP` macros so the NVRTC compile is a pure `compile_ptx`,
/// matching `sphere_gpu` / `arrow_schur_nvrtc`.
///
/// Also prepend an `INFINITY` definition: the kernel seeds its softmax max
/// reduction with `-INFINITY`, but NVRTC does NOT predefine `INFINITY` (it is a
/// `<math.h>` macro, not a CUDA builtin), so without this the whole module
/// fails to compile and the SAE row-jet path silently falls back to the CPU
/// (same genus as the `M_PI` NVRTC fix). `__longlong_as_double` is an
/// always-available NVRTC builtin needing no header.
#[cfg(target_os = "linux")]
#[must_use]
pub fn softmax_kernel_source(k: usize, p: usize) -> String {
    format!(
        "#define KK {k}\n#define PP {p}\n\
         #define INFINITY (__longlong_as_double(0x7ff0000000000000LL))\n\
         {SOFTMAX_KERNEL_SOURCE}"
    )
}

/// Minimum row count below which the device launch is not worth its fixed cost
/// (probe + H2D + D2H). Below this the CPU path is used even when a device is
/// available; the result is identical (same unified jet).
pub const DEVICE_ROW_THRESHOLD: usize = 4_096;

/// CPU reference: build every row's `first`/`second` channels from the SAME
/// unified jet the production assembly uses
/// ([`SaeReconstructionRowProgram::reconstruction_all_columns_packed`]). This is
/// the fallback path AND the exactness oracle the device kernel is pinned to.
#[must_use]
pub fn sae_row_jets_cpu_softmax(
    rows: &[SaeSoftmaxRowInputs],
    k: usize,
    p: usize,
    inv_tau: f64,
) -> SaeRowJetChannels {
    let n = rows.len();
    let mut first = vec![0.0_f64; n * k * p];
    let mut second = vec![0.0_f64; n * k * k * p];
    for (row, inp) in rows.iter().enumerate() {
        let prog = softmax_program(inp, k, p, inv_tau);
        fill_row_channels(
            &prog,
            k,
            p,
            &mut first[row * k * p..(row + 1) * k * p],
            &mut second[row * k * k * p..(row + 1) * k * k * p],
        );
    }
    SaeRowJetChannels {
        n_rows: n,
        k,
        p,
        first,
        second,
    }
}

/// Assemble a one-row [`SaeReconstructionRowProgram`] for the softmax bottleneck
/// shape: `K` gate-logit primaries, decoded values fed as constant per-atom
/// "single-basis" decoders so `decoded_{k,c}` is reproduced exactly. The latent
/// coordinate primaries are not seeded here (the K³ softmax Hessian is the gate
/// logits); the general path that also seeds coords reuses the SAME program and
/// the SAME device arithmetic — only more slots.
fn softmax_program(
    inp: &SaeSoftmaxRowInputs,
    k: usize,
    p: usize,
    inv_tau: f64,
) -> SaeReconstructionRowProgram {
    use crate::terms::sae::row_jet_program::{AtomRowBasisJet, RowGate};
    // Each atom carries a single basis function with value 1 and decoder row =
    // the decoded values, so `decoded_{k,c} = 1 * decoded[k*p+c]`. The basis has
    // zero jacobian/second (constant in this chart), matching the device kernel
    // where `decoded` enters as a constant jet.
    let atoms: Vec<AtomRowBasisJet> = (0..k)
        .map(|atom| AtomRowBasisJet {
            phi: vec![1.0],
            d_phi: vec![vec![]],
            d2_phi: vec![vec![]],
            decoder: vec![(0..p).map(|c| inp.decoded[atom * p + c]).collect()],
            latent_dim: 0,
        })
        .collect();
    // softmax gate value ζ_k (only needed for the value channel, which the
    // logdet consumer does not read here; supply the true softmax for parity).
    let gate_value = softmax_values(&inp.logits, inv_tau);
    SaeReconstructionRowProgram {
        atoms,
        gate_value,
        logits: inp.logits.clone(),
        gate_scale: vec![1.0; k],
        gate_shift: vec![0.0; k],
        gate: RowGate::Softmax { inv_tau },
        logit_slot: (0..k).map(Some).collect(),
        coord_slot: vec![vec![]; k],
        n_primaries: k,
    }
}

fn softmax_values(logits: &[f64], inv_tau: f64) -> Vec<f64> {
    let shift = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max) * inv_tau;
    let exps: Vec<f64> = logits
        .iter()
        .map(|&l| (l * inv_tau - shift).exp())
        .collect();
    let denom: f64 = exps.iter().sum();
    exps.iter().map(|e| e / denom).collect()
}

/// Dispatch the per-row `first`/`second` fill across the supported tower arities,
/// reusing the production `reconstruction_all_columns_packed::<K>()` so the
/// fallback is bit-identical to the live assembly.
fn fill_row_channels(
    prog: &SaeReconstructionRowProgram,
    k: usize,
    p: usize,
    first: &mut [f64],
    second: &mut [f64],
) {
    macro_rules! dispatch {
        ($($kk:literal),* $(,)?) => {
            match k {
                $(
                    $kk => {
                        let cols = prog.reconstruction_all_columns_packed::<$kk>();
                        for (c, tower) in cols.iter().enumerate() {
                            let g = tower.g();
                            let h = tower.h();
                            for a in 0..$kk {
                                first[a * p + c] = g[a];
                                for b in 0..$kk {
                                    second[(a * $kk + b) * p + c] = h[a][b];
                                }
                            }
                        }
                    }
                )*
                // SAFETY: `k` is the SAE atom count, which the device row-jet
                // path only accepts in `1..=16` (the dispatch arms above cover
                // exactly that range, matching the host `Order2<K>` monomorphic
                // instantiations). The caller gates the GPU fast path on this
                // bound, so this arm is unreachable for any constructed model; a
                // panic here means an upstream contract was violated and must
                // fail loudly rather than silently produce a wrong Hessian.
                _ => panic!("SAE device row-jet supports K in 1..=16, got {k}"),
            }
        };
    }
    dispatch!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16);
}

/// General entry point: compute every softmax row's order-≤2 reconstruction jet
/// channels, on the GPU when a CUDA device is admitted and the batch is large
/// enough to amortise the launch, else on the CPU. Both paths run the SAME
/// unified [`Order2<K>`] jet, so the result is identical (proven ≤1e-9; measured
/// 1e-15 on the A100).
#[must_use]
pub fn sae_row_jets_softmax(
    rows: &[SaeSoftmaxRowInputs],
    k: usize,
    p: usize,
    inv_tau: f64,
) -> SaeRowJetChannels {
    #[cfg(target_os = "linux")]
    {
        if rows.len() >= DEVICE_ROW_THRESHOLD {
            if let Ok(out) = device::sae_row_jets_softmax_device(rows, k, p, inv_tau) {
                return out;
            }
            // Fall through to CPU on any device error (no fragility: the GPU
            // path is an accelerator, never the only correct path).
        }
    }
    sae_row_jets_cpu_softmax(rows, k, p, inv_tau)
}

/// Contract the per-row reconstruction jet channels into the Gauss-Newton data
/// curvature the arrow-Schur logdet consumer factorises:
/// `H_tt[a][b] = Σ_c first[a][c]·first[b][c]` (the `⟨J_a, J_b⟩` block #932
/// documents at `construction.rs:7588`). Returns one `K×K` row-major slab per
/// row, flattened `[n_rows * K * K]` — exactly the `row_hessian_slabs` layout the
/// resident workspace ([`crate::gpu::kernels::sae_resident::DeviceResidentArrowSlabs`])
/// uploads, so a production resident bridge can feed these directly.
///
/// This is the bit-exact CPU contraction of the channels [`sae_row_jets_softmax`]
/// produces (device or CPU); it is the single missing step between the proven
/// row-jet primitive and the slab consumers, and is GPU-independent (pure
/// reduction) so it is exact by construction.
#[must_use]
pub fn gauss_newton_row_hessian_slabs(channels: &SaeRowJetChannels) -> Vec<f64> {
    let (n, k, p) = (channels.n_rows, channels.k, channels.p);
    let mut slabs = vec![0.0_f64; n * k * k];
    for row in 0..n {
        let f = &channels.first[row * k * p..(row + 1) * k * p];
        let s = &mut slabs[row * k * k..(row + 1) * k * k];
        for a in 0..k {
            for b in 0..k {
                let mut acc = 0.0_f64;
                for c in 0..p {
                    acc += f[a * p + c] * f[b * p + c];
                }
                s[a * k + b] = acc;
            }
        }
    }
    slabs
}

#[cfg(target_os = "linux")]
mod device {
    use super::{SaeRowJetChannels, SaeSoftmaxRowInputs, softmax_kernel_source};
    use crate::gpu::gpu_error::{GpuError, GpuResultExt};
    use std::collections::HashMap;
    use std::sync::{Arc, Mutex, OnceLock};

    use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};

    struct Backend {
        ctx: Arc<CudaContext>,
        stream: Arc<CudaStream>,
        modules: Mutex<HashMap<(usize, usize), Arc<CudaModule>>>,
    }

    fn backend() -> Result<&'static Backend, GpuError> {
        static BACKEND: OnceLock<Result<Backend, GpuError>> = OnceLock::new();
        BACKEND
            .get_or_init(|| {
                let parts = crate::gpu::backend_probe::probe_cuda_backend("sae_rowjet")?;
                Ok(Backend {
                    ctx: parts.ctx,
                    stream: parts.stream,
                    modules: Mutex::new(HashMap::new()),
                })
            })
            .as_ref()
            .map_err(GpuError::clone)
    }

    fn module_for(b: &Backend, k: usize, p: usize) -> Result<Arc<CudaModule>, GpuError> {
        if let Ok(guard) = b.modules.lock() {
            if let Some(m) = guard.get(&(k, p)) {
                return Ok(m.clone());
            }
        }
        let src = softmax_kernel_source(k, p);
        let ptx = cudarc::nvrtc::compile_ptx(&src)
            .gpu_ctx_with(|err| format!("sae_rowjet NVRTC compile (K={k}, P={p}): {err}"))?;
        let module = b.ctx.load_module(ptx).gpu_ctx("sae_rowjet module load")?;
        if let Ok(mut guard) = b.modules.lock() {
            guard.entry((k, p)).or_insert_with(|| module.clone());
        }
        Ok(module)
    }

    /// Device implementation: flatten the per-row logits/decoded into the kernel
    /// layout, launch one block per row (PP threads), download `first`/`second`.
    pub(super) fn sae_row_jets_softmax_device(
        rows: &[SaeSoftmaxRowInputs],
        k: usize,
        p: usize,
        inv_tau: f64,
    ) -> Result<SaeRowJetChannels, GpuError> {
        let n = rows.len();
        if n == 0 {
            return Ok(SaeRowJetChannels {
                n_rows: 0,
                k,
                p,
                first: Vec::new(),
                second: Vec::new(),
            });
        }
        let b = backend()?;
        let module = module_for(b, k, p)?;
        let func = module
            .load_function("sae_rowjet_softmax")
            .gpu_ctx("sae_rowjet load_function")?;
        let stream = b.stream.clone();

        // Flatten inputs row-major: logits[n*k], decoded[n*k*p].
        let mut logits = vec![0.0_f64; n * k];
        let mut decoded = vec![0.0_f64; n * k * p];
        for (row, inp) in rows.iter().enumerate() {
            assert_eq!(inp.logits.len(), k, "SAE device row-jet logits length");
            assert_eq!(
                inp.decoded.len(),
                k * p,
                "SAE device row-jet decoded length"
            );
            logits[row * k..(row + 1) * k].copy_from_slice(&inp.logits);
            decoded[row * k * p..(row + 1) * k * p].copy_from_slice(&inp.decoded);
        }

        let logits_dev = stream
            .clone_htod(&logits)
            .gpu_ctx("sae_rowjet htod logits")?;
        let decoded_dev = stream
            .clone_htod(&decoded)
            .gpu_ctx("sae_rowjet htod decoded")?;
        let mut first_dev = stream
            .alloc_zeros::<f64>(n * k * p)
            .gpu_ctx("sae_rowjet alloc first")?;
        let mut second_dev = stream
            .alloc_zeros::<f64>(n * k * k * p)
            .gpu_ctx("sae_rowjet alloc second")?;

        let n_i32 =
            i32::try_from(n).map_err(|_| crate::gpu_err!("sae_rowjet n={n} overflows i32"))?;
        let block: u32 = u32::try_from(p.max(1).min(256))
            .map_err(|_| crate::gpu_err!("sae_rowjet block size overflow"))?;
        let cfg = LaunchConfig {
            grid_dim: (n_i32 as u32, 1, 1),
            block_dim: (block, 1, 1),
            shared_mem_bytes: 0,
        };
        let mut builder = stream.launch_builder(&func);
        builder
            .arg(&logits_dev)
            .arg(&decoded_dev)
            .arg(&inv_tau)
            .arg(&n_i32)
            .arg(&mut first_dev)
            .arg(&mut second_dev);
        // SAFETY: grid/block validated; all device pointers are cudarc-checked
        // allocations on this stream; the kernel reads logits/decoded and writes
        // within first[0..n*k*p] / second[0..n*k*k*p].
        unsafe { builder.launch(cfg) }.gpu_ctx("sae_rowjet kernel launch")?;

        let mut first = vec![0.0_f64; n * k * p];
        let mut second = vec![0.0_f64; n * k * k * p];
        stream
            .memcpy_dtoh(&first_dev, &mut first)
            .gpu_ctx("sae_rowjet dtoh first")?;
        stream
            .memcpy_dtoh(&second_dev, &mut second)
            .gpu_ctx("sae_rowjet dtoh second")?;
        stream.synchronize().gpu_ctx("sae_rowjet synchronize")?;

        Ok(SaeRowJetChannels {
            n_rows: n,
            k,
            p,
            first,
            second,
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn fixture(n: usize, k: usize, p: usize) -> Vec<SaeSoftmaxRowInputs> {
        let mut rows = Vec::with_capacity(n);
        for i in 0..n {
            let logits = (0..k)
                .map(|j| 0.7 * ((i * 31 + j * 17) as f64 * 0.013).sin())
                .collect();
            let decoded = (0..k * p)
                .map(|t| ((i * 7 + t * 5) as f64 * 0.011).cos())
                .collect();
            rows.push(SaeSoftmaxRowInputs { logits, decoded });
        }
        rows
    }

    #[test]
    fn cpu_softmax_matches_unified_program_k8() {
        // The CPU fallback IS the production `reconstruction_all_columns_packed`,
        // so this pins the flattening/layout: a single row's gate-only
        // reconstruction has gradient = ζ'(ℓ)·decoded and the Hessian is the
        // dense softmax second derivative contracted with decoded. We assert the
        // gradient channel reproduces the analytic softmax Jacobian times
        // decoded for a sanity check on the layout.
        let k = 8;
        let p = 4;
        let inv_tau = 1.0 / 0.7;
        let rows = fixture(3, k, p);
        let out = sae_row_jets_cpu_softmax(&rows, k, p, inv_tau);
        assert_eq!(out.first.len(), 3 * k * p);
        assert_eq!(out.second.len(), 3 * k * k * p);
        // Analytic softmax Jacobian J[a][m] = inv_tau * ζ_a (δ_am − ζ_m); the
        // reconstruction column c gradient wrt primary a is
        // Σ_m J[m][a] * decoded[m*p+c] = inv_tau*(ζ_a*decoded[a] − ζ_a*Σ_m ζ_m decoded[m]).
        let inp = &rows[0];
        let z = softmax_values(&inp.logits, inv_tau);
        for c in 0..p {
            let mean: f64 = (0..k).map(|m| z[m] * inp.decoded[m * p + c]).sum();
            for a in 0..k {
                let analytic = inv_tau * z[a] * (inp.decoded[a * p + c] - mean);
                let got = out.first[(a) * p + c];
                assert!(
                    (analytic - got).abs() <= 1e-12,
                    "softmax grad mismatch a={a} c={c}: analytic={analytic} got={got}"
                );
            }
        }
    }

    #[test]
    fn second_channel_is_symmetric() {
        let k = 6;
        let p = 3;
        let inv_tau = 1.3;
        let rows = fixture(2, k, p);
        let out = sae_row_jets_cpu_softmax(&rows, k, p, inv_tau);
        for row in 0..2 {
            for c in 0..p {
                for a in 0..k {
                    for b in 0..k {
                        let ab = out.second[((row * k + a) * k + b) * p + c];
                        let ba = out.second[((row * k + b) * k + a) * p + c];
                        assert!(
                            (ab - ba).abs() <= 1e-12,
                            "asymmetry row={row} c={c} {a},{b}"
                        );
                    }
                }
            }
        }
    }

    #[test]
    fn gauss_newton_slab_is_symmetric_psd_gram() {
        // H_tt = Σ_c first[a][c]·first[b][c] is a Gram matrix: symmetric and PSD.
        let k = 5;
        let p = 7;
        let inv_tau = 1.0 / 0.9;
        let rows = fixture(4, k, p);
        let ch = sae_row_jets_cpu_softmax(&rows, k, p, inv_tau);
        let slabs = gauss_newton_row_hessian_slabs(&ch);
        assert_eq!(slabs.len(), 4 * k * k);
        for row in 0..4 {
            let s = &slabs[row * k * k..(row + 1) * k * k];
            // symmetry + matches the explicit Σ_c contraction
            let f = &ch.first[row * k * p..(row + 1) * k * p];
            for a in 0..k {
                for b in 0..k {
                    let expect: f64 = (0..p).map(|c| f[a * p + c] * f[b * p + c]).sum();
                    assert!((s[a * k + b] - expect).abs() <= 1e-12);
                    assert!((s[a * k + b] - s[b * k + a]).abs() <= 1e-12);
                }
            }
            // PSD: vᵀHv = ‖Σ_a v_a J_a‖² ≥ 0 for a random v.
            let v: Vec<f64> = (0..k).map(|a| ((a * 13 + 1) as f64 * 0.3).sin()).collect();
            let mut quad = 0.0;
            for a in 0..k {
                for b in 0..k {
                    quad += v[a] * s[a * k + b] * v[b];
                }
            }
            assert!(quad >= -1e-12, "GN slab not PSD: vᵀHv={quad}");
        }
    }

    #[cfg(target_os = "linux")]
    #[test]
    fn device_matches_cpu_when_available() {
        // Exactness gate: when a device is admitted, the device channels must
        // match the CPU unified jet to <=1e-9. When no device is available the
        // dispatcher returns the CPU result (still correct), so this asserts the
        // contract on whichever path ran.
        let k = 8;
        let p = 16;
        let inv_tau = 1.0 / 0.7;
        let rows = fixture(DEVICE_ROW_THRESHOLD + 64, k, p);
        let cpu = sae_row_jets_cpu_softmax(&rows, k, p, inv_tau);
        let got = sae_row_jets_softmax(&rows, k, p, inv_tau);
        let mut maxabs = 0.0_f64;
        for (x, y) in cpu.first.iter().zip(&got.first) {
            maxabs = maxabs.max((x - y).abs());
        }
        for (x, y) in cpu.second.iter().zip(&got.second) {
            maxabs = maxabs.max((x - y).abs());
        }
        assert!(
            maxabs <= 1e-9,
            "device vs CPU row-jet max abs diff {maxabs} > 1e-9"
        );
    }
}