1use std::sync::OnceLock;
24
25use bytemuck::{Pod, Zeroable};
26
27pub const MATMUL_WGSL: &str = include_str!("matmul.wgsl");
28pub const MATMUL_WIDE_WGSL: &str = include_str!("matmul_wide.wgsl");
29pub const MATMUL_F16W_WGSL: &str = include_str!("matmul_f16w.wgsl");
30pub const MATMUL_F16_COMPUTE_WGSL: &str = include_str!("matmul_f16_compute.wgsl");
31pub const MATMUL_COOP16_WGSL: &str = include_str!("matmul_coop16.wgsl");
32pub const MATMUL_COOP_F32_WGSL: &str = include_str!("matmul_coop_f32.wgsl");
33pub const CAST_F32_TO_F16_WGSL: &str = include_str!("cast_f32_to_f16.wgsl");
34pub const BINARY_WGSL: &str = include_str!("binary.wgsl");
35pub const UNARY_WGSL: &str = include_str!("unary.wgsl");
36pub const COMPARE_WGSL: &str = include_str!("compare.wgsl");
37pub const WHERE_WGSL: &str = include_str!("where.wgsl");
38pub const REDUCE_WGSL: &str = include_str!("reduce.wgsl");
39pub const SOFTMAX_WGSL: &str = include_str!("softmax.wgsl");
40pub const LAYERNORM_WGSL: &str = include_str!("layernorm.wgsl");
41pub const RMS_NORM_BWD_WGSL: &str = include_str!("rms_norm_backward.wgsl");
42pub const CUMSUM_BWD_WGSL: &str = include_str!("cumsum_backward.wgsl");
43pub const ROPE_BWD_WGSL: &str = include_str!("rope_backward.wgsl");
44pub const GATHER_BWD_WGSL: &str = include_str!("gather_backward.wgsl");
45pub const CUMSUM_WGSL: &str = include_str!("cumsum.wgsl");
46pub const FFT_GPU_WGSL: &str = include_str!("fft_gpu.wgsl");
47pub const COPY_WGSL: &str = include_str!("copy.wgsl");
48pub const ELEMENTWISE_REGION_WGSL: &str = include_str!("elementwise_region.wgsl");
49pub const TRANSPOSE_WGSL: &str = include_str!("transpose.wgsl");
50pub const NARROW_WGSL: &str = include_str!("narrow.wgsl");
51pub const CONCAT_WGSL: &str = include_str!("concat.wgsl");
52pub const GATHER_WGSL: &str = include_str!("gather.wgsl");
53pub const GATHER_AXIS_WGSL: &str = include_str!("gather_axis.wgsl");
54pub const ATTENTION_WGSL: &str = include_str!("attention.wgsl");
55pub const ATTENTION_BWD_WGSL: &str = include_str!("attention_bwd.wgsl");
56pub const ROPE_WGSL: &str = include_str!("rope.wgsl");
57pub const EXPAND_WGSL: &str = include_str!("expand.wgsl");
58pub const ARGMAX_WGSL: &str = include_str!("argmax.wgsl");
59pub const POOL2D_WGSL: &str = include_str!("pool2d.wgsl");
60pub const CONV2D_WGSL: &str = include_str!("conv2d.wgsl");
61pub const POOL1D_WGSL: &str = include_str!("pool1d.wgsl");
62pub const POOL3D_WGSL: &str = include_str!("pool3d.wgsl");
63pub const CONV1D_WGSL: &str = include_str!("conv1d.wgsl");
64pub const CONV3D_WGSL: &str = include_str!("conv3d.wgsl");
65pub const SCATTER_ADD_WGSL: &str = include_str!("scatter_add.wgsl");
66pub const TOPK_WGSL: &str = include_str!("topk.wgsl");
67pub const UMAP_KNN_WGSL: &str = include_str!("umap_knn.wgsl");
68pub const GROUPED_MATMUL_WGSL: &str = include_str!("grouped_matmul.wgsl");
69pub const SAMPLE_WGSL: &str = include_str!("sample.wgsl");
70pub const SELECTIVE_SCAN_WGSL: &str = include_str!("selective_scan.wgsl");
71pub const DEQUANT_MATMUL_WGSL: &str = include_str!("dequant_matmul.wgsl");
72pub const FUSED_RESIDUAL_LN_WGSL: &str = include_str!("fused_residual_ln.wgsl");
73pub const FUSED_RESIDUAL_LN_TEE_WGSL: &str = include_str!("fused_residual_ln_tee.wgsl");
74pub const FUSED_RESIDUAL_RMS_NORM_WGSL: &str = include_str!("fused_residual_rms_norm.wgsl");
75pub const MATMUL_QKV_WGSL: &str = include_str!("matmul_qkv.wgsl");
76pub const MATMUL_QKV_COOP_F32_WGSL: &str = include_str!("matmul_qkv_coop_f32.wgsl");
77
78#[repr(C)]
79#[derive(Debug, Clone, Copy, Pod, Zeroable)]
80pub struct MatmulParams {
81 pub m: u32,
82 pub k: u32,
83 pub n: u32,
84 pub a_off: u32,
85 pub b_off: u32,
86 pub c_off: u32,
87 pub batch: u32,
88 pub a_batch_stride: u32,
89 pub b_batch_stride: u32,
90 pub c_batch_stride: u32,
91 pub has_bias: u32,
92 pub bias_off: u32,
93 pub act_id: u32, pub _pad0: u32,
95 pub _pad1: u32,
96 pub _pad2: u32,
97}
98
99#[repr(C)]
101#[derive(Debug, Clone, Copy, Pod, Zeroable)]
102pub struct BinaryParams {
103 pub n: u32,
104 pub a_off: u32,
105 pub b_off: u32,
106 pub c_off: u32,
107 pub op: u32,
108 pub _p0: u32,
109 pub _p1: u32,
110 pub _p2: u32,
111}
112
113#[repr(C)]
115#[derive(Debug, Clone, Copy, Pod, Zeroable)]
116pub struct UnaryParams {
117 pub n: u32,
118 pub in_off: u32,
119 pub out_off: u32,
120 pub op: u32,
121 pub _p0: u32,
122 pub _p1: u32,
123 pub _p2: u32,
124 pub _p3: u32,
125}
126
127#[repr(C)]
129#[derive(Debug, Clone, Copy, Pod, Zeroable)]
130pub struct WhereParams {
131 pub n: u32,
132 pub cond_off: u32,
133 pub x_off: u32,
134 pub y_off: u32,
135 pub out_off: u32,
136 pub _p0: u32,
137 pub _p1: u32,
138 pub _p2: u32,
139}
140
141#[repr(C)]
143#[derive(Debug, Clone, Copy, Pod, Zeroable)]
144pub struct ReduceParams {
145 pub outer: u32,
146 pub inner: u32,
147 pub in_off: u32,
148 pub out_off: u32,
149 pub op: u32,
150 pub _p0: u32,
151 pub _p1: u32,
152 pub _p2: u32,
153}
154
155#[repr(C)]
157#[derive(Debug, Clone, Copy, Pod, Zeroable)]
158pub struct SoftmaxParams {
159 pub outer: u32,
160 pub inner: u32,
161 pub in_off: u32,
162 pub out_off: u32,
163 pub _p0: u32,
164 pub _p1: u32,
165 pub _p2: u32,
166 pub _p3: u32,
167}
168
169#[repr(C)]
171#[derive(Debug, Clone, Copy, Pod, Zeroable)]
172pub struct LayerNormParams {
173 pub outer: u32,
174 pub inner: u32,
175 pub in_off: u32,
176 pub out_off: u32,
177 pub gamma_off: u32,
178 pub beta_off: u32,
179 pub eps_bits: u32, pub op: u32, }
182
183#[repr(C)]
185#[derive(Debug, Clone, Copy, Pod, Zeroable)]
186pub struct RmsNormBwdParams {
187 pub outer: u32,
188 pub inner: u32,
189 pub x_off: u32,
190 pub gamma_off: u32,
191 pub beta_off: u32,
192 pub dy_off: u32,
193 pub out_off: u32,
194 pub eps_bits: u32,
195 pub wrt: u32,
196}
197
198#[repr(C)]
199#[derive(Debug, Clone, Copy, Pod, Zeroable)]
200pub struct CumsumBwdParams {
201 pub outer: u32,
202 pub inner: u32,
203 pub dy_off: u32,
204 pub dx_off: u32,
205 pub exclusive: u32,
206 pub _p0: u32,
207 pub _p1: u32,
208 pub _p2: u32,
209}
210
211#[repr(C)]
212#[derive(Debug, Clone, Copy, Pod, Zeroable)]
213pub struct RopeBwdParams {
214 pub batch: u32,
215 pub seq: u32,
216 pub hidden: u32,
217 pub head_dim: u32,
218 pub n_rot: u32,
219 pub dy_off: u32,
220 pub cos_off: u32,
221 pub sin_off: u32,
222 pub dx_off: u32,
223 pub cos_len: u32,
224}
225
226#[repr(C)]
227#[derive(Debug, Clone, Copy, Pod, Zeroable)]
228pub struct GatherBwdParams {
229 pub outer: u32,
230 pub axis_dim: u32,
231 pub num_idx: u32,
232 pub trailing: u32,
233 pub dy_off: u32,
234 pub idx_off: u32,
235 pub dst_off: u32,
236 pub _p0: u32,
237}
238
239#[repr(C)]
241#[derive(Debug, Clone, Copy, Pod, Zeroable)]
242pub struct CumsumParams {
243 pub outer: u32,
244 pub inner: u32,
245 pub in_off: u32,
246 pub out_off: u32,
247 pub exclusive: u32,
248 pub _p0: u32,
249 pub _p1: u32,
250 pub _p2: u32,
251}
252
253#[repr(C)]
255#[derive(Debug, Clone, Copy, Pod, Zeroable)]
256pub struct FftParams {
257 pub src_off: u32,
258 pub dst_off: u32,
259 pub n: u32,
260 pub log2n: u32,
261 pub inverse: u32,
262 pub norm_scale: f32,
263 pub _p1: u32,
264 pub _p2: u32,
265}
266
267#[repr(C)]
269#[derive(Debug, Clone, Copy, Pod, Zeroable)]
270pub struct FftGpuParams {
271 pub off: u32,
272 pub dst_off: u32,
273 pub n: u32,
274 pub log2n: u32,
275 pub inverse: u32,
276 pub norm_scale: f32,
277 pub outer: u32,
278 pub tile: u32,
279 pub inner_stages: u32,
280 pub q_or_hs: u32,
281}
282
283#[repr(C)]
294#[derive(Debug, Clone, Copy, Pod, Zeroable)]
295pub struct ElementwiseRegionParams {
296 pub len: u32,
297 pub num_inputs: u32,
298 pub num_steps: u32,
299 pub dst_off: u32,
300 pub input_offs: [u32; 16],
301 pub chain: [u32; 128], pub scalar_input_mask: u32,
303 pub _pad0: u32,
304 pub _pad1: u32,
305 pub _pad2: u32,
306 pub input_modulus: [u32; 16],
307}
308
309#[repr(C)]
311#[derive(Debug, Clone, Copy, Pod, Zeroable)]
312pub struct CopyParams {
313 pub n: u32,
314 pub in_off: u32,
315 pub out_off: u32,
316 pub _p0: u32,
317 pub _p1: u32,
318 pub _p2: u32,
319 pub _p3: u32,
320 pub _p4: u32,
321}
322
323#[repr(C)]
325#[derive(Debug, Clone, Copy, Pod, Zeroable)]
326pub struct TransposeParams {
327 pub rank: u32,
328 pub out_total: u32,
329 pub in_off: u32,
330 pub out_off: u32,
331 pub bucket_outermost: u32,
335 pub out_dim_0: u32,
337 pub _p2: u32,
338 pub _p3: u32,
339}
340
341#[repr(C)]
343#[derive(Debug, Clone, Copy, Pod, Zeroable)]
344pub struct NarrowConcatParams {
345 pub total: u32, pub outer: u32,
347 pub inner: u32,
348 pub axis_in_size: u32,
349 pub axis_out_size: u32,
350 pub start: u32,
351 pub in_off: u32,
352 pub out_off: u32,
353}
354
355#[repr(C)]
357#[derive(Debug, Clone, Copy, Pod, Zeroable)]
358pub struct GatherParams {
359 pub n_out: u32,
360 pub n_idx: u32,
361 pub dim: u32,
362 pub vocab: u32,
363 pub in_off: u32,
364 pub idx_off: u32,
365 pub out_off: u32,
366 pub _p0: u32,
367}
368
369#[repr(C)]
371#[derive(Debug, Clone, Copy, Pod, Zeroable)]
372pub struct GatherAxisParams {
373 pub total: u32,
374 pub outer: u32,
375 pub axis_dim: u32,
376 pub num_idx: u32,
377 pub trailing: u32,
378 pub table_off: u32,
379 pub idx_off: u32,
380 pub out_off: u32,
381}
382
383#[repr(C)]
397#[derive(Debug, Clone, Copy, Pod, Zeroable)]
398pub struct AttentionParams {
399 pub batch: u32,
400 pub heads: u32,
401 pub seq_q: u32,
402 pub seq_k: u32,
403 pub head_dim: u32,
404 pub q_off: u32,
405 pub k_off: u32,
406 pub v_off: u32,
407 pub out_off: u32,
408 pub mask_off: u32,
409 pub mask_kind: u32,
410 pub scale_bits: u32,
411 pub window: u32,
412 pub seq_q_stride: u32,
424 pub seq_k_stride: u32,
425 pub mask_batch_stride: u32,
426 pub mask_head_stride: u32,
427 pub _pad_mask_0: u32,
428 pub _pad_mask_1: u32,
429 pub _pad_mask_2: u32,
430
431 pub q_batch_stride: u32,
436 pub q_head_stride: u32,
437 pub q_seq_stride: u32,
438 pub _pad_q: u32,
439
440 pub k_batch_stride: u32,
441 pub k_head_stride: u32,
442 pub k_seq_stride: u32,
443 pub _pad_k: u32,
444
445 pub v_batch_stride: u32,
446 pub v_head_stride: u32,
447 pub v_seq_stride: u32,
448 pub _pad_v: u32,
449
450 pub o_batch_stride: u32,
451 pub o_head_stride: u32,
452 pub o_seq_stride: u32,
453 pub _pad_o: u32,
454}
455
456#[repr(C)]
458#[derive(Debug, Clone, Copy, Pod, Zeroable)]
459pub struct AttentionBwdParams {
460 pub batch: u32,
461 pub heads: u32,
462 pub seq_q: u32,
463 pub seq_k: u32,
464 pub head_dim: u32,
465 pub q_off: u32,
466 pub k_off: u32,
467 pub v_off: u32,
468 pub dy_off: u32,
469 pub out_off: u32,
470 pub mask_off: u32,
471 pub mask_kind: u32,
472 pub scale_bits: u32,
473 pub window: u32,
474 pub wrt: u32,
475 pub seq_q_stride: u32,
476 pub seq_k_stride: u32,
477 pub mask_batch_stride: u32,
478 pub mask_head_stride: u32,
479 pub _pad_mask_0: u32,
480 pub _pad_mask_1: u32,
481 pub _pad_mask_2: u32,
482 pub q_batch_stride: u32,
483 pub q_head_stride: u32,
484 pub q_seq_stride: u32,
485 pub _pad_q: u32,
486 pub k_batch_stride: u32,
487 pub k_head_stride: u32,
488 pub k_seq_stride: u32,
489 pub _pad_k: u32,
490 pub v_batch_stride: u32,
491 pub v_head_stride: u32,
492 pub v_seq_stride: u32,
493 pub _pad_v: u32,
494 pub o_batch_stride: u32,
495 pub o_head_stride: u32,
496 pub o_seq_stride: u32,
497 pub _pad_o: u32,
498}
499
500#[repr(C)]
502#[derive(Debug, Clone, Copy, Pod, Zeroable)]
503pub struct RopeParams {
504 pub n_total: u32,
505 pub seq: u32,
506 pub head_dim: u32,
507 pub half: u32,
508 pub in_off: u32,
509 pub cos_off: u32,
510 pub sin_off: u32,
511 pub out_off: u32,
512 pub last_dim: u32,
513 pub batch: u32,
519 pub seq_stride: u32,
520 pub _p2: u32,
521}
522
523#[repr(C)]
526#[derive(Debug, Clone, Copy, Pod, Zeroable)]
527pub struct ExpandParams {
528 pub rank: u32,
529 pub out_total: u32,
530 pub in_off: u32,
531 pub out_off: u32,
532 pub bucket_outermost: u32,
535 pub out_dim_0: u32,
537 pub _p2: u32,
538 pub _p3: u32,
539}
540
541#[repr(C)]
543#[derive(Debug, Clone, Copy, Pod, Zeroable)]
544pub struct ArgmaxParams {
545 pub outer: u32,
546 pub inner: u32,
547 pub in_off: u32,
548 pub out_off: u32,
549 pub _p0: u32,
550 pub _p1: u32,
551 pub _p2: u32,
552 pub _p3: u32,
553}
554
555#[repr(C)]
557#[derive(Debug, Clone, Copy, Pod, Zeroable)]
558pub struct Pool2dParams {
559 pub n: u32,
560 pub c: u32,
561 pub h: u32,
562 pub w: u32,
563 pub h_out: u32,
564 pub w_out: u32,
565 pub kh: u32,
566 pub kw: u32,
567 pub sh: u32,
568 pub sw: u32,
569 pub ph: u32,
570 pub pw: u32,
571 pub op: u32,
572 pub in_off: u32,
573 pub out_off: u32,
574 pub _p0: u32,
575 pub _p1: u32,
576 pub _p2: u32,
577}
578
579#[repr(C)]
581#[derive(Debug, Clone, Copy, Pod, Zeroable)]
582pub struct Conv2dParams {
583 pub n: u32,
584 pub c_in: u32,
585 pub c_out: u32,
586 pub h: u32,
587 pub w: u32,
588 pub h_out: u32,
589 pub w_out: u32,
590 pub kh: u32,
591 pub kw: u32,
592 pub sh: u32,
593 pub sw: u32,
594 pub ph: u32,
595 pub pw: u32,
596 pub dh: u32,
597 pub dw: u32,
598 pub groups: u32,
599 pub in_off: u32,
600 pub w_off: u32,
601 pub out_off: u32,
602}
603
604#[repr(C)]
606#[derive(Debug, Clone, Copy, Pod, Zeroable)]
607pub struct Pool1dParams {
608 pub n: u32,
609 pub c: u32,
610 pub l: u32,
611 pub l_out: u32,
612 pub kl: u32,
613 pub sl: u32,
614 pub pl: u32,
615 pub op: u32,
616 pub in_off: u32,
617 pub out_off: u32,
618 pub _p0: u32,
619 pub _p1: u32,
620 pub _p2: u32,
621 pub _p3: u32,
622 pub _p4: u32,
623 pub _p5: u32,
624}
625
626#[repr(C)]
628#[derive(Debug, Clone, Copy, Pod, Zeroable)]
629pub struct Pool3dParams {
630 pub n: u32,
631 pub c: u32,
632 pub d: u32,
633 pub h: u32,
634 pub w: u32,
635 pub d_out: u32,
636 pub h_out: u32,
637 pub w_out: u32,
638 pub kd: u32,
639 pub kh: u32,
640 pub kw: u32,
641 pub sd: u32,
642 pub sh: u32,
643 pub sw: u32,
644 pub pd: u32,
645 pub ph: u32,
646 pub pw: u32,
647 pub op: u32,
648 pub in_off: u32,
649 pub out_off: u32,
650 pub _p0: u32,
651 pub _p1: u32,
652}
653
654#[repr(C)]
656#[derive(Debug, Clone, Copy, Pod, Zeroable)]
657pub struct Conv1dParams {
658 pub n: u32,
659 pub c_in: u32,
660 pub c_out: u32,
661 pub l: u32,
662 pub l_out: u32,
663 pub kl: u32,
664 pub sl: u32,
665 pub pl: u32,
666 pub dl: u32,
667 pub groups: u32,
668 pub in_off: u32,
669 pub w_off: u32,
670 pub out_off: u32,
671 pub _p0: u32,
672 pub _p1: u32,
673 pub _p2: u32,
674}
675
676#[repr(C)]
678#[derive(Debug, Clone, Copy, Pod, Zeroable)]
679pub struct DequantMatmulParams {
680 pub m: u32,
681 pub k: u32,
682 pub n: u32,
683 pub block_size: u32,
684 pub scheme_id: u32,
685 pub x_off: u32,
686 pub w_off: u32,
687 pub scale_off: u32,
688 pub zp_off: u32,
689 pub out_off: u32,
690 pub _p0: u32,
691 pub _p1: u32,
692}
693
694#[repr(C)]
696#[derive(Debug, Clone, Copy, Pod, Zeroable)]
697pub struct FusedResidualLnTeeParams {
698 pub outer: u32,
699 pub inner: u32,
700 pub in_off: u32,
701 pub residual_off: u32,
702 pub bias_off: u32,
703 pub gamma_off: u32,
704 pub beta_off: u32,
705 pub sum_off: u32,
706 pub ln_out_off: u32,
707 pub eps_bits: u32,
708 pub has_bias: u32,
709 pub _p0: u32,
710}
711
712#[repr(C)]
715#[derive(Debug, Clone, Copy, Pod, Zeroable)]
716pub struct MatmulQkvParams {
717 pub m: u32,
718 pub k: u32,
719 pub n: u32,
720 pub a_off: u32,
721 pub b_off: u32,
722 pub q_off: u32,
723 pub k_off: u32,
724 pub v_off: u32,
725 pub head_width: u32,
726 pub has_bias: u32,
727 pub bias_off: u32,
728 pub _p0: u32,
729 pub _p1: u32,
730 pub _p2: u32,
731 pub _p3: u32,
732 pub _p4: u32,
733}
734
735pub type FusedResidualRmsNormParams = FusedResidualLnParams;
737
738#[repr(C)]
740#[derive(Debug, Clone, Copy, Pod, Zeroable)]
741pub struct FusedResidualLnParams {
742 pub outer: u32,
743 pub inner: u32,
744 pub in_off: u32,
745 pub residual_off: u32,
746 pub bias_off: u32,
747 pub gamma_off: u32,
748 pub beta_off: u32,
749 pub out_off: u32,
750 pub eps_bits: u32,
751 pub has_bias: u32,
752 pub _p0: u32,
753 pub _p1: u32,
754}
755
756#[repr(C)]
758#[derive(Debug, Clone, Copy, Pod, Zeroable)]
759pub struct SelectiveScanParams {
760 pub batch: u32,
761 pub seq: u32,
762 pub hidden: u32,
763 pub state_size: u32,
764 pub x_off: u32,
765 pub delta_off: u32,
766 pub a_off: u32,
767 pub b_off: u32,
768 pub c_off: u32,
769 pub out_off: u32,
770 pub seq_stride: u32,
774 pub _p1: u32,
775 pub _p2: u32,
776 pub _p3: u32,
777 pub _p4: u32,
778 pub _p5: u32,
779}
780
781#[repr(C)]
783#[derive(Debug, Clone, Copy, Pod, Zeroable)]
784pub struct SampleParams {
785 pub outer: u32,
786 pub inner: u32,
787 pub in_off: u32,
788 pub out_off: u32,
789 pub top_k: u32,
790 pub top_p_bits: u32,
791 pub temp_bits: u32,
792 pub seed_lo: u32,
793 pub seed_hi: u32,
794 pub _p0: u32,
795 pub _p1: u32,
796 pub _p2: u32,
797}
798
799#[repr(C)]
801#[derive(Debug, Clone, Copy, Pod, Zeroable)]
802pub struct GroupedMatmulParams {
803 pub m: u32,
804 pub k: u32,
805 pub n: u32,
806 pub num_experts: u32,
807 pub in_off: u32,
808 pub w_off: u32,
809 pub idx_off: u32,
810 pub out_off: u32,
811}
812
813#[repr(C)]
815#[derive(Debug, Clone, Copy, Pod, Zeroable)]
816pub struct TopKParams {
817 pub outer: u32,
818 pub inner: u32,
819 pub k: u32,
820 pub in_off: u32,
821 pub out_off: u32,
822 pub _p0: u32,
823 pub _p1: u32,
824 pub _p2: u32,
825}
826
827#[repr(C)]
829#[derive(Debug, Clone, Copy, Pod, Zeroable)]
830pub struct UmapKnnParams {
831 pub n: u32,
832 pub k: u32,
833 pub pw_off: u32,
834 pub out_off: u32,
835 pub _p0: u32,
836 pub _p1: u32,
837 pub _p2: u32,
838}
839
840#[repr(C)]
842#[derive(Debug, Clone, Copy, Pod, Zeroable)]
843pub struct ScatterAddParams {
844 pub op: u32, pub out_off: u32,
846 pub upd_off: u32,
847 pub idx_off: u32,
848 pub out_total: u32,
849 pub num_updates: u32,
850 pub trailing: u32,
851 pub out_dim: u32,
852}
853
854#[repr(C)]
856#[derive(Debug, Clone, Copy, Pod, Zeroable)]
857pub struct Conv3dParams {
858 pub n: u32,
859 pub c_in: u32,
860 pub c_out: u32,
861 pub d: u32,
862 pub h: u32,
863 pub w: u32,
864 pub d_out: u32,
865 pub h_out: u32,
866 pub w_out: u32,
867 pub kd: u32,
868 pub kh: u32,
869 pub kw: u32,
870 pub sd: u32,
871 pub sh: u32,
872 pub sw: u32,
873 pub pd: u32,
874 pub ph: u32,
875 pub pw: u32,
876 pub dd: u32,
877 pub dh: u32,
878 pub dw: u32,
879 pub groups: u32,
880 pub in_off: u32,
881 pub w_off: u32,
882 pub out_off: u32,
883 pub _p0: u32,
884}
885
886pub struct Kernel {
888 pub pipeline: wgpu::ComputePipeline,
889 pub bgl: wgpu::BindGroupLayout,
890}
891
892impl Kernel {
893 pub fn bind_two(
894 &self,
895 device: &wgpu::Device,
896 arena: &wgpu::Buffer,
897 uniform: &wgpu::Buffer,
898 ) -> wgpu::BindGroup {
899 device.create_bind_group(&wgpu::BindGroupDescriptor {
900 label: Some("rlx-wgpu fft gpu bg"),
901 layout: &self.bgl,
902 entries: &[
903 wgpu::BindGroupEntry {
904 binding: 0,
905 resource: arena.as_entire_binding(),
906 },
907 wgpu::BindGroupEntry {
908 binding: 1,
909 resource: uniform.as_entire_binding(),
910 },
911 ],
912 })
913 }
914}
915
916#[allow(dead_code)]
922fn build_kernel_4(
925 device: &wgpu::Device,
926 label: &'static str,
927 wgsl: &str,
928 entry_point: &'static str,
929) -> Kernel {
930 let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
931 label: Some(label),
932 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
933 });
934 let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
935 label: Some(label),
936 entries: &[
937 wgpu::BindGroupLayoutEntry {
938 binding: 0,
939 visibility: wgpu::ShaderStages::COMPUTE,
940 ty: wgpu::BindingType::Buffer {
941 ty: wgpu::BufferBindingType::Storage { read_only: false },
942 has_dynamic_offset: false,
943 min_binding_size: None,
944 },
945 count: None,
946 },
947 wgpu::BindGroupLayoutEntry {
948 binding: 1,
949 visibility: wgpu::ShaderStages::COMPUTE,
950 ty: wgpu::BindingType::Buffer {
951 ty: wgpu::BufferBindingType::Uniform,
952 has_dynamic_offset: false,
953 min_binding_size: None,
954 },
955 count: None,
956 },
957 wgpu::BindGroupLayoutEntry {
958 binding: 2,
959 visibility: wgpu::ShaderStages::COMPUTE,
960 ty: wgpu::BindingType::Buffer {
961 ty: wgpu::BufferBindingType::Storage { read_only: true },
962 has_dynamic_offset: false,
963 min_binding_size: None,
964 },
965 count: None,
966 },
967 wgpu::BindGroupLayoutEntry {
968 binding: 3,
969 visibility: wgpu::ShaderStages::COMPUTE,
970 ty: wgpu::BindingType::Buffer {
971 ty: wgpu::BufferBindingType::Storage { read_only: true },
972 has_dynamic_offset: false,
973 min_binding_size: None,
974 },
975 count: None,
976 },
977 ],
978 });
979 let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
980 label: Some(label),
981 bind_group_layouts: &[Some(&bgl)],
982 immediate_size: 0,
983 });
984 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
985 label: Some(label),
986 layout: Some(&layout),
987 module: &module,
988 entry_point: Some(entry_point),
989 compilation_options: Default::default(),
990 cache: None,
991 });
992 Kernel { pipeline, bgl }
993}
994
995fn build_kernel_3(
996 device: &wgpu::Device,
997 label: &'static str,
998 wgsl: &str,
999 entry_point: &'static str,
1000) -> Kernel {
1001 let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1002 label: Some(label),
1003 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1004 });
1005 let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1006 label: Some(label),
1007 entries: &[
1008 wgpu::BindGroupLayoutEntry {
1009 binding: 0,
1010 visibility: wgpu::ShaderStages::COMPUTE,
1011 ty: wgpu::BindingType::Buffer {
1012 ty: wgpu::BufferBindingType::Storage { read_only: false },
1013 has_dynamic_offset: false,
1014 min_binding_size: None,
1015 },
1016 count: None,
1017 },
1018 wgpu::BindGroupLayoutEntry {
1019 binding: 1,
1020 visibility: wgpu::ShaderStages::COMPUTE,
1021 ty: wgpu::BindingType::Buffer {
1022 ty: wgpu::BufferBindingType::Uniform,
1023 has_dynamic_offset: false,
1024 min_binding_size: None,
1025 },
1026 count: None,
1027 },
1028 wgpu::BindGroupLayoutEntry {
1029 binding: 2,
1030 visibility: wgpu::ShaderStages::COMPUTE,
1031 ty: wgpu::BindingType::Buffer {
1032 ty: wgpu::BufferBindingType::Storage { read_only: true },
1033 has_dynamic_offset: false,
1034 min_binding_size: None,
1035 },
1036 count: None,
1037 },
1038 ],
1039 });
1040 let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1041 label: Some(label),
1042 bind_group_layouts: &[Some(&bgl)],
1043 immediate_size: 0,
1044 });
1045 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1046 label: Some(label),
1047 layout: Some(&layout),
1048 module: &module,
1049 entry_point: Some(entry_point),
1050 compilation_options: Default::default(),
1051 cache: None,
1052 });
1053 Kernel { pipeline, bgl }
1054}
1055
1056fn build_kernel(
1057 device: &wgpu::Device,
1058 label: &'static str,
1059 wgsl: &str,
1060 entry_point: &'static str,
1061) -> Kernel {
1062 let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1063 label: Some(label),
1064 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1065 });
1066 let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1067 label: Some(label),
1068 entries: &[
1069 wgpu::BindGroupLayoutEntry {
1070 binding: 0,
1071 visibility: wgpu::ShaderStages::COMPUTE,
1072 ty: wgpu::BindingType::Buffer {
1073 ty: wgpu::BufferBindingType::Storage { read_only: false },
1074 has_dynamic_offset: false,
1075 min_binding_size: None,
1076 },
1077 count: None,
1078 },
1079 wgpu::BindGroupLayoutEntry {
1080 binding: 1,
1081 visibility: wgpu::ShaderStages::COMPUTE,
1082 ty: wgpu::BindingType::Buffer {
1083 ty: wgpu::BufferBindingType::Uniform,
1084 has_dynamic_offset: false,
1085 min_binding_size: None,
1086 },
1087 count: None,
1088 },
1089 ],
1090 });
1091 let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1092 label: Some(label),
1093 bind_group_layouts: &[Some(&bgl)],
1094 immediate_size: 0,
1095 });
1096 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1097 label: Some(label),
1098 layout: Some(&layout),
1099 module: &module,
1100 entry_point: Some(entry_point),
1101 compilation_options: Default::default(),
1102 cache: None,
1103 });
1104 Kernel { pipeline, bgl }
1105}
1106
1107static MATMUL: OnceLock<Kernel> = OnceLock::new();
1108static MATMUL_WIDE: OnceLock<Kernel> = OnceLock::new();
1109static MATMUL_F16W: OnceLock<Kernel> = OnceLock::new();
1110static MATMUL_F16_COMPUTE: OnceLock<Kernel> = OnceLock::new();
1111static MATMUL_COOP16: OnceLock<Kernel> = OnceLock::new();
1112static MATMUL_COOP_F32: OnceLock<Kernel> = OnceLock::new();
1113static CAST_F32_TO_F16: OnceLock<Kernel> = OnceLock::new();
1114static BINARY: OnceLock<Kernel> = OnceLock::new();
1115static UNARY: OnceLock<Kernel> = OnceLock::new();
1116static COMPARE: OnceLock<Kernel> = OnceLock::new();
1117static WHEREK: OnceLock<Kernel> = OnceLock::new();
1118static REDUCE: OnceLock<Kernel> = OnceLock::new();
1119static SOFTMAX: OnceLock<Kernel> = OnceLock::new();
1120static LAYERNORM: OnceLock<Kernel> = OnceLock::new();
1121static RMS_NORM_BWD: OnceLock<Kernel> = OnceLock::new();
1122static RMS_NORM_BWD_PARAM: OnceLock<Kernel> = OnceLock::new();
1123static CUMSUM_BWD: OnceLock<Kernel> = OnceLock::new();
1124static ROPE_BWD: OnceLock<Kernel> = OnceLock::new();
1125static GATHER_BWD_ZERO: OnceLock<Kernel> = OnceLock::new();
1126static GATHER_BWD_ACC: OnceLock<Kernel> = OnceLock::new();
1127static CUMSUM: OnceLock<Kernel> = OnceLock::new();
1128static FFT_GPU_RADIX2: OnceLock<Kernel> = OnceLock::new();
1129static FFT_GPU_BITREV: OnceLock<Kernel> = OnceLock::new();
1130static FFT_GPU_INNER: OnceLock<Kernel> = OnceLock::new();
1131static FFT_GPU_OUTER_R4: OnceLock<Kernel> = OnceLock::new();
1132static FFT_GPU_OUTER_R2: OnceLock<Kernel> = OnceLock::new();
1133static COPY: OnceLock<Kernel> = OnceLock::new();
1134static ELEMENTWISE_REGION: OnceLock<Kernel> = OnceLock::new();
1135static TRANSPOSE: OnceLock<Kernel> = OnceLock::new();
1136static NARROW: OnceLock<Kernel> = OnceLock::new();
1137static CONCAT: OnceLock<Kernel> = OnceLock::new();
1138static GATHER: OnceLock<Kernel> = OnceLock::new();
1139static GATHER_AXIS: OnceLock<Kernel> = OnceLock::new();
1140static ATTENTION: OnceLock<Kernel> = OnceLock::new();
1141static ATTENTION_BWD: OnceLock<Kernel> = OnceLock::new();
1142static ROPE: OnceLock<Kernel> = OnceLock::new();
1143static EXPAND: OnceLock<Kernel> = OnceLock::new();
1144static ARGMAX: OnceLock<Kernel> = OnceLock::new();
1145static POOL2D: OnceLock<Kernel> = OnceLock::new();
1146static CONV2D: OnceLock<Kernel> = OnceLock::new();
1147static POOL1D: OnceLock<Kernel> = OnceLock::new();
1148static POOL3D: OnceLock<Kernel> = OnceLock::new();
1149static CONV1D: OnceLock<Kernel> = OnceLock::new();
1150static CONV3D: OnceLock<Kernel> = OnceLock::new();
1151static SCATTER_ADD: OnceLock<Kernel> = OnceLock::new();
1152static TOPK: OnceLock<Kernel> = OnceLock::new();
1153static UMAP_KNN: OnceLock<Kernel> = OnceLock::new();
1154static GROUPED_MATMUL: OnceLock<Kernel> = OnceLock::new();
1155static SAMPLE: OnceLock<Kernel> = OnceLock::new();
1156static SELECTIVE_SCAN: OnceLock<Kernel> = OnceLock::new();
1157static DEQUANT_MATMUL: OnceLock<Kernel> = OnceLock::new();
1158static FUSED_RESIDUAL_LN: OnceLock<Kernel> = OnceLock::new();
1159static FUSED_RESIDUAL_LN_TEE: OnceLock<Kernel> = OnceLock::new();
1160static FUSED_RESIDUAL_RMS_NORM: OnceLock<Kernel> = OnceLock::new();
1161static MATMUL_QKV: OnceLock<Kernel> = OnceLock::new();
1162static MATMUL_QKV_COOP_F32: OnceLock<Kernel> = OnceLock::new();
1163
1164pub fn matmul_kernel(device: &wgpu::Device) -> &'static Kernel {
1165 MATMUL.get_or_init(|| build_kernel(device, "rlx-wgpu matmul", MATMUL_WGSL, "matmul"))
1166}
1167pub fn matmul_wide_kernel(device: &wgpu::Device) -> &'static Kernel {
1168 MATMUL_WIDE.get_or_init(|| {
1169 build_kernel(
1170 device,
1171 "rlx-wgpu matmul_wide",
1172 MATMUL_WIDE_WGSL,
1173 "matmul_wide",
1174 )
1175 })
1176}
1177pub fn matmul_f16w_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1182 if !device.features().contains(wgpu::Features::SHADER_F16) {
1183 return None;
1184 }
1185 Some(MATMUL_F16W.get_or_init(|| {
1186 build_kernel_3(
1187 device,
1188 "rlx-wgpu matmul_f16w",
1189 MATMUL_F16W_WGSL,
1190 "matmul_f16w",
1191 )
1192 }))
1193}
1194pub fn matmul_f16_compute_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1198 if !device.features().contains(wgpu::Features::SHADER_F16) {
1199 return None;
1200 }
1201 Some(MATMUL_F16_COMPUTE.get_or_init(|| {
1202 build_kernel_3(
1203 device,
1204 "rlx-wgpu matmul_f16_compute",
1205 MATMUL_F16_COMPUTE_WGSL,
1206 "matmul_f16_compute",
1207 )
1208 }))
1209}
1210pub fn matmul_coop16_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1216 let feats = device.features();
1217 if !feats.contains(wgpu::Features::SHADER_F16)
1218 || !feats.contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX)
1219 {
1220 return None;
1221 }
1222 Some(MATMUL_COOP16.get_or_init(|| {
1223 build_kernel_3(
1224 device,
1225 "rlx-wgpu matmul_coop16",
1226 MATMUL_COOP16_WGSL,
1227 "matmul_coop16",
1228 )
1229 }))
1230}
1231pub fn matmul_coop_f32_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1237 let feats = device.features();
1238 if !feats.contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX) {
1239 return None;
1240 }
1241 Some(MATMUL_COOP_F32.get_or_init(|| {
1242 build_kernel(
1243 device,
1244 "rlx-wgpu matmul_coop_f32",
1245 MATMUL_COOP_F32_WGSL,
1246 "matmul_coop_f32",
1247 )
1248 }))
1249}
1250pub fn cast_f32_to_f16_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1255 if !device.features().contains(wgpu::Features::SHADER_F16) {
1256 return None;
1257 }
1258 Some(CAST_F32_TO_F16.get_or_init(|| {
1259 build_kernel_3(
1260 device,
1261 "rlx-wgpu cast_f32_to_f16",
1262 CAST_F32_TO_F16_WGSL,
1263 "cast_f32_to_f16",
1264 )
1265 }))
1266}
1267pub fn binary_kernel(device: &wgpu::Device) -> &'static Kernel {
1268 BINARY.get_or_init(|| build_kernel(device, "rlx-wgpu binary", BINARY_WGSL, "binary"))
1269}
1270pub fn unary_kernel(device: &wgpu::Device) -> &'static Kernel {
1271 UNARY.get_or_init(|| build_kernel(device, "rlx-wgpu unary", UNARY_WGSL, "unary"))
1272}
1273pub fn compare_kernel(device: &wgpu::Device) -> &'static Kernel {
1274 COMPARE.get_or_init(|| build_kernel(device, "rlx-wgpu compare", COMPARE_WGSL, "compare"))
1275}
1276pub fn where_kernel(device: &wgpu::Device) -> &'static Kernel {
1277 WHEREK.get_or_init(|| build_kernel(device, "rlx-wgpu where", WHERE_WGSL, "where_select"))
1278}
1279pub fn reduce_kernel(device: &wgpu::Device) -> &'static Kernel {
1280 REDUCE.get_or_init(|| build_kernel(device, "rlx-wgpu reduce", REDUCE_WGSL, "reduce"))
1281}
1282pub fn softmax_kernel(device: &wgpu::Device) -> &'static Kernel {
1283 SOFTMAX.get_or_init(|| build_kernel(device, "rlx-wgpu softmax", SOFTMAX_WGSL, "softmax"))
1284}
1285pub fn layernorm_kernel(device: &wgpu::Device) -> &'static Kernel {
1286 LAYERNORM.get_or_init(|| build_kernel(device, "rlx-wgpu layernorm", LAYERNORM_WGSL, "norm"))
1287}
1288pub fn rms_norm_backward_kernel(device: &wgpu::Device) -> &'static Kernel {
1289 RMS_NORM_BWD.get_or_init(|| {
1290 build_kernel(
1291 device,
1292 "rlx-wgpu rms_norm_bwd",
1293 RMS_NORM_BWD_WGSL,
1294 "rms_norm_bwd",
1295 )
1296 })
1297}
1298pub fn rms_norm_backward_param_kernel(device: &wgpu::Device) -> &'static Kernel {
1299 RMS_NORM_BWD_PARAM.get_or_init(|| {
1300 build_kernel(
1301 device,
1302 "rlx-wgpu rms_norm_bwd_param",
1303 RMS_NORM_BWD_WGSL,
1304 "rms_norm_bwd_param",
1305 )
1306 })
1307}
1308pub fn cumsum_backward_kernel(device: &wgpu::Device) -> &'static Kernel {
1309 CUMSUM_BWD
1310 .get_or_init(|| build_kernel(device, "rlx-wgpu cumsum_bwd", CUMSUM_BWD_WGSL, "cumsum_bwd"))
1311}
1312pub fn rope_backward_kernel(device: &wgpu::Device) -> &'static Kernel {
1313 ROPE_BWD.get_or_init(|| build_kernel(device, "rlx-wgpu rope_bwd", ROPE_BWD_WGSL, "rope_bwd"))
1314}
1315pub fn gather_backward_zero_kernel(device: &wgpu::Device) -> &'static Kernel {
1316 GATHER_BWD_ZERO.get_or_init(|| {
1317 build_kernel(
1318 device,
1319 "rlx-wgpu gather_bwd_zero",
1320 GATHER_BWD_WGSL,
1321 "gather_bwd_zero",
1322 )
1323 })
1324}
1325pub fn gather_backward_acc_kernel(device: &wgpu::Device) -> &'static Kernel {
1326 GATHER_BWD_ACC.get_or_init(|| {
1327 build_kernel(
1328 device,
1329 "rlx-wgpu gather_bwd_acc",
1330 GATHER_BWD_WGSL,
1331 "gather_bwd_acc",
1332 )
1333 })
1334}
1335pub fn cumsum_kernel(device: &wgpu::Device) -> &'static Kernel {
1336 CUMSUM.get_or_init(|| build_kernel(device, "rlx-wgpu cumsum", CUMSUM_WGSL, "cumsum"))
1337}
1338pub fn fft_gpu_radix2_full_kernel(device: &wgpu::Device) -> &'static Kernel {
1339 FFT_GPU_RADIX2.get_or_init(|| {
1340 build_kernel(
1341 device,
1342 "rlx-wgpu fft_radix2_full",
1343 FFT_GPU_WGSL,
1344 "fft_radix2_full",
1345 )
1346 })
1347}
1348pub fn fft_gpu_bit_reverse_kernel(device: &wgpu::Device) -> &'static Kernel {
1349 FFT_GPU_BITREV.get_or_init(|| {
1350 build_kernel(
1351 device,
1352 "rlx-wgpu fft_bit_reverse",
1353 FFT_GPU_WGSL,
1354 "fft_bit_reverse",
1355 )
1356 })
1357}
1358pub fn fft_gpu_inner_kernel(device: &wgpu::Device) -> &'static Kernel {
1359 FFT_GPU_INNER
1360 .get_or_init(|| build_kernel(device, "rlx-wgpu fft_inner", FFT_GPU_WGSL, "fft_inner"))
1361}
1362pub fn fft_gpu_outer_r4_kernel(device: &wgpu::Device) -> &'static Kernel {
1363 FFT_GPU_OUTER_R4.get_or_init(|| {
1364 build_kernel(
1365 device,
1366 "rlx-wgpu fft_outer_r4",
1367 FFT_GPU_WGSL,
1368 "fft_outer_r4",
1369 )
1370 })
1371}
1372pub fn fft_gpu_outer_r2_kernel(device: &wgpu::Device) -> &'static Kernel {
1373 FFT_GPU_OUTER_R2.get_or_init(|| {
1374 build_kernel(
1375 device,
1376 "rlx-wgpu fft_outer_r2",
1377 FFT_GPU_WGSL,
1378 "fft_outer_r2",
1379 )
1380 })
1381}
1382pub fn copy_kernel(device: &wgpu::Device) -> &'static Kernel {
1383 COPY.get_or_init(|| build_kernel(device, "rlx-wgpu copy", COPY_WGSL, "copy"))
1384}
1385pub fn elementwise_region_kernel(device: &wgpu::Device) -> &'static Kernel {
1386 ELEMENTWISE_REGION.get_or_init(|| {
1391 build_kernel_region(
1392 device,
1393 "rlx-wgpu elementwise_region",
1394 ELEMENTWISE_REGION_WGSL,
1395 "elementwise_region",
1396 )
1397 })
1398}
1399
1400fn build_kernel_region(
1401 device: &wgpu::Device,
1402 label: &'static str,
1403 wgsl: &str,
1404 entry_point: &'static str,
1405) -> Kernel {
1406 let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1407 label: Some(label),
1408 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1409 });
1410 let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1411 label: Some(label),
1412 entries: &[
1413 wgpu::BindGroupLayoutEntry {
1414 binding: 0,
1415 visibility: wgpu::ShaderStages::COMPUTE,
1416 ty: wgpu::BindingType::Buffer {
1417 ty: wgpu::BufferBindingType::Storage { read_only: false },
1418 has_dynamic_offset: false,
1419 min_binding_size: None,
1420 },
1421 count: None,
1422 },
1423 wgpu::BindGroupLayoutEntry {
1424 binding: 1,
1425 visibility: wgpu::ShaderStages::COMPUTE,
1426 ty: wgpu::BindingType::Buffer {
1427 ty: wgpu::BufferBindingType::Storage { read_only: true },
1429 has_dynamic_offset: false,
1430 min_binding_size: None,
1431 },
1432 count: None,
1433 },
1434 ],
1435 });
1436 let pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1437 label: Some(label),
1438 bind_group_layouts: &[Some(&bgl)],
1439 immediate_size: 0,
1440 });
1441 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1442 label: Some(label),
1443 layout: Some(&pl),
1444 module: &module,
1445 entry_point: Some(entry_point),
1446 compilation_options: Default::default(),
1447 cache: None,
1448 });
1449 Kernel { pipeline, bgl }
1450}
1451pub fn transpose_kernel(device: &wgpu::Device) -> &'static Kernel {
1452 TRANSPOSE
1453 .get_or_init(|| build_kernel_3(device, "rlx-wgpu transpose", TRANSPOSE_WGSL, "transpose"))
1454}
1455pub fn narrow_kernel(device: &wgpu::Device) -> &'static Kernel {
1456 NARROW.get_or_init(|| build_kernel(device, "rlx-wgpu narrow", NARROW_WGSL, "narrow"))
1457}
1458pub fn concat_kernel(device: &wgpu::Device) -> &'static Kernel {
1459 CONCAT.get_or_init(|| build_kernel(device, "rlx-wgpu concat", CONCAT_WGSL, "concat"))
1460}
1461pub fn gather_kernel(device: &wgpu::Device) -> &'static Kernel {
1462 GATHER.get_or_init(|| build_kernel(device, "rlx-wgpu gather", GATHER_WGSL, "gather"))
1463}
1464pub fn gather_axis_kernel(device: &wgpu::Device) -> &'static Kernel {
1465 GATHER_AXIS.get_or_init(|| {
1466 build_kernel(
1467 device,
1468 "rlx-wgpu gather_axis",
1469 GATHER_AXIS_WGSL,
1470 "gather_axis",
1471 )
1472 })
1473}
1474pub fn attention_kernel(device: &wgpu::Device) -> &'static Kernel {
1475 ATTENTION
1476 .get_or_init(|| build_kernel(device, "rlx-wgpu attention", ATTENTION_WGSL, "attention"))
1477}
1478pub fn attention_bwd_kernel(device: &wgpu::Device) -> &'static Kernel {
1479 ATTENTION_BWD.get_or_init(|| {
1480 build_kernel(
1481 device,
1482 "rlx-wgpu attention_bwd",
1483 ATTENTION_BWD_WGSL,
1484 "attention_bwd",
1485 )
1486 })
1487}
1488pub fn rope_kernel(device: &wgpu::Device) -> &'static Kernel {
1489 ROPE.get_or_init(|| build_kernel(device, "rlx-wgpu rope", ROPE_WGSL, "rope"))
1490}
1491pub fn expand_kernel(device: &wgpu::Device) -> &'static Kernel {
1492 EXPAND.get_or_init(|| build_kernel_3(device, "rlx-wgpu expand", EXPAND_WGSL, "expand"))
1493}
1494pub fn argmax_kernel(device: &wgpu::Device) -> &'static Kernel {
1495 ARGMAX.get_or_init(|| build_kernel(device, "rlx-wgpu argmax", ARGMAX_WGSL, "argmax"))
1496}
1497pub fn pool2d_kernel(device: &wgpu::Device) -> &'static Kernel {
1498 POOL2D.get_or_init(|| build_kernel(device, "rlx-wgpu pool2d", POOL2D_WGSL, "pool2d"))
1499}
1500pub fn conv2d_kernel(device: &wgpu::Device) -> &'static Kernel {
1501 CONV2D.get_or_init(|| build_kernel(device, "rlx-wgpu conv2d", CONV2D_WGSL, "conv2d"))
1502}
1503pub fn pool1d_kernel(device: &wgpu::Device) -> &'static Kernel {
1504 POOL1D.get_or_init(|| build_kernel(device, "rlx-wgpu pool1d", POOL1D_WGSL, "pool1d"))
1505}
1506pub fn pool3d_kernel(device: &wgpu::Device) -> &'static Kernel {
1507 POOL3D.get_or_init(|| build_kernel(device, "rlx-wgpu pool3d", POOL3D_WGSL, "pool3d"))
1508}
1509pub fn conv1d_kernel(device: &wgpu::Device) -> &'static Kernel {
1510 CONV1D.get_or_init(|| build_kernel(device, "rlx-wgpu conv1d", CONV1D_WGSL, "conv1d"))
1511}
1512pub fn conv3d_kernel(device: &wgpu::Device) -> &'static Kernel {
1513 CONV3D.get_or_init(|| build_kernel(device, "rlx-wgpu conv3d", CONV3D_WGSL, "conv3d"))
1514}
1515pub fn scatter_add_kernel(device: &wgpu::Device) -> &'static Kernel {
1516 SCATTER_ADD.get_or_init(|| {
1517 build_kernel(
1518 device,
1519 "rlx-wgpu scatter_add",
1520 SCATTER_ADD_WGSL,
1521 "scatter_add",
1522 )
1523 })
1524}
1525pub fn topk_kernel(device: &wgpu::Device) -> &'static Kernel {
1526 TOPK.get_or_init(|| build_kernel(device, "rlx-wgpu topk", TOPK_WGSL, "topk"))
1527}
1528pub fn umap_knn_kernel(device: &wgpu::Device) -> &'static Kernel {
1529 UMAP_KNN.get_or_init(|| build_kernel(device, "rlx-wgpu umap_knn", UMAP_KNN_WGSL, "umap_knn"))
1530}
1531pub fn grouped_matmul_kernel(device: &wgpu::Device) -> &'static Kernel {
1532 GROUPED_MATMUL.get_or_init(|| {
1533 build_kernel(
1534 device,
1535 "rlx-wgpu grouped_matmul",
1536 GROUPED_MATMUL_WGSL,
1537 "grouped_matmul",
1538 )
1539 })
1540}
1541pub fn sample_kernel(device: &wgpu::Device) -> &'static Kernel {
1542 SAMPLE.get_or_init(|| build_kernel(device, "rlx-wgpu sample", SAMPLE_WGSL, "sample"))
1543}
1544pub fn selective_scan_kernel(device: &wgpu::Device) -> &'static Kernel {
1545 SELECTIVE_SCAN.get_or_init(|| {
1546 build_kernel(
1547 device,
1548 "rlx-wgpu selective_scan",
1549 SELECTIVE_SCAN_WGSL,
1550 "selective_scan",
1551 )
1552 })
1553}
1554pub fn dequant_matmul_kernel(device: &wgpu::Device) -> &'static Kernel {
1555 DEQUANT_MATMUL.get_or_init(|| {
1556 build_kernel(
1557 device,
1558 "rlx-wgpu dequant_matmul",
1559 DEQUANT_MATMUL_WGSL,
1560 "dequant_matmul",
1561 )
1562 })
1563}
1564pub fn fused_residual_ln_kernel(device: &wgpu::Device) -> &'static Kernel {
1565 FUSED_RESIDUAL_LN.get_or_init(|| {
1566 build_kernel(
1567 device,
1568 "rlx-wgpu fused_residual_ln",
1569 FUSED_RESIDUAL_LN_WGSL,
1570 "fused_residual_ln",
1571 )
1572 })
1573}
1574pub fn fused_residual_ln_tee_kernel(device: &wgpu::Device) -> &'static Kernel {
1575 FUSED_RESIDUAL_LN_TEE.get_or_init(|| {
1576 build_kernel(
1577 device,
1578 "rlx-wgpu fused_residual_ln_tee",
1579 FUSED_RESIDUAL_LN_TEE_WGSL,
1580 "fused_residual_ln_tee",
1581 )
1582 })
1583}
1584pub fn fused_residual_rms_norm_kernel(device: &wgpu::Device) -> &'static Kernel {
1585 FUSED_RESIDUAL_RMS_NORM.get_or_init(|| {
1586 build_kernel(
1587 device,
1588 "rlx-wgpu fused_residual_rms_norm",
1589 FUSED_RESIDUAL_RMS_NORM_WGSL,
1590 "fused_residual_rms_norm",
1591 )
1592 })
1593}
1594pub fn matmul_qkv_kernel(device: &wgpu::Device) -> &'static Kernel {
1595 MATMUL_QKV
1596 .get_or_init(|| build_kernel(device, "rlx-wgpu matmul_qkv", MATMUL_QKV_WGSL, "matmul_qkv"))
1597}
1598pub fn matmul_qkv_coop_f32_kernel(device: &wgpu::Device) -> Option<&'static Kernel> {
1599 if !device
1600 .features()
1601 .contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX)
1602 {
1603 return None;
1604 }
1605 Some(MATMUL_QKV_COOP_F32.get_or_init(|| {
1606 build_kernel(
1607 device,
1608 "rlx-wgpu matmul_qkv_coop_f32",
1609 MATMUL_QKV_COOP_F32_WGSL,
1610 "matmul_qkv_coop_f32",
1611 )
1612 }))
1613}