1use std::sync::OnceLock;
24
25use bytemuck::{Pod, Zeroable};
26
27pub const MATMUL_WGSL: &str = include_str!("matmul.wgsl");
28pub const MATMUL_WIDE_WGSL: &str = include_str!("matmul_wide.wgsl");
29pub const MATMUL_WIDE_NV_WGSL: &str = include_str!("matmul_wide_nv.wgsl");
30pub const MATMUL_F16W_WGSL: &str = include_str!("matmul_f16w.wgsl");
31pub const MATMUL_F16_COMPUTE_WGSL: &str = include_str!("matmul_f16_compute.wgsl");
32pub const MATMUL_COOP16_WGSL: &str = include_str!("matmul_coop16.wgsl");
33pub const MATMUL_COOP_F32_WGSL: &str = include_str!("matmul_coop_f32.wgsl");
34pub const MATMUL_COOP_F32_PORTABLE_WGSL: &str = include_str!("matmul_coop_f32_portable.wgsl");
35pub const MATMUL_COOP_F16_VULKAN_WGSL: &str = include_str!("matmul_coop_f16_vulkan.wgsl");
36pub const MATMUL_COOP_F16_VULKAN_WIDEN_WGSL: &str =
37 include_str!("matmul_coop_f16_vulkan_widen.wgsl");
38pub const MATMUL_COOP_F16_VULKAN_F32ACC_WGSL: &str =
39 include_str!("matmul_coop_f16_vulkan_f32acc.wgsl");
40pub const MATMUL_COOP_F16_VULKAN_WIDEN_F32ACC_WGSL: &str =
41 include_str!("matmul_coop_f16_vulkan_widen_f32acc.wgsl");
42pub const MATMUL_QKV_COOP_F16_VK_WGSL: &str = include_str!("matmul_qkv_coop_f16_vk.wgsl");
43pub const MATMUL_QKV_COOP_F16_VK_WIDEN_WGSL: &str =
44 include_str!("matmul_qkv_coop_f16_vk_widen.wgsl");
45pub const MATMUL_QKV_COOP_F16_VK_F32ACC_WGSL: &str =
46 include_str!("matmul_qkv_coop_f16_vk_f32acc.wgsl");
47pub const MATMUL_QKV_COOP_F16_VK_WIDEN_F32ACC_WGSL: &str =
48 include_str!("matmul_qkv_coop_f16_vk_widen_f32acc.wgsl");
49pub const CAST_F32_TO_F16_WGSL: &str = include_str!("cast_f32_to_f16.wgsl");
50pub const BINARY_WGSL: &str = include_str!("binary.wgsl");
51pub const UNARY_WGSL: &str = include_str!("unary.wgsl");
52pub const UNARY_F16_MIRROR_WGSL: &str = include_str!("unary_f16_mirror.wgsl");
53pub const COMPARE_WGSL: &str = include_str!("compare.wgsl");
54pub const WHERE_WGSL: &str = include_str!("where.wgsl");
55pub const REDUCE_WGSL: &str = include_str!("reduce.wgsl");
56pub const SOFTMAX_WGSL: &str = include_str!("softmax.wgsl");
57pub const LAYERNORM_WGSL: &str = include_str!("layernorm.wgsl");
58pub const RMS_NORM_BWD_WGSL: &str = include_str!("rms_norm_backward.wgsl");
59pub const LAYER_NORM_BWD_WGSL: &str = include_str!("layer_norm_backward.wgsl");
60pub const CUMSUM_BWD_WGSL: &str = include_str!("cumsum_backward.wgsl");
61pub const ROPE_BWD_WGSL: &str = include_str!("rope_backward.wgsl");
62pub const GATHER_BWD_WGSL: &str = include_str!("gather_backward.wgsl");
63pub const CUMSUM_WGSL: &str = include_str!("cumsum.wgsl");
64pub const FFT_GPU_WGSL: &str = include_str!("fft_gpu.wgsl");
65pub const COPY_WGSL: &str = include_str!("copy.wgsl");
66pub const ELEMENTWISE_REGION_WGSL: &str = include_str!("elementwise_region.wgsl");
67pub const TRANSPOSE_WGSL: &str = include_str!("transpose.wgsl");
68pub const NARROW_WGSL: &str = include_str!("narrow.wgsl");
69pub const CONCAT_WGSL: &str = include_str!("concat.wgsl");
70pub const GATHER_WGSL: &str = include_str!("gather.wgsl");
71pub const GATHER_AXIS_WGSL: &str = include_str!("gather_axis.wgsl");
72pub const ATTENTION_WGSL: &str = include_str!("attention.wgsl");
73pub const ATTENTION_BWD_WGSL: &str = include_str!("attention_bwd.wgsl");
74pub const ROPE_WGSL: &str = include_str!("rope.wgsl");
75pub const EXPAND_WGSL: &str = include_str!("expand.wgsl");
76pub const ARGMAX_WGSL: &str = include_str!("argmax.wgsl");
77pub const POOL2D_WGSL: &str = include_str!("pool2d.wgsl");
78pub const CONV2D_WGSL: &str = include_str!("conv2d.wgsl");
79pub const POOL1D_WGSL: &str = include_str!("pool1d.wgsl");
80pub const POOL3D_WGSL: &str = include_str!("pool3d.wgsl");
81pub const CONV1D_WGSL: &str = include_str!("conv1d.wgsl");
82pub const CONV3D_WGSL: &str = include_str!("conv3d.wgsl");
83pub const SCATTER_ADD_WGSL: &str = include_str!("scatter_add.wgsl");
84pub const TOPK_WGSL: &str = include_str!("topk.wgsl");
85pub const WELCH_PEAKS_GPU_WGSL: &str = include_str!("welch_peaks_gpu.wgsl");
86pub const UMAP_KNN_WGSL: &str = include_str!("umap_knn.wgsl");
87pub const GROUPED_MATMUL_WGSL: &str = include_str!("grouped_matmul.wgsl");
88pub const SAMPLE_WGSL: &str = include_str!("sample.wgsl");
89pub const SELECTIVE_SCAN_WGSL: &str = include_str!("selective_scan.wgsl");
90pub const DEQUANT_MATMUL_WGSL: &str = include_str!("dequant_matmul.wgsl");
91pub const FUSED_RESIDUAL_LN_WGSL: &str = include_str!("fused_residual_ln.wgsl");
92pub const FUSED_RESIDUAL_LN_TEE_WGSL: &str = include_str!("fused_residual_ln_tee.wgsl");
93pub const FUSED_RESIDUAL_RMS_NORM_WGSL: &str = include_str!("fused_residual_rms_norm.wgsl");
94pub const MATMUL_QKV_WGSL: &str = include_str!("matmul_qkv.wgsl");
95pub const MATMUL_QKV_COOP_F32_WGSL: &str = include_str!("matmul_qkv_coop_f32.wgsl");
96
97#[repr(C)]
98#[derive(Debug, Clone, Copy, Pod, Zeroable)]
99pub struct MatmulParams {
100 pub m: u32,
101 pub k: u32,
102 pub n: u32,
103 pub a_off: u32,
104 pub b_off: u32,
105 pub c_off: u32,
106 pub batch: u32,
107 pub a_batch_stride: u32,
108 pub b_batch_stride: u32,
109 pub c_batch_stride: u32,
110 pub has_bias: u32,
111 pub bias_off: u32,
112 pub act_id: u32, pub _pad0: u32,
114 pub _pad1: u32,
115 pub _pad2: u32,
116}
117
118#[repr(C)]
120#[derive(Debug, Clone, Copy, Pod, Zeroable)]
121pub struct BinaryParams {
122 pub n: u32,
123 pub a_off: u32,
124 pub b_off: u32,
125 pub c_off: u32,
126 pub op: u32,
127 pub _p0: u32,
128 pub _p1: u32,
129 pub _p2: u32,
130}
131
132#[repr(C)]
134#[derive(Debug, Clone, Copy, Pod, Zeroable)]
135pub struct UnaryParams {
136 pub n: u32,
137 pub in_off: u32,
138 pub out_off: u32,
139 pub op: u32,
140 pub _p0: u32,
141 pub _p1: u32,
142 pub _p2: u32,
143 pub _p3: u32,
144}
145
146#[repr(C)]
148#[derive(Debug, Clone, Copy, Pod, Zeroable)]
149pub struct WhereParams {
150 pub n: u32,
151 pub cond_off: u32,
152 pub x_off: u32,
153 pub y_off: u32,
154 pub out_off: u32,
155 pub _p0: u32,
156 pub _p1: u32,
157 pub _p2: u32,
158}
159
160#[repr(C)]
171pub struct ReduceParams {
172 pub outer: u32,
173 pub reduce_dim: u32,
174 pub inner: u32,
175 pub in_off: u32,
176 pub out_off: u32,
177 pub op: u32,
178 pub _p0: u32,
179 pub _p1: u32,
180}
181
182unsafe impl Pod for ReduceParams {}
185unsafe impl Zeroable for ReduceParams {}
186impl Copy for ReduceParams {}
187impl Clone for ReduceParams {
188 fn clone(&self) -> Self {
189 *self
190 }
191}
192impl std::fmt::Debug for ReduceParams {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 write!(
195 f,
196 "ReduceParams {{ outer: {}, reduce_dim: {}, inner: {}, op: {} }}",
197 self.outer, self.reduce_dim, self.inner, self.op
198 )
199 }
200}
201
202#[repr(C)]
204#[derive(Debug, Clone, Copy, Pod, Zeroable)]
205pub struct SoftmaxParams {
206 pub outer: u32,
207 pub inner: u32,
208 pub in_off: u32,
209 pub out_off: u32,
210 pub _p0: u32,
211 pub _p1: u32,
212 pub _p2: u32,
213 pub _p3: u32,
214}
215
216#[repr(C)]
218#[derive(Debug, Clone, Copy, Pod, Zeroable)]
219pub struct LayerNormParams {
220 pub outer: u32,
221 pub inner: u32,
222 pub in_off: u32,
223 pub out_off: u32,
224 pub gamma_off: u32,
225 pub beta_off: u32,
226 pub eps_bits: u32, pub op: u32, }
229
230#[repr(C)]
242#[derive(Debug, Clone, Copy, Pod, Zeroable)]
243pub struct LayerNormBwdParams {
244 pub outer: u32,
245 pub inner: u32,
246 pub x_off: u32,
247 pub gamma_off: u32,
248 pub dy_off: u32,
249 pub out_off: u32,
250 pub eps_bits: u32,
251 pub scratch_off: u32,
252}
253
254#[repr(C)]
256#[derive(Debug, Clone, Copy, Pod, Zeroable)]
257pub struct RmsNormBwdParams {
258 pub outer: u32,
259 pub inner: u32,
260 pub x_off: u32,
261 pub gamma_off: u32,
262 pub beta_off: u32,
263 pub dy_off: u32,
264 pub out_off: u32,
265 pub eps_bits: u32,
266 pub wrt: u32,
267}
268
269#[repr(C)]
270#[derive(Debug, Clone, Copy, Pod, Zeroable)]
271pub struct CumsumBwdParams {
272 pub outer: u32,
273 pub inner: u32,
274 pub dy_off: u32,
275 pub dx_off: u32,
276 pub exclusive: u32,
277 pub _p0: u32,
278 pub _p1: u32,
279 pub _p2: u32,
280}
281
282#[repr(C)]
283#[derive(Debug, Clone, Copy, Pod, Zeroable)]
284pub struct RopeBwdParams {
285 pub batch: u32,
286 pub seq: u32,
287 pub hidden: u32,
288 pub head_dim: u32,
289 pub n_rot: u32,
290 pub dy_off: u32,
291 pub cos_off: u32,
292 pub sin_off: u32,
293 pub dx_off: u32,
294 pub cos_len: u32,
295}
296
297#[repr(C)]
298#[derive(Debug, Clone, Copy, Pod, Zeroable)]
299pub struct GatherBwdParams {
300 pub outer: u32,
301 pub axis_dim: u32,
302 pub num_idx: u32,
303 pub trailing: u32,
304 pub dy_off: u32,
305 pub idx_off: u32,
306 pub dst_off: u32,
307 pub _p0: u32,
308}
309
310#[repr(C)]
312#[derive(Debug, Clone, Copy, Pod, Zeroable)]
313pub struct CumsumParams {
314 pub outer: u32,
315 pub inner: u32,
316 pub in_off: u32,
317 pub out_off: u32,
318 pub exclusive: u32,
319 pub _p0: u32,
320 pub _p1: u32,
321 pub _p2: u32,
322}
323
324#[repr(C)]
326#[derive(Debug, Clone, Copy, Pod, Zeroable)]
327pub struct FftParams {
328 pub src_off: u32,
329 pub dst_off: u32,
330 pub n: u32,
331 pub log2n: u32,
332 pub inverse: u32,
333 pub norm_scale: f32,
334 pub _p1: u32,
335 pub _p2: u32,
336}
337
338#[repr(C)]
340#[derive(Debug, Clone, Copy, Pod, Zeroable)]
341pub struct FftGpuParams {
342 pub off: u32,
343 pub dst_off: u32,
344 pub n: u32,
345 pub log2n: u32,
346 pub inverse: u32,
347 pub norm_scale: f32,
348 pub outer: u32,
349 pub tile: u32,
350 pub inner_stages: u32,
351 pub q_or_hs: u32,
352}
353
354#[repr(C)]
365#[derive(Debug, Clone, Copy, Pod, Zeroable)]
366pub struct ElementwiseRegionParams {
367 pub len: u32,
368 pub num_inputs: u32,
369 pub num_steps: u32,
370 pub dst_off: u32,
371 pub input_offs: [u32; 16],
372 pub chain: [u32; 128], pub scalar_input_mask: u32,
374 pub prologue: u32,
375 pub out_n: u32,
376 pub out_c: u32,
377 pub out_h: u32,
378 pub out_w: u32,
379 pub prologue_input: u32,
380 pub input_modulus: [u32; 16],
381}
382
383#[repr(C)]
385#[derive(Debug, Clone, Copy, Pod, Zeroable)]
386pub struct BatchElementwiseRegionParams {
387 pub slice_len: u32,
388 pub num_batch: u32,
389 pub num_steps: u32,
390 pub base_dst_off: u32,
391 pub slice_elems: u32,
392 pub batch_input_offs: [u32; 64],
393 pub chain: [u32; 128],
394 pub scalar_input_mask: u32,
395 pub input_modulus: [u32; 16],
396}
397
398#[repr(C)]
400#[derive(Debug, Clone, Copy, Pod, Zeroable)]
401pub struct CopyParams {
402 pub n: u32,
403 pub in_off: u32,
404 pub out_off: u32,
405 pub _p0: u32,
406 pub _p1: u32,
407 pub _p2: u32,
408 pub _p3: u32,
409 pub _p4: u32,
410}
411
412#[repr(C)]
414#[derive(Debug, Clone, Copy, Pod, Zeroable)]
415pub struct TransposeParams {
416 pub rank: u32,
417 pub out_total: u32,
418 pub in_off: u32,
419 pub out_off: u32,
420 pub bucket_outermost: u32,
424 pub out_dim_0: u32,
426 pub _p2: u32,
427 pub _p3: u32,
428}
429
430#[repr(C)]
432#[derive(Debug, Clone, Copy, Pod, Zeroable)]
433pub struct NarrowConcatParams {
434 pub total: u32, pub outer: u32,
436 pub inner: u32,
437 pub axis_in_size: u32,
438 pub axis_out_size: u32,
439 pub start: u32,
440 pub in_off: u32,
441 pub out_off: u32,
442}
443
444#[repr(C)]
446#[derive(Debug, Clone, Copy, Pod, Zeroable)]
447pub struct GatherParams {
448 pub n_out: u32,
449 pub n_idx: u32,
450 pub dim: u32,
451 pub vocab: u32,
452 pub in_off: u32,
453 pub idx_off: u32,
454 pub out_off: u32,
455 pub _p0: u32,
456}
457
458#[repr(C)]
460#[derive(Debug, Clone, Copy, Pod, Zeroable)]
461pub struct GatherAxisParams {
462 pub total: u32,
463 pub outer: u32,
464 pub axis_dim: u32,
465 pub num_idx: u32,
466 pub trailing: u32,
467 pub table_off: u32,
468 pub idx_off: u32,
469 pub out_off: u32,
470}
471
472#[repr(C)]
486#[derive(Debug, Clone, Copy, Pod, Zeroable)]
487pub struct AttentionParams {
488 pub batch: u32,
489 pub heads: u32,
490 pub seq_q: u32,
491 pub seq_k: u32,
492 pub head_dim: u32,
493 pub q_off: u32,
494 pub k_off: u32,
495 pub v_off: u32,
496 pub out_off: u32,
497 pub mask_off: u32,
498 pub mask_kind: u32,
499 pub scale_bits: u32,
500 pub window: u32,
501 pub seq_q_stride: u32,
513 pub seq_k_stride: u32,
514 pub mask_batch_stride: u32,
515 pub mask_head_stride: u32,
516 pub _pad_mask_0: u32,
517 pub _pad_mask_1: u32,
518 pub _pad_mask_2: u32,
519
520 pub q_batch_stride: u32,
525 pub q_head_stride: u32,
526 pub q_seq_stride: u32,
527 pub _pad_q: u32,
528
529 pub k_batch_stride: u32,
530 pub k_head_stride: u32,
531 pub k_seq_stride: u32,
532 pub _pad_k: u32,
533
534 pub v_batch_stride: u32,
535 pub v_head_stride: u32,
536 pub v_seq_stride: u32,
537 pub _pad_v: u32,
538
539 pub o_batch_stride: u32,
540 pub o_head_stride: u32,
541 pub o_seq_stride: u32,
542 pub _pad_o: u32,
543}
544
545#[repr(C)]
547#[derive(Debug, Clone, Copy, Pod, Zeroable)]
548pub struct AttentionBwdParams {
549 pub batch: u32,
550 pub heads: u32,
551 pub seq_q: u32,
552 pub seq_k: u32,
553 pub head_dim: u32,
554 pub q_off: u32,
555 pub k_off: u32,
556 pub v_off: u32,
557 pub dy_off: u32,
558 pub out_off: u32,
559 pub mask_off: u32,
560 pub mask_kind: u32,
561 pub scale_bits: u32,
562 pub window: u32,
563 pub wrt: u32,
564 pub seq_q_stride: u32,
565 pub seq_k_stride: u32,
566 pub mask_batch_stride: u32,
567 pub mask_head_stride: u32,
568 pub _pad_mask_0: u32,
569 pub _pad_mask_1: u32,
570 pub _pad_mask_2: u32,
571 pub q_batch_stride: u32,
572 pub q_head_stride: u32,
573 pub q_seq_stride: u32,
574 pub _pad_q: u32,
575 pub k_batch_stride: u32,
576 pub k_head_stride: u32,
577 pub k_seq_stride: u32,
578 pub _pad_k: u32,
579 pub v_batch_stride: u32,
580 pub v_head_stride: u32,
581 pub v_seq_stride: u32,
582 pub _pad_v: u32,
583 pub o_batch_stride: u32,
584 pub o_head_stride: u32,
585 pub o_seq_stride: u32,
586 pub _pad_o: u32,
587}
588
589#[repr(C)]
591#[derive(Debug, Clone, Copy, Pod, Zeroable)]
592pub struct RopeParams {
593 pub n_total: u32,
594 pub seq: u32,
595 pub head_dim: u32,
596 pub half: u32,
597 pub in_off: u32,
598 pub cos_off: u32,
599 pub sin_off: u32,
600 pub out_off: u32,
601 pub last_dim: u32,
602 pub batch: u32,
608 pub seq_stride: u32,
609 pub _p2: u32,
610}
611
612#[repr(C)]
615#[derive(Debug, Clone, Copy, Pod, Zeroable)]
616pub struct ExpandParams {
617 pub rank: u32,
618 pub out_total: u32,
619 pub in_off: u32,
620 pub out_off: u32,
621 pub bucket_outermost: u32,
624 pub out_dim_0: u32,
626 pub _p2: u32,
627 pub _p3: u32,
628}
629
630#[repr(C)]
632#[derive(Debug, Clone, Copy, Pod, Zeroable)]
633pub struct ArgmaxParams {
634 pub outer: u32,
635 pub inner: u32,
636 pub in_off: u32,
637 pub out_off: u32,
638 pub _p0: u32,
639 pub _p1: u32,
640 pub _p2: u32,
641 pub _p3: u32,
642}
643
644#[repr(C)]
646#[derive(Debug, Clone, Copy, Pod, Zeroable)]
647pub struct Pool2dParams {
648 pub n: u32,
649 pub c: u32,
650 pub h: u32,
651 pub w: u32,
652 pub h_out: u32,
653 pub w_out: u32,
654 pub kh: u32,
655 pub kw: u32,
656 pub sh: u32,
657 pub sw: u32,
658 pub ph: u32,
659 pub pw: u32,
660 pub op: u32,
661 pub in_off: u32,
662 pub out_off: u32,
663 pub _p0: u32,
664 pub _p1: u32,
665 pub _p2: u32,
666}
667
668#[repr(C)]
670#[derive(Debug, Clone, Copy, Pod, Zeroable)]
671pub struct Conv2dParams {
672 pub n: u32,
673 pub c_in: u32,
674 pub c_out: u32,
675 pub h: u32,
676 pub w: u32,
677 pub h_out: u32,
678 pub w_out: u32,
679 pub kh: u32,
680 pub kw: u32,
681 pub sh: u32,
682 pub sw: u32,
683 pub ph: u32,
684 pub pw: u32,
685 pub dh: u32,
686 pub dw: u32,
687 pub groups: u32,
688 pub in_off: u32,
689 pub w_off: u32,
690 pub out_off: u32,
691}
692
693#[repr(C)]
695#[derive(Debug, Clone, Copy, Pod, Zeroable)]
696pub struct Pool1dParams {
697 pub n: u32,
698 pub c: u32,
699 pub l: u32,
700 pub l_out: u32,
701 pub kl: u32,
702 pub sl: u32,
703 pub pl: u32,
704 pub op: u32,
705 pub in_off: u32,
706 pub out_off: u32,
707 pub _p0: u32,
708 pub _p1: u32,
709 pub _p2: u32,
710 pub _p3: u32,
711 pub _p4: u32,
712 pub _p5: u32,
713}
714
715#[repr(C)]
717#[derive(Debug, Clone, Copy, Pod, Zeroable)]
718pub struct Pool3dParams {
719 pub n: u32,
720 pub c: u32,
721 pub d: u32,
722 pub h: u32,
723 pub w: u32,
724 pub d_out: u32,
725 pub h_out: u32,
726 pub w_out: u32,
727 pub kd: u32,
728 pub kh: u32,
729 pub kw: u32,
730 pub sd: u32,
731 pub sh: u32,
732 pub sw: u32,
733 pub pd: u32,
734 pub ph: u32,
735 pub pw: u32,
736 pub op: u32,
737 pub in_off: u32,
738 pub out_off: u32,
739 pub _p0: u32,
740 pub _p1: u32,
741}
742
743#[repr(C)]
745#[derive(Debug, Clone, Copy, Pod, Zeroable)]
746pub struct Conv1dParams {
747 pub n: u32,
748 pub c_in: u32,
749 pub c_out: u32,
750 pub l: u32,
751 pub l_out: u32,
752 pub kl: u32,
753 pub sl: u32,
754 pub pl: u32,
755 pub dl: u32,
756 pub groups: u32,
757 pub in_off: u32,
758 pub w_off: u32,
759 pub out_off: u32,
760 pub _p0: u32,
761 pub _p1: u32,
762 pub _p2: u32,
763}
764
765#[repr(C)]
767#[derive(Debug, Clone, Copy, Pod, Zeroable)]
768pub struct DequantMatmulParams {
769 pub m: u32,
770 pub k: u32,
771 pub n: u32,
772 pub block_size: u32,
773 pub scheme_id: u32,
774 pub x_off: u32,
775 pub w_off: u32,
776 pub scale_off: u32,
777 pub zp_off: u32,
778 pub out_off: u32,
779 pub _p0: u32,
780 pub _p1: u32,
781}
782
783#[repr(C)]
785#[derive(Debug, Clone, Copy, Pod, Zeroable)]
786pub struct FusedResidualLnTeeParams {
787 pub outer: u32,
788 pub inner: u32,
789 pub in_off: u32,
790 pub residual_off: u32,
791 pub bias_off: u32,
792 pub gamma_off: u32,
793 pub beta_off: u32,
794 pub sum_off: u32,
795 pub ln_out_off: u32,
796 pub eps_bits: u32,
797 pub has_bias: u32,
798 pub _p0: u32,
799}
800
801#[repr(C)]
804#[derive(Debug, Clone, Copy, Pod, Zeroable)]
805pub struct MatmulQkvParams {
806 pub m: u32,
807 pub k: u32,
808 pub n: u32,
809 pub a_off: u32,
810 pub b_off: u32,
811 pub q_off: u32,
812 pub k_off: u32,
813 pub v_off: u32,
814 pub head_width: u32,
815 pub has_bias: u32,
816 pub bias_off: u32,
817 pub _p0: u32,
818 pub _p1: u32,
819 pub _p2: u32,
820 pub _p3: u32,
821 pub _p4: u32,
822}
823
824pub type FusedResidualRmsNormParams = FusedResidualLnParams;
826
827#[repr(C)]
829#[derive(Debug, Clone, Copy, Pod, Zeroable)]
830pub struct FusedResidualLnParams {
831 pub outer: u32,
832 pub inner: u32,
833 pub in_off: u32,
834 pub residual_off: u32,
835 pub bias_off: u32,
836 pub gamma_off: u32,
837 pub beta_off: u32,
838 pub out_off: u32,
839 pub eps_bits: u32,
840 pub has_bias: u32,
841 pub _p0: u32,
842 pub _p1: u32,
843}
844
845#[repr(C)]
847#[derive(Debug, Clone, Copy, Pod, Zeroable)]
848pub struct SelectiveScanParams {
849 pub batch: u32,
850 pub seq: u32,
851 pub hidden: u32,
852 pub state_size: u32,
853 pub x_off: u32,
854 pub delta_off: u32,
855 pub a_off: u32,
856 pub b_off: u32,
857 pub c_off: u32,
858 pub out_off: u32,
859 pub seq_stride: u32,
863 pub _p1: u32,
864 pub _p2: u32,
865 pub _p3: u32,
866 pub _p4: u32,
867 pub _p5: u32,
868}
869
870#[repr(C)]
872#[derive(Debug, Clone, Copy, Pod, Zeroable)]
873pub struct SampleParams {
874 pub outer: u32,
875 pub inner: u32,
876 pub in_off: u32,
877 pub out_off: u32,
878 pub top_k: u32,
879 pub top_p_bits: u32,
880 pub temp_bits: u32,
881 pub seed_lo: u32,
882 pub seed_hi: u32,
883 pub _p0: u32,
884 pub _p1: u32,
885 pub _p2: u32,
886}
887
888#[repr(C)]
890#[derive(Debug, Clone, Copy, Pod, Zeroable)]
891pub struct GroupedMatmulParams {
892 pub m: u32,
893 pub k: u32,
894 pub n: u32,
895 pub num_experts: u32,
896 pub in_off: u32,
897 pub w_off: u32,
898 pub idx_off: u32,
899 pub out_off: u32,
900}
901
902#[repr(C)]
904#[derive(Debug, Clone, Copy, Pod, Zeroable)]
905pub struct TopKParams {
906 pub outer: u32,
907 pub inner: u32,
908 pub k: u32,
909 pub in_off: u32,
910 pub out_off: u32,
911 pub _p0: u32,
912 pub _p1: u32,
913 pub _p2: u32,
914}
915
916#[repr(C)]
918#[derive(Debug, Clone, Copy, Pod, Zeroable)]
919pub struct WelchPeaksGpuParams {
920 pub spec_off: u32,
921 pub dst_off: u32,
922 pub welch_batch: u32,
923 pub n_fft: u32,
924 pub n_segments: u32,
925 pub k: u32,
926 pub n_bins: u32,
927 pub _p0: u32,
928 pub _p1: u32,
929}
930
931#[repr(C)]
933#[derive(Debug, Clone, Copy, Pod, Zeroable)]
934pub struct UmapKnnParams {
935 pub n: u32,
936 pub k: u32,
937 pub pw_off: u32,
938 pub out_off: u32,
939 pub _p0: u32,
940 pub _p1: u32,
941 pub _p2: u32,
942}
943
944#[repr(C)]
946#[derive(Debug, Clone, Copy, Pod, Zeroable)]
947pub struct ScatterAddParams {
948 pub op: u32, pub out_off: u32,
950 pub upd_off: u32,
951 pub idx_off: u32,
952 pub out_total: u32,
953 pub num_updates: u32,
954 pub trailing: u32,
955 pub out_dim: u32,
956}
957
958#[repr(C)]
960#[derive(Debug, Clone, Copy, Pod, Zeroable)]
961pub struct Conv3dParams {
962 pub n: u32,
963 pub c_in: u32,
964 pub c_out: u32,
965 pub d: u32,
966 pub h: u32,
967 pub w: u32,
968 pub d_out: u32,
969 pub h_out: u32,
970 pub w_out: u32,
971 pub kd: u32,
972 pub kh: u32,
973 pub kw: u32,
974 pub sd: u32,
975 pub sh: u32,
976 pub sw: u32,
977 pub pd: u32,
978 pub ph: u32,
979 pub pw: u32,
980 pub dd: u32,
981 pub dh: u32,
982 pub dw: u32,
983 pub groups: u32,
984 pub in_off: u32,
985 pub w_off: u32,
986 pub out_off: u32,
987 pub _p0: u32,
988}
989
990pub struct Kernel {
992 pub pipeline: wgpu::ComputePipeline,
993 pub bgl: wgpu::BindGroupLayout,
994}
995
996impl Kernel {
997 pub fn bind_two(
998 &self,
999 device: &wgpu::Device,
1000 arena: &wgpu::Buffer,
1001 uniform: &wgpu::Buffer,
1002 ) -> wgpu::BindGroup {
1003 device.create_bind_group(&wgpu::BindGroupDescriptor {
1004 label: Some("rlx-wgpu fft gpu bg"),
1005 layout: &self.bgl,
1006 entries: &[
1007 wgpu::BindGroupEntry {
1008 binding: 0,
1009 resource: arena.as_entire_binding(),
1010 },
1011 wgpu::BindGroupEntry {
1012 binding: 1,
1013 resource: uniform.as_entire_binding(),
1014 },
1015 ],
1016 })
1017 }
1018}
1019
1020#[allow(dead_code)]
1026fn build_kernel_4(
1029 device: &wgpu::Device,
1030 label: &'static str,
1031 wgsl: &str,
1032 entry_point: &'static str,
1033) -> Kernel {
1034 let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1035 label: Some(label),
1036 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1037 });
1038 let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1039 label: Some(label),
1040 entries: &[
1041 wgpu::BindGroupLayoutEntry {
1042 binding: 0,
1043 visibility: wgpu::ShaderStages::COMPUTE,
1044 ty: wgpu::BindingType::Buffer {
1045 ty: wgpu::BufferBindingType::Storage { read_only: false },
1046 has_dynamic_offset: false,
1047 min_binding_size: None,
1048 },
1049 count: None,
1050 },
1051 wgpu::BindGroupLayoutEntry {
1052 binding: 1,
1053 visibility: wgpu::ShaderStages::COMPUTE,
1054 ty: wgpu::BindingType::Buffer {
1055 ty: wgpu::BufferBindingType::Uniform,
1056 has_dynamic_offset: false,
1057 min_binding_size: None,
1058 },
1059 count: None,
1060 },
1061 wgpu::BindGroupLayoutEntry {
1062 binding: 2,
1063 visibility: wgpu::ShaderStages::COMPUTE,
1064 ty: wgpu::BindingType::Buffer {
1065 ty: wgpu::BufferBindingType::Storage { read_only: true },
1066 has_dynamic_offset: false,
1067 min_binding_size: None,
1068 },
1069 count: None,
1070 },
1071 wgpu::BindGroupLayoutEntry {
1072 binding: 3,
1073 visibility: wgpu::ShaderStages::COMPUTE,
1074 ty: wgpu::BindingType::Buffer {
1075 ty: wgpu::BufferBindingType::Storage { read_only: true },
1076 has_dynamic_offset: false,
1077 min_binding_size: None,
1078 },
1079 count: None,
1080 },
1081 ],
1082 });
1083 let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1084 label: Some(label),
1085 bind_group_layouts: &[Some(&bgl)],
1086 immediate_size: 0,
1087 });
1088 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1089 label: Some(label),
1090 layout: Some(&layout),
1091 module: &module,
1092 entry_point: Some(entry_point),
1093 compilation_options: Default::default(),
1094 cache: None,
1095 });
1096 Kernel { pipeline, bgl }
1097}
1098
1099fn build_kernel_3(
1100 device: &wgpu::Device,
1101 label: &'static str,
1102 wgsl: &str,
1103 entry_point: &'static str,
1104) -> Kernel {
1105 let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1106 label: Some(label),
1107 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1108 });
1109 let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1110 label: Some(label),
1111 entries: &[
1112 wgpu::BindGroupLayoutEntry {
1113 binding: 0,
1114 visibility: wgpu::ShaderStages::COMPUTE,
1115 ty: wgpu::BindingType::Buffer {
1116 ty: wgpu::BufferBindingType::Storage { read_only: false },
1117 has_dynamic_offset: false,
1118 min_binding_size: None,
1119 },
1120 count: None,
1121 },
1122 wgpu::BindGroupLayoutEntry {
1123 binding: 1,
1124 visibility: wgpu::ShaderStages::COMPUTE,
1125 ty: wgpu::BindingType::Buffer {
1126 ty: wgpu::BufferBindingType::Uniform,
1127 has_dynamic_offset: false,
1128 min_binding_size: None,
1129 },
1130 count: None,
1131 },
1132 wgpu::BindGroupLayoutEntry {
1133 binding: 2,
1134 visibility: wgpu::ShaderStages::COMPUTE,
1135 ty: wgpu::BindingType::Buffer {
1136 ty: wgpu::BufferBindingType::Storage { read_only: true },
1137 has_dynamic_offset: false,
1138 min_binding_size: None,
1139 },
1140 count: None,
1141 },
1142 ],
1143 });
1144 let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1145 label: Some(label),
1146 bind_group_layouts: &[Some(&bgl)],
1147 immediate_size: 0,
1148 });
1149 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1150 label: Some(label),
1151 layout: Some(&layout),
1152 module: &module,
1153 entry_point: Some(entry_point),
1154 compilation_options: Default::default(),
1155 cache: None,
1156 });
1157 Kernel { pipeline, bgl }
1158}
1159
1160fn build_kernel_cast_f32_to_f16(
1164 device: &wgpu::Device,
1165 label: &'static str,
1166 wgsl: &str,
1167 entry_point: &'static str,
1168) -> Kernel {
1169 let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1170 label: Some(label),
1171 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1172 });
1173 let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1174 label: Some(label),
1175 entries: &[
1176 wgpu::BindGroupLayoutEntry {
1177 binding: 0,
1178 visibility: wgpu::ShaderStages::COMPUTE,
1179 ty: wgpu::BindingType::Buffer {
1180 ty: wgpu::BufferBindingType::Storage { read_only: false },
1181 has_dynamic_offset: false,
1182 min_binding_size: None,
1183 },
1184 count: None,
1185 },
1186 wgpu::BindGroupLayoutEntry {
1187 binding: 1,
1188 visibility: wgpu::ShaderStages::COMPUTE,
1189 ty: wgpu::BindingType::Buffer {
1190 ty: wgpu::BufferBindingType::Uniform,
1191 has_dynamic_offset: false,
1192 min_binding_size: None,
1193 },
1194 count: None,
1195 },
1196 wgpu::BindGroupLayoutEntry {
1197 binding: 2,
1198 visibility: wgpu::ShaderStages::COMPUTE,
1199 ty: wgpu::BindingType::Buffer {
1200 ty: wgpu::BufferBindingType::Storage { read_only: false },
1201 has_dynamic_offset: false,
1202 min_binding_size: None,
1203 },
1204 count: None,
1205 },
1206 ],
1207 });
1208 let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1209 label: Some(label),
1210 bind_group_layouts: &[Some(&bgl)],
1211 immediate_size: 0,
1212 });
1213 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1214 label: Some(label),
1215 layout: Some(&layout),
1216 module: &module,
1217 entry_point: Some(entry_point),
1218 compilation_options: Default::default(),
1219 cache: None,
1220 });
1221 Kernel { pipeline, bgl }
1222}
1223
1224fn build_kernel_f32_rw_uniform_f16_rw(
1226 device: &wgpu::Device,
1227 label: &'static str,
1228 wgsl: &str,
1229 entry_point: &'static str,
1230) -> Kernel {
1231 let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1232 label: Some(label),
1233 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1234 });
1235 let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1236 label: Some(label),
1237 entries: &[
1238 wgpu::BindGroupLayoutEntry {
1239 binding: 0,
1240 visibility: wgpu::ShaderStages::COMPUTE,
1241 ty: wgpu::BindingType::Buffer {
1242 ty: wgpu::BufferBindingType::Storage { read_only: false },
1243 has_dynamic_offset: false,
1244 min_binding_size: None,
1245 },
1246 count: None,
1247 },
1248 wgpu::BindGroupLayoutEntry {
1249 binding: 1,
1250 visibility: wgpu::ShaderStages::COMPUTE,
1251 ty: wgpu::BindingType::Buffer {
1252 ty: wgpu::BufferBindingType::Uniform,
1253 has_dynamic_offset: false,
1254 min_binding_size: None,
1255 },
1256 count: None,
1257 },
1258 wgpu::BindGroupLayoutEntry {
1259 binding: 2,
1260 visibility: wgpu::ShaderStages::COMPUTE,
1261 ty: wgpu::BindingType::Buffer {
1262 ty: wgpu::BufferBindingType::Storage { read_only: false },
1263 has_dynamic_offset: false,
1264 min_binding_size: None,
1265 },
1266 count: None,
1267 },
1268 ],
1269 });
1270 let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1271 label: Some(label),
1272 bind_group_layouts: &[Some(&bgl)],
1273 immediate_size: 0,
1274 });
1275 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1276 label: Some(label),
1277 layout: Some(&layout),
1278 module: &module,
1279 entry_point: Some(entry_point),
1280 compilation_options: Default::default(),
1281 cache: None,
1282 });
1283 Kernel { pipeline, bgl }
1284}
1285
1286fn build_kernel_coop_f16_vk(
1288 device: &wgpu::Device,
1289 label: &'static str,
1290 wgsl: &str,
1291 entry_point: &'static str,
1292) -> Kernel {
1293 let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1294 label: Some(label),
1295 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1296 });
1297 let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1298 label: Some(label),
1299 entries: &[
1300 wgpu::BindGroupLayoutEntry {
1301 binding: 0,
1302 visibility: wgpu::ShaderStages::COMPUTE,
1303 ty: wgpu::BindingType::Buffer {
1304 ty: wgpu::BufferBindingType::Storage { read_only: true },
1305 has_dynamic_offset: false,
1306 min_binding_size: None,
1307 },
1308 count: None,
1309 },
1310 wgpu::BindGroupLayoutEntry {
1311 binding: 1,
1312 visibility: wgpu::ShaderStages::COMPUTE,
1313 ty: wgpu::BindingType::Buffer {
1314 ty: wgpu::BufferBindingType::Storage { read_only: false },
1315 has_dynamic_offset: false,
1316 min_binding_size: None,
1317 },
1318 count: None,
1319 },
1320 wgpu::BindGroupLayoutEntry {
1321 binding: 2,
1322 visibility: wgpu::ShaderStages::COMPUTE,
1323 ty: wgpu::BindingType::Buffer {
1324 ty: wgpu::BufferBindingType::Uniform,
1325 has_dynamic_offset: false,
1326 min_binding_size: None,
1327 },
1328 count: None,
1329 },
1330 ],
1331 });
1332 let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1333 label: Some(label),
1334 bind_group_layouts: &[Some(&bgl)],
1335 immediate_size: 0,
1336 });
1337 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1338 label: Some(label),
1339 layout: Some(&layout),
1340 module: &module,
1341 entry_point: Some(entry_point),
1342 compilation_options: Default::default(),
1343 cache: None,
1344 });
1345 Kernel { pipeline, bgl }
1346}
1347
1348fn try_build_kernel_coop_f16_vk(
1349 device: &wgpu::Device,
1350 label: &'static str,
1351 wgsl: &str,
1352 entry_point: &'static str,
1353) -> Option<Kernel> {
1354 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
1355 build_kernel_coop_f16_vk(device, label, wgsl, entry_point)
1356 }))
1357 .ok()
1358}
1359
1360fn build_kernel(
1361 device: &wgpu::Device,
1362 label: &'static str,
1363 wgsl: &str,
1364 entry_point: &'static str,
1365) -> Kernel {
1366 let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1367 label: Some(label),
1368 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1369 });
1370 let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1371 label: Some(label),
1372 entries: &[
1373 wgpu::BindGroupLayoutEntry {
1374 binding: 0,
1375 visibility: wgpu::ShaderStages::COMPUTE,
1376 ty: wgpu::BindingType::Buffer {
1377 ty: wgpu::BufferBindingType::Storage { read_only: false },
1378 has_dynamic_offset: false,
1379 min_binding_size: None,
1380 },
1381 count: None,
1382 },
1383 wgpu::BindGroupLayoutEntry {
1384 binding: 1,
1385 visibility: wgpu::ShaderStages::COMPUTE,
1386 ty: wgpu::BindingType::Buffer {
1387 ty: wgpu::BufferBindingType::Uniform,
1388 has_dynamic_offset: false,
1389 min_binding_size: None,
1390 },
1391 count: None,
1392 },
1393 ],
1394 });
1395 let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1396 label: Some(label),
1397 bind_group_layouts: &[Some(&bgl)],
1398 immediate_size: 0,
1399 });
1400 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1401 label: Some(label),
1402 layout: Some(&layout),
1403 module: &module,
1404 entry_point: Some(entry_point),
1405 compilation_options: Default::default(),
1406 cache: None,
1407 });
1408 Kernel { pipeline, bgl }
1409}
1410
1411static MATMUL: OnceLock<Kernel> = OnceLock::new();
1412static MATMUL_WIDE: OnceLock<Kernel> = OnceLock::new();
1413static MATMUL_WIDE_NV: OnceLock<Kernel> = OnceLock::new();
1414static MATMUL_F16W: OnceLock<Kernel> = OnceLock::new();
1415static MATMUL_F16_COMPUTE: OnceLock<Kernel> = OnceLock::new();
1416static MATMUL_COOP16: OnceLock<Kernel> = OnceLock::new();
1417static MATMUL_COOP_F32: OnceLock<Kernel> = OnceLock::new();
1418static MATMUL_COOP_F32_PORTABLE: OnceLock<Kernel> = OnceLock::new();
1419static MATMUL_COOP_F16_VULKAN: OnceLock<Kernel> = OnceLock::new();
1420static MATMUL_COOP_F16_VULKAN_WIDEN: OnceLock<Kernel> = OnceLock::new();
1421static MATMUL_COOP_F16_VULKAN_F32ACC: OnceLock<Option<Kernel>> = OnceLock::new();
1422static MATMUL_COOP_F16_VULKAN_WIDEN_F32ACC: OnceLock<Option<Kernel>> = OnceLock::new();
1423static CAST_F32_TO_F16: OnceLock<Kernel> = OnceLock::new();
1424static BINARY: OnceLock<Kernel> = OnceLock::new();
1425static UNARY: OnceLock<Kernel> = OnceLock::new();
1426static UNARY_F16_MIRROR: OnceLock<Kernel> = OnceLock::new();
1427static COMPARE: OnceLock<Kernel> = OnceLock::new();
1428static WHEREK: OnceLock<Kernel> = OnceLock::new();
1429static REDUCE: OnceLock<Kernel> = OnceLock::new();
1430static SOFTMAX: OnceLock<Kernel> = OnceLock::new();
1431static LAYERNORM: OnceLock<Kernel> = OnceLock::new();
1432static RMS_NORM_BWD: OnceLock<Kernel> = OnceLock::new();
1433static RMS_NORM_BWD_PARAM: OnceLock<Kernel> = OnceLock::new();
1434static LAYER_NORM_BWD_INPUT: OnceLock<Kernel> = OnceLock::new();
1435static LAYER_NORM_BWD_GAMMA: OnceLock<Kernel> = OnceLock::new();
1436static LAYER_NORM_BWD_GAMMA_REDUCE: OnceLock<Kernel> = OnceLock::new();
1437static CUMSUM_BWD: OnceLock<Kernel> = OnceLock::new();
1438static ROPE_BWD: OnceLock<Kernel> = OnceLock::new();
1439static GATHER_BWD_ZERO: OnceLock<Kernel> = OnceLock::new();
1440static GATHER_BWD_ACC: OnceLock<Kernel> = OnceLock::new();
1441static CUMSUM: OnceLock<Kernel> = OnceLock::new();
1442static FFT_GPU_RADIX2: OnceLock<Kernel> = OnceLock::new();
1443static FFT_GPU_BITREV: OnceLock<Kernel> = OnceLock::new();
1444static FFT_GPU_INNER: OnceLock<Kernel> = OnceLock::new();
1445static FFT_GPU_OUTER_R4: OnceLock<Kernel> = OnceLock::new();
1446static FFT_GPU_OUTER_R2: OnceLock<Kernel> = OnceLock::new();
1447static COPY: OnceLock<Kernel> = OnceLock::new();
1448static ELEMENTWISE_REGION: OnceLock<Kernel> = OnceLock::new();
1449static ELEMENTWISE_REGION_SPATIAL: OnceLock<Kernel> = OnceLock::new();
1450static TRANSPOSE: OnceLock<Kernel> = OnceLock::new();
1451static NARROW: OnceLock<Kernel> = OnceLock::new();
1452static CONCAT: OnceLock<Kernel> = OnceLock::new();
1453static GATHER: OnceLock<Kernel> = OnceLock::new();
1454static GATHER_AXIS: OnceLock<Kernel> = OnceLock::new();
1455static ATTENTION: OnceLock<Kernel> = OnceLock::new();
1456static ATTENTION_BWD: OnceLock<Kernel> = OnceLock::new();
1457static ROPE: OnceLock<Kernel> = OnceLock::new();
1458static EXPAND: OnceLock<Kernel> = OnceLock::new();
1459static ARGMAX: OnceLock<Kernel> = OnceLock::new();
1460static POOL2D: OnceLock<Kernel> = OnceLock::new();
1461static CONV2D: OnceLock<Kernel> = OnceLock::new();
1462static POOL1D: OnceLock<Kernel> = OnceLock::new();
1463static POOL3D: OnceLock<Kernel> = OnceLock::new();
1464static CONV1D: OnceLock<Kernel> = OnceLock::new();
1465static CONV3D: OnceLock<Kernel> = OnceLock::new();
1466static SCATTER_ADD: OnceLock<Kernel> = OnceLock::new();
1467static TOPK: OnceLock<Kernel> = OnceLock::new();
1468static WELCH_PEAKS_GPU: OnceLock<Kernel> = OnceLock::new();
1469static UMAP_KNN: OnceLock<Kernel> = OnceLock::new();
1470static GROUPED_MATMUL: OnceLock<Kernel> = OnceLock::new();
1471static SAMPLE: OnceLock<Kernel> = OnceLock::new();
1472static SELECTIVE_SCAN: OnceLock<Kernel> = OnceLock::new();
1473static DEQUANT_MATMUL: OnceLock<Kernel> = OnceLock::new();
1474static FUSED_RESIDUAL_LN: OnceLock<Kernel> = OnceLock::new();
1475static FUSED_RESIDUAL_LN_TEE: OnceLock<Kernel> = OnceLock::new();
1476static FUSED_RESIDUAL_RMS_NORM: OnceLock<Kernel> = OnceLock::new();
1477static MATMUL_QKV: OnceLock<Kernel> = OnceLock::new();
1478static MATMUL_QKV_COOP_F32: OnceLock<Kernel> = OnceLock::new();
1479static MATMUL_QKV_COOP_F16_VK: OnceLock<Kernel> = OnceLock::new();
1480static MATMUL_QKV_COOP_F16_VK_WIDEN: OnceLock<Kernel> = OnceLock::new();
1481static MATMUL_QKV_COOP_F16_VK_F32ACC: OnceLock<Option<Kernel>> = OnceLock::new();
1482static MATMUL_QKV_COOP_F16_VK_WIDEN_F32ACC: OnceLock<Option<Kernel>> = OnceLock::new();
1483
1484pub fn matmul_kernel(device: &wgpu::Device) -> &'static Kernel {
1485 MATMUL.get_or_init(|| build_kernel(device, "rlx-wgpu matmul", MATMUL_WGSL, "matmul"))
1486}
1487pub fn matmul_wide_kernel(device: &wgpu::Device) -> &'static Kernel {
1488 MATMUL_WIDE.get_or_init(|| {
1489 build_kernel(
1490 device,
1491 "rlx-wgpu matmul_wide",
1492 MATMUL_WIDE_WGSL,
1493 "matmul_wide",
1494 )
1495 })
1496}
1497pub fn matmul_wide_nv_kernel(device: &wgpu::Device) -> &'static Kernel {
1499 MATMUL_WIDE_NV.get_or_init(|| {
1500 build_kernel(
1501 device,
1502 "rlx-wgpu matmul_wide_nv",
1503 MATMUL_WIDE_NV_WGSL,
1504 "matmul_wide_nv",
1505 )
1506 })
1507}
1508pub fn matmul_f16w_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1513 if !device.features().contains(wgpu::Features::SHADER_F16) {
1514 return None;
1515 }
1516 Some(MATMUL_F16W.get_or_init(|| {
1517 build_kernel_3(
1518 device,
1519 "rlx-wgpu matmul_f16w",
1520 MATMUL_F16W_WGSL,
1521 "matmul_f16w",
1522 )
1523 }))
1524}
1525pub fn matmul_f16_compute_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1529 if !device.features().contains(wgpu::Features::SHADER_F16) {
1530 return None;
1531 }
1532 Some(MATMUL_F16_COMPUTE.get_or_init(|| {
1533 build_kernel_3(
1534 device,
1535 "rlx-wgpu matmul_f16_compute",
1536 MATMUL_F16_COMPUTE_WGSL,
1537 "matmul_f16_compute",
1538 )
1539 }))
1540}
1541pub fn matmul_coop16_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1547 let feats = device.features();
1548 if !feats.contains(wgpu::Features::SHADER_F16)
1549 || !feats.contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX)
1550 {
1551 return None;
1552 }
1553 Some(MATMUL_COOP16.get_or_init(|| {
1554 build_kernel_3(
1555 device,
1556 "rlx-wgpu matmul_coop16",
1557 MATMUL_COOP16_WGSL,
1558 "matmul_coop16",
1559 )
1560 }))
1561}
1562pub fn matmul_coop_f32_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1568 let feats = device.features();
1569 if !feats.contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX) {
1570 return None;
1571 }
1572 Some(MATMUL_COOP_F32.get_or_init(|| {
1573 build_kernel(
1574 device,
1575 "rlx-wgpu matmul_coop_f32",
1576 MATMUL_COOP_F32_WGSL,
1577 "matmul_coop_f32",
1578 )
1579 }))
1580}
1581pub fn matmul_coop_f32_portable_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1583 let feats = device.features();
1584 if !feats.contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX)
1585 || !crate::device::coop_f32_8x8_supported()
1586 {
1587 return None;
1588 }
1589 Some(MATMUL_COOP_F32_PORTABLE.get_or_init(|| {
1590 build_kernel(
1591 device,
1592 "rlx-wgpu matmul_coop_f32_portable",
1593 MATMUL_COOP_F32_PORTABLE_WGSL,
1594 "matmul_coop_f32_portable",
1595 )
1596 }))
1597}
1598fn coop_f16_vk_device_ready(device: &wgpu::Device) -> bool {
1599 if rlx_ir::env::flag("RLX_WGPU_COOP_F16_VK_DISABLE")
1604 || !rlx_ir::env::flag("RLX_WGPU_COOP_F16_VK_ENABLE")
1605 {
1606 return false;
1607 }
1608 device.features().contains(wgpu::Features::SHADER_F16)
1609 && device
1610 .features()
1611 .contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX)
1612 && crate::device::coop_f16_16x16_supported()
1613 && crate::device::coop_discrete_backend()
1614}
1615
1616fn coop_f16_vk_f32acc_device_ready(device: &wgpu::Device) -> bool {
1617 coop_f16_vk_device_ready(device) && crate::device::coop_f16_16x16_f32_acc_supported()
1618}
1619
1620pub fn matmul_coop_f16_vulkan_f32acc_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1621 if !coop_f16_vk_f32acc_device_ready(device) {
1622 return None;
1623 }
1624 MATMUL_COOP_F16_VULKAN_F32ACC
1625 .get_or_init(|| {
1626 try_build_kernel_coop_f16_vk(
1627 device,
1628 "rlx-wgpu matmul_coop_f16_vulkan_f32acc",
1629 MATMUL_COOP_F16_VULKAN_F32ACC_WGSL,
1630 "matmul_coop_f16_vulkan_f32acc",
1631 )
1632 })
1633 .as_ref()
1634}
1635
1636pub fn matmul_coop_f16_vulkan_widen_f32acc_kernel(
1637 device: &wgpu::Device,
1638) -> Option<&'static Kernel> {
1639 if !coop_f16_vk_f32acc_device_ready(device) {
1640 return None;
1641 }
1642 MATMUL_COOP_F16_VULKAN_WIDEN_F32ACC
1643 .get_or_init(|| {
1644 try_build_kernel_coop_f16_vk(
1645 device,
1646 "rlx-wgpu matmul_coop_f16_vulkan_widen_f32acc",
1647 MATMUL_COOP_F16_VULKAN_WIDEN_F32ACC_WGSL,
1648 "matmul_coop_f16_vulkan_widen_f32acc",
1649 )
1650 })
1651 .as_ref()
1652}
1653
1654fn coop_f16_vk_use_f32acc(device: &wgpu::Device) -> bool {
1655 !rlx_ir::env::flag("RLX_WGPU_COOP_F16_VK_NO_F32ACC")
1656 && matmul_coop_f16_vulkan_f32acc_kernel(device).is_some()
1657}
1658
1659fn pick_coop_f16_vk_matmul(
1660 device: &wgpu::Device,
1661 n: u32,
1662 loadt: fn(&wgpu::Device) -> Option<&'static Kernel>,
1663 loadt_f32acc: fn(&wgpu::Device) -> Option<&'static Kernel>,
1664 widen: fn(&wgpu::Device) -> Option<&'static Kernel>,
1665 widen_f32acc: fn(&wgpu::Device) -> Option<&'static Kernel>,
1666) -> Option<&'static Kernel> {
1667 if coop_f16_vk_use_f32acc(device) {
1668 if coop_f16_vk_widen_b_load(n) {
1669 return widen_f32acc(device).or_else(|| loadt_f32acc(device));
1670 }
1671 return loadt_f32acc(device);
1672 }
1673 if coop_f16_vk_widen_b_load(n) {
1674 widen(device).or_else(|| loadt(device))
1675 } else {
1676 loadt(device)
1677 }
1678}
1679
1680pub fn matmul_coop_f16_vulkan_active_kernel(
1682 device: &wgpu::Device,
1683 n: u32,
1684) -> Option<&'static Kernel> {
1685 pick_coop_f16_vk_matmul(
1686 device,
1687 n,
1688 matmul_coop_f16_vulkan_kernel,
1689 matmul_coop_f16_vulkan_f32acc_kernel,
1690 matmul_coop_f16_vulkan_widen_kernel,
1691 matmul_coop_f16_vulkan_widen_f32acc_kernel,
1692 )
1693}
1694
1695pub fn matmul_coop_f16_vulkan_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1696 if !coop_f16_vk_device_ready(device) {
1697 return None;
1698 }
1699 Some(MATMUL_COOP_F16_VULKAN.get_or_init(|| {
1700 build_kernel_coop_f16_vk(
1701 device,
1702 "rlx-wgpu matmul_coop_f16_vulkan",
1703 MATMUL_COOP_F16_VULKAN_WGSL,
1704 "matmul_coop_f16_vulkan",
1705 )
1706 }))
1707}
1708pub const COOP_F16_VK_WIDEN_N: u32 = 768;
1710
1711pub fn coop_f16_vk_widen_b_load(n: u32) -> bool {
1713 n > COOP_F16_VK_WIDEN_N && !rlx_ir::env::flag("RLX_WGPU_COOP_F16_VK_LOAD_T")
1714}
1715
1716pub fn matmul_coop_f16_vulkan_widen_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1717 if !coop_f16_vk_device_ready(device) {
1718 return None;
1719 }
1720 Some(MATMUL_COOP_F16_VULKAN_WIDEN.get_or_init(|| {
1721 build_kernel_coop_f16_vk(
1722 device,
1723 "rlx-wgpu matmul_coop_f16_vulkan_widen",
1724 MATMUL_COOP_F16_VULKAN_WIDEN_WGSL,
1725 "matmul_coop_f16_vulkan_widen",
1726 )
1727 }))
1728}
1729pub fn coop_f16_vk_f32acc_available(device: &wgpu::Device) -> bool {
1730 matmul_coop_f16_vulkan_f32acc_kernel(device).is_some()
1731}
1732pub fn matmul_coop_f32_active_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1734 match crate::device::wgpu_device().map(|d| d.backend) {
1735 Some(wgpu::Backend::Metal) => matmul_coop_f32_kernel(device),
1736 Some(wgpu::Backend::Vulkan) | Some(wgpu::Backend::Dx12) => {
1737 matmul_coop_f32_portable_kernel(device)
1738 }
1739 _ => None,
1740 }
1741}
1742pub fn matmul_wide_active_kernel(device: &wgpu::Device) -> &'static Kernel {
1744 match crate::device::wgpu_device().map(|d| d.backend) {
1745 Some(wgpu::Backend::Vulkan) | Some(wgpu::Backend::Dx12) => matmul_wide_nv_kernel(device),
1746 _ => matmul_wide_kernel(device),
1747 }
1748}
1749pub fn cast_f32_to_f16_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1754 if !device.features().contains(wgpu::Features::SHADER_F16) {
1755 return None;
1756 }
1757 Some(CAST_F32_TO_F16.get_or_init(|| {
1758 build_kernel_cast_f32_to_f16(
1759 device,
1760 "rlx-wgpu cast_f32_to_f16",
1761 CAST_F32_TO_F16_WGSL,
1762 "cast_f32_to_f16",
1763 )
1764 }))
1765}
1766pub fn binary_kernel(device: &wgpu::Device) -> &'static Kernel {
1767 BINARY.get_or_init(|| build_kernel(device, "rlx-wgpu binary", BINARY_WGSL, "binary"))
1768}
1769pub fn unary_kernel(device: &wgpu::Device) -> &'static Kernel {
1770 UNARY.get_or_init(|| build_kernel(device, "rlx-wgpu unary", UNARY_WGSL, "unary"))
1771}
1772pub fn unary_f16_mirror_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1773 if !device.features().contains(wgpu::Features::SHADER_F16) {
1774 return None;
1775 }
1776 Some(UNARY_F16_MIRROR.get_or_init(|| {
1777 build_kernel_f32_rw_uniform_f16_rw(
1778 device,
1779 "rlx-wgpu unary_f16_mirror",
1780 UNARY_F16_MIRROR_WGSL,
1781 "unary_f16_mirror",
1782 )
1783 }))
1784}
1785pub fn compare_kernel(device: &wgpu::Device) -> &'static Kernel {
1786 COMPARE.get_or_init(|| build_kernel(device, "rlx-wgpu compare", COMPARE_WGSL, "compare"))
1787}
1788pub fn where_kernel(device: &wgpu::Device) -> &'static Kernel {
1789 WHEREK.get_or_init(|| build_kernel(device, "rlx-wgpu where", WHERE_WGSL, "where_select"))
1790}
1791pub fn reduce_kernel(device: &wgpu::Device) -> &'static Kernel {
1792 REDUCE.get_or_init(|| build_kernel(device, "rlx-wgpu reduce", REDUCE_WGSL, "reduce"))
1793}
1794pub fn softmax_kernel(device: &wgpu::Device) -> &'static Kernel {
1795 SOFTMAX.get_or_init(|| build_kernel(device, "rlx-wgpu softmax", SOFTMAX_WGSL, "softmax"))
1796}
1797pub fn layernorm_kernel(device: &wgpu::Device) -> &'static Kernel {
1798 LAYERNORM.get_or_init(|| build_kernel(device, "rlx-wgpu layernorm", LAYERNORM_WGSL, "norm"))
1799}
1800pub fn rms_norm_backward_kernel(device: &wgpu::Device) -> &'static Kernel {
1801 RMS_NORM_BWD.get_or_init(|| {
1802 build_kernel(
1803 device,
1804 "rlx-wgpu rms_norm_bwd",
1805 RMS_NORM_BWD_WGSL,
1806 "rms_norm_bwd",
1807 )
1808 })
1809}
1810pub fn rms_norm_backward_param_kernel(device: &wgpu::Device) -> &'static Kernel {
1811 RMS_NORM_BWD_PARAM.get_or_init(|| {
1812 build_kernel(
1813 device,
1814 "rlx-wgpu rms_norm_bwd_param",
1815 RMS_NORM_BWD_WGSL,
1816 "rms_norm_bwd_param",
1817 )
1818 })
1819}
1820pub fn layer_norm_backward_input_kernel(device: &wgpu::Device) -> &'static Kernel {
1821 LAYER_NORM_BWD_INPUT.get_or_init(|| {
1822 build_kernel(
1823 device,
1824 "rlx-wgpu layer_norm_bwd_input",
1825 LAYER_NORM_BWD_WGSL,
1826 "layer_norm_bwd_input",
1827 )
1828 })
1829}
1830pub fn layer_norm_backward_gamma_partial_kernel(device: &wgpu::Device) -> &'static Kernel {
1831 LAYER_NORM_BWD_GAMMA.get_or_init(|| {
1832 build_kernel(
1833 device,
1834 "rlx-wgpu layer_norm_bwd_gamma_partial",
1835 LAYER_NORM_BWD_WGSL,
1836 "layer_norm_bwd_gamma_partial",
1837 )
1838 })
1839}
1840
1841pub fn layer_norm_backward_gamma_reduce_kernel(device: &wgpu::Device) -> &'static Kernel {
1842 LAYER_NORM_BWD_GAMMA_REDUCE.get_or_init(|| {
1843 build_kernel(
1844 device,
1845 "rlx-wgpu layer_norm_bwd_gamma_reduce",
1846 LAYER_NORM_BWD_WGSL,
1847 "layer_norm_bwd_gamma_reduce",
1848 )
1849 })
1850}
1851pub fn cumsum_backward_kernel(device: &wgpu::Device) -> &'static Kernel {
1852 CUMSUM_BWD
1853 .get_or_init(|| build_kernel(device, "rlx-wgpu cumsum_bwd", CUMSUM_BWD_WGSL, "cumsum_bwd"))
1854}
1855pub fn rope_backward_kernel(device: &wgpu::Device) -> &'static Kernel {
1856 ROPE_BWD.get_or_init(|| build_kernel(device, "rlx-wgpu rope_bwd", ROPE_BWD_WGSL, "rope_bwd"))
1857}
1858pub fn gather_backward_zero_kernel(device: &wgpu::Device) -> &'static Kernel {
1859 GATHER_BWD_ZERO.get_or_init(|| {
1860 build_kernel(
1861 device,
1862 "rlx-wgpu gather_bwd_zero",
1863 GATHER_BWD_WGSL,
1864 "gather_bwd_zero",
1865 )
1866 })
1867}
1868pub fn gather_backward_acc_kernel(device: &wgpu::Device) -> &'static Kernel {
1869 GATHER_BWD_ACC.get_or_init(|| {
1870 build_kernel(
1871 device,
1872 "rlx-wgpu gather_bwd_acc",
1873 GATHER_BWD_WGSL,
1874 "gather_bwd_acc",
1875 )
1876 })
1877}
1878pub fn cumsum_kernel(device: &wgpu::Device) -> &'static Kernel {
1879 CUMSUM.get_or_init(|| build_kernel(device, "rlx-wgpu cumsum", CUMSUM_WGSL, "cumsum"))
1880}
1881pub fn fft_gpu_radix2_full_kernel(device: &wgpu::Device) -> &'static Kernel {
1882 FFT_GPU_RADIX2.get_or_init(|| {
1883 build_kernel(
1884 device,
1885 "rlx-wgpu fft_radix2_full",
1886 FFT_GPU_WGSL,
1887 "fft_radix2_full",
1888 )
1889 })
1890}
1891pub fn fft_gpu_bit_reverse_kernel(device: &wgpu::Device) -> &'static Kernel {
1892 FFT_GPU_BITREV.get_or_init(|| {
1893 build_kernel(
1894 device,
1895 "rlx-wgpu fft_bit_reverse",
1896 FFT_GPU_WGSL,
1897 "fft_bit_reverse",
1898 )
1899 })
1900}
1901pub fn fft_gpu_inner_kernel(device: &wgpu::Device) -> &'static Kernel {
1902 FFT_GPU_INNER
1903 .get_or_init(|| build_kernel(device, "rlx-wgpu fft_inner", FFT_GPU_WGSL, "fft_inner"))
1904}
1905pub fn fft_gpu_outer_r4_kernel(device: &wgpu::Device) -> &'static Kernel {
1906 FFT_GPU_OUTER_R4.get_or_init(|| {
1907 build_kernel(
1908 device,
1909 "rlx-wgpu fft_outer_r4",
1910 FFT_GPU_WGSL,
1911 "fft_outer_r4",
1912 )
1913 })
1914}
1915pub fn fft_gpu_outer_r2_kernel(device: &wgpu::Device) -> &'static Kernel {
1916 FFT_GPU_OUTER_R2.get_or_init(|| {
1917 build_kernel(
1918 device,
1919 "rlx-wgpu fft_outer_r2",
1920 FFT_GPU_WGSL,
1921 "fft_outer_r2",
1922 )
1923 })
1924}
1925pub fn copy_kernel(device: &wgpu::Device) -> &'static Kernel {
1926 COPY.get_or_init(|| build_kernel(device, "rlx-wgpu copy", COPY_WGSL, "copy"))
1927}
1928pub fn elementwise_region_kernel(device: &wgpu::Device) -> &'static Kernel {
1929 ELEMENTWISE_REGION.get_or_init(|| {
1934 build_kernel_region(
1935 device,
1936 "rlx-wgpu elementwise_region",
1937 ELEMENTWISE_REGION_WGSL,
1938 "elementwise_region",
1939 )
1940 })
1941}
1942
1943pub fn elementwise_region_spatial_kernel(device: &wgpu::Device) -> &'static Kernel {
1944 ELEMENTWISE_REGION_SPATIAL.get_or_init(|| {
1945 build_kernel_region(
1946 device,
1947 "rlx-wgpu elementwise_region_spatial",
1948 ELEMENTWISE_REGION_WGSL,
1949 "elementwise_region_spatial",
1950 )
1951 })
1952}
1953
1954static BATCH_ELEMENTWISE_REGION: std::sync::OnceLock<Kernel> = std::sync::OnceLock::new();
1955
1956pub fn batch_elementwise_region_kernel(device: &wgpu::Device) -> &'static Kernel {
1957 BATCH_ELEMENTWISE_REGION.get_or_init(|| {
1958 build_kernel_region(
1959 device,
1960 "rlx-wgpu batch_elementwise_region",
1961 ELEMENTWISE_REGION_WGSL,
1962 "batch_elementwise_region",
1963 )
1964 })
1965}
1966
1967fn build_kernel_region(
1968 device: &wgpu::Device,
1969 label: &'static str,
1970 wgsl: &str,
1971 entry_point: &'static str,
1972) -> Kernel {
1973 let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1974 label: Some(label),
1975 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1976 });
1977 let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1978 label: Some(label),
1979 entries: &[
1980 wgpu::BindGroupLayoutEntry {
1981 binding: 0,
1982 visibility: wgpu::ShaderStages::COMPUTE,
1983 ty: wgpu::BindingType::Buffer {
1984 ty: wgpu::BufferBindingType::Storage { read_only: false },
1985 has_dynamic_offset: false,
1986 min_binding_size: None,
1987 },
1988 count: None,
1989 },
1990 wgpu::BindGroupLayoutEntry {
1991 binding: 1,
1992 visibility: wgpu::ShaderStages::COMPUTE,
1993 ty: wgpu::BindingType::Buffer {
1994 ty: wgpu::BufferBindingType::Storage { read_only: true },
1996 has_dynamic_offset: false,
1997 min_binding_size: None,
1998 },
1999 count: None,
2000 },
2001 ],
2002 });
2003 let pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
2004 label: Some(label),
2005 bind_group_layouts: &[Some(&bgl)],
2006 immediate_size: 0,
2007 });
2008 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
2009 label: Some(label),
2010 layout: Some(&pl),
2011 module: &module,
2012 entry_point: Some(entry_point),
2013 compilation_options: Default::default(),
2014 cache: None,
2015 });
2016 Kernel { pipeline, bgl }
2017}
2018pub fn transpose_kernel(device: &wgpu::Device) -> &'static Kernel {
2019 TRANSPOSE
2020 .get_or_init(|| build_kernel_3(device, "rlx-wgpu transpose", TRANSPOSE_WGSL, "transpose"))
2021}
2022pub fn narrow_kernel(device: &wgpu::Device) -> &'static Kernel {
2023 NARROW.get_or_init(|| build_kernel(device, "rlx-wgpu narrow", NARROW_WGSL, "narrow"))
2024}
2025pub fn concat_kernel(device: &wgpu::Device) -> &'static Kernel {
2026 CONCAT.get_or_init(|| build_kernel(device, "rlx-wgpu concat", CONCAT_WGSL, "concat"))
2027}
2028pub fn gather_kernel(device: &wgpu::Device) -> &'static Kernel {
2029 GATHER.get_or_init(|| build_kernel(device, "rlx-wgpu gather", GATHER_WGSL, "gather"))
2030}
2031pub fn gather_axis_kernel(device: &wgpu::Device) -> &'static Kernel {
2032 GATHER_AXIS.get_or_init(|| {
2033 build_kernel(
2034 device,
2035 "rlx-wgpu gather_axis",
2036 GATHER_AXIS_WGSL,
2037 "gather_axis",
2038 )
2039 })
2040}
2041pub fn attention_kernel(device: &wgpu::Device) -> &'static Kernel {
2042 ATTENTION
2043 .get_or_init(|| build_kernel(device, "rlx-wgpu attention", ATTENTION_WGSL, "attention"))
2044}
2045pub fn attention_bwd_kernel(device: &wgpu::Device) -> &'static Kernel {
2046 ATTENTION_BWD.get_or_init(|| {
2047 build_kernel(
2048 device,
2049 "rlx-wgpu attention_bwd",
2050 ATTENTION_BWD_WGSL,
2051 "attention_bwd",
2052 )
2053 })
2054}
2055pub fn rope_kernel(device: &wgpu::Device) -> &'static Kernel {
2056 ROPE.get_or_init(|| build_kernel(device, "rlx-wgpu rope", ROPE_WGSL, "rope"))
2057}
2058pub fn expand_kernel(device: &wgpu::Device) -> &'static Kernel {
2059 EXPAND.get_or_init(|| build_kernel_3(device, "rlx-wgpu expand", EXPAND_WGSL, "expand"))
2060}
2061pub fn argmax_kernel(device: &wgpu::Device) -> &'static Kernel {
2062 ARGMAX.get_or_init(|| build_kernel(device, "rlx-wgpu argmax", ARGMAX_WGSL, "argmax"))
2063}
2064pub fn pool2d_kernel(device: &wgpu::Device) -> &'static Kernel {
2065 POOL2D.get_or_init(|| build_kernel(device, "rlx-wgpu pool2d", POOL2D_WGSL, "pool2d"))
2066}
2067pub fn conv2d_kernel(device: &wgpu::Device) -> &'static Kernel {
2068 CONV2D.get_or_init(|| build_kernel(device, "rlx-wgpu conv2d", CONV2D_WGSL, "conv2d"))
2069}
2070pub fn pool1d_kernel(device: &wgpu::Device) -> &'static Kernel {
2071 POOL1D.get_or_init(|| build_kernel(device, "rlx-wgpu pool1d", POOL1D_WGSL, "pool1d"))
2072}
2073pub fn pool3d_kernel(device: &wgpu::Device) -> &'static Kernel {
2074 POOL3D.get_or_init(|| build_kernel(device, "rlx-wgpu pool3d", POOL3D_WGSL, "pool3d"))
2075}
2076pub fn conv1d_kernel(device: &wgpu::Device) -> &'static Kernel {
2077 CONV1D.get_or_init(|| build_kernel(device, "rlx-wgpu conv1d", CONV1D_WGSL, "conv1d"))
2078}
2079pub fn conv3d_kernel(device: &wgpu::Device) -> &'static Kernel {
2080 CONV3D.get_or_init(|| build_kernel(device, "rlx-wgpu conv3d", CONV3D_WGSL, "conv3d"))
2081}
2082pub fn scatter_add_kernel(device: &wgpu::Device) -> &'static Kernel {
2083 SCATTER_ADD.get_or_init(|| {
2084 build_kernel(
2085 device,
2086 "rlx-wgpu scatter_add",
2087 SCATTER_ADD_WGSL,
2088 "scatter_add",
2089 )
2090 })
2091}
2092pub fn topk_kernel(device: &wgpu::Device) -> &'static Kernel {
2093 TOPK.get_or_init(|| build_kernel(device, "rlx-wgpu topk", TOPK_WGSL, "topk"))
2094}
2095pub fn welch_peaks_gpu_kernel(device: &wgpu::Device) -> &'static Kernel {
2096 WELCH_PEAKS_GPU.get_or_init(|| {
2097 build_kernel(
2098 device,
2099 "rlx-wgpu welch_peaks_gpu",
2100 WELCH_PEAKS_GPU_WGSL,
2101 "welch_peaks_gpu",
2102 )
2103 })
2104}
2105pub fn umap_knn_kernel(device: &wgpu::Device) -> &'static Kernel {
2106 UMAP_KNN.get_or_init(|| build_kernel(device, "rlx-wgpu umap_knn", UMAP_KNN_WGSL, "umap_knn"))
2107}
2108pub fn grouped_matmul_kernel(device: &wgpu::Device) -> &'static Kernel {
2109 GROUPED_MATMUL.get_or_init(|| {
2110 build_kernel(
2111 device,
2112 "rlx-wgpu grouped_matmul",
2113 GROUPED_MATMUL_WGSL,
2114 "grouped_matmul",
2115 )
2116 })
2117}
2118pub fn sample_kernel(device: &wgpu::Device) -> &'static Kernel {
2119 SAMPLE.get_or_init(|| build_kernel(device, "rlx-wgpu sample", SAMPLE_WGSL, "sample"))
2120}
2121pub fn selective_scan_kernel(device: &wgpu::Device) -> &'static Kernel {
2122 SELECTIVE_SCAN.get_or_init(|| {
2123 build_kernel(
2124 device,
2125 "rlx-wgpu selective_scan",
2126 SELECTIVE_SCAN_WGSL,
2127 "selective_scan",
2128 )
2129 })
2130}
2131pub fn dequant_matmul_kernel(device: &wgpu::Device) -> &'static Kernel {
2132 DEQUANT_MATMUL.get_or_init(|| {
2133 build_kernel(
2134 device,
2135 "rlx-wgpu dequant_matmul",
2136 DEQUANT_MATMUL_WGSL,
2137 "dequant_matmul",
2138 )
2139 })
2140}
2141pub fn fused_residual_ln_kernel(device: &wgpu::Device) -> &'static Kernel {
2142 FUSED_RESIDUAL_LN.get_or_init(|| {
2143 build_kernel(
2144 device,
2145 "rlx-wgpu fused_residual_ln",
2146 FUSED_RESIDUAL_LN_WGSL,
2147 "fused_residual_ln",
2148 )
2149 })
2150}
2151pub fn fused_residual_ln_tee_kernel(device: &wgpu::Device) -> &'static Kernel {
2152 FUSED_RESIDUAL_LN_TEE.get_or_init(|| {
2153 build_kernel(
2154 device,
2155 "rlx-wgpu fused_residual_ln_tee",
2156 FUSED_RESIDUAL_LN_TEE_WGSL,
2157 "fused_residual_ln_tee",
2158 )
2159 })
2160}
2161pub fn fused_residual_rms_norm_kernel(device: &wgpu::Device) -> &'static Kernel {
2162 FUSED_RESIDUAL_RMS_NORM.get_or_init(|| {
2163 build_kernel(
2164 device,
2165 "rlx-wgpu fused_residual_rms_norm",
2166 FUSED_RESIDUAL_RMS_NORM_WGSL,
2167 "fused_residual_rms_norm",
2168 )
2169 })
2170}
2171pub fn matmul_qkv_kernel(device: &wgpu::Device) -> &'static Kernel {
2172 MATMUL_QKV
2173 .get_or_init(|| build_kernel(device, "rlx-wgpu matmul_qkv", MATMUL_QKV_WGSL, "matmul_qkv"))
2174}
2175pub fn matmul_qkv_coop_f32_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
2176 if !device
2177 .features()
2178 .contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX)
2179 {
2180 return None;
2181 }
2182 Some(MATMUL_QKV_COOP_F32.get_or_init(|| {
2183 build_kernel(
2184 device,
2185 "rlx-wgpu matmul_qkv_coop_f32",
2186 MATMUL_QKV_COOP_F32_WGSL,
2187 "matmul_qkv_coop_f32",
2188 )
2189 }))
2190}
2191pub fn matmul_qkv_coop_f16_vk_f32acc_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
2192 if !coop_f16_vk_f32acc_device_ready(device) {
2193 return None;
2194 }
2195 MATMUL_QKV_COOP_F16_VK_F32ACC
2196 .get_or_init(|| {
2197 try_build_kernel_coop_f16_vk(
2198 device,
2199 "rlx-wgpu matmul_qkv_coop_f16_vk_f32acc",
2200 MATMUL_QKV_COOP_F16_VK_F32ACC_WGSL,
2201 "matmul_qkv_coop_f16_vk_f32acc",
2202 )
2203 })
2204 .as_ref()
2205}
2206
2207pub fn matmul_qkv_coop_f16_vk_widen_f32acc_kernel(
2208 device: &wgpu::Device,
2209) -> Option<&'static Kernel> {
2210 if !coop_f16_vk_f32acc_device_ready(device) {
2211 return None;
2212 }
2213 MATMUL_QKV_COOP_F16_VK_WIDEN_F32ACC
2214 .get_or_init(|| {
2215 try_build_kernel_coop_f16_vk(
2216 device,
2217 "rlx-wgpu matmul_qkv_coop_f16_vk_widen_f32acc",
2218 MATMUL_QKV_COOP_F16_VK_WIDEN_F32ACC_WGSL,
2219 "matmul_qkv_coop_f16_vk_widen_f32acc",
2220 )
2221 })
2222 .as_ref()
2223}
2224
2225pub fn matmul_qkv_coop_f16_vk_active_kernel(
2226 device: &wgpu::Device,
2227 n: u32,
2228) -> Option<&'static Kernel> {
2229 pick_coop_f16_vk_matmul(
2230 device,
2231 n,
2232 matmul_qkv_coop_f16_vk_kernel,
2233 matmul_qkv_coop_f16_vk_f32acc_kernel,
2234 matmul_qkv_coop_f16_vk_widen_kernel,
2235 matmul_qkv_coop_f16_vk_widen_f32acc_kernel,
2236 )
2237}
2238
2239pub fn matmul_qkv_coop_f16_vk_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
2240 if !coop_f16_vk_device_ready(device) {
2241 return None;
2242 }
2243 Some(MATMUL_QKV_COOP_F16_VK.get_or_init(|| {
2244 build_kernel_coop_f16_vk(
2245 device,
2246 "rlx-wgpu matmul_qkv_coop_f16_vk",
2247 MATMUL_QKV_COOP_F16_VK_WGSL,
2248 "matmul_qkv_coop_f16_vk",
2249 )
2250 }))
2251}
2252pub fn matmul_qkv_coop_f16_vk_widen_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
2253 if !coop_f16_vk_device_ready(device) {
2254 return None;
2255 }
2256 Some(MATMUL_QKV_COOP_F16_VK_WIDEN.get_or_init(|| {
2257 build_kernel_coop_f16_vk(
2258 device,
2259 "rlx-wgpu matmul_qkv_coop_f16_vk_widen",
2260 MATMUL_QKV_COOP_F16_VK_WIDEN_WGSL,
2261 "matmul_qkv_coop_f16_vk_widen",
2262 )
2263 }))
2264}