Skip to main content

rlx_wgpu/kernels/
mod.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! WGSL kernel sources + per-kernel pipeline cache.
17//!
18//! Pipelines are content-addressed: same WGSL source + same entry
19//! point yields the same pipeline. We hold them in `OnceLock`s so a
20//! single device dispatches every (graph, op) pair against a cached
21//! compilation.
22
23use std::sync::OnceLock;
24
25use bytemuck::{Pod, Zeroable};
26
27pub const MATMUL_WGSL: &str = include_str!("matmul.wgsl");
28pub const MATMUL_WIDE_WGSL: &str = include_str!("matmul_wide.wgsl");
29pub const MATMUL_WIDE_NV_WGSL: &str = include_str!("matmul_wide_nv.wgsl");
30pub const MATMUL_F16W_WGSL: &str = include_str!("matmul_f16w.wgsl");
31pub const MATMUL_F16_COMPUTE_WGSL: &str = include_str!("matmul_f16_compute.wgsl");
32pub const MATMUL_COOP16_WGSL: &str = include_str!("matmul_coop16.wgsl");
33pub const MATMUL_COOP_F32_WGSL: &str = include_str!("matmul_coop_f32.wgsl");
34pub const MATMUL_COOP_F32_PORTABLE_WGSL: &str = include_str!("matmul_coop_f32_portable.wgsl");
35pub const MATMUL_COOP_F16_VULKAN_WGSL: &str = include_str!("matmul_coop_f16_vulkan.wgsl");
36pub const MATMUL_COOP_F16_VULKAN_WIDEN_WGSL: &str =
37    include_str!("matmul_coop_f16_vulkan_widen.wgsl");
38pub const MATMUL_COOP_F16_VULKAN_F32ACC_WGSL: &str =
39    include_str!("matmul_coop_f16_vulkan_f32acc.wgsl");
40pub const MATMUL_COOP_F16_VULKAN_WIDEN_F32ACC_WGSL: &str =
41    include_str!("matmul_coop_f16_vulkan_widen_f32acc.wgsl");
42pub const MATMUL_QKV_COOP_F16_VK_WGSL: &str = include_str!("matmul_qkv_coop_f16_vk.wgsl");
43pub const MATMUL_QKV_COOP_F16_VK_WIDEN_WGSL: &str =
44    include_str!("matmul_qkv_coop_f16_vk_widen.wgsl");
45pub const MATMUL_QKV_COOP_F16_VK_F32ACC_WGSL: &str =
46    include_str!("matmul_qkv_coop_f16_vk_f32acc.wgsl");
47pub const MATMUL_QKV_COOP_F16_VK_WIDEN_F32ACC_WGSL: &str =
48    include_str!("matmul_qkv_coop_f16_vk_widen_f32acc.wgsl");
49pub const CAST_F32_TO_F16_WGSL: &str = include_str!("cast_f32_to_f16.wgsl");
50pub const BINARY_WGSL: &str = include_str!("binary.wgsl");
51pub const UNARY_WGSL: &str = include_str!("unary.wgsl");
52pub const UNARY_F16_MIRROR_WGSL: &str = include_str!("unary_f16_mirror.wgsl");
53pub const COMPARE_WGSL: &str = include_str!("compare.wgsl");
54pub const WHERE_WGSL: &str = include_str!("where.wgsl");
55pub const REDUCE_WGSL: &str = include_str!("reduce.wgsl");
56pub const SOFTMAX_WGSL: &str = include_str!("softmax.wgsl");
57pub const LAYERNORM_WGSL: &str = include_str!("layernorm.wgsl");
58pub const RMS_NORM_BWD_WGSL: &str = include_str!("rms_norm_backward.wgsl");
59pub const LAYER_NORM_BWD_WGSL: &str = include_str!("layer_norm_backward.wgsl");
60pub const CUMSUM_BWD_WGSL: &str = include_str!("cumsum_backward.wgsl");
61pub const ROPE_BWD_WGSL: &str = include_str!("rope_backward.wgsl");
62pub const GATHER_BWD_WGSL: &str = include_str!("gather_backward.wgsl");
63pub const CUMSUM_WGSL: &str = include_str!("cumsum.wgsl");
64pub const FFT_GPU_WGSL: &str = include_str!("fft_gpu.wgsl");
65pub const COPY_WGSL: &str = include_str!("copy.wgsl");
66pub const ELEMENTWISE_REGION_WGSL: &str = include_str!("elementwise_region.wgsl");
67pub const TRANSPOSE_WGSL: &str = include_str!("transpose.wgsl");
68pub const NARROW_WGSL: &str = include_str!("narrow.wgsl");
69pub const CONCAT_WGSL: &str = include_str!("concat.wgsl");
70pub const GATHER_WGSL: &str = include_str!("gather.wgsl");
71pub const GATHER_AXIS_WGSL: &str = include_str!("gather_axis.wgsl");
72pub const ATTENTION_WGSL: &str = include_str!("attention.wgsl");
73pub const ATTENTION_BWD_WGSL: &str = include_str!("attention_bwd.wgsl");
74pub const ROPE_WGSL: &str = include_str!("rope.wgsl");
75pub const EXPAND_WGSL: &str = include_str!("expand.wgsl");
76pub const ARGMAX_WGSL: &str = include_str!("argmax.wgsl");
77pub const POOL2D_WGSL: &str = include_str!("pool2d.wgsl");
78pub const CONV2D_WGSL: &str = include_str!("conv2d.wgsl");
79pub const POOL1D_WGSL: &str = include_str!("pool1d.wgsl");
80pub const POOL3D_WGSL: &str = include_str!("pool3d.wgsl");
81pub const CONV1D_WGSL: &str = include_str!("conv1d.wgsl");
82pub const CONV3D_WGSL: &str = include_str!("conv3d.wgsl");
83pub const SCATTER_ADD_WGSL: &str = include_str!("scatter_add.wgsl");
84pub const TOPK_WGSL: &str = include_str!("topk.wgsl");
85pub const WELCH_PEAKS_GPU_WGSL: &str = include_str!("welch_peaks_gpu.wgsl");
86pub const UMAP_KNN_WGSL: &str = include_str!("umap_knn.wgsl");
87pub const GROUPED_MATMUL_WGSL: &str = include_str!("grouped_matmul.wgsl");
88pub const SAMPLE_WGSL: &str = include_str!("sample.wgsl");
89pub const SELECTIVE_SCAN_WGSL: &str = include_str!("selective_scan.wgsl");
90pub const DEQUANT_MATMUL_WGSL: &str = include_str!("dequant_matmul.wgsl");
91pub const FUSED_RESIDUAL_LN_WGSL: &str = include_str!("fused_residual_ln.wgsl");
92pub const FUSED_RESIDUAL_LN_TEE_WGSL: &str = include_str!("fused_residual_ln_tee.wgsl");
93pub const FUSED_RESIDUAL_RMS_NORM_WGSL: &str = include_str!("fused_residual_rms_norm.wgsl");
94pub const MATMUL_QKV_WGSL: &str = include_str!("matmul_qkv.wgsl");
95pub const MATMUL_QKV_COOP_F32_WGSL: &str = include_str!("matmul_qkv_coop_f32.wgsl");
96
97#[repr(C)]
98#[derive(Debug, Clone, Copy, Pod, Zeroable)]
99pub struct MatmulParams {
100    pub m: u32,
101    pub k: u32,
102    pub n: u32,
103    pub a_off: u32,
104    pub b_off: u32,
105    pub c_off: u32,
106    pub batch: u32,
107    pub a_batch_stride: u32,
108    pub b_batch_stride: u32,
109    pub c_batch_stride: u32,
110    pub has_bias: u32,
111    pub bias_off: u32,
112    pub act_id: u32, // 0xFFFF = no activation
113    pub _pad0: u32,
114    pub _pad1: u32,
115    pub _pad2: u32,
116}
117
118/// Shared layout for binary, compare. 32 bytes (8 u32s).
119#[repr(C)]
120#[derive(Debug, Clone, Copy, Pod, Zeroable)]
121pub struct BinaryParams {
122    pub n: u32,
123    pub a_off: u32,
124    pub b_off: u32,
125    pub c_off: u32,
126    pub op: u32,
127    pub _p0: u32,
128    pub _p1: u32,
129    pub _p2: u32,
130}
131
132/// Layout for unary kernel. 32 bytes.
133#[repr(C)]
134#[derive(Debug, Clone, Copy, Pod, Zeroable)]
135pub struct UnaryParams {
136    pub n: u32,
137    pub in_off: u32,
138    pub out_off: u32,
139    pub op: u32,
140    pub _p0: u32,
141    pub _p1: u32,
142    pub _p2: u32,
143    pub _p3: u32,
144}
145
146/// Layout for where (3-input select). 32 bytes.
147#[repr(C)]
148#[derive(Debug, Clone, Copy, Pod, Zeroable)]
149pub struct WhereParams {
150    pub n: u32,
151    pub cond_off: u32,
152    pub x_off: u32,
153    pub y_off: u32,
154    pub out_off: u32,
155    pub _p0: u32,
156    pub _p1: u32,
157    pub _p2: u32,
158}
159
160/// Layout for reductions. 32 bytes.
161///
162/// Supports arbitrary-axis reductions. The reduce kernel walks the
163/// input as a 3D tensor `[outer, reduce_dim, inner]` where:
164///   * `outer` = product of dims BEFORE the reduce axis
165///   * `reduce_dim` = the reduce axis itself
166///   * `inner` = product of dims AFTER the reduce axis (=1 for the
167///     last-axis case, which is what the v3 dispatcher emitted).
168/// Output shape is `[outer, inner]` (or with the reduce axis kept as 1
169/// when `keep_dim`; the dispatcher handles the shape arithmetic).
170#[repr(C)]
171pub struct ReduceParams {
172    pub outer: u32,
173    pub reduce_dim: u32,
174    pub inner: u32,
175    pub in_off: u32,
176    pub out_off: u32,
177    pub op: u32,
178    pub _p0: u32,
179    pub _p1: u32,
180}
181
182// Manual impls to avoid issues with structural derives if any field
183// arrangement subtly trips bytemuck.
184unsafe impl Pod for ReduceParams {}
185unsafe impl Zeroable for ReduceParams {}
186impl Copy for ReduceParams {}
187impl Clone for ReduceParams {
188    fn clone(&self) -> Self {
189        *self
190    }
191}
192impl std::fmt::Debug for ReduceParams {
193    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        write!(
195            f,
196            "ReduceParams {{ outer: {}, reduce_dim: {}, inner: {}, op: {} }}",
197            self.outer, self.reduce_dim, self.inner, self.op
198        )
199    }
200}
201
202/// Layout for softmax. 32 bytes.
203#[repr(C)]
204#[derive(Debug, Clone, Copy, Pod, Zeroable)]
205pub struct SoftmaxParams {
206    pub outer: u32,
207    pub inner: u32,
208    pub in_off: u32,
209    pub out_off: u32,
210    pub _p0: u32,
211    pub _p1: u32,
212    pub _p2: u32,
213    pub _p3: u32,
214}
215
216/// Layout for LayerNorm / RmsNorm.
217#[repr(C)]
218#[derive(Debug, Clone, Copy, Pod, Zeroable)]
219pub struct LayerNormParams {
220    pub outer: u32,
221    pub inner: u32,
222    pub in_off: u32,
223    pub out_off: u32,
224    pub gamma_off: u32,
225    pub beta_off: u32,
226    pub eps_bits: u32, // bitcast::<u32>(eps)
227    pub op: u32,       // 0=LayerNorm, 1=RmsNorm
228}
229
230/// LayerNorm backward kernel params (f32 element offsets). Shared by
231/// the three entry points; the dispatcher picks `layer_norm_bwd_input`,
232/// `layer_norm_bwd_gamma_partial`, or `layer_norm_bwd_gamma_reduce`
233/// based on which Step variant fired. dbeta isn't a dedicated op — it's
234/// a plain `Reduce::Sum` over the batch dim of `dy`, handled by the
235/// general reduce kernel.
236///
237/// `scratch_off` is the f32-element offset of the tail scratch zone
238/// (only used by the gamma partial/reduce kernels). For the reduce
239/// kernel `outer` carries the number of partial chunks emitted by the
240/// partial kernel.
241#[repr(C)]
242#[derive(Debug, Clone, Copy, Pod, Zeroable)]
243pub struct LayerNormBwdParams {
244    pub outer: u32,
245    pub inner: u32,
246    pub x_off: u32,
247    pub gamma_off: u32,
248    pub dy_off: u32,
249    pub out_off: u32,
250    pub eps_bits: u32,
251    pub scratch_off: u32,
252}
253
254/// RMSNorm backward kernel params (f32 element offsets). `wrt`: 0=dx, 1=dgamma, 2=dbeta.
255#[repr(C)]
256#[derive(Debug, Clone, Copy, Pod, Zeroable)]
257pub struct RmsNormBwdParams {
258    pub outer: u32,
259    pub inner: u32,
260    pub x_off: u32,
261    pub gamma_off: u32,
262    pub beta_off: u32,
263    pub dy_off: u32,
264    pub out_off: u32,
265    pub eps_bits: u32,
266    pub wrt: u32,
267}
268
269#[repr(C)]
270#[derive(Debug, Clone, Copy, Pod, Zeroable)]
271pub struct CumsumBwdParams {
272    pub outer: u32,
273    pub inner: u32,
274    pub dy_off: u32,
275    pub dx_off: u32,
276    pub exclusive: u32,
277    pub _p0: u32,
278    pub _p1: u32,
279    pub _p2: u32,
280}
281
282#[repr(C)]
283#[derive(Debug, Clone, Copy, Pod, Zeroable)]
284pub struct RopeBwdParams {
285    pub batch: u32,
286    pub seq: u32,
287    pub hidden: u32,
288    pub head_dim: u32,
289    pub n_rot: u32,
290    pub dy_off: u32,
291    pub cos_off: u32,
292    pub sin_off: u32,
293    pub dx_off: u32,
294    pub cos_len: u32,
295}
296
297#[repr(C)]
298#[derive(Debug, Clone, Copy, Pod, Zeroable)]
299pub struct GatherBwdParams {
300    pub outer: u32,
301    pub axis_dim: u32,
302    pub num_idx: u32,
303    pub trailing: u32,
304    pub dy_off: u32,
305    pub idx_off: u32,
306    pub dst_off: u32,
307    pub _p0: u32,
308}
309
310/// Layout for cumsum. 32 bytes.
311#[repr(C)]
312#[derive(Debug, Clone, Copy, Pod, Zeroable)]
313pub struct CumsumParams {
314    pub outer: u32,
315    pub inner: u32,
316    pub in_off: u32,
317    pub out_off: u32,
318    pub exclusive: u32,
319    pub _p0: u32,
320    pub _p1: u32,
321    pub _p2: u32,
322}
323
324/// Layout for FFT. 32 bytes. Matches `fft.wgsl::Params`.
325#[repr(C)]
326#[derive(Debug, Clone, Copy, Pod, Zeroable)]
327pub struct FftParams {
328    pub src_off: u32,
329    pub dst_off: u32,
330    pub n: u32,
331    pub log2n: u32,
332    pub inverse: u32,
333    pub norm_scale: f32,
334    pub _p1: u32,
335    pub _p2: u32,
336}
337
338/// Uniform block for multi-kernel FFT (`fft_gpu.wgsl::Params`). 48 bytes.
339#[repr(C)]
340#[derive(Debug, Clone, Copy, Pod, Zeroable)]
341pub struct FftGpuParams {
342    pub off: u32,
343    pub dst_off: u32,
344    pub n: u32,
345    pub log2n: u32,
346    pub inverse: u32,
347    pub norm_scale: f32,
348    pub outer: u32,
349    pub tile: u32,
350    pub inner_stages: u32,
351    pub q_or_hs: u32,
352}
353
354/// PLAN L2 — interpreted N-ary element-wise region. Chain encoded
355/// as 4 u32s per step (op_kind, op_sub, lhs_enc, rhs_enc). Operand
356/// encoding: bit 31 = src kind (0=Input, 1=Step), bits 0..30 = index.
357/// `scalar_input_mask` is the per-input scalar fast-path bitfield;
358/// `input_modulus[i]` is the per-input element count for trailing-
359/// shape broadcast (`0` ⇒ no broadcast, kernel reads gid; `>0` ⇒
360/// kernel reads `gid % input_modulus[i]`). Fixed cap at 32 steps +
361/// 16 inputs (ample for chains rlx produces). 12 padding bytes
362/// after `scalar_input_mask` align the next array on WGSL's
363/// 16-byte uniform alignment boundary.
364#[repr(C)]
365#[derive(Debug, Clone, Copy, Pod, Zeroable)]
366pub struct ElementwiseRegionParams {
367    pub len: u32,
368    pub num_inputs: u32,
369    pub num_steps: u32,
370    pub dst_off: u32,
371    pub input_offs: [u32; 16],
372    pub chain: [u32; 128], // 32 steps * 4 u32s
373    pub scalar_input_mask: u32,
374    pub prologue: u32,
375    pub out_n: u32,
376    pub out_c: u32,
377    pub out_h: u32,
378    pub out_w: u32,
379    pub prologue_input: u32,
380    pub input_modulus: [u32; 16],
381}
382
383/// FKL batch region: `batch_input_offs[slice]` + shared chain (no prologue).
384#[repr(C)]
385#[derive(Debug, Clone, Copy, Pod, Zeroable)]
386pub struct BatchElementwiseRegionParams {
387    pub slice_len: u32,
388    pub num_batch: u32,
389    pub num_steps: u32,
390    pub base_dst_off: u32,
391    pub slice_elems: u32,
392    pub batch_input_offs: [u32; 64],
393    pub chain: [u32; 128],
394    pub scalar_input_mask: u32,
395    pub input_modulus: [u32; 16],
396}
397
398/// Layout shared by Reshape / Cast / generic full copy. 32 bytes.
399#[repr(C)]
400#[derive(Debug, Clone, Copy, Pod, Zeroable)]
401pub struct CopyParams {
402    pub n: u32,
403    pub in_off: u32,
404    pub out_off: u32,
405    pub _p0: u32,
406    pub _p1: u32,
407    pub _p2: u32,
408    pub _p3: u32,
409    pub _p4: u32,
410}
411
412/// Layout for transpose (uses the 3-binding bind layout).
413#[repr(C)]
414#[derive(Debug, Clone, Copy, Pod, Zeroable)]
415pub struct TransposeParams {
416    pub rank: u32,
417    pub out_total: u32,
418    pub in_off: u32,
419    pub out_off: u32,
420    /// PLAN L1 — precomputed at compile time. `1` when `perm[0] == 0`
421    /// (= bucket axis stays at output axis 0). Active-extent path
422    /// scales `out_total` proportionally only when this is `1`.
423    pub bucket_outermost: u32,
424    /// PLAN L1 — `out_dims[0]` for active-extent scaling math.
425    pub out_dim_0: u32,
426    pub _p2: u32,
427    pub _p3: u32,
428}
429
430/// Layout for narrow / concat (the same struct serves both).
431#[repr(C)]
432#[derive(Debug, Clone, Copy, Pod, Zeroable)]
433pub struct NarrowConcatParams {
434    pub total: u32, // total elements (output for narrow, input for concat)
435    pub outer: u32,
436    pub inner: u32,
437    pub axis_in_size: u32,
438    pub axis_out_size: u32,
439    pub start: u32,
440    pub in_off: u32,
441    pub out_off: u32,
442}
443
444/// Layout for gather.
445#[repr(C)]
446#[derive(Debug, Clone, Copy, Pod, Zeroable)]
447pub struct GatherParams {
448    pub n_out: u32,
449    pub n_idx: u32,
450    pub dim: u32,
451    pub vocab: u32,
452    pub in_off: u32,
453    pub idx_off: u32,
454    pub out_off: u32,
455    pub _p0: u32,
456}
457
458/// Layout for gather along a non-zero axis.
459#[repr(C)]
460#[derive(Debug, Clone, Copy, Pod, Zeroable)]
461pub struct GatherAxisParams {
462    pub total: u32,
463    pub outer: u32,
464    pub axis_dim: u32,
465    pub num_idx: u32,
466    pub trailing: u32,
467    pub table_off: u32,
468    pub idx_off: u32,
469    pub out_off: u32,
470}
471
472/// Layout for fused SDPA.
473///
474/// Per-tensor (Q, K, V, output) strides are passed explicitly so the
475/// kernel can read either canonical [B, H, S, D] or transposed
476/// [B, S, H, D] without inserting upstream Transpose dispatches. The
477/// layout-elimination saves ~24 transpose dispatches per BERT-L6
478/// forward (one per Q/K/V/output × layers), each ~50µs at small batch.
479///
480/// The `seq_q_stride` / `seq_k_stride` fields are retained because
481/// they describe the MASK layout `[B, H, S_q, S_k]` (separate from
482/// Q/K/V layout), used by `MaskKind::Custom`.
483///
484/// 144 bytes (36 u32s); WebGPU uniform-buffer 16-byte alignment OK.
485#[repr(C)]
486#[derive(Debug, Clone, Copy, Pod, Zeroable)]
487pub struct AttentionParams {
488    pub batch: u32,
489    pub heads: u32,
490    pub seq_q: u32,
491    pub seq_k: u32,
492    pub head_dim: u32,
493    pub q_off: u32,
494    pub k_off: u32,
495    pub v_off: u32,
496    pub out_off: u32,
497    pub mask_off: u32,
498    pub mask_kind: u32,
499    pub scale_bits: u32,
500    pub window: u32,
501    /// MASK address strides. Mask address math (per-element):
502    ///   addr = mask_off
503    ///        + b  * mask_batch_stride
504    ///        + h  * mask_head_stride
505    ///        + qi * seq_q_stride         (per-query stride)
506    ///        + s  * seq_k_stride         (per-key   stride)
507    /// Setting some strides to 0 lets the kernel read a *broadcast*
508    /// mask without materializing the broadcast. e.g. BERT padding mask
509    /// `[B, S]`: mask_batch_stride=S, mask_head_stride=0, seq_q_stride=0,
510    /// seq_k_stride=1. Saves the Expand pre-pass that unfuse used to
511    /// emit per attention block.
512    pub seq_q_stride: u32,
513    pub seq_k_stride: u32,
514    pub mask_batch_stride: u32,
515    pub mask_head_stride: u32,
516    pub _pad_mask_0: u32,
517    pub _pad_mask_1: u32,
518    pub _pad_mask_2: u32,
519
520    // Q stride triple (in f32 elements). For [B, H, S, D]:
521    //   q_batch_stride = H·S·D, q_head_stride = S·D, q_seq_stride = D
522    // For [B, S, H, D]:
523    //   q_batch_stride = S·H·D, q_head_stride = D,   q_seq_stride = H·D
524    pub q_batch_stride: u32,
525    pub q_head_stride: u32,
526    pub q_seq_stride: u32,
527    pub _pad_q: u32,
528
529    pub k_batch_stride: u32,
530    pub k_head_stride: u32,
531    pub k_seq_stride: u32,
532    pub _pad_k: u32,
533
534    pub v_batch_stride: u32,
535    pub v_head_stride: u32,
536    pub v_seq_stride: u32,
537    pub _pad_v: u32,
538
539    pub o_batch_stride: u32,
540    pub o_head_stride: u32,
541    pub o_seq_stride: u32,
542    pub _pad_o: u32,
543}
544
545/// Layout for [`attention_bwd.wgsl`] — forward strides + `dy_off` + `wrt`.
546#[repr(C)]
547#[derive(Debug, Clone, Copy, Pod, Zeroable)]
548pub struct AttentionBwdParams {
549    pub batch: u32,
550    pub heads: u32,
551    pub seq_q: u32,
552    pub seq_k: u32,
553    pub head_dim: u32,
554    pub q_off: u32,
555    pub k_off: u32,
556    pub v_off: u32,
557    pub dy_off: u32,
558    pub out_off: u32,
559    pub mask_off: u32,
560    pub mask_kind: u32,
561    pub scale_bits: u32,
562    pub window: u32,
563    pub wrt: u32,
564    pub seq_q_stride: u32,
565    pub seq_k_stride: u32,
566    pub mask_batch_stride: u32,
567    pub mask_head_stride: u32,
568    pub _pad_mask_0: u32,
569    pub _pad_mask_1: u32,
570    pub _pad_mask_2: u32,
571    pub q_batch_stride: u32,
572    pub q_head_stride: u32,
573    pub q_seq_stride: u32,
574    pub _pad_q: u32,
575    pub k_batch_stride: u32,
576    pub k_head_stride: u32,
577    pub k_seq_stride: u32,
578    pub _pad_k: u32,
579    pub v_batch_stride: u32,
580    pub v_head_stride: u32,
581    pub v_seq_stride: u32,
582    pub _pad_v: u32,
583    pub o_batch_stride: u32,
584    pub o_head_stride: u32,
585    pub o_seq_stride: u32,
586    pub _pad_o: u32,
587}
588
589/// Layout for Rope.
590#[repr(C)]
591#[derive(Debug, Clone, Copy, Pod, Zeroable)]
592pub struct RopeParams {
593    pub n_total: u32,
594    pub seq: u32,
595    pub head_dim: u32,
596    pub half: u32,
597    pub in_off: u32,
598    pub cos_off: u32,
599    pub sin_off: u32,
600    pub out_off: u32,
601    pub last_dim: u32,
602    /// PLAN L1 — set at compile time. Together with `seq_stride`,
603    /// lets the WGSL kernel decompose iteration index into
604    /// `(bi, si, d)` while indexing into the underlying full-extent
605    /// buffer. `n_total` is the runtime-scaled iteration bound;
606    /// `seq_stride` is the compile-time-fixed full seq for stride.
607    pub batch: u32,
608    pub seq_stride: u32,
609    pub _p2: u32,
610}
611
612/// Layout for Expand. Mirrors TransposeParams (rank, total, offsets);
613/// per-axis dims/strides ride in the meta storage buffer.
614#[repr(C)]
615#[derive(Debug, Clone, Copy, Pod, Zeroable)]
616pub struct ExpandParams {
617    pub rank: u32,
618    pub out_total: u32,
619    pub in_off: u32,
620    pub out_off: u32,
621    /// PLAN L1 — precomputed at compile time. `1` when the bucket
622    /// axis stays at output axis 0 after the expand mapping.
623    pub bucket_outermost: u32,
624    /// PLAN L1 — `out_dims[0]` for active-extent scaling math.
625    pub out_dim_0: u32,
626    pub _p2: u32,
627    pub _p3: u32,
628}
629
630/// Layout for argmax (matches Reduce shape).
631#[repr(C)]
632#[derive(Debug, Clone, Copy, Pod, Zeroable)]
633pub struct ArgmaxParams {
634    pub outer: u32,
635    pub inner: u32,
636    pub in_off: u32,
637    pub out_off: u32,
638    pub _p0: u32,
639    pub _p1: u32,
640    pub _p2: u32,
641    pub _p3: u32,
642}
643
644/// Layout for Pool2D NCHW.
645#[repr(C)]
646#[derive(Debug, Clone, Copy, Pod, Zeroable)]
647pub struct Pool2dParams {
648    pub n: u32,
649    pub c: u32,
650    pub h: u32,
651    pub w: u32,
652    pub h_out: u32,
653    pub w_out: u32,
654    pub kh: u32,
655    pub kw: u32,
656    pub sh: u32,
657    pub sw: u32,
658    pub ph: u32,
659    pub pw: u32,
660    pub op: u32,
661    pub in_off: u32,
662    pub out_off: u32,
663    pub _p0: u32,
664    pub _p1: u32,
665    pub _p2: u32,
666}
667
668/// Layout for Conv2D NCHW.
669#[repr(C)]
670#[derive(Debug, Clone, Copy, Pod, Zeroable)]
671pub struct Conv2dParams {
672    pub n: u32,
673    pub c_in: u32,
674    pub c_out: u32,
675    pub h: u32,
676    pub w: u32,
677    pub h_out: u32,
678    pub w_out: u32,
679    pub kh: u32,
680    pub kw: u32,
681    pub sh: u32,
682    pub sw: u32,
683    pub ph: u32,
684    pub pw: u32,
685    pub dh: u32,
686    pub dw: u32,
687    pub groups: u32,
688    pub in_off: u32,
689    pub w_off: u32,
690    pub out_off: u32,
691}
692
693/// Layout for Pool1D NCL.
694#[repr(C)]
695#[derive(Debug, Clone, Copy, Pod, Zeroable)]
696pub struct Pool1dParams {
697    pub n: u32,
698    pub c: u32,
699    pub l: u32,
700    pub l_out: u32,
701    pub kl: u32,
702    pub sl: u32,
703    pub pl: u32,
704    pub op: u32,
705    pub in_off: u32,
706    pub out_off: u32,
707    pub _p0: u32,
708    pub _p1: u32,
709    pub _p2: u32,
710    pub _p3: u32,
711    pub _p4: u32,
712    pub _p5: u32,
713}
714
715/// Layout for Pool3D NCDHW.
716#[repr(C)]
717#[derive(Debug, Clone, Copy, Pod, Zeroable)]
718pub struct Pool3dParams {
719    pub n: u32,
720    pub c: u32,
721    pub d: u32,
722    pub h: u32,
723    pub w: u32,
724    pub d_out: u32,
725    pub h_out: u32,
726    pub w_out: u32,
727    pub kd: u32,
728    pub kh: u32,
729    pub kw: u32,
730    pub sd: u32,
731    pub sh: u32,
732    pub sw: u32,
733    pub pd: u32,
734    pub ph: u32,
735    pub pw: u32,
736    pub op: u32,
737    pub in_off: u32,
738    pub out_off: u32,
739    pub _p0: u32,
740    pub _p1: u32,
741}
742
743/// Layout for Conv1D NCL.
744#[repr(C)]
745#[derive(Debug, Clone, Copy, Pod, Zeroable)]
746pub struct Conv1dParams {
747    pub n: u32,
748    pub c_in: u32,
749    pub c_out: u32,
750    pub l: u32,
751    pub l_out: u32,
752    pub kl: u32,
753    pub sl: u32,
754    pub pl: u32,
755    pub dl: u32,
756    pub groups: u32,
757    pub in_off: u32,
758    pub w_off: u32,
759    pub out_off: u32,
760    pub _p0: u32,
761    pub _p1: u32,
762    pub _p2: u32,
763}
764
765/// Layout for DequantMatMul. 48 bytes.
766#[repr(C)]
767#[derive(Debug, Clone, Copy, Pod, Zeroable)]
768pub struct DequantMatmulParams {
769    pub m: u32,
770    pub k: u32,
771    pub n: u32,
772    pub block_size: u32,
773    pub scheme_id: u32,
774    pub x_off: u32,
775    pub w_off: u32,
776    pub scale_off: u32,
777    pub zp_off: u32,
778    pub out_off: u32,
779    pub _p0: u32,
780    pub _p1: u32,
781}
782
783/// Layout for FusedResidualLN-Tee. 48 bytes (12 u32s).
784#[repr(C)]
785#[derive(Debug, Clone, Copy, Pod, Zeroable)]
786pub struct FusedResidualLnTeeParams {
787    pub outer: u32,
788    pub inner: u32,
789    pub in_off: u32,
790    pub residual_off: u32,
791    pub bias_off: u32,
792    pub gamma_off: u32,
793    pub beta_off: u32,
794    pub sum_off: u32,
795    pub ln_out_off: u32,
796    pub eps_bits: u32,
797    pub has_bias: u32,
798    pub _p0: u32,
799}
800
801/// Layout for matmul_qkv (split-write QKV matmul).
802/// 64 bytes (16 u32s); WebGPU uniform-buffer 16-byte alignment OK.
803#[repr(C)]
804#[derive(Debug, Clone, Copy, Pod, Zeroable)]
805pub struct MatmulQkvParams {
806    pub m: u32,
807    pub k: u32,
808    pub n: u32,
809    pub a_off: u32,
810    pub b_off: u32,
811    pub q_off: u32,
812    pub k_off: u32,
813    pub v_off: u32,
814    pub head_width: u32,
815    pub has_bias: u32,
816    pub bias_off: u32,
817    pub _p0: u32,
818    pub _p1: u32,
819    pub _p2: u32,
820    pub _p3: u32,
821    pub _p4: u32,
822}
823
824/// Layout for FusedResidualRmsNorm (same bind layout as FusedResidualLN).
825pub type FusedResidualRmsNormParams = FusedResidualLnParams;
826
827/// Layout for FusedResidualLN. 48 bytes.
828#[repr(C)]
829#[derive(Debug, Clone, Copy, Pod, Zeroable)]
830pub struct FusedResidualLnParams {
831    pub outer: u32,
832    pub inner: u32,
833    pub in_off: u32,
834    pub residual_off: u32,
835    pub bias_off: u32,
836    pub gamma_off: u32,
837    pub beta_off: u32,
838    pub out_off: u32,
839    pub eps_bits: u32,
840    pub has_bias: u32,
841    pub _p0: u32,
842    pub _p1: u32,
843}
844
845/// Layout for SelectiveScan. 64 bytes.
846#[repr(C)]
847#[derive(Debug, Clone, Copy, Pod, Zeroable)]
848pub struct SelectiveScanParams {
849    pub batch: u32,
850    pub seq: u32,
851    pub hidden: u32,
852    pub state_size: u32,
853    pub x_off: u32,
854    pub delta_off: u32,
855    pub a_off: u32,
856    pub b_off: u32,
857    pub c_off: u32,
858    pub out_off: u32,
859    /// PLAN L1 — full-extent seq stride for per-batch offset math.
860    /// Stays at compile-time `seq` even when runtime `seq` is scaled,
861    /// so per-batch arena offsets stay correct under active-extent.
862    pub seq_stride: u32,
863    pub _p1: u32,
864    pub _p2: u32,
865    pub _p3: u32,
866    pub _p4: u32,
867    pub _p5: u32,
868}
869
870/// Layout for Sample. 48 bytes.
871#[repr(C)]
872#[derive(Debug, Clone, Copy, Pod, Zeroable)]
873pub struct SampleParams {
874    pub outer: u32,
875    pub inner: u32,
876    pub in_off: u32,
877    pub out_off: u32,
878    pub top_k: u32,
879    pub top_p_bits: u32,
880    pub temp_bits: u32,
881    pub seed_lo: u32,
882    pub seed_hi: u32,
883    pub _p0: u32,
884    pub _p1: u32,
885    pub _p2: u32,
886}
887
888/// Layout for GroupedMatMul. 32 bytes.
889#[repr(C)]
890#[derive(Debug, Clone, Copy, Pod, Zeroable)]
891pub struct GroupedMatmulParams {
892    pub m: u32,
893    pub k: u32,
894    pub n: u32,
895    pub num_experts: u32,
896    pub in_off: u32,
897    pub w_off: u32,
898    pub idx_off: u32,
899    pub out_off: u32,
900}
901
902/// Layout for TopK. 32 bytes.
903#[repr(C)]
904#[derive(Debug, Clone, Copy, Pod, Zeroable)]
905pub struct TopKParams {
906    pub outer: u32,
907    pub inner: u32,
908    pub k: u32,
909    pub in_off: u32,
910    pub out_off: u32,
911    pub _p0: u32,
912    pub _p1: u32,
913    pub _p2: u32,
914}
915
916/// Native GPU WelchPeaks dispatch parameters.
917#[repr(C)]
918#[derive(Debug, Clone, Copy, Pod, Zeroable)]
919pub struct WelchPeaksGpuParams {
920    pub spec_off: u32,
921    pub dst_off: u32,
922    pub welch_batch: u32,
923    pub n_fft: u32,
924    pub n_segments: u32,
925    pub k: u32,
926    pub n_bins: u32,
927    pub _p0: u32,
928    pub _p1: u32,
929}
930
931/// Layout for UMAP k-NN on a pairwise `[n, n]` matrix. 32 bytes.
932#[repr(C)]
933#[derive(Debug, Clone, Copy, Pod, Zeroable)]
934pub struct UmapKnnParams {
935    pub n: u32,
936    pub k: u32,
937    pub pw_off: u32,
938    pub out_off: u32,
939    pub _p0: u32,
940    pub _p1: u32,
941    pub _p2: u32,
942}
943
944/// Layout for ScatterAdd. 32 bytes (8 u32s).
945#[repr(C)]
946#[derive(Debug, Clone, Copy, Pod, Zeroable)]
947pub struct ScatterAddParams {
948    pub op: u32, // 0 = zero phase, 1 = accumulate phase
949    pub out_off: u32,
950    pub upd_off: u32,
951    pub idx_off: u32,
952    pub out_total: u32,
953    pub num_updates: u32,
954    pub trailing: u32,
955    pub out_dim: u32,
956}
957
958/// Layout for Conv3D NCDHW.
959#[repr(C)]
960#[derive(Debug, Clone, Copy, Pod, Zeroable)]
961pub struct Conv3dParams {
962    pub n: u32,
963    pub c_in: u32,
964    pub c_out: u32,
965    pub d: u32,
966    pub h: u32,
967    pub w: u32,
968    pub d_out: u32,
969    pub h_out: u32,
970    pub w_out: u32,
971    pub kd: u32,
972    pub kh: u32,
973    pub kw: u32,
974    pub sd: u32,
975    pub sh: u32,
976    pub sw: u32,
977    pub pd: u32,
978    pub ph: u32,
979    pub pw: u32,
980    pub dd: u32,
981    pub dh: u32,
982    pub dw: u32,
983    pub groups: u32,
984    pub in_off: u32,
985    pub w_off: u32,
986    pub out_off: u32,
987    pub _p0: u32,
988}
989
990/// Lazy-init container for a compute pipeline + its bind-group layout.
991pub struct Kernel {
992    pub pipeline: wgpu::ComputePipeline,
993    pub bgl: wgpu::BindGroupLayout,
994}
995
996impl Kernel {
997    pub fn bind_two(
998        &self,
999        device: &wgpu::Device,
1000        arena: &wgpu::Buffer,
1001        uniform: &wgpu::Buffer,
1002    ) -> wgpu::BindGroup {
1003        device.create_bind_group(&wgpu::BindGroupDescriptor {
1004            label: Some("rlx-wgpu fft gpu bg"),
1005            layout: &self.bgl,
1006            entries: &[
1007                wgpu::BindGroupEntry {
1008                    binding: 0,
1009                    resource: arena.as_entire_binding(),
1010                },
1011                wgpu::BindGroupEntry {
1012                    binding: 1,
1013                    resource: uniform.as_entire_binding(),
1014                },
1015            ],
1016        })
1017    }
1018}
1019
1020/// Build a 4-binding compute kernel: storage(rw) / uniform / storage(ro)
1021/// / storage(ro). Currently unused — `matmul_coop16` switched to a
1022/// 3-binding layout (A is staged from arena through workgroup memory
1023/// instead of from a separate f16 binding). Kept for future kernels
1024/// that genuinely need a 4th binding.
1025#[allow(dead_code)]
1026/// Used by the cooperative-matrix matmul which needs a
1027/// fourth binding for the f16 activation shadow buffer.
1028fn build_kernel_4(
1029    device: &wgpu::Device,
1030    label: &'static str,
1031    wgsl: &str,
1032    entry_point: &'static str,
1033) -> Kernel {
1034    let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1035        label: Some(label),
1036        source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1037    });
1038    let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1039        label: Some(label),
1040        entries: &[
1041            wgpu::BindGroupLayoutEntry {
1042                binding: 0,
1043                visibility: wgpu::ShaderStages::COMPUTE,
1044                ty: wgpu::BindingType::Buffer {
1045                    ty: wgpu::BufferBindingType::Storage { read_only: false },
1046                    has_dynamic_offset: false,
1047                    min_binding_size: None,
1048                },
1049                count: None,
1050            },
1051            wgpu::BindGroupLayoutEntry {
1052                binding: 1,
1053                visibility: wgpu::ShaderStages::COMPUTE,
1054                ty: wgpu::BindingType::Buffer {
1055                    ty: wgpu::BufferBindingType::Uniform,
1056                    has_dynamic_offset: false,
1057                    min_binding_size: None,
1058                },
1059                count: None,
1060            },
1061            wgpu::BindGroupLayoutEntry {
1062                binding: 2,
1063                visibility: wgpu::ShaderStages::COMPUTE,
1064                ty: wgpu::BindingType::Buffer {
1065                    ty: wgpu::BufferBindingType::Storage { read_only: true },
1066                    has_dynamic_offset: false,
1067                    min_binding_size: None,
1068                },
1069                count: None,
1070            },
1071            wgpu::BindGroupLayoutEntry {
1072                binding: 3,
1073                visibility: wgpu::ShaderStages::COMPUTE,
1074                ty: wgpu::BindingType::Buffer {
1075                    ty: wgpu::BufferBindingType::Storage { read_only: true },
1076                    has_dynamic_offset: false,
1077                    min_binding_size: None,
1078                },
1079                count: None,
1080            },
1081        ],
1082    });
1083    let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1084        label: Some(label),
1085        bind_group_layouts: &[Some(&bgl)],
1086        immediate_size: 0,
1087    });
1088    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1089        label: Some(label),
1090        layout: Some(&layout),
1091        module: &module,
1092        entry_point: Some(entry_point),
1093        compilation_options: Default::default(),
1094        cache: None,
1095    });
1096    Kernel { pipeline, bgl }
1097}
1098
1099fn build_kernel_3(
1100    device: &wgpu::Device,
1101    label: &'static str,
1102    wgsl: &str,
1103    entry_point: &'static str,
1104) -> Kernel {
1105    let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1106        label: Some(label),
1107        source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1108    });
1109    let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1110        label: Some(label),
1111        entries: &[
1112            wgpu::BindGroupLayoutEntry {
1113                binding: 0,
1114                visibility: wgpu::ShaderStages::COMPUTE,
1115                ty: wgpu::BindingType::Buffer {
1116                    ty: wgpu::BufferBindingType::Storage { read_only: false },
1117                    has_dynamic_offset: false,
1118                    min_binding_size: None,
1119                },
1120                count: None,
1121            },
1122            wgpu::BindGroupLayoutEntry {
1123                binding: 1,
1124                visibility: wgpu::ShaderStages::COMPUTE,
1125                ty: wgpu::BindingType::Buffer {
1126                    ty: wgpu::BufferBindingType::Uniform,
1127                    has_dynamic_offset: false,
1128                    min_binding_size: None,
1129                },
1130                count: None,
1131            },
1132            wgpu::BindGroupLayoutEntry {
1133                binding: 2,
1134                visibility: wgpu::ShaderStages::COMPUTE,
1135                ty: wgpu::BindingType::Buffer {
1136                    ty: wgpu::BufferBindingType::Storage { read_only: true },
1137                    has_dynamic_offset: false,
1138                    min_binding_size: None,
1139                },
1140                count: None,
1141            },
1142        ],
1143    });
1144    let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1145        label: Some(label),
1146        bind_group_layouts: &[Some(&bgl)],
1147        immediate_size: 0,
1148    });
1149    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1150        label: Some(label),
1151        layout: Some(&layout),
1152        module: &module,
1153        entry_point: Some(entry_point),
1154        compilation_options: Default::default(),
1155        cache: None,
1156    });
1157    Kernel { pipeline, bgl }
1158}
1159
1160/// f16 shadow (rw) + uniform + f32 arena (rw) — `cast_f32_to_f16` only.
1161/// Separate from `build_kernel_3`: cast reads f32 written by a prior unary in
1162/// the same arena; other 3-binding kernels keep binding 2 read-only.
1163fn build_kernel_cast_f32_to_f16(
1164    device: &wgpu::Device,
1165    label: &'static str,
1166    wgsl: &str,
1167    entry_point: &'static str,
1168) -> Kernel {
1169    let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1170        label: Some(label),
1171        source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1172    });
1173    let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1174        label: Some(label),
1175        entries: &[
1176            wgpu::BindGroupLayoutEntry {
1177                binding: 0,
1178                visibility: wgpu::ShaderStages::COMPUTE,
1179                ty: wgpu::BindingType::Buffer {
1180                    ty: wgpu::BufferBindingType::Storage { read_only: false },
1181                    has_dynamic_offset: false,
1182                    min_binding_size: None,
1183                },
1184                count: None,
1185            },
1186            wgpu::BindGroupLayoutEntry {
1187                binding: 1,
1188                visibility: wgpu::ShaderStages::COMPUTE,
1189                ty: wgpu::BindingType::Buffer {
1190                    ty: wgpu::BufferBindingType::Uniform,
1191                    has_dynamic_offset: false,
1192                    min_binding_size: None,
1193                },
1194                count: None,
1195            },
1196            wgpu::BindGroupLayoutEntry {
1197                binding: 2,
1198                visibility: wgpu::ShaderStages::COMPUTE,
1199                ty: wgpu::BindingType::Buffer {
1200                    ty: wgpu::BufferBindingType::Storage { read_only: false },
1201                    has_dynamic_offset: false,
1202                    min_binding_size: None,
1203                },
1204                count: None,
1205            },
1206        ],
1207    });
1208    let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1209        label: Some(label),
1210        bind_group_layouts: &[Some(&bgl)],
1211        immediate_size: 0,
1212    });
1213    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1214        label: Some(label),
1215        layout: Some(&layout),
1216        module: &module,
1217        entry_point: Some(entry_point),
1218        compilation_options: Default::default(),
1219        cache: None,
1220    });
1221    Kernel { pipeline, bgl }
1222}
1223
1224/// f32 arena (rw) + uniform + f16 shadow (rw) — unary with CoopF16Vk mirror.
1225fn build_kernel_f32_rw_uniform_f16_rw(
1226    device: &wgpu::Device,
1227    label: &'static str,
1228    wgsl: &str,
1229    entry_point: &'static str,
1230) -> Kernel {
1231    let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1232        label: Some(label),
1233        source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1234    });
1235    let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1236        label: Some(label),
1237        entries: &[
1238            wgpu::BindGroupLayoutEntry {
1239                binding: 0,
1240                visibility: wgpu::ShaderStages::COMPUTE,
1241                ty: wgpu::BindingType::Buffer {
1242                    ty: wgpu::BufferBindingType::Storage { read_only: false },
1243                    has_dynamic_offset: false,
1244                    min_binding_size: None,
1245                },
1246                count: None,
1247            },
1248            wgpu::BindGroupLayoutEntry {
1249                binding: 1,
1250                visibility: wgpu::ShaderStages::COMPUTE,
1251                ty: wgpu::BindingType::Buffer {
1252                    ty: wgpu::BufferBindingType::Uniform,
1253                    has_dynamic_offset: false,
1254                    min_binding_size: None,
1255                },
1256                count: None,
1257            },
1258            wgpu::BindGroupLayoutEntry {
1259                binding: 2,
1260                visibility: wgpu::ShaderStages::COMPUTE,
1261                ty: wgpu::BindingType::Buffer {
1262                    ty: wgpu::BufferBindingType::Storage { read_only: false },
1263                    has_dynamic_offset: false,
1264                    min_binding_size: None,
1265                },
1266                count: None,
1267            },
1268        ],
1269    });
1270    let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1271        label: Some(label),
1272        bind_group_layouts: &[Some(&bgl)],
1273        immediate_size: 0,
1274    });
1275    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1276        label: Some(label),
1277        layout: Some(&layout),
1278        module: &module,
1279        entry_point: Some(entry_point),
1280        compilation_options: Default::default(),
1281        cache: None,
1282    });
1283    Kernel { pipeline, bgl }
1284}
1285
1286/// f16 shadow (read) + f32 arena (rw) + uniform — Vulkan/DX12 coop f16 matmul.
1287fn build_kernel_coop_f16_vk(
1288    device: &wgpu::Device,
1289    label: &'static str,
1290    wgsl: &str,
1291    entry_point: &'static str,
1292) -> Kernel {
1293    let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1294        label: Some(label),
1295        source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1296    });
1297    let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1298        label: Some(label),
1299        entries: &[
1300            wgpu::BindGroupLayoutEntry {
1301                binding: 0,
1302                visibility: wgpu::ShaderStages::COMPUTE,
1303                ty: wgpu::BindingType::Buffer {
1304                    ty: wgpu::BufferBindingType::Storage { read_only: true },
1305                    has_dynamic_offset: false,
1306                    min_binding_size: None,
1307                },
1308                count: None,
1309            },
1310            wgpu::BindGroupLayoutEntry {
1311                binding: 1,
1312                visibility: wgpu::ShaderStages::COMPUTE,
1313                ty: wgpu::BindingType::Buffer {
1314                    ty: wgpu::BufferBindingType::Storage { read_only: false },
1315                    has_dynamic_offset: false,
1316                    min_binding_size: None,
1317                },
1318                count: None,
1319            },
1320            wgpu::BindGroupLayoutEntry {
1321                binding: 2,
1322                visibility: wgpu::ShaderStages::COMPUTE,
1323                ty: wgpu::BindingType::Buffer {
1324                    ty: wgpu::BufferBindingType::Uniform,
1325                    has_dynamic_offset: false,
1326                    min_binding_size: None,
1327                },
1328                count: None,
1329            },
1330        ],
1331    });
1332    let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1333        label: Some(label),
1334        bind_group_layouts: &[Some(&bgl)],
1335        immediate_size: 0,
1336    });
1337    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1338        label: Some(label),
1339        layout: Some(&layout),
1340        module: &module,
1341        entry_point: Some(entry_point),
1342        compilation_options: Default::default(),
1343        cache: None,
1344    });
1345    Kernel { pipeline, bgl }
1346}
1347
1348fn try_build_kernel_coop_f16_vk(
1349    device: &wgpu::Device,
1350    label: &'static str,
1351    wgsl: &str,
1352    entry_point: &'static str,
1353) -> Option<Kernel> {
1354    std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
1355        build_kernel_coop_f16_vk(device, label, wgsl, entry_point)
1356    }))
1357    .ok()
1358}
1359
1360fn build_kernel(
1361    device: &wgpu::Device,
1362    label: &'static str,
1363    wgsl: &str,
1364    entry_point: &'static str,
1365) -> Kernel {
1366    let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1367        label: Some(label),
1368        source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1369    });
1370    let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1371        label: Some(label),
1372        entries: &[
1373            wgpu::BindGroupLayoutEntry {
1374                binding: 0,
1375                visibility: wgpu::ShaderStages::COMPUTE,
1376                ty: wgpu::BindingType::Buffer {
1377                    ty: wgpu::BufferBindingType::Storage { read_only: false },
1378                    has_dynamic_offset: false,
1379                    min_binding_size: None,
1380                },
1381                count: None,
1382            },
1383            wgpu::BindGroupLayoutEntry {
1384                binding: 1,
1385                visibility: wgpu::ShaderStages::COMPUTE,
1386                ty: wgpu::BindingType::Buffer {
1387                    ty: wgpu::BufferBindingType::Uniform,
1388                    has_dynamic_offset: false,
1389                    min_binding_size: None,
1390                },
1391                count: None,
1392            },
1393        ],
1394    });
1395    let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1396        label: Some(label),
1397        bind_group_layouts: &[Some(&bgl)],
1398        immediate_size: 0,
1399    });
1400    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1401        label: Some(label),
1402        layout: Some(&layout),
1403        module: &module,
1404        entry_point: Some(entry_point),
1405        compilation_options: Default::default(),
1406        cache: None,
1407    });
1408    Kernel { pipeline, bgl }
1409}
1410
1411static MATMUL: OnceLock<Kernel> = OnceLock::new();
1412static MATMUL_WIDE: OnceLock<Kernel> = OnceLock::new();
1413static MATMUL_WIDE_NV: OnceLock<Kernel> = OnceLock::new();
1414static MATMUL_F16W: OnceLock<Kernel> = OnceLock::new();
1415static MATMUL_F16_COMPUTE: OnceLock<Kernel> = OnceLock::new();
1416static MATMUL_COOP16: OnceLock<Kernel> = OnceLock::new();
1417static MATMUL_COOP_F32: OnceLock<Kernel> = OnceLock::new();
1418static MATMUL_COOP_F32_PORTABLE: OnceLock<Kernel> = OnceLock::new();
1419static MATMUL_COOP_F16_VULKAN: OnceLock<Kernel> = OnceLock::new();
1420static MATMUL_COOP_F16_VULKAN_WIDEN: OnceLock<Kernel> = OnceLock::new();
1421static MATMUL_COOP_F16_VULKAN_F32ACC: OnceLock<Option<Kernel>> = OnceLock::new();
1422static MATMUL_COOP_F16_VULKAN_WIDEN_F32ACC: OnceLock<Option<Kernel>> = OnceLock::new();
1423static CAST_F32_TO_F16: OnceLock<Kernel> = OnceLock::new();
1424static BINARY: OnceLock<Kernel> = OnceLock::new();
1425static UNARY: OnceLock<Kernel> = OnceLock::new();
1426static UNARY_F16_MIRROR: OnceLock<Kernel> = OnceLock::new();
1427static COMPARE: OnceLock<Kernel> = OnceLock::new();
1428static WHEREK: OnceLock<Kernel> = OnceLock::new();
1429static REDUCE: OnceLock<Kernel> = OnceLock::new();
1430static SOFTMAX: OnceLock<Kernel> = OnceLock::new();
1431static LAYERNORM: OnceLock<Kernel> = OnceLock::new();
1432static RMS_NORM_BWD: OnceLock<Kernel> = OnceLock::new();
1433static RMS_NORM_BWD_PARAM: OnceLock<Kernel> = OnceLock::new();
1434static LAYER_NORM_BWD_INPUT: OnceLock<Kernel> = OnceLock::new();
1435static LAYER_NORM_BWD_GAMMA: OnceLock<Kernel> = OnceLock::new();
1436static LAYER_NORM_BWD_GAMMA_REDUCE: OnceLock<Kernel> = OnceLock::new();
1437static CUMSUM_BWD: OnceLock<Kernel> = OnceLock::new();
1438static ROPE_BWD: OnceLock<Kernel> = OnceLock::new();
1439static GATHER_BWD_ZERO: OnceLock<Kernel> = OnceLock::new();
1440static GATHER_BWD_ACC: OnceLock<Kernel> = OnceLock::new();
1441static CUMSUM: OnceLock<Kernel> = OnceLock::new();
1442static FFT_GPU_RADIX2: OnceLock<Kernel> = OnceLock::new();
1443static FFT_GPU_BITREV: OnceLock<Kernel> = OnceLock::new();
1444static FFT_GPU_INNER: OnceLock<Kernel> = OnceLock::new();
1445static FFT_GPU_OUTER_R4: OnceLock<Kernel> = OnceLock::new();
1446static FFT_GPU_OUTER_R2: OnceLock<Kernel> = OnceLock::new();
1447static COPY: OnceLock<Kernel> = OnceLock::new();
1448static ELEMENTWISE_REGION: OnceLock<Kernel> = OnceLock::new();
1449static ELEMENTWISE_REGION_SPATIAL: OnceLock<Kernel> = OnceLock::new();
1450static TRANSPOSE: OnceLock<Kernel> = OnceLock::new();
1451static NARROW: OnceLock<Kernel> = OnceLock::new();
1452static CONCAT: OnceLock<Kernel> = OnceLock::new();
1453static GATHER: OnceLock<Kernel> = OnceLock::new();
1454static GATHER_AXIS: OnceLock<Kernel> = OnceLock::new();
1455static ATTENTION: OnceLock<Kernel> = OnceLock::new();
1456static ATTENTION_BWD: OnceLock<Kernel> = OnceLock::new();
1457static ROPE: OnceLock<Kernel> = OnceLock::new();
1458static EXPAND: OnceLock<Kernel> = OnceLock::new();
1459static ARGMAX: OnceLock<Kernel> = OnceLock::new();
1460static POOL2D: OnceLock<Kernel> = OnceLock::new();
1461static CONV2D: OnceLock<Kernel> = OnceLock::new();
1462static POOL1D: OnceLock<Kernel> = OnceLock::new();
1463static POOL3D: OnceLock<Kernel> = OnceLock::new();
1464static CONV1D: OnceLock<Kernel> = OnceLock::new();
1465static CONV3D: OnceLock<Kernel> = OnceLock::new();
1466static SCATTER_ADD: OnceLock<Kernel> = OnceLock::new();
1467static TOPK: OnceLock<Kernel> = OnceLock::new();
1468static WELCH_PEAKS_GPU: OnceLock<Kernel> = OnceLock::new();
1469static UMAP_KNN: OnceLock<Kernel> = OnceLock::new();
1470static GROUPED_MATMUL: OnceLock<Kernel> = OnceLock::new();
1471static SAMPLE: OnceLock<Kernel> = OnceLock::new();
1472static SELECTIVE_SCAN: OnceLock<Kernel> = OnceLock::new();
1473static DEQUANT_MATMUL: OnceLock<Kernel> = OnceLock::new();
1474static FUSED_RESIDUAL_LN: OnceLock<Kernel> = OnceLock::new();
1475static FUSED_RESIDUAL_LN_TEE: OnceLock<Kernel> = OnceLock::new();
1476static FUSED_RESIDUAL_RMS_NORM: OnceLock<Kernel> = OnceLock::new();
1477static MATMUL_QKV: OnceLock<Kernel> = OnceLock::new();
1478static MATMUL_QKV_COOP_F32: OnceLock<Kernel> = OnceLock::new();
1479static MATMUL_QKV_COOP_F16_VK: OnceLock<Kernel> = OnceLock::new();
1480static MATMUL_QKV_COOP_F16_VK_WIDEN: OnceLock<Kernel> = OnceLock::new();
1481static MATMUL_QKV_COOP_F16_VK_F32ACC: OnceLock<Option<Kernel>> = OnceLock::new();
1482static MATMUL_QKV_COOP_F16_VK_WIDEN_F32ACC: OnceLock<Option<Kernel>> = OnceLock::new();
1483
1484pub fn matmul_kernel(device: &wgpu::Device) -> &'static Kernel {
1485    MATMUL.get_or_init(|| build_kernel(device, "rlx-wgpu matmul", MATMUL_WGSL, "matmul"))
1486}
1487pub fn matmul_wide_kernel(device: &wgpu::Device) -> &'static Kernel {
1488    MATMUL_WIDE.get_or_init(|| {
1489        build_kernel(
1490            device,
1491            "rlx-wgpu matmul_wide",
1492            MATMUL_WIDE_WGSL,
1493            "matmul_wide",
1494        )
1495    })
1496}
1497/// 64×64 / 256-thread variant for discrete GPUs (Vulkan path).
1498pub fn matmul_wide_nv_kernel(device: &wgpu::Device) -> &'static Kernel {
1499    MATMUL_WIDE_NV.get_or_init(|| {
1500        build_kernel(
1501            device,
1502            "rlx-wgpu matmul_wide_nv",
1503            MATMUL_WIDE_NV_WGSL,
1504            "matmul_wide_nv",
1505        )
1506    })
1507}
1508/// f16-weight matmul (f32 compute). Returns Some only when the device
1509/// exposes the `SHADER_F16` feature. EXPERIMENTAL: currently slower
1510/// than the f32 baseline on Apple Silicon — kept as foundation; see
1511/// `matmul_f16w.wgsl` for the empirical analysis.
1512pub fn matmul_f16w_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1513    if !device.features().contains(wgpu::Features::SHADER_F16) {
1514        return None;
1515    }
1516    Some(MATMUL_F16W.get_or_init(|| {
1517        build_kernel_3(
1518            device,
1519            "rlx-wgpu matmul_f16w",
1520            MATMUL_F16W_WGSL,
1521            "matmul_f16w",
1522        )
1523    }))
1524}
1525/// f16-compute matmul: f16 operands, f16 multiply, f32 accumulator.
1526/// Targets the 2× f16 ALU throughput on Apple Silicon. Returns Some
1527/// only when the device exposes `SHADER_F16`.
1528pub fn matmul_f16_compute_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1529    if !device.features().contains(wgpu::Features::SHADER_F16) {
1530        return None;
1531    }
1532    Some(MATMUL_F16_COMPUTE.get_or_init(|| {
1533        build_kernel_3(
1534            device,
1535            "rlx-wgpu matmul_f16_compute",
1536            MATMUL_F16_COMPUTE_WGSL,
1537            "matmul_f16_compute",
1538        )
1539    }))
1540}
1541/// Cooperative-matrix matmul (8×8 tiles, hardware GEMM units).
1542/// Lowers to MSL `simdgroup_matrix` on Metal and SPIR-V's
1543/// `OpCooperativeMatrixMulAddKHR` on Vulkan. Returns Some only when
1544/// the device exposes both `SHADER_F16` and
1545/// `EXPERIMENTAL_COOPERATIVE_MATRIX`.
1546pub fn matmul_coop16_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1547    let feats = device.features();
1548    if !feats.contains(wgpu::Features::SHADER_F16)
1549        || !feats.contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX)
1550    {
1551        return None;
1552    }
1553    Some(MATMUL_COOP16.get_or_init(|| {
1554        build_kernel_3(
1555            device,
1556            "rlx-wgpu matmul_coop16",
1557            MATMUL_COOP16_WGSL,
1558            "matmul_coop16",
1559        )
1560    }))
1561}
1562/// Pure-f32 cooperative-matrix matmul. No SHADER_F16 needed — uses
1563/// `coop_mat8x8<f32>` throughout (lowers to `simdgroup_float8x8` on
1564/// Apple). Returns None if the cooperative-matrix feature is missing
1565/// OR if the device's WGSL→backend lowering can't compile it (some
1566/// implementations only expose half-precision coop matrices).
1567pub fn matmul_coop_f32_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1568    let feats = device.features();
1569    if !feats.contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX) {
1570        return None;
1571    }
1572    Some(MATMUL_COOP_F32.get_or_init(|| {
1573        build_kernel(
1574            device,
1575            "rlx-wgpu matmul_coop_f32",
1576            MATMUL_COOP_F32_WGSL,
1577            "matmul_coop_f32",
1578        )
1579    }))
1580}
1581/// Vulkan/DX12-oriented coop f32 matmul (`coopLoad`, 8×8 workgroups).
1582pub fn matmul_coop_f32_portable_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1583    let feats = device.features();
1584    if !feats.contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX)
1585        || !crate::device::coop_f32_8x8_supported()
1586    {
1587        return None;
1588    }
1589    Some(MATMUL_COOP_F32_PORTABLE.get_or_init(|| {
1590        build_kernel(
1591            device,
1592            "rlx-wgpu matmul_coop_f32_portable",
1593            MATMUL_COOP_F32_PORTABLE_WGSL,
1594            "matmul_coop_f32_portable",
1595        )
1596    }))
1597}
1598fn coop_f16_vk_device_ready(device: &wgpu::Device) -> bool {
1599    // Cooperative-matrix Vulkan/DX12 matmul is OFF by default — see
1600    // `coop_f16_vk_eligible` in `backend.rs` for the rationale. Opt in
1601    // with `RLX_WGPU_COOP_F16_VK_ENABLE=1`. Legacy
1602    // `RLX_WGPU_COOP_F16_VK_DISABLE=1` also fully disables.
1603    if rlx_ir::env::flag("RLX_WGPU_COOP_F16_VK_DISABLE")
1604        || !rlx_ir::env::flag("RLX_WGPU_COOP_F16_VK_ENABLE")
1605    {
1606        return false;
1607    }
1608    device.features().contains(wgpu::Features::SHADER_F16)
1609        && device
1610            .features()
1611            .contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX)
1612        && crate::device::coop_f16_16x16_supported()
1613        && crate::device::coop_discrete_backend()
1614}
1615
1616fn coop_f16_vk_f32acc_device_ready(device: &wgpu::Device) -> bool {
1617    coop_f16_vk_device_ready(device) && crate::device::coop_f16_16x16_f32_acc_supported()
1618}
1619
1620pub fn matmul_coop_f16_vulkan_f32acc_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1621    if !coop_f16_vk_f32acc_device_ready(device) {
1622        return None;
1623    }
1624    MATMUL_COOP_F16_VULKAN_F32ACC
1625        .get_or_init(|| {
1626            try_build_kernel_coop_f16_vk(
1627                device,
1628                "rlx-wgpu matmul_coop_f16_vulkan_f32acc",
1629                MATMUL_COOP_F16_VULKAN_F32ACC_WGSL,
1630                "matmul_coop_f16_vulkan_f32acc",
1631            )
1632        })
1633        .as_ref()
1634}
1635
1636pub fn matmul_coop_f16_vulkan_widen_f32acc_kernel(
1637    device: &wgpu::Device,
1638) -> Option<&'static Kernel> {
1639    if !coop_f16_vk_f32acc_device_ready(device) {
1640        return None;
1641    }
1642    MATMUL_COOP_F16_VULKAN_WIDEN_F32ACC
1643        .get_or_init(|| {
1644            try_build_kernel_coop_f16_vk(
1645                device,
1646                "rlx-wgpu matmul_coop_f16_vulkan_widen_f32acc",
1647                MATMUL_COOP_F16_VULKAN_WIDEN_F32ACC_WGSL,
1648                "matmul_coop_f16_vulkan_widen_f32acc",
1649            )
1650        })
1651        .as_ref()
1652}
1653
1654fn coop_f16_vk_use_f32acc(device: &wgpu::Device) -> bool {
1655    !rlx_ir::env::flag("RLX_WGPU_COOP_F16_VK_NO_F32ACC")
1656        && matmul_coop_f16_vulkan_f32acc_kernel(device).is_some()
1657}
1658
1659fn pick_coop_f16_vk_matmul(
1660    device: &wgpu::Device,
1661    n: u32,
1662    loadt: fn(&wgpu::Device) -> Option<&'static Kernel>,
1663    loadt_f32acc: fn(&wgpu::Device) -> Option<&'static Kernel>,
1664    widen: fn(&wgpu::Device) -> Option<&'static Kernel>,
1665    widen_f32acc: fn(&wgpu::Device) -> Option<&'static Kernel>,
1666) -> Option<&'static Kernel> {
1667    if coop_f16_vk_use_f32acc(device) {
1668        if coop_f16_vk_widen_b_load(n) {
1669            return widen_f32acc(device).or_else(|| loadt_f32acc(device));
1670        }
1671        return loadt_f32acc(device);
1672    }
1673    if coop_f16_vk_widen_b_load(n) {
1674        widen(device).or_else(|| loadt(device))
1675    } else {
1676        loadt(device)
1677    }
1678}
1679
1680/// Matmul CoopF16Vk kernel for column count `n`.
1681pub fn matmul_coop_f16_vulkan_active_kernel(
1682    device: &wgpu::Device,
1683    n: u32,
1684) -> Option<&'static Kernel> {
1685    pick_coop_f16_vk_matmul(
1686        device,
1687        n,
1688        matmul_coop_f16_vulkan_kernel,
1689        matmul_coop_f16_vulkan_f32acc_kernel,
1690        matmul_coop_f16_vulkan_widen_kernel,
1691        matmul_coop_f16_vulkan_widen_f32acc_kernel,
1692    )
1693}
1694
1695pub fn matmul_coop_f16_vulkan_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1696    if !coop_f16_vk_device_ready(device) {
1697        return None;
1698    }
1699    Some(MATMUL_COOP_F16_VULKAN.get_or_init(|| {
1700        build_kernel_coop_f16_vk(
1701            device,
1702            "rlx-wgpu matmul_coop_f16_vulkan",
1703            MATMUL_COOP_F16_VULKAN_WGSL,
1704            "matmul_coop_f16_vulkan",
1705        )
1706    }))
1707}
1708/// N above which coop may use the row-major B-load variant (`RLX_WGPU_COOP_F16_VK_LARGE_N`).
1709pub const COOP_F16_VK_WIDEN_N: u32 = 768;
1710
1711/// Use `coopLoad` on B instead of `coopLoadT` when N > 768 and `RLX_WGPU_COOP_F16_VK_LOAD_T` is unset.
1712pub fn coop_f16_vk_widen_b_load(n: u32) -> bool {
1713    n > COOP_F16_VK_WIDEN_N && !rlx_ir::env::flag("RLX_WGPU_COOP_F16_VK_LOAD_T")
1714}
1715
1716pub fn matmul_coop_f16_vulkan_widen_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1717    if !coop_f16_vk_device_ready(device) {
1718        return None;
1719    }
1720    Some(MATMUL_COOP_F16_VULKAN_WIDEN.get_or_init(|| {
1721        build_kernel_coop_f16_vk(
1722            device,
1723            "rlx-wgpu matmul_coop_f16_vulkan_widen",
1724            MATMUL_COOP_F16_VULKAN_WIDEN_WGSL,
1725            "matmul_coop_f16_vulkan_widen",
1726        )
1727    }))
1728}
1729pub fn coop_f16_vk_f32acc_available(device: &wgpu::Device) -> bool {
1730    matmul_coop_f16_vulkan_f32acc_kernel(device).is_some()
1731}
1732/// CoopF32 kernel for the active wgpu backend (Metal simdgroup vs Vulkan/DX12 portable).
1733pub fn matmul_coop_f32_active_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1734    match crate::device::wgpu_device().map(|d| d.backend) {
1735        Some(wgpu::Backend::Metal) => matmul_coop_f32_kernel(device),
1736        Some(wgpu::Backend::Vulkan) | Some(wgpu::Backend::Dx12) => {
1737            matmul_coop_f32_portable_kernel(device)
1738        }
1739        _ => None,
1740    }
1741}
1742/// Wide f32 matmul kernel for the active backend.
1743pub fn matmul_wide_active_kernel(device: &wgpu::Device) -> &'static Kernel {
1744    match crate::device::wgpu_device().map(|d| d.backend) {
1745        Some(wgpu::Backend::Vulkan) | Some(wgpu::Backend::Dx12) => matmul_wide_nv_kernel(device),
1746        _ => matmul_wide_kernel(device),
1747    }
1748}
1749/// Mirrors a region of the f32 arena into the f16 shadow buffer.
1750/// Used before `matmul_coop16` for the matmul's activation operand
1751/// (intermediate activations don't go through `set_param` /
1752/// `write_f32`, so they aren't in the f16 buffer otherwise).
1753pub fn cast_f32_to_f16_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1754    if !device.features().contains(wgpu::Features::SHADER_F16) {
1755        return None;
1756    }
1757    Some(CAST_F32_TO_F16.get_or_init(|| {
1758        build_kernel_cast_f32_to_f16(
1759            device,
1760            "rlx-wgpu cast_f32_to_f16",
1761            CAST_F32_TO_F16_WGSL,
1762            "cast_f32_to_f16",
1763        )
1764    }))
1765}
1766pub fn binary_kernel(device: &wgpu::Device) -> &'static Kernel {
1767    BINARY.get_or_init(|| build_kernel(device, "rlx-wgpu binary", BINARY_WGSL, "binary"))
1768}
1769pub fn unary_kernel(device: &wgpu::Device) -> &'static Kernel {
1770    UNARY.get_or_init(|| build_kernel(device, "rlx-wgpu unary", UNARY_WGSL, "unary"))
1771}
1772pub fn unary_f16_mirror_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1773    if !device.features().contains(wgpu::Features::SHADER_F16) {
1774        return None;
1775    }
1776    Some(UNARY_F16_MIRROR.get_or_init(|| {
1777        build_kernel_f32_rw_uniform_f16_rw(
1778            device,
1779            "rlx-wgpu unary_f16_mirror",
1780            UNARY_F16_MIRROR_WGSL,
1781            "unary_f16_mirror",
1782        )
1783    }))
1784}
1785pub fn compare_kernel(device: &wgpu::Device) -> &'static Kernel {
1786    COMPARE.get_or_init(|| build_kernel(device, "rlx-wgpu compare", COMPARE_WGSL, "compare"))
1787}
1788pub fn where_kernel(device: &wgpu::Device) -> &'static Kernel {
1789    WHEREK.get_or_init(|| build_kernel(device, "rlx-wgpu where", WHERE_WGSL, "where_select"))
1790}
1791pub fn reduce_kernel(device: &wgpu::Device) -> &'static Kernel {
1792    REDUCE.get_or_init(|| build_kernel(device, "rlx-wgpu reduce", REDUCE_WGSL, "reduce"))
1793}
1794pub fn softmax_kernel(device: &wgpu::Device) -> &'static Kernel {
1795    SOFTMAX.get_or_init(|| build_kernel(device, "rlx-wgpu softmax", SOFTMAX_WGSL, "softmax"))
1796}
1797pub fn layernorm_kernel(device: &wgpu::Device) -> &'static Kernel {
1798    LAYERNORM.get_or_init(|| build_kernel(device, "rlx-wgpu layernorm", LAYERNORM_WGSL, "norm"))
1799}
1800pub fn rms_norm_backward_kernel(device: &wgpu::Device) -> &'static Kernel {
1801    RMS_NORM_BWD.get_or_init(|| {
1802        build_kernel(
1803            device,
1804            "rlx-wgpu rms_norm_bwd",
1805            RMS_NORM_BWD_WGSL,
1806            "rms_norm_bwd",
1807        )
1808    })
1809}
1810pub fn rms_norm_backward_param_kernel(device: &wgpu::Device) -> &'static Kernel {
1811    RMS_NORM_BWD_PARAM.get_or_init(|| {
1812        build_kernel(
1813            device,
1814            "rlx-wgpu rms_norm_bwd_param",
1815            RMS_NORM_BWD_WGSL,
1816            "rms_norm_bwd_param",
1817        )
1818    })
1819}
1820pub fn layer_norm_backward_input_kernel(device: &wgpu::Device) -> &'static Kernel {
1821    LAYER_NORM_BWD_INPUT.get_or_init(|| {
1822        build_kernel(
1823            device,
1824            "rlx-wgpu layer_norm_bwd_input",
1825            LAYER_NORM_BWD_WGSL,
1826            "layer_norm_bwd_input",
1827        )
1828    })
1829}
1830pub fn layer_norm_backward_gamma_partial_kernel(device: &wgpu::Device) -> &'static Kernel {
1831    LAYER_NORM_BWD_GAMMA.get_or_init(|| {
1832        build_kernel(
1833            device,
1834            "rlx-wgpu layer_norm_bwd_gamma_partial",
1835            LAYER_NORM_BWD_WGSL,
1836            "layer_norm_bwd_gamma_partial",
1837        )
1838    })
1839}
1840
1841pub fn layer_norm_backward_gamma_reduce_kernel(device: &wgpu::Device) -> &'static Kernel {
1842    LAYER_NORM_BWD_GAMMA_REDUCE.get_or_init(|| {
1843        build_kernel(
1844            device,
1845            "rlx-wgpu layer_norm_bwd_gamma_reduce",
1846            LAYER_NORM_BWD_WGSL,
1847            "layer_norm_bwd_gamma_reduce",
1848        )
1849    })
1850}
1851pub fn cumsum_backward_kernel(device: &wgpu::Device) -> &'static Kernel {
1852    CUMSUM_BWD
1853        .get_or_init(|| build_kernel(device, "rlx-wgpu cumsum_bwd", CUMSUM_BWD_WGSL, "cumsum_bwd"))
1854}
1855pub fn rope_backward_kernel(device: &wgpu::Device) -> &'static Kernel {
1856    ROPE_BWD.get_or_init(|| build_kernel(device, "rlx-wgpu rope_bwd", ROPE_BWD_WGSL, "rope_bwd"))
1857}
1858pub fn gather_backward_zero_kernel(device: &wgpu::Device) -> &'static Kernel {
1859    GATHER_BWD_ZERO.get_or_init(|| {
1860        build_kernel(
1861            device,
1862            "rlx-wgpu gather_bwd_zero",
1863            GATHER_BWD_WGSL,
1864            "gather_bwd_zero",
1865        )
1866    })
1867}
1868pub fn gather_backward_acc_kernel(device: &wgpu::Device) -> &'static Kernel {
1869    GATHER_BWD_ACC.get_or_init(|| {
1870        build_kernel(
1871            device,
1872            "rlx-wgpu gather_bwd_acc",
1873            GATHER_BWD_WGSL,
1874            "gather_bwd_acc",
1875        )
1876    })
1877}
1878pub fn cumsum_kernel(device: &wgpu::Device) -> &'static Kernel {
1879    CUMSUM.get_or_init(|| build_kernel(device, "rlx-wgpu cumsum", CUMSUM_WGSL, "cumsum"))
1880}
1881pub fn fft_gpu_radix2_full_kernel(device: &wgpu::Device) -> &'static Kernel {
1882    FFT_GPU_RADIX2.get_or_init(|| {
1883        build_kernel(
1884            device,
1885            "rlx-wgpu fft_radix2_full",
1886            FFT_GPU_WGSL,
1887            "fft_radix2_full",
1888        )
1889    })
1890}
1891pub fn fft_gpu_bit_reverse_kernel(device: &wgpu::Device) -> &'static Kernel {
1892    FFT_GPU_BITREV.get_or_init(|| {
1893        build_kernel(
1894            device,
1895            "rlx-wgpu fft_bit_reverse",
1896            FFT_GPU_WGSL,
1897            "fft_bit_reverse",
1898        )
1899    })
1900}
1901pub fn fft_gpu_inner_kernel(device: &wgpu::Device) -> &'static Kernel {
1902    FFT_GPU_INNER
1903        .get_or_init(|| build_kernel(device, "rlx-wgpu fft_inner", FFT_GPU_WGSL, "fft_inner"))
1904}
1905pub fn fft_gpu_outer_r4_kernel(device: &wgpu::Device) -> &'static Kernel {
1906    FFT_GPU_OUTER_R4.get_or_init(|| {
1907        build_kernel(
1908            device,
1909            "rlx-wgpu fft_outer_r4",
1910            FFT_GPU_WGSL,
1911            "fft_outer_r4",
1912        )
1913    })
1914}
1915pub fn fft_gpu_outer_r2_kernel(device: &wgpu::Device) -> &'static Kernel {
1916    FFT_GPU_OUTER_R2.get_or_init(|| {
1917        build_kernel(
1918            device,
1919            "rlx-wgpu fft_outer_r2",
1920            FFT_GPU_WGSL,
1921            "fft_outer_r2",
1922        )
1923    })
1924}
1925pub fn copy_kernel(device: &wgpu::Device) -> &'static Kernel {
1926    COPY.get_or_init(|| build_kernel(device, "rlx-wgpu copy", COPY_WGSL, "copy"))
1927}
1928pub fn elementwise_region_kernel(device: &wgpu::Device) -> &'static Kernel {
1929    // Region params bind as a STORAGE buffer (not uniform) — WGSL's
1930    // uniform-storage spec requires 16-byte stride for `array<T, N>`,
1931    // which our packed `array<u32, N>` chain layout doesn't satisfy.
1932    // Storage allows arbitrary stride.
1933    ELEMENTWISE_REGION.get_or_init(|| {
1934        build_kernel_region(
1935            device,
1936            "rlx-wgpu elementwise_region",
1937            ELEMENTWISE_REGION_WGSL,
1938            "elementwise_region",
1939        )
1940    })
1941}
1942
1943pub fn elementwise_region_spatial_kernel(device: &wgpu::Device) -> &'static Kernel {
1944    ELEMENTWISE_REGION_SPATIAL.get_or_init(|| {
1945        build_kernel_region(
1946            device,
1947            "rlx-wgpu elementwise_region_spatial",
1948            ELEMENTWISE_REGION_WGSL,
1949            "elementwise_region_spatial",
1950        )
1951    })
1952}
1953
1954static BATCH_ELEMENTWISE_REGION: std::sync::OnceLock<Kernel> = std::sync::OnceLock::new();
1955
1956pub fn batch_elementwise_region_kernel(device: &wgpu::Device) -> &'static Kernel {
1957    BATCH_ELEMENTWISE_REGION.get_or_init(|| {
1958        build_kernel_region(
1959            device,
1960            "rlx-wgpu batch_elementwise_region",
1961            ELEMENTWISE_REGION_WGSL,
1962            "batch_elementwise_region",
1963        )
1964    })
1965}
1966
1967fn build_kernel_region(
1968    device: &wgpu::Device,
1969    label: &'static str,
1970    wgsl: &str,
1971    entry_point: &'static str,
1972) -> Kernel {
1973    let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1974        label: Some(label),
1975        source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1976    });
1977    let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1978        label: Some(label),
1979        entries: &[
1980            wgpu::BindGroupLayoutEntry {
1981                binding: 0,
1982                visibility: wgpu::ShaderStages::COMPUTE,
1983                ty: wgpu::BindingType::Buffer {
1984                    ty: wgpu::BufferBindingType::Storage { read_only: false },
1985                    has_dynamic_offset: false,
1986                    min_binding_size: None,
1987                },
1988                count: None,
1989            },
1990            wgpu::BindGroupLayoutEntry {
1991                binding: 1,
1992                visibility: wgpu::ShaderStages::COMPUTE,
1993                ty: wgpu::BindingType::Buffer {
1994                    // Region params: read-only storage (vs uniform).
1995                    ty: wgpu::BufferBindingType::Storage { read_only: true },
1996                    has_dynamic_offset: false,
1997                    min_binding_size: None,
1998                },
1999                count: None,
2000            },
2001        ],
2002    });
2003    let pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
2004        label: Some(label),
2005        bind_group_layouts: &[Some(&bgl)],
2006        immediate_size: 0,
2007    });
2008    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
2009        label: Some(label),
2010        layout: Some(&pl),
2011        module: &module,
2012        entry_point: Some(entry_point),
2013        compilation_options: Default::default(),
2014        cache: None,
2015    });
2016    Kernel { pipeline, bgl }
2017}
2018pub fn transpose_kernel(device: &wgpu::Device) -> &'static Kernel {
2019    TRANSPOSE
2020        .get_or_init(|| build_kernel_3(device, "rlx-wgpu transpose", TRANSPOSE_WGSL, "transpose"))
2021}
2022pub fn narrow_kernel(device: &wgpu::Device) -> &'static Kernel {
2023    NARROW.get_or_init(|| build_kernel(device, "rlx-wgpu narrow", NARROW_WGSL, "narrow"))
2024}
2025pub fn concat_kernel(device: &wgpu::Device) -> &'static Kernel {
2026    CONCAT.get_or_init(|| build_kernel(device, "rlx-wgpu concat", CONCAT_WGSL, "concat"))
2027}
2028pub fn gather_kernel(device: &wgpu::Device) -> &'static Kernel {
2029    GATHER.get_or_init(|| build_kernel(device, "rlx-wgpu gather", GATHER_WGSL, "gather"))
2030}
2031pub fn gather_axis_kernel(device: &wgpu::Device) -> &'static Kernel {
2032    GATHER_AXIS.get_or_init(|| {
2033        build_kernel(
2034            device,
2035            "rlx-wgpu gather_axis",
2036            GATHER_AXIS_WGSL,
2037            "gather_axis",
2038        )
2039    })
2040}
2041pub fn attention_kernel(device: &wgpu::Device) -> &'static Kernel {
2042    ATTENTION
2043        .get_or_init(|| build_kernel(device, "rlx-wgpu attention", ATTENTION_WGSL, "attention"))
2044}
2045pub fn attention_bwd_kernel(device: &wgpu::Device) -> &'static Kernel {
2046    ATTENTION_BWD.get_or_init(|| {
2047        build_kernel(
2048            device,
2049            "rlx-wgpu attention_bwd",
2050            ATTENTION_BWD_WGSL,
2051            "attention_bwd",
2052        )
2053    })
2054}
2055pub fn rope_kernel(device: &wgpu::Device) -> &'static Kernel {
2056    ROPE.get_or_init(|| build_kernel(device, "rlx-wgpu rope", ROPE_WGSL, "rope"))
2057}
2058pub fn expand_kernel(device: &wgpu::Device) -> &'static Kernel {
2059    EXPAND.get_or_init(|| build_kernel_3(device, "rlx-wgpu expand", EXPAND_WGSL, "expand"))
2060}
2061pub fn argmax_kernel(device: &wgpu::Device) -> &'static Kernel {
2062    ARGMAX.get_or_init(|| build_kernel(device, "rlx-wgpu argmax", ARGMAX_WGSL, "argmax"))
2063}
2064pub fn pool2d_kernel(device: &wgpu::Device) -> &'static Kernel {
2065    POOL2D.get_or_init(|| build_kernel(device, "rlx-wgpu pool2d", POOL2D_WGSL, "pool2d"))
2066}
2067pub fn conv2d_kernel(device: &wgpu::Device) -> &'static Kernel {
2068    CONV2D.get_or_init(|| build_kernel(device, "rlx-wgpu conv2d", CONV2D_WGSL, "conv2d"))
2069}
2070pub fn pool1d_kernel(device: &wgpu::Device) -> &'static Kernel {
2071    POOL1D.get_or_init(|| build_kernel(device, "rlx-wgpu pool1d", POOL1D_WGSL, "pool1d"))
2072}
2073pub fn pool3d_kernel(device: &wgpu::Device) -> &'static Kernel {
2074    POOL3D.get_or_init(|| build_kernel(device, "rlx-wgpu pool3d", POOL3D_WGSL, "pool3d"))
2075}
2076pub fn conv1d_kernel(device: &wgpu::Device) -> &'static Kernel {
2077    CONV1D.get_or_init(|| build_kernel(device, "rlx-wgpu conv1d", CONV1D_WGSL, "conv1d"))
2078}
2079pub fn conv3d_kernel(device: &wgpu::Device) -> &'static Kernel {
2080    CONV3D.get_or_init(|| build_kernel(device, "rlx-wgpu conv3d", CONV3D_WGSL, "conv3d"))
2081}
2082pub fn scatter_add_kernel(device: &wgpu::Device) -> &'static Kernel {
2083    SCATTER_ADD.get_or_init(|| {
2084        build_kernel(
2085            device,
2086            "rlx-wgpu scatter_add",
2087            SCATTER_ADD_WGSL,
2088            "scatter_add",
2089        )
2090    })
2091}
2092pub fn topk_kernel(device: &wgpu::Device) -> &'static Kernel {
2093    TOPK.get_or_init(|| build_kernel(device, "rlx-wgpu topk", TOPK_WGSL, "topk"))
2094}
2095pub fn welch_peaks_gpu_kernel(device: &wgpu::Device) -> &'static Kernel {
2096    WELCH_PEAKS_GPU.get_or_init(|| {
2097        build_kernel(
2098            device,
2099            "rlx-wgpu welch_peaks_gpu",
2100            WELCH_PEAKS_GPU_WGSL,
2101            "welch_peaks_gpu",
2102        )
2103    })
2104}
2105pub fn umap_knn_kernel(device: &wgpu::Device) -> &'static Kernel {
2106    UMAP_KNN.get_or_init(|| build_kernel(device, "rlx-wgpu umap_knn", UMAP_KNN_WGSL, "umap_knn"))
2107}
2108pub fn grouped_matmul_kernel(device: &wgpu::Device) -> &'static Kernel {
2109    GROUPED_MATMUL.get_or_init(|| {
2110        build_kernel(
2111            device,
2112            "rlx-wgpu grouped_matmul",
2113            GROUPED_MATMUL_WGSL,
2114            "grouped_matmul",
2115        )
2116    })
2117}
2118pub fn sample_kernel(device: &wgpu::Device) -> &'static Kernel {
2119    SAMPLE.get_or_init(|| build_kernel(device, "rlx-wgpu sample", SAMPLE_WGSL, "sample"))
2120}
2121pub fn selective_scan_kernel(device: &wgpu::Device) -> &'static Kernel {
2122    SELECTIVE_SCAN.get_or_init(|| {
2123        build_kernel(
2124            device,
2125            "rlx-wgpu selective_scan",
2126            SELECTIVE_SCAN_WGSL,
2127            "selective_scan",
2128        )
2129    })
2130}
2131pub fn dequant_matmul_kernel(device: &wgpu::Device) -> &'static Kernel {
2132    DEQUANT_MATMUL.get_or_init(|| {
2133        build_kernel(
2134            device,
2135            "rlx-wgpu dequant_matmul",
2136            DEQUANT_MATMUL_WGSL,
2137            "dequant_matmul",
2138        )
2139    })
2140}
2141pub fn fused_residual_ln_kernel(device: &wgpu::Device) -> &'static Kernel {
2142    FUSED_RESIDUAL_LN.get_or_init(|| {
2143        build_kernel(
2144            device,
2145            "rlx-wgpu fused_residual_ln",
2146            FUSED_RESIDUAL_LN_WGSL,
2147            "fused_residual_ln",
2148        )
2149    })
2150}
2151pub fn fused_residual_ln_tee_kernel(device: &wgpu::Device) -> &'static Kernel {
2152    FUSED_RESIDUAL_LN_TEE.get_or_init(|| {
2153        build_kernel(
2154            device,
2155            "rlx-wgpu fused_residual_ln_tee",
2156            FUSED_RESIDUAL_LN_TEE_WGSL,
2157            "fused_residual_ln_tee",
2158        )
2159    })
2160}
2161pub fn fused_residual_rms_norm_kernel(device: &wgpu::Device) -> &'static Kernel {
2162    FUSED_RESIDUAL_RMS_NORM.get_or_init(|| {
2163        build_kernel(
2164            device,
2165            "rlx-wgpu fused_residual_rms_norm",
2166            FUSED_RESIDUAL_RMS_NORM_WGSL,
2167            "fused_residual_rms_norm",
2168        )
2169    })
2170}
2171pub fn matmul_qkv_kernel(device: &wgpu::Device) -> &'static Kernel {
2172    MATMUL_QKV
2173        .get_or_init(|| build_kernel(device, "rlx-wgpu matmul_qkv", MATMUL_QKV_WGSL, "matmul_qkv"))
2174}
2175pub fn matmul_qkv_coop_f32_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
2176    if !device
2177        .features()
2178        .contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX)
2179    {
2180        return None;
2181    }
2182    Some(MATMUL_QKV_COOP_F32.get_or_init(|| {
2183        build_kernel(
2184            device,
2185            "rlx-wgpu matmul_qkv_coop_f32",
2186            MATMUL_QKV_COOP_F32_WGSL,
2187            "matmul_qkv_coop_f32",
2188        )
2189    }))
2190}
2191pub fn matmul_qkv_coop_f16_vk_f32acc_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
2192    if !coop_f16_vk_f32acc_device_ready(device) {
2193        return None;
2194    }
2195    MATMUL_QKV_COOP_F16_VK_F32ACC
2196        .get_or_init(|| {
2197            try_build_kernel_coop_f16_vk(
2198                device,
2199                "rlx-wgpu matmul_qkv_coop_f16_vk_f32acc",
2200                MATMUL_QKV_COOP_F16_VK_F32ACC_WGSL,
2201                "matmul_qkv_coop_f16_vk_f32acc",
2202            )
2203        })
2204        .as_ref()
2205}
2206
2207pub fn matmul_qkv_coop_f16_vk_widen_f32acc_kernel(
2208    device: &wgpu::Device,
2209) -> Option<&'static Kernel> {
2210    if !coop_f16_vk_f32acc_device_ready(device) {
2211        return None;
2212    }
2213    MATMUL_QKV_COOP_F16_VK_WIDEN_F32ACC
2214        .get_or_init(|| {
2215            try_build_kernel_coop_f16_vk(
2216                device,
2217                "rlx-wgpu matmul_qkv_coop_f16_vk_widen_f32acc",
2218                MATMUL_QKV_COOP_F16_VK_WIDEN_F32ACC_WGSL,
2219                "matmul_qkv_coop_f16_vk_widen_f32acc",
2220            )
2221        })
2222        .as_ref()
2223}
2224
2225pub fn matmul_qkv_coop_f16_vk_active_kernel(
2226    device: &wgpu::Device,
2227    n: u32,
2228) -> Option<&'static Kernel> {
2229    pick_coop_f16_vk_matmul(
2230        device,
2231        n,
2232        matmul_qkv_coop_f16_vk_kernel,
2233        matmul_qkv_coop_f16_vk_f32acc_kernel,
2234        matmul_qkv_coop_f16_vk_widen_kernel,
2235        matmul_qkv_coop_f16_vk_widen_f32acc_kernel,
2236    )
2237}
2238
2239pub fn matmul_qkv_coop_f16_vk_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
2240    if !coop_f16_vk_device_ready(device) {
2241        return None;
2242    }
2243    Some(MATMUL_QKV_COOP_F16_VK.get_or_init(|| {
2244        build_kernel_coop_f16_vk(
2245            device,
2246            "rlx-wgpu matmul_qkv_coop_f16_vk",
2247            MATMUL_QKV_COOP_F16_VK_WGSL,
2248            "matmul_qkv_coop_f16_vk",
2249        )
2250    }))
2251}
2252pub fn matmul_qkv_coop_f16_vk_widen_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
2253    if !coop_f16_vk_device_ready(device) {
2254        return None;
2255    }
2256    Some(MATMUL_QKV_COOP_F16_VK_WIDEN.get_or_init(|| {
2257        build_kernel_coop_f16_vk(
2258            device,
2259            "rlx-wgpu matmul_qkv_coop_f16_vk_widen",
2260            MATMUL_QKV_COOP_F16_VK_WIDEN_WGSL,
2261            "matmul_qkv_coop_f16_vk_widen",
2262        )
2263    }))
2264}