rlx-wgpu 0.2.6

Cross-platform GPU backend for RLX via wgpu (Metal/Vulkan/DX12/WebGPU)
Documentation

rlx-wgpu

Cross-platform GPU backend via the wgpu crate. Single backend serves Metal (macOS), Vulkan (Linux), DirectX 12 (Windows), and WebGPU (browsers). WGSL kernels, pure Rust deps — no FFI, no submodules.

What's here

  • WGSL kernels — fp32 matmul (8×8 tile), cooperative-matrix matmul (32×32 tile, simdgroup_matrix / KHR_cooperative_matrix), f16-storage matmul.
  • device.rs — wgpu instance/adapter/device singleton. Platform defaults: DX12 (+ Vulkan fallback) on Windows, Vulkan on Linux/WSL, Metal (+ MoltenVK) on macOS. Override with WGPU_BACKEND. Sync wrapper via pollster::block_on so the rest of the backend matches the rlx-cpu / rlx-metal / rlx-mlx synchronous shape.
  • buffer.rs / Arena — single contiguous storage buffer; per- node offsets from rlx-opt::memory::plan_memory_aligned. f32 host I/O via queue.write_buffer / pooled MAP_READ staging readback (ReadbackStaging, read_f32_many_pooled).
  • kernels/matmul.wgsl — fp32 matmul, one workgroup per 8×8 output tile. Functional, not optimized.
  • kernels/mod.rsOnceLock-cached pipeline + bind-group layout. First dispatch pays the WGSL → SPIR-V/MSL/HLSL translation cost (~ms); subsequent dispatches reuse the compiled pipeline.
  • backend.rsWgpuExecutable. Anything not in the supported op set panics at compile time with a clear "fall back to CPU/Metal/MLX" diagnostic.
  • FFTfft_gpu.wgsl multi-kernel pow-2 dispatch (in-pass with per-op uniforms). Non-pow2 / f64 / C64 use fft_host.rs partial sync. RLX_BENCH_DISPATCH_ONLY=1 skips output readback for micro-benchmarks. RLX_DISPATCH_REPORT=1 logs fft_gpu vs fft_host step counts.

Op coverage

Today: MatMul (2D), Op::Input, Op::Param, Op::Constant. Anything else fails at compile time with a clear "fall back to CPU/Metal/MLX" diagnostic.

The roadmap is to land ops in BERT-shaped order: element-wise binary, layer norm, softmax, attention, gather, transpose. Adding an op means: WGSL source, a MatmulPipeline-style cache entry, a Step variant, a dispatch in run. PRs welcome.

Install

[dependencies]
rlx-wgpu = "0.2"

Or via rlx's gpu feature.

Cost-model calibration

When a wgpu adapter is available, rlx_wgpu::calibrate::Calibration::load_or_measure() benchmarks a 512³ matmul and caches results at ~/.cache/rlx/wgpu-calib-<adapter>.json. Feeds WgpuCostModel in rlx-runtime for backend ranking.

Cooperative-matrix matmul (Vulkan / DX12 / Metal)

On discrete GPUs (Vulkan, DX12), aligned f32 matmul with Param weights auto-selects CoopF16Vk when the adapter reports 16×16 f16 cooperative matrix support:

  • matmul_coop_f16_vulkan.wgsl — 16×16 f16 tensor-core tiles, coopLoad on A + coopLoadT on B when N ≤ 768; optional widen variant uses coopLoad on B for N > 768 when RLX_WGPU_COOP_F16_VK_LARGE_N=1. By default N > 768 dispatches matmul_wide_nv instead (see auto wide fallback).
  • matmul_qkv_coop_f16_vk.wgsl — split-write Q/K/V variant (same numerics, one dispatch instead of matmul + Narrow×3).
  • Auto wide fallback — at set_param, an oscillation score on Param B selects wide f32 matmul for stress weights; N > 768 also defaults to wide f32 (matmul_wide_nv) because RTX coop B-load at large N stays ~4e-3 vs f16-ref. Opt back into coop for large N with RLX_WGPU_COOP_F16_VK_LARGE_N=1.
  • Fallback wide path: matmul_wide_nv.wgsl (64×64 tiles).

When a matmul operand is an Activation, the lowering host-mirrors it each run() (apply_activation + write_f32, same f16 shadow as Params) and routes the matmul through F32 wide (matmul_wide_nv). CoopF16Vk remains on Input/Param-only operands (BERT QKV weights, etc.). A GPU cast_f32_to_f16 pre-pass still handles non-unary computed tensors when CoopF16Vk applies. Pass flushes between unary → cast → coop matmul keep Vulkan/DX12 visibility correct.

