echidna 0.9.0

A high-performance automatic differentiation library for Rust
Documentation
// Batched reverse (adjoint) sweep of a BytecodeTape on GPU.
//
// One compute thread per batch element. Each thread walks the tape in reverse,
// accumulating adjoints. The values buffer must already be populated by the
// forward shader.

// ── OpCode constants (must match OpCode repr(u8) discriminants) ──
const OP_INPUT:  u32 = 0u;
const OP_CONST:  u32 = 1u;
const OP_ADD:    u32 = 2u;
const OP_SUB:    u32 = 3u;
const OP_MUL:    u32 = 4u;
const OP_DIV:    u32 = 5u;
const OP_REM:    u32 = 6u;
const OP_POWF:   u32 = 7u;
const OP_ATAN2:  u32 = 8u;
const OP_HYPOT:  u32 = 9u;
const OP_MAX:    u32 = 10u;
const OP_MIN:    u32 = 11u;
const OP_NEG:    u32 = 12u;
const OP_RECIP:  u32 = 13u;
const OP_SQRT:   u32 = 14u;
const OP_CBRT:   u32 = 15u;
const OP_POWI:   u32 = 16u;
const OP_EXP:    u32 = 17u;
const OP_EXP2:   u32 = 18u;
const OP_EXPM1:  u32 = 19u;
const OP_LN:     u32 = 20u;
const OP_LOG2:   u32 = 21u;
const OP_LOG10:  u32 = 22u;
const OP_LN1P:   u32 = 23u;
const OP_SIN:    u32 = 24u;
const OP_COS:    u32 = 25u;
const OP_TAN:    u32 = 26u;
const OP_ASIN:   u32 = 27u;
const OP_ACOS:   u32 = 28u;
const OP_ATAN:   u32 = 29u;
const OP_SINH:   u32 = 30u;
const OP_COSH:   u32 = 31u;
const OP_TANH:   u32 = 32u;
const OP_ASINH:  u32 = 33u;
const OP_ACOSH:  u32 = 34u;
const OP_ATANH:  u32 = 35u;
const OP_ABS:    u32 = 36u;
const OP_SIGNUM: u32 = 37u;
const OP_FLOOR:  u32 = 38u;
const OP_CEIL:   u32 = 39u;
const OP_ROUND:  u32 = 40u;
const OP_TRUNC:  u32 = 41u;
const OP_FRACT:  u32 = 42u;

const UNUSED: u32 = 0xFFFFFFFFu;

// ── Tape data (bind group 0) ──
struct TapeMeta {
    num_ops: u32,
    num_inputs: u32,
    num_variables: u32,
    num_outputs: u32,
    batch_size: u32,
    _pad0: u32,
    _pad1: u32,
    _pad2: u32,
}

@group(0) @binding(0) var<storage, read> opcodes: array<u32>;
@group(0) @binding(1) var<storage, read> arg0: array<u32>;
@group(0) @binding(2) var<storage, read> arg1: array<u32>;
@group(0) @binding(3) var<storage, read> constants: array<f32>;
@group(0) @binding(4) var<uniform> tape_meta: TapeMeta;
@group(0) @binding(5) var<storage, read> output_indices: array<u32>;

// ── I/O buffers (bind group 1) ──
// binding 0: values [B * num_variables] (from forward pass, read-only here)
@group(1) @binding(0) var<storage, read> values: array<f32>;
// binding 1: adjoints [B * num_variables] (working memory, read-write)
@group(1) @binding(1) var<storage, read_write> adjoints: array<f32>;
// binding 2: grad_out [B * num_inputs] (output gradients)
@group(1) @binding(2) var<storage, read_write> grad_out: array<f32>;

// ── Main kernel ──

