rlx-wgpu 0.2.6

Cross-platform GPU backend for RLX via wgpu (Metal/Vulkan/DX12/WebGPU)
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

// Multi-kernel f32 FFT (gpu-fft strategy), RLX 2N real-block layout.

struct Params {
    off: u32,
    dst_off: u32,
    n: u32,
    log2n: u32,
    inverse: u32,
    norm_scale: f32,
    outer: u32,
    tile: u32,
    inner_stages: u32,
    q_or_hs: u32,
};

@group(0) @binding(0) var<storage, read_write> arena: array<f32>;
@group(0) @binding(1) var<uniform>              params: Params;

fn re_at(base: u32, k: u32, n: u32) -> f32 {
    return arena[base + k];
}
fn im_at(base: u32, k: u32, n: u32) -> f32 {
    return arena[base + n + k];
}
fn set_re(base: u32, k: u32, n: u32, v: f32) {
    arena[base + k] = v;
}
fn set_im(base: u32, k: u32, n: u32, v: f32) {
    arena[base + n + k] = v;
}

// Single-kernel path (N <= 1024): bit-reverse load + all stages in TG mem.
var<workgroup> sre_full: array<f32, 1024>;
var<workgroup> sim_full: array<f32, 1024>;

@compute @workgroup_size(256)
fn fft_radix2_full(
    @builtin(workgroup_id) wgid: vec3<u32>,
    @builtin(local_invocation_id) lid: vec3<u32>,
) {
    let n = params.n;
    let log2n = params.log2n;
    let row = wgid.y;
    if (row >= params.outer) { return; }
    let src_base = params.off + row * 2u * n;
    let dst_base = params.dst_off + row * 2u * n;
    let tid = lid.x;
    let tg_size = 256u;

    var k: u32 = tid;
    loop {
        if (k >= n) { break; }
        let rev = reverseBits(k) >> (32u - log2n);
        sre_full[rev] = re_at(src_base, k, n);
        sim_full[rev] = im_at(src_base, k, n);
        k = k + tg_size;
    }
    workgroupBarrier();

    let sign = select(-1.0, 1.0, params.inverse != 0u);
    let two_pi = 6.28318530717958647692;
    var len: u32 = 2u;
    loop {
        if (len > n) { break; }
        let h2 = len >> 1u;
        let theta_base = sign * two_pi / f32(len);
        var b: u32 = tid;
        loop {
            if (b >= n / 2u) { break; }
            let group = b / h2;
            let k_in = b % h2;
            let i_lo = group * len + k_in;
            let i_hi = i_lo + h2;
            let theta = theta_base * f32(k_in);
            let wre = cos(theta);
            let wim = sin(theta);
            let t_re = wre * sre_full[i_hi] - wim * sim_full[i_hi];
            let t_im = wre * sim_full[i_hi] + wim * sre_full[i_hi];
            let u_re = sre_full[i_lo];
            let u_im = sim_full[i_lo];
            sre_full[i_lo] = u_re + t_re;
            sim_full[i_lo] = u_im + t_im;
            sre_full[i_hi] = u_re - t_re;
            sim_full[i_hi] = u_im - t_im;
            b = b + tg_size;
        }
        workgroupBarrier();
        len = len << 1u;
    }

    k = tid;
    loop {
        if (k >= n) { break; }
        set_re(dst_base, k, n, sre_full[k] * params.norm_scale);
        set_im(dst_base, k, n, sim_full[k] * params.norm_scale);
        k = k + tg_size;
    }
}

// Bit-reverse one row before multi-kernel outer stages.
@compute @workgroup_size(256)
fn fft_bit_reverse(
    @builtin(global_invocation_id) gid: vec3<u32>,
    @builtin(workgroup_id) wgid: vec3<u32>,
) {
    let row = wgid.y;
    if (row >= params.outer) { return; }
    let n = params.n;
    let k = gid.x;
    if (k >= n) { return; }
    let base = params.off + row * 2u * n;
    let rev = reverseBits(k) >> (32u - params.log2n);
    if (k >= rev) { return; }
    let tr = re_at(base, k, n);
    let ti = im_at(base, k, n);
    set_re(base, k, n, re_at(base, rev, n));
    set_im(base, k, n, im_at(base, rev, n));
    set_re(base, rev, n, tr);
    set_im(base, rev, n, ti);
}

// Inner shared-memory tile (tile <= 1024).
var<workgroup> sre_in: array<f32, 1024>;
var<workgroup> sim_in: array<f32, 1024>;

