rlx-wgpu 0.2.6

Cross-platform GPU backend for RLX via wgpu (Metal/Vulkan/DX12/WebGPU)
Documentation
# rlx-wgpu

Cross-platform GPU backend via the [wgpu](https://wgpu.rs/) crate.
Single backend serves Metal (macOS), Vulkan (Linux), DirectX 12
(Windows), and WebGPU (browsers). WGSL kernels, pure Rust deps —
no FFI, no submodules.

## What's here

- **WGSL kernels** — fp32 matmul (8×8 tile), cooperative-matrix
  matmul (32×32 tile, `simdgroup_matrix` / `KHR_cooperative_matrix`),
  f16-storage matmul.
- **`device.rs`** — wgpu instance/adapter/device singleton. Platform
  defaults: DX12 (+ Vulkan fallback) on Windows, Vulkan on Linux/WSL,
  Metal (+ MoltenVK) on macOS. Override with `WGPU_BACKEND`. Sync wrapper
  via `pollster::block_on` so the rest of the backend matches the
  rlx-cpu / rlx-metal / rlx-mlx synchronous shape.
- **`buffer.rs` / `Arena`** — single contiguous storage buffer; per-
  node offsets from `rlx-opt::memory::plan_memory_aligned`. f32 host
  I/O via `queue.write_buffer` / pooled MAP_READ staging readback
  (`ReadbackStaging`, `read_f32_many_pooled`).
- **`kernels/matmul.wgsl`** — fp32 matmul, one workgroup per 8×8 output
  tile. Functional, not optimized.
- **`kernels/mod.rs`**`OnceLock`-cached pipeline + bind-group layout.
  First dispatch pays the WGSL → SPIR-V/MSL/HLSL translation cost
  (~ms); subsequent dispatches reuse the compiled pipeline.
- **`backend.rs`**`WgpuExecutable`. Anything not in the supported op
  set panics at compile time with a clear "fall back to CPU/Metal/MLX"
  diagnostic.
- **FFT**`fft_gpu.wgsl` multi-kernel pow-2 dispatch (in-pass with
  per-op uniforms). Non-pow2 / f64 / C64 use `fft_host.rs` partial sync.
  `RLX_BENCH_DISPATCH_ONLY=1` skips output readback for micro-benchmarks.
  `RLX_DISPATCH_REPORT=1` logs `fft_gpu` vs `fft_host` step counts.

## Op coverage

Today: `MatMul` (2D), `Op::Input`, `Op::Param`, `Op::Constant`.
Anything else fails at compile time with a clear "fall back to
CPU/Metal/MLX" diagnostic.

The roadmap is to land ops in BERT-shaped order: element-wise binary,
layer norm, softmax, attention, gather, transpose. Adding an op means:
WGSL source, a `MatmulPipeline`-style cache entry, a `Step` variant, a
dispatch in `run`. PRs welcome.

## Install

```toml
[dependencies]
rlx-wgpu = "0.2"
```

Or via [`rlx`](https://crates.io/crates/rlx)'s `gpu` feature.

## Cost-model calibration

When a wgpu adapter is available, `rlx_wgpu::calibrate::Calibration::load_or_measure()`
benchmarks a 512³ matmul and caches results at `~/.cache/rlx/wgpu-calib-<adapter>.json`.
Feeds `WgpuCostModel` in `rlx-runtime` for backend ranking.

## Cooperative-matrix matmul (Vulkan / DX12 / Metal)

On discrete GPUs (Vulkan, DX12), aligned f32 matmul with Param weights
auto-selects **CoopF16Vk** when the adapter reports 16×16 f16 cooperative
matrix support:

- **`matmul_coop_f16_vulkan.wgsl`** — 16×16 f16 tensor-core tiles,
  `coopLoad` on A + `coopLoadT` on B when N ≤ 768; optional widen variant uses
  `coopLoad` on B for N > 768 when `RLX_WGPU_COOP_F16_VK_LARGE_N=1`. By default
  N > 768 dispatches `matmul_wide_nv` instead (see auto wide fallback).
- **`matmul_qkv_coop_f16_vk.wgsl`** — split-write Q/K/V variant (same
  numerics, one dispatch instead of matmul + Narrow×3).
- **Auto wide fallback** — at `set_param`, an oscillation score on Param B
  selects wide f32 matmul for stress weights; **N > 768** also defaults to wide
  f32 (`matmul_wide_nv`) because RTX coop B-load at large N stays ~4e-3 vs f16-ref.
  Opt back into coop for large N with `RLX_WGPU_COOP_F16_VK_LARGE_N=1`.
- Fallback wide path: **`matmul_wide_nv.wgsl`** (64×64 tiles).

When a matmul operand is an `Activation`, the lowering **host-mirrors** it
each `run()` (`apply_activation` + `write_f32`, same f16 shadow as Params)
and routes the matmul through **F32 wide** (`matmul_wide_nv`). CoopF16Vk
remains on Input/Param-only operands (BERT QKV weights, etc.). A GPU
`cast_f32_to_f16` pre-pass still handles non-unary computed tensors when
CoopF16Vk applies. Pass flushes between unary → cast → coop matmul keep
Vulkan/DX12 visibility correct.

### Environment flags

| Variable | Effect |
|----------|--------|
| `WGPU_BACKEND` | `vulkan` (default Linux/Windows), `dx12` (Windows), `metal` (macOS) |
| `RLX_WGPU_NO_COOP_F16_VK=1` | Force `matmul_wide_nv` instead of CoopF16Vk |
| `RLX_WGPU_COOP_F16_VK_OSC_THRESH` | Oscillation score threshold for auto wide fallback (default `0.35`) |
| `RLX_WGPU_COOP_F16_VK_NO_AUTO_WIDE=1` | Disable oscillation-based wide fallback |
| `RLX_WGPU_COOP_F16_VK_NO_F32ACC=1` | Disable f16×f16→f32 coop accumulator kernels |
| `RLX_WGPU_COOP_F16_VK_FORCE_WIDE=1` | Always wide-fallback CoopF16Vk matmul at dispatch |
| `RLX_WGPU_COOP_F16_VK_LARGE_N=1` | Keep CoopF16Vk for N > 768 (default: wide f32 for accuracy) |
| `RLX_WGPU_COOP_F16_VK_LOAD_T=1` | Force `coopLoadT` on B when `RLX_WGPU_COOP_F16_VK_LARGE_N=1` and N > 768 |
| `RLX_WGPU_NO_COOP_F32=1` | Disable CoopF32 on Metal |
| `RLX_WGPU_FORCE_COOP_F32=1` | Opt-in CoopF32 on Vulkan (portable 8×8; correctness not validated) |
| `RLX_WGPU_FORCE_INPUT_UPLOAD=1` | Always upload inputs (disable hash-based skip) |
| `RLX_BENCH_DISPATCH_ONLY=1` | Skip output readback (micro-benchmarks) |
| `RLX_WGPU_F16_WEIGHTS=1` | Legacy f16-storage matmul experiment |
| `RLX_WGPU_SCHEDULE=1` | Log compiled dispatch schedule |
| `RLX_DISPATCH_REPORT=1` | Per-step dispatch report |
| `RLX_WGPU_NO_PACKED_BSHD_ATTN=1` | Disable packed QKV attention stride path |
| `RLX_WGPU_DUMP_NODES=1` | Per-node max-abs dump after run (debug) |

**EEG-DINO parity notes:** compile runs `LegalizeBroadcast` before fusion
(mid-axis `[1,C,1,D]+[1,C,P,D]` needs `Expand`) and unfuses
`ElementwiseRegion` on wgpu (region kernel only supports trailing broadcast).
`Activation::Gelu` uses exact `erf` in `unary.wgsl` (not the tanh approx).

## Build / test

```sh
cargo build -p rlx-wgpu --release
cargo test  -p rlx-wgpu --release
cargo test  -p rlx-wgpu --test coop_f16_vk_correctness --release -- --nocapture
cargo test  -p rlx-wgpu --test coop_mat_probe --release -- --nocapture
# DX12-only (set WGPU_BACKEND=dx12 on Windows):
cargo test  -p rlx-wgpu --test coop_f16_vk_dx12 --release -- --nocapture
```

Through `rlx-runtime`:

```sh
cargo build -p rlx-runtime --features gpu --release
```

## Status

Functional, less battle-tested than `rlx-metal` / `rlx-mlx` on Apple
Silicon. CoopF16Vk on Vulkan/DX12 is validated on RTX-class GPUs via
`coop_f16_vk_correctness` and `coop_f16_vk_bert_baseline` (gentle BERT
weights vs f16-ref, oscillating weights auto-wide to f32). The legacy fp32
tile matmul kernel remains correctness-first for unaligned shapes.

## Gotchas

- **Scalar-output latency:** End-to-end `run()` on tiny graphs is often
  readback-bound (~1.3 ms on MoltenVK; ~100–150 µs on DX12/Vulkan).
  Kernels may be sub-µs; use `RLX_BENCH_DISPATCH_ONLY=1` to time dispatch
  without readback. See [`docs/benchmarks/higher-order-ad.md`]../docs/benchmarks/higher-order-ad.md.
- **WSL:** Linux cargo on the rig uses `~/rlx-workspace-mirror/rlx` (ext4),
  not virtio `D:` directly. `rig.sh bench-nth-order` syncs new examples there.
- Wgpu is async; we wrap with `pollster::block_on` for sync semantics.
  `run()` fuses output readback copies into the final compute encoder
  submit and schedules host mapping via `CommandEncoder::map_buffer_on_submit`
  (wgpu 29+), then reuses a pooled MAP_READ staging buffer across runs.
- The matmul kernel is correctness-first. It loops over K per thread
  with no register blocking or shared-memory tiling — order of
  magnitude slower than what's possible. Optimization comes after
  the op set is broad enough to run a real model.
- Shader compilation is lazy + cached via `OnceLock`. First dispatch
  pays the WGSL → SPIR-V/MSL/HLSL translation cost (~ms); subsequent
  dispatches reuse the compiled pipeline.

## License

GPL-3.0-only.