1use std::collections::{HashMap, HashSet};
23
24use rlx_ir::dynamic::{bind_graph, has_dynamic_dims, infer_bindings_from_f32_inputs, same_binding};
25use rlx_ir::op::{Activation, BinaryOp, CmpOp, MaskKind, ReduceOp};
26use rlx_ir::shape::DimBinding;
27use rlx_ir::{Graph, NodeId, Op};
28
29use crate::buffer::{Arena, plan_f32_uniform};
30use crate::device::wgpu_device;
31use crate::kernels::{
32 ArgmaxParams, AttentionBwdParams, AttentionParams, BinaryParams, Conv1dParams, Conv2dParams,
33 Conv3dParams, CopyParams, CumsumBwdParams, CumsumParams, DequantMatmulParams,
34 ElementwiseRegionParams, ExpandParams, FusedResidualLnParams, FusedResidualLnTeeParams,
35 FusedResidualRmsNormParams, GatherAxisParams, GatherBwdParams, GatherParams,
36 GroupedMatmulParams, Kernel, LayerNormParams, MatmulParams, MatmulQkvParams,
37 NarrowConcatParams, Pool1dParams, Pool2dParams, Pool3dParams, ReduceParams, RmsNormBwdParams,
38 RopeBwdParams, RopeParams, SampleParams, ScatterAddParams, SelectiveScanParams, SoftmaxParams,
39 TopKParams, TransposeParams, UmapKnnParams, UnaryParams, WhereParams, argmax_kernel,
40 attention_bwd_kernel, attention_kernel, binary_kernel, cast_f32_to_f16_kernel, compare_kernel,
41 concat_kernel, conv1d_kernel, conv2d_kernel, conv3d_kernel, copy_kernel,
42 cumsum_backward_kernel, cumsum_kernel, dequant_matmul_kernel, elementwise_region_kernel,
43 expand_kernel, fused_residual_ln_kernel, fused_residual_ln_tee_kernel,
44 fused_residual_rms_norm_kernel, gather_axis_kernel, gather_backward_acc_kernel,
45 gather_backward_zero_kernel, gather_kernel, grouped_matmul_kernel, layernorm_kernel,
46 matmul_coop_f32_kernel, matmul_coop16_kernel, matmul_f16_compute_kernel, matmul_f16w_kernel,
47 matmul_kernel, matmul_qkv_coop_f32_kernel, matmul_qkv_kernel, matmul_wide_kernel,
48 narrow_kernel, pool1d_kernel, pool2d_kernel, pool3d_kernel, reduce_kernel,
49 rms_norm_backward_kernel, rms_norm_backward_param_kernel, rope_backward_kernel, rope_kernel,
50 sample_kernel, scatter_add_kernel, selective_scan_kernel, softmax_kernel, topk_kernel,
51 transpose_kernel, umap_knn_kernel, unary_kernel, where_kernel,
52};
53use rlx_ir::op::{ChainOperand, ChainStep};
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66enum MatmulCompute {
67 F32,
68 F16,
69 Coop16,
70 CoopF32,
75}
76
77#[allow(dead_code)]
87#[derive(Debug, Clone, Copy)]
88struct CastF32ToF16Params {
89 pub src_off: u32, pub len: u32,
91 pub _p0: u32,
92 pub _p1: u32,
93}
94unsafe impl bytemuck::Pod for CastF32ToF16Params {}
95unsafe impl bytemuck::Zeroable for CastF32ToF16Params {}
96
97#[allow(dead_code)]
106enum Step {
107 CastF32ToF16 {
108 params: CastF32ToF16Params,
109 },
110 Matmul {
111 m: u32,
112 k: u32,
113 n: u32,
114 a_off_f32: u32,
115 b_off_f32: u32,
116 c_off_f32: u32,
117 batch: u32,
118 a_batch_stride: u32,
119 b_batch_stride: u32,
120 c_batch_stride: u32,
121 has_bias: u32,
122 bias_off_f32: u32,
123 act_id: u32, b_is_param: bool,
130 compute_precision: MatmulCompute,
136 },
137 Binary {
138 params: BinaryParams,
139 },
140 Compare {
141 params: BinaryParams,
142 },
143 Unary {
144 params: UnaryParams,
145 },
146 Where {
147 params: WhereParams,
148 },
149 Reduce {
150 params: ReduceParams,
151 },
152 Softmax {
153 params: SoftmaxParams,
154 },
155 LayerNorm {
156 params: LayerNormParams,
157 },
158 Cumsum {
159 params: CumsumParams,
160 },
161 FftGpu {
163 src_off: u32,
164 dst_off: u32,
165 outer: u32,
166 n: u32,
167 inverse: u32,
168 norm_scale: f32,
169 },
170 FftHost {
173 src_byte_off: u32,
174 dst_byte_off: u32,
175 outer: u32,
176 n_complex: u32,
177 inverse: bool,
178 norm_tag: u32,
179 dtype_tag: u32,
180 },
181 Copy {
182 params: CopyParams,
183 },
184 ElementwiseRegion {
190 params: ElementwiseRegionParams,
191 },
192 Transpose {
193 params: TransposeParams,
194 meta_idx: usize,
195 },
196 Narrow {
197 params: NarrowConcatParams,
198 },
199 Concat {
200 params: NarrowConcatParams,
201 }, Gather {
203 params: GatherParams,
204 },
205 GatherAxis {
206 params: GatherAxisParams,
207 },
208 Attention {
209 params: AttentionParams,
210 mask_buf: Option<wgpu::Buffer>,
211 },
212 AttentionBackward {
213 params: AttentionBwdParams,
214 mask_buf: Option<wgpu::Buffer>,
215 },
216 Rope {
217 params: RopeParams,
218 },
219 Expand {
220 params: ExpandParams,
221 meta_idx: usize,
222 },
223 Argmax {
224 params: ArgmaxParams,
225 },
226 Pool2d {
227 params: Pool2dParams,
228 },
229 Conv2d {
230 params: Conv2dParams,
231 },
232 Pool1d {
233 params: Pool1dParams,
234 },
235 Pool3d {
236 params: Pool3dParams,
237 },
238 Conv1d {
239 params: Conv1dParams,
240 },
241 Conv3d {
242 params: Conv3dParams,
243 },
244 ScatterAdd {
245 params: ScatterAddParams,
246 },
247 TopK {
248 params: TopKParams,
249 },
250 GroupedMatmul {
251 params: GroupedMatmulParams,
252 },
253 Sample {
254 params: SampleParams,
255 },
256 SelectiveScan {
257 params: SelectiveScanParams,
258 },
259 DequantMatmul {
260 params: DequantMatmulParams,
261 },
262 DequantMatmulGguf {
264 m: u32,
265 k: u32,
266 n: u32,
267 scheme_id: u32,
268 x_byte_off: u32,
269 w_byte_off: u32,
270 out_byte_off: u32,
271 },
272 DequantGroupedMatmulGguf {
274 m: u32,
275 k: u32,
276 n: u32,
277 num_experts: u32,
278 scheme_id: u32,
279 x_byte_off: u32,
280 w_byte_off: u32,
281 idx_byte_off: u32,
282 out_byte_off: u32,
283 },
284 GatedDeltaNet {
286 q_byte_off: u32,
287 k_byte_off: u32,
288 v_byte_off: u32,
289 g_byte_off: u32,
290 beta_byte_off: u32,
291 state_byte_off: u32,
292 dst_byte_off: u32,
293 batch: u32,
294 seq: u32,
295 heads: u32,
296 state_size: u32,
297 use_carry: bool,
298 },
299 Llada2GroupLimitedGate {
300 sig_byte_off: u32,
301 route_byte_off: u32,
302 out_byte_off: u32,
303 n_elems: u32,
304 attrs: [u8; 20],
305 },
306 UmapKnn {
307 params: UmapKnnParams,
308 },
309 UmapKnnHost {
311 pairwise_byte_off: u32,
312 out_byte_off: u32,
313 n: u32,
314 k: u32,
315 },
316 #[cfg(feature = "splat")]
318 GaussianSplatRender {
319 positions_byte_off: u32,
320 positions_len: u32,
321 scales_byte_off: u32,
322 scales_len: u32,
323 rotations_byte_off: u32,
324 rotations_len: u32,
325 opacities_byte_off: u32,
326 opacities_len: u32,
327 colors_byte_off: u32,
328 colors_len: u32,
329 sh_coeffs_byte_off: u32,
330 sh_coeffs_len: u32,
331 meta_byte_off: u32,
332 dst_byte_off: u32,
333 dst_len: u32,
334 width: u32,
335 height: u32,
336 tile_size: u32,
337 radius_scale: f32,
338 alpha_cutoff: f32,
339 max_splat_steps: u32,
340 transmittance_threshold: f32,
341 max_list_entries: u32,
342 },
343 #[cfg(feature = "splat")]
345 GaussianSplatRenderBackward {
346 positions_byte_off: u32,
347 positions_len: u32,
348 scales_byte_off: u32,
349 scales_len: u32,
350 rotations_byte_off: u32,
351 rotations_len: u32,
352 opacities_byte_off: u32,
353 opacities_len: u32,
354 colors_byte_off: u32,
355 colors_len: u32,
356 sh_coeffs_byte_off: u32,
357 sh_coeffs_len: u32,
358 meta_byte_off: u32,
359 d_loss_byte_off: u32,
360 d_loss_len: u32,
361 packed_byte_off: u32,
362 packed_len: u32,
363 width: u32,
364 height: u32,
365 tile_size: u32,
366 radius_scale: f32,
367 alpha_cutoff: f32,
368 max_splat_steps: u32,
369 transmittance_threshold: f32,
370 max_list_entries: u32,
371 loss_grad_clip: f32,
372 sh_band: u32,
373 max_anisotropy: f32,
374 },
375 #[cfg(feature = "splat")]
376 GaussianSplatPrepare {
377 positions_byte_off: u32,
378 positions_len: u32,
379 scales_byte_off: u32,
380 scales_len: u32,
381 rotations_byte_off: u32,
382 rotations_len: u32,
383 opacities_byte_off: u32,
384 opacities_len: u32,
385 colors_byte_off: u32,
386 colors_len: u32,
387 sh_coeffs_byte_off: u32,
388 sh_coeffs_len: u32,
389 meta_byte_off: u32,
390 meta_len: u32,
391 prep_byte_off: u32,
392 prep_len: u32,
393 width: u32,
394 height: u32,
395 tile_size: u32,
396 radius_scale: f32,
397 alpha_cutoff: f32,
398 max_splat_steps: u32,
399 transmittance_threshold: f32,
400 max_list_entries: u32,
401 },
402 #[cfg(feature = "splat")]
403 GaussianSplatRasterize {
404 prep_byte_off: u32,
405 prep_len: u32,
406 meta_byte_off: u32,
407 meta_len: u32,
408 dst_byte_off: u32,
409 dst_len: u32,
410 count: u32,
411 width: u32,
412 height: u32,
413 tile_size: u32,
414 alpha_cutoff: f32,
415 max_splat_steps: u32,
416 transmittance_threshold: f32,
417 max_list_entries: u32,
418 },
419 RmsNormBackwardInput {
420 params: RmsNormBwdParams,
421 },
422 RmsNormBackwardGamma {
423 params: RmsNormBwdParams,
424 },
425 RmsNormBackwardBeta {
426 params: RmsNormBwdParams,
427 },
428 RopeBackward {
429 params: RopeBwdParams,
430 },
431 CumsumBackward {
432 params: CumsumBwdParams,
433 },
434 GatherBackward {
435 params: GatherBwdParams,
436 },
437 FusedResidualLn {
438 params: FusedResidualLnParams,
439 },
440 MatmulQkv {
445 params: MatmulQkvParams,
446 coop: bool,
450 },
451 FusedResidualLnTee {
455 params: FusedResidualLnTeeParams,
456 },
457 FusedResidualRmsNorm {
458 params: FusedResidualRmsNormParams,
459 },
460}
461
462pub struct WgpuExecutable {
463 graph: Graph,
464 arena: Arena,
465 schedule: Vec<Step>,
466 input_offsets: HashMap<String, NodeId>,
467 param_offsets: HashMap<String, NodeId>,
468 uniforms: Vec<wgpu::Buffer>,
471 bind_groups: Vec<wgpu::BindGroup>,
472 meta_buffers: Vec<wgpu::Buffer>,
475
476 unresolved: Option<Graph>,
483 last_binding: Option<DimBinding>,
484 pending_params: HashMap<String, Vec<f32>>,
488 pending_param_bytes: HashMap<String, Vec<u8>>,
489 pub(crate) active_extent: Option<(usize, usize)>,
493 uniforms_active_extent: Option<Option<(usize, usize)>>,
503 fft_gpu_steps: Vec<crate::fft_dispatch::FftGpuResources>,
505}
506
507impl Step {
508 pub fn safe_for_active_extent(&self) -> bool {
514 match self {
515 Step::Binary { .. }
516 | Step::Compare { .. }
517 | Step::Unary { .. }
518 | Step::Where { .. }
519 | Step::Reduce { .. }
520 | Step::Softmax { .. }
521 | Step::LayerNorm { .. }
522 | Step::FusedResidualLn { .. }
523 | Step::FusedResidualLnTee { .. }
524 | Step::FusedResidualRmsNorm { .. }
525 | Step::Cumsum { .. }
526 | Step::Copy { .. }
527 | Step::ElementwiseRegion { .. }
528 | Step::Argmax { .. }
529 | Step::TopK { .. }
530 | Step::Sample { .. }
531 | Step::Gather { .. }
532 | Step::GatherAxis { .. }
533 | Step::GroupedMatmul { .. }
534 | Step::DequantMatmul { .. }
535 | Step::DequantMatmulGguf { .. }
536 | Step::DequantGroupedMatmulGguf { .. }
537 | Step::GatedDeltaNet { .. }
538 | Step::Llada2GroupLimitedGate { .. }
539 | Step::UmapKnn { .. }
540 | Step::UmapKnnHost { .. }
541 | Step::Conv1d { .. }
542 | Step::Conv2d { .. }
543 | Step::Conv3d { .. }
544 | Step::Pool1d { .. }
545 | Step::Pool2d { .. }
546 | Step::Pool3d { .. }
547 | Step::ScatterAdd { .. } => true,
548 Step::FftGpu { .. } | Step::FftHost { .. } => true,
553 Step::Matmul { .. } => true,
558 Step::MatmulQkv { .. } => true,
562 Step::CastF32ToF16 { .. } => true,
563 Step::Attention { .. } => true,
569 Step::AttentionBackward { .. } => true,
570 Step::SelectiveScan { .. } => true,
575 Step::Narrow { .. } => true,
584 Step::Concat { .. } => true,
585 Step::Rope { .. } => true,
591 Step::Transpose { params, .. } => params.bucket_outermost == 1,
597 Step::Expand { params, .. } => params.bucket_outermost == 1,
601 Step::RmsNormBackwardInput { .. }
604 | Step::RmsNormBackwardGamma { .. }
605 | Step::RmsNormBackwardBeta { .. }
606 | Step::RopeBackward { .. }
607 | Step::CumsumBackward { .. }
608 | Step::GatherBackward { .. } => false,
609 #[cfg(feature = "splat")]
610 Step::GaussianSplatRender { .. }
611 | Step::GaussianSplatRenderBackward { .. }
612 | Step::GaussianSplatPrepare { .. }
613 | Step::GaussianSplatRasterize { .. } => false,
614 }
615 }
616}
617
618fn fft_dtype_tag(dtype: rlx_ir::DType) -> u32 {
621 match dtype {
622 rlx_ir::DType::F32 => 0,
623 rlx_ir::DType::F64 => 1,
624 rlx_ir::DType::C64 => 2,
625 other => panic!("rlx-wgpu Op::Fft: unsupported dtype {other:?}"),
626 }
627}
628
629fn fft_dtype_from_tag(tag: u32) -> rlx_ir::DType {
630 match tag {
631 0 => rlx_ir::DType::F32,
632 1 => rlx_ir::DType::F64,
633 2 => rlx_ir::DType::C64,
634 other => panic!("rlx-wgpu Op::Fft: bad dtype tag {other}"),
635 }
636}
637
638fn step_name(step: &Step) -> &'static str {
639 match step {
640 Step::CastF32ToF16 { .. } => "cast_f32_to_f16",
641 Step::Matmul { .. } => "matmul",
642 Step::Binary { .. } => "binary",
643 Step::Compare { .. } => "compare",
644 Step::Unary { .. } => "unary",
645 Step::Where { .. } => "where",
646 Step::Reduce { .. } => "reduce",
647 Step::Softmax { .. } => "softmax",
648 Step::LayerNorm { .. } => "layer_norm",
649 Step::Cumsum { .. } => "cumsum",
650 Step::FftGpu { .. } => "fft_gpu",
651 Step::FftHost { .. } => "fft_host",
652 Step::Copy { .. } => "copy",
653 Step::Transpose { .. } => "transpose",
654 Step::Narrow { .. } => "narrow",
655 Step::Concat { .. } => "concat",
656 Step::Gather { .. } => "gather",
657 Step::GatherAxis { .. } => "gather_axis",
658 Step::Attention { .. } => "attention",
659 Step::AttentionBackward { .. } => "attention_bwd",
660 Step::Rope { .. } => "rope",
661 Step::Expand { .. } => "expand",
662 Step::Argmax { .. } => "argmax",
663 Step::Pool2d { .. } => "pool2d",
664 Step::Conv2d { .. } => "conv2d",
665 Step::Pool1d { .. } => "pool1d",
666 Step::Pool3d { .. } => "pool3d",
667 Step::Conv1d { .. } => "conv1d",
668 Step::Conv3d { .. } => "conv3d",
669 Step::ScatterAdd { .. } => "scatter_add",
670 Step::TopK { .. } => "topk",
671 Step::GroupedMatmul { .. } => "grouped_matmul",
672 Step::Sample { .. } => "sample",
673 Step::SelectiveScan { .. } => "selective_scan",
674 Step::DequantMatmul { .. } => "dequant_matmul",
675 Step::DequantMatmulGguf { .. } => "dequant_matmul_gguf",
676 Step::DequantGroupedMatmulGguf { .. } => "dequant_grouped_matmul_gguf",
677 Step::GatedDeltaNet { .. } => "gated_delta_net",
678 Step::Llada2GroupLimitedGate { .. } => "llada2_group_limited_gate",
679 Step::UmapKnn { .. } => "umap_knn",
680 Step::UmapKnnHost { .. } => "umap_knn_host",
681 #[cfg(feature = "splat")]
682 Step::GaussianSplatRender { .. } => "gaussian_splat_render",
683 #[cfg(feature = "splat")]
684 Step::GaussianSplatRenderBackward { .. } => "gaussian_splat_render_backward",
685 #[cfg(feature = "splat")]
686 Step::GaussianSplatPrepare { .. } => "gaussian_splat_prepare",
687 #[cfg(feature = "splat")]
688 Step::GaussianSplatRasterize { .. } => "gaussian_splat_rasterize",
689 Step::RmsNormBackwardInput { .. } => "rms_norm_backward_input",
690 Step::RmsNormBackwardGamma { .. } => "rms_norm_backward_gamma",
691 Step::RmsNormBackwardBeta { .. } => "rms_norm_backward_beta",
692 Step::RopeBackward { .. } => "rope_backward",
693 Step::CumsumBackward { .. } => "cumsum_backward",
694 Step::GatherBackward { .. } => "gather_backward",
695 Step::FusedResidualLn { .. } => "fused_residual_ln",
696 Step::FusedResidualLnTee { .. } => "fused_residual_ln_tee",
697 Step::FusedResidualRmsNorm { .. } => "fused_residual_rms_norm",
698 Step::MatmulQkv { .. } => "matmul_qkv",
699 Step::ElementwiseRegion { .. } => "elementwise_region",
700 }
701}
702
703fn step_runs_on_host(step: &Step) -> bool {
704 match step {
705 Step::DequantMatmulGguf { .. }
706 | Step::DequantGroupedMatmulGguf { .. }
707 | Step::GatedDeltaNet { .. }
708 | Step::Llada2GroupLimitedGate { .. }
709 | Step::UmapKnnHost { .. }
710 | Step::FftHost { .. } => true,
711 #[cfg(feature = "splat")]
712 Step::GaussianSplatRender { .. }
713 | Step::GaussianSplatRenderBackward { .. }
714 | Step::GaussianSplatPrepare { .. }
715 | Step::GaussianSplatRasterize { .. } => true,
716 _ => false,
717 }
718}
719
720fn binary_op_id(op: BinaryOp) -> u32 {
721 match op {
722 BinaryOp::Add => 0,
723 BinaryOp::Sub => 1,
724 BinaryOp::Mul => 2,
725 BinaryOp::Div => 3,
726 BinaryOp::Max => 4,
727 BinaryOp::Min => 5,
728 BinaryOp::Pow => 6,
729 }
730}
731
732fn compare_op_id(op: CmpOp) -> u32 {
733 match op {
734 CmpOp::Eq => 0,
735 CmpOp::Ne => 1,
736 CmpOp::Lt => 2,
737 CmpOp::Le => 3,
738 CmpOp::Gt => 4,
739 CmpOp::Ge => 5,
740 }
741}
742
743fn reduce_op_id(op: ReduceOp) -> u32 {
744 match op {
745 ReduceOp::Sum => 0,
746 ReduceOp::Mean => 1,
747 ReduceOp::Max => 2,
748 ReduceOp::Min => 3,
749 ReduceOp::Prod => 4,
750 }
751}
752
753fn activation_op_id(act: Activation) -> u32 {
754 match act {
755 Activation::Relu => 0,
756 Activation::Sigmoid => 1,
757 Activation::Tanh => 2,
758 Activation::Exp => 3,
759 Activation::Log => 4,
760 Activation::Sqrt => 5,
761 Activation::Rsqrt => 6,
762 Activation::Neg => 7,
763 Activation::Abs => 8,
764 Activation::Gelu => 9,
765 Activation::Silu => 10,
766 Activation::GeluApprox => 11,
767 Activation::Round => 12,
768 Activation::Sin => 13,
769 Activation::Cos => 14,
770 Activation::Tan => 15,
771 Activation::Atan => 16,
772 }
773}
774
775impl WgpuExecutable {
776 fn lazy_compile_for_inputs(&mut self, inputs: &[(&str, &[f32])]) {
780 let unresolved = self
781 .unresolved
782 .as_ref()
783 .expect("lazy_compile_for_inputs called without an unresolved graph");
784 let binding = infer_bindings_from_f32_inputs(unresolved, inputs)
785 .expect("rlx-wgpu lazy compile: could not infer DimBinding from inputs");
786
787 if let Some(prev) = &self.last_binding
789 && same_binding(prev, &binding)
790 {
791 return;
792 }
793
794 let resolved = bind_graph(unresolved, &binding);
796 let original = self.unresolved.take();
797 let pending_params = std::mem::take(&mut self.pending_params);
798 let pending_bytes = std::mem::take(&mut self.pending_param_bytes);
799
800 let fresh = Self::compile_static_inner(resolved);
801
802 self.graph = fresh.graph;
805 self.arena = fresh.arena;
806 self.schedule = fresh.schedule;
807 self.input_offsets = fresh.input_offsets;
808 self.param_offsets = fresh.param_offsets;
809 self.uniforms = fresh.uniforms;
810 self.bind_groups = fresh.bind_groups;
811 self.meta_buffers = fresh.meta_buffers;
812 self.unresolved = original;
813 self.last_binding = Some(binding);
814 self.uniforms_active_extent = None;
817
818 for (name, data) in pending_params {
820 self.set_param(&name, &data);
821 }
822 for (name, data) in pending_bytes {
823 self.set_param_bytes(&name, &data);
824 }
825 }
826
827 pub fn compile_with_bindings(graph: Graph, bindings: &DimBinding) -> Self {
833 if bindings.is_empty() {
834 return Self::compile(graph);
835 }
836 let mut fresh = Graph::new(&graph.name);
838 for node in graph.nodes() {
839 let bound = node.shape.bind(bindings);
840 fresh.add_node(node.op.clone(), node.inputs.clone(), bound);
841 }
842 fresh.set_outputs(graph.outputs.clone());
843 Self::compile(fresh)
844 }
845
846 pub fn compile(graph: Graph) -> Self {
847 if has_dynamic_dims(&graph) {
848 return Self::deferred(graph);
849 }
850 Self::compile_static_inner(graph)
851 }
852
853 fn deferred(graph: Graph) -> Self {
858 let dev = wgpu_device().expect("rlx-wgpu: no compatible adapter found");
859 let placeholder = dev.device.create_buffer(&wgpu::BufferDescriptor {
861 label: Some("rlx-wgpu deferred placeholder"),
862 size: 16,
863 usage: wgpu::BufferUsages::STORAGE
864 | wgpu::BufferUsages::COPY_DST
865 | wgpu::BufferUsages::COPY_SRC,
866 mapped_at_creation: false,
867 });
868 let arena = Arena {
869 buffer: placeholder,
870 f16_buffer: None,
871 offsets: HashMap::new(),
872 lens: HashMap::new(),
873 size: 0,
874 };
875 Self {
876 graph: graph.clone(),
877 arena,
878 schedule: Vec::new(),
879 input_offsets: HashMap::new(),
880 param_offsets: HashMap::new(),
881 uniforms: Vec::new(),
882 bind_groups: Vec::new(),
883 meta_buffers: Vec::new(),
884 unresolved: Some(graph),
885 last_binding: None,
886 pending_params: HashMap::new(),
887 pending_param_bytes: HashMap::new(),
888 active_extent: None,
889 uniforms_active_extent: None,
890 fft_gpu_steps: Vec::new(),
891 }
892 }
893
894 pub fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
898 self.active_extent = extent;
899 }
900
901 fn all_safe_for_active(&self) -> bool {
902 self.schedule.iter().all(|s| s.safe_for_active_extent())
903 }
904
905 fn compile_static_inner(graph: Graph) -> Self {
906 let dev = wgpu_device().expect("rlx-wgpu: no compatible adapter found");
907
908 let graph = crate::unfuse::unfuse(graph);
915
916 let plan = plan_f32_uniform(&graph, 16);
918 let mut arena = Arena::from_plan(&dev.device, &plan);
919 for node in graph.nodes() {
923 let elems = node.shape.num_elements().unwrap_or(0);
924 arena.set_actual_len(node.id, elems * 4);
925 }
926
927 for node in graph.nodes() {
929 if let Op::Constant { data } = &node.op
930 && arena.has(node.id)
931 && !data.is_empty()
932 {
933 let bytes_to_write = data.len().min(arena.len_of(node.id));
934 dev.queue.write_buffer(
935 &arena.buffer,
936 arena.offset(node.id) as u64,
937 &data[..bytes_to_write],
938 );
939 }
940 }
941
942 let mut input_offsets = HashMap::new();
943 let mut param_offsets = HashMap::new();
944 for node in graph.nodes() {
945 match &node.op {
946 Op::Input { name } => {
947 input_offsets.insert(name.clone(), node.id);
948 }
949 Op::Param { name } => {
950 param_offsets.insert(name.clone(), node.id);
951 }
952 _ => {}
953 }
954 }
955
956 let mm_k = matmul_kernel(&dev.device);
957 let mm_w = matmul_wide_kernel(&dev.device);
958 let mm_f16w = matmul_f16w_kernel(&dev.device);
959 let mm_f16c = matmul_f16_compute_kernel(&dev.device);
960 let mm_coop = matmul_coop16_kernel(&dev.device);
961 let mm_coop_f32 = matmul_coop_f32_kernel(&dev.device);
962 let mm_cast = cast_f32_to_f16_kernel(&dev.device);
963 let bk = binary_kernel(&dev.device);
964 let uk = unary_kernel(&dev.device);
965 let ck = compare_kernel(&dev.device);
966 let wk = where_kernel(&dev.device);
967
968 let mut schedule = Vec::new();
969 let mut uniforms = Vec::new();
970 let mut bind_groups = Vec::new();
971 let mut fft_gpu_steps: Vec<crate::fft_dispatch::FftGpuResources> = Vec::new();
972 let mut gguf_host_pad: Option<(wgpu::Buffer, wgpu::BindGroup)> = None;
973 let mut meta_buffers: Vec<wgpu::Buffer> = Vec::new();
974
975 let mut qkv_split: HashMap<NodeId, (NodeId, NodeId, NodeId)> = HashMap::new();
988 for (parent_id, qkv) in detect_split_qkv_pattern(&graph) {
989 let parent = graph.node(parent_id);
990 let a_id = parent.inputs[0];
993 let b_id = parent.inputs[1];
994 let a_dims = graph.node(a_id).shape.dims();
995 let b_dims = graph.node(b_id).shape.dims();
996 let out_dims = parent.shape.dims();
997 let (m, k, n) =
998 if a_dims.len() >= 2 && b_dims.len() == 2 && out_dims.len() == a_dims.len() {
999 let leading: usize = a_dims[..a_dims.len() - 2]
1000 .iter()
1001 .map(|d| d.unwrap_static())
1002 .product();
1003 let m_inner = a_dims[a_dims.len() - 2].unwrap_static();
1004 let k_inner = a_dims[a_dims.len() - 1].unwrap_static();
1005 let n_inner = b_dims[1].unwrap_static();
1006 ((leading * m_inner) as u32, k_inner as u32, n_inner as u32)
1007 } else if a_dims.len() == 2 && b_dims.len() == 2 {
1008 (
1009 a_dims[0].unwrap_static() as u32,
1010 a_dims[1].unwrap_static() as u32,
1011 b_dims[1].unwrap_static() as u32,
1012 )
1013 } else {
1014 continue; };
1016 let cp = derive_matmul_compute(&dev.device, &graph, a_id, b_id, m, k, n);
1017 if cp == MatmulCompute::F32 || cp == MatmulCompute::CoopF32 {
1022 qkv_split.insert(parent_id, qkv);
1023 }
1024 }
1025 let qkv_skip_narrows: HashSet<NodeId> = qkv_split
1026 .values()
1027 .flat_map(|&(q, k, v)| [q, k, v])
1028 .collect();
1029
1030 let (ln_to_tee, skip_adds) = detect_residual_ln_tee_pattern(&graph);
1040
1041 let emit_uniform = |size: usize| -> wgpu::Buffer {
1042 dev.device.create_buffer(&wgpu::BufferDescriptor {
1043 label: Some("rlx-wgpu uniform"),
1044 size: size as u64,
1045 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1046 mapped_at_creation: false,
1047 })
1048 };
1049
1050 for node in graph.nodes() {
1051 let elems = node.shape.num_elements().unwrap_or(0) as u32;
1055 match &node.op {
1056 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => continue,
1057 Op::MatMul => {
1058 let a_id = node.inputs[0];
1059 let b_id = node.inputs[1];
1060 let a_shape = graph.node(a_id).shape.dims();
1061 let b_shape = graph.node(b_id).shape.dims();
1062 let out_shape = node.shape.dims();
1063 let (m, k, n, batch, a_bs, b_bs, c_bs) = if a_shape.len() == 2
1068 && b_shape.len() == 2
1069 && out_shape.len() == 2
1070 {
1071 (
1072 a_shape[0].unwrap_static() as u32,
1073 a_shape[1].unwrap_static() as u32,
1074 b_shape[1].unwrap_static() as u32,
1075 1u32,
1076 0u32,
1077 0u32,
1078 0u32,
1079 )
1080 } else if a_shape.len() >= 2
1081 && b_shape.len() == 2
1082 && out_shape.len() == a_shape.len()
1083 {
1084 let leading: usize = a_shape[..a_shape.len() - 2]
1085 .iter()
1086 .map(|d| d.unwrap_static())
1087 .product();
1088 let m_inner = a_shape[a_shape.len() - 2].unwrap_static();
1089 let k_inner = a_shape[a_shape.len() - 1].unwrap_static();
1090 let n_inner = b_shape[1].unwrap_static();
1091 (
1092 (leading * m_inner) as u32,
1093 k_inner as u32,
1094 n_inner as u32,
1095 1u32,
1096 0u32,
1097 0u32,
1098 0u32,
1099 )
1100 } else if a_shape.len() == b_shape.len()
1101 && a_shape.len() >= 3
1102 && out_shape.len() == a_shape.len()
1103 {
1104 let leading_a: Vec<usize> = a_shape[..a_shape.len() - 2]
1106 .iter()
1107 .map(|d| d.unwrap_static())
1108 .collect();
1109 let leading_b: Vec<usize> = b_shape[..b_shape.len() - 2]
1110 .iter()
1111 .map(|d| d.unwrap_static())
1112 .collect();
1113 if leading_a != leading_b {
1114 panic!(
1115 "rlx-wgpu MatMul: batched shape mismatch \
1116 a_leading={leading_a:?} b_leading={leading_b:?}"
1117 );
1118 }
1119 let b_count: usize = leading_a.iter().product();
1120 let m_inner = a_shape[a_shape.len() - 2].unwrap_static();
1121 let k_inner = a_shape[a_shape.len() - 1].unwrap_static();
1122 let n_inner = b_shape[b_shape.len() - 1].unwrap_static();
1123 (
1124 m_inner as u32,
1125 k_inner as u32,
1126 n_inner as u32,
1127 b_count as u32,
1128 (m_inner * k_inner) as u32,
1129 (k_inner * n_inner) as u32,
1130 (m_inner * n_inner) as u32,
1131 )
1132 } else {
1133 panic!(
1134 "rlx-wgpu MatMul: unsupported shapes a={a_shape:?} b={b_shape:?} \
1135 out={out_shape:?} (supported: 2D×2D, [..,M,K]×[K,N], [..,M,K]×[..,K,N])"
1136 );
1137 };
1138 let b_is_param = traces_to_param(&graph, b_id);
1139 let compute_precision =
1140 derive_matmul_compute(&dev.device, &graph, a_id, b_id, m, k, n);
1141 let _ = mm_cast;
1145 schedule.push(Step::Matmul {
1146 m,
1147 k,
1148 n,
1149 batch,
1150 a_batch_stride: a_bs,
1151 b_batch_stride: b_bs,
1152 c_batch_stride: c_bs,
1153 a_off_f32: (arena.offset(a_id) / 4) as u32,
1154 b_off_f32: (arena.offset(b_id) / 4) as u32,
1155 c_off_f32: (arena.offset(node.id) / 4) as u32,
1156 has_bias: 0,
1157 bias_off_f32: 0,
1158 act_id: 0xFFFF,
1159 b_is_param,
1160 compute_precision,
1161 });
1162 let u = emit_uniform(std::mem::size_of::<MatmulParams>());
1163 let bg = build_matmul_bind_group(
1164 &dev.device,
1165 mm_k,
1166 mm_w,
1167 &mm_f16w,
1168 &mm_f16c,
1169 &mm_coop,
1170 &mm_coop_f32,
1171 &arena,
1172 &u,
1173 b_is_param,
1174 compute_precision,
1175 );
1176 uniforms.push(u);
1177 bind_groups.push(bg);
1178 }
1179 Op::Binary(bop) => {
1180 if skip_adds.contains(&node.id) {
1185 continue;
1186 }
1187 require_equal_shapes(&graph, &node.inputs, "Binary");
1188 let p = BinaryParams {
1189 n: elems,
1190 a_off: (arena.offset(node.inputs[0]) / 4) as u32,
1191 b_off: (arena.offset(node.inputs[1]) / 4) as u32,
1192 c_off: (arena.offset(node.id) / 4) as u32,
1193 op: binary_op_id(*bop),
1194 _p0: 0,
1195 _p1: 0,
1196 _p2: 0,
1197 };
1198 schedule.push(Step::Binary { params: p });
1199 let u = emit_uniform(std::mem::size_of::<BinaryParams>());
1200 let bg = bind_two(&dev.device, bk, &arena.buffer, &u);
1201 uniforms.push(u);
1202 bind_groups.push(bg);
1203 }
1204 Op::Compare(cop) => {
1205 require_equal_shapes(&graph, &node.inputs, "Compare");
1206 let p = BinaryParams {
1207 n: elems,
1208 a_off: (arena.offset(node.inputs[0]) / 4) as u32,
1209 b_off: (arena.offset(node.inputs[1]) / 4) as u32,
1210 c_off: (arena.offset(node.id) / 4) as u32,
1211 op: compare_op_id(*cop),
1212 _p0: 0,
1213 _p1: 0,
1214 _p2: 0,
1215 };
1216 schedule.push(Step::Compare { params: p });
1217 let u = emit_uniform(std::mem::size_of::<BinaryParams>());
1218 let bg = bind_two(&dev.device, ck, &arena.buffer, &u);
1219 uniforms.push(u);
1220 bind_groups.push(bg);
1221 }
1222 Op::Activation(act) => {
1223 let p = UnaryParams {
1224 n: elems,
1225 in_off: (arena.offset(node.inputs[0]) / 4) as u32,
1226 out_off: (arena.offset(node.id) / 4) as u32,
1227 op: activation_op_id(*act),
1228 _p0: 0,
1229 _p1: 0,
1230 _p2: 0,
1231 _p3: 0,
1232 };
1233 schedule.push(Step::Unary { params: p });
1234 let u = emit_uniform(std::mem::size_of::<UnaryParams>());
1235 let bg = bind_two(&dev.device, uk, &arena.buffer, &u);
1236 uniforms.push(u);
1237 bind_groups.push(bg);
1238 }
1239 Op::Where => {
1240 let p = WhereParams {
1241 n: elems,
1242 cond_off: (arena.offset(node.inputs[0]) / 4) as u32,
1243 x_off: (arena.offset(node.inputs[1]) / 4) as u32,
1244 y_off: (arena.offset(node.inputs[2]) / 4) as u32,
1245 out_off: (arena.offset(node.id) / 4) as u32,
1246 _p0: 0,
1247 _p1: 0,
1248 _p2: 0,
1249 };
1250 schedule.push(Step::Where { params: p });
1251 let u = emit_uniform(std::mem::size_of::<WhereParams>());
1252 let bg = bind_two(&dev.device, wk, &arena.buffer, &u);
1253 uniforms.push(u);
1254 bind_groups.push(bg);
1255 }
1256
1257 Op::ElementwiseRegion {
1258 chain,
1259 num_inputs,
1260 scalar_input_mask,
1261 input_modulus,
1262 } => {
1263 let n = *num_inputs as usize;
1266 if n > 16 || chain.len() > 32 {
1267 panic!(
1268 "rlx-wgpu ElementwiseRegion: chain too large \
1269 (inputs={n}, steps={}). Caps: 16 / 32. \
1270 Use UnfuseElementwiseRegions to fall back.",
1271 chain.len()
1272 );
1273 }
1274 let mut input_offs = [0u32; 16];
1275 for (i, &id) in node.inputs.iter().enumerate() {
1276 input_offs[i] = (arena.offset(id) / 4) as u32;
1277 }
1278 let encode_operand = |op: &ChainOperand| -> u32 {
1279 match *op {
1280 ChainOperand::Input(i) => i & 0x7FFF_FFFFu32,
1281 ChainOperand::Step(i) => 0x8000_0000u32 | (i & 0x7FFF_FFFFu32),
1282 }
1283 };
1284 let act_sub = |a: Activation| match a {
1285 Activation::Gelu => 0u32,
1286 Activation::GeluApprox => 1,
1287 Activation::Silu => 2,
1288 Activation::Relu => 3,
1289 Activation::Sigmoid => 4,
1290 Activation::Tanh => 5,
1291 Activation::Exp => 6,
1292 Activation::Log => 7,
1293 Activation::Sqrt => 8,
1294 Activation::Rsqrt => 9,
1295 Activation::Neg => 10,
1296 Activation::Abs => 11,
1297 Activation::Round => 12,
1298 Activation::Sin => 13,
1299 Activation::Cos => 14,
1300 Activation::Tan => 15,
1301 Activation::Atan => 16,
1302 };
1303 let bin_sub = |b: BinaryOp| match b {
1304 BinaryOp::Add => 0u32,
1305 BinaryOp::Sub => 1,
1306 BinaryOp::Mul => 2,
1307 BinaryOp::Div => 3,
1308 BinaryOp::Max => 4,
1309 BinaryOp::Min => 5,
1310 BinaryOp::Pow => 6,
1311 };
1312 let cmp_sub = |c: CmpOp| match c {
1313 CmpOp::Eq => 0u32,
1314 CmpOp::Ne => 1,
1315 CmpOp::Lt => 2,
1316 CmpOp::Le => 3,
1317 CmpOp::Gt => 4,
1318 CmpOp::Ge => 5,
1319 };
1320 let mut chain_enc = [0u32; 128];
1321 for (k, step) in chain.iter().enumerate() {
1322 let base = k * 4;
1323 let (kind, sub, lhs, rhs) = match step {
1324 ChainStep::Activation(a, src) => {
1325 (0u32, act_sub(*a), encode_operand(src), 0u32)
1326 }
1327 ChainStep::Cast(_, src) => (1u32, 0, encode_operand(src), 0u32),
1328 ChainStep::Binary(op, l, r) => {
1329 (2u32, bin_sub(*op), encode_operand(l), encode_operand(r))
1330 }
1331 ChainStep::Compare(op, l, r) => {
1332 (3u32, cmp_sub(*op), encode_operand(l), encode_operand(r))
1333 }
1334 ChainStep::Where(c, t, f) =>
1335 {
1338 (
1339 4u32,
1340 encode_operand(c),
1341 encode_operand(t),
1342 encode_operand(f),
1343 )
1344 }
1345 };
1346 chain_enc[base] = kind;
1347 chain_enc[base + 1] = sub;
1348 chain_enc[base + 2] = lhs;
1349 chain_enc[base + 3] = rhs;
1350 }
1351 let p = ElementwiseRegionParams {
1352 len: elems,
1353 num_inputs: *num_inputs,
1354 num_steps: chain.len() as u32,
1355 dst_off: (arena.offset(node.id) / 4) as u32,
1356 input_offs,
1357 chain: chain_enc,
1358 scalar_input_mask: *scalar_input_mask,
1359 _pad0: 0,
1360 _pad1: 0,
1361 _pad2: 0,
1362 input_modulus: *input_modulus,
1363 };
1364 schedule.push(Step::ElementwiseRegion { params: p });
1365 let ek = elementwise_region_kernel(&dev.device);
1366 let u = dev.device.create_buffer(&wgpu::BufferDescriptor {
1370 label: Some("rlx-wgpu region params"),
1371 size: std::mem::size_of::<ElementwiseRegionParams>() as u64,
1372 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
1373 mapped_at_creation: false,
1374 });
1375 let bg = bind_two(&dev.device, ek, &arena.buffer, &u);
1376 uniforms.push(u);
1377 bind_groups.push(bg);
1378 }
1379
1380 Op::Reduce {
1381 op: rop,
1382 axes,
1383 keep_dim: _,
1384 } => {
1385 let in_id = node.inputs[0];
1389 let in_shape = graph.node(in_id).shape.dims();
1390 let last = in_shape.len() - 1;
1391 if axes.as_slice() != [last] {
1392 panic!(
1393 "rlx-wgpu Reduce: only last-axis is wired \
1394 (got axes={axes:?}, rank={})",
1395 in_shape.len()
1396 );
1397 }
1398 let inner = in_shape[last].unwrap_static() as u32;
1399 let total: u32 = in_shape.iter().map(|d| d.unwrap_static() as u32).product();
1400 let outer = total / inner.max(1);
1401 let p = ReduceParams {
1402 outer,
1403 inner,
1404 in_off: (arena.offset(in_id) / 4) as u32,
1405 out_off: (arena.offset(node.id) / 4) as u32,
1406 op: reduce_op_id(*rop),
1407 _p0: 0,
1408 _p1: 0,
1409 _p2: 0,
1410 };
1411 schedule.push(Step::Reduce { params: p });
1412 let rk = reduce_kernel(&dev.device);
1413 let u = emit_uniform(std::mem::size_of::<ReduceParams>());
1414 let bg = bind_two(&dev.device, rk, &arena.buffer, &u);
1415 uniforms.push(u);
1416 bind_groups.push(bg);
1417 }
1418
1419 Op::Softmax { axis } => {
1420 let in_id = node.inputs[0];
1421 let in_shape = graph.node(in_id).shape.dims();
1422 let last = (in_shape.len() - 1) as i32;
1423 if *axis != -1 && *axis != last {
1424 panic!("rlx-wgpu Softmax: only last-axis wired (got axis={axis})");
1425 }
1426 let inner = in_shape[in_shape.len() - 1].unwrap_static() as u32;
1427 let total: u32 = in_shape.iter().map(|d| d.unwrap_static() as u32).product();
1428 let outer = total / inner.max(1);
1429 let p = SoftmaxParams {
1430 outer,
1431 inner,
1432 in_off: (arena.offset(in_id) / 4) as u32,
1433 out_off: (arena.offset(node.id) / 4) as u32,
1434 _p0: 0,
1435 _p1: 0,
1436 _p2: 0,
1437 _p3: 0,
1438 };
1439 schedule.push(Step::Softmax { params: p });
1440 let sk = softmax_kernel(&dev.device);
1441 let u = emit_uniform(std::mem::size_of::<SoftmaxParams>());
1442 let bg = bind_two(&dev.device, sk, &arena.buffer, &u);
1443 uniforms.push(u);
1444 bind_groups.push(bg);
1445 }
1446
1447 Op::LayerNorm { axis: _, eps } | Op::RmsNorm { axis: _, eps } => {
1448 let in_id = node.inputs[0];
1449 let in_shape = graph.node(in_id).shape.dims();
1450 let inner = in_shape[in_shape.len() - 1].unwrap_static() as u32;
1451 let total: u32 = in_shape.iter().map(|d| d.unwrap_static() as u32).product();
1452 let outer = total / inner.max(1);
1453 let is_layer_norm = matches!(&node.op, Op::LayerNorm { .. });
1454
1455 if is_layer_norm
1462 && let Some(&(h_id, delta_id, gamma_id, beta_id, sum_id)) =
1463 ln_to_tee.get(&node.id)
1464 {
1465 let p = FusedResidualLnTeeParams {
1466 outer,
1467 inner,
1468 in_off: (arena.offset(h_id) / 4) as u32,
1469 residual_off: (arena.offset(delta_id) / 4) as u32,
1470 bias_off: 0, gamma_off: (arena.offset(gamma_id) / 4) as u32,
1472 beta_off: (arena.offset(beta_id) / 4) as u32,
1473 sum_off: (arena.offset(sum_id) / 4) as u32,
1474 ln_out_off: (arena.offset(node.id) / 4) as u32,
1475 eps_bits: eps.to_bits(),
1476 has_bias: 0,
1477 _p0: 0,
1478 };
1479 schedule.push(Step::FusedResidualLnTee { params: p });
1480 let frtk = fused_residual_ln_tee_kernel(&dev.device);
1481 let u = emit_uniform(std::mem::size_of::<FusedResidualLnTeeParams>());
1482 let bg = bind_two(&dev.device, frtk, &arena.buffer, &u);
1483 uniforms.push(u);
1484 bind_groups.push(bg);
1485 continue;
1486 }
1487
1488 let gamma_id = node.inputs[1];
1489 let beta_id = if is_layer_norm && node.inputs.len() >= 3 {
1492 node.inputs[2]
1493 } else {
1494 gamma_id
1497 };
1498 let p = LayerNormParams {
1499 outer,
1500 inner,
1501 in_off: (arena.offset(in_id) / 4) as u32,
1502 out_off: (arena.offset(node.id) / 4) as u32,
1503 gamma_off: (arena.offset(gamma_id) / 4) as u32,
1504 beta_off: (arena.offset(beta_id) / 4) as u32,
1505 eps_bits: eps.to_bits(),
1506 op: if is_layer_norm { 0 } else { 1 },
1507 };
1508 schedule.push(Step::LayerNorm { params: p });
1509 let lk = layernorm_kernel(&dev.device);
1510 let u = emit_uniform(std::mem::size_of::<LayerNormParams>());
1511 let bg = bind_two(&dev.device, lk, &arena.buffer, &u);
1512 uniforms.push(u);
1513 bind_groups.push(bg);
1514 }
1515
1516 Op::Reshape { .. } | Op::Cast { .. } => {
1517 }
1519
1520 Op::Transpose { perm } => {
1521 let in_id = node.inputs[0];
1522 let in_shape = graph.node(in_id).shape.dims();
1523 let out_shape = node.shape.dims();
1524 let rank = perm.len();
1525 if rank != in_shape.len() || rank != out_shape.len() {
1526 panic!("rlx-wgpu Transpose: rank mismatch");
1527 }
1528 let in_dims: Vec<u32> =
1529 in_shape.iter().map(|d| d.unwrap_static() as u32).collect();
1530 let out_dims: Vec<u32> =
1531 out_shape.iter().map(|d| d.unwrap_static() as u32).collect();
1532 let mut in_strides = vec![1u32; rank];
1534 for i in (0..rank.saturating_sub(1)).rev() {
1535 in_strides[i] = in_strides[i + 1] * in_dims[i + 1];
1536 }
1537 let strides_for_out: Vec<u32> =
1540 (0..rank).map(|i| in_strides[perm[i]]).collect();
1541
1542 let mut meta_data: Vec<u32> = Vec::with_capacity(rank * 2);
1544 meta_data.extend_from_slice(&out_dims);
1545 meta_data.extend_from_slice(&strides_for_out);
1546 let meta_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
1547 label: Some("rlx-wgpu transpose meta"),
1548 size: (meta_data.len() * 4).max(4) as u64,
1549 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
1550 mapped_at_creation: false,
1551 });
1552 dev.queue
1553 .write_buffer(&meta_buf, 0, bytemuck::cast_slice(&meta_data));
1554 let meta_idx = meta_buffers.len();
1555 meta_buffers.push(meta_buf);
1556
1557 let bucket_outermost = if perm[0] == 0 { 1u32 } else { 0u32 };
1561 let p = TransposeParams {
1562 rank: rank as u32,
1563 out_total: elems,
1564 in_off: (arena.offset(in_id) / 4) as u32,
1565 out_off: (arena.offset(node.id) / 4) as u32,
1566 bucket_outermost,
1567 out_dim_0: out_dims[0],
1568 _p2: 0,
1569 _p3: 0,
1570 };
1571 schedule.push(Step::Transpose {
1572 params: p,
1573 meta_idx,
1574 });
1575 let tk = transpose_kernel(&dev.device);
1576 let u = emit_uniform(std::mem::size_of::<TransposeParams>());
1577 let bg = dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1578 label: Some("rlx-wgpu transpose bg"),
1579 layout: &tk.bgl,
1580 entries: &[
1581 wgpu::BindGroupEntry {
1582 binding: 0,
1583 resource: arena.buffer.as_entire_binding(),
1584 },
1585 wgpu::BindGroupEntry {
1586 binding: 1,
1587 resource: u.as_entire_binding(),
1588 },
1589 wgpu::BindGroupEntry {
1590 binding: 2,
1591 resource: meta_buffers[meta_idx].as_entire_binding(),
1592 },
1593 ],
1594 });
1595 uniforms.push(u);
1596 bind_groups.push(bg);
1597 }
1598
1599 Op::Narrow { axis, start, len } => {
1600 if qkv_skip_narrows.contains(&node.id) {
1605 continue;
1606 }
1607 let in_id = node.inputs[0];
1608 let in_shape = graph.node(in_id).shape.dims();
1609 let outer: u32 = in_shape[..*axis]
1610 .iter()
1611 .map(|d| d.unwrap_static() as u32)
1612 .product::<u32>()
1613 .max(1);
1614 let inner: u32 = in_shape[*axis + 1..]
1615 .iter()
1616 .map(|d| d.unwrap_static() as u32)
1617 .product::<u32>()
1618 .max(1);
1619 let axis_in = in_shape[*axis].unwrap_static() as u32;
1620 let p = NarrowConcatParams {
1621 total: elems,
1622 outer,
1623 inner,
1624 axis_in_size: axis_in,
1625 axis_out_size: *len as u32,
1626 start: *start as u32,
1627 in_off: (arena.offset(in_id) / 4) as u32,
1628 out_off: (arena.offset(node.id) / 4) as u32,
1629 };
1630 schedule.push(Step::Narrow { params: p });
1631 let nk = narrow_kernel(&dev.device);
1632 let u = emit_uniform(std::mem::size_of::<NarrowConcatParams>());
1633 let bg = bind_two(&dev.device, nk, &arena.buffer, &u);
1634 uniforms.push(u);
1635 bind_groups.push(bg);
1636 }
1637
1638 Op::Concat { axis } => {
1639 let out_shape = node.shape.dims();
1640 let outer: u32 = out_shape[..*axis]
1641 .iter()
1642 .map(|d| d.unwrap_static() as u32)
1643 .product::<u32>()
1644 .max(1);
1645 let inner: u32 = out_shape[*axis + 1..]
1646 .iter()
1647 .map(|d| d.unwrap_static() as u32)
1648 .product::<u32>()
1649 .max(1);
1650 let axis_out = out_shape[*axis].unwrap_static() as u32;
1651
1652 let mut start_pos: u32 = 0;
1653 for &in_id in &node.inputs {
1654 let in_shape = graph.node(in_id).shape.dims();
1655 let axis_in = in_shape[*axis].unwrap_static() as u32;
1656 let in_total: u32 =
1657 in_shape.iter().map(|d| d.unwrap_static() as u32).product();
1658 let p = NarrowConcatParams {
1659 total: in_total,
1660 outer,
1661 inner,
1662 axis_in_size: axis_in,
1663 axis_out_size: axis_out,
1664 start: start_pos,
1665 in_off: (arena.offset(in_id) / 4) as u32,
1666 out_off: (arena.offset(node.id) / 4) as u32,
1667 };
1668 schedule.push(Step::Concat { params: p });
1669 let cck = concat_kernel(&dev.device);
1670 let u = emit_uniform(std::mem::size_of::<NarrowConcatParams>());
1671 let bg = bind_two(&dev.device, cck, &arena.buffer, &u);
1672 uniforms.push(u);
1673 bind_groups.push(bg);
1674 start_pos += axis_in;
1675 }
1676 }
1677
1678 Op::Attention {
1679 num_heads,
1680 head_dim,
1681 mask_kind,
1682 score_scale: _,
1683 attn_logit_softcap: _,
1684 } => {
1685 let q_id = node.inputs[0];
1688 let k_id = node.inputs[1];
1689 let v_id = node.inputs[2];
1690 let q_shape = graph.node(q_id).shape.dims();
1691 let k_shape = graph.node(k_id).shape.dims();
1692 let h = *num_heads as u32;
1698 let hd = *head_dim as u32;
1699 let (batch, heads, seq_q, seq_k) = match q_shape.len() {
1700 4 => (
1701 q_shape[0].unwrap_static() as u32,
1702 q_shape[1].unwrap_static() as u32,
1703 q_shape[2].unwrap_static() as u32,
1704 k_shape[2].unwrap_static() as u32,
1705 ),
1706 3 => {
1707 let last = q_shape[2].unwrap_static() as u32;
1714 if last == h * hd {
1715 (
1717 q_shape[0].unwrap_static() as u32,
1718 h,
1719 q_shape[1].unwrap_static() as u32,
1720 k_shape[1].unwrap_static() as u32,
1721 )
1722 } else {
1723 let leading = q_shape[0].unwrap_static() as u32;
1725 if !leading.is_multiple_of(h) {
1726 panic!(
1727 "rlx-wgpu Attention: rank-3 leading dim {leading} \
1728 not divisible by num_heads {h} (and last dim \
1729 {last} ≠ H·D = {})",
1730 h * hd
1731 );
1732 }
1733 (
1734 leading / h,
1735 h,
1736 q_shape[1].unwrap_static() as u32,
1737 k_shape[1].unwrap_static() as u32,
1738 )
1739 }
1740 }
1741 other => panic!(
1742 "rlx-wgpu Attention: only rank-3 / rank-4 Q,K,V \
1743 inputs supported (got rank {other})"
1744 ),
1745 };
1746 let scale = 1.0_f32 / (hd as f32).sqrt();
1747
1748 let (mask_kind_id, mask_off, mask_buf, window) = match mask_kind {
1749 MaskKind::None => (0u32, 0u32, None, 0u32),
1750 MaskKind::Causal => (1u32, 0u32, None, 0u32),
1751 MaskKind::Custom | MaskKind::Bias => {
1752 let m_id = node.inputs[3];
1753 (2u32, (arena.offset(m_id) / 4) as u32, None, 0u32)
1754 }
1755 MaskKind::SlidingWindow(w) => (3u32, 0u32, None, *w as u32),
1756 };
1757
1758 struct MStrides {
1765 b: u32,
1766 h: u32,
1767 q: u32,
1768 k: u32,
1769 }
1770 let mask_strides = if mask_kind_id == 2u32 {
1771 let m_dims = graph.node(node.inputs[3]).shape.dims();
1772 let dim = |i: usize| m_dims[i].unwrap_static() as u32;
1773 match m_dims.len() {
1774 2 => MStrides {
1775 b: dim(1),
1776 h: 0,
1777 q: 0,
1778 k: 1,
1779 },
1780 3 => MStrides {
1781 b: dim(1) * dim(2),
1782 h: 0,
1783 q: dim(2),
1784 k: 1,
1785 },
1786 4 => MStrides {
1787 b: dim(1) * dim(2) * dim(3),
1788 h: dim(2) * dim(3),
1789 q: dim(3),
1790 k: 1,
1791 },
1792 _ => MStrides {
1793 b: heads * seq_q * seq_k,
1794 h: seq_q * seq_k,
1795 q: seq_k,
1796 k: 1,
1797 },
1798 }
1799 } else {
1800 MStrides {
1801 b: heads * seq_q * seq_k,
1802 h: seq_q * seq_k,
1803 q: seq_k,
1804 k: 1,
1805 }
1806 };
1807
1808 let infer_strides =
1816 |shape: &[rlx_ir::shape::Dim], seq_extent: u32| -> (u32, u32, u32) {
1817 let last = shape[shape.len() - 1].unwrap_static() as u32;
1818 if shape.len() == 3 && last == (heads * hd) {
1819 let head_dim_total = heads * hd;
1821 (seq_extent * head_dim_total, hd, head_dim_total)
1822 } else {
1823 (heads * seq_extent * hd, seq_extent * hd, hd)
1825 }
1826 };
1827 let (q_b, q_h, q_s) = infer_strides(q_shape, seq_q);
1828 let (k_b, k_h, k_s) = infer_strides(k_shape, seq_k);
1829 let v_shape = graph.node(v_id).shape.dims();
1830 let (v_b, v_h, v_s) = infer_strides(v_shape, seq_k);
1831 let out_shape = node.shape.dims();
1832 let (o_b, o_h, o_s) = infer_strides(out_shape, seq_q);
1833 let p = AttentionParams {
1834 batch,
1835 heads,
1836 seq_q,
1837 seq_k,
1838 head_dim: hd,
1839 q_off: (arena.offset(q_id) / 4) as u32,
1840 k_off: (arena.offset(k_id) / 4) as u32,
1841 v_off: (arena.offset(v_id) / 4) as u32,
1842 out_off: (arena.offset(node.id) / 4) as u32,
1843 mask_off,
1844 mask_kind: mask_kind_id,
1845 scale_bits: scale.to_bits(),
1846 window,
1847 seq_q_stride: mask_strides.q,
1856 seq_k_stride: mask_strides.k,
1857 mask_batch_stride: mask_strides.b,
1858 mask_head_stride: mask_strides.h,
1859 _pad_mask_0: 0,
1860 _pad_mask_1: 0,
1861 _pad_mask_2: 0,
1862 q_batch_stride: q_b,
1863 q_head_stride: q_h,
1864 q_seq_stride: q_s,
1865 _pad_q: 0,
1866 k_batch_stride: k_b,
1867 k_head_stride: k_h,
1868 k_seq_stride: k_s,
1869 _pad_k: 0,
1870 v_batch_stride: v_b,
1871 v_head_stride: v_h,
1872 v_seq_stride: v_s,
1873 _pad_v: 0,
1874 o_batch_stride: o_b,
1875 o_head_stride: o_h,
1876 o_seq_stride: o_s,
1877 _pad_o: 0,
1878 };
1879 let _ = num_heads;
1880 schedule.push(Step::Attention {
1881 params: p,
1882 mask_buf,
1883 });
1884 let ak = attention_kernel(&dev.device);
1885 let u = emit_uniform(std::mem::size_of::<AttentionParams>());
1886 let bg = bind_two(&dev.device, ak, &arena.buffer, &u);
1887 uniforms.push(u);
1888 bind_groups.push(bg);
1889 }
1890
1891 Op::AttentionBackward {
1892 num_heads: _,
1893 head_dim,
1894 mask_kind,
1895 wrt,
1896 } => {
1897 use rlx_ir::op::AttentionBwdWrt;
1898 let q_id = node.inputs[0];
1899 let k_id = node.inputs[1];
1900 let v_id = node.inputs[2];
1901 let dy_id = node.inputs[3];
1902 let q_shape = graph.node(q_id).shape.dims();
1903 let k_shape = graph.node(k_id).shape.dims();
1904 let hd = *head_dim as u32;
1905 let (batch, heads, seq_q, seq_k) = match q_shape.len() {
1906 4 => (
1907 q_shape[0].unwrap_static() as u32,
1908 q_shape[1].unwrap_static() as u32,
1909 q_shape[2].unwrap_static() as u32,
1910 k_shape[2].unwrap_static() as u32,
1911 ),
1912 3 => {
1913 let h = q_shape[2].unwrap_static() as u32 / hd;
1914 (
1915 q_shape[0].unwrap_static() as u32 / h,
1916 h,
1917 q_shape[1].unwrap_static() as u32,
1918 k_shape[1].unwrap_static() as u32,
1919 )
1920 }
1921 other => panic!(
1922 "rlx-wgpu AttentionBackward: only rank-3/4 Q,K,V (got rank {other})"
1923 ),
1924 };
1925 let scale = 1.0_f32 / (hd as f32).sqrt();
1926 let (mask_kind_id, mask_off, mask_buf, window) = match mask_kind {
1927 MaskKind::None => (0u32, 0u32, None, 0u32),
1928 MaskKind::Causal => (1u32, 0u32, None, 0u32),
1929 MaskKind::Custom => {
1930 (2u32, (arena.offset(node.inputs[4]) / 4) as u32, None, 0u32)
1931 }
1932 MaskKind::Bias => {
1933 (4u32, (arena.offset(node.inputs[4]) / 4) as u32, None, 0u32)
1934 }
1935 MaskKind::SlidingWindow(w) => (3u32, 0u32, None, *w as u32),
1936 };
1937 struct MStrides {
1938 b: u32,
1939 h: u32,
1940 q: u32,
1941 k: u32,
1942 }
1943 let mask_strides = if mask_kind_id == 2 || mask_kind_id == 4 {
1944 let m_dims = graph.node(node.inputs[4]).shape.dims();
1945 let dim = |i: usize| m_dims[i].unwrap_static() as u32;
1946 match m_dims.len() {
1947 2 => MStrides {
1948 b: dim(1),
1949 h: 0,
1950 q: 0,
1951 k: 1,
1952 },
1953 3 => MStrides {
1954 b: dim(1) * dim(2),
1955 h: 0,
1956 q: dim(2),
1957 k: 1,
1958 },
1959 4 => MStrides {
1960 b: dim(1) * dim(2) * dim(3),
1961 h: dim(2) * dim(3),
1962 q: dim(3),
1963 k: 1,
1964 },
1965 _ => MStrides {
1966 b: heads * seq_q * seq_k,
1967 h: seq_q * seq_k,
1968 q: seq_k,
1969 k: 1,
1970 },
1971 }
1972 } else {
1973 MStrides {
1974 b: heads * seq_q * seq_k,
1975 h: seq_q * seq_k,
1976 q: seq_k,
1977 k: 1,
1978 }
1979 };
1980 let infer_strides =
1981 |shape: &[rlx_ir::shape::Dim], seq_extent: u32| -> (u32, u32, u32) {
1982 let last = shape[shape.len() - 1].unwrap_static() as u32;
1983 if shape.len() == 3 && last == (heads * hd) {
1984 let head_dim_total = heads * hd;
1985 (seq_extent * head_dim_total, hd, head_dim_total)
1986 } else {
1987 (heads * seq_extent * hd, seq_extent * hd, hd)
1988 }
1989 };
1990 let (q_b, q_h, q_s) = infer_strides(q_shape, seq_q);
1991 let (k_b, k_h, k_s) = infer_strides(k_shape, seq_k);
1992 let v_shape = graph.node(v_id).shape.dims();
1993 let (v_b, v_h, v_s) = infer_strides(v_shape, seq_k);
1994 let out_shape = node.shape.dims();
1995 let out_seq = match wrt {
1996 AttentionBwdWrt::Query => seq_q,
1997 AttentionBwdWrt::Key | AttentionBwdWrt::Value => seq_k,
1998 };
1999 let (o_b, o_h, o_s) = infer_strides(out_shape, out_seq);
2000 let wrt_id = match wrt {
2001 AttentionBwdWrt::Query => 0u32,
2002 AttentionBwdWrt::Key => 1u32,
2003 AttentionBwdWrt::Value => 2u32,
2004 };
2005 let p = AttentionBwdParams {
2006 batch,
2007 heads,
2008 seq_q,
2009 seq_k,
2010 head_dim: hd,
2011 q_off: (arena.offset(q_id) / 4) as u32,
2012 k_off: (arena.offset(k_id) / 4) as u32,
2013 v_off: (arena.offset(v_id) / 4) as u32,
2014 dy_off: (arena.offset(dy_id) / 4) as u32,
2015 out_off: (arena.offset(node.id) / 4) as u32,
2016 mask_off,
2017 mask_kind: mask_kind_id,
2018 scale_bits: scale.to_bits(),
2019 window,
2020 wrt: wrt_id,
2021 seq_q_stride: mask_strides.q,
2022 seq_k_stride: mask_strides.k,
2023 mask_batch_stride: mask_strides.b,
2024 mask_head_stride: mask_strides.h,
2025 _pad_mask_0: 0,
2026 _pad_mask_1: 0,
2027 _pad_mask_2: 0,
2028 q_batch_stride: q_b,
2029 q_head_stride: q_h,
2030 q_seq_stride: q_s,
2031 _pad_q: 0,
2032 k_batch_stride: k_b,
2033 k_head_stride: k_h,
2034 k_seq_stride: k_s,
2035 _pad_k: 0,
2036 v_batch_stride: v_b,
2037 v_head_stride: v_h,
2038 v_seq_stride: v_s,
2039 _pad_v: 0,
2040 o_batch_stride: o_b,
2041 o_head_stride: o_h,
2042 o_seq_stride: o_s,
2043 _pad_o: 0,
2044 };
2045 schedule.push(Step::AttentionBackward {
2046 params: p,
2047 mask_buf,
2048 });
2049 let ak = attention_bwd_kernel(&dev.device);
2050 let u = emit_uniform(std::mem::size_of::<AttentionBwdParams>());
2051 let bg = bind_two(&dev.device, ak, &arena.buffer, &u);
2052 uniforms.push(u);
2053 bind_groups.push(bg);
2054 }
2055
2056 Op::Rope { head_dim, n_rot: _ } => {
2057 let x_id = node.inputs[0];
2058 let cos_id = node.inputs[1];
2059 let sin_id = node.inputs[2];
2060 let x_shape = graph.node(x_id).shape.dims();
2061 let last = x_shape.last().map(|d| d.unwrap_static()).unwrap_or(0);
2062 if !last.is_multiple_of(*head_dim) {
2063 panic!(
2064 "rlx-wgpu Rope: last_dim ({last}) must be a multiple \
2065 of head_dim ({head_dim})"
2066 );
2067 }
2068 if head_dim % 2 != 0 {
2069 panic!("rlx-wgpu Rope: head_dim must be even");
2070 }
2071 let total: u32 = x_shape.iter().map(|d| d.unwrap_static() as u32).product();
2072 let seq = x_shape[x_shape.len() - 2].unwrap_static() as u32;
2073 let batch = total / (seq * last as u32).max(1);
2078 let p = RopeParams {
2079 n_total: total,
2080 seq,
2081 head_dim: *head_dim as u32,
2082 half: (*head_dim / 2) as u32,
2083 in_off: (arena.offset(x_id) / 4) as u32,
2084 cos_off: (arena.offset(cos_id) / 4) as u32,
2085 sin_off: (arena.offset(sin_id) / 4) as u32,
2086 out_off: (arena.offset(node.id) / 4) as u32,
2087 last_dim: last as u32,
2088 batch,
2089 seq_stride: seq,
2090 _p2: 0,
2091 };
2092 schedule.push(Step::Rope { params: p });
2093 let rk = rope_kernel(&dev.device);
2094 let u = emit_uniform(std::mem::size_of::<RopeParams>());
2095 let bg = bind_two(&dev.device, rk, &arena.buffer, &u);
2096 uniforms.push(u);
2097 bind_groups.push(bg);
2098 }
2099
2100 Op::Expand { target_shape } => {
2101 let in_id = node.inputs[0];
2102 let in_shape = graph.node(in_id).shape.dims();
2103 let rank = target_shape.len();
2104 if rank != in_shape.len() {
2105 panic!(
2106 "rlx-wgpu Expand: rank mismatch \
2107 (in_rank={}, target_rank={})",
2108 in_shape.len(),
2109 rank
2110 );
2111 }
2112 let out_dims: Vec<u32> = target_shape.iter().map(|&d| d as u32).collect();
2113 let in_dims: Vec<u32> =
2114 in_shape.iter().map(|d| d.unwrap_static() as u32).collect();
2115 let mut in_strides_row = vec![1u32; rank];
2119 for i in (0..rank.saturating_sub(1)).rev() {
2120 in_strides_row[i] = in_strides_row[i + 1] * in_dims[i + 1];
2121 }
2122 let strides_for_out: Vec<u32> = (0..rank)
2123 .map(|i| {
2124 if in_dims[i] == 1 && out_dims[i] != 1 {
2125 0
2126 } else {
2127 in_strides_row[i]
2128 }
2129 })
2130 .collect();
2131
2132 let mut meta_data: Vec<u32> = Vec::with_capacity(rank * 2);
2133 meta_data.extend_from_slice(&out_dims);
2134 meta_data.extend_from_slice(&strides_for_out);
2135 let meta_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
2136 label: Some("rlx-wgpu expand meta"),
2137 size: (meta_data.len() * 4).max(4) as u64,
2138 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
2139 mapped_at_creation: false,
2140 });
2141 dev.queue
2142 .write_buffer(&meta_buf, 0, bytemuck::cast_slice(&meta_data));
2143 let meta_idx = meta_buffers.len();
2144 meta_buffers.push(meta_buf);
2145
2146 let bucket_outermost = if in_dims[0] == out_dims[0] {
2152 1u32
2153 } else {
2154 0u32
2155 };
2156 let p = ExpandParams {
2157 rank: rank as u32,
2158 out_total: elems,
2159 in_off: (arena.offset(in_id) / 4) as u32,
2160 out_off: (arena.offset(node.id) / 4) as u32,
2161 bucket_outermost,
2162 out_dim_0: out_dims[0],
2163 _p2: 0,
2164 _p3: 0,
2165 };
2166 schedule.push(Step::Expand {
2167 params: p,
2168 meta_idx,
2169 });
2170 let ek = expand_kernel(&dev.device);
2171 let u = emit_uniform(std::mem::size_of::<ExpandParams>());
2172 let bg = dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
2173 label: Some("rlx-wgpu expand bg"),
2174 layout: &ek.bgl,
2175 entries: &[
2176 wgpu::BindGroupEntry {
2177 binding: 0,
2178 resource: arena.buffer.as_entire_binding(),
2179 },
2180 wgpu::BindGroupEntry {
2181 binding: 1,
2182 resource: u.as_entire_binding(),
2183 },
2184 wgpu::BindGroupEntry {
2185 binding: 2,
2186 resource: meta_buffers[meta_idx].as_entire_binding(),
2187 },
2188 ],
2189 });
2190 uniforms.push(u);
2191 bind_groups.push(bg);
2192 }
2193
2194 Op::Gather { axis } => {
2195 let table_id = node.inputs[0];
2196 let idx_id = node.inputs[1];
2197 if *axis == 0 {
2198 let table_shape = graph.node(table_id).shape.dims();
2199 let idx_shape = graph.node(idx_id).shape.dims();
2200 let vocab = table_shape[0].unwrap_static() as u32;
2201 let dim: u32 = table_shape[1..]
2202 .iter()
2203 .map(|d| d.unwrap_static() as u32)
2204 .product::<u32>()
2205 .max(1);
2206 let n_idx: u32 =
2207 idx_shape.iter().map(|d| d.unwrap_static() as u32).product();
2208 let p = GatherParams {
2209 n_out: elems,
2210 n_idx,
2211 dim,
2212 vocab,
2213 in_off: (arena.offset(table_id) / 4) as u32,
2214 idx_off: (arena.offset(idx_id) / 4) as u32,
2215 out_off: (arena.offset(node.id) / 4) as u32,
2216 _p0: 0,
2217 };
2218 schedule.push(Step::Gather { params: p });
2219 let gk = gather_kernel(&dev.device);
2220 let u = emit_uniform(std::mem::size_of::<GatherParams>());
2221 let bg = bind_two(&dev.device, gk, &arena.buffer, &u);
2222 uniforms.push(u);
2223 bind_groups.push(bg);
2224 } else {
2225 let table_shape = graph.node(table_id).shape.dims();
2226 let idx_shape = graph.node(idx_id).shape.dims();
2227 let outer: u32 = table_shape[..*axis]
2228 .iter()
2229 .map(|d| d.unwrap_static() as u32)
2230 .product::<u32>()
2231 .max(1);
2232 let trailing: u32 = table_shape[*axis + 1..]
2233 .iter()
2234 .map(|d| d.unwrap_static() as u32)
2235 .product::<u32>()
2236 .max(1);
2237 let axis_dim = table_shape[*axis].unwrap_static() as u32;
2238 let num_idx: u32 =
2239 idx_shape.iter().map(|d| d.unwrap_static() as u32).product();
2240 let total = outer * num_idx * trailing;
2241 let p = GatherAxisParams {
2242 total,
2243 outer,
2244 axis_dim,
2245 num_idx,
2246 trailing,
2247 table_off: (arena.offset(table_id) / 4) as u32,
2248 idx_off: (arena.offset(idx_id) / 4) as u32,
2249 out_off: (arena.offset(node.id) / 4) as u32,
2250 };
2251 schedule.push(Step::GatherAxis { params: p });
2252 let gk = gather_axis_kernel(&dev.device);
2253 let u = emit_uniform(std::mem::size_of::<GatherAxisParams>());
2254 let bg = bind_two(&dev.device, gk, &arena.buffer, &u);
2255 uniforms.push(u);
2256 bind_groups.push(bg);
2257 }
2258 }
2259
2260 Op::FusedMatMulBiasAct { activation } => {
2261 let a_id = node.inputs[0];
2264 let b_id = node.inputs[1];
2265 let bias_id = node.inputs[2];
2266 let a_shape = graph.node(a_id).shape.dims();
2267 let b_shape = graph.node(b_id).shape.dims();
2268 let out_shape = node.shape.dims();
2269 let (m, k, n) =
2270 if a_shape.len() == 2 && b_shape.len() == 2 && out_shape.len() == 2 {
2271 (
2272 a_shape[0].unwrap_static() as u32,
2273 a_shape[1].unwrap_static() as u32,
2274 b_shape[1].unwrap_static() as u32,
2275 )
2276 } else if a_shape.len() >= 2
2277 && b_shape.len() == 2
2278 && out_shape.len() == a_shape.len()
2279 {
2280 let leading: usize = a_shape[..a_shape.len() - 2]
2281 .iter()
2282 .map(|d| d.unwrap_static())
2283 .product();
2284 let m_inner = a_shape[a_shape.len() - 2].unwrap_static();
2285 let k_inner = a_shape[a_shape.len() - 1].unwrap_static();
2286 let n_inner = b_shape[1].unwrap_static();
2287 ((leading * m_inner) as u32, k_inner as u32, n_inner as u32)
2288 } else {
2289 panic!(
2290 "rlx-wgpu FusedMatMulBiasAct: unsupported shapes \
2291 a={a_shape:?} b={b_shape:?}"
2292 );
2293 };
2294 let act_id = match activation {
2295 None => 0xFFFFu32,
2296 Some(a) => activation_op_id(*a),
2297 };
2298 let b_is_param = traces_to_param(&graph, b_id);
2299 let compute_precision =
2300 derive_matmul_compute(&dev.device, &graph, a_id, b_id, m, k, n);
2301
2302 let mqk_eligible = act_id == 0xFFFFu32
2311 && (compute_precision == MatmulCompute::F32
2312 || compute_precision == MatmulCompute::CoopF32);
2313 if mqk_eligible && let Some(&(q_id, k_id_n, v_id)) = qkv_split.get(&node.id) {
2314 let head_width = n / 3;
2315 let coop = compute_precision == MatmulCompute::CoopF32;
2316 let mqk_kernel = if coop {
2317 matmul_qkv_coop_f32_kernel(&dev.device)
2318 .expect("coop matmul_qkv kernel: hardware feature was checked but kernel missing")
2319 } else {
2320 matmul_qkv_kernel(&dev.device)
2321 };
2322 let p = MatmulQkvParams {
2323 m,
2324 k,
2325 n,
2326 a_off: (arena.offset(a_id) / 4) as u32,
2327 b_off: (arena.offset(b_id) / 4) as u32,
2328 q_off: (arena.offset(q_id) / 4) as u32,
2329 k_off: (arena.offset(k_id_n) / 4) as u32,
2330 v_off: (arena.offset(v_id) / 4) as u32,
2331 head_width,
2332 has_bias: 1,
2333 bias_off: (arena.offset(bias_id) / 4) as u32,
2334 _p0: 0,
2335 _p1: 0,
2336 _p2: 0,
2337 _p3: 0,
2338 _p4: 0,
2339 };
2340 schedule.push(Step::MatmulQkv { params: p, coop });
2341 let u = emit_uniform(std::mem::size_of::<MatmulQkvParams>());
2342 let bg = bind_two(&dev.device, mqk_kernel, &arena.buffer, &u);
2343 uniforms.push(u);
2344 bind_groups.push(bg);
2345 } else {
2346 schedule.push(Step::Matmul {
2347 m,
2348 k,
2349 n,
2350 batch: 1,
2351 a_batch_stride: 0,
2352 b_batch_stride: 0,
2353 c_batch_stride: 0,
2354 a_off_f32: (arena.offset(a_id) / 4) as u32,
2355 b_off_f32: (arena.offset(b_id) / 4) as u32,
2356 c_off_f32: (arena.offset(node.id) / 4) as u32,
2357 has_bias: 1,
2358 bias_off_f32: (arena.offset(bias_id) / 4) as u32,
2359 act_id,
2360 b_is_param,
2361 compute_precision,
2362 });
2363 let u = emit_uniform(std::mem::size_of::<MatmulParams>());
2364 let bg = build_matmul_bind_group(
2365 &dev.device,
2366 mm_k,
2367 mm_w,
2368 &mm_f16w,
2369 &mm_f16c,
2370 &mm_coop,
2371 &mm_coop_f32,
2372 &arena,
2373 &u,
2374 b_is_param,
2375 compute_precision,
2376 );
2377 uniforms.push(u);
2378 bind_groups.push(bg);
2379 }
2380 }
2381
2382 Op::DotGeneral { .. } => {
2383 panic!(
2388 "rlx-wgpu DotGeneral: leaked past unfusion pass — \
2389 check unfuse.rs::expand_dot_general for missing patterns"
2390 );
2391 }
2392
2393 Op::Sample {
2394 top_k,
2395 top_p,
2396 temperature,
2397 seed,
2398 } => {
2399 let in_id = node.inputs[0];
2400 let in_shape = graph.node(in_id).shape.dims();
2401 let inner = in_shape[in_shape.len() - 1].unwrap_static() as u32;
2402 let total: u32 = in_shape.iter().map(|d| d.unwrap_static() as u32).product();
2403 let outer = total / inner.max(1);
2404 let is_greedy = *top_k == 0
2407 && (*top_p - 1.0).abs() < 1e-6
2408 && (*temperature - 1.0).abs() < 1e-6;
2409 if is_greedy {
2410 let p = ArgmaxParams {
2411 outer,
2412 inner,
2413 in_off: (arena.offset(in_id) / 4) as u32,
2414 out_off: (arena.offset(node.id) / 4) as u32,
2415 _p0: 0,
2416 _p1: 0,
2417 _p2: 0,
2418 _p3: 0,
2419 };
2420 schedule.push(Step::Argmax { params: p });
2421 let amk = argmax_kernel(&dev.device);
2422 let u = emit_uniform(std::mem::size_of::<ArgmaxParams>());
2423 let bg = bind_two(&dev.device, amk, &arena.buffer, &u);
2424 uniforms.push(u);
2425 bind_groups.push(bg);
2426 } else {
2427 let p = SampleParams {
2428 outer,
2429 inner,
2430 in_off: (arena.offset(in_id) / 4) as u32,
2431 out_off: (arena.offset(node.id) / 4) as u32,
2432 top_k: *top_k as u32,
2433 top_p_bits: top_p.to_bits(),
2434 temp_bits: temperature.to_bits(),
2435 seed_lo: *seed as u32,
2436 seed_hi: (*seed >> 32) as u32,
2437 _p0: 0,
2438 _p1: 0,
2439 _p2: 0,
2440 };
2441 schedule.push(Step::Sample { params: p });
2442 let sk = sample_kernel(&dev.device);
2443 let u = emit_uniform(std::mem::size_of::<SampleParams>());
2444 let bg = bind_two(&dev.device, sk, &arena.buffer, &u);
2445 uniforms.push(u);
2446 bind_groups.push(bg);
2447 }
2448 }
2449
2450 Op::Pool {
2451 kind,
2452 kernel_size,
2453 stride,
2454 padding,
2455 } => {
2456 let in_shape = graph.node(node.inputs[0]).shape.dims();
2457 let out_shape = node.shape.dims();
2458 let op_id: u32 = match kind {
2459 ReduceOp::Sum => 0,
2460 ReduceOp::Mean => 1,
2461 ReduceOp::Max => 2,
2462 ReduceOp::Min => 3,
2463 ReduceOp::Prod => 4,
2464 };
2465 match (kernel_size.len(), in_shape.len(), out_shape.len()) {
2466 (1, 3, 3) => {
2467 let p = Pool1dParams {
2468 n: in_shape[0].unwrap_static() as u32,
2469 c: in_shape[1].unwrap_static() as u32,
2470 l: in_shape[2].unwrap_static() as u32,
2471 l_out: out_shape[2].unwrap_static() as u32,
2472 kl: kernel_size[0] as u32,
2473 sl: stride.first().copied().unwrap_or(1) as u32,
2474 pl: padding.first().copied().unwrap_or(0) as u32,
2475 op: op_id,
2476 in_off: (arena.offset(node.inputs[0]) / 4) as u32,
2477 out_off: (arena.offset(node.id) / 4) as u32,
2478 _p0: 0,
2479 _p1: 0,
2480 _p2: 0,
2481 _p3: 0,
2482 _p4: 0,
2483 _p5: 0,
2484 };
2485 schedule.push(Step::Pool1d { params: p });
2486 let pk = pool1d_kernel(&dev.device);
2487 let u = emit_uniform(std::mem::size_of::<Pool1dParams>());
2488 let bg = bind_two(&dev.device, pk, &arena.buffer, &u);
2489 uniforms.push(u);
2490 bind_groups.push(bg);
2491 }
2492 (2, 4, 4) => {
2493 let p = Pool2dParams {
2494 n: in_shape[0].unwrap_static() as u32,
2495 c: in_shape[1].unwrap_static() as u32,
2496 h: in_shape[2].unwrap_static() as u32,
2497 w: in_shape[3].unwrap_static() as u32,
2498 h_out: out_shape[2].unwrap_static() as u32,
2499 w_out: out_shape[3].unwrap_static() as u32,
2500 kh: kernel_size[0] as u32,
2501 kw: kernel_size[1] as u32,
2502 sh: stride.first().copied().unwrap_or(1) as u32,
2503 sw: stride.get(1).copied().unwrap_or(1) as u32,
2504 ph: padding.first().copied().unwrap_or(0) as u32,
2505 pw: padding.get(1).copied().unwrap_or(0) as u32,
2506 op: op_id,
2507 in_off: (arena.offset(node.inputs[0]) / 4) as u32,
2508 out_off: (arena.offset(node.id) / 4) as u32,
2509 _p0: 0,
2510 _p1: 0,
2511 _p2: 0,
2512 };
2513 schedule.push(Step::Pool2d { params: p });
2514 let pk = pool2d_kernel(&dev.device);
2515 let u = emit_uniform(std::mem::size_of::<Pool2dParams>());
2516 let bg = bind_two(&dev.device, pk, &arena.buffer, &u);
2517 uniforms.push(u);
2518 bind_groups.push(bg);
2519 }
2520 (3, 5, 5) => {
2521 let p = Pool3dParams {
2522 n: in_shape[0].unwrap_static() as u32,
2523 c: in_shape[1].unwrap_static() as u32,
2524 d: in_shape[2].unwrap_static() as u32,
2525 h: in_shape[3].unwrap_static() as u32,
2526 w: in_shape[4].unwrap_static() as u32,
2527 d_out: out_shape[2].unwrap_static() as u32,
2528 h_out: out_shape[3].unwrap_static() as u32,
2529 w_out: out_shape[4].unwrap_static() as u32,
2530 kd: kernel_size[0] as u32,
2531 kh: kernel_size[1] as u32,
2532 kw: kernel_size[2] as u32,
2533 sd: stride.first().copied().unwrap_or(1) as u32,
2534 sh: stride.get(1).copied().unwrap_or(1) as u32,
2535 sw: stride.get(2).copied().unwrap_or(1) as u32,
2536 pd: padding.first().copied().unwrap_or(0) as u32,
2537 ph: padding.get(1).copied().unwrap_or(0) as u32,
2538 pw: padding.get(2).copied().unwrap_or(0) as u32,
2539 op: op_id,
2540 in_off: (arena.offset(node.inputs[0]) / 4) as u32,
2541 out_off: (arena.offset(node.id) / 4) as u32,
2542 _p0: 0,
2543 _p1: 0,
2544 };
2545 schedule.push(Step::Pool3d { params: p });
2546 let pk = pool3d_kernel(&dev.device);
2547 let u = emit_uniform(std::mem::size_of::<Pool3dParams>());
2548 let bg = bind_two(&dev.device, pk, &arena.buffer, &u);
2549 uniforms.push(u);
2550 bind_groups.push(bg);
2551 }
2552 (k, n, m) => panic!(
2553 "rlx-wgpu Pool: kernel-rank {k} with input rank {n} / \
2554 output rank {m} not supported (use 1D/2D/3D NCHW)"
2555 ),
2556 }
2557 }
2558
2559 Op::Conv {
2560 kernel_size,
2561 stride,
2562 padding,
2563 dilation,
2564 groups,
2565 } => {
2566 let in_shape = graph.node(node.inputs[0]).shape.dims();
2567 let w_shape = graph.node(node.inputs[1]).shape.dims();
2568 let out_shape = node.shape.dims();
2569 let s = |i: usize| stride.get(i).copied().unwrap_or(1) as u32;
2570 let p = |i: usize| padding.get(i).copied().unwrap_or(0) as u32;
2571 let d = |i: usize| dilation.get(i).copied().unwrap_or(1) as u32;
2572 match (
2573 kernel_size.len(),
2574 in_shape.len(),
2575 w_shape.len(),
2576 out_shape.len(),
2577 ) {
2578 (1, 3, 3, 3) => {
2579 let p1 = Conv1dParams {
2580 n: in_shape[0].unwrap_static() as u32,
2581 c_in: in_shape[1].unwrap_static() as u32,
2582 c_out: out_shape[1].unwrap_static() as u32,
2583 l: in_shape[2].unwrap_static() as u32,
2584 l_out: out_shape[2].unwrap_static() as u32,
2585 kl: kernel_size[0] as u32,
2586 sl: s(0),
2587 pl: p(0),
2588 dl: d(0),
2589 groups: *groups as u32,
2590 in_off: (arena.offset(node.inputs[0]) / 4) as u32,
2591 w_off: (arena.offset(node.inputs[1]) / 4) as u32,
2592 out_off: (arena.offset(node.id) / 4) as u32,
2593 _p0: 0,
2594 _p1: 0,
2595 _p2: 0,
2596 };
2597 schedule.push(Step::Conv1d { params: p1 });
2598 let ck = conv1d_kernel(&dev.device);
2599 let u = emit_uniform(std::mem::size_of::<Conv1dParams>());
2600 let bg = bind_two(&dev.device, ck, &arena.buffer, &u);
2601 uniforms.push(u);
2602 bind_groups.push(bg);
2603 }
2604 (2, 4, 4, 4) => {
2605 let p2 = Conv2dParams {
2606 n: in_shape[0].unwrap_static() as u32,
2607 c_in: in_shape[1].unwrap_static() as u32,
2608 c_out: out_shape[1].unwrap_static() as u32,
2609 h: in_shape[2].unwrap_static() as u32,
2610 w: in_shape[3].unwrap_static() as u32,
2611 h_out: out_shape[2].unwrap_static() as u32,
2612 w_out: out_shape[3].unwrap_static() as u32,
2613 kh: kernel_size[0] as u32,
2614 kw: kernel_size[1] as u32,
2615 sh: s(0),
2616 sw: s(1),
2617 ph: p(0),
2618 pw: p(1),
2619 dh: d(0),
2620 dw: d(1),
2621 groups: *groups as u32,
2622 in_off: (arena.offset(node.inputs[0]) / 4) as u32,
2623 w_off: (arena.offset(node.inputs[1]) / 4) as u32,
2624 out_off: (arena.offset(node.id) / 4) as u32,
2625 };
2626 schedule.push(Step::Conv2d { params: p2 });
2627 let ck = conv2d_kernel(&dev.device);
2628 let u = emit_uniform(std::mem::size_of::<Conv2dParams>());
2629 let bg = bind_two(&dev.device, ck, &arena.buffer, &u);
2630 uniforms.push(u);
2631 bind_groups.push(bg);
2632 }
2633 (3, 5, 5, 5) => {
2634 let p3 = Conv3dParams {
2635 n: in_shape[0].unwrap_static() as u32,
2636 c_in: in_shape[1].unwrap_static() as u32,
2637 c_out: out_shape[1].unwrap_static() as u32,
2638 d: in_shape[2].unwrap_static() as u32,
2639 h: in_shape[3].unwrap_static() as u32,
2640 w: in_shape[4].unwrap_static() as u32,
2641 d_out: out_shape[2].unwrap_static() as u32,
2642 h_out: out_shape[3].unwrap_static() as u32,
2643 w_out: out_shape[4].unwrap_static() as u32,
2644 kd: kernel_size[0] as u32,
2645 kh: kernel_size[1] as u32,
2646 kw: kernel_size[2] as u32,
2647 sd: s(0),
2648 sh: s(1),
2649 sw: s(2),
2650 pd: p(0),
2651 ph: p(1),
2652 pw: p(2),
2653 dd: d(0),
2654 dh: d(1),
2655 dw: d(2),
2656 groups: *groups as u32,
2657 in_off: (arena.offset(node.inputs[0]) / 4) as u32,
2658 w_off: (arena.offset(node.inputs[1]) / 4) as u32,
2659 out_off: (arena.offset(node.id) / 4) as u32,
2660 _p0: 0,
2661 };
2662 schedule.push(Step::Conv3d { params: p3 });
2663 let ck = conv3d_kernel(&dev.device);
2664 let u = emit_uniform(std::mem::size_of::<Conv3dParams>());
2665 let bg = bind_two(&dev.device, ck, &arena.buffer, &u);
2666 uniforms.push(u);
2667 bind_groups.push(bg);
2668 }
2669 (k, ni, wi, mi) => panic!(
2670 "rlx-wgpu Conv: rank kernel={k} in={ni} weight={wi} out={mi} \
2671 not supported (use 1D/2D/3D NCHW)"
2672 ),
2673 }
2674 }
2675
2676 Op::Cumsum { axis, exclusive } => {
2677 let in_id = node.inputs[0];
2678 let in_shape = graph.node(in_id).shape.dims();
2679 let last = (in_shape.len() - 1) as i32;
2680 if *axis != -1 && *axis != last {
2681 panic!("rlx-wgpu Cumsum: only last-axis wired (got axis={axis})");
2682 }
2683 let inner = in_shape[in_shape.len() - 1].unwrap_static() as u32;
2684 let total: u32 = in_shape.iter().map(|d| d.unwrap_static() as u32).product();
2685 let outer = total / inner.max(1);
2686 let p = CumsumParams {
2687 outer,
2688 inner,
2689 in_off: (arena.offset(in_id) / 4) as u32,
2690 out_off: (arena.offset(node.id) / 4) as u32,
2691 exclusive: if *exclusive { 1 } else { 0 },
2692 _p0: 0,
2693 _p1: 0,
2694 _p2: 0,
2695 };
2696 schedule.push(Step::Cumsum { params: p });
2697 let ck2 = cumsum_kernel(&dev.device);
2698 let u = emit_uniform(std::mem::size_of::<CumsumParams>());
2699 let bg = bind_two(&dev.device, ck2, &arena.buffer, &u);
2700 uniforms.push(u);
2701 bind_groups.push(bg);
2702 }
2703 Op::Fft { inverse, norm } => {
2704 let in_id = node.inputs[0];
2705 let in_shape = graph.node(in_id).shape.clone();
2706 let meta = rlx_ir::fft::fft_meta(&in_shape);
2707 let dtype = in_shape.dtype();
2708 let use_gpu = rlx_ir::fft::gpu_fft_native_eligible(dtype, meta.n_complex)
2709 && meta.n_complex >= 2;
2710 let scale = norm.output_scale(meta.n_complex, *inverse) as f32;
2711 if use_gpu {
2712 schedule.push(Step::FftGpu {
2713 src_off: (arena.offset(in_id) / 4) as u32,
2714 dst_off: (arena.offset(node.id) / 4) as u32,
2715 outer: meta.outer as u32,
2716 n: meta.n_complex as u32,
2717 inverse: if *inverse { 1 } else { 0 },
2718 norm_scale: scale,
2719 });
2720 fft_gpu_steps.push(crate::fft_dispatch::FftGpuResources::new(
2721 &dev.device,
2722 &arena.buffer,
2723 ));
2724 } else {
2725 schedule.push(Step::FftHost {
2726 src_byte_off: arena.offset(in_id) as u32,
2727 dst_byte_off: arena.offset(node.id) as u32,
2728 outer: meta.outer as u32,
2729 n_complex: meta.n_complex as u32,
2730 inverse: *inverse,
2731 norm_tag: norm.tag(),
2732 dtype_tag: fft_dtype_tag(dtype),
2733 });
2734 }
2735 }
2736 Op::SelectiveScan { state_size } => {
2737 if *state_size > 256 {
2738 panic!(
2739 "rlx-wgpu SelectiveScan: state_size {} exceeds compile-time \
2740 cap of 256 (kernel uses fixed-size private array)",
2741 state_size
2742 );
2743 }
2744 let x_id = node.inputs[0];
2745 let dt_id = node.inputs[1];
2746 let a_id = node.inputs[2];
2747 let b_id = node.inputs[3];
2748 let c_id = node.inputs[4];
2749 let in_dims = graph.node(x_id).shape.dims();
2750 let seq = in_dims[1].unwrap_static() as u32;
2751 let p = SelectiveScanParams {
2752 batch: in_dims[0].unwrap_static() as u32,
2753 seq,
2754 hidden: in_dims[2].unwrap_static() as u32,
2755 state_size: *state_size as u32,
2756 x_off: (arena.offset(x_id) / 4) as u32,
2757 delta_off: (arena.offset(dt_id) / 4) as u32,
2758 a_off: (arena.offset(a_id) / 4) as u32,
2759 b_off: (arena.offset(b_id) / 4) as u32,
2760 c_off: (arena.offset(c_id) / 4) as u32,
2761 out_off: (arena.offset(node.id) / 4) as u32,
2762 seq_stride: seq,
2765 _p1: 0,
2766 _p2: 0,
2767 _p3: 0,
2768 _p4: 0,
2769 _p5: 0,
2770 };
2771 schedule.push(Step::SelectiveScan { params: p });
2772 let ssk = selective_scan_kernel(&dev.device);
2773 let u = emit_uniform(std::mem::size_of::<SelectiveScanParams>());
2774 let bg = bind_two(&dev.device, ssk, &arena.buffer, &u);
2775 uniforms.push(u);
2776 bind_groups.push(bg);
2777 }
2778 Op::GatedDeltaNet {
2779 state_size,
2780 carry_state,
2781 } => {
2782 if *state_size > rlx_cpu::gdn::GDN_MAX_STATE {
2783 panic!(
2784 "rlx-wgpu GatedDeltaNet: state_size {state_size} > {}",
2785 rlx_cpu::gdn::GDN_MAX_STATE
2786 );
2787 }
2788 let q_id = node.inputs[0];
2789 let q_shape = &graph.node(q_id).shape;
2790 let state_off = if *carry_state {
2791 arena.offset(node.inputs[5])
2792 } else {
2793 0
2794 };
2795 schedule.push(Step::GatedDeltaNet {
2796 q_byte_off: arena.offset(q_id) as u32,
2797 k_byte_off: arena.offset(node.inputs[1]) as u32,
2798 v_byte_off: arena.offset(node.inputs[2]) as u32,
2799 g_byte_off: arena.offset(node.inputs[3]) as u32,
2800 beta_byte_off: arena.offset(node.inputs[4]) as u32,
2801 state_byte_off: state_off as u32,
2802 dst_byte_off: arena.offset(node.id) as u32,
2803 batch: q_shape.dim(0).unwrap_static() as u32,
2804 seq: q_shape.dim(1).unwrap_static() as u32,
2805 heads: q_shape.dim(2).unwrap_static() as u32,
2806 state_size: *state_size as u32,
2807 use_carry: *carry_state,
2808 });
2809 if gguf_host_pad.is_none() {
2810 let bk = binary_kernel(&dev.device);
2811 let u = emit_uniform(256);
2812 gguf_host_pad =
2813 Some((u.clone(), bind_two(&dev.device, bk, &arena.buffer, &u)));
2814 }
2815 let (u, bg) = gguf_host_pad.as_ref().unwrap();
2816 uniforms.push(u.clone());
2817 bind_groups.push(bg.clone());
2818 }
2819 Op::Custom { name, attrs, .. } => match name.as_str() {
2820 "llada2.group_limited_gate" => {
2821 let sig_id = node.inputs[0];
2822 let route_id = node.inputs[1];
2823 let n_elems = graph.node(sig_id).shape.num_elements().unwrap() as u32;
2824 let mut attr_buf = [0u8; 20];
2825 let n = attrs.len().min(20);
2826 attr_buf[..n].copy_from_slice(&attrs[..n]);
2827 schedule.push(Step::Llada2GroupLimitedGate {
2828 sig_byte_off: arena.offset(sig_id) as u32,
2829 route_byte_off: arena.offset(route_id) as u32,
2830 out_byte_off: arena.offset(node.id) as u32,
2831 n_elems,
2832 attrs: attr_buf,
2833 });
2834 }
2835 "umap.knn" => {
2836 let pw_id = node.inputs[0];
2837 let pw_shape = graph.node(pw_id).shape.dims();
2838 let n = pw_shape[0].unwrap_static() as u32;
2839 let k = if attrs.len() >= 4 {
2840 u32::from_le_bytes(attrs[..4].try_into().unwrap())
2841 } else {
2842 panic!("rlx-wgpu: umap.knn attrs missing k");
2843 };
2844 let pw_off = arena.offset(pw_id) as u32;
2845 let out_off = arena.offset(node.id) as u32;
2846 if n as usize >= crate::umap_knn_host::UMAP_KNN_GPU_MIN_N {
2847 let p = UmapKnnParams {
2848 n,
2849 k,
2850 pw_off: pw_off / 4,
2851 out_off: out_off / 4,
2852 _p0: 0,
2853 _p1: 0,
2854 _p2: 0,
2855 };
2856 schedule.push(Step::UmapKnn { params: p });
2857 let uk = umap_knn_kernel(&dev.device);
2858 let u = emit_uniform(std::mem::size_of::<UmapKnnParams>());
2859 let bg = bind_two(&dev.device, uk, &arena.buffer, &u);
2860 uniforms.push(u);
2861 bind_groups.push(bg);
2862 } else {
2863 schedule.push(Step::UmapKnnHost {
2864 pairwise_byte_off: pw_off,
2865 out_byte_off: out_off,
2866 n,
2867 k,
2868 });
2869 }
2870 }
2871 other => panic!("rlx-wgpu: unsupported Op::Custom('{other}')"),
2872 },
2873 Op::GroupedMatMul => {
2874 let in_id = node.inputs[0];
2876 let w_id = node.inputs[1];
2877 let idx_id = node.inputs[2];
2878 let in_dims = graph.node(in_id).shape.dims();
2879 let w_dims = graph.node(w_id).shape.dims();
2880 let m = in_dims[0].unwrap_static() as u32;
2881 let k = in_dims[1].unwrap_static() as u32;
2882 let n = w_dims[2].unwrap_static() as u32;
2883 let ne = w_dims[0].unwrap_static() as u32;
2884 let p = GroupedMatmulParams {
2885 m,
2886 k,
2887 n,
2888 num_experts: ne,
2889 in_off: (arena.offset(in_id) / 4) as u32,
2890 w_off: (arena.offset(w_id) / 4) as u32,
2891 idx_off: (arena.offset(idx_id) / 4) as u32,
2892 out_off: (arena.offset(node.id) / 4) as u32,
2893 };
2894 schedule.push(Step::GroupedMatmul { params: p });
2895 let gk = grouped_matmul_kernel(&dev.device);
2896 let u = emit_uniform(std::mem::size_of::<GroupedMatmulParams>());
2897 let bg = bind_two(&dev.device, gk, &arena.buffer, &u);
2898 uniforms.push(u);
2899 bind_groups.push(bg);
2900 }
2901 Op::DequantGroupedMatMul { scheme } => {
2902 let in_id = node.inputs[0];
2903 let w_id = node.inputs[1];
2904 let idx_id = node.inputs[2];
2905 let in_dims = graph.node(in_id).shape.dims();
2906 let out_dims = node.shape.dims();
2907 let m = in_dims[0].unwrap_static() as u32;
2908 let k = in_dims[1].unwrap_static() as u32;
2909 let n = out_dims[out_dims.len() - 1].unwrap_static() as u32;
2910 let block_elems = scheme.gguf_block_size() as usize;
2911 let block_bytes = scheme.gguf_block_bytes() as usize;
2912 let slab_bytes = (k as usize * n as usize) / block_elems * block_bytes;
2913 let total_bytes = graph.node(w_id).shape.num_elements().unwrap();
2914 let ne = (total_bytes / slab_bytes.max(1)) as u32;
2915 schedule.push(Step::DequantGroupedMatmulGguf {
2916 m,
2917 k,
2918 n,
2919 num_experts: ne,
2920 scheme_id: crate::gguf_host::gguf_scheme_id(*scheme),
2921 x_byte_off: arena.offset(in_id) as u32,
2922 w_byte_off: arena.offset(w_id) as u32,
2923 idx_byte_off: arena.offset(idx_id) as u32,
2924 out_byte_off: arena.offset(node.id) as u32,
2925 });
2926 if gguf_host_pad.is_none() {
2927 let bk = binary_kernel(&dev.device);
2928 let u = emit_uniform(256);
2929 gguf_host_pad =
2930 Some((u.clone(), bind_two(&dev.device, bk, &arena.buffer, &u)));
2931 }
2932 let (u, bg) = gguf_host_pad.as_ref().unwrap();
2933 uniforms.push(u.clone());
2934 bind_groups.push(bg.clone());
2935 }
2936 Op::TopK { k } => {
2937 let in_id = node.inputs[0];
2938 let in_dims = graph.node(in_id).shape.dims();
2939 let inner = in_dims.last().unwrap().unwrap_static() as u32;
2940 let outer: u32 = in_dims[..in_dims.len() - 1]
2941 .iter()
2942 .map(|d| d.unwrap_static() as u32)
2943 .product::<u32>()
2944 .max(1);
2945 let p = TopKParams {
2946 outer,
2947 inner,
2948 k: *k as u32,
2949 in_off: (arena.offset(in_id) / 4) as u32,
2950 out_off: (arena.offset(node.id) / 4) as u32,
2951 _p0: 0,
2952 _p1: 0,
2953 _p2: 0,
2954 };
2955 schedule.push(Step::TopK { params: p });
2956 let tk = topk_kernel(&dev.device);
2957 let u = emit_uniform(std::mem::size_of::<TopKParams>());
2958 let bg = bind_two(&dev.device, tk, &arena.buffer, &u);
2959 uniforms.push(u);
2960 bind_groups.push(bg);
2961 }
2962 Op::ScatterAdd => {
2963 let upd_id = node.inputs[0];
2968 let idx_id = node.inputs[1];
2969 let upd_dims = graph.node(upd_id).shape.dims();
2970 let out_dims = node.shape.dims();
2971 let num_updates = upd_dims[0].unwrap_static() as u32;
2972 let trailing: u32 = upd_dims
2973 .iter()
2974 .skip(1)
2975 .map(|d| d.unwrap_static() as u32)
2976 .product::<u32>()
2977 .max(1);
2978 let out_dim = out_dims[0].unwrap_static() as u32;
2979 let out_total = out_dim * trailing;
2980
2981 let common = ScatterAddParams {
2982 op: 0,
2983 out_off: (arena.offset(node.id) / 4) as u32,
2984 upd_off: (arena.offset(upd_id) / 4) as u32,
2985 idx_off: (arena.offset(idx_id) / 4) as u32,
2986 out_total,
2987 num_updates,
2988 trailing,
2989 out_dim,
2990 };
2991 let sk = scatter_add_kernel(&dev.device);
2992
2993 schedule.push(Step::ScatterAdd { params: common });
2995 let u0 = emit_uniform(std::mem::size_of::<ScatterAddParams>());
2996 let bg0 = bind_two(&dev.device, sk, &arena.buffer, &u0);
2997 uniforms.push(u0);
2998 bind_groups.push(bg0);
2999
3000 let mut acc = common;
3002 acc.op = 1;
3003 schedule.push(Step::ScatterAdd { params: acc });
3004 let u1 = emit_uniform(std::mem::size_of::<ScatterAddParams>());
3005 let bg1 = bind_two(&dev.device, sk, &arena.buffer, &u1);
3006 uniforms.push(u1);
3007 bind_groups.push(bg1);
3008 }
3009 Op::FusedResidualLN { has_bias, eps } => {
3010 let x_id = node.inputs[0];
3012 let r_id = node.inputs[1];
3013 let (bias_id, g_id, b_id) = if *has_bias {
3014 (node.inputs[2], node.inputs[3], node.inputs[4])
3015 } else {
3016 (x_id, node.inputs[2], node.inputs[3]) };
3018 let in_dims = node.shape.dims();
3019 let inner = in_dims[in_dims.len() - 1].unwrap_static() as u32;
3020 let total: u32 = in_dims.iter().map(|d| d.unwrap_static() as u32).product();
3021 let outer = total / inner.max(1);
3022 let p = FusedResidualLnParams {
3023 outer,
3024 inner,
3025 in_off: (arena.offset(x_id) / 4) as u32,
3026 residual_off: (arena.offset(r_id) / 4) as u32,
3027 bias_off: (arena.offset(bias_id) / 4) as u32,
3028 gamma_off: (arena.offset(g_id) / 4) as u32,
3029 beta_off: (arena.offset(b_id) / 4) as u32,
3030 out_off: (arena.offset(node.id) / 4) as u32,
3031 eps_bits: eps.to_bits(),
3032 has_bias: if *has_bias { 1 } else { 0 },
3033 _p0: 0,
3034 _p1: 0,
3035 };
3036 schedule.push(Step::FusedResidualLn { params: p });
3037 let frk = fused_residual_ln_kernel(&dev.device);
3038 let u = emit_uniform(std::mem::size_of::<FusedResidualLnParams>());
3039 let bg = bind_two(&dev.device, frk, &arena.buffer, &u);
3040 uniforms.push(u);
3041 bind_groups.push(bg);
3042 }
3043 Op::FusedResidualRmsNorm { has_bias, eps } => {
3044 let x_id = node.inputs[0];
3045 let r_id = node.inputs[1];
3046 let (bias_id, g_id, b_id) = if *has_bias {
3047 (node.inputs[2], node.inputs[3], node.inputs[4])
3048 } else {
3049 (x_id, node.inputs[2], node.inputs[3])
3050 };
3051 let in_dims = node.shape.dims();
3052 let inner = in_dims[in_dims.len() - 1].unwrap_static() as u32;
3053 let total: u32 = in_dims.iter().map(|d| d.unwrap_static() as u32).product();
3054 let outer = total / inner.max(1);
3055 let p = FusedResidualRmsNormParams {
3056 outer,
3057 inner,
3058 in_off: (arena.offset(x_id) / 4) as u32,
3059 residual_off: (arena.offset(r_id) / 4) as u32,
3060 bias_off: (arena.offset(bias_id) / 4) as u32,
3061 gamma_off: (arena.offset(g_id) / 4) as u32,
3062 beta_off: (arena.offset(b_id) / 4) as u32,
3063 out_off: (arena.offset(node.id) / 4) as u32,
3064 eps_bits: eps.to_bits(),
3065 has_bias: if *has_bias { 1 } else { 0 },
3066 _p0: 0,
3067 _p1: 0,
3068 };
3069 schedule.push(Step::FusedResidualRmsNorm { params: p });
3070 let frk = fused_residual_rms_norm_kernel(&dev.device);
3071 let u = emit_uniform(std::mem::size_of::<FusedResidualRmsNormParams>());
3072 let bg = bind_two(&dev.device, frk, &arena.buffer, &u);
3073 uniforms.push(u);
3074 bind_groups.push(bg);
3075 }
3076 Op::DequantMatMul { scheme } => {
3077 use rlx_ir::QuantScheme;
3078 let x_id = node.inputs[0];
3079 let w_id = node.inputs[1];
3080 let out_dims = node.shape.dims();
3081 let x_dims = graph.node(x_id).shape.dims();
3082 let m = out_dims[0].unwrap_static() as u32;
3083 let n = out_dims[1].unwrap_static() as u32;
3084 let k = x_dims[1].unwrap_static() as u32;
3085 if scheme.is_gguf() {
3086 schedule.push(Step::DequantMatmulGguf {
3087 m,
3088 k,
3089 n,
3090 scheme_id: crate::gguf_host::gguf_scheme_id(*scheme),
3091 x_byte_off: arena.offset(x_id) as u32,
3092 w_byte_off: arena.offset(w_id) as u32,
3093 out_byte_off: arena.offset(node.id) as u32,
3094 });
3095 if gguf_host_pad.is_none() {
3096 let bk = binary_kernel(&dev.device);
3097 let u = emit_uniform(256);
3098 gguf_host_pad =
3099 Some((u.clone(), bind_two(&dev.device, bk, &arena.buffer, &u)));
3100 }
3101 let (u, bg) = gguf_host_pad.as_ref().unwrap();
3102 uniforms.push(u.clone());
3103 bind_groups.push(bg.clone());
3104 } else {
3105 let (block_size, scheme_id) = match scheme {
3106 QuantScheme::Int8Block { block_size } => (*block_size, 0u32),
3107 QuantScheme::Int8BlockAsym { block_size } => (*block_size, 1u32),
3108 QuantScheme::Int4Block { block_size } => (*block_size, 2u32),
3109 QuantScheme::Fp8E4m3 => (1, 3u32),
3110 QuantScheme::Fp8E5m2 => (1, 4u32),
3111 QuantScheme::Nvfp4Block => (rlx_ir::NVFP4_GROUP_SIZE as u32, 5u32),
3112 other => panic!("rlx-wgpu DequantMatMul: unsupported scheme {other:?}"),
3113 };
3114 let scale_id = node.inputs[2];
3115 let zp_id = node.inputs[3];
3116 let p = DequantMatmulParams {
3117 m,
3118 k,
3119 n,
3120 block_size,
3121 scheme_id,
3122 x_off: (arena.offset(x_id) / 4) as u32,
3123 w_off: (arena.offset(w_id) / 4) as u32,
3124 scale_off: (arena.offset(scale_id) / 4) as u32,
3125 zp_off: (arena.offset(zp_id) / 4) as u32,
3126 out_off: (arena.offset(node.id) / 4) as u32,
3127 _p0: 0,
3128 _p1: 0,
3129 };
3130 schedule.push(Step::DequantMatmul { params: p });
3131 let dk = dequant_matmul_kernel(&dev.device);
3132 let u = emit_uniform(std::mem::size_of::<DequantMatmulParams>());
3133 let bg = bind_two(&dev.device, dk, &arena.buffer, &u);
3134 uniforms.push(u);
3135 bind_groups.push(bg);
3136 }
3137 }
3138 Op::RmsNormBackwardInput { eps, .. }
3139 | Op::RmsNormBackwardGamma { eps, .. }
3140 | Op::RmsNormBackwardBeta { eps, .. } => {
3141 let x_shape = &graph.node(node.inputs[0]).shape;
3142 let h = x_shape.dim(x_shape.rank() - 1).unwrap_static() as u32;
3143 let rows = (x_shape.num_elements().unwrap() / h.max(1) as usize) as u32;
3144 let foff = |i: usize| (arena.offset(node.inputs[i]) / 4) as u32;
3145 let wrt = match &node.op {
3146 Op::RmsNormBackwardInput { .. } => 0u32,
3147 Op::RmsNormBackwardGamma { .. } => 1u32,
3148 Op::RmsNormBackwardBeta { .. } => 2u32,
3149 _ => unreachable!(),
3150 };
3151 let p = RmsNormBwdParams {
3152 outer: rows,
3153 inner: h,
3154 x_off: foff(0),
3155 gamma_off: foff(1),
3156 beta_off: foff(2),
3157 dy_off: foff(3),
3158 out_off: (arena.offset(node.id) / 4) as u32,
3159 eps_bits: eps.to_bits(),
3160 wrt,
3161 };
3162 let rk = if wrt == 0 {
3163 rms_norm_backward_kernel(&dev.device)
3164 } else {
3165 rms_norm_backward_param_kernel(&dev.device)
3166 };
3167 let u = emit_uniform(std::mem::size_of::<RmsNormBwdParams>());
3168 let bg = bind_two(&dev.device, rk, &arena.buffer, &u);
3169 match &node.op {
3170 Op::RmsNormBackwardInput { .. } => {
3171 schedule.push(Step::RmsNormBackwardInput { params: p });
3172 }
3173 Op::RmsNormBackwardGamma { .. } => {
3174 schedule.push(Step::RmsNormBackwardGamma { params: p });
3175 }
3176 Op::RmsNormBackwardBeta { .. } => {
3177 schedule.push(Step::RmsNormBackwardBeta { params: p });
3178 }
3179 _ => unreachable!(),
3180 }
3181 uniforms.push(u);
3182 bind_groups.push(bg);
3183 }
3184 Op::RopeBackward { head_dim, n_rot } => {
3185 let dy_shape = &graph.node(node.inputs[0]).shape;
3186 let (batch, seq, hidden) = if dy_shape.rank() >= 3 {
3187 (
3188 dy_shape.dim(0).unwrap_static() as u32,
3189 dy_shape.dim(1).unwrap_static() as u32,
3190 dy_shape.dim(2).unwrap_static() as u32,
3191 )
3192 } else {
3193 (
3194 1,
3195 dy_shape.dim(0).unwrap_static() as u32,
3196 dy_shape.dim(1).unwrap_static() as u32,
3197 )
3198 };
3199 let cos_len = graph.node(node.inputs[1]).shape.num_elements().unwrap() as u32;
3200 let p = RopeBwdParams {
3201 batch,
3202 seq,
3203 hidden,
3204 head_dim: *head_dim as u32,
3205 n_rot: *n_rot as u32,
3206 dy_off: (arena.offset(node.inputs[0]) / 4) as u32,
3207 cos_off: (arena.offset(node.inputs[1]) / 4) as u32,
3208 sin_off: (arena.offset(node.inputs[2]) / 4) as u32,
3209 dx_off: (arena.offset(node.id) / 4) as u32,
3210 cos_len,
3211 };
3212 let rk = rope_backward_kernel(&dev.device);
3213 let u = emit_uniform(std::mem::size_of::<RopeBwdParams>());
3214 let bg = bind_two(&dev.device, rk, &arena.buffer, &u);
3215 schedule.push(Step::RopeBackward { params: p });
3216 uniforms.push(u);
3217 bind_groups.push(bg);
3218 }
3219 Op::CumsumBackward { exclusive, .. } => {
3220 let dy_shape = &graph.node(node.inputs[0]).shape;
3221 let cols = dy_shape.dim(dy_shape.rank() - 1).unwrap_static() as u32;
3222 let rows = (dy_shape.num_elements().unwrap() / cols.max(1) as usize) as u32;
3223 let p = CumsumBwdParams {
3224 outer: rows,
3225 inner: cols,
3226 dy_off: (arena.offset(node.inputs[0]) / 4) as u32,
3227 dx_off: (arena.offset(node.id) / 4) as u32,
3228 exclusive: if *exclusive { 1 } else { 0 },
3229 _p0: 0,
3230 _p1: 0,
3231 _p2: 0,
3232 };
3233 let ck = cumsum_backward_kernel(&dev.device);
3234 let u = emit_uniform(std::mem::size_of::<CumsumBwdParams>());
3235 let bg = bind_two(&dev.device, ck, &arena.buffer, &u);
3236 schedule.push(Step::CumsumBackward { params: p });
3237 uniforms.push(u);
3238 bind_groups.push(bg);
3239 }
3240 Op::GatherBackward { .. } => {
3241 let dy_shape = &graph.node(node.inputs[0]).shape;
3242 let idx_shape = &graph.node(node.inputs[1]).shape;
3243 let out_shape = &node.shape;
3244 let rank = out_shape.rank();
3245 let axis = match &node.op {
3246 Op::GatherBackward { axis } => *axis,
3247 _ => 0,
3248 };
3249 let axis_u = if axis < 0 {
3250 (rank as i32 + axis) as usize
3251 } else {
3252 axis as usize
3253 };
3254 let outer: usize = (0..axis_u)
3255 .map(|i| dy_shape.dim(i).unwrap_static())
3256 .product::<usize>()
3257 .max(1);
3258 let num_idx = idx_shape.dim(axis_u).unwrap_static();
3259 let trailing: usize = (axis_u + 1..dy_shape.rank())
3260 .map(|i| dy_shape.dim(i).unwrap_static())
3261 .product::<usize>()
3262 .max(1);
3263 let axis_dim = out_shape.dim(axis_u).unwrap_static();
3264 let p = GatherBwdParams {
3265 outer: outer as u32,
3266 axis_dim: axis_dim as u32,
3267 num_idx: num_idx as u32,
3268 trailing: trailing as u32,
3269 dy_off: (arena.offset(node.inputs[0]) / 4) as u32,
3270 idx_off: (arena.offset(node.inputs[1]) / 4) as u32,
3271 dst_off: (arena.offset(node.id) / 4) as u32,
3272 _p0: 0,
3273 };
3274 let zk = gather_backward_zero_kernel(&dev.device);
3275 let u = emit_uniform(std::mem::size_of::<GatherBwdParams>());
3276 let bg = bind_two(&dev.device, zk, &arena.buffer, &u);
3277 schedule.push(Step::GatherBackward { params: p });
3278 uniforms.push(u);
3279 bind_groups.push(bg);
3280 }
3281 #[cfg(feature = "splat")]
3282 Op::GaussianSplatRender {
3283 width,
3284 height,
3285 tile_size,
3286 radius_scale,
3287 alpha_cutoff,
3288 max_splat_steps,
3289 transmittance_threshold,
3290 max_list_entries,
3291 } => {
3292 let elem_len = |id: NodeId| -> u32 {
3293 graph.node(id).shape.num_elements().unwrap_or(0) as u32
3294 };
3295 schedule.push(Step::GaussianSplatRender {
3296 positions_byte_off: arena.offset(node.inputs[0]) as u32,
3297 positions_len: elem_len(node.inputs[0]),
3298 scales_byte_off: arena.offset(node.inputs[1]) as u32,
3299 scales_len: elem_len(node.inputs[1]),
3300 rotations_byte_off: arena.offset(node.inputs[2]) as u32,
3301 rotations_len: elem_len(node.inputs[2]),
3302 opacities_byte_off: arena.offset(node.inputs[3]) as u32,
3303 opacities_len: elem_len(node.inputs[3]),
3304 colors_byte_off: arena.offset(node.inputs[4]) as u32,
3305 colors_len: elem_len(node.inputs[4]),
3306 sh_coeffs_byte_off: arena.offset(node.inputs[5]) as u32,
3307 sh_coeffs_len: elem_len(node.inputs[5]),
3308 meta_byte_off: arena.offset(node.inputs[6]) as u32,
3309 dst_byte_off: arena.offset(node.id) as u32,
3310 dst_len: node.shape.num_elements().unwrap_or(0) as u32,
3311 width: *width,
3312 height: *height,
3313 tile_size: *tile_size,
3314 radius_scale: *radius_scale,
3315 alpha_cutoff: *alpha_cutoff,
3316 max_splat_steps: *max_splat_steps,
3317 transmittance_threshold: *transmittance_threshold,
3318 max_list_entries: *max_list_entries,
3319 });
3320 }
3321
3322 #[cfg(feature = "splat")]
3323 Op::GaussianSplatRenderBackward {
3324 width,
3325 height,
3326 tile_size,
3327 radius_scale,
3328 alpha_cutoff,
3329 max_splat_steps,
3330 transmittance_threshold,
3331 max_list_entries,
3332 loss_grad_clip,
3333 sh_band,
3334 max_anisotropy,
3335 } => {
3336 let elem_len = |id: NodeId| -> u32 {
3337 graph.node(id).shape.num_elements().unwrap_or(0) as u32
3338 };
3339 schedule.push(Step::GaussianSplatRenderBackward {
3340 positions_byte_off: arena.offset(node.inputs[0]) as u32,
3341 positions_len: elem_len(node.inputs[0]),
3342 scales_byte_off: arena.offset(node.inputs[1]) as u32,
3343 scales_len: elem_len(node.inputs[1]),
3344 rotations_byte_off: arena.offset(node.inputs[2]) as u32,
3345 rotations_len: elem_len(node.inputs[2]),
3346 opacities_byte_off: arena.offset(node.inputs[3]) as u32,
3347 opacities_len: elem_len(node.inputs[3]),
3348 colors_byte_off: arena.offset(node.inputs[4]) as u32,
3349 colors_len: elem_len(node.inputs[4]),
3350 sh_coeffs_byte_off: arena.offset(node.inputs[5]) as u32,
3351 sh_coeffs_len: elem_len(node.inputs[5]),
3352 meta_byte_off: arena.offset(node.inputs[6]) as u32,
3353 d_loss_byte_off: arena.offset(node.inputs[7]) as u32,
3354 d_loss_len: elem_len(node.inputs[7]),
3355 packed_byte_off: arena.offset(node.id) as u32,
3356 packed_len: node.shape.num_elements().unwrap_or(0) as u32,
3357 width: *width,
3358 height: *height,
3359 tile_size: *tile_size,
3360 radius_scale: *radius_scale,
3361 alpha_cutoff: *alpha_cutoff,
3362 max_splat_steps: *max_splat_steps,
3363 transmittance_threshold: *transmittance_threshold,
3364 max_list_entries: *max_list_entries,
3365 loss_grad_clip: *loss_grad_clip,
3366 sh_band: *sh_band,
3367 max_anisotropy: *max_anisotropy,
3368 });
3369 }
3370
3371 #[cfg(feature = "splat")]
3372 Op::GaussianSplatPrepare {
3373 width,
3374 height,
3375 tile_size,
3376 radius_scale,
3377 alpha_cutoff,
3378 max_splat_steps,
3379 transmittance_threshold,
3380 max_list_entries,
3381 } => {
3382 let elem_len = |id: NodeId| -> u32 {
3383 graph.node(id).shape.num_elements().unwrap_or(0) as u32
3384 };
3385 schedule.push(Step::GaussianSplatPrepare {
3386 positions_byte_off: arena.offset(node.inputs[0]) as u32,
3387 positions_len: elem_len(node.inputs[0]),
3388 scales_byte_off: arena.offset(node.inputs[1]) as u32,
3389 scales_len: elem_len(node.inputs[1]),
3390 rotations_byte_off: arena.offset(node.inputs[2]) as u32,
3391 rotations_len: elem_len(node.inputs[2]),
3392 opacities_byte_off: arena.offset(node.inputs[3]) as u32,
3393 opacities_len: elem_len(node.inputs[3]),
3394 colors_byte_off: arena.offset(node.inputs[4]) as u32,
3395 colors_len: elem_len(node.inputs[4]),
3396 sh_coeffs_byte_off: arena.offset(node.inputs[5]) as u32,
3397 sh_coeffs_len: elem_len(node.inputs[5]),
3398 meta_byte_off: arena.offset(node.inputs[6]) as u32,
3399 meta_len: elem_len(node.inputs[6]),
3400 prep_byte_off: arena.offset(node.id) as u32,
3401 prep_len: node.shape.num_elements().unwrap_or(0) as u32,
3402 width: *width,
3403 height: *height,
3404 tile_size: *tile_size,
3405 radius_scale: *radius_scale,
3406 alpha_cutoff: *alpha_cutoff,
3407 max_splat_steps: *max_splat_steps,
3408 transmittance_threshold: *transmittance_threshold,
3409 max_list_entries: *max_list_entries,
3410 });
3411 }
3412
3413 #[cfg(feature = "splat")]
3414 Op::GaussianSplatRasterize {
3415 width,
3416 height,
3417 tile_size,
3418 alpha_cutoff,
3419 max_splat_steps,
3420 transmittance_threshold,
3421 max_list_entries,
3422 } => {
3423 let elem_len = |id: NodeId| -> u32 {
3424 graph.node(id).shape.num_elements().unwrap_or(0) as u32
3425 };
3426 let prep_id = node.inputs[0];
3427 let count = match &graph.node(prep_id).op {
3428 rlx_ir::Op::GaussianSplatPrepare { .. } => {
3429 elem_len(graph.node(prep_id).inputs[0]) / 3
3430 }
3431 _ => 1,
3432 };
3433 schedule.push(Step::GaussianSplatRasterize {
3434 prep_byte_off: arena.offset(prep_id) as u32,
3435 prep_len: elem_len(prep_id),
3436 meta_byte_off: arena.offset(node.inputs[1]) as u32,
3437 meta_len: elem_len(node.inputs[1]),
3438 dst_byte_off: arena.offset(node.id) as u32,
3439 dst_len: node.shape.num_elements().unwrap_or(0) as u32,
3440 count,
3441 width: *width,
3442 height: *height,
3443 tile_size: *tile_size,
3444 alpha_cutoff: *alpha_cutoff,
3445 max_splat_steps: *max_splat_steps,
3446 transmittance_threshold: *transmittance_threshold,
3447 max_list_entries: *max_list_entries,
3448 });
3449 }
3450
3451 Op::If { .. } | Op::While { .. } => {
3452 panic!(
3457 "rlx-wgpu: Op::If/While leaked past unfusion pass — \
3458 check unfuse.rs::expand_if / expand_while"
3459 );
3460 }
3461 other => panic!(
3462 "rlx-wgpu: op {other:?} not yet lowered (v2 covers Matmul, \
3463 Binary, Compare, Activation, Where — fall back to CPU/Metal/MLX)"
3464 ),
3465 }
3466 }
3467
3468 if rlx_ir::env::flag("RLX_WGPU_SCHEDULE") || rlx_ir::env::flag("RLX_DISPATCH_REPORT") {
3469 let mut counts: std::collections::BTreeMap<&'static str, usize> =
3470 std::collections::BTreeMap::new();
3471 let mut fft_gpu = 0usize;
3472 let mut fft_host = 0usize;
3473 for s in &schedule {
3474 *counts.entry(step_name(s)).or_insert(0) += 1;
3475 match s {
3476 Step::FftGpu { .. } => fft_gpu += 1,
3477 Step::FftHost { .. } => fft_host += 1,
3478 _ => {}
3479 }
3480 }
3481 let arena_mb = arena.size as f64 / (1u64 << 20) as f64;
3482 eprintln!(
3483 "[rlx-wgpu] schedule: {} steps, arena={arena_mb:.1} MiB, fft_gpu={fft_gpu}, fft_host={fft_host}",
3484 schedule.len()
3485 );
3486 for (n, c) in &counts {
3487 eprintln!(" {c:>4} × {n}");
3488 }
3489 }
3490
3491 Self {
3492 graph,
3493 arena,
3494 schedule,
3495 input_offsets,
3496 param_offsets,
3497 uniforms,
3498 bind_groups,
3499 meta_buffers,
3500 unresolved: None,
3501 last_binding: None,
3502 pending_params: HashMap::new(),
3503 pending_param_bytes: HashMap::new(),
3504 active_extent: None,
3505 uniforms_active_extent: None,
3506 fft_gpu_steps,
3507 }
3508 }
3509
3510 pub fn set_param(&mut self, name: &str, data: &[f32]) {
3511 if self.unresolved.is_some() {
3512 self.pending_params.insert(name.to_string(), data.to_vec());
3513 return;
3514 }
3515 let dev = wgpu_device().expect("rlx-wgpu: device gone");
3516 if let Some(&id) = self.param_offsets.get(name)
3517 && self.arena.has(id)
3518 {
3519 self.arena.write_f32(&dev.queue, id, data);
3520 }
3521 }
3522
3523 pub fn debug_first_nan_node(
3528 &mut self,
3529 inputs: &[(&str, &[f32])],
3530 ) -> Option<(usize, String, String)> {
3531 let _ = self.run(inputs);
3532 let dev = wgpu_device().expect("rlx-wgpu: device gone");
3533 let mut prev_summary = String::from("(none)");
3534 for (i, node) in self.graph.nodes().iter().enumerate() {
3535 if !self.arena.has(node.id) {
3536 continue;
3537 }
3538 let elems = node.shape.num_elements().unwrap_or(0);
3539 if elems == 0 {
3540 continue;
3541 }
3542 let data = self.arena.read_f32(&dev.device, &dev.queue, node.id);
3543 let nan_count = data.iter().filter(|v| v.is_nan()).count();
3544 let inf_count = data.iter().filter(|v| v.is_infinite()).count();
3545 if nan_count > 0 || inf_count > 0 {
3546 return Some((i, format!("{:?}", node.op), prev_summary));
3547 }
3548 let max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
3549 let min = data.iter().copied().fold(f32::INFINITY, f32::min);
3550 let abs_max = data.iter().map(|v| v.abs()).fold(0.0_f32, f32::max);
3551 prev_summary = format!(
3552 "node #{i} {:?} shape={:?} min={min:.6e} max={max:.6e} |max|={abs_max:.6e}",
3553 node.op,
3554 node.shape
3555 .dims()
3556 .iter()
3557 .map(|d| format!("{d:?}"))
3558 .collect::<Vec<_>>()
3559 );
3560 }
3561 None
3562 }
3563
3564 pub fn output_dtypes(&self) -> Vec<rlx_ir::DType> {
3568 self.graph
3569 .outputs
3570 .iter()
3571 .map(|&id| self.graph.node(id).shape.dtype())
3572 .collect()
3573 }
3574
3575 pub fn set_param_bytes(&mut self, name: &str, data: &[u8]) {
3580 if self.unresolved.is_some() {
3581 self.pending_param_bytes
3582 .insert(name.to_string(), data.to_vec());
3583 return;
3584 }
3585 let dev = wgpu_device().expect("rlx-wgpu: device gone");
3586 if let Some(&id) = self.param_offsets.get(name)
3587 && self.arena.has(id)
3588 {
3589 dev.queue
3590 .write_buffer(&self.arena.buffer, self.arena.offset(id) as u64, data);
3591 }
3592 }
3593
3594 pub fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
3595 if self.unresolved.is_some() {
3598 self.lazy_compile_for_inputs(inputs);
3599 }
3600 let dev = wgpu_device().expect("rlx-wgpu: device gone");
3601 for &(name, data) in inputs {
3602 if let Some(&id) = self.input_offsets.get(name)
3603 && self.arena.has(id)
3604 {
3605 self.arena.write_f32(&dev.queue, id, data);
3606 }
3607 }
3608
3609 let active = self.active_extent.filter(|_| self.all_safe_for_active());
3614 let scale = |full: u32| -> u32 {
3615 match active {
3616 Some((a, u)) if u > 0 => {
3617 let f = full as usize;
3618 (f * a).div_ceil(u).min(f) as u32
3619 }
3620 _ => full,
3621 }
3622 };
3623
3624 let need_uniform_writes = self.uniforms_active_extent != Some(active);
3630 if need_uniform_writes {
3631 let mut gpu_ui = 0usize;
3632 for step in self.schedule.iter() {
3633 if step_runs_on_host(step) {
3634 continue;
3635 }
3636 match step {
3637 Step::CastF32ToF16 { .. } => {
3638 }
3642 Step::Matmul {
3643 m,
3644 k,
3645 n,
3646 a_off_f32,
3647 b_off_f32,
3648 c_off_f32,
3649 batch,
3650 a_batch_stride,
3651 b_batch_stride,
3652 c_batch_stride,
3653 has_bias,
3654 bias_off_f32,
3655 act_id,
3656 b_is_param: _,
3657 compute_precision: _,
3658 } => {
3659 let m_scaled = scale(*m);
3663 let p = MatmulParams {
3664 m: m_scaled,
3665 k: *k,
3666 n: *n,
3667 a_off: *a_off_f32,
3668 b_off: *b_off_f32,
3669 c_off: *c_off_f32,
3670 batch: *batch,
3671 a_batch_stride: *a_batch_stride,
3672 b_batch_stride: *b_batch_stride,
3673 c_batch_stride: *c_batch_stride,
3674 has_bias: *has_bias,
3675 bias_off: *bias_off_f32,
3676 act_id: *act_id,
3677 _pad0: 0,
3678 _pad1: 0,
3679 _pad2: 0,
3680 };
3681 dev.queue
3682 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3683 }
3684 Step::Binary { params } | Step::Compare { params } => {
3685 let mut p = *params;
3686 p.n = scale(p.n);
3687 dev.queue
3688 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3689 }
3690 Step::Unary { params } => {
3691 let mut p = *params;
3692 p.n = scale(p.n);
3693 dev.queue
3694 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3695 }
3696 Step::Where { params } => {
3697 let mut p = *params;
3698 p.n = scale(p.n);
3699 dev.queue
3700 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3701 }
3702 Step::Reduce { params } => {
3703 let mut p = *params;
3704 p.outer = scale(p.outer);
3705 dev.queue
3706 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3707 }
3708 Step::Softmax { params } => {
3709 let mut p = *params;
3710 p.outer = scale(p.outer);
3711 dev.queue
3712 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3713 }
3714 Step::LayerNorm { params } => {
3715 let mut p = *params;
3716 p.outer = scale(p.outer);
3717 dev.queue
3718 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3719 }
3720 Step::RmsNormBackwardInput { params }
3721 | Step::RmsNormBackwardGamma { params }
3722 | Step::RmsNormBackwardBeta { params } => {
3723 let mut p = *params;
3724 p.outer = scale(p.outer);
3725 dev.queue
3726 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3727 }
3728 Step::CumsumBackward { params } => {
3729 let mut p = *params;
3730 p.outer = scale(p.outer);
3731 dev.queue
3732 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3733 }
3734 Step::RopeBackward { params } => {
3735 let mut p = *params;
3736 p.seq = scale(p.seq);
3737 dev.queue
3738 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3739 }
3740 Step::GatherBackward { params } => {
3741 let mut p = *params;
3742 p.outer = scale(p.outer);
3743 dev.queue
3744 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3745 }
3746 Step::Cumsum { params } => {
3747 let mut p = *params;
3748 p.outer = scale(p.outer);
3749 dev.queue
3750 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3751 }
3752 Step::FftGpu { .. } => {}
3753 Step::Copy { params } => {
3754 let mut p = *params;
3755 p.n = scale(p.n);
3756 dev.queue
3757 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3758 }
3759 Step::ElementwiseRegion { params } => {
3760 let mut p = *params;
3762 p.len = scale(p.len);
3763 dev.queue
3764 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3765 }
3766 Step::Transpose { params, .. } => {
3767 let mut p = *params;
3772 if p.bucket_outermost == 1 && p.out_dim_0 > 0 {
3773 let scaled_d0 = scale(p.out_dim_0);
3774 let inner = p.out_total / p.out_dim_0;
3775 p.out_total = scaled_d0 * inner;
3776 }
3777 dev.queue
3778 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3779 }
3780 Step::Narrow { params } => {
3781 let mut p = *params;
3782 p.total = scale(p.total);
3783 dev.queue
3784 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3785 }
3786 Step::Concat { params } => {
3787 let mut p = *params;
3788 p.total = scale(p.total);
3789 dev.queue
3790 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3791 }
3792 Step::Gather { params } => {
3793 let mut p = *params;
3794 p.n_out = scale(p.n_out);
3795 dev.queue
3796 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3797 }
3798 Step::GatherAxis { params } => {
3799 let mut p = *params;
3800 p.total = scale(p.total);
3801 dev.queue
3802 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3803 }
3804 Step::Attention { params, .. } => {
3805 let mut p = *params;
3810 p.seq_q = scale(p.seq_q);
3811 p.seq_k = scale(p.seq_k);
3812 dev.queue
3813 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3814 }
3815 Step::AttentionBackward { params, .. } => {
3816 let mut p = *params;
3817 if p.wrt == 0 {
3818 p.seq_q = scale(p.seq_q);
3819 } else {
3820 p.seq_k = scale(p.seq_k);
3821 }
3822 dev.queue
3823 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3824 }
3825 Step::Rope { params } => {
3826 let mut p = *params;
3831 let s_active = scale(p.seq);
3832 p.seq = s_active;
3833 p.n_total = p.batch * s_active * p.last_dim;
3834 dev.queue
3835 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3836 }
3837 Step::Expand { params, .. } => {
3838 let mut p = *params;
3840 if p.bucket_outermost == 1 && p.out_dim_0 > 0 {
3841 let scaled_d0 = scale(p.out_dim_0);
3842 let inner = p.out_total / p.out_dim_0;
3843 p.out_total = scaled_d0 * inner;
3844 }
3845 dev.queue
3846 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3847 }
3848 Step::Argmax { params } => {
3849 let mut p = *params;
3850 p.outer = scale(p.outer);
3851 dev.queue
3852 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3853 }
3854 Step::Pool2d { params } => {
3855 let mut p = *params;
3856 p.n = scale(p.n);
3857 dev.queue
3858 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3859 }
3860 Step::Conv2d { params } => {
3861 let mut p = *params;
3862 p.n = scale(p.n);
3863 dev.queue
3864 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3865 }
3866 Step::Pool1d { params } => {
3867 let mut p = *params;
3868 p.n = scale(p.n);
3869 dev.queue
3870 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3871 }
3872 Step::Pool3d { params } => {
3873 let mut p = *params;
3874 p.n = scale(p.n);
3875 dev.queue
3876 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3877 }
3878 Step::Conv1d { params } => {
3879 let mut p = *params;
3880 p.n = scale(p.n);
3881 dev.queue
3882 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3883 }
3884 Step::Conv3d { params } => {
3885 let mut p = *params;
3886 p.n = scale(p.n);
3887 dev.queue
3888 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3889 }
3890 Step::ScatterAdd { params } => {
3891 let mut p = *params;
3895 if p.op == 1 {
3896 p.num_updates = scale(p.num_updates);
3897 }
3898 dev.queue
3899 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3900 }
3901 Step::TopK { params } => {
3902 let mut p = *params;
3903 p.outer = scale(p.outer);
3904 dev.queue
3905 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3906 }
3907 Step::UmapKnn { params } => {
3908 let mut p = *params;
3909 p.n = scale(p.n);
3910 dev.queue
3911 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3912 }
3913 Step::GroupedMatmul { params } => {
3914 let mut p = *params;
3915 p.m = scale(p.m);
3916 dev.queue
3917 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3918 }
3919 Step::Sample { params } => {
3920 let mut p = *params;
3921 p.outer = scale(p.outer);
3922 dev.queue
3923 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3924 }
3925 Step::SelectiveScan { params } => {
3926 let mut p = *params;
3928 p.seq = scale(p.seq);
3929 dev.queue
3930 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3931 }
3932 Step::DequantMatmul { params } => {
3933 let mut p = *params;
3934 p.m = scale(p.m);
3935 dev.queue
3936 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3937 }
3938 Step::DequantMatmulGguf { .. }
3939 | Step::DequantGroupedMatmulGguf { .. }
3940 | Step::GatedDeltaNet { .. }
3941 | Step::Llada2GroupLimitedGate { .. }
3942 | Step::UmapKnnHost { .. }
3943 | Step::FftHost { .. } => {}
3944 Step::FusedResidualLn { params } => {
3945 let mut p = *params;
3946 p.outer = scale(p.outer);
3947 dev.queue
3948 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3949 }
3950 Step::FusedResidualLnTee { params } => {
3951 let mut p = *params;
3952 p.outer = scale(p.outer);
3953 dev.queue
3954 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3955 }
3956 Step::FusedResidualRmsNorm { params } => {
3957 let mut p = *params;
3958 p.outer = scale(p.outer);
3959 dev.queue
3960 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3961 }
3962 Step::MatmulQkv { params, coop: _ } => {
3963 let mut p = *params;
3964 p.m = scale(p.m);
3965 dev.queue
3966 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3967 }
3968 #[cfg(feature = "splat")]
3969 Step::GaussianSplatRender { .. }
3970 | Step::GaussianSplatRenderBackward { .. }
3971 | Step::GaussianSplatPrepare { .. }
3972 | Step::GaussianSplatRasterize { .. } => {}
3973 }
3974 if !matches!(step, Step::FftGpu { .. }) {
3975 gpu_ui += 1;
3976 }
3977 }
3978 self.uniforms_active_extent = Some(active);
3979 }
3980
3981 let mm_k = matmul_kernel(&dev.device);
3983 let mm_w = matmul_wide_kernel(&dev.device);
3984 let mm_f16w = matmul_f16w_kernel(&dev.device);
3985 let mm_f16c = matmul_f16_compute_kernel(&dev.device);
3986 let mm_coop = matmul_coop16_kernel(&dev.device);
3987 let mm_coop_f32 = matmul_coop_f32_kernel(&dev.device);
3988 let mm_cast = cast_f32_to_f16_kernel(&dev.device);
3989 let bk = binary_kernel(&dev.device);
3990 let uk = unary_kernel(&dev.device);
3991 let ck = compare_kernel(&dev.device);
3992 let wk = where_kernel(&dev.device);
3993 let mut step_i = 0;
3994 let mut gpu_bi = 0usize;
3995 let mut fft_i = 0usize;
3996 while step_i < self.schedule.len() {
3997 let mut enc = dev
3998 .device
3999 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
4000 label: Some("rlx-wgpu run"),
4001 });
4002 {
4003 let mut pass = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
4004 label: Some("rlx-wgpu compute pass"),
4005 timestamp_writes: None,
4006 });
4007 while step_i < self.schedule.len() {
4008 if step_runs_on_host(&self.schedule[step_i]) {
4009 break;
4010 }
4011 let step = &self.schedule[step_i];
4012 let _perf = rlx_ir::perfetto::TraceSpan::new(step_name(step), "wgpu");
4015 match step {
4016 Step::CastF32ToF16 { params } => {
4017 if let Some(cast_k) = mm_cast {
4022 pass.set_pipeline(&cast_k.pipeline);
4023 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4024 let (gx, gy, gz) = dispatch_dims(params.len, 64);
4025 pass.dispatch_workgroups(gx, gy, gz);
4026 }
4027 }
4028 Step::Matmul {
4029 m,
4030 n,
4031 batch,
4032 b_is_param,
4033 compute_precision,
4034 ..
4035 } =>
4036 {
4043 #[allow(clippy::unnecessary_unwrap)]
4044 let m_s = scale(*m);
4048 if m_s == 0 {
4049 continue;
4050 }
4051 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4052 let f16w_opt_in = rlx_ir::env::flag("RLX_WGPU_F16_WEIGHTS");
4062 if let Some(coop) = mm_coop.as_ref()
4063 && *b_is_param
4064 && *compute_precision == MatmulCompute::Coop16
4065 {
4066 pass.set_pipeline(&coop.pipeline);
4072 pass.dispatch_workgroups(n / 32, m_s.div_ceil(32), *batch);
4073 } else if let Some(coop_f32) = mm_coop_f32.as_ref()
4074 && *b_is_param
4075 && *compute_precision == MatmulCompute::CoopF32
4076 {
4077 pass.set_pipeline(&coop_f32.pipeline);
4081 pass.dispatch_workgroups(n / 32, m_s.div_ceil(32), *batch);
4082 } else if let Some(f16c) = mm_f16c.as_ref()
4083 && *b_is_param
4084 && *compute_precision == MatmulCompute::F16
4085 {
4086 pass.set_pipeline(&f16c.pipeline);
4087 pass.dispatch_workgroups(n.div_ceil(32), m_s.div_ceil(32), *batch);
4088 } else if let Some(f16w) = mm_f16w.as_ref()
4089 && *b_is_param
4090 && f16w_opt_in
4091 {
4092 pass.set_pipeline(&f16w.pipeline);
4093 pass.dispatch_workgroups(n.div_ceil(32), m_s.div_ceil(32), *batch);
4094 } else if m_s >= 32 && *n >= 64 {
4095 pass.set_pipeline(&mm_w.pipeline);
4096 pass.dispatch_workgroups(n.div_ceil(64), m_s.div_ceil(32), *batch);
4097 } else {
4098 pass.set_pipeline(&mm_k.pipeline);
4099 pass.dispatch_workgroups(n.div_ceil(32), m_s.div_ceil(32), *batch);
4100 }
4101 }
4102 Step::Binary { params } => {
4103 let n_s = scale(params.n);
4104 if n_s == 0 {
4105 continue;
4106 }
4107 pass.set_pipeline(&bk.pipeline);
4108 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4109 let (gx, gy, gz) = dispatch_dims(n_s, 64);
4110 pass.dispatch_workgroups(gx, gy, gz);
4111 }
4112 Step::Compare { params } => {
4113 let n_s = scale(params.n);
4114 if n_s == 0 {
4115 continue;
4116 }
4117 pass.set_pipeline(&ck.pipeline);
4118 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4119 let (gx, gy, gz) = dispatch_dims(n_s, 64);
4120 pass.dispatch_workgroups(gx, gy, gz);
4121 }
4122 Step::Unary { params } => {
4123 let n_s = scale(params.n);
4124 if n_s == 0 {
4125 continue;
4126 }
4127 pass.set_pipeline(&uk.pipeline);
4128 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4129 let (gx, gy, gz) = dispatch_dims(n_s, 64);
4130 pass.dispatch_workgroups(gx, gy, gz);
4131 }
4132 Step::Where { params } => {
4133 let n_s = scale(params.n);
4134 if n_s == 0 {
4135 continue;
4136 }
4137 pass.set_pipeline(&wk.pipeline);
4138 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4139 let (gx, gy, gz) = dispatch_dims(n_s, 64);
4140 pass.dispatch_workgroups(gx, gy, gz);
4141 }
4142 Step::Reduce { params } => {
4143 let outer_s = scale(params.outer);
4144 if outer_s == 0 {
4145 continue;
4146 }
4147 let rk = reduce_kernel(&dev.device);
4148 pass.set_pipeline(&rk.pipeline);
4149 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4150 let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4151 pass.dispatch_workgroups(gx, gy, gz);
4152 }
4153 Step::Softmax { params } => {
4154 let outer_s = scale(params.outer);
4155 if outer_s == 0 {
4156 continue;
4157 }
4158 let sk = softmax_kernel(&dev.device);
4159 pass.set_pipeline(&sk.pipeline);
4160 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4161 let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4162 pass.dispatch_workgroups(gx, gy, gz);
4163 }
4164 Step::LayerNorm { params } => {
4165 let outer_s = scale(params.outer);
4166 if outer_s == 0 {
4167 continue;
4168 }
4169 let lk = layernorm_kernel(&dev.device);
4170 pass.set_pipeline(&lk.pipeline);
4171 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4172 let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4173 pass.dispatch_workgroups(gx, gy, gz);
4174 }
4175 Step::RmsNormBackwardInput { params } => {
4176 let outer_s = scale(params.outer);
4177 if outer_s == 0 {
4178 continue;
4179 }
4180 let rk = rms_norm_backward_kernel(&dev.device);
4181 pass.set_pipeline(&rk.pipeline);
4182 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4183 pass.dispatch_workgroups(outer_s, 1, 1);
4184 }
4185 Step::RmsNormBackwardGamma { params }
4186 | Step::RmsNormBackwardBeta { params } => {
4187 if params.inner == 0 {
4188 continue;
4189 }
4190 let rk = rms_norm_backward_param_kernel(&dev.device);
4191 pass.set_pipeline(&rk.pipeline);
4192 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4193 pass.dispatch_workgroups(1, 1, 1);
4194 }
4195 Step::CumsumBackward { params } => {
4196 let outer_s = scale(params.outer);
4197 if outer_s == 0 {
4198 continue;
4199 }
4200 let ck = cumsum_backward_kernel(&dev.device);
4201 pass.set_pipeline(&ck.pipeline);
4202 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4203 let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4204 pass.dispatch_workgroups(gx, gy, gz);
4205 }
4206 Step::RopeBackward { params } => {
4207 let seq_s = scale(params.seq);
4208 if seq_s == 0 {
4209 continue;
4210 }
4211 let rk = rope_backward_kernel(&dev.device);
4212 pass.set_pipeline(&rk.pipeline);
4213 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4214 let total = params.batch * seq_s * params.hidden;
4215 let (gx, gy, gz) = dispatch_dims(total, 64);
4216 pass.dispatch_workgroups(gx, gy, gz);
4217 }
4218 Step::GatherBackward { params } => {
4219 let outer_s = scale(params.outer);
4220 if outer_s == 0 {
4221 continue;
4222 }
4223 let total = outer_s * params.axis_dim * params.trailing;
4224 if total > 0 {
4225 let zk = gather_backward_zero_kernel(&dev.device);
4226 pass.set_pipeline(&zk.pipeline);
4227 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4228 let (gx, _, _) = dispatch_dims(total, 256);
4229 pass.dispatch_workgroups(gx, 1, 1);
4230 }
4231 let ak = gather_backward_acc_kernel(&dev.device);
4232 pass.set_pipeline(&ak.pipeline);
4233 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4234 pass.dispatch_workgroups(outer_s, 1, 1);
4235 }
4236 Step::Cumsum { params } => {
4237 let outer_s = scale(params.outer);
4238 if outer_s == 0 {
4239 continue;
4240 }
4241 let ck2 = cumsum_kernel(&dev.device);
4242 pass.set_pipeline(&ck2.pipeline);
4243 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4244 let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4245 pass.dispatch_workgroups(gx, gy, gz);
4246 }
4247 Step::FftGpu {
4248 src_off,
4249 dst_off,
4250 outer,
4251 n,
4252 inverse,
4253 norm_scale,
4254 } => {
4255 let res = &self.fft_gpu_steps[fft_i];
4256 fft_i += 1;
4257 crate::fft_dispatch::dispatch_fft_gpu_in_pass(
4258 &dev.device,
4259 &dev.queue,
4260 &mut pass,
4261 res,
4262 *src_off,
4263 *dst_off,
4264 *outer,
4265 *n,
4266 *inverse != 0,
4267 *norm_scale,
4268 );
4269 }
4270 Step::Copy { params } => {
4271 let n_s = scale(params.n);
4272 if n_s == 0 {
4273 continue;
4274 }
4275 let ck2 = copy_kernel(&dev.device);
4276 pass.set_pipeline(&ck2.pipeline);
4277 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4278 let (gx, gy, gz) = dispatch_dims(n_s, 64);
4279 pass.dispatch_workgroups(gx, gy, gz);
4280 }
4281 Step::ElementwiseRegion { params } => {
4282 let len_s = scale(params.len);
4283 if len_s == 0 {
4284 continue;
4285 }
4286 let ek = elementwise_region_kernel(&dev.device);
4287 pass.set_pipeline(&ek.pipeline);
4288 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4289 let (gx, gy, gz) = dispatch_dims(len_s, 64);
4290 pass.dispatch_workgroups(gx, gy, gz);
4291 }
4292 Step::Transpose { params, .. } => {
4293 let total_s = if params.bucket_outermost == 1 && params.out_dim_0 > 0 {
4297 let scaled_d0 = scale(params.out_dim_0);
4298 let inner = params.out_total / params.out_dim_0;
4299 scaled_d0 * inner
4300 } else {
4301 params.out_total
4302 };
4303 if total_s == 0 {
4304 continue;
4305 }
4306 let tk = transpose_kernel(&dev.device);
4307 pass.set_pipeline(&tk.pipeline);
4308 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4309 let (gx, gy, gz) = dispatch_dims(total_s, 64);
4310 pass.dispatch_workgroups(gx, gy, gz);
4311 }
4312 Step::Narrow { params } => {
4313 let total_s = scale(params.total);
4314 if total_s == 0 {
4315 continue;
4316 }
4317 let nk = narrow_kernel(&dev.device);
4318 pass.set_pipeline(&nk.pipeline);
4319 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4320 let (gx, gy, gz) = dispatch_dims(total_s, 64);
4321 pass.dispatch_workgroups(gx, gy, gz);
4322 }
4323 Step::Concat { params } => {
4324 let total_s = scale(params.total);
4325 if total_s == 0 {
4326 continue;
4327 }
4328 let cck = concat_kernel(&dev.device);
4329 pass.set_pipeline(&cck.pipeline);
4330 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4331 let (gx, gy, gz) = dispatch_dims(total_s, 64);
4332 pass.dispatch_workgroups(gx, gy, gz);
4333 }
4334 Step::Gather { params } => {
4335 let n_out_s = scale(params.n_out);
4336 if n_out_s == 0 {
4337 continue;
4338 }
4339 let gk = gather_kernel(&dev.device);
4340 pass.set_pipeline(&gk.pipeline);
4341 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4342 let (gx, gy, gz) = dispatch_dims(n_out_s, 64);
4343 pass.dispatch_workgroups(gx, gy, gz);
4344 }
4345 Step::GatherAxis { params } => {
4346 let total_s = scale(params.total);
4347 if total_s == 0 {
4348 continue;
4349 }
4350 let gk = gather_axis_kernel(&dev.device);
4351 pass.set_pipeline(&gk.pipeline);
4352 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4353 let (gx, gy, gz) = dispatch_dims(total_s, 64);
4354 pass.dispatch_workgroups(gx, gy, gz);
4355 }
4356 Step::Attention { params, .. } => {
4357 let seq_q_s = scale(params.seq_q);
4361 if seq_q_s == 0 {
4362 continue;
4363 }
4364 let ak = attention_kernel(&dev.device);
4365 pass.set_pipeline(&ak.pipeline);
4366 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4367 let total = params.batch * params.heads * seq_q_s;
4368 let (gx, gy, gz) = dispatch_dims(total, 64);
4369 pass.dispatch_workgroups(gx, gy, gz);
4370 }
4371 Step::AttentionBackward { params, .. } => {
4372 let axis = if params.wrt == 0 {
4373 params.seq_q
4374 } else {
4375 params.seq_k
4376 };
4377 let axis_s = scale(axis);
4378 if axis_s == 0 {
4379 continue;
4380 }
4381 let ak = attention_bwd_kernel(&dev.device);
4382 pass.set_pipeline(&ak.pipeline);
4383 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4384 let total = params.batch * params.heads * axis_s;
4385 let (gx, gy, gz) = dispatch_dims(total, 64);
4386 pass.dispatch_workgroups(gx, gy, gz);
4387 }
4388 Step::Rope { params } => {
4389 let s_active = scale(params.seq);
4392 let total_s = params.batch * s_active * params.last_dim;
4393 if total_s == 0 {
4394 continue;
4395 }
4396 let rk = rope_kernel(&dev.device);
4397 pass.set_pipeline(&rk.pipeline);
4398 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4399 let (gx, gy, gz) = dispatch_dims(total_s, 64);
4400 pass.dispatch_workgroups(gx, gy, gz);
4401 }
4402 Step::Expand { params, .. } => {
4403 let total_s = if params.bucket_outermost == 1 && params.out_dim_0 > 0 {
4404 let scaled_d0 = scale(params.out_dim_0);
4405 let inner = params.out_total / params.out_dim_0;
4406 scaled_d0 * inner
4407 } else {
4408 params.out_total
4409 };
4410 if total_s == 0 {
4411 continue;
4412 }
4413 let ek = expand_kernel(&dev.device);
4414 pass.set_pipeline(&ek.pipeline);
4415 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4416 let (gx, gy, gz) = dispatch_dims(total_s, 64);
4417 pass.dispatch_workgroups(gx, gy, gz);
4418 }
4419 Step::Argmax { params } => {
4420 let outer_s = scale(params.outer);
4421 if outer_s == 0 {
4422 continue;
4423 }
4424 let amk = argmax_kernel(&dev.device);
4425 pass.set_pipeline(&amk.pipeline);
4426 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4427 let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4428 pass.dispatch_workgroups(gx, gy, gz);
4429 }
4430 Step::Pool2d { params } => {
4431 let n_s = scale(params.n);
4432 if n_s == 0 {
4433 continue;
4434 }
4435 let pk = pool2d_kernel(&dev.device);
4436 pass.set_pipeline(&pk.pipeline);
4437 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4438 let total = n_s * params.c * params.h_out * params.w_out;
4439 let (gx, gy, gz) = dispatch_dims(total, 64);
4440 pass.dispatch_workgroups(gx, gy, gz);
4441 }
4442 Step::Conv2d { params } => {
4443 let n_s = scale(params.n);
4444 if n_s == 0 {
4445 continue;
4446 }
4447 let ck2 = conv2d_kernel(&dev.device);
4448 pass.set_pipeline(&ck2.pipeline);
4449 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4450 let total = n_s * params.c_out * params.h_out * params.w_out;
4451 let (gx, gy, gz) = dispatch_dims(total, 64);
4452 pass.dispatch_workgroups(gx, gy, gz);
4453 }
4454 Step::Pool1d { params } => {
4455 let n_s = scale(params.n);
4456 if n_s == 0 {
4457 continue;
4458 }
4459 let pk = pool1d_kernel(&dev.device);
4460 pass.set_pipeline(&pk.pipeline);
4461 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4462 let total = n_s * params.c * params.l_out;
4463 let (gx, gy, gz) = dispatch_dims(total, 64);
4464 pass.dispatch_workgroups(gx, gy, gz);
4465 }
4466 Step::Pool3d { params } => {
4467 let n_s = scale(params.n);
4468 if n_s == 0 {
4469 continue;
4470 }
4471 let pk = pool3d_kernel(&dev.device);
4472 pass.set_pipeline(&pk.pipeline);
4473 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4474 let total = n_s * params.c * params.d_out * params.h_out * params.w_out;
4475 let (gx, gy, gz) = dispatch_dims(total, 64);
4476 pass.dispatch_workgroups(gx, gy, gz);
4477 }
4478 Step::Conv1d { params } => {
4479 let n_s = scale(params.n);
4480 if n_s == 0 {
4481 continue;
4482 }
4483 let ck = conv1d_kernel(&dev.device);
4484 pass.set_pipeline(&ck.pipeline);
4485 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4486 let total = n_s * params.c_out * params.l_out;
4487 let (gx, gy, gz) = dispatch_dims(total, 64);
4488 pass.dispatch_workgroups(gx, gy, gz);
4489 }
4490 Step::Conv3d { params } => {
4491 let n_s = scale(params.n);
4492 if n_s == 0 {
4493 continue;
4494 }
4495 let ck = conv3d_kernel(&dev.device);
4496 pass.set_pipeline(&ck.pipeline);
4497 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4498 let total =
4499 n_s * params.c_out * params.d_out * params.h_out * params.w_out;
4500 let (gx, gy, gz) = dispatch_dims(total, 64);
4501 pass.dispatch_workgroups(gx, gy, gz);
4502 }
4503 Step::ScatterAdd { params } => {
4504 let sk = scatter_add_kernel(&dev.device);
4505 pass.set_pipeline(&sk.pipeline);
4506 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4507 if params.op == 0 {
4513 let (gx, gy, gz) = dispatch_dims(params.out_total, 64);
4514 pass.dispatch_workgroups(gx, gy, gz);
4515 } else {
4516 pass.dispatch_workgroups(1, 1, 1);
4517 }
4518 }
4519 Step::TopK { params } => {
4520 let outer_s = scale(params.outer);
4521 if outer_s == 0 {
4522 continue;
4523 }
4524 let tk = topk_kernel(&dev.device);
4525 pass.set_pipeline(&tk.pipeline);
4526 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4527 let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4528 pass.dispatch_workgroups(gx, gy, gz);
4529 }
4530 Step::UmapKnn { params } => {
4531 let n_s = scale(params.n);
4532 if n_s == 0 {
4533 continue;
4534 }
4535 let uk = umap_knn_kernel(&dev.device);
4536 pass.set_pipeline(&uk.pipeline);
4537 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4538 let (gx, gy, gz) = dispatch_dims(n_s, 64);
4539 pass.dispatch_workgroups(gx, gy, gz);
4540 }
4541 Step::GroupedMatmul { params } => {
4542 let m_s = scale(params.m);
4543 if m_s == 0 {
4544 continue;
4545 }
4546 let gk = grouped_matmul_kernel(&dev.device);
4547 pass.set_pipeline(&gk.pipeline);
4548 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4549 pass.dispatch_workgroups(params.n.div_ceil(8), m_s.div_ceil(8), 1);
4550 }
4551 Step::Sample { params } => {
4552 let outer_s = scale(params.outer);
4553 if outer_s == 0 {
4554 continue;
4555 }
4556 let sk = sample_kernel(&dev.device);
4557 pass.set_pipeline(&sk.pipeline);
4558 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4559 let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4560 pass.dispatch_workgroups(gx, gy, gz);
4561 }
4562 Step::SelectiveScan { params } => {
4563 let ssk = selective_scan_kernel(&dev.device);
4568 pass.set_pipeline(&ssk.pipeline);
4569 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4570 let total = params.batch * params.hidden;
4571 let (gx, gy, gz) = dispatch_dims(total, 64);
4572 pass.dispatch_workgroups(gx, gy, gz);
4573 }
4574 Step::DequantMatmul { params } => {
4575 let m_s = scale(params.m);
4576 if m_s == 0 {
4577 continue;
4578 }
4579 let dk = dequant_matmul_kernel(&dev.device);
4580 pass.set_pipeline(&dk.pipeline);
4581 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4582 pass.dispatch_workgroups(params.n.div_ceil(8), m_s.div_ceil(8), 1);
4583 }
4584 Step::FusedResidualLn { params } => {
4585 let outer_s = scale(params.outer);
4586 if outer_s == 0 {
4587 continue;
4588 }
4589 let frk = fused_residual_ln_kernel(&dev.device);
4590 pass.set_pipeline(&frk.pipeline);
4591 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4592 let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4593 pass.dispatch_workgroups(gx, gy, gz);
4594 }
4595 Step::FusedResidualLnTee { params } => {
4596 let outer_s = scale(params.outer);
4597 if outer_s == 0 {
4598 continue;
4599 }
4600 let frtk = fused_residual_ln_tee_kernel(&dev.device);
4601 pass.set_pipeline(&frtk.pipeline);
4602 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4603 let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4604 pass.dispatch_workgroups(gx, gy, gz);
4605 }
4606 Step::FusedResidualRmsNorm { params } => {
4607 let outer_s = scale(params.outer);
4608 if outer_s == 0 {
4609 continue;
4610 }
4611 let frk = fused_residual_rms_norm_kernel(&dev.device);
4612 pass.set_pipeline(&frk.pipeline);
4613 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4614 let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4615 pass.dispatch_workgroups(gx, gy, gz);
4616 }
4617 Step::MatmulQkv { params, coop } => {
4618 let m_s = scale(params.m);
4619 if m_s == 0 {
4620 continue;
4621 }
4622 let pipe = if *coop {
4625 &matmul_qkv_coop_f32_kernel(&dev.device)
4626 .expect("coop matmul_qkv kernel missing")
4627 .pipeline
4628 } else {
4629 &matmul_qkv_kernel(&dev.device).pipeline
4630 };
4631 pass.set_pipeline(pipe);
4632 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4633 pass.dispatch_workgroups(params.n.div_ceil(32), m_s.div_ceil(32), 1);
4634 }
4635 Step::DequantMatmulGguf { .. }
4636 | Step::DequantGroupedMatmulGguf { .. }
4637 | Step::GatedDeltaNet { .. }
4638 | Step::Llada2GroupLimitedGate { .. }
4639 | Step::UmapKnnHost { .. }
4640 | Step::FftHost { .. } => {}
4641 #[cfg(feature = "splat")]
4642 Step::GaussianSplatRender { .. }
4643 | Step::GaussianSplatRenderBackward { .. }
4644 | Step::GaussianSplatPrepare { .. }
4645 | Step::GaussianSplatRasterize { .. } => {}
4646 }
4647 if !matches!(step, Step::FftGpu { .. }) {
4648 gpu_bi += 1;
4649 }
4650 step_i += 1;
4651 }
4652 }
4653 dev.queue.submit(std::iter::once(enc.finish()));
4654 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
4655 if step_i >= self.schedule.len() {
4656 break;
4657 }
4658 match &self.schedule[step_i] {
4659 Step::DequantMatmulGguf {
4660 m,
4661 k,
4662 n,
4663 scheme_id,
4664 x_byte_off,
4665 w_byte_off,
4666 out_byte_off,
4667 } => {
4668 crate::gguf_host::run_dequant_matmul_gguf(
4669 &self.arena,
4670 &dev.device,
4671 &dev.queue,
4672 *m as usize,
4673 *k as usize,
4674 *n as usize,
4675 *scheme_id,
4676 *x_byte_off as usize,
4677 *w_byte_off as usize,
4678 *out_byte_off as usize,
4679 );
4680 }
4681 Step::DequantGroupedMatmulGguf {
4682 m,
4683 k,
4684 n,
4685 num_experts,
4686 scheme_id,
4687 x_byte_off,
4688 w_byte_off,
4689 idx_byte_off,
4690 out_byte_off,
4691 } => {
4692 crate::gguf_host::run_dequant_grouped_matmul_gguf(
4693 &self.arena,
4694 &dev.device,
4695 &dev.queue,
4696 *m as usize,
4697 *k as usize,
4698 *n as usize,
4699 *num_experts as usize,
4700 *scheme_id,
4701 *x_byte_off as usize,
4702 *w_byte_off as usize,
4703 *idx_byte_off as usize,
4704 *out_byte_off as usize,
4705 );
4706 }
4707 Step::GatedDeltaNet {
4708 q_byte_off,
4709 k_byte_off,
4710 v_byte_off,
4711 g_byte_off,
4712 beta_byte_off,
4713 state_byte_off,
4714 dst_byte_off,
4715 batch,
4716 seq,
4717 heads,
4718 state_size,
4719 use_carry,
4720 } => {
4721 crate::gdn_host::run_gated_delta_net(
4722 &self.arena,
4723 &dev.device,
4724 &dev.queue,
4725 *q_byte_off as usize,
4726 *k_byte_off as usize,
4727 *v_byte_off as usize,
4728 *g_byte_off as usize,
4729 *beta_byte_off as usize,
4730 *state_byte_off as usize,
4731 *dst_byte_off as usize,
4732 *batch as usize,
4733 *seq as usize,
4734 *heads as usize,
4735 *state_size as usize,
4736 *use_carry,
4737 );
4738 }
4739 Step::Llada2GroupLimitedGate {
4740 sig_byte_off,
4741 route_byte_off,
4742 out_byte_off,
4743 n_elems,
4744 attrs,
4745 } => {
4746 crate::llada2_gate_host::run_llada2_group_limited_gate(
4747 &self.arena,
4748 &dev.device,
4749 &dev.queue,
4750 *sig_byte_off as usize,
4751 *route_byte_off as usize,
4752 *out_byte_off as usize,
4753 *n_elems as usize,
4754 attrs,
4755 );
4756 }
4757 Step::UmapKnnHost {
4758 pairwise_byte_off,
4759 out_byte_off,
4760 n,
4761 k,
4762 } => {
4763 crate::umap_knn_host::run_umap_knn(
4764 &self.arena,
4765 &dev.device,
4766 &dev.queue,
4767 *pairwise_byte_off as usize,
4768 *out_byte_off as usize,
4769 *n as usize,
4770 *k as usize,
4771 );
4772 }
4773 Step::FftHost {
4774 src_byte_off,
4775 dst_byte_off,
4776 outer,
4777 n_complex,
4778 inverse,
4779 norm_tag,
4780 dtype_tag,
4781 } => {
4782 crate::fft_host::run_fft1d(
4783 &self.arena,
4784 &dev.device,
4785 &dev.queue,
4786 *src_byte_off as usize,
4787 *dst_byte_off as usize,
4788 *outer as usize,
4789 *n_complex as usize,
4790 *inverse,
4791 *norm_tag,
4792 fft_dtype_from_tag(*dtype_tag),
4793 );
4794 }
4795 #[cfg(feature = "splat")]
4796 Step::GaussianSplatRender {
4797 positions_byte_off,
4798 positions_len,
4799 scales_byte_off,
4800 scales_len,
4801 rotations_byte_off,
4802 rotations_len,
4803 opacities_byte_off,
4804 opacities_len,
4805 colors_byte_off,
4806 colors_len,
4807 sh_coeffs_byte_off,
4808 sh_coeffs_len,
4809 meta_byte_off,
4810 dst_byte_off,
4811 dst_len,
4812 width,
4813 height,
4814 tile_size,
4815 radius_scale,
4816 alpha_cutoff,
4817 max_splat_steps,
4818 transmittance_threshold,
4819 max_list_entries,
4820 } => {
4821 crate::splat::run_gaussian_splat_render(
4822 &self.arena,
4823 &dev.device,
4824 &dev.queue,
4825 *positions_byte_off as usize,
4826 *positions_len as usize,
4827 *scales_byte_off as usize,
4828 *scales_len as usize,
4829 *rotations_byte_off as usize,
4830 *rotations_len as usize,
4831 *opacities_byte_off as usize,
4832 *opacities_len as usize,
4833 *colors_byte_off as usize,
4834 *colors_len as usize,
4835 *sh_coeffs_byte_off as usize,
4836 *sh_coeffs_len as usize,
4837 *meta_byte_off as usize,
4838 *dst_byte_off as usize,
4839 *dst_len as usize,
4840 *width,
4841 *height,
4842 *tile_size,
4843 *radius_scale,
4844 *alpha_cutoff,
4845 *max_splat_steps,
4846 *transmittance_threshold,
4847 *max_list_entries,
4848 );
4849 }
4850 #[cfg(feature = "splat")]
4851 Step::GaussianSplatPrepare {
4852 positions_byte_off,
4853 positions_len,
4854 scales_byte_off,
4855 scales_len,
4856 rotations_byte_off,
4857 rotations_len,
4858 opacities_byte_off,
4859 opacities_len,
4860 colors_byte_off,
4861 colors_len,
4862 sh_coeffs_byte_off,
4863 sh_coeffs_len,
4864 meta_byte_off,
4865 meta_len,
4866 prep_byte_off,
4867 prep_len,
4868 width,
4869 height,
4870 tile_size,
4871 radius_scale,
4872 alpha_cutoff,
4873 max_splat_steps,
4874 transmittance_threshold,
4875 max_list_entries,
4876 } => {
4877 crate::splat::run_gaussian_splat_prepare(
4878 &self.arena,
4879 &dev.device,
4880 &dev.queue,
4881 *positions_byte_off as usize,
4882 *positions_len as usize,
4883 *scales_byte_off as usize,
4884 *scales_len as usize,
4885 *rotations_byte_off as usize,
4886 *rotations_len as usize,
4887 *opacities_byte_off as usize,
4888 *opacities_len as usize,
4889 *colors_byte_off as usize,
4890 *colors_len as usize,
4891 *sh_coeffs_byte_off as usize,
4892 *sh_coeffs_len as usize,
4893 *meta_byte_off as usize,
4894 *meta_len as usize,
4895 *prep_byte_off as usize,
4896 *prep_len as usize,
4897 *width,
4898 *height,
4899 *tile_size,
4900 *radius_scale,
4901 *alpha_cutoff,
4902 *max_splat_steps,
4903 *transmittance_threshold,
4904 *max_list_entries,
4905 );
4906 }
4907 #[cfg(feature = "splat")]
4908 Step::GaussianSplatRasterize {
4909 prep_byte_off,
4910 prep_len,
4911 meta_byte_off,
4912 meta_len,
4913 dst_byte_off,
4914 dst_len,
4915 count,
4916 width,
4917 height,
4918 tile_size,
4919 alpha_cutoff,
4920 max_splat_steps,
4921 transmittance_threshold,
4922 max_list_entries,
4923 } => {
4924 crate::splat::run_gaussian_splat_rasterize(
4925 &self.arena,
4926 &dev.device,
4927 &dev.queue,
4928 *prep_byte_off as usize,
4929 *prep_len as usize,
4930 *meta_byte_off as usize,
4931 *meta_len as usize,
4932 *dst_byte_off as usize,
4933 *dst_len as usize,
4934 *count as usize,
4935 *width,
4936 *height,
4937 *tile_size,
4938 *alpha_cutoff,
4939 *max_splat_steps,
4940 *transmittance_threshold,
4941 *max_list_entries,
4942 );
4943 }
4944 #[cfg(feature = "splat")]
4945 Step::GaussianSplatRenderBackward {
4946 positions_byte_off,
4947 positions_len,
4948 scales_byte_off,
4949 scales_len,
4950 rotations_byte_off,
4951 rotations_len,
4952 opacities_byte_off,
4953 opacities_len,
4954 colors_byte_off,
4955 colors_len,
4956 sh_coeffs_byte_off,
4957 sh_coeffs_len,
4958 meta_byte_off,
4959 d_loss_byte_off,
4960 d_loss_len,
4961 packed_byte_off,
4962 packed_len,
4963 width,
4964 height,
4965 tile_size,
4966 radius_scale,
4967 alpha_cutoff,
4968 max_splat_steps,
4969 transmittance_threshold,
4970 max_list_entries,
4971 loss_grad_clip,
4972 sh_band,
4973 max_anisotropy,
4974 } => {
4975 crate::splat::run_gaussian_splat_render_backward(
4976 &self.arena,
4977 &dev.device,
4978 &dev.queue,
4979 *positions_byte_off as usize,
4980 *positions_len as usize,
4981 *scales_byte_off as usize,
4982 *scales_len as usize,
4983 *rotations_byte_off as usize,
4984 *rotations_len as usize,
4985 *opacities_byte_off as usize,
4986 *opacities_len as usize,
4987 *colors_byte_off as usize,
4988 *colors_len as usize,
4989 *sh_coeffs_byte_off as usize,
4990 *sh_coeffs_len as usize,
4991 *meta_byte_off as usize,
4992 *d_loss_byte_off as usize,
4993 *d_loss_len as usize,
4994 *packed_byte_off as usize,
4995 *packed_len as usize,
4996 *width,
4997 *height,
4998 *tile_size,
4999 *radius_scale,
5000 *alpha_cutoff,
5001 *max_splat_steps,
5002 *transmittance_threshold,
5003 *max_list_entries,
5004 *loss_grad_clip,
5005 *sh_band,
5006 *max_anisotropy,
5007 );
5008 }
5009 _ => break,
5010 }
5011 step_i += 1;
5012 }
5013
5014 if rlx_ir::env::flag("RLX_WGPU_NAN_TRACE") {
5020 let mut bad_nodes = Vec::new();
5021 for node in self.graph.nodes() {
5022 if !self.arena.has(node.id) {
5023 continue;
5024 }
5025 if matches!(
5027 node.op,
5028 rlx_ir::Op::Input { .. }
5029 | rlx_ir::Op::Param { .. }
5030 | rlx_ir::Op::Constant { .. }
5031 ) {
5032 continue;
5033 }
5034 let data = self.arena.read_f32(&dev.device, &dev.queue, node.id);
5035 let nan_count = data.iter().filter(|v| v.is_nan()).count();
5036 let inf_count = data.iter().filter(|v| v.is_infinite()).count();
5037 if nan_count > 0 || inf_count > 0 {
5038 let first_nan = data.iter().position(|v| v.is_nan());
5040 if let Some(idx) = first_nan {
5041 let lo = idx.saturating_sub(2);
5042 let hi = (idx + 3).min(data.len());
5043 eprintln!(
5044 " node {:?} op={:?} len={} nan={} inf={} \
5045 first_nan_idx={} ctx={:?}",
5046 node.id,
5047 node.op,
5048 data.len(),
5049 nan_count,
5050 inf_count,
5051 idx,
5052 &data[lo..hi]
5053 );
5054 }
5055 bad_nodes.push((node.id, data.len(), nan_count, inf_count));
5056 if bad_nodes.len() >= 3 {
5057 break;
5058 }
5059 }
5060 }
5061 if bad_nodes.is_empty() {
5062 eprintln!("[wgpu-nan-trace] no NaN/Inf in any node — clean run");
5063 } else {
5064 eprintln!(
5065 "[wgpu-nan-trace] first {} bad nodes (above)",
5066 bad_nodes.len()
5067 );
5068 }
5069 }
5070
5071 self.graph
5072 .outputs
5073 .iter()
5074 .map(|&id| {
5075 if rlx_ir::env::flag("RLX_BENCH_DISPATCH_ONLY") {
5076 let n = self.graph.node(id).shape.num_elements().unwrap_or(0);
5077 vec![0.0; n]
5078 } else {
5079 self.arena.read_f32(&dev.device, &dev.queue, id)
5080 }
5081 })
5082 .collect()
5083 }
5084}
5085
5086fn dispatch_dims(threads_total: u32, workgroup_size: u32) -> (u32, u32, u32) {
5093 let groups = threads_total.div_ceil(workgroup_size);
5094 if groups <= 65535 {
5095 (groups, 1, 1)
5096 } else {
5097 let gx = 65535u32;
5098 let gy = groups.div_ceil(gx);
5099 (gx, gy, 1)
5100 }
5101}
5102
5103fn require_equal_shapes(graph: &Graph, ids: &[NodeId], op_name: &str) {
5104 let s0 = graph.node(ids[0]).shape.num_elements().unwrap_or(0);
5105 for &id in &ids[1..] {
5106 let si = graph.node(id).shape.num_elements().unwrap_or(0);
5107 if si != s0 {
5108 panic!(
5109 "rlx-wgpu {op_name}: broadcasting not yet implemented; \
5110 inputs must have the same element count (got {s0} vs {si})"
5111 );
5112 }
5113 }
5114}
5115
5116fn bind_two(
5117 device: &wgpu::Device,
5118 kernel: &Kernel,
5119 buf0: &wgpu::Buffer,
5120 buf1: &wgpu::Buffer,
5121) -> wgpu::BindGroup {
5122 device.create_bind_group(&wgpu::BindGroupDescriptor {
5123 label: Some("rlx-wgpu bg"),
5124 layout: &kernel.bgl,
5125 entries: &[
5126 wgpu::BindGroupEntry {
5127 binding: 0,
5128 resource: buf0.as_entire_binding(),
5129 },
5130 wgpu::BindGroupEntry {
5131 binding: 1,
5132 resource: buf1.as_entire_binding(),
5133 },
5134 ],
5135 })
5136}
5137
5138fn derive_matmul_compute(
5161 dev: &wgpu::Device,
5162 graph: &Graph,
5163 a_id: NodeId,
5164 b_id: NodeId,
5165 m: u32,
5166 k: u32,
5167 n: u32,
5168) -> MatmulCompute {
5169 use rlx_ir::DType;
5170 let a_dt = graph.node(a_id).shape.dtype();
5171 let b_dt = graph.node(b_id).shape.dtype();
5172 let any_low =
5173 matches!(a_dt, DType::F16 | DType::BF16) || matches!(b_dt, DType::F16 | DType::BF16);
5174 let coop16_aligned = m.is_multiple_of(32) && k.is_multiple_of(8) && n.is_multiple_of(32);
5181 let coop_f32_aligned = k.is_multiple_of(8) && n.is_multiple_of(32);
5182 let has_coop = dev
5183 .features()
5184 .contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX);
5185 if any_low
5190 && has_coop
5191 && dev.features().contains(wgpu::Features::SHADER_F16)
5192 && traces_to_param(graph, b_id)
5193 && coop16_aligned
5194 {
5195 return MatmulCompute::Coop16;
5196 }
5197 let disabled = rlx_ir::env::flag("RLX_WGPU_NO_COOP_F32");
5211 let forced = rlx_ir::env::flag("RLX_WGPU_FORCE_COOP_F32");
5212 let backend_ok = forced
5213 || matches!(
5214 crate::device::wgpu_device().map(|d| d.backend),
5215 Some(wgpu::Backend::Metal)
5216 );
5217 if !disabled && backend_ok && has_coop && coop_f32_aligned && traces_to_param(graph, b_id) {
5218 return MatmulCompute::CoopF32;
5219 }
5220 MatmulCompute::F32
5221}
5222
5223#[allow(dead_code)]
5243fn detect_qkv_narrow_pattern(
5244 graph: &Graph,
5245 q_id: NodeId,
5246 k_id: NodeId,
5247 v_id: NodeId,
5248) -> Option<(NodeId, u32)> {
5249 let unwrap_narrow = |id: NodeId| -> Option<(NodeId, usize, usize, usize)> {
5250 let node = graph.node(id);
5251 match &node.op {
5252 Op::Narrow { axis, start, len } => Some((node.inputs[0], *axis, *start, *len)),
5253 _ => None,
5254 }
5255 };
5256 let (q_src, q_axis, q_start, q_len) = unwrap_narrow(q_id)?;
5257 let (k_src, k_axis, k_start, k_len) = unwrap_narrow(k_id)?;
5258 let (v_src, v_axis, v_start, v_len) = unwrap_narrow(v_id)?;
5259 if q_src != k_src || k_src != v_src {
5261 return None;
5262 }
5263 if q_len != k_len || k_len != v_len {
5265 return None;
5266 }
5267 if q_start != 0 || k_start != q_len || v_start != q_len * 2 {
5269 return None;
5270 }
5271 let src_rank = graph.node(q_src).shape.dims().len();
5273 if q_axis + 1 != src_rank || k_axis + 1 != src_rank || v_axis + 1 != src_rank {
5274 return None;
5275 }
5276 Some((q_src, q_len as u32))
5277}
5278
5279fn detect_residual_ln_tee_pattern(
5309 graph: &Graph,
5310) -> (
5311 HashMap<NodeId, (NodeId, NodeId, NodeId, NodeId, NodeId)>,
5312 HashSet<NodeId>,
5313) {
5314 use rlx_ir::op::BinaryOp;
5315 let mut consumers: HashMap<NodeId, usize> = HashMap::new();
5317 for node in graph.nodes() {
5318 for &input in &node.inputs {
5319 *consumers.entry(input).or_insert(0) += 1;
5320 }
5321 }
5322 for &out in &graph.outputs {
5323 *consumers.entry(out).or_insert(0) += 1;
5324 }
5325
5326 let mut ln_to_tee = HashMap::new();
5327 let mut skip_adds = HashSet::new();
5328 for node in graph.nodes() {
5329 let Op::LayerNorm { axis: _, eps: _ } = &node.op else {
5330 continue;
5331 };
5332 if node.inputs.len() < 3 {
5333 continue;
5334 } let in_id = node.inputs[0];
5336 let in_node = graph.node(in_id);
5337 if !matches!(in_node.op, Op::Binary(BinaryOp::Add)) {
5338 continue;
5339 }
5340 if consumers.get(&in_id).copied().unwrap_or(0) < 2 {
5343 continue;
5344 }
5345 if in_node.inputs.len() != 2 {
5348 continue;
5349 }
5350 let h_id = in_node.inputs[0];
5351 let delta_id = in_node.inputs[1];
5352 if graph.node(h_id).shape.dims() != node.shape.dims() {
5353 continue;
5354 }
5355 if graph.node(delta_id).shape.dims() != node.shape.dims() {
5356 continue;
5357 }
5358 let gamma_id = node.inputs[1];
5359 let beta_id = node.inputs[2];
5360 ln_to_tee.insert(node.id, (h_id, delta_id, gamma_id, beta_id, in_id));
5361 skip_adds.insert(in_id);
5362 }
5363 (ln_to_tee, skip_adds)
5364}
5365
5366fn detect_split_qkv_pattern(graph: &Graph) -> HashMap<NodeId, (NodeId, NodeId, NodeId)> {
5367 let mut consumers: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
5369 for node in graph.nodes() {
5370 for &input in &node.inputs {
5371 consumers.entry(input).or_default().push(node.id);
5372 }
5373 }
5374 for &out_id in &graph.outputs {
5377 consumers.entry(out_id).or_default().push(NodeId(u32::MAX));
5378 }
5379
5380 let mut result = HashMap::new();
5381 for node in graph.nodes() {
5382 if !matches!(node.op, Op::FusedMatMulBiasAct { activation: None }) {
5383 continue;
5384 }
5385 let cs = match consumers.get(&node.id) {
5386 Some(c) if c.len() == 3 => c,
5387 _ => continue,
5388 };
5389 let dims = node.shape.dims();
5390 if dims.is_empty() {
5391 continue;
5392 }
5393 let last_axis = dims.len() - 1;
5394 let n = dims[last_axis].unwrap_static();
5395 if n % 3 != 0 {
5396 continue;
5397 }
5398 let head_width = n / 3;
5399
5400 let mut narrows: Vec<(usize, NodeId)> = Vec::with_capacity(3);
5402 let mut all_match = true;
5403 for &c in cs {
5404 let cn = graph.node(c);
5405 match cn.op {
5406 Op::Narrow { axis, start, len }
5407 if axis == last_axis && len == head_width && cn.inputs[0] == node.id =>
5408 {
5409 narrows.push((start, c));
5410 }
5411 _ => {
5412 all_match = false;
5413 break;
5414 }
5415 }
5416 }
5417 if !all_match {
5418 continue;
5419 }
5420 narrows.sort_by_key(|&(start, _)| start);
5421 if narrows[0].0 != 0 || narrows[1].0 != head_width || narrows[2].0 != 2 * head_width {
5422 continue;
5423 }
5424 result.insert(node.id, (narrows[0].1, narrows[1].1, narrows[2].1));
5425 }
5426 result
5427}
5428
5429fn traces_to_param(graph: &Graph, mut id: NodeId) -> bool {
5435 loop {
5436 let node = graph.node(id);
5437 match &node.op {
5438 Op::Param { .. } => return true,
5439 Op::Cast { .. } | Op::Reshape { .. } => {
5440 if node.inputs.is_empty() {
5441 return false;
5442 }
5443 id = node.inputs[0];
5444 }
5445 _ => return false,
5446 }
5447 }
5448}
5449
5450#[allow(dead_code)]
5467fn push_cast_f32_to_f16_step(
5468 device: &wgpu::Device,
5469 arena: &Arena,
5470 schedule: &mut Vec<Step>,
5471 uniforms: &mut Vec<wgpu::Buffer>,
5472 bind_groups: &mut Vec<wgpu::BindGroup>,
5473 mm_cast: &Option<&'static Kernel>,
5474 src_off: u32,
5475 len: u32,
5476) {
5477 let kernel = match mm_cast {
5478 Some(k) => *k,
5479 None => return, };
5481 let f16_buf = match &arena.f16_buffer {
5482 Some(b) => b,
5483 None => return,
5484 };
5485 let p = CastF32ToF16Params {
5486 src_off,
5487 len,
5488 _p0: 0,
5489 _p1: 0,
5490 };
5491 let u = device.create_buffer(&wgpu::BufferDescriptor {
5492 label: Some("rlx-wgpu cast_f32_to_f16 uniform"),
5493 size: std::mem::size_of::<CastF32ToF16Params>() as u64,
5494 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
5495 mapped_at_creation: false,
5496 });
5497 let dev = wgpu_device().expect("rlx-wgpu: device gone");
5499 dev.queue.write_buffer(&u, 0, bytemuck::bytes_of(&p));
5500 let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
5501 label: Some("rlx-wgpu cast_f32_to_f16 bg"),
5502 layout: &kernel.bgl,
5503 entries: &[
5504 wgpu::BindGroupEntry {
5505 binding: 0,
5506 resource: f16_buf.as_entire_binding(),
5507 },
5508 wgpu::BindGroupEntry {
5509 binding: 1,
5510 resource: u.as_entire_binding(),
5511 },
5512 wgpu::BindGroupEntry {
5513 binding: 2,
5514 resource: arena.buffer.as_entire_binding(),
5515 },
5516 ],
5517 });
5518 schedule.push(Step::CastF32ToF16 { params: p });
5519 uniforms.push(u);
5520 bind_groups.push(bg);
5521}
5522
5523fn build_matmul_bind_group(
5524 device: &wgpu::Device,
5525 mm_k: &Kernel,
5526 _mm_w: &Kernel,
5527 mm_f16w: &Option<&'static Kernel>,
5528 mm_f16c: &Option<&'static Kernel>,
5529 mm_coop: &Option<&'static Kernel>,
5530 mm_coop_f32: &Option<&'static Kernel>,
5531 arena: &Arena,
5532 params: &wgpu::Buffer,
5533 b_is_param: bool,
5534 compute_precision: MatmulCompute,
5535) -> wgpu::BindGroup {
5536 if b_is_param
5537 && compute_precision == MatmulCompute::CoopF32
5538 && let Some(coop_f32) = mm_coop_f32
5539 {
5540 return device.create_bind_group(&wgpu::BindGroupDescriptor {
5543 label: Some("rlx-wgpu matmul_coop_f32 bg"),
5544 layout: &coop_f32.bgl,
5545 entries: &[
5546 wgpu::BindGroupEntry {
5547 binding: 0,
5548 resource: arena.buffer.as_entire_binding(),
5549 },
5550 wgpu::BindGroupEntry {
5551 binding: 1,
5552 resource: params.as_entire_binding(),
5553 },
5554 ],
5555 });
5556 }
5557 if b_is_param
5558 && compute_precision == MatmulCompute::Coop16
5559 && let (Some(f16_buf), Some(coop)) = (&arena.f16_buffer, mm_coop)
5560 {
5561 return device.create_bind_group(&wgpu::BindGroupDescriptor {
5565 label: Some("rlx-wgpu matmul_coop16 bg"),
5566 layout: &coop.bgl,
5567 entries: &[
5568 wgpu::BindGroupEntry {
5569 binding: 0,
5570 resource: arena.buffer.as_entire_binding(),
5571 },
5572 wgpu::BindGroupEntry {
5573 binding: 1,
5574 resource: params.as_entire_binding(),
5575 },
5576 wgpu::BindGroupEntry {
5577 binding: 2,
5578 resource: f16_buf.as_entire_binding(),
5579 }, ],
5581 });
5582 }
5583 if b_is_param
5584 && compute_precision == MatmulCompute::F16
5585 && let (Some(f16_buf), Some(f16c)) = (&arena.f16_buffer, mm_f16c)
5586 {
5587 return device.create_bind_group(&wgpu::BindGroupDescriptor {
5588 label: Some("rlx-wgpu matmul_f16_compute bg"),
5589 layout: &f16c.bgl,
5590 entries: &[
5591 wgpu::BindGroupEntry {
5592 binding: 0,
5593 resource: arena.buffer.as_entire_binding(),
5594 },
5595 wgpu::BindGroupEntry {
5596 binding: 1,
5597 resource: params.as_entire_binding(),
5598 },
5599 wgpu::BindGroupEntry {
5600 binding: 2,
5601 resource: f16_buf.as_entire_binding(),
5602 },
5603 ],
5604 });
5605 }
5606 let f16w_opt_in = rlx_ir::env::flag("RLX_WGPU_F16_WEIGHTS");
5607 if b_is_param
5608 && f16w_opt_in
5609 && let (Some(f16_buf), Some(f16w)) = (&arena.f16_buffer, mm_f16w)
5610 {
5611 return device.create_bind_group(&wgpu::BindGroupDescriptor {
5612 label: Some("rlx-wgpu matmul_f16w bg"),
5613 layout: &f16w.bgl,
5614 entries: &[
5615 wgpu::BindGroupEntry {
5616 binding: 0,
5617 resource: arena.buffer.as_entire_binding(),
5618 },
5619 wgpu::BindGroupEntry {
5620 binding: 1,
5621 resource: params.as_entire_binding(),
5622 },
5623 wgpu::BindGroupEntry {
5624 binding: 2,
5625 resource: f16_buf.as_entire_binding(),
5626 },
5627 ],
5628 });
5629 }
5630 bind_two(device, mm_k, &arena.buffer, params)
5631}