Skip to main content

rlx_vulkan/
backend.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// SPDX-License-Identifier: GPL-3.0-only
5
6//! `VulkanExecutable` — compile an IR graph into a flat schedule of compute
7//! dispatches over a single f32 arena buffer, then execute it.
8//!
9//! Design (mirrors rlx-cuda / rlx-wgpu): every tensor is an f32 slot in one
10//! arena `VkBuffer`; each schedule [`Step`] is one compute pipeline + push
11//! constants + a workgroup count. A single descriptor set binds the whole
12//! arena; per-op offsets/dims ride in push constants. Between dispatches we
13//! insert a global shader-memory barrier (every kernel reads/writes the shared
14//! arena), submit once per `run`, and read outputs back from the host-visible
15//! mapping.
16//!
17//! Op coverage is the transformer-inference hot path: elementwise (binary /
18//! unary / compare / where), matmul, last-axis reduce, softmax, RMS/Layer
19//! norm, RoPE, attention, gather, cumsum, and the shape ops (narrow / concat /
20//! expand / transpose) via one strided-copy kernel. Fused ops, DotGeneral,
21//! Fma, non-last-axis reduce, GroupNorm, etc. are decomposed to these
22//! primitives by `legalize_or_rewrite_for_backend`. Anything left unsupported
23//! (Conv, Pool, quantized matmul, SSM, …) fails loudly with a "pin to CPU"
24//! diagnostic — like rlx-wgpu's stance for ops it can't lower.
25
26use crate::buffer::Arena;
27use crate::device::vulkan_device;
28use crate::kernels::kernels;
29use ash::vk;
30use rlx_compile::memory::{BufferSlot, MemoryPlan};
31use rlx_ir::op::{Activation, BinaryOp, CmpOp, MaskKind, ReduceOp, RopeStyle};
32use rlx_ir::{DType, Graph, NodeId, Op, RngOptions};
33use std::collections::{HashMap, HashSet};
34
35/// OpKinds this backend lowers natively. Everything else is either decomposed
36/// into this set by the rewrite pass or rejected at legalize time.
37pub const SUPPORTED_OPS: &[rlx_ir::OpKind] = {
38    use rlx_ir::OpKind::*;
39    &[
40        Input,
41        Param,
42        Constant,
43        Cast,
44        StopGradient,
45        Reshape, // structural / alias
46        Binary,
47        Compare,
48        Where,
49        Activation, // elementwise
50        MatMul,
51        Reduce,
52        Softmax, // contraction / reduction
53        LayerNorm,
54        RmsNorm,
55        LayerNorm2d, // normalization
56        Rope,
57        Attention, // transformer
58        // Claimed so the block is a first-class op; `compile_rng` runs
59        // `unfuse_attention_block` to lower it to the chain above (matmul
60        // → narrow → rope → attention → matmul) before legalization.
61        FusedAttentionBlock,
62        Transpose,
63        Narrow,
64        Concat,
65        Expand,
66        Gather,
67        Cumsum,
68        Reverse, // shape / indexing
69        ArgMax,
70        ArgMin,
71        Pool,
72        ResizeNearest2x,
73        Conv,          // reductions / vision
74        GroupedMatMul, // MoE
75        SelectiveScan, // SSM / Mamba
76        Im2Col,
77        ScatterAdd,
78        TopK, // vision / indexing / generation
79        // Host-fallback (run on the CPU reference against the mapped arena —
80        // sequential / specialized families with no native SPIR-V kernel yet):
81        Lstm,
82        Gru,
83        Rnn,
84        Mamba2,
85        GatedDeltaNet,
86        ConvTranspose2d,
87        Fft,
88        DequantMatMul,
89        DequantGroupedMatMul,
90        DequantMoEWeights, // GGUF quant
91        RngNormal,
92        RngUniform,
93        Sample, // RNG / generation
94    ]
95};
96
97/// Ops with no native kernel that route to the CPU host-fallback path.
98///
99/// `DequantMatMul` is handled by its own scheduler arm: GGUF Q4_K/Q6_K decode
100/// GEMV (`m == 1`) runs natively via the `dequant_matmul` shader; every other
101/// scheme/shape still falls back to the CPU reference from that arm.
102fn is_host_fallback(op: &Op) -> bool {
103    matches!(
104        op,
105        Op::Lstm { .. }
106            | Op::Gru { .. }
107            | Op::Rnn { .. }
108            | Op::Mamba2 { .. }
109            | Op::GatedDeltaNet { .. }
110            | Op::ConvTranspose2d { .. }
111            | Op::Fft { .. }
112            | Op::DequantGroupedMatMul { .. }
113            | Op::DequantMoEWeights { .. }
114            | Op::RngNormal { .. }
115            | Op::RngUniform { .. }
116            | Op::Sample { .. }
117    )
118}
119
120/// One scheduled step: either a GPU compute dispatch or a CPU host-fallback
121/// op (for families with no native SPIR-V kernel yet).
122#[derive(Clone)]
123enum Step {
124    Gpu {
125        kernel: &'static str,
126        push: Vec<u8>,
127        groups: (u32, u32, u32),
128    },
129    Host {
130        op: Op,
131        out: NodeId,
132        out_shape: rlx_ir::Shape,
133        inputs: Vec<NodeId>,
134    },
135}
136
137/// A pre-recorded execution segment. The schedule is partitioned into maximal
138/// runs of GPU dispatches (each recorded ONCE into a reusable command buffer at
139/// compile time) separated by CPU host-fallback ops. At run time a GPU segment
140/// is a single `queue_submit` of its prebuilt command buffer — no per-step
141/// allocation, recording, or fence churn. See [`record_segments`].
142enum Segment {
143    /// A prebuilt command buffer covering a run of consecutive GPU dispatches.
144    Gpu(vk::CommandBuffer),
145    /// A CPU host-fallback op, evaluated against the mapped arena between GPU
146    /// segments (HOST_COHERENT memory, queue idle here — see `run_read_outputs`).
147    Host {
148        op: Op,
149        out: NodeId,
150        out_shape: rlx_ir::Shape,
151        inputs: Vec<NodeId>,
152    },
153}
154
155pub struct VulkanExecutable {
156    /// Post-legalize, f32-uniform graph (kept for `clone_for_cache`).
157    graph: Graph,
158    arena: Arena,
159    schedule: Vec<Step>,
160    /// Pre-recorded segments (GPU command buffers + interleaved host ops). Built
161    /// once from `schedule`; reused every `run`. Empty when caching is disabled
162    /// (`RLX_VULKAN_NOCACHE=1`), in which case the legacy per-run record path
163    /// drives `schedule` directly.
164    segments: Vec<Segment>,
165    /// Reusable fence for the cached submit path (reset after each wait).
166    fence: vk::Fence,
167    /// Whether the cached pre-recorded path is active.
168    cached: bool,
169    input_ids: HashMap<String, NodeId>,
170    param_ids: HashMap<String, NodeId>,
171    output_ids: Vec<NodeId>,
172    output_dtypes: Vec<DType>,
173    desc_pool: vk::DescriptorPool,
174    desc_set: vk::DescriptorSet,
175    rng: RngOptions,
176    active_extent: Option<(usize, usize)>,
177    /// GPU-resident input handles (KV-cache style). Host mirror is kept only
178    /// until the handle becomes resident (fed in-arena from an output), after
179    /// which it is cleared — the value lives purely in the arena slot.
180    gpu_handles: HashMap<String, Vec<f32>>,
181    /// `handle_name → output index`: after each run, that output's arena slot
182    /// is folded back into the handle's input slot (in-place, no host copy).
183    gpu_handle_feeds: HashMap<String, usize>,
184    /// Handles whose value is live in the arena (skip host re-upload).
185    gpu_handle_resident: HashSet<String>,
186    /// `handle_name → output index` for the *row* feed (decode graphs that emit
187    /// the new K/V token at the LAST row of a bucket-padded output, e.g. llama32
188    /// `concat(past_k, k_new)`). Driven explicitly via [`feed_kv_row`] after a
189    /// logits-only run; kept separate from `gpu_handle_feeds` so the generic
190    /// prefix propagation never fires for these.
191    kv_row_feeds: HashMap<String, usize>,
192}
193
194unsafe impl Send for VulkanExecutable {}
195
196// ── memory plan (f32-uniform bump allocator; same shape as rlx-cuda) ───────
197
198fn plan_f32_uniform(graph: &Graph, align: usize) -> MemoryPlan {
199    let mut assignments: HashMap<NodeId, BufferSlot> = HashMap::new();
200    let mut schedule = Vec::with_capacity(graph.nodes().len());
201    let mut cursor = 0usize;
202    for node in graph.nodes() {
203        if matches!(
204            node.op,
205            Op::Reshape { .. } | Op::Cast { .. } | Op::StopGradient
206        ) {
207            if let Some(in_id) = node.inputs.first()
208                && let Some(slot) = assignments.get(in_id)
209            {
210                let aliased = slot.clone();
211                assignments.insert(node.id, aliased);
212                schedule.push(node.id);
213                continue;
214            }
215        }
216        let elems = node.shape.num_elements().unwrap_or(0);
217        // f32-uniform arena: F32 (and widened F16/BF16/int) params occupy 4 bytes
218        // per element, but U8/I8 packed quant weights are stored as RAW BYTES
219        // (`set_param_bytes`) and read via byte addressing in the dequant kernel —
220        // sizing them `elems*4` like f32 inflated the arena ~4× (≈10 GB on
221        // Orpheus-3B Q4_K). Size by the real dtype. Slots stay `align`-aligned so
222        // every f32 word offset is still exact and the GEMV's word-relative
223        // weight addressing is unaffected.
224        let elem_size = match node.shape.dtype() {
225            DType::U8 | DType::I8 => 1,
226            _ => 4,
227        };
228        let bytes = (elems * elem_size).max(4);
229        let aligned = bytes.div_ceil(align) * align;
230        assignments.insert(
231            node.id,
232            BufferSlot {
233                offset: cursor,
234                size: aligned,
235            },
236        );
237        schedule.push(node.id);
238        cursor += aligned;
239    }
240    MemoryPlan {
241        arena_size: cursor.max(align),
242        assignments,
243        schedule,
244    }
245}
246
247// ── small shape helpers ────────────────────────────────────────────────────
248
249fn dims(graph: &Graph, id: NodeId) -> Vec<usize> {
250    graph
251        .node(id)
252        .shape
253        .dims()
254        .iter()
255        .map(|d| match d {
256            rlx_ir::Dim::Static(s) => *s,
257            _ => 0,
258        })
259        .collect()
260}
261
262fn numel(d: &[usize]) -> usize {
263    d.iter()
264        .product::<usize>()
265        .max(if d.is_empty() { 1 } else { 0 })
266}
267
268/// Row-major contiguous strides for `d`.
269fn contig_strides(d: &[usize]) -> Vec<usize> {
270    let mut s = vec![1usize; d.len()];
271    for i in (0..d.len().saturating_sub(1)).rev() {
272        s[i] = s[i + 1] * d[i + 1];
273    }
274    s
275}
276
277fn norm_axis(axis: i32, rank: usize) -> usize {
278    if axis < 0 {
279        (rank as i32 + axis).max(0) as usize
280    } else {
281        (axis as usize).min(rank.saturating_sub(1))
282    }
283}
284
285// ── push-constant builder (std430, all 4-byte scalars / scalar arrays) ─────
286
287#[derive(Default)]
288struct Push {
289    words: Vec<u32>,
290}
291impl Push {
292    fn u(mut self, v: u32) -> Self {
293        self.words.push(v);
294        self
295    }
296    fn f(mut self, v: f32) -> Self {
297        self.words.push(v.to_bits());
298        self
299    }
300    fn us(mut self, vs: &[u32]) -> Self {
301        self.words.extend_from_slice(vs);
302        self
303    }
304    fn bytes(self) -> Vec<u8> {
305        let mut b = Vec::with_capacity(self.words.len() * 4);
306        for w in self.words {
307            b.extend_from_slice(&w.to_le_bytes());
308        }
309        b
310    }
311}
312
313fn ceil_div(n: usize, d: u32) -> u32 {
314    (n as u64).div_ceil(d as u64) as u32
315}
316
317/// The `matmul_coop` kernel writes a full 16×16 output tile per workgroup, so M
318/// and N must be 16-aligned — a partial output tile would store out of bounds.
319/// K is unconstrained: the kernel zero-pads its final partial K-tile. Shapes
320/// with non-16-aligned M/N fall back to the (fully general, fp32-exact) tiled
321/// kernel, which is the better fit for them anyway.
322fn coop_eligible(m: usize, _k: usize, n: usize) -> bool {
323    m.is_multiple_of(16) && n.is_multiple_of(16)
324}
325
326/// Which matmul kernel to dispatch:
327/// - default: `matmul_tiled` (shared-memory blocked **fp32**, exact) on native
328///   drivers; `matmul` (scalar) on portability drivers (MoltenVK), where
329///   tiling + barriers regress under Vulkan→Metal translation.
330/// - `RLX_VULKAN_MATMUL=coop`: `matmul_coop`, the tensor-core path (f16·f16→f32
331///   cooperative matrix). It is **opt-in** because f16 operands trade precision
332///   for throughput (not fp32-exact), so it is never auto-selected — that would
333///   silently degrade accuracy. Used only when the device advertises a usable
334///   config (`coop_matmul`) and M,N are 16-aligned (K is arbitrary); otherwise
335///   falls back to the exact tiled kernel (see `coop_eligible`).
336/// - `RLX_VULKAN_MATMUL=scalar|tiled`: force that fp32 kernel (A/B benching).
337fn matmul_kernel(m: usize, k: usize, n: usize) -> &'static str {
338    let dev = vulkan_device();
339    let portability = dev.map(|d| d.portability).unwrap_or(false);
340    let coop = dev.map(|d| d.coop_matmul).unwrap_or(false);
341    match std::env::var("RLX_VULKAN_MATMUL").ok().as_deref() {
342        Some("scalar") => "matmul",
343        Some("tiled") => "matmul_tiled",
344        Some("coop") if coop && coop_eligible(m, k, n) => "matmul_coop",
345        Some("coop") => "matmul_tiled",
346        _ if portability => "matmul",
347        _ => "matmul_tiled",
348    }
349}
350
351/// 1-D workgroup count for `n` items at `local` threads/group. Assumes the
352/// device's `maxComputeWorkGroupCount[0]` is large (true on desktop GPUs;
353/// the Vulkan minimum of 65535 caps ~16M elements/dispatch — a follow-up
354/// would switch to a grid-stride loop).
355fn groups1d(n: usize, local: u32) -> (u32, u32, u32) {
356    (ceil_div(n, local).max(1), 1, 1)
357}
358
359fn act_id(a: Activation) -> u32 {
360    match a {
361        Activation::Gelu => 0,
362        Activation::GeluApprox => 1,
363        Activation::Silu => 2,
364        Activation::Relu => 3,
365        Activation::Sigmoid => 4,
366        Activation::Tanh => 5,
367        Activation::Exp => 6,
368        Activation::Log => 7,
369        Activation::Sqrt => 8,
370        Activation::Rsqrt => 9,
371        Activation::Neg => 10,
372        Activation::Abs => 11,
373        Activation::Sin => 12,
374        Activation::Cos => 13,
375        Activation::Tan => 14,
376        Activation::Atan => 15,
377        Activation::Round => 16,
378    }
379}
380
381fn binop_id(op: BinaryOp) -> u32 {
382    match op {
383        BinaryOp::Add => 0,
384        BinaryOp::Sub => 1,
385        BinaryOp::Mul => 2,
386        BinaryOp::Div => 3,
387        BinaryOp::Max => 4,
388        BinaryOp::Min => 5,
389        BinaryOp::Pow => 6,
390    }
391}
392
393fn cmp_id(op: CmpOp) -> u32 {
394    match op {
395        CmpOp::Eq => 0,
396        CmpOp::Ne => 1,
397        CmpOp::Lt => 2,
398        CmpOp::Le => 3,
399        CmpOp::Gt => 4,
400        CmpOp::Ge => 5,
401    }
402}
403
404fn reduce_id(op: ReduceOp) -> u32 {
405    match op {
406        ReduceOp::Sum => 0,
407        ReduceOp::Mean => 1,
408        ReduceOp::Max => 2,
409        ReduceOp::Min => 3,
410        ReduceOp::Prod => 4,
411    }
412}
413
414impl VulkanExecutable {
415    pub fn compile(graph: Graph) -> Self {
416        Self::compile_rng(graph, RngOptions::default())
417    }
418
419    /// Prepare the graph (legalize → primitive set), plan the arena, and build
420    /// the dispatch schedule. Panics with a clear message if the graph
421    /// contains an op no decomposition rule can reduce to [`SUPPORTED_OPS`].
422    pub fn compile_rng(graph: Graph, rng: RngOptions) -> Self {
423        use rlx_opt::pass::Pass as _;
424
425        let graph = rlx_opt::LowerControlFlow.run(graph);
426        // `FusedAttentionBlock` is claimed (so it legalizes), but there is
427        // no monolithic fused-attention kernel — decompose it to primitives
428        // first. FAB-only (not the whole-graph unfuse) so nothing else is
429        // touched. No-op when no FAB node is present.
430        let graph = rlx_opt::unfuse::unfuse_attention_block(graph);
431        let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, SUPPORTED_OPS)
432            .unwrap_or_else(|errs| panic!("{}", rlx_opt::format_legalize_error("vulkan", &errs)));
433        // Materialize mid-axis broadcasts so Binary operands are equal-shaped
434        // or trailing-broadcast (our kernels only do trailing modulus).
435        let graph = rlx_opt::LegalizeBroadcast.run(graph);
436
437        Self::build(graph, rng)
438    }
439
440    fn build(graph: Graph, rng: RngOptions) -> Self {
441        let dev = vulkan_device().expect("rlx-vulkan: no device");
442        let kern = kernels().expect("rlx-vulkan: no kernels");
443
444        let plan = plan_f32_uniform(&graph, 16);
445        let arena = Arena::from_plan(&plan);
446
447        // Upload constants (widened to f32 — the arena is f32-uniform).
448        for node in graph.nodes() {
449            if let Op::Constant { data } = &node.op
450                && arena.has(node.id)
451                && !data.is_empty()
452            {
453                let f = widen_const_to_f32(data, node.shape.dtype());
454                arena.write_f32(node.id, &f);
455            }
456        }
457
458        let mut input_ids = HashMap::new();
459        let mut param_ids = HashMap::new();
460        for node in graph.nodes() {
461            match &node.op {
462                Op::Input { name } => {
463                    input_ids.insert(name.clone(), node.id);
464                }
465                Op::Param { name } => {
466                    param_ids.insert(name.clone(), node.id);
467                }
468                _ => {}
469            }
470        }
471
472        let output_ids = graph.outputs.clone();
473        let output_dtypes = output_ids
474            .iter()
475            .map(|&id| graph.node(id).shape.dtype())
476            .collect();
477
478        let (schedule, deps) = build_schedule(&graph, &arena);
479
480        // Descriptor set binding the whole arena to binding 0.
481        let pool_sizes = [vk::DescriptorPoolSize::default()
482            .ty(vk::DescriptorType::STORAGE_BUFFER)
483            .descriptor_count(1)];
484        let desc_pool = unsafe {
485            dev.device.create_descriptor_pool(
486                &vk::DescriptorPoolCreateInfo::default()
487                    .max_sets(1)
488                    .pool_sizes(&pool_sizes),
489                None,
490            )
491        }
492        .expect("vk descriptor_pool");
493        let set_layouts = [kern.dsl];
494        let desc_set = unsafe {
495            dev.device.allocate_descriptor_sets(
496                &vk::DescriptorSetAllocateInfo::default()
497                    .descriptor_pool(desc_pool)
498                    .set_layouts(&set_layouts),
499            )
500        }
501        .expect("vk descriptor_set")[0];
502        let buf_info = [vk::DescriptorBufferInfo::default()
503            .buffer(arena.buffer)
504            .offset(0)
505            .range(vk::WHOLE_SIZE)];
506        let write = vk::WriteDescriptorSet::default()
507            .dst_set(desc_set)
508            .dst_binding(0)
509            .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
510            .buffer_info(&buf_info);
511        unsafe { dev.device.update_descriptor_sets(&[write], &[]) };
512
513        // Pre-record the static schedule into reusable command buffers (one per
514        // maximal GPU run). The whole schedule — kernels, push constants,
515        // workgroup counts — is fixed at compile time; per-step inputs are
516        // memcpy'd into the host-visible arena, never the command stream. So a
517        // single recording is valid for every `run`, turning each step into one
518        // `queue_submit` instead of allocate → record → fence → free.
519        let cached = std::env::var("RLX_VULKAN_NOCACHE").as_deref() != Ok("1");
520        let (segments, fence) = if cached {
521            let segs = record_segments(dev, kern, desc_set, &schedule, &deps);
522            (segs, dev.create_reusable_fence())
523        } else {
524            (Vec::new(), vk::Fence::null())
525        };
526
527        if std::env::var_os("RLX_VULKAN_DEBUG").is_some() {
528            let gpu = schedule
529                .iter()
530                .filter(|s| matches!(s, Step::Gpu { .. }))
531                .count();
532            let host = schedule.len() - gpu;
533            let gpu_segs = segments
534                .iter()
535                .filter(|s| matches!(s, Segment::Gpu(_)))
536                .count();
537            let mut hist: HashMap<&'static str, usize> = HashMap::new();
538            for s in &schedule {
539                if let Step::Gpu { kernel, .. } = s {
540                    *hist.entry(kernel).or_default() += 1;
541                }
542            }
543            let mut by_count: Vec<_> = hist.into_iter().collect();
544            by_count.sort_by_key(|&(_, c)| std::cmp::Reverse(c));
545            eprintln!(
546                "[rlx-vulkan] schedule: {gpu} gpu dispatches, {host} host ops; \
547                 cached={cached} ({gpu_segs} gpu submit(s)/run)"
548            );
549            eprintln!("[rlx-vulkan] dispatch histogram: {by_count:?}");
550        }
551
552        Self {
553            graph,
554            arena,
555            schedule,
556            segments,
557            fence,
558            cached,
559            input_ids,
560            param_ids,
561            output_ids,
562            output_dtypes,
563            desc_pool,
564            desc_set,
565            rng,
566            active_extent: None,
567            gpu_handles: HashMap::new(),
568            gpu_handle_feeds: HashMap::new(),
569            gpu_handle_resident: HashSet::new(),
570            kv_row_feeds: HashMap::new(),
571        }
572    }
573
574    pub fn set_param(&mut self, name: &str, data: &[f32]) {
575        if let Some(&id) = self.param_ids.get(name) {
576            self.arena.write_f32(id, data);
577        }
578    }
579
580    /// Raw-byte param upload (packed weights). The arena is f32-uniform, so
581    /// callers should normally use [`set_param`]; this exists for symmetry.
582    pub fn set_param_bytes(&mut self, name: &str, data: &[u8]) {
583        if let Some(&id) = self.param_ids.get(name) {
584            self.arena.write_bytes(id, data);
585        }
586    }
587
588    pub fn output_dtypes(&self) -> Vec<DType> {
589        self.output_dtypes.clone()
590    }
591
592    pub fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
593        self.active_extent = extent;
594    }
595
596    /// Persistent input buffer for KV-cache style graphs. Writes `data` into the
597    /// input's arena slot once; subsequent decode steps reuse it (and, with a
598    /// feed wired, update it in place on-device). Returns false if `name` is not
599    /// a graph input. Mirrors the rlx-metal handle semantics.
600    pub fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
601        let Some(&id) = self.input_ids.get(name) else {
602            return false;
603        };
604        // A fresh bind re-seeds from host, so it is no longer purely resident.
605        self.gpu_handle_resident.remove(name);
606        self.arena.write_f32(id, data);
607        // Keep a host mirror only until the first in-arena feed makes it resident.
608        self.gpu_handles.insert(name.to_string(), data.to_vec());
609        true
610    }
611
612    pub fn has_gpu_handle(&self, name: &str) -> bool {
613        self.gpu_handles.contains_key(name)
614    }
615
616    pub fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) {
617        self.gpu_handle_feeds
618            .insert(handle_name.to_string(), output_index);
619    }
620
621    /// Register a *row* feed (vs the generic prefix feed): after a decode run,
622    /// row `src_row` of output `output_index` is folded into handle
623    /// `handle_name`'s input slot at row `dst_row`. For decode graphs that emit
624    /// the new K/V token at the last bucket-padded output row (llama32). Driven
625    /// explicitly via [`feed_kv_row`]; does NOT trigger the auto-propagation in
626    /// `run_read_outputs`.
627    pub fn register_kv_row_feed(&mut self, handle_name: &str, output_index: usize) {
628        self.kv_row_feeds
629            .insert(handle_name.to_string(), output_index);
630    }
631
632    /// Fold each registered row-feed's new-token row into its resident handle
633    /// slot, in-place on the arena (no host round-trip). Call after a
634    /// logits-only `run_read_outputs(.., Some(&[0]))`. `row_elems` is kv_dim.
635    pub fn feed_kv_row(&mut self, src_row: usize, dst_row: usize, row_elems: usize) {
636        let feeds: Vec<(String, usize)> = self
637            .kv_row_feeds
638            .iter()
639            .map(|(k, &v)| (k.clone(), v))
640            .collect();
641        for (name, out_idx) in feeds {
642            let Some(&out_id) = self.output_ids.get(out_idx) else {
643                continue;
644            };
645            let Some(&in_id) = self.input_ids.get(name.as_str()) else {
646                continue;
647            };
648            if in_id != out_id {
649                self.arena.copy_node_f32_range(
650                    in_id,
651                    dst_row * row_elems,
652                    out_id,
653                    src_row * row_elems,
654                    row_elems,
655                );
656            }
657            self.gpu_handle_resident.insert(name.clone());
658            self.gpu_handles.insert(name.clone(), Vec::new());
659        }
660    }
661
662    /// Read a handle back to host: from its fed output slot if wired, else the
663    /// resident arena slot, else the host mirror. Used on bucket change / sync.
664    pub fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
665        if let Some(&out_idx) = self.gpu_handle_feeds.get(name)
666            && let Some(&out_id) = self.output_ids.get(out_idx)
667        {
668            let n = self.graph.node(out_id).shape.num_elements().unwrap_or(0);
669            return Some(self.arena.read_f32(out_id, n));
670        }
671        if self.gpu_handle_resident.contains(name)
672            && let Some(&id) = self.input_ids.get(name)
673        {
674            let n = self.graph.node(id).shape.num_elements().unwrap_or(0);
675            return Some(self.arena.read_f32(id, n));
676        }
677        self.gpu_handles.get(name).cloned()
678    }
679
680    /// Read one row (`row_inner` f32 elements at `row`) from graph output
681    /// `out_idx`, directly from the arena. Used by resident KV decode to pull
682    /// just the new-token K/V row to the host cache (for bucket transitions)
683    /// without a full-output readback.
684    pub fn read_output_row(
685        &self,
686        out_idx: usize,
687        row: usize,
688        row_inner: usize,
689    ) -> Option<Vec<f32>> {
690        let id = *self.output_ids.get(out_idx)?;
691        let base = self.arena.elem_offset(id) as usize + row * row_inner;
692        Some(self.arena.read_f32_at_elem(base, row_inner))
693    }
694
695    /// Fold each fed output's arena slot back into its handle input slot,
696    /// in-place (no host round-trip). The copy length honors `active_extent`
697    /// `(actual_rows, upper)` so only the valid prefix (incl. the new token row)
698    /// is carried — the rest of the bucket-padded slot stays zero.
699    fn propagate_gpu_handle_feeds_in_arena(&mut self) {
700        let extent = self.active_extent;
701        let feeds: Vec<(String, usize)> = self
702            .gpu_handle_feeds
703            .iter()
704            .map(|(k, &v)| (k.clone(), v))
705            .collect();
706        for (name, out_idx) in feeds {
707            let Some(&out_id) = self.output_ids.get(out_idx) else {
708                continue;
709            };
710            let Some(&in_id) = self.input_ids.get(name.as_str()) else {
711                continue;
712            };
713            if in_id != out_id {
714                let out_elems = self.graph.node(out_id).shape.num_elements().unwrap_or(0);
715                let copy_elems = match extent {
716                    Some((actual, upper)) if upper > 0 => actual * (out_elems / (upper + 1)).max(1),
717                    _ => out_elems,
718                };
719                self.arena
720                    .copy_node_f32_prefix(in_id, out_id, copy_elems.min(out_elems));
721            }
722            self.gpu_handle_resident.insert(name.clone());
723            // Drop the host mirror — the value now lives in the arena.
724            self.gpu_handles.insert(name.clone(), Vec::new());
725        }
726    }
727
728    /// Refresh host mirrors from fed outputs (only when all outputs are read).
729    fn refresh_gpu_handles_from_outputs(&mut self) {
730        let feeds: Vec<(String, usize)> = self
731            .gpu_handle_feeds
732            .iter()
733            .map(|(k, &v)| (k.clone(), v))
734            .collect();
735        for (name, out_idx) in feeds {
736            let Some(&out_id) = self.output_ids.get(out_idx) else {
737                continue;
738            };
739            let n = self.graph.node(out_id).shape.num_elements().unwrap_or(0);
740            let src = self.arena.read_f32(out_id, n);
741            self.gpu_handles.insert(name, src);
742        }
743    }
744
745    pub fn set_rng(&mut self, rng: RngOptions) {
746        self.rng = rng;
747    }
748
749    pub fn rng(&self) -> RngOptions {
750        self.rng
751    }
752
753    pub fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
754        self.run_read_outputs(inputs, None)
755    }
756
757    pub fn run_read_outputs(
758        &mut self,
759        inputs: &[(&str, &[f32])],
760        read_indices: Option<&[usize]>,
761    ) -> Vec<Vec<f32>> {
762        // Re-seed any GPU handle that is neither resident in the arena nor about
763        // to be overwritten by an explicit input this step (first step after a
764        // bind, or a bucket reinstall). Resident handles are skipped — their
765        // value already lives in the arena from the previous step's feed.
766        for (name, data) in &self.gpu_handles {
767            if self.gpu_handle_resident.contains(name) || inputs.iter().any(|(n, _)| n == name) {
768                continue;
769            }
770            if let Some(&id) = self.input_ids.get(name) {
771                self.arena.write_f32(id, data);
772            }
773        }
774        // Upload inputs.
775        for &(name, data) in inputs {
776            if let Some(&id) = self.input_ids.get(name) {
777                self.arena.write_f32(id, data);
778            }
779        }
780
781        // Execute the schedule in segments: runs of consecutive GPU dispatches
782        // are submitted together; a host-fallback step flushes the queue, runs
783        // on the CPU directly against the host-visible arena, and the next GPU
784        // segment picks up its result (HOST_COHERENT memory).
785        let dev = vulkan_device().expect("rlx-vulkan: no device");
786        let kern = kernels().expect("rlx-vulkan: no kernels");
787        let desc_set = self.desc_set;
788        let layout = kern.pipeline_layout;
789
790        if self.cached {
791            // Fast path: each GPU segment is a single submit of its pre-recorded
792            // command buffer; host segments run on the CPU between submits. Arena
793            // reads/writes are `&self` (interior mutability via the mapped ptr),
794            // so the whole loop borrows `self` immutably.
795            let nseg = self.segments.len();
796            for si in 0..nseg {
797                match &self.segments[si] {
798                    Segment::Gpu(cmd) => {
799                        let cmd = *cmd;
800                        dev.submit_recorded_wait(cmd, self.fence);
801                    }
802                    Segment::Host {
803                        op,
804                        out,
805                        out_shape,
806                        inputs: in_ids,
807                    } => {
808                        let in_specs: Vec<(rlx_ir::Shape, crate::host::HostBuf)> = in_ids
809                            .iter()
810                            .map(|&id| {
811                                let sh = self.graph.node(id).shape.clone();
812                                let nn = sh.num_elements().unwrap_or(0);
813                                let buf = if matches!(sh.dtype(), DType::U8 | DType::I8) {
814                                    crate::host::HostBuf::Bytes(self.arena.read_bytes(id, nn))
815                                } else {
816                                    crate::host::HostBuf::F32(self.arena.read_f32(id, nn))
817                                };
818                                (sh, buf)
819                            })
820                            .collect();
821                        let result = crate::host::eval(op, out_shape, &in_specs);
822                        self.arena.write_f32(*out, &result);
823                    }
824                }
825            }
826            // Fall through to the feed/readback tail below.
827            return self.finish_run(read_indices);
828        }
829
830        let n = self.schedule.len();
831        let mut i = 0;
832        while i < n {
833            let start = i;
834            while i < n && matches!(self.schedule[i], Step::Gpu { .. }) {
835                i += 1;
836            }
837            if i > start {
838                let gpu = self.schedule[start..i].to_vec();
839                dev.submit_and_wait(|cmd| unsafe {
840                    dev.device.cmd_bind_descriptor_sets(
841                        cmd,
842                        vk::PipelineBindPoint::COMPUTE,
843                        layout,
844                        0,
845                        &[desc_set],
846                        &[],
847                    );
848                    let barrier = vk::MemoryBarrier::default()
849                        .src_access_mask(vk::AccessFlags::SHADER_WRITE)
850                        .dst_access_mask(
851                            vk::AccessFlags::SHADER_READ | vk::AccessFlags::SHADER_WRITE,
852                        );
853                    for (j, step) in gpu.iter().enumerate() {
854                        if let Step::Gpu {
855                            kernel,
856                            push,
857                            groups,
858                        } = step
859                        {
860                            let pipeline = kern.pipeline(kernel);
861                            dev.device.cmd_bind_pipeline(
862                                cmd,
863                                vk::PipelineBindPoint::COMPUTE,
864                                pipeline,
865                            );
866                            dev.device.cmd_push_constants(
867                                cmd,
868                                layout,
869                                vk::ShaderStageFlags::COMPUTE,
870                                0,
871                                push,
872                            );
873                            dev.device.cmd_dispatch(cmd, groups.0, groups.1, groups.2);
874                            if j + 1 < gpu.len() {
875                                dev.device.cmd_pipeline_barrier(
876                                    cmd,
877                                    vk::PipelineStageFlags::COMPUTE_SHADER,
878                                    vk::PipelineStageFlags::COMPUTE_SHADER,
879                                    vk::DependencyFlags::empty(),
880                                    &[barrier],
881                                    &[],
882                                    &[],
883                                );
884                            }
885                        }
886                    }
887                });
888            }
889            if i < n {
890                if let Step::Host {
891                    op,
892                    out,
893                    out_shape,
894                    inputs: in_ids,
895                } = self.schedule[i].clone()
896                {
897                    let in_specs: Vec<(rlx_ir::Shape, crate::host::HostBuf)> = in_ids
898                        .iter()
899                        .map(|&id| {
900                            let sh = self.graph.node(id).shape.clone();
901                            let nn = sh.num_elements().unwrap_or(0);
902                            // Packed quant weights (U8/I8) are read as raw bytes;
903                            // everything else is f32 from the uniform arena.
904                            let buf = if matches!(sh.dtype(), DType::U8 | DType::I8) {
905                                crate::host::HostBuf::Bytes(self.arena.read_bytes(id, nn))
906                            } else {
907                                crate::host::HostBuf::F32(self.arena.read_f32(id, nn))
908                            };
909                            (sh, buf)
910                        })
911                        .collect();
912                    let result = crate::host::eval(&op, &out_shape, &in_specs);
913                    self.arena.write_f32(out, &result);
914                }
915                i += 1;
916            }
917        }
918
919        self.finish_run(read_indices)
920    }
921
922    /// Shared post-execution tail for both the cached and legacy run paths: fold
923    /// fed outputs (new-token K/V) back into their handle input slots in-place on
924    /// the arena — the queue is idle here so the mapped memory is coherent. When
925    /// all outputs are read back, also refresh host mirrors; for logits-only
926    /// decode (`read_indices == Some([0])`) the K/V never leaves the arena, which
927    /// is the whole point. Then read the requested outputs.
928    fn finish_run(&mut self, read_indices: Option<&[usize]>) -> Vec<Vec<f32>> {
929        if !self.gpu_handle_feeds.is_empty() {
930            self.propagate_gpu_handle_feeds_in_arena();
931            if read_indices.is_none() {
932                self.refresh_gpu_handles_from_outputs();
933            }
934        }
935
936        let want: Vec<usize> = match read_indices {
937            Some(ix) => ix.to_vec(),
938            None => (0..self.output_ids.len()).collect(),
939        };
940        want.into_iter()
941            .filter_map(|i| {
942                let id = *self.output_ids.get(i)?;
943                let n = self.graph.node(id).shape.num_elements().unwrap_or(0);
944                Some(self.arena.read_f32(id, n))
945            })
946            .collect()
947    }
948
949    /// Deep copy for `clone_box`: fresh arena/descriptors with the same params
950    /// and constants already resident.
951    pub fn clone_for_cache(&self) -> Self {
952        let mut twin = Self::build(self.graph.clone(), self.rng);
953        twin.active_extent = self.active_extent;
954        // Copy the whole arena (params + constants, plus any resident K/V)
955        // byte-for-byte, then carry the GPU-handle bookkeeping so the twin keeps
956        // feeding/resident semantics identical to the source.
957        self.arena.copy_into(&twin.arena);
958        twin.gpu_handles = self.gpu_handles.clone();
959        twin.gpu_handle_feeds = self.gpu_handle_feeds.clone();
960        twin.gpu_handle_resident = self.gpu_handle_resident.clone();
961        twin.kv_row_feeds = self.kv_row_feeds.clone();
962        twin
963    }
964}
965
966impl Drop for VulkanExecutable {
967    fn drop(&mut self) {
968        if let Some(dev) = vulkan_device() {
969            // Free the pre-recorded command buffers and the reusable fence
970            // before tearing down the pool they came from.
971            let cmds: Vec<vk::CommandBuffer> = self
972                .segments
973                .iter()
974                .filter_map(|s| match s {
975                    Segment::Gpu(cmd) => Some(*cmd),
976                    Segment::Host { .. } => None,
977                })
978                .collect();
979            if !cmds.is_empty() {
980                dev.free_cmds(&cmds);
981            }
982            if self.fence != vk::Fence::null() {
983                dev.destroy_fence(self.fence);
984            }
985            unsafe {
986                dev.device.destroy_descriptor_pool(self.desc_pool, None);
987            }
988        }
989    }
990}
991
992/// Pre-record the static schedule into reusable command buffers. The schedule is
993/// partitioned into maximal runs of consecutive GPU dispatches; each run is
994/// recorded once into a primary command buffer that is resubmitted unchanged
995/// every `run`. Host-fallback ops become `Segment::Host` markers, executed on the
996/// CPU between GPU submits. Recorded WITHOUT `ONE_TIME_SUBMIT` so the buffers can
997/// be resubmitted.
998///
999/// Barriers are placed only where a real memory hazard exists (per `deps`): a
1000/// dispatch that reads/writes a slot touched since the last barrier flushes with
1001/// one global shader-memory barrier, which both lets the driver overlap
1002/// independent dispatches and — decisively on MoltenVK, where each barrier forces
1003/// a Metal compute-encoder restart — slashes the barrier count for the typical
1004/// MLP/CNN graph (most of whose 100+ dispatches are independent elementwise/shape
1005/// glue). `RLX_VULKAN_FULLBARRIER=1` restores a barrier between every pair
1006/// (conservative fallback); `RLX_VULKAN_NOBARRIER=1` drops them all (unsafe —
1007/// diagnostic only).
1008fn record_segments(
1009    dev: &crate::device::VulkanDevice,
1010    kern: &crate::kernels::Kernels,
1011    desc_set: vk::DescriptorSet,
1012    schedule: &[Step],
1013    deps: &[StepDep],
1014) -> Vec<Segment> {
1015    let layout = kern.pipeline_layout;
1016    let no_barrier = std::env::var("RLX_VULKAN_NOBARRIER").as_deref() == Ok("1");
1017    let full_barrier = std::env::var("RLX_VULKAN_FULLBARRIER").as_deref() == Ok("1");
1018    let mut segments = Vec::new();
1019    let n = schedule.len();
1020    let mut i = 0;
1021    while i < n {
1022        let start = i;
1023        while i < n && matches!(schedule[i], Step::Gpu { .. }) {
1024            i += 1;
1025        }
1026        if i > start {
1027            let run = &schedule[start..i];
1028            let run_deps = &deps[start..i];
1029            let cmd = dev.alloc_primary_cmd();
1030            unsafe {
1031                dev.device
1032                    .begin_command_buffer(cmd, &vk::CommandBufferBeginInfo::default())
1033                    .expect("vk begin cmd");
1034                dev.device.cmd_bind_descriptor_sets(
1035                    cmd,
1036                    vk::PipelineBindPoint::COMPUTE,
1037                    layout,
1038                    0,
1039                    &[desc_set],
1040                    &[],
1041                );
1042                let barrier = vk::MemoryBarrier::default()
1043                    .src_access_mask(vk::AccessFlags::SHADER_WRITE)
1044                    .dst_access_mask(vk::AccessFlags::SHADER_READ | vk::AccessFlags::SHADER_WRITE);
1045                // Slots written / read since the last barrier (arena elem
1046                // offsets). A dispatch hazards on RAW (reads a written slot),
1047                // WAW (writes a written slot) or WAR (writes a read slot); on a
1048                // hazard we flush with one barrier and reset the sets.
1049                let mut wrote: HashSet<u32> = HashSet::new();
1050                let mut read: HashSet<u32> = HashSet::new();
1051                for (j, step) in run.iter().enumerate() {
1052                    if let Step::Gpu {
1053                        kernel,
1054                        push,
1055                        groups,
1056                    } = step
1057                    {
1058                        let dep = &run_deps[j];
1059                        let hazard = !wrote.is_empty()
1060                            && (dep.reads.iter().any(|r| wrote.contains(r))
1061                                || wrote.contains(&dep.write)
1062                                || read.contains(&dep.write));
1063                        let emit_barrier = j > 0 && !no_barrier && (full_barrier || hazard);
1064                        if emit_barrier {
1065                            dev.device.cmd_pipeline_barrier(
1066                                cmd,
1067                                vk::PipelineStageFlags::COMPUTE_SHADER,
1068                                vk::PipelineStageFlags::COMPUTE_SHADER,
1069                                vk::DependencyFlags::empty(),
1070                                &[barrier],
1071                                &[],
1072                                &[],
1073                            );
1074                            wrote.clear();
1075                            read.clear();
1076                        }
1077                        let pipeline = kern.pipeline(kernel);
1078                        dev.device
1079                            .cmd_bind_pipeline(cmd, vk::PipelineBindPoint::COMPUTE, pipeline);
1080                        dev.device.cmd_push_constants(
1081                            cmd,
1082                            layout,
1083                            vk::ShaderStageFlags::COMPUTE,
1084                            0,
1085                            push,
1086                        );
1087                        dev.device.cmd_dispatch(cmd, groups.0, groups.1, groups.2);
1088                        wrote.insert(dep.write);
1089                        for &r in &dep.reads {
1090                            read.insert(r);
1091                        }
1092                    }
1093                }
1094                dev.device.end_command_buffer(cmd).expect("vk end cmd");
1095            }
1096            segments.push(Segment::Gpu(cmd));
1097        }
1098        if i < n {
1099            if let Step::Host {
1100                op,
1101                out,
1102                out_shape,
1103                inputs,
1104            } = &schedule[i]
1105            {
1106                segments.push(Segment::Host {
1107                    op: op.clone(),
1108                    out: *out,
1109                    out_shape: out_shape.clone(),
1110                    inputs: inputs.clone(),
1111                });
1112            }
1113            i += 1;
1114        }
1115    }
1116    segments
1117}
1118
1119/// Widen a constant byte blob (any IR dtype) to f32 for the f32-uniform arena.
1120fn widen_const_to_f32(data: &[u8], dt: DType) -> Vec<f32> {
1121    match dt {
1122        DType::F32 => data
1123            .chunks_exact(4)
1124            .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
1125            .collect(),
1126        DType::F16 => data
1127            .chunks_exact(2)
1128            .map(|c| half::f16::from_le_bytes([c[0], c[1]]).to_f32())
1129            .collect(),
1130        DType::BF16 => data
1131            .chunks_exact(2)
1132            .map(|c| half::bf16::from_le_bytes([c[0], c[1]]).to_f32())
1133            .collect(),
1134        DType::F64 => data
1135            .chunks_exact(8)
1136            .map(|c| f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32)
1137            .collect(),
1138        DType::I64 => data
1139            .chunks_exact(8)
1140            .map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32)
1141            .collect(),
1142        DType::I32 | DType::U32 => data
1143            .chunks_exact(4)
1144            .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
1145            .collect(),
1146        DType::I16 => data
1147            .chunks_exact(2)
1148            .map(|c| i16::from_le_bytes([c[0], c[1]]) as f32)
1149            .collect(),
1150        DType::I8 => data.iter().map(|&b| b as i8 as f32).collect(),
1151        DType::U8 | DType::Bool => data.iter().map(|&b| b as f32).collect(),
1152        DType::C64 => data
1153            .chunks_exact(4)
1154            .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
1155            .collect(),
1156    }
1157}
1158
1159// ── schedule construction ──────────────────────────────────────────────────
1160
1161/// Per-GPU-step memory footprint, used to place barriers only where a real data
1162/// hazard exists. The arena is a bump allocator with one unique slot per node
1163/// (only Reshape/Cast/StopGradient alias, and aliases share an offset), so
1164/// tracking arena *element offsets* captures aliasing for free. `reads` are the
1165/// node's input slot offsets; `write` is its output slot offset.
1166#[derive(Clone, Default)]
1167struct StepDep {
1168    reads: Vec<u32>,
1169    write: u32,
1170}
1171
1172/// Build the dispatch schedule plus, in lockstep, the per-step dependency info
1173/// (`StepDep`) that [`record_segments`] uses to elide redundant barriers. Each
1174/// graph node contributes its node-level footprint to every `Step` it emits
1175/// (most nodes emit one; `Concat` emits one per input — conservatively sharing
1176/// the node footprint, which over-serializes only a concat's own sub-copies).
1177fn build_schedule(graph: &Graph, arena: &Arena) -> (Vec<Step>, Vec<StepDep>) {
1178    let mut steps = Vec::new();
1179    let mut deps: Vec<StepDep> = Vec::new();
1180    for node in graph.nodes() {
1181        let off = |id: NodeId| arena.elem_offset(id);
1182        let out = node.id;
1183        let before = steps.len();
1184        match &node.op {
1185            // Leaves / aliases — no dispatch.
1186            Op::Input { .. }
1187            | Op::Param { .. }
1188            | Op::Constant { .. }
1189            | Op::Reshape { .. }
1190            | Op::Cast { .. }
1191            | Op::StopGradient => {}
1192
1193            Op::Binary(op) => {
1194                let a = node.inputs[0];
1195                let b = node.inputs[1];
1196                let n = numel(&dims(graph, out));
1197                let an = numel(&dims(graph, a));
1198                let bn = numel(&dims(graph, b));
1199                let push = Push::default()
1200                    .u(n as u32)
1201                    .u(off(a))
1202                    .u(off(b))
1203                    .u(off(out))
1204                    .u(if an == n { 0 } else { an as u32 })
1205                    .u(if bn == n { 0 } else { bn as u32 })
1206                    .u(binop_id(*op))
1207                    .bytes();
1208                steps.push(Step::Gpu {
1209                    kernel: "binary",
1210                    push,
1211                    groups: groups1d(n, 256),
1212                });
1213            }
1214
1215            Op::Compare(op) => {
1216                let a = node.inputs[0];
1217                let b = node.inputs[1];
1218                let n = numel(&dims(graph, out));
1219                let an = numel(&dims(graph, a));
1220                let bn = numel(&dims(graph, b));
1221                let push = Push::default()
1222                    .u(n as u32)
1223                    .u(off(a))
1224                    .u(off(b))
1225                    .u(off(out))
1226                    .u(if an == n { 0 } else { an as u32 })
1227                    .u(if bn == n { 0 } else { bn as u32 })
1228                    .u(cmp_id(*op))
1229                    .bytes();
1230                steps.push(Step::Gpu {
1231                    kernel: "compare",
1232                    push,
1233                    groups: groups1d(n, 256),
1234                });
1235            }
1236
1237            Op::Where => {
1238                let c = node.inputs[0];
1239                let a = node.inputs[1];
1240                let b = node.inputs[2];
1241                let n = numel(&dims(graph, out));
1242                let cn = numel(&dims(graph, c));
1243                let an = numel(&dims(graph, a));
1244                let bn = numel(&dims(graph, b));
1245                let push = Push::default()
1246                    .u(n as u32)
1247                    .u(off(c))
1248                    .u(off(a))
1249                    .u(off(b))
1250                    .u(off(out))
1251                    .u(if cn == n { 0 } else { cn as u32 })
1252                    .u(if an == n { 0 } else { an as u32 })
1253                    .u(if bn == n { 0 } else { bn as u32 })
1254                    .bytes();
1255                steps.push(Step::Gpu {
1256                    kernel: "where",
1257                    push,
1258                    groups: groups1d(n, 256),
1259                });
1260            }
1261
1262            Op::Activation(act) => {
1263                let x = node.inputs[0];
1264                let n = numel(&dims(graph, out));
1265                let push = Push::default()
1266                    .u(n as u32)
1267                    .u(off(x))
1268                    .u(off(out))
1269                    .u(act_id(*act))
1270                    .bytes();
1271                steps.push(Step::Gpu {
1272                    kernel: "unary",
1273                    push,
1274                    groups: groups1d(n, 256),
1275                });
1276            }
1277
1278            Op::MatMul => {
1279                let a = node.inputs[0];
1280                let b = node.inputs[1];
1281                let ad = dims(graph, a);
1282                let bd = dims(graph, b);
1283                let od = dims(graph, out);
1284                let (m, k) = (ad[ad.len() - 2], ad[ad.len() - 1]);
1285                let n = bd[bd.len() - 1];
1286                let batch = if od.len() > 2 {
1287                    numel(&od[..od.len() - 2])
1288                } else {
1289                    1
1290                };
1291                let a_batch = if ad.len() > 2 {
1292                    numel(&ad[..ad.len() - 2])
1293                } else {
1294                    1
1295                };
1296                let b_batch = if bd.len() > 2 {
1297                    numel(&bd[..bd.len() - 2])
1298                } else {
1299                    1
1300                };
1301                let a_bs = if a_batch <= 1 { 0 } else { m * k };
1302                let b_bs = if b_batch <= 1 { 0 } else { k * n };
1303                let push = Push::default()
1304                    .u(m as u32)
1305                    .u(k as u32)
1306                    .u(n as u32)
1307                    .u(off(a))
1308                    .u(off(b))
1309                    .u(off(out))
1310                    .u(batch as u32)
1311                    .u(a_bs as u32)
1312                    .u(b_bs as u32)
1313                    .u((m * n) as u32)
1314                    .bytes();
1315                steps.push(Step::Gpu {
1316                    kernel: matmul_kernel(m, k, n),
1317                    push,
1318                    groups: (ceil_div(n, 16), ceil_div(m, 16), batch.max(1) as u32),
1319                });
1320            }
1321
1322            Op::Reduce { op, axes, .. } => {
1323                let x = node.inputs[0];
1324                let xd = dims(graph, x);
1325                let rank = xd.len();
1326                // After LowerNonLastAxisReduce: last-axis single-axis reduce.
1327                let last = rank.saturating_sub(1);
1328                debug_assert!(
1329                    axes.as_slice() == [last] || (rank <= 1),
1330                    "rlx-vulkan: non-last-axis reduce should have been lowered"
1331                );
1332                let r = *xd.get(last).unwrap_or(&1);
1333                let outer = numel(&xd) / r.max(1);
1334                let push = Push::default()
1335                    .u(outer as u32)
1336                    .u(r as u32)
1337                    .u(off(x))
1338                    .u(off(out))
1339                    .u(reduce_id(*op))
1340                    .bytes();
1341                steps.push(Step::Gpu {
1342                    kernel: "reduce",
1343                    push,
1344                    groups: groups1d(outer, 256),
1345                });
1346            }
1347
1348            Op::Softmax { axis } => {
1349                let x = node.inputs[0];
1350                let xd = dims(graph, x);
1351                let ax = norm_axis(*axis, xd.len());
1352                let axis_len = xd[ax];
1353                let outer = numel(&xd[..ax]);
1354                let inner = numel(&xd[ax + 1..]);
1355                let push = Push::default()
1356                    .u(outer as u32)
1357                    .u(axis_len as u32)
1358                    .u(inner as u32)
1359                    .u(off(x))
1360                    .u(off(out))
1361                    .bytes();
1362                steps.push(Step::Gpu {
1363                    kernel: "softmax",
1364                    push,
1365                    groups: groups1d(outer * inner, 256),
1366                });
1367            }
1368
1369            Op::RmsNorm { axis, eps } => {
1370                // Op::RmsNorm carries (x, gamma, beta): y = x*rsqrt(ms+eps)*gamma + beta.
1371                let x = node.inputs[0];
1372                let gamma = node.inputs[1];
1373                let beta = node.inputs[2];
1374                let xd = dims(graph, x);
1375                let ax = norm_axis(*axis, xd.len());
1376                debug_assert_eq!(ax, xd.len().saturating_sub(1), "rmsnorm expects last axis");
1377                let n = xd[ax];
1378                let rows = numel(&xd) / n.max(1);
1379                let push = Push::default()
1380                    .u(rows as u32)
1381                    .u(n as u32)
1382                    .u(off(x))
1383                    .u(off(gamma))
1384                    .u(off(beta))
1385                    .u(off(out))
1386                    .f(*eps)
1387                    .bytes();
1388                steps.push(Step::Gpu {
1389                    kernel: "rmsnorm",
1390                    push,
1391                    groups: groups1d(rows, 64),
1392                });
1393            }
1394
1395            Op::LayerNorm { axis, eps } => {
1396                let x = node.inputs[0];
1397                let gamma = node.inputs[1];
1398                let has_beta = node.inputs.len() >= 3;
1399                let beta = if has_beta { node.inputs[2] } else { gamma };
1400                let xd = dims(graph, x);
1401                let ax = norm_axis(*axis, xd.len());
1402                let n = xd[ax];
1403                let rows = numel(&xd) / n.max(1);
1404                let push = Push::default()
1405                    .u(rows as u32)
1406                    .u(n as u32)
1407                    .u(off(x))
1408                    .u(off(gamma))
1409                    .u(off(beta))
1410                    .u(off(out))
1411                    .u(if has_beta { 1 } else { 0 })
1412                    .f(*eps)
1413                    .bytes();
1414                steps.push(Step::Gpu {
1415                    kernel: "layernorm",
1416                    push,
1417                    groups: groups1d(rows, 64),
1418                });
1419            }
1420
1421            Op::Rope {
1422                head_dim,
1423                n_rot,
1424                style,
1425            } => {
1426                let x = node.inputs[0];
1427                let cos = node.inputs[1];
1428                let sin = node.inputs[2];
1429                let xd = dims(graph, x);
1430                let (batch, seq, hidden) = if xd.len() >= 3 {
1431                    (xd[0], xd[1], xd[2])
1432                } else {
1433                    let total = numel(&xd);
1434                    (1, xd[0], total / xd[0].max(1))
1435                };
1436                let hd = *head_dim;
1437                let nh = hidden / hd.max(1);
1438                let tab_half = hd / 2;
1439                let cos_len = numel(&dims(graph, cos));
1440                let cos_rows = cos_len / tab_half.max(1);
1441                let per_token = (cos_rows == batch * seq && cos_rows != seq) as u32;
1442                let style_id = match style {
1443                    RopeStyle::NeoX => 0u32,
1444                    RopeStyle::GptJ => 1u32,
1445                };
1446                let push = Push::default()
1447                    .u(batch as u32)
1448                    .u(seq as u32)
1449                    .u(hidden as u32)
1450                    .u(hd as u32)
1451                    .u(*n_rot as u32)
1452                    .u(nh as u32)
1453                    .u(tab_half as u32)
1454                    .u(hidden as u32) // src_row_stride (no Narrow→Rope fusion)
1455                    .u(per_token)
1456                    .u(style_id)
1457                    .u(off(x))
1458                    .u(off(cos))
1459                    .u(off(sin))
1460                    .u(off(out))
1461                    .bytes();
1462                steps.push(Step::Gpu {
1463                    kernel: "rope",
1464                    push,
1465                    groups: groups1d(batch * seq * nh, 64),
1466                });
1467            }
1468
1469            Op::Attention {
1470                num_heads,
1471                head_dim,
1472                mask_kind,
1473                score_scale,
1474                ..
1475            } => {
1476                let q = node.inputs[0];
1477                let k = node.inputs[1];
1478                let v = node.inputs[2];
1479                let qd = dims(graph, q);
1480                let kd = dims(graph, k);
1481                let nh = *num_heads;
1482                let dh = *head_dim;
1483                let (batch, q_s, k_s, bhsd) = if qd.len() == 4 {
1484                    if qd[1] == nh {
1485                        (qd[0], qd[2], kd[2], 1u32) // [B,H,S,D]
1486                    } else {
1487                        (qd[0], qd[1], kd[1], 0u32) // [B,S,H,D]
1488                    }
1489                } else if qd.len() >= 3 {
1490                    (qd[0], qd[1], kd[1], 0u32)
1491                } else {
1492                    (1, qd[0], kd[0], 0u32)
1493                };
1494                let hs = (nh * dh) as u32;
1495                let (mask_kind_id, mask_off, window) = match mask_kind {
1496                    MaskKind::None => (0u32, 0u32, 0u32),
1497                    MaskKind::Causal => (1, 0, 0),
1498                    MaskKind::SlidingWindow(w) => (2, 0, *w as u32),
1499                    MaskKind::Custom => (3, off(node.inputs[3]), 0),
1500                    MaskKind::Bias => (4, off(node.inputs[3]), 0),
1501                };
1502                let scale = score_scale.unwrap_or((dh as f32).powf(-0.5));
1503                let push = Push::default()
1504                    .u(batch as u32)
1505                    .u(nh as u32)
1506                    .u(q_s as u32)
1507                    .u(k_s as u32)
1508                    .u(dh as u32)
1509                    .u(off(q))
1510                    .u(off(k))
1511                    .u(off(v))
1512                    .u(off(out))
1513                    .u(hs)
1514                    .u(hs)
1515                    .u(hs)
1516                    .u(bhsd)
1517                    .u(mask_kind_id)
1518                    .u(mask_off)
1519                    .u(window)
1520                    .f(scale)
1521                    .f(-1.0e30)
1522                    .f(0.5)
1523                    .bytes();
1524                steps.push(Step::Gpu {
1525                    kernel: "attention",
1526                    push,
1527                    groups: groups1d(batch * nh * q_s, 64),
1528                });
1529            }
1530
1531            Op::Transpose { perm } => {
1532                let x = node.inputs[0];
1533                let xd = dims(graph, x);
1534                let od = dims(graph, out);
1535                let in_str = contig_strides(&xd);
1536                let out_str = contig_strides(&od);
1537                let rank = od.len();
1538                let mut shape = [1u32; 6];
1539                let mut istr = [0u32; 6];
1540                let mut ostr = [0u32; 6];
1541                for ax in 0..rank {
1542                    shape[ax] = od[ax] as u32;
1543                    istr[ax] = in_str[perm[ax]] as u32;
1544                    ostr[ax] = out_str[ax] as u32;
1545                }
1546                let n = numel(&od);
1547                let push = Push::default()
1548                    .u(n as u32)
1549                    .u(rank as u32)
1550                    .u(off(x))
1551                    .u(off(out))
1552                    .us(&shape)
1553                    .us(&istr)
1554                    .us(&ostr)
1555                    .bytes();
1556                steps.push(Step::Gpu {
1557                    kernel: "reindex",
1558                    push,
1559                    groups: groups1d(n, 256),
1560                });
1561            }
1562
1563            Op::Narrow { axis, start, .. } => {
1564                let x = node.inputs[0];
1565                let xd = dims(graph, x);
1566                let od = dims(graph, out);
1567                let in_str = contig_strides(&xd);
1568                let out_str = contig_strides(&od);
1569                let rank = od.len();
1570                let mut shape = [1u32; 6];
1571                let mut istr = [0u32; 6];
1572                let mut ostr = [0u32; 6];
1573                for ax in 0..rank {
1574                    shape[ax] = od[ax] as u32;
1575                    istr[ax] = in_str[ax] as u32;
1576                    ostr[ax] = out_str[ax] as u32;
1577                }
1578                let in_off = off(x) + (*start * in_str[*axis]) as u32;
1579                let n = numel(&od);
1580                let push = Push::default()
1581                    .u(n as u32)
1582                    .u(rank as u32)
1583                    .u(in_off)
1584                    .u(off(out))
1585                    .us(&shape)
1586                    .us(&istr)
1587                    .us(&ostr)
1588                    .bytes();
1589                steps.push(Step::Gpu {
1590                    kernel: "reindex",
1591                    push,
1592                    groups: groups1d(n, 256),
1593                });
1594            }
1595
1596            Op::Expand { .. } => {
1597                let x = node.inputs[0];
1598                let xd = dims(graph, x);
1599                let od = dims(graph, out);
1600                let rank = od.len();
1601                // Right-align input dims to output rank.
1602                let pad = rank - xd.len();
1603                let in_str_full = contig_strides(&xd);
1604                let out_str = contig_strides(&od);
1605                let mut shape = [1u32; 6];
1606                let mut istr = [0u32; 6];
1607                let mut ostr = [0u32; 6];
1608                for ax in 0..rank {
1609                    shape[ax] = od[ax] as u32;
1610                    ostr[ax] = out_str[ax] as u32;
1611                    if ax < pad {
1612                        istr[ax] = 0;
1613                    } else {
1614                        let xi = ax - pad;
1615                        istr[ax] = if xd[xi] == 1 && od[ax] != 1 {
1616                            0
1617                        } else {
1618                            in_str_full[xi] as u32
1619                        };
1620                    }
1621                }
1622                let n = numel(&od);
1623                let push = Push::default()
1624                    .u(n as u32)
1625                    .u(rank as u32)
1626                    .u(off(x))
1627                    .u(off(out))
1628                    .us(&shape)
1629                    .us(&istr)
1630                    .us(&ostr)
1631                    .bytes();
1632                steps.push(Step::Gpu {
1633                    kernel: "reindex",
1634                    push,
1635                    groups: groups1d(n, 256),
1636                });
1637            }
1638
1639            Op::Concat { axis } => {
1640                let od = dims(graph, out);
1641                let out_str = contig_strides(&od);
1642                let rank = od.len();
1643                let mut axis_cursor = 0usize;
1644                for &inp in &node.inputs {
1645                    let id_dims = dims(graph, inp);
1646                    let in_str = contig_strides(&id_dims);
1647                    let mut shape = [1u32; 6];
1648                    let mut istr = [0u32; 6];
1649                    let mut ostr = [0u32; 6];
1650                    for ax in 0..rank {
1651                        shape[ax] = *id_dims.get(ax).unwrap_or(&1) as u32;
1652                        istr[ax] = *in_str.get(ax).unwrap_or(&0) as u32;
1653                        ostr[ax] = out_str[ax] as u32;
1654                    }
1655                    let out_off = off(out) + (axis_cursor * out_str[*axis]) as u32;
1656                    let n = numel(&id_dims);
1657                    let push = Push::default()
1658                        .u(n as u32)
1659                        .u(rank as u32)
1660                        .u(off(inp))
1661                        .u(out_off)
1662                        .us(&shape)
1663                        .us(&istr)
1664                        .us(&ostr)
1665                        .bytes();
1666                    steps.push(Step::Gpu {
1667                        kernel: "reindex",
1668                        push,
1669                        groups: groups1d(n, 256),
1670                    });
1671                    axis_cursor += *id_dims.get(*axis).unwrap_or(&1);
1672                }
1673            }
1674
1675            Op::Gather { axis } => {
1676                let data = node.inputs[0];
1677                let idx = node.inputs[1];
1678                let dd = dims(graph, data);
1679                let ax = *axis;
1680                let out_outer = numel(&dd[..ax]);
1681                let axis_dim = dd[ax];
1682                let out_inner = numel(&dd[ax + 1..]);
1683                let n_idx = numel(&dims(graph, idx));
1684                let total = out_outer * n_idx * out_inner;
1685                let push = Push::default()
1686                    .u(out_outer as u32)
1687                    .u(n_idx as u32)
1688                    .u(out_inner as u32)
1689                    .u(axis_dim as u32)
1690                    .u(off(data))
1691                    .u(off(idx))
1692                    .u(off(out))
1693                    .bytes();
1694                steps.push(Step::Gpu {
1695                    kernel: "gather",
1696                    push,
1697                    groups: groups1d(total, 256),
1698                });
1699            }
1700
1701            Op::Cumsum { axis, exclusive } => {
1702                let x = node.inputs[0];
1703                let xd = dims(graph, x);
1704                let ax = norm_axis(*axis, xd.len());
1705                debug_assert_eq!(ax, xd.len().saturating_sub(1), "cumsum expects last axis");
1706                let cols = *xd.get(ax).unwrap_or(&1);
1707                let rows = numel(&xd) / cols.max(1);
1708                let push = Push::default()
1709                    .u(rows as u32)
1710                    .u(cols as u32)
1711                    .u(off(x))
1712                    .u(off(out))
1713                    .u(if *exclusive { 1 } else { 0 })
1714                    .bytes();
1715                steps.push(Step::Gpu {
1716                    kernel: "cumsum",
1717                    push,
1718                    groups: groups1d(rows, 64),
1719                });
1720            }
1721
1722            Op::Reverse { axes } => {
1723                let x = node.inputs[0];
1724                let xd = dims(graph, x);
1725                let rank = xd.len();
1726                let mut shape = [1u32; 6];
1727                let mut flip = [0u32; 6];
1728                for ax in 0..rank {
1729                    shape[ax] = xd[ax] as u32;
1730                    flip[ax] = if axes.contains(&ax) { 1 } else { 0 };
1731                }
1732                let n = numel(&xd);
1733                let push = Push::default()
1734                    .u(n as u32)
1735                    .u(rank as u32)
1736                    .u(off(x))
1737                    .u(off(out))
1738                    .us(&shape)
1739                    .us(&flip)
1740                    .bytes();
1741                steps.push(Step::Gpu {
1742                    kernel: "reverse",
1743                    push,
1744                    groups: groups1d(n, 256),
1745                });
1746            }
1747
1748            Op::ArgMax { axis, .. } | Op::ArgMin { axis, .. } => {
1749                let x = node.inputs[0];
1750                let xd = dims(graph, x);
1751                let ax = (*axis).min(xd.len().saturating_sub(1));
1752                let axis_len = xd[ax];
1753                let outer = numel(&xd[..ax]);
1754                let inner = numel(&xd[ax + 1..]);
1755                let op_id = if matches!(node.op, Op::ArgMax { .. }) {
1756                    0
1757                } else {
1758                    1
1759                };
1760                let push = Push::default()
1761                    .u(outer as u32)
1762                    .u(axis_len as u32)
1763                    .u(inner as u32)
1764                    .u(off(x))
1765                    .u(off(out))
1766                    .u(op_id)
1767                    .bytes();
1768                steps.push(Step::Gpu {
1769                    kernel: "argreduce",
1770                    push,
1771                    groups: groups1d(outer * inner, 256),
1772                });
1773            }
1774
1775            Op::LayerNorm2d { eps } => {
1776                // x [N,C,H,W], gamma, beta [C].
1777                let x = node.inputs[0];
1778                let gamma = node.inputs[1];
1779                let beta = node.inputs[2];
1780                let xd = dims(graph, x);
1781                let (nn, cc, hw) = (xd[0], xd[1], xd[2] * xd[3]);
1782                let positions = nn * hw;
1783                let push = Push::default()
1784                    .u(positions as u32)
1785                    .u(cc as u32)
1786                    .u(hw as u32)
1787                    .u(off(x))
1788                    .u(off(gamma))
1789                    .u(off(beta))
1790                    .u(off(out))
1791                    .f(*eps)
1792                    .bytes();
1793                steps.push(Step::Gpu {
1794                    kernel: "layernorm2d",
1795                    push,
1796                    groups: groups1d(positions, 64),
1797                });
1798            }
1799
1800            Op::Pool {
1801                kind,
1802                kernel_size,
1803                stride,
1804                padding,
1805            } => {
1806                // 2-D pooling on NCHW (kernel_size.len() == 2).
1807                let x = node.inputs[0];
1808                let xd = dims(graph, x);
1809                let od = dims(graph, out);
1810                let (nn, cc, hh, ww) = (xd[0], xd[1], xd[2], xd[3]);
1811                let (oh, ow) = (od[2], od[3]);
1812                let (kh, kw) = (kernel_size[0], kernel_size[1]);
1813                let (sh, sw) = (stride[0], stride[1]);
1814                let (ph, pw) = (padding[0], padding[1]);
1815                let kind_id = reduce_id(*kind); // Max=2, Mean=1
1816                let push = Push::default()
1817                    .us(&[nn as u32, cc as u32, hh as u32, ww as u32])
1818                    .us(&[oh as u32, ow as u32])
1819                    .us(&[
1820                        kh as u32, kw as u32, sh as u32, sw as u32, ph as u32, pw as u32,
1821                    ])
1822                    .u(off(x))
1823                    .u(off(out))
1824                    .u(kind_id)
1825                    .bytes();
1826                steps.push(Step::Gpu {
1827                    kernel: "pool2d",
1828                    push,
1829                    groups: groups1d(nn * cc * oh * ow, 64),
1830                });
1831            }
1832
1833            Op::ResizeNearest2x => {
1834                let x = node.inputs[0];
1835                let xd = dims(graph, x);
1836                let (nn, cc, hh, ww) = (xd[0], xd[1], xd[2], xd[3]);
1837                let push = Push::default()
1838                    .us(&[nn as u32, cc as u32, hh as u32, ww as u32])
1839                    .u(off(x))
1840                    .u(off(out))
1841                    .bytes();
1842                steps.push(Step::Gpu {
1843                    kernel: "resize2x",
1844                    push,
1845                    groups: groups1d(nn * cc * hh * 4 * ww, 256),
1846                });
1847            }
1848
1849            Op::GroupedMatMul => {
1850                // inputs: [input [M,K], weight [E,K,N], expert_idx [M]] → [M,N]
1851                let input = node.inputs[0];
1852                let weight = node.inputs[1];
1853                let idx = node.inputs[2];
1854                let id = dims(graph, input);
1855                let wd = dims(graph, weight);
1856                let (m, k) = (id[id.len() - 2], id[id.len() - 1]);
1857                let n = wd[wd.len() - 1];
1858                let push = Push::default()
1859                    .u(m as u32)
1860                    .u(k as u32)
1861                    .u(n as u32)
1862                    .u(off(input))
1863                    .u(off(weight))
1864                    .u(off(idx))
1865                    .u(off(out))
1866                    .bytes();
1867                steps.push(Step::Gpu {
1868                    kernel: "grouped_matmul",
1869                    push,
1870                    groups: (ceil_div(n, 16), ceil_div(m, 16), 1),
1871                });
1872            }
1873
1874            Op::Conv {
1875                kernel_size,
1876                stride,
1877                padding,
1878                dilation,
1879                groups,
1880            } => {
1881                // 2-D conv (kernel_size.len() == 2). inputs: [x, weight, bias?].
1882                let x = node.inputs[0];
1883                let weight = node.inputs[1];
1884                let has_bias = node.inputs.len() > 2;
1885                let bias = if has_bias { node.inputs[2] } else { weight };
1886                let xd = dims(graph, x);
1887                let od = dims(graph, out);
1888                let (nn, cin, hh, ww) = (xd[0], xd[1], xd[2], xd[3]);
1889                let (cout, oh, ow) = (od[1], od[2], od[3]);
1890                let (kh, kw) = (kernel_size[0], kernel_size[1]);
1891                let (sh, sw) = (stride[0], stride[1]);
1892                let (ph, pw) = (padding[0], padding[1]);
1893                let (dh, dw) = (dilation[0], dilation[1]);
1894                let push = Push::default()
1895                    .us(&[nn as u32, cin as u32, hh as u32, ww as u32])
1896                    .us(&[cout as u32, kh as u32, kw as u32])
1897                    .us(&[oh as u32, ow as u32])
1898                    .us(&[
1899                        sh as u32, sw as u32, ph as u32, pw as u32, dh as u32, dw as u32,
1900                    ])
1901                    .u(*groups as u32)
1902                    .u(if has_bias { 1 } else { 0 })
1903                    .u(off(x))
1904                    .u(off(weight))
1905                    .u(off(bias))
1906                    .u(off(out))
1907                    .bytes();
1908                steps.push(Step::Gpu {
1909                    kernel: "conv2d",
1910                    push,
1911                    groups: groups1d(nn * cout * oh * ow, 64),
1912                });
1913            }
1914
1915            Op::SelectiveScan { state_size } => {
1916                // inputs: [x, delta, a, b, c]; x,delta [B,S,H], a [H,N], b,c [B,S,N]
1917                let x = node.inputs[0];
1918                let delta = node.inputs[1];
1919                let a = node.inputs[2];
1920                let bmat = node.inputs[3];
1921                let cmat = node.inputs[4];
1922                let xd = dims(graph, x);
1923                let (bb, ss, hh) = (xd[0], xd[1], xd[2]);
1924                let nn = *state_size;
1925                let push = Push::default()
1926                    .u(bb as u32)
1927                    .u(ss as u32)
1928                    .u(hh as u32)
1929                    .u(nn as u32)
1930                    .u(off(x))
1931                    .u(off(delta))
1932                    .u(off(a))
1933                    .u(off(bmat))
1934                    .u(off(cmat))
1935                    .u(off(out))
1936                    .bytes();
1937                steps.push(Step::Gpu {
1938                    kernel: "selective_scan",
1939                    push,
1940                    groups: groups1d(bb * hh, 64),
1941                });
1942            }
1943
1944            Op::Im2Col {
1945                kernel_size,
1946                stride,
1947                padding,
1948                dilation,
1949            } => {
1950                // x [N,Cin,H,W] → [N*Ho*Wo, Cin*kH*kW]. out dims give Ho*Wo / cols.
1951                let x = node.inputs[0];
1952                let xd = dims(graph, x);
1953                let (nn, cin, hh, ww) = (xd[0], xd[1], xd[2], xd[3]);
1954                let (kh, kw) = (kernel_size[0], kernel_size[1]);
1955                let (sh, sw) = (stride[0], stride[1]);
1956                let (ph, pw) = (padding[0], padding[1]);
1957                let (dh, dw) = (dilation[0], dilation[1]);
1958                let eff_h = dh * (kh - 1) + 1;
1959                let eff_w = dw * (kw - 1) + 1;
1960                let ho = (hh + 2 * ph - eff_h) / sh + 1;
1961                let wo = (ww + 2 * pw - eff_w) / sw + 1;
1962                let push = Push::default()
1963                    .us(&[nn as u32, cin as u32, hh as u32, ww as u32])
1964                    .us(&[ho as u32, wo as u32])
1965                    .us(&[
1966                        kh as u32, kw as u32, sh as u32, sw as u32, ph as u32, pw as u32,
1967                        dh as u32, dw as u32,
1968                    ])
1969                    .u(off(x))
1970                    .u(off(out))
1971                    .bytes();
1972                steps.push(Step::Gpu {
1973                    kernel: "im2col",
1974                    push,
1975                    groups: groups1d(nn * ho * wo * cin * kh * kw, 256),
1976                });
1977            }
1978
1979            Op::ScatterAdd => {
1980                // updates [U, ...trailing], indices [U] → out [out_dim, ...trailing]
1981                let updates = node.inputs[0];
1982                let indices = node.inputs[1];
1983                let ud = dims(graph, updates);
1984                let od = dims(graph, out);
1985                let num_updates = ud[0];
1986                let trailing = numel(&ud[1..]);
1987                let out_dim = od[0];
1988                let push = Push::default()
1989                    .u(out_dim as u32)
1990                    .u(trailing as u32)
1991                    .u(num_updates as u32)
1992                    .u(off(updates))
1993                    .u(off(indices))
1994                    .u(off(out))
1995                    .bytes();
1996                steps.push(Step::Gpu {
1997                    kernel: "scatter_add",
1998                    push,
1999                    groups: groups1d(out_dim * trailing, 256),
2000                });
2001            }
2002
2003            Op::TopK { k } => {
2004                let x = node.inputs[0];
2005                let xd = dims(graph, x);
2006                let n = *xd.last().unwrap_or(&1);
2007                let rows = numel(&xd) / n.max(1);
2008                let push = Push::default()
2009                    .u(rows as u32)
2010                    .u(n as u32)
2011                    .u(*k as u32)
2012                    .u(off(x))
2013                    .u(off(out))
2014                    .bytes();
2015                steps.push(Step::Gpu {
2016                    kernel: "topk",
2017                    push,
2018                    groups: groups1d(rows, 64),
2019                });
2020            }
2021
2022            // GGUF K-quant dequant + matmul. Decode GEMV (m == 1) for the
2023            // Q4_K / Q6_K schemes runs natively; everything else (prefill
2024            // m > 1, other GGUF schemes) keeps the CPU host-fallback path.
2025            Op::DequantMatMul { scheme } => {
2026                use rlx_ir::quant::QuantScheme;
2027                let x = node.inputs[0];
2028                let xd = dims(graph, x);
2029                let od = dims(graph, out);
2030                let n = *od.last().unwrap_or(&1);
2031                let m = numel(&od) / n.max(1);
2032                let k = numel(&xd) / m.max(1);
2033                let gpu_scheme = match scheme {
2034                    QuantScheme::GgufQ4K => Some(0u32),
2035                    QuantScheme::GgufQ6K => Some(1u32),
2036                    _ => None,
2037                };
2038                match gpu_scheme {
2039                    Some(sc) if m == 1 && k.is_multiple_of(256) && n >= 1 => {
2040                        let w = node.inputs[1];
2041                        let push = Push::default()
2042                            .u(n as u32)
2043                            .u(k as u32)
2044                            .u(off(x))
2045                            .u(off(w))
2046                            .u(off(out))
2047                            .u(sc)
2048                            .bytes();
2049                        steps.push(Step::Gpu {
2050                            kernel: "dequant_matmul",
2051                            push,
2052                            groups: groups1d(n, 64),
2053                        });
2054                    }
2055                    _ => {
2056                        steps.push(Step::Host {
2057                            op: node.op.clone(),
2058                            out: node.id,
2059                            out_shape: node.shape.clone(),
2060                            inputs: node.inputs.clone(),
2061                        });
2062                    }
2063                }
2064            }
2065
2066            op if is_host_fallback(op) => {
2067                steps.push(Step::Host {
2068                    op: node.op.clone(),
2069                    out: node.id,
2070                    out_shape: node.shape.clone(),
2071                    inputs: node.inputs.clone(),
2072                });
2073            }
2074
2075            other => panic!(
2076                "rlx-vulkan: op {:?} reached the scheduler but has no kernel \
2077                 (should have been rejected at legalize). Pin this graph to Device::Cpu.",
2078                other.kind()
2079            ),
2080        }
2081
2082        // Attach the node's memory footprint to each Step it just produced. GPU
2083        // steps read the node's input slots and write its output slot; host
2084        // steps get an entry too (kept parallel to `steps`, unused at record
2085        // time since host ops sit on their own segment boundary).
2086        let added = steps.len() - before;
2087        if added > 0 {
2088            let reads: Vec<u32> = node
2089                .inputs
2090                .iter()
2091                .filter(|&&id| arena.has(id))
2092                .map(|&id| arena.elem_offset(id))
2093                .collect();
2094            let write = if arena.has(out) {
2095                arena.elem_offset(out)
2096            } else {
2097                0
2098            };
2099            for _ in 0..added {
2100                deps.push(StepDep {
2101                    reads: reads.clone(),
2102                    write,
2103                });
2104            }
2105        }
2106    }
2107    (steps, deps)
2108}