Skip to main content

Module kernels

Module kernels 

Source
Expand description

WGSL kernel sources + per-kernel pipeline cache.

Pipelines are content-addressed: same WGSL source + same entry point yields the same pipeline. We hold them in OnceLocks so a single device dispatches every (graph, op) pair against a cached compilation.

Structs§

ArgmaxParams
Layout for argmax (matches Reduce shape).
AttentionBwdParams
Layout for [attention_bwd.wgsl] — forward strides + dy_off + wrt.
AttentionParams
Layout for fused SDPA.
BatchElementwiseRegionParams
FKL batch region: batch_input_offs[slice] + shared chain (no prologue).
BinaryParams
Shared layout for binary, compare. 32 bytes (8 u32s).
Conv1dParams
Layout for Conv1D NCL.
Conv2dParams
Layout for Conv2D NCHW.
Conv3dParams
Layout for Conv3D NCDHW.
CopyParams
Layout shared by Reshape / Cast / generic full copy. 32 bytes.
CumsumBwdParams
CumsumParams
Layout for cumsum. 32 bytes.
DequantMatmulParams
Layout for DequantMatMul. 48 bytes.
ElementwiseRegionParams
PLAN L2 — interpreted N-ary element-wise region. Chain encoded as 4 u32s per step (op_kind, op_sub, lhs_enc, rhs_enc). Operand encoding: bit 31 = src kind (0=Input, 1=Step), bits 0..30 = index. scalar_input_mask is the per-input scalar fast-path bitfield; input_modulus[i] is the per-input element count for trailing- shape broadcast (0 ⇒ no broadcast, kernel reads gid; >0 ⇒ kernel reads gid % input_modulus[i]). Fixed cap at 32 steps + 16 inputs (ample for chains rlx produces). 12 padding bytes after scalar_input_mask align the next array on WGSL’s 16-byte uniform alignment boundary.
ExpandParams
Layout for Expand. Mirrors TransposeParams (rank, total, offsets); per-axis dims/strides ride in the meta storage buffer.
FftGpuParams
Uniform block for multi-kernel FFT (fft_gpu.wgsl::Params). 48 bytes.
FftParams
Layout for FFT. 32 bytes. Matches fft.wgsl::Params.
FusedResidualLnParams
Layout for FusedResidualLN. 48 bytes.
FusedResidualLnTeeParams
Layout for FusedResidualLN-Tee. 48 bytes (12 u32s).
GatherAxisParams
Layout for gather along a non-zero axis.
GatherBwdParams
GatherParams
Layout for gather.
GroupedMatmulParams
Layout for GroupedMatMul. 32 bytes.
Kernel
Lazy-init container for a compute pipeline + its bind-group layout.
LayerNormBwdParams
LayerNorm backward kernel params (f32 element offsets). Shared by the three entry points; the dispatcher picks layer_norm_bwd_input, layer_norm_bwd_gamma_partial, or layer_norm_bwd_gamma_reduce based on which Step variant fired. dbeta isn’t a dedicated op — it’s a plain Reduce::Sum over the batch dim of dy, handled by the general reduce kernel.
LayerNormParams
Layout for LayerNorm / RmsNorm.
MatmulParams
MatmulQkvParams
Layout for matmul_qkv (split-write QKV matmul). 64 bytes (16 u32s); WebGPU uniform-buffer 16-byte alignment OK.
NarrowConcatParams
Layout for narrow / concat (the same struct serves both).
Pool1dParams
Layout for Pool1D NCL.
Pool2dParams
Layout for Pool2D NCHW.
Pool3dParams
Layout for Pool3D NCDHW.
ReduceParams
Layout for reductions. 32 bytes.
RmsNormBwdParams
RMSNorm backward kernel params (f32 element offsets). wrt: 0=dx, 1=dgamma, 2=dbeta.
RopeBwdParams
RopeParams
Layout for Rope.
SampleParams
Layout for Sample. 48 bytes.
ScatterAddParams
Layout for ScatterAdd. 32 bytes (8 u32s).
SelectiveScanParams
Layout for SelectiveScan. 64 bytes.
SoftmaxParams
Layout for softmax. 32 bytes.
TopKParams
Layout for TopK. 32 bytes.
TransposeParams
Layout for transpose (uses the 3-binding bind layout).
UmapKnnParams
Layout for UMAP k-NN on a pairwise [n, n] matrix. 32 bytes.
UnaryParams
Layout for unary kernel. 32 bytes.
WelchPeaksGpuParams
Native GPU WelchPeaks dispatch parameters.
WhereParams
Layout for where (3-input select). 32 bytes.