@compute @workgroup_size(512)
fn fft_inner(
    @builtin(workgroup_id) wgid: vec3<u32>,
    @builtin(local_invocation_id) lid: vec3<u32>,
) {
    let row = wgid.y;
    if (row >= params.outer) { return; }
    let n = params.n;
    let tile = params.tile;
    let half_tile = tile / 2u;
    let tile_id = wgid.x;
    let num_tiles = (n + tile - 1u) / tile;
    if (tile_id >= num_tiles) { return; }
    let local = lid.x;
    if (local >= half_tile) { return; }

    let row_base = params.off + row * 2u * n;
    let tile_base = tile_id * tile;

    if (tile_base + local < n) {
        sre_in[local] = re_at(row_base, tile_base + local, n);
        sim_in[local] = im_at(row_base, tile_base + local, n);
    }
    if (tile_base + local + half_tile < n) {
        sre_in[local + half_tile] = re_at(row_base, tile_base + local + half_tile, n);
        sim_in[local + half_tile] = im_at(row_base, tile_base + local + half_tile, n);
    }
    workgroupBarrier();

    let sign = select(-1.0, 1.0, params.inverse != 0u);
    let pi = 3.14159265358979323846;
    for (var s: u32 = 0u; s < params.inner_stages; s = s + 1u) {
        let hs = 1u << s;
        let k = local % hs;
        let i = (local / hs) * (hs * 2u) + k;
        let j = i + hs;
        let angle = sign * pi * f32(k) / f32(hs);
        let cos_a = cos(angle);
        let sin_a = sin(angle);
        let ur = sre_in[i];
        let ui = sim_in[i];
        let vr = cos_a * sre_in[j] - sin_a * sim_in[j];
        let vi = sin_a * sre_in[j] + cos_a * sim_in[j];
        sre_in[i] = ur + vr;
        sim_in[i] = ui + vi;
        sre_in[j] = ur - vr;
        sim_in[j] = ui - vi;
        workgroupBarrier();
    }

    let scale = params.norm_scale;
    if (tile_base + local < n) {
        set_re(row_base, tile_base + local, n, sre_in[local] * scale);
        set_im(row_base, tile_base + local, n, sim_in[local] * scale);
    }
    if (tile_base + local + half_tile < n) {
        set_re(row_base, tile_base + local + half_tile, n, sre_in[local + half_tile] * scale);
        set_im(row_base, tile_base + local + half_tile, n, sim_in[local + half_tile] * scale);
    }
}

@compute @workgroup_size(256)
fn fft_outer_r4(
    @builtin(global_invocation_id) gid: vec3<u32>,
    @builtin(workgroup_id) wgid: vec3<u32>,
) {
    let row = wgid.y;
    if (row >= params.outer) { return; }
    let n = params.n;
    let q = params.q_or_hs;
    let tid = gid.x;
    if (tid >= n / 4u) { return; }

    let base = params.off + row * 2u * n;
    let k = tid % q;
    let group = tid / q;
    let p = group * (q * 4u) + k;

    let ar = re_at(base, p, n);
    let ai = im_at(base, p, n);
    let br = re_at(base, p + q, n);
    let bi = im_at(base, p + q, n);
    let cr = re_at(base, p + q * 2u, n);
    let ci = im_at(base, p + q * 2u, n);
    let dr = re_at(base, p + q * 3u, n);
    let di = im_at(base, p + q * 3u, n);

    let sign = select(-1.0, 1.0, params.inverse != 0u);
    let neg_sign = select(1.0, -1.0, params.inverse != 0u);
    let angle1 = sign * 3.14159265358979323846 * f32(k) / f32(q);
    let cos1 = cos(angle1);
    let sin1 = sin(angle1);
    let w1b_r = cos1 * br - sin1 * bi;
    let w1b_i = sin1 * br + cos1 * bi;
    let w1d_r = cos1 * dr - sin1 * di;
    let w1d_i = sin1 * dr + cos1 * di;

    let u0r = ar + w1b_r;
    let u0i = ai + w1b_i;
    let u1r = ar - w1b_r;
    let u1i = ai - w1b_i;
    let u2r = cr + w1d_r;
    let u2i = ci + w1d_i;
    let u3r = cr - w1d_r;
    let u3i = ci - w1d_i;

    let angle2a = sign * 3.14159265358979323846 * f32(k) / f32(q * 2u);
    let cos2a = cos(angle2a);
    let sin2a = sin(angle2a);
    let cos2b = neg_sign * sin2a;
    let sin2b = sign * cos2a;

    let w2a_u2r = cos2a * u2r - sin2a * u2i;
    let w2a_u2i = sin2a * u2r + cos2a * u2i;
    let w2b_u3r = cos2b * u3r - sin2b * u3i;
    let w2b_u3i = sin2b * u3r + cos2b * u3i;

    let scale = params.norm_scale;
    set_re(base, p, n, (u0r + w2a_u2r) * scale);
    set_im(base, p, n, (u0i + w2a_u2i) * scale);
    set_re(base, p + q * 2u, n, (u0r - w2a_u2r) * scale);
    set_im(base, p + q * 2u, n, (u0i - w2a_u2i) * scale);
    set_re(base, p + q, n, (u1r + w2b_u3r) * scale);
    set_im(base, p + q, n, (u1i + w2b_u3i) * scale);
    set_re(base, p + q * 3u, n, (u1r - w2b_u3r) * scale);
    set_im(base, p + q * 3u, n, (u1i - w2b_u3i) * scale);
}

@compute @workgroup_size(256)
fn fft_outer_r2(
    @builtin(global_invocation_id) gid: vec3<u32>,
    @builtin(workgroup_id) wgid: vec3<u32>,
) {
    let row = wgid.y;
    if (row >= params.outer) { return; }
    let n = params.n;
    let half_stride = params.q_or_hs;
    let tid = gid.x;
    if (tid >= n / 2u) { return; }

    let base = params.off + row * 2u * n;
    let k = tid % half_stride;
    let i = (tid / half_stride) * (half_stride * 2u) + k;
    let j = i + half_stride;

    let sign = select(-1.0, 1.0, params.inverse != 0u);
    let angle = sign * 3.14159265358979323846 * f32(k) / f32(half_stride);
    let cos_a = cos(angle);
    let sin_a = sin(angle);

    let ur = re_at(base, i, n);
    let ui = im_at(base, i, n);
    let vr = cos_a * re_at(base, j, n) - sin_a * im_at(base, j, n);
    let vi = sin_a * re_at(base, j, n) + cos_a * im_at(base, j, n);
    let scale = params.norm_scale;
    set_re(base, i, n, (ur + vr) * scale);
    set_im(base, i, n, (ui + vi) * scale);
    set_re(base, j, n, (ur - vr) * scale);
    set_im(base, j, n, (ui - vi) * scale);
}