Skip to main content

rlx_wgpu/
backend.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! `WgpuExecutable` — compiles an rlx-ir Graph into a sequence of
17//! kernel dispatches against a pre-allocated arena buffer.
18//!
19//! v2 op coverage: MatMul + element-wise families (Binary 7, Unary 12,
20//! Compare 6, Where) + leaves. Anything else panics at compile time.
21
22use std::collections::{HashMap, HashSet};
23
24use rlx_ir::dynamic::{bind_graph, has_dynamic_dims, infer_bindings_from_f32_inputs, same_binding};
25use rlx_ir::op::{Activation, BinaryOp, CmpOp, MaskKind, ReduceOp};
26use rlx_ir::shape::DimBinding;
27use rlx_ir::{Graph, NodeId, Op};
28
29use crate::buffer::{Arena, plan_f32_uniform};
30use crate::device::wgpu_device;
31use crate::kernels::{
32    ArgmaxParams, AttentionBwdParams, AttentionParams, BinaryParams, Conv1dParams, Conv2dParams,
33    Conv3dParams, CopyParams, CumsumBwdParams, CumsumParams, DequantMatmulParams,
34    ElementwiseRegionParams, ExpandParams, FusedResidualLnParams, FusedResidualLnTeeParams,
35    FusedResidualRmsNormParams, GatherAxisParams, GatherBwdParams, GatherParams,
36    GroupedMatmulParams, Kernel, LayerNormParams, MatmulParams, MatmulQkvParams,
37    NarrowConcatParams, Pool1dParams, Pool2dParams, Pool3dParams, ReduceParams, RmsNormBwdParams,
38    RopeBwdParams, RopeParams, SampleParams, ScatterAddParams, SelectiveScanParams, SoftmaxParams,
39    TopKParams, TransposeParams, UmapKnnParams, UnaryParams, WhereParams, argmax_kernel,
40    attention_bwd_kernel, attention_kernel, binary_kernel, cast_f32_to_f16_kernel, compare_kernel,
41    concat_kernel, conv1d_kernel, conv2d_kernel, conv3d_kernel, copy_kernel,
42    cumsum_backward_kernel, cumsum_kernel, dequant_matmul_kernel, elementwise_region_kernel,
43    expand_kernel, fused_residual_ln_kernel, fused_residual_ln_tee_kernel,
44    fused_residual_rms_norm_kernel, gather_axis_kernel, gather_backward_acc_kernel,
45    gather_backward_zero_kernel, gather_kernel, grouped_matmul_kernel, layernorm_kernel,
46    matmul_coop_f32_kernel, matmul_coop16_kernel, matmul_f16_compute_kernel, matmul_f16w_kernel,
47    matmul_kernel, matmul_qkv_coop_f32_kernel, matmul_qkv_kernel, matmul_wide_kernel,
48    narrow_kernel, pool1d_kernel, pool2d_kernel, pool3d_kernel, reduce_kernel,
49    rms_norm_backward_kernel, rms_norm_backward_param_kernel, rope_backward_kernel, rope_kernel,
50    sample_kernel, scatter_add_kernel, selective_scan_kernel, softmax_kernel, topk_kernel,
51    transpose_kernel, umap_knn_kernel, unary_kernel, where_kernel,
52};
53use rlx_ir::op::{ChainOperand, ChainStep};
54
55/// Inner-FMA precision for matmul.
56///   F32    — full f32 path (matmul.wgsl / matmul_wide.wgsl).
57///   F16    — f16 multiply, f32 acc (matmul_f16_compute.wgsl).
58///   Coop16 — cooperative-matrix 8×8 hardware GEMM
59///            (matmul_coop16.wgsl, simdgroup_multiply_accumulate on
60///             Apple, OpCooperativeMatrixMulAddKHR on Vulkan).
61///            Requires M/N/K multiples of 8, b is a Param, and
62///            both SHADER_F16 + EXPERIMENTAL_COOPERATIVE_MATRIX.
63///            Caller must ensure A is mirrored to arena_f16 first
64///            (the lowering inserts a `Step::CastF32ToF16` pre-pass).
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66enum MatmulCompute {
67    F32,
68    F16,
69    Coop16,
70    /// Cooperative-matrix on Apple's `simdgroup_float8x8` — same hardware
71    /// GEMM unit as Coop16 but with f32 operands and f32 accumulator.
72    /// No precision loss vs F32 baseline; no f16 overflow risk in deep
73    /// FFN sums. Used when alignment + features allow but the IR is f32.
74    CoopF32,
75}
76
77/// f32 → f16 element-wise cast, mirroring an arena region into the
78/// f16 shadow buffer. Used as a pre-pass before `matmul_coop16` so
79/// the matmul's A operand (a runtime activation, not a Param) is
80/// readable as f16.
81///
82/// Currently unused — the matmul_coop16 kernel stages A through
83/// workgroup-shared memory directly from the f32 arena. Kept for
84/// future paths that may want a one-shot cast (e.g. before a chain
85/// of f16-only kernels operating on a fixed activation region).
86#[allow(dead_code)]
87#[derive(Debug, Clone, Copy)]
88struct CastF32ToF16Params {
89    pub src_off: u32, // f32-element offset into arena (also f16-element offset)
90    pub len: u32,
91    pub _p0: u32,
92    pub _p1: u32,
93}
94unsafe impl bytemuck::Pod for CastF32ToF16Params {}
95unsafe impl bytemuck::Zeroable for CastF32ToF16Params {}
96
97/// One dispatch step in the compiled schedule.
98///
99/// `dead_code` is allowed at the enum level: several variants carry
100/// fields (mask_buf, meta_idx, compute_precision discriminants) that
101/// are only consulted at compile time during bind-group construction,
102/// or are kept to extend buffer lifetimes (mask_buf). A few variants
103/// (CastF32ToF16, Copy, the unreachable F16 compute_precision) are
104/// retained for future paths.
105#[allow(dead_code)]
106enum Step {
107    CastF32ToF16 {
108        params: CastF32ToF16Params,
109    },
110    Matmul {
111        m: u32,
112        k: u32,
113        n: u32,
114        a_off_f32: u32,
115        b_off_f32: u32,
116        c_off_f32: u32,
117        batch: u32,
118        a_batch_stride: u32,
119        b_batch_stride: u32,
120        c_batch_stride: u32,
121        has_bias: u32,
122        bias_off_f32: u32,
123        act_id: u32, // 0xFFFF = no activation
124        // True iff input B is a Param node — i.e. a model weight that
125        // doesn't change between `run()` calls. Read from the f16
126        // shadow buffer (half memory bandwidth) when set + the device
127        // exposes SHADER_F16. Set at compile time; consulted only by
128        // the dispatch arm.
129        b_is_param: bool,
130        // Compute precision for the inner FMA. F32 = full precision
131        // (the historical / default path). F16 = mixed-precision
132        // (operands cast to f16, multiply in f16 for 2× ALU on Apple,
133        // accumulator in f32). Set at compile time from the IR's
134        // dtype after AutoMixedPrecision policy.
135        compute_precision: MatmulCompute,
136    },
137    Binary {
138        params: BinaryParams,
139    },
140    Compare {
141        params: BinaryParams,
142    },
143    Unary {
144        params: UnaryParams,
145    },
146    Where {
147        params: WhereParams,
148    },
149    Reduce {
150        params: ReduceParams,
151    },
152    Softmax {
153        params: SoftmaxParams,
154    },
155    LayerNorm {
156        params: LayerNormParams,
157    },
158    Cumsum {
159        params: CumsumParams,
160    },
161    /// Native multi-kernel f32 FFT (gpu-fft dispatch strategy).
162    FftGpu {
163        src_off: u32,
164        dst_off: u32,
165        outer: u32,
166        n: u32,
167        inverse: u32,
168        norm_scale: f32,
169    },
170    /// Explicit host FFT (D2H → rlx-cpu → H2D). Used when the native
171    /// WGSL kernel cannot handle dtype / size / non-pow-2 constraints.
172    FftHost {
173        src_byte_off: u32,
174        dst_byte_off: u32,
175        outer: u32,
176        n_complex: u32,
177        inverse: bool,
178        norm_tag: u32,
179        dtype_tag: u32,
180    },
181    Copy {
182        params: CopyParams,
183    },
184    /// PLAN L2 — fused N-ary element-wise region. Lowered from
185    /// `Op::ElementwiseRegion` by `MarkElementwiseRegions`. Kernel
186    /// interprets the chain encoding per-element (saves N kernel
187    /// dispatches + N global-memory round-trips vs the decomposed
188    /// atomic ops).
189    ElementwiseRegion {
190        params: ElementwiseRegionParams,
191    },
192    Transpose {
193        params: TransposeParams,
194        meta_idx: usize,
195    },
196    Narrow {
197        params: NarrowConcatParams,
198    },
199    Concat {
200        params: NarrowConcatParams,
201    }, // one Step per input
202    Gather {
203        params: GatherParams,
204    },
205    GatherAxis {
206        params: GatherAxisParams,
207    },
208    Attention {
209        params: AttentionParams,
210        mask_buf: Option<wgpu::Buffer>,
211    },
212    AttentionBackward {
213        params: AttentionBwdParams,
214        mask_buf: Option<wgpu::Buffer>,
215    },
216    Rope {
217        params: RopeParams,
218    },
219    Expand {
220        params: ExpandParams,
221        meta_idx: usize,
222    },
223    Argmax {
224        params: ArgmaxParams,
225    },
226    Pool2d {
227        params: Pool2dParams,
228    },
229    Conv2d {
230        params: Conv2dParams,
231    },
232    Pool1d {
233        params: Pool1dParams,
234    },
235    Pool3d {
236        params: Pool3dParams,
237    },
238    Conv1d {
239        params: Conv1dParams,
240    },
241    Conv3d {
242        params: Conv3dParams,
243    },
244    ScatterAdd {
245        params: ScatterAddParams,
246    },
247    TopK {
248        params: TopKParams,
249    },
250    GroupedMatmul {
251        params: GroupedMatmulParams,
252    },
253    Sample {
254        params: SampleParams,
255    },
256    SelectiveScan {
257        params: SelectiveScanParams,
258    },
259    DequantMatmul {
260        params: DequantMatmulParams,
261    },
262    /// GGUF K-quant — host fused dequant+matmul between GPU segments.
263    DequantMatmulGguf {
264        m: u32,
265        k: u32,
266        n: u32,
267        scheme_id: u32,
268        x_byte_off: u32,
269        w_byte_off: u32,
270        out_byte_off: u32,
271    },
272    /// GGUF K-quant — host fused dequant+grouped matmul between GPU segments.
273    DequantGroupedMatmulGguf {
274        m: u32,
275        k: u32,
276        n: u32,
277        num_experts: u32,
278        scheme_id: u32,
279        x_byte_off: u32,
280        w_byte_off: u32,
281        idx_byte_off: u32,
282        out_byte_off: u32,
283    },
284    /// Gated-DeltaNet — host scan between GPU segments (qwen35 linear layers).
285    GatedDeltaNet {
286        q_byte_off: u32,
287        k_byte_off: u32,
288        v_byte_off: u32,
289        g_byte_off: u32,
290        beta_byte_off: u32,
291        state_byte_off: u32,
292        dst_byte_off: u32,
293        batch: u32,
294        seq: u32,
295        heads: u32,
296        state_size: u32,
297        use_carry: bool,
298    },
299    Llada2GroupLimitedGate {
300        sig_byte_off: u32,
301        route_byte_off: u32,
302        out_byte_off: u32,
303        n_elems: u32,
304        attrs: [u8; 20],
305    },
306    UmapKnn {
307        params: UmapKnnParams,
308    },
309    /// Small-`n` host k-NN (partial arena read/write; avoids GPU launch overhead).
310    UmapKnnHost {
311        pairwise_byte_off: u32,
312        out_byte_off: u32,
313        n: u32,
314        k: u32,
315    },
316    /// 3D Gaussian splat forward (CPU reference between segments).
317    #[cfg(feature = "splat")]
318    GaussianSplatRender {
319        positions_byte_off: u32,
320        positions_len: u32,
321        scales_byte_off: u32,
322        scales_len: u32,
323        rotations_byte_off: u32,
324        rotations_len: u32,
325        opacities_byte_off: u32,
326        opacities_len: u32,
327        colors_byte_off: u32,
328        colors_len: u32,
329        sh_coeffs_byte_off: u32,
330        sh_coeffs_len: u32,
331        meta_byte_off: u32,
332        dst_byte_off: u32,
333        dst_len: u32,
334        width: u32,
335        height: u32,
336        tile_size: u32,
337        radius_scale: f32,
338        alpha_cutoff: f32,
339        max_splat_steps: u32,
340        transmittance_threshold: f32,
341        max_list_entries: u32,
342    },
343    /// Backward splat — host round-trip via rlx-cpu/splat.
344    #[cfg(feature = "splat")]
345    GaussianSplatRenderBackward {
346        positions_byte_off: u32,
347        positions_len: u32,
348        scales_byte_off: u32,
349        scales_len: u32,
350        rotations_byte_off: u32,
351        rotations_len: u32,
352        opacities_byte_off: u32,
353        opacities_len: u32,
354        colors_byte_off: u32,
355        colors_len: u32,
356        sh_coeffs_byte_off: u32,
357        sh_coeffs_len: u32,
358        meta_byte_off: u32,
359        d_loss_byte_off: u32,
360        d_loss_len: u32,
361        packed_byte_off: u32,
362        packed_len: u32,
363        width: u32,
364        height: u32,
365        tile_size: u32,
366        radius_scale: f32,
367        alpha_cutoff: f32,
368        max_splat_steps: u32,
369        transmittance_threshold: f32,
370        max_list_entries: u32,
371        loss_grad_clip: f32,
372        sh_band: u32,
373        max_anisotropy: f32,
374    },
375    #[cfg(feature = "splat")]
376    GaussianSplatPrepare {
377        positions_byte_off: u32,
378        positions_len: u32,
379        scales_byte_off: u32,
380        scales_len: u32,
381        rotations_byte_off: u32,
382        rotations_len: u32,
383        opacities_byte_off: u32,
384        opacities_len: u32,
385        colors_byte_off: u32,
386        colors_len: u32,
387        sh_coeffs_byte_off: u32,
388        sh_coeffs_len: u32,
389        meta_byte_off: u32,
390        meta_len: u32,
391        prep_byte_off: u32,
392        prep_len: u32,
393        width: u32,
394        height: u32,
395        tile_size: u32,
396        radius_scale: f32,
397        alpha_cutoff: f32,
398        max_splat_steps: u32,
399        transmittance_threshold: f32,
400        max_list_entries: u32,
401    },
402    #[cfg(feature = "splat")]
403    GaussianSplatRasterize {
404        prep_byte_off: u32,
405        prep_len: u32,
406        meta_byte_off: u32,
407        meta_len: u32,
408        dst_byte_off: u32,
409        dst_len: u32,
410        count: u32,
411        width: u32,
412        height: u32,
413        tile_size: u32,
414        alpha_cutoff: f32,
415        max_splat_steps: u32,
416        transmittance_threshold: f32,
417        max_list_entries: u32,
418    },
419    RmsNormBackwardInput {
420        params: RmsNormBwdParams,
421    },
422    RmsNormBackwardGamma {
423        params: RmsNormBwdParams,
424    },
425    RmsNormBackwardBeta {
426        params: RmsNormBwdParams,
427    },
428    RopeBackward {
429        params: RopeBwdParams,
430    },
431    CumsumBackward {
432        params: CumsumBwdParams,
433    },
434    GatherBackward {
435        params: GatherBwdParams,
436    },
437    FusedResidualLn {
438        params: FusedResidualLnParams,
439    },
440    /// Split-write QKV matmul. Replaces a (FusedMatMulBiasAct → Narrow×3)
441    /// pattern with one dispatch that writes Q, K, V into separate
442    /// contiguous buffers from a single matmul pass. See
443    /// `kernels/matmul_qkv.wgsl`.
444    MatmulQkv {
445        params: MatmulQkvParams,
446        /// True → `matmul_qkv_coop_f32` (cooperative_matrix → simdgroup
447        /// f32 hw GEMM). False → `matmul_qkv` (portable f32 tile).
448        /// Both have identical bind groups and dispatch grid.
449        coop: bool,
450    },
451    /// `fused_residual_ln_tee` — does (Add → LN) but writes the sum to
452    /// a separate arena slot (the eliminated Add's old slot). Fires
453    /// when the Add has multi-consumer downstream (vision pre-norm).
454    FusedResidualLnTee {
455        params: FusedResidualLnTeeParams,
456    },
457    FusedResidualRmsNorm {
458        params: FusedResidualRmsNormParams,
459    },
460}
461
462pub struct WgpuExecutable {
463    graph: Graph,
464    arena: Arena,
465    schedule: Vec<Step>,
466    input_offsets: HashMap<String, NodeId>,
467    param_offsets: HashMap<String, NodeId>,
468    /// One uniform buffer + bind group per dispatch step. Pre-allocated
469    /// so run() just writes new bytes per step.
470    uniforms: Vec<wgpu::Buffer>,
471    bind_groups: Vec<wgpu::BindGroup>,
472    /// Per-step metadata storage buffers (only Transpose uses them).
473    /// Indexed by `Step::Transpose.meta_idx`.
474    meta_buffers: Vec<wgpu::Buffer>,
475
476    // ── Lazy dynamic-shape state ─────────────────────────────────
477    /// The originally-supplied graph (pre-resolution). Only set when
478    /// the input graph contained `Dim::Dynamic` entries — otherwise
479    /// `None` and the compiled fields above are authoritative. On each
480    /// `run()` we infer a `DimBinding` from the live input data, and
481    /// if it differs from `last_binding` we re-resolve + recompile.
482    unresolved: Option<Graph>,
483    last_binding: Option<DimBinding>,
484    /// Buffered params written via `set_param` / `set_param_bytes`
485    /// before the first `run()`. Replayed against the freshly compiled
486    /// arena once shapes resolve.
487    pending_params: HashMap<String, Vec<f32>>,
488    pending_param_bytes: HashMap<String, Vec<u8>>,
489    /// Active-extent hint (PLAN L1). When set + every Step in the
490    /// safe set, both the uniform write and the dispatch workgroup
491    /// count are scaled by `actual / upper`. Otherwise full-extent.
492    pub(crate) active_extent: Option<(usize, usize)>,
493    /// Skip-redundant-uniform-writes guard. Each `run()` would
494    /// otherwise re-`queue.write_buffer` ~115 per-step uniforms (one
495    /// per dispatched op in BERT) even when their bytes are identical
496    /// to the previous call's. At small batches, that fixed write +
497    /// staging-copy overhead is the dominant cost. We track the last
498    /// active-extent value the uniforms were written for; subsequent
499    /// `run()`s with the same `active_extent` (and `recompile`-clean
500    /// schedule) skip the entire uniform-write loop. `None` ⇒ never
501    /// written; `Some(x)` ⇒ uniforms hold params for active_extent=x.
502    uniforms_active_extent: Option<Option<(usize, usize)>>,
503    /// Per-`FftGpu` step: isolated uniform buffers + bind groups (one vec entry per op).
504    fft_gpu_steps: Vec<crate::fft_dispatch::FftGpuResources>,
505}
506
507impl Step {
508    /// True when this Step variant honors active-extent dispatch (PLAN L1).
509    /// Coverage: simple element-wise + reductions + matmul + linalg
510    /// + reductions/argmax/topk/sample + gather + conv + pool +
511    /// scatter (zero output + scale num_updates) + macros gated to
512    /// batch=1 (Attention, SelectiveScan).
513    pub fn safe_for_active_extent(&self) -> bool {
514        match self {
515            Step::Binary { .. }
516            | Step::Compare { .. }
517            | Step::Unary { .. }
518            | Step::Where { .. }
519            | Step::Reduce { .. }
520            | Step::Softmax { .. }
521            | Step::LayerNorm { .. }
522            | Step::FusedResidualLn { .. }
523            | Step::FusedResidualLnTee { .. }
524            | Step::FusedResidualRmsNorm { .. }
525            | Step::Cumsum { .. }
526            | Step::Copy { .. }
527            | Step::ElementwiseRegion { .. }
528            | Step::Argmax { .. }
529            | Step::TopK { .. }
530            | Step::Sample { .. }
531            | Step::Gather { .. }
532            | Step::GatherAxis { .. }
533            | Step::GroupedMatmul { .. }
534            | Step::DequantMatmul { .. }
535            | Step::DequantMatmulGguf { .. }
536            | Step::DequantGroupedMatmulGguf { .. }
537            | Step::GatedDeltaNet { .. }
538            | Step::Llada2GroupLimitedGate { .. }
539            | Step::UmapKnn { .. }
540            | Step::UmapKnnHost { .. }
541            | Step::Conv1d { .. }
542            | Step::Conv2d { .. }
543            | Step::Conv3d { .. }
544            | Step::Pool1d { .. }
545            | Step::Pool2d { .. }
546            | Step::Pool3d { .. }
547            | Step::ScatterAdd { .. } => true,
548            // FFT: full-extent transform per row, no active-extent
549            // scaling. Marking true so a graph that mixes FFT with
550            // active-extent-safe ops still gets the optimization for
551            // the rest of the schedule.
552            Step::FftGpu { .. } | Step::FftHost { .. } => true,
553            // Matmul: c_batch_stride is set at compile time at full m,
554            // independent of params.m. With scaled m, threads with
555            // global_row >= m early-return; per-batch output offsets
556            // stay correct. Safe at any batch.
557            Step::Matmul { .. } => true,
558            // Same active-extent reasoning as Matmul: per-batch output
559            // strides are baked at compile time, scaling m only adjusts
560            // the per-thread bound check.
561            Step::MatmulQkv { .. } => true,
562            Step::CastF32ToF16 { .. } => true,
563            // Attention: WGSL kernel uses `seq_q_stride`/`seq_k_stride`
564            // (full extent, set at compile time) for per-(batch, head)
565            // offset math, and `params.seq_q`/`params.seq_k` for loop
566            // bounds only. Scaling seq_q/seq_k shrinks the iteration
567            // without corrupting per-head strides. Safe at any batch.
568            Step::Attention { .. } => true,
569            Step::AttentionBackward { .. } => true,
570            // SelectiveScan: WGSL kernel uses `params.seq_stride`
571            // (full extent, set at compile time) for per-batch stride
572            // math; `params.seq` is the loop bound only. Safe at any
573            // batch under active-extent scaling of seq.
574            Step::SelectiveScan { .. } => true,
575            // Narrow + Concat: kernel iterates `params.total` in
576            // row-major order with outer as the leading dim. Scaling
577            // total by actual/upper effectively scales outer by the
578            // same factor (since total = outer * axis_size * inner).
579            // Output positions past scaled_total stay untouched.
580            // **Conservative assumption**: bucket axis is outer.
581            // Cases where the bucket axis is the narrow/concat axis
582            // itself are unsafe — fall back to full extent there.
583            Step::Narrow { .. } => true,
584            Step::Concat { .. } => true,
585            // Rope: WGSL kernel uses `seq_stride` (full extent, set
586            // at compile time) for per-batch buffer offset math and
587            // explicit `batch` for index decomposition. `params.seq`
588            // and `params.n_total` are runtime-scaled iteration
589            // bounds. Safe at any batch.
590            Step::Rope { .. } => true,
591            // Transpose: precomputed `bucket_outermost` flag in
592            // params (set to 1 at compile time iff `perm[0] == 0`).
593            // Active path scales `out_total` by `actual / upper`
594            // proportional to `out_dim_0`. Other transposes (where
595            // bucket axis moves) fall back to full extent.
596            Step::Transpose { params, .. } => params.bucket_outermost == 1,
597            // Expand: same shape as Transpose. `bucket_outermost` is
598            // 1 iff `in_dims[0] == out_dims[0]` (no broadcast at the
599            // bucket axis).
600            Step::Expand { params, .. } => params.bucket_outermost == 1,
601            // Training backward ops: not used in inference; disable
602            // active-extent fast path until individually audited.
603            Step::RmsNormBackwardInput { .. }
604            | Step::RmsNormBackwardGamma { .. }
605            | Step::RmsNormBackwardBeta { .. }
606            | Step::RopeBackward { .. }
607            | Step::CumsumBackward { .. }
608            | Step::GatherBackward { .. } => false,
609            #[cfg(feature = "splat")]
610            Step::GaussianSplatRender { .. }
611            | Step::GaussianSplatRenderBackward { .. }
612            | Step::GaussianSplatPrepare { .. }
613            | Step::GaussianSplatRasterize { .. } => false,
614        }
615    }
616}
617
618/// Static-string label for each Step variant — used by the Perfetto
619/// trace layer (PLAN L3) to mark per-step events without allocating.
620fn fft_dtype_tag(dtype: rlx_ir::DType) -> u32 {
621    match dtype {
622        rlx_ir::DType::F32 => 0,
623        rlx_ir::DType::F64 => 1,
624        rlx_ir::DType::C64 => 2,
625        other => panic!("rlx-wgpu Op::Fft: unsupported dtype {other:?}"),
626    }
627}
628
629fn fft_dtype_from_tag(tag: u32) -> rlx_ir::DType {
630    match tag {
631        0 => rlx_ir::DType::F32,
632        1 => rlx_ir::DType::F64,
633        2 => rlx_ir::DType::C64,
634        other => panic!("rlx-wgpu Op::Fft: bad dtype tag {other}"),
635    }
636}
637
638fn step_name(step: &Step) -> &'static str {
639    match step {
640        Step::CastF32ToF16 { .. } => "cast_f32_to_f16",
641        Step::Matmul { .. } => "matmul",
642        Step::Binary { .. } => "binary",
643        Step::Compare { .. } => "compare",
644        Step::Unary { .. } => "unary",
645        Step::Where { .. } => "where",
646        Step::Reduce { .. } => "reduce",
647        Step::Softmax { .. } => "softmax",
648        Step::LayerNorm { .. } => "layer_norm",
649        Step::Cumsum { .. } => "cumsum",
650        Step::FftGpu { .. } => "fft_gpu",
651        Step::FftHost { .. } => "fft_host",
652        Step::Copy { .. } => "copy",
653        Step::Transpose { .. } => "transpose",
654        Step::Narrow { .. } => "narrow",
655        Step::Concat { .. } => "concat",
656        Step::Gather { .. } => "gather",
657        Step::GatherAxis { .. } => "gather_axis",
658        Step::Attention { .. } => "attention",
659        Step::AttentionBackward { .. } => "attention_bwd",
660        Step::Rope { .. } => "rope",
661        Step::Expand { .. } => "expand",
662        Step::Argmax { .. } => "argmax",
663        Step::Pool2d { .. } => "pool2d",
664        Step::Conv2d { .. } => "conv2d",
665        Step::Pool1d { .. } => "pool1d",
666        Step::Pool3d { .. } => "pool3d",
667        Step::Conv1d { .. } => "conv1d",
668        Step::Conv3d { .. } => "conv3d",
669        Step::ScatterAdd { .. } => "scatter_add",
670        Step::TopK { .. } => "topk",
671        Step::GroupedMatmul { .. } => "grouped_matmul",
672        Step::Sample { .. } => "sample",
673        Step::SelectiveScan { .. } => "selective_scan",
674        Step::DequantMatmul { .. } => "dequant_matmul",
675        Step::DequantMatmulGguf { .. } => "dequant_matmul_gguf",
676        Step::DequantGroupedMatmulGguf { .. } => "dequant_grouped_matmul_gguf",
677        Step::GatedDeltaNet { .. } => "gated_delta_net",
678        Step::Llada2GroupLimitedGate { .. } => "llada2_group_limited_gate",
679        Step::UmapKnn { .. } => "umap_knn",
680        Step::UmapKnnHost { .. } => "umap_knn_host",
681        #[cfg(feature = "splat")]
682        Step::GaussianSplatRender { .. } => "gaussian_splat_render",
683        #[cfg(feature = "splat")]
684        Step::GaussianSplatRenderBackward { .. } => "gaussian_splat_render_backward",
685        #[cfg(feature = "splat")]
686        Step::GaussianSplatPrepare { .. } => "gaussian_splat_prepare",
687        #[cfg(feature = "splat")]
688        Step::GaussianSplatRasterize { .. } => "gaussian_splat_rasterize",
689        Step::RmsNormBackwardInput { .. } => "rms_norm_backward_input",
690        Step::RmsNormBackwardGamma { .. } => "rms_norm_backward_gamma",
691        Step::RmsNormBackwardBeta { .. } => "rms_norm_backward_beta",
692        Step::RopeBackward { .. } => "rope_backward",
693        Step::CumsumBackward { .. } => "cumsum_backward",
694        Step::GatherBackward { .. } => "gather_backward",
695        Step::FusedResidualLn { .. } => "fused_residual_ln",
696        Step::FusedResidualLnTee { .. } => "fused_residual_ln_tee",
697        Step::FusedResidualRmsNorm { .. } => "fused_residual_rms_norm",
698        Step::MatmulQkv { .. } => "matmul_qkv",
699        Step::ElementwiseRegion { .. } => "elementwise_region",
700    }
701}
702
703fn step_runs_on_host(step: &Step) -> bool {
704    match step {
705        Step::DequantMatmulGguf { .. }
706        | Step::DequantGroupedMatmulGguf { .. }
707        | Step::GatedDeltaNet { .. }
708        | Step::Llada2GroupLimitedGate { .. }
709        | Step::UmapKnnHost { .. }
710        | Step::FftHost { .. } => true,
711        #[cfg(feature = "splat")]
712        Step::GaussianSplatRender { .. }
713        | Step::GaussianSplatRenderBackward { .. }
714        | Step::GaussianSplatPrepare { .. }
715        | Step::GaussianSplatRasterize { .. } => true,
716        _ => false,
717    }
718}
719
720fn binary_op_id(op: BinaryOp) -> u32 {
721    match op {
722        BinaryOp::Add => 0,
723        BinaryOp::Sub => 1,
724        BinaryOp::Mul => 2,
725        BinaryOp::Div => 3,
726        BinaryOp::Max => 4,
727        BinaryOp::Min => 5,
728        BinaryOp::Pow => 6,
729    }
730}
731
732fn compare_op_id(op: CmpOp) -> u32 {
733    match op {
734        CmpOp::Eq => 0,
735        CmpOp::Ne => 1,
736        CmpOp::Lt => 2,
737        CmpOp::Le => 3,
738        CmpOp::Gt => 4,
739        CmpOp::Ge => 5,
740    }
741}
742
743fn reduce_op_id(op: ReduceOp) -> u32 {
744    match op {
745        ReduceOp::Sum => 0,
746        ReduceOp::Mean => 1,
747        ReduceOp::Max => 2,
748        ReduceOp::Min => 3,
749        ReduceOp::Prod => 4,
750    }
751}
752
753fn activation_op_id(act: Activation) -> u32 {
754    match act {
755        Activation::Relu => 0,
756        Activation::Sigmoid => 1,
757        Activation::Tanh => 2,
758        Activation::Exp => 3,
759        Activation::Log => 4,
760        Activation::Sqrt => 5,
761        Activation::Rsqrt => 6,
762        Activation::Neg => 7,
763        Activation::Abs => 8,
764        Activation::Gelu => 9,
765        Activation::Silu => 10,
766        Activation::GeluApprox => 11,
767        Activation::Round => 12,
768        Activation::Sin => 13,
769        Activation::Cos => 14,
770        Activation::Tan => 15,
771        Activation::Atan => 16,
772    }
773}
774
775impl WgpuExecutable {
776    /// Resolve the deferred graph against bindings inferred from
777    /// `inputs`, recompile the inner state if the bindings changed
778    /// since the last call, and replay any pending params.
779    fn lazy_compile_for_inputs(&mut self, inputs: &[(&str, &[f32])]) {
780        let unresolved = self
781            .unresolved
782            .as_ref()
783            .expect("lazy_compile_for_inputs called without an unresolved graph");
784        let binding = infer_bindings_from_f32_inputs(unresolved, inputs)
785            .expect("rlx-wgpu lazy compile: could not infer DimBinding from inputs");
786
787        // No-op if shapes haven't changed since the last compile.
788        if let Some(prev) = &self.last_binding
789            && same_binding(prev, &binding)
790        {
791            return;
792        }
793
794        // Resolve and recompile.
795        let resolved = bind_graph(unresolved, &binding);
796        let original = self.unresolved.take();
797        let pending_params = std::mem::take(&mut self.pending_params);
798        let pending_bytes = std::mem::take(&mut self.pending_param_bytes);
799
800        let fresh = Self::compile_static_inner(resolved);
801
802        // Move the freshly-compiled fields into self, preserve the
803        // unresolved+binding state for the next round.
804        self.graph = fresh.graph;
805        self.arena = fresh.arena;
806        self.schedule = fresh.schedule;
807        self.input_offsets = fresh.input_offsets;
808        self.param_offsets = fresh.param_offsets;
809        self.uniforms = fresh.uniforms;
810        self.bind_groups = fresh.bind_groups;
811        self.meta_buffers = fresh.meta_buffers;
812        self.unresolved = original;
813        self.last_binding = Some(binding);
814        // Recompiled — uniforms are now empty buffers; force re-write
815        // on next run().
816        self.uniforms_active_extent = None;
817
818        // Replay pending param uploads against the new arena.
819        for (name, data) in pending_params {
820            self.set_param(&name, &data);
821        }
822        for (name, data) in pending_bytes {
823            self.set_param_bytes(&name, &data);
824        }
825    }
826
827    /// Compile against an explicit `DimBinding`. Each `Dim::Dynamic`
828    /// in the graph that maps to a symbol in `bindings` is replaced
829    /// with `Dim::Static(size)` before the standard compile runs.
830    /// Symbols not in the binding stay dynamic — and then `compile`
831    /// will panic with the usual diagnostic.
832    pub fn compile_with_bindings(graph: Graph, bindings: &DimBinding) -> Self {
833        if bindings.is_empty() {
834            return Self::compile(graph);
835        }
836        // Walk the graph and bind every node's shape.
837        let mut fresh = Graph::new(&graph.name);
838        for node in graph.nodes() {
839            let bound = node.shape.bind(bindings);
840            fresh.add_node(node.op.clone(), node.inputs.clone(), bound);
841        }
842        fresh.set_outputs(graph.outputs.clone());
843        Self::compile(fresh)
844    }
845
846    pub fn compile(graph: Graph) -> Self {
847        if has_dynamic_dims(&graph) {
848            return Self::deferred(graph);
849        }
850        Self::compile_static_inner(graph)
851    }
852
853    /// Compile placeholder for a graph with `Dim::Dynamic` entries.
854    /// The real compile happens on the first `run()` once input data
855    /// reveals the symbol → size bindings. Buffered params (set via
856    /// `set_param` / `set_param_bytes` before run) are replayed.
857    fn deferred(graph: Graph) -> Self {
858        let dev = wgpu_device().expect("rlx-wgpu: no compatible adapter found");
859        // Minimal valid arena buffer. Replaced on first run().
860        let placeholder = dev.device.create_buffer(&wgpu::BufferDescriptor {
861            label: Some("rlx-wgpu deferred placeholder"),
862            size: 16,
863            usage: wgpu::BufferUsages::STORAGE
864                | wgpu::BufferUsages::COPY_DST
865                | wgpu::BufferUsages::COPY_SRC,
866            mapped_at_creation: false,
867        });
868        let arena = Arena {
869            buffer: placeholder,
870            f16_buffer: None,
871            offsets: HashMap::new(),
872            lens: HashMap::new(),
873            size: 0,
874        };
875        Self {
876            graph: graph.clone(),
877            arena,
878            schedule: Vec::new(),
879            input_offsets: HashMap::new(),
880            param_offsets: HashMap::new(),
881            uniforms: Vec::new(),
882            bind_groups: Vec::new(),
883            meta_buffers: Vec::new(),
884            unresolved: Some(graph),
885            last_binding: None,
886            pending_params: HashMap::new(),
887            pending_param_bytes: HashMap::new(),
888            active_extent: None,
889            uniforms_active_extent: None,
890            fft_gpu_steps: Vec::new(),
891        }
892    }
893
894    /// Hint the next `run` to process only the first `actual` rows
895    /// along the bucket axis (out of `upper`, the compile extent).
896    /// Honored when every Step is in the safe set. See PLAN L1.
897    pub fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
898        self.active_extent = extent;
899    }
900
901    fn all_safe_for_active(&self) -> bool {
902        self.schedule.iter().all(|s| s.safe_for_active_extent())
903    }
904
905    fn compile_static_inner(graph: Graph) -> Self {
906        let dev = wgpu_device().expect("rlx-wgpu: no compatible adapter found");
907
908        // Decompose composed/fused ops (FusedMatMulBiasAct, LoraMatMul,
909        // FusedAttentionBlock, FusedTransformerLayer, ...) into primitive
910        // sequences before memory planning so every intermediate gets a
911        // regular arena slot. CPU/Metal/MLX lower the fused variants
912        // directly with bespoke kernels; we choose simplicity over peak
913        // throughput here.
914        let graph = crate::unfuse::unfuse(graph);
915
916        // f32-uniform slots + liveness reuse (pairwise `[n,n]` graphs).
917        let plan = plan_f32_uniform(&graph, 16);
918        let mut arena = Arena::from_plan(&dev.device, &plan);
919        // Override slot lengths with the actual elem*4 byte counts so
920        // readback returns the right element count (slots may be
921        // padded for alignment).
922        for node in graph.nodes() {
923            let elems = node.shape.num_elements().unwrap_or(0);
924            arena.set_actual_len(node.id, elems * 4);
925        }
926
927        // Initialize Constants directly into the arena.
928        for node in graph.nodes() {
929            if let Op::Constant { data } = &node.op
930                && arena.has(node.id)
931                && !data.is_empty()
932            {
933                let bytes_to_write = data.len().min(arena.len_of(node.id));
934                dev.queue.write_buffer(
935                    &arena.buffer,
936                    arena.offset(node.id) as u64,
937                    &data[..bytes_to_write],
938                );
939            }
940        }
941
942        let mut input_offsets = HashMap::new();
943        let mut param_offsets = HashMap::new();
944        for node in graph.nodes() {
945            match &node.op {
946                Op::Input { name } => {
947                    input_offsets.insert(name.clone(), node.id);
948                }
949                Op::Param { name } => {
950                    param_offsets.insert(name.clone(), node.id);
951                }
952                _ => {}
953            }
954        }
955
956        let mm_k = matmul_kernel(&dev.device);
957        let mm_w = matmul_wide_kernel(&dev.device);
958        let mm_f16w = matmul_f16w_kernel(&dev.device);
959        let mm_f16c = matmul_f16_compute_kernel(&dev.device);
960        let mm_coop = matmul_coop16_kernel(&dev.device);
961        let mm_coop_f32 = matmul_coop_f32_kernel(&dev.device);
962        let mm_cast = cast_f32_to_f16_kernel(&dev.device);
963        let bk = binary_kernel(&dev.device);
964        let uk = unary_kernel(&dev.device);
965        let ck = compare_kernel(&dev.device);
966        let wk = where_kernel(&dev.device);
967
968        let mut schedule = Vec::new();
969        let mut uniforms = Vec::new();
970        let mut bind_groups = Vec::new();
971        let mut fft_gpu_steps: Vec<crate::fft_dispatch::FftGpuResources> = Vec::new();
972        let mut gguf_host_pad: Option<(wgpu::Buffer, wgpu::BindGroup)> = None;
973        let mut meta_buffers: Vec<wgpu::Buffer> = Vec::new();
974
975        // Detect (FusedMatMulBiasAct → Narrow×3) split-QKV pattern. Returns
976        // a map parent_node_id → (q_narrow_id, k_narrow_id, v_narrow_id).
977        // The matmul_qkv kernel collapses the matmul + 3 narrows into one
978        // dispatch by routing each output column to the right Q/K/V sink.
979        //
980        // CRITICAL: only mark a pattern site for elision when the parent
981        // FMB will actually take the MatmulQkv path (which only fires
982        // for F32 compute precision). For Coop16/CoopF32-eligible FMBs,
983        // those kernels write to the FMB's *own* output slot, NOT the
984        // 3 narrow slots — skipping the narrows would leave Q/K/V
985        // uninitialized and attention would read garbage. Predict the
986        // compute precision the FMB will receive; only skip when F32.
987        let mut qkv_split: HashMap<NodeId, (NodeId, NodeId, NodeId)> = HashMap::new();
988        for (parent_id, qkv) in detect_split_qkv_pattern(&graph) {
989            let parent = graph.node(parent_id);
990            // Mirror the lowering's precision derivation. FMB inputs:
991            // [a, w, bias]; we need (m, k, n) to query.
992            let a_id = parent.inputs[0];
993            let b_id = parent.inputs[1];
994            let a_dims = graph.node(a_id).shape.dims();
995            let b_dims = graph.node(b_id).shape.dims();
996            let out_dims = parent.shape.dims();
997            let (m, k, n) =
998                if a_dims.len() >= 2 && b_dims.len() == 2 && out_dims.len() == a_dims.len() {
999                    let leading: usize = a_dims[..a_dims.len() - 2]
1000                        .iter()
1001                        .map(|d| d.unwrap_static())
1002                        .product();
1003                    let m_inner = a_dims[a_dims.len() - 2].unwrap_static();
1004                    let k_inner = a_dims[a_dims.len() - 1].unwrap_static();
1005                    let n_inner = b_dims[1].unwrap_static();
1006                    ((leading * m_inner) as u32, k_inner as u32, n_inner as u32)
1007                } else if a_dims.len() == 2 && b_dims.len() == 2 {
1008                    (
1009                        a_dims[0].unwrap_static() as u32,
1010                        a_dims[1].unwrap_static() as u32,
1011                        b_dims[1].unwrap_static() as u32,
1012                    )
1013                } else {
1014                    continue; // unusual shape — let the regular FMB path handle
1015                };
1016            let cp = derive_matmul_compute(&dev.device, &graph, a_id, b_id, m, k, n);
1017            // F32 → matmul_qkv. CoopF32 → matmul_qkv_coop_f32. Both write
1018            // Q/K/V into the narrow output slots, so the narrows can be
1019            // elided. Coop16 still falls back to FMB+narrows (kernel
1020            // would need an f16-acc variant; deferred).
1021            if cp == MatmulCompute::F32 || cp == MatmulCompute::CoopF32 {
1022                qkv_split.insert(parent_id, qkv);
1023            }
1024        }
1025        let qkv_skip_narrows: HashSet<NodeId> = qkv_split
1026            .values()
1027            .flat_map(|&(q, k, v)| [q, k, v])
1028            .collect();
1029
1030        // Detect (Add → LayerNorm) where Add has multi-consumer downstream.
1031        // The standard `FuseResidualLN` pass declines to fuse these (its
1032        // single-consumer guard forces materializing the sum); we collapse
1033        // them here at the wgpu lowering level via `Step::FusedResidualLnTee`.
1034        // Returns:
1035        //   ln_to_tee: ln_id  → (h, delta, gamma, beta, sum_arena_id)
1036        //   skip_adds: { add_id }  — these Add nodes are computed by the
1037        //                            tee step; their normal Step emission
1038        //                            is suppressed.
1039        let (ln_to_tee, skip_adds) = detect_residual_ln_tee_pattern(&graph);
1040
1041        let emit_uniform = |size: usize| -> wgpu::Buffer {
1042            dev.device.create_buffer(&wgpu::BufferDescriptor {
1043                label: Some("rlx-wgpu uniform"),
1044                size: size as u64,
1045                usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1046                mapped_at_creation: false,
1047            })
1048        };
1049
1050        for node in graph.nodes() {
1051            // Helpers — capture device + arena into closures isn't
1052            // ergonomic in the loop, so inline the bind-group build
1053            // when each step is emitted below.
1054            let elems = node.shape.num_elements().unwrap_or(0) as u32;
1055            match &node.op {
1056                Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => continue,
1057                Op::MatMul => {
1058                    let a_id = node.inputs[0];
1059                    let b_id = node.inputs[1];
1060                    let a_shape = graph.node(a_id).shape.dims();
1061                    let b_shape = graph.node(b_id).shape.dims();
1062                    let out_shape = node.shape.dims();
1063                    // Three patterns:
1064                    //   • 2D×2D                              → batch=1
1065                    //   • [..,M,K] × [K,N]  (broadcast rhs)  → batch=1, flatten leading into M
1066                    //   • [..,M,K] × [..,K,N] (matched batch)→ batch=prod(leading), per-batch strides
1067                    let (m, k, n, batch, a_bs, b_bs, c_bs) = if a_shape.len() == 2
1068                        && b_shape.len() == 2
1069                        && out_shape.len() == 2
1070                    {
1071                        (
1072                            a_shape[0].unwrap_static() as u32,
1073                            a_shape[1].unwrap_static() as u32,
1074                            b_shape[1].unwrap_static() as u32,
1075                            1u32,
1076                            0u32,
1077                            0u32,
1078                            0u32,
1079                        )
1080                    } else if a_shape.len() >= 2
1081                        && b_shape.len() == 2
1082                        && out_shape.len() == a_shape.len()
1083                    {
1084                        let leading: usize = a_shape[..a_shape.len() - 2]
1085                            .iter()
1086                            .map(|d| d.unwrap_static())
1087                            .product();
1088                        let m_inner = a_shape[a_shape.len() - 2].unwrap_static();
1089                        let k_inner = a_shape[a_shape.len() - 1].unwrap_static();
1090                        let n_inner = b_shape[1].unwrap_static();
1091                        (
1092                            (leading * m_inner) as u32,
1093                            k_inner as u32,
1094                            n_inner as u32,
1095                            1u32,
1096                            0u32,
1097                            0u32,
1098                            0u32,
1099                        )
1100                    } else if a_shape.len() == b_shape.len()
1101                        && a_shape.len() >= 3
1102                        && out_shape.len() == a_shape.len()
1103                    {
1104                        // True batched: leading dims must match.
1105                        let leading_a: Vec<usize> = a_shape[..a_shape.len() - 2]
1106                            .iter()
1107                            .map(|d| d.unwrap_static())
1108                            .collect();
1109                        let leading_b: Vec<usize> = b_shape[..b_shape.len() - 2]
1110                            .iter()
1111                            .map(|d| d.unwrap_static())
1112                            .collect();
1113                        if leading_a != leading_b {
1114                            panic!(
1115                                "rlx-wgpu MatMul: batched shape mismatch \
1116                                    a_leading={leading_a:?} b_leading={leading_b:?}"
1117                            );
1118                        }
1119                        let b_count: usize = leading_a.iter().product();
1120                        let m_inner = a_shape[a_shape.len() - 2].unwrap_static();
1121                        let k_inner = a_shape[a_shape.len() - 1].unwrap_static();
1122                        let n_inner = b_shape[b_shape.len() - 1].unwrap_static();
1123                        (
1124                            m_inner as u32,
1125                            k_inner as u32,
1126                            n_inner as u32,
1127                            b_count as u32,
1128                            (m_inner * k_inner) as u32,
1129                            (k_inner * n_inner) as u32,
1130                            (m_inner * n_inner) as u32,
1131                        )
1132                    } else {
1133                        panic!(
1134                            "rlx-wgpu MatMul: unsupported shapes a={a_shape:?} b={b_shape:?} \
1135                                out={out_shape:?} (supported: 2D×2D, [..,M,K]×[K,N], [..,M,K]×[..,K,N])"
1136                        );
1137                    };
1138                    let b_is_param = traces_to_param(&graph, b_id);
1139                    let compute_precision =
1140                        derive_matmul_compute(&dev.device, &graph, a_id, b_id, m, k, n);
1141                    // No cast pre-pass needed for Coop16 anymore — the
1142                    // kernel stages A through workgroup-shared memory
1143                    // directly from the f32 arena.
1144                    let _ = mm_cast;
1145                    schedule.push(Step::Matmul {
1146                        m,
1147                        k,
1148                        n,
1149                        batch,
1150                        a_batch_stride: a_bs,
1151                        b_batch_stride: b_bs,
1152                        c_batch_stride: c_bs,
1153                        a_off_f32: (arena.offset(a_id) / 4) as u32,
1154                        b_off_f32: (arena.offset(b_id) / 4) as u32,
1155                        c_off_f32: (arena.offset(node.id) / 4) as u32,
1156                        has_bias: 0,
1157                        bias_off_f32: 0,
1158                        act_id: 0xFFFF,
1159                        b_is_param,
1160                        compute_precision,
1161                    });
1162                    let u = emit_uniform(std::mem::size_of::<MatmulParams>());
1163                    let bg = build_matmul_bind_group(
1164                        &dev.device,
1165                        mm_k,
1166                        mm_w,
1167                        &mm_f16w,
1168                        &mm_f16c,
1169                        &mm_coop,
1170                        &mm_coop_f32,
1171                        &arena,
1172                        &u,
1173                        b_is_param,
1174                        compute_precision,
1175                    );
1176                    uniforms.push(u);
1177                    bind_groups.push(bg);
1178                }
1179                Op::Binary(bop) => {
1180                    // Skip emit when this Add is consumed by a downstream
1181                    // FRLTee — the tee step writes the sum to this node's
1182                    // arena slot directly. Subsequent consumers read the
1183                    // same slot and find correct data.
1184                    if skip_adds.contains(&node.id) {
1185                        continue;
1186                    }
1187                    require_equal_shapes(&graph, &node.inputs, "Binary");
1188                    let p = BinaryParams {
1189                        n: elems,
1190                        a_off: (arena.offset(node.inputs[0]) / 4) as u32,
1191                        b_off: (arena.offset(node.inputs[1]) / 4) as u32,
1192                        c_off: (arena.offset(node.id) / 4) as u32,
1193                        op: binary_op_id(*bop),
1194                        _p0: 0,
1195                        _p1: 0,
1196                        _p2: 0,
1197                    };
1198                    schedule.push(Step::Binary { params: p });
1199                    let u = emit_uniform(std::mem::size_of::<BinaryParams>());
1200                    let bg = bind_two(&dev.device, bk, &arena.buffer, &u);
1201                    uniforms.push(u);
1202                    bind_groups.push(bg);
1203                }
1204                Op::Compare(cop) => {
1205                    require_equal_shapes(&graph, &node.inputs, "Compare");
1206                    let p = BinaryParams {
1207                        n: elems,
1208                        a_off: (arena.offset(node.inputs[0]) / 4) as u32,
1209                        b_off: (arena.offset(node.inputs[1]) / 4) as u32,
1210                        c_off: (arena.offset(node.id) / 4) as u32,
1211                        op: compare_op_id(*cop),
1212                        _p0: 0,
1213                        _p1: 0,
1214                        _p2: 0,
1215                    };
1216                    schedule.push(Step::Compare { params: p });
1217                    let u = emit_uniform(std::mem::size_of::<BinaryParams>());
1218                    let bg = bind_two(&dev.device, ck, &arena.buffer, &u);
1219                    uniforms.push(u);
1220                    bind_groups.push(bg);
1221                }
1222                Op::Activation(act) => {
1223                    let p = UnaryParams {
1224                        n: elems,
1225                        in_off: (arena.offset(node.inputs[0]) / 4) as u32,
1226                        out_off: (arena.offset(node.id) / 4) as u32,
1227                        op: activation_op_id(*act),
1228                        _p0: 0,
1229                        _p1: 0,
1230                        _p2: 0,
1231                        _p3: 0,
1232                    };
1233                    schedule.push(Step::Unary { params: p });
1234                    let u = emit_uniform(std::mem::size_of::<UnaryParams>());
1235                    let bg = bind_two(&dev.device, uk, &arena.buffer, &u);
1236                    uniforms.push(u);
1237                    bind_groups.push(bg);
1238                }
1239                Op::Where => {
1240                    let p = WhereParams {
1241                        n: elems,
1242                        cond_off: (arena.offset(node.inputs[0]) / 4) as u32,
1243                        x_off: (arena.offset(node.inputs[1]) / 4) as u32,
1244                        y_off: (arena.offset(node.inputs[2]) / 4) as u32,
1245                        out_off: (arena.offset(node.id) / 4) as u32,
1246                        _p0: 0,
1247                        _p1: 0,
1248                        _p2: 0,
1249                    };
1250                    schedule.push(Step::Where { params: p });
1251                    let u = emit_uniform(std::mem::size_of::<WhereParams>());
1252                    let bg = bind_two(&dev.device, wk, &arena.buffer, &u);
1253                    uniforms.push(u);
1254                    bind_groups.push(bg);
1255                }
1256
1257                Op::ElementwiseRegion {
1258                    chain,
1259                    num_inputs,
1260                    scalar_input_mask,
1261                    input_modulus,
1262                } => {
1263                    // PLAN L2 native lowering. Encode the chain into a
1264                    // fixed-size u32 buffer; one uniform per region.
1265                    let n = *num_inputs as usize;
1266                    if n > 16 || chain.len() > 32 {
1267                        panic!(
1268                            "rlx-wgpu ElementwiseRegion: chain too large \
1269                                (inputs={n}, steps={}). Caps: 16 / 32. \
1270                                Use UnfuseElementwiseRegions to fall back.",
1271                            chain.len()
1272                        );
1273                    }
1274                    let mut input_offs = [0u32; 16];
1275                    for (i, &id) in node.inputs.iter().enumerate() {
1276                        input_offs[i] = (arena.offset(id) / 4) as u32;
1277                    }
1278                    let encode_operand = |op: &ChainOperand| -> u32 {
1279                        match *op {
1280                            ChainOperand::Input(i) => i & 0x7FFF_FFFFu32,
1281                            ChainOperand::Step(i) => 0x8000_0000u32 | (i & 0x7FFF_FFFFu32),
1282                        }
1283                    };
1284                    let act_sub = |a: Activation| match a {
1285                        Activation::Gelu => 0u32,
1286                        Activation::GeluApprox => 1,
1287                        Activation::Silu => 2,
1288                        Activation::Relu => 3,
1289                        Activation::Sigmoid => 4,
1290                        Activation::Tanh => 5,
1291                        Activation::Exp => 6,
1292                        Activation::Log => 7,
1293                        Activation::Sqrt => 8,
1294                        Activation::Rsqrt => 9,
1295                        Activation::Neg => 10,
1296                        Activation::Abs => 11,
1297                        Activation::Round => 12,
1298                        Activation::Sin => 13,
1299                        Activation::Cos => 14,
1300                        Activation::Tan => 15,
1301                        Activation::Atan => 16,
1302                    };
1303                    let bin_sub = |b: BinaryOp| match b {
1304                        BinaryOp::Add => 0u32,
1305                        BinaryOp::Sub => 1,
1306                        BinaryOp::Mul => 2,
1307                        BinaryOp::Div => 3,
1308                        BinaryOp::Max => 4,
1309                        BinaryOp::Min => 5,
1310                        BinaryOp::Pow => 6,
1311                    };
1312                    let cmp_sub = |c: CmpOp| match c {
1313                        CmpOp::Eq => 0u32,
1314                        CmpOp::Ne => 1,
1315                        CmpOp::Lt => 2,
1316                        CmpOp::Le => 3,
1317                        CmpOp::Gt => 4,
1318                        CmpOp::Ge => 5,
1319                    };
1320                    let mut chain_enc = [0u32; 128];
1321                    for (k, step) in chain.iter().enumerate() {
1322                        let base = k * 4;
1323                        let (kind, sub, lhs, rhs) = match step {
1324                            ChainStep::Activation(a, src) => {
1325                                (0u32, act_sub(*a), encode_operand(src), 0u32)
1326                            }
1327                            ChainStep::Cast(_, src) => (1u32, 0, encode_operand(src), 0u32),
1328                            ChainStep::Binary(op, l, r) => {
1329                                (2u32, bin_sub(*op), encode_operand(l), encode_operand(r))
1330                            }
1331                            ChainStep::Compare(op, l, r) => {
1332                                (3u32, cmp_sub(*op), encode_operand(l), encode_operand(r))
1333                            }
1334                            ChainStep::Where(c, t, f) =>
1335                            // Pack 3 operands into the 4-u32 step:
1336                            // op_sub=cond, lhs=on_true, rhs=on_false.
1337                            {
1338                                (
1339                                    4u32,
1340                                    encode_operand(c),
1341                                    encode_operand(t),
1342                                    encode_operand(f),
1343                                )
1344                            }
1345                        };
1346                        chain_enc[base] = kind;
1347                        chain_enc[base + 1] = sub;
1348                        chain_enc[base + 2] = lhs;
1349                        chain_enc[base + 3] = rhs;
1350                    }
1351                    let p = ElementwiseRegionParams {
1352                        len: elems,
1353                        num_inputs: *num_inputs,
1354                        num_steps: chain.len() as u32,
1355                        dst_off: (arena.offset(node.id) / 4) as u32,
1356                        input_offs,
1357                        chain: chain_enc,
1358                        scalar_input_mask: *scalar_input_mask,
1359                        _pad0: 0,
1360                        _pad1: 0,
1361                        _pad2: 0,
1362                        input_modulus: *input_modulus,
1363                    };
1364                    schedule.push(Step::ElementwiseRegion { params: p });
1365                    let ek = elementwise_region_kernel(&dev.device);
1366                    // STORAGE (not UNIFORM) — the WGSL params struct
1367                    // contains `array<u32, N>` arrays whose 4-byte
1368                    // stride violates uniform's 16-byte stride rule.
1369                    let u = dev.device.create_buffer(&wgpu::BufferDescriptor {
1370                        label: Some("rlx-wgpu region params"),
1371                        size: std::mem::size_of::<ElementwiseRegionParams>() as u64,
1372                        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
1373                        mapped_at_creation: false,
1374                    });
1375                    let bg = bind_two(&dev.device, ek, &arena.buffer, &u);
1376                    uniforms.push(u);
1377                    bind_groups.push(bg);
1378                }
1379
1380                Op::Reduce {
1381                    op: rop,
1382                    axes,
1383                    keep_dim: _,
1384                } => {
1385                    // v3: only reduce-last-axis is supported. The
1386                    // kernel reads inner contiguously and writes one
1387                    // f32 per output row.
1388                    let in_id = node.inputs[0];
1389                    let in_shape = graph.node(in_id).shape.dims();
1390                    let last = in_shape.len() - 1;
1391                    if axes.as_slice() != [last] {
1392                        panic!(
1393                            "rlx-wgpu Reduce: only last-axis is wired \
1394                             (got axes={axes:?}, rank={})",
1395                            in_shape.len()
1396                        );
1397                    }
1398                    let inner = in_shape[last].unwrap_static() as u32;
1399                    let total: u32 = in_shape.iter().map(|d| d.unwrap_static() as u32).product();
1400                    let outer = total / inner.max(1);
1401                    let p = ReduceParams {
1402                        outer,
1403                        inner,
1404                        in_off: (arena.offset(in_id) / 4) as u32,
1405                        out_off: (arena.offset(node.id) / 4) as u32,
1406                        op: reduce_op_id(*rop),
1407                        _p0: 0,
1408                        _p1: 0,
1409                        _p2: 0,
1410                    };
1411                    schedule.push(Step::Reduce { params: p });
1412                    let rk = reduce_kernel(&dev.device);
1413                    let u = emit_uniform(std::mem::size_of::<ReduceParams>());
1414                    let bg = bind_two(&dev.device, rk, &arena.buffer, &u);
1415                    uniforms.push(u);
1416                    bind_groups.push(bg);
1417                }
1418
1419                Op::Softmax { axis } => {
1420                    let in_id = node.inputs[0];
1421                    let in_shape = graph.node(in_id).shape.dims();
1422                    let last = (in_shape.len() - 1) as i32;
1423                    if *axis != -1 && *axis != last {
1424                        panic!("rlx-wgpu Softmax: only last-axis wired (got axis={axis})");
1425                    }
1426                    let inner = in_shape[in_shape.len() - 1].unwrap_static() as u32;
1427                    let total: u32 = in_shape.iter().map(|d| d.unwrap_static() as u32).product();
1428                    let outer = total / inner.max(1);
1429                    let p = SoftmaxParams {
1430                        outer,
1431                        inner,
1432                        in_off: (arena.offset(in_id) / 4) as u32,
1433                        out_off: (arena.offset(node.id) / 4) as u32,
1434                        _p0: 0,
1435                        _p1: 0,
1436                        _p2: 0,
1437                        _p3: 0,
1438                    };
1439                    schedule.push(Step::Softmax { params: p });
1440                    let sk = softmax_kernel(&dev.device);
1441                    let u = emit_uniform(std::mem::size_of::<SoftmaxParams>());
1442                    let bg = bind_two(&dev.device, sk, &arena.buffer, &u);
1443                    uniforms.push(u);
1444                    bind_groups.push(bg);
1445                }
1446
1447                Op::LayerNorm { axis: _, eps } | Op::RmsNorm { axis: _, eps } => {
1448                    let in_id = node.inputs[0];
1449                    let in_shape = graph.node(in_id).shape.dims();
1450                    let inner = in_shape[in_shape.len() - 1].unwrap_static() as u32;
1451                    let total: u32 = in_shape.iter().map(|d| d.unwrap_static() as u32).product();
1452                    let outer = total / inner.max(1);
1453                    let is_layer_norm = matches!(&node.op, Op::LayerNorm { .. });
1454
1455                    // FRLTee fast path: if this LN is the head of a
1456                    // (multi-consumer Add → LN) pattern, emit one
1457                    // `Step::FusedResidualLnTee` that writes the sum to
1458                    // the eliminated Add's arena slot AND the LN result
1459                    // to this LN's slot. The Add itself is skipped
1460                    // upstream (`skip_adds`).
1461                    if is_layer_norm
1462                        && let Some(&(h_id, delta_id, gamma_id, beta_id, sum_id)) =
1463                            ln_to_tee.get(&node.id)
1464                    {
1465                        let p = FusedResidualLnTeeParams {
1466                            outer,
1467                            inner,
1468                            in_off: (arena.offset(h_id) / 4) as u32,
1469                            residual_off: (arena.offset(delta_id) / 4) as u32,
1470                            bias_off: 0, // FRLTee currently no-bias only
1471                            gamma_off: (arena.offset(gamma_id) / 4) as u32,
1472                            beta_off: (arena.offset(beta_id) / 4) as u32,
1473                            sum_off: (arena.offset(sum_id) / 4) as u32,
1474                            ln_out_off: (arena.offset(node.id) / 4) as u32,
1475                            eps_bits: eps.to_bits(),
1476                            has_bias: 0,
1477                            _p0: 0,
1478                        };
1479                        schedule.push(Step::FusedResidualLnTee { params: p });
1480                        let frtk = fused_residual_ln_tee_kernel(&dev.device);
1481                        let u = emit_uniform(std::mem::size_of::<FusedResidualLnTeeParams>());
1482                        let bg = bind_two(&dev.device, frtk, &arena.buffer, &u);
1483                        uniforms.push(u);
1484                        bind_groups.push(bg);
1485                        continue;
1486                    }
1487
1488                    let gamma_id = node.inputs[1];
1489                    // beta is the third input for LayerNorm; RmsNorm
1490                    // ignores it (kernel branch on `op` skips the read).
1491                    let beta_id = if is_layer_norm && node.inputs.len() >= 3 {
1492                        node.inputs[2]
1493                    } else {
1494                        // Use gamma's offset as a benign placeholder;
1495                        // the RmsNorm kernel branch never reads it.
1496                        gamma_id
1497                    };
1498                    let p = LayerNormParams {
1499                        outer,
1500                        inner,
1501                        in_off: (arena.offset(in_id) / 4) as u32,
1502                        out_off: (arena.offset(node.id) / 4) as u32,
1503                        gamma_off: (arena.offset(gamma_id) / 4) as u32,
1504                        beta_off: (arena.offset(beta_id) / 4) as u32,
1505                        eps_bits: eps.to_bits(),
1506                        op: if is_layer_norm { 0 } else { 1 },
1507                    };
1508                    schedule.push(Step::LayerNorm { params: p });
1509                    let lk = layernorm_kernel(&dev.device);
1510                    let u = emit_uniform(std::mem::size_of::<LayerNormParams>());
1511                    let bg = bind_two(&dev.device, lk, &arena.buffer, &u);
1512                    uniforms.push(u);
1513                    bind_groups.push(bg);
1514                }
1515
1516                Op::Reshape { .. } | Op::Cast { .. } => {
1517                    // No-op: memory planner view-aliased this slot.
1518                }
1519
1520                Op::Transpose { perm } => {
1521                    let in_id = node.inputs[0];
1522                    let in_shape = graph.node(in_id).shape.dims();
1523                    let out_shape = node.shape.dims();
1524                    let rank = perm.len();
1525                    if rank != in_shape.len() || rank != out_shape.len() {
1526                        panic!("rlx-wgpu Transpose: rank mismatch");
1527                    }
1528                    let in_dims: Vec<u32> =
1529                        in_shape.iter().map(|d| d.unwrap_static() as u32).collect();
1530                    let out_dims: Vec<u32> =
1531                        out_shape.iter().map(|d| d.unwrap_static() as u32).collect();
1532                    // Input cumulative strides (row-major).
1533                    let mut in_strides = vec![1u32; rank];
1534                    for i in (0..rank.saturating_sub(1)).rev() {
1535                        in_strides[i] = in_strides[i + 1] * in_dims[i + 1];
1536                    }
1537                    // For each *output* axis i, the corresponding input
1538                    // axis is perm[i] — its stride is in_strides[perm[i]].
1539                    let strides_for_out: Vec<u32> =
1540                        (0..rank).map(|i| in_strides[perm[i]]).collect();
1541
1542                    // Build meta buffer: dims (rank u32s) + strides (rank u32s).
1543                    let mut meta_data: Vec<u32> = Vec::with_capacity(rank * 2);
1544                    meta_data.extend_from_slice(&out_dims);
1545                    meta_data.extend_from_slice(&strides_for_out);
1546                    let meta_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
1547                        label: Some("rlx-wgpu transpose meta"),
1548                        size: (meta_data.len() * 4).max(4) as u64,
1549                        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
1550                        mapped_at_creation: false,
1551                    });
1552                    dev.queue
1553                        .write_buffer(&meta_buf, 0, bytemuck::cast_slice(&meta_data));
1554                    let meta_idx = meta_buffers.len();
1555                    meta_buffers.push(meta_buf);
1556
1557                    // PLAN L1: precompute "bucket axis stays at out
1558                    // axis 0" flag from perm. When `perm[0] == 0`,
1559                    // active-extent scaling of `out_total` is safe.
1560                    let bucket_outermost = if perm[0] == 0 { 1u32 } else { 0u32 };
1561                    let p = TransposeParams {
1562                        rank: rank as u32,
1563                        out_total: elems,
1564                        in_off: (arena.offset(in_id) / 4) as u32,
1565                        out_off: (arena.offset(node.id) / 4) as u32,
1566                        bucket_outermost,
1567                        out_dim_0: out_dims[0],
1568                        _p2: 0,
1569                        _p3: 0,
1570                    };
1571                    schedule.push(Step::Transpose {
1572                        params: p,
1573                        meta_idx,
1574                    });
1575                    let tk = transpose_kernel(&dev.device);
1576                    let u = emit_uniform(std::mem::size_of::<TransposeParams>());
1577                    let bg = dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1578                        label: Some("rlx-wgpu transpose bg"),
1579                        layout: &tk.bgl,
1580                        entries: &[
1581                            wgpu::BindGroupEntry {
1582                                binding: 0,
1583                                resource: arena.buffer.as_entire_binding(),
1584                            },
1585                            wgpu::BindGroupEntry {
1586                                binding: 1,
1587                                resource: u.as_entire_binding(),
1588                            },
1589                            wgpu::BindGroupEntry {
1590                                binding: 2,
1591                                resource: meta_buffers[meta_idx].as_entire_binding(),
1592                            },
1593                        ],
1594                    });
1595                    uniforms.push(u);
1596                    bind_groups.push(bg);
1597                }
1598
1599                Op::Narrow { axis, start, len } => {
1600                    // Part of a split-QKV pattern: the parent FMB has been
1601                    // (or will be) replaced by Step::MatmulQkv that writes
1602                    // directly into this narrow's arena slot. Skip the
1603                    // narrow's own dispatch.
1604                    if qkv_skip_narrows.contains(&node.id) {
1605                        continue;
1606                    }
1607                    let in_id = node.inputs[0];
1608                    let in_shape = graph.node(in_id).shape.dims();
1609                    let outer: u32 = in_shape[..*axis]
1610                        .iter()
1611                        .map(|d| d.unwrap_static() as u32)
1612                        .product::<u32>()
1613                        .max(1);
1614                    let inner: u32 = in_shape[*axis + 1..]
1615                        .iter()
1616                        .map(|d| d.unwrap_static() as u32)
1617                        .product::<u32>()
1618                        .max(1);
1619                    let axis_in = in_shape[*axis].unwrap_static() as u32;
1620                    let p = NarrowConcatParams {
1621                        total: elems,
1622                        outer,
1623                        inner,
1624                        axis_in_size: axis_in,
1625                        axis_out_size: *len as u32,
1626                        start: *start as u32,
1627                        in_off: (arena.offset(in_id) / 4) as u32,
1628                        out_off: (arena.offset(node.id) / 4) as u32,
1629                    };
1630                    schedule.push(Step::Narrow { params: p });
1631                    let nk = narrow_kernel(&dev.device);
1632                    let u = emit_uniform(std::mem::size_of::<NarrowConcatParams>());
1633                    let bg = bind_two(&dev.device, nk, &arena.buffer, &u);
1634                    uniforms.push(u);
1635                    bind_groups.push(bg);
1636                }
1637
1638                Op::Concat { axis } => {
1639                    let out_shape = node.shape.dims();
1640                    let outer: u32 = out_shape[..*axis]
1641                        .iter()
1642                        .map(|d| d.unwrap_static() as u32)
1643                        .product::<u32>()
1644                        .max(1);
1645                    let inner: u32 = out_shape[*axis + 1..]
1646                        .iter()
1647                        .map(|d| d.unwrap_static() as u32)
1648                        .product::<u32>()
1649                        .max(1);
1650                    let axis_out = out_shape[*axis].unwrap_static() as u32;
1651
1652                    let mut start_pos: u32 = 0;
1653                    for &in_id in &node.inputs {
1654                        let in_shape = graph.node(in_id).shape.dims();
1655                        let axis_in = in_shape[*axis].unwrap_static() as u32;
1656                        let in_total: u32 =
1657                            in_shape.iter().map(|d| d.unwrap_static() as u32).product();
1658                        let p = NarrowConcatParams {
1659                            total: in_total,
1660                            outer,
1661                            inner,
1662                            axis_in_size: axis_in,
1663                            axis_out_size: axis_out,
1664                            start: start_pos,
1665                            in_off: (arena.offset(in_id) / 4) as u32,
1666                            out_off: (arena.offset(node.id) / 4) as u32,
1667                        };
1668                        schedule.push(Step::Concat { params: p });
1669                        let cck = concat_kernel(&dev.device);
1670                        let u = emit_uniform(std::mem::size_of::<NarrowConcatParams>());
1671                        let bg = bind_two(&dev.device, cck, &arena.buffer, &u);
1672                        uniforms.push(u);
1673                        bind_groups.push(bg);
1674                        start_pos += axis_in;
1675                    }
1676                }
1677
1678                Op::Attention {
1679                    num_heads,
1680                    head_dim,
1681                    mask_kind,
1682                    score_scale: _,
1683                    attn_logit_softcap: _,
1684                } => {
1685                    // v5: rank-4 [B, H, S, D] inputs only. SlidingWindow
1686                    // synthesizes a Custom mask host-side.
1687                    let q_id = node.inputs[0];
1688                    let k_id = node.inputs[1];
1689                    let v_id = node.inputs[2];
1690                    let q_shape = graph.node(q_id).shape.dims();
1691                    let k_shape = graph.node(k_id).shape.dims();
1692                    // Accept either rank-4 [B, H, S, D] or rank-3 [B*H, S, D]
1693                    // (the latter is what BERT-flavored builders emit). For
1694                    // rank-3 we treat the leading dim as `batch * heads`,
1695                    // setting heads = num_heads from the Op so the kernel's
1696                    // (b, h) indexing folds back to the right offset.
1697                    let h = *num_heads as u32;
1698                    let hd = *head_dim as u32;
1699                    let (batch, heads, seq_q, seq_k) = match q_shape.len() {
1700                        4 => (
1701                            q_shape[0].unwrap_static() as u32,
1702                            q_shape[1].unwrap_static() as u32,
1703                            q_shape[2].unwrap_static() as u32,
1704                            k_shape[2].unwrap_static() as u32,
1705                        ),
1706                        3 => {
1707                            // Two rank-3 layouts coexist:
1708                            //   [B, S, H·D] — transpose-elided layout
1709                            //   [B·H, S, D] — canonical compacted layout
1710                            // Distinguish by last-dim: if it equals H·D
1711                            // (the per-token feature width) it's [B, S, H·D];
1712                            // otherwise it's [B·H, S, D].
1713                            let last = q_shape[2].unwrap_static() as u32;
1714                            if last == h * hd {
1715                                // [B, S, H·D]: leading = B, seq = S
1716                                (
1717                                    q_shape[0].unwrap_static() as u32,
1718                                    h,
1719                                    q_shape[1].unwrap_static() as u32,
1720                                    k_shape[1].unwrap_static() as u32,
1721                                )
1722                            } else {
1723                                // [B·H, S, D]: leading must be divisible by H
1724                                let leading = q_shape[0].unwrap_static() as u32;
1725                                if !leading.is_multiple_of(h) {
1726                                    panic!(
1727                                        "rlx-wgpu Attention: rank-3 leading dim {leading} \
1728                                            not divisible by num_heads {h} (and last dim \
1729                                            {last} ≠ H·D = {})",
1730                                        h * hd
1731                                    );
1732                                }
1733                                (
1734                                    leading / h,
1735                                    h,
1736                                    q_shape[1].unwrap_static() as u32,
1737                                    k_shape[1].unwrap_static() as u32,
1738                                )
1739                            }
1740                        }
1741                        other => panic!(
1742                            "rlx-wgpu Attention: only rank-3 / rank-4 Q,K,V \
1743                                         inputs supported (got rank {other})"
1744                        ),
1745                    };
1746                    let scale = 1.0_f32 / (hd as f32).sqrt();
1747
1748                    let (mask_kind_id, mask_off, mask_buf, window) = match mask_kind {
1749                        MaskKind::None => (0u32, 0u32, None, 0u32),
1750                        MaskKind::Causal => (1u32, 0u32, None, 0u32),
1751                        MaskKind::Custom | MaskKind::Bias => {
1752                            let m_id = node.inputs[3];
1753                            (2u32, (arena.offset(m_id) / 4) as u32, None, 0u32)
1754                        }
1755                        MaskKind::SlidingWindow(w) => (3u32, 0u32, None, *w as u32),
1756                    };
1757
1758                    // Mask address strides. For Custom masks, derive from
1759                    // the mask's IR shape so the kernel can broadcast a
1760                    // [B, S] padding mask without materializing the full
1761                    // [B, H, S_q, S_k] expansion. Other mask kinds use
1762                    // canonical [B, H, S_q, S_k] strides (the kernel's
1763                    // mask_partial computation is harmless when not read).
1764                    struct MStrides {
1765                        b: u32,
1766                        h: u32,
1767                        q: u32,
1768                        k: u32,
1769                    }
1770                    let mask_strides = if mask_kind_id == 2u32 {
1771                        let m_dims = graph.node(node.inputs[3]).shape.dims();
1772                        let dim = |i: usize| m_dims[i].unwrap_static() as u32;
1773                        match m_dims.len() {
1774                            2 => MStrides {
1775                                b: dim(1),
1776                                h: 0,
1777                                q: 0,
1778                                k: 1,
1779                            },
1780                            3 => MStrides {
1781                                b: dim(1) * dim(2),
1782                                h: 0,
1783                                q: dim(2),
1784                                k: 1,
1785                            },
1786                            4 => MStrides {
1787                                b: dim(1) * dim(2) * dim(3),
1788                                h: dim(2) * dim(3),
1789                                q: dim(3),
1790                                k: 1,
1791                            },
1792                            _ => MStrides {
1793                                b: heads * seq_q * seq_k,
1794                                h: seq_q * seq_k,
1795                                q: seq_k,
1796                                k: 1,
1797                            },
1798                        }
1799                    } else {
1800                        MStrides {
1801                            b: heads * seq_q * seq_k,
1802                            h: seq_q * seq_k,
1803                            q: seq_k,
1804                            k: 1,
1805                        }
1806                    };
1807
1808                    // Compute per-axis strides from input shape. Supports
1809                    // both [B, H, S, D] (rank-4) / [B*H, S, D] (rank-3)
1810                    // layouts (the canonical post-`unfuse` form) and the
1811                    // future [B, S, H, D] / [B, S, H·D] layout that
1812                    // skips the unfuse transposes. Detection: if the
1813                    // input shape's rank-3 last-dim equals H·D, treat
1814                    // as [B, S, H·D] = [B, S, H, D]; otherwise canonical.
1815                    let infer_strides =
1816                        |shape: &[rlx_ir::shape::Dim], seq_extent: u32| -> (u32, u32, u32) {
1817                            let last = shape[shape.len() - 1].unwrap_static() as u32;
1818                            if shape.len() == 3 && last == (heads * hd) {
1819                                // [B, S, H·D] viewed as [B, S, H, D]
1820                                let head_dim_total = heads * hd;
1821                                (seq_extent * head_dim_total, hd, head_dim_total)
1822                            } else {
1823                                // Canonical [B, H, S, D] (or rank-3 [B*H, S, D])
1824                                (heads * seq_extent * hd, seq_extent * hd, hd)
1825                            }
1826                        };
1827                    let (q_b, q_h, q_s) = infer_strides(q_shape, seq_q);
1828                    let (k_b, k_h, k_s) = infer_strides(k_shape, seq_k);
1829                    let v_shape = graph.node(v_id).shape.dims();
1830                    let (v_b, v_h, v_s) = infer_strides(v_shape, seq_k);
1831                    let out_shape = node.shape.dims();
1832                    let (o_b, o_h, o_s) = infer_strides(out_shape, seq_q);
1833                    let p = AttentionParams {
1834                        batch,
1835                        heads,
1836                        seq_q,
1837                        seq_k,
1838                        head_dim: hd,
1839                        q_off: (arena.offset(q_id) / 4) as u32,
1840                        k_off: (arena.offset(k_id) / 4) as u32,
1841                        v_off: (arena.offset(v_id) / 4) as u32,
1842                        out_off: (arena.offset(node.id) / 4) as u32,
1843                        mask_off,
1844                        mask_kind: mask_kind_id,
1845                        scale_bits: scale.to_bits(),
1846                        window,
1847                        // Mask strides — derive from the mask's IR shape:
1848                        //   [B, S]:           (mb=S,        mh=0,    mq=0,   mk=1)
1849                        //   [B, S_q, S_k]:    (mb=S_q·S_k,  mh=0,    mq=S_k, mk=1)
1850                        //   [B, H, S_q, S_k]: (mb=H·S_q·S_k mh=S_q·S_k mq=S_k mk=1)
1851                        // Stride 0 means the kernel broadcasts across that
1852                        // axis (reads the same element for every value of
1853                        // the index). Lets us skip the Expand pre-pass that
1854                        // unfuse used to emit per attention block.
1855                        seq_q_stride: mask_strides.q,
1856                        seq_k_stride: mask_strides.k,
1857                        mask_batch_stride: mask_strides.b,
1858                        mask_head_stride: mask_strides.h,
1859                        _pad_mask_0: 0,
1860                        _pad_mask_1: 0,
1861                        _pad_mask_2: 0,
1862                        q_batch_stride: q_b,
1863                        q_head_stride: q_h,
1864                        q_seq_stride: q_s,
1865                        _pad_q: 0,
1866                        k_batch_stride: k_b,
1867                        k_head_stride: k_h,
1868                        k_seq_stride: k_s,
1869                        _pad_k: 0,
1870                        v_batch_stride: v_b,
1871                        v_head_stride: v_h,
1872                        v_seq_stride: v_s,
1873                        _pad_v: 0,
1874                        o_batch_stride: o_b,
1875                        o_head_stride: o_h,
1876                        o_seq_stride: o_s,
1877                        _pad_o: 0,
1878                    };
1879                    let _ = num_heads;
1880                    schedule.push(Step::Attention {
1881                        params: p,
1882                        mask_buf,
1883                    });
1884                    let ak = attention_kernel(&dev.device);
1885                    let u = emit_uniform(std::mem::size_of::<AttentionParams>());
1886                    let bg = bind_two(&dev.device, ak, &arena.buffer, &u);
1887                    uniforms.push(u);
1888                    bind_groups.push(bg);
1889                }
1890
1891                Op::AttentionBackward {
1892                    num_heads: _,
1893                    head_dim,
1894                    mask_kind,
1895                    wrt,
1896                } => {
1897                    use rlx_ir::op::AttentionBwdWrt;
1898                    let q_id = node.inputs[0];
1899                    let k_id = node.inputs[1];
1900                    let v_id = node.inputs[2];
1901                    let dy_id = node.inputs[3];
1902                    let q_shape = graph.node(q_id).shape.dims();
1903                    let k_shape = graph.node(k_id).shape.dims();
1904                    let hd = *head_dim as u32;
1905                    let (batch, heads, seq_q, seq_k) = match q_shape.len() {
1906                        4 => (
1907                            q_shape[0].unwrap_static() as u32,
1908                            q_shape[1].unwrap_static() as u32,
1909                            q_shape[2].unwrap_static() as u32,
1910                            k_shape[2].unwrap_static() as u32,
1911                        ),
1912                        3 => {
1913                            let h = q_shape[2].unwrap_static() as u32 / hd;
1914                            (
1915                                q_shape[0].unwrap_static() as u32 / h,
1916                                h,
1917                                q_shape[1].unwrap_static() as u32,
1918                                k_shape[1].unwrap_static() as u32,
1919                            )
1920                        }
1921                        other => panic!(
1922                            "rlx-wgpu AttentionBackward: only rank-3/4 Q,K,V (got rank {other})"
1923                        ),
1924                    };
1925                    let scale = 1.0_f32 / (hd as f32).sqrt();
1926                    let (mask_kind_id, mask_off, mask_buf, window) = match mask_kind {
1927                        MaskKind::None => (0u32, 0u32, None, 0u32),
1928                        MaskKind::Causal => (1u32, 0u32, None, 0u32),
1929                        MaskKind::Custom => {
1930                            (2u32, (arena.offset(node.inputs[4]) / 4) as u32, None, 0u32)
1931                        }
1932                        MaskKind::Bias => {
1933                            (4u32, (arena.offset(node.inputs[4]) / 4) as u32, None, 0u32)
1934                        }
1935                        MaskKind::SlidingWindow(w) => (3u32, 0u32, None, *w as u32),
1936                    };
1937                    struct MStrides {
1938                        b: u32,
1939                        h: u32,
1940                        q: u32,
1941                        k: u32,
1942                    }
1943                    let mask_strides = if mask_kind_id == 2 || mask_kind_id == 4 {
1944                        let m_dims = graph.node(node.inputs[4]).shape.dims();
1945                        let dim = |i: usize| m_dims[i].unwrap_static() as u32;
1946                        match m_dims.len() {
1947                            2 => MStrides {
1948                                b: dim(1),
1949                                h: 0,
1950                                q: 0,
1951                                k: 1,
1952                            },
1953                            3 => MStrides {
1954                                b: dim(1) * dim(2),
1955                                h: 0,
1956                                q: dim(2),
1957                                k: 1,
1958                            },
1959                            4 => MStrides {
1960                                b: dim(1) * dim(2) * dim(3),
1961                                h: dim(2) * dim(3),
1962                                q: dim(3),
1963                                k: 1,
1964                            },
1965                            _ => MStrides {
1966                                b: heads * seq_q * seq_k,
1967                                h: seq_q * seq_k,
1968                                q: seq_k,
1969                                k: 1,
1970                            },
1971                        }
1972                    } else {
1973                        MStrides {
1974                            b: heads * seq_q * seq_k,
1975                            h: seq_q * seq_k,
1976                            q: seq_k,
1977                            k: 1,
1978                        }
1979                    };
1980                    let infer_strides =
1981                        |shape: &[rlx_ir::shape::Dim], seq_extent: u32| -> (u32, u32, u32) {
1982                            let last = shape[shape.len() - 1].unwrap_static() as u32;
1983                            if shape.len() == 3 && last == (heads * hd) {
1984                                let head_dim_total = heads * hd;
1985                                (seq_extent * head_dim_total, hd, head_dim_total)
1986                            } else {
1987                                (heads * seq_extent * hd, seq_extent * hd, hd)
1988                            }
1989                        };
1990                    let (q_b, q_h, q_s) = infer_strides(q_shape, seq_q);
1991                    let (k_b, k_h, k_s) = infer_strides(k_shape, seq_k);
1992                    let v_shape = graph.node(v_id).shape.dims();
1993                    let (v_b, v_h, v_s) = infer_strides(v_shape, seq_k);
1994                    let out_shape = node.shape.dims();
1995                    let out_seq = match wrt {
1996                        AttentionBwdWrt::Query => seq_q,
1997                        AttentionBwdWrt::Key | AttentionBwdWrt::Value => seq_k,
1998                    };
1999                    let (o_b, o_h, o_s) = infer_strides(out_shape, out_seq);
2000                    let wrt_id = match wrt {
2001                        AttentionBwdWrt::Query => 0u32,
2002                        AttentionBwdWrt::Key => 1u32,
2003                        AttentionBwdWrt::Value => 2u32,
2004                    };
2005                    let p = AttentionBwdParams {
2006                        batch,
2007                        heads,
2008                        seq_q,
2009                        seq_k,
2010                        head_dim: hd,
2011                        q_off: (arena.offset(q_id) / 4) as u32,
2012                        k_off: (arena.offset(k_id) / 4) as u32,
2013                        v_off: (arena.offset(v_id) / 4) as u32,
2014                        dy_off: (arena.offset(dy_id) / 4) as u32,
2015                        out_off: (arena.offset(node.id) / 4) as u32,
2016                        mask_off,
2017                        mask_kind: mask_kind_id,
2018                        scale_bits: scale.to_bits(),
2019                        window,
2020                        wrt: wrt_id,
2021                        seq_q_stride: mask_strides.q,
2022                        seq_k_stride: mask_strides.k,
2023                        mask_batch_stride: mask_strides.b,
2024                        mask_head_stride: mask_strides.h,
2025                        _pad_mask_0: 0,
2026                        _pad_mask_1: 0,
2027                        _pad_mask_2: 0,
2028                        q_batch_stride: q_b,
2029                        q_head_stride: q_h,
2030                        q_seq_stride: q_s,
2031                        _pad_q: 0,
2032                        k_batch_stride: k_b,
2033                        k_head_stride: k_h,
2034                        k_seq_stride: k_s,
2035                        _pad_k: 0,
2036                        v_batch_stride: v_b,
2037                        v_head_stride: v_h,
2038                        v_seq_stride: v_s,
2039                        _pad_v: 0,
2040                        o_batch_stride: o_b,
2041                        o_head_stride: o_h,
2042                        o_seq_stride: o_s,
2043                        _pad_o: 0,
2044                    };
2045                    schedule.push(Step::AttentionBackward {
2046                        params: p,
2047                        mask_buf,
2048                    });
2049                    let ak = attention_bwd_kernel(&dev.device);
2050                    let u = emit_uniform(std::mem::size_of::<AttentionBwdParams>());
2051                    let bg = bind_two(&dev.device, ak, &arena.buffer, &u);
2052                    uniforms.push(u);
2053                    bind_groups.push(bg);
2054                }
2055
2056                Op::Rope { head_dim, n_rot: _ } => {
2057                    let x_id = node.inputs[0];
2058                    let cos_id = node.inputs[1];
2059                    let sin_id = node.inputs[2];
2060                    let x_shape = graph.node(x_id).shape.dims();
2061                    let last = x_shape.last().map(|d| d.unwrap_static()).unwrap_or(0);
2062                    if !last.is_multiple_of(*head_dim) {
2063                        panic!(
2064                            "rlx-wgpu Rope: last_dim ({last}) must be a multiple \
2065                                of head_dim ({head_dim})"
2066                        );
2067                    }
2068                    if head_dim % 2 != 0 {
2069                        panic!("rlx-wgpu Rope: head_dim must be even");
2070                    }
2071                    let total: u32 = x_shape.iter().map(|d| d.unwrap_static() as u32).product();
2072                    let seq = x_shape[x_shape.len() - 2].unwrap_static() as u32;
2073                    // PLAN L1: derive batch from total / seq / last_dim
2074                    // (= product of leading dims). `seq_stride` stays at
2075                    // full seq for buffer offset math; `seq` becomes the
2076                    // runtime-scaled loop bound.
2077                    let batch = total / (seq * last as u32).max(1);
2078                    let p = RopeParams {
2079                        n_total: total,
2080                        seq,
2081                        head_dim: *head_dim as u32,
2082                        half: (*head_dim / 2) as u32,
2083                        in_off: (arena.offset(x_id) / 4) as u32,
2084                        cos_off: (arena.offset(cos_id) / 4) as u32,
2085                        sin_off: (arena.offset(sin_id) / 4) as u32,
2086                        out_off: (arena.offset(node.id) / 4) as u32,
2087                        last_dim: last as u32,
2088                        batch,
2089                        seq_stride: seq,
2090                        _p2: 0,
2091                    };
2092                    schedule.push(Step::Rope { params: p });
2093                    let rk = rope_kernel(&dev.device);
2094                    let u = emit_uniform(std::mem::size_of::<RopeParams>());
2095                    let bg = bind_two(&dev.device, rk, &arena.buffer, &u);
2096                    uniforms.push(u);
2097                    bind_groups.push(bg);
2098                }
2099
2100                Op::Expand { target_shape } => {
2101                    let in_id = node.inputs[0];
2102                    let in_shape = graph.node(in_id).shape.dims();
2103                    let rank = target_shape.len();
2104                    if rank != in_shape.len() {
2105                        panic!(
2106                            "rlx-wgpu Expand: rank mismatch \
2107                                (in_rank={}, target_rank={})",
2108                            in_shape.len(),
2109                            rank
2110                        );
2111                    }
2112                    let out_dims: Vec<u32> = target_shape.iter().map(|&d| d as u32).collect();
2113                    let in_dims: Vec<u32> =
2114                        in_shape.iter().map(|d| d.unwrap_static() as u32).collect();
2115                    // Cumulative input strides (row-major). When the
2116                    // input dim is 1 but target dim > 1, that axis
2117                    // broadcasts → stride = 0.
2118                    let mut in_strides_row = vec![1u32; rank];
2119                    for i in (0..rank.saturating_sub(1)).rev() {
2120                        in_strides_row[i] = in_strides_row[i + 1] * in_dims[i + 1];
2121                    }
2122                    let strides_for_out: Vec<u32> = (0..rank)
2123                        .map(|i| {
2124                            if in_dims[i] == 1 && out_dims[i] != 1 {
2125                                0
2126                            } else {
2127                                in_strides_row[i]
2128                            }
2129                        })
2130                        .collect();
2131
2132                    let mut meta_data: Vec<u32> = Vec::with_capacity(rank * 2);
2133                    meta_data.extend_from_slice(&out_dims);
2134                    meta_data.extend_from_slice(&strides_for_out);
2135                    let meta_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
2136                        label: Some("rlx-wgpu expand meta"),
2137                        size: (meta_data.len() * 4).max(4) as u64,
2138                        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
2139                        mapped_at_creation: false,
2140                    });
2141                    dev.queue
2142                        .write_buffer(&meta_buf, 0, bytemuck::cast_slice(&meta_data));
2143                    let meta_idx = meta_buffers.len();
2144                    meta_buffers.push(meta_buf);
2145
2146                    // PLAN L1: bucket axis stays at out axis 0 iff the
2147                    // expand at axis 0 isn't a broadcast (in_dims[0]
2148                    // matches out_dims[0]). When broadcast at axis 0
2149                    // (in_dims[0]==1, out_dims[0]>1), the bucket-axis
2150                    // contract doesn't apply — fall back to full extent.
2151                    let bucket_outermost = if in_dims[0] == out_dims[0] {
2152                        1u32
2153                    } else {
2154                        0u32
2155                    };
2156                    let p = ExpandParams {
2157                        rank: rank as u32,
2158                        out_total: elems,
2159                        in_off: (arena.offset(in_id) / 4) as u32,
2160                        out_off: (arena.offset(node.id) / 4) as u32,
2161                        bucket_outermost,
2162                        out_dim_0: out_dims[0],
2163                        _p2: 0,
2164                        _p3: 0,
2165                    };
2166                    schedule.push(Step::Expand {
2167                        params: p,
2168                        meta_idx,
2169                    });
2170                    let ek = expand_kernel(&dev.device);
2171                    let u = emit_uniform(std::mem::size_of::<ExpandParams>());
2172                    let bg = dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
2173                        label: Some("rlx-wgpu expand bg"),
2174                        layout: &ek.bgl,
2175                        entries: &[
2176                            wgpu::BindGroupEntry {
2177                                binding: 0,
2178                                resource: arena.buffer.as_entire_binding(),
2179                            },
2180                            wgpu::BindGroupEntry {
2181                                binding: 1,
2182                                resource: u.as_entire_binding(),
2183                            },
2184                            wgpu::BindGroupEntry {
2185                                binding: 2,
2186                                resource: meta_buffers[meta_idx].as_entire_binding(),
2187                            },
2188                        ],
2189                    });
2190                    uniforms.push(u);
2191                    bind_groups.push(bg);
2192                }
2193
2194                Op::Gather { axis } => {
2195                    let table_id = node.inputs[0];
2196                    let idx_id = node.inputs[1];
2197                    if *axis == 0 {
2198                        let table_shape = graph.node(table_id).shape.dims();
2199                        let idx_shape = graph.node(idx_id).shape.dims();
2200                        let vocab = table_shape[0].unwrap_static() as u32;
2201                        let dim: u32 = table_shape[1..]
2202                            .iter()
2203                            .map(|d| d.unwrap_static() as u32)
2204                            .product::<u32>()
2205                            .max(1);
2206                        let n_idx: u32 =
2207                            idx_shape.iter().map(|d| d.unwrap_static() as u32).product();
2208                        let p = GatherParams {
2209                            n_out: elems,
2210                            n_idx,
2211                            dim,
2212                            vocab,
2213                            in_off: (arena.offset(table_id) / 4) as u32,
2214                            idx_off: (arena.offset(idx_id) / 4) as u32,
2215                            out_off: (arena.offset(node.id) / 4) as u32,
2216                            _p0: 0,
2217                        };
2218                        schedule.push(Step::Gather { params: p });
2219                        let gk = gather_kernel(&dev.device);
2220                        let u = emit_uniform(std::mem::size_of::<GatherParams>());
2221                        let bg = bind_two(&dev.device, gk, &arena.buffer, &u);
2222                        uniforms.push(u);
2223                        bind_groups.push(bg);
2224                    } else {
2225                        let table_shape = graph.node(table_id).shape.dims();
2226                        let idx_shape = graph.node(idx_id).shape.dims();
2227                        let outer: u32 = table_shape[..*axis]
2228                            .iter()
2229                            .map(|d| d.unwrap_static() as u32)
2230                            .product::<u32>()
2231                            .max(1);
2232                        let trailing: u32 = table_shape[*axis + 1..]
2233                            .iter()
2234                            .map(|d| d.unwrap_static() as u32)
2235                            .product::<u32>()
2236                            .max(1);
2237                        let axis_dim = table_shape[*axis].unwrap_static() as u32;
2238                        let num_idx: u32 =
2239                            idx_shape.iter().map(|d| d.unwrap_static() as u32).product();
2240                        let total = outer * num_idx * trailing;
2241                        let p = GatherAxisParams {
2242                            total,
2243                            outer,
2244                            axis_dim,
2245                            num_idx,
2246                            trailing,
2247                            table_off: (arena.offset(table_id) / 4) as u32,
2248                            idx_off: (arena.offset(idx_id) / 4) as u32,
2249                            out_off: (arena.offset(node.id) / 4) as u32,
2250                        };
2251                        schedule.push(Step::GatherAxis { params: p });
2252                        let gk = gather_axis_kernel(&dev.device);
2253                        let u = emit_uniform(std::mem::size_of::<GatherAxisParams>());
2254                        let bg = bind_two(&dev.device, gk, &arena.buffer, &u);
2255                        uniforms.push(u);
2256                        bind_groups.push(bg);
2257                    }
2258                }
2259
2260                Op::FusedMatMulBiasAct { activation } => {
2261                    // Inputs: [x, w, bias]. We require 2D × 2D or
2262                    // [..,M,K] × [K,N] (broadcast bias). Bias is shape [N].
2263                    let a_id = node.inputs[0];
2264                    let b_id = node.inputs[1];
2265                    let bias_id = node.inputs[2];
2266                    let a_shape = graph.node(a_id).shape.dims();
2267                    let b_shape = graph.node(b_id).shape.dims();
2268                    let out_shape = node.shape.dims();
2269                    let (m, k, n) =
2270                        if a_shape.len() == 2 && b_shape.len() == 2 && out_shape.len() == 2 {
2271                            (
2272                                a_shape[0].unwrap_static() as u32,
2273                                a_shape[1].unwrap_static() as u32,
2274                                b_shape[1].unwrap_static() as u32,
2275                            )
2276                        } else if a_shape.len() >= 2
2277                            && b_shape.len() == 2
2278                            && out_shape.len() == a_shape.len()
2279                        {
2280                            let leading: usize = a_shape[..a_shape.len() - 2]
2281                                .iter()
2282                                .map(|d| d.unwrap_static())
2283                                .product();
2284                            let m_inner = a_shape[a_shape.len() - 2].unwrap_static();
2285                            let k_inner = a_shape[a_shape.len() - 1].unwrap_static();
2286                            let n_inner = b_shape[1].unwrap_static();
2287                            ((leading * m_inner) as u32, k_inner as u32, n_inner as u32)
2288                        } else {
2289                            panic!(
2290                                "rlx-wgpu FusedMatMulBiasAct: unsupported shapes \
2291                                a={a_shape:?} b={b_shape:?}"
2292                            );
2293                        };
2294                    let act_id = match activation {
2295                        None => 0xFFFFu32,
2296                        Some(a) => activation_op_id(*a),
2297                    };
2298                    let b_is_param = traces_to_param(&graph, b_id);
2299                    let compute_precision =
2300                        derive_matmul_compute(&dev.device, &graph, a_id, b_id, m, k, n);
2301
2302                    // Split-QKV pattern: matmul writes Q/K/V directly into
2303                    // 3 separate output buffers, eliminating the 3 Narrow
2304                    // dispatches that would otherwise follow. Two flavors:
2305                    //   F32     → matmul_qkv          (portable f32 tile)
2306                    //   CoopF32 → matmul_qkv_coop_f32 (simdgroup f32 GEMM)
2307                    // Coop16 is intentionally not handled here (the kernel
2308                    // would need an f16-acc variant — Naga 29 can't compile
2309                    // mixed-precision coop_mat).
2310                    let mqk_eligible = act_id == 0xFFFFu32
2311                        && (compute_precision == MatmulCompute::F32
2312                            || compute_precision == MatmulCompute::CoopF32);
2313                    if mqk_eligible && let Some(&(q_id, k_id_n, v_id)) = qkv_split.get(&node.id) {
2314                        let head_width = n / 3;
2315                        let coop = compute_precision == MatmulCompute::CoopF32;
2316                        let mqk_kernel = if coop {
2317                            matmul_qkv_coop_f32_kernel(&dev.device)
2318                                .expect("coop matmul_qkv kernel: hardware feature was checked but kernel missing")
2319                        } else {
2320                            matmul_qkv_kernel(&dev.device)
2321                        };
2322                        let p = MatmulQkvParams {
2323                            m,
2324                            k,
2325                            n,
2326                            a_off: (arena.offset(a_id) / 4) as u32,
2327                            b_off: (arena.offset(b_id) / 4) as u32,
2328                            q_off: (arena.offset(q_id) / 4) as u32,
2329                            k_off: (arena.offset(k_id_n) / 4) as u32,
2330                            v_off: (arena.offset(v_id) / 4) as u32,
2331                            head_width,
2332                            has_bias: 1,
2333                            bias_off: (arena.offset(bias_id) / 4) as u32,
2334                            _p0: 0,
2335                            _p1: 0,
2336                            _p2: 0,
2337                            _p3: 0,
2338                            _p4: 0,
2339                        };
2340                        schedule.push(Step::MatmulQkv { params: p, coop });
2341                        let u = emit_uniform(std::mem::size_of::<MatmulQkvParams>());
2342                        let bg = bind_two(&dev.device, mqk_kernel, &arena.buffer, &u);
2343                        uniforms.push(u);
2344                        bind_groups.push(bg);
2345                    } else {
2346                        schedule.push(Step::Matmul {
2347                            m,
2348                            k,
2349                            n,
2350                            batch: 1,
2351                            a_batch_stride: 0,
2352                            b_batch_stride: 0,
2353                            c_batch_stride: 0,
2354                            a_off_f32: (arena.offset(a_id) / 4) as u32,
2355                            b_off_f32: (arena.offset(b_id) / 4) as u32,
2356                            c_off_f32: (arena.offset(node.id) / 4) as u32,
2357                            has_bias: 1,
2358                            bias_off_f32: (arena.offset(bias_id) / 4) as u32,
2359                            act_id,
2360                            b_is_param,
2361                            compute_precision,
2362                        });
2363                        let u = emit_uniform(std::mem::size_of::<MatmulParams>());
2364                        let bg = build_matmul_bind_group(
2365                            &dev.device,
2366                            mm_k,
2367                            mm_w,
2368                            &mm_f16w,
2369                            &mm_f16c,
2370                            &mm_coop,
2371                            &mm_coop_f32,
2372                            &arena,
2373                            &u,
2374                            b_is_param,
2375                            compute_precision,
2376                        );
2377                        uniforms.push(u);
2378                        bind_groups.push(bg);
2379                    }
2380                }
2381
2382                Op::DotGeneral { .. } => {
2383                    // Should be unreachable: DotGeneral is decomposed into
2384                    // MatMul + Transpose + Reshape by the unfusion pass
2385                    // before memory planning. If we hit this arm, the
2386                    // unfusion pass has a gap.
2387                    panic!(
2388                        "rlx-wgpu DotGeneral: leaked past unfusion pass — \
2389                            check unfuse.rs::expand_dot_general for missing patterns"
2390                    );
2391                }
2392
2393                Op::Sample {
2394                    top_k,
2395                    top_p,
2396                    temperature,
2397                    seed,
2398                } => {
2399                    let in_id = node.inputs[0];
2400                    let in_shape = graph.node(in_id).shape.dims();
2401                    let inner = in_shape[in_shape.len() - 1].unwrap_static() as u32;
2402                    let total: u32 = in_shape.iter().map(|d| d.unwrap_static() as u32).product();
2403                    let outer = total / inner.max(1);
2404                    // Greedy fast-path: temperature == 1.0 with no top_k/top_p
2405                    // is an argmax — same numeric result, much cheaper kernel.
2406                    let is_greedy = *top_k == 0
2407                        && (*top_p - 1.0).abs() < 1e-6
2408                        && (*temperature - 1.0).abs() < 1e-6;
2409                    if is_greedy {
2410                        let p = ArgmaxParams {
2411                            outer,
2412                            inner,
2413                            in_off: (arena.offset(in_id) / 4) as u32,
2414                            out_off: (arena.offset(node.id) / 4) as u32,
2415                            _p0: 0,
2416                            _p1: 0,
2417                            _p2: 0,
2418                            _p3: 0,
2419                        };
2420                        schedule.push(Step::Argmax { params: p });
2421                        let amk = argmax_kernel(&dev.device);
2422                        let u = emit_uniform(std::mem::size_of::<ArgmaxParams>());
2423                        let bg = bind_two(&dev.device, amk, &arena.buffer, &u);
2424                        uniforms.push(u);
2425                        bind_groups.push(bg);
2426                    } else {
2427                        let p = SampleParams {
2428                            outer,
2429                            inner,
2430                            in_off: (arena.offset(in_id) / 4) as u32,
2431                            out_off: (arena.offset(node.id) / 4) as u32,
2432                            top_k: *top_k as u32,
2433                            top_p_bits: top_p.to_bits(),
2434                            temp_bits: temperature.to_bits(),
2435                            seed_lo: *seed as u32,
2436                            seed_hi: (*seed >> 32) as u32,
2437                            _p0: 0,
2438                            _p1: 0,
2439                            _p2: 0,
2440                        };
2441                        schedule.push(Step::Sample { params: p });
2442                        let sk = sample_kernel(&dev.device);
2443                        let u = emit_uniform(std::mem::size_of::<SampleParams>());
2444                        let bg = bind_two(&dev.device, sk, &arena.buffer, &u);
2445                        uniforms.push(u);
2446                        bind_groups.push(bg);
2447                    }
2448                }
2449
2450                Op::Pool {
2451                    kind,
2452                    kernel_size,
2453                    stride,
2454                    padding,
2455                } => {
2456                    let in_shape = graph.node(node.inputs[0]).shape.dims();
2457                    let out_shape = node.shape.dims();
2458                    let op_id: u32 = match kind {
2459                        ReduceOp::Sum => 0,
2460                        ReduceOp::Mean => 1,
2461                        ReduceOp::Max => 2,
2462                        ReduceOp::Min => 3,
2463                        ReduceOp::Prod => 4,
2464                    };
2465                    match (kernel_size.len(), in_shape.len(), out_shape.len()) {
2466                        (1, 3, 3) => {
2467                            let p = Pool1dParams {
2468                                n: in_shape[0].unwrap_static() as u32,
2469                                c: in_shape[1].unwrap_static() as u32,
2470                                l: in_shape[2].unwrap_static() as u32,
2471                                l_out: out_shape[2].unwrap_static() as u32,
2472                                kl: kernel_size[0] as u32,
2473                                sl: stride.first().copied().unwrap_or(1) as u32,
2474                                pl: padding.first().copied().unwrap_or(0) as u32,
2475                                op: op_id,
2476                                in_off: (arena.offset(node.inputs[0]) / 4) as u32,
2477                                out_off: (arena.offset(node.id) / 4) as u32,
2478                                _p0: 0,
2479                                _p1: 0,
2480                                _p2: 0,
2481                                _p3: 0,
2482                                _p4: 0,
2483                                _p5: 0,
2484                            };
2485                            schedule.push(Step::Pool1d { params: p });
2486                            let pk = pool1d_kernel(&dev.device);
2487                            let u = emit_uniform(std::mem::size_of::<Pool1dParams>());
2488                            let bg = bind_two(&dev.device, pk, &arena.buffer, &u);
2489                            uniforms.push(u);
2490                            bind_groups.push(bg);
2491                        }
2492                        (2, 4, 4) => {
2493                            let p = Pool2dParams {
2494                                n: in_shape[0].unwrap_static() as u32,
2495                                c: in_shape[1].unwrap_static() as u32,
2496                                h: in_shape[2].unwrap_static() as u32,
2497                                w: in_shape[3].unwrap_static() as u32,
2498                                h_out: out_shape[2].unwrap_static() as u32,
2499                                w_out: out_shape[3].unwrap_static() as u32,
2500                                kh: kernel_size[0] as u32,
2501                                kw: kernel_size[1] as u32,
2502                                sh: stride.first().copied().unwrap_or(1) as u32,
2503                                sw: stride.get(1).copied().unwrap_or(1) as u32,
2504                                ph: padding.first().copied().unwrap_or(0) as u32,
2505                                pw: padding.get(1).copied().unwrap_or(0) as u32,
2506                                op: op_id,
2507                                in_off: (arena.offset(node.inputs[0]) / 4) as u32,
2508                                out_off: (arena.offset(node.id) / 4) as u32,
2509                                _p0: 0,
2510                                _p1: 0,
2511                                _p2: 0,
2512                            };
2513                            schedule.push(Step::Pool2d { params: p });
2514                            let pk = pool2d_kernel(&dev.device);
2515                            let u = emit_uniform(std::mem::size_of::<Pool2dParams>());
2516                            let bg = bind_two(&dev.device, pk, &arena.buffer, &u);
2517                            uniforms.push(u);
2518                            bind_groups.push(bg);
2519                        }
2520                        (3, 5, 5) => {
2521                            let p = Pool3dParams {
2522                                n: in_shape[0].unwrap_static() as u32,
2523                                c: in_shape[1].unwrap_static() as u32,
2524                                d: in_shape[2].unwrap_static() as u32,
2525                                h: in_shape[3].unwrap_static() as u32,
2526                                w: in_shape[4].unwrap_static() as u32,
2527                                d_out: out_shape[2].unwrap_static() as u32,
2528                                h_out: out_shape[3].unwrap_static() as u32,
2529                                w_out: out_shape[4].unwrap_static() as u32,
2530                                kd: kernel_size[0] as u32,
2531                                kh: kernel_size[1] as u32,
2532                                kw: kernel_size[2] as u32,
2533                                sd: stride.first().copied().unwrap_or(1) as u32,
2534                                sh: stride.get(1).copied().unwrap_or(1) as u32,
2535                                sw: stride.get(2).copied().unwrap_or(1) as u32,
2536                                pd: padding.first().copied().unwrap_or(0) as u32,
2537                                ph: padding.get(1).copied().unwrap_or(0) as u32,
2538                                pw: padding.get(2).copied().unwrap_or(0) as u32,
2539                                op: op_id,
2540                                in_off: (arena.offset(node.inputs[0]) / 4) as u32,
2541                                out_off: (arena.offset(node.id) / 4) as u32,
2542                                _p0: 0,
2543                                _p1: 0,
2544                            };
2545                            schedule.push(Step::Pool3d { params: p });
2546                            let pk = pool3d_kernel(&dev.device);
2547                            let u = emit_uniform(std::mem::size_of::<Pool3dParams>());
2548                            let bg = bind_two(&dev.device, pk, &arena.buffer, &u);
2549                            uniforms.push(u);
2550                            bind_groups.push(bg);
2551                        }
2552                        (k, n, m) => panic!(
2553                            "rlx-wgpu Pool: kernel-rank {k} with input rank {n} / \
2554                             output rank {m} not supported (use 1D/2D/3D NCHW)"
2555                        ),
2556                    }
2557                }
2558
2559                Op::Conv {
2560                    kernel_size,
2561                    stride,
2562                    padding,
2563                    dilation,
2564                    groups,
2565                } => {
2566                    let in_shape = graph.node(node.inputs[0]).shape.dims();
2567                    let w_shape = graph.node(node.inputs[1]).shape.dims();
2568                    let out_shape = node.shape.dims();
2569                    let s = |i: usize| stride.get(i).copied().unwrap_or(1) as u32;
2570                    let p = |i: usize| padding.get(i).copied().unwrap_or(0) as u32;
2571                    let d = |i: usize| dilation.get(i).copied().unwrap_or(1) as u32;
2572                    match (
2573                        kernel_size.len(),
2574                        in_shape.len(),
2575                        w_shape.len(),
2576                        out_shape.len(),
2577                    ) {
2578                        (1, 3, 3, 3) => {
2579                            let p1 = Conv1dParams {
2580                                n: in_shape[0].unwrap_static() as u32,
2581                                c_in: in_shape[1].unwrap_static() as u32,
2582                                c_out: out_shape[1].unwrap_static() as u32,
2583                                l: in_shape[2].unwrap_static() as u32,
2584                                l_out: out_shape[2].unwrap_static() as u32,
2585                                kl: kernel_size[0] as u32,
2586                                sl: s(0),
2587                                pl: p(0),
2588                                dl: d(0),
2589                                groups: *groups as u32,
2590                                in_off: (arena.offset(node.inputs[0]) / 4) as u32,
2591                                w_off: (arena.offset(node.inputs[1]) / 4) as u32,
2592                                out_off: (arena.offset(node.id) / 4) as u32,
2593                                _p0: 0,
2594                                _p1: 0,
2595                                _p2: 0,
2596                            };
2597                            schedule.push(Step::Conv1d { params: p1 });
2598                            let ck = conv1d_kernel(&dev.device);
2599                            let u = emit_uniform(std::mem::size_of::<Conv1dParams>());
2600                            let bg = bind_two(&dev.device, ck, &arena.buffer, &u);
2601                            uniforms.push(u);
2602                            bind_groups.push(bg);
2603                        }
2604                        (2, 4, 4, 4) => {
2605                            let p2 = Conv2dParams {
2606                                n: in_shape[0].unwrap_static() as u32,
2607                                c_in: in_shape[1].unwrap_static() as u32,
2608                                c_out: out_shape[1].unwrap_static() as u32,
2609                                h: in_shape[2].unwrap_static() as u32,
2610                                w: in_shape[3].unwrap_static() as u32,
2611                                h_out: out_shape[2].unwrap_static() as u32,
2612                                w_out: out_shape[3].unwrap_static() as u32,
2613                                kh: kernel_size[0] as u32,
2614                                kw: kernel_size[1] as u32,
2615                                sh: s(0),
2616                                sw: s(1),
2617                                ph: p(0),
2618                                pw: p(1),
2619                                dh: d(0),
2620                                dw: d(1),
2621                                groups: *groups as u32,
2622                                in_off: (arena.offset(node.inputs[0]) / 4) as u32,
2623                                w_off: (arena.offset(node.inputs[1]) / 4) as u32,
2624                                out_off: (arena.offset(node.id) / 4) as u32,
2625                            };
2626                            schedule.push(Step::Conv2d { params: p2 });
2627                            let ck = conv2d_kernel(&dev.device);
2628                            let u = emit_uniform(std::mem::size_of::<Conv2dParams>());
2629                            let bg = bind_two(&dev.device, ck, &arena.buffer, &u);
2630                            uniforms.push(u);
2631                            bind_groups.push(bg);
2632                        }
2633                        (3, 5, 5, 5) => {
2634                            let p3 = Conv3dParams {
2635                                n: in_shape[0].unwrap_static() as u32,
2636                                c_in: in_shape[1].unwrap_static() as u32,
2637                                c_out: out_shape[1].unwrap_static() as u32,
2638                                d: in_shape[2].unwrap_static() as u32,
2639                                h: in_shape[3].unwrap_static() as u32,
2640                                w: in_shape[4].unwrap_static() as u32,
2641                                d_out: out_shape[2].unwrap_static() as u32,
2642                                h_out: out_shape[3].unwrap_static() as u32,
2643                                w_out: out_shape[4].unwrap_static() as u32,
2644                                kd: kernel_size[0] as u32,
2645                                kh: kernel_size[1] as u32,
2646                                kw: kernel_size[2] as u32,
2647                                sd: s(0),
2648                                sh: s(1),
2649                                sw: s(2),
2650                                pd: p(0),
2651                                ph: p(1),
2652                                pw: p(2),
2653                                dd: d(0),
2654                                dh: d(1),
2655                                dw: d(2),
2656                                groups: *groups as u32,
2657                                in_off: (arena.offset(node.inputs[0]) / 4) as u32,
2658                                w_off: (arena.offset(node.inputs[1]) / 4) as u32,
2659                                out_off: (arena.offset(node.id) / 4) as u32,
2660                                _p0: 0,
2661                            };
2662                            schedule.push(Step::Conv3d { params: p3 });
2663                            let ck = conv3d_kernel(&dev.device);
2664                            let u = emit_uniform(std::mem::size_of::<Conv3dParams>());
2665                            let bg = bind_two(&dev.device, ck, &arena.buffer, &u);
2666                            uniforms.push(u);
2667                            bind_groups.push(bg);
2668                        }
2669                        (k, ni, wi, mi) => panic!(
2670                            "rlx-wgpu Conv: rank kernel={k} in={ni} weight={wi} out={mi} \
2671                             not supported (use 1D/2D/3D NCHW)"
2672                        ),
2673                    }
2674                }
2675
2676                Op::Cumsum { axis, exclusive } => {
2677                    let in_id = node.inputs[0];
2678                    let in_shape = graph.node(in_id).shape.dims();
2679                    let last = (in_shape.len() - 1) as i32;
2680                    if *axis != -1 && *axis != last {
2681                        panic!("rlx-wgpu Cumsum: only last-axis wired (got axis={axis})");
2682                    }
2683                    let inner = in_shape[in_shape.len() - 1].unwrap_static() as u32;
2684                    let total: u32 = in_shape.iter().map(|d| d.unwrap_static() as u32).product();
2685                    let outer = total / inner.max(1);
2686                    let p = CumsumParams {
2687                        outer,
2688                        inner,
2689                        in_off: (arena.offset(in_id) / 4) as u32,
2690                        out_off: (arena.offset(node.id) / 4) as u32,
2691                        exclusive: if *exclusive { 1 } else { 0 },
2692                        _p0: 0,
2693                        _p1: 0,
2694                        _p2: 0,
2695                    };
2696                    schedule.push(Step::Cumsum { params: p });
2697                    let ck2 = cumsum_kernel(&dev.device);
2698                    let u = emit_uniform(std::mem::size_of::<CumsumParams>());
2699                    let bg = bind_two(&dev.device, ck2, &arena.buffer, &u);
2700                    uniforms.push(u);
2701                    bind_groups.push(bg);
2702                }
2703                Op::Fft { inverse, norm } => {
2704                    let in_id = node.inputs[0];
2705                    let in_shape = graph.node(in_id).shape.clone();
2706                    let meta = rlx_ir::fft::fft_meta(&in_shape);
2707                    let dtype = in_shape.dtype();
2708                    let use_gpu = rlx_ir::fft::gpu_fft_native_eligible(dtype, meta.n_complex)
2709                        && meta.n_complex >= 2;
2710                    let scale = norm.output_scale(meta.n_complex, *inverse) as f32;
2711                    if use_gpu {
2712                        schedule.push(Step::FftGpu {
2713                            src_off: (arena.offset(in_id) / 4) as u32,
2714                            dst_off: (arena.offset(node.id) / 4) as u32,
2715                            outer: meta.outer as u32,
2716                            n: meta.n_complex as u32,
2717                            inverse: if *inverse { 1 } else { 0 },
2718                            norm_scale: scale,
2719                        });
2720                        fft_gpu_steps.push(crate::fft_dispatch::FftGpuResources::new(
2721                            &dev.device,
2722                            &arena.buffer,
2723                        ));
2724                    } else {
2725                        schedule.push(Step::FftHost {
2726                            src_byte_off: arena.offset(in_id) as u32,
2727                            dst_byte_off: arena.offset(node.id) as u32,
2728                            outer: meta.outer as u32,
2729                            n_complex: meta.n_complex as u32,
2730                            inverse: *inverse,
2731                            norm_tag: norm.tag(),
2732                            dtype_tag: fft_dtype_tag(dtype),
2733                        });
2734                    }
2735                }
2736                Op::SelectiveScan { state_size } => {
2737                    if *state_size > 256 {
2738                        panic!(
2739                            "rlx-wgpu SelectiveScan: state_size {} exceeds compile-time \
2740                                cap of 256 (kernel uses fixed-size private array)",
2741                            state_size
2742                        );
2743                    }
2744                    let x_id = node.inputs[0];
2745                    let dt_id = node.inputs[1];
2746                    let a_id = node.inputs[2];
2747                    let b_id = node.inputs[3];
2748                    let c_id = node.inputs[4];
2749                    let in_dims = graph.node(x_id).shape.dims();
2750                    let seq = in_dims[1].unwrap_static() as u32;
2751                    let p = SelectiveScanParams {
2752                        batch: in_dims[0].unwrap_static() as u32,
2753                        seq,
2754                        hidden: in_dims[2].unwrap_static() as u32,
2755                        state_size: *state_size as u32,
2756                        x_off: (arena.offset(x_id) / 4) as u32,
2757                        delta_off: (arena.offset(dt_id) / 4) as u32,
2758                        a_off: (arena.offset(a_id) / 4) as u32,
2759                        b_off: (arena.offset(b_id) / 4) as u32,
2760                        c_off: (arena.offset(c_id) / 4) as u32,
2761                        out_off: (arena.offset(node.id) / 4) as u32,
2762                        // PLAN L1: full-extent stride; safe under
2763                        // active-extent scaling of params.seq.
2764                        seq_stride: seq,
2765                        _p1: 0,
2766                        _p2: 0,
2767                        _p3: 0,
2768                        _p4: 0,
2769                        _p5: 0,
2770                    };
2771                    schedule.push(Step::SelectiveScan { params: p });
2772                    let ssk = selective_scan_kernel(&dev.device);
2773                    let u = emit_uniform(std::mem::size_of::<SelectiveScanParams>());
2774                    let bg = bind_two(&dev.device, ssk, &arena.buffer, &u);
2775                    uniforms.push(u);
2776                    bind_groups.push(bg);
2777                }
2778                Op::GatedDeltaNet {
2779                    state_size,
2780                    carry_state,
2781                } => {
2782                    if *state_size > rlx_cpu::gdn::GDN_MAX_STATE {
2783                        panic!(
2784                            "rlx-wgpu GatedDeltaNet: state_size {state_size} > {}",
2785                            rlx_cpu::gdn::GDN_MAX_STATE
2786                        );
2787                    }
2788                    let q_id = node.inputs[0];
2789                    let q_shape = &graph.node(q_id).shape;
2790                    let state_off = if *carry_state {
2791                        arena.offset(node.inputs[5])
2792                    } else {
2793                        0
2794                    };
2795                    schedule.push(Step::GatedDeltaNet {
2796                        q_byte_off: arena.offset(q_id) as u32,
2797                        k_byte_off: arena.offset(node.inputs[1]) as u32,
2798                        v_byte_off: arena.offset(node.inputs[2]) as u32,
2799                        g_byte_off: arena.offset(node.inputs[3]) as u32,
2800                        beta_byte_off: arena.offset(node.inputs[4]) as u32,
2801                        state_byte_off: state_off as u32,
2802                        dst_byte_off: arena.offset(node.id) as u32,
2803                        batch: q_shape.dim(0).unwrap_static() as u32,
2804                        seq: q_shape.dim(1).unwrap_static() as u32,
2805                        heads: q_shape.dim(2).unwrap_static() as u32,
2806                        state_size: *state_size as u32,
2807                        use_carry: *carry_state,
2808                    });
2809                    if gguf_host_pad.is_none() {
2810                        let bk = binary_kernel(&dev.device);
2811                        let u = emit_uniform(256);
2812                        gguf_host_pad =
2813                            Some((u.clone(), bind_two(&dev.device, bk, &arena.buffer, &u)));
2814                    }
2815                    let (u, bg) = gguf_host_pad.as_ref().unwrap();
2816                    uniforms.push(u.clone());
2817                    bind_groups.push(bg.clone());
2818                }
2819                Op::Custom { name, attrs, .. } => match name.as_str() {
2820                    "llada2.group_limited_gate" => {
2821                        let sig_id = node.inputs[0];
2822                        let route_id = node.inputs[1];
2823                        let n_elems = graph.node(sig_id).shape.num_elements().unwrap() as u32;
2824                        let mut attr_buf = [0u8; 20];
2825                        let n = attrs.len().min(20);
2826                        attr_buf[..n].copy_from_slice(&attrs[..n]);
2827                        schedule.push(Step::Llada2GroupLimitedGate {
2828                            sig_byte_off: arena.offset(sig_id) as u32,
2829                            route_byte_off: arena.offset(route_id) as u32,
2830                            out_byte_off: arena.offset(node.id) as u32,
2831                            n_elems,
2832                            attrs: attr_buf,
2833                        });
2834                    }
2835                    "umap.knn" => {
2836                        let pw_id = node.inputs[0];
2837                        let pw_shape = graph.node(pw_id).shape.dims();
2838                        let n = pw_shape[0].unwrap_static() as u32;
2839                        let k = if attrs.len() >= 4 {
2840                            u32::from_le_bytes(attrs[..4].try_into().unwrap())
2841                        } else {
2842                            panic!("rlx-wgpu: umap.knn attrs missing k");
2843                        };
2844                        let pw_off = arena.offset(pw_id) as u32;
2845                        let out_off = arena.offset(node.id) as u32;
2846                        if n as usize >= crate::umap_knn_host::UMAP_KNN_GPU_MIN_N {
2847                            let p = UmapKnnParams {
2848                                n,
2849                                k,
2850                                pw_off: pw_off / 4,
2851                                out_off: out_off / 4,
2852                                _p0: 0,
2853                                _p1: 0,
2854                                _p2: 0,
2855                            };
2856                            schedule.push(Step::UmapKnn { params: p });
2857                            let uk = umap_knn_kernel(&dev.device);
2858                            let u = emit_uniform(std::mem::size_of::<UmapKnnParams>());
2859                            let bg = bind_two(&dev.device, uk, &arena.buffer, &u);
2860                            uniforms.push(u);
2861                            bind_groups.push(bg);
2862                        } else {
2863                            schedule.push(Step::UmapKnnHost {
2864                                pairwise_byte_off: pw_off,
2865                                out_byte_off: out_off,
2866                                n,
2867                                k,
2868                            });
2869                        }
2870                    }
2871                    other => panic!("rlx-wgpu: unsupported Op::Custom('{other}')"),
2872                },
2873                Op::GroupedMatMul => {
2874                    // Inputs: input [M, K], weight [E, K, N], expert_idx [M]
2875                    let in_id = node.inputs[0];
2876                    let w_id = node.inputs[1];
2877                    let idx_id = node.inputs[2];
2878                    let in_dims = graph.node(in_id).shape.dims();
2879                    let w_dims = graph.node(w_id).shape.dims();
2880                    let m = in_dims[0].unwrap_static() as u32;
2881                    let k = in_dims[1].unwrap_static() as u32;
2882                    let n = w_dims[2].unwrap_static() as u32;
2883                    let ne = w_dims[0].unwrap_static() as u32;
2884                    let p = GroupedMatmulParams {
2885                        m,
2886                        k,
2887                        n,
2888                        num_experts: ne,
2889                        in_off: (arena.offset(in_id) / 4) as u32,
2890                        w_off: (arena.offset(w_id) / 4) as u32,
2891                        idx_off: (arena.offset(idx_id) / 4) as u32,
2892                        out_off: (arena.offset(node.id) / 4) as u32,
2893                    };
2894                    schedule.push(Step::GroupedMatmul { params: p });
2895                    let gk = grouped_matmul_kernel(&dev.device);
2896                    let u = emit_uniform(std::mem::size_of::<GroupedMatmulParams>());
2897                    let bg = bind_two(&dev.device, gk, &arena.buffer, &u);
2898                    uniforms.push(u);
2899                    bind_groups.push(bg);
2900                }
2901                Op::DequantGroupedMatMul { scheme } => {
2902                    let in_id = node.inputs[0];
2903                    let w_id = node.inputs[1];
2904                    let idx_id = node.inputs[2];
2905                    let in_dims = graph.node(in_id).shape.dims();
2906                    let out_dims = node.shape.dims();
2907                    let m = in_dims[0].unwrap_static() as u32;
2908                    let k = in_dims[1].unwrap_static() as u32;
2909                    let n = out_dims[out_dims.len() - 1].unwrap_static() as u32;
2910                    let block_elems = scheme.gguf_block_size() as usize;
2911                    let block_bytes = scheme.gguf_block_bytes() as usize;
2912                    let slab_bytes = (k as usize * n as usize) / block_elems * block_bytes;
2913                    let total_bytes = graph.node(w_id).shape.num_elements().unwrap();
2914                    let ne = (total_bytes / slab_bytes.max(1)) as u32;
2915                    schedule.push(Step::DequantGroupedMatmulGguf {
2916                        m,
2917                        k,
2918                        n,
2919                        num_experts: ne,
2920                        scheme_id: crate::gguf_host::gguf_scheme_id(*scheme),
2921                        x_byte_off: arena.offset(in_id) as u32,
2922                        w_byte_off: arena.offset(w_id) as u32,
2923                        idx_byte_off: arena.offset(idx_id) as u32,
2924                        out_byte_off: arena.offset(node.id) as u32,
2925                    });
2926                    if gguf_host_pad.is_none() {
2927                        let bk = binary_kernel(&dev.device);
2928                        let u = emit_uniform(256);
2929                        gguf_host_pad =
2930                            Some((u.clone(), bind_two(&dev.device, bk, &arena.buffer, &u)));
2931                    }
2932                    let (u, bg) = gguf_host_pad.as_ref().unwrap();
2933                    uniforms.push(u.clone());
2934                    bind_groups.push(bg.clone());
2935                }
2936                Op::TopK { k } => {
2937                    let in_id = node.inputs[0];
2938                    let in_dims = graph.node(in_id).shape.dims();
2939                    let inner = in_dims.last().unwrap().unwrap_static() as u32;
2940                    let outer: u32 = in_dims[..in_dims.len() - 1]
2941                        .iter()
2942                        .map(|d| d.unwrap_static() as u32)
2943                        .product::<u32>()
2944                        .max(1);
2945                    let p = TopKParams {
2946                        outer,
2947                        inner,
2948                        k: *k as u32,
2949                        in_off: (arena.offset(in_id) / 4) as u32,
2950                        out_off: (arena.offset(node.id) / 4) as u32,
2951                        _p0: 0,
2952                        _p1: 0,
2953                        _p2: 0,
2954                    };
2955                    schedule.push(Step::TopK { params: p });
2956                    let tk = topk_kernel(&dev.device);
2957                    let u = emit_uniform(std::mem::size_of::<TopKParams>());
2958                    let bg = bind_two(&dev.device, tk, &arena.buffer, &u);
2959                    uniforms.push(u);
2960                    bind_groups.push(bg);
2961                }
2962                Op::ScatterAdd => {
2963                    // Inputs: updates [num_updates, trailing], indices [num_updates].
2964                    // Output: [out_dim, trailing]. Implemented as two phases:
2965                    //   1. Zero `out_dim * trailing` slots.
2966                    //   2. CAS-loop atomic-accumulate `num_updates * trailing` updates.
2967                    let upd_id = node.inputs[0];
2968                    let idx_id = node.inputs[1];
2969                    let upd_dims = graph.node(upd_id).shape.dims();
2970                    let out_dims = node.shape.dims();
2971                    let num_updates = upd_dims[0].unwrap_static() as u32;
2972                    let trailing: u32 = upd_dims
2973                        .iter()
2974                        .skip(1)
2975                        .map(|d| d.unwrap_static() as u32)
2976                        .product::<u32>()
2977                        .max(1);
2978                    let out_dim = out_dims[0].unwrap_static() as u32;
2979                    let out_total = out_dim * trailing;
2980
2981                    let common = ScatterAddParams {
2982                        op: 0,
2983                        out_off: (arena.offset(node.id) / 4) as u32,
2984                        upd_off: (arena.offset(upd_id) / 4) as u32,
2985                        idx_off: (arena.offset(idx_id) / 4) as u32,
2986                        out_total,
2987                        num_updates,
2988                        trailing,
2989                        out_dim,
2990                    };
2991                    let sk = scatter_add_kernel(&dev.device);
2992
2993                    // Phase 0: zero.
2994                    schedule.push(Step::ScatterAdd { params: common });
2995                    let u0 = emit_uniform(std::mem::size_of::<ScatterAddParams>());
2996                    let bg0 = bind_two(&dev.device, sk, &arena.buffer, &u0);
2997                    uniforms.push(u0);
2998                    bind_groups.push(bg0);
2999
3000                    // Phase 1: accumulate.
3001                    let mut acc = common;
3002                    acc.op = 1;
3003                    schedule.push(Step::ScatterAdd { params: acc });
3004                    let u1 = emit_uniform(std::mem::size_of::<ScatterAddParams>());
3005                    let bg1 = bind_two(&dev.device, sk, &arena.buffer, &u1);
3006                    uniforms.push(u1);
3007                    bind_groups.push(bg1);
3008                }
3009                Op::FusedResidualLN { has_bias, eps } => {
3010                    // Inputs: [x, residual, [bias], gamma, beta].
3011                    let x_id = node.inputs[0];
3012                    let r_id = node.inputs[1];
3013                    let (bias_id, g_id, b_id) = if *has_bias {
3014                        (node.inputs[2], node.inputs[3], node.inputs[4])
3015                    } else {
3016                        (x_id, node.inputs[2], node.inputs[3]) // bias unused
3017                    };
3018                    let in_dims = node.shape.dims();
3019                    let inner = in_dims[in_dims.len() - 1].unwrap_static() as u32;
3020                    let total: u32 = in_dims.iter().map(|d| d.unwrap_static() as u32).product();
3021                    let outer = total / inner.max(1);
3022                    let p = FusedResidualLnParams {
3023                        outer,
3024                        inner,
3025                        in_off: (arena.offset(x_id) / 4) as u32,
3026                        residual_off: (arena.offset(r_id) / 4) as u32,
3027                        bias_off: (arena.offset(bias_id) / 4) as u32,
3028                        gamma_off: (arena.offset(g_id) / 4) as u32,
3029                        beta_off: (arena.offset(b_id) / 4) as u32,
3030                        out_off: (arena.offset(node.id) / 4) as u32,
3031                        eps_bits: eps.to_bits(),
3032                        has_bias: if *has_bias { 1 } else { 0 },
3033                        _p0: 0,
3034                        _p1: 0,
3035                    };
3036                    schedule.push(Step::FusedResidualLn { params: p });
3037                    let frk = fused_residual_ln_kernel(&dev.device);
3038                    let u = emit_uniform(std::mem::size_of::<FusedResidualLnParams>());
3039                    let bg = bind_two(&dev.device, frk, &arena.buffer, &u);
3040                    uniforms.push(u);
3041                    bind_groups.push(bg);
3042                }
3043                Op::FusedResidualRmsNorm { has_bias, eps } => {
3044                    let x_id = node.inputs[0];
3045                    let r_id = node.inputs[1];
3046                    let (bias_id, g_id, b_id) = if *has_bias {
3047                        (node.inputs[2], node.inputs[3], node.inputs[4])
3048                    } else {
3049                        (x_id, node.inputs[2], node.inputs[3])
3050                    };
3051                    let in_dims = node.shape.dims();
3052                    let inner = in_dims[in_dims.len() - 1].unwrap_static() as u32;
3053                    let total: u32 = in_dims.iter().map(|d| d.unwrap_static() as u32).product();
3054                    let outer = total / inner.max(1);
3055                    let p = FusedResidualRmsNormParams {
3056                        outer,
3057                        inner,
3058                        in_off: (arena.offset(x_id) / 4) as u32,
3059                        residual_off: (arena.offset(r_id) / 4) as u32,
3060                        bias_off: (arena.offset(bias_id) / 4) as u32,
3061                        gamma_off: (arena.offset(g_id) / 4) as u32,
3062                        beta_off: (arena.offset(b_id) / 4) as u32,
3063                        out_off: (arena.offset(node.id) / 4) as u32,
3064                        eps_bits: eps.to_bits(),
3065                        has_bias: if *has_bias { 1 } else { 0 },
3066                        _p0: 0,
3067                        _p1: 0,
3068                    };
3069                    schedule.push(Step::FusedResidualRmsNorm { params: p });
3070                    let frk = fused_residual_rms_norm_kernel(&dev.device);
3071                    let u = emit_uniform(std::mem::size_of::<FusedResidualRmsNormParams>());
3072                    let bg = bind_two(&dev.device, frk, &arena.buffer, &u);
3073                    uniforms.push(u);
3074                    bind_groups.push(bg);
3075                }
3076                Op::DequantMatMul { scheme } => {
3077                    use rlx_ir::QuantScheme;
3078                    let x_id = node.inputs[0];
3079                    let w_id = node.inputs[1];
3080                    let out_dims = node.shape.dims();
3081                    let x_dims = graph.node(x_id).shape.dims();
3082                    let m = out_dims[0].unwrap_static() as u32;
3083                    let n = out_dims[1].unwrap_static() as u32;
3084                    let k = x_dims[1].unwrap_static() as u32;
3085                    if scheme.is_gguf() {
3086                        schedule.push(Step::DequantMatmulGguf {
3087                            m,
3088                            k,
3089                            n,
3090                            scheme_id: crate::gguf_host::gguf_scheme_id(*scheme),
3091                            x_byte_off: arena.offset(x_id) as u32,
3092                            w_byte_off: arena.offset(w_id) as u32,
3093                            out_byte_off: arena.offset(node.id) as u32,
3094                        });
3095                        if gguf_host_pad.is_none() {
3096                            let bk = binary_kernel(&dev.device);
3097                            let u = emit_uniform(256);
3098                            gguf_host_pad =
3099                                Some((u.clone(), bind_two(&dev.device, bk, &arena.buffer, &u)));
3100                        }
3101                        let (u, bg) = gguf_host_pad.as_ref().unwrap();
3102                        uniforms.push(u.clone());
3103                        bind_groups.push(bg.clone());
3104                    } else {
3105                        let (block_size, scheme_id) = match scheme {
3106                            QuantScheme::Int8Block { block_size } => (*block_size, 0u32),
3107                            QuantScheme::Int8BlockAsym { block_size } => (*block_size, 1u32),
3108                            QuantScheme::Int4Block { block_size } => (*block_size, 2u32),
3109                            QuantScheme::Fp8E4m3 => (1, 3u32),
3110                            QuantScheme::Fp8E5m2 => (1, 4u32),
3111                            QuantScheme::Nvfp4Block => (rlx_ir::NVFP4_GROUP_SIZE as u32, 5u32),
3112                            other => panic!("rlx-wgpu DequantMatMul: unsupported scheme {other:?}"),
3113                        };
3114                        let scale_id = node.inputs[2];
3115                        let zp_id = node.inputs[3];
3116                        let p = DequantMatmulParams {
3117                            m,
3118                            k,
3119                            n,
3120                            block_size,
3121                            scheme_id,
3122                            x_off: (arena.offset(x_id) / 4) as u32,
3123                            w_off: (arena.offset(w_id) / 4) as u32,
3124                            scale_off: (arena.offset(scale_id) / 4) as u32,
3125                            zp_off: (arena.offset(zp_id) / 4) as u32,
3126                            out_off: (arena.offset(node.id) / 4) as u32,
3127                            _p0: 0,
3128                            _p1: 0,
3129                        };
3130                        schedule.push(Step::DequantMatmul { params: p });
3131                        let dk = dequant_matmul_kernel(&dev.device);
3132                        let u = emit_uniform(std::mem::size_of::<DequantMatmulParams>());
3133                        let bg = bind_two(&dev.device, dk, &arena.buffer, &u);
3134                        uniforms.push(u);
3135                        bind_groups.push(bg);
3136                    }
3137                }
3138                Op::RmsNormBackwardInput { eps, .. }
3139                | Op::RmsNormBackwardGamma { eps, .. }
3140                | Op::RmsNormBackwardBeta { eps, .. } => {
3141                    let x_shape = &graph.node(node.inputs[0]).shape;
3142                    let h = x_shape.dim(x_shape.rank() - 1).unwrap_static() as u32;
3143                    let rows = (x_shape.num_elements().unwrap() / h.max(1) as usize) as u32;
3144                    let foff = |i: usize| (arena.offset(node.inputs[i]) / 4) as u32;
3145                    let wrt = match &node.op {
3146                        Op::RmsNormBackwardInput { .. } => 0u32,
3147                        Op::RmsNormBackwardGamma { .. } => 1u32,
3148                        Op::RmsNormBackwardBeta { .. } => 2u32,
3149                        _ => unreachable!(),
3150                    };
3151                    let p = RmsNormBwdParams {
3152                        outer: rows,
3153                        inner: h,
3154                        x_off: foff(0),
3155                        gamma_off: foff(1),
3156                        beta_off: foff(2),
3157                        dy_off: foff(3),
3158                        out_off: (arena.offset(node.id) / 4) as u32,
3159                        eps_bits: eps.to_bits(),
3160                        wrt,
3161                    };
3162                    let rk = if wrt == 0 {
3163                        rms_norm_backward_kernel(&dev.device)
3164                    } else {
3165                        rms_norm_backward_param_kernel(&dev.device)
3166                    };
3167                    let u = emit_uniform(std::mem::size_of::<RmsNormBwdParams>());
3168                    let bg = bind_two(&dev.device, rk, &arena.buffer, &u);
3169                    match &node.op {
3170                        Op::RmsNormBackwardInput { .. } => {
3171                            schedule.push(Step::RmsNormBackwardInput { params: p });
3172                        }
3173                        Op::RmsNormBackwardGamma { .. } => {
3174                            schedule.push(Step::RmsNormBackwardGamma { params: p });
3175                        }
3176                        Op::RmsNormBackwardBeta { .. } => {
3177                            schedule.push(Step::RmsNormBackwardBeta { params: p });
3178                        }
3179                        _ => unreachable!(),
3180                    }
3181                    uniforms.push(u);
3182                    bind_groups.push(bg);
3183                }
3184                Op::RopeBackward { head_dim, n_rot } => {
3185                    let dy_shape = &graph.node(node.inputs[0]).shape;
3186                    let (batch, seq, hidden) = if dy_shape.rank() >= 3 {
3187                        (
3188                            dy_shape.dim(0).unwrap_static() as u32,
3189                            dy_shape.dim(1).unwrap_static() as u32,
3190                            dy_shape.dim(2).unwrap_static() as u32,
3191                        )
3192                    } else {
3193                        (
3194                            1,
3195                            dy_shape.dim(0).unwrap_static() as u32,
3196                            dy_shape.dim(1).unwrap_static() as u32,
3197                        )
3198                    };
3199                    let cos_len = graph.node(node.inputs[1]).shape.num_elements().unwrap() as u32;
3200                    let p = RopeBwdParams {
3201                        batch,
3202                        seq,
3203                        hidden,
3204                        head_dim: *head_dim as u32,
3205                        n_rot: *n_rot as u32,
3206                        dy_off: (arena.offset(node.inputs[0]) / 4) as u32,
3207                        cos_off: (arena.offset(node.inputs[1]) / 4) as u32,
3208                        sin_off: (arena.offset(node.inputs[2]) / 4) as u32,
3209                        dx_off: (arena.offset(node.id) / 4) as u32,
3210                        cos_len,
3211                    };
3212                    let rk = rope_backward_kernel(&dev.device);
3213                    let u = emit_uniform(std::mem::size_of::<RopeBwdParams>());
3214                    let bg = bind_two(&dev.device, rk, &arena.buffer, &u);
3215                    schedule.push(Step::RopeBackward { params: p });
3216                    uniforms.push(u);
3217                    bind_groups.push(bg);
3218                }
3219                Op::CumsumBackward { exclusive, .. } => {
3220                    let dy_shape = &graph.node(node.inputs[0]).shape;
3221                    let cols = dy_shape.dim(dy_shape.rank() - 1).unwrap_static() as u32;
3222                    let rows = (dy_shape.num_elements().unwrap() / cols.max(1) as usize) as u32;
3223                    let p = CumsumBwdParams {
3224                        outer: rows,
3225                        inner: cols,
3226                        dy_off: (arena.offset(node.inputs[0]) / 4) as u32,
3227                        dx_off: (arena.offset(node.id) / 4) as u32,
3228                        exclusive: if *exclusive { 1 } else { 0 },
3229                        _p0: 0,
3230                        _p1: 0,
3231                        _p2: 0,
3232                    };
3233                    let ck = cumsum_backward_kernel(&dev.device);
3234                    let u = emit_uniform(std::mem::size_of::<CumsumBwdParams>());
3235                    let bg = bind_two(&dev.device, ck, &arena.buffer, &u);
3236                    schedule.push(Step::CumsumBackward { params: p });
3237                    uniforms.push(u);
3238                    bind_groups.push(bg);
3239                }
3240                Op::GatherBackward { .. } => {
3241                    let dy_shape = &graph.node(node.inputs[0]).shape;
3242                    let idx_shape = &graph.node(node.inputs[1]).shape;
3243                    let out_shape = &node.shape;
3244                    let rank = out_shape.rank();
3245                    let axis = match &node.op {
3246                        Op::GatherBackward { axis } => *axis,
3247                        _ => 0,
3248                    };
3249                    let axis_u = if axis < 0 {
3250                        (rank as i32 + axis) as usize
3251                    } else {
3252                        axis as usize
3253                    };
3254                    let outer: usize = (0..axis_u)
3255                        .map(|i| dy_shape.dim(i).unwrap_static())
3256                        .product::<usize>()
3257                        .max(1);
3258                    let num_idx = idx_shape.dim(axis_u).unwrap_static();
3259                    let trailing: usize = (axis_u + 1..dy_shape.rank())
3260                        .map(|i| dy_shape.dim(i).unwrap_static())
3261                        .product::<usize>()
3262                        .max(1);
3263                    let axis_dim = out_shape.dim(axis_u).unwrap_static();
3264                    let p = GatherBwdParams {
3265                        outer: outer as u32,
3266                        axis_dim: axis_dim as u32,
3267                        num_idx: num_idx as u32,
3268                        trailing: trailing as u32,
3269                        dy_off: (arena.offset(node.inputs[0]) / 4) as u32,
3270                        idx_off: (arena.offset(node.inputs[1]) / 4) as u32,
3271                        dst_off: (arena.offset(node.id) / 4) as u32,
3272                        _p0: 0,
3273                    };
3274                    let zk = gather_backward_zero_kernel(&dev.device);
3275                    let u = emit_uniform(std::mem::size_of::<GatherBwdParams>());
3276                    let bg = bind_two(&dev.device, zk, &arena.buffer, &u);
3277                    schedule.push(Step::GatherBackward { params: p });
3278                    uniforms.push(u);
3279                    bind_groups.push(bg);
3280                }
3281                #[cfg(feature = "splat")]
3282                Op::GaussianSplatRender {
3283                    width,
3284                    height,
3285                    tile_size,
3286                    radius_scale,
3287                    alpha_cutoff,
3288                    max_splat_steps,
3289                    transmittance_threshold,
3290                    max_list_entries,
3291                } => {
3292                    let elem_len = |id: NodeId| -> u32 {
3293                        graph.node(id).shape.num_elements().unwrap_or(0) as u32
3294                    };
3295                    schedule.push(Step::GaussianSplatRender {
3296                        positions_byte_off: arena.offset(node.inputs[0]) as u32,
3297                        positions_len: elem_len(node.inputs[0]),
3298                        scales_byte_off: arena.offset(node.inputs[1]) as u32,
3299                        scales_len: elem_len(node.inputs[1]),
3300                        rotations_byte_off: arena.offset(node.inputs[2]) as u32,
3301                        rotations_len: elem_len(node.inputs[2]),
3302                        opacities_byte_off: arena.offset(node.inputs[3]) as u32,
3303                        opacities_len: elem_len(node.inputs[3]),
3304                        colors_byte_off: arena.offset(node.inputs[4]) as u32,
3305                        colors_len: elem_len(node.inputs[4]),
3306                        sh_coeffs_byte_off: arena.offset(node.inputs[5]) as u32,
3307                        sh_coeffs_len: elem_len(node.inputs[5]),
3308                        meta_byte_off: arena.offset(node.inputs[6]) as u32,
3309                        dst_byte_off: arena.offset(node.id) as u32,
3310                        dst_len: node.shape.num_elements().unwrap_or(0) as u32,
3311                        width: *width,
3312                        height: *height,
3313                        tile_size: *tile_size,
3314                        radius_scale: *radius_scale,
3315                        alpha_cutoff: *alpha_cutoff,
3316                        max_splat_steps: *max_splat_steps,
3317                        transmittance_threshold: *transmittance_threshold,
3318                        max_list_entries: *max_list_entries,
3319                    });
3320                }
3321
3322                #[cfg(feature = "splat")]
3323                Op::GaussianSplatRenderBackward {
3324                    width,
3325                    height,
3326                    tile_size,
3327                    radius_scale,
3328                    alpha_cutoff,
3329                    max_splat_steps,
3330                    transmittance_threshold,
3331                    max_list_entries,
3332                    loss_grad_clip,
3333                    sh_band,
3334                    max_anisotropy,
3335                } => {
3336                    let elem_len = |id: NodeId| -> u32 {
3337                        graph.node(id).shape.num_elements().unwrap_or(0) as u32
3338                    };
3339                    schedule.push(Step::GaussianSplatRenderBackward {
3340                        positions_byte_off: arena.offset(node.inputs[0]) as u32,
3341                        positions_len: elem_len(node.inputs[0]),
3342                        scales_byte_off: arena.offset(node.inputs[1]) as u32,
3343                        scales_len: elem_len(node.inputs[1]),
3344                        rotations_byte_off: arena.offset(node.inputs[2]) as u32,
3345                        rotations_len: elem_len(node.inputs[2]),
3346                        opacities_byte_off: arena.offset(node.inputs[3]) as u32,
3347                        opacities_len: elem_len(node.inputs[3]),
3348                        colors_byte_off: arena.offset(node.inputs[4]) as u32,
3349                        colors_len: elem_len(node.inputs[4]),
3350                        sh_coeffs_byte_off: arena.offset(node.inputs[5]) as u32,
3351                        sh_coeffs_len: elem_len(node.inputs[5]),
3352                        meta_byte_off: arena.offset(node.inputs[6]) as u32,
3353                        d_loss_byte_off: arena.offset(node.inputs[7]) as u32,
3354                        d_loss_len: elem_len(node.inputs[7]),
3355                        packed_byte_off: arena.offset(node.id) as u32,
3356                        packed_len: node.shape.num_elements().unwrap_or(0) as u32,
3357                        width: *width,
3358                        height: *height,
3359                        tile_size: *tile_size,
3360                        radius_scale: *radius_scale,
3361                        alpha_cutoff: *alpha_cutoff,
3362                        max_splat_steps: *max_splat_steps,
3363                        transmittance_threshold: *transmittance_threshold,
3364                        max_list_entries: *max_list_entries,
3365                        loss_grad_clip: *loss_grad_clip,
3366                        sh_band: *sh_band,
3367                        max_anisotropy: *max_anisotropy,
3368                    });
3369                }
3370
3371                #[cfg(feature = "splat")]
3372                Op::GaussianSplatPrepare {
3373                    width,
3374                    height,
3375                    tile_size,
3376                    radius_scale,
3377                    alpha_cutoff,
3378                    max_splat_steps,
3379                    transmittance_threshold,
3380                    max_list_entries,
3381                } => {
3382                    let elem_len = |id: NodeId| -> u32 {
3383                        graph.node(id).shape.num_elements().unwrap_or(0) as u32
3384                    };
3385                    schedule.push(Step::GaussianSplatPrepare {
3386                        positions_byte_off: arena.offset(node.inputs[0]) as u32,
3387                        positions_len: elem_len(node.inputs[0]),
3388                        scales_byte_off: arena.offset(node.inputs[1]) as u32,
3389                        scales_len: elem_len(node.inputs[1]),
3390                        rotations_byte_off: arena.offset(node.inputs[2]) as u32,
3391                        rotations_len: elem_len(node.inputs[2]),
3392                        opacities_byte_off: arena.offset(node.inputs[3]) as u32,
3393                        opacities_len: elem_len(node.inputs[3]),
3394                        colors_byte_off: arena.offset(node.inputs[4]) as u32,
3395                        colors_len: elem_len(node.inputs[4]),
3396                        sh_coeffs_byte_off: arena.offset(node.inputs[5]) as u32,
3397                        sh_coeffs_len: elem_len(node.inputs[5]),
3398                        meta_byte_off: arena.offset(node.inputs[6]) as u32,
3399                        meta_len: elem_len(node.inputs[6]),
3400                        prep_byte_off: arena.offset(node.id) as u32,
3401                        prep_len: node.shape.num_elements().unwrap_or(0) as u32,
3402                        width: *width,
3403                        height: *height,
3404                        tile_size: *tile_size,
3405                        radius_scale: *radius_scale,
3406                        alpha_cutoff: *alpha_cutoff,
3407                        max_splat_steps: *max_splat_steps,
3408                        transmittance_threshold: *transmittance_threshold,
3409                        max_list_entries: *max_list_entries,
3410                    });
3411                }
3412
3413                #[cfg(feature = "splat")]
3414                Op::GaussianSplatRasterize {
3415                    width,
3416                    height,
3417                    tile_size,
3418                    alpha_cutoff,
3419                    max_splat_steps,
3420                    transmittance_threshold,
3421                    max_list_entries,
3422                } => {
3423                    let elem_len = |id: NodeId| -> u32 {
3424                        graph.node(id).shape.num_elements().unwrap_or(0) as u32
3425                    };
3426                    let prep_id = node.inputs[0];
3427                    let count = match &graph.node(prep_id).op {
3428                        rlx_ir::Op::GaussianSplatPrepare { .. } => {
3429                            elem_len(graph.node(prep_id).inputs[0]) / 3
3430                        }
3431                        _ => 1,
3432                    };
3433                    schedule.push(Step::GaussianSplatRasterize {
3434                        prep_byte_off: arena.offset(prep_id) as u32,
3435                        prep_len: elem_len(prep_id),
3436                        meta_byte_off: arena.offset(node.inputs[1]) as u32,
3437                        meta_len: elem_len(node.inputs[1]),
3438                        dst_byte_off: arena.offset(node.id) as u32,
3439                        dst_len: node.shape.num_elements().unwrap_or(0) as u32,
3440                        count,
3441                        width: *width,
3442                        height: *height,
3443                        tile_size: *tile_size,
3444                        alpha_cutoff: *alpha_cutoff,
3445                        max_splat_steps: *max_splat_steps,
3446                        transmittance_threshold: *transmittance_threshold,
3447                        max_list_entries: *max_list_entries,
3448                    });
3449                }
3450
3451                Op::If { .. } | Op::While { .. } => {
3452                    // Should be unreachable: unfuse.rs inlines both branches
3453                    // (If) or unrolls max_iterations (While) into the parent
3454                    // graph using primitive ops + Where for the gating. If
3455                    // we hit this arm, the unfusion pass has a gap.
3456                    panic!(
3457                        "rlx-wgpu: Op::If/While leaked past unfusion pass — \
3458                            check unfuse.rs::expand_if / expand_while"
3459                    );
3460                }
3461                other => panic!(
3462                    "rlx-wgpu: op {other:?} not yet lowered (v2 covers Matmul, \
3463                     Binary, Compare, Activation, Where — fall back to CPU/Metal/MLX)"
3464                ),
3465            }
3466        }
3467
3468        if rlx_ir::env::flag("RLX_WGPU_SCHEDULE") || rlx_ir::env::flag("RLX_DISPATCH_REPORT") {
3469            let mut counts: std::collections::BTreeMap<&'static str, usize> =
3470                std::collections::BTreeMap::new();
3471            let mut fft_gpu = 0usize;
3472            let mut fft_host = 0usize;
3473            for s in &schedule {
3474                *counts.entry(step_name(s)).or_insert(0) += 1;
3475                match s {
3476                    Step::FftGpu { .. } => fft_gpu += 1,
3477                    Step::FftHost { .. } => fft_host += 1,
3478                    _ => {}
3479                }
3480            }
3481            let arena_mb = arena.size as f64 / (1u64 << 20) as f64;
3482            eprintln!(
3483                "[rlx-wgpu] schedule: {} steps, arena={arena_mb:.1} MiB, fft_gpu={fft_gpu}, fft_host={fft_host}",
3484                schedule.len()
3485            );
3486            for (n, c) in &counts {
3487                eprintln!("    {c:>4} × {n}");
3488            }
3489        }
3490
3491        Self {
3492            graph,
3493            arena,
3494            schedule,
3495            input_offsets,
3496            param_offsets,
3497            uniforms,
3498            bind_groups,
3499            meta_buffers,
3500            unresolved: None,
3501            last_binding: None,
3502            pending_params: HashMap::new(),
3503            pending_param_bytes: HashMap::new(),
3504            active_extent: None,
3505            uniforms_active_extent: None,
3506            fft_gpu_steps,
3507        }
3508    }
3509
3510    pub fn set_param(&mut self, name: &str, data: &[f32]) {
3511        if self.unresolved.is_some() {
3512            self.pending_params.insert(name.to_string(), data.to_vec());
3513            return;
3514        }
3515        let dev = wgpu_device().expect("rlx-wgpu: device gone");
3516        if let Some(&id) = self.param_offsets.get(name)
3517            && self.arena.has(id)
3518        {
3519            self.arena.write_f32(&dev.queue, id, data);
3520        }
3521    }
3522
3523    /// Debug helper: run forward, then read every node slot back and
3524    /// report the first node whose output contains a NaN, plus a
3525    /// summary of the *previous* finite node's value range so the
3526    /// caller can see the input that broke. Slow — diagnosis only.
3527    pub fn debug_first_nan_node(
3528        &mut self,
3529        inputs: &[(&str, &[f32])],
3530    ) -> Option<(usize, String, String)> {
3531        let _ = self.run(inputs);
3532        let dev = wgpu_device().expect("rlx-wgpu: device gone");
3533        let mut prev_summary = String::from("(none)");
3534        for (i, node) in self.graph.nodes().iter().enumerate() {
3535            if !self.arena.has(node.id) {
3536                continue;
3537            }
3538            let elems = node.shape.num_elements().unwrap_or(0);
3539            if elems == 0 {
3540                continue;
3541            }
3542            let data = self.arena.read_f32(&dev.device, &dev.queue, node.id);
3543            let nan_count = data.iter().filter(|v| v.is_nan()).count();
3544            let inf_count = data.iter().filter(|v| v.is_infinite()).count();
3545            if nan_count > 0 || inf_count > 0 {
3546                return Some((i, format!("{:?}", node.op), prev_summary));
3547            }
3548            let max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
3549            let min = data.iter().copied().fold(f32::INFINITY, f32::min);
3550            let abs_max = data.iter().map(|v| v.abs()).fold(0.0_f32, f32::max);
3551            prev_summary = format!(
3552                "node #{i} {:?} shape={:?}  min={min:.6e} max={max:.6e} |max|={abs_max:.6e}",
3553                node.op,
3554                node.shape
3555                    .dims()
3556                    .iter()
3557                    .map(|d| format!("{d:?}"))
3558                    .collect::<Vec<_>>()
3559            );
3560        }
3561        None
3562    }
3563
3564    /// Declared output dtypes (one per graph output). Used by the
3565    /// runtime wrapper's `run_typed` to narrow F32 results back to
3566    /// F16/BF16 etc. on the way out.
3567    pub fn output_dtypes(&self) -> Vec<rlx_ir::DType> {
3568        self.graph
3569            .outputs
3570            .iter()
3571            .map(|&id| self.graph.node(id).shape.dtype())
3572            .collect()
3573    }
3574
3575    /// Upload raw bytes for a Param. The bytes land tight-packed at
3576    /// the param's slot offset — no f32 round-trip. Used for quantized
3577    /// weights (int8 / int4) where the kernel reads the byte stream
3578    /// via `bitcast<u32>` from the f32-typed arena.
3579    pub fn set_param_bytes(&mut self, name: &str, data: &[u8]) {
3580        if self.unresolved.is_some() {
3581            self.pending_param_bytes
3582                .insert(name.to_string(), data.to_vec());
3583            return;
3584        }
3585        let dev = wgpu_device().expect("rlx-wgpu: device gone");
3586        if let Some(&id) = self.param_offsets.get(name)
3587            && self.arena.has(id)
3588        {
3589            dev.queue
3590                .write_buffer(&self.arena.buffer, self.arena.offset(id) as u64, data);
3591        }
3592    }
3593
3594    pub fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
3595        // Lazy compile path: if we deferred compile waiting for shapes,
3596        // infer the binding from input data lengths now and compile.
3597        if self.unresolved.is_some() {
3598            self.lazy_compile_for_inputs(inputs);
3599        }
3600        let dev = wgpu_device().expect("rlx-wgpu: device gone");
3601        for &(name, data) in inputs {
3602            if let Some(&id) = self.input_offsets.get(name)
3603                && self.arena.has(id)
3604            {
3605                self.arena.write_f32(&dev.queue, id, data);
3606            }
3607        }
3608
3609        // Active-extent (PLAN L1): scale safe Steps' primary dim by
3610        // actual/upper. Used in BOTH the uniform-write loop (so the
3611        // kernel sees the scaled count) AND the dispatch loop (so the
3612        // workgroup grid is shrunk).
3613        let active = self.active_extent.filter(|_| self.all_safe_for_active());
3614        let scale = |full: u32| -> u32 {
3615            match active {
3616                Some((a, u)) if u > 0 => {
3617                    let f = full as usize;
3618                    (f * a).div_ceil(u).min(f) as u32
3619                }
3620                _ => full,
3621            }
3622        };
3623
3624        // Stage uniform writes — but skip the loop entirely when the
3625        // bytes already in the uniforms match this run's active extent.
3626        // BERT inference at fixed batch hits this path: 100+ tiny
3627        // queue.write_buffer calls (one per Step) collapse to zero,
3628        // saving milliseconds of staging-copy overhead.
3629        let need_uniform_writes = self.uniforms_active_extent != Some(active);
3630        if need_uniform_writes {
3631            let mut gpu_ui = 0usize;
3632            for step in self.schedule.iter() {
3633                if step_runs_on_host(step) {
3634                    continue;
3635                }
3636                match step {
3637                    Step::CastF32ToF16 { .. } => {
3638                        // Params are static for this step (offset+len), so the
3639                        // pre-pass write at compile time is sufficient. No
3640                        // active-extent scaling — len is the full element count.
3641                    }
3642                    Step::Matmul {
3643                        m,
3644                        k,
3645                        n,
3646                        a_off_f32,
3647                        b_off_f32,
3648                        c_off_f32,
3649                        batch,
3650                        a_batch_stride,
3651                        b_batch_stride,
3652                        c_batch_stride,
3653                        has_bias,
3654                        bias_off_f32,
3655                        act_id,
3656                        b_is_param: _,
3657                        compute_precision: _,
3658                    } => {
3659                        // PLAN L1 (safe at any batch — c_batch_stride is
3660                        // pre-baked at compile time at FULL m, so scaling
3661                        // params.m only changes per-thread bound checks).
3662                        let m_scaled = scale(*m);
3663                        let p = MatmulParams {
3664                            m: m_scaled,
3665                            k: *k,
3666                            n: *n,
3667                            a_off: *a_off_f32,
3668                            b_off: *b_off_f32,
3669                            c_off: *c_off_f32,
3670                            batch: *batch,
3671                            a_batch_stride: *a_batch_stride,
3672                            b_batch_stride: *b_batch_stride,
3673                            c_batch_stride: *c_batch_stride,
3674                            has_bias: *has_bias,
3675                            bias_off: *bias_off_f32,
3676                            act_id: *act_id,
3677                            _pad0: 0,
3678                            _pad1: 0,
3679                            _pad2: 0,
3680                        };
3681                        dev.queue
3682                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3683                    }
3684                    Step::Binary { params } | Step::Compare { params } => {
3685                        let mut p = *params;
3686                        p.n = scale(p.n);
3687                        dev.queue
3688                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3689                    }
3690                    Step::Unary { params } => {
3691                        let mut p = *params;
3692                        p.n = scale(p.n);
3693                        dev.queue
3694                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3695                    }
3696                    Step::Where { params } => {
3697                        let mut p = *params;
3698                        p.n = scale(p.n);
3699                        dev.queue
3700                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3701                    }
3702                    Step::Reduce { params } => {
3703                        let mut p = *params;
3704                        p.outer = scale(p.outer);
3705                        dev.queue
3706                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3707                    }
3708                    Step::Softmax { params } => {
3709                        let mut p = *params;
3710                        p.outer = scale(p.outer);
3711                        dev.queue
3712                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3713                    }
3714                    Step::LayerNorm { params } => {
3715                        let mut p = *params;
3716                        p.outer = scale(p.outer);
3717                        dev.queue
3718                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3719                    }
3720                    Step::RmsNormBackwardInput { params }
3721                    | Step::RmsNormBackwardGamma { params }
3722                    | Step::RmsNormBackwardBeta { params } => {
3723                        let mut p = *params;
3724                        p.outer = scale(p.outer);
3725                        dev.queue
3726                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3727                    }
3728                    Step::CumsumBackward { params } => {
3729                        let mut p = *params;
3730                        p.outer = scale(p.outer);
3731                        dev.queue
3732                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3733                    }
3734                    Step::RopeBackward { params } => {
3735                        let mut p = *params;
3736                        p.seq = scale(p.seq);
3737                        dev.queue
3738                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3739                    }
3740                    Step::GatherBackward { params } => {
3741                        let mut p = *params;
3742                        p.outer = scale(p.outer);
3743                        dev.queue
3744                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3745                    }
3746                    Step::Cumsum { params } => {
3747                        let mut p = *params;
3748                        p.outer = scale(p.outer);
3749                        dev.queue
3750                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3751                    }
3752                    Step::FftGpu { .. } => {}
3753                    Step::Copy { params } => {
3754                        let mut p = *params;
3755                        p.n = scale(p.n);
3756                        dev.queue
3757                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3758                    }
3759                    Step::ElementwiseRegion { params } => {
3760                        // Active-extent: scale element count.
3761                        let mut p = *params;
3762                        p.len = scale(p.len);
3763                        dev.queue
3764                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3765                    }
3766                    Step::Transpose { params, .. } => {
3767                        // PLAN L1: when bucket_outermost == 1, scale
3768                        // `out_total` proportional to scaling `out_dim_0`.
3769                        // Other transposes leave out_total at full extent
3770                        // (predicate prevents the active-extent path).
3771                        let mut p = *params;
3772                        if p.bucket_outermost == 1 && p.out_dim_0 > 0 {
3773                            let scaled_d0 = scale(p.out_dim_0);
3774                            let inner = p.out_total / p.out_dim_0;
3775                            p.out_total = scaled_d0 * inner;
3776                        }
3777                        dev.queue
3778                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3779                    }
3780                    Step::Narrow { params } => {
3781                        let mut p = *params;
3782                        p.total = scale(p.total);
3783                        dev.queue
3784                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3785                    }
3786                    Step::Concat { params } => {
3787                        let mut p = *params;
3788                        p.total = scale(p.total);
3789                        dev.queue
3790                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3791                    }
3792                    Step::Gather { params } => {
3793                        let mut p = *params;
3794                        p.n_out = scale(p.n_out);
3795                        dev.queue
3796                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3797                    }
3798                    Step::GatherAxis { params } => {
3799                        let mut p = *params;
3800                        p.total = scale(p.total);
3801                        dev.queue
3802                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3803                    }
3804                    Step::Attention { params, .. } => {
3805                        // PLAN L1: scale seq_q + seq_k. Stride fields
3806                        // (seq_q_stride / seq_k_stride) stay at the
3807                        // compile-time full extent, so per-(batch, head)
3808                        // offset math in the WGSL stays correct.
3809                        let mut p = *params;
3810                        p.seq_q = scale(p.seq_q);
3811                        p.seq_k = scale(p.seq_k);
3812                        dev.queue
3813                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3814                    }
3815                    Step::AttentionBackward { params, .. } => {
3816                        let mut p = *params;
3817                        if p.wrt == 0 {
3818                            p.seq_q = scale(p.seq_q);
3819                        } else {
3820                            p.seq_k = scale(p.seq_k);
3821                        }
3822                        dev.queue
3823                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3824                    }
3825                    Step::Rope { params } => {
3826                        // PLAN L1: scale `seq` and `n_total` proportionally.
3827                        // `seq_stride` and `batch` stay at compile-time
3828                        // values; the WGSL kernel uses them for buffer
3829                        // offsets while `seq` / `n_total` are loop bounds.
3830                        let mut p = *params;
3831                        let s_active = scale(p.seq);
3832                        p.seq = s_active;
3833                        p.n_total = p.batch * s_active * p.last_dim;
3834                        dev.queue
3835                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3836                    }
3837                    Step::Expand { params, .. } => {
3838                        // PLAN L1: same pattern as Transpose.
3839                        let mut p = *params;
3840                        if p.bucket_outermost == 1 && p.out_dim_0 > 0 {
3841                            let scaled_d0 = scale(p.out_dim_0);
3842                            let inner = p.out_total / p.out_dim_0;
3843                            p.out_total = scaled_d0 * inner;
3844                        }
3845                        dev.queue
3846                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3847                    }
3848                    Step::Argmax { params } => {
3849                        let mut p = *params;
3850                        p.outer = scale(p.outer);
3851                        dev.queue
3852                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3853                    }
3854                    Step::Pool2d { params } => {
3855                        let mut p = *params;
3856                        p.n = scale(p.n);
3857                        dev.queue
3858                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3859                    }
3860                    Step::Conv2d { params } => {
3861                        let mut p = *params;
3862                        p.n = scale(p.n);
3863                        dev.queue
3864                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3865                    }
3866                    Step::Pool1d { params } => {
3867                        let mut p = *params;
3868                        p.n = scale(p.n);
3869                        dev.queue
3870                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3871                    }
3872                    Step::Pool3d { params } => {
3873                        let mut p = *params;
3874                        p.n = scale(p.n);
3875                        dev.queue
3876                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3877                    }
3878                    Step::Conv1d { params } => {
3879                        let mut p = *params;
3880                        p.n = scale(p.n);
3881                        dev.queue
3882                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3883                    }
3884                    Step::Conv3d { params } => {
3885                        let mut p = *params;
3886                        p.n = scale(p.n);
3887                        dev.queue
3888                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3889                    }
3890                    Step::ScatterAdd { params } => {
3891                        // Two-phase: phase 0 zeros the FULL output (preserves
3892                        // accumulator semantics); phase 1 scatters first
3893                        // num_updates_active updates only.
3894                        let mut p = *params;
3895                        if p.op == 1 {
3896                            p.num_updates = scale(p.num_updates);
3897                        }
3898                        dev.queue
3899                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3900                    }
3901                    Step::TopK { params } => {
3902                        let mut p = *params;
3903                        p.outer = scale(p.outer);
3904                        dev.queue
3905                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3906                    }
3907                    Step::UmapKnn { params } => {
3908                        let mut p = *params;
3909                        p.n = scale(p.n);
3910                        dev.queue
3911                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3912                    }
3913                    Step::GroupedMatmul { params } => {
3914                        let mut p = *params;
3915                        p.m = scale(p.m);
3916                        dev.queue
3917                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3918                    }
3919                    Step::Sample { params } => {
3920                        let mut p = *params;
3921                        p.outer = scale(p.outer);
3922                        dev.queue
3923                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3924                    }
3925                    Step::SelectiveScan { params } => {
3926                        // Predicate-gated to batch=1: scale seq.
3927                        let mut p = *params;
3928                        p.seq = scale(p.seq);
3929                        dev.queue
3930                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3931                    }
3932                    Step::DequantMatmul { params } => {
3933                        let mut p = *params;
3934                        p.m = scale(p.m);
3935                        dev.queue
3936                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3937                    }
3938                    Step::DequantMatmulGguf { .. }
3939                    | Step::DequantGroupedMatmulGguf { .. }
3940                    | Step::GatedDeltaNet { .. }
3941                    | Step::Llada2GroupLimitedGate { .. }
3942                    | Step::UmapKnnHost { .. }
3943                    | Step::FftHost { .. } => {}
3944                    Step::FusedResidualLn { params } => {
3945                        let mut p = *params;
3946                        p.outer = scale(p.outer);
3947                        dev.queue
3948                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3949                    }
3950                    Step::FusedResidualLnTee { params } => {
3951                        let mut p = *params;
3952                        p.outer = scale(p.outer);
3953                        dev.queue
3954                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3955                    }
3956                    Step::FusedResidualRmsNorm { params } => {
3957                        let mut p = *params;
3958                        p.outer = scale(p.outer);
3959                        dev.queue
3960                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3961                    }
3962                    Step::MatmulQkv { params, coop: _ } => {
3963                        let mut p = *params;
3964                        p.m = scale(p.m);
3965                        dev.queue
3966                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
3967                    }
3968                    #[cfg(feature = "splat")]
3969                    Step::GaussianSplatRender { .. }
3970                    | Step::GaussianSplatRenderBackward { .. }
3971                    | Step::GaussianSplatPrepare { .. }
3972                    | Step::GaussianSplatRasterize { .. } => {}
3973                }
3974                if !matches!(step, Step::FftGpu { .. }) {
3975                    gpu_ui += 1;
3976                }
3977            }
3978            self.uniforms_active_extent = Some(active);
3979        }
3980
3981        // Encode + submit.
3982        let mm_k = matmul_kernel(&dev.device);
3983        let mm_w = matmul_wide_kernel(&dev.device);
3984        let mm_f16w = matmul_f16w_kernel(&dev.device);
3985        let mm_f16c = matmul_f16_compute_kernel(&dev.device);
3986        let mm_coop = matmul_coop16_kernel(&dev.device);
3987        let mm_coop_f32 = matmul_coop_f32_kernel(&dev.device);
3988        let mm_cast = cast_f32_to_f16_kernel(&dev.device);
3989        let bk = binary_kernel(&dev.device);
3990        let uk = unary_kernel(&dev.device);
3991        let ck = compare_kernel(&dev.device);
3992        let wk = where_kernel(&dev.device);
3993        let mut step_i = 0;
3994        let mut gpu_bi = 0usize;
3995        let mut fft_i = 0usize;
3996        while step_i < self.schedule.len() {
3997            let mut enc = dev
3998                .device
3999                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
4000                    label: Some("rlx-wgpu run"),
4001                });
4002            {
4003                let mut pass = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
4004                    label: Some("rlx-wgpu compute pass"),
4005                    timestamp_writes: None,
4006                });
4007                while step_i < self.schedule.len() {
4008                    if step_runs_on_host(&self.schedule[step_i]) {
4009                        break;
4010                    }
4011                    let step = &self.schedule[step_i];
4012                    // PLAN L3: per-step Perfetto trace span; no-op when
4013                    // env var RLX_TRACE_PERFETTO unset.
4014                    let _perf = rlx_ir::perfetto::TraceSpan::new(step_name(step), "wgpu");
4015                    match step {
4016                        Step::CastF32ToF16 { params } => {
4017                            // Pre-pass for matmul_coop16: mirror f32 arena
4018                            // region into f16 shadow buffer so the matmul
4019                            // kernel can read A as f16. One thread per
4020                            // element; 64-thread workgroups.
4021                            if let Some(cast_k) = mm_cast {
4022                                pass.set_pipeline(&cast_k.pipeline);
4023                                pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4024                                let (gx, gy, gz) = dispatch_dims(params.len, 64);
4025                                pass.dispatch_workgroups(gx, gy, gz);
4026                            }
4027                        }
4028                        Step::Matmul {
4029                            m,
4030                            n,
4031                            batch,
4032                            b_is_param,
4033                            compute_precision,
4034                            ..
4035                        } =>
4036                        // The dispatch branches below use a chain of
4037                        // `is_some() && …unwrap()` to pick a pipeline
4038                        // because each variant cares about a different
4039                        // Option<Pipeline>. `if let Some(p) = …` chains
4040                        // would require nesting per variant; the flat
4041                        // form is the readable shape here.
4042                        {
4043                            #[allow(clippy::unnecessary_unwrap)]
4044                            // Safe at any batch (see safe_for_active_extent
4045                            // comment); scale m, output rows past m_s per
4046                            // batch retain prior values via c_batch_stride.
4047                            let m_s = scale(*m);
4048                            if m_s == 0 {
4049                                continue;
4050                            }
4051                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4052                            // Kernel selection priority:
4053                            //   1. compute_precision == F16 + b_is_param +
4054                            //      SHADER_F16 → matmul_f16_compute
4055                            //      (f16 multiply, f32 acc — 2× ALU on Apple)
4056                            //   2. legacy RLX_WGPU_F16_WEIGHTS opt-in →
4057                            //      matmul_f16w (storage-only f16; experimental,
4058                            //      currently regresses on Apple)
4059                            //   3. wide-N (m≥32, n≥64)   → matmul_wide
4060                            //   4. otherwise            → matmul (small/skinny)
4061                            let f16w_opt_in = rlx_ir::env::flag("RLX_WGPU_F16_WEIGHTS");
4062                            if let Some(coop) = mm_coop.as_ref()
4063                                && *b_is_param
4064                                && *compute_precision == MatmulCompute::Coop16
4065                            {
4066                                // Hardware GEMM via simdgroup_matrix /
4067                                // KHR_cooperative_matrix. 32×32 output tile
4068                                // per workgroup (16 hardware-GEMM ops with
4069                                // shared A/B loads). Caller guaranteed m, n,
4070                                // k are multiples of 32/32/8.
4071                                pass.set_pipeline(&coop.pipeline);
4072                                pass.dispatch_workgroups(n / 32, m_s.div_ceil(32), *batch);
4073                            } else if let Some(coop_f32) = mm_coop_f32.as_ref()
4074                                && *b_is_param
4075                                && *compute_precision == MatmulCompute::CoopF32
4076                            {
4077                                // Pure-f32 cooperative-matrix path
4078                                // (`simdgroup_float8x8` on Apple). Same tile
4079                                // shape as Coop16; no precision loss.
4080                                pass.set_pipeline(&coop_f32.pipeline);
4081                                pass.dispatch_workgroups(n / 32, m_s.div_ceil(32), *batch);
4082                            } else if let Some(f16c) = mm_f16c.as_ref()
4083                                && *b_is_param
4084                                && *compute_precision == MatmulCompute::F16
4085                            {
4086                                pass.set_pipeline(&f16c.pipeline);
4087                                pass.dispatch_workgroups(n.div_ceil(32), m_s.div_ceil(32), *batch);
4088                            } else if let Some(f16w) = mm_f16w.as_ref()
4089                                && *b_is_param
4090                                && f16w_opt_in
4091                            {
4092                                pass.set_pipeline(&f16w.pipeline);
4093                                pass.dispatch_workgroups(n.div_ceil(32), m_s.div_ceil(32), *batch);
4094                            } else if m_s >= 32 && *n >= 64 {
4095                                pass.set_pipeline(&mm_w.pipeline);
4096                                pass.dispatch_workgroups(n.div_ceil(64), m_s.div_ceil(32), *batch);
4097                            } else {
4098                                pass.set_pipeline(&mm_k.pipeline);
4099                                pass.dispatch_workgroups(n.div_ceil(32), m_s.div_ceil(32), *batch);
4100                            }
4101                        }
4102                        Step::Binary { params } => {
4103                            let n_s = scale(params.n);
4104                            if n_s == 0 {
4105                                continue;
4106                            }
4107                            pass.set_pipeline(&bk.pipeline);
4108                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4109                            let (gx, gy, gz) = dispatch_dims(n_s, 64);
4110                            pass.dispatch_workgroups(gx, gy, gz);
4111                        }
4112                        Step::Compare { params } => {
4113                            let n_s = scale(params.n);
4114                            if n_s == 0 {
4115                                continue;
4116                            }
4117                            pass.set_pipeline(&ck.pipeline);
4118                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4119                            let (gx, gy, gz) = dispatch_dims(n_s, 64);
4120                            pass.dispatch_workgroups(gx, gy, gz);
4121                        }
4122                        Step::Unary { params } => {
4123                            let n_s = scale(params.n);
4124                            if n_s == 0 {
4125                                continue;
4126                            }
4127                            pass.set_pipeline(&uk.pipeline);
4128                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4129                            let (gx, gy, gz) = dispatch_dims(n_s, 64);
4130                            pass.dispatch_workgroups(gx, gy, gz);
4131                        }
4132                        Step::Where { params } => {
4133                            let n_s = scale(params.n);
4134                            if n_s == 0 {
4135                                continue;
4136                            }
4137                            pass.set_pipeline(&wk.pipeline);
4138                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4139                            let (gx, gy, gz) = dispatch_dims(n_s, 64);
4140                            pass.dispatch_workgroups(gx, gy, gz);
4141                        }
4142                        Step::Reduce { params } => {
4143                            let outer_s = scale(params.outer);
4144                            if outer_s == 0 {
4145                                continue;
4146                            }
4147                            let rk = reduce_kernel(&dev.device);
4148                            pass.set_pipeline(&rk.pipeline);
4149                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4150                            let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4151                            pass.dispatch_workgroups(gx, gy, gz);
4152                        }
4153                        Step::Softmax { params } => {
4154                            let outer_s = scale(params.outer);
4155                            if outer_s == 0 {
4156                                continue;
4157                            }
4158                            let sk = softmax_kernel(&dev.device);
4159                            pass.set_pipeline(&sk.pipeline);
4160                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4161                            let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4162                            pass.dispatch_workgroups(gx, gy, gz);
4163                        }
4164                        Step::LayerNorm { params } => {
4165                            let outer_s = scale(params.outer);
4166                            if outer_s == 0 {
4167                                continue;
4168                            }
4169                            let lk = layernorm_kernel(&dev.device);
4170                            pass.set_pipeline(&lk.pipeline);
4171                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4172                            let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4173                            pass.dispatch_workgroups(gx, gy, gz);
4174                        }
4175                        Step::RmsNormBackwardInput { params } => {
4176                            let outer_s = scale(params.outer);
4177                            if outer_s == 0 {
4178                                continue;
4179                            }
4180                            let rk = rms_norm_backward_kernel(&dev.device);
4181                            pass.set_pipeline(&rk.pipeline);
4182                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4183                            pass.dispatch_workgroups(outer_s, 1, 1);
4184                        }
4185                        Step::RmsNormBackwardGamma { params }
4186                        | Step::RmsNormBackwardBeta { params } => {
4187                            if params.inner == 0 {
4188                                continue;
4189                            }
4190                            let rk = rms_norm_backward_param_kernel(&dev.device);
4191                            pass.set_pipeline(&rk.pipeline);
4192                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4193                            pass.dispatch_workgroups(1, 1, 1);
4194                        }
4195                        Step::CumsumBackward { params } => {
4196                            let outer_s = scale(params.outer);
4197                            if outer_s == 0 {
4198                                continue;
4199                            }
4200                            let ck = cumsum_backward_kernel(&dev.device);
4201                            pass.set_pipeline(&ck.pipeline);
4202                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4203                            let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4204                            pass.dispatch_workgroups(gx, gy, gz);
4205                        }
4206                        Step::RopeBackward { params } => {
4207                            let seq_s = scale(params.seq);
4208                            if seq_s == 0 {
4209                                continue;
4210                            }
4211                            let rk = rope_backward_kernel(&dev.device);
4212                            pass.set_pipeline(&rk.pipeline);
4213                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4214                            let total = params.batch * seq_s * params.hidden;
4215                            let (gx, gy, gz) = dispatch_dims(total, 64);
4216                            pass.dispatch_workgroups(gx, gy, gz);
4217                        }
4218                        Step::GatherBackward { params } => {
4219                            let outer_s = scale(params.outer);
4220                            if outer_s == 0 {
4221                                continue;
4222                            }
4223                            let total = outer_s * params.axis_dim * params.trailing;
4224                            if total > 0 {
4225                                let zk = gather_backward_zero_kernel(&dev.device);
4226                                pass.set_pipeline(&zk.pipeline);
4227                                pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4228                                let (gx, _, _) = dispatch_dims(total, 256);
4229                                pass.dispatch_workgroups(gx, 1, 1);
4230                            }
4231                            let ak = gather_backward_acc_kernel(&dev.device);
4232                            pass.set_pipeline(&ak.pipeline);
4233                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4234                            pass.dispatch_workgroups(outer_s, 1, 1);
4235                        }
4236                        Step::Cumsum { params } => {
4237                            let outer_s = scale(params.outer);
4238                            if outer_s == 0 {
4239                                continue;
4240                            }
4241                            let ck2 = cumsum_kernel(&dev.device);
4242                            pass.set_pipeline(&ck2.pipeline);
4243                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4244                            let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4245                            pass.dispatch_workgroups(gx, gy, gz);
4246                        }
4247                        Step::FftGpu {
4248                            src_off,
4249                            dst_off,
4250                            outer,
4251                            n,
4252                            inverse,
4253                            norm_scale,
4254                        } => {
4255                            let res = &self.fft_gpu_steps[fft_i];
4256                            fft_i += 1;
4257                            crate::fft_dispatch::dispatch_fft_gpu_in_pass(
4258                                &dev.device,
4259                                &dev.queue,
4260                                &mut pass,
4261                                res,
4262                                *src_off,
4263                                *dst_off,
4264                                *outer,
4265                                *n,
4266                                *inverse != 0,
4267                                *norm_scale,
4268                            );
4269                        }
4270                        Step::Copy { params } => {
4271                            let n_s = scale(params.n);
4272                            if n_s == 0 {
4273                                continue;
4274                            }
4275                            let ck2 = copy_kernel(&dev.device);
4276                            pass.set_pipeline(&ck2.pipeline);
4277                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4278                            let (gx, gy, gz) = dispatch_dims(n_s, 64);
4279                            pass.dispatch_workgroups(gx, gy, gz);
4280                        }
4281                        Step::ElementwiseRegion { params } => {
4282                            let len_s = scale(params.len);
4283                            if len_s == 0 {
4284                                continue;
4285                            }
4286                            let ek = elementwise_region_kernel(&dev.device);
4287                            pass.set_pipeline(&ek.pipeline);
4288                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4289                            let (gx, gy, gz) = dispatch_dims(len_s, 64);
4290                            pass.dispatch_workgroups(gx, gy, gz);
4291                        }
4292                        Step::Transpose { params, .. } => {
4293                            // Compute scaled grid count to match the
4294                            // uniform's scaled out_total when bucket axis
4295                            // is outermost.
4296                            let total_s = if params.bucket_outermost == 1 && params.out_dim_0 > 0 {
4297                                let scaled_d0 = scale(params.out_dim_0);
4298                                let inner = params.out_total / params.out_dim_0;
4299                                scaled_d0 * inner
4300                            } else {
4301                                params.out_total
4302                            };
4303                            if total_s == 0 {
4304                                continue;
4305                            }
4306                            let tk = transpose_kernel(&dev.device);
4307                            pass.set_pipeline(&tk.pipeline);
4308                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4309                            let (gx, gy, gz) = dispatch_dims(total_s, 64);
4310                            pass.dispatch_workgroups(gx, gy, gz);
4311                        }
4312                        Step::Narrow { params } => {
4313                            let total_s = scale(params.total);
4314                            if total_s == 0 {
4315                                continue;
4316                            }
4317                            let nk = narrow_kernel(&dev.device);
4318                            pass.set_pipeline(&nk.pipeline);
4319                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4320                            let (gx, gy, gz) = dispatch_dims(total_s, 64);
4321                            pass.dispatch_workgroups(gx, gy, gz);
4322                        }
4323                        Step::Concat { params } => {
4324                            let total_s = scale(params.total);
4325                            if total_s == 0 {
4326                                continue;
4327                            }
4328                            let cck = concat_kernel(&dev.device);
4329                            pass.set_pipeline(&cck.pipeline);
4330                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4331                            let (gx, gy, gz) = dispatch_dims(total_s, 64);
4332                            pass.dispatch_workgroups(gx, gy, gz);
4333                        }
4334                        Step::Gather { params } => {
4335                            let n_out_s = scale(params.n_out);
4336                            if n_out_s == 0 {
4337                                continue;
4338                            }
4339                            let gk = gather_kernel(&dev.device);
4340                            pass.set_pipeline(&gk.pipeline);
4341                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4342                            let (gx, gy, gz) = dispatch_dims(n_out_s, 64);
4343                            pass.dispatch_workgroups(gx, gy, gz);
4344                        }
4345                        Step::GatherAxis { params } => {
4346                            let total_s = scale(params.total);
4347                            if total_s == 0 {
4348                                continue;
4349                            }
4350                            let gk = gather_axis_kernel(&dev.device);
4351                            pass.set_pipeline(&gk.pipeline);
4352                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4353                            let (gx, gy, gz) = dispatch_dims(total_s, 64);
4354                            pass.dispatch_workgroups(gx, gy, gz);
4355                        }
4356                        Step::Attention { params, .. } => {
4357                            // Scale seq_q for grid dim; per-head strides
4358                            // come from seq_q_stride / seq_k_stride (full
4359                            // extent) inside the WGSL.
4360                            let seq_q_s = scale(params.seq_q);
4361                            if seq_q_s == 0 {
4362                                continue;
4363                            }
4364                            let ak = attention_kernel(&dev.device);
4365                            pass.set_pipeline(&ak.pipeline);
4366                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4367                            let total = params.batch * params.heads * seq_q_s;
4368                            let (gx, gy, gz) = dispatch_dims(total, 64);
4369                            pass.dispatch_workgroups(gx, gy, gz);
4370                        }
4371                        Step::AttentionBackward { params, .. } => {
4372                            let axis = if params.wrt == 0 {
4373                                params.seq_q
4374                            } else {
4375                                params.seq_k
4376                            };
4377                            let axis_s = scale(axis);
4378                            if axis_s == 0 {
4379                                continue;
4380                            }
4381                            let ak = attention_bwd_kernel(&dev.device);
4382                            pass.set_pipeline(&ak.pipeline);
4383                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4384                            let total = params.batch * params.heads * axis_s;
4385                            let (gx, gy, gz) = dispatch_dims(total, 64);
4386                            pass.dispatch_workgroups(gx, gy, gz);
4387                        }
4388                        Step::Rope { params } => {
4389                            // Multi-batch via stride-field WGSL fix:
4390                            // iterate `batch * scaled_seq * last_dim` items.
4391                            let s_active = scale(params.seq);
4392                            let total_s = params.batch * s_active * params.last_dim;
4393                            if total_s == 0 {
4394                                continue;
4395                            }
4396                            let rk = rope_kernel(&dev.device);
4397                            pass.set_pipeline(&rk.pipeline);
4398                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4399                            let (gx, gy, gz) = dispatch_dims(total_s, 64);
4400                            pass.dispatch_workgroups(gx, gy, gz);
4401                        }
4402                        Step::Expand { params, .. } => {
4403                            let total_s = if params.bucket_outermost == 1 && params.out_dim_0 > 0 {
4404                                let scaled_d0 = scale(params.out_dim_0);
4405                                let inner = params.out_total / params.out_dim_0;
4406                                scaled_d0 * inner
4407                            } else {
4408                                params.out_total
4409                            };
4410                            if total_s == 0 {
4411                                continue;
4412                            }
4413                            let ek = expand_kernel(&dev.device);
4414                            pass.set_pipeline(&ek.pipeline);
4415                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4416                            let (gx, gy, gz) = dispatch_dims(total_s, 64);
4417                            pass.dispatch_workgroups(gx, gy, gz);
4418                        }
4419                        Step::Argmax { params } => {
4420                            let outer_s = scale(params.outer);
4421                            if outer_s == 0 {
4422                                continue;
4423                            }
4424                            let amk = argmax_kernel(&dev.device);
4425                            pass.set_pipeline(&amk.pipeline);
4426                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4427                            let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4428                            pass.dispatch_workgroups(gx, gy, gz);
4429                        }
4430                        Step::Pool2d { params } => {
4431                            let n_s = scale(params.n);
4432                            if n_s == 0 {
4433                                continue;
4434                            }
4435                            let pk = pool2d_kernel(&dev.device);
4436                            pass.set_pipeline(&pk.pipeline);
4437                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4438                            let total = n_s * params.c * params.h_out * params.w_out;
4439                            let (gx, gy, gz) = dispatch_dims(total, 64);
4440                            pass.dispatch_workgroups(gx, gy, gz);
4441                        }
4442                        Step::Conv2d { params } => {
4443                            let n_s = scale(params.n);
4444                            if n_s == 0 {
4445                                continue;
4446                            }
4447                            let ck2 = conv2d_kernel(&dev.device);
4448                            pass.set_pipeline(&ck2.pipeline);
4449                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4450                            let total = n_s * params.c_out * params.h_out * params.w_out;
4451                            let (gx, gy, gz) = dispatch_dims(total, 64);
4452                            pass.dispatch_workgroups(gx, gy, gz);
4453                        }
4454                        Step::Pool1d { params } => {
4455                            let n_s = scale(params.n);
4456                            if n_s == 0 {
4457                                continue;
4458                            }
4459                            let pk = pool1d_kernel(&dev.device);
4460                            pass.set_pipeline(&pk.pipeline);
4461                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4462                            let total = n_s * params.c * params.l_out;
4463                            let (gx, gy, gz) = dispatch_dims(total, 64);
4464                            pass.dispatch_workgroups(gx, gy, gz);
4465                        }
4466                        Step::Pool3d { params } => {
4467                            let n_s = scale(params.n);
4468                            if n_s == 0 {
4469                                continue;
4470                            }
4471                            let pk = pool3d_kernel(&dev.device);
4472                            pass.set_pipeline(&pk.pipeline);
4473                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4474                            let total = n_s * params.c * params.d_out * params.h_out * params.w_out;
4475                            let (gx, gy, gz) = dispatch_dims(total, 64);
4476                            pass.dispatch_workgroups(gx, gy, gz);
4477                        }
4478                        Step::Conv1d { params } => {
4479                            let n_s = scale(params.n);
4480                            if n_s == 0 {
4481                                continue;
4482                            }
4483                            let ck = conv1d_kernel(&dev.device);
4484                            pass.set_pipeline(&ck.pipeline);
4485                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4486                            let total = n_s * params.c_out * params.l_out;
4487                            let (gx, gy, gz) = dispatch_dims(total, 64);
4488                            pass.dispatch_workgroups(gx, gy, gz);
4489                        }
4490                        Step::Conv3d { params } => {
4491                            let n_s = scale(params.n);
4492                            if n_s == 0 {
4493                                continue;
4494                            }
4495                            let ck = conv3d_kernel(&dev.device);
4496                            pass.set_pipeline(&ck.pipeline);
4497                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4498                            let total =
4499                                n_s * params.c_out * params.d_out * params.h_out * params.w_out;
4500                            let (gx, gy, gz) = dispatch_dims(total, 64);
4501                            pass.dispatch_workgroups(gx, gy, gz);
4502                        }
4503                        Step::ScatterAdd { params } => {
4504                            let sk = scatter_add_kernel(&dev.device);
4505                            pass.set_pipeline(&sk.pipeline);
4506                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4507                            // Phase 0 zeros the FULL output (preserves
4508                            // accumulator semantics). Phase 1 scatters first
4509                            // num_updates_active updates only; serial single
4510                            // workgroup either way (atomic CAS unsupported in
4511                            // naga's MSL emitter — see scatter_add.wgsl).
4512                            if params.op == 0 {
4513                                let (gx, gy, gz) = dispatch_dims(params.out_total, 64);
4514                                pass.dispatch_workgroups(gx, gy, gz);
4515                            } else {
4516                                pass.dispatch_workgroups(1, 1, 1);
4517                            }
4518                        }
4519                        Step::TopK { params } => {
4520                            let outer_s = scale(params.outer);
4521                            if outer_s == 0 {
4522                                continue;
4523                            }
4524                            let tk = topk_kernel(&dev.device);
4525                            pass.set_pipeline(&tk.pipeline);
4526                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4527                            let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4528                            pass.dispatch_workgroups(gx, gy, gz);
4529                        }
4530                        Step::UmapKnn { params } => {
4531                            let n_s = scale(params.n);
4532                            if n_s == 0 {
4533                                continue;
4534                            }
4535                            let uk = umap_knn_kernel(&dev.device);
4536                            pass.set_pipeline(&uk.pipeline);
4537                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4538                            let (gx, gy, gz) = dispatch_dims(n_s, 64);
4539                            pass.dispatch_workgroups(gx, gy, gz);
4540                        }
4541                        Step::GroupedMatmul { params } => {
4542                            let m_s = scale(params.m);
4543                            if m_s == 0 {
4544                                continue;
4545                            }
4546                            let gk = grouped_matmul_kernel(&dev.device);
4547                            pass.set_pipeline(&gk.pipeline);
4548                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4549                            pass.dispatch_workgroups(params.n.div_ceil(8), m_s.div_ceil(8), 1);
4550                        }
4551                        Step::Sample { params } => {
4552                            let outer_s = scale(params.outer);
4553                            if outer_s == 0 {
4554                                continue;
4555                            }
4556                            let sk = sample_kernel(&dev.device);
4557                            pass.set_pipeline(&sk.pipeline);
4558                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4559                            let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4560                            pass.dispatch_workgroups(gx, gy, gz);
4561                        }
4562                        Step::SelectiveScan { params } => {
4563                            // Predicate-gated to batch=1; the seq scaling
4564                            // happens inside the kernel (uniform sees scaled
4565                            // seq). Dispatch grid here is per-(batch, hidden);
4566                            // unaffected by seq scaling.
4567                            let ssk = selective_scan_kernel(&dev.device);
4568                            pass.set_pipeline(&ssk.pipeline);
4569                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4570                            let total = params.batch * params.hidden;
4571                            let (gx, gy, gz) = dispatch_dims(total, 64);
4572                            pass.dispatch_workgroups(gx, gy, gz);
4573                        }
4574                        Step::DequantMatmul { params } => {
4575                            let m_s = scale(params.m);
4576                            if m_s == 0 {
4577                                continue;
4578                            }
4579                            let dk = dequant_matmul_kernel(&dev.device);
4580                            pass.set_pipeline(&dk.pipeline);
4581                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4582                            pass.dispatch_workgroups(params.n.div_ceil(8), m_s.div_ceil(8), 1);
4583                        }
4584                        Step::FusedResidualLn { params } => {
4585                            let outer_s = scale(params.outer);
4586                            if outer_s == 0 {
4587                                continue;
4588                            }
4589                            let frk = fused_residual_ln_kernel(&dev.device);
4590                            pass.set_pipeline(&frk.pipeline);
4591                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4592                            let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4593                            pass.dispatch_workgroups(gx, gy, gz);
4594                        }
4595                        Step::FusedResidualLnTee { params } => {
4596                            let outer_s = scale(params.outer);
4597                            if outer_s == 0 {
4598                                continue;
4599                            }
4600                            let frtk = fused_residual_ln_tee_kernel(&dev.device);
4601                            pass.set_pipeline(&frtk.pipeline);
4602                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4603                            let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4604                            pass.dispatch_workgroups(gx, gy, gz);
4605                        }
4606                        Step::FusedResidualRmsNorm { params } => {
4607                            let outer_s = scale(params.outer);
4608                            if outer_s == 0 {
4609                                continue;
4610                            }
4611                            let frk = fused_residual_rms_norm_kernel(&dev.device);
4612                            pass.set_pipeline(&frk.pipeline);
4613                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4614                            let (gx, gy, gz) = dispatch_dims(outer_s, 64);
4615                            pass.dispatch_workgroups(gx, gy, gz);
4616                        }
4617                        Step::MatmulQkv { params, coop } => {
4618                            let m_s = scale(params.m);
4619                            if m_s == 0 {
4620                                continue;
4621                            }
4622                            // Both kernels write to the same 32×32 output tile
4623                            // grid; only the inner GEMM strategy differs.
4624                            let pipe = if *coop {
4625                                &matmul_qkv_coop_f32_kernel(&dev.device)
4626                                    .expect("coop matmul_qkv kernel missing")
4627                                    .pipeline
4628                            } else {
4629                                &matmul_qkv_kernel(&dev.device).pipeline
4630                            };
4631                            pass.set_pipeline(pipe);
4632                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
4633                            pass.dispatch_workgroups(params.n.div_ceil(32), m_s.div_ceil(32), 1);
4634                        }
4635                        Step::DequantMatmulGguf { .. }
4636                        | Step::DequantGroupedMatmulGguf { .. }
4637                        | Step::GatedDeltaNet { .. }
4638                        | Step::Llada2GroupLimitedGate { .. }
4639                        | Step::UmapKnnHost { .. }
4640                        | Step::FftHost { .. } => {}
4641                        #[cfg(feature = "splat")]
4642                        Step::GaussianSplatRender { .. }
4643                        | Step::GaussianSplatRenderBackward { .. }
4644                        | Step::GaussianSplatPrepare { .. }
4645                        | Step::GaussianSplatRasterize { .. } => {}
4646                    }
4647                    if !matches!(step, Step::FftGpu { .. }) {
4648                        gpu_bi += 1;
4649                    }
4650                    step_i += 1;
4651                }
4652            }
4653            dev.queue.submit(std::iter::once(enc.finish()));
4654            let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
4655            if step_i >= self.schedule.len() {
4656                break;
4657            }
4658            match &self.schedule[step_i] {
4659                Step::DequantMatmulGguf {
4660                    m,
4661                    k,
4662                    n,
4663                    scheme_id,
4664                    x_byte_off,
4665                    w_byte_off,
4666                    out_byte_off,
4667                } => {
4668                    crate::gguf_host::run_dequant_matmul_gguf(
4669                        &self.arena,
4670                        &dev.device,
4671                        &dev.queue,
4672                        *m as usize,
4673                        *k as usize,
4674                        *n as usize,
4675                        *scheme_id,
4676                        *x_byte_off as usize,
4677                        *w_byte_off as usize,
4678                        *out_byte_off as usize,
4679                    );
4680                }
4681                Step::DequantGroupedMatmulGguf {
4682                    m,
4683                    k,
4684                    n,
4685                    num_experts,
4686                    scheme_id,
4687                    x_byte_off,
4688                    w_byte_off,
4689                    idx_byte_off,
4690                    out_byte_off,
4691                } => {
4692                    crate::gguf_host::run_dequant_grouped_matmul_gguf(
4693                        &self.arena,
4694                        &dev.device,
4695                        &dev.queue,
4696                        *m as usize,
4697                        *k as usize,
4698                        *n as usize,
4699                        *num_experts as usize,
4700                        *scheme_id,
4701                        *x_byte_off as usize,
4702                        *w_byte_off as usize,
4703                        *idx_byte_off as usize,
4704                        *out_byte_off as usize,
4705                    );
4706                }
4707                Step::GatedDeltaNet {
4708                    q_byte_off,
4709                    k_byte_off,
4710                    v_byte_off,
4711                    g_byte_off,
4712                    beta_byte_off,
4713                    state_byte_off,
4714                    dst_byte_off,
4715                    batch,
4716                    seq,
4717                    heads,
4718                    state_size,
4719                    use_carry,
4720                } => {
4721                    crate::gdn_host::run_gated_delta_net(
4722                        &self.arena,
4723                        &dev.device,
4724                        &dev.queue,
4725                        *q_byte_off as usize,
4726                        *k_byte_off as usize,
4727                        *v_byte_off as usize,
4728                        *g_byte_off as usize,
4729                        *beta_byte_off as usize,
4730                        *state_byte_off as usize,
4731                        *dst_byte_off as usize,
4732                        *batch as usize,
4733                        *seq as usize,
4734                        *heads as usize,
4735                        *state_size as usize,
4736                        *use_carry,
4737                    );
4738                }
4739                Step::Llada2GroupLimitedGate {
4740                    sig_byte_off,
4741                    route_byte_off,
4742                    out_byte_off,
4743                    n_elems,
4744                    attrs,
4745                } => {
4746                    crate::llada2_gate_host::run_llada2_group_limited_gate(
4747                        &self.arena,
4748                        &dev.device,
4749                        &dev.queue,
4750                        *sig_byte_off as usize,
4751                        *route_byte_off as usize,
4752                        *out_byte_off as usize,
4753                        *n_elems as usize,
4754                        attrs,
4755                    );
4756                }
4757                Step::UmapKnnHost {
4758                    pairwise_byte_off,
4759                    out_byte_off,
4760                    n,
4761                    k,
4762                } => {
4763                    crate::umap_knn_host::run_umap_knn(
4764                        &self.arena,
4765                        &dev.device,
4766                        &dev.queue,
4767                        *pairwise_byte_off as usize,
4768                        *out_byte_off as usize,
4769                        *n as usize,
4770                        *k as usize,
4771                    );
4772                }
4773                Step::FftHost {
4774                    src_byte_off,
4775                    dst_byte_off,
4776                    outer,
4777                    n_complex,
4778                    inverse,
4779                    norm_tag,
4780                    dtype_tag,
4781                } => {
4782                    crate::fft_host::run_fft1d(
4783                        &self.arena,
4784                        &dev.device,
4785                        &dev.queue,
4786                        *src_byte_off as usize,
4787                        *dst_byte_off as usize,
4788                        *outer as usize,
4789                        *n_complex as usize,
4790                        *inverse,
4791                        *norm_tag,
4792                        fft_dtype_from_tag(*dtype_tag),
4793                    );
4794                }
4795                #[cfg(feature = "splat")]
4796                Step::GaussianSplatRender {
4797                    positions_byte_off,
4798                    positions_len,
4799                    scales_byte_off,
4800                    scales_len,
4801                    rotations_byte_off,
4802                    rotations_len,
4803                    opacities_byte_off,
4804                    opacities_len,
4805                    colors_byte_off,
4806                    colors_len,
4807                    sh_coeffs_byte_off,
4808                    sh_coeffs_len,
4809                    meta_byte_off,
4810                    dst_byte_off,
4811                    dst_len,
4812                    width,
4813                    height,
4814                    tile_size,
4815                    radius_scale,
4816                    alpha_cutoff,
4817                    max_splat_steps,
4818                    transmittance_threshold,
4819                    max_list_entries,
4820                } => {
4821                    crate::splat::run_gaussian_splat_render(
4822                        &self.arena,
4823                        &dev.device,
4824                        &dev.queue,
4825                        *positions_byte_off as usize,
4826                        *positions_len as usize,
4827                        *scales_byte_off as usize,
4828                        *scales_len as usize,
4829                        *rotations_byte_off as usize,
4830                        *rotations_len as usize,
4831                        *opacities_byte_off as usize,
4832                        *opacities_len as usize,
4833                        *colors_byte_off as usize,
4834                        *colors_len as usize,
4835                        *sh_coeffs_byte_off as usize,
4836                        *sh_coeffs_len as usize,
4837                        *meta_byte_off as usize,
4838                        *dst_byte_off as usize,
4839                        *dst_len as usize,
4840                        *width,
4841                        *height,
4842                        *tile_size,
4843                        *radius_scale,
4844                        *alpha_cutoff,
4845                        *max_splat_steps,
4846                        *transmittance_threshold,
4847                        *max_list_entries,
4848                    );
4849                }
4850                #[cfg(feature = "splat")]
4851                Step::GaussianSplatPrepare {
4852                    positions_byte_off,
4853                    positions_len,
4854                    scales_byte_off,
4855                    scales_len,
4856                    rotations_byte_off,
4857                    rotations_len,
4858                    opacities_byte_off,
4859                    opacities_len,
4860                    colors_byte_off,
4861                    colors_len,
4862                    sh_coeffs_byte_off,
4863                    sh_coeffs_len,
4864                    meta_byte_off,
4865                    meta_len,
4866                    prep_byte_off,
4867                    prep_len,
4868                    width,
4869                    height,
4870                    tile_size,
4871                    radius_scale,
4872                    alpha_cutoff,
4873                    max_splat_steps,
4874                    transmittance_threshold,
4875                    max_list_entries,
4876                } => {
4877                    crate::splat::run_gaussian_splat_prepare(
4878                        &self.arena,
4879                        &dev.device,
4880                        &dev.queue,
4881                        *positions_byte_off as usize,
4882                        *positions_len as usize,
4883                        *scales_byte_off as usize,
4884                        *scales_len as usize,
4885                        *rotations_byte_off as usize,
4886                        *rotations_len as usize,
4887                        *opacities_byte_off as usize,
4888                        *opacities_len as usize,
4889                        *colors_byte_off as usize,
4890                        *colors_len as usize,
4891                        *sh_coeffs_byte_off as usize,
4892                        *sh_coeffs_len as usize,
4893                        *meta_byte_off as usize,
4894                        *meta_len as usize,
4895                        *prep_byte_off as usize,
4896                        *prep_len as usize,
4897                        *width,
4898                        *height,
4899                        *tile_size,
4900                        *radius_scale,
4901                        *alpha_cutoff,
4902                        *max_splat_steps,
4903                        *transmittance_threshold,
4904                        *max_list_entries,
4905                    );
4906                }
4907                #[cfg(feature = "splat")]
4908                Step::GaussianSplatRasterize {
4909                    prep_byte_off,
4910                    prep_len,
4911                    meta_byte_off,
4912                    meta_len,
4913                    dst_byte_off,
4914                    dst_len,
4915                    count,
4916                    width,
4917                    height,
4918                    tile_size,
4919                    alpha_cutoff,
4920                    max_splat_steps,
4921                    transmittance_threshold,
4922                    max_list_entries,
4923                } => {
4924                    crate::splat::run_gaussian_splat_rasterize(
4925                        &self.arena,
4926                        &dev.device,
4927                        &dev.queue,
4928                        *prep_byte_off as usize,
4929                        *prep_len as usize,
4930                        *meta_byte_off as usize,
4931                        *meta_len as usize,
4932                        *dst_byte_off as usize,
4933                        *dst_len as usize,
4934                        *count as usize,
4935                        *width,
4936                        *height,
4937                        *tile_size,
4938                        *alpha_cutoff,
4939                        *max_splat_steps,
4940                        *transmittance_threshold,
4941                        *max_list_entries,
4942                    );
4943                }
4944                #[cfg(feature = "splat")]
4945                Step::GaussianSplatRenderBackward {
4946                    positions_byte_off,
4947                    positions_len,
4948                    scales_byte_off,
4949                    scales_len,
4950                    rotations_byte_off,
4951                    rotations_len,
4952                    opacities_byte_off,
4953                    opacities_len,
4954                    colors_byte_off,
4955                    colors_len,
4956                    sh_coeffs_byte_off,
4957                    sh_coeffs_len,
4958                    meta_byte_off,
4959                    d_loss_byte_off,
4960                    d_loss_len,
4961                    packed_byte_off,
4962                    packed_len,
4963                    width,
4964                    height,
4965                    tile_size,
4966                    radius_scale,
4967                    alpha_cutoff,
4968                    max_splat_steps,
4969                    transmittance_threshold,
4970                    max_list_entries,
4971                    loss_grad_clip,
4972                    sh_band,
4973                    max_anisotropy,
4974                } => {
4975                    crate::splat::run_gaussian_splat_render_backward(
4976                        &self.arena,
4977                        &dev.device,
4978                        &dev.queue,
4979                        *positions_byte_off as usize,
4980                        *positions_len as usize,
4981                        *scales_byte_off as usize,
4982                        *scales_len as usize,
4983                        *rotations_byte_off as usize,
4984                        *rotations_len as usize,
4985                        *opacities_byte_off as usize,
4986                        *opacities_len as usize,
4987                        *colors_byte_off as usize,
4988                        *colors_len as usize,
4989                        *sh_coeffs_byte_off as usize,
4990                        *sh_coeffs_len as usize,
4991                        *meta_byte_off as usize,
4992                        *d_loss_byte_off as usize,
4993                        *d_loss_len as usize,
4994                        *packed_byte_off as usize,
4995                        *packed_len as usize,
4996                        *width,
4997                        *height,
4998                        *tile_size,
4999                        *radius_scale,
5000                        *alpha_cutoff,
5001                        *max_splat_steps,
5002                        *transmittance_threshold,
5003                        *max_list_entries,
5004                        *loss_grad_clip,
5005                        *sh_band,
5006                        *max_anisotropy,
5007                    );
5008                }
5009                _ => break,
5010            }
5011            step_i += 1;
5012        }
5013
5014        // RLX_WGPU_NAN_TRACE=1: after submission, scan every node's
5015        // arena slot for NaN. Print the first N nodes whose output
5016        // contains NaN (in IR topo order). Used to bisect which kernel
5017        // first introduces NaN — once we know the producer, we know
5018        // which WGSL to look at.
5019        if rlx_ir::env::flag("RLX_WGPU_NAN_TRACE") {
5020            let mut bad_nodes = Vec::new();
5021            for node in self.graph.nodes() {
5022                if !self.arena.has(node.id) {
5023                    continue;
5024                }
5025                // Skip leaves — populated by host writes, not kernels.
5026                if matches!(
5027                    node.op,
5028                    rlx_ir::Op::Input { .. }
5029                        | rlx_ir::Op::Param { .. }
5030                        | rlx_ir::Op::Constant { .. }
5031                ) {
5032                    continue;
5033                }
5034                let data = self.arena.read_f32(&dev.device, &dev.queue, node.id);
5035                let nan_count = data.iter().filter(|v| v.is_nan()).count();
5036                let inf_count = data.iter().filter(|v| v.is_infinite()).count();
5037                if nan_count > 0 || inf_count > 0 {
5038                    // Capture first NaN index + the values around it.
5039                    let first_nan = data.iter().position(|v| v.is_nan());
5040                    if let Some(idx) = first_nan {
5041                        let lo = idx.saturating_sub(2);
5042                        let hi = (idx + 3).min(data.len());
5043                        eprintln!(
5044                            "  node {:?} op={:?} len={} nan={} inf={} \
5045                                   first_nan_idx={} ctx={:?}",
5046                            node.id,
5047                            node.op,
5048                            data.len(),
5049                            nan_count,
5050                            inf_count,
5051                            idx,
5052                            &data[lo..hi]
5053                        );
5054                    }
5055                    bad_nodes.push((node.id, data.len(), nan_count, inf_count));
5056                    if bad_nodes.len() >= 3 {
5057                        break;
5058                    }
5059                }
5060            }
5061            if bad_nodes.is_empty() {
5062                eprintln!("[wgpu-nan-trace] no NaN/Inf in any node — clean run");
5063            } else {
5064                eprintln!(
5065                    "[wgpu-nan-trace] first {} bad nodes (above)",
5066                    bad_nodes.len()
5067                );
5068            }
5069        }
5070
5071        self.graph
5072            .outputs
5073            .iter()
5074            .map(|&id| {
5075                if rlx_ir::env::flag("RLX_BENCH_DISPATCH_ONLY") {
5076                    let n = self.graph.node(id).shape.num_elements().unwrap_or(0);
5077                    vec![0.0; n]
5078                } else {
5079                    self.arena.read_f32(&dev.device, &dev.queue, id)
5080                }
5081            })
5082            .collect()
5083    }
5084}
5085
5086/// Compute a (X, Y, 1) workgroup grid for a 1-D workload.
5087///
5088/// WebGPU caps `dispatch_workgroups` per-dimension at 65535. For
5089/// workloads beyond `65535 × workgroup_size_x` threads we split into
5090/// a 2-D grid; kernels recover the linear thread index via
5091/// `gid.x + gid.y * num_workgroups.x * 64u`.
5092fn dispatch_dims(threads_total: u32, workgroup_size: u32) -> (u32, u32, u32) {
5093    let groups = threads_total.div_ceil(workgroup_size);
5094    if groups <= 65535 {
5095        (groups, 1, 1)
5096    } else {
5097        let gx = 65535u32;
5098        let gy = groups.div_ceil(gx);
5099        (gx, gy, 1)
5100    }
5101}
5102
5103fn require_equal_shapes(graph: &Graph, ids: &[NodeId], op_name: &str) {
5104    let s0 = graph.node(ids[0]).shape.num_elements().unwrap_or(0);
5105    for &id in &ids[1..] {
5106        let si = graph.node(id).shape.num_elements().unwrap_or(0);
5107        if si != s0 {
5108            panic!(
5109                "rlx-wgpu {op_name}: broadcasting not yet implemented; \
5110                    inputs must have the same element count (got {s0} vs {si})"
5111            );
5112        }
5113    }
5114}
5115
5116fn bind_two(
5117    device: &wgpu::Device,
5118    kernel: &Kernel,
5119    buf0: &wgpu::Buffer,
5120    buf1: &wgpu::Buffer,
5121) -> wgpu::BindGroup {
5122    device.create_bind_group(&wgpu::BindGroupDescriptor {
5123        label: Some("rlx-wgpu bg"),
5124        layout: &kernel.bgl,
5125        entries: &[
5126            wgpu::BindGroupEntry {
5127                binding: 0,
5128                resource: buf0.as_entire_binding(),
5129            },
5130            wgpu::BindGroupEntry {
5131                binding: 1,
5132                resource: buf1.as_entire_binding(),
5133            },
5134        ],
5135    })
5136}
5137
5138/// Compute precision selector: derive from IR dtypes of A and B and
5139/// the device features.
5140///
5141/// Priority:
5142///   1. Coop16 — if EXPERIMENTAL_COOPERATIVE_MATRIX + SHADER_F16 +
5143///      F16 IR tag + b traces to a Param + M/K/N are 32/8/32 aligned.
5144///      Unlocks Apple's `simdgroup_matrix` / Vulkan's KHR_cooperative
5145///      hardware GEMM units (~18× faster than f32 ALU on Apple M-series).
5146///   2. F32 — every other case, *including* when AutoMixedPrecision
5147///      tagged the matmul as F16 but it failed Coop16's alignment
5148///      check. The non-coop F16 path (`matmul_f16_compute.wgsl`) was
5149///      empirically measured 4-5× SLOWER than the f32 baseline on
5150///      Apple via wgpu/naga 29 — the WGSL→MSL emit doesn't unlock
5151///      Apple's f16 ALU through portable WGSL ALU. So at small /
5152///      unaligned shapes we lose nothing by ignoring the IR's f16
5153///      tag and using f32 — precision improves AND speed wins.
5154///
5155/// (The F16 variant of `MatmulCompute` and `matmul_f16_compute.wgsl`
5156/// remain for future use — e.g. when naga gains a portable subgroup-
5157/// matrix surface that lowers efficiently without needing the full
5158/// coop-matrix dance, or when bf16 hardware lands. Today no path
5159/// dispatches them.)
5160fn derive_matmul_compute(
5161    dev: &wgpu::Device,
5162    graph: &Graph,
5163    a_id: NodeId,
5164    b_id: NodeId,
5165    m: u32,
5166    k: u32,
5167    n: u32,
5168) -> MatmulCompute {
5169    use rlx_ir::DType;
5170    let a_dt = graph.node(a_id).shape.dtype();
5171    let b_dt = graph.node(b_id).shape.dtype();
5172    let any_low =
5173        matches!(a_dt, DType::F16 | DType::BF16) || matches!(b_dt, DType::F16 | DType::BF16);
5174    // CoopF32 (`simdgroup_float8x8`) needs K and N aligned to 8 and 32
5175    // (one micro-tile per K-iter, one 32-col workgroup per N-tile).
5176    // M can be arbitrary — the kernel pads to the next multiple of 32
5177    // and bounds-checks the output writes so out-of-range rows stay
5178    // untouched. (The Coop16 / matmul_qkv paths still require m%32==0;
5179    // their kernels don't have the same bounds check.)
5180    let coop16_aligned = m.is_multiple_of(32) && k.is_multiple_of(8) && n.is_multiple_of(32);
5181    let coop_f32_aligned = k.is_multiple_of(8) && n.is_multiple_of(32);
5182    let has_coop = dev
5183        .features()
5184        .contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX);
5185    // Coop16 has an f16 accumulator (Naga 29 can't compile the mixed
5186    // f32-acc / f16-operand form). Sums of 3072 BERT-FFN activations
5187    // overflow f16, so we only enter on F16/BF16 IR tags — AutoMixed
5188    // users have already opted into the precision tradeoff.
5189    if any_low
5190        && has_coop
5191        && dev.features().contains(wgpu::Features::SHADER_F16)
5192        && traces_to_param(graph, b_id)
5193        && coop16_aligned
5194    {
5195        return MatmulCompute::Coop16;
5196    }
5197    // CoopF32 (`simdgroup_float8x8` on Apple): the f32 hardware-GEMM
5198    // path. Used whenever cooperative-matrix is available, B is a
5199    // Param, and shapes align — gives ~5-10× speedup over the
5200    // tiled `matmul_wide` path with no precision loss vs the f32
5201    // baseline (BERT max|Δ| stays at 2.3e-3 vs CPU on Apple).
5202    //
5203    // Backend gate: only Metal validated. On Vulkan/NVIDIA the same
5204    // kernel produces wildly wrong output (BERT max|Δ| 3.4 vs CPU,
5205    // bench 2026-05 on RTX 4090) — naga 29's lowering of
5206    // `coop_mat<f32>` to KHR_cooperative_matrix doesn't agree with
5207    // the simdgroup_float8x8 path on layout or stride. Re-enable on
5208    // Vulkan/DX12 once the path is verified end-to-end. Override
5209    // with RLX_WGPU_FORCE_COOP_F32=1 to bench the broken path.
5210    let disabled = rlx_ir::env::flag("RLX_WGPU_NO_COOP_F32");
5211    let forced = rlx_ir::env::flag("RLX_WGPU_FORCE_COOP_F32");
5212    let backend_ok = forced
5213        || matches!(
5214            crate::device::wgpu_device().map(|d| d.backend),
5215            Some(wgpu::Backend::Metal)
5216        );
5217    if !disabled && backend_ok && has_coop && coop_f32_aligned && traces_to_param(graph, b_id) {
5218        return MatmulCompute::CoopF32;
5219    }
5220    MatmulCompute::F32
5221}
5222
5223/// Detects the BERT-style fused-QKV-then-narrow-then-attention
5224/// pattern. When all three of an attention's Q/K/V inputs are
5225/// `Op::Narrow` of a single source tensor on the last axis with
5226/// sequential offsets `(0, H·D, 2·H·D)` and equal lengths `H·D`,
5227/// returns `Some((qkv_source_node, h_d))` — naming the source
5228/// tensor and per-slice width.
5229///
5230/// EMPIRICAL FINDING: the obvious "skip the narrow + read attention
5231/// directly from QKV with stride 3·H·D" optimization REGRESSED end-
5232/// to-end perf 7-15× on Apple M4 Pro. The narrow's apparent overhead
5233/// (~3 dispatches per attention block, ~150µs at small batch) is
5234/// dwarfed by the cost of strided attention reads — stepping by
5235/// 3·H·D = 4.6 KB between sequence positions defeats the hardware
5236/// prefetcher (prefetch distance maxes around 1-2 KB on M-series).
5237/// Cosine stayed 0.9999+ (output is correct, just slow).
5238///
5239/// Kept as a helper for future smarter fusions — e.g. a coop kernel
5240/// that reads Q/K/V cooperatively from QKV in a single pass over
5241/// the sequence dim, avoiding the random-access stride pattern.
5242#[allow(dead_code)]
5243fn detect_qkv_narrow_pattern(
5244    graph: &Graph,
5245    q_id: NodeId,
5246    k_id: NodeId,
5247    v_id: NodeId,
5248) -> Option<(NodeId, u32)> {
5249    let unwrap_narrow = |id: NodeId| -> Option<(NodeId, usize, usize, usize)> {
5250        let node = graph.node(id);
5251        match &node.op {
5252            Op::Narrow { axis, start, len } => Some((node.inputs[0], *axis, *start, *len)),
5253            _ => None,
5254        }
5255    };
5256    let (q_src, q_axis, q_start, q_len) = unwrap_narrow(q_id)?;
5257    let (k_src, k_axis, k_start, k_len) = unwrap_narrow(k_id)?;
5258    let (v_src, v_axis, v_start, v_len) = unwrap_narrow(v_id)?;
5259    // Same source tensor.
5260    if q_src != k_src || k_src != v_src {
5261        return None;
5262    }
5263    // Equal slice widths (= H · D).
5264    if q_len != k_len || k_len != v_len {
5265        return None;
5266    }
5267    // Sequential offsets 0, H·D, 2·H·D.
5268    if q_start != 0 || k_start != q_len || v_start != q_len * 2 {
5269        return None;
5270    }
5271    // All on the LAST axis of the source.
5272    let src_rank = graph.node(q_src).shape.dims().len();
5273    if q_axis + 1 != src_rank || k_axis + 1 != src_rank || v_axis + 1 != src_rank {
5274        return None;
5275    }
5276    Some((q_src, q_len as u32))
5277}
5278
5279/// Detects the (FusedMatMulBiasAct → Narrow×3) split-QKV pattern that
5280/// shows up at the start of every BERT-style attention block. Returns
5281/// a map `parent_fmb_id → (q_narrow_id, k_narrow_id, v_narrow_id)`
5282/// for every site where the pattern can be replaced by one
5283/// `Step::MatmulQkv` dispatch.
5284///
5285/// Pattern requirements:
5286///   - Parent is `Op::FusedMatMulBiasAct { activation: None }` with
5287///     output shape `[..., 3·head_width]`.
5288///   - The parent's *only* consumers are exactly 3 `Op::Narrow` nodes,
5289///     all on the last axis, with offsets `(0, head_width, 2·head_width)`
5290///     and equal `len = head_width`.
5291///
5292/// The win is purely structural: same FMA work, but the 3 narrow
5293/// dispatches (and their full-tensor read+write of the QKV intermediate)
5294/// disappear. Different from the reverted "skip narrow + read attention
5295/// strided" approach because reads from each Q/K/V buffer remain
5296/// sequential — the prefetcher stays happy.
5297/// Detects (`Op::Binary(Add) → Op::LayerNorm`) where the Add has more
5298/// than one consumer in the graph — the case `FuseResidualLN` declines
5299/// because its single-consumer guard would force materializing the sum.
5300///
5301/// Returns:
5302///   - `ln_to_tee`: `ln_id → (h, delta, gamma, beta, sum_id)` so the
5303///     wgpu LayerNorm lowering can emit `Step::FusedResidualLnTee`
5304///     using the existing arena slot for the sum (= the Add's slot).
5305///   - `skip_adds`: the set of Add `NodeId`s whose normal Step emission
5306///     should be suppressed; their output value is written by the tee
5307///     step instead.
5308fn detect_residual_ln_tee_pattern(
5309    graph: &Graph,
5310) -> (
5311    HashMap<NodeId, (NodeId, NodeId, NodeId, NodeId, NodeId)>,
5312    HashSet<NodeId>,
5313) {
5314    use rlx_ir::op::BinaryOp;
5315    // Consumer counts (output references count once each).
5316    let mut consumers: HashMap<NodeId, usize> = HashMap::new();
5317    for node in graph.nodes() {
5318        for &input in &node.inputs {
5319            *consumers.entry(input).or_insert(0) += 1;
5320        }
5321    }
5322    for &out in &graph.outputs {
5323        *consumers.entry(out).or_insert(0) += 1;
5324    }
5325
5326    let mut ln_to_tee = HashMap::new();
5327    let mut skip_adds = HashSet::new();
5328    for node in graph.nodes() {
5329        let Op::LayerNorm { axis: _, eps: _ } = &node.op else {
5330            continue;
5331        };
5332        if node.inputs.len() < 3 {
5333            continue;
5334        } // need [in, gamma, beta]
5335        let in_id = node.inputs[0];
5336        let in_node = graph.node(in_id);
5337        if !matches!(in_node.op, Op::Binary(BinaryOp::Add)) {
5338            continue;
5339        }
5340        // Only fire when Add has >= 2 consumers (otherwise `FuseResidualLN`
5341        // already collapses it into Op::FusedResidualLN upstream).
5342        if consumers.get(&in_id).copied().unwrap_or(0) < 2 {
5343            continue;
5344        }
5345        // Add must be plain — both operands shape-equal to LN's input
5346        // and to each other.
5347        if in_node.inputs.len() != 2 {
5348            continue;
5349        }
5350        let h_id = in_node.inputs[0];
5351        let delta_id = in_node.inputs[1];
5352        if graph.node(h_id).shape.dims() != node.shape.dims() {
5353            continue;
5354        }
5355        if graph.node(delta_id).shape.dims() != node.shape.dims() {
5356            continue;
5357        }
5358        let gamma_id = node.inputs[1];
5359        let beta_id = node.inputs[2];
5360        ln_to_tee.insert(node.id, (h_id, delta_id, gamma_id, beta_id, in_id));
5361        skip_adds.insert(in_id);
5362    }
5363    (ln_to_tee, skip_adds)
5364}
5365
5366fn detect_split_qkv_pattern(graph: &Graph) -> HashMap<NodeId, (NodeId, NodeId, NodeId)> {
5367    // consumers[parent] = list of node ids that read parent
5368    let mut consumers: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
5369    for node in graph.nodes() {
5370        for &input in &node.inputs {
5371            consumers.entry(input).or_default().push(node.id);
5372        }
5373    }
5374    // Output nodes also count as consumers — would prevent QKV elision
5375    // if the matmul output is ever read externally.
5376    for &out_id in &graph.outputs {
5377        consumers.entry(out_id).or_default().push(NodeId(u32::MAX));
5378    }
5379
5380    let mut result = HashMap::new();
5381    for node in graph.nodes() {
5382        if !matches!(node.op, Op::FusedMatMulBiasAct { activation: None }) {
5383            continue;
5384        }
5385        let cs = match consumers.get(&node.id) {
5386            Some(c) if c.len() == 3 => c,
5387            _ => continue,
5388        };
5389        let dims = node.shape.dims();
5390        if dims.is_empty() {
5391            continue;
5392        }
5393        let last_axis = dims.len() - 1;
5394        let n = dims[last_axis].unwrap_static();
5395        if n % 3 != 0 {
5396            continue;
5397        }
5398        let head_width = n / 3;
5399
5400        // Each consumer must be a Narrow on the last axis, len = head_width.
5401        let mut narrows: Vec<(usize, NodeId)> = Vec::with_capacity(3);
5402        let mut all_match = true;
5403        for &c in cs {
5404            let cn = graph.node(c);
5405            match cn.op {
5406                Op::Narrow { axis, start, len }
5407                    if axis == last_axis && len == head_width && cn.inputs[0] == node.id =>
5408                {
5409                    narrows.push((start, c));
5410                }
5411                _ => {
5412                    all_match = false;
5413                    break;
5414                }
5415            }
5416        }
5417        if !all_match {
5418            continue;
5419        }
5420        narrows.sort_by_key(|&(start, _)| start);
5421        if narrows[0].0 != 0 || narrows[1].0 != head_width || narrows[2].0 != 2 * head_width {
5422            continue;
5423        }
5424        result.insert(node.id, (narrows[0].1, narrows[1].1, narrows[2].1));
5425    }
5426    result
5427}
5428
5429/// Walk through Cast/Reshape nodes (which alias the underlying arena
5430/// slot, per `plan_f32_uniform`) to find whether `id` ultimately
5431/// refers to an `Op::Param`. AutoMixedPrecision wraps params in
5432/// Cast(F32→F16) nodes, so a literal `matches!(node.op, Op::Param)`
5433/// check on the matmul's `b_id` would miss the Cast(Param) case.
5434fn traces_to_param(graph: &Graph, mut id: NodeId) -> bool {
5435    loop {
5436        let node = graph.node(id);
5437        match &node.op {
5438            Op::Param { .. } => return true,
5439            Op::Cast { .. } | Op::Reshape { .. } => {
5440                if node.inputs.is_empty() {
5441                    return false;
5442                }
5443                id = node.inputs[0];
5444            }
5445            _ => return false,
5446        }
5447    }
5448}
5449
5450/// Per-Matmul-step bind group builder. Three branches:
5451///   1. compute_precision == F16 + b_is_param + SHADER_F16
5452///        → matmul_f16_compute (3-binding, f16 ALU)
5453///   2. legacy `RLX_WGPU_F16_WEIGHTS` env var + b_is_param + SHADER_F16
5454///        → matmul_f16w (3-binding, f32 ALU; experimental, see kernel
5455///         docstring for why this currently regresses perf)
5456///   3. otherwise → matmul (2-binding, f32 ALU)
5457/// Append a Coop16 pre-pass: mirrors `arena[off..off+len]` (f32) into
5458/// `arena_f16[off..off+len]` (f16) so the matmul kernel can read A
5459/// as f16. Caller is responsible for guaranteeing the arena has an
5460/// `f16_buffer` (should be true on any SHADER_F16-capable device).
5461///
5462/// Currently unused — superseded by the workgroup-staging path in
5463/// `matmul_coop16.wgsl`. Retained as the right primitive for future
5464/// kernels that operate on a f16-tagged activation region without
5465/// internal staging (e.g. a chain of f16-only ops).
5466#[allow(dead_code)]
5467fn push_cast_f32_to_f16_step(
5468    device: &wgpu::Device,
5469    arena: &Arena,
5470    schedule: &mut Vec<Step>,
5471    uniforms: &mut Vec<wgpu::Buffer>,
5472    bind_groups: &mut Vec<wgpu::BindGroup>,
5473    mm_cast: &Option<&'static Kernel>,
5474    src_off: u32,
5475    len: u32,
5476) {
5477    let kernel = match mm_cast {
5478        Some(k) => *k,
5479        None => return, // device lacks SHADER_F16; fall through, dispatch will skip
5480    };
5481    let f16_buf = match &arena.f16_buffer {
5482        Some(b) => b,
5483        None => return,
5484    };
5485    let p = CastF32ToF16Params {
5486        src_off,
5487        len,
5488        _p0: 0,
5489        _p1: 0,
5490    };
5491    let u = device.create_buffer(&wgpu::BufferDescriptor {
5492        label: Some("rlx-wgpu cast_f32_to_f16 uniform"),
5493        size: std::mem::size_of::<CastF32ToF16Params>() as u64,
5494        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
5495        mapped_at_creation: false,
5496    });
5497    // Write params at compile (kernel doesn't depend on active extent).
5498    let dev = wgpu_device().expect("rlx-wgpu: device gone");
5499    dev.queue.write_buffer(&u, 0, bytemuck::bytes_of(&p));
5500    let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
5501        label: Some("rlx-wgpu cast_f32_to_f16 bg"),
5502        layout: &kernel.bgl,
5503        entries: &[
5504            wgpu::BindGroupEntry {
5505                binding: 0,
5506                resource: f16_buf.as_entire_binding(),
5507            },
5508            wgpu::BindGroupEntry {
5509                binding: 1,
5510                resource: u.as_entire_binding(),
5511            },
5512            wgpu::BindGroupEntry {
5513                binding: 2,
5514                resource: arena.buffer.as_entire_binding(),
5515            },
5516        ],
5517    });
5518    schedule.push(Step::CastF32ToF16 { params: p });
5519    uniforms.push(u);
5520    bind_groups.push(bg);
5521}
5522
5523fn build_matmul_bind_group(
5524    device: &wgpu::Device,
5525    mm_k: &Kernel,
5526    _mm_w: &Kernel,
5527    mm_f16w: &Option<&'static Kernel>,
5528    mm_f16c: &Option<&'static Kernel>,
5529    mm_coop: &Option<&'static Kernel>,
5530    mm_coop_f32: &Option<&'static Kernel>,
5531    arena: &Arena,
5532    params: &wgpu::Buffer,
5533    b_is_param: bool,
5534    compute_precision: MatmulCompute,
5535) -> wgpu::BindGroup {
5536    if b_is_param
5537        && compute_precision == MatmulCompute::CoopF32
5538        && let Some(coop_f32) = mm_coop_f32
5539    {
5540        // 2-binding layout — both A and B come from the f32 arena
5541        // (no f16 shadow buffer needed for the pure-f32 path).
5542        return device.create_bind_group(&wgpu::BindGroupDescriptor {
5543            label: Some("rlx-wgpu matmul_coop_f32 bg"),
5544            layout: &coop_f32.bgl,
5545            entries: &[
5546                wgpu::BindGroupEntry {
5547                    binding: 0,
5548                    resource: arena.buffer.as_entire_binding(),
5549                },
5550                wgpu::BindGroupEntry {
5551                    binding: 1,
5552                    resource: params.as_entire_binding(),
5553                },
5554            ],
5555        });
5556    }
5557    if b_is_param
5558        && compute_precision == MatmulCompute::Coop16
5559        && let (Some(f16_buf), Some(coop)) = (&arena.f16_buffer, mm_coop)
5560    {
5561        // 3-binding layout — A is staged from arena (f32) through
5562        // workgroup-shared memory inside the kernel, no separate
5563        // f16 binding for A.
5564        return device.create_bind_group(&wgpu::BindGroupDescriptor {
5565            label: Some("rlx-wgpu matmul_coop16 bg"),
5566            layout: &coop.bgl,
5567            entries: &[
5568                wgpu::BindGroupEntry {
5569                    binding: 0,
5570                    resource: arena.buffer.as_entire_binding(),
5571                },
5572                wgpu::BindGroupEntry {
5573                    binding: 1,
5574                    resource: params.as_entire_binding(),
5575                },
5576                wgpu::BindGroupEntry {
5577                    binding: 2,
5578                    resource: f16_buf.as_entire_binding(),
5579                }, // weights
5580            ],
5581        });
5582    }
5583    if b_is_param
5584        && compute_precision == MatmulCompute::F16
5585        && let (Some(f16_buf), Some(f16c)) = (&arena.f16_buffer, mm_f16c)
5586    {
5587        return device.create_bind_group(&wgpu::BindGroupDescriptor {
5588            label: Some("rlx-wgpu matmul_f16_compute bg"),
5589            layout: &f16c.bgl,
5590            entries: &[
5591                wgpu::BindGroupEntry {
5592                    binding: 0,
5593                    resource: arena.buffer.as_entire_binding(),
5594                },
5595                wgpu::BindGroupEntry {
5596                    binding: 1,
5597                    resource: params.as_entire_binding(),
5598                },
5599                wgpu::BindGroupEntry {
5600                    binding: 2,
5601                    resource: f16_buf.as_entire_binding(),
5602                },
5603            ],
5604        });
5605    }
5606    let f16w_opt_in = rlx_ir::env::flag("RLX_WGPU_F16_WEIGHTS");
5607    if b_is_param
5608        && f16w_opt_in
5609        && let (Some(f16_buf), Some(f16w)) = (&arena.f16_buffer, mm_f16w)
5610    {
5611        return device.create_bind_group(&wgpu::BindGroupDescriptor {
5612            label: Some("rlx-wgpu matmul_f16w bg"),
5613            layout: &f16w.bgl,
5614            entries: &[
5615                wgpu::BindGroupEntry {
5616                    binding: 0,
5617                    resource: arena.buffer.as_entire_binding(),
5618                },
5619                wgpu::BindGroupEntry {
5620                    binding: 1,
5621                    resource: params.as_entire_binding(),
5622                },
5623                wgpu::BindGroupEntry {
5624                    binding: 2,
5625                    resource: f16_buf.as_entire_binding(),
5626                },
5627            ],
5628        });
5629    }
5630    bind_two(device, mm_k, &arena.buffer, params)
5631}