Constants§

ARGMAX_WGSL
ATTENTION_BWD_WGSL
ATTENTION_WGSL
BINARY_WGSL
CAST_F32_TO_F16_WGSL
COMPARE_WGSL
CONCAT_WGSL
CONV1D_WGSL
CONV2D_WGSL
CONV3D_WGSL
COOP_F16_VK_WIDEN_N
N above which coop may use the row-major B-load variant (RLX_WGPU_COOP_F16_VK_LARGE_N).
COPY_WGSL
CUMSUM_BWD_WGSL
CUMSUM_WGSL
DEQUANT_MATMUL_WGSL
ELEMENTWISE_REGION_WGSL
EXPAND_WGSL
FFT_GPU_WGSL
FUSED_RESIDUAL_LN_TEE_WGSL
FUSED_RESIDUAL_LN_WGSL
FUSED_RESIDUAL_RMS_NORM_WGSL
GATHER_AXIS_WGSL
GATHER_BWD_WGSL
GATHER_WGSL
GROUPED_MATMUL_WGSL
LAYERNORM_WGSL
LAYER_NORM_BWD_WGSL
MATMUL_COOP16_WGSL
MATMUL_COOP_F16_VULKAN_F32ACC_WGSL
MATMUL_COOP_F16_VULKAN_WGSL
MATMUL_COOP_F16_VULKAN_WIDEN_F32ACC_WGSL
MATMUL_COOP_F16_VULKAN_WIDEN_WGSL
MATMUL_COOP_F32_PORTABLE_WGSL
MATMUL_COOP_F32_WGSL
MATMUL_F16W_WGSL
MATMUL_F16_COMPUTE_WGSL
MATMUL_QKV_COOP_F16_VK_F32ACC_WGSL
MATMUL_QKV_COOP_F16_VK_WGSL
MATMUL_QKV_COOP_F16_VK_WIDEN_F32ACC_WGSL
MATMUL_QKV_COOP_F16_VK_WIDEN_WGSL
MATMUL_QKV_COOP_F32_WGSL
MATMUL_QKV_WGSL
MATMUL_WGSL
MATMUL_WIDE_NV_WGSL
MATMUL_WIDE_WGSL
NARROW_WGSL
POOL1D_WGSL
POOL2D_WGSL
POOL3D_WGSL
REDUCE_WGSL
RMS_NORM_BWD_WGSL
ROPE_BWD_WGSL
ROPE_WGSL
SAMPLE_WGSL
SCATTER_ADD_WGSL
SELECTIVE_SCAN_WGSL
SOFTMAX_WGSL
TOPK_WGSL
TRANSPOSE_WGSL
UMAP_KNN_WGSL
UNARY_F16_MIRROR_WGSL
UNARY_WGSL
WELCH_PEAKS_GPU_WGSL
WHERE_WGSL

Functions§

