rlx-wgpu 0.2.4

Cross-platform GPU backend for RLX via wgpu (Metal/Vulkan/DX12/WebGPU)
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

// LayerNorm backward over the last axis ("row" = last-axis slice).
//
// Three entry points share the same Params layout:
//   * `layer_norm_bwd_input`         — dx (one workgroup per row).
//   * `layer_norm_bwd_gamma_partial` — partial dgamma per row chunk,
//     written to scratch[wg_idx, d].
//   * `layer_norm_bwd_gamma_reduce`  — sum scratch[*, d] → dgamma[d].
//
// Both gamma entry points + the input kernel recompute per-row
// mean/inv_std from x (the forward didn't save them), matching the
// CPU thunk in `rlx-cpu/src/thunk.rs::Thunk::LayerNormBackward*`.

struct Params {
    outer: u32,       // number of rows
    inner: u32,       // size of last axis (H)
    x_off: u32,
    gamma_off: u32,
    dy_off: u32,
    out_off: u32,     // dx or dgamma destination
    eps_bits: u32,
    scratch_off: u32, // word-offset (f32 index) into the tail scratch
                      // zone (gamma path only; ignored for input).
};

@group(0) @binding(0) var<storage, read_write> arena: array<f32>;
@group(0) @binding(1) var<uniform>              params: Params;

var<workgroup> scratch_wg: array<f32, 64>;

fn workgroup_sum(tid: u32, val: f32) -> f32 {
    scratch_wg[tid] = val;
    workgroupBarrier();
    var stride: u32 = 32u;
    loop {
        if (stride == 0u) { break; }
        if (tid < stride) {
            scratch_wg[tid] = scratch_wg[tid] + scratch_wg[tid + stride];
        }
        workgroupBarrier();
        stride = stride >> 1u;
    }
    return scratch_wg[0];
}

// dx kernel: one workgroup per row, 64 lanes cooperate.
@compute @workgroup_size(64)
fn layer_norm_bwd_input(
    @builtin(workgroup_id) wid: vec3<u32>,
    @builtin(local_invocation_id) lid: vec3<u32>,
) {
    let row = wid.x;
    if (row >= params.outer || params.inner == 0u) { return; }
    let tid = lid.x;
    let h = params.inner;
    let x_base = params.x_off + row * h;
    let dy_base = params.dy_off + row * h;
    let out_base = params.out_off + row * h;
    let n_inv = 1.0 / f32(h);
    let eps = bitcast<f32>(params.eps_bits);

    var s: f32 = 0.0;
    var i: u32 = tid;
    loop {
        if (i >= h) { break; }
        s = s + arena[x_base + i];
        i = i + 64u;
    }
    let mean = workgroup_sum(tid, s) * n_inv;

    var s2: f32 = 0.0;
    i = tid;
    loop {
        if (i >= h) { break; }
        let d = arena[x_base + i] - mean;
        s2 = s2 + d * d;
        i = i + 64u;
    }
    let inv_std = inverseSqrt(workgroup_sum(tid, s2) * n_inv + eps);

    var p_sy: f32 = 0.0;
    var p_sxh: f32 = 0.0;
    i = tid;
    loop {
        if (i >= h) { break; }
        let xv = arena[x_base + i];
        let dyv = arena[dy_base + i];
        let gv = arena[params.gamma_off + i];
        let xh = (xv - mean) * inv_std;
        let sy = dyv * gv;
        p_sy = p_sy + sy;
        p_sxh = p_sxh + sy * xh;
        i = i + 64u;
    }
    let m_sy = workgroup_sum(tid, p_sy) * n_inv;
    let m_sxh = workgroup_sum(tid, p_sxh) * n_inv;

    i = tid;
    loop {
        if (i >= h) { break; }
        let xv = arena[x_base + i];
        let dyv = arena[dy_base + i];
        let gv = arena[params.gamma_off + i];
        let xh = (xv - mean) * inv_std;
        let sy = dyv * gv;
        arena[out_base + i] = inv_std * (sy - m_sy - xh * m_sxh);
        i = i + 64u;
    }
}