@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let batch_id = gid.x;
    if batch_id >= tape_meta.batch_size {
        return;
    }

    let num_vars = tape_meta.num_variables;
    let num_ops = tape_meta.num_ops;
    let num_in = tape_meta.num_inputs;
    let n_out = tape_meta.num_outputs;

    // Base offsets into per-thread buffers.
    let v_base = batch_id * num_vars; // values base
    let a_base = batch_id * num_vars; // adjoints base

    // Initialize adjoints to zero.
    for (var i = 0u; i < num_vars; i = i + 1u) {
        adjoints[a_base + i] = 0.0;
    }

    // Seed: set adjoint of the output to 1.0.
    // For single-output tapes, seed the first (only) output.
    let seed_idx = output_indices[0];
    adjoints[a_base + seed_idx] = 1.0;

    // Reverse sweep.
    for (var ii = 0u; ii < num_ops; ii = ii + 1u) {
        let i = num_ops - 1u - ii;
        let adj = adjoints[a_base + i];

        // Skip zero adjoints and structural ops.
        if adj == 0.0 {
            continue;
        }

        let op = opcodes[i];
        if op == OP_INPUT || op == OP_CONST {
            continue;
        }

        // Clear this adjoint (it's been consumed).
        adjoints[a_base + i] = 0.0;

        let a_idx = arg0[i];
        let b_idx = arg1[i];
        let a = values[v_base + a_idx];
        let r = values[v_base + i];

        // Compute reverse partials (da, db) and accumulate.
        var da = 0.0f;
        var db = 0.0f;

        switch op {
            // Binary
            case 2u /* ADD */: { da = 1.0; db = 1.0; }
            case 3u /* SUB */: { da = 1.0; db = -1.0; }
            case 4u /* MUL */: {
                let b = values[v_base + b_idx];
                da = b; db = a;
            }
            case 5u /* DIV */: {
                let b = values[v_base + b_idx];
                let inv = 1.0 / b;
                // `db = -a/b² = -r/b = -r * inv` — factoring through the
                // primal `r = a/b` avoids forming `inv*inv`, which
                // overflows when |b| is small.
                da = inv;
                db = -r * inv;
            }
            case 6u /* REM */: {
                let b = values[v_base + b_idx];
                da = 1.0;
                db = -trunc(a / b);
            }
            case 7u /* POWF */: {
                let b = values[v_base + b_idx];
                da = b * pow(a, b - 1.0);
                // For a <= 0, `log(a)` is NaN. A finite `r` at a < 0 means b
                // was integer; the classical derivative w.r.t. b at integer b
                // is undefined, and the convention here is 0 — matches the
                // CPU `OpCode::Powf` safety net.
                db = select(r * log(a), 0.0, r == 0.0 || a <= 0.0);
            }
            case 8u /* ATAN2 */: {
                let b = values[v_base + b_idx];
                // WGSL lacks `hypot`, so we can't form `h = sqrt(a² + b²)`
                // directly — `a*a` overflows in f32 for |a| > ~1.8e19 even
                // when the partial `b/(a²+b²)` is representable. Normalize by
                // max(|a|, |b|) so the sum-of-squares is bounded by 2:
                //   a² + b² = mx² · (au² + bu²)   with  au = a/mx, bu = b/mx
                //   b/(a²+b²) = bu / (mx · (au² + bu²))
                let mx = max(abs(a), abs(b));
                if mx == 0.0 {
                    da = 0.0; db = 0.0;
                } else {
                    let au = a / mx;
                    let bu = b / mx;
                    let denom = mx * (au * au + bu * bu);
                    da = bu / denom;
                    db = -au / denom;
                }
            }
            case 9u /* HYPOT */: {
                let b = values[v_base + b_idx];
                if r == 0.0 { da = 0.0; db = 0.0; } else {
                da = a / r;
                db = b / r; }
            }
            case 10u /* MAX */: {
                let b = values[v_base + b_idx];
                // Route adjoint to the operand whose value equals the forward
                // output `r = max(a,b)`. Direct bit-equality handles NaN
                // correctly and survives optimizers that fold `b != b` away.
                if bitcast<u32>(r) == bitcast<u32>(a) {
                    da = 1.0; db = 0.0;
                } else {
                    da = 0.0; db = 1.0;
                }
            }
            case 11u /* MIN */: {
                let b = values[v_base + b_idx];
                if bitcast<u32>(r) == bitcast<u32>(a) {
                    da = 1.0; db = 0.0;
                } else {
                    da = 0.0; db = 1.0;
                }
            }

            // Unary
            case 12u /* NEG */: { da = -1.0; }
            case 13u /* RECIP */: { let inv = 1.0 / a; da = -inv * inv; }
            case 14u /* SQRT */: { da = 0.5 / r; }
            case 15u /* CBRT */: { da = 1.0 / (3.0 * r * r); }
            case 16u /* POWI */: {
                let exp = bitcast<i32>(b_idx);
                if exp == 0 { da = 0.0; } else {
                let n = f32(exp);
                da = n * pow(a, n - 1.0); }
            }

            // Exp/Log
            case 17u /* EXP */: { da = r; }
            case 18u /* EXP2 */: { da = r * log(2.0); }
            case 19u /* EXPM1 */: { da = r + 1.0; }
            case 20u /* LN */: { da = 1.0 / a; }
            case 21u /* LOG2 */: { da = 1.0 / (a * log(2.0)); }
            case 22u /* LOG10 */: { da = 1.0 / (a * log(10.0)); }
            case 23u /* LN1P */: { da = 1.0 / (1.0 + a); }

            // Trig
            case 24u /* SIN */: { da = cos(a); }
            case 25u /* COS */: { da = -sin(a); }
            case 26u /* TAN */: { let c = cos(a); da = 1.0 / (c * c); }
            case 27u /* ASIN */: { da = 1.0 / sqrt((1.0 - a) * (1.0 + a)); }
            case 28u /* ACOS */: { da = -1.0 / sqrt((1.0 - a) * (1.0 + a)); }
            case 29u /* ATAN */: {
                // For |a|>1e8, 1+a² overflows in f32; use inv-based form
                // `inv²/(1+inv²)` that stays in-range.
                let aa = abs(a);
                if aa > 1e8 { let inv = 1.0 / a; da = inv * inv / (1.0 + inv * inv); }
                else        { da = 1.0 / (1.0 + a * a); }
            }

            // Hyperbolic
            case 30u /* SINH */: { da = cosh(a); }
            case 31u /* COSH */: { da = sinh(a); }
            case 32u /* TANH */: { let c = cosh(a); da = 1.0 / (c * c); }
            // For |a| > 1e8, a*a + 1 overflows; use |1/a|/sqrt(1 + 1/a²) which
            // stays representable. Mirrors the CPU `OpCode::Asinh`/`OpCode::Acosh`
            // overflow guard and the existing `OpCode::Atan` pattern.
            case 33u /* ASINH */: {
                if abs(a) > 1e8 {
                    let inv = 1.0 / a;
                    da = abs(inv) / sqrt(1.0 + inv * inv);
                } else {
                    da = 1.0 / sqrt(a * a + 1.0);
                }
            }
            case 34u /* ACOSH */: {
                if abs(a) > 1e8 {
                    let inv = 1.0 / a;
                    da = abs(inv) / sqrt(1.0 - inv * inv);
                } else {
                    // Factored form (a-1)(a+1) avoids catastrophic
                    // cancellation near a=1; matches kernels::acosh_deriv.
                    da = 1.0 / sqrt((a - 1.0) * (a + 1.0));
                }
            }
            case 35u /* ATANH */: { da = 1.0 / ((1.0 - a) * (1.0 + a)); }

            // Misc
            case 36u /* ABS */: {
                // `a >= 0.0` maps -0.0 to +1; to mirror Rust's `f32::signum`
                // (which uses the sign bit), inspect bit 31 of the bitcast.
                if a != a {
                    da = a; // NaN
                } else {
                    let bits = bitcast<u32>(a);
                    da = select(1.0, -1.0, (bits & 0x80000000u) != 0u);
                }
            }
            case 37u, 38u, 39u, 40u, 41u /* SIGNUM..TRUNC */: { da = 0.0; }
            case 42u /* FRACT */: { da = 1.0; }

            default: {}
        }

        // Accumulate
        adjoints[a_base + a_idx] = adjoints[a_base + a_idx] + da * adj;
        if b_idx != UNUSED && op != OP_POWI {
            adjoints[a_base + b_idx] = adjoints[a_base + b_idx] + db * adj;
        }
    }

    // Write gradients: input adjoints → grad_out.
    let g_base = batch_id * num_in;
    for (var i = 0u; i < num_in; i = i + 1u) {
        grad_out[g_base + i] = adjoints[a_base + i];
    }
}

// Manual implementations for WGSL builtins not available.
fn sinh(x: f32) -> f32 {
    return (exp(x) - exp(-x)) * 0.5;
}

fn cosh(x: f32) -> f32 {
    return (exp(x) + exp(-x)) * 0.5;
}