use crate::terms::sae::row_jet_program::SaeReconstructionRowProgram;
#[derive(Debug, Clone, PartialEq)]
pub struct SaeRowJetChannels {
pub n_rows: usize,
pub k: usize,
pub p: usize,
pub first: Vec<f64>,
pub second: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct SaeSoftmaxRowInputs {
pub logits: Vec<f64>,
pub decoded: Vec<f64>,
}
pub const SOFTMAX_KERNEL_SOURCE: &str = r#"
struct Jet { double v; double g[KK]; double h[KK][KK]; };
__device__ __forceinline__ void jet_zero(Jet* j){
j->v=0.0;
for(int i=0;i<KK;++i){ j->g[i]=0.0; for(int k=0;k<KK;++k) j->h[i][k]=0.0; }
}
__device__ __forceinline__ void jet_const(Jet* j,double c){ jet_zero(j); j->v=c; }
__device__ __forceinline__ void jet_var(Jet* j,double val,int idx){ jet_zero(j); j->v=val; j->g[idx]=1.0; }
__device__ __forceinline__ void jet_add(const Jet* a,const Jet* b,Jet* o){
o->v=a->v+b->v;
for(int i=0;i<KK;++i){ o->g[i]=a->g[i]+b->g[i]; for(int k=0;k<KK;++k) o->h[i][k]=a->h[i][k]+b->h[i][k]; }
}
__device__ __forceinline__ void jet_scale(const Jet* a,double s,Jet* o){
o->v=a->v*s;
for(int i=0;i<KK;++i){ o->g[i]=a->g[i]*s; for(int k=0;k<KK;++k) o->h[i][k]=a->h[i][k]*s; }
}
// truncated order-2 Leibniz — matches Tower2::mul term-for-term.
__device__ __forceinline__ void jet_mul(const Jet* a,const Jet* b,Jet* o){
o->v=a->v*b->v;
for(int i=0;i<KK;++i) o->g[i]=a->v*b->g[i]+a->g[i]*b->v;
for(int i=0;i<KK;++i) for(int k=0;k<KK;++k)
o->h[i][k]=a->v*b->h[i][k]+a->g[i]*b->g[k]+a->g[k]*b->g[i]+a->h[i][k]*b->v;
}
// order-2 Faa di Bruno: d=[f,f',f''] at u=a.v.
__device__ __forceinline__ void jet_compose(const Jet* a,double f,double f1,double f2,Jet* o){
o->v=f;
for(int i=0;i<KK;++i) o->g[i]=f1*a->g[i];
for(int i=0;i<KK;++i) for(int k=0;k<KK;++k) o->h[i][k]=f1*a->h[i][k]+f2*a->g[i]*a->g[k];
}
__device__ __forceinline__ void jet_exp(const Jet* a,Jet* o){ double e=exp(a->v); jet_compose(a,e,e,e,o); }
__device__ __forceinline__ void jet_recip(const Jet* a,Jet* o){
double u=a->v,u2=u*u,u3=u2*u; jet_compose(a,1.0/u,-1.0/u2,2.0/u3,o);
}
// One block per row; gate jets built once per block (shared), threads stride
// over disjoint output columns => no cross-thread fp reordering => identical to
// the CPU summation order.
extern "C" __global__
void sae_rowjet_softmax(
const double* __restrict__ logits, // [n * KK]
const double* __restrict__ decoded, // [n * KK * PP]
double inv_tau,
int n,
double* __restrict__ first, // [n * KK * PP]
double* __restrict__ second) // [n * KK * KK * PP]
{
int row = blockIdx.x;
if (row >= n) return;
const double* L = logits + (size_t)row * KK;
const double* DEC = decoded + (size_t)row * KK * PP;
__shared__ Jet gates[KK];
if (threadIdx.x == 0) {
double mx = -INFINITY;
for (int j=0;j<KK;++j) mx = fmax(mx, L[j]);
double shift = mx * inv_tau;
Jet exps[KK];
Jet denom; jet_const(&denom, 0.0);
for (int j=0;j<KK;++j){
Jet lj; jet_var(&lj, L[j], j);
Jet tmp; jet_scale(&lj, inv_tau, &tmp);
tmp.v -= shift;
jet_exp(&tmp, &exps[j]);
Jet nd; jet_add(&denom, &exps[j], &nd); denom = nd;
}
Jet inv; jet_recip(&denom, &inv);
for (int k=0;k<KK;++k) jet_mul(&exps[k], &inv, &gates[k]);
}
__syncthreads();
double* F = first + (size_t)row * KK * PP;
double* S = second + (size_t)row * KK * KK * PP;
for (int c = threadIdx.x; c < PP; c += blockDim.x) {
Jet acc; jet_const(&acc, 0.0);
for (int k=0;k<KK;++k){
double dval = DEC[k*PP + c];
Jet term; jet_scale(&gates[k], dval, &term);
Jet na; jet_add(&acc, &term, &na); acc = na;
}
for (int a=0;a<KK;++a){
F[a*PP + c] = acc.g[a];
for (int b=0;b<KK;++b) S[(a*KK + b)*PP + c] = acc.h[a][b];
}
}
}
"#;
#[cfg(target_os = "linux")]
#[must_use]
pub fn softmax_kernel_source(k: usize, p: usize) -> String {
format!(
"#define KK {k}\n#define PP {p}\n\
#define INFINITY (__longlong_as_double(0x7ff0000000000000LL))\n\
{SOFTMAX_KERNEL_SOURCE}"
)
}
pub const DEVICE_ROW_THRESHOLD: usize = 4_096;
#[must_use]
pub fn sae_row_jets_cpu_softmax(
rows: &[SaeSoftmaxRowInputs],
k: usize,
p: usize,
inv_tau: f64,
) -> SaeRowJetChannels {
let n = rows.len();
let mut first = vec![0.0_f64; n * k * p];
let mut second = vec![0.0_f64; n * k * k * p];
for (row, inp) in rows.iter().enumerate() {
let prog = softmax_program(inp, k, p, inv_tau);
fill_row_channels(
&prog,
k,
p,
&mut first[row * k * p..(row + 1) * k * p],
&mut second[row * k * k * p..(row + 1) * k * k * p],
);
}
SaeRowJetChannels {
n_rows: n,
k,
p,
first,
second,
}
}
fn softmax_program(
inp: &SaeSoftmaxRowInputs,
k: usize,
p: usize,
inv_tau: f64,
) -> SaeReconstructionRowProgram {
use crate::terms::sae::row_jet_program::{AtomRowBasisJet, RowGate};
let atoms: Vec<AtomRowBasisJet> = (0..k)
.map(|atom| AtomRowBasisJet {
phi: vec![1.0],
d_phi: vec![vec![]],
d2_phi: vec![vec![]],
decoder: vec![(0..p).map(|c| inp.decoded[atom * p + c]).collect()],
latent_dim: 0,
})
.collect();
let gate_value = softmax_values(&inp.logits, inv_tau);
SaeReconstructionRowProgram {
atoms,
gate_value,
logits: inp.logits.clone(),
gate_scale: vec![1.0; k],
gate_shift: vec![0.0; k],
gate: RowGate::Softmax { inv_tau },
logit_slot: (0..k).map(Some).collect(),
coord_slot: vec![vec![]; k],
n_primaries: k,
}
}
fn softmax_values(logits: &[f64], inv_tau: f64) -> Vec<f64> {
let shift = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max) * inv_tau;
let exps: Vec<f64> = logits
.iter()
.map(|&l| (l * inv_tau - shift).exp())
.collect();
let denom: f64 = exps.iter().sum();
exps.iter().map(|e| e / denom).collect()
}
fn fill_row_channels(
prog: &SaeReconstructionRowProgram,
k: usize,
p: usize,
first: &mut [f64],
second: &mut [f64],
) {
macro_rules! dispatch {
($($kk:literal),* $(,)?) => {
match k {
$(
$kk => {
let cols = prog.reconstruction_all_columns_packed::<$kk>();
for (c, tower) in cols.iter().enumerate() {
let g = tower.g();
let h = tower.h();
for a in 0..$kk {
first[a * p + c] = g[a];
for b in 0..$kk {
second[(a * $kk + b) * p + c] = h[a][b];
}
}
}
}
)*
_ => panic!("SAE device row-jet supports K in 1..=16, got {k}"),
}
};
}
dispatch!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16);
}
#[must_use]
pub fn sae_row_jets_softmax(
rows: &[SaeSoftmaxRowInputs],
k: usize,
p: usize,
inv_tau: f64,
) -> SaeRowJetChannels {
#[cfg(target_os = "linux")]
{
if rows.len() >= DEVICE_ROW_THRESHOLD {
if let Ok(out) = device::sae_row_jets_softmax_device(rows, k, p, inv_tau) {
return out;
}
}
}
sae_row_jets_cpu_softmax(rows, k, p, inv_tau)
}
#[must_use]
pub fn gauss_newton_row_hessian_slabs(channels: &SaeRowJetChannels) -> Vec<f64> {
let (n, k, p) = (channels.n_rows, channels.k, channels.p);
let mut slabs = vec![0.0_f64; n * k * k];
for row in 0..n {
let f = &channels.first[row * k * p..(row + 1) * k * p];
let s = &mut slabs[row * k * k..(row + 1) * k * k];
for a in 0..k {
for b in 0..k {
let mut acc = 0.0_f64;
for c in 0..p {
acc += f[a * p + c] * f[b * p + c];
}
s[a * k + b] = acc;
}
}
}
slabs
}
#[cfg(target_os = "linux")]
mod device {
use super::{SaeRowJetChannels, SaeSoftmaxRowInputs, softmax_kernel_source};
use crate::gpu::gpu_error::{GpuError, GpuResultExt};
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
struct Backend {
ctx: Arc<CudaContext>,
stream: Arc<CudaStream>,
modules: Mutex<HashMap<(usize, usize), Arc<CudaModule>>>,
}
fn backend() -> Result<&'static Backend, GpuError> {
static BACKEND: OnceLock<Result<Backend, GpuError>> = OnceLock::new();
BACKEND
.get_or_init(|| {
let parts = crate::gpu::backend_probe::probe_cuda_backend("sae_rowjet")?;
Ok(Backend {
ctx: parts.ctx,
stream: parts.stream,
modules: Mutex::new(HashMap::new()),
})
})
.as_ref()
.map_err(GpuError::clone)
}
fn module_for(b: &Backend, k: usize, p: usize) -> Result<Arc<CudaModule>, GpuError> {
if let Ok(guard) = b.modules.lock() {
if let Some(m) = guard.get(&(k, p)) {
return Ok(m.clone());
}
}
let src = softmax_kernel_source(k, p);
let ptx = cudarc::nvrtc::compile_ptx(&src)
.gpu_ctx_with(|err| format!("sae_rowjet NVRTC compile (K={k}, P={p}): {err}"))?;
let module = b.ctx.load_module(ptx).gpu_ctx("sae_rowjet module load")?;
if let Ok(mut guard) = b.modules.lock() {
guard.entry((k, p)).or_insert_with(|| module.clone());
}
Ok(module)
}
pub(super) fn sae_row_jets_softmax_device(
rows: &[SaeSoftmaxRowInputs],
k: usize,
p: usize,
inv_tau: f64,
) -> Result<SaeRowJetChannels, GpuError> {
let n = rows.len();
if n == 0 {
return Ok(SaeRowJetChannels {
n_rows: 0,
k,
p,
first: Vec::new(),
second: Vec::new(),
});
}
let b = backend()?;
let module = module_for(b, k, p)?;
let func = module
.load_function("sae_rowjet_softmax")
.gpu_ctx("sae_rowjet load_function")?;
let stream = b.stream.clone();
let mut logits = vec![0.0_f64; n * k];
let mut decoded = vec![0.0_f64; n * k * p];
for (row, inp) in rows.iter().enumerate() {
assert_eq!(inp.logits.len(), k, "SAE device row-jet logits length");
assert_eq!(
inp.decoded.len(),
k * p,
"SAE device row-jet decoded length"
);
logits[row * k..(row + 1) * k].copy_from_slice(&inp.logits);
decoded[row * k * p..(row + 1) * k * p].copy_from_slice(&inp.decoded);
}
let logits_dev = stream
.clone_htod(&logits)
.gpu_ctx("sae_rowjet htod logits")?;
let decoded_dev = stream
.clone_htod(&decoded)
.gpu_ctx("sae_rowjet htod decoded")?;
let mut first_dev = stream
.alloc_zeros::<f64>(n * k * p)
.gpu_ctx("sae_rowjet alloc first")?;
let mut second_dev = stream
.alloc_zeros::<f64>(n * k * k * p)
.gpu_ctx("sae_rowjet alloc second")?;
let n_i32 =
i32::try_from(n).map_err(|_| crate::gpu_err!("sae_rowjet n={n} overflows i32"))?;
let block: u32 = u32::try_from(p.max(1).min(256))
.map_err(|_| crate::gpu_err!("sae_rowjet block size overflow"))?;
let cfg = LaunchConfig {
grid_dim: (n_i32 as u32, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&func);
builder
.arg(&logits_dev)
.arg(&decoded_dev)
.arg(&inv_tau)
.arg(&n_i32)
.arg(&mut first_dev)
.arg(&mut second_dev);
unsafe { builder.launch(cfg) }.gpu_ctx("sae_rowjet kernel launch")?;
let mut first = vec![0.0_f64; n * k * p];
let mut second = vec![0.0_f64; n * k * k * p];
stream
.memcpy_dtoh(&first_dev, &mut first)
.gpu_ctx("sae_rowjet dtoh first")?;
stream
.memcpy_dtoh(&second_dev, &mut second)
.gpu_ctx("sae_rowjet dtoh second")?;
stream.synchronize().gpu_ctx("sae_rowjet synchronize")?;
Ok(SaeRowJetChannels {
n_rows: n,
k,
p,
first,
second,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn fixture(n: usize, k: usize, p: usize) -> Vec<SaeSoftmaxRowInputs> {
let mut rows = Vec::with_capacity(n);
for i in 0..n {
let logits = (0..k)
.map(|j| 0.7 * ((i * 31 + j * 17) as f64 * 0.013).sin())
.collect();
let decoded = (0..k * p)
.map(|t| ((i * 7 + t * 5) as f64 * 0.011).cos())
.collect();
rows.push(SaeSoftmaxRowInputs { logits, decoded });
}
rows
}
#[test]
fn cpu_softmax_matches_unified_program_k8() {
let k = 8;
let p = 4;
let inv_tau = 1.0 / 0.7;
let rows = fixture(3, k, p);
let out = sae_row_jets_cpu_softmax(&rows, k, p, inv_tau);
assert_eq!(out.first.len(), 3 * k * p);
assert_eq!(out.second.len(), 3 * k * k * p);
let inp = &rows[0];
let z = softmax_values(&inp.logits, inv_tau);
for c in 0..p {
let mean: f64 = (0..k).map(|m| z[m] * inp.decoded[m * p + c]).sum();
for a in 0..k {
let analytic = inv_tau * z[a] * (inp.decoded[a * p + c] - mean);
let got = out.first[(a) * p + c];
assert!(
(analytic - got).abs() <= 1e-12,
"softmax grad mismatch a={a} c={c}: analytic={analytic} got={got}"
);
}
}
}
#[test]
fn second_channel_is_symmetric() {
let k = 6;
let p = 3;
let inv_tau = 1.3;
let rows = fixture(2, k, p);
let out = sae_row_jets_cpu_softmax(&rows, k, p, inv_tau);
for row in 0..2 {
for c in 0..p {
for a in 0..k {
for b in 0..k {
let ab = out.second[((row * k + a) * k + b) * p + c];
let ba = out.second[((row * k + b) * k + a) * p + c];
assert!(
(ab - ba).abs() <= 1e-12,
"asymmetry row={row} c={c} {a},{b}"
);
}
}
}
}
}
#[test]
fn gauss_newton_slab_is_symmetric_psd_gram() {
let k = 5;
let p = 7;
let inv_tau = 1.0 / 0.9;
let rows = fixture(4, k, p);
let ch = sae_row_jets_cpu_softmax(&rows, k, p, inv_tau);
let slabs = gauss_newton_row_hessian_slabs(&ch);
assert_eq!(slabs.len(), 4 * k * k);
for row in 0..4 {
let s = &slabs[row * k * k..(row + 1) * k * k];
let f = &ch.first[row * k * p..(row + 1) * k * p];
for a in 0..k {
for b in 0..k {
let expect: f64 = (0..p).map(|c| f[a * p + c] * f[b * p + c]).sum();
assert!((s[a * k + b] - expect).abs() <= 1e-12);
assert!((s[a * k + b] - s[b * k + a]).abs() <= 1e-12);
}
}
let v: Vec<f64> = (0..k).map(|a| ((a * 13 + 1) as f64 * 0.3).sin()).collect();
let mut quad = 0.0;
for a in 0..k {
for b in 0..k {
quad += v[a] * s[a * k + b] * v[b];
}
}
assert!(quad >= -1e-12, "GN slab not PSD: vᵀHv={quad}");
}
}
#[cfg(target_os = "linux")]
#[test]
fn device_matches_cpu_when_available() {
let k = 8;
let p = 16;
let inv_tau = 1.0 / 0.7;
let rows = fixture(DEVICE_ROW_THRESHOLD + 64, k, p);
let cpu = sae_row_jets_cpu_softmax(&rows, k, p, inv_tau);
let got = sae_row_jets_softmax(&rows, k, p, inv_tau);
let mut maxabs = 0.0_f64;
for (x, y) in cpu.first.iter().zip(&got.first) {
maxabs = maxabs.max((x - y).abs());
}
for (x, y) in cpu.second.iter().zip(&got.second) {
maxabs = maxabs.max((x - y).abs());
}
assert!(
maxabs <= 1e-9,
"device vs CPU row-jet max abs diff {maxabs} > 1e-9"
);
}
}