argmax_kernel
attention_bwd_kernel
attention_kernel
batch_elementwise_region_kernel
binary_kernel
cast_f32_to_f16_kernel
Mirrors a region of the f32 arena into the f16 shadow buffer. Used before matmul_coop16 for the matmul’s activation operand (intermediate activations don’t go through set_param / write_f32, so they aren’t in the f16 buffer otherwise).
compare_kernel
concat_kernel
conv1d_kernel
conv2d_kernel
conv3d_kernel
coop_f16_vk_f32acc_available
coop_f16_vk_widen_b_load
Use coopLoad on B instead of coopLoadT when N > 768 and RLX_WGPU_COOP_F16_VK_LOAD_T is unset.
copy_kernel
cumsum_backward_kernel
cumsum_kernel
dequant_matmul_kernel
elementwise_region_kernel
elementwise_region_spatial_kernel
expand_kernel
fft_gpu_bit_reverse_kernel
fft_gpu_inner_kernel
fft_gpu_outer_r2_kernel
fft_gpu_outer_r4_kernel
fft_gpu_radix2_full_kernel
fused_residual_ln_kernel
fused_residual_ln_tee_kernel
fused_residual_rms_norm_kernel
gather_axis_kernel
gather_backward_acc_kernel
gather_backward_zero_kernel
gather_kernel
grouped_matmul_kernel
layer_norm_backward_gamma_partial_kernel
layer_norm_backward_gamma_reduce_kernel
layer_norm_backward_input_kernel
layernorm_kernel
matmul_coop16_kernel
Cooperative-matrix matmul (8×8 tiles, hardware GEMM units). Lowers to MSL simdgroup_matrix on Metal and SPIR-V’s OpCooperativeMatrixMulAddKHR on Vulkan. Returns Some only when the device exposes both SHADER_F16 and EXPERIMENTAL_COOPERATIVE_MATRIX.
matmul_coop_f16_vulkan_active_kernel
Matmul CoopF16Vk kernel for column count n.
matmul_coop_f16_vulkan_f32acc_kernel
matmul_coop_f16_vulkan_kernel
matmul_coop_f16_vulkan_widen_f32acc_kernel
matmul_coop_f16_vulkan_widen_kernel
matmul_coop_f32_active_kernel
CoopF32 kernel for the active wgpu backend (Metal simdgroup vs Vulkan/DX12 portable).
matmul_coop_f32_kernel
Pure-f32 cooperative-matrix matmul. No SHADER_F16 needed — uses coop_mat8x8<f32> throughout (lowers to simdgroup_float8x8 on Apple). Returns None if the cooperative-matrix feature is missing OR if the device’s WGSL→backend lowering can’t compile it (some implementations only expose half-precision coop matrices).
matmul_coop_f32_portable_kernel
Vulkan/DX12-oriented coop f32 matmul (coopLoad, 8×8 workgroups).
matmul_f16_compute_kernel
f16-compute matmul: f16 operands, f16 multiply, f32 accumulator. Targets the 2× f16 ALU throughput on Apple Silicon. Returns Some only when the device exposes SHADER_F16.
matmul_f16w_kernel
f16-weight matmul (f32 compute). Returns Some only when the device exposes the SHADER_F16 feature. EXPERIMENTAL: currently slower than the f32 baseline on Apple Silicon — kept as foundation; see matmul_f16w.wgsl for the empirical analysis.
matmul_kernel
matmul_qkv_coop_f16_vk_active_kernel
matmul_qkv_coop_f16_vk_f32acc_kernel
matmul_qkv_coop_f16_vk_kernel
matmul_qkv_coop_f16_vk_widen_f32acc_kernel
matmul_qkv_coop_f16_vk_widen_kernel
matmul_qkv_coop_f32_kernel
matmul_qkv_kernel
matmul_wide_active_kernel
Wide f32 matmul kernel for the active backend.
matmul_wide_kernel
matmul_wide_nv_kernel
64×64 / 256-thread variant for discrete GPUs (Vulkan path).
narrow_kernel
pool1d_kernel
pool2d_kernel
pool3d_kernel
reduce_kernel
rms_norm_backward_kernel
rms_norm_backward_param_kernel
rope_backward_kernel
rope_kernel
sample_kernel
scatter_add_kernel
selective_scan_kernel
softmax_kernel
topk_kernel
transpose_kernel
umap_knn_kernel
unary_f16_mirror_kernel
unary_kernel
welch_peaks_gpu_kernel
where_kernel

Type Aliases§

FusedResidualRmsNormParams
Layout for FusedResidualRmsNorm (same bind layout as FusedResidualLN).