Environment flags

Variable Effect
WGPU_BACKEND vulkan (default Linux/Windows), dx12 (Windows), metal (macOS)
RLX_WGPU_NO_COOP_F16_VK=1 Force matmul_wide_nv instead of CoopF16Vk
RLX_WGPU_COOP_F16_VK_OSC_THRESH Oscillation score threshold for auto wide fallback (default 0.35)
RLX_WGPU_COOP_F16_VK_NO_AUTO_WIDE=1 Disable oscillation-based wide fallback
RLX_WGPU_COOP_F16_VK_NO_F32ACC=1 Disable f16×f16→f32 coop accumulator kernels
RLX_WGPU_COOP_F16_VK_FORCE_WIDE=1 Always wide-fallback CoopF16Vk matmul at dispatch
RLX_WGPU_COOP_F16_VK_LARGE_N=1 Keep CoopF16Vk for N > 768 (default: wide f32 for accuracy)
RLX_WGPU_COOP_F16_VK_LOAD_T=1 Force coopLoadT on B when RLX_WGPU_COOP_F16_VK_LARGE_N=1 and N > 768
RLX_WGPU_NO_COOP_F32=1 Disable CoopF32 on Metal
RLX_WGPU_FORCE_COOP_F32=1 Opt-in CoopF32 on Vulkan (portable 8×8; correctness not validated)
RLX_WGPU_FORCE_INPUT_UPLOAD=1 Always upload inputs (disable hash-based skip)
RLX_BENCH_DISPATCH_ONLY=1 Skip output readback (micro-benchmarks)
RLX_WGPU_F16_WEIGHTS=1 Legacy f16-storage matmul experiment
RLX_WGPU_SCHEDULE=1 Log compiled dispatch schedule
RLX_DISPATCH_REPORT=1 Per-step dispatch report
RLX_WGPU_NO_PACKED_BSHD_ATTN=1 Disable packed QKV attention stride path
RLX_WGPU_DUMP_NODES=1 Per-node max-abs dump after run (debug)

EEG-DINO parity notes: compile runs LegalizeBroadcast before fusion (mid-axis [1,C,1,D]+[1,C,P,D] needs Expand) and unfuses ElementwiseRegion on wgpu (region kernel only supports trailing broadcast). Activation::Gelu uses exact erf in unary.wgsl (not the tanh approx).

Build / test

cargo build -p rlx-wgpu --release
cargo test  -p rlx-wgpu --release
cargo test  -p rlx-wgpu --test coop_f16_vk_correctness --release -- --nocapture
cargo test  -p rlx-wgpu --test coop_mat_probe --release -- --nocapture
# DX12-only (set WGPU_BACKEND=dx12 on Windows):
cargo test  -p rlx-wgpu --test coop_f16_vk_dx12 --release -- --nocapture

Through rlx-runtime:

cargo build -p rlx-runtime --features gpu --release

Status

Functional, less battle-tested than rlx-metal / rlx-mlx on Apple Silicon. CoopF16Vk on Vulkan/DX12 is validated on RTX-class GPUs via coop_f16_vk_correctness and coop_f16_vk_bert_baseline (gentle BERT weights vs f16-ref, oscillating weights auto-wide to f32). The legacy fp32 tile matmul kernel remains correctness-first for unaligned shapes.

Gotchas

  • Scalar-output latency: End-to-end run() on tiny graphs is often readback-bound (~1.3 ms on MoltenVK; ~100–150 µs on DX12/Vulkan). Kernels may be sub-µs; use RLX_BENCH_DISPATCH_ONLY=1 to time dispatch without readback. See docs/benchmarks/higher-order-ad.md.
  • WSL: Linux cargo on the rig uses ~/rlx-workspace-mirror/rlx (ext4), not virtio D: directly. rig.sh bench-nth-order syncs new examples there.
  • Wgpu is async; we wrap with pollster::block_on for sync semantics. run() fuses output readback copies into the final compute encoder submit and schedules host mapping via CommandEncoder::map_buffer_on_submit (wgpu 29+), then reuses a pooled MAP_READ staging buffer across runs.
  • The matmul kernel is correctness-first. It loops over K per thread with no register blocking or shared-memory tiling — order of magnitude slower than what's possible. Optimization comes after the op set is broad enough to run a real model.
  • Shader compilation is lazy + cached via OnceLock. First dispatch pays the WGSL → SPIR-V/MSL/HLSL translation cost (~ms); subsequent dispatches reuse the compiled pipeline.

License

GPL-3.0-only.