1use std::collections::{HashMap, HashSet};
23use std::num::NonZeroU64;
24
25use rlx_ir::dynamic::{bind_graph, has_dynamic_dims, infer_bindings_from_f32_inputs, same_binding};
26use rlx_ir::op::{Activation, BinaryOp, CmpOp, MaskKind, ReduceOp};
27use rlx_ir::shape::DimBinding;
28use rlx_ir::{Graph, NodeId, Op};
29
30use crate::buffer::{
31 Arena, ReadbackLayout, ReadbackStaging, TinyReadbackStaging, decode_mapped_readback_f32,
32 decode_tiny_mapped_f32, encode_readback_copies, plan_f32_uniform, read_f32_many_pooled,
33 schedule_readback_map, use_tiny_readback, wait_readback_map,
34};
35use crate::device::wgpu_device;
36use crate::kernels::{
37 ArgmaxParams, AttentionBwdParams, AttentionParams, BatchElementwiseRegionParams, BinaryParams,
38 Conv1dParams, Conv2dParams, Conv3dParams, CopyParams, CumsumBwdParams, CumsumParams,
39 DequantMatmulParams, ElementwiseRegionParams, ExpandParams, FusedResidualLnParams,
40 FusedResidualLnTeeParams, FusedResidualRmsNormParams, GatherAxisParams, GatherBwdParams,
41 GatherParams, GroupedMatmulParams, Kernel, LayerNormBwdParams, LayerNormParams, MatmulParams,
42 MatmulQkvParams, NarrowConcatParams, Pool1dParams, Pool2dParams, Pool3dParams, ReduceParams,
43 RmsNormBwdParams, RopeBwdParams, RopeParams, SampleParams, ScatterAddParams,
44 SelectiveScanParams, SoftmaxParams, TopKParams, TransposeParams, UmapKnnParams, UnaryParams,
45 WelchPeaksGpuParams, WhereParams, argmax_kernel, attention_bwd_kernel, attention_kernel,
46 batch_elementwise_region_kernel, binary_kernel, cast_f32_to_f16_kernel, compare_kernel,
47 concat_kernel, conv1d_kernel, conv2d_kernel, conv3d_kernel, copy_kernel,
48 cumsum_backward_kernel, cumsum_kernel, dequant_matmul_kernel, elementwise_region_kernel,
49 elementwise_region_spatial_kernel, expand_kernel, fused_residual_ln_kernel,
50 fused_residual_ln_tee_kernel, fused_residual_rms_norm_kernel, gather_axis_kernel,
51 gather_backward_acc_kernel, gather_backward_zero_kernel, gather_kernel, grouped_matmul_kernel,
52 layer_norm_backward_gamma_partial_kernel, layer_norm_backward_gamma_reduce_kernel,
53 layer_norm_backward_input_kernel, layernorm_kernel, matmul_coop_f16_vulkan_active_kernel,
54 matmul_coop_f16_vulkan_kernel, matmul_coop_f32_active_kernel, matmul_coop16_kernel,
55 matmul_f16_compute_kernel, matmul_f16w_kernel, matmul_kernel,
56 matmul_qkv_coop_f16_vk_active_kernel, matmul_qkv_coop_f16_vk_kernel,
57 matmul_qkv_coop_f32_kernel, matmul_qkv_kernel, matmul_wide_active_kernel, matmul_wide_kernel,
58 narrow_kernel, pool1d_kernel, pool2d_kernel, pool3d_kernel, reduce_kernel,
59 rms_norm_backward_kernel, rms_norm_backward_param_kernel, rope_backward_kernel, rope_kernel,
60 sample_kernel, scatter_add_kernel, selective_scan_kernel, softmax_kernel, topk_kernel,
61 transpose_kernel, umap_knn_kernel, unary_f16_mirror_kernel, unary_kernel,
62 welch_peaks_gpu_kernel, where_kernel,
63};
64fn compute_scratch_bytes(graph: &rlx_ir::Graph) -> usize {
68 const ROWS_PER_WG: u32 = 16;
69 let mut max_bytes = 0usize;
70 for node in graph.nodes() {
71 if matches!(
77 &node.op,
78 rlx_ir::Op::LayerNorm { .. } | rlx_ir::Op::RmsNorm { .. }
79 ) {
80 let x_shape = &graph.node(node.inputs[0]).shape;
81 let h_dim = x_shape.dim(x_shape.rank() - 1);
82 if h_dim.is_static() {
83 let h = h_dim.unwrap_static();
84 let bytes = ((h * 4).div_ceil(256) * 256) * 2;
86 if bytes > max_bytes {
87 max_bytes = bytes;
88 }
89 }
90 }
91 if let rlx_ir::Op::LayerNormBackwardGamma { .. } = &node.op {
92 let x_shape = &graph.node(node.inputs[0]).shape;
93 let Some(elems) = x_shape.num_elements() else {
94 continue;
95 };
96 let h_dim = x_shape.dim(x_shape.rank() - 1);
97 if !h_dim.is_static() {
98 continue;
99 }
100 let h = h_dim.unwrap_static();
101 if h == 0 {
102 continue;
103 }
104 let rows = (elems / h) as u32;
105 let num_workgroups = rows.div_ceil(ROWS_PER_WG.max(1));
106 let bytes = (num_workgroups as usize) * h * 4;
107 if bytes > max_bytes {
108 max_bytes = bytes;
109 }
110 }
111 }
112 max_bytes.max(64 * 1024 * 1024)
116}
117
118fn hash_f32_input(data: &[f32]) -> u64 {
121 let bytes = bytemuck::cast_slice(data);
122 let mut h: u64 = 0xcbf29ce484222325;
123 h ^= data.len() as u64;
124 h = h.wrapping_mul(0x100000001b3);
125 for chunk in bytes.chunks(8) {
126 let mut arr = [0u8; 8];
127 arr[..chunk.len()].copy_from_slice(chunk);
128 h ^= u64::from_le_bytes(arr);
129 h = h.wrapping_mul(0x100000001b3);
130 }
131 h
132}
133
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
145enum MatmulCompute {
146 F32,
147 F16,
148 Coop16,
149 CoopF32,
154 CoopF16Vk,
157}
158
159#[derive(Debug, Clone, Copy, PartialEq, Eq)]
161enum MatmulQkvKind {
162 F32,
163 CoopF32,
164 CoopF16Vk,
165}
166
167#[allow(dead_code)]
177#[derive(Debug, Clone, Copy)]
178struct CastF32ToF16Params {
179 pub src_off: u32, pub len: u32,
181 pub _p0: u32,
182 pub _p1: u32,
183}
184unsafe impl bytemuck::Pod for CastF32ToF16Params {}
185unsafe impl bytemuck::Zeroable for CastF32ToF16Params {}
186
187#[allow(dead_code)]
196enum Step {
197 CastF32ToF16 {
198 params: CastF32ToF16Params,
199 },
200 Matmul {
201 m: u32,
202 k: u32,
203 n: u32,
204 a_off_f32: u32,
205 b_off_f32: u32,
206 c_off_f32: u32,
207 batch: u32,
208 a_batch_stride: u32,
209 b_batch_stride: u32,
210 c_batch_stride: u32,
211 has_bias: u32,
212 bias_off_f32: u32,
213 act_id: u32, b_is_param: bool,
220 compute_precision: MatmulCompute,
226 },
227 Binary {
228 params: BinaryParams,
229 },
230 Compare {
231 params: BinaryParams,
232 },
233 Unary {
234 params: UnaryParams,
235 f16_mirror: bool,
236 },
237 Where {
238 params: WhereParams,
239 },
240 Reduce {
241 params: ReduceParams,
242 },
243 Softmax {
244 params: SoftmaxParams,
245 },
246 LayerNorm {
247 params: LayerNormParams,
248 },
249 Cumsum {
250 params: CumsumParams,
251 },
252 FftGpu {
254 src_off: u32,
255 dst_off: u32,
256 outer: u32,
257 n: u32,
258 inverse: u32,
259 norm_scale: f32,
260 },
261 FftHost {
264 src_byte_off: u32,
265 dst_byte_off: u32,
266 outer: u32,
267 n_complex: u32,
268 inverse: bool,
269 norm_tag: u32,
270 dtype_tag: u32,
271 },
272 WelchPeaksHost {
274 spec_byte_off: u32,
275 dst_byte_off: u32,
276 welch_batch: u32,
277 n_fft: u32,
278 n_segments: u32,
279 k: u32,
280 },
281 LogMelHost {
282 spec_byte_off: u32,
283 filt_byte_off: u32,
284 dst_byte_off: u32,
285 outer: u32,
286 n_fft: u32,
287 n_bins: u32,
288 n_mels: u32,
289 },
290 LogMelBackwardHost {
291 spec_byte_off: u32,
292 filt_byte_off: u32,
293 dy_byte_off: u32,
294 dst_byte_off: u32,
295 outer: u32,
296 n_fft: u32,
297 n_bins: u32,
298 n_mels: u32,
299 },
300 Im2ColHost {
302 x_byte_off: u32,
303 col_byte_off: u32,
304 n: u32,
305 c_in: u32,
306 h: u32,
307 w: u32,
308 h_out: u32,
309 w_out: u32,
310 kh: u32,
311 kw: u32,
312 sh: u32,
313 sw: u32,
314 ph: u32,
315 pw: u32,
316 dh: u32,
317 dw_dil: u32,
318 },
319 BufferCopy {
323 src_byte_off: u32,
324 dst_byte_off: u32,
325 bytes: u32,
326 },
327 Copy {
328 params: CopyParams,
329 },
330 ElementwiseRegion {
336 params: ElementwiseRegionParams,
337 },
338 BatchElementwiseRegion {
339 params: BatchElementwiseRegionParams,
340 },
341 Transpose {
342 params: TransposeParams,
343 meta_idx: usize,
344 },
345 Narrow {
346 params: NarrowConcatParams,
347 },
348 Concat {
349 params: NarrowConcatParams,
350 }, Gather {
352 params: GatherParams,
353 },
354 GatherAxis {
355 params: GatherAxisParams,
356 },
357 Attention {
358 params: AttentionParams,
359 mask_buf: Option<wgpu::Buffer>,
360 },
361 AttentionBackward {
362 params: AttentionBwdParams,
363 mask_buf: Option<wgpu::Buffer>,
364 },
365 Rope {
366 params: RopeParams,
367 },
368 Expand {
369 params: ExpandParams,
370 meta_idx: usize,
371 },
372 Argmax {
373 params: ArgmaxParams,
374 },
375 Pool2d {
376 params: Pool2dParams,
377 },
378 Conv2d {
379 params: Conv2dParams,
380 },
381 Pool1d {
382 params: Pool1dParams,
383 },
384 Pool3d {
385 params: Pool3dParams,
386 },
387 Conv1d {
388 params: Conv1dParams,
389 },
390 Conv3d {
391 params: Conv3dParams,
392 },
393 ScatterAdd {
394 params: ScatterAddParams,
395 },
396 TopK {
397 params: TopKParams,
398 },
399 WelchPeaksGpu {
400 params: WelchPeaksGpuParams,
401 },
402 GroupedMatmul {
403 params: GroupedMatmulParams,
404 },
405 Sample {
406 params: SampleParams,
407 },
408 SelectiveScan {
409 params: SelectiveScanParams,
410 },
411 DequantMatmul {
412 params: DequantMatmulParams,
413 },
414 DequantMatmulGguf {
416 m: u32,
417 k: u32,
418 n: u32,
419 scheme_id: u32,
420 x_byte_off: u32,
421 w_byte_off: u32,
422 out_byte_off: u32,
423 },
424 DequantGroupedMatmulGguf {
426 m: u32,
427 k: u32,
428 n: u32,
429 num_experts: u32,
430 scheme_id: u32,
431 x_byte_off: u32,
432 w_byte_off: u32,
433 idx_byte_off: u32,
434 out_byte_off: u32,
435 },
436 GatedDeltaNet {
438 q_byte_off: u32,
439 k_byte_off: u32,
440 v_byte_off: u32,
441 g_byte_off: u32,
442 beta_byte_off: u32,
443 state_byte_off: u32,
444 dst_byte_off: u32,
445 batch: u32,
446 seq: u32,
447 heads: u32,
448 state_size: u32,
449 use_carry: bool,
450 },
451 Llada2GroupLimitedGate {
452 sig_byte_off: u32,
453 route_byte_off: u32,
454 out_byte_off: u32,
455 n_elems: u32,
456 attrs: [u8; 20],
457 },
458 UmapKnn {
459 params: UmapKnnParams,
460 },
461 UmapKnnHost {
463 pairwise_byte_off: u32,
464 out_byte_off: u32,
465 n: u32,
466 k: u32,
467 },
468 #[cfg(feature = "splat")]
470 GaussianSplatRender {
471 positions_byte_off: u32,
472 positions_len: u32,
473 scales_byte_off: u32,
474 scales_len: u32,
475 rotations_byte_off: u32,
476 rotations_len: u32,
477 opacities_byte_off: u32,
478 opacities_len: u32,
479 colors_byte_off: u32,
480 colors_len: u32,
481 sh_coeffs_byte_off: u32,
482 sh_coeffs_len: u32,
483 meta_byte_off: u32,
484 dst_byte_off: u32,
485 dst_len: u32,
486 width: u32,
487 height: u32,
488 tile_size: u32,
489 radius_scale: f32,
490 alpha_cutoff: f32,
491 max_splat_steps: u32,
492 transmittance_threshold: f32,
493 max_list_entries: u32,
494 },
495 #[cfg(feature = "splat")]
497 GaussianSplatRenderBackward {
498 positions_byte_off: u32,
499 positions_len: u32,
500 scales_byte_off: u32,
501 scales_len: u32,
502 rotations_byte_off: u32,
503 rotations_len: u32,
504 opacities_byte_off: u32,
505 opacities_len: u32,
506 colors_byte_off: u32,
507 colors_len: u32,
508 sh_coeffs_byte_off: u32,
509 sh_coeffs_len: u32,
510 meta_byte_off: u32,
511 d_loss_byte_off: u32,
512 d_loss_len: u32,
513 packed_byte_off: u32,
514 packed_len: u32,
515 width: u32,
516 height: u32,
517 tile_size: u32,
518 radius_scale: f32,
519 alpha_cutoff: f32,
520 max_splat_steps: u32,
521 transmittance_threshold: f32,
522 max_list_entries: u32,
523 loss_grad_clip: f32,
524 sh_band: u32,
525 max_anisotropy: f32,
526 },
527 #[cfg(feature = "splat")]
528 GaussianSplatPrepare {
529 positions_byte_off: u32,
530 positions_len: u32,
531 scales_byte_off: u32,
532 scales_len: u32,
533 rotations_byte_off: u32,
534 rotations_len: u32,
535 opacities_byte_off: u32,
536 opacities_len: u32,
537 colors_byte_off: u32,
538 colors_len: u32,
539 sh_coeffs_byte_off: u32,
540 sh_coeffs_len: u32,
541 meta_byte_off: u32,
542 meta_len: u32,
543 prep_byte_off: u32,
544 prep_len: u32,
545 width: u32,
546 height: u32,
547 tile_size: u32,
548 radius_scale: f32,
549 alpha_cutoff: f32,
550 max_splat_steps: u32,
551 transmittance_threshold: f32,
552 max_list_entries: u32,
553 },
554 #[cfg(feature = "splat")]
555 GaussianSplatRasterize {
556 prep_byte_off: u32,
557 prep_len: u32,
558 meta_byte_off: u32,
559 meta_len: u32,
560 dst_byte_off: u32,
561 dst_len: u32,
562 count: u32,
563 width: u32,
564 height: u32,
565 tile_size: u32,
566 alpha_cutoff: f32,
567 max_splat_steps: u32,
568 transmittance_threshold: f32,
569 max_list_entries: u32,
570 },
571 RmsNormBackwardInput {
572 params: RmsNormBwdParams,
573 },
574 RmsNormBackwardGamma {
575 params: RmsNormBwdParams,
576 },
577 RmsNormBackwardBeta {
578 params: RmsNormBwdParams,
579 },
580 LayerNormBackwardInput {
581 params: LayerNormBwdParams,
582 },
583 LayerNormBackwardGammaPartial {
584 params: LayerNormBwdParams,
585 num_workgroups: u32,
586 },
587 LayerNormBackwardGammaReduce {
588 params: LayerNormBwdParams,
589 },
590 RopeBackward {
591 params: RopeBwdParams,
592 },
593 CumsumBackward {
594 params: CumsumBwdParams,
595 },
596 GatherBackward {
597 params: GatherBwdParams,
598 },
599 FusedResidualLn {
600 params: FusedResidualLnParams,
601 },
602 MatmulQkv {
607 params: MatmulQkvParams,
608 kind: MatmulQkvKind,
609 },
610 FusedResidualLnTee {
614 params: FusedResidualLnTeeParams,
615 },
616 FusedResidualRmsNorm {
617 params: FusedResidualRmsNormParams,
618 },
619}
620
621pub struct WgpuExecutable {
622 graph: Graph,
623 arena: Arena,
624 schedule: Vec<Step>,
625 input_offsets: HashMap<String, NodeId>,
626 param_offsets: HashMap<String, NodeId>,
627 uniforms: Vec<wgpu::Buffer>,
630 bind_groups: Vec<wgpu::BindGroup>,
631 meta_buffers: Vec<wgpu::Buffer>,
634
635 unresolved: Option<Graph>,
642 last_binding: Option<DimBinding>,
643 pending_params: HashMap<String, Vec<f32>>,
647 pending_param_bytes: HashMap<String, Vec<u8>>,
648 pub(crate) active_extent: Option<(usize, usize)>,
652 uniforms_active_extent: Option<Option<(usize, usize)>>,
662 input_staging_hashes: HashMap<String, u64>,
664 coop_f16_vk: bool,
667 coop_f16_b_param: HashMap<u32, String>,
669 coop_f16_vk_wide_b: HashSet<String>,
671 coop_f16_vk_wide_bind_groups: HashMap<usize, wgpu::BindGroup>,
673 coop_f16_host_activations: Vec<(NodeId, Activation, String)>,
675 stashed_params: HashMap<String, Vec<f32>>,
677 readback_staging: Option<ReadbackStaging>,
679 tiny_readback: Option<TinyReadbackStaging>,
681 fft_gpu_steps: Vec<crate::fft_dispatch::FftGpuResources>,
683 gpu_handles: HashMap<String, Vec<f32>>,
685 gpu_handle_feeds: HashMap<String, usize>,
686 gpu_handle_resident: HashSet<String>,
688 pending_read_indices: Option<Vec<usize>>,
689}
690
691impl Step {
692 pub fn safe_for_active_extent(&self) -> bool {
698 match self {
699 Step::Binary { .. }
700 | Step::Compare { .. }
701 | Step::Unary { .. }
702 | Step::Where { .. }
703 | Step::Reduce { .. }
704 | Step::Softmax { .. }
705 | Step::LayerNorm { .. }
706 | Step::FusedResidualLn { .. }
707 | Step::FusedResidualLnTee { .. }
708 | Step::FusedResidualRmsNorm { .. }
709 | Step::Cumsum { .. }
710 | Step::Copy { .. }
711 | Step::ElementwiseRegion { .. }
712 | Step::BatchElementwiseRegion { .. }
713 | Step::Argmax { .. }
714 | Step::TopK { .. }
715 | Step::WelchPeaksGpu { .. }
716 | Step::Sample { .. }
717 | Step::Gather { .. }
718 | Step::GatherAxis { .. }
719 | Step::GroupedMatmul { .. }
720 | Step::DequantMatmul { .. }
721 | Step::DequantMatmulGguf { .. }
722 | Step::DequantGroupedMatmulGguf { .. }
723 | Step::GatedDeltaNet { .. }
724 | Step::Llada2GroupLimitedGate { .. }
725 | Step::UmapKnn { .. }
726 | Step::UmapKnnHost { .. }
727 | Step::Conv1d { .. }
728 | Step::Conv2d { .. }
729 | Step::Conv3d { .. }
730 | Step::Pool1d { .. }
731 | Step::Pool2d { .. }
732 | Step::Pool3d { .. }
733 | Step::ScatterAdd { .. }
734 | Step::BufferCopy { .. } => true,
735 Step::FftGpu { .. } | Step::FftHost { .. } => true,
740 Step::Im2ColHost { .. }
741 | Step::WelchPeaksHost { .. }
742 | Step::LogMelHost { .. }
743 | Step::LogMelBackwardHost { .. } => true,
744 Step::Matmul { .. } => true,
749 Step::MatmulQkv { .. } => true,
753 Step::CastF32ToF16 { .. } => true,
754 Step::Attention { .. } => true,
760 Step::AttentionBackward { .. } => true,
761 Step::SelectiveScan { .. } => true,
766 Step::Narrow { .. } => true,
775 Step::Concat { .. } => true,
776 Step::Rope { .. } => true,
782 Step::Transpose { params, .. } => params.bucket_outermost == 1,
788 Step::Expand { params, .. } => params.bucket_outermost == 1,
792 Step::RmsNormBackwardInput { .. }
795 | Step::RmsNormBackwardGamma { .. }
796 | Step::RmsNormBackwardBeta { .. }
797 | Step::LayerNormBackwardInput { .. }
798 | Step::LayerNormBackwardGammaPartial { .. }
799 | Step::LayerNormBackwardGammaReduce { .. }
800 | Step::RopeBackward { .. }
801 | Step::CumsumBackward { .. }
802 | Step::GatherBackward { .. } => false,
803 #[cfg(feature = "splat")]
804 Step::GaussianSplatRender { .. }
805 | Step::GaussianSplatRenderBackward { .. }
806 | Step::GaussianSplatPrepare { .. }
807 | Step::GaussianSplatRasterize { .. } => false,
808 }
809 }
810}
811
812fn fft_dtype_tag(dtype: rlx_ir::DType) -> u32 {
815 match dtype {
816 rlx_ir::DType::F32 => 0,
817 rlx_ir::DType::F64 => 1,
818 rlx_ir::DType::C64 => 2,
819 other => panic!("rlx-wgpu Op::Fft: unsupported dtype {other:?}"),
820 }
821}
822
823fn fft_dtype_from_tag(tag: u32) -> rlx_ir::DType {
824 match tag {
825 0 => rlx_ir::DType::F32,
826 1 => rlx_ir::DType::F64,
827 2 => rlx_ir::DType::C64,
828 other => panic!("rlx-wgpu Op::Fft: bad dtype tag {other}"),
829 }
830}
831
832fn step_name(step: &Step) -> &'static str {
833 match step {
834 Step::CastF32ToF16 { .. } => "cast_f32_to_f16",
835 Step::Matmul { .. } => "matmul",
836 Step::Binary { .. } => "binary",
837 Step::Compare { .. } => "compare",
838 Step::Unary { .. } => "unary",
839 Step::Where { .. } => "where",
840 Step::Reduce { .. } => "reduce",
841 Step::Softmax { .. } => "softmax",
842 Step::LayerNorm { .. } => "layer_norm",
843 Step::Cumsum { .. } => "cumsum",
844 Step::FftGpu { .. } => "fft_gpu",
845 Step::FftHost { .. } => "fft_host",
846 Step::WelchPeaksHost { .. } => "welch_peaks_host",
847 Step::LogMelHost { .. } => "log_mel_host",
848 Step::LogMelBackwardHost { .. } => "log_mel_backward_host",
849 Step::Im2ColHost { .. } => "im2col_host",
850 Step::BufferCopy { .. } => "buffer_copy",
851 Step::Copy { .. } => "copy",
852 Step::Transpose { .. } => "transpose",
853 Step::Narrow { .. } => "narrow",
854 Step::Concat { .. } => "concat",
855 Step::Gather { .. } => "gather",
856 Step::GatherAxis { .. } => "gather_axis",
857 Step::Attention { .. } => "attention",
858 Step::AttentionBackward { .. } => "attention_bwd",
859 Step::Rope { .. } => "rope",
860 Step::Expand { .. } => "expand",
861 Step::Argmax { .. } => "argmax",
862 Step::Pool2d { .. } => "pool2d",
863 Step::Conv2d { .. } => "conv2d",
864 Step::Pool1d { .. } => "pool1d",
865 Step::Pool3d { .. } => "pool3d",
866 Step::Conv1d { .. } => "conv1d",
867 Step::Conv3d { .. } => "conv3d",
868 Step::ScatterAdd { .. } => "scatter_add",
869 Step::TopK { .. } => "topk",
870 Step::WelchPeaksGpu { .. } => "welch_peaks_gpu",
871 Step::GroupedMatmul { .. } => "grouped_matmul",
872 Step::Sample { .. } => "sample",
873 Step::SelectiveScan { .. } => "selective_scan",
874 Step::DequantMatmul { .. } => "dequant_matmul",
875 Step::DequantMatmulGguf { .. } => "dequant_matmul_gguf",
876 Step::DequantGroupedMatmulGguf { .. } => "dequant_grouped_matmul_gguf",
877 Step::GatedDeltaNet { .. } => "gated_delta_net",
878 Step::Llada2GroupLimitedGate { .. } => "llada2_group_limited_gate",
879 Step::UmapKnn { .. } => "umap_knn",
880 Step::UmapKnnHost { .. } => "umap_knn_host",
881 #[cfg(feature = "splat")]
882 Step::GaussianSplatRender { .. } => "gaussian_splat_render",
883 #[cfg(feature = "splat")]
884 Step::GaussianSplatRenderBackward { .. } => "gaussian_splat_render_backward",
885 #[cfg(feature = "splat")]
886 Step::GaussianSplatPrepare { .. } => "gaussian_splat_prepare",
887 #[cfg(feature = "splat")]
888 Step::GaussianSplatRasterize { .. } => "gaussian_splat_rasterize",
889 Step::RmsNormBackwardInput { .. } => "rms_norm_backward_input",
890 Step::RmsNormBackwardGamma { .. } => "rms_norm_backward_gamma",
891 Step::RmsNormBackwardBeta { .. } => "rms_norm_backward_beta",
892 Step::LayerNormBackwardInput { .. } => "layer_norm_backward_input",
893 Step::LayerNormBackwardGammaPartial { .. } => "layer_norm_backward_gamma_partial",
894 Step::LayerNormBackwardGammaReduce { .. } => "layer_norm_backward_gamma_reduce",
895 Step::RopeBackward { .. } => "rope_backward",
896 Step::CumsumBackward { .. } => "cumsum_backward",
897 Step::GatherBackward { .. } => "gather_backward",
898 Step::FusedResidualLn { .. } => "fused_residual_ln",
899 Step::FusedResidualLnTee { .. } => "fused_residual_ln_tee",
900 Step::FusedResidualRmsNorm { .. } => "fused_residual_rms_norm",
901 Step::MatmulQkv { .. } => "matmul_qkv",
902 Step::ElementwiseRegion { .. } => "elementwise_region",
903 Step::BatchElementwiseRegion { .. } => "batch_elementwise_region",
904 }
905}
906
907fn step_is_tail_host(step: &Step) -> bool {
908 matches!(
909 step,
910 Step::WelchPeaksHost { .. } | Step::LogMelHost { .. } | Step::LogMelBackwardHost { .. }
911 )
912}
913
914fn step_runs_on_host(step: &Step) -> bool {
915 match step {
916 Step::DequantMatmulGguf { .. }
917 | Step::DequantGroupedMatmulGguf { .. }
918 | Step::GatedDeltaNet { .. }
919 | Step::Llada2GroupLimitedGate { .. }
920 | Step::UmapKnnHost { .. }
921 | Step::FftHost { .. }
922 | Step::Im2ColHost { .. }
923 | Step::BufferCopy { .. } => true,
924 #[cfg(feature = "splat")]
925 Step::GaussianSplatRender { .. }
926 | Step::GaussianSplatRenderBackward { .. }
927 | Step::GaussianSplatPrepare { .. }
928 | Step::GaussianSplatRasterize { .. } => true,
929 _ => false,
930 }
931}
932
933fn binary_op_id(op: BinaryOp) -> u32 {
934 match op {
935 BinaryOp::Add => 0,
936 BinaryOp::Sub => 1,
937 BinaryOp::Mul => 2,
938 BinaryOp::Div => 3,
939 BinaryOp::Max => 4,
940 BinaryOp::Min => 5,
941 BinaryOp::Pow => 6,
942 }
943}
944
945fn compare_op_id(op: CmpOp) -> u32 {
946 match op {
947 CmpOp::Eq => 0,
948 CmpOp::Ne => 1,
949 CmpOp::Lt => 2,
950 CmpOp::Le => 3,
951 CmpOp::Gt => 4,
952 CmpOp::Ge => 5,
953 }
954}
955
956fn reduce_op_id(op: ReduceOp) -> u32 {
957 match op {
958 ReduceOp::Sum => 0,
959 ReduceOp::Mean => 1,
960 ReduceOp::Max => 2,
961 ReduceOp::Min => 3,
962 ReduceOp::Prod => 4,
963 }
964}
965
966fn activation_op_id(act: Activation) -> u32 {
967 match act {
968 Activation::Relu => 0,
969 Activation::Sigmoid => 1,
970 Activation::Tanh => 2,
971 Activation::Exp => 3,
972 Activation::Log => 4,
973 Activation::Sqrt => 5,
974 Activation::Rsqrt => 6,
975 Activation::Neg => 7,
976 Activation::Abs => 8,
977 Activation::Gelu => 9,
978 Activation::Silu => 10,
979 Activation::GeluApprox => 11,
980 Activation::Round => 12,
981 Activation::Sin => 13,
982 Activation::Cos => 14,
983 Activation::Tan => 15,
984 Activation::Atan => 16,
985 }
986}
987
988impl WgpuExecutable {
989 fn lazy_compile_for_inputs(&mut self, inputs: &[(&str, &[f32])]) {
993 let unresolved = self
994 .unresolved
995 .as_ref()
996 .expect("lazy_compile_for_inputs called without an unresolved graph");
997 let binding = infer_bindings_from_f32_inputs(unresolved, inputs)
998 .expect("rlx-wgpu lazy compile: could not infer DimBinding from inputs");
999
1000 if let Some(prev) = &self.last_binding
1002 && same_binding(prev, &binding)
1003 {
1004 return;
1005 }
1006
1007 let resolved = bind_graph(unresolved, &binding);
1009 let original = self.unresolved.take();
1010 let pending_params = std::mem::take(&mut self.pending_params);
1011 let pending_bytes = std::mem::take(&mut self.pending_param_bytes);
1012
1013 let fresh = Self::compile_static_inner(resolved);
1014
1015 self.graph = fresh.graph;
1018 self.arena = fresh.arena;
1019 self.schedule = fresh.schedule;
1020 self.input_offsets = fresh.input_offsets;
1021 self.param_offsets = fresh.param_offsets;
1022 self.uniforms = fresh.uniforms;
1023 self.bind_groups = fresh.bind_groups;
1024 self.meta_buffers = fresh.meta_buffers;
1025 self.unresolved = original;
1026 self.last_binding = Some(binding);
1027 self.uniforms_active_extent = None;
1030 self.input_staging_hashes.clear();
1031 self.coop_f16_vk = fresh.coop_f16_vk;
1032 self.coop_f16_b_param = fresh.coop_f16_b_param;
1033 self.coop_f16_vk_wide_bind_groups = fresh.coop_f16_vk_wide_bind_groups;
1034 self.coop_f16_host_activations = fresh.coop_f16_host_activations;
1035
1036 for (name, data) in pending_params {
1038 self.set_param(&name, &data);
1039 }
1040 for (name, data) in pending_bytes {
1041 self.set_param_bytes(&name, &data);
1042 }
1043 }
1044
1045 pub fn compile_with_bindings(graph: Graph, bindings: &DimBinding) -> Self {
1051 if bindings.is_empty() {
1052 return Self::compile(graph);
1053 }
1054 let mut fresh = Graph::new(&graph.name);
1056 for node in graph.nodes() {
1057 let bound = node.shape.bind(bindings);
1058 fresh.add_node(node.op.clone(), node.inputs.clone(), bound);
1059 }
1060 fresh.set_outputs(graph.outputs.clone());
1061 Self::compile(fresh)
1062 }
1063
1064 pub fn compile(graph: Graph) -> Self {
1065 if has_dynamic_dims(&graph) {
1066 return Self::deferred(graph);
1067 }
1068 Self::compile_static_inner(graph)
1069 }
1070
1071 #[doc(hidden)]
1073 pub fn test_attn_q_seq_stride(&self) -> Option<u32> {
1074 self.schedule.iter().find_map(|s| {
1075 if let Step::Attention { params, .. } = s {
1076 Some(params.q_seq_stride)
1077 } else {
1078 None
1079 }
1080 })
1081 }
1082
1083 #[doc(hidden)]
1085 pub fn test_attn_offsets_and_stride(&self) -> Option<(u32, u32, u32, u32)> {
1086 self.schedule.iter().find_map(|s| {
1087 if let Step::Attention { params, .. } = s {
1088 Some((
1089 params.q_off,
1090 params.k_off,
1091 params.v_off,
1092 params.q_seq_stride,
1093 ))
1094 } else {
1095 None
1096 }
1097 })
1098 }
1099
1100 #[doc(hidden)]
1102 pub fn test_arena_offset_elems(&self, id: NodeId) -> u32 {
1103 (self.arena.offset(id) / 4) as u32
1104 }
1105
1106 fn deferred(graph: Graph) -> Self {
1111 let dev = wgpu_device().expect("rlx-wgpu: no compatible adapter found");
1112 let placeholder = dev.device.create_buffer(&wgpu::BufferDescriptor {
1114 label: Some("rlx-wgpu deferred placeholder"),
1115 size: 16,
1116 usage: wgpu::BufferUsages::STORAGE
1117 | wgpu::BufferUsages::COPY_DST
1118 | wgpu::BufferUsages::COPY_SRC,
1119 mapped_at_creation: false,
1120 });
1121 let arena = Arena {
1122 buffer: placeholder,
1123 f16_buffer: None,
1124 offsets: HashMap::new(),
1125 lens: HashMap::new(),
1126 size: 0,
1127 scratch_off: 0,
1128 scratch_bytes: 0,
1129 };
1130 Self {
1131 graph: graph.clone(),
1132 arena,
1133 schedule: Vec::new(),
1134 input_offsets: HashMap::new(),
1135 param_offsets: HashMap::new(),
1136 uniforms: Vec::new(),
1137 bind_groups: Vec::new(),
1138 meta_buffers: Vec::new(),
1139 unresolved: Some(graph),
1140 last_binding: None,
1141 pending_params: HashMap::new(),
1142 pending_param_bytes: HashMap::new(),
1143 active_extent: None,
1144 uniforms_active_extent: None,
1145 input_staging_hashes: HashMap::new(),
1146 coop_f16_vk: false,
1147 coop_f16_b_param: HashMap::new(),
1148 coop_f16_vk_wide_b: HashSet::new(),
1149 coop_f16_vk_wide_bind_groups: HashMap::new(),
1150 coop_f16_host_activations: Vec::new(),
1151 stashed_params: HashMap::new(),
1152 readback_staging: None,
1153 tiny_readback: None,
1154 fft_gpu_steps: Vec::new(),
1155 gpu_handles: HashMap::new(),
1156 gpu_handle_feeds: HashMap::new(),
1157 gpu_handle_resident: HashSet::new(),
1158 pending_read_indices: None,
1159 }
1160 }
1161
1162 pub fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1166 self.active_extent = extent;
1167 }
1168
1169 fn all_safe_for_active(&self) -> bool {
1170 self.schedule.iter().all(|s| s.safe_for_active_extent())
1171 }
1172
1173 fn compile_static_inner(graph: Graph) -> Self {
1174 let dev = wgpu_device().expect("rlx-wgpu: no compatible adapter found");
1175
1176 let graph = crate::unfuse::unfuse(graph);
1183
1184 let plan = plan_f32_uniform(&graph, 16);
1186 let scratch_bytes = compute_scratch_bytes(&graph);
1190 let mut arena = Arena::from_plan_with_scratch(&dev.device, &plan, scratch_bytes);
1191 for node in graph.nodes() {
1195 let elems = node.shape.num_elements().unwrap_or(0);
1196 arena.set_actual_len(node.id, elems * 4);
1197 }
1198
1199 for node in graph.nodes() {
1201 if let Op::Constant { data } = &node.op
1202 && arena.has(node.id)
1203 && !data.is_empty()
1204 {
1205 let bytes_to_write = data.len().min(arena.len_of(node.id));
1206 dev.queue.write_buffer(
1207 &arena.buffer,
1208 arena.offset(node.id) as u64,
1209 &data[..bytes_to_write],
1210 );
1211 }
1212 }
1213
1214 let mut input_offsets = HashMap::new();
1215 let mut param_offsets = HashMap::new();
1216 for node in graph.nodes() {
1217 match &node.op {
1218 Op::Input { name } => {
1219 input_offsets.insert(name.clone(), node.id);
1220 }
1221 Op::Param { name } => {
1222 param_offsets.insert(name.clone(), node.id);
1223 }
1224 _ => {}
1225 }
1226 }
1227
1228 let mm_k = matmul_kernel(&dev.device);
1229 let mm_w = matmul_wide_kernel(&dev.device);
1230 let _mm_w_active = matmul_wide_active_kernel(&dev.device);
1231 let mm_f16w = matmul_f16w_kernel(&dev.device);
1232 let mm_f16c = matmul_f16_compute_kernel(&dev.device);
1233 let mm_coop = matmul_coop16_kernel(&dev.device);
1234 let mm_coop_f32 = matmul_coop_f32_active_kernel(&dev.device);
1235 let mm_cast = cast_f32_to_f16_kernel(&dev.device);
1236 let bk = binary_kernel(&dev.device);
1237 let uk = unary_kernel(&dev.device);
1238 let ck = compare_kernel(&dev.device);
1239 let wk = where_kernel(&dev.device);
1240
1241 let mut schedule = Vec::new();
1242 let mut uniforms = Vec::new();
1243 let mut bind_groups = Vec::new();
1244 let mut fft_gpu_steps: Vec<crate::fft_dispatch::FftGpuResources> = Vec::new();
1245 let mut gguf_host_pad: Option<(wgpu::Buffer, wgpu::BindGroup)> = None;
1246 let mut meta_buffers: Vec<wgpu::Buffer> = Vec::new();
1247 let mut coop_f16_b_param: HashMap<u32, String> = HashMap::new();
1248 let mut coop_f16_vk_wide_bind_groups: HashMap<usize, wgpu::BindGroup> = HashMap::new();
1249 let mm_w_active_compile = matmul_wide_active_kernel(&dev.device);
1250
1251 let coop_f16_vk_mirror_acts = collect_coop_f16_vk_mirror_activations(&graph, &dev.device);
1252
1253 let mut qkv_split: HashMap<NodeId, (NodeId, NodeId, NodeId)> = HashMap::new();
1266 for (parent_id, qkv) in detect_split_qkv_pattern(&graph) {
1267 let parent = graph.node(parent_id);
1268 let a_id = parent.inputs[0];
1271 let b_id = parent.inputs[1];
1272 let a_dims = graph.node(a_id).shape.dims();
1273 let b_dims = graph.node(b_id).shape.dims();
1274 let out_dims = parent.shape.dims();
1275 let (m, k, n) =
1276 if a_dims.len() >= 2 && b_dims.len() == 2 && out_dims.len() == a_dims.len() {
1277 let leading: usize = a_dims[..a_dims.len() - 2]
1278 .iter()
1279 .map(|d| d.unwrap_static())
1280 .product();
1281 let m_inner = a_dims[a_dims.len() - 2].unwrap_static();
1282 let k_inner = a_dims[a_dims.len() - 1].unwrap_static();
1283 let n_inner = b_dims[1].unwrap_static();
1284 ((leading * m_inner) as u32, k_inner as u32, n_inner as u32)
1285 } else if a_dims.len() == 2 && b_dims.len() == 2 {
1286 (
1287 a_dims[0].unwrap_static() as u32,
1288 a_dims[1].unwrap_static() as u32,
1289 b_dims[1].unwrap_static() as u32,
1290 )
1291 } else {
1292 continue; };
1294 let cp = derive_matmul_compute(
1295 &dev.device,
1296 &graph,
1297 &coop_f16_vk_mirror_acts,
1298 a_id,
1299 b_id,
1300 m,
1301 k,
1302 n,
1303 );
1304 if cp == MatmulCompute::F32 || cp == MatmulCompute::CoopF32 {
1309 qkv_split.insert(parent_id, qkv);
1310 }
1311 }
1312 let qkv_skip_narrows: HashSet<NodeId> = qkv_split
1313 .values()
1314 .flat_map(|&(q, k, v)| [q, k, v])
1315 .collect();
1316
1317 let mut packed_bshd_attn: HashMap<NodeId, (NodeId, u32)> = HashMap::new();
1321 let mut packed_bshd_skip_narrows: HashSet<NodeId> = HashSet::new();
1322 if !rlx_ir::env::flag("RLX_WGPU_NO_PACKED_BSHD_ATTN") {
1323 for node in graph.nodes() {
1324 let Op::Attention { .. } = &node.op else {
1325 continue;
1326 };
1327 if node.inputs.len() < 3 {
1328 continue;
1329 }
1330 if let Some((parent, head_width, narrows)) =
1331 rlx_ir::detect_packed_bshd_qkv_attention(
1332 &graph,
1333 node.inputs[0],
1334 node.inputs[1],
1335 node.inputs[2],
1336 )
1337 {
1338 packed_bshd_attn.insert(node.id, (parent, head_width as u32));
1339 for narrow in narrows {
1340 if rlx_ir::packed_bshd_narrow_elidable(&graph, narrow, node.id) {
1341 packed_bshd_skip_narrows.insert(narrow);
1342 }
1343 }
1344 }
1345 }
1346 }
1347
1348 let (ln_to_tee, skip_adds) = detect_residual_ln_tee_pattern(&graph);
1358
1359 let mut coop_f16_host_activations: Vec<(NodeId, Activation, String)> = Vec::new();
1360
1361 let emit_uniform = |size: usize| -> wgpu::Buffer {
1362 dev.device.create_buffer(&wgpu::BufferDescriptor {
1363 label: Some("rlx-wgpu uniform"),
1364 size: size as u64,
1365 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1366 mapped_at_creation: false,
1367 })
1368 };
1369
1370 for node in graph.nodes() {
1371 let elems = node.shape.num_elements().unwrap_or(0) as u32;
1375 match &node.op {
1376 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => continue,
1377 Op::MatMul => {
1378 let a_id = node.inputs[0];
1379 let b_id = node.inputs[1];
1380 let a_shape = graph.node(a_id).shape.dims();
1381 let b_shape = graph.node(b_id).shape.dims();
1382 let out_shape = node.shape.dims();
1383 let (m, k, n, batch, a_bs, b_bs, c_bs) = if a_shape.len() == 2
1388 && b_shape.len() == 2
1389 && out_shape.len() == 2
1390 {
1391 (
1392 a_shape[0].unwrap_static() as u32,
1393 a_shape[1].unwrap_static() as u32,
1394 b_shape[1].unwrap_static() as u32,
1395 1u32,
1396 0u32,
1397 0u32,
1398 0u32,
1399 )
1400 } else if a_shape.len() >= 2
1401 && b_shape.len() == 2
1402 && out_shape.len() == a_shape.len()
1403 {
1404 let leading: usize = a_shape[..a_shape.len() - 2]
1405 .iter()
1406 .map(|d| d.unwrap_static())
1407 .product();
1408 let m_inner = a_shape[a_shape.len() - 2].unwrap_static();
1409 let k_inner = a_shape[a_shape.len() - 1].unwrap_static();
1410 let n_inner = b_shape[1].unwrap_static();
1411 (
1412 (leading * m_inner) as u32,
1413 k_inner as u32,
1414 n_inner as u32,
1415 1u32,
1416 0u32,
1417 0u32,
1418 0u32,
1419 )
1420 } else if a_shape.len() == b_shape.len()
1421 && a_shape.len() >= 3
1422 && out_shape.len() == a_shape.len()
1423 {
1424 let leading_a: Vec<usize> = a_shape[..a_shape.len() - 2]
1426 .iter()
1427 .map(|d| d.unwrap_static())
1428 .collect();
1429 let leading_b: Vec<usize> = b_shape[..b_shape.len() - 2]
1430 .iter()
1431 .map(|d| d.unwrap_static())
1432 .collect();
1433 if leading_a != leading_b {
1434 panic!(
1435 "rlx-wgpu MatMul: batched shape mismatch \
1436 a_leading={leading_a:?} b_leading={leading_b:?}"
1437 );
1438 }
1439 let b_count: usize = leading_a.iter().product();
1440 let m_inner = a_shape[a_shape.len() - 2].unwrap_static();
1441 let k_inner = a_shape[a_shape.len() - 1].unwrap_static();
1442 let n_inner = b_shape[b_shape.len() - 1].unwrap_static();
1443 (
1444 m_inner as u32,
1445 k_inner as u32,
1446 n_inner as u32,
1447 b_count as u32,
1448 (m_inner * k_inner) as u32,
1449 (k_inner * n_inner) as u32,
1450 (m_inner * n_inner) as u32,
1451 )
1452 } else {
1453 panic!(
1454 "rlx-wgpu MatMul: unsupported shapes a={a_shape:?} b={b_shape:?} \
1455 out={out_shape:?} (supported: 2D×2D, [..,M,K]×[K,N], [..,M,K]×[..,K,N])"
1456 );
1457 };
1458 let b_is_param = tensor_is_graph_param(&graph, ¶m_offsets, b_id);
1459 let b_bytes = arena.len_of(b_id) as u64;
1460 let mut compute_precision = derive_matmul_compute(
1461 &dev.device,
1462 &graph,
1463 &coop_f16_vk_mirror_acts,
1464 a_id,
1465 b_id,
1466 m,
1467 k,
1468 n,
1469 );
1470 if b_is_param && b_bytes > ARENA_STAGE_CAP && arena.param_fits_f16_mirror(b_id)
1471 {
1472 compute_precision = MatmulCompute::F16;
1473 }
1474 let (mut base, mut size, param_anchor) = arena_matmul_bind_window(
1475 &dev.device,
1476 &arena,
1477 &graph,
1478 ¶m_offsets,
1479 node.id,
1480 a_id,
1481 b_id,
1482 );
1483 let max_binding = dev.device.limits().max_storage_buffer_binding_size;
1484 arena_expand_bind_window(
1485 &arena,
1486 &[node.id, a_id, b_id],
1487 &mut base,
1488 &mut size,
1489 max_binding,
1490 );
1491 let mut scratch = arena.scratch_off as u64;
1492 if param_anchor {
1493 arena_ensure_scratch_in_window(&mut scratch, base, size);
1494 }
1495 if b_is_param && b_bytes > ARENA_STAGE_CAP {
1496 assert!(
1497 param_anchor && arena_tensor_in_window(&arena, b_id, base, size),
1498 "rlx-wgpu matmul: large param B {:?} off={} not in window base={base} size={size}",
1499 b_id,
1500 arena.offset(b_id),
1501 );
1502 }
1503 let a_off_f32 = arena_off_in_bind_window(
1504 &graph,
1505 ¶m_offsets,
1506 &dev.device,
1507 &arena,
1508 &mut schedule,
1509 &mut scratch,
1510 a_id,
1511 &mut base,
1512 &mut size,
1513 );
1514 let b_off_f32 = if b_is_param
1515 && b_bytes > ARENA_STAGE_CAP
1516 && arena_tensor_in_window(&arena, b_id, base, size)
1517 {
1518 arena_local_off_f32(&arena, b_id, base)
1519 } else {
1520 arena_off_in_bind_window(
1521 &graph,
1522 ¶m_offsets,
1523 &dev.device,
1524 &arena,
1525 &mut schedule,
1526 &mut scratch,
1527 b_id,
1528 &mut base,
1529 &mut size,
1530 )
1531 };
1532 maybe_push_coop_f16_vk_casts(
1533 &graph,
1534 a_id,
1535 b_id,
1536 &coop_f16_vk_mirror_acts,
1537 &dev.device,
1538 &arena,
1539 &mut schedule,
1540 &mut uniforms,
1541 &mut bind_groups,
1542 &mm_cast,
1543 compute_precision,
1544 a_off_f32,
1545 m,
1546 k,
1547 batch,
1548 b_off_f32,
1549 n,
1550 );
1551 schedule.push(Step::Matmul {
1552 m,
1553 k,
1554 n,
1555 batch,
1556 a_batch_stride: a_bs,
1557 b_batch_stride: b_bs,
1558 c_batch_stride: c_bs,
1559 a_off_f32,
1560 b_off_f32,
1561 c_off_f32: arena_local_off_f32(&arena, node.id, base),
1562 has_bias: 0,
1563 bias_off_f32: 0,
1564 act_id: 0xFFFF,
1565 b_is_param,
1566 compute_precision,
1567 });
1568 let b_off_global = (arena.offset(b_id) / 4) as u32;
1569 let b_off_bind = if b_is_param
1570 && matches!(
1571 compute_precision,
1572 MatmulCompute::Coop16 | MatmulCompute::CoopF16Vk | MatmulCompute::F16
1573 ) {
1574 b_off_global
1575 } else {
1576 b_off_f32
1577 };
1578 register_coop_f16_vk_b_param(
1579 &mut coop_f16_b_param,
1580 ¶m_offsets,
1581 b_id,
1582 b_off_bind,
1583 compute_precision,
1584 );
1585 let u = emit_uniform(std::mem::size_of::<MatmulParams>());
1586 let (bg, b_off_adj) = build_matmul_bind_group(
1587 &dev.device,
1588 mm_k,
1589 mm_w,
1590 &mm_f16w,
1591 &mm_f16c,
1592 &mm_coop,
1593 &mm_coop_f32,
1594 &arena,
1595 base,
1596 size,
1597 &u,
1598 b_is_param,
1599 compute_precision,
1600 k,
1601 n,
1602 batch,
1603 b_off_bind,
1604 b_bs,
1605 );
1606 if let Some(Step::Matmul { b_off_f32, .. }) = schedule.last_mut() {
1607 *b_off_f32 = b_off_adj;
1608 }
1609 uniforms.push(u);
1610 bind_groups.push(bg);
1611 if compute_precision == MatmulCompute::CoopF16Vk {
1612 coop_f16_vk_wide_bind_groups.insert(
1613 schedule.len() - 1,
1614 bind_two_buf0_window(
1615 &dev.device,
1616 mm_w_active_compile,
1617 &arena.buffer,
1618 base,
1619 size,
1620 &uniforms[uniforms.len() - 1],
1621 ),
1622 );
1623 }
1624 }
1625 Op::Binary(bop) => {
1626 if skip_adds.contains(&node.id) {
1631 continue;
1632 }
1633 require_equal_shapes(&graph, &node.inputs, "Binary");
1634 let a_id = node.inputs[0];
1635 let b_id = node.inputs[1];
1636 let win_ids = [node.id, a_id, b_id];
1637 let max_binding = dev.device.limits().max_storage_buffer_binding_size;
1638 let fits = arena_span_bytes(&arena, &win_ids) <= max_binding;
1639 let mut scratch = arena.scratch_off as u64;
1640 let (mut base, mut size, param_anchor) = arena_multi_op_window(
1641 &dev.device,
1642 &arena,
1643 &graph,
1644 ¶m_offsets,
1645 &mut schedule,
1646 &mut scratch,
1647 &win_ids,
1648 );
1649 if !fits && !param_anchor {
1650 base = arena_bind_window_covering_scratch_if_needed(
1651 &arena, base, size, scratch,
1652 );
1653 }
1654 let a_off = arena_off_in_bind_window(
1655 &graph,
1656 ¶m_offsets,
1657 &dev.device,
1658 &arena,
1659 &mut schedule,
1660 &mut scratch,
1661 a_id,
1662 &mut base,
1663 &mut size,
1664 );
1665 let b_off = arena_off_in_bind_window(
1666 &graph,
1667 ¶m_offsets,
1668 &dev.device,
1669 &arena,
1670 &mut schedule,
1671 &mut scratch,
1672 b_id,
1673 &mut base,
1674 &mut size,
1675 );
1676 let p = BinaryParams {
1677 n: elems,
1678 a_off,
1679 b_off,
1680 c_off: arena_local_off_f32(&arena, node.id, base),
1681 op: binary_op_id(*bop),
1682 _p0: 0,
1683 _p1: 0,
1684 _p2: 0,
1685 };
1686 schedule.push(Step::Binary { params: p });
1687 let u = emit_uniform(std::mem::size_of::<BinaryParams>());
1688 let bg = bind_two_buf0_window(&dev.device, bk, &arena.buffer, base, size, &u);
1689 uniforms.push(u);
1690 bind_groups.push(bg);
1691 }
1692 Op::Compare(cop) => {
1693 require_equal_shapes(&graph, &node.inputs, "Compare");
1694 let (mut base, size) = arena_window_for_nodes(&dev.device, &arena, &[node.id]);
1695 let a_id = node.inputs[0];
1696 let b_id = node.inputs[1];
1697 let a_src = arena.offset(a_id) as u64;
1698 let b_src = arena.offset(b_id) as u64;
1699 let a_len = arena.len_of(a_id) as u64;
1700 let b_len = arena.len_of(b_id) as u64;
1701 let a_in = a_src >= base && a_src + a_len <= base + size;
1702 let b_in = b_src >= base && b_src + b_len <= base + size;
1703 let a_dst = arena.scratch_off as u64;
1704 let a_aligned = a_len.div_ceil(256) * 256;
1705 let b_dst = a_dst + a_aligned;
1706 if a_dst < base || b_dst + b_len > base + size {
1707 base = (arena.size as u64).saturating_sub(size);
1708 base = (base / 256) * 256;
1709 }
1710 let a_off = if a_in {
1711 arena_local_off_f32(&arena, a_id, base)
1712 } else {
1713 if a_len > 64 * 1024 * 1024 {
1714 panic!("rlx-wgpu: Compare staging operand A too large ({a_len} bytes)");
1715 }
1716 schedule.push(Step::BufferCopy {
1717 src_byte_off: a_src as u32,
1718 dst_byte_off: a_dst as u32,
1719 bytes: a_len as u32,
1720 });
1721 ((a_dst.saturating_sub(base)) / 4) as u32
1722 };
1723 let b_off = if b_in {
1724 arena_local_off_f32(&arena, b_id, base)
1725 } else {
1726 if b_len > 64 * 1024 * 1024 {
1727 panic!("rlx-wgpu: Compare staging operand B too large ({b_len} bytes)");
1728 }
1729 schedule.push(Step::BufferCopy {
1730 src_byte_off: b_src as u32,
1731 dst_byte_off: b_dst as u32,
1732 bytes: b_len as u32,
1733 });
1734 ((b_dst.saturating_sub(base)) / 4) as u32
1735 };
1736 let p = BinaryParams {
1737 n: elems,
1738 a_off,
1739 b_off,
1740 c_off: arena_local_off_f32(&arena, node.id, base),
1741 op: compare_op_id(*cop),
1742 _p0: 0,
1743 _p1: 0,
1744 _p2: 0,
1745 };
1746 schedule.push(Step::Compare { params: p });
1747 let u = emit_uniform(std::mem::size_of::<BinaryParams>());
1748 let bg = bind_two_buf0_window(&dev.device, ck, &arena.buffer, base, size, &u);
1749 uniforms.push(u);
1750 bind_groups.push(bg);
1751 }
1752 Op::Activation(act) => {
1753 if coop_f16_vk_mirror_acts.contains(&node.id) {
1754 let src_name =
1755 tensor_host_name(&input_offsets, ¶m_offsets, node.inputs[0]);
1756 coop_f16_host_activations.push((node.id, *act, src_name));
1757 continue;
1758 }
1759 let in_id = node.inputs[0];
1760 let win_ids = [node.id, in_id];
1761 let max_binding = dev.device.limits().max_storage_buffer_binding_size;
1762 let fits = arena_span_bytes(&arena, &win_ids) <= max_binding;
1763 let mut scratch = arena.scratch_off as u64;
1764 let (mut base, mut size, param_anchor) = arena_multi_op_window(
1765 &dev.device,
1766 &arena,
1767 &graph,
1768 ¶m_offsets,
1769 &mut schedule,
1770 &mut scratch,
1771 &win_ids,
1772 );
1773 if !fits && !param_anchor {
1774 base = arena_bind_window_covering_scratch_if_needed(
1775 &arena, base, size, scratch,
1776 );
1777 }
1778 let in_off = arena_off_in_bind_window(
1779 &graph,
1780 ¶m_offsets,
1781 &dev.device,
1782 &arena,
1783 &mut schedule,
1784 &mut scratch,
1785 in_id,
1786 &mut base,
1787 &mut size,
1788 );
1789 let p = UnaryParams {
1790 n: elems,
1791 in_off,
1792 out_off: arena_local_off_f32(&arena, node.id, base),
1793 op: activation_op_id(*act),
1794 _p0: 0,
1795 _p1: 0,
1796 _p2: 0,
1797 _p3: 0,
1798 };
1799 schedule.push(Step::Unary {
1800 params: p,
1801 f16_mirror: false,
1802 });
1803 let u = emit_uniform(std::mem::size_of::<UnaryParams>());
1804 let bg = bind_two_buf0_window(&dev.device, uk, &arena.buffer, base, size, &u);
1805 uniforms.push(u);
1806 bind_groups.push(bg);
1807 }
1808 Op::Where => {
1809 let (mut base, size) = arena_window_for_nodes(&dev.device, &arena, &[node.id]);
1810 let cond_id = node.inputs[0];
1811 let x_id = node.inputs[1];
1812 let y_id = node.inputs[2];
1813 let cond_src = arena.offset(cond_id) as u64;
1814 let x_src = arena.offset(x_id) as u64;
1815 let y_src = arena.offset(y_id) as u64;
1816 let cond_len = arena.len_of(cond_id) as u64;
1817 let x_len = arena.len_of(x_id) as u64;
1818 let y_len = arena.len_of(y_id) as u64;
1819 let cond_in = cond_src >= base && cond_src + cond_len <= base + size;
1820 let x_in = x_src >= base && x_src + x_len <= base + size;
1821 let y_in = y_src >= base && y_src + y_len <= base + size;
1822 let cond_dst = arena.scratch_off as u64;
1823 let cond_aligned = cond_len.div_ceil(256) * 256;
1824 let x_dst = cond_dst + cond_aligned;
1825 let x_aligned = x_len.div_ceil(256) * 256;
1826 let y_dst = x_dst + x_aligned;
1827 if cond_dst < base || y_dst + y_len > base + size {
1828 base = (arena.size as u64).saturating_sub(size);
1829 base = (base / 256) * 256;
1830 }
1831 let cond_off = if cond_in {
1832 arena_local_off_f32(&arena, cond_id, base)
1833 } else {
1834 if cond_len > 64 * 1024 * 1024 {
1835 panic!("rlx-wgpu: Where staging cond too large ({cond_len} bytes)");
1836 }
1837 schedule.push(Step::BufferCopy {
1838 src_byte_off: cond_src as u32,
1839 dst_byte_off: cond_dst as u32,
1840 bytes: cond_len as u32,
1841 });
1842 ((cond_dst.saturating_sub(base)) / 4) as u32
1843 };
1844 let x_off = if x_in {
1845 arena_local_off_f32(&arena, x_id, base)
1846 } else {
1847 if x_len > 64 * 1024 * 1024 {
1848 panic!("rlx-wgpu: Where staging x too large ({x_len} bytes)");
1849 }
1850 schedule.push(Step::BufferCopy {
1851 src_byte_off: x_src as u32,
1852 dst_byte_off: x_dst as u32,
1853 bytes: x_len as u32,
1854 });
1855 ((x_dst.saturating_sub(base)) / 4) as u32
1856 };
1857 let y_off = if y_in {
1858 arena_local_off_f32(&arena, y_id, base)
1859 } else {
1860 if y_len > 64 * 1024 * 1024 {
1861 panic!("rlx-wgpu: Where staging y too large ({y_len} bytes)");
1862 }
1863 schedule.push(Step::BufferCopy {
1864 src_byte_off: y_src as u32,
1865 dst_byte_off: y_dst as u32,
1866 bytes: y_len as u32,
1867 });
1868 ((y_dst.saturating_sub(base)) / 4) as u32
1869 };
1870 let p = WhereParams {
1871 n: elems,
1872 cond_off,
1873 x_off,
1874 y_off,
1875 out_off: arena_local_off_f32(&arena, node.id, base),
1876 _p0: 0,
1877 _p1: 0,
1878 _p2: 0,
1879 };
1880 schedule.push(Step::Where { params: p });
1881 let u = emit_uniform(std::mem::size_of::<WhereParams>());
1882 let bg = bind_two_buf0_window(&dev.device, wk, &arena.buffer, base, size, &u);
1883 uniforms.push(u);
1884 bind_groups.push(bg);
1885 }
1886
1887 Op::BatchElementwiseRegion {
1888 chain,
1889 num_batch_inputs,
1890 scalar_input_mask,
1891 input_modulus,
1892 prologue,
1893 prologue_input,
1894 } => {
1895 let n = *num_batch_inputs as usize;
1896 if n == 0 || chain.len() > 32 {
1897 panic!(
1898 "rlx-wgpu BatchElementwiseRegion: num_batch_inputs={n} steps={}",
1899 chain.len()
1900 );
1901 }
1902 let slice_shape = rlx_ir::batch_region_slice_shape(&node.shape);
1903 let slice_elems = rlx_ir::batch_region_slice_elems(&node.shape, n)
1904 .expect("batch region static shape");
1905 let mut win_ids: Vec<NodeId> = vec![node.id];
1906 win_ids.extend(node.inputs.iter().copied());
1907 let max_binding = dev.device.limits().max_storage_buffer_binding_size;
1908 let fits = arena_span_bytes(&arena, &win_ids) <= max_binding;
1909 let mut scratch = arena.scratch_off as u64;
1910 let (mut base, mut size, param_anchor) = arena_multi_op_window(
1911 &dev.device,
1912 &arena,
1913 &graph,
1914 ¶m_offsets,
1915 &mut schedule,
1916 &mut scratch,
1917 &win_ids,
1918 );
1919 if !fits && !param_anchor {
1920 base = arena_bind_window_covering_scratch_if_needed(
1921 &arena, base, size, scratch,
1922 );
1923 }
1924 let chain_enc = rlx_ir::encode_chain_steps(chain);
1925 let tail =
1926 rlx_ir::encode_prologue_tail(*prologue, &slice_shape, *prologue_input);
1927 let base_dst = arena_local_off_f32(&arena, node.id, base);
1928 let use_single = rlx_ir::fk_batch_use_single_launch(n, *prologue);
1929 if use_single {
1930 let mut batch_input_offs = [0u32; 64];
1931 for i in 0..n {
1932 batch_input_offs[i] = arena_off_in_bind_window(
1933 &graph,
1934 ¶m_offsets,
1935 &dev.device,
1936 &arena,
1937 &mut schedule,
1938 &mut scratch,
1939 node.inputs[i],
1940 &mut base,
1941 &mut size,
1942 );
1943 }
1944 let p = BatchElementwiseRegionParams {
1945 slice_len: slice_elems,
1946 num_batch: n as u32,
1947 num_steps: chain.len() as u32,
1948 base_dst_off: base_dst,
1949 slice_elems,
1950 batch_input_offs,
1951 chain: chain_enc,
1952 scalar_input_mask: *scalar_input_mask,
1953 input_modulus: *input_modulus,
1954 };
1955 schedule.push(Step::BatchElementwiseRegion { params: p });
1956 let ek = batch_elementwise_region_kernel(&dev.device);
1957 let u = dev.device.create_buffer(&wgpu::BufferDescriptor {
1958 label: Some("rlx-wgpu batch region params"),
1959 size: std::mem::size_of::<BatchElementwiseRegionParams>() as u64,
1960 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
1961 mapped_at_creation: false,
1962 });
1963 let bg =
1964 bind_two_buf0_window(&dev.device, ek, &arena.buffer, base, size, &u);
1965 uniforms.push(u);
1966 bind_groups.push(bg);
1967 } else {
1968 let spatial = tail[0] == rlx_ir::REGION_PROLOGUE_RESIZE_NEAREST_2X_NCHW;
1969 let ek = if spatial {
1970 elementwise_region_spatial_kernel(&dev.device)
1971 } else {
1972 elementwise_region_kernel(&dev.device)
1973 };
1974 for i in 0..n {
1975 let mut input_offs = [0u32; 16];
1976 input_offs[0] = arena_off_in_bind_window(
1977 &graph,
1978 ¶m_offsets,
1979 &dev.device,
1980 &arena,
1981 &mut schedule,
1982 &mut scratch,
1983 node.inputs[i],
1984 &mut base,
1985 &mut size,
1986 );
1987 let p = ElementwiseRegionParams {
1988 len: slice_elems,
1989 num_inputs: 1,
1990 num_steps: chain.len() as u32,
1991 dst_off: rlx_ir::batch_region_slice_dst_off_f32(
1992 base_dst,
1993 slice_elems,
1994 i,
1995 ),
1996 input_offs,
1997 chain: chain_enc,
1998 scalar_input_mask: *scalar_input_mask,
1999 prologue: tail[0],
2000 out_n: tail[1],
2001 out_c: tail[2],
2002 out_h: tail[3],
2003 out_w: tail[4],
2004 prologue_input: tail[5],
2005 input_modulus: *input_modulus,
2006 };
2007 schedule.push(Step::ElementwiseRegion { params: p });
2008 let u = dev.device.create_buffer(&wgpu::BufferDescriptor {
2009 label: Some("rlx-wgpu batch region params"),
2010 size: std::mem::size_of::<ElementwiseRegionParams>() as u64,
2011 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
2012 mapped_at_creation: false,
2013 });
2014 let bg = bind_two_buf0_window(
2015 &dev.device,
2016 ek,
2017 &arena.buffer,
2018 base,
2019 size,
2020 &u,
2021 );
2022 uniforms.push(u);
2023 bind_groups.push(bg);
2024 }
2025 }
2026 }
2027 Op::ElementwiseRegion {
2028 chain,
2029 num_inputs,
2030 scalar_input_mask,
2031 input_modulus,
2032 prologue,
2033 prologue_input,
2034 } => {
2035 let n = *num_inputs as usize;
2038 if n > 16 || chain.len() > 32 {
2039 panic!(
2040 "rlx-wgpu ElementwiseRegion: chain too large \
2041 (inputs={n}, steps={}). Caps: 16 / 32. \
2042 Use UnfuseElementwiseRegions to fall back.",
2043 chain.len()
2044 );
2045 }
2046 let mut win_ids: Vec<NodeId> = vec![node.id];
2047 win_ids.extend(node.inputs.iter().copied());
2048 let max_binding = dev.device.limits().max_storage_buffer_binding_size;
2049 let fits = arena_span_bytes(&arena, &win_ids) <= max_binding;
2050 let mut scratch = arena.scratch_off as u64;
2051 let (mut base, mut size, param_anchor) = arena_multi_op_window(
2052 &dev.device,
2053 &arena,
2054 &graph,
2055 ¶m_offsets,
2056 &mut schedule,
2057 &mut scratch,
2058 &win_ids,
2059 );
2060 if !fits && !param_anchor {
2061 base = arena_bind_window_covering_scratch_if_needed(
2062 &arena, base, size, scratch,
2063 );
2064 }
2065 let mut input_offs = [0u32; 16];
2066 for (i, &id) in node.inputs.iter().enumerate() {
2067 input_offs[i] = arena_off_in_bind_window(
2068 &graph,
2069 ¶m_offsets,
2070 &dev.device,
2071 &arena,
2072 &mut schedule,
2073 &mut scratch,
2074 id,
2075 &mut base,
2076 &mut size,
2077 );
2078 }
2079 let chain_enc = rlx_ir::encode_chain_steps(chain);
2080 let tail =
2081 rlx_ir::encode_prologue_tail(*prologue, &node.shape, *prologue_input);
2082 let p = ElementwiseRegionParams {
2083 len: elems,
2084 num_inputs: *num_inputs,
2085 num_steps: chain.len() as u32,
2086 dst_off: arena_local_off_f32(&arena, node.id, base),
2087 input_offs,
2088 chain: chain_enc,
2089 scalar_input_mask: *scalar_input_mask,
2090 prologue: tail[0],
2091 out_n: tail[1],
2092 out_c: tail[2],
2093 out_h: tail[3],
2094 out_w: tail[4],
2095 prologue_input: tail[5],
2096 input_modulus: *input_modulus,
2097 };
2098 schedule.push(Step::ElementwiseRegion { params: p });
2099 let ek = if p.prologue == rlx_ir::REGION_PROLOGUE_RESIZE_NEAREST_2X_NCHW {
2100 elementwise_region_spatial_kernel(&dev.device)
2101 } else {
2102 elementwise_region_kernel(&dev.device)
2103 };
2104 let u = dev.device.create_buffer(&wgpu::BufferDescriptor {
2108 label: Some("rlx-wgpu region params"),
2109 size: std::mem::size_of::<ElementwiseRegionParams>() as u64,
2110 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
2111 mapped_at_creation: false,
2112 });
2113 let bg = bind_two_buf0_window(&dev.device, ek, &arena.buffer, base, size, &u);
2114 uniforms.push(u);
2115 bind_groups.push(bg);
2116 }
2117
2118 Op::Reduce {
2119 op: rop,
2120 axes,
2121 keep_dim: _,
2122 } => {
2123 let in_id = node.inputs[0];
2131 let in_shape = graph.node(in_id).shape.dims();
2132 let mut sorted = axes.clone();
2133 sorted.sort_unstable();
2134 let contiguous = sorted.windows(2).all(|w| w[1] == w[0] + 1);
2135 if !contiguous {
2136 panic!(
2137 "rlx-wgpu Reduce: non-contiguous axes not yet wired \
2138 (got axes={axes:?}, rank={})",
2139 in_shape.len()
2140 );
2141 }
2142 let ax_first = sorted[0];
2143 let ax_last = *sorted.last().unwrap();
2144 let dims_u32: Vec<u32> =
2145 in_shape.iter().map(|d| d.unwrap_static() as u32).collect();
2146 let outer: u32 = dims_u32[..ax_first].iter().product();
2147 let reduce_dim: u32 = dims_u32[ax_first..=ax_last].iter().product();
2148 let inner: u32 = dims_u32[ax_last + 1..].iter().product();
2149 let red_ids = [node.id, in_id];
2150 let max_binding = dev.device.limits().max_storage_buffer_binding_size;
2151 let red_fits = arena_span_bytes(&arena, &red_ids) <= max_binding;
2152 let mut scratch = arena.scratch_off as u64;
2153 let (mut base, mut size, param_anchor) = arena_multi_op_window(
2154 &dev.device,
2155 &arena,
2156 &graph,
2157 ¶m_offsets,
2158 &mut schedule,
2159 &mut scratch,
2160 &red_ids,
2161 );
2162 if !red_fits && !param_anchor {
2163 base = arena_bind_window_covering_scratch_if_needed(
2164 &arena, base, size, scratch,
2165 );
2166 }
2167 let in_off = arena_off_in_bind_window(
2168 &graph,
2169 ¶m_offsets,
2170 &dev.device,
2171 &arena,
2172 &mut schedule,
2173 &mut scratch,
2174 in_id,
2175 &mut base,
2176 &mut size,
2177 );
2178 let p = ReduceParams {
2179 outer,
2180 reduce_dim,
2181 inner,
2182 in_off,
2183 out_off: arena_local_off_f32(&arena, node.id, base),
2184 op: reduce_op_id(*rop),
2185 _p0: 0,
2186 _p1: 0,
2187 };
2188 schedule.push(Step::Reduce { params: p });
2189 let rk = reduce_kernel(&dev.device);
2190 let u = emit_uniform(std::mem::size_of::<ReduceParams>());
2191 let bg = bind_two_buf0_window(&dev.device, rk, &arena.buffer, base, size, &u);
2192 uniforms.push(u);
2193 bind_groups.push(bg);
2194 }
2195
2196 Op::Softmax { axis } => {
2197 let in_id = node.inputs[0];
2198 let in_shape = graph.node(in_id).shape.dims();
2199 let last = (in_shape.len() - 1) as i32;
2200 if *axis != -1 && *axis != last {
2201 panic!("rlx-wgpu Softmax: only last-axis wired (got axis={axis})");
2202 }
2203 let inner = in_shape[in_shape.len() - 1].unwrap_static() as u32;
2204 let total: u32 = in_shape.iter().map(|d| d.unwrap_static() as u32).product();
2205 let outer = total / inner.max(1);
2206 let sm_ids = [node.id, in_id];
2207 let max_binding = dev.device.limits().max_storage_buffer_binding_size;
2208 let sm_fits = arena_span_bytes(&arena, &sm_ids) <= max_binding;
2209 let mut scratch = arena.scratch_off as u64;
2210 let (mut base, mut size, param_anchor) = arena_multi_op_window(
2211 &dev.device,
2212 &arena,
2213 &graph,
2214 ¶m_offsets,
2215 &mut schedule,
2216 &mut scratch,
2217 &sm_ids,
2218 );
2219 if !sm_fits && !param_anchor {
2220 base = arena_bind_window_covering_scratch_if_needed(
2221 &arena, base, size, scratch,
2222 );
2223 }
2224 let in_off = arena_off_in_bind_window(
2225 &graph,
2226 ¶m_offsets,
2227 &dev.device,
2228 &arena,
2229 &mut schedule,
2230 &mut scratch,
2231 in_id,
2232 &mut base,
2233 &mut size,
2234 );
2235 let p = SoftmaxParams {
2236 outer,
2237 inner,
2238 in_off,
2239 out_off: arena_local_off_f32(&arena, node.id, base),
2240 _p0: 0,
2241 _p1: 0,
2242 _p2: 0,
2243 _p3: 0,
2244 };
2245 schedule.push(Step::Softmax { params: p });
2246 let sk = softmax_kernel(&dev.device);
2247 let u = emit_uniform(std::mem::size_of::<SoftmaxParams>());
2248 let bg = bind_two_buf0_window(&dev.device, sk, &arena.buffer, base, size, &u);
2249 uniforms.push(u);
2250 bind_groups.push(bg);
2251 }
2252
2253 Op::LayerNorm { axis: _, eps } | Op::RmsNorm { axis: _, eps } => {
2254 let in_id = node.inputs[0];
2255 let in_shape = graph.node(in_id).shape.dims();
2256 let inner = in_shape[in_shape.len() - 1].unwrap_static() as u32;
2257 let total: u32 = in_shape.iter().map(|d| d.unwrap_static() as u32).product();
2258 let outer = total / inner.max(1);
2259 let is_layer_norm = matches!(&node.op, Op::LayerNorm { .. });
2260
2261 if is_layer_norm
2268 && let Some(&(h_id, delta_id, gamma_id, beta_id, sum_id)) =
2269 ln_to_tee.get(&node.id)
2270 {
2271 let gamma_is_param =
2272 tensor_is_graph_param(&graph, ¶m_offsets, gamma_id);
2273 let gamma_bytes = arena.len_of(gamma_id) as u64;
2274 let frlt_win: Vec<NodeId> =
2275 if gamma_is_param && gamma_bytes > ARENA_STAGE_CAP {
2276 vec![gamma_id, node.id, h_id, delta_id, beta_id, sum_id]
2277 } else {
2278 vec![node.id, h_id, delta_id, gamma_id, beta_id, sum_id]
2279 };
2280 let mut scratch = arena.scratch_off as u64;
2281 let (mut base, mut size, param_anchor) = arena_multi_op_window(
2282 &dev.device,
2283 &arena,
2284 &graph,
2285 ¶m_offsets,
2286 &mut schedule,
2287 &mut scratch,
2288 &frlt_win,
2289 );
2290 if !param_anchor {
2291 base = arena_bind_window_covering_scratch_if_needed(
2292 &arena, base, size, scratch,
2293 );
2294 }
2295 let in_off = arena_off_in_bind_window(
2296 &graph,
2297 ¶m_offsets,
2298 &dev.device,
2299 &arena,
2300 &mut schedule,
2301 &mut scratch,
2302 h_id,
2303 &mut base,
2304 &mut size,
2305 );
2306 let residual_off = arena_off_in_bind_window(
2307 &graph,
2308 ¶m_offsets,
2309 &dev.device,
2310 &arena,
2311 &mut schedule,
2312 &mut scratch,
2313 delta_id,
2314 &mut base,
2315 &mut size,
2316 );
2317 let sum_off = arena_off_in_bind_window(
2318 &graph,
2319 ¶m_offsets,
2320 &dev.device,
2321 &arena,
2322 &mut schedule,
2323 &mut scratch,
2324 sum_id,
2325 &mut base,
2326 &mut size,
2327 );
2328 let gamma_off = arena_off_in_bind_window(
2329 &graph,
2330 ¶m_offsets,
2331 &dev.device,
2332 &arena,
2333 &mut schedule,
2334 &mut scratch,
2335 gamma_id,
2336 &mut base,
2337 &mut size,
2338 );
2339 let beta_off = arena_off_in_bind_window(
2340 &graph,
2341 ¶m_offsets,
2342 &dev.device,
2343 &arena,
2344 &mut schedule,
2345 &mut scratch,
2346 beta_id,
2347 &mut base,
2348 &mut size,
2349 );
2350 let p = FusedResidualLnTeeParams {
2351 outer,
2352 inner,
2353 in_off,
2354 residual_off,
2355 bias_off: 0, gamma_off,
2357 beta_off,
2358 sum_off,
2359 ln_out_off: arena_local_off_f32(&arena, node.id, base),
2360 eps_bits: eps.to_bits(),
2361 has_bias: 0,
2362 _p0: 0,
2363 };
2364 schedule.push(Step::FusedResidualLnTee { params: p });
2365 let frtk = fused_residual_ln_tee_kernel(&dev.device);
2366 let u = emit_uniform(std::mem::size_of::<FusedResidualLnTeeParams>());
2367 let bg =
2368 bind_two_buf0_window(&dev.device, frtk, &arena.buffer, base, size, &u);
2369 uniforms.push(u);
2370 bind_groups.push(bg);
2371 continue;
2372 }
2373
2374 let gamma_id = node.inputs[1];
2375 let beta_id = if is_layer_norm && node.inputs.len() >= 3 {
2378 node.inputs[2]
2379 } else {
2380 gamma_id
2383 };
2384 let gamma_is_param = tensor_is_graph_param(&graph, ¶m_offsets, gamma_id);
2385 let gamma_bytes = arena.len_of(gamma_id) as u64;
2386 let ln_win: Vec<NodeId> = if gamma_is_param && gamma_bytes > ARENA_STAGE_CAP {
2387 vec![gamma_id, node.id, in_id]
2388 } else {
2389 let mut v = vec![node.id, in_id];
2390 if gamma_is_param {
2391 v.push(gamma_id);
2392 }
2393 if is_layer_norm {
2394 v.push(beta_id);
2395 }
2396 v
2397 };
2398 let max_binding = dev.device.limits().max_storage_buffer_binding_size;
2399 let ln_fits = arena_span_bytes(&arena, &ln_win) <= max_binding;
2400 let mut scratch = arena.scratch_off as u64;
2401 let (mut base, mut size, param_anchor) = arena_multi_op_window(
2402 &dev.device,
2403 &arena,
2404 &graph,
2405 ¶m_offsets,
2406 &mut schedule,
2407 &mut scratch,
2408 &ln_win,
2409 );
2410 if !ln_fits && !param_anchor {
2411 base = arena_bind_window_covering_scratch_if_needed(
2412 &arena, base, size, scratch,
2413 );
2414 }
2415 let in_off = arena_off_in_bind_window(
2416 &graph,
2417 ¶m_offsets,
2418 &dev.device,
2419 &arena,
2420 &mut schedule,
2421 &mut scratch,
2422 in_id,
2423 &mut base,
2424 &mut size,
2425 );
2426 let gamma_off = arena_off_in_bind_window(
2427 &graph,
2428 ¶m_offsets,
2429 &dev.device,
2430 &arena,
2431 &mut schedule,
2432 &mut scratch,
2433 gamma_id,
2434 &mut base,
2435 &mut size,
2436 );
2437 let beta_off = arena_off_in_bind_window(
2438 &graph,
2439 ¶m_offsets,
2440 &dev.device,
2441 &arena,
2442 &mut schedule,
2443 &mut scratch,
2444 beta_id,
2445 &mut base,
2446 &mut size,
2447 );
2448 let p = LayerNormParams {
2449 outer,
2450 inner,
2451 in_off,
2452 out_off: arena_local_off_f32(&arena, node.id, base),
2453 gamma_off,
2454 beta_off,
2455 eps_bits: eps.to_bits(),
2456 op: if is_layer_norm { 0 } else { 1 },
2457 };
2458 schedule.push(Step::LayerNorm { params: p });
2459 let lk = layernorm_kernel(&dev.device);
2460 let u = emit_uniform(std::mem::size_of::<LayerNormParams>());
2461 let bg = bind_two_buf0_window(&dev.device, lk, &arena.buffer, base, size, &u);
2462 uniforms.push(u);
2463 bind_groups.push(bg);
2464 }
2465
2466 Op::Reshape { .. } | Op::Cast { .. } => {
2467 }
2469
2470 Op::Transpose { perm } => {
2471 let in_id = node.inputs[0];
2472 let in_shape = graph.node(in_id).shape.dims();
2473 let out_shape = node.shape.dims();
2474 let rank = perm.len();
2475 if rank != in_shape.len() || rank != out_shape.len() {
2476 panic!("rlx-wgpu Transpose: rank mismatch");
2477 }
2478 let in_dims: Vec<u32> =
2479 in_shape.iter().map(|d| d.unwrap_static() as u32).collect();
2480 let out_dims: Vec<u32> =
2481 out_shape.iter().map(|d| d.unwrap_static() as u32).collect();
2482 let mut in_strides = vec![1u32; rank];
2484 for i in (0..rank.saturating_sub(1)).rev() {
2485 in_strides[i] = in_strides[i + 1] * in_dims[i + 1];
2486 }
2487 let strides_for_out: Vec<u32> =
2490 (0..rank).map(|i| in_strides[perm[i]]).collect();
2491
2492 let mut meta_data: Vec<u32> = Vec::with_capacity(rank * 2);
2494 meta_data.extend_from_slice(&out_dims);
2495 meta_data.extend_from_slice(&strides_for_out);
2496 let meta_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
2497 label: Some("rlx-wgpu transpose meta"),
2498 size: (meta_data.len() * 4).max(4) as u64,
2499 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
2500 mapped_at_creation: false,
2501 });
2502 dev.queue
2503 .write_buffer(&meta_buf, 0, bytemuck::cast_slice(&meta_data));
2504 let meta_idx = meta_buffers.len();
2505 meta_buffers.push(meta_buf);
2506
2507 let bucket_outermost = if perm[0] == 0 { 1u32 } else { 0u32 };
2511 let tr_ids = [node.id, in_id];
2512 let max_binding = dev.device.limits().max_storage_buffer_binding_size;
2513 let in_is_param = tensor_is_graph_param(&graph, ¶m_offsets, in_id);
2514 let in_bytes = arena.len_of(in_id) as u64;
2515 let (mut base, mut size) = if in_is_param && in_bytes <= max_binding {
2516 arena_window_for_nodes(&dev.device, &arena, &[in_id])
2517 } else if arena_span_bytes(&arena, &tr_ids) <= max_binding {
2518 arena_window_for_nodes(&dev.device, &arena, &tr_ids)
2519 } else {
2520 arena_window_for_nodes(&dev.device, &arena, &[node.id])
2521 };
2522 let mut scratch = arena.scratch_off as u64;
2523 let in_off = arena_off_in_bind_window(
2524 &graph,
2525 ¶m_offsets,
2526 &dev.device,
2527 &arena,
2528 &mut schedule,
2529 &mut scratch,
2530 in_id,
2531 &mut base,
2532 &mut size,
2533 );
2534 let out_off = arena_off_in_bind_window(
2535 &graph,
2536 ¶m_offsets,
2537 &dev.device,
2538 &arena,
2539 &mut schedule,
2540 &mut scratch,
2541 node.id,
2542 &mut base,
2543 &mut size,
2544 );
2545 let p = TransposeParams {
2546 rank: rank as u32,
2547 out_total: elems,
2548 in_off,
2549 out_off,
2550 bucket_outermost,
2551 out_dim_0: out_dims[0],
2552 _p2: 0,
2553 _p3: 0,
2554 };
2555 schedule.push(Step::Transpose {
2556 params: p,
2557 meta_idx,
2558 });
2559 let tk = transpose_kernel(&dev.device);
2560 let u = emit_uniform(std::mem::size_of::<TransposeParams>());
2561 let bg = dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
2562 label: Some("rlx-wgpu transpose bg"),
2563 layout: &tk.bgl,
2564 entries: &[
2565 wgpu::BindGroupEntry {
2566 binding: 0,
2567 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
2568 buffer: &arena.buffer,
2569 offset: base,
2570 size: NonZeroU64::new(size),
2571 }),
2572 },
2573 wgpu::BindGroupEntry {
2574 binding: 1,
2575 resource: u.as_entire_binding(),
2576 },
2577 wgpu::BindGroupEntry {
2578 binding: 2,
2579 resource: meta_buffers[meta_idx].as_entire_binding(),
2580 },
2581 ],
2582 });
2583 uniforms.push(u);
2584 bind_groups.push(bg);
2585 }
2586
2587 Op::Narrow { axis, start, len } => {
2588 if qkv_skip_narrows.contains(&node.id)
2593 || packed_bshd_skip_narrows.contains(&node.id)
2594 {
2595 continue;
2596 }
2597 let in_id = node.inputs[0];
2598 let in_shape = graph.node(in_id).shape.dims();
2599 let outer: u32 = in_shape[..*axis]
2600 .iter()
2601 .map(|d| d.unwrap_static() as u32)
2602 .product::<u32>()
2603 .max(1);
2604 let inner: u32 = in_shape[*axis + 1..]
2605 .iter()
2606 .map(|d| d.unwrap_static() as u32)
2607 .product::<u32>()
2608 .max(1);
2609 let axis_in = in_shape[*axis].unwrap_static() as u32;
2610 let win_ids = [node.id, in_id];
2611 let max_binding = dev.device.limits().max_storage_buffer_binding_size;
2612 let fits = arena_span_bytes(&arena, &win_ids) <= max_binding;
2613 let mut scratch = arena.scratch_off as u64;
2614 let (mut base, mut size, param_anchor) = arena_multi_op_window(
2615 &dev.device,
2616 &arena,
2617 &graph,
2618 ¶m_offsets,
2619 &mut schedule,
2620 &mut scratch,
2621 &win_ids,
2622 );
2623 if !fits && !param_anchor {
2624 base = arena_bind_window_covering_scratch_if_needed(
2625 &arena, base, size, scratch,
2626 );
2627 }
2628 let in_off = arena_off_in_bind_window(
2629 &graph,
2630 ¶m_offsets,
2631 &dev.device,
2632 &arena,
2633 &mut schedule,
2634 &mut scratch,
2635 in_id,
2636 &mut base,
2637 &mut size,
2638 );
2639 let out_off = arena_off_in_bind_window(
2640 &graph,
2641 ¶m_offsets,
2642 &dev.device,
2643 &arena,
2644 &mut schedule,
2645 &mut scratch,
2646 node.id,
2647 &mut base,
2648 &mut size,
2649 );
2650 let p = NarrowConcatParams {
2651 total: elems,
2652 outer,
2653 inner,
2654 axis_in_size: axis_in,
2655 axis_out_size: *len as u32,
2656 start: *start as u32,
2657 in_off,
2658 out_off,
2659 };
2660 schedule.push(Step::Narrow { params: p });
2661 let nk = narrow_kernel(&dev.device);
2662 let u = emit_uniform(std::mem::size_of::<NarrowConcatParams>());
2663 let bg = bind_two_buf0_window(&dev.device, nk, &arena.buffer, base, size, &u);
2664 uniforms.push(u);
2665 bind_groups.push(bg);
2666 }
2667
2668 Op::Concat { axis } => {
2669 let out_shape = node.shape.dims();
2670 let outer: u32 = out_shape[..*axis]
2671 .iter()
2672 .map(|d| d.unwrap_static() as u32)
2673 .product::<u32>()
2674 .max(1);
2675 let inner: u32 = out_shape[*axis + 1..]
2676 .iter()
2677 .map(|d| d.unwrap_static() as u32)
2678 .product::<u32>()
2679 .max(1);
2680 let axis_out = out_shape[*axis].unwrap_static() as u32;
2681
2682 let all_ids: Vec<NodeId> = std::iter::once(node.id)
2683 .chain(node.inputs.iter().copied())
2684 .collect();
2685 let max_binding = dev.device.limits().max_storage_buffer_binding_size;
2686 let fits_all = arena_span_bytes(&arena, &all_ids) <= max_binding;
2687 let mut scratch = arena.scratch_off as u64;
2688 let (mut base, mut size, param_anchor) = arena_multi_op_window(
2689 &dev.device,
2690 &arena,
2691 &graph,
2692 ¶m_offsets,
2693 &mut schedule,
2694 &mut scratch,
2695 &all_ids,
2696 );
2697 arena_expand_bind_window(&arena, &all_ids, &mut base, &mut size, max_binding);
2698 if !fits_all && !param_anchor {
2699 base = arena_bind_window_covering_scratch_if_needed(
2700 &arena, base, size, scratch,
2701 );
2702 }
2703 let out_off = arena_local_off_f32(&arena, node.id, base);
2704
2705 let mut start_pos: u32 = 0;
2706 for &in_id in &node.inputs {
2707 let in_shape = graph.node(in_id).shape.dims();
2708 let axis_in = in_shape[*axis].unwrap_static() as u32;
2709 let in_total: u32 =
2710 in_shape.iter().map(|d| d.unwrap_static() as u32).product();
2711 let _win_ids = [node.id, in_id];
2712 let in_off = arena_off_in_bind_window(
2713 &graph,
2714 ¶m_offsets,
2715 &dev.device,
2716 &arena,
2717 &mut schedule,
2718 &mut scratch,
2719 in_id,
2720 &mut base,
2721 &mut size,
2722 );
2723 let p = NarrowConcatParams {
2724 total: in_total,
2725 outer,
2726 inner,
2727 axis_in_size: axis_in,
2728 axis_out_size: axis_out,
2729 start: start_pos,
2730 in_off,
2731 out_off,
2732 };
2733 schedule.push(Step::Concat { params: p });
2734 let cck = concat_kernel(&dev.device);
2735 let u = emit_uniform(std::mem::size_of::<NarrowConcatParams>());
2736 let bg =
2737 bind_two_buf0_window(&dev.device, cck, &arena.buffer, base, size, &u);
2738 uniforms.push(u);
2739 bind_groups.push(bg);
2740 start_pos += axis_in;
2741 }
2742 }
2743
2744 Op::Attention {
2745 num_heads,
2746 head_dim,
2747 mask_kind,
2748 score_scale: _,
2749 attn_logit_softcap: _,
2750 } => {
2751 let q_id = node.inputs[0];
2754 let k_id = node.inputs[1];
2755 let v_id = node.inputs[2];
2756 let q_shape = graph.node(q_id).shape.dims();
2757 let k_shape = graph.node(k_id).shape.dims();
2758 let h = *num_heads as u32;
2764 let hd = *head_dim as u32;
2765 let q_ir = graph.node(q_id).shape.clone();
2766 let k_ir = graph.node(k_id).shape.clone();
2767 let geom = rlx_ir::attention_geom(&q_ir, &k_ir, *num_heads, *head_dim);
2768 let bhsd = geom.bhsd;
2769 let (batch, heads, seq_q, seq_k) = match q_shape.len() {
2770 4 => (
2771 geom.batch as u32,
2772 geom.heads as u32,
2773 geom.seq_q as u32,
2774 geom.seq_k as u32,
2775 ),
2776 3 => {
2777 let last = q_shape[2].unwrap_static() as u32;
2784 if last == h * hd {
2785 (
2787 q_shape[0].unwrap_static() as u32,
2788 h,
2789 q_shape[1].unwrap_static() as u32,
2790 k_shape[1].unwrap_static() as u32,
2791 )
2792 } else {
2793 let leading = q_shape[0].unwrap_static() as u32;
2795 if !leading.is_multiple_of(h) {
2796 panic!(
2797 "rlx-wgpu Attention: rank-3 leading dim {leading} \
2798 not divisible by num_heads {h} (and last dim \
2799 {last} ≠ H·D = {})",
2800 h * hd
2801 );
2802 }
2803 (
2804 leading / h,
2805 h,
2806 q_shape[1].unwrap_static() as u32,
2807 k_shape[1].unwrap_static() as u32,
2808 )
2809 }
2810 }
2811 other => panic!(
2812 "rlx-wgpu Attention: only rank-3 / rank-4 Q,K,V \
2813 inputs supported (got rank {other})"
2814 ),
2815 };
2816 let scale = 1.0_f32 / (hd as f32).sqrt();
2817
2818 let (mask_kind_id, mask_buf, window) = match mask_kind {
2819 MaskKind::None => (0u32, None, 0u32),
2820 MaskKind::Causal => (1u32, None, 0u32),
2821 MaskKind::Custom | MaskKind::Bias => (2u32, None, 0u32),
2822 MaskKind::SlidingWindow(w) => (3u32, None, *w as u32),
2823 };
2824
2825 struct MStrides {
2832 b: u32,
2833 h: u32,
2834 q: u32,
2835 k: u32,
2836 }
2837 let mask_strides = if mask_kind_id == 2u32 {
2838 let m_dims = graph.node(node.inputs[3]).shape.dims();
2839 let dim = |i: usize| m_dims[i].unwrap_static() as u32;
2840 match m_dims.len() {
2841 2 => MStrides {
2842 b: dim(1),
2843 h: 0,
2844 q: 0,
2845 k: 1,
2846 },
2847 3 => MStrides {
2848 b: dim(1) * dim(2),
2849 h: 0,
2850 q: dim(2),
2851 k: 1,
2852 },
2853 4 => MStrides {
2854 b: dim(1) * dim(2) * dim(3),
2855 h: dim(2) * dim(3),
2856 q: dim(3),
2857 k: 1,
2858 },
2859 _ => MStrides {
2860 b: heads * seq_q * seq_k,
2861 h: seq_q * seq_k,
2862 q: seq_k,
2863 k: 1,
2864 },
2865 }
2866 } else {
2867 MStrides {
2868 b: heads * seq_q * seq_k,
2869 h: seq_q * seq_k,
2870 q: seq_k,
2871 k: 1,
2872 }
2873 };
2874
2875 let stride = |shape: &[rlx_ir::shape::Dim], seq_extent: u32| {
2876 rlx_ir::strides_for_shape(shape, heads, hd, seq_extent, bhsd)
2877 };
2878 let packed_parent = packed_bshd_attn.get(&node.id).copied();
2879 let (q_b, q_h, q_s, k_b, k_h, k_s, v_b, v_h, v_s) =
2880 if let Some((_parent, head_width)) = packed_parent {
2881 let (batch_stride, head_stride, pack_seq) =
2882 rlx_ir::packed_bshd_qkv_strides(head_width as usize, hd, seq_q);
2883 (
2884 batch_stride,
2885 head_stride,
2886 pack_seq,
2887 batch_stride,
2888 head_stride,
2889 pack_seq,
2890 batch_stride,
2891 head_stride,
2892 pack_seq,
2893 )
2894 } else {
2895 let (qb, qh, qs) = stride(q_shape, seq_q);
2896 let (kb, kh, ks) = stride(k_shape, seq_k);
2897 let v_shape = graph.node(v_id).shape.dims();
2898 let (vb, vh, vs) = stride(v_shape, seq_k);
2899 (qb, qh, qs, kb, kh, ks, vb, vh, vs)
2900 };
2901 let out_shape = node.shape.dims();
2902 let (o_b, o_h, o_s) = stride(out_shape, seq_q);
2903 let mut attn_ids = if let Some((parent, _)) = packed_parent {
2904 vec![node.id, parent]
2905 } else {
2906 vec![node.id, q_id, k_id, v_id]
2907 };
2908 if mask_kind_id == 2 {
2909 attn_ids.push(node.inputs[3]);
2910 }
2911 let mut scratch = arena.scratch_off as u64;
2912 let (mut base, mut size, param_anchor) = arena_multi_op_window(
2913 &dev.device,
2914 &arena,
2915 &graph,
2916 ¶m_offsets,
2917 &mut schedule,
2918 &mut scratch,
2919 &attn_ids,
2920 );
2921 if !param_anchor {
2922 base = arena_bind_window_covering_scratch_if_needed(
2923 &arena, base, size, scratch,
2924 );
2925 }
2926 let (q_off, k_off, v_off) = if let Some((parent, head_width)) = packed_parent {
2927 let parent_off = arena_off_in_bind_window(
2928 &graph,
2929 ¶m_offsets,
2930 &dev.device,
2931 &arena,
2932 &mut schedule,
2933 &mut scratch,
2934 parent,
2935 &mut base,
2936 &mut size,
2937 );
2938 (
2939 parent_off,
2940 parent_off.saturating_add(head_width),
2941 parent_off.saturating_add(head_width * 2),
2942 )
2943 } else {
2944 let q_off = arena_off_in_bind_window(
2945 &graph,
2946 ¶m_offsets,
2947 &dev.device,
2948 &arena,
2949 &mut schedule,
2950 &mut scratch,
2951 q_id,
2952 &mut base,
2953 &mut size,
2954 );
2955 let k_off = arena_off_in_bind_window(
2956 &graph,
2957 ¶m_offsets,
2958 &dev.device,
2959 &arena,
2960 &mut schedule,
2961 &mut scratch,
2962 k_id,
2963 &mut base,
2964 &mut size,
2965 );
2966 let v_off = arena_off_in_bind_window(
2967 &graph,
2968 ¶m_offsets,
2969 &dev.device,
2970 &arena,
2971 &mut schedule,
2972 &mut scratch,
2973 v_id,
2974 &mut base,
2975 &mut size,
2976 );
2977 (q_off, k_off, v_off)
2978 };
2979 let out_byte = arena.offset(node.id) as u64;
2980 let out_len = arena.len_of(node.id) as u64;
2981 let out_aliases_qkv = arena_tensors_overlap(&arena, node.id, q_id)
2982 || arena_tensors_overlap(&arena, node.id, k_id)
2983 || arena_tensors_overlap(&arena, node.id, v_id)
2984 || packed_parent.is_some_and(|(parent, _)| {
2985 arena_tensors_overlap(&arena, node.id, parent)
2986 });
2987 let mut kernel_out_off = arena_off_in_bind_window(
2988 &graph,
2989 ¶m_offsets,
2990 &dev.device,
2991 &arena,
2992 &mut schedule,
2993 &mut scratch,
2994 node.id,
2995 &mut base,
2996 &mut size,
2997 );
2998 let mut attn_scratch_copy: Option<(u64, u32)> = None;
2999 if out_aliases_qkv && rlx_ir::env::flag("RLX_WGPU_DEBUG_ATTN_ALIAS") {
3000 eprintln!(
3001 "rlx-wgpu Attention alias: out={:?}@{}+{} q={:?}@{} k={:?}@{} v={:?}@{}",
3002 node.id,
3003 out_byte,
3004 out_len,
3005 q_id,
3006 arena.offset(q_id),
3007 k_id,
3008 arena.offset(k_id),
3009 v_id,
3010 arena.offset(v_id),
3011 );
3012 }
3013 if out_aliases_qkv {
3014 let tmp_byte = scratch;
3015 let tmp_aligned = out_len.div_ceil(256) * 256;
3016 scratch = scratch.saturating_add(tmp_aligned);
3017 if param_anchor {
3018 arena_ensure_scratch_in_window(&mut scratch, base, size);
3019 } else {
3020 base = arena_bind_window_covering_scratch_if_needed(
3021 &arena, base, size, scratch,
3022 );
3023 }
3024 kernel_out_off = ((tmp_byte.saturating_sub(base)) / 4) as u32;
3025 attn_scratch_copy = Some((tmp_byte, out_len as u32));
3026 }
3027 let mask_off = if mask_kind_id == 2 {
3028 arena_off_in_bind_window(
3029 &graph,
3030 ¶m_offsets,
3031 &dev.device,
3032 &arena,
3033 &mut schedule,
3034 &mut scratch,
3035 node.inputs[3],
3036 &mut base,
3037 &mut size,
3038 )
3039 } else {
3040 0
3041 };
3042 let p = AttentionParams {
3043 batch,
3044 heads,
3045 seq_q,
3046 seq_k,
3047 head_dim: hd,
3048 q_off,
3049 k_off,
3050 v_off,
3051 out_off: kernel_out_off,
3052 mask_off,
3053 mask_kind: mask_kind_id,
3054 scale_bits: scale.to_bits(),
3055 window,
3056 seq_q_stride: mask_strides.q,
3065 seq_k_stride: mask_strides.k,
3066 mask_batch_stride: mask_strides.b,
3067 mask_head_stride: mask_strides.h,
3068 _pad_mask_0: 0,
3069 _pad_mask_1: 0,
3070 _pad_mask_2: 0,
3071 q_batch_stride: q_b,
3072 q_head_stride: q_h,
3073 q_seq_stride: q_s,
3074 _pad_q: 0,
3075 k_batch_stride: k_b,
3076 k_head_stride: k_h,
3077 k_seq_stride: k_s,
3078 _pad_k: 0,
3079 v_batch_stride: v_b,
3080 v_head_stride: v_h,
3081 v_seq_stride: v_s,
3082 _pad_v: 0,
3083 o_batch_stride: o_b,
3084 o_head_stride: o_h,
3085 o_seq_stride: o_s,
3086 _pad_o: 0,
3087 };
3088 let _ = num_heads;
3089 schedule.push(Step::Attention {
3090 params: p,
3091 mask_buf,
3092 });
3093 if let Some((tmp_byte, bytes)) = attn_scratch_copy {
3094 schedule.push(Step::BufferCopy {
3095 src_byte_off: tmp_byte as u32,
3096 dst_byte_off: out_byte as u32,
3097 bytes,
3098 });
3099 }
3100 let ak = attention_kernel(&dev.device);
3101 let u = emit_uniform(std::mem::size_of::<AttentionParams>());
3102 let bg = bind_two_buf0_window(&dev.device, ak, &arena.buffer, base, size, &u);
3103 uniforms.push(u);
3104 bind_groups.push(bg);
3105 }
3106
3107 Op::AttentionBackward {
3108 num_heads,
3109 head_dim,
3110 mask_kind,
3111 wrt,
3112 } => {
3113 use rlx_ir::op::AttentionBwdWrt;
3114 let q_id = node.inputs[0];
3115 let k_id = node.inputs[1];
3116 let v_id = node.inputs[2];
3117 let dy_id = node.inputs[3];
3118 let q_shape = graph.node(q_id).shape.dims();
3119 let k_shape = graph.node(k_id).shape.dims();
3120 let hd = *head_dim as u32;
3121 let q_ir = graph.node(q_id).shape.clone();
3122 let k_ir = graph.node(k_id).shape.clone();
3123 let geom = rlx_ir::attention_geom(&q_ir, &k_ir, *num_heads, *head_dim);
3124 let bhsd = geom.bhsd;
3125 let (batch, heads, seq_q, seq_k) = match q_shape.len() {
3126 4 => (
3127 geom.batch as u32,
3128 geom.heads as u32,
3129 geom.seq_q as u32,
3130 geom.seq_k as u32,
3131 ),
3132 3 => {
3133 let h = q_shape[2].unwrap_static() as u32 / hd;
3134 (
3135 q_shape[0].unwrap_static() as u32 / h,
3136 h,
3137 q_shape[1].unwrap_static() as u32,
3138 k_shape[1].unwrap_static() as u32,
3139 )
3140 }
3141 other => panic!(
3142 "rlx-wgpu AttentionBackward: only rank-3/4 Q,K,V (got rank {other})"
3143 ),
3144 };
3145 let scale = 1.0_f32 / (hd as f32).sqrt();
3146 let (mask_kind_id, mask_off, mask_buf, window) = match mask_kind {
3147 MaskKind::None => (0u32, 0u32, None, 0u32),
3148 MaskKind::Causal => (1u32, 0u32, None, 0u32),
3149 MaskKind::Custom => {
3150 (2u32, (arena.offset(node.inputs[4]) / 4) as u32, None, 0u32)
3151 }
3152 MaskKind::Bias => {
3153 (4u32, (arena.offset(node.inputs[4]) / 4) as u32, None, 0u32)
3154 }
3155 MaskKind::SlidingWindow(w) => (3u32, 0u32, None, *w as u32),
3156 };
3157 struct MStrides {
3158 b: u32,
3159 h: u32,
3160 q: u32,
3161 k: u32,
3162 }
3163 let mask_strides = if mask_kind_id == 2 || mask_kind_id == 4 {
3164 let m_dims = graph.node(node.inputs[4]).shape.dims();
3165 let dim = |i: usize| m_dims[i].unwrap_static() as u32;
3166 match m_dims.len() {
3167 2 => MStrides {
3168 b: dim(1),
3169 h: 0,
3170 q: 0,
3171 k: 1,
3172 },
3173 3 => MStrides {
3174 b: dim(1) * dim(2),
3175 h: 0,
3176 q: dim(2),
3177 k: 1,
3178 },
3179 4 => MStrides {
3180 b: dim(1) * dim(2) * dim(3),
3181 h: dim(2) * dim(3),
3182 q: dim(3),
3183 k: 1,
3184 },
3185 _ => MStrides {
3186 b: heads * seq_q * seq_k,
3187 h: seq_q * seq_k,
3188 q: seq_k,
3189 k: 1,
3190 },
3191 }
3192 } else {
3193 MStrides {
3194 b: heads * seq_q * seq_k,
3195 h: seq_q * seq_k,
3196 q: seq_k,
3197 k: 1,
3198 }
3199 };
3200 let stride = |shape: &[rlx_ir::shape::Dim], seq_extent: u32| {
3201 rlx_ir::strides_for_shape(shape, heads, hd, seq_extent, bhsd)
3202 };
3203 let (q_b, q_h, q_s) = stride(q_shape, seq_q);
3204 let (k_b, k_h, k_s) = stride(k_shape, seq_k);
3205 let v_shape = graph.node(v_id).shape.dims();
3206 let (v_b, v_h, v_s) = stride(v_shape, seq_k);
3207 let out_shape = node.shape.dims();
3208 let out_seq = match wrt {
3209 AttentionBwdWrt::Query => seq_q,
3210 AttentionBwdWrt::Key | AttentionBwdWrt::Value => seq_k,
3211 };
3212 let (o_b, o_h, o_s) = stride(out_shape, out_seq);
3213 let wrt_id = match wrt {
3214 AttentionBwdWrt::Query => 0u32,
3215 AttentionBwdWrt::Key => 1u32,
3216 AttentionBwdWrt::Value => 2u32,
3217 };
3218 let p = AttentionBwdParams {
3219 batch,
3220 heads,
3221 seq_q,
3222 seq_k,
3223 head_dim: hd,
3224 q_off: (arena.offset(q_id) / 4) as u32,
3225 k_off: (arena.offset(k_id) / 4) as u32,
3226 v_off: (arena.offset(v_id) / 4) as u32,
3227 dy_off: (arena.offset(dy_id) / 4) as u32,
3228 out_off: (arena.offset(node.id) / 4) as u32,
3229 mask_off,
3230 mask_kind: mask_kind_id,
3231 scale_bits: scale.to_bits(),
3232 window,
3233 wrt: wrt_id,
3234 seq_q_stride: mask_strides.q,
3235 seq_k_stride: mask_strides.k,
3236 mask_batch_stride: mask_strides.b,
3237 mask_head_stride: mask_strides.h,
3238 _pad_mask_0: 0,
3239 _pad_mask_1: 0,
3240 _pad_mask_2: 0,
3241 q_batch_stride: q_b,
3242 q_head_stride: q_h,
3243 q_seq_stride: q_s,
3244 _pad_q: 0,
3245 k_batch_stride: k_b,
3246 k_head_stride: k_h,
3247 k_seq_stride: k_s,
3248 _pad_k: 0,
3249 v_batch_stride: v_b,
3250 v_head_stride: v_h,
3251 v_seq_stride: v_s,
3252 _pad_v: 0,
3253 o_batch_stride: o_b,
3254 o_head_stride: o_h,
3255 o_seq_stride: o_s,
3256 _pad_o: 0,
3257 };
3258 schedule.push(Step::AttentionBackward {
3259 params: p,
3260 mask_buf,
3261 });
3262 let ak = attention_bwd_kernel(&dev.device);
3263 let u = emit_uniform(std::mem::size_of::<AttentionBwdParams>());
3264 let bg = bind_op_output_window(&dev.device, ak, &arena, node.id, &u);
3265 uniforms.push(u);
3266 bind_groups.push(bg);
3267 }
3268
3269 Op::Rope { head_dim, n_rot: _ } => {
3270 let x_id = node.inputs[0];
3271 let cos_id = node.inputs[1];
3272 let sin_id = node.inputs[2];
3273 let x_shape = graph.node(x_id).shape.dims();
3274 let last = x_shape.last().map(|d| d.unwrap_static()).unwrap_or(0);
3275 if !last.is_multiple_of(*head_dim) {
3276 panic!(
3277 "rlx-wgpu Rope: last_dim ({last}) must be a multiple \
3278 of head_dim ({head_dim})"
3279 );
3280 }
3281 if head_dim % 2 != 0 {
3282 panic!("rlx-wgpu Rope: head_dim must be even");
3283 }
3284 let total: u32 = x_shape.iter().map(|d| d.unwrap_static() as u32).product();
3285 let seq = x_shape[x_shape.len() - 2].unwrap_static() as u32;
3286 let batch = total / (seq * last as u32).max(1);
3291 let cos_is_param = tensor_is_graph_param(&graph, ¶m_offsets, cos_id);
3292 let cos_bytes = arena.len_of(cos_id) as u64;
3293 let rope_win: Vec<NodeId> = if cos_is_param && cos_bytes > ARENA_STAGE_CAP {
3294 vec![cos_id, sin_id, node.id, x_id]
3295 } else {
3296 vec![node.id, x_id, cos_id, sin_id]
3297 };
3298 let mut scratch = arena.scratch_off as u64;
3299 let (mut base, mut size, param_anchor) = arena_multi_op_window(
3300 &dev.device,
3301 &arena,
3302 &graph,
3303 ¶m_offsets,
3304 &mut schedule,
3305 &mut scratch,
3306 &rope_win,
3307 );
3308 if !param_anchor {
3309 base = arena_bind_window_covering_scratch_if_needed(
3310 &arena, base, size, scratch,
3311 );
3312 }
3313 let in_off = arena_off_in_bind_window(
3314 &graph,
3315 ¶m_offsets,
3316 &dev.device,
3317 &arena,
3318 &mut schedule,
3319 &mut scratch,
3320 x_id,
3321 &mut base,
3322 &mut size,
3323 );
3324 let cos_off = arena_off_in_bind_window(
3325 &graph,
3326 ¶m_offsets,
3327 &dev.device,
3328 &arena,
3329 &mut schedule,
3330 &mut scratch,
3331 cos_id,
3332 &mut base,
3333 &mut size,
3334 );
3335 let sin_off = arena_off_in_bind_window(
3336 &graph,
3337 ¶m_offsets,
3338 &dev.device,
3339 &arena,
3340 &mut schedule,
3341 &mut scratch,
3342 sin_id,
3343 &mut base,
3344 &mut size,
3345 );
3346 let p = RopeParams {
3347 n_total: total,
3348 seq,
3349 head_dim: *head_dim as u32,
3350 half: (*head_dim / 2) as u32,
3351 in_off,
3352 cos_off,
3353 sin_off,
3354 out_off: arena_local_off_f32(&arena, node.id, base),
3355 last_dim: last as u32,
3356 batch,
3357 seq_stride: seq,
3358 _p2: 0,
3359 };
3360 schedule.push(Step::Rope { params: p });
3361 let rk = rope_kernel(&dev.device);
3362 let u = emit_uniform(std::mem::size_of::<RopeParams>());
3363 let bg = bind_two_buf0_window(&dev.device, rk, &arena.buffer, base, size, &u);
3364 uniforms.push(u);
3365 bind_groups.push(bg);
3366 }
3367
3368 Op::Expand { target_shape } => {
3369 let in_id = node.inputs[0];
3370 let in_shape = graph.node(in_id).shape.dims();
3371 let in_rank = in_shape.len();
3372 let rank = target_shape.len();
3373 if in_rank > rank {
3374 panic!(
3375 "rlx-wgpu Expand: rank mismatch \
3376 (in_rank={in_rank}, target_rank={rank})"
3377 );
3378 }
3379 let pad = rank.saturating_sub(in_rank);
3382 let out_dims: Vec<u32> = target_shape.iter().map(|&d| d as u32).collect();
3383 let in_dims: Vec<u32> = (0..rank)
3384 .map(|i| {
3385 if i < pad {
3386 1
3387 } else {
3388 in_shape[i - pad].unwrap_static() as u32
3389 }
3390 })
3391 .collect();
3392 let mut in_strides_row = vec![1u32; rank];
3396 for i in (0..rank.saturating_sub(1)).rev() {
3397 in_strides_row[i] = in_strides_row[i + 1] * in_dims[i + 1];
3398 }
3399 let strides_for_out: Vec<u32> = (0..rank)
3400 .map(|i| {
3401 if in_dims[i] == 1 && out_dims[i] != 1 {
3402 0
3403 } else {
3404 in_strides_row[i]
3405 }
3406 })
3407 .collect();
3408
3409 let mut meta_data: Vec<u32> = Vec::with_capacity(rank * 2);
3410 meta_data.extend_from_slice(&out_dims);
3411 meta_data.extend_from_slice(&strides_for_out);
3412 let meta_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
3413 label: Some("rlx-wgpu expand meta"),
3414 size: (meta_data.len() * 4).max(4) as u64,
3415 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
3416 mapped_at_creation: false,
3417 });
3418 dev.queue
3419 .write_buffer(&meta_buf, 0, bytemuck::cast_slice(&meta_data));
3420 let meta_idx = meta_buffers.len();
3421 meta_buffers.push(meta_buf);
3422
3423 let bucket_outermost = if in_dims[0] == out_dims[0] {
3429 1u32
3430 } else {
3431 0u32
3432 };
3433 let exp_ids = [node.id, in_id];
3434 let max_binding = dev.device.limits().max_storage_buffer_binding_size;
3435 let exp_fits = arena_span_bytes(&arena, &exp_ids) <= max_binding;
3436 let mut scratch = arena.scratch_off as u64;
3437 let (mut base, mut size, param_anchor) = arena_multi_op_window(
3438 &dev.device,
3439 &arena,
3440 &graph,
3441 ¶m_offsets,
3442 &mut schedule,
3443 &mut scratch,
3444 &exp_ids,
3445 );
3446 if !exp_fits && !param_anchor {
3447 base = arena_bind_window_covering_scratch_if_needed(
3448 &arena, base, size, scratch,
3449 );
3450 }
3451 let in_off = arena_off_in_bind_window(
3452 &graph,
3453 ¶m_offsets,
3454 &dev.device,
3455 &arena,
3456 &mut schedule,
3457 &mut scratch,
3458 in_id,
3459 &mut base,
3460 &mut size,
3461 );
3462 let out_off = arena_off_in_bind_window(
3463 &graph,
3464 ¶m_offsets,
3465 &dev.device,
3466 &arena,
3467 &mut schedule,
3468 &mut scratch,
3469 node.id,
3470 &mut base,
3471 &mut size,
3472 );
3473 let p = ExpandParams {
3474 rank: rank as u32,
3475 out_total: elems,
3476 in_off,
3477 out_off,
3478 bucket_outermost,
3479 out_dim_0: out_dims[0],
3480 _p2: 0,
3481 _p3: 0,
3482 };
3483 schedule.push(Step::Expand {
3484 params: p,
3485 meta_idx,
3486 });
3487 let ek = expand_kernel(&dev.device);
3488 let u = emit_uniform(std::mem::size_of::<ExpandParams>());
3489 let bg = dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
3490 label: Some("rlx-wgpu expand bg"),
3491 layout: &ek.bgl,
3492 entries: &[
3493 wgpu::BindGroupEntry {
3494 binding: 0,
3495 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
3496 buffer: &arena.buffer,
3497 offset: base,
3498 size: NonZeroU64::new(size),
3499 }),
3500 },
3501 wgpu::BindGroupEntry {
3502 binding: 1,
3503 resource: u.as_entire_binding(),
3504 },
3505 wgpu::BindGroupEntry {
3506 binding: 2,
3507 resource: meta_buffers[meta_idx].as_entire_binding(),
3508 },
3509 ],
3510 });
3511 uniforms.push(u);
3512 bind_groups.push(bg);
3513 }
3514
3515 Op::Gather { axis } => {
3516 let table_id = node.inputs[0];
3517 let idx_id = node.inputs[1];
3518 let table_is_param = tensor_is_graph_param(&graph, ¶m_offsets, table_id);
3519 let table_bytes = arena.len_of(table_id) as u64;
3520 let gather_win: Vec<NodeId> = if table_is_param && table_bytes > ARENA_STAGE_CAP
3521 {
3522 vec![table_id, node.id, idx_id]
3523 } else {
3524 vec![node.id, idx_id, table_id]
3525 };
3526 let mut scratch = arena.scratch_off as u64;
3527 let (mut base, mut size, table_anchor) = arena_multi_op_window(
3528 &dev.device,
3529 &arena,
3530 &graph,
3531 ¶m_offsets,
3532 &mut schedule,
3533 &mut scratch,
3534 &gather_win,
3535 );
3536 if !table_anchor {
3537 base = arena_bind_window_covering_scratch_if_needed(
3538 &arena, base, size, scratch,
3539 );
3540 }
3541 let in_off =
3542 if table_anchor && arena_tensor_in_window(&arena, table_id, base, size) {
3543 arena_local_off_f32(&arena, table_id, base)
3544 } else {
3545 arena_off_in_bind_window(
3546 &graph,
3547 ¶m_offsets,
3548 &dev.device,
3549 &arena,
3550 &mut schedule,
3551 &mut scratch,
3552 table_id,
3553 &mut base,
3554 &mut size,
3555 )
3556 };
3557 let idx_off = arena_off_in_bind_window(
3558 &graph,
3559 ¶m_offsets,
3560 &dev.device,
3561 &arena,
3562 &mut schedule,
3563 &mut scratch,
3564 idx_id,
3565 &mut base,
3566 &mut size,
3567 );
3568 let out_off = arena_local_off_f32(&arena, node.id, base);
3569 if *axis == 0 {
3570 let table_shape = graph.node(table_id).shape.dims();
3571 let idx_shape = graph.node(idx_id).shape.dims();
3572 let vocab = table_shape[0].unwrap_static() as u32;
3573 let dim: u32 = table_shape[1..]
3574 .iter()
3575 .map(|d| d.unwrap_static() as u32)
3576 .product::<u32>()
3577 .max(1);
3578 let n_idx: u32 =
3579 idx_shape.iter().map(|d| d.unwrap_static() as u32).product();
3580 let p = GatherParams {
3581 n_out: elems,
3582 n_idx,
3583 dim,
3584 vocab,
3585 in_off,
3586 idx_off,
3587 out_off,
3588 _p0: 0,
3589 };
3590 schedule.push(Step::Gather { params: p });
3591 let gk = gather_kernel(&dev.device);
3592 let u = emit_uniform(std::mem::size_of::<GatherParams>());
3593 let bg =
3594 bind_two_buf0_window(&dev.device, gk, &arena.buffer, base, size, &u);
3595 uniforms.push(u);
3596 bind_groups.push(bg);
3597 } else {
3598 let table_shape = graph.node(table_id).shape.dims();
3599 let idx_shape = graph.node(idx_id).shape.dims();
3600 let outer: u32 = table_shape[..*axis]
3601 .iter()
3602 .map(|d| d.unwrap_static() as u32)
3603 .product::<u32>()
3604 .max(1);
3605 let trailing: u32 = table_shape[*axis + 1..]
3606 .iter()
3607 .map(|d| d.unwrap_static() as u32)
3608 .product::<u32>()
3609 .max(1);
3610 let axis_dim = table_shape[*axis].unwrap_static() as u32;
3611 let num_idx: u32 =
3612 idx_shape.iter().map(|d| d.unwrap_static() as u32).product();
3613 let total = outer * num_idx * trailing;
3614 let p = GatherAxisParams {
3615 total,
3616 outer,
3617 axis_dim,
3618 num_idx,
3619 trailing,
3620 table_off: in_off,
3621 idx_off,
3622 out_off,
3623 };
3624 schedule.push(Step::GatherAxis { params: p });
3625 let gk = gather_axis_kernel(&dev.device);
3626 let u = emit_uniform(std::mem::size_of::<GatherAxisParams>());
3627 let bg =
3628 bind_two_buf0_window(&dev.device, gk, &arena.buffer, base, size, &u);
3629 uniforms.push(u);
3630 bind_groups.push(bg);
3631 }
3632 }
3633
3634 Op::FusedMatMulBiasAct { activation } => {
3635 let a_id = node.inputs[0];
3638 let b_id = node.inputs[1];
3639 let bias_id = node.inputs[2];
3640 let a_shape = graph.node(a_id).shape.dims();
3641 let b_shape = graph.node(b_id).shape.dims();
3642 let out_shape = node.shape.dims();
3643 let (m, k, n) =
3644 if a_shape.len() == 2 && b_shape.len() == 2 && out_shape.len() == 2 {
3645 (
3646 a_shape[0].unwrap_static() as u32,
3647 a_shape[1].unwrap_static() as u32,
3648 b_shape[1].unwrap_static() as u32,
3649 )
3650 } else if a_shape.len() >= 2
3651 && b_shape.len() == 2
3652 && out_shape.len() == a_shape.len()
3653 {
3654 let leading: usize = a_shape[..a_shape.len() - 2]
3655 .iter()
3656 .map(|d| d.unwrap_static())
3657 .product();
3658 let m_inner = a_shape[a_shape.len() - 2].unwrap_static();
3659 let k_inner = a_shape[a_shape.len() - 1].unwrap_static();
3660 let n_inner = b_shape[1].unwrap_static();
3661 ((leading * m_inner) as u32, k_inner as u32, n_inner as u32)
3662 } else {
3663 panic!(
3664 "rlx-wgpu FusedMatMulBiasAct: unsupported shapes \
3665 a={a_shape:?} b={b_shape:?}"
3666 );
3667 };
3668 let act_id = match activation {
3669 None => 0xFFFFu32,
3670 Some(a) => activation_op_id(*a),
3671 };
3672 let b_is_param = tensor_is_graph_param(&graph, ¶m_offsets, b_id);
3673 let b_bytes = arena.len_of(b_id) as u64;
3674 let mut compute_precision = derive_matmul_compute(
3675 &dev.device,
3676 &graph,
3677 &coop_f16_vk_mirror_acts,
3678 a_id,
3679 b_id,
3680 m,
3681 k,
3682 n,
3683 );
3684 if b_is_param && b_bytes > ARENA_STAGE_CAP && arena.param_fits_f16_mirror(b_id)
3685 {
3686 compute_precision = MatmulCompute::F16;
3687 }
3688
3689 let mqk_eligible = act_id == 0xFFFFu32
3693 && matches!(
3694 compute_precision,
3695 MatmulCompute::F32 | MatmulCompute::CoopF32 | MatmulCompute::CoopF16Vk
3696 );
3697 if mqk_eligible && let Some(&(q_id, k_id_n, v_id)) = qkv_split.get(&node.id) {
3698 let head_width = n / 3;
3699 let qkv_kind = match compute_precision {
3700 MatmulCompute::CoopF16Vk => MatmulQkvKind::CoopF16Vk,
3701 MatmulCompute::CoopF32 => MatmulQkvKind::CoopF32,
3702 _ => MatmulQkvKind::F32,
3703 };
3704 let (mut base, mut size, param_anchor) = arena_matmul_bind_window(
3705 &dev.device,
3706 &arena,
3707 &graph,
3708 ¶m_offsets,
3709 q_id,
3710 a_id,
3711 b_id,
3712 );
3713 let mut scratch = arena.scratch_off as u64;
3714 if param_anchor {
3715 arena_ensure_scratch_in_window(&mut scratch, base, size);
3716 }
3717 if b_is_param && b_bytes > ARENA_STAGE_CAP {
3718 assert!(
3719 param_anchor && arena_tensor_in_window(&arena, b_id, base, size),
3720 "rlx-wgpu FusedMatMul QKV: large param B {:?} not in bind window",
3721 b_id,
3722 );
3723 }
3724 let a_off = arena_off_in_bind_window(
3725 &graph,
3726 ¶m_offsets,
3727 &dev.device,
3728 &arena,
3729 &mut schedule,
3730 &mut scratch,
3731 a_id,
3732 &mut base,
3733 &mut size,
3734 );
3735 let q_off = arena_off_in_bind_window(
3736 &graph,
3737 ¶m_offsets,
3738 &dev.device,
3739 &arena,
3740 &mut schedule,
3741 &mut scratch,
3742 q_id,
3743 &mut base,
3744 &mut size,
3745 );
3746 let k_off = arena_off_in_bind_window(
3747 &graph,
3748 ¶m_offsets,
3749 &dev.device,
3750 &arena,
3751 &mut schedule,
3752 &mut scratch,
3753 k_id_n,
3754 &mut base,
3755 &mut size,
3756 );
3757 let v_off = arena_off_in_bind_window(
3758 &graph,
3759 ¶m_offsets,
3760 &dev.device,
3761 &arena,
3762 &mut schedule,
3763 &mut scratch,
3764 v_id,
3765 &mut base,
3766 &mut size,
3767 );
3768 let bias_off = arena_off_in_bind_window(
3769 &graph,
3770 ¶m_offsets,
3771 &dev.device,
3772 &arena,
3773 &mut schedule,
3774 &mut scratch,
3775 bias_id,
3776 &mut base,
3777 &mut size,
3778 );
3779 let b_off_f32 = if b_is_param
3780 && b_bytes > ARENA_STAGE_CAP
3781 && arena_tensor_in_window(&arena, b_id, base, size)
3782 {
3783 arena_local_off_f32(&arena, b_id, base)
3784 } else {
3785 arena_off_in_bind_window(
3786 &graph,
3787 ¶m_offsets,
3788 &dev.device,
3789 &arena,
3790 &mut schedule,
3791 &mut scratch,
3792 b_id,
3793 &mut base,
3794 &mut size,
3795 )
3796 };
3797 let b_off_global = (arena.offset(b_id) / 4) as u32;
3798 maybe_push_coop_f16_vk_casts(
3799 &graph,
3800 a_id,
3801 b_id,
3802 &coop_f16_vk_mirror_acts,
3803 &dev.device,
3804 &arena,
3805 &mut schedule,
3806 &mut uniforms,
3807 &mut bind_groups,
3808 &mm_cast,
3809 compute_precision,
3810 a_off,
3811 m,
3812 k,
3813 1,
3814 if qkv_kind == MatmulQkvKind::CoopF16Vk {
3815 b_off_global
3816 } else {
3817 b_off_f32
3818 },
3819 n,
3820 );
3821 let p = MatmulQkvParams {
3822 m,
3823 k,
3824 n,
3825 a_off,
3826 b_off: if qkv_kind == MatmulQkvKind::CoopF16Vk {
3827 b_off_global
3828 } else {
3829 b_off_f32
3830 },
3831 q_off,
3832 k_off,
3833 v_off,
3834 head_width,
3835 has_bias: 1,
3836 bias_off,
3837 _p0: 0,
3838 _p1: 0,
3839 _p2: 0,
3840 _p3: 0,
3841 _p4: 0,
3842 };
3843 schedule.push(Step::MatmulQkv {
3844 params: p,
3845 kind: qkv_kind,
3846 });
3847 register_coop_f16_vk_b_param(
3848 &mut coop_f16_b_param,
3849 ¶m_offsets,
3850 b_id,
3851 p.b_off,
3852 match qkv_kind {
3853 MatmulQkvKind::CoopF16Vk => MatmulCompute::CoopF16Vk,
3854 MatmulQkvKind::CoopF32 => MatmulCompute::CoopF32,
3855 MatmulQkvKind::F32 => MatmulCompute::F32,
3856 },
3857 );
3858 let u = emit_uniform(std::mem::size_of::<MatmulQkvParams>());
3859 let bg = match qkv_kind {
3860 MatmulQkvKind::CoopF16Vk => {
3861 let mqk = matmul_qkv_coop_f16_vk_kernel(&dev.device).expect(
3862 "coop f16 matmul_qkv kernel: feature was checked but missing",
3863 );
3864 let (bg, b_off_adj) = build_matmul_qkv_coop_f16_vk_bind_group(
3865 &dev.device,
3866 mqk,
3867 &arena,
3868 base,
3869 size,
3870 &u,
3871 k,
3872 n,
3873 p.b_off,
3874 );
3875 if let Some(Step::MatmulQkv { params, .. }) = schedule.last_mut() {
3876 params.b_off = b_off_adj;
3877 }
3878 bg
3879 }
3880 MatmulQkvKind::CoopF32 => bind_two_buf0_window(
3881 &dev.device,
3882 matmul_qkv_coop_f32_kernel(&dev.device).expect(
3883 "coop matmul_qkv kernel: hardware feature was checked but kernel missing",
3884 ),
3885 &arena.buffer,
3886 base,
3887 size,
3888 &u,
3889 ),
3890 MatmulQkvKind::F32 => bind_two_buf0_window(
3891 &dev.device,
3892 matmul_qkv_kernel(&dev.device),
3893 &arena.buffer,
3894 base,
3895 size,
3896 &u,
3897 ),
3898 };
3899 uniforms.push(u);
3900 bind_groups.push(bg);
3901 if qkv_kind == MatmulQkvKind::CoopF16Vk {
3902 coop_f16_vk_wide_bind_groups.insert(
3903 schedule.len() - 1,
3904 bind_two_buf0_window(
3905 &dev.device,
3906 matmul_qkv_kernel(&dev.device),
3907 &arena.buffer,
3908 base,
3909 size,
3910 &uniforms[uniforms.len() - 1],
3911 ),
3912 );
3913 }
3914 } else {
3915 let (mut base, mut size, param_anchor) = arena_matmul_bind_window(
3916 &dev.device,
3917 &arena,
3918 &graph,
3919 ¶m_offsets,
3920 node.id,
3921 a_id,
3922 b_id,
3923 );
3924 let mut scratch = arena.scratch_off as u64;
3925 if param_anchor {
3926 arena_ensure_scratch_in_window(&mut scratch, base, size);
3927 }
3928 if b_is_param && b_bytes > ARENA_STAGE_CAP {
3929 assert!(
3930 param_anchor && arena_tensor_in_window(&arena, b_id, base, size),
3931 "rlx-wgpu FusedMatMul: large param B {:?} not in bind window",
3932 b_id,
3933 );
3934 }
3935 let a_off_f32 = arena_off_in_bind_window(
3936 &graph,
3937 ¶m_offsets,
3938 &dev.device,
3939 &arena,
3940 &mut schedule,
3941 &mut scratch,
3942 a_id,
3943 &mut base,
3944 &mut size,
3945 );
3946 let b_off_f32 = if b_is_param
3947 && b_bytes > ARENA_STAGE_CAP
3948 && arena_tensor_in_window(&arena, b_id, base, size)
3949 {
3950 arena_local_off_f32(&arena, b_id, base)
3951 } else {
3952 arena_off_in_bind_window(
3953 &graph,
3954 ¶m_offsets,
3955 &dev.device,
3956 &arena,
3957 &mut schedule,
3958 &mut scratch,
3959 b_id,
3960 &mut base,
3961 &mut size,
3962 )
3963 };
3964 let bias_off_f32 = arena_off_in_bind_window(
3965 &graph,
3966 ¶m_offsets,
3967 &dev.device,
3968 &arena,
3969 &mut schedule,
3970 &mut scratch,
3971 bias_id,
3972 &mut base,
3973 &mut size,
3974 );
3975 let b_off_global = (arena.offset(b_id) / 4) as u32;
3976 let b_off_bind = if b_is_param
3977 && matches!(
3978 compute_precision,
3979 MatmulCompute::Coop16
3980 | MatmulCompute::CoopF16Vk
3981 | MatmulCompute::F16
3982 ) {
3983 b_off_global
3984 } else {
3985 b_off_f32
3986 };
3987 maybe_push_coop_f16_vk_casts(
3988 &graph,
3989 a_id,
3990 b_id,
3991 &coop_f16_vk_mirror_acts,
3992 &dev.device,
3993 &arena,
3994 &mut schedule,
3995 &mut uniforms,
3996 &mut bind_groups,
3997 &mm_cast,
3998 compute_precision,
3999 a_off_f32,
4000 m,
4001 k,
4002 1,
4003 b_off_bind,
4004 n,
4005 );
4006 schedule.push(Step::Matmul {
4007 m,
4008 k,
4009 n,
4010 batch: 1,
4011 a_batch_stride: 0,
4012 b_batch_stride: 0,
4013 c_batch_stride: 0,
4014 a_off_f32,
4015 b_off_f32,
4016 c_off_f32: arena_local_off_f32(&arena, node.id, base),
4017 has_bias: 1,
4018 bias_off_f32,
4019 act_id,
4020 b_is_param,
4021 compute_precision,
4022 });
4023 register_coop_f16_vk_b_param(
4024 &mut coop_f16_b_param,
4025 ¶m_offsets,
4026 b_id,
4027 b_off_bind,
4028 compute_precision,
4029 );
4030 let u = emit_uniform(std::mem::size_of::<MatmulParams>());
4031 let (bg, b_off_adj) = build_matmul_bind_group(
4032 &dev.device,
4033 mm_k,
4034 mm_w,
4035 &mm_f16w,
4036 &mm_f16c,
4037 &mm_coop,
4038 &mm_coop_f32,
4039 &arena,
4040 base,
4041 size,
4042 &u,
4043 b_is_param,
4044 compute_precision,
4045 k,
4046 n,
4047 1,
4048 b_off_bind,
4049 0,
4050 );
4051 if let Some(Step::Matmul { b_off_f32, .. }) = schedule.last_mut() {
4052 *b_off_f32 = b_off_adj;
4053 }
4054 uniforms.push(u);
4055 bind_groups.push(bg);
4056 if compute_precision == MatmulCompute::CoopF16Vk {
4057 coop_f16_vk_wide_bind_groups.insert(
4058 schedule.len() - 1,
4059 bind_two_buf0_window(
4060 &dev.device,
4061 mm_w_active_compile,
4062 &arena.buffer,
4063 base,
4064 size,
4065 &uniforms[uniforms.len() - 1],
4066 ),
4067 );
4068 }
4069 }
4070 }
4071
4072 Op::DotGeneral { .. } => {
4073 panic!(
4078 "rlx-wgpu DotGeneral: leaked past unfusion pass — \
4079 check unfuse.rs::expand_dot_general for missing patterns"
4080 );
4081 }
4082
4083 Op::Sample {
4084 top_k,
4085 top_p,
4086 temperature,
4087 seed,
4088 } => {
4089 let in_id = node.inputs[0];
4090 let in_shape = graph.node(in_id).shape.dims();
4091 let inner = in_shape[in_shape.len() - 1].unwrap_static() as u32;
4092 let total: u32 = in_shape.iter().map(|d| d.unwrap_static() as u32).product();
4093 let outer = total / inner.max(1);
4094 let is_greedy = *top_k == 0
4097 && (*top_p - 1.0).abs() < 1e-6
4098 && (*temperature - 1.0).abs() < 1e-6;
4099 if is_greedy {
4100 let p = ArgmaxParams {
4101 outer,
4102 inner,
4103 in_off: (arena.offset(in_id) / 4) as u32,
4104 out_off: (arena.offset(node.id) / 4) as u32,
4105 _p0: 0,
4106 _p1: 0,
4107 _p2: 0,
4108 _p3: 0,
4109 };
4110 schedule.push(Step::Argmax { params: p });
4111 let amk = argmax_kernel(&dev.device);
4112 let u = emit_uniform(std::mem::size_of::<ArgmaxParams>());
4113 let bg = bind_op_output_window(&dev.device, amk, &arena, node.id, &u);
4114 uniforms.push(u);
4115 bind_groups.push(bg);
4116 } else {
4117 let p = SampleParams {
4118 outer,
4119 inner,
4120 in_off: (arena.offset(in_id) / 4) as u32,
4121 out_off: (arena.offset(node.id) / 4) as u32,
4122 top_k: *top_k as u32,
4123 top_p_bits: top_p.to_bits(),
4124 temp_bits: temperature.to_bits(),
4125 seed_lo: *seed as u32,
4126 seed_hi: (*seed >> 32) as u32,
4127 _p0: 0,
4128 _p1: 0,
4129 _p2: 0,
4130 };
4131 schedule.push(Step::Sample { params: p });
4132 let sk = sample_kernel(&dev.device);
4133 let u = emit_uniform(std::mem::size_of::<SampleParams>());
4134 let bg = bind_op_output_window(&dev.device, sk, &arena, node.id, &u);
4135 uniforms.push(u);
4136 bind_groups.push(bg);
4137 }
4138 }
4139
4140 Op::Pool {
4141 kind,
4142 kernel_size,
4143 stride,
4144 padding,
4145 } => {
4146 let in_shape = graph.node(node.inputs[0]).shape.dims();
4147 let out_shape = node.shape.dims();
4148 let op_id: u32 = match kind {
4149 ReduceOp::Sum => 0,
4150 ReduceOp::Mean => 1,
4151 ReduceOp::Max => 2,
4152 ReduceOp::Min => 3,
4153 ReduceOp::Prod => 4,
4154 };
4155 match (kernel_size.len(), in_shape.len(), out_shape.len()) {
4156 (1, 3, 3) => {
4157 let p = Pool1dParams {
4158 n: in_shape[0].unwrap_static() as u32,
4159 c: in_shape[1].unwrap_static() as u32,
4160 l: in_shape[2].unwrap_static() as u32,
4161 l_out: out_shape[2].unwrap_static() as u32,
4162 kl: kernel_size[0] as u32,
4163 sl: stride.first().copied().unwrap_or(1) as u32,
4164 pl: padding.first().copied().unwrap_or(0) as u32,
4165 op: op_id,
4166 in_off: (arena.offset(node.inputs[0]) / 4) as u32,
4167 out_off: (arena.offset(node.id) / 4) as u32,
4168 _p0: 0,
4169 _p1: 0,
4170 _p2: 0,
4171 _p3: 0,
4172 _p4: 0,
4173 _p5: 0,
4174 };
4175 schedule.push(Step::Pool1d { params: p });
4176 let pk = pool1d_kernel(&dev.device);
4177 let u = emit_uniform(std::mem::size_of::<Pool1dParams>());
4178 let bg = bind_op_output_window(&dev.device, pk, &arena, node.id, &u);
4179 uniforms.push(u);
4180 bind_groups.push(bg);
4181 }
4182 (2, 4, 4) => {
4183 let p = Pool2dParams {
4184 n: in_shape[0].unwrap_static() as u32,
4185 c: in_shape[1].unwrap_static() as u32,
4186 h: in_shape[2].unwrap_static() as u32,
4187 w: in_shape[3].unwrap_static() as u32,
4188 h_out: out_shape[2].unwrap_static() as u32,
4189 w_out: out_shape[3].unwrap_static() as u32,
4190 kh: kernel_size[0] as u32,
4191 kw: kernel_size[1] as u32,
4192 sh: stride.first().copied().unwrap_or(1) as u32,
4193 sw: stride.get(1).copied().unwrap_or(1) as u32,
4194 ph: padding.first().copied().unwrap_or(0) as u32,
4195 pw: padding.get(1).copied().unwrap_or(0) as u32,
4196 op: op_id,
4197 in_off: (arena.offset(node.inputs[0]) / 4) as u32,
4198 out_off: (arena.offset(node.id) / 4) as u32,
4199 _p0: 0,
4200 _p1: 0,
4201 _p2: 0,
4202 };
4203 schedule.push(Step::Pool2d { params: p });
4204 let pk = pool2d_kernel(&dev.device);
4205 let u = emit_uniform(std::mem::size_of::<Pool2dParams>());
4206 let bg = bind_op_output_window(&dev.device, pk, &arena, node.id, &u);
4207 uniforms.push(u);
4208 bind_groups.push(bg);
4209 }
4210 (3, 5, 5) => {
4211 let p = Pool3dParams {
4212 n: in_shape[0].unwrap_static() as u32,
4213 c: in_shape[1].unwrap_static() as u32,
4214 d: in_shape[2].unwrap_static() as u32,
4215 h: in_shape[3].unwrap_static() as u32,
4216 w: in_shape[4].unwrap_static() as u32,
4217 d_out: out_shape[2].unwrap_static() as u32,
4218 h_out: out_shape[3].unwrap_static() as u32,
4219 w_out: out_shape[4].unwrap_static() as u32,
4220 kd: kernel_size[0] as u32,
4221 kh: kernel_size[1] as u32,
4222 kw: kernel_size[2] as u32,
4223 sd: stride.first().copied().unwrap_or(1) as u32,
4224 sh: stride.get(1).copied().unwrap_or(1) as u32,
4225 sw: stride.get(2).copied().unwrap_or(1) as u32,
4226 pd: padding.first().copied().unwrap_or(0) as u32,
4227 ph: padding.get(1).copied().unwrap_or(0) as u32,
4228 pw: padding.get(2).copied().unwrap_or(0) as u32,
4229 op: op_id,
4230 in_off: (arena.offset(node.inputs[0]) / 4) as u32,
4231 out_off: (arena.offset(node.id) / 4) as u32,
4232 _p0: 0,
4233 _p1: 0,
4234 };
4235 schedule.push(Step::Pool3d { params: p });
4236 let pk = pool3d_kernel(&dev.device);
4237 let u = emit_uniform(std::mem::size_of::<Pool3dParams>());
4238 let bg = bind_op_output_window(&dev.device, pk, &arena, node.id, &u);
4239 uniforms.push(u);
4240 bind_groups.push(bg);
4241 }
4242 (k, n, m) => panic!(
4243 "rlx-wgpu Pool: kernel-rank {k} with input rank {n} / \
4244 output rank {m} not supported (use 1D/2D/3D NCHW)"
4245 ),
4246 }
4247 }
4248
4249 Op::Conv {
4250 kernel_size,
4251 stride,
4252 padding,
4253 dilation,
4254 groups,
4255 } => {
4256 let in_id = node.inputs[0];
4257 let w_id = node.inputs[1];
4258 let win_ids = [node.id, in_id, w_id];
4259 let max_binding = dev.device.limits().max_storage_buffer_binding_size;
4260 let fits = arena_span_bytes(&arena, &win_ids) <= max_binding;
4261 let mut scratch = arena.scratch_off as u64;
4262 let (mut base, mut size, param_anchor) = arena_multi_op_window(
4263 &dev.device,
4264 &arena,
4265 &graph,
4266 ¶m_offsets,
4267 &mut schedule,
4268 &mut scratch,
4269 &win_ids,
4270 );
4271 arena_expand_bind_window(&arena, &win_ids, &mut base, &mut size, max_binding);
4272 if !fits && !param_anchor {
4273 base = arena_bind_window_covering_scratch_if_needed(
4274 &arena, base, size, scratch,
4275 );
4276 }
4277 let in_off = arena_off_in_bind_window(
4278 &graph,
4279 ¶m_offsets,
4280 &dev.device,
4281 &arena,
4282 &mut schedule,
4283 &mut scratch,
4284 in_id,
4285 &mut base,
4286 &mut size,
4287 );
4288 let w_off = arena_off_in_bind_window(
4289 &graph,
4290 ¶m_offsets,
4291 &dev.device,
4292 &arena,
4293 &mut schedule,
4294 &mut scratch,
4295 w_id,
4296 &mut base,
4297 &mut size,
4298 );
4299 let out_off = arena_off_in_bind_window(
4300 &graph,
4301 ¶m_offsets,
4302 &dev.device,
4303 &arena,
4304 &mut schedule,
4305 &mut scratch,
4306 node.id,
4307 &mut base,
4308 &mut size,
4309 );
4310
4311 let in_shape = graph.node(in_id).shape.dims();
4312 let w_shape = graph.node(w_id).shape.dims();
4313 let out_shape = node.shape.dims();
4314 let s = |i: usize| stride.get(i).copied().unwrap_or(1) as u32;
4315 let p = |i: usize| padding.get(i).copied().unwrap_or(0) as u32;
4316 let d = |i: usize| dilation.get(i).copied().unwrap_or(1) as u32;
4317 match (
4318 kernel_size.len(),
4319 in_shape.len(),
4320 w_shape.len(),
4321 out_shape.len(),
4322 ) {
4323 (1, 3, 3, 3) => {
4324 let p1 = Conv1dParams {
4325 n: in_shape[0].unwrap_static() as u32,
4326 c_in: in_shape[1].unwrap_static() as u32,
4327 c_out: out_shape[1].unwrap_static() as u32,
4328 l: in_shape[2].unwrap_static() as u32,
4329 l_out: out_shape[2].unwrap_static() as u32,
4330 kl: kernel_size[0] as u32,
4331 sl: s(0),
4332 pl: p(0),
4333 dl: d(0),
4334 groups: *groups as u32,
4335 in_off,
4336 w_off,
4337 out_off,
4338 _p0: 0,
4339 _p1: 0,
4340 _p2: 0,
4341 };
4342 schedule.push(Step::Conv1d { params: p1 });
4343 let ck = conv1d_kernel(&dev.device);
4344 let u = emit_uniform(std::mem::size_of::<Conv1dParams>());
4345 let bg = bind_two_buf0_window(
4346 &dev.device,
4347 ck,
4348 &arena.buffer,
4349 base,
4350 size,
4351 &u,
4352 );
4353 uniforms.push(u);
4354 bind_groups.push(bg);
4355 }
4356 (2, 4, 4, 4) => {
4357 let p2 = Conv2dParams {
4358 n: in_shape[0].unwrap_static() as u32,
4359 c_in: in_shape[1].unwrap_static() as u32,
4360 c_out: out_shape[1].unwrap_static() as u32,
4361 h: in_shape[2].unwrap_static() as u32,
4362 w: in_shape[3].unwrap_static() as u32,
4363 h_out: out_shape[2].unwrap_static() as u32,
4364 w_out: out_shape[3].unwrap_static() as u32,
4365 kh: kernel_size[0] as u32,
4366 kw: kernel_size[1] as u32,
4367 sh: s(0),
4368 sw: s(1),
4369 ph: p(0),
4370 pw: p(1),
4371 dh: d(0),
4372 dw: d(1),
4373 groups: *groups as u32,
4374 in_off,
4375 w_off,
4376 out_off,
4377 };
4378 schedule.push(Step::Conv2d { params: p2 });
4379 let ck = conv2d_kernel(&dev.device);
4380 let u = emit_uniform(std::mem::size_of::<Conv2dParams>());
4381 let bg = bind_two_buf0_window(
4382 &dev.device,
4383 ck,
4384 &arena.buffer,
4385 base,
4386 size,
4387 &u,
4388 );
4389 uniforms.push(u);
4390 bind_groups.push(bg);
4391 }
4392 (3, 5, 5, 5) => {
4393 let p3 = Conv3dParams {
4394 n: in_shape[0].unwrap_static() as u32,
4395 c_in: in_shape[1].unwrap_static() as u32,
4396 c_out: out_shape[1].unwrap_static() as u32,
4397 d: in_shape[2].unwrap_static() as u32,
4398 h: in_shape[3].unwrap_static() as u32,
4399 w: in_shape[4].unwrap_static() as u32,
4400 d_out: out_shape[2].unwrap_static() as u32,
4401 h_out: out_shape[3].unwrap_static() as u32,
4402 w_out: out_shape[4].unwrap_static() as u32,
4403 kd: kernel_size[0] as u32,
4404 kh: kernel_size[1] as u32,
4405 kw: kernel_size[2] as u32,
4406 sd: s(0),
4407 sh: s(1),
4408 sw: s(2),
4409 pd: p(0),
4410 ph: p(1),
4411 pw: p(2),
4412 dd: d(0),
4413 dh: d(1),
4414 dw: d(2),
4415 groups: *groups as u32,
4416 in_off,
4417 w_off,
4418 out_off,
4419 _p0: 0,
4420 };
4421 schedule.push(Step::Conv3d { params: p3 });
4422 let ck = conv3d_kernel(&dev.device);
4423 let u = emit_uniform(std::mem::size_of::<Conv3dParams>());
4424 let bg = bind_two_buf0_window(
4425 &dev.device,
4426 ck,
4427 &arena.buffer,
4428 base,
4429 size,
4430 &u,
4431 );
4432 uniforms.push(u);
4433 bind_groups.push(bg);
4434 }
4435 (k, ni, wi, mi) => panic!(
4436 "rlx-wgpu Conv: rank kernel={k} in={ni} weight={wi} out={mi} \
4437 not supported (use 1D/2D/3D NCHW)"
4438 ),
4439 }
4440 }
4441
4442 Op::Im2Col {
4443 kernel_size,
4444 stride,
4445 padding,
4446 dilation,
4447 } => {
4448 let x_shape = &graph.node(node.inputs[0]).shape;
4449 if kernel_size.len() != 2 || x_shape.rank() != 4 {
4450 panic!("rlx-wgpu Im2Col: 2D NCHW only");
4451 }
4452 let n = match x_shape.dim(0) {
4453 rlx_ir::shape::Dim::Static(v) => v as u32,
4454 _ => 0,
4455 };
4456 let c_in = x_shape.dim(1).unwrap_static() as u32;
4457 let h = x_shape.dim(2).unwrap_static() as u32;
4458 let w = x_shape.dim(3).unwrap_static() as u32;
4459 let kh = kernel_size[0] as u32;
4460 let kw = kernel_size[1] as u32;
4461 let sh = stride.first().copied().unwrap_or(1) as u32;
4462 let sw = stride.get(1).copied().unwrap_or(1) as u32;
4463 let ph = padding.first().copied().unwrap_or(0) as u32;
4464 let pw = padding.get(1).copied().unwrap_or(0) as u32;
4465 let dh = dilation.first().copied().unwrap_or(1) as u32;
4466 let dw_dil = dilation.get(1).copied().unwrap_or(1) as u32;
4467 let h_out = rlx_ir::shape::conv2d_spatial_output(
4468 h as usize,
4469 kh as usize,
4470 sh as usize,
4471 ph as usize,
4472 dh as usize,
4473 ) as u32;
4474 let w_out = rlx_ir::shape::conv2d_spatial_output(
4475 w as usize,
4476 kw as usize,
4477 sw as usize,
4478 pw as usize,
4479 dw_dil as usize,
4480 ) as u32;
4481 schedule.push(Step::Im2ColHost {
4482 x_byte_off: arena.offset(node.inputs[0]) as u32,
4483 col_byte_off: arena.offset(node.id) as u32,
4484 n,
4485 c_in,
4486 h,
4487 w,
4488 h_out,
4489 w_out,
4490 kh,
4491 kw,
4492 sh,
4493 sw,
4494 ph,
4495 pw,
4496 dh,
4497 dw_dil,
4498 });
4499 }
4500
4501 Op::Cumsum { axis, exclusive } => {
4502 let in_id = node.inputs[0];
4503 let in_shape = graph.node(in_id).shape.dims();
4504 let last = (in_shape.len() - 1) as i32;
4505 if *axis != -1 && *axis != last {
4506 panic!("rlx-wgpu Cumsum: only last-axis wired (got axis={axis})");
4507 }
4508 let inner = in_shape[in_shape.len() - 1].unwrap_static() as u32;
4509 let total: u32 = in_shape.iter().map(|d| d.unwrap_static() as u32).product();
4510 let outer = total / inner.max(1);
4511 let p = CumsumParams {
4512 outer,
4513 inner,
4514 in_off: (arena.offset(in_id) / 4) as u32,
4515 out_off: (arena.offset(node.id) / 4) as u32,
4516 exclusive: if *exclusive { 1 } else { 0 },
4517 _p0: 0,
4518 _p1: 0,
4519 _p2: 0,
4520 };
4521 schedule.push(Step::Cumsum { params: p });
4522 let ck2 = cumsum_kernel(&dev.device);
4523 let u = emit_uniform(std::mem::size_of::<CumsumParams>());
4524 let bg = bind_op_output_window(&dev.device, ck2, &arena, node.id, &u);
4525 uniforms.push(u);
4526 bind_groups.push(bg);
4527 }
4528 Op::Fft { inverse, norm } => {
4529 let in_id = node.inputs[0];
4530 let in_shape = graph.node(in_id).shape.clone();
4531 let meta = rlx_ir::fft::fft_meta(&in_shape);
4532 let dtype = in_shape.dtype();
4533 let use_gpu = rlx_ir::fft::gpu_fft_native_eligible(dtype, meta.n_complex)
4534 && meta.n_complex >= 2;
4535 let scale = norm.output_scale(meta.n_complex, *inverse) as f32;
4536 if use_gpu {
4537 schedule.push(Step::FftGpu {
4538 src_off: (arena.offset(in_id) / 4) as u32,
4539 dst_off: (arena.offset(node.id) / 4) as u32,
4540 outer: meta.outer as u32,
4541 n: meta.n_complex as u32,
4542 inverse: if *inverse { 1 } else { 0 },
4543 norm_scale: scale,
4544 });
4545 fft_gpu_steps.push(crate::fft_dispatch::FftGpuResources::new(
4546 &dev.device,
4547 &arena.buffer,
4548 ));
4549 } else {
4550 schedule.push(Step::FftHost {
4551 src_byte_off: arena.offset(in_id) as u32,
4552 dst_byte_off: arena.offset(node.id) as u32,
4553 outer: meta.outer as u32,
4554 n_complex: meta.n_complex as u32,
4555 inverse: *inverse,
4556 norm_tag: norm.tag(),
4557 dtype_tag: fft_dtype_tag(dtype),
4558 });
4559 }
4560 }
4561 Op::WelchPeaks { k, n_segments } => {
4562 let spec_shape = graph.node(node.inputs[0]).shape.clone();
4563 let meta = rlx_ir::audio::welch_peaks_meta(&spec_shape, *k, *n_segments)
4564 .unwrap_or_else(|e| panic!("Op::WelchPeaks: {e}"));
4565 let use_gpu = rlx_ir::audio::welch_peaks_gpu_native_eligible(
4566 &spec_shape,
4567 *k,
4568 *n_segments,
4569 )
4570 .unwrap_or(false);
4571 if use_gpu {
4572 let p = WelchPeaksGpuParams {
4573 spec_off: (arena.offset(node.inputs[0]) / 4) as u32,
4574 dst_off: (arena.offset(node.id) / 4) as u32,
4575 welch_batch: meta.welch_batch as u32,
4576 n_fft: meta.n_fft as u32,
4577 n_segments: meta.n_segments as u32,
4578 k: meta.k as u32,
4579 n_bins: meta.n_bins as u32,
4580 _p0: 0,
4581 _p1: 0,
4582 };
4583 schedule.push(Step::WelchPeaksGpu { params: p });
4584 let wk = welch_peaks_gpu_kernel(&dev.device);
4585 let u = emit_uniform(std::mem::size_of::<WelchPeaksGpuParams>());
4586 let bg = bind_op_output_window(&dev.device, wk, &arena, node.id, &u);
4587 uniforms.push(u);
4588 bind_groups.push(bg);
4589 } else {
4590 schedule.push(Step::WelchPeaksHost {
4591 spec_byte_off: arena.offset(node.inputs[0]) as u32,
4592 dst_byte_off: arena.offset(node.id) as u32,
4593 welch_batch: meta.welch_batch as u32,
4594 n_fft: meta.n_fft as u32,
4595 n_segments: meta.n_segments as u32,
4596 k: meta.k as u32,
4597 });
4598 }
4599 }
4600 Op::LogMel => {
4601 let spec_shape = graph.node(node.inputs[0]).shape.clone();
4602 let filt_shape = graph.node(node.inputs[1]).shape.clone();
4603 let meta = rlx_ir::audio::log_mel_meta(&spec_shape, &filt_shape)
4604 .unwrap_or_else(|e| panic!("Op::LogMel: {e}"));
4605 schedule.push(Step::LogMelHost {
4606 spec_byte_off: arena.offset(node.inputs[0]) as u32,
4607 filt_byte_off: arena.offset(node.inputs[1]) as u32,
4608 dst_byte_off: arena.offset(node.id) as u32,
4609 outer: meta.outer as u32,
4610 n_fft: meta.n_fft as u32,
4611 n_bins: meta.n_bins as u32,
4612 n_mels: meta.n_mels as u32,
4613 });
4614 }
4615 Op::LogMelBackward => {
4616 let spec_shape = graph.node(node.inputs[0]).shape.clone();
4617 let filt_shape = graph.node(node.inputs[1]).shape.clone();
4618 let meta = rlx_ir::audio::log_mel_meta(&spec_shape, &filt_shape)
4619 .unwrap_or_else(|e| panic!("Op::LogMelBackward: {e}"));
4620 schedule.push(Step::LogMelBackwardHost {
4621 spec_byte_off: arena.offset(node.inputs[0]) as u32,
4622 filt_byte_off: arena.offset(node.inputs[1]) as u32,
4623 dy_byte_off: arena.offset(node.inputs[2]) as u32,
4624 dst_byte_off: arena.offset(node.id) as u32,
4625 outer: meta.outer as u32,
4626 n_fft: meta.n_fft as u32,
4627 n_bins: meta.n_bins as u32,
4628 n_mels: meta.n_mels as u32,
4629 });
4630 }
4631 Op::SelectiveScan { state_size } => {
4632 if *state_size > 256 {
4633 panic!(
4634 "rlx-wgpu SelectiveScan: state_size {} exceeds compile-time \
4635 cap of 256 (kernel uses fixed-size private array)",
4636 state_size
4637 );
4638 }
4639 let x_id = node.inputs[0];
4640 let dt_id = node.inputs[1];
4641 let a_id = node.inputs[2];
4642 let b_id = node.inputs[3];
4643 let c_id = node.inputs[4];
4644 let in_dims = graph.node(x_id).shape.dims();
4645 let seq = in_dims[1].unwrap_static() as u32;
4646 let p = SelectiveScanParams {
4647 batch: in_dims[0].unwrap_static() as u32,
4648 seq,
4649 hidden: in_dims[2].unwrap_static() as u32,
4650 state_size: *state_size as u32,
4651 x_off: (arena.offset(x_id) / 4) as u32,
4652 delta_off: (arena.offset(dt_id) / 4) as u32,
4653 a_off: (arena.offset(a_id) / 4) as u32,
4654 b_off: (arena.offset(b_id) / 4) as u32,
4655 c_off: (arena.offset(c_id) / 4) as u32,
4656 out_off: (arena.offset(node.id) / 4) as u32,
4657 seq_stride: seq,
4660 _p1: 0,
4661 _p2: 0,
4662 _p3: 0,
4663 _p4: 0,
4664 _p5: 0,
4665 };
4666 schedule.push(Step::SelectiveScan { params: p });
4667 let ssk = selective_scan_kernel(&dev.device);
4668 let u = emit_uniform(std::mem::size_of::<SelectiveScanParams>());
4669 let bg = bind_op_output_window(&dev.device, ssk, &arena, node.id, &u);
4670 uniforms.push(u);
4671 bind_groups.push(bg);
4672 }
4673 Op::GatedDeltaNet {
4674 state_size,
4675 carry_state,
4676 } => {
4677 if *state_size > rlx_cpu::gdn::GDN_MAX_STATE {
4678 panic!(
4679 "rlx-wgpu GatedDeltaNet: state_size {state_size} > {}",
4680 rlx_cpu::gdn::GDN_MAX_STATE
4681 );
4682 }
4683 let q_id = node.inputs[0];
4684 let q_shape = &graph.node(q_id).shape;
4685 let state_off = if *carry_state {
4686 arena.offset(node.inputs[5])
4687 } else {
4688 0
4689 };
4690 schedule.push(Step::GatedDeltaNet {
4691 q_byte_off: arena.offset(q_id) as u32,
4692 k_byte_off: arena.offset(node.inputs[1]) as u32,
4693 v_byte_off: arena.offset(node.inputs[2]) as u32,
4694 g_byte_off: arena.offset(node.inputs[3]) as u32,
4695 beta_byte_off: arena.offset(node.inputs[4]) as u32,
4696 state_byte_off: state_off as u32,
4697 dst_byte_off: arena.offset(node.id) as u32,
4698 batch: q_shape.dim(0).unwrap_static() as u32,
4699 seq: q_shape.dim(1).unwrap_static() as u32,
4700 heads: q_shape.dim(2).unwrap_static() as u32,
4701 state_size: *state_size as u32,
4702 use_carry: *carry_state,
4703 });
4704 if gguf_host_pad.is_none() {
4705 let bk = binary_kernel(&dev.device);
4706 let u = emit_uniform(256);
4707 gguf_host_pad = Some((
4708 u.clone(),
4709 bind_op_output_window(&dev.device, bk, &arena, node.id, &u),
4710 ));
4711 }
4712 let (u, bg) = gguf_host_pad.as_ref().unwrap();
4713 uniforms.push(u.clone());
4714 bind_groups.push(bg.clone());
4715 }
4716 Op::Custom { name, attrs, .. } => match name.as_str() {
4717 "llada2.group_limited_gate" => {
4718 let sig_id = node.inputs[0];
4719 let route_id = node.inputs[1];
4720 let n_elems = graph.node(sig_id).shape.num_elements().unwrap() as u32;
4721 let mut attr_buf = [0u8; 20];
4722 let n = attrs.len().min(20);
4723 attr_buf[..n].copy_from_slice(&attrs[..n]);
4724 schedule.push(Step::Llada2GroupLimitedGate {
4725 sig_byte_off: arena.offset(sig_id) as u32,
4726 route_byte_off: arena.offset(route_id) as u32,
4727 out_byte_off: arena.offset(node.id) as u32,
4728 n_elems,
4729 attrs: attr_buf,
4730 });
4731 }
4732 "umap.knn" => {
4733 let pw_id = node.inputs[0];
4734 let pw_shape = graph.node(pw_id).shape.dims();
4735 let n = pw_shape[0].unwrap_static() as u32;
4736 let k = if attrs.len() >= 4 {
4737 u32::from_le_bytes(attrs[..4].try_into().unwrap())
4738 } else {
4739 panic!("rlx-wgpu: umap.knn attrs missing k");
4740 };
4741 let pw_off = arena.offset(pw_id) as u32;
4742 let out_off = arena.offset(node.id) as u32;
4743 if n as usize >= crate::umap_knn_host::UMAP_KNN_GPU_MIN_N {
4744 let p = UmapKnnParams {
4745 n,
4746 k,
4747 pw_off: pw_off / 4,
4748 out_off: out_off / 4,
4749 _p0: 0,
4750 _p1: 0,
4751 _p2: 0,
4752 };
4753 schedule.push(Step::UmapKnn { params: p });
4754 let uk = umap_knn_kernel(&dev.device);
4755 let u = emit_uniform(std::mem::size_of::<UmapKnnParams>());
4756 let bg = bind_op_output_window(&dev.device, uk, &arena, node.id, &u);
4757 uniforms.push(u);
4758 bind_groups.push(bg);
4759 } else {
4760 schedule.push(Step::UmapKnnHost {
4761 pairwise_byte_off: pw_off,
4762 out_byte_off: out_off,
4763 n,
4764 k,
4765 });
4766 }
4767 }
4768 other => panic!("rlx-wgpu: unsupported Op::Custom('{other}')"),
4769 },
4770 Op::GroupedMatMul => {
4771 let in_id = node.inputs[0];
4773 let w_id = node.inputs[1];
4774 let idx_id = node.inputs[2];
4775 let in_dims = graph.node(in_id).shape.dims();
4776 let w_dims = graph.node(w_id).shape.dims();
4777 let m = in_dims[0].unwrap_static() as u32;
4778 let k = in_dims[1].unwrap_static() as u32;
4779 let n = w_dims[2].unwrap_static() as u32;
4780 let ne = w_dims[0].unwrap_static() as u32;
4781 let p = GroupedMatmulParams {
4782 m,
4783 k,
4784 n,
4785 num_experts: ne,
4786 in_off: (arena.offset(in_id) / 4) as u32,
4787 w_off: (arena.offset(w_id) / 4) as u32,
4788 idx_off: (arena.offset(idx_id) / 4) as u32,
4789 out_off: (arena.offset(node.id) / 4) as u32,
4790 };
4791 schedule.push(Step::GroupedMatmul { params: p });
4792 let gk = grouped_matmul_kernel(&dev.device);
4793 let u = emit_uniform(std::mem::size_of::<GroupedMatmulParams>());
4794 let bg = bind_op_output_window(&dev.device, gk, &arena, node.id, &u);
4795 uniforms.push(u);
4796 bind_groups.push(bg);
4797 }
4798 Op::DequantGroupedMatMul { scheme } => {
4799 let in_id = node.inputs[0];
4800 let w_id = node.inputs[1];
4801 let idx_id = node.inputs[2];
4802 let in_dims = graph.node(in_id).shape.dims();
4803 let out_dims = node.shape.dims();
4804 let m = in_dims[0].unwrap_static() as u32;
4805 let k = in_dims[1].unwrap_static() as u32;
4806 let n = out_dims[out_dims.len() - 1].unwrap_static() as u32;
4807 let block_elems = scheme.gguf_block_size() as usize;
4808 let block_bytes = scheme.gguf_block_bytes() as usize;
4809 let slab_bytes = (k as usize * n as usize) / block_elems * block_bytes;
4810 let total_bytes = graph.node(w_id).shape.num_elements().unwrap();
4811 let ne = (total_bytes / slab_bytes.max(1)) as u32;
4812 schedule.push(Step::DequantGroupedMatmulGguf {
4813 m,
4814 k,
4815 n,
4816 num_experts: ne,
4817 scheme_id: crate::gguf_host::gguf_scheme_id(*scheme),
4818 x_byte_off: arena.offset(in_id) as u32,
4819 w_byte_off: arena.offset(w_id) as u32,
4820 idx_byte_off: arena.offset(idx_id) as u32,
4821 out_byte_off: arena.offset(node.id) as u32,
4822 });
4823 if gguf_host_pad.is_none() {
4824 let bk = binary_kernel(&dev.device);
4825 let u = emit_uniform(256);
4826 gguf_host_pad = Some((
4827 u.clone(),
4828 bind_op_output_window(&dev.device, bk, &arena, node.id, &u),
4829 ));
4830 }
4831 let (u, bg) = gguf_host_pad.as_ref().unwrap();
4832 uniforms.push(u.clone());
4833 bind_groups.push(bg.clone());
4834 }
4835 Op::TopK { k } => {
4836 let in_id = node.inputs[0];
4837 let in_dims = graph.node(in_id).shape.dims();
4838 let inner = in_dims.last().unwrap().unwrap_static() as u32;
4839 let outer: u32 = in_dims[..in_dims.len() - 1]
4840 .iter()
4841 .map(|d| d.unwrap_static() as u32)
4842 .product::<u32>()
4843 .max(1);
4844 let p = TopKParams {
4845 outer,
4846 inner,
4847 k: *k as u32,
4848 in_off: (arena.offset(in_id) / 4) as u32,
4849 out_off: (arena.offset(node.id) / 4) as u32,
4850 _p0: 0,
4851 _p1: 0,
4852 _p2: 0,
4853 };
4854 schedule.push(Step::TopK { params: p });
4855 let tk = topk_kernel(&dev.device);
4856 let u = emit_uniform(std::mem::size_of::<TopKParams>());
4857 let bg = bind_op_output_window(&dev.device, tk, &arena, node.id, &u);
4858 uniforms.push(u);
4859 bind_groups.push(bg);
4860 }
4861 Op::ScatterAdd => {
4862 let upd_id = node.inputs[0];
4867 let idx_id = node.inputs[1];
4868 let upd_dims = graph.node(upd_id).shape.dims();
4869 let out_dims = node.shape.dims();
4870 let num_updates = upd_dims[0].unwrap_static() as u32;
4871 let trailing: u32 = upd_dims
4872 .iter()
4873 .skip(1)
4874 .map(|d| d.unwrap_static() as u32)
4875 .product::<u32>()
4876 .max(1);
4877 let out_dim = out_dims[0].unwrap_static() as u32;
4878 let out_total = out_dim * trailing;
4879
4880 let common = ScatterAddParams {
4881 op: 0,
4882 out_off: (arena.offset(node.id) / 4) as u32,
4883 upd_off: (arena.offset(upd_id) / 4) as u32,
4884 idx_off: (arena.offset(idx_id) / 4) as u32,
4885 out_total,
4886 num_updates,
4887 trailing,
4888 out_dim,
4889 };
4890 let sk = scatter_add_kernel(&dev.device);
4891
4892 schedule.push(Step::ScatterAdd { params: common });
4894 let u0 = emit_uniform(std::mem::size_of::<ScatterAddParams>());
4895 let bg0 = bind_op_output_window(&dev.device, sk, &arena, node.id, &u0);
4896 uniforms.push(u0);
4897 bind_groups.push(bg0);
4898
4899 let mut acc = common;
4901 acc.op = 1;
4902 schedule.push(Step::ScatterAdd { params: acc });
4903 let u1 = emit_uniform(std::mem::size_of::<ScatterAddParams>());
4904 let bg1 = bind_op_output_window(&dev.device, sk, &arena, node.id, &u1);
4905 uniforms.push(u1);
4906 bind_groups.push(bg1);
4907 }
4908 Op::FusedResidualLN { has_bias, eps } => {
4909 let x_id = node.inputs[0];
4911 let r_id = node.inputs[1];
4912 let (bias_id, g_id, b_id) = if *has_bias {
4913 (node.inputs[2], node.inputs[3], node.inputs[4])
4914 } else {
4915 (x_id, node.inputs[2], node.inputs[3]) };
4917 let in_dims = node.shape.dims();
4918 let inner = in_dims[in_dims.len() - 1].unwrap_static() as u32;
4919 let total: u32 = in_dims.iter().map(|d| d.unwrap_static() as u32).product();
4920 let outer = total / inner.max(1);
4921 let p = FusedResidualLnParams {
4922 outer,
4923 inner,
4924 in_off: (arena.offset(x_id) / 4) as u32,
4925 residual_off: (arena.offset(r_id) / 4) as u32,
4926 bias_off: (arena.offset(bias_id) / 4) as u32,
4927 gamma_off: (arena.offset(g_id) / 4) as u32,
4928 beta_off: (arena.offset(b_id) / 4) as u32,
4929 out_off: (arena.offset(node.id) / 4) as u32,
4930 eps_bits: eps.to_bits(),
4931 has_bias: if *has_bias { 1 } else { 0 },
4932 _p0: 0,
4933 _p1: 0,
4934 };
4935 schedule.push(Step::FusedResidualLn { params: p });
4936 let frk = fused_residual_ln_kernel(&dev.device);
4937 let u = emit_uniform(std::mem::size_of::<FusedResidualLnParams>());
4938 let bg = bind_op_output_window(&dev.device, frk, &arena, node.id, &u);
4939 uniforms.push(u);
4940 bind_groups.push(bg);
4941 }
4942 Op::FusedResidualRmsNorm { has_bias, eps } => {
4943 let x_id = node.inputs[0];
4944 let r_id = node.inputs[1];
4945 let (bias_id, g_id, b_id) = if *has_bias {
4946 (node.inputs[2], node.inputs[3], node.inputs[4])
4947 } else {
4948 (x_id, node.inputs[2], node.inputs[3])
4949 };
4950 let in_dims = node.shape.dims();
4951 let inner = in_dims[in_dims.len() - 1].unwrap_static() as u32;
4952 let total: u32 = in_dims.iter().map(|d| d.unwrap_static() as u32).product();
4953 let outer = total / inner.max(1);
4954 let p = FusedResidualRmsNormParams {
4955 outer,
4956 inner,
4957 in_off: (arena.offset(x_id) / 4) as u32,
4958 residual_off: (arena.offset(r_id) / 4) as u32,
4959 bias_off: (arena.offset(bias_id) / 4) as u32,
4960 gamma_off: (arena.offset(g_id) / 4) as u32,
4961 beta_off: (arena.offset(b_id) / 4) as u32,
4962 out_off: (arena.offset(node.id) / 4) as u32,
4963 eps_bits: eps.to_bits(),
4964 has_bias: if *has_bias { 1 } else { 0 },
4965 _p0: 0,
4966 _p1: 0,
4967 };
4968 schedule.push(Step::FusedResidualRmsNorm { params: p });
4969 let frk = fused_residual_rms_norm_kernel(&dev.device);
4970 let u = emit_uniform(std::mem::size_of::<FusedResidualRmsNormParams>());
4971 let bg = bind_op_output_window(&dev.device, frk, &arena, node.id, &u);
4972 uniforms.push(u);
4973 bind_groups.push(bg);
4974 }
4975 Op::DequantMatMul { scheme } => {
4976 use rlx_ir::QuantScheme;
4977 let x_id = node.inputs[0];
4978 let w_id = node.inputs[1];
4979 let out_dims = node.shape.dims();
4980 let x_dims = graph.node(x_id).shape.dims();
4981 let m = out_dims[0].unwrap_static() as u32;
4982 let n = out_dims[1].unwrap_static() as u32;
4983 let k = x_dims[1].unwrap_static() as u32;
4984 if scheme.is_gguf() {
4985 schedule.push(Step::DequantMatmulGguf {
4986 m,
4987 k,
4988 n,
4989 scheme_id: crate::gguf_host::gguf_scheme_id(*scheme),
4990 x_byte_off: arena.offset(x_id) as u32,
4991 w_byte_off: arena.offset(w_id) as u32,
4992 out_byte_off: arena.offset(node.id) as u32,
4993 });
4994 if gguf_host_pad.is_none() {
4995 let bk = binary_kernel(&dev.device);
4996 let u = emit_uniform(256);
4997 gguf_host_pad = Some((
4998 u.clone(),
4999 bind_op_output_window(&dev.device, bk, &arena, node.id, &u),
5000 ));
5001 }
5002 let (u, bg) = gguf_host_pad.as_ref().unwrap();
5003 uniforms.push(u.clone());
5004 bind_groups.push(bg.clone());
5005 } else {
5006 let (block_size, scheme_id) = match scheme {
5007 QuantScheme::Int8Block { block_size } => (*block_size, 0u32),
5008 QuantScheme::Int8BlockAsym { block_size } => (*block_size, 1u32),
5009 QuantScheme::Int4Block { block_size } => (*block_size, 2u32),
5010 QuantScheme::Fp8E4m3 => (1, 3u32),
5011 QuantScheme::Fp8E5m2 => (1, 4u32),
5012 QuantScheme::Nvfp4Block => (rlx_ir::NVFP4_GROUP_SIZE as u32, 5u32),
5013 other => panic!("rlx-wgpu DequantMatMul: unsupported scheme {other:?}"),
5014 };
5015 let scale_id = node.inputs[2];
5016 let zp_id = node.inputs[3];
5017 let p = DequantMatmulParams {
5018 m,
5019 k,
5020 n,
5021 block_size,
5022 scheme_id,
5023 x_off: (arena.offset(x_id) / 4) as u32,
5024 w_off: (arena.offset(w_id) / 4) as u32,
5025 scale_off: (arena.offset(scale_id) / 4) as u32,
5026 zp_off: (arena.offset(zp_id) / 4) as u32,
5027 out_off: (arena.offset(node.id) / 4) as u32,
5028 _p0: 0,
5029 _p1: 0,
5030 };
5031 schedule.push(Step::DequantMatmul { params: p });
5032 let dk = dequant_matmul_kernel(&dev.device);
5033 let u = emit_uniform(std::mem::size_of::<DequantMatmulParams>());
5034 let bg = bind_op_output_window(&dev.device, dk, &arena, node.id, &u);
5035 uniforms.push(u);
5036 bind_groups.push(bg);
5037 }
5038 }
5039 Op::RmsNormBackwardInput { eps, .. }
5040 | Op::RmsNormBackwardGamma { eps, .. }
5041 | Op::RmsNormBackwardBeta { eps, .. } => {
5042 let x_shape = &graph.node(node.inputs[0]).shape;
5043 let h = x_shape.dim(x_shape.rank() - 1).unwrap_static() as u32;
5044 let rows = (x_shape.num_elements().unwrap() / h.max(1) as usize) as u32;
5045 let foff = |i: usize| (arena.offset(node.inputs[i]) / 4) as u32;
5046 let wrt = match &node.op {
5047 Op::RmsNormBackwardInput { .. } => 0u32,
5048 Op::RmsNormBackwardGamma { .. } => 1u32,
5049 Op::RmsNormBackwardBeta { .. } => 2u32,
5050 _ => unreachable!(),
5051 };
5052 let p = RmsNormBwdParams {
5053 outer: rows,
5054 inner: h,
5055 x_off: foff(0),
5056 gamma_off: foff(1),
5057 beta_off: foff(2),
5058 dy_off: foff(3),
5059 out_off: (arena.offset(node.id) / 4) as u32,
5060 eps_bits: eps.to_bits(),
5061 wrt,
5062 };
5063 let rk = if wrt == 0 {
5064 rms_norm_backward_kernel(&dev.device)
5065 } else {
5066 rms_norm_backward_param_kernel(&dev.device)
5067 };
5068 let u = emit_uniform(std::mem::size_of::<RmsNormBwdParams>());
5069 let bg = bind_op_output_window(&dev.device, rk, &arena, node.id, &u);
5070 match &node.op {
5071 Op::RmsNormBackwardInput { .. } => {
5072 schedule.push(Step::RmsNormBackwardInput { params: p });
5073 }
5074 Op::RmsNormBackwardGamma { .. } => {
5075 schedule.push(Step::RmsNormBackwardGamma { params: p });
5076 }
5077 Op::RmsNormBackwardBeta { .. } => {
5078 schedule.push(Step::RmsNormBackwardBeta { params: p });
5079 }
5080 _ => unreachable!(),
5081 }
5082 uniforms.push(u);
5083 bind_groups.push(bg);
5084 }
5085 Op::LayerNormBackwardInput { eps, .. } => {
5086 let x_shape = &graph.node(node.inputs[0]).shape;
5087 let h = x_shape.dim(x_shape.rank() - 1).unwrap_static() as u32;
5088 let rows = (x_shape.num_elements().unwrap() / h.max(1) as usize) as u32;
5089 let p = LayerNormBwdParams {
5090 outer: rows,
5091 inner: h,
5092 x_off: (arena.offset(node.inputs[0]) / 4) as u32,
5093 gamma_off: (arena.offset(node.inputs[1]) / 4) as u32,
5094 dy_off: (arena.offset(node.inputs[2]) / 4) as u32,
5095 out_off: (arena.offset(node.id) / 4) as u32,
5096 eps_bits: eps.to_bits(),
5097 scratch_off: 0,
5098 };
5099 let rk = layer_norm_backward_input_kernel(&dev.device);
5100 let u = emit_uniform(std::mem::size_of::<LayerNormBwdParams>());
5101 let bg = bind_op_output_window(&dev.device, rk, &arena, node.id, &u);
5102 schedule.push(Step::LayerNormBackwardInput { params: p });
5103 uniforms.push(u);
5104 bind_groups.push(bg);
5105 }
5106 Op::LayerNormBackwardGamma { eps, .. } => {
5107 let x_shape = &graph.node(node.inputs[0]).shape;
5113 let h = x_shape.dim(x_shape.rank() - 1).unwrap_static() as u32;
5114 let rows = (x_shape.num_elements().unwrap() / h.max(1) as usize) as u32;
5115 const ROWS_PER_WG: u32 = 16;
5116 let num_workgroups = rows.div_ceil(ROWS_PER_WG.max(1));
5117 let scratch_off_words = (arena.scratch_off / 4) as u32;
5118 let partial_params = LayerNormBwdParams {
5119 outer: rows,
5120 inner: h,
5121 x_off: (arena.offset(node.inputs[0]) / 4) as u32,
5122 gamma_off: 0,
5123 dy_off: (arena.offset(node.inputs[1]) / 4) as u32,
5124 out_off: 0, eps_bits: eps.to_bits(),
5126 scratch_off: scratch_off_words,
5127 };
5128 let reduce_params = LayerNormBwdParams {
5129 outer: num_workgroups,
5132 inner: h,
5133 x_off: 0,
5134 gamma_off: 0,
5135 dy_off: 0,
5136 out_off: (arena.offset(node.id) / 4) as u32,
5137 eps_bits: eps.to_bits(),
5138 scratch_off: scratch_off_words,
5139 };
5140 let p_k = layer_norm_backward_gamma_partial_kernel(&dev.device);
5141 let r_k = layer_norm_backward_gamma_reduce_kernel(&dev.device);
5142 let p_u = emit_uniform(std::mem::size_of::<LayerNormBwdParams>());
5143 let r_u = emit_uniform(std::mem::size_of::<LayerNormBwdParams>());
5144 let p_bg = bind_op_output_window(&dev.device, p_k, &arena, node.id, &p_u);
5145 let r_bg = bind_op_output_window(&dev.device, r_k, &arena, node.id, &r_u);
5146 schedule.push(Step::LayerNormBackwardGammaPartial {
5147 params: partial_params,
5148 num_workgroups,
5149 });
5150 schedule.push(Step::LayerNormBackwardGammaReduce {
5151 params: reduce_params,
5152 });
5153 uniforms.push(p_u);
5154 uniforms.push(r_u);
5155 bind_groups.push(p_bg);
5156 bind_groups.push(r_bg);
5157 }
5158 Op::RopeBackward { head_dim, n_rot } => {
5159 let dy_shape = &graph.node(node.inputs[0]).shape;
5160 let (batch, seq, hidden) = if dy_shape.rank() >= 3 {
5161 (
5162 dy_shape.dim(0).unwrap_static() as u32,
5163 dy_shape.dim(1).unwrap_static() as u32,
5164 dy_shape.dim(2).unwrap_static() as u32,
5165 )
5166 } else {
5167 (
5168 1,
5169 dy_shape.dim(0).unwrap_static() as u32,
5170 dy_shape.dim(1).unwrap_static() as u32,
5171 )
5172 };
5173 let cos_len = graph.node(node.inputs[1]).shape.num_elements().unwrap() as u32;
5174 let p = RopeBwdParams {
5175 batch,
5176 seq,
5177 hidden,
5178 head_dim: *head_dim as u32,
5179 n_rot: *n_rot as u32,
5180 dy_off: (arena.offset(node.inputs[0]) / 4) as u32,
5181 cos_off: (arena.offset(node.inputs[1]) / 4) as u32,
5182 sin_off: (arena.offset(node.inputs[2]) / 4) as u32,
5183 dx_off: (arena.offset(node.id) / 4) as u32,
5184 cos_len,
5185 };
5186 let rk = rope_backward_kernel(&dev.device);
5187 let u = emit_uniform(std::mem::size_of::<RopeBwdParams>());
5188 let bg = bind_op_output_window(&dev.device, rk, &arena, node.id, &u);
5189 schedule.push(Step::RopeBackward { params: p });
5190 uniforms.push(u);
5191 bind_groups.push(bg);
5192 }
5193 Op::CumsumBackward { exclusive, .. } => {
5194 let dy_shape = &graph.node(node.inputs[0]).shape;
5195 let cols = dy_shape.dim(dy_shape.rank() - 1).unwrap_static() as u32;
5196 let rows = (dy_shape.num_elements().unwrap() / cols.max(1) as usize) as u32;
5197 let p = CumsumBwdParams {
5198 outer: rows,
5199 inner: cols,
5200 dy_off: (arena.offset(node.inputs[0]) / 4) as u32,
5201 dx_off: (arena.offset(node.id) / 4) as u32,
5202 exclusive: if *exclusive { 1 } else { 0 },
5203 _p0: 0,
5204 _p1: 0,
5205 _p2: 0,
5206 };
5207 let ck = cumsum_backward_kernel(&dev.device);
5208 let u = emit_uniform(std::mem::size_of::<CumsumBwdParams>());
5209 let bg = bind_op_output_window(&dev.device, ck, &arena, node.id, &u);
5210 schedule.push(Step::CumsumBackward { params: p });
5211 uniforms.push(u);
5212 bind_groups.push(bg);
5213 }
5214 Op::GatherBackward { .. } => {
5215 let dy_shape = &graph.node(node.inputs[0]).shape;
5216 let idx_shape = &graph.node(node.inputs[1]).shape;
5217 let out_shape = &node.shape;
5218 let rank = out_shape.rank();
5219 let axis = match &node.op {
5220 Op::GatherBackward { axis } => *axis,
5221 _ => 0,
5222 };
5223 let axis_u = if axis < 0 {
5224 (rank as i32 + axis) as usize
5225 } else {
5226 axis as usize
5227 };
5228 let outer: usize = (0..axis_u)
5229 .map(|i| dy_shape.dim(i).unwrap_static())
5230 .product::<usize>()
5231 .max(1);
5232 let num_idx = idx_shape.dim(axis_u).unwrap_static();
5233 let trailing: usize = (axis_u + 1..dy_shape.rank())
5234 .map(|i| dy_shape.dim(i).unwrap_static())
5235 .product::<usize>()
5236 .max(1);
5237 let axis_dim = out_shape.dim(axis_u).unwrap_static();
5238 let p = GatherBwdParams {
5239 outer: outer as u32,
5240 axis_dim: axis_dim as u32,
5241 num_idx: num_idx as u32,
5242 trailing: trailing as u32,
5243 dy_off: (arena.offset(node.inputs[0]) / 4) as u32,
5244 idx_off: (arena.offset(node.inputs[1]) / 4) as u32,
5245 dst_off: (arena.offset(node.id) / 4) as u32,
5246 _p0: 0,
5247 };
5248 let zk = gather_backward_zero_kernel(&dev.device);
5249 let u = emit_uniform(std::mem::size_of::<GatherBwdParams>());
5250 let bg = bind_op_output_window(&dev.device, zk, &arena, node.id, &u);
5251 schedule.push(Step::GatherBackward { params: p });
5252 uniforms.push(u);
5253 bind_groups.push(bg);
5254 }
5255 #[cfg(feature = "splat")]
5256 Op::GaussianSplatRender {
5257 width,
5258 height,
5259 tile_size,
5260 radius_scale,
5261 alpha_cutoff,
5262 max_splat_steps,
5263 transmittance_threshold,
5264 max_list_entries,
5265 } => {
5266 let elem_len = |id: NodeId| -> u32 {
5267 graph.node(id).shape.num_elements().unwrap_or(0) as u32
5268 };
5269 schedule.push(Step::GaussianSplatRender {
5270 positions_byte_off: arena.offset(node.inputs[0]) as u32,
5271 positions_len: elem_len(node.inputs[0]),
5272 scales_byte_off: arena.offset(node.inputs[1]) as u32,
5273 scales_len: elem_len(node.inputs[1]),
5274 rotations_byte_off: arena.offset(node.inputs[2]) as u32,
5275 rotations_len: elem_len(node.inputs[2]),
5276 opacities_byte_off: arena.offset(node.inputs[3]) as u32,
5277 opacities_len: elem_len(node.inputs[3]),
5278 colors_byte_off: arena.offset(node.inputs[4]) as u32,
5279 colors_len: elem_len(node.inputs[4]),
5280 sh_coeffs_byte_off: arena.offset(node.inputs[5]) as u32,
5281 sh_coeffs_len: elem_len(node.inputs[5]),
5282 meta_byte_off: arena.offset(node.inputs[6]) as u32,
5283 dst_byte_off: arena.offset(node.id) as u32,
5284 dst_len: node.shape.num_elements().unwrap_or(0) as u32,
5285 width: *width,
5286 height: *height,
5287 tile_size: *tile_size,
5288 radius_scale: *radius_scale,
5289 alpha_cutoff: *alpha_cutoff,
5290 max_splat_steps: *max_splat_steps,
5291 transmittance_threshold: *transmittance_threshold,
5292 max_list_entries: *max_list_entries,
5293 });
5294 }
5295
5296 #[cfg(feature = "splat")]
5297 Op::GaussianSplatRenderBackward {
5298 width,
5299 height,
5300 tile_size,
5301 radius_scale,
5302 alpha_cutoff,
5303 max_splat_steps,
5304 transmittance_threshold,
5305 max_list_entries,
5306 loss_grad_clip,
5307 sh_band,
5308 max_anisotropy,
5309 } => {
5310 let elem_len = |id: NodeId| -> u32 {
5311 graph.node(id).shape.num_elements().unwrap_or(0) as u32
5312 };
5313 schedule.push(Step::GaussianSplatRenderBackward {
5314 positions_byte_off: arena.offset(node.inputs[0]) as u32,
5315 positions_len: elem_len(node.inputs[0]),
5316 scales_byte_off: arena.offset(node.inputs[1]) as u32,
5317 scales_len: elem_len(node.inputs[1]),
5318 rotations_byte_off: arena.offset(node.inputs[2]) as u32,
5319 rotations_len: elem_len(node.inputs[2]),
5320 opacities_byte_off: arena.offset(node.inputs[3]) as u32,
5321 opacities_len: elem_len(node.inputs[3]),
5322 colors_byte_off: arena.offset(node.inputs[4]) as u32,
5323 colors_len: elem_len(node.inputs[4]),
5324 sh_coeffs_byte_off: arena.offset(node.inputs[5]) as u32,
5325 sh_coeffs_len: elem_len(node.inputs[5]),
5326 meta_byte_off: arena.offset(node.inputs[6]) as u32,
5327 d_loss_byte_off: arena.offset(node.inputs[7]) as u32,
5328 d_loss_len: elem_len(node.inputs[7]),
5329 packed_byte_off: arena.offset(node.id) as u32,
5330 packed_len: node.shape.num_elements().unwrap_or(0) as u32,
5331 width: *width,
5332 height: *height,
5333 tile_size: *tile_size,
5334 radius_scale: *radius_scale,
5335 alpha_cutoff: *alpha_cutoff,
5336 max_splat_steps: *max_splat_steps,
5337 transmittance_threshold: *transmittance_threshold,
5338 max_list_entries: *max_list_entries,
5339 loss_grad_clip: *loss_grad_clip,
5340 sh_band: *sh_band,
5341 max_anisotropy: *max_anisotropy,
5342 });
5343 }
5344
5345 #[cfg(feature = "splat")]
5346 Op::GaussianSplatPrepare {
5347 width,
5348 height,
5349 tile_size,
5350 radius_scale,
5351 alpha_cutoff,
5352 max_splat_steps,
5353 transmittance_threshold,
5354 max_list_entries,
5355 } => {
5356 let elem_len = |id: NodeId| -> u32 {
5357 graph.node(id).shape.num_elements().unwrap_or(0) as u32
5358 };
5359 schedule.push(Step::GaussianSplatPrepare {
5360 positions_byte_off: arena.offset(node.inputs[0]) as u32,
5361 positions_len: elem_len(node.inputs[0]),
5362 scales_byte_off: arena.offset(node.inputs[1]) as u32,
5363 scales_len: elem_len(node.inputs[1]),
5364 rotations_byte_off: arena.offset(node.inputs[2]) as u32,
5365 rotations_len: elem_len(node.inputs[2]),
5366 opacities_byte_off: arena.offset(node.inputs[3]) as u32,
5367 opacities_len: elem_len(node.inputs[3]),
5368 colors_byte_off: arena.offset(node.inputs[4]) as u32,
5369 colors_len: elem_len(node.inputs[4]),
5370 sh_coeffs_byte_off: arena.offset(node.inputs[5]) as u32,
5371 sh_coeffs_len: elem_len(node.inputs[5]),
5372 meta_byte_off: arena.offset(node.inputs[6]) as u32,
5373 meta_len: elem_len(node.inputs[6]),
5374 prep_byte_off: arena.offset(node.id) as u32,
5375 prep_len: node.shape.num_elements().unwrap_or(0) as u32,
5376 width: *width,
5377 height: *height,
5378 tile_size: *tile_size,
5379 radius_scale: *radius_scale,
5380 alpha_cutoff: *alpha_cutoff,
5381 max_splat_steps: *max_splat_steps,
5382 transmittance_threshold: *transmittance_threshold,
5383 max_list_entries: *max_list_entries,
5384 });
5385 }
5386
5387 #[cfg(feature = "splat")]
5388 Op::GaussianSplatRasterize {
5389 width,
5390 height,
5391 tile_size,
5392 alpha_cutoff,
5393 max_splat_steps,
5394 transmittance_threshold,
5395 max_list_entries,
5396 } => {
5397 let elem_len = |id: NodeId| -> u32 {
5398 graph.node(id).shape.num_elements().unwrap_or(0) as u32
5399 };
5400 let prep_id = node.inputs[0];
5401 let count = match &graph.node(prep_id).op {
5402 rlx_ir::Op::GaussianSplatPrepare { .. } => {
5403 elem_len(graph.node(prep_id).inputs[0]) / 3
5404 }
5405 _ => 1,
5406 };
5407 schedule.push(Step::GaussianSplatRasterize {
5408 prep_byte_off: arena.offset(prep_id) as u32,
5409 prep_len: elem_len(prep_id),
5410 meta_byte_off: arena.offset(node.inputs[1]) as u32,
5411 meta_len: elem_len(node.inputs[1]),
5412 dst_byte_off: arena.offset(node.id) as u32,
5413 dst_len: node.shape.num_elements().unwrap_or(0) as u32,
5414 count,
5415 width: *width,
5416 height: *height,
5417 tile_size: *tile_size,
5418 alpha_cutoff: *alpha_cutoff,
5419 max_splat_steps: *max_splat_steps,
5420 transmittance_threshold: *transmittance_threshold,
5421 max_list_entries: *max_list_entries,
5422 });
5423 }
5424
5425 Op::If { .. } | Op::While { .. } => {
5426 panic!(
5431 "rlx-wgpu: Op::If/While leaked past unfusion pass — \
5432 check unfuse.rs::expand_if / expand_while"
5433 );
5434 }
5435 other => panic!(
5436 "rlx-wgpu: op {other:?} not yet lowered (v2 covers Matmul, \
5437 Binary, Compare, Activation, Where — fall back to CPU/Metal/MLX)"
5438 ),
5439 }
5440 }
5441
5442 if rlx_ir::env::flag("RLX_WGPU_SCHEDULE") || rlx_ir::env::flag("RLX_DISPATCH_REPORT") {
5443 let mut counts: std::collections::BTreeMap<&'static str, usize> =
5444 std::collections::BTreeMap::new();
5445 let mut fft_gpu = 0usize;
5446 let mut fft_host = 0usize;
5447 for s in &schedule {
5448 *counts.entry(step_name(s)).or_insert(0) += 1;
5449 match s {
5450 Step::FftGpu { .. } => fft_gpu += 1,
5451 Step::FftHost { .. } => fft_host += 1,
5452 _ => {}
5453 }
5454 }
5455 let arena_mb = arena.size as f64 / (1u64 << 20) as f64;
5456 eprintln!(
5457 "[rlx-wgpu] schedule: {} steps, arena={arena_mb:.1} MiB, fft_gpu={fft_gpu}, fft_host={fft_host}",
5458 schedule.len()
5459 );
5460 for (n, c) in &counts {
5461 eprintln!(" {c:>4} × {n}");
5462 }
5463 }
5464
5465 let coop_f16_vk = schedule_uses_coop_f16_vk(&schedule);
5466
5467 Self {
5468 graph,
5469 arena,
5470 schedule,
5471 input_offsets,
5472 param_offsets,
5473 uniforms,
5474 bind_groups,
5475 meta_buffers,
5476 unresolved: None,
5477 last_binding: None,
5478 pending_params: HashMap::new(),
5479 pending_param_bytes: HashMap::new(),
5480 active_extent: None,
5481 uniforms_active_extent: None,
5482 input_staging_hashes: HashMap::new(),
5483 coop_f16_vk,
5484 coop_f16_b_param,
5485 coop_f16_vk_wide_b: HashSet::new(),
5486 coop_f16_vk_wide_bind_groups,
5487 coop_f16_host_activations,
5488 stashed_params: HashMap::new(),
5489 readback_staging: None,
5490 tiny_readback: None,
5491 fft_gpu_steps,
5492 gpu_handles: HashMap::new(),
5493 gpu_handle_feeds: HashMap::new(),
5494 gpu_handle_resident: HashSet::new(),
5495 pending_read_indices: None,
5496 }
5497 }
5498
5499 pub fn set_param(&mut self, name: &str, data: &[f32]) {
5500 const STASH_MAX_BYTES: usize = 16 * 1024 * 1024;
5501 if data.len() * 4 <= STASH_MAX_BYTES {
5502 self.stashed_params.insert(name.to_string(), data.to_vec());
5503 }
5504 if self.coop_f16_vk {
5505 crate::coop_f16_vk::refresh_wide_b_flag(&mut self.coop_f16_vk_wide_b, name, data);
5506 }
5507 if self.unresolved.is_some() {
5508 self.pending_params.insert(name.to_string(), data.to_vec());
5509 return;
5510 }
5511 let dev = wgpu_device().expect("rlx-wgpu: device gone");
5512 if let Some(&id) = self.param_offsets.get(name)
5513 && self.arena.has(id)
5514 {
5515 self.arena.write_f32(&dev.queue, id, data);
5516 }
5517 }
5518
5519 pub fn debug_first_nan_node(
5524 &mut self,
5525 inputs: &[(&str, &[f32])],
5526 ) -> Option<(usize, String, String)> {
5527 let _ = self.run(inputs);
5528 let dev = wgpu_device().expect("rlx-wgpu: device gone");
5529 let mut prev_summary = String::from("(none)");
5530 for (i, node) in self.graph.nodes().iter().enumerate() {
5531 if !self.arena.has(node.id) {
5532 continue;
5533 }
5534 let elems = node.shape.num_elements().unwrap_or(0);
5535 if elems == 0 {
5536 continue;
5537 }
5538 let data = self.arena.read_f32(&dev.device, &dev.queue, node.id);
5539 let nan_count = data.iter().filter(|v| v.is_nan()).count();
5540 let inf_count = data.iter().filter(|v| v.is_infinite()).count();
5541 if nan_count > 0 || inf_count > 0 {
5542 return Some((i, format!("{:?}", node.op), prev_summary));
5543 }
5544 let max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
5545 let min = data.iter().copied().fold(f32::INFINITY, f32::min);
5546 let abs_max = data.iter().map(|v| v.abs()).fold(0.0_f32, f32::max);
5547 prev_summary = format!(
5548 "node #{i} {:?} shape={:?} min={min:.6e} max={max:.6e} |max|={abs_max:.6e}",
5549 node.op,
5550 node.shape
5551 .dims()
5552 .iter()
5553 .map(|d| format!("{d:?}"))
5554 .collect::<Vec<_>>()
5555 );
5556 }
5557 None
5558 }
5559
5560 pub fn output_dtypes(&self) -> Vec<rlx_ir::DType> {
5564 self.graph
5565 .outputs
5566 .iter()
5567 .map(|&id| self.graph.node(id).shape.dtype())
5568 .collect()
5569 }
5570
5571 pub fn set_param_bytes(&mut self, name: &str, data: &[u8]) {
5576 if self.unresolved.is_some() {
5577 self.pending_param_bytes
5578 .insert(name.to_string(), data.to_vec());
5579 return;
5580 }
5581 let dev = wgpu_device().expect("rlx-wgpu: device gone");
5582 if let Some(&id) = self.param_offsets.get(name)
5583 && self.arena.has(id)
5584 {
5585 dev.queue
5586 .write_buffer(&self.arena.buffer, self.arena.offset(id) as u64, data);
5587 }
5588 }
5589
5590 fn dump_node_stats_if_requested(&self, dev: &crate::device::WgpuDevice) {
5591 if !rlx_ir::env::flag("RLX_WGPU_DUMP_NODES") {
5592 return;
5593 }
5594 let flat_probe = rlx_ir::env::parse_or::<usize>("RLX_WGPU_DUMP_FLAT", usize::MAX);
5595 let limit = rlx_ir::env::parse_or("RLX_WGPU_DUMP_NODES_LIMIT", 40usize);
5596 eprintln!(
5597 "[rlx-wgpu-dump] per-node max |x| (topo order, limit={limit}{})",
5598 if flat_probe != usize::MAX {
5599 format!(", flat[{flat_probe}]")
5600 } else {
5601 String::new()
5602 }
5603 );
5604 let mut shown = 0usize;
5605 for (i, node) in self.graph.nodes().iter().enumerate() {
5606 if !self.arena.has(node.id) {
5607 continue;
5608 }
5609 if matches!(
5610 node.op,
5611 rlx_ir::Op::Input { .. }
5612 | rlx_ir::Op::Param { .. }
5613 | rlx_ir::Op::Constant { .. }
5614 | rlx_ir::Op::Reshape { .. }
5615 | rlx_ir::Op::Cast { .. }
5616 ) {
5617 continue;
5618 }
5619 let data = self.arena.read_f32(&dev.device, &dev.queue, node.id);
5620 let max = data.iter().fold(0.0f32, |m, &v| m.max(v.abs()));
5621 let nz = data.iter().filter(|&&v| v != 0.0).count();
5622 let flat_s = if flat_probe < data.len() {
5623 format!(" flat[{flat_probe}]={:.6}", data[flat_probe])
5624 } else {
5625 String::new()
5626 };
5627 eprintln!(
5628 " [{i:>3}] {:?} max={max:.6} nonzero={}/{}{flat_s}",
5629 node.op,
5630 nz,
5631 data.len()
5632 );
5633 shown += 1;
5634 if shown >= limit {
5635 break;
5636 }
5637 }
5638 }
5639
5640 pub fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
5641 self.run_read_outputs(inputs, None)
5642 }
5643
5644 pub fn run_read_outputs(
5645 &mut self,
5646 inputs: &[(&str, &[f32])],
5647 read_indices: Option<&[usize]>,
5648 ) -> Vec<Vec<f32>> {
5649 self.pending_read_indices = read_indices.map(|s| s.to_vec());
5650 let outs = self.run_inner(inputs);
5651 self.pending_read_indices = None;
5652 outs
5653 }
5654
5655 pub fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
5656 if !self.input_offsets.contains_key(name) {
5657 return false;
5658 }
5659 self.gpu_handle_resident.remove(name);
5660 self.gpu_handles.insert(name.to_string(), data.to_vec());
5661 true
5662 }
5663
5664 pub fn has_gpu_handle(&self, name: &str) -> bool {
5665 self.gpu_handles.contains_key(name)
5666 }
5667
5668 pub fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) {
5669 self.gpu_handle_feeds
5670 .insert(handle_name.to_string(), output_index);
5671 }
5672
5673 pub fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
5674 if let Some(&out_idx) = self.gpu_handle_feeds.get(name) {
5675 if out_idx < self.graph.outputs.len() {
5676 let id = self.graph.outputs[out_idx];
5677 if self.arena.has(id) {
5678 let dev = wgpu_device().expect("rlx-wgpu: device gone");
5679 return Some(self.arena.read_f32(&dev.device, &dev.queue, id));
5680 }
5681 }
5682 }
5683 if self.gpu_handle_resident.contains(name) {
5684 if let Some(&id) = self.input_offsets.get(name) {
5685 if self.arena.has(id) {
5686 let dev = wgpu_device().expect("rlx-wgpu: device gone");
5687 return Some(self.arena.read_f32(&dev.device, &dev.queue, id));
5688 }
5689 }
5690 }
5691 self.gpu_handles.get(name).cloned()
5692 }
5693
5694 fn readback_plan(&self) -> Vec<usize> {
5695 let n = self.graph.outputs.len();
5696 if self.pending_read_indices.is_none() && self.gpu_handle_feeds.is_empty() {
5697 return (0..n).collect();
5698 }
5699 if let Some(ref want) = self.pending_read_indices {
5700 let mut v: Vec<_> = want.to_vec();
5701 v.sort_unstable();
5702 return v;
5703 }
5704 (0..n).collect()
5705 }
5706
5707 fn propagate_gpu_handle_feeds_on_gpu(
5708 &mut self,
5709 dev: &crate::device::WgpuDevice,
5710 enc: &mut wgpu::CommandEncoder,
5711 ) {
5712 let extent = self.active_extent;
5713 let feeds: Vec<(String, usize)> = self
5714 .gpu_handle_feeds
5715 .iter()
5716 .map(|(n, &i)| (n.clone(), i))
5717 .collect();
5718 for (name, out_idx) in feeds {
5719 if out_idx >= self.graph.outputs.len() {
5720 continue;
5721 }
5722 let out_id = self.graph.outputs[out_idx];
5723 let Some(&in_id) = self.input_offsets.get(name.as_str()) else {
5724 continue;
5725 };
5726 if in_id != out_id {
5727 let out_bytes = self.arena.len_of(out_id);
5728 let copy_bytes = match extent {
5729 Some((actual, upper)) if upper > 0 => {
5730 let stride = (out_bytes / (upper + 1)).max(4);
5731 (actual * stride).min(out_bytes)
5732 }
5733 _ => out_bytes,
5734 };
5735 self.dispatch_arena_copy_bytes(dev, enc, out_id, in_id, copy_bytes);
5736 }
5737 self.gpu_handle_resident.insert(name.clone());
5738 self.gpu_handles.insert(name.clone(), Vec::new());
5739 }
5740 }
5741
5742 fn dispatch_arena_copy_bytes(
5743 &self,
5744 dev: &crate::device::WgpuDevice,
5745 enc: &mut wgpu::CommandEncoder,
5746 src_id: NodeId,
5747 dst_id: NodeId,
5748 nbytes: usize,
5749 ) {
5750 if nbytes == 0 {
5751 return;
5752 }
5753 let src = self.arena.offset(src_id) as u64;
5754 let dst = self.arena.offset(dst_id) as u64;
5755 let nbytes = nbytes
5756 .min(self.arena.len_of(src_id))
5757 .min(self.arena.len_of(dst_id)) as u64;
5758 let elems = (nbytes / 4).max(1) as u32;
5759 let lo = src.min(dst);
5760 let hi = src.saturating_add(nbytes).max(dst.saturating_add(nbytes));
5761 let max_binding = dev.device.limits().max_storage_buffer_binding_size;
5762 let mut size = hi.saturating_sub(lo).div_ceil(256) * 256;
5763 size = size.max(256).min(max_binding);
5764 let mut base = (lo / 256) * 256;
5765 if base.saturating_add(size) > self.arena.size as u64 {
5766 base = (self.arena.size as u64).saturating_sub(size);
5767 base = (base / 256) * 256;
5768 }
5769 let p = CopyParams {
5770 n: elems,
5771 in_off: (src.saturating_sub(base) / 4) as u32,
5772 out_off: (dst.saturating_sub(base) / 4) as u32,
5773 _p0: 0,
5774 _p1: 0,
5775 _p2: 0,
5776 _p3: 0,
5777 _p4: 0,
5778 };
5779 let ck = copy_kernel(&dev.device);
5780 let u = dev.device.create_buffer(&wgpu::BufferDescriptor {
5781 label: Some("rlx-wgpu kv_feed_copy uniform"),
5782 size: std::mem::size_of::<CopyParams>() as u64,
5783 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
5784 mapped_at_creation: false,
5785 });
5786 dev.queue.write_buffer(&u, 0, bytemuck::bytes_of(&p));
5787 let bg = bind_two_buf0_window(&dev.device, ck, &self.arena.buffer, base, size, &u);
5788 let mut pass = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
5789 label: Some("rlx-wgpu kv_feed_copy pass"),
5790 ..Default::default()
5791 });
5792 pass.set_pipeline(&ck.pipeline);
5793 pass.set_bind_group(0, &bg, &[]);
5794 let (gx, gy, gz) = dispatch_dims(elems, 64);
5795 pass.dispatch_workgroups(gx, gy, gz);
5796 }
5797
5798 #[allow(dead_code)]
5799 fn dispatch_arena_copy_between_nodes(
5800 &self,
5801 dev: &crate::device::WgpuDevice,
5802 enc: &mut wgpu::CommandEncoder,
5803 src_id: NodeId,
5804 dst_id: NodeId,
5805 ) {
5806 let nbytes = self.arena.len_of(src_id).min(self.arena.len_of(dst_id));
5807 self.dispatch_arena_copy_bytes(dev, enc, src_id, dst_id, nbytes);
5808 }
5809
5810 fn stage_gpu_handle_inputs(
5811 &mut self,
5812 dev: &crate::device::WgpuDevice,
5813 inputs: &[(&str, &[f32])],
5814 ) {
5815 for (name, data) in &self.gpu_handles {
5816 if self.gpu_handle_resident.contains(name) || inputs.iter().any(|(n, _)| n == name) {
5817 continue;
5818 }
5819 if let Some(&id) = self.input_offsets.get(name.as_str())
5820 && self.arena.has(id)
5821 {
5822 self.arena.write_f32(&dev.queue, id, data);
5823 self.input_staging_hashes.remove(name);
5824 }
5825 }
5826 }
5827
5828 fn pack_readback_outputs(&mut self, plan: &[usize], partial: Vec<Vec<f32>>) -> Vec<Vec<f32>> {
5829 if self.pending_read_indices.is_none() {
5830 for (pos, &out_i) in plan.iter().enumerate() {
5831 if let Some(data) = partial.get(pos) {
5832 for (name, &feed_i) in &self.gpu_handle_feeds {
5833 if feed_i == out_i {
5834 self.gpu_handles.insert(name.clone(), data.clone());
5835 }
5836 }
5837 }
5838 }
5839 }
5840 if self.pending_read_indices.is_none() && plan.len() == self.graph.outputs.len() {
5841 return partial;
5842 }
5843 let want = self.pending_read_indices.as_deref().unwrap_or(plan);
5844 let mut by_idx = std::collections::HashMap::new();
5845 for (pos, &i) in plan.iter().enumerate() {
5846 if let Some(d) = partial.get(pos) {
5847 by_idx.insert(i, d.clone());
5848 }
5849 }
5850 want.iter()
5851 .map(|&i| {
5852 by_idx
5853 .get(&i)
5854 .cloned()
5855 .expect("readback plan missing output")
5856 })
5857 .collect()
5858 }
5859
5860 fn run_tail_host_audio_ops(&self, dev: &crate::device::WgpuDevice) {
5861 if !self.schedule.iter().any(step_is_tail_host) {
5862 return;
5863 }
5864 for step in &self.schedule {
5865 if !step_is_tail_host(step) {
5866 continue;
5867 }
5868 match step {
5869 Step::WelchPeaksHost {
5870 spec_byte_off,
5871 dst_byte_off,
5872 welch_batch,
5873 n_fft,
5874 n_segments,
5875 k,
5876 } => {
5877 crate::welch_peaks_host::run_welch_peaks(
5878 &self.arena,
5879 &dev.device,
5880 &dev.queue,
5881 *spec_byte_off as usize,
5882 *dst_byte_off as usize,
5883 *welch_batch as usize,
5884 *n_fft as usize,
5885 *n_segments as usize,
5886 *k as usize,
5887 );
5888 }
5889 Step::LogMelHost {
5890 spec_byte_off,
5891 filt_byte_off,
5892 dst_byte_off,
5893 outer,
5894 n_fft,
5895 n_bins,
5896 n_mels,
5897 } => {
5898 crate::log_mel_host::run_log_mel(
5899 &self.arena,
5900 &dev.device,
5901 &dev.queue,
5902 *spec_byte_off as usize,
5903 *filt_byte_off as usize,
5904 *dst_byte_off as usize,
5905 *outer as usize,
5906 *n_fft as usize,
5907 *n_bins as usize,
5908 *n_mels as usize,
5909 );
5910 }
5911 Step::LogMelBackwardHost {
5912 spec_byte_off,
5913 filt_byte_off,
5914 dy_byte_off,
5915 dst_byte_off,
5916 outer,
5917 n_fft,
5918 n_bins,
5919 n_mels,
5920 } => {
5921 crate::log_mel_host::run_log_mel_backward(
5922 &self.arena,
5923 &dev.device,
5924 &dev.queue,
5925 *spec_byte_off as usize,
5926 *filt_byte_off as usize,
5927 *dy_byte_off as usize,
5928 *dst_byte_off as usize,
5929 *outer as usize,
5930 *n_fft as usize,
5931 *n_bins as usize,
5932 *n_mels as usize,
5933 );
5934 }
5935 _ => {}
5936 }
5937 }
5938 }
5939
5940 fn run_inner(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
5941 if self.unresolved.is_some() {
5944 self.lazy_compile_for_inputs(inputs);
5945 }
5946 let dev = wgpu_device().expect("rlx-wgpu: device gone");
5947 self.stage_gpu_handle_inputs(dev, inputs);
5948 let skip_input_upload =
5949 !rlx_ir::env::flag("RLX_WGPU_FORCE_INPUT_UPLOAD") && !self.coop_f16_vk;
5950 for &(name, data) in inputs {
5951 if let Some(&id) = self.input_offsets.get(name)
5952 && self.arena.has(id)
5953 {
5954 if skip_input_upload {
5955 let h = hash_f32_input(data);
5956 if self.input_staging_hashes.get(name) == Some(&h) {
5957 if self.arena.f16_buffer.is_some() {
5958 self.arena.write_f16_shadow(&dev.queue, id, data);
5959 }
5960 continue;
5961 }
5962 self.arena.write_f32(&dev.queue, id, data);
5963 self.input_staging_hashes.insert(name.to_string(), h);
5964 } else {
5965 self.arena.write_f32(&dev.queue, id, data);
5966 }
5967 }
5968 }
5969 for &(act_id, act, ref src_name) in &self.coop_f16_host_activations {
5970 let src =
5971 host_tensor_f32(src_name, inputs, &self.stashed_params).unwrap_or_else(|| {
5972 panic!("rlx-wgpu CoopF16Vk host activation: missing tensor {src_name:?}")
5973 });
5974 let mirrored = apply_activation_host(act, src);
5975 self.arena.write_f32(&dev.queue, act_id, &mirrored);
5976 }
5977 if !self.coop_f16_host_activations.is_empty() {
5978 let flush = dev
5980 .device
5981 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
5982 label: Some("rlx-wgpu host mirror flush"),
5983 });
5984 dev.queue.submit(std::iter::once(flush.finish()));
5985 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
5986 }
5987
5988 let active = self.active_extent.filter(|_| self.all_safe_for_active());
5993 let scale = |full: u32| -> u32 {
5994 match active {
5995 Some((a, u)) if u > 0 => {
5996 let f = full as usize;
5997 (f * a).div_ceil(u).min(f) as u32
5998 }
5999 _ => full,
6000 }
6001 };
6002
6003 let need_uniform_writes = self.uniforms_active_extent != Some(active);
6009 if need_uniform_writes {
6010 let mut gpu_ui = 0usize;
6011 for step in self.schedule.iter() {
6012 if step_runs_on_host(step) {
6013 continue;
6014 }
6015 match step {
6016 Step::CastF32ToF16 { .. } => {
6017 }
6021 Step::Matmul {
6022 m,
6023 k,
6024 n,
6025 a_off_f32,
6026 b_off_f32,
6027 c_off_f32,
6028 batch,
6029 a_batch_stride,
6030 b_batch_stride,
6031 c_batch_stride,
6032 has_bias,
6033 bias_off_f32,
6034 act_id,
6035 b_is_param: _,
6036 compute_precision: _,
6037 } => {
6038 let m_scaled = scale(*m);
6042 let p = MatmulParams {
6043 m: m_scaled,
6044 k: *k,
6045 n: *n,
6046 a_off: *a_off_f32,
6047 b_off: *b_off_f32,
6048 c_off: *c_off_f32,
6049 batch: *batch,
6050 a_batch_stride: *a_batch_stride,
6051 b_batch_stride: *b_batch_stride,
6052 c_batch_stride: *c_batch_stride,
6053 has_bias: *has_bias,
6054 bias_off: *bias_off_f32,
6055 act_id: *act_id,
6056 _pad0: 0,
6057 _pad1: 0,
6058 _pad2: 0,
6059 };
6060 dev.queue
6061 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6062 }
6063 Step::Binary { params } | Step::Compare { params } => {
6064 let mut p = *params;
6065 p.n = scale(p.n);
6066 dev.queue
6067 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6068 }
6069 Step::Unary { params, .. } => {
6070 let mut p = *params;
6071 p.n = scale(p.n);
6072 dev.queue
6073 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6074 }
6075 Step::Where { params } => {
6076 let mut p = *params;
6077 p.n = scale(p.n);
6078 dev.queue
6079 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6080 }
6081 Step::Reduce { params } => {
6082 let mut p = *params;
6083 p.outer = scale(p.outer);
6084 dev.queue
6085 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6086 }
6087 Step::Softmax { params } => {
6088 let mut p = *params;
6089 p.outer = scale(p.outer);
6090 dev.queue
6091 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6092 }
6093 Step::LayerNorm { params } => {
6094 let mut p = *params;
6095 p.outer = scale(p.outer);
6096 dev.queue
6097 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6098 }
6099 Step::RmsNormBackwardInput { params }
6100 | Step::RmsNormBackwardGamma { params }
6101 | Step::RmsNormBackwardBeta { params } => {
6102 let mut p = *params;
6103 p.outer = scale(p.outer);
6104 dev.queue
6105 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6106 }
6107 Step::LayerNormBackwardInput { params } => {
6108 let mut p = *params;
6109 p.outer = scale(p.outer);
6110 dev.queue
6111 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6112 }
6113 Step::LayerNormBackwardGammaPartial { params, .. } => {
6114 let mut p = *params;
6115 p.outer = scale(p.outer);
6116 dev.queue
6117 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6118 }
6119 Step::LayerNormBackwardGammaReduce { params } => {
6120 dev.queue.write_buffer(
6124 &self.uniforms[gpu_ui],
6125 0,
6126 bytemuck::bytes_of(params),
6127 );
6128 }
6129 Step::CumsumBackward { params } => {
6130 let mut p = *params;
6131 p.outer = scale(p.outer);
6132 dev.queue
6133 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6134 }
6135 Step::RopeBackward { params } => {
6136 let mut p = *params;
6137 p.seq = scale(p.seq);
6138 dev.queue
6139 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6140 }
6141 Step::GatherBackward { params } => {
6142 let mut p = *params;
6143 p.outer = scale(p.outer);
6144 dev.queue
6145 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6146 }
6147 Step::Cumsum { params } => {
6148 let mut p = *params;
6149 p.outer = scale(p.outer);
6150 dev.queue
6151 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6152 }
6153 Step::FftGpu { .. } => {}
6154 Step::Copy { params } => {
6155 let mut p = *params;
6156 p.n = scale(p.n);
6157 dev.queue
6158 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6159 }
6160 Step::BufferCopy { .. } => {}
6161 Step::ElementwiseRegion { params } => {
6162 let mut p = *params;
6164 p.len = scale(p.len);
6165 dev.queue
6166 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6167 }
6168 Step::BatchElementwiseRegion { params } => {
6169 let mut p = *params;
6170 p.slice_len = scale(p.slice_len);
6171 p.num_batch = scale(p.num_batch);
6172 dev.queue
6173 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6174 }
6175 Step::Transpose { params, .. } => {
6176 let mut p = *params;
6181 if p.bucket_outermost == 1 && p.out_dim_0 > 0 {
6182 let scaled_d0 = scale(p.out_dim_0);
6183 let inner = p.out_total / p.out_dim_0;
6184 p.out_total = scaled_d0 * inner;
6185 }
6186 dev.queue
6187 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6188 }
6189 Step::Narrow { params } => {
6190 let mut p = *params;
6191 p.total = scale(p.total);
6192 dev.queue
6193 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6194 }
6195 Step::Concat { params } => {
6196 let mut p = *params;
6197 p.total = scale(p.total);
6198 dev.queue
6199 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6200 }
6201 Step::Gather { params } => {
6202 let mut p = *params;
6203 p.n_out = scale(p.n_out);
6204 dev.queue
6205 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6206 }
6207 Step::GatherAxis { params } => {
6208 let mut p = *params;
6209 p.total = scale(p.total);
6210 dev.queue
6211 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6212 }
6213 Step::Attention { params, .. } => {
6214 let mut p = *params;
6219 p.seq_q = scale(p.seq_q);
6220 p.seq_k = scale(p.seq_k);
6221 dev.queue
6222 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6223 }
6224 Step::AttentionBackward { params, .. } => {
6225 let mut p = *params;
6226 if p.wrt == 0 {
6227 p.seq_q = scale(p.seq_q);
6228 } else {
6229 p.seq_k = scale(p.seq_k);
6230 }
6231 dev.queue
6232 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6233 }
6234 Step::Rope { params } => {
6235 let mut p = *params;
6240 let s_active = scale(p.seq);
6241 p.seq = s_active;
6242 p.n_total = p.batch * s_active * p.last_dim;
6243 dev.queue
6244 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6245 }
6246 Step::Expand { params, .. } => {
6247 let mut p = *params;
6249 if p.bucket_outermost == 1 && p.out_dim_0 > 0 {
6250 let scaled_d0 = scale(p.out_dim_0);
6251 let inner = p.out_total / p.out_dim_0;
6252 p.out_total = scaled_d0 * inner;
6253 }
6254 dev.queue
6255 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6256 }
6257 Step::Argmax { params } => {
6258 let mut p = *params;
6259 p.outer = scale(p.outer);
6260 dev.queue
6261 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6262 }
6263 Step::Pool2d { params } => {
6264 let mut p = *params;
6265 p.n = scale(p.n);
6266 dev.queue
6267 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6268 }
6269 Step::Conv2d { params } => {
6270 let mut p = *params;
6271 p.n = scale(p.n);
6272 dev.queue
6273 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6274 }
6275 Step::Pool1d { params } => {
6276 let mut p = *params;
6277 p.n = scale(p.n);
6278 dev.queue
6279 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6280 }
6281 Step::Pool3d { params } => {
6282 let mut p = *params;
6283 p.n = scale(p.n);
6284 dev.queue
6285 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6286 }
6287 Step::Conv1d { params } => {
6288 let mut p = *params;
6289 p.n = scale(p.n);
6290 dev.queue
6291 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6292 }
6293 Step::Conv3d { params } => {
6294 let mut p = *params;
6295 p.n = scale(p.n);
6296 dev.queue
6297 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6298 }
6299 Step::ScatterAdd { params } => {
6300 let mut p = *params;
6304 if p.op == 1 {
6305 p.num_updates = scale(p.num_updates);
6306 }
6307 dev.queue
6308 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6309 }
6310 Step::TopK { params } => {
6311 let mut p = *params;
6312 p.outer = scale(p.outer);
6313 dev.queue
6314 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6315 }
6316 Step::WelchPeaksGpu { params } => {
6317 let mut p = *params;
6318 p.welch_batch = scale(p.welch_batch);
6319 dev.queue
6320 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6321 }
6322 Step::UmapKnn { params } => {
6323 let mut p = *params;
6324 p.n = scale(p.n);
6325 dev.queue
6326 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6327 }
6328 Step::GroupedMatmul { params } => {
6329 let mut p = *params;
6330 p.m = scale(p.m);
6331 dev.queue
6332 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6333 }
6334 Step::Sample { params } => {
6335 let mut p = *params;
6336 p.outer = scale(p.outer);
6337 dev.queue
6338 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6339 }
6340 Step::SelectiveScan { params } => {
6341 let mut p = *params;
6343 p.seq = scale(p.seq);
6344 dev.queue
6345 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6346 }
6347 Step::DequantMatmul { params } => {
6348 let mut p = *params;
6349 p.m = scale(p.m);
6350 dev.queue
6351 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6352 }
6353 Step::DequantMatmulGguf { .. }
6354 | Step::DequantGroupedMatmulGguf { .. }
6355 | Step::GatedDeltaNet { .. }
6356 | Step::Llada2GroupLimitedGate { .. }
6357 | Step::UmapKnnHost { .. }
6358 | Step::FftHost { .. }
6359 | Step::Im2ColHost { .. }
6360 | Step::WelchPeaksHost { .. }
6361 | Step::LogMelHost { .. }
6362 | Step::LogMelBackwardHost { .. } => {}
6363 Step::FusedResidualLn { params } => {
6364 let mut p = *params;
6365 p.outer = scale(p.outer);
6366 dev.queue
6367 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6368 }
6369 Step::FusedResidualLnTee { params } => {
6370 let mut p = *params;
6371 p.outer = scale(p.outer);
6372 dev.queue
6373 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6374 }
6375 Step::FusedResidualRmsNorm { params } => {
6376 let mut p = *params;
6377 p.outer = scale(p.outer);
6378 dev.queue
6379 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6380 }
6381 Step::MatmulQkv { params, kind: _ } => {
6382 let mut p = *params;
6383 p.m = scale(p.m);
6384 dev.queue
6385 .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6386 }
6387 #[cfg(feature = "splat")]
6388 Step::GaussianSplatRender { .. }
6389 | Step::GaussianSplatRenderBackward { .. }
6390 | Step::GaussianSplatPrepare { .. }
6391 | Step::GaussianSplatRasterize { .. } => {}
6392 }
6393 if !matches!(step, Step::FftGpu { .. }) {
6394 gpu_ui += 1;
6395 }
6396 }
6397 self.uniforms_active_extent = Some(active);
6398 }
6399
6400 let mm_k = matmul_kernel(&dev.device);
6402 let mm_w_active = matmul_wide_active_kernel(&dev.device);
6403 let mm_f16w = matmul_f16w_kernel(&dev.device);
6404 let mm_f16c = matmul_f16_compute_kernel(&dev.device);
6405 let mm_coop = matmul_coop16_kernel(&dev.device);
6406 let mm_coop_f16_vk = matmul_coop_f16_vulkan_kernel(&dev.device);
6407 let mm_coop_f32 = matmul_coop_f32_active_kernel(&dev.device);
6408 let mm_cast = cast_f32_to_f16_kernel(&dev.device);
6409 let bk = binary_kernel(&dev.device);
6410 let uk = unary_kernel(&dev.device);
6411 let ck = compare_kernel(&dev.device);
6412 let wk = where_kernel(&dev.device);
6413 let mut step_i = 0;
6414 let mut gpu_bi = 0usize;
6415 let mut fft_i = 0usize;
6416 while step_i < self.schedule.len() {
6417 let mut enc = dev
6418 .device
6419 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
6420 label: Some("rlx-wgpu run"),
6421 });
6422 {
6423 let mut pass = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
6424 label: Some("rlx-wgpu compute pass"),
6425 timestamp_writes: None,
6426 });
6427 let mut pass_dispatched = false;
6428 while step_i < self.schedule.len() {
6429 if step_is_tail_host(&self.schedule[step_i]) {
6430 step_i += 1;
6431 continue;
6432 }
6433 if step_runs_on_host(&self.schedule[step_i]) {
6434 break;
6435 }
6436 if pass_dispatched
6441 && step_i > 0
6442 && step_needs_pass_flush(&self.schedule[step_i], &self.schedule[step_i - 1])
6443 {
6444 break;
6445 }
6446 let step = &self.schedule[step_i];
6447 let _perf = rlx_ir::perfetto::TraceSpan::new(step_name(step), "wgpu");
6450 match step {
6451 Step::CastF32ToF16 { params } => {
6452 if let Some(cast_k) = mm_cast {
6457 pass.set_pipeline(&cast_k.pipeline);
6458 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6459 let (gx, gy, gz) = dispatch_dims(params.len, 64);
6460 pass.dispatch_workgroups(gx, gy, gz);
6461 }
6462 }
6463 Step::Matmul {
6464 m,
6465 n,
6466 batch,
6467 b_off_f32,
6468 b_is_param,
6469 compute_precision,
6470 ..
6471 } =>
6472 {
6479 #[allow(clippy::unnecessary_unwrap)]
6480 let m_s = scale(*m);
6484 if m_s == 0 {
6485 continue;
6486 }
6487 let coop_f16_wide = mm_coop_f16_vk.is_some()
6488 && *compute_precision == MatmulCompute::CoopF16Vk
6489 && crate::coop_f16_vk::use_wide_matmul(
6490 *b_off_f32,
6491 *n,
6492 &self.coop_f16_b_param,
6493 &self.coop_f16_vk_wide_b,
6494 );
6495 pass.set_bind_group(
6496 0,
6497 coop_f16_vk_bind_group(self, gpu_bi, coop_f16_wide),
6498 &[],
6499 );
6500 let f16w_opt_in = rlx_ir::env::flag("RLX_WGPU_F16_WEIGHTS");
6510 if let Some(coop) = mm_coop.as_ref()
6511 && *b_is_param
6512 && *compute_precision == MatmulCompute::Coop16
6513 {
6514 pass.set_pipeline(&coop.pipeline);
6520 pass.dispatch_workgroups(n.div_ceil(32), m_s.div_ceil(32), *batch);
6521 } else if mm_coop_f16_vk.is_some()
6522 && *compute_precision == MatmulCompute::CoopF16Vk
6523 {
6524 if coop_f16_wide {
6525 dispatch_wide_f32_matmul(
6526 &mut pass,
6527 mm_w_active,
6528 mm_k,
6529 m_s,
6530 *n,
6531 *batch,
6532 );
6533 } else {
6534 let n_eff = scale(*n);
6535 let coop_vk =
6536 matmul_coop_f16_vulkan_active_kernel(&dev.device, n_eff)
6537 .expect("coop f16 vk kernel missing");
6538 pass.set_pipeline(&coop_vk.pipeline);
6539 pass.dispatch_workgroups(
6540 m_s.div_ceil(16),
6541 n.div_ceil(16),
6542 *batch,
6543 );
6544 }
6545 } else if let Some(coop_f32) = mm_coop_f32.as_ref()
6546 && *b_is_param
6547 && *compute_precision == MatmulCompute::CoopF32
6548 {
6549 pass.set_pipeline(&coop_f32.pipeline);
6552 let backend = wgpu_device()
6553 .map(|d| d.backend)
6554 .unwrap_or(wgpu::Backend::Noop);
6555 let (gx, gy) = if backend == wgpu::Backend::Metal {
6556 (n.div_ceil(32), m_s.div_ceil(32))
6557 } else {
6558 (m_s.div_ceil(8), n.div_ceil(8))
6559 };
6560 pass.dispatch_workgroups(gx, gy, *batch);
6561 } else if let Some(f16c) = mm_f16c.as_ref()
6562 && *b_is_param
6563 && *compute_precision == MatmulCompute::F16
6564 {
6565 pass.set_pipeline(&f16c.pipeline);
6566 pass.dispatch_workgroups(n.div_ceil(32), m_s.div_ceil(32), *batch);
6567 } else if let Some(f16w) = mm_f16w.as_ref()
6568 && *b_is_param
6569 && f16w_opt_in
6570 {
6571 pass.set_pipeline(&f16w.pipeline);
6572 pass.dispatch_workgroups(n.div_ceil(32), m_s.div_ceil(32), *batch);
6573 } else if m_s >= 32 && *n >= 64 {
6574 pass.set_pipeline(&mm_w_active.pipeline);
6575 let backend = wgpu_device()
6576 .map(|d| d.backend)
6577 .unwrap_or(wgpu::Backend::Noop);
6578 let (gx, gy) = if matches!(
6579 backend,
6580 wgpu::Backend::Vulkan | wgpu::Backend::Dx12
6581 ) {
6582 (n.div_ceil(64), m_s.div_ceil(64))
6583 } else {
6584 (n.div_ceil(64), m_s.div_ceil(32))
6585 };
6586 pass.dispatch_workgroups(gx, gy, *batch);
6587 } else {
6588 pass.set_pipeline(&mm_k.pipeline);
6589 pass.dispatch_workgroups(n.div_ceil(32), m_s.div_ceil(32), *batch);
6590 }
6591 }
6592 Step::Binary { params } => {
6593 let n_s = scale(params.n);
6594 if n_s == 0 {
6595 continue;
6596 }
6597 pass.set_pipeline(&bk.pipeline);
6598 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6599 let (gx, gy, gz) = dispatch_dims(n_s, 64);
6600 pass.dispatch_workgroups(gx, gy, gz);
6601 }
6602 Step::Compare { params } => {
6603 let n_s = scale(params.n);
6604 if n_s == 0 {
6605 continue;
6606 }
6607 pass.set_pipeline(&ck.pipeline);
6608 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6609 let (gx, gy, gz) = dispatch_dims(n_s, 64);
6610 pass.dispatch_workgroups(gx, gy, gz);
6611 }
6612 Step::Unary { params, f16_mirror } => {
6613 let n_s = scale(params.n);
6614 if n_s == 0 {
6615 continue;
6616 }
6617 if *f16_mirror {
6618 if let Some(uk_f16) = unary_f16_mirror_kernel(&dev.device) {
6619 pass.set_pipeline(&uk_f16.pipeline);
6620 } else {
6621 pass.set_pipeline(&uk.pipeline);
6622 }
6623 } else {
6624 pass.set_pipeline(&uk.pipeline);
6625 }
6626 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6627 let (gx, gy, gz) = dispatch_dims(n_s, 64);
6628 pass.dispatch_workgroups(gx, gy, gz);
6629 }
6630 Step::Where { params } => {
6631 let n_s = scale(params.n);
6632 if n_s == 0 {
6633 continue;
6634 }
6635 pass.set_pipeline(&wk.pipeline);
6636 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6637 let (gx, gy, gz) = dispatch_dims(n_s, 64);
6638 pass.dispatch_workgroups(gx, gy, gz);
6639 }
6640 Step::Reduce { params } => {
6641 let outer_s = scale(params.outer);
6642 if outer_s == 0 {
6643 continue;
6644 }
6645 let rk = reduce_kernel(&dev.device);
6646 pass.set_pipeline(&rk.pipeline);
6647 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6648 let total_out = outer_s.saturating_mul(params.inner);
6649 if params.reduce_dim <= 64 {
6650 let (gx, gy, gz) = dispatch_dims(total_out, 64);
6652 pass.dispatch_workgroups(gx, gy, gz);
6653 } else {
6654 let (gx, gy, gz) = dispatch_dims(total_out, 1);
6658 pass.dispatch_workgroups(gx, gy, gz);
6659 }
6660 }
6661 Step::Softmax { params } => {
6662 let outer_s = scale(params.outer);
6663 if outer_s == 0 {
6664 continue;
6665 }
6666 let sk = softmax_kernel(&dev.device);
6667 pass.set_pipeline(&sk.pipeline);
6668 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6669 let (gx, gy, gz) = dispatch_dims(outer_s, 64);
6670 pass.dispatch_workgroups(gx, gy, gz);
6671 }
6672 Step::LayerNorm { params } => {
6673 let outer_s = scale(params.outer);
6674 if outer_s == 0 {
6675 continue;
6676 }
6677 let lk = layernorm_kernel(&dev.device);
6678 pass.set_pipeline(&lk.pipeline);
6679 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6680 let (gx, gy, gz) = dispatch_dims(outer_s, 64);
6681 pass.dispatch_workgroups(gx, gy, gz);
6682 }
6683 Step::RmsNormBackwardInput { params } => {
6684 let outer_s = scale(params.outer);
6685 if outer_s == 0 {
6686 continue;
6687 }
6688 let rk = rms_norm_backward_kernel(&dev.device);
6689 pass.set_pipeline(&rk.pipeline);
6690 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6691 pass.dispatch_workgroups(outer_s, 1, 1);
6692 }
6693 Step::RmsNormBackwardGamma { params }
6694 | Step::RmsNormBackwardBeta { params } => {
6695 if params.inner == 0 {
6696 continue;
6697 }
6698 let rk = rms_norm_backward_param_kernel(&dev.device);
6699 pass.set_pipeline(&rk.pipeline);
6700 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6701 pass.dispatch_workgroups(1, 1, 1);
6702 }
6703 Step::LayerNormBackwardInput { params } => {
6704 let outer_s = scale(params.outer);
6705 if outer_s == 0 {
6706 continue;
6707 }
6708 let lk = layer_norm_backward_input_kernel(&dev.device);
6709 pass.set_pipeline(&lk.pipeline);
6710 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6711 pass.dispatch_workgroups(outer_s, 1, 1);
6712 }
6713 Step::LayerNormBackwardGammaPartial {
6714 params,
6715 num_workgroups,
6716 } => {
6717 if params.inner == 0 || *num_workgroups == 0 {
6718 continue;
6719 }
6720 let lk = layer_norm_backward_gamma_partial_kernel(&dev.device);
6721 pass.set_pipeline(&lk.pipeline);
6722 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6723 pass.dispatch_workgroups(*num_workgroups, 1, 1);
6724 }
6725 Step::LayerNormBackwardGammaReduce { params } => {
6726 if params.inner == 0 {
6727 continue;
6728 }
6729 let lk = layer_norm_backward_gamma_reduce_kernel(&dev.device);
6730 pass.set_pipeline(&lk.pipeline);
6731 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6732 pass.dispatch_workgroups(1, 1, 1);
6733 }
6734 Step::CumsumBackward { params } => {
6735 let outer_s = scale(params.outer);
6736 if outer_s == 0 {
6737 continue;
6738 }
6739 let ck = cumsum_backward_kernel(&dev.device);
6740 pass.set_pipeline(&ck.pipeline);
6741 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6742 let (gx, gy, gz) = dispatch_dims(outer_s, 64);
6743 pass.dispatch_workgroups(gx, gy, gz);
6744 }
6745 Step::RopeBackward { params } => {
6746 let seq_s = scale(params.seq);
6747 if seq_s == 0 {
6748 continue;
6749 }
6750 let rk = rope_backward_kernel(&dev.device);
6751 pass.set_pipeline(&rk.pipeline);
6752 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6753 let total = params.batch * seq_s * params.hidden;
6754 let (gx, gy, gz) = dispatch_dims(total, 64);
6755 pass.dispatch_workgroups(gx, gy, gz);
6756 }
6757 Step::GatherBackward { params } => {
6758 let outer_s = scale(params.outer);
6759 if outer_s == 0 {
6760 continue;
6761 }
6762 let total = outer_s * params.axis_dim * params.trailing;
6763 if total > 0 {
6764 let zk = gather_backward_zero_kernel(&dev.device);
6765 pass.set_pipeline(&zk.pipeline);
6766 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6767 let (gx, _, _) = dispatch_dims(total, 256);
6768 pass.dispatch_workgroups(gx, 1, 1);
6769 }
6770 let ak = gather_backward_acc_kernel(&dev.device);
6771 pass.set_pipeline(&ak.pipeline);
6772 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6773 pass.dispatch_workgroups(outer_s, 1, 1);
6774 }
6775 Step::Cumsum { params } => {
6776 let outer_s = scale(params.outer);
6777 if outer_s == 0 {
6778 continue;
6779 }
6780 let ck2 = cumsum_kernel(&dev.device);
6781 pass.set_pipeline(&ck2.pipeline);
6782 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6783 let (gx, gy, gz) = dispatch_dims(outer_s, 64);
6784 pass.dispatch_workgroups(gx, gy, gz);
6785 }
6786 Step::FftGpu {
6787 src_off,
6788 dst_off,
6789 outer,
6790 n,
6791 inverse,
6792 norm_scale,
6793 } => {
6794 let res = &self.fft_gpu_steps[fft_i];
6795 fft_i += 1;
6796 crate::fft_dispatch::dispatch_fft_gpu_in_pass(
6797 &dev.device,
6798 &dev.queue,
6799 &mut pass,
6800 res,
6801 *src_off,
6802 *dst_off,
6803 *outer,
6804 *n,
6805 *inverse != 0,
6806 *norm_scale,
6807 );
6808 }
6809 Step::Copy { params } => {
6810 let n_s = scale(params.n);
6811 if n_s == 0 {
6812 continue;
6813 }
6814 let ck2 = copy_kernel(&dev.device);
6815 pass.set_pipeline(&ck2.pipeline);
6816 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6817 let (gx, gy, gz) = dispatch_dims(n_s, 64);
6818 pass.dispatch_workgroups(gx, gy, gz);
6819 }
6820 Step::BufferCopy { .. } => {
6821 }
6823 Step::ElementwiseRegion { params } => {
6824 let len_s = scale(params.len);
6825 if len_s == 0 {
6826 continue;
6827 }
6828 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6829 if params.prologue == rlx_ir::REGION_PROLOGUE_RESIZE_NEAREST_2X_NCHW {
6830 let ek = elementwise_region_spatial_kernel(&dev.device);
6831 pass.set_pipeline(&ek.pipeline);
6832 let (gx, gy, gz) = dispatch_prologue_nchw(
6833 params.out_w,
6834 params.out_h,
6835 params.out_n * params.out_c,
6836 );
6837 pass.dispatch_workgroups(gx, gy, gz);
6838 } else {
6839 let ek = elementwise_region_kernel(&dev.device);
6840 pass.set_pipeline(&ek.pipeline);
6841 let (gx, gy, gz) = dispatch_dims(len_s, 64);
6842 pass.dispatch_workgroups(gx, gy, gz);
6843 }
6844 }
6845 Step::BatchElementwiseRegion { params } => {
6846 let slice_len_s = scale(params.slice_len);
6847 let num_batch_s = scale(params.num_batch);
6848 if slice_len_s == 0 || num_batch_s == 0 {
6849 continue;
6850 }
6851 let ek = batch_elementwise_region_kernel(&dev.device);
6852 pass.set_pipeline(&ek.pipeline);
6853 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6854 let (gx, gy, _) = dispatch_dims(slice_len_s, 64);
6855 pass.dispatch_workgroups(gx, gy, num_batch_s);
6856 }
6857 Step::Transpose { params, .. } => {
6858 let total_s = if params.bucket_outermost == 1 && params.out_dim_0 > 0 {
6862 let scaled_d0 = scale(params.out_dim_0);
6863 let inner = params.out_total / params.out_dim_0;
6864 scaled_d0 * inner
6865 } else {
6866 params.out_total
6867 };
6868 if total_s == 0 {
6869 continue;
6870 }
6871 let tk = transpose_kernel(&dev.device);
6872 pass.set_pipeline(&tk.pipeline);
6873 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6874 let (gx, gy, gz) = dispatch_dims(total_s, 64);
6875 pass.dispatch_workgroups(gx, gy, gz);
6876 }
6877 Step::Narrow { params } => {
6878 let total_s = scale(params.total);
6879 if total_s == 0 {
6880 continue;
6881 }
6882 let nk = narrow_kernel(&dev.device);
6883 pass.set_pipeline(&nk.pipeline);
6884 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6885 let (gx, gy, gz) = dispatch_dims(total_s, 64);
6886 pass.dispatch_workgroups(gx, gy, gz);
6887 }
6888 Step::Concat { params } => {
6889 let total_s = scale(params.total);
6890 if total_s == 0 {
6891 continue;
6892 }
6893 let cck = concat_kernel(&dev.device);
6894 pass.set_pipeline(&cck.pipeline);
6895 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6896 let (gx, gy, gz) = dispatch_dims(total_s, 64);
6897 pass.dispatch_workgroups(gx, gy, gz);
6898 }
6899 Step::Gather { params } => {
6900 let n_out_s = scale(params.n_out);
6901 if n_out_s == 0 {
6902 continue;
6903 }
6904 let gk = gather_kernel(&dev.device);
6905 pass.set_pipeline(&gk.pipeline);
6906 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6907 let (gx, gy, gz) = dispatch_dims(n_out_s, 64);
6908 pass.dispatch_workgroups(gx, gy, gz);
6909 }
6910 Step::GatherAxis { params } => {
6911 let total_s = scale(params.total);
6912 if total_s == 0 {
6913 continue;
6914 }
6915 let gk = gather_axis_kernel(&dev.device);
6916 pass.set_pipeline(&gk.pipeline);
6917 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6918 let (gx, gy, gz) = dispatch_dims(total_s, 64);
6919 pass.dispatch_workgroups(gx, gy, gz);
6920 }
6921 Step::Attention { params, .. } => {
6922 let seq_q_s = scale(params.seq_q);
6926 if seq_q_s == 0 {
6927 continue;
6928 }
6929 let ak = attention_kernel(&dev.device);
6930 pass.set_pipeline(&ak.pipeline);
6931 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6932 let total = params.batch * params.heads * seq_q_s;
6933 let (gx, gy, gz) = dispatch_dims(total, 64);
6934 pass.dispatch_workgroups(gx, gy, gz);
6935 }
6936 Step::AttentionBackward { params, .. } => {
6937 let axis = if params.wrt == 0 {
6938 params.seq_q
6939 } else {
6940 params.seq_k
6941 };
6942 let axis_s = scale(axis);
6943 if axis_s == 0 {
6944 continue;
6945 }
6946 let ak = attention_bwd_kernel(&dev.device);
6947 pass.set_pipeline(&ak.pipeline);
6948 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6949 let total = params.batch * params.heads * axis_s;
6950 let (gx, gy, gz) = dispatch_dims(total, 64);
6951 pass.dispatch_workgroups(gx, gy, gz);
6952 }
6953 Step::Rope { params } => {
6954 let s_active = scale(params.seq);
6957 let total_s = params.batch * s_active * params.last_dim;
6958 if total_s == 0 {
6959 continue;
6960 }
6961 let rk = rope_kernel(&dev.device);
6962 pass.set_pipeline(&rk.pipeline);
6963 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6964 let (gx, gy, gz) = dispatch_dims(total_s, 64);
6965 pass.dispatch_workgroups(gx, gy, gz);
6966 }
6967 Step::Expand { params, .. } => {
6968 let total_s = if params.bucket_outermost == 1 && params.out_dim_0 > 0 {
6969 let scaled_d0 = scale(params.out_dim_0);
6970 let inner = params.out_total / params.out_dim_0;
6971 scaled_d0 * inner
6972 } else {
6973 params.out_total
6974 };
6975 if total_s == 0 {
6976 continue;
6977 }
6978 let ek = expand_kernel(&dev.device);
6979 pass.set_pipeline(&ek.pipeline);
6980 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6981 let (gx, gy, gz) = dispatch_dims(total_s, 64);
6982 pass.dispatch_workgroups(gx, gy, gz);
6983 }
6984 Step::Argmax { params } => {
6985 let outer_s = scale(params.outer);
6986 if outer_s == 0 {
6987 continue;
6988 }
6989 let amk = argmax_kernel(&dev.device);
6990 pass.set_pipeline(&amk.pipeline);
6991 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6992 let (gx, gy, gz) = dispatch_dims(outer_s, 64);
6993 pass.dispatch_workgroups(gx, gy, gz);
6994 }
6995 Step::Pool2d { params } => {
6996 let n_s = scale(params.n);
6997 if n_s == 0 {
6998 continue;
6999 }
7000 let pk = pool2d_kernel(&dev.device);
7001 pass.set_pipeline(&pk.pipeline);
7002 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7003 let total = n_s * params.c * params.h_out * params.w_out;
7004 let (gx, gy, gz) = dispatch_dims(total, 64);
7005 pass.dispatch_workgroups(gx, gy, gz);
7006 }
7007 Step::Conv2d { params } => {
7008 let n_s = scale(params.n);
7009 if n_s == 0 {
7010 continue;
7011 }
7012 let ck2 = conv2d_kernel(&dev.device);
7013 pass.set_pipeline(&ck2.pipeline);
7014 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7015 let total = n_s * params.c_out * params.h_out * params.w_out;
7016 let (gx, gy, gz) = dispatch_dims(total, 64);
7017 pass.dispatch_workgroups(gx, gy, gz);
7018 }
7019 Step::Pool1d { params } => {
7020 let n_s = scale(params.n);
7021 if n_s == 0 {
7022 continue;
7023 }
7024 let pk = pool1d_kernel(&dev.device);
7025 pass.set_pipeline(&pk.pipeline);
7026 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7027 let total = n_s * params.c * params.l_out;
7028 let (gx, gy, gz) = dispatch_dims(total, 64);
7029 pass.dispatch_workgroups(gx, gy, gz);
7030 }
7031 Step::Pool3d { params } => {
7032 let n_s = scale(params.n);
7033 if n_s == 0 {
7034 continue;
7035 }
7036 let pk = pool3d_kernel(&dev.device);
7037 pass.set_pipeline(&pk.pipeline);
7038 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7039 let total = n_s * params.c * params.d_out * params.h_out * params.w_out;
7040 let (gx, gy, gz) = dispatch_dims(total, 64);
7041 pass.dispatch_workgroups(gx, gy, gz);
7042 }
7043 Step::Conv1d { params } => {
7044 let n_s = scale(params.n);
7045 if n_s == 0 {
7046 continue;
7047 }
7048 let ck = conv1d_kernel(&dev.device);
7049 pass.set_pipeline(&ck.pipeline);
7050 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7051 let total = n_s * params.c_out * params.l_out;
7052 let (gx, gy, gz) = dispatch_dims(total, 64);
7053 pass.dispatch_workgroups(gx, gy, gz);
7054 }
7055 Step::Conv3d { params } => {
7056 let n_s = scale(params.n);
7057 if n_s == 0 {
7058 continue;
7059 }
7060 let ck = conv3d_kernel(&dev.device);
7061 pass.set_pipeline(&ck.pipeline);
7062 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7063 let total =
7064 n_s * params.c_out * params.d_out * params.h_out * params.w_out;
7065 let (gx, gy, gz) = dispatch_dims(total, 64);
7066 pass.dispatch_workgroups(gx, gy, gz);
7067 }
7068 Step::ScatterAdd { params } => {
7069 let sk = scatter_add_kernel(&dev.device);
7070 pass.set_pipeline(&sk.pipeline);
7071 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7072 if params.op == 0 {
7078 let (gx, gy, gz) = dispatch_dims(params.out_total, 64);
7079 pass.dispatch_workgroups(gx, gy, gz);
7080 } else {
7081 pass.dispatch_workgroups(1, 1, 1);
7082 }
7083 }
7084 Step::TopK { params } => {
7085 let outer_s = scale(params.outer);
7086 if outer_s == 0 {
7087 continue;
7088 }
7089 let tk = topk_kernel(&dev.device);
7090 pass.set_pipeline(&tk.pipeline);
7091 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7092 let (gx, gy, gz) = dispatch_dims(outer_s, 64);
7093 pass.dispatch_workgroups(gx, gy, gz);
7094 }
7095 Step::WelchPeaksGpu { params } => {
7096 let batch_s = scale(params.welch_batch);
7097 if batch_s == 0 {
7098 continue;
7099 }
7100 let wk = welch_peaks_gpu_kernel(&dev.device);
7101 pass.set_pipeline(&wk.pipeline);
7102 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7103 let (gx, gy, gz) = dispatch_dims(batch_s, 64);
7104 pass.dispatch_workgroups(gx, gy, gz);
7105 }
7106 Step::UmapKnn { params } => {
7107 let n_s = scale(params.n);
7108 if n_s == 0 {
7109 continue;
7110 }
7111 let uk = umap_knn_kernel(&dev.device);
7112 pass.set_pipeline(&uk.pipeline);
7113 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7114 let (gx, gy, gz) = dispatch_dims(n_s, 64);
7115 pass.dispatch_workgroups(gx, gy, gz);
7116 }
7117 Step::GroupedMatmul { params } => {
7118 let m_s = scale(params.m);
7119 if m_s == 0 {
7120 continue;
7121 }
7122 let gk = grouped_matmul_kernel(&dev.device);
7123 pass.set_pipeline(&gk.pipeline);
7124 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7125 pass.dispatch_workgroups(params.n.div_ceil(8), m_s.div_ceil(8), 1);
7126 }
7127 Step::Sample { params } => {
7128 let outer_s = scale(params.outer);
7129 if outer_s == 0 {
7130 continue;
7131 }
7132 let sk = sample_kernel(&dev.device);
7133 pass.set_pipeline(&sk.pipeline);
7134 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7135 let (gx, gy, gz) = dispatch_dims(outer_s, 64);
7136 pass.dispatch_workgroups(gx, gy, gz);
7137 }
7138 Step::SelectiveScan { params } => {
7139 let ssk = selective_scan_kernel(&dev.device);
7144 pass.set_pipeline(&ssk.pipeline);
7145 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7146 let total = params.batch * params.hidden;
7147 let (gx, gy, gz) = dispatch_dims(total, 64);
7148 pass.dispatch_workgroups(gx, gy, gz);
7149 }
7150 Step::DequantMatmul { params } => {
7151 let m_s = scale(params.m);
7152 if m_s == 0 {
7153 continue;
7154 }
7155 let dk = dequant_matmul_kernel(&dev.device);
7156 pass.set_pipeline(&dk.pipeline);
7157 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7158 pass.dispatch_workgroups(params.n.div_ceil(8), m_s.div_ceil(8), 1);
7159 }
7160 Step::FusedResidualLn { params } => {
7161 let outer_s = scale(params.outer);
7162 if outer_s == 0 {
7163 continue;
7164 }
7165 let frk = fused_residual_ln_kernel(&dev.device);
7166 pass.set_pipeline(&frk.pipeline);
7167 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7168 let (gx, gy, gz) = dispatch_dims(outer_s, 64);
7169 pass.dispatch_workgroups(gx, gy, gz);
7170 }
7171 Step::FusedResidualLnTee { params } => {
7172 let outer_s = scale(params.outer);
7173 if outer_s == 0 {
7174 continue;
7175 }
7176 let frtk = fused_residual_ln_tee_kernel(&dev.device);
7177 pass.set_pipeline(&frtk.pipeline);
7178 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7179 let (gx, gy, gz) = dispatch_dims(outer_s, 64);
7180 pass.dispatch_workgroups(gx, gy, gz);
7181 }
7182 Step::FusedResidualRmsNorm { params } => {
7183 let outer_s = scale(params.outer);
7184 if outer_s == 0 {
7185 continue;
7186 }
7187 let frk = fused_residual_rms_norm_kernel(&dev.device);
7188 pass.set_pipeline(&frk.pipeline);
7189 pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7190 let (gx, gy, gz) = dispatch_dims(outer_s, 64);
7191 pass.dispatch_workgroups(gx, gy, gz);
7192 }
7193 Step::MatmulQkv { params, kind } => {
7194 let m_s = scale(params.m);
7195 if m_s == 0 {
7196 continue;
7197 }
7198 let qkv_coop_wide = matches!(kind, MatmulQkvKind::CoopF16Vk)
7199 && crate::coop_f16_vk::use_wide_matmul(
7200 params.b_off,
7201 params.n,
7202 &self.coop_f16_b_param,
7203 &self.coop_f16_vk_wide_b,
7204 );
7205 pass.set_bind_group(
7206 0,
7207 coop_f16_vk_bind_group(self, gpu_bi, qkv_coop_wide),
7208 &[],
7209 );
7210 match kind {
7211 MatmulQkvKind::CoopF16Vk => {
7212 if qkv_coop_wide {
7213 pass.set_pipeline(&matmul_qkv_kernel(&dev.device).pipeline);
7214 pass.dispatch_workgroups(
7215 params.n.div_ceil(32),
7216 m_s.div_ceil(32),
7217 1,
7218 );
7219 } else {
7220 let n_eff = scale(params.n);
7221 let mqk = matmul_qkv_coop_f16_vk_active_kernel(
7222 &dev.device,
7223 n_eff,
7224 )
7225 .expect("coop f16 matmul_qkv kernel missing");
7226 pass.set_pipeline(&mqk.pipeline);
7227 pass.dispatch_workgroups(
7228 m_s.div_ceil(16),
7229 params.n.div_ceil(16),
7230 1,
7231 );
7232 }
7233 }
7234 MatmulQkvKind::CoopF32 => {
7235 pass.set_pipeline(
7236 &matmul_qkv_coop_f32_kernel(&dev.device)
7237 .expect("coop matmul_qkv kernel missing")
7238 .pipeline,
7239 );
7240 pass.dispatch_workgroups(
7241 params.n.div_ceil(32),
7242 m_s.div_ceil(32),
7243 1,
7244 );
7245 }
7246 MatmulQkvKind::F32 => {
7247 pass.set_pipeline(&matmul_qkv_kernel(&dev.device).pipeline);
7248 pass.dispatch_workgroups(
7249 params.n.div_ceil(32),
7250 m_s.div_ceil(32),
7251 1,
7252 );
7253 }
7254 }
7255 }
7256 Step::DequantMatmulGguf { .. }
7257 | Step::DequantGroupedMatmulGguf { .. }
7258 | Step::GatedDeltaNet { .. }
7259 | Step::Llada2GroupLimitedGate { .. }
7260 | Step::UmapKnnHost { .. }
7261 | Step::FftHost { .. }
7262 | Step::Im2ColHost { .. }
7263 | Step::WelchPeaksHost { .. }
7264 | Step::LogMelHost { .. }
7265 | Step::LogMelBackwardHost { .. } => {}
7266 #[cfg(feature = "splat")]
7267 Step::GaussianSplatRender { .. }
7268 | Step::GaussianSplatRenderBackward { .. }
7269 | Step::GaussianSplatPrepare { .. }
7270 | Step::GaussianSplatRasterize { .. } => {}
7271 }
7272 if !matches!(step, Step::FftGpu { .. }) {
7273 gpu_bi += 1;
7274 }
7275 step_i += 1;
7276 pass_dispatched = true;
7277 }
7278 }
7279 let needs_f16_drain = step_i < self.schedule.len()
7280 && !step_runs_on_host(&self.schedule[step_i])
7281 && step_i > 0
7282 && step_needs_pass_flush(&self.schedule[step_i], &self.schedule[step_i - 1]);
7283 let gpu_schedule_done = step_i >= self.schedule.len();
7284 let skip_readback = rlx_ir::env::flag("RLX_BENCH_DISPATCH_ONLY");
7285 let defer_tail = gpu_schedule_done && self.schedule.iter().any(step_is_tail_host);
7286 let mut fused_readback: Option<(
7287 ReadbackLayout,
7288 std::sync::mpsc::Receiver<Result<(), wgpu::BufferAsyncError>>,
7289 Vec<usize>,
7290 )> = None;
7291 if gpu_schedule_done && !skip_readback && !defer_tail {
7292 if !self.gpu_handle_feeds.is_empty() {
7293 self.propagate_gpu_handle_feeds_on_gpu(dev, &mut enc);
7294 }
7295 let plan = self.readback_plan();
7296 let out_ids_all: Vec<_> = self.graph.outputs.clone();
7297 let out_ids: Vec<_> = plan.iter().map(|&i| out_ids_all[i]).collect();
7298 let layout = ReadbackLayout::for_nodes(&self.arena, &out_ids);
7299 if use_tiny_readback(&layout, out_ids.len()) && plan == vec![0] {
7300 if self.tiny_readback.is_none() {
7301 self.tiny_readback = Some(TinyReadbackStaging::new(&dev.device));
7302 }
7303 let tiny = self.tiny_readback.as_ref().expect("tiny readback");
7304 encode_readback_copies(&mut enc, &self.arena, tiny.buffer(), &out_ids, &layout);
7305 let map_rx = schedule_readback_map(&mut enc, tiny.buffer(), &layout);
7306 dev.queue.submit(std::iter::once(enc.finish()));
7307 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
7308 wait_readback_map(&dev.device, &map_rx, layout.total_bytes);
7309 map_rx.recv().unwrap().unwrap();
7310 return self.pack_readback_outputs(
7311 &plan,
7312 vec![decode_tiny_mapped_f32(tiny.buffer(), layout.total_bytes)],
7313 );
7314 }
7315 ReadbackStaging::prepare(
7316 &dev.device,
7317 &mut self.readback_staging,
7318 layout.total_bytes,
7319 );
7320 if let Some(staging) = self.readback_staging.as_ref() {
7321 encode_readback_copies(
7322 &mut enc,
7323 &self.arena,
7324 staging.buffer(),
7325 &out_ids,
7326 &layout,
7327 );
7328 let map_rx = schedule_readback_map(&mut enc, staging.buffer(), &layout);
7329 fused_readback = Some((layout, map_rx, plan));
7330 }
7331 }
7332 dev.queue.submit(std::iter::once(enc.finish()));
7333 if defer_tail {
7334 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
7335 self.run_tail_host_audio_ops(dev);
7336 if !skip_readback {
7337 let mut rb_enc =
7338 dev.device
7339 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
7340 label: Some("rlx-wgpu readback after tail-host"),
7341 });
7342 if !self.gpu_handle_feeds.is_empty() {
7343 self.propagate_gpu_handle_feeds_on_gpu(dev, &mut rb_enc);
7344 }
7345 let plan = self.readback_plan();
7346 let out_ids_all: Vec<_> = self.graph.outputs.clone();
7347 let out_ids: Vec<_> = plan.iter().map(|&i| out_ids_all[i]).collect();
7348 let layout = ReadbackLayout::for_nodes(&self.arena, &out_ids);
7349 if use_tiny_readback(&layout, out_ids.len()) && plan == vec![0] {
7350 if self.tiny_readback.is_none() {
7351 self.tiny_readback = Some(TinyReadbackStaging::new(&dev.device));
7352 }
7353 let tiny = self.tiny_readback.as_ref().expect("tiny readback");
7354 encode_readback_copies(
7355 &mut rb_enc,
7356 &self.arena,
7357 tiny.buffer(),
7358 &out_ids,
7359 &layout,
7360 );
7361 let map_rx = schedule_readback_map(&mut rb_enc, tiny.buffer(), &layout);
7362 dev.queue.submit(std::iter::once(rb_enc.finish()));
7363 wait_readback_map(&dev.device, &map_rx, layout.total_bytes);
7364 map_rx.recv().unwrap().unwrap();
7365 return self.pack_readback_outputs(
7366 &plan,
7367 vec![decode_tiny_mapped_f32(tiny.buffer(), layout.total_bytes)],
7368 );
7369 }
7370 ReadbackStaging::prepare(
7371 &dev.device,
7372 &mut self.readback_staging,
7373 layout.total_bytes,
7374 );
7375 if let Some(staging) = self.readback_staging.as_ref() {
7376 encode_readback_copies(
7377 &mut rb_enc,
7378 &self.arena,
7379 staging.buffer(),
7380 &out_ids,
7381 &layout,
7382 );
7383 let map_rx = schedule_readback_map(&mut rb_enc, staging.buffer(), &layout);
7384 dev.queue.submit(std::iter::once(rb_enc.finish()));
7385 wait_readback_map(&dev.device, &map_rx, layout.total_bytes);
7386 map_rx.recv().unwrap().unwrap();
7387 self.dump_node_stats_if_requested(dev);
7388 let partial = decode_mapped_readback_f32(staging.buffer(), &layout);
7389 return self.pack_readback_outputs(&plan, partial);
7390 }
7391 }
7392 }
7393 if needs_f16_drain {
7394 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
7395 }
7396 let need_host_sync =
7397 step_i < self.schedule.len() && step_runs_on_host(&self.schedule[step_i]);
7398 if need_host_sync {
7399 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
7400 }
7401 if gpu_schedule_done {
7402 if skip_readback || defer_tail {
7403 return self
7404 .graph
7405 .outputs
7406 .iter()
7407 .map(|&id| {
7408 let n = self.graph.node(id).shape.num_elements().unwrap_or(0);
7409 vec![0.0; n]
7410 })
7411 .collect();
7412 }
7413 if let (Some((layout, map_rx, plan)), Some(staging)) =
7414 (fused_readback, self.readback_staging.as_ref())
7415 {
7416 wait_readback_map(&dev.device, &map_rx, layout.total_bytes);
7417 map_rx.recv().unwrap().unwrap();
7418 self.dump_node_stats_if_requested(dev);
7419 let partial = decode_mapped_readback_f32(staging.buffer(), &layout);
7420 return self.pack_readback_outputs(&plan, partial);
7421 }
7422 break;
7423 }
7424 match &self.schedule[step_i] {
7425 Step::BufferCopy {
7426 src_byte_off,
7427 dst_byte_off,
7428 bytes,
7429 } => {
7430 let src = *src_byte_off as u64;
7433 let dst = *dst_byte_off as u64;
7434 let nbytes = *bytes as u64;
7435 let elems = (nbytes / 4).max(1) as u32;
7436 let lo = src.min(dst);
7437 let hi = src.saturating_add(nbytes).max(dst.saturating_add(nbytes));
7438 let max_binding = dev.device.limits().max_storage_buffer_binding_size;
7439 let span = hi.saturating_sub(lo).max(1);
7440 let mut size = span.div_ceil(256) * 256;
7441 size = size.max(256).min(max_binding);
7442 let mut base = (lo / 256) * 256;
7443 if base.saturating_add(size) > self.arena.size as u64 {
7444 base = (self.arena.size as u64).saturating_sub(size);
7445 base = (base / 256) * 256;
7446 }
7447 let p = CopyParams {
7448 n: elems,
7449 in_off: (src.saturating_sub(base) / 4) as u32,
7450 out_off: (dst.saturating_sub(base) / 4) as u32,
7451 _p0: 0,
7452 _p1: 0,
7453 _p2: 0,
7454 _p3: 0,
7455 _p4: 0,
7456 };
7457 let ck = copy_kernel(&dev.device);
7458 let u = dev.device.create_buffer(&wgpu::BufferDescriptor {
7459 label: Some("rlx-wgpu arena_copy uniform"),
7460 size: std::mem::size_of::<CopyParams>() as u64,
7461 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
7462 mapped_at_creation: false,
7463 });
7464 dev.queue.write_buffer(&u, 0, bytemuck::bytes_of(&p));
7465 let bg =
7466 bind_two_buf0_window(&dev.device, ck, &self.arena.buffer, base, size, &u);
7467 let mut enc =
7468 dev.device
7469 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
7470 label: Some("rlx-wgpu arena_copy"),
7471 });
7472 {
7473 let mut pass = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
7474 label: Some("rlx-wgpu arena_copy pass"),
7475 ..Default::default()
7476 });
7477 pass.set_pipeline(&ck.pipeline);
7478 pass.set_bind_group(0, &bg, &[]);
7479 let (gx, gy, gz) = dispatch_dims(elems, 64);
7480 pass.dispatch_workgroups(gx, gy, gz);
7481 }
7482 dev.queue.submit(std::iter::once(enc.finish()));
7483 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
7484 }
7485 Step::DequantMatmulGguf {
7486 m,
7487 k,
7488 n,
7489 scheme_id,
7490 x_byte_off,
7491 w_byte_off,
7492 out_byte_off,
7493 } => {
7494 crate::gguf_host::run_dequant_matmul_gguf(
7495 &self.arena,
7496 &dev.device,
7497 &dev.queue,
7498 *m as usize,
7499 *k as usize,
7500 *n as usize,
7501 *scheme_id,
7502 *x_byte_off as usize,
7503 *w_byte_off as usize,
7504 *out_byte_off as usize,
7505 );
7506 }
7507 Step::DequantGroupedMatmulGguf {
7508 m,
7509 k,
7510 n,
7511 num_experts,
7512 scheme_id,
7513 x_byte_off,
7514 w_byte_off,
7515 idx_byte_off,
7516 out_byte_off,
7517 } => {
7518 crate::gguf_host::run_dequant_grouped_matmul_gguf(
7519 &self.arena,
7520 &dev.device,
7521 &dev.queue,
7522 *m as usize,
7523 *k as usize,
7524 *n as usize,
7525 *num_experts as usize,
7526 *scheme_id,
7527 *x_byte_off as usize,
7528 *w_byte_off as usize,
7529 *idx_byte_off as usize,
7530 *out_byte_off as usize,
7531 );
7532 }
7533 Step::GatedDeltaNet {
7534 q_byte_off,
7535 k_byte_off,
7536 v_byte_off,
7537 g_byte_off,
7538 beta_byte_off,
7539 state_byte_off,
7540 dst_byte_off,
7541 batch,
7542 seq,
7543 heads,
7544 state_size,
7545 use_carry,
7546 } => {
7547 crate::gdn_host::run_gated_delta_net(
7548 &self.arena,
7549 &dev.device,
7550 &dev.queue,
7551 *q_byte_off as usize,
7552 *k_byte_off as usize,
7553 *v_byte_off as usize,
7554 *g_byte_off as usize,
7555 *beta_byte_off as usize,
7556 *state_byte_off as usize,
7557 *dst_byte_off as usize,
7558 *batch as usize,
7559 *seq as usize,
7560 *heads as usize,
7561 *state_size as usize,
7562 *use_carry,
7563 );
7564 }
7565 Step::Llada2GroupLimitedGate {
7566 sig_byte_off,
7567 route_byte_off,
7568 out_byte_off,
7569 n_elems,
7570 attrs,
7571 } => {
7572 crate::llada2_gate_host::run_llada2_group_limited_gate(
7573 &self.arena,
7574 &dev.device,
7575 &dev.queue,
7576 *sig_byte_off as usize,
7577 *route_byte_off as usize,
7578 *out_byte_off as usize,
7579 *n_elems as usize,
7580 attrs,
7581 );
7582 }
7583 Step::UmapKnnHost {
7584 pairwise_byte_off,
7585 out_byte_off,
7586 n,
7587 k,
7588 } => {
7589 crate::umap_knn_host::run_umap_knn(
7590 &self.arena,
7591 &dev.device,
7592 &dev.queue,
7593 *pairwise_byte_off as usize,
7594 *out_byte_off as usize,
7595 *n as usize,
7596 *k as usize,
7597 );
7598 }
7599 Step::FftHost {
7600 src_byte_off,
7601 dst_byte_off,
7602 outer,
7603 n_complex,
7604 inverse,
7605 norm_tag,
7606 dtype_tag,
7607 } => {
7608 crate::fft_host::run_fft1d(
7609 &self.arena,
7610 &dev.device,
7611 &dev.queue,
7612 *src_byte_off as usize,
7613 *dst_byte_off as usize,
7614 *outer as usize,
7615 *n_complex as usize,
7616 *inverse,
7617 *norm_tag,
7618 fft_dtype_from_tag(*dtype_tag),
7619 );
7620 }
7621 Step::WelchPeaksHost { .. }
7622 | Step::LogMelHost { .. }
7623 | Step::LogMelBackwardHost { .. } => {
7624 unreachable!("tail-host audio ops run after GPU wait")
7625 }
7626 Step::Im2ColHost {
7627 x_byte_off,
7628 col_byte_off,
7629 n,
7630 c_in,
7631 h,
7632 w,
7633 h_out,
7634 w_out,
7635 kh,
7636 kw,
7637 sh,
7638 sw,
7639 ph,
7640 pw,
7641 dh,
7642 dw_dil,
7643 } => {
7644 crate::im2col_host::run_im2col(
7645 &self.arena,
7646 &dev.device,
7647 &dev.queue,
7648 *x_byte_off as usize,
7649 *col_byte_off as usize,
7650 *n,
7651 *c_in,
7652 *h,
7653 *w,
7654 *h_out,
7655 *w_out,
7656 *kh,
7657 *kw,
7658 *sh,
7659 *sw,
7660 *ph,
7661 *pw,
7662 *dh,
7663 *dw_dil,
7664 );
7665 }
7666 #[cfg(feature = "splat")]
7667 Step::GaussianSplatRender {
7668 positions_byte_off,
7669 positions_len,
7670 scales_byte_off,
7671 scales_len,
7672 rotations_byte_off,
7673 rotations_len,
7674 opacities_byte_off,
7675 opacities_len,
7676 colors_byte_off,
7677 colors_len,
7678 sh_coeffs_byte_off,
7679 sh_coeffs_len,
7680 meta_byte_off,
7681 dst_byte_off,
7682 dst_len,
7683 width,
7684 height,
7685 tile_size,
7686 radius_scale,
7687 alpha_cutoff,
7688 max_splat_steps,
7689 transmittance_threshold,
7690 max_list_entries,
7691 } => {
7692 crate::splat::run_gaussian_splat_render(
7693 &self.arena,
7694 &dev.device,
7695 &dev.queue,
7696 *positions_byte_off as usize,
7697 *positions_len as usize,
7698 *scales_byte_off as usize,
7699 *scales_len as usize,
7700 *rotations_byte_off as usize,
7701 *rotations_len as usize,
7702 *opacities_byte_off as usize,
7703 *opacities_len as usize,
7704 *colors_byte_off as usize,
7705 *colors_len as usize,
7706 *sh_coeffs_byte_off as usize,
7707 *sh_coeffs_len as usize,
7708 *meta_byte_off as usize,
7709 *dst_byte_off as usize,
7710 *dst_len as usize,
7711 *width,
7712 *height,
7713 *tile_size,
7714 *radius_scale,
7715 *alpha_cutoff,
7716 *max_splat_steps,
7717 *transmittance_threshold,
7718 *max_list_entries,
7719 );
7720 }
7721 #[cfg(feature = "splat")]
7722 Step::GaussianSplatPrepare {
7723 positions_byte_off,
7724 positions_len,
7725 scales_byte_off,
7726 scales_len,
7727 rotations_byte_off,
7728 rotations_len,
7729 opacities_byte_off,
7730 opacities_len,
7731 colors_byte_off,
7732 colors_len,
7733 sh_coeffs_byte_off,
7734 sh_coeffs_len,
7735 meta_byte_off,
7736 meta_len,
7737 prep_byte_off,
7738 prep_len,
7739 width,
7740 height,
7741 tile_size,
7742 radius_scale,
7743 alpha_cutoff,
7744 max_splat_steps,
7745 transmittance_threshold,
7746 max_list_entries,
7747 } => {
7748 crate::splat::run_gaussian_splat_prepare(
7749 &self.arena,
7750 &dev.device,
7751 &dev.queue,
7752 *positions_byte_off as usize,
7753 *positions_len as usize,
7754 *scales_byte_off as usize,
7755 *scales_len as usize,
7756 *rotations_byte_off as usize,
7757 *rotations_len as usize,
7758 *opacities_byte_off as usize,
7759 *opacities_len as usize,
7760 *colors_byte_off as usize,
7761 *colors_len as usize,
7762 *sh_coeffs_byte_off as usize,
7763 *sh_coeffs_len as usize,
7764 *meta_byte_off as usize,
7765 *meta_len as usize,
7766 *prep_byte_off as usize,
7767 *prep_len as usize,
7768 *width,
7769 *height,
7770 *tile_size,
7771 *radius_scale,
7772 *alpha_cutoff,
7773 *max_splat_steps,
7774 *transmittance_threshold,
7775 *max_list_entries,
7776 );
7777 }
7778 #[cfg(feature = "splat")]
7779 Step::GaussianSplatRasterize {
7780 prep_byte_off,
7781 prep_len,
7782 meta_byte_off,
7783 meta_len,
7784 dst_byte_off,
7785 dst_len,
7786 count,
7787 width,
7788 height,
7789 tile_size,
7790 alpha_cutoff,
7791 max_splat_steps,
7792 transmittance_threshold,
7793 max_list_entries,
7794 } => {
7795 crate::splat::run_gaussian_splat_rasterize(
7796 &self.arena,
7797 &dev.device,
7798 &dev.queue,
7799 *prep_byte_off as usize,
7800 *prep_len as usize,
7801 *meta_byte_off as usize,
7802 *meta_len as usize,
7803 *dst_byte_off as usize,
7804 *dst_len as usize,
7805 *count as usize,
7806 *width,
7807 *height,
7808 *tile_size,
7809 *alpha_cutoff,
7810 *max_splat_steps,
7811 *transmittance_threshold,
7812 *max_list_entries,
7813 );
7814 }
7815 #[cfg(feature = "splat")]
7816 Step::GaussianSplatRenderBackward {
7817 positions_byte_off,
7818 positions_len,
7819 scales_byte_off,
7820 scales_len,
7821 rotations_byte_off,
7822 rotations_len,
7823 opacities_byte_off,
7824 opacities_len,
7825 colors_byte_off,
7826 colors_len,
7827 sh_coeffs_byte_off,
7828 sh_coeffs_len,
7829 meta_byte_off,
7830 d_loss_byte_off,
7831 d_loss_len,
7832 packed_byte_off,
7833 packed_len,
7834 width,
7835 height,
7836 tile_size,
7837 radius_scale,
7838 alpha_cutoff,
7839 max_splat_steps,
7840 transmittance_threshold,
7841 max_list_entries,
7842 loss_grad_clip,
7843 sh_band,
7844 max_anisotropy,
7845 } => {
7846 crate::splat::run_gaussian_splat_render_backward(
7847 &self.arena,
7848 &dev.device,
7849 &dev.queue,
7850 *positions_byte_off as usize,
7851 *positions_len as usize,
7852 *scales_byte_off as usize,
7853 *scales_len as usize,
7854 *rotations_byte_off as usize,
7855 *rotations_len as usize,
7856 *opacities_byte_off as usize,
7857 *opacities_len as usize,
7858 *colors_byte_off as usize,
7859 *colors_len as usize,
7860 *sh_coeffs_byte_off as usize,
7861 *sh_coeffs_len as usize,
7862 *meta_byte_off as usize,
7863 *d_loss_byte_off as usize,
7864 *d_loss_len as usize,
7865 *packed_byte_off as usize,
7866 *packed_len as usize,
7867 *width,
7868 *height,
7869 *tile_size,
7870 *radius_scale,
7871 *alpha_cutoff,
7872 *max_splat_steps,
7873 *transmittance_threshold,
7874 *max_list_entries,
7875 *loss_grad_clip,
7876 *sh_band,
7877 *max_anisotropy,
7878 );
7879 }
7880 _ => break,
7881 }
7882 step_i += 1;
7883 }
7884
7885 self.dump_node_stats_if_requested(dev);
7886
7887 if rlx_ir::env::flag("RLX_WGPU_NAN_TRACE") {
7888 let mut bad_nodes = Vec::new();
7889 for node in self.graph.nodes() {
7890 if !self.arena.has(node.id) {
7891 continue;
7892 }
7893 if matches!(
7895 node.op,
7896 rlx_ir::Op::Input { .. }
7897 | rlx_ir::Op::Param { .. }
7898 | rlx_ir::Op::Constant { .. }
7899 ) {
7900 continue;
7901 }
7902 let data = self.arena.read_f32(&dev.device, &dev.queue, node.id);
7903 let nan_count = data.iter().filter(|v| v.is_nan()).count();
7904 let inf_count = data.iter().filter(|v| v.is_infinite()).count();
7905 if nan_count > 0 || inf_count > 0 {
7906 let first_nan = data.iter().position(|v| v.is_nan());
7908 if let Some(idx) = first_nan {
7909 let lo = idx.saturating_sub(2);
7910 let hi = (idx + 3).min(data.len());
7911 eprintln!(
7912 " node {:?} op={:?} len={} nan={} inf={} \
7913 first_nan_idx={} ctx={:?}",
7914 node.id,
7915 node.op,
7916 data.len(),
7917 nan_count,
7918 inf_count,
7919 idx,
7920 &data[lo..hi]
7921 );
7922 }
7923 bad_nodes.push((node.id, data.len(), nan_count, inf_count));
7924 if bad_nodes.len() >= 3 {
7925 break;
7926 }
7927 }
7928 }
7929 if bad_nodes.is_empty() {
7930 eprintln!("[wgpu-nan-trace] no NaN/Inf in any node — clean run");
7931 } else {
7932 eprintln!(
7933 "[wgpu-nan-trace] first {} bad nodes (above)",
7934 bad_nodes.len()
7935 );
7936 }
7937 }
7938
7939 if rlx_ir::env::flag("RLX_BENCH_DISPATCH_ONLY") {
7940 return self
7941 .graph
7942 .outputs
7943 .iter()
7944 .map(|&id| {
7945 let n = self.graph.node(id).shape.num_elements().unwrap_or(0);
7946 vec![0.0; n]
7947 })
7948 .collect();
7949 }
7950 let out_ids: Vec<_> = self.graph.outputs.clone();
7951 read_f32_many_pooled(
7952 &self.arena,
7953 &dev.device,
7954 &dev.queue,
7955 &out_ids,
7956 &mut self.readback_staging,
7957 )
7958 }
7959}
7960
7961fn dispatch_prologue_nchw(w: u32, h: u32, nc: u32) -> (u32, u32, u32) {
7968 (w.div_ceil(8).max(1), h.div_ceil(8).max(1), nc.max(1))
7969}
7970
7971fn dispatch_dims(threads_total: u32, workgroup_size: u32) -> (u32, u32, u32) {
7972 let groups = threads_total.div_ceil(workgroup_size);
7973 if groups <= 65535 {
7974 (groups, 1, 1)
7975 } else {
7976 let gx = 65535u32;
7977 let gy = groups.div_ceil(gx);
7978 (gx, gy, 1)
7979 }
7980}
7981
7982fn coop_f16_vk_eligible(dev: &wgpu::Device, m: u32, k: u32, n: u32) -> bool {
8004 if rlx_ir::env::flag("RLX_WGPU_NO_COOP_F16_VK")
8005 || rlx_ir::env::flag("RLX_WGPU_COOP_F16_VK_DISABLE")
8006 {
8007 return false;
8008 }
8009 if !rlx_ir::env::flag("RLX_WGPU_COOP_F16_VK_ENABLE") {
8010 return false;
8011 }
8012 m.is_multiple_of(16)
8013 && k.is_multiple_of(16)
8014 && n.is_multiple_of(16)
8015 && dev
8016 .features()
8017 .contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX)
8018 && dev.features().contains(wgpu::Features::SHADER_F16)
8019 && crate::device::coop_discrete_backend()
8020 && crate::device::coop_f16_16x16_supported()
8021}
8022
8023fn step_needs_pass_flush(step: &Step, prev: &Step) -> bool {
8024 match step {
8025 Step::CastF32ToF16 { .. } => matches!(
8026 prev,
8027 Step::Unary {
8028 f16_mirror: false,
8029 ..
8030 }
8031 ),
8032 Step::Matmul {
8033 compute_precision: MatmulCompute::CoopF16Vk,
8034 ..
8035 }
8036 | Step::MatmulQkv {
8037 kind: MatmulQkvKind::CoopF16Vk,
8038 ..
8039 } => matches!(prev, Step::Unary { .. } | Step::CastF32ToF16 { .. }),
8040 _ => false,
8041 }
8042}
8043
8044fn dispatch_wide_f32_matmul(
8045 pass: &mut wgpu::ComputePass<'_>,
8046 mm_w_active: &Kernel,
8047 mm_k: &Kernel,
8048 m_s: u32,
8049 n: u32,
8050 batch: u32,
8051) {
8052 let backend = wgpu_device()
8067 .map(|d| d.backend)
8068 .unwrap_or(wgpu::Backend::Noop);
8069 let is_vulkan_dx12 = matches!(backend, wgpu::Backend::Vulkan | wgpu::Backend::Dx12);
8070 let prefer_small_for_m = is_vulkan_dx12 && m_s < 64;
8071 let use_wide = !prefer_small_for_m && m_s >= 32 && n >= 64;
8072 if use_wide {
8073 pass.set_pipeline(&mm_w_active.pipeline);
8074 let (gx, gy) = if is_vulkan_dx12 {
8075 (n.div_ceil(64), m_s.div_ceil(64))
8076 } else {
8077 (n.div_ceil(64), m_s.div_ceil(32))
8078 };
8079 pass.dispatch_workgroups(gx, gy, batch);
8080 } else {
8081 pass.set_pipeline(&mm_k.pipeline);
8082 pass.dispatch_workgroups(n.div_ceil(32), m_s.div_ceil(32), batch);
8083 }
8084}
8085
8086fn coop_f16_vk_bind_group(exe: &WgpuExecutable, gpu_bi: usize, use_wide: bool) -> &wgpu::BindGroup {
8087 if use_wide {
8088 exe.coop_f16_vk_wide_bind_groups
8089 .get(&gpu_bi)
8090 .unwrap_or(&exe.bind_groups[gpu_bi])
8091 } else {
8092 &exe.bind_groups[gpu_bi]
8093 }
8094}
8095
8096fn require_equal_shapes(graph: &Graph, ids: &[NodeId], op_name: &str) {
8097 let s0 = graph.node(ids[0]).shape.num_elements().unwrap_or(0);
8098 for &id in &ids[1..] {
8099 let si = graph.node(id).shape.num_elements().unwrap_or(0);
8100 if si != s0 {
8101 panic!(
8102 "rlx-wgpu {op_name}: broadcasting not yet implemented; \
8103 inputs must have the same element count (got {s0} vs {si})"
8104 );
8105 }
8106 }
8107}
8108
8109fn arena_whole_arena_bind(arena: &Arena, max_binding: u64) -> Option<(u64, u64)> {
8111 let need = arena.size as u64;
8112 if need > max_binding {
8113 return None;
8114 }
8115 let buf_bytes = arena.buffer.size();
8117 let size = need.min(buf_bytes).max(256);
8118 Some((0, size))
8119}
8120
8121fn arena_window_for_nodes(dev: &wgpu::Device, arena: &Arena, ids: &[NodeId]) -> (u64, u64) {
8122 const ALIGN: u64 = 256;
8124 let max_binding = dev.limits().max_storage_buffer_binding_size;
8125 if let Some(w) = arena_whole_arena_bind(arena, max_binding) {
8126 return w;
8127 }
8128 let mut lo: u64 = u64::MAX;
8129 let mut hi: u64 = 0;
8130 for &id in ids {
8131 let off = arena.offset(id) as u64;
8132 let len = arena.len_of(id) as u64;
8133 lo = lo.min(off);
8134 hi = hi.max(off.saturating_add(len));
8135 }
8136 if lo == u64::MAX {
8137 return (0, max_binding.max(256));
8138 }
8139 let span = hi.saturating_sub(lo).max(1);
8140 if span > max_binding {
8141 let mut details = String::new();
8142 for &id in ids.iter().take(6) {
8143 let off = arena.offset(id);
8144 let len = arena.len_of(id);
8145 details.push_str(&format!(" id={id:?}@{off}+{len};"));
8146 }
8147 panic!(
8148 "rlx-wgpu: op needs {} bytes of arena span (>{});{}",
8149 span, max_binding, details
8150 );
8151 }
8152 let mut base = (lo / ALIGN) * ALIGN;
8153 let mut size = span.div_ceil(ALIGN) * ALIGN;
8156 size = size.max(256).min(max_binding);
8157 if base.saturating_add(size) > arena.size as u64 {
8158 base = (arena.size as u64).saturating_sub(size);
8159 base = (base / ALIGN) * ALIGN;
8160 }
8161 if base > lo || base.saturating_add(size) < hi {
8162 base = (lo / ALIGN) * ALIGN;
8163 size = hi.saturating_sub(base).div_ceil(ALIGN) * ALIGN;
8164 size = size.max(256).min(max_binding);
8165 if base.saturating_add(size) > arena.size as u64 {
8166 base = hi.saturating_sub(size);
8167 base = (base / ALIGN) * ALIGN;
8168 }
8169 }
8170 (base, size)
8171}
8172
8173fn arena_local_off_f32(arena: &Arena, id: NodeId, base: u64) -> u32 {
8174 (((arena.offset(id) as u64).saturating_sub(base)) / 4) as u32
8175}
8176
8177fn arena_tensor_in_window(arena: &Arena, id: NodeId, base: u64, size: u64) -> bool {
8178 let src = arena.offset(id) as u64;
8179 let len = arena.len_of(id) as u64;
8180 src >= base && src.saturating_add(len) <= base.saturating_add(size)
8181}
8182
8183fn arena_tensors_overlap(arena: &Arena, a: NodeId, b: NodeId) -> bool {
8185 if a == b {
8186 return true;
8187 }
8188 let (a0, al) = (arena.offset(a) as u64, arena.len_of(a) as u64);
8189 let (b0, bl) = (arena.offset(b) as u64, arena.len_of(b) as u64);
8190 if al == 0 || bl == 0 {
8191 return false;
8192 }
8193 let a1 = a0.saturating_add(al);
8194 let b1 = b0.saturating_add(bl);
8195 a0 < b1 && b0 < a1
8196}
8197
8198fn arena_matmul_bind_window(
8201 device: &wgpu::Device,
8202 arena: &Arena,
8203 graph: &Graph,
8204 param_offsets: &HashMap<String, NodeId>,
8205 out_id: NodeId,
8206 a_id: NodeId,
8207 b_id: NodeId,
8208) -> (u64, u64, bool) {
8209 let max_binding = device.limits().max_storage_buffer_binding_size;
8210 if let Some((base, size)) = arena_whole_arena_bind(arena, max_binding) {
8211 return (base, size, false);
8212 }
8213 let ids = [out_id, a_id, b_id];
8214 let all_fits = arena_span_bytes(arena, &ids) <= max_binding;
8215 let b_bytes = arena.len_of(b_id) as u64;
8216 let b_is_param = tensor_is_graph_param(graph, param_offsets, b_id);
8217 let param_anchor =
8218 b_is_param && b_bytes <= max_binding && (!all_fits || b_bytes > ARENA_STAGE_CAP);
8219 let (mut base, mut size) = if param_anchor {
8220 arena_window_for_nodes(device, arena, &[b_id])
8221 } else if all_fits {
8222 arena_window_for_nodes(device, arena, &ids)
8223 } else {
8224 arena_window_for_nodes(device, arena, &[out_id])
8225 };
8226 let param_anchor = param_anchor
8227 || (b_is_param
8228 && b_bytes <= max_binding
8229 && !arena_tensor_in_window(arena, b_id, base, size));
8230 if param_anchor && !arena_tensor_in_window(arena, b_id, base, size) {
8231 (base, size) = arena_window_for_nodes(device, arena, &[b_id]);
8232 }
8233 (base, size, param_anchor)
8234}
8235
8236fn arena_expand_bind_window(
8239 arena: &Arena,
8240 ids: &[NodeId],
8241 base: &mut u64,
8242 size: &mut u64,
8243 max_binding: u64,
8244) {
8245 const ALIGN: u64 = 256;
8246 let mut lo = *base;
8247 let mut hi = base.saturating_add(*size);
8248 for &id in ids {
8249 let off = arena.offset(id) as u64;
8250 let len = arena.len_of(id) as u64;
8251 lo = lo.min(off);
8252 hi = hi.max(off.saturating_add(len));
8253 }
8254 let span = hi.saturating_sub(lo).max(1);
8255 if span > max_binding {
8256 return;
8257 }
8258 *base = (lo / ALIGN) * ALIGN;
8259 *size = span.div_ceil(ALIGN) * ALIGN;
8260 *size = (*size).max(256).min(max_binding);
8261 if (*base).saturating_add(*size) > arena.size as u64 {
8262 *base = (arena.size as u64).saturating_sub(*size);
8263 *base = (*base / ALIGN) * ALIGN;
8264 }
8265}
8266
8267fn arena_off_in_bind_window(
8268 graph: &Graph,
8269 param_offsets: &HashMap<String, NodeId>,
8270 device: &wgpu::Device,
8271 arena: &Arena,
8272 schedule: &mut Vec<Step>,
8273 scratch: &mut u64,
8274 id: NodeId,
8275 base: &mut u64,
8276 size: &mut u64,
8277) -> u32 {
8278 let max_binding = device.limits().max_storage_buffer_binding_size;
8279 if let Some((b, s)) = arena_whole_arena_bind(arena, max_binding) {
8280 *base = b;
8281 *size = s;
8282 return arena_local_off_f32(arena, id, b);
8283 }
8284 if arena_tensor_in_window(arena, id, *base, *size) {
8285 arena_local_off_f32(arena, id, *base)
8286 } else {
8287 let len = arena.len_of(id) as u64;
8288 if tensor_is_graph_param(graph, param_offsets, id) && len > max_binding {
8289 panic!(
8290 "rlx-wgpu: param node {:?} ({} bytes) exceeds max_storage_buffer_binding_size \
8291 ({max_binding}); split weights or use f16 shadow binds",
8292 id, len
8293 );
8294 }
8295 if len > ARENA_STAGE_CAP {
8296 let op = &graph.node(id).op;
8297 panic!(
8298 "rlx-wgpu: bind_window would stage {} bytes for {:?} op={op:?} \
8299 (off={}, base={}, bind_size={})",
8300 len,
8301 id,
8302 arena.offset(id),
8303 *base,
8304 *size,
8305 );
8306 }
8307 arena_off_in_window_or_stage(arena, schedule, scratch, base, size, max_binding, id)
8308 }
8309}
8310
8311fn arena_multi_op_window(
8315 dev: &wgpu::Device,
8316 arena: &Arena,
8317 graph: &Graph,
8318 param_offsets: &HashMap<String, NodeId>,
8319 _schedule: &mut Vec<Step>,
8320 scratch: &mut u64,
8321 ids: &[NodeId],
8322) -> (u64, u64, bool) {
8323 let max_binding = dev.limits().max_storage_buffer_binding_size;
8324 if let Some((base, size)) = arena_whole_arena_bind(arena, max_binding) {
8325 *scratch = arena.scratch_off as u64;
8326 return (base, size, false);
8327 }
8328 let param_anchor = if arena_span_bytes(arena, ids) > max_binding {
8329 ids.iter()
8330 .find(|&&id| {
8331 let nbytes = arena.len_of(id) as u64;
8332 tensor_is_graph_param(graph, param_offsets, id) && nbytes <= max_binding
8333 })
8334 .copied()
8335 } else {
8336 None
8337 };
8338 let mut param_anchored = param_anchor.is_some();
8339 let (mut base, mut size) = if arena_span_bytes(arena, ids) <= max_binding {
8340 arena_window_for_nodes(dev, arena, ids)
8341 } else if let Some(id) = param_anchor {
8342 arena_window_for_nodes(dev, arena, &[id])
8343 } else {
8344 arena_window_for_nodes(dev, arena, &[ids[0]])
8345 };
8346 if let Some(id) = param_anchor {
8347 if !arena_tensor_in_window(arena, id, base, size) {
8348 (base, size) = arena_window_for_nodes(dev, arena, &[id]);
8349 }
8350 param_anchored = true;
8351 } else {
8352 for &id in ids {
8353 let nbytes = arena.len_of(id) as u64;
8354 if tensor_is_graph_param(graph, param_offsets, id)
8355 && nbytes <= max_binding
8356 && !arena_tensor_in_window(arena, id, base, size)
8357 {
8358 (base, size) = arena_window_for_nodes(dev, arena, &[id]);
8359 param_anchored = true;
8360 break;
8361 }
8362 }
8363 }
8364 *scratch = arena.scratch_off as u64;
8365 if param_anchored {
8366 arena_ensure_scratch_in_window(scratch, base, size);
8367 }
8368 (base, size, param_anchored)
8369}
8370
8371fn arena_bind_window_covering_scratch_if_needed(
8372 arena: &Arena,
8373 base: u64,
8374 size: u64,
8375 scratch: u64,
8376) -> u64 {
8377 if scratch <= arena.scratch_off as u64 {
8380 return base;
8381 }
8382 if scratch >= base && scratch.saturating_add(ARENA_STAGE_CAP) <= base.saturating_add(size) {
8383 return base;
8384 }
8385 arena_window_covering_scratch(arena, base, size)
8386}
8387
8388fn arena_ensure_scratch_in_window(scratch: &mut u64, base: u64, size: u64) {
8391 let cap = ARENA_STAGE_CAP.min(size);
8392 let end = base.saturating_add(size);
8393 if *scratch < base || scratch.saturating_add(cap) > end {
8394 *scratch = end.saturating_sub(cap);
8395 *scratch = (*scratch / 256) * 256;
8396 }
8397}
8398
8399#[allow(dead_code)]
8400fn arena_off_for_window(
8401 arena: &Arena,
8402 schedule: &mut Vec<Step>,
8403 scratch: &mut u64,
8404 id: NodeId,
8405 _window_ids: &[NodeId],
8406 mut base: u64,
8407 mut size: u64,
8408 max_binding: u64,
8409 _fits_in_one_binding: bool,
8410) -> u32 {
8411 let src = arena.offset(id) as u64;
8412 let len = arena.len_of(id) as u64;
8413 if src >= base && src.saturating_add(len) <= base.saturating_add(size) {
8414 arena_local_off_f32(arena, id, base)
8415 } else {
8416 arena_off_in_window_or_stage(
8417 arena,
8418 schedule,
8419 scratch,
8420 &mut base,
8421 &mut size,
8422 max_binding,
8423 id,
8424 )
8425 }
8426}
8427
8428fn f16_shadow_bind_range(arena_base: u64, arena_size: u64, f16_buf_bytes: u64) -> (u64, u64) {
8430 const ALIGN: u64 = 256;
8431 let mut base = (arena_base / 2 / ALIGN) * ALIGN;
8432 let mut size = (arena_size / 2).div_ceil(ALIGN) * ALIGN;
8433 size = size.max(256).min(f16_buf_bytes);
8434 if base.saturating_add(size) > f16_buf_bytes {
8435 base = f16_buf_bytes.saturating_sub(size);
8436 base = (base / ALIGN) * ALIGN;
8437 }
8438 (base, size)
8439}
8440
8441fn f16_weight_bind_range(
8444 dev: &wgpu::Device,
8445 f16_buf_bytes: u64,
8446 b_off: u32,
8447 k: u32,
8448 n: u32,
8449 batch: u32,
8450 b_batch_stride: u32,
8451) -> (u64, u64, u32) {
8452 const ALIGN: u64 = 256;
8453 let max_binding = dev.limits().max_storage_buffer_binding_size;
8454 let b0 = b_off as u64;
8455 let span = (k as u64).saturating_mul(n as u64);
8456 let batch_n = batch.max(1) as u64;
8457 let stride = if batch_n > 1 {
8458 b_batch_stride as u64
8459 } else {
8460 span
8461 };
8462 let hi_elems = b0
8463 .saturating_add((batch_n - 1).saturating_mul(stride))
8464 .saturating_add(span);
8465 let lo_byte = b0.saturating_mul(2);
8466 let hi_byte = hi_elems.saturating_mul(2).saturating_add(8);
8467 let need = hi_byte.saturating_sub(lo_byte).max(1);
8468 if need > max_binding {
8469 panic!(
8470 "rlx-wgpu: f16 weight region needs {need} bytes (> {max_binding}); \
8471 matmul k={k} n={n} batch={batch}"
8472 );
8473 }
8474 let mut base = (lo_byte / ALIGN) * ALIGN;
8475 let mut size = need.div_ceil(ALIGN) * ALIGN;
8476 size = size.max(256).min(max_binding).min(f16_buf_bytes);
8477 if base.saturating_add(size) < hi_byte {
8478 base = hi_byte.saturating_sub(size);
8479 base = (base / ALIGN) * ALIGN;
8480 }
8481 if base.saturating_add(size) > f16_buf_bytes {
8482 base = f16_buf_bytes.saturating_sub(size);
8483 base = (base / ALIGN) * ALIGN;
8484 }
8485 let rebased = b_off.saturating_sub((base / 2) as u32);
8486 (base, size, rebased)
8487}
8488
8489const ARENA_STAGE_CAP: u64 = 256 * 1024 * 1024;
8490
8491fn arena_off_in_window_or_stage(
8494 arena: &Arena,
8495 schedule: &mut Vec<Step>,
8496 scratch: &mut u64,
8497 base: &mut u64,
8498 size: &mut u64,
8499 max_binding: u64,
8500 id: NodeId,
8501) -> u32 {
8502 let src = arena.offset(id) as u64;
8503 let len = arena.len_of(id) as u64;
8504 if src >= *base && src.saturating_add(len) <= (*base).saturating_add(*size) {
8505 return arena_local_off_f32(arena, id, *base);
8506 }
8507 if len > ARENA_STAGE_CAP {
8508 panic!(
8509 "rlx-wgpu: cannot stage {} bytes for node {:?} (cap {ARENA_STAGE_CAP})",
8510 len, id
8511 );
8512 }
8513 let aligned = len.div_ceil(256) * 256;
8514 let dst = *scratch;
8515 *scratch = scratch.saturating_add(aligned);
8516 schedule.push(Step::BufferCopy {
8517 src_byte_off: src as u32,
8518 dst_byte_off: dst as u32,
8519 bytes: len as u32,
8520 });
8521 let lo = (*base).min(dst);
8522 let hi = (*base)
8523 .saturating_add(*size)
8524 .max(dst.saturating_add(aligned));
8525 let span = hi.saturating_sub(lo).max(1);
8526 if span <= max_binding {
8527 const ALIGN: u64 = 256;
8528 *base = (lo / ALIGN) * ALIGN;
8529 *size = span.div_ceil(ALIGN) * ALIGN;
8530 *size = (*size).max(256).min(max_binding);
8531 if (*base).saturating_add(*size) > arena.size as u64 {
8532 *base = (arena.size as u64).saturating_sub(*size);
8533 *base = (*base / ALIGN) * ALIGN;
8534 }
8535 }
8536 if arena_tensor_in_window(arena, id, *base, *size) {
8537 arena_local_off_f32(arena, id, *base)
8538 } else {
8539 ((dst.saturating_sub(*base)) / 4) as u32
8540 }
8541}
8542
8543fn arena_window_covering_scratch(arena: &Arena, base: u64, size: u64) -> u64 {
8545 let scratch = arena.scratch_off as u64;
8546 if scratch >= base && scratch.saturating_add(ARENA_STAGE_CAP) <= base.saturating_add(size) {
8547 return base;
8548 }
8549 let new_base = (arena.size as u64).saturating_sub(size);
8550 (new_base / 256) * 256
8551}
8552
8553fn arena_span_bytes(arena: &Arena, ids: &[NodeId]) -> u64 {
8554 let mut lo: u64 = u64::MAX;
8555 let mut hi: u64 = 0;
8556 for &id in ids {
8557 let off = arena.offset(id) as u64;
8558 let len = arena.len_of(id) as u64;
8559 lo = lo.min(off);
8560 hi = hi.max(off.saturating_add(len));
8561 }
8562 if lo == u64::MAX {
8563 0
8564 } else {
8565 hi.saturating_sub(lo)
8566 }
8567}
8568
8569#[allow(dead_code)]
8570fn bind_two(
8571 device: &wgpu::Device,
8572 kernel: &Kernel,
8573 buf0: &wgpu::Buffer,
8574 buf1: &wgpu::Buffer,
8575) -> wgpu::BindGroup {
8576 let max_binding = device.limits().max_storage_buffer_binding_size;
8577 if buf0.size() > max_binding {
8578 panic!(
8579 "rlx-wgpu: bind_two buffer {} bytes exceeds max_storage_buffer_binding_size {}; \
8580 use bind_two_buf0_window or bind_op_output_window",
8581 buf0.size(),
8582 max_binding
8583 );
8584 }
8585 device.create_bind_group(&wgpu::BindGroupDescriptor {
8586 label: Some("rlx-wgpu bg"),
8587 layout: &kernel.bgl,
8588 entries: &[
8589 wgpu::BindGroupEntry {
8590 binding: 0,
8591 resource: buf0.as_entire_binding(),
8592 },
8593 wgpu::BindGroupEntry {
8594 binding: 1,
8595 resource: buf1.as_entire_binding(),
8596 },
8597 ],
8598 })
8599}
8600
8601fn bind_op_output_window(
8605 device: &wgpu::Device,
8606 kernel: &Kernel,
8607 arena: &Arena,
8608 out_id: NodeId,
8609 params: &wgpu::Buffer,
8610) -> wgpu::BindGroup {
8611 bind_op_window(device, kernel, arena, &[out_id], params)
8612}
8613
8614fn bind_op_window(
8615 device: &wgpu::Device,
8616 kernel: &Kernel,
8617 arena: &Arena,
8618 ids: &[NodeId],
8619 params: &wgpu::Buffer,
8620) -> wgpu::BindGroup {
8621 let max_binding = device.limits().max_storage_buffer_binding_size;
8622 let (base, size) = if arena_span_bytes(arena, ids) <= max_binding {
8623 arena_window_for_nodes(device, arena, ids)
8624 } else {
8625 arena_window_for_nodes(device, arena, &[ids[0]])
8626 };
8627 bind_two_buf0_window(device, kernel, &arena.buffer, base, size, params)
8628}
8629
8630fn bind_two_buf0_window(
8631 device: &wgpu::Device,
8632 kernel: &Kernel,
8633 buf0: &wgpu::Buffer,
8634 buf0_base: u64,
8635 buf0_size: u64,
8636 buf1: &wgpu::Buffer,
8637) -> wgpu::BindGroup {
8638 device.create_bind_group(&wgpu::BindGroupDescriptor {
8639 label: Some("rlx-wgpu bg window"),
8640 layout: &kernel.bgl,
8641 entries: &[
8642 wgpu::BindGroupEntry {
8643 binding: 0,
8644 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
8645 buffer: buf0,
8646 offset: buf0_base,
8647 size: NonZeroU64::new(buf0_size),
8648 }),
8649 },
8650 wgpu::BindGroupEntry {
8651 binding: 1,
8652 resource: buf1.as_entire_binding(),
8653 },
8654 ],
8655 })
8656}
8657
8658fn derive_matmul_compute(
8681 dev: &wgpu::Device,
8682 graph: &Graph,
8683 mirror_acts: &HashSet<NodeId>,
8684 a_id: NodeId,
8685 b_id: NodeId,
8686 m: u32,
8687 k: u32,
8688 n: u32,
8689) -> MatmulCompute {
8690 if rlx_ir::env::flag("RLX_WGPU_MATMUL_F32_ONLY") {
8691 return MatmulCompute::F32;
8692 }
8693 use rlx_ir::DType;
8694 let a_dt = graph.node(a_id).shape.dtype();
8695 let b_dt = graph.node(b_id).shape.dtype();
8696 let any_low =
8697 matches!(a_dt, DType::F16 | DType::BF16) || matches!(b_dt, DType::F16 | DType::BF16);
8698 let coop16_aligned = m.is_multiple_of(32) && k.is_multiple_of(8) && n.is_multiple_of(32);
8708 let coop_f32_metal_aligned = k.is_multiple_of(8) && n.is_multiple_of(32);
8709 let coop_f32_portable_aligned = k.is_multiple_of(8) && n.is_multiple_of(8);
8710 let has_coop = dev
8711 .features()
8712 .contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX);
8713 let backend = crate::device::wgpu_device().map(|d| d.backend);
8714 if any_low
8719 && has_coop
8720 && dev.features().contains(wgpu::Features::SHADER_F16)
8721 && traces_to_param(graph, b_id)
8722 && coop16_aligned
8723 {
8724 return MatmulCompute::Coop16;
8725 }
8726 if !any_low && coop_f16_vk_eligible(dev, m, k, n) {
8727 if traces_to_param(graph, b_id)
8728 && !mirror_acts.contains(&a_id)
8729 && !mirror_acts.contains(&b_id)
8730 {
8731 return MatmulCompute::CoopF16Vk;
8732 }
8733 }
8734 let disabled = rlx_ir::env::flag("RLX_WGPU_NO_COOP_F32");
8743 let forced = rlx_ir::env::flag("RLX_WGPU_FORCE_COOP_F32");
8744 let metal_coop = !disabled
8745 && has_coop
8746 && coop_f32_metal_aligned
8747 && traces_to_param(graph, b_id)
8748 && (forced || matches!(backend, Some(wgpu::Backend::Metal)));
8749 let vulkan_coop = !disabled
8750 && has_coop
8751 && coop_f32_portable_aligned
8752 && traces_to_param(graph, b_id)
8753 && crate::device::coop_discrete_backend()
8754 && crate::device::coop_f32_8x8_supported();
8755 if metal_coop
8756 || vulkan_coop
8757 || (forced
8758 && has_coop
8759 && traces_to_param(graph, b_id)
8760 && (coop_f32_metal_aligned || coop_f32_portable_aligned))
8761 {
8762 return MatmulCompute::CoopF32;
8763 }
8764 MatmulCompute::F32
8765}
8766
8767#[allow(dead_code)]
8787fn detect_qkv_narrow_pattern(
8788 graph: &Graph,
8789 q_id: NodeId,
8790 k_id: NodeId,
8791 v_id: NodeId,
8792) -> Option<(NodeId, u32)> {
8793 let unwrap_narrow = |id: NodeId| -> Option<(NodeId, usize, usize, usize)> {
8794 let node = graph.node(id);
8795 match &node.op {
8796 Op::Narrow { axis, start, len } => Some((node.inputs[0], *axis, *start, *len)),
8797 _ => None,
8798 }
8799 };
8800 let (q_src, q_axis, q_start, q_len) = unwrap_narrow(q_id)?;
8801 let (k_src, k_axis, k_start, k_len) = unwrap_narrow(k_id)?;
8802 let (v_src, v_axis, v_start, v_len) = unwrap_narrow(v_id)?;
8803 if q_src != k_src || k_src != v_src {
8805 return None;
8806 }
8807 if q_len != k_len || k_len != v_len {
8809 return None;
8810 }
8811 if q_start != 0 || k_start != q_len || v_start != q_len * 2 {
8813 return None;
8814 }
8815 let src_rank = graph.node(q_src).shape.dims().len();
8817 if q_axis + 1 != src_rank || k_axis + 1 != src_rank || v_axis + 1 != src_rank {
8818 return None;
8819 }
8820 Some((q_src, q_len as u32))
8821}
8822
8823fn detect_residual_ln_tee_pattern(
8853 graph: &Graph,
8854) -> (
8855 HashMap<NodeId, (NodeId, NodeId, NodeId, NodeId, NodeId)>,
8856 HashSet<NodeId>,
8857) {
8858 use rlx_ir::op::BinaryOp;
8859 let mut consumers: HashMap<NodeId, usize> = HashMap::new();
8861 for node in graph.nodes() {
8862 for &input in &node.inputs {
8863 *consumers.entry(input).or_insert(0) += 1;
8864 }
8865 }
8866 for &out in &graph.outputs {
8867 *consumers.entry(out).or_insert(0) += 1;
8868 }
8869
8870 let mut ln_to_tee = HashMap::new();
8871 let mut skip_adds = HashSet::new();
8872 for node in graph.nodes() {
8873 let Op::LayerNorm { axis: _, eps: _ } = &node.op else {
8874 continue;
8875 };
8876 if node.inputs.len() < 3 {
8877 continue;
8878 } let in_id = node.inputs[0];
8880 let in_node = graph.node(in_id);
8881 if !matches!(in_node.op, Op::Binary(BinaryOp::Add)) {
8882 continue;
8883 }
8884 if consumers.get(&in_id).copied().unwrap_or(0) < 2 {
8887 continue;
8888 }
8889 if in_node.inputs.len() != 2 {
8892 continue;
8893 }
8894 let h_id = in_node.inputs[0];
8895 let delta_id = in_node.inputs[1];
8896 if graph.node(h_id).shape.dims() != node.shape.dims() {
8897 continue;
8898 }
8899 if graph.node(delta_id).shape.dims() != node.shape.dims() {
8900 continue;
8901 }
8902 let gamma_id = node.inputs[1];
8903 let beta_id = node.inputs[2];
8904 ln_to_tee.insert(node.id, (h_id, delta_id, gamma_id, beta_id, in_id));
8905 skip_adds.insert(in_id);
8906 }
8907 (ln_to_tee, skip_adds)
8908}
8909
8910fn detect_split_qkv_pattern(graph: &Graph) -> HashMap<NodeId, (NodeId, NodeId, NodeId)> {
8911 let mut consumers: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
8913 for node in graph.nodes() {
8914 for &input in &node.inputs {
8915 consumers.entry(input).or_default().push(node.id);
8916 }
8917 }
8918 for &out_id in &graph.outputs {
8921 consumers.entry(out_id).or_default().push(NodeId(u32::MAX));
8922 }
8923
8924 let mut result = HashMap::new();
8925 for node in graph.nodes() {
8926 if !matches!(node.op, Op::FusedMatMulBiasAct { activation: None }) {
8927 continue;
8928 }
8929 let cs = match consumers.get(&node.id) {
8930 Some(c) if c.len() == 3 => c,
8931 _ => continue,
8932 };
8933 let dims = node.shape.dims();
8934 if dims.is_empty() {
8935 continue;
8936 }
8937 let last_axis = dims.len() - 1;
8938 let n = dims[last_axis].unwrap_static();
8939 if n % 3 != 0 {
8940 continue;
8941 }
8942 let head_width = n / 3;
8943
8944 let mut narrows: Vec<(usize, NodeId)> = Vec::with_capacity(3);
8946 let mut all_match = true;
8947 for &c in cs {
8948 let cn = graph.node(c);
8949 match cn.op {
8950 Op::Narrow { axis, start, len }
8951 if axis == last_axis && len == head_width && cn.inputs[0] == node.id =>
8952 {
8953 narrows.push((start, c));
8954 }
8955 _ => {
8956 all_match = false;
8957 break;
8958 }
8959 }
8960 }
8961 if !all_match {
8962 continue;
8963 }
8964 narrows.sort_by_key(|&(start, _)| start);
8965 if narrows[0].0 != 0 || narrows[1].0 != head_width || narrows[2].0 != 2 * head_width {
8966 continue;
8967 }
8968 result.insert(node.id, (narrows[0].1, narrows[1].1, narrows[2].1));
8969 }
8970 result
8971}
8972
8973fn node_is_arena_param(param_offsets: &HashMap<String, NodeId>, id: NodeId) -> bool {
8979 param_offsets.values().any(|&nid| nid == id)
8980}
8981
8982fn traces_to_param(graph: &Graph, mut id: NodeId) -> bool {
8983 loop {
8984 let node = graph.node(id);
8985 match &node.op {
8986 Op::Param { .. } => return true,
8987 Op::Cast { .. } | Op::Reshape { .. } | Op::Transpose { .. } => {
8988 if node.inputs.is_empty() {
8989 return false;
8990 }
8991 id = node.inputs[0];
8992 }
8993 _ => return false,
8994 }
8995 }
8996}
8997
8998fn tensor_is_graph_param(
8999 graph: &Graph,
9000 param_offsets: &HashMap<String, NodeId>,
9001 id: NodeId,
9002) -> bool {
9003 node_is_arena_param(param_offsets, id) || traces_to_param(graph, id)
9004}
9005
9006fn traces_to_input(graph: &Graph, mut id: NodeId) -> bool {
9007 loop {
9008 let node = graph.node(id);
9009 match &node.op {
9010 Op::Input { .. } => return true,
9011 Op::Cast { .. } | Op::Reshape { .. } => {
9012 if node.inputs.is_empty() {
9013 return false;
9014 }
9015 id = node.inputs[0];
9016 }
9017 _ => return false,
9018 }
9019 }
9020}
9021
9022fn schedule_uses_coop_f16_vk(schedule: &[Step]) -> bool {
9025 schedule.iter().any(|s| {
9026 matches!(
9027 s,
9028 Step::Matmul {
9029 compute_precision: MatmulCompute::CoopF16Vk,
9030 ..
9031 } | Step::MatmulQkv {
9032 kind: MatmulQkvKind::CoopF16Vk,
9033 ..
9034 }
9035 )
9036 })
9037}
9038
9039fn register_coop_f16_vk_b_param(
9040 map: &mut HashMap<u32, String>,
9041 param_offsets: &HashMap<String, NodeId>,
9042 b_id: NodeId,
9043 b_off_f32: u32,
9044 compute: MatmulCompute,
9045) {
9046 if compute != MatmulCompute::CoopF16Vk {
9047 return;
9048 }
9049 for (name, &id) in param_offsets {
9050 if id == b_id {
9051 map.insert(b_off_f32, name.clone());
9052 return;
9053 }
9054 }
9055}
9056
9057fn tensor_host_name(
9058 input_offsets: &HashMap<String, NodeId>,
9059 param_offsets: &HashMap<String, NodeId>,
9060 id: NodeId,
9061) -> String {
9062 for (name, &nid) in input_offsets {
9063 if nid == id {
9064 return name.clone();
9065 }
9066 }
9067 for (name, &nid) in param_offsets {
9068 if nid == id {
9069 return name.clone();
9070 }
9071 }
9072 panic!("rlx-wgpu: CoopF16Vk host activation source {id} is not an input or param");
9073}
9074
9075fn host_tensor_f32<'a>(
9076 name: &str,
9077 inputs: &'a [(&str, &[f32])],
9078 stashed_params: &'a HashMap<String, Vec<f32>>,
9079) -> Option<&'a [f32]> {
9080 inputs
9081 .iter()
9082 .find(|(n, _)| *n == name)
9083 .map(|(_, d)| *d)
9084 .or_else(|| stashed_params.get(name).map(|v| v.as_slice()))
9085}
9086
9087fn apply_activation_host(act: Activation, data: &[f32]) -> Vec<f32> {
9088 data.iter()
9089 .map(|&x| match act {
9090 Activation::Relu => x.max(0.0),
9091 Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
9092 Activation::Tanh => x.tanh(),
9093 Activation::Exp => x.exp(),
9094 Activation::Log => x.ln(),
9095 Activation::Sqrt => x.sqrt(),
9096 Activation::Rsqrt => 1.0 / x.sqrt(),
9097 Activation::Neg => -x,
9098 Activation::Abs => x.abs(),
9099 Activation::Gelu | Activation::GeluApprox => {
9100 let c = 0.797_884_6_f32;
9101 let x3 = x * x * x;
9102 let inner = (c * (x + 0.044_715 * x3)).clamp(-15.0, 15.0);
9103 0.5 * x * (1.0 + inner.tanh())
9104 }
9105 Activation::Silu => {
9106 let nx = (-x).clamp(-88.0, 88.0);
9107 x / (1.0 + nx.exp())
9108 }
9109 Activation::Round => x.round(),
9110 Activation::Sin => x.sin(),
9111 Activation::Cos => x.cos(),
9112 Activation::Tan => x.tan(),
9113 Activation::Atan => x.atan(),
9114 })
9115 .collect()
9116}
9117
9118fn collect_coop_f16_vk_mirror_activations(graph: &Graph, dev: &wgpu::Device) -> HashSet<NodeId> {
9120 let mut acts = HashSet::new();
9121 for node in graph.nodes() {
9122 if !matches!(node.op, Op::MatMul) {
9123 continue;
9124 }
9125 let a_id = node.inputs[0];
9126 let b_id = node.inputs[1];
9127 let a_shape = graph.node(a_id).shape.dims();
9128 let b_shape = graph.node(b_id).shape.dims();
9129 if a_shape.len() != 2 || b_shape.len() != 2 {
9130 continue;
9131 }
9132 let m = a_shape[0].unwrap_static() as u32;
9133 let k = a_shape[1].unwrap_static() as u32;
9134 let n = b_shape[1].unwrap_static() as u32;
9135 if !coop_f16_vk_eligible(dev, m, k, n) || !traces_to_param(graph, b_id) {
9136 continue;
9137 }
9138 if matches!(graph.node(a_id).op, Op::Activation(_)) {
9139 acts.insert(a_id);
9140 }
9141 if matches!(graph.node(b_id).op, Op::Activation(_)) {
9142 acts.insert(b_id);
9143 }
9144 }
9145 acts
9146}
9147
9148fn maybe_push_coop_f16_vk_casts(
9151 graph: &Graph,
9152 a_id: NodeId,
9153 b_id: NodeId,
9154 mirror_acts: &HashSet<NodeId>,
9155 device: &wgpu::Device,
9156 arena: &Arena,
9157 schedule: &mut Vec<Step>,
9158 uniforms: &mut Vec<wgpu::Buffer>,
9159 bind_groups: &mut Vec<wgpu::BindGroup>,
9160 mm_cast: &Option<&'static Kernel>,
9161 compute_precision: MatmulCompute,
9162 a_off_f32: u32,
9163 m: u32,
9164 k: u32,
9165 batch: u32,
9166 b_off_f32: u32,
9167 n: u32,
9168) {
9169 if compute_precision != MatmulCompute::CoopF16Vk {
9170 return;
9171 }
9172 let batch_n = batch.max(1);
9173 if !traces_to_input(graph, a_id)
9174 && !traces_to_param(graph, a_id)
9175 && !mirror_acts.contains(&a_id)
9176 {
9177 let a_elems = m.saturating_mul(k).saturating_mul(batch_n);
9178 let (base, size) = arena_window_for_nodes(device, arena, &[a_id]);
9179 push_cast_f32_to_f16_step(
9180 device,
9181 arena,
9182 base,
9183 size,
9184 schedule,
9185 uniforms,
9186 bind_groups,
9187 mm_cast,
9188 a_off_f32,
9189 a_elems,
9190 );
9191 }
9192 if !traces_to_input(graph, b_id)
9193 && !traces_to_param(graph, b_id)
9194 && !mirror_acts.contains(&b_id)
9195 {
9196 let b_elems = k.saturating_mul(n).saturating_mul(batch_n);
9197 let (base, size) = arena_window_for_nodes(device, arena, &[b_id]);
9198 push_cast_f32_to_f16_step(
9199 device,
9200 arena,
9201 base,
9202 size,
9203 schedule,
9204 uniforms,
9205 bind_groups,
9206 mm_cast,
9207 b_off_f32,
9208 b_elems,
9209 );
9210 }
9211}
9212
9213fn build_matmul_qkv_coop_f16_vk_bind_group(
9214 device: &wgpu::Device,
9215 mqk: &Kernel,
9216 arena: &Arena,
9217 arena_base: u64,
9218 arena_size: u64,
9219 params: &wgpu::Buffer,
9220 k: u32,
9221 n: u32,
9222 b_off: u32,
9223) -> (wgpu::BindGroup, u32) {
9224 let f16_buf = arena
9225 .f16_buffer
9226 .as_ref()
9227 .expect("CoopF16Vk QKV requires SHADER_F16 f16 shadow arena");
9228 let (f16_res, rebased_b) = {
9229 let (base, size, rebased) =
9230 f16_weight_bind_range(device, f16_buf.size(), b_off, k, n, 1, 0);
9231 (
9232 wgpu::BindingResource::Buffer(wgpu::BufferBinding {
9233 buffer: f16_buf,
9234 offset: base,
9235 size: NonZeroU64::new(size),
9236 }),
9237 rebased,
9238 )
9239 };
9240 (
9241 device.create_bind_group(&wgpu::BindGroupDescriptor {
9242 label: Some("rlx-wgpu matmul_qkv_coop_f16_vk bg"),
9243 layout: &mqk.bgl,
9244 entries: &[
9245 wgpu::BindGroupEntry {
9246 binding: 0,
9247 resource: f16_res,
9248 },
9249 wgpu::BindGroupEntry {
9250 binding: 1,
9251 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
9252 buffer: &arena.buffer,
9253 offset: arena_base,
9254 size: NonZeroU64::new(arena_size),
9255 }),
9256 },
9257 wgpu::BindGroupEntry {
9258 binding: 2,
9259 resource: params.as_entire_binding(),
9260 },
9261 ],
9262 }),
9263 rebased_b,
9264 )
9265}
9266fn push_cast_f32_to_f16_step(
9270 device: &wgpu::Device,
9271 arena: &Arena,
9272 arena_base: u64,
9273 arena_size: u64,
9274 schedule: &mut Vec<Step>,
9275 uniforms: &mut Vec<wgpu::Buffer>,
9276 bind_groups: &mut Vec<wgpu::BindGroup>,
9277 mm_cast: &Option<&'static Kernel>,
9278 src_off: u32,
9279 len: u32,
9280) {
9281 let kernel = match mm_cast {
9282 Some(k) => *k,
9283 None => return, };
9285 let f16_buf = match &arena.f16_buffer {
9286 Some(b) => b,
9287 None => return,
9288 };
9289 let p = CastF32ToF16Params {
9290 src_off: src_off.saturating_sub((arena_base / 4) as u32),
9291 len,
9292 _p0: 0,
9293 _p1: 0,
9294 };
9295 let u = device.create_buffer(&wgpu::BufferDescriptor {
9296 label: Some("rlx-wgpu cast_f32_to_f16 uniform"),
9297 size: std::mem::size_of::<CastF32ToF16Params>() as u64,
9298 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
9299 mapped_at_creation: false,
9300 });
9301 let dev = wgpu_device().expect("rlx-wgpu: device gone");
9303 dev.queue.write_buffer(&u, 0, bytemuck::bytes_of(&p));
9304 let (f16_base, f16_size) = f16_shadow_bind_range(arena_base, arena_size, f16_buf.size());
9305 let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
9306 label: Some("rlx-wgpu cast_f32_to_f16 bg"),
9307 layout: &kernel.bgl,
9308 entries: &[
9309 wgpu::BindGroupEntry {
9310 binding: 0,
9311 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
9312 buffer: f16_buf,
9313 offset: f16_base,
9314 size: NonZeroU64::new(f16_size),
9315 }),
9316 },
9317 wgpu::BindGroupEntry {
9318 binding: 1,
9319 resource: u.as_entire_binding(),
9320 },
9321 wgpu::BindGroupEntry {
9322 binding: 2,
9323 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
9324 buffer: &arena.buffer,
9325 offset: arena_base,
9326 size: NonZeroU64::new(arena_size),
9327 }),
9328 },
9329 ],
9330 });
9331 schedule.push(Step::CastF32ToF16 { params: p });
9332 uniforms.push(u);
9333 bind_groups.push(bg);
9334}
9335
9336fn build_matmul_bind_group(
9340 device: &wgpu::Device,
9341 mm_k: &Kernel,
9342 _mm_w: &Kernel,
9343 mm_f16w: &Option<&'static Kernel>,
9344 mm_f16c: &Option<&'static Kernel>,
9345 mm_coop: &Option<&'static Kernel>,
9346 mm_coop_f32: &Option<&'static Kernel>,
9347 arena: &Arena,
9348 arena_base: u64,
9349 arena_size: u64,
9350 params: &wgpu::Buffer,
9351 b_is_param: bool,
9352 compute_precision: MatmulCompute,
9353 k: u32,
9354 n: u32,
9355 batch: u32,
9356 b_off: u32,
9357 b_batch_stride: u32,
9358) -> (wgpu::BindGroup, u32) {
9359 let f16_bind = |b_off: u32| -> (wgpu::BindingResource<'_>, u32) {
9360 let f16_buf = arena
9361 .f16_buffer
9362 .as_ref()
9363 .expect("f16 weight bind without f16_buffer");
9364 let (base, size, rebased) =
9365 f16_weight_bind_range(device, f16_buf.size(), b_off, k, n, batch, b_batch_stride);
9366 (
9367 wgpu::BindingResource::Buffer(wgpu::BufferBinding {
9368 buffer: f16_buf,
9369 offset: base,
9370 size: NonZeroU64::new(size),
9371 }),
9372 rebased,
9373 )
9374 };
9375 if compute_precision == MatmulCompute::CoopF16Vk
9376 && let (Some(coop_vk), Some(_f16_buf)) =
9377 (matmul_coop_f16_vulkan_kernel(device), &arena.f16_buffer)
9378 {
9379 let (f16_res, rebased_b) = f16_bind(b_off);
9380 return (
9381 device.create_bind_group(&wgpu::BindGroupDescriptor {
9382 label: Some("rlx-wgpu matmul_coop_f16_vulkan bg"),
9383 layout: &coop_vk.bgl,
9384 entries: &[
9385 wgpu::BindGroupEntry {
9386 binding: 0,
9387 resource: f16_res,
9388 },
9389 wgpu::BindGroupEntry {
9390 binding: 1,
9391 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
9392 buffer: &arena.buffer,
9393 offset: arena_base,
9394 size: NonZeroU64::new(arena_size),
9395 }),
9396 },
9397 wgpu::BindGroupEntry {
9398 binding: 2,
9399 resource: params.as_entire_binding(),
9400 },
9401 ],
9402 }),
9403 rebased_b,
9404 );
9405 }
9406 if b_is_param
9407 && compute_precision == MatmulCompute::CoopF32
9408 && let Some(coop_f32) = mm_coop_f32
9409 {
9410 return (
9413 device.create_bind_group(&wgpu::BindGroupDescriptor {
9414 label: Some("rlx-wgpu matmul_coop_f32 bg"),
9415 layout: &coop_f32.bgl,
9416 entries: &[
9417 wgpu::BindGroupEntry {
9418 binding: 0,
9419 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
9420 buffer: &arena.buffer,
9421 offset: arena_base,
9422 size: NonZeroU64::new(arena_size),
9423 }),
9424 },
9425 wgpu::BindGroupEntry {
9426 binding: 1,
9427 resource: params.as_entire_binding(),
9428 },
9429 ],
9430 }),
9431 b_off,
9432 );
9433 }
9434 if b_is_param
9435 && compute_precision == MatmulCompute::Coop16
9436 && let (Some(_f16_buf), Some(coop)) = (&arena.f16_buffer, mm_coop)
9437 {
9438 let (f16_res, rebased_b) = f16_bind(b_off);
9439 return (
9443 device.create_bind_group(&wgpu::BindGroupDescriptor {
9444 label: Some("rlx-wgpu matmul_coop16 bg"),
9445 layout: &coop.bgl,
9446 entries: &[
9447 wgpu::BindGroupEntry {
9448 binding: 0,
9449 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
9450 buffer: &arena.buffer,
9451 offset: arena_base,
9452 size: NonZeroU64::new(arena_size),
9453 }),
9454 },
9455 wgpu::BindGroupEntry {
9456 binding: 1,
9457 resource: params.as_entire_binding(),
9458 },
9459 wgpu::BindGroupEntry {
9460 binding: 2,
9461 resource: f16_res,
9462 }, ],
9464 }),
9465 rebased_b,
9466 );
9467 }
9468 if b_is_param
9469 && compute_precision == MatmulCompute::F16
9470 && let (Some(_f16_buf), Some(f16c)) = (&arena.f16_buffer, mm_f16c)
9471 {
9472 let (f16_res, rebased_b) = f16_bind(b_off);
9473 return (
9474 device.create_bind_group(&wgpu::BindGroupDescriptor {
9475 label: Some("rlx-wgpu matmul_f16_compute bg"),
9476 layout: &f16c.bgl,
9477 entries: &[
9478 wgpu::BindGroupEntry {
9479 binding: 0,
9480 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
9481 buffer: &arena.buffer,
9482 offset: arena_base,
9483 size: NonZeroU64::new(arena_size),
9484 }),
9485 },
9486 wgpu::BindGroupEntry {
9487 binding: 1,
9488 resource: params.as_entire_binding(),
9489 },
9490 wgpu::BindGroupEntry {
9491 binding: 2,
9492 resource: f16_res,
9493 },
9494 ],
9495 }),
9496 rebased_b,
9497 );
9498 }
9499 let f16w_opt_in = rlx_ir::env::flag("RLX_WGPU_F16_WEIGHTS");
9500 if b_is_param
9501 && f16w_opt_in
9502 && let (Some(_f16_buf), Some(f16w)) = (&arena.f16_buffer, mm_f16w)
9503 {
9504 let (f16_res, rebased_b) = f16_bind(b_off);
9505 return (
9506 device.create_bind_group(&wgpu::BindGroupDescriptor {
9507 label: Some("rlx-wgpu matmul_f16w bg"),
9508 layout: &f16w.bgl,
9509 entries: &[
9510 wgpu::BindGroupEntry {
9511 binding: 0,
9512 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
9513 buffer: &arena.buffer,
9514 offset: arena_base,
9515 size: NonZeroU64::new(arena_size),
9516 }),
9517 },
9518 wgpu::BindGroupEntry {
9519 binding: 1,
9520 resource: params.as_entire_binding(),
9521 },
9522 wgpu::BindGroupEntry {
9523 binding: 2,
9524 resource: f16_res,
9525 },
9526 ],
9527 }),
9528 rebased_b,
9529 );
9530 }
9531 (
9532 bind_two_buf0_window(device, mm_k, &arena.buffer, arena_base, arena_size, params),
9533 b_off,
9534 )
9535}