Skip to main content

rlx_wgpu/kernels/
mod.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! WGSL kernel sources + per-kernel pipeline cache.
17//!
18//! Pipelines are content-addressed: same WGSL source + same entry
19//! point yields the same pipeline. We hold them in `OnceLock`s so a
20//! single device dispatches every (graph, op) pair against a cached
21//! compilation.
22
23use std::sync::OnceLock;
24
25use bytemuck::{Pod, Zeroable};
26
27pub const MATMUL_WGSL: &str = include_str!("matmul.wgsl");
28pub const MATMUL_WIDE_WGSL: &str = include_str!("matmul_wide.wgsl");
29pub const MATMUL_F16W_WGSL: &str = include_str!("matmul_f16w.wgsl");
30pub const MATMUL_F16_COMPUTE_WGSL: &str = include_str!("matmul_f16_compute.wgsl");
31pub const MATMUL_COOP16_WGSL: &str = include_str!("matmul_coop16.wgsl");
32pub const MATMUL_COOP_F32_WGSL: &str = include_str!("matmul_coop_f32.wgsl");
33pub const CAST_F32_TO_F16_WGSL: &str = include_str!("cast_f32_to_f16.wgsl");
34pub const BINARY_WGSL: &str = include_str!("binary.wgsl");
35pub const UNARY_WGSL: &str = include_str!("unary.wgsl");
36pub const COMPARE_WGSL: &str = include_str!("compare.wgsl");
37pub const WHERE_WGSL: &str = include_str!("where.wgsl");
38pub const REDUCE_WGSL: &str = include_str!("reduce.wgsl");
39pub const SOFTMAX_WGSL: &str = include_str!("softmax.wgsl");
40pub const LAYERNORM_WGSL: &str = include_str!("layernorm.wgsl");
41pub const RMS_NORM_BWD_WGSL: &str = include_str!("rms_norm_backward.wgsl");
42pub const CUMSUM_BWD_WGSL: &str = include_str!("cumsum_backward.wgsl");
43pub const ROPE_BWD_WGSL: &str = include_str!("rope_backward.wgsl");
44pub const GATHER_BWD_WGSL: &str = include_str!("gather_backward.wgsl");
45pub const CUMSUM_WGSL: &str = include_str!("cumsum.wgsl");
46pub const FFT_GPU_WGSL: &str = include_str!("fft_gpu.wgsl");
47pub const COPY_WGSL: &str = include_str!("copy.wgsl");
48pub const ELEMENTWISE_REGION_WGSL: &str = include_str!("elementwise_region.wgsl");
49pub const TRANSPOSE_WGSL: &str = include_str!("transpose.wgsl");
50pub const NARROW_WGSL: &str = include_str!("narrow.wgsl");
51pub const CONCAT_WGSL: &str = include_str!("concat.wgsl");
52pub const GATHER_WGSL: &str = include_str!("gather.wgsl");
53pub const GATHER_AXIS_WGSL: &str = include_str!("gather_axis.wgsl");
54pub const ATTENTION_WGSL: &str = include_str!("attention.wgsl");
55pub const ATTENTION_BWD_WGSL: &str = include_str!("attention_bwd.wgsl");
56pub const ROPE_WGSL: &str = include_str!("rope.wgsl");
57pub const EXPAND_WGSL: &str = include_str!("expand.wgsl");
58pub const ARGMAX_WGSL: &str = include_str!("argmax.wgsl");
59pub const POOL2D_WGSL: &str = include_str!("pool2d.wgsl");
60pub const CONV2D_WGSL: &str = include_str!("conv2d.wgsl");
61pub const POOL1D_WGSL: &str = include_str!("pool1d.wgsl");
62pub const POOL3D_WGSL: &str = include_str!("pool3d.wgsl");
63pub const CONV1D_WGSL: &str = include_str!("conv1d.wgsl");
64pub const CONV3D_WGSL: &str = include_str!("conv3d.wgsl");
65pub const SCATTER_ADD_WGSL: &str = include_str!("scatter_add.wgsl");
66pub const TOPK_WGSL: &str = include_str!("topk.wgsl");
67pub const UMAP_KNN_WGSL: &str = include_str!("umap_knn.wgsl");
68pub const GROUPED_MATMUL_WGSL: &str = include_str!("grouped_matmul.wgsl");
69pub const SAMPLE_WGSL: &str = include_str!("sample.wgsl");
70pub const SELECTIVE_SCAN_WGSL: &str = include_str!("selective_scan.wgsl");
71pub const DEQUANT_MATMUL_WGSL: &str = include_str!("dequant_matmul.wgsl");
72pub const FUSED_RESIDUAL_LN_WGSL: &str = include_str!("fused_residual_ln.wgsl");
73pub const FUSED_RESIDUAL_LN_TEE_WGSL: &str = include_str!("fused_residual_ln_tee.wgsl");
74pub const FUSED_RESIDUAL_RMS_NORM_WGSL: &str = include_str!("fused_residual_rms_norm.wgsl");
75pub const MATMUL_QKV_WGSL: &str = include_str!("matmul_qkv.wgsl");
76pub const MATMUL_QKV_COOP_F32_WGSL: &str = include_str!("matmul_qkv_coop_f32.wgsl");
77
78#[repr(C)]
79#[derive(Debug, Clone, Copy, Pod, Zeroable)]
80pub struct MatmulParams {
81    pub m: u32,
82    pub k: u32,
83    pub n: u32,
84    pub a_off: u32,
85    pub b_off: u32,
86    pub c_off: u32,
87    pub batch: u32,
88    pub a_batch_stride: u32,
89    pub b_batch_stride: u32,
90    pub c_batch_stride: u32,
91    pub has_bias: u32,
92    pub bias_off: u32,
93    pub act_id: u32, // 0xFFFF = no activation
94    pub _pad0: u32,
95    pub _pad1: u32,
96    pub _pad2: u32,
97}
98
99/// Shared layout for binary, compare. 32 bytes (8 u32s).
100#[repr(C)]
101#[derive(Debug, Clone, Copy, Pod, Zeroable)]
102pub struct BinaryParams {
103    pub n: u32,
104    pub a_off: u32,
105    pub b_off: u32,
106    pub c_off: u32,
107    pub op: u32,
108    pub _p0: u32,
109    pub _p1: u32,
110    pub _p2: u32,
111}
112
113/// Layout for unary kernel. 32 bytes.
114#[repr(C)]
115#[derive(Debug, Clone, Copy, Pod, Zeroable)]
116pub struct UnaryParams {
117    pub n: u32,
118    pub in_off: u32,
119    pub out_off: u32,
120    pub op: u32,
121    pub _p0: u32,
122    pub _p1: u32,
123    pub _p2: u32,
124    pub _p3: u32,
125}
126
127/// Layout for where (3-input select). 32 bytes.
128#[repr(C)]
129#[derive(Debug, Clone, Copy, Pod, Zeroable)]
130pub struct WhereParams {
131    pub n: u32,
132    pub cond_off: u32,
133    pub x_off: u32,
134    pub y_off: u32,
135    pub out_off: u32,
136    pub _p0: u32,
137    pub _p1: u32,
138    pub _p2: u32,
139}
140
141/// Layout for reductions. 32 bytes.
142#[repr(C)]
143#[derive(Debug, Clone, Copy, Pod, Zeroable)]
144pub struct ReduceParams {
145    pub outer: u32,
146    pub inner: u32,
147    pub in_off: u32,
148    pub out_off: u32,
149    pub op: u32,
150    pub _p0: u32,
151    pub _p1: u32,
152    pub _p2: u32,
153}
154
155/// Layout for softmax. 32 bytes.
156#[repr(C)]
157#[derive(Debug, Clone, Copy, Pod, Zeroable)]
158pub struct SoftmaxParams {
159    pub outer: u32,
160    pub inner: u32,
161    pub in_off: u32,
162    pub out_off: u32,
163    pub _p0: u32,
164    pub _p1: u32,
165    pub _p2: u32,
166    pub _p3: u32,
167}
168
169/// Layout for LayerNorm / RmsNorm.
170#[repr(C)]
171#[derive(Debug, Clone, Copy, Pod, Zeroable)]
172pub struct LayerNormParams {
173    pub outer: u32,
174    pub inner: u32,
175    pub in_off: u32,
176    pub out_off: u32,
177    pub gamma_off: u32,
178    pub beta_off: u32,
179    pub eps_bits: u32, // bitcast::<u32>(eps)
180    pub op: u32,       // 0=LayerNorm, 1=RmsNorm
181}
182
183/// RMSNorm backward kernel params (f32 element offsets). `wrt`: 0=dx, 1=dgamma, 2=dbeta.
184#[repr(C)]
185#[derive(Debug, Clone, Copy, Pod, Zeroable)]
186pub struct RmsNormBwdParams {
187    pub outer: u32,
188    pub inner: u32,
189    pub x_off: u32,
190    pub gamma_off: u32,
191    pub beta_off: u32,
192    pub dy_off: u32,
193    pub out_off: u32,
194    pub eps_bits: u32,
195    pub wrt: u32,
196}
197
198#[repr(C)]
199#[derive(Debug, Clone, Copy, Pod, Zeroable)]
200pub struct CumsumBwdParams {
201    pub outer: u32,
202    pub inner: u32,
203    pub dy_off: u32,
204    pub dx_off: u32,
205    pub exclusive: u32,
206    pub _p0: u32,
207    pub _p1: u32,
208    pub _p2: u32,
209}
210
211#[repr(C)]
212#[derive(Debug, Clone, Copy, Pod, Zeroable)]
213pub struct RopeBwdParams {
214    pub batch: u32,
215    pub seq: u32,
216    pub hidden: u32,
217    pub head_dim: u32,
218    pub n_rot: u32,
219    pub dy_off: u32,
220    pub cos_off: u32,
221    pub sin_off: u32,
222    pub dx_off: u32,
223    pub cos_len: u32,
224}
225
226#[repr(C)]
227#[derive(Debug, Clone, Copy, Pod, Zeroable)]
228pub struct GatherBwdParams {
229    pub outer: u32,
230    pub axis_dim: u32,
231    pub num_idx: u32,
232    pub trailing: u32,
233    pub dy_off: u32,
234    pub idx_off: u32,
235    pub dst_off: u32,
236    pub _p0: u32,
237}
238
239/// Layout for cumsum. 32 bytes.
240#[repr(C)]
241#[derive(Debug, Clone, Copy, Pod, Zeroable)]
242pub struct CumsumParams {
243    pub outer: u32,
244    pub inner: u32,
245    pub in_off: u32,
246    pub out_off: u32,
247    pub exclusive: u32,
248    pub _p0: u32,
249    pub _p1: u32,
250    pub _p2: u32,
251}
252
253/// Layout for FFT. 32 bytes. Matches `fft.wgsl::Params`.
254#[repr(C)]
255#[derive(Debug, Clone, Copy, Pod, Zeroable)]
256pub struct FftParams {
257    pub src_off: u32,
258    pub dst_off: u32,
259    pub n: u32,
260    pub log2n: u32,
261    pub inverse: u32,
262    pub norm_scale: f32,
263    pub _p1: u32,
264    pub _p2: u32,
265}
266
267/// Uniform block for multi-kernel FFT (`fft_gpu.wgsl::Params`). 48 bytes.
268#[repr(C)]
269#[derive(Debug, Clone, Copy, Pod, Zeroable)]
270pub struct FftGpuParams {
271    pub off: u32,
272    pub dst_off: u32,
273    pub n: u32,
274    pub log2n: u32,
275    pub inverse: u32,
276    pub norm_scale: f32,
277    pub outer: u32,
278    pub tile: u32,
279    pub inner_stages: u32,
280    pub q_or_hs: u32,
281}
282
283/// PLAN L2 — interpreted N-ary element-wise region. Chain encoded
284/// as 4 u32s per step (op_kind, op_sub, lhs_enc, rhs_enc). Operand
285/// encoding: bit 31 = src kind (0=Input, 1=Step), bits 0..30 = index.
286/// `scalar_input_mask` is the per-input scalar fast-path bitfield;
287/// `input_modulus[i]` is the per-input element count for trailing-
288/// shape broadcast (`0` ⇒ no broadcast, kernel reads gid; `>0` ⇒
289/// kernel reads `gid % input_modulus[i]`). Fixed cap at 32 steps +
290/// 16 inputs (ample for chains rlx produces). 12 padding bytes
291/// after `scalar_input_mask` align the next array on WGSL's
292/// 16-byte uniform alignment boundary.
293#[repr(C)]
294#[derive(Debug, Clone, Copy, Pod, Zeroable)]
295pub struct ElementwiseRegionParams {
296    pub len: u32,
297    pub num_inputs: u32,
298    pub num_steps: u32,
299    pub dst_off: u32,
300    pub input_offs: [u32; 16],
301    pub chain: [u32; 128], // 32 steps * 4 u32s
302    pub scalar_input_mask: u32,
303    pub _pad0: u32,
304    pub _pad1: u32,
305    pub _pad2: u32,
306    pub input_modulus: [u32; 16],
307}
308
309/// Layout shared by Reshape / Cast / generic full copy. 32 bytes.
310#[repr(C)]
311#[derive(Debug, Clone, Copy, Pod, Zeroable)]
312pub struct CopyParams {
313    pub n: u32,
314    pub in_off: u32,
315    pub out_off: u32,
316    pub _p0: u32,
317    pub _p1: u32,
318    pub _p2: u32,
319    pub _p3: u32,
320    pub _p4: u32,
321}
322
323/// Layout for transpose (uses the 3-binding bind layout).
324#[repr(C)]
325#[derive(Debug, Clone, Copy, Pod, Zeroable)]
326pub struct TransposeParams {
327    pub rank: u32,
328    pub out_total: u32,
329    pub in_off: u32,
330    pub out_off: u32,
331    /// PLAN L1 — precomputed at compile time. `1` when `perm[0] == 0`
332    /// (= bucket axis stays at output axis 0). Active-extent path
333    /// scales `out_total` proportionally only when this is `1`.
334    pub bucket_outermost: u32,
335    /// PLAN L1 — `out_dims[0]` for active-extent scaling math.
336    pub out_dim_0: u32,
337    pub _p2: u32,
338    pub _p3: u32,
339}
340
341/// Layout for narrow / concat (the same struct serves both).
342#[repr(C)]
343#[derive(Debug, Clone, Copy, Pod, Zeroable)]
344pub struct NarrowConcatParams {
345    pub total: u32, // total elements (output for narrow, input for concat)
346    pub outer: u32,
347    pub inner: u32,
348    pub axis_in_size: u32,
349    pub axis_out_size: u32,
350    pub start: u32,
351    pub in_off: u32,
352    pub out_off: u32,
353}
354
355/// Layout for gather.
356#[repr(C)]
357#[derive(Debug, Clone, Copy, Pod, Zeroable)]
358pub struct GatherParams {
359    pub n_out: u32,
360    pub n_idx: u32,
361    pub dim: u32,
362    pub vocab: u32,
363    pub in_off: u32,
364    pub idx_off: u32,
365    pub out_off: u32,
366    pub _p0: u32,
367}
368
369/// Layout for gather along a non-zero axis.
370#[repr(C)]
371#[derive(Debug, Clone, Copy, Pod, Zeroable)]
372pub struct GatherAxisParams {
373    pub total: u32,
374    pub outer: u32,
375    pub axis_dim: u32,
376    pub num_idx: u32,
377    pub trailing: u32,
378    pub table_off: u32,
379    pub idx_off: u32,
380    pub out_off: u32,
381}
382
383/// Layout for fused SDPA.
384///
385/// Per-tensor (Q, K, V, output) strides are passed explicitly so the
386/// kernel can read either canonical [B, H, S, D] or transposed
387/// [B, S, H, D] without inserting upstream Transpose dispatches. The
388/// layout-elimination saves ~24 transpose dispatches per BERT-L6
389/// forward (one per Q/K/V/output × layers), each ~50µs at small batch.
390///
391/// The `seq_q_stride` / `seq_k_stride` fields are retained because
392/// they describe the MASK layout `[B, H, S_q, S_k]` (separate from
393/// Q/K/V layout), used by `MaskKind::Custom`.
394///
395/// 144 bytes (36 u32s); WebGPU uniform-buffer 16-byte alignment OK.
396#[repr(C)]
397#[derive(Debug, Clone, Copy, Pod, Zeroable)]
398pub struct AttentionParams {
399    pub batch: u32,
400    pub heads: u32,
401    pub seq_q: u32,
402    pub seq_k: u32,
403    pub head_dim: u32,
404    pub q_off: u32,
405    pub k_off: u32,
406    pub v_off: u32,
407    pub out_off: u32,
408    pub mask_off: u32,
409    pub mask_kind: u32,
410    pub scale_bits: u32,
411    pub window: u32,
412    /// MASK address strides. Mask address math (per-element):
413    ///   addr = mask_off
414    ///        + b  * mask_batch_stride
415    ///        + h  * mask_head_stride
416    ///        + qi * seq_q_stride         (per-query stride)
417    ///        + s  * seq_k_stride         (per-key   stride)
418    /// Setting some strides to 0 lets the kernel read a *broadcast*
419    /// mask without materializing the broadcast. e.g. BERT padding mask
420    /// `[B, S]`: mask_batch_stride=S, mask_head_stride=0, seq_q_stride=0,
421    /// seq_k_stride=1. Saves the Expand pre-pass that unfuse used to
422    /// emit per attention block.
423    pub seq_q_stride: u32,
424    pub seq_k_stride: u32,
425    pub mask_batch_stride: u32,
426    pub mask_head_stride: u32,
427    pub _pad_mask_0: u32,
428    pub _pad_mask_1: u32,
429    pub _pad_mask_2: u32,
430
431    // Q stride triple (in f32 elements). For [B, H, S, D]:
432    //   q_batch_stride = H·S·D, q_head_stride = S·D, q_seq_stride = D
433    // For [B, S, H, D]:
434    //   q_batch_stride = S·H·D, q_head_stride = D,   q_seq_stride = H·D
435    pub q_batch_stride: u32,
436    pub q_head_stride: u32,
437    pub q_seq_stride: u32,
438    pub _pad_q: u32,
439
440    pub k_batch_stride: u32,
441    pub k_head_stride: u32,
442    pub k_seq_stride: u32,
443    pub _pad_k: u32,
444
445    pub v_batch_stride: u32,
446    pub v_head_stride: u32,
447    pub v_seq_stride: u32,
448    pub _pad_v: u32,
449
450    pub o_batch_stride: u32,
451    pub o_head_stride: u32,
452    pub o_seq_stride: u32,
453    pub _pad_o: u32,
454}
455
456/// Layout for [`attention_bwd.wgsl`] — forward strides + `dy_off` + `wrt`.
457#[repr(C)]
458#[derive(Debug, Clone, Copy, Pod, Zeroable)]
459pub struct AttentionBwdParams {
460    pub batch: u32,
461    pub heads: u32,
462    pub seq_q: u32,
463    pub seq_k: u32,
464    pub head_dim: u32,
465    pub q_off: u32,
466    pub k_off: u32,
467    pub v_off: u32,
468    pub dy_off: u32,
469    pub out_off: u32,
470    pub mask_off: u32,
471    pub mask_kind: u32,
472    pub scale_bits: u32,
473    pub window: u32,
474    pub wrt: u32,
475    pub seq_q_stride: u32,
476    pub seq_k_stride: u32,
477    pub mask_batch_stride: u32,
478    pub mask_head_stride: u32,
479    pub _pad_mask_0: u32,
480    pub _pad_mask_1: u32,
481    pub _pad_mask_2: u32,
482    pub q_batch_stride: u32,
483    pub q_head_stride: u32,
484    pub q_seq_stride: u32,
485    pub _pad_q: u32,
486    pub k_batch_stride: u32,
487    pub k_head_stride: u32,
488    pub k_seq_stride: u32,
489    pub _pad_k: u32,
490    pub v_batch_stride: u32,
491    pub v_head_stride: u32,
492    pub v_seq_stride: u32,
493    pub _pad_v: u32,
494    pub o_batch_stride: u32,
495    pub o_head_stride: u32,
496    pub o_seq_stride: u32,
497    pub _pad_o: u32,
498}
499
500/// Layout for Rope.
501#[repr(C)]
502#[derive(Debug, Clone, Copy, Pod, Zeroable)]
503pub struct RopeParams {
504    pub n_total: u32,
505    pub seq: u32,
506    pub head_dim: u32,
507    pub half: u32,
508    pub in_off: u32,
509    pub cos_off: u32,
510    pub sin_off: u32,
511    pub out_off: u32,
512    pub last_dim: u32,
513    /// PLAN L1 — set at compile time. Together with `seq_stride`,
514    /// lets the WGSL kernel decompose iteration index into
515    /// `(bi, si, d)` while indexing into the underlying full-extent
516    /// buffer. `n_total` is the runtime-scaled iteration bound;
517    /// `seq_stride` is the compile-time-fixed full seq for stride.
518    pub batch: u32,
519    pub seq_stride: u32,
520    pub _p2: u32,
521}
522
523/// Layout for Expand. Mirrors TransposeParams (rank, total, offsets);
524/// per-axis dims/strides ride in the meta storage buffer.
525#[repr(C)]
526#[derive(Debug, Clone, Copy, Pod, Zeroable)]
527pub struct ExpandParams {
528    pub rank: u32,
529    pub out_total: u32,
530    pub in_off: u32,
531    pub out_off: u32,
532    /// PLAN L1 — precomputed at compile time. `1` when the bucket
533    /// axis stays at output axis 0 after the expand mapping.
534    pub bucket_outermost: u32,
535    /// PLAN L1 — `out_dims[0]` for active-extent scaling math.
536    pub out_dim_0: u32,
537    pub _p2: u32,
538    pub _p3: u32,
539}
540
541/// Layout for argmax (matches Reduce shape).
542#[repr(C)]
543#[derive(Debug, Clone, Copy, Pod, Zeroable)]
544pub struct ArgmaxParams {
545    pub outer: u32,
546    pub inner: u32,
547    pub in_off: u32,
548    pub out_off: u32,
549    pub _p0: u32,
550    pub _p1: u32,
551    pub _p2: u32,
552    pub _p3: u32,
553}
554
555/// Layout for Pool2D NCHW.
556#[repr(C)]
557#[derive(Debug, Clone, Copy, Pod, Zeroable)]
558pub struct Pool2dParams {
559    pub n: u32,
560    pub c: u32,
561    pub h: u32,
562    pub w: u32,
563    pub h_out: u32,
564    pub w_out: u32,
565    pub kh: u32,
566    pub kw: u32,
567    pub sh: u32,
568    pub sw: u32,
569    pub ph: u32,
570    pub pw: u32,
571    pub op: u32,
572    pub in_off: u32,
573    pub out_off: u32,
574    pub _p0: u32,
575    pub _p1: u32,
576    pub _p2: u32,
577}
578
579/// Layout for Conv2D NCHW.
580#[repr(C)]
581#[derive(Debug, Clone, Copy, Pod, Zeroable)]
582pub struct Conv2dParams {
583    pub n: u32,
584    pub c_in: u32,
585    pub c_out: u32,
586    pub h: u32,
587    pub w: u32,
588    pub h_out: u32,
589    pub w_out: u32,
590    pub kh: u32,
591    pub kw: u32,
592    pub sh: u32,
593    pub sw: u32,
594    pub ph: u32,
595    pub pw: u32,
596    pub dh: u32,
597    pub dw: u32,
598    pub groups: u32,
599    pub in_off: u32,
600    pub w_off: u32,
601    pub out_off: u32,
602}
603
604/// Layout for Pool1D NCL.
605#[repr(C)]
606#[derive(Debug, Clone, Copy, Pod, Zeroable)]
607pub struct Pool1dParams {
608    pub n: u32,
609    pub c: u32,
610    pub l: u32,
611    pub l_out: u32,
612    pub kl: u32,
613    pub sl: u32,
614    pub pl: u32,
615    pub op: u32,
616    pub in_off: u32,
617    pub out_off: u32,
618    pub _p0: u32,
619    pub _p1: u32,
620    pub _p2: u32,
621    pub _p3: u32,
622    pub _p4: u32,
623    pub _p5: u32,
624}
625
626/// Layout for Pool3D NCDHW.
627#[repr(C)]
628#[derive(Debug, Clone, Copy, Pod, Zeroable)]
629pub struct Pool3dParams {
630    pub n: u32,
631    pub c: u32,
632    pub d: u32,
633    pub h: u32,
634    pub w: u32,
635    pub d_out: u32,
636    pub h_out: u32,
637    pub w_out: u32,
638    pub kd: u32,
639    pub kh: u32,
640    pub kw: u32,
641    pub sd: u32,
642    pub sh: u32,
643    pub sw: u32,
644    pub pd: u32,
645    pub ph: u32,
646    pub pw: u32,
647    pub op: u32,
648    pub in_off: u32,
649    pub out_off: u32,
650    pub _p0: u32,
651    pub _p1: u32,
652}
653
654/// Layout for Conv1D NCL.
655#[repr(C)]
656#[derive(Debug, Clone, Copy, Pod, Zeroable)]
657pub struct Conv1dParams {
658    pub n: u32,
659    pub c_in: u32,
660    pub c_out: u32,
661    pub l: u32,
662    pub l_out: u32,
663    pub kl: u32,
664    pub sl: u32,
665    pub pl: u32,
666    pub dl: u32,
667    pub groups: u32,
668    pub in_off: u32,
669    pub w_off: u32,
670    pub out_off: u32,
671    pub _p0: u32,
672    pub _p1: u32,
673    pub _p2: u32,
674}
675
676/// Layout for DequantMatMul. 48 bytes.
677#[repr(C)]
678#[derive(Debug, Clone, Copy, Pod, Zeroable)]
679pub struct DequantMatmulParams {
680    pub m: u32,
681    pub k: u32,
682    pub n: u32,
683    pub block_size: u32,
684    pub scheme_id: u32,
685    pub x_off: u32,
686    pub w_off: u32,
687    pub scale_off: u32,
688    pub zp_off: u32,
689    pub out_off: u32,
690    pub _p0: u32,
691    pub _p1: u32,
692}
693
694/// Layout for FusedResidualLN-Tee. 48 bytes (12 u32s).
695#[repr(C)]
696#[derive(Debug, Clone, Copy, Pod, Zeroable)]
697pub struct FusedResidualLnTeeParams {
698    pub outer: u32,
699    pub inner: u32,
700    pub in_off: u32,
701    pub residual_off: u32,
702    pub bias_off: u32,
703    pub gamma_off: u32,
704    pub beta_off: u32,
705    pub sum_off: u32,
706    pub ln_out_off: u32,
707    pub eps_bits: u32,
708    pub has_bias: u32,
709    pub _p0: u32,
710}
711
712/// Layout for matmul_qkv (split-write QKV matmul).
713/// 64 bytes (16 u32s); WebGPU uniform-buffer 16-byte alignment OK.
714#[repr(C)]
715#[derive(Debug, Clone, Copy, Pod, Zeroable)]
716pub struct MatmulQkvParams {
717    pub m: u32,
718    pub k: u32,
719    pub n: u32,
720    pub a_off: u32,
721    pub b_off: u32,
722    pub q_off: u32,
723    pub k_off: u32,
724    pub v_off: u32,
725    pub head_width: u32,
726    pub has_bias: u32,
727    pub bias_off: u32,
728    pub _p0: u32,
729    pub _p1: u32,
730    pub _p2: u32,
731    pub _p3: u32,
732    pub _p4: u32,
733}
734
735/// Layout for FusedResidualRmsNorm (same bind layout as FusedResidualLN).
736pub type FusedResidualRmsNormParams = FusedResidualLnParams;
737
738/// Layout for FusedResidualLN. 48 bytes.
739#[repr(C)]
740#[derive(Debug, Clone, Copy, Pod, Zeroable)]
741pub struct FusedResidualLnParams {
742    pub outer: u32,
743    pub inner: u32,
744    pub in_off: u32,
745    pub residual_off: u32,
746    pub bias_off: u32,
747    pub gamma_off: u32,
748    pub beta_off: u32,
749    pub out_off: u32,
750    pub eps_bits: u32,
751    pub has_bias: u32,
752    pub _p0: u32,
753    pub _p1: u32,
754}
755
756/// Layout for SelectiveScan. 64 bytes.
757#[repr(C)]
758#[derive(Debug, Clone, Copy, Pod, Zeroable)]
759pub struct SelectiveScanParams {
760    pub batch: u32,
761    pub seq: u32,
762    pub hidden: u32,
763    pub state_size: u32,
764    pub x_off: u32,
765    pub delta_off: u32,
766    pub a_off: u32,
767    pub b_off: u32,
768    pub c_off: u32,
769    pub out_off: u32,
770    /// PLAN L1 — full-extent seq stride for per-batch offset math.
771    /// Stays at compile-time `seq` even when runtime `seq` is scaled,
772    /// so per-batch arena offsets stay correct under active-extent.
773    pub seq_stride: u32,
774    pub _p1: u32,
775    pub _p2: u32,
776    pub _p3: u32,
777    pub _p4: u32,
778    pub _p5: u32,
779}
780
781/// Layout for Sample. 48 bytes.
782#[repr(C)]
783#[derive(Debug, Clone, Copy, Pod, Zeroable)]
784pub struct SampleParams {
785    pub outer: u32,
786    pub inner: u32,
787    pub in_off: u32,
788    pub out_off: u32,
789    pub top_k: u32,
790    pub top_p_bits: u32,
791    pub temp_bits: u32,
792    pub seed_lo: u32,
793    pub seed_hi: u32,
794    pub _p0: u32,
795    pub _p1: u32,
796    pub _p2: u32,
797}
798
799/// Layout for GroupedMatMul. 32 bytes.
800#[repr(C)]
801#[derive(Debug, Clone, Copy, Pod, Zeroable)]
802pub struct GroupedMatmulParams {
803    pub m: u32,
804    pub k: u32,
805    pub n: u32,
806    pub num_experts: u32,
807    pub in_off: u32,
808    pub w_off: u32,
809    pub idx_off: u32,
810    pub out_off: u32,
811}
812
813/// Layout for TopK. 32 bytes.
814#[repr(C)]
815#[derive(Debug, Clone, Copy, Pod, Zeroable)]
816pub struct TopKParams {
817    pub outer: u32,
818    pub inner: u32,
819    pub k: u32,
820    pub in_off: u32,
821    pub out_off: u32,
822    pub _p0: u32,
823    pub _p1: u32,
824    pub _p2: u32,
825}
826
827/// Layout for UMAP k-NN on a pairwise `[n, n]` matrix. 32 bytes.
828#[repr(C)]
829#[derive(Debug, Clone, Copy, Pod, Zeroable)]
830pub struct UmapKnnParams {
831    pub n: u32,
832    pub k: u32,
833    pub pw_off: u32,
834    pub out_off: u32,
835    pub _p0: u32,
836    pub _p1: u32,
837    pub _p2: u32,
838}
839
840/// Layout for ScatterAdd. 32 bytes (8 u32s).
841#[repr(C)]
842#[derive(Debug, Clone, Copy, Pod, Zeroable)]
843pub struct ScatterAddParams {
844    pub op: u32, // 0 = zero phase, 1 = accumulate phase
845    pub out_off: u32,
846    pub upd_off: u32,
847    pub idx_off: u32,
848    pub out_total: u32,
849    pub num_updates: u32,
850    pub trailing: u32,
851    pub out_dim: u32,
852}
853
854/// Layout for Conv3D NCDHW.
855#[repr(C)]
856#[derive(Debug, Clone, Copy, Pod, Zeroable)]
857pub struct Conv3dParams {
858    pub n: u32,
859    pub c_in: u32,
860    pub c_out: u32,
861    pub d: u32,
862    pub h: u32,
863    pub w: u32,
864    pub d_out: u32,
865    pub h_out: u32,
866    pub w_out: u32,
867    pub kd: u32,
868    pub kh: u32,
869    pub kw: u32,
870    pub sd: u32,
871    pub sh: u32,
872    pub sw: u32,
873    pub pd: u32,
874    pub ph: u32,
875    pub pw: u32,
876    pub dd: u32,
877    pub dh: u32,
878    pub dw: u32,
879    pub groups: u32,
880    pub in_off: u32,
881    pub w_off: u32,
882    pub out_off: u32,
883    pub _p0: u32,
884}
885
886/// Lazy-init container for a compute pipeline + its bind-group layout.
887pub struct Kernel {
888    pub pipeline: wgpu::ComputePipeline,
889    pub bgl: wgpu::BindGroupLayout,
890}
891
892impl Kernel {
893    pub fn bind_two(
894        &self,
895        device: &wgpu::Device,
896        arena: &wgpu::Buffer,
897        uniform: &wgpu::Buffer,
898    ) -> wgpu::BindGroup {
899        device.create_bind_group(&wgpu::BindGroupDescriptor {
900            label: Some("rlx-wgpu fft gpu bg"),
901            layout: &self.bgl,
902            entries: &[
903                wgpu::BindGroupEntry {
904                    binding: 0,
905                    resource: arena.as_entire_binding(),
906                },
907                wgpu::BindGroupEntry {
908                    binding: 1,
909                    resource: uniform.as_entire_binding(),
910                },
911            ],
912        })
913    }
914}
915
916/// Build a 4-binding compute kernel: storage(rw) / uniform / storage(ro)
917/// / storage(ro). Currently unused — `matmul_coop16` switched to a
918/// 3-binding layout (A is staged from arena through workgroup memory
919/// instead of from a separate f16 binding). Kept for future kernels
920/// that genuinely need a 4th binding.
921#[allow(dead_code)]
922/// Used by the cooperative-matrix matmul which needs a
923/// fourth binding for the f16 activation shadow buffer.
924fn build_kernel_4(
925    device: &wgpu::Device,
926    label: &'static str,
927    wgsl: &str,
928    entry_point: &'static str,
929) -> Kernel {
930    let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
931        label: Some(label),
932        source: wgpu::ShaderSource::Wgsl(wgsl.into()),
933    });
934    let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
935        label: Some(label),
936        entries: &[
937            wgpu::BindGroupLayoutEntry {
938                binding: 0,
939                visibility: wgpu::ShaderStages::COMPUTE,
940                ty: wgpu::BindingType::Buffer {
941                    ty: wgpu::BufferBindingType::Storage { read_only: false },
942                    has_dynamic_offset: false,
943                    min_binding_size: None,
944                },
945                count: None,
946            },
947            wgpu::BindGroupLayoutEntry {
948                binding: 1,
949                visibility: wgpu::ShaderStages::COMPUTE,
950                ty: wgpu::BindingType::Buffer {
951                    ty: wgpu::BufferBindingType::Uniform,
952                    has_dynamic_offset: false,
953                    min_binding_size: None,
954                },
955                count: None,
956            },
957            wgpu::BindGroupLayoutEntry {
958                binding: 2,
959                visibility: wgpu::ShaderStages::COMPUTE,
960                ty: wgpu::BindingType::Buffer {
961                    ty: wgpu::BufferBindingType::Storage { read_only: true },
962                    has_dynamic_offset: false,
963                    min_binding_size: None,
964                },
965                count: None,
966            },
967            wgpu::BindGroupLayoutEntry {
968                binding: 3,
969                visibility: wgpu::ShaderStages::COMPUTE,
970                ty: wgpu::BindingType::Buffer {
971                    ty: wgpu::BufferBindingType::Storage { read_only: true },
972                    has_dynamic_offset: false,
973                    min_binding_size: None,
974                },
975                count: None,
976            },
977        ],
978    });
979    let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
980        label: Some(label),
981        bind_group_layouts: &[Some(&bgl)],
982        immediate_size: 0,
983    });
984    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
985        label: Some(label),
986        layout: Some(&layout),
987        module: &module,
988        entry_point: Some(entry_point),
989        compilation_options: Default::default(),
990        cache: None,
991    });
992    Kernel { pipeline, bgl }
993}
994
995fn build_kernel_3(
996    device: &wgpu::Device,
997    label: &'static str,
998    wgsl: &str,
999    entry_point: &'static str,
1000) -> Kernel {
1001    let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1002        label: Some(label),
1003        source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1004    });
1005    let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1006        label: Some(label),
1007        entries: &[
1008            wgpu::BindGroupLayoutEntry {
1009                binding: 0,
1010                visibility: wgpu::ShaderStages::COMPUTE,
1011                ty: wgpu::BindingType::Buffer {
1012                    ty: wgpu::BufferBindingType::Storage { read_only: false },
1013                    has_dynamic_offset: false,
1014                    min_binding_size: None,
1015                },
1016                count: None,
1017            },
1018            wgpu::BindGroupLayoutEntry {
1019                binding: 1,
1020                visibility: wgpu::ShaderStages::COMPUTE,
1021                ty: wgpu::BindingType::Buffer {
1022                    ty: wgpu::BufferBindingType::Uniform,
1023                    has_dynamic_offset: false,
1024                    min_binding_size: None,
1025                },
1026                count: None,
1027            },
1028            wgpu::BindGroupLayoutEntry {
1029                binding: 2,
1030                visibility: wgpu::ShaderStages::COMPUTE,
1031                ty: wgpu::BindingType::Buffer {
1032                    ty: wgpu::BufferBindingType::Storage { read_only: true },
1033                    has_dynamic_offset: false,
1034                    min_binding_size: None,
1035                },
1036                count: None,
1037            },
1038        ],
1039    });
1040    let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1041        label: Some(label),
1042        bind_group_layouts: &[Some(&bgl)],
1043        immediate_size: 0,
1044    });
1045    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1046        label: Some(label),
1047        layout: Some(&layout),
1048        module: &module,
1049        entry_point: Some(entry_point),
1050        compilation_options: Default::default(),
1051        cache: None,
1052    });
1053    Kernel { pipeline, bgl }
1054}
1055
1056fn build_kernel(
1057    device: &wgpu::Device,
1058    label: &'static str,
1059    wgsl: &str,
1060    entry_point: &'static str,
1061) -> Kernel {
1062    let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1063        label: Some(label),
1064        source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1065    });
1066    let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1067        label: Some(label),
1068        entries: &[
1069            wgpu::BindGroupLayoutEntry {
1070                binding: 0,
1071                visibility: wgpu::ShaderStages::COMPUTE,
1072                ty: wgpu::BindingType::Buffer {
1073                    ty: wgpu::BufferBindingType::Storage { read_only: false },
1074                    has_dynamic_offset: false,
1075                    min_binding_size: None,
1076                },
1077                count: None,
1078            },
1079            wgpu::BindGroupLayoutEntry {
1080                binding: 1,
1081                visibility: wgpu::ShaderStages::COMPUTE,
1082                ty: wgpu::BindingType::Buffer {
1083                    ty: wgpu::BufferBindingType::Uniform,
1084                    has_dynamic_offset: false,
1085                    min_binding_size: None,
1086                },
1087                count: None,
1088            },
1089        ],
1090    });
1091    let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1092        label: Some(label),
1093        bind_group_layouts: &[Some(&bgl)],
1094        immediate_size: 0,
1095    });
1096    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1097        label: Some(label),
1098        layout: Some(&layout),
1099        module: &module,
1100        entry_point: Some(entry_point),
1101        compilation_options: Default::default(),
1102        cache: None,
1103    });
1104    Kernel { pipeline, bgl }
1105}
1106
1107static MATMUL: OnceLock<Kernel> = OnceLock::new();
1108static MATMUL_WIDE: OnceLock<Kernel> = OnceLock::new();
1109static MATMUL_F16W: OnceLock<Kernel> = OnceLock::new();
1110static MATMUL_F16_COMPUTE: OnceLock<Kernel> = OnceLock::new();
1111static MATMUL_COOP16: OnceLock<Kernel> = OnceLock::new();
1112static MATMUL_COOP_F32: OnceLock<Kernel> = OnceLock::new();
1113static CAST_F32_TO_F16: OnceLock<Kernel> = OnceLock::new();
1114static BINARY: OnceLock<Kernel> = OnceLock::new();
1115static UNARY: OnceLock<Kernel> = OnceLock::new();
1116static COMPARE: OnceLock<Kernel> = OnceLock::new();
1117static WHEREK: OnceLock<Kernel> = OnceLock::new();
1118static REDUCE: OnceLock<Kernel> = OnceLock::new();
1119static SOFTMAX: OnceLock<Kernel> = OnceLock::new();
1120static LAYERNORM: OnceLock<Kernel> = OnceLock::new();
1121static RMS_NORM_BWD: OnceLock<Kernel> = OnceLock::new();
1122static RMS_NORM_BWD_PARAM: OnceLock<Kernel> = OnceLock::new();
1123static CUMSUM_BWD: OnceLock<Kernel> = OnceLock::new();
1124static ROPE_BWD: OnceLock<Kernel> = OnceLock::new();
1125static GATHER_BWD_ZERO: OnceLock<Kernel> = OnceLock::new();
1126static GATHER_BWD_ACC: OnceLock<Kernel> = OnceLock::new();
1127static CUMSUM: OnceLock<Kernel> = OnceLock::new();
1128static FFT_GPU_RADIX2: OnceLock<Kernel> = OnceLock::new();
1129static FFT_GPU_BITREV: OnceLock<Kernel> = OnceLock::new();
1130static FFT_GPU_INNER: OnceLock<Kernel> = OnceLock::new();
1131static FFT_GPU_OUTER_R4: OnceLock<Kernel> = OnceLock::new();
1132static FFT_GPU_OUTER_R2: OnceLock<Kernel> = OnceLock::new();
1133static COPY: OnceLock<Kernel> = OnceLock::new();
1134static ELEMENTWISE_REGION: OnceLock<Kernel> = OnceLock::new();
1135static TRANSPOSE: OnceLock<Kernel> = OnceLock::new();
1136static NARROW: OnceLock<Kernel> = OnceLock::new();
1137static CONCAT: OnceLock<Kernel> = OnceLock::new();
1138static GATHER: OnceLock<Kernel> = OnceLock::new();
1139static GATHER_AXIS: OnceLock<Kernel> = OnceLock::new();
1140static ATTENTION: OnceLock<Kernel> = OnceLock::new();
1141static ATTENTION_BWD: OnceLock<Kernel> = OnceLock::new();
1142static ROPE: OnceLock<Kernel> = OnceLock::new();
1143static EXPAND: OnceLock<Kernel> = OnceLock::new();
1144static ARGMAX: OnceLock<Kernel> = OnceLock::new();
1145static POOL2D: OnceLock<Kernel> = OnceLock::new();
1146static CONV2D: OnceLock<Kernel> = OnceLock::new();
1147static POOL1D: OnceLock<Kernel> = OnceLock::new();
1148static POOL3D: OnceLock<Kernel> = OnceLock::new();
1149static CONV1D: OnceLock<Kernel> = OnceLock::new();
1150static CONV3D: OnceLock<Kernel> = OnceLock::new();
1151static SCATTER_ADD: OnceLock<Kernel> = OnceLock::new();
1152static TOPK: OnceLock<Kernel> = OnceLock::new();
1153static UMAP_KNN: OnceLock<Kernel> = OnceLock::new();
1154static GROUPED_MATMUL: OnceLock<Kernel> = OnceLock::new();
1155static SAMPLE: OnceLock<Kernel> = OnceLock::new();
1156static SELECTIVE_SCAN: OnceLock<Kernel> = OnceLock::new();
1157static DEQUANT_MATMUL: OnceLock<Kernel> = OnceLock::new();
1158static FUSED_RESIDUAL_LN: OnceLock<Kernel> = OnceLock::new();
1159static FUSED_RESIDUAL_LN_TEE: OnceLock<Kernel> = OnceLock::new();
1160static FUSED_RESIDUAL_RMS_NORM: OnceLock<Kernel> = OnceLock::new();
1161static MATMUL_QKV: OnceLock<Kernel> = OnceLock::new();
1162static MATMUL_QKV_COOP_F32: OnceLock<Kernel> = OnceLock::new();
1163
1164pub fn matmul_kernel(device: &wgpu::Device) -> &'static Kernel {
1165    MATMUL.get_or_init(|| build_kernel(device, "rlx-wgpu matmul", MATMUL_WGSL, "matmul"))
1166}
1167pub fn matmul_wide_kernel(device: &wgpu::Device) -> &'static Kernel {
1168    MATMUL_WIDE.get_or_init(|| {
1169        build_kernel(
1170            device,
1171            "rlx-wgpu matmul_wide",
1172            MATMUL_WIDE_WGSL,
1173            "matmul_wide",
1174        )
1175    })
1176}
1177/// f16-weight matmul (f32 compute). Returns Some only when the device
1178/// exposes the `SHADER_F16` feature. EXPERIMENTAL: currently slower
1179/// than the f32 baseline on Apple Silicon — kept as foundation; see
1180/// `matmul_f16w.wgsl` for the empirical analysis.
1181pub fn matmul_f16w_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1182    if !device.features().contains(wgpu::Features::SHADER_F16) {
1183        return None;
1184    }
1185    Some(MATMUL_F16W.get_or_init(|| {
1186        build_kernel_3(
1187            device,
1188            "rlx-wgpu matmul_f16w",
1189            MATMUL_F16W_WGSL,
1190            "matmul_f16w",
1191        )
1192    }))
1193}
1194/// f16-compute matmul: f16 operands, f16 multiply, f32 accumulator.
1195/// Targets the 2× f16 ALU throughput on Apple Silicon. Returns Some
1196/// only when the device exposes `SHADER_F16`.
1197pub fn matmul_f16_compute_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1198    if !device.features().contains(wgpu::Features::SHADER_F16) {
1199        return None;
1200    }
1201    Some(MATMUL_F16_COMPUTE.get_or_init(|| {
1202        build_kernel_3(
1203            device,
1204            "rlx-wgpu matmul_f16_compute",
1205            MATMUL_F16_COMPUTE_WGSL,
1206            "matmul_f16_compute",
1207        )
1208    }))
1209}
1210/// Cooperative-matrix matmul (8×8 tiles, hardware GEMM units).
1211/// Lowers to MSL `simdgroup_matrix` on Metal and SPIR-V's
1212/// `OpCooperativeMatrixMulAddKHR` on Vulkan. Returns Some only when
1213/// the device exposes both `SHADER_F16` and
1214/// `EXPERIMENTAL_COOPERATIVE_MATRIX`.
1215pub fn matmul_coop16_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1216    let feats = device.features();
1217    if !feats.contains(wgpu::Features::SHADER_F16)
1218        || !feats.contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX)
1219    {
1220        return None;
1221    }
1222    Some(MATMUL_COOP16.get_or_init(|| {
1223        build_kernel_3(
1224            device,
1225            "rlx-wgpu matmul_coop16",
1226            MATMUL_COOP16_WGSL,
1227            "matmul_coop16",
1228        )
1229    }))
1230}
1231/// Pure-f32 cooperative-matrix matmul. No SHADER_F16 needed — uses
1232/// `coop_mat8x8<f32>` throughout (lowers to `simdgroup_float8x8` on
1233/// Apple). Returns None if the cooperative-matrix feature is missing
1234/// OR if the device's WGSL→backend lowering can't compile it (some
1235/// implementations only expose half-precision coop matrices).
1236pub fn matmul_coop_f32_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1237    let feats = device.features();
1238    if !feats.contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX) {
1239        return None;
1240    }
1241    Some(MATMUL_COOP_F32.get_or_init(|| {
1242        build_kernel(
1243            device,
1244            "rlx-wgpu matmul_coop_f32",
1245            MATMUL_COOP_F32_WGSL,
1246            "matmul_coop_f32",
1247        )
1248    }))
1249}
1250/// Mirrors a region of the f32 arena into the f16 shadow buffer.
1251/// Used before `matmul_coop16` for the matmul's activation operand
1252/// (intermediate activations don't go through `set_param` /
1253/// `write_f32`, so they aren't in the f16 buffer otherwise).
1254pub fn cast_f32_to_f16_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1255    if !device.features().contains(wgpu::Features::SHADER_F16) {
1256        return None;
1257    }
1258    Some(CAST_F32_TO_F16.get_or_init(|| {
1259        build_kernel_3(
1260            device,
1261            "rlx-wgpu cast_f32_to_f16",
1262            CAST_F32_TO_F16_WGSL,
1263            "cast_f32_to_f16",
1264        )
1265    }))
1266}
1267pub fn binary_kernel(device: &wgpu::Device) -> &'static Kernel {
1268    BINARY.get_or_init(|| build_kernel(device, "rlx-wgpu binary", BINARY_WGSL, "binary"))
1269}
1270pub fn unary_kernel(device: &wgpu::Device) -> &'static Kernel {
1271    UNARY.get_or_init(|| build_kernel(device, "rlx-wgpu unary", UNARY_WGSL, "unary"))
1272}
1273pub fn compare_kernel(device: &wgpu::Device) -> &'static Kernel {
1274    COMPARE.get_or_init(|| build_kernel(device, "rlx-wgpu compare", COMPARE_WGSL, "compare"))
1275}
1276pub fn where_kernel(device: &wgpu::Device) -> &'static Kernel {
1277    WHEREK.get_or_init(|| build_kernel(device, "rlx-wgpu where", WHERE_WGSL, "where_select"))
1278}
1279pub fn reduce_kernel(device: &wgpu::Device) -> &'static Kernel {
1280    REDUCE.get_or_init(|| build_kernel(device, "rlx-wgpu reduce", REDUCE_WGSL, "reduce"))
1281}
1282pub fn softmax_kernel(device: &wgpu::Device) -> &'static Kernel {
1283    SOFTMAX.get_or_init(|| build_kernel(device, "rlx-wgpu softmax", SOFTMAX_WGSL, "softmax"))
1284}
1285pub fn layernorm_kernel(device: &wgpu::Device) -> &'static Kernel {
1286    LAYERNORM.get_or_init(|| build_kernel(device, "rlx-wgpu layernorm", LAYERNORM_WGSL, "norm"))
1287}
1288pub fn rms_norm_backward_kernel(device: &wgpu::Device) -> &'static Kernel {
1289    RMS_NORM_BWD.get_or_init(|| {
1290        build_kernel(
1291            device,
1292            "rlx-wgpu rms_norm_bwd",
1293            RMS_NORM_BWD_WGSL,
1294            "rms_norm_bwd",
1295        )
1296    })
1297}
1298pub fn rms_norm_backward_param_kernel(device: &wgpu::Device) -> &'static Kernel {
1299    RMS_NORM_BWD_PARAM.get_or_init(|| {
1300        build_kernel(
1301            device,
1302            "rlx-wgpu rms_norm_bwd_param",
1303            RMS_NORM_BWD_WGSL,
1304            "rms_norm_bwd_param",
1305        )
1306    })
1307}
1308pub fn cumsum_backward_kernel(device: &wgpu::Device) -> &'static Kernel {
1309    CUMSUM_BWD
1310        .get_or_init(|| build_kernel(device, "rlx-wgpu cumsum_bwd", CUMSUM_BWD_WGSL, "cumsum_bwd"))
1311}
1312pub fn rope_backward_kernel(device: &wgpu::Device) -> &'static Kernel {
1313    ROPE_BWD.get_or_init(|| build_kernel(device, "rlx-wgpu rope_bwd", ROPE_BWD_WGSL, "rope_bwd"))
1314}
1315pub fn gather_backward_zero_kernel(device: &wgpu::Device) -> &'static Kernel {
1316    GATHER_BWD_ZERO.get_or_init(|| {
1317        build_kernel(
1318            device,
1319            "rlx-wgpu gather_bwd_zero",
1320            GATHER_BWD_WGSL,
1321            "gather_bwd_zero",
1322        )
1323    })
1324}
1325pub fn gather_backward_acc_kernel(device: &wgpu::Device) -> &'static Kernel {
1326    GATHER_BWD_ACC.get_or_init(|| {
1327        build_kernel(
1328            device,
1329            "rlx-wgpu gather_bwd_acc",
1330            GATHER_BWD_WGSL,
1331            "gather_bwd_acc",
1332        )
1333    })
1334}
1335pub fn cumsum_kernel(device: &wgpu::Device) -> &'static Kernel {
1336    CUMSUM.get_or_init(|| build_kernel(device, "rlx-wgpu cumsum", CUMSUM_WGSL, "cumsum"))
1337}
1338pub fn fft_gpu_radix2_full_kernel(device: &wgpu::Device) -> &'static Kernel {
1339    FFT_GPU_RADIX2.get_or_init(|| {
1340        build_kernel(
1341            device,
1342            "rlx-wgpu fft_radix2_full",
1343            FFT_GPU_WGSL,
1344            "fft_radix2_full",
1345        )
1346    })
1347}
1348pub fn fft_gpu_bit_reverse_kernel(device: &wgpu::Device) -> &'static Kernel {
1349    FFT_GPU_BITREV.get_or_init(|| {
1350        build_kernel(
1351            device,
1352            "rlx-wgpu fft_bit_reverse",
1353            FFT_GPU_WGSL,
1354            "fft_bit_reverse",
1355        )
1356    })
1357}
1358pub fn fft_gpu_inner_kernel(device: &wgpu::Device) -> &'static Kernel {
1359    FFT_GPU_INNER
1360        .get_or_init(|| build_kernel(device, "rlx-wgpu fft_inner", FFT_GPU_WGSL, "fft_inner"))
1361}
1362pub fn fft_gpu_outer_r4_kernel(device: &wgpu::Device) -> &'static Kernel {
1363    FFT_GPU_OUTER_R4.get_or_init(|| {
1364        build_kernel(
1365            device,
1366            "rlx-wgpu fft_outer_r4",
1367            FFT_GPU_WGSL,
1368            "fft_outer_r4",
1369        )
1370    })
1371}
1372pub fn fft_gpu_outer_r2_kernel(device: &wgpu::Device) -> &'static Kernel {
1373    FFT_GPU_OUTER_R2.get_or_init(|| {
1374        build_kernel(
1375            device,
1376            "rlx-wgpu fft_outer_r2",
1377            FFT_GPU_WGSL,
1378            "fft_outer_r2",
1379        )
1380    })
1381}
1382pub fn copy_kernel(device: &wgpu::Device) -> &'static Kernel {
1383    COPY.get_or_init(|| build_kernel(device, "rlx-wgpu copy", COPY_WGSL, "copy"))
1384}
1385pub fn elementwise_region_kernel(device: &wgpu::Device) -> &'static Kernel {
1386    // Region params bind as a STORAGE buffer (not uniform) — WGSL's
1387    // uniform-storage spec requires 16-byte stride for `array<T, N>`,
1388    // which our packed `array<u32, N>` chain layout doesn't satisfy.
1389    // Storage allows arbitrary stride.
1390    ELEMENTWISE_REGION.get_or_init(|| {
1391        build_kernel_region(
1392            device,
1393            "rlx-wgpu elementwise_region",
1394            ELEMENTWISE_REGION_WGSL,
1395            "elementwise_region",
1396        )
1397    })
1398}
1399
1400fn build_kernel_region(
1401    device: &wgpu::Device,
1402    label: &'static str,
1403    wgsl: &str,
1404    entry_point: &'static str,
1405) -> Kernel {
1406    let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1407        label: Some(label),
1408        source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1409    });
1410    let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1411        label: Some(label),
1412        entries: &[
1413            wgpu::BindGroupLayoutEntry {
1414                binding: 0,
1415                visibility: wgpu::ShaderStages::COMPUTE,
1416                ty: wgpu::BindingType::Buffer {
1417                    ty: wgpu::BufferBindingType::Storage { read_only: false },
1418                    has_dynamic_offset: false,
1419                    min_binding_size: None,
1420                },
1421                count: None,
1422            },
1423            wgpu::BindGroupLayoutEntry {
1424                binding: 1,
1425                visibility: wgpu::ShaderStages::COMPUTE,
1426                ty: wgpu::BindingType::Buffer {
1427                    // Region params: read-only storage (vs uniform).
1428                    ty: wgpu::BufferBindingType::Storage { read_only: true },
1429                    has_dynamic_offset: false,
1430                    min_binding_size: None,
1431                },
1432                count: None,
1433            },
1434        ],
1435    });
1436    let pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1437        label: Some(label),
1438        bind_group_layouts: &[Some(&bgl)],
1439        immediate_size: 0,
1440    });
1441    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1442        label: Some(label),
1443        layout: Some(&pl),
1444        module: &module,
1445        entry_point: Some(entry_point),
1446        compilation_options: Default::default(),
1447        cache: None,
1448    });
1449    Kernel { pipeline, bgl }
1450}
1451pub fn transpose_kernel(device: &wgpu::Device) -> &'static Kernel {
1452    TRANSPOSE
1453        .get_or_init(|| build_kernel_3(device, "rlx-wgpu transpose", TRANSPOSE_WGSL, "transpose"))
1454}
1455pub fn narrow_kernel(device: &wgpu::Device) -> &'static Kernel {
1456    NARROW.get_or_init(|| build_kernel(device, "rlx-wgpu narrow", NARROW_WGSL, "narrow"))
1457}
1458pub fn concat_kernel(device: &wgpu::Device) -> &'static Kernel {
1459    CONCAT.get_or_init(|| build_kernel(device, "rlx-wgpu concat", CONCAT_WGSL, "concat"))
1460}
1461pub fn gather_kernel(device: &wgpu::Device) -> &'static Kernel {
1462    GATHER.get_or_init(|| build_kernel(device, "rlx-wgpu gather", GATHER_WGSL, "gather"))
1463}
1464pub fn gather_axis_kernel(device: &wgpu::Device) -> &'static Kernel {
1465    GATHER_AXIS.get_or_init(|| {
1466        build_kernel(
1467            device,
1468            "rlx-wgpu gather_axis",
1469            GATHER_AXIS_WGSL,
1470            "gather_axis",
1471        )
1472    })
1473}
1474pub fn attention_kernel(device: &wgpu::Device) -> &'static Kernel {
1475    ATTENTION
1476        .get_or_init(|| build_kernel(device, "rlx-wgpu attention", ATTENTION_WGSL, "attention"))
1477}
1478pub fn attention_bwd_kernel(device: &wgpu::Device) -> &'static Kernel {
1479    ATTENTION_BWD.get_or_init(|| {
1480        build_kernel(
1481            device,
1482            "rlx-wgpu attention_bwd",
1483            ATTENTION_BWD_WGSL,
1484            "attention_bwd",
1485        )
1486    })
1487}
1488pub fn rope_kernel(device: &wgpu::Device) -> &'static Kernel {
1489    ROPE.get_or_init(|| build_kernel(device, "rlx-wgpu rope", ROPE_WGSL, "rope"))
1490}
1491pub fn expand_kernel(device: &wgpu::Device) -> &'static Kernel {
1492    EXPAND.get_or_init(|| build_kernel_3(device, "rlx-wgpu expand", EXPAND_WGSL, "expand"))
1493}
1494pub fn argmax_kernel(device: &wgpu::Device) -> &'static Kernel {
1495    ARGMAX.get_or_init(|| build_kernel(device, "rlx-wgpu argmax", ARGMAX_WGSL, "argmax"))
1496}
1497pub fn pool2d_kernel(device: &wgpu::Device) -> &'static Kernel {
1498    POOL2D.get_or_init(|| build_kernel(device, "rlx-wgpu pool2d", POOL2D_WGSL, "pool2d"))
1499}
1500pub fn conv2d_kernel(device: &wgpu::Device) -> &'static Kernel {
1501    CONV2D.get_or_init(|| build_kernel(device, "rlx-wgpu conv2d", CONV2D_WGSL, "conv2d"))
1502}
1503pub fn pool1d_kernel(device: &wgpu::Device) -> &'static Kernel {
1504    POOL1D.get_or_init(|| build_kernel(device, "rlx-wgpu pool1d", POOL1D_WGSL, "pool1d"))
1505}
1506pub fn pool3d_kernel(device: &wgpu::Device) -> &'static Kernel {
1507    POOL3D.get_or_init(|| build_kernel(device, "rlx-wgpu pool3d", POOL3D_WGSL, "pool3d"))
1508}
1509pub fn conv1d_kernel(device: &wgpu::Device) -> &'static Kernel {
1510    CONV1D.get_or_init(|| build_kernel(device, "rlx-wgpu conv1d", CONV1D_WGSL, "conv1d"))
1511}
1512pub fn conv3d_kernel(device: &wgpu::Device) -> &'static Kernel {
1513    CONV3D.get_or_init(|| build_kernel(device, "rlx-wgpu conv3d", CONV3D_WGSL, "conv3d"))
1514}
1515pub fn scatter_add_kernel(device: &wgpu::Device) -> &'static Kernel {
1516    SCATTER_ADD.get_or_init(|| {
1517        build_kernel(
1518            device,
1519            "rlx-wgpu scatter_add",
1520            SCATTER_ADD_WGSL,
1521            "scatter_add",
1522        )
1523    })
1524}
1525pub fn topk_kernel(device: &wgpu::Device) -> &'static Kernel {
1526    TOPK.get_or_init(|| build_kernel(device, "rlx-wgpu topk", TOPK_WGSL, "topk"))
1527}
1528pub fn umap_knn_kernel(device: &wgpu::Device) -> &'static Kernel {
1529    UMAP_KNN.get_or_init(|| build_kernel(device, "rlx-wgpu umap_knn", UMAP_KNN_WGSL, "umap_knn"))
1530}
1531pub fn grouped_matmul_kernel(device: &wgpu::Device) -> &'static Kernel {
1532    GROUPED_MATMUL.get_or_init(|| {
1533        build_kernel(
1534            device,
1535            "rlx-wgpu grouped_matmul",
1536            GROUPED_MATMUL_WGSL,
1537            "grouped_matmul",
1538        )
1539    })
1540}
1541pub fn sample_kernel(device: &wgpu::Device) -> &'static Kernel {
1542    SAMPLE.get_or_init(|| build_kernel(device, "rlx-wgpu sample", SAMPLE_WGSL, "sample"))
1543}
1544pub fn selective_scan_kernel(device: &wgpu::Device) -> &'static Kernel {
1545    SELECTIVE_SCAN.get_or_init(|| {
1546        build_kernel(
1547            device,
1548            "rlx-wgpu selective_scan",
1549            SELECTIVE_SCAN_WGSL,
1550            "selective_scan",
1551        )
1552    })
1553}
1554pub fn dequant_matmul_kernel(device: &wgpu::Device) -> &'static Kernel {
1555    DEQUANT_MATMUL.get_or_init(|| {
1556        build_kernel(
1557            device,
1558            "rlx-wgpu dequant_matmul",
1559            DEQUANT_MATMUL_WGSL,
1560            "dequant_matmul",
1561        )
1562    })
1563}
1564pub fn fused_residual_ln_kernel(device: &wgpu::Device) -> &'static Kernel {
1565    FUSED_RESIDUAL_LN.get_or_init(|| {
1566        build_kernel(
1567            device,
1568            "rlx-wgpu fused_residual_ln",
1569            FUSED_RESIDUAL_LN_WGSL,
1570            "fused_residual_ln",
1571        )
1572    })
1573}
1574pub fn fused_residual_ln_tee_kernel(device: &wgpu::Device) -> &'static Kernel {
1575    FUSED_RESIDUAL_LN_TEE.get_or_init(|| {
1576        build_kernel(
1577            device,
1578            "rlx-wgpu fused_residual_ln_tee",
1579            FUSED_RESIDUAL_LN_TEE_WGSL,
1580            "fused_residual_ln_tee",
1581        )
1582    })
1583}
1584pub fn fused_residual_rms_norm_kernel(device: &wgpu::Device) -> &'static Kernel {
1585    FUSED_RESIDUAL_RMS_NORM.get_or_init(|| {
1586        build_kernel(
1587            device,
1588            "rlx-wgpu fused_residual_rms_norm",
1589            FUSED_RESIDUAL_RMS_NORM_WGSL,
1590            "fused_residual_rms_norm",
1591        )
1592    })
1593}
1594pub fn matmul_qkv_kernel(device: &wgpu::Device) -> &'static Kernel {
1595    MATMUL_QKV
1596        .get_or_init(|| build_kernel(device, "rlx-wgpu matmul_qkv", MATMUL_QKV_WGSL, "matmul_qkv"))
1597}
1598pub fn matmul_qkv_coop_f32_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1599    if !device
1600        .features()
1601        .contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX)
1602    {
1603        return None;
1604    }
1605    Some(MATMUL_QKV_COOP_F32.get_or_init(|| {
1606        build_kernel(
1607            device,
1608            "rlx-wgpu matmul_qkv_coop_f32",
1609            MATMUL_QKV_COOP_F32_WGSL,
1610            "matmul_qkv_coop_f32",
1611        )
1612    }))
1613}