Skip to main content

rlx_wgpu/
backend.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! `WgpuExecutable` — compiles an rlx-ir Graph into a sequence of
17//! kernel dispatches against a pre-allocated arena buffer.
18//!
19//! v2 op coverage: MatMul + element-wise families (Binary 7, Unary 12,
20//! Compare 6, Where) + leaves. Anything else panics at compile time.
21
22use std::collections::{HashMap, HashSet};
23use std::num::NonZeroU64;
24
25use rlx_ir::dynamic::{bind_graph, has_dynamic_dims, infer_bindings_from_f32_inputs, same_binding};
26use rlx_ir::op::{Activation, BinaryOp, CmpOp, MaskKind, ReduceOp};
27use rlx_ir::shape::DimBinding;
28use rlx_ir::{Graph, NodeId, Op};
29
30use crate::buffer::{
31    Arena, ReadbackLayout, ReadbackStaging, TinyReadbackStaging, decode_mapped_readback_f32,
32    decode_tiny_mapped_f32, encode_readback_copies, plan_f32_uniform, read_f32_many_pooled,
33    schedule_readback_map, use_tiny_readback, wait_readback_map,
34};
35use crate::device::wgpu_device;
36use crate::kernels::{
37    ArgmaxParams, AttentionBwdParams, AttentionParams, BatchElementwiseRegionParams, BinaryParams,
38    Conv1dParams, Conv2dParams, Conv3dParams, CopyParams, CumsumBwdParams, CumsumParams,
39    DequantMatmulParams, ElementwiseRegionParams, ExpandParams, FusedResidualLnParams,
40    FusedResidualLnTeeParams, FusedResidualRmsNormParams, GatherAxisParams, GatherBwdParams,
41    GatherParams, GroupedMatmulParams, Kernel, LayerNormBwdParams, LayerNormParams, MatmulParams,
42    MatmulQkvParams, NarrowConcatParams, Pool1dParams, Pool2dParams, Pool3dParams, ReduceParams,
43    RmsNormBwdParams, RopeBwdParams, RopeParams, SampleParams, ScatterAddParams,
44    SelectiveScanParams, SoftmaxParams, TopKParams, TransposeParams, UmapKnnParams, UnaryParams,
45    WelchPeaksGpuParams, WhereParams, argmax_kernel, attention_bwd_kernel, attention_kernel,
46    batch_elementwise_region_kernel, binary_kernel, cast_f32_to_f16_kernel, compare_kernel,
47    concat_kernel, conv1d_kernel, conv2d_kernel, conv3d_kernel, copy_kernel,
48    cumsum_backward_kernel, cumsum_kernel, dequant_matmul_kernel, elementwise_region_kernel,
49    elementwise_region_spatial_kernel, expand_kernel, fused_residual_ln_kernel,
50    fused_residual_ln_tee_kernel, fused_residual_rms_norm_kernel, gather_axis_kernel,
51    gather_backward_acc_kernel, gather_backward_zero_kernel, gather_kernel, grouped_matmul_kernel,
52    layer_norm_backward_gamma_partial_kernel, layer_norm_backward_gamma_reduce_kernel,
53    layer_norm_backward_input_kernel, layernorm_kernel, matmul_coop_f16_vulkan_active_kernel,
54    matmul_coop_f16_vulkan_kernel, matmul_coop_f32_active_kernel, matmul_coop16_kernel,
55    matmul_f16_compute_kernel, matmul_f16w_kernel, matmul_kernel,
56    matmul_qkv_coop_f16_vk_active_kernel, matmul_qkv_coop_f16_vk_kernel,
57    matmul_qkv_coop_f32_kernel, matmul_qkv_kernel, matmul_wide_active_kernel, matmul_wide_kernel,
58    narrow_kernel, pool1d_kernel, pool2d_kernel, pool3d_kernel, reduce_kernel,
59    rms_norm_backward_kernel, rms_norm_backward_param_kernel, rope_backward_kernel, rope_kernel,
60    sample_kernel, scatter_add_kernel, selective_scan_kernel, softmax_kernel, topk_kernel,
61    transpose_kernel, umap_knn_kernel, unary_f16_mirror_kernel, unary_kernel,
62    welch_peaks_gpu_kernel, where_kernel,
63};
64/// Compute the maximum tail-scratch bytes any single op needs across
65/// the graph. Currently only `Op::LayerNormBackwardGamma` uses scratch
66/// — it stores `num_workgroups * H` f32 partial sums.
67fn compute_scratch_bytes(graph: &rlx_ir::Graph) -> usize {
68    const ROWS_PER_WG: u32 = 16;
69    let mut max_bytes = 0usize;
70    for node in graph.nodes() {
71        // Norm staging: when params live far from activations in the arena,
72        // wgpu's `max_storage_buffer_binding_size` can prevent binding a
73        // single window that covers both. We reserve a small scratch tail
74        // zone so we can copy gamma/beta next to activations via
75        // `copy_buffer_to_buffer` and keep shader bindings local.
76        if matches!(
77            &node.op,
78            rlx_ir::Op::LayerNorm { .. } | rlx_ir::Op::RmsNorm { .. }
79        ) {
80            let x_shape = &graph.node(node.inputs[0]).shape;
81            let h_dim = x_shape.dim(x_shape.rank() - 1);
82            if h_dim.is_static() {
83                let h = h_dim.unwrap_static();
84                // gamma + beta, 256B-aligned for binding offsets.
85                let bytes = ((h * 4).div_ceil(256) * 256) * 2;
86                if bytes > max_bytes {
87                    max_bytes = bytes;
88                }
89            }
90        }
91        if let rlx_ir::Op::LayerNormBackwardGamma { .. } = &node.op {
92            let x_shape = &graph.node(node.inputs[0]).shape;
93            let Some(elems) = x_shape.num_elements() else {
94                continue;
95            };
96            let h_dim = x_shape.dim(x_shape.rank() - 1);
97            if !h_dim.is_static() {
98                continue;
99            }
100            let h = h_dim.unwrap_static();
101            if h == 0 {
102                continue;
103            }
104            let rows = (elems / h) as u32;
105            let num_workgroups = rows.div_ceil(ROWS_PER_WG.max(1));
106            let bytes = (num_workgroups as usize) * h * 4;
107            if bytes > max_bytes {
108                max_bytes = bytes;
109            }
110        }
111    }
112    // Reserve extra scratch for staging small far-apart operands when the
113    // arena exceeds wgpu's binding window. This keeps compile-time simple
114    // and avoids per-op scratch sizing plumbing.
115    max_bytes.max(64 * 1024 * 1024)
116}
117
118/// FNV-1a over f32 payload bytes — skips redundant `queue.write_buffer`
119/// when bench/inference feeds identical input tensors across runs.
120fn hash_f32_input(data: &[f32]) -> u64 {
121    let bytes = bytemuck::cast_slice(data);
122    let mut h: u64 = 0xcbf29ce484222325;
123    h ^= data.len() as u64;
124    h = h.wrapping_mul(0x100000001b3);
125    for chunk in bytes.chunks(8) {
126        let mut arr = [0u8; 8];
127        arr[..chunk.len()].copy_from_slice(chunk);
128        h ^= u64::from_le_bytes(arr);
129        h = h.wrapping_mul(0x100000001b3);
130    }
131    h
132}
133
134/// Inner-FMA precision for matmul.
135///   F32    — full f32 path (matmul.wgsl / matmul_wide.wgsl).
136///   F16    — f16 multiply, f32 acc (matmul_f16_compute.wgsl).
137///   Coop16 — cooperative-matrix 8×8 hardware GEMM
138///            (matmul_coop16.wgsl, simdgroup_multiply_accumulate on
139///             Apple, OpCooperativeMatrixMulAddKHR on Vulkan).
140///            Requires M/N/K multiples of 8, b is a Param, and
141///            both SHADER_F16 + EXPERIMENTAL_COOPERATIVE_MATRIX.
142///            Caller must ensure A is mirrored to arena_f16 first
143///            (the lowering inserts a `Step::CastF32ToF16` pre-pass).
144#[derive(Debug, Clone, Copy, PartialEq, Eq)]
145enum MatmulCompute {
146    F32,
147    F16,
148    Coop16,
149    /// Cooperative-matrix on Apple's `simdgroup_float8x8` — same hardware
150    /// GEMM unit as Coop16 but with f32 operands and f32 accumulator.
151    /// No precision loss vs F32 baseline; no f16 overflow risk in deep
152    /// FFN sums. Used when alignment + features allow but the IR is f32.
153    CoopF32,
154    /// Vulkan/NVIDIA 16×16 f16 tensor-core matmul with K-slab f32
155    /// reduction (avoids Naga mixed f16/f32 coop_mat bugs).
156    CoopF16Vk,
157}
158
159/// Split-write QKV matmul kernel selection.
160#[derive(Debug, Clone, Copy, PartialEq, Eq)]
161enum MatmulQkvKind {
162    F32,
163    CoopF32,
164    CoopF16Vk,
165}
166
167/// f32 → f16 element-wise cast, mirroring an arena region into the
168/// f16 shadow buffer. Used as a pre-pass before `matmul_coop16` so
169/// the matmul's A operand (a runtime activation, not a Param) is
170/// readable as f16.
171///
172/// Currently unused — the matmul_coop16 kernel stages A through
173/// workgroup-shared memory directly from the f32 arena. Kept for
174/// future paths that may want a one-shot cast (e.g. before a chain
175/// of f16-only kernels operating on a fixed activation region).
176#[allow(dead_code)]
177#[derive(Debug, Clone, Copy)]
178struct CastF32ToF16Params {
179    pub src_off: u32, // f32-element offset into arena (also f16-element offset)
180    pub len: u32,
181    pub _p0: u32,
182    pub _p1: u32,
183}
184unsafe impl bytemuck::Pod for CastF32ToF16Params {}
185unsafe impl bytemuck::Zeroable for CastF32ToF16Params {}
186
187/// One dispatch step in the compiled schedule.
188///
189/// `dead_code` is allowed at the enum level: several variants carry
190/// fields (mask_buf, meta_idx, compute_precision discriminants) that
191/// are only consulted at compile time during bind-group construction,
192/// or are kept to extend buffer lifetimes (mask_buf). A few variants
193/// (CastF32ToF16, Copy, the unreachable F16 compute_precision) are
194/// retained for future paths.
195#[allow(dead_code)]
196enum Step {
197    CastF32ToF16 {
198        params: CastF32ToF16Params,
199    },
200    Matmul {
201        m: u32,
202        k: u32,
203        n: u32,
204        a_off_f32: u32,
205        b_off_f32: u32,
206        c_off_f32: u32,
207        batch: u32,
208        a_batch_stride: u32,
209        b_batch_stride: u32,
210        c_batch_stride: u32,
211        has_bias: u32,
212        bias_off_f32: u32,
213        act_id: u32, // 0xFFFF = no activation
214        // True iff input B is a Param node — i.e. a model weight that
215        // doesn't change between `run()` calls. Read from the f16
216        // shadow buffer (half memory bandwidth) when set + the device
217        // exposes SHADER_F16. Set at compile time; consulted only by
218        // the dispatch arm.
219        b_is_param: bool,
220        // Compute precision for the inner FMA. F32 = full precision
221        // (the historical / default path). F16 = mixed-precision
222        // (operands cast to f16, multiply in f16 for 2× ALU on Apple,
223        // accumulator in f32). Set at compile time from the IR's
224        // dtype after AutoMixedPrecision policy.
225        compute_precision: MatmulCompute,
226    },
227    Binary {
228        params: BinaryParams,
229    },
230    Compare {
231        params: BinaryParams,
232    },
233    Unary {
234        params: UnaryParams,
235        f16_mirror: bool,
236    },
237    Where {
238        params: WhereParams,
239    },
240    Reduce {
241        params: ReduceParams,
242    },
243    Softmax {
244        params: SoftmaxParams,
245    },
246    LayerNorm {
247        params: LayerNormParams,
248    },
249    Cumsum {
250        params: CumsumParams,
251    },
252    /// Native multi-kernel f32 FFT (gpu-fft dispatch strategy).
253    FftGpu {
254        src_off: u32,
255        dst_off: u32,
256        outer: u32,
257        n: u32,
258        inverse: u32,
259        norm_scale: f32,
260    },
261    /// Explicit host FFT (D2H → rlx-cpu → H2D). Used when the native
262    /// WGSL kernel cannot handle dtype / size / non-pow-2 constraints.
263    FftHost {
264        src_byte_off: u32,
265        dst_byte_off: u32,
266        outer: u32,
267        n_complex: u32,
268        inverse: bool,
269        norm_tag: u32,
270        dtype_tag: u32,
271    },
272    /// Welch PSD top-K — D2H → rlx-cpu → H2D.
273    WelchPeaksHost {
274        spec_byte_off: u32,
275        dst_byte_off: u32,
276        welch_batch: u32,
277        n_fft: u32,
278        n_segments: u32,
279        k: u32,
280    },
281    LogMelHost {
282        spec_byte_off: u32,
283        filt_byte_off: u32,
284        dst_byte_off: u32,
285        outer: u32,
286        n_fft: u32,
287        n_bins: u32,
288        n_mels: u32,
289    },
290    LogMelBackwardHost {
291        spec_byte_off: u32,
292        filt_byte_off: u32,
293        dy_byte_off: u32,
294        dst_byte_off: u32,
295        outer: u32,
296        n_fft: u32,
297        n_bins: u32,
298        n_mels: u32,
299    },
300    /// NCHW im2col host path (D2H → rlx-cpu → H2D).
301    Im2ColHost {
302        x_byte_off: u32,
303        col_byte_off: u32,
304        n: u32,
305        c_in: u32,
306        h: u32,
307        w: u32,
308        h_out: u32,
309        w_out: u32,
310        kh: u32,
311        kw: u32,
312        sh: u32,
313        sw: u32,
314        ph: u32,
315        pw: u32,
316        dh: u32,
317        dw_dil: u32,
318    },
319    /// Host-side buffer copy (recorded into a command encoder) used to
320    /// stage small param tensors into the tail scratch region so kernels
321    /// can bind a ≤4GiB window of the arena.
322    BufferCopy {
323        src_byte_off: u32,
324        dst_byte_off: u32,
325        bytes: u32,
326    },
327    Copy {
328        params: CopyParams,
329    },
330    /// PLAN L2 — fused N-ary element-wise region. Lowered from
331    /// `Op::ElementwiseRegion` by `MarkElementwiseRegions`. Kernel
332    /// interprets the chain encoding per-element (saves N kernel
333    /// dispatches + N global-memory round-trips vs the decomposed
334    /// atomic ops).
335    ElementwiseRegion {
336        params: ElementwiseRegionParams,
337    },
338    BatchElementwiseRegion {
339        params: BatchElementwiseRegionParams,
340    },
341    Transpose {
342        params: TransposeParams,
343        meta_idx: usize,
344    },
345    Narrow {
346        params: NarrowConcatParams,
347    },
348    Concat {
349        params: NarrowConcatParams,
350    }, // one Step per input
351    Gather {
352        params: GatherParams,
353    },
354    GatherAxis {
355        params: GatherAxisParams,
356    },
357    Attention {
358        params: AttentionParams,
359        mask_buf: Option<wgpu::Buffer>,
360    },
361    AttentionBackward {
362        params: AttentionBwdParams,
363        mask_buf: Option<wgpu::Buffer>,
364    },
365    Rope {
366        params: RopeParams,
367    },
368    Expand {
369        params: ExpandParams,
370        meta_idx: usize,
371    },
372    Argmax {
373        params: ArgmaxParams,
374    },
375    Pool2d {
376        params: Pool2dParams,
377    },
378    Conv2d {
379        params: Conv2dParams,
380    },
381    Pool1d {
382        params: Pool1dParams,
383    },
384    Pool3d {
385        params: Pool3dParams,
386    },
387    Conv1d {
388        params: Conv1dParams,
389    },
390    Conv3d {
391        params: Conv3dParams,
392    },
393    ScatterAdd {
394        params: ScatterAddParams,
395    },
396    TopK {
397        params: TopKParams,
398    },
399    WelchPeaksGpu {
400        params: WelchPeaksGpuParams,
401    },
402    GroupedMatmul {
403        params: GroupedMatmulParams,
404    },
405    Sample {
406        params: SampleParams,
407    },
408    SelectiveScan {
409        params: SelectiveScanParams,
410    },
411    DequantMatmul {
412        params: DequantMatmulParams,
413    },
414    /// GGUF K-quant — host fused dequant+matmul between GPU segments.
415    DequantMatmulGguf {
416        m: u32,
417        k: u32,
418        n: u32,
419        scheme_id: u32,
420        x_byte_off: u32,
421        w_byte_off: u32,
422        out_byte_off: u32,
423    },
424    /// GGUF K-quant — host fused dequant+grouped matmul between GPU segments.
425    DequantGroupedMatmulGguf {
426        m: u32,
427        k: u32,
428        n: u32,
429        num_experts: u32,
430        scheme_id: u32,
431        x_byte_off: u32,
432        w_byte_off: u32,
433        idx_byte_off: u32,
434        out_byte_off: u32,
435    },
436    /// Gated-DeltaNet — host scan between GPU segments (qwen35 linear layers).
437    GatedDeltaNet {
438        q_byte_off: u32,
439        k_byte_off: u32,
440        v_byte_off: u32,
441        g_byte_off: u32,
442        beta_byte_off: u32,
443        state_byte_off: u32,
444        dst_byte_off: u32,
445        batch: u32,
446        seq: u32,
447        heads: u32,
448        state_size: u32,
449        use_carry: bool,
450    },
451    Llada2GroupLimitedGate {
452        sig_byte_off: u32,
453        route_byte_off: u32,
454        out_byte_off: u32,
455        n_elems: u32,
456        attrs: [u8; 20],
457    },
458    UmapKnn {
459        params: UmapKnnParams,
460    },
461    /// Small-`n` host k-NN (partial arena read/write; avoids GPU launch overhead).
462    UmapKnnHost {
463        pairwise_byte_off: u32,
464        out_byte_off: u32,
465        n: u32,
466        k: u32,
467    },
468    /// 3D Gaussian splat forward (CPU reference between segments).
469    #[cfg(feature = "splat")]
470    GaussianSplatRender {
471        positions_byte_off: u32,
472        positions_len: u32,
473        scales_byte_off: u32,
474        scales_len: u32,
475        rotations_byte_off: u32,
476        rotations_len: u32,
477        opacities_byte_off: u32,
478        opacities_len: u32,
479        colors_byte_off: u32,
480        colors_len: u32,
481        sh_coeffs_byte_off: u32,
482        sh_coeffs_len: u32,
483        meta_byte_off: u32,
484        dst_byte_off: u32,
485        dst_len: u32,
486        width: u32,
487        height: u32,
488        tile_size: u32,
489        radius_scale: f32,
490        alpha_cutoff: f32,
491        max_splat_steps: u32,
492        transmittance_threshold: f32,
493        max_list_entries: u32,
494    },
495    /// Backward splat — host round-trip via rlx-cpu/splat.
496    #[cfg(feature = "splat")]
497    GaussianSplatRenderBackward {
498        positions_byte_off: u32,
499        positions_len: u32,
500        scales_byte_off: u32,
501        scales_len: u32,
502        rotations_byte_off: u32,
503        rotations_len: u32,
504        opacities_byte_off: u32,
505        opacities_len: u32,
506        colors_byte_off: u32,
507        colors_len: u32,
508        sh_coeffs_byte_off: u32,
509        sh_coeffs_len: u32,
510        meta_byte_off: u32,
511        d_loss_byte_off: u32,
512        d_loss_len: u32,
513        packed_byte_off: u32,
514        packed_len: u32,
515        width: u32,
516        height: u32,
517        tile_size: u32,
518        radius_scale: f32,
519        alpha_cutoff: f32,
520        max_splat_steps: u32,
521        transmittance_threshold: f32,
522        max_list_entries: u32,
523        loss_grad_clip: f32,
524        sh_band: u32,
525        max_anisotropy: f32,
526    },
527    #[cfg(feature = "splat")]
528    GaussianSplatPrepare {
529        positions_byte_off: u32,
530        positions_len: u32,
531        scales_byte_off: u32,
532        scales_len: u32,
533        rotations_byte_off: u32,
534        rotations_len: u32,
535        opacities_byte_off: u32,
536        opacities_len: u32,
537        colors_byte_off: u32,
538        colors_len: u32,
539        sh_coeffs_byte_off: u32,
540        sh_coeffs_len: u32,
541        meta_byte_off: u32,
542        meta_len: u32,
543        prep_byte_off: u32,
544        prep_len: u32,
545        width: u32,
546        height: u32,
547        tile_size: u32,
548        radius_scale: f32,
549        alpha_cutoff: f32,
550        max_splat_steps: u32,
551        transmittance_threshold: f32,
552        max_list_entries: u32,
553    },
554    #[cfg(feature = "splat")]
555    GaussianSplatRasterize {
556        prep_byte_off: u32,
557        prep_len: u32,
558        meta_byte_off: u32,
559        meta_len: u32,
560        dst_byte_off: u32,
561        dst_len: u32,
562        count: u32,
563        width: u32,
564        height: u32,
565        tile_size: u32,
566        alpha_cutoff: f32,
567        max_splat_steps: u32,
568        transmittance_threshold: f32,
569        max_list_entries: u32,
570    },
571    RmsNormBackwardInput {
572        params: RmsNormBwdParams,
573    },
574    RmsNormBackwardGamma {
575        params: RmsNormBwdParams,
576    },
577    RmsNormBackwardBeta {
578        params: RmsNormBwdParams,
579    },
580    LayerNormBackwardInput {
581        params: LayerNormBwdParams,
582    },
583    LayerNormBackwardGammaPartial {
584        params: LayerNormBwdParams,
585        num_workgroups: u32,
586    },
587    LayerNormBackwardGammaReduce {
588        params: LayerNormBwdParams,
589    },
590    RopeBackward {
591        params: RopeBwdParams,
592    },
593    CumsumBackward {
594        params: CumsumBwdParams,
595    },
596    GatherBackward {
597        params: GatherBwdParams,
598    },
599    FusedResidualLn {
600        params: FusedResidualLnParams,
601    },
602    /// Split-write QKV matmul. Replaces a (FusedMatMulBiasAct → Narrow×3)
603    /// pattern with one dispatch that writes Q, K, V into separate
604    /// contiguous buffers from a single matmul pass. See
605    /// `kernels/matmul_qkv.wgsl`.
606    MatmulQkv {
607        params: MatmulQkvParams,
608        kind: MatmulQkvKind,
609    },
610    /// `fused_residual_ln_tee` — does (Add → LN) but writes the sum to
611    /// a separate arena slot (the eliminated Add's old slot). Fires
612    /// when the Add has multi-consumer downstream (vision pre-norm).
613    FusedResidualLnTee {
614        params: FusedResidualLnTeeParams,
615    },
616    FusedResidualRmsNorm {
617        params: FusedResidualRmsNormParams,
618    },
619}
620
621pub struct WgpuExecutable {
622    graph: Graph,
623    arena: Arena,
624    schedule: Vec<Step>,
625    input_offsets: HashMap<String, NodeId>,
626    param_offsets: HashMap<String, NodeId>,
627    /// One uniform buffer + bind group per dispatch step. Pre-allocated
628    /// so run() just writes new bytes per step.
629    uniforms: Vec<wgpu::Buffer>,
630    bind_groups: Vec<wgpu::BindGroup>,
631    /// Per-step metadata storage buffers (only Transpose uses them).
632    /// Indexed by `Step::Transpose.meta_idx`.
633    meta_buffers: Vec<wgpu::Buffer>,
634
635    // ── Lazy dynamic-shape state ─────────────────────────────────
636    /// The originally-supplied graph (pre-resolution). Only set when
637    /// the input graph contained `Dim::Dynamic` entries — otherwise
638    /// `None` and the compiled fields above are authoritative. On each
639    /// `run()` we infer a `DimBinding` from the live input data, and
640    /// if it differs from `last_binding` we re-resolve + recompile.
641    unresolved: Option<Graph>,
642    last_binding: Option<DimBinding>,
643    /// Buffered params written via `set_param` / `set_param_bytes`
644    /// before the first `run()`. Replayed against the freshly compiled
645    /// arena once shapes resolve.
646    pending_params: HashMap<String, Vec<f32>>,
647    pending_param_bytes: HashMap<String, Vec<u8>>,
648    /// Active-extent hint (PLAN L1). When set + every Step in the
649    /// safe set, both the uniform write and the dispatch workgroup
650    /// count are scaled by `actual / upper`. Otherwise full-extent.
651    pub(crate) active_extent: Option<(usize, usize)>,
652    /// Skip-redundant-uniform-writes guard. Each `run()` would
653    /// otherwise re-`queue.write_buffer` ~115 per-step uniforms (one
654    /// per dispatched op in BERT) even when their bytes are identical
655    /// to the previous call's. At small batches, that fixed write +
656    /// staging-copy overhead is the dominant cost. We track the last
657    /// active-extent value the uniforms were written for; subsequent
658    /// `run()`s with the same `active_extent` (and `recompile`-clean
659    /// schedule) skip the entire uniform-write loop. `None` ⇒ never
660    /// written; `Some(x)` ⇒ uniforms hold params for active_extent=x.
661    uniforms_active_extent: Option<Option<(usize, usize)>>,
662    /// Last-upload fingerprint per input name; skips staging when unchanged.
663    input_staging_hashes: HashMap<String, u64>,
664    /// True when the schedule contains CoopF16Vk matmul (disables f32-only
665    /// input upload skip — the f16 shadow must stay in sync each run).
666    coop_f16_vk: bool,
667    /// CoopF16Vk Param B offsets (f32 arena / 4) → param name for wide routing.
668    coop_f16_b_param: HashMap<u32, String>,
669    /// Param names flagged by the oscillation probe for wide f32 fallback.
670    coop_f16_vk_wide_b: HashSet<String>,
671    /// Wide f32 bind groups for CoopF16Vk steps (schedule index → bg).
672    coop_f16_vk_wide_bind_groups: HashMap<usize, wgpu::BindGroup>,
673    /// CoopF16Vk activation operands mirrored on the host each `run()` (f32+f16).
674    coop_f16_host_activations: Vec<(NodeId, Activation, String)>,
675    /// Last `set_param` f32 payload per name (for host activation mirrors).
676    stashed_params: HashMap<String, Vec<f32>>,
677    /// Reused output readback staging (avoids per-run buffer alloc).
678    readback_staging: Option<ReadbackStaging>,
679    /// Persistent tiny readback buffer for single scalar outputs.
680    tiny_readback: Option<TinyReadbackStaging>,
681    /// Per-`FftGpu` step: isolated uniform buffers + bind groups (one vec entry per op).
682    fft_gpu_steps: Vec<crate::fft_dispatch::FftGpuResources>,
683    /// Persistent KV inputs (host staging uploaded each run).
684    gpu_handles: HashMap<String, Vec<f32>>,
685    gpu_handle_feeds: HashMap<String, usize>,
686    /// Arena input slots authoritative — skip host KV mirror each decode step.
687    gpu_handle_resident: HashSet<String>,
688    pending_read_indices: Option<Vec<usize>>,
689}
690
691impl Step {
692    /// True when this Step variant honors active-extent dispatch (PLAN L1).
693    /// Coverage: simple element-wise + reductions + matmul + linalg
694    /// + reductions/argmax/topk/sample + gather + conv + pool +
695    /// scatter (zero output + scale num_updates) + macros gated to
696    /// batch=1 (Attention, SelectiveScan).
697    pub fn safe_for_active_extent(&self) -> bool {
698        match self {
699            Step::Binary { .. }
700            | Step::Compare { .. }
701            | Step::Unary { .. }
702            | Step::Where { .. }
703            | Step::Reduce { .. }
704            | Step::Softmax { .. }
705            | Step::LayerNorm { .. }
706            | Step::FusedResidualLn { .. }
707            | Step::FusedResidualLnTee { .. }
708            | Step::FusedResidualRmsNorm { .. }
709            | Step::Cumsum { .. }
710            | Step::Copy { .. }
711            | Step::ElementwiseRegion { .. }
712            | Step::BatchElementwiseRegion { .. }
713            | Step::Argmax { .. }
714            | Step::TopK { .. }
715            | Step::WelchPeaksGpu { .. }
716            | Step::Sample { .. }
717            | Step::Gather { .. }
718            | Step::GatherAxis { .. }
719            | Step::GroupedMatmul { .. }
720            | Step::DequantMatmul { .. }
721            | Step::DequantMatmulGguf { .. }
722            | Step::DequantGroupedMatmulGguf { .. }
723            | Step::GatedDeltaNet { .. }
724            | Step::Llada2GroupLimitedGate { .. }
725            | Step::UmapKnn { .. }
726            | Step::UmapKnnHost { .. }
727            | Step::Conv1d { .. }
728            | Step::Conv2d { .. }
729            | Step::Conv3d { .. }
730            | Step::Pool1d { .. }
731            | Step::Pool2d { .. }
732            | Step::Pool3d { .. }
733            | Step::ScatterAdd { .. }
734            | Step::BufferCopy { .. } => true,
735            // FFT: full-extent transform per row, no active-extent
736            // scaling. Marking true so a graph that mixes FFT with
737            // active-extent-safe ops still gets the optimization for
738            // the rest of the schedule.
739            Step::FftGpu { .. } | Step::FftHost { .. } => true,
740            Step::Im2ColHost { .. }
741            | Step::WelchPeaksHost { .. }
742            | Step::LogMelHost { .. }
743            | Step::LogMelBackwardHost { .. } => true,
744            // Matmul: c_batch_stride is set at compile time at full m,
745            // independent of params.m. With scaled m, threads with
746            // global_row >= m early-return; per-batch output offsets
747            // stay correct. Safe at any batch.
748            Step::Matmul { .. } => true,
749            // Same active-extent reasoning as Matmul: per-batch output
750            // strides are baked at compile time, scaling m only adjusts
751            // the per-thread bound check.
752            Step::MatmulQkv { .. } => true,
753            Step::CastF32ToF16 { .. } => true,
754            // Attention: WGSL kernel uses `seq_q_stride`/`seq_k_stride`
755            // (full extent, set at compile time) for per-(batch, head)
756            // offset math, and `params.seq_q`/`params.seq_k` for loop
757            // bounds only. Scaling seq_q/seq_k shrinks the iteration
758            // without corrupting per-head strides. Safe at any batch.
759            Step::Attention { .. } => true,
760            Step::AttentionBackward { .. } => true,
761            // SelectiveScan: WGSL kernel uses `params.seq_stride`
762            // (full extent, set at compile time) for per-batch stride
763            // math; `params.seq` is the loop bound only. Safe at any
764            // batch under active-extent scaling of seq.
765            Step::SelectiveScan { .. } => true,
766            // Narrow + Concat: kernel iterates `params.total` in
767            // row-major order with outer as the leading dim. Scaling
768            // total by actual/upper effectively scales outer by the
769            // same factor (since total = outer * axis_size * inner).
770            // Output positions past scaled_total stay untouched.
771            // **Conservative assumption**: bucket axis is outer.
772            // Cases where the bucket axis is the narrow/concat axis
773            // itself are unsafe — fall back to full extent there.
774            Step::Narrow { .. } => true,
775            Step::Concat { .. } => true,
776            // Rope: WGSL kernel uses `seq_stride` (full extent, set
777            // at compile time) for per-batch buffer offset math and
778            // explicit `batch` for index decomposition. `params.seq`
779            // and `params.n_total` are runtime-scaled iteration
780            // bounds. Safe at any batch.
781            Step::Rope { .. } => true,
782            // Transpose: precomputed `bucket_outermost` flag in
783            // params (set to 1 at compile time iff `perm[0] == 0`).
784            // Active path scales `out_total` by `actual / upper`
785            // proportional to `out_dim_0`. Other transposes (where
786            // bucket axis moves) fall back to full extent.
787            Step::Transpose { params, .. } => params.bucket_outermost == 1,
788            // Expand: same shape as Transpose. `bucket_outermost` is
789            // 1 iff `in_dims[0] == out_dims[0]` (no broadcast at the
790            // bucket axis).
791            Step::Expand { params, .. } => params.bucket_outermost == 1,
792            // Training backward ops: not used in inference; disable
793            // active-extent fast path until individually audited.
794            Step::RmsNormBackwardInput { .. }
795            | Step::RmsNormBackwardGamma { .. }
796            | Step::RmsNormBackwardBeta { .. }
797            | Step::LayerNormBackwardInput { .. }
798            | Step::LayerNormBackwardGammaPartial { .. }
799            | Step::LayerNormBackwardGammaReduce { .. }
800            | Step::RopeBackward { .. }
801            | Step::CumsumBackward { .. }
802            | Step::GatherBackward { .. } => false,
803            #[cfg(feature = "splat")]
804            Step::GaussianSplatRender { .. }
805            | Step::GaussianSplatRenderBackward { .. }
806            | Step::GaussianSplatPrepare { .. }
807            | Step::GaussianSplatRasterize { .. } => false,
808        }
809    }
810}
811
812/// Static-string label for each Step variant — used by the Perfetto
813/// trace layer (PLAN L3) to mark per-step events without allocating.
814fn fft_dtype_tag(dtype: rlx_ir::DType) -> u32 {
815    match dtype {
816        rlx_ir::DType::F32 => 0,
817        rlx_ir::DType::F64 => 1,
818        rlx_ir::DType::C64 => 2,
819        other => panic!("rlx-wgpu Op::Fft: unsupported dtype {other:?}"),
820    }
821}
822
823fn fft_dtype_from_tag(tag: u32) -> rlx_ir::DType {
824    match tag {
825        0 => rlx_ir::DType::F32,
826        1 => rlx_ir::DType::F64,
827        2 => rlx_ir::DType::C64,
828        other => panic!("rlx-wgpu Op::Fft: bad dtype tag {other}"),
829    }
830}
831
832fn step_name(step: &Step) -> &'static str {
833    match step {
834        Step::CastF32ToF16 { .. } => "cast_f32_to_f16",
835        Step::Matmul { .. } => "matmul",
836        Step::Binary { .. } => "binary",
837        Step::Compare { .. } => "compare",
838        Step::Unary { .. } => "unary",
839        Step::Where { .. } => "where",
840        Step::Reduce { .. } => "reduce",
841        Step::Softmax { .. } => "softmax",
842        Step::LayerNorm { .. } => "layer_norm",
843        Step::Cumsum { .. } => "cumsum",
844        Step::FftGpu { .. } => "fft_gpu",
845        Step::FftHost { .. } => "fft_host",
846        Step::WelchPeaksHost { .. } => "welch_peaks_host",
847        Step::LogMelHost { .. } => "log_mel_host",
848        Step::LogMelBackwardHost { .. } => "log_mel_backward_host",
849        Step::Im2ColHost { .. } => "im2col_host",
850        Step::BufferCopy { .. } => "buffer_copy",
851        Step::Copy { .. } => "copy",
852        Step::Transpose { .. } => "transpose",
853        Step::Narrow { .. } => "narrow",
854        Step::Concat { .. } => "concat",
855        Step::Gather { .. } => "gather",
856        Step::GatherAxis { .. } => "gather_axis",
857        Step::Attention { .. } => "attention",
858        Step::AttentionBackward { .. } => "attention_bwd",
859        Step::Rope { .. } => "rope",
860        Step::Expand { .. } => "expand",
861        Step::Argmax { .. } => "argmax",
862        Step::Pool2d { .. } => "pool2d",
863        Step::Conv2d { .. } => "conv2d",
864        Step::Pool1d { .. } => "pool1d",
865        Step::Pool3d { .. } => "pool3d",
866        Step::Conv1d { .. } => "conv1d",
867        Step::Conv3d { .. } => "conv3d",
868        Step::ScatterAdd { .. } => "scatter_add",
869        Step::TopK { .. } => "topk",
870        Step::WelchPeaksGpu { .. } => "welch_peaks_gpu",
871        Step::GroupedMatmul { .. } => "grouped_matmul",
872        Step::Sample { .. } => "sample",
873        Step::SelectiveScan { .. } => "selective_scan",
874        Step::DequantMatmul { .. } => "dequant_matmul",
875        Step::DequantMatmulGguf { .. } => "dequant_matmul_gguf",
876        Step::DequantGroupedMatmulGguf { .. } => "dequant_grouped_matmul_gguf",
877        Step::GatedDeltaNet { .. } => "gated_delta_net",
878        Step::Llada2GroupLimitedGate { .. } => "llada2_group_limited_gate",
879        Step::UmapKnn { .. } => "umap_knn",
880        Step::UmapKnnHost { .. } => "umap_knn_host",
881        #[cfg(feature = "splat")]
882        Step::GaussianSplatRender { .. } => "gaussian_splat_render",
883        #[cfg(feature = "splat")]
884        Step::GaussianSplatRenderBackward { .. } => "gaussian_splat_render_backward",
885        #[cfg(feature = "splat")]
886        Step::GaussianSplatPrepare { .. } => "gaussian_splat_prepare",
887        #[cfg(feature = "splat")]
888        Step::GaussianSplatRasterize { .. } => "gaussian_splat_rasterize",
889        Step::RmsNormBackwardInput { .. } => "rms_norm_backward_input",
890        Step::RmsNormBackwardGamma { .. } => "rms_norm_backward_gamma",
891        Step::RmsNormBackwardBeta { .. } => "rms_norm_backward_beta",
892        Step::LayerNormBackwardInput { .. } => "layer_norm_backward_input",
893        Step::LayerNormBackwardGammaPartial { .. } => "layer_norm_backward_gamma_partial",
894        Step::LayerNormBackwardGammaReduce { .. } => "layer_norm_backward_gamma_reduce",
895        Step::RopeBackward { .. } => "rope_backward",
896        Step::CumsumBackward { .. } => "cumsum_backward",
897        Step::GatherBackward { .. } => "gather_backward",
898        Step::FusedResidualLn { .. } => "fused_residual_ln",
899        Step::FusedResidualLnTee { .. } => "fused_residual_ln_tee",
900        Step::FusedResidualRmsNorm { .. } => "fused_residual_rms_norm",
901        Step::MatmulQkv { .. } => "matmul_qkv",
902        Step::ElementwiseRegion { .. } => "elementwise_region",
903        Step::BatchElementwiseRegion { .. } => "batch_elementwise_region",
904    }
905}
906
907fn step_is_tail_host(step: &Step) -> bool {
908    matches!(
909        step,
910        Step::WelchPeaksHost { .. } | Step::LogMelHost { .. } | Step::LogMelBackwardHost { .. }
911    )
912}
913
914fn step_runs_on_host(step: &Step) -> bool {
915    match step {
916        Step::DequantMatmulGguf { .. }
917        | Step::DequantGroupedMatmulGguf { .. }
918        | Step::GatedDeltaNet { .. }
919        | Step::Llada2GroupLimitedGate { .. }
920        | Step::UmapKnnHost { .. }
921        | Step::FftHost { .. }
922        | Step::Im2ColHost { .. }
923        | Step::BufferCopy { .. } => true,
924        #[cfg(feature = "splat")]
925        Step::GaussianSplatRender { .. }
926        | Step::GaussianSplatRenderBackward { .. }
927        | Step::GaussianSplatPrepare { .. }
928        | Step::GaussianSplatRasterize { .. } => true,
929        _ => false,
930    }
931}
932
933fn binary_op_id(op: BinaryOp) -> u32 {
934    match op {
935        BinaryOp::Add => 0,
936        BinaryOp::Sub => 1,
937        BinaryOp::Mul => 2,
938        BinaryOp::Div => 3,
939        BinaryOp::Max => 4,
940        BinaryOp::Min => 5,
941        BinaryOp::Pow => 6,
942    }
943}
944
945fn compare_op_id(op: CmpOp) -> u32 {
946    match op {
947        CmpOp::Eq => 0,
948        CmpOp::Ne => 1,
949        CmpOp::Lt => 2,
950        CmpOp::Le => 3,
951        CmpOp::Gt => 4,
952        CmpOp::Ge => 5,
953    }
954}
955
956fn reduce_op_id(op: ReduceOp) -> u32 {
957    match op {
958        ReduceOp::Sum => 0,
959        ReduceOp::Mean => 1,
960        ReduceOp::Max => 2,
961        ReduceOp::Min => 3,
962        ReduceOp::Prod => 4,
963    }
964}
965
966fn activation_op_id(act: Activation) -> u32 {
967    match act {
968        Activation::Relu => 0,
969        Activation::Sigmoid => 1,
970        Activation::Tanh => 2,
971        Activation::Exp => 3,
972        Activation::Log => 4,
973        Activation::Sqrt => 5,
974        Activation::Rsqrt => 6,
975        Activation::Neg => 7,
976        Activation::Abs => 8,
977        Activation::Gelu => 9,
978        Activation::Silu => 10,
979        Activation::GeluApprox => 11,
980        Activation::Round => 12,
981        Activation::Sin => 13,
982        Activation::Cos => 14,
983        Activation::Tan => 15,
984        Activation::Atan => 16,
985    }
986}
987
988impl WgpuExecutable {
989    /// Resolve the deferred graph against bindings inferred from
990    /// `inputs`, recompile the inner state if the bindings changed
991    /// since the last call, and replay any pending params.
992    fn lazy_compile_for_inputs(&mut self, inputs: &[(&str, &[f32])]) {
993        let unresolved = self
994            .unresolved
995            .as_ref()
996            .expect("lazy_compile_for_inputs called without an unresolved graph");
997        let binding = infer_bindings_from_f32_inputs(unresolved, inputs)
998            .expect("rlx-wgpu lazy compile: could not infer DimBinding from inputs");
999
1000        // No-op if shapes haven't changed since the last compile.
1001        if let Some(prev) = &self.last_binding
1002            && same_binding(prev, &binding)
1003        {
1004            return;
1005        }
1006
1007        // Resolve and recompile.
1008        let resolved = bind_graph(unresolved, &binding);
1009        let original = self.unresolved.take();
1010        let pending_params = std::mem::take(&mut self.pending_params);
1011        let pending_bytes = std::mem::take(&mut self.pending_param_bytes);
1012
1013        let fresh = Self::compile_static_inner(resolved);
1014
1015        // Move the freshly-compiled fields into self, preserve the
1016        // unresolved+binding state for the next round.
1017        self.graph = fresh.graph;
1018        self.arena = fresh.arena;
1019        self.schedule = fresh.schedule;
1020        self.input_offsets = fresh.input_offsets;
1021        self.param_offsets = fresh.param_offsets;
1022        self.uniforms = fresh.uniforms;
1023        self.bind_groups = fresh.bind_groups;
1024        self.meta_buffers = fresh.meta_buffers;
1025        self.unresolved = original;
1026        self.last_binding = Some(binding);
1027        // Recompiled — uniforms are now empty buffers; force re-write
1028        // on next run().
1029        self.uniforms_active_extent = None;
1030        self.input_staging_hashes.clear();
1031        self.coop_f16_vk = fresh.coop_f16_vk;
1032        self.coop_f16_b_param = fresh.coop_f16_b_param;
1033        self.coop_f16_vk_wide_bind_groups = fresh.coop_f16_vk_wide_bind_groups;
1034        self.coop_f16_host_activations = fresh.coop_f16_host_activations;
1035
1036        // Replay pending param uploads against the new arena.
1037        for (name, data) in pending_params {
1038            self.set_param(&name, &data);
1039        }
1040        for (name, data) in pending_bytes {
1041            self.set_param_bytes(&name, &data);
1042        }
1043    }
1044
1045    /// Compile against an explicit `DimBinding`. Each `Dim::Dynamic`
1046    /// in the graph that maps to a symbol in `bindings` is replaced
1047    /// with `Dim::Static(size)` before the standard compile runs.
1048    /// Symbols not in the binding stay dynamic — and then `compile`
1049    /// will panic with the usual diagnostic.
1050    pub fn compile_with_bindings(graph: Graph, bindings: &DimBinding) -> Self {
1051        if bindings.is_empty() {
1052            return Self::compile(graph);
1053        }
1054        // Walk the graph and bind every node's shape.
1055        let mut fresh = Graph::new(&graph.name);
1056        for node in graph.nodes() {
1057            let bound = node.shape.bind(bindings);
1058            fresh.add_node(node.op.clone(), node.inputs.clone(), bound);
1059        }
1060        fresh.set_outputs(graph.outputs.clone());
1061        Self::compile(fresh)
1062    }
1063
1064    pub fn compile(graph: Graph) -> Self {
1065        if has_dynamic_dims(&graph) {
1066            return Self::deferred(graph);
1067        }
1068        Self::compile_static_inner(graph)
1069    }
1070
1071    /// Test hook: first `Step::Attention` Q sequence stride (600 = packed QKV).
1072    #[doc(hidden)]
1073    pub fn test_attn_q_seq_stride(&self) -> Option<u32> {
1074        self.schedule.iter().find_map(|s| {
1075            if let Step::Attention { params, .. } = s {
1076                Some(params.q_seq_stride)
1077            } else {
1078                None
1079            }
1080        })
1081    }
1082
1083    /// Test hook: `(q_off, k_off, v_off, q_seq_stride)` for the first attention step.
1084    #[doc(hidden)]
1085    pub fn test_attn_offsets_and_stride(&self) -> Option<(u32, u32, u32, u32)> {
1086        self.schedule.iter().find_map(|s| {
1087            if let Step::Attention { params, .. } = s {
1088                Some((
1089                    params.q_off,
1090                    params.k_off,
1091                    params.v_off,
1092                    params.q_seq_stride,
1093                ))
1094            } else {
1095                None
1096            }
1097        })
1098    }
1099
1100    /// Global arena offset in f32 elements (not bind-window-local).
1101    #[doc(hidden)]
1102    pub fn test_arena_offset_elems(&self, id: NodeId) -> u32 {
1103        (self.arena.offset(id) / 4) as u32
1104    }
1105
1106    /// Compile placeholder for a graph with `Dim::Dynamic` entries.
1107    /// The real compile happens on the first `run()` once input data
1108    /// reveals the symbol → size bindings. Buffered params (set via
1109    /// `set_param` / `set_param_bytes` before run) are replayed.
1110    fn deferred(graph: Graph) -> Self {
1111        let dev = wgpu_device().expect("rlx-wgpu: no compatible adapter found");
1112        // Minimal valid arena buffer. Replaced on first run().
1113        let placeholder = dev.device.create_buffer(&wgpu::BufferDescriptor {
1114            label: Some("rlx-wgpu deferred placeholder"),
1115            size: 16,
1116            usage: wgpu::BufferUsages::STORAGE
1117                | wgpu::BufferUsages::COPY_DST
1118                | wgpu::BufferUsages::COPY_SRC,
1119            mapped_at_creation: false,
1120        });
1121        let arena = Arena {
1122            buffer: placeholder,
1123            f16_buffer: None,
1124            offsets: HashMap::new(),
1125            lens: HashMap::new(),
1126            size: 0,
1127            scratch_off: 0,
1128            scratch_bytes: 0,
1129        };
1130        Self {
1131            graph: graph.clone(),
1132            arena,
1133            schedule: Vec::new(),
1134            input_offsets: HashMap::new(),
1135            param_offsets: HashMap::new(),
1136            uniforms: Vec::new(),
1137            bind_groups: Vec::new(),
1138            meta_buffers: Vec::new(),
1139            unresolved: Some(graph),
1140            last_binding: None,
1141            pending_params: HashMap::new(),
1142            pending_param_bytes: HashMap::new(),
1143            active_extent: None,
1144            uniforms_active_extent: None,
1145            input_staging_hashes: HashMap::new(),
1146            coop_f16_vk: false,
1147            coop_f16_b_param: HashMap::new(),
1148            coop_f16_vk_wide_b: HashSet::new(),
1149            coop_f16_vk_wide_bind_groups: HashMap::new(),
1150            coop_f16_host_activations: Vec::new(),
1151            stashed_params: HashMap::new(),
1152            readback_staging: None,
1153            tiny_readback: None,
1154            fft_gpu_steps: Vec::new(),
1155            gpu_handles: HashMap::new(),
1156            gpu_handle_feeds: HashMap::new(),
1157            gpu_handle_resident: HashSet::new(),
1158            pending_read_indices: None,
1159        }
1160    }
1161
1162    /// Hint the next `run` to process only the first `actual` rows
1163    /// along the bucket axis (out of `upper`, the compile extent).
1164    /// Honored when every Step is in the safe set. See PLAN L1.
1165    pub fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1166        self.active_extent = extent;
1167    }
1168
1169    fn all_safe_for_active(&self) -> bool {
1170        self.schedule.iter().all(|s| s.safe_for_active_extent())
1171    }
1172
1173    fn compile_static_inner(graph: Graph) -> Self {
1174        let dev = wgpu_device().expect("rlx-wgpu: no compatible adapter found");
1175
1176        // Decompose composed/fused ops (FusedMatMulBiasAct, LoraMatMul,
1177        // FusedAttentionBlock, FusedTransformerLayer, ...) into primitive
1178        // sequences before memory planning so every intermediate gets a
1179        // regular arena slot. CPU/Metal/MLX lower the fused variants
1180        // directly with bespoke kernels; we choose simplicity over peak
1181        // throughput here.
1182        let graph = crate::unfuse::unfuse(graph);
1183
1184        // f32-uniform slots + liveness reuse (pairwise `[n,n]` graphs).
1185        let plan = plan_f32_uniform(&graph, 16);
1186        // Pre-walk to compute the max scratch any single op needs.
1187        // Currently only `Op::LayerNormBackwardGamma` uses scratch
1188        // (`num_workgroups * H * 4` bytes for the partial-sums buffer).
1189        let scratch_bytes = compute_scratch_bytes(&graph);
1190        let mut arena = Arena::from_plan_with_scratch(&dev.device, &plan, scratch_bytes);
1191        // Override slot lengths with the actual elem*4 byte counts so
1192        // readback returns the right element count (slots may be
1193        // padded for alignment).
1194        for node in graph.nodes() {
1195            let elems = node.shape.num_elements().unwrap_or(0);
1196            arena.set_actual_len(node.id, elems * 4);
1197        }
1198
1199        // Initialize Constants directly into the arena.
1200        for node in graph.nodes() {
1201            if let Op::Constant { data } = &node.op
1202                && arena.has(node.id)
1203                && !data.is_empty()
1204            {
1205                let bytes_to_write = data.len().min(arena.len_of(node.id));
1206                dev.queue.write_buffer(
1207                    &arena.buffer,
1208                    arena.offset(node.id) as u64,
1209                    &data[..bytes_to_write],
1210                );
1211            }
1212        }
1213
1214        let mut input_offsets = HashMap::new();
1215        let mut param_offsets = HashMap::new();
1216        for node in graph.nodes() {
1217            match &node.op {
1218                Op::Input { name } => {
1219                    input_offsets.insert(name.clone(), node.id);
1220                }
1221                Op::Param { name } => {
1222                    param_offsets.insert(name.clone(), node.id);
1223                }
1224                _ => {}
1225            }
1226        }
1227
1228        let mm_k = matmul_kernel(&dev.device);
1229        let mm_w = matmul_wide_kernel(&dev.device);
1230        let _mm_w_active = matmul_wide_active_kernel(&dev.device);
1231        let mm_f16w = matmul_f16w_kernel(&dev.device);
1232        let mm_f16c = matmul_f16_compute_kernel(&dev.device);
1233        let mm_coop = matmul_coop16_kernel(&dev.device);
1234        let mm_coop_f32 = matmul_coop_f32_active_kernel(&dev.device);
1235        let mm_cast = cast_f32_to_f16_kernel(&dev.device);
1236        let bk = binary_kernel(&dev.device);
1237        let uk = unary_kernel(&dev.device);
1238        let ck = compare_kernel(&dev.device);
1239        let wk = where_kernel(&dev.device);
1240
1241        let mut schedule = Vec::new();
1242        let mut uniforms = Vec::new();
1243        let mut bind_groups = Vec::new();
1244        let mut fft_gpu_steps: Vec<crate::fft_dispatch::FftGpuResources> = Vec::new();
1245        let mut gguf_host_pad: Option<(wgpu::Buffer, wgpu::BindGroup)> = None;
1246        let mut meta_buffers: Vec<wgpu::Buffer> = Vec::new();
1247        let mut coop_f16_b_param: HashMap<u32, String> = HashMap::new();
1248        let mut coop_f16_vk_wide_bind_groups: HashMap<usize, wgpu::BindGroup> = HashMap::new();
1249        let mm_w_active_compile = matmul_wide_active_kernel(&dev.device);
1250
1251        let coop_f16_vk_mirror_acts = collect_coop_f16_vk_mirror_activations(&graph, &dev.device);
1252
1253        // Detect (FusedMatMulBiasAct → Narrow×3) split-QKV pattern. Returns
1254        // a map parent_node_id → (q_narrow_id, k_narrow_id, v_narrow_id).
1255        // The matmul_qkv kernel collapses the matmul + 3 narrows into one
1256        // dispatch by routing each output column to the right Q/K/V sink.
1257        //
1258        // CRITICAL: only mark a pattern site for elision when the parent
1259        // FMB will actually take the MatmulQkv path (which only fires
1260        // for F32 compute precision). For Coop16/CoopF32-eligible FMBs,
1261        // those kernels write to the FMB's *own* output slot, NOT the
1262        // 3 narrow slots — skipping the narrows would leave Q/K/V
1263        // uninitialized and attention would read garbage. Predict the
1264        // compute precision the FMB will receive; only skip when F32.
1265        let mut qkv_split: HashMap<NodeId, (NodeId, NodeId, NodeId)> = HashMap::new();
1266        for (parent_id, qkv) in detect_split_qkv_pattern(&graph) {
1267            let parent = graph.node(parent_id);
1268            // Mirror the lowering's precision derivation. FMB inputs:
1269            // [a, w, bias]; we need (m, k, n) to query.
1270            let a_id = parent.inputs[0];
1271            let b_id = parent.inputs[1];
1272            let a_dims = graph.node(a_id).shape.dims();
1273            let b_dims = graph.node(b_id).shape.dims();
1274            let out_dims = parent.shape.dims();
1275            let (m, k, n) =
1276                if a_dims.len() >= 2 && b_dims.len() == 2 && out_dims.len() == a_dims.len() {
1277                    let leading: usize = a_dims[..a_dims.len() - 2]
1278                        .iter()
1279                        .map(|d| d.unwrap_static())
1280                        .product();
1281                    let m_inner = a_dims[a_dims.len() - 2].unwrap_static();
1282                    let k_inner = a_dims[a_dims.len() - 1].unwrap_static();
1283                    let n_inner = b_dims[1].unwrap_static();
1284                    ((leading * m_inner) as u32, k_inner as u32, n_inner as u32)
1285                } else if a_dims.len() == 2 && b_dims.len() == 2 {
1286                    (
1287                        a_dims[0].unwrap_static() as u32,
1288                        a_dims[1].unwrap_static() as u32,
1289                        b_dims[1].unwrap_static() as u32,
1290                    )
1291                } else {
1292                    continue; // unusual shape — let the regular FMB path handle
1293                };
1294            let cp = derive_matmul_compute(
1295                &dev.device,
1296                &graph,
1297                &coop_f16_vk_mirror_acts,
1298                a_id,
1299                b_id,
1300                m,
1301                k,
1302                n,
1303            );
1304            // F32 → matmul_qkv. CoopF32 → matmul_qkv_coop_f32. Both write
1305            // Q/K/V into the narrow output slots, so the narrows can be
1306            // elided. Coop16 still falls back to FMB+narrows (kernel
1307            // would need an f16-acc variant; deferred).
1308            if cp == MatmulCompute::F32 || cp == MatmulCompute::CoopF32 {
1309                qkv_split.insert(parent_id, qkv);
1310            }
1311        }
1312        let qkv_skip_narrows: HashSet<NodeId> = qkv_split
1313            .values()
1314            .flat_map(|&(q, k, v)| [q, k, v])
1315            .collect();
1316
1317        // EEG-DINO / packed QKV: FMB → [B,S,3,H,D] → Narrow×3 (axis 2) → Attention.
1318        // Match CPU `compile_thunks` fused_strided_attn: read Q/K/V from the
1319        // packed parent with seq stride 3·H·D instead of materializing narrows.
1320        let mut packed_bshd_attn: HashMap<NodeId, (NodeId, u32)> = HashMap::new();
1321        let mut packed_bshd_skip_narrows: HashSet<NodeId> = HashSet::new();
1322        if !rlx_ir::env::flag("RLX_WGPU_NO_PACKED_BSHD_ATTN") {
1323            for node in graph.nodes() {
1324                let Op::Attention { .. } = &node.op else {
1325                    continue;
1326                };
1327                if node.inputs.len() < 3 {
1328                    continue;
1329                }
1330                if let Some((parent, head_width, narrows)) =
1331                    rlx_ir::detect_packed_bshd_qkv_attention(
1332                        &graph,
1333                        node.inputs[0],
1334                        node.inputs[1],
1335                        node.inputs[2],
1336                    )
1337                {
1338                    packed_bshd_attn.insert(node.id, (parent, head_width as u32));
1339                    for narrow in narrows {
1340                        if rlx_ir::packed_bshd_narrow_elidable(&graph, narrow, node.id) {
1341                            packed_bshd_skip_narrows.insert(narrow);
1342                        }
1343                    }
1344                }
1345            }
1346        }
1347
1348        // Detect (Add → LayerNorm) where Add has multi-consumer downstream.
1349        // The standard `FuseResidualLN` pass declines to fuse these (its
1350        // single-consumer guard forces materializing the sum); we collapse
1351        // them here at the wgpu lowering level via `Step::FusedResidualLnTee`.
1352        // Returns:
1353        //   ln_to_tee: ln_id  → (h, delta, gamma, beta, sum_arena_id)
1354        //   skip_adds: { add_id }  — these Add nodes are computed by the
1355        //                            tee step; their normal Step emission
1356        //                            is suppressed.
1357        let (ln_to_tee, skip_adds) = detect_residual_ln_tee_pattern(&graph);
1358
1359        let mut coop_f16_host_activations: Vec<(NodeId, Activation, String)> = Vec::new();
1360
1361        let emit_uniform = |size: usize| -> wgpu::Buffer {
1362            dev.device.create_buffer(&wgpu::BufferDescriptor {
1363                label: Some("rlx-wgpu uniform"),
1364                size: size as u64,
1365                usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1366                mapped_at_creation: false,
1367            })
1368        };
1369
1370        for node in graph.nodes() {
1371            // Helpers — capture device + arena into closures isn't
1372            // ergonomic in the loop, so inline the bind-group build
1373            // when each step is emitted below.
1374            let elems = node.shape.num_elements().unwrap_or(0) as u32;
1375            match &node.op {
1376                Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => continue,
1377                Op::MatMul => {
1378                    let a_id = node.inputs[0];
1379                    let b_id = node.inputs[1];
1380                    let a_shape = graph.node(a_id).shape.dims();
1381                    let b_shape = graph.node(b_id).shape.dims();
1382                    let out_shape = node.shape.dims();
1383                    // Three patterns:
1384                    //   • 2D×2D                              → batch=1
1385                    //   • [..,M,K] × [K,N]  (broadcast rhs)  → batch=1, flatten leading into M
1386                    //   • [..,M,K] × [..,K,N] (matched batch)→ batch=prod(leading), per-batch strides
1387                    let (m, k, n, batch, a_bs, b_bs, c_bs) = if a_shape.len() == 2
1388                        && b_shape.len() == 2
1389                        && out_shape.len() == 2
1390                    {
1391                        (
1392                            a_shape[0].unwrap_static() as u32,
1393                            a_shape[1].unwrap_static() as u32,
1394                            b_shape[1].unwrap_static() as u32,
1395                            1u32,
1396                            0u32,
1397                            0u32,
1398                            0u32,
1399                        )
1400                    } else if a_shape.len() >= 2
1401                        && b_shape.len() == 2
1402                        && out_shape.len() == a_shape.len()
1403                    {
1404                        let leading: usize = a_shape[..a_shape.len() - 2]
1405                            .iter()
1406                            .map(|d| d.unwrap_static())
1407                            .product();
1408                        let m_inner = a_shape[a_shape.len() - 2].unwrap_static();
1409                        let k_inner = a_shape[a_shape.len() - 1].unwrap_static();
1410                        let n_inner = b_shape[1].unwrap_static();
1411                        (
1412                            (leading * m_inner) as u32,
1413                            k_inner as u32,
1414                            n_inner as u32,
1415                            1u32,
1416                            0u32,
1417                            0u32,
1418                            0u32,
1419                        )
1420                    } else if a_shape.len() == b_shape.len()
1421                        && a_shape.len() >= 3
1422                        && out_shape.len() == a_shape.len()
1423                    {
1424                        // True batched: leading dims must match.
1425                        let leading_a: Vec<usize> = a_shape[..a_shape.len() - 2]
1426                            .iter()
1427                            .map(|d| d.unwrap_static())
1428                            .collect();
1429                        let leading_b: Vec<usize> = b_shape[..b_shape.len() - 2]
1430                            .iter()
1431                            .map(|d| d.unwrap_static())
1432                            .collect();
1433                        if leading_a != leading_b {
1434                            panic!(
1435                                "rlx-wgpu MatMul: batched shape mismatch \
1436                                    a_leading={leading_a:?} b_leading={leading_b:?}"
1437                            );
1438                        }
1439                        let b_count: usize = leading_a.iter().product();
1440                        let m_inner = a_shape[a_shape.len() - 2].unwrap_static();
1441                        let k_inner = a_shape[a_shape.len() - 1].unwrap_static();
1442                        let n_inner = b_shape[b_shape.len() - 1].unwrap_static();
1443                        (
1444                            m_inner as u32,
1445                            k_inner as u32,
1446                            n_inner as u32,
1447                            b_count as u32,
1448                            (m_inner * k_inner) as u32,
1449                            (k_inner * n_inner) as u32,
1450                            (m_inner * n_inner) as u32,
1451                        )
1452                    } else {
1453                        panic!(
1454                            "rlx-wgpu MatMul: unsupported shapes a={a_shape:?} b={b_shape:?} \
1455                                out={out_shape:?} (supported: 2D×2D, [..,M,K]×[K,N], [..,M,K]×[..,K,N])"
1456                        );
1457                    };
1458                    let b_is_param = tensor_is_graph_param(&graph, &param_offsets, b_id);
1459                    let b_bytes = arena.len_of(b_id) as u64;
1460                    let mut compute_precision = derive_matmul_compute(
1461                        &dev.device,
1462                        &graph,
1463                        &coop_f16_vk_mirror_acts,
1464                        a_id,
1465                        b_id,
1466                        m,
1467                        k,
1468                        n,
1469                    );
1470                    if b_is_param && b_bytes > ARENA_STAGE_CAP && arena.param_fits_f16_mirror(b_id)
1471                    {
1472                        compute_precision = MatmulCompute::F16;
1473                    }
1474                    let (mut base, mut size, param_anchor) = arena_matmul_bind_window(
1475                        &dev.device,
1476                        &arena,
1477                        &graph,
1478                        &param_offsets,
1479                        node.id,
1480                        a_id,
1481                        b_id,
1482                    );
1483                    let max_binding = dev.device.limits().max_storage_buffer_binding_size;
1484                    arena_expand_bind_window(
1485                        &arena,
1486                        &[node.id, a_id, b_id],
1487                        &mut base,
1488                        &mut size,
1489                        max_binding,
1490                    );
1491                    let mut scratch = arena.scratch_off as u64;
1492                    if param_anchor {
1493                        arena_ensure_scratch_in_window(&mut scratch, base, size);
1494                    }
1495                    if b_is_param && b_bytes > ARENA_STAGE_CAP {
1496                        assert!(
1497                            param_anchor && arena_tensor_in_window(&arena, b_id, base, size),
1498                            "rlx-wgpu matmul: large param B {:?} off={} not in window base={base} size={size}",
1499                            b_id,
1500                            arena.offset(b_id),
1501                        );
1502                    }
1503                    let a_off_f32 = arena_off_in_bind_window(
1504                        &graph,
1505                        &param_offsets,
1506                        &dev.device,
1507                        &arena,
1508                        &mut schedule,
1509                        &mut scratch,
1510                        a_id,
1511                        &mut base,
1512                        &mut size,
1513                    );
1514                    let b_off_f32 = if b_is_param
1515                        && b_bytes > ARENA_STAGE_CAP
1516                        && arena_tensor_in_window(&arena, b_id, base, size)
1517                    {
1518                        arena_local_off_f32(&arena, b_id, base)
1519                    } else {
1520                        arena_off_in_bind_window(
1521                            &graph,
1522                            &param_offsets,
1523                            &dev.device,
1524                            &arena,
1525                            &mut schedule,
1526                            &mut scratch,
1527                            b_id,
1528                            &mut base,
1529                            &mut size,
1530                        )
1531                    };
1532                    maybe_push_coop_f16_vk_casts(
1533                        &graph,
1534                        a_id,
1535                        b_id,
1536                        &coop_f16_vk_mirror_acts,
1537                        &dev.device,
1538                        &arena,
1539                        &mut schedule,
1540                        &mut uniforms,
1541                        &mut bind_groups,
1542                        &mm_cast,
1543                        compute_precision,
1544                        a_off_f32,
1545                        m,
1546                        k,
1547                        batch,
1548                        b_off_f32,
1549                        n,
1550                    );
1551                    schedule.push(Step::Matmul {
1552                        m,
1553                        k,
1554                        n,
1555                        batch,
1556                        a_batch_stride: a_bs,
1557                        b_batch_stride: b_bs,
1558                        c_batch_stride: c_bs,
1559                        a_off_f32,
1560                        b_off_f32,
1561                        c_off_f32: arena_local_off_f32(&arena, node.id, base),
1562                        has_bias: 0,
1563                        bias_off_f32: 0,
1564                        act_id: 0xFFFF,
1565                        b_is_param,
1566                        compute_precision,
1567                    });
1568                    let b_off_global = (arena.offset(b_id) / 4) as u32;
1569                    let b_off_bind = if b_is_param
1570                        && matches!(
1571                            compute_precision,
1572                            MatmulCompute::Coop16 | MatmulCompute::CoopF16Vk | MatmulCompute::F16
1573                        ) {
1574                        b_off_global
1575                    } else {
1576                        b_off_f32
1577                    };
1578                    register_coop_f16_vk_b_param(
1579                        &mut coop_f16_b_param,
1580                        &param_offsets,
1581                        b_id,
1582                        b_off_bind,
1583                        compute_precision,
1584                    );
1585                    let u = emit_uniform(std::mem::size_of::<MatmulParams>());
1586                    let (bg, b_off_adj) = build_matmul_bind_group(
1587                        &dev.device,
1588                        mm_k,
1589                        mm_w,
1590                        &mm_f16w,
1591                        &mm_f16c,
1592                        &mm_coop,
1593                        &mm_coop_f32,
1594                        &arena,
1595                        base,
1596                        size,
1597                        &u,
1598                        b_is_param,
1599                        compute_precision,
1600                        k,
1601                        n,
1602                        batch,
1603                        b_off_bind,
1604                        b_bs,
1605                    );
1606                    if let Some(Step::Matmul { b_off_f32, .. }) = schedule.last_mut() {
1607                        *b_off_f32 = b_off_adj;
1608                    }
1609                    uniforms.push(u);
1610                    bind_groups.push(bg);
1611                    if compute_precision == MatmulCompute::CoopF16Vk {
1612                        coop_f16_vk_wide_bind_groups.insert(
1613                            schedule.len() - 1,
1614                            bind_two_buf0_window(
1615                                &dev.device,
1616                                mm_w_active_compile,
1617                                &arena.buffer,
1618                                base,
1619                                size,
1620                                &uniforms[uniforms.len() - 1],
1621                            ),
1622                        );
1623                    }
1624                }
1625                Op::Binary(bop) => {
1626                    // Skip emit when this Add is consumed by a downstream
1627                    // FRLTee — the tee step writes the sum to this node's
1628                    // arena slot directly. Subsequent consumers read the
1629                    // same slot and find correct data.
1630                    if skip_adds.contains(&node.id) {
1631                        continue;
1632                    }
1633                    require_equal_shapes(&graph, &node.inputs, "Binary");
1634                    let a_id = node.inputs[0];
1635                    let b_id = node.inputs[1];
1636                    let win_ids = [node.id, a_id, b_id];
1637                    let max_binding = dev.device.limits().max_storage_buffer_binding_size;
1638                    let fits = arena_span_bytes(&arena, &win_ids) <= max_binding;
1639                    let mut scratch = arena.scratch_off as u64;
1640                    let (mut base, mut size, param_anchor) = arena_multi_op_window(
1641                        &dev.device,
1642                        &arena,
1643                        &graph,
1644                        &param_offsets,
1645                        &mut schedule,
1646                        &mut scratch,
1647                        &win_ids,
1648                    );
1649                    if !fits && !param_anchor {
1650                        base = arena_bind_window_covering_scratch_if_needed(
1651                            &arena, base, size, scratch,
1652                        );
1653                    }
1654                    let a_off = arena_off_in_bind_window(
1655                        &graph,
1656                        &param_offsets,
1657                        &dev.device,
1658                        &arena,
1659                        &mut schedule,
1660                        &mut scratch,
1661                        a_id,
1662                        &mut base,
1663                        &mut size,
1664                    );
1665                    let b_off = arena_off_in_bind_window(
1666                        &graph,
1667                        &param_offsets,
1668                        &dev.device,
1669                        &arena,
1670                        &mut schedule,
1671                        &mut scratch,
1672                        b_id,
1673                        &mut base,
1674                        &mut size,
1675                    );
1676                    let p = BinaryParams {
1677                        n: elems,
1678                        a_off,
1679                        b_off,
1680                        c_off: arena_local_off_f32(&arena, node.id, base),
1681                        op: binary_op_id(*bop),
1682                        _p0: 0,
1683                        _p1: 0,
1684                        _p2: 0,
1685                    };
1686                    schedule.push(Step::Binary { params: p });
1687                    let u = emit_uniform(std::mem::size_of::<BinaryParams>());
1688                    let bg = bind_two_buf0_window(&dev.device, bk, &arena.buffer, base, size, &u);
1689                    uniforms.push(u);
1690                    bind_groups.push(bg);
1691                }
1692                Op::Compare(cop) => {
1693                    require_equal_shapes(&graph, &node.inputs, "Compare");
1694                    let (mut base, size) = arena_window_for_nodes(&dev.device, &arena, &[node.id]);
1695                    let a_id = node.inputs[0];
1696                    let b_id = node.inputs[1];
1697                    let a_src = arena.offset(a_id) as u64;
1698                    let b_src = arena.offset(b_id) as u64;
1699                    let a_len = arena.len_of(a_id) as u64;
1700                    let b_len = arena.len_of(b_id) as u64;
1701                    let a_in = a_src >= base && a_src + a_len <= base + size;
1702                    let b_in = b_src >= base && b_src + b_len <= base + size;
1703                    let a_dst = arena.scratch_off as u64;
1704                    let a_aligned = a_len.div_ceil(256) * 256;
1705                    let b_dst = a_dst + a_aligned;
1706                    if a_dst < base || b_dst + b_len > base + size {
1707                        base = (arena.size as u64).saturating_sub(size);
1708                        base = (base / 256) * 256;
1709                    }
1710                    let a_off = if a_in {
1711                        arena_local_off_f32(&arena, a_id, base)
1712                    } else {
1713                        if a_len > 64 * 1024 * 1024 {
1714                            panic!("rlx-wgpu: Compare staging operand A too large ({a_len} bytes)");
1715                        }
1716                        schedule.push(Step::BufferCopy {
1717                            src_byte_off: a_src as u32,
1718                            dst_byte_off: a_dst as u32,
1719                            bytes: a_len as u32,
1720                        });
1721                        ((a_dst.saturating_sub(base)) / 4) as u32
1722                    };
1723                    let b_off = if b_in {
1724                        arena_local_off_f32(&arena, b_id, base)
1725                    } else {
1726                        if b_len > 64 * 1024 * 1024 {
1727                            panic!("rlx-wgpu: Compare staging operand B too large ({b_len} bytes)");
1728                        }
1729                        schedule.push(Step::BufferCopy {
1730                            src_byte_off: b_src as u32,
1731                            dst_byte_off: b_dst as u32,
1732                            bytes: b_len as u32,
1733                        });
1734                        ((b_dst.saturating_sub(base)) / 4) as u32
1735                    };
1736                    let p = BinaryParams {
1737                        n: elems,
1738                        a_off,
1739                        b_off,
1740                        c_off: arena_local_off_f32(&arena, node.id, base),
1741                        op: compare_op_id(*cop),
1742                        _p0: 0,
1743                        _p1: 0,
1744                        _p2: 0,
1745                    };
1746                    schedule.push(Step::Compare { params: p });
1747                    let u = emit_uniform(std::mem::size_of::<BinaryParams>());
1748                    let bg = bind_two_buf0_window(&dev.device, ck, &arena.buffer, base, size, &u);
1749                    uniforms.push(u);
1750                    bind_groups.push(bg);
1751                }
1752                Op::Activation(act) => {
1753                    if coop_f16_vk_mirror_acts.contains(&node.id) {
1754                        let src_name =
1755                            tensor_host_name(&input_offsets, &param_offsets, node.inputs[0]);
1756                        coop_f16_host_activations.push((node.id, *act, src_name));
1757                        continue;
1758                    }
1759                    let in_id = node.inputs[0];
1760                    let win_ids = [node.id, in_id];
1761                    let max_binding = dev.device.limits().max_storage_buffer_binding_size;
1762                    let fits = arena_span_bytes(&arena, &win_ids) <= max_binding;
1763                    let mut scratch = arena.scratch_off as u64;
1764                    let (mut base, mut size, param_anchor) = arena_multi_op_window(
1765                        &dev.device,
1766                        &arena,
1767                        &graph,
1768                        &param_offsets,
1769                        &mut schedule,
1770                        &mut scratch,
1771                        &win_ids,
1772                    );
1773                    if !fits && !param_anchor {
1774                        base = arena_bind_window_covering_scratch_if_needed(
1775                            &arena, base, size, scratch,
1776                        );
1777                    }
1778                    let in_off = arena_off_in_bind_window(
1779                        &graph,
1780                        &param_offsets,
1781                        &dev.device,
1782                        &arena,
1783                        &mut schedule,
1784                        &mut scratch,
1785                        in_id,
1786                        &mut base,
1787                        &mut size,
1788                    );
1789                    let p = UnaryParams {
1790                        n: elems,
1791                        in_off,
1792                        out_off: arena_local_off_f32(&arena, node.id, base),
1793                        op: activation_op_id(*act),
1794                        _p0: 0,
1795                        _p1: 0,
1796                        _p2: 0,
1797                        _p3: 0,
1798                    };
1799                    schedule.push(Step::Unary {
1800                        params: p,
1801                        f16_mirror: false,
1802                    });
1803                    let u = emit_uniform(std::mem::size_of::<UnaryParams>());
1804                    let bg = bind_two_buf0_window(&dev.device, uk, &arena.buffer, base, size, &u);
1805                    uniforms.push(u);
1806                    bind_groups.push(bg);
1807                }
1808                Op::Where => {
1809                    let (mut base, size) = arena_window_for_nodes(&dev.device, &arena, &[node.id]);
1810                    let cond_id = node.inputs[0];
1811                    let x_id = node.inputs[1];
1812                    let y_id = node.inputs[2];
1813                    let cond_src = arena.offset(cond_id) as u64;
1814                    let x_src = arena.offset(x_id) as u64;
1815                    let y_src = arena.offset(y_id) as u64;
1816                    let cond_len = arena.len_of(cond_id) as u64;
1817                    let x_len = arena.len_of(x_id) as u64;
1818                    let y_len = arena.len_of(y_id) as u64;
1819                    let cond_in = cond_src >= base && cond_src + cond_len <= base + size;
1820                    let x_in = x_src >= base && x_src + x_len <= base + size;
1821                    let y_in = y_src >= base && y_src + y_len <= base + size;
1822                    let cond_dst = arena.scratch_off as u64;
1823                    let cond_aligned = cond_len.div_ceil(256) * 256;
1824                    let x_dst = cond_dst + cond_aligned;
1825                    let x_aligned = x_len.div_ceil(256) * 256;
1826                    let y_dst = x_dst + x_aligned;
1827                    if cond_dst < base || y_dst + y_len > base + size {
1828                        base = (arena.size as u64).saturating_sub(size);
1829                        base = (base / 256) * 256;
1830                    }
1831                    let cond_off = if cond_in {
1832                        arena_local_off_f32(&arena, cond_id, base)
1833                    } else {
1834                        if cond_len > 64 * 1024 * 1024 {
1835                            panic!("rlx-wgpu: Where staging cond too large ({cond_len} bytes)");
1836                        }
1837                        schedule.push(Step::BufferCopy {
1838                            src_byte_off: cond_src as u32,
1839                            dst_byte_off: cond_dst as u32,
1840                            bytes: cond_len as u32,
1841                        });
1842                        ((cond_dst.saturating_sub(base)) / 4) as u32
1843                    };
1844                    let x_off = if x_in {
1845                        arena_local_off_f32(&arena, x_id, base)
1846                    } else {
1847                        if x_len > 64 * 1024 * 1024 {
1848                            panic!("rlx-wgpu: Where staging x too large ({x_len} bytes)");
1849                        }
1850                        schedule.push(Step::BufferCopy {
1851                            src_byte_off: x_src as u32,
1852                            dst_byte_off: x_dst as u32,
1853                            bytes: x_len as u32,
1854                        });
1855                        ((x_dst.saturating_sub(base)) / 4) as u32
1856                    };
1857                    let y_off = if y_in {
1858                        arena_local_off_f32(&arena, y_id, base)
1859                    } else {
1860                        if y_len > 64 * 1024 * 1024 {
1861                            panic!("rlx-wgpu: Where staging y too large ({y_len} bytes)");
1862                        }
1863                        schedule.push(Step::BufferCopy {
1864                            src_byte_off: y_src as u32,
1865                            dst_byte_off: y_dst as u32,
1866                            bytes: y_len as u32,
1867                        });
1868                        ((y_dst.saturating_sub(base)) / 4) as u32
1869                    };
1870                    let p = WhereParams {
1871                        n: elems,
1872                        cond_off,
1873                        x_off,
1874                        y_off,
1875                        out_off: arena_local_off_f32(&arena, node.id, base),
1876                        _p0: 0,
1877                        _p1: 0,
1878                        _p2: 0,
1879                    };
1880                    schedule.push(Step::Where { params: p });
1881                    let u = emit_uniform(std::mem::size_of::<WhereParams>());
1882                    let bg = bind_two_buf0_window(&dev.device, wk, &arena.buffer, base, size, &u);
1883                    uniforms.push(u);
1884                    bind_groups.push(bg);
1885                }
1886
1887                Op::BatchElementwiseRegion {
1888                    chain,
1889                    num_batch_inputs,
1890                    scalar_input_mask,
1891                    input_modulus,
1892                    prologue,
1893                    prologue_input,
1894                } => {
1895                    let n = *num_batch_inputs as usize;
1896                    if n == 0 || chain.len() > 32 {
1897                        panic!(
1898                            "rlx-wgpu BatchElementwiseRegion: num_batch_inputs={n} steps={}",
1899                            chain.len()
1900                        );
1901                    }
1902                    let slice_shape = rlx_ir::batch_region_slice_shape(&node.shape);
1903                    let slice_elems = rlx_ir::batch_region_slice_elems(&node.shape, n)
1904                        .expect("batch region static shape");
1905                    let mut win_ids: Vec<NodeId> = vec![node.id];
1906                    win_ids.extend(node.inputs.iter().copied());
1907                    let max_binding = dev.device.limits().max_storage_buffer_binding_size;
1908                    let fits = arena_span_bytes(&arena, &win_ids) <= max_binding;
1909                    let mut scratch = arena.scratch_off as u64;
1910                    let (mut base, mut size, param_anchor) = arena_multi_op_window(
1911                        &dev.device,
1912                        &arena,
1913                        &graph,
1914                        &param_offsets,
1915                        &mut schedule,
1916                        &mut scratch,
1917                        &win_ids,
1918                    );
1919                    if !fits && !param_anchor {
1920                        base = arena_bind_window_covering_scratch_if_needed(
1921                            &arena, base, size, scratch,
1922                        );
1923                    }
1924                    let chain_enc = rlx_ir::encode_chain_steps(chain);
1925                    let tail =
1926                        rlx_ir::encode_prologue_tail(*prologue, &slice_shape, *prologue_input);
1927                    let base_dst = arena_local_off_f32(&arena, node.id, base);
1928                    let use_single = rlx_ir::fk_batch_use_single_launch(n, *prologue);
1929                    if use_single {
1930                        let mut batch_input_offs = [0u32; 64];
1931                        for i in 0..n {
1932                            batch_input_offs[i] = arena_off_in_bind_window(
1933                                &graph,
1934                                &param_offsets,
1935                                &dev.device,
1936                                &arena,
1937                                &mut schedule,
1938                                &mut scratch,
1939                                node.inputs[i],
1940                                &mut base,
1941                                &mut size,
1942                            );
1943                        }
1944                        let p = BatchElementwiseRegionParams {
1945                            slice_len: slice_elems,
1946                            num_batch: n as u32,
1947                            num_steps: chain.len() as u32,
1948                            base_dst_off: base_dst,
1949                            slice_elems,
1950                            batch_input_offs,
1951                            chain: chain_enc,
1952                            scalar_input_mask: *scalar_input_mask,
1953                            input_modulus: *input_modulus,
1954                        };
1955                        schedule.push(Step::BatchElementwiseRegion { params: p });
1956                        let ek = batch_elementwise_region_kernel(&dev.device);
1957                        let u = dev.device.create_buffer(&wgpu::BufferDescriptor {
1958                            label: Some("rlx-wgpu batch region params"),
1959                            size: std::mem::size_of::<BatchElementwiseRegionParams>() as u64,
1960                            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
1961                            mapped_at_creation: false,
1962                        });
1963                        let bg =
1964                            bind_two_buf0_window(&dev.device, ek, &arena.buffer, base, size, &u);
1965                        uniforms.push(u);
1966                        bind_groups.push(bg);
1967                    } else {
1968                        let spatial = tail[0] == rlx_ir::REGION_PROLOGUE_RESIZE_NEAREST_2X_NCHW;
1969                        let ek = if spatial {
1970                            elementwise_region_spatial_kernel(&dev.device)
1971                        } else {
1972                            elementwise_region_kernel(&dev.device)
1973                        };
1974                        for i in 0..n {
1975                            let mut input_offs = [0u32; 16];
1976                            input_offs[0] = arena_off_in_bind_window(
1977                                &graph,
1978                                &param_offsets,
1979                                &dev.device,
1980                                &arena,
1981                                &mut schedule,
1982                                &mut scratch,
1983                                node.inputs[i],
1984                                &mut base,
1985                                &mut size,
1986                            );
1987                            let p = ElementwiseRegionParams {
1988                                len: slice_elems,
1989                                num_inputs: 1,
1990                                num_steps: chain.len() as u32,
1991                                dst_off: rlx_ir::batch_region_slice_dst_off_f32(
1992                                    base_dst,
1993                                    slice_elems,
1994                                    i,
1995                                ),
1996                                input_offs,
1997                                chain: chain_enc,
1998                                scalar_input_mask: *scalar_input_mask,
1999                                prologue: tail[0],
2000                                out_n: tail[1],
2001                                out_c: tail[2],
2002                                out_h: tail[3],
2003                                out_w: tail[4],
2004                                prologue_input: tail[5],
2005                                input_modulus: *input_modulus,
2006                            };
2007                            schedule.push(Step::ElementwiseRegion { params: p });
2008                            let u = dev.device.create_buffer(&wgpu::BufferDescriptor {
2009                                label: Some("rlx-wgpu batch region params"),
2010                                size: std::mem::size_of::<ElementwiseRegionParams>() as u64,
2011                                usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
2012                                mapped_at_creation: false,
2013                            });
2014                            let bg = bind_two_buf0_window(
2015                                &dev.device,
2016                                ek,
2017                                &arena.buffer,
2018                                base,
2019                                size,
2020                                &u,
2021                            );
2022                            uniforms.push(u);
2023                            bind_groups.push(bg);
2024                        }
2025                    }
2026                }
2027                Op::ElementwiseRegion {
2028                    chain,
2029                    num_inputs,
2030                    scalar_input_mask,
2031                    input_modulus,
2032                    prologue,
2033                    prologue_input,
2034                } => {
2035                    // PLAN L2 native lowering. Encode the chain into a
2036                    // fixed-size u32 buffer; one uniform per region.
2037                    let n = *num_inputs as usize;
2038                    if n > 16 || chain.len() > 32 {
2039                        panic!(
2040                            "rlx-wgpu ElementwiseRegion: chain too large \
2041                                (inputs={n}, steps={}). Caps: 16 / 32. \
2042                                Use UnfuseElementwiseRegions to fall back.",
2043                            chain.len()
2044                        );
2045                    }
2046                    let mut win_ids: Vec<NodeId> = vec![node.id];
2047                    win_ids.extend(node.inputs.iter().copied());
2048                    let max_binding = dev.device.limits().max_storage_buffer_binding_size;
2049                    let fits = arena_span_bytes(&arena, &win_ids) <= max_binding;
2050                    let mut scratch = arena.scratch_off as u64;
2051                    let (mut base, mut size, param_anchor) = arena_multi_op_window(
2052                        &dev.device,
2053                        &arena,
2054                        &graph,
2055                        &param_offsets,
2056                        &mut schedule,
2057                        &mut scratch,
2058                        &win_ids,
2059                    );
2060                    if !fits && !param_anchor {
2061                        base = arena_bind_window_covering_scratch_if_needed(
2062                            &arena, base, size, scratch,
2063                        );
2064                    }
2065                    let mut input_offs = [0u32; 16];
2066                    for (i, &id) in node.inputs.iter().enumerate() {
2067                        input_offs[i] = arena_off_in_bind_window(
2068                            &graph,
2069                            &param_offsets,
2070                            &dev.device,
2071                            &arena,
2072                            &mut schedule,
2073                            &mut scratch,
2074                            id,
2075                            &mut base,
2076                            &mut size,
2077                        );
2078                    }
2079                    let chain_enc = rlx_ir::encode_chain_steps(chain);
2080                    let tail =
2081                        rlx_ir::encode_prologue_tail(*prologue, &node.shape, *prologue_input);
2082                    let p = ElementwiseRegionParams {
2083                        len: elems,
2084                        num_inputs: *num_inputs,
2085                        num_steps: chain.len() as u32,
2086                        dst_off: arena_local_off_f32(&arena, node.id, base),
2087                        input_offs,
2088                        chain: chain_enc,
2089                        scalar_input_mask: *scalar_input_mask,
2090                        prologue: tail[0],
2091                        out_n: tail[1],
2092                        out_c: tail[2],
2093                        out_h: tail[3],
2094                        out_w: tail[4],
2095                        prologue_input: tail[5],
2096                        input_modulus: *input_modulus,
2097                    };
2098                    schedule.push(Step::ElementwiseRegion { params: p });
2099                    let ek = if p.prologue == rlx_ir::REGION_PROLOGUE_RESIZE_NEAREST_2X_NCHW {
2100                        elementwise_region_spatial_kernel(&dev.device)
2101                    } else {
2102                        elementwise_region_kernel(&dev.device)
2103                    };
2104                    // STORAGE (not UNIFORM) — the WGSL params struct
2105                    // contains `array<u32, N>` arrays whose 4-byte
2106                    // stride violates uniform's 16-byte stride rule.
2107                    let u = dev.device.create_buffer(&wgpu::BufferDescriptor {
2108                        label: Some("rlx-wgpu region params"),
2109                        size: std::mem::size_of::<ElementwiseRegionParams>() as u64,
2110                        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
2111                        mapped_at_creation: false,
2112                    });
2113                    let bg = bind_two_buf0_window(&dev.device, ek, &arena.buffer, base, size, &u);
2114                    uniforms.push(u);
2115                    bind_groups.push(bg);
2116                }
2117
2118                Op::Reduce {
2119                    op: rop,
2120                    axes,
2121                    keep_dim: _,
2122                } => {
2123                    // Single-axis reduce OR contiguous multi-axis reduce.
2124                    // The kernel walks the input as `[outer, reduce_dim,
2125                    // inner]` — for contiguous axes [k..k+m], we set
2126                    // `reduce_dim = product(dims[k..k+m])`.
2127                    // Non-contiguous reductions are not yet wired (no
2128                    // model has hit them); transposing into contiguous
2129                    // form first is the future fix.
2130                    let in_id = node.inputs[0];
2131                    let in_shape = graph.node(in_id).shape.dims();
2132                    let mut sorted = axes.clone();
2133                    sorted.sort_unstable();
2134                    let contiguous = sorted.windows(2).all(|w| w[1] == w[0] + 1);
2135                    if !contiguous {
2136                        panic!(
2137                            "rlx-wgpu Reduce: non-contiguous axes not yet wired \
2138                             (got axes={axes:?}, rank={})",
2139                            in_shape.len()
2140                        );
2141                    }
2142                    let ax_first = sorted[0];
2143                    let ax_last = *sorted.last().unwrap();
2144                    let dims_u32: Vec<u32> =
2145                        in_shape.iter().map(|d| d.unwrap_static() as u32).collect();
2146                    let outer: u32 = dims_u32[..ax_first].iter().product();
2147                    let reduce_dim: u32 = dims_u32[ax_first..=ax_last].iter().product();
2148                    let inner: u32 = dims_u32[ax_last + 1..].iter().product();
2149                    let red_ids = [node.id, in_id];
2150                    let max_binding = dev.device.limits().max_storage_buffer_binding_size;
2151                    let red_fits = arena_span_bytes(&arena, &red_ids) <= max_binding;
2152                    let mut scratch = arena.scratch_off as u64;
2153                    let (mut base, mut size, param_anchor) = arena_multi_op_window(
2154                        &dev.device,
2155                        &arena,
2156                        &graph,
2157                        &param_offsets,
2158                        &mut schedule,
2159                        &mut scratch,
2160                        &red_ids,
2161                    );
2162                    if !red_fits && !param_anchor {
2163                        base = arena_bind_window_covering_scratch_if_needed(
2164                            &arena, base, size, scratch,
2165                        );
2166                    }
2167                    let in_off = arena_off_in_bind_window(
2168                        &graph,
2169                        &param_offsets,
2170                        &dev.device,
2171                        &arena,
2172                        &mut schedule,
2173                        &mut scratch,
2174                        in_id,
2175                        &mut base,
2176                        &mut size,
2177                    );
2178                    let p = ReduceParams {
2179                        outer,
2180                        reduce_dim,
2181                        inner,
2182                        in_off,
2183                        out_off: arena_local_off_f32(&arena, node.id, base),
2184                        op: reduce_op_id(*rop),
2185                        _p0: 0,
2186                        _p1: 0,
2187                    };
2188                    schedule.push(Step::Reduce { params: p });
2189                    let rk = reduce_kernel(&dev.device);
2190                    let u = emit_uniform(std::mem::size_of::<ReduceParams>());
2191                    let bg = bind_two_buf0_window(&dev.device, rk, &arena.buffer, base, size, &u);
2192                    uniforms.push(u);
2193                    bind_groups.push(bg);
2194                }
2195
2196                Op::Softmax { axis } => {
2197                    let in_id = node.inputs[0];
2198                    let in_shape = graph.node(in_id).shape.dims();
2199                    let last = (in_shape.len() - 1) as i32;
2200                    if *axis != -1 && *axis != last {
2201                        panic!("rlx-wgpu Softmax: only last-axis wired (got axis={axis})");
2202                    }
2203                    let inner = in_shape[in_shape.len() - 1].unwrap_static() as u32;
2204                    let total: u32 = in_shape.iter().map(|d| d.unwrap_static() as u32).product();
2205                    let outer = total / inner.max(1);
2206                    let sm_ids = [node.id, in_id];
2207                    let max_binding = dev.device.limits().max_storage_buffer_binding_size;
2208                    let sm_fits = arena_span_bytes(&arena, &sm_ids) <= max_binding;
2209                    let mut scratch = arena.scratch_off as u64;
2210                    let (mut base, mut size, param_anchor) = arena_multi_op_window(
2211                        &dev.device,
2212                        &arena,
2213                        &graph,
2214                        &param_offsets,
2215                        &mut schedule,
2216                        &mut scratch,
2217                        &sm_ids,
2218                    );
2219                    if !sm_fits && !param_anchor {
2220                        base = arena_bind_window_covering_scratch_if_needed(
2221                            &arena, base, size, scratch,
2222                        );
2223                    }
2224                    let in_off = arena_off_in_bind_window(
2225                        &graph,
2226                        &param_offsets,
2227                        &dev.device,
2228                        &arena,
2229                        &mut schedule,
2230                        &mut scratch,
2231                        in_id,
2232                        &mut base,
2233                        &mut size,
2234                    );
2235                    let p = SoftmaxParams {
2236                        outer,
2237                        inner,
2238                        in_off,
2239                        out_off: arena_local_off_f32(&arena, node.id, base),
2240                        _p0: 0,
2241                        _p1: 0,
2242                        _p2: 0,
2243                        _p3: 0,
2244                    };
2245                    schedule.push(Step::Softmax { params: p });
2246                    let sk = softmax_kernel(&dev.device);
2247                    let u = emit_uniform(std::mem::size_of::<SoftmaxParams>());
2248                    let bg = bind_two_buf0_window(&dev.device, sk, &arena.buffer, base, size, &u);
2249                    uniforms.push(u);
2250                    bind_groups.push(bg);
2251                }
2252
2253                Op::LayerNorm { axis: _, eps } | Op::RmsNorm { axis: _, eps } => {
2254                    let in_id = node.inputs[0];
2255                    let in_shape = graph.node(in_id).shape.dims();
2256                    let inner = in_shape[in_shape.len() - 1].unwrap_static() as u32;
2257                    let total: u32 = in_shape.iter().map(|d| d.unwrap_static() as u32).product();
2258                    let outer = total / inner.max(1);
2259                    let is_layer_norm = matches!(&node.op, Op::LayerNorm { .. });
2260
2261                    // FRLTee fast path: if this LN is the head of a
2262                    // (multi-consumer Add → LN) pattern, emit one
2263                    // `Step::FusedResidualLnTee` that writes the sum to
2264                    // the eliminated Add's arena slot AND the LN result
2265                    // to this LN's slot. The Add itself is skipped
2266                    // upstream (`skip_adds`).
2267                    if is_layer_norm
2268                        && let Some(&(h_id, delta_id, gamma_id, beta_id, sum_id)) =
2269                            ln_to_tee.get(&node.id)
2270                    {
2271                        let gamma_is_param =
2272                            tensor_is_graph_param(&graph, &param_offsets, gamma_id);
2273                        let gamma_bytes = arena.len_of(gamma_id) as u64;
2274                        let frlt_win: Vec<NodeId> =
2275                            if gamma_is_param && gamma_bytes > ARENA_STAGE_CAP {
2276                                vec![gamma_id, node.id, h_id, delta_id, beta_id, sum_id]
2277                            } else {
2278                                vec![node.id, h_id, delta_id, gamma_id, beta_id, sum_id]
2279                            };
2280                        let mut scratch = arena.scratch_off as u64;
2281                        let (mut base, mut size, param_anchor) = arena_multi_op_window(
2282                            &dev.device,
2283                            &arena,
2284                            &graph,
2285                            &param_offsets,
2286                            &mut schedule,
2287                            &mut scratch,
2288                            &frlt_win,
2289                        );
2290                        if !param_anchor {
2291                            base = arena_bind_window_covering_scratch_if_needed(
2292                                &arena, base, size, scratch,
2293                            );
2294                        }
2295                        let in_off = arena_off_in_bind_window(
2296                            &graph,
2297                            &param_offsets,
2298                            &dev.device,
2299                            &arena,
2300                            &mut schedule,
2301                            &mut scratch,
2302                            h_id,
2303                            &mut base,
2304                            &mut size,
2305                        );
2306                        let residual_off = arena_off_in_bind_window(
2307                            &graph,
2308                            &param_offsets,
2309                            &dev.device,
2310                            &arena,
2311                            &mut schedule,
2312                            &mut scratch,
2313                            delta_id,
2314                            &mut base,
2315                            &mut size,
2316                        );
2317                        let sum_off = arena_off_in_bind_window(
2318                            &graph,
2319                            &param_offsets,
2320                            &dev.device,
2321                            &arena,
2322                            &mut schedule,
2323                            &mut scratch,
2324                            sum_id,
2325                            &mut base,
2326                            &mut size,
2327                        );
2328                        let gamma_off = arena_off_in_bind_window(
2329                            &graph,
2330                            &param_offsets,
2331                            &dev.device,
2332                            &arena,
2333                            &mut schedule,
2334                            &mut scratch,
2335                            gamma_id,
2336                            &mut base,
2337                            &mut size,
2338                        );
2339                        let beta_off = arena_off_in_bind_window(
2340                            &graph,
2341                            &param_offsets,
2342                            &dev.device,
2343                            &arena,
2344                            &mut schedule,
2345                            &mut scratch,
2346                            beta_id,
2347                            &mut base,
2348                            &mut size,
2349                        );
2350                        let p = FusedResidualLnTeeParams {
2351                            outer,
2352                            inner,
2353                            in_off,
2354                            residual_off,
2355                            bias_off: 0, // FRLTee currently no-bias only
2356                            gamma_off,
2357                            beta_off,
2358                            sum_off,
2359                            ln_out_off: arena_local_off_f32(&arena, node.id, base),
2360                            eps_bits: eps.to_bits(),
2361                            has_bias: 0,
2362                            _p0: 0,
2363                        };
2364                        schedule.push(Step::FusedResidualLnTee { params: p });
2365                        let frtk = fused_residual_ln_tee_kernel(&dev.device);
2366                        let u = emit_uniform(std::mem::size_of::<FusedResidualLnTeeParams>());
2367                        let bg =
2368                            bind_two_buf0_window(&dev.device, frtk, &arena.buffer, base, size, &u);
2369                        uniforms.push(u);
2370                        bind_groups.push(bg);
2371                        continue;
2372                    }
2373
2374                    let gamma_id = node.inputs[1];
2375                    // beta is the third input for LayerNorm; RmsNorm
2376                    // ignores it (kernel branch on `op` skips the read).
2377                    let beta_id = if is_layer_norm && node.inputs.len() >= 3 {
2378                        node.inputs[2]
2379                    } else {
2380                        // Use gamma's offset as a benign placeholder;
2381                        // the RmsNorm kernel branch never reads it.
2382                        gamma_id
2383                    };
2384                    let gamma_is_param = tensor_is_graph_param(&graph, &param_offsets, gamma_id);
2385                    let gamma_bytes = arena.len_of(gamma_id) as u64;
2386                    let ln_win: Vec<NodeId> = if gamma_is_param && gamma_bytes > ARENA_STAGE_CAP {
2387                        vec![gamma_id, node.id, in_id]
2388                    } else {
2389                        let mut v = vec![node.id, in_id];
2390                        if gamma_is_param {
2391                            v.push(gamma_id);
2392                        }
2393                        if is_layer_norm {
2394                            v.push(beta_id);
2395                        }
2396                        v
2397                    };
2398                    let max_binding = dev.device.limits().max_storage_buffer_binding_size;
2399                    let ln_fits = arena_span_bytes(&arena, &ln_win) <= max_binding;
2400                    let mut scratch = arena.scratch_off as u64;
2401                    let (mut base, mut size, param_anchor) = arena_multi_op_window(
2402                        &dev.device,
2403                        &arena,
2404                        &graph,
2405                        &param_offsets,
2406                        &mut schedule,
2407                        &mut scratch,
2408                        &ln_win,
2409                    );
2410                    if !ln_fits && !param_anchor {
2411                        base = arena_bind_window_covering_scratch_if_needed(
2412                            &arena, base, size, scratch,
2413                        );
2414                    }
2415                    let in_off = arena_off_in_bind_window(
2416                        &graph,
2417                        &param_offsets,
2418                        &dev.device,
2419                        &arena,
2420                        &mut schedule,
2421                        &mut scratch,
2422                        in_id,
2423                        &mut base,
2424                        &mut size,
2425                    );
2426                    let gamma_off = arena_off_in_bind_window(
2427                        &graph,
2428                        &param_offsets,
2429                        &dev.device,
2430                        &arena,
2431                        &mut schedule,
2432                        &mut scratch,
2433                        gamma_id,
2434                        &mut base,
2435                        &mut size,
2436                    );
2437                    let beta_off = arena_off_in_bind_window(
2438                        &graph,
2439                        &param_offsets,
2440                        &dev.device,
2441                        &arena,
2442                        &mut schedule,
2443                        &mut scratch,
2444                        beta_id,
2445                        &mut base,
2446                        &mut size,
2447                    );
2448                    let p = LayerNormParams {
2449                        outer,
2450                        inner,
2451                        in_off,
2452                        out_off: arena_local_off_f32(&arena, node.id, base),
2453                        gamma_off,
2454                        beta_off,
2455                        eps_bits: eps.to_bits(),
2456                        op: if is_layer_norm { 0 } else { 1 },
2457                    };
2458                    schedule.push(Step::LayerNorm { params: p });
2459                    let lk = layernorm_kernel(&dev.device);
2460                    let u = emit_uniform(std::mem::size_of::<LayerNormParams>());
2461                    let bg = bind_two_buf0_window(&dev.device, lk, &arena.buffer, base, size, &u);
2462                    uniforms.push(u);
2463                    bind_groups.push(bg);
2464                }
2465
2466                Op::Reshape { .. } | Op::Cast { .. } => {
2467                    // No-op: memory planner view-aliased this slot.
2468                }
2469
2470                Op::Transpose { perm } => {
2471                    let in_id = node.inputs[0];
2472                    let in_shape = graph.node(in_id).shape.dims();
2473                    let out_shape = node.shape.dims();
2474                    let rank = perm.len();
2475                    if rank != in_shape.len() || rank != out_shape.len() {
2476                        panic!("rlx-wgpu Transpose: rank mismatch");
2477                    }
2478                    let in_dims: Vec<u32> =
2479                        in_shape.iter().map(|d| d.unwrap_static() as u32).collect();
2480                    let out_dims: Vec<u32> =
2481                        out_shape.iter().map(|d| d.unwrap_static() as u32).collect();
2482                    // Input cumulative strides (row-major).
2483                    let mut in_strides = vec![1u32; rank];
2484                    for i in (0..rank.saturating_sub(1)).rev() {
2485                        in_strides[i] = in_strides[i + 1] * in_dims[i + 1];
2486                    }
2487                    // For each *output* axis i, the corresponding input
2488                    // axis is perm[i] — its stride is in_strides[perm[i]].
2489                    let strides_for_out: Vec<u32> =
2490                        (0..rank).map(|i| in_strides[perm[i]]).collect();
2491
2492                    // Build meta buffer: dims (rank u32s) + strides (rank u32s).
2493                    let mut meta_data: Vec<u32> = Vec::with_capacity(rank * 2);
2494                    meta_data.extend_from_slice(&out_dims);
2495                    meta_data.extend_from_slice(&strides_for_out);
2496                    let meta_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
2497                        label: Some("rlx-wgpu transpose meta"),
2498                        size: (meta_data.len() * 4).max(4) as u64,
2499                        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
2500                        mapped_at_creation: false,
2501                    });
2502                    dev.queue
2503                        .write_buffer(&meta_buf, 0, bytemuck::cast_slice(&meta_data));
2504                    let meta_idx = meta_buffers.len();
2505                    meta_buffers.push(meta_buf);
2506
2507                    // PLAN L1: precompute "bucket axis stays at out
2508                    // axis 0" flag from perm. When `perm[0] == 0`,
2509                    // active-extent scaling of `out_total` is safe.
2510                    let bucket_outermost = if perm[0] == 0 { 1u32 } else { 0u32 };
2511                    let tr_ids = [node.id, in_id];
2512                    let max_binding = dev.device.limits().max_storage_buffer_binding_size;
2513                    let in_is_param = tensor_is_graph_param(&graph, &param_offsets, in_id);
2514                    let in_bytes = arena.len_of(in_id) as u64;
2515                    let (mut base, mut size) = if in_is_param && in_bytes <= max_binding {
2516                        arena_window_for_nodes(&dev.device, &arena, &[in_id])
2517                    } else if arena_span_bytes(&arena, &tr_ids) <= max_binding {
2518                        arena_window_for_nodes(&dev.device, &arena, &tr_ids)
2519                    } else {
2520                        arena_window_for_nodes(&dev.device, &arena, &[node.id])
2521                    };
2522                    let mut scratch = arena.scratch_off as u64;
2523                    let in_off = arena_off_in_bind_window(
2524                        &graph,
2525                        &param_offsets,
2526                        &dev.device,
2527                        &arena,
2528                        &mut schedule,
2529                        &mut scratch,
2530                        in_id,
2531                        &mut base,
2532                        &mut size,
2533                    );
2534                    let out_off = arena_off_in_bind_window(
2535                        &graph,
2536                        &param_offsets,
2537                        &dev.device,
2538                        &arena,
2539                        &mut schedule,
2540                        &mut scratch,
2541                        node.id,
2542                        &mut base,
2543                        &mut size,
2544                    );
2545                    let p = TransposeParams {
2546                        rank: rank as u32,
2547                        out_total: elems,
2548                        in_off,
2549                        out_off,
2550                        bucket_outermost,
2551                        out_dim_0: out_dims[0],
2552                        _p2: 0,
2553                        _p3: 0,
2554                    };
2555                    schedule.push(Step::Transpose {
2556                        params: p,
2557                        meta_idx,
2558                    });
2559                    let tk = transpose_kernel(&dev.device);
2560                    let u = emit_uniform(std::mem::size_of::<TransposeParams>());
2561                    let bg = dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
2562                        label: Some("rlx-wgpu transpose bg"),
2563                        layout: &tk.bgl,
2564                        entries: &[
2565                            wgpu::BindGroupEntry {
2566                                binding: 0,
2567                                resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
2568                                    buffer: &arena.buffer,
2569                                    offset: base,
2570                                    size: NonZeroU64::new(size),
2571                                }),
2572                            },
2573                            wgpu::BindGroupEntry {
2574                                binding: 1,
2575                                resource: u.as_entire_binding(),
2576                            },
2577                            wgpu::BindGroupEntry {
2578                                binding: 2,
2579                                resource: meta_buffers[meta_idx].as_entire_binding(),
2580                            },
2581                        ],
2582                    });
2583                    uniforms.push(u);
2584                    bind_groups.push(bg);
2585                }
2586
2587                Op::Narrow { axis, start, len } => {
2588                    // Part of a split-QKV pattern: the parent FMB has been
2589                    // (or will be) replaced by Step::MatmulQkv that writes
2590                    // directly into this narrow's arena slot. Skip the
2591                    // narrow's own dispatch.
2592                    if qkv_skip_narrows.contains(&node.id)
2593                        || packed_bshd_skip_narrows.contains(&node.id)
2594                    {
2595                        continue;
2596                    }
2597                    let in_id = node.inputs[0];
2598                    let in_shape = graph.node(in_id).shape.dims();
2599                    let outer: u32 = in_shape[..*axis]
2600                        .iter()
2601                        .map(|d| d.unwrap_static() as u32)
2602                        .product::<u32>()
2603                        .max(1);
2604                    let inner: u32 = in_shape[*axis + 1..]
2605                        .iter()
2606                        .map(|d| d.unwrap_static() as u32)
2607                        .product::<u32>()
2608                        .max(1);
2609                    let axis_in = in_shape[*axis].unwrap_static() as u32;
2610                    let win_ids = [node.id, in_id];
2611                    let max_binding = dev.device.limits().max_storage_buffer_binding_size;
2612                    let fits = arena_span_bytes(&arena, &win_ids) <= max_binding;
2613                    let mut scratch = arena.scratch_off as u64;
2614                    let (mut base, mut size, param_anchor) = arena_multi_op_window(
2615                        &dev.device,
2616                        &arena,
2617                        &graph,
2618                        &param_offsets,
2619                        &mut schedule,
2620                        &mut scratch,
2621                        &win_ids,
2622                    );
2623                    if !fits && !param_anchor {
2624                        base = arena_bind_window_covering_scratch_if_needed(
2625                            &arena, base, size, scratch,
2626                        );
2627                    }
2628                    let in_off = arena_off_in_bind_window(
2629                        &graph,
2630                        &param_offsets,
2631                        &dev.device,
2632                        &arena,
2633                        &mut schedule,
2634                        &mut scratch,
2635                        in_id,
2636                        &mut base,
2637                        &mut size,
2638                    );
2639                    let out_off = arena_off_in_bind_window(
2640                        &graph,
2641                        &param_offsets,
2642                        &dev.device,
2643                        &arena,
2644                        &mut schedule,
2645                        &mut scratch,
2646                        node.id,
2647                        &mut base,
2648                        &mut size,
2649                    );
2650                    let p = NarrowConcatParams {
2651                        total: elems,
2652                        outer,
2653                        inner,
2654                        axis_in_size: axis_in,
2655                        axis_out_size: *len as u32,
2656                        start: *start as u32,
2657                        in_off,
2658                        out_off,
2659                    };
2660                    schedule.push(Step::Narrow { params: p });
2661                    let nk = narrow_kernel(&dev.device);
2662                    let u = emit_uniform(std::mem::size_of::<NarrowConcatParams>());
2663                    let bg = bind_two_buf0_window(&dev.device, nk, &arena.buffer, base, size, &u);
2664                    uniforms.push(u);
2665                    bind_groups.push(bg);
2666                }
2667
2668                Op::Concat { axis } => {
2669                    let out_shape = node.shape.dims();
2670                    let outer: u32 = out_shape[..*axis]
2671                        .iter()
2672                        .map(|d| d.unwrap_static() as u32)
2673                        .product::<u32>()
2674                        .max(1);
2675                    let inner: u32 = out_shape[*axis + 1..]
2676                        .iter()
2677                        .map(|d| d.unwrap_static() as u32)
2678                        .product::<u32>()
2679                        .max(1);
2680                    let axis_out = out_shape[*axis].unwrap_static() as u32;
2681
2682                    let all_ids: Vec<NodeId> = std::iter::once(node.id)
2683                        .chain(node.inputs.iter().copied())
2684                        .collect();
2685                    let max_binding = dev.device.limits().max_storage_buffer_binding_size;
2686                    let fits_all = arena_span_bytes(&arena, &all_ids) <= max_binding;
2687                    let mut scratch = arena.scratch_off as u64;
2688                    let (mut base, mut size, param_anchor) = arena_multi_op_window(
2689                        &dev.device,
2690                        &arena,
2691                        &graph,
2692                        &param_offsets,
2693                        &mut schedule,
2694                        &mut scratch,
2695                        &all_ids,
2696                    );
2697                    arena_expand_bind_window(&arena, &all_ids, &mut base, &mut size, max_binding);
2698                    if !fits_all && !param_anchor {
2699                        base = arena_bind_window_covering_scratch_if_needed(
2700                            &arena, base, size, scratch,
2701                        );
2702                    }
2703                    let out_off = arena_local_off_f32(&arena, node.id, base);
2704
2705                    let mut start_pos: u32 = 0;
2706                    for &in_id in &node.inputs {
2707                        let in_shape = graph.node(in_id).shape.dims();
2708                        let axis_in = in_shape[*axis].unwrap_static() as u32;
2709                        let in_total: u32 =
2710                            in_shape.iter().map(|d| d.unwrap_static() as u32).product();
2711                        let _win_ids = [node.id, in_id];
2712                        let in_off = arena_off_in_bind_window(
2713                            &graph,
2714                            &param_offsets,
2715                            &dev.device,
2716                            &arena,
2717                            &mut schedule,
2718                            &mut scratch,
2719                            in_id,
2720                            &mut base,
2721                            &mut size,
2722                        );
2723                        let p = NarrowConcatParams {
2724                            total: in_total,
2725                            outer,
2726                            inner,
2727                            axis_in_size: axis_in,
2728                            axis_out_size: axis_out,
2729                            start: start_pos,
2730                            in_off,
2731                            out_off,
2732                        };
2733                        schedule.push(Step::Concat { params: p });
2734                        let cck = concat_kernel(&dev.device);
2735                        let u = emit_uniform(std::mem::size_of::<NarrowConcatParams>());
2736                        let bg =
2737                            bind_two_buf0_window(&dev.device, cck, &arena.buffer, base, size, &u);
2738                        uniforms.push(u);
2739                        bind_groups.push(bg);
2740                        start_pos += axis_in;
2741                    }
2742                }
2743
2744                Op::Attention {
2745                    num_heads,
2746                    head_dim,
2747                    mask_kind,
2748                    score_scale: _,
2749                    attn_logit_softcap: _,
2750                } => {
2751                    // v5: rank-4 [B, H, S, D] inputs only. SlidingWindow
2752                    // synthesizes a Custom mask host-side.
2753                    let q_id = node.inputs[0];
2754                    let k_id = node.inputs[1];
2755                    let v_id = node.inputs[2];
2756                    let q_shape = graph.node(q_id).shape.dims();
2757                    let k_shape = graph.node(k_id).shape.dims();
2758                    // Accept either rank-4 [B, H, S, D] or rank-3 [B*H, S, D]
2759                    // (the latter is what BERT-flavored builders emit). For
2760                    // rank-3 we treat the leading dim as `batch * heads`,
2761                    // setting heads = num_heads from the Op so the kernel's
2762                    // (b, h) indexing folds back to the right offset.
2763                    let h = *num_heads as u32;
2764                    let hd = *head_dim as u32;
2765                    let q_ir = graph.node(q_id).shape.clone();
2766                    let k_ir = graph.node(k_id).shape.clone();
2767                    let geom = rlx_ir::attention_geom(&q_ir, &k_ir, *num_heads, *head_dim);
2768                    let bhsd = geom.bhsd;
2769                    let (batch, heads, seq_q, seq_k) = match q_shape.len() {
2770                        4 => (
2771                            geom.batch as u32,
2772                            geom.heads as u32,
2773                            geom.seq_q as u32,
2774                            geom.seq_k as u32,
2775                        ),
2776                        3 => {
2777                            // Two rank-3 layouts coexist:
2778                            //   [B, S, H·D] — transpose-elided layout
2779                            //   [B·H, S, D] — canonical compacted layout
2780                            // Distinguish by last-dim: if it equals H·D
2781                            // (the per-token feature width) it's [B, S, H·D];
2782                            // otherwise it's [B·H, S, D].
2783                            let last = q_shape[2].unwrap_static() as u32;
2784                            if last == h * hd {
2785                                // [B, S, H·D]: leading = B, seq = S
2786                                (
2787                                    q_shape[0].unwrap_static() as u32,
2788                                    h,
2789                                    q_shape[1].unwrap_static() as u32,
2790                                    k_shape[1].unwrap_static() as u32,
2791                                )
2792                            } else {
2793                                // [B·H, S, D]: leading must be divisible by H
2794                                let leading = q_shape[0].unwrap_static() as u32;
2795                                if !leading.is_multiple_of(h) {
2796                                    panic!(
2797                                        "rlx-wgpu Attention: rank-3 leading dim {leading} \
2798                                            not divisible by num_heads {h} (and last dim \
2799                                            {last} ≠ H·D = {})",
2800                                        h * hd
2801                                    );
2802                                }
2803                                (
2804                                    leading / h,
2805                                    h,
2806                                    q_shape[1].unwrap_static() as u32,
2807                                    k_shape[1].unwrap_static() as u32,
2808                                )
2809                            }
2810                        }
2811                        other => panic!(
2812                            "rlx-wgpu Attention: only rank-3 / rank-4 Q,K,V \
2813                                         inputs supported (got rank {other})"
2814                        ),
2815                    };
2816                    let scale = 1.0_f32 / (hd as f32).sqrt();
2817
2818                    let (mask_kind_id, mask_buf, window) = match mask_kind {
2819                        MaskKind::None => (0u32, None, 0u32),
2820                        MaskKind::Causal => (1u32, None, 0u32),
2821                        MaskKind::Custom | MaskKind::Bias => (2u32, None, 0u32),
2822                        MaskKind::SlidingWindow(w) => (3u32, None, *w as u32),
2823                    };
2824
2825                    // Mask address strides. For Custom masks, derive from
2826                    // the mask's IR shape so the kernel can broadcast a
2827                    // [B, S] padding mask without materializing the full
2828                    // [B, H, S_q, S_k] expansion. Other mask kinds use
2829                    // canonical [B, H, S_q, S_k] strides (the kernel's
2830                    // mask_partial computation is harmless when not read).
2831                    struct MStrides {
2832                        b: u32,
2833                        h: u32,
2834                        q: u32,
2835                        k: u32,
2836                    }
2837                    let mask_strides = if mask_kind_id == 2u32 {
2838                        let m_dims = graph.node(node.inputs[3]).shape.dims();
2839                        let dim = |i: usize| m_dims[i].unwrap_static() as u32;
2840                        match m_dims.len() {
2841                            2 => MStrides {
2842                                b: dim(1),
2843                                h: 0,
2844                                q: 0,
2845                                k: 1,
2846                            },
2847                            3 => MStrides {
2848                                b: dim(1) * dim(2),
2849                                h: 0,
2850                                q: dim(2),
2851                                k: 1,
2852                            },
2853                            4 => MStrides {
2854                                b: dim(1) * dim(2) * dim(3),
2855                                h: dim(2) * dim(3),
2856                                q: dim(3),
2857                                k: 1,
2858                            },
2859                            _ => MStrides {
2860                                b: heads * seq_q * seq_k,
2861                                h: seq_q * seq_k,
2862                                q: seq_k,
2863                                k: 1,
2864                            },
2865                        }
2866                    } else {
2867                        MStrides {
2868                            b: heads * seq_q * seq_k,
2869                            h: seq_q * seq_k,
2870                            q: seq_k,
2871                            k: 1,
2872                        }
2873                    };
2874
2875                    let stride = |shape: &[rlx_ir::shape::Dim], seq_extent: u32| {
2876                        rlx_ir::strides_for_shape(shape, heads, hd, seq_extent, bhsd)
2877                    };
2878                    let packed_parent = packed_bshd_attn.get(&node.id).copied();
2879                    let (q_b, q_h, q_s, k_b, k_h, k_s, v_b, v_h, v_s) =
2880                        if let Some((_parent, head_width)) = packed_parent {
2881                            let (batch_stride, head_stride, pack_seq) =
2882                                rlx_ir::packed_bshd_qkv_strides(head_width as usize, hd, seq_q);
2883                            (
2884                                batch_stride,
2885                                head_stride,
2886                                pack_seq,
2887                                batch_stride,
2888                                head_stride,
2889                                pack_seq,
2890                                batch_stride,
2891                                head_stride,
2892                                pack_seq,
2893                            )
2894                        } else {
2895                            let (qb, qh, qs) = stride(q_shape, seq_q);
2896                            let (kb, kh, ks) = stride(k_shape, seq_k);
2897                            let v_shape = graph.node(v_id).shape.dims();
2898                            let (vb, vh, vs) = stride(v_shape, seq_k);
2899                            (qb, qh, qs, kb, kh, ks, vb, vh, vs)
2900                        };
2901                    let out_shape = node.shape.dims();
2902                    let (o_b, o_h, o_s) = stride(out_shape, seq_q);
2903                    let mut attn_ids = if let Some((parent, _)) = packed_parent {
2904                        vec![node.id, parent]
2905                    } else {
2906                        vec![node.id, q_id, k_id, v_id]
2907                    };
2908                    if mask_kind_id == 2 {
2909                        attn_ids.push(node.inputs[3]);
2910                    }
2911                    let mut scratch = arena.scratch_off as u64;
2912                    let (mut base, mut size, param_anchor) = arena_multi_op_window(
2913                        &dev.device,
2914                        &arena,
2915                        &graph,
2916                        &param_offsets,
2917                        &mut schedule,
2918                        &mut scratch,
2919                        &attn_ids,
2920                    );
2921                    if !param_anchor {
2922                        base = arena_bind_window_covering_scratch_if_needed(
2923                            &arena, base, size, scratch,
2924                        );
2925                    }
2926                    let (q_off, k_off, v_off) = if let Some((parent, head_width)) = packed_parent {
2927                        let parent_off = arena_off_in_bind_window(
2928                            &graph,
2929                            &param_offsets,
2930                            &dev.device,
2931                            &arena,
2932                            &mut schedule,
2933                            &mut scratch,
2934                            parent,
2935                            &mut base,
2936                            &mut size,
2937                        );
2938                        (
2939                            parent_off,
2940                            parent_off.saturating_add(head_width),
2941                            parent_off.saturating_add(head_width * 2),
2942                        )
2943                    } else {
2944                        let q_off = arena_off_in_bind_window(
2945                            &graph,
2946                            &param_offsets,
2947                            &dev.device,
2948                            &arena,
2949                            &mut schedule,
2950                            &mut scratch,
2951                            q_id,
2952                            &mut base,
2953                            &mut size,
2954                        );
2955                        let k_off = arena_off_in_bind_window(
2956                            &graph,
2957                            &param_offsets,
2958                            &dev.device,
2959                            &arena,
2960                            &mut schedule,
2961                            &mut scratch,
2962                            k_id,
2963                            &mut base,
2964                            &mut size,
2965                        );
2966                        let v_off = arena_off_in_bind_window(
2967                            &graph,
2968                            &param_offsets,
2969                            &dev.device,
2970                            &arena,
2971                            &mut schedule,
2972                            &mut scratch,
2973                            v_id,
2974                            &mut base,
2975                            &mut size,
2976                        );
2977                        (q_off, k_off, v_off)
2978                    };
2979                    let out_byte = arena.offset(node.id) as u64;
2980                    let out_len = arena.len_of(node.id) as u64;
2981                    let out_aliases_qkv = arena_tensors_overlap(&arena, node.id, q_id)
2982                        || arena_tensors_overlap(&arena, node.id, k_id)
2983                        || arena_tensors_overlap(&arena, node.id, v_id)
2984                        || packed_parent.is_some_and(|(parent, _)| {
2985                            arena_tensors_overlap(&arena, node.id, parent)
2986                        });
2987                    let mut kernel_out_off = arena_off_in_bind_window(
2988                        &graph,
2989                        &param_offsets,
2990                        &dev.device,
2991                        &arena,
2992                        &mut schedule,
2993                        &mut scratch,
2994                        node.id,
2995                        &mut base,
2996                        &mut size,
2997                    );
2998                    let mut attn_scratch_copy: Option<(u64, u32)> = None;
2999                    if out_aliases_qkv && rlx_ir::env::flag("RLX_WGPU_DEBUG_ATTN_ALIAS") {
3000                        eprintln!(
3001                            "rlx-wgpu Attention alias: out={:?}@{}+{} q={:?}@{} k={:?}@{} v={:?}@{}",
3002                            node.id,
3003                            out_byte,
3004                            out_len,
3005                            q_id,
3006                            arena.offset(q_id),
3007                            k_id,
3008                            arena.offset(k_id),
3009                            v_id,
3010                            arena.offset(v_id),
3011                        );
3012                    }
3013                    if out_aliases_qkv {
3014                        let tmp_byte = scratch;
3015                        let tmp_aligned = out_len.div_ceil(256) * 256;
3016                        scratch = scratch.saturating_add(tmp_aligned);
3017                        if param_anchor {
3018                            arena_ensure_scratch_in_window(&mut scratch, base, size);
3019                        } else {
3020                            base = arena_bind_window_covering_scratch_if_needed(
3021                                &arena, base, size, scratch,
3022                            );
3023                        }
3024                        kernel_out_off = ((tmp_byte.saturating_sub(base)) / 4) as u32;
3025                        attn_scratch_copy = Some((tmp_byte, out_len as u32));
3026                    }
3027                    let mask_off = if mask_kind_id == 2 {
3028                        arena_off_in_bind_window(
3029                            &graph,
3030                            &param_offsets,
3031                            &dev.device,
3032                            &arena,
3033                            &mut schedule,
3034                            &mut scratch,
3035                            node.inputs[3],
3036                            &mut base,
3037                            &mut size,
3038                        )
3039                    } else {
3040                        0
3041                    };
3042                    let p = AttentionParams {
3043                        batch,
3044                        heads,
3045                        seq_q,
3046                        seq_k,
3047                        head_dim: hd,
3048                        q_off,
3049                        k_off,
3050                        v_off,
3051                        out_off: kernel_out_off,
3052                        mask_off,
3053                        mask_kind: mask_kind_id,
3054                        scale_bits: scale.to_bits(),
3055                        window,
3056                        // Mask strides — derive from the mask's IR shape:
3057                        //   [B, S]:           (mb=S,        mh=0,    mq=0,   mk=1)
3058                        //   [B, S_q, S_k]:    (mb=S_q·S_k,  mh=0,    mq=S_k, mk=1)
3059                        //   [B, H, S_q, S_k]: (mb=H·S_q·S_k mh=S_q·S_k mq=S_k mk=1)
3060                        // Stride 0 means the kernel broadcasts across that
3061                        // axis (reads the same element for every value of
3062                        // the index). Lets us skip the Expand pre-pass that
3063                        // unfuse used to emit per attention block.
3064                        seq_q_stride: mask_strides.q,
3065                        seq_k_stride: mask_strides.k,
3066                        mask_batch_stride: mask_strides.b,
3067                        mask_head_stride: mask_strides.h,
3068                        _pad_mask_0: 0,
3069                        _pad_mask_1: 0,
3070                        _pad_mask_2: 0,
3071                        q_batch_stride: q_b,
3072                        q_head_stride: q_h,
3073                        q_seq_stride: q_s,
3074                        _pad_q: 0,
3075                        k_batch_stride: k_b,
3076                        k_head_stride: k_h,
3077                        k_seq_stride: k_s,
3078                        _pad_k: 0,
3079                        v_batch_stride: v_b,
3080                        v_head_stride: v_h,
3081                        v_seq_stride: v_s,
3082                        _pad_v: 0,
3083                        o_batch_stride: o_b,
3084                        o_head_stride: o_h,
3085                        o_seq_stride: o_s,
3086                        _pad_o: 0,
3087                    };
3088                    let _ = num_heads;
3089                    schedule.push(Step::Attention {
3090                        params: p,
3091                        mask_buf,
3092                    });
3093                    if let Some((tmp_byte, bytes)) = attn_scratch_copy {
3094                        schedule.push(Step::BufferCopy {
3095                            src_byte_off: tmp_byte as u32,
3096                            dst_byte_off: out_byte as u32,
3097                            bytes,
3098                        });
3099                    }
3100                    let ak = attention_kernel(&dev.device);
3101                    let u = emit_uniform(std::mem::size_of::<AttentionParams>());
3102                    let bg = bind_two_buf0_window(&dev.device, ak, &arena.buffer, base, size, &u);
3103                    uniforms.push(u);
3104                    bind_groups.push(bg);
3105                }
3106
3107                Op::AttentionBackward {
3108                    num_heads,
3109                    head_dim,
3110                    mask_kind,
3111                    wrt,
3112                } => {
3113                    use rlx_ir::op::AttentionBwdWrt;
3114                    let q_id = node.inputs[0];
3115                    let k_id = node.inputs[1];
3116                    let v_id = node.inputs[2];
3117                    let dy_id = node.inputs[3];
3118                    let q_shape = graph.node(q_id).shape.dims();
3119                    let k_shape = graph.node(k_id).shape.dims();
3120                    let hd = *head_dim as u32;
3121                    let q_ir = graph.node(q_id).shape.clone();
3122                    let k_ir = graph.node(k_id).shape.clone();
3123                    let geom = rlx_ir::attention_geom(&q_ir, &k_ir, *num_heads, *head_dim);
3124                    let bhsd = geom.bhsd;
3125                    let (batch, heads, seq_q, seq_k) = match q_shape.len() {
3126                        4 => (
3127                            geom.batch as u32,
3128                            geom.heads as u32,
3129                            geom.seq_q as u32,
3130                            geom.seq_k as u32,
3131                        ),
3132                        3 => {
3133                            let h = q_shape[2].unwrap_static() as u32 / hd;
3134                            (
3135                                q_shape[0].unwrap_static() as u32 / h,
3136                                h,
3137                                q_shape[1].unwrap_static() as u32,
3138                                k_shape[1].unwrap_static() as u32,
3139                            )
3140                        }
3141                        other => panic!(
3142                            "rlx-wgpu AttentionBackward: only rank-3/4 Q,K,V (got rank {other})"
3143                        ),
3144                    };
3145                    let scale = 1.0_f32 / (hd as f32).sqrt();
3146                    let (mask_kind_id, mask_off, mask_buf, window) = match mask_kind {
3147                        MaskKind::None => (0u32, 0u32, None, 0u32),
3148                        MaskKind::Causal => (1u32, 0u32, None, 0u32),
3149                        MaskKind::Custom => {
3150                            (2u32, (arena.offset(node.inputs[4]) / 4) as u32, None, 0u32)
3151                        }
3152                        MaskKind::Bias => {
3153                            (4u32, (arena.offset(node.inputs[4]) / 4) as u32, None, 0u32)
3154                        }
3155                        MaskKind::SlidingWindow(w) => (3u32, 0u32, None, *w as u32),
3156                    };
3157                    struct MStrides {
3158                        b: u32,
3159                        h: u32,
3160                        q: u32,
3161                        k: u32,
3162                    }
3163                    let mask_strides = if mask_kind_id == 2 || mask_kind_id == 4 {
3164                        let m_dims = graph.node(node.inputs[4]).shape.dims();
3165                        let dim = |i: usize| m_dims[i].unwrap_static() as u32;
3166                        match m_dims.len() {
3167                            2 => MStrides {
3168                                b: dim(1),
3169                                h: 0,
3170                                q: 0,
3171                                k: 1,
3172                            },
3173                            3 => MStrides {
3174                                b: dim(1) * dim(2),
3175                                h: 0,
3176                                q: dim(2),
3177                                k: 1,
3178                            },
3179                            4 => MStrides {
3180                                b: dim(1) * dim(2) * dim(3),
3181                                h: dim(2) * dim(3),
3182                                q: dim(3),
3183                                k: 1,
3184                            },
3185                            _ => MStrides {
3186                                b: heads * seq_q * seq_k,
3187                                h: seq_q * seq_k,
3188                                q: seq_k,
3189                                k: 1,
3190                            },
3191                        }
3192                    } else {
3193                        MStrides {
3194                            b: heads * seq_q * seq_k,
3195                            h: seq_q * seq_k,
3196                            q: seq_k,
3197                            k: 1,
3198                        }
3199                    };
3200                    let stride = |shape: &[rlx_ir::shape::Dim], seq_extent: u32| {
3201                        rlx_ir::strides_for_shape(shape, heads, hd, seq_extent, bhsd)
3202                    };
3203                    let (q_b, q_h, q_s) = stride(q_shape, seq_q);
3204                    let (k_b, k_h, k_s) = stride(k_shape, seq_k);
3205                    let v_shape = graph.node(v_id).shape.dims();
3206                    let (v_b, v_h, v_s) = stride(v_shape, seq_k);
3207                    let out_shape = node.shape.dims();
3208                    let out_seq = match wrt {
3209                        AttentionBwdWrt::Query => seq_q,
3210                        AttentionBwdWrt::Key | AttentionBwdWrt::Value => seq_k,
3211                    };
3212                    let (o_b, o_h, o_s) = stride(out_shape, out_seq);
3213                    let wrt_id = match wrt {
3214                        AttentionBwdWrt::Query => 0u32,
3215                        AttentionBwdWrt::Key => 1u32,
3216                        AttentionBwdWrt::Value => 2u32,
3217                    };
3218                    let p = AttentionBwdParams {
3219                        batch,
3220                        heads,
3221                        seq_q,
3222                        seq_k,
3223                        head_dim: hd,
3224                        q_off: (arena.offset(q_id) / 4) as u32,
3225                        k_off: (arena.offset(k_id) / 4) as u32,
3226                        v_off: (arena.offset(v_id) / 4) as u32,
3227                        dy_off: (arena.offset(dy_id) / 4) as u32,
3228                        out_off: (arena.offset(node.id) / 4) as u32,
3229                        mask_off,
3230                        mask_kind: mask_kind_id,
3231                        scale_bits: scale.to_bits(),
3232                        window,
3233                        wrt: wrt_id,
3234                        seq_q_stride: mask_strides.q,
3235                        seq_k_stride: mask_strides.k,
3236                        mask_batch_stride: mask_strides.b,
3237                        mask_head_stride: mask_strides.h,
3238                        _pad_mask_0: 0,
3239                        _pad_mask_1: 0,
3240                        _pad_mask_2: 0,
3241                        q_batch_stride: q_b,
3242                        q_head_stride: q_h,
3243                        q_seq_stride: q_s,
3244                        _pad_q: 0,
3245                        k_batch_stride: k_b,
3246                        k_head_stride: k_h,
3247                        k_seq_stride: k_s,
3248                        _pad_k: 0,
3249                        v_batch_stride: v_b,
3250                        v_head_stride: v_h,
3251                        v_seq_stride: v_s,
3252                        _pad_v: 0,
3253                        o_batch_stride: o_b,
3254                        o_head_stride: o_h,
3255                        o_seq_stride: o_s,
3256                        _pad_o: 0,
3257                    };
3258                    schedule.push(Step::AttentionBackward {
3259                        params: p,
3260                        mask_buf,
3261                    });
3262                    let ak = attention_bwd_kernel(&dev.device);
3263                    let u = emit_uniform(std::mem::size_of::<AttentionBwdParams>());
3264                    let bg = bind_op_output_window(&dev.device, ak, &arena, node.id, &u);
3265                    uniforms.push(u);
3266                    bind_groups.push(bg);
3267                }
3268
3269                Op::Rope { head_dim, n_rot: _ } => {
3270                    let x_id = node.inputs[0];
3271                    let cos_id = node.inputs[1];
3272                    let sin_id = node.inputs[2];
3273                    let x_shape = graph.node(x_id).shape.dims();
3274                    let last = x_shape.last().map(|d| d.unwrap_static()).unwrap_or(0);
3275                    if !last.is_multiple_of(*head_dim) {
3276                        panic!(
3277                            "rlx-wgpu Rope: last_dim ({last}) must be a multiple \
3278                                of head_dim ({head_dim})"
3279                        );
3280                    }
3281                    if head_dim % 2 != 0 {
3282                        panic!("rlx-wgpu Rope: head_dim must be even");
3283                    }
3284                    let total: u32 = x_shape.iter().map(|d| d.unwrap_static() as u32).product();
3285                    let seq = x_shape[x_shape.len() - 2].unwrap_static() as u32;
3286                    // PLAN L1: derive batch from total / seq / last_dim
3287                    // (= product of leading dims). `seq_stride` stays at
3288                    // full seq for buffer offset math; `seq` becomes the
3289                    // runtime-scaled loop bound.
3290                    let batch = total / (seq * last as u32).max(1);
3291                    let cos_is_param = tensor_is_graph_param(&graph, &param_offsets, cos_id);
3292                    let cos_bytes = arena.len_of(cos_id) as u64;
3293                    let rope_win: Vec<NodeId> = if cos_is_param && cos_bytes > ARENA_STAGE_CAP {
3294                        vec![cos_id, sin_id, node.id, x_id]
3295                    } else {
3296                        vec![node.id, x_id, cos_id, sin_id]
3297                    };
3298                    let mut scratch = arena.scratch_off as u64;
3299                    let (mut base, mut size, param_anchor) = arena_multi_op_window(
3300                        &dev.device,
3301                        &arena,
3302                        &graph,
3303                        &param_offsets,
3304                        &mut schedule,
3305                        &mut scratch,
3306                        &rope_win,
3307                    );
3308                    if !param_anchor {
3309                        base = arena_bind_window_covering_scratch_if_needed(
3310                            &arena, base, size, scratch,
3311                        );
3312                    }
3313                    let in_off = arena_off_in_bind_window(
3314                        &graph,
3315                        &param_offsets,
3316                        &dev.device,
3317                        &arena,
3318                        &mut schedule,
3319                        &mut scratch,
3320                        x_id,
3321                        &mut base,
3322                        &mut size,
3323                    );
3324                    let cos_off = arena_off_in_bind_window(
3325                        &graph,
3326                        &param_offsets,
3327                        &dev.device,
3328                        &arena,
3329                        &mut schedule,
3330                        &mut scratch,
3331                        cos_id,
3332                        &mut base,
3333                        &mut size,
3334                    );
3335                    let sin_off = arena_off_in_bind_window(
3336                        &graph,
3337                        &param_offsets,
3338                        &dev.device,
3339                        &arena,
3340                        &mut schedule,
3341                        &mut scratch,
3342                        sin_id,
3343                        &mut base,
3344                        &mut size,
3345                    );
3346                    let p = RopeParams {
3347                        n_total: total,
3348                        seq,
3349                        head_dim: *head_dim as u32,
3350                        half: (*head_dim / 2) as u32,
3351                        in_off,
3352                        cos_off,
3353                        sin_off,
3354                        out_off: arena_local_off_f32(&arena, node.id, base),
3355                        last_dim: last as u32,
3356                        batch,
3357                        seq_stride: seq,
3358                        _p2: 0,
3359                    };
3360                    schedule.push(Step::Rope { params: p });
3361                    let rk = rope_kernel(&dev.device);
3362                    let u = emit_uniform(std::mem::size_of::<RopeParams>());
3363                    let bg = bind_two_buf0_window(&dev.device, rk, &arena.buffer, base, size, &u);
3364                    uniforms.push(u);
3365                    bind_groups.push(bg);
3366                }
3367
3368                Op::Expand { target_shape } => {
3369                    let in_id = node.inputs[0];
3370                    let in_shape = graph.node(in_id).shape.dims();
3371                    let in_rank = in_shape.len();
3372                    let rank = target_shape.len();
3373                    if in_rank > rank {
3374                        panic!(
3375                            "rlx-wgpu Expand: rank mismatch \
3376                                (in_rank={in_rank}, target_rank={rank})"
3377                        );
3378                    }
3379                    // Implicit leading 1s when input rank < target rank (e.g.
3380                    // scalar → vector from `LegalizeBroadcast`).
3381                    let pad = rank.saturating_sub(in_rank);
3382                    let out_dims: Vec<u32> = target_shape.iter().map(|&d| d as u32).collect();
3383                    let in_dims: Vec<u32> = (0..rank)
3384                        .map(|i| {
3385                            if i < pad {
3386                                1
3387                            } else {
3388                                in_shape[i - pad].unwrap_static() as u32
3389                            }
3390                        })
3391                        .collect();
3392                    // Cumulative input strides (row-major). When the
3393                    // input dim is 1 but target dim > 1, that axis
3394                    // broadcasts → stride = 0.
3395                    let mut in_strides_row = vec![1u32; rank];
3396                    for i in (0..rank.saturating_sub(1)).rev() {
3397                        in_strides_row[i] = in_strides_row[i + 1] * in_dims[i + 1];
3398                    }
3399                    let strides_for_out: Vec<u32> = (0..rank)
3400                        .map(|i| {
3401                            if in_dims[i] == 1 && out_dims[i] != 1 {
3402                                0
3403                            } else {
3404                                in_strides_row[i]
3405                            }
3406                        })
3407                        .collect();
3408
3409                    let mut meta_data: Vec<u32> = Vec::with_capacity(rank * 2);
3410                    meta_data.extend_from_slice(&out_dims);
3411                    meta_data.extend_from_slice(&strides_for_out);
3412                    let meta_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
3413                        label: Some("rlx-wgpu expand meta"),
3414                        size: (meta_data.len() * 4).max(4) as u64,
3415                        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
3416                        mapped_at_creation: false,
3417                    });
3418                    dev.queue
3419                        .write_buffer(&meta_buf, 0, bytemuck::cast_slice(&meta_data));
3420                    let meta_idx = meta_buffers.len();
3421                    meta_buffers.push(meta_buf);
3422
3423                    // PLAN L1: bucket axis stays at out axis 0 iff the
3424                    // expand at axis 0 isn't a broadcast (in_dims[0]
3425                    // matches out_dims[0]). When broadcast at axis 0
3426                    // (in_dims[0]==1, out_dims[0]>1), the bucket-axis
3427                    // contract doesn't apply — fall back to full extent.
3428                    let bucket_outermost = if in_dims[0] == out_dims[0] {
3429                        1u32
3430                    } else {
3431                        0u32
3432                    };
3433                    let exp_ids = [node.id, in_id];
3434                    let max_binding = dev.device.limits().max_storage_buffer_binding_size;
3435                    let exp_fits = arena_span_bytes(&arena, &exp_ids) <= max_binding;
3436                    let mut scratch = arena.scratch_off as u64;
3437                    let (mut base, mut size, param_anchor) = arena_multi_op_window(
3438                        &dev.device,
3439                        &arena,
3440                        &graph,
3441                        &param_offsets,
3442                        &mut schedule,
3443                        &mut scratch,
3444                        &exp_ids,
3445                    );
3446                    if !exp_fits && !param_anchor {
3447                        base = arena_bind_window_covering_scratch_if_needed(
3448                            &arena, base, size, scratch,
3449                        );
3450                    }
3451                    let in_off = arena_off_in_bind_window(
3452                        &graph,
3453                        &param_offsets,
3454                        &dev.device,
3455                        &arena,
3456                        &mut schedule,
3457                        &mut scratch,
3458                        in_id,
3459                        &mut base,
3460                        &mut size,
3461                    );
3462                    let out_off = arena_off_in_bind_window(
3463                        &graph,
3464                        &param_offsets,
3465                        &dev.device,
3466                        &arena,
3467                        &mut schedule,
3468                        &mut scratch,
3469                        node.id,
3470                        &mut base,
3471                        &mut size,
3472                    );
3473                    let p = ExpandParams {
3474                        rank: rank as u32,
3475                        out_total: elems,
3476                        in_off,
3477                        out_off,
3478                        bucket_outermost,
3479                        out_dim_0: out_dims[0],
3480                        _p2: 0,
3481                        _p3: 0,
3482                    };
3483                    schedule.push(Step::Expand {
3484                        params: p,
3485                        meta_idx,
3486                    });
3487                    let ek = expand_kernel(&dev.device);
3488                    let u = emit_uniform(std::mem::size_of::<ExpandParams>());
3489                    let bg = dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
3490                        label: Some("rlx-wgpu expand bg"),
3491                        layout: &ek.bgl,
3492                        entries: &[
3493                            wgpu::BindGroupEntry {
3494                                binding: 0,
3495                                resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
3496                                    buffer: &arena.buffer,
3497                                    offset: base,
3498                                    size: NonZeroU64::new(size),
3499                                }),
3500                            },
3501                            wgpu::BindGroupEntry {
3502                                binding: 1,
3503                                resource: u.as_entire_binding(),
3504                            },
3505                            wgpu::BindGroupEntry {
3506                                binding: 2,
3507                                resource: meta_buffers[meta_idx].as_entire_binding(),
3508                            },
3509                        ],
3510                    });
3511                    uniforms.push(u);
3512                    bind_groups.push(bg);
3513                }
3514
3515                Op::Gather { axis } => {
3516                    let table_id = node.inputs[0];
3517                    let idx_id = node.inputs[1];
3518                    let table_is_param = tensor_is_graph_param(&graph, &param_offsets, table_id);
3519                    let table_bytes = arena.len_of(table_id) as u64;
3520                    let gather_win: Vec<NodeId> = if table_is_param && table_bytes > ARENA_STAGE_CAP
3521                    {
3522                        vec![table_id, node.id, idx_id]
3523                    } else {
3524                        vec![node.id, idx_id, table_id]
3525                    };
3526                    let mut scratch = arena.scratch_off as u64;
3527                    let (mut base, mut size, table_anchor) = arena_multi_op_window(
3528                        &dev.device,
3529                        &arena,
3530                        &graph,
3531                        &param_offsets,
3532                        &mut schedule,
3533                        &mut scratch,
3534                        &gather_win,
3535                    );
3536                    if !table_anchor {
3537                        base = arena_bind_window_covering_scratch_if_needed(
3538                            &arena, base, size, scratch,
3539                        );
3540                    }
3541                    let in_off =
3542                        if table_anchor && arena_tensor_in_window(&arena, table_id, base, size) {
3543                            arena_local_off_f32(&arena, table_id, base)
3544                        } else {
3545                            arena_off_in_bind_window(
3546                                &graph,
3547                                &param_offsets,
3548                                &dev.device,
3549                                &arena,
3550                                &mut schedule,
3551                                &mut scratch,
3552                                table_id,
3553                                &mut base,
3554                                &mut size,
3555                            )
3556                        };
3557                    let idx_off = arena_off_in_bind_window(
3558                        &graph,
3559                        &param_offsets,
3560                        &dev.device,
3561                        &arena,
3562                        &mut schedule,
3563                        &mut scratch,
3564                        idx_id,
3565                        &mut base,
3566                        &mut size,
3567                    );
3568                    let out_off = arena_local_off_f32(&arena, node.id, base);
3569                    if *axis == 0 {
3570                        let table_shape = graph.node(table_id).shape.dims();
3571                        let idx_shape = graph.node(idx_id).shape.dims();
3572                        let vocab = table_shape[0].unwrap_static() as u32;
3573                        let dim: u32 = table_shape[1..]
3574                            .iter()
3575                            .map(|d| d.unwrap_static() as u32)
3576                            .product::<u32>()
3577                            .max(1);
3578                        let n_idx: u32 =
3579                            idx_shape.iter().map(|d| d.unwrap_static() as u32).product();
3580                        let p = GatherParams {
3581                            n_out: elems,
3582                            n_idx,
3583                            dim,
3584                            vocab,
3585                            in_off,
3586                            idx_off,
3587                            out_off,
3588                            _p0: 0,
3589                        };
3590                        schedule.push(Step::Gather { params: p });
3591                        let gk = gather_kernel(&dev.device);
3592                        let u = emit_uniform(std::mem::size_of::<GatherParams>());
3593                        let bg =
3594                            bind_two_buf0_window(&dev.device, gk, &arena.buffer, base, size, &u);
3595                        uniforms.push(u);
3596                        bind_groups.push(bg);
3597                    } else {
3598                        let table_shape = graph.node(table_id).shape.dims();
3599                        let idx_shape = graph.node(idx_id).shape.dims();
3600                        let outer: u32 = table_shape[..*axis]
3601                            .iter()
3602                            .map(|d| d.unwrap_static() as u32)
3603                            .product::<u32>()
3604                            .max(1);
3605                        let trailing: u32 = table_shape[*axis + 1..]
3606                            .iter()
3607                            .map(|d| d.unwrap_static() as u32)
3608                            .product::<u32>()
3609                            .max(1);
3610                        let axis_dim = table_shape[*axis].unwrap_static() as u32;
3611                        let num_idx: u32 =
3612                            idx_shape.iter().map(|d| d.unwrap_static() as u32).product();
3613                        let total = outer * num_idx * trailing;
3614                        let p = GatherAxisParams {
3615                            total,
3616                            outer,
3617                            axis_dim,
3618                            num_idx,
3619                            trailing,
3620                            table_off: in_off,
3621                            idx_off,
3622                            out_off,
3623                        };
3624                        schedule.push(Step::GatherAxis { params: p });
3625                        let gk = gather_axis_kernel(&dev.device);
3626                        let u = emit_uniform(std::mem::size_of::<GatherAxisParams>());
3627                        let bg =
3628                            bind_two_buf0_window(&dev.device, gk, &arena.buffer, base, size, &u);
3629                        uniforms.push(u);
3630                        bind_groups.push(bg);
3631                    }
3632                }
3633
3634                Op::FusedMatMulBiasAct { activation } => {
3635                    // Inputs: [x, w, bias]. We require 2D × 2D or
3636                    // [..,M,K] × [K,N] (broadcast bias). Bias is shape [N].
3637                    let a_id = node.inputs[0];
3638                    let b_id = node.inputs[1];
3639                    let bias_id = node.inputs[2];
3640                    let a_shape = graph.node(a_id).shape.dims();
3641                    let b_shape = graph.node(b_id).shape.dims();
3642                    let out_shape = node.shape.dims();
3643                    let (m, k, n) =
3644                        if a_shape.len() == 2 && b_shape.len() == 2 && out_shape.len() == 2 {
3645                            (
3646                                a_shape[0].unwrap_static() as u32,
3647                                a_shape[1].unwrap_static() as u32,
3648                                b_shape[1].unwrap_static() as u32,
3649                            )
3650                        } else if a_shape.len() >= 2
3651                            && b_shape.len() == 2
3652                            && out_shape.len() == a_shape.len()
3653                        {
3654                            let leading: usize = a_shape[..a_shape.len() - 2]
3655                                .iter()
3656                                .map(|d| d.unwrap_static())
3657                                .product();
3658                            let m_inner = a_shape[a_shape.len() - 2].unwrap_static();
3659                            let k_inner = a_shape[a_shape.len() - 1].unwrap_static();
3660                            let n_inner = b_shape[1].unwrap_static();
3661                            ((leading * m_inner) as u32, k_inner as u32, n_inner as u32)
3662                        } else {
3663                            panic!(
3664                                "rlx-wgpu FusedMatMulBiasAct: unsupported shapes \
3665                                a={a_shape:?} b={b_shape:?}"
3666                            );
3667                        };
3668                    let act_id = match activation {
3669                        None => 0xFFFFu32,
3670                        Some(a) => activation_op_id(*a),
3671                    };
3672                    let b_is_param = tensor_is_graph_param(&graph, &param_offsets, b_id);
3673                    let b_bytes = arena.len_of(b_id) as u64;
3674                    let mut compute_precision = derive_matmul_compute(
3675                        &dev.device,
3676                        &graph,
3677                        &coop_f16_vk_mirror_acts,
3678                        a_id,
3679                        b_id,
3680                        m,
3681                        k,
3682                        n,
3683                    );
3684                    if b_is_param && b_bytes > ARENA_STAGE_CAP && arena.param_fits_f16_mirror(b_id)
3685                    {
3686                        compute_precision = MatmulCompute::F16;
3687                    }
3688
3689                    // Split-QKV pattern: matmul writes Q/K/V directly into
3690                    // 3 separate output buffers, eliminating the 3 Narrow
3691                    // dispatches that would otherwise follow.
3692                    let mqk_eligible = act_id == 0xFFFFu32
3693                        && matches!(
3694                            compute_precision,
3695                            MatmulCompute::F32 | MatmulCompute::CoopF32 | MatmulCompute::CoopF16Vk
3696                        );
3697                    if mqk_eligible && let Some(&(q_id, k_id_n, v_id)) = qkv_split.get(&node.id) {
3698                        let head_width = n / 3;
3699                        let qkv_kind = match compute_precision {
3700                            MatmulCompute::CoopF16Vk => MatmulQkvKind::CoopF16Vk,
3701                            MatmulCompute::CoopF32 => MatmulQkvKind::CoopF32,
3702                            _ => MatmulQkvKind::F32,
3703                        };
3704                        let (mut base, mut size, param_anchor) = arena_matmul_bind_window(
3705                            &dev.device,
3706                            &arena,
3707                            &graph,
3708                            &param_offsets,
3709                            q_id,
3710                            a_id,
3711                            b_id,
3712                        );
3713                        let mut scratch = arena.scratch_off as u64;
3714                        if param_anchor {
3715                            arena_ensure_scratch_in_window(&mut scratch, base, size);
3716                        }
3717                        if b_is_param && b_bytes > ARENA_STAGE_CAP {
3718                            assert!(
3719                                param_anchor && arena_tensor_in_window(&arena, b_id, base, size),
3720                                "rlx-wgpu FusedMatMul QKV: large param B {:?} not in bind window",
3721                                b_id,
3722                            );
3723                        }
3724                        let a_off = arena_off_in_bind_window(
3725                            &graph,
3726                            &param_offsets,
3727                            &dev.device,
3728                            &arena,
3729                            &mut schedule,
3730                            &mut scratch,
3731                            a_id,
3732                            &mut base,
3733                            &mut size,
3734                        );
3735                        let q_off = arena_off_in_bind_window(
3736                            &graph,
3737                            &param_offsets,
3738                            &dev.device,
3739                            &arena,
3740                            &mut schedule,
3741                            &mut scratch,
3742                            q_id,
3743                            &mut base,
3744                            &mut size,
3745                        );
3746                        let k_off = arena_off_in_bind_window(
3747                            &graph,
3748                            &param_offsets,
3749                            &dev.device,
3750                            &arena,
3751                            &mut schedule,
3752                            &mut scratch,
3753                            k_id_n,
3754                            &mut base,
3755                            &mut size,
3756                        );
3757                        let v_off = arena_off_in_bind_window(
3758                            &graph,
3759                            &param_offsets,
3760                            &dev.device,
3761                            &arena,
3762                            &mut schedule,
3763                            &mut scratch,
3764                            v_id,
3765                            &mut base,
3766                            &mut size,
3767                        );
3768                        let bias_off = arena_off_in_bind_window(
3769                            &graph,
3770                            &param_offsets,
3771                            &dev.device,
3772                            &arena,
3773                            &mut schedule,
3774                            &mut scratch,
3775                            bias_id,
3776                            &mut base,
3777                            &mut size,
3778                        );
3779                        let b_off_f32 = if b_is_param
3780                            && b_bytes > ARENA_STAGE_CAP
3781                            && arena_tensor_in_window(&arena, b_id, base, size)
3782                        {
3783                            arena_local_off_f32(&arena, b_id, base)
3784                        } else {
3785                            arena_off_in_bind_window(
3786                                &graph,
3787                                &param_offsets,
3788                                &dev.device,
3789                                &arena,
3790                                &mut schedule,
3791                                &mut scratch,
3792                                b_id,
3793                                &mut base,
3794                                &mut size,
3795                            )
3796                        };
3797                        let b_off_global = (arena.offset(b_id) / 4) as u32;
3798                        maybe_push_coop_f16_vk_casts(
3799                            &graph,
3800                            a_id,
3801                            b_id,
3802                            &coop_f16_vk_mirror_acts,
3803                            &dev.device,
3804                            &arena,
3805                            &mut schedule,
3806                            &mut uniforms,
3807                            &mut bind_groups,
3808                            &mm_cast,
3809                            compute_precision,
3810                            a_off,
3811                            m,
3812                            k,
3813                            1,
3814                            if qkv_kind == MatmulQkvKind::CoopF16Vk {
3815                                b_off_global
3816                            } else {
3817                                b_off_f32
3818                            },
3819                            n,
3820                        );
3821                        let p = MatmulQkvParams {
3822                            m,
3823                            k,
3824                            n,
3825                            a_off,
3826                            b_off: if qkv_kind == MatmulQkvKind::CoopF16Vk {
3827                                b_off_global
3828                            } else {
3829                                b_off_f32
3830                            },
3831                            q_off,
3832                            k_off,
3833                            v_off,
3834                            head_width,
3835                            has_bias: 1,
3836                            bias_off,
3837                            _p0: 0,
3838                            _p1: 0,
3839                            _p2: 0,
3840                            _p3: 0,
3841                            _p4: 0,
3842                        };
3843                        schedule.push(Step::MatmulQkv {
3844                            params: p,
3845                            kind: qkv_kind,
3846                        });
3847                        register_coop_f16_vk_b_param(
3848                            &mut coop_f16_b_param,
3849                            &param_offsets,
3850                            b_id,
3851                            p.b_off,
3852                            match qkv_kind {
3853                                MatmulQkvKind::CoopF16Vk => MatmulCompute::CoopF16Vk,
3854                                MatmulQkvKind::CoopF32 => MatmulCompute::CoopF32,
3855                                MatmulQkvKind::F32 => MatmulCompute::F32,
3856                            },
3857                        );
3858                        let u = emit_uniform(std::mem::size_of::<MatmulQkvParams>());
3859                        let bg = match qkv_kind {
3860                            MatmulQkvKind::CoopF16Vk => {
3861                                let mqk = matmul_qkv_coop_f16_vk_kernel(&dev.device).expect(
3862                                    "coop f16 matmul_qkv kernel: feature was checked but missing",
3863                                );
3864                                let (bg, b_off_adj) = build_matmul_qkv_coop_f16_vk_bind_group(
3865                                    &dev.device,
3866                                    mqk,
3867                                    &arena,
3868                                    base,
3869                                    size,
3870                                    &u,
3871                                    k,
3872                                    n,
3873                                    p.b_off,
3874                                );
3875                                if let Some(Step::MatmulQkv { params, .. }) = schedule.last_mut() {
3876                                    params.b_off = b_off_adj;
3877                                }
3878                                bg
3879                            }
3880                            MatmulQkvKind::CoopF32 => bind_two_buf0_window(
3881                                &dev.device,
3882                                matmul_qkv_coop_f32_kernel(&dev.device).expect(
3883                                    "coop matmul_qkv kernel: hardware feature was checked but kernel missing",
3884                                ),
3885                                &arena.buffer,
3886                                base,
3887                                size,
3888                                &u,
3889                            ),
3890                            MatmulQkvKind::F32 => bind_two_buf0_window(
3891                                &dev.device,
3892                                matmul_qkv_kernel(&dev.device),
3893                                &arena.buffer,
3894                                base,
3895                                size,
3896                                &u,
3897                            ),
3898                        };
3899                        uniforms.push(u);
3900                        bind_groups.push(bg);
3901                        if qkv_kind == MatmulQkvKind::CoopF16Vk {
3902                            coop_f16_vk_wide_bind_groups.insert(
3903                                schedule.len() - 1,
3904                                bind_two_buf0_window(
3905                                    &dev.device,
3906                                    matmul_qkv_kernel(&dev.device),
3907                                    &arena.buffer,
3908                                    base,
3909                                    size,
3910                                    &uniforms[uniforms.len() - 1],
3911                                ),
3912                            );
3913                        }
3914                    } else {
3915                        let (mut base, mut size, param_anchor) = arena_matmul_bind_window(
3916                            &dev.device,
3917                            &arena,
3918                            &graph,
3919                            &param_offsets,
3920                            node.id,
3921                            a_id,
3922                            b_id,
3923                        );
3924                        let mut scratch = arena.scratch_off as u64;
3925                        if param_anchor {
3926                            arena_ensure_scratch_in_window(&mut scratch, base, size);
3927                        }
3928                        if b_is_param && b_bytes > ARENA_STAGE_CAP {
3929                            assert!(
3930                                param_anchor && arena_tensor_in_window(&arena, b_id, base, size),
3931                                "rlx-wgpu FusedMatMul: large param B {:?} not in bind window",
3932                                b_id,
3933                            );
3934                        }
3935                        let a_off_f32 = arena_off_in_bind_window(
3936                            &graph,
3937                            &param_offsets,
3938                            &dev.device,
3939                            &arena,
3940                            &mut schedule,
3941                            &mut scratch,
3942                            a_id,
3943                            &mut base,
3944                            &mut size,
3945                        );
3946                        let b_off_f32 = if b_is_param
3947                            && b_bytes > ARENA_STAGE_CAP
3948                            && arena_tensor_in_window(&arena, b_id, base, size)
3949                        {
3950                            arena_local_off_f32(&arena, b_id, base)
3951                        } else {
3952                            arena_off_in_bind_window(
3953                                &graph,
3954                                &param_offsets,
3955                                &dev.device,
3956                                &arena,
3957                                &mut schedule,
3958                                &mut scratch,
3959                                b_id,
3960                                &mut base,
3961                                &mut size,
3962                            )
3963                        };
3964                        let bias_off_f32 = arena_off_in_bind_window(
3965                            &graph,
3966                            &param_offsets,
3967                            &dev.device,
3968                            &arena,
3969                            &mut schedule,
3970                            &mut scratch,
3971                            bias_id,
3972                            &mut base,
3973                            &mut size,
3974                        );
3975                        let b_off_global = (arena.offset(b_id) / 4) as u32;
3976                        let b_off_bind = if b_is_param
3977                            && matches!(
3978                                compute_precision,
3979                                MatmulCompute::Coop16
3980                                    | MatmulCompute::CoopF16Vk
3981                                    | MatmulCompute::F16
3982                            ) {
3983                            b_off_global
3984                        } else {
3985                            b_off_f32
3986                        };
3987                        maybe_push_coop_f16_vk_casts(
3988                            &graph,
3989                            a_id,
3990                            b_id,
3991                            &coop_f16_vk_mirror_acts,
3992                            &dev.device,
3993                            &arena,
3994                            &mut schedule,
3995                            &mut uniforms,
3996                            &mut bind_groups,
3997                            &mm_cast,
3998                            compute_precision,
3999                            a_off_f32,
4000                            m,
4001                            k,
4002                            1,
4003                            b_off_bind,
4004                            n,
4005                        );
4006                        schedule.push(Step::Matmul {
4007                            m,
4008                            k,
4009                            n,
4010                            batch: 1,
4011                            a_batch_stride: 0,
4012                            b_batch_stride: 0,
4013                            c_batch_stride: 0,
4014                            a_off_f32,
4015                            b_off_f32,
4016                            c_off_f32: arena_local_off_f32(&arena, node.id, base),
4017                            has_bias: 1,
4018                            bias_off_f32,
4019                            act_id,
4020                            b_is_param,
4021                            compute_precision,
4022                        });
4023                        register_coop_f16_vk_b_param(
4024                            &mut coop_f16_b_param,
4025                            &param_offsets,
4026                            b_id,
4027                            b_off_bind,
4028                            compute_precision,
4029                        );
4030                        let u = emit_uniform(std::mem::size_of::<MatmulParams>());
4031                        let (bg, b_off_adj) = build_matmul_bind_group(
4032                            &dev.device,
4033                            mm_k,
4034                            mm_w,
4035                            &mm_f16w,
4036                            &mm_f16c,
4037                            &mm_coop,
4038                            &mm_coop_f32,
4039                            &arena,
4040                            base,
4041                            size,
4042                            &u,
4043                            b_is_param,
4044                            compute_precision,
4045                            k,
4046                            n,
4047                            1,
4048                            b_off_bind,
4049                            0,
4050                        );
4051                        if let Some(Step::Matmul { b_off_f32, .. }) = schedule.last_mut() {
4052                            *b_off_f32 = b_off_adj;
4053                        }
4054                        uniforms.push(u);
4055                        bind_groups.push(bg);
4056                        if compute_precision == MatmulCompute::CoopF16Vk {
4057                            coop_f16_vk_wide_bind_groups.insert(
4058                                schedule.len() - 1,
4059                                bind_two_buf0_window(
4060                                    &dev.device,
4061                                    mm_w_active_compile,
4062                                    &arena.buffer,
4063                                    base,
4064                                    size,
4065                                    &uniforms[uniforms.len() - 1],
4066                                ),
4067                            );
4068                        }
4069                    }
4070                }
4071
4072                Op::DotGeneral { .. } => {
4073                    // Should be unreachable: DotGeneral is decomposed into
4074                    // MatMul + Transpose + Reshape by the unfusion pass
4075                    // before memory planning. If we hit this arm, the
4076                    // unfusion pass has a gap.
4077                    panic!(
4078                        "rlx-wgpu DotGeneral: leaked past unfusion pass — \
4079                            check unfuse.rs::expand_dot_general for missing patterns"
4080                    );
4081                }
4082
4083                Op::Sample {
4084                    top_k,
4085                    top_p,
4086                    temperature,
4087                    seed,
4088                } => {
4089                    let in_id = node.inputs[0];
4090                    let in_shape = graph.node(in_id).shape.dims();
4091                    let inner = in_shape[in_shape.len() - 1].unwrap_static() as u32;
4092                    let total: u32 = in_shape.iter().map(|d| d.unwrap_static() as u32).product();
4093                    let outer = total / inner.max(1);
4094                    // Greedy fast-path: temperature == 1.0 with no top_k/top_p
4095                    // is an argmax — same numeric result, much cheaper kernel.
4096                    let is_greedy = *top_k == 0
4097                        && (*top_p - 1.0).abs() < 1e-6
4098                        && (*temperature - 1.0).abs() < 1e-6;
4099                    if is_greedy {
4100                        let p = ArgmaxParams {
4101                            outer,
4102                            inner,
4103                            in_off: (arena.offset(in_id) / 4) as u32,
4104                            out_off: (arena.offset(node.id) / 4) as u32,
4105                            _p0: 0,
4106                            _p1: 0,
4107                            _p2: 0,
4108                            _p3: 0,
4109                        };
4110                        schedule.push(Step::Argmax { params: p });
4111                        let amk = argmax_kernel(&dev.device);
4112                        let u = emit_uniform(std::mem::size_of::<ArgmaxParams>());
4113                        let bg = bind_op_output_window(&dev.device, amk, &arena, node.id, &u);
4114                        uniforms.push(u);
4115                        bind_groups.push(bg);
4116                    } else {
4117                        let p = SampleParams {
4118                            outer,
4119                            inner,
4120                            in_off: (arena.offset(in_id) / 4) as u32,
4121                            out_off: (arena.offset(node.id) / 4) as u32,
4122                            top_k: *top_k as u32,
4123                            top_p_bits: top_p.to_bits(),
4124                            temp_bits: temperature.to_bits(),
4125                            seed_lo: *seed as u32,
4126                            seed_hi: (*seed >> 32) as u32,
4127                            _p0: 0,
4128                            _p1: 0,
4129                            _p2: 0,
4130                        };
4131                        schedule.push(Step::Sample { params: p });
4132                        let sk = sample_kernel(&dev.device);
4133                        let u = emit_uniform(std::mem::size_of::<SampleParams>());
4134                        let bg = bind_op_output_window(&dev.device, sk, &arena, node.id, &u);
4135                        uniforms.push(u);
4136                        bind_groups.push(bg);
4137                    }
4138                }
4139
4140                Op::Pool {
4141                    kind,
4142                    kernel_size,
4143                    stride,
4144                    padding,
4145                } => {
4146                    let in_shape = graph.node(node.inputs[0]).shape.dims();
4147                    let out_shape = node.shape.dims();
4148                    let op_id: u32 = match kind {
4149                        ReduceOp::Sum => 0,
4150                        ReduceOp::Mean => 1,
4151                        ReduceOp::Max => 2,
4152                        ReduceOp::Min => 3,
4153                        ReduceOp::Prod => 4,
4154                    };
4155                    match (kernel_size.len(), in_shape.len(), out_shape.len()) {
4156                        (1, 3, 3) => {
4157                            let p = Pool1dParams {
4158                                n: in_shape[0].unwrap_static() as u32,
4159                                c: in_shape[1].unwrap_static() as u32,
4160                                l: in_shape[2].unwrap_static() as u32,
4161                                l_out: out_shape[2].unwrap_static() as u32,
4162                                kl: kernel_size[0] as u32,
4163                                sl: stride.first().copied().unwrap_or(1) as u32,
4164                                pl: padding.first().copied().unwrap_or(0) as u32,
4165                                op: op_id,
4166                                in_off: (arena.offset(node.inputs[0]) / 4) as u32,
4167                                out_off: (arena.offset(node.id) / 4) as u32,
4168                                _p0: 0,
4169                                _p1: 0,
4170                                _p2: 0,
4171                                _p3: 0,
4172                                _p4: 0,
4173                                _p5: 0,
4174                            };
4175                            schedule.push(Step::Pool1d { params: p });
4176                            let pk = pool1d_kernel(&dev.device);
4177                            let u = emit_uniform(std::mem::size_of::<Pool1dParams>());
4178                            let bg = bind_op_output_window(&dev.device, pk, &arena, node.id, &u);
4179                            uniforms.push(u);
4180                            bind_groups.push(bg);
4181                        }
4182                        (2, 4, 4) => {
4183                            let p = Pool2dParams {
4184                                n: in_shape[0].unwrap_static() as u32,
4185                                c: in_shape[1].unwrap_static() as u32,
4186                                h: in_shape[2].unwrap_static() as u32,
4187                                w: in_shape[3].unwrap_static() as u32,
4188                                h_out: out_shape[2].unwrap_static() as u32,
4189                                w_out: out_shape[3].unwrap_static() as u32,
4190                                kh: kernel_size[0] as u32,
4191                                kw: kernel_size[1] as u32,
4192                                sh: stride.first().copied().unwrap_or(1) as u32,
4193                                sw: stride.get(1).copied().unwrap_or(1) as u32,
4194                                ph: padding.first().copied().unwrap_or(0) as u32,
4195                                pw: padding.get(1).copied().unwrap_or(0) as u32,
4196                                op: op_id,
4197                                in_off: (arena.offset(node.inputs[0]) / 4) as u32,
4198                                out_off: (arena.offset(node.id) / 4) as u32,
4199                                _p0: 0,
4200                                _p1: 0,
4201                                _p2: 0,
4202                            };
4203                            schedule.push(Step::Pool2d { params: p });
4204                            let pk = pool2d_kernel(&dev.device);
4205                            let u = emit_uniform(std::mem::size_of::<Pool2dParams>());
4206                            let bg = bind_op_output_window(&dev.device, pk, &arena, node.id, &u);
4207                            uniforms.push(u);
4208                            bind_groups.push(bg);
4209                        }
4210                        (3, 5, 5) => {
4211                            let p = Pool3dParams {
4212                                n: in_shape[0].unwrap_static() as u32,
4213                                c: in_shape[1].unwrap_static() as u32,
4214                                d: in_shape[2].unwrap_static() as u32,
4215                                h: in_shape[3].unwrap_static() as u32,
4216                                w: in_shape[4].unwrap_static() as u32,
4217                                d_out: out_shape[2].unwrap_static() as u32,
4218                                h_out: out_shape[3].unwrap_static() as u32,
4219                                w_out: out_shape[4].unwrap_static() as u32,
4220                                kd: kernel_size[0] as u32,
4221                                kh: kernel_size[1] as u32,
4222                                kw: kernel_size[2] as u32,
4223                                sd: stride.first().copied().unwrap_or(1) as u32,
4224                                sh: stride.get(1).copied().unwrap_or(1) as u32,
4225                                sw: stride.get(2).copied().unwrap_or(1) as u32,
4226                                pd: padding.first().copied().unwrap_or(0) as u32,
4227                                ph: padding.get(1).copied().unwrap_or(0) as u32,
4228                                pw: padding.get(2).copied().unwrap_or(0) as u32,
4229                                op: op_id,
4230                                in_off: (arena.offset(node.inputs[0]) / 4) as u32,
4231                                out_off: (arena.offset(node.id) / 4) as u32,
4232                                _p0: 0,
4233                                _p1: 0,
4234                            };
4235                            schedule.push(Step::Pool3d { params: p });
4236                            let pk = pool3d_kernel(&dev.device);
4237                            let u = emit_uniform(std::mem::size_of::<Pool3dParams>());
4238                            let bg = bind_op_output_window(&dev.device, pk, &arena, node.id, &u);
4239                            uniforms.push(u);
4240                            bind_groups.push(bg);
4241                        }
4242                        (k, n, m) => panic!(
4243                            "rlx-wgpu Pool: kernel-rank {k} with input rank {n} / \
4244                             output rank {m} not supported (use 1D/2D/3D NCHW)"
4245                        ),
4246                    }
4247                }
4248
4249                Op::Conv {
4250                    kernel_size,
4251                    stride,
4252                    padding,
4253                    dilation,
4254                    groups,
4255                } => {
4256                    let in_id = node.inputs[0];
4257                    let w_id = node.inputs[1];
4258                    let win_ids = [node.id, in_id, w_id];
4259                    let max_binding = dev.device.limits().max_storage_buffer_binding_size;
4260                    let fits = arena_span_bytes(&arena, &win_ids) <= max_binding;
4261                    let mut scratch = arena.scratch_off as u64;
4262                    let (mut base, mut size, param_anchor) = arena_multi_op_window(
4263                        &dev.device,
4264                        &arena,
4265                        &graph,
4266                        &param_offsets,
4267                        &mut schedule,
4268                        &mut scratch,
4269                        &win_ids,
4270                    );
4271                    arena_expand_bind_window(&arena, &win_ids, &mut base, &mut size, max_binding);
4272                    if !fits && !param_anchor {
4273                        base = arena_bind_window_covering_scratch_if_needed(
4274                            &arena, base, size, scratch,
4275                        );
4276                    }
4277                    let in_off = arena_off_in_bind_window(
4278                        &graph,
4279                        &param_offsets,
4280                        &dev.device,
4281                        &arena,
4282                        &mut schedule,
4283                        &mut scratch,
4284                        in_id,
4285                        &mut base,
4286                        &mut size,
4287                    );
4288                    let w_off = arena_off_in_bind_window(
4289                        &graph,
4290                        &param_offsets,
4291                        &dev.device,
4292                        &arena,
4293                        &mut schedule,
4294                        &mut scratch,
4295                        w_id,
4296                        &mut base,
4297                        &mut size,
4298                    );
4299                    let out_off = arena_off_in_bind_window(
4300                        &graph,
4301                        &param_offsets,
4302                        &dev.device,
4303                        &arena,
4304                        &mut schedule,
4305                        &mut scratch,
4306                        node.id,
4307                        &mut base,
4308                        &mut size,
4309                    );
4310
4311                    let in_shape = graph.node(in_id).shape.dims();
4312                    let w_shape = graph.node(w_id).shape.dims();
4313                    let out_shape = node.shape.dims();
4314                    let s = |i: usize| stride.get(i).copied().unwrap_or(1) as u32;
4315                    let p = |i: usize| padding.get(i).copied().unwrap_or(0) as u32;
4316                    let d = |i: usize| dilation.get(i).copied().unwrap_or(1) as u32;
4317                    match (
4318                        kernel_size.len(),
4319                        in_shape.len(),
4320                        w_shape.len(),
4321                        out_shape.len(),
4322                    ) {
4323                        (1, 3, 3, 3) => {
4324                            let p1 = Conv1dParams {
4325                                n: in_shape[0].unwrap_static() as u32,
4326                                c_in: in_shape[1].unwrap_static() as u32,
4327                                c_out: out_shape[1].unwrap_static() as u32,
4328                                l: in_shape[2].unwrap_static() as u32,
4329                                l_out: out_shape[2].unwrap_static() as u32,
4330                                kl: kernel_size[0] as u32,
4331                                sl: s(0),
4332                                pl: p(0),
4333                                dl: d(0),
4334                                groups: *groups as u32,
4335                                in_off,
4336                                w_off,
4337                                out_off,
4338                                _p0: 0,
4339                                _p1: 0,
4340                                _p2: 0,
4341                            };
4342                            schedule.push(Step::Conv1d { params: p1 });
4343                            let ck = conv1d_kernel(&dev.device);
4344                            let u = emit_uniform(std::mem::size_of::<Conv1dParams>());
4345                            let bg = bind_two_buf0_window(
4346                                &dev.device,
4347                                ck,
4348                                &arena.buffer,
4349                                base,
4350                                size,
4351                                &u,
4352                            );
4353                            uniforms.push(u);
4354                            bind_groups.push(bg);
4355                        }
4356                        (2, 4, 4, 4) => {
4357                            let p2 = Conv2dParams {
4358                                n: in_shape[0].unwrap_static() as u32,
4359                                c_in: in_shape[1].unwrap_static() as u32,
4360                                c_out: out_shape[1].unwrap_static() as u32,
4361                                h: in_shape[2].unwrap_static() as u32,
4362                                w: in_shape[3].unwrap_static() as u32,
4363                                h_out: out_shape[2].unwrap_static() as u32,
4364                                w_out: out_shape[3].unwrap_static() as u32,
4365                                kh: kernel_size[0] as u32,
4366                                kw: kernel_size[1] as u32,
4367                                sh: s(0),
4368                                sw: s(1),
4369                                ph: p(0),
4370                                pw: p(1),
4371                                dh: d(0),
4372                                dw: d(1),
4373                                groups: *groups as u32,
4374                                in_off,
4375                                w_off,
4376                                out_off,
4377                            };
4378                            schedule.push(Step::Conv2d { params: p2 });
4379                            let ck = conv2d_kernel(&dev.device);
4380                            let u = emit_uniform(std::mem::size_of::<Conv2dParams>());
4381                            let bg = bind_two_buf0_window(
4382                                &dev.device,
4383                                ck,
4384                                &arena.buffer,
4385                                base,
4386                                size,
4387                                &u,
4388                            );
4389                            uniforms.push(u);
4390                            bind_groups.push(bg);
4391                        }
4392                        (3, 5, 5, 5) => {
4393                            let p3 = Conv3dParams {
4394                                n: in_shape[0].unwrap_static() as u32,
4395                                c_in: in_shape[1].unwrap_static() as u32,
4396                                c_out: out_shape[1].unwrap_static() as u32,
4397                                d: in_shape[2].unwrap_static() as u32,
4398                                h: in_shape[3].unwrap_static() as u32,
4399                                w: in_shape[4].unwrap_static() as u32,
4400                                d_out: out_shape[2].unwrap_static() as u32,
4401                                h_out: out_shape[3].unwrap_static() as u32,
4402                                w_out: out_shape[4].unwrap_static() as u32,
4403                                kd: kernel_size[0] as u32,
4404                                kh: kernel_size[1] as u32,
4405                                kw: kernel_size[2] as u32,
4406                                sd: s(0),
4407                                sh: s(1),
4408                                sw: s(2),
4409                                pd: p(0),
4410                                ph: p(1),
4411                                pw: p(2),
4412                                dd: d(0),
4413                                dh: d(1),
4414                                dw: d(2),
4415                                groups: *groups as u32,
4416                                in_off,
4417                                w_off,
4418                                out_off,
4419                                _p0: 0,
4420                            };
4421                            schedule.push(Step::Conv3d { params: p3 });
4422                            let ck = conv3d_kernel(&dev.device);
4423                            let u = emit_uniform(std::mem::size_of::<Conv3dParams>());
4424                            let bg = bind_two_buf0_window(
4425                                &dev.device,
4426                                ck,
4427                                &arena.buffer,
4428                                base,
4429                                size,
4430                                &u,
4431                            );
4432                            uniforms.push(u);
4433                            bind_groups.push(bg);
4434                        }
4435                        (k, ni, wi, mi) => panic!(
4436                            "rlx-wgpu Conv: rank kernel={k} in={ni} weight={wi} out={mi} \
4437                             not supported (use 1D/2D/3D NCHW)"
4438                        ),
4439                    }
4440                }
4441
4442                Op::Im2Col {
4443                    kernel_size,
4444                    stride,
4445                    padding,
4446                    dilation,
4447                } => {
4448                    let x_shape = &graph.node(node.inputs[0]).shape;
4449                    if kernel_size.len() != 2 || x_shape.rank() != 4 {
4450                        panic!("rlx-wgpu Im2Col: 2D NCHW only");
4451                    }
4452                    let n = match x_shape.dim(0) {
4453                        rlx_ir::shape::Dim::Static(v) => v as u32,
4454                        _ => 0,
4455                    };
4456                    let c_in = x_shape.dim(1).unwrap_static() as u32;
4457                    let h = x_shape.dim(2).unwrap_static() as u32;
4458                    let w = x_shape.dim(3).unwrap_static() as u32;
4459                    let kh = kernel_size[0] as u32;
4460                    let kw = kernel_size[1] as u32;
4461                    let sh = stride.first().copied().unwrap_or(1) as u32;
4462                    let sw = stride.get(1).copied().unwrap_or(1) as u32;
4463                    let ph = padding.first().copied().unwrap_or(0) as u32;
4464                    let pw = padding.get(1).copied().unwrap_or(0) as u32;
4465                    let dh = dilation.first().copied().unwrap_or(1) as u32;
4466                    let dw_dil = dilation.get(1).copied().unwrap_or(1) as u32;
4467                    let h_out = rlx_ir::shape::conv2d_spatial_output(
4468                        h as usize,
4469                        kh as usize,
4470                        sh as usize,
4471                        ph as usize,
4472                        dh as usize,
4473                    ) as u32;
4474                    let w_out = rlx_ir::shape::conv2d_spatial_output(
4475                        w as usize,
4476                        kw as usize,
4477                        sw as usize,
4478                        pw as usize,
4479                        dw_dil as usize,
4480                    ) as u32;
4481                    schedule.push(Step::Im2ColHost {
4482                        x_byte_off: arena.offset(node.inputs[0]) as u32,
4483                        col_byte_off: arena.offset(node.id) as u32,
4484                        n,
4485                        c_in,
4486                        h,
4487                        w,
4488                        h_out,
4489                        w_out,
4490                        kh,
4491                        kw,
4492                        sh,
4493                        sw,
4494                        ph,
4495                        pw,
4496                        dh,
4497                        dw_dil,
4498                    });
4499                }
4500
4501                Op::Cumsum { axis, exclusive } => {
4502                    let in_id = node.inputs[0];
4503                    let in_shape = graph.node(in_id).shape.dims();
4504                    let last = (in_shape.len() - 1) as i32;
4505                    if *axis != -1 && *axis != last {
4506                        panic!("rlx-wgpu Cumsum: only last-axis wired (got axis={axis})");
4507                    }
4508                    let inner = in_shape[in_shape.len() - 1].unwrap_static() as u32;
4509                    let total: u32 = in_shape.iter().map(|d| d.unwrap_static() as u32).product();
4510                    let outer = total / inner.max(1);
4511                    let p = CumsumParams {
4512                        outer,
4513                        inner,
4514                        in_off: (arena.offset(in_id) / 4) as u32,
4515                        out_off: (arena.offset(node.id) / 4) as u32,
4516                        exclusive: if *exclusive { 1 } else { 0 },
4517                        _p0: 0,
4518                        _p1: 0,
4519                        _p2: 0,
4520                    };
4521                    schedule.push(Step::Cumsum { params: p });
4522                    let ck2 = cumsum_kernel(&dev.device);
4523                    let u = emit_uniform(std::mem::size_of::<CumsumParams>());
4524                    let bg = bind_op_output_window(&dev.device, ck2, &arena, node.id, &u);
4525                    uniforms.push(u);
4526                    bind_groups.push(bg);
4527                }
4528                Op::Fft { inverse, norm } => {
4529                    let in_id = node.inputs[0];
4530                    let in_shape = graph.node(in_id).shape.clone();
4531                    let meta = rlx_ir::fft::fft_meta(&in_shape);
4532                    let dtype = in_shape.dtype();
4533                    let use_gpu = rlx_ir::fft::gpu_fft_native_eligible(dtype, meta.n_complex)
4534                        && meta.n_complex >= 2;
4535                    let scale = norm.output_scale(meta.n_complex, *inverse) as f32;
4536                    if use_gpu {
4537                        schedule.push(Step::FftGpu {
4538                            src_off: (arena.offset(in_id) / 4) as u32,
4539                            dst_off: (arena.offset(node.id) / 4) as u32,
4540                            outer: meta.outer as u32,
4541                            n: meta.n_complex as u32,
4542                            inverse: if *inverse { 1 } else { 0 },
4543                            norm_scale: scale,
4544                        });
4545                        fft_gpu_steps.push(crate::fft_dispatch::FftGpuResources::new(
4546                            &dev.device,
4547                            &arena.buffer,
4548                        ));
4549                    } else {
4550                        schedule.push(Step::FftHost {
4551                            src_byte_off: arena.offset(in_id) as u32,
4552                            dst_byte_off: arena.offset(node.id) as u32,
4553                            outer: meta.outer as u32,
4554                            n_complex: meta.n_complex as u32,
4555                            inverse: *inverse,
4556                            norm_tag: norm.tag(),
4557                            dtype_tag: fft_dtype_tag(dtype),
4558                        });
4559                    }
4560                }
4561                Op::WelchPeaks { k, n_segments } => {
4562                    let spec_shape = graph.node(node.inputs[0]).shape.clone();
4563                    let meta = rlx_ir::audio::welch_peaks_meta(&spec_shape, *k, *n_segments)
4564                        .unwrap_or_else(|e| panic!("Op::WelchPeaks: {e}"));
4565                    let use_gpu = rlx_ir::audio::welch_peaks_gpu_native_eligible(
4566                        &spec_shape,
4567                        *k,
4568                        *n_segments,
4569                    )
4570                    .unwrap_or(false);
4571                    if use_gpu {
4572                        let p = WelchPeaksGpuParams {
4573                            spec_off: (arena.offset(node.inputs[0]) / 4) as u32,
4574                            dst_off: (arena.offset(node.id) / 4) as u32,
4575                            welch_batch: meta.welch_batch as u32,
4576                            n_fft: meta.n_fft as u32,
4577                            n_segments: meta.n_segments as u32,
4578                            k: meta.k as u32,
4579                            n_bins: meta.n_bins as u32,
4580                            _p0: 0,
4581                            _p1: 0,
4582                        };
4583                        schedule.push(Step::WelchPeaksGpu { params: p });
4584                        let wk = welch_peaks_gpu_kernel(&dev.device);
4585                        let u = emit_uniform(std::mem::size_of::<WelchPeaksGpuParams>());
4586                        let bg = bind_op_output_window(&dev.device, wk, &arena, node.id, &u);
4587                        uniforms.push(u);
4588                        bind_groups.push(bg);
4589                    } else {
4590                        schedule.push(Step::WelchPeaksHost {
4591                            spec_byte_off: arena.offset(node.inputs[0]) as u32,
4592                            dst_byte_off: arena.offset(node.id) as u32,
4593                            welch_batch: meta.welch_batch as u32,
4594                            n_fft: meta.n_fft as u32,
4595                            n_segments: meta.n_segments as u32,
4596                            k: meta.k as u32,
4597                        });
4598                    }
4599                }
4600                Op::LogMel => {
4601                    let spec_shape = graph.node(node.inputs[0]).shape.clone();
4602                    let filt_shape = graph.node(node.inputs[1]).shape.clone();
4603                    let meta = rlx_ir::audio::log_mel_meta(&spec_shape, &filt_shape)
4604                        .unwrap_or_else(|e| panic!("Op::LogMel: {e}"));
4605                    schedule.push(Step::LogMelHost {
4606                        spec_byte_off: arena.offset(node.inputs[0]) as u32,
4607                        filt_byte_off: arena.offset(node.inputs[1]) as u32,
4608                        dst_byte_off: arena.offset(node.id) as u32,
4609                        outer: meta.outer as u32,
4610                        n_fft: meta.n_fft as u32,
4611                        n_bins: meta.n_bins as u32,
4612                        n_mels: meta.n_mels as u32,
4613                    });
4614                }
4615                Op::LogMelBackward => {
4616                    let spec_shape = graph.node(node.inputs[0]).shape.clone();
4617                    let filt_shape = graph.node(node.inputs[1]).shape.clone();
4618                    let meta = rlx_ir::audio::log_mel_meta(&spec_shape, &filt_shape)
4619                        .unwrap_or_else(|e| panic!("Op::LogMelBackward: {e}"));
4620                    schedule.push(Step::LogMelBackwardHost {
4621                        spec_byte_off: arena.offset(node.inputs[0]) as u32,
4622                        filt_byte_off: arena.offset(node.inputs[1]) as u32,
4623                        dy_byte_off: arena.offset(node.inputs[2]) as u32,
4624                        dst_byte_off: arena.offset(node.id) as u32,
4625                        outer: meta.outer as u32,
4626                        n_fft: meta.n_fft as u32,
4627                        n_bins: meta.n_bins as u32,
4628                        n_mels: meta.n_mels as u32,
4629                    });
4630                }
4631                Op::SelectiveScan { state_size } => {
4632                    if *state_size > 256 {
4633                        panic!(
4634                            "rlx-wgpu SelectiveScan: state_size {} exceeds compile-time \
4635                                cap of 256 (kernel uses fixed-size private array)",
4636                            state_size
4637                        );
4638                    }
4639                    let x_id = node.inputs[0];
4640                    let dt_id = node.inputs[1];
4641                    let a_id = node.inputs[2];
4642                    let b_id = node.inputs[3];
4643                    let c_id = node.inputs[4];
4644                    let in_dims = graph.node(x_id).shape.dims();
4645                    let seq = in_dims[1].unwrap_static() as u32;
4646                    let p = SelectiveScanParams {
4647                        batch: in_dims[0].unwrap_static() as u32,
4648                        seq,
4649                        hidden: in_dims[2].unwrap_static() as u32,
4650                        state_size: *state_size as u32,
4651                        x_off: (arena.offset(x_id) / 4) as u32,
4652                        delta_off: (arena.offset(dt_id) / 4) as u32,
4653                        a_off: (arena.offset(a_id) / 4) as u32,
4654                        b_off: (arena.offset(b_id) / 4) as u32,
4655                        c_off: (arena.offset(c_id) / 4) as u32,
4656                        out_off: (arena.offset(node.id) / 4) as u32,
4657                        // PLAN L1: full-extent stride; safe under
4658                        // active-extent scaling of params.seq.
4659                        seq_stride: seq,
4660                        _p1: 0,
4661                        _p2: 0,
4662                        _p3: 0,
4663                        _p4: 0,
4664                        _p5: 0,
4665                    };
4666                    schedule.push(Step::SelectiveScan { params: p });
4667                    let ssk = selective_scan_kernel(&dev.device);
4668                    let u = emit_uniform(std::mem::size_of::<SelectiveScanParams>());
4669                    let bg = bind_op_output_window(&dev.device, ssk, &arena, node.id, &u);
4670                    uniforms.push(u);
4671                    bind_groups.push(bg);
4672                }
4673                Op::GatedDeltaNet {
4674                    state_size,
4675                    carry_state,
4676                } => {
4677                    if *state_size > rlx_cpu::gdn::GDN_MAX_STATE {
4678                        panic!(
4679                            "rlx-wgpu GatedDeltaNet: state_size {state_size} > {}",
4680                            rlx_cpu::gdn::GDN_MAX_STATE
4681                        );
4682                    }
4683                    let q_id = node.inputs[0];
4684                    let q_shape = &graph.node(q_id).shape;
4685                    let state_off = if *carry_state {
4686                        arena.offset(node.inputs[5])
4687                    } else {
4688                        0
4689                    };
4690                    schedule.push(Step::GatedDeltaNet {
4691                        q_byte_off: arena.offset(q_id) as u32,
4692                        k_byte_off: arena.offset(node.inputs[1]) as u32,
4693                        v_byte_off: arena.offset(node.inputs[2]) as u32,
4694                        g_byte_off: arena.offset(node.inputs[3]) as u32,
4695                        beta_byte_off: arena.offset(node.inputs[4]) as u32,
4696                        state_byte_off: state_off as u32,
4697                        dst_byte_off: arena.offset(node.id) as u32,
4698                        batch: q_shape.dim(0).unwrap_static() as u32,
4699                        seq: q_shape.dim(1).unwrap_static() as u32,
4700                        heads: q_shape.dim(2).unwrap_static() as u32,
4701                        state_size: *state_size as u32,
4702                        use_carry: *carry_state,
4703                    });
4704                    if gguf_host_pad.is_none() {
4705                        let bk = binary_kernel(&dev.device);
4706                        let u = emit_uniform(256);
4707                        gguf_host_pad = Some((
4708                            u.clone(),
4709                            bind_op_output_window(&dev.device, bk, &arena, node.id, &u),
4710                        ));
4711                    }
4712                    let (u, bg) = gguf_host_pad.as_ref().unwrap();
4713                    uniforms.push(u.clone());
4714                    bind_groups.push(bg.clone());
4715                }
4716                Op::Custom { name, attrs, .. } => match name.as_str() {
4717                    "llada2.group_limited_gate" => {
4718                        let sig_id = node.inputs[0];
4719                        let route_id = node.inputs[1];
4720                        let n_elems = graph.node(sig_id).shape.num_elements().unwrap() as u32;
4721                        let mut attr_buf = [0u8; 20];
4722                        let n = attrs.len().min(20);
4723                        attr_buf[..n].copy_from_slice(&attrs[..n]);
4724                        schedule.push(Step::Llada2GroupLimitedGate {
4725                            sig_byte_off: arena.offset(sig_id) as u32,
4726                            route_byte_off: arena.offset(route_id) as u32,
4727                            out_byte_off: arena.offset(node.id) as u32,
4728                            n_elems,
4729                            attrs: attr_buf,
4730                        });
4731                    }
4732                    "umap.knn" => {
4733                        let pw_id = node.inputs[0];
4734                        let pw_shape = graph.node(pw_id).shape.dims();
4735                        let n = pw_shape[0].unwrap_static() as u32;
4736                        let k = if attrs.len() >= 4 {
4737                            u32::from_le_bytes(attrs[..4].try_into().unwrap())
4738                        } else {
4739                            panic!("rlx-wgpu: umap.knn attrs missing k");
4740                        };
4741                        let pw_off = arena.offset(pw_id) as u32;
4742                        let out_off = arena.offset(node.id) as u32;
4743                        if n as usize >= crate::umap_knn_host::UMAP_KNN_GPU_MIN_N {
4744                            let p = UmapKnnParams {
4745                                n,
4746                                k,
4747                                pw_off: pw_off / 4,
4748                                out_off: out_off / 4,
4749                                _p0: 0,
4750                                _p1: 0,
4751                                _p2: 0,
4752                            };
4753                            schedule.push(Step::UmapKnn { params: p });
4754                            let uk = umap_knn_kernel(&dev.device);
4755                            let u = emit_uniform(std::mem::size_of::<UmapKnnParams>());
4756                            let bg = bind_op_output_window(&dev.device, uk, &arena, node.id, &u);
4757                            uniforms.push(u);
4758                            bind_groups.push(bg);
4759                        } else {
4760                            schedule.push(Step::UmapKnnHost {
4761                                pairwise_byte_off: pw_off,
4762                                out_byte_off: out_off,
4763                                n,
4764                                k,
4765                            });
4766                        }
4767                    }
4768                    other => panic!("rlx-wgpu: unsupported Op::Custom('{other}')"),
4769                },
4770                Op::GroupedMatMul => {
4771                    // Inputs: input [M, K], weight [E, K, N], expert_idx [M]
4772                    let in_id = node.inputs[0];
4773                    let w_id = node.inputs[1];
4774                    let idx_id = node.inputs[2];
4775                    let in_dims = graph.node(in_id).shape.dims();
4776                    let w_dims = graph.node(w_id).shape.dims();
4777                    let m = in_dims[0].unwrap_static() as u32;
4778                    let k = in_dims[1].unwrap_static() as u32;
4779                    let n = w_dims[2].unwrap_static() as u32;
4780                    let ne = w_dims[0].unwrap_static() as u32;
4781                    let p = GroupedMatmulParams {
4782                        m,
4783                        k,
4784                        n,
4785                        num_experts: ne,
4786                        in_off: (arena.offset(in_id) / 4) as u32,
4787                        w_off: (arena.offset(w_id) / 4) as u32,
4788                        idx_off: (arena.offset(idx_id) / 4) as u32,
4789                        out_off: (arena.offset(node.id) / 4) as u32,
4790                    };
4791                    schedule.push(Step::GroupedMatmul { params: p });
4792                    let gk = grouped_matmul_kernel(&dev.device);
4793                    let u = emit_uniform(std::mem::size_of::<GroupedMatmulParams>());
4794                    let bg = bind_op_output_window(&dev.device, gk, &arena, node.id, &u);
4795                    uniforms.push(u);
4796                    bind_groups.push(bg);
4797                }
4798                Op::DequantGroupedMatMul { scheme } => {
4799                    let in_id = node.inputs[0];
4800                    let w_id = node.inputs[1];
4801                    let idx_id = node.inputs[2];
4802                    let in_dims = graph.node(in_id).shape.dims();
4803                    let out_dims = node.shape.dims();
4804                    let m = in_dims[0].unwrap_static() as u32;
4805                    let k = in_dims[1].unwrap_static() as u32;
4806                    let n = out_dims[out_dims.len() - 1].unwrap_static() as u32;
4807                    let block_elems = scheme.gguf_block_size() as usize;
4808                    let block_bytes = scheme.gguf_block_bytes() as usize;
4809                    let slab_bytes = (k as usize * n as usize) / block_elems * block_bytes;
4810                    let total_bytes = graph.node(w_id).shape.num_elements().unwrap();
4811                    let ne = (total_bytes / slab_bytes.max(1)) as u32;
4812                    schedule.push(Step::DequantGroupedMatmulGguf {
4813                        m,
4814                        k,
4815                        n,
4816                        num_experts: ne,
4817                        scheme_id: crate::gguf_host::gguf_scheme_id(*scheme),
4818                        x_byte_off: arena.offset(in_id) as u32,
4819                        w_byte_off: arena.offset(w_id) as u32,
4820                        idx_byte_off: arena.offset(idx_id) as u32,
4821                        out_byte_off: arena.offset(node.id) as u32,
4822                    });
4823                    if gguf_host_pad.is_none() {
4824                        let bk = binary_kernel(&dev.device);
4825                        let u = emit_uniform(256);
4826                        gguf_host_pad = Some((
4827                            u.clone(),
4828                            bind_op_output_window(&dev.device, bk, &arena, node.id, &u),
4829                        ));
4830                    }
4831                    let (u, bg) = gguf_host_pad.as_ref().unwrap();
4832                    uniforms.push(u.clone());
4833                    bind_groups.push(bg.clone());
4834                }
4835                Op::TopK { k } => {
4836                    let in_id = node.inputs[0];
4837                    let in_dims = graph.node(in_id).shape.dims();
4838                    let inner = in_dims.last().unwrap().unwrap_static() as u32;
4839                    let outer: u32 = in_dims[..in_dims.len() - 1]
4840                        .iter()
4841                        .map(|d| d.unwrap_static() as u32)
4842                        .product::<u32>()
4843                        .max(1);
4844                    let p = TopKParams {
4845                        outer,
4846                        inner,
4847                        k: *k as u32,
4848                        in_off: (arena.offset(in_id) / 4) as u32,
4849                        out_off: (arena.offset(node.id) / 4) as u32,
4850                        _p0: 0,
4851                        _p1: 0,
4852                        _p2: 0,
4853                    };
4854                    schedule.push(Step::TopK { params: p });
4855                    let tk = topk_kernel(&dev.device);
4856                    let u = emit_uniform(std::mem::size_of::<TopKParams>());
4857                    let bg = bind_op_output_window(&dev.device, tk, &arena, node.id, &u);
4858                    uniforms.push(u);
4859                    bind_groups.push(bg);
4860                }
4861                Op::ScatterAdd => {
4862                    // Inputs: updates [num_updates, trailing], indices [num_updates].
4863                    // Output: [out_dim, trailing]. Implemented as two phases:
4864                    //   1. Zero `out_dim * trailing` slots.
4865                    //   2. CAS-loop atomic-accumulate `num_updates * trailing` updates.
4866                    let upd_id = node.inputs[0];
4867                    let idx_id = node.inputs[1];
4868                    let upd_dims = graph.node(upd_id).shape.dims();
4869                    let out_dims = node.shape.dims();
4870                    let num_updates = upd_dims[0].unwrap_static() as u32;
4871                    let trailing: u32 = upd_dims
4872                        .iter()
4873                        .skip(1)
4874                        .map(|d| d.unwrap_static() as u32)
4875                        .product::<u32>()
4876                        .max(1);
4877                    let out_dim = out_dims[0].unwrap_static() as u32;
4878                    let out_total = out_dim * trailing;
4879
4880                    let common = ScatterAddParams {
4881                        op: 0,
4882                        out_off: (arena.offset(node.id) / 4) as u32,
4883                        upd_off: (arena.offset(upd_id) / 4) as u32,
4884                        idx_off: (arena.offset(idx_id) / 4) as u32,
4885                        out_total,
4886                        num_updates,
4887                        trailing,
4888                        out_dim,
4889                    };
4890                    let sk = scatter_add_kernel(&dev.device);
4891
4892                    // Phase 0: zero.
4893                    schedule.push(Step::ScatterAdd { params: common });
4894                    let u0 = emit_uniform(std::mem::size_of::<ScatterAddParams>());
4895                    let bg0 = bind_op_output_window(&dev.device, sk, &arena, node.id, &u0);
4896                    uniforms.push(u0);
4897                    bind_groups.push(bg0);
4898
4899                    // Phase 1: accumulate.
4900                    let mut acc = common;
4901                    acc.op = 1;
4902                    schedule.push(Step::ScatterAdd { params: acc });
4903                    let u1 = emit_uniform(std::mem::size_of::<ScatterAddParams>());
4904                    let bg1 = bind_op_output_window(&dev.device, sk, &arena, node.id, &u1);
4905                    uniforms.push(u1);
4906                    bind_groups.push(bg1);
4907                }
4908                Op::FusedResidualLN { has_bias, eps } => {
4909                    // Inputs: [x, residual, [bias], gamma, beta].
4910                    let x_id = node.inputs[0];
4911                    let r_id = node.inputs[1];
4912                    let (bias_id, g_id, b_id) = if *has_bias {
4913                        (node.inputs[2], node.inputs[3], node.inputs[4])
4914                    } else {
4915                        (x_id, node.inputs[2], node.inputs[3]) // bias unused
4916                    };
4917                    let in_dims = node.shape.dims();
4918                    let inner = in_dims[in_dims.len() - 1].unwrap_static() as u32;
4919                    let total: u32 = in_dims.iter().map(|d| d.unwrap_static() as u32).product();
4920                    let outer = total / inner.max(1);
4921                    let p = FusedResidualLnParams {
4922                        outer,
4923                        inner,
4924                        in_off: (arena.offset(x_id) / 4) as u32,
4925                        residual_off: (arena.offset(r_id) / 4) as u32,
4926                        bias_off: (arena.offset(bias_id) / 4) as u32,
4927                        gamma_off: (arena.offset(g_id) / 4) as u32,
4928                        beta_off: (arena.offset(b_id) / 4) as u32,
4929                        out_off: (arena.offset(node.id) / 4) as u32,
4930                        eps_bits: eps.to_bits(),
4931                        has_bias: if *has_bias { 1 } else { 0 },
4932                        _p0: 0,
4933                        _p1: 0,
4934                    };
4935                    schedule.push(Step::FusedResidualLn { params: p });
4936                    let frk = fused_residual_ln_kernel(&dev.device);
4937                    let u = emit_uniform(std::mem::size_of::<FusedResidualLnParams>());
4938                    let bg = bind_op_output_window(&dev.device, frk, &arena, node.id, &u);
4939                    uniforms.push(u);
4940                    bind_groups.push(bg);
4941                }
4942                Op::FusedResidualRmsNorm { has_bias, eps } => {
4943                    let x_id = node.inputs[0];
4944                    let r_id = node.inputs[1];
4945                    let (bias_id, g_id, b_id) = if *has_bias {
4946                        (node.inputs[2], node.inputs[3], node.inputs[4])
4947                    } else {
4948                        (x_id, node.inputs[2], node.inputs[3])
4949                    };
4950                    let in_dims = node.shape.dims();
4951                    let inner = in_dims[in_dims.len() - 1].unwrap_static() as u32;
4952                    let total: u32 = in_dims.iter().map(|d| d.unwrap_static() as u32).product();
4953                    let outer = total / inner.max(1);
4954                    let p = FusedResidualRmsNormParams {
4955                        outer,
4956                        inner,
4957                        in_off: (arena.offset(x_id) / 4) as u32,
4958                        residual_off: (arena.offset(r_id) / 4) as u32,
4959                        bias_off: (arena.offset(bias_id) / 4) as u32,
4960                        gamma_off: (arena.offset(g_id) / 4) as u32,
4961                        beta_off: (arena.offset(b_id) / 4) as u32,
4962                        out_off: (arena.offset(node.id) / 4) as u32,
4963                        eps_bits: eps.to_bits(),
4964                        has_bias: if *has_bias { 1 } else { 0 },
4965                        _p0: 0,
4966                        _p1: 0,
4967                    };
4968                    schedule.push(Step::FusedResidualRmsNorm { params: p });
4969                    let frk = fused_residual_rms_norm_kernel(&dev.device);
4970                    let u = emit_uniform(std::mem::size_of::<FusedResidualRmsNormParams>());
4971                    let bg = bind_op_output_window(&dev.device, frk, &arena, node.id, &u);
4972                    uniforms.push(u);
4973                    bind_groups.push(bg);
4974                }
4975                Op::DequantMatMul { scheme } => {
4976                    use rlx_ir::QuantScheme;
4977                    let x_id = node.inputs[0];
4978                    let w_id = node.inputs[1];
4979                    let out_dims = node.shape.dims();
4980                    let x_dims = graph.node(x_id).shape.dims();
4981                    let m = out_dims[0].unwrap_static() as u32;
4982                    let n = out_dims[1].unwrap_static() as u32;
4983                    let k = x_dims[1].unwrap_static() as u32;
4984                    if scheme.is_gguf() {
4985                        schedule.push(Step::DequantMatmulGguf {
4986                            m,
4987                            k,
4988                            n,
4989                            scheme_id: crate::gguf_host::gguf_scheme_id(*scheme),
4990                            x_byte_off: arena.offset(x_id) as u32,
4991                            w_byte_off: arena.offset(w_id) as u32,
4992                            out_byte_off: arena.offset(node.id) as u32,
4993                        });
4994                        if gguf_host_pad.is_none() {
4995                            let bk = binary_kernel(&dev.device);
4996                            let u = emit_uniform(256);
4997                            gguf_host_pad = Some((
4998                                u.clone(),
4999                                bind_op_output_window(&dev.device, bk, &arena, node.id, &u),
5000                            ));
5001                        }
5002                        let (u, bg) = gguf_host_pad.as_ref().unwrap();
5003                        uniforms.push(u.clone());
5004                        bind_groups.push(bg.clone());
5005                    } else {
5006                        let (block_size, scheme_id) = match scheme {
5007                            QuantScheme::Int8Block { block_size } => (*block_size, 0u32),
5008                            QuantScheme::Int8BlockAsym { block_size } => (*block_size, 1u32),
5009                            QuantScheme::Int4Block { block_size } => (*block_size, 2u32),
5010                            QuantScheme::Fp8E4m3 => (1, 3u32),
5011                            QuantScheme::Fp8E5m2 => (1, 4u32),
5012                            QuantScheme::Nvfp4Block => (rlx_ir::NVFP4_GROUP_SIZE as u32, 5u32),
5013                            other => panic!("rlx-wgpu DequantMatMul: unsupported scheme {other:?}"),
5014                        };
5015                        let scale_id = node.inputs[2];
5016                        let zp_id = node.inputs[3];
5017                        let p = DequantMatmulParams {
5018                            m,
5019                            k,
5020                            n,
5021                            block_size,
5022                            scheme_id,
5023                            x_off: (arena.offset(x_id) / 4) as u32,
5024                            w_off: (arena.offset(w_id) / 4) as u32,
5025                            scale_off: (arena.offset(scale_id) / 4) as u32,
5026                            zp_off: (arena.offset(zp_id) / 4) as u32,
5027                            out_off: (arena.offset(node.id) / 4) as u32,
5028                            _p0: 0,
5029                            _p1: 0,
5030                        };
5031                        schedule.push(Step::DequantMatmul { params: p });
5032                        let dk = dequant_matmul_kernel(&dev.device);
5033                        let u = emit_uniform(std::mem::size_of::<DequantMatmulParams>());
5034                        let bg = bind_op_output_window(&dev.device, dk, &arena, node.id, &u);
5035                        uniforms.push(u);
5036                        bind_groups.push(bg);
5037                    }
5038                }
5039                Op::RmsNormBackwardInput { eps, .. }
5040                | Op::RmsNormBackwardGamma { eps, .. }
5041                | Op::RmsNormBackwardBeta { eps, .. } => {
5042                    let x_shape = &graph.node(node.inputs[0]).shape;
5043                    let h = x_shape.dim(x_shape.rank() - 1).unwrap_static() as u32;
5044                    let rows = (x_shape.num_elements().unwrap() / h.max(1) as usize) as u32;
5045                    let foff = |i: usize| (arena.offset(node.inputs[i]) / 4) as u32;
5046                    let wrt = match &node.op {
5047                        Op::RmsNormBackwardInput { .. } => 0u32,
5048                        Op::RmsNormBackwardGamma { .. } => 1u32,
5049                        Op::RmsNormBackwardBeta { .. } => 2u32,
5050                        _ => unreachable!(),
5051                    };
5052                    let p = RmsNormBwdParams {
5053                        outer: rows,
5054                        inner: h,
5055                        x_off: foff(0),
5056                        gamma_off: foff(1),
5057                        beta_off: foff(2),
5058                        dy_off: foff(3),
5059                        out_off: (arena.offset(node.id) / 4) as u32,
5060                        eps_bits: eps.to_bits(),
5061                        wrt,
5062                    };
5063                    let rk = if wrt == 0 {
5064                        rms_norm_backward_kernel(&dev.device)
5065                    } else {
5066                        rms_norm_backward_param_kernel(&dev.device)
5067                    };
5068                    let u = emit_uniform(std::mem::size_of::<RmsNormBwdParams>());
5069                    let bg = bind_op_output_window(&dev.device, rk, &arena, node.id, &u);
5070                    match &node.op {
5071                        Op::RmsNormBackwardInput { .. } => {
5072                            schedule.push(Step::RmsNormBackwardInput { params: p });
5073                        }
5074                        Op::RmsNormBackwardGamma { .. } => {
5075                            schedule.push(Step::RmsNormBackwardGamma { params: p });
5076                        }
5077                        Op::RmsNormBackwardBeta { .. } => {
5078                            schedule.push(Step::RmsNormBackwardBeta { params: p });
5079                        }
5080                        _ => unreachable!(),
5081                    }
5082                    uniforms.push(u);
5083                    bind_groups.push(bg);
5084                }
5085                Op::LayerNormBackwardInput { eps, .. } => {
5086                    let x_shape = &graph.node(node.inputs[0]).shape;
5087                    let h = x_shape.dim(x_shape.rank() - 1).unwrap_static() as u32;
5088                    let rows = (x_shape.num_elements().unwrap() / h.max(1) as usize) as u32;
5089                    let p = LayerNormBwdParams {
5090                        outer: rows,
5091                        inner: h,
5092                        x_off: (arena.offset(node.inputs[0]) / 4) as u32,
5093                        gamma_off: (arena.offset(node.inputs[1]) / 4) as u32,
5094                        dy_off: (arena.offset(node.inputs[2]) / 4) as u32,
5095                        out_off: (arena.offset(node.id) / 4) as u32,
5096                        eps_bits: eps.to_bits(),
5097                        scratch_off: 0,
5098                    };
5099                    let rk = layer_norm_backward_input_kernel(&dev.device);
5100                    let u = emit_uniform(std::mem::size_of::<LayerNormBwdParams>());
5101                    let bg = bind_op_output_window(&dev.device, rk, &arena, node.id, &u);
5102                    schedule.push(Step::LayerNormBackwardInput { params: p });
5103                    uniforms.push(u);
5104                    bind_groups.push(bg);
5105                }
5106                Op::LayerNormBackwardGamma { eps, .. } => {
5107                    // Inputs: [x, dy] — gamma_off is unused for this op.
5108                    // Emit two steps: a multi-workgroup partial that
5109                    // writes per-chunk dgamma to the tail scratch zone,
5110                    // and a single-workgroup reduce that sums chunks
5111                    // into the final dgamma slot.
5112                    let x_shape = &graph.node(node.inputs[0]).shape;
5113                    let h = x_shape.dim(x_shape.rank() - 1).unwrap_static() as u32;
5114                    let rows = (x_shape.num_elements().unwrap() / h.max(1) as usize) as u32;
5115                    const ROWS_PER_WG: u32 = 16;
5116                    let num_workgroups = rows.div_ceil(ROWS_PER_WG.max(1));
5117                    let scratch_off_words = (arena.scratch_off / 4) as u32;
5118                    let partial_params = LayerNormBwdParams {
5119                        outer: rows,
5120                        inner: h,
5121                        x_off: (arena.offset(node.inputs[0]) / 4) as u32,
5122                        gamma_off: 0,
5123                        dy_off: (arena.offset(node.inputs[1]) / 4) as u32,
5124                        out_off: 0, // unused by the partial kernel
5125                        eps_bits: eps.to_bits(),
5126                        scratch_off: scratch_off_words,
5127                    };
5128                    let reduce_params = LayerNormBwdParams {
5129                        // `outer` for the reduce kernel carries the
5130                        // number of partial chunks we just emitted.
5131                        outer: num_workgroups,
5132                        inner: h,
5133                        x_off: 0,
5134                        gamma_off: 0,
5135                        dy_off: 0,
5136                        out_off: (arena.offset(node.id) / 4) as u32,
5137                        eps_bits: eps.to_bits(),
5138                        scratch_off: scratch_off_words,
5139                    };
5140                    let p_k = layer_norm_backward_gamma_partial_kernel(&dev.device);
5141                    let r_k = layer_norm_backward_gamma_reduce_kernel(&dev.device);
5142                    let p_u = emit_uniform(std::mem::size_of::<LayerNormBwdParams>());
5143                    let r_u = emit_uniform(std::mem::size_of::<LayerNormBwdParams>());
5144                    let p_bg = bind_op_output_window(&dev.device, p_k, &arena, node.id, &p_u);
5145                    let r_bg = bind_op_output_window(&dev.device, r_k, &arena, node.id, &r_u);
5146                    schedule.push(Step::LayerNormBackwardGammaPartial {
5147                        params: partial_params,
5148                        num_workgroups,
5149                    });
5150                    schedule.push(Step::LayerNormBackwardGammaReduce {
5151                        params: reduce_params,
5152                    });
5153                    uniforms.push(p_u);
5154                    uniforms.push(r_u);
5155                    bind_groups.push(p_bg);
5156                    bind_groups.push(r_bg);
5157                }
5158                Op::RopeBackward { head_dim, n_rot } => {
5159                    let dy_shape = &graph.node(node.inputs[0]).shape;
5160                    let (batch, seq, hidden) = if dy_shape.rank() >= 3 {
5161                        (
5162                            dy_shape.dim(0).unwrap_static() as u32,
5163                            dy_shape.dim(1).unwrap_static() as u32,
5164                            dy_shape.dim(2).unwrap_static() as u32,
5165                        )
5166                    } else {
5167                        (
5168                            1,
5169                            dy_shape.dim(0).unwrap_static() as u32,
5170                            dy_shape.dim(1).unwrap_static() as u32,
5171                        )
5172                    };
5173                    let cos_len = graph.node(node.inputs[1]).shape.num_elements().unwrap() as u32;
5174                    let p = RopeBwdParams {
5175                        batch,
5176                        seq,
5177                        hidden,
5178                        head_dim: *head_dim as u32,
5179                        n_rot: *n_rot as u32,
5180                        dy_off: (arena.offset(node.inputs[0]) / 4) as u32,
5181                        cos_off: (arena.offset(node.inputs[1]) / 4) as u32,
5182                        sin_off: (arena.offset(node.inputs[2]) / 4) as u32,
5183                        dx_off: (arena.offset(node.id) / 4) as u32,
5184                        cos_len,
5185                    };
5186                    let rk = rope_backward_kernel(&dev.device);
5187                    let u = emit_uniform(std::mem::size_of::<RopeBwdParams>());
5188                    let bg = bind_op_output_window(&dev.device, rk, &arena, node.id, &u);
5189                    schedule.push(Step::RopeBackward { params: p });
5190                    uniforms.push(u);
5191                    bind_groups.push(bg);
5192                }
5193                Op::CumsumBackward { exclusive, .. } => {
5194                    let dy_shape = &graph.node(node.inputs[0]).shape;
5195                    let cols = dy_shape.dim(dy_shape.rank() - 1).unwrap_static() as u32;
5196                    let rows = (dy_shape.num_elements().unwrap() / cols.max(1) as usize) as u32;
5197                    let p = CumsumBwdParams {
5198                        outer: rows,
5199                        inner: cols,
5200                        dy_off: (arena.offset(node.inputs[0]) / 4) as u32,
5201                        dx_off: (arena.offset(node.id) / 4) as u32,
5202                        exclusive: if *exclusive { 1 } else { 0 },
5203                        _p0: 0,
5204                        _p1: 0,
5205                        _p2: 0,
5206                    };
5207                    let ck = cumsum_backward_kernel(&dev.device);
5208                    let u = emit_uniform(std::mem::size_of::<CumsumBwdParams>());
5209                    let bg = bind_op_output_window(&dev.device, ck, &arena, node.id, &u);
5210                    schedule.push(Step::CumsumBackward { params: p });
5211                    uniforms.push(u);
5212                    bind_groups.push(bg);
5213                }
5214                Op::GatherBackward { .. } => {
5215                    let dy_shape = &graph.node(node.inputs[0]).shape;
5216                    let idx_shape = &graph.node(node.inputs[1]).shape;
5217                    let out_shape = &node.shape;
5218                    let rank = out_shape.rank();
5219                    let axis = match &node.op {
5220                        Op::GatherBackward { axis } => *axis,
5221                        _ => 0,
5222                    };
5223                    let axis_u = if axis < 0 {
5224                        (rank as i32 + axis) as usize
5225                    } else {
5226                        axis as usize
5227                    };
5228                    let outer: usize = (0..axis_u)
5229                        .map(|i| dy_shape.dim(i).unwrap_static())
5230                        .product::<usize>()
5231                        .max(1);
5232                    let num_idx = idx_shape.dim(axis_u).unwrap_static();
5233                    let trailing: usize = (axis_u + 1..dy_shape.rank())
5234                        .map(|i| dy_shape.dim(i).unwrap_static())
5235                        .product::<usize>()
5236                        .max(1);
5237                    let axis_dim = out_shape.dim(axis_u).unwrap_static();
5238                    let p = GatherBwdParams {
5239                        outer: outer as u32,
5240                        axis_dim: axis_dim as u32,
5241                        num_idx: num_idx as u32,
5242                        trailing: trailing as u32,
5243                        dy_off: (arena.offset(node.inputs[0]) / 4) as u32,
5244                        idx_off: (arena.offset(node.inputs[1]) / 4) as u32,
5245                        dst_off: (arena.offset(node.id) / 4) as u32,
5246                        _p0: 0,
5247                    };
5248                    let zk = gather_backward_zero_kernel(&dev.device);
5249                    let u = emit_uniform(std::mem::size_of::<GatherBwdParams>());
5250                    let bg = bind_op_output_window(&dev.device, zk, &arena, node.id, &u);
5251                    schedule.push(Step::GatherBackward { params: p });
5252                    uniforms.push(u);
5253                    bind_groups.push(bg);
5254                }
5255                #[cfg(feature = "splat")]
5256                Op::GaussianSplatRender {
5257                    width,
5258                    height,
5259                    tile_size,
5260                    radius_scale,
5261                    alpha_cutoff,
5262                    max_splat_steps,
5263                    transmittance_threshold,
5264                    max_list_entries,
5265                } => {
5266                    let elem_len = |id: NodeId| -> u32 {
5267                        graph.node(id).shape.num_elements().unwrap_or(0) as u32
5268                    };
5269                    schedule.push(Step::GaussianSplatRender {
5270                        positions_byte_off: arena.offset(node.inputs[0]) as u32,
5271                        positions_len: elem_len(node.inputs[0]),
5272                        scales_byte_off: arena.offset(node.inputs[1]) as u32,
5273                        scales_len: elem_len(node.inputs[1]),
5274                        rotations_byte_off: arena.offset(node.inputs[2]) as u32,
5275                        rotations_len: elem_len(node.inputs[2]),
5276                        opacities_byte_off: arena.offset(node.inputs[3]) as u32,
5277                        opacities_len: elem_len(node.inputs[3]),
5278                        colors_byte_off: arena.offset(node.inputs[4]) as u32,
5279                        colors_len: elem_len(node.inputs[4]),
5280                        sh_coeffs_byte_off: arena.offset(node.inputs[5]) as u32,
5281                        sh_coeffs_len: elem_len(node.inputs[5]),
5282                        meta_byte_off: arena.offset(node.inputs[6]) as u32,
5283                        dst_byte_off: arena.offset(node.id) as u32,
5284                        dst_len: node.shape.num_elements().unwrap_or(0) as u32,
5285                        width: *width,
5286                        height: *height,
5287                        tile_size: *tile_size,
5288                        radius_scale: *radius_scale,
5289                        alpha_cutoff: *alpha_cutoff,
5290                        max_splat_steps: *max_splat_steps,
5291                        transmittance_threshold: *transmittance_threshold,
5292                        max_list_entries: *max_list_entries,
5293                    });
5294                }
5295
5296                #[cfg(feature = "splat")]
5297                Op::GaussianSplatRenderBackward {
5298                    width,
5299                    height,
5300                    tile_size,
5301                    radius_scale,
5302                    alpha_cutoff,
5303                    max_splat_steps,
5304                    transmittance_threshold,
5305                    max_list_entries,
5306                    loss_grad_clip,
5307                    sh_band,
5308                    max_anisotropy,
5309                } => {
5310                    let elem_len = |id: NodeId| -> u32 {
5311                        graph.node(id).shape.num_elements().unwrap_or(0) as u32
5312                    };
5313                    schedule.push(Step::GaussianSplatRenderBackward {
5314                        positions_byte_off: arena.offset(node.inputs[0]) as u32,
5315                        positions_len: elem_len(node.inputs[0]),
5316                        scales_byte_off: arena.offset(node.inputs[1]) as u32,
5317                        scales_len: elem_len(node.inputs[1]),
5318                        rotations_byte_off: arena.offset(node.inputs[2]) as u32,
5319                        rotations_len: elem_len(node.inputs[2]),
5320                        opacities_byte_off: arena.offset(node.inputs[3]) as u32,
5321                        opacities_len: elem_len(node.inputs[3]),
5322                        colors_byte_off: arena.offset(node.inputs[4]) as u32,
5323                        colors_len: elem_len(node.inputs[4]),
5324                        sh_coeffs_byte_off: arena.offset(node.inputs[5]) as u32,
5325                        sh_coeffs_len: elem_len(node.inputs[5]),
5326                        meta_byte_off: arena.offset(node.inputs[6]) as u32,
5327                        d_loss_byte_off: arena.offset(node.inputs[7]) as u32,
5328                        d_loss_len: elem_len(node.inputs[7]),
5329                        packed_byte_off: arena.offset(node.id) as u32,
5330                        packed_len: node.shape.num_elements().unwrap_or(0) as u32,
5331                        width: *width,
5332                        height: *height,
5333                        tile_size: *tile_size,
5334                        radius_scale: *radius_scale,
5335                        alpha_cutoff: *alpha_cutoff,
5336                        max_splat_steps: *max_splat_steps,
5337                        transmittance_threshold: *transmittance_threshold,
5338                        max_list_entries: *max_list_entries,
5339                        loss_grad_clip: *loss_grad_clip,
5340                        sh_band: *sh_band,
5341                        max_anisotropy: *max_anisotropy,
5342                    });
5343                }
5344
5345                #[cfg(feature = "splat")]
5346                Op::GaussianSplatPrepare {
5347                    width,
5348                    height,
5349                    tile_size,
5350                    radius_scale,
5351                    alpha_cutoff,
5352                    max_splat_steps,
5353                    transmittance_threshold,
5354                    max_list_entries,
5355                } => {
5356                    let elem_len = |id: NodeId| -> u32 {
5357                        graph.node(id).shape.num_elements().unwrap_or(0) as u32
5358                    };
5359                    schedule.push(Step::GaussianSplatPrepare {
5360                        positions_byte_off: arena.offset(node.inputs[0]) as u32,
5361                        positions_len: elem_len(node.inputs[0]),
5362                        scales_byte_off: arena.offset(node.inputs[1]) as u32,
5363                        scales_len: elem_len(node.inputs[1]),
5364                        rotations_byte_off: arena.offset(node.inputs[2]) as u32,
5365                        rotations_len: elem_len(node.inputs[2]),
5366                        opacities_byte_off: arena.offset(node.inputs[3]) as u32,
5367                        opacities_len: elem_len(node.inputs[3]),
5368                        colors_byte_off: arena.offset(node.inputs[4]) as u32,
5369                        colors_len: elem_len(node.inputs[4]),
5370                        sh_coeffs_byte_off: arena.offset(node.inputs[5]) as u32,
5371                        sh_coeffs_len: elem_len(node.inputs[5]),
5372                        meta_byte_off: arena.offset(node.inputs[6]) as u32,
5373                        meta_len: elem_len(node.inputs[6]),
5374                        prep_byte_off: arena.offset(node.id) as u32,
5375                        prep_len: node.shape.num_elements().unwrap_or(0) as u32,
5376                        width: *width,
5377                        height: *height,
5378                        tile_size: *tile_size,
5379                        radius_scale: *radius_scale,
5380                        alpha_cutoff: *alpha_cutoff,
5381                        max_splat_steps: *max_splat_steps,
5382                        transmittance_threshold: *transmittance_threshold,
5383                        max_list_entries: *max_list_entries,
5384                    });
5385                }
5386
5387                #[cfg(feature = "splat")]
5388                Op::GaussianSplatRasterize {
5389                    width,
5390                    height,
5391                    tile_size,
5392                    alpha_cutoff,
5393                    max_splat_steps,
5394                    transmittance_threshold,
5395                    max_list_entries,
5396                } => {
5397                    let elem_len = |id: NodeId| -> u32 {
5398                        graph.node(id).shape.num_elements().unwrap_or(0) as u32
5399                    };
5400                    let prep_id = node.inputs[0];
5401                    let count = match &graph.node(prep_id).op {
5402                        rlx_ir::Op::GaussianSplatPrepare { .. } => {
5403                            elem_len(graph.node(prep_id).inputs[0]) / 3
5404                        }
5405                        _ => 1,
5406                    };
5407                    schedule.push(Step::GaussianSplatRasterize {
5408                        prep_byte_off: arena.offset(prep_id) as u32,
5409                        prep_len: elem_len(prep_id),
5410                        meta_byte_off: arena.offset(node.inputs[1]) as u32,
5411                        meta_len: elem_len(node.inputs[1]),
5412                        dst_byte_off: arena.offset(node.id) as u32,
5413                        dst_len: node.shape.num_elements().unwrap_or(0) as u32,
5414                        count,
5415                        width: *width,
5416                        height: *height,
5417                        tile_size: *tile_size,
5418                        alpha_cutoff: *alpha_cutoff,
5419                        max_splat_steps: *max_splat_steps,
5420                        transmittance_threshold: *transmittance_threshold,
5421                        max_list_entries: *max_list_entries,
5422                    });
5423                }
5424
5425                Op::If { .. } | Op::While { .. } => {
5426                    // Should be unreachable: unfuse.rs inlines both branches
5427                    // (If) or unrolls max_iterations (While) into the parent
5428                    // graph using primitive ops + Where for the gating. If
5429                    // we hit this arm, the unfusion pass has a gap.
5430                    panic!(
5431                        "rlx-wgpu: Op::If/While leaked past unfusion pass — \
5432                            check unfuse.rs::expand_if / expand_while"
5433                    );
5434                }
5435                other => panic!(
5436                    "rlx-wgpu: op {other:?} not yet lowered (v2 covers Matmul, \
5437                     Binary, Compare, Activation, Where — fall back to CPU/Metal/MLX)"
5438                ),
5439            }
5440        }
5441
5442        if rlx_ir::env::flag("RLX_WGPU_SCHEDULE") || rlx_ir::env::flag("RLX_DISPATCH_REPORT") {
5443            let mut counts: std::collections::BTreeMap<&'static str, usize> =
5444                std::collections::BTreeMap::new();
5445            let mut fft_gpu = 0usize;
5446            let mut fft_host = 0usize;
5447            for s in &schedule {
5448                *counts.entry(step_name(s)).or_insert(0) += 1;
5449                match s {
5450                    Step::FftGpu { .. } => fft_gpu += 1,
5451                    Step::FftHost { .. } => fft_host += 1,
5452                    _ => {}
5453                }
5454            }
5455            let arena_mb = arena.size as f64 / (1u64 << 20) as f64;
5456            eprintln!(
5457                "[rlx-wgpu] schedule: {} steps, arena={arena_mb:.1} MiB, fft_gpu={fft_gpu}, fft_host={fft_host}",
5458                schedule.len()
5459            );
5460            for (n, c) in &counts {
5461                eprintln!("    {c:>4} × {n}");
5462            }
5463        }
5464
5465        let coop_f16_vk = schedule_uses_coop_f16_vk(&schedule);
5466
5467        Self {
5468            graph,
5469            arena,
5470            schedule,
5471            input_offsets,
5472            param_offsets,
5473            uniforms,
5474            bind_groups,
5475            meta_buffers,
5476            unresolved: None,
5477            last_binding: None,
5478            pending_params: HashMap::new(),
5479            pending_param_bytes: HashMap::new(),
5480            active_extent: None,
5481            uniforms_active_extent: None,
5482            input_staging_hashes: HashMap::new(),
5483            coop_f16_vk,
5484            coop_f16_b_param,
5485            coop_f16_vk_wide_b: HashSet::new(),
5486            coop_f16_vk_wide_bind_groups,
5487            coop_f16_host_activations,
5488            stashed_params: HashMap::new(),
5489            readback_staging: None,
5490            tiny_readback: None,
5491            fft_gpu_steps,
5492            gpu_handles: HashMap::new(),
5493            gpu_handle_feeds: HashMap::new(),
5494            gpu_handle_resident: HashSet::new(),
5495            pending_read_indices: None,
5496        }
5497    }
5498
5499    pub fn set_param(&mut self, name: &str, data: &[f32]) {
5500        const STASH_MAX_BYTES: usize = 16 * 1024 * 1024;
5501        if data.len() * 4 <= STASH_MAX_BYTES {
5502            self.stashed_params.insert(name.to_string(), data.to_vec());
5503        }
5504        if self.coop_f16_vk {
5505            crate::coop_f16_vk::refresh_wide_b_flag(&mut self.coop_f16_vk_wide_b, name, data);
5506        }
5507        if self.unresolved.is_some() {
5508            self.pending_params.insert(name.to_string(), data.to_vec());
5509            return;
5510        }
5511        let dev = wgpu_device().expect("rlx-wgpu: device gone");
5512        if let Some(&id) = self.param_offsets.get(name)
5513            && self.arena.has(id)
5514        {
5515            self.arena.write_f32(&dev.queue, id, data);
5516        }
5517    }
5518
5519    /// Debug helper: run forward, then read every node slot back and
5520    /// report the first node whose output contains a NaN, plus a
5521    /// summary of the *previous* finite node's value range so the
5522    /// caller can see the input that broke. Slow — diagnosis only.
5523    pub fn debug_first_nan_node(
5524        &mut self,
5525        inputs: &[(&str, &[f32])],
5526    ) -> Option<(usize, String, String)> {
5527        let _ = self.run(inputs);
5528        let dev = wgpu_device().expect("rlx-wgpu: device gone");
5529        let mut prev_summary = String::from("(none)");
5530        for (i, node) in self.graph.nodes().iter().enumerate() {
5531            if !self.arena.has(node.id) {
5532                continue;
5533            }
5534            let elems = node.shape.num_elements().unwrap_or(0);
5535            if elems == 0 {
5536                continue;
5537            }
5538            let data = self.arena.read_f32(&dev.device, &dev.queue, node.id);
5539            let nan_count = data.iter().filter(|v| v.is_nan()).count();
5540            let inf_count = data.iter().filter(|v| v.is_infinite()).count();
5541            if nan_count > 0 || inf_count > 0 {
5542                return Some((i, format!("{:?}", node.op), prev_summary));
5543            }
5544            let max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
5545            let min = data.iter().copied().fold(f32::INFINITY, f32::min);
5546            let abs_max = data.iter().map(|v| v.abs()).fold(0.0_f32, f32::max);
5547            prev_summary = format!(
5548                "node #{i} {:?} shape={:?}  min={min:.6e} max={max:.6e} |max|={abs_max:.6e}",
5549                node.op,
5550                node.shape
5551                    .dims()
5552                    .iter()
5553                    .map(|d| format!("{d:?}"))
5554                    .collect::<Vec<_>>()
5555            );
5556        }
5557        None
5558    }
5559
5560    /// Declared output dtypes (one per graph output). Used by the
5561    /// runtime wrapper's `run_typed` to narrow F32 results back to
5562    /// F16/BF16 etc. on the way out.
5563    pub fn output_dtypes(&self) -> Vec<rlx_ir::DType> {
5564        self.graph
5565            .outputs
5566            .iter()
5567            .map(|&id| self.graph.node(id).shape.dtype())
5568            .collect()
5569    }
5570
5571    /// Upload raw bytes for a Param. The bytes land tight-packed at
5572    /// the param's slot offset — no f32 round-trip. Used for quantized
5573    /// weights (int8 / int4) where the kernel reads the byte stream
5574    /// via `bitcast<u32>` from the f32-typed arena.
5575    pub fn set_param_bytes(&mut self, name: &str, data: &[u8]) {
5576        if self.unresolved.is_some() {
5577            self.pending_param_bytes
5578                .insert(name.to_string(), data.to_vec());
5579            return;
5580        }
5581        let dev = wgpu_device().expect("rlx-wgpu: device gone");
5582        if let Some(&id) = self.param_offsets.get(name)
5583            && self.arena.has(id)
5584        {
5585            dev.queue
5586                .write_buffer(&self.arena.buffer, self.arena.offset(id) as u64, data);
5587        }
5588    }
5589
5590    fn dump_node_stats_if_requested(&self, dev: &crate::device::WgpuDevice) {
5591        if !rlx_ir::env::flag("RLX_WGPU_DUMP_NODES") {
5592            return;
5593        }
5594        let flat_probe = rlx_ir::env::parse_or::<usize>("RLX_WGPU_DUMP_FLAT", usize::MAX);
5595        let limit = rlx_ir::env::parse_or("RLX_WGPU_DUMP_NODES_LIMIT", 40usize);
5596        eprintln!(
5597            "[rlx-wgpu-dump] per-node max |x| (topo order, limit={limit}{})",
5598            if flat_probe != usize::MAX {
5599                format!(", flat[{flat_probe}]")
5600            } else {
5601                String::new()
5602            }
5603        );
5604        let mut shown = 0usize;
5605        for (i, node) in self.graph.nodes().iter().enumerate() {
5606            if !self.arena.has(node.id) {
5607                continue;
5608            }
5609            if matches!(
5610                node.op,
5611                rlx_ir::Op::Input { .. }
5612                    | rlx_ir::Op::Param { .. }
5613                    | rlx_ir::Op::Constant { .. }
5614                    | rlx_ir::Op::Reshape { .. }
5615                    | rlx_ir::Op::Cast { .. }
5616            ) {
5617                continue;
5618            }
5619            let data = self.arena.read_f32(&dev.device, &dev.queue, node.id);
5620            let max = data.iter().fold(0.0f32, |m, &v| m.max(v.abs()));
5621            let nz = data.iter().filter(|&&v| v != 0.0).count();
5622            let flat_s = if flat_probe < data.len() {
5623                format!(" flat[{flat_probe}]={:.6}", data[flat_probe])
5624            } else {
5625                String::new()
5626            };
5627            eprintln!(
5628                "  [{i:>3}] {:?} max={max:.6} nonzero={}/{}{flat_s}",
5629                node.op,
5630                nz,
5631                data.len()
5632            );
5633            shown += 1;
5634            if shown >= limit {
5635                break;
5636            }
5637        }
5638    }
5639
5640    pub fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
5641        self.run_read_outputs(inputs, None)
5642    }
5643
5644    pub fn run_read_outputs(
5645        &mut self,
5646        inputs: &[(&str, &[f32])],
5647        read_indices: Option<&[usize]>,
5648    ) -> Vec<Vec<f32>> {
5649        self.pending_read_indices = read_indices.map(|s| s.to_vec());
5650        let outs = self.run_inner(inputs);
5651        self.pending_read_indices = None;
5652        outs
5653    }
5654
5655    pub fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
5656        if !self.input_offsets.contains_key(name) {
5657            return false;
5658        }
5659        self.gpu_handle_resident.remove(name);
5660        self.gpu_handles.insert(name.to_string(), data.to_vec());
5661        true
5662    }
5663
5664    pub fn has_gpu_handle(&self, name: &str) -> bool {
5665        self.gpu_handles.contains_key(name)
5666    }
5667
5668    pub fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) {
5669        self.gpu_handle_feeds
5670            .insert(handle_name.to_string(), output_index);
5671    }
5672
5673    pub fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
5674        if let Some(&out_idx) = self.gpu_handle_feeds.get(name) {
5675            if out_idx < self.graph.outputs.len() {
5676                let id = self.graph.outputs[out_idx];
5677                if self.arena.has(id) {
5678                    let dev = wgpu_device().expect("rlx-wgpu: device gone");
5679                    return Some(self.arena.read_f32(&dev.device, &dev.queue, id));
5680                }
5681            }
5682        }
5683        if self.gpu_handle_resident.contains(name) {
5684            if let Some(&id) = self.input_offsets.get(name) {
5685                if self.arena.has(id) {
5686                    let dev = wgpu_device().expect("rlx-wgpu: device gone");
5687                    return Some(self.arena.read_f32(&dev.device, &dev.queue, id));
5688                }
5689            }
5690        }
5691        self.gpu_handles.get(name).cloned()
5692    }
5693
5694    fn readback_plan(&self) -> Vec<usize> {
5695        let n = self.graph.outputs.len();
5696        if self.pending_read_indices.is_none() && self.gpu_handle_feeds.is_empty() {
5697            return (0..n).collect();
5698        }
5699        if let Some(ref want) = self.pending_read_indices {
5700            let mut v: Vec<_> = want.to_vec();
5701            v.sort_unstable();
5702            return v;
5703        }
5704        (0..n).collect()
5705    }
5706
5707    fn propagate_gpu_handle_feeds_on_gpu(
5708        &mut self,
5709        dev: &crate::device::WgpuDevice,
5710        enc: &mut wgpu::CommandEncoder,
5711    ) {
5712        let extent = self.active_extent;
5713        let feeds: Vec<(String, usize)> = self
5714            .gpu_handle_feeds
5715            .iter()
5716            .map(|(n, &i)| (n.clone(), i))
5717            .collect();
5718        for (name, out_idx) in feeds {
5719            if out_idx >= self.graph.outputs.len() {
5720                continue;
5721            }
5722            let out_id = self.graph.outputs[out_idx];
5723            let Some(&in_id) = self.input_offsets.get(name.as_str()) else {
5724                continue;
5725            };
5726            if in_id != out_id {
5727                let out_bytes = self.arena.len_of(out_id);
5728                let copy_bytes = match extent {
5729                    Some((actual, upper)) if upper > 0 => {
5730                        let stride = (out_bytes / (upper + 1)).max(4);
5731                        (actual * stride).min(out_bytes)
5732                    }
5733                    _ => out_bytes,
5734                };
5735                self.dispatch_arena_copy_bytes(dev, enc, out_id, in_id, copy_bytes);
5736            }
5737            self.gpu_handle_resident.insert(name.clone());
5738            self.gpu_handles.insert(name.clone(), Vec::new());
5739        }
5740    }
5741
5742    fn dispatch_arena_copy_bytes(
5743        &self,
5744        dev: &crate::device::WgpuDevice,
5745        enc: &mut wgpu::CommandEncoder,
5746        src_id: NodeId,
5747        dst_id: NodeId,
5748        nbytes: usize,
5749    ) {
5750        if nbytes == 0 {
5751            return;
5752        }
5753        let src = self.arena.offset(src_id) as u64;
5754        let dst = self.arena.offset(dst_id) as u64;
5755        let nbytes = nbytes
5756            .min(self.arena.len_of(src_id))
5757            .min(self.arena.len_of(dst_id)) as u64;
5758        let elems = (nbytes / 4).max(1) as u32;
5759        let lo = src.min(dst);
5760        let hi = src.saturating_add(nbytes).max(dst.saturating_add(nbytes));
5761        let max_binding = dev.device.limits().max_storage_buffer_binding_size;
5762        let mut size = hi.saturating_sub(lo).div_ceil(256) * 256;
5763        size = size.max(256).min(max_binding);
5764        let mut base = (lo / 256) * 256;
5765        if base.saturating_add(size) > self.arena.size as u64 {
5766            base = (self.arena.size as u64).saturating_sub(size);
5767            base = (base / 256) * 256;
5768        }
5769        let p = CopyParams {
5770            n: elems,
5771            in_off: (src.saturating_sub(base) / 4) as u32,
5772            out_off: (dst.saturating_sub(base) / 4) as u32,
5773            _p0: 0,
5774            _p1: 0,
5775            _p2: 0,
5776            _p3: 0,
5777            _p4: 0,
5778        };
5779        let ck = copy_kernel(&dev.device);
5780        let u = dev.device.create_buffer(&wgpu::BufferDescriptor {
5781            label: Some("rlx-wgpu kv_feed_copy uniform"),
5782            size: std::mem::size_of::<CopyParams>() as u64,
5783            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
5784            mapped_at_creation: false,
5785        });
5786        dev.queue.write_buffer(&u, 0, bytemuck::bytes_of(&p));
5787        let bg = bind_two_buf0_window(&dev.device, ck, &self.arena.buffer, base, size, &u);
5788        let mut pass = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
5789            label: Some("rlx-wgpu kv_feed_copy pass"),
5790            ..Default::default()
5791        });
5792        pass.set_pipeline(&ck.pipeline);
5793        pass.set_bind_group(0, &bg, &[]);
5794        let (gx, gy, gz) = dispatch_dims(elems, 64);
5795        pass.dispatch_workgroups(gx, gy, gz);
5796    }
5797
5798    #[allow(dead_code)]
5799    fn dispatch_arena_copy_between_nodes(
5800        &self,
5801        dev: &crate::device::WgpuDevice,
5802        enc: &mut wgpu::CommandEncoder,
5803        src_id: NodeId,
5804        dst_id: NodeId,
5805    ) {
5806        let nbytes = self.arena.len_of(src_id).min(self.arena.len_of(dst_id));
5807        self.dispatch_arena_copy_bytes(dev, enc, src_id, dst_id, nbytes);
5808    }
5809
5810    fn stage_gpu_handle_inputs(
5811        &mut self,
5812        dev: &crate::device::WgpuDevice,
5813        inputs: &[(&str, &[f32])],
5814    ) {
5815        for (name, data) in &self.gpu_handles {
5816            if self.gpu_handle_resident.contains(name) || inputs.iter().any(|(n, _)| n == name) {
5817                continue;
5818            }
5819            if let Some(&id) = self.input_offsets.get(name.as_str())
5820                && self.arena.has(id)
5821            {
5822                self.arena.write_f32(&dev.queue, id, data);
5823                self.input_staging_hashes.remove(name);
5824            }
5825        }
5826    }
5827
5828    fn pack_readback_outputs(&mut self, plan: &[usize], partial: Vec<Vec<f32>>) -> Vec<Vec<f32>> {
5829        if self.pending_read_indices.is_none() {
5830            for (pos, &out_i) in plan.iter().enumerate() {
5831                if let Some(data) = partial.get(pos) {
5832                    for (name, &feed_i) in &self.gpu_handle_feeds {
5833                        if feed_i == out_i {
5834                            self.gpu_handles.insert(name.clone(), data.clone());
5835                        }
5836                    }
5837                }
5838            }
5839        }
5840        if self.pending_read_indices.is_none() && plan.len() == self.graph.outputs.len() {
5841            return partial;
5842        }
5843        let want = self.pending_read_indices.as_deref().unwrap_or(plan);
5844        let mut by_idx = std::collections::HashMap::new();
5845        for (pos, &i) in plan.iter().enumerate() {
5846            if let Some(d) = partial.get(pos) {
5847                by_idx.insert(i, d.clone());
5848            }
5849        }
5850        want.iter()
5851            .map(|&i| {
5852                by_idx
5853                    .get(&i)
5854                    .cloned()
5855                    .expect("readback plan missing output")
5856            })
5857            .collect()
5858    }
5859
5860    fn run_tail_host_audio_ops(&self, dev: &crate::device::WgpuDevice) {
5861        if !self.schedule.iter().any(step_is_tail_host) {
5862            return;
5863        }
5864        for step in &self.schedule {
5865            if !step_is_tail_host(step) {
5866                continue;
5867            }
5868            match step {
5869                Step::WelchPeaksHost {
5870                    spec_byte_off,
5871                    dst_byte_off,
5872                    welch_batch,
5873                    n_fft,
5874                    n_segments,
5875                    k,
5876                } => {
5877                    crate::welch_peaks_host::run_welch_peaks(
5878                        &self.arena,
5879                        &dev.device,
5880                        &dev.queue,
5881                        *spec_byte_off as usize,
5882                        *dst_byte_off as usize,
5883                        *welch_batch as usize,
5884                        *n_fft as usize,
5885                        *n_segments as usize,
5886                        *k as usize,
5887                    );
5888                }
5889                Step::LogMelHost {
5890                    spec_byte_off,
5891                    filt_byte_off,
5892                    dst_byte_off,
5893                    outer,
5894                    n_fft,
5895                    n_bins,
5896                    n_mels,
5897                } => {
5898                    crate::log_mel_host::run_log_mel(
5899                        &self.arena,
5900                        &dev.device,
5901                        &dev.queue,
5902                        *spec_byte_off as usize,
5903                        *filt_byte_off as usize,
5904                        *dst_byte_off as usize,
5905                        *outer as usize,
5906                        *n_fft as usize,
5907                        *n_bins as usize,
5908                        *n_mels as usize,
5909                    );
5910                }
5911                Step::LogMelBackwardHost {
5912                    spec_byte_off,
5913                    filt_byte_off,
5914                    dy_byte_off,
5915                    dst_byte_off,
5916                    outer,
5917                    n_fft,
5918                    n_bins,
5919                    n_mels,
5920                } => {
5921                    crate::log_mel_host::run_log_mel_backward(
5922                        &self.arena,
5923                        &dev.device,
5924                        &dev.queue,
5925                        *spec_byte_off as usize,
5926                        *filt_byte_off as usize,
5927                        *dy_byte_off as usize,
5928                        *dst_byte_off as usize,
5929                        *outer as usize,
5930                        *n_fft as usize,
5931                        *n_bins as usize,
5932                        *n_mels as usize,
5933                    );
5934                }
5935                _ => {}
5936            }
5937        }
5938    }
5939
5940    fn run_inner(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
5941        // Lazy compile path: if we deferred compile waiting for shapes,
5942        // infer the binding from input data lengths now and compile.
5943        if self.unresolved.is_some() {
5944            self.lazy_compile_for_inputs(inputs);
5945        }
5946        let dev = wgpu_device().expect("rlx-wgpu: device gone");
5947        self.stage_gpu_handle_inputs(dev, inputs);
5948        let skip_input_upload =
5949            !rlx_ir::env::flag("RLX_WGPU_FORCE_INPUT_UPLOAD") && !self.coop_f16_vk;
5950        for &(name, data) in inputs {
5951            if let Some(&id) = self.input_offsets.get(name)
5952                && self.arena.has(id)
5953            {
5954                if skip_input_upload {
5955                    let h = hash_f32_input(data);
5956                    if self.input_staging_hashes.get(name) == Some(&h) {
5957                        if self.arena.f16_buffer.is_some() {
5958                            self.arena.write_f16_shadow(&dev.queue, id, data);
5959                        }
5960                        continue;
5961                    }
5962                    self.arena.write_f32(&dev.queue, id, data);
5963                    self.input_staging_hashes.insert(name.to_string(), h);
5964                } else {
5965                    self.arena.write_f32(&dev.queue, id, data);
5966                }
5967            }
5968        }
5969        for &(act_id, act, ref src_name) in &self.coop_f16_host_activations {
5970            let src =
5971                host_tensor_f32(src_name, inputs, &self.stashed_params).unwrap_or_else(|| {
5972                    panic!("rlx-wgpu CoopF16Vk host activation: missing tensor {src_name:?}")
5973                });
5974            let mirrored = apply_activation_host(act, src);
5975            self.arena.write_f32(&dev.queue, act_id, &mirrored);
5976        }
5977        if !self.coop_f16_host_activations.is_empty() {
5978            // Ensure host staging writes are visible before CoopF16Vk reads f16.
5979            let flush = dev
5980                .device
5981                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
5982                    label: Some("rlx-wgpu host mirror flush"),
5983                });
5984            dev.queue.submit(std::iter::once(flush.finish()));
5985            let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
5986        }
5987
5988        // Active-extent (PLAN L1): scale safe Steps' primary dim by
5989        // actual/upper. Used in BOTH the uniform-write loop (so the
5990        // kernel sees the scaled count) AND the dispatch loop (so the
5991        // workgroup grid is shrunk).
5992        let active = self.active_extent.filter(|_| self.all_safe_for_active());
5993        let scale = |full: u32| -> u32 {
5994            match active {
5995                Some((a, u)) if u > 0 => {
5996                    let f = full as usize;
5997                    (f * a).div_ceil(u).min(f) as u32
5998                }
5999                _ => full,
6000            }
6001        };
6002
6003        // Stage uniform writes — but skip the loop entirely when the
6004        // bytes already in the uniforms match this run's active extent.
6005        // BERT inference at fixed batch hits this path: 100+ tiny
6006        // queue.write_buffer calls (one per Step) collapse to zero,
6007        // saving milliseconds of staging-copy overhead.
6008        let need_uniform_writes = self.uniforms_active_extent != Some(active);
6009        if need_uniform_writes {
6010            let mut gpu_ui = 0usize;
6011            for step in self.schedule.iter() {
6012                if step_runs_on_host(step) {
6013                    continue;
6014                }
6015                match step {
6016                    Step::CastF32ToF16 { .. } => {
6017                        // Params are static for this step (offset+len), so the
6018                        // pre-pass write at compile time is sufficient. No
6019                        // active-extent scaling — len is the full element count.
6020                    }
6021                    Step::Matmul {
6022                        m,
6023                        k,
6024                        n,
6025                        a_off_f32,
6026                        b_off_f32,
6027                        c_off_f32,
6028                        batch,
6029                        a_batch_stride,
6030                        b_batch_stride,
6031                        c_batch_stride,
6032                        has_bias,
6033                        bias_off_f32,
6034                        act_id,
6035                        b_is_param: _,
6036                        compute_precision: _,
6037                    } => {
6038                        // PLAN L1 (safe at any batch — c_batch_stride is
6039                        // pre-baked at compile time at FULL m, so scaling
6040                        // params.m only changes per-thread bound checks).
6041                        let m_scaled = scale(*m);
6042                        let p = MatmulParams {
6043                            m: m_scaled,
6044                            k: *k,
6045                            n: *n,
6046                            a_off: *a_off_f32,
6047                            b_off: *b_off_f32,
6048                            c_off: *c_off_f32,
6049                            batch: *batch,
6050                            a_batch_stride: *a_batch_stride,
6051                            b_batch_stride: *b_batch_stride,
6052                            c_batch_stride: *c_batch_stride,
6053                            has_bias: *has_bias,
6054                            bias_off: *bias_off_f32,
6055                            act_id: *act_id,
6056                            _pad0: 0,
6057                            _pad1: 0,
6058                            _pad2: 0,
6059                        };
6060                        dev.queue
6061                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6062                    }
6063                    Step::Binary { params } | Step::Compare { params } => {
6064                        let mut p = *params;
6065                        p.n = scale(p.n);
6066                        dev.queue
6067                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6068                    }
6069                    Step::Unary { params, .. } => {
6070                        let mut p = *params;
6071                        p.n = scale(p.n);
6072                        dev.queue
6073                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6074                    }
6075                    Step::Where { params } => {
6076                        let mut p = *params;
6077                        p.n = scale(p.n);
6078                        dev.queue
6079                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6080                    }
6081                    Step::Reduce { params } => {
6082                        let mut p = *params;
6083                        p.outer = scale(p.outer);
6084                        dev.queue
6085                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6086                    }
6087                    Step::Softmax { params } => {
6088                        let mut p = *params;
6089                        p.outer = scale(p.outer);
6090                        dev.queue
6091                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6092                    }
6093                    Step::LayerNorm { params } => {
6094                        let mut p = *params;
6095                        p.outer = scale(p.outer);
6096                        dev.queue
6097                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6098                    }
6099                    Step::RmsNormBackwardInput { params }
6100                    | Step::RmsNormBackwardGamma { params }
6101                    | Step::RmsNormBackwardBeta { params } => {
6102                        let mut p = *params;
6103                        p.outer = scale(p.outer);
6104                        dev.queue
6105                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6106                    }
6107                    Step::LayerNormBackwardInput { params } => {
6108                        let mut p = *params;
6109                        p.outer = scale(p.outer);
6110                        dev.queue
6111                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6112                    }
6113                    Step::LayerNormBackwardGammaPartial { params, .. } => {
6114                        let mut p = *params;
6115                        p.outer = scale(p.outer);
6116                        dev.queue
6117                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6118                    }
6119                    Step::LayerNormBackwardGammaReduce { params } => {
6120                        // `outer` here is the partial chunk count (not
6121                        // a batch dim) — do NOT apply active-extent
6122                        // scaling.
6123                        dev.queue.write_buffer(
6124                            &self.uniforms[gpu_ui],
6125                            0,
6126                            bytemuck::bytes_of(params),
6127                        );
6128                    }
6129                    Step::CumsumBackward { params } => {
6130                        let mut p = *params;
6131                        p.outer = scale(p.outer);
6132                        dev.queue
6133                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6134                    }
6135                    Step::RopeBackward { params } => {
6136                        let mut p = *params;
6137                        p.seq = scale(p.seq);
6138                        dev.queue
6139                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6140                    }
6141                    Step::GatherBackward { params } => {
6142                        let mut p = *params;
6143                        p.outer = scale(p.outer);
6144                        dev.queue
6145                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6146                    }
6147                    Step::Cumsum { params } => {
6148                        let mut p = *params;
6149                        p.outer = scale(p.outer);
6150                        dev.queue
6151                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6152                    }
6153                    Step::FftGpu { .. } => {}
6154                    Step::Copy { params } => {
6155                        let mut p = *params;
6156                        p.n = scale(p.n);
6157                        dev.queue
6158                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6159                    }
6160                    Step::BufferCopy { .. } => {}
6161                    Step::ElementwiseRegion { params } => {
6162                        // Active-extent: scale element count.
6163                        let mut p = *params;
6164                        p.len = scale(p.len);
6165                        dev.queue
6166                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6167                    }
6168                    Step::BatchElementwiseRegion { params } => {
6169                        let mut p = *params;
6170                        p.slice_len = scale(p.slice_len);
6171                        p.num_batch = scale(p.num_batch);
6172                        dev.queue
6173                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6174                    }
6175                    Step::Transpose { params, .. } => {
6176                        // PLAN L1: when bucket_outermost == 1, scale
6177                        // `out_total` proportional to scaling `out_dim_0`.
6178                        // Other transposes leave out_total at full extent
6179                        // (predicate prevents the active-extent path).
6180                        let mut p = *params;
6181                        if p.bucket_outermost == 1 && p.out_dim_0 > 0 {
6182                            let scaled_d0 = scale(p.out_dim_0);
6183                            let inner = p.out_total / p.out_dim_0;
6184                            p.out_total = scaled_d0 * inner;
6185                        }
6186                        dev.queue
6187                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6188                    }
6189                    Step::Narrow { params } => {
6190                        let mut p = *params;
6191                        p.total = scale(p.total);
6192                        dev.queue
6193                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6194                    }
6195                    Step::Concat { params } => {
6196                        let mut p = *params;
6197                        p.total = scale(p.total);
6198                        dev.queue
6199                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6200                    }
6201                    Step::Gather { params } => {
6202                        let mut p = *params;
6203                        p.n_out = scale(p.n_out);
6204                        dev.queue
6205                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6206                    }
6207                    Step::GatherAxis { params } => {
6208                        let mut p = *params;
6209                        p.total = scale(p.total);
6210                        dev.queue
6211                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6212                    }
6213                    Step::Attention { params, .. } => {
6214                        // PLAN L1: scale seq_q + seq_k. Stride fields
6215                        // (seq_q_stride / seq_k_stride) stay at the
6216                        // compile-time full extent, so per-(batch, head)
6217                        // offset math in the WGSL stays correct.
6218                        let mut p = *params;
6219                        p.seq_q = scale(p.seq_q);
6220                        p.seq_k = scale(p.seq_k);
6221                        dev.queue
6222                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6223                    }
6224                    Step::AttentionBackward { params, .. } => {
6225                        let mut p = *params;
6226                        if p.wrt == 0 {
6227                            p.seq_q = scale(p.seq_q);
6228                        } else {
6229                            p.seq_k = scale(p.seq_k);
6230                        }
6231                        dev.queue
6232                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6233                    }
6234                    Step::Rope { params } => {
6235                        // PLAN L1: scale `seq` and `n_total` proportionally.
6236                        // `seq_stride` and `batch` stay at compile-time
6237                        // values; the WGSL kernel uses them for buffer
6238                        // offsets while `seq` / `n_total` are loop bounds.
6239                        let mut p = *params;
6240                        let s_active = scale(p.seq);
6241                        p.seq = s_active;
6242                        p.n_total = p.batch * s_active * p.last_dim;
6243                        dev.queue
6244                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6245                    }
6246                    Step::Expand { params, .. } => {
6247                        // PLAN L1: same pattern as Transpose.
6248                        let mut p = *params;
6249                        if p.bucket_outermost == 1 && p.out_dim_0 > 0 {
6250                            let scaled_d0 = scale(p.out_dim_0);
6251                            let inner = p.out_total / p.out_dim_0;
6252                            p.out_total = scaled_d0 * inner;
6253                        }
6254                        dev.queue
6255                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6256                    }
6257                    Step::Argmax { params } => {
6258                        let mut p = *params;
6259                        p.outer = scale(p.outer);
6260                        dev.queue
6261                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6262                    }
6263                    Step::Pool2d { params } => {
6264                        let mut p = *params;
6265                        p.n = scale(p.n);
6266                        dev.queue
6267                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6268                    }
6269                    Step::Conv2d { params } => {
6270                        let mut p = *params;
6271                        p.n = scale(p.n);
6272                        dev.queue
6273                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6274                    }
6275                    Step::Pool1d { params } => {
6276                        let mut p = *params;
6277                        p.n = scale(p.n);
6278                        dev.queue
6279                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6280                    }
6281                    Step::Pool3d { params } => {
6282                        let mut p = *params;
6283                        p.n = scale(p.n);
6284                        dev.queue
6285                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6286                    }
6287                    Step::Conv1d { params } => {
6288                        let mut p = *params;
6289                        p.n = scale(p.n);
6290                        dev.queue
6291                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6292                    }
6293                    Step::Conv3d { params } => {
6294                        let mut p = *params;
6295                        p.n = scale(p.n);
6296                        dev.queue
6297                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6298                    }
6299                    Step::ScatterAdd { params } => {
6300                        // Two-phase: phase 0 zeros the FULL output (preserves
6301                        // accumulator semantics); phase 1 scatters first
6302                        // num_updates_active updates only.
6303                        let mut p = *params;
6304                        if p.op == 1 {
6305                            p.num_updates = scale(p.num_updates);
6306                        }
6307                        dev.queue
6308                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6309                    }
6310                    Step::TopK { params } => {
6311                        let mut p = *params;
6312                        p.outer = scale(p.outer);
6313                        dev.queue
6314                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6315                    }
6316                    Step::WelchPeaksGpu { params } => {
6317                        let mut p = *params;
6318                        p.welch_batch = scale(p.welch_batch);
6319                        dev.queue
6320                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6321                    }
6322                    Step::UmapKnn { params } => {
6323                        let mut p = *params;
6324                        p.n = scale(p.n);
6325                        dev.queue
6326                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6327                    }
6328                    Step::GroupedMatmul { params } => {
6329                        let mut p = *params;
6330                        p.m = scale(p.m);
6331                        dev.queue
6332                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6333                    }
6334                    Step::Sample { params } => {
6335                        let mut p = *params;
6336                        p.outer = scale(p.outer);
6337                        dev.queue
6338                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6339                    }
6340                    Step::SelectiveScan { params } => {
6341                        // Predicate-gated to batch=1: scale seq.
6342                        let mut p = *params;
6343                        p.seq = scale(p.seq);
6344                        dev.queue
6345                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6346                    }
6347                    Step::DequantMatmul { params } => {
6348                        let mut p = *params;
6349                        p.m = scale(p.m);
6350                        dev.queue
6351                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6352                    }
6353                    Step::DequantMatmulGguf { .. }
6354                    | Step::DequantGroupedMatmulGguf { .. }
6355                    | Step::GatedDeltaNet { .. }
6356                    | Step::Llada2GroupLimitedGate { .. }
6357                    | Step::UmapKnnHost { .. }
6358                    | Step::FftHost { .. }
6359                    | Step::Im2ColHost { .. }
6360                    | Step::WelchPeaksHost { .. }
6361                    | Step::LogMelHost { .. }
6362                    | Step::LogMelBackwardHost { .. } => {}
6363                    Step::FusedResidualLn { params } => {
6364                        let mut p = *params;
6365                        p.outer = scale(p.outer);
6366                        dev.queue
6367                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6368                    }
6369                    Step::FusedResidualLnTee { params } => {
6370                        let mut p = *params;
6371                        p.outer = scale(p.outer);
6372                        dev.queue
6373                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6374                    }
6375                    Step::FusedResidualRmsNorm { params } => {
6376                        let mut p = *params;
6377                        p.outer = scale(p.outer);
6378                        dev.queue
6379                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6380                    }
6381                    Step::MatmulQkv { params, kind: _ } => {
6382                        let mut p = *params;
6383                        p.m = scale(p.m);
6384                        dev.queue
6385                            .write_buffer(&self.uniforms[gpu_ui], 0, bytemuck::bytes_of(&p));
6386                    }
6387                    #[cfg(feature = "splat")]
6388                    Step::GaussianSplatRender { .. }
6389                    | Step::GaussianSplatRenderBackward { .. }
6390                    | Step::GaussianSplatPrepare { .. }
6391                    | Step::GaussianSplatRasterize { .. } => {}
6392                }
6393                if !matches!(step, Step::FftGpu { .. }) {
6394                    gpu_ui += 1;
6395                }
6396            }
6397            self.uniforms_active_extent = Some(active);
6398        }
6399
6400        // Encode + submit.
6401        let mm_k = matmul_kernel(&dev.device);
6402        let mm_w_active = matmul_wide_active_kernel(&dev.device);
6403        let mm_f16w = matmul_f16w_kernel(&dev.device);
6404        let mm_f16c = matmul_f16_compute_kernel(&dev.device);
6405        let mm_coop = matmul_coop16_kernel(&dev.device);
6406        let mm_coop_f16_vk = matmul_coop_f16_vulkan_kernel(&dev.device);
6407        let mm_coop_f32 = matmul_coop_f32_active_kernel(&dev.device);
6408        let mm_cast = cast_f32_to_f16_kernel(&dev.device);
6409        let bk = binary_kernel(&dev.device);
6410        let uk = unary_kernel(&dev.device);
6411        let ck = compare_kernel(&dev.device);
6412        let wk = where_kernel(&dev.device);
6413        let mut step_i = 0;
6414        let mut gpu_bi = 0usize;
6415        let mut fft_i = 0usize;
6416        while step_i < self.schedule.len() {
6417            let mut enc = dev
6418                .device
6419                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
6420                    label: Some("rlx-wgpu run"),
6421                });
6422            {
6423                let mut pass = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
6424                    label: Some("rlx-wgpu compute pass"),
6425                    timestamp_writes: None,
6426                });
6427                let mut pass_dispatched = false;
6428                while step_i < self.schedule.len() {
6429                    if step_is_tail_host(&self.schedule[step_i]) {
6430                        step_i += 1;
6431                        continue;
6432                    }
6433                    if step_runs_on_host(&self.schedule[step_i]) {
6434                        break;
6435                    }
6436                    // Vulkan/DX12: end the pass after unary/cast so f32→f16
6437                    // mirrors are visible to the next step. Only split once
6438                    // we've dispatched in *this* pass — otherwise the step that
6439                    // needs the flush would never run (infinite empty passes).
6440                    if pass_dispatched
6441                        && step_i > 0
6442                        && step_needs_pass_flush(&self.schedule[step_i], &self.schedule[step_i - 1])
6443                    {
6444                        break;
6445                    }
6446                    let step = &self.schedule[step_i];
6447                    // PLAN L3: per-step Perfetto trace span; no-op when
6448                    // env var RLX_TRACE_PERFETTO unset.
6449                    let _perf = rlx_ir::perfetto::TraceSpan::new(step_name(step), "wgpu");
6450                    match step {
6451                        Step::CastF32ToF16 { params } => {
6452                            // Pre-pass for matmul_coop16: mirror f32 arena
6453                            // region into f16 shadow buffer so the matmul
6454                            // kernel can read A as f16. One thread per
6455                            // element; 64-thread workgroups.
6456                            if let Some(cast_k) = mm_cast {
6457                                pass.set_pipeline(&cast_k.pipeline);
6458                                pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6459                                let (gx, gy, gz) = dispatch_dims(params.len, 64);
6460                                pass.dispatch_workgroups(gx, gy, gz);
6461                            }
6462                        }
6463                        Step::Matmul {
6464                            m,
6465                            n,
6466                            batch,
6467                            b_off_f32,
6468                            b_is_param,
6469                            compute_precision,
6470                            ..
6471                        } =>
6472                        // The dispatch branches below use a chain of
6473                        // `is_some() && …unwrap()` to pick a pipeline
6474                        // because each variant cares about a different
6475                        // Option<Pipeline>. `if let Some(p) = …` chains
6476                        // would require nesting per variant; the flat
6477                        // form is the readable shape here.
6478                        {
6479                            #[allow(clippy::unnecessary_unwrap)]
6480                            // Safe at any batch (see safe_for_active_extent
6481                            // comment); scale m, output rows past m_s per
6482                            // batch retain prior values via c_batch_stride.
6483                            let m_s = scale(*m);
6484                            if m_s == 0 {
6485                                continue;
6486                            }
6487                            let coop_f16_wide = mm_coop_f16_vk.is_some()
6488                                && *compute_precision == MatmulCompute::CoopF16Vk
6489                                && crate::coop_f16_vk::use_wide_matmul(
6490                                    *b_off_f32,
6491                                    *n,
6492                                    &self.coop_f16_b_param,
6493                                    &self.coop_f16_vk_wide_b,
6494                                );
6495                            pass.set_bind_group(
6496                                0,
6497                                coop_f16_vk_bind_group(self, gpu_bi, coop_f16_wide),
6498                                &[],
6499                            );
6500                            // Kernel selection priority:
6501                            //   1. compute_precision == F16 + b_is_param +
6502                            //      SHADER_F16 → matmul_f16_compute
6503                            //      (f16 multiply, f32 acc — 2× ALU on Apple)
6504                            //   2. legacy RLX_WGPU_F16_WEIGHTS opt-in →
6505                            //      matmul_f16w (storage-only f16; experimental,
6506                            //      currently regresses on Apple)
6507                            //   3. wide-N (m≥32, n≥64)   → matmul_wide
6508                            //   4. otherwise            → matmul (small/skinny)
6509                            let f16w_opt_in = rlx_ir::env::flag("RLX_WGPU_F16_WEIGHTS");
6510                            if let Some(coop) = mm_coop.as_ref()
6511                                && *b_is_param
6512                                && *compute_precision == MatmulCompute::Coop16
6513                            {
6514                                // Hardware GEMM via simdgroup_matrix /
6515                                // KHR_cooperative_matrix. 32×32 output tile
6516                                // per workgroup (16 hardware-GEMM ops with
6517                                // shared A/B loads). Caller guaranteed m, n,
6518                                // k are multiples of 32/32/8.
6519                                pass.set_pipeline(&coop.pipeline);
6520                                pass.dispatch_workgroups(n.div_ceil(32), m_s.div_ceil(32), *batch);
6521                            } else if mm_coop_f16_vk.is_some()
6522                                && *compute_precision == MatmulCompute::CoopF16Vk
6523                            {
6524                                if coop_f16_wide {
6525                                    dispatch_wide_f32_matmul(
6526                                        &mut pass,
6527                                        mm_w_active,
6528                                        mm_k,
6529                                        m_s,
6530                                        *n,
6531                                        *batch,
6532                                    );
6533                                } else {
6534                                    let n_eff = scale(*n);
6535                                    let coop_vk =
6536                                        matmul_coop_f16_vulkan_active_kernel(&dev.device, n_eff)
6537                                            .expect("coop f16 vk kernel missing");
6538                                    pass.set_pipeline(&coop_vk.pipeline);
6539                                    pass.dispatch_workgroups(
6540                                        m_s.div_ceil(16),
6541                                        n.div_ceil(16),
6542                                        *batch,
6543                                    );
6544                                }
6545                            } else if let Some(coop_f32) = mm_coop_f32.as_ref()
6546                                && *b_is_param
6547                                && *compute_precision == MatmulCompute::CoopF32
6548                            {
6549                                // CoopF32: Metal uses 32×32 simdgroup tiles;
6550                                // Vulkan uses 8×8 coopLoadT portable kernel.
6551                                pass.set_pipeline(&coop_f32.pipeline);
6552                                let backend = wgpu_device()
6553                                    .map(|d| d.backend)
6554                                    .unwrap_or(wgpu::Backend::Noop);
6555                                let (gx, gy) = if backend == wgpu::Backend::Metal {
6556                                    (n.div_ceil(32), m_s.div_ceil(32))
6557                                } else {
6558                                    (m_s.div_ceil(8), n.div_ceil(8))
6559                                };
6560                                pass.dispatch_workgroups(gx, gy, *batch);
6561                            } else if let Some(f16c) = mm_f16c.as_ref()
6562                                && *b_is_param
6563                                && *compute_precision == MatmulCompute::F16
6564                            {
6565                                pass.set_pipeline(&f16c.pipeline);
6566                                pass.dispatch_workgroups(n.div_ceil(32), m_s.div_ceil(32), *batch);
6567                            } else if let Some(f16w) = mm_f16w.as_ref()
6568                                && *b_is_param
6569                                && f16w_opt_in
6570                            {
6571                                pass.set_pipeline(&f16w.pipeline);
6572                                pass.dispatch_workgroups(n.div_ceil(32), m_s.div_ceil(32), *batch);
6573                            } else if m_s >= 32 && *n >= 64 {
6574                                pass.set_pipeline(&mm_w_active.pipeline);
6575                                let backend = wgpu_device()
6576                                    .map(|d| d.backend)
6577                                    .unwrap_or(wgpu::Backend::Noop);
6578                                let (gx, gy) = if matches!(
6579                                    backend,
6580                                    wgpu::Backend::Vulkan | wgpu::Backend::Dx12
6581                                ) {
6582                                    (n.div_ceil(64), m_s.div_ceil(64))
6583                                } else {
6584                                    (n.div_ceil(64), m_s.div_ceil(32))
6585                                };
6586                                pass.dispatch_workgroups(gx, gy, *batch);
6587                            } else {
6588                                pass.set_pipeline(&mm_k.pipeline);
6589                                pass.dispatch_workgroups(n.div_ceil(32), m_s.div_ceil(32), *batch);
6590                            }
6591                        }
6592                        Step::Binary { params } => {
6593                            let n_s = scale(params.n);
6594                            if n_s == 0 {
6595                                continue;
6596                            }
6597                            pass.set_pipeline(&bk.pipeline);
6598                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6599                            let (gx, gy, gz) = dispatch_dims(n_s, 64);
6600                            pass.dispatch_workgroups(gx, gy, gz);
6601                        }
6602                        Step::Compare { params } => {
6603                            let n_s = scale(params.n);
6604                            if n_s == 0 {
6605                                continue;
6606                            }
6607                            pass.set_pipeline(&ck.pipeline);
6608                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6609                            let (gx, gy, gz) = dispatch_dims(n_s, 64);
6610                            pass.dispatch_workgroups(gx, gy, gz);
6611                        }
6612                        Step::Unary { params, f16_mirror } => {
6613                            let n_s = scale(params.n);
6614                            if n_s == 0 {
6615                                continue;
6616                            }
6617                            if *f16_mirror {
6618                                if let Some(uk_f16) = unary_f16_mirror_kernel(&dev.device) {
6619                                    pass.set_pipeline(&uk_f16.pipeline);
6620                                } else {
6621                                    pass.set_pipeline(&uk.pipeline);
6622                                }
6623                            } else {
6624                                pass.set_pipeline(&uk.pipeline);
6625                            }
6626                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6627                            let (gx, gy, gz) = dispatch_dims(n_s, 64);
6628                            pass.dispatch_workgroups(gx, gy, gz);
6629                        }
6630                        Step::Where { params } => {
6631                            let n_s = scale(params.n);
6632                            if n_s == 0 {
6633                                continue;
6634                            }
6635                            pass.set_pipeline(&wk.pipeline);
6636                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6637                            let (gx, gy, gz) = dispatch_dims(n_s, 64);
6638                            pass.dispatch_workgroups(gx, gy, gz);
6639                        }
6640                        Step::Reduce { params } => {
6641                            let outer_s = scale(params.outer);
6642                            if outer_s == 0 {
6643                                continue;
6644                            }
6645                            let rk = reduce_kernel(&dev.device);
6646                            pass.set_pipeline(&rk.pipeline);
6647                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6648                            let total_out = outer_s.saturating_mul(params.inner);
6649                            if params.reduce_dim <= 64 {
6650                                // Fast path: 1 thread per output cell.
6651                                let (gx, gy, gz) = dispatch_dims(total_out, 64);
6652                                pass.dispatch_workgroups(gx, gy, gz);
6653                            } else {
6654                                // Tree-reduce path: 1 workgroup (64
6655                                // threads) per output cell, parallel
6656                                // reduction with shared scratch.
6657                                let (gx, gy, gz) = dispatch_dims(total_out, 1);
6658                                pass.dispatch_workgroups(gx, gy, gz);
6659                            }
6660                        }
6661                        Step::Softmax { params } => {
6662                            let outer_s = scale(params.outer);
6663                            if outer_s == 0 {
6664                                continue;
6665                            }
6666                            let sk = softmax_kernel(&dev.device);
6667                            pass.set_pipeline(&sk.pipeline);
6668                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6669                            let (gx, gy, gz) = dispatch_dims(outer_s, 64);
6670                            pass.dispatch_workgroups(gx, gy, gz);
6671                        }
6672                        Step::LayerNorm { params } => {
6673                            let outer_s = scale(params.outer);
6674                            if outer_s == 0 {
6675                                continue;
6676                            }
6677                            let lk = layernorm_kernel(&dev.device);
6678                            pass.set_pipeline(&lk.pipeline);
6679                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6680                            let (gx, gy, gz) = dispatch_dims(outer_s, 64);
6681                            pass.dispatch_workgroups(gx, gy, gz);
6682                        }
6683                        Step::RmsNormBackwardInput { params } => {
6684                            let outer_s = scale(params.outer);
6685                            if outer_s == 0 {
6686                                continue;
6687                            }
6688                            let rk = rms_norm_backward_kernel(&dev.device);
6689                            pass.set_pipeline(&rk.pipeline);
6690                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6691                            pass.dispatch_workgroups(outer_s, 1, 1);
6692                        }
6693                        Step::RmsNormBackwardGamma { params }
6694                        | Step::RmsNormBackwardBeta { params } => {
6695                            if params.inner == 0 {
6696                                continue;
6697                            }
6698                            let rk = rms_norm_backward_param_kernel(&dev.device);
6699                            pass.set_pipeline(&rk.pipeline);
6700                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6701                            pass.dispatch_workgroups(1, 1, 1);
6702                        }
6703                        Step::LayerNormBackwardInput { params } => {
6704                            let outer_s = scale(params.outer);
6705                            if outer_s == 0 {
6706                                continue;
6707                            }
6708                            let lk = layer_norm_backward_input_kernel(&dev.device);
6709                            pass.set_pipeline(&lk.pipeline);
6710                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6711                            pass.dispatch_workgroups(outer_s, 1, 1);
6712                        }
6713                        Step::LayerNormBackwardGammaPartial {
6714                            params,
6715                            num_workgroups,
6716                        } => {
6717                            if params.inner == 0 || *num_workgroups == 0 {
6718                                continue;
6719                            }
6720                            let lk = layer_norm_backward_gamma_partial_kernel(&dev.device);
6721                            pass.set_pipeline(&lk.pipeline);
6722                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6723                            pass.dispatch_workgroups(*num_workgroups, 1, 1);
6724                        }
6725                        Step::LayerNormBackwardGammaReduce { params } => {
6726                            if params.inner == 0 {
6727                                continue;
6728                            }
6729                            let lk = layer_norm_backward_gamma_reduce_kernel(&dev.device);
6730                            pass.set_pipeline(&lk.pipeline);
6731                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6732                            pass.dispatch_workgroups(1, 1, 1);
6733                        }
6734                        Step::CumsumBackward { params } => {
6735                            let outer_s = scale(params.outer);
6736                            if outer_s == 0 {
6737                                continue;
6738                            }
6739                            let ck = cumsum_backward_kernel(&dev.device);
6740                            pass.set_pipeline(&ck.pipeline);
6741                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6742                            let (gx, gy, gz) = dispatch_dims(outer_s, 64);
6743                            pass.dispatch_workgroups(gx, gy, gz);
6744                        }
6745                        Step::RopeBackward { params } => {
6746                            let seq_s = scale(params.seq);
6747                            if seq_s == 0 {
6748                                continue;
6749                            }
6750                            let rk = rope_backward_kernel(&dev.device);
6751                            pass.set_pipeline(&rk.pipeline);
6752                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6753                            let total = params.batch * seq_s * params.hidden;
6754                            let (gx, gy, gz) = dispatch_dims(total, 64);
6755                            pass.dispatch_workgroups(gx, gy, gz);
6756                        }
6757                        Step::GatherBackward { params } => {
6758                            let outer_s = scale(params.outer);
6759                            if outer_s == 0 {
6760                                continue;
6761                            }
6762                            let total = outer_s * params.axis_dim * params.trailing;
6763                            if total > 0 {
6764                                let zk = gather_backward_zero_kernel(&dev.device);
6765                                pass.set_pipeline(&zk.pipeline);
6766                                pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6767                                let (gx, _, _) = dispatch_dims(total, 256);
6768                                pass.dispatch_workgroups(gx, 1, 1);
6769                            }
6770                            let ak = gather_backward_acc_kernel(&dev.device);
6771                            pass.set_pipeline(&ak.pipeline);
6772                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6773                            pass.dispatch_workgroups(outer_s, 1, 1);
6774                        }
6775                        Step::Cumsum { params } => {
6776                            let outer_s = scale(params.outer);
6777                            if outer_s == 0 {
6778                                continue;
6779                            }
6780                            let ck2 = cumsum_kernel(&dev.device);
6781                            pass.set_pipeline(&ck2.pipeline);
6782                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6783                            let (gx, gy, gz) = dispatch_dims(outer_s, 64);
6784                            pass.dispatch_workgroups(gx, gy, gz);
6785                        }
6786                        Step::FftGpu {
6787                            src_off,
6788                            dst_off,
6789                            outer,
6790                            n,
6791                            inverse,
6792                            norm_scale,
6793                        } => {
6794                            let res = &self.fft_gpu_steps[fft_i];
6795                            fft_i += 1;
6796                            crate::fft_dispatch::dispatch_fft_gpu_in_pass(
6797                                &dev.device,
6798                                &dev.queue,
6799                                &mut pass,
6800                                res,
6801                                *src_off,
6802                                *dst_off,
6803                                *outer,
6804                                *n,
6805                                *inverse != 0,
6806                                *norm_scale,
6807                            );
6808                        }
6809                        Step::Copy { params } => {
6810                            let n_s = scale(params.n);
6811                            if n_s == 0 {
6812                                continue;
6813                            }
6814                            let ck2 = copy_kernel(&dev.device);
6815                            pass.set_pipeline(&ck2.pipeline);
6816                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6817                            let (gx, gy, gz) = dispatch_dims(n_s, 64);
6818                            pass.dispatch_workgroups(gx, gy, gz);
6819                        }
6820                        Step::BufferCopy { .. } => {
6821                            // Host step: `copy_buffer_to_buffer` runs outside compute passes.
6822                        }
6823                        Step::ElementwiseRegion { params } => {
6824                            let len_s = scale(params.len);
6825                            if len_s == 0 {
6826                                continue;
6827                            }
6828                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6829                            if params.prologue == rlx_ir::REGION_PROLOGUE_RESIZE_NEAREST_2X_NCHW {
6830                                let ek = elementwise_region_spatial_kernel(&dev.device);
6831                                pass.set_pipeline(&ek.pipeline);
6832                                let (gx, gy, gz) = dispatch_prologue_nchw(
6833                                    params.out_w,
6834                                    params.out_h,
6835                                    params.out_n * params.out_c,
6836                                );
6837                                pass.dispatch_workgroups(gx, gy, gz);
6838                            } else {
6839                                let ek = elementwise_region_kernel(&dev.device);
6840                                pass.set_pipeline(&ek.pipeline);
6841                                let (gx, gy, gz) = dispatch_dims(len_s, 64);
6842                                pass.dispatch_workgroups(gx, gy, gz);
6843                            }
6844                        }
6845                        Step::BatchElementwiseRegion { params } => {
6846                            let slice_len_s = scale(params.slice_len);
6847                            let num_batch_s = scale(params.num_batch);
6848                            if slice_len_s == 0 || num_batch_s == 0 {
6849                                continue;
6850                            }
6851                            let ek = batch_elementwise_region_kernel(&dev.device);
6852                            pass.set_pipeline(&ek.pipeline);
6853                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6854                            let (gx, gy, _) = dispatch_dims(slice_len_s, 64);
6855                            pass.dispatch_workgroups(gx, gy, num_batch_s);
6856                        }
6857                        Step::Transpose { params, .. } => {
6858                            // Compute scaled grid count to match the
6859                            // uniform's scaled out_total when bucket axis
6860                            // is outermost.
6861                            let total_s = if params.bucket_outermost == 1 && params.out_dim_0 > 0 {
6862                                let scaled_d0 = scale(params.out_dim_0);
6863                                let inner = params.out_total / params.out_dim_0;
6864                                scaled_d0 * inner
6865                            } else {
6866                                params.out_total
6867                            };
6868                            if total_s == 0 {
6869                                continue;
6870                            }
6871                            let tk = transpose_kernel(&dev.device);
6872                            pass.set_pipeline(&tk.pipeline);
6873                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6874                            let (gx, gy, gz) = dispatch_dims(total_s, 64);
6875                            pass.dispatch_workgroups(gx, gy, gz);
6876                        }
6877                        Step::Narrow { params } => {
6878                            let total_s = scale(params.total);
6879                            if total_s == 0 {
6880                                continue;
6881                            }
6882                            let nk = narrow_kernel(&dev.device);
6883                            pass.set_pipeline(&nk.pipeline);
6884                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6885                            let (gx, gy, gz) = dispatch_dims(total_s, 64);
6886                            pass.dispatch_workgroups(gx, gy, gz);
6887                        }
6888                        Step::Concat { params } => {
6889                            let total_s = scale(params.total);
6890                            if total_s == 0 {
6891                                continue;
6892                            }
6893                            let cck = concat_kernel(&dev.device);
6894                            pass.set_pipeline(&cck.pipeline);
6895                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6896                            let (gx, gy, gz) = dispatch_dims(total_s, 64);
6897                            pass.dispatch_workgroups(gx, gy, gz);
6898                        }
6899                        Step::Gather { params } => {
6900                            let n_out_s = scale(params.n_out);
6901                            if n_out_s == 0 {
6902                                continue;
6903                            }
6904                            let gk = gather_kernel(&dev.device);
6905                            pass.set_pipeline(&gk.pipeline);
6906                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6907                            let (gx, gy, gz) = dispatch_dims(n_out_s, 64);
6908                            pass.dispatch_workgroups(gx, gy, gz);
6909                        }
6910                        Step::GatherAxis { params } => {
6911                            let total_s = scale(params.total);
6912                            if total_s == 0 {
6913                                continue;
6914                            }
6915                            let gk = gather_axis_kernel(&dev.device);
6916                            pass.set_pipeline(&gk.pipeline);
6917                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6918                            let (gx, gy, gz) = dispatch_dims(total_s, 64);
6919                            pass.dispatch_workgroups(gx, gy, gz);
6920                        }
6921                        Step::Attention { params, .. } => {
6922                            // Scale seq_q for grid dim; per-head strides
6923                            // come from seq_q_stride / seq_k_stride (full
6924                            // extent) inside the WGSL.
6925                            let seq_q_s = scale(params.seq_q);
6926                            if seq_q_s == 0 {
6927                                continue;
6928                            }
6929                            let ak = attention_kernel(&dev.device);
6930                            pass.set_pipeline(&ak.pipeline);
6931                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6932                            let total = params.batch * params.heads * seq_q_s;
6933                            let (gx, gy, gz) = dispatch_dims(total, 64);
6934                            pass.dispatch_workgroups(gx, gy, gz);
6935                        }
6936                        Step::AttentionBackward { params, .. } => {
6937                            let axis = if params.wrt == 0 {
6938                                params.seq_q
6939                            } else {
6940                                params.seq_k
6941                            };
6942                            let axis_s = scale(axis);
6943                            if axis_s == 0 {
6944                                continue;
6945                            }
6946                            let ak = attention_bwd_kernel(&dev.device);
6947                            pass.set_pipeline(&ak.pipeline);
6948                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6949                            let total = params.batch * params.heads * axis_s;
6950                            let (gx, gy, gz) = dispatch_dims(total, 64);
6951                            pass.dispatch_workgroups(gx, gy, gz);
6952                        }
6953                        Step::Rope { params } => {
6954                            // Multi-batch via stride-field WGSL fix:
6955                            // iterate `batch * scaled_seq * last_dim` items.
6956                            let s_active = scale(params.seq);
6957                            let total_s = params.batch * s_active * params.last_dim;
6958                            if total_s == 0 {
6959                                continue;
6960                            }
6961                            let rk = rope_kernel(&dev.device);
6962                            pass.set_pipeline(&rk.pipeline);
6963                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6964                            let (gx, gy, gz) = dispatch_dims(total_s, 64);
6965                            pass.dispatch_workgroups(gx, gy, gz);
6966                        }
6967                        Step::Expand { params, .. } => {
6968                            let total_s = if params.bucket_outermost == 1 && params.out_dim_0 > 0 {
6969                                let scaled_d0 = scale(params.out_dim_0);
6970                                let inner = params.out_total / params.out_dim_0;
6971                                scaled_d0 * inner
6972                            } else {
6973                                params.out_total
6974                            };
6975                            if total_s == 0 {
6976                                continue;
6977                            }
6978                            let ek = expand_kernel(&dev.device);
6979                            pass.set_pipeline(&ek.pipeline);
6980                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6981                            let (gx, gy, gz) = dispatch_dims(total_s, 64);
6982                            pass.dispatch_workgroups(gx, gy, gz);
6983                        }
6984                        Step::Argmax { params } => {
6985                            let outer_s = scale(params.outer);
6986                            if outer_s == 0 {
6987                                continue;
6988                            }
6989                            let amk = argmax_kernel(&dev.device);
6990                            pass.set_pipeline(&amk.pipeline);
6991                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
6992                            let (gx, gy, gz) = dispatch_dims(outer_s, 64);
6993                            pass.dispatch_workgroups(gx, gy, gz);
6994                        }
6995                        Step::Pool2d { params } => {
6996                            let n_s = scale(params.n);
6997                            if n_s == 0 {
6998                                continue;
6999                            }
7000                            let pk = pool2d_kernel(&dev.device);
7001                            pass.set_pipeline(&pk.pipeline);
7002                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7003                            let total = n_s * params.c * params.h_out * params.w_out;
7004                            let (gx, gy, gz) = dispatch_dims(total, 64);
7005                            pass.dispatch_workgroups(gx, gy, gz);
7006                        }
7007                        Step::Conv2d { params } => {
7008                            let n_s = scale(params.n);
7009                            if n_s == 0 {
7010                                continue;
7011                            }
7012                            let ck2 = conv2d_kernel(&dev.device);
7013                            pass.set_pipeline(&ck2.pipeline);
7014                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7015                            let total = n_s * params.c_out * params.h_out * params.w_out;
7016                            let (gx, gy, gz) = dispatch_dims(total, 64);
7017                            pass.dispatch_workgroups(gx, gy, gz);
7018                        }
7019                        Step::Pool1d { params } => {
7020                            let n_s = scale(params.n);
7021                            if n_s == 0 {
7022                                continue;
7023                            }
7024                            let pk = pool1d_kernel(&dev.device);
7025                            pass.set_pipeline(&pk.pipeline);
7026                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7027                            let total = n_s * params.c * params.l_out;
7028                            let (gx, gy, gz) = dispatch_dims(total, 64);
7029                            pass.dispatch_workgroups(gx, gy, gz);
7030                        }
7031                        Step::Pool3d { params } => {
7032                            let n_s = scale(params.n);
7033                            if n_s == 0 {
7034                                continue;
7035                            }
7036                            let pk = pool3d_kernel(&dev.device);
7037                            pass.set_pipeline(&pk.pipeline);
7038                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7039                            let total = n_s * params.c * params.d_out * params.h_out * params.w_out;
7040                            let (gx, gy, gz) = dispatch_dims(total, 64);
7041                            pass.dispatch_workgroups(gx, gy, gz);
7042                        }
7043                        Step::Conv1d { params } => {
7044                            let n_s = scale(params.n);
7045                            if n_s == 0 {
7046                                continue;
7047                            }
7048                            let ck = conv1d_kernel(&dev.device);
7049                            pass.set_pipeline(&ck.pipeline);
7050                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7051                            let total = n_s * params.c_out * params.l_out;
7052                            let (gx, gy, gz) = dispatch_dims(total, 64);
7053                            pass.dispatch_workgroups(gx, gy, gz);
7054                        }
7055                        Step::Conv3d { params } => {
7056                            let n_s = scale(params.n);
7057                            if n_s == 0 {
7058                                continue;
7059                            }
7060                            let ck = conv3d_kernel(&dev.device);
7061                            pass.set_pipeline(&ck.pipeline);
7062                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7063                            let total =
7064                                n_s * params.c_out * params.d_out * params.h_out * params.w_out;
7065                            let (gx, gy, gz) = dispatch_dims(total, 64);
7066                            pass.dispatch_workgroups(gx, gy, gz);
7067                        }
7068                        Step::ScatterAdd { params } => {
7069                            let sk = scatter_add_kernel(&dev.device);
7070                            pass.set_pipeline(&sk.pipeline);
7071                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7072                            // Phase 0 zeros the FULL output (preserves
7073                            // accumulator semantics). Phase 1 scatters first
7074                            // num_updates_active updates only; serial single
7075                            // workgroup either way (atomic CAS unsupported in
7076                            // naga's MSL emitter — see scatter_add.wgsl).
7077                            if params.op == 0 {
7078                                let (gx, gy, gz) = dispatch_dims(params.out_total, 64);
7079                                pass.dispatch_workgroups(gx, gy, gz);
7080                            } else {
7081                                pass.dispatch_workgroups(1, 1, 1);
7082                            }
7083                        }
7084                        Step::TopK { params } => {
7085                            let outer_s = scale(params.outer);
7086                            if outer_s == 0 {
7087                                continue;
7088                            }
7089                            let tk = topk_kernel(&dev.device);
7090                            pass.set_pipeline(&tk.pipeline);
7091                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7092                            let (gx, gy, gz) = dispatch_dims(outer_s, 64);
7093                            pass.dispatch_workgroups(gx, gy, gz);
7094                        }
7095                        Step::WelchPeaksGpu { params } => {
7096                            let batch_s = scale(params.welch_batch);
7097                            if batch_s == 0 {
7098                                continue;
7099                            }
7100                            let wk = welch_peaks_gpu_kernel(&dev.device);
7101                            pass.set_pipeline(&wk.pipeline);
7102                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7103                            let (gx, gy, gz) = dispatch_dims(batch_s, 64);
7104                            pass.dispatch_workgroups(gx, gy, gz);
7105                        }
7106                        Step::UmapKnn { params } => {
7107                            let n_s = scale(params.n);
7108                            if n_s == 0 {
7109                                continue;
7110                            }
7111                            let uk = umap_knn_kernel(&dev.device);
7112                            pass.set_pipeline(&uk.pipeline);
7113                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7114                            let (gx, gy, gz) = dispatch_dims(n_s, 64);
7115                            pass.dispatch_workgroups(gx, gy, gz);
7116                        }
7117                        Step::GroupedMatmul { params } => {
7118                            let m_s = scale(params.m);
7119                            if m_s == 0 {
7120                                continue;
7121                            }
7122                            let gk = grouped_matmul_kernel(&dev.device);
7123                            pass.set_pipeline(&gk.pipeline);
7124                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7125                            pass.dispatch_workgroups(params.n.div_ceil(8), m_s.div_ceil(8), 1);
7126                        }
7127                        Step::Sample { params } => {
7128                            let outer_s = scale(params.outer);
7129                            if outer_s == 0 {
7130                                continue;
7131                            }
7132                            let sk = sample_kernel(&dev.device);
7133                            pass.set_pipeline(&sk.pipeline);
7134                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7135                            let (gx, gy, gz) = dispatch_dims(outer_s, 64);
7136                            pass.dispatch_workgroups(gx, gy, gz);
7137                        }
7138                        Step::SelectiveScan { params } => {
7139                            // Predicate-gated to batch=1; the seq scaling
7140                            // happens inside the kernel (uniform sees scaled
7141                            // seq). Dispatch grid here is per-(batch, hidden);
7142                            // unaffected by seq scaling.
7143                            let ssk = selective_scan_kernel(&dev.device);
7144                            pass.set_pipeline(&ssk.pipeline);
7145                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7146                            let total = params.batch * params.hidden;
7147                            let (gx, gy, gz) = dispatch_dims(total, 64);
7148                            pass.dispatch_workgroups(gx, gy, gz);
7149                        }
7150                        Step::DequantMatmul { params } => {
7151                            let m_s = scale(params.m);
7152                            if m_s == 0 {
7153                                continue;
7154                            }
7155                            let dk = dequant_matmul_kernel(&dev.device);
7156                            pass.set_pipeline(&dk.pipeline);
7157                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7158                            pass.dispatch_workgroups(params.n.div_ceil(8), m_s.div_ceil(8), 1);
7159                        }
7160                        Step::FusedResidualLn { params } => {
7161                            let outer_s = scale(params.outer);
7162                            if outer_s == 0 {
7163                                continue;
7164                            }
7165                            let frk = fused_residual_ln_kernel(&dev.device);
7166                            pass.set_pipeline(&frk.pipeline);
7167                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7168                            let (gx, gy, gz) = dispatch_dims(outer_s, 64);
7169                            pass.dispatch_workgroups(gx, gy, gz);
7170                        }
7171                        Step::FusedResidualLnTee { params } => {
7172                            let outer_s = scale(params.outer);
7173                            if outer_s == 0 {
7174                                continue;
7175                            }
7176                            let frtk = fused_residual_ln_tee_kernel(&dev.device);
7177                            pass.set_pipeline(&frtk.pipeline);
7178                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7179                            let (gx, gy, gz) = dispatch_dims(outer_s, 64);
7180                            pass.dispatch_workgroups(gx, gy, gz);
7181                        }
7182                        Step::FusedResidualRmsNorm { params } => {
7183                            let outer_s = scale(params.outer);
7184                            if outer_s == 0 {
7185                                continue;
7186                            }
7187                            let frk = fused_residual_rms_norm_kernel(&dev.device);
7188                            pass.set_pipeline(&frk.pipeline);
7189                            pass.set_bind_group(0, &self.bind_groups[gpu_bi], &[]);
7190                            let (gx, gy, gz) = dispatch_dims(outer_s, 64);
7191                            pass.dispatch_workgroups(gx, gy, gz);
7192                        }
7193                        Step::MatmulQkv { params, kind } => {
7194                            let m_s = scale(params.m);
7195                            if m_s == 0 {
7196                                continue;
7197                            }
7198                            let qkv_coop_wide = matches!(kind, MatmulQkvKind::CoopF16Vk)
7199                                && crate::coop_f16_vk::use_wide_matmul(
7200                                    params.b_off,
7201                                    params.n,
7202                                    &self.coop_f16_b_param,
7203                                    &self.coop_f16_vk_wide_b,
7204                                );
7205                            pass.set_bind_group(
7206                                0,
7207                                coop_f16_vk_bind_group(self, gpu_bi, qkv_coop_wide),
7208                                &[],
7209                            );
7210                            match kind {
7211                                MatmulQkvKind::CoopF16Vk => {
7212                                    if qkv_coop_wide {
7213                                        pass.set_pipeline(&matmul_qkv_kernel(&dev.device).pipeline);
7214                                        pass.dispatch_workgroups(
7215                                            params.n.div_ceil(32),
7216                                            m_s.div_ceil(32),
7217                                            1,
7218                                        );
7219                                    } else {
7220                                        let n_eff = scale(params.n);
7221                                        let mqk = matmul_qkv_coop_f16_vk_active_kernel(
7222                                            &dev.device,
7223                                            n_eff,
7224                                        )
7225                                        .expect("coop f16 matmul_qkv kernel missing");
7226                                        pass.set_pipeline(&mqk.pipeline);
7227                                        pass.dispatch_workgroups(
7228                                            m_s.div_ceil(16),
7229                                            params.n.div_ceil(16),
7230                                            1,
7231                                        );
7232                                    }
7233                                }
7234                                MatmulQkvKind::CoopF32 => {
7235                                    pass.set_pipeline(
7236                                        &matmul_qkv_coop_f32_kernel(&dev.device)
7237                                            .expect("coop matmul_qkv kernel missing")
7238                                            .pipeline,
7239                                    );
7240                                    pass.dispatch_workgroups(
7241                                        params.n.div_ceil(32),
7242                                        m_s.div_ceil(32),
7243                                        1,
7244                                    );
7245                                }
7246                                MatmulQkvKind::F32 => {
7247                                    pass.set_pipeline(&matmul_qkv_kernel(&dev.device).pipeline);
7248                                    pass.dispatch_workgroups(
7249                                        params.n.div_ceil(32),
7250                                        m_s.div_ceil(32),
7251                                        1,
7252                                    );
7253                                }
7254                            }
7255                        }
7256                        Step::DequantMatmulGguf { .. }
7257                        | Step::DequantGroupedMatmulGguf { .. }
7258                        | Step::GatedDeltaNet { .. }
7259                        | Step::Llada2GroupLimitedGate { .. }
7260                        | Step::UmapKnnHost { .. }
7261                        | Step::FftHost { .. }
7262                        | Step::Im2ColHost { .. }
7263                        | Step::WelchPeaksHost { .. }
7264                        | Step::LogMelHost { .. }
7265                        | Step::LogMelBackwardHost { .. } => {}
7266                        #[cfg(feature = "splat")]
7267                        Step::GaussianSplatRender { .. }
7268                        | Step::GaussianSplatRenderBackward { .. }
7269                        | Step::GaussianSplatPrepare { .. }
7270                        | Step::GaussianSplatRasterize { .. } => {}
7271                    }
7272                    if !matches!(step, Step::FftGpu { .. }) {
7273                        gpu_bi += 1;
7274                    }
7275                    step_i += 1;
7276                    pass_dispatched = true;
7277                }
7278            }
7279            let needs_f16_drain = step_i < self.schedule.len()
7280                && !step_runs_on_host(&self.schedule[step_i])
7281                && step_i > 0
7282                && step_needs_pass_flush(&self.schedule[step_i], &self.schedule[step_i - 1]);
7283            let gpu_schedule_done = step_i >= self.schedule.len();
7284            let skip_readback = rlx_ir::env::flag("RLX_BENCH_DISPATCH_ONLY");
7285            let defer_tail = gpu_schedule_done && self.schedule.iter().any(step_is_tail_host);
7286            let mut fused_readback: Option<(
7287                ReadbackLayout,
7288                std::sync::mpsc::Receiver<Result<(), wgpu::BufferAsyncError>>,
7289                Vec<usize>,
7290            )> = None;
7291            if gpu_schedule_done && !skip_readback && !defer_tail {
7292                if !self.gpu_handle_feeds.is_empty() {
7293                    self.propagate_gpu_handle_feeds_on_gpu(dev, &mut enc);
7294                }
7295                let plan = self.readback_plan();
7296                let out_ids_all: Vec<_> = self.graph.outputs.clone();
7297                let out_ids: Vec<_> = plan.iter().map(|&i| out_ids_all[i]).collect();
7298                let layout = ReadbackLayout::for_nodes(&self.arena, &out_ids);
7299                if use_tiny_readback(&layout, out_ids.len()) && plan == vec![0] {
7300                    if self.tiny_readback.is_none() {
7301                        self.tiny_readback = Some(TinyReadbackStaging::new(&dev.device));
7302                    }
7303                    let tiny = self.tiny_readback.as_ref().expect("tiny readback");
7304                    encode_readback_copies(&mut enc, &self.arena, tiny.buffer(), &out_ids, &layout);
7305                    let map_rx = schedule_readback_map(&mut enc, tiny.buffer(), &layout);
7306                    dev.queue.submit(std::iter::once(enc.finish()));
7307                    let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
7308                    wait_readback_map(&dev.device, &map_rx, layout.total_bytes);
7309                    map_rx.recv().unwrap().unwrap();
7310                    return self.pack_readback_outputs(
7311                        &plan,
7312                        vec![decode_tiny_mapped_f32(tiny.buffer(), layout.total_bytes)],
7313                    );
7314                }
7315                ReadbackStaging::prepare(
7316                    &dev.device,
7317                    &mut self.readback_staging,
7318                    layout.total_bytes,
7319                );
7320                if let Some(staging) = self.readback_staging.as_ref() {
7321                    encode_readback_copies(
7322                        &mut enc,
7323                        &self.arena,
7324                        staging.buffer(),
7325                        &out_ids,
7326                        &layout,
7327                    );
7328                    let map_rx = schedule_readback_map(&mut enc, staging.buffer(), &layout);
7329                    fused_readback = Some((layout, map_rx, plan));
7330                }
7331            }
7332            dev.queue.submit(std::iter::once(enc.finish()));
7333            if defer_tail {
7334                let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
7335                self.run_tail_host_audio_ops(dev);
7336                if !skip_readback {
7337                    let mut rb_enc =
7338                        dev.device
7339                            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
7340                                label: Some("rlx-wgpu readback after tail-host"),
7341                            });
7342                    if !self.gpu_handle_feeds.is_empty() {
7343                        self.propagate_gpu_handle_feeds_on_gpu(dev, &mut rb_enc);
7344                    }
7345                    let plan = self.readback_plan();
7346                    let out_ids_all: Vec<_> = self.graph.outputs.clone();
7347                    let out_ids: Vec<_> = plan.iter().map(|&i| out_ids_all[i]).collect();
7348                    let layout = ReadbackLayout::for_nodes(&self.arena, &out_ids);
7349                    if use_tiny_readback(&layout, out_ids.len()) && plan == vec![0] {
7350                        if self.tiny_readback.is_none() {
7351                            self.tiny_readback = Some(TinyReadbackStaging::new(&dev.device));
7352                        }
7353                        let tiny = self.tiny_readback.as_ref().expect("tiny readback");
7354                        encode_readback_copies(
7355                            &mut rb_enc,
7356                            &self.arena,
7357                            tiny.buffer(),
7358                            &out_ids,
7359                            &layout,
7360                        );
7361                        let map_rx = schedule_readback_map(&mut rb_enc, tiny.buffer(), &layout);
7362                        dev.queue.submit(std::iter::once(rb_enc.finish()));
7363                        wait_readback_map(&dev.device, &map_rx, layout.total_bytes);
7364                        map_rx.recv().unwrap().unwrap();
7365                        return self.pack_readback_outputs(
7366                            &plan,
7367                            vec![decode_tiny_mapped_f32(tiny.buffer(), layout.total_bytes)],
7368                        );
7369                    }
7370                    ReadbackStaging::prepare(
7371                        &dev.device,
7372                        &mut self.readback_staging,
7373                        layout.total_bytes,
7374                    );
7375                    if let Some(staging) = self.readback_staging.as_ref() {
7376                        encode_readback_copies(
7377                            &mut rb_enc,
7378                            &self.arena,
7379                            staging.buffer(),
7380                            &out_ids,
7381                            &layout,
7382                        );
7383                        let map_rx = schedule_readback_map(&mut rb_enc, staging.buffer(), &layout);
7384                        dev.queue.submit(std::iter::once(rb_enc.finish()));
7385                        wait_readback_map(&dev.device, &map_rx, layout.total_bytes);
7386                        map_rx.recv().unwrap().unwrap();
7387                        self.dump_node_stats_if_requested(dev);
7388                        let partial = decode_mapped_readback_f32(staging.buffer(), &layout);
7389                        return self.pack_readback_outputs(&plan, partial);
7390                    }
7391                }
7392            }
7393            if needs_f16_drain {
7394                let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
7395            }
7396            let need_host_sync =
7397                step_i < self.schedule.len() && step_runs_on_host(&self.schedule[step_i]);
7398            if need_host_sync {
7399                let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
7400            }
7401            if gpu_schedule_done {
7402                if skip_readback || defer_tail {
7403                    return self
7404                        .graph
7405                        .outputs
7406                        .iter()
7407                        .map(|&id| {
7408                            let n = self.graph.node(id).shape.num_elements().unwrap_or(0);
7409                            vec![0.0; n]
7410                        })
7411                        .collect();
7412                }
7413                if let (Some((layout, map_rx, plan)), Some(staging)) =
7414                    (fused_readback, self.readback_staging.as_ref())
7415                {
7416                    wait_readback_map(&dev.device, &map_rx, layout.total_bytes);
7417                    map_rx.recv().unwrap().unwrap();
7418                    self.dump_node_stats_if_requested(dev);
7419                    let partial = decode_mapped_readback_f32(staging.buffer(), &layout);
7420                    return self.pack_readback_outputs(&plan, partial);
7421                }
7422                break;
7423            }
7424            match &self.schedule[step_i] {
7425                Step::BufferCopy {
7426                    src_byte_off,
7427                    dst_byte_off,
7428                    bytes,
7429                } => {
7430                    // wgpu forbids `copy_buffer_to_buffer` on the same buffer;
7431                    // use the generic copy compute kernel instead.
7432                    let src = *src_byte_off as u64;
7433                    let dst = *dst_byte_off as u64;
7434                    let nbytes = *bytes as u64;
7435                    let elems = (nbytes / 4).max(1) as u32;
7436                    let lo = src.min(dst);
7437                    let hi = src.saturating_add(nbytes).max(dst.saturating_add(nbytes));
7438                    let max_binding = dev.device.limits().max_storage_buffer_binding_size;
7439                    let span = hi.saturating_sub(lo).max(1);
7440                    let mut size = span.div_ceil(256) * 256;
7441                    size = size.max(256).min(max_binding);
7442                    let mut base = (lo / 256) * 256;
7443                    if base.saturating_add(size) > self.arena.size as u64 {
7444                        base = (self.arena.size as u64).saturating_sub(size);
7445                        base = (base / 256) * 256;
7446                    }
7447                    let p = CopyParams {
7448                        n: elems,
7449                        in_off: (src.saturating_sub(base) / 4) as u32,
7450                        out_off: (dst.saturating_sub(base) / 4) as u32,
7451                        _p0: 0,
7452                        _p1: 0,
7453                        _p2: 0,
7454                        _p3: 0,
7455                        _p4: 0,
7456                    };
7457                    let ck = copy_kernel(&dev.device);
7458                    let u = dev.device.create_buffer(&wgpu::BufferDescriptor {
7459                        label: Some("rlx-wgpu arena_copy uniform"),
7460                        size: std::mem::size_of::<CopyParams>() as u64,
7461                        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
7462                        mapped_at_creation: false,
7463                    });
7464                    dev.queue.write_buffer(&u, 0, bytemuck::bytes_of(&p));
7465                    let bg =
7466                        bind_two_buf0_window(&dev.device, ck, &self.arena.buffer, base, size, &u);
7467                    let mut enc =
7468                        dev.device
7469                            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
7470                                label: Some("rlx-wgpu arena_copy"),
7471                            });
7472                    {
7473                        let mut pass = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
7474                            label: Some("rlx-wgpu arena_copy pass"),
7475                            ..Default::default()
7476                        });
7477                        pass.set_pipeline(&ck.pipeline);
7478                        pass.set_bind_group(0, &bg, &[]);
7479                        let (gx, gy, gz) = dispatch_dims(elems, 64);
7480                        pass.dispatch_workgroups(gx, gy, gz);
7481                    }
7482                    dev.queue.submit(std::iter::once(enc.finish()));
7483                    let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
7484                }
7485                Step::DequantMatmulGguf {
7486                    m,
7487                    k,
7488                    n,
7489                    scheme_id,
7490                    x_byte_off,
7491                    w_byte_off,
7492                    out_byte_off,
7493                } => {
7494                    crate::gguf_host::run_dequant_matmul_gguf(
7495                        &self.arena,
7496                        &dev.device,
7497                        &dev.queue,
7498                        *m as usize,
7499                        *k as usize,
7500                        *n as usize,
7501                        *scheme_id,
7502                        *x_byte_off as usize,
7503                        *w_byte_off as usize,
7504                        *out_byte_off as usize,
7505                    );
7506                }
7507                Step::DequantGroupedMatmulGguf {
7508                    m,
7509                    k,
7510                    n,
7511                    num_experts,
7512                    scheme_id,
7513                    x_byte_off,
7514                    w_byte_off,
7515                    idx_byte_off,
7516                    out_byte_off,
7517                } => {
7518                    crate::gguf_host::run_dequant_grouped_matmul_gguf(
7519                        &self.arena,
7520                        &dev.device,
7521                        &dev.queue,
7522                        *m as usize,
7523                        *k as usize,
7524                        *n as usize,
7525                        *num_experts as usize,
7526                        *scheme_id,
7527                        *x_byte_off as usize,
7528                        *w_byte_off as usize,
7529                        *idx_byte_off as usize,
7530                        *out_byte_off as usize,
7531                    );
7532                }
7533                Step::GatedDeltaNet {
7534                    q_byte_off,
7535                    k_byte_off,
7536                    v_byte_off,
7537                    g_byte_off,
7538                    beta_byte_off,
7539                    state_byte_off,
7540                    dst_byte_off,
7541                    batch,
7542                    seq,
7543                    heads,
7544                    state_size,
7545                    use_carry,
7546                } => {
7547                    crate::gdn_host::run_gated_delta_net(
7548                        &self.arena,
7549                        &dev.device,
7550                        &dev.queue,
7551                        *q_byte_off as usize,
7552                        *k_byte_off as usize,
7553                        *v_byte_off as usize,
7554                        *g_byte_off as usize,
7555                        *beta_byte_off as usize,
7556                        *state_byte_off as usize,
7557                        *dst_byte_off as usize,
7558                        *batch as usize,
7559                        *seq as usize,
7560                        *heads as usize,
7561                        *state_size as usize,
7562                        *use_carry,
7563                    );
7564                }
7565                Step::Llada2GroupLimitedGate {
7566                    sig_byte_off,
7567                    route_byte_off,
7568                    out_byte_off,
7569                    n_elems,
7570                    attrs,
7571                } => {
7572                    crate::llada2_gate_host::run_llada2_group_limited_gate(
7573                        &self.arena,
7574                        &dev.device,
7575                        &dev.queue,
7576                        *sig_byte_off as usize,
7577                        *route_byte_off as usize,
7578                        *out_byte_off as usize,
7579                        *n_elems as usize,
7580                        attrs,
7581                    );
7582                }
7583                Step::UmapKnnHost {
7584                    pairwise_byte_off,
7585                    out_byte_off,
7586                    n,
7587                    k,
7588                } => {
7589                    crate::umap_knn_host::run_umap_knn(
7590                        &self.arena,
7591                        &dev.device,
7592                        &dev.queue,
7593                        *pairwise_byte_off as usize,
7594                        *out_byte_off as usize,
7595                        *n as usize,
7596                        *k as usize,
7597                    );
7598                }
7599                Step::FftHost {
7600                    src_byte_off,
7601                    dst_byte_off,
7602                    outer,
7603                    n_complex,
7604                    inverse,
7605                    norm_tag,
7606                    dtype_tag,
7607                } => {
7608                    crate::fft_host::run_fft1d(
7609                        &self.arena,
7610                        &dev.device,
7611                        &dev.queue,
7612                        *src_byte_off as usize,
7613                        *dst_byte_off as usize,
7614                        *outer as usize,
7615                        *n_complex as usize,
7616                        *inverse,
7617                        *norm_tag,
7618                        fft_dtype_from_tag(*dtype_tag),
7619                    );
7620                }
7621                Step::WelchPeaksHost { .. }
7622                | Step::LogMelHost { .. }
7623                | Step::LogMelBackwardHost { .. } => {
7624                    unreachable!("tail-host audio ops run after GPU wait")
7625                }
7626                Step::Im2ColHost {
7627                    x_byte_off,
7628                    col_byte_off,
7629                    n,
7630                    c_in,
7631                    h,
7632                    w,
7633                    h_out,
7634                    w_out,
7635                    kh,
7636                    kw,
7637                    sh,
7638                    sw,
7639                    ph,
7640                    pw,
7641                    dh,
7642                    dw_dil,
7643                } => {
7644                    crate::im2col_host::run_im2col(
7645                        &self.arena,
7646                        &dev.device,
7647                        &dev.queue,
7648                        *x_byte_off as usize,
7649                        *col_byte_off as usize,
7650                        *n,
7651                        *c_in,
7652                        *h,
7653                        *w,
7654                        *h_out,
7655                        *w_out,
7656                        *kh,
7657                        *kw,
7658                        *sh,
7659                        *sw,
7660                        *ph,
7661                        *pw,
7662                        *dh,
7663                        *dw_dil,
7664                    );
7665                }
7666                #[cfg(feature = "splat")]
7667                Step::GaussianSplatRender {
7668                    positions_byte_off,
7669                    positions_len,
7670                    scales_byte_off,
7671                    scales_len,
7672                    rotations_byte_off,
7673                    rotations_len,
7674                    opacities_byte_off,
7675                    opacities_len,
7676                    colors_byte_off,
7677                    colors_len,
7678                    sh_coeffs_byte_off,
7679                    sh_coeffs_len,
7680                    meta_byte_off,
7681                    dst_byte_off,
7682                    dst_len,
7683                    width,
7684                    height,
7685                    tile_size,
7686                    radius_scale,
7687                    alpha_cutoff,
7688                    max_splat_steps,
7689                    transmittance_threshold,
7690                    max_list_entries,
7691                } => {
7692                    crate::splat::run_gaussian_splat_render(
7693                        &self.arena,
7694                        &dev.device,
7695                        &dev.queue,
7696                        *positions_byte_off as usize,
7697                        *positions_len as usize,
7698                        *scales_byte_off as usize,
7699                        *scales_len as usize,
7700                        *rotations_byte_off as usize,
7701                        *rotations_len as usize,
7702                        *opacities_byte_off as usize,
7703                        *opacities_len as usize,
7704                        *colors_byte_off as usize,
7705                        *colors_len as usize,
7706                        *sh_coeffs_byte_off as usize,
7707                        *sh_coeffs_len as usize,
7708                        *meta_byte_off as usize,
7709                        *dst_byte_off as usize,
7710                        *dst_len as usize,
7711                        *width,
7712                        *height,
7713                        *tile_size,
7714                        *radius_scale,
7715                        *alpha_cutoff,
7716                        *max_splat_steps,
7717                        *transmittance_threshold,
7718                        *max_list_entries,
7719                    );
7720                }
7721                #[cfg(feature = "splat")]
7722                Step::GaussianSplatPrepare {
7723                    positions_byte_off,
7724                    positions_len,
7725                    scales_byte_off,
7726                    scales_len,
7727                    rotations_byte_off,
7728                    rotations_len,
7729                    opacities_byte_off,
7730                    opacities_len,
7731                    colors_byte_off,
7732                    colors_len,
7733                    sh_coeffs_byte_off,
7734                    sh_coeffs_len,
7735                    meta_byte_off,
7736                    meta_len,
7737                    prep_byte_off,
7738                    prep_len,
7739                    width,
7740                    height,
7741                    tile_size,
7742                    radius_scale,
7743                    alpha_cutoff,
7744                    max_splat_steps,
7745                    transmittance_threshold,
7746                    max_list_entries,
7747                } => {
7748                    crate::splat::run_gaussian_splat_prepare(
7749                        &self.arena,
7750                        &dev.device,
7751                        &dev.queue,
7752                        *positions_byte_off as usize,
7753                        *positions_len as usize,
7754                        *scales_byte_off as usize,
7755                        *scales_len as usize,
7756                        *rotations_byte_off as usize,
7757                        *rotations_len as usize,
7758                        *opacities_byte_off as usize,
7759                        *opacities_len as usize,
7760                        *colors_byte_off as usize,
7761                        *colors_len as usize,
7762                        *sh_coeffs_byte_off as usize,
7763                        *sh_coeffs_len as usize,
7764                        *meta_byte_off as usize,
7765                        *meta_len as usize,
7766                        *prep_byte_off as usize,
7767                        *prep_len as usize,
7768                        *width,
7769                        *height,
7770                        *tile_size,
7771                        *radius_scale,
7772                        *alpha_cutoff,
7773                        *max_splat_steps,
7774                        *transmittance_threshold,
7775                        *max_list_entries,
7776                    );
7777                }
7778                #[cfg(feature = "splat")]
7779                Step::GaussianSplatRasterize {
7780                    prep_byte_off,
7781                    prep_len,
7782                    meta_byte_off,
7783                    meta_len,
7784                    dst_byte_off,
7785                    dst_len,
7786                    count,
7787                    width,
7788                    height,
7789                    tile_size,
7790                    alpha_cutoff,
7791                    max_splat_steps,
7792                    transmittance_threshold,
7793                    max_list_entries,
7794                } => {
7795                    crate::splat::run_gaussian_splat_rasterize(
7796                        &self.arena,
7797                        &dev.device,
7798                        &dev.queue,
7799                        *prep_byte_off as usize,
7800                        *prep_len as usize,
7801                        *meta_byte_off as usize,
7802                        *meta_len as usize,
7803                        *dst_byte_off as usize,
7804                        *dst_len as usize,
7805                        *count as usize,
7806                        *width,
7807                        *height,
7808                        *tile_size,
7809                        *alpha_cutoff,
7810                        *max_splat_steps,
7811                        *transmittance_threshold,
7812                        *max_list_entries,
7813                    );
7814                }
7815                #[cfg(feature = "splat")]
7816                Step::GaussianSplatRenderBackward {
7817                    positions_byte_off,
7818                    positions_len,
7819                    scales_byte_off,
7820                    scales_len,
7821                    rotations_byte_off,
7822                    rotations_len,
7823                    opacities_byte_off,
7824                    opacities_len,
7825                    colors_byte_off,
7826                    colors_len,
7827                    sh_coeffs_byte_off,
7828                    sh_coeffs_len,
7829                    meta_byte_off,
7830                    d_loss_byte_off,
7831                    d_loss_len,
7832                    packed_byte_off,
7833                    packed_len,
7834                    width,
7835                    height,
7836                    tile_size,
7837                    radius_scale,
7838                    alpha_cutoff,
7839                    max_splat_steps,
7840                    transmittance_threshold,
7841                    max_list_entries,
7842                    loss_grad_clip,
7843                    sh_band,
7844                    max_anisotropy,
7845                } => {
7846                    crate::splat::run_gaussian_splat_render_backward(
7847                        &self.arena,
7848                        &dev.device,
7849                        &dev.queue,
7850                        *positions_byte_off as usize,
7851                        *positions_len as usize,
7852                        *scales_byte_off as usize,
7853                        *scales_len as usize,
7854                        *rotations_byte_off as usize,
7855                        *rotations_len as usize,
7856                        *opacities_byte_off as usize,
7857                        *opacities_len as usize,
7858                        *colors_byte_off as usize,
7859                        *colors_len as usize,
7860                        *sh_coeffs_byte_off as usize,
7861                        *sh_coeffs_len as usize,
7862                        *meta_byte_off as usize,
7863                        *d_loss_byte_off as usize,
7864                        *d_loss_len as usize,
7865                        *packed_byte_off as usize,
7866                        *packed_len as usize,
7867                        *width,
7868                        *height,
7869                        *tile_size,
7870                        *radius_scale,
7871                        *alpha_cutoff,
7872                        *max_splat_steps,
7873                        *transmittance_threshold,
7874                        *max_list_entries,
7875                        *loss_grad_clip,
7876                        *sh_band,
7877                        *max_anisotropy,
7878                    );
7879                }
7880                _ => break,
7881            }
7882            step_i += 1;
7883        }
7884
7885        self.dump_node_stats_if_requested(dev);
7886
7887        if rlx_ir::env::flag("RLX_WGPU_NAN_TRACE") {
7888            let mut bad_nodes = Vec::new();
7889            for node in self.graph.nodes() {
7890                if !self.arena.has(node.id) {
7891                    continue;
7892                }
7893                // Skip leaves — populated by host writes, not kernels.
7894                if matches!(
7895                    node.op,
7896                    rlx_ir::Op::Input { .. }
7897                        | rlx_ir::Op::Param { .. }
7898                        | rlx_ir::Op::Constant { .. }
7899                ) {
7900                    continue;
7901                }
7902                let data = self.arena.read_f32(&dev.device, &dev.queue, node.id);
7903                let nan_count = data.iter().filter(|v| v.is_nan()).count();
7904                let inf_count = data.iter().filter(|v| v.is_infinite()).count();
7905                if nan_count > 0 || inf_count > 0 {
7906                    // Capture first NaN index + the values around it.
7907                    let first_nan = data.iter().position(|v| v.is_nan());
7908                    if let Some(idx) = first_nan {
7909                        let lo = idx.saturating_sub(2);
7910                        let hi = (idx + 3).min(data.len());
7911                        eprintln!(
7912                            "  node {:?} op={:?} len={} nan={} inf={} \
7913                                   first_nan_idx={} ctx={:?}",
7914                            node.id,
7915                            node.op,
7916                            data.len(),
7917                            nan_count,
7918                            inf_count,
7919                            idx,
7920                            &data[lo..hi]
7921                        );
7922                    }
7923                    bad_nodes.push((node.id, data.len(), nan_count, inf_count));
7924                    if bad_nodes.len() >= 3 {
7925                        break;
7926                    }
7927                }
7928            }
7929            if bad_nodes.is_empty() {
7930                eprintln!("[wgpu-nan-trace] no NaN/Inf in any node — clean run");
7931            } else {
7932                eprintln!(
7933                    "[wgpu-nan-trace] first {} bad nodes (above)",
7934                    bad_nodes.len()
7935                );
7936            }
7937        }
7938
7939        if rlx_ir::env::flag("RLX_BENCH_DISPATCH_ONLY") {
7940            return self
7941                .graph
7942                .outputs
7943                .iter()
7944                .map(|&id| {
7945                    let n = self.graph.node(id).shape.num_elements().unwrap_or(0);
7946                    vec![0.0; n]
7947                })
7948                .collect();
7949        }
7950        let out_ids: Vec<_> = self.graph.outputs.clone();
7951        read_f32_many_pooled(
7952            &self.arena,
7953            &dev.device,
7954            &dev.queue,
7955            &out_ids,
7956            &mut self.readback_staging,
7957        )
7958    }
7959}
7960
7961/// Compute a (X, Y, 1) workgroup grid for a 1-D workload.
7962///
7963/// WebGPU caps `dispatch_workgroups` per-dimension at 65535. For
7964/// workloads beyond `65535 × workgroup_size_x` threads we split into
7965/// a 2-D grid; kernels recover the linear thread index via
7966/// `gid.x + gid.y * num_workgroups.x * 64u`.
7967fn dispatch_prologue_nchw(w: u32, h: u32, nc: u32) -> (u32, u32, u32) {
7968    (w.div_ceil(8).max(1), h.div_ceil(8).max(1), nc.max(1))
7969}
7970
7971fn dispatch_dims(threads_total: u32, workgroup_size: u32) -> (u32, u32, u32) {
7972    let groups = threads_total.div_ceil(workgroup_size);
7973    if groups <= 65535 {
7974        (groups, 1, 1)
7975    } else {
7976        let gx = 65535u32;
7977        let gy = groups.div_ceil(gx);
7978        (gx, gy, 1)
7979    }
7980}
7981
7982/// Shape/feature gate for CoopF16Vk (no operand tracing — avoids circular
7983/// dependency with compile-time f16 mirror planning).
7984///
7985/// **Default OFF.** The Vulkan/DX12 cooperative-matrix matmul path
7986/// silently produces wrong output on BERT-family attention chains on at
7987/// least RTX 4090 (verified empirically against Bio_ClinicalBERT:
7988/// encoder cosine collapses from ≈1.0 on the wide-F32 fallback to ≈0.09
7989/// when the coop path runs, regardless of whether the kernel uses
7990/// F16-acc or F32-acc accumulators). The root cause is upstream — likely
7991/// in how wgpu's `coopLoadT` / `coopMultiplyAdd` interact with strided
7992/// arena buffers on non-Apple drivers — and needs a focused
7993/// reproducer before it can be fixed in `rlx-wgpu`. Until then the
7994/// correctness-first default is to route Vulkan/DX12 matmuls through the
7995/// wide-F32 path, even though it's substantially slower (~80× on this
7996/// shape).
7997///
7998/// Opt back in (at the user's risk) with `RLX_WGPU_COOP_F16_VK_ENABLE=1`
7999/// — useful for measuring the perf headroom or for non-BERT models
8000/// where the precision loss may be acceptable. Legacy
8001/// `RLX_WGPU_NO_COOP_F16_VK=1` and explicit
8002/// `RLX_WGPU_COOP_F16_VK_DISABLE=1` are honored for completeness.
8003fn coop_f16_vk_eligible(dev: &wgpu::Device, m: u32, k: u32, n: u32) -> bool {
8004    if rlx_ir::env::flag("RLX_WGPU_NO_COOP_F16_VK")
8005        || rlx_ir::env::flag("RLX_WGPU_COOP_F16_VK_DISABLE")
8006    {
8007        return false;
8008    }
8009    if !rlx_ir::env::flag("RLX_WGPU_COOP_F16_VK_ENABLE") {
8010        return false;
8011    }
8012    m.is_multiple_of(16)
8013        && k.is_multiple_of(16)
8014        && n.is_multiple_of(16)
8015        && dev
8016            .features()
8017            .contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX)
8018        && dev.features().contains(wgpu::Features::SHADER_F16)
8019        && crate::device::coop_discrete_backend()
8020        && crate::device::coop_f16_16x16_supported()
8021}
8022
8023fn step_needs_pass_flush(step: &Step, prev: &Step) -> bool {
8024    match step {
8025        Step::CastF32ToF16 { .. } => matches!(
8026            prev,
8027            Step::Unary {
8028                f16_mirror: false,
8029                ..
8030            }
8031        ),
8032        Step::Matmul {
8033            compute_precision: MatmulCompute::CoopF16Vk,
8034            ..
8035        }
8036        | Step::MatmulQkv {
8037            kind: MatmulQkvKind::CoopF16Vk,
8038            ..
8039        } => matches!(prev, Step::Unary { .. } | Step::CastF32ToF16 { .. }),
8040        _ => false,
8041    }
8042}
8043
8044fn dispatch_wide_f32_matmul(
8045    pass: &mut wgpu::ComputePass<'_>,
8046    mm_w_active: &Kernel,
8047    mm_k: &Kernel,
8048    m_s: u32,
8049    n: u32,
8050    batch: u32,
8051) {
8052    // Tile-size selection differs by GPU backend.
8053    //
8054    // **Vulkan / DX12** (`matmul_wide_nv`, 64×64 tile): when `m_s < 64`
8055    // the bottom rows of every workgroup's M-axis tile contain padded
8056    // zeros that the kernel still computes and writes back — pure
8057    // wasted work on small-M shapes like BERT-base prefill (m=32). The
8058    // regular 32×32-tile kernel sidesteps the M-axis padding and is
8059    // ~8% faster end-to-end on RTX 4090 (verified on Bio_ClinicalBERT:
8060    // encoder forward 58.9 ms → 54.1 ms at cosine 0.9999995 vs HF).
8061    //
8062    // **Metal / other** (`matmul_wide`, 64×64 tile): the wider tile
8063    // wins even on small M — Apple GPUs prefer the larger workgroup
8064    // and amortize the M-padding well. Forcing the 32×32 kernel here
8065    // regresses Mac WGPU encoder time (26.6 → 29.1 ms verified).
8066    let backend = wgpu_device()
8067        .map(|d| d.backend)
8068        .unwrap_or(wgpu::Backend::Noop);
8069    let is_vulkan_dx12 = matches!(backend, wgpu::Backend::Vulkan | wgpu::Backend::Dx12);
8070    let prefer_small_for_m = is_vulkan_dx12 && m_s < 64;
8071    let use_wide = !prefer_small_for_m && m_s >= 32 && n >= 64;
8072    if use_wide {
8073        pass.set_pipeline(&mm_w_active.pipeline);
8074        let (gx, gy) = if is_vulkan_dx12 {
8075            (n.div_ceil(64), m_s.div_ceil(64))
8076        } else {
8077            (n.div_ceil(64), m_s.div_ceil(32))
8078        };
8079        pass.dispatch_workgroups(gx, gy, batch);
8080    } else {
8081        pass.set_pipeline(&mm_k.pipeline);
8082        pass.dispatch_workgroups(n.div_ceil(32), m_s.div_ceil(32), batch);
8083    }
8084}
8085
8086fn coop_f16_vk_bind_group(exe: &WgpuExecutable, gpu_bi: usize, use_wide: bool) -> &wgpu::BindGroup {
8087    if use_wide {
8088        exe.coop_f16_vk_wide_bind_groups
8089            .get(&gpu_bi)
8090            .unwrap_or(&exe.bind_groups[gpu_bi])
8091    } else {
8092        &exe.bind_groups[gpu_bi]
8093    }
8094}
8095
8096fn require_equal_shapes(graph: &Graph, ids: &[NodeId], op_name: &str) {
8097    let s0 = graph.node(ids[0]).shape.num_elements().unwrap_or(0);
8098    for &id in &ids[1..] {
8099        let si = graph.node(id).shape.num_elements().unwrap_or(0);
8100        if si != s0 {
8101            panic!(
8102                "rlx-wgpu {op_name}: broadcasting not yet implemented; \
8103                    inputs must have the same element count (got {s0} vs {si})"
8104            );
8105        }
8106    }
8107}
8108
8109/// Bind the entire arena in one storage buffer range when it fits the device limit.
8110fn arena_whole_arena_bind(arena: &Arena, max_binding: u64) -> Option<(u64, u64)> {
8111    let need = arena.size as u64;
8112    if need > max_binding {
8113        return None;
8114    }
8115    // Bind size must not exceed the allocated buffer (planner may leave a small tail gap).
8116    let buf_bytes = arena.buffer.size();
8117    let size = need.min(buf_bytes).max(256);
8118    Some((0, size))
8119}
8120
8121fn arena_window_for_nodes(dev: &wgpu::Device, arena: &Arena, ids: &[NodeId]) -> (u64, u64) {
8122    // wgpu requires storage buffer binding offsets aligned to 256 bytes.
8123    const ALIGN: u64 = 256;
8124    let max_binding = dev.limits().max_storage_buffer_binding_size;
8125    if let Some(w) = arena_whole_arena_bind(arena, max_binding) {
8126        return w;
8127    }
8128    let mut lo: u64 = u64::MAX;
8129    let mut hi: u64 = 0;
8130    for &id in ids {
8131        let off = arena.offset(id) as u64;
8132        let len = arena.len_of(id) as u64;
8133        lo = lo.min(off);
8134        hi = hi.max(off.saturating_add(len));
8135    }
8136    if lo == u64::MAX {
8137        return (0, max_binding.max(256));
8138    }
8139    let span = hi.saturating_sub(lo).max(1);
8140    if span > max_binding {
8141        let mut details = String::new();
8142        for &id in ids.iter().take(6) {
8143            let off = arena.offset(id);
8144            let len = arena.len_of(id);
8145            details.push_str(&format!(" id={id:?}@{off}+{len};"));
8146        }
8147        panic!(
8148            "rlx-wgpu: op needs {} bytes of arena span (>{});{}",
8149            span, max_binding, details
8150        );
8151    }
8152    let mut base = (lo / ALIGN) * ALIGN;
8153    // Bind only the byte span the op needs (not the full 4 GiB cap) so we
8154    // don't slide the window to the arena tail and drop low-offset tensors.
8155    let mut size = span.div_ceil(ALIGN) * ALIGN;
8156    size = size.max(256).min(max_binding);
8157    if base.saturating_add(size) > arena.size as u64 {
8158        base = (arena.size as u64).saturating_sub(size);
8159        base = (base / ALIGN) * ALIGN;
8160    }
8161    if base > lo || base.saturating_add(size) < hi {
8162        base = (lo / ALIGN) * ALIGN;
8163        size = hi.saturating_sub(base).div_ceil(ALIGN) * ALIGN;
8164        size = size.max(256).min(max_binding);
8165        if base.saturating_add(size) > arena.size as u64 {
8166            base = hi.saturating_sub(size);
8167            base = (base / ALIGN) * ALIGN;
8168        }
8169    }
8170    (base, size)
8171}
8172
8173fn arena_local_off_f32(arena: &Arena, id: NodeId, base: u64) -> u32 {
8174    (((arena.offset(id) as u64).saturating_sub(base)) / 4) as u32
8175}
8176
8177fn arena_tensor_in_window(arena: &Arena, id: NodeId, base: u64, size: u64) -> bool {
8178    let src = arena.offset(id) as u64;
8179    let len = arena.len_of(id) as u64;
8180    src >= base && src.saturating_add(len) <= base.saturating_add(size)
8181}
8182
8183/// True when two planned arena slots share any byte (memory planner reuse).
8184fn arena_tensors_overlap(arena: &Arena, a: NodeId, b: NodeId) -> bool {
8185    if a == b {
8186        return true;
8187    }
8188    let (a0, al) = (arena.offset(a) as u64, arena.len_of(a) as u64);
8189    let (b0, bl) = (arena.offset(b) as u64, arena.len_of(b) as u64);
8190    if al == 0 || bl == 0 {
8191        return false;
8192    }
8193    let a1 = a0.saturating_add(al);
8194    let b1 = b0.saturating_add(bl);
8195    a0 < b1 && b0 < a1
8196}
8197
8198/// Arena bind window for matmul: when the weight alone fits the bind limit but
8199/// activations + weight do not, anchor on the param tensor (e.g. tied `LmHead`).
8200fn arena_matmul_bind_window(
8201    device: &wgpu::Device,
8202    arena: &Arena,
8203    graph: &Graph,
8204    param_offsets: &HashMap<String, NodeId>,
8205    out_id: NodeId,
8206    a_id: NodeId,
8207    b_id: NodeId,
8208) -> (u64, u64, bool) {
8209    let max_binding = device.limits().max_storage_buffer_binding_size;
8210    if let Some((base, size)) = arena_whole_arena_bind(arena, max_binding) {
8211        return (base, size, false);
8212    }
8213    let ids = [out_id, a_id, b_id];
8214    let all_fits = arena_span_bytes(arena, &ids) <= max_binding;
8215    let b_bytes = arena.len_of(b_id) as u64;
8216    let b_is_param = tensor_is_graph_param(graph, param_offsets, b_id);
8217    let param_anchor =
8218        b_is_param && b_bytes <= max_binding && (!all_fits || b_bytes > ARENA_STAGE_CAP);
8219    let (mut base, mut size) = if param_anchor {
8220        arena_window_for_nodes(device, arena, &[b_id])
8221    } else if all_fits {
8222        arena_window_for_nodes(device, arena, &ids)
8223    } else {
8224        arena_window_for_nodes(device, arena, &[out_id])
8225    };
8226    let param_anchor = param_anchor
8227        || (b_is_param
8228            && b_bytes <= max_binding
8229            && !arena_tensor_in_window(arena, b_id, base, size));
8230    if param_anchor && !arena_tensor_in_window(arena, b_id, base, size) {
8231        (base, size) = arena_window_for_nodes(device, arena, &[b_id]);
8232    }
8233    (base, size, param_anchor)
8234}
8235
8236/// Grow `[base, base+size)` to cover all listed tensors when the span still
8237/// fits `max_storage_buffer_binding_size` (avoids spurious staging copies).
8238fn arena_expand_bind_window(
8239    arena: &Arena,
8240    ids: &[NodeId],
8241    base: &mut u64,
8242    size: &mut u64,
8243    max_binding: u64,
8244) {
8245    const ALIGN: u64 = 256;
8246    let mut lo = *base;
8247    let mut hi = base.saturating_add(*size);
8248    for &id in ids {
8249        let off = arena.offset(id) as u64;
8250        let len = arena.len_of(id) as u64;
8251        lo = lo.min(off);
8252        hi = hi.max(off.saturating_add(len));
8253    }
8254    let span = hi.saturating_sub(lo).max(1);
8255    if span > max_binding {
8256        return;
8257    }
8258    *base = (lo / ALIGN) * ALIGN;
8259    *size = span.div_ceil(ALIGN) * ALIGN;
8260    *size = (*size).max(256).min(max_binding);
8261    if (*base).saturating_add(*size) > arena.size as u64 {
8262        *base = (arena.size as u64).saturating_sub(*size);
8263        *base = (*base / ALIGN) * ALIGN;
8264    }
8265}
8266
8267fn arena_off_in_bind_window(
8268    graph: &Graph,
8269    param_offsets: &HashMap<String, NodeId>,
8270    device: &wgpu::Device,
8271    arena: &Arena,
8272    schedule: &mut Vec<Step>,
8273    scratch: &mut u64,
8274    id: NodeId,
8275    base: &mut u64,
8276    size: &mut u64,
8277) -> u32 {
8278    let max_binding = device.limits().max_storage_buffer_binding_size;
8279    if let Some((b, s)) = arena_whole_arena_bind(arena, max_binding) {
8280        *base = b;
8281        *size = s;
8282        return arena_local_off_f32(arena, id, b);
8283    }
8284    if arena_tensor_in_window(arena, id, *base, *size) {
8285        arena_local_off_f32(arena, id, *base)
8286    } else {
8287        let len = arena.len_of(id) as u64;
8288        if tensor_is_graph_param(graph, param_offsets, id) && len > max_binding {
8289            panic!(
8290                "rlx-wgpu: param node {:?} ({} bytes) exceeds max_storage_buffer_binding_size \
8291                 ({max_binding}); split weights or use f16 shadow binds",
8292                id, len
8293            );
8294        }
8295        if len > ARENA_STAGE_CAP {
8296            let op = &graph.node(id).op;
8297            panic!(
8298                "rlx-wgpu: bind_window would stage {} bytes for {:?} op={op:?} \
8299                 (off={}, base={}, bind_size={})",
8300                len,
8301                id,
8302                arena.offset(id),
8303                *base,
8304                *size,
8305            );
8306        }
8307        arena_off_in_window_or_stage(arena, schedule, scratch, base, size, max_binding, id)
8308    }
8309}
8310
8311/// Bind window for ops that read/write multiple arena tensors (conv, concat, …).
8312/// Returns `(base, size)` and rebased f32 offsets; stages operands that fall outside
8313/// the window when the full span exceeds `max_storage_buffer_binding_size`.
8314fn arena_multi_op_window(
8315    dev: &wgpu::Device,
8316    arena: &Arena,
8317    graph: &Graph,
8318    param_offsets: &HashMap<String, NodeId>,
8319    _schedule: &mut Vec<Step>,
8320    scratch: &mut u64,
8321    ids: &[NodeId],
8322) -> (u64, u64, bool) {
8323    let max_binding = dev.limits().max_storage_buffer_binding_size;
8324    if let Some((base, size)) = arena_whole_arena_bind(arena, max_binding) {
8325        *scratch = arena.scratch_off as u64;
8326        return (base, size, false);
8327    }
8328    let param_anchor = if arena_span_bytes(arena, ids) > max_binding {
8329        ids.iter()
8330            .find(|&&id| {
8331                let nbytes = arena.len_of(id) as u64;
8332                tensor_is_graph_param(graph, param_offsets, id) && nbytes <= max_binding
8333            })
8334            .copied()
8335    } else {
8336        None
8337    };
8338    let mut param_anchored = param_anchor.is_some();
8339    let (mut base, mut size) = if arena_span_bytes(arena, ids) <= max_binding {
8340        arena_window_for_nodes(dev, arena, ids)
8341    } else if let Some(id) = param_anchor {
8342        arena_window_for_nodes(dev, arena, &[id])
8343    } else {
8344        arena_window_for_nodes(dev, arena, &[ids[0]])
8345    };
8346    if let Some(id) = param_anchor {
8347        if !arena_tensor_in_window(arena, id, base, size) {
8348            (base, size) = arena_window_for_nodes(dev, arena, &[id]);
8349        }
8350        param_anchored = true;
8351    } else {
8352        for &id in ids {
8353            let nbytes = arena.len_of(id) as u64;
8354            if tensor_is_graph_param(graph, param_offsets, id)
8355                && nbytes <= max_binding
8356                && !arena_tensor_in_window(arena, id, base, size)
8357            {
8358                (base, size) = arena_window_for_nodes(dev, arena, &[id]);
8359                param_anchored = true;
8360                break;
8361            }
8362        }
8363    }
8364    *scratch = arena.scratch_off as u64;
8365    if param_anchored {
8366        arena_ensure_scratch_in_window(scratch, base, size);
8367    }
8368    (base, size, param_anchored)
8369}
8370
8371fn arena_bind_window_covering_scratch_if_needed(
8372    arena: &Arena,
8373    base: u64,
8374    size: u64,
8375    scratch: u64,
8376) -> u64 {
8377    // Planner places scratch at the arena tail; do not relocate the bind
8378    // window until this op has actually started staging into scratch.
8379    if scratch <= arena.scratch_off as u64 {
8380        return base;
8381    }
8382    if scratch >= base && scratch.saturating_add(ARENA_STAGE_CAP) <= base.saturating_add(size) {
8383        return base;
8384    }
8385    arena_window_covering_scratch(arena, base, size)
8386}
8387
8388/// Keep staging writes inside `[base, base+size)` when the bind window is anchored on a
8389/// param far from the arena tail scratch zone.
8390fn arena_ensure_scratch_in_window(scratch: &mut u64, base: u64, size: u64) {
8391    let cap = ARENA_STAGE_CAP.min(size);
8392    let end = base.saturating_add(size);
8393    if *scratch < base || scratch.saturating_add(cap) > end {
8394        *scratch = end.saturating_sub(cap);
8395        *scratch = (*scratch / 256) * 256;
8396    }
8397}
8398
8399#[allow(dead_code)]
8400fn arena_off_for_window(
8401    arena: &Arena,
8402    schedule: &mut Vec<Step>,
8403    scratch: &mut u64,
8404    id: NodeId,
8405    _window_ids: &[NodeId],
8406    mut base: u64,
8407    mut size: u64,
8408    max_binding: u64,
8409    _fits_in_one_binding: bool,
8410) -> u32 {
8411    let src = arena.offset(id) as u64;
8412    let len = arena.len_of(id) as u64;
8413    if src >= base && src.saturating_add(len) <= base.saturating_add(size) {
8414        arena_local_off_f32(arena, id, base)
8415    } else {
8416        arena_off_in_window_or_stage(
8417            arena,
8418            schedule,
8419            scratch,
8420            &mut base,
8421            &mut size,
8422            max_binding,
8423            id,
8424        )
8425    }
8426}
8427
8428/// f16 shadow buffer window matching an f32 arena bind `[arena_base, arena_base+arena_size)`.
8429fn f16_shadow_bind_range(arena_base: u64, arena_size: u64, f16_buf_bytes: u64) -> (u64, u64) {
8430    const ALIGN: u64 = 256;
8431    let mut base = (arena_base / 2 / ALIGN) * ALIGN;
8432    let mut size = (arena_size / 2).div_ceil(ALIGN) * ALIGN;
8433    size = size.max(256).min(f16_buf_bytes);
8434    if base.saturating_add(size) > f16_buf_bytes {
8435        base = f16_buf_bytes.saturating_sub(size);
8436        base = (base / ALIGN) * ALIGN;
8437    }
8438    (base, size)
8439}
8440
8441/// Window into `f16_buffer` for matmul weight reads (`params.b_off` is in
8442/// f16-element indices, matching the f32 arena word index).
8443fn f16_weight_bind_range(
8444    dev: &wgpu::Device,
8445    f16_buf_bytes: u64,
8446    b_off: u32,
8447    k: u32,
8448    n: u32,
8449    batch: u32,
8450    b_batch_stride: u32,
8451) -> (u64, u64, u32) {
8452    const ALIGN: u64 = 256;
8453    let max_binding = dev.limits().max_storage_buffer_binding_size;
8454    let b0 = b_off as u64;
8455    let span = (k as u64).saturating_mul(n as u64);
8456    let batch_n = batch.max(1) as u64;
8457    let stride = if batch_n > 1 {
8458        b_batch_stride as u64
8459    } else {
8460        span
8461    };
8462    let hi_elems = b0
8463        .saturating_add((batch_n - 1).saturating_mul(stride))
8464        .saturating_add(span);
8465    let lo_byte = b0.saturating_mul(2);
8466    let hi_byte = hi_elems.saturating_mul(2).saturating_add(8);
8467    let need = hi_byte.saturating_sub(lo_byte).max(1);
8468    if need > max_binding {
8469        panic!(
8470            "rlx-wgpu: f16 weight region needs {need} bytes (> {max_binding}); \
8471             matmul k={k} n={n} batch={batch}"
8472        );
8473    }
8474    let mut base = (lo_byte / ALIGN) * ALIGN;
8475    let mut size = need.div_ceil(ALIGN) * ALIGN;
8476    size = size.max(256).min(max_binding).min(f16_buf_bytes);
8477    if base.saturating_add(size) < hi_byte {
8478        base = hi_byte.saturating_sub(size);
8479        base = (base / ALIGN) * ALIGN;
8480    }
8481    if base.saturating_add(size) > f16_buf_bytes {
8482        base = f16_buf_bytes.saturating_sub(size);
8483        base = (base / ALIGN) * ALIGN;
8484    }
8485    let rebased = b_off.saturating_sub((base / 2) as u32);
8486    (base, size, rebased)
8487}
8488
8489const ARENA_STAGE_CAP: u64 = 256 * 1024 * 1024;
8490
8491/// Return a window-local f32 offset, staging into scratch when the tensor lies
8492/// outside the bind window (via `copy_buffer_to_buffer`).
8493fn arena_off_in_window_or_stage(
8494    arena: &Arena,
8495    schedule: &mut Vec<Step>,
8496    scratch: &mut u64,
8497    base: &mut u64,
8498    size: &mut u64,
8499    max_binding: u64,
8500    id: NodeId,
8501) -> u32 {
8502    let src = arena.offset(id) as u64;
8503    let len = arena.len_of(id) as u64;
8504    if src >= *base && src.saturating_add(len) <= (*base).saturating_add(*size) {
8505        return arena_local_off_f32(arena, id, *base);
8506    }
8507    if len > ARENA_STAGE_CAP {
8508        panic!(
8509            "rlx-wgpu: cannot stage {} bytes for node {:?} (cap {ARENA_STAGE_CAP})",
8510            len, id
8511        );
8512    }
8513    let aligned = len.div_ceil(256) * 256;
8514    let dst = *scratch;
8515    *scratch = scratch.saturating_add(aligned);
8516    schedule.push(Step::BufferCopy {
8517        src_byte_off: src as u32,
8518        dst_byte_off: dst as u32,
8519        bytes: len as u32,
8520    });
8521    let lo = (*base).min(dst);
8522    let hi = (*base)
8523        .saturating_add(*size)
8524        .max(dst.saturating_add(aligned));
8525    let span = hi.saturating_sub(lo).max(1);
8526    if span <= max_binding {
8527        const ALIGN: u64 = 256;
8528        *base = (lo / ALIGN) * ALIGN;
8529        *size = span.div_ceil(ALIGN) * ALIGN;
8530        *size = (*size).max(256).min(max_binding);
8531        if (*base).saturating_add(*size) > arena.size as u64 {
8532            *base = (arena.size as u64).saturating_sub(*size);
8533            *base = (*base / ALIGN) * ALIGN;
8534        }
8535    }
8536    if arena_tensor_in_window(arena, id, *base, *size) {
8537        arena_local_off_f32(arena, id, *base)
8538    } else {
8539        ((dst.saturating_sub(*base)) / 4) as u32
8540    }
8541}
8542
8543/// If scratch does not fall inside `[base, base+size)`, slide the window to the tail.
8544fn arena_window_covering_scratch(arena: &Arena, base: u64, size: u64) -> u64 {
8545    let scratch = arena.scratch_off as u64;
8546    if scratch >= base && scratch.saturating_add(ARENA_STAGE_CAP) <= base.saturating_add(size) {
8547        return base;
8548    }
8549    let new_base = (arena.size as u64).saturating_sub(size);
8550    (new_base / 256) * 256
8551}
8552
8553fn arena_span_bytes(arena: &Arena, ids: &[NodeId]) -> u64 {
8554    let mut lo: u64 = u64::MAX;
8555    let mut hi: u64 = 0;
8556    for &id in ids {
8557        let off = arena.offset(id) as u64;
8558        let len = arena.len_of(id) as u64;
8559        lo = lo.min(off);
8560        hi = hi.max(off.saturating_add(len));
8561    }
8562    if lo == u64::MAX {
8563        0
8564    } else {
8565        hi.saturating_sub(lo)
8566    }
8567}
8568
8569#[allow(dead_code)]
8570fn bind_two(
8571    device: &wgpu::Device,
8572    kernel: &Kernel,
8573    buf0: &wgpu::Buffer,
8574    buf1: &wgpu::Buffer,
8575) -> wgpu::BindGroup {
8576    let max_binding = device.limits().max_storage_buffer_binding_size;
8577    if buf0.size() > max_binding {
8578        panic!(
8579            "rlx-wgpu: bind_two buffer {} bytes exceeds max_storage_buffer_binding_size {}; \
8580             use bind_two_buf0_window or bind_op_output_window",
8581            buf0.size(),
8582            max_binding
8583        );
8584    }
8585    device.create_bind_group(&wgpu::BindGroupDescriptor {
8586        label: Some("rlx-wgpu bg"),
8587        layout: &kernel.bgl,
8588        entries: &[
8589            wgpu::BindGroupEntry {
8590                binding: 0,
8591                resource: buf0.as_entire_binding(),
8592            },
8593            wgpu::BindGroupEntry {
8594                binding: 1,
8595                resource: buf1.as_entire_binding(),
8596            },
8597        ],
8598    })
8599}
8600
8601/// Windowed arena bind. When `operand_ids` is non-empty and their span with
8602/// `out_id` exceeds the binding limit, falls back to output-only window
8603/// (callers should stage operands and rebase offsets).
8604fn bind_op_output_window(
8605    device: &wgpu::Device,
8606    kernel: &Kernel,
8607    arena: &Arena,
8608    out_id: NodeId,
8609    params: &wgpu::Buffer,
8610) -> wgpu::BindGroup {
8611    bind_op_window(device, kernel, arena, &[out_id], params)
8612}
8613
8614fn bind_op_window(
8615    device: &wgpu::Device,
8616    kernel: &Kernel,
8617    arena: &Arena,
8618    ids: &[NodeId],
8619    params: &wgpu::Buffer,
8620) -> wgpu::BindGroup {
8621    let max_binding = device.limits().max_storage_buffer_binding_size;
8622    let (base, size) = if arena_span_bytes(arena, ids) <= max_binding {
8623        arena_window_for_nodes(device, arena, ids)
8624    } else {
8625        arena_window_for_nodes(device, arena, &[ids[0]])
8626    };
8627    bind_two_buf0_window(device, kernel, &arena.buffer, base, size, params)
8628}
8629
8630fn bind_two_buf0_window(
8631    device: &wgpu::Device,
8632    kernel: &Kernel,
8633    buf0: &wgpu::Buffer,
8634    buf0_base: u64,
8635    buf0_size: u64,
8636    buf1: &wgpu::Buffer,
8637) -> wgpu::BindGroup {
8638    device.create_bind_group(&wgpu::BindGroupDescriptor {
8639        label: Some("rlx-wgpu bg window"),
8640        layout: &kernel.bgl,
8641        entries: &[
8642            wgpu::BindGroupEntry {
8643                binding: 0,
8644                resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
8645                    buffer: buf0,
8646                    offset: buf0_base,
8647                    size: NonZeroU64::new(buf0_size),
8648                }),
8649            },
8650            wgpu::BindGroupEntry {
8651                binding: 1,
8652                resource: buf1.as_entire_binding(),
8653            },
8654        ],
8655    })
8656}
8657
8658/// Compute precision selector: derive from IR dtypes of A and B and
8659/// the device features.
8660///
8661/// Priority:
8662///   1. Coop16 — if EXPERIMENTAL_COOPERATIVE_MATRIX + SHADER_F16 +
8663///      F16 IR tag + b traces to a Param + M/K/N are 32/8/32 aligned.
8664///      Unlocks Apple's `simdgroup_matrix` / Vulkan's KHR_cooperative
8665///      hardware GEMM units (~18× faster than f32 ALU on Apple M-series).
8666///   2. F32 — every other case, *including* when AutoMixedPrecision
8667///      tagged the matmul as F16 but it failed Coop16's alignment
8668///      check. The non-coop F16 path (`matmul_f16_compute.wgsl`) was
8669///      empirically measured 4-5× SLOWER than the f32 baseline on
8670///      Apple via wgpu/naga 29 — the WGSL→MSL emit doesn't unlock
8671///      Apple's f16 ALU through portable WGSL ALU. So at small /
8672///      unaligned shapes we lose nothing by ignoring the IR's f16
8673///      tag and using f32 — precision improves AND speed wins.
8674///
8675/// (The F16 variant of `MatmulCompute` and `matmul_f16_compute.wgsl`
8676/// remain for future use — e.g. when naga gains a portable subgroup-
8677/// matrix surface that lowers efficiently without needing the full
8678/// coop-matrix dance, or when bf16 hardware lands. Today no path
8679/// dispatches them.)
8680fn derive_matmul_compute(
8681    dev: &wgpu::Device,
8682    graph: &Graph,
8683    mirror_acts: &HashSet<NodeId>,
8684    a_id: NodeId,
8685    b_id: NodeId,
8686    m: u32,
8687    k: u32,
8688    n: u32,
8689) -> MatmulCompute {
8690    if rlx_ir::env::flag("RLX_WGPU_MATMUL_F32_ONLY") {
8691        return MatmulCompute::F32;
8692    }
8693    use rlx_ir::DType;
8694    let a_dt = graph.node(a_id).shape.dtype();
8695    let b_dt = graph.node(b_id).shape.dtype();
8696    let any_low =
8697        matches!(a_dt, DType::F16 | DType::BF16) || matches!(b_dt, DType::F16 | DType::BF16);
8698    // CoopF32 (`simdgroup_float8x8`) needs K and N aligned to 8 and 32
8699    // (one micro-tile per K-iter, one 32-col workgroup per N-tile).
8700    // M can be arbitrary — the kernel pads to the next multiple of 32
8701    // and bounds-checks the output writes so out-of-range rows stay
8702    // untouched. (The Coop16 / matmul_qkv paths still require m%32==0;
8703    // their kernels don't have the same bounds check.)
8704    //
8705    // Vulkan uses `matmul_coop_f32_portable` (8×8 tiles, coopLoadT) which
8706    // only requires k%8 and n%8.
8707    let coop16_aligned = m.is_multiple_of(32) && k.is_multiple_of(8) && n.is_multiple_of(32);
8708    let coop_f32_metal_aligned = k.is_multiple_of(8) && n.is_multiple_of(32);
8709    let coop_f32_portable_aligned = k.is_multiple_of(8) && n.is_multiple_of(8);
8710    let has_coop = dev
8711        .features()
8712        .contains(wgpu::Features::EXPERIMENTAL_COOPERATIVE_MATRIX);
8713    let backend = crate::device::wgpu_device().map(|d| d.backend);
8714    // Coop16 has an f16 accumulator (Naga 29 can't compile the mixed
8715    // f32-acc / f16-operand form). Sums of 3072 BERT-FFN activations
8716    // overflow f16, so we only enter on F16/BF16 IR tags — AutoMixed
8717    // users have already opted into the precision tradeoff.
8718    if any_low
8719        && has_coop
8720        && dev.features().contains(wgpu::Features::SHADER_F16)
8721        && traces_to_param(graph, b_id)
8722        && coop16_aligned
8723    {
8724        return MatmulCompute::Coop16;
8725    }
8726    if !any_low && coop_f16_vk_eligible(dev, m, k, n) {
8727        if traces_to_param(graph, b_id)
8728            && !mirror_acts.contains(&a_id)
8729            && !mirror_acts.contains(&b_id)
8730        {
8731            return MatmulCompute::CoopF16Vk;
8732        }
8733    }
8734    // CoopF32 (`simdgroup_float8x8` on Apple): the f32 hardware-GEMM
8735    // path. Used whenever cooperative-matrix is available, B is a
8736    // Param, and shapes align — gives ~5-10× speedup over the
8737    // tiled `matmul_wide` path with no precision loss vs the f32
8738    // baseline (BERT max|Δ| stays at 2.3e-3 vs CPU on Apple).
8739    //
8740    // CoopF32: Metal-only by default. Vulkan portable 8×8 is opt-in via
8741    // RLX_WGPU_FORCE_COOP_F32 (RTX lacks 8×8 f32 coop; output is unreliable).
8742    let disabled = rlx_ir::env::flag("RLX_WGPU_NO_COOP_F32");
8743    let forced = rlx_ir::env::flag("RLX_WGPU_FORCE_COOP_F32");
8744    let metal_coop = !disabled
8745        && has_coop
8746        && coop_f32_metal_aligned
8747        && traces_to_param(graph, b_id)
8748        && (forced || matches!(backend, Some(wgpu::Backend::Metal)));
8749    let vulkan_coop = !disabled
8750        && has_coop
8751        && coop_f32_portable_aligned
8752        && traces_to_param(graph, b_id)
8753        && crate::device::coop_discrete_backend()
8754        && crate::device::coop_f32_8x8_supported();
8755    if metal_coop
8756        || vulkan_coop
8757        || (forced
8758            && has_coop
8759            && traces_to_param(graph, b_id)
8760            && (coop_f32_metal_aligned || coop_f32_portable_aligned))
8761    {
8762        return MatmulCompute::CoopF32;
8763    }
8764    MatmulCompute::F32
8765}
8766
8767/// Detects the BERT-style fused-QKV-then-narrow-then-attention
8768/// pattern. When all three of an attention's Q/K/V inputs are
8769/// `Op::Narrow` of a single source tensor on the last axis with
8770/// sequential offsets `(0, H·D, 2·H·D)` and equal lengths `H·D`,
8771/// returns `Some((qkv_source_node, h_d))` — naming the source
8772/// tensor and per-slice width.
8773///
8774/// EMPIRICAL FINDING: the obvious "skip the narrow + read attention
8775/// directly from QKV with stride 3·H·D" optimization REGRESSED end-
8776/// to-end perf 7-15× on Apple M4 Pro. The narrow's apparent overhead
8777/// (~3 dispatches per attention block, ~150µs at small batch) is
8778/// dwarfed by the cost of strided attention reads — stepping by
8779/// 3·H·D = 4.6 KB between sequence positions defeats the hardware
8780/// prefetcher (prefetch distance maxes around 1-2 KB on M-series).
8781/// Cosine stayed 0.9999+ (output is correct, just slow).
8782///
8783/// Kept as a helper for future smarter fusions — e.g. a coop kernel
8784/// that reads Q/K/V cooperatively from QKV in a single pass over
8785/// the sequence dim, avoiding the random-access stride pattern.
8786#[allow(dead_code)]
8787fn detect_qkv_narrow_pattern(
8788    graph: &Graph,
8789    q_id: NodeId,
8790    k_id: NodeId,
8791    v_id: NodeId,
8792) -> Option<(NodeId, u32)> {
8793    let unwrap_narrow = |id: NodeId| -> Option<(NodeId, usize, usize, usize)> {
8794        let node = graph.node(id);
8795        match &node.op {
8796            Op::Narrow { axis, start, len } => Some((node.inputs[0], *axis, *start, *len)),
8797            _ => None,
8798        }
8799    };
8800    let (q_src, q_axis, q_start, q_len) = unwrap_narrow(q_id)?;
8801    let (k_src, k_axis, k_start, k_len) = unwrap_narrow(k_id)?;
8802    let (v_src, v_axis, v_start, v_len) = unwrap_narrow(v_id)?;
8803    // Same source tensor.
8804    if q_src != k_src || k_src != v_src {
8805        return None;
8806    }
8807    // Equal slice widths (= H · D).
8808    if q_len != k_len || k_len != v_len {
8809        return None;
8810    }
8811    // Sequential offsets 0, H·D, 2·H·D.
8812    if q_start != 0 || k_start != q_len || v_start != q_len * 2 {
8813        return None;
8814    }
8815    // All on the LAST axis of the source.
8816    let src_rank = graph.node(q_src).shape.dims().len();
8817    if q_axis + 1 != src_rank || k_axis + 1 != src_rank || v_axis + 1 != src_rank {
8818        return None;
8819    }
8820    Some((q_src, q_len as u32))
8821}
8822
8823/// Detects the (FusedMatMulBiasAct → Narrow×3) split-QKV pattern that
8824/// shows up at the start of every BERT-style attention block. Returns
8825/// a map `parent_fmb_id → (q_narrow_id, k_narrow_id, v_narrow_id)`
8826/// for every site where the pattern can be replaced by one
8827/// `Step::MatmulQkv` dispatch.
8828///
8829/// Pattern requirements:
8830///   - Parent is `Op::FusedMatMulBiasAct { activation: None }` with
8831///     output shape `[..., 3·head_width]`.
8832///   - The parent's *only* consumers are exactly 3 `Op::Narrow` nodes,
8833///     all on the last axis, with offsets `(0, head_width, 2·head_width)`
8834///     and equal `len = head_width`.
8835///
8836/// The win is purely structural: same FMA work, but the 3 narrow
8837/// dispatches (and their full-tensor read+write of the QKV intermediate)
8838/// disappear. Different from the reverted "skip narrow + read attention
8839/// strided" approach because reads from each Q/K/V buffer remain
8840/// sequential — the prefetcher stays happy.
8841/// Detects (`Op::Binary(Add) → Op::LayerNorm`) where the Add has more
8842/// than one consumer in the graph — the case `FuseResidualLN` declines
8843/// because its single-consumer guard would force materializing the sum.
8844///
8845/// Returns:
8846///   - `ln_to_tee`: `ln_id → (h, delta, gamma, beta, sum_id)` so the
8847///     wgpu LayerNorm lowering can emit `Step::FusedResidualLnTee`
8848///     using the existing arena slot for the sum (= the Add's slot).
8849///   - `skip_adds`: the set of Add `NodeId`s whose normal Step emission
8850///     should be suppressed; their output value is written by the tee
8851///     step instead.
8852fn detect_residual_ln_tee_pattern(
8853    graph: &Graph,
8854) -> (
8855    HashMap<NodeId, (NodeId, NodeId, NodeId, NodeId, NodeId)>,
8856    HashSet<NodeId>,
8857) {
8858    use rlx_ir::op::BinaryOp;
8859    // Consumer counts (output references count once each).
8860    let mut consumers: HashMap<NodeId, usize> = HashMap::new();
8861    for node in graph.nodes() {
8862        for &input in &node.inputs {
8863            *consumers.entry(input).or_insert(0) += 1;
8864        }
8865    }
8866    for &out in &graph.outputs {
8867        *consumers.entry(out).or_insert(0) += 1;
8868    }
8869
8870    let mut ln_to_tee = HashMap::new();
8871    let mut skip_adds = HashSet::new();
8872    for node in graph.nodes() {
8873        let Op::LayerNorm { axis: _, eps: _ } = &node.op else {
8874            continue;
8875        };
8876        if node.inputs.len() < 3 {
8877            continue;
8878        } // need [in, gamma, beta]
8879        let in_id = node.inputs[0];
8880        let in_node = graph.node(in_id);
8881        if !matches!(in_node.op, Op::Binary(BinaryOp::Add)) {
8882            continue;
8883        }
8884        // Only fire when Add has >= 2 consumers (otherwise `FuseResidualLN`
8885        // already collapses it into Op::FusedResidualLN upstream).
8886        if consumers.get(&in_id).copied().unwrap_or(0) < 2 {
8887            continue;
8888        }
8889        // Add must be plain — both operands shape-equal to LN's input
8890        // and to each other.
8891        if in_node.inputs.len() != 2 {
8892            continue;
8893        }
8894        let h_id = in_node.inputs[0];
8895        let delta_id = in_node.inputs[1];
8896        if graph.node(h_id).shape.dims() != node.shape.dims() {
8897            continue;
8898        }
8899        if graph.node(delta_id).shape.dims() != node.shape.dims() {
8900            continue;
8901        }
8902        let gamma_id = node.inputs[1];
8903        let beta_id = node.inputs[2];
8904        ln_to_tee.insert(node.id, (h_id, delta_id, gamma_id, beta_id, in_id));
8905        skip_adds.insert(in_id);
8906    }
8907    (ln_to_tee, skip_adds)
8908}
8909
8910fn detect_split_qkv_pattern(graph: &Graph) -> HashMap<NodeId, (NodeId, NodeId, NodeId)> {
8911    // consumers[parent] = list of node ids that read parent
8912    let mut consumers: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
8913    for node in graph.nodes() {
8914        for &input in &node.inputs {
8915            consumers.entry(input).or_default().push(node.id);
8916        }
8917    }
8918    // Output nodes also count as consumers — would prevent QKV elision
8919    // if the matmul output is ever read externally.
8920    for &out_id in &graph.outputs {
8921        consumers.entry(out_id).or_default().push(NodeId(u32::MAX));
8922    }
8923
8924    let mut result = HashMap::new();
8925    for node in graph.nodes() {
8926        if !matches!(node.op, Op::FusedMatMulBiasAct { activation: None }) {
8927            continue;
8928        }
8929        let cs = match consumers.get(&node.id) {
8930            Some(c) if c.len() == 3 => c,
8931            _ => continue,
8932        };
8933        let dims = node.shape.dims();
8934        if dims.is_empty() {
8935            continue;
8936        }
8937        let last_axis = dims.len() - 1;
8938        let n = dims[last_axis].unwrap_static();
8939        if n % 3 != 0 {
8940            continue;
8941        }
8942        let head_width = n / 3;
8943
8944        // Each consumer must be a Narrow on the last axis, len = head_width.
8945        let mut narrows: Vec<(usize, NodeId)> = Vec::with_capacity(3);
8946        let mut all_match = true;
8947        for &c in cs {
8948            let cn = graph.node(c);
8949            match cn.op {
8950                Op::Narrow { axis, start, len }
8951                    if axis == last_axis && len == head_width && cn.inputs[0] == node.id =>
8952                {
8953                    narrows.push((start, c));
8954                }
8955                _ => {
8956                    all_match = false;
8957                    break;
8958                }
8959            }
8960        }
8961        if !all_match {
8962            continue;
8963        }
8964        narrows.sort_by_key(|&(start, _)| start);
8965        if narrows[0].0 != 0 || narrows[1].0 != head_width || narrows[2].0 != 2 * head_width {
8966            continue;
8967        }
8968        result.insert(node.id, (narrows[0].1, narrows[1].1, narrows[2].1));
8969    }
8970    result
8971}
8972
8973/// Walk through Cast/Reshape nodes (which alias the underlying arena
8974/// slot, per `plan_f32_uniform`) to find whether `id` ultimately
8975/// refers to an `Op::Param`. AutoMixedPrecision wraps params in
8976/// Cast(F32→F16) nodes, so a literal `matches!(node.op, Op::Param)`
8977/// check on the matmul's `b_id` would miss the Cast(Param) case.
8978fn node_is_arena_param(param_offsets: &HashMap<String, NodeId>, id: NodeId) -> bool {
8979    param_offsets.values().any(|&nid| nid == id)
8980}
8981
8982fn traces_to_param(graph: &Graph, mut id: NodeId) -> bool {
8983    loop {
8984        let node = graph.node(id);
8985        match &node.op {
8986            Op::Param { .. } => return true,
8987            Op::Cast { .. } | Op::Reshape { .. } | Op::Transpose { .. } => {
8988                if node.inputs.is_empty() {
8989                    return false;
8990                }
8991                id = node.inputs[0];
8992            }
8993            _ => return false,
8994        }
8995    }
8996}
8997
8998fn tensor_is_graph_param(
8999    graph: &Graph,
9000    param_offsets: &HashMap<String, NodeId>,
9001    id: NodeId,
9002) -> bool {
9003    node_is_arena_param(param_offsets, id) || traces_to_param(graph, id)
9004}
9005
9006fn traces_to_input(graph: &Graph, mut id: NodeId) -> bool {
9007    loop {
9008        let node = graph.node(id);
9009        match &node.op {
9010            Op::Input { .. } => return true,
9011            Op::Cast { .. } | Op::Reshape { .. } => {
9012                if node.inputs.is_empty() {
9013                    return false;
9014                }
9015                id = node.inputs[0];
9016            }
9017            _ => return false,
9018        }
9019    }
9020}
9021
9022/// Mirror A/B into the f16 shadow buffer before CoopF16Vk when the operand
9023/// is not already mirrored (Inputs/Params are written via `write_f32`).
9024fn schedule_uses_coop_f16_vk(schedule: &[Step]) -> bool {
9025    schedule.iter().any(|s| {
9026        matches!(
9027            s,
9028            Step::Matmul {
9029                compute_precision: MatmulCompute::CoopF16Vk,
9030                ..
9031            } | Step::MatmulQkv {
9032                kind: MatmulQkvKind::CoopF16Vk,
9033                ..
9034            }
9035        )
9036    })
9037}
9038
9039fn register_coop_f16_vk_b_param(
9040    map: &mut HashMap<u32, String>,
9041    param_offsets: &HashMap<String, NodeId>,
9042    b_id: NodeId,
9043    b_off_f32: u32,
9044    compute: MatmulCompute,
9045) {
9046    if compute != MatmulCompute::CoopF16Vk {
9047        return;
9048    }
9049    for (name, &id) in param_offsets {
9050        if id == b_id {
9051            map.insert(b_off_f32, name.clone());
9052            return;
9053        }
9054    }
9055}
9056
9057fn tensor_host_name(
9058    input_offsets: &HashMap<String, NodeId>,
9059    param_offsets: &HashMap<String, NodeId>,
9060    id: NodeId,
9061) -> String {
9062    for (name, &nid) in input_offsets {
9063        if nid == id {
9064            return name.clone();
9065        }
9066    }
9067    for (name, &nid) in param_offsets {
9068        if nid == id {
9069            return name.clone();
9070        }
9071    }
9072    panic!("rlx-wgpu: CoopF16Vk host activation source {id} is not an input or param");
9073}
9074
9075fn host_tensor_f32<'a>(
9076    name: &str,
9077    inputs: &'a [(&str, &[f32])],
9078    stashed_params: &'a HashMap<String, Vec<f32>>,
9079) -> Option<&'a [f32]> {
9080    inputs
9081        .iter()
9082        .find(|(n, _)| *n == name)
9083        .map(|(_, d)| *d)
9084        .or_else(|| stashed_params.get(name).map(|v| v.as_slice()))
9085}
9086
9087fn apply_activation_host(act: Activation, data: &[f32]) -> Vec<f32> {
9088    data.iter()
9089        .map(|&x| match act {
9090            Activation::Relu => x.max(0.0),
9091            Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
9092            Activation::Tanh => x.tanh(),
9093            Activation::Exp => x.exp(),
9094            Activation::Log => x.ln(),
9095            Activation::Sqrt => x.sqrt(),
9096            Activation::Rsqrt => 1.0 / x.sqrt(),
9097            Activation::Neg => -x,
9098            Activation::Abs => x.abs(),
9099            Activation::Gelu | Activation::GeluApprox => {
9100                let c = 0.797_884_6_f32;
9101                let x3 = x * x * x;
9102                let inner = (c * (x + 0.044_715 * x3)).clamp(-15.0, 15.0);
9103                0.5 * x * (1.0 + inner.tanh())
9104            }
9105            Activation::Silu => {
9106                let nx = (-x).clamp(-88.0, 88.0);
9107                x / (1.0 + nx.exp())
9108            }
9109            Activation::Round => x.round(),
9110            Activation::Sin => x.sin(),
9111            Activation::Cos => x.cos(),
9112            Activation::Tan => x.tan(),
9113            Activation::Atan => x.atan(),
9114        })
9115        .collect()
9116}
9117
9118/// Activation node ids consumed as CoopF16Vk matmul A/B operands.
9119fn collect_coop_f16_vk_mirror_activations(graph: &Graph, dev: &wgpu::Device) -> HashSet<NodeId> {
9120    let mut acts = HashSet::new();
9121    for node in graph.nodes() {
9122        if !matches!(node.op, Op::MatMul) {
9123            continue;
9124        }
9125        let a_id = node.inputs[0];
9126        let b_id = node.inputs[1];
9127        let a_shape = graph.node(a_id).shape.dims();
9128        let b_shape = graph.node(b_id).shape.dims();
9129        if a_shape.len() != 2 || b_shape.len() != 2 {
9130            continue;
9131        }
9132        let m = a_shape[0].unwrap_static() as u32;
9133        let k = a_shape[1].unwrap_static() as u32;
9134        let n = b_shape[1].unwrap_static() as u32;
9135        if !coop_f16_vk_eligible(dev, m, k, n) || !traces_to_param(graph, b_id) {
9136            continue;
9137        }
9138        if matches!(graph.node(a_id).op, Op::Activation(_)) {
9139            acts.insert(a_id);
9140        }
9141        if matches!(graph.node(b_id).op, Op::Activation(_)) {
9142            acts.insert(b_id);
9143        }
9144    }
9145    acts
9146}
9147
9148/// When A/B are computed (not Input/Param), mirror f32 arena into f16 shadow
9149/// via `cast_f32_to_f16` before CoopF16Vk matmul (non-activation intermediates).
9150fn maybe_push_coop_f16_vk_casts(
9151    graph: &Graph,
9152    a_id: NodeId,
9153    b_id: NodeId,
9154    mirror_acts: &HashSet<NodeId>,
9155    device: &wgpu::Device,
9156    arena: &Arena,
9157    schedule: &mut Vec<Step>,
9158    uniforms: &mut Vec<wgpu::Buffer>,
9159    bind_groups: &mut Vec<wgpu::BindGroup>,
9160    mm_cast: &Option<&'static Kernel>,
9161    compute_precision: MatmulCompute,
9162    a_off_f32: u32,
9163    m: u32,
9164    k: u32,
9165    batch: u32,
9166    b_off_f32: u32,
9167    n: u32,
9168) {
9169    if compute_precision != MatmulCompute::CoopF16Vk {
9170        return;
9171    }
9172    let batch_n = batch.max(1);
9173    if !traces_to_input(graph, a_id)
9174        && !traces_to_param(graph, a_id)
9175        && !mirror_acts.contains(&a_id)
9176    {
9177        let a_elems = m.saturating_mul(k).saturating_mul(batch_n);
9178        let (base, size) = arena_window_for_nodes(device, arena, &[a_id]);
9179        push_cast_f32_to_f16_step(
9180            device,
9181            arena,
9182            base,
9183            size,
9184            schedule,
9185            uniforms,
9186            bind_groups,
9187            mm_cast,
9188            a_off_f32,
9189            a_elems,
9190        );
9191    }
9192    if !traces_to_input(graph, b_id)
9193        && !traces_to_param(graph, b_id)
9194        && !mirror_acts.contains(&b_id)
9195    {
9196        let b_elems = k.saturating_mul(n).saturating_mul(batch_n);
9197        let (base, size) = arena_window_for_nodes(device, arena, &[b_id]);
9198        push_cast_f32_to_f16_step(
9199            device,
9200            arena,
9201            base,
9202            size,
9203            schedule,
9204            uniforms,
9205            bind_groups,
9206            mm_cast,
9207            b_off_f32,
9208            b_elems,
9209        );
9210    }
9211}
9212
9213fn build_matmul_qkv_coop_f16_vk_bind_group(
9214    device: &wgpu::Device,
9215    mqk: &Kernel,
9216    arena: &Arena,
9217    arena_base: u64,
9218    arena_size: u64,
9219    params: &wgpu::Buffer,
9220    k: u32,
9221    n: u32,
9222    b_off: u32,
9223) -> (wgpu::BindGroup, u32) {
9224    let f16_buf = arena
9225        .f16_buffer
9226        .as_ref()
9227        .expect("CoopF16Vk QKV requires SHADER_F16 f16 shadow arena");
9228    let (f16_res, rebased_b) = {
9229        let (base, size, rebased) =
9230            f16_weight_bind_range(device, f16_buf.size(), b_off, k, n, 1, 0);
9231        (
9232            wgpu::BindingResource::Buffer(wgpu::BufferBinding {
9233                buffer: f16_buf,
9234                offset: base,
9235                size: NonZeroU64::new(size),
9236            }),
9237            rebased,
9238        )
9239    };
9240    (
9241        device.create_bind_group(&wgpu::BindGroupDescriptor {
9242            label: Some("rlx-wgpu matmul_qkv_coop_f16_vk bg"),
9243            layout: &mqk.bgl,
9244            entries: &[
9245                wgpu::BindGroupEntry {
9246                    binding: 0,
9247                    resource: f16_res,
9248                },
9249                wgpu::BindGroupEntry {
9250                    binding: 1,
9251                    resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
9252                        buffer: &arena.buffer,
9253                        offset: arena_base,
9254                        size: NonZeroU64::new(arena_size),
9255                    }),
9256                },
9257                wgpu::BindGroupEntry {
9258                    binding: 2,
9259                    resource: params.as_entire_binding(),
9260                },
9261            ],
9262        }),
9263        rebased_b,
9264    )
9265}
9266/// Append a CastF32ToF16 pre-pass: mirrors `arena[off..off+len]` (f32) into
9267/// `arena_f16[off..off+len]` (f16) so coop matmul kernels can read operands
9268/// as f16. Used before CoopF16Vk when A/B are computed activations.
9269fn push_cast_f32_to_f16_step(
9270    device: &wgpu::Device,
9271    arena: &Arena,
9272    arena_base: u64,
9273    arena_size: u64,
9274    schedule: &mut Vec<Step>,
9275    uniforms: &mut Vec<wgpu::Buffer>,
9276    bind_groups: &mut Vec<wgpu::BindGroup>,
9277    mm_cast: &Option<&'static Kernel>,
9278    src_off: u32,
9279    len: u32,
9280) {
9281    let kernel = match mm_cast {
9282        Some(k) => *k,
9283        None => return, // device lacks SHADER_F16; fall through, dispatch will skip
9284    };
9285    let f16_buf = match &arena.f16_buffer {
9286        Some(b) => b,
9287        None => return,
9288    };
9289    let p = CastF32ToF16Params {
9290        src_off: src_off.saturating_sub((arena_base / 4) as u32),
9291        len,
9292        _p0: 0,
9293        _p1: 0,
9294    };
9295    let u = device.create_buffer(&wgpu::BufferDescriptor {
9296        label: Some("rlx-wgpu cast_f32_to_f16 uniform"),
9297        size: std::mem::size_of::<CastF32ToF16Params>() as u64,
9298        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
9299        mapped_at_creation: false,
9300    });
9301    // Write params at compile (kernel doesn't depend on active extent).
9302    let dev = wgpu_device().expect("rlx-wgpu: device gone");
9303    dev.queue.write_buffer(&u, 0, bytemuck::bytes_of(&p));
9304    let (f16_base, f16_size) = f16_shadow_bind_range(arena_base, arena_size, f16_buf.size());
9305    let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
9306        label: Some("rlx-wgpu cast_f32_to_f16 bg"),
9307        layout: &kernel.bgl,
9308        entries: &[
9309            wgpu::BindGroupEntry {
9310                binding: 0,
9311                resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
9312                    buffer: f16_buf,
9313                    offset: f16_base,
9314                    size: NonZeroU64::new(f16_size),
9315                }),
9316            },
9317            wgpu::BindGroupEntry {
9318                binding: 1,
9319                resource: u.as_entire_binding(),
9320            },
9321            wgpu::BindGroupEntry {
9322                binding: 2,
9323                resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
9324                    buffer: &arena.buffer,
9325                    offset: arena_base,
9326                    size: NonZeroU64::new(arena_size),
9327                }),
9328            },
9329        ],
9330    });
9331    schedule.push(Step::CastF32ToF16 { params: p });
9332    uniforms.push(u);
9333    bind_groups.push(bg);
9334}
9335
9336/// Per-Matmul-step bind group builder. Returns `(bind_group, rebased_b_off)`;
9337/// `rebased_b_off` adjusts `MatmulParams.b_off` when the f16 weight buffer is
9338/// window-bound.
9339fn build_matmul_bind_group(
9340    device: &wgpu::Device,
9341    mm_k: &Kernel,
9342    _mm_w: &Kernel,
9343    mm_f16w: &Option<&'static Kernel>,
9344    mm_f16c: &Option<&'static Kernel>,
9345    mm_coop: &Option<&'static Kernel>,
9346    mm_coop_f32: &Option<&'static Kernel>,
9347    arena: &Arena,
9348    arena_base: u64,
9349    arena_size: u64,
9350    params: &wgpu::Buffer,
9351    b_is_param: bool,
9352    compute_precision: MatmulCompute,
9353    k: u32,
9354    n: u32,
9355    batch: u32,
9356    b_off: u32,
9357    b_batch_stride: u32,
9358) -> (wgpu::BindGroup, u32) {
9359    let f16_bind = |b_off: u32| -> (wgpu::BindingResource<'_>, u32) {
9360        let f16_buf = arena
9361            .f16_buffer
9362            .as_ref()
9363            .expect("f16 weight bind without f16_buffer");
9364        let (base, size, rebased) =
9365            f16_weight_bind_range(device, f16_buf.size(), b_off, k, n, batch, b_batch_stride);
9366        (
9367            wgpu::BindingResource::Buffer(wgpu::BufferBinding {
9368                buffer: f16_buf,
9369                offset: base,
9370                size: NonZeroU64::new(size),
9371            }),
9372            rebased,
9373        )
9374    };
9375    if compute_precision == MatmulCompute::CoopF16Vk
9376        && let (Some(coop_vk), Some(_f16_buf)) =
9377            (matmul_coop_f16_vulkan_kernel(device), &arena.f16_buffer)
9378    {
9379        let (f16_res, rebased_b) = f16_bind(b_off);
9380        return (
9381            device.create_bind_group(&wgpu::BindGroupDescriptor {
9382                label: Some("rlx-wgpu matmul_coop_f16_vulkan bg"),
9383                layout: &coop_vk.bgl,
9384                entries: &[
9385                    wgpu::BindGroupEntry {
9386                        binding: 0,
9387                        resource: f16_res,
9388                    },
9389                    wgpu::BindGroupEntry {
9390                        binding: 1,
9391                        resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
9392                            buffer: &arena.buffer,
9393                            offset: arena_base,
9394                            size: NonZeroU64::new(arena_size),
9395                        }),
9396                    },
9397                    wgpu::BindGroupEntry {
9398                        binding: 2,
9399                        resource: params.as_entire_binding(),
9400                    },
9401                ],
9402            }),
9403            rebased_b,
9404        );
9405    }
9406    if b_is_param
9407        && compute_precision == MatmulCompute::CoopF32
9408        && let Some(coop_f32) = mm_coop_f32
9409    {
9410        // 2-binding layout — both A and B come from the f32 arena
9411        // (no f16 shadow buffer needed for the pure-f32 path).
9412        return (
9413            device.create_bind_group(&wgpu::BindGroupDescriptor {
9414                label: Some("rlx-wgpu matmul_coop_f32 bg"),
9415                layout: &coop_f32.bgl,
9416                entries: &[
9417                    wgpu::BindGroupEntry {
9418                        binding: 0,
9419                        resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
9420                            buffer: &arena.buffer,
9421                            offset: arena_base,
9422                            size: NonZeroU64::new(arena_size),
9423                        }),
9424                    },
9425                    wgpu::BindGroupEntry {
9426                        binding: 1,
9427                        resource: params.as_entire_binding(),
9428                    },
9429                ],
9430            }),
9431            b_off,
9432        );
9433    }
9434    if b_is_param
9435        && compute_precision == MatmulCompute::Coop16
9436        && let (Some(_f16_buf), Some(coop)) = (&arena.f16_buffer, mm_coop)
9437    {
9438        let (f16_res, rebased_b) = f16_bind(b_off);
9439        // 3-binding layout — A is staged from arena (f32) through
9440        // workgroup-shared memory inside the kernel, no separate
9441        // f16 binding for A.
9442        return (
9443            device.create_bind_group(&wgpu::BindGroupDescriptor {
9444                label: Some("rlx-wgpu matmul_coop16 bg"),
9445                layout: &coop.bgl,
9446                entries: &[
9447                    wgpu::BindGroupEntry {
9448                        binding: 0,
9449                        resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
9450                            buffer: &arena.buffer,
9451                            offset: arena_base,
9452                            size: NonZeroU64::new(arena_size),
9453                        }),
9454                    },
9455                    wgpu::BindGroupEntry {
9456                        binding: 1,
9457                        resource: params.as_entire_binding(),
9458                    },
9459                    wgpu::BindGroupEntry {
9460                        binding: 2,
9461                        resource: f16_res,
9462                    }, // weights
9463                ],
9464            }),
9465            rebased_b,
9466        );
9467    }
9468    if b_is_param
9469        && compute_precision == MatmulCompute::F16
9470        && let (Some(_f16_buf), Some(f16c)) = (&arena.f16_buffer, mm_f16c)
9471    {
9472        let (f16_res, rebased_b) = f16_bind(b_off);
9473        return (
9474            device.create_bind_group(&wgpu::BindGroupDescriptor {
9475                label: Some("rlx-wgpu matmul_f16_compute bg"),
9476                layout: &f16c.bgl,
9477                entries: &[
9478                    wgpu::BindGroupEntry {
9479                        binding: 0,
9480                        resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
9481                            buffer: &arena.buffer,
9482                            offset: arena_base,
9483                            size: NonZeroU64::new(arena_size),
9484                        }),
9485                    },
9486                    wgpu::BindGroupEntry {
9487                        binding: 1,
9488                        resource: params.as_entire_binding(),
9489                    },
9490                    wgpu::BindGroupEntry {
9491                        binding: 2,
9492                        resource: f16_res,
9493                    },
9494                ],
9495            }),
9496            rebased_b,
9497        );
9498    }
9499    let f16w_opt_in = rlx_ir::env::flag("RLX_WGPU_F16_WEIGHTS");
9500    if b_is_param
9501        && f16w_opt_in
9502        && let (Some(_f16_buf), Some(f16w)) = (&arena.f16_buffer, mm_f16w)
9503    {
9504        let (f16_res, rebased_b) = f16_bind(b_off);
9505        return (
9506            device.create_bind_group(&wgpu::BindGroupDescriptor {
9507                label: Some("rlx-wgpu matmul_f16w bg"),
9508                layout: &f16w.bgl,
9509                entries: &[
9510                    wgpu::BindGroupEntry {
9511                        binding: 0,
9512                        resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
9513                            buffer: &arena.buffer,
9514                            offset: arena_base,
9515                            size: NonZeroU64::new(arena_size),
9516                        }),
9517                    },
9518                    wgpu::BindGroupEntry {
9519                        binding: 1,
9520                        resource: params.as_entire_binding(),
9521                    },
9522                    wgpu::BindGroupEntry {
9523                        binding: 2,
9524                        resource: f16_res,
9525                    },
9526                ],
9527            }),
9528            rebased_b,
9529        );
9530    }
9531    (
9532        bind_two_buf0_window(device, mm_k, &arena.buffer, arena_base, arena_size, params),
9533        b_off,
9534    )
9535}