rlx-wgpu 0.2.6

Cross-platform GPU backend for RLX via wgpu (Metal/Vulkan/DX12/WebGPU)
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

// Dequant-on-the-fly matmul. Reads packed int8 weight bytes from the
// f32-typed arena via bitcast<u32>, then dequantizes per-block before
// accumulating into the output.
//
// The arena binding is `array<f32>` like every other kernel — for the
// w_q tensor we bitcast each f32 word to u32 and extract one of four
// bytes via shift+mask. The user uploads packed int8 bytes through
// `set_param_bytes`, which writes them tight-packed at the slot's
// f32 offset.
//
// Layout:
//   x  [m, k]                            f32
//   w_q[k, n]   tightly packed i8         (k*n bytes inside k*n*4-byte slot)
//   scale[k/block, n]                    f32
//   zp   [k/block, n]                    f32 (read only when is_asym != 0)
// Output:
//   out[m, n]                            f32

// scheme_id selects the unpack path:
//   0 = Int8Block      (bits=8, signed, per-block scale)
//   1 = Int8BlockAsym  (bits=8, signed, per-block scale + zero point)
//   2 = Int4Block      (bits=4, signed, per-block scale)
//   3 = Fp8E4m3        (bits=8, no scale, OCP E4M3 bit decode)
//   4 = Fp8E5m2        (bits=8, no scale, OCP E5M2 bit decode)
//   5 = Nvfp4Block     (E2M1 nibbles + FP8 E4M3 block scales + f32 global_scale @ zp_off)

struct Params {
    m: u32,
    k: u32,
    n: u32,
    block_size: u32,
    scheme_id: u32,
    x_off: u32,
    w_off: u32,        // f32-element offset of w_q's slot
    scale_off: u32,
    zp_off: u32,
    out_off: u32,
    _p0: u32,
    _p1: u32,
};

@group(0) @binding(0) var<storage, read_write> arena: array<f32>;
@group(0) @binding(1) var<uniform>              params: Params;

// Read one byte from the f32-packed weight stream at flat element index.
fn read_byte(elem_idx: u32) -> u32 {
    let word_idx = elem_idx / 4u;
    let byte_shift = (elem_idx % 4u) * 8u;
    let bits = bitcast<u32>(arena[params.w_off + word_idx]);
    return (bits >> byte_shift) & 0xffu;
}

// Read one signed nibble (4-bit, two-per-byte, low nibble first).
fn read_nibble_signed(elem_idx: u32) -> i32 {
    let word_idx = elem_idx / 8u;
    let nib_shift = (elem_idx % 8u) * 4u;
    let bits = bitcast<u32>(arena[params.w_off + word_idx]);
    let nib: u32 = (bits >> nib_shift) & 0xfu;
    var q = i32(nib);
    if (q >= 8) { q = q - 16; }
    return q;
}

// Read one byte from the f32-packed scale stream at flat byte index.
fn read_scale_byte(byte_idx: u32) -> u32 {
    let word_idx = byte_idx / 4u;
    let byte_shift = (byte_idx % 4u) * 8u;
    let bits = bitcast<u32>(arena[params.scale_off + word_idx]);
    return (bits >> byte_shift) & 0xffu;
}

// Read one unsigned nibble (FP4 E2M1 code 0..15).
fn read_nibble_u4(elem_idx: u32) -> u32 {
    let word_idx = elem_idx / 8u;
    let nib_shift = (elem_idx % 8u) * 4u;
    let bits = bitcast<u32>(arena[params.w_off + word_idx]);
    return (bits >> nib_shift) & 0xfu;
}

const FP4_E2M1: array<f32, 16> = array<f32, 16>(
    0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
    -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
);

