Skip to main content

rlx_cpu/
thunk.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Thunks — pre-compiled kernel dispatch with zero per-call overhead.
17//!
18//! At compile time, the graph is lowered into a flat `Vec<Thunk>` where each
19//! thunk holds pre-computed arena offsets, dimensions, and kernel type.
20//! At runtime, the executor just iterates thunks and calls kernels directly.
21
22// Edition 2024: bodies of `unsafe fn` are safe by default; `sl`/`sl_mut` stay `unsafe fn`.
23#![allow(unsafe_op_in_unsafe_fn)]
24//! No match dispatch, no HashMap lookup, no dimension computation.
25
26use crate::arena::Arena;
27use crate::op_registry::CpuKernel;
28use rlx_ir::op::{Activation, BinaryOp, CmpOp, ReduceOp};
29use rlx_ir::{Graph, NodeId, Op, Shape};
30use std::collections::HashMap;
31use std::sync::Arc;
32
33/// A pre-compiled kernel call with all args resolved to arena offsets.
34#[derive(Clone)]
35pub enum Thunk {
36    /// Skip (Input/Param already in arena)
37    Nop,
38    /// C = A @ B (BLAS sgemm)
39    Sgemm {
40        a: usize,
41        b: usize,
42        c: usize,
43        m: u32,
44        k: u32,
45        n: u32,
46    },
47    /// f64 dense solve `x = A⁻¹·b` via LAPACK dgesv.
48    /// `a`, `b`, `x` are byte-offsets into the arena. `n` is the matrix
49    /// dimension; `nrhs` is 1 for a vector RHS or >1 for multi-RHS.
50    /// The kernel materializes scratch copies of A and b internally
51    /// (LAPACK overwrites both with LU factors and solution).
52    DenseSolveF64 {
53        a: usize,
54        b: usize,
55        x: usize,
56        n: u32,
57        nrhs: u32,
58    },
59    /// f32 twin of `DenseSolveF64`. Calls LAPACK `sgesv` (or the
60    /// no-blas Rust fallback). Same arena byte-offset contract.
61    DenseSolveF32 {
62        a: usize,
63        b: usize,
64        x: usize,
65        n: u32,
66        nrhs: u32,
67    },
68    /// Batched f64 dense solve. `a`, `b`, `x` are byte-offsets to
69    /// the leading slice; `batch` is the number of independent
70    /// systems. Per slice the kernel calls `dgesv(A_i, b_i, n, nrhs)`
71    /// — LAPACK has no batched dgesv on Accelerate, so we loop.
72    BatchedDenseSolveF64 {
73        a: usize,
74        b: usize,
75        x: usize,
76        batch: u32,
77        n: u32,
78        nrhs: u32,
79    },
80    /// Batched f32 dense solve — loop of `sgesv` per batch slice.
81    BatchedDenseSolveF32 {
82        a: usize,
83        b: usize,
84        x: usize,
85        batch: u32,
86        n: u32,
87        nrhs: u32,
88    },
89    /// Batched f64 matmul. Both inputs and output have a leading
90    /// batch axis of size `batch`. Per-batch independent dgemm:
91    /// `C[i] = A[i] @ B[i]` for `i in 0..batch`. Used by VJP rules
92    /// that emit per-batch outer products (e.g., BatchedDenseSolve
93    /// VJP). The unbatched `Dgemm` thunk handles the rank-2 case.
94    BatchedDgemmF64 {
95        a: usize,
96        b: usize,
97        c: usize,
98        batch: u32,
99        m: u32,
100        k: u32,
101        n: u32,
102    },
103    /// Batched f32 matmul — same loop-per-batch shape as
104    /// `BatchedDgemmF64` but calling `sgemm`. Needed for attention
105    /// patterns where both operands carry a batch dim (e.g. q@k^T
106    /// and attn@v in decomposed self-attention). The 2-D `Sgemm`
107    /// flatten trick is wrong in that case because it treats `b` as
108    /// a single shared RHS across every batch.
109    BatchedSgemm {
110        a: usize,
111        b: usize,
112        c: usize,
113        batch: u32,
114        m: u32,
115        k: u32,
116        n: u32,
117    },
118    /// C = A @ B via Accelerate cblas_dgemm. Mirror of `Sgemm` at f64.
119    Dgemm {
120        a: usize,
121        b: usize,
122        c: usize,
123        m: u32,
124        k: u32,
125        n: u32,
126    },
127    /// f64 N-D index walk used for both `Op::Transpose` and `Op::Expand`.
128    /// `in_strides` carries 0s on broadcast axes (Expand) or permuted
129    /// strides (Transpose). Mirror of `Thunk::Transpose` at f64.
130    TransposeF64 {
131        src: usize,
132        dst: usize,
133        in_total: u32,
134        out_dims: Vec<u32>,
135        in_strides: Vec<u32>,
136    },
137    /// f64 element-wise activation. Single-input, single-output. The
138    /// kernel always reads from `src` and writes to `dst`, so it works
139    /// whether or not the planner aliased the two slots.
140    ActivationF64 {
141        src: usize,
142        dst: usize,
143        len: u32,
144        kind: Activation,
145    },
146    /// Element-wise complex squared-magnitude: `|z|² = re² + im²`.
147    /// Reads the C64 input at `src` as `2·len` f32 ([re,im] pairs),
148    /// writes `len` f32 to `dst`.
149    ComplexNormSqF32 {
150        src: usize,
151        dst: usize,
152        /// Logical element count (number of complex values).
153        len: u32,
154    },
155    /// Wirtinger backward for [`ComplexNormSqF32`]: `dz = g · z` as
156    /// C64. Reads `z` at `2·len` f32 + `g` at `len` f32; writes
157    /// `2·len` f32 to `dz`.
158    ComplexNormSqBackwardF32 {
159        z: usize,
160        g: usize,
161        dz: usize,
162        len: u32,
163    },
164    /// Element-wise C64 conjugate: writes `[re_i, -im_i]` per element.
165    /// Layout matches the rest of C64 here ([re,im] interleaved f32).
166    ConjugateC64 { src: usize, dst: usize, len: u32 },
167    /// C64 element-wise activation. Only kinds with well-defined
168    /// complex extensions are supported: Neg, Exp, Log, Sqrt.
169    /// Everything else (Sigmoid, Tanh, Relu, Abs, Sin/Cos/Tan/Atan,
170    /// Round, GeLU family) is rejected at lowering — those don't have
171    /// single natural complex definitions. `len` is the **complex
172    /// element count** (the f32 buffer holds `2·len` floats).
173    ActivationC64 {
174        src: usize,
175        dst: usize,
176        len: u32,
177        kind: Activation,
178    },
179    /// f64 contiguous reduction along a single axis range. Layout
180    /// `[outer, reduced, inner]` in memory; output is `[outer, inner]`.
181    /// Sum only for now (Mean composes via 1/N multiply post-pass).
182    ReduceSumF64 {
183        src: usize,
184        dst: usize,
185        outer: u32,
186        reduced: u32,
187        inner: u32,
188    },
189    /// f64 plain copy (Reshape / Cast at the same dtype). Mirrors `Copy`
190    /// but at 8 bytes per element.
191    CopyF64 { src: usize, dst: usize, len: u32 },
192    /// f64 element-wise binary with broadcast. `len`/`lhs_len`/`rhs_len`
193    /// are element counts; kernel does `out[i] = lhs[i % lhs_len] OP rhs[i % rhs_len]`.
194    /// Mirror of `BinaryFull` at 8 bytes per element.
195    BinaryFullF64 {
196        lhs: usize,
197        rhs: usize,
198        dst: usize,
199        len: u32,
200        lhs_len: u32,
201        rhs_len: u32,
202        op: BinaryOp,
203        /// Output shape dims (row-major). Empty in the fast path. See
204        /// `BinaryFull` doc for the broadcast convention.
205        out_dims_bcast: Vec<u32>,
206        bcast_lhs_strides: Vec<u32>,
207        bcast_rhs_strides: Vec<u32>,
208    },
209    /// f64 concat — byte-for-byte mirror of `Concat` but copies
210    /// 8 bytes per element. Element-counted offsets/strides match
211    /// the f32 variant; the executor scales by elem_size internally.
212    ConcatF64 {
213        dst: usize,
214        outer: u32,
215        inner: u32,
216        total_axis: u32,
217        inputs: Vec<(usize, u32)>,
218    },
219    /// C64 element-wise binary with broadcast. Same `len` /
220    /// `lhs_len` / `rhs_len` semantics as `BinaryFull` but each
221    /// "element" is one complex value (8 bytes = `[re, im]` as two
222    /// f32s). The executor reads the underlying f32 buffer at
223    /// `2·len` floats and walks element pairs. Supports Add / Sub /
224    /// Mul / Div; Max / Min / Pow have no single natural complex
225    /// definition and panic at lowering.
226    BinaryFullC64 {
227        lhs: usize,
228        rhs: usize,
229        dst: usize,
230        /// Complex element count (NOT f32 count). f32 buffer length
231        /// is `2·len`.
232        len: u32,
233        lhs_len: u32,
234        rhs_len: u32,
235        op: BinaryOp,
236        out_dims_bcast: Vec<u32>,
237        bcast_lhs_strides: Vec<u32>,
238        bcast_rhs_strides: Vec<u32>,
239    },
240    /// Bounded scan. Holds a recursively-compiled body schedule + a
241    /// pre-initialized body arena snapshot (constants filled). Each
242    /// outer execution clones the snapshot, copies the carry-in slot
243    /// from the outer arena, runs the body schedule `length` times,
244    /// then writes the final carry to the outer arena.
245    ///
246    /// Single-carry MVP — body has exactly one Input and one output,
247    /// both same shape and dtype.
248    Scan {
249        body: Arc<ThunkSchedule>,
250        body_init: Arc<Vec<u8>>, // pristine body arena bytes
251        body_input_off: usize,   // byte offset of the body's carry-Input slot
252        body_output_off: usize,  // byte offset of the body's output slot
253        outer_init_off: usize,   // outer-arena offset of the initial carry
254        outer_final_off: usize,  // outer-arena offset of the final carry / trajectory base
255        length: u32,
256        carry_bytes: u32, // carry size in bytes
257        /// When true, write each step's carry to the outer arena at
258        /// offset `outer_final_off + t * carry_bytes`, producing a
259        /// `[length, *carry]` stacked trajectory. When false, only the
260        /// final carry lands at `outer_final_off`.
261        save_trajectory: bool,
262        /// Per-step `xs` inputs. For each: (body_x_input_off,
263        /// outer_xs_base_off, per_step_bytes). Per iteration `t`, the
264        /// executor copies `outer_xs_base_off + t * per_step_bytes`
265        /// into `body_x_input_off`. Empty when the scan has no xs.
266        xs_inputs: Arc<Vec<(usize, usize, u32)>>,
267        /// Broadcast inputs — values constant across iterations. For
268        /// each: (body_bcast_input_off, outer_bcast_off, total_bytes).
269        /// Filled into `body_buf` ONCE before the scan loop starts
270        /// (xs in contrast are re-filled every iteration). Empty when
271        /// the scan has no bcasts.
272        bcast_inputs: Arc<Vec<(usize, usize, u32)>>,
273        /// Number of trajectory checkpoints (when `save_trajectory`).
274        /// `0` or `length` ⇒ save every iteration. Otherwise save only
275        /// `K` rows at indices `floor((k+1) * length / K) - 1` for
276        /// `k in 0..K`. Last index is always `length-1` so the final
277        /// carry is always cached.
278        num_checkpoints: u32,
279    },
280
281    /// Reverse-mode AD companion to `Thunk::Scan`. Walks `t = length-1
282    /// .. 0`, threading `dcarry` through the body's VJP. Per iteration:
283    /// writes `carry_t` (from outer init or trajectory), each `xs_i[t]`
284    /// slice, and the current `dcarry` into the body_vjp's Input
285    /// slots, runs body_vjp, reads new `dcarry` from its single output.
286    /// f64 carry only — the upstream-accumulation step in trajectory
287    /// mode does an element-wise f64 add.
288    ScanBackward {
289        body_vjp: Arc<ThunkSchedule>,
290        body_init: Arc<Vec<u8>>,
291        body_carry_in_off: usize, // body_vjp's mirrored body-carry-input slot
292        body_x_offs: Arc<Vec<usize>>, // body_vjp's mirrored x_t_i Input slots, in xs order
293        body_d_output_off: usize, // body_vjp's "d_output" Input slot
294        body_dcarry_out_off: usize, // body_vjp's gradient output
295        outer_init_off: usize,    // original init carry
296        outer_traj_off: usize,    // [length-or-K, *carry] trajectory base
297        outer_upstream_off: usize, // upstream gradient (carry shape, or [length, *carry])
298        /// Per-xs entries: (outer_xs_base_off, per_step_bytes). Read
299        /// `xs_i[t]` from `outer_xs_base_off + t * per_step_bytes`.
300        outer_xs_offs: Arc<Vec<(usize, u32)>>,
301        outer_dinit_off: usize, // output: dinit
302        length: u32,
303        carry_bytes: u32,
304        /// Bytes per element in the carry tensor: 4 for f32, 8 for f64.
305        /// Used to dispatch the trajectory-mode upstream accumulation
306        /// kernel (the dcarry += upstream\[t\] add must use the right
307        /// floating-point type — a hard-coded f64 add silently does
308        /// nothing for an f32 carry whose `cb` isn't divisible by 8).
309        carry_elem_size: u32,
310        save_trajectory: bool, // true → upstream is per-step; false → just final
311        /// Recursive checkpointing config. `0` or `length` ⇒ full
312        /// trajectory cached, no recompute (existing behavior).
313        /// `0 < K < length` ⇒ trajectory has only K rows; the executor
314        /// recomputes intermediate carries via `forward_body` between
315        /// checkpoints. Memory: O(K · carry_bytes); time: O(length).
316        num_checkpoints: u32,
317        /// Forward body schedule (same compiled body as the forward
318        /// Op::Scan), used for recompute when `num_checkpoints` is
319        /// active. `None` for the All strategy.
320        forward_body: Option<Arc<ThunkSchedule>>,
321        /// Pristine forward body arena bytes (constants filled).
322        forward_body_init: Option<Arc<Vec<u8>>>,
323        /// Forward body's carry-Input and output slot offsets — needed
324        /// to seed/read the body during recompute.
325        forward_body_carry_in_off: usize,
326        forward_body_output_off: usize,
327        /// Forward body's per-step xs Input slots (one per outer xs).
328        /// Same indexing convention as `body_x_offs`.
329        forward_body_x_offs: Arc<Vec<usize>>,
330    },
331
332    /// Companion to `ScanBackward` that materializes one stacked
333    /// `dxs_i`. Same backward loop; per iteration, after running
334    /// body_vjp, copies its `body_dxs_out_off` slot into the outer
335    /// arena at `outer_dxs_off + t * per_step_bytes`. dcarry threading
336    /// is identical — we still need it for the body_vjp recurrence
337    /// even though we don't write it back to the outer arena.
338    ScanBackwardXs {
339        body_vjp: Arc<ThunkSchedule>,
340        body_init: Arc<Vec<u8>>,
341        body_carry_in_off: usize,
342        body_x_offs: Arc<Vec<usize>>,
343        body_d_output_off: usize,
344        body_dcarry_out_off: usize,
345        body_dxs_out_off: usize, // the body_vjp output we extract per step
346        outer_init_off: usize,
347        outer_traj_off: usize,
348        outer_upstream_off: usize,
349        outer_xs_offs: Arc<Vec<(usize, u32)>>,
350        outer_dxs_off: usize, // base of the stacked [length, *per_step] output
351        length: u32,
352        carry_bytes: u32,
353        /// Same role as `Thunk::ScanBackward::carry_elem_size`.
354        carry_elem_size: u32,
355        per_step_bytes: u32, // bytes per row of the dxs output
356        save_trajectory: bool,
357        /// Recursive checkpointing config. Same semantics as
358        /// `Thunk::ScanBackward::num_checkpoints` — `0` or `length`
359        /// means "save every step's carry"; `0 < K < length` means
360        /// the trajectory has only K rows and the executor recomputes
361        /// intermediate carries via `forward_body` (which must be
362        /// `Some`). Implemented via segment-cached recompute,
363        /// mirroring the `ScanBackward` path.
364        num_checkpoints: u32,
365        forward_body: Option<Arc<ThunkSchedule>>,
366        forward_body_init: Option<Arc<Vec<u8>>>,
367        forward_body_carry_in_off: usize,
368        forward_body_output_off: usize,
369        forward_body_x_offs: Arc<Vec<usize>>,
370    },
371    /// User-defined sub-graph (`Op::CustomFn`) — runs `fwd_body` once.
372    /// Per execution: clone `body_init`, copy each primal input from the
373    /// outer arena into its body Input slot, run the body schedule,
374    /// copy the body's single output back to the outer arena.
375    CustomFn {
376        body: Arc<ThunkSchedule>,
377        body_init: Arc<Vec<u8>>,
378        /// Per primal input: (body_input_off, outer_input_off, bytes).
379        inputs: Arc<Vec<(usize, usize, u32)>>,
380        body_output_off: usize,
381        outer_output_off: usize,
382        out_bytes: u32,
383    },
384    /// C = A @ B; C += bias; C = act(C)
385    FusedMmBiasAct {
386        a: usize,
387        w: usize,
388        bias: usize,
389        c: usize,
390        m: u32,
391        k: u32,
392        n: u32,
393        act: Option<Activation>,
394    },
395    /// out = LN(x + residual + bias, gamma, beta)
396    FusedResidualLN {
397        x: usize,
398        res: usize,
399        bias: usize,
400        g: usize,
401        b: usize,
402        out: usize,
403        rows: u32,
404        h: u32,
405        eps: f32,
406        has_bias: bool,
407    },
408    /// out = RmsNorm(x + residual + bias, gamma, beta)
409    FusedResidualRmsNorm {
410        x: usize,
411        res: usize,
412        bias: usize,
413        g: usize,
414        b: usize,
415        out: usize,
416        rows: u32,
417        h: u32,
418        eps: f32,
419        has_bias: bool,
420    },
421    /// out = bias_add(data, bias, m, n) for Binary::Add with broadcast
422    BiasAdd {
423        src: usize,
424        bias: usize,
425        dst: usize,
426        m: u32,
427        n: u32,
428    },
429    /// Element-wise binary op with NumPy-style broadcast.
430    ///
431    /// Fast path (`lhs_len == rhs_len == len`): plain element-wise loop,
432    /// SIMD-vectorized on aarch64 for `Add`/`Mul`. `bcast_*` fields
433    /// are unused.
434    ///
435    /// Broadcast path: uses `out_dims_bcast` + `bcast_lhs_strides` +
436    /// `bcast_rhs_strides` to compute per-cell indices into each
437    /// operand. The strides are precomputed at thunk-construction
438    /// time from the operands' true shapes (with stride 0 on any axis
439    /// where the operand has size 1). This is the only correct way
440    /// to handle bidirectional broadcasts like `[N, 1] op [1, S]
441    /// → [N, S]`, which simple `i % lhs_len` modulo indexing maps to
442    /// wrong cells.
443    BinaryFull {
444        lhs: usize,
445        rhs: usize,
446        dst: usize,
447        len: u32,
448        lhs_len: u32,
449        rhs_len: u32,
450        op: BinaryOp,
451        /// Output shape dims (row-major). Empty in the fast path.
452        out_dims_bcast: Vec<u32>,
453        /// Per-dim stride into `lhs` (0 where lhs broadcasts).
454        bcast_lhs_strides: Vec<u32>,
455        /// Per-dim stride into `rhs`.
456        bcast_rhs_strides: Vec<u32>,
457    },
458    /// Activation in-place
459    ActivationInPlace {
460        data: usize,
461        len: u32,
462        act: Activation,
463    },
464    /// Gather axis=0: table\[idx\] → out
465    Gather {
466        table: usize,
467        table_len: u32,
468        idx: usize,
469        dst: usize,
470        num_idx: u32,
471        trailing: u32,
472    },
473    /// Narrow: copy slice (`elem_bytes` = source element size: 4 for f32, 8 for f64).
474    Narrow {
475        src: usize,
476        dst: usize,
477        outer: u32,
478        src_stride: u32,
479        dst_stride: u32,
480        inner: u32,
481        elem_bytes: u8,
482    },
483    /// Copy (reshape, expand)
484    Copy { src: usize, dst: usize, len: u32 },
485    /// LayerNorm standalone
486    LayerNorm {
487        src: usize,
488        g: usize,
489        b: usize,
490        dst: usize,
491        rows: u32,
492        h: u32,
493        eps: f32,
494    },
495    /// GroupNorm on NCHW `[N,C,H,W]`.
496    GroupNorm {
497        src: usize,
498        g: usize,
499        b: usize,
500        dst: usize,
501        n: u32,
502        c: u32,
503        h: u32,
504        w: u32,
505        num_groups: u32,
506        eps: f32,
507    },
508    /// LayerNorm2d on NCHW (SAM / candle semantics).
509    LayerNorm2d {
510        src: usize,
511        g: usize,
512        b: usize,
513        dst: usize,
514        n: u32,
515        c: u32,
516        h: u32,
517        w: u32,
518        eps: f32,
519    },
520    /// ConvTranspose2d on NCHW.
521    ConvTranspose2d {
522        src: usize,
523        weight: usize,
524        dst: usize,
525        n: u32,
526        c_in: u32,
527        h: u32,
528        w_in: u32,
529        c_out: u32,
530        h_out: u32,
531        w_out: u32,
532        kh: u32,
533        kw: u32,
534        sh: u32,
535        sw: u32,
536        ph: u32,
537        pw: u32,
538        dh: u32,
539        dw: u32,
540        groups: u32,
541    },
542    /// Nearest 2× upsample on NCHW (per-batch slice).
543    ResizeNearest2x {
544        src: usize,
545        dst: usize,
546        n: u32,
547        c: u32,
548        h: u32,
549        w: u32,
550    },
551    /// SAM2 axial 2-D RoPE on `[batch, seq, num_heads * head_dim]`.
552    AxialRope2d {
553        src: usize,
554        dst: usize,
555        batch: u32,
556        seq: u32,
557        hidden: u32,
558        end_x: u32,
559        end_y: u32,
560        head_dim: u32,
561        num_heads: u32,
562        theta: f32,
563        repeat_factor: u32,
564    },
565    /// RMSNorm: out = (x / sqrt(mean(x^2) + eps)) * gamma + beta. No mean
566    /// subtraction, hence cheaper than LayerNorm. Used by Llama-class models.
567    RmsNorm {
568        src: usize,
569        g: usize,
570        b: usize,
571        dst: usize,
572        rows: u32,
573        h: u32,
574        eps: f32,
575    },
576    /// Softmax
577    Softmax { data: usize, rows: u32, cols: u32 },
578    /// Inclusive (or exclusive) cumulative sum along the last axis
579    /// (callers pre-flatten higher-dim cumsums via reshape views).
580    Cumsum {
581        src: usize,
582        dst: usize,
583        rows: u32,
584        cols: u32,
585        exclusive: bool,
586    },
587    /// Mamba-style selective scan (plan #15).
588    /// Inputs: x, delta \[b,s,h\], a \[h,n\], b \[b,s,n\], c \[b,s,n\].
589    /// Output: y \[b,s,h\]. State h carries through the seq.
590    SelectiveScan {
591        x: usize,
592        delta: usize,
593        a: usize,
594        b: usize,
595        c: usize,
596        dst: usize,
597        batch: u32,
598        seq: u32,
599        hidden: u32,
600        state_size: u32,
601    },
602
603    /// Gated DeltaNet linear-attention scan (Qwen3.5/3.6 trunk).
604    /// Inputs: q, k, v `[b, s, h, n]`; g, beta `[b, s, h]`. Output:
605    /// `[b, s, h, n]`. See `Op::GatedDeltaNet` for math.
606    GatedDeltaNet {
607        q: usize,
608        k: usize,
609        v: usize,
610        g: usize,
611        beta: usize,
612        /// When non-zero, load initial `[b, h, n, n]` state and write
613        /// the final state back in place after the scan.
614        state: usize,
615        dst: usize,
616        batch: u32,
617        seq: u32,
618        heads: u32,
619        state_size: u32,
620    },
621
622    /// 1×1 conv fast path (plan #26). The general Conv2D thunk
623    /// runs the textbook 7-deep loop; a 1×1 stride-1 padding-0
624    /// groups-1 conv is mathematically a per-batch matmul, and
625    /// dispatching it through BLAS is 3-10× faster than the
626    /// scalar nest. Common case: ViT patch-projection follow-on,
627    /// transformer "expert" reductions in some MoE designs.
628    ///
629    /// Per batch: weight `[c_out, c_in]` × input `[c_in, h*w]`
630    ///         = output `[c_out, h*w]`.
631    Conv2D1x1 {
632        src: usize,
633        weight: usize,
634        dst: usize,
635        n: u32,
636        c_in: u32,
637        c_out: u32,
638        hw: u32,
639    },
640
641    /// Fused dequant + matmul (plan #5). Today supports
642    /// `QuantScheme::Int8Block` (symmetric); other schemes panic
643    /// at lowering time with a clear message until kernels are added.
644    DequantMatMul {
645        x: usize,
646        w_q: usize,   // packed i8 bytes for Int8 schemes
647        scale: usize, // [k/block, n] f32 scale
648        zp: usize,    // [k/block, n] f32 zero-point (0 for sym)
649        dst: usize,
650        m: u32,
651        k: u32,
652        n: u32,
653        block_size: u32,
654        is_asymmetric: bool,
655    },
656
657    /// GGUF-format dequant + matmul. Weight is a packed byte tensor
658    /// in one of the K-quant super-block layouts (Q4_K, Q5_K, Q6_K,
659    /// Q8_K). Scales / mins live inside the packed bytes — no
660    /// side-channel scale tensor.
661    ///
662    /// Today this is a "dequant-to-scratch then sgemm" kernel — it
663    /// keeps the *arena* memory footprint down (weights stay packed)
664    /// but the dequant itself happens per matmul. A future fully
665    /// fused tile-streaming kernel would close the compute gap.
666    DequantMatMulGguf {
667        x: usize,   // f32 activations [m, k]
668        w_q: usize, // packed weight bytes (k*n elements packed)
669        dst: usize, // f32 output [m, n]
670        m: u32,
671        k: u32,
672        n: u32,
673        scheme: rlx_ir::quant::QuantScheme,
674    },
675
676    /// Int4 block dequant + matmul (packed nibbles, side scale/zp).
677    DequantMatMulInt4 {
678        x: usize,
679        w_q: usize,
680        scale: usize,
681        zp: usize,
682        dst: usize,
683        m: u32,
684        k: u32,
685        n: u32,
686        block_size: u32,
687        is_asymmetric: bool,
688    },
689
690    /// FP8 dequant + matmul (per-tensor or per-column scale).
691    DequantMatMulFp8 {
692        x: usize,
693        w_q: usize,
694        scale: usize,
695        dst: usize,
696        m: u32,
697        k: u32,
698        n: u32,
699        e5m2: bool,
700    },
701
702    /// NVFP4 (E2M1) block dequant + matmul — 16-wide groups, FP8 scales.
703    DequantMatMulNvfp4 {
704        x: usize,
705        w_q: usize,
706        scale: usize,
707        global_scale: usize,
708        dst: usize,
709        m: u32,
710        k: u32,
711        n: u32,
712    },
713
714    /// Fused LoRA matmul (plan #9): out = x·W + scale * (x·A)·B.
715    /// `r` is the LoRA rank (typically 4-64) — the rank-r
716    /// intermediate `x·A` lives in scratch, never on the arena.
717    LoraMatMul {
718        x: usize,
719        w: usize,
720        a: usize,
721        b: usize,
722        dst: usize,
723        m: u32,
724        k: u32,
725        n: u32,
726        r: u32,
727        scale: f32,
728    },
729    /// Fused sample: logits [batch, vocab] → token ids \[batch\].
730    /// See Op::Sample. Output values are f32-encoded usize indices
731    /// (matches the rest of the IR's "ids as f32" convention).
732    Sample {
733        logits: usize,
734        dst: usize,
735        batch: u32,
736        vocab: u32,
737        top_k: u32,       // 0 = disabled
738        top_p: f32,       // 1.0 = disabled
739        temperature: f32, // 1.0 = neutral
740        seed: u64,
741    },
742    /// Attention SDPA. `mask` is the offset of the optional mask tensor
743    /// (only meaningful when `mask_kind == MaskKind::Custom`); other
744    /// kinds synthesize the mask in-kernel.
745    ///
746    /// Q/K/V each carry a `_row_stride` (elements per source row).
747    /// Defaults to `heads * head_dim` — matches the standalone
748    /// "Q/K/V are their own contiguous buffers" case. The Narrow→
749    /// Attention fusion below rewrites these to the parent QKV stride
750    /// (typically `3 * heads * head_dim`) so the kernel reads QKV
751    /// directly without materializing the per-head buffers (plan #46).
752    Attention {
753        q: usize,
754        k: usize,
755        v: usize,
756        mask: usize,
757        out: usize,
758        batch: u32,
759        /// Query sequence length.
760        seq: u32,
761        /// Key/value sequence length. Differs from `seq` during cached decode.
762        kv_seq: u32,
763        heads: u32,
764        head_dim: u32,
765        mask_kind: rlx_ir::op::MaskKind,
766        q_row_stride: u32,
767        k_row_stride: u32,
768        v_row_stride: u32,
769        /// Memory layout flag. `false` (the historical default) →
770        /// `[B, S, H, D]` row-major: per-head offset is
771        /// `bi*S*H*D + si*H*D + hi*D`. `true` → `[B, H, S, D]`
772        /// (head-major), matching the convention used by rlx-cuda /
773        /// rlx-rocm / rlx-tpu: per-head offset is
774        /// `bi*H*S*D + hi*S*D + si*D`. Detected at lowering time
775        /// from the input shape vs `num_heads` / `head_dim`.
776        bhsd: bool,
777    },
778    /// [`Op::AttentionBackward`] — emits dQ, dK, or dV (see `wrt`).
779    AttentionBackward {
780        q: usize,
781        k: usize,
782        v: usize,
783        dy: usize,
784        mask: usize,
785        out: usize,
786        batch: u32,
787        seq: u32,
788        kv_seq: u32,
789        heads: u32,
790        head_dim: u32,
791        mask_kind: rlx_ir::op::MaskKind,
792        wrt: rlx_ir::op::AttentionBwdWrt,
793        bhsd: bool,
794    },
795    /// RoPE (rotary position embeddings).
796    /// `src_row_stride` is elements per source row (defaults to `hidden`
797    /// for the standalone case; set to `qkv_axis * inner` when the
798    /// thunk fusion pass below rewires Rope to read directly from the
799    /// fused QKV buffer — plan #45).
800    Rope {
801        src: usize,
802        cos: usize,
803        sin: usize,
804        dst: usize,
805        batch: u32,
806        seq: u32,
807        hidden: u32,
808        head_dim: u32,
809        n_rot: u32,
810        cos_len: u32,
811        src_row_stride: u32,
812    },
813    /// Fused attention block: QKV proj → split → \[RoPE\] → SDPA → output proj.
814    /// All intermediates stay in L1 cache. Zero arena writes between ops.
815    FusedAttnBlock {
816        hidden: usize,
817        qkv_w: usize,
818        out_w: usize,
819        mask: usize,
820        out: usize,
821        qkv_b: usize,
822        out_b: usize, // 0 = no bias
823        cos: usize,
824        sin: usize,
825        cos_len: u32, // 0 = no RoPE
826        batch: u32,
827        seq: u32,
828        hs: u32,
829        nh: u32,
830        dh: u32,
831        has_bias: bool,
832        has_rope: bool,
833    },
834    /// Fused ENTIRE transformer layer: attention + residual + LN + FFN + residual + LN.
835    /// Combines ~10 thunks into 1. All intermediates on stack. Zero arena traffic.
836    FusedBertLayer {
837        // attention
838        hidden: usize,
839        qkv_w: usize,
840        qkv_b: usize,
841        out_w: usize,
842        out_b: usize,
843        mask: usize,
844        // LN1
845        ln1_g: usize,
846        ln1_b: usize,
847        eps1: f32,
848        // FFN (GELU)
849        fc1_w: usize,
850        fc1_b: usize,
851        fc2_w: usize,
852        fc2_b: usize,
853        // LN2
854        ln2_g: usize,
855        ln2_b: usize,
856        eps2: f32,
857        // output
858        out: usize,
859        // dims
860        batch: u32,
861        seq: u32,
862        hs: u32,
863        nh: u32,
864        dh: u32,
865        int_dim: u32,
866    },
867    /// Fused Nomic transformer layer: attention+RoPE + residual + LN + SwiGLU FFN + residual + LN.
868    FusedNomicLayer {
869        hidden: usize,
870        qkv_w: usize,
871        out_w: usize,
872        mask: usize,
873        cos: usize,
874        sin: usize,
875        cos_len: u32,
876        ln1_g: usize,
877        ln1_b: usize,
878        eps1: f32,
879        fc11_w: usize,
880        fc12_w: usize,
881        fc2_w: usize,
882        ln2_g: usize,
883        ln2_b: usize,
884        eps2: f32,
885        out: usize,
886        batch: u32,
887        seq: u32,
888        hs: u32,
889        nh: u32,
890        dh: u32,
891        int_dim: u32,
892    },
893    /// Fused SwiGLU: out\[r,i\] = x\[r,i\] * silu(x[r, n_half+i]).
894    /// Input: [outer, 2*n_half] — concatenated up||gate per row.
895    /// Output: [outer, n_half].
896    FusedSwiGLU {
897        src: usize,
898        dst: usize,
899        n_half: u32,
900        total: u32,
901        gate_first: bool,
902    },
903    /// Concat along an axis: output[outer, axis, inner] = inputs concatenated.
904    /// Each entry of `inputs` is (src_offset, axis_len_for_that_input) in u32
905    /// elements. `outer`, `inner`, and `total_axis_len` are pre-computed
906    /// at compile time to avoid per-run shape work.
907    Concat {
908        dst: usize,
909        outer: u32,
910        inner: u32,
911        total_axis: u32,
912        inputs: Vec<(usize, u32)>,
913    },
914    /// Element-wise comparison: out = (lhs CMP rhs) ? 1.0 : 0.0
915    Compare {
916        lhs: usize,
917        rhs: usize,
918        dst: usize,
919        len: u32,
920        op: CmpOp,
921    },
922    /// Reduction along a contiguous range of axes. Input layout (after
923    /// shape decomposition) is `[outer, reduced, inner]`; output is
924    /// `[outer, inner]`. The single-axis cases (axis=0 → outer=1;
925    /// axis=last → inner=1) and contiguous multi-axis (e.g. reduce over
926    /// [0, 1] of an [N, C, H, W] tensor → outer=1, reduced=N*C, inner=H*W)
927    /// all map onto this triplet. Non-contiguous axes are not supported
928    /// and bail to Nop in the compile pass.
929    Reduce {
930        src: usize,
931        dst: usize,
932        outer: u32,
933        reduced: u32,
934        inner: u32,
935        op: ReduceOp,
936    },
937    /// Top-K **indices** along the last axis. Input shape `[outer, axis_dim]`,
938    /// output `[outer, k]` of f32-encoded i64 indices. Ties broken by
939    /// smaller index. Used by MoE gating + beam search.
940    TopK {
941        src: usize,
942        dst: usize,
943        outer: u32,
944        axis_dim: u32,
945        k: u32,
946    },
947    /// Indexed batched matmul: out\[i\] = input\[i\] @ weight[expert_idx\[i\]].
948    /// Naive impl per token; for real MoE workloads, sort-by-expert + run
949    /// segmented GEMM would amortize. Done when there's a workload.
950    GroupedMatMul {
951        input: usize,
952        weight: usize,
953        expert_idx: usize,
954        dst: usize,
955        m: u32,
956        k_dim: u32,
957        n: u32,
958        num_experts: u32,
959    },
960    /// GGUF K-quant packed expert stack + grouped matmul (MoE FFN).
961    DequantGroupedMatMulGguf {
962        input: usize,
963        w_q: usize,
964        expert_idx: usize,
965        dst: usize,
966        m: u32,
967        k_dim: u32,
968        n: u32,
969        num_experts: u32,
970        scheme: rlx_ir::quant::QuantScheme,
971    },
972    /// Materialize packed MoE weights to F32 `[E, K, N]` (autodiff helper).
973    DequantMoEWeightsGguf {
974        w_q: usize,
975        dst: usize,
976        k_dim: u32,
977        n: u32,
978        num_experts: u32,
979        scheme: rlx_ir::quant::QuantScheme,
980    },
981    /// Scatter-add: dst[indices\[i\] * trailing + j] += updates[i * trailing + j].
982    /// Output is zeroed first; multiple updates to the same row accumulate.
983    ScatterAdd {
984        updates: usize,
985        indices: usize,
986        dst: usize,
987        num_updates: u32,
988        out_dim: u32,
989        trailing: u32,
990    },
991    /// Ternary select: out = cond != 0 ? on_true : on_false
992    Where {
993        cond: usize,
994        on_true: usize,
995        on_false: usize,
996        dst: usize,
997        len: u32,
998    },
999    /// General N-D transpose / broadcast. `out_dims[i]` is the output's dim
1000    /// i length; `in_strides[i]` is the input stride (in elements) used to
1001    /// index that dim — 0 for broadcast dims (Expand). `in_total` is the
1002    /// total element count in the source buffer (≤ output total when
1003    /// broadcasting). Strides are pre-computed at compile time.
1004    Transpose {
1005        src: usize,
1006        dst: usize,
1007        in_total: u32,
1008        out_dims: Vec<u32>,
1009        in_strides: Vec<u32>,
1010    },
1011    /// Gather along an arbitrary axis. `outer = product(dims[..axis])`,
1012    /// `trailing = product(dims[axis+1..])`, `axis_dim` = the dimension
1013    /// being indexed into. Output: outer × num_idx × trailing.
1014    /// (axis=0 still routes to the simpler Thunk::Gather fast path.)
1015    GatherAxis {
1016        table: usize,
1017        idx: usize,
1018        dst: usize,
1019        outer: u32,
1020        axis_dim: u32,
1021        num_idx: u32,
1022        trailing: u32,
1023    },
1024    /// 2D pooling (Max or Mean). Input layout [N, C, H, W], output
1025    /// [N, C, H_out, W_out]. Padding is implicit-zero; Mean divides by
1026    /// the full kernel area (matches torch's `count_include_pad=True`).
1027    Pool2D {
1028        src: usize,
1029        dst: usize,
1030        n: u32,
1031        c: u32,
1032        h: u32,
1033        w: u32,
1034        h_out: u32,
1035        w_out: u32,
1036        kh: u32,
1037        kw: u32,
1038        sh: u32,
1039        sw: u32,
1040        ph: u32,
1041        pw: u32,
1042        kind: ReduceOp,
1043    },
1044    /// 2D convolution. Input [N, C_in, H, W], weight [C_out, C_in_per_group, kH, kW],
1045    /// output [N, C_out, H_out, W_out]. Bias is a separate Op::Binary::Add
1046    /// after the conv (matching the IR's input layout — Op::Conv has 2 inputs).
1047    /// Naive direct convolution; sufficient for correctness, not optimised.
1048    Conv2D {
1049        src: usize,
1050        weight: usize,
1051        dst: usize,
1052        n: u32,
1053        c_in: u32,
1054        h: u32,
1055        w: u32,
1056        c_out: u32,
1057        h_out: u32,
1058        w_out: u32,
1059        kh: u32,
1060        kw: u32,
1061        sh: u32,
1062        sw: u32,
1063        ph: u32,
1064        pw: u32,
1065        dh: u32,
1066        dw: u32,
1067        groups: u32,
1068    },
1069
1070    // ── Backward / training kernels ─────────────────────────────
1071    /// Real INT8 matmul with i32 accumulation.
1072    ///   `out[m, n] = requantize(bias[n] + Σₖ (x[m,k]-x_zp)·(w[k,n]-w_zp), mult, out_zp)`
1073    /// Reads `x` and `w` as i8, `bias` as i32; writes `out` as i8.
1074    /// Same kernel shape as `rlx_cortexm::dense::dense_i8` — promoted
1075    /// to a desktop thunk so a quantized graph compiled here doesn't
1076    /// have to round-trip through fake-quant.
1077    QMatMul {
1078        x: usize,
1079        w: usize,
1080        bias: usize,
1081        out: usize,
1082        m: u32,
1083        k: u32,
1084        n: u32,
1085        x_zp: i32,
1086        w_zp: i32,
1087        out_zp: i32,
1088        mult: f32,
1089    },
1090
1091    /// Real INT8 conv2d, NCHW layout. Same loop shape as `Thunk::Conv2D`
1092    /// but with i8 reads, i32 accumulation, and per-output requantize
1093    /// to i8. Bias is i32 in the accumulator scale.
1094    QConv2d {
1095        x: usize,
1096        w: usize,
1097        bias: usize,
1098        out: usize,
1099        n: u32,
1100        c_in: u32,
1101        h: u32,
1102        w_in: u32,
1103        c_out: u32,
1104        h_out: u32,
1105        w_out: u32,
1106        kh: u32,
1107        kw: u32,
1108        sh: u32,
1109        sw: u32,
1110        ph: u32,
1111        pw: u32,
1112        dh: u32,
1113        dw: u32,
1114        groups: u32,
1115        x_zp: i32,
1116        w_zp: i32,
1117        out_zp: i32,
1118        mult: f32,
1119    },
1120
1121    /// INT8 quantize. Reads `x` as f32, writes `q` as i8.
1122    /// `chan = (i / inner) % chan_dim` selects the per-channel
1123    /// scale/zp; `chan_axis` is informational only (the kernel uses
1124    /// `chan_dim` and `inner` directly).
1125    /// For per-tensor, `chan_dim = 1` and `inner = len` so `chan` is
1126    /// always 0.
1127    Quantize {
1128        x: usize,
1129        q: usize,
1130        len: u32,
1131        chan_axis: u32,
1132        chan_dim: u32,
1133        inner: u32,
1134        scales: Vec<f32>,
1135        zero_points: Vec<i32>,
1136    },
1137
1138    /// INT8 dequantize — inverse of `Thunk::Quantize`.
1139    Dequantize {
1140        q: usize,
1141        x: usize,
1142        len: u32,
1143        chan_axis: u32,
1144        chan_dim: u32,
1145        inner: u32,
1146        scales: Vec<f32>,
1147        zero_points: Vec<i32>,
1148    },
1149
1150    /// QAT fake-quantize. Per-channel (or per-tensor) symmetric
1151    /// quantize-then-dequantize on the fly. Computes
1152    ///   `s[c] = max(|x[..., c, ...]|) / q_max`
1153    /// then
1154    ///   `out[i] = clamp(round(x[i]/s[c]), -q_max, q_max) * s[c]`
1155    /// with `q_max = {127, 7, 1}` for `bits = {8, 4, 2}`. Same
1156    /// channel-layout convention as `Thunk::Quantize`: every
1157    /// element's channel is `(i / inner) % chan_dim`. The kernel
1158    /// does two passes — one to scan max-abs per channel, one to
1159    /// quant-dequant per element.
1160    FakeQuantize {
1161        x: usize,
1162        out: usize,
1163        len: u32,
1164        chan_axis: u32,
1165        chan_dim: u32,
1166        inner: u32,
1167        bits: u8,
1168        /// STE variant — informational on the forward side (output is
1169        /// the same regardless), kernel-relevant in the matching
1170        /// `FakeQuantizeBackward` thunk.
1171        ste: rlx_ir::op::SteKind,
1172        /// Scale-tracking strategy. `PerBatch` recomputes
1173        /// `max_abs/q_max` every call (the original path). `EMA{decay}`
1174        /// blends per-batch max-abs into the `state_off` buffer; `Fixed`
1175        /// reads `state_off` and never updates it.
1176        scale_mode: rlx_ir::op::ScaleMode,
1177        /// `Some(off)` for `EMA` and `Fixed`; `None` for `PerBatch`.
1178        /// Points at a `[chan_dim]` f32 buffer holding the running scale
1179        /// per channel.
1180        state_off: Option<usize>,
1181    },
1182
1183    /// Backward pass for `Op::FakeQuantize` under one of four STE
1184    /// variants. Computes `dx[i]` from the f32 forward input `x` and
1185    /// the upstream gradient `dy`, using the same per-channel scale
1186    /// scheme as the forward.
1187    FakeQuantizeBackward {
1188        x: usize,
1189        dy: usize,
1190        dx: usize,
1191        len: u32,
1192        chan_axis: u32,
1193        chan_dim: u32,
1194        inner: u32,
1195        bits: u8,
1196        ste: rlx_ir::op::SteKind,
1197    },
1198
1199    /// LSQ forward — same kernel shape as `FakeQuantize` Fixed mode.
1200    /// Reads scale from `scale_off` (a `[chan_dim]` Param tensor).
1201    FakeQuantizeLSQ {
1202        x: usize,
1203        scale_off: usize,
1204        out: usize,
1205        len: u32,
1206        chan_axis: u32,
1207        chan_dim: u32,
1208        inner: u32,
1209        bits: u8,
1210    },
1211
1212    /// LSQ backward, x-gradient. STE-clipped: passes upstream
1213    /// through inside the quantization range, zeros outside.
1214    FakeQuantizeLSQBackwardX {
1215        x: usize,
1216        scale_off: usize,
1217        dy: usize,
1218        dx: usize,
1219        len: u32,
1220        chan_axis: u32,
1221        chan_dim: u32,
1222        inner: u32,
1223        bits: u8,
1224    },
1225
1226    /// LSQ backward, scale-gradient. Per-channel:
1227    ///   `dscale[c] = sum_i ψ(x[i]/s[c]) · upstream[i]`
1228    /// where `ψ(z) = -z + round(z)` if `|z| ≤ q_max` else
1229    /// `sign(z) · q_max`. Output shape: `[chan_dim]`.
1230    FakeQuantizeLSQBackwardScale {
1231        x: usize,
1232        scale_off: usize,
1233        dy: usize,
1234        dscale: usize,
1235        len: u32,
1236        chan_axis: u32,
1237        chan_dim: u32,
1238        inner: u32,
1239        bits: u8,
1240    },
1241
1242    /// ReLU backward: `dx[i] = dy[i] if x[i] > 0 else 0`.
1243    ReluBackward {
1244        x: usize,
1245        dy: usize,
1246        dx: usize,
1247        len: u32,
1248    },
1249    /// f64 sibling of `ReluBackward` — same shape as the f32 variant
1250    /// but reads/writes 8 bytes per element. Required because
1251    /// `ReluBackward`'s `&[f32]` slot view returns half of every f64
1252    /// otherwise → backward silently produces 0 gradients on an f64
1253    /// graph. Mirrors the `ActivationBackwardF64` split.
1254    ReluBackwardF64 {
1255        x: usize,
1256        dy: usize,
1257        dx: usize,
1258        len: u32,
1259    },
1260
1261    /// Generic element-wise activation backward.
1262    /// `dx[i] = (d/dx act(x))[i] · dy[i]`. The closure dispatch is
1263    /// per-element; expensive activations (Gelu) recompute internals
1264    /// inline rather than threading an extra "saved y" tensor through.
1265    ActivationBackward {
1266        x: usize,
1267        dy: usize,
1268        dx: usize,
1269        len: u32,
1270        kind: Activation,
1271    },
1272    /// f64 sibling of `ActivationBackward` — slot offsets, len in
1273    /// elements; kernel reads/writes 8 bytes per element. Required
1274    /// because `ActivationBackward`'s `&[f32]` slot view silently
1275    /// returns garbage on an f64 graph (cb % 4 still works but every
1276    /// loaded value is half of an f64 → wrong gradient).
1277    ActivationBackwardF64 {
1278        x: usize,
1279        dy: usize,
1280        dx: usize,
1281        len: u32,
1282        kind: Activation,
1283    },
1284
1285    /// LayerNorm backward — input gradient. Recomputes mean/var/x̂ from
1286    /// `x` and emits the closed-form `d_x` per row.
1287    LayerNormBackwardInput {
1288        x: usize,
1289        gamma: usize,
1290        dy: usize,
1291        dx: usize,
1292        rows: u32,
1293        h: u32,
1294        eps: f32,
1295    },
1296
1297    /// LayerNorm backward — gamma gradient. `d_gamma[d] = Σ_row dy·x̂`.
1298    LayerNormBackwardGamma {
1299        x: usize,
1300        dy: usize,
1301        dgamma: usize,
1302        rows: u32,
1303        h: u32,
1304        eps: f32,
1305    },
1306
1307    RmsNormBackwardInput {
1308        x: usize,
1309        gamma: usize,
1310        beta: usize,
1311        dy: usize,
1312        dx: usize,
1313        rows: u32,
1314        h: u32,
1315        eps: f32,
1316    },
1317    RmsNormBackwardGamma {
1318        x: usize,
1319        gamma: usize,
1320        beta: usize,
1321        dy: usize,
1322        dgamma: usize,
1323        rows: u32,
1324        h: u32,
1325        eps: f32,
1326    },
1327    RmsNormBackwardBeta {
1328        x: usize,
1329        gamma: usize,
1330        beta: usize,
1331        dy: usize,
1332        dbeta: usize,
1333        rows: u32,
1334        h: u32,
1335        eps: f32,
1336    },
1337    RopeBackward {
1338        dy: usize,
1339        cos: usize,
1340        sin: usize,
1341        dx: usize,
1342        batch: u32,
1343        seq: u32,
1344        hidden: u32,
1345        head_dim: u32,
1346        n_rot: u32,
1347        cos_len: u32,
1348    },
1349    CumsumBackward {
1350        dy: usize,
1351        dx: usize,
1352        rows: u32,
1353        cols: u32,
1354        exclusive: bool,
1355    },
1356    GatherBackward {
1357        dy: usize,
1358        indices: usize,
1359        dst: usize,
1360        outer: u32,
1361        axis_dim: u32,
1362        num_idx: u32,
1363        trailing: u32,
1364    },
1365
1366    GroupNormBackwardInput {
1367        x: usize,
1368        gamma: usize,
1369        beta: usize,
1370        dy: usize,
1371        dx: usize,
1372        n: u32,
1373        c: u32,
1374        h: u32,
1375        w: u32,
1376        num_groups: u32,
1377        eps: f32,
1378    },
1379    GroupNormBackwardGamma {
1380        x: usize,
1381        dy: usize,
1382        dgamma: usize,
1383        n: u32,
1384        c: u32,
1385        h: u32,
1386        w: u32,
1387        num_groups: u32,
1388        eps: f32,
1389    },
1390    GroupNormBackwardBeta {
1391        dy: usize,
1392        dbeta: usize,
1393        n: u32,
1394        c: u32,
1395        h: u32,
1396        w: u32,
1397    },
1398
1399    /// 2D max-pool backward (NCHW). Recomputes the argmax position
1400    /// inside each window and accumulates `dy` into `dx` at that
1401    /// position. Output is zeroed first; ties resolve to the first
1402    /// hit (lowest (kh,kw) index), matching what the forward kernel
1403    /// does with `acc.max(v)`.
1404    MaxPool2dBackward {
1405        x: usize,
1406        dy: usize,
1407        dx: usize,
1408        n: u32,
1409        c: u32,
1410        h: u32,
1411        w: u32,
1412        h_out: u32,
1413        w_out: u32,
1414        kh: u32,
1415        kw: u32,
1416        sh: u32,
1417        sw: u32,
1418        ph: u32,
1419        pw: u32,
1420    },
1421
1422    /// 2D conv backward w.r.t. input (`dx = conv_transpose(dy, w)`).
1423    /// `dy [N, C_out, H_out, W_out]`, `w [C_out, C_in_per_group, kH, kW]`,
1424    /// `dx [N, C_in, H, W]`.
1425    Conv2dBackwardInput {
1426        dy: usize,
1427        w: usize,
1428        dx: usize,
1429        n: u32,
1430        c_in: u32,
1431        h: u32,
1432        w_in: u32,
1433        c_out: u32,
1434        h_out: u32,
1435        w_out: u32,
1436        kh: u32,
1437        kw: u32,
1438        sh: u32,
1439        sw: u32,
1440        ph: u32,
1441        pw: u32,
1442        dh: u32,
1443        dw: u32,
1444        groups: u32,
1445    },
1446
1447    /// 2D conv backward w.r.t. weight. `x [N, C_in, H, W]`,
1448    /// `dy [N, C_out, H_out, W_out]`, `dw [C_out, C_in_per_group, kH, kW]`.
1449    /// `dw` is zeroed before accumulation.
1450    Conv2dBackwardWeight {
1451        x: usize,
1452        dy: usize,
1453        dw: usize,
1454        n: u32,
1455        c_in: u32,
1456        h: u32,
1457        w: u32,
1458        c_out: u32,
1459        h_out: u32,
1460        w_out: u32,
1461        kh: u32,
1462        kw: u32,
1463        sh: u32,
1464        sw: u32,
1465        ph: u32,
1466        pw: u32,
1467        dh: u32,
1468        dw_dil: u32,
1469        groups: u32,
1470    },
1471
1472    /// Fused softmax + cross-entropy loss with f32-encoded integer
1473    /// labels. `logits [N, C]`, `labels [N]`, output `[N]` per-row loss.
1474    /// Numerically stable (max-subtract before exp).
1475    SoftmaxCrossEntropy {
1476        logits: usize,
1477        labels: usize,
1478        dst: usize,
1479        n: u32,
1480        c: u32,
1481    },
1482
1483    /// Backward of the fused loss above.
1484    /// `dlogits[n, k] = (softmax(logits[n])[k] - one_hot(labels[n])[k]) * d_loss[n]`.
1485    SoftmaxCrossEntropyBackward {
1486        logits: usize,
1487        labels: usize,
1488        d_loss: usize,
1489        dlogits: usize,
1490        n: u32,
1491        c: u32,
1492    },
1493
1494    /// User-registered custom op (CPU side). Lowered from `Op::Custom`.
1495    /// `kernel` is resolved against the global CPU kernel registry at
1496    /// compile time and stored as `Arc<dyn CpuKernel>` so execution
1497    /// avoids per-call lookups. v1: f32 contiguous only — see
1498    /// `op_registry::CpuKernel::execute_f32`.
1499    CustomOp {
1500        kernel: Arc<dyn CpuKernel>,
1501        inputs: Vec<(usize, u32, Shape)>, // (offset, len_elements, shape)
1502        output: (usize, u32, Shape),      // (offset, len_elements, shape)
1503        attrs: Vec<u8>,
1504    },
1505
1506    /// 1D FFT along the last axis. Input/output are `[..., 2N]`
1507    /// real-block layout (first N real, second N imag along the
1508    /// transformed axis). `outer` is the product of all leading axes;
1509    /// `n_complex` is N (the number of complex points). Both halves
1510    /// of the real-block layout are read together by the kernel.
1511    /// `dtype` selects the f32 or f64 path; the two share structure
1512    /// but not buffers, so a flag at compile time avoids per-row
1513    /// dispatch.
1514    /// CPU reference 3D Gaussian splat render ([`rlx_ir::Op::GaussianSplatRender`]).
1515    GaussianSplatRender {
1516        positions_off: usize,
1517        positions_len: usize,
1518        scales_off: usize,
1519        scales_len: usize,
1520        rotations_off: usize,
1521        rotations_len: usize,
1522        opacities_off: usize,
1523        opacities_len: usize,
1524        colors_off: usize,
1525        colors_len: usize,
1526        sh_coeffs_off: usize,
1527        sh_coeffs_len: usize,
1528        meta_off: usize,
1529        dst_off: usize,
1530        dst_len: usize,
1531        width: u32,
1532        height: u32,
1533        tile_size: u32,
1534        radius_scale: f32,
1535        alpha_cutoff: f32,
1536        max_splat_steps: u32,
1537        transmittance_threshold: f32,
1538        max_list_entries: u32,
1539    },
1540    GaussianSplatRenderBackward {
1541        positions_off: usize,
1542        positions_len: usize,
1543        scales_off: usize,
1544        scales_len: usize,
1545        rotations_off: usize,
1546        rotations_len: usize,
1547        opacities_off: usize,
1548        opacities_len: usize,
1549        colors_off: usize,
1550        colors_len: usize,
1551        sh_coeffs_off: usize,
1552        sh_coeffs_len: usize,
1553        meta_off: usize,
1554        d_loss_off: usize,
1555        d_loss_len: usize,
1556        packed_off: usize,
1557        packed_len: usize,
1558        width: u32,
1559        height: u32,
1560        tile_size: u32,
1561        radius_scale: f32,
1562        alpha_cutoff: f32,
1563        max_splat_steps: u32,
1564        transmittance_threshold: f32,
1565        max_list_entries: u32,
1566        loss_grad_clip: f32,
1567        sh_band: u32,
1568        max_anisotropy: f32,
1569    },
1570    /// Strict IR stage 1 — project + bin + sort + rays ([`Op::GaussianSplatPrepare`]).
1571    GaussianSplatPrepare {
1572        positions_off: usize,
1573        positions_len: usize,
1574        scales_off: usize,
1575        scales_len: usize,
1576        rotations_off: usize,
1577        rotations_len: usize,
1578        opacities_off: usize,
1579        opacities_len: usize,
1580        colors_off: usize,
1581        colors_len: usize,
1582        sh_coeffs_off: usize,
1583        sh_coeffs_len: usize,
1584        meta_off: usize,
1585        meta_len: usize,
1586        prep_off: usize,
1587        prep_len: usize,
1588        width: u32,
1589        height: u32,
1590        tile_size: u32,
1591        radius_scale: f32,
1592        alpha_cutoff: f32,
1593        max_splat_steps: u32,
1594        transmittance_threshold: f32,
1595        max_list_entries: u32,
1596    },
1597    /// Strict IR stage 2 — tile raster from prepare buffer ([`Op::GaussianSplatRasterize`]).
1598    GaussianSplatRasterize {
1599        prep_off: usize,
1600        prep_len: usize,
1601        meta_off: usize,
1602        meta_len: usize,
1603        dst_off: usize,
1604        dst_len: usize,
1605        count: usize,
1606        width: u32,
1607        height: u32,
1608        tile_size: u32,
1609        alpha_cutoff: f32,
1610        max_splat_steps: u32,
1611        transmittance_threshold: f32,
1612        max_list_entries: u32,
1613    },
1614    Fft1d {
1615        src: usize,
1616        dst: usize,
1617        outer: u32,
1618        n_complex: u32,
1619        inverse: bool,
1620        dtype: rlx_ir::DType,
1621    },
1622}
1623
1624/// Compiled thunk schedule — the runtime hot path.
1625/// Nop thunks are filtered out at compile time for zero iteration overhead.
1626#[derive(Clone)]
1627pub struct ThunkSchedule {
1628    pub thunks: Vec<Thunk>,
1629    /// TIDE merged placement mask (union across layers).
1630    pub moe_resident: Option<std::sync::Arc<[bool]>>,
1631    /// Per MoE layer placement (`layer[e]`); preferred when set.
1632    pub moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
1633    /// MoE router TopK capture (per-layer refresh).
1634    pub moe_topk_capture: Option<std::sync::Arc<crate::moe_topk_capture::MoeTopkCapture>>,
1635    /// Cached config values.
1636    pub mask_threshold: f32,
1637    pub mask_neg_inf: f32,
1638    pub score_skip: f32,
1639    /// Pre-compiled closure dispatch (zero match overhead). `Arc` (not
1640    /// `Box`) so the schedule can be `Clone` — multiple parallel
1641    /// executors share the same compiled closures (they're read-only
1642    /// `Fn(*mut u8)` so concurrent dispatch is safe; the arena pointer
1643    /// they receive is the only mutable state and is per-executor).
1644    pub compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>>,
1645}
1646
1647impl ThunkSchedule {
1648    pub fn strip_nops(&mut self) {
1649        self.thunks.retain(|t| !matches!(t, Thunk::Nop));
1650        // compiled_fns must be rebuilt after stripping — caller should
1651        // call strip_nops() before compile_closures().
1652        self.compiled_fns.clear();
1653    }
1654}
1655
1656/// Get the arena byte offset for a node.
1657fn node_offset(arena: &Arena, id: NodeId) -> usize {
1658    if arena.has_buffer(id) {
1659        arena.byte_offset(id)
1660    } else {
1661        usize::MAX
1662    }
1663}
1664
1665/// Every byte-offset that a thunk reads from. Used by the Narrow→Rope
1666/// fusion (#45) to verify a Narrow's dst has exactly one consumer
1667/// before eliding it. Conservative: when in doubt about reads (an op
1668/// not yet listed here), the fusion will skip — correctness over
1669/// completeness.
1670fn thunk_read_offsets(t: &Thunk) -> Vec<usize> {
1671    match t {
1672        Thunk::Sgemm { a, b, .. } => vec![*a, *b],
1673        Thunk::DenseSolveF64 { a, b, .. } => vec![*a, *b],
1674        Thunk::DenseSolveF32 { a, b, .. } => vec![*a, *b],
1675        Thunk::BatchedDenseSolveF64 { a, b, .. } => vec![*a, *b],
1676        Thunk::BatchedDgemmF64 { a, b, .. } => vec![*a, *b],
1677        Thunk::BatchedSgemm { a, b, .. } => vec![*a, *b],
1678        Thunk::FusedMmBiasAct { a, w, bias, .. } => vec![*a, *w, *bias],
1679        Thunk::BiasAdd { src, bias, .. } => vec![*src, *bias],
1680        Thunk::BinaryFull { lhs, rhs, .. } => vec![*lhs, *rhs],
1681        Thunk::BinaryFullF64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1682        Thunk::BinaryFullC64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1683        Thunk::ComplexNormSqF32 { src, .. } => vec![*src],
1684        Thunk::ComplexNormSqBackwardF32 { z, g, .. } => vec![*z, *g],
1685        Thunk::ConjugateC64 { src, .. } => vec![*src],
1686        Thunk::Scan {
1687            outer_init_off,
1688            xs_inputs,
1689            ..
1690        } => {
1691            let mut v = vec![*outer_init_off];
1692            for (_, outer_xs_off, _) in xs_inputs.iter() {
1693                v.push(*outer_xs_off);
1694            }
1695            v
1696        }
1697        Thunk::ScanBackward {
1698            outer_init_off,
1699            outer_traj_off,
1700            outer_upstream_off,
1701            outer_xs_offs,
1702            ..
1703        } => {
1704            let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1705            for (off, _) in outer_xs_offs.iter() {
1706                v.push(*off);
1707            }
1708            v
1709        }
1710        Thunk::ScanBackwardXs {
1711            outer_init_off,
1712            outer_traj_off,
1713            outer_upstream_off,
1714            outer_xs_offs,
1715            ..
1716        } => {
1717            let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1718            for (off, _) in outer_xs_offs.iter() {
1719                v.push(*off);
1720            }
1721            v
1722        }
1723        Thunk::CustomFn { inputs, .. } => {
1724            inputs.iter().map(|(_, outer_off, _)| *outer_off).collect()
1725        }
1726        Thunk::ActivationInPlace { data, .. } => vec![*data],
1727        Thunk::LayerNorm { src, g, b, .. } | Thunk::GroupNorm { src, g, b, .. } => {
1728            vec![*src, *g, *b]
1729        }
1730        Thunk::ResizeNearest2x { src, .. } => vec![*src],
1731        Thunk::AxialRope2d { src, .. } => vec![*src],
1732        Thunk::FusedResidualLN {
1733            x, res, bias, g, b, ..
1734        } => vec![*x, *res, *bias, *g, *b],
1735        Thunk::FusedResidualRmsNorm {
1736            x, res, bias, g, b, ..
1737        } => vec![*x, *res, *bias, *g, *b],
1738        Thunk::RmsNorm { src, g, b, .. } => vec![*src, *g, *b],
1739        Thunk::Softmax { data, .. } => vec![*data],
1740        Thunk::Cumsum { src, .. } => vec![*src],
1741        Thunk::Sample { logits, .. } => vec![*logits],
1742        Thunk::LoraMatMul { x, w, a, b, .. } => vec![*x, *w, *a, *b],
1743        Thunk::DequantMatMul {
1744            x, w_q, scale, zp, ..
1745        } => vec![*x, *w_q, *scale, *zp],
1746        Thunk::DequantMatMulGguf { x, w_q, .. } => vec![*x, *w_q],
1747        Thunk::DequantMatMulInt4 {
1748            x, w_q, scale, zp, ..
1749        } => vec![*x, *w_q, *scale, *zp],
1750        Thunk::DequantMatMulFp8 { x, w_q, scale, .. } => vec![*x, *w_q, *scale],
1751        Thunk::DequantMatMulNvfp4 {
1752            x,
1753            w_q,
1754            scale,
1755            global_scale,
1756            ..
1757        } => vec![*x, *w_q, *scale, *global_scale],
1758        Thunk::Conv2D1x1 { src, weight, .. } => vec![*src, *weight],
1759        Thunk::SelectiveScan {
1760            x, delta, a, b, c, ..
1761        } => vec![*x, *delta, *a, *b, *c],
1762        Thunk::GatedDeltaNet {
1763            q,
1764            k,
1765            v,
1766            g,
1767            beta,
1768            state,
1769            ..
1770        } => {
1771            let mut v = vec![*q, *k, *v, *g, *beta];
1772            if *state != 0 {
1773                v.push(*state);
1774            }
1775            v
1776        }
1777        Thunk::Attention { q, k, v, mask, .. } => vec![*q, *k, *v, *mask],
1778        Thunk::AttentionBackward {
1779            q, k, v, dy, mask, ..
1780        } => {
1781            let mut v = vec![*q, *k, *v, *dy];
1782            if *mask != 0 {
1783                v.push(*mask);
1784            }
1785            v
1786        }
1787        Thunk::Rope { src, cos, sin, .. } => vec![*src, *cos, *sin],
1788        Thunk::FusedAttnBlock {
1789            hidden,
1790            qkv_w,
1791            out_w,
1792            mask,
1793            qkv_b,
1794            out_b,
1795            cos,
1796            sin,
1797            ..
1798        } => vec![*hidden, *qkv_w, *out_w, *mask, *qkv_b, *out_b, *cos, *sin],
1799        Thunk::FusedSwiGLU { src, .. } => vec![*src],
1800        Thunk::Concat { inputs, .. } => inputs.iter().map(|(off, _)| *off).collect(),
1801        Thunk::ConcatF64 { inputs, .. } => inputs.iter().map(|(off, _)| *off).collect(),
1802        Thunk::Narrow { src, .. } => vec![*src],
1803        Thunk::Copy { src, .. } => vec![*src],
1804        Thunk::Gather { table, idx, .. } => vec![*table, *idx],
1805        // Anything not enumerated → return the dst as a "read" too,
1806        // forcing the fusion to bail (read_count >= 2 → skip). Keeps
1807        // this list safe to be incomplete.
1808        _ => vec![],
1809    }
1810}
1811
1812/// Fused dequant + matmul (plan #5). Int8-blockwise weights: each
1813/// `block_size` consecutive elements of a column share one f32
1814/// scale (and optionally a zero-point). The dequant happens inside
1815/// the inner accumulate so the f32 weight is never materialized.
1816///
1817/// `w_bytes` is the row-major i8 weight matrix `[k, n]`. `scales`
1818/// and `zps` are `[k/block, n]`. When `asym=false`, `zps` may be
1819/// empty.
1820///
1821/// Today this is the reference scalar implementation — the win is
1822/// memory bandwidth, not flops, since LLM weights dominate the
1823/// working set. A NEON SIMD path that loads 16 i8 → splat-scale →
1824/// fused-multiply-add is the natural follow-on.
1825#[allow(clippy::too_many_arguments)]
1826fn dequant_matmul_int8(
1827    x: &[f32],       // [m, k]
1828    w_bytes: &[i8],  // [k, n]
1829    scales: &[f32],  // [k/block, n]
1830    zps: &[f32],     // [k/block, n] or empty
1831    out: &mut [f32], // [m, n]
1832    m: usize,
1833    k: usize,
1834    n: usize,
1835    block_size: usize,
1836    asym: bool,
1837) {
1838    let blocks_per_col = k.div_ceil(block_size);
1839    for i in 0..m {
1840        for j in 0..n {
1841            let mut acc = 0f32;
1842            for p in 0..k {
1843                let block = p / block_size;
1844                let s = scales[block * n + j];
1845                let z = if asym { zps[block * n + j] } else { 0.0 };
1846                let q = w_bytes[p * n + j] as f32;
1847                let dequantized = (q - z) * s;
1848                acc += x[i * k + p] * dequantized;
1849            }
1850            out[i * n + j] = acc;
1851        }
1852    }
1853    let _ = blocks_per_col;
1854}
1855
1856#[allow(clippy::too_many_arguments)]
1857fn dequant_matmul_int4(
1858    x: &[f32],
1859    w_bytes: &[u8],
1860    scales: &[f32],
1861    zps: &[f32],
1862    out: &mut [f32],
1863    m: usize,
1864    k: usize,
1865    n: usize,
1866    block_size: usize,
1867    asym: bool,
1868) {
1869    for i in 0..m {
1870        for j in 0..n {
1871            let mut acc = 0f32;
1872            for p in 0..k {
1873                let block = p / block_size;
1874                let s = scales[block * n + j];
1875                let z = if asym { zps[block * n + j] } else { 0.0 };
1876                let byte_idx = (p * n + j) / 2;
1877                let nibble = if (p * n + j) & 1 == 0 {
1878                    w_bytes[byte_idx] & 0x0F
1879                } else {
1880                    w_bytes[byte_idx] >> 4
1881                };
1882                let dequantized = (nibble as f32 - z) * s;
1883                acc += x[i * k + p] * dequantized;
1884            }
1885            out[i * n + j] = acc;
1886        }
1887    }
1888}
1889
1890fn fp8_e4m3_to_f32(b: u8) -> f32 {
1891    let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
1892    let exp = (b >> 3) & 0x0F;
1893    let mant = b & 0x07;
1894    if exp == 0 {
1895        if mant == 0 {
1896            return 0.0;
1897        }
1898        return sign * (mant as f32) * 2f32.powi(-9);
1899    }
1900    if exp == 0x0F {
1901        return if mant == 0 {
1902            sign * f32::INFINITY
1903        } else {
1904            f32::NAN
1905        };
1906    }
1907    sign * (1.0 + mant as f32 / 8.0) * 2f32.powi(exp as i32 - 7)
1908}
1909
1910fn fp8_e5m2_to_f32(b: u8) -> f32 {
1911    let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
1912    let exp = (b >> 2) & 0x1F;
1913    let mant = b & 0x03;
1914    if exp == 0 {
1915        if mant == 0 {
1916            return 0.0;
1917        }
1918        return sign * (mant as f32) * 2f32.powi(-16);
1919    }
1920    if exp == 0x1F {
1921        return if mant == 0 {
1922            sign * f32::INFINITY
1923        } else {
1924            f32::NAN
1925        };
1926    }
1927    sign * (1.0 + mant as f32 / 4.0) * 2f32.powi(exp as i32 - 15)
1928}
1929
1930#[allow(clippy::too_many_arguments)]
1931fn dequant_matmul_fp8(
1932    x: &[f32],
1933    w_bytes: &[u8],
1934    scales: &[f32],
1935    out: &mut [f32],
1936    m: usize,
1937    k: usize,
1938    n: usize,
1939    e5m2: bool,
1940) {
1941    let dequant = if e5m2 {
1942        fp8_e5m2_to_f32
1943    } else {
1944        fp8_e4m3_to_f32
1945    };
1946    for i in 0..m {
1947        for j in 0..n {
1948            let mut acc = 0f32;
1949            for p in 0..k {
1950                let w = dequant(w_bytes[p * n + j]);
1951                let s = scales.get(j).copied().unwrap_or(1.0);
1952                acc += x[i * k + p] * w * s;
1953            }
1954            out[i * n + j] = acc;
1955        }
1956    }
1957}
1958
1959#[allow(clippy::too_many_arguments)]
1960pub fn dequant_matmul_nvfp4(
1961    x: &[f32],
1962    w_bytes: &[u8],
1963    scale_bytes: &[u8],
1964    global_scale: f32,
1965    out: &mut [f32],
1966    m: usize,
1967    k: usize,
1968    n: usize,
1969) {
1970    use rlx_ir::{NVFP4_GROUP_SIZE, fp4_e2m1_to_f32, fp8_e4m3_scale_to_f32};
1971    let gs = NVFP4_GROUP_SIZE;
1972    for i in 0..m {
1973        for j in 0..n {
1974            let mut acc = 0f32;
1975            for p in 0..k {
1976                let byte_idx = (p * n + j) / 2;
1977                let nibble = if (p * n + j) & 1 == 0 {
1978                    w_bytes[byte_idx] & 0x0F
1979                } else {
1980                    w_bytes[byte_idx] >> 4
1981                };
1982                let block = p / gs;
1983                let scale = fp8_e4m3_scale_to_f32(scale_bytes[block * n + j]);
1984                let w = fp4_e2m1_to_f32(nibble) * scale * global_scale;
1985                acc += x[i * k + p] * w;
1986            }
1987            out[i * n + j] = acc;
1988        }
1989    }
1990}
1991
1992/// Fused sampling step: logits → top-k filter → top-p truncation
1993/// → softmax → multinomial sample. Operates on one row of length
1994/// `vocab` and returns the sampled index. Plan #42.
1995///
1996/// Internal scratch is on the stack via SmallVec-style fallback —
1997/// for `vocab > 8192` we heap-allocate a working buffer; below
1998/// that we keep things in a fixed array. (TODO: thread the
1999/// scratch through ThunkSchedule like sdpa_scores does.)
2000fn sample_row(
2001    logits: &[f32],
2002    top_k: usize,
2003    top_p: f32,
2004    temperature: f32,
2005    rng: &mut rlx_ir::Philox4x32,
2006) -> usize {
2007    let v = logits.len();
2008    if v == 0 {
2009        return 0;
2010    }
2011    let temp = temperature.max(1e-6);
2012    // Copy + temperature-scale into a working buffer.
2013    let mut scaled: Vec<f32> = logits.iter().map(|&x| x / temp).collect();
2014
2015    // Top-k: zero out everything but the k largest by setting to -inf.
2016    if top_k > 0 && top_k < v {
2017        // Partial selection: find k-th largest then mask below.
2018        let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2019        // Sort descending; partial would be O(n log k), full sort is fine
2020        // for typical vocab sizes (32k-128k) — single-row work.
2021        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2022        let cutoff = indexed[top_k - 1].1;
2023        for x in scaled.iter_mut() {
2024            if *x < cutoff {
2025                *x = f32::NEG_INFINITY;
2026            }
2027        }
2028    }
2029
2030    // Stable softmax.
2031    let mut max_l = f32::NEG_INFINITY;
2032    for &x in &scaled {
2033        if x > max_l {
2034            max_l = x;
2035        }
2036    }
2037    let mut sum = 0.0f32;
2038    for x in scaled.iter_mut() {
2039        *x = (*x - max_l).exp();
2040        sum += *x;
2041    }
2042    let inv = 1.0 / sum.max(f32::MIN_POSITIVE);
2043    for x in scaled.iter_mut() {
2044        *x *= inv;
2045    }
2046
2047    // Top-p: keep the smallest set of tokens whose cumulative
2048    // probability exceeds top_p (after sorting descending).
2049    if top_p < 1.0 {
2050        let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2051        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2052        let mut cum = 0.0f32;
2053        let mut keep = vec![false; v];
2054        for (idx, p) in indexed.iter() {
2055            keep[*idx] = true;
2056            cum += *p;
2057            if cum >= top_p {
2058                break;
2059            }
2060        }
2061        let mut new_sum = 0.0f32;
2062        for (i, x) in scaled.iter_mut().enumerate() {
2063            if !keep[i] {
2064                *x = 0.0;
2065            }
2066            new_sum += *x;
2067        }
2068        let inv = 1.0 / new_sum.max(f32::MIN_POSITIVE);
2069        for x in scaled.iter_mut() {
2070            *x *= inv;
2071        }
2072    }
2073
2074    // Multinomial sample via inverse-CDF.
2075    let r = rng.next_f32();
2076    let mut acc = 0.0f32;
2077    for (i, &p) in scaled.iter().enumerate() {
2078        acc += p;
2079        if r <= acc {
2080            return i;
2081        }
2082    }
2083    v - 1 // floating-point edge case fallback
2084}
2085
2086/// Apply a synthetic (kernel-generated) attention mask to a `[q_seq, k_seq]`
2087/// scores matrix. Custom masks are read from a tensor and not handled here.
2088/// `None` is a no-op so callers don't need to special-case it.
2089#[inline]
2090fn apply_synthetic_mask(
2091    scores: &mut [f32],
2092    q_seq: usize,
2093    k_seq: usize,
2094    kind: rlx_ir::op::MaskKind,
2095) {
2096    let neg = crate::config::RuntimeConfig::global().attn_mask_neg_inf;
2097    let q_offset = k_seq.saturating_sub(q_seq);
2098    match kind {
2099        rlx_ir::op::MaskKind::None | rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias => {}
2100        rlx_ir::op::MaskKind::Causal => {
2101            for qi in 0..q_seq {
2102                let abs_q = q_offset + qi;
2103                for ki in (abs_q + 1)..k_seq {
2104                    scores[qi * k_seq + ki] = neg;
2105                }
2106            }
2107        }
2108        rlx_ir::op::MaskKind::SlidingWindow(w) => {
2109            for qi in 0..q_seq {
2110                let abs_q = q_offset + qi;
2111                let lo = abs_q.saturating_sub(w);
2112                for ki in 0..k_seq {
2113                    if ki < lo || ki > abs_q {
2114                        scores[qi * k_seq + ki] = neg;
2115                    }
2116                }
2117            }
2118        }
2119    }
2120}
2121
2122/// Compile graph into thunk schedule.
2123pub fn compile_thunks(graph: &Graph, arena: &Arena) -> ThunkSchedule {
2124    let mut thunks = Vec::with_capacity(graph.len());
2125
2126    for node in graph.nodes() {
2127        // View ops (Reshape / same-dtype Cast / axis-0 Narrow) are aliased
2128        // to their parent's slot by the memory planner — no copy needed.
2129        // Plan #46.
2130        if rlx_opt::is_pure_view(graph, node) {
2131            thunks.push(Thunk::Nop);
2132            continue;
2133        }
2134        let t = match &node.op {
2135            Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => Thunk::Nop,
2136
2137            Op::FusedMatMulBiasAct { activation } => {
2138                let shape = &node.shape;
2139                let n = shape.dim(shape.rank() - 1).unwrap_static();
2140                let total = shape.num_elements().unwrap();
2141                let m = total / n;
2142                let a_len = get_len(graph, node.inputs[0]);
2143                let k = a_len / m;
2144                Thunk::FusedMmBiasAct {
2145                    a: node_offset(arena, node.inputs[0]),
2146                    w: node_offset(arena, node.inputs[1]),
2147                    bias: node_offset(arena, node.inputs[2]),
2148                    c: node_offset(arena, node.id),
2149                    m: m as u32,
2150                    k: k as u32,
2151                    n: n as u32,
2152                    act: *activation,
2153                }
2154            }
2155
2156            Op::FusedResidualLN { has_bias, eps } => {
2157                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2158                let total = node.shape.num_elements().unwrap();
2159                let rows = total / h;
2160                let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2161                Thunk::FusedResidualLN {
2162                    x: node_offset(arena, node.inputs[0]),
2163                    res: node_offset(arena, node.inputs[1]),
2164                    bias: if *has_bias {
2165                        node_offset(arena, node.inputs[2])
2166                    } else {
2167                        0
2168                    },
2169                    g: node_offset(arena, node.inputs[g_idx]),
2170                    b: node_offset(arena, node.inputs[b_idx]),
2171                    out: node_offset(arena, node.id),
2172                    rows: rows as u32,
2173                    h: h as u32,
2174                    eps: *eps,
2175                    has_bias: *has_bias,
2176                }
2177            }
2178
2179            Op::FusedResidualRmsNorm { has_bias, eps } => {
2180                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2181                let total = node.shape.num_elements().unwrap();
2182                let rows = total / h;
2183                let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2184                Thunk::FusedResidualRmsNorm {
2185                    x: node_offset(arena, node.inputs[0]),
2186                    res: node_offset(arena, node.inputs[1]),
2187                    bias: if *has_bias {
2188                        node_offset(arena, node.inputs[2])
2189                    } else {
2190                        0
2191                    },
2192                    g: node_offset(arena, node.inputs[g_idx]),
2193                    b: node_offset(arena, node.inputs[b_idx]),
2194                    out: node_offset(arena, node.id),
2195                    rows: rows as u32,
2196                    h: h as u32,
2197                    eps: *eps,
2198                    has_bias: *has_bias,
2199                }
2200            }
2201
2202            Op::MatMul => {
2203                let shape = &node.shape;
2204                let a_shape = &graph.node(node.inputs[0]).shape;
2205                let b_shape = &graph.node(node.inputs[1]).shape;
2206                let n = shape.dim(shape.rank() - 1).unwrap_static();
2207
2208                // Detect batched matmul: any rank where both inputs
2209                // and output share the same leading batch dims and
2210                // the last 2 dims form an [M, K] @ [K, N] = [M, N].
2211                // The 2-D MatMul lowering's flatten-and-call-dgemm trick
2212                // is wrong when both operands carry independent batch
2213                // dims (per-batch K dimension differs).
2214                let batched_3d = a_shape.rank() >= 3
2215                    && b_shape.rank() == a_shape.rank()
2216                    && shape.rank() == a_shape.rank()
2217                    && {
2218                        // All leading dims (everything except last 2) match.
2219                        let mut ok = true;
2220                        for d in 0..a_shape.rank() - 2 {
2221                            if a_shape.dim(d) != b_shape.dim(d) || a_shape.dim(d) != shape.dim(d) {
2222                                ok = false;
2223                                break;
2224                            }
2225                        }
2226                        ok
2227                    };
2228                if batched_3d && shape.dtype() == rlx_ir::DType::F64 {
2229                    // Batch is the product of all leading dims (every
2230                    // dim except the last 2); m/k/n are the inner
2231                    // matmul dims. Works for any rank >= 3.
2232                    let r = shape.rank();
2233                    let mut batch_prod = 1usize;
2234                    for d in 0..r - 2 {
2235                        batch_prod *= shape.dim(d).unwrap_static();
2236                    }
2237                    let m_dim = shape.dim(r - 2).unwrap_static();
2238                    let k_dim = a_shape.dim(r - 1).unwrap_static();
2239                    debug_assert_eq!(k_dim, b_shape.dim(r - 2).unwrap_static());
2240                    Thunk::BatchedDgemmF64 {
2241                        a: node_offset(arena, node.inputs[0]),
2242                        b: node_offset(arena, node.inputs[1]),
2243                        c: node_offset(arena, node.id),
2244                        batch: batch_prod as u32,
2245                        m: m_dim as u32,
2246                        k: k_dim as u32,
2247                        n: n as u32,
2248                    }
2249                } else if batched_3d && shape.dtype() == rlx_ir::DType::F32 {
2250                    // f32 batched matmul for any rank >= 3 (collapse all
2251                    // leading batch dims into a single batch count).
2252                    let r = shape.rank();
2253                    let mut batch_prod = 1usize;
2254                    for d in 0..r - 2 {
2255                        batch_prod *= shape.dim(d).unwrap_static();
2256                    }
2257                    let m_dim = shape.dim(r - 2).unwrap_static();
2258                    let k_dim = a_shape.dim(r - 1).unwrap_static();
2259                    debug_assert_eq!(k_dim, b_shape.dim(r - 2).unwrap_static());
2260                    Thunk::BatchedSgemm {
2261                        a: node_offset(arena, node.inputs[0]),
2262                        b: node_offset(arena, node.inputs[1]),
2263                        c: node_offset(arena, node.id),
2264                        batch: batch_prod as u32,
2265                        m: m_dim as u32,
2266                        k: k_dim as u32,
2267                        n: n as u32,
2268                    }
2269                } else {
2270                    let total = shape.num_elements().unwrap();
2271                    let m = total / n;
2272                    let a_len = get_len(graph, node.inputs[0]);
2273                    let k = a_len / m;
2274                    match shape.dtype() {
2275                        rlx_ir::DType::F64 => Thunk::Dgemm {
2276                            a: node_offset(arena, node.inputs[0]),
2277                            b: node_offset(arena, node.inputs[1]),
2278                            c: node_offset(arena, node.id),
2279                            m: m as u32,
2280                            k: k as u32,
2281                            n: n as u32,
2282                        },
2283                        _ => Thunk::Sgemm {
2284                            a: node_offset(arena, node.inputs[0]),
2285                            b: node_offset(arena, node.inputs[1]),
2286                            c: node_offset(arena, node.id),
2287                            m: m as u32,
2288                            k: k as u32,
2289                            n: n as u32,
2290                        },
2291                    }
2292                }
2293            }
2294
2295            Op::Binary(op) => {
2296                let lhs_len = get_len(graph, node.inputs[0]);
2297                let rhs_len = get_len(graph, node.inputs[1]);
2298                let out_len = node.shape.num_elements().unwrap();
2299                if node.shape.dtype() == rlx_ir::DType::C64 {
2300                    // Native C64 element-wise. Add/Sub/Mul/Div lower
2301                    // to `BinaryFullC64`; the rest don't have a
2302                    // single natural complex definition.
2303                    match op {
2304                        BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div => {}
2305                        BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => panic!(
2306                            "Op::Binary({op:?}) on DType::C64: complex \
2307                             max/min/pow have no single natural definition \
2308                             — caller should drop to 2N-real-block (see \
2309                             spike-ac) and pick a convention there"
2310                        ),
2311                    }
2312                }
2313                // Compute broadcast strides for the slow path. Empty
2314                // vectors when no broadcast is needed (the fast-path
2315                // kernel ignores them anyway).
2316                let (out_dims_bcast, bcast_lhs_strides, bcast_rhs_strides) =
2317                    if lhs_len == out_len && rhs_len == out_len {
2318                        (Vec::new(), Vec::new(), Vec::new())
2319                    } else {
2320                        let lhs_dims = get_static_dims(graph, node.inputs[0]);
2321                        let rhs_dims = get_static_dims(graph, node.inputs[1]);
2322                        let out_dims_v = get_static_dims(graph, node.id);
2323                        if lhs_dims.is_empty() || rhs_dims.is_empty() || out_dims_v.is_empty() {
2324                            // Dynamic shape — fall back to the legacy
2325                            // modulo path (correct for scalar / last-
2326                            // axis broadcast, which is the only
2327                            // dynamic case in practice).
2328                            (Vec::new(), Vec::new(), Vec::new())
2329                        } else {
2330                            let ls = broadcast_strides(&lhs_dims, &out_dims_v);
2331                            let rs = broadcast_strides(&rhs_dims, &out_dims_v);
2332                            let od: Vec<u32> = out_dims_v.iter().map(|x| *x as u32).collect();
2333                            (od, ls, rs)
2334                        }
2335                    };
2336                if node.shape.dtype() == rlx_ir::DType::C64 {
2337                    Thunk::BinaryFullC64 {
2338                        lhs: node_offset(arena, node.inputs[0]),
2339                        rhs: node_offset(arena, node.inputs[1]),
2340                        dst: node_offset(arena, node.id),
2341                        len: out_len as u32,
2342                        lhs_len: lhs_len as u32,
2343                        rhs_len: rhs_len as u32,
2344                        op: *op,
2345                        out_dims_bcast,
2346                        bcast_lhs_strides,
2347                        bcast_rhs_strides,
2348                    }
2349                } else if node.shape.dtype() == rlx_ir::DType::F64 {
2350                    // f64 path — no BiasAdd fast-path (yet); use the
2351                    // general binary-with-broadcast kernel.
2352                    Thunk::BinaryFullF64 {
2353                        lhs: node_offset(arena, node.inputs[0]),
2354                        rhs: node_offset(arena, node.inputs[1]),
2355                        dst: node_offset(arena, node.id),
2356                        len: out_len as u32,
2357                        lhs_len: lhs_len as u32,
2358                        rhs_len: rhs_len as u32,
2359                        op: *op,
2360                        out_dims_bcast,
2361                        bcast_lhs_strides,
2362                        bcast_rhs_strides,
2363                    }
2364                } else if matches!(op, BinaryOp::Add)
2365                    && rhs_len < out_len
2366                    && out_len % rhs_len == 0
2367                    && is_trailing_bias_broadcast(
2368                        graph.node(node.inputs[1]).shape.dims(),
2369                        graph.node(node.id).shape.dims(),
2370                    )
2371                {
2372                    // `BiasAdd` is only correct when the bias is a
2373                    // *trailing* broadcast — rhs dims match the right-
2374                    // hand side of the output dims (with size-1 only
2375                    // allowed in left-padded outer positions).
2376                    // SAM's rel-pos `[bh, h, w, 1, w] + [bh, h, w, h, w]`
2377                    // has rhs_len divide out_len cleanly but is a
2378                    // mid-shape singleton, NOT a trailing broadcast.
2379                    // Routing it through BiasAdd silently treats it as
2380                    // last-`rhs_len`-cols repeated — wrong values.
2381                    Thunk::BiasAdd {
2382                        src: node_offset(arena, node.inputs[0]),
2383                        bias: node_offset(arena, node.inputs[1]),
2384                        dst: node_offset(arena, node.id),
2385                        m: (out_len / rhs_len) as u32,
2386                        n: rhs_len as u32,
2387                    }
2388                } else {
2389                    let lhs_len = get_len(graph, node.inputs[0]);
2390                    Thunk::BinaryFull {
2391                        lhs: node_offset(arena, node.inputs[0]),
2392                        rhs: node_offset(arena, node.inputs[1]),
2393                        dst: node_offset(arena, node.id),
2394                        len: out_len as u32,
2395                        lhs_len: lhs_len as u32,
2396                        rhs_len: rhs_len as u32,
2397                        op: *op,
2398                        out_dims_bcast,
2399                        bcast_lhs_strides,
2400                        bcast_rhs_strides,
2401                    }
2402                }
2403            }
2404
2405            Op::Activation(act) => {
2406                let len = node.shape.num_elements().unwrap();
2407                let in_off = node_offset(arena, node.inputs[0]);
2408                let out_off = node_offset(arena, node.id);
2409                if node.shape.dtype() == rlx_ir::DType::C64 {
2410                    // Only Neg/Exp/Log/Sqrt have natural complex
2411                    // extensions used in signal-processing graphs.
2412                    // Everything else (Sigmoid, Tanh, Relu, Abs,
2413                    // Sin/Cos/Tan/Atan, Round, GeLU family) is rejected.
2414                    match act {
2415                        Activation::Neg | Activation::Exp | Activation::Log | Activation::Sqrt => {}
2416                        other => panic!(
2417                            "Op::Activation({other:?}) on DType::C64: no \
2418                             natural complex extension — supported on C64: \
2419                             Neg, Exp, Log, Sqrt"
2420                        ),
2421                    }
2422                    Thunk::ActivationC64 {
2423                        src: in_off,
2424                        dst: out_off,
2425                        len: len as u32,
2426                        kind: *act,
2427                    }
2428                } else if node.shape.dtype() == rlx_ir::DType::F64 {
2429                    Thunk::ActivationF64 {
2430                        src: in_off,
2431                        dst: out_off,
2432                        len: len as u32,
2433                        kind: *act,
2434                    }
2435                } else if in_off == out_off {
2436                    // ActivationInPlace operates on a single buffer. When the
2437                    // planner has assigned input and output the same slot
2438                    // (typical post-fusion case), we just run on that slot.
2439                    Thunk::ActivationInPlace {
2440                        data: out_off,
2441                        len: len as u32,
2442                        act: *act,
2443                    }
2444                } else {
2445                    // Two-step: copy input → output, then activate output in place.
2446                    // The schedule executes them in this order; downstream
2447                    // thunks see the activated output at out_off.
2448                    thunks.push(Thunk::Copy {
2449                        src: in_off,
2450                        dst: out_off,
2451                        len: len as u32,
2452                    });
2453                    Thunk::ActivationInPlace {
2454                        data: out_off,
2455                        len: len as u32,
2456                        act: *act,
2457                    }
2458                }
2459            }
2460
2461            Op::Gather { axis } if *axis == 0 => {
2462                let table_shape = &graph.node(node.inputs[0]).shape;
2463                let table_total = table_shape.num_elements().unwrap();
2464                let trailing: usize = (1..table_shape.rank())
2465                    .map(|i| table_shape.dim(i).unwrap_static())
2466                    .product();
2467                let idx_len = get_len(graph, node.inputs[1]);
2468                Thunk::Gather {
2469                    table: node_offset(arena, node.inputs[0]),
2470                    table_len: table_total as u32,
2471                    idx: node_offset(arena, node.inputs[1]),
2472                    dst: node_offset(arena, node.id),
2473                    num_idx: idx_len as u32,
2474                    trailing: trailing as u32,
2475                }
2476            }
2477
2478            Op::Gather { axis } => {
2479                // Non-zero axis: outer × num_idx × trailing layout.
2480                let table_shape = &graph.node(node.inputs[0]).shape;
2481                let rank = table_shape.rank();
2482                let outer: usize = (0..*axis)
2483                    .map(|i| table_shape.dim(i).unwrap_static())
2484                    .product::<usize>()
2485                    .max(1);
2486                let trailing: usize = (*axis + 1..rank)
2487                    .map(|i| table_shape.dim(i).unwrap_static())
2488                    .product::<usize>()
2489                    .max(1);
2490                let axis_dim = table_shape.dim(*axis).unwrap_static();
2491                let idx_len = get_len(graph, node.inputs[1]);
2492                Thunk::GatherAxis {
2493                    table: node_offset(arena, node.inputs[0]),
2494                    idx: node_offset(arena, node.inputs[1]),
2495                    dst: node_offset(arena, node.id),
2496                    outer: outer as u32,
2497                    axis_dim: axis_dim as u32,
2498                    num_idx: idx_len as u32,
2499                    trailing: trailing as u32,
2500                }
2501            }
2502
2503            Op::Narrow { axis, start, len } => {
2504                let in_shape = &graph.node(node.inputs[0]).shape;
2505                let elem_bytes = in_shape.dtype().size_bytes() as u8;
2506                let rank = in_shape.rank();
2507                let outer: usize = (0..*axis)
2508                    .map(|i| in_shape.dim(i).unwrap_static())
2509                    .product::<usize>()
2510                    .max(1);
2511                let inner: usize = (*axis + 1..rank)
2512                    .map(|i| in_shape.dim(i).unwrap_static())
2513                    .product::<usize>()
2514                    .max(1);
2515                let in_axis = in_shape.dim(*axis).unwrap_static();
2516                let src_byte_offset =
2517                    node_offset(arena, node.inputs[0]) + start * inner * elem_bytes as usize;
2518                Thunk::Narrow {
2519                    src: src_byte_offset,
2520                    dst: node_offset(arena, node.id),
2521                    outer: outer as u32,
2522                    src_stride: (in_axis * inner) as u32, // elements per outer step in source
2523                    dst_stride: (*len * inner) as u32,    // elements per outer step in dest
2524                    inner: (*len * inner) as u32,         // elements to copy per outer step
2525                    elem_bytes,
2526                }
2527            }
2528
2529            Op::Reshape { .. } | Op::Cast { .. } => {
2530                // Pure layout/dtype change: same total element count, plain copy.
2531                let len = node.shape.num_elements().unwrap();
2532                let src = node_offset(arena, node.inputs[0]);
2533                let dst = node_offset(arena, node.id);
2534                match node.shape.dtype() {
2535                    rlx_ir::DType::F64 => Thunk::CopyF64 {
2536                        src,
2537                        dst,
2538                        len: len as u32,
2539                    },
2540                    _ => Thunk::Copy {
2541                        src,
2542                        dst,
2543                        len: len as u32,
2544                    },
2545                }
2546            }
2547
2548            Op::Quantize {
2549                axis,
2550                scales,
2551                zero_points,
2552            } => {
2553                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2554                Thunk::Quantize {
2555                    x: node_offset(arena, node.inputs[0]),
2556                    q: node_offset(arena, node.id),
2557                    len: node.shape.num_elements().unwrap() as u32,
2558                    chan_axis: chan_axis as u32,
2559                    chan_dim: chan_dim as u32,
2560                    inner: inner as u32,
2561                    scales: scales.clone(),
2562                    zero_points: zero_points.clone(),
2563                }
2564            }
2565
2566            Op::FakeQuantize {
2567                bits,
2568                axis,
2569                ste,
2570                scale_mode,
2571            } => {
2572                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2573                let state_off = match scale_mode {
2574                    rlx_ir::op::ScaleMode::PerBatch => None,
2575                    rlx_ir::op::ScaleMode::EMA { .. } | rlx_ir::op::ScaleMode::Fixed => {
2576                        // Second input carries the [chan_dim] scale state.
2577                        debug_assert_eq!(
2578                            node.inputs.len(),
2579                            2,
2580                            "EMA/Fixed FakeQuantize needs a state input"
2581                        );
2582                        Some(node_offset(arena, node.inputs[1]))
2583                    }
2584                };
2585                Thunk::FakeQuantize {
2586                    x: node_offset(arena, node.inputs[0]),
2587                    out: node_offset(arena, node.id),
2588                    len: node.shape.num_elements().unwrap() as u32,
2589                    chan_axis: chan_axis as u32,
2590                    chan_dim: chan_dim as u32,
2591                    inner: inner as u32,
2592                    bits: *bits,
2593                    ste: *ste,
2594                    scale_mode: *scale_mode,
2595                    state_off,
2596                }
2597            }
2598
2599            Op::FakeQuantizeLSQ { bits, axis } => {
2600                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2601                Thunk::FakeQuantizeLSQ {
2602                    x: node_offset(arena, node.inputs[0]),
2603                    scale_off: node_offset(arena, node.inputs[1]),
2604                    out: node_offset(arena, node.id),
2605                    len: node.shape.num_elements().unwrap() as u32,
2606                    chan_axis: chan_axis as u32,
2607                    chan_dim: chan_dim as u32,
2608                    inner: inner as u32,
2609                    bits: *bits,
2610                }
2611            }
2612
2613            Op::FakeQuantizeLSQBackwardX { bits, axis } => {
2614                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2615                Thunk::FakeQuantizeLSQBackwardX {
2616                    x: node_offset(arena, node.inputs[0]),
2617                    scale_off: node_offset(arena, node.inputs[1]),
2618                    dy: node_offset(arena, node.inputs[2]),
2619                    dx: node_offset(arena, node.id),
2620                    len: node.shape.num_elements().unwrap() as u32,
2621                    chan_axis: chan_axis as u32,
2622                    chan_dim: chan_dim as u32,
2623                    inner: inner as u32,
2624                    bits: *bits,
2625                }
2626            }
2627
2628            Op::FakeQuantizeLSQBackwardScale { bits, axis } => {
2629                // Output shape is [chan_dim] — node.shape doesn't
2630                // describe the input data layout, but inputs[0] does.
2631                let in_shape = &graph.node(node.inputs[0]).shape;
2632                let (chan_axis, chan_dim, inner) = quant_layout(in_shape, *axis);
2633                Thunk::FakeQuantizeLSQBackwardScale {
2634                    x: node_offset(arena, node.inputs[0]),
2635                    scale_off: node_offset(arena, node.inputs[1]),
2636                    dy: node_offset(arena, node.inputs[2]),
2637                    dscale: node_offset(arena, node.id),
2638                    len: in_shape.num_elements().unwrap() as u32,
2639                    chan_axis: chan_axis as u32,
2640                    chan_dim: chan_dim as u32,
2641                    inner: inner as u32,
2642                    bits: *bits,
2643                }
2644            }
2645
2646            Op::FakeQuantizeBackward { bits, axis, ste } => {
2647                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2648                Thunk::FakeQuantizeBackward {
2649                    x: node_offset(arena, node.inputs[0]),
2650                    dy: node_offset(arena, node.inputs[1]),
2651                    dx: node_offset(arena, node.id),
2652                    len: node.shape.num_elements().unwrap() as u32,
2653                    chan_axis: chan_axis as u32,
2654                    chan_dim: chan_dim as u32,
2655                    inner: inner as u32,
2656                    bits: *bits,
2657                    ste: *ste,
2658                }
2659            }
2660
2661            Op::Dequantize {
2662                axis,
2663                scales,
2664                zero_points,
2665            } => {
2666                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2667                Thunk::Dequantize {
2668                    q: node_offset(arena, node.inputs[0]),
2669                    x: node_offset(arena, node.id),
2670                    len: node.shape.num_elements().unwrap() as u32,
2671                    chan_axis: chan_axis as u32,
2672                    chan_dim: chan_dim as u32,
2673                    inner: inner as u32,
2674                    scales: scales.clone(),
2675                    zero_points: zero_points.clone(),
2676                }
2677            }
2678
2679            Op::Expand { .. } => {
2680                // Broadcast: build per-output-dim strides where any input dim
2681                // of size 1 has stride 0 (read the same element repeatedly).
2682                // Reuses the Thunk::Transpose runtime — N-D walk with strides
2683                // is identical; only the strides differ.
2684                let in_shape = &graph.node(node.inputs[0]).shape;
2685                let out_shape = &node.shape;
2686                let in_rank = in_shape.rank();
2687                let out_rank = out_shape.rank();
2688                // Implicit leading 1s if input has lower rank.
2689                let pad = out_rank.saturating_sub(in_rank);
2690                let in_dims: Vec<usize> = (0..out_rank)
2691                    .map(|i| {
2692                        if i < pad {
2693                            1
2694                        } else {
2695                            in_shape.dim(i - pad).unwrap_static()
2696                        }
2697                    })
2698                    .collect();
2699                // Row-major input strides (over the padded shape).
2700                let mut in_strides_full = vec![1usize; out_rank];
2701                for d in (0..out_rank.saturating_sub(1)).rev() {
2702                    in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
2703                }
2704                let out_dims: Vec<u32> = (0..out_rank)
2705                    .map(|i| out_shape.dim(i).unwrap_static() as u32)
2706                    .collect();
2707                // Stride is 0 for broadcast dims (in_dim == 1 && out_dim > 1).
2708                let in_strides: Vec<u32> = (0..out_rank)
2709                    .map(|i| {
2710                        if in_dims[i] == 1 && (out_dims[i] as usize) > 1 {
2711                            0
2712                        } else {
2713                            in_strides_full[i] as u32
2714                        }
2715                    })
2716                    .collect();
2717                let in_total = in_dims.iter().product::<usize>() as u32;
2718                let src = node_offset(arena, node.inputs[0]);
2719                let dst = node_offset(arena, node.id);
2720                match node.shape.dtype() {
2721                    rlx_ir::DType::F64 => Thunk::TransposeF64 {
2722                        src,
2723                        dst,
2724                        in_total,
2725                        out_dims,
2726                        in_strides,
2727                    },
2728                    _ => Thunk::Transpose {
2729                        src,
2730                        dst,
2731                        in_total,
2732                        out_dims,
2733                        in_strides,
2734                    },
2735                }
2736            }
2737
2738            Op::RmsNorm { eps, .. } => {
2739                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2740                let total = node.shape.num_elements().unwrap();
2741                Thunk::RmsNorm {
2742                    src: node_offset(arena, node.inputs[0]),
2743                    g: node_offset(arena, node.inputs[1]),
2744                    b: node_offset(arena, node.inputs[2]),
2745                    dst: node_offset(arena, node.id),
2746                    rows: (total / h) as u32,
2747                    h: h as u32,
2748                    eps: *eps,
2749                }
2750            }
2751
2752            Op::LayerNorm { eps, .. } => {
2753                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2754                let total = node.shape.num_elements().unwrap();
2755                Thunk::LayerNorm {
2756                    src: node_offset(arena, node.inputs[0]),
2757                    g: node_offset(arena, node.inputs[1]),
2758                    b: node_offset(arena, node.inputs[2]),
2759                    dst: node_offset(arena, node.id),
2760                    rows: (total / h) as u32,
2761                    h: h as u32,
2762                    eps: *eps,
2763                }
2764            }
2765
2766            Op::GroupNorm { num_groups, eps } => {
2767                let in_shape = &graph.node(node.inputs[0]).shape;
2768                Thunk::GroupNorm {
2769                    src: node_offset(arena, node.inputs[0]),
2770                    g: node_offset(arena, node.inputs[1]),
2771                    b: node_offset(arena, node.inputs[2]),
2772                    dst: node_offset(arena, node.id),
2773                    n: in_shape.dim(0).unwrap_static() as u32,
2774                    c: in_shape.dim(1).unwrap_static() as u32,
2775                    h: in_shape.dim(2).unwrap_static() as u32,
2776                    w: in_shape.dim(3).unwrap_static() as u32,
2777                    num_groups: *num_groups as u32,
2778                    eps: *eps,
2779                }
2780            }
2781
2782            Op::LayerNorm2d { eps } => {
2783                let in_shape = &graph.node(node.inputs[0]).shape;
2784                Thunk::LayerNorm2d {
2785                    src: node_offset(arena, node.inputs[0]),
2786                    g: node_offset(arena, node.inputs[1]),
2787                    b: node_offset(arena, node.inputs[2]),
2788                    dst: node_offset(arena, node.id),
2789                    n: in_shape.dim(0).unwrap_static() as u32,
2790                    c: in_shape.dim(1).unwrap_static() as u32,
2791                    h: in_shape.dim(2).unwrap_static() as u32,
2792                    w: in_shape.dim(3).unwrap_static() as u32,
2793                    eps: *eps,
2794                }
2795            }
2796
2797            Op::ConvTranspose2d {
2798                kernel_size,
2799                stride,
2800                padding,
2801                dilation,
2802                output_padding: _,
2803                groups,
2804            } => {
2805                let in_shape = &graph.node(node.inputs[0]).shape;
2806                let out_shape = &node.shape;
2807                Thunk::ConvTranspose2d {
2808                    src: node_offset(arena, node.inputs[0]),
2809                    weight: node_offset(arena, node.inputs[1]),
2810                    dst: node_offset(arena, node.id),
2811                    n: in_shape.dim(0).unwrap_static() as u32,
2812                    c_in: in_shape.dim(1).unwrap_static() as u32,
2813                    h: in_shape.dim(2).unwrap_static() as u32,
2814                    w_in: in_shape.dim(3).unwrap_static() as u32,
2815                    c_out: out_shape.dim(1).unwrap_static() as u32,
2816                    h_out: out_shape.dim(2).unwrap_static() as u32,
2817                    w_out: out_shape.dim(3).unwrap_static() as u32,
2818                    kh: kernel_size[0] as u32,
2819                    kw: kernel_size[1] as u32,
2820                    sh: stride.first().copied().unwrap_or(1) as u32,
2821                    sw: stride.get(1).copied().unwrap_or(1) as u32,
2822                    ph: padding.first().copied().unwrap_or(0) as u32,
2823                    pw: padding.get(1).copied().unwrap_or(0) as u32,
2824                    dh: dilation.first().copied().unwrap_or(1) as u32,
2825                    dw: dilation.get(1).copied().unwrap_or(1) as u32,
2826                    groups: *groups as u32,
2827                }
2828            }
2829
2830            Op::ResizeNearest2x => {
2831                let in_shape = &graph.node(node.inputs[0]).shape;
2832                Thunk::ResizeNearest2x {
2833                    src: node_offset(arena, node.inputs[0]),
2834                    dst: node_offset(arena, node.id),
2835                    n: in_shape.dim(0).unwrap_static() as u32,
2836                    c: in_shape.dim(1).unwrap_static() as u32,
2837                    h: in_shape.dim(2).unwrap_static() as u32,
2838                    w: in_shape.dim(3).unwrap_static() as u32,
2839                }
2840            }
2841
2842            Op::AxialRope2d {
2843                end_x,
2844                end_y,
2845                head_dim,
2846                num_heads,
2847                theta,
2848                repeat_factor,
2849            } => {
2850                let in_shape = &graph.node(node.inputs[0]).shape;
2851                let batch = in_shape.dim(0).unwrap_static() as u32;
2852                let seq = in_shape.dim(1).unwrap_static() as u32;
2853                let hidden = in_shape.dim(2).unwrap_static() as u32;
2854                Thunk::AxialRope2d {
2855                    src: node_offset(arena, node.inputs[0]),
2856                    dst: node_offset(arena, node.id),
2857                    batch,
2858                    seq,
2859                    hidden,
2860                    end_x: *end_x as u32,
2861                    end_y: *end_y as u32,
2862                    head_dim: *head_dim as u32,
2863                    num_heads: *num_heads as u32,
2864                    theta: *theta,
2865                    repeat_factor: *repeat_factor as u32,
2866                }
2867            }
2868
2869            Op::Softmax { axis } => {
2870                let rank = node.shape.rank();
2871                let ax = if *axis < 0 {
2872                    (rank as i32 + axis) as usize
2873                } else {
2874                    *axis as usize
2875                };
2876                let cols = node.shape.dim(ax).unwrap_static();
2877                let total = node.shape.num_elements().unwrap();
2878                let in_off = node_offset(arena, node.inputs[0]);
2879                let out_off = node_offset(arena, node.id);
2880                // Softmax kernel runs in-place on its data buffer. If the
2881                // planner gave input and output separate slots (their live
2882                // ranges overlap, so no aliasing), the output starts
2883                // uninitialized — emit a Copy first so the data is there.
2884                // Same pattern as Op::Activation.
2885                if in_off != out_off {
2886                    thunks.push(Thunk::Copy {
2887                        src: in_off,
2888                        dst: out_off,
2889                        len: total as u32,
2890                    });
2891                }
2892                Thunk::Softmax {
2893                    data: out_off,
2894                    rows: (total / cols) as u32,
2895                    cols: cols as u32,
2896                }
2897            }
2898
2899            Op::SelectiveScan { state_size } => {
2900                let in_shape = &graph.node(node.inputs[0]).shape;
2901                let (batch, seq, hidden) = (
2902                    in_shape.dim(0).unwrap_static(),
2903                    in_shape.dim(1).unwrap_static(),
2904                    in_shape.dim(2).unwrap_static(),
2905                );
2906                Thunk::SelectiveScan {
2907                    x: node_offset(arena, node.inputs[0]),
2908                    delta: node_offset(arena, node.inputs[1]),
2909                    a: node_offset(arena, node.inputs[2]),
2910                    b: node_offset(arena, node.inputs[3]),
2911                    c: node_offset(arena, node.inputs[4]),
2912                    dst: node_offset(arena, node.id),
2913                    batch: batch as u32,
2914                    seq: seq as u32,
2915                    hidden: hidden as u32,
2916                    state_size: *state_size as u32,
2917                }
2918            }
2919
2920            Op::GatedDeltaNet {
2921                state_size,
2922                carry_state,
2923            } => {
2924                let q_shape = &graph.node(node.inputs[0]).shape;
2925                let (batch, seq, heads) = (
2926                    q_shape.dim(0).unwrap_static(),
2927                    q_shape.dim(1).unwrap_static(),
2928                    q_shape.dim(2).unwrap_static(),
2929                );
2930                let state_off = if *carry_state {
2931                    node_offset(arena, node.inputs[5])
2932                } else {
2933                    0
2934                };
2935                Thunk::GatedDeltaNet {
2936                    q: node_offset(arena, node.inputs[0]),
2937                    k: node_offset(arena, node.inputs[1]),
2938                    v: node_offset(arena, node.inputs[2]),
2939                    g: node_offset(arena, node.inputs[3]),
2940                    beta: node_offset(arena, node.inputs[4]),
2941                    state: state_off,
2942                    dst: node_offset(arena, node.id),
2943                    batch: batch as u32,
2944                    seq: seq as u32,
2945                    heads: heads as u32,
2946                    state_size: *state_size as u32,
2947                }
2948            }
2949
2950            Op::QMatMul {
2951                x_zp,
2952                w_zp,
2953                out_zp,
2954                mult,
2955            } => {
2956                let x_shape = &graph.node(node.inputs[0]).shape;
2957                let w_shape = &graph.node(node.inputs[1]).shape;
2958                let m = x_shape.dim(0).unwrap_static();
2959                let k = x_shape.dim(1).unwrap_static();
2960                let n = w_shape.dim(1).unwrap_static();
2961                Thunk::QMatMul {
2962                    x: node_offset(arena, node.inputs[0]),
2963                    w: node_offset(arena, node.inputs[1]),
2964                    bias: node_offset(arena, node.inputs[2]),
2965                    out: node_offset(arena, node.id),
2966                    m: m as u32,
2967                    k: k as u32,
2968                    n: n as u32,
2969                    x_zp: *x_zp,
2970                    w_zp: *w_zp,
2971                    out_zp: *out_zp,
2972                    mult: *mult,
2973                }
2974            }
2975
2976            Op::QConv2d {
2977                kernel_size,
2978                stride,
2979                padding,
2980                dilation,
2981                groups,
2982                x_zp,
2983                w_zp,
2984                out_zp,
2985                mult,
2986            } => {
2987                let in_shape = &graph.node(node.inputs[0]).shape;
2988                let w_shape = &graph.node(node.inputs[1]).shape;
2989                let out_shape = &node.shape;
2990                if kernel_size.len() == 2
2991                    && in_shape.rank() == 4
2992                    && w_shape.rank() == 4
2993                    && out_shape.rank() == 4
2994                {
2995                    Thunk::QConv2d {
2996                        x: node_offset(arena, node.inputs[0]),
2997                        w: node_offset(arena, node.inputs[1]),
2998                        bias: node_offset(arena, node.inputs[2]),
2999                        out: node_offset(arena, node.id),
3000                        n: in_shape.dim(0).unwrap_static() as u32,
3001                        c_in: in_shape.dim(1).unwrap_static() as u32,
3002                        h: in_shape.dim(2).unwrap_static() as u32,
3003                        w_in: in_shape.dim(3).unwrap_static() as u32,
3004                        c_out: out_shape.dim(1).unwrap_static() as u32,
3005                        h_out: out_shape.dim(2).unwrap_static() as u32,
3006                        w_out: out_shape.dim(3).unwrap_static() as u32,
3007                        kh: kernel_size[0] as u32,
3008                        kw: kernel_size[1] as u32,
3009                        sh: stride.first().copied().unwrap_or(1) as u32,
3010                        sw: stride.get(1).copied().unwrap_or(1) as u32,
3011                        ph: padding.first().copied().unwrap_or(0) as u32,
3012                        pw: padding.get(1).copied().unwrap_or(0) as u32,
3013                        dh: dilation.first().copied().unwrap_or(1) as u32,
3014                        dw: dilation.get(1).copied().unwrap_or(1) as u32,
3015                        groups: *groups as u32,
3016                        x_zp: *x_zp,
3017                        w_zp: *w_zp,
3018                        out_zp: *out_zp,
3019                        mult: *mult,
3020                    }
3021                } else {
3022                    Thunk::Nop
3023                }
3024            }
3025
3026            Op::DequantMatMul { scheme } => {
3027                use rlx_ir::quant::QuantScheme;
3028                let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3029                let total = node.shape.num_elements().unwrap();
3030                let m = total / n.max(1);
3031                let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3032                let k = x_total / m.max(1);
3033                if scheme.is_gguf() {
3034                    Thunk::DequantMatMulGguf {
3035                        x: node_offset(arena, node.inputs[0]),
3036                        w_q: node_offset(arena, node.inputs[1]),
3037                        dst: node_offset(arena, node.id),
3038                        m: m as u32,
3039                        k: k as u32,
3040                        n: n as u32,
3041                        scheme: *scheme,
3042                    }
3043                } else {
3044                    match scheme {
3045                        QuantScheme::Nvfp4Block => Thunk::DequantMatMulNvfp4 {
3046                            x: node_offset(arena, node.inputs[0]),
3047                            w_q: node_offset(arena, node.inputs[1]),
3048                            scale: node_offset(arena, node.inputs[2]),
3049                            global_scale: node_offset(arena, node.inputs[3]),
3050                            dst: node_offset(arena, node.id),
3051                            m: m as u32,
3052                            k: k as u32,
3053                            n: n as u32,
3054                        },
3055                        QuantScheme::Int4Block { block_size } => Thunk::DequantMatMulInt4 {
3056                            x: node_offset(arena, node.inputs[0]),
3057                            w_q: node_offset(arena, node.inputs[1]),
3058                            scale: node_offset(arena, node.inputs[2]),
3059                            zp: node_offset(arena, node.inputs[3]),
3060                            dst: node_offset(arena, node.id),
3061                            m: m as u32,
3062                            k: k as u32,
3063                            n: n as u32,
3064                            block_size: *block_size,
3065                            is_asymmetric: false,
3066                        },
3067                        QuantScheme::Fp8E4m3 => Thunk::DequantMatMulFp8 {
3068                            x: node_offset(arena, node.inputs[0]),
3069                            w_q: node_offset(arena, node.inputs[1]),
3070                            scale: node_offset(arena, node.inputs[2]),
3071                            dst: node_offset(arena, node.id),
3072                            m: m as u32,
3073                            k: k as u32,
3074                            n: n as u32,
3075                            e5m2: false,
3076                        },
3077                        QuantScheme::Fp8E5m2 => Thunk::DequantMatMulFp8 {
3078                            x: node_offset(arena, node.inputs[0]),
3079                            w_q: node_offset(arena, node.inputs[1]),
3080                            scale: node_offset(arena, node.inputs[2]),
3081                            dst: node_offset(arena, node.id),
3082                            m: m as u32,
3083                            k: k as u32,
3084                            n: n as u32,
3085                            e5m2: true,
3086                        },
3087                        QuantScheme::Int8Block { block_size } => Thunk::DequantMatMul {
3088                            x: node_offset(arena, node.inputs[0]),
3089                            w_q: node_offset(arena, node.inputs[1]),
3090                            scale: node_offset(arena, node.inputs[2]),
3091                            zp: node_offset(arena, node.inputs[3]),
3092                            dst: node_offset(arena, node.id),
3093                            m: m as u32,
3094                            k: k as u32,
3095                            n: n as u32,
3096                            block_size: *block_size,
3097                            is_asymmetric: false,
3098                        },
3099                        QuantScheme::Int8BlockAsym { block_size } => Thunk::DequantMatMul {
3100                            x: node_offset(arena, node.inputs[0]),
3101                            w_q: node_offset(arena, node.inputs[1]),
3102                            scale: node_offset(arena, node.inputs[2]),
3103                            zp: node_offset(arena, node.inputs[3]),
3104                            dst: node_offset(arena, node.id),
3105                            m: m as u32,
3106                            k: k as u32,
3107                            n: n as u32,
3108                            block_size: *block_size,
3109                            is_asymmetric: true,
3110                        },
3111                        other => panic!(
3112                            "DequantMatMul on CPU supports Int8/Int4/FP8/NVFP4 legacy or GGUF schemes; got {other}"
3113                        ),
3114                    }
3115                }
3116            }
3117
3118            Op::LoraMatMul { scale } => {
3119                // x [m, k], w [k, n], a [k, r], b [r, n].
3120                let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3121                let total = node.shape.num_elements().unwrap();
3122                let m = total / n.max(1);
3123                let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3124                let k = x_total / m.max(1);
3125                let a_total = graph.node(node.inputs[2]).shape.num_elements().unwrap();
3126                let r = a_total / k.max(1);
3127                Thunk::LoraMatMul {
3128                    x: node_offset(arena, node.inputs[0]),
3129                    w: node_offset(arena, node.inputs[1]),
3130                    a: node_offset(arena, node.inputs[2]),
3131                    b: node_offset(arena, node.inputs[3]),
3132                    dst: node_offset(arena, node.id),
3133                    m: m as u32,
3134                    k: k as u32,
3135                    n: n as u32,
3136                    r: r as u32,
3137                    scale: *scale,
3138                }
3139            }
3140
3141            Op::Sample {
3142                top_k,
3143                top_p,
3144                temperature,
3145                seed,
3146            } => {
3147                let in_shape = &graph.node(node.inputs[0]).shape;
3148                // Logits are [batch, vocab] (or [vocab] → batch=1).
3149                let (batch, vocab) = if in_shape.rank() >= 2 {
3150                    (
3151                        in_shape.dim(0).unwrap_static(),
3152                        in_shape.dim(in_shape.rank() - 1).unwrap_static(),
3153                    )
3154                } else {
3155                    (1, in_shape.num_elements().unwrap_or(0))
3156                };
3157                Thunk::Sample {
3158                    logits: node_offset(arena, node.inputs[0]),
3159                    dst: node_offset(arena, node.id),
3160                    batch: batch as u32,
3161                    vocab: vocab as u32,
3162                    top_k: *top_k as u32,
3163                    top_p: *top_p,
3164                    temperature: *temperature,
3165                    seed: *seed,
3166                }
3167            }
3168
3169            Op::Cumsum { axis, exclusive } => {
3170                // For now CPU only supports last-axis cumsum (the
3171                // common case for sampling / ragged offsets).
3172                // Other axes can lower via Transpose → Cumsum →
3173                // Transpose; not on the hot path today.
3174                let rank = node.shape.rank();
3175                let ax = if *axis < 0 {
3176                    (rank as i32 + axis) as usize
3177                } else {
3178                    *axis as usize
3179                };
3180                assert_eq!(
3181                    ax,
3182                    rank - 1,
3183                    "Cumsum only supports the last axis on CPU today"
3184                );
3185                let cols = node.shape.dim(ax).unwrap_static();
3186                let total = node.shape.num_elements().unwrap();
3187                Thunk::Cumsum {
3188                    src: node_offset(arena, node.inputs[0]),
3189                    dst: node_offset(arena, node.id),
3190                    rows: (total / cols) as u32,
3191                    cols: cols as u32,
3192                    exclusive: *exclusive,
3193                }
3194            }
3195
3196            Op::Attention {
3197                num_heads,
3198                head_dim,
3199                mask_kind,
3200                score_scale: _,
3201                attn_logit_softcap: _,
3202            } => {
3203                // Layout dispatch: rank-4 input could be either
3204                // `[B, S, H, D]` (CPU's historical convention) or
3205                // `[B, H, S, D]` (the convention the GPU/TPU backends
3206                // share). Disambiguate by which axis matches
3207                // `num_heads`. Rank-3 is always `[B, S, H*D]`.
3208                let q_shape = &graph.node(node.inputs[0]).shape;
3209                let k_shape = &graph.node(node.inputs[1]).shape;
3210                let rank = q_shape.rank();
3211                let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3212                    let d1 = q_shape.dim(1).unwrap_static();
3213                    let d2 = q_shape.dim(2).unwrap_static();
3214                    if d1 == *num_heads {
3215                        // [B, H, S, D]
3216                        (
3217                            q_shape.dim(0).unwrap_static(),
3218                            d2,
3219                            k_shape.dim(2).unwrap_static(),
3220                            true,
3221                        )
3222                    } else {
3223                        // [B, S, H, D]
3224                        (
3225                            q_shape.dim(0).unwrap_static(),
3226                            d1,
3227                            k_shape.dim(1).unwrap_static(),
3228                            false,
3229                        )
3230                    }
3231                } else if rank >= 3 {
3232                    (
3233                        q_shape.dim(0).unwrap_static(),
3234                        q_shape.dim(1).unwrap_static(),
3235                        k_shape.dim(1).unwrap_static(),
3236                        false,
3237                    )
3238                } else {
3239                    (
3240                        1,
3241                        q_shape.dim(0).unwrap_static(),
3242                        k_shape.dim(0).unwrap_static(),
3243                        false,
3244                    )
3245                };
3246                let mask_off = if matches!(
3247                    mask_kind,
3248                    rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3249                ) {
3250                    node_offset(arena, node.inputs[3])
3251                } else {
3252                    0
3253                };
3254                let hs = (*num_heads * *head_dim) as u32;
3255                Thunk::Attention {
3256                    q: node_offset(arena, node.inputs[0]),
3257                    k: node_offset(arena, node.inputs[1]),
3258                    v: node_offset(arena, node.inputs[2]),
3259                    mask: mask_off,
3260                    out: node_offset(arena, node.id),
3261                    batch: batch as u32,
3262                    seq: seq as u32,
3263                    kv_seq: kv_seq as u32,
3264                    heads: *num_heads as u32,
3265                    head_dim: *head_dim as u32,
3266                    mask_kind: *mask_kind,
3267                    // Defaults: each input is its own contiguous buffer
3268                    // with row stride = hidden. Rewritten by the
3269                    // Narrow→Attention fusion when applicable.
3270                    q_row_stride: hs,
3271                    k_row_stride: hs,
3272                    v_row_stride: hs,
3273                    bhsd,
3274                }
3275            }
3276
3277            Op::AttentionBackward {
3278                num_heads,
3279                head_dim,
3280                mask_kind,
3281                wrt,
3282            } => {
3283                let q_shape = &graph.node(node.inputs[0]).shape;
3284                let k_shape = &graph.node(node.inputs[1]).shape;
3285                let rank = q_shape.rank();
3286                let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3287                    let d1 = q_shape.dim(1).unwrap_static();
3288                    let d2 = q_shape.dim(2).unwrap_static();
3289                    if d1 == *num_heads {
3290                        (
3291                            q_shape.dim(0).unwrap_static(),
3292                            d2,
3293                            k_shape.dim(2).unwrap_static(),
3294                            true,
3295                        )
3296                    } else {
3297                        (
3298                            q_shape.dim(0).unwrap_static(),
3299                            d1,
3300                            k_shape.dim(1).unwrap_static(),
3301                            false,
3302                        )
3303                    }
3304                } else if rank >= 3 {
3305                    (
3306                        q_shape.dim(0).unwrap_static(),
3307                        q_shape.dim(1).unwrap_static(),
3308                        k_shape.dim(1).unwrap_static(),
3309                        false,
3310                    )
3311                } else {
3312                    (
3313                        1,
3314                        q_shape.dim(0).unwrap_static(),
3315                        k_shape.dim(0).unwrap_static(),
3316                        false,
3317                    )
3318                };
3319                let mask_off = if matches!(
3320                    mask_kind,
3321                    rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3322                ) {
3323                    node_offset(arena, node.inputs[4])
3324                } else {
3325                    0
3326                };
3327                Thunk::AttentionBackward {
3328                    q: node_offset(arena, node.inputs[0]),
3329                    k: node_offset(arena, node.inputs[1]),
3330                    v: node_offset(arena, node.inputs[2]),
3331                    dy: node_offset(arena, node.inputs[3]),
3332                    mask: mask_off,
3333                    out: node_offset(arena, node.id),
3334                    batch: batch as u32,
3335                    seq: seq as u32,
3336                    kv_seq: kv_seq as u32,
3337                    heads: *num_heads as u32,
3338                    head_dim: *head_dim as u32,
3339                    mask_kind: *mask_kind,
3340                    wrt: *wrt,
3341                    bhsd,
3342                }
3343            }
3344
3345            Op::FusedAttentionBlock {
3346                num_heads,
3347                head_dim,
3348                has_bias,
3349                has_rope,
3350            } => {
3351                let x_shape = &graph.node(node.inputs[0]).shape;
3352                let (batch, seq) = if x_shape.rank() >= 3 {
3353                    (
3354                        x_shape.dim(0).unwrap_static(),
3355                        x_shape.dim(1).unwrap_static(),
3356                    )
3357                } else {
3358                    let total = x_shape.num_elements().unwrap();
3359                    let s = x_shape.dim(x_shape.rank() - 2).unwrap_static();
3360                    (total / (s * num_heads * head_dim), s)
3361                };
3362                let hs = (*num_heads * *head_dim) as u32;
3363                // Inputs: hidden, qkv_w, out_w, mask, [qkv_b, out_b], [cos, sin]
3364                let mut idx = 4;
3365                let (qkv_b_off, out_b_off) = if *has_bias {
3366                    let qb = node_offset(arena, node.inputs[idx]);
3367                    let ob = node_offset(arena, node.inputs[idx + 1]);
3368                    idx += 2;
3369                    (qb, ob)
3370                } else {
3371                    (0, 0)
3372                };
3373                let (cos_off, sin_off, cl) = if *has_rope {
3374                    let c = node_offset(arena, node.inputs[idx]);
3375                    let s = node_offset(arena, node.inputs[idx + 1]);
3376                    let clen = get_len(graph, node.inputs[idx]);
3377                    (c, s, clen as u32)
3378                } else {
3379                    (0, 0, 0)
3380                };
3381
3382                Thunk::FusedAttnBlock {
3383                    hidden: node_offset(arena, node.inputs[0]),
3384                    qkv_w: node_offset(arena, node.inputs[1]),
3385                    out_w: node_offset(arena, node.inputs[2]),
3386                    mask: node_offset(arena, node.inputs[3]),
3387                    out: node_offset(arena, node.id),
3388                    qkv_b: qkv_b_off,
3389                    out_b: out_b_off,
3390                    cos: cos_off,
3391                    sin: sin_off,
3392                    cos_len: cl,
3393                    batch: batch as u32,
3394                    seq: seq as u32,
3395                    hs,
3396                    nh: *num_heads as u32,
3397                    dh: *head_dim as u32,
3398                    has_bias: *has_bias,
3399                    has_rope: *has_rope,
3400                }
3401            }
3402
3403            Op::Rope { head_dim, n_rot } => {
3404                let x_shape = &graph.node(node.inputs[0]).shape;
3405                let (batch, seq, hidden) = if x_shape.rank() >= 3 {
3406                    (
3407                        x_shape.dim(0).unwrap_static(),
3408                        x_shape.dim(1).unwrap_static(),
3409                        x_shape.dim(2).unwrap_static(),
3410                    )
3411                } else {
3412                    let total = x_shape.num_elements().unwrap();
3413                    (
3414                        1,
3415                        x_shape.dim(0).unwrap_static(),
3416                        total / x_shape.dim(0).unwrap_static(),
3417                    )
3418                };
3419                let cos_len = get_len(graph, node.inputs[1]);
3420                Thunk::Rope {
3421                    src: node_offset(arena, node.inputs[0]),
3422                    cos: node_offset(arena, node.inputs[1]),
3423                    sin: node_offset(arena, node.inputs[2]),
3424                    dst: node_offset(arena, node.id),
3425                    batch: batch as u32,
3426                    seq: seq as u32,
3427                    hidden: hidden as u32,
3428                    head_dim: *head_dim as u32,
3429                    n_rot: *n_rot as u32,
3430                    cos_len: cos_len as u32,
3431                    // Default: source rows are tightly packed (rewritten
3432                    // by the Narrow→Rope fusion pass below if Rope ends
3433                    // up reading from a wider parent like QKV).
3434                    src_row_stride: hidden as u32,
3435                }
3436            }
3437
3438            Op::FusedSwiGLU {
3439                cast_to: _,
3440                gate_first,
3441            } => {
3442                let n_half = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3443                let total = node.shape.num_elements().unwrap();
3444                Thunk::FusedSwiGLU {
3445                    src: node_offset(arena, node.inputs[0]),
3446                    dst: node_offset(arena, node.id),
3447                    n_half: n_half as u32,
3448                    total: total as u32,
3449                    gate_first: *gate_first,
3450                }
3451            }
3452
3453            Op::Conv {
3454                kernel_size,
3455                stride,
3456                padding,
3457                dilation,
3458                groups,
3459            } => {
3460                let in_shape = &graph.node(node.inputs[0]).shape;
3461                let w_shape = &graph.node(node.inputs[1]).shape;
3462                let out_shape = &node.shape;
3463                // 1×1 fast path (plan #26): kH=kW=1, stride=1,
3464                // padding=0, dilation=1, groups=1. Emits a single
3465                // Conv2D1x1 thunk that BLAS-dispatches per batch.
3466                let is_1x1_simple = kernel_size.len() == 2
3467                    && kernel_size[0] == 1
3468                    && kernel_size[1] == 1
3469                    && stride.iter().all(|&s| s == 1)
3470                    && padding.iter().all(|&p| p == 0)
3471                    && dilation.iter().all(|&d| d == 1)
3472                    && *groups == 1;
3473                if is_1x1_simple && in_shape.rank() == 4 && out_shape.rank() == 4 {
3474                    let n = in_shape.dim(0).unwrap_static();
3475                    let c_in = in_shape.dim(1).unwrap_static();
3476                    let c_out = out_shape.dim(1).unwrap_static();
3477                    let h = in_shape.dim(2).unwrap_static();
3478                    let w = in_shape.dim(3).unwrap_static();
3479                    Thunk::Conv2D1x1 {
3480                        src: node_offset(arena, node.inputs[0]),
3481                        weight: node_offset(arena, node.inputs[1]),
3482                        dst: node_offset(arena, node.id),
3483                        n: n as u32,
3484                        c_in: c_in as u32,
3485                        c_out: c_out as u32,
3486                        hw: (h * w) as u32,
3487                    }
3488                } else if kernel_size.len() == 2
3489                    && in_shape.rank() == 4
3490                    && w_shape.rank() == 4
3491                    && out_shape.rank() == 4
3492                {
3493                    Thunk::Conv2D {
3494                        src: node_offset(arena, node.inputs[0]),
3495                        weight: node_offset(arena, node.inputs[1]),
3496                        dst: node_offset(arena, node.id),
3497                        n: in_shape.dim(0).unwrap_static() as u32,
3498                        c_in: in_shape.dim(1).unwrap_static() as u32,
3499                        h: in_shape.dim(2).unwrap_static() as u32,
3500                        w: in_shape.dim(3).unwrap_static() as u32,
3501                        c_out: out_shape.dim(1).unwrap_static() as u32,
3502                        h_out: out_shape.dim(2).unwrap_static() as u32,
3503                        w_out: out_shape.dim(3).unwrap_static() as u32,
3504                        kh: kernel_size[0] as u32,
3505                        kw: kernel_size[1] as u32,
3506                        sh: stride.first().copied().unwrap_or(1) as u32,
3507                        sw: stride.get(1).copied().unwrap_or(1) as u32,
3508                        ph: padding.first().copied().unwrap_or(0) as u32,
3509                        pw: padding.get(1).copied().unwrap_or(0) as u32,
3510                        dh: dilation.first().copied().unwrap_or(1) as u32,
3511                        dw: dilation.get(1).copied().unwrap_or(1) as u32,
3512                        groups: *groups as u32,
3513                    }
3514                } else {
3515                    Thunk::Nop
3516                }
3517            }
3518
3519            Op::Pool {
3520                kind,
3521                kernel_size,
3522                stride,
3523                padding,
3524            } => {
3525                // Currently support 2D pooling on rank-4 NCHW tensors.
3526                let in_shape = &graph.node(node.inputs[0]).shape;
3527                let out_shape = &node.shape;
3528                if kernel_size.len() == 2 && in_shape.rank() == 4 && out_shape.rank() == 4 {
3529                    Thunk::Pool2D {
3530                        src: node_offset(arena, node.inputs[0]),
3531                        dst: node_offset(arena, node.id),
3532                        n: in_shape.dim(0).unwrap_static() as u32,
3533                        c: in_shape.dim(1).unwrap_static() as u32,
3534                        h: in_shape.dim(2).unwrap_static() as u32,
3535                        w: in_shape.dim(3).unwrap_static() as u32,
3536                        h_out: out_shape.dim(2).unwrap_static() as u32,
3537                        w_out: out_shape.dim(3).unwrap_static() as u32,
3538                        kh: kernel_size[0] as u32,
3539                        kw: kernel_size[1] as u32,
3540                        sh: stride.first().copied().unwrap_or(1) as u32,
3541                        sw: stride.get(1).copied().unwrap_or(1) as u32,
3542                        ph: padding.first().copied().unwrap_or(0) as u32,
3543                        pw: padding.get(1).copied().unwrap_or(0) as u32,
3544                        kind: *kind,
3545                    }
3546                } else {
3547                    Thunk::Nop
3548                }
3549            }
3550
3551            Op::Transpose { perm } => {
3552                // Pre-compute (out_dims, in_strides_for_each_out_dim) so the
3553                // runtime loop is just an N-D index walk + scatter.
3554                let in_shape = &graph.node(node.inputs[0]).shape;
3555                let in_rank = in_shape.rank();
3556                let in_dims: Vec<usize> = (0..in_rank)
3557                    .map(|i| in_shape.dim(i).unwrap_static())
3558                    .collect();
3559                // Row-major input strides: stride[d] = product of dims[d+1..].
3560                let mut in_strides_full = vec![1usize; in_rank];
3561                for d in (0..in_rank.saturating_sub(1)).rev() {
3562                    in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
3563                }
3564                let out_dims: Vec<u32> = perm.iter().map(|&p| in_dims[p] as u32).collect();
3565                let in_strides: Vec<u32> =
3566                    perm.iter().map(|&p| in_strides_full[p] as u32).collect();
3567                let in_total = in_dims.iter().product::<usize>() as u32;
3568                let src = node_offset(arena, node.inputs[0]);
3569                let dst = node_offset(arena, node.id);
3570                match node.shape.dtype() {
3571                    rlx_ir::DType::F64 => Thunk::TransposeF64 {
3572                        src,
3573                        dst,
3574                        in_total,
3575                        out_dims,
3576                        in_strides,
3577                    },
3578                    _ => Thunk::Transpose {
3579                        src,
3580                        dst,
3581                        in_total,
3582                        out_dims,
3583                        in_strides,
3584                    },
3585                }
3586            }
3587
3588            Op::ScatterAdd => {
3589                // updates: [num_updates, ...trailing], indices: [num_updates],
3590                // output: [out_dim, ...trailing]
3591                let upd_shape = &graph.node(node.inputs[0]).shape;
3592                let out_shape = &node.shape;
3593                let num_updates = upd_shape.dim(0).unwrap_static();
3594                let out_dim = out_shape.dim(0).unwrap_static();
3595                let trailing: usize = (1..out_shape.rank())
3596                    .map(|i| out_shape.dim(i).unwrap_static())
3597                    .product::<usize>()
3598                    .max(1);
3599                Thunk::ScatterAdd {
3600                    updates: node_offset(arena, node.inputs[0]),
3601                    indices: node_offset(arena, node.inputs[1]),
3602                    dst: node_offset(arena, node.id),
3603                    num_updates: num_updates as u32,
3604                    out_dim: out_dim as u32,
3605                    trailing: trailing as u32,
3606                }
3607            }
3608
3609            Op::GroupedMatMul => {
3610                // Inputs: [input(M, K), weight(E, K, N), expert_idx(M)]
3611                let in_shape = &graph.node(node.inputs[0]).shape;
3612                let w_shape = &graph.node(node.inputs[1]).shape;
3613                let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
3614                let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
3615                let num_experts = w_shape.dim(0).unwrap_static();
3616                let n = w_shape.dim(2).unwrap_static();
3617                Thunk::GroupedMatMul {
3618                    input: node_offset(arena, node.inputs[0]),
3619                    weight: node_offset(arena, node.inputs[1]),
3620                    expert_idx: node_offset(arena, node.inputs[2]),
3621                    dst: node_offset(arena, node.id),
3622                    m: m as u32,
3623                    k_dim: k_dim as u32,
3624                    n: n as u32,
3625                    num_experts: num_experts as u32,
3626                }
3627            }
3628
3629            Op::DequantGroupedMatMul { scheme } => {
3630                let in_shape = &graph.node(node.inputs[0]).shape;
3631                let w_shape = &graph.node(node.inputs[1]).shape;
3632                let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
3633                let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
3634                let out_shape = &node.shape;
3635                let n = out_shape.dim(out_shape.rank() - 1).unwrap_static();
3636                let block_elems = scheme.gguf_block_size() as usize;
3637                let block_bytes = scheme.gguf_block_bytes() as usize;
3638                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
3639                let total_bytes = w_shape.num_elements().unwrap();
3640                let num_experts = total_bytes / slab_bytes.max(1);
3641                Thunk::DequantGroupedMatMulGguf {
3642                    input: node_offset(arena, node.inputs[0]),
3643                    w_q: node_offset(arena, node.inputs[1]),
3644                    expert_idx: node_offset(arena, node.inputs[2]),
3645                    dst: node_offset(arena, node.id),
3646                    m: m as u32,
3647                    k_dim: k_dim as u32,
3648                    n: n as u32,
3649                    num_experts: num_experts as u32,
3650                    scheme: *scheme,
3651                }
3652            }
3653
3654            Op::DequantMoEWeights { scheme } => {
3655                let w_shape = &graph.node(node.inputs[0]).shape;
3656                let out_shape = &node.shape;
3657                let num_experts = out_shape.dim(0).unwrap_static();
3658                let k_dim = out_shape.dim(1).unwrap_static();
3659                let n = out_shape.dim(2).unwrap_static();
3660                let block_elems = scheme.gguf_block_size() as usize;
3661                let block_bytes = scheme.gguf_block_bytes() as usize;
3662                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
3663                let total_bytes = w_shape.num_elements().unwrap();
3664                assert_eq!(
3665                    total_bytes,
3666                    num_experts * slab_bytes,
3667                    "DequantMoEWeights packed bytes mismatch"
3668                );
3669                Thunk::DequantMoEWeightsGguf {
3670                    w_q: node_offset(arena, node.inputs[0]),
3671                    dst: node_offset(arena, node.id),
3672                    k_dim: k_dim as u32,
3673                    n: n as u32,
3674                    num_experts: num_experts as u32,
3675                    scheme: *scheme,
3676                }
3677            }
3678
3679            Op::TopK { k } => {
3680                let in_shape = &graph.node(node.inputs[0]).shape;
3681                let rank = in_shape.rank();
3682                let axis_dim = in_shape.dim(rank - 1).unwrap_static();
3683                let outer = in_shape.num_elements().unwrap() / axis_dim;
3684                Thunk::TopK {
3685                    src: node_offset(arena, node.inputs[0]),
3686                    dst: node_offset(arena, node.id),
3687                    outer: outer as u32,
3688                    axis_dim: axis_dim as u32,
3689                    k: *k as u32,
3690                }
3691            }
3692
3693            Op::Reduce {
3694                op,
3695                axes,
3696                keep_dim: _,
3697            } => {
3698                // Decompose the input shape into [outer, reduced, inner]
3699                // around the reduced axis range. Non-contiguous reduced
3700                // axes aren't supported here — caller must transpose them
3701                // contiguous first (the coverage tool would surface the
3702                // gap if a model needs it).
3703                let in_shape = &graph.node(node.inputs[0]).shape;
3704                let rank = in_shape.rank();
3705                let mut sorted = axes.clone();
3706                sorted.sort();
3707                sorted.dedup();
3708                let contiguous = sorted.windows(2).all(|w| w[1] == w[0] + 1)
3709                    && !sorted.is_empty()
3710                    && *sorted.last().unwrap() < rank;
3711                if !contiguous {
3712                    Thunk::Nop
3713                } else {
3714                    let first = sorted[0];
3715                    let last = *sorted.last().unwrap();
3716                    let outer: usize = (0..first)
3717                        .map(|i| in_shape.dim(i).unwrap_static())
3718                        .product::<usize>()
3719                        .max(1);
3720                    let reduced: usize = (first..=last)
3721                        .map(|i| in_shape.dim(i).unwrap_static())
3722                        .product();
3723                    let inner: usize = (last + 1..rank)
3724                        .map(|i| in_shape.dim(i).unwrap_static())
3725                        .product::<usize>()
3726                        .max(1);
3727                    let src = node_offset(arena, node.inputs[0]);
3728                    let dst = node_offset(arena, node.id);
3729                    if node.shape.dtype() == rlx_ir::DType::F64 && matches!(op, ReduceOp::Sum) {
3730                        Thunk::ReduceSumF64 {
3731                            src,
3732                            dst,
3733                            outer: outer as u32,
3734                            reduced: reduced as u32,
3735                            inner: inner as u32,
3736                        }
3737                    } else {
3738                        Thunk::Reduce {
3739                            src,
3740                            dst,
3741                            outer: outer as u32,
3742                            reduced: reduced as u32,
3743                            inner: inner as u32,
3744                            op: *op,
3745                        }
3746                    }
3747                }
3748            }
3749
3750            Op::Compare(cmp) => {
3751                let len = node.shape.num_elements().unwrap();
3752                Thunk::Compare {
3753                    lhs: node_offset(arena, node.inputs[0]),
3754                    rhs: node_offset(arena, node.inputs[1]),
3755                    dst: node_offset(arena, node.id),
3756                    len: len as u32,
3757                    op: *cmp,
3758                }
3759            }
3760
3761            Op::Where => {
3762                let len = node.shape.num_elements().unwrap();
3763                Thunk::Where {
3764                    cond: node_offset(arena, node.inputs[0]),
3765                    on_true: node_offset(arena, node.inputs[1]),
3766                    on_false: node_offset(arena, node.inputs[2]),
3767                    dst: node_offset(arena, node.id),
3768                    len: len as u32,
3769                }
3770            }
3771
3772            Op::ReluBackward => {
3773                let len: usize = (0..node.shape.rank())
3774                    .map(|i| node.shape.dim(i).unwrap_static())
3775                    .product();
3776                let x = node_offset(arena, node.inputs[0]);
3777                let dy = node_offset(arena, node.inputs[1]);
3778                let dx = node_offset(arena, node.id);
3779                match node.shape.dtype() {
3780                    rlx_ir::DType::F64 => Thunk::ReluBackwardF64 {
3781                        x,
3782                        dy,
3783                        dx,
3784                        len: len as u32,
3785                    },
3786                    _ => Thunk::ReluBackward {
3787                        x,
3788                        dy,
3789                        dx,
3790                        len: len as u32,
3791                    },
3792                }
3793            }
3794
3795            Op::ComplexNormSq => {
3796                let len: usize = (0..node.shape.rank())
3797                    .map(|i| node.shape.dim(i).unwrap_static())
3798                    .product();
3799                let src = node_offset(arena, node.inputs[0]);
3800                let dst = node_offset(arena, node.id);
3801                Thunk::ComplexNormSqF32 {
3802                    src,
3803                    dst,
3804                    len: len as u32,
3805                }
3806            }
3807
3808            Op::ComplexNormSqBackward => {
3809                let len: usize = (0..node.shape.rank())
3810                    .map(|i| node.shape.dim(i).unwrap_static())
3811                    .product();
3812                let z = node_offset(arena, node.inputs[0]);
3813                let g = node_offset(arena, node.inputs[1]);
3814                let dz = node_offset(arena, node.id);
3815                Thunk::ComplexNormSqBackwardF32 {
3816                    z,
3817                    g,
3818                    dz,
3819                    len: len as u32,
3820                }
3821            }
3822
3823            Op::Conjugate => {
3824                let len: usize = (0..node.shape.rank())
3825                    .map(|i| node.shape.dim(i).unwrap_static())
3826                    .product();
3827                Thunk::ConjugateC64 {
3828                    src: node_offset(arena, node.inputs[0]),
3829                    dst: node_offset(arena, node.id),
3830                    len: len as u32,
3831                }
3832            }
3833
3834            Op::ActivationBackward { kind } => {
3835                let len: usize = (0..node.shape.rank())
3836                    .map(|i| node.shape.dim(i).unwrap_static())
3837                    .product();
3838                let x = node_offset(arena, node.inputs[0]);
3839                let dy = node_offset(arena, node.inputs[1]);
3840                let dx = node_offset(arena, node.id);
3841                match node.shape.dtype() {
3842                    rlx_ir::DType::F64 => Thunk::ActivationBackwardF64 {
3843                        x,
3844                        dy,
3845                        dx,
3846                        len: len as u32,
3847                        kind: *kind,
3848                    },
3849                    _ => Thunk::ActivationBackward {
3850                        x,
3851                        dy,
3852                        dx,
3853                        len: len as u32,
3854                        kind: *kind,
3855                    },
3856                }
3857            }
3858
3859            Op::LayerNormBackwardInput { eps, .. } => {
3860                // axis = -1 only (matches forward LayerNorm thunk).
3861                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3862                let total = node.shape.num_elements().unwrap();
3863                Thunk::LayerNormBackwardInput {
3864                    x: node_offset(arena, node.inputs[0]),
3865                    gamma: node_offset(arena, node.inputs[1]),
3866                    dy: node_offset(arena, node.inputs[2]),
3867                    dx: node_offset(arena, node.id),
3868                    rows: (total / h) as u32,
3869                    h: h as u32,
3870                    eps: *eps,
3871                }
3872            }
3873
3874            Op::LayerNormBackwardGamma { eps, .. } => {
3875                let x_shape = &graph.node(node.inputs[0]).shape;
3876                let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
3877                let x_total = x_shape.num_elements().unwrap();
3878                Thunk::LayerNormBackwardGamma {
3879                    x: node_offset(arena, node.inputs[0]),
3880                    dy: node_offset(arena, node.inputs[1]),
3881                    dgamma: node_offset(arena, node.id),
3882                    rows: (x_total / h) as u32,
3883                    h: h as u32,
3884                    eps: *eps,
3885                }
3886            }
3887
3888            Op::RmsNormBackwardInput { eps, .. }
3889            | Op::RmsNormBackwardGamma { eps, .. }
3890            | Op::RmsNormBackwardBeta { eps, .. } => {
3891                let x_shape = &graph.node(node.inputs[0]).shape;
3892                let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
3893                let rows = (x_shape.num_elements().unwrap() / h) as u32;
3894                let off = |i: usize| node_offset(arena, node.inputs[i]);
3895                let common = (off(0), off(1), off(2), off(3), rows, h as u32, *eps);
3896                match &node.op {
3897                    Op::RmsNormBackwardInput { .. } => Thunk::RmsNormBackwardInput {
3898                        x: common.0,
3899                        gamma: common.1,
3900                        beta: common.2,
3901                        dy: common.3,
3902                        dx: node_offset(arena, node.id),
3903                        rows: common.4,
3904                        h: common.5,
3905                        eps: common.6,
3906                    },
3907                    Op::RmsNormBackwardGamma { .. } => Thunk::RmsNormBackwardGamma {
3908                        x: common.0,
3909                        gamma: common.1,
3910                        beta: common.2,
3911                        dy: common.3,
3912                        dgamma: node_offset(arena, node.id),
3913                        rows: common.4,
3914                        h: common.5,
3915                        eps: common.6,
3916                    },
3917                    Op::RmsNormBackwardBeta { .. } => Thunk::RmsNormBackwardBeta {
3918                        x: common.0,
3919                        gamma: common.1,
3920                        beta: common.2,
3921                        dy: common.3,
3922                        dbeta: node_offset(arena, node.id),
3923                        rows: common.4,
3924                        h: common.5,
3925                        eps: common.6,
3926                    },
3927                    _ => unreachable!(),
3928                }
3929            }
3930
3931            Op::RopeBackward { head_dim, n_rot } => {
3932                let dy_shape = &graph.node(node.inputs[0]).shape;
3933                let (batch, seq, hidden) = if dy_shape.rank() >= 3 {
3934                    (
3935                        dy_shape.dim(0).unwrap_static(),
3936                        dy_shape.dim(1).unwrap_static(),
3937                        dy_shape.dim(2).unwrap_static(),
3938                    )
3939                } else {
3940                    (
3941                        1,
3942                        dy_shape.dim(0).unwrap_static(),
3943                        dy_shape.dim(1).unwrap_static(),
3944                    )
3945                };
3946                let cos_shape = &graph.node(node.inputs[1]).shape;
3947                let cos_len = cos_shape.num_elements().unwrap();
3948                Thunk::RopeBackward {
3949                    dy: node_offset(arena, node.inputs[0]),
3950                    cos: node_offset(arena, node.inputs[1]),
3951                    sin: node_offset(arena, node.inputs[2]),
3952                    dx: node_offset(arena, node.id),
3953                    batch: batch as u32,
3954                    seq: seq as u32,
3955                    hidden: hidden as u32,
3956                    head_dim: *head_dim as u32,
3957                    n_rot: *n_rot as u32,
3958                    cos_len: cos_len as u32,
3959                }
3960            }
3961
3962            Op::CumsumBackward { exclusive, .. } => {
3963                let dy_shape = &graph.node(node.inputs[0]).shape;
3964                let rank = dy_shape.rank();
3965                let cols = dy_shape.dim(rank - 1).unwrap_static();
3966                let rows = dy_shape.num_elements().unwrap() / cols;
3967                Thunk::CumsumBackward {
3968                    dy: node_offset(arena, node.inputs[0]),
3969                    dx: node_offset(arena, node.id),
3970                    rows: rows as u32,
3971                    cols: cols as u32,
3972                    exclusive: *exclusive,
3973                }
3974            }
3975
3976            Op::GatherBackward { .. } => {
3977                let dy_shape = &graph.node(node.inputs[0]).shape;
3978                let idx_shape = &graph.node(node.inputs[1]).shape;
3979                let out_shape = &node.shape;
3980                let rank = out_shape.rank();
3981                let axis = match &node.op {
3982                    Op::GatherBackward { axis } => *axis,
3983                    _ => 0,
3984                };
3985                let axis_u = if axis < 0 {
3986                    (rank as i32 + axis) as usize
3987                } else {
3988                    axis as usize
3989                };
3990                let outer: usize = (0..axis_u)
3991                    .map(|i| dy_shape.dim(i).unwrap_static())
3992                    .product::<usize>()
3993                    .max(1);
3994                let num_idx = idx_shape.dim(axis_u).unwrap_static();
3995                let trailing: usize = (axis_u + 1..dy_shape.rank())
3996                    .map(|i| dy_shape.dim(i).unwrap_static())
3997                    .product::<usize>()
3998                    .max(1);
3999                let axis_dim = out_shape.dim(axis_u).unwrap_static();
4000                Thunk::GatherBackward {
4001                    dy: node_offset(arena, node.inputs[0]),
4002                    indices: node_offset(arena, node.inputs[1]),
4003                    dst: node_offset(arena, node.id),
4004                    outer: outer as u32,
4005                    axis_dim: axis_dim as u32,
4006                    num_idx: num_idx as u32,
4007                    trailing: trailing as u32,
4008                }
4009            }
4010
4011            Op::GroupNormBackwardInput { num_groups, eps }
4012            | Op::GroupNormBackwardGamma { num_groups, eps }
4013            | Op::GroupNormBackwardBeta { num_groups, eps } => {
4014                let x_shape = &graph.node(node.inputs[0]).shape;
4015                let n = x_shape.dim(0).unwrap_static() as u32;
4016                let c = x_shape.dim(1).unwrap_static() as u32;
4017                let h = x_shape.dim(2).unwrap_static() as u32;
4018                let w = x_shape.dim(3).unwrap_static() as u32;
4019                match &node.op {
4020                    Op::GroupNormBackwardInput { .. } => Thunk::GroupNormBackwardInput {
4021                        x: node_offset(arena, node.inputs[0]),
4022                        gamma: node_offset(arena, node.inputs[1]),
4023                        beta: node_offset(arena, node.inputs[2]),
4024                        dy: node_offset(arena, node.inputs[3]),
4025                        dx: node_offset(arena, node.id),
4026                        n,
4027                        c,
4028                        h,
4029                        w,
4030                        num_groups: *num_groups as u32,
4031                        eps: *eps,
4032                    },
4033                    Op::GroupNormBackwardGamma { .. } => Thunk::GroupNormBackwardGamma {
4034                        x: node_offset(arena, node.inputs[0]),
4035                        dy: node_offset(arena, node.inputs[1]),
4036                        dgamma: node_offset(arena, node.id),
4037                        n,
4038                        c,
4039                        h,
4040                        w,
4041                        num_groups: *num_groups as u32,
4042                        eps: *eps,
4043                    },
4044                    Op::GroupNormBackwardBeta { .. } => Thunk::GroupNormBackwardBeta {
4045                        dy: node_offset(arena, node.inputs[1]),
4046                        dbeta: node_offset(arena, node.id),
4047                        n,
4048                        c,
4049                        h,
4050                        w,
4051                    },
4052                    _ => unreachable!(),
4053                }
4054            }
4055
4056            Op::MaxPool2dBackward {
4057                kernel_size,
4058                stride,
4059                padding,
4060            } => {
4061                let x_shape = &graph.node(node.inputs[0]).shape;
4062                let dy_shape = &graph.node(node.inputs[1]).shape;
4063                if kernel_size.len() == 2 && x_shape.rank() == 4 && dy_shape.rank() == 4 {
4064                    Thunk::MaxPool2dBackward {
4065                        x: node_offset(arena, node.inputs[0]),
4066                        dy: node_offset(arena, node.inputs[1]),
4067                        dx: node_offset(arena, node.id),
4068                        n: x_shape.dim(0).unwrap_static() as u32,
4069                        c: x_shape.dim(1).unwrap_static() as u32,
4070                        h: x_shape.dim(2).unwrap_static() as u32,
4071                        w: x_shape.dim(3).unwrap_static() as u32,
4072                        h_out: dy_shape.dim(2).unwrap_static() as u32,
4073                        w_out: dy_shape.dim(3).unwrap_static() as u32,
4074                        kh: kernel_size[0] as u32,
4075                        kw: kernel_size[1] as u32,
4076                        sh: stride.first().copied().unwrap_or(1) as u32,
4077                        sw: stride.get(1).copied().unwrap_or(1) as u32,
4078                        ph: padding.first().copied().unwrap_or(0) as u32,
4079                        pw: padding.get(1).copied().unwrap_or(0) as u32,
4080                    }
4081                } else {
4082                    Thunk::Nop
4083                }
4084            }
4085
4086            Op::Conv2dBackwardInput {
4087                kernel_size,
4088                stride,
4089                padding,
4090                dilation,
4091                groups,
4092            } => {
4093                let dy_shape = &graph.node(node.inputs[0]).shape;
4094                let w_shape = &graph.node(node.inputs[1]).shape;
4095                let out_shape = &node.shape;
4096                if kernel_size.len() == 2
4097                    && dy_shape.rank() == 4
4098                    && w_shape.rank() == 4
4099                    && out_shape.rank() == 4
4100                {
4101                    Thunk::Conv2dBackwardInput {
4102                        dy: node_offset(arena, node.inputs[0]),
4103                        w: node_offset(arena, node.inputs[1]),
4104                        dx: node_offset(arena, node.id),
4105                        n: out_shape.dim(0).unwrap_static() as u32,
4106                        c_in: out_shape.dim(1).unwrap_static() as u32,
4107                        h: out_shape.dim(2).unwrap_static() as u32,
4108                        w_in: out_shape.dim(3).unwrap_static() as u32,
4109                        c_out: dy_shape.dim(1).unwrap_static() as u32,
4110                        h_out: dy_shape.dim(2).unwrap_static() as u32,
4111                        w_out: dy_shape.dim(3).unwrap_static() as u32,
4112                        kh: kernel_size[0] as u32,
4113                        kw: kernel_size[1] as u32,
4114                        sh: stride.first().copied().unwrap_or(1) as u32,
4115                        sw: stride.get(1).copied().unwrap_or(1) as u32,
4116                        ph: padding.first().copied().unwrap_or(0) as u32,
4117                        pw: padding.get(1).copied().unwrap_or(0) as u32,
4118                        dh: dilation.first().copied().unwrap_or(1) as u32,
4119                        dw: dilation.get(1).copied().unwrap_or(1) as u32,
4120                        groups: *groups as u32,
4121                    }
4122                } else {
4123                    Thunk::Nop
4124                }
4125            }
4126
4127            Op::Conv2dBackwardWeight {
4128                kernel_size,
4129                stride,
4130                padding,
4131                dilation,
4132                groups,
4133            } => {
4134                let x_shape = &graph.node(node.inputs[0]).shape;
4135                let dy_shape = &graph.node(node.inputs[1]).shape;
4136                let dw_shape = &node.shape;
4137                if kernel_size.len() == 2
4138                    && x_shape.rank() == 4
4139                    && dy_shape.rank() == 4
4140                    && dw_shape.rank() == 4
4141                {
4142                    Thunk::Conv2dBackwardWeight {
4143                        x: node_offset(arena, node.inputs[0]),
4144                        dy: node_offset(arena, node.inputs[1]),
4145                        dw: node_offset(arena, node.id),
4146                        n: x_shape.dim(0).unwrap_static() as u32,
4147                        c_in: x_shape.dim(1).unwrap_static() as u32,
4148                        h: x_shape.dim(2).unwrap_static() as u32,
4149                        w: x_shape.dim(3).unwrap_static() as u32,
4150                        c_out: dy_shape.dim(1).unwrap_static() as u32,
4151                        h_out: dy_shape.dim(2).unwrap_static() as u32,
4152                        w_out: dy_shape.dim(3).unwrap_static() as u32,
4153                        kh: kernel_size[0] as u32,
4154                        kw: kernel_size[1] as u32,
4155                        sh: stride.first().copied().unwrap_or(1) as u32,
4156                        sw: stride.get(1).copied().unwrap_or(1) as u32,
4157                        ph: padding.first().copied().unwrap_or(0) as u32,
4158                        pw: padding.get(1).copied().unwrap_or(0) as u32,
4159                        dh: dilation.first().copied().unwrap_or(1) as u32,
4160                        dw_dil: dilation.get(1).copied().unwrap_or(1) as u32,
4161                        groups: *groups as u32,
4162                    }
4163                } else {
4164                    Thunk::Nop
4165                }
4166            }
4167
4168            Op::SoftmaxCrossEntropyWithLogits => {
4169                let logits_shape = &graph.node(node.inputs[0]).shape;
4170                if logits_shape.rank() == 2 {
4171                    Thunk::SoftmaxCrossEntropy {
4172                        logits: node_offset(arena, node.inputs[0]),
4173                        labels: node_offset(arena, node.inputs[1]),
4174                        dst: node_offset(arena, node.id),
4175                        n: logits_shape.dim(0).unwrap_static() as u32,
4176                        c: logits_shape.dim(1).unwrap_static() as u32,
4177                    }
4178                } else {
4179                    Thunk::Nop
4180                }
4181            }
4182
4183            Op::SoftmaxCrossEntropyBackward => {
4184                let logits_shape = &graph.node(node.inputs[0]).shape;
4185                if logits_shape.rank() == 2 {
4186                    Thunk::SoftmaxCrossEntropyBackward {
4187                        logits: node_offset(arena, node.inputs[0]),
4188                        labels: node_offset(arena, node.inputs[1]),
4189                        d_loss: node_offset(arena, node.inputs[2]),
4190                        dlogits: node_offset(arena, node.id),
4191                        n: logits_shape.dim(0).unwrap_static() as u32,
4192                        c: logits_shape.dim(1).unwrap_static() as u32,
4193                    }
4194                } else {
4195                    Thunk::Nop
4196                }
4197            }
4198
4199            Op::DenseSolve => {
4200                // A: [n, n], b: [n] or [n, nrhs]. Output matches b.
4201                let a_shape = &graph.node(node.inputs[0]).shape;
4202                let n = a_shape.dim(0).unwrap_static();
4203                debug_assert_eq!(
4204                    n,
4205                    a_shape.dim(1).unwrap_static(),
4206                    "DenseSolve: A must be square"
4207                );
4208                let b_elems = node.shape.num_elements().unwrap();
4209                let nrhs = b_elems / n;
4210                match node.shape.dtype() {
4211                    rlx_ir::DType::F64 => Thunk::DenseSolveF64 {
4212                        a: node_offset(arena, node.inputs[0]),
4213                        b: node_offset(arena, node.inputs[1]),
4214                        x: node_offset(arena, node.id),
4215                        n: n as u32,
4216                        nrhs: nrhs as u32,
4217                    },
4218                    rlx_ir::DType::F32 => Thunk::DenseSolveF32 {
4219                        a: node_offset(arena, node.inputs[0]),
4220                        b: node_offset(arena, node.inputs[1]),
4221                        x: node_offset(arena, node.id),
4222                        n: n as u32,
4223                        nrhs: nrhs as u32,
4224                    },
4225                    other => panic!(
4226                        "DenseSolve: F32 + F64 lowered; got {other:?}. \
4227                         Add another variant when needed."
4228                    ),
4229                }
4230            }
4231
4232            Op::BatchedDenseSolve => {
4233                // A: [B, N, N], b: [B, N] or [B, N, K]. Output matches b.
4234                let a_shape = &graph.node(node.inputs[0]).shape;
4235                assert_eq!(a_shape.rank(), 3, "BatchedDenseSolve: A rank must be 3");
4236                let batch = a_shape.dim(0).unwrap_static();
4237                let n = a_shape.dim(1).unwrap_static();
4238                debug_assert_eq!(
4239                    n,
4240                    a_shape.dim(2).unwrap_static(),
4241                    "BatchedDenseSolve: A's last two dims must match"
4242                );
4243                let total = node.shape.num_elements().unwrap();
4244                let nrhs = total / (batch * n);
4245                match node.shape.dtype() {
4246                    rlx_ir::DType::F32 => Thunk::BatchedDenseSolveF32 {
4247                        a: node_offset(arena, node.inputs[0]),
4248                        b: node_offset(arena, node.inputs[1]),
4249                        x: node_offset(arena, node.id),
4250                        batch: batch as u32,
4251                        n: n as u32,
4252                        nrhs: nrhs as u32,
4253                    },
4254                    rlx_ir::DType::F64 => Thunk::BatchedDenseSolveF64 {
4255                        a: node_offset(arena, node.inputs[0]),
4256                        b: node_offset(arena, node.inputs[1]),
4257                        x: node_offset(arena, node.id),
4258                        batch: batch as u32,
4259                        n: n as u32,
4260                        nrhs: nrhs as u32,
4261                    },
4262                    other => panic!("BatchedDenseSolve: F32 + F64 only, got {other:?}"),
4263                }
4264            }
4265
4266            Op::Scan {
4267                body,
4268                length,
4269                save_trajectory,
4270                num_bcast,
4271                num_xs,
4272                num_checkpoints,
4273            } => {
4274                assert!(
4275                    *num_checkpoints == 0 || *num_checkpoints <= *length,
4276                    "Op::Scan: num_checkpoints={} must be 0 or ≤ length={}",
4277                    *num_checkpoints,
4278                    *length
4279                );
4280                if *num_checkpoints != 0 && *num_checkpoints != *length {
4281                    assert!(
4282                        *save_trajectory,
4283                        "Op::Scan: num_checkpoints<length only meaningful when save_trajectory=true"
4284                    );
4285                }
4286                // Plan + compile the body sub-graph standalone. The body
4287                // gets its own Arena; per execution we clone its
4288                // pristine bytes, copy the outer carry (and per-step xs
4289                // slices, if any) into the body's Input slots, run the
4290                // body schedule N times, then copy the body's output
4291                // back to the outer arena.
4292                //
4293                // Body invariants: 1 + num_xs Op::Inputs in NodeId order
4294                // — first declared is the carry, rest are x_t_i. Single
4295                // graph output (the next carry), same shape as carry.
4296                let body_plan = rlx_opt::memory::plan_memory(body);
4297                let _body_arena_size = body_plan.arena_size;
4298                // Snapshot per-input byte offsets before plan_memory
4299                // moves into the Arena below.
4300                let body_offsets: HashMap<NodeId, usize> = body_plan
4301                    .assignments
4302                    .iter()
4303                    .map(|(id, slot)| (*id, slot.offset))
4304                    .collect();
4305
4306                // Collect body Input nodes in NodeId order; first is
4307                // carry, rest are per-step xs in matching order.
4308                let mut body_inputs: Vec<NodeId> = body
4309                    .nodes()
4310                    .iter()
4311                    .filter(|n| matches!(n.op, Op::Input { .. }))
4312                    .map(|n| n.id)
4313                    .collect();
4314                body_inputs.sort();
4315                let n_body_inputs = body_inputs.len();
4316                let expected = 1 + *num_bcast as usize + *num_xs as usize;
4317                if n_body_inputs != expected {
4318                    let names: Vec<String> = body
4319                        .nodes()
4320                        .iter()
4321                        .filter_map(|n| match &n.op {
4322                            Op::Input { name } => Some(format!("{}={}", n.id, name)),
4323                            _ => None,
4324                        })
4325                        .collect();
4326                    panic!(
4327                        "Op::Scan body has {} Op::Input nodes; expected {} \
4328                            (1 carry + {} bcast + {} xs). Inputs by NodeId: [{}]",
4329                        n_body_inputs,
4330                        expected,
4331                        *num_bcast,
4332                        *num_xs,
4333                        names.join(", ")
4334                    );
4335                }
4336
4337                let body_input_id = body_inputs[0];
4338                let body_input_off = body_offsets[&body_input_id];
4339                let body_output_id = body
4340                    .outputs
4341                    .first()
4342                    .copied()
4343                    .expect("Op::Scan body must declare one output");
4344                let body_output_off = body_offsets[&body_output_id];
4345
4346                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4347                // Fill body Constant nodes — mirror the outer-graph logic
4348                // in rlx-runtime/src/backend.rs (dtype-aware).
4349                for n in body.nodes() {
4350                    if let Op::Constant { data } = &n.op
4351                        && body_arena.has_buffer(n.id)
4352                        && !data.is_empty()
4353                    {
4354                        match n.shape.dtype() {
4355                            rlx_ir::DType::F64 => {
4356                                let off = body_arena.byte_offset(n.id);
4357                                let buf = body_arena.raw_buf_mut();
4358                                let nbytes = (buf.len() - off).min(data.len());
4359                                buf[off..off + nbytes].copy_from_slice(&data[..nbytes]);
4360                            }
4361                            _ => {
4362                                let buf = body_arena.slice_mut(n.id);
4363                                let n_floats = data.len() / 4;
4364                                let n_lim = buf.len().min(n_floats);
4365                                for i in 0..n_lim {
4366                                    let bytes = [
4367                                        data[i * 4],
4368                                        data[i * 4 + 1],
4369                                        data[i * 4 + 2],
4370                                        data[i * 4 + 3],
4371                                    ];
4372                                    buf[i] = f32::from_le_bytes(bytes);
4373                                }
4374                            }
4375                        }
4376                    }
4377                }
4378                let body_init = body_arena.raw_buf().to_vec();
4379                let body_schedule = compile_thunks(body, &body_arena);
4380
4381                // Carry bytes — for trajectory mode, the outer node's
4382                // shape is [length, *carry_shape], so dividing by length
4383                // gives one row's bytes; the body's input slot still
4384                // holds carry_shape bytes.
4385                let carry_bytes = if *save_trajectory {
4386                    let total = node
4387                        .shape
4388                        .size_bytes()
4389                        .expect("Op::Scan trajectory output must have static shape");
4390                    total / *length as usize
4391                } else {
4392                    node.shape
4393                        .size_bytes()
4394                        .expect("Op::Scan carry must have static shape")
4395                };
4396
4397                // Bcast inputs occupy body_inputs[1..1+num_bcast] and
4398                // outer node.inputs[1..1+num_bcast]. They keep their
4399                // natural shape (no [length, ...] prefix) and are
4400                // copied into body_buf ONCE before the scan loop.
4401                let mut bcast_inputs: Vec<(usize, usize, u32)> =
4402                    Vec::with_capacity(*num_bcast as usize);
4403                for i in 0..*num_bcast as usize {
4404                    let body_b_id = body_inputs[1 + i];
4405                    let body_b_off = body_offsets[&body_b_id];
4406                    let outer_b_id = node.inputs[1 + i];
4407                    let outer_b_off = node_offset(arena, outer_b_id);
4408                    let outer_b_shape = &graph.node(outer_b_id).shape;
4409                    let total = outer_b_shape
4410                        .size_bytes()
4411                        .expect("Op::Scan bcast must have static shape");
4412                    bcast_inputs.push((body_b_off, outer_b_off, total as u32));
4413                }
4414
4415                // xs occupy body_inputs[1+num_bcast..] and node.inputs
4416                // [1+num_bcast..]. Each has shape [length, *per_step];
4417                // per-step bytes = total / length.
4418                let mut xs_inputs: Vec<(usize, usize, u32)> = Vec::with_capacity(*num_xs as usize);
4419                let xs_base = 1 + *num_bcast as usize;
4420                for i in 0..*num_xs as usize {
4421                    let body_x_id = body_inputs[xs_base + i];
4422                    let body_x_off = body_offsets[&body_x_id];
4423                    let outer_xs_id = node.inputs[xs_base + i];
4424                    let outer_xs_off = node_offset(arena, outer_xs_id);
4425                    let outer_xs_shape = &graph.node(outer_xs_id).shape;
4426                    let total = outer_xs_shape
4427                        .size_bytes()
4428                        .expect("Op::Scan xs must have static shape");
4429                    let per_step = total / *length as usize;
4430                    xs_inputs.push((body_x_off, outer_xs_off, per_step as u32));
4431                }
4432
4433                Thunk::Scan {
4434                    body: Arc::new(body_schedule),
4435                    body_init: Arc::new(body_init),
4436                    body_input_off,
4437                    body_output_off,
4438                    outer_init_off: node_offset(arena, node.inputs[0]),
4439                    outer_final_off: node_offset(arena, node.id),
4440                    length: *length,
4441                    carry_bytes: carry_bytes as u32,
4442                    save_trajectory: *save_trajectory,
4443                    xs_inputs: Arc::new(xs_inputs),
4444                    bcast_inputs: Arc::new(bcast_inputs),
4445                    num_checkpoints: *num_checkpoints,
4446                }
4447            }
4448
4449            Op::ScanBackward {
4450                body_vjp,
4451                length,
4452                save_trajectory,
4453                num_xs,
4454                num_checkpoints,
4455                forward_body,
4456            } => {
4457                let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
4458                if is_recursive {
4459                    assert!(
4460                        forward_body.is_some(),
4461                        "Op::ScanBackward with num_checkpoints<length requires forward_body"
4462                    );
4463                }
4464                // body_vjp has signature
4465                //   (carry, x_t_0, ..., x_t_{num_xs-1}, d_output) → dcarry
4466                // Identify slots:
4467                //   * "d_output" by exact name (AD-introduced seed Input).
4468                //   * Remaining Inputs sorted by NodeId — first is the
4469                //     carry mirror, rest are x_t_i mirrors in body's
4470                //     original Op::Input declaration order.
4471                let body_plan = rlx_opt::memory::plan_memory(body_vjp);
4472                let body_offsets: HashMap<NodeId, usize> = body_plan
4473                    .assignments
4474                    .iter()
4475                    .map(|(id, slot)| (*id, slot.offset))
4476                    .collect();
4477                let mut body_d_output_off: Option<usize> = None;
4478                let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
4479                for n in body_vjp.nodes() {
4480                    if let Op::Input { name } = &n.op {
4481                        let off = body_offsets[&n.id];
4482                        if name == "d_output" {
4483                            body_d_output_off = Some(off);
4484                        } else {
4485                            body_other_inputs.push((n.id, off));
4486                        }
4487                    }
4488                }
4489                body_other_inputs.sort_by_key(|(id, _)| *id);
4490                let body_d_output_off =
4491                    body_d_output_off.expect("ScanBackward body_vjp missing 'd_output' Input");
4492                let expected_others = 1 + *num_xs as usize;
4493                assert_eq!(
4494                    body_other_inputs.len(),
4495                    expected_others,
4496                    "ScanBackward body_vjp has {} non-d_output Inputs; \
4497                     expected {} (1 carry + {} xs)",
4498                    body_other_inputs.len(),
4499                    expected_others,
4500                    num_xs
4501                );
4502                let body_carry_in_off = body_other_inputs[0].1;
4503                let body_x_offs: Vec<usize> = body_other_inputs
4504                    .iter()
4505                    .skip(1)
4506                    .map(|(_, off)| *off)
4507                    .collect();
4508                let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
4509
4510                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4511                // Fill body_vjp's Constants (mirrors the Scan lowering).
4512                for n in body_vjp.nodes() {
4513                    if let Op::Constant { data } = &n.op
4514                        && body_arena.has_buffer(n.id)
4515                        && !data.is_empty()
4516                    {
4517                        match n.shape.dtype() {
4518                            rlx_ir::DType::F64 => {
4519                                let off = body_arena.byte_offset(n.id);
4520                                let buf = body_arena.raw_buf_mut();
4521                                let nb = (buf.len() - off).min(data.len());
4522                                buf[off..off + nb].copy_from_slice(&data[..nb]);
4523                            }
4524                            _ => {
4525                                let buf = body_arena.slice_mut(n.id);
4526                                let nf = data.len() / 4;
4527                                let nl = buf.len().min(nf);
4528                                for i in 0..nl {
4529                                    let bytes = [
4530                                        data[i * 4],
4531                                        data[i * 4 + 1],
4532                                        data[i * 4 + 2],
4533                                        data[i * 4 + 3],
4534                                    ];
4535                                    buf[i] = f32::from_le_bytes(bytes);
4536                                }
4537                            }
4538                        }
4539                    }
4540                }
4541                let body_init = body_arena.raw_buf().to_vec();
4542                let body_schedule = compile_thunks(body_vjp, &body_arena);
4543
4544                // Carry bytes from the dcarry output node (== carry shape).
4545                let carry_bytes = body_vjp
4546                    .node(body_vjp.outputs[0])
4547                    .shape
4548                    .size_bytes()
4549                    .expect("ScanBackward dcarry must be statically shaped");
4550                let carry_elem_size = body_vjp
4551                    .node(body_vjp.outputs[0])
4552                    .shape
4553                    .dtype()
4554                    .size_bytes() as u32;
4555
4556                // For each xs input on the outer node:
4557                // (outer_xs_base, per_step_bytes).
4558                let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
4559                for i in 0..*num_xs as usize {
4560                    let outer_xs_id = node.inputs[3 + i];
4561                    let outer_xs_off = node_offset(arena, outer_xs_id);
4562                    let outer_xs_shape = &graph.node(outer_xs_id).shape;
4563                    let total = outer_xs_shape
4564                        .size_bytes()
4565                        .expect("ScanBackward xs must have static shape");
4566                    let per_step = total / *length as usize;
4567                    outer_xs_offs.push((outer_xs_off, per_step as u32));
4568                }
4569
4570                // If recursive checkpointing is active, we also compile
4571                // the forward body so the executor can recompute
4572                // intermediate carries. The forward body is supplied
4573                // by the AD pass via `forward_body: Some(_)`.
4574                let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
4575                    if is_recursive {
4576                        let fb = forward_body.as_ref().unwrap();
4577                        let fb_plan = rlx_opt::memory::plan_memory(fb);
4578                        let fb_offsets: HashMap<NodeId, usize> = fb_plan
4579                            .assignments
4580                            .iter()
4581                            .map(|(id, slot)| (*id, slot.offset))
4582                            .collect();
4583                        let mut fb_inputs: Vec<NodeId> = fb
4584                            .nodes()
4585                            .iter()
4586                            .filter(|n| matches!(n.op, Op::Input { .. }))
4587                            .map(|n| n.id)
4588                            .collect();
4589                        fb_inputs.sort();
4590                        let fb_carry = fb_offsets[&fb_inputs[0]];
4591                        let fb_xs: Vec<usize> = (1..fb_inputs.len())
4592                            .map(|i| fb_offsets[&fb_inputs[i]])
4593                            .collect();
4594                        let fb_out = fb_offsets[&fb.outputs[0]];
4595                        let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
4596                        for n in fb.nodes() {
4597                            if let Op::Constant { data } = &n.op
4598                                && fb_arena.has_buffer(n.id)
4599                                && !data.is_empty()
4600                            {
4601                                // Byte-copy works for any
4602                                // numeric dtype as long as the
4603                                // arena slot is sized to hold
4604                                // it — the Constant's `data`
4605                                // already encodes the right
4606                                // bytes per element.
4607                                let off = fb_arena.byte_offset(n.id);
4608                                let buf = fb_arena.raw_buf_mut();
4609                                let nb = (buf.len() - off).min(data.len());
4610                                buf[off..off + nb].copy_from_slice(&data[..nb]);
4611                            }
4612                        }
4613                        let fb_init_bytes = fb_arena.raw_buf().to_vec();
4614                        let fb_sched = compile_thunks(fb, &fb_arena);
4615                        (
4616                            Some(Arc::new(fb_sched)),
4617                            Some(Arc::new(fb_init_bytes)),
4618                            fb_carry,
4619                            fb_out,
4620                            fb_xs,
4621                        )
4622                    } else {
4623                        (None, None, 0, 0, Vec::new())
4624                    };
4625
4626                Thunk::ScanBackward {
4627                    body_vjp: Arc::new(body_schedule),
4628                    body_init: Arc::new(body_init),
4629                    body_carry_in_off,
4630                    body_x_offs: Arc::new(body_x_offs),
4631                    body_d_output_off,
4632                    body_dcarry_out_off,
4633                    outer_init_off: node_offset(arena, node.inputs[0]),
4634                    outer_traj_off: node_offset(arena, node.inputs[1]),
4635                    outer_upstream_off: node_offset(arena, node.inputs[2]),
4636                    outer_xs_offs: Arc::new(outer_xs_offs),
4637                    outer_dinit_off: node_offset(arena, node.id),
4638                    length: *length,
4639                    carry_bytes: carry_bytes as u32,
4640                    carry_elem_size,
4641                    save_trajectory: *save_trajectory,
4642                    num_checkpoints: *num_checkpoints,
4643                    forward_body: fb_schedule,
4644                    forward_body_init: fb_init,
4645                    forward_body_carry_in_off: fb_carry_in_off,
4646                    forward_body_output_off: fb_output_off,
4647                    forward_body_x_offs: Arc::new(fb_x_offs),
4648                }
4649            }
4650
4651            Op::ScanBackwardXs {
4652                body_vjp,
4653                length,
4654                save_trajectory,
4655                num_xs,
4656                xs_idx,
4657                num_checkpoints,
4658                forward_body,
4659            } => {
4660                assert!(
4661                    *num_checkpoints == 0 || *num_checkpoints <= *length,
4662                    "Op::ScanBackwardXs: num_checkpoints={} must be 0 or ≤ length={}",
4663                    *num_checkpoints,
4664                    *length
4665                );
4666                let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
4667                if is_recursive {
4668                    assert!(
4669                        forward_body.is_some(),
4670                        "Op::ScanBackwardXs with num_checkpoints<length \
4671                         requires forward_body"
4672                    );
4673                }
4674                // Mirror ScanBackward's body_vjp slot identification +
4675                // arena prep, then add: per-iteration extraction of the
4676                // body_vjp output that corresponds to the chosen xs.
4677                //
4678                // body_vjp's outputs (from `grad(body, [carry, xs_0, ..., xs_{num_xs-1}])`):
4679                //   outputs[0]      = dcarry
4680                //   outputs[1 + i]  = dx_t_i
4681                let body_plan = rlx_opt::memory::plan_memory(body_vjp);
4682                let body_offsets: HashMap<NodeId, usize> = body_plan
4683                    .assignments
4684                    .iter()
4685                    .map(|(id, slot)| (*id, slot.offset))
4686                    .collect();
4687                let mut body_d_output_off: Option<usize> = None;
4688                let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
4689                for n in body_vjp.nodes() {
4690                    if let Op::Input { name } = &n.op {
4691                        let off = body_offsets[&n.id];
4692                        if name == "d_output" {
4693                            body_d_output_off = Some(off);
4694                        } else {
4695                            body_other_inputs.push((n.id, off));
4696                        }
4697                    }
4698                }
4699                body_other_inputs.sort_by_key(|(id, _)| *id);
4700                let body_d_output_off =
4701                    body_d_output_off.expect("ScanBackwardXs body_vjp missing 'd_output' Input");
4702                let expected_others = 1 + *num_xs as usize;
4703                assert_eq!(
4704                    body_other_inputs.len(),
4705                    expected_others,
4706                    "ScanBackwardXs body_vjp has {} non-d_output Inputs; expected {}",
4707                    body_other_inputs.len(),
4708                    expected_others
4709                );
4710                let body_carry_in_off = body_other_inputs[0].1;
4711                let body_x_offs: Vec<usize> = body_other_inputs
4712                    .iter()
4713                    .skip(1)
4714                    .map(|(_, off)| *off)
4715                    .collect();
4716                let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
4717                let dxs_out_node = body_vjp.outputs[1 + *xs_idx as usize];
4718                let body_dxs_out_off = body_offsets[&dxs_out_node];
4719
4720                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4721                for n in body_vjp.nodes() {
4722                    if let Op::Constant { data } = &n.op
4723                        && body_arena.has_buffer(n.id)
4724                        && !data.is_empty()
4725                    {
4726                        match n.shape.dtype() {
4727                            rlx_ir::DType::F64 => {
4728                                let off = body_arena.byte_offset(n.id);
4729                                let buf = body_arena.raw_buf_mut();
4730                                let nb = (buf.len() - off).min(data.len());
4731                                buf[off..off + nb].copy_from_slice(&data[..nb]);
4732                            }
4733                            _ => {
4734                                let buf = body_arena.slice_mut(n.id);
4735                                let nf = data.len() / 4;
4736                                let nl = buf.len().min(nf);
4737                                for i in 0..nl {
4738                                    let bytes = [
4739                                        data[i * 4],
4740                                        data[i * 4 + 1],
4741                                        data[i * 4 + 2],
4742                                        data[i * 4 + 3],
4743                                    ];
4744                                    buf[i] = f32::from_le_bytes(bytes);
4745                                }
4746                            }
4747                        }
4748                    }
4749                }
4750                let body_init = body_arena.raw_buf().to_vec();
4751                let body_schedule = compile_thunks(body_vjp, &body_arena);
4752
4753                let carry_bytes = body_vjp
4754                    .node(body_vjp.outputs[0])
4755                    .shape
4756                    .size_bytes()
4757                    .expect("ScanBackwardXs dcarry must be statically shaped");
4758                let carry_elem_size = body_vjp
4759                    .node(body_vjp.outputs[0])
4760                    .shape
4761                    .dtype()
4762                    .size_bytes() as u32;
4763                let per_step_bytes = body_vjp
4764                    .node(dxs_out_node)
4765                    .shape
4766                    .size_bytes()
4767                    .expect("ScanBackwardXs dxs body output must be statically shaped");
4768
4769                let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
4770                for i in 0..*num_xs as usize {
4771                    let outer_xs_id = node.inputs[3 + i];
4772                    let outer_xs_off = node_offset(arena, outer_xs_id);
4773                    let outer_xs_shape = &graph.node(outer_xs_id).shape;
4774                    let total = outer_xs_shape
4775                        .size_bytes()
4776                        .expect("ScanBackwardXs xs must have static shape");
4777                    let per_step = total / *length as usize;
4778                    outer_xs_offs.push((outer_xs_off, per_step as u32));
4779                }
4780
4781                // Compile forward_body for recompute when checkpointed.
4782                // Mirrors the same code path in the ScanBackward arm.
4783                let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
4784                    if is_recursive {
4785                        let fb = forward_body.as_ref().unwrap();
4786                        let fb_plan = rlx_opt::memory::plan_memory(fb);
4787                        let fb_offsets: HashMap<NodeId, usize> = fb_plan
4788                            .assignments
4789                            .iter()
4790                            .map(|(id, slot)| (*id, slot.offset))
4791                            .collect();
4792                        let mut fb_inputs: Vec<NodeId> = fb
4793                            .nodes()
4794                            .iter()
4795                            .filter(|n| matches!(n.op, Op::Input { .. }))
4796                            .map(|n| n.id)
4797                            .collect();
4798                        fb_inputs.sort();
4799                        let fb_carry = fb_offsets[&fb_inputs[0]];
4800                        let fb_xs: Vec<usize> = (1..fb_inputs.len())
4801                            .map(|i| fb_offsets[&fb_inputs[i]])
4802                            .collect();
4803                        let fb_out = fb_offsets[&fb.outputs[0]];
4804                        let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
4805                        for n in fb.nodes() {
4806                            if let Op::Constant { data } = &n.op
4807                                && fb_arena.has_buffer(n.id)
4808                                && !data.is_empty()
4809                            {
4810                                // Byte-copy works for any
4811                                // numeric dtype as long as the
4812                                // arena slot is sized to hold
4813                                // it — the Constant's `data`
4814                                // already encodes the right
4815                                // bytes per element.
4816                                let off = fb_arena.byte_offset(n.id);
4817                                let buf = fb_arena.raw_buf_mut();
4818                                let nb = (buf.len() - off).min(data.len());
4819                                buf[off..off + nb].copy_from_slice(&data[..nb]);
4820                            }
4821                        }
4822                        let fb_init_bytes = fb_arena.raw_buf().to_vec();
4823                        let fb_sched = compile_thunks(fb, &fb_arena);
4824                        (
4825                            Some(Arc::new(fb_sched)),
4826                            Some(Arc::new(fb_init_bytes)),
4827                            fb_carry,
4828                            fb_out,
4829                            fb_xs,
4830                        )
4831                    } else {
4832                        (None, None, 0, 0, Vec::new())
4833                    };
4834
4835                Thunk::ScanBackwardXs {
4836                    body_vjp: Arc::new(body_schedule),
4837                    body_init: Arc::new(body_init),
4838                    body_carry_in_off,
4839                    body_x_offs: Arc::new(body_x_offs),
4840                    body_d_output_off,
4841                    body_dcarry_out_off,
4842                    body_dxs_out_off,
4843                    outer_init_off: node_offset(arena, node.inputs[0]),
4844                    outer_traj_off: node_offset(arena, node.inputs[1]),
4845                    outer_upstream_off: node_offset(arena, node.inputs[2]),
4846                    outer_xs_offs: Arc::new(outer_xs_offs),
4847                    outer_dxs_off: node_offset(arena, node.id),
4848                    length: *length,
4849                    carry_bytes: carry_bytes as u32,
4850                    carry_elem_size,
4851                    per_step_bytes: per_step_bytes as u32,
4852                    save_trajectory: *save_trajectory,
4853                    num_checkpoints: *num_checkpoints,
4854                    forward_body: fb_schedule,
4855                    forward_body_init: fb_init,
4856                    forward_body_carry_in_off: fb_carry_in_off,
4857                    forward_body_output_off: fb_output_off,
4858                    forward_body_x_offs: Arc::new(fb_x_offs),
4859                }
4860            }
4861
4862            Op::Concat { axis } => {
4863                // Compute outer/inner from the OUTPUT shape: all inputs share
4864                // the same shape except along `axis`. The output's leading
4865                // and trailing dims match.
4866                let out_shape = &node.shape;
4867                let rank = out_shape.rank();
4868                let outer: usize = (0..*axis)
4869                    .map(|i| out_shape.dim(i).unwrap_static())
4870                    .product::<usize>()
4871                    .max(1);
4872                let inner: usize = (*axis + 1..rank)
4873                    .map(|i| out_shape.dim(i).unwrap_static())
4874                    .product::<usize>()
4875                    .max(1);
4876                let total_axis = out_shape.dim(*axis).unwrap_static();
4877                let inputs: Vec<(usize, u32)> = node
4878                    .inputs
4879                    .iter()
4880                    .map(|&in_id| {
4881                        let in_shape = &graph.node(in_id).shape;
4882                        let in_axis = in_shape.dim(*axis).unwrap_static();
4883                        (node_offset(arena, in_id), in_axis as u32)
4884                    })
4885                    .collect();
4886                let dst = node_offset(arena, node.id);
4887                match out_shape.dtype() {
4888                    rlx_ir::DType::F64 => Thunk::ConcatF64 {
4889                        dst,
4890                        outer: outer as u32,
4891                        inner: inner as u32,
4892                        total_axis: total_axis as u32,
4893                        inputs,
4894                    },
4895                    _ => Thunk::Concat {
4896                        dst,
4897                        outer: outer as u32,
4898                        inner: inner as u32,
4899                        total_axis: total_axis as u32,
4900                        inputs,
4901                    },
4902                }
4903            }
4904
4905            Op::GaussianSplatRender {
4906                width,
4907                height,
4908                tile_size,
4909                radius_scale,
4910                alpha_cutoff,
4911                max_splat_steps,
4912                transmittance_threshold,
4913                max_list_entries,
4914            } => {
4915                let elem_len =
4916                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
4917                Thunk::GaussianSplatRender {
4918                    positions_off: node_offset(arena, node.inputs[0]),
4919                    positions_len: elem_len(node.inputs[0]),
4920                    scales_off: node_offset(arena, node.inputs[1]),
4921                    scales_len: elem_len(node.inputs[1]),
4922                    rotations_off: node_offset(arena, node.inputs[2]),
4923                    rotations_len: elem_len(node.inputs[2]),
4924                    opacities_off: node_offset(arena, node.inputs[3]),
4925                    opacities_len: elem_len(node.inputs[3]),
4926                    colors_off: node_offset(arena, node.inputs[4]),
4927                    colors_len: elem_len(node.inputs[4]),
4928                    sh_coeffs_off: node_offset(arena, node.inputs[5]),
4929                    sh_coeffs_len: elem_len(node.inputs[5]),
4930                    meta_off: node_offset(arena, node.inputs[6]),
4931                    dst_off: node_offset(arena, node.id),
4932                    dst_len: node.shape.num_elements().unwrap_or(0),
4933                    width: *width,
4934                    height: *height,
4935                    tile_size: *tile_size,
4936                    radius_scale: *radius_scale,
4937                    alpha_cutoff: *alpha_cutoff,
4938                    max_splat_steps: *max_splat_steps,
4939                    transmittance_threshold: *transmittance_threshold,
4940                    max_list_entries: *max_list_entries,
4941                }
4942            }
4943
4944            Op::GaussianSplatRenderBackward {
4945                width,
4946                height,
4947                tile_size,
4948                radius_scale,
4949                alpha_cutoff,
4950                max_splat_steps,
4951                transmittance_threshold,
4952                max_list_entries,
4953                loss_grad_clip,
4954                sh_band,
4955                max_anisotropy,
4956            } => {
4957                let elem_len =
4958                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
4959                Thunk::GaussianSplatRenderBackward {
4960                    positions_off: node_offset(arena, node.inputs[0]),
4961                    positions_len: elem_len(node.inputs[0]),
4962                    scales_off: node_offset(arena, node.inputs[1]),
4963                    scales_len: elem_len(node.inputs[1]),
4964                    rotations_off: node_offset(arena, node.inputs[2]),
4965                    rotations_len: elem_len(node.inputs[2]),
4966                    opacities_off: node_offset(arena, node.inputs[3]),
4967                    opacities_len: elem_len(node.inputs[3]),
4968                    colors_off: node_offset(arena, node.inputs[4]),
4969                    colors_len: elem_len(node.inputs[4]),
4970                    sh_coeffs_off: node_offset(arena, node.inputs[5]),
4971                    sh_coeffs_len: elem_len(node.inputs[5]),
4972                    meta_off: node_offset(arena, node.inputs[6]),
4973                    d_loss_off: node_offset(arena, node.inputs[7]),
4974                    d_loss_len: elem_len(node.inputs[7]),
4975                    packed_off: node_offset(arena, node.id),
4976                    packed_len: node.shape.num_elements().unwrap_or(0),
4977                    width: *width,
4978                    height: *height,
4979                    tile_size: *tile_size,
4980                    radius_scale: *radius_scale,
4981                    alpha_cutoff: *alpha_cutoff,
4982                    max_splat_steps: *max_splat_steps,
4983                    transmittance_threshold: *transmittance_threshold,
4984                    max_list_entries: *max_list_entries,
4985                    loss_grad_clip: *loss_grad_clip,
4986                    sh_band: *sh_band,
4987                    max_anisotropy: *max_anisotropy,
4988                }
4989            }
4990
4991            Op::GaussianSplatPrepare {
4992                width,
4993                height,
4994                tile_size,
4995                radius_scale,
4996                alpha_cutoff,
4997                max_splat_steps,
4998                transmittance_threshold,
4999                max_list_entries,
5000            } => {
5001                let elem_len =
5002                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5003                Thunk::GaussianSplatPrepare {
5004                    positions_off: node_offset(arena, node.inputs[0]),
5005                    positions_len: elem_len(node.inputs[0]),
5006                    scales_off: node_offset(arena, node.inputs[1]),
5007                    scales_len: elem_len(node.inputs[1]),
5008                    rotations_off: node_offset(arena, node.inputs[2]),
5009                    rotations_len: elem_len(node.inputs[2]),
5010                    opacities_off: node_offset(arena, node.inputs[3]),
5011                    opacities_len: elem_len(node.inputs[3]),
5012                    colors_off: node_offset(arena, node.inputs[4]),
5013                    colors_len: elem_len(node.inputs[4]),
5014                    sh_coeffs_off: node_offset(arena, node.inputs[5]),
5015                    sh_coeffs_len: elem_len(node.inputs[5]),
5016                    meta_off: node_offset(arena, node.inputs[6]),
5017                    meta_len: elem_len(node.inputs[6]),
5018                    prep_off: node_offset(arena, node.id),
5019                    prep_len: node.shape.num_elements().unwrap_or(0),
5020                    width: *width,
5021                    height: *height,
5022                    tile_size: *tile_size,
5023                    radius_scale: *radius_scale,
5024                    alpha_cutoff: *alpha_cutoff,
5025                    max_splat_steps: *max_splat_steps,
5026                    transmittance_threshold: *transmittance_threshold,
5027                    max_list_entries: *max_list_entries,
5028                }
5029            }
5030
5031            Op::GaussianSplatRasterize {
5032                width,
5033                height,
5034                tile_size,
5035                alpha_cutoff,
5036                max_splat_steps,
5037                transmittance_threshold,
5038                max_list_entries,
5039            } => {
5040                let elem_len =
5041                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5042                let prep_id = node.inputs[0];
5043                let count = match &graph.node(prep_id).op {
5044                    rlx_ir::Op::GaussianSplatPrepare { .. } => {
5045                        elem_len(graph.node(prep_id).inputs[0]) / 3
5046                    }
5047                    _ => 1,
5048                };
5049                Thunk::GaussianSplatRasterize {
5050                    prep_off: node_offset(arena, prep_id),
5051                    prep_len: elem_len(prep_id),
5052                    meta_off: node_offset(arena, node.inputs[1]),
5053                    meta_len: elem_len(node.inputs[1]),
5054                    dst_off: node_offset(arena, node.id),
5055                    dst_len: node.shape.num_elements().unwrap_or(0),
5056                    count,
5057                    width: *width,
5058                    height: *height,
5059                    tile_size: *tile_size,
5060                    alpha_cutoff: *alpha_cutoff,
5061                    max_splat_steps: *max_splat_steps,
5062                    transmittance_threshold: *transmittance_threshold,
5063                    max_list_entries: *max_list_entries,
5064                }
5065            }
5066
5067            Op::Custom { name, attrs, .. } => {
5068                let kernel = crate::op_registry::lookup_cpu_kernel(name).unwrap_or_else(|| {
5069                    panic!(
5070                        "compile_thunks: no CPU kernel registered for \
5071                         Op::Custom('{name}'). Register one via \
5072                         rlx_cpu::op_registry::register_cpu_kernel \
5073                         before compiling on the CPU backend."
5074                    )
5075                });
5076                let inputs_v: Vec<(usize, u32, Shape)> = node
5077                    .inputs
5078                    .iter()
5079                    .map(|&in_id| {
5080                        let s = graph.node(in_id).shape.clone();
5081                        let len = s.num_elements().unwrap_or(0) as u32;
5082                        (node_offset(arena, in_id), len, s)
5083                    })
5084                    .collect();
5085                let out_len = node.shape.num_elements().unwrap_or(0) as u32;
5086                Thunk::CustomOp {
5087                    kernel,
5088                    inputs: inputs_v,
5089                    output: (node_offset(arena, node.id), out_len, node.shape.clone()),
5090                    attrs: attrs.clone(),
5091                }
5092            }
5093
5094            Op::Fft { inverse } => {
5095                // Last axis carries the 2N real-block layout; complex
5096                // points = N = last_dim / 2. `outer` is the product
5097                // of all preceding axes — the kernel iterates one
5098                // batch-row at a time. f32 and f64 share the same
5099                // radix-2 structure but use separate scratch buffers;
5100                // the dtype is captured here so the closure dispatches
5101                // without per-row branching.
5102                let shape = &node.shape;
5103                let last = shape.dim(shape.rank() - 1).unwrap_static();
5104                let n_complex = (last / 2) as u32;
5105                let total = shape.num_elements().unwrap_or(0);
5106                let outer = (total / last) as u32;
5107                let dtype = shape.dtype();
5108                assert!(
5109                    matches!(dtype, rlx_ir::DType::F32 | rlx_ir::DType::F64),
5110                    "Op::Fft on CPU requires F32 or F64, got {dtype:?}"
5111                );
5112                Thunk::Fft1d {
5113                    src: node_offset(arena, node.inputs[0]),
5114                    dst: node_offset(arena, node.id),
5115                    outer,
5116                    n_complex,
5117                    inverse: *inverse,
5118                    dtype,
5119                }
5120            }
5121
5122            Op::CustomFn {
5123                fwd_body,
5124                num_inputs,
5125                ..
5126            } => {
5127                // Plan + compile the body sub-graph standalone, fill its
5128                // Constants (mirrors the Op::Scan body lowering), then
5129                // capture per-input copy specs and the output spec.
5130                // Body Inputs in NodeId order match the outer node's
5131                // operand vector by position.
5132                let body_plan = rlx_opt::memory::plan_memory(fwd_body);
5133                let body_offsets: HashMap<NodeId, usize> = body_plan
5134                    .assignments
5135                    .iter()
5136                    .map(|(id, slot)| (*id, slot.offset))
5137                    .collect();
5138
5139                let mut body_input_ids: Vec<NodeId> = fwd_body
5140                    .nodes()
5141                    .iter()
5142                    .filter(|n| matches!(n.op, Op::Input { .. }))
5143                    .map(|n| n.id)
5144                    .collect();
5145                body_input_ids.sort();
5146                assert_eq!(
5147                    body_input_ids.len(),
5148                    *num_inputs as usize,
5149                    "Op::CustomFn fwd_body has {} Op::Input(s); declared num_inputs={}",
5150                    body_input_ids.len(),
5151                    *num_inputs,
5152                );
5153
5154                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5155                for n in fwd_body.nodes() {
5156                    if let Op::Constant { data } = &n.op
5157                        && body_arena.has_buffer(n.id)
5158                        && !data.is_empty()
5159                    {
5160                        match n.shape.dtype() {
5161                            rlx_ir::DType::F64 => {
5162                                let off = body_arena.byte_offset(n.id);
5163                                let buf = body_arena.raw_buf_mut();
5164                                let nb = (buf.len() - off).min(data.len());
5165                                buf[off..off + nb].copy_from_slice(&data[..nb]);
5166                            }
5167                            _ => {
5168                                let buf = body_arena.slice_mut(n.id);
5169                                let nf = data.len() / 4;
5170                                let nl = buf.len().min(nf);
5171                                for i in 0..nl {
5172                                    let bytes = [
5173                                        data[i * 4],
5174                                        data[i * 4 + 1],
5175                                        data[i * 4 + 2],
5176                                        data[i * 4 + 3],
5177                                    ];
5178                                    buf[i] = f32::from_le_bytes(bytes);
5179                                }
5180                            }
5181                        }
5182                    }
5183                }
5184                let body_init = body_arena.raw_buf().to_vec();
5185                let body_schedule = compile_thunks(fwd_body, &body_arena);
5186
5187                // Per primal input: (body_input_off, outer_input_off, bytes).
5188                let inputs_v: Vec<(usize, usize, u32)> = (0..*num_inputs as usize)
5189                    .map(|i| {
5190                        let body_in = body_input_ids[i];
5191                        let body_off = body_offsets[&body_in];
5192                        let outer_in = node.inputs[i];
5193                        let outer_off = node_offset(arena, outer_in);
5194                        let bytes = graph
5195                            .node(outer_in)
5196                            .shape
5197                            .size_bytes()
5198                            .expect("Op::CustomFn primal input must have static shape");
5199                        (body_off, outer_off, bytes as u32)
5200                    })
5201                    .collect();
5202
5203                let body_output_id = fwd_body
5204                    .outputs
5205                    .first()
5206                    .copied()
5207                    .expect("Op::CustomFn fwd_body must declare exactly one output");
5208                let body_output_off = body_offsets[&body_output_id];
5209                let out_bytes = node
5210                    .shape
5211                    .size_bytes()
5212                    .expect("Op::CustomFn output must have static shape");
5213
5214                Thunk::CustomFn {
5215                    body: Arc::new(body_schedule),
5216                    body_init: Arc::new(body_init),
5217                    inputs: Arc::new(inputs_v),
5218                    body_output_off,
5219                    outer_output_off: node_offset(arena, node.id),
5220                    out_bytes: out_bytes as u32,
5221                }
5222            }
5223
5224            _ => Thunk::Nop,
5225        };
5226        thunks.push(t);
5227    }
5228
5229    let cfg = crate::config::RuntimeConfig::global();
5230    let mask_thr = cfg.mask_binary_threshold;
5231    let mask_neg = cfg.attn_mask_neg_inf;
5232    let score_skip = cfg.score_skip_threshold;
5233
5234    // Pre-compile closures (skip Nops — they're filtered out)
5235    let compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>> = thunks
5236        .iter()
5237        .filter(|t| !matches!(t, Thunk::Nop))
5238        .map(|thunk| {
5239            match thunk.clone() {
5240                Thunk::Nop => Arc::new(|_: *mut u8| {}) as Arc<dyn Fn(*mut u8) + Send + Sync>,
5241
5242                Thunk::Sgemm { a, b, c, m, k, n } => {
5243                    let (m, k, n) = (m as usize, k as usize, n as usize);
5244                    Arc::new(move |base: *mut u8| unsafe {
5245                        crate::blas::sgemm(
5246                            sl(a, base, m * k),
5247                            sl(b, base, k * n),
5248                            sl_mut(c, base, m * n),
5249                            m,
5250                            k,
5251                            n,
5252                        );
5253                    })
5254                }
5255
5256                Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
5257                    let (n_, nrhs_) = (n as usize, nrhs as usize);
5258                    Arc::new(move |base: *mut u8| unsafe {
5259                        let a_src = sl_f64(a, base, n_ * n_);
5260                        let b_src = sl_f64(b, base, n_ * nrhs_);
5261                        let mut a_scratch: Vec<f64> = a_src.to_vec();
5262                        let mut x_buf: Vec<f64> = b_src.to_vec();
5263                        let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
5264                        if info != 0 {
5265                            panic!("DenseSolveF64: singular (info={info})");
5266                        }
5267                        sl_mut_f64(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
5268                    })
5269                }
5270
5271                Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
5272                    let (n_, nrhs_) = (n as usize, nrhs as usize);
5273                    Arc::new(move |base: *mut u8| unsafe {
5274                        let a_src = sl(a, base, n_ * n_);
5275                        let b_src = sl(b, base, n_ * nrhs_);
5276                        let mut a_scratch: Vec<f32> = a_src.to_vec();
5277                        let mut x_buf: Vec<f32> = b_src.to_vec();
5278                        let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
5279                        if info != 0 {
5280                            panic!("DenseSolveF32: singular (info={info})");
5281                        }
5282                        sl_mut(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
5283                    })
5284                }
5285
5286                Thunk::FusedMmBiasAct {
5287                    a,
5288                    w,
5289                    bias,
5290                    c,
5291                    m,
5292                    k,
5293                    n,
5294                    act,
5295                } => {
5296                    let (m, k, n) = (m as usize, k as usize, n as usize);
5297                    Arc::new(move |base: *mut u8| unsafe {
5298                        let out = sl_mut(c, base, m * n);
5299                        crate::blas::sgemm(sl(a, base, m * k), sl(w, base, k * n), out, m, k, n);
5300                        // Bias + activation epilogue. Gelu uses the fused
5301                        // `par_bias_gelu` kernel (bias add + Gelu in one
5302                        // pass). For everything else, do the bias add first
5303                        // and then apply the activation per-element. The
5304                        // pre-fix code dispatched `_ => bias_add` and dropped
5305                        // the activation entirely — silent correctness bug
5306                        // for Silu/Relu/Sigmoid/etc.
5307                        match act {
5308                            Some(Activation::Gelu) => {
5309                                crate::kernels::par_bias_gelu(out, sl(bias, base, n), m, n)
5310                            }
5311                            Some(other) => {
5312                                crate::blas::bias_add(out, sl(bias, base, n), m, n);
5313                                apply_activation_inplace(out, other);
5314                            }
5315                            None => crate::blas::bias_add(out, sl(bias, base, n), m, n),
5316                        }
5317                    })
5318                }
5319
5320                Thunk::FusedResidualLN {
5321                    x,
5322                    res,
5323                    bias,
5324                    g,
5325                    b,
5326                    out,
5327                    rows,
5328                    h,
5329                    eps,
5330                    has_bias,
5331                } => {
5332                    let (rows, h) = (rows as usize, h as usize);
5333                    Arc::new(move |base: *mut u8| unsafe {
5334                        let zero = vec![0f32; h]; // closure only — not hot path
5335                        let bi = if has_bias { sl(bias, base, h) } else { &zero };
5336                        let xp = sl(x, base, rows * h).as_ptr() as usize;
5337                        let rp = sl(res, base, rows * h).as_ptr() as usize;
5338                        let op = sl_mut(out, base, rows * h).as_mut_ptr() as usize;
5339                        let bp = bi.as_ptr() as usize;
5340                        let gp = sl(g, base, h).as_ptr() as usize;
5341                        let bbp = sl(b, base, h).as_ptr() as usize;
5342                        crate::pool::par_for(rows, 4, &|off, cnt| {
5343                            let xs = std::slice::from_raw_parts(
5344                                (xp as *const f32).add(off * h),
5345                                cnt * h,
5346                            );
5347                            let rs = std::slice::from_raw_parts(
5348                                (rp as *const f32).add(off * h),
5349                                cnt * h,
5350                            );
5351                            let os = std::slice::from_raw_parts_mut(
5352                                (op as *mut f32).add(off * h),
5353                                cnt * h,
5354                            );
5355                            let bi = std::slice::from_raw_parts(bp as *const f32, h);
5356                            let g = std::slice::from_raw_parts(gp as *const f32, h);
5357                            let b = std::slice::from_raw_parts(bbp as *const f32, h);
5358                            crate::kernels::residual_bias_layer_norm(
5359                                xs, rs, bi, g, b, os, cnt, h, eps,
5360                            );
5361                        });
5362                    })
5363                }
5364
5365                Thunk::BiasAdd {
5366                    src,
5367                    bias,
5368                    dst,
5369                    m,
5370                    n,
5371                } => {
5372                    let (m, n) = (m as usize, n as usize);
5373                    Arc::new(move |base: *mut u8| unsafe {
5374                        let out = sl_mut(dst, base, m * n);
5375                        out.copy_from_slice(sl(src, base, m * n));
5376                        crate::blas::bias_add(out, sl(bias, base, n), m, n);
5377                    })
5378                }
5379
5380                Thunk::Gather {
5381                    table,
5382                    table_len,
5383                    idx,
5384                    dst,
5385                    num_idx,
5386                    trailing,
5387                } => {
5388                    let (ni, tr, tl) = (num_idx as usize, trailing as usize, table_len as usize);
5389                    Arc::new(move |base: *mut u8| unsafe {
5390                        let tab = sl(table, base, tl);
5391                        let ids = sl(idx, base, ni);
5392                        let out = sl_mut(dst, base, ni * tr);
5393                        for i in 0..ni {
5394                            let row = ids[i] as usize;
5395                            out[i * tr..(i + 1) * tr]
5396                                .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5397                        }
5398                    })
5399                }
5400
5401                Thunk::Narrow {
5402                    src,
5403                    dst,
5404                    outer,
5405                    src_stride,
5406                    dst_stride,
5407                    inner,
5408                    elem_bytes,
5409                } => {
5410                    narrow_thunk_closure(src, dst, outer, src_stride, dst_stride, inner, elem_bytes)
5411                }
5412
5413                Thunk::Copy { src, dst, len } => {
5414                    let len = len as usize;
5415                    Arc::new(move |base: *mut u8| unsafe {
5416                        sl_mut(dst, base, len).copy_from_slice(sl(src, base, len));
5417                    })
5418                }
5419
5420                Thunk::Softmax { data, rows, cols } => {
5421                    let (rows, cols) = (rows as usize, cols as usize);
5422                    Arc::new(move |base: *mut u8| unsafe {
5423                        crate::naive::softmax(sl_mut(data, base, rows * cols), rows, cols);
5424                    })
5425                }
5426
5427                Thunk::Cumsum {
5428                    src,
5429                    dst,
5430                    rows,
5431                    cols,
5432                    exclusive,
5433                } => {
5434                    let (rows, cols) = (rows as usize, cols as usize);
5435                    Arc::new(move |base: *mut u8| unsafe {
5436                        let s = sl(src, base, rows * cols);
5437                        let d = sl_mut(dst, base, rows * cols);
5438                        if exclusive {
5439                            for r in 0..rows {
5440                                let mut acc = 0.0f32;
5441                                for c in 0..cols {
5442                                    d[r * cols + c] = acc;
5443                                    acc += s[r * cols + c];
5444                                }
5445                            }
5446                        } else {
5447                            for r in 0..rows {
5448                                let mut acc = 0.0f32;
5449                                for c in 0..cols {
5450                                    acc += s[r * cols + c];
5451                                    d[r * cols + c] = acc;
5452                                }
5453                            }
5454                        }
5455                    })
5456                }
5457
5458                Thunk::Sample {
5459                    logits,
5460                    dst,
5461                    batch,
5462                    vocab,
5463                    top_k,
5464                    top_p,
5465                    temperature,
5466                    seed,
5467                } => {
5468                    let (b, v) = (batch as usize, vocab as usize);
5469                    let k = (top_k as usize).min(v);
5470                    Arc::new(move |base: *mut u8| unsafe {
5471                        let lg = sl(logits, base, b * v);
5472                        let out = sl_mut(dst, base, b);
5473                        let mut rng =
5474                            rlx_ir::Philox4x32::new(if seed == 0 { 0xDEADBEEF } else { seed });
5475                        for bi in 0..b {
5476                            let row = &lg[bi * v..(bi + 1) * v];
5477                            out[bi] = sample_row(row, k, top_p, temperature, &mut rng) as f32;
5478                        }
5479                    })
5480                }
5481
5482                Thunk::DequantMatMul {
5483                    x,
5484                    w_q,
5485                    scale,
5486                    zp,
5487                    dst,
5488                    m,
5489                    k,
5490                    n,
5491                    block_size,
5492                    is_asymmetric,
5493                } => {
5494                    let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
5495                    let n_blocks_per_col = k.div_ceil(bs);
5496                    Arc::new(move |base: *mut u8| unsafe {
5497                        let xs = sl(x, base, m * k);
5498                        // w_q is packed i8 — use raw byte slice + reinterpret.
5499                        let raw = base.add(w_q);
5500                        let w_bytes = std::slice::from_raw_parts(raw as *const i8, k * n);
5501                        let scales = sl(scale, base, n_blocks_per_col * n);
5502                        let zps = if is_asymmetric {
5503                            sl(zp, base, n_blocks_per_col * n)
5504                        } else {
5505                            &[][..]
5506                        };
5507                        let out = sl_mut(dst, base, m * n);
5508                        dequant_matmul_int8(
5509                            xs,
5510                            w_bytes,
5511                            scales,
5512                            zps,
5513                            out,
5514                            m,
5515                            k,
5516                            n,
5517                            bs,
5518                            is_asymmetric,
5519                        );
5520                    })
5521                }
5522
5523                Thunk::DequantMatMulGguf {
5524                    x,
5525                    w_q,
5526                    dst,
5527                    m,
5528                    k,
5529                    n,
5530                    scheme,
5531                } => {
5532                    let (m, k, n) = (m as usize, k as usize, n as usize);
5533                    let block_bytes = scheme.gguf_block_bytes() as usize;
5534                    let block_elems = scheme.gguf_block_size() as usize;
5535                    let total_bytes = (k * n) / block_elems * block_bytes;
5536                    Arc::new(move |base: *mut u8| unsafe {
5537                        let xs = sl(x, base, m * k);
5538                        let w_bytes =
5539                            std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
5540                        let out = sl_mut(dst, base, m * n);
5541                        crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
5542                    })
5543                }
5544
5545                Thunk::DequantMatMulInt4 {
5546                    x,
5547                    w_q,
5548                    scale,
5549                    zp,
5550                    dst,
5551                    m,
5552                    k,
5553                    n,
5554                    block_size,
5555                    is_asymmetric,
5556                } => {
5557                    let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
5558                    let n_blocks = k.div_ceil(bs);
5559                    Arc::new(move |base: *mut u8| unsafe {
5560                        let xs = sl(x, base, m * k);
5561                        let w_bytes = std::slice::from_raw_parts(
5562                            base.add(w_q) as *const u8,
5563                            (k * n).div_ceil(2),
5564                        );
5565                        let scales = sl(scale, base, n_blocks * n);
5566                        let zps = if is_asymmetric {
5567                            sl(zp, base, n_blocks * n)
5568                        } else {
5569                            &[][..]
5570                        };
5571                        let out = sl_mut(dst, base, m * n);
5572                        dequant_matmul_int4(
5573                            xs,
5574                            w_bytes,
5575                            scales,
5576                            zps,
5577                            out,
5578                            m,
5579                            k,
5580                            n,
5581                            bs,
5582                            is_asymmetric,
5583                        );
5584                    })
5585                }
5586
5587                Thunk::DequantMatMulFp8 {
5588                    x,
5589                    w_q,
5590                    scale,
5591                    dst,
5592                    m,
5593                    k,
5594                    n,
5595                    e5m2,
5596                } => {
5597                    let (m, k, n) = (m as usize, k as usize, n as usize);
5598                    Arc::new(move |base: *mut u8| unsafe {
5599                        let xs = sl(x, base, m * k);
5600                        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
5601                        let scales = sl(scale, base, n);
5602                        let out = sl_mut(dst, base, m * n);
5603                        dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
5604                    })
5605                }
5606
5607                Thunk::DequantMatMulNvfp4 {
5608                    x,
5609                    w_q,
5610                    scale,
5611                    global_scale,
5612                    dst,
5613                    m,
5614                    k,
5615                    n,
5616                } => {
5617                    let (m, k, n) = (m as usize, k as usize, n as usize);
5618                    let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
5619                    Arc::new(move |base: *mut u8| unsafe {
5620                        let xs = sl(x, base, m * k);
5621                        let w_bytes = std::slice::from_raw_parts(
5622                            base.add(w_q) as *const u8,
5623                            (k * n).div_ceil(2),
5624                        );
5625                        let scale_bytes =
5626                            std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
5627                        let gs = sl(global_scale, base, 1)[0];
5628                        let out = sl_mut(dst, base, m * n);
5629                        dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
5630                    })
5631                }
5632
5633                Thunk::LoraMatMul {
5634                    x,
5635                    w,
5636                    a,
5637                    b,
5638                    dst,
5639                    m,
5640                    k,
5641                    n,
5642                    r,
5643                    scale,
5644                } => {
5645                    let (m, k, n, r) = (m as usize, k as usize, n as usize, r as usize);
5646                    Arc::new(move |base: *mut u8| unsafe {
5647                        let xs = sl(x, base, m * k);
5648                        let ws = sl(w, base, k * n);
5649                        let a_s = sl(a, base, k * r);
5650                        let bs = sl(b, base, r * n);
5651                        let out = sl_mut(dst, base, m * n);
5652                        // Step 1: out = x · W.
5653                        crate::blas::sgemm(xs, ws, out, m, k, n);
5654                        // Step 2: tmp = x · A (rank-r intermediate; tiny).
5655                        let mut tmp = vec![0f32; m * r];
5656                        crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
5657                        // Step 3: out += scale * (tmp · B).
5658                        // sgemm_accumulate uses alpha=1.0 internally, so
5659                        // scale tmp first.
5660                        if scale != 1.0 {
5661                            for v in tmp.iter_mut() {
5662                                *v *= scale;
5663                            }
5664                        }
5665                        crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
5666                    })
5667                }
5668
5669                Thunk::LayerNorm {
5670                    src,
5671                    g,
5672                    b,
5673                    dst,
5674                    rows,
5675                    h,
5676                    eps,
5677                } => {
5678                    let (rows, h) = (rows as usize, h as usize);
5679                    Arc::new(move |base: *mut u8| unsafe {
5680                        let inp = sl(src, base, rows * h);
5681                        let gamma = sl(g, base, h);
5682                        let beta = sl(b, base, h);
5683                        let out = sl_mut(dst, base, rows * h);
5684                        for row in 0..rows {
5685                            crate::kernels::layer_norm_row(
5686                                &inp[row * h..(row + 1) * h],
5687                                gamma,
5688                                beta,
5689                                &mut out[row * h..(row + 1) * h],
5690                                h,
5691                                eps,
5692                            );
5693                        }
5694                    })
5695                }
5696
5697                Thunk::Attention {
5698                    q,
5699                    k,
5700                    v,
5701                    mask,
5702                    out,
5703                    batch,
5704                    seq,
5705                    kv_seq: _,
5706                    heads,
5707                    head_dim,
5708                    mask_kind,
5709                    q_row_stride,
5710                    k_row_stride,
5711                    v_row_stride,
5712                    bhsd,
5713                } => {
5714                    let (b, s, nh, dh) = (
5715                        batch as usize,
5716                        seq as usize,
5717                        heads as usize,
5718                        head_dim as usize,
5719                    );
5720                    let hs = nh * dh;
5721                    let qrs = q_row_stride as usize;
5722                    let krs = k_row_stride as usize;
5723                    let vrs = v_row_stride as usize;
5724                    let scale = (dh as f32).powf(-0.5);
5725                    Arc::new(move |base: *mut u8| unsafe {
5726                        // Slice lengths use the source's row stride so the
5727                        // compiler-emitted bounds checks cover the whole
5728                        // strided span (the kernel walks with q/k/v_rs).
5729                        // For [B, H, S, D] the buffer is dense B*H*S*D.
5730                        let (q_len, k_len, v_len, o_len) = if bhsd {
5731                            let n = b * nh * s * dh;
5732                            (n, n, n, n)
5733                        } else {
5734                            (b * s * qrs, b * s * krs, b * s * vrs, b * s * hs)
5735                        };
5736                        let q_d = sl(q, base, q_len);
5737                        let k_d = sl(k, base, k_len);
5738                        let v_d = sl(v, base, v_len);
5739                        let m_d: &[f32] = match mask_kind {
5740                            rlx_ir::op::MaskKind::Custom => sl(mask, base, b * s),
5741                            rlx_ir::op::MaskKind::Bias => sl(mask, base, b * nh * s * s),
5742                            _ => &[],
5743                        };
5744                        let o_d = sl_mut(out, base, o_len);
5745                        let sdh = s * dh;
5746                        let mut qh = vec![0f32; sdh];
5747                        let mut kh = vec![0f32; sdh];
5748                        let mut vh = vec![0f32; sdh];
5749                        let mut sc = vec![0f32; s * s];
5750                        let mut oh = vec![0f32; sdh];
5751                        for bi in 0..b {
5752                            for hi in 0..nh {
5753                                for si in 0..s {
5754                                    // Two layouts:
5755                                    //   bhsd=false: [B, S, H, D] (default) →
5756                                    //     off = bi*S*RS + si*RS + hi*D
5757                                    //   bhsd=true:  [B, H, S, D] (GPU/TPU
5758                                    //     convention) →
5759                                    //     off = bi*H*S*D + hi*S*D + si*D
5760                                    // The thunk-fusion pass below sets row
5761                                    // strides, but only for the [B, S, H, D]
5762                                    // case. For bhsd we always use the dense
5763                                    // contiguous stride (qrs == krs == vrs ==
5764                                    // H*D from compile_thunks).
5765                                    let (q_off, k_off, v_off) = if bhsd {
5766                                        (
5767                                            bi * nh * s * dh + hi * s * dh + si * dh,
5768                                            bi * nh * s * dh + hi * s * dh + si * dh,
5769                                            bi * nh * s * dh + hi * s * dh + si * dh,
5770                                        )
5771                                    } else {
5772                                        (
5773                                            bi * s * qrs + si * qrs + hi * dh,
5774                                            bi * s * krs + si * krs + hi * dh,
5775                                            bi * s * vrs + si * vrs + hi * dh,
5776                                        )
5777                                    };
5778                                    qh[si * dh..(si + 1) * dh]
5779                                        .copy_from_slice(&q_d[q_off..q_off + dh]);
5780                                    kh[si * dh..(si + 1) * dh]
5781                                        .copy_from_slice(&k_d[k_off..k_off + dh]);
5782                                    vh[si * dh..(si + 1) * dh]
5783                                        .copy_from_slice(&v_d[v_off..v_off + dh]);
5784                                }
5785                                for qi in 0..s {
5786                                    for ki in 0..s {
5787                                        let mut dot = 0f32;
5788                                        for d in 0..dh {
5789                                            dot += qh[qi * dh + d] * kh[ki * dh + d];
5790                                        }
5791                                        sc[qi * s + ki] = dot * scale;
5792                                    }
5793                                }
5794                                // Apply mask kind — None skips entirely, Causal /
5795                                // SlidingWindow synthesize, Custom reads m_d.
5796                                match mask_kind {
5797                                    rlx_ir::op::MaskKind::None => {}
5798                                    rlx_ir::op::MaskKind::Causal => {
5799                                        for qi in 0..s {
5800                                            for ki in (qi + 1)..s {
5801                                                sc[qi * s + ki] = mask_neg;
5802                                            }
5803                                        }
5804                                    }
5805                                    rlx_ir::op::MaskKind::SlidingWindow(w) => {
5806                                        for qi in 0..s {
5807                                            let lo = qi.saturating_sub(w);
5808                                            for ki in 0..s {
5809                                                if ki < lo || ki > qi {
5810                                                    sc[qi * s + ki] = mask_neg;
5811                                                }
5812                                            }
5813                                        }
5814                                    }
5815                                    rlx_ir::op::MaskKind::Custom => {
5816                                        for qi in 0..s {
5817                                            for ki in 0..s {
5818                                                if m_d[bi * s + ki] < mask_thr {
5819                                                    sc[qi * s + ki] = mask_neg;
5820                                                }
5821                                            }
5822                                        }
5823                                    }
5824                                    rlx_ir::op::MaskKind::Bias => {
5825                                        let per_bh = s * s;
5826                                        let off = (bi * nh + hi) * per_bh;
5827                                        for i in 0..per_bh {
5828                                            sc[i] += m_d[off + i];
5829                                        }
5830                                    }
5831                                }
5832                                crate::naive::softmax(&mut sc, s, s);
5833                                oh.fill(0.0);
5834                                for qi in 0..s {
5835                                    for ki in 0..s {
5836                                        let w = sc[qi * s + ki];
5837                                        if w > score_skip {
5838                                            for d in 0..dh {
5839                                                oh[qi * dh + d] += w * vh[ki * dh + d];
5840                                            }
5841                                        }
5842                                    }
5843                                }
5844                                for si in 0..s {
5845                                    let off = if bhsd {
5846                                        bi * nh * s * dh + hi * s * dh + si * dh
5847                                    } else {
5848                                        bi * s * hs + si * hs + hi * dh
5849                                    };
5850                                    o_d[off..off + dh].copy_from_slice(&oh[si * dh..(si + 1) * dh]);
5851                                }
5852                            }
5853                        }
5854                    })
5855                }
5856
5857                Thunk::FusedSwiGLU {
5858                    src,
5859                    dst,
5860                    n_half,
5861                    total,
5862                    gate_first,
5863                } => {
5864                    let n = n_half as usize;
5865                    let t = total as usize;
5866                    let outer = t / n;
5867                    let in_total = outer * 2 * n;
5868                    Arc::new(move |base: *mut u8| unsafe {
5869                        let inp = sl(src, base, in_total);
5870                        let out = sl_mut(dst, base, t);
5871                        for o in 0..outer {
5872                            let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
5873                            let out_row = &mut out[o * n..(o + 1) * n];
5874                            for i in 0..n {
5875                                let (up, gate) = if gate_first {
5876                                    (in_row[n + i], in_row[i])
5877                                } else {
5878                                    (in_row[i], in_row[n + i])
5879                                };
5880                                out_row[i] = up * (gate / (1.0 + (-gate).exp()));
5881                            }
5882                        }
5883                    })
5884                }
5885
5886                Thunk::Concat {
5887                    dst,
5888                    outer,
5889                    inner,
5890                    total_axis,
5891                    inputs,
5892                } => {
5893                    let outer = outer as usize;
5894                    let inner = inner as usize;
5895                    let total_axis = total_axis as usize;
5896                    let out_total = outer * total_axis * inner;
5897                    // Pre-compute the destination row offset for each input
5898                    // (cumulative axis offsets times inner).
5899                    let mut layout: Vec<(usize, usize, usize)> = Vec::with_capacity(inputs.len());
5900                    let mut cum: usize = 0;
5901                    for (src_off, in_axis) in &inputs {
5902                        let in_axis = *in_axis as usize;
5903                        layout.push((*src_off, cum * inner, in_axis * inner));
5904                        cum += in_axis;
5905                    }
5906                    Arc::new(move |base: *mut u8| unsafe {
5907                        let out = sl_mut(dst, base, out_total);
5908                        let row_stride = total_axis * inner;
5909                        for (src_off, dst_col_off, copy_per_row) in &layout {
5910                            let in_total = outer * *copy_per_row;
5911                            let inp = sl(*src_off, base, in_total);
5912                            for o in 0..outer {
5913                                let dst_row_start = o * row_stride + *dst_col_off;
5914                                let src_row_start = o * *copy_per_row;
5915                                out[dst_row_start..dst_row_start + *copy_per_row].copy_from_slice(
5916                                    &inp[src_row_start..src_row_start + *copy_per_row],
5917                                );
5918                            }
5919                        }
5920                    })
5921                }
5922
5923                Thunk::CustomOp {
5924                    kernel,
5925                    inputs,
5926                    output,
5927                    attrs,
5928                } => {
5929                    // Capture-by-move: clone the Arc and Vecs once into the
5930                    // closure. Dispatch by output dtype each call (the
5931                    // dtype is fixed at compile time but it's cheaper to
5932                    // branch once per execution than to monomorphize a
5933                    // dozen closure variants).
5934                    let kernel = kernel.clone();
5935                    let attrs = attrs.clone();
5936                    let inputs = inputs.clone();
5937                    let (out_off, out_len, out_shape) = output.clone();
5938                    Arc::new(move |base: *mut u8| unsafe {
5939                        dispatch_custom_op(
5940                            &*kernel, &inputs, out_off, out_len, &out_shape, &attrs, base,
5941                        );
5942                    })
5943                }
5944
5945                Thunk::GaussianSplatRender {
5946                    positions_off,
5947                    positions_len,
5948                    scales_off,
5949                    scales_len,
5950                    rotations_off,
5951                    rotations_len,
5952                    opacities_off,
5953                    opacities_len,
5954                    colors_off,
5955                    colors_len,
5956                    sh_coeffs_off,
5957                    sh_coeffs_len,
5958                    meta_off,
5959                    dst_off,
5960                    dst_len,
5961                    width,
5962                    height,
5963                    tile_size,
5964                    radius_scale,
5965                    alpha_cutoff,
5966                    max_splat_steps,
5967                    transmittance_threshold,
5968                    max_list_entries,
5969                } => Arc::new(move |base: *mut u8| unsafe {
5970                    crate::splat::execute_gaussian_splat_render(
5971                        positions_off,
5972                        positions_len,
5973                        scales_off,
5974                        scales_len,
5975                        rotations_off,
5976                        rotations_len,
5977                        opacities_off,
5978                        opacities_len,
5979                        colors_off,
5980                        colors_len,
5981                        sh_coeffs_off,
5982                        sh_coeffs_len,
5983                        meta_off,
5984                        dst_off,
5985                        dst_len,
5986                        width,
5987                        height,
5988                        tile_size,
5989                        radius_scale,
5990                        alpha_cutoff,
5991                        max_splat_steps,
5992                        transmittance_threshold,
5993                        max_list_entries,
5994                        base,
5995                    );
5996                }),
5997
5998                Thunk::GaussianSplatRenderBackward {
5999                    positions_off,
6000                    positions_len,
6001                    scales_off,
6002                    scales_len,
6003                    rotations_off,
6004                    rotations_len,
6005                    opacities_off,
6006                    opacities_len,
6007                    colors_off,
6008                    colors_len,
6009                    sh_coeffs_off,
6010                    sh_coeffs_len,
6011                    meta_off,
6012                    d_loss_off,
6013                    d_loss_len,
6014                    packed_off,
6015                    packed_len,
6016                    width,
6017                    height,
6018                    tile_size,
6019                    radius_scale,
6020                    alpha_cutoff,
6021                    max_splat_steps,
6022                    transmittance_threshold,
6023                    max_list_entries,
6024                    loss_grad_clip,
6025                    sh_band,
6026                    max_anisotropy,
6027                } => Arc::new(move |base: *mut u8| unsafe {
6028                    crate::splat::execute_gaussian_splat_render_backward(
6029                        positions_off,
6030                        positions_len,
6031                        scales_off,
6032                        scales_len,
6033                        rotations_off,
6034                        rotations_len,
6035                        opacities_off,
6036                        opacities_len,
6037                        colors_off,
6038                        colors_len,
6039                        sh_coeffs_off,
6040                        sh_coeffs_len,
6041                        meta_off,
6042                        d_loss_off,
6043                        d_loss_len,
6044                        packed_off,
6045                        packed_len,
6046                        width,
6047                        height,
6048                        tile_size,
6049                        radius_scale,
6050                        alpha_cutoff,
6051                        max_splat_steps,
6052                        transmittance_threshold,
6053                        max_list_entries,
6054                        loss_grad_clip,
6055                        sh_band,
6056                        max_anisotropy,
6057                        base,
6058                    );
6059                }),
6060
6061                Thunk::GaussianSplatPrepare {
6062                    positions_off,
6063                    positions_len,
6064                    scales_off,
6065                    scales_len,
6066                    rotations_off,
6067                    rotations_len,
6068                    opacities_off,
6069                    opacities_len,
6070                    colors_off,
6071                    colors_len,
6072                    sh_coeffs_off,
6073                    sh_coeffs_len,
6074                    meta_off,
6075                    meta_len,
6076                    prep_off,
6077                    prep_len,
6078                    width,
6079                    height,
6080                    tile_size,
6081                    radius_scale,
6082                    alpha_cutoff,
6083                    max_splat_steps,
6084                    transmittance_threshold,
6085                    max_list_entries,
6086                } => Arc::new(move |base: *mut u8| unsafe {
6087                    crate::splat::execute_gaussian_splat_prepare(
6088                        positions_off,
6089                        positions_len,
6090                        scales_off,
6091                        scales_len,
6092                        rotations_off,
6093                        rotations_len,
6094                        opacities_off,
6095                        opacities_len,
6096                        colors_off,
6097                        colors_len,
6098                        sh_coeffs_off,
6099                        sh_coeffs_len,
6100                        meta_off,
6101                        meta_len,
6102                        prep_off,
6103                        prep_len,
6104                        width,
6105                        height,
6106                        tile_size,
6107                        radius_scale,
6108                        alpha_cutoff,
6109                        max_splat_steps,
6110                        transmittance_threshold,
6111                        max_list_entries,
6112                        base,
6113                    );
6114                }),
6115
6116                Thunk::GaussianSplatRasterize {
6117                    prep_off,
6118                    prep_len,
6119                    meta_off,
6120                    meta_len,
6121                    dst_off,
6122                    dst_len,
6123                    count,
6124                    width,
6125                    height,
6126                    tile_size,
6127                    alpha_cutoff,
6128                    max_splat_steps,
6129                    transmittance_threshold,
6130                    max_list_entries,
6131                } => Arc::new(move |base: *mut u8| unsafe {
6132                    crate::splat::execute_gaussian_splat_rasterize(
6133                        prep_off,
6134                        prep_len,
6135                        meta_off,
6136                        meta_len,
6137                        dst_off,
6138                        dst_len,
6139                        count,
6140                        width,
6141                        height,
6142                        tile_size,
6143                        alpha_cutoff,
6144                        max_splat_steps,
6145                        transmittance_threshold,
6146                        max_list_entries,
6147                        base,
6148                    );
6149                }),
6150
6151                Thunk::Fft1d {
6152                    src,
6153                    dst,
6154                    outer,
6155                    n_complex,
6156                    inverse,
6157                    dtype,
6158                } => {
6159                    let f: Arc<dyn Fn(*mut u8) + Send + Sync> = match dtype {
6160                        rlx_ir::DType::F64 => Arc::new(move |base: *mut u8| unsafe {
6161                            execute_fft1d_f64(
6162                                src,
6163                                dst,
6164                                outer as usize,
6165                                n_complex as usize,
6166                                inverse,
6167                                base,
6168                            );
6169                        }),
6170                        rlx_ir::DType::F32 => Arc::new(move |base: *mut u8| unsafe {
6171                            execute_fft1d_f32(
6172                                src,
6173                                dst,
6174                                outer as usize,
6175                                n_complex as usize,
6176                                inverse,
6177                                base,
6178                            );
6179                        }),
6180                        other => panic!("Op::Fft on CPU requires F32/F64, got {other:?}"),
6181                    };
6182                    f
6183                }
6184
6185                _ => Arc::new(|_: *mut u8| {}),
6186            }
6187        })
6188        .collect();
6189
6190    // ── Thunk-level attention fusion ──────────────────────
6191    // For small batch*seq, fuse QKV→Narrow×3→[Rope×2]→Attention→OutProj
6192    // into a single FusedAttnBlock. Auto-detects from Attention thunks.
6193    let fuse_threshold: usize = rlx_ir::env::var("RLX_FUSE_ATTN_THRESHOLD")
6194        .and_then(|v| v.parse().ok())
6195        .unwrap_or(64);
6196    let should_fuse = thunks.iter().any(|t| match t {
6197        Thunk::Attention { batch, seq, .. } => {
6198            (*batch as usize) * (*seq as usize) <= fuse_threshold
6199        }
6200        _ => false,
6201    });
6202
6203    if should_fuse {
6204        // Build non-Nop index for pattern matching across Nop gaps
6205        let active: Vec<usize> = thunks
6206            .iter()
6207            .enumerate()
6208            .filter(|(_, t)| !matches!(t, Thunk::Nop))
6209            .map(|(i, _)| i)
6210            .collect();
6211
6212        let mut kill = vec![false; thunks.len()]; // mark thunks to remove
6213        let mut insertions: Vec<(usize, Thunk)> = Vec::new(); // (position, replacement)
6214
6215        let mut ai = 0;
6216        while ai < active.len() {
6217            // Helper: get active thunk at offset from current
6218            let a = |off: usize| -> Option<(usize, &Thunk)> {
6219                active.get(ai + off).map(|&idx| (idx, &thunks[idx]))
6220            };
6221
6222            // Try BERT pattern: FusedMmBiasAct(QKV) → Narrow×3 → Attention → FusedMmBiasAct(out)
6223            let matched = (|| {
6224                let (_i0, t0) = a(0)?;
6225                let (_, t1) = a(1)?;
6226                let (_, t2) = a(2)?;
6227                let (_, t3) = a(3)?;
6228
6229                // a[0] must be FusedMmBiasAct or Sgemm (QKV projection)
6230                let (hidden, qkv_w, qkv_b, has_b) = match t0 {
6231                    Thunk::FusedMmBiasAct {
6232                        a,
6233                        w,
6234                        bias,
6235                        n: _,
6236                        act: None,
6237                        ..
6238                    } => (*a, *w, *bias, true),
6239                    Thunk::Sgemm { a, b, n: _, .. } => (*a, *b, 0, false),
6240                    _ => return None,
6241                };
6242
6243                // a[1..3] must be Narrows
6244                if !matches!(t1, Thunk::Narrow { .. }) {
6245                    return None;
6246                }
6247                if !matches!(t2, Thunk::Narrow { .. }) {
6248                    return None;
6249                }
6250                if !matches!(t3, Thunk::Narrow { .. }) {
6251                    return None;
6252                }
6253
6254                // Look for optional Rope×2 then Attention
6255                let (has_rope, attn_ai, cos_off, sin_off, cl) = if let Some((
6256                    _,
6257                    Thunk::Rope {
6258                        cos, sin, cos_len, ..
6259                    },
6260                )) = a(4)
6261                {
6262                    if matches!(a(5).map(|x| x.1), Some(Thunk::Rope { .. })) {
6263                        if matches!(a(6).map(|x| x.1), Some(Thunk::Attention { .. })) {
6264                            (true, 6, *cos, *sin, *cos_len)
6265                        } else {
6266                            return None;
6267                        }
6268                    } else {
6269                        return None;
6270                    }
6271                } else if matches!(a(4).map(|x| x.1), Some(Thunk::Attention { .. })) {
6272                    (false, 4, 0, 0, 0)
6273                } else {
6274                    return None;
6275                };
6276
6277                let (_attn_real_idx, attn_t) = a(attn_ai)?;
6278                let (batch, seq, heads, head_dim, mask) = match attn_t {
6279                    Thunk::Attention {
6280                        batch,
6281                        seq,
6282                        heads,
6283                        head_dim,
6284                        mask,
6285                        ..
6286                    } => (*batch, *seq, *heads, *head_dim, *mask),
6287                    _ => return None,
6288                };
6289
6290                // Next active must be out projection (FusedMmBiasAct or Sgemm)
6291                let (_out_real_idx, out_t) = a(attn_ai + 1)?;
6292                let (out_w, out_b, out_dst) = match out_t {
6293                    Thunk::FusedMmBiasAct {
6294                        w,
6295                        bias,
6296                        c,
6297                        act: None,
6298                        ..
6299                    } => (*w, *bias, *c),
6300                    Thunk::Sgemm { b: w, c, .. } => (*w, 0, *c),
6301                    _ => return None,
6302                };
6303
6304                let hs = heads * head_dim;
6305                let total_active = attn_ai + 2; // number of active thunks consumed
6306
6307                Some((
6308                    total_active,
6309                    Thunk::FusedAttnBlock {
6310                        hidden,
6311                        qkv_w,
6312                        out_w,
6313                        mask,
6314                        out: out_dst,
6315                        qkv_b: if has_b { qkv_b } else { 0 },
6316                        out_b: if has_b { out_b } else { 0 },
6317                        cos: cos_off,
6318                        sin: sin_off,
6319                        cos_len: cl,
6320                        batch,
6321                        seq,
6322                        hs,
6323                        nh: heads,
6324                        dh: head_dim,
6325                        has_bias: has_b,
6326                        has_rope,
6327                    },
6328                ))
6329            })();
6330
6331            if let Some((count, fused_thunk)) = matched {
6332                // Mark consumed thunks for removal
6333                for off in 0..count {
6334                    if let Some(&idx) = active.get(ai + off) {
6335                        kill[idx] = true;
6336                    }
6337                }
6338                // Insert replacement at position of the QKV thunk
6339                insertions.push((active[ai], fused_thunk));
6340                ai += count;
6341            } else {
6342                ai += 1;
6343            }
6344        }
6345
6346        // Rebuild thunk list: keep non-killed, insert fused at right positions
6347        if !insertions.is_empty() {
6348            let mut new_thunks = Vec::with_capacity(thunks.len());
6349            let mut insert_idx = 0;
6350            for (i, t) in thunks.into_iter().enumerate() {
6351                if insert_idx < insertions.len() && insertions[insert_idx].0 == i {
6352                    new_thunks.push(insertions[insert_idx].1.clone());
6353                    insert_idx += 1;
6354                }
6355                if !kill[i] {
6356                    new_thunks.push(t);
6357                }
6358            }
6359            if cfg.verbose >= 1 {
6360                eprintln!(
6361                    "[rlx] fused_attention: {} attention blocks fused",
6362                    insertions.len()
6363                );
6364            }
6365            thunks = new_thunks;
6366        }
6367    }
6368
6369    // ── Full layer fusion ──────────────────────────────────
6370    // After attention blocks are fused, scan for full layer patterns:
6371    // BERT:  FusedAttnBlock → FusedResidualLN → FusedMmBiasAct(gelu) → Sgemm → BiasAdd → FusedResidualLN
6372    // Nomic: FusedAttnBlock → BinaryFull(add) → LayerNorm → Sgemm → [Narrow×2 → Silu → BinaryFull(mul)] → Sgemm → BinaryFull(add) → LayerNorm
6373    if should_fuse {
6374        let active: Vec<usize> = thunks
6375            .iter()
6376            .enumerate()
6377            .filter(|(_, t)| !matches!(t, Thunk::Nop))
6378            .map(|(i, _)| i)
6379            .collect();
6380
6381        let mut kill = vec![false; thunks.len()];
6382        let mut insertions: Vec<(usize, Thunk)> = Vec::new();
6383
6384        let a = |ai: usize| -> Option<&Thunk> { active.get(ai).map(|&i| &thunks[i]) };
6385
6386        let mut ai = 0;
6387        while ai < active.len() {
6388            // BERT pattern: FusedAttnBlock → FusedResidualLN → FusedMmBiasAct(gelu) → FusedMmBiasAct(none) → FusedResidualLN
6389            let bert_match = (|| -> Option<usize> {
6390                let fab = a(ai)?;
6391                let rln1 = a(ai + 1)?;
6392                let ffn1 = a(ai + 2)?;
6393                let ffn2 = a(ai + 3)?;
6394                let rln2 = a(ai + 4)?;
6395
6396                let (hidden, qkv_w, qkv_b, out_w, out_b, mask, batch, seq, hs, nh, dh) = match fab {
6397                    Thunk::FusedAttnBlock {
6398                        hidden,
6399                        qkv_w,
6400                        qkv_b,
6401                        out_w,
6402                        out_b,
6403                        mask,
6404                        batch,
6405                        seq,
6406                        hs,
6407                        nh,
6408                        dh,
6409                        has_bias: true,
6410                        has_rope: false,
6411                        ..
6412                    } => (
6413                        *hidden, *qkv_w, *qkv_b, *out_w, *out_b, *mask, *batch, *seq, *hs, *nh, *dh,
6414                    ),
6415                    _ => return None,
6416                };
6417                let (ln1_g, ln1_b, eps1) = match rln1 {
6418                    Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
6419                    _ => return None,
6420                };
6421                let (fc1_w, fc1_b, int_dim) = match ffn1 {
6422                    Thunk::FusedMmBiasAct {
6423                        w,
6424                        bias,
6425                        n,
6426                        act: Some(Activation::Gelu),
6427                        ..
6428                    } => (*w, *bias, *n),
6429                    _ => return None,
6430                };
6431                let (fc2_w, fc2_b) = match ffn2 {
6432                    Thunk::FusedMmBiasAct {
6433                        w, bias, act: None, ..
6434                    } => (*w, *bias),
6435                    _ => return None,
6436                };
6437                let (ln2_g, ln2_b, eps2, out) = match rln2 {
6438                    Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
6439                    _ => return None,
6440                };
6441
6442                for off in 0..5 {
6443                    kill[active[ai + off]] = true;
6444                }
6445                insertions.push((
6446                    active[ai],
6447                    Thunk::FusedBertLayer {
6448                        hidden,
6449                        qkv_w,
6450                        qkv_b,
6451                        out_w,
6452                        out_b,
6453                        mask,
6454                        ln1_g,
6455                        ln1_b,
6456                        eps1,
6457                        fc1_w,
6458                        fc1_b,
6459                        fc2_w,
6460                        fc2_b,
6461                        ln2_g,
6462                        ln2_b,
6463                        eps2,
6464                        out,
6465                        batch,
6466                        seq,
6467                        hs,
6468                        nh,
6469                        dh,
6470                        int_dim,
6471                    },
6472                ));
6473                Some(5)
6474            })();
6475            if let Some(n) = bert_match {
6476                ai += n;
6477                continue;
6478            }
6479
6480            // Nomic full layer fusion — disabled pending SwiGLU stride debugging.
6481            // Nomic still benefits from FusedAttnBlock (attention-level fusion).
6482            // The body below is kept as reference for when the stride bug is fixed.
6483            #[allow(unreachable_code)]
6484            let nomic_match = (|| -> Option<usize> {
6485                return None; // TODO: fix SwiGLU strided fc2 output mismatch
6486                let fab = a(ai)?;
6487                let (hidden, qkv_w, out_w, mask, cos, sin, cos_len, batch, seq, hs, nh, dh) =
6488                    match fab {
6489                        Thunk::FusedAttnBlock {
6490                            hidden,
6491                            qkv_w,
6492                            out_w,
6493                            mask,
6494                            cos,
6495                            sin,
6496                            cos_len,
6497                            batch,
6498                            seq,
6499                            hs,
6500                            nh,
6501                            dh,
6502                            has_bias: false,
6503                            has_rope: true,
6504                            ..
6505                        } => (
6506                            *hidden, *qkv_w, *out_w, *mask, *cos, *sin, *cos_len, *batch, *seq,
6507                            *hs, *nh, *dh,
6508                        ),
6509                        _ => return None,
6510                    };
6511                // FusedResidualLN for LN1
6512                let (ln1_g, ln1_b, eps1) = match a(ai + 1)? {
6513                    Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
6514                    _ => return None,
6515                };
6516                // Sgemm (fused fc11+fc12)
6517                let fused_fc_w = match a(ai + 2)? {
6518                    Thunk::Sgemm { b: w, .. } => *w,
6519                    _ => return None,
6520                };
6521                // Narrow×2 for split
6522                if !matches!(a(ai + 3)?, Thunk::Narrow { .. }) {
6523                    return None;
6524                }
6525                if !matches!(a(ai + 4)?, Thunk::Narrow { .. }) {
6526                    return None;
6527                }
6528                // SiLU
6529                if !matches!(
6530                    a(ai + 5)?,
6531                    Thunk::ActivationInPlace {
6532                        act: Activation::Silu,
6533                        ..
6534                    }
6535                ) {
6536                    return None;
6537                }
6538                // BinaryFull(Mul) for gate
6539                if !matches!(
6540                    a(ai + 6)?,
6541                    Thunk::BinaryFull {
6542                        op: BinaryOp::Mul,
6543                        ..
6544                    }
6545                ) {
6546                    return None;
6547                }
6548                // Sgemm (fc2)
6549                let fc2_w = match a(ai + 7)? {
6550                    Thunk::Sgemm { b: w, .. } => *w,
6551                    _ => return None,
6552                };
6553                // Get int_dim from the Narrow (inner = int_dim for last-axis narrow)
6554                let int_dim = match a(ai + 3)? {
6555                    Thunk::Narrow { inner, .. } => *inner,
6556                    _ => return None,
6557                };
6558                // FusedResidualLN for LN2
6559                let (ln2_g, ln2_b, eps2, out) = match a(ai + 8)? {
6560                    Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
6561                    _ => return None,
6562                };
6563
6564                for off in 0..9 {
6565                    kill[active[ai + off]] = true;
6566                }
6567                insertions.push((
6568                    active[ai],
6569                    Thunk::FusedNomicLayer {
6570                        hidden,
6571                        qkv_w,
6572                        out_w,
6573                        mask,
6574                        cos,
6575                        sin,
6576                        cos_len,
6577                        ln1_g,
6578                        ln1_b,
6579                        eps1,
6580                        fc11_w: fused_fc_w,
6581                        fc12_w: 0,
6582                        fc2_w,
6583                        ln2_g,
6584                        ln2_b,
6585                        eps2,
6586                        out,
6587                        batch,
6588                        seq,
6589                        hs,
6590                        nh,
6591                        dh,
6592                        int_dim,
6593                    },
6594                ));
6595                Some(9)
6596            })();
6597            if let Some(n) = nomic_match {
6598                ai += n;
6599                continue;
6600            }
6601
6602            ai += 1;
6603        }
6604
6605        if !insertions.is_empty() {
6606            let mut new_thunks = Vec::with_capacity(thunks.len());
6607            let mut ins_idx = 0;
6608            for (i, t) in thunks.into_iter().enumerate() {
6609                if ins_idx < insertions.len() && insertions[ins_idx].0 == i {
6610                    new_thunks.push(insertions[ins_idx].1.clone());
6611                    ins_idx += 1;
6612                }
6613                if !kill[i] {
6614                    new_thunks.push(t);
6615                }
6616            }
6617            if cfg.verbose >= 1 {
6618                eprintln!(
6619                    "[rlx] fused_layer: {} full transformer layers fused",
6620                    insertions.len()
6621                );
6622            }
6623            thunks = new_thunks;
6624        }
6625    }
6626
6627    // ── Narrow → Rope thunk fusion (plan #45) ──────────────
6628    // Runs *after* FusedAttnBlock fusion so it only catches the medium-
6629    // batch path (batch*seq > 64) where the bigger fusion didn't fire.
6630    // Pattern: a Rope thunk whose `src` is the dst of an immediately-
6631    // preceding Narrow whose dst has no other consumer in this schedule.
6632    // Rewrite Rope to read directly from the parent buffer with the
6633    // parent's row stride; the Narrow becomes a Nop.
6634    //
6635    // Skipping the Narrow's write saves one full pass over Q/K (B*S*hs
6636    // f32) per Rope. For Nomic h=768 / batch=8 / seq=15 / 12 layers
6637    // that's 2 ropes/layer × 369 KB = ~8.9 MB of write traffic gone.
6638    {
6639        // Collect every byte-offset that's read as a thunk's `src` so
6640        // we know whether a Narrow's dst has consumers other than Rope.
6641        let mut read_offsets: HashMap<usize, usize> = HashMap::new();
6642        for t in &thunks {
6643            for off in thunk_read_offsets(t) {
6644                *read_offsets.entry(off).or_insert(0) += 1;
6645            }
6646        }
6647
6648        let mut fused_count = 0usize;
6649        for i in 0..thunks.len().saturating_sub(1) {
6650            // Look for Rope at i+1 reading from Narrow at i (skip Nops
6651            // between them since the planner left them in place).
6652            let narrow = match &thunks[i] {
6653                Thunk::Narrow { .. } => i,
6654                _ => continue,
6655            };
6656            // Find the next non-Nop thunk
6657            let mut j = narrow + 1;
6658            while j < thunks.len() && matches!(thunks[j], Thunk::Nop) {
6659                j += 1;
6660            }
6661            if j >= thunks.len() {
6662                continue;
6663            }
6664            // Must be Rope reading Narrow's dst
6665            let (n_src, n_dst, n_src_stride) = match &thunks[narrow] {
6666                Thunk::Narrow {
6667                    src,
6668                    dst,
6669                    src_stride,
6670                    ..
6671                } => (*src, *dst, *src_stride),
6672                _ => continue,
6673            };
6674            let rope_reads_narrow = matches!(&thunks[j],
6675                Thunk::Rope { src, .. } if *src == n_dst);
6676            if !rope_reads_narrow {
6677                continue;
6678            }
6679            // Conservatively require that the Narrow's dst has exactly
6680            // one reader (the Rope). Anything else and rewriting would
6681            // skip a needed write.
6682            if read_offsets.get(&n_dst).copied().unwrap_or(0) != 1 {
6683                continue;
6684            }
6685
6686            // Rewire: Rope reads from Narrow's adjusted source with the
6687            // parent buffer's row stride.
6688            if let Thunk::Rope {
6689                src,
6690                src_row_stride,
6691                ..
6692            } = &mut thunks[j]
6693            {
6694                *src = n_src;
6695                *src_row_stride = n_src_stride;
6696            }
6697            thunks[narrow] = Thunk::Nop;
6698            fused_count += 1;
6699        }
6700
6701        if fused_count > 0 && cfg.verbose >= 1 {
6702            eprintln!(
6703                "[rlx] fused_qk_rope: {} Narrow→Rope pairs collapsed",
6704                fused_count
6705            );
6706        }
6707    }
6708
6709    // ── Narrow×3 → Attention thunk fusion (plan #46 deep) ────
6710    // For each Attention thunk in the schedule, look up the producers
6711    // of its q/k/v inputs. If each is a Narrow whose dst has exactly
6712    // one consumer (the Attention), rewire Attention to read directly
6713    // from the parent buffer with the parent's row stride. The three
6714    // Narrows become Nops.
6715    //
6716    // This catches the BERT/Nomic QKV split path that FusedAttnBlock
6717    // misses (batch*seq > 64) — eliminates Q/K/V copies entirely.
6718    // For minilm6 batch=32 seq=16 hs=384: 3 × 32*16*384*4 = 2.3 MB
6719    // per layer × 6 layers = ~14 MB of write traffic gone.
6720    {
6721        let mut read_counts: HashMap<usize, usize> = HashMap::new();
6722        for t in &thunks {
6723            for off in thunk_read_offsets(t) {
6724                *read_counts.entry(off).or_insert(0) += 1;
6725            }
6726        }
6727        // Build dst→index map for fast producer lookup.
6728        let mut dst_to_idx: HashMap<usize, usize> = HashMap::new();
6729        for (i, t) in thunks.iter().enumerate() {
6730            if let Thunk::Narrow { dst, .. } = t {
6731                dst_to_idx.insert(*dst, i);
6732            }
6733        }
6734
6735        let mut fused_count = 0usize;
6736        for i in 0..thunks.len() {
6737            let (q_off, k_off, v_off) = match &thunks[i] {
6738                Thunk::Attention { q, k, v, .. } => (*q, *k, *v),
6739                _ => continue,
6740            };
6741            // All three inputs must come from Narrows.
6742            let q_n = match dst_to_idx.get(&q_off).copied() {
6743                Some(x) => x,
6744                None => continue,
6745            };
6746            let k_n = match dst_to_idx.get(&k_off).copied() {
6747                Some(x) => x,
6748                None => continue,
6749            };
6750            let v_n = match dst_to_idx.get(&v_off).copied() {
6751                Some(x) => x,
6752                None => continue,
6753            };
6754            // Each Narrow's dst must have exactly one reader (this Attn).
6755            if read_counts.get(&q_off).copied().unwrap_or(0) != 1 {
6756                continue;
6757            }
6758            if read_counts.get(&k_off).copied().unwrap_or(0) != 1 {
6759                continue;
6760            }
6761            if read_counts.get(&v_off).copied().unwrap_or(0) != 1 {
6762                continue;
6763            }
6764
6765            let (q_src, q_stride) = match &thunks[q_n] {
6766                Thunk::Narrow {
6767                    src, src_stride, ..
6768                } => (*src, *src_stride),
6769                _ => continue,
6770            };
6771            let (k_src, k_stride) = match &thunks[k_n] {
6772                Thunk::Narrow {
6773                    src, src_stride, ..
6774                } => (*src, *src_stride),
6775                _ => continue,
6776            };
6777            let (v_src, v_stride) = match &thunks[v_n] {
6778                Thunk::Narrow {
6779                    src, src_stride, ..
6780                } => (*src, *src_stride),
6781                _ => continue,
6782            };
6783
6784            if let Thunk::Attention {
6785                q,
6786                k,
6787                v,
6788                q_row_stride,
6789                k_row_stride,
6790                v_row_stride,
6791                ..
6792            } = &mut thunks[i]
6793            {
6794                *q = q_src;
6795                *k = k_src;
6796                *v = v_src;
6797                *q_row_stride = q_stride;
6798                *k_row_stride = k_stride;
6799                *v_row_stride = v_stride;
6800            }
6801            thunks[q_n] = Thunk::Nop;
6802            thunks[k_n] = Thunk::Nop;
6803            thunks[v_n] = Thunk::Nop;
6804            fused_count += 1;
6805        }
6806
6807        if fused_count > 0 && cfg.verbose >= 1 {
6808            eprintln!(
6809                "[rlx] fused_strided_attn: {} Narrow×3→Attention rewrites",
6810                fused_count
6811            );
6812        }
6813    }
6814
6815    ThunkSchedule {
6816        thunks,
6817        moe_resident: None,
6818        moe_resident_layers: None,
6819        moe_topk_capture: None,
6820        mask_threshold: cfg.mask_binary_threshold,
6821        mask_neg_inf: cfg.attn_mask_neg_inf,
6822        score_skip: cfg.score_skip_threshold,
6823        compiled_fns,
6824    }
6825}
6826
6827fn get_len(graph: &Graph, id: NodeId) -> usize {
6828    graph.node(id).shape.num_elements().unwrap_or(0)
6829}
6830
6831/// Static `usize` dims of a node's shape, or empty if any dim is dynamic.
6832fn get_static_dims(graph: &Graph, id: NodeId) -> Vec<usize> {
6833    let dims = graph.node(id).shape.dims();
6834    let mut out = Vec::with_capacity(dims.len());
6835    for d in dims {
6836        if let Some(s) = match d {
6837            rlx_ir::Dim::Static(s) => Some(*s),
6838            _ => None,
6839        } {
6840            out.push(s);
6841        } else {
6842            return Vec::new();
6843        }
6844    }
6845    out
6846}
6847
6848/// NumPy-style broadcast strides for one operand into the flat output
6849/// buffer. Returns a length-`out_dims.len()` `Vec<u32>` where entry
6850/// `d` is `0` if the input is size-1 (broadcast) at output dim `d`
6851/// (after left-padding with size-1 to match ranks), otherwise the
6852/// natural row-major stride into the *input* buffer.
6853///
6854/// Caller iterates output flat index `i` → output coords (row-major)
6855/// → input flat index = dot(coords, strides). The result is correct
6856/// for any broadcast pattern (scalar, last-axis, middle-axis,
6857/// bidirectional).
6858/// True when `rhs_dims` describes a *trailing* broadcast of `out_dims`
6859/// — i.e. every rhs dim either equals the corresponding output dim
6860/// (counting from the right) or rhs is shorter (left-padded with 1s).
6861/// Mid-shape singletons (e.g. rhs `[a, b, 1, d]` into out `[a, b, c, d]`
6862/// where `c > 1`) are NOT trailing broadcasts and require the
6863/// shape-aware `BinaryFull` slow path — `BiasAdd`'s linear bias-replicated
6864/// kernel silently miscomputes them.
6865fn is_trailing_bias_broadcast(rhs_dims: &[rlx_ir::Dim], out_dims: &[rlx_ir::Dim]) -> bool {
6866    if rhs_dims.len() > out_dims.len() {
6867        return false;
6868    }
6869    let off = out_dims.len() - rhs_dims.len();
6870    for i in 0..rhs_dims.len() {
6871        let r = match rhs_dims[i] {
6872            rlx_ir::Dim::Static(n) => n,
6873            _ => return false,
6874        };
6875        let o = match out_dims[off + i] {
6876            rlx_ir::Dim::Static(n) => n,
6877            _ => return false,
6878        };
6879        if r != o {
6880            return false;
6881        }
6882    }
6883    true
6884}
6885
6886fn broadcast_strides(in_dims: &[usize], out_dims: &[usize]) -> Vec<u32> {
6887    let r_out = out_dims.len();
6888    let r_in = in_dims.len();
6889    assert!(
6890        r_in <= r_out,
6891        "broadcast: input rank {r_in} > output rank {r_out}"
6892    );
6893    let pad = r_out - r_in;
6894    let mut strides = vec![0u32; r_out];
6895    let mut acc: usize = 1;
6896    for d in (0..r_out).rev() {
6897        let in_size = if d < pad { 1 } else { in_dims[d - pad] };
6898        if in_size == 1 {
6899            strides[d] = 0;
6900        } else {
6901            assert_eq!(
6902                in_size, out_dims[d],
6903                "broadcast: input dim {in_size} doesn't match output dim {} at axis {d}",
6904                out_dims[d]
6905            );
6906            strides[d] = acc as u32;
6907            acc *= in_size;
6908        }
6909    }
6910    strides
6911}
6912
6913/// Execute a thunk schedule on a raw arena buffer.
6914/// Fastest executor: call pre-compiled closures sequentially.
6915/// Zero match dispatch — each closure is a direct kernel call.
6916pub fn execute_compiled(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
6917    let base = arena_buf.as_mut_ptr();
6918    for f in &schedule.compiled_fns {
6919        f(base);
6920    }
6921}
6922
6923/// Active-extent execution stub. The runtime calls this when it has an
6924/// active-extent hint set. CPU doesn't implement per-thunk active-extent
6925/// scaling yet — return false so the caller falls back to the full
6926/// `execute_thunks` path.
6927pub fn execute_thunks_active(
6928    schedule: &ThunkSchedule,
6929    _arena_buf: &mut [u8],
6930    _actual: usize,
6931    _upper: usize,
6932) -> bool {
6933    let _ = schedule;
6934    false
6935}
6936
6937/// Match-based executor (fallback, used by tests).
6938struct MoeResidencyGuard;
6939impl Drop for MoeResidencyGuard {
6940    fn drop(&mut self) {
6941        if let Some(stats) = crate::moe_residency::take_stats() {
6942            crate::moe_residency::stash_last_forward_stats(stats);
6943        } else {
6944            crate::moe_residency::clear_mask();
6945        }
6946    }
6947}
6948
6949pub fn execute_thunks(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
6950    crate::moe_residency::reset_gmm_counters();
6951    if let Some(layers) = schedule.moe_resident_layers.clone() {
6952        crate::moe_residency::set_per_layer_masks(Some(layers));
6953    } else {
6954        crate::moe_residency::set_mask(schedule.moe_resident.clone());
6955    }
6956    if let Some(cap) = schedule.moe_topk_capture.as_ref() {
6957        cap.clear();
6958    }
6959    let _moe_guard = MoeResidencyGuard;
6960    let base = arena_buf.as_mut_ptr();
6961    let mask_thr = schedule.mask_threshold;
6962    let mask_neg = schedule.mask_neg_inf;
6963    let score_thr = schedule.score_skip;
6964    let thunks = &schedule.thunks;
6965    let len = thunks.len();
6966
6967    // Pre-allocate ALL reusable buffers once (zero per-call allocation)
6968    let max_h = thunks
6969        .iter()
6970        .filter_map(|t| match t {
6971            Thunk::FusedResidualLN { h, .. }
6972            | Thunk::FusedResidualRmsNorm { h, .. }
6973            | Thunk::LayerNorm { h, .. } => Some(*h as usize),
6974            _ => None,
6975        })
6976        .max()
6977        .unwrap_or(0);
6978    let zero_bias = vec![0f32; max_h];
6979
6980    // Pre-allocate per-(batch,head) score buffers for parallel SDPA.
6981    // Q/K/V/out are accessed via strided BLAS — no deinterleave copy needed.
6982    let max_sdpa = thunks
6983        .iter()
6984        .filter_map(|t| match t {
6985            Thunk::Attention {
6986                batch,
6987                seq,
6988                kv_seq,
6989                heads,
6990                head_dim,
6991                ..
6992            } => Some((
6993                *batch as usize,
6994                (*seq as usize).max(*kv_seq as usize),
6995                *heads as usize,
6996                *head_dim as usize,
6997            )),
6998            _ => None,
6999        })
7000        .fold((0, 0, 0, 0), |(mb, ms, mh, md), (b, s, h, d)| {
7001            (mb.max(b), ms.max(s), mh.max(h), md.max(d))
7002        });
7003    let (max_batch, max_seq, max_heads, _max_dh) = max_sdpa;
7004    let max_units = max_batch * max_heads;
7005    let mut sdpa_scores = vec![0f32; max_units * max_seq * max_seq];
7006
7007    // Pre-allocate fused layer buffers (reused across all 12+ layers — zero malloc per layer)
7008    let fl = thunks
7009        .iter()
7010        .filter_map(|t| match t {
7011            Thunk::FusedBertLayer {
7012                batch,
7013                seq,
7014                hs,
7015                int_dim,
7016                ..
7017            } => {
7018                let m = (*batch as usize) * (*seq as usize);
7019                let h = *hs as usize;
7020                let id = *int_dim as usize;
7021                Some((m, h, id, m * (*seq as usize)))
7022            }
7023            Thunk::FusedNomicLayer {
7024                batch,
7025                seq,
7026                hs,
7027                int_dim,
7028                ..
7029            } => {
7030                let m = (*batch as usize) * (*seq as usize);
7031                let h = *hs as usize;
7032                let id = *int_dim as usize;
7033                Some((m, h, id, m * (*seq as usize)))
7034            }
7035            _ => None,
7036        })
7037        .fold((0, 0, 0, 0), |(mm, mh, mi, ms), (m, h, id, ss)| {
7038            (mm.max(m), mh.max(h), mi.max(id), ms.max(ss))
7039        });
7040    let (fl_m, fl_h, fl_int, fl_ss) = fl;
7041    let mut fl_qkv = vec![0f32; fl_m * 3 * fl_h];
7042    let mut fl_attn = vec![0f32; fl_m * fl_h];
7043    let mut fl_res = vec![0f32; fl_m * fl_h];
7044    let mut fl_normed = vec![0f32; fl_m * fl_h];
7045    let mut fl_ffn = vec![0f32; fl_m * fl_int.max(2 * fl_int)]; // Nomic needs 2×int for fused fc11+fc12
7046    let mut fl_sc = vec![0f32; fl_ss.max(1)];
7047
7048    for i in 0..len {
7049        let thunk = unsafe { thunks.get_unchecked(i) };
7050        match thunk {
7051            Thunk::Nop => {}
7052
7053            Thunk::GaussianSplatRender {
7054                positions_off,
7055                positions_len,
7056                scales_off,
7057                scales_len,
7058                rotations_off,
7059                rotations_len,
7060                opacities_off,
7061                opacities_len,
7062                colors_off,
7063                colors_len,
7064                sh_coeffs_off,
7065                sh_coeffs_len,
7066                meta_off,
7067                dst_off,
7068                dst_len,
7069                width,
7070                height,
7071                tile_size,
7072                radius_scale,
7073                alpha_cutoff,
7074                max_splat_steps,
7075                transmittance_threshold,
7076                max_list_entries,
7077            } => unsafe {
7078                crate::splat::execute_gaussian_splat_render(
7079                    *positions_off,
7080                    *positions_len,
7081                    *scales_off,
7082                    *scales_len,
7083                    *rotations_off,
7084                    *rotations_len,
7085                    *opacities_off,
7086                    *opacities_len,
7087                    *colors_off,
7088                    *colors_len,
7089                    *sh_coeffs_off,
7090                    *sh_coeffs_len,
7091                    *meta_off,
7092                    *dst_off,
7093                    *dst_len,
7094                    *width,
7095                    *height,
7096                    *tile_size,
7097                    *radius_scale,
7098                    *alpha_cutoff,
7099                    *max_splat_steps,
7100                    *transmittance_threshold,
7101                    *max_list_entries,
7102                    base,
7103                );
7104            },
7105
7106            Thunk::GaussianSplatRenderBackward {
7107                positions_off,
7108                positions_len,
7109                scales_off,
7110                scales_len,
7111                rotations_off,
7112                rotations_len,
7113                opacities_off,
7114                opacities_len,
7115                colors_off,
7116                colors_len,
7117                sh_coeffs_off,
7118                sh_coeffs_len,
7119                meta_off,
7120                d_loss_off,
7121                d_loss_len,
7122                packed_off,
7123                packed_len,
7124                width,
7125                height,
7126                tile_size,
7127                radius_scale,
7128                alpha_cutoff,
7129                max_splat_steps,
7130                transmittance_threshold,
7131                max_list_entries,
7132                loss_grad_clip,
7133                sh_band,
7134                max_anisotropy,
7135            } => unsafe {
7136                crate::splat::execute_gaussian_splat_render_backward(
7137                    *positions_off,
7138                    *positions_len,
7139                    *scales_off,
7140                    *scales_len,
7141                    *rotations_off,
7142                    *rotations_len,
7143                    *opacities_off,
7144                    *opacities_len,
7145                    *colors_off,
7146                    *colors_len,
7147                    *sh_coeffs_off,
7148                    *sh_coeffs_len,
7149                    *meta_off,
7150                    *d_loss_off,
7151                    *d_loss_len,
7152                    *packed_off,
7153                    *packed_len,
7154                    *width,
7155                    *height,
7156                    *tile_size,
7157                    *radius_scale,
7158                    *alpha_cutoff,
7159                    *max_splat_steps,
7160                    *transmittance_threshold,
7161                    *max_list_entries,
7162                    *loss_grad_clip,
7163                    *sh_band,
7164                    *max_anisotropy,
7165                    base,
7166                );
7167            },
7168
7169            Thunk::GaussianSplatPrepare {
7170                positions_off,
7171                positions_len,
7172                scales_off,
7173                scales_len,
7174                rotations_off,
7175                rotations_len,
7176                opacities_off,
7177                opacities_len,
7178                colors_off,
7179                colors_len,
7180                sh_coeffs_off,
7181                sh_coeffs_len,
7182                meta_off,
7183                meta_len,
7184                prep_off,
7185                prep_len,
7186                width,
7187                height,
7188                tile_size,
7189                radius_scale,
7190                alpha_cutoff,
7191                max_splat_steps,
7192                transmittance_threshold,
7193                max_list_entries,
7194            } => unsafe {
7195                crate::splat::execute_gaussian_splat_prepare(
7196                    *positions_off,
7197                    *positions_len,
7198                    *scales_off,
7199                    *scales_len,
7200                    *rotations_off,
7201                    *rotations_len,
7202                    *opacities_off,
7203                    *opacities_len,
7204                    *colors_off,
7205                    *colors_len,
7206                    *sh_coeffs_off,
7207                    *sh_coeffs_len,
7208                    *meta_off,
7209                    *meta_len,
7210                    *prep_off,
7211                    *prep_len,
7212                    *width,
7213                    *height,
7214                    *tile_size,
7215                    *radius_scale,
7216                    *alpha_cutoff,
7217                    *max_splat_steps,
7218                    *transmittance_threshold,
7219                    *max_list_entries,
7220                    base,
7221                );
7222            },
7223
7224            Thunk::GaussianSplatRasterize {
7225                prep_off,
7226                prep_len,
7227                meta_off,
7228                meta_len,
7229                dst_off,
7230                dst_len,
7231                count,
7232                width,
7233                height,
7234                tile_size,
7235                alpha_cutoff,
7236                max_splat_steps,
7237                transmittance_threshold,
7238                max_list_entries,
7239            } => unsafe {
7240                crate::splat::execute_gaussian_splat_rasterize(
7241                    *prep_off,
7242                    *prep_len,
7243                    *meta_off,
7244                    *meta_len,
7245                    *dst_off,
7246                    *dst_len,
7247                    *count,
7248                    *width,
7249                    *height,
7250                    *tile_size,
7251                    *alpha_cutoff,
7252                    *max_splat_steps,
7253                    *transmittance_threshold,
7254                    *max_list_entries,
7255                    base,
7256                );
7257            },
7258
7259            Thunk::Fft1d {
7260                src,
7261                dst,
7262                outer,
7263                n_complex,
7264                inverse,
7265                dtype,
7266            } => unsafe {
7267                match dtype {
7268                    rlx_ir::DType::F64 => execute_fft1d_f64(
7269                        *src,
7270                        *dst,
7271                        *outer as usize,
7272                        *n_complex as usize,
7273                        *inverse,
7274                        base,
7275                    ),
7276                    rlx_ir::DType::F32 => execute_fft1d_f32(
7277                        *src,
7278                        *dst,
7279                        *outer as usize,
7280                        *n_complex as usize,
7281                        *inverse,
7282                        base,
7283                    ),
7284                    other => panic!("Op::Fft on CPU requires F32/F64, got {other:?}"),
7285                }
7286            },
7287
7288            // CustomFn dispatch (interpreted path). Mirrors the
7289            // pre-compiled-closure variant elsewhere in this file.
7290            // Patched by rlx-eda.
7291            Thunk::CustomFn {
7292                body,
7293                body_init,
7294                inputs,
7295                body_output_off,
7296                outer_output_off,
7297                out_bytes,
7298            } => {
7299                let mut body_buf: Vec<u8> = (**body_init).clone();
7300                unsafe {
7301                    for (body_in_off, outer_in_off, n_bytes) in inputs.iter() {
7302                        let src = (base as *const u8).add(*outer_in_off);
7303                        let dst = body_buf.as_mut_ptr().add(*body_in_off);
7304                        std::ptr::copy_nonoverlapping(src, dst, *n_bytes as usize);
7305                    }
7306                }
7307                execute_thunks(body, &mut body_buf);
7308                unsafe {
7309                    let src = body_buf.as_ptr().add(*body_output_off);
7310                    let dst = base.add(*outer_output_off);
7311                    std::ptr::copy_nonoverlapping(src, dst, *out_bytes as usize);
7312                }
7313            }
7314
7315            Thunk::Sgemm { a, b, c, m, k, n } => {
7316                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
7317                unsafe {
7318                    crate::blas::sgemm_auto(
7319                        sl(*a, base, m * k),
7320                        sl(*b, base, k * n),
7321                        sl_mut(*c, base, m * n),
7322                        m,
7323                        k,
7324                        n,
7325                    );
7326                }
7327            }
7328
7329            Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
7330                let (n_, nrhs_) = (*n as usize, *nrhs as usize);
7331                // LAPACK overwrites both A and B; clone into scratch
7332                // each call. Caller's A and b must be preserved for
7333                // VJP recompute. (Eventually: swap to a factor-once /
7334                // solve-many scheme; that's the symbolic-reuse story
7335                // and lives with the sparse path.)
7336                unsafe {
7337                    let a_src = sl_f64(*a, base, n_ * n_);
7338                    let b_src = sl_f64(*b, base, n_ * nrhs_);
7339                    let mut a_scratch: Vec<f64> = a_src.to_vec();
7340                    let mut x_buf: Vec<f64> = b_src.to_vec();
7341                    let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7342                    if info != 0 {
7343                        panic!(
7344                            "DenseSolveF64: dgesv reported singular matrix \
7345                                (info={info}, n={n_}, nrhs={nrhs_})"
7346                        );
7347                    }
7348                    let dst = sl_mut_f64(*x, base, n_ * nrhs_);
7349                    dst.copy_from_slice(&x_buf);
7350                }
7351            }
7352
7353            Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
7354                let (n_, nrhs_) = (*n as usize, *nrhs as usize);
7355                unsafe {
7356                    let a_src = sl(*a, base, n_ * n_);
7357                    let b_src = sl(*b, base, n_ * nrhs_);
7358                    let mut a_scratch: Vec<f32> = a_src.to_vec();
7359                    let mut x_buf: Vec<f32> = b_src.to_vec();
7360                    let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7361                    if info != 0 {
7362                        panic!(
7363                            "DenseSolveF32: sgesv reported singular matrix \
7364                             (info={info}, n={n_}, nrhs={nrhs_})"
7365                        );
7366                    }
7367                    let dst = sl_mut(*x, base, n_ * nrhs_);
7368                    dst.copy_from_slice(&x_buf);
7369                }
7370            }
7371
7372            Thunk::BatchedDenseSolveF64 {
7373                a,
7374                b,
7375                x,
7376                batch,
7377                n,
7378                nrhs,
7379            } => {
7380                // Per slice: extract A_i and b_i, dgesv, write x_i.
7381                // LAPACK has no batched dgesv on Accelerate, so this
7382                // is a serial loop over the batch axis. cuSOLVER /
7383                // hipSOLVER expose `getrfBatched` / `getrsBatched` for
7384                // the GPU path — we'll wire that in rlx-cuda when
7385                // someone needs Linux+CUDA.
7386                let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
7387                let a_stride = n_ * n_;
7388                let b_stride = n_ * nrhs_;
7389                unsafe {
7390                    let a_full = sl_f64(*a, base, b_ * a_stride);
7391                    let b_full = sl_f64(*b, base, b_ * b_stride);
7392                    let x_full = sl_mut_f64(*x, base, b_ * b_stride);
7393                    for bi in 0..b_ {
7394                        let mut a_scratch: Vec<f64> =
7395                            a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
7396                        let mut x_buf: Vec<f64> =
7397                            b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
7398                        let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7399                        if info != 0 {
7400                            panic!(
7401                                "BatchedDenseSolveF64: slice {bi} \
7402                                    singular (info={info}, n={n_}, nrhs={nrhs_})"
7403                            );
7404                        }
7405                        x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
7406                    }
7407                }
7408            }
7409
7410            Thunk::BatchedDenseSolveF32 {
7411                a,
7412                b,
7413                x,
7414                batch,
7415                n,
7416                nrhs,
7417            } => {
7418                let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
7419                let a_stride = n_ * n_;
7420                let b_stride = n_ * nrhs_;
7421                unsafe {
7422                    let a_full = sl(*a, base, b_ * a_stride);
7423                    let b_full = sl(*b, base, b_ * b_stride);
7424                    let x_full = sl_mut(*x, base, b_ * b_stride);
7425                    for bi in 0..b_ {
7426                        let mut a_scratch = a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
7427                        let mut x_buf = b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
7428                        let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7429                        if info != 0 {
7430                            panic!("BatchedDenseSolveF32: slice {bi} singular (info={info})");
7431                        }
7432                        x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
7433                    }
7434                }
7435            }
7436
7437            Thunk::BatchedDgemmF64 {
7438                a,
7439                b,
7440                c,
7441                batch,
7442                m,
7443                k,
7444                n,
7445            } => {
7446                let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
7447                let a_stride = m_ * k_;
7448                let b_stride = k_ * n_;
7449                let c_stride = m_ * n_;
7450                unsafe {
7451                    let a_full = sl_f64(*a, base, b_ * a_stride);
7452                    let b_full = sl_f64(*b, base, b_ * b_stride);
7453                    let c_full = sl_mut_f64(*c, base, b_ * c_stride);
7454                    for bi in 0..b_ {
7455                        let a_slice = &a_full[bi * a_stride..(bi + 1) * a_stride];
7456                        let b_slice = &b_full[bi * b_stride..(bi + 1) * b_stride];
7457                        let c_slice = &mut c_full[bi * c_stride..(bi + 1) * c_stride];
7458                        crate::blas::dgemm(a_slice, b_slice, c_slice, m_, k_, n_);
7459                    }
7460                }
7461            }
7462
7463            Thunk::BatchedSgemm {
7464                a,
7465                b,
7466                c,
7467                batch,
7468                m,
7469                k,
7470                n,
7471            } => {
7472                let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
7473                let a_stride = m_ * k_;
7474                let b_stride = k_ * n_;
7475                let c_stride = m_ * n_;
7476                unsafe {
7477                    let a_full = sl(*a, base, b_ * a_stride);
7478                    let b_full = sl(*b, base, b_ * b_stride);
7479                    let c_full = sl_mut(*c, base, b_ * c_stride);
7480                    for bi in 0..b_ {
7481                        let a_slice = &a_full[bi * a_stride..(bi + 1) * a_stride];
7482                        let b_slice = &b_full[bi * b_stride..(bi + 1) * b_stride];
7483                        let c_slice = &mut c_full[bi * c_stride..(bi + 1) * c_stride];
7484                        crate::blas::sgemm_auto(a_slice, b_slice, c_slice, m_, k_, n_);
7485                    }
7486                }
7487            }
7488
7489            Thunk::Dgemm { a, b, c, m, k, n } => {
7490                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
7491                unsafe {
7492                    crate::blas::dgemm(
7493                        sl_f64(*a, base, m * k),
7494                        sl_f64(*b, base, k * n),
7495                        sl_mut_f64(*c, base, m * n),
7496                        m,
7497                        k,
7498                        n,
7499                    );
7500                }
7501            }
7502
7503            Thunk::TransposeF64 {
7504                src,
7505                dst,
7506                in_total,
7507                out_dims,
7508                in_strides,
7509            } => unsafe {
7510                let inp = sl_f64(*src, base, *in_total as usize);
7511                let out_total: usize = out_dims.iter().map(|d| *d as usize).product();
7512                let out = sl_mut_f64(*dst, base, out_total);
7513                transpose_walk_f64(inp, out, out_dims, in_strides);
7514            },
7515
7516            Thunk::ActivationF64 {
7517                src,
7518                dst,
7519                len,
7520                kind,
7521            } => {
7522                let len = *len as usize;
7523                unsafe {
7524                    let inp = sl_f64(*src, base, len);
7525                    let out = sl_mut_f64(*dst, base, len);
7526                    apply_activation_f64(inp, out, *kind);
7527                }
7528            }
7529
7530            Thunk::ReduceSumF64 {
7531                src,
7532                dst,
7533                outer,
7534                reduced,
7535                inner,
7536            } => {
7537                let (o, r, n) = (*outer as usize, *reduced as usize, *inner as usize);
7538                unsafe {
7539                    let inp = sl_f64(*src, base, o * r * n);
7540                    let out = sl_mut_f64(*dst, base, o * n);
7541                    reduce_sum_f64(inp, out, o, r, n);
7542                }
7543            }
7544
7545            Thunk::CopyF64 { src, dst, len } => {
7546                let len = *len as usize;
7547                if *src == *dst { /* aliased, no copy needed */
7548                } else {
7549                    unsafe {
7550                        let s = sl_f64(*src, base, len);
7551                        let d = sl_mut_f64(*dst, base, len);
7552                        d.copy_from_slice(s);
7553                    }
7554                }
7555            }
7556
7557            Thunk::BinaryFullF64 {
7558                lhs,
7559                rhs,
7560                dst,
7561                len,
7562                lhs_len,
7563                rhs_len,
7564                op,
7565                out_dims_bcast,
7566                bcast_lhs_strides,
7567                bcast_rhs_strides,
7568            } => {
7569                let len = *len as usize;
7570                let lhs_len = *lhs_len as usize;
7571                let rhs_len = *rhs_len as usize;
7572                unsafe {
7573                    let l = sl_f64(*lhs, base, lhs_len);
7574                    let r = sl_f64(*rhs, base, rhs_len);
7575                    let d = sl_mut_f64(*dst, base, len);
7576                    if lhs_len == len && rhs_len == len {
7577                        for i in 0..len {
7578                            d[i] = binary_op_f64(*op, l[i], r[i]);
7579                        }
7580                    } else if !out_dims_bcast.is_empty() {
7581                        // Shape-aware broadcast path: correct for
7582                        // arbitrary NumPy-style broadcasts including
7583                        // bidirectional `[N,1] op [1,S]`.
7584                        let rank = out_dims_bcast.len();
7585                        let mut coords = vec![0u32; rank];
7586                        for i in 0..len {
7587                            let mut rem = i;
7588                            for ax in (0..rank).rev() {
7589                                let sz = out_dims_bcast[ax] as usize;
7590                                coords[ax] = (rem % sz) as u32;
7591                                rem /= sz;
7592                            }
7593                            let mut li: usize = 0;
7594                            let mut ri: usize = 0;
7595                            for ax in 0..rank {
7596                                li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
7597                                ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
7598                            }
7599                            d[i] = binary_op_f64(*op, l[li], r[ri]);
7600                        }
7601                    } else {
7602                        // Fallback: legacy modulo path (preserved for
7603                        // dynamic-shape graphs where strides can't be
7604                        // precomputed). Only correct for scalar /
7605                        // last-axis broadcast.
7606                        for i in 0..len {
7607                            d[i] = binary_op_f64(*op, l[i % lhs_len], r[i % rhs_len]);
7608                        }
7609                    }
7610                }
7611            }
7612
7613            Thunk::BinaryFullC64 {
7614                lhs,
7615                rhs,
7616                dst,
7617                len,
7618                lhs_len,
7619                rhs_len,
7620                op,
7621                out_dims_bcast,
7622                bcast_lhs_strides,
7623                bcast_rhs_strides,
7624            } => {
7625                // Complex element layout: [re_0, im_0, re_1, im_1, ...]
7626                // Underlying f32 buffer length is 2·N (N = complex
7627                // element count). All offsets are byte offsets; the
7628                // `sl` helper reads as f32 starting at the byte
7629                // offset, so f32-length = 2·complex-len.
7630                let n_out = *len as usize;
7631                let n_l = *lhs_len as usize;
7632                let n_r = *rhs_len as usize;
7633                unsafe {
7634                    let l = sl(*lhs, base, 2 * n_l);
7635                    let r = sl(*rhs, base, 2 * n_r);
7636                    let d = sl_mut(*dst, base, 2 * n_out);
7637                    let do_c64 = |a_re: f32, a_im: f32, b_re: f32, b_im: f32| -> (f32, f32) {
7638                        match op {
7639                            BinaryOp::Add => (a_re + b_re, a_im + b_im),
7640                            BinaryOp::Sub => (a_re - b_re, a_im - b_im),
7641                            BinaryOp::Mul => (a_re * b_re - a_im * b_im, a_re * b_im + a_im * b_re),
7642                            BinaryOp::Div => {
7643                                let denom = b_re * b_re + b_im * b_im;
7644                                (
7645                                    (a_re * b_re + a_im * b_im) / denom,
7646                                    (a_im * b_re - a_re * b_im) / denom,
7647                                )
7648                            }
7649                            BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => {
7650                                unreachable!("C64 max/min/pow rejected at lowering")
7651                            }
7652                        }
7653                    };
7654                    if n_l == n_out && n_r == n_out {
7655                        for i in 0..n_out {
7656                            let (re, im) = do_c64(l[2 * i], l[2 * i + 1], r[2 * i], r[2 * i + 1]);
7657                            d[2 * i] = re;
7658                            d[2 * i + 1] = im;
7659                        }
7660                    } else if !out_dims_bcast.is_empty() {
7661                        // Strided complex broadcast: strides are in
7662                        // *complex element* units; multiply by 2 when
7663                        // indexing into the f32 buffer.
7664                        let rank = out_dims_bcast.len();
7665                        let mut coords = vec![0u32; rank];
7666                        for i in 0..n_out {
7667                            let mut rem = i;
7668                            for ax in (0..rank).rev() {
7669                                let sz = out_dims_bcast[ax] as usize;
7670                                coords[ax] = (rem % sz) as u32;
7671                                rem /= sz;
7672                            }
7673                            let mut li: usize = 0;
7674                            let mut ri: usize = 0;
7675                            for ax in 0..rank {
7676                                li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
7677                                ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
7678                            }
7679                            let (re, im) =
7680                                do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
7681                            d[2 * i] = re;
7682                            d[2 * i + 1] = im;
7683                        }
7684                    } else {
7685                        // Modulo fallback (scalar / last-axis broadcast).
7686                        for i in 0..n_out {
7687                            let li = if n_l == 1 { 0 } else { i % n_l };
7688                            let ri = if n_r == 1 { 0 } else { i % n_r };
7689                            let (re, im) =
7690                                do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
7691                            d[2 * i] = re;
7692                            d[2 * i + 1] = im;
7693                        }
7694                    }
7695                }
7696            }
7697
7698            Thunk::ComplexNormSqF32 { src, dst, len } => {
7699                let n = *len as usize;
7700                unsafe {
7701                    let s = sl(*src, base, 2 * n);
7702                    let d = sl_mut(*dst, base, n);
7703                    for i in 0..n {
7704                        let re = s[2 * i];
7705                        let im = s[2 * i + 1];
7706                        d[i] = re * re + im * im;
7707                    }
7708                }
7709            }
7710
7711            Thunk::ComplexNormSqBackwardF32 { z, g, dz, len } => {
7712                // Wirtinger: dz = g · z, element-wise complex
7713                // (g is real, z is complex).
7714                let n = *len as usize;
7715                unsafe {
7716                    let zb = sl(*z, base, 2 * n);
7717                    let gb = sl(*g, base, n);
7718                    let db = sl_mut(*dz, base, 2 * n);
7719                    for i in 0..n {
7720                        let re = zb[2 * i];
7721                        let im = zb[2 * i + 1];
7722                        let gv = gb[i];
7723                        db[2 * i] = gv * re;
7724                        db[2 * i + 1] = gv * im;
7725                    }
7726                }
7727            }
7728
7729            Thunk::ConjugateC64 { src, dst, len } => {
7730                let n = *len as usize;
7731                unsafe {
7732                    let s = sl(*src, base, 2 * n);
7733                    let d = sl_mut(*dst, base, 2 * n);
7734                    for i in 0..n {
7735                        d[2 * i] = s[2 * i];
7736                        d[2 * i + 1] = -s[2 * i + 1];
7737                    }
7738                }
7739            }
7740
7741            Thunk::ActivationC64 {
7742                src,
7743                dst,
7744                len,
7745                kind,
7746            } => {
7747                let n = *len as usize;
7748                unsafe {
7749                    let s = sl(*src, base, 2 * n);
7750                    let d = sl_mut(*dst, base, 2 * n);
7751                    for i in 0..n {
7752                        let a = s[2 * i];
7753                        let b = s[2 * i + 1];
7754                        let (re, im) = match kind {
7755                            Activation::Neg => (-a, -b),
7756                            Activation::Exp => {
7757                                // exp(a + bi) = e^a · (cos b + i·sin b)
7758                                let ea = a.exp();
7759                                (ea * b.cos(), ea * b.sin())
7760                            }
7761                            Activation::Log => {
7762                                // log(z) = log|z| + i·arg(z), principal branch
7763                                let r = (a * a + b * b).sqrt();
7764                                (r.ln(), b.atan2(a))
7765                            }
7766                            Activation::Sqrt => {
7767                                // sqrt(a+bi) = sqrt((|z|+a)/2) + sign(b)·i·sqrt((|z|-a)/2)
7768                                // Principal branch; for b == 0 and a < 0 returns +i·sqrt(|a|).
7769                                let r = (a * a + b * b).sqrt();
7770                                let re = ((r + a) * 0.5).max(0.0).sqrt();
7771                                let im_mag = ((r - a) * 0.5).max(0.0).sqrt();
7772                                let im = if b >= 0.0 { im_mag } else { -im_mag };
7773                                (re, im)
7774                            }
7775                            _ => unreachable!("non-C64 activation kind survived lowering"),
7776                        };
7777                        d[2 * i] = re;
7778                        d[2 * i + 1] = im;
7779                    }
7780                }
7781            }
7782
7783            Thunk::Scan {
7784                body,
7785                body_init,
7786                body_input_off,
7787                body_output_off,
7788                outer_init_off,
7789                outer_final_off,
7790                length,
7791                carry_bytes,
7792                save_trajectory,
7793                xs_inputs,
7794                bcast_inputs,
7795                num_checkpoints,
7796            } => {
7797                let cb = *carry_bytes as usize;
7798                let n_steps = *length as usize;
7799                // Checkpoint mode: when 0 < K < length, save trajectory[k]
7800                // only when t == c_k = floor((k+1) * length / K) - 1.
7801                // The last index c_{K-1} = length - 1 always.
7802                let k_total = if *num_checkpoints == 0 || *num_checkpoints == *length {
7803                    n_steps // save every step
7804                } else {
7805                    *num_checkpoints as usize
7806                };
7807                let checkpoint_t_for_k = |k: usize| -> usize {
7808                    if k_total == n_steps {
7809                        k
7810                    } else {
7811                        ((k + 1) * n_steps)
7812                            .div_ceil(k_total)
7813                            .saturating_sub(1)
7814                            .min(n_steps - 1)
7815                    }
7816                };
7817                let mut next_k = 0usize;
7818
7819                let mut body_buf: Vec<u8> = (**body_init).clone();
7820                unsafe {
7821                    std::ptr::copy_nonoverlapping(
7822                        base.add(*outer_init_off),
7823                        body_buf.as_mut_ptr().add(*body_input_off),
7824                        cb,
7825                    );
7826                    // Broadcast inputs: copy each one into the body's
7827                    // input slot ONCE. They aren't touched in the
7828                    // iteration loop below (in contrast to xs).
7829                    for (body_b_off, outer_b_off, total_bytes) in bcast_inputs.iter() {
7830                        std::ptr::copy_nonoverlapping(
7831                            base.add(*outer_b_off),
7832                            body_buf.as_mut_ptr().add(*body_b_off),
7833                            *total_bytes as usize,
7834                        );
7835                    }
7836                }
7837                for t in 0..n_steps {
7838                    for (body_x_off, outer_xs_off, per_step_bytes) in xs_inputs.iter() {
7839                        let psb = *per_step_bytes as usize;
7840                        unsafe {
7841                            std::ptr::copy_nonoverlapping(
7842                                base.add(*outer_xs_off + t * psb),
7843                                body_buf.as_mut_ptr().add(*body_x_off),
7844                                psb,
7845                            );
7846                        }
7847                    }
7848
7849                    execute_thunks(body, &mut body_buf);
7850
7851                    if *save_trajectory && next_k < k_total && t == checkpoint_t_for_k(next_k) {
7852                        unsafe {
7853                            std::ptr::copy_nonoverlapping(
7854                                body_buf.as_ptr().add(*body_output_off),
7855                                base.add(*outer_final_off + next_k * cb),
7856                                cb,
7857                            );
7858                        }
7859                        next_k += 1;
7860                    }
7861
7862                    if *body_output_off != *body_input_off {
7863                        body_buf
7864                            .copy_within(*body_output_off..*body_output_off + cb, *body_input_off);
7865                    }
7866                }
7867
7868                if !*save_trajectory {
7869                    // Single final-carry write.
7870                    unsafe {
7871                        std::ptr::copy_nonoverlapping(
7872                            body_buf.as_ptr().add(*body_output_off),
7873                            base.add(*outer_final_off),
7874                            cb,
7875                        );
7876                    }
7877                }
7878            }
7879
7880            Thunk::ScanBackward {
7881                body_vjp,
7882                body_init,
7883                body_carry_in_off,
7884                body_x_offs,
7885                body_d_output_off,
7886                body_dcarry_out_off,
7887                outer_init_off,
7888                outer_traj_off,
7889                outer_upstream_off,
7890                outer_xs_offs,
7891                outer_dinit_off,
7892                length,
7893                carry_bytes,
7894                save_trajectory,
7895                num_checkpoints,
7896                forward_body,
7897                forward_body_init,
7898                forward_body_carry_in_off,
7899                forward_body_output_off,
7900                forward_body_x_offs,
7901                carry_elem_size,
7902            } => {
7903                // Two backward paths share the same per-iteration body
7904                // (body_vjp run + dcarry threading). The "All" path
7905                // reads the carry directly from the saved trajectory
7906                // each step. The "Recursive checkpointing" path stores
7907                // only K saved checkpoints and reconstructs intermediate
7908                // carries via Griewank-style recursive subdivision —
7909                // see [`griewank_process_segment`]. Auxiliary memory
7910                // is `O(log(segment_size) · carry_bytes)` for the
7911                // recursion stack, vs the old segment-cache scheme's
7912                // `O(segment_size · carry_bytes)`. Total recompute work
7913                // grows from `O(length)` to `O(length · log)`, which
7914                // is the canonical Griewank trade.
7915                let cb = *carry_bytes as usize;
7916                let n_steps = *length as usize;
7917                let k_total = *num_checkpoints as usize;
7918                let is_recursive = k_total != 0 && k_total != n_steps;
7919                let checkpoint_t_for_k = |k: usize| -> usize {
7920                    ((k + 1) * n_steps)
7921                        .div_ceil(k_total)
7922                        .saturating_sub(1)
7923                        .min(n_steps - 1)
7924                };
7925
7926                let mut fwd_buf: Vec<u8> = if is_recursive {
7927                    (**forward_body_init.as_ref().unwrap()).clone()
7928                } else {
7929                    Vec::new()
7930                };
7931
7932                let mut dcarry: Vec<u8> = vec![0u8; cb];
7933                if !*save_trajectory {
7934                    unsafe {
7935                        std::ptr::copy_nonoverlapping(
7936                            base.add(*outer_upstream_off),
7937                            dcarry.as_mut_ptr(),
7938                            cb,
7939                        );
7940                    }
7941                }
7942
7943                let mut body_buf: Vec<u8> = (**body_init).clone();
7944
7945                // Per-iteration backward action — shared between the
7946                // direct-trajectory (All) and Griewank (Recursive) paths.
7947                // Both feed the same body_vjp run with carry-at-t,
7948                // x_t_i, and d_output, then thread dcarry backward.
7949                let process_iter =
7950                    |t: usize, carry_in: &[u8], dcarry: &mut Vec<u8>, body_buf: &mut Vec<u8>| {
7951                        if *save_trajectory {
7952                            unsafe {
7953                                let up_off = *outer_upstream_off + t * cb;
7954                                match *carry_elem_size {
7955                                    4 => {
7956                                        let up_ptr = base.add(up_off) as *const f32;
7957                                        let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
7958                                        let n_elems = cb / 4;
7959                                        for i in 0..n_elems {
7960                                            *dc_ptr.add(i) += *up_ptr.add(i);
7961                                        }
7962                                    }
7963                                    8 => {
7964                                        let up_ptr = base.add(up_off) as *const f64;
7965                                        let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
7966                                        let n_elems = cb / 8;
7967                                        for i in 0..n_elems {
7968                                            *dc_ptr.add(i) += *up_ptr.add(i);
7969                                        }
7970                                    }
7971                                    other => panic!(
7972                                        "ScanBackward: unsupported carry elem size {other} \
7973                                     (only f32/f64 carries are supported today)"
7974                                    ),
7975                                }
7976                            }
7977                        }
7978                        body_buf[*body_carry_in_off..*body_carry_in_off + cb]
7979                            .copy_from_slice(carry_in);
7980                        unsafe {
7981                            for (i, body_x_off) in body_x_offs.iter().enumerate() {
7982                                let (outer_xs_off, per_step_bytes) = outer_xs_offs[i];
7983                                let psb = per_step_bytes as usize;
7984                                std::ptr::copy_nonoverlapping(
7985                                    base.add(outer_xs_off + t * psb),
7986                                    body_buf.as_mut_ptr().add(*body_x_off),
7987                                    psb,
7988                                );
7989                            }
7990                            std::ptr::copy_nonoverlapping(
7991                                dcarry.as_ptr(),
7992                                body_buf.as_mut_ptr().add(*body_d_output_off),
7993                                cb,
7994                            );
7995                        }
7996                        execute_thunks(body_vjp, body_buf);
7997                        unsafe {
7998                            std::ptr::copy_nonoverlapping(
7999                                body_buf.as_ptr().add(*body_dcarry_out_off),
8000                                dcarry.as_mut_ptr(),
8001                                cb,
8002                            );
8003                        }
8004                    };
8005
8006                if is_recursive {
8007                    // Griewank treeverse path. Process saved-checkpoint
8008                    // segments from highest-t to lowest-t; within each,
8009                    // recursive binary subdivision via
8010                    // `griewank_process_segment`. Auxiliary memory:
8011                    // O(log(seg_size) · cb) for the recursion stack
8012                    // (vs O(seg_size · cb) for the older segment-cache
8013                    // scheme); recompute work: O(seg_size · log).
8014                    let leaf_threshold = 4usize;
8015                    let fb_sched = forward_body.as_ref().unwrap();
8016                    let fb_init = forward_body_init.as_ref().unwrap().as_slice();
8017                    let mut segment_end = n_steps - 1;
8018                    for seg_k in (0..k_total).rev() {
8019                        let segment_start = if seg_k == 0 {
8020                            0
8021                        } else {
8022                            checkpoint_t_for_k(seg_k - 1) + 1
8023                        };
8024                        let mut anchor: Vec<u8> = vec![0u8; cb];
8025                        unsafe {
8026                            let src = if seg_k == 0 {
8027                                base.add(*outer_init_off)
8028                            } else {
8029                                base.add(*outer_traj_off + (seg_k - 1) * cb)
8030                            };
8031                            std::ptr::copy_nonoverlapping(src, anchor.as_mut_ptr(), cb);
8032                        }
8033                        // Closure adapter for the helper's signature
8034                        // (mutably re-borrows dcarry / body_buf each call).
8035                        let mut leaf_action = |t: usize, carry_in: &[u8]| {
8036                            process_iter(t, carry_in, &mut dcarry, &mut body_buf);
8037                        };
8038                        unsafe {
8039                            griewank_process_segment(
8040                                segment_start,
8041                                segment_end,
8042                                &anchor,
8043                                cb,
8044                                fb_sched,
8045                                fb_init,
8046                                *forward_body_carry_in_off,
8047                                *forward_body_output_off,
8048                                forward_body_x_offs,
8049                                base,
8050                                outer_xs_offs,
8051                                &mut fwd_buf,
8052                                leaf_threshold,
8053                                &mut leaf_action,
8054                            );
8055                        }
8056                        if seg_k == 0 {
8057                            break;
8058                        }
8059                        segment_end = segment_start - 1;
8060                    }
8061                } else {
8062                    // All-trajectory path: read each carry directly
8063                    // from the saved trajectory buffer.
8064                    let mut carry_buf: Vec<u8> = vec![0u8; cb];
8065                    for t in (0..n_steps).rev() {
8066                        unsafe {
8067                            let src = if t == 0 {
8068                                base.add(*outer_init_off)
8069                            } else {
8070                                base.add(*outer_traj_off + (t - 1) * cb)
8071                            };
8072                            std::ptr::copy_nonoverlapping(src, carry_buf.as_mut_ptr(), cb);
8073                        }
8074                        process_iter(t, &carry_buf, &mut dcarry, &mut body_buf);
8075                    }
8076                }
8077
8078                unsafe {
8079                    std::ptr::copy_nonoverlapping(dcarry.as_ptr(), base.add(*outer_dinit_off), cb);
8080                }
8081            }
8082
8083            Thunk::ScanBackwardXs {
8084                body_vjp,
8085                body_init,
8086                body_carry_in_off,
8087                body_x_offs,
8088                body_d_output_off,
8089                body_dcarry_out_off,
8090                body_dxs_out_off,
8091                outer_init_off,
8092                outer_traj_off,
8093                outer_upstream_off,
8094                outer_xs_offs,
8095                outer_dxs_off,
8096                length,
8097                carry_bytes,
8098                carry_elem_size,
8099                per_step_bytes,
8100                save_trajectory,
8101                num_checkpoints,
8102                forward_body,
8103                forward_body_init,
8104                forward_body_carry_in_off,
8105                forward_body_output_off,
8106                forward_body_x_offs,
8107            } => {
8108                let cb = *carry_bytes as usize;
8109                let psb = *per_step_bytes as usize;
8110                let n_steps = *length as usize;
8111                let k_total = *num_checkpoints as usize;
8112                let is_recursive = k_total != 0 && k_total != n_steps;
8113                let checkpoint_t_for_k = |k: usize| -> usize {
8114                    ((k + 1) * n_steps)
8115                        .div_ceil(k_total)
8116                        .saturating_sub(1)
8117                        .min(n_steps - 1)
8118                };
8119
8120                // Forward-body recompute scratch + segment cache —
8121                // exact mirror of the ScanBackward path. With ≈√length
8122                // checkpoints, total recompute work is O(length).
8123                let mut fwd_buf: Vec<u8> = if is_recursive {
8124                    (**forward_body_init.as_ref().unwrap()).clone()
8125                } else {
8126                    Vec::new()
8127                };
8128                let mut seg_cache: Vec<u8> = Vec::new();
8129                let mut seg_start_t: usize = usize::MAX;
8130                let mut seg_count: usize = 0;
8131                let recompute_carry_t =
8132                    |t: usize,
8133                     dst: &mut [u8],
8134                     fwd_buf: &mut Vec<u8>,
8135                     seg_cache: &mut Vec<u8>,
8136                     seg_start_t: &mut usize,
8137                     seg_count: &mut usize| {
8138                        if !is_recursive {
8139                            unsafe {
8140                                let src = if t == 0 {
8141                                    base.add(*outer_init_off)
8142                                } else {
8143                                    base.add(*outer_traj_off + (t - 1) * cb)
8144                                };
8145                                std::ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), cb);
8146                            }
8147                            return;
8148                        }
8149                        if *seg_start_t != usize::MAX
8150                            && t >= *seg_start_t
8151                            && t < *seg_start_t + *seg_count
8152                        {
8153                            let off = (t - *seg_start_t) * cb;
8154                            dst.copy_from_slice(&seg_cache[off..off + cb]);
8155                            return;
8156                        }
8157                        let seg_k = (0..k_total)
8158                            .find(|&k| t <= checkpoint_t_for_k(k))
8159                            .unwrap_or(k_total - 1);
8160                        let (anchor_t, anchor_ptr): (usize, *const u8) = if seg_k == 0 {
8161                            (0, unsafe { base.add(*outer_init_off) as *const u8 })
8162                        } else {
8163                            let prev_ck = checkpoint_t_for_k(seg_k - 1);
8164                            (prev_ck + 1, unsafe {
8165                                base.add(*outer_traj_off + (seg_k - 1) * cb) as *const u8
8166                            })
8167                        };
8168                        let seg_end_t = checkpoint_t_for_k(seg_k);
8169                        let seg_size = seg_end_t - anchor_t + 1;
8170
8171                        fwd_buf.copy_from_slice(forward_body_init.as_ref().unwrap());
8172                        unsafe {
8173                            std::ptr::copy_nonoverlapping(
8174                                anchor_ptr,
8175                                fwd_buf.as_mut_ptr().add(*forward_body_carry_in_off),
8176                                cb,
8177                            );
8178                        }
8179                        seg_cache.resize(seg_size * cb, 0u8);
8180                        seg_cache[0..cb].copy_from_slice(
8181                            &fwd_buf[*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
8182                        );
8183                        let fb_sched = forward_body.as_ref().unwrap();
8184                        for i in 1..seg_size {
8185                            let cur_iter = anchor_t + i - 1;
8186                            for (idx, fb_x_off) in forward_body_x_offs.iter().enumerate() {
8187                                let (outer_xs_off, x_psb) = outer_xs_offs[idx];
8188                                let xb = x_psb as usize;
8189                                unsafe {
8190                                    std::ptr::copy_nonoverlapping(
8191                                        base.add(outer_xs_off + cur_iter * xb),
8192                                        fwd_buf.as_mut_ptr().add(*fb_x_off),
8193                                        xb,
8194                                    );
8195                                }
8196                            }
8197                            execute_thunks(fb_sched, fwd_buf);
8198                            if *forward_body_output_off != *forward_body_carry_in_off {
8199                                fwd_buf.copy_within(
8200                                    *forward_body_output_off..*forward_body_output_off + cb,
8201                                    *forward_body_carry_in_off,
8202                                );
8203                            }
8204                            let cache_off = i * cb;
8205                            seg_cache[cache_off..cache_off + cb].copy_from_slice(
8206                                &fwd_buf
8207                                    [*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
8208                            );
8209                        }
8210                        *seg_start_t = anchor_t;
8211                        *seg_count = seg_size;
8212
8213                        let off = (t - anchor_t) * cb;
8214                        dst.copy_from_slice(&seg_cache[off..off + cb]);
8215                    };
8216
8217                let mut dcarry: Vec<u8> = vec![0u8; cb];
8218                if !*save_trajectory {
8219                    unsafe {
8220                        std::ptr::copy_nonoverlapping(
8221                            base.add(*outer_upstream_off),
8222                            dcarry.as_mut_ptr(),
8223                            cb,
8224                        );
8225                    }
8226                }
8227
8228                let mut body_buf: Vec<u8> = (**body_init).clone();
8229
8230                for t in (0..n_steps).rev() {
8231                    if *save_trajectory {
8232                        unsafe {
8233                            let up_off = *outer_upstream_off + t * cb;
8234                            match *carry_elem_size {
8235                                4 => {
8236                                    let up_ptr = base.add(up_off) as *const f32;
8237                                    let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
8238                                    let n_elems = cb / 4;
8239                                    for i in 0..n_elems {
8240                                        *dc_ptr.add(i) += *up_ptr.add(i);
8241                                    }
8242                                }
8243                                8 => {
8244                                    let up_ptr = base.add(up_off) as *const f64;
8245                                    let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
8246                                    let n_elems = cb / 8;
8247                                    for i in 0..n_elems {
8248                                        *dc_ptr.add(i) += *up_ptr.add(i);
8249                                    }
8250                                }
8251                                other => panic!(
8252                                    "ScanBackwardXs: unsupported carry elem size {other} \
8253                                     (only f32/f64 carries are supported today)"
8254                                ),
8255                            }
8256                        }
8257                    }
8258
8259                    // Seed body_vjp's carry input via the recompute
8260                    // helper (works for both All and Recursive modes),
8261                    // then x_t_i + d_output.
8262                    let carry_dst_start = *body_carry_in_off;
8263                    {
8264                        let carry_slice = &mut body_buf[carry_dst_start..carry_dst_start + cb];
8265                        recompute_carry_t(
8266                            t,
8267                            carry_slice,
8268                            &mut fwd_buf,
8269                            &mut seg_cache,
8270                            &mut seg_start_t,
8271                            &mut seg_count,
8272                        );
8273                    }
8274                    unsafe {
8275                        for (i, body_x_off) in body_x_offs.iter().enumerate() {
8276                            let (outer_xs_off, x_psb) = outer_xs_offs[i];
8277                            let xb = x_psb as usize;
8278                            std::ptr::copy_nonoverlapping(
8279                                base.add(outer_xs_off + t * xb),
8280                                body_buf.as_mut_ptr().add(*body_x_off),
8281                                xb,
8282                            );
8283                        }
8284                        std::ptr::copy_nonoverlapping(
8285                            dcarry.as_ptr(),
8286                            body_buf.as_mut_ptr().add(*body_d_output_off),
8287                            cb,
8288                        );
8289                    }
8290
8291                    execute_thunks(body_vjp, &mut body_buf);
8292
8293                    // Stash this step's dxs into row `t` of the outer
8294                    // [length, *per_step_xs] output.
8295                    unsafe {
8296                        std::ptr::copy_nonoverlapping(
8297                            body_buf.as_ptr().add(*body_dxs_out_off),
8298                            base.add(*outer_dxs_off + t * psb),
8299                            psb,
8300                        );
8301                    }
8302
8303                    // Update dcarry for next backward iteration.
8304                    unsafe {
8305                        std::ptr::copy_nonoverlapping(
8306                            body_buf.as_ptr().add(*body_dcarry_out_off),
8307                            dcarry.as_mut_ptr(),
8308                            cb,
8309                        );
8310                    }
8311                }
8312            }
8313
8314            Thunk::FusedMmBiasAct {
8315                a,
8316                w,
8317                bias,
8318                c,
8319                m,
8320                k,
8321                n,
8322                act,
8323            } => {
8324                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
8325                unsafe {
8326                    let out = sl_mut(*c, base, m * n);
8327                    crate::blas::sgemm_auto(sl(*a, base, m * k), sl(*w, base, k * n), out, m, k, n);
8328                    match act {
8329                        Some(Activation::Gelu) => {
8330                            crate::kernels::par_bias_gelu(out, sl(*bias, base, n), m, n)
8331                        }
8332                        Some(other) => {
8333                            crate::blas::bias_add(out, sl(*bias, base, n), m, n);
8334                            apply_activation_inplace(out, *other);
8335                        }
8336                        None => crate::blas::bias_add(out, sl(*bias, base, n), m, n),
8337                    }
8338                }
8339            }
8340
8341            Thunk::FusedResidualLN {
8342                x,
8343                res,
8344                bias,
8345                g,
8346                b,
8347                out,
8348                rows,
8349                h,
8350                eps,
8351                has_bias,
8352            } => {
8353                let (rows, h) = (*rows as usize, *h as usize);
8354                unsafe {
8355                    let zero = &zero_bias[..h];
8356                    let bi = if *has_bias { sl(*bias, base, h) } else { zero };
8357                    let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
8358                    let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
8359                    let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
8360                    let bi_ptr = bi.as_ptr() as usize;
8361                    let g_ptr = sl(*g, base, h).as_ptr() as usize;
8362                    let b_ptr = sl(*b, base, h).as_ptr() as usize;
8363                    let e = *eps;
8364                    crate::pool::par_for(rows, 4, &|off, cnt| {
8365                        let xs =
8366                            std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
8367                        let rs =
8368                            std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
8369                        let os = std::slice::from_raw_parts_mut(
8370                            (o_ptr as *mut f32).add(off * h),
8371                            cnt * h,
8372                        );
8373                        let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
8374                        let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
8375                        let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
8376                        crate::kernels::residual_bias_layer_norm(xs, rs, bi, g, b, os, cnt, h, e);
8377                    });
8378                }
8379            }
8380
8381            Thunk::FusedResidualRmsNorm {
8382                x,
8383                res,
8384                bias,
8385                g,
8386                b,
8387                out,
8388                rows,
8389                h,
8390                eps,
8391                has_bias,
8392            } => {
8393                let (rows, h) = (*rows as usize, *h as usize);
8394                unsafe {
8395                    let zero = &zero_bias[..h];
8396                    let bi = if *has_bias { sl(*bias, base, h) } else { zero };
8397                    let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
8398                    let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
8399                    let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
8400                    let bi_ptr = bi.as_ptr() as usize;
8401                    let g_ptr = sl(*g, base, h).as_ptr() as usize;
8402                    let b_ptr = sl(*b, base, h).as_ptr() as usize;
8403                    let e = *eps;
8404                    crate::pool::par_for(rows, 4, &|off, cnt| {
8405                        let xs =
8406                            std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
8407                        let rs =
8408                            std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
8409                        let os = std::slice::from_raw_parts_mut(
8410                            (o_ptr as *mut f32).add(off * h),
8411                            cnt * h,
8412                        );
8413                        let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
8414                        let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
8415                        let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
8416                        crate::kernels::residual_bias_rms_norm(xs, rs, bi, g, b, os, cnt, h, e);
8417                    });
8418                }
8419            }
8420
8421            Thunk::BiasAdd {
8422                src,
8423                bias,
8424                dst,
8425                m,
8426                n,
8427            } => {
8428                let (m, n) = (*m as usize, *n as usize);
8429                unsafe {
8430                    let out = sl_mut(*dst, base, m * n);
8431                    out.copy_from_slice(sl(*src, base, m * n));
8432                    crate::blas::bias_add(out, sl(*bias, base, n), m, n);
8433                }
8434            }
8435
8436            Thunk::BinaryFull {
8437                lhs,
8438                rhs,
8439                dst,
8440                len,
8441                lhs_len,
8442                rhs_len,
8443                op,
8444                out_dims_bcast,
8445                bcast_lhs_strides,
8446                bcast_rhs_strides,
8447            } => {
8448                let len = *len as usize;
8449                let ll = (*lhs_len as usize).max(1);
8450                let rl = (*rhs_len as usize).max(1);
8451                unsafe {
8452                    let l = sl(*lhs, base, ll);
8453                    let r = sl(*rhs, base, rl);
8454                    let o = sl_mut(*dst, base, len);
8455                    // Fast path: shapes match exactly → NEON-vectorized loop.
8456                    if ll == len && rl == len {
8457                        #[cfg(target_arch = "aarch64")]
8458                        if matches!(op, BinaryOp::Add | BinaryOp::Mul) {
8459                            use std::arch::aarch64::*;
8460                            let chunks = len / 4;
8461                            for c in 0..chunks {
8462                                let off = c * 4;
8463                                let vl = vld1q_f32(l.as_ptr().add(off));
8464                                let vr = vld1q_f32(r.as_ptr().add(off));
8465                                let res = match op {
8466                                    BinaryOp::Add => vaddq_f32(vl, vr),
8467                                    BinaryOp::Mul => vmulq_f32(vl, vr),
8468                                    _ => unreachable!(),
8469                                };
8470                                vst1q_f32(o.as_mut_ptr().add(off), res);
8471                            }
8472                            for i in (chunks * 4)..len {
8473                                o[i] = match op {
8474                                    BinaryOp::Add => l[i] + r[i],
8475                                    BinaryOp::Mul => l[i] * r[i],
8476                                    _ => unreachable!(),
8477                                };
8478                            }
8479                            // `continue` to next thunk in the schedule — a
8480                            // bare `return` here used to exit execute_thunks
8481                            // entirely, silently dropping every thunk after
8482                            // the first BinaryFull (catastrophic for chained
8483                            // adds in BERT embedding stage).
8484                            continue;
8485                        }
8486                    }
8487                    if !out_dims_bcast.is_empty() {
8488                        // Shape-aware broadcast path: correct for
8489                        // bidirectional `[N,1] op [1,S]` etc.
8490                        let rank = out_dims_bcast.len();
8491                        let mut coords = vec![0u32; rank];
8492                        for i in 0..len {
8493                            let mut rem = i;
8494                            for ax in (0..rank).rev() {
8495                                let sz = out_dims_bcast[ax] as usize;
8496                                coords[ax] = (rem % sz) as u32;
8497                                rem /= sz;
8498                            }
8499                            let mut li: usize = 0;
8500                            let mut ri: usize = 0;
8501                            for ax in 0..rank {
8502                                li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
8503                                ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
8504                            }
8505                            o[i] = match op {
8506                                BinaryOp::Add => l[li] + r[ri],
8507                                BinaryOp::Sub => l[li] - r[ri],
8508                                BinaryOp::Mul => l[li] * r[ri],
8509                                BinaryOp::Div => l[li] / r[ri],
8510                                BinaryOp::Max => l[li].max(r[ri]),
8511                                BinaryOp::Min => l[li].min(r[ri]),
8512                                BinaryOp::Pow => l[li].powf(r[ri]),
8513                            };
8514                        }
8515                    } else {
8516                        // Fallback: legacy modulo path (dynamic shapes only).
8517                        for i in 0..len {
8518                            let li = if ll == 1 { 0 } else { i % ll };
8519                            let ri = if rl == 1 { 0 } else { i % rl };
8520                            o[i] = match op {
8521                                BinaryOp::Add => l[li] + r[ri],
8522                                BinaryOp::Sub => l[li] - r[ri],
8523                                BinaryOp::Mul => l[li] * r[ri],
8524                                BinaryOp::Div => l[li] / r[ri],
8525                                BinaryOp::Max => l[li].max(r[ri]),
8526                                BinaryOp::Min => l[li].min(r[ri]),
8527                                BinaryOp::Pow => l[li].powf(r[ri]),
8528                            };
8529                        }
8530                    }
8531                }
8532            }
8533
8534            Thunk::Gather {
8535                table,
8536                table_len,
8537                idx,
8538                dst,
8539                num_idx,
8540                trailing,
8541            } => {
8542                let (ni, tr) = (*num_idx as usize, *trailing as usize);
8543                unsafe {
8544                    let tab = sl(*table, base, *table_len as usize);
8545                    let ids = sl(*idx, base, ni);
8546                    let out = sl_mut(*dst, base, ni * tr);
8547                    for i in 0..ni {
8548                        let row = ids[i] as usize;
8549                        out[i * tr..(i + 1) * tr].copy_from_slice(&tab[row * tr..(row + 1) * tr]);
8550                    }
8551                }
8552            }
8553
8554            Thunk::Narrow {
8555                src,
8556                dst,
8557                outer,
8558                src_stride,
8559                dst_stride,
8560                inner,
8561                elem_bytes,
8562            } => {
8563                let f = narrow_thunk_closure(
8564                    *src,
8565                    *dst,
8566                    *outer,
8567                    *src_stride,
8568                    *dst_stride,
8569                    *inner,
8570                    *elem_bytes,
8571                );
8572                f(base);
8573            }
8574
8575            Thunk::Copy { src, dst, len } => {
8576                let len = *len as usize;
8577                unsafe {
8578                    let s = sl(*src, base, len);
8579                    let d = sl_mut(*dst, base, len);
8580                    d.copy_from_slice(s);
8581                }
8582            }
8583
8584            Thunk::LayerNorm {
8585                src,
8586                g,
8587                b,
8588                dst,
8589                rows,
8590                h,
8591                eps,
8592            } => {
8593                let (rows, h) = (*rows as usize, *h as usize);
8594                unsafe {
8595                    let input = sl(*src, base, rows * h);
8596                    let gamma = sl(*g, base, h);
8597                    let beta = sl(*b, base, h);
8598                    let output = sl_mut(*dst, base, rows * h);
8599                    // Parallelize across rows (same pattern as FusedResidualLN)
8600                    if rows >= 4 && rows * h >= 30_000 {
8601                        let i_ptr = input.as_ptr() as usize;
8602                        let o_ptr = output.as_mut_ptr() as usize;
8603                        let g_ptr = gamma.as_ptr() as usize;
8604                        let b_ptr = beta.as_ptr() as usize;
8605                        let e = *eps;
8606                        crate::pool::par_for(rows, 4, &|off, cnt| {
8607                            let inp = std::slice::from_raw_parts(
8608                                (i_ptr as *const f32).add(off * h),
8609                                cnt * h,
8610                            );
8611                            let out = std::slice::from_raw_parts_mut(
8612                                (o_ptr as *mut f32).add(off * h),
8613                                cnt * h,
8614                            );
8615                            let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
8616                            let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
8617                            for row in 0..cnt {
8618                                crate::kernels::layer_norm_row(
8619                                    &inp[row * h..(row + 1) * h],
8620                                    g,
8621                                    b,
8622                                    &mut out[row * h..(row + 1) * h],
8623                                    h,
8624                                    e,
8625                                );
8626                            }
8627                        });
8628                    } else {
8629                        for row in 0..rows {
8630                            crate::kernels::layer_norm_row(
8631                                &input[row * h..(row + 1) * h],
8632                                gamma,
8633                                beta,
8634                                &mut output[row * h..(row + 1) * h],
8635                                h,
8636                                *eps,
8637                            );
8638                        }
8639                    }
8640                }
8641            }
8642
8643            Thunk::GroupNorm {
8644                src,
8645                g,
8646                b,
8647                dst,
8648                n,
8649                c,
8650                h,
8651                w,
8652                num_groups,
8653                eps,
8654            } => {
8655                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
8656                let plane = c * h * w;
8657                unsafe {
8658                    for ni in 0..n {
8659                        let input = sl(*src, base.add(ni * plane), plane);
8660                        let gamma = sl(*g, base, c);
8661                        let beta = sl(*b, base, c);
8662                        let output = sl_mut(*dst, base.add(ni * plane), plane);
8663                        crate::kernels::group_norm_nchw(
8664                            input,
8665                            gamma,
8666                            beta,
8667                            output,
8668                            1,
8669                            c,
8670                            h,
8671                            w,
8672                            *num_groups as usize,
8673                            *eps,
8674                        );
8675                    }
8676                }
8677            }
8678
8679            Thunk::LayerNorm2d {
8680                src,
8681                g,
8682                b,
8683                dst,
8684                n,
8685                c,
8686                h,
8687                w,
8688                eps,
8689            } => {
8690                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
8691                let plane = c * h * w;
8692                unsafe {
8693                    let input = sl(*src, base, n * plane);
8694                    let gamma = sl(*g, base, c);
8695                    let beta = sl(*b, base, c);
8696                    let output = sl_mut(*dst, base, n * plane);
8697                    crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, *eps);
8698                }
8699            }
8700
8701            Thunk::ConvTranspose2d {
8702                src,
8703                weight,
8704                dst,
8705                n,
8706                c_in,
8707                h,
8708                w_in,
8709                c_out,
8710                h_out,
8711                w_out,
8712                kh,
8713                kw,
8714                sh,
8715                sw,
8716                ph,
8717                pw,
8718                dh,
8719                dw,
8720                groups,
8721            } => {
8722                let n = *n as usize;
8723                let c_in = *c_in as usize;
8724                let h = *h as usize;
8725                let w_in = *w_in as usize;
8726                let c_out = *c_out as usize;
8727                let h_out = *h_out as usize;
8728                let w_out = *w_out as usize;
8729                unsafe {
8730                    let inp = sl(*src, base, n * c_in * h * w_in);
8731                    let wt = sl(
8732                        *weight,
8733                        base,
8734                        c_in * (c_out / *groups as usize) * (*kh as usize) * (*kw as usize),
8735                    );
8736                    let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
8737                    crate::kernels::conv_transpose2d_nchw(
8738                        inp,
8739                        wt,
8740                        out,
8741                        n,
8742                        c_in,
8743                        h,
8744                        w_in,
8745                        c_out,
8746                        h_out,
8747                        w_out,
8748                        *kh as usize,
8749                        *kw as usize,
8750                        *sh as usize,
8751                        *sw as usize,
8752                        *ph as usize,
8753                        *pw as usize,
8754                        *dh as usize,
8755                        *dw as usize,
8756                        *groups as usize,
8757                    );
8758                }
8759            }
8760
8761            Thunk::ResizeNearest2x {
8762                src,
8763                dst,
8764                n,
8765                c,
8766                h,
8767                w,
8768            } => {
8769                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
8770                let in_plane = c * h * w;
8771                let out_plane = c * h * 2 * w * 2;
8772                unsafe {
8773                    for ni in 0..n {
8774                        let input = sl(*src, base.add(ni * in_plane), in_plane);
8775                        let output = sl_mut(*dst, base.add(ni * out_plane), out_plane);
8776                        crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
8777                    }
8778                }
8779            }
8780
8781            Thunk::AxialRope2d {
8782                src,
8783                dst,
8784                batch,
8785                seq,
8786                hidden,
8787                end_x,
8788                end_y,
8789                head_dim,
8790                num_heads,
8791                theta,
8792                repeat_factor,
8793            } => {
8794                let b = *batch as usize;
8795                let s = *seq as usize;
8796                let hdim = *head_dim as usize;
8797                let nh = *num_heads as usize;
8798                let plane = s * (*hidden as usize);
8799                unsafe {
8800                    for bi in 0..b {
8801                        let input = sl(*src, base.add(bi * plane), plane);
8802                        let output = sl_mut(*dst, base.add(bi * plane), plane);
8803                        let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
8804                            input,
8805                            nh,
8806                            s,
8807                            hdim,
8808                            *end_x as usize,
8809                            *end_y as usize,
8810                            *theta,
8811                            *repeat_factor as usize,
8812                        );
8813                        output.copy_from_slice(&rotated);
8814                    }
8815                }
8816            }
8817
8818            Thunk::RmsNorm {
8819                src,
8820                g,
8821                b,
8822                dst,
8823                rows,
8824                h,
8825                eps,
8826            } => {
8827                let (rows, h) = (*rows as usize, *h as usize);
8828                unsafe {
8829                    let input = sl(*src, base, rows * h);
8830                    let gamma = sl(*g, base, h);
8831                    let beta = sl(*b, base, h);
8832                    let output = sl_mut(*dst, base, rows * h);
8833                    let inv_h = 1.0 / h as f32;
8834                    for row in 0..rows {
8835                        let in_row = &input[row * h..(row + 1) * h];
8836                        let out_row = &mut output[row * h..(row + 1) * h];
8837                        // RMS = sqrt(mean(x^2) + eps); scale = 1/RMS.
8838                        let mut sumsq = 0f32;
8839                        for &v in in_row {
8840                            sumsq += v * v;
8841                        }
8842                        let inv_rms = (sumsq * inv_h + *eps).sqrt().recip();
8843                        for i in 0..h {
8844                            out_row[i] = in_row[i] * inv_rms * gamma[i] + beta[i];
8845                        }
8846                    }
8847                }
8848            }
8849
8850            Thunk::Softmax { data, rows, cols } => {
8851                let (rows, cols) = (*rows as usize, *cols as usize);
8852                unsafe {
8853                    crate::kernels::neon_softmax(sl_mut(*data, base, rows * cols), rows, cols);
8854                }
8855            }
8856
8857            Thunk::Cumsum {
8858                src,
8859                dst,
8860                rows,
8861                cols,
8862                exclusive,
8863            } => {
8864                let (rows, cols) = (*rows as usize, *cols as usize);
8865                unsafe {
8866                    let s = sl(*src, base, rows * cols);
8867                    let d = sl_mut(*dst, base, rows * cols);
8868                    if *exclusive {
8869                        for r in 0..rows {
8870                            let mut acc = 0.0f32;
8871                            for c in 0..cols {
8872                                d[r * cols + c] = acc;
8873                                acc += s[r * cols + c];
8874                            }
8875                        }
8876                    } else {
8877                        for r in 0..rows {
8878                            let mut acc = 0.0f32;
8879                            for c in 0..cols {
8880                                acc += s[r * cols + c];
8881                                d[r * cols + c] = acc;
8882                            }
8883                        }
8884                    }
8885                }
8886            }
8887
8888            Thunk::Sample {
8889                logits,
8890                dst,
8891                batch,
8892                vocab,
8893                top_k,
8894                top_p,
8895                temperature,
8896                seed,
8897            } => {
8898                let (b, v) = (*batch as usize, *vocab as usize);
8899                let k = (*top_k as usize).min(v);
8900                unsafe {
8901                    let lg = sl(*logits, base, b * v);
8902                    let out = sl_mut(*dst, base, b);
8903                    let mut rng =
8904                        rlx_ir::Philox4x32::new(if *seed == 0 { 0xDEADBEEF } else { *seed });
8905                    for bi in 0..b {
8906                        let row = &lg[bi * v..(bi + 1) * v];
8907                        out[bi] = sample_row(row, k, *top_p, *temperature, &mut rng) as f32;
8908                    }
8909                }
8910            }
8911
8912            Thunk::GatedDeltaNet {
8913                q,
8914                k,
8915                v,
8916                g,
8917                beta,
8918                state,
8919                dst,
8920                batch,
8921                seq,
8922                heads,
8923                state_size,
8924            } => unsafe {
8925                execute_gated_delta_net_f32(
8926                    *q,
8927                    *k,
8928                    *v,
8929                    *g,
8930                    *beta,
8931                    *state,
8932                    *dst,
8933                    *batch as usize,
8934                    *seq as usize,
8935                    *heads as usize,
8936                    *state_size as usize,
8937                    base,
8938                );
8939            },
8940
8941            Thunk::SelectiveScan {
8942                x,
8943                delta,
8944                a,
8945                b: bp,
8946                c: cp,
8947                dst,
8948                batch,
8949                seq,
8950                hidden,
8951                state_size,
8952            } => {
8953                let (b, s, h, n) = (
8954                    *batch as usize,
8955                    *seq as usize,
8956                    *hidden as usize,
8957                    *state_size as usize,
8958                );
8959                unsafe {
8960                    let xs = sl(*x, base, b * s * h);
8961                    let dt = sl(*delta, base, b * s * h);
8962                    let am = sl(*a, base, h * n);
8963                    let bm = sl(*bp, base, b * s * n);
8964                    let cm = sl(*cp, base, b * s * n);
8965                    let out = sl_mut(*dst, base, b * s * h);
8966
8967                    // State buffer per-batch: h channels × n state.
8968                    // Sequential along the seq dimension; could
8969                    // parallelize over batch+channel later.
8970                    let mut state = vec![0f32; h * n];
8971                    for bi in 0..b {
8972                        // Reset state at the start of each batch row.
8973                        for v in state.iter_mut() {
8974                            *v = 0.0;
8975                        }
8976                        for si in 0..s {
8977                            let x_row = &xs[bi * s * h + si * h..bi * s * h + (si + 1) * h];
8978                            let dt_row = &dt[bi * s * h + si * h..bi * s * h + (si + 1) * h];
8979                            let b_row = &bm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
8980                            let c_row = &cm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
8981                            let out_row = &mut out[bi * s * h + si * h..bi * s * h + (si + 1) * h];
8982
8983                            for ci in 0..h {
8984                                let d = dt_row[ci];
8985                                let xv = x_row[ci];
8986                                let mut acc = 0f32;
8987                                for ni in 0..n {
8988                                    // Discretize: exp(d * a) and d * b.
8989                                    let da = (d * am[ci * n + ni]).exp();
8990                                    state[ci * n + ni] =
8991                                        da * state[ci * n + ni] + d * b_row[ni] * xv;
8992                                    acc += c_row[ni] * state[ci * n + ni];
8993                                }
8994                                out_row[ci] = acc;
8995                            }
8996                        }
8997                    }
8998                }
8999            }
9000
9001            Thunk::DequantMatMul {
9002                x,
9003                w_q,
9004                scale,
9005                zp,
9006                dst,
9007                m,
9008                k,
9009                n,
9010                block_size,
9011                is_asymmetric,
9012            } => {
9013                let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
9014                let n_blocks = k.div_ceil(bs);
9015                unsafe {
9016                    let xs = sl(*x, base, m * k);
9017                    let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const i8, k * n);
9018                    let scales = sl(*scale, base, n_blocks * n);
9019                    let zps = if *is_asymmetric {
9020                        sl(*zp, base, n_blocks * n)
9021                    } else {
9022                        &[][..]
9023                    };
9024                    let out = sl_mut(*dst, base, m * n);
9025                    dequant_matmul_int8(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
9026                }
9027            }
9028
9029            Thunk::DequantMatMulGguf {
9030                x,
9031                w_q,
9032                dst,
9033                m,
9034                k,
9035                n,
9036                scheme,
9037            } => {
9038                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9039                let block_bytes = scheme.gguf_block_bytes() as usize;
9040                let block_elems = scheme.gguf_block_size() as usize;
9041                debug_assert!(
9042                    block_bytes > 0 && block_elems > 0,
9043                    "non-GGUF scheme in GGUF arm"
9044                );
9045                debug_assert!(
9046                    (k * n).is_multiple_of(block_elems),
9047                    "k*n={} not aligned to GGUF block size {}",
9048                    k * n,
9049                    block_elems
9050                );
9051                let total_bytes = (k * n) / block_elems * block_bytes;
9052                unsafe {
9053                    let xs = sl(*x, base, m * k);
9054                    let w_bytes_ptr = base.add(*w_q) as *const u8;
9055                    let w_bytes = std::slice::from_raw_parts(w_bytes_ptr, total_bytes);
9056                    let out = sl_mut(*dst, base, m * n);
9057                    crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, *scheme);
9058                }
9059            }
9060
9061            Thunk::DequantMatMulInt4 {
9062                x,
9063                w_q,
9064                scale,
9065                zp,
9066                dst,
9067                m,
9068                k,
9069                n,
9070                block_size,
9071                is_asymmetric,
9072            } => {
9073                let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
9074                let n_blocks = k.div_ceil(bs);
9075                unsafe {
9076                    let xs = sl(*x, base, m * k);
9077                    let w_bytes = std::slice::from_raw_parts(
9078                        base.add(*w_q) as *const u8,
9079                        (k * n).div_ceil(2),
9080                    );
9081                    let scales = sl(*scale, base, n_blocks * n);
9082                    let zps = if *is_asymmetric {
9083                        sl(*zp, base, n_blocks * n)
9084                    } else {
9085                        &[][..]
9086                    };
9087                    let out = sl_mut(*dst, base, m * n);
9088                    dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
9089                }
9090            }
9091
9092            Thunk::DequantMatMulFp8 {
9093                x,
9094                w_q,
9095                scale,
9096                dst,
9097                m,
9098                k,
9099                n,
9100                e5m2,
9101            } => {
9102                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9103                unsafe {
9104                    let xs = sl(*x, base, m * k);
9105                    let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const u8, k * n);
9106                    let scales = sl(*scale, base, n);
9107                    let out = sl_mut(*dst, base, m * n);
9108                    dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, *e5m2);
9109                }
9110            }
9111
9112            Thunk::DequantMatMulNvfp4 {
9113                x,
9114                w_q,
9115                scale,
9116                global_scale,
9117                dst,
9118                m,
9119                k,
9120                n,
9121            } => {
9122                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9123                let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
9124                unsafe {
9125                    let xs = sl(*x, base, m * k);
9126                    let w_bytes = std::slice::from_raw_parts(
9127                        base.add(*w_q) as *const u8,
9128                        (k * n).div_ceil(2),
9129                    );
9130                    let scale_bytes =
9131                        std::slice::from_raw_parts(base.add(*scale) as *const u8, n_scale);
9132                    let gs = sl(*global_scale, base, 1)[0];
9133                    let out = sl_mut(*dst, base, m * n);
9134                    dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
9135                }
9136            }
9137
9138            Thunk::LoraMatMul {
9139                x,
9140                w,
9141                a,
9142                b,
9143                dst,
9144                m,
9145                k,
9146                n,
9147                r,
9148                scale,
9149            } => {
9150                let (m, k, n, r) = (*m as usize, *k as usize, *n as usize, *r as usize);
9151                unsafe {
9152                    let xs = sl(*x, base, m * k);
9153                    let ws = sl(*w, base, k * n);
9154                    let a_s = sl(*a, base, k * r);
9155                    let bs = sl(*b, base, r * n);
9156                    let out = sl_mut(*dst, base, m * n);
9157                    crate::blas::sgemm(xs, ws, out, m, k, n);
9158                    let mut tmp = vec![0f32; m * r];
9159                    crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
9160                    if *scale != 1.0 {
9161                        for v in tmp.iter_mut() {
9162                            *v *= *scale;
9163                        }
9164                    }
9165                    crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
9166                }
9167            }
9168
9169            Thunk::Attention {
9170                q,
9171                k,
9172                v,
9173                mask,
9174                out,
9175                batch,
9176                seq,
9177                kv_seq,
9178                heads,
9179                head_dim,
9180                mask_kind,
9181                q_row_stride,
9182                k_row_stride,
9183                v_row_stride,
9184                bhsd,
9185            } => {
9186                let (b, q_s, k_s, nh, dh) = (
9187                    *batch as usize,
9188                    *seq as usize,
9189                    *kv_seq as usize,
9190                    *heads as usize,
9191                    *head_dim as usize,
9192                );
9193                let hs = nh * dh;
9194                // For [B, H, S, D] layout each (b, h) tile is dense
9195                // contiguous; the qrs/krs/vrs strides are not used.
9196                let (qrs, krs, vrs) = if *bhsd {
9197                    (dh, dh, dh)
9198                } else {
9199                    (
9200                        *q_row_stride as usize,
9201                        *k_row_stride as usize,
9202                        *v_row_stride as usize,
9203                    )
9204                };
9205                let bhsd = *bhsd;
9206                let _ = (q_row_stride, k_row_stride, v_row_stride);
9207                let scale = (dh as f32).powf(-0.5);
9208                let ss = q_s * k_s;
9209                let cfg = crate::config::RuntimeConfig::global();
9210                unsafe {
9211                    // Slice lengths cover the strided span. When Q/K/V
9212                    // alias the parent QKV (post-#46-fusion), the same
9213                    // bytes back all three slices — compiler bounds
9214                    // checks see the right size. For [B, H, S, D] the
9215                    // buffer is densely B*H*S*D elements; the row
9216                    // strides aren't used.
9217                    let q_len = if bhsd {
9218                        b * nh * q_s * dh
9219                    } else {
9220                        b * q_s * qrs
9221                    };
9222                    let k_len = if bhsd {
9223                        b * nh * k_s * dh
9224                    } else {
9225                        b * k_s * krs
9226                    };
9227                    let v_len = if bhsd {
9228                        b * nh * k_s * dh
9229                    } else {
9230                        b * k_s * vrs
9231                    };
9232                    let q_data = sl(*q, base, q_len);
9233                    let k_data = sl(*k, base, k_len);
9234                    let v_data = sl(*v, base, v_len);
9235                    let mask_data: &[f32] = match mask_kind {
9236                        rlx_ir::op::MaskKind::Custom => sl(*mask, base, b * k_s),
9237                        rlx_ir::op::MaskKind::Bias => sl(*mask, base, b * nh * q_s * k_s),
9238                        _ => &[],
9239                    };
9240                    let out_len = if bhsd {
9241                        b * nh * q_s * dh
9242                    } else {
9243                        b * q_s * hs
9244                    };
9245                    let out_data = sl_mut(*out, base, out_len);
9246
9247                    // ── [B, H, S, D] fallback ──────────────────────
9248                    // The NEON / strided-BLAS specializations below
9249                    // are written for the [B, S, H, D] layout. When
9250                    // the input is head-major ([B, H, S, D] —
9251                    // matching rlx-cuda / rlx-rocm / rlx-tpu), bypass
9252                    // them and run a simple (correct but slower)
9253                    // scalar implementation. Production-CPU inference
9254                    // graphs use [B, S, H, D] so they still hit the
9255                    // hot path; cross-backend parity tests use
9256                    // [B, H, S, D] and land here.
9257                    if bhsd {
9258                        let scores = &mut sdpa_scores[..ss];
9259                        for bi in 0..b {
9260                            for hi in 0..nh {
9261                                let q_head_base = bi * nh * q_s * dh + hi * q_s * dh;
9262                                let k_head_base = bi * nh * k_s * dh + hi * k_s * dh;
9263                                // Q@K^T
9264                                for qi in 0..q_s {
9265                                    let q_base = q_head_base + qi * dh;
9266                                    for ki in 0..k_s {
9267                                        let k_base = k_head_base + ki * dh;
9268                                        let mut dot = 0f32;
9269                                        for d in 0..dh {
9270                                            dot += q_data[q_base + d] * k_data[k_base + d];
9271                                        }
9272                                        scores[qi * k_s + ki] = dot * scale;
9273                                        if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
9274                                            && !mask_data.is_empty()
9275                                            && mask_data[bi * k_s + ki] < mask_thr
9276                                        {
9277                                            scores[qi * k_s + ki] = mask_neg;
9278                                        }
9279                                    }
9280                                }
9281                                if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
9282                                    let off = (bi * nh + hi) * q_s * k_s;
9283                                    for i in 0..q_s * k_s {
9284                                        scores[i] += mask_data[off + i];
9285                                    }
9286                                }
9287                                apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
9288                                crate::kernels::neon_softmax(scores, q_s, k_s);
9289                                // score @ V
9290                                for qi in 0..q_s {
9291                                    let o_base = q_head_base + qi * dh;
9292                                    for d in 0..dh {
9293                                        out_data[o_base + d] = 0.0;
9294                                    }
9295                                    for ki in 0..k_s {
9296                                        let sc = scores[qi * k_s + ki];
9297                                        if sc > score_thr {
9298                                            let v_base = k_head_base + ki * dh;
9299                                            for d in 0..dh {
9300                                                out_data[o_base + d] += sc * v_data[v_base + d];
9301                                            }
9302                                        }
9303                                    }
9304                                }
9305                            }
9306                        }
9307                        continue;
9308                    }
9309
9310                    // ── Auto-select kernel: NEON dots vs strided BLAS ───
9311                    // For tiny inputs (batch=1, short seq), per-head BLAS call
9312                    // overhead (~0.5µs × 2 calls × num_heads × num_layers)
9313                    // exceeds the NEON compute cost. Use direct strided NEON
9314                    // with zero dispatch overhead.
9315                    // For batch≥2: always BLAS + par_for (parallelism wins).
9316                    if b == 1 && q_s.max(k_s) <= cfg.sdpa_seq_threshold {
9317                        // ── Sequential NEON path (zero overhead) ──
9318                        let scores = &mut sdpa_scores[..ss];
9319                        #[cfg(target_arch = "aarch64")]
9320                        let neon_chunks = dh / 4;
9321
9322                        for bi in 0..b {
9323                            for hi in 0..nh {
9324                                // Q@K^T via strided NEON dot products
9325                                for qi in 0..q_s {
9326                                    let q_off = bi * q_s * qrs + qi * qrs + hi * dh;
9327                                    for ki in 0..k_s {
9328                                        let k_off = bi * k_s * krs + ki * krs + hi * dh;
9329                                        #[cfg(target_arch = "aarch64")]
9330                                        let mut dot;
9331                                        #[cfg(not(target_arch = "aarch64"))]
9332                                        let mut dot = 0f32;
9333                                        #[cfg(target_arch = "aarch64")]
9334                                        {
9335                                            use std::arch::aarch64::*;
9336                                            let mut acc = vdupq_n_f32(0.0);
9337                                            for c in 0..neon_chunks {
9338                                                let vq =
9339                                                    vld1q_f32(q_data.as_ptr().add(q_off + c * 4));
9340                                                let vk =
9341                                                    vld1q_f32(k_data.as_ptr().add(k_off + c * 4));
9342                                                acc = vfmaq_f32(acc, vq, vk);
9343                                            }
9344                                            dot = vaddvq_f32(acc);
9345                                            for d in (neon_chunks * 4)..dh {
9346                                                dot += q_data[q_off + d] * k_data[k_off + d];
9347                                            }
9348                                        }
9349                                        #[cfg(not(target_arch = "aarch64"))]
9350                                        for d in 0..dh {
9351                                            dot += q_data[q_off + d] * k_data[k_off + d];
9352                                        }
9353                                        scores[qi * k_s + ki] = dot * scale;
9354                                        // Inner-loop Custom mask check —
9355                                        // Causal / SlidingWindow / None
9356                                        // apply outside the loop below.
9357                                        // Skip for Bias — that mask is a
9358                                        // per-head additive tensor, not a
9359                                        // 0/1 key-padding mask.
9360                                        if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
9361                                            && !mask_data.is_empty()
9362                                            && mask_data[bi * k_s + ki] < mask_thr
9363                                        {
9364                                            scores[qi * k_s + ki] = mask_neg;
9365                                        }
9366                                    }
9367                                }
9368
9369                                if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
9370                                    let off = (bi * nh + hi) * q_s * k_s;
9371                                    for i in 0..q_s * k_s {
9372                                        scores[i] += mask_data[off + i];
9373                                    }
9374                                }
9375                                apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
9376                                crate::kernels::neon_softmax(scores, q_s, k_s);
9377
9378                                // Score@V via strided NEON accumulation (zero-copy)
9379                                for qi in 0..q_s {
9380                                    let o_off = bi * q_s * hs + qi * hs + hi * dh;
9381                                    // Zero output for this head position
9382                                    for d in 0..dh {
9383                                        out_data[o_off + d] = 0.0;
9384                                    }
9385                                    for ki in 0..k_s {
9386                                        let sc = scores[qi * k_s + ki];
9387                                        if sc > score_thr {
9388                                            let v_off = bi * k_s * vrs + ki * vrs + hi * dh;
9389                                            #[cfg(target_arch = "aarch64")]
9390                                            {
9391                                                use std::arch::aarch64::*;
9392                                                let vsc = vdupq_n_f32(sc);
9393                                                for c in 0..neon_chunks {
9394                                                    let off = c * 4;
9395                                                    let vo = vld1q_f32(
9396                                                        out_data.as_ptr().add(o_off + off),
9397                                                    );
9398                                                    let vv =
9399                                                        vld1q_f32(v_data.as_ptr().add(v_off + off));
9400                                                    vst1q_f32(
9401                                                        out_data.as_mut_ptr().add(o_off + off),
9402                                                        vfmaq_f32(vo, vsc, vv),
9403                                                    );
9404                                                }
9405                                            }
9406                                            #[cfg(not(target_arch = "aarch64"))]
9407                                            for d in 0..dh {
9408                                                out_data[o_off + d] += sc * v_data[v_off + d];
9409                                            }
9410                                        }
9411                                    }
9412                                }
9413                            }
9414                        }
9415                    } else {
9416                        // ── Parallel strided BLAS path (high throughput) ──
9417                        let total_work = b * nh;
9418                        let q_addr = q_data.as_ptr() as usize;
9419                        let k_addr = k_data.as_ptr() as usize;
9420                        let v_addr = v_data.as_ptr() as usize;
9421                        let m_addr = mask_data.as_ptr() as usize;
9422                        let o_addr = out_data.as_mut_ptr() as usize;
9423                        let sc_addr = sdpa_scores.as_mut_ptr() as usize;
9424
9425                        crate::pool::par_for(total_work, 1, &|off, cnt| {
9426                            for idx in off..off + cnt {
9427                                let bi = idx / nh;
9428                                let hi = idx % nh;
9429
9430                                let q_start = (q_addr as *const f32).add(bi * q_s * qrs + hi * dh);
9431                                let k_start = (k_addr as *const f32).add(bi * k_s * krs + hi * dh);
9432                                let v_start = (v_addr as *const f32).add(bi * k_s * vrs + hi * dh);
9433                                let o_start = (o_addr as *mut f32).add(bi * q_s * hs + hi * dh);
9434                                let sc = std::slice::from_raw_parts_mut(
9435                                    (sc_addr as *mut f32).add(idx * ss),
9436                                    ss,
9437                                );
9438
9439                                // LDA = qrs, LDB = krs (parent row strides
9440                                // when fused; hs otherwise).
9441                                crate::blas::sgemm_general(
9442                                    q_start,
9443                                    k_start,
9444                                    sc.as_mut_ptr(),
9445                                    q_s,
9446                                    k_s,
9447                                    dh,
9448                                    scale,
9449                                    0.0,
9450                                    qrs,
9451                                    krs,
9452                                    k_s,
9453                                    false,
9454                                    true,
9455                                );
9456
9457                                match mask_kind {
9458                                    rlx_ir::op::MaskKind::Custom => {
9459                                        let mask_bi = std::slice::from_raw_parts(
9460                                            (m_addr as *const f32).add(bi * k_s),
9461                                            k_s,
9462                                        );
9463                                        for ki in 0..k_s {
9464                                            if mask_bi[ki] < mask_thr {
9465                                                for qi in 0..q_s {
9466                                                    sc[qi * k_s + ki] = mask_neg;
9467                                                }
9468                                            }
9469                                        }
9470                                    }
9471                                    rlx_ir::op::MaskKind::Bias => {
9472                                        // Per-head additive bias slice.
9473                                        let bias = std::slice::from_raw_parts(
9474                                            (m_addr as *const f32).add((bi * nh + hi) * q_s * k_s),
9475                                            q_s * k_s,
9476                                        );
9477                                        for i in 0..q_s * k_s {
9478                                            sc[i] += bias[i];
9479                                        }
9480                                    }
9481                                    _ => apply_synthetic_mask(sc, q_s, k_s, *mask_kind),
9482                                }
9483
9484                                crate::kernels::neon_softmax(sc, q_s, k_s);
9485
9486                                // LDB = vrs (parent row stride when
9487                                // fused; hs otherwise). LDC stays hs —
9488                                // output is its own contiguous buffer.
9489                                crate::blas::sgemm_general(
9490                                    sc.as_ptr(),
9491                                    v_start,
9492                                    o_start,
9493                                    q_s,
9494                                    dh,
9495                                    k_s,
9496                                    1.0,
9497                                    0.0,
9498                                    k_s,
9499                                    vrs,
9500                                    hs,
9501                                    false,
9502                                    false,
9503                                );
9504                            }
9505                        });
9506                    }
9507                }
9508            }
9509
9510            Thunk::AttentionBackward {
9511                q,
9512                k,
9513                v,
9514                dy,
9515                mask,
9516                out,
9517                batch,
9518                seq,
9519                kv_seq,
9520                heads,
9521                head_dim,
9522                mask_kind,
9523                wrt,
9524                bhsd,
9525            } => {
9526                let (b, q_s, k_s, nh, dh) = (
9527                    *batch as usize,
9528                    *seq as usize,
9529                    *kv_seq as usize,
9530                    *heads as usize,
9531                    *head_dim as usize,
9532                );
9533                unsafe {
9534                    let q_len = if *bhsd {
9535                        b * nh * q_s * dh
9536                    } else {
9537                        b * q_s * nh * dh
9538                    };
9539                    let k_len = if *bhsd {
9540                        b * nh * k_s * dh
9541                    } else {
9542                        b * k_s * nh * dh
9543                    };
9544                    let out_len = match wrt {
9545                        rlx_ir::op::AttentionBwdWrt::Key | rlx_ir::op::AttentionBwdWrt::Value => {
9546                            k_len
9547                        }
9548                        rlx_ir::op::AttentionBwdWrt::Query => q_len,
9549                    };
9550                    let q_data = sl(*q, base, q_len);
9551                    let k_data = sl(*k, base, k_len);
9552                    let v_data = sl(*v, base, k_len);
9553                    let dy_data = sl(*dy, base, q_len);
9554                    let out_data = sl_mut(*out, base, out_len);
9555                    let mask_data: &[f32] = if *mask != 0 {
9556                        let ml = match mask_kind {
9557                            rlx_ir::op::MaskKind::Custom => b * k_s,
9558                            rlx_ir::op::MaskKind::Bias => b * nh * q_s * k_s,
9559                            _ => 0,
9560                        };
9561                        sl(*mask, base, ml)
9562                    } else {
9563                        &[]
9564                    };
9565                    crate::attention_bwd::attention_backward(
9566                        *wrt, q_data, k_data, v_data, dy_data, out_data, b, nh, q_s, k_s, dh,
9567                        *mask_kind, mask_data, *bhsd,
9568                    );
9569                }
9570            }
9571
9572            Thunk::ActivationInPlace { data, len, act } => {
9573                let len = *len as usize;
9574                unsafe {
9575                    let d = sl_mut(*data, base, len);
9576                    match act {
9577                        Activation::Gelu => crate::kernels::par_gelu_inplace(d),
9578                        Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
9579                        Activation::Silu => crate::kernels::par_silu_inplace(d),
9580                        Activation::Relu => {
9581                            for v in d.iter_mut() {
9582                                *v = v.max(0.0);
9583                            }
9584                        }
9585                        Activation::Sigmoid => {
9586                            for v in d.iter_mut() {
9587                                *v = 1.0 / (1.0 + (-*v).exp());
9588                            }
9589                        }
9590                        Activation::Tanh => {
9591                            for v in d.iter_mut() {
9592                                *v = v.tanh();
9593                            }
9594                        }
9595                        Activation::Exp => {
9596                            for v in d.iter_mut() {
9597                                *v = v.exp();
9598                            }
9599                        }
9600                        Activation::Log => {
9601                            for v in d.iter_mut() {
9602                                *v = v.ln();
9603                            }
9604                        }
9605                        Activation::Sqrt => {
9606                            for v in d.iter_mut() {
9607                                *v = v.sqrt();
9608                            }
9609                        }
9610                        Activation::Rsqrt => {
9611                            for v in d.iter_mut() {
9612                                *v = 1.0 / v.sqrt();
9613                            }
9614                        }
9615                        Activation::Neg => {
9616                            for v in d.iter_mut() {
9617                                *v = -*v;
9618                            }
9619                        }
9620                        Activation::Abs => {
9621                            for v in d.iter_mut() {
9622                                *v = v.abs();
9623                            }
9624                        }
9625                        Activation::Round => {
9626                            for v in d.iter_mut() {
9627                                *v = v.round();
9628                            }
9629                        }
9630                        Activation::Sin => {
9631                            for v in d.iter_mut() {
9632                                *v = v.sin();
9633                            }
9634                        }
9635                        Activation::Cos => {
9636                            for v in d.iter_mut() {
9637                                *v = v.cos();
9638                            }
9639                        }
9640                        Activation::Tan => {
9641                            for v in d.iter_mut() {
9642                                *v = v.tan();
9643                            }
9644                        }
9645                        Activation::Atan => {
9646                            for v in d.iter_mut() {
9647                                *v = v.atan();
9648                            }
9649                        }
9650                    }
9651                }
9652            }
9653
9654            Thunk::FusedAttnBlock {
9655                hidden,
9656                qkv_w,
9657                out_w,
9658                mask,
9659                out,
9660                qkv_b,
9661                out_b,
9662                cos,
9663                sin,
9664                cos_len,
9665                batch,
9666                seq,
9667                hs,
9668                nh,
9669                dh,
9670                has_bias,
9671                has_rope,
9672            } => {
9673                let (b, s) = (*batch as usize, *seq as usize);
9674                let (h, n_h, d_h) = (*hs as usize, *nh as usize, *dh as usize);
9675                let m = b * s;
9676                let scale = (d_h as f32).powf(-0.5);
9677                let half = d_h / 2;
9678                unsafe {
9679                    let inp = sl(*hidden, base, m * h);
9680                    let wq = sl(*qkv_w, base, h * 3 * h);
9681                    let wo = sl(*out_w, base, h * h);
9682                    let mk = sl(*mask, base, b * s);
9683                    let dst = sl_mut(*out, base, m * h);
9684
9685                    // Stack-allocated intermediates — all fit in L1 cache for small batch
9686                    let mut qkv = vec![0f32; m * 3 * h];
9687                    let mut attn_out = vec![0f32; m * h];
9688                    let mut scores_buf = vec![0f32; s * s]; // one head at a time
9689
9690                    // 1. QKV projection: [m, h] @ [h, 3h] → [m, 3h]
9691                    crate::blas::sgemm(inp, wq, &mut qkv, m, h, 3 * h);
9692                    if *has_bias {
9693                        let bias = sl(*qkv_b, base, 3 * h);
9694                        crate::blas::bias_add(&mut qkv, bias, m, 3 * h);
9695                    }
9696
9697                    // 2. Multi-head SDPA (Q/K/V are views into qkv at offsets 0, h, 2h)
9698                    //    Process heads sequentially with inline RoPE — zero copy.
9699                    #[cfg(target_arch = "aarch64")]
9700                    let neon_chunks = d_h / 4;
9701                    #[cfg(target_arch = "aarch64")]
9702                    let _rope_chunks = half / 4;
9703
9704                    for bi in 0..b {
9705                        for hi in 0..n_h {
9706                            // For each (query_pos, key_pos): compute Q@K^T with inline RoPE
9707                            for qi in 0..s {
9708                                let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
9709                                for ki in 0..s {
9710                                    let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
9711                                    let mut dot = 0f32;
9712
9713                                    if *has_rope {
9714                                        // Apply RoPE inline during dot product
9715                                        let q_cos = qi * half;
9716                                        let k_cos = ki * half;
9717                                        let cos_tab = sl(*cos, base, *cos_len as usize);
9718                                        let sin_tab = sl(*sin, base, *cos_len as usize);
9719                                        // First half: (q1*c - q2*s) * (k1*c - k2*s)
9720                                        // Second half: (q2*c + q1*s) * (k2*c + k1*s)
9721                                        for i in 0..half {
9722                                            let q1 = qkv[q_base + i];
9723                                            let q2 = qkv[q_base + half + i];
9724                                            let k1 = qkv[k_base + i];
9725                                            let k2 = qkv[k_base + half + i];
9726                                            let c_q = cos_tab[q_cos + i];
9727                                            let s_q = sin_tab[q_cos + i];
9728                                            let c_k = cos_tab[k_cos + i];
9729                                            let s_k = sin_tab[k_cos + i];
9730                                            let qr1 = q1 * c_q - q2 * s_q;
9731                                            let kr1 = k1 * c_k - k2 * s_k;
9732                                            let qr2 = q2 * c_q + q1 * s_q;
9733                                            let kr2 = k2 * c_k + k1 * s_k;
9734                                            dot += qr1 * kr1 + qr2 * kr2;
9735                                        }
9736                                    } else {
9737                                        // Standard dot product
9738                                        #[cfg(target_arch = "aarch64")]
9739                                        {
9740                                            use std::arch::aarch64::*;
9741                                            let mut acc = vdupq_n_f32(0.0);
9742                                            for c in 0..neon_chunks {
9743                                                let vq =
9744                                                    vld1q_f32(qkv.as_ptr().add(q_base + c * 4));
9745                                                let vk =
9746                                                    vld1q_f32(qkv.as_ptr().add(k_base + c * 4));
9747                                                acc = vfmaq_f32(acc, vq, vk);
9748                                            }
9749                                            dot = vaddvq_f32(acc);
9750                                            for d in (neon_chunks * 4)..d_h {
9751                                                dot += qkv[q_base + d] * qkv[k_base + d];
9752                                            }
9753                                        }
9754                                        #[cfg(not(target_arch = "aarch64"))]
9755                                        for d in 0..d_h {
9756                                            dot += qkv[q_base + d] * qkv[k_base + d];
9757                                        }
9758                                    }
9759
9760                                    scores_buf[qi * s + ki] = dot * scale;
9761                                    if mk[bi * s + ki] < mask_thr {
9762                                        scores_buf[qi * s + ki] = mask_neg;
9763                                    }
9764                                }
9765                            }
9766
9767                            // Softmax
9768                            crate::kernels::neon_softmax(&mut scores_buf[..s * s], s, s);
9769
9770                            // Score @ V accumulation (V at offset 2h in QKV)
9771                            for qi in 0..s {
9772                                let o_base = bi * s * h + qi * h + hi * d_h;
9773                                for d in 0..d_h {
9774                                    attn_out[o_base + d] = 0.0;
9775                                }
9776                                for ki in 0..s {
9777                                    let sc = scores_buf[qi * s + ki];
9778                                    if sc > score_thr {
9779                                        let v_base = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
9780                                        #[cfg(target_arch = "aarch64")]
9781                                        {
9782                                            use std::arch::aarch64::*;
9783                                            let vsc = vdupq_n_f32(sc);
9784                                            for c in 0..neon_chunks {
9785                                                let off = c * 4;
9786                                                let vo =
9787                                                    vld1q_f32(attn_out.as_ptr().add(o_base + off));
9788                                                let vv = vld1q_f32(qkv.as_ptr().add(v_base + off));
9789                                                vst1q_f32(
9790                                                    attn_out.as_mut_ptr().add(o_base + off),
9791                                                    vfmaq_f32(vo, vsc, vv),
9792                                                );
9793                                            }
9794                                        }
9795                                        #[cfg(not(target_arch = "aarch64"))]
9796                                        for d in 0..d_h {
9797                                            attn_out[o_base + d] += sc * qkv[v_base + d];
9798                                        }
9799                                    }
9800                                }
9801                            }
9802                        }
9803                    }
9804
9805                    // 3. Output projection: [m, h] @ [h, h] → dst
9806                    crate::blas::sgemm(&attn_out, wo, dst, m, h, h);
9807                    if *has_bias {
9808                        let bias = sl(*out_b, base, h);
9809                        crate::blas::bias_add(dst, bias, m, h);
9810                    }
9811                }
9812            }
9813
9814            Thunk::Rope {
9815                src,
9816                cos,
9817                sin,
9818                dst,
9819                batch,
9820                seq,
9821                hidden,
9822                head_dim,
9823                n_rot,
9824                cos_len,
9825                src_row_stride,
9826            } => {
9827                let (b, s, hs, dh, nr) = (
9828                    *batch as usize,
9829                    *seq as usize,
9830                    *hidden as usize,
9831                    *head_dim as usize,
9832                    *n_rot as usize,
9833                );
9834                let tab_half = dh / 2;
9835                let rot_half = nr / 2;
9836                let nh = hs / dh;
9837                let cl = *cos_len as usize;
9838                let src_rs = *src_row_stride as usize;
9839                unsafe {
9840                    let x = sl(*src, base, b * s * src_rs);
9841                    let cos_tab = sl(*cos, base, cl);
9842                    let sin_tab = sl(*sin, base, cl);
9843                    let out = sl_mut(*dst, base, b * s * hs);
9844
9845                    let total = b * s;
9846                    let x_ptr = x.as_ptr() as usize;
9847                    let o_ptr = out.as_mut_ptr() as usize;
9848                    let c_ptr = cos_tab.as_ptr() as usize;
9849                    let s_ptr = sin_tab.as_ptr() as usize;
9850
9851                    crate::pool::par_for(total, 4, &|off, cnt| {
9852                        for idx in off..off + cnt {
9853                            let bi = idx / s;
9854                            let si = idx % s;
9855                            let tab_off = si * tab_half;
9856
9857                            for hi in 0..nh {
9858                                let src_base = bi * s * src_rs + si * src_rs + hi * dh;
9859                                let dst_base = bi * s * hs + si * hs + hi * dh;
9860                                let xp = (x_ptr as *const f32).add(src_base);
9861                                let op = (o_ptr as *mut f32).add(dst_base);
9862                                let cp = (c_ptr as *const f32).add(tab_off);
9863                                let sp = (s_ptr as *const f32).add(tab_off);
9864
9865                                for i in 0..rot_half {
9866                                    let x1 = *xp.add(i);
9867                                    let x2 = *xp.add(rot_half + i);
9868                                    let cv = *cp.add(i);
9869                                    let sv = *sp.add(i);
9870                                    *op.add(i) = x1 * cv - x2 * sv;
9871                                    *op.add(rot_half + i) = x2 * cv + x1 * sv;
9872                                }
9873                                for j in nr..dh {
9874                                    *op.add(j) = *xp.add(j);
9875                                }
9876                            }
9877                        }
9878                    });
9879                }
9880            }
9881            Thunk::FusedBertLayer {
9882                hidden,
9883                qkv_w,
9884                qkv_b,
9885                out_w,
9886                out_b,
9887                mask,
9888                ln1_g,
9889                ln1_b,
9890                eps1,
9891                fc1_w,
9892                fc1_b,
9893                fc2_w,
9894                fc2_b,
9895                ln2_g,
9896                ln2_b,
9897                eps2,
9898                out,
9899                batch,
9900                seq,
9901                hs,
9902                nh,
9903                dh,
9904                int_dim,
9905            } => {
9906                let (b, s, h, n_h, d_h) = (
9907                    *batch as usize,
9908                    *seq as usize,
9909                    *hs as usize,
9910                    *nh as usize,
9911                    *dh as usize,
9912                );
9913                let m = b * s;
9914                let id = *int_dim as usize;
9915                let scale = (d_h as f32).powf(-0.5);
9916                let _half = d_h / 2;
9917                #[cfg(target_arch = "aarch64")]
9918                let neon_chunks = d_h / 4;
9919                unsafe {
9920                    let inp = sl(*hidden, base, m * h);
9921                    let dst = sl_mut(*out, base, m * h);
9922                    let mk = sl(*mask, base, b * s);
9923
9924                    // Pre-allocated buffers (zero malloc per layer — allocated once before thunk loop)
9925                    let qkv = std::slice::from_raw_parts_mut(fl_qkv.as_mut_ptr(), m * 3 * h);
9926                    let attn = std::slice::from_raw_parts_mut(fl_attn.as_mut_ptr(), m * h);
9927                    let res = std::slice::from_raw_parts_mut(fl_res.as_mut_ptr(), m * h);
9928                    let normed = std::slice::from_raw_parts_mut(fl_normed.as_mut_ptr(), m * h);
9929                    let ffn = std::slice::from_raw_parts_mut(fl_ffn.as_mut_ptr(), m * id);
9930                    let sc = std::slice::from_raw_parts_mut(fl_sc.as_mut_ptr(), s * s);
9931
9932                    // QKV (parallelized across cores — multiple AMX coprocessors)
9933                    crate::blas::par_sgemm_bias(
9934                        inp,
9935                        sl(*qkv_w, base, h * 3 * h),
9936                        sl(*qkv_b, base, 3 * h),
9937                        qkv,
9938                        m,
9939                        h,
9940                        3 * h,
9941                    );
9942
9943                    // SDPA per head (sequential NEON, inline — zero overhead)
9944                    for bi in 0..b {
9945                        for hi in 0..n_h {
9946                            for qi in 0..s {
9947                                for ki in 0..s {
9948                                    let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
9949                                    let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
9950                                    #[cfg(target_arch = "aarch64")]
9951                                    let dot;
9952                                    #[cfg(not(target_arch = "aarch64"))]
9953                                    let mut dot = 0f32;
9954                                    #[cfg(target_arch = "aarch64")]
9955                                    {
9956                                        use std::arch::aarch64::*;
9957                                        let mut acc = vdupq_n_f32(0.0);
9958                                        for c in 0..neon_chunks {
9959                                            acc = vfmaq_f32(
9960                                                acc,
9961                                                vld1q_f32(qkv.as_ptr().add(q_base + c * 4)),
9962                                                vld1q_f32(qkv.as_ptr().add(k_base + c * 4)),
9963                                            );
9964                                        }
9965                                        dot = vaddvq_f32(acc);
9966                                    }
9967                                    #[cfg(not(target_arch = "aarch64"))]
9968                                    for d in 0..d_h {
9969                                        dot += qkv[q_base + d] * qkv[k_base + d];
9970                                    }
9971                                    sc[qi * s + ki] = dot * scale;
9972                                    if mk[bi * s + ki] < mask_thr {
9973                                        sc[qi * s + ki] = mask_neg;
9974                                    }
9975                                }
9976                            }
9977                            crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
9978                            for qi in 0..s {
9979                                let o = bi * s * h + qi * h + hi * d_h;
9980                                for d in 0..d_h {
9981                                    attn[o + d] = 0.0;
9982                                }
9983                                for ki in 0..s {
9984                                    let w = sc[qi * s + ki];
9985                                    if w > score_thr {
9986                                        let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
9987                                        #[cfg(target_arch = "aarch64")]
9988                                        {
9989                                            use std::arch::aarch64::*;
9990                                            let vw = vdupq_n_f32(w);
9991                                            for c in 0..neon_chunks {
9992                                                let off = c * 4;
9993                                                vst1q_f32(
9994                                                    attn.as_mut_ptr().add(o + off),
9995                                                    vfmaq_f32(
9996                                                        vld1q_f32(attn.as_ptr().add(o + off)),
9997                                                        vw,
9998                                                        vld1q_f32(qkv.as_ptr().add(v + off)),
9999                                                    ),
10000                                                );
10001                                            }
10002                                        }
10003                                        #[cfg(not(target_arch = "aarch64"))]
10004                                        for d in 0..d_h {
10005                                            attn[o + d] += w * qkv[v + d];
10006                                        }
10007                                    }
10008                                }
10009                            }
10010                        }
10011                    }
10012
10013                    // Out proj (sgemm + bias fused) + residual add with NEON
10014                    crate::blas::sgemm_bias(
10015                        attn,
10016                        sl(*out_w, base, h * h),
10017                        sl(*out_b, base, h),
10018                        res,
10019                        m,
10020                        h,
10021                        h,
10022                    );
10023                    #[cfg(target_arch = "aarch64")]
10024                    {
10025                        use std::arch::aarch64::*;
10026                        let chunks_h = (m * h) / 4;
10027                        for c in 0..chunks_h {
10028                            let off = c * 4;
10029                            vst1q_f32(
10030                                res.as_mut_ptr().add(off),
10031                                vaddq_f32(
10032                                    vld1q_f32(res.as_ptr().add(off)),
10033                                    vld1q_f32(inp.as_ptr().add(off)),
10034                                ),
10035                            );
10036                        }
10037                        for i in (chunks_h * 4)..(m * h) {
10038                            res[i] += inp[i];
10039                        }
10040                    }
10041                    #[cfg(not(target_arch = "aarch64"))]
10042                    for i in 0..m * h {
10043                        res[i] += inp[i];
10044                    }
10045
10046                    // LN1 (fused residual already done above — just normalize)
10047                    let g1 = sl(*ln1_g, base, h);
10048                    let b1 = sl(*ln1_b, base, h);
10049                    for r in 0..m {
10050                        crate::kernels::layer_norm_row(
10051                            &res[r * h..(r + 1) * h],
10052                            g1,
10053                            b1,
10054                            &mut normed[r * h..(r + 1) * h],
10055                            h,
10056                            *eps1,
10057                        );
10058                    }
10059
10060                    // FFN: fc1 (parallel across cores) + GELU
10061                    crate::blas::par_sgemm_bias(
10062                        normed,
10063                        sl(*fc1_w, base, h * id),
10064                        sl(*fc1_b, base, id),
10065                        ffn,
10066                        m,
10067                        h,
10068                        id,
10069                    );
10070                    crate::kernels::par_gelu_inplace(ffn);
10071
10072                    // fc2 + bias (parallel across cores) + residual with NEON
10073                    crate::blas::par_sgemm_bias(
10074                        ffn,
10075                        sl(*fc2_w, base, id * h),
10076                        sl(*fc2_b, base, h),
10077                        res,
10078                        m,
10079                        id,
10080                        h,
10081                    );
10082                    #[cfg(target_arch = "aarch64")]
10083                    {
10084                        use std::arch::aarch64::*;
10085                        let chunks_h = (m * h) / 4;
10086                        for c in 0..chunks_h {
10087                            let off = c * 4;
10088                            vst1q_f32(
10089                                res.as_mut_ptr().add(off),
10090                                vaddq_f32(
10091                                    vld1q_f32(res.as_ptr().add(off)),
10092                                    vld1q_f32(normed.as_ptr().add(off)),
10093                                ),
10094                            );
10095                        }
10096                        for i in (chunks_h * 4)..(m * h) {
10097                            res[i] += normed[i];
10098                        }
10099                    }
10100                    #[cfg(not(target_arch = "aarch64"))]
10101                    for i in 0..m * h {
10102                        res[i] += normed[i];
10103                    }
10104
10105                    // LN2 → output
10106                    let g2 = sl(*ln2_g, base, h);
10107                    let b2 = sl(*ln2_b, base, h);
10108                    for r in 0..m {
10109                        crate::kernels::layer_norm_row(
10110                            &res[r * h..(r + 1) * h],
10111                            g2,
10112                            b2,
10113                            &mut dst[r * h..(r + 1) * h],
10114                            h,
10115                            *eps2,
10116                        );
10117                    }
10118                }
10119            }
10120
10121            Thunk::FusedNomicLayer {
10122                hidden,
10123                qkv_w,
10124                out_w,
10125                mask,
10126                cos,
10127                sin,
10128                cos_len,
10129                ln1_g,
10130                ln1_b,
10131                eps1,
10132                fc11_w,
10133                fc12_w: _,
10134                fc2_w,
10135                ln2_g,
10136                ln2_b,
10137                eps2,
10138                out,
10139                batch,
10140                seq,
10141                hs,
10142                nh,
10143                dh,
10144                int_dim,
10145            } => {
10146                let (b, s, h, n_h, d_h) = (
10147                    *batch as usize,
10148                    *seq as usize,
10149                    *hs as usize,
10150                    *nh as usize,
10151                    *dh as usize,
10152                );
10153                let m = b * s;
10154                let id = *int_dim as usize;
10155                let scale = (d_h as f32).powf(-0.5);
10156                let half_dh = d_h / 2;
10157                #[cfg(target_arch = "aarch64")]
10158                let neon_chunks = d_h / 4;
10159                unsafe {
10160                    let inp = sl(*hidden, base, m * h);
10161                    let dst = sl_mut(*out, base, m * h);
10162                    let mk = sl(*mask, base, b * s);
10163                    let cos_tab = sl(*cos, base, *cos_len as usize);
10164                    let sin_tab = sl(*sin, base, *cos_len as usize);
10165                    // fc11_w is the fused [h, 2*int_dim] weight (fc11 || fc12 concatenated)
10166                    let fused_fc_w = sl(*fc11_w, base, h * 2 * id);
10167
10168                    let mut qkv = vec![0f32; m * 3 * h];
10169                    let mut attn = vec![0f32; m * h];
10170                    let mut res = vec![0f32; m * h];
10171                    let mut normed = vec![0f32; m * h];
10172                    let mut ffn_concat = vec![0f32; m * 2 * id]; // fc11||fc12 output
10173                    let mut sc = vec![0f32; s * s];
10174
10175                    // QKV (no bias)
10176                    crate::blas::sgemm(inp, sl(*qkv_w, base, h * 3 * h), &mut qkv, m, h, 3 * h);
10177
10178                    // SDPA with inline RoPE
10179                    for bi in 0..b {
10180                        for hi in 0..n_h {
10181                            for qi in 0..s {
10182                                for ki in 0..s {
10183                                    let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
10184                                    let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
10185                                    let mut dot = 0f32;
10186                                    for i in 0..half_dh {
10187                                        let q1 = qkv[q_base + i];
10188                                        let q2 = qkv[q_base + half_dh + i];
10189                                        let k1 = qkv[k_base + i];
10190                                        let k2 = qkv[k_base + half_dh + i];
10191                                        let cq = cos_tab[qi * half_dh + i];
10192                                        let sq = sin_tab[qi * half_dh + i];
10193                                        let ck = cos_tab[ki * half_dh + i];
10194                                        let sk = sin_tab[ki * half_dh + i];
10195                                        dot += (q1 * cq - q2 * sq) * (k1 * ck - k2 * sk)
10196                                            + (q2 * cq + q1 * sq) * (k2 * ck + k1 * sk);
10197                                    }
10198                                    sc[qi * s + ki] = dot * scale;
10199                                    if mk[bi * s + ki] < mask_thr {
10200                                        sc[qi * s + ki] = mask_neg;
10201                                    }
10202                                }
10203                            }
10204                            crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
10205                            for qi in 0..s {
10206                                let o = bi * s * h + qi * h + hi * d_h;
10207                                for d in 0..d_h {
10208                                    attn[o + d] = 0.0;
10209                                }
10210                                for ki in 0..s {
10211                                    let w = sc[qi * s + ki];
10212                                    if w > score_thr {
10213                                        let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
10214                                        #[cfg(target_arch = "aarch64")]
10215                                        {
10216                                            use std::arch::aarch64::*;
10217                                            let vw = vdupq_n_f32(w);
10218                                            for c in 0..neon_chunks {
10219                                                let off = c * 4;
10220                                                vst1q_f32(
10221                                                    attn.as_mut_ptr().add(o + off),
10222                                                    vfmaq_f32(
10223                                                        vld1q_f32(attn.as_ptr().add(o + off)),
10224                                                        vw,
10225                                                        vld1q_f32(qkv.as_ptr().add(v + off)),
10226                                                    ),
10227                                                );
10228                                            }
10229                                        }
10230                                        #[cfg(not(target_arch = "aarch64"))]
10231                                        for d in 0..d_h {
10232                                            attn[o + d] += w * qkv[v + d];
10233                                        }
10234                                    }
10235                                }
10236                            }
10237                        }
10238                    }
10239
10240                    // Out proj (no bias) + residual
10241                    crate::blas::sgemm(&attn, sl(*out_w, base, h * h), &mut res, m, h, h);
10242                    for i in 0..m * h {
10243                        res[i] += inp[i];
10244                    }
10245
10246                    // LN1
10247                    let g1 = sl(*ln1_g, base, h);
10248                    let b1 = sl(*ln1_b, base, h);
10249                    for r in 0..m {
10250                        crate::kernels::layer_norm_row(
10251                            &res[r * h..(r + 1) * h],
10252                            g1,
10253                            b1,
10254                            &mut normed[r * h..(r + 1) * h],
10255                            h,
10256                            *eps1,
10257                        );
10258                    }
10259
10260                    // SwiGLU: fused fc11+fc12 sgemm, then split, silu, mul
10261                    crate::blas::sgemm(&normed, fused_fc_w, &mut ffn_concat, m, h, 2 * id);
10262                    // Split: first id cols = fc11 (up), second id cols = fc12 (gate)
10263                    // SiLU on gate, then multiply up * gate → store in up region
10264                    for row in 0..m {
10265                        let bo = row * 2 * id;
10266                        // SiLU in-place on gate portion
10267                        for j in 0..id {
10268                            let x = ffn_concat[bo + id + j];
10269                            ffn_concat[bo + id + j] = x / (1.0 + (-x).exp());
10270                        }
10271                        // Multiply: up[j] *= gate[j]
10272                        for j in 0..id {
10273                            ffn_concat[bo + j] *= ffn_concat[bo + id + j];
10274                        }
10275                    }
10276
10277                    // fc2 (no bias) + residual  — read from first id cols of ffn_concat
10278                    // Need contiguous [m, id] for sgemm. Copy or use strided sgemm.
10279                    // The up*gate result is at ffn_concat[row * 2*id .. row * 2*id + id]
10280                    // Stride = 2*id. Use sgemm_general with lda = 2*id.
10281                    crate::blas::sgemm_general(
10282                        ffn_concat.as_ptr(),
10283                        sl(*fc2_w, base, id * h).as_ptr(),
10284                        res.as_mut_ptr(),
10285                        m,
10286                        h,
10287                        id,
10288                        1.0,
10289                        0.0,
10290                        2 * id,
10291                        h,
10292                        h,
10293                        false,
10294                        false,
10295                    );
10296                    for i in 0..m * h {
10297                        res[i] += normed[i];
10298                    }
10299
10300                    // LN2 → output
10301                    let g2 = sl(*ln2_g, base, h);
10302                    let b2 = sl(*ln2_b, base, h);
10303                    for r in 0..m {
10304                        crate::kernels::layer_norm_row(
10305                            &res[r * h..(r + 1) * h],
10306                            g2,
10307                            b2,
10308                            &mut dst[r * h..(r + 1) * h],
10309                            h,
10310                            *eps2,
10311                        );
10312                    }
10313                }
10314            }
10315
10316            Thunk::FusedSwiGLU {
10317                src,
10318                dst,
10319                n_half,
10320                total,
10321                gate_first,
10322            } => {
10323                let n = *n_half as usize;
10324                let t = *total as usize;
10325                let outer = t / n;
10326                let in_total = outer * 2 * n;
10327                let gate_first = *gate_first;
10328                unsafe {
10329                    let inp = sl(*src, base, in_total);
10330                    let out = sl_mut(*dst, base, t);
10331                    for o in 0..outer {
10332                        let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
10333                        let out_row = &mut out[o * n..(o + 1) * n];
10334                        for i in 0..n {
10335                            let (up, gate) = if gate_first {
10336                                (in_row[n + i], in_row[i])
10337                            } else {
10338                                (in_row[i], in_row[n + i])
10339                            };
10340                            out_row[i] = up * (gate / (1.0 + (-gate).exp()));
10341                        }
10342                    }
10343                }
10344            }
10345
10346            Thunk::Concat {
10347                dst,
10348                outer,
10349                inner,
10350                total_axis,
10351                inputs,
10352            } => {
10353                let outer = *outer as usize;
10354                let inner = *inner as usize;
10355                let total_axis = *total_axis as usize;
10356                let row_stride = total_axis * inner;
10357                let out_total = outer * row_stride;
10358                unsafe {
10359                    let out = sl_mut(*dst, base, out_total);
10360                    let mut cum: usize = 0;
10361                    for (src_off, in_axis) in inputs {
10362                        let in_axis = *in_axis as usize;
10363                        let copy_per_row = in_axis * inner;
10364                        let dst_col_off = cum * inner;
10365                        let in_total = outer * copy_per_row;
10366                        let inp = sl(*src_off, base, in_total);
10367                        for o in 0..outer {
10368                            let dst_row_start = o * row_stride + dst_col_off;
10369                            let src_row_start = o * copy_per_row;
10370                            out[dst_row_start..dst_row_start + copy_per_row]
10371                                .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
10372                        }
10373                        cum += in_axis;
10374                    }
10375                }
10376            }
10377
10378            Thunk::ConcatF64 {
10379                dst,
10380                outer,
10381                inner,
10382                total_axis,
10383                inputs,
10384            } => {
10385                let outer = *outer as usize;
10386                let inner = *inner as usize;
10387                let total_axis = *total_axis as usize;
10388                let row_stride = total_axis * inner;
10389                let out_total = outer * row_stride;
10390                unsafe {
10391                    let out = sl_mut_f64(*dst, base, out_total);
10392                    let mut cum: usize = 0;
10393                    for (src_off, in_axis) in inputs {
10394                        let in_axis = *in_axis as usize;
10395                        let copy_per_row = in_axis * inner;
10396                        let dst_col_off = cum * inner;
10397                        let in_total = outer * copy_per_row;
10398                        let inp = sl_f64(*src_off, base, in_total);
10399                        for o in 0..outer {
10400                            let dst_row_start = o * row_stride + dst_col_off;
10401                            let src_row_start = o * copy_per_row;
10402                            out[dst_row_start..dst_row_start + copy_per_row]
10403                                .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
10404                        }
10405                        cum += in_axis;
10406                    }
10407                }
10408            }
10409
10410            Thunk::Compare {
10411                lhs,
10412                rhs,
10413                dst,
10414                len,
10415                op,
10416            } => {
10417                let len = *len as usize;
10418                unsafe {
10419                    let l = sl(*lhs, base, len);
10420                    let r = sl(*rhs, base, len);
10421                    let o = sl_mut(*dst, base, len);
10422                    for i in 0..len {
10423                        o[i] = match op {
10424                            CmpOp::Eq => (l[i] == r[i]) as u32 as f32,
10425                            CmpOp::Ne => (l[i] != r[i]) as u32 as f32,
10426                            CmpOp::Lt => (l[i] < r[i]) as u32 as f32,
10427                            CmpOp::Le => (l[i] <= r[i]) as u32 as f32,
10428                            CmpOp::Gt => (l[i] > r[i]) as u32 as f32,
10429                            CmpOp::Ge => (l[i] >= r[i]) as u32 as f32,
10430                        };
10431                    }
10432                }
10433            }
10434
10435            Thunk::Where {
10436                cond,
10437                on_true,
10438                on_false,
10439                dst,
10440                len,
10441            } => {
10442                let len = *len as usize;
10443                unsafe {
10444                    let c = sl(*cond, base, len);
10445                    let t = sl(*on_true, base, len);
10446                    let e = sl(*on_false, base, len);
10447                    let o = sl_mut(*dst, base, len);
10448                    for i in 0..len {
10449                        // Treat cond as boolean: any non-zero → true.
10450                        o[i] = if c[i] != 0.0 { t[i] } else { e[i] };
10451                    }
10452                }
10453            }
10454
10455            Thunk::ScatterAdd {
10456                updates,
10457                indices,
10458                dst,
10459                num_updates,
10460                out_dim,
10461                trailing,
10462            } => {
10463                let num_updates = *num_updates as usize;
10464                let out_dim = *out_dim as usize;
10465                let trailing = *trailing as usize;
10466                unsafe {
10467                    let upd = sl(*updates, base, num_updates * trailing);
10468                    let ids = sl(*indices, base, num_updates);
10469                    let out = sl_mut(*dst, base, out_dim * trailing);
10470                    // Zero the output first — semantics are accumulate-into-zeros.
10471                    for v in out.iter_mut() {
10472                        *v = 0.0;
10473                    }
10474                    for i in 0..num_updates {
10475                        let row = ids[i] as usize;
10476                        debug_assert!(row < out_dim, "ScatterAdd index out of range");
10477                        let src_off = i * trailing;
10478                        let dst_off = row * trailing;
10479                        for j in 0..trailing {
10480                            out[dst_off + j] += upd[src_off + j];
10481                        }
10482                    }
10483                }
10484            }
10485
10486            Thunk::GroupedMatMul {
10487                input,
10488                weight,
10489                expert_idx,
10490                dst,
10491                m,
10492                k_dim,
10493                n,
10494                num_experts,
10495            } => {
10496                let m = *m as usize;
10497                let k_dim = *k_dim as usize;
10498                let n = *n as usize;
10499                let num_experts = *num_experts as usize;
10500                unsafe {
10501                    let inp = sl(*input, base, m * k_dim);
10502                    let wt = sl(*weight, base, num_experts * k_dim * n);
10503                    let ids = sl(*expert_idx, base, m);
10504                    let out = sl_mut(*dst, base, m * n);
10505
10506                    // Counting-sort tokens by their assigned expert.
10507                    // counts[e] = how many tokens routed to expert e.
10508                    let mut counts = vec![0usize; num_experts];
10509                    for i in 0..m {
10510                        let e = ids[i] as usize;
10511                        debug_assert!(
10512                            e < num_experts,
10513                            "expert_idx out of range: {e} >= {num_experts}"
10514                        );
10515                        counts[e] += 1;
10516                    }
10517                    // Cumulative offsets into the packed buffer.
10518                    let mut offsets = vec![0usize; num_experts + 1];
10519                    for e in 0..num_experts {
10520                        offsets[e + 1] = offsets[e] + counts[e];
10521                    }
10522                    // Pack: each expert's rows land contiguously in `packed_in`.
10523                    // `original_pos[packed_idx] = original_token_idx` for the
10524                    // unpermute step at the end.
10525                    let mut packed_in = vec![0f32; m * k_dim];
10526                    let mut original_pos = vec![0usize; m];
10527                    let mut write_idx = vec![0usize; num_experts];
10528                    for i in 0..m {
10529                        let e = ids[i] as usize;
10530                        let dst_row = offsets[e] + write_idx[e];
10531                        packed_in[dst_row * k_dim..(dst_row + 1) * k_dim]
10532                            .copy_from_slice(&inp[i * k_dim..(i + 1) * k_dim]);
10533                        original_pos[dst_row] = i;
10534                        write_idx[e] += 1;
10535                    }
10536
10537                    // One BLAS sgemm per expert. Skip experts with no
10538                    // tokens — common at the tail when M is much smaller
10539                    // than num_experts × k.
10540                    let mut packed_out = vec![0f32; m * n];
10541                    let expert_stride = k_dim * n;
10542                    let gmm_ord = crate::moe_residency::next_gmm_ord();
10543                    let moe_layer = gmm_ord / 3;
10544                    for e in 0..num_experts {
10545                        let count = counts[e];
10546                        if count == 0 {
10547                            continue;
10548                        }
10549                        crate::moe_residency::record_expert_tokens(moe_layer, e, count);
10550                        let in_start = offsets[e];
10551                        let in_slice = &packed_in[in_start * k_dim..(in_start + count) * k_dim];
10552                        let w_slab: &[f32] =
10553                            if !crate::moe_residency::expert_on_device_for_layer(moe_layer, e) {
10554                                if let Some(ptr) =
10555                                    crate::moe_residency::host_expert_weight_ptr(gmm_ord, e)
10556                                {
10557                                    std::slice::from_raw_parts(ptr, expert_stride)
10558                                } else {
10559                                    &wt[e * expert_stride..(e + 1) * expert_stride]
10560                                }
10561                            } else {
10562                                &wt[e * expert_stride..(e + 1) * expert_stride]
10563                            };
10564                        let out_slice = &mut packed_out[in_start * n..(in_start + count) * n];
10565                        crate::blas::sgemm(in_slice, w_slab, out_slice, count, k_dim, n);
10566                    }
10567
10568                    // Unpermute back to original token order.
10569                    for packed_idx in 0..m {
10570                        let i = original_pos[packed_idx];
10571                        out[i * n..(i + 1) * n]
10572                            .copy_from_slice(&packed_out[packed_idx * n..(packed_idx + 1) * n]);
10573                    }
10574                }
10575            }
10576
10577            Thunk::DequantGroupedMatMulGguf {
10578                input,
10579                w_q,
10580                expert_idx,
10581                dst,
10582                m,
10583                k_dim,
10584                n,
10585                num_experts,
10586                scheme,
10587            } => {
10588                let m = *m as usize;
10589                let k_dim = *k_dim as usize;
10590                let n = *n as usize;
10591                let num_experts = *num_experts as usize;
10592                let block_elems = scheme.gguf_block_size() as usize;
10593                let block_bytes = scheme.gguf_block_bytes() as usize;
10594                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
10595                unsafe {
10596                    let inp = sl(*input, base, m * k_dim);
10597                    let wt = std::slice::from_raw_parts(
10598                        base.add(*w_q) as *const u8,
10599                        num_experts * slab_bytes,
10600                    );
10601                    let ids = sl(*expert_idx, base, m);
10602                    let out = sl_mut(*dst, base, m * n);
10603                    crate::gguf_matmul::gguf_grouped_matmul_bt(
10604                        inp,
10605                        wt,
10606                        ids,
10607                        out,
10608                        m,
10609                        k_dim,
10610                        n,
10611                        num_experts,
10612                        *scheme,
10613                    );
10614                }
10615            }
10616
10617            Thunk::DequantMoEWeightsGguf {
10618                w_q,
10619                dst,
10620                k_dim,
10621                n,
10622                num_experts,
10623                scheme,
10624            } => {
10625                let k_dim = *k_dim as usize;
10626                let n = *n as usize;
10627                let num_experts = *num_experts as usize;
10628                let block_elems = scheme.gguf_block_size() as usize;
10629                let block_bytes = scheme.gguf_block_bytes() as usize;
10630                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
10631                unsafe {
10632                    let wt = std::slice::from_raw_parts(
10633                        base.add(*w_q) as *const u8,
10634                        num_experts * slab_bytes,
10635                    );
10636                    let out = sl_mut(*dst, base, num_experts * k_dim * n);
10637                    crate::gguf_matmul::dequant_moe_weights_to_grouped_f32(
10638                        wt,
10639                        out,
10640                        num_experts,
10641                        k_dim,
10642                        n,
10643                        *scheme,
10644                    );
10645                }
10646            }
10647
10648            Thunk::TopK {
10649                src,
10650                dst,
10651                outer,
10652                axis_dim,
10653                k,
10654            } => {
10655                let outer = *outer as usize;
10656                let axis_dim = *axis_dim as usize;
10657                let k = *k as usize;
10658                unsafe {
10659                    let inp = sl(*src, base, outer * axis_dim);
10660                    let out = sl_mut(*dst, base, outer * k);
10661                    // Repeated argmax with masking. O(k * axis_dim) per row;
10662                    // good enough for small k (MoE typical k=2–8). For larger
10663                    // k a partial heap would win.
10664                    let mut row_buf: Vec<f32> = vec![0.0; axis_dim];
10665                    for o in 0..outer {
10666                        row_buf.copy_from_slice(&inp[o * axis_dim..(o + 1) * axis_dim]);
10667                        for ki in 0..k {
10668                            // Find argmax with tie-break to smaller index.
10669                            let mut best_i = 0usize;
10670                            let mut best_v = row_buf[0];
10671                            for i in 1..axis_dim {
10672                                let v = row_buf[i];
10673                                if v > best_v {
10674                                    best_v = v;
10675                                    best_i = i;
10676                                }
10677                            }
10678                            out[o * k + ki] = best_i as f32;
10679                            // Mask the chosen index so the next pass picks
10680                            // the next-largest instead.
10681                            row_buf[best_i] = f32::NEG_INFINITY;
10682                        }
10683                    }
10684                    if let Some(cap) = schedule.moe_topk_capture.as_ref() {
10685                        cap.push_topk_f32(&out[..outer * k], axis_dim);
10686                    }
10687                }
10688            }
10689
10690            Thunk::Reduce {
10691                src,
10692                dst,
10693                outer,
10694                reduced,
10695                inner,
10696                op,
10697            } => {
10698                let outer = *outer as usize;
10699                let reduced = *reduced as usize;
10700                let inner = *inner as usize;
10701                let in_total = outer * reduced * inner;
10702                let out_total = outer * inner;
10703                unsafe {
10704                    let inp = sl(*src, base, in_total);
10705                    let out = sl_mut(*dst, base, out_total);
10706                    for o in 0..outer {
10707                        for i in 0..inner {
10708                            let mut acc = match op {
10709                                ReduceOp::Max => f32::NEG_INFINITY,
10710                                ReduceOp::Min => f32::INFINITY,
10711                                ReduceOp::Prod => 1.0f32,
10712                                _ => 0.0f32, // Sum / Mean
10713                            };
10714                            // Walk the reduced axis with stride `inner`.
10715                            for r in 0..reduced {
10716                                let v = inp[o * reduced * inner + r * inner + i];
10717                                acc = match op {
10718                                    ReduceOp::Sum | ReduceOp::Mean => acc + v,
10719                                    ReduceOp::Max => acc.max(v),
10720                                    ReduceOp::Min => acc.min(v),
10721                                    ReduceOp::Prod => acc * v,
10722                                };
10723                            }
10724                            if matches!(op, ReduceOp::Mean) {
10725                                acc /= reduced as f32;
10726                            }
10727                            out[o * inner + i] = acc;
10728                        }
10729                    }
10730                }
10731            }
10732
10733            Thunk::Conv2D1x1 {
10734                src,
10735                weight,
10736                dst,
10737                n,
10738                c_in,
10739                c_out,
10740                hw,
10741            } => {
10742                let n = *n as usize;
10743                let c_in = *c_in as usize;
10744                let c_out = *c_out as usize;
10745                let hw = *hw as usize;
10746                unsafe {
10747                    let inp = sl(*src, base, n * c_in * hw);
10748                    let wt = sl(*weight, base, c_out * c_in);
10749                    let out = sl_mut(*dst, base, n * c_out * hw);
10750                    // Per-batch sgemm: weight [c_out, c_in] @ input
10751                    // [c_in, hw] = output [c_out, hw]. The weight is
10752                    // shared across batches, so we get to dispatch
10753                    // BLAS once per N (typically 1).
10754                    for ni in 0..n {
10755                        let in_off = ni * c_in * hw;
10756                        let out_off = ni * c_out * hw;
10757                        crate::blas::sgemm(
10758                            wt,
10759                            &inp[in_off..in_off + c_in * hw],
10760                            &mut out[out_off..out_off + c_out * hw],
10761                            c_out,
10762                            c_in,
10763                            hw,
10764                        );
10765                    }
10766                }
10767            }
10768
10769            Thunk::Conv2D {
10770                src,
10771                weight,
10772                dst,
10773                n,
10774                c_in,
10775                h,
10776                w,
10777                c_out,
10778                h_out,
10779                w_out,
10780                kh,
10781                kw,
10782                sh,
10783                sw,
10784                ph,
10785                pw,
10786                dh,
10787                dw,
10788                groups,
10789            } => {
10790                let n = *n as usize;
10791                let c_in = *c_in as usize;
10792                let h = *h as usize;
10793                let w = *w as usize;
10794                let c_out = *c_out as usize;
10795                let h_out = *h_out as usize;
10796                let w_out = *w_out as usize;
10797                let kh = *kh as usize;
10798                let kw = *kw as usize;
10799                let sh = *sh as usize;
10800                let sw = *sw as usize;
10801                let ph = *ph as usize;
10802                let pw = *pw as usize;
10803                let dh = *dh as usize;
10804                let dw = *dw as usize;
10805                let groups = *groups as usize;
10806                let c_in_per_g = c_in / groups;
10807                let c_out_per_g = c_out / groups;
10808                unsafe {
10809                    let inp = sl(*src, base, n * c_in * h * w);
10810                    let wt = sl(*weight, base, c_out * c_in_per_g * kh * kw);
10811                    let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
10812                    for ni in 0..n {
10813                        for co in 0..c_out {
10814                            let g = co / c_out_per_g;
10815                            let ci_start = g * c_in_per_g;
10816                            for ho in 0..h_out {
10817                                for wo in 0..w_out {
10818                                    let mut acc = 0f32;
10819                                    for ci_off in 0..c_in_per_g {
10820                                        let ci = ci_start + ci_off;
10821                                        let in_chan = ((ni * c_in) + ci) * h * w;
10822                                        let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
10823                                        for ki in 0..kh {
10824                                            for kj in 0..kw {
10825                                                let hi = ho * sh + ki * dh;
10826                                                let wi = wo * sw + kj * dw;
10827                                                if hi < ph || wi < pw {
10828                                                    continue;
10829                                                }
10830                                                let hi = hi - ph;
10831                                                let wi = wi - pw;
10832                                                if hi >= h || wi >= w {
10833                                                    continue;
10834                                                }
10835                                                acc += inp[in_chan + hi * w + wi]
10836                                                    * wt[wt_chan + ki * kw + kj];
10837                                            }
10838                                        }
10839                                    }
10840                                    out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] =
10841                                        acc;
10842                                }
10843                            }
10844                        }
10845                    }
10846                }
10847            }
10848
10849            Thunk::Pool2D {
10850                src,
10851                dst,
10852                n,
10853                c,
10854                h,
10855                w,
10856                h_out,
10857                w_out,
10858                kh,
10859                kw,
10860                sh,
10861                sw,
10862                ph,
10863                pw,
10864                kind,
10865            } => {
10866                let n = *n as usize;
10867                let c = *c as usize;
10868                let h = *h as usize;
10869                let w = *w as usize;
10870                let h_out = *h_out as usize;
10871                let w_out = *w_out as usize;
10872                let kh = *kh as usize;
10873                let kw = *kw as usize;
10874                let sh = *sh as usize;
10875                let sw = *sw as usize;
10876                let ph = *ph as usize;
10877                let pw = *pw as usize;
10878                let kernel_area = (kh * kw) as f32;
10879                unsafe {
10880                    let inp = sl(*src, base, n * c * h * w);
10881                    let out = sl_mut(*dst, base, n * c * h_out * w_out);
10882                    for ni in 0..n {
10883                        for ci in 0..c {
10884                            let in_chan = ni * c * h * w + ci * h * w;
10885                            let out_chan = ni * c * h_out * w_out + ci * h_out * w_out;
10886                            for ho in 0..h_out {
10887                                for wo in 0..w_out {
10888                                    let mut acc = match kind {
10889                                        ReduceOp::Max => f32::NEG_INFINITY,
10890                                        _ => 0f32, // Mean (and Sum/Min/Prod fall back here)
10891                                    };
10892                                    for ki in 0..kh {
10893                                        for kj in 0..kw {
10894                                            let hi = ho * sh + ki;
10895                                            let wi = wo * sw + kj;
10896                                            // Padded-zero region.
10897                                            if hi < ph || wi < pw {
10898                                                continue;
10899                                            }
10900                                            let hi = hi - ph;
10901                                            let wi = wi - pw;
10902                                            if hi >= h || wi >= w {
10903                                                continue;
10904                                            }
10905                                            let v = inp[in_chan + hi * w + wi];
10906                                            match kind {
10907                                                ReduceOp::Max => acc = acc.max(v),
10908                                                _ => acc += v,
10909                                            }
10910                                        }
10911                                    }
10912                                    if matches!(kind, ReduceOp::Mean) {
10913                                        acc /= kernel_area;
10914                                    }
10915                                    out[out_chan + ho * w_out + wo] = acc;
10916                                }
10917                            }
10918                        }
10919                    }
10920                }
10921            }
10922
10923            Thunk::ReluBackward { x, dy, dx, len } => {
10924                let len = *len as usize;
10925                unsafe {
10926                    let xs = sl(*x, base, len);
10927                    let dys = sl(*dy, base, len);
10928                    let out = sl_mut(*dx, base, len);
10929                    for i in 0..len {
10930                        out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
10931                    }
10932                }
10933            }
10934
10935            Thunk::ReluBackwardF64 { x, dy, dx, len } => {
10936                let len = *len as usize;
10937                unsafe {
10938                    let xs = sl_f64(*x, base, len);
10939                    let dys = sl_f64(*dy, base, len);
10940                    let out = sl_mut_f64(*dx, base, len);
10941                    for i in 0..len {
10942                        out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
10943                    }
10944                }
10945            }
10946
10947            Thunk::QMatMul {
10948                x,
10949                w,
10950                bias,
10951                out,
10952                m,
10953                k,
10954                n,
10955                x_zp,
10956                w_zp,
10957                out_zp,
10958                mult,
10959            } => {
10960                let m = *m as usize;
10961                let k = *k as usize;
10962                let n = *n as usize;
10963                unsafe {
10964                    let x_ptr = base.add(*x) as *const i8;
10965                    let w_ptr = base.add(*w) as *const i8;
10966                    let bias_ptr = base.add(*bias) as *const i32;
10967                    let out_ptr = base.add(*out) as *mut i8;
10968                    for mi in 0..m {
10969                        for ni in 0..n {
10970                            let mut acc: i32 = *bias_ptr.add(ni);
10971                            for ki in 0..k {
10972                                let xv = *x_ptr.add(mi * k + ki) as i32 - *x_zp;
10973                                let wv = *w_ptr.add(ki * n + ni) as i32 - *w_zp;
10974                                acc += xv * wv;
10975                            }
10976                            // Requantize: round(acc · mult) + out_zp,
10977                            // clamped to i8.
10978                            let r = (acc as f32 * *mult).round() as i32 + *out_zp;
10979                            let r = r.clamp(-128, 127) as i8;
10980                            *out_ptr.add(mi * n + ni) = r;
10981                        }
10982                    }
10983                }
10984            }
10985
10986            Thunk::QConv2d {
10987                x,
10988                w,
10989                bias,
10990                out,
10991                n,
10992                c_in,
10993                h,
10994                w_in,
10995                c_out,
10996                h_out,
10997                w_out,
10998                kh,
10999                kw,
11000                sh,
11001                sw,
11002                ph,
11003                pw,
11004                dh,
11005                dw,
11006                groups,
11007                x_zp,
11008                w_zp,
11009                out_zp,
11010                mult,
11011            } => {
11012                let n = *n as usize;
11013                let c_in = *c_in as usize;
11014                let h = *h as usize;
11015                let w_in = *w_in as usize;
11016                let c_out = *c_out as usize;
11017                let h_out = *h_out as usize;
11018                let w_out = *w_out as usize;
11019                let kh = *kh as usize;
11020                let kw = *kw as usize;
11021                let sh = *sh as usize;
11022                let sw = *sw as usize;
11023                let ph = *ph as usize;
11024                let pw = *pw as usize;
11025                let dh = *dh as usize;
11026                let dw = *dw as usize;
11027                let groups = *groups as usize;
11028                let c_in_per_g = c_in / groups;
11029                let c_out_per_g = c_out / groups;
11030                unsafe {
11031                    let x_ptr = base.add(*x) as *const i8;
11032                    let w_ptr = base.add(*w) as *const i8;
11033                    let bias_ptr = base.add(*bias) as *const i32;
11034                    let out_ptr = base.add(*out) as *mut i8;
11035                    for ni in 0..n {
11036                        for co in 0..c_out {
11037                            let g = co / c_out_per_g;
11038                            let ci_start = g * c_in_per_g;
11039                            for ho in 0..h_out {
11040                                for wo in 0..w_out {
11041                                    let mut acc: i32 = *bias_ptr.add(co);
11042                                    for ci_off in 0..c_in_per_g {
11043                                        let ci = ci_start + ci_off;
11044                                        let in_chan = ((ni * c_in) + ci) * h * w_in;
11045                                        let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
11046                                        for ki in 0..kh {
11047                                            for kj in 0..kw {
11048                                                let hi = ho * sh + ki * dh;
11049                                                let wi = wo * sw + kj * dw;
11050                                                if hi < ph || wi < pw {
11051                                                    continue;
11052                                                }
11053                                                let hi = hi - ph;
11054                                                let wi = wi - pw;
11055                                                if hi >= h || wi >= w_in {
11056                                                    continue;
11057                                                }
11058                                                let xv = *x_ptr.add(in_chan + hi * w_in + wi)
11059                                                    as i32
11060                                                    - *x_zp;
11061                                                let wv = *w_ptr.add(wt_chan + ki * kw + kj) as i32
11062                                                    - *w_zp;
11063                                                acc += xv * wv;
11064                                            }
11065                                        }
11066                                    }
11067                                    let r = (acc as f32 * *mult).round() as i32 + *out_zp;
11068                                    let r = r.clamp(-128, 127) as i8;
11069                                    let dst = ((ni * c_out) + co) * h_out * w_out + ho * w_out + wo;
11070                                    *out_ptr.add(dst) = r;
11071                                }
11072                            }
11073                        }
11074                    }
11075                }
11076            }
11077
11078            Thunk::Quantize {
11079                x,
11080                q,
11081                len,
11082                chan_axis: _,
11083                chan_dim,
11084                inner,
11085                scales,
11086                zero_points,
11087            } => {
11088                let len = *len as usize;
11089                let chan_dim = *chan_dim as usize;
11090                let inner = *inner as usize;
11091                unsafe {
11092                    let xs = sl(*x, base, len);
11093                    let q_ptr = base.add(*q) as *mut i8;
11094                    for i in 0..len {
11095                        let c = if chan_dim == 1 {
11096                            0
11097                        } else {
11098                            (i / inner) % chan_dim
11099                        };
11100                        let inv_scale = 1.0 / scales[c];
11101                        let zp = zero_points[c];
11102                        let v = (xs[i] * inv_scale).round() as i32 + zp;
11103                        *q_ptr.add(i) = v.clamp(-128, 127) as i8;
11104                    }
11105                }
11106            }
11107
11108            Thunk::Dequantize {
11109                q,
11110                x,
11111                len,
11112                chan_axis: _,
11113                chan_dim,
11114                inner,
11115                scales,
11116                zero_points,
11117            } => {
11118                let len = *len as usize;
11119                let chan_dim = *chan_dim as usize;
11120                let inner = *inner as usize;
11121                unsafe {
11122                    let q_ptr = base.add(*q) as *const i8;
11123                    let out = sl_mut(*x, base, len);
11124                    for i in 0..len {
11125                        let c = if chan_dim == 1 {
11126                            0
11127                        } else {
11128                            (i / inner) % chan_dim
11129                        };
11130                        let scale = scales[c];
11131                        let zp = zero_points[c];
11132                        let qv = *q_ptr.add(i) as i32;
11133                        out[i] = (qv - zp) as f32 * scale;
11134                    }
11135                }
11136            }
11137
11138            Thunk::FakeQuantize {
11139                x,
11140                out,
11141                len,
11142                chan_axis: _,
11143                chan_dim,
11144                inner,
11145                bits,
11146                ste: _,
11147                scale_mode,
11148                state_off,
11149            } => {
11150                use rlx_ir::op::ScaleMode;
11151                let len = *len as usize;
11152                let chan_dim = *chan_dim as usize;
11153                let inner = *inner as usize;
11154                let q_max: f32 = match *bits {
11155                    8 => 127.0,
11156                    4 => 7.0,
11157                    2 => 1.0,
11158                    n => panic!("FakeQuantize: unsupported bits {n}"),
11159                };
11160                unsafe {
11161                    let xs = sl(*x, base, len);
11162                    let outs = sl_mut(*out, base, len);
11163
11164                    let mut scale = vec![0f32; chan_dim];
11165                    match scale_mode {
11166                        ScaleMode::PerBatch => {
11167                            let mut max_abs = vec![0f32; chan_dim];
11168                            for i in 0..len {
11169                                let c = if chan_dim == 1 {
11170                                    0
11171                                } else {
11172                                    (i / inner) % chan_dim
11173                                };
11174                                let a = xs[i].abs();
11175                                if a > max_abs[c] {
11176                                    max_abs[c] = a;
11177                                }
11178                            }
11179                            for c in 0..chan_dim {
11180                                scale[c] = (max_abs[c] / q_max).max(1e-12);
11181                            }
11182                        }
11183                        ScaleMode::EMA { decay } => {
11184                            // Per-channel current max-abs, then blend
11185                            // into the running state in place.
11186                            let mut max_abs = vec![0f32; chan_dim];
11187                            for i in 0..len {
11188                                let c = if chan_dim == 1 {
11189                                    0
11190                                } else {
11191                                    (i / inner) % chan_dim
11192                                };
11193                                let a = xs[i].abs();
11194                                if a > max_abs[c] {
11195                                    max_abs[c] = a;
11196                                }
11197                            }
11198                            let state =
11199                                sl_mut(state_off.expect("EMA needs state_off"), base, chan_dim);
11200                            for c in 0..chan_dim {
11201                                let cur = (max_abs[c] / q_max).max(1e-12);
11202                                // Cold-start: state==0 → seed directly.
11203                                let blended = if state[c] <= 0.0 {
11204                                    cur
11205                                } else {
11206                                    *decay * state[c] + (1.0 - *decay) * cur
11207                                };
11208                                state[c] = blended;
11209                                scale[c] = blended;
11210                            }
11211                        }
11212                        ScaleMode::Fixed => {
11213                            let state =
11214                                sl(state_off.expect("Fixed needs state_off"), base, chan_dim);
11215                            for c in 0..chan_dim {
11216                                scale[c] = state[c].max(1e-12);
11217                            }
11218                        }
11219                    }
11220
11221                    for i in 0..len {
11222                        let c = if chan_dim == 1 {
11223                            0
11224                        } else {
11225                            (i / inner) % chan_dim
11226                        };
11227                        let s = scale[c];
11228                        let qv = (xs[i] / s).round().clamp(-q_max, q_max);
11229                        outs[i] = qv * s;
11230                    }
11231                }
11232            }
11233
11234            Thunk::ActivationBackward {
11235                x,
11236                dy,
11237                dx,
11238                len,
11239                kind,
11240            } => {
11241                let len = *len as usize;
11242                unsafe {
11243                    let xs = sl(*x, base, len);
11244                    let dys = sl(*dy, base, len);
11245                    let out = sl_mut(*dx, base, len);
11246                    activation_backward_kernel(*kind, xs, dys, out);
11247                }
11248            }
11249
11250            Thunk::ActivationBackwardF64 {
11251                x,
11252                dy,
11253                dx,
11254                len,
11255                kind,
11256            } => {
11257                let len = *len as usize;
11258                unsafe {
11259                    let xs = sl_f64(*x, base, len);
11260                    let dys = sl_f64(*dy, base, len);
11261                    let out = sl_mut_f64(*dx, base, len);
11262                    activation_backward_kernel_f64(*kind, xs, dys, out);
11263                }
11264            }
11265
11266            Thunk::FakeQuantizeLSQ {
11267                x,
11268                scale_off,
11269                out,
11270                len,
11271                chan_axis: _,
11272                chan_dim,
11273                inner,
11274                bits,
11275            } => {
11276                let len = *len as usize;
11277                let chan_dim = *chan_dim as usize;
11278                let inner = *inner as usize;
11279                let q_max: f32 = match *bits {
11280                    8 => 127.0,
11281                    4 => 7.0,
11282                    2 => 1.0,
11283                    n => panic!("FakeQuantizeLSQ: bad bits {n}"),
11284                };
11285                unsafe {
11286                    let xs = sl(*x, base, len);
11287                    let scale = sl(*scale_off, base, chan_dim);
11288                    let outs = sl_mut(*out, base, len);
11289                    for i in 0..len {
11290                        let c = if chan_dim == 1 {
11291                            0
11292                        } else {
11293                            (i / inner) % chan_dim
11294                        };
11295                        let s = scale[c].max(1e-12);
11296                        let qv = (xs[i] / s).round().clamp(-q_max, q_max);
11297                        outs[i] = qv * s;
11298                    }
11299                }
11300            }
11301
11302            Thunk::FakeQuantizeLSQBackwardX {
11303                x,
11304                scale_off,
11305                dy,
11306                dx,
11307                len,
11308                chan_axis: _,
11309                chan_dim,
11310                inner,
11311                bits,
11312            } => {
11313                let len = *len as usize;
11314                let chan_dim = *chan_dim as usize;
11315                let inner = *inner as usize;
11316                let q_max: f32 = match *bits {
11317                    8 => 127.0,
11318                    4 => 7.0,
11319                    2 => 1.0,
11320                    n => panic!("FakeQuantizeLSQBackwardX: bad bits {n}"),
11321                };
11322                unsafe {
11323                    let xs = sl(*x, base, len);
11324                    let scale = sl(*scale_off, base, chan_dim);
11325                    let dys = sl(*dy, base, len);
11326                    let outs = sl_mut(*dx, base, len);
11327                    // STE-clipped: dx = dy when |x/s| ≤ q_max, else 0.
11328                    for i in 0..len {
11329                        let c = if chan_dim == 1 {
11330                            0
11331                        } else {
11332                            (i / inner) % chan_dim
11333                        };
11334                        let z = xs[i] / scale[c].max(1e-12);
11335                        outs[i] = if z.abs() <= q_max { dys[i] } else { 0.0 };
11336                    }
11337                }
11338            }
11339
11340            Thunk::FakeQuantizeLSQBackwardScale {
11341                x,
11342                scale_off,
11343                dy,
11344                dscale,
11345                len,
11346                chan_axis: _,
11347                chan_dim,
11348                inner,
11349                bits,
11350            } => {
11351                let len = *len as usize;
11352                let chan_dim = *chan_dim as usize;
11353                let inner = *inner as usize;
11354                let q_max: f32 = match *bits {
11355                    8 => 127.0,
11356                    4 => 7.0,
11357                    2 => 1.0,
11358                    n => panic!("FakeQuantizeLSQBackwardScale: bad bits {n}"),
11359                };
11360                unsafe {
11361                    let xs = sl(*x, base, len);
11362                    let scale = sl(*scale_off, base, chan_dim);
11363                    let dys = sl(*dy, base, len);
11364                    let outs = sl_mut(*dscale, base, chan_dim);
11365                    for v in outs.iter_mut() {
11366                        *v = 0.0;
11367                    }
11368                    // ψ(z) = -z + round(z) inside range, sign(z)·q_max outside.
11369                    // dscale[c] = sum_i ψ(x_i/s[c]) * upstream[i].
11370                    for i in 0..len {
11371                        let c = if chan_dim == 1 {
11372                            0
11373                        } else {
11374                            (i / inner) % chan_dim
11375                        };
11376                        let s = scale[c].max(1e-12);
11377                        let z = xs[i] / s;
11378                        let psi = if z.abs() <= q_max {
11379                            -z + z.round()
11380                        } else if z > 0.0 {
11381                            q_max
11382                        } else {
11383                            -q_max
11384                        };
11385                        outs[c] += psi * dys[i];
11386                    }
11387                }
11388            }
11389
11390            Thunk::FakeQuantizeBackward {
11391                x,
11392                dy,
11393                dx,
11394                len,
11395                chan_axis: _,
11396                chan_dim,
11397                inner,
11398                bits,
11399                ste,
11400            } => {
11401                use rlx_ir::op::SteKind;
11402                let len = *len as usize;
11403                let chan_dim = *chan_dim as usize;
11404                let inner = *inner as usize;
11405                let q_max: f32 = match *bits {
11406                    8 => 127.0,
11407                    4 => 7.0,
11408                    2 => 1.0,
11409                    n => panic!("FakeQuantizeBackward: bad bits {n}"),
11410                };
11411                unsafe {
11412                    let xs = sl(*x, base, len);
11413                    let dys = sl(*dy, base, len);
11414                    let outs = sl_mut(*dx, base, len);
11415
11416                    // Per-channel max-abs → scale, same as forward.
11417                    let mut max_abs = vec![0f32; chan_dim];
11418                    for i in 0..len {
11419                        let c = if chan_dim == 1 {
11420                            0
11421                        } else {
11422                            (i / inner) % chan_dim
11423                        };
11424                        let a = xs[i].abs();
11425                        if a > max_abs[c] {
11426                            max_abs[c] = a;
11427                        }
11428                    }
11429                    let mut scale = vec![0f32; chan_dim];
11430                    for c in 0..chan_dim {
11431                        scale[c] = (max_abs[c] / q_max).max(1e-12);
11432                    }
11433
11434                    match *ste {
11435                        SteKind::Identity => {
11436                            // dx = dy unchanged.
11437                            outs.copy_from_slice(dys);
11438                        }
11439                        SteKind::ClippedIdentity => {
11440                            // dx = dy * (|x| <= q_max·s); zero if the
11441                            // forward saturated.
11442                            for i in 0..len {
11443                                let c = if chan_dim == 1 {
11444                                    0
11445                                } else {
11446                                    (i / inner) % chan_dim
11447                                };
11448                                let bound = q_max * scale[c];
11449                                outs[i] = if xs[i].abs() <= bound { dys[i] } else { 0.0 };
11450                            }
11451                        }
11452                        SteKind::Tanh => {
11453                            // dx = dy * (1 - tanh²(x/s)).
11454                            for i in 0..len {
11455                                let c = if chan_dim == 1 {
11456                                    0
11457                                } else {
11458                                    (i / inner) % chan_dim
11459                                };
11460                                let t = (xs[i] / scale[c]).tanh();
11461                                outs[i] = dys[i] * (1.0 - t * t);
11462                            }
11463                        }
11464                        SteKind::HardTanh => {
11465                            // dx = dy * max(0, 1 - |x/(q_max·s)|).
11466                            for i in 0..len {
11467                                let c = if chan_dim == 1 {
11468                                    0
11469                                } else {
11470                                    (i / inner) % chan_dim
11471                                };
11472                                let bound = q_max * scale[c];
11473                                let attenuation = (1.0 - (xs[i] / bound).abs()).max(0.0);
11474                                outs[i] = dys[i] * attenuation;
11475                            }
11476                        }
11477                    }
11478                }
11479            }
11480
11481            Thunk::LayerNormBackwardInput {
11482                x,
11483                gamma,
11484                dy,
11485                dx,
11486                rows,
11487                h,
11488                eps,
11489            } => {
11490                let rows = *rows as usize;
11491                let h = *h as usize;
11492                let eps = *eps;
11493                unsafe {
11494                    let xs = sl(*x, base, rows * h);
11495                    let g = sl(*gamma, base, h);
11496                    let dys = sl(*dy, base, rows * h);
11497                    let out = sl_mut(*dx, base, rows * h);
11498                    let n_inv = 1.0 / h as f32;
11499                    for r in 0..rows {
11500                        let xr = &xs[r * h..(r + 1) * h];
11501                        let dyr = &dys[r * h..(r + 1) * h];
11502                        // Per-row mean and inv_std (recompute — no saved
11503                        // tensor from the forward pass).
11504                        let mut sum = 0f32;
11505                        for &v in xr {
11506                            sum += v;
11507                        }
11508                        let mean = sum * n_inv;
11509                        let mut var = 0f32;
11510                        for &v in xr {
11511                            let d = v - mean;
11512                            var += d * d;
11513                        }
11514                        let inv_std = 1.0 / (var * n_inv + eps).sqrt();
11515
11516                        // sums needed for the closed-form:
11517                        //   mean(dy·γ) and mean(dy·γ·x̂)
11518                        let mut s_sy = 0f32;
11519                        let mut s_sxh = 0f32;
11520                        for d in 0..h {
11521                            let xh = (xr[d] - mean) * inv_std;
11522                            let sy = dyr[d] * g[d];
11523                            s_sy += sy;
11524                            s_sxh += sy * xh;
11525                        }
11526                        let m_sy = s_sy * n_inv;
11527                        let m_sxh = s_sxh * n_inv;
11528
11529                        for d in 0..h {
11530                            let xh = (xr[d] - mean) * inv_std;
11531                            let sy = dyr[d] * g[d];
11532                            out[r * h + d] = inv_std * (sy - m_sy - xh * m_sxh);
11533                        }
11534                    }
11535                }
11536            }
11537
11538            Thunk::LayerNormBackwardGamma {
11539                x,
11540                dy,
11541                dgamma,
11542                rows,
11543                h,
11544                eps,
11545            } => {
11546                let rows = *rows as usize;
11547                let h = *h as usize;
11548                let eps = *eps;
11549                unsafe {
11550                    let xs = sl(*x, base, rows * h);
11551                    let dys = sl(*dy, base, rows * h);
11552                    let out = sl_mut(*dgamma, base, h);
11553                    for v in out.iter_mut() {
11554                        *v = 0.0;
11555                    }
11556                    let n_inv = 1.0 / h as f32;
11557                    for r in 0..rows {
11558                        let xr = &xs[r * h..(r + 1) * h];
11559                        let dyr = &dys[r * h..(r + 1) * h];
11560                        let mut sum = 0f32;
11561                        for &v in xr {
11562                            sum += v;
11563                        }
11564                        let mean = sum * n_inv;
11565                        let mut var = 0f32;
11566                        for &v in xr {
11567                            let d = v - mean;
11568                            var += d * d;
11569                        }
11570                        let inv_std = 1.0 / (var * n_inv + eps).sqrt();
11571                        for d in 0..h {
11572                            let xh = (xr[d] - mean) * inv_std;
11573                            out[d] += dyr[d] * xh;
11574                        }
11575                    }
11576                }
11577            }
11578
11579            Thunk::RmsNormBackwardInput {
11580                x,
11581                gamma,
11582                beta,
11583                dy,
11584                dx,
11585                rows,
11586                h,
11587                eps,
11588            } => {
11589                let (rows, h) = (*rows as usize, *h as usize);
11590                unsafe {
11591                    let xs = sl(*x, base, rows * h);
11592                    let g = sl(*gamma, base, h);
11593                    let b = sl(*beta, base, h);
11594                    let dys = sl(*dy, base, rows * h);
11595                    let out = sl_mut(*dx, base, rows * h);
11596                    let mut dg = vec![0f32; h];
11597                    let mut db = vec![0f32; h];
11598                    for r in 0..rows {
11599                        crate::training_bwd::rms_norm_backward_row(
11600                            &xs[r * h..(r + 1) * h],
11601                            g,
11602                            b,
11603                            &dys[r * h..(r + 1) * h],
11604                            &mut out[r * h..(r + 1) * h],
11605                            &mut dg,
11606                            &mut db,
11607                            *eps,
11608                        );
11609                    }
11610                }
11611            }
11612
11613            Thunk::RmsNormBackwardGamma {
11614                x,
11615                gamma,
11616                beta,
11617                dy,
11618                dgamma,
11619                rows,
11620                h,
11621                eps,
11622            } => {
11623                let (rows, h) = (*rows as usize, *h as usize);
11624                unsafe {
11625                    let xs = sl(*x, base, rows * h);
11626                    let g = sl(*gamma, base, h);
11627                    let b = sl(*beta, base, h);
11628                    let dys = sl(*dy, base, rows * h);
11629                    let out = sl_mut(*dgamma, base, h);
11630                    for v in out.iter_mut() {
11631                        *v = 0.0;
11632                    }
11633                    let mut dx = vec![0f32; h];
11634                    let mut db = vec![0f32; h];
11635                    for r in 0..rows {
11636                        crate::training_bwd::rms_norm_backward_row(
11637                            &xs[r * h..(r + 1) * h],
11638                            g,
11639                            b,
11640                            &dys[r * h..(r + 1) * h],
11641                            &mut dx,
11642                            &mut *out,
11643                            &mut db,
11644                            *eps,
11645                        );
11646                    }
11647                }
11648            }
11649
11650            Thunk::RmsNormBackwardBeta {
11651                x,
11652                gamma,
11653                beta,
11654                dy,
11655                dbeta,
11656                rows,
11657                h,
11658                eps,
11659            } => {
11660                let (rows, h) = (*rows as usize, *h as usize);
11661                unsafe {
11662                    let xs = sl(*x, base, rows * h);
11663                    let g = sl(*gamma, base, h);
11664                    let b = sl(*beta, base, h);
11665                    let dys = sl(*dy, base, rows * h);
11666                    let out = sl_mut(*dbeta, base, h);
11667                    for v in out.iter_mut() {
11668                        *v = 0.0;
11669                    }
11670                    let mut dx = vec![0f32; h];
11671                    let mut dg = vec![0f32; h];
11672                    for r in 0..rows {
11673                        crate::training_bwd::rms_norm_backward_row(
11674                            &xs[r * h..(r + 1) * h],
11675                            g,
11676                            b,
11677                            &dys[r * h..(r + 1) * h],
11678                            &mut dx,
11679                            &mut dg,
11680                            &mut *out,
11681                            *eps,
11682                        );
11683                    }
11684                }
11685            }
11686
11687            Thunk::RopeBackward {
11688                dy,
11689                cos,
11690                sin,
11691                dx,
11692                batch,
11693                seq,
11694                hidden,
11695                head_dim,
11696                n_rot,
11697                cos_len,
11698            } => {
11699                let (b, s, hs, dh, nr, cl) = (
11700                    *batch as usize,
11701                    *seq as usize,
11702                    *hidden as usize,
11703                    *head_dim as usize,
11704                    *n_rot as usize,
11705                    *cos_len as usize,
11706                );
11707                let nh = hs / dh;
11708                let tab_half = dh / 2;
11709                unsafe {
11710                    let dys = sl(*dy, base, b * s * hs);
11711                    let cos_tab = sl(*cos, base, cl);
11712                    let sin_tab = sl(*sin, base, cl);
11713                    let out = sl_mut(*dx, base, b * s * hs);
11714                    for bi in 0..b {
11715                        for si in 0..s {
11716                            let tab_off = si.saturating_mul(tab_half) % cl.max(1);
11717                            let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
11718                            let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
11719                            for hi in 0..nh {
11720                                let base_idx = bi * s * hs + si * hs + hi * dh;
11721                                crate::training_bwd::rope_backward_row(
11722                                    &dys[base_idx..base_idx + dh],
11723                                    cp,
11724                                    sp,
11725                                    &mut out[base_idx..base_idx + dh],
11726                                    dh,
11727                                    nr,
11728                                );
11729                            }
11730                        }
11731                    }
11732                }
11733            }
11734
11735            Thunk::CumsumBackward {
11736                dy,
11737                dx,
11738                rows,
11739                cols,
11740                exclusive,
11741            } => {
11742                let (rows, cols) = (*rows as usize, *cols as usize);
11743                unsafe {
11744                    let dys = sl(*dy, base, rows * cols);
11745                    let out = sl_mut(*dx, base, rows * cols);
11746                    for r in 0..rows {
11747                        crate::training_bwd::cumsum_backward_row(
11748                            &dys[r * cols..(r + 1) * cols],
11749                            &mut out[r * cols..(r + 1) * cols],
11750                            *exclusive,
11751                        );
11752                    }
11753                }
11754            }
11755
11756            Thunk::GroupNormBackwardInput {
11757                x,
11758                gamma,
11759                beta: _beta,
11760                dy,
11761                dx,
11762                n,
11763                c,
11764                h,
11765                w,
11766                num_groups,
11767                eps,
11768            } => {
11769                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
11770                let plane = c * h * w;
11771                unsafe {
11772                    let xs = sl(*x, base, n * plane);
11773                    let g = sl(*gamma, base, c);
11774                    let dys = sl(*dy, base, n * plane);
11775                    let out = sl_mut(*dx, base, n * plane);
11776                    crate::training_bwd::group_norm_backward_input_nchw(
11777                        xs,
11778                        g,
11779                        dys,
11780                        out,
11781                        n,
11782                        c,
11783                        h,
11784                        w,
11785                        *num_groups as usize,
11786                        *eps,
11787                    );
11788                }
11789            }
11790
11791            Thunk::GroupNormBackwardGamma {
11792                x,
11793                dy,
11794                dgamma,
11795                n,
11796                c,
11797                h,
11798                w,
11799                num_groups,
11800                eps,
11801            } => {
11802                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
11803                let plane = c * h * w;
11804                unsafe {
11805                    let xs = sl(*x, base, n * plane);
11806                    let dys = sl(*dy, base, n * plane);
11807                    let out = sl_mut(*dgamma, base, c);
11808                    crate::training_bwd::group_norm_backward_gamma_nchw(
11809                        xs,
11810                        dys,
11811                        out,
11812                        n,
11813                        c,
11814                        h,
11815                        w,
11816                        *num_groups as usize,
11817                        *eps,
11818                    );
11819                }
11820            }
11821
11822            Thunk::GroupNormBackwardBeta {
11823                dy,
11824                dbeta,
11825                n,
11826                c,
11827                h,
11828                w,
11829            } => {
11830                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
11831                let plane = c * h * w;
11832                unsafe {
11833                    let dys = sl(*dy, base, n * plane);
11834                    let out = sl_mut(*dbeta, base, c);
11835                    crate::training_bwd::group_norm_backward_beta_nchw(dys, out, n, c, h, w);
11836                }
11837            }
11838
11839            Thunk::GatherBackward {
11840                dy,
11841                indices,
11842                dst,
11843                outer,
11844                axis_dim,
11845                num_idx,
11846                trailing,
11847            } => {
11848                let (outer, axis_dim, num_idx, trailing) = (
11849                    *outer as usize,
11850                    *axis_dim as usize,
11851                    *num_idx as usize,
11852                    *trailing as usize,
11853                );
11854                unsafe {
11855                    let dys = sl(*dy, base, outer * num_idx * trailing);
11856                    let ids = sl(*indices, base, num_idx);
11857                    let out = sl_mut(*dst, base, outer * axis_dim * trailing);
11858                    for v in out.iter_mut() {
11859                        *v = 0.0;
11860                    }
11861                    crate::training_bwd::gather_axis_backward(
11862                        dys, ids, out, outer, axis_dim, num_idx, trailing,
11863                    );
11864                }
11865            }
11866
11867            Thunk::MaxPool2dBackward {
11868                x,
11869                dy,
11870                dx,
11871                n,
11872                c,
11873                h,
11874                w,
11875                h_out,
11876                w_out,
11877                kh,
11878                kw,
11879                sh,
11880                sw,
11881                ph,
11882                pw,
11883            } => {
11884                let n = *n as usize;
11885                let c = *c as usize;
11886                let h = *h as usize;
11887                let w = *w as usize;
11888                let h_out = *h_out as usize;
11889                let w_out = *w_out as usize;
11890                let kh = *kh as usize;
11891                let kw = *kw as usize;
11892                let sh = *sh as usize;
11893                let sw = *sw as usize;
11894                let ph = *ph as usize;
11895                let pw = *pw as usize;
11896                unsafe {
11897                    let xs = sl(*x, base, n * c * h * w);
11898                    let dys = sl(*dy, base, n * c * h_out * w_out);
11899                    let dxs = sl_mut(*dx, base, n * c * h * w);
11900                    // Zero before scatter — multiple windows can write
11901                    // to the same input position when stride < kernel.
11902                    for v in dxs.iter_mut() {
11903                        *v = 0.0;
11904                    }
11905                    for ni in 0..n {
11906                        for ci in 0..c {
11907                            let in_chan = (ni * c + ci) * h * w;
11908                            let out_chan = (ni * c + ci) * h_out * w_out;
11909                            for ho in 0..h_out {
11910                                for wo in 0..w_out {
11911                                    // Recompute argmax inside this window.
11912                                    let mut best_v = f32::NEG_INFINITY;
11913                                    let mut best_idx: Option<usize> = None;
11914                                    for ki in 0..kh {
11915                                        for kj in 0..kw {
11916                                            let hi = ho * sh + ki;
11917                                            let wi = wo * sw + kj;
11918                                            if hi < ph || wi < pw {
11919                                                continue;
11920                                            }
11921                                            let hi = hi - ph;
11922                                            let wi = wi - pw;
11923                                            if hi >= h || wi >= w {
11924                                                continue;
11925                                            }
11926                                            let idx = in_chan + hi * w + wi;
11927                                            let v = xs[idx];
11928                                            // Tie-break: keep first hit
11929                                            // (matches forward's `acc.max(v)`
11930                                            // — strict greater-than wins).
11931                                            if v > best_v {
11932                                                best_v = v;
11933                                                best_idx = Some(idx);
11934                                            }
11935                                        }
11936                                    }
11937                                    if let Some(idx) = best_idx {
11938                                        dxs[idx] += dys[out_chan + ho * w_out + wo];
11939                                    }
11940                                }
11941                            }
11942                        }
11943                    }
11944                }
11945            }
11946
11947            Thunk::Conv2dBackwardInput {
11948                dy,
11949                w,
11950                dx,
11951                n,
11952                c_in,
11953                h,
11954                w_in,
11955                c_out,
11956                h_out,
11957                w_out,
11958                kh,
11959                kw,
11960                sh,
11961                sw,
11962                ph,
11963                pw,
11964                dh,
11965                dw,
11966                groups,
11967            } => {
11968                // Per-group GEMM + col2im. Two orders of magnitude faster
11969                // than the naive 6-deep nested loop on training shapes.
11970                //
11971                //   dcol_n_g = w_g^T  @  dy_n_g            (sgemm)
11972                //   dx_n_g  += col2im(dcol_n_g)            (scatter-add)
11973                //
11974                // Layouts (all row-major):
11975                //   w_g       [c_out_per_g, c_in_per_g · kh · kw]
11976                //   dy_n_g    [c_out_per_g, h_out · w_out]
11977                //   dcol_n_g  [c_in_per_g · kh · kw, h_out · w_out]
11978                //   dx_n_g    [c_in_per_g, h · w_in]
11979                let n = *n as usize;
11980                let c_in = *c_in as usize;
11981                let h = *h as usize;
11982                let w_in = *w_in as usize;
11983                let c_out = *c_out as usize;
11984                let h_out = *h_out as usize;
11985                let w_out = *w_out as usize;
11986                let kh = *kh as usize;
11987                let kw = *kw as usize;
11988                let sh = *sh as usize;
11989                let sw = *sw as usize;
11990                let ph = *ph as usize;
11991                let pw = *pw as usize;
11992                let dh = *dh as usize;
11993                let dw = *dw as usize;
11994                let groups = *groups as usize;
11995                let c_in_per_g = c_in / groups;
11996                let c_out_per_g = c_out / groups;
11997
11998                let m_dim = c_in_per_g * kh * kw;
11999                let n_dim = h_out * w_out;
12000                let k_dim = c_out_per_g;
12001
12002                let dy_stride_n = c_out * h_out * w_out;
12003                let dy_stride_g = c_out_per_g * h_out * w_out;
12004                let w_stride_g = c_out_per_g * c_in_per_g * kh * kw;
12005                let dx_stride_n = c_in * h * w_in;
12006                let dx_stride_g = c_in_per_g * h * w_in;
12007
12008                unsafe {
12009                    let dys = sl(*dy, base, n * c_out * h_out * w_out);
12010                    let ws = sl(*w, base, c_out * c_in_per_g * kh * kw);
12011                    let dxs = sl_mut(*dx, base, n * c_in * h * w_in);
12012                    for v in dxs.iter_mut() {
12013                        *v = 0.0;
12014                    }
12015
12016                    // Reused scratch buffer for the [m_dim, n_dim] dcol.
12017                    let mut dcol = vec![0f32; m_dim * n_dim];
12018
12019                    for ni in 0..n {
12020                        for g in 0..groups {
12021                            let w_g_off = g * w_stride_g;
12022                            let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
12023                            let dx_n_g_off = ni * dx_stride_n + g * dx_stride_g;
12024
12025                            // dcol = w_g^T @ dy_n_g
12026                            // w_g  is stored as [k_dim rows, m_dim cols] row-major
12027                            // (i.e. K×M storage with lda = M = m_dim — exactly what
12028                            // sgemm_general wants for trans_a=true).
12029                            crate::blas::sgemm_general(
12030                                ws.as_ptr().add(w_g_off),
12031                                dys.as_ptr().add(dy_n_g_off),
12032                                dcol.as_mut_ptr(),
12033                                m_dim,
12034                                n_dim,
12035                                k_dim,
12036                                1.0,
12037                                0.0,
12038                                /*lda=*/ m_dim,
12039                                /*ldb=*/ n_dim,
12040                                /*ldc=*/ n_dim,
12041                                /*trans_a=*/ true,
12042                                /*trans_b=*/ false,
12043                            );
12044
12045                            // dx_n_g += col2im(dcol)
12046                            col2im(
12047                                &dcol,
12048                                &mut dxs[dx_n_g_off..dx_n_g_off + dx_stride_g],
12049                                c_in_per_g,
12050                                h,
12051                                w_in,
12052                                h_out,
12053                                w_out,
12054                                kh,
12055                                kw,
12056                                sh,
12057                                sw,
12058                                ph,
12059                                pw,
12060                                dh,
12061                                dw,
12062                            );
12063                        }
12064                    }
12065                }
12066            }
12067
12068            Thunk::Conv2dBackwardWeight {
12069                x,
12070                dy,
12071                dw,
12072                n,
12073                c_in,
12074                h,
12075                w,
12076                c_out,
12077                h_out,
12078                w_out,
12079                kh,
12080                kw,
12081                sh,
12082                sw,
12083                ph,
12084                pw,
12085                dh,
12086                dw_dil,
12087                groups,
12088            } => {
12089                let n = *n as usize;
12090                let c_in = *c_in as usize;
12091                let h = *h as usize;
12092                let w = *w as usize;
12093                // Per-group im2col + GEMM, summed across batch.
12094                //
12095                //   col_n_g  = im2col(x_n_g)               (gather)
12096                //   dw_g    += dy_n_g  @  col_n_g^T        (sgemm, β=1)
12097                //
12098                // Layouts:
12099                //   x_n_g     [c_in_per_g, h · w]
12100                //   col_n_g   [c_in_per_g · kh · kw, h_out · w_out]
12101                //   dy_n_g    [c_out_per_g, h_out · w_out]
12102                //   dw_g      [c_out_per_g, c_in_per_g · kh · kw]
12103                let c_out = *c_out as usize;
12104                let h_out = *h_out as usize;
12105                let w_out = *w_out as usize;
12106                let kh = *kh as usize;
12107                let kw = *kw as usize;
12108                let sh = *sh as usize;
12109                let sw = *sw as usize;
12110                let ph = *ph as usize;
12111                let pw = *pw as usize;
12112                let dh = *dh as usize;
12113                let dw_dil = *dw_dil as usize;
12114                let groups = *groups as usize;
12115                let c_in_per_g = c_in / groups;
12116                let c_out_per_g = c_out / groups;
12117
12118                let m_dim = c_out_per_g;
12119                let n_dim = c_in_per_g * kh * kw;
12120                let k_dim = h_out * w_out;
12121
12122                let x_stride_n = c_in * h * w;
12123                let x_stride_g = c_in_per_g * h * w;
12124                let dy_stride_n = c_out * h_out * w_out;
12125                let dy_stride_g = c_out_per_g * h_out * w_out;
12126                let dw_stride_g = c_out_per_g * c_in_per_g * kh * kw;
12127
12128                unsafe {
12129                    let xs = sl(*x, base, n * c_in * h * w);
12130                    let dys = sl(*dy, base, n * c_out * h_out * w_out);
12131                    let dws = sl_mut(*dw, base, c_out * c_in_per_g * kh * kw);
12132                    for v in dws.iter_mut() {
12133                        *v = 0.0;
12134                    }
12135
12136                    let mut col = vec![0f32; n_dim * k_dim];
12137
12138                    for ni in 0..n {
12139                        for g in 0..groups {
12140                            let x_n_g_off = ni * x_stride_n + g * x_stride_g;
12141                            im2col(
12142                                &xs[x_n_g_off..x_n_g_off + x_stride_g],
12143                                &mut col,
12144                                c_in_per_g,
12145                                h,
12146                                w,
12147                                h_out,
12148                                w_out,
12149                                kh,
12150                                kw,
12151                                sh,
12152                                sw,
12153                                ph,
12154                                pw,
12155                                dh,
12156                                dw_dil,
12157                            );
12158
12159                            let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
12160                            let dw_g_off = g * dw_stride_g;
12161
12162                            // dw_g += dy_n_g @ col^T
12163                            //
12164                            // Output shape m × n_out = c_out_per_g × (c_in_per_g·kh·kw).
12165                            // dy_n_g is stored M×K row-major (lda = K = k_dim).
12166                            // col is stored as N×K row-major; with trans_b=true,
12167                            // sgemm_general uses ldb = K = k_dim and treats it as
12168                            // transposed. β=1 accumulates across the batch loop.
12169                            crate::blas::sgemm_general(
12170                                dys.as_ptr().add(dy_n_g_off),
12171                                col.as_ptr(),
12172                                dws.as_mut_ptr().add(dw_g_off),
12173                                m_dim,
12174                                n_dim,
12175                                k_dim,
12176                                1.0,
12177                                1.0,
12178                                /*lda=*/ k_dim,
12179                                /*ldb=*/ k_dim,
12180                                /*ldc=*/ n_dim,
12181                                /*trans_a=*/ false,
12182                                /*trans_b=*/ true,
12183                            );
12184                        }
12185                    }
12186                }
12187            }
12188
12189            Thunk::SoftmaxCrossEntropy {
12190                logits,
12191                labels,
12192                dst,
12193                n,
12194                c,
12195            } => {
12196                let n = *n as usize;
12197                let c = *c as usize;
12198                unsafe {
12199                    let lg = sl(*logits, base, n * c);
12200                    let lb = sl(*labels, base, n);
12201                    let out = sl_mut(*dst, base, n);
12202                    for ni in 0..n {
12203                        let row = &lg[ni * c..(ni + 1) * c];
12204                        // log-sum-exp: max-subtract for stability.
12205                        let mut m = f32::NEG_INFINITY;
12206                        for &v in row {
12207                            if v > m {
12208                                m = v;
12209                            }
12210                        }
12211                        let mut sum = 0f32;
12212                        for &v in row {
12213                            sum += (v - m).exp();
12214                        }
12215                        let lse = m + sum.ln();
12216                        let label_idx = lb[ni] as usize;
12217                        // loss = -(logits[label] - lse) = lse - logits[label].
12218                        out[ni] = lse - row[label_idx];
12219                    }
12220                }
12221            }
12222
12223            Thunk::SoftmaxCrossEntropyBackward {
12224                logits,
12225                labels,
12226                d_loss,
12227                dlogits,
12228                n,
12229                c,
12230            } => {
12231                let n = *n as usize;
12232                let c = *c as usize;
12233                unsafe {
12234                    let lg = sl(*logits, base, n * c);
12235                    let lb = sl(*labels, base, n);
12236                    let dl = sl(*d_loss, base, n);
12237                    let out = sl_mut(*dlogits, base, n * c);
12238                    for ni in 0..n {
12239                        let row = &lg[ni * c..(ni + 1) * c];
12240                        let label_idx = lb[ni] as usize;
12241                        let scale = dl[ni];
12242                        let mut m = f32::NEG_INFINITY;
12243                        for &v in row {
12244                            if v > m {
12245                                m = v;
12246                            }
12247                        }
12248                        let mut sum = 0f32;
12249                        for &v in row {
12250                            sum += (v - m).exp();
12251                        }
12252                        let inv_sum = 1.0 / sum;
12253                        let dst_row = &mut out[ni * c..(ni + 1) * c];
12254                        for k in 0..c {
12255                            let p = (row[k] - m).exp() * inv_sum;
12256                            let one_hot = if k == label_idx { 1.0 } else { 0.0 };
12257                            dst_row[k] = (p - one_hot) * scale;
12258                        }
12259                    }
12260                }
12261            }
12262
12263            Thunk::GatherAxis {
12264                table,
12265                idx,
12266                dst,
12267                outer,
12268                axis_dim,
12269                num_idx,
12270                trailing,
12271            } => {
12272                let outer = *outer as usize;
12273                let axis_dim = *axis_dim as usize;
12274                let num_idx = *num_idx as usize;
12275                let trailing = *trailing as usize;
12276                unsafe {
12277                    let tab = sl(*table, base, outer * axis_dim * trailing);
12278                    let ids = sl(*idx, base, num_idx);
12279                    let out = sl_mut(*dst, base, outer * num_idx * trailing);
12280                    for o in 0..outer {
12281                        let tab_outer = o * axis_dim * trailing;
12282                        let out_outer = o * num_idx * trailing;
12283                        for k in 0..num_idx {
12284                            let row = ids[k] as usize;
12285                            let tab_row = tab_outer + row * trailing;
12286                            let out_row = out_outer + k * trailing;
12287                            out[out_row..out_row + trailing]
12288                                .copy_from_slice(&tab[tab_row..tab_row + trailing]);
12289                        }
12290                    }
12291                }
12292            }
12293
12294            Thunk::Transpose {
12295                src,
12296                dst,
12297                in_total,
12298                out_dims,
12299                in_strides,
12300            } => {
12301                // N-D index walk: for each output flat index, decompose into
12302                // multi-dim coords using out_dims, then dot with in_strides
12303                // to find the source flat index. Stride 0 = broadcast (read
12304                // the same input element repeatedly along that dim).
12305                let rank = out_dims.len();
12306                let total: usize = out_dims.iter().map(|&d| d as usize).product();
12307                let in_total = *in_total as usize;
12308                unsafe {
12309                    let inp = sl(*src, base, in_total);
12310                    let out = sl_mut(*dst, base, total);
12311                    let mut idx = vec![0usize; rank];
12312                    for o in 0..total {
12313                        let mut src_idx = 0usize;
12314                        for d in 0..rank {
12315                            src_idx += idx[d] * in_strides[d] as usize;
12316                        }
12317                        out[o] = inp[src_idx];
12318                        // Increment multi-index (innermost dim first).
12319                        for d in (0..rank).rev() {
12320                            idx[d] += 1;
12321                            if idx[d] < out_dims[d] as usize {
12322                                break;
12323                            }
12324                            idx[d] = 0;
12325                        }
12326                    }
12327                }
12328            }
12329
12330            // (Thunk::DenseSolveF64 / Thunk::ScanBackward had panic
12331            // stubs here as placeholders during the wire-up; both
12332            // are now reached by the real implementations earlier in
12333            // this same match — the stubs were dead code shadowed by
12334            // the specific-pattern arms above. Removed.)
12335            Thunk::CustomOp {
12336                kernel,
12337                inputs,
12338                output,
12339                attrs,
12340            } => {
12341                let (out_off, out_len, out_shape) = output;
12342                unsafe {
12343                    dispatch_custom_op(
12344                        &**kernel, inputs, *out_off, *out_len, out_shape, attrs, base,
12345                    );
12346                }
12347            }
12348        }
12349    }
12350}
12351
12352/// Griewank treeverse: process backward iterations `[t_lo..=t_hi]` (with
12353/// the carry entering iteration `t_lo` supplied as `anchor_carry`) by
12354/// recursive binary subdivision. Total work `O((t_hi-t_lo+1) · log)`,
12355/// auxiliary memory `O(log · carry_bytes)` for the recursion stack.
12356///
12357/// Compared to the iterative segment-cached scheme, this trades extra
12358/// recompute for less working memory — each level of recursion holds
12359/// one `cb`-sized intermediate carry on the stack but never the whole
12360/// segment at once. With K saved outer checkpoints, the outer driver
12361/// invokes this helper once per segment.
12362///
12363/// `process_iter(t, carry_at_t)` is the per-iteration leaf action: it
12364/// runs `body_vjp` at iteration `t` with the supplied carry, threads
12365/// `dcarry` backward, and (for ScanBackwardXs) writes `dxs[t]`.
12366#[allow(clippy::too_many_arguments)]
12367unsafe fn griewank_process_segment(
12368    t_lo: usize,
12369    t_hi: usize,
12370    anchor_carry: &[u8],
12371    cb: usize,
12372    fwd_sched: &ThunkSchedule,
12373    fwd_init: &[u8],
12374    fwd_carry_in_off: usize,
12375    fwd_output_off: usize,
12376    fwd_x_offs: &[usize],
12377    base: *mut u8,
12378    outer_xs_offs: &[(usize, u32)],
12379    fwd_buf: &mut Vec<u8>,
12380    leaf_threshold: usize,
12381    process_iter: &mut dyn FnMut(usize, &[u8]),
12382) {
12383    unsafe {
12384        let size = t_hi - t_lo + 1;
12385        if size == 1 {
12386            process_iter(t_lo, anchor_carry);
12387            return;
12388        }
12389        if size <= leaf_threshold {
12390            // Walk forward, cache each carry, run backward in reverse.
12391            let mut cache: Vec<u8> = Vec::with_capacity(size * cb);
12392            cache.extend_from_slice(anchor_carry);
12393            fwd_buf.copy_from_slice(fwd_init);
12394            std::ptr::copy_nonoverlapping(
12395                anchor_carry.as_ptr(),
12396                fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
12397                cb,
12398            );
12399            for i in 1..size {
12400                let cur_iter = t_lo + i - 1;
12401                for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
12402                    let (outer_xs_off, x_psb) = outer_xs_offs[idx];
12403                    let xb = x_psb as usize;
12404                    std::ptr::copy_nonoverlapping(
12405                        base.add(outer_xs_off + cur_iter * xb),
12406                        fwd_buf.as_mut_ptr().add(*fb_x_off),
12407                        xb,
12408                    );
12409                }
12410                execute_thunks(fwd_sched, fwd_buf);
12411                if fwd_output_off != fwd_carry_in_off {
12412                    fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
12413                }
12414                cache.extend_from_slice(&fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb]);
12415            }
12416            // Process backward.
12417            for t in (t_lo..=t_hi).rev() {
12418                let idx = t - t_lo;
12419                let carry = &cache[idx * cb..(idx + 1) * cb];
12420                process_iter(t, carry);
12421            }
12422            return;
12423        }
12424
12425        // Split: walk forward from anchor to compute carry entering `mid`.
12426        // (We need `mid - t_lo` body executions: one per iteration in
12427        // [t_lo, mid).)
12428        let mid = t_lo + size / 2;
12429        fwd_buf.copy_from_slice(fwd_init);
12430        std::ptr::copy_nonoverlapping(
12431            anchor_carry.as_ptr(),
12432            fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
12433            cb,
12434        );
12435        for cur_iter in t_lo..mid {
12436            for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
12437                let (outer_xs_off, x_psb) = outer_xs_offs[idx];
12438                let xb = x_psb as usize;
12439                std::ptr::copy_nonoverlapping(
12440                    base.add(outer_xs_off + cur_iter * xb),
12441                    fwd_buf.as_mut_ptr().add(*fb_x_off),
12442                    xb,
12443                );
12444            }
12445            execute_thunks(fwd_sched, fwd_buf);
12446            if fwd_output_off != fwd_carry_in_off {
12447                fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
12448            }
12449        }
12450        let mid_carry: Vec<u8> = fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb].to_vec();
12451
12452        // Right half first (higher t values processed first to match the
12453        // canonical reverse-mode iteration order: dcarry threads from
12454        // t=length-1 down to t=0).
12455        griewank_process_segment(
12456            mid,
12457            t_hi,
12458            &mid_carry,
12459            cb,
12460            fwd_sched,
12461            fwd_init,
12462            fwd_carry_in_off,
12463            fwd_output_off,
12464            fwd_x_offs,
12465            base,
12466            outer_xs_offs,
12467            fwd_buf,
12468            leaf_threshold,
12469            process_iter,
12470        );
12471        // Then left half with original anchor.
12472        griewank_process_segment(
12473            t_lo,
12474            mid - 1,
12475            anchor_carry,
12476            cb,
12477            fwd_sched,
12478            fwd_init,
12479            fwd_carry_in_off,
12480            fwd_output_off,
12481            fwd_x_offs,
12482            base,
12483            outer_xs_offs,
12484            fwd_buf,
12485            leaf_threshold,
12486            process_iter,
12487        );
12488    }
12489}
12490
12491/// Execute a batched 1D FFT in the f64 2N-real-block layout.
12492/// Each "row" is `2N` f64 elements: first `N` real, then `N` imag.
12493/// The `outer` rows are independent and processed sequentially.
12494///
12495/// Both forward and inverse use the same Cooley-Tukey radix-2 DIT
12496/// kernel — only the twiddle-factor sign differs. Power-of-2 only
12497/// (the IR builder rejects non-power-of-2 sizes at graph-build time).
12498/// Batched 1D FFT on the f64 2N-real-block layout. Public so other
12499/// backend crates can invoke this as a host fallback against a
12500/// unified-memory arena (e.g. rlx-metal: sync the command buffer,
12501/// pass the Metal `Buffer::contents()` pointer as `base`, restart the
12502/// command buffer). Self-contained — no rlx-cpu state required.
12503///
12504/// Safety: `base + src` and `base + dst` must be valid for the
12505/// `outer * 2 * n_complex * sizeof::<f64>()` byte range and stay
12506/// alive for the duration of the call.
12507pub unsafe fn execute_fft1d_f64(
12508    src: usize,
12509    dst: usize,
12510    outer: usize,
12511    n_complex: usize,
12512    inverse: bool,
12513    base: *mut u8,
12514) {
12515    let row_elems = 2 * n_complex;
12516    let mut re = vec![0f64; n_complex];
12517    let mut im = vec![0f64; n_complex];
12518    // Scratch reused across rows for the Bluestein path. Empty when
12519    // we're on the radix-2 fast path.
12520    let mut scratch = if n_complex.is_power_of_two() {
12521        BluesteinScratchF64::empty()
12522    } else {
12523        BluesteinScratchF64::build(n_complex, inverse)
12524    };
12525    for o in 0..outer {
12526        let row_offset = src + o * row_elems * std::mem::size_of::<f64>();
12527        let s = unsafe { sl_f64(row_offset, base, row_elems) };
12528        re.copy_from_slice(&s[..n_complex]);
12529        im.copy_from_slice(&s[n_complex..]);
12530        if n_complex.is_power_of_two() {
12531            fft_radix2_inplace_f64(&mut re, &mut im, inverse);
12532        } else {
12533            fft_bluestein_inplace_f64(&mut re, &mut im, inverse, &mut scratch);
12534        }
12535        let dst_offset = dst + o * row_elems * std::mem::size_of::<f64>();
12536        let d = unsafe { sl_mut_f64(dst_offset, base, row_elems) };
12537        d[..n_complex].copy_from_slice(&re);
12538        d[n_complex..].copy_from_slice(&im);
12539    }
12540}
12541
12542/// f32 counterpart of `execute_fft1d_f64`. Same 2N-real-block layout
12543/// (first N real, second N imag per row), same unnormalized
12544/// convention; only the element width differs. Twiddle factors are
12545/// computed in f64 and cast to f32 to keep large-N error closer to
12546/// the f64 path (the savings from f32 are in memory bandwidth, not in
12547/// twiddle precision).
12548/// Host-fallback entry for `Op::GatedDeltaNet` (Metal / unified memory).
12549/// When `state == 0`, uses a zero-initialized scratch state per batch item.
12550pub unsafe fn execute_gated_delta_net_f32(
12551    q: usize,
12552    k: usize,
12553    v: usize,
12554    g: usize,
12555    beta: usize,
12556    state: usize,
12557    dst: usize,
12558    batch: usize,
12559    seq: usize,
12560    heads: usize,
12561    state_size: usize,
12562    base: *mut u8,
12563) {
12564    use rayon::prelude::*;
12565
12566    #[derive(Copy, Clone)]
12567    struct ArenaPtr(usize);
12568    unsafe impl Send for ArenaPtr {}
12569    unsafe impl Sync for ArenaPtr {}
12570    impl ArenaPtr {
12571        #[inline]
12572        fn get(self) -> *mut u8 {
12573            self.0 as *mut u8
12574        }
12575    }
12576
12577    unsafe {
12578        let arena = ArenaPtr(base as usize);
12579        let (b, s, h, n) = (batch, seq, heads, state_size);
12580        let scale = 1.0f32 / (n as f32).sqrt();
12581        let use_external = state != 0;
12582        let mut owned_state = vec![0f32; h * n * n];
12583
12584        crate::pool::num_threads();
12585
12586        assert!(
12587            n <= crate::gdn::GDN_MAX_STATE,
12588            "GatedDeltaNet state_size={n} exceeds stack scratch ({})",
12589            crate::gdn::GDN_MAX_STATE
12590        );
12591
12592        let qs = sl(q, arena.get(), b * s * h * n);
12593        let ks = sl(k, arena.get(), b * s * h * n);
12594        let vs = sl(v, arena.get(), b * s * h * n);
12595        let gs = sl(g, arena.get(), b * s * h);
12596        let betas = sl(beta, arena.get(), b * s * h);
12597        let _out = sl_mut(dst, arena.get(), b * s * h * n);
12598        let hs_n = h * n;
12599
12600        let run_head = |bi: usize, hi: usize, s_mat: &mut [f32], sk: &mut [f32]| {
12601            for ti in 0..s {
12602                let qkv_step = bi * s * hs_n + ti * hs_n + hi * n;
12603                let gb_step = bi * s * h + ti * h + hi;
12604                let out_row = sl_mut(dst + qkv_step * std::mem::size_of::<f32>(), arena.get(), n);
12605                crate::gdn::gdn_step_blas(
12606                    s_mat,
12607                    &qs[qkv_step..qkv_step + n],
12608                    &ks[qkv_step..qkv_step + n],
12609                    &vs[qkv_step..qkv_step + n],
12610                    gs[gb_step],
12611                    betas[gb_step],
12612                    out_row,
12613                    sk,
12614                    n,
12615                    scale,
12616                );
12617            }
12618        };
12619
12620        // Prefill (seq>1, ephemeral state): time-outer, parallel over heads —
12621        // better occupancy than head-outer when prompt length dominates.
12622        if !use_external && s > 1 {
12623            for bi in 0..b {
12624                (0..h).into_par_iter().for_each(|hi| {
12625                    let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
12626                    let sk = &mut sk_buf[..n];
12627                    let mut local_state =
12628                        [0f32; crate::gdn::GDN_MAX_STATE * crate::gdn::GDN_MAX_STATE];
12629                    let s_mat = &mut local_state[..n * n];
12630                    s_mat.fill(0.0);
12631                    run_head(bi, hi, s_mat, sk);
12632                });
12633            }
12634            return;
12635        }
12636
12637        if use_external {
12638            let state_bytes = state;
12639            (0..b * h).into_par_iter().for_each(|bhi| {
12640                let bi = bhi / h;
12641                let hi = bhi % h;
12642                let elem_off = bi * h * n * n + hi * n * n;
12643                let s_mat = sl_mut(
12644                    state_bytes + elem_off * std::mem::size_of::<f32>(),
12645                    arena.get(),
12646                    n * n,
12647                );
12648                let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
12649                run_head(bi, hi, s_mat, &mut sk_buf[..n]);
12650            });
12651        } else {
12652            for bi in 0..b {
12653                owned_state.fill(0.0);
12654                owned_state
12655                    .par_chunks_mut(n * n)
12656                    .enumerate()
12657                    .for_each(|(hi, s_mat)| {
12658                        let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
12659                        run_head(bi, hi, s_mat, &mut sk_buf[..n]);
12660                    });
12661            }
12662        }
12663    }
12664}
12665
12666/// Host-fallback: `Op::RmsNormBackwardInput` (GPU unified-memory / D2H arenas).
12667pub unsafe fn execute_rms_norm_backward_input_f32(
12668    x: usize,
12669    gamma: usize,
12670    beta: usize,
12671    dy: usize,
12672    dx: usize,
12673    rows: u32,
12674    h: u32,
12675    eps: f32,
12676    base: *mut u8,
12677) {
12678    let (rows, h) = (rows as usize, h as usize);
12679    let mut dg = vec![0f32; h];
12680    let mut db = vec![0f32; h];
12681    let xs = sl(x, base, rows * h);
12682    let dys = sl(dy, base, rows * h);
12683    let g = sl(gamma, base, h);
12684    let b = sl(beta, base, h);
12685    let out = sl_mut(dx, base, rows * h);
12686    for r in 0..rows {
12687        crate::training_bwd::rms_norm_backward_row(
12688            &xs[r * h..(r + 1) * h],
12689            g,
12690            b,
12691            &dys[r * h..(r + 1) * h],
12692            &mut out[r * h..(r + 1) * h],
12693            &mut dg,
12694            &mut db,
12695            eps,
12696        );
12697    }
12698}
12699
12700pub unsafe fn execute_rms_norm_backward_gamma_f32(
12701    x: usize,
12702    gamma: usize,
12703    beta: usize,
12704    dy: usize,
12705    dgamma: usize,
12706    rows: u32,
12707    h: u32,
12708    eps: f32,
12709    base: *mut u8,
12710) {
12711    let (rows, h) = (rows as usize, h as usize);
12712    let out = sl_mut(dgamma, base, h);
12713    out.fill(0.0);
12714    let mut dx = vec![0f32; h];
12715    let mut db = vec![0f32; h];
12716    let xs = sl(x, base, rows * h);
12717    let dys = sl(dy, base, rows * h);
12718    let g = sl(gamma, base, h);
12719    let b = sl(beta, base, h);
12720    for r in 0..rows {
12721        crate::training_bwd::rms_norm_backward_row(
12722            &xs[r * h..(r + 1) * h],
12723            g,
12724            b,
12725            &dys[r * h..(r + 1) * h],
12726            &mut dx,
12727            out,
12728            &mut db,
12729            eps,
12730        );
12731    }
12732}
12733
12734pub unsafe fn execute_rms_norm_backward_beta_f32(
12735    x: usize,
12736    gamma: usize,
12737    beta: usize,
12738    dy: usize,
12739    dbeta: usize,
12740    rows: u32,
12741    h: u32,
12742    eps: f32,
12743    base: *mut u8,
12744) {
12745    let (rows, h) = (rows as usize, h as usize);
12746    let out = sl_mut(dbeta, base, h);
12747    out.fill(0.0);
12748    let mut dx = vec![0f32; h];
12749    let mut dg = vec![0f32; h];
12750    let xs = sl(x, base, rows * h);
12751    let dys = sl(dy, base, rows * h);
12752    let g = sl(gamma, base, h);
12753    let b = sl(beta, base, h);
12754    for r in 0..rows {
12755        crate::training_bwd::rms_norm_backward_row(
12756            &xs[r * h..(r + 1) * h],
12757            g,
12758            b,
12759            &dys[r * h..(r + 1) * h],
12760            &mut dx,
12761            &mut dg,
12762            out,
12763            eps,
12764        );
12765    }
12766}
12767
12768pub unsafe fn execute_rope_backward_f32(
12769    dy: usize,
12770    cos: usize,
12771    sin: usize,
12772    dx: usize,
12773    batch: u32,
12774    seq: u32,
12775    hidden: u32,
12776    head_dim: u32,
12777    n_rot: u32,
12778    cos_len: u32,
12779    base: *mut u8,
12780) {
12781    let (b, s, hs, dh, nr, cl) = (
12782        batch as usize,
12783        seq as usize,
12784        hidden as usize,
12785        head_dim as usize,
12786        n_rot as usize,
12787        cos_len as usize,
12788    );
12789    let nh = hs / dh;
12790    let tab_half = dh / 2;
12791    let dys = sl(dy, base, b * s * hs);
12792    let cos_tab = sl(cos, base, cl);
12793    let sin_tab = sl(sin, base, cl);
12794    let out = sl_mut(dx, base, b * s * hs);
12795    for bi in 0..b {
12796        for si in 0..s {
12797            let tab_off = si.saturating_mul(tab_half) % cl.max(1);
12798            let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
12799            let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
12800            for hi in 0..nh {
12801                let base_idx = bi * s * hs + si * hs + hi * dh;
12802                crate::training_bwd::rope_backward_row(
12803                    &dys[base_idx..base_idx + dh],
12804                    cp,
12805                    sp,
12806                    &mut out[base_idx..base_idx + dh],
12807                    dh,
12808                    nr,
12809                );
12810            }
12811        }
12812    }
12813}
12814
12815pub unsafe fn execute_cumsum_backward_f32(
12816    dy: usize,
12817    dx: usize,
12818    rows: u32,
12819    cols: u32,
12820    exclusive: bool,
12821    base: *mut u8,
12822) {
12823    let (rows, cols) = (rows as usize, cols as usize);
12824    let dys = sl(dy, base, rows * cols);
12825    let out = sl_mut(dx, base, rows * cols);
12826    for r in 0..rows {
12827        crate::training_bwd::cumsum_backward_row(
12828            &dys[r * cols..(r + 1) * cols],
12829            &mut out[r * cols..(r + 1) * cols],
12830            exclusive,
12831        );
12832    }
12833}
12834
12835pub unsafe fn execute_gather_backward_f32(
12836    dy: usize,
12837    indices: usize,
12838    dst: usize,
12839    outer: u32,
12840    axis_dim: u32,
12841    num_idx: u32,
12842    trailing: u32,
12843    base: *mut u8,
12844) {
12845    let (outer, axis_dim, num_idx, trailing) = (
12846        outer as usize,
12847        axis_dim as usize,
12848        num_idx as usize,
12849        trailing as usize,
12850    );
12851    let out = sl_mut(dst, base, outer * axis_dim * trailing);
12852    out.fill(0.0);
12853    crate::training_bwd::gather_axis_backward(
12854        sl(dy, base, outer * num_idx * trailing),
12855        sl(indices, base, num_idx),
12856        out,
12857        outer,
12858        axis_dim,
12859        num_idx,
12860        trailing,
12861    );
12862}
12863
12864/// Host-fallback entry for GGUF `Op::DequantMatMul` (Metal unified memory).
12865pub unsafe fn execute_dequant_matmul_gguf_f32(
12866    x: usize,
12867    w_q: usize,
12868    dst: usize,
12869    m: usize,
12870    k: usize,
12871    n: usize,
12872    scheme: rlx_ir::quant::QuantScheme,
12873    base: *mut u8,
12874) {
12875    unsafe {
12876        let block_bytes = scheme.gguf_block_bytes() as usize;
12877        let block_elems = scheme.gguf_block_size() as usize;
12878        let total_bytes = (k * n) / block_elems * block_bytes;
12879        let xs = sl(x, base, m * k);
12880        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
12881        let out = sl_mut(dst, base, m * n);
12882        crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
12883    }
12884}
12885
12886/// Host-fallback entry for GGUF `Op::DequantGroupedMatMul` (MoE expert stack).
12887pub unsafe fn execute_dequant_grouped_matmul_gguf_f32(
12888    input: usize,
12889    w_q: usize,
12890    expert_idx: usize,
12891    dst: usize,
12892    m: usize,
12893    k: usize,
12894    n: usize,
12895    num_experts: usize,
12896    scheme: rlx_ir::quant::QuantScheme,
12897    base: *mut u8,
12898) {
12899    unsafe {
12900        let block_bytes = scheme.gguf_block_bytes() as usize;
12901        let block_elems = scheme.gguf_block_size() as usize;
12902        let slab_bytes = (k * n) / block_elems * block_bytes;
12903        let xs = sl(input, base, m * k);
12904        let w_bytes =
12905            std::slice::from_raw_parts(base.add(w_q) as *const u8, num_experts * slab_bytes);
12906        let ids = sl(expert_idx, base, m);
12907        let out = sl_mut(dst, base, m * n);
12908        crate::gguf_matmul::gguf_grouped_matmul_bt(
12909            xs,
12910            w_bytes,
12911            ids,
12912            out,
12913            m,
12914            k,
12915            n,
12916            num_experts,
12917            scheme,
12918        );
12919    }
12920}
12921
12922/// Host-fallback entry for Int4 `Op::DequantMatMul` (Metal unified memory).
12923pub unsafe fn execute_dequant_matmul_int4_f32(
12924    x: usize,
12925    w_q: usize,
12926    scale: usize,
12927    zp: usize,
12928    dst: usize,
12929    m: usize,
12930    k: usize,
12931    n: usize,
12932    block_size: u32,
12933    is_asymmetric: bool,
12934    base: *mut u8,
12935) {
12936    let bs = block_size as usize;
12937    let n_blocks = k.div_ceil(bs);
12938    unsafe {
12939        let xs = sl(x, base, m * k);
12940        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
12941        let scales = sl(scale, base, n_blocks * n);
12942        let zps = if is_asymmetric {
12943            sl(zp, base, n_blocks * n)
12944        } else {
12945            &[][..]
12946        };
12947        let out = sl_mut(dst, base, m * n);
12948        dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, is_asymmetric);
12949    }
12950}
12951
12952/// Host-fallback entry for FP8 `Op::DequantMatMul` (Metal unified memory).
12953pub unsafe fn execute_dequant_matmul_fp8_f32(
12954    x: usize,
12955    w_q: usize,
12956    scale: usize,
12957    dst: usize,
12958    m: usize,
12959    k: usize,
12960    n: usize,
12961    e5m2: bool,
12962    base: *mut u8,
12963) {
12964    unsafe {
12965        let xs = sl(x, base, m * k);
12966        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
12967        let scales = sl(scale, base, n);
12968        let out = sl_mut(dst, base, m * n);
12969        dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
12970    }
12971}
12972
12973/// Host-fallback entry for NVFP4 `Op::DequantMatMul` (Metal unified memory).
12974pub unsafe fn execute_dequant_matmul_nvfp4_f32(
12975    x: usize,
12976    w_q: usize,
12977    scale: usize,
12978    global_scale: usize,
12979    dst: usize,
12980    m: usize,
12981    k: usize,
12982    n: usize,
12983    base: *mut u8,
12984) {
12985    let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
12986    unsafe {
12987        let xs = sl(x, base, m * k);
12988        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
12989        let scale_bytes = std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
12990        let gs = sl(global_scale, base, 1)[0];
12991        let out = sl_mut(dst, base, m * n);
12992        dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
12993    }
12994}
12995
12996/// Host-fallback entry for f16 `Op::GatedDeltaNet` tensors on Metal.
12997pub unsafe fn execute_gated_delta_net_f16(
12998    q: usize,
12999    k: usize,
13000    v: usize,
13001    g: usize,
13002    beta: usize,
13003    state: usize,
13004    dst: usize,
13005    batch: usize,
13006    seq: usize,
13007    heads: usize,
13008    state_size: usize,
13009    base: *mut u8,
13010) {
13011    use half::f16;
13012    unsafe {
13013        let read_f16 = |off: usize, len: usize| -> Vec<f32> {
13014            let raw = std::slice::from_raw_parts(base.add(off) as *const u8, len * 2);
13015            raw.chunks_exact(2)
13016                .map(|c| f16::from_le_bytes([c[0], c[1]]).to_f32())
13017                .collect()
13018        };
13019        let write_f16 = |off: usize, data: &[f32]| {
13020            let out = std::slice::from_raw_parts_mut(base.add(off), data.len() * 2);
13021            for (i, &v) in data.iter().enumerate() {
13022                let le = f16::from_f32(v).to_le_bytes();
13023                out[i * 2] = le[0];
13024                out[i * 2 + 1] = le[1];
13025            }
13026        };
13027
13028        let (b, s, h, n) = (batch, seq, heads, state_size);
13029        let q_f = read_f16(q, b * s * h * n);
13030        let k_f = read_f16(k, b * s * h * n);
13031        let v_f = read_f16(v, b * s * h * n);
13032        let g_f = read_f16(g, b * s * h);
13033        let b_f = read_f16(beta, b * s * h);
13034        let mut state_f = if state != 0 {
13035            read_f16(state, b * h * n * n)
13036        } else {
13037            vec![0f32; b * h * n * n]
13038        };
13039        let mut out_f = vec![0f32; b * s * h * n];
13040        let scale = 1.0f32 / (n as f32).sqrt();
13041        let mut sk_buf = vec![0f32; n];
13042        let mut owned_state = vec![0f32; h * n * n];
13043
13044        for bi in 0..b {
13045            let state_slice: &mut [f32] = if state != 0 {
13046                let start = bi * h * n * n;
13047                &mut state_f[start..start + h * n * n]
13048            } else {
13049                owned_state.fill(0.0);
13050                &mut owned_state
13051            };
13052
13053            for ti in 0..s {
13054                let qkv_step_base = bi * s * h * n + ti * h * n;
13055                let gb_step_base = bi * s * h + ti * h;
13056
13057                for hi in 0..h {
13058                    let q_row = &q_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13059                    let k_row = &k_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13060                    let v_row = &v_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13061                    let g_t = g_f[gb_step_base + hi];
13062                    let beta_t = b_f[gb_step_base + hi];
13063
13064                    let s_base = hi * n * n;
13065                    let s_mat = &mut state_slice[s_base..s_base + n * n];
13066
13067                    let g_exp = g_t.exp();
13068                    for st in s_mat.iter_mut() {
13069                        *st *= g_exp;
13070                    }
13071
13072                    for j in 0..n {
13073                        let mut acc = 0f32;
13074                        for i in 0..n {
13075                            acc += s_mat[i * n + j] * k_row[i];
13076                        }
13077                        sk_buf[j] = acc;
13078                    }
13079
13080                    for j in 0..n {
13081                        sk_buf[j] = (v_row[j] - sk_buf[j]) * beta_t;
13082                    }
13083
13084                    for i in 0..n {
13085                        let ki = k_row[i];
13086                        for j in 0..n {
13087                            s_mat[i * n + j] += ki * sk_buf[j];
13088                        }
13089                    }
13090
13091                    let out_row = &mut out_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13092                    for j in 0..n {
13093                        let mut acc = 0f32;
13094                        for i in 0..n {
13095                            acc += s_mat[i * n + j] * q_row[i];
13096                        }
13097                        out_row[j] = acc * scale;
13098                    }
13099                }
13100            }
13101        }
13102
13103        write_f16(dst, &out_f);
13104        if state != 0 {
13105            write_f16(state, &state_f);
13106        }
13107    }
13108}
13109
13110/// Host fallback for NCHW group norm (Metal unified-memory arena).
13111pub unsafe fn execute_group_norm_nchw_f32(
13112    src: usize,
13113    g: usize,
13114    b: usize,
13115    dst: usize,
13116    n: usize,
13117    c: usize,
13118    h: usize,
13119    w: usize,
13120    num_groups: usize,
13121    eps: f32,
13122    base: *mut u8,
13123) {
13124    let plane = c * h * w;
13125    for ni in 0..n {
13126        let input = unsafe { sl(src + ni * plane * std::mem::size_of::<f32>(), base, plane) };
13127        let gamma = unsafe { sl(g, base, c) };
13128        let beta = unsafe { sl(b, base, c) };
13129        let output = unsafe { sl_mut(dst + ni * plane * std::mem::size_of::<f32>(), base, plane) };
13130        crate::kernels::group_norm_nchw(input, gamma, beta, output, 1, c, h, w, num_groups, eps);
13131    }
13132}
13133
13134/// Host fallback for NCHW LayerNorm2d (SAM / candle semantics).
13135pub unsafe fn execute_layer_norm2d_nchw_f32(
13136    src: usize,
13137    g: usize,
13138    b: usize,
13139    dst: usize,
13140    n: usize,
13141    c: usize,
13142    h: usize,
13143    w: usize,
13144    eps: f32,
13145    base: *mut u8,
13146) {
13147    let plane = c * h * w;
13148    unsafe {
13149        let input = sl(src, base, n * plane);
13150        let gamma = sl(g, base, c);
13151        let beta = sl(b, base, c);
13152        let output = sl_mut(dst, base, n * plane);
13153        crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, eps);
13154    }
13155}
13156
13157/// Host fallback for NCHW ConvTranspose2d.
13158pub unsafe fn execute_conv_transpose2d_nchw_f32(
13159    src: usize,
13160    weight: usize,
13161    dst: usize,
13162    n: usize,
13163    c_in: usize,
13164    h: usize,
13165    w_in: usize,
13166    c_out: usize,
13167    h_out: usize,
13168    w_out: usize,
13169    kh: usize,
13170    kw: usize,
13171    sh: usize,
13172    sw: usize,
13173    ph: usize,
13174    pw: usize,
13175    dh: usize,
13176    dw: usize,
13177    groups: usize,
13178    base: *mut u8,
13179) {
13180    let in_elems = n * c_in * h * w_in;
13181    let w_elems = c_in * (c_out / groups) * kh * kw;
13182    let out_elems = n * c_out * h_out * w_out;
13183    unsafe {
13184        let input = sl(src, base, in_elems);
13185        let wt = sl(weight, base, w_elems);
13186        let output = sl_mut(dst, base, out_elems);
13187        crate::kernels::conv_transpose2d_nchw(
13188            input, wt, output, n, c_in, h, w_in, c_out, h_out, w_out, kh, kw, sh, sw, ph, pw, dh,
13189            dw, groups,
13190        );
13191    }
13192}
13193
13194/// Host fallback for nearest 2× upsample on NCHW.
13195pub unsafe fn execute_resize_nearest_2x_f32(
13196    src: usize,
13197    dst: usize,
13198    n: usize,
13199    c: usize,
13200    h: usize,
13201    w: usize,
13202    base: *mut u8,
13203) {
13204    let in_plane = c * h * w;
13205    let out_plane = c * h * 2 * w * 2;
13206    for ni in 0..n {
13207        let input = unsafe {
13208            sl(
13209                src + ni * in_plane * std::mem::size_of::<f32>(),
13210                base,
13211                in_plane,
13212            )
13213        };
13214        let output = unsafe {
13215            sl_mut(
13216                dst + ni * out_plane * std::mem::size_of::<f32>(),
13217                base,
13218                out_plane,
13219            )
13220        };
13221        crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
13222    }
13223}
13224
13225/// Host axial 2-D RoPE for Metal (and other) fallbacks on unified memory.
13226pub unsafe fn execute_axial_rope2d_f32(
13227    src: usize,
13228    dst: usize,
13229    batch: usize,
13230    seq: usize,
13231    hidden: usize,
13232    end_x: usize,
13233    end_y: usize,
13234    head_dim: usize,
13235    num_heads: usize,
13236    theta: f32,
13237    repeat_factor: usize,
13238    base: *mut u8,
13239) {
13240    let plane = seq * hidden;
13241    let plane_bytes = plane * std::mem::size_of::<f32>();
13242    for bi in 0..batch {
13243        let in_off = src + bi * plane_bytes;
13244        let input = unsafe { sl(in_off, base, plane) };
13245        let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
13246            input,
13247            num_heads,
13248            seq,
13249            head_dim,
13250            end_x,
13251            end_y,
13252            theta,
13253            repeat_factor,
13254        );
13255        let out_off = dst + bi * plane_bytes;
13256        let output = unsafe { sl_mut(out_off, base, plane) };
13257        output.copy_from_slice(&rotated);
13258    }
13259}
13260
13261/// f32 mirror of `execute_fft1d_f64`. Same public-host-fallback role.
13262pub unsafe fn execute_fft1d_f32(
13263    src: usize,
13264    dst: usize,
13265    outer: usize,
13266    n_complex: usize,
13267    inverse: bool,
13268    base: *mut u8,
13269) {
13270    let row_elems = 2 * n_complex;
13271    let mut re = vec![0f32; n_complex];
13272    let mut im = vec![0f32; n_complex];
13273    let mut scratch = if n_complex.is_power_of_two() {
13274        BluesteinScratchF32::empty()
13275    } else {
13276        BluesteinScratchF32::build(n_complex, inverse)
13277    };
13278    for o in 0..outer {
13279        let row_offset = src + o * row_elems * std::mem::size_of::<f32>();
13280        let s = unsafe { sl(row_offset, base, row_elems) };
13281        re.copy_from_slice(&s[..n_complex]);
13282        im.copy_from_slice(&s[n_complex..]);
13283        if n_complex.is_power_of_two() {
13284            fft_radix2_inplace_f32(&mut re, &mut im, inverse);
13285        } else {
13286            fft_bluestein_inplace_f32(&mut re, &mut im, inverse, &mut scratch);
13287        }
13288        let dst_offset = dst + o * row_elems * std::mem::size_of::<f32>();
13289        let d = unsafe { sl_mut(dst_offset, base, row_elems) };
13290        d[..n_complex].copy_from_slice(&re);
13291        d[n_complex..].copy_from_slice(&im);
13292    }
13293}
13294
13295/// f32 in-place radix-2 DIT Cooley-Tukey. Structurally identical to
13296/// the f64 path; twiddle recurrence is kept in f64 so accumulated
13297/// rotation drift doesn't dominate the per-stage error budget at
13298/// larger N.
13299fn fft_radix2_inplace_f32(re: &mut [f32], im: &mut [f32], inverse: bool) {
13300    let n = re.len();
13301    debug_assert_eq!(im.len(), n);
13302    debug_assert!(
13303        n.is_power_of_two(),
13304        "fft_radix2_f32: n={n} must be a power of two"
13305    );
13306    if n <= 1 {
13307        return;
13308    }
13309
13310    let mut j = 0usize;
13311    for i in 1..n {
13312        let mut bit = n >> 1;
13313        while j & bit != 0 {
13314            j ^= bit;
13315            bit >>= 1;
13316        }
13317        j ^= bit;
13318        if i < j {
13319            re.swap(i, j);
13320            im.swap(i, j);
13321        }
13322    }
13323
13324    let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
13325    let mut len = 2usize;
13326    while len <= n {
13327        let half = len / 2;
13328        let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
13329        let w_re_step = theta.cos();
13330        let w_im_step = theta.sin();
13331        let mut i = 0usize;
13332        while i < n {
13333            let mut wre = 1.0_f64;
13334            let mut wim = 0.0_f64;
13335            for k in 0..half {
13336                let wre_f = wre as f32;
13337                let wim_f = wim as f32;
13338                let t_re = wre_f * re[i + k + half] - wim_f * im[i + k + half];
13339                let t_im = wre_f * im[i + k + half] + wim_f * re[i + k + half];
13340                let u_re = re[i + k];
13341                let u_im = im[i + k];
13342                re[i + k] = u_re + t_re;
13343                im[i + k] = u_im + t_im;
13344                re[i + k + half] = u_re - t_re;
13345                im[i + k + half] = u_im - t_im;
13346                let new_wre = wre * w_re_step - wim * w_im_step;
13347                let new_wim = wre * w_im_step + wim * w_re_step;
13348                wre = new_wre;
13349                wim = new_wim;
13350            }
13351            i += len;
13352        }
13353        len <<= 1;
13354    }
13355}
13356
13357/// In-place radix-2 DIT Cooley-Tukey FFT on split (real, imag) f64
13358/// arrays. `n = re.len() = im.len()` must be a power of two. Forward
13359/// uses ω = exp(-2πi/n); inverse uses ω = exp(+2πi/n) (no 1/N scale).
13360fn fft_radix2_inplace_f64(re: &mut [f64], im: &mut [f64], inverse: bool) {
13361    let n = re.len();
13362    debug_assert_eq!(im.len(), n);
13363    debug_assert!(
13364        n.is_power_of_two(),
13365        "fft_radix2: n={n} must be a power of two"
13366    );
13367    if n <= 1 {
13368        return;
13369    }
13370
13371    // Bit-reverse permutation.
13372    let mut j = 0usize;
13373    for i in 1..n {
13374        let mut bit = n >> 1;
13375        while j & bit != 0 {
13376            j ^= bit;
13377            bit >>= 1;
13378        }
13379        j ^= bit;
13380        if i < j {
13381            re.swap(i, j);
13382            im.swap(i, j);
13383        }
13384    }
13385
13386    // Cooley-Tukey butterflies: ω_len = exp(±2πi/len).
13387    let sign = if inverse { 1.0 } else { -1.0 };
13388    let mut len = 2usize;
13389    while len <= n {
13390        let half = len / 2;
13391        let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
13392        let w_re_step = theta.cos();
13393        let w_im_step = theta.sin();
13394        let mut i = 0usize;
13395        while i < n {
13396            // Twiddle starts at 1+0i for each segment.
13397            let mut wre = 1.0_f64;
13398            let mut wim = 0.0_f64;
13399            for k in 0..half {
13400                let t_re = wre * re[i + k + half] - wim * im[i + k + half];
13401                let t_im = wre * im[i + k + half] + wim * re[i + k + half];
13402                let u_re = re[i + k];
13403                let u_im = im[i + k];
13404                re[i + k] = u_re + t_re;
13405                im[i + k] = u_im + t_im;
13406                re[i + k + half] = u_re - t_re;
13407                im[i + k + half] = u_im - t_im;
13408                let new_wre = wre * w_re_step - wim * w_im_step;
13409                let new_wim = wre * w_im_step + wim * w_re_step;
13410                wre = new_wre;
13411                wim = new_wim;
13412            }
13413            i += len;
13414        }
13415        len <<= 1;
13416    }
13417}
13418
13419/// Pre-computed chirp + filter-spectrum for one (N, direction) pair.
13420/// Built once per call to `execute_fft1d_f64` and reused across rows
13421/// when `outer > 1` — the chirp and FFT(b) don't depend on the input.
13422struct BluesteinScratchF64 {
13423    /// Power-of-two convolution length, ≥ 2N - 1.
13424    m: usize,
13425    /// `w[k] = exp(sign · iπ · k² / N)` for k=0..N, where sign matches
13426    /// the requested direction. Forward chirp on the way in, output
13427    /// chirp on the way out.
13428    w_re: Vec<f64>,
13429    w_im: Vec<f64>,
13430    /// FFT of the embedded filter `b[k] = conj(w[|k|])` in length-M.
13431    /// Doesn't depend on the input — precomputed once.
13432    bf_re: Vec<f64>,
13433    bf_im: Vec<f64>,
13434    /// Workspace reused per row (avoids per-row allocation).
13435    ar: Vec<f64>,
13436    ai: Vec<f64>,
13437}
13438
13439impl BluesteinScratchF64 {
13440    fn empty() -> Self {
13441        Self {
13442            m: 0,
13443            w_re: Vec::new(),
13444            w_im: Vec::new(),
13445            bf_re: Vec::new(),
13446            bf_im: Vec::new(),
13447            ar: Vec::new(),
13448            ai: Vec::new(),
13449        }
13450    }
13451
13452    fn build(n: usize, inverse: bool) -> Self {
13453        // M = next power of two ≥ 2N - 1 keeps the inner FFT on the
13454        // fast radix-2 path. For N=1 fall back to M=1 (no-op convolution).
13455        let m = if n <= 1 {
13456            1
13457        } else {
13458            (2 * n - 1).next_power_of_two()
13459        };
13460
13461        // Chirp arg reduced via k² mod 2N — without this, large N
13462        // bleeds precision into the trig call (n² grows quadratically).
13463        let mod_2n = (2 * n) as u64;
13464        let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
13465        let mut w_re = vec![0.0_f64; n];
13466        let mut w_im = vec![0.0_f64; n];
13467        for k in 0..n {
13468            let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
13469            let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
13470            w_re[k] = theta.cos();
13471            w_im[k] = theta.sin();
13472        }
13473
13474        // Embed b[k] = conj(w[|k|]) into length M with the negative
13475        // indices wrapping to the tail: b[-j] → B[M-j] for j=1..N-1.
13476        let mut bf_re = vec![0.0_f64; m];
13477        let mut bf_im = vec![0.0_f64; m];
13478        if n > 0 {
13479            bf_re[0] = w_re[0];
13480            bf_im[0] = -w_im[0];
13481            for k in 1..n {
13482                bf_re[k] = w_re[k];
13483                bf_im[k] = -w_im[k];
13484                bf_re[m - k] = w_re[k];
13485                bf_im[m - k] = -w_im[k];
13486            }
13487        }
13488        if m > 1 {
13489            fft_radix2_inplace_f64(&mut bf_re, &mut bf_im, false);
13490        }
13491
13492        Self {
13493            m,
13494            w_re,
13495            w_im,
13496            bf_re,
13497            bf_im,
13498            ar: vec![0.0_f64; m],
13499            ai: vec![0.0_f64; m],
13500        }
13501    }
13502}
13503
13504/// Bluestein (chirp-z) FFT for arbitrary N. Identity used:
13505///   `n·k = (n² + k² - (k-n)²) / 2`
13506/// which lets the DFT be written as a linear convolution sandwiched
13507/// between two chirp multiplies:
13508///   `X[k] = w[k] · ((x·w) ⊛ conj(w))[k]`   where `w[n] = exp(±iπ·n²/N)`.
13509/// The convolution is computed via a length-M radix-2 FFT (M ≥ 2N-1).
13510/// Both directions stay unnormalized to match the radix-2 path, so the
13511/// chain rule keeps working without scaling.
13512fn fft_bluestein_inplace_f64(
13513    re: &mut [f64],
13514    im: &mut [f64],
13515    _inverse: bool,
13516    s: &mut BluesteinScratchF64,
13517) {
13518    let n = re.len();
13519    debug_assert_eq!(im.len(), n);
13520    debug_assert_eq!(s.w_re.len(), n);
13521    if n <= 1 {
13522        return;
13523    }
13524    let m = s.m;
13525
13526    // Pre-chirp: a[k] = x[k] · w[k], zero-padded to M.
13527    for k in 0..m {
13528        s.ar[k] = 0.0;
13529        s.ai[k] = 0.0;
13530    }
13531    for k in 0..n {
13532        s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
13533        s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
13534    }
13535
13536    // Length-M forward FFT of the padded chirped input.
13537    fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, false);
13538
13539    // Pointwise product with FFT(b). Stored back into (ar, ai).
13540    for k in 0..m {
13541        let ar = s.ar[k];
13542        let ai = s.ai[k];
13543        let br = s.bf_re[k];
13544        let bi = s.bf_im[k];
13545        s.ar[k] = ar * br - ai * bi;
13546        s.ai[k] = ar * bi + ai * br;
13547    }
13548
13549    // Inverse FFT — radix-2 here is the unnormalized inverse, so we
13550    // divide by M to recover the true circular convolution.
13551    fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, true);
13552    let inv_m = 1.0 / (m as f64);
13553
13554    // Post-chirp: X[k] = w[k] · Y[k] / M for k = 0..N.
13555    for k in 0..n {
13556        let yr = s.ar[k] * inv_m;
13557        let yi = s.ai[k] * inv_m;
13558        re[k] = yr * s.w_re[k] - yi * s.w_im[k];
13559        im[k] = yr * s.w_im[k] + yi * s.w_re[k];
13560    }
13561}
13562
13563/// f32 mirror of `BluesteinScratchF64`. Chirp is computed in f64 for
13564/// precision (same justification as the radix-2 f32 path: twiddles in
13565/// f64, butterflies in f32). The actual conv buffers are f32.
13566struct BluesteinScratchF32 {
13567    m: usize,
13568    w_re: Vec<f32>,
13569    w_im: Vec<f32>,
13570    bf_re: Vec<f32>,
13571    bf_im: Vec<f32>,
13572    ar: Vec<f32>,
13573    ai: Vec<f32>,
13574}
13575
13576impl BluesteinScratchF32 {
13577    fn empty() -> Self {
13578        Self {
13579            m: 0,
13580            w_re: Vec::new(),
13581            w_im: Vec::new(),
13582            bf_re: Vec::new(),
13583            bf_im: Vec::new(),
13584            ar: Vec::new(),
13585            ai: Vec::new(),
13586        }
13587    }
13588
13589    fn build(n: usize, inverse: bool) -> Self {
13590        let m = if n <= 1 {
13591            1
13592        } else {
13593            (2 * n - 1).next_power_of_two()
13594        };
13595
13596        let mod_2n = (2 * n) as u64;
13597        let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
13598        let mut w_re = vec![0.0_f32; n];
13599        let mut w_im = vec![0.0_f32; n];
13600        for k in 0..n {
13601            let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
13602            let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
13603            w_re[k] = theta.cos() as f32;
13604            w_im[k] = theta.sin() as f32;
13605        }
13606
13607        let mut bf_re = vec![0.0_f32; m];
13608        let mut bf_im = vec![0.0_f32; m];
13609        if n > 0 {
13610            bf_re[0] = w_re[0];
13611            bf_im[0] = -w_im[0];
13612            for k in 1..n {
13613                bf_re[k] = w_re[k];
13614                bf_im[k] = -w_im[k];
13615                bf_re[m - k] = w_re[k];
13616                bf_im[m - k] = -w_im[k];
13617            }
13618        }
13619        if m > 1 {
13620            fft_radix2_inplace_f32(&mut bf_re, &mut bf_im, false);
13621        }
13622
13623        Self {
13624            m,
13625            w_re,
13626            w_im,
13627            bf_re,
13628            bf_im,
13629            ar: vec![0.0_f32; m],
13630            ai: vec![0.0_f32; m],
13631        }
13632    }
13633}
13634
13635fn fft_bluestein_inplace_f32(
13636    re: &mut [f32],
13637    im: &mut [f32],
13638    _inverse: bool,
13639    s: &mut BluesteinScratchF32,
13640) {
13641    let n = re.len();
13642    debug_assert_eq!(im.len(), n);
13643    debug_assert_eq!(s.w_re.len(), n);
13644    if n <= 1 {
13645        return;
13646    }
13647    let m = s.m;
13648
13649    for k in 0..m {
13650        s.ar[k] = 0.0;
13651        s.ai[k] = 0.0;
13652    }
13653    for k in 0..n {
13654        s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
13655        s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
13656    }
13657
13658    fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, false);
13659
13660    for k in 0..m {
13661        let ar = s.ar[k];
13662        let ai = s.ai[k];
13663        let br = s.bf_re[k];
13664        let bi = s.bf_im[k];
13665        s.ar[k] = ar * br - ai * bi;
13666        s.ai[k] = ar * bi + ai * br;
13667    }
13668
13669    fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, true);
13670    let inv_m = 1.0_f32 / (m as f32);
13671
13672    for k in 0..n {
13673        let yr = s.ar[k] * inv_m;
13674        let yi = s.ai[k] * inv_m;
13675        re[k] = yr * s.w_re[k] - yi * s.w_im[k];
13676        im[k] = yr * s.w_im[k] + yi * s.w_re[k];
13677    }
13678}
13679
13680/// Shared dispatch path for `Thunk::CustomOp`. Builds a typed
13681/// [`CpuTensorRef`] for each input *at that input's declared dtype*
13682/// (so a sparse-LU op with mixed F64/I32 inputs gets the right
13683/// typed slices) and a [`CpuTensorMut`] for the output, then calls
13684/// the kernel's single `execute` method.
13685unsafe fn dispatch_custom_op(
13686    kernel: &dyn crate::op_registry::CpuKernel,
13687    inputs: &[(usize, u32, Shape)],
13688    out_off: usize,
13689    out_len: u32,
13690    out_shape: &Shape,
13691    attrs: &[u8],
13692    base: *mut u8,
13693) {
13694    use crate::op_registry::{CpuTensorMut, CpuTensorRef};
13695    use rlx_ir::DType;
13696
13697    // One arm per `DType` variant — single source of truth for
13698    // "which dtypes the CPU custom-op dispatcher wires." If a new
13699    // DType lands in `rlx-ir`, the compiler flags this match as
13700    // non-exhaustive and the gap gets named at the right place.
13701    macro_rules! build_in_view {
13702        ($shape:expr, $off:expr, $n:expr, $variant:ident, $rust_ty:ty) => {
13703            CpuTensorRef::$variant {
13704                data: unsafe { sl_typed::<$rust_ty>($off, base, $n) },
13705                shape: $shape,
13706            }
13707        };
13708    }
13709    macro_rules! build_out_view {
13710        ($variant:ident, $rust_ty:ty) => {
13711            CpuTensorMut::$variant {
13712                data: unsafe { sl_mut_typed::<$rust_ty>(out_off, base, out_len as usize) },
13713                shape: out_shape,
13714            }
13715        };
13716    }
13717
13718    let in_views: Vec<CpuTensorRef<'_>> = inputs
13719        .iter()
13720        .map(|(off, len, shape)| {
13721            let n = *len as usize;
13722            let off = *off;
13723            match shape.dtype() {
13724                DType::F32 => build_in_view!(shape, off, n, F32, f32),
13725                DType::F64 => build_in_view!(shape, off, n, F64, f64),
13726                DType::F16 => build_in_view!(shape, off, n, F16, half::f16),
13727                DType::BF16 => build_in_view!(shape, off, n, BF16, half::bf16),
13728                DType::I8 => build_in_view!(shape, off, n, I8, i8),
13729                DType::I16 => build_in_view!(shape, off, n, I16, i16),
13730                DType::I32 => build_in_view!(shape, off, n, I32, i32),
13731                DType::I64 => build_in_view!(shape, off, n, I64, i64),
13732                DType::U8 => build_in_view!(shape, off, n, U8, u8),
13733                DType::U32 => build_in_view!(shape, off, n, U32, u32),
13734                DType::Bool => build_in_view!(shape, off, n, Bool, u8),
13735                // C64 isn't a CpuTensor variant today; the user-registered
13736                // op_registry path doesn't see complex inputs (those are
13737                // handled by built-in ops with dedicated kernels).
13738                DType::C64 => panic!(
13739                    "Op::Custom kernel input has DType::C64 — built-in \
13740                 complex ops handle their own kernels; user-registered \
13741                 ops don't yet see complex tensors"
13742                ),
13743            }
13744        })
13745        .collect();
13746
13747    let result = match out_shape.dtype() {
13748        DType::F32 => kernel.execute(&in_views, build_out_view!(F32, f32), attrs),
13749        DType::F64 => kernel.execute(&in_views, build_out_view!(F64, f64), attrs),
13750        DType::F16 => kernel.execute(&in_views, build_out_view!(F16, half::f16), attrs),
13751        DType::BF16 => kernel.execute(&in_views, build_out_view!(BF16, half::bf16), attrs),
13752        DType::I8 => kernel.execute(&in_views, build_out_view!(I8, i8), attrs),
13753        DType::I16 => kernel.execute(&in_views, build_out_view!(I16, i16), attrs),
13754        DType::I32 => kernel.execute(&in_views, build_out_view!(I32, i32), attrs),
13755        DType::I64 => kernel.execute(&in_views, build_out_view!(I64, i64), attrs),
13756        DType::U8 => kernel.execute(&in_views, build_out_view!(U8, u8), attrs),
13757        DType::U32 => kernel.execute(&in_views, build_out_view!(U32, u32), attrs),
13758        DType::Bool => kernel.execute(&in_views, build_out_view!(Bool, u8), attrs),
13759        DType::C64 => panic!("Op::Custom output DType::C64 not supported"),
13760    };
13761    if let Err(e) = result {
13762        panic!("Op::Custom('{}') CPU kernel failed: {e}", kernel.name());
13763    }
13764}
13765
13766/// Generic raw-cast slice helper. The existing per-dtype `sl_*` /
13767/// `sl_mut_*` helpers stay in place for the rest of `thunk.rs` (which
13768/// uses them at call sites with concrete dtypes); the custom-op
13769/// dispatcher uses these to enumerate every `DType` uniformly without
13770/// listing one helper per dtype.
13771#[inline(always)]
13772unsafe fn sl_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static [T] {
13773    if offset == usize::MAX {
13774        return &[];
13775    }
13776    unsafe { std::slice::from_raw_parts(base.add(offset) as *const T, len) }
13777}
13778
13779#[inline(always)]
13780unsafe fn sl_mut_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static mut [T] {
13781    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut T, len) }
13782}
13783
13784// Unsafe helpers to create slices from arena base + offset
13785#[inline(always)]
13786/// In-place per-element activation. Mirrors the dispatch in
13787/// `Thunk::ActivationInPlace`. Used by `Thunk::FusedMmBiasAct` to
13788/// apply the activation after `bias_add` for all non-Gelu cases.
13789fn apply_activation_inplace(d: &mut [f32], act: rlx_ir::op::Activation) {
13790    use rlx_ir::op::Activation;
13791    match act {
13792        Activation::Gelu => crate::kernels::par_gelu_inplace(d),
13793        Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
13794        Activation::Silu => crate::kernels::par_silu_inplace(d),
13795        Activation::Relu => {
13796            for v in d.iter_mut() {
13797                *v = v.max(0.0);
13798            }
13799        }
13800        Activation::Sigmoid => {
13801            for v in d.iter_mut() {
13802                *v = 1.0 / (1.0 + (-*v).exp());
13803            }
13804        }
13805        Activation::Tanh => {
13806            for v in d.iter_mut() {
13807                *v = v.tanh();
13808            }
13809        }
13810        Activation::Exp => {
13811            for v in d.iter_mut() {
13812                *v = v.exp();
13813            }
13814        }
13815        Activation::Log => {
13816            for v in d.iter_mut() {
13817                *v = v.ln();
13818            }
13819        }
13820        Activation::Sqrt => {
13821            for v in d.iter_mut() {
13822                *v = v.sqrt();
13823            }
13824        }
13825        Activation::Rsqrt => {
13826            for v in d.iter_mut() {
13827                *v = 1.0 / v.sqrt();
13828            }
13829        }
13830        Activation::Neg => {
13831            for v in d.iter_mut() {
13832                *v = -*v;
13833            }
13834        }
13835        Activation::Abs => {
13836            for v in d.iter_mut() {
13837                *v = v.abs();
13838            }
13839        }
13840        Activation::Round => {
13841            for v in d.iter_mut() {
13842                *v = v.round();
13843            }
13844        }
13845        Activation::Sin => {
13846            for v in d.iter_mut() {
13847                *v = v.sin();
13848            }
13849        }
13850        Activation::Cos => {
13851            for v in d.iter_mut() {
13852                *v = v.cos();
13853            }
13854        }
13855        Activation::Tan => {
13856            for v in d.iter_mut() {
13857                *v = v.tan();
13858            }
13859        }
13860        Activation::Atan => {
13861            for v in d.iter_mut() {
13862                *v = v.atan();
13863            }
13864        }
13865    }
13866}
13867
13868/// im2col for one image (single batch + group slice).
13869///
13870/// Source `x` is `[c_in, H, W]` row-major. Destination `col` is
13871/// `[c_in · kH · kW, H_out · W_out]` row-major. Out-of-bounds positions
13872/// (in the padded region) are written as 0.
13873///
13874/// `col[(ci · kH · kW + ki · kW + kj) · n_dim + ho · W_out + wo] =
13875///    x[ci, ho·sh + ki·dh − ph, wo·sw + kj·dw_dil − pw]`
13876#[allow(clippy::too_many_arguments)]
13877fn im2col(
13878    x: &[f32],
13879    col: &mut [f32],
13880    c_in: usize,
13881    h: usize,
13882    w: usize,
13883    h_out: usize,
13884    w_out: usize,
13885    kh: usize,
13886    kw: usize,
13887    sh: usize,
13888    sw: usize,
13889    ph: usize,
13890    pw: usize,
13891    dh: usize,
13892    dw_dil: usize,
13893) {
13894    let n_dim = h_out * w_out;
13895    debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
13896    debug_assert_eq!(x.len(), c_in * h * w);
13897    let h_isz = h as isize;
13898    let w_isz = w as isize;
13899    let ph_isz = ph as isize;
13900    let pw_isz = pw as isize;
13901    for ci in 0..c_in {
13902        for ki in 0..kh {
13903            for kj in 0..kw {
13904                let row = ((ci * kh) + ki) * kw + kj;
13905                let row_off = row * n_dim;
13906                for ho in 0..h_out {
13907                    let hi = (ho * sh + ki * dh) as isize - ph_isz;
13908                    if hi < 0 || hi >= h_isz {
13909                        for wo in 0..w_out {
13910                            col[row_off + ho * w_out + wo] = 0.0;
13911                        }
13912                        continue;
13913                    }
13914                    let hi = hi as usize;
13915                    let in_row_off = (ci * h + hi) * w;
13916                    for wo in 0..w_out {
13917                        let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
13918                        col[row_off + ho * w_out + wo] = if wi < 0 || wi >= w_isz {
13919                            0.0
13920                        } else {
13921                            x[in_row_off + wi as usize]
13922                        };
13923                    }
13924                }
13925            }
13926        }
13927    }
13928}
13929
13930/// col2im — inverse of `im2col` with scatter-accumulation. The caller
13931/// is responsible for zeroing `x` if it doesn't already start zero
13932/// (the conv-input-grad path zeros once before the batch loop).
13933///
13934/// `x[ci, hi, wi] += col[(ci · kH · kW + ki · kW + kj) · n_dim + ho · W_out + wo]`
13935/// for all `(ki, kj, ho, wo)` whose `(hi, wi)` lands in `[0, H) × [0, W)`.
13936#[allow(clippy::too_many_arguments)]
13937fn col2im(
13938    col: &[f32],
13939    x: &mut [f32],
13940    c_in: usize,
13941    h: usize,
13942    w: usize,
13943    h_out: usize,
13944    w_out: usize,
13945    kh: usize,
13946    kw: usize,
13947    sh: usize,
13948    sw: usize,
13949    ph: usize,
13950    pw: usize,
13951    dh: usize,
13952    dw_dil: usize,
13953) {
13954    let n_dim = h_out * w_out;
13955    debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
13956    debug_assert_eq!(x.len(), c_in * h * w);
13957    let h_isz = h as isize;
13958    let w_isz = w as isize;
13959    let ph_isz = ph as isize;
13960    let pw_isz = pw as isize;
13961    for ci in 0..c_in {
13962        for ki in 0..kh {
13963            for kj in 0..kw {
13964                let row = ((ci * kh) + ki) * kw + kj;
13965                let row_off = row * n_dim;
13966                for ho in 0..h_out {
13967                    let hi = (ho * sh + ki * dh) as isize - ph_isz;
13968                    if hi < 0 || hi >= h_isz {
13969                        continue;
13970                    }
13971                    let hi = hi as usize;
13972                    let in_row_off = (ci * h + hi) * w;
13973                    for wo in 0..w_out {
13974                        let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
13975                        if wi < 0 || wi >= w_isz {
13976                            continue;
13977                        }
13978                        x[in_row_off + wi as usize] += col[row_off + ho * w_out + wo];
13979                    }
13980                }
13981            }
13982        }
13983    }
13984}
13985
13986/// Element-wise backward for `Op::Activation`. `xs` is the original
13987/// input to the forward activation; `dys` is the upstream gradient.
13988/// Writes `out[i] = (d/dx act(xs[i])) * dys[i]`.
13989/// Decompose a per-channel quantization shape into the
13990/// `(chan_axis, chan_dim, inner)` triplet the kernel needs to map a
13991/// flat output index to a channel index. Per-tensor (`axis = None`)
13992/// degenerates to `chan_dim = 1, inner = len`, which makes the
13993/// kernel's `(i / inner) % chan_dim` always 0 — same fast path the
13994/// scalar version used.
13995fn quant_layout(shape: &rlx_ir::Shape, axis: Option<usize>) -> (usize, usize, usize) {
13996    match axis {
13997        None => (0, 1, shape.num_elements().unwrap_or(0).max(1)),
13998        Some(d) => {
13999            let chan_dim = shape.dim(d).unwrap_static();
14000            let inner: usize = (d + 1..shape.rank())
14001                .map(|i| shape.dim(i).unwrap_static())
14002                .product::<usize>()
14003                .max(1);
14004            (d, chan_dim, inner)
14005        }
14006    }
14007}
14008
14009fn activation_backward_kernel(
14010    act: rlx_ir::op::Activation,
14011    xs: &[f32],
14012    dys: &[f32],
14013    out: &mut [f32],
14014) {
14015    use rlx_ir::op::Activation;
14016    let n = xs.len();
14017    debug_assert_eq!(dys.len(), n);
14018    debug_assert_eq!(out.len(), n);
14019    match act {
14020        Activation::Relu => {
14021            for i in 0..n {
14022                out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
14023            }
14024        }
14025        Activation::Sigmoid => {
14026            for i in 0..n {
14027                let s = 1.0 / (1.0 + (-xs[i]).exp());
14028                out[i] = s * (1.0 - s) * dys[i];
14029            }
14030        }
14031        Activation::Tanh => {
14032            for i in 0..n {
14033                let t = xs[i].tanh();
14034                out[i] = (1.0 - t * t) * dys[i];
14035            }
14036        }
14037        Activation::Silu => {
14038            // y = x * σ(x);  dy/dx = σ(x) * (1 + x * (1 - σ(x))).
14039            for i in 0..n {
14040                let s = 1.0 / (1.0 + (-xs[i]).exp());
14041                out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
14042            }
14043        }
14044        Activation::Gelu => {
14045            // Exact erf-based GELU:  y = 0.5 x (1 + erf(x / √2)).
14046            //   dy/dx = 0.5 (1 + erf(x/√2)) + (x / √(2π)) · exp(-x²/2)
14047            const INV_SQRT2: f32 = 0.707_106_77;
14048            const INV_SQRT_2PI: f32 = 0.398_942_3;
14049            for i in 0..n {
14050                let x = xs[i];
14051                let phi = 0.5 * (1.0 + erf_f32(x * INV_SQRT2));
14052                let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
14053                out[i] = (phi + x * pdf) * dys[i];
14054            }
14055        }
14056        Activation::GeluApprox => {
14057            // Tanh-approximation:
14058            //   y = 0.5 x (1 + tanh(c · (x + 0.044715 x³))) where c = √(2/π).
14059            const C: f32 = 0.797_884_6; // √(2/π)
14060            const A: f32 = 0.044_715;
14061            for i in 0..n {
14062                let x = xs[i];
14063                let inner = C * (x + A * x * x * x);
14064                let t = inner.tanh();
14065                let dinner = C * (1.0 + 3.0 * A * x * x);
14066                let d = 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * dinner;
14067                out[i] = d * dys[i];
14068            }
14069        }
14070        Activation::Exp => {
14071            for i in 0..n {
14072                out[i] = xs[i].exp() * dys[i];
14073            }
14074        }
14075        Activation::Log => {
14076            for i in 0..n {
14077                out[i] = dys[i] / xs[i];
14078            }
14079        }
14080        Activation::Sqrt => {
14081            // d/dx √x = 0.5 / √x — undefined at x=0; clamp to 0.
14082            for i in 0..n {
14083                let s = xs[i].sqrt();
14084                out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
14085            }
14086        }
14087        Activation::Rsqrt => {
14088            // d/dx (1/√x) = -0.5 · x^(-3/2).
14089            for i in 0..n {
14090                let s = xs[i].sqrt();
14091                out[i] = if s > 0.0 {
14092                    -0.5 * dys[i] / (xs[i] * s)
14093                } else {
14094                    0.0
14095                };
14096            }
14097        }
14098        Activation::Neg => {
14099            for i in 0..n {
14100                out[i] = -dys[i];
14101            }
14102        }
14103        Activation::Abs => {
14104            // sign(x); 0 at x=0.
14105            for i in 0..n {
14106                let x = xs[i];
14107                let s = if x > 0.0 {
14108                    1.0
14109                } else if x < 0.0 {
14110                    -1.0
14111                } else {
14112                    0.0
14113                };
14114                out[i] = s * dys[i];
14115            }
14116        }
14117        Activation::Round => {
14118            // STE: pretend the round was identity in the backward
14119            // pass. The round step has zero gradient almost
14120            // everywhere, so without this trick the optimizer can't
14121            // learn through it.
14122            out.copy_from_slice(dys);
14123        }
14124        Activation::Sin => {
14125            // d/dx sin(x) = cos(x).
14126            for i in 0..n {
14127                out[i] = xs[i].cos() * dys[i];
14128            }
14129        }
14130        Activation::Cos => {
14131            for i in 0..n {
14132                out[i] = -xs[i].sin() * dys[i];
14133            }
14134        }
14135        Activation::Tan => {
14136            // d/dx tan(x) = sec²(x) = 1 + tan²(x)
14137            for i in 0..n {
14138                let t = xs[i].tan();
14139                out[i] = (1.0 + t * t) * dys[i];
14140            }
14141        }
14142        Activation::Atan => {
14143            // d/dx atan(x) = 1 / (1 + x²)
14144            for i in 0..n {
14145                let x = xs[i];
14146                out[i] = dys[i] / (1.0 + x * x);
14147            }
14148        }
14149    }
14150}
14151
14152/// f64 sibling of `activation_backward_kernel`. Same math, twice the
14153/// precision — used by f64 graphs where the f32 kernel reading bytes
14154/// as `&[f32]` would silently discard half of every f64 value.
14155fn activation_backward_kernel_f64(
14156    act: rlx_ir::op::Activation,
14157    xs: &[f64],
14158    dys: &[f64],
14159    out: &mut [f64],
14160) {
14161    use rlx_ir::op::Activation;
14162    let n = xs.len();
14163    debug_assert_eq!(dys.len(), n);
14164    debug_assert_eq!(out.len(), n);
14165    match act {
14166        Activation::Relu => {
14167            for i in 0..n {
14168                out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
14169            }
14170        }
14171        Activation::Sigmoid => {
14172            for i in 0..n {
14173                let s = 1.0 / (1.0 + (-xs[i]).exp());
14174                out[i] = s * (1.0 - s) * dys[i];
14175            }
14176        }
14177        Activation::Tanh => {
14178            for i in 0..n {
14179                let t = xs[i].tanh();
14180                out[i] = (1.0 - t * t) * dys[i];
14181            }
14182        }
14183        Activation::Silu => {
14184            for i in 0..n {
14185                let s = 1.0 / (1.0 + (-xs[i]).exp());
14186                out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
14187            }
14188        }
14189        Activation::Gelu | Activation::GeluApprox => {
14190            // Both rare on f64 paths; use the high-quality libm erf.
14191            const INV_SQRT2: f64 = std::f64::consts::FRAC_1_SQRT_2;
14192            const INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
14193            for i in 0..n {
14194                let x = xs[i];
14195                let phi = 0.5 * (1.0 + erf_f64(x * INV_SQRT2));
14196                let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
14197                out[i] = (phi + x * pdf) * dys[i];
14198            }
14199        }
14200        Activation::Exp => {
14201            for i in 0..n {
14202                out[i] = xs[i].exp() * dys[i];
14203            }
14204        }
14205        Activation::Log => {
14206            for i in 0..n {
14207                out[i] = dys[i] / xs[i];
14208            }
14209        }
14210        Activation::Sqrt => {
14211            for i in 0..n {
14212                let s = xs[i].sqrt();
14213                out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
14214            }
14215        }
14216        Activation::Rsqrt => {
14217            for i in 0..n {
14218                let s = xs[i].sqrt();
14219                out[i] = if s > 0.0 {
14220                    -0.5 * dys[i] / (xs[i] * s)
14221                } else {
14222                    0.0
14223                };
14224            }
14225        }
14226        Activation::Neg => {
14227            for i in 0..n {
14228                out[i] = -dys[i];
14229            }
14230        }
14231        Activation::Abs => {
14232            for i in 0..n {
14233                let x = xs[i];
14234                let s = if x > 0.0 {
14235                    1.0
14236                } else if x < 0.0 {
14237                    -1.0
14238                } else {
14239                    0.0
14240                };
14241                out[i] = s * dys[i];
14242            }
14243        }
14244        Activation::Round => {
14245            out.copy_from_slice(dys);
14246        }
14247        Activation::Sin => {
14248            for i in 0..n {
14249                out[i] = xs[i].cos() * dys[i];
14250            }
14251        }
14252        Activation::Cos => {
14253            for i in 0..n {
14254                out[i] = -xs[i].sin() * dys[i];
14255            }
14256        }
14257        Activation::Tan => {
14258            for i in 0..n {
14259                let t = xs[i].tan();
14260                out[i] = (1.0 + t * t) * dys[i];
14261            }
14262        }
14263        Activation::Atan => {
14264            for i in 0..n {
14265                let x = xs[i];
14266                out[i] = dys[i] / (1.0 + x * x);
14267            }
14268        }
14269    }
14270}
14271
14272/// f64 erf via A&S 7.1.26 — same coefficients as `erf_f32`, computed
14273/// at f64 width. Max error ~1.5e-7 (limited by the polynomial, not the
14274/// arithmetic). Adequate for gradient kernels; if higher precision is
14275/// needed, swap in a libm dependency.
14276#[inline(always)]
14277fn erf_f64(x: f64) -> f64 {
14278    let s = x.signum();
14279    let x = x.abs();
14280    let t = 1.0 / (1.0 + 0.327_591_1 * x);
14281    let y = 1.0
14282        - (((((1.061_405_43 * t - 1.453_152_03) * t) + 1.421_413_75) * t - 0.284_496_74) * t
14283            + 0.254_829_59)
14284            * t
14285            * (-x * x).exp();
14286    s * y
14287}
14288
14289/// Cheap erf approximation (Abramowitz & Stegun 7.1.26, max error ~1.5e-7
14290/// over all of ℝ — plenty for f32 gradient kernels).
14291#[inline(always)]
14292fn erf_f32(x: f32) -> f32 {
14293    let s = x.signum();
14294    let x = x.abs();
14295    let t = 1.0 / (1.0 + 0.327_591_1 * x);
14296    let y = 1.0
14297        - (((((1.061_405_4 * t - 1.453_152_1) * t) + 1.421_413_8) * t - 0.284_496_74) * t
14298            + 0.254_829_6)
14299            * t
14300            * (-x * x).exp();
14301    s * y
14302}
14303
14304fn narrow_thunk_closure(
14305    src: usize,
14306    dst: usize,
14307    outer: u32,
14308    src_stride: u32,
14309    dst_stride: u32,
14310    inner: u32,
14311    elem_bytes: u8,
14312) -> Arc<dyn Fn(*mut u8) + Send + Sync> {
14313    let (outer, ss, ds, inner) = (
14314        outer as usize,
14315        src_stride as usize,
14316        dst_stride as usize,
14317        inner as usize,
14318    );
14319    if elem_bytes == 8 {
14320        Arc::new(move |base: *mut u8| unsafe {
14321            let s = sl_f64(src, base, outer * ss);
14322            let d = sl_mut_f64(dst, base, outer * ds);
14323            for o in 0..outer {
14324                d[o * ds..o * ds + inner].copy_from_slice(&s[o * ss..o * ss + inner]);
14325            }
14326        })
14327    } else {
14328        Arc::new(move |base: *mut u8| unsafe {
14329            let s = sl(src, base, outer * ss);
14330            let d = sl_mut(dst, base, outer * ds);
14331            for o in 0..outer {
14332                d[o * ds..o * ds + inner].copy_from_slice(&s[o * ss..o * ss + inner]);
14333            }
14334        })
14335    }
14336}
14337
14338unsafe fn sl(offset: usize, base: *mut u8, len: usize) -> &'static [f32] {
14339    if offset == usize::MAX {
14340        return &[];
14341    }
14342    unsafe { std::slice::from_raw_parts(base.add(offset) as *const f32, len) }
14343}
14344
14345#[inline(always)]
14346unsafe fn sl_mut(offset: usize, base: *mut u8, len: usize) -> &'static mut [f32] {
14347    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f32, len) }
14348}
14349
14350#[inline(always)]
14351unsafe fn sl_f64(offset: usize, base: *mut u8, len: usize) -> &'static [f64] {
14352    if offset == usize::MAX {
14353        return &[];
14354    }
14355    unsafe { std::slice::from_raw_parts(base.add(offset) as *const f64, len) }
14356}
14357
14358#[inline(always)]
14359unsafe fn sl_mut_f64(offset: usize, base: *mut u8, len: usize) -> &'static mut [f64] {
14360    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f64, len) }
14361}
14362
14363// i32 / i64 typed slice helpers — siblings of sl_f32/sl_f64. Kept for
14364// integer-tensor thunks that haven't landed yet (Sample, Gather index
14365// buffers); deleting them now would force re-deriving the unsafe
14366// boilerplate when the next int-typed thunk lands.
14367#[allow(dead_code)]
14368#[inline(always)]
14369unsafe fn sl_i32(offset: usize, base: *mut u8, len: usize) -> &'static [i32] {
14370    if offset == usize::MAX {
14371        return &[];
14372    }
14373    unsafe { std::slice::from_raw_parts(base.add(offset) as *const i32, len) }
14374}
14375
14376#[allow(dead_code)]
14377#[inline(always)]
14378unsafe fn sl_mut_i32(offset: usize, base: *mut u8, len: usize) -> &'static mut [i32] {
14379    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i32, len) }
14380}
14381
14382#[allow(dead_code)]
14383#[inline(always)]
14384unsafe fn sl_i64(offset: usize, base: *mut u8, len: usize) -> &'static [i64] {
14385    if offset == usize::MAX {
14386        return &[];
14387    }
14388    unsafe { std::slice::from_raw_parts(base.add(offset) as *const i64, len) }
14389}
14390
14391#[allow(dead_code)]
14392#[inline(always)]
14393unsafe fn sl_mut_i64(offset: usize, base: *mut u8, len: usize) -> &'static mut [i64] {
14394    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i64, len) }
14395}
14396
14397/// f64 N-D index walk used by Transpose and Expand. `out_dims` gives
14398/// the output shape; `in_strides` gives the source stride for each
14399/// output dim (broadcast axes have stride 0).
14400fn transpose_walk_f64(inp: &[f64], out: &mut [f64], out_dims: &[u32], in_strides: &[u32]) {
14401    let rank = out_dims.len();
14402    let mut idx = vec![0u32; rank];
14403    for o in 0..out.len() {
14404        let mut src_off = 0usize;
14405        for d in 0..rank {
14406            src_off += idx[d] as usize * in_strides[d] as usize;
14407        }
14408        out[o] = inp[src_off];
14409        // Increment index — last dim varies fastest.
14410        for d in (0..rank).rev() {
14411            idx[d] += 1;
14412            if idx[d] < out_dims[d] {
14413                break;
14414            }
14415            idx[d] = 0;
14416        }
14417    }
14418}
14419
14420/// f64 elementwise activation. Reads `inp`, writes `out`. For now
14421/// covers what the autodiff-emitted gradient graph needs (Neg, Exp,
14422/// Log, Sqrt, Rsqrt, Abs, Tanh, Sigmoid, Relu — the
14423/// transcendental-free subset). Approximate Gelu/Silu deferred until a
14424/// workload demands them at f64.
14425fn apply_activation_f64(inp: &[f64], out: &mut [f64], kind: Activation) {
14426    match kind {
14427        Activation::Neg => {
14428            for (o, &v) in out.iter_mut().zip(inp) {
14429                *o = -v;
14430            }
14431        }
14432        Activation::Exp => {
14433            for (o, &v) in out.iter_mut().zip(inp) {
14434                *o = v.exp();
14435            }
14436        }
14437        Activation::Log => {
14438            for (o, &v) in out.iter_mut().zip(inp) {
14439                *o = v.ln();
14440            }
14441        }
14442        Activation::Sqrt => {
14443            for (o, &v) in out.iter_mut().zip(inp) {
14444                *o = v.sqrt();
14445            }
14446        }
14447        Activation::Rsqrt => {
14448            for (o, &v) in out.iter_mut().zip(inp) {
14449                *o = 1.0 / v.sqrt();
14450            }
14451        }
14452        Activation::Abs => {
14453            for (o, &v) in out.iter_mut().zip(inp) {
14454                *o = v.abs();
14455            }
14456        }
14457        Activation::Tanh => {
14458            for (o, &v) in out.iter_mut().zip(inp) {
14459                *o = v.tanh();
14460            }
14461        }
14462        Activation::Sigmoid => {
14463            for (o, &v) in out.iter_mut().zip(inp) {
14464                *o = 1.0 / (1.0 + (-v).exp());
14465            }
14466        }
14467        Activation::Relu => {
14468            for (o, &v) in out.iter_mut().zip(inp) {
14469                *o = v.max(0.0);
14470            }
14471        }
14472        Activation::Round => {
14473            for (o, &v) in out.iter_mut().zip(inp) {
14474                *o = v.round_ties_even();
14475            }
14476        }
14477        Activation::Sin => {
14478            for (o, &v) in out.iter_mut().zip(inp) {
14479                *o = v.sin();
14480            }
14481        }
14482        Activation::Cos => {
14483            for (o, &v) in out.iter_mut().zip(inp) {
14484                *o = v.cos();
14485            }
14486        }
14487        Activation::Tan => {
14488            for (o, &v) in out.iter_mut().zip(inp) {
14489                *o = v.tan();
14490            }
14491        }
14492        Activation::Atan => {
14493            for (o, &v) in out.iter_mut().zip(inp) {
14494                *o = v.atan();
14495            }
14496        }
14497        Activation::Gelu | Activation::GeluApprox | Activation::Silu => {
14498            panic!(
14499                "apply_activation_f64: {kind:?} not yet implemented at f64. \
14500                    Add when a workload needs it."
14501            );
14502        }
14503    }
14504}
14505
14506#[inline]
14507fn binary_op_f64(op: BinaryOp, a: f64, b: f64) -> f64 {
14508    match op {
14509        BinaryOp::Add => a + b,
14510        BinaryOp::Sub => a - b,
14511        BinaryOp::Mul => a * b,
14512        BinaryOp::Div => a / b,
14513        BinaryOp::Max => a.max(b),
14514        BinaryOp::Min => a.min(b),
14515        BinaryOp::Pow => a.powf(b),
14516    }
14517}
14518
14519/// f64 sum reduction over a contiguous middle range.
14520/// Layout: input is `[outer, reduced, inner]`, output is `[outer, inner]`.
14521fn reduce_sum_f64(inp: &[f64], out: &mut [f64], outer: usize, reduced: usize, inner: usize) {
14522    for o in 0..outer {
14523        for n in 0..inner {
14524            let mut acc = 0.0_f64;
14525            for r in 0..reduced {
14526                acc += inp[o * reduced * inner + r * inner + n];
14527            }
14528            out[o * inner + n] = acc;
14529        }
14530    }
14531}
14532
14533#[cfg(test)]
14534mod tests {
14535    use super::*;
14536    use rlx_ir::*;
14537
14538    /// Plan #45: when a Narrow's only consumer is a Rope, the thunk
14539    /// fusion pass collapses them — the Narrow becomes Nop, and the
14540    /// Rope reads from the parent buffer with its row stride. This
14541    /// test runs the unfused path (batch*seq > FusedAttnBlock
14542    /// threshold) and asserts the rewrite happened.
14543    #[test]
14544    fn narrow_rope_fuses_in_unfused_path() {
14545        let f = DType::F32;
14546        let mut g = Graph::new("nr_fuse");
14547        // Force batch*seq > 64 so FusedAttnBlock doesn't pre-empt us.
14548        let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f)); // 16*8=128 > 64
14549        let cos = g.input("cos", Shape::new(&[16], f));
14550        let sin = g.input("sin", Shape::new(&[16], f));
14551        // Last-axis narrow: Q = qkv[..., 0..64]
14552        let q = g.narrow_(qkv, 2, 0, 64);
14553        let q_rope = g.rope(q, cos, sin, 16);
14554        g.set_outputs(vec![q_rope]);
14555
14556        let plan = rlx_opt::memory::plan_memory(&g);
14557        let arena = crate::arena::Arena::from_plan(plan);
14558        let sched = compile_thunks(&g, &arena);
14559
14560        let mut narrow_count = 0;
14561        let mut rope_with_stride: Option<u32> = None;
14562        for t in &sched.thunks {
14563            match t {
14564                Thunk::Narrow { .. } => narrow_count += 1,
14565                Thunk::Rope { src_row_stride, .. } => rope_with_stride = Some(*src_row_stride),
14566                _ => {}
14567            }
14568        }
14569        // After fusion the Narrow is gone; only the Rope remains, and
14570        // it now walks with the parent QKV's row stride (3 * 64 = 192).
14571        assert_eq!(
14572            narrow_count, 0,
14573            "Narrow→Rope fusion should leave zero Narrow thunks; saw {narrow_count}"
14574        );
14575        assert_eq!(
14576            rope_with_stride,
14577            Some(192),
14578            "Rope's src_row_stride should be 192 (parent qkv axis), saw {rope_with_stride:?}"
14579        );
14580    }
14581
14582    /// Plan #15: SSM selective scan matches a naive Python-style
14583    /// Python-style sequential reference.
14584    #[test]
14585    fn ssm_selective_scan_matches_reference() {
14586        use rlx_ir::Philox4x32;
14587        let bch = 1usize;
14588        let s = 4usize;
14589        let h = 3usize;
14590        let n = 2usize;
14591
14592        let mut rng = Philox4x32::new(13);
14593        let mut x = vec![0f32; bch * s * h];
14594        rng.fill_normal(&mut x);
14595        let mut delta = vec![0f32; bch * s * h];
14596        // Keep Δ small so exp(Δ·A) doesn't blow up.
14597        for v in delta.iter_mut() {
14598            *v = (rng.next_f32() - 0.5) * 0.1;
14599        }
14600        let mut a = vec![0f32; h * n];
14601        for v in a.iter_mut() {
14602            *v = -(rng.next_f32() * 0.5 + 0.1);
14603        } // negative for stability
14604        let mut b = vec![0f32; bch * s * n];
14605        rng.fill_normal(&mut b);
14606        let mut c = vec![0f32; bch * s * n];
14607        rng.fill_normal(&mut c);
14608
14609        // Reference scan.
14610        let mut expected = vec![0f32; bch * s * h];
14611        for bi in 0..bch {
14612            let mut state = vec![0f32; h * n];
14613            for si in 0..s {
14614                for ci in 0..h {
14615                    let d = delta[bi * s * h + si * h + ci];
14616                    let xv = x[bi * s * h + si * h + ci];
14617                    let mut acc = 0f32;
14618                    for ni in 0..n {
14619                        let da = (d * a[ci * n + ni]).exp();
14620                        state[ci * n + ni] =
14621                            da * state[ci * n + ni] + d * b[bi * s * n + si * n + ni] * xv;
14622                        acc += c[bi * s * n + si * n + ni] * state[ci * n + ni];
14623                    }
14624                    expected[bi * s * h + si * h + ci] = acc;
14625                }
14626            }
14627        }
14628
14629        // RLX path.
14630        let f = DType::F32;
14631        let mut g = Graph::new("ssm");
14632        let xn = g.input("x", Shape::new(&[bch, s, h], f));
14633        let dn = g.input("delta", Shape::new(&[bch, s, h], f));
14634        let an = g.param("a", Shape::new(&[h, n], f));
14635        let bn = g.param("b", Shape::new(&[bch, s, n], f));
14636        let cn = g.param("c", Shape::new(&[bch, s, n], f));
14637        let yn = g.selective_scan(xn, dn, an, bn, cn, n, Shape::new(&[bch, s, h], f));
14638        g.set_outputs(vec![yn]);
14639
14640        let plan = rlx_opt::memory::plan_memory(&g);
14641        let mut arena = crate::arena::Arena::from_plan(plan);
14642        let sched = compile_thunks(&g, &arena);
14643
14644        let xn_off = arena.byte_offset(xn);
14645        let dn_off = arena.byte_offset(dn);
14646        let an_off = arena.byte_offset(an);
14647        let bn_off = arena.byte_offset(bn);
14648        let cn_off = arena.byte_offset(cn);
14649        let yn_off = arena.byte_offset(yn);
14650        let buf = arena.raw_buf_mut();
14651        unsafe {
14652            let copy = |dst: *mut f32, data: &[f32]| {
14653                for (i, &v) in data.iter().enumerate() {
14654                    *dst.add(i) = v;
14655                }
14656            };
14657            copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
14658            copy(buf.as_mut_ptr().add(dn_off) as *mut f32, &delta);
14659            copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
14660            copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
14661            copy(buf.as_mut_ptr().add(cn_off) as *mut f32, &c);
14662        }
14663        execute_thunks(&sched, arena.raw_buf_mut());
14664
14665        let actual: Vec<f32> = unsafe {
14666            let p = arena.raw_buf().as_ptr().add(yn_off) as *const f32;
14667            (0..bch * s * h).map(|i| *p.add(i)).collect()
14668        };
14669
14670        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
14671            assert!(
14672                (e - a).abs() < 1e-3,
14673                "mismatch at {i}: expected {e}, got {a}"
14674            );
14675        }
14676    }
14677
14678    /// Plan #26: 1×1 conv lowers to per-batch sgemm and matches the
14679    /// scalar 7-loop reference.
14680    #[test]
14681    fn conv_1x1_fast_path_matches_scalar() {
14682        use rlx_ir::Philox4x32;
14683        // [N=2, C_in=4, H=3, W=3]
14684        let n = 2usize;
14685        let c_in = 4usize;
14686        let h = 3usize;
14687        let w = 3usize;
14688        let c_out = 5usize;
14689        let mut rng = Philox4x32::new(31);
14690        let mut x = vec![0f32; n * c_in * h * w];
14691        rng.fill_normal(&mut x);
14692        let mut weight = vec![0f32; c_out * c_in];
14693        rng.fill_normal(&mut weight);
14694
14695        // Reference: scalar 1×1 conv = per-batch matmul
14696        // out[ni, co, hi, wi] = sum_ci weight[co, ci] * x[ni, ci, hi, wi]
14697        let mut expected = vec![0f32; n * c_out * h * w];
14698        for ni in 0..n {
14699            for co in 0..c_out {
14700                for hi in 0..h {
14701                    for wi in 0..w {
14702                        let mut acc = 0f32;
14703                        for ci in 0..c_in {
14704                            acc += weight[co * c_in + ci]
14705                                * x[((ni * c_in) + ci) * h * w + hi * w + wi];
14706                        }
14707                        expected[((ni * c_out) + co) * h * w + hi * w + wi] = acc;
14708                    }
14709                }
14710            }
14711        }
14712
14713        // RLX path: build a graph with Op::Conv (kernel=[1,1], stride=[1,1], etc).
14714        let f = DType::F32;
14715        let mut g = Graph::new("conv1x1");
14716        let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
14717        let wn = g.param("w", Shape::new(&[c_out, c_in, 1, 1], f));
14718        // Manually add Op::Conv since there's no `g.conv()` helper.
14719        let cn = g.add_node(
14720            rlx_ir::Op::Conv {
14721                kernel_size: vec![1, 1],
14722                stride: vec![1, 1],
14723                padding: vec![0, 0],
14724                dilation: vec![1, 1],
14725                groups: 1,
14726            },
14727            vec![xn, wn],
14728            Shape::new(&[n, c_out, h, w], f),
14729        );
14730        g.set_outputs(vec![cn]);
14731
14732        let plan = rlx_opt::memory::plan_memory(&g);
14733        let mut arena = crate::arena::Arena::from_plan(plan);
14734        let sched = compile_thunks(&g, &arena);
14735
14736        // Verify the fast path was selected.
14737        let saw_fast = sched
14738            .thunks
14739            .iter()
14740            .any(|t| matches!(t, Thunk::Conv2D1x1 { .. }));
14741        let saw_slow = sched
14742            .thunks
14743            .iter()
14744            .any(|t| matches!(t, Thunk::Conv2D { .. }));
14745        assert!(saw_fast, "1×1 conv should emit Conv2D1x1");
14746        assert!(!saw_slow, "1×1 conv must not fall through to scalar Conv2D");
14747
14748        let xn_off = arena.byte_offset(xn);
14749        let wn_off = arena.byte_offset(wn);
14750        let cn_off = arena.byte_offset(cn);
14751        let buf = arena.raw_buf_mut();
14752        unsafe {
14753            let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
14754            for (i, &v) in x.iter().enumerate() {
14755                *xp.add(i) = v;
14756            }
14757            let wp = buf.as_mut_ptr().add(wn_off) as *mut f32;
14758            for (i, &v) in weight.iter().enumerate() {
14759                *wp.add(i) = v;
14760            }
14761        }
14762        execute_thunks(&sched, arena.raw_buf_mut());
14763
14764        let actual: Vec<f32> = unsafe {
14765            let p = arena.raw_buf().as_ptr().add(cn_off) as *const f32;
14766            (0..(n * c_out * h * w)).map(|i| *p.add(i)).collect()
14767        };
14768
14769        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
14770            assert!(
14771                (e - a).abs() < 1e-3,
14772                "mismatch at {i}: expected {e}, got {a}"
14773            );
14774        }
14775    }
14776
14777    /// Plan #5: fused dequant matmul matches the dequant-then-matmul
14778    /// reference (i.e. `(scale * (q - z)) @ x` materialized).
14779    #[test]
14780    fn dequant_matmul_int8_sym_matches_reference() {
14781        use rlx_ir::Philox4x32;
14782        use rlx_ir::quant::QuantScheme;
14783
14784        let m = 3usize;
14785        let k = 8usize;
14786        let n = 4usize;
14787        let block_size = 4usize; // 2 blocks per column
14788        let blocks_per_col = k / block_size;
14789
14790        // Random inputs: x f32, w_q i8, scales f32. Symmetric → no zp.
14791        let mut rng = Philox4x32::new(99);
14792        let mut x = vec![0f32; m * k];
14793        rng.fill_normal(&mut x);
14794        let w_q: Vec<i8> = (0..(k * n))
14795            .map(|i| ((i as i32 * 13 + 7) % 127 - 63) as i8)
14796            .collect();
14797        let scales: Vec<f32> = (0..(blocks_per_col * n))
14798            .map(|i| 0.01 + 0.001 * i as f32)
14799            .collect();
14800
14801        // Reference: build f32 weights from (q * scale) per block.
14802        let mut w_f32 = vec![0f32; k * n];
14803        for p in 0..k {
14804            let block = p / block_size;
14805            for j in 0..n {
14806                let s = scales[block * n + j];
14807                w_f32[p * n + j] = w_q[p * n + j] as f32 * s;
14808            }
14809        }
14810        let mut expected = vec![0f32; m * n];
14811        for i in 0..m {
14812            for j in 0..n {
14813                let mut acc = 0f32;
14814                for p in 0..k {
14815                    acc += x[i * k + p] * w_f32[p * n + j];
14816                }
14817                expected[i * n + j] = acc;
14818            }
14819        }
14820
14821        // RLX path.
14822        let f = DType::F32;
14823        let mut g = Graph::new("dq");
14824        let xn = g.input("x", Shape::new(&[m, k], f));
14825        let wn = g.param("w", Shape::new(&[k, n], DType::I8));
14826        let sn = g.param("scale", Shape::new(&[blocks_per_col, n], f));
14827        let zn = g.param("zp", Shape::new(&[blocks_per_col, n], f)); // unused (sym)
14828        let dq = g.dequant_matmul(
14829            xn,
14830            wn,
14831            sn,
14832            zn,
14833            QuantScheme::Int8Block {
14834                block_size: block_size as u32,
14835            },
14836            Shape::new(&[m, n], f),
14837        );
14838        g.set_outputs(vec![dq]);
14839
14840        let plan = rlx_opt::memory::plan_memory(&g);
14841        let mut arena = crate::arena::Arena::from_plan(plan);
14842        let sched = compile_thunks(&g, &arena);
14843
14844        let xn_off = arena.byte_offset(xn);
14845        let wn_off = arena.byte_offset(wn);
14846        let sn_off = arena.byte_offset(sn);
14847        let zn_off = arena.byte_offset(zn);
14848        let dq_off = arena.byte_offset(dq);
14849        let buf = arena.raw_buf_mut();
14850        unsafe {
14851            // Seed f32 inputs.
14852            let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
14853            for (i, &v) in x.iter().enumerate() {
14854                *xp.add(i) = v;
14855            }
14856            let sp = buf.as_mut_ptr().add(sn_off) as *mut f32;
14857            for (i, &v) in scales.iter().enumerate() {
14858                *sp.add(i) = v;
14859            }
14860            let zp = buf.as_mut_ptr().add(zn_off) as *mut f32;
14861            for i in 0..(blocks_per_col * n) {
14862                *zp.add(i) = 0.0;
14863            }
14864            // Seed i8 weights byte-by-byte.
14865            let wp = buf.as_mut_ptr().add(wn_off) as *mut i8;
14866            for (i, &v) in w_q.iter().enumerate() {
14867                *wp.add(i) = v;
14868            }
14869        }
14870        execute_thunks(&sched, arena.raw_buf_mut());
14871
14872        let actual: Vec<f32> = unsafe {
14873            let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
14874            (0..m * n).map(|i| *p.add(i)).collect()
14875        };
14876
14877        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
14878            assert!(
14879                (e - a).abs() < 1e-3,
14880                "mismatch at {i}: expected {e}, got {a}"
14881            );
14882        }
14883    }
14884
14885    /// Plan #9: LoRA matmul matches the unfused 3-matmul reference.
14886    #[test]
14887    fn lora_matmul_matches_unfused_reference() {
14888        use rlx_ir::Philox4x32;
14889
14890        let m = 4usize;
14891        let k = 8usize;
14892        let n = 6usize;
14893        let r = 2usize;
14894        let scale = 0.5f32;
14895
14896        // Random inputs (deterministic via Philox).
14897        let mut rng = Philox4x32::new(42);
14898        let mut x = vec![0f32; m * k];
14899        rng.fill_normal(&mut x);
14900        let mut w = vec![0f32; k * n];
14901        rng.fill_normal(&mut w);
14902        let mut a = vec![0f32; k * r];
14903        rng.fill_normal(&mut a);
14904        let mut b = vec![0f32; r * n];
14905        rng.fill_normal(&mut b);
14906
14907        // Reference: out = x·W + scale * x·A·B. Naive triple-loop.
14908        let naive = |a_buf: &[f32], b_buf: &[f32], rows: usize, inner: usize, cols: usize| {
14909            let mut o = vec![0f32; rows * cols];
14910            for i in 0..rows {
14911                for j in 0..cols {
14912                    let mut acc = 0f32;
14913                    for p in 0..inner {
14914                        acc += a_buf[i * inner + p] * b_buf[p * cols + j];
14915                    }
14916                    o[i * cols + j] = acc;
14917                }
14918            }
14919            o
14920        };
14921        let xw = naive(&x, &w, m, k, n);
14922        let xa = naive(&x, &a, m, k, r);
14923        let xab = naive(&xa, &b, m, r, n);
14924        let mut expected = xw;
14925        for i in 0..(m * n) {
14926            expected[i] += scale * xab[i];
14927        }
14928
14929        // RLX path: build a graph with one LoraMatMul.
14930        let f = DType::F32;
14931        let mut g = Graph::new("lora");
14932        let xn = g.input("x", Shape::new(&[m, k], f));
14933        let wn = g.param("w", Shape::new(&[k, n], f));
14934        let an = g.param("a", Shape::new(&[k, r], f));
14935        let bn = g.param("b", Shape::new(&[r, n], f));
14936        let lm = g.lora_matmul(xn, wn, an, bn, scale, Shape::new(&[m, n], f));
14937        g.set_outputs(vec![lm]);
14938
14939        let plan = rlx_opt::memory::plan_memory(&g);
14940        let mut arena = crate::arena::Arena::from_plan(plan);
14941        let sched = compile_thunks(&g, &arena);
14942
14943        let xn_off = arena.byte_offset(xn);
14944        let wn_off = arena.byte_offset(wn);
14945        let an_off = arena.byte_offset(an);
14946        let bn_off = arena.byte_offset(bn);
14947        let lm_off = arena.byte_offset(lm);
14948        let buf = arena.raw_buf_mut();
14949        unsafe {
14950            let copy = |dst: *mut f32, data: &[f32]| {
14951                for (i, &v) in data.iter().enumerate() {
14952                    *dst.add(i) = v;
14953                }
14954            };
14955            copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
14956            copy(buf.as_mut_ptr().add(wn_off) as *mut f32, &w);
14957            copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
14958            copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
14959        }
14960        execute_thunks(&sched, arena.raw_buf_mut());
14961
14962        let actual: Vec<f32> = unsafe {
14963            let p = arena.raw_buf().as_ptr().add(lm_off) as *const f32;
14964            (0..m * n).map(|i| *p.add(i)).collect()
14965        };
14966
14967        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
14968            assert!(
14969                (e - a).abs() < 1e-3,
14970                "mismatch at {i}: expected {e}, got {a}"
14971            );
14972        }
14973    }
14974
14975    /// Plan #42: fused sampling kernel determinism + greedy fallback.
14976    #[test]
14977    fn sample_temperature_zero_is_argmax() {
14978        // Very low temperature → distribution collapses on argmax.
14979        // Same seed → same output bit-for-bit.
14980        let f = DType::F32;
14981        let mut g = Graph::new("samp");
14982        let logits = g.input("logits", Shape::new(&[1, 8], f));
14983        let s = g.sample(logits, 0, 1.0, 1e-3, 42, Shape::new(&[1], f));
14984        g.set_outputs(vec![s]);
14985        let plan = rlx_opt::memory::plan_memory(&g);
14986        let mut arena = crate::arena::Arena::from_plan(plan);
14987        let sched = compile_thunks(&g, &arena);
14988
14989        let logits_off = arena.byte_offset(logits);
14990        let s_off = arena.byte_offset(s);
14991        let buf = arena.raw_buf_mut();
14992        unsafe {
14993            let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
14994            // argmax = index 5 (value 9.0).
14995            let inputs = [0.1f32, 0.2, 0.3, 0.4, 0.5, 9.0, 0.7, 0.8];
14996            for (i, &v) in inputs.iter().enumerate() {
14997                *p.add(i) = v;
14998            }
14999        }
15000        execute_thunks(&sched, arena.raw_buf_mut());
15001
15002        let token = unsafe {
15003            let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
15004            *p as usize
15005        };
15006        assert_eq!(token, 5, "low-temp sampling should pick the argmax");
15007    }
15008
15009    #[test]
15010    fn sample_top_k_one_is_deterministic() {
15011        // top_k=1 forces only the argmax to have nonzero probability.
15012        let f = DType::F32;
15013        let mut g = Graph::new("samp_k1");
15014        let logits = g.input("logits", Shape::new(&[1, 4], f));
15015        let s = g.sample(logits, 1, 1.0, 1.0, 7, Shape::new(&[1], f));
15016        g.set_outputs(vec![s]);
15017        let plan = rlx_opt::memory::plan_memory(&g);
15018        let mut arena = crate::arena::Arena::from_plan(plan);
15019        let sched = compile_thunks(&g, &arena);
15020
15021        let logits_off = arena.byte_offset(logits);
15022        let s_off = arena.byte_offset(s);
15023        let buf = arena.raw_buf_mut();
15024        unsafe {
15025            let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
15026            let inputs = [0.1f32, 5.0, 0.3, 0.4]; // argmax = 1
15027            for (i, &v) in inputs.iter().enumerate() {
15028                *p.add(i) = v;
15029            }
15030        }
15031        execute_thunks(&sched, arena.raw_buf_mut());
15032        let token = unsafe {
15033            let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
15034            *p as usize
15035        };
15036        assert_eq!(token, 1);
15037    }
15038
15039    /// Plan #44: cumsum primitive parity vs. naive scan.
15040    #[test]
15041    fn cumsum_inclusive_matches_naive() {
15042        let f = DType::F32;
15043        let mut g = Graph::new("cumsum");
15044        let x = g.input("x", Shape::new(&[2, 4], f));
15045        let cs = g.cumsum(x, -1, false, Shape::new(&[2, 4], f));
15046        g.set_outputs(vec![cs]);
15047        let plan = rlx_opt::memory::plan_memory(&g);
15048        let mut arena = crate::arena::Arena::from_plan(plan);
15049        let sched = compile_thunks(&g, &arena);
15050
15051        // Cache offsets up-front so we can drop the immutable borrow.
15052        let x_off = arena.byte_offset(x);
15053        let out_off = arena.byte_offset(cs);
15054        let buf = arena.raw_buf_mut();
15055        unsafe {
15056            let p = buf.as_mut_ptr().add(x_off) as *mut f32;
15057            let inputs = [1.0f32, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0];
15058            for (i, &v) in inputs.iter().enumerate() {
15059                *p.add(i) = v;
15060            }
15061        }
15062        execute_thunks(&sched, arena.raw_buf_mut());
15063
15064        let out: Vec<f32> = unsafe {
15065            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
15066            (0..8).map(|i| *p.add(i)).collect()
15067        };
15068        assert_eq!(out, vec![1.0, 3.0, 6.0, 10.0, 10.0, 30.0, 60.0, 100.0]);
15069    }
15070
15071    /// Plan #46 deep: Narrow×3 → Attention fusion. The three QKV
15072    /// narrows that BERT/Nomic emit on the unfused (batch*seq > 64)
15073    /// path collapse into a single strided-Attention thunk.
15074    #[test]
15075    fn narrow_attention_fuses_in_unfused_path() {
15076        let f = DType::F32;
15077        let mut g = Graph::new("nattn_fuse");
15078        // batch*seq = 8*16 = 128 > 64 so FusedAttnBlock skips.
15079        let qkv = g.input("qkv", Shape::new(&[8, 16, 192], f)); // 3*64 = 192
15080        let mask = g.input("mask", Shape::new(&[8, 16], f));
15081        let q = g.narrow_(qkv, 2, 0, 64);
15082        let k = g.narrow_(qkv, 2, 64, 64);
15083        let v = g.narrow_(qkv, 2, 128, 64);
15084        let attn = g.attention(q, k, v, mask, 4, 16, Shape::new(&[8, 16, 64], f));
15085        g.set_outputs(vec![attn]);
15086
15087        let plan = rlx_opt::memory::plan_memory(&g);
15088        let arena = crate::arena::Arena::from_plan(plan);
15089        let sched = compile_thunks(&g, &arena);
15090
15091        let mut narrow_count = 0;
15092        let mut attn_strides: Option<(u32, u32, u32)> = None;
15093        for t in &sched.thunks {
15094            match t {
15095                Thunk::Narrow { .. } => narrow_count += 1,
15096                Thunk::Attention {
15097                    q_row_stride,
15098                    k_row_stride,
15099                    v_row_stride,
15100                    ..
15101                } => attn_strides = Some((*q_row_stride, *k_row_stride, *v_row_stride)),
15102                _ => {}
15103            }
15104        }
15105        // After fusion the 3 narrows are gone; Attention now walks the
15106        // QKV with parent row stride = 192 (3 × 64) on all three inputs.
15107        assert_eq!(
15108            narrow_count, 0,
15109            "Narrow×3→Attention fusion should eliminate all 3 narrows; saw {narrow_count}"
15110        );
15111        assert_eq!(
15112            attn_strides,
15113            Some((192, 192, 192)),
15114            "Attention should walk Q/K/V with parent row stride 192"
15115        );
15116    }
15117
15118    // ── Backward / training op parity tests ────────────────────
15119    //
15120    // Strategy: build a graph that contains exactly the backward op
15121    // under test (plus its inputs as graph Inputs), execute, and
15122    // compare against a hand-rolled scalar reference. For
15123    // Conv2dBackwardInput we additionally check against the numerical
15124    // gradient of the forward Conv2D — that's the gold-standard test
15125    // that validates the math, not just consistency between two
15126    // implementations of the same formula.
15127
15128    fn run_graph(
15129        g: &Graph,
15130        inputs: &[(NodeId, &[f32])],
15131        out_id: NodeId,
15132        out_len: usize,
15133    ) -> Vec<f32> {
15134        let plan = rlx_opt::memory::plan_memory(g);
15135        let mut arena = crate::arena::Arena::from_plan(plan);
15136        let sched = compile_thunks(g, &arena);
15137        for &(id, data) in inputs {
15138            let off = arena.byte_offset(id);
15139            let buf = arena.raw_buf_mut();
15140            unsafe {
15141                let p = buf.as_mut_ptr().add(off) as *mut f32;
15142                for (i, &v) in data.iter().enumerate() {
15143                    *p.add(i) = v;
15144                }
15145            }
15146        }
15147        execute_thunks(&sched, arena.raw_buf_mut());
15148        let off = arena.byte_offset(out_id);
15149        unsafe {
15150            let p = arena.raw_buf().as_ptr().add(off) as *const f32;
15151            (0..out_len).map(|i| *p.add(i)).collect()
15152        }
15153    }
15154
15155    #[test]
15156    fn relu_backward_matches_mask() {
15157        let f = DType::F32;
15158        let len = 7usize;
15159        let x: Vec<f32> = vec![-2.0, -0.1, 0.0, 0.1, 1.0, 3.0, -5.0];
15160        let dy: Vec<f32> = vec![0.5, 1.5, 2.5, -0.7, 4.0, -1.0, 9.0];
15161
15162        let mut g = Graph::new("relu_bw");
15163        let xn = g.input("x", Shape::new(&[len], f));
15164        let dyn_ = g.input("dy", Shape::new(&[len], f));
15165        let dx = g.relu_backward(xn, dyn_);
15166        g.set_outputs(vec![dx]);
15167
15168        let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, len);
15169        // Reference: gradient is dy where x>0 strictly, else 0.
15170        // (zero is not "positive" — the forward applied max(0, x), and at
15171        // x=0 the subgradient could be anything in [0, dy]; we pick 0.)
15172        let expected: Vec<f32> = x
15173            .iter()
15174            .zip(&dy)
15175            .map(|(&xi, &dyi)| if xi > 0.0 { dyi } else { 0.0 })
15176            .collect();
15177        for (a, e) in actual.iter().zip(&expected) {
15178            assert!((a - e).abs() < 1e-6, "relu_bw mismatch: {a} vs {e}");
15179        }
15180    }
15181
15182    #[test]
15183    fn maxpool2d_backward_routes_to_argmax() {
15184        let f = DType::F32;
15185        // [N=1, C=1, H=4, W=4] → 2x2 max-pool stride 2 → [1,1,2,2].
15186        let x: Vec<f32> = vec![
15187            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
15188        ];
15189        // Argmax of each 2x2 window:
15190        //   (0,0)→6 (idx 5), (0,1)→8 (idx 7),
15191        //   (1,0)→14(idx 13),(1,1)→16(idx 15).
15192        let dy: Vec<f32> = vec![0.5, 1.0, 2.0, 4.0];
15193
15194        let mut g = Graph::new("maxpool_bw");
15195        let xn = g.input("x", Shape::new(&[1, 1, 4, 4], f));
15196        let dyn_ = g.input("dy", Shape::new(&[1, 1, 2, 2], f));
15197        let dx = g.maxpool2d_backward(xn, dyn_, vec![2, 2], vec![2, 2], vec![0, 0]);
15198        g.set_outputs(vec![dx]);
15199
15200        let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, 16);
15201        let mut expected = vec![0f32; 16];
15202        expected[5] = 0.5;
15203        expected[7] = 1.0;
15204        expected[13] = 2.0;
15205        expected[15] = 4.0;
15206        for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
15207            assert!((a - e).abs() < 1e-6, "maxpool_bw[{i}] mismatch: {a} vs {e}");
15208        }
15209    }
15210
15211    #[test]
15212    fn conv2d_backward_input_matches_numerical_gradient() {
15213        use rlx_ir::Philox4x32;
15214        // Small enough to numerically differentiate exhaustively but
15215        // big enough to exercise stride/padding edge cases.
15216        let n = 1usize;
15217        let c_in = 2usize;
15218        let h = 4usize;
15219        let w = 4usize;
15220        let c_out = 3usize;
15221        let kh = 3usize;
15222        let kw = 3usize;
15223        let ph = 1usize;
15224        let pw = 1usize;
15225        let sh = 1usize;
15226        let sw = 1usize;
15227        // Output dims with padding=1, stride=1: same as input.
15228        let h_out = (h + 2 * ph - kh) / sh + 1;
15229        let w_out = (w + 2 * pw - kw) / sw + 1;
15230        assert_eq!(h_out, 4);
15231        assert_eq!(w_out, 4);
15232
15233        let mut rng = Philox4x32::new(7);
15234        let mut x = vec![0f32; n * c_in * h * w];
15235        rng.fill_normal(&mut x);
15236        let mut wt = vec![0f32; c_out * c_in * kh * kw];
15237        rng.fill_normal(&mut wt);
15238        let mut dy = vec![0f32; n * c_out * h_out * w_out];
15239        rng.fill_normal(&mut dy);
15240
15241        // Analytical: Conv2dBackwardInput on (dy, w).
15242        let f = DType::F32;
15243        let mut g = Graph::new("conv_bwi");
15244        let dy_in = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
15245        let w_in = g.input("w", Shape::new(&[c_out, c_in, kh, kw], f));
15246        let dx = g.conv2d_backward_input(
15247            dy_in,
15248            w_in,
15249            Shape::new(&[n, c_in, h, w], f),
15250            vec![kh, kw],
15251            vec![sh, sw],
15252            vec![ph, pw],
15253            vec![1, 1],
15254            1,
15255        );
15256        g.set_outputs(vec![dx]);
15257        let analytical = run_graph(&g, &[(dy_in, &dy), (w_in, &wt)], dx, n * c_in * h * w);
15258
15259        // Numerical: for each x[i], finite-difference forward conv twice.
15260        // Forward: y[j] = sum over filter window of w * x ; dot(dy, y) is
15261        // the scalar we differentiate. Then dx[i] = ∂(dot(dy, y))/∂x[i].
15262        let forward = |x: &[f32]| -> Vec<f32> {
15263            let mut out = vec![0f32; n * c_out * h_out * w_out];
15264            for ni in 0..n {
15265                for co in 0..c_out {
15266                    for ho in 0..h_out {
15267                        for wo in 0..w_out {
15268                            let mut acc = 0f32;
15269                            for ci in 0..c_in {
15270                                for ki in 0..kh {
15271                                    for kj in 0..kw {
15272                                        let hi = ho * sh + ki;
15273                                        let wi = wo * sw + kj;
15274                                        if hi < ph || wi < pw {
15275                                            continue;
15276                                        }
15277                                        let hi = hi - ph;
15278                                        let wi = wi - pw;
15279                                        if hi >= h || wi >= w {
15280                                            continue;
15281                                        }
15282                                        let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
15283                                        let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
15284                                        acc += xv * wv;
15285                                    }
15286                                }
15287                            }
15288                            out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
15289                        }
15290                    }
15291                }
15292            }
15293            out
15294        };
15295        let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
15296        let eps = 1e-3f32;
15297        let mut numerical = vec![0f32; x.len()];
15298        for i in 0..x.len() {
15299            let saved = x[i];
15300            x[i] = saved + eps;
15301            let plus = dot(&forward(&x), &dy);
15302            x[i] = saved - eps;
15303            let minus = dot(&forward(&x), &dy);
15304            x[i] = saved;
15305            numerical[i] = (plus - minus) / (2.0 * eps);
15306        }
15307        for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
15308            // f32 + eps=1e-3 numerical grad → ~1e-3 absolute is realistic.
15309            assert!(
15310                (a - n).abs() < 5e-3,
15311                "conv_bw_input[{i}]: analytical {a} vs numerical {n}"
15312            );
15313        }
15314    }
15315
15316    #[test]
15317    fn conv2d_backward_weight_matches_numerical_gradient() {
15318        use rlx_ir::Philox4x32;
15319        let n = 2usize;
15320        let c_in = 2usize;
15321        let h = 4usize;
15322        let w = 4usize;
15323        let c_out = 2usize;
15324        let kh = 3usize;
15325        let kw = 3usize;
15326        let ph = 0usize;
15327        let pw = 0usize;
15328        let sh = 1usize;
15329        let sw = 1usize;
15330        let h_out = (h + 2 * ph - kh) / sh + 1;
15331        let w_out = (w + 2 * pw - kw) / sw + 1;
15332
15333        let mut rng = Philox4x32::new(11);
15334        let mut x = vec![0f32; n * c_in * h * w];
15335        rng.fill_normal(&mut x);
15336        let mut wt = vec![0f32; c_out * c_in * kh * kw];
15337        rng.fill_normal(&mut wt);
15338        let mut dy = vec![0f32; n * c_out * h_out * w_out];
15339        rng.fill_normal(&mut dy);
15340
15341        let f = DType::F32;
15342        let mut g = Graph::new("conv_bww");
15343        let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
15344        let dyn_ = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
15345        let dwn = g.conv2d_backward_weight(
15346            xn,
15347            dyn_,
15348            Shape::new(&[c_out, c_in, kh, kw], f),
15349            vec![kh, kw],
15350            vec![sh, sw],
15351            vec![ph, pw],
15352            vec![1, 1],
15353            1,
15354        );
15355        g.set_outputs(vec![dwn]);
15356        let analytical = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dwn, c_out * c_in * kh * kw);
15357
15358        let forward = |wt: &[f32]| -> Vec<f32> {
15359            let mut out = vec![0f32; n * c_out * h_out * w_out];
15360            for ni in 0..n {
15361                for co in 0..c_out {
15362                    for ho in 0..h_out {
15363                        for wo in 0..w_out {
15364                            let mut acc = 0f32;
15365                            for ci in 0..c_in {
15366                                for ki in 0..kh {
15367                                    for kj in 0..kw {
15368                                        let hi = ho + ki;
15369                                        let wi = wo + kj;
15370                                        let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
15371                                        let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
15372                                        acc += xv * wv;
15373                                    }
15374                                }
15375                            }
15376                            out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
15377                        }
15378                    }
15379                }
15380            }
15381            out
15382        };
15383        let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
15384        let eps = 1e-3f32;
15385        let mut numerical = vec![0f32; wt.len()];
15386        for i in 0..wt.len() {
15387            let saved = wt[i];
15388            wt[i] = saved + eps;
15389            let plus = dot(&forward(&wt), &dy);
15390            wt[i] = saved - eps;
15391            let minus = dot(&forward(&wt), &dy);
15392            wt[i] = saved;
15393            numerical[i] = (plus - minus) / (2.0 * eps);
15394        }
15395        for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
15396            assert!(
15397                (a - n).abs() < 5e-3,
15398                "conv_bw_weight[{i}]: analytical {a} vs numerical {n}"
15399            );
15400        }
15401    }
15402
15403    #[test]
15404    fn softmax_cross_entropy_matches_reference() {
15405        let f = DType::F32;
15406        let logits: Vec<f32> = vec![
15407            1.0, 2.0, 3.0, // row 0: max=3 (idx 2)
15408            -1.0, 0.0, 4.0, // row 1: max=4 (idx 2)
15409            5.0, 5.0, 5.0, // row 2: uniform
15410        ];
15411        let labels: Vec<f32> = vec![2.0, 0.0, 1.0];
15412
15413        let mut g = Graph::new("sce");
15414        let lg = g.input("logits", Shape::new(&[3, 3], f));
15415        let lb = g.input("labels", Shape::new(&[3], f));
15416        let loss = g.softmax_cross_entropy_with_logits(lg, lb);
15417        g.set_outputs(vec![loss]);
15418        let actual = run_graph(&g, &[(lg, &logits), (lb, &labels)], loss, 3);
15419
15420        // Reference per-row: -log(softmax(row)[label]).
15421        let mut expected = vec![0f32; 3];
15422        for ni in 0..3 {
15423            let row = &logits[ni * 3..(ni + 1) * 3];
15424            let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
15425            let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
15426            let lse = m + sum.ln();
15427            let label_idx = labels[ni] as usize;
15428            expected[ni] = lse - row[label_idx];
15429        }
15430        for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
15431            assert!((a - e).abs() < 1e-5, "sce loss[{i}]: {a} vs {e}");
15432        }
15433    }
15434
15435    #[test]
15436    fn softmax_cross_entropy_backward_matches_numerical_gradient() {
15437        use rlx_ir::Philox4x32;
15438        let n = 4usize;
15439        let c = 5usize;
15440        let mut rng = Philox4x32::new(23);
15441        let mut logits = vec![0f32; n * c];
15442        rng.fill_normal(&mut logits);
15443        let labels: Vec<f32> = (0..n).map(|i| (i % c) as f32).collect();
15444        let mut d_loss = vec![0f32; n];
15445        rng.fill_normal(&mut d_loss);
15446
15447        let f = DType::F32;
15448        let mut g = Graph::new("sce_bw");
15449        let lg = g.input("logits", Shape::new(&[n, c], f));
15450        let lb = g.input("labels", Shape::new(&[n], f));
15451        let dl = g.input("d_loss", Shape::new(&[n], f));
15452        let dlogits = g.softmax_cross_entropy_backward(lg, lb, dl);
15453        g.set_outputs(vec![dlogits]);
15454        let analytical = run_graph(
15455            &g,
15456            &[(lg, &logits), (lb, &labels), (dl, &d_loss)],
15457            dlogits,
15458            n * c,
15459        );
15460
15461        // Numerical: differentiate dot(d_loss, sce_loss(logits)) w.r.t. each logit.
15462        let sce_loss = |logits: &[f32]| -> Vec<f32> {
15463            let mut out = vec![0f32; n];
15464            for ni in 0..n {
15465                let row = &logits[ni * c..(ni + 1) * c];
15466                let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
15467                let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
15468                out[ni] = (m + sum.ln()) - row[labels[ni] as usize];
15469            }
15470            out
15471        };
15472        let dot = |a: &[f32], b: &[f32]| a.iter().zip(b).map(|(&u, &v)| u * v).sum::<f32>();
15473        let eps = 1e-3f32;
15474        let mut numerical = vec![0f32; logits.len()];
15475        for i in 0..logits.len() {
15476            let saved = logits[i];
15477            logits[i] = saved + eps;
15478            let plus = dot(&sce_loss(&logits), &d_loss);
15479            logits[i] = saved - eps;
15480            let minus = dot(&sce_loss(&logits), &d_loss);
15481            logits[i] = saved;
15482            numerical[i] = (plus - minus) / (2.0 * eps);
15483        }
15484        for (i, (a, num)) in analytical.iter().zip(&numerical).enumerate() {
15485            assert!(
15486                (a - num).abs() < 5e-3,
15487                "sce_bw[{i}]: analytical {a} vs numerical {num}"
15488            );
15489        }
15490    }
15491
15492    // ── End-to-end autodiff parity tests ──────────────────────
15493    //
15494    // Build a forward graph, run `grad_with_loss` to produce a graph
15495    // that emits [loss, gradients...], execute it through rlx-cpu,
15496    // and compare each gradient to a finite-difference estimate
15497    // produced by re-running the forward graph with each parameter
15498    // entry perturbed. f32 + ε=1e-3 puts the tolerance floor around
15499    // 5e-3 absolute error.
15500
15501    /// Initialize Op::Constant slots in the arena with their literal
15502    /// data. Mirrors the loop in rlx_runtime::backend (which serves
15503    /// the same role for production runs).
15504    fn fill_constants_into_arena(graph: &Graph, arena: &mut crate::arena::Arena) {
15505        for node in graph.nodes() {
15506            if let Op::Constant { data } = &node.op
15507                && arena.has_buffer(node.id)
15508                && !data.is_empty()
15509            {
15510                let buf = arena.slice_mut(node.id);
15511                let n_floats = data.len() / 4;
15512                let n = buf.len().min(n_floats);
15513                for i in 0..n {
15514                    let bytes = [
15515                        data[i * 4],
15516                        data[i * 4 + 1],
15517                        data[i * 4 + 2],
15518                        data[i * 4 + 3],
15519                    ];
15520                    buf[i] = f32::from_le_bytes(bytes);
15521                }
15522            }
15523        }
15524    }
15525
15526    /// Compile + arena-prep helper for these tests. Returns the
15527    /// schedule and a populated arena. `seed_inputs` writes f32 input
15528    /// data into the arena slot for each (NodeId, &[f32]) pair.
15529    fn prepare(
15530        graph: &Graph,
15531        seed_inputs: &[(NodeId, &[f32])],
15532    ) -> (ThunkSchedule, crate::arena::Arena) {
15533        let plan = rlx_opt::memory::plan_memory(graph);
15534        let mut arena = crate::arena::Arena::from_plan(plan);
15535        let sched = compile_thunks(graph, &arena);
15536        fill_constants_into_arena(graph, &mut arena);
15537        for &(id, data) in seed_inputs {
15538            let off = arena.byte_offset(id);
15539            let buf = arena.raw_buf_mut();
15540            unsafe {
15541                let p = buf.as_mut_ptr().add(off) as *mut f32;
15542                for (i, &v) in data.iter().enumerate() {
15543                    *p.add(i) = v;
15544                }
15545            }
15546        }
15547        (sched, arena)
15548    }
15549
15550    fn read_arena(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f32> {
15551        let off = arena.byte_offset(id);
15552        unsafe {
15553            let p = arena.raw_buf().as_ptr().add(off) as *const f32;
15554            (0..len).map(|i| *p.add(i)).collect()
15555        }
15556    }
15557
15558    fn write_arena(arena: &mut crate::arena::Arena, id: NodeId, data: &[f32]) {
15559        let off = arena.byte_offset(id);
15560        let buf = arena.raw_buf_mut();
15561        unsafe {
15562            let p = buf.as_mut_ptr().add(off) as *mut f32;
15563            for (i, &v) in data.iter().enumerate() {
15564                *p.add(i) = v;
15565            }
15566        }
15567    }
15568
15569    /// f64 sibling of `prepare`. Writes f64 input data into the arena.
15570    fn prepare_f64(
15571        graph: &Graph,
15572        seed_inputs: &[(NodeId, &[f64])],
15573    ) -> (ThunkSchedule, crate::arena::Arena) {
15574        let plan = rlx_opt::memory::plan_memory(graph);
15575        let mut arena = crate::arena::Arena::from_plan(plan);
15576        let sched = compile_thunks(graph, &arena);
15577        fill_constants_into_arena(graph, &mut arena);
15578        for &(id, data) in seed_inputs {
15579            let off = arena.byte_offset(id);
15580            let buf = arena.raw_buf_mut();
15581            unsafe {
15582                let p = buf.as_mut_ptr().add(off) as *mut f64;
15583                for (i, &v) in data.iter().enumerate() {
15584                    *p.add(i) = v;
15585                }
15586            }
15587        }
15588        (sched, arena)
15589    }
15590
15591    fn read_arena_f64(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f64> {
15592        let off = arena.byte_offset(id);
15593        unsafe {
15594            let p = arena.raw_buf().as_ptr().add(off) as *const f64;
15595            (0..len).map(|i| *p.add(i)).collect()
15596        }
15597    }
15598
15599    /// End-to-end f64 DenseSolve through the full compile + execute
15600    /// path. Validates: IR shape inference, memory planner f64 sizing,
15601    /// arena f64 accessors, Thunk::DenseSolveF64 lowering, executor
15602    /// dispatch, Accelerate dgesv FFI.
15603    ///
15604    /// System:
15605    ///   A = [[2, 1],
15606    ///        [1, 3]]   b = [5, 10]
15607    ///   ⇒  x = [1, 3]   (verified by hand)
15608    #[test]
15609    fn dense_solve_f64_end_to_end() {
15610        let mut g = Graph::new("solve_e2e");
15611        let a = g.input("A", Shape::new(&[2, 2], DType::F64));
15612        let b = g.input("b", Shape::new(&[2], DType::F64));
15613        let x = g.dense_solve(a, b, Shape::new(&[2], DType::F64));
15614        g.set_outputs(vec![x]);
15615
15616        let a_data = [2.0, 1.0, 1.0, 3.0_f64];
15617        let b_data = [5.0, 10.0_f64];
15618        let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
15619        execute_thunks(&sched, arena.raw_buf_mut());
15620
15621        let got = read_arena_f64(&arena, x, 2);
15622        let want = [1.0, 3.0_f64];
15623        for i in 0..2 {
15624            assert!(
15625                (got[i] - want[i]).abs() < 1e-12,
15626                "x[{i}] = {} (expected {})",
15627                got[i],
15628                want[i]
15629            );
15630        }
15631    }
15632
15633    /// Scaled-up f64 DenseSolve — tridiagonal Laplacian-shape (typical
15634    /// MNA structure for a passive RC mesh in Circulax). Validates
15635    /// that the solve scales beyond the trivial 2×2 and that the
15636    /// row-major ↔ col-major dance in `dgesv` is correct for the
15637    /// general case.
15638    #[test]
15639    fn dense_solve_f64_5x5_laplacian() {
15640        let n = 5usize;
15641        let mut g = Graph::new("solve_5x5");
15642        let a = g.input("A", Shape::new(&[n, n], DType::F64));
15643        let b = g.input("b", Shape::new(&[n], DType::F64));
15644        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
15645        g.set_outputs(vec![x]);
15646
15647        // 1-D Laplacian: 2 on diagonal, -1 on off-diagonals, 0 elsewhere.
15648        let mut a_data = vec![0.0_f64; n * n];
15649        for i in 0..n {
15650            a_data[i * n + i] = 2.0;
15651            if i > 0 {
15652                a_data[i * n + (i - 1)] = -1.0;
15653            }
15654            if i + 1 < n {
15655                a_data[i * n + (i + 1)] = -1.0;
15656            }
15657        }
15658        let b_data: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
15659        let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
15660        execute_thunks(&sched, arena.raw_buf_mut());
15661
15662        let got = read_arena_f64(&arena, x, n);
15663        // Verify A·x ≈ b by computing the residual.
15664        let mut residual = vec![0.0_f64; n];
15665        for i in 0..n {
15666            for j in 0..n {
15667                residual[i] += a_data[i * n + j] * got[j];
15668            }
15669        }
15670        for i in 0..n {
15671            assert!(
15672                (residual[i] - b_data[i]).abs() < 1e-10,
15673                "row {i}: residual {} vs b {}",
15674                residual[i],
15675                b_data[i]
15676            );
15677        }
15678    }
15679
15680    /// Hello Resistor: end-to-end f64 gradient through a dense solve.
15681    ///
15682    /// Forward:
15683    ///   A      : Param  [N, N]   f64
15684    ///   b      : Input  [N]      f64
15685    ///   x      = solve(A, b)            (DenseSolve)
15686    ///   loss   = sum(x)                 (Reduce::Sum)
15687    ///
15688    /// Backward (via grad_with_loss):
15689    ///   ones [N] = expand(d_output, [N])      (Reduce::Sum VJP)
15690    ///   dx_int   = solve(Aᵀ, ones)             (DenseSolve VJP step 1)
15691    ///   dA       = -outer(dx_int, x)           (DenseSolve VJP step 2)
15692    ///   db       = dx_int                       (DenseSolve VJP step 3)
15693    ///
15694    /// Closed form: with loss = sum(solve(A, b)) = ones·x and
15695    /// implicit-function calculus, db = (Aᵀ)⁻¹·ones, dA = -db ⊗ x.
15696    /// We verify this against the autodiff-emitted graph's output and
15697    /// against a finite-difference baseline.
15698    #[test]
15699    fn hello_resistor_gradient_end_to_end() {
15700        use rlx_opt::autodiff::grad_with_loss;
15701        let n = 3usize;
15702
15703        // ── Build forward graph ──
15704        let mut g = Graph::new("hello_resistor");
15705        let a = g.param("A", Shape::new(&[n, n], DType::F64));
15706        let b = g.input("b", Shape::new(&[n], DType::F64));
15707        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
15708        let loss = g.reduce(
15709            x,
15710            ReduceOp::Sum,
15711            vec![0],
15712            false,
15713            Shape::new(&[1], DType::F64),
15714        );
15715        g.set_outputs(vec![loss]);
15716
15717        // ── Run reverse-mode AD ──
15718        let bwd = grad_with_loss(&g, &[a, b]);
15719        assert_eq!(bwd.outputs.len(), 3, "expect [loss, dA, db]");
15720
15721        // ── Locate the inputs the bwd graph still needs from us ──
15722        // grad_with_loss copies forward nodes into bwd, so A/b/d_output
15723        // appear under their original names. Find them by name.
15724        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
15725            for node in graph.nodes() {
15726                let name = match &node.op {
15727                    rlx_ir::Op::Input { name } => Some(name.as_str()),
15728                    rlx_ir::Op::Param { name } => Some(name.as_str()),
15729                    _ => None,
15730                };
15731                if name == Some(want) {
15732                    return node.id;
15733                }
15734            }
15735            panic!("no node named {want:?} in bwd graph");
15736        };
15737        let a_bwd = find_by_name(&bwd, "A");
15738        let b_bwd = find_by_name(&bwd, "b");
15739        let d_out_bwd = find_by_name(&bwd, "d_output");
15740
15741        // ── Test data ──
15742        // A = [[2,1,0],[1,3,1],[0,1,2]]   (SPD tridiagonal, well-conditioned)
15743        // b = [1,2,3]
15744        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
15745        let b_data = [1.0, 2.0, 3.0_f64];
15746        let d_output = [1.0_f64]; // ∂loss/∂loss
15747
15748        // ── Compile + execute backward graph ──
15749        let (sched, mut arena) = prepare_f64(
15750            &bwd,
15751            &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out_bwd, &d_output)],
15752        );
15753        execute_thunks(&sched, arena.raw_buf_mut());
15754
15755        let loss_out = read_arena_f64(&arena, bwd.outputs[0], 1);
15756        let da_out = read_arena_f64(&arena, bwd.outputs[1], n * n);
15757        let db_out = read_arena_f64(&arena, bwd.outputs[2], n);
15758
15759        // ── Closed-form reference ──
15760        // x = A⁻¹ b ; loss = sum(x).
15761        let x_ref = {
15762            let mut a = a_data;
15763            let mut b = b_data;
15764            let info = crate::blas::dgesv(&mut a, &mut b, n, 1);
15765            assert_eq!(info, 0);
15766            b
15767        };
15768        let loss_ref: f64 = x_ref.iter().sum();
15769        // db = (Aᵀ)⁻¹ · 1
15770        let db_ref = {
15771            let mut at = [0.0_f64; 9];
15772            for i in 0..n {
15773                for j in 0..n {
15774                    at[i * n + j] = a_data[j * n + i];
15775                }
15776            }
15777            let mut ones = [1.0_f64; 3];
15778            let info = crate::blas::dgesv(&mut at, &mut ones, n, 1);
15779            assert_eq!(info, 0);
15780            ones
15781        };
15782        // dA = -outer(db, x) ; dA[i,j] = -db[i] * x[j]
15783        let mut da_ref = [0.0_f64; 9];
15784        for i in 0..n {
15785            for j in 0..n {
15786                da_ref[i * n + j] = -db_ref[i] * x_ref[j];
15787            }
15788        }
15789
15790        // ── Assertions vs analytic answer ──
15791        assert!(
15792            (loss_out[0] - loss_ref).abs() < 1e-10,
15793            "loss: got {}, want {}",
15794            loss_out[0],
15795            loss_ref
15796        );
15797        for i in 0..n {
15798            assert!(
15799                (db_out[i] - db_ref[i]).abs() < 1e-10,
15800                "db[{i}]: got {}, want {}",
15801                db_out[i],
15802                db_ref[i]
15803            );
15804        }
15805        for i in 0..n * n {
15806            assert!(
15807                (da_out[i] - da_ref[i]).abs() < 1e-10,
15808                "dA[{i}]: got {}, want {}",
15809                da_out[i],
15810                da_ref[i]
15811            );
15812        }
15813
15814        // ── Cross-check vs finite differences on db (a few entries) ──
15815        // ∂loss/∂b[k] ≈ (loss(b + h·e_k) - loss(b - h·e_k)) / (2h).
15816        let h = 1e-6_f64;
15817        for k in 0..n {
15818            let mut bp = b_data;
15819            bp[k] += h;
15820            let mut bm = b_data;
15821            bm[k] -= h;
15822            let lp = {
15823                let mut ac = a_data;
15824                let info = crate::blas::dgesv(&mut ac, &mut bp, n, 1);
15825                assert_eq!(info, 0);
15826                bp.iter().sum::<f64>()
15827            };
15828            let lm = {
15829                let mut ac = a_data;
15830                let info = crate::blas::dgesv(&mut ac, &mut bm, n, 1);
15831                assert_eq!(info, 0);
15832                bm.iter().sum::<f64>()
15833            };
15834            let fd = (lp - lm) / (2.0 * h);
15835            assert!(
15836                (db_out[k] - fd).abs() < 1e-7,
15837                "FD mismatch on db[{k}]: AD={} FD={}",
15838                db_out[k],
15839                fd
15840            );
15841        }
15842    }
15843
15844    /// Smallest possible Op::Scan basic test: geometric growth.
15845    /// init = [1, 1, 1] f64, body = (x → x + 0.1·x) = (x → 1.1·x),
15846    /// length = 10. Final carry must equal init·(1.1)^10 ≈ 2.5937…
15847    /// to f64 precision.
15848    #[test]
15849    fn scan_geometric_growth_f64() {
15850        let n = 3usize;
15851        let length = 10u32;
15852
15853        // Body: (x) → x + 0.1·x. One Input, one output, same shape/dtype.
15854        let mut body = Graph::new("scan_body");
15855        let x = body.input("carry", Shape::new(&[n], DType::F64));
15856        let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 0.1_f64.to_le_bytes()).collect();
15857        let scale = body.add_node(
15858            Op::Constant { data: scale_bytes },
15859            vec![],
15860            Shape::new(&[n], DType::F64),
15861        );
15862        let scaled = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
15863        let next = body.binary(BinaryOp::Add, x, scaled, Shape::new(&[n], DType::F64));
15864        body.set_outputs(vec![next]);
15865
15866        // Outer graph: scan(init, body, length).
15867        let mut g = Graph::new("scan_outer");
15868        let init = g.input("init", Shape::new(&[n], DType::F64));
15869        let final_carry = g.scan(init, body, length);
15870        g.set_outputs(vec![final_carry]);
15871
15872        let init_data = vec![1.0_f64; n];
15873        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
15874        execute_thunks(&sched, arena.raw_buf_mut());
15875        let got = read_arena_f64(&arena, final_carry, n);
15876        let want: f64 = 1.1_f64.powi(length as i32);
15877        for i in 0..n {
15878            assert!(
15879                (got[i] - want).abs() < 1e-12,
15880                "got[{i}] = {} want {}",
15881                got[i],
15882                want
15883            );
15884        }
15885    }
15886
15887    /// Per-step xs scan: cumulative-sum.
15888    ///   carry_0 = init
15889    ///   carry_{t+1} = carry_t + xs\[t\]
15890    ///   final = sum_{t<length} xs\[t\] + init
15891    /// Body has 2 inputs (carry, x_t) in that NodeId order; one output
15892    /// (next carry). Validates the per-step-input plumbing end-to-end.
15893    #[test]
15894    fn scan_with_xs_cumulative_sum() {
15895        let n = 3usize;
15896        let length = 4u32;
15897
15898        let mut body = Graph::new("cumsum_body");
15899        // carry must come first in NodeId order — declare it first.
15900        let carry = body.input("carry", Shape::new(&[n], DType::F64));
15901        let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
15902        let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
15903        body.set_outputs(vec![next]);
15904
15905        let mut g = Graph::new("cumsum_outer");
15906        let init = g.input("init", Shape::new(&[n], DType::F64));
15907        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
15908        let final_carry = g.scan_with_xs(init, &[xs], body, length);
15909        g.set_outputs(vec![final_carry]);
15910
15911        let init_data = vec![0.0_f64; n];
15912        let xs_data: Vec<f64> = (0..length as usize * n).map(|i| (i + 1) as f64).collect(); // 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12
15913        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
15914        execute_thunks(&sched, arena.raw_buf_mut());
15915        let got = read_arena_f64(&arena, final_carry, n);
15916
15917        // Reference: column-wise sum of xs rows + init. With our row-major
15918        // layout, column j of xs is xs_data[j], xs_data[n+j], xs_data[2n+j], ...
15919        // (per-step row at offset t*n contributes element j to slot j).
15920        let mut want = init_data.clone();
15921        for t in 0..length as usize {
15922            for j in 0..n {
15923                want[j] += xs_data[t * n + j];
15924            }
15925        }
15926        for i in 0..n {
15927            assert!(
15928                (got[i] - want[i]).abs() < 1e-12,
15929                "got[{i}] = {} want {}",
15930                got[i],
15931                want[i]
15932            );
15933        }
15934    }
15935
15936    /// Per-step xs scan composing with DenseSolve — Circulax-shaped:
15937    ///   carry_{t+1} = solve(M, carry_t + xs\[t\])
15938    /// Models a Backward-Euler step driven by a time-varying source.
15939    #[test]
15940    fn scan_with_xs_be_with_drive() {
15941        let n = 3usize;
15942        let length = 4u32;
15943        let dt = 0.1_f64;
15944
15945        let mut m_data = vec![0.0_f64; n * n];
15946        for i in 0..n {
15947            m_data[i * n + i] = 1.0 + dt * 2.0;
15948            if i > 0 {
15949                m_data[i * n + (i - 1)] = -dt;
15950            }
15951            if i + 1 < n {
15952                m_data[i * n + (i + 1)] = -dt;
15953            }
15954        }
15955        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
15956
15957        let mut body = Graph::new("be_drive_body");
15958        let carry = body.input("carry", Shape::new(&[n], DType::F64));
15959        let drive = body.input("drive", Shape::new(&[n], DType::F64));
15960        let m = body.add_node(
15961            Op::Constant { data: m_bytes },
15962            vec![],
15963            Shape::new(&[n, n], DType::F64),
15964        );
15965        let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
15966        let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
15967        body.set_outputs(vec![next]);
15968
15969        let mut g = Graph::new("be_drive_outer");
15970        let init = g.input("init", Shape::new(&[n], DType::F64));
15971        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
15972        let final_carry = g.scan_with_xs(init, &[xs], body, length);
15973        g.set_outputs(vec![final_carry]);
15974
15975        let init_data = vec![0.0_f64; n];
15976        // Drive the system with a unit pulse on element 0 at t=0,
15977        // zeros after.
15978        let mut xs_data = vec![0.0_f64; length as usize * n];
15979        xs_data[0] = 1.0;
15980
15981        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
15982        execute_thunks(&sched, arena.raw_buf_mut());
15983        let got = read_arena_f64(&arena, final_carry, n);
15984
15985        // Reference: per-step in pure Rust.
15986        let mut x = init_data.clone();
15987        for t in 0..length as usize {
15988            for j in 0..n {
15989                x[j] += xs_data[t * n + j];
15990            }
15991            let mut a_copy = m_data.clone();
15992            crate::blas::dgesv(&mut a_copy, &mut x, n, 1);
15993        }
15994        for i in 0..n {
15995            assert!(
15996                (got[i] - x[i]).abs() < 1e-12,
15997                "got[{i}] = {} ref {}",
15998                got[i],
15999                x[i]
16000            );
16001        }
16002    }
16003
16004    /// Reverse-mode AD through Op::BatchedDenseSolve. Forward solves
16005    /// `[B, N, N] · x = [B, N]`; loss = sum of all entries. Closed
16006    /// form: dB = (Aᵀ)⁻¹·1, dA = -(Aᵀ)⁻¹·1 ⊗ x. Verified analytically
16007    /// per batch (each slice matches what the unbatched DenseSolve VJP
16008    /// would compute).
16009    #[test]
16010    fn batched_dense_solve_gradient_matches_per_batch_analytic() {
16011        use rlx_opt::autodiff::grad_with_loss;
16012        let n = 3usize;
16013        let batch = 4usize;
16014
16015        let mut g = Graph::new("bds_grad");
16016        let a = g.param("A", Shape::new(&[batch, n, n], DType::F64));
16017        let b = g.input("b", Shape::new(&[batch, n], DType::F64));
16018        let x = g.batched_dense_solve(a, b, Shape::new(&[batch, n], DType::F64));
16019        let loss = g.reduce(
16020            x,
16021            ReduceOp::Sum,
16022            vec![0, 1],
16023            false,
16024            Shape::new(&[1], DType::F64),
16025        );
16026        g.set_outputs(vec![loss]);
16027
16028        let bwd = grad_with_loss(&g, &[a, b]);
16029
16030        let find = |graph: &Graph, want: &str| -> NodeId {
16031            for node in graph.nodes() {
16032                let name = match &node.op {
16033                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16034                    _ => None,
16035                };
16036                if name == Some(want) {
16037                    return node.id;
16038                }
16039            }
16040            panic!("no node named {want}");
16041        };
16042        let a_id = find(&bwd, "A");
16043        let b_id = find(&bwd, "b");
16044        let d_out_id = find(&bwd, "d_output");
16045
16046        let mut rng = rlx_ir::Philox4x32::new(0x57e1_u64);
16047        let mut a_data = vec![0.0_f64; batch * n * n];
16048        let mut b_data = vec![0.0_f64; batch * n];
16049        for bi in 0..batch {
16050            for i in 0..n {
16051                for j in 0..n {
16052                    a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
16053                }
16054                a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
16055            }
16056            for i in 0..n {
16057                b_data[bi * n + i] = rng.next_f32() as f64;
16058            }
16059        }
16060        let d_seed = [1.0_f64];
16061
16062        let (sched, mut arena) = prepare_f64(
16063            &bwd,
16064            &[(a_id, &a_data), (b_id, &b_data), (d_out_id, &d_seed)],
16065        );
16066        execute_thunks(&sched, arena.raw_buf_mut());
16067        let da_out = read_arena_f64(&arena, bwd.outputs[1], batch * n * n);
16068        let db_out = read_arena_f64(&arena, bwd.outputs[2], batch * n);
16069
16070        // Reference: per-batch analytic solve. dB_i = (A_iᵀ)⁻¹ · 1,
16071        // dA_i = -dB_i ⊗ x_i.
16072        for bi in 0..batch {
16073            let a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
16074            let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
16075            let mut a_copy = a_slice.clone();
16076            crate::blas::dgesv(&mut a_copy, &mut b_slice, n, 1);
16077            let x_ref = b_slice.clone();
16078            // dB: solve(A^T, ones)
16079            let mut at = vec![0.0_f64; n * n];
16080            for i in 0..n {
16081                for j in 0..n {
16082                    at[i * n + j] = a_slice[j * n + i];
16083                }
16084            }
16085            let mut ones = vec![1.0_f64; n];
16086            crate::blas::dgesv(&mut at, &mut ones, n, 1);
16087            let db_ref = ones;
16088            for i in 0..n {
16089                let got = db_out[bi * n + i];
16090                assert!(
16091                    (got - db_ref[i]).abs() < 1e-10,
16092                    "batch {bi}, db[{i}]: got {got} ref {}",
16093                    db_ref[i]
16094                );
16095            }
16096            // dA: -outer(db, x)
16097            for i in 0..n {
16098                for j in 0..n {
16099                    let got = da_out[bi * n * n + i * n + j];
16100                    let want = -db_ref[i] * x_ref[j];
16101                    assert!(
16102                        (got - want).abs() < 1e-10,
16103                        "batch {bi}, dA[{i},{j}]: got {got} ref {want}"
16104                    );
16105                }
16106            }
16107        }
16108    }
16109
16110    /// AD knob: gradient through `scan_checkpointed` automatically
16111    /// uses the recompute backward path. Compares dinit from a plain
16112    /// scan against the same forward written with `scan_checkpointed`,
16113    /// both run through `grad_with_loss`. They must match to f64.
16114    #[test]
16115    fn scan_checkpointed_grad_matches_plain_scan_grad() {
16116        use rlx_opt::autodiff::grad_with_loss;
16117        let n = 2usize;
16118        let length = 6u32;
16119
16120        let make_body = || {
16121            let mut body = Graph::new("ck_body");
16122            let carry = body.input("carry", Shape::new(&[n], DType::F64));
16123            let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.05_f64.to_le_bytes()).collect();
16124            let scale = body.add_node(
16125                Op::Constant { data: scale_bytes },
16126                vec![],
16127                Shape::new(&[n], DType::F64),
16128            );
16129            let next = body.binary(BinaryOp::Mul, carry, scale, Shape::new(&[n], DType::F64));
16130            body.set_outputs(vec![next]);
16131            body
16132        };
16133
16134        // Plain scan path.
16135        let mut g_plain = Graph::new("ck_plain");
16136        let init_p = g_plain.input("init", Shape::new(&[n], DType::F64));
16137        let final_p = g_plain.scan(init_p, make_body(), length);
16138        let loss_p = g_plain.reduce(
16139            final_p,
16140            ReduceOp::Sum,
16141            vec![0],
16142            false,
16143            Shape::new(&[1], DType::F64),
16144        );
16145        g_plain.set_outputs(vec![loss_p]);
16146        let bwd_p = grad_with_loss(&g_plain, &[init_p]);
16147
16148        // Checkpointed scan path with K=2 (length=6).
16149        let mut g_ck = Graph::new("ck_ckpt");
16150        let init_c = g_ck.input("init", Shape::new(&[n], DType::F64));
16151        let final_c = g_ck.scan_checkpointed(init_c, make_body(), length, 2);
16152        let loss_c = g_ck.reduce(
16153            final_c,
16154            ReduceOp::Sum,
16155            vec![0],
16156            false,
16157            Shape::new(&[1], DType::F64),
16158        );
16159        g_ck.set_outputs(vec![loss_c]);
16160        let bwd_c = grad_with_loss(&g_ck, &[init_c]);
16161
16162        let find = |graph: &Graph, want: &str| -> NodeId {
16163            for node in graph.nodes() {
16164                let name = match &node.op {
16165                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16166                    _ => None,
16167                };
16168                if name == Some(want) {
16169                    return node.id;
16170                }
16171            }
16172            panic!("no {want}");
16173        };
16174
16175        let init_data = vec![0.5_f64, -0.5];
16176        let d_seed = [1.0_f64];
16177
16178        let (s_p, mut a_p) = prepare_f64(
16179            &bwd_p,
16180            &[
16181                (find(&bwd_p, "init"), &init_data),
16182                (find(&bwd_p, "d_output"), &d_seed),
16183            ],
16184        );
16185        execute_thunks(&s_p, a_p.raw_buf_mut());
16186        let dinit_p = read_arena_f64(&a_p, bwd_p.outputs[1], n);
16187
16188        let (s_c, mut a_c) = prepare_f64(
16189            &bwd_c,
16190            &[
16191                (find(&bwd_c, "init"), &init_data),
16192                (find(&bwd_c, "d_output"), &d_seed),
16193            ],
16194        );
16195        execute_thunks(&s_c, a_c.raw_buf_mut());
16196        let dinit_c = read_arena_f64(&a_c, bwd_c.outputs[1], n);
16197
16198        for i in 0..n {
16199            assert!(
16200                (dinit_p[i] - dinit_c[i]).abs() < 1e-12,
16201                "dinit[{i}]: plain={} checkpointed={}",
16202                dinit_p[i],
16203                dinit_c[i]
16204            );
16205        }
16206    }
16207
16208    /// Recursive checkpointing end-to-end: build a ScanBackward
16209    /// configured with K=2 checkpoints (for length=4), and compare
16210    /// dinit against the same backward graph with full trajectory
16211    /// (K=0). Forward computes a cumulative-sum-style scan; loss = sum.
16212    /// Both paths must agree to f64 precision.
16213    #[test]
16214    fn recursive_checkpointing_matches_full_trajectory() {
16215        let n = 2usize;
16216        let length = 4u32;
16217
16218        // Body: carry + ones (deterministic, no xs)
16219        let build_body = || -> Graph {
16220            let mut body = Graph::new("rc_body");
16221            let carry = body.input("carry", Shape::new(&[n], DType::F64));
16222            let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
16223            let ones = body.add_node(
16224                Op::Constant { data: ones_bytes },
16225                vec![],
16226                Shape::new(&[n], DType::F64),
16227            );
16228            let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
16229            body.set_outputs(vec![next]);
16230            body
16231        };
16232
16233        // body_vjp: same body + d_output, output dcarry. body_vjp is
16234        // used by ScanBackward to walk the chain rule per step.
16235        let body_vjp_for = || -> Graph {
16236            use rlx_opt::autodiff::grad;
16237            let body = build_body();
16238            // grad(body, [carry_id]) → graph with dcarry as the output.
16239            let carry_id = body
16240                .nodes()
16241                .iter()
16242                .find(|n| matches!(n.op, Op::Input { .. }))
16243                .map(|n| n.id)
16244                .unwrap();
16245            grad(&body, &[carry_id])
16246        };
16247
16248        // ── Forward (All-strategy): scan with full trajectory ──
16249        let mut g_full = Graph::new("rc_outer_full");
16250        let init_full = g_full.input("init", Shape::new(&[n], DType::F64));
16251        let traj_full_id = g_full.scan_trajectory(init_full, build_body(), length);
16252        // Hand-build a ScanBackward node that reads the full trajectory.
16253        let upstream_full = g_full.input("upstream", Shape::new(&[length as usize, n], DType::F64));
16254        let dinit_full_id = g_full.scan_backward(
16255            init_full,
16256            traj_full_id,
16257            upstream_full,
16258            &[],
16259            body_vjp_for(),
16260            length,
16261            true,
16262            Shape::new(&[n], DType::F64),
16263        );
16264        g_full.set_outputs(vec![dinit_full_id]);
16265
16266        // ── Forward (Recursive-2): scan saves only K=2 rows ──
16267        // Build the trajectory shape [K, *carry] = [2, 2].
16268        let k = 2u32;
16269        let mut g_rec = Graph::new("rc_outer_rec");
16270        let init_rec = g_rec.input("init", Shape::new(&[n], DType::F64));
16271        let traj_rec_id = g_rec.add_node(
16272            Op::Scan {
16273                body: Box::new(build_body()),
16274                length,
16275                save_trajectory: true,
16276                num_bcast: 0,
16277                num_xs: 0,
16278                num_checkpoints: k,
16279            },
16280            vec![init_rec],
16281            Shape::new(&[k as usize, n], DType::F64),
16282        );
16283        // Same upstream shape as the full version (the upstream is per
16284        // *forward step*, length rows — independent of K).
16285        let upstream_rec = g_rec.input("upstream", Shape::new(&[length as usize, n], DType::F64));
16286        let dinit_rec_id = g_rec.add_node(
16287            Op::ScanBackward {
16288                body_vjp: Box::new(body_vjp_for()),
16289                length,
16290                save_trajectory: true,
16291                num_xs: 0,
16292                num_checkpoints: k,
16293                forward_body: Some(Box::new(build_body())),
16294            },
16295            vec![init_rec, traj_rec_id, upstream_rec],
16296            Shape::new(&[n], DType::F64),
16297        );
16298        g_rec.set_outputs(vec![dinit_rec_id]);
16299
16300        // ── Run both, same inputs ──
16301        let init_data = vec![0.5_f64, -0.5];
16302        let upstream_data: Vec<f64> = (0..length as usize * n).map(|i| (i as f64) * 0.1).collect();
16303
16304        let find = |graph: &Graph, want: &str| -> NodeId {
16305            for node in graph.nodes() {
16306                if let Op::Input { name } = &node.op
16307                    && name == want
16308                {
16309                    return node.id;
16310                }
16311            }
16312            panic!("no input {want}");
16313        };
16314
16315        let (s_full, mut a_full) = prepare_f64(
16316            &g_full,
16317            &[
16318                (find(&g_full, "init"), &init_data),
16319                (find(&g_full, "upstream"), &upstream_data),
16320            ],
16321        );
16322        execute_thunks(&s_full, a_full.raw_buf_mut());
16323        let dinit_full = read_arena_f64(&a_full, g_full.outputs[0], n);
16324
16325        let (s_rec, mut a_rec) = prepare_f64(
16326            &g_rec,
16327            &[
16328                (find(&g_rec, "init"), &init_data),
16329                (find(&g_rec, "upstream"), &upstream_data),
16330            ],
16331        );
16332        execute_thunks(&s_rec, a_rec.raw_buf_mut());
16333        let dinit_rec = read_arena_f64(&a_rec, g_rec.outputs[0], n);
16334
16335        for i in 0..n {
16336            assert!(
16337                (dinit_full[i] - dinit_rec[i]).abs() < 1e-12,
16338                "i={i}: full={} rec={}",
16339                dinit_full[i],
16340                dinit_rec[i]
16341            );
16342        }
16343    }
16344
16345    /// vmap-of-grad: gradient through Scan, vmap'd over init.
16346    /// Forward (per row):
16347    ///   carry_{t+1} = carry_t + ones    (body adds a constant)
16348    ///   loss = sum(carry_length) = sum(init) + length·n
16349    /// Closed form: dloss/dinit_i = 1 for every i. vmap over init at
16350    /// batch=3 → dinit_batched is all-ones [3, n]. Cross-checks
16351    /// against per-row grad_with_loss runs. Validates the vmap rule
16352    /// for Op::ScanBackward.
16353    #[test]
16354    fn vmap_of_grad_scan_matches_per_row_runs() {
16355        use rlx_opt::autodiff::grad_with_loss;
16356        use rlx_opt::vmap::vmap;
16357        let n = 2usize;
16358        let length = 3u32;
16359        let batch = 3usize;
16360
16361        let mut body = Graph::new("scan_grad_body");
16362        let carry = body.input("carry", Shape::new(&[n], DType::F64));
16363        let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
16364        let ones = body.add_node(
16365            Op::Constant { data: ones_bytes },
16366            vec![],
16367            Shape::new(&[n], DType::F64),
16368        );
16369        let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
16370        body.set_outputs(vec![next]);
16371
16372        let mut g = Graph::new("scan_grad_outer");
16373        let init = g.input("init", Shape::new(&[n], DType::F64));
16374        let final_x = g.scan(init, body, length);
16375        let loss = g.reduce(
16376            final_x,
16377            ReduceOp::Sum,
16378            vec![0],
16379            false,
16380            Shape::new(&[1], DType::F64),
16381        );
16382        g.set_outputs(vec![loss]);
16383
16384        let bwd = grad_with_loss(&g, &[init]);
16385        let bg = vmap(&bwd, &["init"], batch);
16386
16387        let find = |graph: &Graph, want: &str| -> NodeId {
16388            for node in graph.nodes() {
16389                let name = match &node.op {
16390                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16391                    _ => None,
16392                };
16393                if name == Some(want) {
16394                    return node.id;
16395                }
16396            }
16397            panic!("no node named {want}");
16398        };
16399        let init_b = find(&bg, "init");
16400        let d_out_b = find(&bg, "d_output");
16401
16402        let init_data: Vec<f64> = (0..batch * n).map(|i| (i as f64) * 0.5).collect();
16403        let d_seed = [1.0_f64];
16404
16405        let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (d_out_b, &d_seed)]);
16406        execute_thunks(&sched, arena.raw_buf_mut());
16407        let dinit_b = read_arena_f64(&arena, bg.outputs[1], batch * n);
16408
16409        for i in 0..batch * n {
16410            assert!(
16411                (dinit_b[i] - 1.0).abs() < 1e-12,
16412                "dinit[{i}] = {} (expected 1.0)",
16413                dinit_b[i]
16414            );
16415        }
16416
16417        // Cross-check vs per-row grad_with_loss.
16418        for bi in 0..batch {
16419            let row = &init_data[bi * n..(bi + 1) * n];
16420            let mut g2 = Graph::new("per_row_grad");
16421            let init2 = g2.input("init", Shape::new(&[n], DType::F64));
16422            let mut body2 = Graph::new("per_row_body");
16423            let c2 = body2.input("carry", Shape::new(&[n], DType::F64));
16424            let ones2_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
16425            let ones2 = body2.add_node(
16426                Op::Constant { data: ones2_bytes },
16427                vec![],
16428                Shape::new(&[n], DType::F64),
16429            );
16430            let next2 = body2.binary(BinaryOp::Add, c2, ones2, Shape::new(&[n], DType::F64));
16431            body2.set_outputs(vec![next2]);
16432            let final2 = g2.scan(init2, body2, length);
16433            let loss2 = g2.reduce(
16434                final2,
16435                ReduceOp::Sum,
16436                vec![0],
16437                false,
16438                Shape::new(&[1], DType::F64),
16439            );
16440            g2.set_outputs(vec![loss2]);
16441            let bwd2 = grad_with_loss(&g2, &[init2]);
16442            let init2_id = find(&bwd2, "init");
16443            let d_out2_id = find(&bwd2, "d_output");
16444            let (s2, mut a2) = prepare_f64(&bwd2, &[(init2_id, row), (d_out2_id, &d_seed)]);
16445            execute_thunks(&s2, a2.raw_buf_mut());
16446            let row_dinit = read_arena_f64(&a2, bwd2.outputs[1], n);
16447            for j in 0..n {
16448                let got = dinit_b[bi * n + j];
16449                let want = row_dinit[j];
16450                assert!(
16451                    (got - want).abs() < 1e-12,
16452                    "row {bi}, j {j}: vmap'd={got} per-row={want}"
16453                );
16454            }
16455        }
16456    }
16457
16458    /// vmap of Op::Scan: batched cumulative-sum. Forward
16459    ///   carry_{t+1} = carry_t + xs\[t\]
16460    ///   final = init + sum(xs)
16461    /// vmap over both init and xs at batch=3. Each batch row should
16462    /// equal the scalar run of the same body+xs subset.
16463    #[test]
16464    fn vmap_scan_cumulative_sum_matches_scalar_runs() {
16465        use rlx_opt::vmap::vmap;
16466        let n = 2usize;
16467        let length = 4u32;
16468        let batch = 3usize;
16469
16470        // Body: (carry, x_t) → carry + x_t
16471        let mut body = Graph::new("scan_body_cumsum");
16472        let carry = body.input("carry", Shape::new(&[n], DType::F64));
16473        let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
16474        let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
16475        body.set_outputs(vec![next]);
16476
16477        let mut g = Graph::new("scan_outer_cumsum");
16478        let init = g.input("init", Shape::new(&[n], DType::F64));
16479        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
16480        let final_carry = g.scan_with_xs(init, &[xs], body, length);
16481        g.set_outputs(vec![final_carry]);
16482
16483        // vmap over both init and xs.
16484        let bg = vmap(&g, &["init", "xs"], batch);
16485
16486        // Test data — distinct per-batch rows.
16487        let init_data: Vec<f64> = (0..batch * n).map(|i| (i + 1) as f64).collect();
16488        // xs has shape [B, length, n] after vmap (the outer's xs is
16489        // [length, n]; vmap lifts it to [B, length, n]).
16490        let xs_data: Vec<f64> = (0..batch * length as usize * n)
16491            .map(|i| 0.1 * (i as f64))
16492            .collect();
16493
16494        let find = |graph: &Graph, want: &str| -> NodeId {
16495            for node in graph.nodes() {
16496                if let Op::Input { name } = &node.op
16497                    && name == want
16498                {
16499                    return node.id;
16500                }
16501            }
16502            panic!("no input {want}");
16503        };
16504        let init_b = find(&bg, "init");
16505        let xs_b = find(&bg, "xs");
16506        let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (xs_b, &xs_data)]);
16507        execute_thunks(&sched, arena.raw_buf_mut());
16508        let batched_out = read_arena_f64(&arena, bg.outputs[0], batch * n);
16509
16510        // Reference: per-batch scalar Scan.
16511        for bi in 0..batch {
16512            let init_slice = &init_data[bi * n..(bi + 1) * n];
16513            let mut x = init_slice.to_vec();
16514            for t in 0..length as usize {
16515                for j in 0..n {
16516                    x[j] += xs_data[bi * length as usize * n + t * n + j];
16517                }
16518            }
16519
16520            for i in 0..n {
16521                let got = batched_out[bi * n + i];
16522                assert!(
16523                    (got - x[i]).abs() < 1e-12,
16524                    "row {bi}, i {i}: got {got} ref {}",
16525                    x[i]
16526                );
16527            }
16528        }
16529    }
16530
16531    /// vmap of dense solve — Circulax-shaped batched parameter sweep.
16532    /// Forward: x = solve(A, b). vmap over both A (batched [B,N,N])
16533    /// and b (batched [B,N]). Run on CPU and compare each batch row
16534    /// against an independent scalar dgesv.
16535    #[test]
16536    fn vmap_dense_solve_matches_scalar_runs() {
16537        use rlx_opt::vmap::vmap;
16538        let n = 3usize;
16539        let batch = 4usize;
16540
16541        let mut g = Graph::new("solve_forward");
16542        let a = g.input("A", Shape::new(&[n, n], DType::F64));
16543        let b = g.input("b", Shape::new(&[n], DType::F64));
16544        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
16545        g.set_outputs(vec![x]);
16546
16547        // vmap both A and b across the batch.
16548        let bg = vmap(&g, &["A", "b"], batch);
16549
16550        // Independent A and b per batch row.
16551        let mut rng = rlx_ir::Philox4x32::new(0xb47c_u64);
16552        let mut a_data = vec![0.0_f64; batch * n * n];
16553        let mut b_data = vec![0.0_f64; batch * n];
16554        for bi in 0..batch {
16555            // Diagonally dominant A — guaranteed non-singular.
16556            for i in 0..n {
16557                for j in 0..n {
16558                    a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
16559                }
16560                a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
16561            }
16562            for i in 0..n {
16563                b_data[bi * n + i] = rng.next_f32() as f64;
16564            }
16565        }
16566
16567        let find = |graph: &Graph, want: &str| -> NodeId {
16568            for node in graph.nodes() {
16569                if let Op::Input { name } = &node.op
16570                    && name == want
16571                {
16572                    return node.id;
16573                }
16574            }
16575            panic!("no input named {want}");
16576        };
16577        let ba = find(&bg, "A");
16578        let bb = find(&bg, "b");
16579        let (sched, mut arena) = prepare_f64(&bg, &[(ba, &a_data), (bb, &b_data)]);
16580        execute_thunks(&sched, arena.raw_buf_mut());
16581        let batched_x = read_arena_f64(&arena, bg.outputs[0], batch * n);
16582
16583        // Reference: per-batch dgesv.
16584        for bi in 0..batch {
16585            let mut a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
16586            let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
16587            crate::blas::dgesv(&mut a_slice, &mut b_slice, n, 1);
16588            for i in 0..n {
16589                let got = batched_x[bi * n + i];
16590                let want = b_slice[i];
16591                assert!(
16592                    (got - want).abs() < 1e-12,
16593                    "row {bi}, i {i}: got {got} want {want}"
16594                );
16595            }
16596        }
16597    }
16598
16599    /// vmap end-to-end: build a graph that computes y = MatMul(x, w) + b
16600    /// and reduces to a per-element loss. vmap over x with batch=4.
16601    /// Run the batched graph and compare each output row against an
16602    /// independent scalar run of the original graph. Validates the
16603    /// structural lift + the runtime path for batched MatMul +
16604    /// batched Binary + batched Reduce.
16605    #[test]
16606    fn vmap_matmul_add_reduce_matches_scalar_runs() {
16607        use rlx_opt::vmap::vmap;
16608        let n = 3usize;
16609        let batch = 4usize;
16610
16611        // Forward graph: y = MatMul(reshape(x, [1,n]), w) + b ; loss = sum(y).
16612        let mut g = Graph::new("vmap_e2e_forward");
16613        let x = g.input("x", Shape::new(&[n], DType::F64));
16614        let w = g.input("w", Shape::new(&[n, n], DType::F64));
16615        let b = g.input("b", Shape::new(&[n], DType::F64));
16616        let x_row = g.add_node(
16617            Op::Reshape {
16618                new_shape: vec![1, n as i64],
16619            },
16620            vec![x],
16621            Shape::new(&[1, n], DType::F64),
16622        );
16623        let mm = g.matmul(x_row, w, Shape::new(&[1, n], DType::F64));
16624        let mm_flat = g.add_node(
16625            Op::Reshape {
16626                new_shape: vec![n as i64],
16627            },
16628            vec![mm],
16629            Shape::new(&[n], DType::F64),
16630        );
16631        let yv = g.binary(BinaryOp::Add, mm_flat, b, Shape::new(&[n], DType::F64));
16632        let loss = g.reduce(
16633            yv,
16634            ReduceOp::Sum,
16635            vec![0],
16636            false,
16637            Shape::new(&[1], DType::F64),
16638        );
16639        g.set_outputs(vec![loss]);
16640
16641        // Build the vmap'd version (batch over x; w and b shared).
16642        let bg = vmap(&g, &["x"], batch);
16643
16644        // Test data — distinct rows so we can verify the per-row dispatch.
16645        let mut rng = rlx_ir::Philox4x32::new(0xc1c0_u64);
16646        let n_w = n * n;
16647        let w_data: Vec<f64> = (0..n_w).map(|_| rng.next_f32() as f64).collect();
16648        let b_data: Vec<f64> = (0..n).map(|_| rng.next_f32() as f64).collect();
16649        let mut x_data_batched: Vec<f64> = Vec::with_capacity(batch * n);
16650        for _ in 0..batch * n {
16651            x_data_batched.push(rng.next_f32() as f64);
16652        }
16653
16654        // Run the batched graph.
16655        let find = |graph: &Graph, want: &str| -> NodeId {
16656            for node in graph.nodes() {
16657                if let Op::Input { name } = &node.op
16658                    && name == want
16659                {
16660                    return node.id;
16661                }
16662            }
16663            panic!("no input named {want}");
16664        };
16665        let bx = find(&bg, "x");
16666        let bw = find(&bg, "w");
16667        let bb = find(&bg, "b");
16668        let (sched, mut arena) =
16669            prepare_f64(&bg, &[(bx, &x_data_batched), (bw, &w_data), (bb, &b_data)]);
16670        execute_thunks(&sched, arena.raw_buf_mut());
16671        // Reduce::Sum on shifted axis 1 with keep_dim=false → output [B, 1]
16672        // (it preserves the leading batch axis but reduces what was [n] to [].
16673        // Since the original output was [1] f64 and the reduce was over
16674        // axis 0, after vmap the leading-axis-shifted reduce keeps the
16675        // leading 1 from the original output's [1] shape.)
16676        let batched_out = read_arena_f64(&arena, bg.outputs[0], batch);
16677
16678        // Reference: run the original (un-batched) graph once per batch row.
16679        for bi in 0..batch {
16680            let xs_slice = &x_data_batched[bi * n..(bi + 1) * n];
16681            let mut g2 = Graph::new("scalar_run");
16682            let x2 = g2.input("x", Shape::new(&[n], DType::F64));
16683            let w2 = g2.input("w", Shape::new(&[n, n], DType::F64));
16684            let b2 = g2.input("b", Shape::new(&[n], DType::F64));
16685            let xr = g2.add_node(
16686                Op::Reshape {
16687                    new_shape: vec![1, n as i64],
16688                },
16689                vec![x2],
16690                Shape::new(&[1, n], DType::F64),
16691            );
16692            let m = g2.matmul(xr, w2, Shape::new(&[1, n], DType::F64));
16693            let mf = g2.add_node(
16694                Op::Reshape {
16695                    new_shape: vec![n as i64],
16696                },
16697                vec![m],
16698                Shape::new(&[n], DType::F64),
16699            );
16700            let yv2 = g2.binary(BinaryOp::Add, mf, b2, Shape::new(&[n], DType::F64));
16701            let l2 = g2.reduce(
16702                yv2,
16703                ReduceOp::Sum,
16704                vec![0],
16705                false,
16706                Shape::new(&[1], DType::F64),
16707            );
16708            g2.set_outputs(vec![l2]);
16709            let (s2, mut a2) = prepare_f64(&g2, &[(x2, xs_slice), (w2, &w_data), (b2, &b_data)]);
16710            execute_thunks(&s2, a2.raw_buf_mut());
16711            let scalar_out = read_arena_f64(&a2, l2, 1);
16712            assert!(
16713                (batched_out[bi] - scalar_out[0]).abs() < 1e-12,
16714                "row {bi}: batched={} scalar={}",
16715                batched_out[bi],
16716                scalar_out[0]
16717            );
16718        }
16719    }
16720
16721    /// Full gradient through scan-with-xs: dinit AND dxs both checked
16722    /// against finite differences. Forward
16723    ///   carry_{t+1} = solve(M, carry_t + xs\[t\])
16724    ///   loss        = sum(carry_length)
16725    /// Verifies that grad_with_loss returns gradients w.r.t. both
16726    /// `init` and `xs` and that dxs matches per-element FD.
16727    #[test]
16728    fn scan_with_xs_dxs_matches_fd() {
16729        use rlx_opt::autodiff::grad_with_loss;
16730        let n = 3usize;
16731        let length = 3u32;
16732        let dt = 0.1_f64;
16733
16734        let mut m_data = vec![0.0_f64; n * n];
16735        for i in 0..n {
16736            m_data[i * n + i] = 1.0 + dt * 2.0;
16737            if i > 0 {
16738                m_data[i * n + (i - 1)] = -dt;
16739            }
16740            if i + 1 < n {
16741                m_data[i * n + (i + 1)] = -dt;
16742            }
16743        }
16744        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
16745
16746        let mut body = Graph::new("be_dxs_body");
16747        let carry = body.input("carry", Shape::new(&[n], DType::F64));
16748        let drive = body.input("drive", Shape::new(&[n], DType::F64));
16749        let m = body.add_node(
16750            Op::Constant { data: m_bytes },
16751            vec![],
16752            Shape::new(&[n, n], DType::F64),
16753        );
16754        let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
16755        let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
16756        body.set_outputs(vec![next]);
16757
16758        let mut g = Graph::new("be_dxs_outer");
16759        let init = g.input("init", Shape::new(&[n], DType::F64));
16760        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
16761        let final_carry = g.scan_with_xs(init, &[xs], body, length);
16762        let loss = g.reduce(
16763            final_carry,
16764            ReduceOp::Sum,
16765            vec![0],
16766            false,
16767            Shape::new(&[1], DType::F64),
16768        );
16769        g.set_outputs(vec![loss]);
16770
16771        // wrt = [init, xs] — get both gradients back.
16772        let bwd = grad_with_loss(&g, &[init, xs]);
16773        assert_eq!(bwd.outputs.len(), 3, "[loss, dinit, dxs]");
16774
16775        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
16776            for node in graph.nodes() {
16777                let name = match &node.op {
16778                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16779                    _ => None,
16780                };
16781                if name == Some(want) {
16782                    return node.id;
16783                }
16784            }
16785            panic!("no node named {want:?}");
16786        };
16787        let init_bwd = find_by_name(&bwd, "init");
16788        let xs_bwd = find_by_name(&bwd, "xs");
16789        let d_out_bwd = find_by_name(&bwd, "d_output");
16790
16791        let init_data = vec![0.5_f64, 0.0, -0.5];
16792        let xs_data: Vec<f64> = (0..length as usize * n)
16793            .map(|i| 0.1_f64 * ((i as f64) - 4.0))
16794            .collect();
16795        let d_seed = [1.0_f64];
16796
16797        let (sched, mut arena) = prepare_f64(
16798            &bwd,
16799            &[
16800                (init_bwd, &init_data),
16801                (xs_bwd, &xs_data),
16802                (d_out_bwd, &d_seed),
16803            ],
16804        );
16805        execute_thunks(&sched, arena.raw_buf_mut());
16806        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
16807        let dxs = read_arena_f64(&arena, bwd.outputs[2], length as usize * n);
16808
16809        let h = 1e-6;
16810        let loss_at = |x0: &[f64], xs_in: &[f64]| -> f64 {
16811            let mut acc = x0.to_vec();
16812            for t in 0..length as usize {
16813                for j in 0..n {
16814                    acc[j] += xs_in[t * n + j];
16815                }
16816                let mut a_copy = m_data.clone();
16817                crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
16818            }
16819            acc.iter().sum()
16820        };
16821
16822        // FD on dinit (sanity).
16823        for i in 0..n {
16824            let mut ip = init_data.to_vec();
16825            ip[i] += h;
16826            let mut im = init_data.to_vec();
16827            im[i] -= h;
16828            let fd = (loss_at(&ip, &xs_data) - loss_at(&im, &xs_data)) / (2.0 * h);
16829            assert!(
16830                (dinit[i] - fd).abs() < 1e-7,
16831                "FD dinit[{i}]: AD={} FD={}",
16832                dinit[i],
16833                fd
16834            );
16835        }
16836
16837        // FD on every dxs entry — full per-step gradient check.
16838        for t in 0..length as usize {
16839            for j in 0..n {
16840                let idx = t * n + j;
16841                let mut xp = xs_data.clone();
16842                xp[idx] += h;
16843                let mut xm = xs_data.clone();
16844                xm[idx] -= h;
16845                let fd = (loss_at(&init_data, &xp) - loss_at(&init_data, &xm)) / (2.0 * h);
16846                assert!(
16847                    (dxs[idx] - fd).abs() < 1e-7,
16848                    "FD dxs[t={t},j={j}]: AD={} FD={}",
16849                    dxs[idx],
16850                    fd
16851                );
16852            }
16853        }
16854    }
16855
16856    /// Gradient through a scan with per-step xs (Circulax-shaped).
16857    /// Forward:
16858    ///   carry_{t+1} = solve(M, carry_t + xs\[t\])
16859    ///   loss = sum(carry_length)
16860    /// dxs is out of MVP (asserted in the VJP rule's body_vjp `wrt`),
16861    /// but `dinit` flows correctly through the body's reverse Jacobian
16862    /// even with xs in the chain. Verify dinit against finite differences.
16863    #[test]
16864    fn scan_with_xs_gradient_dinit_matches_fd() {
16865        use rlx_opt::autodiff::grad_with_loss;
16866        let n = 3usize;
16867        let length = 3u32;
16868        let dt = 0.1_f64;
16869
16870        let mut m_data = vec![0.0_f64; n * n];
16871        for i in 0..n {
16872            m_data[i * n + i] = 1.0 + dt * 2.0;
16873            if i > 0 {
16874                m_data[i * n + (i - 1)] = -dt;
16875            }
16876            if i + 1 < n {
16877                m_data[i * n + (i + 1)] = -dt;
16878            }
16879        }
16880        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
16881
16882        let mut body = Graph::new("be_xs_grad_body");
16883        let carry = body.input("carry", Shape::new(&[n], DType::F64));
16884        let drive = body.input("drive", Shape::new(&[n], DType::F64));
16885        let m = body.add_node(
16886            Op::Constant { data: m_bytes },
16887            vec![],
16888            Shape::new(&[n, n], DType::F64),
16889        );
16890        let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
16891        let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
16892        body.set_outputs(vec![next]);
16893
16894        let mut g = Graph::new("be_xs_grad_outer");
16895        let init = g.input("init", Shape::new(&[n], DType::F64));
16896        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
16897        let final_carry = g.scan_with_xs(init, &[xs], body, length);
16898        let loss = g.reduce(
16899            final_carry,
16900            ReduceOp::Sum,
16901            vec![0],
16902            false,
16903            Shape::new(&[1], DType::F64),
16904        );
16905        g.set_outputs(vec![loss]);
16906
16907        let bwd = grad_with_loss(&g, &[init]);
16908
16909        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
16910            for node in graph.nodes() {
16911                let name = match &node.op {
16912                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16913                    _ => None,
16914                };
16915                if name == Some(want) {
16916                    return node.id;
16917                }
16918            }
16919            panic!("no node named {want:?}");
16920        };
16921        let init_bwd = find_by_name(&bwd, "init");
16922        let xs_bwd = find_by_name(&bwd, "xs");
16923        let d_out_bwd = find_by_name(&bwd, "d_output");
16924
16925        let init_data = vec![0.5_f64, 0.0, -0.5];
16926        // Drive: small per-step pulse, varying per element.
16927        let xs_data: Vec<f64> = (0..length as usize * n)
16928            .map(|i| 0.1_f64 * ((i as f64) - 4.0))
16929            .collect();
16930        let d_seed = [1.0_f64];
16931
16932        let (sched, mut arena) = prepare_f64(
16933            &bwd,
16934            &[
16935                (init_bwd, &init_data),
16936                (xs_bwd, &xs_data),
16937                (d_out_bwd, &d_seed),
16938            ],
16939        );
16940        execute_thunks(&sched, arena.raw_buf_mut());
16941        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
16942
16943        let h = 1e-6;
16944        let loss_at = |x0: &[f64]| -> f64 {
16945            let mut acc = x0.to_vec();
16946            for t in 0..length as usize {
16947                for j in 0..n {
16948                    acc[j] += xs_data[t * n + j];
16949                }
16950                let mut a_copy = m_data.clone();
16951                crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
16952            }
16953            acc.iter().sum()
16954        };
16955        for i in 0..n {
16956            let mut ip = init_data.to_vec();
16957            ip[i] += h;
16958            let mut im = init_data.to_vec();
16959            im[i] -= h;
16960            let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
16961            assert!(
16962                (dinit[i] - fd).abs() < 1e-7,
16963                "FD dinit[{i}]: AD={} FD={}",
16964                dinit[i],
16965                fd
16966            );
16967        }
16968    }
16969
16970    /// Gradient through a geometric-growth scan: forward
16971    ///   x_{t+1} = 1.1 · x_t,    x_0 = init
16972    ///   final   = x_length     = init · 1.1^length
16973    ///   loss    = sum(final)
16974    /// closed-form ∂loss/∂init\[i\] = 1.1^length for every i.
16975    /// Validates the VJP path: AD pre-pass rewrites save_trajectory=false
16976    /// to true, autodiff emits Op::ScanBackward, executor walks t back.
16977    #[test]
16978    fn scan_gradient_geometric_matches_closed_form() {
16979        use rlx_opt::autodiff::grad_with_loss;
16980        let n = 3usize;
16981        let length = 5u32;
16982
16983        let mut body = Graph::new("scan_grad_body");
16984        let x = body.input("carry", Shape::new(&[n], DType::F64));
16985        let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.1_f64.to_le_bytes()).collect();
16986        let scale = body.add_node(
16987            Op::Constant { data: scale_bytes },
16988            vec![],
16989            Shape::new(&[n], DType::F64),
16990        );
16991        let next = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
16992        body.set_outputs(vec![next]);
16993
16994        let mut g = Graph::new("scan_grad_outer");
16995        let init = g.input("init", Shape::new(&[n], DType::F64));
16996        let final_x = g.scan(init, body, length);
16997        let loss = g.reduce(
16998            final_x,
16999            ReduceOp::Sum,
17000            vec![0],
17001            false,
17002            Shape::new(&[1], DType::F64),
17003        );
17004        g.set_outputs(vec![loss]);
17005
17006        let bwd = grad_with_loss(&g, &[init]);
17007        assert_eq!(bwd.outputs.len(), 2);
17008
17009        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17010            for node in graph.nodes() {
17011                let name = match &node.op {
17012                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17013                    _ => None,
17014                };
17015                if name == Some(want) {
17016                    return node.id;
17017                }
17018            }
17019            panic!("no node named {want:?}");
17020        };
17021        let init_bwd = find_by_name(&bwd, "init");
17022        let d_out_bwd = find_by_name(&bwd, "d_output");
17023
17024        let init_data = vec![1.0_f64; n];
17025        let d_seed = [1.0_f64];
17026        let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
17027        execute_thunks(&sched, arena.raw_buf_mut());
17028        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
17029
17030        let want = 1.1_f64.powi(length as i32);
17031        for i in 0..n {
17032            assert!(
17033                (dinit[i] - want).abs() < 1e-12,
17034                "dinit[{i}] = {} want {}",
17035                dinit[i],
17036                want
17037            );
17038        }
17039
17040        // Finite-difference cross-check on init[0].
17041        let h = 1e-6;
17042        let loss_at = |x: &[f64]| -> f64 {
17043            let mut acc = x.to_vec();
17044            for _ in 0..length {
17045                for v in acc.iter_mut() {
17046                    *v *= 1.1;
17047                }
17048            }
17049            acc.iter().sum()
17050        };
17051        let mut ip = init_data.clone();
17052        ip[0] += h;
17053        let mut im = init_data.clone();
17054        im[0] -= h;
17055        let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
17056        assert!(
17057            (dinit[0] - fd).abs() < 1e-7,
17058            "FD dinit[0]: AD={} FD={}",
17059            dinit[0],
17060            fd
17061        );
17062    }
17063
17064    /// Gradient through Backward Euler scan composing with DenseSolve.
17065    /// Asserts dinit matches finite-difference per coordinate.
17066    #[test]
17067    fn scan_gradient_backward_euler_matches_fd() {
17068        use rlx_opt::autodiff::grad_with_loss;
17069        let n = 4usize;
17070        let length = 3u32;
17071        let dt = 0.05_f64;
17072
17073        let mut m_data = vec![0.0_f64; n * n];
17074        for i in 0..n {
17075            m_data[i * n + i] = 1.0 + dt * 2.0;
17076            if i > 0 {
17077                m_data[i * n + (i - 1)] = -dt;
17078            }
17079            if i + 1 < n {
17080                m_data[i * n + (i + 1)] = -dt;
17081            }
17082        }
17083        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17084
17085        let mut body = Graph::new("be_grad_body");
17086        let x = body.input("x", Shape::new(&[n], DType::F64));
17087        let m = body.add_node(
17088            Op::Constant { data: m_bytes },
17089            vec![],
17090            Shape::new(&[n, n], DType::F64),
17091        );
17092        let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
17093        body.set_outputs(vec![next]);
17094
17095        let mut g = Graph::new("be_grad_outer");
17096        let init = g.input("x0", Shape::new(&[n], DType::F64));
17097        let final_x = g.scan(init, body, length);
17098        let loss = g.reduce(
17099            final_x,
17100            ReduceOp::Sum,
17101            vec![0],
17102            false,
17103            Shape::new(&[1], DType::F64),
17104        );
17105        g.set_outputs(vec![loss]);
17106
17107        let bwd = grad_with_loss(&g, &[init]);
17108
17109        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17110            for node in graph.nodes() {
17111                let name = match &node.op {
17112                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17113                    _ => None,
17114                };
17115                if name == Some(want) {
17116                    return node.id;
17117                }
17118            }
17119            panic!("no node named {want:?}");
17120        };
17121        let init_bwd = find_by_name(&bwd, "x0");
17122        let d_out_bwd = find_by_name(&bwd, "d_output");
17123
17124        let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
17125        let d_seed = [1.0_f64];
17126        let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
17127        execute_thunks(&sched, arena.raw_buf_mut());
17128        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
17129
17130        let h = 1e-6;
17131        let loss_at = |x0: &[f64]| -> f64 {
17132            let mut acc = x0.to_vec();
17133            for _ in 0..length {
17134                let mut a_copy = m_data.clone();
17135                crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
17136            }
17137            acc.iter().sum()
17138        };
17139        for i in 0..n {
17140            let mut ip = init_data.to_vec();
17141            ip[i] += h;
17142            let mut im = init_data.to_vec();
17143            im[i] -= h;
17144            let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
17145            assert!(
17146                (dinit[i] - fd).abs() < 1e-7,
17147                "FD dinit[{i}]: AD={} FD={}",
17148                dinit[i],
17149                fd
17150            );
17151        }
17152    }
17153
17154    /// Trajectory-mode scan: same Backward Euler body, but record the
17155    /// carry at every step. Output is `[length, n]` — row `t` is the
17156    /// state after step `t+1`. Validates the SaveAt-style waveform
17157    /// recording end-to-end, including that the last row equals what
17158    /// the no-trajectory variant would have returned.
17159    #[test]
17160    fn scan_trajectory_backward_euler_records_waveform() {
17161        let n = 4usize;
17162        let length = 5u32;
17163        let dt = 0.05_f64;
17164
17165        let mut m_data = vec![0.0_f64; n * n];
17166        for i in 0..n {
17167            m_data[i * n + i] = 1.0 + dt * 2.0;
17168            if i > 0 {
17169                m_data[i * n + (i - 1)] = -dt;
17170            }
17171            if i + 1 < n {
17172                m_data[i * n + (i + 1)] = -dt;
17173            }
17174        }
17175        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17176
17177        let mut body = Graph::new("be_traj_body");
17178        let x = body.input("x", Shape::new(&[n], DType::F64));
17179        let m = body.add_node(
17180            Op::Constant { data: m_bytes },
17181            vec![],
17182            Shape::new(&[n, n], DType::F64),
17183        );
17184        let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
17185        body.set_outputs(vec![next]);
17186
17187        let mut g = Graph::new("be_traj_outer");
17188        let init = g.input("x0", Shape::new(&[n], DType::F64));
17189        let traj = g.scan_trajectory(init, body, length);
17190        g.set_outputs(vec![traj]);
17191
17192        let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
17193        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
17194        execute_thunks(&sched, arena.raw_buf_mut());
17195        let got = read_arena_f64(&arena, traj, length as usize * n);
17196
17197        // Reference: each step's solve, recorded.
17198        let mut want = Vec::<f64>::with_capacity(length as usize * n);
17199        let mut x_ref = init_data.to_vec();
17200        for _ in 0..length {
17201            let mut a_copy = m_data.clone();
17202            crate::blas::dgesv(&mut a_copy, &mut x_ref, n, 1);
17203            want.extend_from_slice(&x_ref);
17204        }
17205        for i in 0..length as usize * n {
17206            assert!(
17207                (got[i] - want[i]).abs() < 1e-12,
17208                "got[{i}] = {} ref {}",
17209                got[i],
17210                want[i]
17211            );
17212        }
17213
17214        // Sanity: trajectory rows are monotone-decreasing in mass
17215        // (Backward Euler diffuses; boundary leak removes mass).
17216        for t in 1..length as usize {
17217            let prev: f64 = got[(t - 1) * n..t * n].iter().sum();
17218            let curr: f64 = got[t * n..(t + 1) * n].iter().sum();
17219            assert!(
17220                curr <= prev + 1e-15,
17221                "mass should decay: row {} sum {prev}, row {t} sum {curr}",
17222                t - 1
17223            );
17224        }
17225
17226        // Last row of the trajectory equals what a non-trajectory
17227        // scan returns — verify by running the same forward through
17228        // the simpler API and comparing.
17229        let mut body2 = Graph::new("be_final_body");
17230        let x2 = body2.input("x", Shape::new(&[n], DType::F64));
17231        let m_bytes2: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17232        let m2 = body2.add_node(
17233            Op::Constant { data: m_bytes2 },
17234            vec![],
17235            Shape::new(&[n, n], DType::F64),
17236        );
17237        let next2 = body2.dense_solve(m2, x2, Shape::new(&[n], DType::F64));
17238        body2.set_outputs(vec![next2]);
17239
17240        let mut g2 = Graph::new("be_final_outer");
17241        let init2 = g2.input("x0", Shape::new(&[n], DType::F64));
17242        let final_x = g2.scan(init2, body2, length);
17243        g2.set_outputs(vec![final_x]);
17244        let (sched2, mut arena2) = prepare_f64(&g2, &[(init2, &init_data)]);
17245        execute_thunks(&sched2, arena2.raw_buf_mut());
17246        let final_got = read_arena_f64(&arena2, final_x, n);
17247
17248        let last_row = &got[(length as usize - 1) * n..length as usize * n];
17249        for i in 0..n {
17250            assert!(
17251                (last_row[i] - final_got[i]).abs() < 1e-15,
17252                "last trajectory row[{i}] = {} vs final-scan = {}",
17253                last_row[i],
17254                final_got[i]
17255            );
17256        }
17257    }
17258
17259    /// Op::Scan composing with Op::DenseSolve — the Circulax-shaped
17260    /// pattern for Backward Euler.
17261    /// Body: x_{t+1} = solve(I + dt·A, x_t).
17262    /// 1-D heat-equation Laplacian A; analytic ground truth from
17263    /// composing the same per-step solve in Rust.
17264    #[test]
17265    fn scan_backward_euler_heat_f64() {
17266        let n = 4usize;
17267        let length = 5u32;
17268        let dt = 0.05_f64;
17269
17270        // Construct M = I + dt · L  where L is the Laplacian (-1, 2, -1).
17271        // M is constant across iterations; embed it in the body via Op::Constant.
17272        let mut m_data = vec![0.0_f64; n * n];
17273        for i in 0..n {
17274            m_data[i * n + i] = 1.0 + dt * 2.0;
17275            if i > 0 {
17276                m_data[i * n + (i - 1)] = -dt;
17277            }
17278            if i + 1 < n {
17279                m_data[i * n + (i + 1)] = -dt;
17280            }
17281        }
17282        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17283
17284        let mut body = Graph::new("be_body");
17285        let x = body.input("x", Shape::new(&[n], DType::F64));
17286        let m = body.add_node(
17287            Op::Constant { data: m_bytes },
17288            vec![],
17289            Shape::new(&[n, n], DType::F64),
17290        );
17291        let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
17292        body.set_outputs(vec![next]);
17293
17294        let mut g = Graph::new("be_outer");
17295        let init = g.input("x0", Shape::new(&[n], DType::F64));
17296        let final_x = g.scan(init, body, length);
17297        g.set_outputs(vec![final_x]);
17298
17299        // Initial: a sharp pulse at index 1.
17300        let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
17301        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
17302        execute_thunks(&sched, arena.raw_buf_mut());
17303        let got = read_arena_f64(&arena, final_x, n);
17304
17305        // Reference: apply the same M-solve `length` times in pure Rust.
17306        let mut ref_x = init_data.to_vec();
17307        for _ in 0..length {
17308            let mut a_copy = m_data.clone();
17309            crate::blas::dgesv(&mut a_copy, &mut ref_x, n, 1);
17310        }
17311        for i in 0..n {
17312            assert!(
17313                (got[i] - ref_x[i]).abs() < 1e-12,
17314                "got[{i}] = {} ref {}",
17315                got[i],
17316                ref_x[i]
17317            );
17318        }
17319        // Sanity: pulse should diffuse, mass should be conserved-ish
17320        // (Backward Euler is mass-conserving for this stencil with
17321        // zero-flux boundaries — but our boundaries leak, so check
17322        // that mass strictly decreases instead).
17323        let mass: f64 = got.iter().sum();
17324        assert!(mass > 0.0 && mass < 1.0, "diffusion mass: {mass}");
17325    }
17326
17327    /// Multi-RHS forward DenseSolve: X = solve(A, B) with B [N, K]
17328    /// stays correct end-to-end. Verifies the executor/lowering and
17329    /// the LAPACK column-major dance both honour `nrhs > 1`.
17330    #[test]
17331    fn dense_solve_f64_multi_rhs_forward() {
17332        let n = 3usize;
17333        let k = 2usize;
17334        let mut g = Graph::new("solve_multi_rhs");
17335        let a = g.input("A", Shape::new(&[n, n], DType::F64));
17336        let b = g.input("B", Shape::new(&[n, k], DType::F64));
17337        let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
17338        g.set_outputs(vec![x]);
17339
17340        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17341        let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
17342        let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
17343        execute_thunks(&sched, arena.raw_buf_mut());
17344        let x_got = read_arena_f64(&arena, x, n * k);
17345        for c in 0..k {
17346            for i in 0..n {
17347                let mut acc = 0.0_f64;
17348                for j in 0..n {
17349                    acc += a_data[i * n + j] * x_got[j * k + c];
17350                }
17351                let want = b_data[i * k + c];
17352                assert!(
17353                    (acc - want).abs() < 1e-10,
17354                    "col {c} row {i}: got {acc} want {want}"
17355                );
17356            }
17357        }
17358    }
17359
17360    /// Multi-RHS reverse-mode VJP: dB = (Aᵀ)⁻¹·1, dA = -dB · Xᵀ.
17361    /// Verified analytically + finite differences on dB[0,0].
17362    #[test]
17363    fn dense_solve_f64_multi_rhs_gradient() {
17364        use rlx_opt::autodiff::grad_with_loss;
17365        let n = 3usize;
17366        let k = 2usize;
17367        let mut g = Graph::new("solve_mrhs_grad");
17368        let a = g.param("A", Shape::new(&[n, n], DType::F64));
17369        let b = g.input("B", Shape::new(&[n, k], DType::F64));
17370        let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
17371        let loss = g.reduce(
17372            x,
17373            ReduceOp::Sum,
17374            vec![0, 1],
17375            false,
17376            Shape::new(&[1], DType::F64),
17377        );
17378        g.set_outputs(vec![loss]);
17379
17380        let bwd = grad_with_loss(&g, &[a, b]);
17381        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17382            for node in graph.nodes() {
17383                let name = match &node.op {
17384                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17385                    _ => None,
17386                };
17387                if name == Some(want) {
17388                    return node.id;
17389                }
17390            }
17391            panic!("no node named {want:?}");
17392        };
17393        let a_bwd = find_by_name(&bwd, "A");
17394        let b_bwd = find_by_name(&bwd, "B");
17395        let d_out = find_by_name(&bwd, "d_output");
17396
17397        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17398        let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
17399        let d_seed = [1.0_f64];
17400
17401        let (sched, mut arena) = prepare_f64(
17402            &bwd,
17403            &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out, &d_seed)],
17404        );
17405        execute_thunks(&sched, arena.raw_buf_mut());
17406        let da_got = read_arena_f64(&arena, bwd.outputs[1], n * n);
17407        let db_got = read_arena_f64(&arena, bwd.outputs[2], n * k);
17408
17409        // Reference.
17410        let mut x_ref = b_data;
17411        {
17412            let mut a_copy = a_data;
17413            crate::blas::dgesv(&mut a_copy, &mut x_ref, n, k);
17414        }
17415        let mut at = [0.0_f64; 9];
17416        for i in 0..n {
17417            for j in 0..n {
17418                at[i * n + j] = a_data[j * n + i];
17419            }
17420        }
17421        let mut ones_nk = vec![1.0_f64; n * k];
17422        crate::blas::dgesv(&mut at, &mut ones_nk, n, k);
17423        let db_ref = ones_nk;
17424        let mut da_ref = [0.0_f64; 9];
17425        for i in 0..n {
17426            for j in 0..n {
17427                let mut acc = 0.0_f64;
17428                for c in 0..k {
17429                    acc += db_ref[i * k + c] * x_ref[j * k + c];
17430                }
17431                da_ref[i * n + j] = -acc;
17432            }
17433        }
17434        for i in 0..n * k {
17435            assert!(
17436                (db_got[i] - db_ref[i]).abs() < 1e-10,
17437                "dB[{i}]: got {} want {}",
17438                db_got[i],
17439                db_ref[i]
17440            );
17441        }
17442        for i in 0..n * n {
17443            assert!(
17444                (da_got[i] - da_ref[i]).abs() < 1e-10,
17445                "dA[{i}]: got {} want {}",
17446                da_got[i],
17447                da_ref[i]
17448            );
17449        }
17450
17451        // FD on dB[0,0].
17452        let h = 1e-6;
17453        let mut bp = b_data;
17454        bp[0] += h;
17455        let mut bm = b_data;
17456        bm[0] -= h;
17457        let xp = {
17458            let mut a_copy = a_data;
17459            crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
17460            bp
17461        };
17462        let xm = {
17463            let mut a_copy = a_data;
17464            crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
17465            bm
17466        };
17467        let lp: f64 = xp.iter().sum();
17468        let lm: f64 = xm.iter().sum();
17469        let fd = (lp - lm) / (2.0 * h);
17470        assert!(
17471            (db_got[0] - fd).abs() < 1e-7,
17472            "FD dB[0,0]: AD={} FD={}",
17473            db_got[0],
17474            fd
17475        );
17476    }
17477
17478    /// Multi-RHS forward-mode JVP w.r.t. B. Closed form: t_X = solve(A, t_B).
17479    #[test]
17480    fn dense_solve_f64_multi_rhs_jvp() {
17481        use rlx_opt::autodiff_fwd::jvp;
17482        let n = 3usize;
17483        let k = 2usize;
17484        let mut g = Graph::new("solve_mrhs_jvp");
17485        let a = g.input("A", Shape::new(&[n, n], DType::F64));
17486        let b = g.input("B", Shape::new(&[n, k], DType::F64));
17487        let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
17488        g.set_outputs(vec![x]);
17489
17490        let jg = jvp(&g, &[b]);
17491        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17492            for node in graph.nodes() {
17493                let name = match &node.op {
17494                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17495                    _ => None,
17496                };
17497                if name == Some(want) {
17498                    return node.id;
17499                }
17500            }
17501            panic!("no node named {want:?}");
17502        };
17503        let a_id = find_by_name(&jg, "A");
17504        let b_id = find_by_name(&jg, "B");
17505        let tb_id = find_by_name(&jg, "tangent_B");
17506
17507        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17508        let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
17509        let tb_data = [0.5, 0.0, -0.25, 1.0, 1.0, -0.5_f64];
17510
17511        let (sched, mut arena) =
17512            prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
17513        execute_thunks(&sched, arena.raw_buf_mut());
17514        let tangent_x = read_arena_f64(&arena, jg.outputs[1], n * k);
17515
17516        let mut a_copy = a_data;
17517        let mut tb_copy = tb_data;
17518        crate::blas::dgesv(&mut a_copy, &mut tb_copy, n, k);
17519        for i in 0..n * k {
17520            assert!(
17521                (tangent_x[i] - tb_copy[i]).abs() < 1e-10,
17522                "t_X[{i}]: AD={} ref={}",
17523                tangent_x[i],
17524                tb_copy[i]
17525            );
17526        }
17527
17528        let h = 1e-6;
17529        let mut bp = b_data;
17530        let mut bm = b_data;
17531        for i in 0..n * k {
17532            bp[i] += h * tb_data[i];
17533            bm[i] -= h * tb_data[i];
17534        }
17535        let xp = {
17536            let mut a_copy = a_data;
17537            crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
17538            bp
17539        };
17540        let xm = {
17541            let mut a_copy = a_data;
17542            crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
17543            bm
17544        };
17545        for i in 0..n * k {
17546            let fd = (xp[i] - xm[i]) / (2.0 * h);
17547            assert!(
17548                (tangent_x[i] - fd).abs() < 1e-7,
17549                "FD t_X[{i}]: AD={} FD={}",
17550                tangent_x[i],
17551                fd
17552            );
17553        }
17554    }
17555
17556    /// Forward-mode JVP through DenseSolve, end-to-end at f64.
17557    ///
17558    /// Build forward x = solve(A, b), call `jvp(forward, [b])`,
17559    /// compile + run, and check the tangent output matches the
17560    /// closed form `t_x = solve(A, t_b)` plus a finite-difference
17561    /// cross-check `(solve(A, b + h·t_b) − solve(A, b − h·t_b)) / 2h`.
17562    #[test]
17563    fn jvp_dense_solve_b_runs_and_matches_fd() {
17564        use rlx_opt::autodiff_fwd::jvp;
17565        let n = 3usize;
17566
17567        // Forward.
17568        let mut g = Graph::new("jvp_b_e2e");
17569        let a = g.input("A", Shape::new(&[n, n], DType::F64));
17570        let b = g.input("b", Shape::new(&[n], DType::F64));
17571        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
17572        g.set_outputs(vec![x]);
17573
17574        // JVP graph perturbing b only.
17575        let jg = jvp(&g, &[b]);
17576        // The JVP graph holds a fresh "tangent_b" Input on top of A and b.
17577        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17578            for node in graph.nodes() {
17579                let name = match &node.op {
17580                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17581                    _ => None,
17582                };
17583                if name == Some(want) {
17584                    return node.id;
17585                }
17586            }
17587            panic!("no node named {want:?}");
17588        };
17589        let a_id = find_by_name(&jg, "A");
17590        let b_id = find_by_name(&jg, "b");
17591        let tb_id = find_by_name(&jg, "tangent_b");
17592
17593        let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
17594        let b_data: [f64; 3] = [1.0, 2.0, 3.0];
17595        // Pick an arbitrary perturbation direction.
17596        let tb_data: [f64; 3] = [0.5, -0.25, 1.0];
17597
17598        let (sched, mut arena) =
17599            prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
17600        execute_thunks(&sched, arena.raw_buf_mut());
17601
17602        // Outputs: [primal_x, tangent_x].
17603        let primal_x = read_arena_f64(&arena, jg.outputs[0], n);
17604        let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
17605
17606        // Closed form: t_x = solve(A, t_b).
17607        let t_x_ref = {
17608            let mut a = a_data;
17609            let mut tb = tb_data;
17610            let info = crate::blas::dgesv(&mut a, &mut tb, n, 1);
17611            assert_eq!(info, 0);
17612            tb
17613        };
17614        for i in 0..n {
17615            assert!(
17616                (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
17617                "t_x[{i}]: got {} want {}",
17618                tangent_x[i],
17619                t_x_ref[i]
17620            );
17621        }
17622
17623        // FD: x(b + h·tb) − x(b − h·tb)) / 2h
17624        let h = 1e-6;
17625        let mut bp = b_data;
17626        let mut bm = b_data;
17627        for i in 0..n {
17628            bp[i] += h * tb_data[i];
17629            bm[i] -= h * tb_data[i];
17630        }
17631        let xp = {
17632            let mut a = a_data;
17633            let info = crate::blas::dgesv(&mut a, &mut bp, n, 1);
17634            assert_eq!(info, 0);
17635            bp
17636        };
17637        let xm = {
17638            let mut a = a_data;
17639            let info = crate::blas::dgesv(&mut a, &mut bm, n, 1);
17640            assert_eq!(info, 0);
17641            bm
17642        };
17643        let fd: Vec<f64> = (0..n).map(|i| (xp[i] - xm[i]) / (2.0 * h)).collect();
17644        for i in 0..n {
17645            assert!(
17646                (tangent_x[i] - fd[i]).abs() < 1e-7,
17647                "FD mismatch t_x[{i}]: AD={} FD={}",
17648                tangent_x[i],
17649                fd[i]
17650            );
17651        }
17652        // Sanity: primal output is the actual solve.
17653        let primal_ref = {
17654            let mut a = a_data;
17655            let mut b = b_data;
17656            crate::blas::dgesv(&mut a, &mut b, n, 1);
17657            b
17658        };
17659        for i in 0..n {
17660            assert!((primal_x[i] - primal_ref[i]).abs() < 1e-10);
17661        }
17662    }
17663
17664    /// Forward-mode JVP through DenseSolve perturbing A. The tangent
17665    /// path includes the −t_A·x correction term.
17666    /// `t_x = −solve(A, t_A · x)` should match a finite-difference
17667    /// directional derivative of `solve(A, b)` w.r.t. A in the
17668    /// `t_A` direction.
17669    #[test]
17670    fn jvp_dense_solve_a_runs_and_matches_fd() {
17671        use rlx_opt::autodiff_fwd::jvp;
17672        let n = 3usize;
17673
17674        let mut g = Graph::new("jvp_a_e2e");
17675        let a = g.input("A", Shape::new(&[n, n], DType::F64));
17676        let b = g.input("b", Shape::new(&[n], DType::F64));
17677        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
17678        g.set_outputs(vec![x]);
17679
17680        let jg = jvp(&g, &[a]);
17681        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17682            for node in graph.nodes() {
17683                let name = match &node.op {
17684                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17685                    _ => None,
17686                };
17687                if name == Some(want) {
17688                    return node.id;
17689                }
17690            }
17691            panic!("no node named {want:?}");
17692        };
17693        let a_id = find_by_name(&jg, "A");
17694        let b_id = find_by_name(&jg, "b");
17695        let ta_id = find_by_name(&jg, "tangent_A");
17696
17697        let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
17698        let b_data: [f64; 3] = [1.0, 2.0, 3.0];
17699        // Asymmetric perturbation direction for A.
17700        let ta_data: [f64; 9] = [0.10, -0.05, 0.02, 0.03, 0.20, -0.04, -0.01, 0.07, 0.15];
17701
17702        let (sched, mut arena) =
17703            prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (ta_id, &ta_data)]);
17704        execute_thunks(&sched, arena.raw_buf_mut());
17705
17706        let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
17707
17708        // Closed form: x = solve(A, b); t_x = −solve(A, t_A · x).
17709        let x_ref = {
17710            let mut a = a_data;
17711            let mut b = b_data;
17712            crate::blas::dgesv(&mut a, &mut b, n, 1);
17713            b
17714        };
17715        let mut prod = [0.0_f64; 3];
17716        for i in 0..n {
17717            for j in 0..n {
17718                prod[i] += ta_data[i * n + j] * x_ref[j];
17719            }
17720        }
17721        let t_x_ref = {
17722            let mut a = a_data;
17723            let mut p = prod;
17724            crate::blas::dgesv(&mut a, &mut p, n, 1);
17725            [-p[0], -p[1], -p[2]]
17726        };
17727        for i in 0..n {
17728            assert!(
17729                (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
17730                "closed-form t_x[{i}]: AD={} ref={}",
17731                tangent_x[i],
17732                t_x_ref[i]
17733            );
17734        }
17735
17736        // FD: solve(A + h·t_A, b) and solve(A − h·t_A, b).
17737        let h = 1e-6;
17738        let mut ap = a_data;
17739        let mut am = a_data;
17740        for i in 0..n * n {
17741            ap[i] += h * ta_data[i];
17742            am[i] -= h * ta_data[i];
17743        }
17744        let xp = {
17745            let mut a = ap;
17746            let mut b = b_data;
17747            crate::blas::dgesv(&mut a, &mut b, n, 1);
17748            b
17749        };
17750        let xm = {
17751            let mut a = am;
17752            let mut b = b_data;
17753            crate::blas::dgesv(&mut a, &mut b, n, 1);
17754            b
17755        };
17756        for i in 0..n {
17757            let fd = (xp[i] - xm[i]) / (2.0 * h);
17758            assert!(
17759                (tangent_x[i] - fd).abs() < 1e-7,
17760                "FD t_x[{i}]: AD={} FD={}",
17761                tangent_x[i],
17762                fd
17763            );
17764        }
17765    }
17766
17767    /// Real INT8 conv2d parity. Same setup as QMatMul: pre-quantize
17768    /// f32 inputs to i8, run `Op::QConv2d`, compare against an
17769    /// in-test reference loop that does the same i32 accumulation
17770    /// and requantize math. Symmetric quant (zp=0) to keep the math
17771    /// head-to-head.
17772    #[test]
17773    fn q_conv2d_matches_reference() {
17774        use rlx_ir::Philox4x32;
17775        // Small NCHW shape — enough to exercise stride/padding edges.
17776        let n = 1usize;
17777        let c_in = 2usize;
17778        let h = 5usize;
17779        let w_in = 5usize;
17780        let c_out = 3usize;
17781        let kh = 3usize;
17782        let kw = 3usize;
17783        let ph = 1usize;
17784        let pw = 1usize;
17785        let sh = 1usize;
17786        let sw = 1usize;
17787        let h_out = (h + 2 * ph - kh) / sh + 1;
17788        let w_out = (w_in + 2 * pw - kw) / sw + 1;
17789
17790        let x_scale = 0.04f32;
17791        let w_scale = 0.02f32;
17792        let out_scale = 0.5f32;
17793        let mult = x_scale * w_scale / out_scale;
17794
17795        let mut rng = Philox4x32::new(2099);
17796        let mut xf = vec![0f32; n * c_in * h * w_in];
17797        rng.fill_normal(&mut xf);
17798        let mut wf = vec![0f32; c_out * c_in * kh * kw];
17799        rng.fill_normal(&mut wf);
17800        let xq: Vec<i8> = xf
17801            .iter()
17802            .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
17803            .collect();
17804        let wq: Vec<i8> = wf
17805            .iter()
17806            .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
17807            .collect();
17808        let bias: Vec<i32> = vec![0i32; c_out];
17809
17810        let mut g = Graph::new("qconv");
17811        let xn = g.input("x", Shape::new(&[n, c_in, h, w_in], DType::I8));
17812        let wn = g.input("w", Shape::new(&[c_out, c_in, kh, kw], DType::I8));
17813        let bn = g.input("b", Shape::new(&[c_out], DType::I32));
17814        let out = g.q_conv2d(
17815            xn,
17816            wn,
17817            bn,
17818            vec![kh, kw],
17819            vec![sh, sw],
17820            vec![ph, pw],
17821            vec![1, 1],
17822            1,
17823            0,
17824            0,
17825            0,
17826            mult,
17827            Shape::new(&[n, c_out, h_out, w_out], DType::I8),
17828        );
17829        g.set_outputs(vec![out]);
17830
17831        let plan = rlx_opt::memory::plan_memory(&g);
17832        let mut arena = crate::arena::Arena::from_plan(plan);
17833        let sched = compile_thunks(&g, &arena);
17834        // Capture offsets before borrowing the buf mutably (avoids
17835        // overlap between &mut and the &arena.byte_offset reads).
17836        let xn_off = arena.byte_offset(xn);
17837        let wn_off = arena.byte_offset(wn);
17838        let bn_off = arena.byte_offset(bn);
17839        let out_off = arena.byte_offset(out);
17840        let buf = arena.raw_buf_mut();
17841        unsafe {
17842            let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
17843            for (i, &v) in xq.iter().enumerate() {
17844                *p.add(i) = v;
17845            }
17846            let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
17847            for (i, &v) in wq.iter().enumerate() {
17848                *p.add(i) = v;
17849            }
17850            let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
17851            for (i, &v) in bias.iter().enumerate() {
17852                *p.add(i) = v;
17853            }
17854        }
17855        execute_thunks(&sched, arena.raw_buf_mut());
17856        let out_q: Vec<i8> = unsafe {
17857            let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
17858            (0..n * c_out * h_out * w_out).map(|i| *p.add(i)).collect()
17859        };
17860
17861        // Reference: scalar loop in NCHW with the same requantize.
17862        let mut out_ref = vec![0i8; n * c_out * h_out * w_out];
17863        for ni in 0..n {
17864            for co in 0..c_out {
17865                for ho in 0..h_out {
17866                    for wo in 0..w_out {
17867                        let mut acc: i32 = 0;
17868                        for ci in 0..c_in {
17869                            for ki in 0..kh {
17870                                for kj in 0..kw {
17871                                    let hi = ho * sh + ki;
17872                                    let wi = wo * sw + kj;
17873                                    if hi < ph || wi < pw {
17874                                        continue;
17875                                    }
17876                                    let hi = hi - ph;
17877                                    let wi = wi - pw;
17878                                    if hi >= h || wi >= w_in {
17879                                        continue;
17880                                    }
17881                                    let xv =
17882                                        xq[((ni * c_in) + ci) * h * w_in + hi * w_in + wi] as i32;
17883                                    let wv = wq[((co * c_in) + ci) * kh * kw + ki * kw + kj] as i32;
17884                                    acc += xv * wv;
17885                                }
17886                            }
17887                        }
17888                        let r = (acc as f32 * mult).round() as i32;
17889                        let r = r.clamp(-128, 127) as i8;
17890                        out_ref[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = r;
17891                    }
17892                }
17893            }
17894        }
17895
17896        for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
17897            assert_eq!(a, r, "q_conv2d[{i}]: kernel {a} vs reference {r}");
17898        }
17899    }
17900
17901    /// Real INT8 matmul parity: compare `Op::QMatMul` against the
17902    /// fake-quant reference `Dequantize → MatMul → Quantize` that
17903    /// would produce the same output if we round-tripped through
17904    /// f32. Both should agree element-for-element (or within ±1 i8
17905    /// step, since rounding in the requantize uses different code
17906    /// paths). Symmetric quantization (zp=0) for both paths to keep
17907    /// the math head-to-head.
17908    #[test]
17909    fn q_matmul_matches_fake_quant_reference() {
17910        use rlx_ir::Philox4x32;
17911        let m = 3usize;
17912        let k = 8usize;
17913        let n = 5usize;
17914        let mut rng = Philox4x32::new(2031);
17915
17916        // Pick scales and quantize random f32 inputs to i8.
17917        let x_scale = 0.05f32;
17918        let w_scale = 0.03f32;
17919        let out_scale = 0.4f32;
17920        let mult = x_scale * w_scale / out_scale;
17921        let mut xf = vec![0f32; m * k];
17922        rng.fill_normal(&mut xf);
17923        let mut wf = vec![0f32; k * n];
17924        rng.fill_normal(&mut wf);
17925        let xq: Vec<i8> = xf
17926            .iter()
17927            .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
17928            .collect();
17929        let wq: Vec<i8> = wf
17930            .iter()
17931            .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
17932            .collect();
17933        let bias: Vec<i32> = vec![0i32; n];
17934
17935        // ── Direct INT8 path ──
17936        let _f = DType::F32;
17937        let mut g_q = Graph::new("qmm_direct");
17938        let xn = g_q.input("x", Shape::new(&[m, k], DType::I8));
17939        let wn = g_q.input("w", Shape::new(&[k, n], DType::I8));
17940        let bn = g_q.input("b", Shape::new(&[n], DType::I32));
17941        let out = g_q.q_matmul(xn, wn, bn, 0, 0, 0, mult, Shape::new(&[m, n], DType::I8));
17942        g_q.set_outputs(vec![out]);
17943        let plan = rlx_opt::memory::plan_memory(&g_q);
17944        let mut arena = crate::arena::Arena::from_plan(plan);
17945        let sched = compile_thunks(&g_q, &arena);
17946
17947        // Fill inputs.
17948        let xn_off = arena.byte_offset(xn);
17949        let wn_off = arena.byte_offset(wn);
17950        let bn_off = arena.byte_offset(bn);
17951        let out_off = arena.byte_offset(out);
17952        let buf = arena.raw_buf_mut();
17953        unsafe {
17954            let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
17955            for (i, &v) in xq.iter().enumerate() {
17956                *p.add(i) = v;
17957            }
17958            let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
17959            for (i, &v) in wq.iter().enumerate() {
17960                *p.add(i) = v;
17961            }
17962            let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
17963            for (i, &v) in bias.iter().enumerate() {
17964                *p.add(i) = v;
17965            }
17966        }
17967        execute_thunks(&sched, arena.raw_buf_mut());
17968        let out_q: Vec<i8> = unsafe {
17969            let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
17970            (0..m * n).map(|i| *p.add(i)).collect()
17971        };
17972
17973        // ── Fake-quant reference: scalar emulation in plain Rust ──
17974        // Same arithmetic the kernel does, but in a verifier loop:
17975        //   acc = Σ (x[m,k]) · (w[k,n]),  // zps are 0
17976        //   out[m,n] = saturate_i8(round(acc · mult) + 0)
17977        let mut out_ref = vec![0i8; m * n];
17978        for mi in 0..m {
17979            for ni in 0..n {
17980                let mut acc: i32 = 0;
17981                for ki in 0..k {
17982                    acc += (xq[mi * k + ki] as i32) * (wq[ki * n + ni] as i32);
17983                }
17984                let r = (acc as f32 * mult).round() as i32;
17985                out_ref[mi * n + ni] = r.clamp(-128, 127) as i8;
17986            }
17987        }
17988
17989        for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
17990            assert_eq!(a, r, "q_matmul[{i}]: kernel {a} vs reference {r}");
17991        }
17992    }
17993
17994    /// Quantize/Dequantize round-trip — quantize an f32 tensor, then
17995    /// dequantize back, and confirm the result tracks the input
17996    /// within the per-element scale (the inevitable rounding error).
17997    /// Also pins the kernel's saturation behavior at the i8 limits.
17998    #[test]
17999    fn quantize_dequantize_round_trip() {
18000        use rlx_ir::Philox4x32;
18001        let len = 64;
18002        let mut rng = Philox4x32::new(2027);
18003        let mut x = vec![0f32; len];
18004        rng.fill_normal(&mut x);
18005        // Stretch a couple values past the +/- saturation cliff so
18006        // the saturate_i8 path is exercised.
18007        x[0] = 999.0;
18008        x[1] = -999.0;
18009
18010        let scale = 0.05f32;
18011        let zp = 3i32;
18012
18013        let f = DType::F32;
18014        let mut g = Graph::new("qdq");
18015        let xn = g.input("x", Shape::new(&[len], f));
18016        let q = g.quantize(xn, scale, zp);
18017        let dq = g.dequantize(q, scale, zp);
18018        g.set_outputs(vec![dq]);
18019
18020        let plan = rlx_opt::memory::plan_memory(&g);
18021        let mut arena = crate::arena::Arena::from_plan(plan);
18022        let sched = compile_thunks(&g, &arena);
18023        let xn_off = arena.byte_offset(xn);
18024        let dq_off = arena.byte_offset(dq);
18025        let buf = arena.raw_buf_mut();
18026        unsafe {
18027            let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
18028            for (i, &v) in x.iter().enumerate() {
18029                *p.add(i) = v;
18030            }
18031        }
18032        execute_thunks(&sched, arena.raw_buf_mut());
18033        let out: Vec<f32> = unsafe {
18034            let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
18035            (0..len).map(|i| *p.add(i)).collect()
18036        };
18037
18038        // Saturated values at i=0,1 should clamp to ±127's dequant
18039        // range (= (±127 - zp) · scale).
18040        let sat_pos = (127 - zp) as f32 * scale;
18041        let sat_neg = (-128 - zp) as f32 * scale;
18042        assert!((out[0] - sat_pos).abs() < 1e-6, "+sat: {}", out[0]);
18043        assert!((out[1] - sat_neg).abs() < 1e-6, "-sat: {}", out[1]);
18044
18045        // Everything else should round-trip within `scale` (one quant
18046        // step = the worst-case rounding error).
18047        for i in 2..len {
18048            assert!(
18049                (out[i] - x[i]).abs() <= scale + 1e-5,
18050                "qdq[{i}]: {} → {}, scale={scale}",
18051                x[i],
18052                out[i]
18053            );
18054        }
18055    }
18056
18057    /// Per-channel quantize / dequantize: independent scale and zp
18058    /// per slice along an axis. Verifies (a) each channel uses its
18059    /// own scale (not a shared one), (b) saturation still respects
18060    /// the i8 range, (c) channel data layout decomposition is
18061    /// correct (no cross-channel leakage).
18062    #[test]
18063    fn quantize_per_channel_round_trip() {
18064        let c = 4usize;
18065        let inner = 5usize;
18066        // Different magnitudes per channel — proves the per-channel
18067        // scale is actually being read for each row.
18068        let mags = [0.01f32, 0.5, 5.0, 50.0];
18069        let mut x = vec![0f32; c * inner];
18070        for ci in 0..c {
18071            for ii in 0..inner {
18072                // Sweep through values that span [-max_abs, +max_abs]
18073                // for each channel, plus one value past the cliff to
18074                // trigger saturation.
18075                x[ci * inner + ii] = match ii {
18076                    0 => -mags[ci],
18077                    1 => 0.0,
18078                    2 => mags[ci],
18079                    3 => mags[ci] * 1000.0,  // saturates +
18080                    _ => -mags[ci] * 1000.0, // saturates -
18081                };
18082            }
18083        }
18084        let scales: Vec<f32> = mags.iter().map(|&m| m / 127.0).collect();
18085        let zps: Vec<i32> = vec![0, 0, 0, 0];
18086
18087        let f = DType::F32;
18088        let mut g = Graph::new("qdq_pc");
18089        let xn = g.input("x", Shape::new(&[c, inner], f));
18090        let q = g.quantize_per_channel(xn, 0, scales.clone(), zps.clone());
18091        let dq = g.dequantize_per_channel(q, 0, scales.clone(), zps);
18092        g.set_outputs(vec![dq]);
18093
18094        let plan = rlx_opt::memory::plan_memory(&g);
18095        let mut arena = crate::arena::Arena::from_plan(plan);
18096        let sched = compile_thunks(&g, &arena);
18097        let xn_off = arena.byte_offset(xn);
18098        let dq_off = arena.byte_offset(dq);
18099        let buf = arena.raw_buf_mut();
18100        unsafe {
18101            let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
18102            for (i, &v) in x.iter().enumerate() {
18103                *p.add(i) = v;
18104            }
18105        }
18106        execute_thunks(&sched, arena.raw_buf_mut());
18107        let out: Vec<f32> = unsafe {
18108            let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
18109            (0..c * inner).map(|i| *p.add(i)).collect()
18110        };
18111
18112        for ci in 0..c {
18113            // Within-range entries (positions 0, 1, 2) must round-trip
18114            // within one quant step of *that channel's* scale.
18115            for ii in 0..3 {
18116                let idx = ci * inner + ii;
18117                assert!(
18118                    (out[idx] - x[idx]).abs() <= scales[ci] + 1e-5,
18119                    "ch {ci} idx {ii}: {} vs {}",
18120                    x[idx],
18121                    out[idx]
18122                );
18123            }
18124            // Saturated positions clamp to ±127 · scale[ci].
18125            let sat_pos = 127.0 * scales[ci];
18126            let sat_neg = -128.0 * scales[ci];
18127            assert!(
18128                (out[ci * inner + 3] - sat_pos).abs() < 1e-5,
18129                "ch {ci} +sat: {}",
18130                out[ci * inner + 3]
18131            );
18132            assert!(
18133                (out[ci * inner + 4] - sat_neg).abs() < 1e-5,
18134                "ch {ci} -sat: {}",
18135                out[ci * inner + 4]
18136            );
18137        }
18138    }
18139
18140    /// `Op::ActivationBackward` parity for every supported kind.
18141    /// Builds a single-op graph `dx = activation_backward(x, dy)` and
18142    /// compares each `dx[i]` to the central-difference `(act(x+ε) -
18143    /// act(x-ε)) / (2ε) · dy\[i\]`. Sweeps the closed-form covered by
18144    /// the kernel.
18145    #[test]
18146    fn activation_backward_matches_numerical_per_kind() {
18147        use rlx_ir::Philox4x32;
18148        use rlx_ir::op::Activation;
18149        let mut rng = Philox4x32::new(91);
18150        let len = 32;
18151        // x sampled away from kink/branch points: shifted positive
18152        // (exp/sqrt/log domain) for the unary-positive activations;
18153        // wide range otherwise. Two parallel tests would be cleaner
18154        // but this is concise enough.
18155        let mut x_pos = vec![0f32; len];
18156        rng.fill_normal(&mut x_pos);
18157        for v in x_pos.iter_mut() {
18158            *v = v.abs() + 0.5;
18159        }
18160        let mut x_any = vec![0f32; len];
18161        rng.fill_normal(&mut x_any);
18162        let mut dy = vec![0f32; len];
18163        rng.fill_normal(&mut dy);
18164
18165        for &(kind, x_data, eps, tol) in &[
18166            (Activation::Sigmoid, &x_any[..], 1e-3, 5e-3),
18167            (Activation::Tanh, &x_any[..], 1e-3, 5e-3),
18168            (Activation::Silu, &x_any[..], 1e-3, 5e-3),
18169            (Activation::Gelu, &x_any[..], 1e-3, 5e-3),
18170            (Activation::GeluApprox, &x_any[..], 1e-3, 5e-3),
18171            (Activation::Exp, &x_any[..], 1e-4, 5e-3),
18172            (Activation::Log, &x_pos[..], 1e-4, 5e-3),
18173            (Activation::Sqrt, &x_pos[..], 1e-4, 5e-3),
18174            (Activation::Rsqrt, &x_pos[..], 1e-4, 5e-3),
18175            (Activation::Neg, &x_any[..], 1e-3, 5e-4),
18176        ] {
18177            let f = DType::F32;
18178            let mut g = Graph::new("act_bw");
18179            let xn = g.input("x", Shape::new(&[len], f));
18180            let dyn_ = g.input("dy", Shape::new(&[len], f));
18181            let dx = g.activation_backward(kind, xn, dyn_);
18182            g.set_outputs(vec![dx]);
18183
18184            let plan = rlx_opt::memory::plan_memory(&g);
18185            let mut arena = crate::arena::Arena::from_plan(plan);
18186            let sched = compile_thunks(&g, &arena);
18187
18188            let xn_off = arena.byte_offset(xn);
18189            let dyn_off = arena.byte_offset(dyn_);
18190            let dx_off = arena.byte_offset(dx);
18191            let buf = arena.raw_buf_mut();
18192            unsafe {
18193                let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
18194                for (i, &v) in x_data.iter().enumerate() {
18195                    *p.add(i) = v;
18196                }
18197                let p = buf.as_mut_ptr().add(dyn_off) as *mut f32;
18198                for (i, &v) in dy.iter().enumerate() {
18199                    *p.add(i) = v;
18200                }
18201            }
18202            execute_thunks(&sched, arena.raw_buf_mut());
18203            let analytical: Vec<f32> = unsafe {
18204                let p = arena.raw_buf().as_ptr().add(dx_off) as *const f32;
18205                (0..len).map(|i| *p.add(i)).collect()
18206            };
18207
18208            // Apply the forward activation manually; finite-difference
18209            // each element.
18210            let act_apply = |kind: Activation, x: f32| -> f32 {
18211                match kind {
18212                    Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
18213                    Activation::Tanh => x.tanh(),
18214                    Activation::Silu => x / (1.0 + (-x).exp()),
18215                    Activation::Gelu => {
18216                        // Match the kernel's exact erf form.
18217                        const INV_SQRT2: f32 = 0.707_106_77;
18218                        0.5 * x * (1.0 + erf_f32(x * INV_SQRT2))
18219                    }
18220                    Activation::GeluApprox => {
18221                        const C: f32 = 0.797_884_6;
18222                        const A: f32 = 0.044_715;
18223                        let inner = C * (x + A * x * x * x);
18224                        0.5 * x * (1.0 + inner.tanh())
18225                    }
18226                    Activation::Exp => x.exp(),
18227                    Activation::Log => x.ln(),
18228                    Activation::Sqrt => x.sqrt(),
18229                    Activation::Rsqrt => 1.0 / x.sqrt(),
18230                    Activation::Neg => -x,
18231                    Activation::Relu => x.max(0.0),
18232                    Activation::Abs => x.abs(),
18233                    Activation::Round => x.round(),
18234                    Activation::Sin => x.sin(),
18235                    Activation::Cos => x.cos(),
18236                    Activation::Tan => x.tan(),
18237                    Activation::Atan => x.atan(),
18238                }
18239            };
18240            for i in 0..len {
18241                let xv = x_data[i];
18242                let plus = act_apply(kind, xv + eps);
18243                let minus = act_apply(kind, xv - eps);
18244                let num = (plus - minus) / (2.0 * eps) * dy[i];
18245                assert!(
18246                    (analytical[i] - num).abs() < tol,
18247                    "{kind:?}[{i}]: analytical {} vs numerical {num}",
18248                    analytical[i]
18249                );
18250            }
18251        }
18252    }
18253
18254    /// Batched 3-D MatMul VJP — the transformer-attention shape
18255    /// `[B, M, K] @ [B, K, N] = [B, M, N]`. Both gradients flow through
18256    /// `Op::Transpose` with a perm that swaps the last two dims.
18257    #[test]
18258    fn matmul_3d_gradient_matches_numerical() {
18259        use rlx_ir::Philox4x32;
18260        let batch = 2usize;
18261        let m = 3usize;
18262        let k = 4usize;
18263        let n = 5usize;
18264        let mut rng = Philox4x32::new(101);
18265        let mut a_data = vec![0f32; batch * m * k];
18266        rng.fill_normal(&mut a_data);
18267        let mut b_data = vec![0f32; batch * k * n];
18268        rng.fill_normal(&mut b_data);
18269
18270        let f = DType::F32;
18271        let mut fwd = Graph::new("matmul_3d");
18272        let an = fwd.input("a", Shape::new(&[batch, m, k], f));
18273        let bp = fwd.param("b", Shape::new(&[batch, k, n], f));
18274        let mm = fwd.matmul(an, bp, Shape::new(&[batch, m, n], f));
18275        let loss = fwd.add_node(
18276            Op::Reduce {
18277                op: ReduceOp::Sum,
18278                axes: vec![0, 1, 2],
18279                keep_dim: false,
18280            },
18281            vec![mm],
18282            Shape::from_dims(&[], f),
18283        );
18284        fwd.set_outputs(vec![loss]);
18285
18286        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[bp]);
18287        let d_out = bwd_graph
18288            .nodes()
18289            .iter()
18290            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18291            .map(|n| n.id)
18292            .unwrap();
18293
18294        let plan = rlx_opt::memory::plan_memory(&bwd_graph);
18295        let mut arena = crate::arena::Arena::from_plan(plan);
18296        let sched = compile_thunks(&bwd_graph, &arena);
18297        for &(id, data) in &[(an, &a_data), (bp, &b_data), (d_out, &vec![1.0f32])] {
18298            let off = arena.byte_offset(id);
18299            let buf = arena.raw_buf_mut();
18300            unsafe {
18301                let p = buf.as_mut_ptr().add(off) as *mut f32;
18302                for (i, &v) in data.iter().enumerate() {
18303                    *p.add(i) = v;
18304                }
18305            }
18306        }
18307        execute_thunks(&sched, arena.raw_buf_mut());
18308        let gb_id = bwd_graph.outputs[1];
18309        let g_b: Vec<f32> = unsafe {
18310            let p = arena.raw_buf().as_ptr().add(arena.byte_offset(gb_id)) as *const f32;
18311            (0..batch * k * n).map(|i| *p.add(i)).collect()
18312        };
18313
18314        // Numerical gradient: differentiate sum(a @ b) w.r.t. each b entry.
18315        let forward_loss = |b_vals: &[f32]| -> f32 {
18316            let mut out = vec![0f32; batch * m * n];
18317            for bi in 0..batch {
18318                for mi in 0..m {
18319                    for ni in 0..n {
18320                        let mut acc = 0f32;
18321                        for ki in 0..k {
18322                            acc +=
18323                                a_data[bi * m * k + mi * k + ki] * b_vals[bi * k * n + ki * n + ni];
18324                        }
18325                        out[bi * m * n + mi * n + ni] = acc;
18326                    }
18327                }
18328            }
18329            out.iter().sum()
18330        };
18331        let eps = 1e-3f32;
18332        let mut bp_p = b_data.clone();
18333        let mut g_b_num = vec![0f32; b_data.len()];
18334        for i in 0..b_data.len() {
18335            let s = bp_p[i];
18336            bp_p[i] = s + eps;
18337            let lp = forward_loss(&bp_p);
18338            bp_p[i] = s - eps;
18339            let lm = forward_loss(&bp_p);
18340            bp_p[i] = s;
18341            g_b_num[i] = (lp - lm) / (2.0 * eps);
18342        }
18343        for (i, (a, n)) in g_b.iter().zip(&g_b_num).enumerate() {
18344            assert!(
18345                (a - n).abs() < 5e-3,
18346                "matmul_3d g_b[{i}]: analytical {a} vs numerical {n}"
18347            );
18348        }
18349    }
18350
18351    /// Composed `Op::Softmax` VJP — the gradient is built from
18352    /// `mul + reduce_sum + expand + sub + mul`, no dedicated
18353    /// SoftmaxBackward kernel. Verifies the closed-form
18354    /// `dx = y · (g - Σ y·g)` matches the FD gradient over a small
18355    /// 2-D logits tensor.
18356    #[test]
18357    fn softmax_gradient_matches_numerical() {
18358        use rlx_ir::Philox4x32;
18359        let n = 3usize;
18360        let c = 5usize;
18361        let mut rng = Philox4x32::new(57);
18362        let mut x_data = vec![0f32; n * c];
18363        rng.fill_normal(&mut x_data);
18364
18365        let f = DType::F32;
18366        let mut fwd = Graph::new("softmax_only");
18367        let xn = fwd.input("x", Shape::new(&[n, c], f));
18368        let sm = fwd.add_node(Op::Softmax { axis: -1 }, vec![xn], Shape::new(&[n, c], f));
18369        // Loss = sum(softmax · target) for some random fixed target —
18370        // any linear loss will do; sum-of-all is the simplest and gives
18371        // a uniform gradient flow into the softmax.
18372        let loss = fwd.add_node(
18373            Op::Reduce {
18374                op: ReduceOp::Sum,
18375                axes: vec![0, 1],
18376                keep_dim: false,
18377            },
18378            vec![sm],
18379            Shape::from_dims(&[], f),
18380        );
18381        fwd.set_outputs(vec![loss]);
18382
18383        // `wrt = [xn]` — autodiff exposes the gradient w.r.t. the
18384        // input so we can compare it directly. The forward NodeId for
18385        // `xn` doubles as its bwd-graph mirror.
18386        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn]);
18387        let d_out = bwd_graph
18388            .nodes()
18389            .iter()
18390            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18391            .map(|n| n.id)
18392            .unwrap();
18393
18394        let plan = rlx_opt::memory::plan_memory(&bwd_graph);
18395        let mut arena = crate::arena::Arena::from_plan(plan);
18396        let sched = compile_thunks(&bwd_graph, &arena);
18397        for &(id, data) in &[(xn, &x_data), (d_out, &vec![1.0f32])] {
18398            let off = arena.byte_offset(id);
18399            let buf = arena.raw_buf_mut();
18400            unsafe {
18401                let p = buf.as_mut_ptr().add(off) as *mut f32;
18402                for (i, &v) in data.iter().enumerate() {
18403                    *p.add(i) = v;
18404                }
18405            }
18406        }
18407        execute_thunks(&sched, arena.raw_buf_mut());
18408        let g_x_id = bwd_graph.outputs[1];
18409        let g_x: Vec<f32> = unsafe {
18410            let p = arena.raw_buf().as_ptr().add(arena.byte_offset(g_x_id)) as *const f32;
18411            (0..n * c).map(|i| *p.add(i)).collect()
18412        };
18413
18414        // Loss derivative: softmax sums to 1 per row → d/dx_i sum(softmax) = 0
18415        // analytically. So expect g_x ≈ 0 within FD precision. (This
18416        // doubles as a strong sanity check for the composition.)
18417        let forward_loss = |x: &[f32]| -> f32 {
18418            let mut total = 0f32;
18419            for ni in 0..n {
18420                let row = &x[ni * c..(ni + 1) * c];
18421                let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
18422                let denom: f32 = row.iter().map(|&v| (v - m).exp()).sum();
18423                for &v in row {
18424                    total += (v - m).exp() / denom;
18425                }
18426            }
18427            total
18428        };
18429        let eps = 1e-3f32;
18430        let mut p = x_data.clone();
18431        for i in 0..x_data.len() {
18432            let s = p[i];
18433            p[i] = s + eps;
18434            let lp = forward_loss(&p);
18435            p[i] = s - eps;
18436            let lm = forward_loss(&p);
18437            p[i] = s;
18438            let num = (lp - lm) / (2.0 * eps);
18439            assert!(
18440                (g_x[i] - num).abs() < 5e-3,
18441                "softmax g_x[{i}]: analytical {} vs numerical {num}",
18442                g_x[i]
18443            );
18444        }
18445    }
18446
18447    /// LayerNorm VJP — three gradients in one pass:
18448    ///   d_x via `LayerNormBackwardInput`,
18449    ///   d_gamma via `LayerNormBackwardGamma`,
18450    ///   d_beta = `unbroadcast(upstream)` to gamma's shape.
18451    #[test]
18452    fn layer_norm_gradient_matches_numerical() {
18453        use rlx_ir::Philox4x32;
18454        let rows = 3usize;
18455        let h = 6usize;
18456        let mut rng = Philox4x32::new(1009);
18457        let mut x_data = vec![0f32; rows * h];
18458        rng.fill_normal(&mut x_data);
18459        let mut g_data = vec![0f32; h];
18460        rng.fill_normal(&mut g_data);
18461        for v in g_data.iter_mut() {
18462            *v = v.abs() + 0.5;
18463        }
18464        let mut b_data = vec![0f32; h];
18465        rng.fill_normal(&mut b_data);
18466        let eps = 1e-5f32;
18467
18468        let f = DType::F32;
18469        let mut fwd = Graph::new("ln_only");
18470        let xn = fwd.input("x", Shape::new(&[rows, h], f));
18471        let gp = fwd.param("gamma", Shape::new(&[h], f));
18472        let bp = fwd.param("beta", Shape::new(&[h], f));
18473        let ln = fwd.add_node(
18474            Op::LayerNorm { axis: -1, eps },
18475            vec![xn, gp, bp],
18476            Shape::new(&[rows, h], f),
18477        );
18478        let loss = fwd.add_node(
18479            Op::Reduce {
18480                op: ReduceOp::Sum,
18481                axes: vec![0, 1],
18482                keep_dim: false,
18483            },
18484            vec![ln],
18485            Shape::from_dims(&[], f),
18486        );
18487        fwd.set_outputs(vec![loss]);
18488
18489        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn, gp, bp]);
18490        let d_out = bwd_graph
18491            .nodes()
18492            .iter()
18493            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18494            .map(|n| n.id)
18495            .unwrap();
18496
18497        let plan = rlx_opt::memory::plan_memory(&bwd_graph);
18498        let mut arena = crate::arena::Arena::from_plan(plan);
18499        let sched = compile_thunks(&bwd_graph, &arena);
18500        for &(id, data) in &[
18501            (xn, &x_data),
18502            (gp, &g_data),
18503            (bp, &b_data),
18504            (d_out, &vec![1.0f32]),
18505        ] {
18506            let off = arena.byte_offset(id);
18507            let buf = arena.raw_buf_mut();
18508            unsafe {
18509                let p = buf.as_mut_ptr().add(off) as *mut f32;
18510                for (i, &v) in data.iter().enumerate() {
18511                    *p.add(i) = v;
18512                }
18513            }
18514        }
18515        execute_thunks(&sched, arena.raw_buf_mut());
18516        let read = |id: NodeId, n: usize| -> Vec<f32> {
18517            let off = arena.byte_offset(id);
18518            unsafe {
18519                let p = arena.raw_buf().as_ptr().add(off) as *const f32;
18520                (0..n).map(|i| *p.add(i)).collect()
18521            }
18522        };
18523        let dx_a = read(bwd_graph.outputs[1], rows * h);
18524        let dg_a = read(bwd_graph.outputs[2], h);
18525        let db_a = read(bwd_graph.outputs[3], h);
18526
18527        let forward_loss = |x: &[f32], g: &[f32], b: &[f32]| -> f32 {
18528            let mut total = 0f32;
18529            for r in 0..rows {
18530                let row = &x[r * h..(r + 1) * h];
18531                let mean = row.iter().sum::<f32>() / h as f32;
18532                let var = row.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / h as f32;
18533                let inv_std = 1.0 / (var + eps).sqrt();
18534                for d in 0..h {
18535                    total += ((row[d] - mean) * inv_std) * g[d] + b[d];
18536                }
18537            }
18538            total
18539        };
18540        let h_eps = 1e-3f32;
18541
18542        let mut x_p = x_data.clone();
18543        for i in 0..x_p.len() {
18544            let s = x_p[i];
18545            x_p[i] = s + h_eps;
18546            let lp = forward_loss(&x_p, &g_data, &b_data);
18547            x_p[i] = s - h_eps;
18548            let lm = forward_loss(&x_p, &g_data, &b_data);
18549            x_p[i] = s;
18550            let num = (lp - lm) / (2.0 * h_eps);
18551            assert!(
18552                (dx_a[i] - num).abs() < 5e-3,
18553                "ln dx[{i}]: analytical {} vs numerical {num}",
18554                dx_a[i]
18555            );
18556        }
18557        let mut g_p = g_data.clone();
18558        for i in 0..g_p.len() {
18559            let s = g_p[i];
18560            g_p[i] = s + h_eps;
18561            let lp = forward_loss(&x_data, &g_p, &b_data);
18562            g_p[i] = s - h_eps;
18563            let lm = forward_loss(&x_data, &g_p, &b_data);
18564            g_p[i] = s;
18565            let num = (lp - lm) / (2.0 * h_eps);
18566            assert!(
18567                (dg_a[i] - num).abs() < 5e-3,
18568                "ln dg[{i}]: analytical {} vs numerical {num}",
18569                dg_a[i]
18570            );
18571        }
18572        let mut b_p = b_data.clone();
18573        for i in 0..b_p.len() {
18574            let s = b_p[i];
18575            b_p[i] = s + h_eps;
18576            let lp = forward_loss(&x_data, &g_data, &b_p);
18577            b_p[i] = s - h_eps;
18578            let lm = forward_loss(&x_data, &g_data, &b_p);
18579            b_p[i] = s;
18580            let num = (lp - lm) / (2.0 * h_eps);
18581            assert!(
18582                (db_a[i] - num).abs() < 5e-3,
18583                "ln db[{i}]: analytical {} vs numerical {num}",
18584                db_a[i]
18585            );
18586        }
18587    }
18588
18589    /// Single dense layer + softmax-cross-entropy + mean reduce —
18590    /// the simplest non-trivial training graph. Validates MatMul,
18591    /// broadcast Add, SCE, Reduce(Mean) VJPs and the grad_with_loss
18592    /// plumbing all at once.
18593    #[test]
18594    fn dense_sce_mean_gradient_matches_numerical() {
18595        use rlx_ir::Philox4x32;
18596        let bs = 4usize;
18597        let k_in = 3usize;
18598        let c = 5usize;
18599        let mut rng = Philox4x32::new(7);
18600        let mut x = vec![0f32; bs * k_in];
18601        rng.fill_normal(&mut x);
18602        let mut w_init = vec![0f32; k_in * c];
18603        rng.fill_normal(&mut w_init);
18604        let mut b_init = vec![0f32; c];
18605        rng.fill_normal(&mut b_init);
18606        let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
18607
18608        // ── Forward graph: loss = mean(sce(x @ w + b, labels)) ──
18609        let f = DType::F32;
18610        let mut fwd = Graph::new("dense_sce");
18611        let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
18612        let lb = fwd.input("labels", Shape::new(&[bs], f));
18613        let wp = fwd.param("w", Shape::new(&[k_in, c], f));
18614        let bp = fwd.param("b", Shape::new(&[c], f));
18615        let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
18616        let logits = fwd.binary(BinaryOp::Add, mm, bp, Shape::new(&[bs, c], f));
18617        let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
18618        let loss = fwd.add_node(
18619            Op::Reduce {
18620                op: ReduceOp::Sum,
18621                axes: vec![0],
18622                keep_dim: false,
18623            },
18624            vec![loss_per],
18625            // Reduce sum of [bs] with axes=[0] keep_dim=false → scalar [].
18626            Shape::from_dims(&[], f),
18627        );
18628        // Use Sum + manual /bs scalar mul — also exercises BinaryOp::Mul VJP path
18629        // less aggressively than Mean would, and gives us a closed-form
18630        // reference for the loss we expect.
18631        // For simplicity though, switch to Mean which the tests should also cover.
18632        // (Re-using `loss` with Sum here for now; the mean factor cancels in
18633        // the gradient comparison since both analytical and numerical use the
18634        // same forward.)
18635        fwd.set_outputs(vec![loss]);
18636
18637        // ── Backward graph ──
18638        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp, bp]);
18639        // Outputs: [loss, grad_w, grad_b]. NodeIds for x/labels/w/b/loss
18640        // in bwd_graph match their fwd ids (the mirror keeps order).
18641        let d_out = bwd_graph
18642            .nodes()
18643            .iter()
18644            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18645            .map(|n| n.id)
18646            .expect("d_output input");
18647
18648        let (sched, mut arena) = prepare(
18649            &bwd_graph,
18650            &[
18651                (xn, &x),
18652                (lb, &labels),
18653                (wp, &w_init),
18654                (bp, &b_init),
18655                (d_out, &[1.0]),
18656            ],
18657        );
18658        execute_thunks(&sched, arena.raw_buf_mut());
18659
18660        let outs = &bwd_graph.outputs;
18661        let loss_id = outs[0];
18662        let gw_id = outs[1];
18663        let gb_id = outs[2];
18664        let loss_actual = read_arena(&arena, loss_id, 1)[0];
18665        let gw_actual = read_arena(&arena, gw_id, k_in * c);
18666        let gb_actual = read_arena(&arena, gb_id, c);
18667
18668        // ── Forward-only graph for finite differences ──
18669        // Re-use the same `fwd` graph; set up its own arena and rerun
18670        // for each perturbed parameter.
18671        let plan = rlx_opt::memory::plan_memory(&fwd);
18672        let mut fwd_arena = crate::arena::Arena::from_plan(plan);
18673        let fwd_sched = compile_thunks(&fwd, &fwd_arena);
18674        write_arena(&mut fwd_arena, xn, &x);
18675        write_arena(&mut fwd_arena, lb, &labels);
18676
18677        let run_loss = |arena: &mut crate::arena::Arena, w: &[f32], b: &[f32]| -> f32 {
18678            write_arena(arena, wp, w);
18679            write_arena(arena, bp, b);
18680            execute_thunks(&fwd_sched, arena.raw_buf_mut());
18681            read_arena(arena, loss, 1)[0]
18682        };
18683
18684        // Sanity: the loss reported by the bwd graph matches the
18685        // forward-only graph on the unperturbed inputs.
18686        let loss_check = run_loss(&mut fwd_arena, &w_init, &b_init);
18687        assert!(
18688            (loss_actual - loss_check).abs() < 1e-4,
18689            "loss mismatch: bwd graph {loss_actual} vs fwd-only {loss_check}"
18690        );
18691
18692        let eps = 1e-3f32;
18693        let mut w_perturbed = w_init.clone();
18694        let mut gw_numerical = vec![0f32; w_init.len()];
18695        for i in 0..w_init.len() {
18696            let saved = w_perturbed[i];
18697            w_perturbed[i] = saved + eps;
18698            let lp = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
18699            w_perturbed[i] = saved - eps;
18700            let lm = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
18701            w_perturbed[i] = saved;
18702            gw_numerical[i] = (lp - lm) / (2.0 * eps);
18703        }
18704        for (i, (a, n)) in gw_actual.iter().zip(&gw_numerical).enumerate() {
18705            assert!(
18706                (a - n).abs() < 5e-3,
18707                "grad_w[{i}]: analytical {a} vs numerical {n}"
18708            );
18709        }
18710
18711        let mut b_perturbed = b_init.clone();
18712        let mut gb_numerical = vec![0f32; b_init.len()];
18713        for i in 0..b_init.len() {
18714            let saved = b_perturbed[i];
18715            b_perturbed[i] = saved + eps;
18716            let lp = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
18717            b_perturbed[i] = saved - eps;
18718            let lm = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
18719            b_perturbed[i] = saved;
18720            gb_numerical[i] = (lp - lm) / (2.0 * eps);
18721        }
18722        for (i, (a, n)) in gb_actual.iter().zip(&gb_numerical).enumerate() {
18723            assert!(
18724                (a - n).abs() < 5e-3,
18725                "grad_b[{i}]: analytical {a} vs numerical {n}"
18726            );
18727        }
18728    }
18729
18730    /// Reduce::Mean specifically — verifies the 1/N scaling in the VJP.
18731    /// The same dense+SCE graph but with Mean instead of Sum on the loss.
18732    #[test]
18733    fn dense_sce_mean_reduce_gradient_matches_numerical() {
18734        use rlx_ir::Philox4x32;
18735        let bs = 3usize;
18736        let k_in = 2usize;
18737        let c = 4usize;
18738        let mut rng = Philox4x32::new(13);
18739        let mut x = vec![0f32; bs * k_in];
18740        rng.fill_normal(&mut x);
18741        let mut w_init = vec![0f32; k_in * c];
18742        rng.fill_normal(&mut w_init);
18743        let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
18744
18745        let f = DType::F32;
18746        let mut fwd = Graph::new("dense_sce_mean");
18747        let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
18748        let lb = fwd.input("labels", Shape::new(&[bs], f));
18749        let wp = fwd.param("w", Shape::new(&[k_in, c], f));
18750        let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
18751        let loss_per = fwd.softmax_cross_entropy_with_logits(mm, lb);
18752        let loss = fwd.add_node(
18753            Op::Reduce {
18754                op: ReduceOp::Mean,
18755                axes: vec![0],
18756                keep_dim: false,
18757            },
18758            vec![loss_per],
18759            Shape::from_dims(&[], f),
18760        );
18761        fwd.set_outputs(vec![loss]);
18762
18763        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp]);
18764        let d_out = bwd_graph
18765            .nodes()
18766            .iter()
18767            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18768            .map(|n| n.id)
18769            .unwrap();
18770
18771        let (sched, mut arena) = prepare(
18772            &bwd_graph,
18773            &[(xn, &x), (lb, &labels), (wp, &w_init), (d_out, &[1.0])],
18774        );
18775        execute_thunks(&sched, arena.raw_buf_mut());
18776
18777        let outs = &bwd_graph.outputs;
18778        let loss_id = outs[0];
18779        let gw_id = outs[1];
18780        let _ = read_arena(&arena, loss_id, 1)[0];
18781        let gw_actual = read_arena(&arena, gw_id, k_in * c);
18782
18783        let plan = rlx_opt::memory::plan_memory(&fwd);
18784        let mut fwd_arena = crate::arena::Arena::from_plan(plan);
18785        let fwd_sched = compile_thunks(&fwd, &fwd_arena);
18786        write_arena(&mut fwd_arena, xn, &x);
18787        write_arena(&mut fwd_arena, lb, &labels);
18788
18789        let run_loss = |arena: &mut crate::arena::Arena, w: &[f32]| -> f32 {
18790            write_arena(arena, wp, w);
18791            execute_thunks(&fwd_sched, arena.raw_buf_mut());
18792            read_arena(arena, loss, 1)[0]
18793        };
18794
18795        let eps = 1e-3f32;
18796        let mut wp_p = w_init.clone();
18797        let mut gw_num = vec![0f32; w_init.len()];
18798        for i in 0..w_init.len() {
18799            let s = wp_p[i];
18800            wp_p[i] = s + eps;
18801            let lp = run_loss(&mut fwd_arena, &wp_p);
18802            wp_p[i] = s - eps;
18803            let lm = run_loss(&mut fwd_arena, &wp_p);
18804            wp_p[i] = s;
18805            gw_num[i] = (lp - lm) / (2.0 * eps);
18806        }
18807        for (i, (a, n)) in gw_actual.iter().zip(&gw_num).enumerate() {
18808            assert!((a - n).abs() < 5e-3, "mean reduce grad_w[{i}]: {a} vs {n}");
18809        }
18810    }
18811    /// The full TinyConv-MNIST forward path (downsized) plumbed
18812    /// through grad_with_loss. Validates that Conv, Pool(Max), ReLU,
18813    /// Reshape, MatMul, Add (broadcast), SCE, Reduce(Mean) VJPs all
18814    /// compose into a graph that produces correct gradients.
18815    #[test]
18816    fn tinyconv_full_gradient_matches_numerical() {
18817        use rlx_ir::Philox4x32;
18818        // Tiny shapes so finite differences finish in <1s.
18819        let n = 1usize;
18820        let c_in = 1usize;
18821        let h = 6usize;
18822        let w_in = 6usize;
18823        let c_mid = 2usize; // first conv output channels
18824        let kh = 3;
18825        let kw = 3;
18826        let h1 = h - kh + 1; // 4
18827        let w1 = w_in - kw + 1; // 4
18828        let h2 = h1 / 2;
18829        let w2 = w1 / 2; // 2 × 2 after 2× pool
18830        let flat = c_mid * h2 * w2; // 8
18831        let num_classes = 3usize;
18832
18833        let mut rng = Philox4x32::new(31);
18834        let mut x = vec![0f32; n * c_in * h * w_in];
18835        rng.fill_normal(&mut x);
18836        let mut wc = vec![0f32; c_mid * c_in * kh * kw];
18837        rng.fill_normal(&mut wc);
18838        for v in wc.iter_mut() {
18839            *v *= 0.2;
18840        }
18841        // Shift conv-bias well away from the ReLU zero-boundary. Without
18842        // this, an ε-perturbation of bc[c] can flip the ReLU mask on a
18843        // pre-activation that happened to land near zero — making the
18844        // central-difference numerical gradient discontinuous and
18845        // diverge from the analytical (which assumes local smoothness).
18846        // +5.0 keeps every pre-activation positive for any random init
18847        // produced by Philox seed 31 with the wc/x scales used here, so
18848        // ReLU acts as an identity and finite differences are exact.
18849        let bc: Vec<f32> = (0..c_mid).map(|i| 5.0 + 0.1 * i as f32).collect();
18850        let mut wfc = vec![0f32; flat * num_classes];
18851        rng.fill_normal(&mut wfc);
18852        for v in wfc.iter_mut() {
18853            *v *= 0.5;
18854        }
18855        let mut bfc = vec![0f32; num_classes];
18856        rng.fill_normal(&mut bfc);
18857        let labels: Vec<f32> = vec![1.0]; // batch=1
18858
18859        let f = DType::F32;
18860        let mut fwd = Graph::new("tinyconv");
18861        let xn = fwd.input("x", Shape::new(&[n, c_in, h, w_in], f));
18862        let lb = fwd.input("labels", Shape::new(&[n], f));
18863        let wcp = fwd.param("wc", Shape::new(&[c_mid, c_in, kh, kw], f));
18864        let bcp = fwd.param("bc", Shape::new(&[c_mid], f));
18865        let wfp = fwd.param("wfc", Shape::new(&[flat, num_classes], f));
18866        let bfp = fwd.param("bfc", Shape::new(&[num_classes], f));
18867
18868        // conv: [n, c_in, h, w] → [n, c_mid, h1, w1]
18869        let conv = fwd.add_node(
18870            Op::Conv {
18871                kernel_size: vec![kh, kw],
18872                stride: vec![1, 1],
18873                padding: vec![0, 0],
18874                dilation: vec![1, 1],
18875                groups: 1,
18876            },
18877            vec![xn, wcp],
18878            Shape::new(&[n, c_mid, h1, w1], f),
18879        );
18880        // Bias add: expand bc[c_mid] up to the full [n, c_mid, h1, w1]
18881        // shape so the Add becomes a plain element-wise op. Going through
18882        // an explicit Reshape→Expand instead of relying on the Add to
18883        // broadcast `[1, C, 1, 1]` → `[N, C, H, W]` works around a known
18884        // limitation of `rlx-cpu`'s `Op::Binary` lowering: it dispatches
18885        // on `out_len % rhs_len == 0` and treats `rhs` as a last-axis
18886        // bias, which produces `bc[0], bc[1], bc[0], bc[1], …` alternating
18887        // across all positions instead of channel-broadcasting. Going
18888        // through Expand (a real broadcast thunk) avoids that path
18889        // entirely. The autodiff still exercises `unbroadcast` because
18890        // `Op::Expand`'s VJP reduces over the broadcast axes.
18891        let bc_4d = fwd.add_node(
18892            Op::Reshape {
18893                new_shape: vec![1, c_mid as i64, 1, 1],
18894            },
18895            vec![bcp],
18896            Shape::new(&[1, c_mid, 1, 1], f),
18897        );
18898        let bc_expanded = fwd.add_node(
18899            Op::Expand {
18900                target_shape: vec![n as i64, c_mid as i64, h1 as i64, w1 as i64],
18901            },
18902            vec![bc_4d],
18903            Shape::new(&[n, c_mid, h1, w1], f),
18904        );
18905        let conv_b = fwd.binary(
18906            BinaryOp::Add,
18907            conv,
18908            bc_expanded,
18909            Shape::new(&[n, c_mid, h1, w1], f),
18910        );
18911        let relu = fwd.activation(Activation::Relu, conv_b, Shape::new(&[n, c_mid, h1, w1], f));
18912        let pool = fwd.add_node(
18913            Op::Pool {
18914                kind: ReduceOp::Max,
18915                kernel_size: vec![2, 2],
18916                stride: vec![2, 2],
18917                padding: vec![0, 0],
18918            },
18919            vec![relu],
18920            Shape::new(&[n, c_mid, h2, w2], f),
18921        );
18922        let flatn = fwd.add_node(
18923            Op::Reshape {
18924                new_shape: vec![n as i64, flat as i64],
18925            },
18926            vec![pool],
18927            Shape::new(&[n, flat], f),
18928        );
18929        let mm = fwd.matmul(flatn, wfp, Shape::new(&[n, num_classes], f));
18930        let logits = fwd.binary(BinaryOp::Add, mm, bfp, Shape::new(&[n, num_classes], f));
18931        let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
18932        let loss = fwd.add_node(
18933            Op::Reduce {
18934                op: ReduceOp::Mean,
18935                axes: vec![0],
18936                keep_dim: false,
18937            },
18938            vec![loss_per],
18939            Shape::from_dims(&[], f),
18940        );
18941        fwd.set_outputs(vec![loss]);
18942
18943        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wcp, bcp, wfp, bfp]);
18944        let d_out = bwd_graph
18945            .nodes()
18946            .iter()
18947            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18948            .map(|n| n.id)
18949            .unwrap();
18950
18951        let (sched, mut arena) = prepare(
18952            &bwd_graph,
18953            &[
18954                (xn, &x),
18955                (lb, &labels),
18956                (wcp, &wc),
18957                (bcp, &bc),
18958                (wfp, &wfc),
18959                (bfp, &bfc),
18960                (d_out, &[1.0]),
18961            ],
18962        );
18963        execute_thunks(&sched, arena.raw_buf_mut());
18964
18965        let outs = bwd_graph.outputs.clone();
18966        let loss_id = outs[0];
18967        let g_wc_id = outs[1];
18968        let g_bc_id = outs[2];
18969        let g_wfc_id = outs[3];
18970        let g_bfc_id = outs[4];
18971        let loss_actual = read_arena(&arena, loss_id, 1)[0];
18972        let g_wc = read_arena(&arena, g_wc_id, wc.len());
18973        let g_bc = read_arena(&arena, g_bc_id, bc.len());
18974        let g_wfc = read_arena(&arena, g_wfc_id, wfc.len());
18975        let g_bfc = read_arena(&arena, g_bfc_id, bfc.len());
18976
18977        // Forward-only arena for finite differences.
18978        let plan = rlx_opt::memory::plan_memory(&fwd);
18979        let mut fwd_arena = crate::arena::Arena::from_plan(plan);
18980        let fwd_sched = compile_thunks(&fwd, &fwd_arena);
18981        write_arena(&mut fwd_arena, xn, &x);
18982        write_arena(&mut fwd_arena, lb, &labels);
18983
18984        // Closure variant: we need to set all four params each call so
18985        // perturbations to one don't leak between sweeps.
18986        let run_loss = |arena: &mut crate::arena::Arena,
18987                        wc: &[f32],
18988                        bc: &[f32],
18989                        wfc: &[f32],
18990                        bfc: &[f32]|
18991         -> f32 {
18992            write_arena(arena, wcp, wc);
18993            write_arena(arena, bcp, bc);
18994            write_arena(arena, wfp, wfc);
18995            write_arena(arena, bfp, bfc);
18996            execute_thunks(&fwd_sched, arena.raw_buf_mut());
18997            read_arena(arena, loss, 1)[0]
18998        };
18999
19000        let loss_check = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc);
19001        assert!(
19002            (loss_actual - loss_check).abs() < 1e-4,
19003            "tinyconv loss mismatch: bwd {loss_actual} vs fwd {loss_check}"
19004        );
19005
19006        let eps = 1e-3f32;
19007        let check_grad = |arena: &mut crate::arena::Arena,
19008                          name: &str,
19009                          analytical: &[f32],
19010                          mut perturb: Box<
19011            dyn FnMut(&mut [f32], usize, f32, &mut crate::arena::Arena) -> f32 + '_,
19012        >,
19013                          n: usize| {
19014            for i in 0..n {
19015                let lp = perturb(&mut analytical.to_vec(), i, eps, arena);
19016                let lm = perturb(&mut analytical.to_vec(), i, -eps, arena);
19017                let num = (lp - lm) / (2.0 * eps);
19018                assert!(
19019                    (analytical[i] - num).abs() < 5e-3,
19020                    "{name}[{i}]: analytical {} vs numerical {num}",
19021                    analytical[i]
19022                );
19023            }
19024        };
19025
19026        // Helper to perturb one param and run forward. Kept as a
19027        // reference for the explicit per-param sweep pattern below.
19028        #[allow(unused_macros)]
19029        macro_rules! sweep {
19030            ($name:expr, $base:expr, $analytical:expr, $set_param:ident) => {{
19031                let n = $base.len();
19032                for i in 0..n {
19033                    let mut p = $base.clone();
19034                    let s = p[i];
19035                    p[i] = s + eps;
19036                    let lp = {
19037                        let $set_param = &p;
19038                        run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc).max(f32::NEG_INFINITY);
19039                        // Reset others, set the one being swept, run.
19040                        // (the macro receives one of the four params via $set_param)
19041                        let _ = $set_param;
19042                        // Fall through to the explicit per-param helper:
19043                        0.0_f32
19044                    };
19045                    let _ = lp;
19046                }
19047            }};
19048        }
19049        let _ = check_grad; // silence unused (sweep! macro is intentionally\n        // unused — kept as reference for the per-param sweep pattern below)
19050
19051        // Per-param sweeps (explicit, not macro — clearer).
19052        for i in 0..wc.len() {
19053            let mut p = wc.clone();
19054            let s = p[i];
19055            p[i] = s + eps;
19056            let lp = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
19057            p[i] = s - eps;
19058            let lm = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
19059            let num = (lp - lm) / (2.0 * eps);
19060            assert!(
19061                (g_wc[i] - num).abs() < 5e-3,
19062                "g_wc[{i}]: {} vs {num}",
19063                g_wc[i]
19064            );
19065        }
19066        for i in 0..bc.len() {
19067            let mut p = bc.clone();
19068            let s = p[i];
19069            p[i] = s + eps;
19070            let lp = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
19071            p[i] = s - eps;
19072            let lm = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
19073            let num = (lp - lm) / (2.0 * eps);
19074            assert!(
19075                (g_bc[i] - num).abs() < 5e-3,
19076                "g_bc[{i}]: {} vs {num}",
19077                g_bc[i]
19078            );
19079        }
19080        for i in 0..wfc.len() {
19081            let mut p = wfc.clone();
19082            let s = p[i];
19083            p[i] = s + eps;
19084            let lp = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
19085            p[i] = s - eps;
19086            let lm = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
19087            let num = (lp - lm) / (2.0 * eps);
19088            assert!(
19089                (g_wfc[i] - num).abs() < 5e-3,
19090                "g_wfc[{i}]: {} vs {num}",
19091                g_wfc[i]
19092            );
19093        }
19094        for i in 0..bfc.len() {
19095            let mut p = bfc.clone();
19096            let s = p[i];
19097            p[i] = s + eps;
19098            let lp = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
19099            p[i] = s - eps;
19100            let lm = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
19101            let num = (lp - lm) / (2.0 * eps);
19102            assert!(
19103                (g_bfc[i] - num).abs() < 5e-3,
19104                "g_bfc[{i}]: {} vs {num}",
19105                g_bfc[i]
19106            );
19107        }
19108    }
19109
19110    /// Negative case: a Narrow whose output has multiple consumers
19111    /// must NOT be fused (we can't elide its write — something else
19112    /// reads it).
19113    #[test]
19114    fn narrow_rope_skips_when_narrow_has_multiple_consumers() {
19115        let f = DType::F32;
19116        let mut g = Graph::new("nr_skip");
19117        let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f));
19118        let cos = g.input("cos", Shape::new(&[16], f));
19119        let sin = g.input("sin", Shape::new(&[16], f));
19120        let q = g.narrow_(qkv, 2, 0, 64);
19121        let q_rope = g.rope(q, cos, sin, 16);
19122        // Second consumer of `q` blocks the fusion.
19123        let q_dup = g.activation(rlx_ir::op::Activation::Relu, q, Shape::new(&[16, 8, 64], f));
19124        g.set_outputs(vec![q_rope, q_dup]);
19125
19126        let plan = rlx_opt::memory::plan_memory(&g);
19127        let arena = crate::arena::Arena::from_plan(plan);
19128        let sched = compile_thunks(&g, &arena);
19129
19130        let narrow_count = sched
19131            .thunks
19132            .iter()
19133            .filter(|t| matches!(t, Thunk::Narrow { .. }))
19134            .count();
19135        assert!(
19136            narrow_count >= 1,
19137            "Narrow with multiple consumers must NOT be fused away"
19138        );
19139    }
19140
19141    // ── Op::CustomFn (custom_vjp / custom_jvp) tests ──
19142    //
19143    // Validates: forward execution inlines fwd_body; VJP rule inlines
19144    // vjp_body in place of recursing into fwd_body; JVP rule inlines
19145    // jvp_body. Each test deliberately picks a body whose AD-via-tracing
19146    // would yield a *different* gradient than the override, so we know
19147    // the override actually fired.
19148
19149    /// Forward only: CustomFn wrapping `f(x) = x + c` (c=1 inside body)
19150    /// without override AD bodies. Verifies the body is compiled,
19151    /// constants in the body fill correctly, and the output lands at
19152    /// the outer node's slot.
19153    #[test]
19154    fn custom_fn_forward_inlines_body() {
19155        let s = Shape::new(&[3], DType::F32);
19156
19157        // Body: f(x) = x + 1
19158        let mut body = Graph::new("addone_body");
19159        let x = body.input("x", s.clone());
19160        let one_data: Vec<u8> = (0..3).flat_map(|_| 1.0_f32.to_le_bytes()).collect();
19161        let one = body.add_node(Op::Constant { data: one_data }, vec![], s.clone());
19162        let y = body.binary(BinaryOp::Add, x, one, s.clone());
19163        body.set_outputs(vec![y]);
19164
19165        let mut g = Graph::new("custom_fn_outer");
19166        let xin = g.input("x_in", s.clone());
19167        let cf = g.custom_fn(vec![xin], body, None, None);
19168        g.set_outputs(vec![cf]);
19169
19170        let xs = vec![10.0_f32, 20.0, 30.0];
19171        let (sched, mut arena) = prepare(&g, &[(xin, &xs)]);
19172        execute_thunks(&sched, arena.raw_buf_mut());
19173        let got = read_arena(&arena, cf, 3);
19174        assert_eq!(got, vec![11.0, 21.0, 31.0]);
19175    }
19176
19177    /// Locate an Op::Input or Op::Param by name in a graph.
19178    fn find_named(graph: &Graph, want: &str) -> NodeId {
19179        for n in graph.nodes() {
19180            let name = match &n.op {
19181                Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19182                _ => None,
19183            };
19184            if name == Some(want) {
19185                return n.id;
19186            }
19187        }
19188        panic!("no node named {want:?} in graph");
19189    }
19190
19191    /// VJP override: f(x) = x but vjp_body returns 2 * d_output, so the
19192    /// reported gradient should be 2 — different from the natural 1
19193    /// you'd get by recursing into the identity body.
19194    #[test]
19195    fn custom_fn_vjp_overrides_natural_gradient() {
19196        use rlx_opt::autodiff::grad_with_loss;
19197        let s = Shape::new(&[1], DType::F32);
19198
19199        let mut fwd = Graph::new("id_fwd");
19200        let x = fwd.input("x", s.clone());
19201        fwd.set_outputs(vec![x]);
19202
19203        let mut vjp_g = Graph::new("id_vjp");
19204        let _x_p = vjp_g.input("x", s.clone());
19205        let _y_p = vjp_g.input("primal_output", s.clone());
19206        let dy = vjp_g.input("d_output", s.clone());
19207        let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
19208        let two = vjp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
19209        let dx = vjp_g.binary(BinaryOp::Mul, dy, two, s.clone());
19210        vjp_g.set_outputs(vec![dx]);
19211
19212        let mut g = Graph::new("outer");
19213        let xp = g.param("x", s.clone());
19214        let cf = g.custom_fn(vec![xp], fwd, Some(vjp_g), None);
19215        g.set_outputs(vec![cf]);
19216
19217        let bwd = grad_with_loss(&g, &[xp]);
19218        assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
19219
19220        let xb = find_named(&bwd, "x");
19221        let dout = find_named(&bwd, "d_output");
19222        let (sched, mut arena) = prepare(&bwd, &[(xb, &[7.0]), (dout, &[1.0])]);
19223        execute_thunks(&sched, arena.raw_buf_mut());
19224        let loss = read_arena(&arena, bwd.outputs[0], 1);
19225        let dx_v = read_arena(&arena, bwd.outputs[1], 1);
19226        assert!((loss[0] - 7.0).abs() < 1e-6, "loss should be 7.0");
19227        assert!(
19228            (dx_v[0] - 2.0).abs() < 1e-6,
19229            "vjp override should yield dx=2.0, got {} (natural autodiff would give 1.0)",
19230            dx_v[0]
19231        );
19232    }
19233
19234    /// VJP override: f(a, b) = a*b with vjp_body returning
19235    /// (b * d_output, a * d_output). Validates routing of multiple
19236    /// primals + d_output through the override; matches the natural
19237    /// autodiff-of-Mul gradient (b, a).
19238    #[test]
19239    fn custom_fn_vjp_two_inputs_matches_mul_autodiff() {
19240        use rlx_opt::autodiff::grad_with_loss;
19241        let s = Shape::new(&[1], DType::F32);
19242
19243        let mut fwd = Graph::new("mul_fwd");
19244        let a_f = fwd.input("a", s.clone());
19245        let b_f = fwd.input("b", s.clone());
19246        let y_f = fwd.binary(BinaryOp::Mul, a_f, b_f, s.clone());
19247        fwd.set_outputs(vec![y_f]);
19248
19249        let mut vjp_g = Graph::new("mul_vjp");
19250        let a_v = vjp_g.input("a", s.clone());
19251        let b_v = vjp_g.input("b", s.clone());
19252        let _y_v = vjp_g.input("primal_output", s.clone());
19253        let dy_v = vjp_g.input("d_output", s.clone());
19254        let da = vjp_g.binary(BinaryOp::Mul, b_v, dy_v, s.clone());
19255        let db = vjp_g.binary(BinaryOp::Mul, a_v, dy_v, s.clone());
19256        vjp_g.set_outputs(vec![da, db]);
19257
19258        let mut g = Graph::new("outer");
19259        let ap = g.param("a", s.clone());
19260        let bp = g.param("b", s.clone());
19261        let cf = g.custom_fn(vec![ap, bp], fwd, Some(vjp_g), None);
19262        g.set_outputs(vec![cf]);
19263
19264        let bwd = grad_with_loss(&g, &[ap, bp]);
19265        assert_eq!(bwd.outputs.len(), 3, "expect [loss, da, db]");
19266
19267        let ab = find_named(&bwd, "a");
19268        let bb = find_named(&bwd, "b");
19269        let dout = find_named(&bwd, "d_output");
19270        let (sched, mut arena) = prepare(&bwd, &[(ab, &[3.0]), (bb, &[5.0]), (dout, &[1.0])]);
19271        execute_thunks(&sched, arena.raw_buf_mut());
19272        let loss = read_arena(&arena, bwd.outputs[0], 1);
19273        let da_v = read_arena(&arena, bwd.outputs[1], 1);
19274        let db_v = read_arena(&arena, bwd.outputs[2], 1);
19275        assert!((loss[0] - 15.0).abs() < 1e-5);
19276        assert!(
19277            (da_v[0] - 5.0).abs() < 1e-5,
19278            "da should be b=5.0, got {}",
19279            da_v[0]
19280        );
19281        assert!(
19282            (db_v[0] - 3.0).abs() < 1e-5,
19283            "db should be a=3.0, got {}",
19284            db_v[0]
19285        );
19286    }
19287
19288    /// JVP override: f(x) = x but jvp_body returns 2 * tangent_0.
19289    /// Forward-mode tangent should be 2x the seed (1.0) → 2.0.
19290    #[test]
19291    fn custom_fn_jvp_overrides_natural_tangent() {
19292        use rlx_opt::autodiff_fwd::jvp;
19293        let s = Shape::new(&[1], DType::F32);
19294
19295        let mut fwd = Graph::new("id_fwd");
19296        let x = fwd.input("x", s.clone());
19297        fwd.set_outputs(vec![x]);
19298
19299        let mut jvp_g = Graph::new("id_jvp");
19300        let _x_p = jvp_g.input("x", s.clone());
19301        let tx = jvp_g.input("tangent_0", s.clone());
19302        let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
19303        let two = jvp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
19304        let ty = jvp_g.binary(BinaryOp::Mul, tx, two, s.clone());
19305        jvp_g.set_outputs(vec![ty]);
19306
19307        let mut g = Graph::new("outer");
19308        let xin = g.input("x_in", s.clone());
19309        let cf = g.custom_fn(vec![xin], fwd, None, Some(jvp_g));
19310        g.set_outputs(vec![cf]);
19311
19312        let fwd_g = jvp(&g, &[xin]);
19313        assert_eq!(fwd_g.outputs.len(), 2, "expect [primal_y, tangent_y]");
19314
19315        let xb = find_named(&fwd_g, "x_in");
19316        let tan = find_named(&fwd_g, "tangent_x_in");
19317        let (sched, mut arena) = prepare(&fwd_g, &[(xb, &[7.0]), (tan, &[1.0])]);
19318        execute_thunks(&sched, arena.raw_buf_mut());
19319        let y = read_arena(&arena, fwd_g.outputs[0], 1);
19320        let ty_v = read_arena(&arena, fwd_g.outputs[1], 1);
19321        assert!((y[0] - 7.0).abs() < 1e-6);
19322        assert!(
19323            (ty_v[0] - 2.0).abs() < 1e-6,
19324            "jvp override should yield t_y=2.0 (natural autodiff would give 1.0), got {}",
19325            ty_v[0]
19326        );
19327    }
19328
19329    /// IR-level basic test: `DType::C64` is wired through the dtype
19330    /// table — `size_bytes() == 8`, `is_complex()` reports true, and
19331    /// a `[2]`-shaped C64 buffer in the arena occupies the expected
19332    /// 16 bytes.
19333    #[test]
19334    fn c64_dtype_storage_layout() {
19335        assert_eq!(
19336            DType::C64.size_bytes(),
19337            8,
19338            "C64 should be 8 bytes (f32 real + f32 imag)"
19339        );
19340        assert!(DType::C64.is_complex());
19341        assert!(!DType::C64.is_float());
19342
19343        // A length-2 C64 buffer should have shape size_bytes = 16.
19344        let s = Shape::new(&[2], DType::C64);
19345        assert_eq!(s.size_bytes().unwrap(), 16);
19346    }
19347
19348    // ── C64 element-wise binary kernel witnesses (2026-05-17) ──────
19349    //
19350    // Build a tiny graph: Input `a` + Input `b` (both C64 [2]),
19351    // output = a OP b. Run through CompileResult and compare against
19352    // the closed-form complex arithmetic on the four chosen pairs.
19353
19354    fn run_c64_binary(op: BinaryOp, a: &[(f32, f32)], b: &[(f32, f32)]) -> Vec<(f32, f32)> {
19355        let n = a.len();
19356        let s = Shape::new(&[n], DType::C64);
19357        let mut g = Graph::new("c64_bin");
19358        let in_a = g.input("a", s.clone());
19359        let in_b = g.input("b", s.clone());
19360        let out = g.binary(op, in_a, in_b, s.clone());
19361        g.set_outputs(vec![out]);
19362
19363        let plan = rlx_opt::memory::plan_memory(&g);
19364        let mut arena = crate::arena::Arena::from_plan(plan);
19365        let sched = compile_thunks(&g, &arena);
19366
19367        let a_off = arena.byte_offset(in_a);
19368        let b_off = arena.byte_offset(in_b);
19369        let out_off = arena.byte_offset(out);
19370        // Interleave [re_0, im_0, re_1, im_1, ...] in the f32 buffer.
19371        let buf = arena.raw_buf_mut();
19372        unsafe {
19373            let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
19374            let pb = buf.as_mut_ptr().add(b_off) as *mut f32;
19375            for (i, &(re, im)) in a.iter().enumerate() {
19376                *pa.add(2 * i) = re;
19377                *pa.add(2 * i + 1) = im;
19378            }
19379            for (i, &(re, im)) in b.iter().enumerate() {
19380                *pb.add(2 * i) = re;
19381                *pb.add(2 * i + 1) = im;
19382            }
19383        }
19384        execute_thunks(&sched, arena.raw_buf_mut());
19385        let raw_out: Vec<f32> = unsafe {
19386            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19387            (0..(2 * n)).map(|i| *p.add(i)).collect()
19388        };
19389        (0..n)
19390            .map(|i| (raw_out[2 * i], raw_out[2 * i + 1]))
19391            .collect()
19392    }
19393
19394    #[track_caller]
19395    fn assert_close_c(got: (f32, f32), expected: (f32, f32), tol: f32, label: &str) {
19396        let dr = (got.0 - expected.0).abs();
19397        let di = (got.1 - expected.1).abs();
19398        assert!(
19399            dr < tol && di < tol,
19400            "[{label}] got ({:+.4}, {:+.4}), expected ({:+.4}, {:+.4})",
19401            got.0,
19402            got.1,
19403            expected.0,
19404            expected.1
19405        );
19406    }
19407
19408    #[test]
19409    fn c64_binary_add_matches_complex_arithmetic() {
19410        let a = [(1.0_f32, 2.0_f32), (3.0_f32, -1.0_f32)];
19411        let b = [(4.0_f32, -1.0_f32), (0.5_f32, 0.5_f32)];
19412        let out = run_c64_binary(BinaryOp::Add, &a, &b);
19413        assert_close_c(out[0], (5.0, 1.0), 1e-6, "add[0]");
19414        assert_close_c(out[1], (3.5, -0.5), 1e-6, "add[1]");
19415    }
19416
19417    #[test]
19418    fn c64_binary_sub_matches_complex_arithmetic() {
19419        let a = [(5.0_f32, 1.0_f32)];
19420        let b = [(2.0_f32, 3.0_f32)];
19421        let out = run_c64_binary(BinaryOp::Sub, &a, &b);
19422        assert_close_c(out[0], (3.0, -2.0), 1e-6, "sub");
19423    }
19424
19425    #[test]
19426    fn c64_binary_mul_matches_complex_arithmetic() {
19427        // (1 + 2i)(3 + 4i) = 3 + 4i + 6i + 8i² = -5 + 10i.
19428        let a = [(1.0_f32, 2.0_f32)];
19429        let b = [(3.0_f32, 4.0_f32)];
19430        let out = run_c64_binary(BinaryOp::Mul, &a, &b);
19431        assert_close_c(out[0], (-5.0, 10.0), 1e-5, "mul");
19432    }
19433
19434    #[test]
19435    fn c64_binary_div_matches_complex_arithmetic() {
19436        // (1 + 2i) / (3 + 4i) = ((1·3 + 2·4) + (2·3 − 1·4)i) / 25
19437        //                     = (11 + 2i) / 25
19438        //                     = 0.44 + 0.08i
19439        let a = [(1.0_f32, 2.0_f32)];
19440        let b = [(3.0_f32, 4.0_f32)];
19441        let out = run_c64_binary(BinaryOp::Div, &a, &b);
19442        assert_close_c(out[0], (0.44, 0.08), 1e-5, "div");
19443    }
19444
19445    #[test]
19446    fn c64_binary_mul_identity_one_is_no_op() {
19447        // (a + bi) · (1 + 0i) = a + bi.
19448        let a = [(3.5_f32, -1.25_f32), (-2.0_f32, 7.0_f32)];
19449        let b = [(1.0_f32, 0.0_f32), (1.0_f32, 0.0_f32)];
19450        let out = run_c64_binary(BinaryOp::Mul, &a, &b);
19451        assert_close_c(out[0], a[0], 1e-6, "mul·1[0]");
19452        assert_close_c(out[1], a[1], 1e-6, "mul·1[1]");
19453    }
19454
19455    #[test]
19456    fn c64_binary_mul_by_i_rotates_90_degrees() {
19457        // (a + bi) · i = (a + bi)(0 + i) = -b + ai. 90° CCW rotation.
19458        let a = [(1.0_f32, 0.0_f32)];
19459        let b = [(0.0_f32, 1.0_f32)];
19460        let out = run_c64_binary(BinaryOp::Mul, &a, &b);
19461        assert_close_c(out[0], (0.0, 1.0), 1e-6, "1·i");
19462    }
19463
19464    #[test]
19465    fn c64_binary_div_by_self_gives_unity() {
19466        let a = [(2.5_f32, -1.5_f32), (-0.7_f32, 4.2_f32)];
19467        let out = run_c64_binary(BinaryOp::Div, &a, &a);
19468        assert_close_c(out[0], (1.0, 0.0), 1e-5, "div_self[0]");
19469        assert_close_c(out[1], (1.0, 0.0), 1e-5, "div_self[1]");
19470    }
19471
19472    #[test]
19473    #[should_panic(expected = "C64: complex max/min/pow")]
19474    fn c64_binary_max_is_rejected_at_lowering() {
19475        run_c64_binary(BinaryOp::Max, &[(1.0_f32, 2.0_f32)], &[(3.0_f32, 4.0_f32)]);
19476    }
19477
19478    fn run_c64_activation(act: Activation, a: &[(f32, f32)]) -> Vec<(f32, f32)> {
19479        let n = a.len();
19480        let s = Shape::new(&[n], DType::C64);
19481        let mut g = Graph::new("c64_act");
19482        let in_a = g.input("a", s.clone());
19483        let out = g.activation(act, in_a, s.clone());
19484        g.set_outputs(vec![out]);
19485        let plan = rlx_opt::memory::plan_memory(&g);
19486        let mut arena = crate::arena::Arena::from_plan(plan);
19487        let sched = compile_thunks(&g, &arena);
19488        let a_off = arena.byte_offset(in_a);
19489        let out_off = arena.byte_offset(out);
19490        let buf = arena.raw_buf_mut();
19491        unsafe {
19492            let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
19493            for (i, &(re, im)) in a.iter().enumerate() {
19494                *pa.add(2 * i) = re;
19495                *pa.add(2 * i + 1) = im;
19496            }
19497        }
19498        execute_thunks(&sched, arena.raw_buf_mut());
19499        let raw: Vec<f32> = unsafe {
19500            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19501            (0..(2 * n)).map(|i| *p.add(i)).collect()
19502        };
19503        (0..n).map(|i| (raw[2 * i], raw[2 * i + 1])).collect()
19504    }
19505
19506    #[test]
19507    fn c64_activation_neg_negates_both_components() {
19508        let inp = [(3.5_f32, -1.25_f32), (-2.0_f32, 0.0_f32)];
19509        let out = run_c64_activation(Activation::Neg, &inp);
19510        assert_close_c(out[0], (-3.5, 1.25), 1e-6, "neg[0]");
19511        assert_close_c(out[1], (2.0, 0.0), 1e-6, "neg[1]");
19512    }
19513
19514    #[test]
19515    fn c64_activation_exp_matches_euler() {
19516        // exp(0 + i·π) = -1 + 0i.
19517        // exp(1 + 0i) = e ≈ 2.71828.
19518        let inp = [(0.0_f32, std::f32::consts::PI), (1.0_f32, 0.0_f32)];
19519        let out = run_c64_activation(Activation::Exp, &inp);
19520        assert_close_c(out[0], (-1.0, 0.0), 1e-5, "exp(iπ)");
19521        assert_close_c(out[1], (std::f32::consts::E, 0.0), 1e-5, "exp(1)");
19522    }
19523
19524    #[test]
19525    fn c64_activation_log_matches_principal_branch() {
19526        // log(1 + 0i) = 0.
19527        // log(0 + i) = log(1) + i·π/2 = 0 + i·π/2.
19528        // log(-1 + 0i) = 0 + i·π.
19529        let inp = [(1.0_f32, 0.0_f32), (0.0_f32, 1.0_f32), (-1.0_f32, 0.0_f32)];
19530        let out = run_c64_activation(Activation::Log, &inp);
19531        assert_close_c(out[0], (0.0, 0.0), 1e-5, "log(1)");
19532        assert_close_c(out[1], (0.0, std::f32::consts::FRAC_PI_2), 1e-5, "log(i)");
19533        assert_close_c(out[2], (0.0, std::f32::consts::PI), 1e-5, "log(-1)");
19534    }
19535
19536    #[test]
19537    fn c64_activation_sqrt_squared_recovers_input() {
19538        // For positive-real-part inputs, sqrt(z)² should equal z exactly
19539        // to f32 noise.
19540        let inp = [(4.0_f32, 0.0_f32), (3.0_f32, 4.0_f32)];
19541        let roots = run_c64_activation(Activation::Sqrt, &inp);
19542        // sqrt(4) = 2 + 0i; sqrt(3+4i) = 2 + i (since (2+i)² = 4+4i-1 = 3+4i).
19543        assert_close_c(roots[0], (2.0, 0.0), 1e-5, "sqrt(4)");
19544        assert_close_c(roots[1], (2.0, 1.0), 1e-5, "sqrt(3+4i)");
19545    }
19546
19547    #[test]
19548    #[should_panic(expected = "no natural complex extension")]
19549    fn c64_activation_relu_is_rejected_at_lowering() {
19550        run_c64_activation(Activation::Relu, &[(1.0_f32, 2.0_f32)]);
19551    }
19552
19553    // ── ComplexNormSq + Wirtinger backward witnesses ───────────────
19554
19555    /// Forward `|z|²`: returns `[n]` f32.
19556    fn run_complex_norm_sq(z: &[(f32, f32)]) -> Vec<f32> {
19557        let n = z.len();
19558        let mut g = Graph::new("cns_fwd");
19559        let in_z = g.input("z", Shape::new(&[n], DType::C64));
19560        let out = g.complex_norm_sq(in_z);
19561        g.set_outputs(vec![out]);
19562        let plan = rlx_opt::memory::plan_memory(&g);
19563        let mut arena = crate::arena::Arena::from_plan(plan);
19564        let sched = compile_thunks(&g, &arena);
19565        let z_off = arena.byte_offset(in_z);
19566        let out_off = arena.byte_offset(out);
19567        let buf = arena.raw_buf_mut();
19568        unsafe {
19569            let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
19570            for (i, &(re, im)) in z.iter().enumerate() {
19571                *pz.add(2 * i) = re;
19572                *pz.add(2 * i + 1) = im;
19573            }
19574        }
19575        execute_thunks(&sched, arena.raw_buf_mut());
19576        unsafe {
19577            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19578            (0..n).map(|i| *p.add(i)).collect()
19579        }
19580    }
19581
19582    /// Backward: given z and upstream g, return dz = g·z element-wise (C64).
19583    fn run_complex_norm_sq_bwd(z: &[(f32, f32)], g: &[f32]) -> Vec<(f32, f32)> {
19584        let n = z.len();
19585        let mut gr = Graph::new("cns_bwd");
19586        let in_z = gr.input("z", Shape::new(&[n], DType::C64));
19587        let in_g = gr.input("g", Shape::new(&[n], DType::F32));
19588        let out = gr.complex_norm_sq_backward(in_z, in_g);
19589        gr.set_outputs(vec![out]);
19590        let plan = rlx_opt::memory::plan_memory(&gr);
19591        let mut arena = crate::arena::Arena::from_plan(plan);
19592        let sched = compile_thunks(&gr, &arena);
19593        let z_off = arena.byte_offset(in_z);
19594        let g_off = arena.byte_offset(in_g);
19595        let out_off = arena.byte_offset(out);
19596        let buf = arena.raw_buf_mut();
19597        unsafe {
19598            let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
19599            let pg = buf.as_mut_ptr().add(g_off) as *mut f32;
19600            for (i, &(re, im)) in z.iter().enumerate() {
19601                *pz.add(2 * i) = re;
19602                *pz.add(2 * i + 1) = im;
19603            }
19604            for (i, &v) in g.iter().enumerate() {
19605                *pg.add(i) = v;
19606            }
19607        }
19608        execute_thunks(&sched, arena.raw_buf_mut());
19609        unsafe {
19610            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19611            (0..n).map(|i| (*p.add(2 * i), *p.add(2 * i + 1))).collect()
19612        }
19613    }
19614
19615    #[test]
19616    fn complex_norm_sq_matches_textbook() {
19617        // |3 + 4i|² = 9 + 16 = 25.
19618        // |1 + 0i|² = 1.
19619        // |0 + 0i|² = 0.
19620        let z = [(3.0_f32, 4.0_f32), (1.0_f32, 0.0_f32), (0.0_f32, 0.0_f32)];
19621        let out = run_complex_norm_sq(&z);
19622        assert!((out[0] - 25.0).abs() < 1e-5);
19623        assert!((out[1] - 1.0).abs() < 1e-6);
19624        assert!(out[2].abs() < 1e-6);
19625    }
19626
19627    #[test]
19628    fn complex_norm_sq_backward_matches_wirtinger_formula() {
19629        // Wirtinger: ∂|z|²/∂z̄ = z. With upstream g = 1, dz = z.
19630        let z = [(3.0_f32, 4.0_f32), (1.5_f32, -2.5_f32)];
19631        let g = [1.0_f32, 1.0_f32];
19632        let dz = run_complex_norm_sq_bwd(&z, &g);
19633        assert_close_c(dz[0], z[0], 1e-6, "dz[0] = g·z[0]");
19634        assert_close_c(dz[1], z[1], 1e-6, "dz[1] = g·z[1]");
19635    }
19636
19637    #[test]
19638    fn complex_norm_sq_backward_scales_with_upstream() {
19639        // With upstream g[i] ≠ 1: dz[i] = g[i]·z[i].
19640        let z = [(2.0_f32, 1.0_f32), (-1.0_f32, 3.0_f32)];
19641        let g = [0.5_f32, -2.0_f32];
19642        let dz = run_complex_norm_sq_bwd(&z, &g);
19643        assert_close_c(dz[0], (1.0, 0.5), 1e-6, "g=0.5 · (2,1)");
19644        assert_close_c(dz[1], (2.0, -6.0), 1e-6, "g=-2 · (-1,3)");
19645    }
19646
19647    /// Multi-output Op::CustomFn via the concat-with-Narrow design
19648    /// (rlx-ir::Graph::custom_fn_multi). Build a custom_fn whose
19649    /// fwd_body returns two outputs (x², 2x), then materialize each
19650    /// via the MultiOutputHandle and verify both numerically.
19651    #[test]
19652    fn custom_fn_multi_extracts_each_subgraph_output() {
19653        use rlx_ir::ops::special::MultiOutputHandle;
19654
19655        let _ = MultiOutputHandle {
19656            source: NodeId(0),
19657            sub_shapes: vec![],
19658            offsets: vec![],
19659        }; // import sanity
19660
19661        // Inner body: input x [3] f32, outputs (x², 2x) both [3] f32.
19662        let mut body = Graph::new("multi_body");
19663        let s3 = Shape::new(&[3], DType::F32);
19664        let x = body.input("x", s3.clone());
19665        let x_sq = body.binary(BinaryOp::Mul, x, x, s3.clone());
19666        let two = body.add_node(
19667            Op::Constant {
19668                data: vec![
19669                    2.0_f32.to_le_bytes(),
19670                    2.0_f32.to_le_bytes(),
19671                    2.0_f32.to_le_bytes(),
19672                ]
19673                .into_iter()
19674                .flatten()
19675                .collect(),
19676            },
19677            vec![],
19678            s3.clone(),
19679        );
19680        let two_x = body.binary(BinaryOp::Mul, two, x, s3.clone());
19681        body.set_outputs(vec![x_sq, two_x]);
19682
19683        // Outer graph: feed in_x → custom_fn_multi → handle.output(0/1).
19684        let mut outer = Graph::new("multi_outer");
19685        let in_x = outer.input("xin", s3.clone());
19686        let handle = outer.custom_fn_multi(vec![in_x], body);
19687        assert_eq!(handle.n_outputs(), 2);
19688        let out0 = handle.output(&mut outer, 0); // x²
19689        let out1 = handle.output(&mut outer, 1); // 2x
19690        outer.set_outputs(vec![out0, out1]);
19691
19692        let plan = rlx_opt::memory::plan_memory(&outer);
19693        let mut arena = crate::arena::Arena::from_plan(plan);
19694        let sched = compile_thunks(&outer, &arena);
19695        let xin_off = arena.byte_offset(in_x);
19696        let out0_off = arena.byte_offset(out0);
19697        let out1_off = arena.byte_offset(out1);
19698        let xs = [1.0_f32, 2.0, 3.0];
19699        unsafe {
19700            let p = arena.raw_buf_mut().as_mut_ptr().add(xin_off) as *mut f32;
19701            for (i, &v) in xs.iter().enumerate() {
19702                *p.add(i) = v;
19703            }
19704        }
19705        execute_thunks(&sched, arena.raw_buf_mut());
19706        let out0_v: Vec<f32> = unsafe {
19707            let p = arena.raw_buf().as_ptr().add(out0_off) as *const f32;
19708            (0..3).map(|i| *p.add(i)).collect()
19709        };
19710        let out1_v: Vec<f32> = unsafe {
19711            let p = arena.raw_buf().as_ptr().add(out1_off) as *const f32;
19712            (0..3).map(|i| *p.add(i)).collect()
19713        };
19714        // x² = [1, 4, 9]; 2x = [2, 4, 6].
19715        for i in 0..3 {
19716            assert!(
19717                (out0_v[i] - xs[i] * xs[i]).abs() < 1e-5,
19718                "out0[{i}] = {} != x² = {}",
19719                out0_v[i],
19720                xs[i] * xs[i]
19721            );
19722            assert!(
19723                (out1_v[i] - 2.0 * xs[i]).abs() < 1e-5,
19724                "out1[{i}] = {} != 2x = {}",
19725                out1_v[i],
19726                2.0 * xs[i]
19727            );
19728        }
19729    }
19730
19731    #[test]
19732    fn complex_norm_sq_gradient_matches_finite_difference() {
19733        // Numerical sanity: perturb z[0].re by ε, observe Δ|z|² ≈ 2·re·ε.
19734        let z = [(3.0_f32, 4.0_f32)];
19735        let eps = 1e-3_f32;
19736        let v0 = run_complex_norm_sq(&z)[0];
19737        let z_pert = [(3.0_f32 + eps, 4.0_f32)];
19738        let v1 = run_complex_norm_sq(&z_pert)[0];
19739        let fd_re = (v1 - v0) / eps;
19740        let analytic_re = 2.0 * z[0].0;
19741        assert!((fd_re - analytic_re).abs() < 1e-2);
19742
19743        // ∂/∂im at z = (3, 4) is 2·im = 8.
19744        let z_pert_im = [(3.0_f32, 4.0_f32 + eps)];
19745        let v2 = run_complex_norm_sq(&z_pert_im)[0];
19746        let fd_im = (v2 - v0) / eps;
19747        let analytic_im = 2.0 * z[0].1;
19748        assert!((fd_im - analytic_im).abs() < 1e-2);
19749
19750        // Compare with the Wirtinger backward at upstream g = 1.
19751        // Wirtinger ∂/∂z̄ = z gives dz = (re, im). The "real
19752        // gradient" wrt (re, im) is 2·(re, im), i.e. 2·dz = (2·re,
19753        // 2·im) — that's the factor 2 difference between Wirtinger
19754        // ∂/∂z̄ and the real-vector gradient on (re, im).
19755        let dz = run_complex_norm_sq_bwd(&z, &[1.0_f32]);
19756        assert!((2.0 * dz[0].0 - analytic_re).abs() < 1e-5);
19757        assert!((2.0 * dz[0].1 - analytic_im).abs() < 1e-5);
19758    }
19759
19760    /// Direct regression test for the 5-D mid-shape singleton broadcast
19761    /// (SAM rel_pos pattern: `[bh, h, w, 1, w] + [bh, h, w, h, w]`).
19762    /// The SAM port worked around this by `concat`-tiling the rhs; this
19763    /// test verifies the in-graph broadcast path is bit-correct.
19764    #[test]
19765    fn binary_full_5d_mid_singleton_broadcast() {
19766        let bh = 2usize;
19767        let h = 3;
19768        let w = 4;
19769        let f = DType::F32;
19770
19771        let mut g = Graph::new("bcast_5d");
19772        let lhs = g.input("lhs", Shape::new(&[bh, h, w, h, w], f));
19773        // rhs shape with size-1 at axis 3 (mid-shape singleton).
19774        let rhs = g.input("rhs", Shape::new(&[bh, h, w, 1, w], f));
19775        let out = g.binary(BinaryOp::Add, lhs, rhs, Shape::new(&[bh, h, w, h, w], f));
19776        g.set_outputs(vec![out]);
19777
19778        // Deterministic data.
19779        let lhs_data: Vec<f32> = (0..bh * h * w * h * w).map(|i| i as f32 * 0.01).collect();
19780        let rhs_data: Vec<f32> = (0..bh * h * w * w)
19781            .map(|i| (i as f32 + 100.0) * 0.01)
19782            .collect();
19783
19784        // Compute expected output by hand.
19785        let mut expected = vec![0f32; bh * h * w * h * w];
19786        for b_ in 0..bh {
19787            for hq in 0..h {
19788                for wq in 0..w {
19789                    for hk in 0..h {
19790                        for wk in 0..w {
19791                            let li = (((b_ * h + hq) * w + wq) * h + hk) * w + wk;
19792                            // rhs has hk dim = 1, so it's always index 0 there.
19793                            let ri = ((b_ * h + hq) * w + wq) * w + wk;
19794                            expected[li] = lhs_data[li] + rhs_data[ri];
19795                        }
19796                    }
19797                }
19798            }
19799        }
19800
19801        let plan = rlx_opt::memory::plan_memory(&g);
19802        let mut arena = crate::arena::Arena::from_plan(plan);
19803        let sched = compile_thunks(&g, &arena);
19804        let lhs_off = arena.byte_offset(lhs);
19805        let rhs_off = arena.byte_offset(rhs);
19806        let out_off = arena.byte_offset(out);
19807        let buf = arena.raw_buf_mut();
19808        unsafe {
19809            let p = buf.as_mut_ptr().add(lhs_off) as *mut f32;
19810            for (i, &v) in lhs_data.iter().enumerate() {
19811                *p.add(i) = v;
19812            }
19813            let p = buf.as_mut_ptr().add(rhs_off) as *mut f32;
19814            for (i, &v) in rhs_data.iter().enumerate() {
19815                *p.add(i) = v;
19816            }
19817        }
19818        execute_thunks(&sched, arena.raw_buf_mut());
19819        let actual: Vec<f32> = unsafe {
19820            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19821            (0..bh * h * w * h * w).map(|i| *p.add(i)).collect()
19822        };
19823
19824        // Bit-exact check.
19825        let mut max_diff = 0f32;
19826        let mut max_idx = 0;
19827        for i in 0..actual.len() {
19828            let d = (actual[i] - expected[i]).abs();
19829            if d > max_diff {
19830                max_diff = d;
19831                max_idx = i;
19832            }
19833        }
19834        assert!(
19835            max_diff < 1e-6,
19836            "5D mid-shape singleton broadcast wrong: max |Δ| = {max_diff} at idx {max_idx} \
19837             (actual={}, expected={})",
19838            actual[max_idx],
19839            expected[max_idx]
19840        );
19841    }
19842
19843    #[test]
19844    fn layer_norm2d_and_conv_transpose2d_kernels() {
19845        let mut out = vec![0f32; 8];
19846        crate::kernels::layer_norm2d_nchw(
19847            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
19848            &[1.0, 1.0],
19849            &[0.0, 0.0],
19850            &mut out,
19851            1,
19852            2,
19853            2,
19854            2,
19855            1e-5,
19856        );
19857        let mean0: f32 = (1.0 + 3.0) / 2.0;
19858        assert!((out[0] - mean0).abs() > 0.1);
19859
19860        let mut up = vec![0f32; 4];
19861        crate::kernels::conv_transpose2d_nchw(
19862            &[2.0],
19863            &[1.0, 0.0, 0.0, 1.0],
19864            &mut up,
19865            1,
19866            1,
19867            1,
19868            1,
19869            1,
19870            2,
19871            2,
19872            2,
19873            2,
19874            2,
19875            2,
19876            0,
19877            0,
19878            1,
19879            1,
19880            1,
19881        );
19882        assert!((up[0] - 2.0).abs() < 1e-5);
19883        assert!((up[3] - 2.0).abs() < 1e-5);
19884    }
19885}