1use std::collections::HashMap;
27use std::sync::{Arc, Mutex, Once};
28
29use cudarc::cublas::{CudaBlas, sys as cublas_sys};
30use cudarc::cublaslt::{result as cublaslt_result, sys as cublaslt_sys};
31use cudarc::cudnn::{result as cudnn_result, sys as cudnn_sys};
32use cudarc::driver::{CudaContext, DevicePtrMut, LaunchConfig, PushKernelArg};
33use rlx_ir::op::{Activation, BinaryOp, CmpOp, MaskKind, ReduceOp};
34use rlx_ir::{Graph, NodeId, Op};
35use rlx_opt::rlx_fusion::lower_reduce_axes::LowerNonLastAxisReduce;
36use rlx_opt::rlx_fusion::pass::Pass as _;
37
38use crate::arena::{Arena, plan_f32_uniform};
39use crate::device::{
40 CUBLASLT_WORKSPACE_BYTES, CUDNN_WORKSPACE_BYTES, cuda_blas, cuda_blas_lt_handle,
41 cuda_blas_lt_workspace, cuda_context, cuda_dnn_handle, cuda_dnn_workspace,
42};
43use crate::host_staging::F32HostSlot;
44use crate::kernels::{
45 argmax_kernel, attention_bwd_kernel, attention_kernel, attention_row_kernel,
46 batch_elementwise_region_kernel, binary_kernel, compare_kernel, concat_kernel,
47 conv_transpose2d_kernel, conv1d_kernel, conv2d_kernel, conv3d_kernel, cumsum_backward_kernel,
48 cumsum_kernel, dequant_matmul_kernel, dispatch_grid_1d, dispatch_grid_prologue_nchw,
49 elementwise_region_kernel, expand_kernel, fused_binary_unary_kernel, fused_residual_ln_kernel,
50 fused_residual_rms_norm_kernel, gather_axis_kernel, gather_backward_kernel, gather_kernel,
51 group_norm_kernel, grouped_matmul_kernel, im2col_kernel, layer_norm2d_kernel, layernorm_kernel,
52 matmul_epilogue_kernel, matmul_kernel, matmul_wmma_kernel, narrow_kernel, pool1d_kernel,
53 pool2d_kernel, pool3d_kernel, reduce_kernel, resize_nearest_2x_kernel,
54 rms_norm_backward_kernel, rms_norm_bwd_zero_kernel, rope_backward_kernel, rope_kernel,
55 sample_kernel, scatter_add_acc_kernel, scatter_add_zero_kernel, selective_scan_kernel,
56 softmax_kernel, topk_kernel, transpose_kernel, unary_kernel, where_kernel,
57};
58
59fn use_wmma() -> bool {
65 use std::sync::OnceLock;
66 static FLAG: OnceLock<bool> = OnceLock::new();
67 *FLAG.get_or_init(|| {
68 rlx_ir::env::var("RLX_CUDA_WMMA")
69 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
70 .unwrap_or(false)
71 })
72}
73
74fn matmul_parity_mode() -> bool {
77 use std::sync::OnceLock;
78 static FLAG: OnceLock<bool> = OnceLock::new();
79 *FLAG.get_or_init(|| {
80 rlx_ir::env::flag("RLX_CUDA_NO_TF32")
81 || rlx_ir::env::flag("RLX_CUDA_PARITY")
82 || rlx_ir::env::flag("RLX_CUDA_NO_CUBLASLT")
83 })
84}
85
86#[derive(Clone)]
88enum Step {
89 Matmul {
90 m: u32,
91 k: u32,
92 n: u32,
93 a_off_f32: u32,
94 b_off_f32: u32,
95 c_off_f32: u32,
96 batch: u32,
97 a_batch_stride: u32,
98 b_batch_stride: u32,
99 c_batch_stride: u32,
100 has_bias: u32,
101 bias_off_f32: u32,
102 act_id: u32,
103 },
104 Binary {
105 n: u32,
106 a_off: u32,
107 b_off: u32,
108 c_off: u32,
109 op: u32,
110 },
111 Compare {
112 n: u32,
113 a_off: u32,
114 b_off: u32,
115 c_off: u32,
116 op: u32,
117 },
118 Unary {
119 n: u32,
120 in_off: u32,
121 out_off: u32,
122 op: u32,
123 },
124 Where {
125 n: u32,
126 cond_off: u32,
127 x_off: u32,
128 y_off: u32,
129 out_off: u32,
130 },
131 Reduce {
132 outer: u32,
133 inner: u32,
134 in_off: u32,
135 out_off: u32,
136 op: u32,
137 },
138 Softmax {
139 outer: u32,
140 inner: u32,
141 in_off: u32,
142 out_off: u32,
143 },
144 LayerNorm {
145 outer: u32,
146 inner: u32,
147 in_off: u32,
148 out_off: u32,
149 gamma_off: u32,
150 beta_off: u32,
151 eps_bits: u32,
152 op: u32,
153 },
154 FusedResidualLn {
155 outer: u32,
156 inner: u32,
157 in_off: u32,
158 residual_off: u32,
159 bias_off: u32,
160 gamma_off: u32,
161 beta_off: u32,
162 out_off: u32,
163 eps_bits: u32,
164 has_bias: u32,
165 },
166 FusedResidualRmsNorm {
167 outer: u32,
168 inner: u32,
169 in_off: u32,
170 residual_off: u32,
171 bias_off: u32,
172 gamma_off: u32,
173 beta_off: u32,
174 out_off: u32,
175 eps_bits: u32,
176 has_bias: u32,
177 },
178 Gather {
179 n_out: u32,
180 n_idx: u32,
181 dim: u32,
182 vocab: u32,
183 in_off: u32,
184 idx_off: u32,
185 out_off: u32,
186 },
187 GatherAxis {
188 total: u32,
189 outer: u32,
190 axis_dim: u32,
191 num_idx: u32,
192 trailing: u32,
193 table_off: u32,
194 idx_off: u32,
195 out_off: u32,
196 },
197 Narrow {
198 total: u32,
199 outer: u32,
200 inner: u32,
201 axis_in_size: u32,
202 axis_out_size: u32,
203 start: u32,
204 in_off: u32,
205 out_off: u32,
206 },
207 Argmax {
208 outer: u32,
209 inner: u32,
210 in_off: u32,
211 out_off: u32,
212 },
213 Transpose {
214 rank: u32,
215 out_total: u32,
216 in_off: u32,
217 out_off: u32,
218 meta_idx: usize,
219 },
220 Expand {
221 rank: u32,
222 out_total: u32,
223 in_off: u32,
224 out_off: u32,
225 meta_idx: usize,
226 },
227 Concat {
228 total: u32,
229 outer: u32,
230 inner: u32,
231 axis_in_size: u32,
232 axis_out_size: u32,
233 start: u32,
234 in_off: u32,
235 out_off: u32,
236 },
237 Attention {
238 batch: u32,
239 heads: u32,
240 seq_q: u32,
241 seq_k: u32,
242 head_dim: u32,
243 q_off: u32,
244 k_off: u32,
245 v_off: u32,
246 out_off: u32,
247 mask_off: u32,
248 mask_kind: u32,
249 scale_bits: u32,
250 window: u32,
251 seq_q_stride: u32,
252 seq_k_stride: u32,
253 mask_batch_stride: u32,
254 mask_head_stride: u32,
255 q_batch_stride: u32,
256 q_head_stride: u32,
257 q_seq_stride: u32,
258 k_batch_stride: u32,
259 k_head_stride: u32,
260 k_seq_stride: u32,
261 v_batch_stride: u32,
262 v_head_stride: u32,
263 v_seq_stride: u32,
264 o_batch_stride: u32,
265 o_head_stride: u32,
266 o_seq_stride: u32,
267 },
268 AttentionBackward {
269 batch: u32,
270 heads: u32,
271 seq_q: u32,
272 seq_k: u32,
273 head_dim: u32,
274 q_off: u32,
275 k_off: u32,
276 v_off: u32,
277 dy_off: u32,
278 out_off: u32,
279 mask_off: u32,
280 mask_kind: u32,
281 scale_bits: u32,
282 window: u32,
283 wrt: u32,
284 },
285 Rope {
286 n_total: u32,
287 seq: u32,
288 head_dim: u32,
289 half: u32,
290 in_off: u32,
291 cos_off: u32,
292 sin_off: u32,
293 out_off: u32,
294 last_dim: u32,
295 },
296 Cumsum {
297 outer: u32,
298 inner: u32,
299 in_off: u32,
300 out_off: u32,
301 exclusive: u32,
302 },
303 TopK {
304 outer: u32,
305 inner: u32,
306 k: u32,
307 in_off: u32,
308 out_off: u32,
309 },
310 GroupedMatmul {
311 m: u32,
312 k: u32,
313 n: u32,
314 num_experts: u32,
315 in_off: u32,
316 w_off: u32,
317 idx_off: u32,
318 out_off: u32,
319 },
320 ScatterAddZero {
321 out_off: u32,
322 out_total: u32,
323 },
324 ScatterAddAcc {
325 out_off: u32,
326 upd_off: u32,
327 idx_off: u32,
328 num_updates: u32,
329 trailing: u32,
330 out_dim: u32,
331 },
332 DequantMatmul {
333 m: u32,
334 k: u32,
335 n: u32,
336 block_size: u32,
337 scheme_id: u32,
338 x_off: u32,
339 w_off: u32,
340 scale_off: u32,
341 zp_off: u32,
342 out_off: u32,
343 },
344 DequantMatmulGguf {
346 m: u32,
347 k: u32,
348 n: u32,
349 scheme_id: u32,
350 x_byte_off: u32,
351 w_byte_off: u32,
352 out_byte_off: u32,
353 },
354 DequantGroupedMatmulGguf {
355 m: u32,
356 k: u32,
357 n: u32,
358 num_experts: u32,
359 scheme_id: u32,
360 x_byte_off: u32,
361 w_byte_off: u32,
362 idx_byte_off: u32,
363 out_byte_off: u32,
364 },
365 Sample {
366 outer: u32,
367 inner: u32,
368 in_off: u32,
369 out_off: u32,
370 top_k: u32,
371 top_p_bits: u32,
372 temp_bits: u32,
373 seed_lo: u32,
374 seed_hi: u32,
375 },
376 SelectiveScan {
377 batch: u32,
378 seq: u32,
379 hidden: u32,
380 state_size: u32,
381 x_off: u32,
382 delta_off: u32,
383 a_off: u32,
384 b_off: u32,
385 c_off: u32,
386 out_off: u32,
387 },
388 Fft {
390 src_byte_off: u32,
391 dst_byte_off: u32,
392 outer: u32,
393 n_complex: u32,
394 inverse: bool,
395 norm_tag: u32,
396 dtype_tag: u32,
397 use_gpu: bool,
398 },
399 LogMelHost {
401 spec_byte_off: u32,
402 filt_byte_off: u32,
403 dst_byte_off: u32,
404 outer: u32,
405 n_fft: u32,
406 n_bins: u32,
407 n_mels: u32,
408 },
409 LogMelBackwardHost {
410 spec_byte_off: u32,
411 filt_byte_off: u32,
412 dy_byte_off: u32,
413 dst_byte_off: u32,
414 outer: u32,
415 n_fft: u32,
416 n_bins: u32,
417 n_mels: u32,
418 },
419 WelchPeaksHost {
421 spec_byte_off: u32,
422 dst_byte_off: u32,
423 welch_batch: u32,
424 n_fft: u32,
425 n_segments: u32,
426 k: u32,
427 },
428 WelchPeaksGpu {
430 spec_off: u32,
431 dst_off: u32,
432 welch_batch: u32,
433 n_fft: u32,
434 n_segments: u32,
435 k: u32,
436 n_bins: u32,
437 },
438 Im2ColHost {
440 x_byte_off: u32,
441 col_byte_off: u32,
442 n: u32,
443 c_in: u32,
444 h: u32,
445 w: u32,
446 h_out: u32,
447 w_out: u32,
448 kh: u32,
449 kw: u32,
450 sh: u32,
451 sw: u32,
452 ph: u32,
453 pw: u32,
454 dh: u32,
455 dw_dil: u32,
456 use_gpu: bool,
457 },
458 GatedDeltaNet {
460 q_byte_off: u32,
461 k_byte_off: u32,
462 v_byte_off: u32,
463 g_byte_off: u32,
464 beta_byte_off: u32,
465 state_byte_off: u32,
466 dst_byte_off: u32,
467 batch: u32,
468 seq: u32,
469 heads: u32,
470 state_size: u32,
471 use_carry: bool,
472 },
473 Llada2GroupLimitedGate {
475 sig_off: u32,
476 route_off: u32,
477 out_off: u32,
478 n_elems: u32,
479 attrs: [u8; 20],
480 },
481 UmapKnn {
482 pairwise_off: u32,
483 out_off: u32,
484 n: u32,
485 k: u32,
486 },
487 GaussianSplatRender {
489 positions_off: u32,
490 positions_len: u32,
491 scales_off: u32,
492 scales_len: u32,
493 rotations_off: u32,
494 rotations_len: u32,
495 opacities_off: u32,
496 opacities_len: u32,
497 colors_off: u32,
498 colors_len: u32,
499 sh_coeffs_off: u32,
500 sh_coeffs_len: u32,
501 meta_off: u32,
502 dst_off: u32,
503 dst_len: u32,
504 width: u32,
505 height: u32,
506 tile_size: u32,
507 radius_scale: f32,
508 alpha_cutoff: f32,
509 max_splat_steps: u32,
510 transmittance_threshold: f32,
511 max_list_entries: u32,
512 },
513 GaussianSplatRenderBackward {
514 positions_off: u32,
515 positions_len: u32,
516 scales_off: u32,
517 scales_len: u32,
518 rotations_off: u32,
519 rotations_len: u32,
520 opacities_off: u32,
521 opacities_len: u32,
522 colors_off: u32,
523 colors_len: u32,
524 sh_coeffs_off: u32,
525 sh_coeffs_len: u32,
526 meta_off: u32,
527 d_loss_off: u32,
528 d_loss_len: u32,
529 packed_off: u32,
530 packed_len: u32,
531 width: u32,
532 height: u32,
533 tile_size: u32,
534 radius_scale: f32,
535 alpha_cutoff: f32,
536 max_splat_steps: u32,
537 transmittance_threshold: f32,
538 max_list_entries: u32,
539 loss_grad_clip: f32,
540 sh_band: u32,
541 max_anisotropy: f32,
542 },
543 GaussianSplatPrepare {
544 positions_off: u32,
545 positions_len: u32,
546 scales_off: u32,
547 scales_len: u32,
548 rotations_off: u32,
549 rotations_len: u32,
550 opacities_off: u32,
551 opacities_len: u32,
552 colors_off: u32,
553 colors_len: u32,
554 sh_coeffs_off: u32,
555 sh_coeffs_len: u32,
556 meta_off: u32,
557 meta_len: u32,
558 prep_off: u32,
559 prep_len: u32,
560 width: u32,
561 height: u32,
562 tile_size: u32,
563 radius_scale: f32,
564 alpha_cutoff: f32,
565 max_splat_steps: u32,
566 transmittance_threshold: f32,
567 max_list_entries: u32,
568 },
569 GaussianSplatRasterize {
570 prep_off: u32,
571 prep_len: u32,
572 meta_off: u32,
573 meta_len: u32,
574 dst_off: u32,
575 dst_len: u32,
576 count: u32,
577 width: u32,
578 height: u32,
579 tile_size: u32,
580 alpha_cutoff: f32,
581 max_splat_steps: u32,
582 transmittance_threshold: f32,
583 max_list_entries: u32,
584 },
585 RmsNormBackwardInput {
586 x_byte_off: u32,
587 gamma_byte_off: u32,
588 beta_byte_off: u32,
589 dy_byte_off: u32,
590 dx_byte_off: u32,
591 rows: u32,
592 h: u32,
593 eps_bits: u32,
594 },
595 RmsNormBackwardGamma {
596 x_byte_off: u32,
597 gamma_byte_off: u32,
598 beta_byte_off: u32,
599 dy_byte_off: u32,
600 dgamma_byte_off: u32,
601 rows: u32,
602 h: u32,
603 eps_bits: u32,
604 },
605 RmsNormBackwardBeta {
606 x_byte_off: u32,
607 gamma_byte_off: u32,
608 beta_byte_off: u32,
609 dy_byte_off: u32,
610 dbeta_byte_off: u32,
611 rows: u32,
612 h: u32,
613 eps_bits: u32,
614 },
615 RopeBackward {
616 dy_byte_off: u32,
617 cos_byte_off: u32,
618 sin_byte_off: u32,
619 dx_byte_off: u32,
620 batch: u32,
621 seq: u32,
622 hidden: u32,
623 head_dim: u32,
624 n_rot: u32,
625 cos_len: u32,
626 },
627 CumsumBackward {
628 dy_byte_off: u32,
629 dx_byte_off: u32,
630 rows: u32,
631 cols: u32,
632 exclusive: bool,
633 },
634 GatherBackward {
635 dy_byte_off: u32,
636 indices_byte_off: u32,
637 dst_byte_off: u32,
638 outer: u32,
639 axis_dim: u32,
640 num_idx: u32,
641 trailing: u32,
642 },
643 MaxPool2dBackward {
644 x_byte_off: u32,
645 dy_byte_off: u32,
646 dx_byte_off: u32,
647 n: u32,
648 c: u32,
649 h: u32,
650 w: u32,
651 h_out: u32,
652 w_out: u32,
653 kh: u32,
654 kw: u32,
655 sh: u32,
656 sw: u32,
657 ph: u32,
658 pw: u32,
659 },
660 Conv2dBackwardInput {
661 dy_byte_off: u32,
662 w_byte_off: u32,
663 dx_byte_off: u32,
664 n: u32,
665 c_in: u32,
666 h: u32,
667 w_in: u32,
668 c_out: u32,
669 h_out: u32,
670 w_out: u32,
671 kh: u32,
672 kw: u32,
673 sh: u32,
674 sw: u32,
675 ph: u32,
676 pw: u32,
677 dh: u32,
678 dw: u32,
679 groups: u32,
680 },
681 Conv2dBackwardWeight {
682 x_byte_off: u32,
683 dy_byte_off: u32,
684 dw_byte_off: u32,
685 n: u32,
686 c_in: u32,
687 h: u32,
688 w: u32,
689 c_out: u32,
690 h_out: u32,
691 w_out: u32,
692 kh: u32,
693 kw: u32,
694 sh: u32,
695 sw: u32,
696 ph: u32,
697 pw: u32,
698 dh: u32,
699 dw_dil: u32,
700 groups: u32,
701 },
702 Pool1d {
703 n: u32,
704 c: u32,
705 l: u32,
706 l_out: u32,
707 kl: u32,
708 sl: u32,
709 pl: u32,
710 op: u32,
711 in_off: u32,
712 out_off: u32,
713 },
714 Pool2d {
715 n: u32,
716 c: u32,
717 h: u32,
718 w: u32,
719 h_out: u32,
720 w_out: u32,
721 kh: u32,
722 kw: u32,
723 sh: u32,
724 sw: u32,
725 ph: u32,
726 pw: u32,
727 op: u32,
728 in_off: u32,
729 out_off: u32,
730 },
731 Pool3d {
732 n: u32,
733 c: u32,
734 d: u32,
735 h: u32,
736 w: u32,
737 d_out: u32,
738 h_out: u32,
739 w_out: u32,
740 kd: u32,
741 kh: u32,
742 kw: u32,
743 sd: u32,
744 sh: u32,
745 sw: u32,
746 pd: u32,
747 ph: u32,
748 pw: u32,
749 op: u32,
750 in_off: u32,
751 out_off: u32,
752 },
753 Conv1d {
754 n: u32,
755 c_in: u32,
756 c_out: u32,
757 l: u32,
758 l_out: u32,
759 kl: u32,
760 sl: u32,
761 pl: u32,
762 dl: u32,
763 groups: u32,
764 in_off: u32,
765 w_off: u32,
766 out_off: u32,
767 },
768 Conv2d {
769 n: u32,
770 c_in: u32,
771 c_out: u32,
772 h: u32,
773 w: u32,
774 h_out: u32,
775 w_out: u32,
776 kh: u32,
777 kw: u32,
778 sh: u32,
779 sw: u32,
780 ph: u32,
781 pw: u32,
782 dh: u32,
783 dw: u32,
784 groups: u32,
785 in_off: u32,
786 w_off: u32,
787 out_off: u32,
788 },
789 Conv3d {
790 n: u32,
791 c_in: u32,
792 c_out: u32,
793 d: u32,
794 h: u32,
795 w: u32,
796 d_out: u32,
797 h_out: u32,
798 w_out: u32,
799 kd: u32,
800 kh: u32,
801 kw: u32,
802 sd: u32,
803 sh: u32,
804 sw: u32,
805 pd: u32,
806 ph: u32,
807 pw: u32,
808 dd: u32,
809 dh: u32,
810 dw: u32,
811 groups: u32,
812 in_off: u32,
813 w_off: u32,
814 out_off: u32,
815 },
816 LayerNorm2d {
818 src_off: u32,
819 g_off: u32,
820 b_off: u32,
821 dst_off: u32,
822 n: u32,
823 c: u32,
824 h: u32,
825 w: u32,
826 eps_bits: u32,
827 },
828 ConvTranspose2d {
830 src_off: u32,
831 w_off: u32,
832 dst_off: u32,
833 n: u32,
834 c_in: u32,
835 h: u32,
836 w_in: u32,
837 c_out: u32,
838 h_out: u32,
839 w_out: u32,
840 kh: u32,
841 kw: u32,
842 sh: u32,
843 sw: u32,
844 ph: u32,
845 pw: u32,
846 dh: u32,
847 dw: u32,
848 groups: u32,
849 },
850 GroupNorm {
852 src_off: u32,
853 g_off: u32,
854 b_off: u32,
855 dst_off: u32,
856 n: u32,
857 c: u32,
858 h: u32,
859 w: u32,
860 num_groups: u32,
861 eps_bits: u32,
862 },
863 ResizeNearest2x {
865 src_off: u32,
866 dst_off: u32,
867 n: u32,
868 c: u32,
869 h: u32,
870 w: u32,
871 },
872 FusedBinaryUnary {
878 n: u32,
879 a_off: u32,
880 b_off: u32,
881 out_off: u32,
882 bin_op: u32,
883 un_op: u32,
884 },
885 ElementwiseRegion {
896 len: u32,
897 num_inputs: u32,
898 num_steps: u32,
899 dst_off: u32,
900 input_offs: [u32; 16],
901 scalar_input_mask: u32,
904 input_modulus: [u32; 16],
908 meta_idx: usize,
909 spatial_prologue: bool,
911 prologue_w: u32,
912 prologue_h: u32,
913 prologue_nc: u32,
914 },
915 BatchElementwiseRegion {
917 slice_len: u32,
918 num_batch: u32,
919 num_steps: u32,
920 base_dst_off: u32,
921 slice_elems: u32,
922 batch_input_offs: [u32; 64],
924 batch_offs_idx: usize,
925 meta_idx: usize,
926 scalar_input_mask: u32,
927 input_modulus: [u32; 16],
928 },
929}
930
931#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
939pub enum CompileMode {
940 #[default]
941 Jit,
942 Aot,
943}
944
945#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
959pub enum ExecMode {
960 #[default]
961 Stream,
962 Graph,
963 Eager,
964 MultiStream(usize),
965}
966
967pub struct CudaExecutable {
968 ctx: Arc<CudaContext>,
969 blas: Option<Arc<Mutex<CudaBlas>>>,
974 blas_lt: Option<cublaslt_sys::cublasLtHandle_t>,
977 blas_lt_workspace: Option<Arc<Mutex<cudarc::driver::CudaSlice<u8>>>>,
980 dnn: Option<cudnn_sys::cudnnHandle_t>,
983 dnn_workspace: Option<Arc<Mutex<cudarc::driver::CudaSlice<u8>>>>,
986 half_act_scratch: Option<cudarc::driver::CudaSlice<u16>>,
990 dequant_scratch_off: usize,
992 graph: Graph,
993 arena: Arena,
994 schedule: Vec<Step>,
995 input_offsets: HashMap<String, NodeId>,
996 param_offsets: HashMap<String, NodeId>,
997 meta_buffers: Vec<cudarc::driver::CudaSlice<u32>>,
1000 exec_mode: ExecMode,
1001 captured_graph: Option<cudarc::driver::CudaGraph>,
1004 streams: Vec<Arc<cudarc::driver::CudaStream>>,
1007 pub(crate) active_extent: Option<(usize, usize)>,
1013 output_staging: Vec<F32HostSlot>,
1015 input_staging: HashMap<String, F32HostSlot>,
1017 replay_event: Option<cudarc::driver::CudaEvent>,
1019 gpu_handles: HashMap<String, Vec<f32>>,
1021 gpu_handle_feeds: HashMap<String, usize>,
1022 gpu_handle_resident: std::collections::HashSet<String>,
1023 pending_read_indices: Option<Vec<usize>>,
1025 readback_plan_buf: Vec<usize>,
1027 captured_readback_plan: Option<Vec<usize>>,
1029 input_slot_names: Vec<String>,
1031 input_slots: Vec<(usize, usize)>,
1033 output_slots: Vec<(usize, usize)>,
1035 host_arena: Vec<f32>,
1037}
1038
1039impl Step {
1040 pub fn safe_for_active_extent(&self) -> bool {
1048 matches!(
1049 self,
1050 Step::Binary { .. }
1051 | Step::Compare { .. }
1052 | Step::Unary { .. }
1053 | Step::Where { .. }
1054 | Step::Reduce { .. }
1055 | Step::Softmax { .. }
1056 | Step::LayerNorm { .. }
1057 | Step::FusedResidualLn { .. }
1058 | Step::FusedResidualRmsNorm { .. }
1059 | Step::Cumsum { .. }
1060 | Step::FusedBinaryUnary { .. }
1061 | Step::ElementwiseRegion { .. }
1062 | Step::BatchElementwiseRegion { .. }
1063 )
1064 }
1065
1066 pub fn graph_capture_safe(&self) -> bool {
1068 match self {
1069 Step::Im2ColHost { use_gpu, .. } | Step::Fft { use_gpu, .. } => *use_gpu,
1070 Step::GatedDeltaNet { .. }
1071 | Step::Llada2GroupLimitedGate { .. }
1072 | Step::UmapKnn { .. }
1073 | Step::LogMelHost { .. }
1074 | Step::LogMelBackwardHost { .. }
1075 | Step::WelchPeaksHost { .. }
1076 | Step::GaussianSplatRender { .. }
1077 | Step::GaussianSplatRenderBackward { .. }
1078 | Step::GaussianSplatPrepare { .. } => false,
1079 _ => true,
1080 }
1081 }
1082}
1083
1084fn schedule_graph_capture_safe(schedule: &[Step]) -> bool {
1085 schedule.iter().all(Step::graph_capture_safe)
1086}
1087
1088fn step_is_tail_host(step: &Step) -> bool {
1089 matches!(
1090 step,
1091 Step::LogMelHost { .. } | Step::LogMelBackwardHost { .. } | Step::WelchPeaksHost { .. }
1092 )
1093}
1094
1095fn run_tail_host_audio_ops(
1096 schedule: &[Step],
1097 stream: &Arc<cudarc::driver::CudaStream>,
1098 buffer: &mut cudarc::driver::CudaSlice<f32>,
1099 pre_sync: bool,
1100) {
1101 if !schedule.iter().any(step_is_tail_host) {
1102 return;
1103 }
1104 if pre_sync {
1105 stream
1106 .synchronize()
1107 .expect("rlx-cuda: tail host pre-sync failed");
1108 }
1109 for step in schedule {
1110 match step {
1111 Step::LogMelHost {
1112 spec_byte_off,
1113 filt_byte_off,
1114 dst_byte_off,
1115 outer,
1116 n_fft,
1117 n_bins,
1118 n_mels,
1119 } => {
1120 crate::log_mel_host::run_log_mel(
1121 stream,
1122 buffer,
1123 *spec_byte_off as usize,
1124 *filt_byte_off as usize,
1125 *dst_byte_off as usize,
1126 *outer as usize,
1127 *n_fft as usize,
1128 *n_bins as usize,
1129 *n_mels as usize,
1130 false,
1131 );
1132 }
1133 Step::LogMelBackwardHost {
1134 spec_byte_off,
1135 filt_byte_off,
1136 dy_byte_off,
1137 dst_byte_off,
1138 outer,
1139 n_fft,
1140 n_bins,
1141 n_mels,
1142 } => {
1143 crate::log_mel_backward_host::run_log_mel_backward(
1144 stream,
1145 buffer,
1146 *spec_byte_off as usize,
1147 *filt_byte_off as usize,
1148 *dy_byte_off as usize,
1149 *dst_byte_off as usize,
1150 *outer as usize,
1151 *n_fft as usize,
1152 *n_bins as usize,
1153 *n_mels as usize,
1154 false,
1155 );
1156 }
1157 Step::WelchPeaksHost {
1158 spec_byte_off,
1159 dst_byte_off,
1160 welch_batch,
1161 n_fft,
1162 n_segments,
1163 k,
1164 } => {
1165 crate::welch_peaks_host::run_welch_peaks(
1166 stream,
1167 buffer,
1168 *spec_byte_off as usize,
1169 *dst_byte_off as usize,
1170 *welch_batch as usize,
1171 *n_fft as usize,
1172 *n_segments as usize,
1173 *k as usize,
1174 false,
1175 );
1176 }
1177 _ => {}
1178 }
1179 }
1180}
1181
1182fn schedule_needs_blas_lt(schedule: &[Step]) -> bool {
1183 schedule.iter().any(|s| {
1184 matches!(
1185 s,
1186 Step::Matmul { act_id, .. } if cublaslt_act_supported(*act_id)
1187 )
1188 })
1189}
1190
1191fn schedule_needs_dnn(schedule: &[Step]) -> bool {
1192 schedule.iter().any(|s| {
1193 matches!(
1194 s,
1195 Step::Conv1d { .. } | Step::Conv2d { .. } | Step::Conv3d { .. }
1196 )
1197 })
1198}
1199
1200fn cublaslt_act_for(act_id: u32) -> Option<cublaslt_sys::cublasLtEpilogue_t> {
1206 None.or(match act_id {
1207 0xFFFFu32 => Some(None),
1209 0 => Some(Some(
1211 cublaslt_sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU,
1212 )),
1213 9 | 11 => Some(Some(
1214 cublaslt_sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU,
1215 )),
1216 _ => Some(None),
1217 })
1218 .flatten()
1219}
1220
1221fn cublaslt_act_supported(act_id: u32) -> bool {
1223 matches!(act_id, 0xFFFFu32 | 0 | 9 | 11)
1224}
1225
1226unsafe fn cublaslt_matmul_fused(
1232 handle: cublaslt_sys::cublasLtHandle_t,
1233 workspace_dev_ptr: u64,
1234 workspace_size: usize,
1235 arena_dev_ptr: u64,
1236 m: u32,
1237 k: u32,
1238 n: u32,
1239 a_off_f32: u32,
1240 b_off_f32: u32,
1241 c_off_f32: u32,
1242 has_bias: bool,
1243 bias_off_f32: u32,
1244 epilogue_act: Option<cublaslt_sys::cublasLtEpilogue_t>,
1245 batch: u32,
1246 a_batch_stride: u32,
1247 b_batch_stride: u32,
1248 c_batch_stride: u32,
1249 cu_stream: cudarc::driver::sys::CUstream,
1250) -> Result<(), cublaslt_result::CublasError> {
1251 use core::ffi::c_void;
1252 use core::mem;
1253
1254 let a_ptr = (arena_dev_ptr + (b_off_f32 as u64) * 4) as *const c_void; let b_ptr = (arena_dev_ptr + (a_off_f32 as u64) * 4) as *const c_void; let c_ptr = (arena_dev_ptr + (c_off_f32 as u64) * 4) as *const c_void;
1259 let d_ptr = c_ptr as *mut c_void;
1260
1261 let dt = cublaslt_sys::cudaDataType_t::CUDA_R_32F;
1262
1263 let a_layout = cublaslt_result::create_matrix_layout(dt, n as u64, k as u64, n as i64)?;
1265 let b_layout = cublaslt_result::create_matrix_layout(dt, k as u64, m as u64, k as i64)?;
1266 let c_layout = cublaslt_result::create_matrix_layout(dt, n as u64, m as u64, n as i64)?;
1267
1268 if batch > 1 {
1269 unsafe {
1270 let bsz = batch as i32;
1271 for &layout in &[a_layout, b_layout, c_layout] {
1272 cublaslt_result::set_matrix_layout_attribute(
1273 layout,
1274 cublaslt_sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
1275 &bsz as *const _ as *const _,
1276 mem::size_of::<i32>(),
1277 )?;
1278 }
1279 let stride_b = b_batch_stride as i64;
1280 let stride_a = a_batch_stride as i64;
1281 let stride_c = c_batch_stride as i64;
1282 cublaslt_result::set_matrix_layout_attribute(
1283 a_layout,
1284 cublaslt_sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
1285 &stride_b as *const _ as *const _, mem::size_of::<i64>())?;
1286 cublaslt_result::set_matrix_layout_attribute(
1287 b_layout,
1288 cublaslt_sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
1289 &stride_a as *const _ as *const _, mem::size_of::<i64>())?;
1290 cublaslt_result::set_matrix_layout_attribute(
1291 c_layout,
1292 cublaslt_sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
1293 &stride_c as *const _ as *const _, mem::size_of::<i64>())?;
1294 }
1295 }
1296
1297 let compute_type =
1301 if rlx_ir::env::flag("RLX_CUDA_NO_TF32") || rlx_ir::env::flag("RLX_CUDA_PARITY") {
1302 cublaslt_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
1303 } else {
1304 cublaslt_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32
1305 };
1306 let matmul_desc = cublaslt_result::create_matmul_desc(compute_type, dt)?;
1307
1308 let epilogue = match (has_bias, epilogue_act) {
1313 (true, Some(cublaslt_sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU)) => {
1314 cublaslt_sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_BIAS
1315 }
1316 (true, Some(cublaslt_sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU)) => {
1317 cublaslt_sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_BIAS
1318 }
1319 (true, None) => cublaslt_sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BIAS,
1320 (false, Some(act)) => act,
1321 (false, None) => cublaslt_sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
1322 _ => cublaslt_sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
1323 };
1324 unsafe {
1325 cublaslt_result::set_matmul_desc_attribute(
1326 matmul_desc,
1327 cublaslt_sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE,
1328 &epilogue as *const _ as *const _,
1329 mem::size_of::<cublaslt_sys::cublasLtEpilogue_t>(),
1330 )?;
1331 }
1332
1333 if has_bias {
1334 let bias_dev_ptr = arena_dev_ptr + (bias_off_f32 as u64) * 4;
1335 unsafe {
1336 cublaslt_result::set_matmul_desc_attribute(
1337 matmul_desc,
1338 cublaslt_sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_POINTER,
1339 &bias_dev_ptr as *const _ as *const _,
1340 mem::size_of::<u64>(),
1341 )?;
1342 }
1343 }
1344
1345 let matmul_pref = cublaslt_result::create_matmul_pref()?;
1346 unsafe {
1347 cublaslt_result::set_matmul_pref_attribute(
1348 matmul_pref,
1349 cublaslt_sys::cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
1350 &workspace_size as *const _ as *const _,
1351 mem::size_of::<usize>(),
1352 )?;
1353 }
1354
1355 let heuristic = unsafe {
1356 cublaslt_result::get_matmul_algo_heuristic(
1357 handle,
1358 matmul_desc,
1359 a_layout,
1360 b_layout,
1361 c_layout,
1362 c_layout,
1363 matmul_pref,
1364 )
1365 }?;
1366
1367 let alpha = 1.0_f32;
1368 let beta = 0.0_f32;
1369 let workspace_ptr = workspace_dev_ptr as *mut c_void;
1370
1371 let result = unsafe {
1372 cublaslt_result::matmul(
1373 handle,
1374 matmul_desc,
1375 &alpha as *const _ as *const c_void,
1376 &beta as *const _ as *const c_void,
1377 a_ptr,
1378 a_layout,
1379 b_ptr,
1380 b_layout,
1381 c_ptr,
1382 c_layout,
1383 d_ptr,
1384 c_layout,
1385 &heuristic.algo as *const _,
1386 workspace_ptr,
1387 workspace_size,
1388 cu_stream as cublaslt_sys::cudaStream_t,
1389 )
1390 };
1391
1392 unsafe {
1394 let _ = cublaslt_result::destroy_matmul_pref(matmul_pref);
1395 let _ = cublaslt_result::destroy_matmul_desc(matmul_desc);
1396 let _ = cublaslt_result::destroy_matrix_layout(c_layout);
1397 let _ = cublaslt_result::destroy_matrix_layout(b_layout);
1398 let _ = cublaslt_result::destroy_matrix_layout(a_layout);
1399 }
1400
1401 result
1402}
1403
1404unsafe fn cudnn_conv2d_forward(
1410 handle: cudnn_sys::cudnnHandle_t,
1411 workspace_dev_ptr: u64,
1412 workspace_size: usize,
1413 arena_dev_ptr: u64,
1414 n: u32,
1415 c_in: u32,
1416 c_out: u32,
1417 h: u32,
1418 w: u32,
1419 h_out: u32,
1420 w_out: u32,
1421 kh: u32,
1422 kw: u32,
1423 sh: u32,
1424 sw: u32,
1425 ph: u32,
1426 pw: u32,
1427 dh: u32,
1428 dw: u32,
1429 groups: u32,
1430 in_off_f32: u32,
1431 w_off_f32: u32,
1432 out_off_f32: u32,
1433) -> Result<(), cudnn_result::CudnnError> {
1434 use core::ffi::c_void;
1435
1436 let dt = cudnn_sys::cudnnDataType_t::CUDNN_DATA_FLOAT;
1437 let fmt = cudnn_sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW;
1438
1439 let x_desc = cudnn_result::create_tensor_descriptor()?;
1440 let y_desc = cudnn_result::create_tensor_descriptor()?;
1441 let conv_desc = cudnn_result::create_convolution_descriptor()?;
1442
1443 let w_desc = unsafe {
1444 let mut w_desc_uninit = std::mem::MaybeUninit::uninit();
1445 cudnn_sys::cudnnCreateFilterDescriptor(w_desc_uninit.as_mut_ptr()).result()?;
1446 w_desc_uninit.assume_init()
1447 };
1448
1449 let setup = unsafe {
1450 cudnn_result::set_tensor4d_descriptor(
1451 x_desc,
1452 fmt,
1453 dt,
1454 [n as i32, c_in as i32, h as i32, w as i32],
1455 )?;
1456 cudnn_result::set_tensor4d_descriptor(
1457 y_desc,
1458 fmt,
1459 dt,
1460 [n as i32, c_out as i32, h_out as i32, w_out as i32],
1461 )?;
1462 cudnn_result::set_filter4d_descriptor(
1463 w_desc,
1464 dt,
1465 fmt,
1466 [
1467 c_out as i32,
1468 (c_in / groups.max(1)) as i32,
1469 kh as i32,
1470 kw as i32,
1471 ],
1472 )?;
1473 cudnn_result::set_convolution2d_descriptor(
1474 conv_desc,
1475 ph as i32,
1476 pw as i32,
1477 sh as i32,
1478 sw as i32,
1479 dh as i32,
1480 dw as i32,
1481 cudnn_sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
1482 dt,
1483 )?;
1484 if groups > 1 {
1485 cudnn_sys::cudnnSetConvolutionGroupCount(conv_desc, groups as i32).result()?;
1486 }
1487 Ok::<(), cudnn_result::CudnnError>(())
1488 };
1489
1490 let result = setup.and_then(|()| unsafe {
1491 let mut returned_count: i32 = 0;
1493 let mut perf = std::mem::MaybeUninit::<cudnn_sys::cudnnConvolutionFwdAlgoPerf_t>::uninit();
1494 cudnn_result::get_convolution_forward_algorithm(
1495 handle,
1496 x_desc,
1497 w_desc,
1498 conv_desc,
1499 y_desc,
1500 1,
1501 &mut returned_count,
1502 perf.as_mut_ptr(),
1503 )?;
1504 if returned_count == 0 {
1505 return Err(cudnn_result::CudnnError(
1506 cudnn_sys::cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED,
1507 ));
1508 }
1509 let algo = perf.assume_init().algo;
1510
1511 let needed = cudnn_result::get_convolution_forward_workspace_size(
1512 handle, x_desc, w_desc, conv_desc, y_desc, algo,
1513 )?;
1514 if needed > workspace_size {
1515 return Err(cudnn_result::CudnnError(
1516 cudnn_sys::cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED,
1517 ));
1518 }
1519
1520 let alpha: f32 = 1.0;
1521 let beta: f32 = 0.0;
1522 let x_ptr = (arena_dev_ptr + (in_off_f32 as u64) * 4) as *const c_void;
1523 let w_ptr = (arena_dev_ptr + (w_off_f32 as u64) * 4) as *const c_void;
1524 let y_ptr = (arena_dev_ptr + (out_off_f32 as u64) * 4) as *mut c_void;
1525 let workspace_ptr = workspace_dev_ptr as *mut c_void;
1526
1527 cudnn_result::convolution_forward(
1528 handle,
1529 &alpha as *const _ as *const c_void,
1530 x_desc,
1531 x_ptr,
1532 w_desc,
1533 w_ptr,
1534 conv_desc,
1535 algo,
1536 workspace_ptr,
1537 workspace_size,
1538 &beta as *const _ as *const c_void,
1539 y_desc,
1540 y_ptr,
1541 )
1542 });
1543
1544 unsafe {
1545 let _ = cudnn_result::destroy_convolution_descriptor(conv_desc);
1546 let _ = cudnn_result::destroy_filter_descriptor(w_desc);
1547 let _ = cudnn_result::destroy_tensor_descriptor(y_desc);
1548 let _ = cudnn_result::destroy_tensor_descriptor(x_desc);
1549 }
1550
1551 result
1552}
1553
1554unsafe fn cudnn_conv3d_forward(
1558 handle: cudnn_sys::cudnnHandle_t,
1559 workspace_dev_ptr: u64,
1560 workspace_size: usize,
1561 arena_dev_ptr: u64,
1562 n: u32,
1563 c_in: u32,
1564 c_out: u32,
1565 d: u32,
1566 h: u32,
1567 w: u32,
1568 d_out: u32,
1569 h_out: u32,
1570 w_out: u32,
1571 kd: u32,
1572 kh: u32,
1573 kw: u32,
1574 sd: u32,
1575 sh: u32,
1576 sw: u32,
1577 pd: u32,
1578 ph: u32,
1579 pw: u32,
1580 dd: u32,
1581 dh: u32,
1582 dw: u32,
1583 groups: u32,
1584 in_off_f32: u32,
1585 w_off_f32: u32,
1586 out_off_f32: u32,
1587) -> Result<(), cudnn_result::CudnnError> {
1588 use core::ffi::c_void;
1589
1590 let dt = cudnn_sys::cudnnDataType_t::CUDNN_DATA_FLOAT;
1591 let fmt = cudnn_sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW;
1592
1593 let x_desc = cudnn_result::create_tensor_descriptor()?;
1594 let y_desc = cudnn_result::create_tensor_descriptor()?;
1595 let conv_desc = cudnn_result::create_convolution_descriptor()?;
1596 let w_desc = unsafe {
1597 let mut w_desc_uninit = std::mem::MaybeUninit::uninit();
1598 cudnn_sys::cudnnCreateFilterDescriptor(w_desc_uninit.as_mut_ptr()).result()?;
1599 w_desc_uninit.assume_init()
1600 };
1601
1602 let x_dims: [i32; 5] = [n as i32, c_in as i32, d as i32, h as i32, w as i32];
1604 let x_strides: [i32; 5] = [
1605 (c_in * d * h * w) as i32,
1606 (d * h * w) as i32,
1607 (h * w) as i32,
1608 w as i32,
1609 1,
1610 ];
1611 let y_dims: [i32; 5] = [
1612 n as i32,
1613 c_out as i32,
1614 d_out as i32,
1615 h_out as i32,
1616 w_out as i32,
1617 ];
1618 let y_strides: [i32; 5] = [
1619 (c_out * d_out * h_out * w_out) as i32,
1620 (d_out * h_out * w_out) as i32,
1621 (h_out * w_out) as i32,
1622 w_out as i32,
1623 1,
1624 ];
1625 let f_dims: [i32; 5] = [
1626 c_out as i32,
1627 (c_in / groups.max(1)) as i32,
1628 kd as i32,
1629 kh as i32,
1630 kw as i32,
1631 ];
1632 let pads: [i32; 3] = [pd as i32, ph as i32, pw as i32];
1633 let strides: [i32; 3] = [sd as i32, sh as i32, sw as i32];
1634 let dilations: [i32; 3] = [dd as i32, dh as i32, dw as i32];
1635
1636 let setup = unsafe {
1637 cudnn_result::set_tensornd_descriptor(x_desc, dt, 5, x_dims.as_ptr(), x_strides.as_ptr())?;
1638 cudnn_result::set_tensornd_descriptor(y_desc, dt, 5, y_dims.as_ptr(), y_strides.as_ptr())?;
1639 cudnn_result::set_filternd_descriptor(w_desc, dt, fmt, 5, f_dims.as_ptr())?;
1640 cudnn_result::set_convolutionnd_descriptor(
1641 conv_desc,
1642 3,
1643 pads.as_ptr(),
1644 strides.as_ptr(),
1645 dilations.as_ptr(),
1646 cudnn_sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
1647 dt,
1648 )?;
1649 if groups > 1 {
1650 cudnn_sys::cudnnSetConvolutionGroupCount(conv_desc, groups as i32).result()?;
1651 }
1652 Ok::<(), cudnn_result::CudnnError>(())
1653 };
1654
1655 let result = setup.and_then(|()| unsafe {
1656 let mut returned_count: i32 = 0;
1657 let mut perf = std::mem::MaybeUninit::<cudnn_sys::cudnnConvolutionFwdAlgoPerf_t>::uninit();
1658 cudnn_result::get_convolution_forward_algorithm(
1659 handle,
1660 x_desc,
1661 w_desc,
1662 conv_desc,
1663 y_desc,
1664 1,
1665 &mut returned_count,
1666 perf.as_mut_ptr(),
1667 )?;
1668 if returned_count == 0 {
1669 return Err(cudnn_result::CudnnError(
1670 cudnn_sys::cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED,
1671 ));
1672 }
1673 let algo = perf.assume_init().algo;
1674
1675 let needed = cudnn_result::get_convolution_forward_workspace_size(
1676 handle, x_desc, w_desc, conv_desc, y_desc, algo,
1677 )?;
1678 if needed > workspace_size {
1679 return Err(cudnn_result::CudnnError(
1680 cudnn_sys::cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED,
1681 ));
1682 }
1683
1684 let alpha: f32 = 1.0;
1685 let beta: f32 = 0.0;
1686 let x_ptr = (arena_dev_ptr + (in_off_f32 as u64) * 4) as *const c_void;
1687 let w_ptr = (arena_dev_ptr + (w_off_f32 as u64) * 4) as *const c_void;
1688 let y_ptr = (arena_dev_ptr + (out_off_f32 as u64) * 4) as *mut c_void;
1689 let workspace_ptr = workspace_dev_ptr as *mut c_void;
1690
1691 cudnn_result::convolution_forward(
1692 handle,
1693 &alpha as *const _ as *const c_void,
1694 x_desc,
1695 x_ptr,
1696 w_desc,
1697 w_ptr,
1698 conv_desc,
1699 algo,
1700 workspace_ptr,
1701 workspace_size,
1702 &beta as *const _ as *const c_void,
1703 y_desc,
1704 y_ptr,
1705 )
1706 });
1707
1708 unsafe {
1709 let _ = cudnn_result::destroy_convolution_descriptor(conv_desc);
1710 let _ = cudnn_result::destroy_filter_descriptor(w_desc);
1711 let _ = cudnn_result::destroy_tensor_descriptor(y_desc);
1712 let _ = cudnn_result::destroy_tensor_descriptor(x_desc);
1713 }
1714
1715 result
1716}
1717
1718fn matmul_shape(
1725 graph: &Graph,
1726 node: &rlx_ir::Node,
1727 op_label: &str,
1728) -> (u32, u32, u32, u32, u32, u32, u32, NodeId, NodeId) {
1729 let a_id = node.inputs[0];
1730 let b_id = node.inputs[1];
1731 let a_shape = graph.node(a_id).shape.dims();
1732 let b_shape = graph.node(b_id).shape.dims();
1733 let out_shape = node.shape.dims();
1734 if a_shape.len() == 2 && b_shape.len() == 2 && out_shape.len() == 2 {
1735 let m = a_shape[0].unwrap_static() as u32;
1736 let k = a_shape[1].unwrap_static() as u32;
1737 let n = b_shape[1].unwrap_static() as u32;
1738 (m, k, n, 1, 0, 0, 0, a_id, b_id)
1739 } else if a_shape.len() >= 2 && b_shape.len() == 2 && out_shape.len() == a_shape.len() {
1740 let leading: usize = a_shape[..a_shape.len() - 2]
1741 .iter()
1742 .map(|d| d.unwrap_static())
1743 .product();
1744 let m_inner = a_shape[a_shape.len() - 2].unwrap_static();
1745 let k_inner = a_shape[a_shape.len() - 1].unwrap_static();
1746 let n_inner = b_shape[1].unwrap_static();
1747 (
1748 (leading * m_inner) as u32,
1749 k_inner as u32,
1750 n_inner as u32,
1751 1,
1752 0,
1753 0,
1754 0,
1755 a_id,
1756 b_id,
1757 )
1758 } else if a_shape.len() == b_shape.len() && a_shape.len() >= 3 {
1759 let leading_a: Vec<usize> = a_shape[..a_shape.len() - 2]
1760 .iter()
1761 .map(|d| d.unwrap_static())
1762 .collect();
1763 let leading_b: Vec<usize> = b_shape[..b_shape.len() - 2]
1764 .iter()
1765 .map(|d| d.unwrap_static())
1766 .collect();
1767 if leading_a != leading_b {
1768 panic!(
1769 "rlx-cuda {op_label}: batched shape mismatch \
1770 a_leading={leading_a:?} b_leading={leading_b:?}"
1771 );
1772 }
1773 let b_count: usize = leading_a.iter().product();
1774 let m_inner = a_shape[a_shape.len() - 2].unwrap_static();
1775 let k_inner = a_shape[a_shape.len() - 1].unwrap_static();
1776 let n_inner = b_shape[b_shape.len() - 1].unwrap_static();
1777 (
1778 m_inner as u32,
1779 k_inner as u32,
1780 n_inner as u32,
1781 b_count as u32,
1782 (m_inner * k_inner) as u32,
1783 (k_inner * n_inner) as u32,
1784 (m_inner * n_inner) as u32,
1785 a_id,
1786 b_id,
1787 )
1788 } else {
1789 panic!(
1790 "rlx-cuda {op_label}: unsupported shapes a={a_shape:?} b={b_shape:?} out={out_shape:?}"
1791 );
1792 }
1793}
1794
1795fn binary_op_id(op: BinaryOp) -> u32 {
1796 match op {
1797 BinaryOp::Add => 0,
1798 BinaryOp::Sub => 1,
1799 BinaryOp::Mul => 2,
1800 BinaryOp::Div => 3,
1801 BinaryOp::Max => 4,
1802 BinaryOp::Min => 5,
1803 BinaryOp::Pow => 6,
1804 }
1805}
1806
1807fn compare_op_id(op: CmpOp) -> u32 {
1808 match op {
1809 CmpOp::Eq => 0,
1810 CmpOp::Ne => 1,
1811 CmpOp::Lt => 2,
1812 CmpOp::Le => 3,
1813 CmpOp::Gt => 4,
1814 CmpOp::Ge => 5,
1815 }
1816}
1817
1818fn reduce_op_id(op: ReduceOp) -> u32 {
1819 match op {
1820 ReduceOp::Sum => 0,
1821 ReduceOp::Mean => 1,
1822 ReduceOp::Max => 2,
1823 ReduceOp::Min => 3,
1824 ReduceOp::Prod => 4,
1825 }
1826}
1827
1828fn activation_op_id(act: Activation) -> u32 {
1829 match act {
1830 Activation::Relu => 0,
1831 Activation::Sigmoid => 1,
1832 Activation::Tanh => 2,
1833 Activation::Exp => 3,
1834 Activation::Log => 4,
1835 Activation::Sqrt => 5,
1836 Activation::Rsqrt => 6,
1837 Activation::Neg => 7,
1838 Activation::Abs => 8,
1839 Activation::Gelu => 9,
1840 Activation::Silu => 10,
1841 Activation::GeluApprox => 11,
1842 Activation::Round => 12,
1843 Activation::Sin => 13,
1844 Activation::Cos => 14,
1845 Activation::Tan => 15,
1846 Activation::Atan => 16,
1847 }
1848}
1849
1850#[allow(clippy::too_many_arguments)]
1859fn try_mixed_precision_gemm(
1860 ctx: &Arc<CudaContext>,
1861 arena: &mut crate::arena::Arena,
1862 half_act_scratch: &mut Option<cudarc::driver::CudaSlice<u16>>,
1863 blas: Option<&Arc<Mutex<CudaBlas>>>,
1864 stream: &Arc<cudarc::driver::CudaStream>,
1865 m: u32,
1866 k: u32,
1867 n: u32,
1868 batch: u32,
1869 a_off_f32: u32,
1870 b_off_f32: u32,
1871 c_off_f32: u32,
1872) -> bool {
1873 let (half_off, half_dtype) = match arena.half_by_f32_off.get(&b_off_f32).copied() {
1874 Some(v) => v,
1875 None => return false,
1876 };
1877 let blas = match blas {
1878 Some(b) => b,
1879 None => return false,
1880 };
1881
1882 let act_elems = (m * k * batch.max(1)) as usize;
1883 let need_resize = half_act_scratch
1884 .as_ref()
1885 .is_none_or(|s| s.len() < act_elems);
1886 if need_resize {
1887 *half_act_scratch = stream.alloc_zeros::<u16>(act_elems.max(4)).ok();
1888 }
1889 if half_act_scratch.is_none() {
1890 return false;
1891 }
1892
1893 let n_total = m * k * batch.max(1);
1895 let dtype_id: u32 = match half_dtype {
1896 crate::arena::HalfDtype::F16 => 0,
1897 crate::arena::HalfDtype::Bf16 => 1,
1898 };
1899 {
1900 let kernel = crate::kernels::cast_f32_to_half_kernel(ctx);
1901 let (grid, block) = dispatch_grid_1d(n_total, 256);
1902 let cfg = LaunchConfig {
1903 grid_dim: (grid, 1, 1),
1904 block_dim: (block, 1, 1),
1905 shared_mem_bytes: 0,
1906 };
1907 let src_view = arena
1908 .f32_buf()
1909 .slice(a_off_f32 as usize..a_off_f32 as usize + n_total as usize);
1910 let scratch_mut = half_act_scratch.as_mut().unwrap();
1911 let mut launcher = stream.launch_builder(&kernel.function);
1912 launcher
1913 .arg(&src_view)
1914 .arg(scratch_mut)
1915 .arg(&n_total)
1916 .arg(&dtype_id);
1917 if unsafe { launcher.launch(cfg) }.is_err() {
1918 return false;
1919 }
1920 }
1921
1922 let blas = blas.lock().unwrap();
1924 let arena_ptr_u64 = {
1925 let (p, _ar) = arena.buffer.device_ptr_mut(stream);
1926 p
1927 };
1928 let (half_buf_ptr, _hb) = arena.half_buffer.as_mut().unwrap().device_ptr_mut(stream);
1929 let scratch_ptr_u64 = {
1930 let s = half_act_scratch.as_mut().unwrap();
1931 let (p, _r) = s.device_ptr_mut(stream);
1932 p
1933 };
1934 let weight_dev = half_buf_ptr + (half_off as u64) * 2; let act_dev = scratch_ptr_u64;
1936 let c_dev = arena_ptr_u64 + (c_off_f32 as u64) * 4;
1937 let alpha: f32 = 1.0;
1938 let beta: f32 = 0.0;
1939 let cuda_dt = match half_dtype {
1940 crate::arena::HalfDtype::F16 => cublas_sys::cudaDataType_t::CUDA_R_16F,
1941 crate::arena::HalfDtype::Bf16 => cublas_sys::cudaDataType_t::CUDA_R_16BF,
1942 };
1943 let compute_ty = match half_dtype {
1944 crate::arena::HalfDtype::F16 => {
1945 cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16F
1946 }
1947 crate::arena::HalfDtype::Bf16 => {
1948 cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF
1949 }
1950 };
1951 let result = unsafe {
1952 cudarc::cublas::result::gemm_ex(
1953 *blas.handle(),
1954 cublas_sys::cublasOperation_t::CUBLAS_OP_N,
1955 cublas_sys::cublasOperation_t::CUBLAS_OP_N,
1956 n as i32,
1957 m as i32,
1958 k as i32,
1959 &alpha as *const f32 as *const _,
1960 weight_dev as *const _,
1961 cuda_dt,
1962 n as i32,
1963 act_dev as *const _,
1964 cuda_dt,
1965 k as i32,
1966 &beta as *const f32 as *const _,
1967 c_dev as *mut _,
1968 cublas_sys::cudaDataType_t::CUDA_R_32F,
1969 n as i32,
1970 compute_ty,
1971 cublas_sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT,
1972 )
1973 };
1974 if let Err(ref e) = result {
1975 log_fallback("matmul.gemmEx (mixed-precision)", e);
1976 }
1977 result.is_ok()
1978}
1979
1980fn log_fallback(tier: &str, err: impl std::fmt::Debug) {
1985 use std::sync::OnceLock;
1986 static ENABLED: OnceLock<bool> = OnceLock::new();
1987 let enabled = *ENABLED.get_or_init(|| {
1988 rlx_ir::env::var("RLX_CUDA_LOG_FALLBACK")
1989 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
1990 .unwrap_or(false)
1991 });
1992 if enabled {
1993 eprintln!("rlx-cuda: tier '{tier}' fell back: {err:?}");
1994 }
1995}
1996
1997fn fft_dtype_tag(dtype: rlx_ir::DType) -> u32 {
2001 match dtype {
2002 rlx_ir::DType::F32 => 0,
2003 rlx_ir::DType::F64 => 1,
2004 rlx_ir::DType::C64 => 2,
2005 other => panic!("rlx-cuda Op::Fft: unsupported dtype {other:?}"),
2006 }
2007}
2008
2009fn fft_dtype_from_tag(tag: u32) -> rlx_ir::DType {
2010 match tag {
2011 0 => rlx_ir::DType::F32,
2012 1 => rlx_ir::DType::F64,
2013 2 => rlx_ir::DType::C64,
2014 other => panic!("rlx-cuda Op::Fft: bad dtype tag {other}"),
2015 }
2016}
2017
2018fn step_name(step: &Step) -> &'static str {
2019 match step {
2020 Step::Matmul { .. } => "rlx::Matmul",
2021 Step::Binary { .. } => "rlx::Binary",
2022 Step::Compare { .. } => "rlx::Compare",
2023 Step::Unary { .. } => "rlx::Unary",
2024 Step::Where { .. } => "rlx::Where",
2025 Step::Reduce { .. } => "rlx::Reduce",
2026 Step::Softmax { .. } => "rlx::Softmax",
2027 Step::LayerNorm { .. } => "rlx::LayerNorm",
2028 Step::FusedResidualLn { .. } => "rlx::FusedResidualLN",
2029 Step::FusedResidualRmsNorm { .. } => "rlx::FusedResidualRmsNorm",
2030 Step::Gather { .. } => "rlx::Gather",
2031 Step::GatherAxis { .. } => "rlx::GatherAxis",
2032 Step::Narrow { .. } => "rlx::Narrow",
2033 Step::Concat { .. } => "rlx::Concat",
2034 Step::Transpose { .. } => "rlx::Transpose",
2035 Step::Expand { .. } => "rlx::Expand",
2036 Step::Argmax { .. } => "rlx::Argmax",
2037 Step::Attention { .. } => "rlx::Attention",
2038 Step::AttentionBackward { .. } => "rlx::AttentionBackward",
2039 Step::Rope { .. } => "rlx::Rope",
2040 Step::Cumsum { .. } => "rlx::Cumsum",
2041 Step::TopK { .. } => "rlx::TopK",
2042 Step::GroupedMatmul { .. } => "rlx::GroupedMatmul",
2043 Step::ScatterAddZero { .. } => "rlx::ScatterAdd::zero",
2044 Step::ScatterAddAcc { .. } => "rlx::ScatterAdd::acc",
2045 Step::DequantMatmul { .. } => "rlx::DequantMatmul",
2046 Step::DequantMatmulGguf { .. } => "rlx::DequantMatmulGguf",
2047 Step::DequantGroupedMatmulGguf { .. } => "rlx::DequantGroupedMatmulGguf",
2048 Step::Sample { .. } => "rlx::Sample",
2049 Step::SelectiveScan { .. } => "rlx::SelectiveScan",
2050 Step::Fft { .. } => "rlx::Fft",
2051 Step::LogMelHost { .. } => "rlx::LogMelHost",
2052 Step::LogMelBackwardHost { .. } => "rlx::LogMelBackwardHost",
2053 Step::WelchPeaksHost { .. } => "rlx::WelchPeaksHost",
2054 Step::WelchPeaksGpu { .. } => "rlx::WelchPeaksGpu",
2055 Step::Im2ColHost { .. } => "rlx::Im2ColHost",
2056 Step::GatedDeltaNet { .. } => "rlx::GatedDeltaNet",
2057 Step::Llada2GroupLimitedGate { .. } => "rlx::Llada2GroupLimitedGate",
2058 Step::UmapKnn { .. } => "rlx::UmapKnn",
2059 Step::GaussianSplatRender { .. } => "rlx::GaussianSplatRender",
2060 Step::GaussianSplatRenderBackward { .. } => "rlx::GaussianSplatRenderBackward",
2061 Step::GaussianSplatPrepare { .. } => "rlx::GaussianSplatPrepare",
2062 Step::GaussianSplatRasterize { .. } => "rlx::GaussianSplatRasterize",
2063 Step::RmsNormBackwardInput { .. } => "rlx::RmsNormBackwardInput",
2064 Step::RmsNormBackwardGamma { .. } => "rlx::RmsNormBackwardGamma",
2065 Step::RmsNormBackwardBeta { .. } => "rlx::RmsNormBackwardBeta",
2066 Step::RopeBackward { .. } => "rlx::RopeBackward",
2067 Step::CumsumBackward { .. } => "rlx::CumsumBackward",
2068 Step::GatherBackward { .. } => "rlx::GatherBackward",
2069 Step::MaxPool2dBackward { .. } => "rlx::MaxPool2dBackward",
2070 Step::Conv2dBackwardInput { .. } => "rlx::Conv2dBackwardInput",
2071 Step::Conv2dBackwardWeight { .. } => "rlx::Conv2dBackwardWeight",
2072 Step::Pool1d { .. } => "rlx::Pool1d",
2073 Step::Pool2d { .. } => "rlx::Pool2d",
2074 Step::Pool3d { .. } => "rlx::Pool3d",
2075 Step::Conv1d { .. } => "rlx::Conv1d",
2076 Step::Conv2d { .. } => "rlx::Conv2d",
2077 Step::Conv3d { .. } => "rlx::Conv3d",
2078 Step::LayerNorm2d { .. } => "rlx::LayerNorm2d",
2079 Step::ConvTranspose2d { .. } => "rlx::ConvTranspose2d",
2080 Step::GroupNorm { .. } => "rlx::GroupNorm",
2081 Step::ResizeNearest2x { .. } => "rlx::ResizeNearest2x",
2082 Step::FusedBinaryUnary { .. } => "rlx::FusedBinaryUnary",
2083 Step::ElementwiseRegion { .. } => "rlx::ElementwiseRegion",
2084 Step::BatchElementwiseRegion { .. } => "rlx::BatchElementwiseRegion",
2085 }
2086}
2087
2088fn fuse_elementwise_chains(schedule: Vec<Step>) -> Vec<Step> {
2097 let mut consumer_counts: HashMap<u32, usize> = HashMap::new();
2100 for step in &schedule {
2101 let (reads, _) = step_offsets(step);
2102 for r in &reads {
2103 *consumer_counts.entry(*r).or_insert(0) += 1;
2104 }
2105 }
2106
2107 let mut out = Vec::with_capacity(schedule.len());
2108 let mut i = 0;
2109 while i < schedule.len() {
2110 if i + 1 < schedule.len() {
2111 let pair = (&schedule[i], &schedule[i + 1]);
2112 if let (
2113 Step::Binary {
2114 n,
2115 a_off,
2116 b_off,
2117 c_off,
2118 op: bin_op,
2119 },
2120 Step::Unary {
2121 n: n2,
2122 in_off,
2123 out_off,
2124 op: un_op,
2125 },
2126 ) = pair
2127 {
2128 let single_consumer = consumer_counts.get(c_off).copied() == Some(1);
2129 if n == n2 && c_off == in_off && single_consumer {
2130 out.push(Step::FusedBinaryUnary {
2131 n: *n,
2132 a_off: *a_off,
2133 b_off: *b_off,
2134 out_off: *out_off,
2135 bin_op: *bin_op,
2136 un_op: *un_op,
2137 });
2138 i += 2;
2139 continue;
2140 }
2141 }
2142 }
2143 out.push(schedule[i].clone());
2144 i += 1;
2145 }
2146 out
2147}
2148
2149fn step_offsets(step: &Step) -> (Vec<u32>, Vec<u32>) {
2156 match step {
2157 Step::Matmul {
2158 a_off_f32,
2159 b_off_f32,
2160 c_off_f32,
2161 has_bias,
2162 bias_off_f32,
2163 ..
2164 } => {
2165 let mut r = vec![*a_off_f32, *b_off_f32];
2166 if *has_bias != 0 {
2167 r.push(*bias_off_f32);
2168 }
2169 (r, vec![*c_off_f32])
2170 }
2171 Step::Binary {
2172 a_off,
2173 b_off,
2174 c_off,
2175 ..
2176 }
2177 | Step::Compare {
2178 a_off,
2179 b_off,
2180 c_off,
2181 ..
2182 } => (vec![*a_off, *b_off], vec![*c_off]),
2183 Step::Unary {
2184 in_off, out_off, ..
2185 } => (vec![*in_off], vec![*out_off]),
2186 Step::Where {
2187 cond_off,
2188 x_off,
2189 y_off,
2190 out_off,
2191 ..
2192 } => (vec![*cond_off, *x_off, *y_off], vec![*out_off]),
2193 Step::Reduce {
2194 in_off, out_off, ..
2195 }
2196 | Step::Softmax {
2197 in_off, out_off, ..
2198 }
2199 | Step::Argmax {
2200 in_off, out_off, ..
2201 }
2202 | Step::Cumsum {
2203 in_off, out_off, ..
2204 }
2205 | Step::Sample {
2206 in_off, out_off, ..
2207 } => (vec![*in_off], vec![*out_off]),
2208 Step::TopK {
2209 in_off, out_off, ..
2210 } => (vec![*in_off], vec![*out_off]),
2211 Step::LayerNorm {
2212 in_off,
2213 gamma_off,
2214 beta_off,
2215 out_off,
2216 ..
2217 } => (vec![*in_off, *gamma_off, *beta_off], vec![*out_off]),
2218 Step::FusedResidualLn {
2219 in_off,
2220 residual_off,
2221 bias_off,
2222 gamma_off,
2223 beta_off,
2224 out_off,
2225 has_bias,
2226 ..
2227 } => {
2228 let mut r = vec![*in_off, *residual_off, *gamma_off, *beta_off];
2229 if *has_bias != 0 {
2230 r.push(*bias_off);
2231 }
2232 (r, vec![*out_off])
2233 }
2234 Step::FusedResidualRmsNorm {
2235 in_off,
2236 residual_off,
2237 bias_off,
2238 gamma_off,
2239 beta_off,
2240 out_off,
2241 has_bias,
2242 ..
2243 } => {
2244 let mut r = vec![*in_off, *residual_off, *gamma_off, *beta_off];
2245 if *has_bias != 0 {
2246 r.push(*bias_off);
2247 }
2248 (r, vec![*out_off])
2249 }
2250 Step::Gather {
2251 in_off,
2252 idx_off,
2253 out_off,
2254 ..
2255 } => (vec![*in_off, *idx_off], vec![*out_off]),
2256 Step::GatherAxis {
2257 table_off,
2258 idx_off,
2259 out_off,
2260 ..
2261 } => (vec![*table_off, *idx_off], vec![*out_off]),
2262 Step::Narrow {
2263 in_off, out_off, ..
2264 }
2265 | Step::Concat {
2266 in_off, out_off, ..
2267 } => (vec![*in_off], vec![*out_off]),
2268 Step::Transpose {
2269 in_off, out_off, ..
2270 }
2271 | Step::Expand {
2272 in_off, out_off, ..
2273 } => (vec![*in_off], vec![*out_off]),
2274 Step::Attention {
2275 q_off,
2276 k_off,
2277 v_off,
2278 mask_off,
2279 mask_kind,
2280 out_off,
2281 ..
2282 } => {
2283 let mut r = vec![*q_off, *k_off, *v_off];
2284 if *mask_kind == 2 || *mask_kind == 4 {
2285 r.push(*mask_off);
2286 }
2287 (r, vec![*out_off])
2288 }
2289 Step::AttentionBackward {
2290 q_off,
2291 k_off,
2292 v_off,
2293 dy_off,
2294 mask_off,
2295 mask_kind,
2296 out_off,
2297 ..
2298 } => {
2299 let mut r = vec![*q_off, *k_off, *v_off, *dy_off];
2300 if *mask_kind == 2 || *mask_kind == 4 {
2301 r.push(*mask_off);
2302 }
2303 (r, vec![*out_off])
2304 }
2305 Step::Rope {
2306 in_off,
2307 cos_off,
2308 sin_off,
2309 out_off,
2310 ..
2311 } => (vec![*in_off, *cos_off, *sin_off], vec![*out_off]),
2312 Step::GroupedMatmul {
2313 in_off,
2314 w_off,
2315 idx_off,
2316 out_off,
2317 ..
2318 } => (vec![*in_off, *w_off, *idx_off], vec![*out_off]),
2319 Step::ScatterAddZero { out_off, .. } => (vec![], vec![*out_off]),
2320 Step::ScatterAddAcc {
2321 upd_off,
2322 idx_off,
2323 out_off,
2324 ..
2325 } =>
2326 {
2329 (vec![*upd_off, *idx_off, *out_off], vec![*out_off])
2330 }
2331 Step::DequantMatmul {
2332 x_off,
2333 w_off,
2334 scale_off,
2335 zp_off,
2336 out_off,
2337 scheme_id,
2338 ..
2339 } => {
2340 let mut r = vec![*x_off, *w_off, *scale_off];
2341 if *scheme_id == 1 {
2342 r.push(*zp_off);
2343 }
2344 (r, vec![*out_off])
2345 }
2346 Step::DequantMatmulGguf {
2347 x_byte_off,
2348 w_byte_off,
2349 out_byte_off,
2350 ..
2351 } => (vec![x_byte_off / 4, w_byte_off / 4], vec![out_byte_off / 4]),
2352 Step::DequantGroupedMatmulGguf {
2353 x_byte_off,
2354 w_byte_off,
2355 idx_byte_off,
2356 out_byte_off,
2357 ..
2358 } => (
2359 vec![x_byte_off / 4, w_byte_off / 4, idx_byte_off / 4],
2360 vec![out_byte_off / 4],
2361 ),
2362 Step::SelectiveScan {
2363 x_off,
2364 delta_off,
2365 a_off,
2366 b_off,
2367 c_off,
2368 out_off,
2369 ..
2370 } => (
2371 vec![*x_off, *delta_off, *a_off, *b_off, *c_off],
2372 vec![*out_off],
2373 ),
2374 Step::Fft {
2375 src_byte_off,
2376 dst_byte_off,
2377 ..
2378 } => (vec![*src_byte_off / 4], vec![*dst_byte_off / 4]),
2379 Step::LogMelHost {
2380 spec_byte_off,
2381 filt_byte_off,
2382 dst_byte_off,
2383 ..
2384 } => (
2385 vec![*spec_byte_off / 4, *filt_byte_off / 4],
2386 vec![*dst_byte_off / 4],
2387 ),
2388 Step::LogMelBackwardHost {
2389 spec_byte_off,
2390 filt_byte_off,
2391 dy_byte_off,
2392 dst_byte_off,
2393 ..
2394 } => (
2395 vec![*spec_byte_off / 4, *filt_byte_off / 4, *dy_byte_off / 4],
2396 vec![*dst_byte_off / 4],
2397 ),
2398 Step::WelchPeaksHost {
2399 spec_byte_off,
2400 dst_byte_off,
2401 ..
2402 } => (vec![*spec_byte_off / 4], vec![*dst_byte_off / 4]),
2403 Step::WelchPeaksGpu {
2404 spec_off, dst_off, ..
2405 } => (vec![*spec_off], vec![*dst_off]),
2406 Step::Im2ColHost {
2407 x_byte_off,
2408 col_byte_off,
2409 ..
2410 } => (vec![*x_byte_off / 4], vec![*col_byte_off / 4]),
2411 Step::GatedDeltaNet {
2412 q_byte_off,
2413 k_byte_off,
2414 v_byte_off,
2415 g_byte_off,
2416 beta_byte_off,
2417 state_byte_off,
2418 dst_byte_off,
2419 use_carry,
2420 ..
2421 } => {
2422 let mut reads = vec![
2423 q_byte_off / 4,
2424 k_byte_off / 4,
2425 v_byte_off / 4,
2426 g_byte_off / 4,
2427 beta_byte_off / 4,
2428 ];
2429 if *use_carry {
2430 reads.push(state_byte_off / 4);
2431 }
2432 let mut writes = vec![dst_byte_off / 4];
2433 if *use_carry {
2434 writes.push(state_byte_off / 4);
2435 }
2436 (reads, writes)
2437 }
2438 Step::Llada2GroupLimitedGate {
2439 sig_off,
2440 route_off,
2441 out_off,
2442 ..
2443 } => (vec![*sig_off, *route_off], vec![*out_off]),
2444 Step::UmapKnn {
2445 pairwise_off,
2446 out_off,
2447 ..
2448 } => (vec![*pairwise_off], vec![*out_off]),
2449 Step::GaussianSplatRender {
2450 positions_off,
2451 positions_len: _,
2452 scales_off,
2453 scales_len: _,
2454 rotations_off,
2455 rotations_len: _,
2456 opacities_off,
2457 opacities_len: _,
2458 colors_off,
2459 colors_len: _,
2460 sh_coeffs_off,
2461 sh_coeffs_len: _,
2462 meta_off,
2463 dst_off,
2464 dst_len: _,
2465 ..
2466 } => (
2467 vec![
2468 positions_off / 4,
2469 scales_off / 4,
2470 rotations_off / 4,
2471 opacities_off / 4,
2472 colors_off / 4,
2473 sh_coeffs_off / 4,
2474 meta_off / 4,
2475 ],
2476 vec![dst_off / 4],
2477 ),
2478 Step::GaussianSplatRenderBackward {
2479 positions_off,
2480 positions_len: _,
2481 scales_off,
2482 scales_len: _,
2483 rotations_off,
2484 rotations_len: _,
2485 opacities_off,
2486 opacities_len: _,
2487 colors_off,
2488 colors_len: _,
2489 sh_coeffs_off,
2490 sh_coeffs_len: _,
2491 meta_off,
2492 d_loss_off,
2493 d_loss_len: _,
2494 packed_off,
2495 packed_len: _,
2496 ..
2497 } => (
2498 vec![
2499 positions_off / 4,
2500 scales_off / 4,
2501 rotations_off / 4,
2502 opacities_off / 4,
2503 colors_off / 4,
2504 sh_coeffs_off / 4,
2505 meta_off / 4,
2506 d_loss_off / 4,
2507 ],
2508 vec![packed_off / 4],
2509 ),
2510 Step::RmsNormBackwardInput {
2511 x_byte_off,
2512 gamma_byte_off,
2513 beta_byte_off,
2514 dy_byte_off,
2515 dx_byte_off,
2516 ..
2517 } => (
2518 vec![
2519 x_byte_off / 4,
2520 gamma_byte_off / 4,
2521 beta_byte_off / 4,
2522 dy_byte_off / 4,
2523 ],
2524 vec![dx_byte_off / 4],
2525 ),
2526 Step::RmsNormBackwardGamma {
2527 x_byte_off,
2528 gamma_byte_off,
2529 beta_byte_off,
2530 dy_byte_off,
2531 dgamma_byte_off,
2532 ..
2533 } => (
2534 vec![
2535 x_byte_off / 4,
2536 gamma_byte_off / 4,
2537 beta_byte_off / 4,
2538 dy_byte_off / 4,
2539 ],
2540 vec![dgamma_byte_off / 4],
2541 ),
2542 Step::RmsNormBackwardBeta {
2543 x_byte_off,
2544 gamma_byte_off,
2545 beta_byte_off,
2546 dy_byte_off,
2547 dbeta_byte_off,
2548 ..
2549 } => (
2550 vec![
2551 x_byte_off / 4,
2552 gamma_byte_off / 4,
2553 beta_byte_off / 4,
2554 dy_byte_off / 4,
2555 ],
2556 vec![dbeta_byte_off / 4],
2557 ),
2558 Step::RopeBackward {
2559 dy_byte_off,
2560 cos_byte_off,
2561 sin_byte_off,
2562 dx_byte_off,
2563 ..
2564 } => (
2565 vec![dy_byte_off / 4, cos_byte_off / 4, sin_byte_off / 4],
2566 vec![dx_byte_off / 4],
2567 ),
2568 Step::CumsumBackward {
2569 dy_byte_off,
2570 dx_byte_off,
2571 ..
2572 } => (vec![dy_byte_off / 4], vec![dx_byte_off / 4]),
2573 Step::GatherBackward {
2574 dy_byte_off,
2575 indices_byte_off,
2576 dst_byte_off,
2577 ..
2578 } => (
2579 vec![dy_byte_off / 4, indices_byte_off / 4],
2580 vec![dst_byte_off / 4],
2581 ),
2582 Step::MaxPool2dBackward {
2583 x_byte_off,
2584 dy_byte_off,
2585 dx_byte_off,
2586 ..
2587 } => (
2588 vec![*x_byte_off / 4, *dy_byte_off / 4],
2589 vec![*dx_byte_off / 4],
2590 ),
2591 Step::Conv2dBackwardInput {
2592 dy_byte_off,
2593 w_byte_off,
2594 dx_byte_off,
2595 ..
2596 } => (
2597 vec![*dy_byte_off / 4, *w_byte_off / 4],
2598 vec![*dx_byte_off / 4],
2599 ),
2600 Step::Conv2dBackwardWeight {
2601 x_byte_off,
2602 dy_byte_off,
2603 dw_byte_off,
2604 ..
2605 } => (
2606 vec![*x_byte_off / 4, *dy_byte_off / 4],
2607 vec![*dw_byte_off / 4],
2608 ),
2609 Step::Pool1d {
2610 in_off, out_off, ..
2611 }
2612 | Step::Pool2d {
2613 in_off, out_off, ..
2614 }
2615 | Step::Pool3d {
2616 in_off, out_off, ..
2617 } => (vec![*in_off], vec![*out_off]),
2618 Step::Conv1d {
2619 in_off,
2620 w_off,
2621 out_off,
2622 ..
2623 }
2624 | Step::Conv2d {
2625 in_off,
2626 w_off,
2627 out_off,
2628 ..
2629 }
2630 | Step::Conv3d {
2631 in_off,
2632 w_off,
2633 out_off,
2634 ..
2635 } => (vec![*in_off, *w_off], vec![*out_off]),
2636 Step::LayerNorm2d {
2637 src_off,
2638 g_off,
2639 b_off,
2640 dst_off,
2641 ..
2642 } => (vec![*src_off, *g_off, *b_off], vec![*dst_off]),
2643 Step::ConvTranspose2d {
2644 src_off,
2645 w_off,
2646 dst_off,
2647 ..
2648 } => (vec![*src_off, *w_off], vec![*dst_off]),
2649 Step::GroupNorm {
2650 src_off,
2651 g_off,
2652 b_off,
2653 dst_off,
2654 ..
2655 } => (vec![*src_off, *g_off, *b_off], vec![*dst_off]),
2656 Step::ResizeNearest2x {
2657 src_off, dst_off, ..
2658 } => (vec![*src_off], vec![*dst_off]),
2659 Step::FusedBinaryUnary {
2660 a_off,
2661 b_off,
2662 out_off,
2663 ..
2664 } => (vec![*a_off, *b_off], vec![*out_off]),
2665 Step::ElementwiseRegion {
2666 dst_off,
2667 input_offs,
2668 num_inputs,
2669 ..
2670 } => {
2671 let n = (*num_inputs as usize).min(input_offs.len());
2672 (input_offs[..n].to_vec(), vec![*dst_off])
2673 }
2674 Step::BatchElementwiseRegion {
2675 base_dst_off,
2676 batch_input_offs,
2677 num_batch,
2678 ..
2679 } => {
2680 let n = (*num_batch as usize).min(64);
2681 (batch_input_offs[..n].to_vec(), vec![*base_dst_off])
2682 }
2683 Step::GaussianSplatPrepare {
2684 positions_off,
2685 scales_off,
2686 rotations_off,
2687 opacities_off,
2688 colors_off,
2689 sh_coeffs_off,
2690 meta_off,
2691 prep_off,
2692 ..
2693 } => (
2694 vec![
2695 positions_off / 4,
2696 scales_off / 4,
2697 rotations_off / 4,
2698 opacities_off / 4,
2699 colors_off / 4,
2700 sh_coeffs_off / 4,
2701 meta_off / 4,
2702 ],
2703 vec![prep_off / 4],
2704 ),
2705 Step::GaussianSplatRasterize {
2706 prep_off,
2707 meta_off,
2708 dst_off,
2709 ..
2710 } => (vec![prep_off / 4, meta_off / 4], vec![dst_off / 4]),
2711 }
2712}
2713
2714static AOT_PREWARM_ONCE: Once = Once::new();
2718
2719fn prewarm_all(ctx: &Arc<CudaContext>) {
2720 AOT_PREWARM_ONCE.call_once(|| prewarm_all_kernels(ctx));
2721}
2722
2723fn prewarm_all_kernels(ctx: &Arc<CudaContext>) {
2724 use crate::kernels::*;
2725 let _ = binary_kernel(ctx);
2726 let _ = fused_binary_unary_kernel(ctx);
2727 let _ = unary_kernel(ctx);
2728 let _ = copy_kernel(ctx);
2729 let _ = matmul_kernel(ctx);
2730 let _ = matmul_epilogue_kernel(ctx);
2731 let _ = compare_kernel(ctx);
2732 let _ = where_kernel(ctx);
2733 let _ = reduce_kernel(ctx);
2734 let _ = softmax_kernel(ctx);
2735 let _ = layernorm_kernel(ctx);
2736 let _ = fused_residual_ln_kernel(ctx);
2737 let _ = fused_residual_rms_norm_kernel(ctx);
2738 let _ = gather_kernel(ctx);
2739 let _ = gather_axis_kernel(ctx);
2740 let _ = narrow_kernel(ctx);
2741 let _ = concat_kernel(ctx);
2742 let _ = transpose_kernel(ctx);
2743 let _ = expand_kernel(ctx);
2744 let _ = attention_kernel(ctx);
2745 let _ = attention_row_kernel(ctx);
2746 let _ = attention_bwd_kernel(ctx);
2747 let _ = argmax_kernel(ctx);
2748 let _ = rope_kernel(ctx);
2749 let _ = cumsum_kernel(ctx);
2750 let _ = topk_kernel(ctx);
2751 let _ = grouped_matmul_kernel(ctx);
2752 let _ = scatter_add_zero_kernel(ctx);
2753 let _ = scatter_add_acc_kernel(ctx);
2754 let _ = dequant_matmul_kernel(ctx);
2755 let _ = dequant_gguf_kernel(ctx);
2756 let _ = sample_kernel(ctx);
2757 let _ = selective_scan_kernel(ctx);
2758 let _ = pool1d_kernel(ctx);
2759 let _ = pool2d_kernel(ctx);
2760 let _ = pool3d_kernel(ctx);
2761 let _ = conv1d_kernel(ctx);
2762 let _ = conv2d_kernel(ctx);
2763 let _ = im2col_kernel(ctx);
2764 let _ = conv3d_kernel(ctx);
2765 let _ = layer_norm2d_kernel(ctx);
2766 let _ = conv_transpose2d_kernel(ctx);
2767 let _ = group_norm_kernel(ctx);
2768 let _ = resize_nearest_2x_kernel(ctx);
2769 let _ = elementwise_region_kernel(ctx);
2770 let _ = batch_elementwise_region_kernel(ctx);
2771 }
2774
2775fn im2col_use_gpu(n: u32, exec_mode: ExecMode) -> bool {
2776 if rlx_ir::env::var("RLX_CUDA_IM2COL_HOST").is_some() {
2777 return false;
2778 }
2779 if matches!(exec_mode, ExecMode::Graph) {
2780 return n > 0;
2781 }
2782 n > 0
2783}
2784
2785fn pinned_host_io_disabled() -> bool {
2786 rlx_ir::env::var("RLX_CUDA_PINNED_IO").is_some_and(|v| v.eq_ignore_ascii_case("0"))
2787}
2788
2789fn pinned_output_staging_enabled() -> bool {
2791 !pinned_host_io_disabled()
2792}
2793
2794fn pinned_input_staging_enabled(exec_mode: ExecMode) -> bool {
2796 if pinned_host_io_disabled() {
2797 return false;
2798 }
2799 matches!(exec_mode, ExecMode::Graph)
2800 || rlx_ir::env::var("RLX_CUDA_PINNED_IO").is_some_and(|v| !v.eq_ignore_ascii_case("0"))
2801}
2802
2803fn normalize_read_indices(buf: &mut Vec<usize>) {
2804 if buf.len() > 1 {
2805 buf.sort_unstable();
2806 buf.dedup();
2807 }
2808}
2809
2810fn compile_mode_from_env() -> CompileMode {
2811 match rlx_ir::env::var("RLX_CUDA_COMPILE_MODE").as_deref() {
2812 Some(mode) if mode.eq_ignore_ascii_case("aot") => CompileMode::Aot,
2813 _ => CompileMode::Jit,
2814 }
2815}
2816
2817fn exec_mode_from_env() -> ExecMode {
2818 match rlx_ir::env::var("RLX_CUDA_EXEC_MODE").as_deref() {
2819 Some(mode) if mode.eq_ignore_ascii_case("graph") => ExecMode::Graph,
2820 Some(mode) => {
2821 let lower = mode.to_ascii_lowercase();
2822 if let Some(rest) = lower.strip_prefix("multistream") {
2823 let n = rest.trim_start_matches([':', '=']).parse().unwrap_or(2);
2824 ExecMode::MultiStream(n.max(1))
2825 } else {
2826 ExecMode::Stream
2827 }
2828 }
2829 _ => ExecMode::Stream,
2830 }
2831}
2832
2833impl CudaExecutable {
2834 pub fn compile(graph: Graph) -> Self {
2838 Self::compile_with(graph, compile_mode_from_env(), exec_mode_from_env())
2839 }
2840
2841 pub fn eager(graph: Graph, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2844 let mut exec = Self::compile_with(graph, CompileMode::Jit, ExecMode::Eager);
2845 exec.run(inputs)
2846 }
2847
2848 pub fn compile_with(graph: Graph, compile_mode: CompileMode, exec_mode: ExecMode) -> Self {
2850 let ctx = cuda_context().expect("rlx-cuda: no CUDA driver available");
2851
2852 if compile_mode == CompileMode::Aot {
2853 prewarm_all(&ctx);
2854 }
2855
2856 let graph = LowerNonLastAxisReduce.run(crate::unfuse::unfuse(graph));
2861
2862 let dequant_scratch = crate::gguf_gpu::dequant_gguf_scratch_bytes(&graph);
2863 let mut plan = plan_f32_uniform(&graph, 16);
2864 let dequant_scratch_off = if dequant_scratch > 0 {
2865 let aligned = plan.arena_size.div_ceil(16) * 16;
2866 plan.arena_size = aligned + dequant_scratch;
2867 aligned
2868 } else {
2869 0
2870 };
2871 let mut arena = Arena::from_plan(&ctx, &plan);
2872 for node in graph.nodes() {
2873 let elems = node.shape.num_elements().unwrap_or(0);
2874 arena.set_actual_len(node.id, elems * 4);
2875 }
2876
2877 let mut input_offsets = HashMap::new();
2879 let mut param_offsets = HashMap::new();
2880 for node in graph.nodes() {
2881 match &node.op {
2882 Op::Input { name } => {
2883 input_offsets.insert(name.clone(), node.id);
2884 }
2885 Op::Param { name } => {
2886 param_offsets.insert(name.clone(), node.id);
2887 }
2888 _ => {}
2889 }
2890 }
2891
2892 for node in graph.nodes() {
2894 if let Op::Constant { data } = &node.op
2895 && arena.has(node.id)
2896 && !data.is_empty()
2897 {
2898 let bytes_to_write = data.len().min(arena.len_of(node.id));
2899 let n_f32 = bytes_to_write / 4;
2900 let f32_view: &[f32] =
2901 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n_f32) };
2902 let off_f32 = arena.offset(node.id) / 4;
2903 let stream = ctx.default_stream();
2904 let mut slot = arena.f32_buf_mut().slice_mut(off_f32..off_f32 + n_f32);
2905 stream
2906 .memcpy_htod(f32_view, &mut slot)
2907 .expect("rlx-cuda: constant upload failed");
2908 }
2909 }
2910
2911 let mut schedule = Vec::new();
2912 let mut meta_buffers: Vec<cudarc::driver::CudaSlice<u32>> = Vec::new();
2913 let mut packed_bshd_attn: HashMap<NodeId, (NodeId, u32)> = HashMap::new();
2914 if !rlx_ir::env::flag("RLX_CUDA_NO_PACKED_BSHD_ATTN") {
2915 for node in graph.nodes() {
2916 let Op::Attention { .. } = &node.op else {
2917 continue;
2918 };
2919 if node.inputs.len() < 3 {
2920 continue;
2921 }
2922 if let Some((parent, head_width, _)) = rlx_ir::detect_packed_bshd_qkv_attention(
2923 &graph,
2924 node.inputs[0],
2925 node.inputs[1],
2926 node.inputs[2],
2927 ) {
2928 packed_bshd_attn.insert(node.id, (parent, head_width as u32));
2929 }
2930 }
2931 }
2932 for node in graph.nodes() {
2933 let elems = node.shape.num_elements().unwrap_or(0) as u32;
2934 match &node.op {
2935 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => continue,
2936 Op::Reshape { .. } | Op::Cast { .. } => {
2937 }
2941 Op::MatMul => {
2942 let (m, k, n, batch, a_bs, b_bs, c_bs, a_id, b_id) =
2943 matmul_shape(&graph, node, "MatMul");
2944 schedule.push(Step::Matmul {
2945 m,
2946 k,
2947 n,
2948 batch,
2949 a_batch_stride: a_bs,
2950 b_batch_stride: b_bs,
2951 c_batch_stride: c_bs,
2952 a_off_f32: (arena.offset(a_id) / 4) as u32,
2953 b_off_f32: (arena.offset(b_id) / 4) as u32,
2954 c_off_f32: (arena.offset(node.id) / 4) as u32,
2955 has_bias: 0,
2956 bias_off_f32: 0,
2957 act_id: 0xFFFF,
2958 });
2959 }
2960 Op::FusedMatMulBiasAct { activation } => {
2961 let (m, k, n, batch, a_bs, b_bs, c_bs, a_id, b_id) =
2962 matmul_shape(&graph, node, "FusedMatMulBiasAct");
2963 let bias_id = node.inputs[2];
2964 let act_id = match activation {
2965 None => 0xFFFFu32,
2966 Some(a) => activation_op_id(*a),
2967 };
2968 schedule.push(Step::Matmul {
2969 m,
2970 k,
2971 n,
2972 batch,
2973 a_batch_stride: a_bs,
2974 b_batch_stride: b_bs,
2975 c_batch_stride: c_bs,
2976 a_off_f32: (arena.offset(a_id) / 4) as u32,
2977 b_off_f32: (arena.offset(b_id) / 4) as u32,
2978 c_off_f32: (arena.offset(node.id) / 4) as u32,
2979 has_bias: 1,
2980 bias_off_f32: (arena.offset(bias_id) / 4) as u32,
2981 act_id,
2982 });
2983 }
2984 Op::Binary(bop) => {
2985 schedule.push(Step::Binary {
2986 n: elems,
2987 a_off: (arena.offset(node.inputs[0]) / 4) as u32,
2988 b_off: (arena.offset(node.inputs[1]) / 4) as u32,
2989 c_off: (arena.offset(node.id) / 4) as u32,
2990 op: binary_op_id(*bop),
2991 });
2992 }
2993 Op::Activation(act) => {
2994 schedule.push(Step::Unary {
2995 n: elems,
2996 in_off: (arena.offset(node.inputs[0]) / 4) as u32,
2997 out_off: (arena.offset(node.id) / 4) as u32,
2998 op: activation_op_id(*act),
2999 });
3000 }
3001 Op::Compare(cop) => {
3002 schedule.push(Step::Compare {
3003 n: elems,
3004 a_off: (arena.offset(node.inputs[0]) / 4) as u32,
3005 b_off: (arena.offset(node.inputs[1]) / 4) as u32,
3006 c_off: (arena.offset(node.id) / 4) as u32,
3007 op: compare_op_id(*cop),
3008 });
3009 }
3010 Op::Where => {
3011 schedule.push(Step::Where {
3012 n: elems,
3013 cond_off: (arena.offset(node.inputs[0]) / 4) as u32,
3014 x_off: (arena.offset(node.inputs[1]) / 4) as u32,
3015 y_off: (arena.offset(node.inputs[2]) / 4) as u32,
3016 out_off: (arena.offset(node.id) / 4) as u32,
3017 });
3018 }
3019 Op::BatchElementwiseRegion {
3020 chain,
3021 num_batch_inputs,
3022 scalar_input_mask,
3023 input_modulus,
3024 prologue,
3025 prologue_input,
3026 } => {
3027 let n = *num_batch_inputs as usize;
3028 if n == 0 || chain.len() > 32 {
3029 panic!(
3030 "rlx-cuda BatchElementwiseRegion: num_batch_inputs={n} steps={}",
3031 chain.len()
3032 );
3033 }
3034 let slice_shape = rlx_ir::batch_region_slice_shape(&node.shape);
3035 let slice_elems = rlx_ir::batch_region_slice_elems(&node.shape, n)
3036 .expect("batch region static shape");
3037 let base_dst_off = (arena.offset(node.id) / 4) as u32;
3038 let use_single = rlx_ir::fk_batch_use_single_launch(n, *prologue);
3039 if use_single {
3040 let mut batch_input_offs = [0u32; 64];
3041 for i in 0..n {
3042 batch_input_offs[i] = (arena.offset(node.inputs[i]) / 4) as u32;
3043 }
3044 let input_offs_meta = [0u32; 16];
3045 let meta_arr = rlx_ir::encode_elementwise_region_meta(
3046 &input_offs_meta,
3047 chain,
3048 *prologue,
3049 &slice_shape,
3050 *prologue_input,
3051 );
3052 let meta = ctx
3053 .default_stream()
3054 .clone_htod(&meta_arr.to_vec())
3055 .expect("rlx-cuda: batch elementwise_region meta upload failed");
3056 let meta_idx = meta_buffers.len();
3057 meta_buffers.push(meta);
3058 let batch_vec: Vec<u32> = batch_input_offs[..n].to_vec();
3059 let batch_dev = ctx
3060 .default_stream()
3061 .clone_htod(&batch_vec)
3062 .expect("rlx-cuda: batch input offs upload failed");
3063 let batch_offs_idx = meta_buffers.len();
3064 meta_buffers.push(batch_dev);
3065 schedule.push(Step::BatchElementwiseRegion {
3066 slice_len: slice_elems,
3067 num_batch: n as u32,
3068 num_steps: chain.len() as u32,
3069 base_dst_off,
3070 slice_elems,
3071 batch_input_offs,
3072 batch_offs_idx,
3073 meta_idx,
3074 scalar_input_mask: *scalar_input_mask,
3075 input_modulus: *input_modulus,
3076 });
3077 } else {
3078 for i in 0..n {
3079 let mut input_offs = [0u32; 16];
3080 input_offs[0] = (arena.offset(node.inputs[i]) / 4) as u32;
3081 let meta_arr = rlx_ir::encode_elementwise_region_meta(
3082 &input_offs,
3083 chain,
3084 *prologue,
3085 &slice_shape,
3086 *prologue_input,
3087 );
3088 let meta = ctx
3089 .default_stream()
3090 .clone_htod(&meta_arr.to_vec())
3091 .expect("rlx-cuda: batch elementwise_region meta upload failed");
3092 let meta_idx = meta_buffers.len();
3093 meta_buffers.push(meta);
3094 let spatial =
3095 matches!(*prologue, rlx_ir::RegionPrologue::ResizeNearest2x);
3096 let grid = rlx_ir::PrologueLaunchGrid::from_output_shape(&slice_shape);
3097 schedule.push(Step::ElementwiseRegion {
3098 len: slice_elems,
3099 num_inputs: 1,
3100 num_steps: chain.len() as u32,
3101 dst_off: rlx_ir::batch_region_slice_dst_off_f32(
3102 base_dst_off,
3103 slice_elems,
3104 i,
3105 ),
3106 input_offs,
3107 scalar_input_mask: *scalar_input_mask,
3108 input_modulus: *input_modulus,
3109 meta_idx,
3110 spatial_prologue: spatial,
3111 prologue_w: grid.map(|g| g.width).unwrap_or(0),
3112 prologue_h: grid.map(|g| g.height).unwrap_or(0),
3113 prologue_nc: grid.map(|g| g.depth).unwrap_or(0),
3114 });
3115 }
3116 }
3117 }
3118 Op::ElementwiseRegion {
3119 chain,
3120 num_inputs,
3121 scalar_input_mask,
3122 input_modulus,
3123 prologue,
3124 prologue_input,
3125 } => {
3126 let n = *num_inputs as usize;
3133 if n > 16 || chain.len() > 32 {
3134 panic!(
3135 "rlx-cuda ElementwiseRegion: chain too large \
3136 (inputs={n}, steps={}). Caps: 16 / 32. \
3137 Run UnfuseElementwiseRegions to fall back \
3138 to atomic ops.",
3139 chain.len()
3140 );
3141 }
3142 let mut input_offs = [0u32; 16];
3143 for (i, &id) in node.inputs.iter().enumerate() {
3144 input_offs[i] = (arena.offset(id) / 4) as u32;
3145 }
3146 let meta_arr = rlx_ir::encode_elementwise_region_meta(
3147 &input_offs,
3148 chain,
3149 *prologue,
3150 &node.shape,
3151 *prologue_input,
3152 );
3153 let meta_data: Vec<u32> = meta_arr.to_vec();
3154 let meta = ctx
3155 .default_stream()
3156 .clone_htod(&meta_data)
3157 .expect("rlx-cuda: elementwise_region meta upload failed");
3158 let meta_idx = meta_buffers.len();
3159 meta_buffers.push(meta);
3160 let spatial = matches!(*prologue, rlx_ir::RegionPrologue::ResizeNearest2x);
3161 let grid = rlx_ir::PrologueLaunchGrid::from_output_shape(&node.shape);
3162 schedule.push(Step::ElementwiseRegion {
3163 len: elems,
3164 num_inputs: *num_inputs,
3165 num_steps: chain.len() as u32,
3166 dst_off: (arena.offset(node.id) / 4) as u32,
3167 input_offs,
3168 scalar_input_mask: *scalar_input_mask,
3169 input_modulus: *input_modulus,
3170 meta_idx,
3171 spatial_prologue: spatial,
3172 prologue_w: grid.map(|g| g.width).unwrap_or(0),
3173 prologue_h: grid.map(|g| g.height).unwrap_or(0),
3174 prologue_nc: grid.map(|g| g.depth).unwrap_or(0),
3175 });
3176 }
3177 Op::Reduce {
3178 op,
3179 axes,
3180 keep_dim: _,
3181 } => {
3182 let in_id = node.inputs[0];
3185 let in_dims = graph.node(in_id).shape.dims();
3186 if axes.len() != 1 || axes[0] != in_dims.len() - 1 {
3187 panic!(
3188 "rlx-cuda Reduce: only single last-axis supported \
3189 (got axes={axes:?}, rank={})",
3190 in_dims.len()
3191 );
3192 }
3193 let inner = in_dims.last().unwrap().unwrap_static() as u32;
3194 let outer = in_dims[..in_dims.len() - 1]
3195 .iter()
3196 .map(|d| d.unwrap_static() as u32)
3197 .product::<u32>()
3198 .max(1);
3199 schedule.push(Step::Reduce {
3200 outer,
3201 inner,
3202 in_off: (arena.offset(in_id) / 4) as u32,
3203 out_off: (arena.offset(node.id) / 4) as u32,
3204 op: reduce_op_id(*op),
3205 });
3206 }
3207 Op::Softmax { axis: _ } => {
3208 let in_id = node.inputs[0];
3209 let in_dims = graph.node(in_id).shape.dims();
3210 let inner = in_dims.last().unwrap().unwrap_static() as u32;
3211 let outer = in_dims[..in_dims.len() - 1]
3212 .iter()
3213 .map(|d| d.unwrap_static() as u32)
3214 .product::<u32>()
3215 .max(1);
3216 schedule.push(Step::Softmax {
3217 outer,
3218 inner,
3219 in_off: (arena.offset(in_id) / 4) as u32,
3220 out_off: (arena.offset(node.id) / 4) as u32,
3221 });
3222 }
3223 Op::LayerNorm { axis: _, eps } | Op::RmsNorm { axis: _, eps } => {
3224 let in_id = node.inputs[0];
3225 let in_dims = graph.node(in_id).shape.dims();
3226 let inner = in_dims.last().unwrap().unwrap_static() as u32;
3227 let total: u32 = in_dims.iter().map(|d| d.unwrap_static() as u32).product();
3228 let outer = total / inner.max(1);
3229 let is_layer = matches!(&node.op, Op::LayerNorm { .. });
3230 let gamma_id = node.inputs[1];
3231 let beta_id = if is_layer && node.inputs.len() >= 3 {
3232 node.inputs[2]
3233 } else {
3234 gamma_id
3235 };
3236 schedule.push(Step::LayerNorm {
3237 outer,
3238 inner,
3239 in_off: (arena.offset(in_id) / 4) as u32,
3240 out_off: (arena.offset(node.id) / 4) as u32,
3241 gamma_off: (arena.offset(gamma_id) / 4) as u32,
3242 beta_off: (arena.offset(beta_id) / 4) as u32,
3243 eps_bits: eps.to_bits(),
3244 op: if is_layer { 0 } else { 1 },
3245 });
3246 }
3247 Op::FusedResidualLN { has_bias, eps } => {
3248 let x_id = node.inputs[0];
3249 let r_id = node.inputs[1];
3250 let (bias_id, g_id, b_id) = if *has_bias {
3251 (node.inputs[2], node.inputs[3], node.inputs[4])
3252 } else {
3253 (x_id, node.inputs[2], node.inputs[3])
3254 };
3255 let in_dims = node.shape.dims();
3256 let inner = in_dims.last().unwrap().unwrap_static() as u32;
3257 let total: u32 = in_dims.iter().map(|d| d.unwrap_static() as u32).product();
3258 let outer = total / inner.max(1);
3259 schedule.push(Step::FusedResidualLn {
3260 outer,
3261 inner,
3262 in_off: (arena.offset(x_id) / 4) as u32,
3263 residual_off: (arena.offset(r_id) / 4) as u32,
3264 bias_off: (arena.offset(bias_id) / 4) as u32,
3265 gamma_off: (arena.offset(g_id) / 4) as u32,
3266 beta_off: (arena.offset(b_id) / 4) as u32,
3267 out_off: (arena.offset(node.id) / 4) as u32,
3268 eps_bits: eps.to_bits(),
3269 has_bias: if *has_bias { 1 } else { 0 },
3270 });
3271 }
3272 Op::FusedResidualRmsNorm { has_bias, eps } => {
3273 let x_id = node.inputs[0];
3274 let r_id = node.inputs[1];
3275 let (bias_id, g_id, b_id) = if *has_bias {
3276 (node.inputs[2], node.inputs[3], node.inputs[4])
3277 } else {
3278 (x_id, node.inputs[2], node.inputs[3])
3279 };
3280 let in_dims = node.shape.dims();
3281 let inner = in_dims.last().unwrap().unwrap_static() as u32;
3282 let total: u32 = in_dims.iter().map(|d| d.unwrap_static() as u32).product();
3283 let outer = total / inner.max(1);
3284 schedule.push(Step::FusedResidualRmsNorm {
3285 outer,
3286 inner,
3287 in_off: (arena.offset(x_id) / 4) as u32,
3288 residual_off: (arena.offset(r_id) / 4) as u32,
3289 bias_off: (arena.offset(bias_id) / 4) as u32,
3290 gamma_off: (arena.offset(g_id) / 4) as u32,
3291 beta_off: (arena.offset(b_id) / 4) as u32,
3292 out_off: (arena.offset(node.id) / 4) as u32,
3293 eps_bits: eps.to_bits(),
3294 has_bias: if *has_bias { 1 } else { 0 },
3295 });
3296 }
3297 Op::Gather { axis } => {
3298 let table_id = node.inputs[0];
3299 let idx_id = node.inputs[1];
3300 if *axis == 0 {
3301 let table_shape = graph.node(table_id).shape.dims();
3302 let idx_shape = graph.node(idx_id).shape.dims();
3303 let vocab = table_shape[0].unwrap_static() as u32;
3304 let dim: u32 = table_shape[1..]
3305 .iter()
3306 .map(|d| d.unwrap_static() as u32)
3307 .product::<u32>()
3308 .max(1);
3309 let n_idx: u32 =
3310 idx_shape.iter().map(|d| d.unwrap_static() as u32).product();
3311 schedule.push(Step::Gather {
3312 n_out: elems,
3313 n_idx,
3314 dim,
3315 vocab,
3316 in_off: (arena.offset(table_id) / 4) as u32,
3317 idx_off: (arena.offset(idx_id) / 4) as u32,
3318 out_off: (arena.offset(node.id) / 4) as u32,
3319 });
3320 } else {
3321 let table_shape = graph.node(table_id).shape.dims();
3322 let idx_shape = graph.node(idx_id).shape.dims();
3323 let outer: u32 = table_shape[..*axis]
3324 .iter()
3325 .map(|d| d.unwrap_static() as u32)
3326 .product::<u32>()
3327 .max(1);
3328 let trailing: u32 = table_shape[*axis + 1..]
3329 .iter()
3330 .map(|d| d.unwrap_static() as u32)
3331 .product::<u32>()
3332 .max(1);
3333 let axis_dim = table_shape[*axis].unwrap_static() as u32;
3334 let num_idx: u32 =
3335 idx_shape.iter().map(|d| d.unwrap_static() as u32).product();
3336 let total = outer * num_idx * trailing;
3337 schedule.push(Step::GatherAxis {
3338 total,
3339 outer,
3340 axis_dim,
3341 num_idx,
3342 trailing,
3343 table_off: (arena.offset(table_id) / 4) as u32,
3344 idx_off: (arena.offset(idx_id) / 4) as u32,
3345 out_off: (arena.offset(node.id) / 4) as u32,
3346 });
3347 }
3348 }
3349 Op::Narrow { axis, start, len } => {
3350 let in_id = node.inputs[0];
3351 let in_dims = graph.node(in_id).shape.dims();
3352 let outer: u32 = in_dims[..*axis]
3353 .iter()
3354 .map(|d| d.unwrap_static() as u32)
3355 .product::<u32>()
3356 .max(1);
3357 let inner: u32 = in_dims[*axis + 1..]
3358 .iter()
3359 .map(|d| d.unwrap_static() as u32)
3360 .product::<u32>()
3361 .max(1);
3362 let axis_in = in_dims[*axis].unwrap_static() as u32;
3363 schedule.push(Step::Narrow {
3364 total: elems,
3365 outer,
3366 inner,
3367 axis_in_size: axis_in,
3368 axis_out_size: *len as u32,
3369 start: *start as u32,
3370 in_off: (arena.offset(in_id) / 4) as u32,
3371 out_off: (arena.offset(node.id) / 4) as u32,
3372 });
3373 }
3374 Op::Transpose { perm } => {
3375 let in_id = node.inputs[0];
3376 let in_dims = graph.node(in_id).shape.dims();
3377 let rank = perm.len();
3378 let in_dims_u: Vec<u32> =
3379 in_dims.iter().map(|d| d.unwrap_static() as u32).collect();
3380 let mut in_strides = vec![1u32; rank];
3382 for i in (0..rank.saturating_sub(1)).rev() {
3383 in_strides[i] = in_strides[i + 1] * in_dims_u[i + 1];
3384 }
3385 let out_dims_u: Vec<u32> = perm.iter().map(|&i| in_dims_u[i]).collect();
3386 let strides_for_out: Vec<u32> = perm.iter().map(|&i| in_strides[i]).collect();
3387 let mut meta_data: Vec<u32> = Vec::with_capacity(rank * 2);
3388 meta_data.extend_from_slice(&out_dims_u);
3389 meta_data.extend_from_slice(&strides_for_out);
3390 let meta = ctx
3391 .default_stream()
3392 .clone_htod(&meta_data)
3393 .expect("rlx-cuda: meta upload failed");
3394 let meta_idx = meta_buffers.len();
3395 meta_buffers.push(meta);
3396 schedule.push(Step::Transpose {
3397 rank: rank as u32,
3398 out_total: elems,
3399 in_off: (arena.offset(in_id) / 4) as u32,
3400 out_off: (arena.offset(node.id) / 4) as u32,
3401 meta_idx,
3402 });
3403 }
3404 Op::Expand { target_shape } => {
3405 let in_id = node.inputs[0];
3406 let in_shape = graph.node(in_id).shape.dims();
3407 let rank = target_shape.len();
3408 if rank < in_shape.len() {
3409 panic!(
3410 "rlx-cuda Expand: cannot reduce rank (in={}, target={})",
3411 in_shape.len(),
3412 rank
3413 );
3414 }
3415 let out_dims: Vec<u32> = target_shape.iter().map(|&d| d as u32).collect();
3416 let pad = rank - in_shape.len();
3417 let mut in_dims: Vec<u32> = vec![1; pad];
3418 in_dims.extend(in_shape.iter().map(|d| d.unwrap_static() as u32));
3419 let mut in_strides_row = vec![1u32; rank];
3420 for i in (0..rank.saturating_sub(1)).rev() {
3421 in_strides_row[i] = in_strides_row[i + 1] * in_dims[i + 1];
3422 }
3423 let strides_for_out: Vec<u32> = (0..rank)
3424 .map(|i| {
3425 if in_dims[i] == 1 && out_dims[i] != 1 {
3426 0
3427 } else {
3428 in_strides_row[i]
3429 }
3430 })
3431 .collect();
3432 let mut meta_data: Vec<u32> = Vec::with_capacity(rank * 2);
3433 meta_data.extend_from_slice(&out_dims);
3434 meta_data.extend_from_slice(&strides_for_out);
3435 let meta = ctx
3436 .default_stream()
3437 .clone_htod(&meta_data)
3438 .expect("rlx-cuda: meta upload failed");
3439 let meta_idx = meta_buffers.len();
3440 meta_buffers.push(meta);
3441 schedule.push(Step::Expand {
3442 rank: rank as u32,
3443 out_total: elems,
3444 in_off: (arena.offset(in_id) / 4) as u32,
3445 out_off: (arena.offset(node.id) / 4) as u32,
3446 meta_idx,
3447 });
3448 }
3449 Op::Concat { axis } => {
3450 let mut start: u32 = 0;
3453 let out_dims = node.shape.dims();
3454 let outer: u32 = out_dims[..*axis]
3455 .iter()
3456 .map(|d| d.unwrap_static() as u32)
3457 .product::<u32>()
3458 .max(1);
3459 let inner: u32 = out_dims[*axis + 1..]
3460 .iter()
3461 .map(|d| d.unwrap_static() as u32)
3462 .product::<u32>()
3463 .max(1);
3464 let axis_out_size = out_dims[*axis].unwrap_static() as u32;
3465 for &in_id in &node.inputs {
3466 let in_dims = graph.node(in_id).shape.dims();
3467 let axis_in = in_dims[*axis].unwrap_static() as u32;
3468 let total: u32 = in_dims.iter().map(|d| d.unwrap_static() as u32).product();
3469 schedule.push(Step::Concat {
3470 total,
3471 outer,
3472 inner,
3473 axis_in_size: axis_in,
3474 axis_out_size,
3475 start,
3476 in_off: (arena.offset(in_id) / 4) as u32,
3477 out_off: (arena.offset(node.id) / 4) as u32,
3478 });
3479 start += axis_in;
3480 }
3481 }
3482 Op::Attention {
3483 num_heads,
3484 head_dim,
3485 mask_kind,
3486 score_scale: _,
3487 attn_logit_softcap: _,
3488 } => {
3489 let q_id = node.inputs[0];
3490 let k_id = node.inputs[1];
3491 let v_id = node.inputs[2];
3492 let q_shape = graph.node(q_id).shape.dims();
3493 let k_shape = graph.node(k_id).shape.dims();
3494 if q_shape.len() != 4 {
3495 panic!("rlx-cuda Attention: unfuse should have promoted to rank-4");
3496 }
3497 let q_ir = graph.node(q_id).shape.clone();
3498 let k_ir = graph.node(k_id).shape.clone();
3499 let geom = rlx_ir::attention_geom(&q_ir, &k_ir, *num_heads, *head_dim);
3500 let batch = geom.batch as u32;
3501 let heads = geom.heads as u32;
3502 let seq_q = geom.seq_q as u32;
3503 let seq_k = geom.seq_k as u32;
3504 let hd = *head_dim as u32;
3505 let scale = 1.0_f32 / (hd as f32).sqrt();
3506 let mask_shape = if matches!(mask_kind, MaskKind::Custom | MaskKind::Bias) {
3507 Some(graph.node(node.inputs[3]).shape.dims())
3508 } else {
3509 None
3510 };
3511 let packed_parent = packed_bshd_attn.get(&node.id).copied();
3512 let st = if let Some((_, head_width)) = packed_parent {
3513 let (qb, qh, qs) =
3514 rlx_ir::packed_bshd_qkv_strides(head_width as usize, hd, seq_q);
3515 let (ob, oh, os) =
3516 rlx_ir::strides_for_shape(node.shape.dims(), heads, hd, seq_q, false);
3517 let (mb, mh, mq, mk) = mask_shape
3518 .map(|m| rlx_ir::mask_strides_for_shape(m, heads, seq_q, seq_k))
3519 .unwrap_or_else(|| rlx_ir::mask_strides_bhsd(heads, seq_q, seq_k));
3520 rlx_ir::AttentionLaunchStrides {
3521 q_batch: qb,
3522 q_head: qh,
3523 q_seq: qs,
3524 k_batch: qb,
3525 k_head: qh,
3526 k_seq: qs,
3527 v_batch: qb,
3528 v_head: qh,
3529 v_seq: qs,
3530 o_batch: ob,
3531 o_head: oh,
3532 o_seq: os,
3533 mask_batch: mb,
3534 mask_head: mh,
3535 mask_q: mq,
3536 mask_k: mk,
3537 }
3538 } else {
3539 rlx_ir::attention_launch_strides(
3540 geom,
3541 q_shape,
3542 k_shape,
3543 graph.node(v_id).shape.dims(),
3544 node.shape.dims(),
3545 mask_shape,
3546 )
3547 };
3548 let (q_off, k_off, v_off) = if let Some((parent, head_width)) = packed_parent {
3549 let p = (arena.offset(parent) / 4) as u32;
3550 (
3551 p,
3552 p.saturating_add(head_width),
3553 p.saturating_add(head_width * 2),
3554 )
3555 } else {
3556 (
3557 (arena.offset(q_id) / 4) as u32,
3558 (arena.offset(k_id) / 4) as u32,
3559 (arena.offset(v_id) / 4) as u32,
3560 )
3561 };
3562 let (mask_kind_id, mask_off, window) = match mask_kind {
3563 MaskKind::None => (0u32, 0u32, 0u32),
3564 MaskKind::Causal => (1u32, 0u32, 0u32),
3565 MaskKind::Custom => (2u32, (arena.offset(node.inputs[3]) / 4) as u32, 0u32),
3566 MaskKind::SlidingWindow(w) => (3u32, 0u32, *w as u32),
3567 MaskKind::Bias => (4u32, (arena.offset(node.inputs[3]) / 4) as u32, 0u32),
3568 };
3569 schedule.push(Step::Attention {
3570 batch,
3571 heads,
3572 seq_q,
3573 seq_k,
3574 head_dim: hd,
3575 q_off,
3576 k_off,
3577 v_off,
3578 out_off: (arena.offset(node.id) / 4) as u32,
3579 mask_off,
3580 mask_kind: mask_kind_id,
3581 scale_bits: scale.to_bits(),
3582 window,
3583 seq_q_stride: st.mask_q,
3584 seq_k_stride: st.mask_k,
3585 mask_batch_stride: st.mask_batch,
3586 mask_head_stride: st.mask_head,
3587 q_batch_stride: st.q_batch,
3588 q_head_stride: st.q_head,
3589 q_seq_stride: st.q_seq,
3590 k_batch_stride: st.k_batch,
3591 k_head_stride: st.k_head,
3592 k_seq_stride: st.k_seq,
3593 v_batch_stride: st.v_batch,
3594 v_head_stride: st.v_head,
3595 v_seq_stride: st.v_seq,
3596 o_batch_stride: st.o_batch,
3597 o_head_stride: st.o_head,
3598 o_seq_stride: st.o_seq,
3599 });
3600 }
3601 Op::AttentionBackward {
3602 num_heads: _,
3603 head_dim,
3604 mask_kind,
3605 wrt,
3606 } => {
3607 use rlx_ir::op::AttentionBwdWrt;
3608 let q_id = node.inputs[0];
3609 let k_id = node.inputs[1];
3610 let v_id = node.inputs[2];
3611 let dy_id = node.inputs[3];
3612 let q_shape = graph.node(q_id).shape.dims();
3613 let k_shape = graph.node(k_id).shape.dims();
3614 if q_shape.len() != 4 {
3615 panic!("rlx-cuda AttentionBackward: unfuse should have promoted to rank-4");
3616 }
3617 let batch = q_shape[0].unwrap_static() as u32;
3618 let heads = q_shape[1].unwrap_static() as u32;
3619 let seq_q = q_shape[2].unwrap_static() as u32;
3620 let seq_k = k_shape[2].unwrap_static() as u32;
3621 let hd = *head_dim as u32;
3622 let scale = 1.0_f32 / (hd as f32).sqrt();
3623 let (mask_kind_id, mask_off, window) = match mask_kind {
3624 MaskKind::None => (0u32, 0u32, 0u32),
3625 MaskKind::Causal => (1u32, 0u32, 0u32),
3626 MaskKind::Custom => (2u32, (arena.offset(node.inputs[4]) / 4) as u32, 0u32),
3627 MaskKind::SlidingWindow(w) => (3u32, 0u32, *w as u32),
3628 MaskKind::Bias => (4u32, (arena.offset(node.inputs[4]) / 4) as u32, 0u32),
3629 };
3630 let wrt_id = match wrt {
3631 AttentionBwdWrt::Query => 0u32,
3632 AttentionBwdWrt::Key => 1u32,
3633 AttentionBwdWrt::Value => 2u32,
3634 };
3635 schedule.push(Step::AttentionBackward {
3636 batch,
3637 heads,
3638 seq_q,
3639 seq_k,
3640 head_dim: hd,
3641 q_off: (arena.offset(q_id) / 4) as u32,
3642 k_off: (arena.offset(k_id) / 4) as u32,
3643 v_off: (arena.offset(v_id) / 4) as u32,
3644 dy_off: (arena.offset(dy_id) / 4) as u32,
3645 out_off: (arena.offset(node.id) / 4) as u32,
3646 mask_off,
3647 mask_kind: mask_kind_id,
3648 scale_bits: scale.to_bits(),
3649 window,
3650 wrt: wrt_id,
3651 });
3652 }
3653 Op::Rope { head_dim, n_rot: _ } => {
3654 let x_id = node.inputs[0];
3655 let cos_id = node.inputs[1];
3656 let sin_id = node.inputs[2];
3657 let x_shape = graph.node(x_id).shape.dims();
3658 let last = x_shape.last().map(|d| d.unwrap_static()).unwrap_or(0);
3659 if !last.is_multiple_of(*head_dim) {
3660 panic!(
3661 "rlx-cuda Rope: last_dim {} not multiple of head_dim {}",
3662 last, head_dim
3663 );
3664 }
3665 if head_dim % 2 != 0 {
3666 panic!("rlx-cuda Rope: head_dim must be even");
3667 }
3668 let total: u32 = x_shape.iter().map(|d| d.unwrap_static() as u32).product();
3669 let seq = x_shape[x_shape.len() - 2].unwrap_static() as u32;
3670 schedule.push(Step::Rope {
3671 n_total: total,
3672 seq,
3673 head_dim: *head_dim as u32,
3674 half: (*head_dim / 2) as u32,
3675 in_off: (arena.offset(x_id) / 4) as u32,
3676 cos_off: (arena.offset(cos_id) / 4) as u32,
3677 sin_off: (arena.offset(sin_id) / 4) as u32,
3678 out_off: (arena.offset(node.id) / 4) as u32,
3679 last_dim: last as u32,
3680 });
3681 }
3682 Op::Cumsum { axis: _, exclusive } => {
3683 let in_id = node.inputs[0];
3684 let in_dims = graph.node(in_id).shape.dims();
3685 let inner = in_dims.last().unwrap().unwrap_static() as u32;
3686 let outer = in_dims[..in_dims.len() - 1]
3687 .iter()
3688 .map(|d| d.unwrap_static() as u32)
3689 .product::<u32>()
3690 .max(1);
3691 schedule.push(Step::Cumsum {
3692 outer,
3693 inner,
3694 in_off: (arena.offset(in_id) / 4) as u32,
3695 out_off: (arena.offset(node.id) / 4) as u32,
3696 exclusive: if *exclusive { 1 } else { 0 },
3697 });
3698 }
3699 Op::TopK { k } => {
3700 let in_id = node.inputs[0];
3701 let in_dims = graph.node(in_id).shape.dims();
3702 let inner = in_dims.last().unwrap().unwrap_static() as u32;
3703 let outer = in_dims[..in_dims.len() - 1]
3704 .iter()
3705 .map(|d| d.unwrap_static() as u32)
3706 .product::<u32>()
3707 .max(1);
3708 schedule.push(Step::TopK {
3709 outer,
3710 inner,
3711 k: *k as u32,
3712 in_off: (arena.offset(in_id) / 4) as u32,
3713 out_off: (arena.offset(node.id) / 4) as u32,
3714 });
3715 }
3716 Op::GroupedMatMul => {
3717 let in_id = node.inputs[0];
3718 let w_id = node.inputs[1];
3719 let idx_id = node.inputs[2];
3720 let in_dims = graph.node(in_id).shape.dims();
3721 let w_dims = graph.node(w_id).shape.dims();
3722 let m = in_dims[0].unwrap_static() as u32;
3723 let k = in_dims[1].unwrap_static() as u32;
3724 let n = w_dims[2].unwrap_static() as u32;
3725 let ne = w_dims[0].unwrap_static() as u32;
3726 schedule.push(Step::GroupedMatmul {
3727 m,
3728 k,
3729 n,
3730 num_experts: ne,
3731 in_off: (arena.offset(in_id) / 4) as u32,
3732 w_off: (arena.offset(w_id) / 4) as u32,
3733 idx_off: (arena.offset(idx_id) / 4) as u32,
3734 out_off: (arena.offset(node.id) / 4) as u32,
3735 });
3736 }
3737 Op::DequantGroupedMatMul { scheme } => {
3738 let in_id = node.inputs[0];
3739 let w_id = node.inputs[1];
3740 let idx_id = node.inputs[2];
3741 let in_dims = graph.node(in_id).shape.dims();
3742 let out_dims = node.shape.dims();
3743 let m = in_dims[0].unwrap_static() as u32;
3744 let k = in_dims[1].unwrap_static() as u32;
3745 let n = out_dims[out_dims.len() - 1].unwrap_static() as u32;
3746 let block_elems = scheme.gguf_block_size() as usize;
3747 let block_bytes = scheme.gguf_block_bytes() as usize;
3748 let slab_bytes = (k as usize * n as usize) / block_elems * block_bytes;
3749 let total_bytes = graph.node(w_id).shape.num_elements().unwrap();
3750 let ne = (total_bytes / slab_bytes.max(1)) as u32;
3751 schedule.push(Step::DequantGroupedMatmulGguf {
3752 m,
3753 k,
3754 n,
3755 num_experts: ne,
3756 scheme_id: crate::gguf_host::gguf_scheme_id(*scheme),
3757 x_byte_off: arena.offset(in_id) as u32,
3758 w_byte_off: arena.offset(w_id) as u32,
3759 idx_byte_off: arena.offset(idx_id) as u32,
3760 out_byte_off: arena.offset(node.id) as u32,
3761 });
3762 }
3763 Op::ScatterAdd => {
3764 let upd_id = node.inputs[0];
3765 let idx_id = node.inputs[1];
3766 let upd_dims = graph.node(upd_id).shape.dims();
3767 let out_dims = node.shape.dims();
3768 let num_updates = upd_dims[0].unwrap_static() as u32;
3769 let trailing: u32 = upd_dims
3770 .iter()
3771 .skip(1)
3772 .map(|d| d.unwrap_static() as u32)
3773 .product::<u32>()
3774 .max(1);
3775 let out_dim = out_dims[0].unwrap_static() as u32;
3776 let out_total = out_dim * trailing;
3777 let out_off = (arena.offset(node.id) / 4) as u32;
3778 schedule.push(Step::ScatterAddZero { out_off, out_total });
3779 schedule.push(Step::ScatterAddAcc {
3780 out_off,
3781 upd_off: (arena.offset(upd_id) / 4) as u32,
3782 idx_off: (arena.offset(idx_id) / 4) as u32,
3783 num_updates,
3784 trailing,
3785 out_dim,
3786 });
3787 }
3788 Op::DequantMatMul { scheme } => {
3789 use rlx_ir::quant::QuantScheme;
3790 let x_id = node.inputs[0];
3791 let w_id = node.inputs[1];
3792 let out_dims = node.shape.dims();
3793 let x_dims = graph.node(x_id).shape.dims();
3794 let m = out_dims[0].unwrap_static() as u32;
3795 let n = out_dims[1].unwrap_static() as u32;
3796 let k = x_dims[1].unwrap_static() as u32;
3797 if scheme.is_gguf() {
3798 schedule.push(Step::DequantMatmulGguf {
3799 m,
3800 k,
3801 n,
3802 scheme_id: crate::gguf_host::gguf_scheme_id(*scheme),
3803 x_byte_off: arena.offset(x_id) as u32,
3804 w_byte_off: arena.offset(w_id) as u32,
3805 out_byte_off: arena.offset(node.id) as u32,
3806 });
3807 } else {
3808 let (block_size, scheme_id) = match scheme {
3809 QuantScheme::Int8Block { block_size } => (*block_size, 0u32),
3810 QuantScheme::Int8BlockAsym { block_size } => (*block_size, 1u32),
3811 QuantScheme::Int4Block { block_size } => (*block_size, 2u32),
3812 QuantScheme::Fp8E4m3 => (1, 3u32),
3813 QuantScheme::Fp8E5m2 => (1, 4u32),
3814 QuantScheme::Nvfp4Block => (rlx_ir::NVFP4_GROUP_SIZE as u32, 5u32),
3815 other => panic!("rlx-cuda DequantMatMul: unsupported scheme {other:?}"),
3816 };
3817 let scale_id = node.inputs[2];
3818 let zp_id = node.inputs[3];
3819 schedule.push(Step::DequantMatmul {
3820 m,
3821 k,
3822 n,
3823 block_size,
3824 scheme_id,
3825 x_off: (arena.offset(x_id) / 4) as u32,
3826 w_off: (arena.offset(w_id) / 4) as u32,
3827 scale_off: (arena.offset(scale_id) / 4) as u32,
3828 zp_off: (arena.offset(zp_id) / 4) as u32,
3829 out_off: (arena.offset(node.id) / 4) as u32,
3830 });
3831 }
3832 }
3833 Op::SelectiveScan { state_size } => {
3834 if *state_size > 256 {
3835 panic!("rlx-cuda SelectiveScan: state_size {state_size} > 256 cap");
3836 }
3837 let x_id = node.inputs[0];
3838 let dt_id = node.inputs[1];
3839 let a_id = node.inputs[2];
3840 let b_id = node.inputs[3];
3841 let c_id = node.inputs[4];
3842 let in_dims = graph.node(x_id).shape.dims();
3843 schedule.push(Step::SelectiveScan {
3844 batch: in_dims[0].unwrap_static() as u32,
3845 seq: in_dims[1].unwrap_static() as u32,
3846 hidden: in_dims[2].unwrap_static() as u32,
3847 state_size: *state_size as u32,
3848 x_off: (arena.offset(x_id) / 4) as u32,
3849 delta_off: (arena.offset(dt_id) / 4) as u32,
3850 a_off: (arena.offset(a_id) / 4) as u32,
3851 b_off: (arena.offset(b_id) / 4) as u32,
3852 c_off: (arena.offset(c_id) / 4) as u32,
3853 out_off: (arena.offset(node.id) / 4) as u32,
3854 });
3855 }
3856 Op::Fft { inverse, norm } => {
3857 let in_id = node.inputs[0];
3858 let in_shape = graph.node(in_id).shape.clone();
3859 let meta = rlx_ir::fft::fft_meta(&in_shape);
3860 let dtype = in_shape.dtype();
3861 let use_gpu = matches!(dtype, rlx_ir::DType::F32)
3862 && meta.n_complex.is_power_of_two()
3863 && meta.n_complex >= 2;
3864 schedule.push(Step::Fft {
3865 src_byte_off: arena.offset(in_id) as u32,
3866 dst_byte_off: arena.offset(node.id) as u32,
3867 outer: meta.outer as u32,
3868 n_complex: meta.n_complex as u32,
3869 inverse: *inverse,
3870 norm_tag: norm.tag(),
3871 dtype_tag: fft_dtype_tag(dtype),
3872 use_gpu,
3873 });
3874 }
3875 Op::LogMel => {
3876 let spec_shape = graph.node(node.inputs[0]).shape.clone();
3877 let filt_shape = graph.node(node.inputs[1]).shape.clone();
3878 let meta = rlx_ir::audio::log_mel_meta(&spec_shape, &filt_shape)
3879 .unwrap_or_else(|e| panic!("Op::LogMel: {e}"));
3880 schedule.push(Step::LogMelHost {
3881 spec_byte_off: arena.offset(node.inputs[0]) as u32,
3882 filt_byte_off: arena.offset(node.inputs[1]) as u32,
3883 dst_byte_off: arena.offset(node.id) as u32,
3884 outer: meta.outer as u32,
3885 n_fft: meta.n_fft as u32,
3886 n_bins: meta.n_bins as u32,
3887 n_mels: meta.n_mels as u32,
3888 });
3889 }
3890 Op::LogMelBackward => {
3891 let spec_shape = graph.node(node.inputs[0]).shape.clone();
3892 let filt_shape = graph.node(node.inputs[1]).shape.clone();
3893 let meta = rlx_ir::audio::log_mel_meta(&spec_shape, &filt_shape)
3894 .unwrap_or_else(|e| panic!("Op::LogMelBackward: {e}"));
3895 schedule.push(Step::LogMelBackwardHost {
3896 spec_byte_off: arena.offset(node.inputs[0]) as u32,
3897 filt_byte_off: arena.offset(node.inputs[1]) as u32,
3898 dy_byte_off: arena.offset(node.inputs[2]) as u32,
3899 dst_byte_off: arena.offset(node.id) as u32,
3900 outer: meta.outer as u32,
3901 n_fft: meta.n_fft as u32,
3902 n_bins: meta.n_bins as u32,
3903 n_mels: meta.n_mels as u32,
3904 });
3905 }
3906 Op::WelchPeaks { k, n_segments } => {
3907 let spec_shape = graph.node(node.inputs[0]).shape.clone();
3908 let meta = rlx_ir::audio::welch_peaks_meta(&spec_shape, *k, *n_segments)
3909 .unwrap_or_else(|e| panic!("Op::WelchPeaks: {e}"));
3910 let use_gpu = rlx_ir::audio::welch_peaks_gpu_native_eligible(
3911 &spec_shape,
3912 *k,
3913 *n_segments,
3914 )
3915 .unwrap_or(false);
3916 if use_gpu {
3917 schedule.push(Step::WelchPeaksGpu {
3918 spec_off: (arena.offset(node.inputs[0]) / 4) as u32,
3919 dst_off: (arena.offset(node.id) / 4) as u32,
3920 welch_batch: meta.welch_batch as u32,
3921 n_fft: meta.n_fft as u32,
3922 n_segments: meta.n_segments as u32,
3923 k: meta.k as u32,
3924 n_bins: meta.n_bins as u32,
3925 });
3926 } else {
3927 schedule.push(Step::WelchPeaksHost {
3928 spec_byte_off: arena.offset(node.inputs[0]) as u32,
3929 dst_byte_off: arena.offset(node.id) as u32,
3930 welch_batch: meta.welch_batch as u32,
3931 n_fft: meta.n_fft as u32,
3932 n_segments: meta.n_segments as u32,
3933 k: meta.k as u32,
3934 });
3935 }
3936 }
3937 Op::Im2Col {
3938 kernel_size,
3939 stride,
3940 padding,
3941 dilation,
3942 } => {
3943 let x_shape = &graph.node(node.inputs[0]).shape;
3944 if kernel_size.len() != 2 || x_shape.rank() != 4 {
3945 panic!("rlx-cuda Im2Col: 2D NCHW only");
3946 }
3947 let n = match x_shape.dim(0) {
3948 rlx_ir::shape::Dim::Static(v) => v as u32,
3949 _ => 0,
3950 };
3951 let c_in = x_shape.dim(1).unwrap_static() as u32;
3952 let h = x_shape.dim(2).unwrap_static() as u32;
3953 let w = x_shape.dim(3).unwrap_static() as u32;
3954 let kh = kernel_size[0] as u32;
3955 let kw = kernel_size[1] as u32;
3956 let sh = stride.first().copied().unwrap_or(1) as u32;
3957 let sw = stride.get(1).copied().unwrap_or(1) as u32;
3958 let ph = padding.first().copied().unwrap_or(0) as u32;
3959 let pw = padding.get(1).copied().unwrap_or(0) as u32;
3960 let dh = dilation.first().copied().unwrap_or(1) as u32;
3961 let dw_dil = dilation.get(1).copied().unwrap_or(1) as u32;
3962 let h_out = rlx_ir::shape::conv2d_spatial_output(
3963 h as usize,
3964 kh as usize,
3965 sh as usize,
3966 ph as usize,
3967 dh as usize,
3968 ) as u32;
3969 let w_out = rlx_ir::shape::conv2d_spatial_output(
3970 w as usize,
3971 kw as usize,
3972 sw as usize,
3973 pw as usize,
3974 dw_dil as usize,
3975 ) as u32;
3976 schedule.push(Step::Im2ColHost {
3977 x_byte_off: arena.offset(node.inputs[0]) as u32,
3978 col_byte_off: arena.offset(node.id) as u32,
3979 n,
3980 c_in,
3981 h,
3982 w,
3983 h_out,
3984 w_out,
3985 kh,
3986 kw,
3987 sh,
3988 sw,
3989 ph,
3990 pw,
3991 dh,
3992 dw_dil,
3993 use_gpu: im2col_use_gpu(n, exec_mode),
3994 });
3995 }
3996 Op::GatedDeltaNet {
3997 state_size,
3998 carry_state,
3999 } => {
4000 if *state_size > rlx_cpu::gdn::GDN_MAX_STATE {
4001 panic!(
4002 "rlx-cuda GatedDeltaNet: state_size {state_size} > {}",
4003 rlx_cpu::gdn::GDN_MAX_STATE
4004 );
4005 }
4006 let q_id = node.inputs[0];
4007 let q_shape = &graph.node(q_id).shape;
4008 let state_off = if *carry_state {
4009 arena.offset(node.inputs[5])
4010 } else {
4011 0
4012 };
4013 schedule.push(Step::GatedDeltaNet {
4014 q_byte_off: arena.offset(q_id) as u32,
4015 k_byte_off: arena.offset(node.inputs[1]) as u32,
4016 v_byte_off: arena.offset(node.inputs[2]) as u32,
4017 g_byte_off: arena.offset(node.inputs[3]) as u32,
4018 beta_byte_off: arena.offset(node.inputs[4]) as u32,
4019 state_byte_off: state_off as u32,
4020 dst_byte_off: arena.offset(node.id) as u32,
4021 batch: q_shape.dim(0).unwrap_static() as u32,
4022 seq: q_shape.dim(1).unwrap_static() as u32,
4023 heads: q_shape.dim(2).unwrap_static() as u32,
4024 state_size: *state_size as u32,
4025 use_carry: *carry_state,
4026 });
4027 }
4028 Op::Custom { name, attrs, .. } => match name.as_str() {
4029 "llada2.group_limited_gate" => {
4030 let sig_id = node.inputs[0];
4031 let route_id = node.inputs[1];
4032 let n_elems = graph.node(sig_id).shape.num_elements().unwrap() as u32;
4033 let mut attr_buf = [0u8; 20];
4034 let n = attrs.len().min(20);
4035 attr_buf[..n].copy_from_slice(&attrs[..n]);
4036 schedule.push(Step::Llada2GroupLimitedGate {
4037 sig_off: (arena.offset(sig_id) / 4) as u32,
4038 route_off: (arena.offset(route_id) / 4) as u32,
4039 out_off: (arena.offset(node.id) / 4) as u32,
4040 n_elems,
4041 attrs: attr_buf,
4042 });
4043 }
4044 "umap.knn" => {
4045 let pw_id = node.inputs[0];
4046 let n = graph.node(pw_id).shape.dims()[0].unwrap_static() as u32;
4047 let k = u32::from_le_bytes(attrs[..4].try_into().unwrap());
4048 schedule.push(Step::UmapKnn {
4049 pairwise_off: (arena.offset(pw_id) / 4) as u32,
4050 out_off: (arena.offset(node.id) / 4) as u32,
4051 n,
4052 k,
4053 });
4054 }
4055 other => panic!("rlx-cuda: unsupported Op::Custom('{other}')"),
4056 },
4057
4058 Op::GaussianSplatRender {
4059 width,
4060 height,
4061 tile_size,
4062 radius_scale,
4063 alpha_cutoff,
4064 max_splat_steps,
4065 transmittance_threshold,
4066 max_list_entries,
4067 } => {
4068 let elem_len = |id: NodeId| -> u32 {
4069 graph.node(id).shape.num_elements().unwrap_or(0) as u32
4070 };
4071 schedule.push(Step::GaussianSplatRender {
4072 positions_off: arena.offset(node.inputs[0]) as u32,
4073 positions_len: elem_len(node.inputs[0]),
4074 scales_off: arena.offset(node.inputs[1]) as u32,
4075 scales_len: elem_len(node.inputs[1]),
4076 rotations_off: arena.offset(node.inputs[2]) as u32,
4077 rotations_len: elem_len(node.inputs[2]),
4078 opacities_off: arena.offset(node.inputs[3]) as u32,
4079 opacities_len: elem_len(node.inputs[3]),
4080 colors_off: arena.offset(node.inputs[4]) as u32,
4081 colors_len: elem_len(node.inputs[4]),
4082 sh_coeffs_off: arena.offset(node.inputs[5]) as u32,
4083 sh_coeffs_len: elem_len(node.inputs[5]),
4084 meta_off: arena.offset(node.inputs[6]) as u32,
4085 dst_off: arena.offset(node.id) as u32,
4086 dst_len: node.shape.num_elements().unwrap_or(0) as u32,
4087 width: *width,
4088 height: *height,
4089 tile_size: *tile_size,
4090 radius_scale: *radius_scale,
4091 alpha_cutoff: *alpha_cutoff,
4092 max_splat_steps: *max_splat_steps,
4093 transmittance_threshold: *transmittance_threshold,
4094 max_list_entries: *max_list_entries,
4095 });
4096 }
4097
4098 Op::GaussianSplatRenderBackward {
4099 width,
4100 height,
4101 tile_size,
4102 radius_scale,
4103 alpha_cutoff,
4104 max_splat_steps,
4105 transmittance_threshold,
4106 max_list_entries,
4107 loss_grad_clip,
4108 sh_band,
4109 max_anisotropy,
4110 } => {
4111 let elem_len = |id: NodeId| -> u32 {
4112 graph.node(id).shape.num_elements().unwrap_or(0) as u32
4113 };
4114 schedule.push(Step::GaussianSplatRenderBackward {
4115 positions_off: arena.offset(node.inputs[0]) as u32,
4116 positions_len: elem_len(node.inputs[0]),
4117 scales_off: arena.offset(node.inputs[1]) as u32,
4118 scales_len: elem_len(node.inputs[1]),
4119 rotations_off: arena.offset(node.inputs[2]) as u32,
4120 rotations_len: elem_len(node.inputs[2]),
4121 opacities_off: arena.offset(node.inputs[3]) as u32,
4122 opacities_len: elem_len(node.inputs[3]),
4123 colors_off: arena.offset(node.inputs[4]) as u32,
4124 colors_len: elem_len(node.inputs[4]),
4125 sh_coeffs_off: arena.offset(node.inputs[5]) as u32,
4126 sh_coeffs_len: elem_len(node.inputs[5]),
4127 meta_off: arena.offset(node.inputs[6]) as u32,
4128 d_loss_off: arena.offset(node.inputs[7]) as u32,
4129 d_loss_len: elem_len(node.inputs[7]),
4130 packed_off: arena.offset(node.id) as u32,
4131 packed_len: node.shape.num_elements().unwrap_or(0) as u32,
4132 width: *width,
4133 height: *height,
4134 tile_size: *tile_size,
4135 radius_scale: *radius_scale,
4136 alpha_cutoff: *alpha_cutoff,
4137 max_splat_steps: *max_splat_steps,
4138 transmittance_threshold: *transmittance_threshold,
4139 max_list_entries: *max_list_entries,
4140 loss_grad_clip: *loss_grad_clip,
4141 sh_band: *sh_band,
4142 max_anisotropy: *max_anisotropy,
4143 });
4144 }
4145
4146 Op::GaussianSplatPrepare {
4147 width,
4148 height,
4149 tile_size,
4150 radius_scale,
4151 alpha_cutoff,
4152 max_splat_steps,
4153 transmittance_threshold,
4154 max_list_entries,
4155 } => {
4156 let elem_len = |id: NodeId| -> u32 {
4157 graph.node(id).shape.num_elements().unwrap_or(0) as u32
4158 };
4159 schedule.push(Step::GaussianSplatPrepare {
4160 positions_off: arena.offset(node.inputs[0]) as u32,
4161 positions_len: elem_len(node.inputs[0]),
4162 scales_off: arena.offset(node.inputs[1]) as u32,
4163 scales_len: elem_len(node.inputs[1]),
4164 rotations_off: arena.offset(node.inputs[2]) as u32,
4165 rotations_len: elem_len(node.inputs[2]),
4166 opacities_off: arena.offset(node.inputs[3]) as u32,
4167 opacities_len: elem_len(node.inputs[3]),
4168 colors_off: arena.offset(node.inputs[4]) as u32,
4169 colors_len: elem_len(node.inputs[4]),
4170 sh_coeffs_off: arena.offset(node.inputs[5]) as u32,
4171 sh_coeffs_len: elem_len(node.inputs[5]),
4172 meta_off: arena.offset(node.inputs[6]) as u32,
4173 meta_len: elem_len(node.inputs[6]),
4174 prep_off: arena.offset(node.id) as u32,
4175 prep_len: node.shape.num_elements().unwrap_or(0) as u32,
4176 width: *width,
4177 height: *height,
4178 tile_size: *tile_size,
4179 radius_scale: *radius_scale,
4180 alpha_cutoff: *alpha_cutoff,
4181 max_splat_steps: *max_splat_steps,
4182 transmittance_threshold: *transmittance_threshold,
4183 max_list_entries: *max_list_entries,
4184 });
4185 }
4186
4187 Op::GaussianSplatRasterize {
4188 width,
4189 height,
4190 tile_size,
4191 alpha_cutoff,
4192 max_splat_steps,
4193 transmittance_threshold,
4194 max_list_entries,
4195 } => {
4196 let elem_len = |id: NodeId| -> u32 {
4197 graph.node(id).shape.num_elements().unwrap_or(0) as u32
4198 };
4199 let prep_id = node.inputs[0];
4200 let count = match &graph.node(prep_id).op {
4201 rlx_ir::Op::GaussianSplatPrepare { .. } => {
4202 elem_len(graph.node(prep_id).inputs[0]) / 3
4203 }
4204 _ => 1,
4205 };
4206 schedule.push(Step::GaussianSplatRasterize {
4207 prep_off: arena.offset(prep_id) as u32,
4208 prep_len: elem_len(prep_id),
4209 meta_off: arena.offset(node.inputs[1]) as u32,
4210 meta_len: elem_len(node.inputs[1]),
4211 dst_off: arena.offset(node.id) as u32,
4212 dst_len: node.shape.num_elements().unwrap_or(0) as u32,
4213 count,
4214 width: *width,
4215 height: *height,
4216 tile_size: *tile_size,
4217 alpha_cutoff: *alpha_cutoff,
4218 max_splat_steps: *max_splat_steps,
4219 transmittance_threshold: *transmittance_threshold,
4220 max_list_entries: *max_list_entries,
4221 });
4222 }
4223
4224 Op::Pool {
4225 kind,
4226 kernel_size,
4227 stride,
4228 padding,
4229 } => {
4230 let in_id = node.inputs[0];
4231 let in_dims = graph.node(in_id).shape.dims();
4232 let out_dims = node.shape.dims();
4233 let op_id = reduce_op_id(*kind);
4234 let in_off = (arena.offset(in_id) / 4) as u32;
4235 let out_off = (arena.offset(node.id) / 4) as u32;
4236 match kernel_size.len() {
4237 1 => {
4238 schedule.push(Step::Pool1d {
4239 n: in_dims[0].unwrap_static() as u32,
4240 c: in_dims[1].unwrap_static() as u32,
4241 l: in_dims[2].unwrap_static() as u32,
4242 l_out: out_dims[2].unwrap_static() as u32,
4243 kl: kernel_size[0] as u32,
4244 sl: stride[0] as u32,
4245 pl: padding[0] as u32,
4246 op: op_id,
4247 in_off,
4248 out_off,
4249 });
4250 }
4251 2 => {
4252 schedule.push(Step::Pool2d {
4253 n: in_dims[0].unwrap_static() as u32,
4254 c: in_dims[1].unwrap_static() as u32,
4255 h: in_dims[2].unwrap_static() as u32,
4256 w: in_dims[3].unwrap_static() as u32,
4257 h_out: out_dims[2].unwrap_static() as u32,
4258 w_out: out_dims[3].unwrap_static() as u32,
4259 kh: kernel_size[0] as u32,
4260 kw: kernel_size[1] as u32,
4261 sh: stride[0] as u32,
4262 sw: stride[1] as u32,
4263 ph: padding[0] as u32,
4264 pw: padding[1] as u32,
4265 op: op_id,
4266 in_off,
4267 out_off,
4268 });
4269 }
4270 3 => {
4271 schedule.push(Step::Pool3d {
4272 n: in_dims[0].unwrap_static() as u32,
4273 c: in_dims[1].unwrap_static() as u32,
4274 d: in_dims[2].unwrap_static() as u32,
4275 h: in_dims[3].unwrap_static() as u32,
4276 w: in_dims[4].unwrap_static() as u32,
4277 d_out: out_dims[2].unwrap_static() as u32,
4278 h_out: out_dims[3].unwrap_static() as u32,
4279 w_out: out_dims[4].unwrap_static() as u32,
4280 kd: kernel_size[0] as u32,
4281 kh: kernel_size[1] as u32,
4282 kw: kernel_size[2] as u32,
4283 sd: stride[0] as u32,
4284 sh: stride[1] as u32,
4285 sw: stride[2] as u32,
4286 pd: padding[0] as u32,
4287 ph: padding[1] as u32,
4288 pw: padding[2] as u32,
4289 op: op_id,
4290 in_off,
4291 out_off,
4292 });
4293 }
4294 other => panic!("rlx-cuda Pool: unsupported kernel rank {other}"),
4295 }
4296 }
4297 Op::LayerNorm2d { eps } => {
4298 let in_shape = &graph.node(node.inputs[0]).shape;
4299 schedule.push(Step::LayerNorm2d {
4300 src_off: (arena.offset(node.inputs[0]) / 4) as u32,
4301 g_off: (arena.offset(node.inputs[1]) / 4) as u32,
4302 b_off: (arena.offset(node.inputs[2]) / 4) as u32,
4303 dst_off: (arena.offset(node.id) / 4) as u32,
4304 n: in_shape.dim(0).unwrap_static() as u32,
4305 c: in_shape.dim(1).unwrap_static() as u32,
4306 h: in_shape.dim(2).unwrap_static() as u32,
4307 w: in_shape.dim(3).unwrap_static() as u32,
4308 eps_bits: eps.to_bits(),
4309 });
4310 }
4311 Op::ConvTranspose2d {
4312 kernel_size,
4313 stride,
4314 padding,
4315 dilation,
4316 output_padding: _,
4317 groups,
4318 } => {
4319 let in_shape = &graph.node(node.inputs[0]).shape;
4320 let out_shape = &node.shape;
4321 schedule.push(Step::ConvTranspose2d {
4322 src_off: (arena.offset(node.inputs[0]) / 4) as u32,
4323 w_off: (arena.offset(node.inputs[1]) / 4) as u32,
4324 dst_off: (arena.offset(node.id) / 4) as u32,
4325 n: in_shape.dim(0).unwrap_static() as u32,
4326 c_in: in_shape.dim(1).unwrap_static() as u32,
4327 h: in_shape.dim(2).unwrap_static() as u32,
4328 w_in: in_shape.dim(3).unwrap_static() as u32,
4329 c_out: out_shape.dim(1).unwrap_static() as u32,
4330 h_out: out_shape.dim(2).unwrap_static() as u32,
4331 w_out: out_shape.dim(3).unwrap_static() as u32,
4332 kh: kernel_size[0] as u32,
4333 kw: kernel_size[1] as u32,
4334 sh: stride.first().copied().unwrap_or(1) as u32,
4335 sw: stride.get(1).copied().unwrap_or(1) as u32,
4336 ph: padding.first().copied().unwrap_or(0) as u32,
4337 pw: padding.get(1).copied().unwrap_or(0) as u32,
4338 dh: dilation.first().copied().unwrap_or(1) as u32,
4339 dw: dilation.get(1).copied().unwrap_or(1) as u32,
4340 groups: *groups as u32,
4341 });
4342 }
4343 Op::GroupNorm { num_groups, eps } => {
4344 let in_shape = &graph.node(node.inputs[0]).shape;
4345 schedule.push(Step::GroupNorm {
4346 src_off: (arena.offset(node.inputs[0]) / 4) as u32,
4347 g_off: (arena.offset(node.inputs[1]) / 4) as u32,
4348 b_off: (arena.offset(node.inputs[2]) / 4) as u32,
4349 dst_off: (arena.offset(node.id) / 4) as u32,
4350 n: in_shape.dim(0).unwrap_static() as u32,
4351 c: in_shape.dim(1).unwrap_static() as u32,
4352 h: in_shape.dim(2).unwrap_static() as u32,
4353 w: in_shape.dim(3).unwrap_static() as u32,
4354 num_groups: *num_groups as u32,
4355 eps_bits: eps.to_bits(),
4356 });
4357 }
4358 Op::ResizeNearest2x => {
4359 let in_shape = &graph.node(node.inputs[0]).shape;
4360 schedule.push(Step::ResizeNearest2x {
4361 src_off: (arena.offset(node.inputs[0]) / 4) as u32,
4362 dst_off: (arena.offset(node.id) / 4) as u32,
4363 n: in_shape.dim(0).unwrap_static() as u32,
4364 c: in_shape.dim(1).unwrap_static() as u32,
4365 h: in_shape.dim(2).unwrap_static() as u32,
4366 w: in_shape.dim(3).unwrap_static() as u32,
4367 });
4368 }
4369 Op::Conv {
4370 kernel_size,
4371 stride,
4372 padding,
4373 dilation,
4374 groups,
4375 } => {
4376 let in_id = node.inputs[0];
4377 let w_id = node.inputs[1];
4378 let in_dims = graph.node(in_id).shape.dims();
4379 let w_dims = graph.node(w_id).shape.dims();
4380 let out_dims = node.shape.dims();
4381 let in_off = (arena.offset(in_id) / 4) as u32;
4382 let w_off = (arena.offset(w_id) / 4) as u32;
4383 let out_off = (arena.offset(node.id) / 4) as u32;
4384 match kernel_size.len() {
4385 1 => {
4386 schedule.push(Step::Conv1d {
4387 n: in_dims[0].unwrap_static() as u32,
4388 c_in: in_dims[1].unwrap_static() as u32,
4389 c_out: w_dims[0].unwrap_static() as u32,
4390 l: in_dims[2].unwrap_static() as u32,
4391 l_out: out_dims[2].unwrap_static() as u32,
4392 kl: kernel_size[0] as u32,
4393 sl: stride[0] as u32,
4394 pl: padding[0] as u32,
4395 dl: dilation[0] as u32,
4396 groups: *groups as u32,
4397 in_off,
4398 w_off,
4399 out_off,
4400 });
4401 }
4402 2 => {
4403 schedule.push(Step::Conv2d {
4404 n: in_dims[0].unwrap_static() as u32,
4405 c_in: in_dims[1].unwrap_static() as u32,
4406 c_out: w_dims[0].unwrap_static() as u32,
4407 h: in_dims[2].unwrap_static() as u32,
4408 w: in_dims[3].unwrap_static() as u32,
4409 h_out: out_dims[2].unwrap_static() as u32,
4410 w_out: out_dims[3].unwrap_static() as u32,
4411 kh: kernel_size[0] as u32,
4412 kw: kernel_size[1] as u32,
4413 sh: stride[0] as u32,
4414 sw: stride[1] as u32,
4415 ph: padding[0] as u32,
4416 pw: padding[1] as u32,
4417 dh: dilation[0] as u32,
4418 dw: dilation[1] as u32,
4419 groups: *groups as u32,
4420 in_off,
4421 w_off,
4422 out_off,
4423 });
4424 }
4425 3 => {
4426 schedule.push(Step::Conv3d {
4427 n: in_dims[0].unwrap_static() as u32,
4428 c_in: in_dims[1].unwrap_static() as u32,
4429 c_out: w_dims[0].unwrap_static() as u32,
4430 d: in_dims[2].unwrap_static() as u32,
4431 h: in_dims[3].unwrap_static() as u32,
4432 w: in_dims[4].unwrap_static() as u32,
4433 d_out: out_dims[2].unwrap_static() as u32,
4434 h_out: out_dims[3].unwrap_static() as u32,
4435 w_out: out_dims[4].unwrap_static() as u32,
4436 kd: kernel_size[0] as u32,
4437 kh: kernel_size[1] as u32,
4438 kw: kernel_size[2] as u32,
4439 sd: stride[0] as u32,
4440 sh: stride[1] as u32,
4441 sw: stride[2] as u32,
4442 pd: padding[0] as u32,
4443 ph: padding[1] as u32,
4444 pw: padding[2] as u32,
4445 dd: dilation[0] as u32,
4446 dh: dilation[1] as u32,
4447 dw: dilation[2] as u32,
4448 groups: *groups as u32,
4449 in_off,
4450 w_off,
4451 out_off,
4452 });
4453 }
4454 other => panic!("rlx-cuda Conv: unsupported kernel rank {other}"),
4455 }
4456 }
4457 Op::Sample {
4458 top_k,
4459 top_p,
4460 temperature,
4461 seed,
4462 } => {
4463 let in_id = node.inputs[0];
4464 let in_dims = graph.node(in_id).shape.dims();
4465 let inner = in_dims.last().unwrap().unwrap_static() as u32;
4466 let outer = in_dims[..in_dims.len() - 1]
4467 .iter()
4468 .map(|d| d.unwrap_static() as u32)
4469 .product::<u32>()
4470 .max(1);
4471 let is_greedy = *top_k == 0
4472 && (*top_p - 1.0).abs() < 1e-6
4473 && (*temperature - 1.0).abs() < 1e-6;
4474 if is_greedy {
4475 schedule.push(Step::Argmax {
4476 outer,
4477 inner,
4478 in_off: (arena.offset(in_id) / 4) as u32,
4479 out_off: (arena.offset(node.id) / 4) as u32,
4480 });
4481 } else {
4482 schedule.push(Step::Sample {
4483 outer,
4484 inner,
4485 in_off: (arena.offset(in_id) / 4) as u32,
4486 out_off: (arena.offset(node.id) / 4) as u32,
4487 top_k: *top_k as u32,
4488 top_p_bits: top_p.to_bits(),
4489 temp_bits: temperature.to_bits(),
4490 seed_lo: *seed as u32,
4491 seed_hi: (*seed >> 32) as u32,
4492 });
4493 }
4494 }
4495 Op::RmsNormBackwardInput { eps, .. }
4496 | Op::RmsNormBackwardGamma { eps, .. }
4497 | Op::RmsNormBackwardBeta { eps, .. } => {
4498 let x_shape = &graph.node(node.inputs[0]).shape;
4499 let h = x_shape.dim(x_shape.rank() - 1).unwrap_static() as u32;
4500 let rows = (x_shape.num_elements().unwrap() / h.max(1) as usize) as u32;
4501 let eps_bits = eps.to_bits();
4502 let off = |i: usize| arena.offset(node.inputs[i]) as u32;
4503 let common = (off(0), off(1), off(2), off(3), rows, h, eps_bits);
4504 match &node.op {
4505 Op::RmsNormBackwardInput { .. } => {
4506 schedule.push(Step::RmsNormBackwardInput {
4507 x_byte_off: common.0,
4508 gamma_byte_off: common.1,
4509 beta_byte_off: common.2,
4510 dy_byte_off: common.3,
4511 dx_byte_off: arena.offset(node.id) as u32,
4512 rows: common.4,
4513 h: common.5,
4514 eps_bits: common.6,
4515 });
4516 }
4517 Op::RmsNormBackwardGamma { .. } => {
4518 schedule.push(Step::RmsNormBackwardGamma {
4519 x_byte_off: common.0,
4520 gamma_byte_off: common.1,
4521 beta_byte_off: common.2,
4522 dy_byte_off: common.3,
4523 dgamma_byte_off: arena.offset(node.id) as u32,
4524 rows: common.4,
4525 h: common.5,
4526 eps_bits: common.6,
4527 });
4528 }
4529 Op::RmsNormBackwardBeta { .. } => {
4530 schedule.push(Step::RmsNormBackwardBeta {
4531 x_byte_off: common.0,
4532 gamma_byte_off: common.1,
4533 beta_byte_off: common.2,
4534 dy_byte_off: common.3,
4535 dbeta_byte_off: arena.offset(node.id) as u32,
4536 rows: common.4,
4537 h: common.5,
4538 eps_bits: common.6,
4539 });
4540 }
4541 _ => unreachable!(),
4542 }
4543 }
4544 Op::RopeBackward { head_dim, n_rot } => {
4545 let dy_shape = &graph.node(node.inputs[0]).shape;
4546 let (batch, seq, hidden) = if dy_shape.rank() >= 3 {
4547 (
4548 dy_shape.dim(0).unwrap_static() as u32,
4549 dy_shape.dim(1).unwrap_static() as u32,
4550 dy_shape.dim(2).unwrap_static() as u32,
4551 )
4552 } else {
4553 (
4554 1,
4555 dy_shape.dim(0).unwrap_static() as u32,
4556 dy_shape.dim(1).unwrap_static() as u32,
4557 )
4558 };
4559 let cos_len = graph.node(node.inputs[1]).shape.num_elements().unwrap() as u32;
4560 schedule.push(Step::RopeBackward {
4561 dy_byte_off: arena.offset(node.inputs[0]) as u32,
4562 cos_byte_off: arena.offset(node.inputs[1]) as u32,
4563 sin_byte_off: arena.offset(node.inputs[2]) as u32,
4564 dx_byte_off: arena.offset(node.id) as u32,
4565 batch,
4566 seq,
4567 hidden,
4568 head_dim: *head_dim as u32,
4569 n_rot: *n_rot as u32,
4570 cos_len,
4571 });
4572 }
4573 Op::CumsumBackward { exclusive, .. } => {
4574 let dy_shape = &graph.node(node.inputs[0]).shape;
4575 let cols = dy_shape.dim(dy_shape.rank() - 1).unwrap_static() as u32;
4576 let rows = (dy_shape.num_elements().unwrap() / cols.max(1) as usize) as u32;
4577 schedule.push(Step::CumsumBackward {
4578 dy_byte_off: arena.offset(node.inputs[0]) as u32,
4579 dx_byte_off: arena.offset(node.id) as u32,
4580 rows,
4581 cols,
4582 exclusive: *exclusive,
4583 });
4584 }
4585 Op::GatherBackward { .. } => {
4586 let dy_shape = &graph.node(node.inputs[0]).shape;
4587 let idx_shape = &graph.node(node.inputs[1]).shape;
4588 let out_shape = &node.shape;
4589 let rank = out_shape.rank();
4590 let axis = match &node.op {
4591 Op::GatherBackward { axis } => *axis,
4592 _ => 0,
4593 };
4594 let axis_u = if axis < 0 {
4595 (rank as i32 + axis) as usize
4596 } else {
4597 axis as usize
4598 };
4599 let outer: usize = (0..axis_u)
4600 .map(|i| dy_shape.dim(i).unwrap_static())
4601 .product::<usize>()
4602 .max(1);
4603 let num_idx = idx_shape.dim(axis_u).unwrap_static();
4604 let trailing: usize = (axis_u + 1..dy_shape.rank())
4605 .map(|i| dy_shape.dim(i).unwrap_static())
4606 .product::<usize>()
4607 .max(1);
4608 let axis_dim = out_shape.dim(axis_u).unwrap_static();
4609 schedule.push(Step::GatherBackward {
4610 dy_byte_off: arena.offset(node.inputs[0]) as u32,
4611 indices_byte_off: arena.offset(node.inputs[1]) as u32,
4612 dst_byte_off: arena.offset(node.id) as u32,
4613 outer: outer as u32,
4614 axis_dim: axis_dim as u32,
4615 num_idx: num_idx as u32,
4616 trailing: trailing as u32,
4617 });
4618 }
4619 Op::Conv2dBackwardInput {
4620 kernel_size,
4621 stride,
4622 padding,
4623 dilation,
4624 groups,
4625 } => {
4626 let dy_shape = &graph.node(node.inputs[0]).shape;
4627 let out_shape = &node.shape;
4628 if kernel_size.len() == 2 && dy_shape.rank() == 4 && out_shape.rank() == 4 {
4629 schedule.push(Step::Conv2dBackwardInput {
4630 dy_byte_off: arena.offset(node.inputs[0]) as u32,
4631 w_byte_off: arena.offset(node.inputs[1]) as u32,
4632 dx_byte_off: arena.offset(node.id) as u32,
4633 n: out_shape.dim(0).unwrap_static() as u32,
4634 c_in: out_shape.dim(1).unwrap_static() as u32,
4635 h: out_shape.dim(2).unwrap_static() as u32,
4636 w_in: out_shape.dim(3).unwrap_static() as u32,
4637 c_out: dy_shape.dim(1).unwrap_static() as u32,
4638 h_out: dy_shape.dim(2).unwrap_static() as u32,
4639 w_out: dy_shape.dim(3).unwrap_static() as u32,
4640 kh: kernel_size[0] as u32,
4641 kw: kernel_size[1] as u32,
4642 sh: stride.first().copied().unwrap_or(1) as u32,
4643 sw: stride.get(1).copied().unwrap_or(1) as u32,
4644 ph: padding.first().copied().unwrap_or(0) as u32,
4645 pw: padding.get(1).copied().unwrap_or(0) as u32,
4646 dh: dilation.first().copied().unwrap_or(1) as u32,
4647 dw: dilation.get(1).copied().unwrap_or(1) as u32,
4648 groups: *groups as u32,
4649 });
4650 } else {
4651 panic!("rlx-cuda: Conv2dBackwardInput expects 2-D conv on NCHW tensors");
4652 }
4653 }
4654 Op::Conv2dBackwardWeight {
4655 kernel_size,
4656 stride,
4657 padding,
4658 dilation,
4659 groups,
4660 } => {
4661 let x_shape = &graph.node(node.inputs[0]).shape;
4662 let dy_shape = &graph.node(node.inputs[1]).shape;
4663 if kernel_size.len() == 2 && x_shape.rank() == 4 && dy_shape.rank() == 4 {
4664 schedule.push(Step::Conv2dBackwardWeight {
4665 x_byte_off: arena.offset(node.inputs[0]) as u32,
4666 dy_byte_off: arena.offset(node.inputs[1]) as u32,
4667 dw_byte_off: arena.offset(node.id) as u32,
4668 n: x_shape.dim(0).unwrap_static() as u32,
4669 c_in: x_shape.dim(1).unwrap_static() as u32,
4670 h: x_shape.dim(2).unwrap_static() as u32,
4671 w: x_shape.dim(3).unwrap_static() as u32,
4672 c_out: dy_shape.dim(1).unwrap_static() as u32,
4673 h_out: dy_shape.dim(2).unwrap_static() as u32,
4674 w_out: dy_shape.dim(3).unwrap_static() as u32,
4675 kh: kernel_size[0] as u32,
4676 kw: kernel_size[1] as u32,
4677 sh: stride.first().copied().unwrap_or(1) as u32,
4678 sw: stride.get(1).copied().unwrap_or(1) as u32,
4679 ph: padding.first().copied().unwrap_or(0) as u32,
4680 pw: padding.get(1).copied().unwrap_or(0) as u32,
4681 dh: dilation.first().copied().unwrap_or(1) as u32,
4682 dw_dil: dilation.get(1).copied().unwrap_or(1) as u32,
4683 groups: *groups as u32,
4684 });
4685 } else {
4686 panic!("rlx-cuda: Conv2dBackwardWeight expects 2-D conv on NCHW tensors");
4687 }
4688 }
4689 Op::MaxPool2dBackward {
4690 kernel_size,
4691 stride,
4692 padding,
4693 } => {
4694 let x_shape = &graph.node(node.inputs[0]).shape;
4695 let dy_shape = &graph.node(node.inputs[1]).shape;
4696 if kernel_size.len() == 2 && x_shape.rank() == 4 && dy_shape.rank() == 4 {
4697 schedule.push(Step::MaxPool2dBackward {
4698 x_byte_off: arena.offset(node.inputs[0]) as u32,
4699 dy_byte_off: arena.offset(node.inputs[1]) as u32,
4700 dx_byte_off: arena.offset(node.id) as u32,
4701 n: x_shape.dim(0).unwrap_static() as u32,
4702 c: x_shape.dim(1).unwrap_static() as u32,
4703 h: x_shape.dim(2).unwrap_static() as u32,
4704 w: x_shape.dim(3).unwrap_static() as u32,
4705 h_out: dy_shape.dim(2).unwrap_static() as u32,
4706 w_out: dy_shape.dim(3).unwrap_static() as u32,
4707 kh: kernel_size[0] as u32,
4708 kw: kernel_size[1] as u32,
4709 sh: stride.first().copied().unwrap_or(1) as u32,
4710 sw: stride.get(1).copied().unwrap_or(1) as u32,
4711 ph: padding.first().copied().unwrap_or(0) as u32,
4712 pw: padding.get(1).copied().unwrap_or(0) as u32,
4713 });
4714 } else {
4715 panic!("rlx-cuda: MaxPool2dBackward expects 2-D pool on NCHW tensors");
4716 }
4717 }
4718 other => panic!(
4719 "rlx-cuda: op {other:?} not yet lowered. \
4720 Open a follow-up PR if you hit this — every other op \
4721 in the IR is wired."
4722 ),
4723 }
4724 }
4725
4726 let schedule = fuse_elementwise_chains(schedule);
4727
4728 let blas = cuda_blas();
4729 let needs_blas_lt = schedule_needs_blas_lt(&schedule);
4730 let needs_dnn = schedule_needs_dnn(&schedule);
4731 let blas_lt = if needs_blas_lt {
4732 cuda_blas_lt_handle()
4733 } else {
4734 None
4735 };
4736 let blas_lt_workspace = if needs_blas_lt {
4737 cuda_blas_lt_workspace()
4738 } else {
4739 None
4740 };
4741 let dnn = if needs_dnn { cuda_dnn_handle() } else { None };
4742 let dnn_workspace = if needs_dnn {
4743 cuda_dnn_workspace()
4744 } else {
4745 None
4746 };
4747
4748 let streams = match exec_mode {
4749 ExecMode::MultiStream(n) if n > 1 => {
4750 let mut v = Vec::with_capacity(n);
4751 for _ in 0..n {
4752 if let Ok(s) = ctx.new_stream() {
4753 v.push(s);
4754 }
4755 }
4756 v
4757 }
4758 _ => Vec::new(),
4759 };
4760
4761 let output_staging: Vec<F32HostSlot> = graph
4762 .outputs
4763 .iter()
4764 .map(|&id| {
4765 let elems = graph.node(id).shape.num_elements().unwrap_or(0);
4766 F32HostSlot::new(&ctx, elems, pinned_output_staging_enabled())
4767 })
4768 .collect();
4769
4770 let mut input_staging = HashMap::new();
4771 if pinned_input_staging_enabled(exec_mode) {
4772 for (name, &id) in &input_offsets {
4773 let elems = graph.node(id).shape.num_elements().unwrap_or(0);
4774 input_staging.insert(name.clone(), F32HostSlot::new(&ctx, elems, true));
4775 }
4776 }
4777
4778 let replay_event = if exec_mode == ExecMode::Graph {
4779 ctx.new_event(None).ok()
4780 } else {
4781 None
4782 };
4783
4784 let mut input_slot_names = Vec::new();
4785 let mut input_slots = Vec::new();
4786 for node in graph.nodes() {
4787 if let Op::Input { name } = &node.op {
4788 let off = if arena.has(node.id) {
4789 arena.offset(node.id)
4790 } else {
4791 0
4792 };
4793 let len = node.shape.num_elements().unwrap_or(0);
4794 input_slot_names.push(name.clone());
4795 input_slots.push((off, len));
4796 }
4797 }
4798
4799 let mut host_total = 0usize;
4800 let mut output_slots = Vec::new();
4801 for &id in &graph.outputs {
4802 let n = graph.node(id).shape.num_elements().unwrap_or(0);
4803 output_slots.push((host_total * 4, n));
4804 host_total += n;
4805 }
4806 let host_arena = vec![0.0f32; host_total];
4807
4808 Self {
4809 ctx,
4810 blas,
4811 blas_lt,
4812 blas_lt_workspace,
4813 dnn,
4814 dnn_workspace,
4815 half_act_scratch: None,
4816 dequant_scratch_off,
4817 graph,
4818 arena,
4819 schedule,
4820 input_offsets,
4821 param_offsets,
4822 meta_buffers,
4823 exec_mode,
4824 captured_graph: None,
4825 streams,
4826 active_extent: None,
4827 output_staging,
4828 input_staging,
4829 replay_event,
4830 gpu_handles: HashMap::new(),
4831 gpu_handle_feeds: HashMap::new(),
4832 gpu_handle_resident: std::collections::HashSet::new(),
4833 pending_read_indices: None,
4834 readback_plan_buf: Vec::new(),
4835 captured_readback_plan: None,
4836 input_slot_names,
4837 input_slots,
4838 output_slots,
4839 host_arena,
4840 }
4841 }
4842
4843 pub fn arena_ptr(&self) -> *const u8 {
4846 self.host_arena.as_ptr() as *const u8
4847 }
4848
4849 pub fn output_slots(&self) -> &[(usize, usize)] {
4850 &self.output_slots
4851 }
4852
4853 fn upload_slot_inputs(&mut self, inputs: &[&[f32]]) {
4854 let stream = self.ctx.default_stream();
4855 for (i, data) in inputs.iter().enumerate() {
4856 let Some(&(byte_off, max_elems)) = self.input_slots.get(i) else {
4857 break;
4858 };
4859 let off_f32 = byte_off / 4;
4860 let len = data.len().min(max_elems);
4861 if len == 0 {
4862 continue;
4863 }
4864 let mut slot = self.arena.f32_buf_mut().slice_mut(off_f32..off_f32 + len);
4865 if let Some(name) = self.input_slot_names.get(i) {
4866 if let Some(host) = self.input_staging.get_mut(name.as_str()) {
4867 host.copy_from_host(data);
4868 let _ = host.htod(&stream, &mut slot, len);
4869 continue;
4870 }
4871 }
4872 let _ = stream.memcpy_htod(&data[..len], &mut slot);
4873 }
4874 }
4875
4876 fn pack_host_arena(&mut self) {
4877 self.prepare_readback_plan();
4878 for &i in &self.readback_plan_buf {
4879 if i >= self.output_staging.len() || i >= self.output_slots.len() {
4880 continue;
4881 }
4882 let (byte_off, n) = self.output_slots[i];
4883 if n == 0 {
4884 continue;
4885 }
4886 let start = byte_off / 4;
4887 let end = start + n;
4888 if end <= self.host_arena.len() {
4889 self.output_staging[i].copy_into(&mut self.host_arena[start..end]);
4890 }
4891 }
4892 }
4893
4894 pub fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
4896 self.upload_slot_inputs(inputs);
4897 let _ = self.run_inner(&[]);
4898 self.pack_host_arena();
4899 &self.output_slots
4900 }
4901
4902 pub fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
4908 self.active_extent = extent;
4909 }
4910
4911 fn all_safe_for_active(&self) -> bool {
4912 self.schedule.iter().all(|s| s.safe_for_active_extent())
4913 }
4914
4915 pub fn output_dtypes(&self) -> Vec<rlx_ir::DType> {
4919 self.graph
4920 .outputs
4921 .iter()
4922 .map(|&id| self.graph.node(id).shape.dtype())
4923 .collect()
4924 }
4925
4926 pub fn set_param(&mut self, name: &str, data: &[f32]) {
4927 if let Some(&id) = self.param_offsets.get(name)
4928 && self.arena.has(id)
4929 {
4930 let off_f32 = self.arena.offset(id) / 4;
4931 let stream = self.ctx.default_stream();
4932 let mut slot = self
4933 .arena
4934 .f32_buf_mut()
4935 .slice_mut(off_f32..off_f32 + data.len());
4936 stream
4937 .memcpy_htod(data, &mut slot)
4938 .expect("rlx-cuda: param upload failed");
4939 }
4940 }
4941
4942 pub fn set_param_bytes(&mut self, name: &str, data: &[u8]) {
4944 if let Some(&id) = self.param_offsets.get(name)
4945 && self.arena.has(id)
4946 {
4947 let byte_off = self.arena.offset(id);
4948 let stream = self.ctx.default_stream();
4949 crate::gguf_host::upload_param_bytes(&stream, self.arena.f32_buf_mut(), byte_off, data);
4950 }
4951 }
4952
4953 pub fn set_param_half(&mut self, name: &str, dtype: crate::arena::HalfDtype, bits: &[u16]) {
4965 let id = match self.param_offsets.get(name) {
4966 Some(&id) if self.arena.has(id) => id,
4967 _ => return,
4968 };
4969 let f32_off = (self.arena.offset(id) / 4) as u32;
4970 let off = self
4971 .arena
4972 .register_half_param(&self.ctx, id, f32_off, bits.len(), dtype);
4973 let stream = self.ctx.default_stream();
4974 if let Some(buf) = self.arena.half_buffer.as_mut() {
4975 let mut slot = buf.slice_mut(off..off + bits.len());
4976 stream
4977 .memcpy_htod(bits, &mut slot)
4978 .expect("rlx-cuda: half-param upload failed");
4979 }
4980 }
4981
4982 pub fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
4983 self.run_read_outputs(inputs, None)
4984 }
4985
4986 pub fn run_read_outputs(
4988 &mut self,
4989 inputs: &[(&str, &[f32])],
4990 read_indices: Option<&[usize]>,
4991 ) -> Vec<Vec<f32>> {
4992 match read_indices {
4993 None => self.pending_read_indices = None,
4994 Some(ix) => {
4995 let buf = self.pending_read_indices.get_or_insert_with(Vec::new);
4996 buf.clear();
4997 buf.extend_from_slice(ix);
4998 normalize_read_indices(buf);
4999 }
5000 }
5001 let outs = self.run_inner(inputs);
5002 self.pending_read_indices = None;
5003 outs
5004 }
5005
5006 pub fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
5007 if !self.input_offsets.contains_key(name) {
5008 return false;
5009 }
5010 self.gpu_handle_resident.remove(name);
5011 self.gpu_handles.insert(name.to_string(), data.to_vec());
5012 true
5013 }
5014
5015 pub fn has_gpu_handle(&self, name: &str) -> bool {
5016 self.gpu_handles.contains_key(name)
5017 }
5018
5019 pub fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) {
5020 self.gpu_handle_feeds
5021 .insert(handle_name.to_string(), output_index);
5022 }
5023
5024 pub fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
5025 if let Some(&out_idx) = self.gpu_handle_feeds.get(name) {
5026 if out_idx < self.graph.outputs.len() {
5027 let id = self.graph.outputs[out_idx];
5028 let stream = self.ctx.default_stream();
5029 let off_f32 = self.arena.offset(id) / 4;
5030 let n_f32 = self.arena.len_of(id) / 4;
5031 let mut host = vec![0f32; n_f32];
5032 let src = self.arena.f32_buf().slice(off_f32..off_f32 + n_f32);
5033 if stream.memcpy_dtoh(&src, host.as_mut_slice()).is_ok() {
5034 return Some(host);
5035 }
5036 }
5037 }
5038 if self.gpu_handle_resident.contains(name) {
5039 if let Some(&id) = self.input_offsets.get(name) {
5040 let stream = self.ctx.default_stream();
5041 let off_f32 = self.arena.offset(id) / 4;
5042 let n_f32 = self.arena.len_of(id) / 4;
5043 let mut host = vec![0f32; n_f32];
5044 let src = self.arena.f32_buf().slice(off_f32..off_f32 + n_f32);
5045 if stream.memcpy_dtoh(&src, host.as_mut_slice()).is_ok() {
5046 return Some(host);
5047 }
5048 }
5049 }
5050 self.gpu_handles.get(name).cloned()
5051 }
5052
5053 fn prepare_readback_plan(&mut self) {
5055 self.readback_plan_buf.clear();
5056 let n = self.graph.outputs.len();
5057 if let Some(ref want) = self.pending_read_indices {
5058 self.readback_plan_buf.extend_from_slice(want);
5059 normalize_read_indices(&mut self.readback_plan_buf);
5060 return;
5061 }
5062 self.readback_plan_buf.extend(0..n);
5063 }
5064
5065 fn propagate_gpu_handle_feeds_d2d(&mut self, stream: &Arc<cudarc::driver::CudaStream>) {
5066 let extent = self.active_extent;
5067 for (name, &out_idx) in &self.gpu_handle_feeds {
5068 if out_idx >= self.graph.outputs.len() {
5069 continue;
5070 }
5071 let out_id = self.graph.outputs[out_idx];
5072 let Some(&in_id) = self.input_offsets.get(name.as_str()) else {
5073 continue;
5074 };
5075 if in_id != out_id {
5076 let out_bytes = self.arena.len_of(out_id);
5077 let copy_bytes = match extent {
5078 Some((actual, upper)) if upper > 0 => {
5079 let stride = (out_bytes / (upper + 1)).max(4);
5080 (actual * stride).min(out_bytes)
5081 }
5082 _ => out_bytes,
5083 }
5084 .min(self.arena.len_of(in_id));
5085 let src_off = self.arena.offset(out_id) / 4;
5086 let dst_off = self.arena.offset(in_id) / 4;
5087 let n_f32 = copy_bytes / 4;
5088 if n_f32 > 0 && src_off != dst_off {
5089 let mut tmp = vec![0.0f32; n_f32];
5090 let src = self.arena.f32_buf().slice(src_off..src_off + n_f32);
5091 if stream.memcpy_dtoh(&src, &mut tmp).is_ok() {
5092 let mut dst = self.arena.f32_buf_mut().slice_mut(dst_off..dst_off + n_f32);
5093 let _ = stream.memcpy_htod(&tmp, &mut dst);
5094 }
5095 }
5096 }
5097 self.gpu_handle_resident.insert(name.clone());
5098 self.gpu_handles.insert(name.clone(), Vec::new());
5099 }
5100 }
5101
5102 fn stage_gpu_handle_inputs(
5103 &mut self,
5104 stream: &Arc<cudarc::driver::CudaStream>,
5105 inputs: &[(&str, &[f32])],
5106 ) {
5107 for (name, data) in &self.gpu_handles {
5108 if self.gpu_handle_resident.contains(name) || inputs.iter().any(|(n, _)| n == name) {
5109 continue;
5110 }
5111 if let Some(&id) = self.input_offsets.get(name.as_str())
5112 && self.arena.has(id)
5113 {
5114 let off_f32 = self.arena.offset(id) / 4;
5115 let mut slot = self
5116 .arena
5117 .f32_buf_mut()
5118 .slice_mut(off_f32..off_f32 + data.len());
5119 if let Some(host) = self.input_staging.get_mut(name.as_str()) {
5120 host.copy_from_host(data);
5121 let _ = host.htod(stream, &mut slot, data.len());
5122 } else {
5123 let _ = stream.memcpy_htod(data.as_slice(), &mut slot);
5124 }
5125 }
5126 }
5127 }
5128
5129 fn refresh_gpu_handles_from_staging(&mut self, plan: &[usize]) {
5130 if self.pending_read_indices.is_some() {
5131 return;
5132 }
5133 for (name, &out_idx) in &self.gpu_handle_feeds {
5134 if plan.contains(&out_idx) && out_idx < self.output_staging.len() {
5135 self.gpu_handles
5136 .insert(name.clone(), self.output_staging[out_idx].to_vec());
5137 }
5138 }
5139 }
5140
5141 fn run_inner(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
5142 let default_stream = self.ctx.default_stream();
5143 let stream = default_stream.clone();
5144
5145 self.stage_gpu_handle_inputs(&stream, inputs);
5146
5147 for &(name, data) in inputs {
5151 if let Some(&id) = self.input_offsets.get(name)
5152 && self.arena.has(id)
5153 {
5154 let off_f32 = self.arena.offset(id) / 4;
5155 let mut slot = self
5156 .arena
5157 .f32_buf_mut()
5158 .slice_mut(off_f32..off_f32 + data.len());
5159 if let Some(host) = self.input_staging.get_mut(name) {
5160 host.copy_from_host(data);
5161 host.htod(&stream, &mut slot, data.len())
5162 .expect("rlx-cuda: pinned input upload failed");
5163 } else {
5164 stream
5165 .memcpy_htod(data, &mut slot)
5166 .expect("rlx-cuda: input upload failed");
5167 }
5168 }
5169 }
5170
5171 let active = self.active_extent.filter(|_| self.all_safe_for_active());
5175 let scale = |full: u32| -> u32 {
5177 match active {
5178 Some((a, u)) if u > 0 => {
5179 let f = full as usize;
5180 (f * a).div_ceil(u).min(f) as u32
5181 }
5182 _ => full,
5183 }
5184 };
5185
5186 let graph_eligible = active.is_none()
5192 && self.exec_mode == ExecMode::Graph
5193 && schedule_graph_capture_safe(&self.schedule);
5194 let do_replay = graph_eligible && self.captured_graph.is_some();
5195 let do_capture = graph_eligible && self.captured_graph.is_none();
5196
5197 if do_replay {
5198 self.prepare_readback_plan();
5199 let plan_ok = self
5200 .captured_readback_plan
5201 .as_ref()
5202 .is_some_and(|p| p.as_slice() == self.readback_plan_buf.as_slice());
5203 if plan_ok {
5204 self.captured_graph
5205 .as_ref()
5206 .unwrap()
5207 .launch()
5208 .expect("rlx-cuda: graph replay failed");
5209 if let Some(evt) = &self.replay_event {
5210 evt.record(&stream)
5211 .expect("rlx-cuda: replay event record failed");
5212 evt.synchronize()
5213 .expect("rlx-cuda: replay event sync failed");
5214 } else {
5215 stream.synchronize().expect("rlx-cuda: stream sync failed");
5216 }
5217 run_tail_host_audio_ops(&self.schedule, &stream, self.arena.f32_buf_mut(), false);
5218 let plan = self.readback_plan_buf.clone();
5219 let read_all = plan.len() == self.graph.outputs.len();
5220 if read_all {
5223 self.fill_output_staging(&stream)
5224 .expect("rlx-cuda: output dtoh failed after replay");
5225 } else {
5226 self.fill_output_staging_indices(&stream, &plan)
5227 .expect("rlx-cuda: partial output dtoh failed after replay");
5228 }
5229 self.refresh_gpu_handles_from_staging(&plan);
5230 return self.outputs_from_staging_plan(&plan);
5231 }
5232 self.captured_graph = None;
5234 self.captured_readback_plan = None;
5235 }
5236 let _ = do_replay;
5237
5238 let mut capturing = false;
5239 if do_capture {
5240 capturing = stream
5241 .begin_capture(
5242 cudarc::driver::sys::CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_RELAXED,
5243 )
5244 .is_ok();
5245 }
5246
5247 let multi_stream =
5253 matches!(self.exec_mode, ExecMode::MultiStream(_)) && !self.streams.is_empty();
5254 let mut producer_of: HashMap<u32, usize> = HashMap::new();
5255 let mut last_event: HashMap<usize, cudarc::driver::CudaEvent> = HashMap::new();
5256 let mut rr_cursor: usize = 0;
5257
5258 for step in &self.schedule {
5264 #[cfg(feature = "nvtx")]
5265 let _nvtx = cudarc::nvtx::scoped_range(step_name(step));
5266 let _perf = rlx_ir::perfetto::TraceSpan::new(step_name(step), "cuda");
5269
5270 let assigned_idx: Option<usize> = if multi_stream {
5275 let (reads, _) = step_offsets(step);
5276 let mut producer_streams: std::collections::HashSet<usize> =
5277 std::collections::HashSet::new();
5278 for r in &reads {
5279 if let Some(&s) = producer_of.get(r) {
5280 producer_streams.insert(s);
5281 }
5282 }
5283 let chosen = if producer_streams.is_empty() {
5284 let s = rr_cursor % self.streams.len();
5285 rr_cursor += 1;
5286 s
5287 } else if producer_streams.len() == 1 {
5288 *producer_streams.iter().next().unwrap()
5289 } else {
5290 let chosen = *producer_streams.iter().next().unwrap();
5293 for s in &producer_streams {
5294 if *s != chosen
5295 && let Some(evt) = last_event.get(s)
5296 {
5297 let _ = self.streams[chosen].wait(evt);
5298 }
5299 }
5300 chosen
5301 };
5302 Some(chosen)
5303 } else {
5304 None
5305 };
5306 let stream: Arc<cudarc::driver::CudaStream> = match assigned_idx {
5307 Some(i) => self.streams[i].clone(),
5308 None => default_stream.clone(),
5309 };
5310 if multi_stream {
5313 if let Some(blas) = self.blas.as_ref() {
5314 let blas = blas.lock().unwrap();
5315 unsafe {
5316 let _ = cudarc::cublas::result::set_stream(
5317 *blas.handle(),
5318 stream.cu_stream() as _,
5319 );
5320 }
5321 }
5322 if let Some(handle) = self.dnn {
5323 unsafe {
5324 let _ = cudarc::cudnn::result::set_stream(
5325 handle,
5326 stream.cu_stream() as cudnn_sys::cudaStream_t,
5327 );
5328 }
5329 }
5330 }
5331 match step {
5332 Step::Matmul {
5333 m,
5334 k,
5335 n,
5336 a_off_f32,
5337 b_off_f32,
5338 c_off_f32,
5339 batch,
5340 a_batch_stride,
5341 b_batch_stride,
5342 c_batch_stride,
5343 has_bias,
5344 bias_off_f32,
5345 act_id,
5346 } => {
5347 if matmul_parity_mode() {
5348 let kernel = matmul_kernel(&self.ctx);
5349 let cfg = LaunchConfig {
5350 grid_dim: ((*n).div_ceil(64), (*m).div_ceil(64), *batch),
5351 block_dim: (16, 16, 1),
5352 shared_mem_bytes: 0,
5353 };
5354 let mut launcher = stream.launch_builder(&kernel.function);
5355 launcher
5356 .arg(self.arena.f32_buf_mut())
5357 .arg(m)
5358 .arg(k)
5359 .arg(n)
5360 .arg(a_off_f32)
5361 .arg(b_off_f32)
5362 .arg(c_off_f32)
5363 .arg(batch)
5364 .arg(a_batch_stride)
5365 .arg(b_batch_stride)
5366 .arg(c_batch_stride)
5367 .arg(has_bias)
5368 .arg(bias_off_f32)
5369 .arg(act_id);
5370 unsafe {
5371 launcher
5372 .launch(cfg)
5373 .expect("rlx-cuda: matmul (parity) launch failed");
5374 }
5375 if let Some(idx) = assigned_idx {
5376 if let Ok(evt) = stream.record_event(None) {
5377 last_event.insert(idx, evt);
5378 }
5379 let (_, writes) = step_offsets(step);
5380 for w in &writes {
5381 producer_of.insert(*w, idx);
5382 }
5383 }
5384 continue;
5385 }
5386
5387 let used_mixed = try_mixed_precision_gemm(
5393 &self.ctx,
5394 &mut self.arena,
5395 &mut self.half_act_scratch,
5396 self.blas.as_ref(),
5397 &stream,
5398 *m,
5399 *k,
5400 *n,
5401 *batch,
5402 *a_off_f32,
5403 *b_off_f32,
5404 *c_off_f32,
5405 );
5406 if used_mixed {
5407 if *has_bias != 0 || *act_id != 0xFFFFu32 {
5409 let kernel = matmul_epilogue_kernel(&self.ctx);
5410 let total = m * n * batch;
5411 let (grid, block) = dispatch_grid_1d(total, 256);
5412 let cfg = LaunchConfig {
5413 grid_dim: (grid, 1, 1),
5414 block_dim: (block, 1, 1),
5415 shared_mem_bytes: 0,
5416 };
5417 let mut launcher = stream.launch_builder(&kernel.function);
5418 launcher
5419 .arg(self.arena.f32_buf_mut())
5420 .arg(&total)
5421 .arg(n)
5422 .arg(c_off_f32)
5423 .arg(has_bias)
5424 .arg(bias_off_f32)
5425 .arg(act_id);
5426 unsafe {
5427 launcher
5428 .launch(cfg)
5429 .expect("rlx-cuda: matmul_epilogue (mixed) failed");
5430 }
5431 }
5432 if let Some(idx) = assigned_idx {
5434 if let Ok(evt) = stream.record_event(None) {
5435 last_event.insert(idx, evt);
5436 }
5437 let (_, writes) = step_offsets(step);
5438 for w in &writes {
5439 producer_of.insert(*w, idx);
5440 }
5441 }
5442 continue;
5443 }
5444
5445 let try_cublaslt = self.blas_lt.is_some()
5451 && self.blas_lt_workspace.is_some()
5452 && cublaslt_act_supported(*act_id);
5453 let used_cublaslt = if try_cublaslt {
5454 let lt_handle = self.blas_lt.unwrap();
5455 let mut workspace =
5456 self.blas_lt_workspace.as_ref().unwrap().lock().unwrap();
5457 let (workspace_ptr, _ws_record) = workspace.device_ptr_mut(&stream);
5458 let (arena_ptr, _record) = self.arena.f32_buf_mut().device_ptr_mut(&stream);
5459 let cu_stream = stream.cu_stream();
5460 let act = cublaslt_act_for(*act_id);
5461 let r = unsafe {
5462 cublaslt_matmul_fused(
5463 lt_handle,
5464 workspace_ptr,
5465 CUBLASLT_WORKSPACE_BYTES,
5466 arena_ptr,
5467 *m,
5468 *k,
5469 *n,
5470 *a_off_f32,
5471 *b_off_f32,
5472 *c_off_f32,
5473 *has_bias != 0,
5474 *bias_off_f32,
5475 act,
5476 *batch,
5477 *a_batch_stride,
5478 *b_batch_stride,
5479 *c_batch_stride,
5480 cu_stream,
5481 )
5482 };
5483 if let Err(ref e) = r {
5484 log_fallback("matmul.cublasLt", e);
5485 }
5486 r.is_ok()
5487 } else {
5488 false
5489 };
5490 if used_cublaslt {
5491 continue;
5492 }
5493
5494 let used_cublas = if let Some(blas) = self.blas.as_ref() {
5497 let blas = blas.lock().unwrap();
5498 let (arena_ptr_u64, _record) =
5499 self.arena.f32_buf_mut().device_ptr_mut(&stream);
5500 let a_dev = arena_ptr_u64 + (*a_off_f32 as u64) * 4;
5501 let b_dev = arena_ptr_u64 + (*b_off_f32 as u64) * 4;
5502 let c_dev = arena_ptr_u64 + (*c_off_f32 as u64) * 4;
5503 let alpha: f32 = 1.0;
5504 let beta: f32 = 0.0;
5505 let result = unsafe {
5511 if *batch == 1 {
5512 cudarc::cublas::result::sgemm(
5513 *blas.handle(),
5514 cublas_sys::cublasOperation_t::CUBLAS_OP_N,
5515 cublas_sys::cublasOperation_t::CUBLAS_OP_N,
5516 *n as i32,
5517 *m as i32,
5518 *k as i32,
5519 &alpha as *const f32,
5520 b_dev as *const f32,
5521 *n as i32,
5522 a_dev as *const f32,
5523 *k as i32,
5524 &beta as *const f32,
5525 c_dev as *mut f32,
5526 *n as i32,
5527 )
5528 } else {
5529 cudarc::cublas::result::sgemm_strided_batched(
5530 *blas.handle(),
5531 cublas_sys::cublasOperation_t::CUBLAS_OP_N,
5532 cublas_sys::cublasOperation_t::CUBLAS_OP_N,
5533 *n as i32,
5534 *m as i32,
5535 *k as i32,
5536 &alpha as *const f32,
5537 b_dev as *const f32,
5538 *n as i32,
5539 *b_batch_stride as i64,
5540 a_dev as *const f32,
5541 *k as i32,
5542 *a_batch_stride as i64,
5543 &beta as *const f32,
5544 c_dev as *mut f32,
5545 *n as i32,
5546 *c_batch_stride as i64,
5547 *batch as i32,
5548 )
5549 }
5550 };
5551 if let Err(ref e) = result {
5552 log_fallback("matmul.cublasSgemm", e);
5553 }
5554 result.is_ok()
5555 } else {
5556 false
5557 };
5558
5559 if used_cublas {
5560 if *has_bias != 0 || *act_id != 0xFFFFu32 {
5563 let kernel = matmul_epilogue_kernel(&self.ctx);
5564 let total = m * n * batch;
5565 let (grid, block) = dispatch_grid_1d(total, 256);
5566 let cfg = LaunchConfig {
5567 grid_dim: (grid, 1, 1),
5568 block_dim: (block, 1, 1),
5569 shared_mem_bytes: 0,
5570 };
5571 let mut launcher = stream.launch_builder(&kernel.function);
5572 launcher
5573 .arg(self.arena.f32_buf_mut())
5574 .arg(&total)
5575 .arg(n)
5576 .arg(c_off_f32)
5577 .arg(has_bias)
5578 .arg(bias_off_f32)
5579 .arg(act_id);
5580 unsafe {
5581 launcher
5582 .launch(cfg)
5583 .expect("rlx-cuda: matmul_epilogue launch failed");
5584 }
5585 }
5586 } else if use_wmma() {
5587 let kernel = matmul_wmma_kernel(&self.ctx);
5591 let cfg = LaunchConfig {
5592 grid_dim: ((*n).div_ceil(64), (*m).div_ceil(32), *batch),
5593 block_dim: (128, 1, 1),
5594 shared_mem_bytes: 0,
5595 };
5596 let mut launcher = stream.launch_builder(&kernel.function);
5597 launcher
5598 .arg(self.arena.f32_buf_mut())
5599 .arg(m)
5600 .arg(k)
5601 .arg(n)
5602 .arg(a_off_f32)
5603 .arg(b_off_f32)
5604 .arg(c_off_f32)
5605 .arg(batch)
5606 .arg(a_batch_stride)
5607 .arg(b_batch_stride)
5608 .arg(c_batch_stride);
5609 unsafe {
5610 launcher
5611 .launch(cfg)
5612 .expect("rlx-cuda: matmul_wmma launch failed");
5613 }
5614 if *has_bias != 0 || *act_id != 0xFFFFu32 {
5615 let kernel = matmul_epilogue_kernel(&self.ctx);
5616 let total = m * n * batch;
5617 let (grid, block) = dispatch_grid_1d(total, 256);
5618 let cfg = LaunchConfig {
5619 grid_dim: (grid, 1, 1),
5620 block_dim: (block, 1, 1),
5621 shared_mem_bytes: 0,
5622 };
5623 let mut launcher = stream.launch_builder(&kernel.function);
5624 launcher
5625 .arg(self.arena.f32_buf_mut())
5626 .arg(&total)
5627 .arg(n)
5628 .arg(c_off_f32)
5629 .arg(has_bias)
5630 .arg(bias_off_f32)
5631 .arg(act_id);
5632 unsafe {
5633 launcher
5634 .launch(cfg)
5635 .expect("rlx-cuda: matmul_epilogue (post-wmma) failed");
5636 }
5637 }
5638 } else {
5639 let kernel = matmul_kernel(&self.ctx);
5641 let cfg = LaunchConfig {
5642 grid_dim: ((*n).div_ceil(64), (*m).div_ceil(64), *batch),
5643 block_dim: (16, 16, 1),
5644 shared_mem_bytes: 0,
5645 };
5646 let mut launcher = stream.launch_builder(&kernel.function);
5647 launcher
5648 .arg(self.arena.f32_buf_mut())
5649 .arg(m)
5650 .arg(k)
5651 .arg(n)
5652 .arg(a_off_f32)
5653 .arg(b_off_f32)
5654 .arg(c_off_f32)
5655 .arg(batch)
5656 .arg(a_batch_stride)
5657 .arg(b_batch_stride)
5658 .arg(c_batch_stride)
5659 .arg(has_bias)
5660 .arg(bias_off_f32)
5661 .arg(act_id);
5662 unsafe {
5663 launcher
5664 .launch(cfg)
5665 .expect("rlx-cuda: matmul launch failed");
5666 }
5667 }
5668 }
5669 Step::Binary {
5670 n,
5671 a_off,
5672 b_off,
5673 c_off,
5674 op,
5675 } => {
5676 let n_s = scale(*n);
5677 if n_s == 0 {
5678 continue;
5679 }
5680 let kernel = binary_kernel(&self.ctx);
5681 let (grid, block) = dispatch_grid_1d(n_s, 256);
5682 let cfg = LaunchConfig {
5683 grid_dim: (grid, 1, 1),
5684 block_dim: (block, 1, 1),
5685 shared_mem_bytes: 0,
5686 };
5687 let mut launcher = stream.launch_builder(&kernel.function);
5688 launcher
5689 .arg(self.arena.f32_buf_mut())
5690 .arg(&n_s)
5691 .arg(a_off)
5692 .arg(b_off)
5693 .arg(c_off)
5694 .arg(op);
5695 unsafe {
5696 launcher
5697 .launch(cfg)
5698 .expect("rlx-cuda: binary launch failed");
5699 }
5700 }
5701 Step::ElementwiseRegion {
5702 len,
5703 num_inputs,
5704 num_steps,
5705 dst_off,
5706 input_offs: _,
5707 scalar_input_mask,
5708 input_modulus,
5709 meta_idx,
5710 spatial_prologue,
5711 prologue_w,
5712 prologue_h,
5713 prologue_nc,
5714 } => {
5715 let len_s = scale(*len);
5716 if len_s == 0 {
5717 continue;
5718 }
5719 let kernel = elementwise_region_kernel(&self.ctx);
5720 let ((gx, gy, gz), (bx, by, bz)) = if *spatial_prologue {
5721 dispatch_grid_prologue_nchw(*prologue_w, *prologue_h, *prologue_nc)
5722 } else {
5723 let (grid, block) = dispatch_grid_1d(len_s, 256);
5724 ((grid, 1, 1), (block, 1, 1))
5725 };
5726 let cfg = LaunchConfig {
5727 grid_dim: (gx, gy, gz),
5728 block_dim: (bx, by, bz),
5729 shared_mem_bytes: 0,
5730 };
5731 let mut launcher = stream.launch_builder(&kernel.function);
5732 launcher
5737 .arg(self.arena.f32_buf_mut())
5738 .arg(&len_s)
5739 .arg(num_inputs)
5740 .arg(num_steps)
5741 .arg(dst_off)
5742 .arg(&self.meta_buffers[*meta_idx])
5743 .arg(scalar_input_mask)
5744 .arg(input_modulus);
5745 unsafe {
5746 launcher
5747 .launch(cfg)
5748 .expect("rlx-cuda: elementwise_region launch failed");
5749 }
5750 }
5751 Step::BatchElementwiseRegion {
5752 slice_len,
5753 num_batch,
5754 num_steps,
5755 base_dst_off,
5756 slice_elems,
5757 batch_offs_idx,
5758 meta_idx,
5759 scalar_input_mask,
5760 input_modulus,
5761 ..
5762 } => {
5763 let slice_len_s = scale(*slice_len);
5764 let num_batch_s = scale(*num_batch);
5765 if slice_len_s == 0 || num_batch_s == 0 {
5766 continue;
5767 }
5768 let kernel = batch_elementwise_region_kernel(&self.ctx);
5769 let (grid_x, block_x) = dispatch_grid_1d(slice_len_s, 256);
5770 let cfg = LaunchConfig {
5771 grid_dim: (grid_x, 1, num_batch_s),
5772 block_dim: (block_x, 1, 1),
5773 shared_mem_bytes: 0,
5774 };
5775 let mut launcher = stream.launch_builder(&kernel.function);
5776 launcher
5777 .arg(self.arena.f32_buf_mut())
5778 .arg(&slice_len_s)
5779 .arg(&num_batch_s)
5780 .arg(num_steps)
5781 .arg(base_dst_off)
5782 .arg(slice_elems)
5783 .arg(&self.meta_buffers[*batch_offs_idx])
5784 .arg(&self.meta_buffers[*meta_idx])
5785 .arg(scalar_input_mask)
5786 .arg(input_modulus);
5787 unsafe {
5788 launcher
5789 .launch(cfg)
5790 .expect("rlx-cuda: batch_elementwise_region launch failed");
5791 }
5792 }
5793 Step::FusedBinaryUnary {
5794 n,
5795 a_off,
5796 b_off,
5797 out_off,
5798 bin_op,
5799 un_op,
5800 } => {
5801 let n_s = scale(*n);
5802 if n_s == 0 {
5803 continue;
5804 }
5805 let kernel = fused_binary_unary_kernel(&self.ctx);
5806 let (grid, block) = dispatch_grid_1d(n_s, 256);
5807 let cfg = LaunchConfig {
5808 grid_dim: (grid, 1, 1),
5809 block_dim: (block, 1, 1),
5810 shared_mem_bytes: 0,
5811 };
5812 let mut launcher = stream.launch_builder(&kernel.function);
5813 launcher
5814 .arg(self.arena.f32_buf_mut())
5815 .arg(&n_s)
5816 .arg(a_off)
5817 .arg(b_off)
5818 .arg(out_off)
5819 .arg(bin_op)
5820 .arg(un_op);
5821 unsafe {
5822 launcher
5823 .launch(cfg)
5824 .expect("rlx-cuda: fused_binary_unary launch failed");
5825 }
5826 }
5827 Step::Unary {
5828 n,
5829 in_off,
5830 out_off,
5831 op,
5832 } => {
5833 let n_s = scale(*n);
5834 if n_s == 0 {
5835 continue;
5836 }
5837 let kernel = unary_kernel(&self.ctx);
5838 let (grid, block) = dispatch_grid_1d(n_s, 256);
5839 let cfg = LaunchConfig {
5840 grid_dim: (grid, 1, 1),
5841 block_dim: (block, 1, 1),
5842 shared_mem_bytes: 0,
5843 };
5844 let mut launcher = stream.launch_builder(&kernel.function);
5845 launcher
5846 .arg(self.arena.f32_buf_mut())
5847 .arg(&n_s)
5848 .arg(in_off)
5849 .arg(out_off)
5850 .arg(op);
5851 unsafe {
5852 launcher.launch(cfg).expect("rlx-cuda: unary launch failed");
5853 }
5854 }
5855 Step::Compare {
5856 n,
5857 a_off,
5858 b_off,
5859 c_off,
5860 op,
5861 } => {
5862 let n_s = scale(*n);
5863 if n_s == 0 {
5864 continue;
5865 }
5866 let kernel = compare_kernel(&self.ctx);
5867 let (grid, block) = dispatch_grid_1d(n_s, 256);
5868 let cfg = LaunchConfig {
5869 grid_dim: (grid, 1, 1),
5870 block_dim: (block, 1, 1),
5871 shared_mem_bytes: 0,
5872 };
5873 let mut launcher = stream.launch_builder(&kernel.function);
5874 launcher
5875 .arg(self.arena.f32_buf_mut())
5876 .arg(&n_s)
5877 .arg(a_off)
5878 .arg(b_off)
5879 .arg(c_off)
5880 .arg(op);
5881 unsafe {
5882 launcher
5883 .launch(cfg)
5884 .expect("rlx-cuda: compare launch failed");
5885 }
5886 }
5887 Step::Where {
5888 n,
5889 cond_off,
5890 x_off,
5891 y_off,
5892 out_off,
5893 } => {
5894 let n_s = scale(*n);
5895 if n_s == 0 {
5896 continue;
5897 }
5898 let kernel = where_kernel(&self.ctx);
5899 let (grid, block) = dispatch_grid_1d(n_s, 256);
5900 let cfg = LaunchConfig {
5901 grid_dim: (grid, 1, 1),
5902 block_dim: (block, 1, 1),
5903 shared_mem_bytes: 0,
5904 };
5905 let mut launcher = stream.launch_builder(&kernel.function);
5906 launcher
5907 .arg(self.arena.f32_buf_mut())
5908 .arg(&n_s)
5909 .arg(cond_off)
5910 .arg(x_off)
5911 .arg(y_off)
5912 .arg(out_off);
5913 unsafe {
5914 launcher.launch(cfg).expect("rlx-cuda: where launch failed");
5915 }
5916 }
5917 Step::Reduce {
5918 outer,
5919 inner,
5920 in_off,
5921 out_off,
5922 op,
5923 } => {
5924 let outer_s = scale(*outer);
5925 if outer_s == 0 {
5926 continue;
5927 }
5928 let kernel = reduce_kernel(&self.ctx);
5929 let cfg = LaunchConfig {
5930 grid_dim: (outer_s, 1, 1),
5931 block_dim: (256, 1, 1),
5932 shared_mem_bytes: 0,
5933 };
5934 let mut launcher = stream.launch_builder(&kernel.function);
5935 launcher
5936 .arg(self.arena.f32_buf_mut())
5937 .arg(&outer_s)
5938 .arg(inner)
5939 .arg(in_off)
5940 .arg(out_off)
5941 .arg(op);
5942 unsafe {
5943 launcher
5944 .launch(cfg)
5945 .expect("rlx-cuda: reduce launch failed");
5946 }
5947 }
5948 Step::Softmax {
5949 outer,
5950 inner,
5951 in_off,
5952 out_off,
5953 } => {
5954 let outer_s = scale(*outer);
5955 if outer_s == 0 {
5956 continue;
5957 }
5958 let kernel = softmax_kernel(&self.ctx);
5959 let cfg = LaunchConfig {
5960 grid_dim: (outer_s, 1, 1),
5961 block_dim: (256, 1, 1),
5962 shared_mem_bytes: 0,
5963 };
5964 let mut launcher = stream.launch_builder(&kernel.function);
5965 launcher
5966 .arg(self.arena.f32_buf_mut())
5967 .arg(&outer_s)
5968 .arg(inner)
5969 .arg(in_off)
5970 .arg(out_off);
5971 unsafe {
5972 launcher
5973 .launch(cfg)
5974 .expect("rlx-cuda: softmax launch failed");
5975 }
5976 }
5977 Step::LayerNorm {
5978 outer,
5979 inner,
5980 in_off,
5981 out_off,
5982 gamma_off,
5983 beta_off,
5984 eps_bits,
5985 op,
5986 } => {
5987 let outer_s = scale(*outer);
5988 if outer_s == 0 {
5989 continue;
5990 }
5991 let kernel = layernorm_kernel(&self.ctx);
5992 let cfg = LaunchConfig {
5993 grid_dim: (outer_s, 1, 1),
5994 block_dim: (256, 1, 1),
5995 shared_mem_bytes: 0,
5996 };
5997 let mut launcher = stream.launch_builder(&kernel.function);
5998 launcher
5999 .arg(self.arena.f32_buf_mut())
6000 .arg(&outer_s)
6001 .arg(inner)
6002 .arg(in_off)
6003 .arg(out_off)
6004 .arg(gamma_off)
6005 .arg(beta_off)
6006 .arg(eps_bits)
6007 .arg(op);
6008 unsafe {
6009 launcher
6010 .launch(cfg)
6011 .expect("rlx-cuda: layernorm launch failed");
6012 }
6013 }
6014 Step::FusedResidualLn {
6015 outer,
6016 inner,
6017 in_off,
6018 residual_off,
6019 bias_off,
6020 gamma_off,
6021 beta_off,
6022 out_off,
6023 eps_bits,
6024 has_bias,
6025 } => {
6026 let outer_s = scale(*outer);
6027 if outer_s == 0 {
6028 continue;
6029 }
6030 let kernel = fused_residual_ln_kernel(&self.ctx);
6031 let cfg = LaunchConfig {
6032 grid_dim: (outer_s, 1, 1),
6033 block_dim: (256, 1, 1),
6034 shared_mem_bytes: 0,
6035 };
6036 let mut launcher = stream.launch_builder(&kernel.function);
6037 launcher
6038 .arg(self.arena.f32_buf_mut())
6039 .arg(&outer_s)
6040 .arg(inner)
6041 .arg(in_off)
6042 .arg(residual_off)
6043 .arg(bias_off)
6044 .arg(gamma_off)
6045 .arg(beta_off)
6046 .arg(out_off)
6047 .arg(eps_bits)
6048 .arg(has_bias);
6049 unsafe {
6050 launcher
6051 .launch(cfg)
6052 .expect("rlx-cuda: fused_residual_ln launch failed");
6053 }
6054 }
6055 Step::FusedResidualRmsNorm {
6056 outer,
6057 inner,
6058 in_off,
6059 residual_off,
6060 bias_off,
6061 gamma_off,
6062 beta_off,
6063 out_off,
6064 eps_bits,
6065 has_bias,
6066 } => {
6067 let outer_s = scale(*outer);
6068 if outer_s == 0 {
6069 continue;
6070 }
6071 let kernel = fused_residual_rms_norm_kernel(&self.ctx);
6072 let cfg = LaunchConfig {
6073 grid_dim: (outer_s, 1, 1),
6074 block_dim: (256, 1, 1),
6075 shared_mem_bytes: 0,
6076 };
6077 let mut launcher = stream.launch_builder(&kernel.function);
6078 launcher
6079 .arg(self.arena.f32_buf_mut())
6080 .arg(&outer_s)
6081 .arg(inner)
6082 .arg(in_off)
6083 .arg(residual_off)
6084 .arg(bias_off)
6085 .arg(gamma_off)
6086 .arg(beta_off)
6087 .arg(out_off)
6088 .arg(eps_bits)
6089 .arg(has_bias);
6090 unsafe {
6091 launcher
6092 .launch(cfg)
6093 .expect("rlx-cuda: fused_residual_rms_norm launch failed");
6094 }
6095 }
6096 Step::Gather {
6097 n_out,
6098 n_idx,
6099 dim,
6100 vocab,
6101 in_off,
6102 idx_off,
6103 out_off,
6104 } => {
6105 let kernel = gather_kernel(&self.ctx);
6106 let (grid, block) = dispatch_grid_1d(*n_out, 256);
6107 let cfg = LaunchConfig {
6108 grid_dim: (grid, 1, 1),
6109 block_dim: (block, 1, 1),
6110 shared_mem_bytes: 0,
6111 };
6112 let mut launcher = stream.launch_builder(&kernel.function);
6113 launcher
6114 .arg(self.arena.f32_buf_mut())
6115 .arg(n_out)
6116 .arg(n_idx)
6117 .arg(dim)
6118 .arg(vocab)
6119 .arg(in_off)
6120 .arg(idx_off)
6121 .arg(out_off);
6122 unsafe {
6123 launcher
6124 .launch(cfg)
6125 .expect("rlx-cuda: gather launch failed");
6126 }
6127 }
6128 Step::GatherAxis {
6129 total,
6130 outer,
6131 axis_dim,
6132 num_idx,
6133 trailing,
6134 table_off,
6135 idx_off,
6136 out_off,
6137 } => {
6138 let kernel = gather_axis_kernel(&self.ctx);
6139 let (grid, block) = dispatch_grid_1d(*total, 256);
6140 let cfg = LaunchConfig {
6141 grid_dim: (grid, 1, 1),
6142 block_dim: (block, 1, 1),
6143 shared_mem_bytes: 0,
6144 };
6145 let mut launcher = stream.launch_builder(&kernel.function);
6146 launcher
6147 .arg(self.arena.f32_buf_mut())
6148 .arg(total)
6149 .arg(outer)
6150 .arg(axis_dim)
6151 .arg(num_idx)
6152 .arg(trailing)
6153 .arg(table_off)
6154 .arg(idx_off)
6155 .arg(out_off);
6156 unsafe {
6157 launcher
6158 .launch(cfg)
6159 .expect("rlx-cuda: gather_axis launch failed");
6160 }
6161 }
6162 Step::Narrow {
6163 total,
6164 outer,
6165 inner,
6166 axis_in_size,
6167 axis_out_size,
6168 start,
6169 in_off,
6170 out_off,
6171 } => {
6172 let kernel = narrow_kernel(&self.ctx);
6173 let (grid, block) = dispatch_grid_1d(*total, 256);
6174 let cfg = LaunchConfig {
6175 grid_dim: (grid, 1, 1),
6176 block_dim: (block, 1, 1),
6177 shared_mem_bytes: 0,
6178 };
6179 let mut launcher = stream.launch_builder(&kernel.function);
6180 launcher
6181 .arg(self.arena.f32_buf_mut())
6182 .arg(total)
6183 .arg(outer)
6184 .arg(inner)
6185 .arg(axis_in_size)
6186 .arg(axis_out_size)
6187 .arg(start)
6188 .arg(in_off)
6189 .arg(out_off);
6190 unsafe {
6191 launcher
6192 .launch(cfg)
6193 .expect("rlx-cuda: narrow launch failed");
6194 }
6195 }
6196 Step::Argmax {
6197 outer,
6198 inner,
6199 in_off,
6200 out_off,
6201 } => {
6202 let kernel = argmax_kernel(&self.ctx);
6203 let (grid, block) = dispatch_grid_1d(*outer, 256);
6204 let cfg = LaunchConfig {
6205 grid_dim: (grid, 1, 1),
6206 block_dim: (block, 1, 1),
6207 shared_mem_bytes: 0,
6208 };
6209 let mut launcher = stream.launch_builder(&kernel.function);
6210 launcher
6211 .arg(self.arena.f32_buf_mut())
6212 .arg(outer)
6213 .arg(inner)
6214 .arg(in_off)
6215 .arg(out_off);
6216 unsafe {
6217 launcher
6218 .launch(cfg)
6219 .expect("rlx-cuda: argmax launch failed");
6220 }
6221 }
6222 Step::Transpose {
6223 rank,
6224 out_total,
6225 in_off,
6226 out_off,
6227 meta_idx,
6228 } => {
6229 let kernel = transpose_kernel(&self.ctx);
6230 let (grid, block) = dispatch_grid_1d(*out_total, 256);
6231 let cfg = LaunchConfig {
6232 grid_dim: (grid, 1, 1),
6233 block_dim: (block, 1, 1),
6234 shared_mem_bytes: 0,
6235 };
6236 let mut launcher = stream.launch_builder(&kernel.function);
6237 launcher
6238 .arg(self.arena.f32_buf_mut())
6239 .arg(rank)
6240 .arg(out_total)
6241 .arg(in_off)
6242 .arg(out_off)
6243 .arg(&self.meta_buffers[*meta_idx]);
6244 unsafe {
6245 launcher
6246 .launch(cfg)
6247 .expect("rlx-cuda: transpose launch failed");
6248 }
6249 }
6250 Step::Expand {
6251 rank,
6252 out_total,
6253 in_off,
6254 out_off,
6255 meta_idx,
6256 } => {
6257 let kernel = expand_kernel(&self.ctx);
6258 let (grid, block) = dispatch_grid_1d(*out_total, 256);
6259 let cfg = LaunchConfig {
6260 grid_dim: (grid, 1, 1),
6261 block_dim: (block, 1, 1),
6262 shared_mem_bytes: 0,
6263 };
6264 let mut launcher = stream.launch_builder(&kernel.function);
6265 launcher
6266 .arg(self.arena.f32_buf_mut())
6267 .arg(rank)
6268 .arg(out_total)
6269 .arg(in_off)
6270 .arg(out_off)
6271 .arg(&self.meta_buffers[*meta_idx]);
6272 unsafe {
6273 launcher
6274 .launch(cfg)
6275 .expect("rlx-cuda: expand launch failed");
6276 }
6277 }
6278 Step::Concat {
6279 total,
6280 outer,
6281 inner,
6282 axis_in_size,
6283 axis_out_size,
6284 start,
6285 in_off,
6286 out_off,
6287 } => {
6288 let kernel = concat_kernel(&self.ctx);
6289 let (grid, block) = dispatch_grid_1d(*total, 256);
6290 let cfg = LaunchConfig {
6291 grid_dim: (grid, 1, 1),
6292 block_dim: (block, 1, 1),
6293 shared_mem_bytes: 0,
6294 };
6295 let mut launcher = stream.launch_builder(&kernel.function);
6296 launcher
6297 .arg(self.arena.f32_buf_mut())
6298 .arg(total)
6299 .arg(outer)
6300 .arg(inner)
6301 .arg(axis_in_size)
6302 .arg(axis_out_size)
6303 .arg(start)
6304 .arg(in_off)
6305 .arg(out_off);
6306 unsafe {
6307 launcher
6308 .launch(cfg)
6309 .expect("rlx-cuda: concat launch failed");
6310 }
6311 }
6312 Step::Attention {
6313 batch,
6314 heads,
6315 seq_q,
6316 seq_k,
6317 head_dim,
6318 q_off,
6319 k_off,
6320 v_off,
6321 out_off,
6322 mask_off,
6323 mask_kind,
6324 scale_bits,
6325 window,
6326 seq_q_stride,
6327 seq_k_stride,
6328 mask_batch_stride,
6329 mask_head_stride,
6330 q_batch_stride,
6331 q_head_stride,
6332 q_seq_stride,
6333 k_batch_stride,
6334 k_head_stride,
6335 k_seq_stride,
6336 v_batch_stride,
6337 v_head_stride,
6338 v_seq_stride,
6339 o_batch_stride,
6340 o_head_stride,
6341 o_seq_stride,
6342 } => {
6343 let use_row = rlx_ir::attention_dispatch_use_row(
6346 *head_dim,
6347 "RLX_CUDA_FORCE_ATTENTION_ROW",
6348 );
6349 let mut launcher = stream.launch_builder(if use_row {
6350 &attention_row_kernel(&self.ctx).function
6351 } else {
6352 &attention_kernel(&self.ctx).function
6353 });
6354 launcher
6355 .arg(self.arena.f32_buf_mut())
6356 .arg(batch)
6357 .arg(heads)
6358 .arg(seq_q)
6359 .arg(seq_k)
6360 .arg(head_dim)
6361 .arg(q_off)
6362 .arg(k_off)
6363 .arg(v_off)
6364 .arg(out_off)
6365 .arg(mask_off)
6366 .arg(mask_kind)
6367 .arg(scale_bits)
6368 .arg(window)
6369 .arg(seq_q_stride)
6370 .arg(seq_k_stride)
6371 .arg(mask_batch_stride)
6372 .arg(mask_head_stride)
6373 .arg(q_batch_stride)
6374 .arg(q_head_stride)
6375 .arg(q_seq_stride)
6376 .arg(k_batch_stride)
6377 .arg(k_head_stride)
6378 .arg(k_seq_stride)
6379 .arg(v_batch_stride)
6380 .arg(v_head_stride)
6381 .arg(v_seq_stride)
6382 .arg(o_batch_stride)
6383 .arg(o_head_stride)
6384 .arg(o_seq_stride);
6385 let cfg = if use_row {
6386 let total = batch * heads * seq_q;
6387 let block = 256u32;
6388 LaunchConfig {
6389 grid_dim: (total.div_ceil(block), 1, 1),
6390 block_dim: (block, 1, 1),
6391 shared_mem_bytes: 0,
6392 }
6393 } else {
6394 let q_blocks = (*seq_q).div_ceil(16);
6395 LaunchConfig {
6396 grid_dim: (q_blocks, batch * heads, 1),
6397 block_dim: (128, 1, 1),
6398 shared_mem_bytes: 0,
6399 }
6400 };
6401 unsafe {
6402 launcher
6403 .launch(cfg)
6404 .expect("rlx-cuda: attention launch failed");
6405 }
6406 }
6407 Step::AttentionBackward {
6408 batch,
6409 heads,
6410 seq_q,
6411 seq_k,
6412 head_dim,
6413 q_off,
6414 k_off,
6415 v_off,
6416 dy_off,
6417 out_off,
6418 mask_off,
6419 mask_kind,
6420 scale_bits,
6421 window,
6422 wrt,
6423 } => {
6424 let kernel = attention_bwd_kernel(&self.ctx);
6425 let seq_axis = if *wrt == 0 { *seq_q } else { *seq_k };
6426 let y_blocks = seq_axis.div_ceil(256);
6427 let cfg = LaunchConfig {
6428 grid_dim: (batch * heads, y_blocks, 1),
6429 block_dim: (256, 1, 1),
6430 shared_mem_bytes: 0,
6431 };
6432 let mut launcher = stream.launch_builder(&kernel.function);
6433 launcher
6434 .arg(self.arena.f32_buf_mut())
6435 .arg(batch)
6436 .arg(heads)
6437 .arg(seq_q)
6438 .arg(seq_k)
6439 .arg(head_dim)
6440 .arg(q_off)
6441 .arg(k_off)
6442 .arg(v_off)
6443 .arg(dy_off)
6444 .arg(out_off)
6445 .arg(mask_off)
6446 .arg(mask_kind)
6447 .arg(scale_bits)
6448 .arg(window)
6449 .arg(wrt);
6450 unsafe {
6451 launcher
6452 .launch(cfg)
6453 .expect("rlx-cuda: attention_bwd launch failed");
6454 }
6455 }
6456 Step::Rope {
6457 n_total,
6458 seq,
6459 head_dim,
6460 half,
6461 in_off,
6462 cos_off,
6463 sin_off,
6464 out_off,
6465 last_dim,
6466 } => {
6467 let kernel = rope_kernel(&self.ctx);
6468 let (grid, block) = dispatch_grid_1d(*n_total, 256);
6469 let cfg = LaunchConfig {
6470 grid_dim: (grid, 1, 1),
6471 block_dim: (block, 1, 1),
6472 shared_mem_bytes: 0,
6473 };
6474 let mut launcher = stream.launch_builder(&kernel.function);
6475 launcher
6476 .arg(self.arena.f32_buf_mut())
6477 .arg(n_total)
6478 .arg(seq)
6479 .arg(head_dim)
6480 .arg(half)
6481 .arg(in_off)
6482 .arg(cos_off)
6483 .arg(sin_off)
6484 .arg(out_off)
6485 .arg(last_dim);
6486 unsafe {
6487 launcher.launch(cfg).expect("rlx-cuda: rope launch failed");
6488 }
6489 }
6490 Step::Cumsum {
6491 outer,
6492 inner,
6493 in_off,
6494 out_off,
6495 exclusive,
6496 } => {
6497 let outer_s = scale(*outer);
6498 if outer_s == 0 {
6499 continue;
6500 }
6501 let kernel = cumsum_kernel(&self.ctx);
6502 let (grid, block) = dispatch_grid_1d(outer_s, 256);
6503 let cfg = LaunchConfig {
6504 grid_dim: (grid, 1, 1),
6505 block_dim: (block, 1, 1),
6506 shared_mem_bytes: 0,
6507 };
6508 let mut launcher = stream.launch_builder(&kernel.function);
6509 launcher
6510 .arg(self.arena.f32_buf_mut())
6511 .arg(&outer_s)
6512 .arg(inner)
6513 .arg(in_off)
6514 .arg(out_off)
6515 .arg(exclusive);
6516 unsafe {
6517 launcher
6518 .launch(cfg)
6519 .expect("rlx-cuda: cumsum launch failed");
6520 }
6521 }
6522 Step::TopK {
6523 outer,
6524 inner,
6525 k,
6526 in_off,
6527 out_off,
6528 } => {
6529 let kernel = topk_kernel(&self.ctx);
6530 let (grid, block) = dispatch_grid_1d(*outer, 256);
6531 let cfg = LaunchConfig {
6532 grid_dim: (grid, 1, 1),
6533 block_dim: (block, 1, 1),
6534 shared_mem_bytes: 0,
6535 };
6536 let mut launcher = stream.launch_builder(&kernel.function);
6537 launcher
6538 .arg(self.arena.f32_buf_mut())
6539 .arg(outer)
6540 .arg(inner)
6541 .arg(k)
6542 .arg(in_off)
6543 .arg(out_off);
6544 unsafe {
6545 launcher.launch(cfg).expect("rlx-cuda: topk launch failed");
6546 }
6547 }
6548 Step::GroupedMatmul {
6549 m,
6550 k,
6551 n,
6552 num_experts,
6553 in_off,
6554 w_off,
6555 idx_off,
6556 out_off,
6557 } => {
6558 let used_sorted = if let Some(blas) = self.blas.as_ref() {
6567 stream
6569 .synchronize()
6570 .expect("rlx-cuda: stream sync before idx download");
6571 let idx_host = {
6572 let idx_slot = self
6573 .arena
6574 .f32_buf()
6575 .slice(*idx_off as usize..(idx_off + m) as usize);
6576 stream.clone_dtoh(&idx_slot).ok()
6577 };
6578 match idx_host {
6579 Some(idx_vec) => {
6580 let mut runs: Vec<(u32, u32, u32)> = Vec::new();
6581 let mut i = 0usize;
6582 let mn = *m as usize;
6583 while i < mn {
6584 let e = idx_vec[i] as u32;
6585 let mut j = i + 1;
6586 while j < mn && (idx_vec[j] as u32) == e {
6587 j += 1;
6588 }
6589 if e < *num_experts {
6590 runs.push((i as u32, j as u32, e));
6591 }
6592 i = j;
6593 }
6594 let threshold = (mn / 4).max(2);
6597 if !runs.is_empty() && runs.len() <= threshold {
6598 let blas = blas.lock().unwrap();
6599 let (arena_ptr, _record) =
6600 self.arena.f32_buf_mut().device_ptr_mut(&stream);
6601 let alpha: f32 = 1.0;
6602 let beta: f32 = 0.0;
6603 let mut all_ok = true;
6604 for (lo, hi, e) in &runs {
6605 let rows = hi - lo;
6606 let a_dev = arena_ptr + ((*in_off + lo * k) as u64) * 4;
6607 let b_dev = arena_ptr + ((*w_off + e * k * n) as u64) * 4;
6608 let c_dev = arena_ptr + ((*out_off + lo * n) as u64) * 4;
6609 let r = unsafe {
6610 cudarc::cublas::result::sgemm(
6611 *blas.handle(),
6612 cublas_sys::cublasOperation_t::CUBLAS_OP_N,
6613 cublas_sys::cublasOperation_t::CUBLAS_OP_N,
6614 *n as i32,
6615 rows as i32,
6616 *k as i32,
6617 &alpha as *const f32,
6618 b_dev as *const f32,
6619 *n as i32,
6620 a_dev as *const f32,
6621 *k as i32,
6622 &beta as *const f32,
6623 c_dev as *mut f32,
6624 *n as i32,
6625 )
6626 };
6627 if r.is_err() {
6628 all_ok = false;
6629 break;
6630 }
6631 }
6632 all_ok
6633 } else {
6634 false
6635 }
6636 }
6637 None => false,
6638 }
6639 } else {
6640 false
6641 };
6642 if used_sorted {
6643 continue;
6644 }
6645
6646 let kernel = grouped_matmul_kernel(&self.ctx);
6648 let cfg = LaunchConfig {
6649 grid_dim: ((*n).div_ceil(8), (*m).div_ceil(8), 1),
6650 block_dim: (8, 8, 1),
6651 shared_mem_bytes: 0,
6652 };
6653 let mut launcher = stream.launch_builder(&kernel.function);
6654 launcher
6655 .arg(self.arena.f32_buf_mut())
6656 .arg(m)
6657 .arg(k)
6658 .arg(n)
6659 .arg(num_experts)
6660 .arg(in_off)
6661 .arg(w_off)
6662 .arg(idx_off)
6663 .arg(out_off);
6664 unsafe {
6665 launcher
6666 .launch(cfg)
6667 .expect("rlx-cuda: grouped_matmul launch failed");
6668 }
6669 }
6670 Step::ScatterAddZero { out_off, out_total } => {
6671 let kernel = scatter_add_zero_kernel(&self.ctx);
6672 let (grid, block) = dispatch_grid_1d(*out_total, 256);
6673 let cfg = LaunchConfig {
6674 grid_dim: (grid, 1, 1),
6675 block_dim: (block, 1, 1),
6676 shared_mem_bytes: 0,
6677 };
6678 let mut launcher = stream.launch_builder(&kernel.function);
6679 launcher
6680 .arg(self.arena.f32_buf_mut())
6681 .arg(out_off)
6682 .arg(out_total);
6683 unsafe {
6684 launcher
6685 .launch(cfg)
6686 .expect("rlx-cuda: scatter_add_zero launch failed");
6687 }
6688 }
6689 Step::ScatterAddAcc {
6690 out_off,
6691 upd_off,
6692 idx_off,
6693 num_updates,
6694 trailing,
6695 out_dim,
6696 } => {
6697 let kernel = scatter_add_acc_kernel(&self.ctx);
6698 let total = num_updates * trailing;
6699 let (grid, block) = dispatch_grid_1d(total, 256);
6700 let cfg = LaunchConfig {
6701 grid_dim: (grid, 1, 1),
6702 block_dim: (block, 1, 1),
6703 shared_mem_bytes: 0,
6704 };
6705 let mut launcher = stream.launch_builder(&kernel.function);
6706 launcher
6707 .arg(self.arena.f32_buf_mut())
6708 .arg(out_off)
6709 .arg(upd_off)
6710 .arg(idx_off)
6711 .arg(num_updates)
6712 .arg(trailing)
6713 .arg(out_dim);
6714 unsafe {
6715 launcher
6716 .launch(cfg)
6717 .expect("rlx-cuda: scatter_add_acc launch failed");
6718 }
6719 }
6720 Step::DequantMatmul {
6721 m,
6722 k,
6723 n,
6724 block_size,
6725 scheme_id,
6726 x_off,
6727 w_off,
6728 scale_off,
6729 zp_off,
6730 out_off,
6731 } => {
6732 let kernel = dequant_matmul_kernel(&self.ctx);
6733 let cfg = LaunchConfig {
6734 grid_dim: ((*n).div_ceil(8), (*m).div_ceil(8), 1),
6735 block_dim: (8, 8, 1),
6736 shared_mem_bytes: 0,
6737 };
6738 let mut launcher = stream.launch_builder(&kernel.function);
6739 launcher
6740 .arg(self.arena.f32_buf_mut())
6741 .arg(m)
6742 .arg(k)
6743 .arg(n)
6744 .arg(block_size)
6745 .arg(scheme_id)
6746 .arg(x_off)
6747 .arg(w_off)
6748 .arg(scale_off)
6749 .arg(zp_off)
6750 .arg(out_off);
6751 unsafe {
6752 launcher
6753 .launch(cfg)
6754 .expect("rlx-cuda: dequant_matmul launch failed");
6755 }
6756 }
6757 Step::DequantMatmulGguf {
6758 m,
6759 k,
6760 n,
6761 scheme_id,
6762 x_byte_off,
6763 w_byte_off,
6764 out_byte_off,
6765 } => {
6766 let use_gpu = self.dequant_scratch_off > 0 && self.blas.is_some();
6767 if use_gpu {
6768 let blas = self.blas.as_ref().unwrap();
6769 crate::gguf_gpu::run_dequant_matmul_gguf_gpu(
6770 &self.ctx,
6771 &stream,
6772 self.arena.f32_buf_mut(),
6773 blas,
6774 *m as usize,
6775 *k as usize,
6776 *n as usize,
6777 *scheme_id,
6778 *x_byte_off as usize,
6779 *w_byte_off as usize,
6780 self.dequant_scratch_off,
6781 *out_byte_off as usize,
6782 );
6783 } else {
6784 crate::gguf_host::run_dequant_matmul_gguf(
6785 &stream,
6786 self.arena.f32_buf_mut(),
6787 *m as usize,
6788 *k as usize,
6789 *n as usize,
6790 *scheme_id,
6791 *x_byte_off as usize,
6792 *w_byte_off as usize,
6793 *out_byte_off as usize,
6794 );
6795 }
6796 }
6797 Step::DequantGroupedMatmulGguf {
6798 m,
6799 k,
6800 n,
6801 num_experts,
6802 scheme_id,
6803 x_byte_off,
6804 w_byte_off,
6805 idx_byte_off,
6806 out_byte_off,
6807 } => {
6808 let use_gpu = self.dequant_scratch_off > 0 && self.blas.is_some();
6809 if use_gpu {
6810 let blas = self.blas.as_ref().unwrap();
6811 crate::gguf_gpu::run_dequant_grouped_matmul_gguf_gpu(
6812 &self.ctx,
6813 &stream,
6814 self.arena.f32_buf_mut(),
6815 blas,
6816 *m as usize,
6817 *k as usize,
6818 *n as usize,
6819 *num_experts as usize,
6820 *scheme_id,
6821 *x_byte_off as usize,
6822 *w_byte_off as usize,
6823 *idx_byte_off as usize,
6824 self.dequant_scratch_off,
6825 *out_byte_off as usize,
6826 );
6827 } else {
6828 crate::gguf_host::run_dequant_grouped_matmul_gguf(
6829 &stream,
6830 self.arena.f32_buf_mut(),
6831 *m as usize,
6832 *k as usize,
6833 *n as usize,
6834 *num_experts as usize,
6835 *scheme_id,
6836 *x_byte_off as usize,
6837 *w_byte_off as usize,
6838 *idx_byte_off as usize,
6839 *out_byte_off as usize,
6840 );
6841 }
6842 }
6843 Step::Sample {
6844 outer,
6845 inner,
6846 in_off,
6847 out_off,
6848 top_k,
6849 top_p_bits,
6850 temp_bits,
6851 seed_lo,
6852 seed_hi,
6853 } => {
6854 let kernel = sample_kernel(&self.ctx);
6855 let (grid, block) = dispatch_grid_1d(*outer, 256);
6856 let cfg = LaunchConfig {
6857 grid_dim: (grid, 1, 1),
6858 block_dim: (block, 1, 1),
6859 shared_mem_bytes: 0,
6860 };
6861 let mut launcher = stream.launch_builder(&kernel.function);
6862 launcher
6863 .arg(self.arena.f32_buf_mut())
6864 .arg(outer)
6865 .arg(inner)
6866 .arg(in_off)
6867 .arg(out_off)
6868 .arg(top_k)
6869 .arg(top_p_bits)
6870 .arg(temp_bits)
6871 .arg(seed_lo)
6872 .arg(seed_hi);
6873 unsafe {
6874 launcher
6875 .launch(cfg)
6876 .expect("rlx-cuda: sample launch failed");
6877 }
6878 }
6879 Step::SelectiveScan {
6880 batch,
6881 seq,
6882 hidden,
6883 state_size,
6884 x_off,
6885 delta_off,
6886 a_off,
6887 b_off,
6888 c_off,
6889 out_off,
6890 } => {
6891 let kernel = selective_scan_kernel(&self.ctx);
6892 let total = batch * hidden;
6893 let (grid, block) = dispatch_grid_1d(total, 256);
6894 let cfg = LaunchConfig {
6895 grid_dim: (grid, 1, 1),
6896 block_dim: (block, 1, 1),
6897 shared_mem_bytes: 0,
6898 };
6899 let mut launcher = stream.launch_builder(&kernel.function);
6900 launcher
6901 .arg(self.arena.f32_buf_mut())
6902 .arg(batch)
6903 .arg(seq)
6904 .arg(hidden)
6905 .arg(state_size)
6906 .arg(x_off)
6907 .arg(delta_off)
6908 .arg(a_off)
6909 .arg(b_off)
6910 .arg(c_off)
6911 .arg(out_off);
6912 unsafe {
6913 launcher
6914 .launch(cfg)
6915 .expect("rlx-cuda: selective_scan launch failed");
6916 }
6917 }
6918 Step::Fft {
6919 src_byte_off,
6920 dst_byte_off,
6921 outer,
6922 n_complex,
6923 inverse,
6924 norm_tag,
6925 dtype_tag,
6926 use_gpu,
6927 } => {
6928 if *use_gpu {
6929 let norm = rlx_ir::fft::FftNorm::from_tag(*norm_tag);
6930 let scale = norm.output_scale(*n_complex as usize, *inverse) as f32;
6931 crate::fft_dispatch::run_fft_gpu(
6932 &self.ctx,
6933 &stream,
6934 self.arena.f32_buf_mut(),
6935 *src_byte_off / 4,
6936 *dst_byte_off / 4,
6937 *outer,
6938 *n_complex,
6939 *inverse,
6940 scale,
6941 );
6942 } else {
6943 let (buf, arena_size) = self.arena.f32_buf_and_size();
6944 crate::fft_host::run_fft1d(
6945 &stream,
6946 buf,
6947 arena_size,
6948 *src_byte_off as usize,
6949 *dst_byte_off as usize,
6950 *outer as usize,
6951 *n_complex as usize,
6952 *inverse,
6953 *norm_tag,
6954 fft_dtype_from_tag(*dtype_tag),
6955 );
6956 }
6957 }
6958 Step::WelchPeaksGpu {
6959 spec_off,
6960 dst_off,
6961 welch_batch,
6962 n_fft,
6963 n_segments,
6964 k,
6965 n_bins,
6966 } => {
6967 crate::welch_peaks_dispatch::run_welch_peaks_gpu(
6968 &self.ctx,
6969 &stream,
6970 self.arena.f32_buf_mut(),
6971 *spec_off,
6972 *dst_off,
6973 *welch_batch,
6974 *n_fft,
6975 *n_segments,
6976 *k,
6977 *n_bins,
6978 );
6979 }
6980 Step::LogMelHost { .. }
6981 | Step::LogMelBackwardHost { .. }
6982 | Step::WelchPeaksHost { .. } => {}
6983 Step::Im2ColHost {
6984 x_byte_off,
6985 col_byte_off,
6986 n,
6987 c_in,
6988 h,
6989 w,
6990 h_out,
6991 w_out,
6992 kh,
6993 kw,
6994 sh,
6995 sw,
6996 ph,
6997 pw,
6998 dh,
6999 dw_dil,
7000 use_gpu,
7001 } => {
7002 if *use_gpu {
7003 let kernel = im2col_kernel(&self.ctx);
7004 let m = *n * *h_out * *w_out;
7005 let k = *c_in * *kh * *kw;
7006 let total = m * k;
7007 let (grid, block) = dispatch_grid_1d(total, 256);
7008 let cfg = LaunchConfig {
7009 grid_dim: (grid, 1, 1),
7010 block_dim: (block, 1, 1),
7011 shared_mem_bytes: 0,
7012 };
7013 let x_off = *x_byte_off / 4;
7014 let col_off = *col_byte_off / 4;
7015 let mut launcher = stream.launch_builder(&kernel.function);
7016 launcher
7017 .arg(self.arena.f32_buf_mut())
7018 .arg(n)
7019 .arg(c_in)
7020 .arg(h)
7021 .arg(w)
7022 .arg(h_out)
7023 .arg(w_out)
7024 .arg(kh)
7025 .arg(kw)
7026 .arg(sh)
7027 .arg(sw)
7028 .arg(ph)
7029 .arg(pw)
7030 .arg(dh)
7031 .arg(dw_dil)
7032 .arg(&x_off)
7033 .arg(&col_off);
7034 unsafe {
7035 launcher
7036 .launch(cfg)
7037 .expect("rlx-cuda: im2col launch failed");
7038 }
7039 } else {
7040 crate::im2col_host::run_im2col(
7041 &stream,
7042 self.arena.f32_buf_mut(),
7043 *x_byte_off as usize,
7044 *col_byte_off as usize,
7045 *n,
7046 *c_in,
7047 *h,
7048 *w,
7049 *h_out,
7050 *w_out,
7051 *kh,
7052 *kw,
7053 *sh,
7054 *sw,
7055 *ph,
7056 *pw,
7057 *dh,
7058 *dw_dil,
7059 );
7060 }
7061 }
7062 Step::GatedDeltaNet {
7063 q_byte_off,
7064 k_byte_off,
7065 v_byte_off,
7066 g_byte_off,
7067 beta_byte_off,
7068 state_byte_off,
7069 dst_byte_off,
7070 batch,
7071 seq,
7072 heads,
7073 state_size,
7074 use_carry,
7075 } => {
7076 let (buf, arena_size) = self.arena.f32_buf_and_size();
7077 crate::gdn_host::run_gated_delta_net(
7078 &stream,
7079 buf,
7080 arena_size,
7081 *q_byte_off as usize,
7082 *k_byte_off as usize,
7083 *v_byte_off as usize,
7084 *g_byte_off as usize,
7085 *beta_byte_off as usize,
7086 *state_byte_off as usize,
7087 *dst_byte_off as usize,
7088 *batch as usize,
7089 *seq as usize,
7090 *heads as usize,
7091 *state_size as usize,
7092 *use_carry,
7093 );
7094 }
7095 Step::Llada2GroupLimitedGate {
7096 sig_off,
7097 route_off,
7098 out_off,
7099 n_elems,
7100 attrs,
7101 } => {
7102 let (buf, arena_size) = self.arena.f32_buf_and_size();
7103 crate::llada2_gate_host::run_llada2_group_limited_gate(
7104 &stream,
7105 buf,
7106 arena_size,
7107 *sig_off as usize,
7108 *route_off as usize,
7109 *out_off as usize,
7110 *n_elems as usize,
7111 attrs,
7112 );
7113 }
7114 Step::UmapKnn {
7115 pairwise_off,
7116 out_off,
7117 n,
7118 k,
7119 } => {
7120 let (buf, arena_size) = self.arena.f32_buf_and_size();
7121 crate::umap_knn_host::run_umap_knn(
7122 &stream,
7123 buf,
7124 arena_size,
7125 *pairwise_off as usize,
7126 *out_off as usize,
7127 *n as usize,
7128 *k as usize,
7129 );
7130 }
7131 Step::LayerNorm2d {
7132 src_off,
7133 g_off,
7134 b_off,
7135 dst_off,
7136 n,
7137 c,
7138 h,
7139 w,
7140 eps_bits,
7141 } => {
7142 let kernel = layer_norm2d_kernel(&self.ctx);
7143 let total = n * h * w;
7144 let (grid, block) = dispatch_grid_1d(total, 256);
7145 let cfg = LaunchConfig {
7146 grid_dim: (grid, 1, 1),
7147 block_dim: (block, 1, 1),
7148 shared_mem_bytes: 0,
7149 };
7150 let mut launcher = stream.launch_builder(&kernel.function);
7151 launcher
7152 .arg(self.arena.f32_buf_mut())
7153 .arg(src_off)
7154 .arg(g_off)
7155 .arg(b_off)
7156 .arg(dst_off)
7157 .arg(n)
7158 .arg(c)
7159 .arg(h)
7160 .arg(w)
7161 .arg(eps_bits);
7162 unsafe {
7163 launcher
7164 .launch(cfg)
7165 .expect("rlx-cuda: layer_norm2d launch failed");
7166 }
7167 }
7168 Step::ConvTranspose2d {
7169 src_off,
7170 w_off,
7171 dst_off,
7172 n,
7173 c_in,
7174 h,
7175 w_in,
7176 c_out,
7177 h_out,
7178 w_out,
7179 kh,
7180 kw,
7181 sh,
7182 sw,
7183 ph,
7184 pw,
7185 dh,
7186 dw,
7187 groups,
7188 } => {
7189 let kernel = conv_transpose2d_kernel(&self.ctx);
7190 let total = n * c_out * h_out * w_out;
7191 let (grid, block) = dispatch_grid_1d(total, 256);
7192 let cfg = LaunchConfig {
7193 grid_dim: (grid, 1, 1),
7194 block_dim: (block, 1, 1),
7195 shared_mem_bytes: 0,
7196 };
7197 let mut launcher = stream.launch_builder(&kernel.function);
7198 launcher
7199 .arg(self.arena.f32_buf_mut())
7200 .arg(src_off)
7201 .arg(w_off)
7202 .arg(dst_off)
7203 .arg(n)
7204 .arg(c_in)
7205 .arg(h)
7206 .arg(w_in)
7207 .arg(c_out)
7208 .arg(h_out)
7209 .arg(w_out)
7210 .arg(kh)
7211 .arg(kw)
7212 .arg(sh)
7213 .arg(sw)
7214 .arg(ph)
7215 .arg(pw)
7216 .arg(dh)
7217 .arg(dw)
7218 .arg(groups);
7219 unsafe {
7220 launcher
7221 .launch(cfg)
7222 .expect("rlx-cuda: conv_transpose2d launch failed");
7223 }
7224 }
7225 Step::GroupNorm {
7226 src_off,
7227 g_off,
7228 b_off,
7229 dst_off,
7230 n,
7231 c,
7232 h,
7233 w,
7234 num_groups,
7235 eps_bits,
7236 } => {
7237 let kernel = group_norm_kernel(&self.ctx);
7238 let grid = n * num_groups;
7239 let cfg = LaunchConfig {
7240 grid_dim: (grid, 1, 1),
7241 block_dim: (256, 1, 1),
7242 shared_mem_bytes: 0,
7243 };
7244 let mut launcher = stream.launch_builder(&kernel.function);
7245 launcher
7246 .arg(self.arena.f32_buf_mut())
7247 .arg(src_off)
7248 .arg(g_off)
7249 .arg(b_off)
7250 .arg(dst_off)
7251 .arg(n)
7252 .arg(c)
7253 .arg(h)
7254 .arg(w)
7255 .arg(num_groups)
7256 .arg(eps_bits);
7257 unsafe {
7258 launcher
7259 .launch(cfg)
7260 .expect("rlx-cuda: group_norm launch failed");
7261 }
7262 }
7263 Step::ResizeNearest2x {
7264 src_off,
7265 dst_off,
7266 n,
7267 c,
7268 h,
7269 w,
7270 } => {
7271 let kernel = resize_nearest_2x_kernel(&self.ctx);
7272 let total = n * c * h * 2 * w * 2;
7273 let (grid, block) = dispatch_grid_1d(total, 256);
7274 let cfg = LaunchConfig {
7275 grid_dim: (grid, 1, 1),
7276 block_dim: (block, 1, 1),
7277 shared_mem_bytes: 0,
7278 };
7279 let mut launcher = stream.launch_builder(&kernel.function);
7280 launcher
7281 .arg(self.arena.f32_buf_mut())
7282 .arg(src_off)
7283 .arg(dst_off)
7284 .arg(n)
7285 .arg(c)
7286 .arg(h)
7287 .arg(w);
7288 unsafe {
7289 launcher
7290 .launch(cfg)
7291 .expect("rlx-cuda: resize_nearest_2x launch failed");
7292 }
7293 }
7294 Step::GaussianSplatRender {
7295 positions_off,
7296 positions_len,
7297 scales_off,
7298 scales_len,
7299 rotations_off,
7300 rotations_len,
7301 opacities_off,
7302 opacities_len,
7303 colors_off,
7304 colors_len,
7305 sh_coeffs_off,
7306 sh_coeffs_len,
7307 meta_off,
7308 dst_off,
7309 dst_len,
7310 width,
7311 height,
7312 tile_size,
7313 radius_scale,
7314 alpha_cutoff,
7315 max_splat_steps,
7316 transmittance_threshold,
7317 max_list_entries,
7318 } => {
7319 let (buf, arena_size) = self.arena.f32_buf_and_size();
7320 #[cfg(feature = "native-splat")]
7321 crate::splat_native::run_gaussian_splat_render_native(
7322 &stream,
7323 buf,
7324 arena_size,
7325 *positions_off as usize,
7326 *positions_len as usize,
7327 *scales_off as usize,
7328 *scales_len as usize,
7329 *rotations_off as usize,
7330 *rotations_len as usize,
7331 *opacities_off as usize,
7332 *opacities_len as usize,
7333 *colors_off as usize,
7334 *colors_len as usize,
7335 *sh_coeffs_off as usize,
7336 *sh_coeffs_len as usize,
7337 *meta_off as usize,
7338 *dst_off as usize,
7339 *width,
7340 *height,
7341 *tile_size,
7342 *radius_scale,
7343 *alpha_cutoff,
7344 *max_splat_steps,
7345 *transmittance_threshold,
7346 *max_list_entries,
7347 );
7348 #[cfg(not(feature = "native-splat"))]
7349 crate::splat_host::run_gaussian_splat_render(
7350 &stream,
7351 buf,
7352 arena_size,
7353 *positions_off as usize,
7354 *positions_len as usize,
7355 *scales_off as usize,
7356 *scales_len as usize,
7357 *rotations_off as usize,
7358 *rotations_len as usize,
7359 *opacities_off as usize,
7360 *opacities_len as usize,
7361 *colors_off as usize,
7362 *colors_len as usize,
7363 *sh_coeffs_off as usize,
7364 *sh_coeffs_len as usize,
7365 *meta_off as usize,
7366 *dst_off as usize,
7367 *dst_len as usize,
7368 *width,
7369 *height,
7370 *tile_size,
7371 *radius_scale,
7372 *alpha_cutoff,
7373 *max_splat_steps,
7374 *transmittance_threshold,
7375 *max_list_entries,
7376 );
7377 }
7378 Step::GaussianSplatPrepare {
7379 positions_off,
7380 positions_len,
7381 scales_off,
7382 scales_len,
7383 rotations_off,
7384 rotations_len,
7385 opacities_off,
7386 opacities_len,
7387 colors_off,
7388 colors_len,
7389 sh_coeffs_off,
7390 sh_coeffs_len,
7391 meta_off,
7392 meta_len,
7393 prep_off,
7394 prep_len,
7395 width,
7396 height,
7397 tile_size,
7398 radius_scale,
7399 alpha_cutoff,
7400 max_splat_steps,
7401 transmittance_threshold,
7402 max_list_entries,
7403 } => {
7404 let (buf, arena_size) = self.arena.f32_buf_and_size();
7405 crate::splat_host::run_gaussian_splat_prepare(
7406 &stream,
7407 buf,
7408 arena_size,
7409 *positions_off as usize,
7410 *positions_len as usize,
7411 *scales_off as usize,
7412 *scales_len as usize,
7413 *rotations_off as usize,
7414 *rotations_len as usize,
7415 *opacities_off as usize,
7416 *opacities_len as usize,
7417 *colors_off as usize,
7418 *colors_len as usize,
7419 *sh_coeffs_off as usize,
7420 *sh_coeffs_len as usize,
7421 *meta_off as usize,
7422 *meta_len as usize,
7423 *prep_off as usize,
7424 *prep_len as usize,
7425 *width,
7426 *height,
7427 *tile_size,
7428 *radius_scale,
7429 *alpha_cutoff,
7430 *max_splat_steps,
7431 *transmittance_threshold,
7432 *max_list_entries,
7433 );
7434 }
7435 Step::GaussianSplatRasterize {
7436 prep_off,
7437 prep_len,
7438 meta_off,
7439 meta_len,
7440 dst_off,
7441 dst_len,
7442 count,
7443 width,
7444 height,
7445 tile_size,
7446 alpha_cutoff,
7447 max_splat_steps,
7448 transmittance_threshold,
7449 max_list_entries,
7450 } => {
7451 let (buf, arena_size) = self.arena.f32_buf_and_size();
7452 crate::splat_host::run_gaussian_splat_rasterize(
7453 &stream,
7454 buf,
7455 arena_size,
7456 *prep_off as usize,
7457 *prep_len as usize,
7458 *meta_off as usize,
7459 *meta_len as usize,
7460 *dst_off as usize,
7461 *dst_len as usize,
7462 *count as usize,
7463 *width,
7464 *height,
7465 *tile_size,
7466 *alpha_cutoff,
7467 *max_splat_steps,
7468 *transmittance_threshold,
7469 *max_list_entries,
7470 );
7471 }
7472 Step::GaussianSplatRenderBackward {
7473 positions_off,
7474 positions_len,
7475 scales_off,
7476 scales_len,
7477 rotations_off,
7478 rotations_len,
7479 opacities_off,
7480 opacities_len,
7481 colors_off,
7482 colors_len,
7483 sh_coeffs_off,
7484 sh_coeffs_len,
7485 meta_off,
7486 d_loss_off,
7487 d_loss_len,
7488 packed_off,
7489 packed_len,
7490 width,
7491 height,
7492 tile_size,
7493 radius_scale,
7494 alpha_cutoff,
7495 max_splat_steps,
7496 transmittance_threshold,
7497 max_list_entries,
7498 loss_grad_clip,
7499 sh_band,
7500 max_anisotropy,
7501 } => {
7502 let (buf, arena_size) = self.arena.f32_buf_and_size();
7503 crate::splat_host::run_gaussian_splat_render_backward(
7504 &stream,
7505 buf,
7506 arena_size,
7507 *positions_off as usize,
7508 *positions_len as usize,
7509 *scales_off as usize,
7510 *scales_len as usize,
7511 *rotations_off as usize,
7512 *rotations_len as usize,
7513 *opacities_off as usize,
7514 *opacities_len as usize,
7515 *colors_off as usize,
7516 *colors_len as usize,
7517 *sh_coeffs_off as usize,
7518 *sh_coeffs_len as usize,
7519 *meta_off as usize,
7520 *d_loss_off as usize,
7521 *d_loss_len as usize,
7522 *packed_off as usize,
7523 *packed_len as usize,
7524 *width,
7525 *height,
7526 *tile_size,
7527 *radius_scale,
7528 *alpha_cutoff,
7529 *max_splat_steps,
7530 *transmittance_threshold,
7531 *max_list_entries,
7532 *loss_grad_clip,
7533 *sh_band,
7534 *max_anisotropy,
7535 );
7536 }
7537 Step::RmsNormBackwardInput {
7538 x_byte_off,
7539 gamma_byte_off,
7540 beta_byte_off,
7541 dy_byte_off,
7542 dx_byte_off,
7543 rows,
7544 h,
7545 eps_bits,
7546 } => {
7547 launch_rms_norm_bwd(
7548 &self.ctx,
7549 &stream,
7550 self.arena.f32_buf_mut(),
7551 *rows,
7552 *h,
7553 *x_byte_off / 4,
7554 *gamma_byte_off / 4,
7555 *beta_byte_off / 4,
7556 *dy_byte_off / 4,
7557 *dx_byte_off / 4,
7558 *eps_bits,
7559 0,
7560 );
7561 }
7562 Step::RmsNormBackwardGamma {
7563 x_byte_off,
7564 gamma_byte_off,
7565 beta_byte_off,
7566 dy_byte_off,
7567 dgamma_byte_off,
7568 rows,
7569 h,
7570 eps_bits,
7571 } => {
7572 launch_rms_norm_bwd(
7573 &self.ctx,
7574 &stream,
7575 self.arena.f32_buf_mut(),
7576 *rows,
7577 *h,
7578 *x_byte_off / 4,
7579 *gamma_byte_off / 4,
7580 *beta_byte_off / 4,
7581 *dy_byte_off / 4,
7582 *dgamma_byte_off / 4,
7583 *eps_bits,
7584 1,
7585 );
7586 }
7587 Step::RmsNormBackwardBeta {
7588 x_byte_off,
7589 gamma_byte_off,
7590 beta_byte_off,
7591 dy_byte_off,
7592 dbeta_byte_off,
7593 rows,
7594 h,
7595 eps_bits,
7596 } => {
7597 launch_rms_norm_bwd(
7598 &self.ctx,
7599 &stream,
7600 self.arena.f32_buf_mut(),
7601 *rows,
7602 *h,
7603 *x_byte_off / 4,
7604 *gamma_byte_off / 4,
7605 *beta_byte_off / 4,
7606 *dy_byte_off / 4,
7607 *dbeta_byte_off / 4,
7608 *eps_bits,
7609 2,
7610 );
7611 }
7612 Step::RopeBackward {
7613 dy_byte_off,
7614 cos_byte_off,
7615 sin_byte_off,
7616 dx_byte_off,
7617 batch,
7618 seq,
7619 hidden,
7620 head_dim,
7621 n_rot,
7622 cos_len,
7623 } => {
7624 launch_rope_bwd(
7625 &self.ctx,
7626 &stream,
7627 self.arena.f32_buf_mut(),
7628 *batch,
7629 *seq,
7630 *hidden,
7631 *head_dim,
7632 *n_rot,
7633 *dy_byte_off / 4,
7634 *cos_byte_off / 4,
7635 *sin_byte_off / 4,
7636 *dx_byte_off / 4,
7637 *cos_len,
7638 );
7639 }
7640 Step::CumsumBackward {
7641 dy_byte_off,
7642 dx_byte_off,
7643 rows,
7644 cols,
7645 exclusive,
7646 } => {
7647 launch_cumsum_bwd(
7648 &self.ctx,
7649 &stream,
7650 self.arena.f32_buf_mut(),
7651 *rows,
7652 *cols,
7653 *dy_byte_off / 4,
7654 *dx_byte_off / 4,
7655 if *exclusive { 1 } else { 0 },
7656 );
7657 }
7658 Step::GatherBackward {
7659 dy_byte_off,
7660 indices_byte_off,
7661 dst_byte_off,
7662 outer,
7663 axis_dim,
7664 num_idx,
7665 trailing,
7666 } => {
7667 launch_gather_bwd(
7668 &self.ctx,
7669 &stream,
7670 self.arena.f32_buf_mut(),
7671 *outer,
7672 *axis_dim,
7673 *num_idx,
7674 *trailing,
7675 *dy_byte_off / 4,
7676 *indices_byte_off / 4,
7677 *dst_byte_off / 4,
7678 );
7679 }
7680 Step::MaxPool2dBackward {
7681 x_byte_off,
7682 dy_byte_off,
7683 dx_byte_off,
7684 n,
7685 c,
7686 h,
7687 w,
7688 h_out,
7689 w_out,
7690 kh,
7691 kw,
7692 sh,
7693 sw,
7694 ph,
7695 pw,
7696 } => {
7697 let buf = self.arena.f32_buf_mut();
7698 crate::training_bwd_host::run_maxpool2d_backward(
7699 &stream,
7700 buf,
7701 *x_byte_off as usize / 4,
7702 *dy_byte_off as usize / 4,
7703 *dx_byte_off as usize / 4,
7704 *n,
7705 *c,
7706 *h,
7707 *w,
7708 *h_out,
7709 *w_out,
7710 *kh,
7711 *kw,
7712 *sh,
7713 *sw,
7714 *ph,
7715 *pw,
7716 );
7717 }
7718 Step::Conv2dBackwardInput {
7719 dy_byte_off,
7720 w_byte_off,
7721 dx_byte_off,
7722 n,
7723 c_in,
7724 h,
7725 w_in,
7726 c_out,
7727 h_out,
7728 w_out,
7729 kh,
7730 kw,
7731 sh,
7732 sw,
7733 ph,
7734 pw,
7735 dh,
7736 dw,
7737 groups,
7738 } => {
7739 let buf = self.arena.f32_buf_mut();
7740 crate::training_bwd_host::run_conv2d_backward_input(
7741 &stream,
7742 buf,
7743 *dy_byte_off as usize / 4,
7744 *w_byte_off as usize / 4,
7745 *dx_byte_off as usize / 4,
7746 *n,
7747 *c_in,
7748 *h,
7749 *w_in,
7750 *c_out,
7751 *h_out,
7752 *w_out,
7753 *kh,
7754 *kw,
7755 *sh,
7756 *sw,
7757 *ph,
7758 *pw,
7759 *dh,
7760 *dw,
7761 *groups,
7762 );
7763 }
7764 Step::Conv2dBackwardWeight {
7765 x_byte_off,
7766 dy_byte_off,
7767 dw_byte_off,
7768 n,
7769 c_in,
7770 h,
7771 w,
7772 c_out,
7773 h_out,
7774 w_out,
7775 kh,
7776 kw,
7777 sh,
7778 sw,
7779 ph,
7780 pw,
7781 dh,
7782 dw_dil,
7783 groups,
7784 } => {
7785 let buf = self.arena.f32_buf_mut();
7786 crate::training_bwd_host::run_conv2d_backward_weight(
7787 &stream,
7788 buf,
7789 *x_byte_off as usize / 4,
7790 *dy_byte_off as usize / 4,
7791 *dw_byte_off as usize / 4,
7792 *n,
7793 *c_in,
7794 *h,
7795 *w,
7796 *c_out,
7797 *h_out,
7798 *w_out,
7799 *kh,
7800 *kw,
7801 *sh,
7802 *sw,
7803 *ph,
7804 *pw,
7805 *dh,
7806 *dw_dil,
7807 *groups,
7808 );
7809 }
7810 Step::Pool1d {
7811 n,
7812 c,
7813 l,
7814 l_out,
7815 kl,
7816 sl,
7817 pl,
7818 op,
7819 in_off,
7820 out_off,
7821 } => {
7822 let kernel = pool1d_kernel(&self.ctx);
7823 let total = n * c * l_out;
7824 let (grid, block) = dispatch_grid_1d(total, 256);
7825 let cfg = LaunchConfig {
7826 grid_dim: (grid, 1, 1),
7827 block_dim: (block, 1, 1),
7828 shared_mem_bytes: 0,
7829 };
7830 let mut launcher = stream.launch_builder(&kernel.function);
7831 launcher
7832 .arg(self.arena.f32_buf_mut())
7833 .arg(n)
7834 .arg(c)
7835 .arg(l)
7836 .arg(l_out)
7837 .arg(kl)
7838 .arg(sl)
7839 .arg(pl)
7840 .arg(op)
7841 .arg(in_off)
7842 .arg(out_off);
7843 unsafe {
7844 launcher
7845 .launch(cfg)
7846 .expect("rlx-cuda: pool1d launch failed");
7847 }
7848 }
7849 Step::Pool2d {
7850 n,
7851 c,
7852 h,
7853 w,
7854 h_out,
7855 w_out,
7856 kh,
7857 kw,
7858 sh,
7859 sw,
7860 ph,
7861 pw,
7862 op,
7863 in_off,
7864 out_off,
7865 } => {
7866 let kernel = pool2d_kernel(&self.ctx);
7867 let total = n * c * h_out * w_out;
7868 let (grid, block) = dispatch_grid_1d(total, 256);
7869 let cfg = LaunchConfig {
7870 grid_dim: (grid, 1, 1),
7871 block_dim: (block, 1, 1),
7872 shared_mem_bytes: 0,
7873 };
7874 let mut launcher = stream.launch_builder(&kernel.function);
7875 launcher
7876 .arg(self.arena.f32_buf_mut())
7877 .arg(n)
7878 .arg(c)
7879 .arg(h)
7880 .arg(w)
7881 .arg(h_out)
7882 .arg(w_out)
7883 .arg(kh)
7884 .arg(kw)
7885 .arg(sh)
7886 .arg(sw)
7887 .arg(ph)
7888 .arg(pw)
7889 .arg(op)
7890 .arg(in_off)
7891 .arg(out_off);
7892 unsafe {
7893 launcher
7894 .launch(cfg)
7895 .expect("rlx-cuda: pool2d launch failed");
7896 }
7897 }
7898 Step::Pool3d {
7899 n,
7900 c,
7901 d,
7902 h,
7903 w,
7904 d_out,
7905 h_out,
7906 w_out,
7907 kd,
7908 kh,
7909 kw,
7910 sd,
7911 sh,
7912 sw,
7913 pd,
7914 ph,
7915 pw,
7916 op,
7917 in_off,
7918 out_off,
7919 } => {
7920 let kernel = pool3d_kernel(&self.ctx);
7921 let total = n * c * d_out * h_out * w_out;
7922 let (grid, block) = dispatch_grid_1d(total, 256);
7923 let cfg = LaunchConfig {
7924 grid_dim: (grid, 1, 1),
7925 block_dim: (block, 1, 1),
7926 shared_mem_bytes: 0,
7927 };
7928 let mut launcher = stream.launch_builder(&kernel.function);
7929 launcher
7930 .arg(self.arena.f32_buf_mut())
7931 .arg(n)
7932 .arg(c)
7933 .arg(d)
7934 .arg(h)
7935 .arg(w)
7936 .arg(d_out)
7937 .arg(h_out)
7938 .arg(w_out)
7939 .arg(kd)
7940 .arg(kh)
7941 .arg(kw)
7942 .arg(sd)
7943 .arg(sh)
7944 .arg(sw)
7945 .arg(pd)
7946 .arg(ph)
7947 .arg(pw)
7948 .arg(op)
7949 .arg(in_off)
7950 .arg(out_off);
7951 unsafe {
7952 launcher
7953 .launch(cfg)
7954 .expect("rlx-cuda: pool3d launch failed");
7955 }
7956 }
7957 Step::Conv1d {
7958 n,
7959 c_in,
7960 c_out,
7961 l,
7962 l_out,
7963 kl,
7964 sl,
7965 pl,
7966 dl,
7967 groups,
7968 in_off,
7969 w_off,
7970 out_off,
7971 } => {
7972 let used_cudnn = if let (Some(handle), Some(workspace)) =
7976 (self.dnn, self.dnn_workspace.as_ref())
7977 {
7978 let mut workspace = workspace.lock().unwrap();
7979 let (ws_ptr, _ws_record) = workspace.device_ptr_mut(&stream);
7980 let (arena_ptr, _arena_record) =
7981 self.arena.f32_buf_mut().device_ptr_mut(&stream);
7982 let r = unsafe {
7983 cudnn_conv2d_forward(
7984 handle,
7985 ws_ptr,
7986 CUDNN_WORKSPACE_BYTES,
7987 arena_ptr,
7988 *n,
7989 *c_in,
7990 *c_out,
7991 1,
7992 *l,
7993 1,
7994 *l_out,
7995 1,
7996 *kl,
7997 1,
7998 *sl,
7999 0,
8000 *pl,
8001 1,
8002 *dl,
8003 *groups,
8004 *in_off,
8005 *w_off,
8006 *out_off,
8007 )
8008 };
8009 if let Err(ref e) = r {
8010 log_fallback("conv1d.cudnn", e);
8011 }
8012 r.is_ok()
8013 } else {
8014 false
8015 };
8016 if used_cudnn {
8017 continue;
8018 }
8019
8020 let kernel = conv1d_kernel(&self.ctx);
8022 let total = n * c_out * l_out;
8023 let (grid, block) = dispatch_grid_1d(total, 256);
8024 let cfg = LaunchConfig {
8025 grid_dim: (grid, 1, 1),
8026 block_dim: (block, 1, 1),
8027 shared_mem_bytes: 0,
8028 };
8029 let mut launcher = stream.launch_builder(&kernel.function);
8030 launcher
8031 .arg(self.arena.f32_buf_mut())
8032 .arg(n)
8033 .arg(c_in)
8034 .arg(c_out)
8035 .arg(l)
8036 .arg(l_out)
8037 .arg(kl)
8038 .arg(sl)
8039 .arg(pl)
8040 .arg(dl)
8041 .arg(groups)
8042 .arg(in_off)
8043 .arg(w_off)
8044 .arg(out_off);
8045 unsafe {
8046 launcher
8047 .launch(cfg)
8048 .expect("rlx-cuda: conv1d launch failed");
8049 }
8050 }
8051 Step::Conv2d {
8052 n,
8053 c_in,
8054 c_out,
8055 h,
8056 w,
8057 h_out,
8058 w_out,
8059 kh,
8060 kw,
8061 sh,
8062 sw,
8063 ph,
8064 pw,
8065 dh,
8066 dw,
8067 groups,
8068 in_off,
8069 w_off,
8070 out_off,
8071 } => {
8072 let try_cudnn = self.dnn.is_some()
8077 && self.dnn_workspace.is_some()
8078 && !rlx_ir::env::flag("RLX_CUDA_NO_CUDNN");
8079 let used_cudnn = if try_cudnn {
8080 let handle = self.dnn.expect("dnn handle");
8081 let workspace = self.dnn_workspace.as_ref().expect("dnn workspace");
8082 let mut workspace = workspace.lock().unwrap();
8083 let (ws_ptr, _ws_record) = workspace.device_ptr_mut(&stream);
8084 let (arena_ptr, _arena_record) =
8085 self.arena.f32_buf_mut().device_ptr_mut(&stream);
8086 let r = unsafe {
8087 cudnn_conv2d_forward(
8088 handle,
8089 ws_ptr,
8090 CUDNN_WORKSPACE_BYTES,
8091 arena_ptr,
8092 *n,
8093 *c_in,
8094 *c_out,
8095 *h,
8096 *w,
8097 *h_out,
8098 *w_out,
8099 *kh,
8100 *kw,
8101 *sh,
8102 *sw,
8103 *ph,
8104 *pw,
8105 *dh,
8106 *dw,
8107 *groups,
8108 *in_off,
8109 *w_off,
8110 *out_off,
8111 )
8112 };
8113 if let Err(ref e) = r {
8114 log_fallback("conv2d.cudnn", e);
8115 }
8116 r.is_ok()
8117 } else {
8118 false
8119 };
8120 if used_cudnn {
8121 continue;
8122 }
8123
8124 let kernel = conv2d_kernel(&self.ctx);
8126 let total = n * c_out * h_out * w_out;
8127 let (grid, block) = dispatch_grid_1d(total, 256);
8128 let cfg = LaunchConfig {
8129 grid_dim: (grid, 1, 1),
8130 block_dim: (block, 1, 1),
8131 shared_mem_bytes: 0,
8132 };
8133 let mut launcher = stream.launch_builder(&kernel.function);
8134 launcher
8135 .arg(self.arena.f32_buf_mut())
8136 .arg(n)
8137 .arg(c_in)
8138 .arg(c_out)
8139 .arg(h)
8140 .arg(w)
8141 .arg(h_out)
8142 .arg(w_out)
8143 .arg(kh)
8144 .arg(kw)
8145 .arg(sh)
8146 .arg(sw)
8147 .arg(ph)
8148 .arg(pw)
8149 .arg(dh)
8150 .arg(dw)
8151 .arg(groups)
8152 .arg(in_off)
8153 .arg(w_off)
8154 .arg(out_off);
8155 unsafe {
8156 launcher
8157 .launch(cfg)
8158 .expect("rlx-cuda: conv2d launch failed");
8159 }
8160 }
8161 Step::Conv3d {
8162 n,
8163 c_in,
8164 c_out,
8165 d,
8166 h,
8167 w,
8168 d_out,
8169 h_out,
8170 w_out,
8171 kd,
8172 kh,
8173 kw,
8174 sd,
8175 sh,
8176 sw,
8177 pd,
8178 ph,
8179 pw,
8180 dd,
8181 dh,
8182 dw,
8183 groups,
8184 in_off,
8185 w_off,
8186 out_off,
8187 } => {
8188 let used_cudnn = if let (Some(handle), Some(workspace)) =
8190 (self.dnn, self.dnn_workspace.as_ref())
8191 {
8192 let mut workspace = workspace.lock().unwrap();
8193 let (ws_ptr, _ws_record) = workspace.device_ptr_mut(&stream);
8194 let (arena_ptr, _arena_record) =
8195 self.arena.f32_buf_mut().device_ptr_mut(&stream);
8196 let r = unsafe {
8197 cudnn_conv3d_forward(
8198 handle,
8199 ws_ptr,
8200 CUDNN_WORKSPACE_BYTES,
8201 arena_ptr,
8202 *n,
8203 *c_in,
8204 *c_out,
8205 *d,
8206 *h,
8207 *w,
8208 *d_out,
8209 *h_out,
8210 *w_out,
8211 *kd,
8212 *kh,
8213 *kw,
8214 *sd,
8215 *sh,
8216 *sw,
8217 *pd,
8218 *ph,
8219 *pw,
8220 *dd,
8221 *dh,
8222 *dw,
8223 *groups,
8224 *in_off,
8225 *w_off,
8226 *out_off,
8227 )
8228 };
8229 if let Err(ref e) = r {
8230 log_fallback("conv3d.cudnn", e);
8231 }
8232 r.is_ok()
8233 } else {
8234 false
8235 };
8236 if used_cudnn {
8237 continue;
8238 }
8239
8240 let kernel = conv3d_kernel(&self.ctx);
8242 let total = n * c_out * d_out * h_out * w_out;
8243 let (grid, block) = dispatch_grid_1d(total, 256);
8244 let cfg = LaunchConfig {
8245 grid_dim: (grid, 1, 1),
8246 block_dim: (block, 1, 1),
8247 shared_mem_bytes: 0,
8248 };
8249 let mut launcher = stream.launch_builder(&kernel.function);
8250 launcher
8251 .arg(self.arena.f32_buf_mut())
8252 .arg(n)
8253 .arg(c_in)
8254 .arg(c_out)
8255 .arg(d)
8256 .arg(h)
8257 .arg(w)
8258 .arg(d_out)
8259 .arg(h_out)
8260 .arg(w_out)
8261 .arg(kd)
8262 .arg(kh)
8263 .arg(kw)
8264 .arg(sd)
8265 .arg(sh)
8266 .arg(sw)
8267 .arg(pd)
8268 .arg(ph)
8269 .arg(pw)
8270 .arg(dd)
8271 .arg(dh)
8272 .arg(dw)
8273 .arg(groups)
8274 .arg(in_off)
8275 .arg(w_off)
8276 .arg(out_off);
8277 unsafe {
8278 launcher
8279 .launch(cfg)
8280 .expect("rlx-cuda: conv3d launch failed");
8281 }
8282 }
8283 }
8284
8285 if let Some(idx) = assigned_idx {
8289 if let Ok(evt) = stream.record_event(None) {
8290 last_event.insert(idx, evt);
8291 }
8292 let (_, writes) = step_offsets(step);
8293 for w in &writes {
8294 producer_of.insert(*w, idx);
8295 }
8296 }
8297 }
8298
8299 if multi_stream {
8302 for s in &self.streams {
8303 let _ = s.synchronize();
8304 }
8305 }
8306
8307 self.prepare_readback_plan();
8308 let plan = self.readback_plan_buf.clone();
8309 run_tail_host_audio_ops(&self.schedule, &stream, self.arena.f32_buf_mut(), true);
8310 if !self.gpu_handle_feeds.is_empty() {
8311 self.propagate_gpu_handle_feeds_d2d(&stream);
8312 }
8313 let read_all = plan.len() == self.graph.outputs.len();
8314
8315 if capturing {
8316 let cu_graph = stream.end_capture(
8318 cudarc::driver::sys::CUgraphInstantiate_flags
8319 ::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
8320 ).expect("rlx-cuda: end_capture failed");
8321 if let Some(g) = cu_graph {
8322 g.upload().expect("rlx-cuda: graph upload failed");
8323 g.launch().expect("rlx-cuda: graph first launch failed");
8324 self.captured_graph = Some(g);
8325 self.captured_readback_plan = Some(plan.clone());
8326 }
8327 }
8328
8329 if read_all {
8330 self.fill_output_staging(&stream)
8331 .expect("rlx-cuda: output dtoh failed");
8332 } else {
8333 self.fill_output_staging_indices(&stream, &plan)
8334 .expect("rlx-cuda: partial output dtoh failed");
8335 }
8336 self.refresh_gpu_handles_from_staging(&plan);
8337 stream.synchronize().expect("rlx-cuda: stream sync failed");
8338 self.outputs_from_staging_plan(&plan)
8339 }
8340
8341 fn fill_output_staging_indices(
8342 &mut self,
8343 stream: &Arc<cudarc::driver::CudaStream>,
8344 indices: &[usize],
8345 ) -> Result<(), cudarc::driver::DriverError> {
8346 for &i in indices {
8347 let id = self.graph.outputs[i];
8348 let off_f32 = self.arena.offset(id) / 4;
8349 let elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
8350 debug_assert_eq!(self.output_staging[i].len(), elems);
8351 let slot = self.arena.f32_buf().slice(off_f32..off_f32 + elems);
8352 self.output_staging[i].dtoh(stream, &slot)?;
8353 }
8354 Ok(())
8355 }
8356
8357 fn outputs_from_staging_plan(&self, plan: &[usize]) -> Vec<Vec<f32>> {
8358 if plan.len() == self.graph.outputs.len() {
8359 return self.outputs_from_staging();
8360 }
8361 plan.iter()
8362 .map(|&i| self.output_staging[i].to_vec())
8363 .collect()
8364 }
8365
8366 fn fill_output_staging(
8367 &mut self,
8368 stream: &Arc<cudarc::driver::CudaStream>,
8369 ) -> Result<(), cudarc::driver::DriverError> {
8370 for (i, &id) in self.graph.outputs.iter().enumerate() {
8371 let off_f32 = self.arena.offset(id) / 4;
8372 let elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
8373 debug_assert_eq!(self.output_staging[i].len(), elems);
8374 let slot = self.arena.f32_buf().slice(off_f32..off_f32 + elems);
8375 self.output_staging[i].dtoh(stream, &slot)?;
8376 }
8377 Ok(())
8378 }
8379
8380 fn outputs_from_staging(&self) -> Vec<Vec<f32>> {
8381 self.output_staging
8382 .iter()
8383 .map(F32HostSlot::to_vec)
8384 .collect()
8385 }
8386}
8387
8388fn launch_cumsum_bwd(
8389 ctx: &Arc<CudaContext>,
8390 stream: &cudarc::driver::CudaStream,
8391 buffer: &mut cudarc::driver::CudaSlice<f32>,
8392 outer: u32,
8393 inner: u32,
8394 dy_off: u32,
8395 dx_off: u32,
8396 exclusive: u32,
8397) {
8398 let kernel = cumsum_backward_kernel(ctx);
8399 let (grid, block) = dispatch_grid_1d(outer, 256);
8400 let cfg = LaunchConfig {
8401 grid_dim: (grid, 1, 1),
8402 block_dim: (block, 1, 1),
8403 shared_mem_bytes: 0,
8404 };
8405 let mut launcher = stream.launch_builder(&kernel.function);
8406 launcher
8407 .arg(buffer)
8408 .arg(&outer)
8409 .arg(&inner)
8410 .arg(&dy_off)
8411 .arg(&dx_off)
8412 .arg(&exclusive);
8413 unsafe {
8414 launcher
8415 .launch(cfg)
8416 .expect("rlx-cuda: cumsum_bwd launch failed");
8417 }
8418}
8419
8420fn launch_rope_bwd(
8421 ctx: &Arc<CudaContext>,
8422 stream: &cudarc::driver::CudaStream,
8423 buffer: &mut cudarc::driver::CudaSlice<f32>,
8424 batch: u32,
8425 seq: u32,
8426 hidden: u32,
8427 head_dim: u32,
8428 n_rot: u32,
8429 dy_off: u32,
8430 cos_off: u32,
8431 sin_off: u32,
8432 dx_off: u32,
8433 cos_len: u32,
8434) {
8435 let total = batch * seq * hidden;
8436 let kernel = rope_backward_kernel(ctx);
8437 let (grid, block) = dispatch_grid_1d(total, 256);
8438 let cfg = LaunchConfig {
8439 grid_dim: (grid, 1, 1),
8440 block_dim: (block, 1, 1),
8441 shared_mem_bytes: 0,
8442 };
8443 let mut launcher = stream.launch_builder(&kernel.function);
8444 launcher
8445 .arg(buffer)
8446 .arg(&batch)
8447 .arg(&seq)
8448 .arg(&hidden)
8449 .arg(&head_dim)
8450 .arg(&n_rot)
8451 .arg(&dy_off)
8452 .arg(&cos_off)
8453 .arg(&sin_off)
8454 .arg(&dx_off)
8455 .arg(&cos_len);
8456 unsafe {
8457 launcher
8458 .launch(cfg)
8459 .expect("rlx-cuda: rope_bwd launch failed");
8460 }
8461}
8462
8463fn launch_gather_bwd(
8464 ctx: &Arc<CudaContext>,
8465 stream: &cudarc::driver::CudaStream,
8466 buffer: &mut cudarc::driver::CudaSlice<f32>,
8467 outer: u32,
8468 axis_dim: u32,
8469 num_idx: u32,
8470 trailing: u32,
8471 dy_off: u32,
8472 idx_off: u32,
8473 dst_off: u32,
8474) {
8475 let total = outer * axis_dim * trailing;
8476 if total > 0 {
8477 let zk = rms_norm_bwd_zero_kernel(ctx);
8478 let (grid, block) = dispatch_grid_1d(total, 256);
8479 let cfg = LaunchConfig {
8480 grid_dim: (grid, 1, 1),
8481 block_dim: (block, 1, 1),
8482 shared_mem_bytes: 0,
8483 };
8484 let mut zl = stream.launch_builder(&zk.function);
8485 zl.arg(&mut *buffer).arg(&dst_off).arg(&total);
8486 unsafe {
8487 zl.launch(cfg)
8488 .expect("rlx-cuda: gather_bwd zero launch failed");
8489 }
8490 }
8491 let kernel = gather_backward_kernel(ctx);
8492 let cfg = LaunchConfig {
8493 grid_dim: (outer, (num_idx * trailing).div_ceil(256), 1),
8494 block_dim: (256, 1, 1),
8495 shared_mem_bytes: 0,
8496 };
8497 let mut launcher = stream.launch_builder(&kernel.function);
8498 launcher
8499 .arg(&mut *buffer)
8500 .arg(&outer)
8501 .arg(&axis_dim)
8502 .arg(&num_idx)
8503 .arg(&trailing)
8504 .arg(&dy_off)
8505 .arg(&idx_off)
8506 .arg(&dst_off);
8507 unsafe {
8508 launcher
8509 .launch(cfg)
8510 .expect("rlx-cuda: gather_bwd launch failed");
8511 }
8512}
8513
8514fn launch_rms_norm_bwd(
8515 ctx: &Arc<CudaContext>,
8516 stream: &cudarc::driver::CudaStream,
8517 buffer: &mut cudarc::driver::CudaSlice<f32>,
8518 rows: u32,
8519 inner: u32,
8520 x_off: u32,
8521 gamma_off: u32,
8522 beta_off: u32,
8523 dy_off: u32,
8524 out_off: u32,
8525 eps_bits: u32,
8526 wrt: u32,
8527) {
8528 if wrt != 0 {
8529 let zk = rms_norm_bwd_zero_kernel(ctx);
8530 let (grid, block) = dispatch_grid_1d(inner, 256);
8531 let cfg = LaunchConfig {
8532 grid_dim: (grid, 1, 1),
8533 block_dim: (block, 1, 1),
8534 shared_mem_bytes: 0,
8535 };
8536 let mut zl = stream.launch_builder(&zk.function);
8537 zl.arg(&mut *buffer).arg(&out_off).arg(&inner);
8538 unsafe {
8539 zl.launch(cfg)
8540 .expect("rlx-cuda: rms_norm_bwd zero launch failed");
8541 }
8542 }
8543 let kernel = rms_norm_backward_kernel(ctx);
8544 let cfg = LaunchConfig {
8545 grid_dim: (rows, 1, 1),
8546 block_dim: (256, 1, 1),
8547 shared_mem_bytes: 0,
8548 };
8549 let mut launcher = stream.launch_builder(&kernel.function);
8550 launcher
8551 .arg(&mut *buffer)
8552 .arg(&rows)
8553 .arg(&inner)
8554 .arg(&x_off)
8555 .arg(&gamma_off)
8556 .arg(&beta_off)
8557 .arg(&dy_off)
8558 .arg(&out_off)
8559 .arg(&eps_bits)
8560 .arg(&wrt);
8561 unsafe {
8562 launcher
8563 .launch(cfg)
8564 .expect("rlx-cuda: rms_norm_bwd launch failed");
8565 }
8566}
8567
8568#[cfg(test)]
8569mod tests {
8570 use super::*;
8575
8576 #[test]
8577 fn normalize_read_indices_dedupes() {
8578 let mut v = vec![3, 1, 2, 1, 0];
8579 normalize_read_indices(&mut v);
8580 assert_eq!(v, vec![0, 1, 2, 3]);
8581 }
8582
8583 #[test]
8584 fn step_offsets_binary() {
8585 let s = Step::Binary {
8586 n: 8,
8587 a_off: 100,
8588 b_off: 200,
8589 c_off: 300,
8590 op: 0,
8591 };
8592 let (r, w) = step_offsets(&s);
8593 assert_eq!(r, vec![100, 200]);
8594 assert_eq!(w, vec![300]);
8595 }
8596
8597 #[test]
8598 fn step_offsets_matmul_with_bias() {
8599 let s = Step::Matmul {
8600 m: 4,
8601 k: 8,
8602 n: 4,
8603 a_off_f32: 10,
8604 b_off_f32: 20,
8605 c_off_f32: 30,
8606 batch: 1,
8607 a_batch_stride: 0,
8608 b_batch_stride: 0,
8609 c_batch_stride: 0,
8610 has_bias: 1,
8611 bias_off_f32: 40,
8612 act_id: 0xFFFF,
8613 };
8614 let (r, w) = step_offsets(&s);
8615 assert_eq!(r, vec![10, 20, 40]);
8616 assert_eq!(w, vec![30]);
8617 }
8618
8619 #[test]
8620 fn step_offsets_matmul_no_bias() {
8621 let s = Step::Matmul {
8622 m: 4,
8623 k: 8,
8624 n: 4,
8625 a_off_f32: 10,
8626 b_off_f32: 20,
8627 c_off_f32: 30,
8628 batch: 1,
8629 a_batch_stride: 0,
8630 b_batch_stride: 0,
8631 c_batch_stride: 0,
8632 has_bias: 0,
8633 bias_off_f32: 0,
8634 act_id: 0xFFFF,
8635 };
8636 let (r, w) = step_offsets(&s);
8637 assert_eq!(r, vec![10, 20]);
8638 assert_eq!(w, vec![30]);
8639 }
8640
8641 #[test]
8642 fn step_offsets_attention_causal_no_mask_arg() {
8643 let (mb, mh, mq, mk) = rlx_ir::mask_strides_bhsd(1, 8, 8);
8644 let (qb, qh, qs) = rlx_ir::strides_bhsd(1, 64, 8);
8645 let s = Step::Attention {
8646 batch: 1,
8647 heads: 1,
8648 seq_q: 8,
8649 seq_k: 8,
8650 head_dim: 64,
8651 q_off: 0,
8652 k_off: 100,
8653 v_off: 200,
8654 out_off: 300,
8655 mask_off: 9999,
8656 mask_kind: 1, scale_bits: 0,
8658 window: 0,
8659 seq_q_stride: mq,
8660 seq_k_stride: mk,
8661 mask_batch_stride: mb,
8662 mask_head_stride: mh,
8663 q_batch_stride: qb,
8664 q_head_stride: qh,
8665 q_seq_stride: qs,
8666 k_batch_stride: qb,
8667 k_head_stride: qh,
8668 k_seq_stride: qs,
8669 v_batch_stride: qb,
8670 v_head_stride: qh,
8671 v_seq_stride: qs,
8672 o_batch_stride: qb,
8673 o_head_stride: qh,
8674 o_seq_stride: qs,
8675 };
8676 let (r, _) = step_offsets(&s);
8677 assert!(!r.contains(&9999), "causal mask must not consume mask_off");
8678 assert_eq!(r, vec![0, 100, 200]);
8679 }
8680
8681 #[test]
8682 fn step_offsets_attention_custom_mask_pulls_mask() {
8683 let (mb, mh, mq, mk) = rlx_ir::mask_strides_bhsd(1, 8, 8);
8684 let (qb, qh, qs) = rlx_ir::strides_bhsd(1, 64, 8);
8685 let s = Step::Attention {
8686 batch: 1,
8687 heads: 1,
8688 seq_q: 8,
8689 seq_k: 8,
8690 head_dim: 64,
8691 q_off: 0,
8692 k_off: 100,
8693 v_off: 200,
8694 out_off: 300,
8695 mask_off: 9999,
8696 mask_kind: 2, scale_bits: 0,
8698 window: 0,
8699 seq_q_stride: mq,
8700 seq_k_stride: mk,
8701 mask_batch_stride: mb,
8702 mask_head_stride: mh,
8703 q_batch_stride: qb,
8704 q_head_stride: qh,
8705 q_seq_stride: qs,
8706 k_batch_stride: qb,
8707 k_head_stride: qh,
8708 k_seq_stride: qs,
8709 v_batch_stride: qb,
8710 v_head_stride: qh,
8711 v_seq_stride: qs,
8712 o_batch_stride: qb,
8713 o_head_stride: qh,
8714 o_seq_stride: qs,
8715 };
8716 let (r, _) = step_offsets(&s);
8717 assert!(r.contains(&9999));
8718 }
8719
8720 #[test]
8721 fn step_offsets_scatter_add_acc_marks_out_as_rmw() {
8722 let s = Step::ScatterAddAcc {
8723 out_off: 100,
8724 upd_off: 200,
8725 idx_off: 300,
8726 num_updates: 4,
8727 trailing: 1,
8728 out_dim: 16,
8729 };
8730 let (r, w) = step_offsets(&s);
8731 assert!(r.contains(&100));
8735 assert!(w.contains(&100));
8736 }
8737
8738 #[test]
8739 fn fuse_elementwise_merges_binary_then_unary() {
8740 let schedule = vec![
8741 Step::Binary {
8743 n: 4,
8744 a_off: 0,
8745 b_off: 4,
8746 c_off: 8,
8747 op: 0,
8748 },
8749 Step::Unary {
8751 n: 4,
8752 in_off: 8,
8753 out_off: 12,
8754 op: 0,
8755 },
8756 ];
8757 let fused = fuse_elementwise_chains(schedule);
8758 assert_eq!(fused.len(), 1, "expected exactly one fused step");
8759 match &fused[0] {
8760 Step::FusedBinaryUnary {
8761 n,
8762 a_off,
8763 b_off,
8764 out_off,
8765 bin_op,
8766 un_op,
8767 } => {
8768 assert_eq!(*n, 4);
8769 assert_eq!(*a_off, 0);
8770 assert_eq!(*b_off, 4);
8771 assert_eq!(*out_off, 12);
8772 assert_eq!(*bin_op, 0);
8773 assert_eq!(*un_op, 0);
8774 }
8775 other => panic!("expected FusedBinaryUnary, got {}", step_name(other)),
8776 }
8777 }
8778
8779 #[test]
8780 fn fuse_elementwise_skips_when_intermediate_has_two_consumers() {
8781 let schedule = vec![
8785 Step::Binary {
8786 n: 4,
8787 a_off: 0,
8788 b_off: 4,
8789 c_off: 8,
8790 op: 0,
8791 },
8792 Step::Unary {
8793 n: 4,
8794 in_off: 8,
8795 out_off: 12,
8796 op: 0,
8797 },
8798 Step::Binary {
8799 n: 4,
8800 a_off: 8,
8801 b_off: 8,
8802 c_off: 16,
8803 op: 2,
8804 },
8805 ];
8806 let fused = fuse_elementwise_chains(schedule);
8807 assert_eq!(fused.len(), 3, "no fusion: c has multiple consumers");
8808 assert!(matches!(&fused[0], Step::Binary { .. }));
8809 assert!(matches!(&fused[1], Step::Unary { .. }));
8810 assert!(matches!(&fused[2], Step::Binary { .. }));
8811 }
8812
8813 #[test]
8814 fn fuse_elementwise_skips_when_n_mismatch() {
8815 let schedule = vec![
8817 Step::Binary {
8818 n: 4,
8819 a_off: 0,
8820 b_off: 4,
8821 c_off: 8,
8822 op: 0,
8823 },
8824 Step::Unary {
8825 n: 8,
8826 in_off: 8,
8827 out_off: 16,
8828 op: 0,
8829 },
8830 ];
8831 let fused = fuse_elementwise_chains(schedule);
8832 assert_eq!(fused.len(), 2);
8833 }
8834
8835 #[test]
8836 fn fuse_elementwise_skips_when_unary_input_isnt_binary_output() {
8837 let schedule = vec![
8839 Step::Binary {
8840 n: 4,
8841 a_off: 0,
8842 b_off: 4,
8843 c_off: 8,
8844 op: 0,
8845 },
8846 Step::Unary {
8847 n: 4,
8848 in_off: 99,
8849 out_off: 16,
8850 op: 0,
8851 },
8852 ];
8853 let fused = fuse_elementwise_chains(schedule);
8854 assert_eq!(fused.len(), 2);
8855 }
8856
8857 #[test]
8858 fn fuse_elementwise_handles_multiple_chains() {
8859 let schedule = vec![
8861 Step::Binary {
8862 n: 4,
8863 a_off: 0,
8864 b_off: 4,
8865 c_off: 8,
8866 op: 0,
8867 },
8868 Step::Unary {
8869 n: 4,
8870 in_off: 8,
8871 out_off: 12,
8872 op: 0,
8873 },
8874 Step::Binary {
8875 n: 4,
8876 a_off: 16,
8877 b_off: 20,
8878 c_off: 24,
8879 op: 2,
8880 },
8881 Step::Unary {
8882 n: 4,
8883 in_off: 24,
8884 out_off: 28,
8885 op: 9,
8886 },
8887 ];
8888 let fused = fuse_elementwise_chains(schedule);
8889 assert_eq!(fused.len(), 2);
8890 assert!(matches!(&fused[0], Step::FusedBinaryUnary { .. }));
8891 assert!(matches!(&fused[1], Step::FusedBinaryUnary { .. }));
8892 }
8893}