// OCP E4M3: 1 sign + 4 exp + 3 mantissa, exp bias = 7, no infinity.
// The all-ones exp + max mantissa pattern (0x7F / 0xFF) is reserved for NaN.
fn decode_e4m3(byte: u32) -> f32 {
    let sign = (byte >> 7u) & 1u;
    let exp  = (byte >> 3u) & 0xfu;
    let mant = byte & 0x7u;
    var v: f32;
    if (exp == 0u) {
        // Subnormal: value = mant/8 * 2^-6 (smallest subnormal = 2^-9).
        v = (f32(mant) / 8.0) * exp2(-6.0);
    } else if (exp == 15u && mant == 7u) {
        // NaN — coerce to 0 so downstream matmul stays finite.
        v = 0.0;
    } else {
        let m = 1.0 + f32(mant) / 8.0;
        v = m * exp2(f32(i32(exp) - 7));
    }
    if (sign != 0u) { v = -v; }
    return v;
}

// OCP E5M2: 1 sign + 5 exp + 2 mantissa, exp bias = 15, has infinity.
// exp=31 + mant=0 → ±inf; exp=31 + mant>0 → NaN.
fn decode_e5m2(byte: u32) -> f32 {
    let sign = (byte >> 7u) & 1u;
    let exp  = (byte >> 2u) & 0x1fu;
    let mant = byte & 0x3u;
    var v: f32;
    if (exp == 0u) {
        // Subnormal: value = mant/4 * 2^-14.
        v = (f32(mant) / 4.0) * exp2(-14.0);
    } else if (exp == 31u) {
        // Inf or NaN — coerce to 0 to keep the matmul finite.
        v = 0.0;
    } else {
        let m = 1.0 + f32(mant) / 4.0;
        v = m * exp2(f32(i32(exp) - 15));
    }
    if (sign != 0u) { v = -v; }
    return v;
}

@compute @workgroup_size(8, 8)
fn dequant_matmul(@builtin(global_invocation_id) gid: vec3<u32>) {
    let row = gid.y;
    let col = gid.x;
    if (row >= params.m || col >= params.n) { return; }
    var acc: f32 = 0.0;
    for (var k: u32 = 0u; k < params.k; k = k + 1u) {
        let elem_idx = k * params.n + col;
        var w_dq: f32 = 0.0;
        if (params.scheme_id == 0u) {
            // Int8Block (symmetric).
            let byte = read_byte(elem_idx);
            var q = i32(byte);
            if (q >= 128) { q = q - 256; }
            let block = k / params.block_size;
            let scale = arena[params.scale_off + block * params.n + col];
            w_dq = f32(q) * scale;
        } else if (params.scheme_id == 1u) {
            // Int8BlockAsym (with zero-point).
            let byte = read_byte(elem_idx);
            var q = i32(byte);
            if (q >= 128) { q = q - 256; }
            let block = k / params.block_size;
            let scale = arena[params.scale_off + block * params.n + col];
            let zp    = arena[params.zp_off    + block * params.n + col];
            w_dq = (f32(q) - zp) * scale;
        } else if (params.scheme_id == 2u) {
            // Int4Block (symmetric).
            let q = read_nibble_signed(elem_idx);
            let block = k / params.block_size;
            let scale = arena[params.scale_off + block * params.n + col];
            w_dq = f32(q) * scale;
        } else if (params.scheme_id == 3u) {
            // Fp8E4m3 — direct bit decode, no scale.
            let byte = read_byte(elem_idx);
            w_dq = decode_e4m3(byte);
        } else if (params.scheme_id == 4u) {
            // Fp8E5m2 — direct bit decode, no scale.
            let byte = read_byte(elem_idx);
            w_dq = decode_e5m2(byte);
        } else {
            // Nvfp4Block — E2M1 nibble × FP8 block scale × global f32 scale.
            let nib = read_nibble_u4(elem_idx);
            let block = k / params.block_size;
            let scale_byte = read_scale_byte(block * params.n + col);
            let scale = decode_e4m3(scale_byte);
            let gs = arena[params.zp_off];
            w_dq = FP4_E2M1[nib] * scale * gs;
        }
        acc = acc + arena[params.x_off + row * params.k + k] * w_dq;
    }
    arena[params.out_off + row * params.n + col] = acc;
}