rlx-wgpu 0.2.3

Cross-platform GPU backend for RLX via wgpu (Metal/Vulkan/DX12/WebGPU)
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// Per-row k-NN on a pairwise distance matrix (UMAP). One thread per row;
// insertion-sort into k slots. Matches `rlx_cpu::umap_knn::knn_forward_packed`.

struct Params {
    n: u32,
    k: u32,
    pw_off: u32,
    out_off: u32,
    _p0: u32,
    _p1: u32,
    _p2: u32,
};

@group(0) @binding(0) var<storage, read_write> arena: array<f32>;
@group(0) @binding(1) var<uniform>              params: Params;

const INF: f32 = 3.4028235e+38;

@compute @workgroup_size(64)
fn umap_knn(
    @builtin(global_invocation_id) gid: vec3<u32>,
    @builtin(num_workgroups) nwg: vec3<u32>,
) {
    let row = gid.x + gid.y * nwg.x * 64u;
    let n = params.n;
    let k = params.k;
    if (row >= n || k >= n) {
        return;
    }

    let pw_base = params.pw_off + row * n;
    let out_base = params.out_off + row * (2u * k);

    for (var s: u32 = 0u; s < k; s = s + 1u) {
        arena[out_base + s] = f32(n);
        arena[out_base + k + s] = INF;
    }

    var worst: f32 = INF;
    for (var col: u32 = 0u; col < n; col = col + 1u) {
        if (col == row) {
            continue;
        }
        let dist = arena[pw_base + col];
        if (dist >= worst) {
            continue;
        }
        if (dist < arena[out_base + k + k - 1u]) {
            var slot: u32 = k - 1u;
            loop {
                if (slot == 0u) {
                    break;
                }
                if (dist < arena[out_base + k + slot - 1u]) {
                    arena[out_base + k + slot] = arena[out_base + k + slot - 1u];
                    arena[out_base + slot] = arena[out_base + slot - 1u];
                    slot = slot - 1u;
                } else {
                    break;
                }
            }
            arena[out_base + k + slot] = dist;
            arena[out_base + slot] = f32(col);
            worst = arena[out_base + k + k - 1u];
        }
    }
}