rlx-wgpu 0.2.1

Cross-platform GPU backend for RLX via wgpu (Metal/Vulkan/DX12/WebGPU)
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

// PLAN L2 — interpreted N-ary element-wise chain kernel.
//
// One thread per output element. Each thread walks the chain encoding
// (compile-time fixed-size array, runtime length via num_steps) and
// computes intermediate values into a private scratch register array.
// The final step's result is written to dst.
//
// Encoding (4 u32s per step):
//   chain[k*4 + 0] = op_kind  (0=Activation, 1=Cast, 2=Binary, 3=Compare)
//   chain[k*4 + 1] = op_sub   (discriminant of the inner op variant)
//   chain[k*4 + 2] = lhs_enc  (bit 31 = src kind: 0=Input, 1=Step;
//                              bits 0..30 = index)
//   chain[k*4 + 3] = rhs_enc  (same; ignored for unary ops)
//
// Per-input data is read from `arena[input_offs[idx] + i]`.
// Output is written at `arena[dst_off + i]`.

const MAX_STEPS: u32 = 32u;
const MAX_INPUTS: u32 = 16u;

struct Params {
    len: u32,
    num_inputs: u32,
    num_steps: u32,
    dst_off: u32,
    input_offs: array<u32, MAX_INPUTS>,
    chain: array<u32, 128>,   // MAX_STEPS * 4
    scalar_input_mask: u32,
    _pad0: u32,
    _pad1: u32,
    _pad2: u32,
    input_modulus: array<u32, MAX_INPUTS>,
};

@group(0) @binding(0) var<storage, read_write> arena: array<f32>;
// Storage buffer (read-only) instead of uniform: WGSL uniform-storage
// requires 16-byte stride for array elements, which doesn't fit our
// `array<u32, N>` packed layout. Storage allows any stride.
@group(0) @binding(1) var<storage, read>        params: Params;

fn resolve_operand(enc: u32, i: u32, scratch: ptr<function, array<f32, 32>>) -> f32 {
    let kind = enc >> 31u;
    let idx  = enc & 0x7FFFFFFFu;
    if (kind == 0u) {
        // Input. Scalar-broadcast inputs read element 0 (fast path);
        // trailing-shape broadcast tiles by `i % input_modulus[idx]`;
        // `input_modulus[idx] == 0` ⇒ read by `i` directly.
        let off = params.input_offs[idx];
        var row: u32;
        if ((params.scalar_input_mask & (1u << idx)) != 0u) {
            row = 0u;
        } else if (params.input_modulus[idx] != 0u) {
            row = i % params.input_modulus[idx];
        } else {
            row = i;
        }
        return arena[off + row];
    } else {
        // Prior step result
        return (*scratch)[idx];
    }
}

fn apply_activation(sub: u32, x: f32) -> f32 {
    if (sub == 3u) { return max(x, 0.0); }                // Relu
    if (sub == 0u || sub == 1u) {                          // Gelu / GeluApprox
        // GELU via the sigmoid-form identity:
        //   gelu_approx(x) = 0.5 · x · (1 + tanh(c·(x + 0.044715·x³)))
        //                  = x · sigmoid(2·c·(x + 0.044715·x³))
        // The tanh form hits `0·∞ = NaN` on Apple Metal's wgsl `tanh`
        // for some specific x values that produce huge intermediates
        // (observed reproducer: BERT MiniLM6 FFN1, 1 NaN per ~5000
        // outputs — see commit message for the bisect). The sigmoid
        // form has no such trap because sigmoid clamps in [0, 1] and
        // x · sigmoid never multiplies a finite by a sign-flipping
        // zero.
        let c2 = 2.0f * 0.7978845608f;
        let inner = c2 * (x + 0.044715f * x * x * x);
        let s = 1.0f / (1.0f + exp(-inner));
        return x * s;
    }
    if (sub == 2u) { return x / (1.0f + exp(-x)); }       // Silu
    if (sub == 4u) { return 1.0f / (1.0f + exp(-x)); }    // Sigmoid
    if (sub == 5u) { return tanh(x); }                     // Tanh
    if (sub == 6u) { return exp(x); }                      // Exp
    if (sub == 7u) { return log(x); }                      // Log
    if (sub == 8u) { return sqrt(x); }                     // Sqrt
    if (sub == 9u) { return 1.0f / sqrt(x); }              // Rsqrt
    if (sub == 10u) { return -x; }                         // Neg
    if (sub == 11u) { return abs(x); }                     // Abs
    if (sub == 13u) { return sin(x); }                     // Sin
    if (sub == 14u) { return cos(x); }                     // Cos
    if (sub == 15u) { return tan(x); }                     // Tan
    if (sub == 16u) { return atan(x); }                    // Atan
    return x;
}

fn apply_binary(sub: u32, a: f32, b: f32) -> f32 {
    if (sub == 0u) { return a + b; }   // Add
    if (sub == 1u) { return a - b; }   // Sub
    if (sub == 2u) { return a * b; }   // Mul
    if (sub == 3u) { return a / b; }   // Div
    if (sub == 4u) { return max(a, b); } // Max
    if (sub == 5u) { return min(a, b); } // Min
    if (sub == 6u) { return pow(a, b); } // Pow
    return a;
}

fn apply_compare(sub: u32, a: f32, b: f32) -> f32 {
    if (sub == 0u) { return select(0.0f, 1.0f, a == b); } // Eq
    if (sub == 1u) { return select(0.0f, 1.0f, a != b); } // Ne
    if (sub == 2u) { return select(0.0f, 1.0f, a <  b); } // Lt
    if (sub == 3u) { return select(0.0f, 1.0f, a <= b); } // Le
    if (sub == 4u) { return select(0.0f, 1.0f, a >  b); } // Gt
    if (sub == 5u) { return select(0.0f, 1.0f, a >= b); } // Ge
    return 0.0f;
}

@compute @workgroup_size(64)
fn elementwise_region(@builtin(global_invocation_id) gid: vec3<u32>,
                      @builtin(num_workgroups) ngs: vec3<u32>) {
    let i = gid.x + gid.y * ngs.x * 64u;
    if (i >= params.len) { return; }

    var scratch: array<f32, 32>;
    var last_idx: u32 = 0u;
    for (var k: u32 = 0u; k < params.num_steps; k = k + 1u) {
        let base = k * 4u;
        let op_kind = params.chain[base + 0u];
        let op_sub  = params.chain[base + 1u];
        let lhs_enc = params.chain[base + 2u];
        let rhs_enc = params.chain[base + 3u];

        let lhs = resolve_operand(lhs_enc, i, &scratch);
        var result: f32;
        if (op_kind == 0u) {
            result = apply_activation(op_sub, lhs);
        } else if (op_kind == 1u) {
            // Cast at f32-arena layer is identity.
            result = lhs;
        } else if (op_kind == 2u) {
            let rhs = resolve_operand(rhs_enc, i, &scratch);
            result = apply_binary(op_sub, lhs, rhs);
        } else if (op_kind == 3u) {
            let rhs = resolve_operand(rhs_enc, i, &scratch);
            result = apply_compare(op_sub, lhs, rhs);
        } else {
            // op_kind == 4u: Where (3-operand select). op_sub carries
            // cond_enc; lhs already resolved is on_true; rhs is on_false.
            let cond = resolve_operand(op_sub, i, &scratch);
            let on_false = resolve_operand(rhs_enc, i, &scratch);
            result = select(on_false, lhs, cond != 0.0f);
        }
        scratch[k] = result;
        last_idx = k;
    }
    arena[params.dst_off + i] = scratch[last_idx];
}