// `dgamma` partial sums: one workgroup per chunk of ROWS_PER_WG rows.
// `gamma_partial[wg_idx, d] = sum over rows in chunk of dy[r,d] * x_hat[r,d]`.
//
// Each workgroup uses its 64 lanes for the per-row mean+var
// reductions; lane d (d < H) also accumulates its column's
// contribution in a private register across the row chunk, then
// writes the final per-chunk partial to `scratch[wg_idx, d]`.
//
// Caller dispatches `num_workgroups = ceil(rows / ROWS_PER_WG)` and
// allocates `scratch_off` for `num_workgroups * H` f32 entries.
const ROWS_PER_WG: u32 = 16u;

var<workgroup> row_mean_inv: vec2<f32>;

@compute @workgroup_size(64)
fn layer_norm_bwd_gamma_partial(
    @builtin(workgroup_id) wid: vec3<u32>,
    @builtin(local_invocation_id) lid: vec3<u32>,
) {
    if (params.inner == 0u) { return; }
    let tid = lid.x;
    let h = params.inner;
    let rows = params.outer;
    let n_inv = 1.0 / f32(h);
    let eps = bitcast<f32>(params.eps_bits);
    let wg = wid.x;
    let row_start = wg * ROWS_PER_WG;
    if (row_start >= rows) { return; }
    let row_end = min(row_start + ROWS_PER_WG, rows);

    // Each lane d (d < H) holds its accumulator across the chunk.
    var acc: f32 = 0.0;
    for (var r: u32 = row_start; r < row_end; r = r + 1u) {
        let x_base = params.x_off + r * h;
        let dy_base = params.dy_off + r * h;

        // Cooperative mean(x_r).
        var s: f32 = 0.0;
        var i: u32 = tid;
        loop {
            if (i >= h) { break; }
            s = s + arena[x_base + i];
            i = i + 64u;
        }
        let sum_x = workgroup_sum(tid, s);
        if (tid == 0u) { row_mean_inv.x = sum_x * n_inv; }
        workgroupBarrier();
        let mean = row_mean_inv.x;

        // Cooperative var(x_r).
        var s2: f32 = 0.0;
        i = tid;
        loop {
            if (i >= h) { break; }
            let dv = arena[x_base + i] - mean;
            s2 = s2 + dv * dv;
            i = i + 64u;
        }
        let var_sum = workgroup_sum(tid, s2);
        if (tid == 0u) { row_mean_inv.y = inverseSqrt(var_sum * n_inv + eps); }
        workgroupBarrier();
        let inv_std = row_mean_inv.y;

        // Lanes d < H accumulate dy[r,d] * x_hat[r,d].
        if (tid < h) {
            let xv = arena[x_base + tid];
            let dyv = arena[dy_base + tid];
            acc = acc + dyv * (xv - mean) * inv_std;
        }
    }
    // Write per-chunk partial.
    if (tid < h) {
        arena[params.scratch_off + wg * h + tid] = acc;
    }
}

// Final reduce: one workgroup, lanes d sum `scratch[*, d]` across all
// chunks into `dgamma[d]`. `outer` in this dispatch carries the partial
// chunk count (= the original `num_workgroups`).
@compute @workgroup_size(64)
fn layer_norm_bwd_gamma_reduce(@builtin(local_invocation_id) lid: vec3<u32>) {
    let tid = lid.x;
    let h = params.inner;
    if (tid >= h || h == 0u) { return; }
    let num_chunks = params.outer;
    var acc: f32 = 0.0;
    for (var c: u32 = 0u; c < num_chunks; c = c + 1u) {
        acc = acc + arena[params.scratch_off + c * h + tid];
    }
    arena[params.out_off + tid] = acc;
}