Skip to main content

rlx_cuda/
backend.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! `CudaExecutable` — lowers an rlx-ir Graph into a sequence of CUDA
17//! kernel launches against a pre-allocated device buffer.
18//!
19//! v2 op coverage: MatMul (tiled SGEMM), Binary, Compare, Activation, Where,
20//! Reduce, Softmax, LayerNorm, RmsNorm, FusedResidualLN, Gather, Narrow,
21//! Argmax, Reshape/Cast (no-op via slot aliasing), leaf nodes. Anything
22//! else panics at compile time with a "fall back to CPU/Metal/MLX/WGPU"
23//! diagnostic. Op coverage is grown incrementally — each new op is one
24//! `.cu` source + one Step variant + one match arm.
25
26use std::collections::HashMap;
27use std::sync::{Arc, Mutex, Once};
28
29use cudarc::cublas::{CudaBlas, sys as cublas_sys};
30use cudarc::cublaslt::{result as cublaslt_result, sys as cublaslt_sys};
31use cudarc::cudnn::{result as cudnn_result, sys as cudnn_sys};
32use cudarc::driver::{CudaContext, DevicePtrMut, LaunchConfig, PushKernelArg};
33use rlx_ir::op::{Activation, BinaryOp, CmpOp, MaskKind, ReduceOp};
34use rlx_ir::{Graph, NodeId, Op};
35use rlx_opt::rlx_fusion::lower_reduce_axes::LowerNonLastAxisReduce;
36use rlx_opt::rlx_fusion::pass::Pass as _;
37
38use crate::arena::{Arena, plan_f32_uniform};
39use crate::device::{
40    CUBLASLT_WORKSPACE_BYTES, CUDNN_WORKSPACE_BYTES, cuda_blas, cuda_blas_lt_handle,
41    cuda_blas_lt_workspace, cuda_context, cuda_dnn_handle, cuda_dnn_workspace,
42};
43use crate::host_staging::F32HostSlot;
44use crate::kernels::{
45    argmax_kernel, attention_bwd_kernel, attention_kernel, attention_row_kernel,
46    batch_elementwise_region_kernel, binary_kernel, compare_kernel, concat_kernel,
47    conv_transpose2d_kernel, conv1d_kernel, conv2d_kernel, conv3d_kernel, cumsum_backward_kernel,
48    cumsum_kernel, dequant_matmul_kernel, dispatch_grid_1d, dispatch_grid_prologue_nchw,
49    elementwise_region_kernel, expand_kernel, fused_binary_unary_kernel, fused_residual_ln_kernel,
50    fused_residual_rms_norm_kernel, gather_axis_kernel, gather_backward_kernel, gather_kernel,
51    group_norm_kernel, grouped_matmul_kernel, im2col_kernel, layer_norm2d_kernel, layernorm_kernel,
52    matmul_epilogue_kernel, matmul_kernel, matmul_wmma_kernel, narrow_kernel, pool1d_kernel,
53    pool2d_kernel, pool3d_kernel, reduce_kernel, resize_nearest_2x_kernel,
54    rms_norm_backward_kernel, rms_norm_bwd_zero_kernel, rope_backward_kernel, rope_kernel,
55    sample_kernel, scatter_add_acc_kernel, scatter_add_zero_kernel, selective_scan_kernel,
56    softmax_kernel, topk_kernel, transpose_kernel, unary_kernel, where_kernel,
57};
58
59/// Opt-in WMMA Tensor Core matmul. Reads `RLX_CUDA_WMMA=1` from env at
60/// process start (cached behind a `OnceLock`). When true and cuBLAS is
61/// unavailable, the scalar matmul kernel is replaced by the WMMA kernel
62/// for plain (non-fused) matmul. Tensor Cores require SM 70+; on older
63/// hardware NVRTC's `load_module` will fail and we fall back to scalar.
64fn use_wmma() -> bool {
65    use std::sync::OnceLock;
66    static FLAG: OnceLock<bool> = OnceLock::new();
67    *FLAG.get_or_init(|| {
68        rlx_ir::env::var("RLX_CUDA_WMMA")
69            .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
70            .unwrap_or(false)
71    })
72}
73
74/// Strict f32 matmul for encoder parity: tiled `matmul.cu` kernel (same
75/// family as wgpu), not cuBLASLt / cuBLAS heuristics.
76fn matmul_parity_mode() -> bool {
77    use std::sync::OnceLock;
78    static FLAG: OnceLock<bool> = OnceLock::new();
79    *FLAG.get_or_init(|| {
80        rlx_ir::env::flag("RLX_CUDA_NO_TF32")
81            || rlx_ir::env::flag("RLX_CUDA_PARITY")
82            || rlx_ir::env::flag("RLX_CUDA_NO_CUBLASLT")
83    })
84}
85
86/// One launch step in the compiled schedule.
87#[derive(Clone)]
88enum Step {
89    Matmul {
90        m: u32,
91        k: u32,
92        n: u32,
93        a_off_f32: u32,
94        b_off_f32: u32,
95        c_off_f32: u32,
96        batch: u32,
97        a_batch_stride: u32,
98        b_batch_stride: u32,
99        c_batch_stride: u32,
100        has_bias: u32,
101        bias_off_f32: u32,
102        act_id: u32,
103    },
104    Binary {
105        n: u32,
106        a_off: u32,
107        b_off: u32,
108        c_off: u32,
109        op: u32,
110    },
111    Compare {
112        n: u32,
113        a_off: u32,
114        b_off: u32,
115        c_off: u32,
116        op: u32,
117    },
118    Unary {
119        n: u32,
120        in_off: u32,
121        out_off: u32,
122        op: u32,
123    },
124    Where {
125        n: u32,
126        cond_off: u32,
127        x_off: u32,
128        y_off: u32,
129        out_off: u32,
130    },
131    Reduce {
132        outer: u32,
133        inner: u32,
134        in_off: u32,
135        out_off: u32,
136        op: u32,
137    },
138    Softmax {
139        outer: u32,
140        inner: u32,
141        in_off: u32,
142        out_off: u32,
143    },
144    LayerNorm {
145        outer: u32,
146        inner: u32,
147        in_off: u32,
148        out_off: u32,
149        gamma_off: u32,
150        beta_off: u32,
151        eps_bits: u32,
152        op: u32,
153    },
154    FusedResidualLn {
155        outer: u32,
156        inner: u32,
157        in_off: u32,
158        residual_off: u32,
159        bias_off: u32,
160        gamma_off: u32,
161        beta_off: u32,
162        out_off: u32,
163        eps_bits: u32,
164        has_bias: u32,
165    },
166    FusedResidualRmsNorm {
167        outer: u32,
168        inner: u32,
169        in_off: u32,
170        residual_off: u32,
171        bias_off: u32,
172        gamma_off: u32,
173        beta_off: u32,
174        out_off: u32,
175        eps_bits: u32,
176        has_bias: u32,
177    },
178    Gather {
179        n_out: u32,
180        n_idx: u32,
181        dim: u32,
182        vocab: u32,
183        in_off: u32,
184        idx_off: u32,
185        out_off: u32,
186    },
187    GatherAxis {
188        total: u32,
189        outer: u32,
190        axis_dim: u32,
191        num_idx: u32,
192        trailing: u32,
193        table_off: u32,
194        idx_off: u32,
195        out_off: u32,
196    },
197    Narrow {
198        total: u32,
199        outer: u32,
200        inner: u32,
201        axis_in_size: u32,
202        axis_out_size: u32,
203        start: u32,
204        in_off: u32,
205        out_off: u32,
206    },
207    Argmax {
208        outer: u32,
209        inner: u32,
210        in_off: u32,
211        out_off: u32,
212    },
213    Transpose {
214        rank: u32,
215        out_total: u32,
216        in_off: u32,
217        out_off: u32,
218        meta_idx: usize,
219    },
220    Expand {
221        rank: u32,
222        out_total: u32,
223        in_off: u32,
224        out_off: u32,
225        meta_idx: usize,
226    },
227    Concat {
228        total: u32,
229        outer: u32,
230        inner: u32,
231        axis_in_size: u32,
232        axis_out_size: u32,
233        start: u32,
234        in_off: u32,
235        out_off: u32,
236    },
237    Attention {
238        batch: u32,
239        heads: u32,
240        seq_q: u32,
241        seq_k: u32,
242        head_dim: u32,
243        q_off: u32,
244        k_off: u32,
245        v_off: u32,
246        out_off: u32,
247        mask_off: u32,
248        mask_kind: u32,
249        scale_bits: u32,
250        window: u32,
251        seq_q_stride: u32,
252        seq_k_stride: u32,
253        mask_batch_stride: u32,
254        mask_head_stride: u32,
255        q_batch_stride: u32,
256        q_head_stride: u32,
257        q_seq_stride: u32,
258        k_batch_stride: u32,
259        k_head_stride: u32,
260        k_seq_stride: u32,
261        v_batch_stride: u32,
262        v_head_stride: u32,
263        v_seq_stride: u32,
264        o_batch_stride: u32,
265        o_head_stride: u32,
266        o_seq_stride: u32,
267    },
268    AttentionBackward {
269        batch: u32,
270        heads: u32,
271        seq_q: u32,
272        seq_k: u32,
273        head_dim: u32,
274        q_off: u32,
275        k_off: u32,
276        v_off: u32,
277        dy_off: u32,
278        out_off: u32,
279        mask_off: u32,
280        mask_kind: u32,
281        scale_bits: u32,
282        window: u32,
283        wrt: u32,
284    },
285    Rope {
286        n_total: u32,
287        seq: u32,
288        head_dim: u32,
289        half: u32,
290        in_off: u32,
291        cos_off: u32,
292        sin_off: u32,
293        out_off: u32,
294        last_dim: u32,
295    },
296    Cumsum {
297        outer: u32,
298        inner: u32,
299        in_off: u32,
300        out_off: u32,
301        exclusive: u32,
302    },
303    TopK {
304        outer: u32,
305        inner: u32,
306        k: u32,
307        in_off: u32,
308        out_off: u32,
309    },
310    GroupedMatmul {
311        m: u32,
312        k: u32,
313        n: u32,
314        num_experts: u32,
315        in_off: u32,
316        w_off: u32,
317        idx_off: u32,
318        out_off: u32,
319    },
320    ScatterAddZero {
321        out_off: u32,
322        out_total: u32,
323    },
324    ScatterAddAcc {
325        out_off: u32,
326        upd_off: u32,
327        idx_off: u32,
328        num_updates: u32,
329        trailing: u32,
330        out_dim: u32,
331    },
332    DequantMatmul {
333        m: u32,
334        k: u32,
335        n: u32,
336        block_size: u32,
337        scheme_id: u32,
338        x_off: u32,
339        w_off: u32,
340        scale_off: u32,
341        zp_off: u32,
342        out_off: u32,
343    },
344    /// GGUF K-quant weights — GPU dequant scratch + cuBLAS (host fallback).
345    DequantMatmulGguf {
346        m: u32,
347        k: u32,
348        n: u32,
349        scheme_id: u32,
350        x_byte_off: u32,
351        w_byte_off: u32,
352        out_byte_off: u32,
353    },
354    DequantGroupedMatmulGguf {
355        m: u32,
356        k: u32,
357        n: u32,
358        num_experts: u32,
359        scheme_id: u32,
360        x_byte_off: u32,
361        w_byte_off: u32,
362        idx_byte_off: u32,
363        out_byte_off: u32,
364    },
365    Sample {
366        outer: u32,
367        inner: u32,
368        in_off: u32,
369        out_off: u32,
370        top_k: u32,
371        top_p_bits: u32,
372        temp_bits: u32,
373        seed_lo: u32,
374        seed_hi: u32,
375    },
376    SelectiveScan {
377        batch: u32,
378        seq: u32,
379        hidden: u32,
380        state_size: u32,
381        x_off: u32,
382        delta_off: u32,
383        a_off: u32,
384        b_off: u32,
385        c_off: u32,
386        out_off: u32,
387    },
388    /// 1D FFT — native GPU (f32 pow2) or host fallback.
389    Fft {
390        src_byte_off: u32,
391        dst_byte_off: u32,
392        outer: u32,
393        n_complex: u32,
394        inverse: bool,
395        norm_tag: u32,
396        dtype_tag: u32,
397        use_gpu: bool,
398    },
399    /// Log-mel from block-layout FFT spectrum — host fallback.
400    LogMelHost {
401        spec_byte_off: u32,
402        filt_byte_off: u32,
403        dst_byte_off: u32,
404        outer: u32,
405        n_fft: u32,
406        n_bins: u32,
407        n_mels: u32,
408    },
409    LogMelBackwardHost {
410        spec_byte_off: u32,
411        filt_byte_off: u32,
412        dy_byte_off: u32,
413        dst_byte_off: u32,
414        outer: u32,
415        n_fft: u32,
416        n_bins: u32,
417        n_mels: u32,
418    },
419    /// Welch PSD top-K from block-layout spectra — host fallback.
420    WelchPeaksHost {
421        spec_byte_off: u32,
422        dst_byte_off: u32,
423        welch_batch: u32,
424        n_fft: u32,
425        n_segments: u32,
426        k: u32,
427    },
428    /// Native GPU WelchPeaks (in-arena, no D2H).
429    WelchPeaksGpu {
430        spec_off: u32,
431        dst_off: u32,
432        welch_batch: u32,
433        n_fft: u32,
434        n_segments: u32,
435        k: u32,
436        n_bins: u32,
437    },
438    /// NCHW im2col — GPU kernel or host fallback (dynamic batch / `RLX_CUDA_IM2COL_HOST=1`).
439    Im2ColHost {
440        x_byte_off: u32,
441        col_byte_off: u32,
442        n: u32,
443        c_in: u32,
444        h: u32,
445        w: u32,
446        h_out: u32,
447        w_out: u32,
448        kh: u32,
449        kw: u32,
450        sh: u32,
451        sw: u32,
452        ph: u32,
453        pw: u32,
454        dh: u32,
455        dw_dil: u32,
456        use_gpu: bool,
457    },
458    /// Gated-DeltaNet — host scan between GPU segments (qwen35 linear layers).
459    GatedDeltaNet {
460        q_byte_off: u32,
461        k_byte_off: u32,
462        v_byte_off: u32,
463        g_byte_off: u32,
464        beta_byte_off: u32,
465        state_byte_off: u32,
466        dst_byte_off: u32,
467        batch: u32,
468        seq: u32,
469        heads: u32,
470        state_size: u32,
471        use_carry: bool,
472    },
473    /// LLaDA2 / TIDE group-limited MoE gate (host TopK between GPU segments).
474    Llada2GroupLimitedGate {
475        sig_off: u32,
476        route_off: u32,
477        out_off: u32,
478        n_elems: u32,
479        attrs: [u8; 20],
480    },
481    UmapKnn {
482        pairwise_off: u32,
483        out_off: u32,
484        n: u32,
485        k: u32,
486    },
487    /// 3D Gaussian splat — host reference between GPU segments.
488    GaussianSplatRender {
489        positions_off: u32,
490        positions_len: u32,
491        scales_off: u32,
492        scales_len: u32,
493        rotations_off: u32,
494        rotations_len: u32,
495        opacities_off: u32,
496        opacities_len: u32,
497        colors_off: u32,
498        colors_len: u32,
499        sh_coeffs_off: u32,
500        sh_coeffs_len: u32,
501        meta_off: u32,
502        dst_off: u32,
503        dst_len: u32,
504        width: u32,
505        height: u32,
506        tile_size: u32,
507        radius_scale: f32,
508        alpha_cutoff: f32,
509        max_splat_steps: u32,
510        transmittance_threshold: f32,
511        max_list_entries: u32,
512    },
513    GaussianSplatRenderBackward {
514        positions_off: u32,
515        positions_len: u32,
516        scales_off: u32,
517        scales_len: u32,
518        rotations_off: u32,
519        rotations_len: u32,
520        opacities_off: u32,
521        opacities_len: u32,
522        colors_off: u32,
523        colors_len: u32,
524        sh_coeffs_off: u32,
525        sh_coeffs_len: u32,
526        meta_off: u32,
527        d_loss_off: u32,
528        d_loss_len: u32,
529        packed_off: u32,
530        packed_len: u32,
531        width: u32,
532        height: u32,
533        tile_size: u32,
534        radius_scale: f32,
535        alpha_cutoff: f32,
536        max_splat_steps: u32,
537        transmittance_threshold: f32,
538        max_list_entries: u32,
539        loss_grad_clip: f32,
540        sh_band: u32,
541        max_anisotropy: f32,
542    },
543    GaussianSplatPrepare {
544        positions_off: u32,
545        positions_len: u32,
546        scales_off: u32,
547        scales_len: u32,
548        rotations_off: u32,
549        rotations_len: u32,
550        opacities_off: u32,
551        opacities_len: u32,
552        colors_off: u32,
553        colors_len: u32,
554        sh_coeffs_off: u32,
555        sh_coeffs_len: u32,
556        meta_off: u32,
557        meta_len: u32,
558        prep_off: u32,
559        prep_len: u32,
560        width: u32,
561        height: u32,
562        tile_size: u32,
563        radius_scale: f32,
564        alpha_cutoff: f32,
565        max_splat_steps: u32,
566        transmittance_threshold: f32,
567        max_list_entries: u32,
568    },
569    GaussianSplatRasterize {
570        prep_off: u32,
571        prep_len: u32,
572        meta_off: u32,
573        meta_len: u32,
574        dst_off: u32,
575        dst_len: u32,
576        count: u32,
577        width: u32,
578        height: u32,
579        tile_size: u32,
580        alpha_cutoff: f32,
581        max_splat_steps: u32,
582        transmittance_threshold: f32,
583        max_list_entries: u32,
584    },
585    RmsNormBackwardInput {
586        x_byte_off: u32,
587        gamma_byte_off: u32,
588        beta_byte_off: u32,
589        dy_byte_off: u32,
590        dx_byte_off: u32,
591        rows: u32,
592        h: u32,
593        eps_bits: u32,
594    },
595    RmsNormBackwardGamma {
596        x_byte_off: u32,
597        gamma_byte_off: u32,
598        beta_byte_off: u32,
599        dy_byte_off: u32,
600        dgamma_byte_off: u32,
601        rows: u32,
602        h: u32,
603        eps_bits: u32,
604    },
605    RmsNormBackwardBeta {
606        x_byte_off: u32,
607        gamma_byte_off: u32,
608        beta_byte_off: u32,
609        dy_byte_off: u32,
610        dbeta_byte_off: u32,
611        rows: u32,
612        h: u32,
613        eps_bits: u32,
614    },
615    RopeBackward {
616        dy_byte_off: u32,
617        cos_byte_off: u32,
618        sin_byte_off: u32,
619        dx_byte_off: u32,
620        batch: u32,
621        seq: u32,
622        hidden: u32,
623        head_dim: u32,
624        n_rot: u32,
625        cos_len: u32,
626    },
627    CumsumBackward {
628        dy_byte_off: u32,
629        dx_byte_off: u32,
630        rows: u32,
631        cols: u32,
632        exclusive: bool,
633    },
634    GatherBackward {
635        dy_byte_off: u32,
636        indices_byte_off: u32,
637        dst_byte_off: u32,
638        outer: u32,
639        axis_dim: u32,
640        num_idx: u32,
641        trailing: u32,
642    },
643    MaxPool2dBackward {
644        x_byte_off: u32,
645        dy_byte_off: u32,
646        dx_byte_off: u32,
647        n: u32,
648        c: u32,
649        h: u32,
650        w: u32,
651        h_out: u32,
652        w_out: u32,
653        kh: u32,
654        kw: u32,
655        sh: u32,
656        sw: u32,
657        ph: u32,
658        pw: u32,
659    },
660    Conv2dBackwardInput {
661        dy_byte_off: u32,
662        w_byte_off: u32,
663        dx_byte_off: u32,
664        n: u32,
665        c_in: u32,
666        h: u32,
667        w_in: u32,
668        c_out: u32,
669        h_out: u32,
670        w_out: u32,
671        kh: u32,
672        kw: u32,
673        sh: u32,
674        sw: u32,
675        ph: u32,
676        pw: u32,
677        dh: u32,
678        dw: u32,
679        groups: u32,
680    },
681    Conv2dBackwardWeight {
682        x_byte_off: u32,
683        dy_byte_off: u32,
684        dw_byte_off: u32,
685        n: u32,
686        c_in: u32,
687        h: u32,
688        w: u32,
689        c_out: u32,
690        h_out: u32,
691        w_out: u32,
692        kh: u32,
693        kw: u32,
694        sh: u32,
695        sw: u32,
696        ph: u32,
697        pw: u32,
698        dh: u32,
699        dw_dil: u32,
700        groups: u32,
701    },
702    Pool1d {
703        n: u32,
704        c: u32,
705        l: u32,
706        l_out: u32,
707        kl: u32,
708        sl: u32,
709        pl: u32,
710        op: u32,
711        in_off: u32,
712        out_off: u32,
713    },
714    Pool2d {
715        n: u32,
716        c: u32,
717        h: u32,
718        w: u32,
719        h_out: u32,
720        w_out: u32,
721        kh: u32,
722        kw: u32,
723        sh: u32,
724        sw: u32,
725        ph: u32,
726        pw: u32,
727        op: u32,
728        in_off: u32,
729        out_off: u32,
730    },
731    Pool3d {
732        n: u32,
733        c: u32,
734        d: u32,
735        h: u32,
736        w: u32,
737        d_out: u32,
738        h_out: u32,
739        w_out: u32,
740        kd: u32,
741        kh: u32,
742        kw: u32,
743        sd: u32,
744        sh: u32,
745        sw: u32,
746        pd: u32,
747        ph: u32,
748        pw: u32,
749        op: u32,
750        in_off: u32,
751        out_off: u32,
752    },
753    Conv1d {
754        n: u32,
755        c_in: u32,
756        c_out: u32,
757        l: u32,
758        l_out: u32,
759        kl: u32,
760        sl: u32,
761        pl: u32,
762        dl: u32,
763        groups: u32,
764        in_off: u32,
765        w_off: u32,
766        out_off: u32,
767    },
768    Conv2d {
769        n: u32,
770        c_in: u32,
771        c_out: u32,
772        h: u32,
773        w: u32,
774        h_out: u32,
775        w_out: u32,
776        kh: u32,
777        kw: u32,
778        sh: u32,
779        sw: u32,
780        ph: u32,
781        pw: u32,
782        dh: u32,
783        dw: u32,
784        groups: u32,
785        in_off: u32,
786        w_off: u32,
787        out_off: u32,
788    },
789    Conv3d {
790        n: u32,
791        c_in: u32,
792        c_out: u32,
793        d: u32,
794        h: u32,
795        w: u32,
796        d_out: u32,
797        h_out: u32,
798        w_out: u32,
799        kd: u32,
800        kh: u32,
801        kw: u32,
802        sd: u32,
803        sh: u32,
804        sw: u32,
805        pd: u32,
806        ph: u32,
807        pw: u32,
808        dd: u32,
809        dh: u32,
810        dw: u32,
811        groups: u32,
812        in_off: u32,
813        w_off: u32,
814        out_off: u32,
815    },
816    /// NCHW LayerNorm2d (SAM semantics).
817    LayerNorm2d {
818        src_off: u32,
819        g_off: u32,
820        b_off: u32,
821        dst_off: u32,
822        n: u32,
823        c: u32,
824        h: u32,
825        w: u32,
826        eps_bits: u32,
827    },
828    /// NCHW ConvTranspose2d (PyTorch weight layout).
829    ConvTranspose2d {
830        src_off: u32,
831        w_off: u32,
832        dst_off: u32,
833        n: u32,
834        c_in: u32,
835        h: u32,
836        w_in: u32,
837        c_out: u32,
838        h_out: u32,
839        w_out: u32,
840        kh: u32,
841        kw: u32,
842        sh: u32,
843        sw: u32,
844        ph: u32,
845        pw: u32,
846        dh: u32,
847        dw: u32,
848        groups: u32,
849    },
850    /// NCHW group norm.
851    GroupNorm {
852        src_off: u32,
853        g_off: u32,
854        b_off: u32,
855        dst_off: u32,
856        n: u32,
857        c: u32,
858        h: u32,
859        w: u32,
860        num_groups: u32,
861        eps_bits: u32,
862    },
863    /// Nearest-neighbor 2× upsample on NCHW.
864    ResizeNearest2x {
865        src_off: u32,
866        dst_off: u32,
867        n: u32,
868        c: u32,
869        h: u32,
870        w: u32,
871    },
872    /// Backend-level fusion of `Binary → Unary` element-wise chains.
873    /// Emitted by `fuse_elementwise_chains` when the intermediate
874    /// offset has exactly one consumer in the schedule. Avoids one
875    /// kernel launch + one round-trip to global memory for the
876    /// intermediate result.
877    FusedBinaryUnary {
878        n: u32,
879        a_off: u32,
880        b_off: u32,
881        out_off: u32,
882        bin_op: u32,
883        un_op: u32,
884    },
885    /// PLAN L2 — interpreted N-ary element-wise chain. The chain
886    /// encoding (input_offs[8] + chain[64]) lives in `meta_buffers`
887    /// and is indexed via `meta_idx`. One thread per output element;
888    /// each thread walks the chain in registers and writes the final
889    /// result to `arena[dst_off + i]`. Caps: 16 steps, 8 inputs.
890    /// Emitted from `Op::ElementwiseRegion` by `MarkElementwiseRegions`
891    /// (replaces the prior `UnfuseElementwiseRegions` decomposer
892    /// fallback). `input_offs` mirrors what's packed in `meta` and is
893    /// kept in the Step so the multi-stream scheduler can resolve
894    /// producer-consumer dependencies without unpacking metadata.
895    ElementwiseRegion {
896        len: u32,
897        num_inputs: u32,
898        num_steps: u32,
899        dst_off: u32,
900        input_offs: [u32; 16],
901        /// PLAN L2 quality fast path: per-input scalar bitfield.
902        /// Bit `i` ⇒ input `i` is a single-element broadcast.
903        scalar_input_mask: u32,
904        /// PLAN L2 quality general broadcast: per-input element count.
905        /// `0` ⇒ no broadcast (kernel reads gid); `>0` ⇒ kernel reads
906        /// `arena[input_offs[i] + (gid % input_modulus[i])]`.
907        input_modulus: [u32; 16],
908        meta_idx: usize,
909        /// When true, launch a W×H×(N·C) grid (resize prologue).
910        spatial_prologue: bool,
911        prologue_w: u32,
912        prologue_h: u32,
913        prologue_nc: u32,
914    },
915    /// FKL batch region: one launch over `num_batch` slices (`blockIdx.z`).
916    BatchElementwiseRegion {
917        slice_len: u32,
918        num_batch: u32,
919        num_steps: u32,
920        base_dst_off: u32,
921        slice_elems: u32,
922        /// Host copy for schedule dependency edges.
923        batch_input_offs: [u32; 64],
924        batch_offs_idx: usize,
925        meta_idx: usize,
926        scalar_input_mask: u32,
927        input_modulus: [u32; 16],
928    },
929}
930
931/// When kernels turn into PTX device code.
932///
933/// `Jit` is the default — each kernel NVRTC-compiles on first dispatch,
934/// then the cuModule is cached for the rest of the process. `Aot`
935/// pre-compiles every kernel at executable construction so the first
936/// `run()` doesn't pay any compile latency. The full AOT pass is ~1-3s
937/// (10-100ms × 32 kernels) but moves that cost out of the critical path.
938#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
939pub enum CompileMode {
940    #[default]
941    Jit,
942    Aot,
943}
944
945/// How the schedule executes.
946///
947/// `Stream` (default) launches each Step on the default stream every
948/// `run()`. `Graph` captures the full schedule into a CUDA Graph on
949/// first run and replays the captured graph on subsequent runs —
950/// eliminates per-launch dispatch overhead (~10-20% on small-batch
951/// inference). `Eager` is a one-shot helper that compiles + runs +
952/// drops the executable in one call; useful for interactive debugging.
953/// `MultiStream(n)` allocates a pool of `n` streams and assigns each
954/// `Step` to a stream based on data dependencies — independent ops
955/// (e.g. unfused Q/K/V projections, FFN gate/up) run in parallel.
956/// Cross-stream synchronization uses CUDA events at producer-consumer
957/// boundaries.
958#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
959pub enum ExecMode {
960    #[default]
961    Stream,
962    Graph,
963    Eager,
964    MultiStream(usize),
965}
966
967pub struct CudaExecutable {
968    ctx: Arc<CudaContext>,
969    /// cuBLAS handle bound to the same default stream as `ctx`. Used for
970    /// plain matmul (no fused bias/activation); falls back to the custom
971    /// kernel when cuBLAS isn't available (e.g., on Mac via the panic-
972    /// catch probe).
973    blas: Option<Arc<Mutex<CudaBlas>>>,
974    /// cuBLASLt handle for fused matmul + bias + activation. Falls back
975    /// to plain cuBLAS sgemm + epilogue kernel when unavailable.
976    blas_lt: Option<cublaslt_sys::cublasLtHandle_t>,
977    /// Shared cuBLASLt scratch — process singleton, only referenced when
978    /// the schedule uses cublasLt-fusable matmul.
979    blas_lt_workspace: Option<Arc<Mutex<cudarc::driver::CudaSlice<u8>>>>,
980    /// cuDNN handle for convolution dispatch (conv1d/2d/3d). Falls back
981    /// to the custom direct-convolution kernels when unavailable.
982    dnn: Option<cudnn_sys::cudnnHandle_t>,
983    /// Shared cuDNN scratch — process singleton, only referenced when the
984    /// schedule contains conv steps.
985    dnn_workspace: Option<Arc<Mutex<cudarc::driver::CudaSlice<u8>>>>,
986    /// Scratch f16 buffer for casting activations on-the-fly when the
987    /// matching weight is half-stored. Sized to fit the largest
988    /// per-call M·K product seen in matmul dispatch; grown lazily.
989    half_act_scratch: Option<cudarc::driver::CudaSlice<u16>>,
990    /// Byte offset in the f32 arena for GGUF dequant scratch (max k×n f32).
991    dequant_scratch_off: usize,
992    graph: Graph,
993    arena: Arena,
994    schedule: Vec<Step>,
995    input_offsets: HashMap<String, NodeId>,
996    param_offsets: HashMap<String, NodeId>,
997    /// Per-step side buffers for kernels that need per-axis u32 metadata
998    /// (Transpose, Expand). Indexed via `Step::Transpose.meta_idx` etc.
999    meta_buffers: Vec<cudarc::driver::CudaSlice<u32>>,
1000    exec_mode: ExecMode,
1001    /// Captured CUDA Graph (built on first `run()` when `exec_mode ==
1002    /// Graph`). Replayed on subsequent runs to skip per-launch dispatch.
1003    captured_graph: Option<cudarc::driver::CudaGraph>,
1004    /// Stream pool for `ExecMode::MultiStream(n)`. Empty for the other
1005    /// modes (which use the context's default stream).
1006    streams: Vec<Arc<cudarc::driver::CudaStream>>,
1007    /// Active-extent hint (`Some((actual, upper))`) for L1 bucketed
1008    /// dispatch. When set AND every step in `schedule` is in the
1009    /// safe set, `run` bypasses the captured CUDA Graph (recorded at
1010    /// full extent) and dispatches per-step with scaled launch dims.
1011    /// Otherwise full-extent fallback. See PLAN L1.
1012    pub(crate) active_extent: Option<(usize, usize)>,
1013    /// Reused host output buffers (stable addresses for CUDA Graph dtoh capture).
1014    output_staging: Vec<F32HostSlot>,
1015    /// Pinned/pageable host staging for fixed-size graph inputs.
1016    input_staging: HashMap<String, F32HostSlot>,
1017    /// Reused event for graph replay completion (avoids full stream sync when possible).
1018    replay_event: Option<cudarc::driver::CudaEvent>,
1019    /// Persistent KV inputs (host mirror + device upload each run).
1020    gpu_handles: HashMap<String, Vec<f32>>,
1021    gpu_handle_feeds: HashMap<String, usize>,
1022    gpu_handle_resident: std::collections::HashSet<String>,
1023    /// When set, only these output indices are read back from device (KV feeds stay on GPU).
1024    pending_read_indices: Option<Vec<usize>>,
1025    /// Reused sorted/deduped output indices for the current run (avoids alloc in `readback_plan`).
1026    readback_plan_buf: Vec<usize>,
1027    /// Output indices baked into the captured CUDA graph (must match on replay).
1028    captured_readback_plan: Option<Vec<usize>>,
1029    /// Graph input names in declaration order (parallel to `input_slots`).
1030    input_slot_names: Vec<String>,
1031    /// Graph inputs in declaration order: `(arena_byte_offset, max_f32_elems)`.
1032    input_slots: Vec<(usize, usize)>,
1033    /// Host readback layout: `(byte_offset_in_host_arena, f32_elems)` per graph output.
1034    output_slots: Vec<(usize, usize)>,
1035    /// Pinned/pageable host mirror for `run_slots` / `arena_ptr` (not GPU arena).
1036    host_arena: Vec<f32>,
1037}
1038
1039impl Step {
1040    /// True when this Step variant honors active-extent dispatch (PLAN L1).
1041    /// Initial coverage: simple element-wise ops + reductions + softmax +
1042    /// LayerNorm + cumsum. Matmul, Attention, Conv, Pool, GroupedMatmul,
1043    /// DequantMatmul, Sample, SelectiveScan, Rope, ScatterAdd, Transpose,
1044    /// Expand, Concat, Narrow, Gather, GatherAxis, Argmax, TopK still
1045    /// default to unsafe — opt in once each Step's per-tier dispatch +
1046    /// kernel offset arithmetic has been verified to scale safely.
1047    pub fn safe_for_active_extent(&self) -> bool {
1048        matches!(
1049            self,
1050            Step::Binary { .. }
1051                | Step::Compare { .. }
1052                | Step::Unary { .. }
1053                | Step::Where { .. }
1054                | Step::Reduce { .. }
1055                | Step::Softmax { .. }
1056                | Step::LayerNorm { .. }
1057                | Step::FusedResidualLn { .. }
1058                | Step::FusedResidualRmsNorm { .. }
1059                | Step::Cumsum { .. }
1060                | Step::FusedBinaryUnary { .. }
1061                | Step::ElementwiseRegion { .. }
1062                | Step::BatchElementwiseRegion { .. }
1063        )
1064    }
1065
1066    /// False when the step performs host-side work or stream sync during dispatch.
1067    pub fn graph_capture_safe(&self) -> bool {
1068        match self {
1069            Step::Im2ColHost { use_gpu, .. } | Step::Fft { use_gpu, .. } => *use_gpu,
1070            Step::GatedDeltaNet { .. }
1071            | Step::Llada2GroupLimitedGate { .. }
1072            | Step::UmapKnn { .. }
1073            | Step::LogMelHost { .. }
1074            | Step::LogMelBackwardHost { .. }
1075            | Step::WelchPeaksHost { .. }
1076            | Step::GaussianSplatRender { .. }
1077            | Step::GaussianSplatRenderBackward { .. }
1078            | Step::GaussianSplatPrepare { .. } => false,
1079            _ => true,
1080        }
1081    }
1082}
1083
1084fn schedule_graph_capture_safe(schedule: &[Step]) -> bool {
1085    schedule.iter().all(Step::graph_capture_safe)
1086}
1087
1088fn step_is_tail_host(step: &Step) -> bool {
1089    matches!(
1090        step,
1091        Step::LogMelHost { .. } | Step::LogMelBackwardHost { .. } | Step::WelchPeaksHost { .. }
1092    )
1093}
1094
1095fn run_tail_host_audio_ops(
1096    schedule: &[Step],
1097    stream: &Arc<cudarc::driver::CudaStream>,
1098    buffer: &mut cudarc::driver::CudaSlice<f32>,
1099    pre_sync: bool,
1100) {
1101    if !schedule.iter().any(step_is_tail_host) {
1102        return;
1103    }
1104    if pre_sync {
1105        stream
1106            .synchronize()
1107            .expect("rlx-cuda: tail host pre-sync failed");
1108    }
1109    for step in schedule {
1110        match step {
1111            Step::LogMelHost {
1112                spec_byte_off,
1113                filt_byte_off,
1114                dst_byte_off,
1115                outer,
1116                n_fft,
1117                n_bins,
1118                n_mels,
1119            } => {
1120                crate::log_mel_host::run_log_mel(
1121                    stream,
1122                    buffer,
1123                    *spec_byte_off as usize,
1124                    *filt_byte_off as usize,
1125                    *dst_byte_off as usize,
1126                    *outer as usize,
1127                    *n_fft as usize,
1128                    *n_bins as usize,
1129                    *n_mels as usize,
1130                    false,
1131                );
1132            }
1133            Step::LogMelBackwardHost {
1134                spec_byte_off,
1135                filt_byte_off,
1136                dy_byte_off,
1137                dst_byte_off,
1138                outer,
1139                n_fft,
1140                n_bins,
1141                n_mels,
1142            } => {
1143                crate::log_mel_backward_host::run_log_mel_backward(
1144                    stream,
1145                    buffer,
1146                    *spec_byte_off as usize,
1147                    *filt_byte_off as usize,
1148                    *dy_byte_off as usize,
1149                    *dst_byte_off as usize,
1150                    *outer as usize,
1151                    *n_fft as usize,
1152                    *n_bins as usize,
1153                    *n_mels as usize,
1154                    false,
1155                );
1156            }
1157            Step::WelchPeaksHost {
1158                spec_byte_off,
1159                dst_byte_off,
1160                welch_batch,
1161                n_fft,
1162                n_segments,
1163                k,
1164            } => {
1165                crate::welch_peaks_host::run_welch_peaks(
1166                    stream,
1167                    buffer,
1168                    *spec_byte_off as usize,
1169                    *dst_byte_off as usize,
1170                    *welch_batch as usize,
1171                    *n_fft as usize,
1172                    *n_segments as usize,
1173                    *k as usize,
1174                    false,
1175                );
1176            }
1177            _ => {}
1178        }
1179    }
1180}
1181
1182fn schedule_needs_blas_lt(schedule: &[Step]) -> bool {
1183    schedule.iter().any(|s| {
1184        matches!(
1185            s,
1186            Step::Matmul { act_id, .. } if cublaslt_act_supported(*act_id)
1187        )
1188    })
1189}
1190
1191fn schedule_needs_dnn(schedule: &[Step]) -> bool {
1192    schedule.iter().any(|s| {
1193        matches!(
1194            s,
1195            Step::Conv1d { .. } | Step::Conv2d { .. } | Step::Conv3d { .. }
1196        )
1197    })
1198}
1199
1200/// Map our internal activation id (matches the `unary` kernel table)
1201/// to a cuBLASLt epilogue activation, if it's natively fusable.
1202/// cuBLASLt only supports Relu and Gelu in the epilogue — anything else
1203/// (sigmoid, tanh, silu, abs, neg, sqrt) returns None and the caller
1204/// falls back to plain sgemm + the matmul_epilogue kernel.
1205fn cublaslt_act_for(act_id: u32) -> Option<cublaslt_sys::cublasLtEpilogue_t> {
1206    None.or(match act_id {
1207        // Identity
1208        0xFFFFu32 => Some(None),
1209        // Relu = 0; Gelu = 9; GeluApprox = 11 (treat as Gelu).
1210        0 => Some(Some(
1211            cublaslt_sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU,
1212        )),
1213        9 | 11 => Some(Some(
1214            cublaslt_sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU,
1215        )),
1216        _ => Some(None),
1217    })
1218    .flatten()
1219}
1220
1221/// True when `act_id` is fusable in cuBLASLt's epilogue (or absent).
1222fn cublaslt_act_supported(act_id: u32) -> bool {
1223    matches!(act_id, 0xFFFFu32 | 0 | 9 | 11)
1224}
1225
1226/// Single cuBLASLt fused matmul. Consumes one descriptor + three matrix
1227/// layouts + one preference object per call (descriptors are cheap to
1228/// create; future optimization could cache them by shape). Returns
1229/// `Err` on any setup failure so the caller can fall back to plain
1230/// cuBLAS sgemm + epilogue kernel.
1231unsafe fn cublaslt_matmul_fused(
1232    handle: cublaslt_sys::cublasLtHandle_t,
1233    workspace_dev_ptr: u64,
1234    workspace_size: usize,
1235    arena_dev_ptr: u64,
1236    m: u32,
1237    k: u32,
1238    n: u32,
1239    a_off_f32: u32,
1240    b_off_f32: u32,
1241    c_off_f32: u32,
1242    has_bias: bool,
1243    bias_off_f32: u32,
1244    epilogue_act: Option<cublaslt_sys::cublasLtEpilogue_t>,
1245    batch: u32,
1246    a_batch_stride: u32,
1247    b_batch_stride: u32,
1248    c_batch_stride: u32,
1249    cu_stream: cudarc::driver::sys::CUstream,
1250) -> Result<(), cublaslt_result::CublasError> {
1251    use core::ffi::c_void;
1252    use core::mem;
1253
1254    // cuBLASLt is column-major. We swap A↔B so that "computing C^T =
1255    // B^T·A^T in column-major" matches "C = A·B in row-major".
1256    let a_ptr = (arena_dev_ptr + (b_off_f32 as u64) * 4) as *const c_void; // = our B
1257    let b_ptr = (arena_dev_ptr + (a_off_f32 as u64) * 4) as *const c_void; // = our A
1258    let c_ptr = (arena_dev_ptr + (c_off_f32 as u64) * 4) as *const c_void;
1259    let d_ptr = c_ptr as *mut c_void;
1260
1261    let dt = cublaslt_sys::cudaDataType_t::CUDA_R_32F;
1262
1263    // Layouts. After A↔B swap: cuBLASLt sees a [n,k] · [k,m] = [n,m].
1264    let a_layout = cublaslt_result::create_matrix_layout(dt, n as u64, k as u64, n as i64)?;
1265    let b_layout = cublaslt_result::create_matrix_layout(dt, k as u64, m as u64, k as i64)?;
1266    let c_layout = cublaslt_result::create_matrix_layout(dt, n as u64, m as u64, n as i64)?;
1267
1268    if batch > 1 {
1269        unsafe {
1270            let bsz = batch as i32;
1271            for &layout in &[a_layout, b_layout, c_layout] {
1272                cublaslt_result::set_matrix_layout_attribute(
1273                layout,
1274                cublaslt_sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
1275                &bsz as *const _ as *const _,
1276                mem::size_of::<i32>(),
1277            )?;
1278            }
1279            let stride_b = b_batch_stride as i64;
1280            let stride_a = a_batch_stride as i64;
1281            let stride_c = c_batch_stride as i64;
1282            cublaslt_result::set_matrix_layout_attribute(
1283            a_layout,
1284            cublaslt_sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
1285            &stride_b as *const _ as *const _, mem::size_of::<i64>())?;
1286            cublaslt_result::set_matrix_layout_attribute(
1287            b_layout,
1288            cublaslt_sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
1289            &stride_a as *const _ as *const _, mem::size_of::<i64>())?;
1290            cublaslt_result::set_matrix_layout_attribute(
1291            c_layout,
1292            cublaslt_sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
1293            &stride_c as *const _ as *const _, mem::size_of::<i64>())?;
1294        }
1295    }
1296
1297    // CUBLAS_COMPUTE_32F_FAST_TF32 enables Tensor-Core paths on Ampere+.
1298    // Set RLX_CUDA_NO_TF32=1 (or RLX_CUDA_PARITY=1) for strict f32 parity
1299    // vs CPU / wgpu reference paths.
1300    let compute_type =
1301        if rlx_ir::env::flag("RLX_CUDA_NO_TF32") || rlx_ir::env::flag("RLX_CUDA_PARITY") {
1302            cublaslt_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
1303        } else {
1304            cublaslt_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32
1305        };
1306    let matmul_desc = cublaslt_result::create_matmul_desc(compute_type, dt)?;
1307
1308    // Pick the epilogue mode. cuBLASLt fuses bias broadcast over the
1309    // M dimension (in cuBLASLt's view). With our A↔B swap, cuBLASLt's
1310    // M = our row-major N, so a bias[N] vector broadcasts across M
1311    // rows of row-major C — exactly what we want.
1312    let epilogue = match (has_bias, epilogue_act) {
1313        (true, Some(cublaslt_sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU)) => {
1314            cublaslt_sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_BIAS
1315        }
1316        (true, Some(cublaslt_sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU)) => {
1317            cublaslt_sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_BIAS
1318        }
1319        (true, None) => cublaslt_sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BIAS,
1320        (false, Some(act)) => act,
1321        (false, None) => cublaslt_sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
1322        _ => cublaslt_sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
1323    };
1324    unsafe {
1325        cublaslt_result::set_matmul_desc_attribute(
1326            matmul_desc,
1327            cublaslt_sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE,
1328            &epilogue as *const _ as *const _,
1329            mem::size_of::<cublaslt_sys::cublasLtEpilogue_t>(),
1330        )?;
1331    }
1332
1333    if has_bias {
1334        let bias_dev_ptr = arena_dev_ptr + (bias_off_f32 as u64) * 4;
1335        unsafe {
1336            cublaslt_result::set_matmul_desc_attribute(
1337                matmul_desc,
1338                cublaslt_sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_POINTER,
1339                &bias_dev_ptr as *const _ as *const _,
1340                mem::size_of::<u64>(),
1341            )?;
1342        }
1343    }
1344
1345    let matmul_pref = cublaslt_result::create_matmul_pref()?;
1346    unsafe {
1347        cublaslt_result::set_matmul_pref_attribute(
1348            matmul_pref,
1349            cublaslt_sys::cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
1350            &workspace_size as *const _ as *const _,
1351            mem::size_of::<usize>(),
1352        )?;
1353    }
1354
1355    let heuristic = unsafe {
1356        cublaslt_result::get_matmul_algo_heuristic(
1357            handle,
1358            matmul_desc,
1359            a_layout,
1360            b_layout,
1361            c_layout,
1362            c_layout,
1363            matmul_pref,
1364        )
1365    }?;
1366
1367    let alpha = 1.0_f32;
1368    let beta = 0.0_f32;
1369    let workspace_ptr = workspace_dev_ptr as *mut c_void;
1370
1371    let result = unsafe {
1372        cublaslt_result::matmul(
1373            handle,
1374            matmul_desc,
1375            &alpha as *const _ as *const c_void,
1376            &beta as *const _ as *const c_void,
1377            a_ptr,
1378            a_layout,
1379            b_ptr,
1380            b_layout,
1381            c_ptr,
1382            c_layout,
1383            d_ptr,
1384            c_layout,
1385            &heuristic.algo as *const _,
1386            workspace_ptr,
1387            workspace_size,
1388            cu_stream as cublaslt_sys::cudaStream_t,
1389        )
1390    };
1391
1392    // Always destroy descriptors (success or fail).
1393    unsafe {
1394        let _ = cublaslt_result::destroy_matmul_pref(matmul_pref);
1395        let _ = cublaslt_result::destroy_matmul_desc(matmul_desc);
1396        let _ = cublaslt_result::destroy_matrix_layout(c_layout);
1397        let _ = cublaslt_result::destroy_matrix_layout(b_layout);
1398        let _ = cublaslt_result::destroy_matrix_layout(a_layout);
1399    }
1400
1401    result
1402}
1403
1404/// cuDNN forward 2D convolution against arena offsets. NCHW input,
1405/// KCRS filter, NCHW output. Uses the v7 algorithm heuristic to pick
1406/// the fastest algo that fits in the supplied workspace. Returns
1407/// `Err` on any setup failure so the caller can fall back to the
1408/// direct-convolution kernel.
1409unsafe fn cudnn_conv2d_forward(
1410    handle: cudnn_sys::cudnnHandle_t,
1411    workspace_dev_ptr: u64,
1412    workspace_size: usize,
1413    arena_dev_ptr: u64,
1414    n: u32,
1415    c_in: u32,
1416    c_out: u32,
1417    h: u32,
1418    w: u32,
1419    h_out: u32,
1420    w_out: u32,
1421    kh: u32,
1422    kw: u32,
1423    sh: u32,
1424    sw: u32,
1425    ph: u32,
1426    pw: u32,
1427    dh: u32,
1428    dw: u32,
1429    groups: u32,
1430    in_off_f32: u32,
1431    w_off_f32: u32,
1432    out_off_f32: u32,
1433) -> Result<(), cudnn_result::CudnnError> {
1434    use core::ffi::c_void;
1435
1436    let dt = cudnn_sys::cudnnDataType_t::CUDNN_DATA_FLOAT;
1437    let fmt = cudnn_sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW;
1438
1439    let x_desc = cudnn_result::create_tensor_descriptor()?;
1440    let y_desc = cudnn_result::create_tensor_descriptor()?;
1441    let conv_desc = cudnn_result::create_convolution_descriptor()?;
1442
1443    let w_desc = unsafe {
1444        let mut w_desc_uninit = std::mem::MaybeUninit::uninit();
1445        cudnn_sys::cudnnCreateFilterDescriptor(w_desc_uninit.as_mut_ptr()).result()?;
1446        w_desc_uninit.assume_init()
1447    };
1448
1449    let setup = unsafe {
1450        cudnn_result::set_tensor4d_descriptor(
1451            x_desc,
1452            fmt,
1453            dt,
1454            [n as i32, c_in as i32, h as i32, w as i32],
1455        )?;
1456        cudnn_result::set_tensor4d_descriptor(
1457            y_desc,
1458            fmt,
1459            dt,
1460            [n as i32, c_out as i32, h_out as i32, w_out as i32],
1461        )?;
1462        cudnn_result::set_filter4d_descriptor(
1463            w_desc,
1464            dt,
1465            fmt,
1466            [
1467                c_out as i32,
1468                (c_in / groups.max(1)) as i32,
1469                kh as i32,
1470                kw as i32,
1471            ],
1472        )?;
1473        cudnn_result::set_convolution2d_descriptor(
1474            conv_desc,
1475            ph as i32,
1476            pw as i32,
1477            sh as i32,
1478            sw as i32,
1479            dh as i32,
1480            dw as i32,
1481            cudnn_sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
1482            dt,
1483        )?;
1484        if groups > 1 {
1485            cudnn_sys::cudnnSetConvolutionGroupCount(conv_desc, groups as i32).result()?;
1486        }
1487        Ok::<(), cudnn_result::CudnnError>(())
1488    };
1489
1490    let result = setup.and_then(|()| unsafe {
1491        // Pick the fastest fwd algo via the v7 heuristic.
1492        let mut returned_count: i32 = 0;
1493        let mut perf = std::mem::MaybeUninit::<cudnn_sys::cudnnConvolutionFwdAlgoPerf_t>::uninit();
1494        cudnn_result::get_convolution_forward_algorithm(
1495            handle,
1496            x_desc,
1497            w_desc,
1498            conv_desc,
1499            y_desc,
1500            1,
1501            &mut returned_count,
1502            perf.as_mut_ptr(),
1503        )?;
1504        if returned_count == 0 {
1505            return Err(cudnn_result::CudnnError(
1506                cudnn_sys::cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED,
1507            ));
1508        }
1509        let algo = perf.assume_init().algo;
1510
1511        let needed = cudnn_result::get_convolution_forward_workspace_size(
1512            handle, x_desc, w_desc, conv_desc, y_desc, algo,
1513        )?;
1514        if needed > workspace_size {
1515            return Err(cudnn_result::CudnnError(
1516                cudnn_sys::cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED,
1517            ));
1518        }
1519
1520        let alpha: f32 = 1.0;
1521        let beta: f32 = 0.0;
1522        let x_ptr = (arena_dev_ptr + (in_off_f32 as u64) * 4) as *const c_void;
1523        let w_ptr = (arena_dev_ptr + (w_off_f32 as u64) * 4) as *const c_void;
1524        let y_ptr = (arena_dev_ptr + (out_off_f32 as u64) * 4) as *mut c_void;
1525        let workspace_ptr = workspace_dev_ptr as *mut c_void;
1526
1527        cudnn_result::convolution_forward(
1528            handle,
1529            &alpha as *const _ as *const c_void,
1530            x_desc,
1531            x_ptr,
1532            w_desc,
1533            w_ptr,
1534            conv_desc,
1535            algo,
1536            workspace_ptr,
1537            workspace_size,
1538            &beta as *const _ as *const c_void,
1539            y_desc,
1540            y_ptr,
1541        )
1542    });
1543
1544    unsafe {
1545        let _ = cudnn_result::destroy_convolution_descriptor(conv_desc);
1546        let _ = cudnn_result::destroy_filter_descriptor(w_desc);
1547        let _ = cudnn_result::destroy_tensor_descriptor(y_desc);
1548        let _ = cudnn_result::destroy_tensor_descriptor(x_desc);
1549    }
1550
1551    result
1552}
1553
1554/// cuDNN forward 3-D convolution. NCDHW input, KCDRS filter, NCDHW
1555/// output. Uses cuDNN's nd-descriptor APIs (set_tensornd / set_filternd
1556/// / set_convolutionnd) since the 4D versions only cover up to 2D conv.
1557unsafe fn cudnn_conv3d_forward(
1558    handle: cudnn_sys::cudnnHandle_t,
1559    workspace_dev_ptr: u64,
1560    workspace_size: usize,
1561    arena_dev_ptr: u64,
1562    n: u32,
1563    c_in: u32,
1564    c_out: u32,
1565    d: u32,
1566    h: u32,
1567    w: u32,
1568    d_out: u32,
1569    h_out: u32,
1570    w_out: u32,
1571    kd: u32,
1572    kh: u32,
1573    kw: u32,
1574    sd: u32,
1575    sh: u32,
1576    sw: u32,
1577    pd: u32,
1578    ph: u32,
1579    pw: u32,
1580    dd: u32,
1581    dh: u32,
1582    dw: u32,
1583    groups: u32,
1584    in_off_f32: u32,
1585    w_off_f32: u32,
1586    out_off_f32: u32,
1587) -> Result<(), cudnn_result::CudnnError> {
1588    use core::ffi::c_void;
1589
1590    let dt = cudnn_sys::cudnnDataType_t::CUDNN_DATA_FLOAT;
1591    let fmt = cudnn_sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW;
1592
1593    let x_desc = cudnn_result::create_tensor_descriptor()?;
1594    let y_desc = cudnn_result::create_tensor_descriptor()?;
1595    let conv_desc = cudnn_result::create_convolution_descriptor()?;
1596    let w_desc = unsafe {
1597        let mut w_desc_uninit = std::mem::MaybeUninit::uninit();
1598        cudnn_sys::cudnnCreateFilterDescriptor(w_desc_uninit.as_mut_ptr()).result()?;
1599        w_desc_uninit.assume_init()
1600    };
1601
1602    // 5-D tensor: [N, C, D, H, W] with row-major contiguous strides.
1603    let x_dims: [i32; 5] = [n as i32, c_in as i32, d as i32, h as i32, w as i32];
1604    let x_strides: [i32; 5] = [
1605        (c_in * d * h * w) as i32,
1606        (d * h * w) as i32,
1607        (h * w) as i32,
1608        w as i32,
1609        1,
1610    ];
1611    let y_dims: [i32; 5] = [
1612        n as i32,
1613        c_out as i32,
1614        d_out as i32,
1615        h_out as i32,
1616        w_out as i32,
1617    ];
1618    let y_strides: [i32; 5] = [
1619        (c_out * d_out * h_out * w_out) as i32,
1620        (d_out * h_out * w_out) as i32,
1621        (h_out * w_out) as i32,
1622        w_out as i32,
1623        1,
1624    ];
1625    let f_dims: [i32; 5] = [
1626        c_out as i32,
1627        (c_in / groups.max(1)) as i32,
1628        kd as i32,
1629        kh as i32,
1630        kw as i32,
1631    ];
1632    let pads: [i32; 3] = [pd as i32, ph as i32, pw as i32];
1633    let strides: [i32; 3] = [sd as i32, sh as i32, sw as i32];
1634    let dilations: [i32; 3] = [dd as i32, dh as i32, dw as i32];
1635
1636    let setup = unsafe {
1637        cudnn_result::set_tensornd_descriptor(x_desc, dt, 5, x_dims.as_ptr(), x_strides.as_ptr())?;
1638        cudnn_result::set_tensornd_descriptor(y_desc, dt, 5, y_dims.as_ptr(), y_strides.as_ptr())?;
1639        cudnn_result::set_filternd_descriptor(w_desc, dt, fmt, 5, f_dims.as_ptr())?;
1640        cudnn_result::set_convolutionnd_descriptor(
1641            conv_desc,
1642            3,
1643            pads.as_ptr(),
1644            strides.as_ptr(),
1645            dilations.as_ptr(),
1646            cudnn_sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
1647            dt,
1648        )?;
1649        if groups > 1 {
1650            cudnn_sys::cudnnSetConvolutionGroupCount(conv_desc, groups as i32).result()?;
1651        }
1652        Ok::<(), cudnn_result::CudnnError>(())
1653    };
1654
1655    let result = setup.and_then(|()| unsafe {
1656        let mut returned_count: i32 = 0;
1657        let mut perf = std::mem::MaybeUninit::<cudnn_sys::cudnnConvolutionFwdAlgoPerf_t>::uninit();
1658        cudnn_result::get_convolution_forward_algorithm(
1659            handle,
1660            x_desc,
1661            w_desc,
1662            conv_desc,
1663            y_desc,
1664            1,
1665            &mut returned_count,
1666            perf.as_mut_ptr(),
1667        )?;
1668        if returned_count == 0 {
1669            return Err(cudnn_result::CudnnError(
1670                cudnn_sys::cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED,
1671            ));
1672        }
1673        let algo = perf.assume_init().algo;
1674
1675        let needed = cudnn_result::get_convolution_forward_workspace_size(
1676            handle, x_desc, w_desc, conv_desc, y_desc, algo,
1677        )?;
1678        if needed > workspace_size {
1679            return Err(cudnn_result::CudnnError(
1680                cudnn_sys::cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED,
1681            ));
1682        }
1683
1684        let alpha: f32 = 1.0;
1685        let beta: f32 = 0.0;
1686        let x_ptr = (arena_dev_ptr + (in_off_f32 as u64) * 4) as *const c_void;
1687        let w_ptr = (arena_dev_ptr + (w_off_f32 as u64) * 4) as *const c_void;
1688        let y_ptr = (arena_dev_ptr + (out_off_f32 as u64) * 4) as *mut c_void;
1689        let workspace_ptr = workspace_dev_ptr as *mut c_void;
1690
1691        cudnn_result::convolution_forward(
1692            handle,
1693            &alpha as *const _ as *const c_void,
1694            x_desc,
1695            x_ptr,
1696            w_desc,
1697            w_ptr,
1698            conv_desc,
1699            algo,
1700            workspace_ptr,
1701            workspace_size,
1702            &beta as *const _ as *const c_void,
1703            y_desc,
1704            y_ptr,
1705        )
1706    });
1707
1708    unsafe {
1709        let _ = cudnn_result::destroy_convolution_descriptor(conv_desc);
1710        let _ = cudnn_result::destroy_filter_descriptor(w_desc);
1711        let _ = cudnn_result::destroy_tensor_descriptor(y_desc);
1712        let _ = cudnn_result::destroy_tensor_descriptor(x_desc);
1713    }
1714
1715    result
1716}
1717
1718/// Decode a Matmul/FusedMatMulBiasAct node's input shapes into the
1719/// (m, k, n, batch, a_stride, b_stride, c_stride, a_id, b_id) tuple
1720/// the kernel expects. Three patterns:
1721///   • 2D × 2D                       → batch=1, all strides 0
1722///   • [..,M,K] × [K,N] (broadcast)  → batch=1, leading dims flattened into M
1723///   • [..,M,K] × [..,K,N] (matched) → batch=prod(leading), per-batch strides
1724fn matmul_shape(
1725    graph: &Graph,
1726    node: &rlx_ir::Node,
1727    op_label: &str,
1728) -> (u32, u32, u32, u32, u32, u32, u32, NodeId, NodeId) {
1729    let a_id = node.inputs[0];
1730    let b_id = node.inputs[1];
1731    let a_shape = graph.node(a_id).shape.dims();
1732    let b_shape = graph.node(b_id).shape.dims();
1733    let out_shape = node.shape.dims();
1734    if a_shape.len() == 2 && b_shape.len() == 2 && out_shape.len() == 2 {
1735        let m = a_shape[0].unwrap_static() as u32;
1736        let k = a_shape[1].unwrap_static() as u32;
1737        let n = b_shape[1].unwrap_static() as u32;
1738        (m, k, n, 1, 0, 0, 0, a_id, b_id)
1739    } else if a_shape.len() >= 2 && b_shape.len() == 2 && out_shape.len() == a_shape.len() {
1740        let leading: usize = a_shape[..a_shape.len() - 2]
1741            .iter()
1742            .map(|d| d.unwrap_static())
1743            .product();
1744        let m_inner = a_shape[a_shape.len() - 2].unwrap_static();
1745        let k_inner = a_shape[a_shape.len() - 1].unwrap_static();
1746        let n_inner = b_shape[1].unwrap_static();
1747        (
1748            (leading * m_inner) as u32,
1749            k_inner as u32,
1750            n_inner as u32,
1751            1,
1752            0,
1753            0,
1754            0,
1755            a_id,
1756            b_id,
1757        )
1758    } else if a_shape.len() == b_shape.len() && a_shape.len() >= 3 {
1759        let leading_a: Vec<usize> = a_shape[..a_shape.len() - 2]
1760            .iter()
1761            .map(|d| d.unwrap_static())
1762            .collect();
1763        let leading_b: Vec<usize> = b_shape[..b_shape.len() - 2]
1764            .iter()
1765            .map(|d| d.unwrap_static())
1766            .collect();
1767        if leading_a != leading_b {
1768            panic!(
1769                "rlx-cuda {op_label}: batched shape mismatch \
1770                    a_leading={leading_a:?} b_leading={leading_b:?}"
1771            );
1772        }
1773        let b_count: usize = leading_a.iter().product();
1774        let m_inner = a_shape[a_shape.len() - 2].unwrap_static();
1775        let k_inner = a_shape[a_shape.len() - 1].unwrap_static();
1776        let n_inner = b_shape[b_shape.len() - 1].unwrap_static();
1777        (
1778            m_inner as u32,
1779            k_inner as u32,
1780            n_inner as u32,
1781            b_count as u32,
1782            (m_inner * k_inner) as u32,
1783            (k_inner * n_inner) as u32,
1784            (m_inner * n_inner) as u32,
1785            a_id,
1786            b_id,
1787        )
1788    } else {
1789        panic!(
1790            "rlx-cuda {op_label}: unsupported shapes a={a_shape:?} b={b_shape:?} out={out_shape:?}"
1791        );
1792    }
1793}
1794
1795fn binary_op_id(op: BinaryOp) -> u32 {
1796    match op {
1797        BinaryOp::Add => 0,
1798        BinaryOp::Sub => 1,
1799        BinaryOp::Mul => 2,
1800        BinaryOp::Div => 3,
1801        BinaryOp::Max => 4,
1802        BinaryOp::Min => 5,
1803        BinaryOp::Pow => 6,
1804    }
1805}
1806
1807fn compare_op_id(op: CmpOp) -> u32 {
1808    match op {
1809        CmpOp::Eq => 0,
1810        CmpOp::Ne => 1,
1811        CmpOp::Lt => 2,
1812        CmpOp::Le => 3,
1813        CmpOp::Gt => 4,
1814        CmpOp::Ge => 5,
1815    }
1816}
1817
1818fn reduce_op_id(op: ReduceOp) -> u32 {
1819    match op {
1820        ReduceOp::Sum => 0,
1821        ReduceOp::Mean => 1,
1822        ReduceOp::Max => 2,
1823        ReduceOp::Min => 3,
1824        ReduceOp::Prod => 4,
1825    }
1826}
1827
1828fn activation_op_id(act: Activation) -> u32 {
1829    match act {
1830        Activation::Relu => 0,
1831        Activation::Sigmoid => 1,
1832        Activation::Tanh => 2,
1833        Activation::Exp => 3,
1834        Activation::Log => 4,
1835        Activation::Sqrt => 5,
1836        Activation::Rsqrt => 6,
1837        Activation::Neg => 7,
1838        Activation::Abs => 8,
1839        Activation::Gelu => 9,
1840        Activation::Silu => 10,
1841        Activation::GeluApprox => 11,
1842        Activation::Round => 12,
1843        Activation::Sin => 13,
1844        Activation::Cos => 14,
1845        Activation::Tan => 15,
1846        Activation::Atan => 16,
1847    }
1848}
1849
1850/// Mixed-precision matmul tier-0: when the weight (B input) is stored
1851/// in the half-arena, cast f32 activations to f16/bf16 in the scratch
1852/// buffer and run `cublasGemmEx` with both inputs half + f32
1853/// accumulator. Returns `true` on success.
1854///
1855/// Free function (rather than `&mut self` method) so the caller can
1856/// hold `&self.schedule` across the call without violating disjoint-
1857/// field borrow checks.
1858#[allow(clippy::too_many_arguments)]
1859fn try_mixed_precision_gemm(
1860    ctx: &Arc<CudaContext>,
1861    arena: &mut crate::arena::Arena,
1862    half_act_scratch: &mut Option<cudarc::driver::CudaSlice<u16>>,
1863    blas: Option<&Arc<Mutex<CudaBlas>>>,
1864    stream: &Arc<cudarc::driver::CudaStream>,
1865    m: u32,
1866    k: u32,
1867    n: u32,
1868    batch: u32,
1869    a_off_f32: u32,
1870    b_off_f32: u32,
1871    c_off_f32: u32,
1872) -> bool {
1873    let (half_off, half_dtype) = match arena.half_by_f32_off.get(&b_off_f32).copied() {
1874        Some(v) => v,
1875        None => return false,
1876    };
1877    let blas = match blas {
1878        Some(b) => b,
1879        None => return false,
1880    };
1881
1882    let act_elems = (m * k * batch.max(1)) as usize;
1883    let need_resize = half_act_scratch
1884        .as_ref()
1885        .is_none_or(|s| s.len() < act_elems);
1886    if need_resize {
1887        *half_act_scratch = stream.alloc_zeros::<u16>(act_elems.max(4)).ok();
1888    }
1889    if half_act_scratch.is_none() {
1890        return false;
1891    }
1892
1893    // Phase 1: cast activations f32 → f16/bf16 into the scratch.
1894    let n_total = m * k * batch.max(1);
1895    let dtype_id: u32 = match half_dtype {
1896        crate::arena::HalfDtype::F16 => 0,
1897        crate::arena::HalfDtype::Bf16 => 1,
1898    };
1899    {
1900        let kernel = crate::kernels::cast_f32_to_half_kernel(ctx);
1901        let (grid, block) = dispatch_grid_1d(n_total, 256);
1902        let cfg = LaunchConfig {
1903            grid_dim: (grid, 1, 1),
1904            block_dim: (block, 1, 1),
1905            shared_mem_bytes: 0,
1906        };
1907        let src_view = arena
1908            .f32_buf()
1909            .slice(a_off_f32 as usize..a_off_f32 as usize + n_total as usize);
1910        let scratch_mut = half_act_scratch.as_mut().unwrap();
1911        let mut launcher = stream.launch_builder(&kernel.function);
1912        launcher
1913            .arg(&src_view)
1914            .arg(scratch_mut)
1915            .arg(&n_total)
1916            .arg(&dtype_id);
1917        if unsafe { launcher.launch(cfg) }.is_err() {
1918            return false;
1919        }
1920    }
1921
1922    // Phase 2: cublasGemmEx with both inputs half + f32 output.
1923    let blas = blas.lock().unwrap();
1924    let arena_ptr_u64 = {
1925        let (p, _ar) = arena.buffer.device_ptr_mut(stream);
1926        p
1927    };
1928    let (half_buf_ptr, _hb) = arena.half_buffer.as_mut().unwrap().device_ptr_mut(stream);
1929    let scratch_ptr_u64 = {
1930        let s = half_act_scratch.as_mut().unwrap();
1931        let (p, _r) = s.device_ptr_mut(stream);
1932        p
1933    };
1934    let weight_dev = half_buf_ptr + (half_off as u64) * 2; // u16 = 2 bytes
1935    let act_dev = scratch_ptr_u64;
1936    let c_dev = arena_ptr_u64 + (c_off_f32 as u64) * 4;
1937    let alpha: f32 = 1.0;
1938    let beta: f32 = 0.0;
1939    let cuda_dt = match half_dtype {
1940        crate::arena::HalfDtype::F16 => cublas_sys::cudaDataType_t::CUDA_R_16F,
1941        crate::arena::HalfDtype::Bf16 => cublas_sys::cudaDataType_t::CUDA_R_16BF,
1942    };
1943    let compute_ty = match half_dtype {
1944        crate::arena::HalfDtype::F16 => {
1945            cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16F
1946        }
1947        crate::arena::HalfDtype::Bf16 => {
1948            cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF
1949        }
1950    };
1951    let result = unsafe {
1952        cudarc::cublas::result::gemm_ex(
1953            *blas.handle(),
1954            cublas_sys::cublasOperation_t::CUBLAS_OP_N,
1955            cublas_sys::cublasOperation_t::CUBLAS_OP_N,
1956            n as i32,
1957            m as i32,
1958            k as i32,
1959            &alpha as *const f32 as *const _,
1960            weight_dev as *const _,
1961            cuda_dt,
1962            n as i32,
1963            act_dev as *const _,
1964            cuda_dt,
1965            k as i32,
1966            &beta as *const f32 as *const _,
1967            c_dev as *mut _,
1968            cublas_sys::cudaDataType_t::CUDA_R_32F,
1969            n as i32,
1970            compute_ty,
1971            cublas_sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT,
1972        )
1973    };
1974    if let Err(ref e) = result {
1975        log_fallback("matmul.gemmEx (mixed-precision)", e);
1976    }
1977    result.is_ok()
1978}
1979
1980/// One-time-per-tier log when a fast-path dispatch silently falls
1981/// back. Helps cloud-GPU debugging see *why* the slow path took over —
1982/// otherwise the only signal is unexpectedly low throughput.
1983/// Gated behind `RLX_CUDA_LOG_FALLBACK=1` so production isn't spammed.
1984fn log_fallback(tier: &str, err: impl std::fmt::Debug) {
1985    use std::sync::OnceLock;
1986    static ENABLED: OnceLock<bool> = OnceLock::new();
1987    let enabled = *ENABLED.get_or_init(|| {
1988        rlx_ir::env::var("RLX_CUDA_LOG_FALLBACK")
1989            .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
1990            .unwrap_or(false)
1991    });
1992    if enabled {
1993        eprintln!("rlx-cuda: tier '{tier}' fell back: {err:?}");
1994    }
1995}
1996
1997/// Stable, profiler-friendly name for an NVTX range covering a Step
1998/// dispatch. Matches the variant name; nsight-systems / nvprof show
1999/// these as range boundaries in the timeline.
2000fn fft_dtype_tag(dtype: rlx_ir::DType) -> u32 {
2001    match dtype {
2002        rlx_ir::DType::F32 => 0,
2003        rlx_ir::DType::F64 => 1,
2004        rlx_ir::DType::C64 => 2,
2005        other => panic!("rlx-cuda Op::Fft: unsupported dtype {other:?}"),
2006    }
2007}
2008
2009fn fft_dtype_from_tag(tag: u32) -> rlx_ir::DType {
2010    match tag {
2011        0 => rlx_ir::DType::F32,
2012        1 => rlx_ir::DType::F64,
2013        2 => rlx_ir::DType::C64,
2014        other => panic!("rlx-cuda Op::Fft: bad dtype tag {other}"),
2015    }
2016}
2017
2018fn step_name(step: &Step) -> &'static str {
2019    match step {
2020        Step::Matmul { .. } => "rlx::Matmul",
2021        Step::Binary { .. } => "rlx::Binary",
2022        Step::Compare { .. } => "rlx::Compare",
2023        Step::Unary { .. } => "rlx::Unary",
2024        Step::Where { .. } => "rlx::Where",
2025        Step::Reduce { .. } => "rlx::Reduce",
2026        Step::Softmax { .. } => "rlx::Softmax",
2027        Step::LayerNorm { .. } => "rlx::LayerNorm",
2028        Step::FusedResidualLn { .. } => "rlx::FusedResidualLN",
2029        Step::FusedResidualRmsNorm { .. } => "rlx::FusedResidualRmsNorm",
2030        Step::Gather { .. } => "rlx::Gather",
2031        Step::GatherAxis { .. } => "rlx::GatherAxis",
2032        Step::Narrow { .. } => "rlx::Narrow",
2033        Step::Concat { .. } => "rlx::Concat",
2034        Step::Transpose { .. } => "rlx::Transpose",
2035        Step::Expand { .. } => "rlx::Expand",
2036        Step::Argmax { .. } => "rlx::Argmax",
2037        Step::Attention { .. } => "rlx::Attention",
2038        Step::AttentionBackward { .. } => "rlx::AttentionBackward",
2039        Step::Rope { .. } => "rlx::Rope",
2040        Step::Cumsum { .. } => "rlx::Cumsum",
2041        Step::TopK { .. } => "rlx::TopK",
2042        Step::GroupedMatmul { .. } => "rlx::GroupedMatmul",
2043        Step::ScatterAddZero { .. } => "rlx::ScatterAdd::zero",
2044        Step::ScatterAddAcc { .. } => "rlx::ScatterAdd::acc",
2045        Step::DequantMatmul { .. } => "rlx::DequantMatmul",
2046        Step::DequantMatmulGguf { .. } => "rlx::DequantMatmulGguf",
2047        Step::DequantGroupedMatmulGguf { .. } => "rlx::DequantGroupedMatmulGguf",
2048        Step::Sample { .. } => "rlx::Sample",
2049        Step::SelectiveScan { .. } => "rlx::SelectiveScan",
2050        Step::Fft { .. } => "rlx::Fft",
2051        Step::LogMelHost { .. } => "rlx::LogMelHost",
2052        Step::LogMelBackwardHost { .. } => "rlx::LogMelBackwardHost",
2053        Step::WelchPeaksHost { .. } => "rlx::WelchPeaksHost",
2054        Step::WelchPeaksGpu { .. } => "rlx::WelchPeaksGpu",
2055        Step::Im2ColHost { .. } => "rlx::Im2ColHost",
2056        Step::GatedDeltaNet { .. } => "rlx::GatedDeltaNet",
2057        Step::Llada2GroupLimitedGate { .. } => "rlx::Llada2GroupLimitedGate",
2058        Step::UmapKnn { .. } => "rlx::UmapKnn",
2059        Step::GaussianSplatRender { .. } => "rlx::GaussianSplatRender",
2060        Step::GaussianSplatRenderBackward { .. } => "rlx::GaussianSplatRenderBackward",
2061        Step::GaussianSplatPrepare { .. } => "rlx::GaussianSplatPrepare",
2062        Step::GaussianSplatRasterize { .. } => "rlx::GaussianSplatRasterize",
2063        Step::RmsNormBackwardInput { .. } => "rlx::RmsNormBackwardInput",
2064        Step::RmsNormBackwardGamma { .. } => "rlx::RmsNormBackwardGamma",
2065        Step::RmsNormBackwardBeta { .. } => "rlx::RmsNormBackwardBeta",
2066        Step::RopeBackward { .. } => "rlx::RopeBackward",
2067        Step::CumsumBackward { .. } => "rlx::CumsumBackward",
2068        Step::GatherBackward { .. } => "rlx::GatherBackward",
2069        Step::MaxPool2dBackward { .. } => "rlx::MaxPool2dBackward",
2070        Step::Conv2dBackwardInput { .. } => "rlx::Conv2dBackwardInput",
2071        Step::Conv2dBackwardWeight { .. } => "rlx::Conv2dBackwardWeight",
2072        Step::Pool1d { .. } => "rlx::Pool1d",
2073        Step::Pool2d { .. } => "rlx::Pool2d",
2074        Step::Pool3d { .. } => "rlx::Pool3d",
2075        Step::Conv1d { .. } => "rlx::Conv1d",
2076        Step::Conv2d { .. } => "rlx::Conv2d",
2077        Step::Conv3d { .. } => "rlx::Conv3d",
2078        Step::LayerNorm2d { .. } => "rlx::LayerNorm2d",
2079        Step::ConvTranspose2d { .. } => "rlx::ConvTranspose2d",
2080        Step::GroupNorm { .. } => "rlx::GroupNorm",
2081        Step::ResizeNearest2x { .. } => "rlx::ResizeNearest2x",
2082        Step::FusedBinaryUnary { .. } => "rlx::FusedBinaryUnary",
2083        Step::ElementwiseRegion { .. } => "rlx::ElementwiseRegion",
2084        Step::BatchElementwiseRegion { .. } => "rlx::BatchElementwiseRegion",
2085    }
2086}
2087
2088/// Walk a freshly-built schedule and merge `Binary → Unary` element-wise
2089/// chains into `FusedBinaryUnary`. Conditions for fusion:
2090///   1. The pair has matching element count `n`.
2091///   2. The Unary's input offset == the Binary's output offset.
2092///   3. The intermediate offset has exactly one consumer in the
2093///      schedule (= no other Step reads it). This guarantees we can
2094///      drop the round-trip to global memory for the intermediate
2095///      without breaking any other Step's input.
2096fn fuse_elementwise_chains(schedule: Vec<Step>) -> Vec<Step> {
2097    // Tally consumer counts per offset: how many Steps in the schedule
2098    // read each offset.
2099    let mut consumer_counts: HashMap<u32, usize> = HashMap::new();
2100    for step in &schedule {
2101        let (reads, _) = step_offsets(step);
2102        for r in &reads {
2103            *consumer_counts.entry(*r).or_insert(0) += 1;
2104        }
2105    }
2106
2107    let mut out = Vec::with_capacity(schedule.len());
2108    let mut i = 0;
2109    while i < schedule.len() {
2110        if i + 1 < schedule.len() {
2111            let pair = (&schedule[i], &schedule[i + 1]);
2112            if let (
2113                Step::Binary {
2114                    n,
2115                    a_off,
2116                    b_off,
2117                    c_off,
2118                    op: bin_op,
2119                },
2120                Step::Unary {
2121                    n: n2,
2122                    in_off,
2123                    out_off,
2124                    op: un_op,
2125                },
2126            ) = pair
2127            {
2128                let single_consumer = consumer_counts.get(c_off).copied() == Some(1);
2129                if n == n2 && c_off == in_off && single_consumer {
2130                    out.push(Step::FusedBinaryUnary {
2131                        n: *n,
2132                        a_off: *a_off,
2133                        b_off: *b_off,
2134                        out_off: *out_off,
2135                        bin_op: *bin_op,
2136                        un_op: *un_op,
2137                    });
2138                    i += 2;
2139                    continue;
2140                }
2141            }
2142        }
2143        out.push(schedule[i].clone());
2144        i += 1;
2145    }
2146    out
2147}
2148
2149/// (read offsets, write offsets) for a Step. Used by the multi-stream
2150/// scheduler to decide which streams each step depends on. Offsets are
2151/// the leading f32-element offset of each input/output tensor — a
2152/// coarse approximation that's correct for our planner since each
2153/// node has its own slot (Reshape/Cast aliasing maps consumers to the
2154/// same slot, which is exactly what the dependency tracker wants).
2155fn step_offsets(step: &Step) -> (Vec<u32>, Vec<u32>) {
2156    match step {
2157        Step::Matmul {
2158            a_off_f32,
2159            b_off_f32,
2160            c_off_f32,
2161            has_bias,
2162            bias_off_f32,
2163            ..
2164        } => {
2165            let mut r = vec![*a_off_f32, *b_off_f32];
2166            if *has_bias != 0 {
2167                r.push(*bias_off_f32);
2168            }
2169            (r, vec![*c_off_f32])
2170        }
2171        Step::Binary {
2172            a_off,
2173            b_off,
2174            c_off,
2175            ..
2176        }
2177        | Step::Compare {
2178            a_off,
2179            b_off,
2180            c_off,
2181            ..
2182        } => (vec![*a_off, *b_off], vec![*c_off]),
2183        Step::Unary {
2184            in_off, out_off, ..
2185        } => (vec![*in_off], vec![*out_off]),
2186        Step::Where {
2187            cond_off,
2188            x_off,
2189            y_off,
2190            out_off,
2191            ..
2192        } => (vec![*cond_off, *x_off, *y_off], vec![*out_off]),
2193        Step::Reduce {
2194            in_off, out_off, ..
2195        }
2196        | Step::Softmax {
2197            in_off, out_off, ..
2198        }
2199        | Step::Argmax {
2200            in_off, out_off, ..
2201        }
2202        | Step::Cumsum {
2203            in_off, out_off, ..
2204        }
2205        | Step::Sample {
2206            in_off, out_off, ..
2207        } => (vec![*in_off], vec![*out_off]),
2208        Step::TopK {
2209            in_off, out_off, ..
2210        } => (vec![*in_off], vec![*out_off]),
2211        Step::LayerNorm {
2212            in_off,
2213            gamma_off,
2214            beta_off,
2215            out_off,
2216            ..
2217        } => (vec![*in_off, *gamma_off, *beta_off], vec![*out_off]),
2218        Step::FusedResidualLn {
2219            in_off,
2220            residual_off,
2221            bias_off,
2222            gamma_off,
2223            beta_off,
2224            out_off,
2225            has_bias,
2226            ..
2227        } => {
2228            let mut r = vec![*in_off, *residual_off, *gamma_off, *beta_off];
2229            if *has_bias != 0 {
2230                r.push(*bias_off);
2231            }
2232            (r, vec![*out_off])
2233        }
2234        Step::FusedResidualRmsNorm {
2235            in_off,
2236            residual_off,
2237            bias_off,
2238            gamma_off,
2239            beta_off,
2240            out_off,
2241            has_bias,
2242            ..
2243        } => {
2244            let mut r = vec![*in_off, *residual_off, *gamma_off, *beta_off];
2245            if *has_bias != 0 {
2246                r.push(*bias_off);
2247            }
2248            (r, vec![*out_off])
2249        }
2250        Step::Gather {
2251            in_off,
2252            idx_off,
2253            out_off,
2254            ..
2255        } => (vec![*in_off, *idx_off], vec![*out_off]),
2256        Step::GatherAxis {
2257            table_off,
2258            idx_off,
2259            out_off,
2260            ..
2261        } => (vec![*table_off, *idx_off], vec![*out_off]),
2262        Step::Narrow {
2263            in_off, out_off, ..
2264        }
2265        | Step::Concat {
2266            in_off, out_off, ..
2267        } => (vec![*in_off], vec![*out_off]),
2268        Step::Transpose {
2269            in_off, out_off, ..
2270        }
2271        | Step::Expand {
2272            in_off, out_off, ..
2273        } => (vec![*in_off], vec![*out_off]),
2274        Step::Attention {
2275            q_off,
2276            k_off,
2277            v_off,
2278            mask_off,
2279            mask_kind,
2280            out_off,
2281            ..
2282        } => {
2283            let mut r = vec![*q_off, *k_off, *v_off];
2284            if *mask_kind == 2 || *mask_kind == 4 {
2285                r.push(*mask_off);
2286            }
2287            (r, vec![*out_off])
2288        }
2289        Step::AttentionBackward {
2290            q_off,
2291            k_off,
2292            v_off,
2293            dy_off,
2294            mask_off,
2295            mask_kind,
2296            out_off,
2297            ..
2298        } => {
2299            let mut r = vec![*q_off, *k_off, *v_off, *dy_off];
2300            if *mask_kind == 2 || *mask_kind == 4 {
2301                r.push(*mask_off);
2302            }
2303            (r, vec![*out_off])
2304        }
2305        Step::Rope {
2306            in_off,
2307            cos_off,
2308            sin_off,
2309            out_off,
2310            ..
2311        } => (vec![*in_off, *cos_off, *sin_off], vec![*out_off]),
2312        Step::GroupedMatmul {
2313            in_off,
2314            w_off,
2315            idx_off,
2316            out_off,
2317            ..
2318        } => (vec![*in_off, *w_off, *idx_off], vec![*out_off]),
2319        Step::ScatterAddZero { out_off, .. } => (vec![], vec![*out_off]),
2320        Step::ScatterAddAcc {
2321            upd_off,
2322            idx_off,
2323            out_off,
2324            ..
2325        } =>
2326        // out_off is read-modify-write — list it as both a read and
2327        // a write so the scheduler waits on the prior zero.
2328        {
2329            (vec![*upd_off, *idx_off, *out_off], vec![*out_off])
2330        }
2331        Step::DequantMatmul {
2332            x_off,
2333            w_off,
2334            scale_off,
2335            zp_off,
2336            out_off,
2337            scheme_id,
2338            ..
2339        } => {
2340            let mut r = vec![*x_off, *w_off, *scale_off];
2341            if *scheme_id == 1 {
2342                r.push(*zp_off);
2343            }
2344            (r, vec![*out_off])
2345        }
2346        Step::DequantMatmulGguf {
2347            x_byte_off,
2348            w_byte_off,
2349            out_byte_off,
2350            ..
2351        } => (vec![x_byte_off / 4, w_byte_off / 4], vec![out_byte_off / 4]),
2352        Step::DequantGroupedMatmulGguf {
2353            x_byte_off,
2354            w_byte_off,
2355            idx_byte_off,
2356            out_byte_off,
2357            ..
2358        } => (
2359            vec![x_byte_off / 4, w_byte_off / 4, idx_byte_off / 4],
2360            vec![out_byte_off / 4],
2361        ),
2362        Step::SelectiveScan {
2363            x_off,
2364            delta_off,
2365            a_off,
2366            b_off,
2367            c_off,
2368            out_off,
2369            ..
2370        } => (
2371            vec![*x_off, *delta_off, *a_off, *b_off, *c_off],
2372            vec![*out_off],
2373        ),
2374        Step::Fft {
2375            src_byte_off,
2376            dst_byte_off,
2377            ..
2378        } => (vec![*src_byte_off / 4], vec![*dst_byte_off / 4]),
2379        Step::LogMelHost {
2380            spec_byte_off,
2381            filt_byte_off,
2382            dst_byte_off,
2383            ..
2384        } => (
2385            vec![*spec_byte_off / 4, *filt_byte_off / 4],
2386            vec![*dst_byte_off / 4],
2387        ),
2388        Step::LogMelBackwardHost {
2389            spec_byte_off,
2390            filt_byte_off,
2391            dy_byte_off,
2392            dst_byte_off,
2393            ..
2394        } => (
2395            vec![*spec_byte_off / 4, *filt_byte_off / 4, *dy_byte_off / 4],
2396            vec![*dst_byte_off / 4],
2397        ),
2398        Step::WelchPeaksHost {
2399            spec_byte_off,
2400            dst_byte_off,
2401            ..
2402        } => (vec![*spec_byte_off / 4], vec![*dst_byte_off / 4]),
2403        Step::WelchPeaksGpu {
2404            spec_off, dst_off, ..
2405        } => (vec![*spec_off], vec![*dst_off]),
2406        Step::Im2ColHost {
2407            x_byte_off,
2408            col_byte_off,
2409            ..
2410        } => (vec![*x_byte_off / 4], vec![*col_byte_off / 4]),
2411        Step::GatedDeltaNet {
2412            q_byte_off,
2413            k_byte_off,
2414            v_byte_off,
2415            g_byte_off,
2416            beta_byte_off,
2417            state_byte_off,
2418            dst_byte_off,
2419            use_carry,
2420            ..
2421        } => {
2422            let mut reads = vec![
2423                q_byte_off / 4,
2424                k_byte_off / 4,
2425                v_byte_off / 4,
2426                g_byte_off / 4,
2427                beta_byte_off / 4,
2428            ];
2429            if *use_carry {
2430                reads.push(state_byte_off / 4);
2431            }
2432            let mut writes = vec![dst_byte_off / 4];
2433            if *use_carry {
2434                writes.push(state_byte_off / 4);
2435            }
2436            (reads, writes)
2437        }
2438        Step::Llada2GroupLimitedGate {
2439            sig_off,
2440            route_off,
2441            out_off,
2442            ..
2443        } => (vec![*sig_off, *route_off], vec![*out_off]),
2444        Step::UmapKnn {
2445            pairwise_off,
2446            out_off,
2447            ..
2448        } => (vec![*pairwise_off], vec![*out_off]),
2449        Step::GaussianSplatRender {
2450            positions_off,
2451            positions_len: _,
2452            scales_off,
2453            scales_len: _,
2454            rotations_off,
2455            rotations_len: _,
2456            opacities_off,
2457            opacities_len: _,
2458            colors_off,
2459            colors_len: _,
2460            sh_coeffs_off,
2461            sh_coeffs_len: _,
2462            meta_off,
2463            dst_off,
2464            dst_len: _,
2465            ..
2466        } => (
2467            vec![
2468                positions_off / 4,
2469                scales_off / 4,
2470                rotations_off / 4,
2471                opacities_off / 4,
2472                colors_off / 4,
2473                sh_coeffs_off / 4,
2474                meta_off / 4,
2475            ],
2476            vec![dst_off / 4],
2477        ),
2478        Step::GaussianSplatRenderBackward {
2479            positions_off,
2480            positions_len: _,
2481            scales_off,
2482            scales_len: _,
2483            rotations_off,
2484            rotations_len: _,
2485            opacities_off,
2486            opacities_len: _,
2487            colors_off,
2488            colors_len: _,
2489            sh_coeffs_off,
2490            sh_coeffs_len: _,
2491            meta_off,
2492            d_loss_off,
2493            d_loss_len: _,
2494            packed_off,
2495            packed_len: _,
2496            ..
2497        } => (
2498            vec![
2499                positions_off / 4,
2500                scales_off / 4,
2501                rotations_off / 4,
2502                opacities_off / 4,
2503                colors_off / 4,
2504                sh_coeffs_off / 4,
2505                meta_off / 4,
2506                d_loss_off / 4,
2507            ],
2508            vec![packed_off / 4],
2509        ),
2510        Step::RmsNormBackwardInput {
2511            x_byte_off,
2512            gamma_byte_off,
2513            beta_byte_off,
2514            dy_byte_off,
2515            dx_byte_off,
2516            ..
2517        } => (
2518            vec![
2519                x_byte_off / 4,
2520                gamma_byte_off / 4,
2521                beta_byte_off / 4,
2522                dy_byte_off / 4,
2523            ],
2524            vec![dx_byte_off / 4],
2525        ),
2526        Step::RmsNormBackwardGamma {
2527            x_byte_off,
2528            gamma_byte_off,
2529            beta_byte_off,
2530            dy_byte_off,
2531            dgamma_byte_off,
2532            ..
2533        } => (
2534            vec![
2535                x_byte_off / 4,
2536                gamma_byte_off / 4,
2537                beta_byte_off / 4,
2538                dy_byte_off / 4,
2539            ],
2540            vec![dgamma_byte_off / 4],
2541        ),
2542        Step::RmsNormBackwardBeta {
2543            x_byte_off,
2544            gamma_byte_off,
2545            beta_byte_off,
2546            dy_byte_off,
2547            dbeta_byte_off,
2548            ..
2549        } => (
2550            vec![
2551                x_byte_off / 4,
2552                gamma_byte_off / 4,
2553                beta_byte_off / 4,
2554                dy_byte_off / 4,
2555            ],
2556            vec![dbeta_byte_off / 4],
2557        ),
2558        Step::RopeBackward {
2559            dy_byte_off,
2560            cos_byte_off,
2561            sin_byte_off,
2562            dx_byte_off,
2563            ..
2564        } => (
2565            vec![dy_byte_off / 4, cos_byte_off / 4, sin_byte_off / 4],
2566            vec![dx_byte_off / 4],
2567        ),
2568        Step::CumsumBackward {
2569            dy_byte_off,
2570            dx_byte_off,
2571            ..
2572        } => (vec![dy_byte_off / 4], vec![dx_byte_off / 4]),
2573        Step::GatherBackward {
2574            dy_byte_off,
2575            indices_byte_off,
2576            dst_byte_off,
2577            ..
2578        } => (
2579            vec![dy_byte_off / 4, indices_byte_off / 4],
2580            vec![dst_byte_off / 4],
2581        ),
2582        Step::MaxPool2dBackward {
2583            x_byte_off,
2584            dy_byte_off,
2585            dx_byte_off,
2586            ..
2587        } => (
2588            vec![*x_byte_off / 4, *dy_byte_off / 4],
2589            vec![*dx_byte_off / 4],
2590        ),
2591        Step::Conv2dBackwardInput {
2592            dy_byte_off,
2593            w_byte_off,
2594            dx_byte_off,
2595            ..
2596        } => (
2597            vec![*dy_byte_off / 4, *w_byte_off / 4],
2598            vec![*dx_byte_off / 4],
2599        ),
2600        Step::Conv2dBackwardWeight {
2601            x_byte_off,
2602            dy_byte_off,
2603            dw_byte_off,
2604            ..
2605        } => (
2606            vec![*x_byte_off / 4, *dy_byte_off / 4],
2607            vec![*dw_byte_off / 4],
2608        ),
2609        Step::Pool1d {
2610            in_off, out_off, ..
2611        }
2612        | Step::Pool2d {
2613            in_off, out_off, ..
2614        }
2615        | Step::Pool3d {
2616            in_off, out_off, ..
2617        } => (vec![*in_off], vec![*out_off]),
2618        Step::Conv1d {
2619            in_off,
2620            w_off,
2621            out_off,
2622            ..
2623        }
2624        | Step::Conv2d {
2625            in_off,
2626            w_off,
2627            out_off,
2628            ..
2629        }
2630        | Step::Conv3d {
2631            in_off,
2632            w_off,
2633            out_off,
2634            ..
2635        } => (vec![*in_off, *w_off], vec![*out_off]),
2636        Step::LayerNorm2d {
2637            src_off,
2638            g_off,
2639            b_off,
2640            dst_off,
2641            ..
2642        } => (vec![*src_off, *g_off, *b_off], vec![*dst_off]),
2643        Step::ConvTranspose2d {
2644            src_off,
2645            w_off,
2646            dst_off,
2647            ..
2648        } => (vec![*src_off, *w_off], vec![*dst_off]),
2649        Step::GroupNorm {
2650            src_off,
2651            g_off,
2652            b_off,
2653            dst_off,
2654            ..
2655        } => (vec![*src_off, *g_off, *b_off], vec![*dst_off]),
2656        Step::ResizeNearest2x {
2657            src_off, dst_off, ..
2658        } => (vec![*src_off], vec![*dst_off]),
2659        Step::FusedBinaryUnary {
2660            a_off,
2661            b_off,
2662            out_off,
2663            ..
2664        } => (vec![*a_off, *b_off], vec![*out_off]),
2665        Step::ElementwiseRegion {
2666            dst_off,
2667            input_offs,
2668            num_inputs,
2669            ..
2670        } => {
2671            let n = (*num_inputs as usize).min(input_offs.len());
2672            (input_offs[..n].to_vec(), vec![*dst_off])
2673        }
2674        Step::BatchElementwiseRegion {
2675            base_dst_off,
2676            batch_input_offs,
2677            num_batch,
2678            ..
2679        } => {
2680            let n = (*num_batch as usize).min(64);
2681            (batch_input_offs[..n].to_vec(), vec![*base_dst_off])
2682        }
2683        Step::GaussianSplatPrepare {
2684            positions_off,
2685            scales_off,
2686            rotations_off,
2687            opacities_off,
2688            colors_off,
2689            sh_coeffs_off,
2690            meta_off,
2691            prep_off,
2692            ..
2693        } => (
2694            vec![
2695                positions_off / 4,
2696                scales_off / 4,
2697                rotations_off / 4,
2698                opacities_off / 4,
2699                colors_off / 4,
2700                sh_coeffs_off / 4,
2701                meta_off / 4,
2702            ],
2703            vec![prep_off / 4],
2704        ),
2705        Step::GaussianSplatRasterize {
2706            prep_off,
2707            meta_off,
2708            dst_off,
2709            ..
2710        } => (vec![prep_off / 4, meta_off / 4], vec![dst_off / 4]),
2711    }
2712}
2713
2714/// Pre-compile every NVRTC kernel against `ctx`. Used by AOT mode to
2715/// move JIT compile cost out of the first-run critical path. Runs at
2716/// most once per process — later `CompileMode::Aot` compiles skip it.
2717static AOT_PREWARM_ONCE: Once = Once::new();
2718
2719fn prewarm_all(ctx: &Arc<CudaContext>) {
2720    AOT_PREWARM_ONCE.call_once(|| prewarm_all_kernels(ctx));
2721}
2722
2723fn prewarm_all_kernels(ctx: &Arc<CudaContext>) {
2724    use crate::kernels::*;
2725    let _ = binary_kernel(ctx);
2726    let _ = fused_binary_unary_kernel(ctx);
2727    let _ = unary_kernel(ctx);
2728    let _ = copy_kernel(ctx);
2729    let _ = matmul_kernel(ctx);
2730    let _ = matmul_epilogue_kernel(ctx);
2731    let _ = compare_kernel(ctx);
2732    let _ = where_kernel(ctx);
2733    let _ = reduce_kernel(ctx);
2734    let _ = softmax_kernel(ctx);
2735    let _ = layernorm_kernel(ctx);
2736    let _ = fused_residual_ln_kernel(ctx);
2737    let _ = fused_residual_rms_norm_kernel(ctx);
2738    let _ = gather_kernel(ctx);
2739    let _ = gather_axis_kernel(ctx);
2740    let _ = narrow_kernel(ctx);
2741    let _ = concat_kernel(ctx);
2742    let _ = transpose_kernel(ctx);
2743    let _ = expand_kernel(ctx);
2744    let _ = attention_kernel(ctx);
2745    let _ = attention_row_kernel(ctx);
2746    let _ = attention_bwd_kernel(ctx);
2747    let _ = argmax_kernel(ctx);
2748    let _ = rope_kernel(ctx);
2749    let _ = cumsum_kernel(ctx);
2750    let _ = topk_kernel(ctx);
2751    let _ = grouped_matmul_kernel(ctx);
2752    let _ = scatter_add_zero_kernel(ctx);
2753    let _ = scatter_add_acc_kernel(ctx);
2754    let _ = dequant_matmul_kernel(ctx);
2755    let _ = dequant_gguf_kernel(ctx);
2756    let _ = sample_kernel(ctx);
2757    let _ = selective_scan_kernel(ctx);
2758    let _ = pool1d_kernel(ctx);
2759    let _ = pool2d_kernel(ctx);
2760    let _ = pool3d_kernel(ctx);
2761    let _ = conv1d_kernel(ctx);
2762    let _ = conv2d_kernel(ctx);
2763    let _ = im2col_kernel(ctx);
2764    let _ = conv3d_kernel(ctx);
2765    let _ = layer_norm2d_kernel(ctx);
2766    let _ = conv_transpose2d_kernel(ctx);
2767    let _ = group_norm_kernel(ctx);
2768    let _ = resize_nearest_2x_kernel(ctx);
2769    let _ = elementwise_region_kernel(ctx);
2770    let _ = batch_elementwise_region_kernel(ctx);
2771    // matmul_wmma deliberately excluded: requires SM 70+ and may fail
2772    // load_module on older GPUs. Compile lazily on first opt-in dispatch.
2773}
2774
2775fn im2col_use_gpu(n: u32, exec_mode: ExecMode) -> bool {
2776    if rlx_ir::env::var("RLX_CUDA_IM2COL_HOST").is_some() {
2777        return false;
2778    }
2779    if matches!(exec_mode, ExecMode::Graph) {
2780        return n > 0;
2781    }
2782    n > 0
2783}
2784
2785fn pinned_host_io_disabled() -> bool {
2786    rlx_ir::env::var("RLX_CUDA_PINNED_IO").is_some_and(|v| v.eq_ignore_ascii_case("0"))
2787}
2788
2789/// Pinned host output staging (faster D2H). On by default; set `RLX_CUDA_PINNED_IO=0` to disable.
2790fn pinned_output_staging_enabled() -> bool {
2791    !pinned_host_io_disabled()
2792}
2793
2794/// Pinned host input staging for H2D. Graph mode always; stream mode when `RLX_CUDA_PINNED_IO=1`.
2795fn pinned_input_staging_enabled(exec_mode: ExecMode) -> bool {
2796    if pinned_host_io_disabled() {
2797        return false;
2798    }
2799    matches!(exec_mode, ExecMode::Graph)
2800        || rlx_ir::env::var("RLX_CUDA_PINNED_IO").is_some_and(|v| !v.eq_ignore_ascii_case("0"))
2801}
2802
2803fn normalize_read_indices(buf: &mut Vec<usize>) {
2804    if buf.len() > 1 {
2805        buf.sort_unstable();
2806        buf.dedup();
2807    }
2808}
2809
2810fn compile_mode_from_env() -> CompileMode {
2811    match rlx_ir::env::var("RLX_CUDA_COMPILE_MODE").as_deref() {
2812        Some(mode) if mode.eq_ignore_ascii_case("aot") => CompileMode::Aot,
2813        _ => CompileMode::Jit,
2814    }
2815}
2816
2817fn exec_mode_from_env() -> ExecMode {
2818    match rlx_ir::env::var("RLX_CUDA_EXEC_MODE").as_deref() {
2819        Some(mode) if mode.eq_ignore_ascii_case("graph") => ExecMode::Graph,
2820        Some(mode) => {
2821            let lower = mode.to_ascii_lowercase();
2822            if let Some(rest) = lower.strip_prefix("multistream") {
2823                let n = rest.trim_start_matches([':', '=']).parse().unwrap_or(2);
2824                ExecMode::MultiStream(n.max(1))
2825            } else {
2826                ExecMode::Stream
2827            }
2828        }
2829        _ => ExecMode::Stream,
2830    }
2831}
2832
2833impl CudaExecutable {
2834    /// JIT compile, stream-mode execution. Default entry point.
2835    ///
2836    /// Honors `RLX_CUDA_COMPILE_MODE=aot` and `RLX_CUDA_EXEC_MODE=graph|multistream:N`.
2837    pub fn compile(graph: Graph) -> Self {
2838        Self::compile_with(graph, compile_mode_from_env(), exec_mode_from_env())
2839    }
2840
2841    /// One-shot eager run. Compiles, executes once with the given
2842    /// inputs, and drops the executable. No persistent state.
2843    pub fn eager(graph: Graph, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2844        let mut exec = Self::compile_with(graph, CompileMode::Jit, ExecMode::Eager);
2845        exec.run(inputs)
2846    }
2847
2848    /// Full constructor with explicit compile + exec modes.
2849    pub fn compile_with(graph: Graph, compile_mode: CompileMode, exec_mode: ExecMode) -> Self {
2850        let ctx = cuda_context().expect("rlx-cuda: no CUDA driver available");
2851
2852        if compile_mode == CompileMode::Aot {
2853            prewarm_all(&ctx);
2854        }
2855
2856        // Decompose composed ops we don't yet have native kernels for
2857        // (FusedMatMulBiasAct, canonical DotGeneral) into primitives
2858        // before memory planning. Fusion may reintroduce mid-axis Reduce
2859        // (e.g. EEG temporal mean); CUDA only schedules last-axis Reduce.
2860        let graph = LowerNonLastAxisReduce.run(crate::unfuse::unfuse(graph));
2861
2862        let dequant_scratch = crate::gguf_gpu::dequant_gguf_scratch_bytes(&graph);
2863        let mut plan = plan_f32_uniform(&graph, 16);
2864        let dequant_scratch_off = if dequant_scratch > 0 {
2865            let aligned = plan.arena_size.div_ceil(16) * 16;
2866            plan.arena_size = aligned + dequant_scratch;
2867            aligned
2868        } else {
2869            0
2870        };
2871        let mut arena = Arena::from_plan(&ctx, &plan);
2872        for node in graph.nodes() {
2873            let elems = node.shape.num_elements().unwrap_or(0);
2874            arena.set_actual_len(node.id, elems * 4);
2875        }
2876
2877        // Initial param/input offset maps for fast lookup at run time.
2878        let mut input_offsets = HashMap::new();
2879        let mut param_offsets = HashMap::new();
2880        for node in graph.nodes() {
2881            match &node.op {
2882                Op::Input { name } => {
2883                    input_offsets.insert(name.clone(), node.id);
2884                }
2885                Op::Param { name } => {
2886                    param_offsets.insert(name.clone(), node.id);
2887                }
2888                _ => {}
2889            }
2890        }
2891
2892        // Initialise Constants directly into the arena.
2893        for node in graph.nodes() {
2894            if let Op::Constant { data } = &node.op
2895                && arena.has(node.id)
2896                && !data.is_empty()
2897            {
2898                let bytes_to_write = data.len().min(arena.len_of(node.id));
2899                let n_f32 = bytes_to_write / 4;
2900                let f32_view: &[f32] =
2901                    unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n_f32) };
2902                let off_f32 = arena.offset(node.id) / 4;
2903                let stream = ctx.default_stream();
2904                let mut slot = arena.f32_buf_mut().slice_mut(off_f32..off_f32 + n_f32);
2905                stream
2906                    .memcpy_htod(f32_view, &mut slot)
2907                    .expect("rlx-cuda: constant upload failed");
2908            }
2909        }
2910
2911        let mut schedule = Vec::new();
2912        let mut meta_buffers: Vec<cudarc::driver::CudaSlice<u32>> = Vec::new();
2913        let mut packed_bshd_attn: HashMap<NodeId, (NodeId, u32)> = HashMap::new();
2914        if !rlx_ir::env::flag("RLX_CUDA_NO_PACKED_BSHD_ATTN") {
2915            for node in graph.nodes() {
2916                let Op::Attention { .. } = &node.op else {
2917                    continue;
2918                };
2919                if node.inputs.len() < 3 {
2920                    continue;
2921                }
2922                if let Some((parent, head_width, _)) = rlx_ir::detect_packed_bshd_qkv_attention(
2923                    &graph,
2924                    node.inputs[0],
2925                    node.inputs[1],
2926                    node.inputs[2],
2927                ) {
2928                    packed_bshd_attn.insert(node.id, (parent, head_width as u32));
2929                }
2930            }
2931        }
2932        for node in graph.nodes() {
2933            let elems = node.shape.num_elements().unwrap_or(0) as u32;
2934            match &node.op {
2935                Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => continue,
2936                Op::Reshape { .. } | Op::Cast { .. } => {
2937                    // No-op: arena.plan_f32_uniform already aliased the
2938                    // output slot to the input. The same row-major bytes
2939                    // are visible under the new node ID.
2940                }
2941                Op::MatMul => {
2942                    let (m, k, n, batch, a_bs, b_bs, c_bs, a_id, b_id) =
2943                        matmul_shape(&graph, node, "MatMul");
2944                    schedule.push(Step::Matmul {
2945                        m,
2946                        k,
2947                        n,
2948                        batch,
2949                        a_batch_stride: a_bs,
2950                        b_batch_stride: b_bs,
2951                        c_batch_stride: c_bs,
2952                        a_off_f32: (arena.offset(a_id) / 4) as u32,
2953                        b_off_f32: (arena.offset(b_id) / 4) as u32,
2954                        c_off_f32: (arena.offset(node.id) / 4) as u32,
2955                        has_bias: 0,
2956                        bias_off_f32: 0,
2957                        act_id: 0xFFFF,
2958                    });
2959                }
2960                Op::FusedMatMulBiasAct { activation } => {
2961                    let (m, k, n, batch, a_bs, b_bs, c_bs, a_id, b_id) =
2962                        matmul_shape(&graph, node, "FusedMatMulBiasAct");
2963                    let bias_id = node.inputs[2];
2964                    let act_id = match activation {
2965                        None => 0xFFFFu32,
2966                        Some(a) => activation_op_id(*a),
2967                    };
2968                    schedule.push(Step::Matmul {
2969                        m,
2970                        k,
2971                        n,
2972                        batch,
2973                        a_batch_stride: a_bs,
2974                        b_batch_stride: b_bs,
2975                        c_batch_stride: c_bs,
2976                        a_off_f32: (arena.offset(a_id) / 4) as u32,
2977                        b_off_f32: (arena.offset(b_id) / 4) as u32,
2978                        c_off_f32: (arena.offset(node.id) / 4) as u32,
2979                        has_bias: 1,
2980                        bias_off_f32: (arena.offset(bias_id) / 4) as u32,
2981                        act_id,
2982                    });
2983                }
2984                Op::Binary(bop) => {
2985                    schedule.push(Step::Binary {
2986                        n: elems,
2987                        a_off: (arena.offset(node.inputs[0]) / 4) as u32,
2988                        b_off: (arena.offset(node.inputs[1]) / 4) as u32,
2989                        c_off: (arena.offset(node.id) / 4) as u32,
2990                        op: binary_op_id(*bop),
2991                    });
2992                }
2993                Op::Activation(act) => {
2994                    schedule.push(Step::Unary {
2995                        n: elems,
2996                        in_off: (arena.offset(node.inputs[0]) / 4) as u32,
2997                        out_off: (arena.offset(node.id) / 4) as u32,
2998                        op: activation_op_id(*act),
2999                    });
3000                }
3001                Op::Compare(cop) => {
3002                    schedule.push(Step::Compare {
3003                        n: elems,
3004                        a_off: (arena.offset(node.inputs[0]) / 4) as u32,
3005                        b_off: (arena.offset(node.inputs[1]) / 4) as u32,
3006                        c_off: (arena.offset(node.id) / 4) as u32,
3007                        op: compare_op_id(*cop),
3008                    });
3009                }
3010                Op::Where => {
3011                    schedule.push(Step::Where {
3012                        n: elems,
3013                        cond_off: (arena.offset(node.inputs[0]) / 4) as u32,
3014                        x_off: (arena.offset(node.inputs[1]) / 4) as u32,
3015                        y_off: (arena.offset(node.inputs[2]) / 4) as u32,
3016                        out_off: (arena.offset(node.id) / 4) as u32,
3017                    });
3018                }
3019                Op::BatchElementwiseRegion {
3020                    chain,
3021                    num_batch_inputs,
3022                    scalar_input_mask,
3023                    input_modulus,
3024                    prologue,
3025                    prologue_input,
3026                } => {
3027                    let n = *num_batch_inputs as usize;
3028                    if n == 0 || chain.len() > 32 {
3029                        panic!(
3030                            "rlx-cuda BatchElementwiseRegion: num_batch_inputs={n} steps={}",
3031                            chain.len()
3032                        );
3033                    }
3034                    let slice_shape = rlx_ir::batch_region_slice_shape(&node.shape);
3035                    let slice_elems = rlx_ir::batch_region_slice_elems(&node.shape, n)
3036                        .expect("batch region static shape");
3037                    let base_dst_off = (arena.offset(node.id) / 4) as u32;
3038                    let use_single = rlx_ir::fk_batch_use_single_launch(n, *prologue);
3039                    if use_single {
3040                        let mut batch_input_offs = [0u32; 64];
3041                        for i in 0..n {
3042                            batch_input_offs[i] = (arena.offset(node.inputs[i]) / 4) as u32;
3043                        }
3044                        let input_offs_meta = [0u32; 16];
3045                        let meta_arr = rlx_ir::encode_elementwise_region_meta(
3046                            &input_offs_meta,
3047                            chain,
3048                            *prologue,
3049                            &slice_shape,
3050                            *prologue_input,
3051                        );
3052                        let meta = ctx
3053                            .default_stream()
3054                            .clone_htod(&meta_arr.to_vec())
3055                            .expect("rlx-cuda: batch elementwise_region meta upload failed");
3056                        let meta_idx = meta_buffers.len();
3057                        meta_buffers.push(meta);
3058                        let batch_vec: Vec<u32> = batch_input_offs[..n].to_vec();
3059                        let batch_dev = ctx
3060                            .default_stream()
3061                            .clone_htod(&batch_vec)
3062                            .expect("rlx-cuda: batch input offs upload failed");
3063                        let batch_offs_idx = meta_buffers.len();
3064                        meta_buffers.push(batch_dev);
3065                        schedule.push(Step::BatchElementwiseRegion {
3066                            slice_len: slice_elems,
3067                            num_batch: n as u32,
3068                            num_steps: chain.len() as u32,
3069                            base_dst_off,
3070                            slice_elems,
3071                            batch_input_offs,
3072                            batch_offs_idx,
3073                            meta_idx,
3074                            scalar_input_mask: *scalar_input_mask,
3075                            input_modulus: *input_modulus,
3076                        });
3077                    } else {
3078                        for i in 0..n {
3079                            let mut input_offs = [0u32; 16];
3080                            input_offs[0] = (arena.offset(node.inputs[i]) / 4) as u32;
3081                            let meta_arr = rlx_ir::encode_elementwise_region_meta(
3082                                &input_offs,
3083                                chain,
3084                                *prologue,
3085                                &slice_shape,
3086                                *prologue_input,
3087                            );
3088                            let meta = ctx
3089                                .default_stream()
3090                                .clone_htod(&meta_arr.to_vec())
3091                                .expect("rlx-cuda: batch elementwise_region meta upload failed");
3092                            let meta_idx = meta_buffers.len();
3093                            meta_buffers.push(meta);
3094                            let spatial =
3095                                matches!(*prologue, rlx_ir::RegionPrologue::ResizeNearest2x);
3096                            let grid = rlx_ir::PrologueLaunchGrid::from_output_shape(&slice_shape);
3097                            schedule.push(Step::ElementwiseRegion {
3098                                len: slice_elems,
3099                                num_inputs: 1,
3100                                num_steps: chain.len() as u32,
3101                                dst_off: rlx_ir::batch_region_slice_dst_off_f32(
3102                                    base_dst_off,
3103                                    slice_elems,
3104                                    i,
3105                                ),
3106                                input_offs,
3107                                scalar_input_mask: *scalar_input_mask,
3108                                input_modulus: *input_modulus,
3109                                meta_idx,
3110                                spatial_prologue: spatial,
3111                                prologue_w: grid.map(|g| g.width).unwrap_or(0),
3112                                prologue_h: grid.map(|g| g.height).unwrap_or(0),
3113                                prologue_nc: grid.map(|g| g.depth).unwrap_or(0),
3114                            });
3115                        }
3116                    }
3117                }
3118                Op::ElementwiseRegion {
3119                    chain,
3120                    num_inputs,
3121                    scalar_input_mask,
3122                    input_modulus,
3123                    prologue,
3124                    prologue_input,
3125                } => {
3126                    // PLAN L2 native lowering. Encode the chain into a
3127                    // 72-u32 metadata buffer (8 input offsets + 16 steps *
3128                    // 4 u32s) uploaded once at compile time; the kernel
3129                    // walks the chain interpretively in registers. Caps
3130                    // match the cross-backend Metal MSL / wgpu WGSL
3131                    // encoders.
3132                    let n = *num_inputs as usize;
3133                    if n > 16 || chain.len() > 32 {
3134                        panic!(
3135                            "rlx-cuda ElementwiseRegion: chain too large \
3136                                (inputs={n}, steps={}). Caps: 16 / 32. \
3137                                Run UnfuseElementwiseRegions to fall back \
3138                                to atomic ops.",
3139                            chain.len()
3140                        );
3141                    }
3142                    let mut input_offs = [0u32; 16];
3143                    for (i, &id) in node.inputs.iter().enumerate() {
3144                        input_offs[i] = (arena.offset(id) / 4) as u32;
3145                    }
3146                    let meta_arr = rlx_ir::encode_elementwise_region_meta(
3147                        &input_offs,
3148                        chain,
3149                        *prologue,
3150                        &node.shape,
3151                        *prologue_input,
3152                    );
3153                    let meta_data: Vec<u32> = meta_arr.to_vec();
3154                    let meta = ctx
3155                        .default_stream()
3156                        .clone_htod(&meta_data)
3157                        .expect("rlx-cuda: elementwise_region meta upload failed");
3158                    let meta_idx = meta_buffers.len();
3159                    meta_buffers.push(meta);
3160                    let spatial = matches!(*prologue, rlx_ir::RegionPrologue::ResizeNearest2x);
3161                    let grid = rlx_ir::PrologueLaunchGrid::from_output_shape(&node.shape);
3162                    schedule.push(Step::ElementwiseRegion {
3163                        len: elems,
3164                        num_inputs: *num_inputs,
3165                        num_steps: chain.len() as u32,
3166                        dst_off: (arena.offset(node.id) / 4) as u32,
3167                        input_offs,
3168                        scalar_input_mask: *scalar_input_mask,
3169                        input_modulus: *input_modulus,
3170                        meta_idx,
3171                        spatial_prologue: spatial,
3172                        prologue_w: grid.map(|g| g.width).unwrap_or(0),
3173                        prologue_h: grid.map(|g| g.height).unwrap_or(0),
3174                        prologue_nc: grid.map(|g| g.depth).unwrap_or(0),
3175                    });
3176                }
3177                Op::Reduce {
3178                    op,
3179                    axes,
3180                    keep_dim: _,
3181                } => {
3182                    // v2: reduce along the LAST axis only — same v1
3183                    // simplification rlx-wgpu had.
3184                    let in_id = node.inputs[0];
3185                    let in_dims = graph.node(in_id).shape.dims();
3186                    if axes.len() != 1 || axes[0] != in_dims.len() - 1 {
3187                        panic!(
3188                            "rlx-cuda Reduce: only single last-axis supported \
3189                                (got axes={axes:?}, rank={})",
3190                            in_dims.len()
3191                        );
3192                    }
3193                    let inner = in_dims.last().unwrap().unwrap_static() as u32;
3194                    let outer = in_dims[..in_dims.len() - 1]
3195                        .iter()
3196                        .map(|d| d.unwrap_static() as u32)
3197                        .product::<u32>()
3198                        .max(1);
3199                    schedule.push(Step::Reduce {
3200                        outer,
3201                        inner,
3202                        in_off: (arena.offset(in_id) / 4) as u32,
3203                        out_off: (arena.offset(node.id) / 4) as u32,
3204                        op: reduce_op_id(*op),
3205                    });
3206                }
3207                Op::Softmax { axis: _ } => {
3208                    let in_id = node.inputs[0];
3209                    let in_dims = graph.node(in_id).shape.dims();
3210                    let inner = in_dims.last().unwrap().unwrap_static() as u32;
3211                    let outer = in_dims[..in_dims.len() - 1]
3212                        .iter()
3213                        .map(|d| d.unwrap_static() as u32)
3214                        .product::<u32>()
3215                        .max(1);
3216                    schedule.push(Step::Softmax {
3217                        outer,
3218                        inner,
3219                        in_off: (arena.offset(in_id) / 4) as u32,
3220                        out_off: (arena.offset(node.id) / 4) as u32,
3221                    });
3222                }
3223                Op::LayerNorm { axis: _, eps } | Op::RmsNorm { axis: _, eps } => {
3224                    let in_id = node.inputs[0];
3225                    let in_dims = graph.node(in_id).shape.dims();
3226                    let inner = in_dims.last().unwrap().unwrap_static() as u32;
3227                    let total: u32 = in_dims.iter().map(|d| d.unwrap_static() as u32).product();
3228                    let outer = total / inner.max(1);
3229                    let is_layer = matches!(&node.op, Op::LayerNorm { .. });
3230                    let gamma_id = node.inputs[1];
3231                    let beta_id = if is_layer && node.inputs.len() >= 3 {
3232                        node.inputs[2]
3233                    } else {
3234                        gamma_id
3235                    };
3236                    schedule.push(Step::LayerNorm {
3237                        outer,
3238                        inner,
3239                        in_off: (arena.offset(in_id) / 4) as u32,
3240                        out_off: (arena.offset(node.id) / 4) as u32,
3241                        gamma_off: (arena.offset(gamma_id) / 4) as u32,
3242                        beta_off: (arena.offset(beta_id) / 4) as u32,
3243                        eps_bits: eps.to_bits(),
3244                        op: if is_layer { 0 } else { 1 },
3245                    });
3246                }
3247                Op::FusedResidualLN { has_bias, eps } => {
3248                    let x_id = node.inputs[0];
3249                    let r_id = node.inputs[1];
3250                    let (bias_id, g_id, b_id) = if *has_bias {
3251                        (node.inputs[2], node.inputs[3], node.inputs[4])
3252                    } else {
3253                        (x_id, node.inputs[2], node.inputs[3])
3254                    };
3255                    let in_dims = node.shape.dims();
3256                    let inner = in_dims.last().unwrap().unwrap_static() as u32;
3257                    let total: u32 = in_dims.iter().map(|d| d.unwrap_static() as u32).product();
3258                    let outer = total / inner.max(1);
3259                    schedule.push(Step::FusedResidualLn {
3260                        outer,
3261                        inner,
3262                        in_off: (arena.offset(x_id) / 4) as u32,
3263                        residual_off: (arena.offset(r_id) / 4) as u32,
3264                        bias_off: (arena.offset(bias_id) / 4) as u32,
3265                        gamma_off: (arena.offset(g_id) / 4) as u32,
3266                        beta_off: (arena.offset(b_id) / 4) as u32,
3267                        out_off: (arena.offset(node.id) / 4) as u32,
3268                        eps_bits: eps.to_bits(),
3269                        has_bias: if *has_bias { 1 } else { 0 },
3270                    });
3271                }
3272                Op::FusedResidualRmsNorm { has_bias, eps } => {
3273                    let x_id = node.inputs[0];
3274                    let r_id = node.inputs[1];
3275                    let (bias_id, g_id, b_id) = if *has_bias {
3276                        (node.inputs[2], node.inputs[3], node.inputs[4])
3277                    } else {
3278                        (x_id, node.inputs[2], node.inputs[3])
3279                    };
3280                    let in_dims = node.shape.dims();
3281                    let inner = in_dims.last().unwrap().unwrap_static() as u32;
3282                    let total: u32 = in_dims.iter().map(|d| d.unwrap_static() as u32).product();
3283                    let outer = total / inner.max(1);
3284                    schedule.push(Step::FusedResidualRmsNorm {
3285                        outer,
3286                        inner,
3287                        in_off: (arena.offset(x_id) / 4) as u32,
3288                        residual_off: (arena.offset(r_id) / 4) as u32,
3289                        bias_off: (arena.offset(bias_id) / 4) as u32,
3290                        gamma_off: (arena.offset(g_id) / 4) as u32,
3291                        beta_off: (arena.offset(b_id) / 4) as u32,
3292                        out_off: (arena.offset(node.id) / 4) as u32,
3293                        eps_bits: eps.to_bits(),
3294                        has_bias: if *has_bias { 1 } else { 0 },
3295                    });
3296                }
3297                Op::Gather { axis } => {
3298                    let table_id = node.inputs[0];
3299                    let idx_id = node.inputs[1];
3300                    if *axis == 0 {
3301                        let table_shape = graph.node(table_id).shape.dims();
3302                        let idx_shape = graph.node(idx_id).shape.dims();
3303                        let vocab = table_shape[0].unwrap_static() as u32;
3304                        let dim: u32 = table_shape[1..]
3305                            .iter()
3306                            .map(|d| d.unwrap_static() as u32)
3307                            .product::<u32>()
3308                            .max(1);
3309                        let n_idx: u32 =
3310                            idx_shape.iter().map(|d| d.unwrap_static() as u32).product();
3311                        schedule.push(Step::Gather {
3312                            n_out: elems,
3313                            n_idx,
3314                            dim,
3315                            vocab,
3316                            in_off: (arena.offset(table_id) / 4) as u32,
3317                            idx_off: (arena.offset(idx_id) / 4) as u32,
3318                            out_off: (arena.offset(node.id) / 4) as u32,
3319                        });
3320                    } else {
3321                        let table_shape = graph.node(table_id).shape.dims();
3322                        let idx_shape = graph.node(idx_id).shape.dims();
3323                        let outer: u32 = table_shape[..*axis]
3324                            .iter()
3325                            .map(|d| d.unwrap_static() as u32)
3326                            .product::<u32>()
3327                            .max(1);
3328                        let trailing: u32 = table_shape[*axis + 1..]
3329                            .iter()
3330                            .map(|d| d.unwrap_static() as u32)
3331                            .product::<u32>()
3332                            .max(1);
3333                        let axis_dim = table_shape[*axis].unwrap_static() as u32;
3334                        let num_idx: u32 =
3335                            idx_shape.iter().map(|d| d.unwrap_static() as u32).product();
3336                        let total = outer * num_idx * trailing;
3337                        schedule.push(Step::GatherAxis {
3338                            total,
3339                            outer,
3340                            axis_dim,
3341                            num_idx,
3342                            trailing,
3343                            table_off: (arena.offset(table_id) / 4) as u32,
3344                            idx_off: (arena.offset(idx_id) / 4) as u32,
3345                            out_off: (arena.offset(node.id) / 4) as u32,
3346                        });
3347                    }
3348                }
3349                Op::Narrow { axis, start, len } => {
3350                    let in_id = node.inputs[0];
3351                    let in_dims = graph.node(in_id).shape.dims();
3352                    let outer: u32 = in_dims[..*axis]
3353                        .iter()
3354                        .map(|d| d.unwrap_static() as u32)
3355                        .product::<u32>()
3356                        .max(1);
3357                    let inner: u32 = in_dims[*axis + 1..]
3358                        .iter()
3359                        .map(|d| d.unwrap_static() as u32)
3360                        .product::<u32>()
3361                        .max(1);
3362                    let axis_in = in_dims[*axis].unwrap_static() as u32;
3363                    schedule.push(Step::Narrow {
3364                        total: elems,
3365                        outer,
3366                        inner,
3367                        axis_in_size: axis_in,
3368                        axis_out_size: *len as u32,
3369                        start: *start as u32,
3370                        in_off: (arena.offset(in_id) / 4) as u32,
3371                        out_off: (arena.offset(node.id) / 4) as u32,
3372                    });
3373                }
3374                Op::Transpose { perm } => {
3375                    let in_id = node.inputs[0];
3376                    let in_dims = graph.node(in_id).shape.dims();
3377                    let rank = perm.len();
3378                    let in_dims_u: Vec<u32> =
3379                        in_dims.iter().map(|d| d.unwrap_static() as u32).collect();
3380                    // Cumulative input strides (row-major, innermost = 1).
3381                    let mut in_strides = vec![1u32; rank];
3382                    for i in (0..rank.saturating_sub(1)).rev() {
3383                        in_strides[i] = in_strides[i + 1] * in_dims_u[i + 1];
3384                    }
3385                    let out_dims_u: Vec<u32> = perm.iter().map(|&i| in_dims_u[i]).collect();
3386                    let strides_for_out: Vec<u32> = perm.iter().map(|&i| in_strides[i]).collect();
3387                    let mut meta_data: Vec<u32> = Vec::with_capacity(rank * 2);
3388                    meta_data.extend_from_slice(&out_dims_u);
3389                    meta_data.extend_from_slice(&strides_for_out);
3390                    let meta = ctx
3391                        .default_stream()
3392                        .clone_htod(&meta_data)
3393                        .expect("rlx-cuda: meta upload failed");
3394                    let meta_idx = meta_buffers.len();
3395                    meta_buffers.push(meta);
3396                    schedule.push(Step::Transpose {
3397                        rank: rank as u32,
3398                        out_total: elems,
3399                        in_off: (arena.offset(in_id) / 4) as u32,
3400                        out_off: (arena.offset(node.id) / 4) as u32,
3401                        meta_idx,
3402                    });
3403                }
3404                Op::Expand { target_shape } => {
3405                    let in_id = node.inputs[0];
3406                    let in_shape = graph.node(in_id).shape.dims();
3407                    let rank = target_shape.len();
3408                    if rank < in_shape.len() {
3409                        panic!(
3410                            "rlx-cuda Expand: cannot reduce rank (in={}, target={})",
3411                            in_shape.len(),
3412                            rank
3413                        );
3414                    }
3415                    let out_dims: Vec<u32> = target_shape.iter().map(|&d| d as u32).collect();
3416                    let pad = rank - in_shape.len();
3417                    let mut in_dims: Vec<u32> = vec![1; pad];
3418                    in_dims.extend(in_shape.iter().map(|d| d.unwrap_static() as u32));
3419                    let mut in_strides_row = vec![1u32; rank];
3420                    for i in (0..rank.saturating_sub(1)).rev() {
3421                        in_strides_row[i] = in_strides_row[i + 1] * in_dims[i + 1];
3422                    }
3423                    let strides_for_out: Vec<u32> = (0..rank)
3424                        .map(|i| {
3425                            if in_dims[i] == 1 && out_dims[i] != 1 {
3426                                0
3427                            } else {
3428                                in_strides_row[i]
3429                            }
3430                        })
3431                        .collect();
3432                    let mut meta_data: Vec<u32> = Vec::with_capacity(rank * 2);
3433                    meta_data.extend_from_slice(&out_dims);
3434                    meta_data.extend_from_slice(&strides_for_out);
3435                    let meta = ctx
3436                        .default_stream()
3437                        .clone_htod(&meta_data)
3438                        .expect("rlx-cuda: meta upload failed");
3439                    let meta_idx = meta_buffers.len();
3440                    meta_buffers.push(meta);
3441                    schedule.push(Step::Expand {
3442                        rank: rank as u32,
3443                        out_total: elems,
3444                        in_off: (arena.offset(in_id) / 4) as u32,
3445                        out_off: (arena.offset(node.id) / 4) as u32,
3446                        meta_idx,
3447                    });
3448                }
3449                Op::Concat { axis } => {
3450                    // Caller convention: one Step::Concat per input, copying
3451                    // each input's slice into the output at the right axis offset.
3452                    let mut start: u32 = 0;
3453                    let out_dims = node.shape.dims();
3454                    let outer: u32 = out_dims[..*axis]
3455                        .iter()
3456                        .map(|d| d.unwrap_static() as u32)
3457                        .product::<u32>()
3458                        .max(1);
3459                    let inner: u32 = out_dims[*axis + 1..]
3460                        .iter()
3461                        .map(|d| d.unwrap_static() as u32)
3462                        .product::<u32>()
3463                        .max(1);
3464                    let axis_out_size = out_dims[*axis].unwrap_static() as u32;
3465                    for &in_id in &node.inputs {
3466                        let in_dims = graph.node(in_id).shape.dims();
3467                        let axis_in = in_dims[*axis].unwrap_static() as u32;
3468                        let total: u32 = in_dims.iter().map(|d| d.unwrap_static() as u32).product();
3469                        schedule.push(Step::Concat {
3470                            total,
3471                            outer,
3472                            inner,
3473                            axis_in_size: axis_in,
3474                            axis_out_size,
3475                            start,
3476                            in_off: (arena.offset(in_id) / 4) as u32,
3477                            out_off: (arena.offset(node.id) / 4) as u32,
3478                        });
3479                        start += axis_in;
3480                    }
3481                }
3482                Op::Attention {
3483                    num_heads,
3484                    head_dim,
3485                    mask_kind,
3486                    score_scale: _,
3487                    attn_logit_softcap: _,
3488                } => {
3489                    let q_id = node.inputs[0];
3490                    let k_id = node.inputs[1];
3491                    let v_id = node.inputs[2];
3492                    let q_shape = graph.node(q_id).shape.dims();
3493                    let k_shape = graph.node(k_id).shape.dims();
3494                    if q_shape.len() != 4 {
3495                        panic!("rlx-cuda Attention: unfuse should have promoted to rank-4");
3496                    }
3497                    let q_ir = graph.node(q_id).shape.clone();
3498                    let k_ir = graph.node(k_id).shape.clone();
3499                    let geom = rlx_ir::attention_geom(&q_ir, &k_ir, *num_heads, *head_dim);
3500                    let batch = geom.batch as u32;
3501                    let heads = geom.heads as u32;
3502                    let seq_q = geom.seq_q as u32;
3503                    let seq_k = geom.seq_k as u32;
3504                    let hd = *head_dim as u32;
3505                    let scale = 1.0_f32 / (hd as f32).sqrt();
3506                    let mask_shape = if matches!(mask_kind, MaskKind::Custom | MaskKind::Bias) {
3507                        Some(graph.node(node.inputs[3]).shape.dims())
3508                    } else {
3509                        None
3510                    };
3511                    let packed_parent = packed_bshd_attn.get(&node.id).copied();
3512                    let st = if let Some((_, head_width)) = packed_parent {
3513                        let (qb, qh, qs) =
3514                            rlx_ir::packed_bshd_qkv_strides(head_width as usize, hd, seq_q);
3515                        let (ob, oh, os) =
3516                            rlx_ir::strides_for_shape(node.shape.dims(), heads, hd, seq_q, false);
3517                        let (mb, mh, mq, mk) = mask_shape
3518                            .map(|m| rlx_ir::mask_strides_for_shape(m, heads, seq_q, seq_k))
3519                            .unwrap_or_else(|| rlx_ir::mask_strides_bhsd(heads, seq_q, seq_k));
3520                        rlx_ir::AttentionLaunchStrides {
3521                            q_batch: qb,
3522                            q_head: qh,
3523                            q_seq: qs,
3524                            k_batch: qb,
3525                            k_head: qh,
3526                            k_seq: qs,
3527                            v_batch: qb,
3528                            v_head: qh,
3529                            v_seq: qs,
3530                            o_batch: ob,
3531                            o_head: oh,
3532                            o_seq: os,
3533                            mask_batch: mb,
3534                            mask_head: mh,
3535                            mask_q: mq,
3536                            mask_k: mk,
3537                        }
3538                    } else {
3539                        rlx_ir::attention_launch_strides(
3540                            geom,
3541                            q_shape,
3542                            k_shape,
3543                            graph.node(v_id).shape.dims(),
3544                            node.shape.dims(),
3545                            mask_shape,
3546                        )
3547                    };
3548                    let (q_off, k_off, v_off) = if let Some((parent, head_width)) = packed_parent {
3549                        let p = (arena.offset(parent) / 4) as u32;
3550                        (
3551                            p,
3552                            p.saturating_add(head_width),
3553                            p.saturating_add(head_width * 2),
3554                        )
3555                    } else {
3556                        (
3557                            (arena.offset(q_id) / 4) as u32,
3558                            (arena.offset(k_id) / 4) as u32,
3559                            (arena.offset(v_id) / 4) as u32,
3560                        )
3561                    };
3562                    let (mask_kind_id, mask_off, window) = match mask_kind {
3563                        MaskKind::None => (0u32, 0u32, 0u32),
3564                        MaskKind::Causal => (1u32, 0u32, 0u32),
3565                        MaskKind::Custom => (2u32, (arena.offset(node.inputs[3]) / 4) as u32, 0u32),
3566                        MaskKind::SlidingWindow(w) => (3u32, 0u32, *w as u32),
3567                        MaskKind::Bias => (4u32, (arena.offset(node.inputs[3]) / 4) as u32, 0u32),
3568                    };
3569                    schedule.push(Step::Attention {
3570                        batch,
3571                        heads,
3572                        seq_q,
3573                        seq_k,
3574                        head_dim: hd,
3575                        q_off,
3576                        k_off,
3577                        v_off,
3578                        out_off: (arena.offset(node.id) / 4) as u32,
3579                        mask_off,
3580                        mask_kind: mask_kind_id,
3581                        scale_bits: scale.to_bits(),
3582                        window,
3583                        seq_q_stride: st.mask_q,
3584                        seq_k_stride: st.mask_k,
3585                        mask_batch_stride: st.mask_batch,
3586                        mask_head_stride: st.mask_head,
3587                        q_batch_stride: st.q_batch,
3588                        q_head_stride: st.q_head,
3589                        q_seq_stride: st.q_seq,
3590                        k_batch_stride: st.k_batch,
3591                        k_head_stride: st.k_head,
3592                        k_seq_stride: st.k_seq,
3593                        v_batch_stride: st.v_batch,
3594                        v_head_stride: st.v_head,
3595                        v_seq_stride: st.v_seq,
3596                        o_batch_stride: st.o_batch,
3597                        o_head_stride: st.o_head,
3598                        o_seq_stride: st.o_seq,
3599                    });
3600                }
3601                Op::AttentionBackward {
3602                    num_heads: _,
3603                    head_dim,
3604                    mask_kind,
3605                    wrt,
3606                } => {
3607                    use rlx_ir::op::AttentionBwdWrt;
3608                    let q_id = node.inputs[0];
3609                    let k_id = node.inputs[1];
3610                    let v_id = node.inputs[2];
3611                    let dy_id = node.inputs[3];
3612                    let q_shape = graph.node(q_id).shape.dims();
3613                    let k_shape = graph.node(k_id).shape.dims();
3614                    if q_shape.len() != 4 {
3615                        panic!("rlx-cuda AttentionBackward: unfuse should have promoted to rank-4");
3616                    }
3617                    let batch = q_shape[0].unwrap_static() as u32;
3618                    let heads = q_shape[1].unwrap_static() as u32;
3619                    let seq_q = q_shape[2].unwrap_static() as u32;
3620                    let seq_k = k_shape[2].unwrap_static() as u32;
3621                    let hd = *head_dim as u32;
3622                    let scale = 1.0_f32 / (hd as f32).sqrt();
3623                    let (mask_kind_id, mask_off, window) = match mask_kind {
3624                        MaskKind::None => (0u32, 0u32, 0u32),
3625                        MaskKind::Causal => (1u32, 0u32, 0u32),
3626                        MaskKind::Custom => (2u32, (arena.offset(node.inputs[4]) / 4) as u32, 0u32),
3627                        MaskKind::SlidingWindow(w) => (3u32, 0u32, *w as u32),
3628                        MaskKind::Bias => (4u32, (arena.offset(node.inputs[4]) / 4) as u32, 0u32),
3629                    };
3630                    let wrt_id = match wrt {
3631                        AttentionBwdWrt::Query => 0u32,
3632                        AttentionBwdWrt::Key => 1u32,
3633                        AttentionBwdWrt::Value => 2u32,
3634                    };
3635                    schedule.push(Step::AttentionBackward {
3636                        batch,
3637                        heads,
3638                        seq_q,
3639                        seq_k,
3640                        head_dim: hd,
3641                        q_off: (arena.offset(q_id) / 4) as u32,
3642                        k_off: (arena.offset(k_id) / 4) as u32,
3643                        v_off: (arena.offset(v_id) / 4) as u32,
3644                        dy_off: (arena.offset(dy_id) / 4) as u32,
3645                        out_off: (arena.offset(node.id) / 4) as u32,
3646                        mask_off,
3647                        mask_kind: mask_kind_id,
3648                        scale_bits: scale.to_bits(),
3649                        window,
3650                        wrt: wrt_id,
3651                    });
3652                }
3653                Op::Rope { head_dim, n_rot: _ } => {
3654                    let x_id = node.inputs[0];
3655                    let cos_id = node.inputs[1];
3656                    let sin_id = node.inputs[2];
3657                    let x_shape = graph.node(x_id).shape.dims();
3658                    let last = x_shape.last().map(|d| d.unwrap_static()).unwrap_or(0);
3659                    if !last.is_multiple_of(*head_dim) {
3660                        panic!(
3661                            "rlx-cuda Rope: last_dim {} not multiple of head_dim {}",
3662                            last, head_dim
3663                        );
3664                    }
3665                    if head_dim % 2 != 0 {
3666                        panic!("rlx-cuda Rope: head_dim must be even");
3667                    }
3668                    let total: u32 = x_shape.iter().map(|d| d.unwrap_static() as u32).product();
3669                    let seq = x_shape[x_shape.len() - 2].unwrap_static() as u32;
3670                    schedule.push(Step::Rope {
3671                        n_total: total,
3672                        seq,
3673                        head_dim: *head_dim as u32,
3674                        half: (*head_dim / 2) as u32,
3675                        in_off: (arena.offset(x_id) / 4) as u32,
3676                        cos_off: (arena.offset(cos_id) / 4) as u32,
3677                        sin_off: (arena.offset(sin_id) / 4) as u32,
3678                        out_off: (arena.offset(node.id) / 4) as u32,
3679                        last_dim: last as u32,
3680                    });
3681                }
3682                Op::Cumsum { axis: _, exclusive } => {
3683                    let in_id = node.inputs[0];
3684                    let in_dims = graph.node(in_id).shape.dims();
3685                    let inner = in_dims.last().unwrap().unwrap_static() as u32;
3686                    let outer = in_dims[..in_dims.len() - 1]
3687                        .iter()
3688                        .map(|d| d.unwrap_static() as u32)
3689                        .product::<u32>()
3690                        .max(1);
3691                    schedule.push(Step::Cumsum {
3692                        outer,
3693                        inner,
3694                        in_off: (arena.offset(in_id) / 4) as u32,
3695                        out_off: (arena.offset(node.id) / 4) as u32,
3696                        exclusive: if *exclusive { 1 } else { 0 },
3697                    });
3698                }
3699                Op::TopK { k } => {
3700                    let in_id = node.inputs[0];
3701                    let in_dims = graph.node(in_id).shape.dims();
3702                    let inner = in_dims.last().unwrap().unwrap_static() as u32;
3703                    let outer = in_dims[..in_dims.len() - 1]
3704                        .iter()
3705                        .map(|d| d.unwrap_static() as u32)
3706                        .product::<u32>()
3707                        .max(1);
3708                    schedule.push(Step::TopK {
3709                        outer,
3710                        inner,
3711                        k: *k as u32,
3712                        in_off: (arena.offset(in_id) / 4) as u32,
3713                        out_off: (arena.offset(node.id) / 4) as u32,
3714                    });
3715                }
3716                Op::GroupedMatMul => {
3717                    let in_id = node.inputs[0];
3718                    let w_id = node.inputs[1];
3719                    let idx_id = node.inputs[2];
3720                    let in_dims = graph.node(in_id).shape.dims();
3721                    let w_dims = graph.node(w_id).shape.dims();
3722                    let m = in_dims[0].unwrap_static() as u32;
3723                    let k = in_dims[1].unwrap_static() as u32;
3724                    let n = w_dims[2].unwrap_static() as u32;
3725                    let ne = w_dims[0].unwrap_static() as u32;
3726                    schedule.push(Step::GroupedMatmul {
3727                        m,
3728                        k,
3729                        n,
3730                        num_experts: ne,
3731                        in_off: (arena.offset(in_id) / 4) as u32,
3732                        w_off: (arena.offset(w_id) / 4) as u32,
3733                        idx_off: (arena.offset(idx_id) / 4) as u32,
3734                        out_off: (arena.offset(node.id) / 4) as u32,
3735                    });
3736                }
3737                Op::DequantGroupedMatMul { scheme } => {
3738                    let in_id = node.inputs[0];
3739                    let w_id = node.inputs[1];
3740                    let idx_id = node.inputs[2];
3741                    let in_dims = graph.node(in_id).shape.dims();
3742                    let out_dims = node.shape.dims();
3743                    let m = in_dims[0].unwrap_static() as u32;
3744                    let k = in_dims[1].unwrap_static() as u32;
3745                    let n = out_dims[out_dims.len() - 1].unwrap_static() as u32;
3746                    let block_elems = scheme.gguf_block_size() as usize;
3747                    let block_bytes = scheme.gguf_block_bytes() as usize;
3748                    let slab_bytes = (k as usize * n as usize) / block_elems * block_bytes;
3749                    let total_bytes = graph.node(w_id).shape.num_elements().unwrap();
3750                    let ne = (total_bytes / slab_bytes.max(1)) as u32;
3751                    schedule.push(Step::DequantGroupedMatmulGguf {
3752                        m,
3753                        k,
3754                        n,
3755                        num_experts: ne,
3756                        scheme_id: crate::gguf_host::gguf_scheme_id(*scheme),
3757                        x_byte_off: arena.offset(in_id) as u32,
3758                        w_byte_off: arena.offset(w_id) as u32,
3759                        idx_byte_off: arena.offset(idx_id) as u32,
3760                        out_byte_off: arena.offset(node.id) as u32,
3761                    });
3762                }
3763                Op::ScatterAdd => {
3764                    let upd_id = node.inputs[0];
3765                    let idx_id = node.inputs[1];
3766                    let upd_dims = graph.node(upd_id).shape.dims();
3767                    let out_dims = node.shape.dims();
3768                    let num_updates = upd_dims[0].unwrap_static() as u32;
3769                    let trailing: u32 = upd_dims
3770                        .iter()
3771                        .skip(1)
3772                        .map(|d| d.unwrap_static() as u32)
3773                        .product::<u32>()
3774                        .max(1);
3775                    let out_dim = out_dims[0].unwrap_static() as u32;
3776                    let out_total = out_dim * trailing;
3777                    let out_off = (arena.offset(node.id) / 4) as u32;
3778                    schedule.push(Step::ScatterAddZero { out_off, out_total });
3779                    schedule.push(Step::ScatterAddAcc {
3780                        out_off,
3781                        upd_off: (arena.offset(upd_id) / 4) as u32,
3782                        idx_off: (arena.offset(idx_id) / 4) as u32,
3783                        num_updates,
3784                        trailing,
3785                        out_dim,
3786                    });
3787                }
3788                Op::DequantMatMul { scheme } => {
3789                    use rlx_ir::quant::QuantScheme;
3790                    let x_id = node.inputs[0];
3791                    let w_id = node.inputs[1];
3792                    let out_dims = node.shape.dims();
3793                    let x_dims = graph.node(x_id).shape.dims();
3794                    let m = out_dims[0].unwrap_static() as u32;
3795                    let n = out_dims[1].unwrap_static() as u32;
3796                    let k = x_dims[1].unwrap_static() as u32;
3797                    if scheme.is_gguf() {
3798                        schedule.push(Step::DequantMatmulGguf {
3799                            m,
3800                            k,
3801                            n,
3802                            scheme_id: crate::gguf_host::gguf_scheme_id(*scheme),
3803                            x_byte_off: arena.offset(x_id) as u32,
3804                            w_byte_off: arena.offset(w_id) as u32,
3805                            out_byte_off: arena.offset(node.id) as u32,
3806                        });
3807                    } else {
3808                        let (block_size, scheme_id) = match scheme {
3809                            QuantScheme::Int8Block { block_size } => (*block_size, 0u32),
3810                            QuantScheme::Int8BlockAsym { block_size } => (*block_size, 1u32),
3811                            QuantScheme::Int4Block { block_size } => (*block_size, 2u32),
3812                            QuantScheme::Fp8E4m3 => (1, 3u32),
3813                            QuantScheme::Fp8E5m2 => (1, 4u32),
3814                            QuantScheme::Nvfp4Block => (rlx_ir::NVFP4_GROUP_SIZE as u32, 5u32),
3815                            other => panic!("rlx-cuda DequantMatMul: unsupported scheme {other:?}"),
3816                        };
3817                        let scale_id = node.inputs[2];
3818                        let zp_id = node.inputs[3];
3819                        schedule.push(Step::DequantMatmul {
3820                            m,
3821                            k,
3822                            n,
3823                            block_size,
3824                            scheme_id,
3825                            x_off: (arena.offset(x_id) / 4) as u32,
3826                            w_off: (arena.offset(w_id) / 4) as u32,
3827                            scale_off: (arena.offset(scale_id) / 4) as u32,
3828                            zp_off: (arena.offset(zp_id) / 4) as u32,
3829                            out_off: (arena.offset(node.id) / 4) as u32,
3830                        });
3831                    }
3832                }
3833                Op::SelectiveScan { state_size } => {
3834                    if *state_size > 256 {
3835                        panic!("rlx-cuda SelectiveScan: state_size {state_size} > 256 cap");
3836                    }
3837                    let x_id = node.inputs[0];
3838                    let dt_id = node.inputs[1];
3839                    let a_id = node.inputs[2];
3840                    let b_id = node.inputs[3];
3841                    let c_id = node.inputs[4];
3842                    let in_dims = graph.node(x_id).shape.dims();
3843                    schedule.push(Step::SelectiveScan {
3844                        batch: in_dims[0].unwrap_static() as u32,
3845                        seq: in_dims[1].unwrap_static() as u32,
3846                        hidden: in_dims[2].unwrap_static() as u32,
3847                        state_size: *state_size as u32,
3848                        x_off: (arena.offset(x_id) / 4) as u32,
3849                        delta_off: (arena.offset(dt_id) / 4) as u32,
3850                        a_off: (arena.offset(a_id) / 4) as u32,
3851                        b_off: (arena.offset(b_id) / 4) as u32,
3852                        c_off: (arena.offset(c_id) / 4) as u32,
3853                        out_off: (arena.offset(node.id) / 4) as u32,
3854                    });
3855                }
3856                Op::Fft { inverse, norm } => {
3857                    let in_id = node.inputs[0];
3858                    let in_shape = graph.node(in_id).shape.clone();
3859                    let meta = rlx_ir::fft::fft_meta(&in_shape);
3860                    let dtype = in_shape.dtype();
3861                    let use_gpu = matches!(dtype, rlx_ir::DType::F32)
3862                        && meta.n_complex.is_power_of_two()
3863                        && meta.n_complex >= 2;
3864                    schedule.push(Step::Fft {
3865                        src_byte_off: arena.offset(in_id) as u32,
3866                        dst_byte_off: arena.offset(node.id) as u32,
3867                        outer: meta.outer as u32,
3868                        n_complex: meta.n_complex as u32,
3869                        inverse: *inverse,
3870                        norm_tag: norm.tag(),
3871                        dtype_tag: fft_dtype_tag(dtype),
3872                        use_gpu,
3873                    });
3874                }
3875                Op::LogMel => {
3876                    let spec_shape = graph.node(node.inputs[0]).shape.clone();
3877                    let filt_shape = graph.node(node.inputs[1]).shape.clone();
3878                    let meta = rlx_ir::audio::log_mel_meta(&spec_shape, &filt_shape)
3879                        .unwrap_or_else(|e| panic!("Op::LogMel: {e}"));
3880                    schedule.push(Step::LogMelHost {
3881                        spec_byte_off: arena.offset(node.inputs[0]) as u32,
3882                        filt_byte_off: arena.offset(node.inputs[1]) as u32,
3883                        dst_byte_off: arena.offset(node.id) as u32,
3884                        outer: meta.outer as u32,
3885                        n_fft: meta.n_fft as u32,
3886                        n_bins: meta.n_bins as u32,
3887                        n_mels: meta.n_mels as u32,
3888                    });
3889                }
3890                Op::LogMelBackward => {
3891                    let spec_shape = graph.node(node.inputs[0]).shape.clone();
3892                    let filt_shape = graph.node(node.inputs[1]).shape.clone();
3893                    let meta = rlx_ir::audio::log_mel_meta(&spec_shape, &filt_shape)
3894                        .unwrap_or_else(|e| panic!("Op::LogMelBackward: {e}"));
3895                    schedule.push(Step::LogMelBackwardHost {
3896                        spec_byte_off: arena.offset(node.inputs[0]) as u32,
3897                        filt_byte_off: arena.offset(node.inputs[1]) as u32,
3898                        dy_byte_off: arena.offset(node.inputs[2]) as u32,
3899                        dst_byte_off: arena.offset(node.id) as u32,
3900                        outer: meta.outer as u32,
3901                        n_fft: meta.n_fft as u32,
3902                        n_bins: meta.n_bins as u32,
3903                        n_mels: meta.n_mels as u32,
3904                    });
3905                }
3906                Op::WelchPeaks { k, n_segments } => {
3907                    let spec_shape = graph.node(node.inputs[0]).shape.clone();
3908                    let meta = rlx_ir::audio::welch_peaks_meta(&spec_shape, *k, *n_segments)
3909                        .unwrap_or_else(|e| panic!("Op::WelchPeaks: {e}"));
3910                    let use_gpu = rlx_ir::audio::welch_peaks_gpu_native_eligible(
3911                        &spec_shape,
3912                        *k,
3913                        *n_segments,
3914                    )
3915                    .unwrap_or(false);
3916                    if use_gpu {
3917                        schedule.push(Step::WelchPeaksGpu {
3918                            spec_off: (arena.offset(node.inputs[0]) / 4) as u32,
3919                            dst_off: (arena.offset(node.id) / 4) as u32,
3920                            welch_batch: meta.welch_batch as u32,
3921                            n_fft: meta.n_fft as u32,
3922                            n_segments: meta.n_segments as u32,
3923                            k: meta.k as u32,
3924                            n_bins: meta.n_bins as u32,
3925                        });
3926                    } else {
3927                        schedule.push(Step::WelchPeaksHost {
3928                            spec_byte_off: arena.offset(node.inputs[0]) as u32,
3929                            dst_byte_off: arena.offset(node.id) as u32,
3930                            welch_batch: meta.welch_batch as u32,
3931                            n_fft: meta.n_fft as u32,
3932                            n_segments: meta.n_segments as u32,
3933                            k: meta.k as u32,
3934                        });
3935                    }
3936                }
3937                Op::Im2Col {
3938                    kernel_size,
3939                    stride,
3940                    padding,
3941                    dilation,
3942                } => {
3943                    let x_shape = &graph.node(node.inputs[0]).shape;
3944                    if kernel_size.len() != 2 || x_shape.rank() != 4 {
3945                        panic!("rlx-cuda Im2Col: 2D NCHW only");
3946                    }
3947                    let n = match x_shape.dim(0) {
3948                        rlx_ir::shape::Dim::Static(v) => v as u32,
3949                        _ => 0,
3950                    };
3951                    let c_in = x_shape.dim(1).unwrap_static() as u32;
3952                    let h = x_shape.dim(2).unwrap_static() as u32;
3953                    let w = x_shape.dim(3).unwrap_static() as u32;
3954                    let kh = kernel_size[0] as u32;
3955                    let kw = kernel_size[1] as u32;
3956                    let sh = stride.first().copied().unwrap_or(1) as u32;
3957                    let sw = stride.get(1).copied().unwrap_or(1) as u32;
3958                    let ph = padding.first().copied().unwrap_or(0) as u32;
3959                    let pw = padding.get(1).copied().unwrap_or(0) as u32;
3960                    let dh = dilation.first().copied().unwrap_or(1) as u32;
3961                    let dw_dil = dilation.get(1).copied().unwrap_or(1) as u32;
3962                    let h_out = rlx_ir::shape::conv2d_spatial_output(
3963                        h as usize,
3964                        kh as usize,
3965                        sh as usize,
3966                        ph as usize,
3967                        dh as usize,
3968                    ) as u32;
3969                    let w_out = rlx_ir::shape::conv2d_spatial_output(
3970                        w as usize,
3971                        kw as usize,
3972                        sw as usize,
3973                        pw as usize,
3974                        dw_dil as usize,
3975                    ) as u32;
3976                    schedule.push(Step::Im2ColHost {
3977                        x_byte_off: arena.offset(node.inputs[0]) as u32,
3978                        col_byte_off: arena.offset(node.id) as u32,
3979                        n,
3980                        c_in,
3981                        h,
3982                        w,
3983                        h_out,
3984                        w_out,
3985                        kh,
3986                        kw,
3987                        sh,
3988                        sw,
3989                        ph,
3990                        pw,
3991                        dh,
3992                        dw_dil,
3993                        use_gpu: im2col_use_gpu(n, exec_mode),
3994                    });
3995                }
3996                Op::GatedDeltaNet {
3997                    state_size,
3998                    carry_state,
3999                } => {
4000                    if *state_size > rlx_cpu::gdn::GDN_MAX_STATE {
4001                        panic!(
4002                            "rlx-cuda GatedDeltaNet: state_size {state_size} > {}",
4003                            rlx_cpu::gdn::GDN_MAX_STATE
4004                        );
4005                    }
4006                    let q_id = node.inputs[0];
4007                    let q_shape = &graph.node(q_id).shape;
4008                    let state_off = if *carry_state {
4009                        arena.offset(node.inputs[5])
4010                    } else {
4011                        0
4012                    };
4013                    schedule.push(Step::GatedDeltaNet {
4014                        q_byte_off: arena.offset(q_id) as u32,
4015                        k_byte_off: arena.offset(node.inputs[1]) as u32,
4016                        v_byte_off: arena.offset(node.inputs[2]) as u32,
4017                        g_byte_off: arena.offset(node.inputs[3]) as u32,
4018                        beta_byte_off: arena.offset(node.inputs[4]) as u32,
4019                        state_byte_off: state_off as u32,
4020                        dst_byte_off: arena.offset(node.id) as u32,
4021                        batch: q_shape.dim(0).unwrap_static() as u32,
4022                        seq: q_shape.dim(1).unwrap_static() as u32,
4023                        heads: q_shape.dim(2).unwrap_static() as u32,
4024                        state_size: *state_size as u32,
4025                        use_carry: *carry_state,
4026                    });
4027                }
4028                Op::Custom { name, attrs, .. } => match name.as_str() {
4029                    "llada2.group_limited_gate" => {
4030                        let sig_id = node.inputs[0];
4031                        let route_id = node.inputs[1];
4032                        let n_elems = graph.node(sig_id).shape.num_elements().unwrap() as u32;
4033                        let mut attr_buf = [0u8; 20];
4034                        let n = attrs.len().min(20);
4035                        attr_buf[..n].copy_from_slice(&attrs[..n]);
4036                        schedule.push(Step::Llada2GroupLimitedGate {
4037                            sig_off: (arena.offset(sig_id) / 4) as u32,
4038                            route_off: (arena.offset(route_id) / 4) as u32,
4039                            out_off: (arena.offset(node.id) / 4) as u32,
4040                            n_elems,
4041                            attrs: attr_buf,
4042                        });
4043                    }
4044                    "umap.knn" => {
4045                        let pw_id = node.inputs[0];
4046                        let n = graph.node(pw_id).shape.dims()[0].unwrap_static() as u32;
4047                        let k = u32::from_le_bytes(attrs[..4].try_into().unwrap());
4048                        schedule.push(Step::UmapKnn {
4049                            pairwise_off: (arena.offset(pw_id) / 4) as u32,
4050                            out_off: (arena.offset(node.id) / 4) as u32,
4051                            n,
4052                            k,
4053                        });
4054                    }
4055                    other => panic!("rlx-cuda: unsupported Op::Custom('{other}')"),
4056                },
4057
4058                Op::GaussianSplatRender {
4059                    width,
4060                    height,
4061                    tile_size,
4062                    radius_scale,
4063                    alpha_cutoff,
4064                    max_splat_steps,
4065                    transmittance_threshold,
4066                    max_list_entries,
4067                } => {
4068                    let elem_len = |id: NodeId| -> u32 {
4069                        graph.node(id).shape.num_elements().unwrap_or(0) as u32
4070                    };
4071                    schedule.push(Step::GaussianSplatRender {
4072                        positions_off: arena.offset(node.inputs[0]) as u32,
4073                        positions_len: elem_len(node.inputs[0]),
4074                        scales_off: arena.offset(node.inputs[1]) as u32,
4075                        scales_len: elem_len(node.inputs[1]),
4076                        rotations_off: arena.offset(node.inputs[2]) as u32,
4077                        rotations_len: elem_len(node.inputs[2]),
4078                        opacities_off: arena.offset(node.inputs[3]) as u32,
4079                        opacities_len: elem_len(node.inputs[3]),
4080                        colors_off: arena.offset(node.inputs[4]) as u32,
4081                        colors_len: elem_len(node.inputs[4]),
4082                        sh_coeffs_off: arena.offset(node.inputs[5]) as u32,
4083                        sh_coeffs_len: elem_len(node.inputs[5]),
4084                        meta_off: arena.offset(node.inputs[6]) as u32,
4085                        dst_off: arena.offset(node.id) as u32,
4086                        dst_len: node.shape.num_elements().unwrap_or(0) as u32,
4087                        width: *width,
4088                        height: *height,
4089                        tile_size: *tile_size,
4090                        radius_scale: *radius_scale,
4091                        alpha_cutoff: *alpha_cutoff,
4092                        max_splat_steps: *max_splat_steps,
4093                        transmittance_threshold: *transmittance_threshold,
4094                        max_list_entries: *max_list_entries,
4095                    });
4096                }
4097
4098                Op::GaussianSplatRenderBackward {
4099                    width,
4100                    height,
4101                    tile_size,
4102                    radius_scale,
4103                    alpha_cutoff,
4104                    max_splat_steps,
4105                    transmittance_threshold,
4106                    max_list_entries,
4107                    loss_grad_clip,
4108                    sh_band,
4109                    max_anisotropy,
4110                } => {
4111                    let elem_len = |id: NodeId| -> u32 {
4112                        graph.node(id).shape.num_elements().unwrap_or(0) as u32
4113                    };
4114                    schedule.push(Step::GaussianSplatRenderBackward {
4115                        positions_off: arena.offset(node.inputs[0]) as u32,
4116                        positions_len: elem_len(node.inputs[0]),
4117                        scales_off: arena.offset(node.inputs[1]) as u32,
4118                        scales_len: elem_len(node.inputs[1]),
4119                        rotations_off: arena.offset(node.inputs[2]) as u32,
4120                        rotations_len: elem_len(node.inputs[2]),
4121                        opacities_off: arena.offset(node.inputs[3]) as u32,
4122                        opacities_len: elem_len(node.inputs[3]),
4123                        colors_off: arena.offset(node.inputs[4]) as u32,
4124                        colors_len: elem_len(node.inputs[4]),
4125                        sh_coeffs_off: arena.offset(node.inputs[5]) as u32,
4126                        sh_coeffs_len: elem_len(node.inputs[5]),
4127                        meta_off: arena.offset(node.inputs[6]) as u32,
4128                        d_loss_off: arena.offset(node.inputs[7]) as u32,
4129                        d_loss_len: elem_len(node.inputs[7]),
4130                        packed_off: arena.offset(node.id) as u32,
4131                        packed_len: node.shape.num_elements().unwrap_or(0) as u32,
4132                        width: *width,
4133                        height: *height,
4134                        tile_size: *tile_size,
4135                        radius_scale: *radius_scale,
4136                        alpha_cutoff: *alpha_cutoff,
4137                        max_splat_steps: *max_splat_steps,
4138                        transmittance_threshold: *transmittance_threshold,
4139                        max_list_entries: *max_list_entries,
4140                        loss_grad_clip: *loss_grad_clip,
4141                        sh_band: *sh_band,
4142                        max_anisotropy: *max_anisotropy,
4143                    });
4144                }
4145
4146                Op::GaussianSplatPrepare {
4147                    width,
4148                    height,
4149                    tile_size,
4150                    radius_scale,
4151                    alpha_cutoff,
4152                    max_splat_steps,
4153                    transmittance_threshold,
4154                    max_list_entries,
4155                } => {
4156                    let elem_len = |id: NodeId| -> u32 {
4157                        graph.node(id).shape.num_elements().unwrap_or(0) as u32
4158                    };
4159                    schedule.push(Step::GaussianSplatPrepare {
4160                        positions_off: arena.offset(node.inputs[0]) as u32,
4161                        positions_len: elem_len(node.inputs[0]),
4162                        scales_off: arena.offset(node.inputs[1]) as u32,
4163                        scales_len: elem_len(node.inputs[1]),
4164                        rotations_off: arena.offset(node.inputs[2]) as u32,
4165                        rotations_len: elem_len(node.inputs[2]),
4166                        opacities_off: arena.offset(node.inputs[3]) as u32,
4167                        opacities_len: elem_len(node.inputs[3]),
4168                        colors_off: arena.offset(node.inputs[4]) as u32,
4169                        colors_len: elem_len(node.inputs[4]),
4170                        sh_coeffs_off: arena.offset(node.inputs[5]) as u32,
4171                        sh_coeffs_len: elem_len(node.inputs[5]),
4172                        meta_off: arena.offset(node.inputs[6]) as u32,
4173                        meta_len: elem_len(node.inputs[6]),
4174                        prep_off: arena.offset(node.id) as u32,
4175                        prep_len: node.shape.num_elements().unwrap_or(0) as u32,
4176                        width: *width,
4177                        height: *height,
4178                        tile_size: *tile_size,
4179                        radius_scale: *radius_scale,
4180                        alpha_cutoff: *alpha_cutoff,
4181                        max_splat_steps: *max_splat_steps,
4182                        transmittance_threshold: *transmittance_threshold,
4183                        max_list_entries: *max_list_entries,
4184                    });
4185                }
4186
4187                Op::GaussianSplatRasterize {
4188                    width,
4189                    height,
4190                    tile_size,
4191                    alpha_cutoff,
4192                    max_splat_steps,
4193                    transmittance_threshold,
4194                    max_list_entries,
4195                } => {
4196                    let elem_len = |id: NodeId| -> u32 {
4197                        graph.node(id).shape.num_elements().unwrap_or(0) as u32
4198                    };
4199                    let prep_id = node.inputs[0];
4200                    let count = match &graph.node(prep_id).op {
4201                        rlx_ir::Op::GaussianSplatPrepare { .. } => {
4202                            elem_len(graph.node(prep_id).inputs[0]) / 3
4203                        }
4204                        _ => 1,
4205                    };
4206                    schedule.push(Step::GaussianSplatRasterize {
4207                        prep_off: arena.offset(prep_id) as u32,
4208                        prep_len: elem_len(prep_id),
4209                        meta_off: arena.offset(node.inputs[1]) as u32,
4210                        meta_len: elem_len(node.inputs[1]),
4211                        dst_off: arena.offset(node.id) as u32,
4212                        dst_len: node.shape.num_elements().unwrap_or(0) as u32,
4213                        count,
4214                        width: *width,
4215                        height: *height,
4216                        tile_size: *tile_size,
4217                        alpha_cutoff: *alpha_cutoff,
4218                        max_splat_steps: *max_splat_steps,
4219                        transmittance_threshold: *transmittance_threshold,
4220                        max_list_entries: *max_list_entries,
4221                    });
4222                }
4223
4224                Op::Pool {
4225                    kind,
4226                    kernel_size,
4227                    stride,
4228                    padding,
4229                } => {
4230                    let in_id = node.inputs[0];
4231                    let in_dims = graph.node(in_id).shape.dims();
4232                    let out_dims = node.shape.dims();
4233                    let op_id = reduce_op_id(*kind);
4234                    let in_off = (arena.offset(in_id) / 4) as u32;
4235                    let out_off = (arena.offset(node.id) / 4) as u32;
4236                    match kernel_size.len() {
4237                        1 => {
4238                            schedule.push(Step::Pool1d {
4239                                n: in_dims[0].unwrap_static() as u32,
4240                                c: in_dims[1].unwrap_static() as u32,
4241                                l: in_dims[2].unwrap_static() as u32,
4242                                l_out: out_dims[2].unwrap_static() as u32,
4243                                kl: kernel_size[0] as u32,
4244                                sl: stride[0] as u32,
4245                                pl: padding[0] as u32,
4246                                op: op_id,
4247                                in_off,
4248                                out_off,
4249                            });
4250                        }
4251                        2 => {
4252                            schedule.push(Step::Pool2d {
4253                                n: in_dims[0].unwrap_static() as u32,
4254                                c: in_dims[1].unwrap_static() as u32,
4255                                h: in_dims[2].unwrap_static() as u32,
4256                                w: in_dims[3].unwrap_static() as u32,
4257                                h_out: out_dims[2].unwrap_static() as u32,
4258                                w_out: out_dims[3].unwrap_static() as u32,
4259                                kh: kernel_size[0] as u32,
4260                                kw: kernel_size[1] as u32,
4261                                sh: stride[0] as u32,
4262                                sw: stride[1] as u32,
4263                                ph: padding[0] as u32,
4264                                pw: padding[1] as u32,
4265                                op: op_id,
4266                                in_off,
4267                                out_off,
4268                            });
4269                        }
4270                        3 => {
4271                            schedule.push(Step::Pool3d {
4272                                n: in_dims[0].unwrap_static() as u32,
4273                                c: in_dims[1].unwrap_static() as u32,
4274                                d: in_dims[2].unwrap_static() as u32,
4275                                h: in_dims[3].unwrap_static() as u32,
4276                                w: in_dims[4].unwrap_static() as u32,
4277                                d_out: out_dims[2].unwrap_static() as u32,
4278                                h_out: out_dims[3].unwrap_static() as u32,
4279                                w_out: out_dims[4].unwrap_static() as u32,
4280                                kd: kernel_size[0] as u32,
4281                                kh: kernel_size[1] as u32,
4282                                kw: kernel_size[2] as u32,
4283                                sd: stride[0] as u32,
4284                                sh: stride[1] as u32,
4285                                sw: stride[2] as u32,
4286                                pd: padding[0] as u32,
4287                                ph: padding[1] as u32,
4288                                pw: padding[2] as u32,
4289                                op: op_id,
4290                                in_off,
4291                                out_off,
4292                            });
4293                        }
4294                        other => panic!("rlx-cuda Pool: unsupported kernel rank {other}"),
4295                    }
4296                }
4297                Op::LayerNorm2d { eps } => {
4298                    let in_shape = &graph.node(node.inputs[0]).shape;
4299                    schedule.push(Step::LayerNorm2d {
4300                        src_off: (arena.offset(node.inputs[0]) / 4) as u32,
4301                        g_off: (arena.offset(node.inputs[1]) / 4) as u32,
4302                        b_off: (arena.offset(node.inputs[2]) / 4) as u32,
4303                        dst_off: (arena.offset(node.id) / 4) as u32,
4304                        n: in_shape.dim(0).unwrap_static() as u32,
4305                        c: in_shape.dim(1).unwrap_static() as u32,
4306                        h: in_shape.dim(2).unwrap_static() as u32,
4307                        w: in_shape.dim(3).unwrap_static() as u32,
4308                        eps_bits: eps.to_bits(),
4309                    });
4310                }
4311                Op::ConvTranspose2d {
4312                    kernel_size,
4313                    stride,
4314                    padding,
4315                    dilation,
4316                    output_padding: _,
4317                    groups,
4318                } => {
4319                    let in_shape = &graph.node(node.inputs[0]).shape;
4320                    let out_shape = &node.shape;
4321                    schedule.push(Step::ConvTranspose2d {
4322                        src_off: (arena.offset(node.inputs[0]) / 4) as u32,
4323                        w_off: (arena.offset(node.inputs[1]) / 4) as u32,
4324                        dst_off: (arena.offset(node.id) / 4) as u32,
4325                        n: in_shape.dim(0).unwrap_static() as u32,
4326                        c_in: in_shape.dim(1).unwrap_static() as u32,
4327                        h: in_shape.dim(2).unwrap_static() as u32,
4328                        w_in: in_shape.dim(3).unwrap_static() as u32,
4329                        c_out: out_shape.dim(1).unwrap_static() as u32,
4330                        h_out: out_shape.dim(2).unwrap_static() as u32,
4331                        w_out: out_shape.dim(3).unwrap_static() as u32,
4332                        kh: kernel_size[0] as u32,
4333                        kw: kernel_size[1] as u32,
4334                        sh: stride.first().copied().unwrap_or(1) as u32,
4335                        sw: stride.get(1).copied().unwrap_or(1) as u32,
4336                        ph: padding.first().copied().unwrap_or(0) as u32,
4337                        pw: padding.get(1).copied().unwrap_or(0) as u32,
4338                        dh: dilation.first().copied().unwrap_or(1) as u32,
4339                        dw: dilation.get(1).copied().unwrap_or(1) as u32,
4340                        groups: *groups as u32,
4341                    });
4342                }
4343                Op::GroupNorm { num_groups, eps } => {
4344                    let in_shape = &graph.node(node.inputs[0]).shape;
4345                    schedule.push(Step::GroupNorm {
4346                        src_off: (arena.offset(node.inputs[0]) / 4) as u32,
4347                        g_off: (arena.offset(node.inputs[1]) / 4) as u32,
4348                        b_off: (arena.offset(node.inputs[2]) / 4) as u32,
4349                        dst_off: (arena.offset(node.id) / 4) as u32,
4350                        n: in_shape.dim(0).unwrap_static() as u32,
4351                        c: in_shape.dim(1).unwrap_static() as u32,
4352                        h: in_shape.dim(2).unwrap_static() as u32,
4353                        w: in_shape.dim(3).unwrap_static() as u32,
4354                        num_groups: *num_groups as u32,
4355                        eps_bits: eps.to_bits(),
4356                    });
4357                }
4358                Op::ResizeNearest2x => {
4359                    let in_shape = &graph.node(node.inputs[0]).shape;
4360                    schedule.push(Step::ResizeNearest2x {
4361                        src_off: (arena.offset(node.inputs[0]) / 4) as u32,
4362                        dst_off: (arena.offset(node.id) / 4) as u32,
4363                        n: in_shape.dim(0).unwrap_static() as u32,
4364                        c: in_shape.dim(1).unwrap_static() as u32,
4365                        h: in_shape.dim(2).unwrap_static() as u32,
4366                        w: in_shape.dim(3).unwrap_static() as u32,
4367                    });
4368                }
4369                Op::Conv {
4370                    kernel_size,
4371                    stride,
4372                    padding,
4373                    dilation,
4374                    groups,
4375                } => {
4376                    let in_id = node.inputs[0];
4377                    let w_id = node.inputs[1];
4378                    let in_dims = graph.node(in_id).shape.dims();
4379                    let w_dims = graph.node(w_id).shape.dims();
4380                    let out_dims = node.shape.dims();
4381                    let in_off = (arena.offset(in_id) / 4) as u32;
4382                    let w_off = (arena.offset(w_id) / 4) as u32;
4383                    let out_off = (arena.offset(node.id) / 4) as u32;
4384                    match kernel_size.len() {
4385                        1 => {
4386                            schedule.push(Step::Conv1d {
4387                                n: in_dims[0].unwrap_static() as u32,
4388                                c_in: in_dims[1].unwrap_static() as u32,
4389                                c_out: w_dims[0].unwrap_static() as u32,
4390                                l: in_dims[2].unwrap_static() as u32,
4391                                l_out: out_dims[2].unwrap_static() as u32,
4392                                kl: kernel_size[0] as u32,
4393                                sl: stride[0] as u32,
4394                                pl: padding[0] as u32,
4395                                dl: dilation[0] as u32,
4396                                groups: *groups as u32,
4397                                in_off,
4398                                w_off,
4399                                out_off,
4400                            });
4401                        }
4402                        2 => {
4403                            schedule.push(Step::Conv2d {
4404                                n: in_dims[0].unwrap_static() as u32,
4405                                c_in: in_dims[1].unwrap_static() as u32,
4406                                c_out: w_dims[0].unwrap_static() as u32,
4407                                h: in_dims[2].unwrap_static() as u32,
4408                                w: in_dims[3].unwrap_static() as u32,
4409                                h_out: out_dims[2].unwrap_static() as u32,
4410                                w_out: out_dims[3].unwrap_static() as u32,
4411                                kh: kernel_size[0] as u32,
4412                                kw: kernel_size[1] as u32,
4413                                sh: stride[0] as u32,
4414                                sw: stride[1] as u32,
4415                                ph: padding[0] as u32,
4416                                pw: padding[1] as u32,
4417                                dh: dilation[0] as u32,
4418                                dw: dilation[1] as u32,
4419                                groups: *groups as u32,
4420                                in_off,
4421                                w_off,
4422                                out_off,
4423                            });
4424                        }
4425                        3 => {
4426                            schedule.push(Step::Conv3d {
4427                                n: in_dims[0].unwrap_static() as u32,
4428                                c_in: in_dims[1].unwrap_static() as u32,
4429                                c_out: w_dims[0].unwrap_static() as u32,
4430                                d: in_dims[2].unwrap_static() as u32,
4431                                h: in_dims[3].unwrap_static() as u32,
4432                                w: in_dims[4].unwrap_static() as u32,
4433                                d_out: out_dims[2].unwrap_static() as u32,
4434                                h_out: out_dims[3].unwrap_static() as u32,
4435                                w_out: out_dims[4].unwrap_static() as u32,
4436                                kd: kernel_size[0] as u32,
4437                                kh: kernel_size[1] as u32,
4438                                kw: kernel_size[2] as u32,
4439                                sd: stride[0] as u32,
4440                                sh: stride[1] as u32,
4441                                sw: stride[2] as u32,
4442                                pd: padding[0] as u32,
4443                                ph: padding[1] as u32,
4444                                pw: padding[2] as u32,
4445                                dd: dilation[0] as u32,
4446                                dh: dilation[1] as u32,
4447                                dw: dilation[2] as u32,
4448                                groups: *groups as u32,
4449                                in_off,
4450                                w_off,
4451                                out_off,
4452                            });
4453                        }
4454                        other => panic!("rlx-cuda Conv: unsupported kernel rank {other}"),
4455                    }
4456                }
4457                Op::Sample {
4458                    top_k,
4459                    top_p,
4460                    temperature,
4461                    seed,
4462                } => {
4463                    let in_id = node.inputs[0];
4464                    let in_dims = graph.node(in_id).shape.dims();
4465                    let inner = in_dims.last().unwrap().unwrap_static() as u32;
4466                    let outer = in_dims[..in_dims.len() - 1]
4467                        .iter()
4468                        .map(|d| d.unwrap_static() as u32)
4469                        .product::<u32>()
4470                        .max(1);
4471                    let is_greedy = *top_k == 0
4472                        && (*top_p - 1.0).abs() < 1e-6
4473                        && (*temperature - 1.0).abs() < 1e-6;
4474                    if is_greedy {
4475                        schedule.push(Step::Argmax {
4476                            outer,
4477                            inner,
4478                            in_off: (arena.offset(in_id) / 4) as u32,
4479                            out_off: (arena.offset(node.id) / 4) as u32,
4480                        });
4481                    } else {
4482                        schedule.push(Step::Sample {
4483                            outer,
4484                            inner,
4485                            in_off: (arena.offset(in_id) / 4) as u32,
4486                            out_off: (arena.offset(node.id) / 4) as u32,
4487                            top_k: *top_k as u32,
4488                            top_p_bits: top_p.to_bits(),
4489                            temp_bits: temperature.to_bits(),
4490                            seed_lo: *seed as u32,
4491                            seed_hi: (*seed >> 32) as u32,
4492                        });
4493                    }
4494                }
4495                Op::RmsNormBackwardInput { eps, .. }
4496                | Op::RmsNormBackwardGamma { eps, .. }
4497                | Op::RmsNormBackwardBeta { eps, .. } => {
4498                    let x_shape = &graph.node(node.inputs[0]).shape;
4499                    let h = x_shape.dim(x_shape.rank() - 1).unwrap_static() as u32;
4500                    let rows = (x_shape.num_elements().unwrap() / h.max(1) as usize) as u32;
4501                    let eps_bits = eps.to_bits();
4502                    let off = |i: usize| arena.offset(node.inputs[i]) as u32;
4503                    let common = (off(0), off(1), off(2), off(3), rows, h, eps_bits);
4504                    match &node.op {
4505                        Op::RmsNormBackwardInput { .. } => {
4506                            schedule.push(Step::RmsNormBackwardInput {
4507                                x_byte_off: common.0,
4508                                gamma_byte_off: common.1,
4509                                beta_byte_off: common.2,
4510                                dy_byte_off: common.3,
4511                                dx_byte_off: arena.offset(node.id) as u32,
4512                                rows: common.4,
4513                                h: common.5,
4514                                eps_bits: common.6,
4515                            });
4516                        }
4517                        Op::RmsNormBackwardGamma { .. } => {
4518                            schedule.push(Step::RmsNormBackwardGamma {
4519                                x_byte_off: common.0,
4520                                gamma_byte_off: common.1,
4521                                beta_byte_off: common.2,
4522                                dy_byte_off: common.3,
4523                                dgamma_byte_off: arena.offset(node.id) as u32,
4524                                rows: common.4,
4525                                h: common.5,
4526                                eps_bits: common.6,
4527                            });
4528                        }
4529                        Op::RmsNormBackwardBeta { .. } => {
4530                            schedule.push(Step::RmsNormBackwardBeta {
4531                                x_byte_off: common.0,
4532                                gamma_byte_off: common.1,
4533                                beta_byte_off: common.2,
4534                                dy_byte_off: common.3,
4535                                dbeta_byte_off: arena.offset(node.id) as u32,
4536                                rows: common.4,
4537                                h: common.5,
4538                                eps_bits: common.6,
4539                            });
4540                        }
4541                        _ => unreachable!(),
4542                    }
4543                }
4544                Op::RopeBackward { head_dim, n_rot } => {
4545                    let dy_shape = &graph.node(node.inputs[0]).shape;
4546                    let (batch, seq, hidden) = if dy_shape.rank() >= 3 {
4547                        (
4548                            dy_shape.dim(0).unwrap_static() as u32,
4549                            dy_shape.dim(1).unwrap_static() as u32,
4550                            dy_shape.dim(2).unwrap_static() as u32,
4551                        )
4552                    } else {
4553                        (
4554                            1,
4555                            dy_shape.dim(0).unwrap_static() as u32,
4556                            dy_shape.dim(1).unwrap_static() as u32,
4557                        )
4558                    };
4559                    let cos_len = graph.node(node.inputs[1]).shape.num_elements().unwrap() as u32;
4560                    schedule.push(Step::RopeBackward {
4561                        dy_byte_off: arena.offset(node.inputs[0]) as u32,
4562                        cos_byte_off: arena.offset(node.inputs[1]) as u32,
4563                        sin_byte_off: arena.offset(node.inputs[2]) as u32,
4564                        dx_byte_off: arena.offset(node.id) as u32,
4565                        batch,
4566                        seq,
4567                        hidden,
4568                        head_dim: *head_dim as u32,
4569                        n_rot: *n_rot as u32,
4570                        cos_len,
4571                    });
4572                }
4573                Op::CumsumBackward { exclusive, .. } => {
4574                    let dy_shape = &graph.node(node.inputs[0]).shape;
4575                    let cols = dy_shape.dim(dy_shape.rank() - 1).unwrap_static() as u32;
4576                    let rows = (dy_shape.num_elements().unwrap() / cols.max(1) as usize) as u32;
4577                    schedule.push(Step::CumsumBackward {
4578                        dy_byte_off: arena.offset(node.inputs[0]) as u32,
4579                        dx_byte_off: arena.offset(node.id) as u32,
4580                        rows,
4581                        cols,
4582                        exclusive: *exclusive,
4583                    });
4584                }
4585                Op::GatherBackward { .. } => {
4586                    let dy_shape = &graph.node(node.inputs[0]).shape;
4587                    let idx_shape = &graph.node(node.inputs[1]).shape;
4588                    let out_shape = &node.shape;
4589                    let rank = out_shape.rank();
4590                    let axis = match &node.op {
4591                        Op::GatherBackward { axis } => *axis,
4592                        _ => 0,
4593                    };
4594                    let axis_u = if axis < 0 {
4595                        (rank as i32 + axis) as usize
4596                    } else {
4597                        axis as usize
4598                    };
4599                    let outer: usize = (0..axis_u)
4600                        .map(|i| dy_shape.dim(i).unwrap_static())
4601                        .product::<usize>()
4602                        .max(1);
4603                    let num_idx = idx_shape.dim(axis_u).unwrap_static();
4604                    let trailing: usize = (axis_u + 1..dy_shape.rank())
4605                        .map(|i| dy_shape.dim(i).unwrap_static())
4606                        .product::<usize>()
4607                        .max(1);
4608                    let axis_dim = out_shape.dim(axis_u).unwrap_static();
4609                    schedule.push(Step::GatherBackward {
4610                        dy_byte_off: arena.offset(node.inputs[0]) as u32,
4611                        indices_byte_off: arena.offset(node.inputs[1]) as u32,
4612                        dst_byte_off: arena.offset(node.id) as u32,
4613                        outer: outer as u32,
4614                        axis_dim: axis_dim as u32,
4615                        num_idx: num_idx as u32,
4616                        trailing: trailing as u32,
4617                    });
4618                }
4619                Op::Conv2dBackwardInput {
4620                    kernel_size,
4621                    stride,
4622                    padding,
4623                    dilation,
4624                    groups,
4625                } => {
4626                    let dy_shape = &graph.node(node.inputs[0]).shape;
4627                    let out_shape = &node.shape;
4628                    if kernel_size.len() == 2 && dy_shape.rank() == 4 && out_shape.rank() == 4 {
4629                        schedule.push(Step::Conv2dBackwardInput {
4630                            dy_byte_off: arena.offset(node.inputs[0]) as u32,
4631                            w_byte_off: arena.offset(node.inputs[1]) as u32,
4632                            dx_byte_off: arena.offset(node.id) as u32,
4633                            n: out_shape.dim(0).unwrap_static() as u32,
4634                            c_in: out_shape.dim(1).unwrap_static() as u32,
4635                            h: out_shape.dim(2).unwrap_static() as u32,
4636                            w_in: out_shape.dim(3).unwrap_static() as u32,
4637                            c_out: dy_shape.dim(1).unwrap_static() as u32,
4638                            h_out: dy_shape.dim(2).unwrap_static() as u32,
4639                            w_out: dy_shape.dim(3).unwrap_static() as u32,
4640                            kh: kernel_size[0] as u32,
4641                            kw: kernel_size[1] as u32,
4642                            sh: stride.first().copied().unwrap_or(1) as u32,
4643                            sw: stride.get(1).copied().unwrap_or(1) as u32,
4644                            ph: padding.first().copied().unwrap_or(0) as u32,
4645                            pw: padding.get(1).copied().unwrap_or(0) as u32,
4646                            dh: dilation.first().copied().unwrap_or(1) as u32,
4647                            dw: dilation.get(1).copied().unwrap_or(1) as u32,
4648                            groups: *groups as u32,
4649                        });
4650                    } else {
4651                        panic!("rlx-cuda: Conv2dBackwardInput expects 2-D conv on NCHW tensors");
4652                    }
4653                }
4654                Op::Conv2dBackwardWeight {
4655                    kernel_size,
4656                    stride,
4657                    padding,
4658                    dilation,
4659                    groups,
4660                } => {
4661                    let x_shape = &graph.node(node.inputs[0]).shape;
4662                    let dy_shape = &graph.node(node.inputs[1]).shape;
4663                    if kernel_size.len() == 2 && x_shape.rank() == 4 && dy_shape.rank() == 4 {
4664                        schedule.push(Step::Conv2dBackwardWeight {
4665                            x_byte_off: arena.offset(node.inputs[0]) as u32,
4666                            dy_byte_off: arena.offset(node.inputs[1]) as u32,
4667                            dw_byte_off: arena.offset(node.id) as u32,
4668                            n: x_shape.dim(0).unwrap_static() as u32,
4669                            c_in: x_shape.dim(1).unwrap_static() as u32,
4670                            h: x_shape.dim(2).unwrap_static() as u32,
4671                            w: x_shape.dim(3).unwrap_static() as u32,
4672                            c_out: dy_shape.dim(1).unwrap_static() as u32,
4673                            h_out: dy_shape.dim(2).unwrap_static() as u32,
4674                            w_out: dy_shape.dim(3).unwrap_static() as u32,
4675                            kh: kernel_size[0] as u32,
4676                            kw: kernel_size[1] as u32,
4677                            sh: stride.first().copied().unwrap_or(1) as u32,
4678                            sw: stride.get(1).copied().unwrap_or(1) as u32,
4679                            ph: padding.first().copied().unwrap_or(0) as u32,
4680                            pw: padding.get(1).copied().unwrap_or(0) as u32,
4681                            dh: dilation.first().copied().unwrap_or(1) as u32,
4682                            dw_dil: dilation.get(1).copied().unwrap_or(1) as u32,
4683                            groups: *groups as u32,
4684                        });
4685                    } else {
4686                        panic!("rlx-cuda: Conv2dBackwardWeight expects 2-D conv on NCHW tensors");
4687                    }
4688                }
4689                Op::MaxPool2dBackward {
4690                    kernel_size,
4691                    stride,
4692                    padding,
4693                } => {
4694                    let x_shape = &graph.node(node.inputs[0]).shape;
4695                    let dy_shape = &graph.node(node.inputs[1]).shape;
4696                    if kernel_size.len() == 2 && x_shape.rank() == 4 && dy_shape.rank() == 4 {
4697                        schedule.push(Step::MaxPool2dBackward {
4698                            x_byte_off: arena.offset(node.inputs[0]) as u32,
4699                            dy_byte_off: arena.offset(node.inputs[1]) as u32,
4700                            dx_byte_off: arena.offset(node.id) as u32,
4701                            n: x_shape.dim(0).unwrap_static() as u32,
4702                            c: x_shape.dim(1).unwrap_static() as u32,
4703                            h: x_shape.dim(2).unwrap_static() as u32,
4704                            w: x_shape.dim(3).unwrap_static() as u32,
4705                            h_out: dy_shape.dim(2).unwrap_static() as u32,
4706                            w_out: dy_shape.dim(3).unwrap_static() as u32,
4707                            kh: kernel_size[0] as u32,
4708                            kw: kernel_size[1] as u32,
4709                            sh: stride.first().copied().unwrap_or(1) as u32,
4710                            sw: stride.get(1).copied().unwrap_or(1) as u32,
4711                            ph: padding.first().copied().unwrap_or(0) as u32,
4712                            pw: padding.get(1).copied().unwrap_or(0) as u32,
4713                        });
4714                    } else {
4715                        panic!("rlx-cuda: MaxPool2dBackward expects 2-D pool on NCHW tensors");
4716                    }
4717                }
4718                other => panic!(
4719                    "rlx-cuda: op {other:?} not yet lowered. \
4720                     Open a follow-up PR if you hit this — every other op \
4721                     in the IR is wired."
4722                ),
4723            }
4724        }
4725
4726        let schedule = fuse_elementwise_chains(schedule);
4727
4728        let blas = cuda_blas();
4729        let needs_blas_lt = schedule_needs_blas_lt(&schedule);
4730        let needs_dnn = schedule_needs_dnn(&schedule);
4731        let blas_lt = if needs_blas_lt {
4732            cuda_blas_lt_handle()
4733        } else {
4734            None
4735        };
4736        let blas_lt_workspace = if needs_blas_lt {
4737            cuda_blas_lt_workspace()
4738        } else {
4739            None
4740        };
4741        let dnn = if needs_dnn { cuda_dnn_handle() } else { None };
4742        let dnn_workspace = if needs_dnn {
4743            cuda_dnn_workspace()
4744        } else {
4745            None
4746        };
4747
4748        let streams = match exec_mode {
4749            ExecMode::MultiStream(n) if n > 1 => {
4750                let mut v = Vec::with_capacity(n);
4751                for _ in 0..n {
4752                    if let Ok(s) = ctx.new_stream() {
4753                        v.push(s);
4754                    }
4755                }
4756                v
4757            }
4758            _ => Vec::new(),
4759        };
4760
4761        let output_staging: Vec<F32HostSlot> = graph
4762            .outputs
4763            .iter()
4764            .map(|&id| {
4765                let elems = graph.node(id).shape.num_elements().unwrap_or(0);
4766                F32HostSlot::new(&ctx, elems, pinned_output_staging_enabled())
4767            })
4768            .collect();
4769
4770        let mut input_staging = HashMap::new();
4771        if pinned_input_staging_enabled(exec_mode) {
4772            for (name, &id) in &input_offsets {
4773                let elems = graph.node(id).shape.num_elements().unwrap_or(0);
4774                input_staging.insert(name.clone(), F32HostSlot::new(&ctx, elems, true));
4775            }
4776        }
4777
4778        let replay_event = if exec_mode == ExecMode::Graph {
4779            ctx.new_event(None).ok()
4780        } else {
4781            None
4782        };
4783
4784        let mut input_slot_names = Vec::new();
4785        let mut input_slots = Vec::new();
4786        for node in graph.nodes() {
4787            if let Op::Input { name } = &node.op {
4788                let off = if arena.has(node.id) {
4789                    arena.offset(node.id)
4790                } else {
4791                    0
4792                };
4793                let len = node.shape.num_elements().unwrap_or(0);
4794                input_slot_names.push(name.clone());
4795                input_slots.push((off, len));
4796            }
4797        }
4798
4799        let mut host_total = 0usize;
4800        let mut output_slots = Vec::new();
4801        for &id in &graph.outputs {
4802            let n = graph.node(id).shape.num_elements().unwrap_or(0);
4803            output_slots.push((host_total * 4, n));
4804            host_total += n;
4805        }
4806        let host_arena = vec![0.0f32; host_total];
4807
4808        Self {
4809            ctx,
4810            blas,
4811            blas_lt,
4812            blas_lt_workspace,
4813            dnn,
4814            dnn_workspace,
4815            half_act_scratch: None,
4816            dequant_scratch_off,
4817            graph,
4818            arena,
4819            schedule,
4820            input_offsets,
4821            param_offsets,
4822            meta_buffers,
4823            exec_mode,
4824            captured_graph: None,
4825            streams,
4826            active_extent: None,
4827            output_staging,
4828            input_staging,
4829            replay_event,
4830            gpu_handles: HashMap::new(),
4831            gpu_handle_feeds: HashMap::new(),
4832            gpu_handle_resident: std::collections::HashSet::new(),
4833            pending_read_indices: None,
4834            readback_plan_buf: Vec::new(),
4835            captured_readback_plan: None,
4836            input_slot_names,
4837            input_slots,
4838            output_slots,
4839            host_arena,
4840        }
4841    }
4842
4843    /// Host buffer base for reading outputs after [`Self::run_slots`].
4844    /// Offsets in the returned slot pairs are **byte** offsets into this buffer.
4845    pub fn arena_ptr(&self) -> *const u8 {
4846        self.host_arena.as_ptr() as *const u8
4847    }
4848
4849    pub fn output_slots(&self) -> &[(usize, usize)] {
4850        &self.output_slots
4851    }
4852
4853    fn upload_slot_inputs(&mut self, inputs: &[&[f32]]) {
4854        let stream = self.ctx.default_stream();
4855        for (i, data) in inputs.iter().enumerate() {
4856            let Some(&(byte_off, max_elems)) = self.input_slots.get(i) else {
4857                break;
4858            };
4859            let off_f32 = byte_off / 4;
4860            let len = data.len().min(max_elems);
4861            if len == 0 {
4862                continue;
4863            }
4864            let mut slot = self.arena.f32_buf_mut().slice_mut(off_f32..off_f32 + len);
4865            if let Some(name) = self.input_slot_names.get(i) {
4866                if let Some(host) = self.input_staging.get_mut(name.as_str()) {
4867                    host.copy_from_host(data);
4868                    let _ = host.htod(&stream, &mut slot, len);
4869                    continue;
4870                }
4871            }
4872            let _ = stream.memcpy_htod(&data[..len], &mut slot);
4873        }
4874    }
4875
4876    fn pack_host_arena(&mut self) {
4877        self.prepare_readback_plan();
4878        for &i in &self.readback_plan_buf {
4879            if i >= self.output_staging.len() || i >= self.output_slots.len() {
4880                continue;
4881            }
4882            let (byte_off, n) = self.output_slots[i];
4883            if n == 0 {
4884                continue;
4885            }
4886            let start = byte_off / 4;
4887            let end = start + n;
4888            if end <= self.host_arena.len() {
4889                self.output_staging[i].copy_into(&mut self.host_arena[start..end]);
4890            }
4891        }
4892    }
4893
4894    /// Fast path: positional inputs, D2H into [`Self::host_arena`], no per-output `Vec`.
4895    pub fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
4896        self.upload_slot_inputs(inputs);
4897        let _ = self.run_inner(&[]);
4898        self.pack_host_arena();
4899        &self.output_slots
4900    }
4901
4902    /// Hint the next `run` to process only the first `actual` rows
4903    /// along the bucket axis (out of `upper`, the compile extent).
4904    /// Honored when every step in the schedule passes
4905    /// `Step::safe_for_active_extent`. Bypasses captured CUDA Graph
4906    /// (recorded at full extent) when active. See PLAN L1.
4907    pub fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
4908        self.active_extent = extent;
4909    }
4910
4911    fn all_safe_for_active(&self) -> bool {
4912        self.schedule.iter().all(|s| s.safe_for_active_extent())
4913    }
4914
4915    /// Declared graph-output dtypes, in `graph.outputs` order. Used by
4916    /// the runtime wrapper's `run_typed` to narrow f32 outputs back to
4917    /// the declared dtype on the way out.
4918    pub fn output_dtypes(&self) -> Vec<rlx_ir::DType> {
4919        self.graph
4920            .outputs
4921            .iter()
4922            .map(|&id| self.graph.node(id).shape.dtype())
4923            .collect()
4924    }
4925
4926    pub fn set_param(&mut self, name: &str, data: &[f32]) {
4927        if let Some(&id) = self.param_offsets.get(name)
4928            && self.arena.has(id)
4929        {
4930            let off_f32 = self.arena.offset(id) / 4;
4931            let stream = self.ctx.default_stream();
4932            let mut slot = self
4933                .arena
4934                .f32_buf_mut()
4935                .slice_mut(off_f32..off_f32 + data.len());
4936            stream
4937                .memcpy_htod(data, &mut slot)
4938                .expect("rlx-cuda: param upload failed");
4939        }
4940    }
4941
4942    /// Upload packed U8/I8 GGUF weights into the param slot (byte offset).
4943    pub fn set_param_bytes(&mut self, name: &str, data: &[u8]) {
4944        if let Some(&id) = self.param_offsets.get(name)
4945            && self.arena.has(id)
4946        {
4947            let byte_off = self.arena.offset(id);
4948            let stream = self.ctx.default_stream();
4949            crate::gguf_host::upload_param_bytes(&stream, self.arena.f32_buf_mut(), byte_off, data);
4950        }
4951    }
4952
4953    /// Upload a param as packed half-precision bits (`u16` per element).
4954    /// Caller passes the raw IEEE-754 binary16 (`F16`) or BFloat16
4955    /// (`Bf16`) bit pattern; the backend stores it in the half-arena
4956    /// side-buffer and skips the f32 slot entirely. Use cases:
4957    /// 2× weight-memory savings for inference, plus Tensor Core matmul
4958    /// via `cublasGemmEx` when both A and B (or just B) are stored
4959    /// half-precision.
4960    ///
4961    /// When the same `name` is also `set_param`'d as f32, the
4962    /// half-arena entry takes precedence in the matmul dispatch. Use
4963    /// only one of the two for any given param.
4964    pub fn set_param_half(&mut self, name: &str, dtype: crate::arena::HalfDtype, bits: &[u16]) {
4965        let id = match self.param_offsets.get(name) {
4966            Some(&id) if self.arena.has(id) => id,
4967            _ => return,
4968        };
4969        let f32_off = (self.arena.offset(id) / 4) as u32;
4970        let off = self
4971            .arena
4972            .register_half_param(&self.ctx, id, f32_off, bits.len(), dtype);
4973        let stream = self.ctx.default_stream();
4974        if let Some(buf) = self.arena.half_buffer.as_mut() {
4975            let mut slot = buf.slice_mut(off..off + bits.len());
4976            stream
4977                .memcpy_htod(bits, &mut slot)
4978                .expect("rlx-cuda: half-param upload failed");
4979        }
4980    }
4981
4982    pub fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
4983        self.run_read_outputs(inputs, None)
4984    }
4985
4986    /// Run and read back only selected outputs (+ GPU handle feed outputs).
4987    pub fn run_read_outputs(
4988        &mut self,
4989        inputs: &[(&str, &[f32])],
4990        read_indices: Option<&[usize]>,
4991    ) -> Vec<Vec<f32>> {
4992        match read_indices {
4993            None => self.pending_read_indices = None,
4994            Some(ix) => {
4995                let buf = self.pending_read_indices.get_or_insert_with(Vec::new);
4996                buf.clear();
4997                buf.extend_from_slice(ix);
4998                normalize_read_indices(buf);
4999            }
5000        }
5001        let outs = self.run_inner(inputs);
5002        self.pending_read_indices = None;
5003        outs
5004    }
5005
5006    pub fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
5007        if !self.input_offsets.contains_key(name) {
5008            return false;
5009        }
5010        self.gpu_handle_resident.remove(name);
5011        self.gpu_handles.insert(name.to_string(), data.to_vec());
5012        true
5013    }
5014
5015    pub fn has_gpu_handle(&self, name: &str) -> bool {
5016        self.gpu_handles.contains_key(name)
5017    }
5018
5019    pub fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) {
5020        self.gpu_handle_feeds
5021            .insert(handle_name.to_string(), output_index);
5022    }
5023
5024    pub fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
5025        if let Some(&out_idx) = self.gpu_handle_feeds.get(name) {
5026            if out_idx < self.graph.outputs.len() {
5027                let id = self.graph.outputs[out_idx];
5028                let stream = self.ctx.default_stream();
5029                let off_f32 = self.arena.offset(id) / 4;
5030                let n_f32 = self.arena.len_of(id) / 4;
5031                let mut host = vec![0f32; n_f32];
5032                let src = self.arena.f32_buf().slice(off_f32..off_f32 + n_f32);
5033                if stream.memcpy_dtoh(&src, host.as_mut_slice()).is_ok() {
5034                    return Some(host);
5035                }
5036            }
5037        }
5038        if self.gpu_handle_resident.contains(name) {
5039            if let Some(&id) = self.input_offsets.get(name) {
5040                let stream = self.ctx.default_stream();
5041                let off_f32 = self.arena.offset(id) / 4;
5042                let n_f32 = self.arena.len_of(id) / 4;
5043                let mut host = vec![0f32; n_f32];
5044                let src = self.arena.f32_buf().slice(off_f32..off_f32 + n_f32);
5045                if stream.memcpy_dtoh(&src, host.as_mut_slice()).is_ok() {
5046                    return Some(host);
5047                }
5048            }
5049        }
5050        self.gpu_handles.get(name).cloned()
5051    }
5052
5053    /// Build the sorted output readback plan into [`Self::readback_plan_buf`].
5054    fn prepare_readback_plan(&mut self) {
5055        self.readback_plan_buf.clear();
5056        let n = self.graph.outputs.len();
5057        if let Some(ref want) = self.pending_read_indices {
5058            self.readback_plan_buf.extend_from_slice(want);
5059            normalize_read_indices(&mut self.readback_plan_buf);
5060            return;
5061        }
5062        self.readback_plan_buf.extend(0..n);
5063    }
5064
5065    fn propagate_gpu_handle_feeds_d2d(&mut self, stream: &Arc<cudarc::driver::CudaStream>) {
5066        let extent = self.active_extent;
5067        for (name, &out_idx) in &self.gpu_handle_feeds {
5068            if out_idx >= self.graph.outputs.len() {
5069                continue;
5070            }
5071            let out_id = self.graph.outputs[out_idx];
5072            let Some(&in_id) = self.input_offsets.get(name.as_str()) else {
5073                continue;
5074            };
5075            if in_id != out_id {
5076                let out_bytes = self.arena.len_of(out_id);
5077                let copy_bytes = match extent {
5078                    Some((actual, upper)) if upper > 0 => {
5079                        let stride = (out_bytes / (upper + 1)).max(4);
5080                        (actual * stride).min(out_bytes)
5081                    }
5082                    _ => out_bytes,
5083                }
5084                .min(self.arena.len_of(in_id));
5085                let src_off = self.arena.offset(out_id) / 4;
5086                let dst_off = self.arena.offset(in_id) / 4;
5087                let n_f32 = copy_bytes / 4;
5088                if n_f32 > 0 && src_off != dst_off {
5089                    let mut tmp = vec![0.0f32; n_f32];
5090                    let src = self.arena.f32_buf().slice(src_off..src_off + n_f32);
5091                    if stream.memcpy_dtoh(&src, &mut tmp).is_ok() {
5092                        let mut dst = self.arena.f32_buf_mut().slice_mut(dst_off..dst_off + n_f32);
5093                        let _ = stream.memcpy_htod(&tmp, &mut dst);
5094                    }
5095                }
5096            }
5097            self.gpu_handle_resident.insert(name.clone());
5098            self.gpu_handles.insert(name.clone(), Vec::new());
5099        }
5100    }
5101
5102    fn stage_gpu_handle_inputs(
5103        &mut self,
5104        stream: &Arc<cudarc::driver::CudaStream>,
5105        inputs: &[(&str, &[f32])],
5106    ) {
5107        for (name, data) in &self.gpu_handles {
5108            if self.gpu_handle_resident.contains(name) || inputs.iter().any(|(n, _)| n == name) {
5109                continue;
5110            }
5111            if let Some(&id) = self.input_offsets.get(name.as_str())
5112                && self.arena.has(id)
5113            {
5114                let off_f32 = self.arena.offset(id) / 4;
5115                let mut slot = self
5116                    .arena
5117                    .f32_buf_mut()
5118                    .slice_mut(off_f32..off_f32 + data.len());
5119                if let Some(host) = self.input_staging.get_mut(name.as_str()) {
5120                    host.copy_from_host(data);
5121                    let _ = host.htod(stream, &mut slot, data.len());
5122                } else {
5123                    let _ = stream.memcpy_htod(data.as_slice(), &mut slot);
5124                }
5125            }
5126        }
5127    }
5128
5129    fn refresh_gpu_handles_from_staging(&mut self, plan: &[usize]) {
5130        if self.pending_read_indices.is_some() {
5131            return;
5132        }
5133        for (name, &out_idx) in &self.gpu_handle_feeds {
5134            if plan.contains(&out_idx) && out_idx < self.output_staging.len() {
5135                self.gpu_handles
5136                    .insert(name.clone(), self.output_staging[out_idx].to_vec());
5137            }
5138        }
5139    }
5140
5141    fn run_inner(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
5142        let default_stream = self.ctx.default_stream();
5143        let stream = default_stream.clone();
5144
5145        self.stage_gpu_handle_inputs(&stream, inputs);
5146
5147        // Copy inputs to device. Always done outside any graph capture
5148        // — inputs change between runs and shouldn't be baked into the
5149        // captured CUDA Graph.
5150        for &(name, data) in inputs {
5151            if let Some(&id) = self.input_offsets.get(name)
5152                && self.arena.has(id)
5153            {
5154                let off_f32 = self.arena.offset(id) / 4;
5155                let mut slot = self
5156                    .arena
5157                    .f32_buf_mut()
5158                    .slice_mut(off_f32..off_f32 + data.len());
5159                if let Some(host) = self.input_staging.get_mut(name) {
5160                    host.copy_from_host(data);
5161                    host.htod(&stream, &mut slot, data.len())
5162                        .expect("rlx-cuda: pinned input upload failed");
5163                } else {
5164                    stream
5165                        .memcpy_htod(data, &mut slot)
5166                        .expect("rlx-cuda: input upload failed");
5167                }
5168            }
5169        }
5170
5171        // Active-extent (PLAN L1): when set + every Step safe, bypass
5172        // captured CUDA Graph (recorded at full extent) and dispatch
5173        // per-step with scaled launch dims via the normal loop.
5174        let active = self.active_extent.filter(|_| self.all_safe_for_active());
5175        // Scale a count by actual/upper with ceiling-division, clamped to [0, full].
5176        let scale = |full: u32| -> u32 {
5177            match active {
5178                Some((a, u)) if u > 0 => {
5179                    let f = full as usize;
5180                    (f * a).div_ceil(u).min(f) as u32
5181                }
5182                _ => full,
5183            }
5184        };
5185
5186        // CUDA Graph fast path: replay a previously-captured schedule.
5187        // The first run with `ExecMode::Graph` falls through to the
5188        // normal dispatch loop with stream capture turned on; the
5189        // resulting graph is stashed in `self.captured_graph` and
5190        // replayed on every subsequent run.
5191        let graph_eligible = active.is_none()
5192            && self.exec_mode == ExecMode::Graph
5193            && schedule_graph_capture_safe(&self.schedule);
5194        let do_replay = graph_eligible && self.captured_graph.is_some();
5195        let do_capture = graph_eligible && self.captured_graph.is_none();
5196
5197        if do_replay {
5198            self.prepare_readback_plan();
5199            let plan_ok = self
5200                .captured_readback_plan
5201                .as_ref()
5202                .is_some_and(|p| p.as_slice() == self.readback_plan_buf.as_slice());
5203            if plan_ok {
5204                self.captured_graph
5205                    .as_ref()
5206                    .unwrap()
5207                    .launch()
5208                    .expect("rlx-cuda: graph replay failed");
5209                if let Some(evt) = &self.replay_event {
5210                    evt.record(&stream)
5211                        .expect("rlx-cuda: replay event record failed");
5212                    evt.synchronize()
5213                        .expect("rlx-cuda: replay event sync failed");
5214                } else {
5215                    stream.synchronize().expect("rlx-cuda: stream sync failed");
5216                }
5217                run_tail_host_audio_ops(&self.schedule, &stream, self.arena.f32_buf_mut(), false);
5218                let plan = self.readback_plan_buf.clone();
5219                let read_all = plan.len() == self.graph.outputs.len();
5220                // DtoH must run after every replay — inputs change each run and
5221                // must not rely on dtoh baked into the captured graph.
5222                if read_all {
5223                    self.fill_output_staging(&stream)
5224                        .expect("rlx-cuda: output dtoh failed after replay");
5225                } else {
5226                    self.fill_output_staging_indices(&stream, &plan)
5227                        .expect("rlx-cuda: partial output dtoh failed after replay");
5228                }
5229                self.refresh_gpu_handles_from_staging(&plan);
5230                return self.outputs_from_staging_plan(&plan);
5231            }
5232            // Readback plan changed (e.g. partial grads); drop stale capture and re-dispatch.
5233            self.captured_graph = None;
5234            self.captured_readback_plan = None;
5235        }
5236        let _ = do_replay;
5237
5238        let mut capturing = false;
5239        if do_capture {
5240            capturing = stream
5241                .begin_capture(
5242                    cudarc::driver::sys::CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_RELAXED,
5243                )
5244                .is_ok();
5245        }
5246
5247        // Multi-stream scheduler state. When `exec_mode ==
5248        // MultiStream(n)`, each Step gets assigned to one of `n` pool
5249        // streams based on producer-consumer dependencies on arena
5250        // offsets. Independent ops (e.g. unfused Q/K/V matmuls)
5251        // parallelise; producer-consumer chains stay on one stream.
5252        let multi_stream =
5253            matches!(self.exec_mode, ExecMode::MultiStream(_)) && !self.streams.is_empty();
5254        let mut producer_of: HashMap<u32, usize> = HashMap::new();
5255        let mut last_event: HashMap<usize, cudarc::driver::CudaEvent> = HashMap::new();
5256        let mut rr_cursor: usize = 0;
5257
5258        // Dispatch each step. Each iteration is wrapped in an NVTX
5259        // range so nsight-systems traces show step boundaries cleanly.
5260        // Gated behind the `nvtx` feature because CUDA 13 removed
5261        // `nvToolsExt.dll`; cudarc panics on first call when the lib
5262        // isn't loadable.
5263        for step in &self.schedule {
5264            #[cfg(feature = "nvtx")]
5265            let _nvtx = cudarc::nvtx::scoped_range(step_name(step));
5266            // PLAN L3: cross-backend Perfetto trace; no-op when env
5267            // var RLX_TRACE_PERFETTO unset.
5268            let _perf = rlx_ir::perfetto::TraceSpan::new(step_name(step), "cuda");
5269
5270            // Per-step stream selection. In single-stream mode `stream`
5271            // shadows to the default stream; in multi-stream mode it
5272            // shadows to the assigned pool stream (and we cross-stream
5273            // event-wait on every producer not on the chosen stream).
5274            let assigned_idx: Option<usize> = if multi_stream {
5275                let (reads, _) = step_offsets(step);
5276                let mut producer_streams: std::collections::HashSet<usize> =
5277                    std::collections::HashSet::new();
5278                for r in &reads {
5279                    if let Some(&s) = producer_of.get(r) {
5280                        producer_streams.insert(s);
5281                    }
5282                }
5283                let chosen = if producer_streams.is_empty() {
5284                    let s = rr_cursor % self.streams.len();
5285                    rr_cursor += 1;
5286                    s
5287                } else if producer_streams.len() == 1 {
5288                    *producer_streams.iter().next().unwrap()
5289                } else {
5290                    // Multiple producers — keep the chosen one's queue
5291                    // intact and event-wait on the others.
5292                    let chosen = *producer_streams.iter().next().unwrap();
5293                    for s in &producer_streams {
5294                        if *s != chosen
5295                            && let Some(evt) = last_event.get(s)
5296                        {
5297                            let _ = self.streams[chosen].wait(evt);
5298                        }
5299                    }
5300                    chosen
5301                };
5302                Some(chosen)
5303            } else {
5304                None
5305            };
5306            let stream: Arc<cudarc::driver::CudaStream> = match assigned_idx {
5307                Some(i) => self.streams[i].clone(),
5308                None => default_stream.clone(),
5309            };
5310            // Re-bind cuBLAS / cuDNN handles to the active stream so
5311            // their internal kernel launches go to the right queue.
5312            if multi_stream {
5313                if let Some(blas) = self.blas.as_ref() {
5314                    let blas = blas.lock().unwrap();
5315                    unsafe {
5316                        let _ = cudarc::cublas::result::set_stream(
5317                            *blas.handle(),
5318                            stream.cu_stream() as _,
5319                        );
5320                    }
5321                }
5322                if let Some(handle) = self.dnn {
5323                    unsafe {
5324                        let _ = cudarc::cudnn::result::set_stream(
5325                            handle,
5326                            stream.cu_stream() as cudnn_sys::cudaStream_t,
5327                        );
5328                    }
5329                }
5330            }
5331            match step {
5332                Step::Matmul {
5333                    m,
5334                    k,
5335                    n,
5336                    a_off_f32,
5337                    b_off_f32,
5338                    c_off_f32,
5339                    batch,
5340                    a_batch_stride,
5341                    b_batch_stride,
5342                    c_batch_stride,
5343                    has_bias,
5344                    bias_off_f32,
5345                    act_id,
5346                } => {
5347                    if matmul_parity_mode() {
5348                        let kernel = matmul_kernel(&self.ctx);
5349                        let cfg = LaunchConfig {
5350                            grid_dim: ((*n).div_ceil(64), (*m).div_ceil(64), *batch),
5351                            block_dim: (16, 16, 1),
5352                            shared_mem_bytes: 0,
5353                        };
5354                        let mut launcher = stream.launch_builder(&kernel.function);
5355                        launcher
5356                            .arg(self.arena.f32_buf_mut())
5357                            .arg(m)
5358                            .arg(k)
5359                            .arg(n)
5360                            .arg(a_off_f32)
5361                            .arg(b_off_f32)
5362                            .arg(c_off_f32)
5363                            .arg(batch)
5364                            .arg(a_batch_stride)
5365                            .arg(b_batch_stride)
5366                            .arg(c_batch_stride)
5367                            .arg(has_bias)
5368                            .arg(bias_off_f32)
5369                            .arg(act_id);
5370                        unsafe {
5371                            launcher
5372                                .launch(cfg)
5373                                .expect("rlx-cuda: matmul (parity) launch failed");
5374                        }
5375                        if let Some(idx) = assigned_idx {
5376                            if let Ok(evt) = stream.record_event(None) {
5377                                last_event.insert(idx, evt);
5378                            }
5379                            let (_, writes) = step_offsets(step);
5380                            for w in &writes {
5381                                producer_of.insert(*w, idx);
5382                            }
5383                        }
5384                        continue;
5385                    }
5386
5387                    // Tier 0: mixed-precision GemmEx — when B (the weight)
5388                    // is stored in the half-arena, cast activations to
5389                    // f16/bf16 in a scratch buffer and call cublasGemmEx
5390                    // with both inputs half + f32 accumulator. Falls
5391                    // through to cublasLt on any setup or runtime error.
5392                    let used_mixed = try_mixed_precision_gemm(
5393                        &self.ctx,
5394                        &mut self.arena,
5395                        &mut self.half_act_scratch,
5396                        self.blas.as_ref(),
5397                        &stream,
5398                        *m,
5399                        *k,
5400                        *n,
5401                        *batch,
5402                        *a_off_f32,
5403                        *b_off_f32,
5404                        *c_off_f32,
5405                    );
5406                    if used_mixed {
5407                        // Optional bias / activation epilogue.
5408                        if *has_bias != 0 || *act_id != 0xFFFFu32 {
5409                            let kernel = matmul_epilogue_kernel(&self.ctx);
5410                            let total = m * n * batch;
5411                            let (grid, block) = dispatch_grid_1d(total, 256);
5412                            let cfg = LaunchConfig {
5413                                grid_dim: (grid, 1, 1),
5414                                block_dim: (block, 1, 1),
5415                                shared_mem_bytes: 0,
5416                            };
5417                            let mut launcher = stream.launch_builder(&kernel.function);
5418                            launcher
5419                                .arg(self.arena.f32_buf_mut())
5420                                .arg(&total)
5421                                .arg(n)
5422                                .arg(c_off_f32)
5423                                .arg(has_bias)
5424                                .arg(bias_off_f32)
5425                                .arg(act_id);
5426                            unsafe {
5427                                launcher
5428                                    .launch(cfg)
5429                                    .expect("rlx-cuda: matmul_epilogue (mixed) failed");
5430                            }
5431                        }
5432                        // Multi-stream tail bookkeeping still runs at end of step.
5433                        if let Some(idx) = assigned_idx {
5434                            if let Ok(evt) = stream.record_event(None) {
5435                                last_event.insert(idx, evt);
5436                            }
5437                            let (_, writes) = step_offsets(step);
5438                            for w in &writes {
5439                                producer_of.insert(*w, idx);
5440                            }
5441                        }
5442                        continue;
5443                    }
5444
5445                    // Tier 1: cublasLt fused (matmul + bias + relu/gelu in
5446                    // one launch). Only used when the activation is one of
5447                    // the two cublasLt natively fuses; other acts (silu,
5448                    // sigmoid, etc.) fall through to the sgemm + epilogue
5449                    // kernel path.
5450                    let try_cublaslt = self.blas_lt.is_some()
5451                        && self.blas_lt_workspace.is_some()
5452                        && cublaslt_act_supported(*act_id);
5453                    let used_cublaslt = if try_cublaslt {
5454                        let lt_handle = self.blas_lt.unwrap();
5455                        let mut workspace =
5456                            self.blas_lt_workspace.as_ref().unwrap().lock().unwrap();
5457                        let (workspace_ptr, _ws_record) = workspace.device_ptr_mut(&stream);
5458                        let (arena_ptr, _record) = self.arena.f32_buf_mut().device_ptr_mut(&stream);
5459                        let cu_stream = stream.cu_stream();
5460                        let act = cublaslt_act_for(*act_id);
5461                        let r = unsafe {
5462                            cublaslt_matmul_fused(
5463                                lt_handle,
5464                                workspace_ptr,
5465                                CUBLASLT_WORKSPACE_BYTES,
5466                                arena_ptr,
5467                                *m,
5468                                *k,
5469                                *n,
5470                                *a_off_f32,
5471                                *b_off_f32,
5472                                *c_off_f32,
5473                                *has_bias != 0,
5474                                *bias_off_f32,
5475                                act,
5476                                *batch,
5477                                *a_batch_stride,
5478                                *b_batch_stride,
5479                                *c_batch_stride,
5480                                cu_stream,
5481                            )
5482                        };
5483                        if let Err(ref e) = r {
5484                            log_fallback("matmul.cublasLt", e);
5485                        }
5486                        r.is_ok()
5487                    } else {
5488                        false
5489                    };
5490                    if used_cublaslt {
5491                        continue;
5492                    }
5493
5494                    // Tier 2: cuBLAS sgemm via raw pointers (bypasses
5495                    // the borrow checker's same-buffer aliasing).
5496                    let used_cublas = if let Some(blas) = self.blas.as_ref() {
5497                        let blas = blas.lock().unwrap();
5498                        let (arena_ptr_u64, _record) =
5499                            self.arena.f32_buf_mut().device_ptr_mut(&stream);
5500                        let a_dev = arena_ptr_u64 + (*a_off_f32 as u64) * 4;
5501                        let b_dev = arena_ptr_u64 + (*b_off_f32 as u64) * 4;
5502                        let c_dev = arena_ptr_u64 + (*c_off_f32 as u64) * 4;
5503                        let alpha: f32 = 1.0;
5504                        let beta: f32 = 0.0;
5505                        // cuBLAS is column-major; we have row-major. Trick:
5506                        // computing C = A·B (row-major) is the same as
5507                        // computing C^T = B^T · A^T (column-major), and
5508                        // viewing our row-major arrays as column-major
5509                        // automatically yields the transpose.
5510                        let result = unsafe {
5511                            if *batch == 1 {
5512                                cudarc::cublas::result::sgemm(
5513                                    *blas.handle(),
5514                                    cublas_sys::cublasOperation_t::CUBLAS_OP_N,
5515                                    cublas_sys::cublasOperation_t::CUBLAS_OP_N,
5516                                    *n as i32,
5517                                    *m as i32,
5518                                    *k as i32,
5519                                    &alpha as *const f32,
5520                                    b_dev as *const f32,
5521                                    *n as i32,
5522                                    a_dev as *const f32,
5523                                    *k as i32,
5524                                    &beta as *const f32,
5525                                    c_dev as *mut f32,
5526                                    *n as i32,
5527                                )
5528                            } else {
5529                                cudarc::cublas::result::sgemm_strided_batched(
5530                                    *blas.handle(),
5531                                    cublas_sys::cublasOperation_t::CUBLAS_OP_N,
5532                                    cublas_sys::cublasOperation_t::CUBLAS_OP_N,
5533                                    *n as i32,
5534                                    *m as i32,
5535                                    *k as i32,
5536                                    &alpha as *const f32,
5537                                    b_dev as *const f32,
5538                                    *n as i32,
5539                                    *b_batch_stride as i64,
5540                                    a_dev as *const f32,
5541                                    *k as i32,
5542                                    *a_batch_stride as i64,
5543                                    &beta as *const f32,
5544                                    c_dev as *mut f32,
5545                                    *n as i32,
5546                                    *c_batch_stride as i64,
5547                                    *batch as i32,
5548                                )
5549                            }
5550                        };
5551                        if let Err(ref e) = result {
5552                            log_fallback("matmul.cublasSgemm", e);
5553                        }
5554                        result.is_ok()
5555                    } else {
5556                        false
5557                    };
5558
5559                    if used_cublas {
5560                        // Optional fused epilogue (bias + activation) as
5561                        // a separate element-wise kernel.
5562                        if *has_bias != 0 || *act_id != 0xFFFFu32 {
5563                            let kernel = matmul_epilogue_kernel(&self.ctx);
5564                            let total = m * n * batch;
5565                            let (grid, block) = dispatch_grid_1d(total, 256);
5566                            let cfg = LaunchConfig {
5567                                grid_dim: (grid, 1, 1),
5568                                block_dim: (block, 1, 1),
5569                                shared_mem_bytes: 0,
5570                            };
5571                            let mut launcher = stream.launch_builder(&kernel.function);
5572                            launcher
5573                                .arg(self.arena.f32_buf_mut())
5574                                .arg(&total)
5575                                .arg(n)
5576                                .arg(c_off_f32)
5577                                .arg(has_bias)
5578                                .arg(bias_off_f32)
5579                                .arg(act_id);
5580                            unsafe {
5581                                launcher
5582                                    .launch(cfg)
5583                                    .expect("rlx-cuda: matmul_epilogue launch failed");
5584                            }
5585                        }
5586                    } else if use_wmma() {
5587                        // WMMA Tensor Core path: 32×64 block tile, 128 threads/block,
5588                        // SM 70+ only. Doesn't fuse bias/activation — those go to the
5589                        // shared epilogue kernel.
5590                        let kernel = matmul_wmma_kernel(&self.ctx);
5591                        let cfg = LaunchConfig {
5592                            grid_dim: ((*n).div_ceil(64), (*m).div_ceil(32), *batch),
5593                            block_dim: (128, 1, 1),
5594                            shared_mem_bytes: 0,
5595                        };
5596                        let mut launcher = stream.launch_builder(&kernel.function);
5597                        launcher
5598                            .arg(self.arena.f32_buf_mut())
5599                            .arg(m)
5600                            .arg(k)
5601                            .arg(n)
5602                            .arg(a_off_f32)
5603                            .arg(b_off_f32)
5604                            .arg(c_off_f32)
5605                            .arg(batch)
5606                            .arg(a_batch_stride)
5607                            .arg(b_batch_stride)
5608                            .arg(c_batch_stride);
5609                        unsafe {
5610                            launcher
5611                                .launch(cfg)
5612                                .expect("rlx-cuda: matmul_wmma launch failed");
5613                        }
5614                        if *has_bias != 0 || *act_id != 0xFFFFu32 {
5615                            let kernel = matmul_epilogue_kernel(&self.ctx);
5616                            let total = m * n * batch;
5617                            let (grid, block) = dispatch_grid_1d(total, 256);
5618                            let cfg = LaunchConfig {
5619                                grid_dim: (grid, 1, 1),
5620                                block_dim: (block, 1, 1),
5621                                shared_mem_bytes: 0,
5622                            };
5623                            let mut launcher = stream.launch_builder(&kernel.function);
5624                            launcher
5625                                .arg(self.arena.f32_buf_mut())
5626                                .arg(&total)
5627                                .arg(n)
5628                                .arg(c_off_f32)
5629                                .arg(has_bias)
5630                                .arg(bias_off_f32)
5631                                .arg(act_id);
5632                            unsafe {
5633                                launcher
5634                                    .launch(cfg)
5635                                    .expect("rlx-cuda: matmul_epilogue (post-wmma) failed");
5636                            }
5637                        }
5638                    } else {
5639                        // Custom scalar kernel fallback: 64×64 block tile, 4×4 register tile.
5640                        let kernel = matmul_kernel(&self.ctx);
5641                        let cfg = LaunchConfig {
5642                            grid_dim: ((*n).div_ceil(64), (*m).div_ceil(64), *batch),
5643                            block_dim: (16, 16, 1),
5644                            shared_mem_bytes: 0,
5645                        };
5646                        let mut launcher = stream.launch_builder(&kernel.function);
5647                        launcher
5648                            .arg(self.arena.f32_buf_mut())
5649                            .arg(m)
5650                            .arg(k)
5651                            .arg(n)
5652                            .arg(a_off_f32)
5653                            .arg(b_off_f32)
5654                            .arg(c_off_f32)
5655                            .arg(batch)
5656                            .arg(a_batch_stride)
5657                            .arg(b_batch_stride)
5658                            .arg(c_batch_stride)
5659                            .arg(has_bias)
5660                            .arg(bias_off_f32)
5661                            .arg(act_id);
5662                        unsafe {
5663                            launcher
5664                                .launch(cfg)
5665                                .expect("rlx-cuda: matmul launch failed");
5666                        }
5667                    }
5668                }
5669                Step::Binary {
5670                    n,
5671                    a_off,
5672                    b_off,
5673                    c_off,
5674                    op,
5675                } => {
5676                    let n_s = scale(*n);
5677                    if n_s == 0 {
5678                        continue;
5679                    }
5680                    let kernel = binary_kernel(&self.ctx);
5681                    let (grid, block) = dispatch_grid_1d(n_s, 256);
5682                    let cfg = LaunchConfig {
5683                        grid_dim: (grid, 1, 1),
5684                        block_dim: (block, 1, 1),
5685                        shared_mem_bytes: 0,
5686                    };
5687                    let mut launcher = stream.launch_builder(&kernel.function);
5688                    launcher
5689                        .arg(self.arena.f32_buf_mut())
5690                        .arg(&n_s)
5691                        .arg(a_off)
5692                        .arg(b_off)
5693                        .arg(c_off)
5694                        .arg(op);
5695                    unsafe {
5696                        launcher
5697                            .launch(cfg)
5698                            .expect("rlx-cuda: binary launch failed");
5699                    }
5700                }
5701                Step::ElementwiseRegion {
5702                    len,
5703                    num_inputs,
5704                    num_steps,
5705                    dst_off,
5706                    input_offs: _,
5707                    scalar_input_mask,
5708                    input_modulus,
5709                    meta_idx,
5710                    spatial_prologue,
5711                    prologue_w,
5712                    prologue_h,
5713                    prologue_nc,
5714                } => {
5715                    let len_s = scale(*len);
5716                    if len_s == 0 {
5717                        continue;
5718                    }
5719                    let kernel = elementwise_region_kernel(&self.ctx);
5720                    let ((gx, gy, gz), (bx, by, bz)) = if *spatial_prologue {
5721                        dispatch_grid_prologue_nchw(*prologue_w, *prologue_h, *prologue_nc)
5722                    } else {
5723                        let (grid, block) = dispatch_grid_1d(len_s, 256);
5724                        ((grid, 1, 1), (block, 1, 1))
5725                    };
5726                    let cfg = LaunchConfig {
5727                        grid_dim: (gx, gy, gz),
5728                        block_dim: (bx, by, bz),
5729                        shared_mem_bytes: 0,
5730                    };
5731                    let mut launcher = stream.launch_builder(&kernel.function);
5732                    // input_modulus is passed by-value as a 64-byte
5733                    // const param (16 u32s). Could move to meta_buffer
5734                    // but a constant param keeps the kernel signature
5735                    // self-describing.
5736                    launcher
5737                        .arg(self.arena.f32_buf_mut())
5738                        .arg(&len_s)
5739                        .arg(num_inputs)
5740                        .arg(num_steps)
5741                        .arg(dst_off)
5742                        .arg(&self.meta_buffers[*meta_idx])
5743                        .arg(scalar_input_mask)
5744                        .arg(input_modulus);
5745                    unsafe {
5746                        launcher
5747                            .launch(cfg)
5748                            .expect("rlx-cuda: elementwise_region launch failed");
5749                    }
5750                }
5751                Step::BatchElementwiseRegion {
5752                    slice_len,
5753                    num_batch,
5754                    num_steps,
5755                    base_dst_off,
5756                    slice_elems,
5757                    batch_offs_idx,
5758                    meta_idx,
5759                    scalar_input_mask,
5760                    input_modulus,
5761                    ..
5762                } => {
5763                    let slice_len_s = scale(*slice_len);
5764                    let num_batch_s = scale(*num_batch);
5765                    if slice_len_s == 0 || num_batch_s == 0 {
5766                        continue;
5767                    }
5768                    let kernel = batch_elementwise_region_kernel(&self.ctx);
5769                    let (grid_x, block_x) = dispatch_grid_1d(slice_len_s, 256);
5770                    let cfg = LaunchConfig {
5771                        grid_dim: (grid_x, 1, num_batch_s),
5772                        block_dim: (block_x, 1, 1),
5773                        shared_mem_bytes: 0,
5774                    };
5775                    let mut launcher = stream.launch_builder(&kernel.function);
5776                    launcher
5777                        .arg(self.arena.f32_buf_mut())
5778                        .arg(&slice_len_s)
5779                        .arg(&num_batch_s)
5780                        .arg(num_steps)
5781                        .arg(base_dst_off)
5782                        .arg(slice_elems)
5783                        .arg(&self.meta_buffers[*batch_offs_idx])
5784                        .arg(&self.meta_buffers[*meta_idx])
5785                        .arg(scalar_input_mask)
5786                        .arg(input_modulus);
5787                    unsafe {
5788                        launcher
5789                            .launch(cfg)
5790                            .expect("rlx-cuda: batch_elementwise_region launch failed");
5791                    }
5792                }
5793                Step::FusedBinaryUnary {
5794                    n,
5795                    a_off,
5796                    b_off,
5797                    out_off,
5798                    bin_op,
5799                    un_op,
5800                } => {
5801                    let n_s = scale(*n);
5802                    if n_s == 0 {
5803                        continue;
5804                    }
5805                    let kernel = fused_binary_unary_kernel(&self.ctx);
5806                    let (grid, block) = dispatch_grid_1d(n_s, 256);
5807                    let cfg = LaunchConfig {
5808                        grid_dim: (grid, 1, 1),
5809                        block_dim: (block, 1, 1),
5810                        shared_mem_bytes: 0,
5811                    };
5812                    let mut launcher = stream.launch_builder(&kernel.function);
5813                    launcher
5814                        .arg(self.arena.f32_buf_mut())
5815                        .arg(&n_s)
5816                        .arg(a_off)
5817                        .arg(b_off)
5818                        .arg(out_off)
5819                        .arg(bin_op)
5820                        .arg(un_op);
5821                    unsafe {
5822                        launcher
5823                            .launch(cfg)
5824                            .expect("rlx-cuda: fused_binary_unary launch failed");
5825                    }
5826                }
5827                Step::Unary {
5828                    n,
5829                    in_off,
5830                    out_off,
5831                    op,
5832                } => {
5833                    let n_s = scale(*n);
5834                    if n_s == 0 {
5835                        continue;
5836                    }
5837                    let kernel = unary_kernel(&self.ctx);
5838                    let (grid, block) = dispatch_grid_1d(n_s, 256);
5839                    let cfg = LaunchConfig {
5840                        grid_dim: (grid, 1, 1),
5841                        block_dim: (block, 1, 1),
5842                        shared_mem_bytes: 0,
5843                    };
5844                    let mut launcher = stream.launch_builder(&kernel.function);
5845                    launcher
5846                        .arg(self.arena.f32_buf_mut())
5847                        .arg(&n_s)
5848                        .arg(in_off)
5849                        .arg(out_off)
5850                        .arg(op);
5851                    unsafe {
5852                        launcher.launch(cfg).expect("rlx-cuda: unary launch failed");
5853                    }
5854                }
5855                Step::Compare {
5856                    n,
5857                    a_off,
5858                    b_off,
5859                    c_off,
5860                    op,
5861                } => {
5862                    let n_s = scale(*n);
5863                    if n_s == 0 {
5864                        continue;
5865                    }
5866                    let kernel = compare_kernel(&self.ctx);
5867                    let (grid, block) = dispatch_grid_1d(n_s, 256);
5868                    let cfg = LaunchConfig {
5869                        grid_dim: (grid, 1, 1),
5870                        block_dim: (block, 1, 1),
5871                        shared_mem_bytes: 0,
5872                    };
5873                    let mut launcher = stream.launch_builder(&kernel.function);
5874                    launcher
5875                        .arg(self.arena.f32_buf_mut())
5876                        .arg(&n_s)
5877                        .arg(a_off)
5878                        .arg(b_off)
5879                        .arg(c_off)
5880                        .arg(op);
5881                    unsafe {
5882                        launcher
5883                            .launch(cfg)
5884                            .expect("rlx-cuda: compare launch failed");
5885                    }
5886                }
5887                Step::Where {
5888                    n,
5889                    cond_off,
5890                    x_off,
5891                    y_off,
5892                    out_off,
5893                } => {
5894                    let n_s = scale(*n);
5895                    if n_s == 0 {
5896                        continue;
5897                    }
5898                    let kernel = where_kernel(&self.ctx);
5899                    let (grid, block) = dispatch_grid_1d(n_s, 256);
5900                    let cfg = LaunchConfig {
5901                        grid_dim: (grid, 1, 1),
5902                        block_dim: (block, 1, 1),
5903                        shared_mem_bytes: 0,
5904                    };
5905                    let mut launcher = stream.launch_builder(&kernel.function);
5906                    launcher
5907                        .arg(self.arena.f32_buf_mut())
5908                        .arg(&n_s)
5909                        .arg(cond_off)
5910                        .arg(x_off)
5911                        .arg(y_off)
5912                        .arg(out_off);
5913                    unsafe {
5914                        launcher.launch(cfg).expect("rlx-cuda: where launch failed");
5915                    }
5916                }
5917                Step::Reduce {
5918                    outer,
5919                    inner,
5920                    in_off,
5921                    out_off,
5922                    op,
5923                } => {
5924                    let outer_s = scale(*outer);
5925                    if outer_s == 0 {
5926                        continue;
5927                    }
5928                    let kernel = reduce_kernel(&self.ctx);
5929                    let cfg = LaunchConfig {
5930                        grid_dim: (outer_s, 1, 1),
5931                        block_dim: (256, 1, 1),
5932                        shared_mem_bytes: 0,
5933                    };
5934                    let mut launcher = stream.launch_builder(&kernel.function);
5935                    launcher
5936                        .arg(self.arena.f32_buf_mut())
5937                        .arg(&outer_s)
5938                        .arg(inner)
5939                        .arg(in_off)
5940                        .arg(out_off)
5941                        .arg(op);
5942                    unsafe {
5943                        launcher
5944                            .launch(cfg)
5945                            .expect("rlx-cuda: reduce launch failed");
5946                    }
5947                }
5948                Step::Softmax {
5949                    outer,
5950                    inner,
5951                    in_off,
5952                    out_off,
5953                } => {
5954                    let outer_s = scale(*outer);
5955                    if outer_s == 0 {
5956                        continue;
5957                    }
5958                    let kernel = softmax_kernel(&self.ctx);
5959                    let cfg = LaunchConfig {
5960                        grid_dim: (outer_s, 1, 1),
5961                        block_dim: (256, 1, 1),
5962                        shared_mem_bytes: 0,
5963                    };
5964                    let mut launcher = stream.launch_builder(&kernel.function);
5965                    launcher
5966                        .arg(self.arena.f32_buf_mut())
5967                        .arg(&outer_s)
5968                        .arg(inner)
5969                        .arg(in_off)
5970                        .arg(out_off);
5971                    unsafe {
5972                        launcher
5973                            .launch(cfg)
5974                            .expect("rlx-cuda: softmax launch failed");
5975                    }
5976                }
5977                Step::LayerNorm {
5978                    outer,
5979                    inner,
5980                    in_off,
5981                    out_off,
5982                    gamma_off,
5983                    beta_off,
5984                    eps_bits,
5985                    op,
5986                } => {
5987                    let outer_s = scale(*outer);
5988                    if outer_s == 0 {
5989                        continue;
5990                    }
5991                    let kernel = layernorm_kernel(&self.ctx);
5992                    let cfg = LaunchConfig {
5993                        grid_dim: (outer_s, 1, 1),
5994                        block_dim: (256, 1, 1),
5995                        shared_mem_bytes: 0,
5996                    };
5997                    let mut launcher = stream.launch_builder(&kernel.function);
5998                    launcher
5999                        .arg(self.arena.f32_buf_mut())
6000                        .arg(&outer_s)
6001                        .arg(inner)
6002                        .arg(in_off)
6003                        .arg(out_off)
6004                        .arg(gamma_off)
6005                        .arg(beta_off)
6006                        .arg(eps_bits)
6007                        .arg(op);
6008                    unsafe {
6009                        launcher
6010                            .launch(cfg)
6011                            .expect("rlx-cuda: layernorm launch failed");
6012                    }
6013                }
6014                Step::FusedResidualLn {
6015                    outer,
6016                    inner,
6017                    in_off,
6018                    residual_off,
6019                    bias_off,
6020                    gamma_off,
6021                    beta_off,
6022                    out_off,
6023                    eps_bits,
6024                    has_bias,
6025                } => {
6026                    let outer_s = scale(*outer);
6027                    if outer_s == 0 {
6028                        continue;
6029                    }
6030                    let kernel = fused_residual_ln_kernel(&self.ctx);
6031                    let cfg = LaunchConfig {
6032                        grid_dim: (outer_s, 1, 1),
6033                        block_dim: (256, 1, 1),
6034                        shared_mem_bytes: 0,
6035                    };
6036                    let mut launcher = stream.launch_builder(&kernel.function);
6037                    launcher
6038                        .arg(self.arena.f32_buf_mut())
6039                        .arg(&outer_s)
6040                        .arg(inner)
6041                        .arg(in_off)
6042                        .arg(residual_off)
6043                        .arg(bias_off)
6044                        .arg(gamma_off)
6045                        .arg(beta_off)
6046                        .arg(out_off)
6047                        .arg(eps_bits)
6048                        .arg(has_bias);
6049                    unsafe {
6050                        launcher
6051                            .launch(cfg)
6052                            .expect("rlx-cuda: fused_residual_ln launch failed");
6053                    }
6054                }
6055                Step::FusedResidualRmsNorm {
6056                    outer,
6057                    inner,
6058                    in_off,
6059                    residual_off,
6060                    bias_off,
6061                    gamma_off,
6062                    beta_off,
6063                    out_off,
6064                    eps_bits,
6065                    has_bias,
6066                } => {
6067                    let outer_s = scale(*outer);
6068                    if outer_s == 0 {
6069                        continue;
6070                    }
6071                    let kernel = fused_residual_rms_norm_kernel(&self.ctx);
6072                    let cfg = LaunchConfig {
6073                        grid_dim: (outer_s, 1, 1),
6074                        block_dim: (256, 1, 1),
6075                        shared_mem_bytes: 0,
6076                    };
6077                    let mut launcher = stream.launch_builder(&kernel.function);
6078                    launcher
6079                        .arg(self.arena.f32_buf_mut())
6080                        .arg(&outer_s)
6081                        .arg(inner)
6082                        .arg(in_off)
6083                        .arg(residual_off)
6084                        .arg(bias_off)
6085                        .arg(gamma_off)
6086                        .arg(beta_off)
6087                        .arg(out_off)
6088                        .arg(eps_bits)
6089                        .arg(has_bias);
6090                    unsafe {
6091                        launcher
6092                            .launch(cfg)
6093                            .expect("rlx-cuda: fused_residual_rms_norm launch failed");
6094                    }
6095                }
6096                Step::Gather {
6097                    n_out,
6098                    n_idx,
6099                    dim,
6100                    vocab,
6101                    in_off,
6102                    idx_off,
6103                    out_off,
6104                } => {
6105                    let kernel = gather_kernel(&self.ctx);
6106                    let (grid, block) = dispatch_grid_1d(*n_out, 256);
6107                    let cfg = LaunchConfig {
6108                        grid_dim: (grid, 1, 1),
6109                        block_dim: (block, 1, 1),
6110                        shared_mem_bytes: 0,
6111                    };
6112                    let mut launcher = stream.launch_builder(&kernel.function);
6113                    launcher
6114                        .arg(self.arena.f32_buf_mut())
6115                        .arg(n_out)
6116                        .arg(n_idx)
6117                        .arg(dim)
6118                        .arg(vocab)
6119                        .arg(in_off)
6120                        .arg(idx_off)
6121                        .arg(out_off);
6122                    unsafe {
6123                        launcher
6124                            .launch(cfg)
6125                            .expect("rlx-cuda: gather launch failed");
6126                    }
6127                }
6128                Step::GatherAxis {
6129                    total,
6130                    outer,
6131                    axis_dim,
6132                    num_idx,
6133                    trailing,
6134                    table_off,
6135                    idx_off,
6136                    out_off,
6137                } => {
6138                    let kernel = gather_axis_kernel(&self.ctx);
6139                    let (grid, block) = dispatch_grid_1d(*total, 256);
6140                    let cfg = LaunchConfig {
6141                        grid_dim: (grid, 1, 1),
6142                        block_dim: (block, 1, 1),
6143                        shared_mem_bytes: 0,
6144                    };
6145                    let mut launcher = stream.launch_builder(&kernel.function);
6146                    launcher
6147                        .arg(self.arena.f32_buf_mut())
6148                        .arg(total)
6149                        .arg(outer)
6150                        .arg(axis_dim)
6151                        .arg(num_idx)
6152                        .arg(trailing)
6153                        .arg(table_off)
6154                        .arg(idx_off)
6155                        .arg(out_off);
6156                    unsafe {
6157                        launcher
6158                            .launch(cfg)
6159                            .expect("rlx-cuda: gather_axis launch failed");
6160                    }
6161                }
6162                Step::Narrow {
6163                    total,
6164                    outer,
6165                    inner,
6166                    axis_in_size,
6167                    axis_out_size,
6168                    start,
6169                    in_off,
6170                    out_off,
6171                } => {
6172                    let kernel = narrow_kernel(&self.ctx);
6173                    let (grid, block) = dispatch_grid_1d(*total, 256);
6174                    let cfg = LaunchConfig {
6175                        grid_dim: (grid, 1, 1),
6176                        block_dim: (block, 1, 1),
6177                        shared_mem_bytes: 0,
6178                    };
6179                    let mut launcher = stream.launch_builder(&kernel.function);
6180                    launcher
6181                        .arg(self.arena.f32_buf_mut())
6182                        .arg(total)
6183                        .arg(outer)
6184                        .arg(inner)
6185                        .arg(axis_in_size)
6186                        .arg(axis_out_size)
6187                        .arg(start)
6188                        .arg(in_off)
6189                        .arg(out_off);
6190                    unsafe {
6191                        launcher
6192                            .launch(cfg)
6193                            .expect("rlx-cuda: narrow launch failed");
6194                    }
6195                }
6196                Step::Argmax {
6197                    outer,
6198                    inner,
6199                    in_off,
6200                    out_off,
6201                } => {
6202                    let kernel = argmax_kernel(&self.ctx);
6203                    let (grid, block) = dispatch_grid_1d(*outer, 256);
6204                    let cfg = LaunchConfig {
6205                        grid_dim: (grid, 1, 1),
6206                        block_dim: (block, 1, 1),
6207                        shared_mem_bytes: 0,
6208                    };
6209                    let mut launcher = stream.launch_builder(&kernel.function);
6210                    launcher
6211                        .arg(self.arena.f32_buf_mut())
6212                        .arg(outer)
6213                        .arg(inner)
6214                        .arg(in_off)
6215                        .arg(out_off);
6216                    unsafe {
6217                        launcher
6218                            .launch(cfg)
6219                            .expect("rlx-cuda: argmax launch failed");
6220                    }
6221                }
6222                Step::Transpose {
6223                    rank,
6224                    out_total,
6225                    in_off,
6226                    out_off,
6227                    meta_idx,
6228                } => {
6229                    let kernel = transpose_kernel(&self.ctx);
6230                    let (grid, block) = dispatch_grid_1d(*out_total, 256);
6231                    let cfg = LaunchConfig {
6232                        grid_dim: (grid, 1, 1),
6233                        block_dim: (block, 1, 1),
6234                        shared_mem_bytes: 0,
6235                    };
6236                    let mut launcher = stream.launch_builder(&kernel.function);
6237                    launcher
6238                        .arg(self.arena.f32_buf_mut())
6239                        .arg(rank)
6240                        .arg(out_total)
6241                        .arg(in_off)
6242                        .arg(out_off)
6243                        .arg(&self.meta_buffers[*meta_idx]);
6244                    unsafe {
6245                        launcher
6246                            .launch(cfg)
6247                            .expect("rlx-cuda: transpose launch failed");
6248                    }
6249                }
6250                Step::Expand {
6251                    rank,
6252                    out_total,
6253                    in_off,
6254                    out_off,
6255                    meta_idx,
6256                } => {
6257                    let kernel = expand_kernel(&self.ctx);
6258                    let (grid, block) = dispatch_grid_1d(*out_total, 256);
6259                    let cfg = LaunchConfig {
6260                        grid_dim: (grid, 1, 1),
6261                        block_dim: (block, 1, 1),
6262                        shared_mem_bytes: 0,
6263                    };
6264                    let mut launcher = stream.launch_builder(&kernel.function);
6265                    launcher
6266                        .arg(self.arena.f32_buf_mut())
6267                        .arg(rank)
6268                        .arg(out_total)
6269                        .arg(in_off)
6270                        .arg(out_off)
6271                        .arg(&self.meta_buffers[*meta_idx]);
6272                    unsafe {
6273                        launcher
6274                            .launch(cfg)
6275                            .expect("rlx-cuda: expand launch failed");
6276                    }
6277                }
6278                Step::Concat {
6279                    total,
6280                    outer,
6281                    inner,
6282                    axis_in_size,
6283                    axis_out_size,
6284                    start,
6285                    in_off,
6286                    out_off,
6287                } => {
6288                    let kernel = concat_kernel(&self.ctx);
6289                    let (grid, block) = dispatch_grid_1d(*total, 256);
6290                    let cfg = LaunchConfig {
6291                        grid_dim: (grid, 1, 1),
6292                        block_dim: (block, 1, 1),
6293                        shared_mem_bytes: 0,
6294                    };
6295                    let mut launcher = stream.launch_builder(&kernel.function);
6296                    launcher
6297                        .arg(self.arena.f32_buf_mut())
6298                        .arg(total)
6299                        .arg(outer)
6300                        .arg(inner)
6301                        .arg(axis_in_size)
6302                        .arg(axis_out_size)
6303                        .arg(start)
6304                        .arg(in_off)
6305                        .arg(out_off);
6306                    unsafe {
6307                        launcher
6308                            .launch(cfg)
6309                            .expect("rlx-cuda: concat launch failed");
6310                    }
6311                }
6312                Step::Attention {
6313                    batch,
6314                    heads,
6315                    seq_q,
6316                    seq_k,
6317                    head_dim,
6318                    q_off,
6319                    k_off,
6320                    v_off,
6321                    out_off,
6322                    mask_off,
6323                    mask_kind,
6324                    scale_bits,
6325                    window,
6326                    seq_q_stride,
6327                    seq_k_stride,
6328                    mask_batch_stride,
6329                    mask_head_stride,
6330                    q_batch_stride,
6331                    q_head_stride,
6332                    q_seq_stride,
6333                    k_batch_stride,
6334                    k_head_stride,
6335                    k_seq_stride,
6336                    v_batch_stride,
6337                    v_head_stride,
6338                    v_seq_stride,
6339                    o_batch_stride,
6340                    o_head_stride,
6341                    o_seq_stride,
6342                } => {
6343                    // Tiled flash supports arbitrary Q/K/V strides (BSHD and BHSD).
6344                    // Row kernel only when head_dim exceeds the flash tile cap or forced.
6345                    let use_row = rlx_ir::attention_dispatch_use_row(
6346                        *head_dim,
6347                        "RLX_CUDA_FORCE_ATTENTION_ROW",
6348                    );
6349                    let mut launcher = stream.launch_builder(if use_row {
6350                        &attention_row_kernel(&self.ctx).function
6351                    } else {
6352                        &attention_kernel(&self.ctx).function
6353                    });
6354                    launcher
6355                        .arg(self.arena.f32_buf_mut())
6356                        .arg(batch)
6357                        .arg(heads)
6358                        .arg(seq_q)
6359                        .arg(seq_k)
6360                        .arg(head_dim)
6361                        .arg(q_off)
6362                        .arg(k_off)
6363                        .arg(v_off)
6364                        .arg(out_off)
6365                        .arg(mask_off)
6366                        .arg(mask_kind)
6367                        .arg(scale_bits)
6368                        .arg(window)
6369                        .arg(seq_q_stride)
6370                        .arg(seq_k_stride)
6371                        .arg(mask_batch_stride)
6372                        .arg(mask_head_stride)
6373                        .arg(q_batch_stride)
6374                        .arg(q_head_stride)
6375                        .arg(q_seq_stride)
6376                        .arg(k_batch_stride)
6377                        .arg(k_head_stride)
6378                        .arg(k_seq_stride)
6379                        .arg(v_batch_stride)
6380                        .arg(v_head_stride)
6381                        .arg(v_seq_stride)
6382                        .arg(o_batch_stride)
6383                        .arg(o_head_stride)
6384                        .arg(o_seq_stride);
6385                    let cfg = if use_row {
6386                        let total = batch * heads * seq_q;
6387                        let block = 256u32;
6388                        LaunchConfig {
6389                            grid_dim: (total.div_ceil(block), 1, 1),
6390                            block_dim: (block, 1, 1),
6391                            shared_mem_bytes: 0,
6392                        }
6393                    } else {
6394                        let q_blocks = (*seq_q).div_ceil(16);
6395                        LaunchConfig {
6396                            grid_dim: (q_blocks, batch * heads, 1),
6397                            block_dim: (128, 1, 1),
6398                            shared_mem_bytes: 0,
6399                        }
6400                    };
6401                    unsafe {
6402                        launcher
6403                            .launch(cfg)
6404                            .expect("rlx-cuda: attention launch failed");
6405                    }
6406                }
6407                Step::AttentionBackward {
6408                    batch,
6409                    heads,
6410                    seq_q,
6411                    seq_k,
6412                    head_dim,
6413                    q_off,
6414                    k_off,
6415                    v_off,
6416                    dy_off,
6417                    out_off,
6418                    mask_off,
6419                    mask_kind,
6420                    scale_bits,
6421                    window,
6422                    wrt,
6423                } => {
6424                    let kernel = attention_bwd_kernel(&self.ctx);
6425                    let seq_axis = if *wrt == 0 { *seq_q } else { *seq_k };
6426                    let y_blocks = seq_axis.div_ceil(256);
6427                    let cfg = LaunchConfig {
6428                        grid_dim: (batch * heads, y_blocks, 1),
6429                        block_dim: (256, 1, 1),
6430                        shared_mem_bytes: 0,
6431                    };
6432                    let mut launcher = stream.launch_builder(&kernel.function);
6433                    launcher
6434                        .arg(self.arena.f32_buf_mut())
6435                        .arg(batch)
6436                        .arg(heads)
6437                        .arg(seq_q)
6438                        .arg(seq_k)
6439                        .arg(head_dim)
6440                        .arg(q_off)
6441                        .arg(k_off)
6442                        .arg(v_off)
6443                        .arg(dy_off)
6444                        .arg(out_off)
6445                        .arg(mask_off)
6446                        .arg(mask_kind)
6447                        .arg(scale_bits)
6448                        .arg(window)
6449                        .arg(wrt);
6450                    unsafe {
6451                        launcher
6452                            .launch(cfg)
6453                            .expect("rlx-cuda: attention_bwd launch failed");
6454                    }
6455                }
6456                Step::Rope {
6457                    n_total,
6458                    seq,
6459                    head_dim,
6460                    half,
6461                    in_off,
6462                    cos_off,
6463                    sin_off,
6464                    out_off,
6465                    last_dim,
6466                } => {
6467                    let kernel = rope_kernel(&self.ctx);
6468                    let (grid, block) = dispatch_grid_1d(*n_total, 256);
6469                    let cfg = LaunchConfig {
6470                        grid_dim: (grid, 1, 1),
6471                        block_dim: (block, 1, 1),
6472                        shared_mem_bytes: 0,
6473                    };
6474                    let mut launcher = stream.launch_builder(&kernel.function);
6475                    launcher
6476                        .arg(self.arena.f32_buf_mut())
6477                        .arg(n_total)
6478                        .arg(seq)
6479                        .arg(head_dim)
6480                        .arg(half)
6481                        .arg(in_off)
6482                        .arg(cos_off)
6483                        .arg(sin_off)
6484                        .arg(out_off)
6485                        .arg(last_dim);
6486                    unsafe {
6487                        launcher.launch(cfg).expect("rlx-cuda: rope launch failed");
6488                    }
6489                }
6490                Step::Cumsum {
6491                    outer,
6492                    inner,
6493                    in_off,
6494                    out_off,
6495                    exclusive,
6496                } => {
6497                    let outer_s = scale(*outer);
6498                    if outer_s == 0 {
6499                        continue;
6500                    }
6501                    let kernel = cumsum_kernel(&self.ctx);
6502                    let (grid, block) = dispatch_grid_1d(outer_s, 256);
6503                    let cfg = LaunchConfig {
6504                        grid_dim: (grid, 1, 1),
6505                        block_dim: (block, 1, 1),
6506                        shared_mem_bytes: 0,
6507                    };
6508                    let mut launcher = stream.launch_builder(&kernel.function);
6509                    launcher
6510                        .arg(self.arena.f32_buf_mut())
6511                        .arg(&outer_s)
6512                        .arg(inner)
6513                        .arg(in_off)
6514                        .arg(out_off)
6515                        .arg(exclusive);
6516                    unsafe {
6517                        launcher
6518                            .launch(cfg)
6519                            .expect("rlx-cuda: cumsum launch failed");
6520                    }
6521                }
6522                Step::TopK {
6523                    outer,
6524                    inner,
6525                    k,
6526                    in_off,
6527                    out_off,
6528                } => {
6529                    let kernel = topk_kernel(&self.ctx);
6530                    let (grid, block) = dispatch_grid_1d(*outer, 256);
6531                    let cfg = LaunchConfig {
6532                        grid_dim: (grid, 1, 1),
6533                        block_dim: (block, 1, 1),
6534                        shared_mem_bytes: 0,
6535                    };
6536                    let mut launcher = stream.launch_builder(&kernel.function);
6537                    launcher
6538                        .arg(self.arena.f32_buf_mut())
6539                        .arg(outer)
6540                        .arg(inner)
6541                        .arg(k)
6542                        .arg(in_off)
6543                        .arg(out_off);
6544                    unsafe {
6545                        launcher.launch(cfg).expect("rlx-cuda: topk launch failed");
6546                    }
6547                }
6548                Step::GroupedMatmul {
6549                    m,
6550                    k,
6551                    n,
6552                    num_experts,
6553                    in_off,
6554                    w_off,
6555                    idx_off,
6556                    out_off,
6557                } => {
6558                    // Tier 1: sorted-batch dispatch via cuBLAS. Reads
6559                    // the idx buffer back to host, finds runs of
6560                    // identical consecutive expert ids, and issues one
6561                    // cublasSgemm per run. Wins big when tokens are
6562                    // pre-sorted by expert (the standard MoE upstream
6563                    // convention) — for random idx the run count is
6564                    // ~m and the launch overhead would negate the win,
6565                    // so we fall back to the kernel in that case.
6566                    let used_sorted = if let Some(blas) = self.blas.as_ref() {
6567                        // Sync first so prior writes to idx are visible.
6568                        stream
6569                            .synchronize()
6570                            .expect("rlx-cuda: stream sync before idx download");
6571                        let idx_host = {
6572                            let idx_slot = self
6573                                .arena
6574                                .f32_buf()
6575                                .slice(*idx_off as usize..(idx_off + m) as usize);
6576                            stream.clone_dtoh(&idx_slot).ok()
6577                        };
6578                        match idx_host {
6579                            Some(idx_vec) => {
6580                                let mut runs: Vec<(u32, u32, u32)> = Vec::new();
6581                                let mut i = 0usize;
6582                                let mn = *m as usize;
6583                                while i < mn {
6584                                    let e = idx_vec[i] as u32;
6585                                    let mut j = i + 1;
6586                                    while j < mn && (idx_vec[j] as u32) == e {
6587                                        j += 1;
6588                                    }
6589                                    if e < *num_experts {
6590                                        runs.push((i as u32, j as u32, e));
6591                                    }
6592                                    i = j;
6593                                }
6594                                // Heuristic: bail when the run count
6595                                // exceeds m/4 (idx isn't usefully sorted).
6596                                let threshold = (mn / 4).max(2);
6597                                if !runs.is_empty() && runs.len() <= threshold {
6598                                    let blas = blas.lock().unwrap();
6599                                    let (arena_ptr, _record) =
6600                                        self.arena.f32_buf_mut().device_ptr_mut(&stream);
6601                                    let alpha: f32 = 1.0;
6602                                    let beta: f32 = 0.0;
6603                                    let mut all_ok = true;
6604                                    for (lo, hi, e) in &runs {
6605                                        let rows = hi - lo;
6606                                        let a_dev = arena_ptr + ((*in_off + lo * k) as u64) * 4;
6607                                        let b_dev = arena_ptr + ((*w_off + e * k * n) as u64) * 4;
6608                                        let c_dev = arena_ptr + ((*out_off + lo * n) as u64) * 4;
6609                                        let r = unsafe {
6610                                            cudarc::cublas::result::sgemm(
6611                                                *blas.handle(),
6612                                                cublas_sys::cublasOperation_t::CUBLAS_OP_N,
6613                                                cublas_sys::cublasOperation_t::CUBLAS_OP_N,
6614                                                *n as i32,
6615                                                rows as i32,
6616                                                *k as i32,
6617                                                &alpha as *const f32,
6618                                                b_dev as *const f32,
6619                                                *n as i32,
6620                                                a_dev as *const f32,
6621                                                *k as i32,
6622                                                &beta as *const f32,
6623                                                c_dev as *mut f32,
6624                                                *n as i32,
6625                                            )
6626                                        };
6627                                        if r.is_err() {
6628                                            all_ok = false;
6629                                            break;
6630                                        }
6631                                    }
6632                                    all_ok
6633                                } else {
6634                                    false
6635                                }
6636                            }
6637                            None => false,
6638                        }
6639                    } else {
6640                        false
6641                    };
6642                    if used_sorted {
6643                        continue;
6644                    }
6645
6646                    // Fallback: per-token expert lookup kernel.
6647                    let kernel = grouped_matmul_kernel(&self.ctx);
6648                    let cfg = LaunchConfig {
6649                        grid_dim: ((*n).div_ceil(8), (*m).div_ceil(8), 1),
6650                        block_dim: (8, 8, 1),
6651                        shared_mem_bytes: 0,
6652                    };
6653                    let mut launcher = stream.launch_builder(&kernel.function);
6654                    launcher
6655                        .arg(self.arena.f32_buf_mut())
6656                        .arg(m)
6657                        .arg(k)
6658                        .arg(n)
6659                        .arg(num_experts)
6660                        .arg(in_off)
6661                        .arg(w_off)
6662                        .arg(idx_off)
6663                        .arg(out_off);
6664                    unsafe {
6665                        launcher
6666                            .launch(cfg)
6667                            .expect("rlx-cuda: grouped_matmul launch failed");
6668                    }
6669                }
6670                Step::ScatterAddZero { out_off, out_total } => {
6671                    let kernel = scatter_add_zero_kernel(&self.ctx);
6672                    let (grid, block) = dispatch_grid_1d(*out_total, 256);
6673                    let cfg = LaunchConfig {
6674                        grid_dim: (grid, 1, 1),
6675                        block_dim: (block, 1, 1),
6676                        shared_mem_bytes: 0,
6677                    };
6678                    let mut launcher = stream.launch_builder(&kernel.function);
6679                    launcher
6680                        .arg(self.arena.f32_buf_mut())
6681                        .arg(out_off)
6682                        .arg(out_total);
6683                    unsafe {
6684                        launcher
6685                            .launch(cfg)
6686                            .expect("rlx-cuda: scatter_add_zero launch failed");
6687                    }
6688                }
6689                Step::ScatterAddAcc {
6690                    out_off,
6691                    upd_off,
6692                    idx_off,
6693                    num_updates,
6694                    trailing,
6695                    out_dim,
6696                } => {
6697                    let kernel = scatter_add_acc_kernel(&self.ctx);
6698                    let total = num_updates * trailing;
6699                    let (grid, block) = dispatch_grid_1d(total, 256);
6700                    let cfg = LaunchConfig {
6701                        grid_dim: (grid, 1, 1),
6702                        block_dim: (block, 1, 1),
6703                        shared_mem_bytes: 0,
6704                    };
6705                    let mut launcher = stream.launch_builder(&kernel.function);
6706                    launcher
6707                        .arg(self.arena.f32_buf_mut())
6708                        .arg(out_off)
6709                        .arg(upd_off)
6710                        .arg(idx_off)
6711                        .arg(num_updates)
6712                        .arg(trailing)
6713                        .arg(out_dim);
6714                    unsafe {
6715                        launcher
6716                            .launch(cfg)
6717                            .expect("rlx-cuda: scatter_add_acc launch failed");
6718                    }
6719                }
6720                Step::DequantMatmul {
6721                    m,
6722                    k,
6723                    n,
6724                    block_size,
6725                    scheme_id,
6726                    x_off,
6727                    w_off,
6728                    scale_off,
6729                    zp_off,
6730                    out_off,
6731                } => {
6732                    let kernel = dequant_matmul_kernel(&self.ctx);
6733                    let cfg = LaunchConfig {
6734                        grid_dim: ((*n).div_ceil(8), (*m).div_ceil(8), 1),
6735                        block_dim: (8, 8, 1),
6736                        shared_mem_bytes: 0,
6737                    };
6738                    let mut launcher = stream.launch_builder(&kernel.function);
6739                    launcher
6740                        .arg(self.arena.f32_buf_mut())
6741                        .arg(m)
6742                        .arg(k)
6743                        .arg(n)
6744                        .arg(block_size)
6745                        .arg(scheme_id)
6746                        .arg(x_off)
6747                        .arg(w_off)
6748                        .arg(scale_off)
6749                        .arg(zp_off)
6750                        .arg(out_off);
6751                    unsafe {
6752                        launcher
6753                            .launch(cfg)
6754                            .expect("rlx-cuda: dequant_matmul launch failed");
6755                    }
6756                }
6757                Step::DequantMatmulGguf {
6758                    m,
6759                    k,
6760                    n,
6761                    scheme_id,
6762                    x_byte_off,
6763                    w_byte_off,
6764                    out_byte_off,
6765                } => {
6766                    let use_gpu = self.dequant_scratch_off > 0 && self.blas.is_some();
6767                    if use_gpu {
6768                        let blas = self.blas.as_ref().unwrap();
6769                        crate::gguf_gpu::run_dequant_matmul_gguf_gpu(
6770                            &self.ctx,
6771                            &stream,
6772                            self.arena.f32_buf_mut(),
6773                            blas,
6774                            *m as usize,
6775                            *k as usize,
6776                            *n as usize,
6777                            *scheme_id,
6778                            *x_byte_off as usize,
6779                            *w_byte_off as usize,
6780                            self.dequant_scratch_off,
6781                            *out_byte_off as usize,
6782                        );
6783                    } else {
6784                        crate::gguf_host::run_dequant_matmul_gguf(
6785                            &stream,
6786                            self.arena.f32_buf_mut(),
6787                            *m as usize,
6788                            *k as usize,
6789                            *n as usize,
6790                            *scheme_id,
6791                            *x_byte_off as usize,
6792                            *w_byte_off as usize,
6793                            *out_byte_off as usize,
6794                        );
6795                    }
6796                }
6797                Step::DequantGroupedMatmulGguf {
6798                    m,
6799                    k,
6800                    n,
6801                    num_experts,
6802                    scheme_id,
6803                    x_byte_off,
6804                    w_byte_off,
6805                    idx_byte_off,
6806                    out_byte_off,
6807                } => {
6808                    let use_gpu = self.dequant_scratch_off > 0 && self.blas.is_some();
6809                    if use_gpu {
6810                        let blas = self.blas.as_ref().unwrap();
6811                        crate::gguf_gpu::run_dequant_grouped_matmul_gguf_gpu(
6812                            &self.ctx,
6813                            &stream,
6814                            self.arena.f32_buf_mut(),
6815                            blas,
6816                            *m as usize,
6817                            *k as usize,
6818                            *n as usize,
6819                            *num_experts as usize,
6820                            *scheme_id,
6821                            *x_byte_off as usize,
6822                            *w_byte_off as usize,
6823                            *idx_byte_off as usize,
6824                            self.dequant_scratch_off,
6825                            *out_byte_off as usize,
6826                        );
6827                    } else {
6828                        crate::gguf_host::run_dequant_grouped_matmul_gguf(
6829                            &stream,
6830                            self.arena.f32_buf_mut(),
6831                            *m as usize,
6832                            *k as usize,
6833                            *n as usize,
6834                            *num_experts as usize,
6835                            *scheme_id,
6836                            *x_byte_off as usize,
6837                            *w_byte_off as usize,
6838                            *idx_byte_off as usize,
6839                            *out_byte_off as usize,
6840                        );
6841                    }
6842                }
6843                Step::Sample {
6844                    outer,
6845                    inner,
6846                    in_off,
6847                    out_off,
6848                    top_k,
6849                    top_p_bits,
6850                    temp_bits,
6851                    seed_lo,
6852                    seed_hi,
6853                } => {
6854                    let kernel = sample_kernel(&self.ctx);
6855                    let (grid, block) = dispatch_grid_1d(*outer, 256);
6856                    let cfg = LaunchConfig {
6857                        grid_dim: (grid, 1, 1),
6858                        block_dim: (block, 1, 1),
6859                        shared_mem_bytes: 0,
6860                    };
6861                    let mut launcher = stream.launch_builder(&kernel.function);
6862                    launcher
6863                        .arg(self.arena.f32_buf_mut())
6864                        .arg(outer)
6865                        .arg(inner)
6866                        .arg(in_off)
6867                        .arg(out_off)
6868                        .arg(top_k)
6869                        .arg(top_p_bits)
6870                        .arg(temp_bits)
6871                        .arg(seed_lo)
6872                        .arg(seed_hi);
6873                    unsafe {
6874                        launcher
6875                            .launch(cfg)
6876                            .expect("rlx-cuda: sample launch failed");
6877                    }
6878                }
6879                Step::SelectiveScan {
6880                    batch,
6881                    seq,
6882                    hidden,
6883                    state_size,
6884                    x_off,
6885                    delta_off,
6886                    a_off,
6887                    b_off,
6888                    c_off,
6889                    out_off,
6890                } => {
6891                    let kernel = selective_scan_kernel(&self.ctx);
6892                    let total = batch * hidden;
6893                    let (grid, block) = dispatch_grid_1d(total, 256);
6894                    let cfg = LaunchConfig {
6895                        grid_dim: (grid, 1, 1),
6896                        block_dim: (block, 1, 1),
6897                        shared_mem_bytes: 0,
6898                    };
6899                    let mut launcher = stream.launch_builder(&kernel.function);
6900                    launcher
6901                        .arg(self.arena.f32_buf_mut())
6902                        .arg(batch)
6903                        .arg(seq)
6904                        .arg(hidden)
6905                        .arg(state_size)
6906                        .arg(x_off)
6907                        .arg(delta_off)
6908                        .arg(a_off)
6909                        .arg(b_off)
6910                        .arg(c_off)
6911                        .arg(out_off);
6912                    unsafe {
6913                        launcher
6914                            .launch(cfg)
6915                            .expect("rlx-cuda: selective_scan launch failed");
6916                    }
6917                }
6918                Step::Fft {
6919                    src_byte_off,
6920                    dst_byte_off,
6921                    outer,
6922                    n_complex,
6923                    inverse,
6924                    norm_tag,
6925                    dtype_tag,
6926                    use_gpu,
6927                } => {
6928                    if *use_gpu {
6929                        let norm = rlx_ir::fft::FftNorm::from_tag(*norm_tag);
6930                        let scale = norm.output_scale(*n_complex as usize, *inverse) as f32;
6931                        crate::fft_dispatch::run_fft_gpu(
6932                            &self.ctx,
6933                            &stream,
6934                            self.arena.f32_buf_mut(),
6935                            *src_byte_off / 4,
6936                            *dst_byte_off / 4,
6937                            *outer,
6938                            *n_complex,
6939                            *inverse,
6940                            scale,
6941                        );
6942                    } else {
6943                        let (buf, arena_size) = self.arena.f32_buf_and_size();
6944                        crate::fft_host::run_fft1d(
6945                            &stream,
6946                            buf,
6947                            arena_size,
6948                            *src_byte_off as usize,
6949                            *dst_byte_off as usize,
6950                            *outer as usize,
6951                            *n_complex as usize,
6952                            *inverse,
6953                            *norm_tag,
6954                            fft_dtype_from_tag(*dtype_tag),
6955                        );
6956                    }
6957                }
6958                Step::WelchPeaksGpu {
6959                    spec_off,
6960                    dst_off,
6961                    welch_batch,
6962                    n_fft,
6963                    n_segments,
6964                    k,
6965                    n_bins,
6966                } => {
6967                    crate::welch_peaks_dispatch::run_welch_peaks_gpu(
6968                        &self.ctx,
6969                        &stream,
6970                        self.arena.f32_buf_mut(),
6971                        *spec_off,
6972                        *dst_off,
6973                        *welch_batch,
6974                        *n_fft,
6975                        *n_segments,
6976                        *k,
6977                        *n_bins,
6978                    );
6979                }
6980                Step::LogMelHost { .. }
6981                | Step::LogMelBackwardHost { .. }
6982                | Step::WelchPeaksHost { .. } => {}
6983                Step::Im2ColHost {
6984                    x_byte_off,
6985                    col_byte_off,
6986                    n,
6987                    c_in,
6988                    h,
6989                    w,
6990                    h_out,
6991                    w_out,
6992                    kh,
6993                    kw,
6994                    sh,
6995                    sw,
6996                    ph,
6997                    pw,
6998                    dh,
6999                    dw_dil,
7000                    use_gpu,
7001                } => {
7002                    if *use_gpu {
7003                        let kernel = im2col_kernel(&self.ctx);
7004                        let m = *n * *h_out * *w_out;
7005                        let k = *c_in * *kh * *kw;
7006                        let total = m * k;
7007                        let (grid, block) = dispatch_grid_1d(total, 256);
7008                        let cfg = LaunchConfig {
7009                            grid_dim: (grid, 1, 1),
7010                            block_dim: (block, 1, 1),
7011                            shared_mem_bytes: 0,
7012                        };
7013                        let x_off = *x_byte_off / 4;
7014                        let col_off = *col_byte_off / 4;
7015                        let mut launcher = stream.launch_builder(&kernel.function);
7016                        launcher
7017                            .arg(self.arena.f32_buf_mut())
7018                            .arg(n)
7019                            .arg(c_in)
7020                            .arg(h)
7021                            .arg(w)
7022                            .arg(h_out)
7023                            .arg(w_out)
7024                            .arg(kh)
7025                            .arg(kw)
7026                            .arg(sh)
7027                            .arg(sw)
7028                            .arg(ph)
7029                            .arg(pw)
7030                            .arg(dh)
7031                            .arg(dw_dil)
7032                            .arg(&x_off)
7033                            .arg(&col_off);
7034                        unsafe {
7035                            launcher
7036                                .launch(cfg)
7037                                .expect("rlx-cuda: im2col launch failed");
7038                        }
7039                    } else {
7040                        crate::im2col_host::run_im2col(
7041                            &stream,
7042                            self.arena.f32_buf_mut(),
7043                            *x_byte_off as usize,
7044                            *col_byte_off as usize,
7045                            *n,
7046                            *c_in,
7047                            *h,
7048                            *w,
7049                            *h_out,
7050                            *w_out,
7051                            *kh,
7052                            *kw,
7053                            *sh,
7054                            *sw,
7055                            *ph,
7056                            *pw,
7057                            *dh,
7058                            *dw_dil,
7059                        );
7060                    }
7061                }
7062                Step::GatedDeltaNet {
7063                    q_byte_off,
7064                    k_byte_off,
7065                    v_byte_off,
7066                    g_byte_off,
7067                    beta_byte_off,
7068                    state_byte_off,
7069                    dst_byte_off,
7070                    batch,
7071                    seq,
7072                    heads,
7073                    state_size,
7074                    use_carry,
7075                } => {
7076                    let (buf, arena_size) = self.arena.f32_buf_and_size();
7077                    crate::gdn_host::run_gated_delta_net(
7078                        &stream,
7079                        buf,
7080                        arena_size,
7081                        *q_byte_off as usize,
7082                        *k_byte_off as usize,
7083                        *v_byte_off as usize,
7084                        *g_byte_off as usize,
7085                        *beta_byte_off as usize,
7086                        *state_byte_off as usize,
7087                        *dst_byte_off as usize,
7088                        *batch as usize,
7089                        *seq as usize,
7090                        *heads as usize,
7091                        *state_size as usize,
7092                        *use_carry,
7093                    );
7094                }
7095                Step::Llada2GroupLimitedGate {
7096                    sig_off,
7097                    route_off,
7098                    out_off,
7099                    n_elems,
7100                    attrs,
7101                } => {
7102                    let (buf, arena_size) = self.arena.f32_buf_and_size();
7103                    crate::llada2_gate_host::run_llada2_group_limited_gate(
7104                        &stream,
7105                        buf,
7106                        arena_size,
7107                        *sig_off as usize,
7108                        *route_off as usize,
7109                        *out_off as usize,
7110                        *n_elems as usize,
7111                        attrs,
7112                    );
7113                }
7114                Step::UmapKnn {
7115                    pairwise_off,
7116                    out_off,
7117                    n,
7118                    k,
7119                } => {
7120                    let (buf, arena_size) = self.arena.f32_buf_and_size();
7121                    crate::umap_knn_host::run_umap_knn(
7122                        &stream,
7123                        buf,
7124                        arena_size,
7125                        *pairwise_off as usize,
7126                        *out_off as usize,
7127                        *n as usize,
7128                        *k as usize,
7129                    );
7130                }
7131                Step::LayerNorm2d {
7132                    src_off,
7133                    g_off,
7134                    b_off,
7135                    dst_off,
7136                    n,
7137                    c,
7138                    h,
7139                    w,
7140                    eps_bits,
7141                } => {
7142                    let kernel = layer_norm2d_kernel(&self.ctx);
7143                    let total = n * h * w;
7144                    let (grid, block) = dispatch_grid_1d(total, 256);
7145                    let cfg = LaunchConfig {
7146                        grid_dim: (grid, 1, 1),
7147                        block_dim: (block, 1, 1),
7148                        shared_mem_bytes: 0,
7149                    };
7150                    let mut launcher = stream.launch_builder(&kernel.function);
7151                    launcher
7152                        .arg(self.arena.f32_buf_mut())
7153                        .arg(src_off)
7154                        .arg(g_off)
7155                        .arg(b_off)
7156                        .arg(dst_off)
7157                        .arg(n)
7158                        .arg(c)
7159                        .arg(h)
7160                        .arg(w)
7161                        .arg(eps_bits);
7162                    unsafe {
7163                        launcher
7164                            .launch(cfg)
7165                            .expect("rlx-cuda: layer_norm2d launch failed");
7166                    }
7167                }
7168                Step::ConvTranspose2d {
7169                    src_off,
7170                    w_off,
7171                    dst_off,
7172                    n,
7173                    c_in,
7174                    h,
7175                    w_in,
7176                    c_out,
7177                    h_out,
7178                    w_out,
7179                    kh,
7180                    kw,
7181                    sh,
7182                    sw,
7183                    ph,
7184                    pw,
7185                    dh,
7186                    dw,
7187                    groups,
7188                } => {
7189                    let kernel = conv_transpose2d_kernel(&self.ctx);
7190                    let total = n * c_out * h_out * w_out;
7191                    let (grid, block) = dispatch_grid_1d(total, 256);
7192                    let cfg = LaunchConfig {
7193                        grid_dim: (grid, 1, 1),
7194                        block_dim: (block, 1, 1),
7195                        shared_mem_bytes: 0,
7196                    };
7197                    let mut launcher = stream.launch_builder(&kernel.function);
7198                    launcher
7199                        .arg(self.arena.f32_buf_mut())
7200                        .arg(src_off)
7201                        .arg(w_off)
7202                        .arg(dst_off)
7203                        .arg(n)
7204                        .arg(c_in)
7205                        .arg(h)
7206                        .arg(w_in)
7207                        .arg(c_out)
7208                        .arg(h_out)
7209                        .arg(w_out)
7210                        .arg(kh)
7211                        .arg(kw)
7212                        .arg(sh)
7213                        .arg(sw)
7214                        .arg(ph)
7215                        .arg(pw)
7216                        .arg(dh)
7217                        .arg(dw)
7218                        .arg(groups);
7219                    unsafe {
7220                        launcher
7221                            .launch(cfg)
7222                            .expect("rlx-cuda: conv_transpose2d launch failed");
7223                    }
7224                }
7225                Step::GroupNorm {
7226                    src_off,
7227                    g_off,
7228                    b_off,
7229                    dst_off,
7230                    n,
7231                    c,
7232                    h,
7233                    w,
7234                    num_groups,
7235                    eps_bits,
7236                } => {
7237                    let kernel = group_norm_kernel(&self.ctx);
7238                    let grid = n * num_groups;
7239                    let cfg = LaunchConfig {
7240                        grid_dim: (grid, 1, 1),
7241                        block_dim: (256, 1, 1),
7242                        shared_mem_bytes: 0,
7243                    };
7244                    let mut launcher = stream.launch_builder(&kernel.function);
7245                    launcher
7246                        .arg(self.arena.f32_buf_mut())
7247                        .arg(src_off)
7248                        .arg(g_off)
7249                        .arg(b_off)
7250                        .arg(dst_off)
7251                        .arg(n)
7252                        .arg(c)
7253                        .arg(h)
7254                        .arg(w)
7255                        .arg(num_groups)
7256                        .arg(eps_bits);
7257                    unsafe {
7258                        launcher
7259                            .launch(cfg)
7260                            .expect("rlx-cuda: group_norm launch failed");
7261                    }
7262                }
7263                Step::ResizeNearest2x {
7264                    src_off,
7265                    dst_off,
7266                    n,
7267                    c,
7268                    h,
7269                    w,
7270                } => {
7271                    let kernel = resize_nearest_2x_kernel(&self.ctx);
7272                    let total = n * c * h * 2 * w * 2;
7273                    let (grid, block) = dispatch_grid_1d(total, 256);
7274                    let cfg = LaunchConfig {
7275                        grid_dim: (grid, 1, 1),
7276                        block_dim: (block, 1, 1),
7277                        shared_mem_bytes: 0,
7278                    };
7279                    let mut launcher = stream.launch_builder(&kernel.function);
7280                    launcher
7281                        .arg(self.arena.f32_buf_mut())
7282                        .arg(src_off)
7283                        .arg(dst_off)
7284                        .arg(n)
7285                        .arg(c)
7286                        .arg(h)
7287                        .arg(w);
7288                    unsafe {
7289                        launcher
7290                            .launch(cfg)
7291                            .expect("rlx-cuda: resize_nearest_2x launch failed");
7292                    }
7293                }
7294                Step::GaussianSplatRender {
7295                    positions_off,
7296                    positions_len,
7297                    scales_off,
7298                    scales_len,
7299                    rotations_off,
7300                    rotations_len,
7301                    opacities_off,
7302                    opacities_len,
7303                    colors_off,
7304                    colors_len,
7305                    sh_coeffs_off,
7306                    sh_coeffs_len,
7307                    meta_off,
7308                    dst_off,
7309                    dst_len,
7310                    width,
7311                    height,
7312                    tile_size,
7313                    radius_scale,
7314                    alpha_cutoff,
7315                    max_splat_steps,
7316                    transmittance_threshold,
7317                    max_list_entries,
7318                } => {
7319                    let (buf, arena_size) = self.arena.f32_buf_and_size();
7320                    #[cfg(feature = "native-splat")]
7321                    crate::splat_native::run_gaussian_splat_render_native(
7322                        &stream,
7323                        buf,
7324                        arena_size,
7325                        *positions_off as usize,
7326                        *positions_len as usize,
7327                        *scales_off as usize,
7328                        *scales_len as usize,
7329                        *rotations_off as usize,
7330                        *rotations_len as usize,
7331                        *opacities_off as usize,
7332                        *opacities_len as usize,
7333                        *colors_off as usize,
7334                        *colors_len as usize,
7335                        *sh_coeffs_off as usize,
7336                        *sh_coeffs_len as usize,
7337                        *meta_off as usize,
7338                        *dst_off as usize,
7339                        *width,
7340                        *height,
7341                        *tile_size,
7342                        *radius_scale,
7343                        *alpha_cutoff,
7344                        *max_splat_steps,
7345                        *transmittance_threshold,
7346                        *max_list_entries,
7347                    );
7348                    #[cfg(not(feature = "native-splat"))]
7349                    crate::splat_host::run_gaussian_splat_render(
7350                        &stream,
7351                        buf,
7352                        arena_size,
7353                        *positions_off as usize,
7354                        *positions_len as usize,
7355                        *scales_off as usize,
7356                        *scales_len as usize,
7357                        *rotations_off as usize,
7358                        *rotations_len as usize,
7359                        *opacities_off as usize,
7360                        *opacities_len as usize,
7361                        *colors_off as usize,
7362                        *colors_len as usize,
7363                        *sh_coeffs_off as usize,
7364                        *sh_coeffs_len as usize,
7365                        *meta_off as usize,
7366                        *dst_off as usize,
7367                        *dst_len as usize,
7368                        *width,
7369                        *height,
7370                        *tile_size,
7371                        *radius_scale,
7372                        *alpha_cutoff,
7373                        *max_splat_steps,
7374                        *transmittance_threshold,
7375                        *max_list_entries,
7376                    );
7377                }
7378                Step::GaussianSplatPrepare {
7379                    positions_off,
7380                    positions_len,
7381                    scales_off,
7382                    scales_len,
7383                    rotations_off,
7384                    rotations_len,
7385                    opacities_off,
7386                    opacities_len,
7387                    colors_off,
7388                    colors_len,
7389                    sh_coeffs_off,
7390                    sh_coeffs_len,
7391                    meta_off,
7392                    meta_len,
7393                    prep_off,
7394                    prep_len,
7395                    width,
7396                    height,
7397                    tile_size,
7398                    radius_scale,
7399                    alpha_cutoff,
7400                    max_splat_steps,
7401                    transmittance_threshold,
7402                    max_list_entries,
7403                } => {
7404                    let (buf, arena_size) = self.arena.f32_buf_and_size();
7405                    crate::splat_host::run_gaussian_splat_prepare(
7406                        &stream,
7407                        buf,
7408                        arena_size,
7409                        *positions_off as usize,
7410                        *positions_len as usize,
7411                        *scales_off as usize,
7412                        *scales_len as usize,
7413                        *rotations_off as usize,
7414                        *rotations_len as usize,
7415                        *opacities_off as usize,
7416                        *opacities_len as usize,
7417                        *colors_off as usize,
7418                        *colors_len as usize,
7419                        *sh_coeffs_off as usize,
7420                        *sh_coeffs_len as usize,
7421                        *meta_off as usize,
7422                        *meta_len as usize,
7423                        *prep_off as usize,
7424                        *prep_len as usize,
7425                        *width,
7426                        *height,
7427                        *tile_size,
7428                        *radius_scale,
7429                        *alpha_cutoff,
7430                        *max_splat_steps,
7431                        *transmittance_threshold,
7432                        *max_list_entries,
7433                    );
7434                }
7435                Step::GaussianSplatRasterize {
7436                    prep_off,
7437                    prep_len,
7438                    meta_off,
7439                    meta_len,
7440                    dst_off,
7441                    dst_len,
7442                    count,
7443                    width,
7444                    height,
7445                    tile_size,
7446                    alpha_cutoff,
7447                    max_splat_steps,
7448                    transmittance_threshold,
7449                    max_list_entries,
7450                } => {
7451                    let (buf, arena_size) = self.arena.f32_buf_and_size();
7452                    crate::splat_host::run_gaussian_splat_rasterize(
7453                        &stream,
7454                        buf,
7455                        arena_size,
7456                        *prep_off as usize,
7457                        *prep_len as usize,
7458                        *meta_off as usize,
7459                        *meta_len as usize,
7460                        *dst_off as usize,
7461                        *dst_len as usize,
7462                        *count as usize,
7463                        *width,
7464                        *height,
7465                        *tile_size,
7466                        *alpha_cutoff,
7467                        *max_splat_steps,
7468                        *transmittance_threshold,
7469                        *max_list_entries,
7470                    );
7471                }
7472                Step::GaussianSplatRenderBackward {
7473                    positions_off,
7474                    positions_len,
7475                    scales_off,
7476                    scales_len,
7477                    rotations_off,
7478                    rotations_len,
7479                    opacities_off,
7480                    opacities_len,
7481                    colors_off,
7482                    colors_len,
7483                    sh_coeffs_off,
7484                    sh_coeffs_len,
7485                    meta_off,
7486                    d_loss_off,
7487                    d_loss_len,
7488                    packed_off,
7489                    packed_len,
7490                    width,
7491                    height,
7492                    tile_size,
7493                    radius_scale,
7494                    alpha_cutoff,
7495                    max_splat_steps,
7496                    transmittance_threshold,
7497                    max_list_entries,
7498                    loss_grad_clip,
7499                    sh_band,
7500                    max_anisotropy,
7501                } => {
7502                    let (buf, arena_size) = self.arena.f32_buf_and_size();
7503                    crate::splat_host::run_gaussian_splat_render_backward(
7504                        &stream,
7505                        buf,
7506                        arena_size,
7507                        *positions_off as usize,
7508                        *positions_len as usize,
7509                        *scales_off as usize,
7510                        *scales_len as usize,
7511                        *rotations_off as usize,
7512                        *rotations_len as usize,
7513                        *opacities_off as usize,
7514                        *opacities_len as usize,
7515                        *colors_off as usize,
7516                        *colors_len as usize,
7517                        *sh_coeffs_off as usize,
7518                        *sh_coeffs_len as usize,
7519                        *meta_off as usize,
7520                        *d_loss_off as usize,
7521                        *d_loss_len as usize,
7522                        *packed_off as usize,
7523                        *packed_len as usize,
7524                        *width,
7525                        *height,
7526                        *tile_size,
7527                        *radius_scale,
7528                        *alpha_cutoff,
7529                        *max_splat_steps,
7530                        *transmittance_threshold,
7531                        *max_list_entries,
7532                        *loss_grad_clip,
7533                        *sh_band,
7534                        *max_anisotropy,
7535                    );
7536                }
7537                Step::RmsNormBackwardInput {
7538                    x_byte_off,
7539                    gamma_byte_off,
7540                    beta_byte_off,
7541                    dy_byte_off,
7542                    dx_byte_off,
7543                    rows,
7544                    h,
7545                    eps_bits,
7546                } => {
7547                    launch_rms_norm_bwd(
7548                        &self.ctx,
7549                        &stream,
7550                        self.arena.f32_buf_mut(),
7551                        *rows,
7552                        *h,
7553                        *x_byte_off / 4,
7554                        *gamma_byte_off / 4,
7555                        *beta_byte_off / 4,
7556                        *dy_byte_off / 4,
7557                        *dx_byte_off / 4,
7558                        *eps_bits,
7559                        0,
7560                    );
7561                }
7562                Step::RmsNormBackwardGamma {
7563                    x_byte_off,
7564                    gamma_byte_off,
7565                    beta_byte_off,
7566                    dy_byte_off,
7567                    dgamma_byte_off,
7568                    rows,
7569                    h,
7570                    eps_bits,
7571                } => {
7572                    launch_rms_norm_bwd(
7573                        &self.ctx,
7574                        &stream,
7575                        self.arena.f32_buf_mut(),
7576                        *rows,
7577                        *h,
7578                        *x_byte_off / 4,
7579                        *gamma_byte_off / 4,
7580                        *beta_byte_off / 4,
7581                        *dy_byte_off / 4,
7582                        *dgamma_byte_off / 4,
7583                        *eps_bits,
7584                        1,
7585                    );
7586                }
7587                Step::RmsNormBackwardBeta {
7588                    x_byte_off,
7589                    gamma_byte_off,
7590                    beta_byte_off,
7591                    dy_byte_off,
7592                    dbeta_byte_off,
7593                    rows,
7594                    h,
7595                    eps_bits,
7596                } => {
7597                    launch_rms_norm_bwd(
7598                        &self.ctx,
7599                        &stream,
7600                        self.arena.f32_buf_mut(),
7601                        *rows,
7602                        *h,
7603                        *x_byte_off / 4,
7604                        *gamma_byte_off / 4,
7605                        *beta_byte_off / 4,
7606                        *dy_byte_off / 4,
7607                        *dbeta_byte_off / 4,
7608                        *eps_bits,
7609                        2,
7610                    );
7611                }
7612                Step::RopeBackward {
7613                    dy_byte_off,
7614                    cos_byte_off,
7615                    sin_byte_off,
7616                    dx_byte_off,
7617                    batch,
7618                    seq,
7619                    hidden,
7620                    head_dim,
7621                    n_rot,
7622                    cos_len,
7623                } => {
7624                    launch_rope_bwd(
7625                        &self.ctx,
7626                        &stream,
7627                        self.arena.f32_buf_mut(),
7628                        *batch,
7629                        *seq,
7630                        *hidden,
7631                        *head_dim,
7632                        *n_rot,
7633                        *dy_byte_off / 4,
7634                        *cos_byte_off / 4,
7635                        *sin_byte_off / 4,
7636                        *dx_byte_off / 4,
7637                        *cos_len,
7638                    );
7639                }
7640                Step::CumsumBackward {
7641                    dy_byte_off,
7642                    dx_byte_off,
7643                    rows,
7644                    cols,
7645                    exclusive,
7646                } => {
7647                    launch_cumsum_bwd(
7648                        &self.ctx,
7649                        &stream,
7650                        self.arena.f32_buf_mut(),
7651                        *rows,
7652                        *cols,
7653                        *dy_byte_off / 4,
7654                        *dx_byte_off / 4,
7655                        if *exclusive { 1 } else { 0 },
7656                    );
7657                }
7658                Step::GatherBackward {
7659                    dy_byte_off,
7660                    indices_byte_off,
7661                    dst_byte_off,
7662                    outer,
7663                    axis_dim,
7664                    num_idx,
7665                    trailing,
7666                } => {
7667                    launch_gather_bwd(
7668                        &self.ctx,
7669                        &stream,
7670                        self.arena.f32_buf_mut(),
7671                        *outer,
7672                        *axis_dim,
7673                        *num_idx,
7674                        *trailing,
7675                        *dy_byte_off / 4,
7676                        *indices_byte_off / 4,
7677                        *dst_byte_off / 4,
7678                    );
7679                }
7680                Step::MaxPool2dBackward {
7681                    x_byte_off,
7682                    dy_byte_off,
7683                    dx_byte_off,
7684                    n,
7685                    c,
7686                    h,
7687                    w,
7688                    h_out,
7689                    w_out,
7690                    kh,
7691                    kw,
7692                    sh,
7693                    sw,
7694                    ph,
7695                    pw,
7696                } => {
7697                    let buf = self.arena.f32_buf_mut();
7698                    crate::training_bwd_host::run_maxpool2d_backward(
7699                        &stream,
7700                        buf,
7701                        *x_byte_off as usize / 4,
7702                        *dy_byte_off as usize / 4,
7703                        *dx_byte_off as usize / 4,
7704                        *n,
7705                        *c,
7706                        *h,
7707                        *w,
7708                        *h_out,
7709                        *w_out,
7710                        *kh,
7711                        *kw,
7712                        *sh,
7713                        *sw,
7714                        *ph,
7715                        *pw,
7716                    );
7717                }
7718                Step::Conv2dBackwardInput {
7719                    dy_byte_off,
7720                    w_byte_off,
7721                    dx_byte_off,
7722                    n,
7723                    c_in,
7724                    h,
7725                    w_in,
7726                    c_out,
7727                    h_out,
7728                    w_out,
7729                    kh,
7730                    kw,
7731                    sh,
7732                    sw,
7733                    ph,
7734                    pw,
7735                    dh,
7736                    dw,
7737                    groups,
7738                } => {
7739                    let buf = self.arena.f32_buf_mut();
7740                    crate::training_bwd_host::run_conv2d_backward_input(
7741                        &stream,
7742                        buf,
7743                        *dy_byte_off as usize / 4,
7744                        *w_byte_off as usize / 4,
7745                        *dx_byte_off as usize / 4,
7746                        *n,
7747                        *c_in,
7748                        *h,
7749                        *w_in,
7750                        *c_out,
7751                        *h_out,
7752                        *w_out,
7753                        *kh,
7754                        *kw,
7755                        *sh,
7756                        *sw,
7757                        *ph,
7758                        *pw,
7759                        *dh,
7760                        *dw,
7761                        *groups,
7762                    );
7763                }
7764                Step::Conv2dBackwardWeight {
7765                    x_byte_off,
7766                    dy_byte_off,
7767                    dw_byte_off,
7768                    n,
7769                    c_in,
7770                    h,
7771                    w,
7772                    c_out,
7773                    h_out,
7774                    w_out,
7775                    kh,
7776                    kw,
7777                    sh,
7778                    sw,
7779                    ph,
7780                    pw,
7781                    dh,
7782                    dw_dil,
7783                    groups,
7784                } => {
7785                    let buf = self.arena.f32_buf_mut();
7786                    crate::training_bwd_host::run_conv2d_backward_weight(
7787                        &stream,
7788                        buf,
7789                        *x_byte_off as usize / 4,
7790                        *dy_byte_off as usize / 4,
7791                        *dw_byte_off as usize / 4,
7792                        *n,
7793                        *c_in,
7794                        *h,
7795                        *w,
7796                        *c_out,
7797                        *h_out,
7798                        *w_out,
7799                        *kh,
7800                        *kw,
7801                        *sh,
7802                        *sw,
7803                        *ph,
7804                        *pw,
7805                        *dh,
7806                        *dw_dil,
7807                        *groups,
7808                    );
7809                }
7810                Step::Pool1d {
7811                    n,
7812                    c,
7813                    l,
7814                    l_out,
7815                    kl,
7816                    sl,
7817                    pl,
7818                    op,
7819                    in_off,
7820                    out_off,
7821                } => {
7822                    let kernel = pool1d_kernel(&self.ctx);
7823                    let total = n * c * l_out;
7824                    let (grid, block) = dispatch_grid_1d(total, 256);
7825                    let cfg = LaunchConfig {
7826                        grid_dim: (grid, 1, 1),
7827                        block_dim: (block, 1, 1),
7828                        shared_mem_bytes: 0,
7829                    };
7830                    let mut launcher = stream.launch_builder(&kernel.function);
7831                    launcher
7832                        .arg(self.arena.f32_buf_mut())
7833                        .arg(n)
7834                        .arg(c)
7835                        .arg(l)
7836                        .arg(l_out)
7837                        .arg(kl)
7838                        .arg(sl)
7839                        .arg(pl)
7840                        .arg(op)
7841                        .arg(in_off)
7842                        .arg(out_off);
7843                    unsafe {
7844                        launcher
7845                            .launch(cfg)
7846                            .expect("rlx-cuda: pool1d launch failed");
7847                    }
7848                }
7849                Step::Pool2d {
7850                    n,
7851                    c,
7852                    h,
7853                    w,
7854                    h_out,
7855                    w_out,
7856                    kh,
7857                    kw,
7858                    sh,
7859                    sw,
7860                    ph,
7861                    pw,
7862                    op,
7863                    in_off,
7864                    out_off,
7865                } => {
7866                    let kernel = pool2d_kernel(&self.ctx);
7867                    let total = n * c * h_out * w_out;
7868                    let (grid, block) = dispatch_grid_1d(total, 256);
7869                    let cfg = LaunchConfig {
7870                        grid_dim: (grid, 1, 1),
7871                        block_dim: (block, 1, 1),
7872                        shared_mem_bytes: 0,
7873                    };
7874                    let mut launcher = stream.launch_builder(&kernel.function);
7875                    launcher
7876                        .arg(self.arena.f32_buf_mut())
7877                        .arg(n)
7878                        .arg(c)
7879                        .arg(h)
7880                        .arg(w)
7881                        .arg(h_out)
7882                        .arg(w_out)
7883                        .arg(kh)
7884                        .arg(kw)
7885                        .arg(sh)
7886                        .arg(sw)
7887                        .arg(ph)
7888                        .arg(pw)
7889                        .arg(op)
7890                        .arg(in_off)
7891                        .arg(out_off);
7892                    unsafe {
7893                        launcher
7894                            .launch(cfg)
7895                            .expect("rlx-cuda: pool2d launch failed");
7896                    }
7897                }
7898                Step::Pool3d {
7899                    n,
7900                    c,
7901                    d,
7902                    h,
7903                    w,
7904                    d_out,
7905                    h_out,
7906                    w_out,
7907                    kd,
7908                    kh,
7909                    kw,
7910                    sd,
7911                    sh,
7912                    sw,
7913                    pd,
7914                    ph,
7915                    pw,
7916                    op,
7917                    in_off,
7918                    out_off,
7919                } => {
7920                    let kernel = pool3d_kernel(&self.ctx);
7921                    let total = n * c * d_out * h_out * w_out;
7922                    let (grid, block) = dispatch_grid_1d(total, 256);
7923                    let cfg = LaunchConfig {
7924                        grid_dim: (grid, 1, 1),
7925                        block_dim: (block, 1, 1),
7926                        shared_mem_bytes: 0,
7927                    };
7928                    let mut launcher = stream.launch_builder(&kernel.function);
7929                    launcher
7930                        .arg(self.arena.f32_buf_mut())
7931                        .arg(n)
7932                        .arg(c)
7933                        .arg(d)
7934                        .arg(h)
7935                        .arg(w)
7936                        .arg(d_out)
7937                        .arg(h_out)
7938                        .arg(w_out)
7939                        .arg(kd)
7940                        .arg(kh)
7941                        .arg(kw)
7942                        .arg(sd)
7943                        .arg(sh)
7944                        .arg(sw)
7945                        .arg(pd)
7946                        .arg(ph)
7947                        .arg(pw)
7948                        .arg(op)
7949                        .arg(in_off)
7950                        .arg(out_off);
7951                    unsafe {
7952                        launcher
7953                            .launch(cfg)
7954                            .expect("rlx-cuda: pool3d launch failed");
7955                    }
7956                }
7957                Step::Conv1d {
7958                    n,
7959                    c_in,
7960                    c_out,
7961                    l,
7962                    l_out,
7963                    kl,
7964                    sl,
7965                    pl,
7966                    dl,
7967                    groups,
7968                    in_off,
7969                    w_off,
7970                    out_off,
7971                } => {
7972                    // Tier 1: cuDNN — 1-D conv as a degenerate 2-D conv
7973                    // with H=1, kh=1, sh=1, ph=0, dh=1. Same descriptors
7974                    // as conv2d; the H axis just collapses to 1.
7975                    let used_cudnn = if let (Some(handle), Some(workspace)) =
7976                        (self.dnn, self.dnn_workspace.as_ref())
7977                    {
7978                        let mut workspace = workspace.lock().unwrap();
7979                        let (ws_ptr, _ws_record) = workspace.device_ptr_mut(&stream);
7980                        let (arena_ptr, _arena_record) =
7981                            self.arena.f32_buf_mut().device_ptr_mut(&stream);
7982                        let r = unsafe {
7983                            cudnn_conv2d_forward(
7984                                handle,
7985                                ws_ptr,
7986                                CUDNN_WORKSPACE_BYTES,
7987                                arena_ptr,
7988                                *n,
7989                                *c_in,
7990                                *c_out,
7991                                /*h*/ 1,
7992                                *l,
7993                                /*h_out*/ 1,
7994                                *l_out,
7995                                /*kh*/ 1,
7996                                *kl,
7997                                /*sh*/ 1,
7998                                *sl,
7999                                /*ph*/ 0,
8000                                *pl,
8001                                /*dh*/ 1,
8002                                *dl,
8003                                *groups,
8004                                *in_off,
8005                                *w_off,
8006                                *out_off,
8007                            )
8008                        };
8009                        if let Err(ref e) = r {
8010                            log_fallback("conv1d.cudnn", e);
8011                        }
8012                        r.is_ok()
8013                    } else {
8014                        false
8015                    };
8016                    if used_cudnn {
8017                        continue;
8018                    }
8019
8020                    // Fallback: custom direct-convolution kernel.
8021                    let kernel = conv1d_kernel(&self.ctx);
8022                    let total = n * c_out * l_out;
8023                    let (grid, block) = dispatch_grid_1d(total, 256);
8024                    let cfg = LaunchConfig {
8025                        grid_dim: (grid, 1, 1),
8026                        block_dim: (block, 1, 1),
8027                        shared_mem_bytes: 0,
8028                    };
8029                    let mut launcher = stream.launch_builder(&kernel.function);
8030                    launcher
8031                        .arg(self.arena.f32_buf_mut())
8032                        .arg(n)
8033                        .arg(c_in)
8034                        .arg(c_out)
8035                        .arg(l)
8036                        .arg(l_out)
8037                        .arg(kl)
8038                        .arg(sl)
8039                        .arg(pl)
8040                        .arg(dl)
8041                        .arg(groups)
8042                        .arg(in_off)
8043                        .arg(w_off)
8044                        .arg(out_off);
8045                    unsafe {
8046                        launcher
8047                            .launch(cfg)
8048                            .expect("rlx-cuda: conv1d launch failed");
8049                    }
8050                }
8051                Step::Conv2d {
8052                    n,
8053                    c_in,
8054                    c_out,
8055                    h,
8056                    w,
8057                    h_out,
8058                    w_out,
8059                    kh,
8060                    kw,
8061                    sh,
8062                    sw,
8063                    ph,
8064                    pw,
8065                    dh,
8066                    dw,
8067                    groups,
8068                    in_off,
8069                    w_off,
8070                    out_off,
8071                } => {
8072                    // Tier 1: cuDNN — picks the fastest algo via the v7
8073                    // heuristic for the supplied shape + workspace size.
8074                    // Matmul parity (RLX_CUDA_PARITY) must not disable cuDNN conv — the
8075                    // custom conv2d.cu fallback drifts vs CPU on Deep4; cuDNN matches CPU.
8076                    let try_cudnn = self.dnn.is_some()
8077                        && self.dnn_workspace.is_some()
8078                        && !rlx_ir::env::flag("RLX_CUDA_NO_CUDNN");
8079                    let used_cudnn = if try_cudnn {
8080                        let handle = self.dnn.expect("dnn handle");
8081                        let workspace = self.dnn_workspace.as_ref().expect("dnn workspace");
8082                        let mut workspace = workspace.lock().unwrap();
8083                        let (ws_ptr, _ws_record) = workspace.device_ptr_mut(&stream);
8084                        let (arena_ptr, _arena_record) =
8085                            self.arena.f32_buf_mut().device_ptr_mut(&stream);
8086                        let r = unsafe {
8087                            cudnn_conv2d_forward(
8088                                handle,
8089                                ws_ptr,
8090                                CUDNN_WORKSPACE_BYTES,
8091                                arena_ptr,
8092                                *n,
8093                                *c_in,
8094                                *c_out,
8095                                *h,
8096                                *w,
8097                                *h_out,
8098                                *w_out,
8099                                *kh,
8100                                *kw,
8101                                *sh,
8102                                *sw,
8103                                *ph,
8104                                *pw,
8105                                *dh,
8106                                *dw,
8107                                *groups,
8108                                *in_off,
8109                                *w_off,
8110                                *out_off,
8111                            )
8112                        };
8113                        if let Err(ref e) = r {
8114                            log_fallback("conv2d.cudnn", e);
8115                        }
8116                        r.is_ok()
8117                    } else {
8118                        false
8119                    };
8120                    if used_cudnn {
8121                        continue;
8122                    }
8123
8124                    // Fallback: custom direct-convolution kernel (cuDNN preferred via PATH).
8125                    let kernel = conv2d_kernel(&self.ctx);
8126                    let total = n * c_out * h_out * w_out;
8127                    let (grid, block) = dispatch_grid_1d(total, 256);
8128                    let cfg = LaunchConfig {
8129                        grid_dim: (grid, 1, 1),
8130                        block_dim: (block, 1, 1),
8131                        shared_mem_bytes: 0,
8132                    };
8133                    let mut launcher = stream.launch_builder(&kernel.function);
8134                    launcher
8135                        .arg(self.arena.f32_buf_mut())
8136                        .arg(n)
8137                        .arg(c_in)
8138                        .arg(c_out)
8139                        .arg(h)
8140                        .arg(w)
8141                        .arg(h_out)
8142                        .arg(w_out)
8143                        .arg(kh)
8144                        .arg(kw)
8145                        .arg(sh)
8146                        .arg(sw)
8147                        .arg(ph)
8148                        .arg(pw)
8149                        .arg(dh)
8150                        .arg(dw)
8151                        .arg(groups)
8152                        .arg(in_off)
8153                        .arg(w_off)
8154                        .arg(out_off);
8155                    unsafe {
8156                        launcher
8157                            .launch(cfg)
8158                            .expect("rlx-cuda: conv2d launch failed");
8159                    }
8160                }
8161                Step::Conv3d {
8162                    n,
8163                    c_in,
8164                    c_out,
8165                    d,
8166                    h,
8167                    w,
8168                    d_out,
8169                    h_out,
8170                    w_out,
8171                    kd,
8172                    kh,
8173                    kw,
8174                    sd,
8175                    sh,
8176                    sw,
8177                    pd,
8178                    ph,
8179                    pw,
8180                    dd,
8181                    dh,
8182                    dw,
8183                    groups,
8184                    in_off,
8185                    w_off,
8186                    out_off,
8187                } => {
8188                    // Tier 1: cuDNN nd-conv (NCDHW + 3-D pads/strides/dilations).
8189                    let used_cudnn = if let (Some(handle), Some(workspace)) =
8190                        (self.dnn, self.dnn_workspace.as_ref())
8191                    {
8192                        let mut workspace = workspace.lock().unwrap();
8193                        let (ws_ptr, _ws_record) = workspace.device_ptr_mut(&stream);
8194                        let (arena_ptr, _arena_record) =
8195                            self.arena.f32_buf_mut().device_ptr_mut(&stream);
8196                        let r = unsafe {
8197                            cudnn_conv3d_forward(
8198                                handle,
8199                                ws_ptr,
8200                                CUDNN_WORKSPACE_BYTES,
8201                                arena_ptr,
8202                                *n,
8203                                *c_in,
8204                                *c_out,
8205                                *d,
8206                                *h,
8207                                *w,
8208                                *d_out,
8209                                *h_out,
8210                                *w_out,
8211                                *kd,
8212                                *kh,
8213                                *kw,
8214                                *sd,
8215                                *sh,
8216                                *sw,
8217                                *pd,
8218                                *ph,
8219                                *pw,
8220                                *dd,
8221                                *dh,
8222                                *dw,
8223                                *groups,
8224                                *in_off,
8225                                *w_off,
8226                                *out_off,
8227                            )
8228                        };
8229                        if let Err(ref e) = r {
8230                            log_fallback("conv3d.cudnn", e);
8231                        }
8232                        r.is_ok()
8233                    } else {
8234                        false
8235                    };
8236                    if used_cudnn {
8237                        continue;
8238                    }
8239
8240                    // Fallback: custom direct-convolution kernel.
8241                    let kernel = conv3d_kernel(&self.ctx);
8242                    let total = n * c_out * d_out * h_out * w_out;
8243                    let (grid, block) = dispatch_grid_1d(total, 256);
8244                    let cfg = LaunchConfig {
8245                        grid_dim: (grid, 1, 1),
8246                        block_dim: (block, 1, 1),
8247                        shared_mem_bytes: 0,
8248                    };
8249                    let mut launcher = stream.launch_builder(&kernel.function);
8250                    launcher
8251                        .arg(self.arena.f32_buf_mut())
8252                        .arg(n)
8253                        .arg(c_in)
8254                        .arg(c_out)
8255                        .arg(d)
8256                        .arg(h)
8257                        .arg(w)
8258                        .arg(d_out)
8259                        .arg(h_out)
8260                        .arg(w_out)
8261                        .arg(kd)
8262                        .arg(kh)
8263                        .arg(kw)
8264                        .arg(sd)
8265                        .arg(sh)
8266                        .arg(sw)
8267                        .arg(pd)
8268                        .arg(ph)
8269                        .arg(pw)
8270                        .arg(dd)
8271                        .arg(dh)
8272                        .arg(dw)
8273                        .arg(groups)
8274                        .arg(in_off)
8275                        .arg(w_off)
8276                        .arg(out_off);
8277                    unsafe {
8278                        launcher
8279                            .launch(cfg)
8280                            .expect("rlx-cuda: conv3d launch failed");
8281                    }
8282                }
8283            }
8284
8285            // Multi-stream tail: record an event so future steps can
8286            // wait on this one, then update producer_of with the
8287            // offsets this step wrote.
8288            if let Some(idx) = assigned_idx {
8289                if let Ok(evt) = stream.record_event(None) {
8290                    last_event.insert(idx, evt);
8291                }
8292                let (_, writes) = step_offsets(step);
8293                for w in &writes {
8294                    producer_of.insert(*w, idx);
8295                }
8296            }
8297        }
8298
8299        // Multi-stream: sync every pool stream so output reads see all
8300        // produced data.
8301        if multi_stream {
8302            for s in &self.streams {
8303                let _ = s.synchronize();
8304            }
8305        }
8306
8307        self.prepare_readback_plan();
8308        let plan = self.readback_plan_buf.clone();
8309        run_tail_host_audio_ops(&self.schedule, &stream, self.arena.f32_buf_mut(), true);
8310        if !self.gpu_handle_feeds.is_empty() {
8311            self.propagate_gpu_handle_feeds_d2d(&stream);
8312        }
8313        let read_all = plan.len() == self.graph.outputs.len();
8314
8315        if capturing {
8316            // End capture before dtoh — the graph records compute kernels only.
8317            let cu_graph = stream.end_capture(
8318                cudarc::driver::sys::CUgraphInstantiate_flags
8319                    ::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
8320            ).expect("rlx-cuda: end_capture failed");
8321            if let Some(g) = cu_graph {
8322                g.upload().expect("rlx-cuda: graph upload failed");
8323                g.launch().expect("rlx-cuda: graph first launch failed");
8324                self.captured_graph = Some(g);
8325                self.captured_readback_plan = Some(plan.clone());
8326            }
8327        }
8328
8329        if read_all {
8330            self.fill_output_staging(&stream)
8331                .expect("rlx-cuda: output dtoh failed");
8332        } else {
8333            self.fill_output_staging_indices(&stream, &plan)
8334                .expect("rlx-cuda: partial output dtoh failed");
8335        }
8336        self.refresh_gpu_handles_from_staging(&plan);
8337        stream.synchronize().expect("rlx-cuda: stream sync failed");
8338        self.outputs_from_staging_plan(&plan)
8339    }
8340
8341    fn fill_output_staging_indices(
8342        &mut self,
8343        stream: &Arc<cudarc::driver::CudaStream>,
8344        indices: &[usize],
8345    ) -> Result<(), cudarc::driver::DriverError> {
8346        for &i in indices {
8347            let id = self.graph.outputs[i];
8348            let off_f32 = self.arena.offset(id) / 4;
8349            let elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
8350            debug_assert_eq!(self.output_staging[i].len(), elems);
8351            let slot = self.arena.f32_buf().slice(off_f32..off_f32 + elems);
8352            self.output_staging[i].dtoh(stream, &slot)?;
8353        }
8354        Ok(())
8355    }
8356
8357    fn outputs_from_staging_plan(&self, plan: &[usize]) -> Vec<Vec<f32>> {
8358        if plan.len() == self.graph.outputs.len() {
8359            return self.outputs_from_staging();
8360        }
8361        plan.iter()
8362            .map(|&i| self.output_staging[i].to_vec())
8363            .collect()
8364    }
8365
8366    fn fill_output_staging(
8367        &mut self,
8368        stream: &Arc<cudarc::driver::CudaStream>,
8369    ) -> Result<(), cudarc::driver::DriverError> {
8370        for (i, &id) in self.graph.outputs.iter().enumerate() {
8371            let off_f32 = self.arena.offset(id) / 4;
8372            let elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
8373            debug_assert_eq!(self.output_staging[i].len(), elems);
8374            let slot = self.arena.f32_buf().slice(off_f32..off_f32 + elems);
8375            self.output_staging[i].dtoh(stream, &slot)?;
8376        }
8377        Ok(())
8378    }
8379
8380    fn outputs_from_staging(&self) -> Vec<Vec<f32>> {
8381        self.output_staging
8382            .iter()
8383            .map(F32HostSlot::to_vec)
8384            .collect()
8385    }
8386}
8387
8388fn launch_cumsum_bwd(
8389    ctx: &Arc<CudaContext>,
8390    stream: &cudarc::driver::CudaStream,
8391    buffer: &mut cudarc::driver::CudaSlice<f32>,
8392    outer: u32,
8393    inner: u32,
8394    dy_off: u32,
8395    dx_off: u32,
8396    exclusive: u32,
8397) {
8398    let kernel = cumsum_backward_kernel(ctx);
8399    let (grid, block) = dispatch_grid_1d(outer, 256);
8400    let cfg = LaunchConfig {
8401        grid_dim: (grid, 1, 1),
8402        block_dim: (block, 1, 1),
8403        shared_mem_bytes: 0,
8404    };
8405    let mut launcher = stream.launch_builder(&kernel.function);
8406    launcher
8407        .arg(buffer)
8408        .arg(&outer)
8409        .arg(&inner)
8410        .arg(&dy_off)
8411        .arg(&dx_off)
8412        .arg(&exclusive);
8413    unsafe {
8414        launcher
8415            .launch(cfg)
8416            .expect("rlx-cuda: cumsum_bwd launch failed");
8417    }
8418}
8419
8420fn launch_rope_bwd(
8421    ctx: &Arc<CudaContext>,
8422    stream: &cudarc::driver::CudaStream,
8423    buffer: &mut cudarc::driver::CudaSlice<f32>,
8424    batch: u32,
8425    seq: u32,
8426    hidden: u32,
8427    head_dim: u32,
8428    n_rot: u32,
8429    dy_off: u32,
8430    cos_off: u32,
8431    sin_off: u32,
8432    dx_off: u32,
8433    cos_len: u32,
8434) {
8435    let total = batch * seq * hidden;
8436    let kernel = rope_backward_kernel(ctx);
8437    let (grid, block) = dispatch_grid_1d(total, 256);
8438    let cfg = LaunchConfig {
8439        grid_dim: (grid, 1, 1),
8440        block_dim: (block, 1, 1),
8441        shared_mem_bytes: 0,
8442    };
8443    let mut launcher = stream.launch_builder(&kernel.function);
8444    launcher
8445        .arg(buffer)
8446        .arg(&batch)
8447        .arg(&seq)
8448        .arg(&hidden)
8449        .arg(&head_dim)
8450        .arg(&n_rot)
8451        .arg(&dy_off)
8452        .arg(&cos_off)
8453        .arg(&sin_off)
8454        .arg(&dx_off)
8455        .arg(&cos_len);
8456    unsafe {
8457        launcher
8458            .launch(cfg)
8459            .expect("rlx-cuda: rope_bwd launch failed");
8460    }
8461}
8462
8463fn launch_gather_bwd(
8464    ctx: &Arc<CudaContext>,
8465    stream: &cudarc::driver::CudaStream,
8466    buffer: &mut cudarc::driver::CudaSlice<f32>,
8467    outer: u32,
8468    axis_dim: u32,
8469    num_idx: u32,
8470    trailing: u32,
8471    dy_off: u32,
8472    idx_off: u32,
8473    dst_off: u32,
8474) {
8475    let total = outer * axis_dim * trailing;
8476    if total > 0 {
8477        let zk = rms_norm_bwd_zero_kernel(ctx);
8478        let (grid, block) = dispatch_grid_1d(total, 256);
8479        let cfg = LaunchConfig {
8480            grid_dim: (grid, 1, 1),
8481            block_dim: (block, 1, 1),
8482            shared_mem_bytes: 0,
8483        };
8484        let mut zl = stream.launch_builder(&zk.function);
8485        zl.arg(&mut *buffer).arg(&dst_off).arg(&total);
8486        unsafe {
8487            zl.launch(cfg)
8488                .expect("rlx-cuda: gather_bwd zero launch failed");
8489        }
8490    }
8491    let kernel = gather_backward_kernel(ctx);
8492    let cfg = LaunchConfig {
8493        grid_dim: (outer, (num_idx * trailing).div_ceil(256), 1),
8494        block_dim: (256, 1, 1),
8495        shared_mem_bytes: 0,
8496    };
8497    let mut launcher = stream.launch_builder(&kernel.function);
8498    launcher
8499        .arg(&mut *buffer)
8500        .arg(&outer)
8501        .arg(&axis_dim)
8502        .arg(&num_idx)
8503        .arg(&trailing)
8504        .arg(&dy_off)
8505        .arg(&idx_off)
8506        .arg(&dst_off);
8507    unsafe {
8508        launcher
8509            .launch(cfg)
8510            .expect("rlx-cuda: gather_bwd launch failed");
8511    }
8512}
8513
8514fn launch_rms_norm_bwd(
8515    ctx: &Arc<CudaContext>,
8516    stream: &cudarc::driver::CudaStream,
8517    buffer: &mut cudarc::driver::CudaSlice<f32>,
8518    rows: u32,
8519    inner: u32,
8520    x_off: u32,
8521    gamma_off: u32,
8522    beta_off: u32,
8523    dy_off: u32,
8524    out_off: u32,
8525    eps_bits: u32,
8526    wrt: u32,
8527) {
8528    if wrt != 0 {
8529        let zk = rms_norm_bwd_zero_kernel(ctx);
8530        let (grid, block) = dispatch_grid_1d(inner, 256);
8531        let cfg = LaunchConfig {
8532            grid_dim: (grid, 1, 1),
8533            block_dim: (block, 1, 1),
8534            shared_mem_bytes: 0,
8535        };
8536        let mut zl = stream.launch_builder(&zk.function);
8537        zl.arg(&mut *buffer).arg(&out_off).arg(&inner);
8538        unsafe {
8539            zl.launch(cfg)
8540                .expect("rlx-cuda: rms_norm_bwd zero launch failed");
8541        }
8542    }
8543    let kernel = rms_norm_backward_kernel(ctx);
8544    let cfg = LaunchConfig {
8545        grid_dim: (rows, 1, 1),
8546        block_dim: (256, 1, 1),
8547        shared_mem_bytes: 0,
8548    };
8549    let mut launcher = stream.launch_builder(&kernel.function);
8550    launcher
8551        .arg(&mut *buffer)
8552        .arg(&rows)
8553        .arg(&inner)
8554        .arg(&x_off)
8555        .arg(&gamma_off)
8556        .arg(&beta_off)
8557        .arg(&dy_off)
8558        .arg(&out_off)
8559        .arg(&eps_bits)
8560        .arg(&wrt);
8561    unsafe {
8562        launcher
8563            .launch(cfg)
8564            .expect("rlx-cuda: rms_norm_bwd launch failed");
8565    }
8566}
8567
8568#[cfg(test)]
8569mod tests {
8570    //! Pure-function tests for the multi-stream scheduler analysis and
8571    //! the element-wise fusion pass. Both are pure Rust against
8572    //! synthesized `Vec<Step>` inputs — no CUDA driver needed, so they
8573    //! run on Mac.
8574    use super::*;
8575
8576    #[test]
8577    fn normalize_read_indices_dedupes() {
8578        let mut v = vec![3, 1, 2, 1, 0];
8579        normalize_read_indices(&mut v);
8580        assert_eq!(v, vec![0, 1, 2, 3]);
8581    }
8582
8583    #[test]
8584    fn step_offsets_binary() {
8585        let s = Step::Binary {
8586            n: 8,
8587            a_off: 100,
8588            b_off: 200,
8589            c_off: 300,
8590            op: 0,
8591        };
8592        let (r, w) = step_offsets(&s);
8593        assert_eq!(r, vec![100, 200]);
8594        assert_eq!(w, vec![300]);
8595    }
8596
8597    #[test]
8598    fn step_offsets_matmul_with_bias() {
8599        let s = Step::Matmul {
8600            m: 4,
8601            k: 8,
8602            n: 4,
8603            a_off_f32: 10,
8604            b_off_f32: 20,
8605            c_off_f32: 30,
8606            batch: 1,
8607            a_batch_stride: 0,
8608            b_batch_stride: 0,
8609            c_batch_stride: 0,
8610            has_bias: 1,
8611            bias_off_f32: 40,
8612            act_id: 0xFFFF,
8613        };
8614        let (r, w) = step_offsets(&s);
8615        assert_eq!(r, vec![10, 20, 40]);
8616        assert_eq!(w, vec![30]);
8617    }
8618
8619    #[test]
8620    fn step_offsets_matmul_no_bias() {
8621        let s = Step::Matmul {
8622            m: 4,
8623            k: 8,
8624            n: 4,
8625            a_off_f32: 10,
8626            b_off_f32: 20,
8627            c_off_f32: 30,
8628            batch: 1,
8629            a_batch_stride: 0,
8630            b_batch_stride: 0,
8631            c_batch_stride: 0,
8632            has_bias: 0,
8633            bias_off_f32: 0,
8634            act_id: 0xFFFF,
8635        };
8636        let (r, w) = step_offsets(&s);
8637        assert_eq!(r, vec![10, 20]);
8638        assert_eq!(w, vec![30]);
8639    }
8640
8641    #[test]
8642    fn step_offsets_attention_causal_no_mask_arg() {
8643        let (mb, mh, mq, mk) = rlx_ir::mask_strides_bhsd(1, 8, 8);
8644        let (qb, qh, qs) = rlx_ir::strides_bhsd(1, 64, 8);
8645        let s = Step::Attention {
8646            batch: 1,
8647            heads: 1,
8648            seq_q: 8,
8649            seq_k: 8,
8650            head_dim: 64,
8651            q_off: 0,
8652            k_off: 100,
8653            v_off: 200,
8654            out_off: 300,
8655            mask_off: 9999,
8656            mask_kind: 1, // causal — mask_off ignored
8657            scale_bits: 0,
8658            window: 0,
8659            seq_q_stride: mq,
8660            seq_k_stride: mk,
8661            mask_batch_stride: mb,
8662            mask_head_stride: mh,
8663            q_batch_stride: qb,
8664            q_head_stride: qh,
8665            q_seq_stride: qs,
8666            k_batch_stride: qb,
8667            k_head_stride: qh,
8668            k_seq_stride: qs,
8669            v_batch_stride: qb,
8670            v_head_stride: qh,
8671            v_seq_stride: qs,
8672            o_batch_stride: qb,
8673            o_head_stride: qh,
8674            o_seq_stride: qs,
8675        };
8676        let (r, _) = step_offsets(&s);
8677        assert!(!r.contains(&9999), "causal mask must not consume mask_off");
8678        assert_eq!(r, vec![0, 100, 200]);
8679    }
8680
8681    #[test]
8682    fn step_offsets_attention_custom_mask_pulls_mask() {
8683        let (mb, mh, mq, mk) = rlx_ir::mask_strides_bhsd(1, 8, 8);
8684        let (qb, qh, qs) = rlx_ir::strides_bhsd(1, 64, 8);
8685        let s = Step::Attention {
8686            batch: 1,
8687            heads: 1,
8688            seq_q: 8,
8689            seq_k: 8,
8690            head_dim: 64,
8691            q_off: 0,
8692            k_off: 100,
8693            v_off: 200,
8694            out_off: 300,
8695            mask_off: 9999,
8696            mask_kind: 2, // custom mask
8697            scale_bits: 0,
8698            window: 0,
8699            seq_q_stride: mq,
8700            seq_k_stride: mk,
8701            mask_batch_stride: mb,
8702            mask_head_stride: mh,
8703            q_batch_stride: qb,
8704            q_head_stride: qh,
8705            q_seq_stride: qs,
8706            k_batch_stride: qb,
8707            k_head_stride: qh,
8708            k_seq_stride: qs,
8709            v_batch_stride: qb,
8710            v_head_stride: qh,
8711            v_seq_stride: qs,
8712            o_batch_stride: qb,
8713            o_head_stride: qh,
8714            o_seq_stride: qs,
8715        };
8716        let (r, _) = step_offsets(&s);
8717        assert!(r.contains(&9999));
8718    }
8719
8720    #[test]
8721    fn step_offsets_scatter_add_acc_marks_out_as_rmw() {
8722        let s = Step::ScatterAddAcc {
8723            out_off: 100,
8724            upd_off: 200,
8725            idx_off: 300,
8726            num_updates: 4,
8727            trailing: 1,
8728            out_dim: 16,
8729        };
8730        let (r, w) = step_offsets(&s);
8731        // out is read-modify-write, so it appears in BOTH reads and writes
8732        // — this lets the multi-stream scheduler force the prior
8733        // ScatterAddZero to complete before the accumulate launches.
8734        assert!(r.contains(&100));
8735        assert!(w.contains(&100));
8736    }
8737
8738    #[test]
8739    fn fuse_elementwise_merges_binary_then_unary() {
8740        let schedule = vec![
8741            // c = a + b
8742            Step::Binary {
8743                n: 4,
8744                a_off: 0,
8745                b_off: 4,
8746                c_off: 8,
8747                op: 0,
8748            },
8749            // d = relu(c)
8750            Step::Unary {
8751                n: 4,
8752                in_off: 8,
8753                out_off: 12,
8754                op: 0,
8755            },
8756        ];
8757        let fused = fuse_elementwise_chains(schedule);
8758        assert_eq!(fused.len(), 1, "expected exactly one fused step");
8759        match &fused[0] {
8760            Step::FusedBinaryUnary {
8761                n,
8762                a_off,
8763                b_off,
8764                out_off,
8765                bin_op,
8766                un_op,
8767            } => {
8768                assert_eq!(*n, 4);
8769                assert_eq!(*a_off, 0);
8770                assert_eq!(*b_off, 4);
8771                assert_eq!(*out_off, 12);
8772                assert_eq!(*bin_op, 0);
8773                assert_eq!(*un_op, 0);
8774            }
8775            other => panic!("expected FusedBinaryUnary, got {}", step_name(other)),
8776        }
8777    }
8778
8779    #[test]
8780    fn fuse_elementwise_skips_when_intermediate_has_two_consumers() {
8781        // c = a + b
8782        // d = relu(c)
8783        // e = c * c   ← second consumer of c, blocks fusion
8784        let schedule = vec![
8785            Step::Binary {
8786                n: 4,
8787                a_off: 0,
8788                b_off: 4,
8789                c_off: 8,
8790                op: 0,
8791            },
8792            Step::Unary {
8793                n: 4,
8794                in_off: 8,
8795                out_off: 12,
8796                op: 0,
8797            },
8798            Step::Binary {
8799                n: 4,
8800                a_off: 8,
8801                b_off: 8,
8802                c_off: 16,
8803                op: 2,
8804            },
8805        ];
8806        let fused = fuse_elementwise_chains(schedule);
8807        assert_eq!(fused.len(), 3, "no fusion: c has multiple consumers");
8808        assert!(matches!(&fused[0], Step::Binary { .. }));
8809        assert!(matches!(&fused[1], Step::Unary { .. }));
8810        assert!(matches!(&fused[2], Step::Binary { .. }));
8811    }
8812
8813    #[test]
8814    fn fuse_elementwise_skips_when_n_mismatch() {
8815        // Different element counts → can't fuse (different launch grid).
8816        let schedule = vec![
8817            Step::Binary {
8818                n: 4,
8819                a_off: 0,
8820                b_off: 4,
8821                c_off: 8,
8822                op: 0,
8823            },
8824            Step::Unary {
8825                n: 8,
8826                in_off: 8,
8827                out_off: 16,
8828                op: 0,
8829            },
8830        ];
8831        let fused = fuse_elementwise_chains(schedule);
8832        assert_eq!(fused.len(), 2);
8833    }
8834
8835    #[test]
8836    fn fuse_elementwise_skips_when_unary_input_isnt_binary_output() {
8837        // Unary reads a different offset than what Binary wrote.
8838        let schedule = vec![
8839            Step::Binary {
8840                n: 4,
8841                a_off: 0,
8842                b_off: 4,
8843                c_off: 8,
8844                op: 0,
8845            },
8846            Step::Unary {
8847                n: 4,
8848                in_off: 99,
8849                out_off: 16,
8850                op: 0,
8851            },
8852        ];
8853        let fused = fuse_elementwise_chains(schedule);
8854        assert_eq!(fused.len(), 2);
8855    }
8856
8857    #[test]
8858    fn fuse_elementwise_handles_multiple_chains() {
8859        // Two independent Binary→Unary chains in a row — both should fuse.
8860        let schedule = vec![
8861            Step::Binary {
8862                n: 4,
8863                a_off: 0,
8864                b_off: 4,
8865                c_off: 8,
8866                op: 0,
8867            },
8868            Step::Unary {
8869                n: 4,
8870                in_off: 8,
8871                out_off: 12,
8872                op: 0,
8873            },
8874            Step::Binary {
8875                n: 4,
8876                a_off: 16,
8877                b_off: 20,
8878                c_off: 24,
8879                op: 2,
8880            },
8881            Step::Unary {
8882                n: 4,
8883                in_off: 24,
8884                out_off: 28,
8885                op: 9,
8886            },
8887        ];
8888        let fused = fuse_elementwise_chains(schedule);
8889        assert_eq!(fused.len(), 2);
8890        assert!(matches!(&fused[0], Step::FusedBinaryUnary { .. }));
8891        assert!(matches!(&fused[1], Step::FusedBinaryUnary { .. }));
8892    }
8893}