// Batched forward tangent (JVP) evaluation on GPU.
//
// One thread per batch element. Each thread propagates both primal values and
// tangent derivatives through the tape using the forward-mode chain rule:
// unary f(a): tangent = f'(a) * a_tangent
// binary f(a,b): tangent = df/da * a_tangent + df/db * b_tangent
//
// Used for sparse Jacobian: dispatch C colors in parallel, each with different
// tangent seeds.
// ── OpCode constants ──
const OP_INPUT: u32 = 0u;
const OP_CONST: u32 = 1u;
const OP_ADD: u32 = 2u;
const OP_SUB: u32 = 3u;
const OP_MUL: u32 = 4u;
const OP_DIV: u32 = 5u;
const OP_REM: u32 = 6u;
const OP_POWF: u32 = 7u;
const OP_ATAN2: u32 = 8u;
const OP_HYPOT: u32 = 9u;
const OP_MAX: u32 = 10u;
const OP_MIN: u32 = 11u;
const OP_NEG: u32 = 12u;
const OP_RECIP: u32 = 13u;
const OP_SQRT: u32 = 14u;
const OP_CBRT: u32 = 15u;
const OP_POWI: u32 = 16u;
const OP_EXP: u32 = 17u;
const OP_EXP2: u32 = 18u;
const OP_EXPM1: u32 = 19u;
const OP_LN: u32 = 20u;
const OP_LOG2: u32 = 21u;
const OP_LOG10: u32 = 22u;
const OP_LN1P: u32 = 23u;
const OP_SIN: u32 = 24u;
const OP_COS: u32 = 25u;
const OP_TAN: u32 = 26u;
const OP_ASIN: u32 = 27u;
const OP_ACOS: u32 = 28u;
const OP_ATAN: u32 = 29u;
const OP_SINH: u32 = 30u;
const OP_COSH: u32 = 31u;
const OP_TANH: u32 = 32u;
const OP_ASINH: u32 = 33u;
const OP_ACOSH: u32 = 34u;
const OP_ATANH: u32 = 35u;
const OP_ABS: u32 = 36u;
const OP_SIGNUM: u32 = 37u;
const OP_FLOOR: u32 = 38u;
const OP_CEIL: u32 = 39u;
const OP_ROUND: u32 = 40u;
const OP_TRUNC: u32 = 41u;
const OP_FRACT: u32 = 42u;
const UNUSED: u32 = 0xFFFFFFFFu;
struct TapeMeta {
num_ops: u32,
num_inputs: u32,
num_variables: u32,
num_outputs: u32,
batch_size: u32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}
// ── Tape data (bind group 0) ──
@group(0) @binding(0) var<storage, read> opcodes: array<u32>;
@group(0) @binding(1) var<storage, read> arg0: array<u32>;
@group(0) @binding(2) var<storage, read> arg1: array<u32>;
@group(0) @binding(3) var<storage, read> constants: array<f32>;
@group(0) @binding(4) var<uniform> tape_meta: TapeMeta;
@group(0) @binding(5) var<storage, read> output_indices: array<u32>;
// ── I/O buffers (bind group 1) ──
// binding 0: primal inputs [B * num_inputs] (same x for all colors, or different per batch)
@group(1) @binding(0) var<storage, read> primal_inputs: array<f32>;
// binding 1: tangent seeds [B * num_inputs] (different per color/batch element)
@group(1) @binding(1) var<storage, read> tangent_seeds: array<f32>;
// binding 2: primals working buffer [B * num_variables]
@group(1) @binding(2) var<storage, read_write> primals: array<f32>;
// binding 3: tangents working buffer [B * num_variables]
@group(1) @binding(3) var<storage, read_write> tangents: array<f32>;
// binding 4: tangent outputs [B * num_outputs]
@group(1) @binding(4) var<storage, read_write> tangent_outputs: array<f32>;
fn sinh_f(x: f32) -> f32 { return (exp(x) - exp(-x)) * 0.5; }
fn cosh_f(x: f32) -> f32 { return (exp(x) + exp(-x)) * 0.5; }
// Precision-preserving EXPM1 / LN1P primals for small |x|, matching
// forward.wgsl helpers. `exp(x) - 1` and `log(1 + x)` cancel
// catastrophically as x → 0; the Taylor-series shortcut avoids that.
fn expm1_f32(x: f32) -> f32 {
if abs(x) < 1e-4 { return x + 0.5 * x * x; }
return exp(x) - 1.0;
}
fn ln1p_f32(x: f32) -> f32 {
if abs(x) < 1e-4 { return x - 0.5 * x * x; }
return log(1.0 + x);
}
// Overflow-safe hypot with IEEE Inf handling, mirroring forward.wgsl.
// The naive `sqrt(a² + b²)` in a tangent primal arm overflowed for
// large |a| or |b|; using this helper keeps the primal matched
// bit-for-bit to the forward kernel.
fn hypot_f32(a: f32, b: f32) -> f32 {
let ax = abs(a);
let ay = abs(b);
let inf = bitcast<f32>(0x7f800000u);
if ax == inf || ay == inf { return inf; }
let mx = max(ax, ay);
let mn = min(ax, ay);
if mx == 0.0 { return 0.0; }
let r = mn / mx;
return mx * sqrt(1.0 + r * r);
}
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let bid = gid.x;
if bid >= tape_meta.batch_size {
return;
}
let nv = tape_meta.num_variables;
let ni = tape_meta.num_inputs;
let num_ops = tape_meta.num_ops;
let n_out = tape_meta.num_outputs;
let p_base = bid * nv; // primals base
let t_base = bid * nv; // tangents base
// Initialize primals from constants, tangents to zero
for (var i = 0u; i < nv; i = i + 1u) {
primals[p_base + i] = constants[i];
tangents[t_base + i] = 0.0;
}
// Set input primals and tangent seeds
let in_base = bid * ni;
for (var i = 0u; i < ni; i = i + 1u) {
primals[p_base + i] = primal_inputs[in_base + i];
tangents[t_base + i] = tangent_seeds[in_base + i];
}
// Walk the tape: compute primals and propagate tangents
for (var i = ni; i < num_ops; i = i + 1u) {
let op = opcodes[i];
if op == OP_CONST {
continue;
}
let a_idx = arg0[i];
let b_idx = arg1[i];
let a = primals[p_base + a_idx];
let at = tangents[t_base + a_idx];
var r = 0.0f;
var rt = 0.0f;
switch op {
case 2u /* ADD */: {
let b = primals[p_base + b_idx];
let bt = tangents[t_base + b_idx];
r = a + b;
rt = at + bt;
}
case 3u /* SUB */: {
let b = primals[p_base + b_idx];
let bt = tangents[t_base + b_idx];
r = a - b;
rt = at - bt;
}
case 4u /* MUL */: {
let b = primals[p_base + b_idx];
let bt = tangents[t_base + b_idx];
r = a * b;
rt = b * at + a * bt;
}
case 5u /* DIV */: {
let b = primals[p_base + b_idx];
let bt = tangents[t_base + b_idx];
r = a / b;
let inv = 1.0 / b;
// `rt = at/b - a*bt/b² = at*inv - r*bt*inv`; avoids
// forming `inv*inv` which overflows at small |b|.
rt = inv * at - r * inv * bt;
}
case 6u /* REM */: {
let b = primals[p_base + b_idx];
let bt = tangents[t_base + b_idx];
r = a - trunc(a / b) * b;
rt = at - trunc(a / b) * bt;
}
case 7u /* POWF */: {
let b = primals[p_base + b_idx];
let bt = tangents[t_base + b_idx];
r = pow(a, b);
// Guard: at a=0, b/a and log(a) are undefined; split dx/dy
let dx = select(b * r / a * at, b * pow(a, b - 1.0) * at, a == 0.0);
let dy = select(r * log(a) * bt, 0.0, r == 0.0);
rt = dx + dy;
}
case 8u /* ATAN2 */: {
let b = primals[p_base + b_idx];
let bt = tangents[t_base + b_idx];
r = atan2(a, b);
// See `reverse.wgsl` ATAN2 — normalize by max(|a|,|b|) so
// au² + bu² is bounded and a² + b² doesn't overflow.
let mx = max(abs(a), abs(b));
if mx == 0.0 {
rt = 0.0;
} else {
let au = a / mx;
let bu = b / mx;
let denom = mx * (au * au + bu * bu);
rt = (bu * at - au * bt) / denom;
}
}
case 9u /* HYPOT */: {
let b = primals[p_base + b_idx];
let bt = tangents[t_base + b_idx];
// Call the scalar helper so the primal matches
// `forward.wgsl` bit-for-bit (rescale + IEEE Inf guard)
// and `a*a + b*b` can't overflow for large operands.
r = hypot_f32(a, b);
if r == 0.0 { rt = 0.0; } else { rt = (a * at + b * bt) / r; }
}
case 10u /* MAX */: {
let b = primals[p_base + b_idx];
let bt = tangents[t_base + b_idx];
// Pick the non-NaN operand when one is NaN (matches IEEE `max`).
// `b != b` is NaN-detection but can be folded away by
// optimizers; use an explicit bit-pattern test instead.
let b_bits = bitcast<u32>(b);
let b_is_nan = ((b_bits >> 23u) & 0xffu) == 0xffu && (b_bits & 0x7fffffu) != 0u;
if a >= b || b_is_nan { r = a; rt = at; } else { r = b; rt = bt; }
}
case 11u /* MIN */: {
let b = primals[p_base + b_idx];
let bt = tangents[t_base + b_idx];
let b_bits = bitcast<u32>(b);
let b_is_nan = ((b_bits >> 23u) & 0xffu) == 0xffu && (b_bits & 0x7fffffu) != 0u;
if a <= b || b_is_nan { r = a; rt = at; } else { r = b; rt = bt; }
}
// Unary
case 12u /* NEG */: { r = -a; rt = -at; }
case 13u /* RECIP */: { r = 1.0 / a; rt = -at / (a * a); }
case 14u /* SQRT */: { r = sqrt(a); rt = at / (2.0 * r); }
case 15u /* CBRT */: {
let s = sign(a);
r = s * pow(abs(a), 1.0 / 3.0);
rt = at / (3.0 * r * r);
}
case 16u /* POWI */: {
let exp = bitcast<i32>(b_idx);
let n = f32(exp);
r = pow(a, n);
rt = select(n * pow(a, n - 1.0) * at, 0.0, exp == 0);
}
case 17u /* EXP */: { r = exp(a); rt = r * at; }
case 18u /* EXP2 */: { r = exp2(a); rt = r * log(2.0) * at; }
case 19u /* EXPM1 */: { r = expm1_f32(a); rt = (r + 1.0) * at; }
case 20u /* LN */: { r = log(a); rt = at / a; }
case 21u /* LOG2 */: { r = log2(a); rt = at / (a * log(2.0)); }
case 22u /* LOG10 */: { r = log(a) / log(10.0); rt = at / (a * log(10.0)); }
case 23u /* LN1P */: { r = ln1p_f32(a); rt = at / (1.0 + a); }
case 24u /* SIN */: { r = sin(a); rt = cos(a) * at; }
case 25u /* COS */: { r = cos(a); rt = -sin(a) * at; }
case 26u /* TAN */: { r = tan(a); let c = cos(a); rt = at / (c * c); }
case 27u /* ASIN */: { r = asin(a); rt = at / sqrt((1.0 - a) * (1.0 + a)); }
case 28u /* ACOS */: { r = acos(a); rt = -at / sqrt((1.0 - a) * (1.0 + a)); }
case 29u /* ATAN */: {
let aa = abs(a);
r = atan(a);
if aa > 1e8 { let inv = 1.0 / a; rt = at * inv * inv / (1.0 + inv * inv); }
else { rt = at / (1.0 + a * a); }
}
case 30u /* SINH */: { r = sinh_f(a); rt = cosh_f(a) * at; }
case 31u /* COSH */: { r = cosh_f(a); rt = sinh_f(a) * at; }
case 32u /* TANH */: { r = tanh(a); let c = cosh_f(a); rt = at / (c * c); }
case 33u /* ASINH */: {
let ax = abs(a);
r = select(-log(ax + sqrt(ax * ax + 1.0)), log(ax + sqrt(ax * ax + 1.0)), a >= 0.0);
// Overflow-safe derivative for |a| > 1e8.
if abs(a) > 1e8 {
let inv = 1.0 / a;
rt = at * abs(inv) / sqrt(1.0 + inv * inv);
} else {
rt = at / sqrt(a * a + 1.0);
}
}
case 34u /* ACOSH */: {
// Factored form under sqrt for both primal and derivative
// — retains the ε² term near a=1; matches forward.wgsl
// acosh_f32 helper and kernels::acosh_deriv.
r = log(a + sqrt((a - 1.0) * (a + 1.0)));
if abs(a) > 1e8 {
let inv = 1.0 / a;
rt = at * abs(inv) / sqrt(1.0 - inv * inv);
} else {
rt = at / sqrt((a - 1.0) * (a + 1.0));
}
}
case 35u /* ATANH */: { r = 0.5 * log((1.0 + a) / (1.0 - a)); rt = at / ((1.0 - a) * (1.0 + a)); }
case 36u /* ABS */: {
r = abs(a);
// Match Rust's `signum` via sign-bit inspection so that
// -0.0 produces -1 (not +1 as `a >= 0.0` would yield).
if a != a {
rt = 0.0;
} else {
let bits = bitcast<u32>(a);
let s = select(1.0, -1.0, (bits & 0x80000000u) != 0u);
rt = s * at;
}
}
case 37u, 38u, 39u, 40u, 41u /* SIGNUM..TRUNC */: {
// Zero derivative ops
switch op {
case 37u: { if a != a { r = a; } else if a >= 0.0 { r = 1.0; } else { r = -1.0; } }
case 38u: { r = floor(a); }
case 39u: { r = ceil(a); }
case 40u: { r = round(a); }
case 41u: { r = trunc(a); }
default: {}
}
rt = 0.0;
}
// WGSL's built-in `fract(x) = x - floor(x)` differs from Rust's
// `f32::fract() = x - trunc(x)` for negative x. Use `a - trunc(a)`
// to match the CPU/stdlib truncation convention.
case 42u /* FRACT */: { r = a - trunc(a); rt = at; }
default: {}
}
primals[p_base + i] = r;
tangents[t_base + i] = rt;
}
// Write tangent outputs
let out_base = bid * n_out;
for (var j = 0u; j < n_out; j = j + 1u) {
let oi = output_indices[j];
tangent_outputs[out_base + j] = tangents[t_base + oi];
}
}