rlx-wgpu 0.2.1

Cross-platform GPU backend for RLX via wgpu (Metal/Vulkan/DX12/WebGPU)
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

// Grouped (MoE) matmul. One workgroup per (M, N) tile; each thread
// computes one output element by reading the per-token expert id, then
// streaming K against weight[expert_id]. Naive correctness-first
// implementation — segmented/grouped GEMM optimization can come later.
//
// input       : [M, K]
// weight      : [num_experts, K, N]
// expert_idx  : [M]              (f32-encoded per-token expert id)
// output      : [M, N]
//   output[m, n] = input[m, :] · weight[expert_idx[m], :, n]

struct Params {
    m: u32,
    k: u32,
    n: u32,
    num_experts: u32,
    in_off: u32,
    w_off: u32,
    idx_off: u32,
    out_off: u32,
};

@group(0) @binding(0) var<storage, read_write> arena: array<f32>;
@group(0) @binding(1) var<uniform>              params: Params;

@compute @workgroup_size(8, 8)
fn grouped_matmul(@builtin(global_invocation_id) gid: vec3<u32>) {
    let row = gid.y;
    let col = gid.x;
    if (row >= params.m || col >= params.n) { return; }
    let e_f = arena[params.idx_off + row];
    let e   = u32(e_f);
    if (e >= params.num_experts) { return; }
    let w_base = params.w_off + e * params.k * params.n;
    let in_base = params.in_off + row * params.k;
    var acc: f32 = 0.0;
    for (var kk: u32 = 0u; kk < params.k; kk = kk + 1u) {
        acc = acc + arena[in_base + kk]
                  * arena[w_base + kk * params.n + col];
    }
    arena[params.out_off + row * params.n + col] = acc;
}