rlx-wgpu 0.2.1

Cross-platform GPU backend for RLX via wgpu (Metal/Vulkan/DX12/WebGPU)
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

// Tiled fp32 matmul that splits its output across 3 separate Q/K/V
// buffers. Identical compute to `matmul.wgsl` (32×32 tile, 4×4 register
// blocking, cooperative shared-memory loads); the only change is the
// epilogue, which routes each output column to one of three sinks based
// on `global_col` against `head_width = H · D`.
//
// Pattern this kernel collapses (one dispatch instead of four):
//   FusedMatMulBiasAct: input @ qkv_w (+ bias)  →  [B·S, 3·H·D]
//   Narrow Q (axis=last, start=0,        len=H·D)
//   Narrow K (axis=last, start=H·D,      len=H·D)
//   Narrow V (axis=last, start=2·H·D,    len=H·D)
//
// Detected at lowering time in `backend.rs` (`detect_split_qkv_pattern`).
// The 3 narrows' arena slots become the kernel's 3 output bases — the
// fused matmul output buffer is never written.
//
// Why split-write rather than strided reads from a single QKV buffer:
// stepping by 3·H·D between K rows defeats Apple's hardware prefetcher
// (~7-15× regression on M-series). Split-write keeps each Q/K/V buffer
// internally contiguous, so attention reads stay sequential.

struct Params {
    m: u32,
    k: u32,
    n: u32,            // = 3 · head_width  (full QKV column count)
    a_off: u32,
    b_off: u32,
    q_off: u32,        // Q output base (per-row stride = head_width)
    k_off: u32,        // K output base
    v_off: u32,        // V output base
    head_width: u32,   // = H · D
    has_bias: u32,
    bias_off: u32,     // bias is [3·H·D]; read bias[global_col]
    _p0: u32, _p1: u32, _p2: u32, _p3: u32, _p4: u32,
};

const TILE_M: u32 = 32u;
const TILE_N: u32 = 32u;
const TILE_K: u32 = 16u;
const RM: u32 = 4u;
const RN: u32 = 4u;

@group(0) @binding(0) var<storage, read_write> arena: array<f32>;
@group(0) @binding(1) var<uniform>              params: Params;

var<workgroup> tile_a: array<array<f32, 16>, 32>;
var<workgroup> tile_b: array<array<f32, 32>, 16>;

@compute @workgroup_size(8, 8)
fn matmul_qkv(
    @builtin(local_invocation_id)    lid: vec3<u32>,
    @builtin(workgroup_id)           wid: vec3<u32>,
) {
    let lr = lid.y;
    let lc = lid.x;
    let row_base = wid.y * TILE_M + lr * RM;
    let col_base = wid.x * TILE_N + lc * RN;

    var acc: array<array<f32, 4>, 4>;
    for (var i: u32 = 0u; i < RM; i = i + 1u) {
        for (var j: u32 = 0u; j < RN; j = j + 1u) {
            acc[i][j] = 0.0;
        }
    }

    let n_tiles = (params.k + TILE_K - 1u) / TILE_K;

    for (var t: u32 = 0u; t < n_tiles; t = t + 1u) {
        for (var i: u32 = 0u; i < RM; i = i + 1u) {
            let m_local = lr * RM + i;
            let global_row = wid.y * TILE_M + m_local;
            for (var j: u32 = 0u; j < 2u; j = j + 1u) {
                let k_local = lc * 2u + j;
                let global_k = t * TILE_K + k_local;
                if (global_row < params.m && global_k < params.k) {
                    tile_a[m_local][k_local] = arena[params.a_off + global_row * params.k + global_k];
                } else {
                    tile_a[m_local][k_local] = 0.0;
                }
            }
        }
        for (var i: u32 = 0u; i < 2u; i = i + 1u) {
            let k_local = lr * 2u + i;
            let global_k = t * TILE_K + k_local;
            for (var j: u32 = 0u; j < RN; j = j + 1u) {
                let n_local = lc * RN + j;
                let global_col = wid.x * TILE_N + n_local;
                if (global_k < params.k && global_col < params.n) {
                    tile_b[k_local][n_local] = arena[params.b_off + global_k * params.n + global_col];
                } else {
                    tile_b[k_local][n_local] = 0.0;
                }
            }
        }

        workgroupBarrier();

        for (var k: u32 = 0u; k < TILE_K; k = k + 1u) {
            var a_reg: array<f32, 4>;
            var b_reg: array<f32, 4>;
            for (var i: u32 = 0u; i < RM; i = i + 1u) {
                a_reg[i] = tile_a[lr * RM + i][k];
            }
            for (var j: u32 = 0u; j < RN; j = j + 1u) {
                b_reg[j] = tile_b[k][lc * RN + j];
            }
            for (var i: u32 = 0u; i < RM; i = i + 1u) {
                for (var j: u32 = 0u; j < RN; j = j + 1u) {
                    acc[i][j] = acc[i][j] + a_reg[i] * b_reg[j];
                }
            }
        }

        workgroupBarrier();
    }

    // Split-write epilogue. For column j in [0, H·D)        → Q
    //                              [H·D, 2·H·D)  → K
    //                              [2·H·D, 3·H·D) → V
    // Each sink buffer's per-row stride is `head_width` (Q/K/V are dense
    // [B·S, H·D] tensors), independent of the matmul's per-row stride n.
    let hw = params.head_width;
    for (var i: u32 = 0u; i < RM; i = i + 1u) {
        let global_row = row_base + i;
        if (global_row >= params.m) { continue; }
        for (var j: u32 = 0u; j < RN; j = j + 1u) {
            let global_col = col_base + j;
            if (global_col >= params.n) { continue; }
            var v = acc[i][j];
            if (params.has_bias != 0u) {
                v = v + arena[params.bias_off + global_col];
            }

            var sink_off: u32 = 0u;
            var col_in_sink: u32 = 0u;
            if (global_col < hw) {
                sink_off = params.q_off;
                col_in_sink = global_col;
            } else if (global_col < 2u * hw) {
                sink_off = params.k_off;
                col_in_sink = global_col - hw;
            } else {
                sink_off = params.v_off;
                col_in_sink = global_col - 2u * hw;
            }
            arena[sink_off + global_row * hw + col_in_sink] = v;
        }
    }
}