// Batched forward evaluation of a BytecodeTape on GPU.
//
// One compute thread per batch element. Each thread walks the tape sequentially,
// maintaining a private section of the values buffer.
// ── OpCode constants (must match OpCode repr(u8) discriminants) ──
const OP_INPUT: u32 = 0u;
const OP_CONST: u32 = 1u;
const OP_ADD: u32 = 2u;
const OP_SUB: u32 = 3u;
const OP_MUL: u32 = 4u;
const OP_DIV: u32 = 5u;
const OP_REM: u32 = 6u;
const OP_POWF: u32 = 7u;
const OP_ATAN2: u32 = 8u;
const OP_HYPOT: u32 = 9u;
const OP_MAX: u32 = 10u;
const OP_MIN: u32 = 11u;
const OP_NEG: u32 = 12u;
const OP_RECIP: u32 = 13u;
const OP_SQRT: u32 = 14u;
const OP_CBRT: u32 = 15u;
const OP_POWI: u32 = 16u;
const OP_EXP: u32 = 17u;
const OP_EXP2: u32 = 18u;
const OP_EXPM1: u32 = 19u;
const OP_LN: u32 = 20u;
const OP_LOG2: u32 = 21u;
const OP_LOG10: u32 = 22u;
const OP_LN1P: u32 = 23u;
const OP_SIN: u32 = 24u;
const OP_COS: u32 = 25u;
const OP_TAN: u32 = 26u;
const OP_ASIN: u32 = 27u;
const OP_ACOS: u32 = 28u;
const OP_ATAN: u32 = 29u;
const OP_SINH: u32 = 30u;
const OP_COSH: u32 = 31u;
const OP_TANH: u32 = 32u;
const OP_ASINH: u32 = 33u;
const OP_ACOSH: u32 = 34u;
const OP_ATANH: u32 = 35u;
const OP_ABS: u32 = 36u;
const OP_SIGNUM: u32 = 37u;
const OP_FLOOR: u32 = 38u;
const OP_CEIL: u32 = 39u;
const OP_ROUND: u32 = 40u;
const OP_TRUNC: u32 = 41u;
const OP_FRACT: u32 = 42u;
const UNUSED: u32 = 0xFFFFFFFFu;
// ── Tape data (bind group 0) ──
struct TapeMeta {
num_ops: u32,
num_inputs: u32,
num_variables: u32,
num_outputs: u32,
batch_size: u32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}
@group(0) @binding(0) var<storage, read> opcodes: array<u32>;
@group(0) @binding(1) var<storage, read> arg0: array<u32>;
@group(0) @binding(2) var<storage, read> arg1: array<u32>;
@group(0) @binding(3) var<storage, read> constants: array<f32>;
@group(0) @binding(4) var<uniform> tape_meta: TapeMeta;
@group(0) @binding(5) var<storage, read> output_indices: array<u32>;
// ── I/O buffers (bind group 1) ──
@group(1) @binding(0) var<storage, read> inputs: array<f32>;
@group(1) @binding(1) var<storage, read_write> values: array<f32>;
@group(1) @binding(2) var<storage, read_write> outputs: array<f32>;
// ── Manual implementations for functions not in WGSL ──
fn cbrt_f32(x: f32) -> f32 {
// cbrt(x) = sign(x) * |x|^(1/3)
let s = sign(x);
return s * pow(abs(x), 1.0 / 3.0);
}
fn expm1_f32(x: f32) -> f32 {
// Avoid catastrophic cancellation for small |x|
if abs(x) < 1e-4 {
return x + 0.5 * x * x;
}
return exp(x) - 1.0;
}
fn ln1p_f32(x: f32) -> f32 {
// Avoid catastrophic cancellation for small |x|
if abs(x) < 1e-4 {
return x - 0.5 * x * x;
}
return log(1.0 + x);
}
fn sinh_f32(x: f32) -> f32 {
return (exp(x) - exp(-x)) * 0.5;
}
fn cosh_f32(x: f32) -> f32 {
return (exp(x) + exp(-x)) * 0.5;
}
fn asinh_f32(x: f32) -> f32 {
// Use |x| to avoid catastrophic cancellation for large negative x:
// log(x + sqrt(x²+1)) ≈ log(0) when x << 0, but log(|x| + sqrt(x²+1)) is stable.
let a = abs(x);
let r = log(a + sqrt(a * a + 1.0));
return select(-r, r, x >= 0.0);
}
fn acosh_f32(x: f32) -> f32 {
// acosh(x) = ln(x + sqrt(x² - 1)). Use factored (x-1)(x+1) under the
// sqrt to retain precision near x=1 (`x*x - 1` rounds away the ε²
// term in f32 for x = 1+ε). Matches kernels::acosh_deriv convention.
return log(x + sqrt((x - 1.0) * (x + 1.0)));
}
fn atanh_f32(x: f32) -> f32 {
// atanh(x) = 0.5 * ln((1+x)/(1-x))
return 0.5 * log((1.0 + x) / (1.0 - x));
}
fn hypot_f32(a: f32, b: f32) -> f32 {
// Factor out max magnitude to avoid overflow for large inputs.
let ax = abs(a);
let ay = abs(b);
let inf = bitcast<f32>(0x7f800000u);
// IEEE: hypot(±Inf, x) = +Inf for any x (including NaN). The
// rescaled formula would otherwise compute `Inf/Inf = NaN` when
// both operands are Inf, diverging from CPU f64::hypot and the
// CUDA `hypot` builtin.
if ax == inf || ay == inf { return inf; }
let mx = max(ax, ay);
let mn = min(ax, ay);
if mx == 0.0 { return 0.0; }
let r = mn / mx;
return mx * sqrt(1.0 + r * r);
}
fn rem_f32(a: f32, b: f32) -> f32 {
// Rust's % is remainder (truncated), matching: a - trunc(a/b) * b
return a - trunc(a / b) * b;
}
fn recip_f32(x: f32) -> f32 {
return 1.0 / x;
}
fn log10_f32(x: f32) -> f32 {
return log(x) / log(10.0);
}
fn signum_f32(x: f32) -> f32 {
// Match Rust's f32::signum: returns ±1 for all finite values (including ±0),
// NaN for NaN. Use bitcast to check sign bit for -0.0 handling.
if x != x { return x; } // NaN passthrough
if (bitcast<u32>(x) & 0x80000000u) != 0u { return -1.0; }
return 1.0;
}
fn powi_f32(base: f32, exp_bits: u32) -> f32 {
// The exponent is stored as i32 reinterpreted as u32.
let n = bitcast<i32>(exp_bits);
return pow(base, f32(n));
}
// ── Main kernel ──
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let batch_id = gid.x;
// Guard: skip threads beyond the batch size (last workgroup padding).
if batch_id >= tape_meta.batch_size {
return;
}
let num_vars = tape_meta.num_variables;
let num_in = tape_meta.num_inputs;
let num_ops = tape_meta.num_ops;
let n_out = tape_meta.num_outputs;
// Base offset into the per-thread values section.
let base = batch_id * num_vars;
// Initialize values: copy constants, then overwrite input slots from inputs buffer.
for (var i = 0u; i < num_vars; i = i + 1u) {
values[base + i] = constants[i];
}
// Overwrite input slots with this batch element's inputs.
let input_base = batch_id * num_in;
for (var i = 0u; i < num_in; i = i + 1u) {
values[base + i] = inputs[input_base + i];
}
// Walk the tape.
for (var i = num_in; i < num_ops; i = i + 1u) {
let op = opcodes[i];
// Skip Const entries — already initialized from constants buffer.
if op == OP_CONST {
continue;
}
let a_idx = arg0[i];
let b_idx = arg1[i];
let a = values[base + a_idx];
var b = 0.0f;
if b_idx != UNUSED {
b = values[base + b_idx];
}
var r = 0.0f;
switch op {
case 2u /* ADD */: { r = a + b; }
case 3u /* SUB */: { r = a - b; }
case 4u /* MUL */: { r = a * b; }
case 5u /* DIV */: { r = a / b; }
case 6u /* REM */: { r = rem_f32(a, b); }
case 7u /* POWF */: { r = pow(a, b); }
case 8u /* ATAN2 */: { r = atan2(a, b); }
case 9u /* HYPOT */: { r = hypot_f32(a, b); }
case 10u /* MAX */: { r = max(a, b); }
case 11u /* MIN */: { r = min(a, b); }
case 12u /* NEG */: { r = -a; }
case 13u /* RECIP */: { r = recip_f32(a); }
case 14u /* SQRT */: { r = sqrt(a); }
case 15u /* CBRT */: { r = cbrt_f32(a); }
case 16u /* POWI */: { r = powi_f32(a, b_idx); }
case 17u /* EXP */: { r = exp(a); }
case 18u /* EXP2 */: { r = exp2(a); }
case 19u /* EXPM1 */: { r = expm1_f32(a); }
case 20u /* LN */: { r = log(a); }
case 21u /* LOG2 */: { r = log2(a); }
case 22u /* LOG10 */: { r = log10_f32(a); }
case 23u /* LN1P */: { r = ln1p_f32(a); }
case 24u /* SIN */: { r = sin(a); }
case 25u /* COS */: { r = cos(a); }
case 26u /* TAN */: { r = tan(a); }
case 27u /* ASIN */: { r = asin(a); }
case 28u /* ACOS */: { r = acos(a); }
case 29u /* ATAN */: { r = atan(a); }
case 30u /* SINH */: { r = sinh_f32(a); }
case 31u /* COSH */: { r = cosh_f32(a); }
case 32u /* TANH */: { r = tanh(a); }
case 33u /* ASINH */: { r = asinh_f32(a); }
case 34u /* ACOSH */: { r = acosh_f32(a); }
case 35u /* ATANH */: { r = atanh_f32(a); }
case 36u /* ABS */: { r = abs(a); }
case 37u /* SIGNUM */: { r = signum_f32(a); }
case 38u /* FLOOR */: { r = floor(a); }
case 39u /* CEIL */: { r = ceil(a); }
case 40u /* ROUND */: { r = round(a); }
case 41u /* TRUNC */: { r = trunc(a); }
// WGSL `fract` is floor-based; CPU `f32::fract()` is truncation-based.
case 42u /* FRACT */: { r = a - trunc(a); }
default: { r = 0.0; }
}
values[base + i] = r;
}
// Write outputs.
let out_base = batch_id * n_out;
for (var j = 0u; j < n_out; j = j + 1u) {
let oi = output_indices[j];
outputs[out_base + j] = values[base + oi];
}
}