1use crate::row_jet_program::SaeReconstructionRowProgram;
46
47#[derive(Debug, Clone, PartialEq)]
55pub struct SaeRowJetChannels {
56 pub n_rows: usize,
57 pub k: usize,
58 pub p: usize,
59 pub first: Vec<f64>,
61 pub second: Vec<f64>,
63}
64
65#[derive(Debug, Clone)]
73pub struct SaeSoftmaxRowInputs {
74 pub logits: Vec<f64>,
76 pub decoded: Vec<f64>,
78}
79
80pub const SOFTMAX_KERNEL_SOURCE: &str = r#"
86struct Jet { double v; double g[KK]; double h[KK][KK]; };
87
88__device__ __forceinline__ void jet_zero(Jet* j){
89 j->v=0.0;
90 for(int i=0;i<KK;++i){ j->g[i]=0.0; for(int k=0;k<KK;++k) j->h[i][k]=0.0; }
91}
92__device__ __forceinline__ void jet_const(Jet* j,double c){ jet_zero(j); j->v=c; }
93__device__ __forceinline__ void jet_var(Jet* j,double val,int idx){ jet_zero(j); j->v=val; j->g[idx]=1.0; }
94__device__ __forceinline__ void jet_add(const Jet* a,const Jet* b,Jet* o){
95 o->v=a->v+b->v;
96 for(int i=0;i<KK;++i){ o->g[i]=a->g[i]+b->g[i]; for(int k=0;k<KK;++k) o->h[i][k]=a->h[i][k]+b->h[i][k]; }
97}
98__device__ __forceinline__ void jet_scale(const Jet* a,double s,Jet* o){
99 o->v=a->v*s;
100 for(int i=0;i<KK;++i){ o->g[i]=a->g[i]*s; for(int k=0;k<KK;++k) o->h[i][k]=a->h[i][k]*s; }
101}
102// truncated order-2 Leibniz — matches Tower2::mul term-for-term.
103__device__ __forceinline__ void jet_mul(const Jet* a,const Jet* b,Jet* o){
104 o->v=a->v*b->v;
105 for(int i=0;i<KK;++i) o->g[i]=a->v*b->g[i]+a->g[i]*b->v;
106 for(int i=0;i<KK;++i) for(int k=0;k<KK;++k)
107 o->h[i][k]=a->v*b->h[i][k]+a->g[i]*b->g[k]+a->g[k]*b->g[i]+a->h[i][k]*b->v;
108}
109// order-2 Faa di Bruno: d=[f,f',f''] at u=a.v.
110__device__ __forceinline__ void jet_compose(const Jet* a,double f,double f1,double f2,Jet* o){
111 o->v=f;
112 for(int i=0;i<KK;++i) o->g[i]=f1*a->g[i];
113 for(int i=0;i<KK;++i) for(int k=0;k<KK;++k) o->h[i][k]=f1*a->h[i][k]+f2*a->g[i]*a->g[k];
114}
115__device__ __forceinline__ void jet_exp(const Jet* a,Jet* o){ double e=exp(a->v); jet_compose(a,e,e,e,o); }
116__device__ __forceinline__ void jet_recip(const Jet* a,Jet* o){
117 double u=a->v,u2=u*u,u3=u2*u; jet_compose(a,1.0/u,-1.0/u2,2.0/u3,o);
118}
119
120// One block per row; gate jets built once per block (shared), threads stride
121// over disjoint output columns => no cross-thread fp reordering => identical to
122// the CPU summation order.
123extern "C" __global__
124void sae_rowjet_softmax(
125 const double* __restrict__ logits, // [n * KK]
126 const double* __restrict__ decoded, // [n * KK * PP]
127 double inv_tau,
128 int n,
129 double* __restrict__ first, // [n * KK * PP]
130 double* __restrict__ second) // [n * KK * KK * PP]
131{
132 int row = blockIdx.x;
133 if (row >= n) return;
134 const double* L = logits + (size_t)row * KK;
135 const double* DEC = decoded + (size_t)row * KK * PP;
136 __shared__ Jet gates[KK];
137 if (threadIdx.x == 0) {
138 double mx = -INFINITY;
139 for (int j=0;j<KK;++j) mx = fmax(mx, L[j]);
140 double shift = mx * inv_tau;
141 Jet exps[KK];
142 Jet denom; jet_const(&denom, 0.0);
143 for (int j=0;j<KK;++j){
144 Jet lj; jet_var(&lj, L[j], j);
145 Jet tmp; jet_scale(&lj, inv_tau, &tmp);
146 tmp.v -= shift;
147 jet_exp(&tmp, &exps[j]);
148 Jet nd; jet_add(&denom, &exps[j], &nd); denom = nd;
149 }
150 Jet inv; jet_recip(&denom, &inv);
151 for (int k=0;k<KK;++k) jet_mul(&exps[k], &inv, &gates[k]);
152 }
153 __syncthreads();
154 double* F = first + (size_t)row * KK * PP;
155 double* S = second + (size_t)row * KK * KK * PP;
156 for (int c = threadIdx.x; c < PP; c += blockDim.x) {
157 for (int a=0;a<KK;++a){
158 double fg = 0.0;
159 double sh[KK];
160 for (int b=0;b<KK;++b) sh[b] = 0.0;
161 for (int k=0;k<KK;++k) {
162 double dval = DEC[k*PP + c];
163 fg += gates[k].g[a] * dval;
164 for (int b=0;b<KK;++b) sh[b] += gates[k].h[a][b] * dval;
165 }
166 F[a*PP + c] = fg;
167 for (int b=0;b<KK;++b) {
168 S[(a*KK + b)*PP + c] = sh[b];
169 }
170 }
171 }
172}
173"#;
174
175#[cfg(target_os = "linux")]
185#[must_use]
186pub fn softmax_kernel_source(k: usize, p: usize) -> String {
187 format!(
188 "#define KK {k}\n#define PP {p}\n\
189 #define INFINITY (__longlong_as_double(0x7ff0000000000000LL))\n\
190 {SOFTMAX_KERNEL_SOURCE}"
191 )
192}
193
194pub const DEVICE_ROW_THRESHOLD: usize = 4_096;
198
199#[must_use]
204pub fn sae_row_jets_cpu_softmax(
205 rows: &[SaeSoftmaxRowInputs],
206 k: usize,
207 p: usize,
208 inv_tau: f64,
209) -> SaeRowJetChannels {
210 let n = rows.len();
211 let mut first = vec![0.0_f64; n * k * p];
212 let mut second = vec![0.0_f64; n * k * k * p];
213 for (row, inp) in rows.iter().enumerate() {
214 let prog = softmax_program(inp, k, p, inv_tau);
215 fill_row_channels(
216 &prog,
217 k,
218 p,
219 &mut first[row * k * p..(row + 1) * k * p],
220 &mut second[row * k * k * p..(row + 1) * k * k * p],
221 );
222 }
223 SaeRowJetChannels {
224 n_rows: n,
225 k,
226 p,
227 first,
228 second,
229 }
230}
231
232fn softmax_program(
239 inp: &SaeSoftmaxRowInputs,
240 k: usize,
241 p: usize,
242 inv_tau: f64,
243) -> SaeReconstructionRowProgram {
244 use crate::row_jet_program::{AtomRowBasisJet, RowGate};
245 let atoms: Vec<AtomRowBasisJet> = (0..k)
250 .map(|atom| AtomRowBasisJet {
251 phi: vec![1.0],
252 d_phi: vec![vec![]],
253 d2_phi: vec![vec![]],
254 decoder: vec![(0..p).map(|c| inp.decoded[atom * p + c]).collect()],
255 latent_dim: 0,
256 })
257 .collect();
258 let gate_value = softmax_values(&inp.logits, inv_tau);
261 SaeReconstructionRowProgram {
262 atoms,
263 gate_value,
264 logits: inp.logits.clone(),
265 gate_scale: vec![1.0; k],
266 gate_shift: vec![0.0; k],
267 gate: RowGate::Softmax { inv_tau },
268 logit_slot: (0..k).map(Some).collect(),
269 coord_slot: vec![vec![]; k],
270 n_primaries: k,
271 }
272}
273
274fn softmax_values(logits: &[f64], inv_tau: f64) -> Vec<f64> {
275 let shift = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max) * inv_tau;
276 let exps: Vec<f64> = logits
277 .iter()
278 .map(|&l| (l * inv_tau - shift).exp())
279 .collect();
280 let denom: f64 = exps.iter().sum();
281 exps.iter().map(|e| e / denom).collect()
282}
283
284fn fill_row_channels(
288 prog: &SaeReconstructionRowProgram,
289 k: usize,
290 p: usize,
291 first: &mut [f64],
292 second: &mut [f64],
293) {
294 macro_rules! dispatch {
295 ($($kk:literal),* $(,)?) => {
296 match k {
297 $(
298 $kk => {
299 let cols = prog.reconstruction_all_columns_packed::<$kk>();
300 for (c, tower) in cols.iter().enumerate() {
301 let g = tower.g();
302 let h = tower.h();
303 for a in 0..$kk {
304 first[a * p + c] = g[a];
305 for b in 0..$kk {
306 second[(a * $kk + b) * p + c] = h[a][b];
307 }
308 }
309 }
310 }
311 )*
312 _ => panic!("SAE device row-jet supports K in 1..=16, got {k}"),
320 }
321 };
322 }
323 dispatch!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16);
324}
325
326#[must_use]
332pub fn sae_row_jets_softmax(
333 rows: &[SaeSoftmaxRowInputs],
334 k: usize,
335 p: usize,
336 inv_tau: f64,
337) -> SaeRowJetChannels {
338 #[cfg(target_os = "linux")]
339 {
340 if rows.len() >= DEVICE_ROW_THRESHOLD {
341 if let Ok(out) = device::sae_row_jets_softmax_device(rows, k, p, inv_tau) {
342 return out;
343 }
344 }
347 }
348 sae_row_jets_cpu_softmax(rows, k, p, inv_tau)
349}
350
351#[must_use]
364pub fn gauss_newton_row_hessian_slabs(channels: &SaeRowJetChannels) -> Vec<f64> {
365 let (n, k, p) = (channels.n_rows, channels.k, channels.p);
366 let mut slabs = vec![0.0_f64; n * k * k];
367 for row in 0..n {
368 let f = &channels.first[row * k * p..(row + 1) * k * p];
369 let s = &mut slabs[row * k * k..(row + 1) * k * k];
370 for a in 0..k {
371 for b in 0..k {
372 let mut acc = 0.0_f64;
373 for c in 0..p {
374 acc += f[a * p + c] * f[b * p + c];
375 }
376 s[a * k + b] = acc;
377 }
378 }
379 }
380 slabs
381}
382
383#[cfg(target_os = "linux")]
384mod device {
385 use super::{SaeRowJetChannels, SaeSoftmaxRowInputs, softmax_kernel_source};
386 use gam_gpu::gpu_error::{GpuError, GpuResultExt};
387 use std::collections::HashMap;
388 use std::sync::{Arc, Mutex, OnceLock};
389
390 use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
391
392 struct Backend {
393 ctx: Arc<CudaContext>,
394 stream: Arc<CudaStream>,
395 modules: Mutex<HashMap<(usize, usize), Arc<CudaModule>>>,
396 }
397
398 fn backend() -> Result<&'static Backend, GpuError> {
399 static BACKEND: OnceLock<Result<Backend, GpuError>> = OnceLock::new();
400 BACKEND
401 .get_or_init(|| {
402 let parts = gam_gpu::backend_probe::probe_cuda_backend("sae_rowjet")?;
403 Ok(Backend {
404 ctx: parts.ctx,
405 stream: parts.stream,
406 modules: Mutex::new(HashMap::new()),
407 })
408 })
409 .as_ref()
410 .map_err(GpuError::clone)
411 }
412
413 fn module_for(b: &Backend, k: usize, p: usize) -> Result<Arc<CudaModule>, GpuError> {
414 if let Ok(guard) = b.modules.lock() {
415 if let Some(m) = guard.get(&(k, p)) {
416 return Ok(m.clone());
417 }
418 }
419 let src = softmax_kernel_source(k, p);
420 let ptx = cudarc::nvrtc::compile_ptx(&src)
421 .gpu_ctx_with(|err| format!("sae_rowjet NVRTC compile (K={k}, P={p}): {err}"))?;
422 let module = b.ctx.load_module(ptx).gpu_ctx("sae_rowjet module load")?;
423 if let Ok(mut guard) = b.modules.lock() {
424 guard.entry((k, p)).or_insert_with(|| module.clone());
425 }
426 Ok(module)
427 }
428
429 pub(super) fn sae_row_jets_softmax_device(
432 rows: &[SaeSoftmaxRowInputs],
433 k: usize,
434 p: usize,
435 inv_tau: f64,
436 ) -> Result<SaeRowJetChannels, GpuError> {
437 let n = rows.len();
438 if n == 0 {
439 return Ok(SaeRowJetChannels {
440 n_rows: 0,
441 k,
442 p,
443 first: Vec::new(),
444 second: Vec::new(),
445 });
446 }
447 let b = backend()?;
448 let module = module_for(b, k, p)?;
449 let func = module
450 .load_function("sae_rowjet_softmax")
451 .gpu_ctx("sae_rowjet load_function")?;
452 let stream = b.stream.clone();
453
454 let mut logits = vec![0.0_f64; n * k];
456 let mut decoded = vec![0.0_f64; n * k * p];
457 for (row, inp) in rows.iter().enumerate() {
458 assert_eq!(inp.logits.len(), k, "SAE device row-jet logits length");
459 assert_eq!(
460 inp.decoded.len(),
461 k * p,
462 "SAE device row-jet decoded length"
463 );
464 logits[row * k..(row + 1) * k].copy_from_slice(&inp.logits);
465 decoded[row * k * p..(row + 1) * k * p].copy_from_slice(&inp.decoded);
466 }
467
468 let logits_dev = stream
469 .clone_htod(&logits)
470 .gpu_ctx("sae_rowjet htod logits")?;
471 let decoded_dev = stream
472 .clone_htod(&decoded)
473 .gpu_ctx("sae_rowjet htod decoded")?;
474 let mut first_dev = stream
475 .alloc_zeros::<f64>(n * k * p)
476 .gpu_ctx("sae_rowjet alloc first")?;
477 let mut second_dev = stream
478 .alloc_zeros::<f64>(n * k * k * p)
479 .gpu_ctx("sae_rowjet alloc second")?;
480
481 let n_i32 =
482 i32::try_from(n).map_err(|_| gam_gpu::gpu_err!("sae_rowjet n={n} overflows i32"))?;
483 let block: u32 = u32::try_from(p.max(1).min(256))
484 .map_err(|_| gam_gpu::gpu_err!("sae_rowjet block size overflow"))?;
485 let cfg = LaunchConfig {
486 grid_dim: (n_i32 as u32, 1, 1),
487 block_dim: (block, 1, 1),
488 shared_mem_bytes: 0,
489 };
490 let mut builder = stream.launch_builder(&func);
491 builder
492 .arg(&logits_dev)
493 .arg(&decoded_dev)
494 .arg(&inv_tau)
495 .arg(&n_i32)
496 .arg(&mut first_dev)
497 .arg(&mut second_dev);
498 unsafe { builder.launch(cfg) }.gpu_ctx("sae_rowjet kernel launch")?;
502
503 let mut first = vec![0.0_f64; n * k * p];
504 let mut second = vec![0.0_f64; n * k * k * p];
505 stream
506 .memcpy_dtoh(&first_dev, &mut first)
507 .gpu_ctx("sae_rowjet dtoh first")?;
508 stream
509 .memcpy_dtoh(&second_dev, &mut second)
510 .gpu_ctx("sae_rowjet dtoh second")?;
511 stream.synchronize().gpu_ctx("sae_rowjet synchronize")?;
512
513 Ok(SaeRowJetChannels {
514 n_rows: n,
515 k,
516 p,
517 first,
518 second,
519 })
520 }
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526
527 fn fixture(n: usize, k: usize, p: usize) -> Vec<SaeSoftmaxRowInputs> {
528 let mut rows = Vec::with_capacity(n);
529 for i in 0..n {
530 let logits = (0..k)
531 .map(|j| 0.7 * ((i * 31 + j * 17) as f64 * 0.013).sin())
532 .collect();
533 let decoded = (0..k * p)
534 .map(|t| ((i * 7 + t * 5) as f64 * 0.011).cos())
535 .collect();
536 rows.push(SaeSoftmaxRowInputs { logits, decoded });
537 }
538 rows
539 }
540
541 #[test]
542 fn cpu_softmax_matches_unified_program_k8() {
543 let k = 8;
550 let p = 4;
551 let inv_tau = 1.0 / 0.7;
552 let rows = fixture(3, k, p);
553 let out = sae_row_jets_cpu_softmax(&rows, k, p, inv_tau);
554 assert_eq!(out.first.len(), 3 * k * p);
555 assert_eq!(out.second.len(), 3 * k * k * p);
556 let inp = &rows[0];
560 let z = softmax_values(&inp.logits, inv_tau);
561 for c in 0..p {
562 let mean: f64 = (0..k).map(|m| z[m] * inp.decoded[m * p + c]).sum();
563 for a in 0..k {
564 let analytic = inv_tau * z[a] * (inp.decoded[a * p + c] - mean);
565 let got = out.first[(a) * p + c];
566 assert!(
567 (analytic - got).abs() <= 1e-12,
568 "softmax grad mismatch a={a} c={c}: analytic={analytic} got={got}"
569 );
570 }
571 }
572 }
573
574 #[test]
575 fn second_channel_is_symmetric() {
576 let k = 6;
577 let p = 3;
578 let inv_tau = 1.3;
579 let rows = fixture(2, k, p);
580 let out = sae_row_jets_cpu_softmax(&rows, k, p, inv_tau);
581 for row in 0..2 {
582 for c in 0..p {
583 for a in 0..k {
584 for b in 0..k {
585 let ab = out.second[((row * k + a) * k + b) * p + c];
586 let ba = out.second[((row * k + b) * k + a) * p + c];
587 assert!(
588 (ab - ba).abs() <= 1e-12,
589 "asymmetry row={row} c={c} {a},{b}"
590 );
591 }
592 }
593 }
594 }
595 }
596
597 #[test]
598 fn gauss_newton_slab_is_symmetric_psd_gram() {
599 let k = 5;
601 let p = 7;
602 let inv_tau = 1.0 / 0.9;
603 let rows = fixture(4, k, p);
604 let ch = sae_row_jets_cpu_softmax(&rows, k, p, inv_tau);
605 let slabs = gauss_newton_row_hessian_slabs(&ch);
606 assert_eq!(slabs.len(), 4 * k * k);
607 for row in 0..4 {
608 let s = &slabs[row * k * k..(row + 1) * k * k];
609 let f = &ch.first[row * k * p..(row + 1) * k * p];
611 for a in 0..k {
612 for b in 0..k {
613 let expect: f64 = (0..p).map(|c| f[a * p + c] * f[b * p + c]).sum();
614 assert!((s[a * k + b] - expect).abs() <= 1e-12);
615 assert!((s[a * k + b] - s[b * k + a]).abs() <= 1e-12);
616 }
617 }
618 let v: Vec<f64> = (0..k).map(|a| ((a * 13 + 1) as f64 * 0.3).sin()).collect();
620 let mut quad = 0.0;
621 for a in 0..k {
622 for b in 0..k {
623 quad += v[a] * s[a * k + b] * v[b];
624 }
625 }
626 assert!(quad >= -1e-12, "GN slab not PSD: vᵀHv={quad}");
627 }
628 }
629
630 #[cfg(target_os = "linux")]
631 #[test]
632 fn device_matches_cpu_when_available() {
633 let k = 8;
638 let p = 16;
639 let inv_tau = 1.0 / 0.7;
640 let rows = fixture(DEVICE_ROW_THRESHOLD + 64, k, p);
641 let cpu = sae_row_jets_cpu_softmax(&rows, k, p, inv_tau);
642 let got = sae_row_jets_softmax(&rows, k, p, inv_tau);
643 let mut maxabs = 0.0_f64;
644 for (x, y) in cpu.first.iter().zip(&got.first) {
645 maxabs = maxabs.max((x - y).abs());
646 }
647 for (x, y) in cpu.second.iter().zip(&got.second) {
648 maxabs = maxabs.max((x - y).abs());
649 }
650 assert!(
651 maxabs <= 1e-9,
652 "device vs CPU row-jet max abs diff {maxabs} > 1e-9"
653 );
654 }
655}