Expand description
WGSL kernel sources + per-kernel pipeline cache.
Pipelines are content-addressed: same WGSL source + same entry
point yields the same pipeline. We hold them in OnceLocks so a
single device dispatches every (graph, op) pair against a cached
compilation.
Structs§
- Argmax
Params - Layout for argmax (matches Reduce shape).
- Attention
BwdParams - Layout for [
attention_bwd.wgsl] — forward strides +dy_off+wrt. - Attention
Params - Layout for fused SDPA.
- Batch
Elementwise Region Params - FKL batch region:
batch_input_offs[slice]+ shared chain (no prologue). - Binary
Params - Shared layout for binary, compare. 32 bytes (8 u32s).
- Conv1d
Params - Layout for Conv1D NCL.
- Conv2d
Params - Layout for Conv2D NCHW.
- Conv3d
Params - Layout for Conv3D NCDHW.
- Copy
Params - Layout shared by Reshape / Cast / generic full copy. 32 bytes.
- Cumsum
BwdParams - Cumsum
Params - Layout for cumsum. 32 bytes.
- Dequant
Matmul Params - Layout for DequantMatMul. 48 bytes.
- Elementwise
Region Params - PLAN L2 — interpreted N-ary element-wise region. Chain encoded
as 4 u32s per step (op_kind, op_sub, lhs_enc, rhs_enc). Operand
encoding: bit 31 = src kind (0=Input, 1=Step), bits 0..30 = index.
scalar_input_maskis the per-input scalar fast-path bitfield;input_modulus[i]is the per-input element count for trailing- shape broadcast (0⇒ no broadcast, kernel reads gid;>0⇒ kernel readsgid % input_modulus[i]). Fixed cap at 32 steps + 16 inputs (ample for chains rlx produces). 12 padding bytes afterscalar_input_maskalign the next array on WGSL’s 16-byte uniform alignment boundary. - Expand
Params - Layout for Expand. Mirrors TransposeParams (rank, total, offsets); per-axis dims/strides ride in the meta storage buffer.
- FftGpu
Params - Uniform block for multi-kernel FFT (
fft_gpu.wgsl::Params). 48 bytes. - FftParams
- Layout for FFT. 32 bytes. Matches
fft.wgsl::Params. - Fused
Residual LnParams - Layout for FusedResidualLN. 48 bytes.
- Fused
Residual LnTee Params - Layout for FusedResidualLN-Tee. 48 bytes (12 u32s).
- Gather
Axis Params - Layout for gather along a non-zero axis.
- Gather
BwdParams - Gather
Params - Layout for gather.
- Grouped
Matmul Params - Layout for GroupedMatMul. 32 bytes.
- Kernel
- Lazy-init container for a compute pipeline + its bind-group layout.
- Layer
Norm BwdParams - LayerNorm backward kernel params (f32 element offsets). Shared by
the three entry points; the dispatcher picks
layer_norm_bwd_input,layer_norm_bwd_gamma_partial, orlayer_norm_bwd_gamma_reducebased on which Step variant fired. dbeta isn’t a dedicated op — it’s a plainReduce::Sumover the batch dim ofdy, handled by the general reduce kernel. - Layer
Norm Params - Layout for LayerNorm / RmsNorm.
- Matmul
Params - Matmul
QkvParams - Layout for matmul_qkv (split-write QKV matmul). 64 bytes (16 u32s); WebGPU uniform-buffer 16-byte alignment OK.
- Narrow
Concat Params - Layout for narrow / concat (the same struct serves both).
- Pool1d
Params - Layout for Pool1D NCL.
- Pool2d
Params - Layout for Pool2D NCHW.
- Pool3d
Params - Layout for Pool3D NCDHW.
- Reduce
Params - Layout for reductions. 32 bytes.
- RmsNorm
BwdParams - RMSNorm backward kernel params (f32 element offsets).
wrt: 0=dx, 1=dgamma, 2=dbeta. - Rope
BwdParams - Rope
Params - Layout for Rope.
- Sample
Params - Layout for Sample. 48 bytes.
- Scatter
AddParams - Layout for ScatterAdd. 32 bytes (8 u32s).
- Selective
Scan Params - Layout for SelectiveScan. 64 bytes.
- Softmax
Params - Layout for softmax. 32 bytes.
- TopK
Params - Layout for TopK. 32 bytes.
- Transpose
Params - Layout for transpose (uses the 3-binding bind layout).
- Umap
KnnParams - Layout for UMAP k-NN on a pairwise
[n, n]matrix. 32 bytes. - Unary
Params - Layout for unary kernel. 32 bytes.
- Welch
Peaks GpuParams - Native GPU WelchPeaks dispatch parameters.
- Where
Params - Layout for where (3-input select). 32 bytes.
Constants§
- ARGMAX_
WGSL - ATTENTION_
BWD_ WGSL - ATTENTION_
WGSL - BINARY_
WGSL - CAST_
F32_ TO_ F16_ WGSL - COMPARE_
WGSL - CONCAT_
WGSL - CONV1D_
WGSL - CONV2D_
WGSL - CONV3D_
WGSL - COOP_
F16_ VK_ WIDEN_ N - N above which coop may use the row-major B-load variant (
RLX_WGPU_COOP_F16_VK_LARGE_N). - COPY_
WGSL - CUMSUM_
BWD_ WGSL - CUMSUM_
WGSL - DEQUANT_
MATMUL_ WGSL - ELEMENTWISE_
REGION_ WGSL - EXPAND_
WGSL - FFT_
GPU_ WGSL - FUSED_
RESIDUAL_ LN_ TEE_ WGSL - FUSED_
RESIDUAL_ LN_ WGSL - FUSED_
RESIDUAL_ RMS_ NORM_ WGSL - GATHER_
AXIS_ WGSL - GATHER_
BWD_ WGSL - GATHER_
WGSL - GROUPED_
MATMUL_ WGSL - LAYERNORM_
WGSL - LAYER_
NORM_ BWD_ WGSL - MATMUL_
COOP16_ WGSL - MATMUL_
COOP_ F16_ VULKAN_ F32ACC_ WGSL - MATMUL_
COOP_ F16_ VULKAN_ WGSL - MATMUL_
COOP_ F16_ VULKAN_ WIDEN_ F32ACC_ WGSL - MATMUL_
COOP_ F16_ VULKAN_ WIDEN_ WGSL - MATMUL_
COOP_ F32_ PORTABLE_ WGSL - MATMUL_
COOP_ F32_ WGSL - MATMUL_
F16W_ WGSL - MATMUL_
F16_ COMPUTE_ WGSL - MATMUL_
QKV_ COOP_ F16_ VK_ F32ACC_ WGSL - MATMUL_
QKV_ COOP_ F16_ VK_ WGSL - MATMUL_
QKV_ COOP_ F16_ VK_ WIDEN_ F32ACC_ WGSL - MATMUL_
QKV_ COOP_ F16_ VK_ WIDEN_ WGSL - MATMUL_
QKV_ COOP_ F32_ WGSL - MATMUL_
QKV_ WGSL - MATMUL_
WGSL - MATMUL_
WIDE_ NV_ WGSL - MATMUL_
WIDE_ WGSL - NARROW_
WGSL - POOL1D_
WGSL - POOL2D_
WGSL - POOL3D_
WGSL - REDUCE_
WGSL - RMS_
NORM_ BWD_ WGSL - ROPE_
BWD_ WGSL - ROPE_
WGSL - SAMPLE_
WGSL - SCATTER_
ADD_ WGSL - SELECTIVE_
SCAN_ WGSL - SOFTMAX_
WGSL - TOPK_
WGSL - TRANSPOSE_
WGSL - UMAP_
KNN_ WGSL - UNARY_
F16_ MIRROR_ WGSL - UNARY_
WGSL - WELCH_
PEAKS_ GPU_ WGSL - WHERE_
WGSL
Functions§
- argmax_
kernel - attention_
bwd_ kernel - attention_
kernel - batch_
elementwise_ region_ kernel - binary_
kernel - cast_
f32_ to_ f16_ kernel - Mirrors a region of the f32 arena into the f16 shadow buffer.
Used before
matmul_coop16for the matmul’s activation operand (intermediate activations don’t go throughset_param/write_f32, so they aren’t in the f16 buffer otherwise). - compare_
kernel - concat_
kernel - conv1d_
kernel - conv2d_
kernel - conv3d_
kernel - coop_
f16_ vk_ f32acc_ available - coop_
f16_ vk_ widen_ b_ load - Use
coopLoadon B instead ofcoopLoadTwhen N > 768 andRLX_WGPU_COOP_F16_VK_LOAD_Tis unset. - copy_
kernel - cumsum_
backward_ kernel - cumsum_
kernel - dequant_
matmul_ kernel - elementwise_
region_ kernel - elementwise_
region_ spatial_ kernel - expand_
kernel - fft_
gpu_ bit_ reverse_ kernel - fft_
gpu_ inner_ kernel - fft_
gpu_ outer_ r2_ kernel - fft_
gpu_ outer_ r4_ kernel - fft_
gpu_ radix2_ full_ kernel - fused_
residual_ ln_ kernel - fused_
residual_ ln_ tee_ kernel - fused_
residual_ rms_ norm_ kernel - gather_
axis_ kernel - gather_
backward_ acc_ kernel - gather_
backward_ zero_ kernel - gather_
kernel - grouped_
matmul_ kernel - layer_
norm_ backward_ gamma_ partial_ kernel - layer_
norm_ backward_ gamma_ reduce_ kernel - layer_
norm_ backward_ input_ kernel - layernorm_
kernel - matmul_
coop16_ kernel - Cooperative-matrix matmul (8×8 tiles, hardware GEMM units).
Lowers to MSL
simdgroup_matrixon Metal and SPIR-V’sOpCooperativeMatrixMulAddKHRon Vulkan. Returns Some only when the device exposes bothSHADER_F16andEXPERIMENTAL_COOPERATIVE_MATRIX. - matmul_
coop_ f16_ vulkan_ active_ kernel - Matmul CoopF16Vk kernel for column count
n. - matmul_
coop_ f16_ vulkan_ f32acc_ kernel - matmul_
coop_ f16_ vulkan_ kernel - matmul_
coop_ f16_ vulkan_ widen_ f32acc_ kernel - matmul_
coop_ f16_ vulkan_ widen_ kernel - matmul_
coop_ f32_ active_ kernel - CoopF32 kernel for the active wgpu backend (Metal simdgroup vs Vulkan/DX12 portable).
- matmul_
coop_ f32_ kernel - Pure-f32 cooperative-matrix matmul. No SHADER_F16 needed — uses
coop_mat8x8<f32>throughout (lowers tosimdgroup_float8x8on Apple). Returns None if the cooperative-matrix feature is missing OR if the device’s WGSL→backend lowering can’t compile it (some implementations only expose half-precision coop matrices). - matmul_
coop_ f32_ portable_ kernel - Vulkan/DX12-oriented coop f32 matmul (
coopLoad, 8×8 workgroups). - matmul_
f16_ compute_ kernel - f16-compute matmul: f16 operands, f16 multiply, f32 accumulator.
Targets the 2× f16 ALU throughput on Apple Silicon. Returns Some
only when the device exposes
SHADER_F16. - matmul_
f16w_ kernel - f16-weight matmul (f32 compute). Returns Some only when the device
exposes the
SHADER_F16feature. EXPERIMENTAL: currently slower than the f32 baseline on Apple Silicon — kept as foundation; seematmul_f16w.wgslfor the empirical analysis. - matmul_
kernel - matmul_
qkv_ coop_ f16_ vk_ active_ kernel - matmul_
qkv_ coop_ f16_ vk_ f32acc_ kernel - matmul_
qkv_ coop_ f16_ vk_ kernel - matmul_
qkv_ coop_ f16_ vk_ widen_ f32acc_ kernel - matmul_
qkv_ coop_ f16_ vk_ widen_ kernel - matmul_
qkv_ coop_ f32_ kernel - matmul_
qkv_ kernel - matmul_
wide_ active_ kernel - Wide f32 matmul kernel for the active backend.
- matmul_
wide_ kernel - matmul_
wide_ nv_ kernel - 64×64 / 256-thread variant for discrete GPUs (Vulkan path).
- narrow_
kernel - pool1d_
kernel - pool2d_
kernel - pool3d_
kernel - reduce_
kernel - rms_
norm_ backward_ kernel - rms_
norm_ backward_ param_ kernel - rope_
backward_ kernel - rope_
kernel - sample_
kernel - scatter_
add_ kernel - selective_
scan_ kernel - softmax_
kernel - topk_
kernel - transpose_
kernel - umap_
knn_ kernel - unary_
f16_ mirror_ kernel - unary_
kernel - welch_
peaks_ gpu_ kernel - where_
kernel
Type Aliases§
- Fused
Residual RmsNorm Params - Layout for FusedResidualRmsNorm (same bind layout as FusedResidualLN).