Skip to main content

rlx_cpu/
thunk.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Thunks — pre-compiled kernel dispatch with zero per-call overhead.
17//!
18//! At compile time, the graph is lowered into a flat `Vec<Thunk>` where each
19//! thunk holds pre-computed arena offsets, dimensions, and kernel type.
20//! At runtime, the executor just iterates thunks and calls kernels directly.
21
22// Edition 2024: bodies of `unsafe fn` are safe by default; `sl`/`sl_mut` stay `unsafe fn`.
23#![allow(unsafe_op_in_unsafe_fn)]
24//! No match dispatch, no HashMap lookup, no dimension computation.
25
26use crate::arena::Arena;
27use crate::op_registry::CpuKernel;
28use rlx_ir::op::{Activation, BinaryOp, CmpOp, ReduceOp};
29use rlx_ir::{Graph, NodeId, Op, Shape};
30use std::collections::HashMap;
31use std::sync::Arc;
32
33/// A pre-compiled kernel call with all args resolved to arena offsets.
34#[derive(Clone)]
35pub enum Thunk {
36    /// Skip (Input/Param already in arena)
37    Nop,
38    /// C = A @ B (BLAS sgemm)
39    Sgemm {
40        a: usize,
41        b: usize,
42        c: usize,
43        m: u32,
44        k: u32,
45        n: u32,
46    },
47    /// f64 dense solve `x = A⁻¹·b` via LAPACK dgesv.
48    /// `a`, `b`, `x` are byte-offsets into the arena. `n` is the matrix
49    /// dimension; `nrhs` is 1 for a vector RHS or >1 for multi-RHS.
50    /// The kernel materializes scratch copies of A and b internally
51    /// (LAPACK overwrites both with LU factors and solution).
52    DenseSolveF64 {
53        a: usize,
54        b: usize,
55        x: usize,
56        n: u32,
57        nrhs: u32,
58    },
59    /// f32 twin of `DenseSolveF64`. Calls LAPACK `sgesv` (or the
60    /// no-blas Rust fallback). Same arena byte-offset contract.
61    DenseSolveF32 {
62        a: usize,
63        b: usize,
64        x: usize,
65        n: u32,
66        nrhs: u32,
67    },
68    /// Batched f64 dense solve. `a`, `b`, `x` are byte-offsets to
69    /// the leading slice; `batch` is the number of independent
70    /// systems. Per slice the kernel calls `dgesv(A_i, b_i, n, nrhs)`
71    /// — LAPACK has no batched dgesv on Accelerate, so we loop.
72    BatchedDenseSolveF64 {
73        a: usize,
74        b: usize,
75        x: usize,
76        batch: u32,
77        n: u32,
78        nrhs: u32,
79    },
80    /// Batched f32 dense solve — loop of `sgesv` per batch slice.
81    BatchedDenseSolveF32 {
82        a: usize,
83        b: usize,
84        x: usize,
85        batch: u32,
86        n: u32,
87        nrhs: u32,
88    },
89    /// Batched f64 matmul. Both inputs and output have a leading
90    /// batch axis of size `batch`. Per-batch independent dgemm:
91    /// `C[i] = A[i] @ B[i]` for `i in 0..batch`. Used by VJP rules
92    /// that emit per-batch outer products (e.g., BatchedDenseSolve
93    /// VJP). The unbatched `Dgemm` thunk handles the rank-2 case.
94    BatchedDgemmF64 {
95        a: usize,
96        b: usize,
97        c: usize,
98        batch: u32,
99        m: u32,
100        k: u32,
101        n: u32,
102    },
103    /// Batched f32 matmul — same loop-per-batch shape as
104    /// `BatchedDgemmF64` but calling `sgemm`. Needed for attention
105    /// patterns where both operands carry a batch dim (e.g. q@k^T
106    /// and attn@v in decomposed self-attention). The 2-D `Sgemm`
107    /// flatten trick is wrong in that case because it treats `b` as
108    /// a single shared RHS across every batch.
109    BatchedSgemm {
110        a: usize,
111        b: usize,
112        c: usize,
113        batch: u32,
114        m: u32,
115        k: u32,
116        n: u32,
117    },
118    /// C = A @ B via Accelerate cblas_dgemm. Mirror of `Sgemm` at f64.
119    Dgemm {
120        a: usize,
121        b: usize,
122        c: usize,
123        m: u32,
124        k: u32,
125        n: u32,
126    },
127    /// f64 N-D index walk used for both `Op::Transpose` and `Op::Expand`.
128    /// `in_strides` carries 0s on broadcast axes (Expand) or permuted
129    /// strides (Transpose). Mirror of `Thunk::Transpose` at f64.
130    TransposeF64 {
131        src: usize,
132        dst: usize,
133        in_total: u32,
134        out_dims: Vec<u32>,
135        in_strides: Vec<u32>,
136    },
137    /// f64 element-wise activation. Single-input, single-output. The
138    /// kernel always reads from `src` and writes to `dst`, so it works
139    /// whether or not the planner aliased the two slots.
140    ActivationF64 {
141        src: usize,
142        dst: usize,
143        len: u32,
144        kind: Activation,
145    },
146    /// Element-wise complex squared-magnitude: `|z|² = re² + im²`.
147    /// Reads the C64 input at `src` as `2·len` f32 ([re,im] pairs),
148    /// writes `len` f32 to `dst`.
149    ComplexNormSqF32 {
150        src: usize,
151        dst: usize,
152        /// Logical element count (number of complex values).
153        len: u32,
154    },
155    /// Wirtinger backward for [`ComplexNormSqF32`]: `dz = g · z` as
156    /// C64. Reads `z` at `2·len` f32 + `g` at `len` f32; writes
157    /// `2·len` f32 to `dz`.
158    ComplexNormSqBackwardF32 {
159        z: usize,
160        g: usize,
161        dz: usize,
162        len: u32,
163    },
164    /// Element-wise C64 conjugate: writes `[re_i, -im_i]` per element.
165    /// Layout matches the rest of C64 here ([re,im] interleaved f32).
166    ConjugateC64 { src: usize, dst: usize, len: u32 },
167    /// C64 element-wise activation. Only kinds with well-defined
168    /// complex extensions are supported: Neg, Exp, Log, Sqrt.
169    /// Everything else (Sigmoid, Tanh, Relu, Abs, Sin/Cos/Tan/Atan,
170    /// Round, GeLU family) is rejected at lowering — those don't have
171    /// single natural complex definitions. `len` is the **complex
172    /// element count** (the f32 buffer holds `2·len` floats).
173    ActivationC64 {
174        src: usize,
175        dst: usize,
176        len: u32,
177        kind: Activation,
178    },
179    /// f64 contiguous reduction along a single axis range. Layout
180    /// `[outer, reduced, inner]` in memory; output is `[outer, inner]`.
181    /// Sum only for now (Mean composes via 1/N multiply post-pass).
182    ReduceSumF64 {
183        src: usize,
184        dst: usize,
185        outer: u32,
186        reduced: u32,
187        inner: u32,
188    },
189    /// f64 plain copy (Reshape / Cast at the same dtype). Mirrors `Copy`
190    /// but at 8 bytes per element.
191    CopyF64 { src: usize, dst: usize, len: u32 },
192    /// f64 element-wise binary with broadcast. `len`/`lhs_len`/`rhs_len`
193    /// are element counts; kernel does `out[i] = lhs[i % lhs_len] OP rhs[i % rhs_len]`.
194    /// Mirror of `BinaryFull` at 8 bytes per element.
195    BinaryFullF64 {
196        lhs: usize,
197        rhs: usize,
198        dst: usize,
199        len: u32,
200        lhs_len: u32,
201        rhs_len: u32,
202        op: BinaryOp,
203        /// Output shape dims (row-major). Empty in the fast path. See
204        /// `BinaryFull` doc for the broadcast convention.
205        out_dims_bcast: Vec<u32>,
206        bcast_lhs_strides: Vec<u32>,
207        bcast_rhs_strides: Vec<u32>,
208    },
209    /// f64 concat — byte-for-byte mirror of `Concat` but copies
210    /// 8 bytes per element. Element-counted offsets/strides match
211    /// the f32 variant; the executor scales by elem_size internally.
212    ConcatF64 {
213        dst: usize,
214        outer: u32,
215        inner: u32,
216        total_axis: u32,
217        inputs: Vec<(usize, u32)>,
218    },
219    /// C64 element-wise binary with broadcast. Same `len` /
220    /// `lhs_len` / `rhs_len` semantics as `BinaryFull` but each
221    /// "element" is one complex value (8 bytes = `[re, im]` as two
222    /// f32s). The executor reads the underlying f32 buffer at
223    /// `2·len` floats and walks element pairs. Supports Add / Sub /
224    /// Mul / Div; Max / Min / Pow have no single natural complex
225    /// definition and panic at lowering.
226    BinaryFullC64 {
227        lhs: usize,
228        rhs: usize,
229        dst: usize,
230        /// Complex element count (NOT f32 count). f32 buffer length
231        /// is `2·len`.
232        len: u32,
233        lhs_len: u32,
234        rhs_len: u32,
235        op: BinaryOp,
236        out_dims_bcast: Vec<u32>,
237        bcast_lhs_strides: Vec<u32>,
238        bcast_rhs_strides: Vec<u32>,
239    },
240    /// Bounded scan. Holds a recursively-compiled body schedule + a
241    /// pre-initialized body arena snapshot (constants filled). Each
242    /// outer execution clones the snapshot, copies the carry-in slot
243    /// from the outer arena, runs the body schedule `length` times,
244    /// then writes the final carry to the outer arena.
245    ///
246    /// Single-carry MVP — body has exactly one Input and one output,
247    /// both same shape and dtype.
248    Scan {
249        body: Arc<ThunkSchedule>,
250        body_init: Arc<Vec<u8>>, // pristine body arena bytes
251        body_input_off: usize,   // byte offset of the body's carry-Input slot
252        body_output_off: usize,  // byte offset of the body's output slot
253        outer_init_off: usize,   // outer-arena offset of the initial carry
254        outer_final_off: usize,  // outer-arena offset of the final carry / trajectory base
255        length: u32,
256        carry_bytes: u32, // carry size in bytes
257        /// When true, write each step's carry to the outer arena at
258        /// offset `outer_final_off + t * carry_bytes`, producing a
259        /// `[length, *carry]` stacked trajectory. When false, only the
260        /// final carry lands at `outer_final_off`.
261        save_trajectory: bool,
262        /// Per-step `xs` inputs. For each: (body_x_input_off,
263        /// outer_xs_base_off, per_step_bytes). Per iteration `t`, the
264        /// executor copies `outer_xs_base_off + t * per_step_bytes`
265        /// into `body_x_input_off`. Empty when the scan has no xs.
266        xs_inputs: Arc<Vec<(usize, usize, u32)>>,
267        /// Broadcast inputs — values constant across iterations. For
268        /// each: (body_bcast_input_off, outer_bcast_off, total_bytes).
269        /// Filled into `body_buf` ONCE before the scan loop starts
270        /// (xs in contrast are re-filled every iteration). Empty when
271        /// the scan has no bcasts.
272        bcast_inputs: Arc<Vec<(usize, usize, u32)>>,
273        /// Number of trajectory checkpoints (when `save_trajectory`).
274        /// `0` or `length` ⇒ save every iteration. Otherwise save only
275        /// `K` rows at indices `floor((k+1) * length / K) - 1` for
276        /// `k in 0..K`. Last index is always `length-1` so the final
277        /// carry is always cached.
278        num_checkpoints: u32,
279    },
280
281    /// Reverse-mode AD companion to `Thunk::Scan`. Walks `t = length-1
282    /// .. 0`, threading `dcarry` through the body's VJP. Per iteration:
283    /// writes `carry_t` (from outer init or trajectory), each `xs_i[t]`
284    /// slice, and the current `dcarry` into the body_vjp's Input
285    /// slots, runs body_vjp, reads new `dcarry` from its single output.
286    /// f64 carry only — the upstream-accumulation step in trajectory
287    /// mode does an element-wise f64 add.
288    ScanBackward {
289        body_vjp: Arc<ThunkSchedule>,
290        body_init: Arc<Vec<u8>>,
291        body_carry_in_off: usize, // body_vjp's mirrored body-carry-input slot
292        body_x_offs: Arc<Vec<usize>>, // body_vjp's mirrored x_t_i Input slots, in xs order
293        body_d_output_off: usize, // body_vjp's "d_output" Input slot
294        body_dcarry_out_off: usize, // body_vjp's gradient output
295        outer_init_off: usize,    // original init carry
296        outer_traj_off: usize,    // [length-or-K, *carry] trajectory base
297        outer_upstream_off: usize, // upstream gradient (carry shape, or [length, *carry])
298        /// Per-xs entries: (outer_xs_base_off, per_step_bytes). Read
299        /// `xs_i[t]` from `outer_xs_base_off + t * per_step_bytes`.
300        outer_xs_offs: Arc<Vec<(usize, u32)>>,
301        outer_dinit_off: usize, // output: dinit
302        length: u32,
303        carry_bytes: u32,
304        /// Bytes per element in the carry tensor: 4 for f32, 8 for f64.
305        /// Used to dispatch the trajectory-mode upstream accumulation
306        /// kernel (the dcarry += upstream\[t\] add must use the right
307        /// floating-point type — a hard-coded f64 add silently does
308        /// nothing for an f32 carry whose `cb` isn't divisible by 8).
309        carry_elem_size: u32,
310        save_trajectory: bool, // true → upstream is per-step; false → just final
311        /// Recursive checkpointing config. `0` or `length` ⇒ full
312        /// trajectory cached, no recompute (existing behavior).
313        /// `0 < K < length` ⇒ trajectory has only K rows; the executor
314        /// recomputes intermediate carries via `forward_body` between
315        /// checkpoints. Memory: O(K · carry_bytes); time: O(length).
316        num_checkpoints: u32,
317        /// Forward body schedule (same compiled body as the forward
318        /// Op::Scan), used for recompute when `num_checkpoints` is
319        /// active. `None` for the All strategy.
320        forward_body: Option<Arc<ThunkSchedule>>,
321        /// Pristine forward body arena bytes (constants filled).
322        forward_body_init: Option<Arc<Vec<u8>>>,
323        /// Forward body's carry-Input and output slot offsets — needed
324        /// to seed/read the body during recompute.
325        forward_body_carry_in_off: usize,
326        forward_body_output_off: usize,
327        /// Forward body's per-step xs Input slots (one per outer xs).
328        /// Same indexing convention as `body_x_offs`.
329        forward_body_x_offs: Arc<Vec<usize>>,
330    },
331
332    /// Companion to `ScanBackward` that materializes one stacked
333    /// `dxs_i`. Same backward loop; per iteration, after running
334    /// body_vjp, copies its `body_dxs_out_off` slot into the outer
335    /// arena at `outer_dxs_off + t * per_step_bytes`. dcarry threading
336    /// is identical — we still need it for the body_vjp recurrence
337    /// even though we don't write it back to the outer arena.
338    ScanBackwardXs {
339        body_vjp: Arc<ThunkSchedule>,
340        body_init: Arc<Vec<u8>>,
341        body_carry_in_off: usize,
342        body_x_offs: Arc<Vec<usize>>,
343        body_d_output_off: usize,
344        body_dcarry_out_off: usize,
345        body_dxs_out_off: usize, // the body_vjp output we extract per step
346        outer_init_off: usize,
347        outer_traj_off: usize,
348        outer_upstream_off: usize,
349        outer_xs_offs: Arc<Vec<(usize, u32)>>,
350        outer_dxs_off: usize, // base of the stacked [length, *per_step] output
351        length: u32,
352        carry_bytes: u32,
353        /// Same role as `Thunk::ScanBackward::carry_elem_size`.
354        carry_elem_size: u32,
355        per_step_bytes: u32, // bytes per row of the dxs output
356        save_trajectory: bool,
357        /// Recursive checkpointing config. Same semantics as
358        /// `Thunk::ScanBackward::num_checkpoints` — `0` or `length`
359        /// means "save every step's carry"; `0 < K < length` means
360        /// the trajectory has only K rows and the executor recomputes
361        /// intermediate carries via `forward_body` (which must be
362        /// `Some`). Implemented via segment-cached recompute,
363        /// mirroring the `ScanBackward` path.
364        num_checkpoints: u32,
365        forward_body: Option<Arc<ThunkSchedule>>,
366        forward_body_init: Option<Arc<Vec<u8>>>,
367        forward_body_carry_in_off: usize,
368        forward_body_output_off: usize,
369        forward_body_x_offs: Arc<Vec<usize>>,
370    },
371    /// User-defined sub-graph (`Op::CustomFn`) — runs `fwd_body` once.
372    /// Per execution: clone `body_init`, copy each primal input from the
373    /// outer arena into its body Input slot, run the body schedule,
374    /// copy the body's single output back to the outer arena.
375    CustomFn {
376        body: Arc<ThunkSchedule>,
377        body_init: Arc<Vec<u8>>,
378        /// Per primal input: (body_input_off, outer_input_off, bytes).
379        inputs: Arc<Vec<(usize, usize, u32)>>,
380        body_output_off: usize,
381        outer_output_off: usize,
382        out_bytes: u32,
383    },
384    /// C = A @ B; C += bias; C = act(C)
385    FusedMmBiasAct {
386        a: usize,
387        w: usize,
388        bias: usize,
389        c: usize,
390        m: u32,
391        k: u32,
392        n: u32,
393        act: Option<Activation>,
394    },
395    /// out = LN(x + residual + bias, gamma, beta)
396    FusedResidualLN {
397        x: usize,
398        res: usize,
399        bias: usize,
400        g: usize,
401        b: usize,
402        out: usize,
403        rows: u32,
404        h: u32,
405        eps: f32,
406        has_bias: bool,
407    },
408    /// out = RmsNorm(x + residual + bias, gamma, beta)
409    FusedResidualRmsNorm {
410        x: usize,
411        res: usize,
412        bias: usize,
413        g: usize,
414        b: usize,
415        out: usize,
416        rows: u32,
417        h: u32,
418        eps: f32,
419        has_bias: bool,
420    },
421    /// out = bias_add(data, bias, m, n) for Binary::Add with broadcast
422    BiasAdd {
423        src: usize,
424        bias: usize,
425        dst: usize,
426        m: u32,
427        n: u32,
428    },
429    /// Element-wise binary op with NumPy-style broadcast.
430    ///
431    /// Fast path (`lhs_len == rhs_len == len`): plain element-wise loop,
432    /// SIMD-vectorized on aarch64 for `Add`/`Mul`. `bcast_*` fields
433    /// are unused.
434    ///
435    /// Broadcast path: uses `out_dims_bcast` + `bcast_lhs_strides` +
436    /// `bcast_rhs_strides` to compute per-cell indices into each
437    /// operand. The strides are precomputed at thunk-construction
438    /// time from the operands' true shapes (with stride 0 on any axis
439    /// where the operand has size 1). This is the only correct way
440    /// to handle bidirectional broadcasts like `[N, 1] op [1, S]
441    /// → [N, S]`, which simple `i % lhs_len` modulo indexing maps to
442    /// wrong cells.
443    BinaryFull {
444        lhs: usize,
445        rhs: usize,
446        dst: usize,
447        len: u32,
448        lhs_len: u32,
449        rhs_len: u32,
450        op: BinaryOp,
451        /// Output shape dims (row-major). Empty in the fast path.
452        out_dims_bcast: Vec<u32>,
453        /// Per-dim stride into `lhs` (0 where lhs broadcasts).
454        bcast_lhs_strides: Vec<u32>,
455        /// Per-dim stride into `rhs`.
456        bcast_rhs_strides: Vec<u32>,
457    },
458    /// Activation in-place
459    ActivationInPlace {
460        data: usize,
461        len: u32,
462        act: Activation,
463    },
464    /// Gather axis=0: table\[idx\] → out
465    Gather {
466        table: usize,
467        table_len: u32,
468        idx: usize,
469        dst: usize,
470        num_idx: u32,
471        trailing: u32,
472    },
473    /// Narrow: copy slice (`elem_bytes` = source element size: 4 for f32, 8 for f64).
474    Narrow {
475        src: usize,
476        dst: usize,
477        outer: u32,
478        src_stride: u32,
479        dst_stride: u32,
480        inner: u32,
481        elem_bytes: u8,
482    },
483    /// Copy (reshape, expand)
484    Copy { src: usize, dst: usize, len: u32 },
485    /// LayerNorm standalone
486    LayerNorm {
487        src: usize,
488        g: usize,
489        b: usize,
490        dst: usize,
491        rows: u32,
492        h: u32,
493        eps: f32,
494    },
495    /// GroupNorm on NCHW `[N,C,H,W]`.
496    GroupNorm {
497        src: usize,
498        g: usize,
499        b: usize,
500        dst: usize,
501        n: u32,
502        c: u32,
503        h: u32,
504        w: u32,
505        num_groups: u32,
506        eps: f32,
507    },
508    /// LayerNorm2d on NCHW (SAM / candle semantics).
509    LayerNorm2d {
510        src: usize,
511        g: usize,
512        b: usize,
513        dst: usize,
514        n: u32,
515        c: u32,
516        h: u32,
517        w: u32,
518        eps: f32,
519    },
520    /// ConvTranspose2d on NCHW.
521    ConvTranspose2d {
522        src: usize,
523        weight: usize,
524        dst: usize,
525        n: u32,
526        c_in: u32,
527        h: u32,
528        w_in: u32,
529        c_out: u32,
530        h_out: u32,
531        w_out: u32,
532        kh: u32,
533        kw: u32,
534        sh: u32,
535        sw: u32,
536        ph: u32,
537        pw: u32,
538        dh: u32,
539        dw: u32,
540        groups: u32,
541    },
542    /// Nearest 2× upsample on NCHW (per-batch slice).
543    ResizeNearest2x {
544        src: usize,
545        dst: usize,
546        n: u32,
547        c: u32,
548        h: u32,
549        w: u32,
550    },
551    /// SAM2 axial 2-D RoPE on `[batch, seq, num_heads * head_dim]`.
552    AxialRope2d {
553        src: usize,
554        dst: usize,
555        batch: u32,
556        seq: u32,
557        hidden: u32,
558        end_x: u32,
559        end_y: u32,
560        head_dim: u32,
561        num_heads: u32,
562        theta: f32,
563        repeat_factor: u32,
564    },
565    /// RMSNorm: out = (x / sqrt(mean(x^2) + eps)) * gamma + beta. No mean
566    /// subtraction, hence cheaper than LayerNorm. Used by Llama-class models.
567    RmsNorm {
568        src: usize,
569        g: usize,
570        b: usize,
571        dst: usize,
572        rows: u32,
573        h: u32,
574        eps: f32,
575    },
576    /// Softmax
577    Softmax { data: usize, rows: u32, cols: u32 },
578    /// Inclusive (or exclusive) cumulative sum along the last axis
579    /// (callers pre-flatten higher-dim cumsums via reshape views).
580    Cumsum {
581        src: usize,
582        dst: usize,
583        rows: u32,
584        cols: u32,
585        exclusive: bool,
586    },
587    /// Mamba-style selective scan (plan #15).
588    /// Inputs: x, delta \[b,s,h\], a \[h,n\], b \[b,s,n\], c \[b,s,n\].
589    /// Output: y \[b,s,h\]. State h carries through the seq.
590    SelectiveScan {
591        x: usize,
592        delta: usize,
593        a: usize,
594        b: usize,
595        c: usize,
596        dst: usize,
597        batch: u32,
598        seq: u32,
599        hidden: u32,
600        state_size: u32,
601    },
602
603    /// Gated DeltaNet linear-attention scan (Qwen3.5/3.6 trunk).
604    /// Inputs: q, k, v `[b, s, h, n]`; g, beta `[b, s, h]`. Output:
605    /// `[b, s, h, n]`. See `Op::GatedDeltaNet` for math.
606    GatedDeltaNet {
607        q: usize,
608        k: usize,
609        v: usize,
610        g: usize,
611        beta: usize,
612        /// When non-zero, load initial `[b, h, n, n]` state and write
613        /// the final state back in place after the scan.
614        state: usize,
615        dst: usize,
616        batch: u32,
617        seq: u32,
618        heads: u32,
619        state_size: u32,
620    },
621
622    /// 1×1 conv fast path (plan #26). The general Conv2D thunk
623    /// runs the textbook 7-deep loop; a 1×1 stride-1 padding-0
624    /// groups-1 conv is mathematically a per-batch matmul, and
625    /// dispatching it through BLAS is 3-10× faster than the
626    /// scalar nest. Common case: ViT patch-projection follow-on,
627    /// transformer "expert" reductions in some MoE designs.
628    ///
629    /// Per batch: weight `[c_out, c_in]` × input `[c_in, h*w]`
630    ///         = output `[c_out, h*w]`.
631    Conv2D1x1 {
632        src: usize,
633        weight: usize,
634        dst: usize,
635        n: u32,
636        c_in: u32,
637        c_out: u32,
638        hw: u32,
639    },
640
641    /// Fused dequant + matmul (plan #5). Today supports
642    /// `QuantScheme::Int8Block` (symmetric); other schemes panic
643    /// at lowering time with a clear message until kernels are added.
644    DequantMatMul {
645        x: usize,
646        w_q: usize,   // packed i8 bytes for Int8 schemes
647        scale: usize, // [k/block, n] f32 scale
648        zp: usize,    // [k/block, n] f32 zero-point (0 for sym)
649        dst: usize,
650        m: u32,
651        k: u32,
652        n: u32,
653        block_size: u32,
654        is_asymmetric: bool,
655    },
656
657    /// GGUF-format dequant + matmul. Weight is a packed byte tensor
658    /// in one of the K-quant super-block layouts (Q4_K, Q5_K, Q6_K,
659    /// Q8_K). Scales / mins live inside the packed bytes — no
660    /// side-channel scale tensor.
661    ///
662    /// Today this is a "dequant-to-scratch then sgemm" kernel — it
663    /// keeps the *arena* memory footprint down (weights stay packed)
664    /// but the dequant itself happens per matmul. A future fully
665    /// fused tile-streaming kernel would close the compute gap.
666    DequantMatMulGguf {
667        x: usize,   // f32 activations [m, k]
668        w_q: usize, // packed weight bytes (k*n elements packed)
669        dst: usize, // f32 output [m, n]
670        m: u32,
671        k: u32,
672        n: u32,
673        scheme: rlx_ir::quant::QuantScheme,
674    },
675
676    /// Int4 block dequant + matmul (packed nibbles, side scale/zp).
677    DequantMatMulInt4 {
678        x: usize,
679        w_q: usize,
680        scale: usize,
681        zp: usize,
682        dst: usize,
683        m: u32,
684        k: u32,
685        n: u32,
686        block_size: u32,
687        is_asymmetric: bool,
688    },
689
690    /// FP8 dequant + matmul (per-tensor or per-column scale).
691    DequantMatMulFp8 {
692        x: usize,
693        w_q: usize,
694        scale: usize,
695        dst: usize,
696        m: u32,
697        k: u32,
698        n: u32,
699        e5m2: bool,
700    },
701
702    /// NVFP4 (E2M1) block dequant + matmul — 16-wide groups, FP8 scales.
703    DequantMatMulNvfp4 {
704        x: usize,
705        w_q: usize,
706        scale: usize,
707        global_scale: usize,
708        dst: usize,
709        m: u32,
710        k: u32,
711        n: u32,
712    },
713
714    /// Fused LoRA matmul (plan #9): out = x·W + scale * (x·A)·B.
715    /// `r` is the LoRA rank (typically 4-64) — the rank-r
716    /// intermediate `x·A` lives in scratch, never on the arena.
717    LoraMatMul {
718        x: usize,
719        w: usize,
720        a: usize,
721        b: usize,
722        dst: usize,
723        m: u32,
724        k: u32,
725        n: u32,
726        r: u32,
727        scale: f32,
728    },
729    /// Fused sample: logits [batch, vocab] → token ids \[batch\].
730    /// See Op::Sample. Output values are f32-encoded usize indices
731    /// (matches the rest of the IR's "ids as f32" convention).
732    Sample {
733        logits: usize,
734        dst: usize,
735        batch: u32,
736        vocab: u32,
737        top_k: u32,       // 0 = disabled
738        top_p: f32,       // 1.0 = disabled
739        temperature: f32, // 1.0 = neutral
740        seed: u64,
741    },
742    /// Attention SDPA. `mask` is the offset of the optional mask tensor
743    /// (only meaningful when `mask_kind == MaskKind::Custom`); other
744    /// kinds synthesize the mask in-kernel.
745    ///
746    /// Q/K/V each carry a `_row_stride` (elements per source row).
747    /// Defaults to `heads * head_dim` — matches the standalone
748    /// "Q/K/V are their own contiguous buffers" case. The Narrow→
749    /// Attention fusion below rewrites these to the parent QKV stride
750    /// (typically `3 * heads * head_dim`) so the kernel reads QKV
751    /// directly without materializing the per-head buffers (plan #46).
752    Attention {
753        q: usize,
754        k: usize,
755        v: usize,
756        mask: usize,
757        out: usize,
758        batch: u32,
759        /// Query sequence length.
760        seq: u32,
761        /// Key/value sequence length. Differs from `seq` during cached decode.
762        kv_seq: u32,
763        heads: u32,
764        head_dim: u32,
765        mask_kind: rlx_ir::op::MaskKind,
766        q_row_stride: u32,
767        k_row_stride: u32,
768        v_row_stride: u32,
769        /// Memory layout flag. `false` (the historical default) →
770        /// `[B, S, H, D]` row-major: per-head offset is
771        /// `bi*S*H*D + si*H*D + hi*D`. `true` → `[B, H, S, D]`
772        /// (head-major), matching the convention used by rlx-cuda /
773        /// rlx-rocm / rlx-tpu: per-head offset is
774        /// `bi*H*S*D + hi*S*D + si*D`. Detected at lowering time
775        /// from the input shape vs `num_heads` / `head_dim`.
776        bhsd: bool,
777    },
778    /// [`Op::AttentionBackward`] — emits dQ, dK, or dV (see `wrt`).
779    AttentionBackward {
780        q: usize,
781        k: usize,
782        v: usize,
783        dy: usize,
784        mask: usize,
785        out: usize,
786        batch: u32,
787        seq: u32,
788        kv_seq: u32,
789        heads: u32,
790        head_dim: u32,
791        mask_kind: rlx_ir::op::MaskKind,
792        wrt: rlx_ir::op::AttentionBwdWrt,
793        bhsd: bool,
794    },
795    /// RoPE (rotary position embeddings).
796    /// `src_row_stride` is elements per source row (defaults to `hidden`
797    /// for the standalone case; set to `qkv_axis * inner` when the
798    /// thunk fusion pass below rewires Rope to read directly from the
799    /// fused QKV buffer — plan #45).
800    Rope {
801        src: usize,
802        cos: usize,
803        sin: usize,
804        dst: usize,
805        batch: u32,
806        seq: u32,
807        hidden: u32,
808        head_dim: u32,
809        n_rot: u32,
810        cos_len: u32,
811        src_row_stride: u32,
812    },
813    /// Fused attention block: QKV proj → split → \[RoPE\] → SDPA → output proj.
814    /// All intermediates stay in L1 cache. Zero arena writes between ops.
815    FusedAttnBlock {
816        hidden: usize,
817        qkv_w: usize,
818        out_w: usize,
819        mask: usize,
820        out: usize,
821        qkv_b: usize,
822        out_b: usize, // 0 = no bias
823        cos: usize,
824        sin: usize,
825        cos_len: u32, // 0 = no RoPE
826        batch: u32,
827        seq: u32,
828        hs: u32,
829        nh: u32,
830        dh: u32,
831        has_bias: bool,
832        has_rope: bool,
833    },
834    /// Fused ENTIRE transformer layer: attention + residual + LN + FFN + residual + LN.
835    /// Combines ~10 thunks into 1. All intermediates on stack. Zero arena traffic.
836    FusedBertLayer {
837        // attention
838        hidden: usize,
839        qkv_w: usize,
840        qkv_b: usize,
841        out_w: usize,
842        out_b: usize,
843        mask: usize,
844        // LN1
845        ln1_g: usize,
846        ln1_b: usize,
847        eps1: f32,
848        // FFN (GELU)
849        fc1_w: usize,
850        fc1_b: usize,
851        fc2_w: usize,
852        fc2_b: usize,
853        // LN2
854        ln2_g: usize,
855        ln2_b: usize,
856        eps2: f32,
857        // output
858        out: usize,
859        // dims
860        batch: u32,
861        seq: u32,
862        hs: u32,
863        nh: u32,
864        dh: u32,
865        int_dim: u32,
866    },
867    /// Fused Nomic transformer layer: attention+RoPE + residual + LN + SwiGLU FFN + residual + LN.
868    FusedNomicLayer {
869        hidden: usize,
870        qkv_w: usize,
871        out_w: usize,
872        mask: usize,
873        cos: usize,
874        sin: usize,
875        cos_len: u32,
876        ln1_g: usize,
877        ln1_b: usize,
878        eps1: f32,
879        fc11_w: usize,
880        fc12_w: usize,
881        fc2_w: usize,
882        ln2_g: usize,
883        ln2_b: usize,
884        eps2: f32,
885        out: usize,
886        batch: u32,
887        seq: u32,
888        hs: u32,
889        nh: u32,
890        dh: u32,
891        int_dim: u32,
892    },
893    /// Fused SwiGLU: out\[r,i\] = x\[r,i\] * silu(x[r, n_half+i]).
894    /// Input: [outer, 2*n_half] — concatenated up||gate per row.
895    /// Output: [outer, n_half].
896    FusedSwiGLU {
897        src: usize,
898        dst: usize,
899        n_half: u32,
900        total: u32,
901        gate_first: bool,
902    },
903    /// Concat along an axis: output[outer, axis, inner] = inputs concatenated.
904    /// Each entry of `inputs` is (src_offset, axis_len_for_that_input) in u32
905    /// elements. `outer`, `inner`, and `total_axis_len` are pre-computed
906    /// at compile time to avoid per-run shape work.
907    Concat {
908        dst: usize,
909        outer: u32,
910        inner: u32,
911        total_axis: u32,
912        inputs: Vec<(usize, u32)>,
913    },
914    /// Element-wise comparison: out = (lhs CMP rhs) ? 1.0 : 0.0
915    Compare {
916        lhs: usize,
917        rhs: usize,
918        dst: usize,
919        len: u32,
920        op: CmpOp,
921    },
922    /// Reduction along a contiguous range of axes. Input layout (after
923    /// shape decomposition) is `[outer, reduced, inner]`; output is
924    /// `[outer, inner]`. The single-axis cases (axis=0 → outer=1;
925    /// axis=last → inner=1) and contiguous multi-axis (e.g. reduce over
926    /// [0, 1] of an [N, C, H, W] tensor → outer=1, reduced=N*C, inner=H*W)
927    /// all map onto this triplet. Non-contiguous axes are not supported
928    /// and bail to Nop in the compile pass.
929    Reduce {
930        src: usize,
931        dst: usize,
932        outer: u32,
933        reduced: u32,
934        inner: u32,
935        op: ReduceOp,
936    },
937    /// Top-K **indices** along the last axis. Input shape `[outer, axis_dim]`,
938    /// output `[outer, k]` of f32-encoded i64 indices. Ties broken by
939    /// smaller index. Used by MoE gating + beam search.
940    TopK {
941        src: usize,
942        dst: usize,
943        outer: u32,
944        axis_dim: u32,
945        k: u32,
946    },
947    /// Indexed batched matmul: out\[i\] = input\[i\] @ weight[expert_idx\[i\]].
948    /// Naive impl per token; for real MoE workloads, sort-by-expert + run
949    /// segmented GEMM would amortize. Done when there's a workload.
950    GroupedMatMul {
951        input: usize,
952        weight: usize,
953        expert_idx: usize,
954        dst: usize,
955        m: u32,
956        k_dim: u32,
957        n: u32,
958        num_experts: u32,
959    },
960    /// GGUF K-quant packed expert stack + grouped matmul (MoE FFN).
961    DequantGroupedMatMulGguf {
962        input: usize,
963        w_q: usize,
964        expert_idx: usize,
965        dst: usize,
966        m: u32,
967        k_dim: u32,
968        n: u32,
969        num_experts: u32,
970        scheme: rlx_ir::quant::QuantScheme,
971    },
972    /// Materialize packed MoE weights to F32 `[E, K, N]` (autodiff helper).
973    DequantMoEWeightsGguf {
974        w_q: usize,
975        dst: usize,
976        k_dim: u32,
977        n: u32,
978        num_experts: u32,
979        scheme: rlx_ir::quant::QuantScheme,
980    },
981    /// Scatter-add: dst[indices\[i\] * trailing + j] += updates[i * trailing + j].
982    /// Output is zeroed first; multiple updates to the same row accumulate.
983    ScatterAdd {
984        updates: usize,
985        indices: usize,
986        dst: usize,
987        num_updates: u32,
988        out_dim: u32,
989        trailing: u32,
990    },
991    /// Ternary select: out = cond != 0 ? on_true : on_false
992    Where {
993        cond: usize,
994        on_true: usize,
995        on_false: usize,
996        dst: usize,
997        len: u32,
998    },
999    /// General N-D transpose / broadcast. `out_dims[i]` is the output's dim
1000    /// i length; `in_strides[i]` is the input stride (in elements) used to
1001    /// index that dim — 0 for broadcast dims (Expand). `in_total` is the
1002    /// total element count in the source buffer (≤ output total when
1003    /// broadcasting). Strides are pre-computed at compile time.
1004    Transpose {
1005        src: usize,
1006        dst: usize,
1007        in_total: u32,
1008        out_dims: Vec<u32>,
1009        in_strides: Vec<u32>,
1010    },
1011    /// Gather along an arbitrary axis. `outer = product(dims[..axis])`,
1012    /// `trailing = product(dims[axis+1..])`, `axis_dim` = the dimension
1013    /// being indexed into. Output: outer × num_idx × trailing.
1014    /// (axis=0 still routes to the simpler Thunk::Gather fast path.)
1015    GatherAxis {
1016        table: usize,
1017        idx: usize,
1018        dst: usize,
1019        outer: u32,
1020        axis_dim: u32,
1021        num_idx: u32,
1022        trailing: u32,
1023    },
1024    /// 2D pooling (Max or Mean). Input layout [N, C, H, W], output
1025    /// [N, C, H_out, W_out]. Padding is implicit-zero; Mean divides by
1026    /// the full kernel area (matches torch's `count_include_pad=True`).
1027    Pool2D {
1028        src: usize,
1029        dst: usize,
1030        n: u32,
1031        c: u32,
1032        h: u32,
1033        w: u32,
1034        h_out: u32,
1035        w_out: u32,
1036        kh: u32,
1037        kw: u32,
1038        sh: u32,
1039        sw: u32,
1040        ph: u32,
1041        pw: u32,
1042        kind: ReduceOp,
1043    },
1044    /// 2D convolution. Input [N, C_in, H, W], weight [C_out, C_in_per_group, kH, kW],
1045    /// output [N, C_out, H_out, W_out]. Bias is a separate Op::Binary::Add
1046    /// after the conv (matching the IR's input layout — Op::Conv has 2 inputs).
1047    /// Naive direct convolution; sufficient for correctness, not optimised.
1048    Conv2D {
1049        src: usize,
1050        weight: usize,
1051        dst: usize,
1052        n: u32,
1053        c_in: u32,
1054        h: u32,
1055        w: u32,
1056        c_out: u32,
1057        h_out: u32,
1058        w_out: u32,
1059        kh: u32,
1060        kw: u32,
1061        sh: u32,
1062        sw: u32,
1063        ph: u32,
1064        pw: u32,
1065        dh: u32,
1066        dw: u32,
1067        groups: u32,
1068    },
1069
1070    // ── Backward / training kernels ─────────────────────────────
1071    /// Real INT8 matmul with i32 accumulation.
1072    ///   `out[m, n] = requantize(bias[n] + Σₖ (x[m,k]-x_zp)·(w[k,n]-w_zp), mult, out_zp)`
1073    /// Reads `x` and `w` as i8, `bias` as i32; writes `out` as i8.
1074    /// Same kernel shape as `rlx_cortexm::dense::dense_i8` — promoted
1075    /// to a desktop thunk so a quantized graph compiled here doesn't
1076    /// have to round-trip through fake-quant.
1077    QMatMul {
1078        x: usize,
1079        w: usize,
1080        bias: usize,
1081        out: usize,
1082        m: u32,
1083        k: u32,
1084        n: u32,
1085        x_zp: i32,
1086        w_zp: i32,
1087        out_zp: i32,
1088        mult: f32,
1089    },
1090
1091    /// Real INT8 conv2d, NCHW layout. Same loop shape as `Thunk::Conv2D`
1092    /// but with i8 reads, i32 accumulation, and per-output requantize
1093    /// to i8. Bias is i32 in the accumulator scale.
1094    QConv2d {
1095        x: usize,
1096        w: usize,
1097        bias: usize,
1098        out: usize,
1099        n: u32,
1100        c_in: u32,
1101        h: u32,
1102        w_in: u32,
1103        c_out: u32,
1104        h_out: u32,
1105        w_out: u32,
1106        kh: u32,
1107        kw: u32,
1108        sh: u32,
1109        sw: u32,
1110        ph: u32,
1111        pw: u32,
1112        dh: u32,
1113        dw: u32,
1114        groups: u32,
1115        x_zp: i32,
1116        w_zp: i32,
1117        out_zp: i32,
1118        mult: f32,
1119    },
1120
1121    /// INT8 quantize. Reads `x` as f32, writes `q` as i8.
1122    /// `chan = (i / inner) % chan_dim` selects the per-channel
1123    /// scale/zp; `chan_axis` is informational only (the kernel uses
1124    /// `chan_dim` and `inner` directly).
1125    /// For per-tensor, `chan_dim = 1` and `inner = len` so `chan` is
1126    /// always 0.
1127    Quantize {
1128        x: usize,
1129        q: usize,
1130        len: u32,
1131        chan_axis: u32,
1132        chan_dim: u32,
1133        inner: u32,
1134        scales: Vec<f32>,
1135        zero_points: Vec<i32>,
1136    },
1137
1138    /// INT8 dequantize — inverse of `Thunk::Quantize`.
1139    Dequantize {
1140        q: usize,
1141        x: usize,
1142        len: u32,
1143        chan_axis: u32,
1144        chan_dim: u32,
1145        inner: u32,
1146        scales: Vec<f32>,
1147        zero_points: Vec<i32>,
1148    },
1149
1150    /// QAT fake-quantize. Per-channel (or per-tensor) symmetric
1151    /// quantize-then-dequantize on the fly. Computes
1152    ///   `s[c] = max(|x[..., c, ...]|) / q_max`
1153    /// then
1154    ///   `out[i] = clamp(round(x[i]/s[c]), -q_max, q_max) * s[c]`
1155    /// with `q_max = {127, 7, 1}` for `bits = {8, 4, 2}`. Same
1156    /// channel-layout convention as `Thunk::Quantize`: every
1157    /// element's channel is `(i / inner) % chan_dim`. The kernel
1158    /// does two passes — one to scan max-abs per channel, one to
1159    /// quant-dequant per element.
1160    FakeQuantize {
1161        x: usize,
1162        out: usize,
1163        len: u32,
1164        chan_axis: u32,
1165        chan_dim: u32,
1166        inner: u32,
1167        bits: u8,
1168        /// STE variant — informational on the forward side (output is
1169        /// the same regardless), kernel-relevant in the matching
1170        /// `FakeQuantizeBackward` thunk.
1171        ste: rlx_ir::op::SteKind,
1172        /// Scale-tracking strategy. `PerBatch` recomputes
1173        /// `max_abs/q_max` every call (the original path). `EMA{decay}`
1174        /// blends per-batch max-abs into the `state_off` buffer; `Fixed`
1175        /// reads `state_off` and never updates it.
1176        scale_mode: rlx_ir::op::ScaleMode,
1177        /// `Some(off)` for `EMA` and `Fixed`; `None` for `PerBatch`.
1178        /// Points at a `[chan_dim]` f32 buffer holding the running scale
1179        /// per channel.
1180        state_off: Option<usize>,
1181    },
1182
1183    /// Backward pass for `Op::FakeQuantize` under one of four STE
1184    /// variants. Computes `dx[i]` from the f32 forward input `x` and
1185    /// the upstream gradient `dy`, using the same per-channel scale
1186    /// scheme as the forward.
1187    FakeQuantizeBackward {
1188        x: usize,
1189        dy: usize,
1190        dx: usize,
1191        len: u32,
1192        chan_axis: u32,
1193        chan_dim: u32,
1194        inner: u32,
1195        bits: u8,
1196        ste: rlx_ir::op::SteKind,
1197    },
1198
1199    /// LSQ forward — same kernel shape as `FakeQuantize` Fixed mode.
1200    /// Reads scale from `scale_off` (a `[chan_dim]` Param tensor).
1201    FakeQuantizeLSQ {
1202        x: usize,
1203        scale_off: usize,
1204        out: usize,
1205        len: u32,
1206        chan_axis: u32,
1207        chan_dim: u32,
1208        inner: u32,
1209        bits: u8,
1210    },
1211
1212    /// LSQ backward, x-gradient. STE-clipped: passes upstream
1213    /// through inside the quantization range, zeros outside.
1214    FakeQuantizeLSQBackwardX {
1215        x: usize,
1216        scale_off: usize,
1217        dy: usize,
1218        dx: usize,
1219        len: u32,
1220        chan_axis: u32,
1221        chan_dim: u32,
1222        inner: u32,
1223        bits: u8,
1224    },
1225
1226    /// LSQ backward, scale-gradient. Per-channel:
1227    ///   `dscale[c] = sum_i ψ(x[i]/s[c]) · upstream[i]`
1228    /// where `ψ(z) = -z + round(z)` if `|z| ≤ q_max` else
1229    /// `sign(z) · q_max`. Output shape: `[chan_dim]`.
1230    FakeQuantizeLSQBackwardScale {
1231        x: usize,
1232        scale_off: usize,
1233        dy: usize,
1234        dscale: usize,
1235        len: u32,
1236        chan_axis: u32,
1237        chan_dim: u32,
1238        inner: u32,
1239        bits: u8,
1240    },
1241
1242    /// ReLU backward: `dx[i] = dy[i] if x[i] > 0 else 0`.
1243    ReluBackward {
1244        x: usize,
1245        dy: usize,
1246        dx: usize,
1247        len: u32,
1248    },
1249    /// f64 sibling of `ReluBackward` — same shape as the f32 variant
1250    /// but reads/writes 8 bytes per element. Required because
1251    /// `ReluBackward`'s `&[f32]` slot view returns half of every f64
1252    /// otherwise → backward silently produces 0 gradients on an f64
1253    /// graph. Mirrors the `ActivationBackwardF64` split.
1254    ReluBackwardF64 {
1255        x: usize,
1256        dy: usize,
1257        dx: usize,
1258        len: u32,
1259    },
1260
1261    /// Generic element-wise activation backward.
1262    /// `dx[i] = (d/dx act(x))[i] · dy[i]`. The closure dispatch is
1263    /// per-element; expensive activations (Gelu) recompute internals
1264    /// inline rather than threading an extra "saved y" tensor through.
1265    ActivationBackward {
1266        x: usize,
1267        dy: usize,
1268        dx: usize,
1269        len: u32,
1270        kind: Activation,
1271    },
1272    /// f64 sibling of `ActivationBackward` — slot offsets, len in
1273    /// elements; kernel reads/writes 8 bytes per element. Required
1274    /// because `ActivationBackward`'s `&[f32]` slot view silently
1275    /// returns garbage on an f64 graph (cb % 4 still works but every
1276    /// loaded value is half of an f64 → wrong gradient).
1277    ActivationBackwardF64 {
1278        x: usize,
1279        dy: usize,
1280        dx: usize,
1281        len: u32,
1282        kind: Activation,
1283    },
1284
1285    /// LayerNorm backward — input gradient. Recomputes mean/var/x̂ from
1286    /// `x` and emits the closed-form `d_x` per row.
1287    LayerNormBackwardInput {
1288        x: usize,
1289        gamma: usize,
1290        dy: usize,
1291        dx: usize,
1292        rows: u32,
1293        h: u32,
1294        eps: f32,
1295    },
1296
1297    /// LayerNorm backward — gamma gradient. `d_gamma[d] = Σ_row dy·x̂`.
1298    LayerNormBackwardGamma {
1299        x: usize,
1300        dy: usize,
1301        dgamma: usize,
1302        rows: u32,
1303        h: u32,
1304        eps: f32,
1305    },
1306
1307    RmsNormBackwardInput {
1308        x: usize,
1309        gamma: usize,
1310        beta: usize,
1311        dy: usize,
1312        dx: usize,
1313        rows: u32,
1314        h: u32,
1315        eps: f32,
1316    },
1317    RmsNormBackwardGamma {
1318        x: usize,
1319        gamma: usize,
1320        beta: usize,
1321        dy: usize,
1322        dgamma: usize,
1323        rows: u32,
1324        h: u32,
1325        eps: f32,
1326    },
1327    RmsNormBackwardBeta {
1328        x: usize,
1329        gamma: usize,
1330        beta: usize,
1331        dy: usize,
1332        dbeta: usize,
1333        rows: u32,
1334        h: u32,
1335        eps: f32,
1336    },
1337    RopeBackward {
1338        dy: usize,
1339        cos: usize,
1340        sin: usize,
1341        dx: usize,
1342        batch: u32,
1343        seq: u32,
1344        hidden: u32,
1345        head_dim: u32,
1346        n_rot: u32,
1347        cos_len: u32,
1348    },
1349    CumsumBackward {
1350        dy: usize,
1351        dx: usize,
1352        rows: u32,
1353        cols: u32,
1354        exclusive: bool,
1355    },
1356    GatherBackward {
1357        dy: usize,
1358        indices: usize,
1359        dst: usize,
1360        outer: u32,
1361        axis_dim: u32,
1362        num_idx: u32,
1363        trailing: u32,
1364    },
1365
1366    GroupNormBackwardInput {
1367        x: usize,
1368        gamma: usize,
1369        beta: usize,
1370        dy: usize,
1371        dx: usize,
1372        n: u32,
1373        c: u32,
1374        h: u32,
1375        w: u32,
1376        num_groups: u32,
1377        eps: f32,
1378    },
1379    GroupNormBackwardGamma {
1380        x: usize,
1381        dy: usize,
1382        dgamma: usize,
1383        n: u32,
1384        c: u32,
1385        h: u32,
1386        w: u32,
1387        num_groups: u32,
1388        eps: f32,
1389    },
1390    GroupNormBackwardBeta {
1391        dy: usize,
1392        dbeta: usize,
1393        n: u32,
1394        c: u32,
1395        h: u32,
1396        w: u32,
1397    },
1398
1399    /// 2D max-pool backward (NCHW). Recomputes the argmax position
1400    /// inside each window and accumulates `dy` into `dx` at that
1401    /// position. Output is zeroed first; ties resolve to the first
1402    /// hit (lowest (kh,kw) index), matching what the forward kernel
1403    /// does with `acc.max(v)`.
1404    MaxPool2dBackward {
1405        x: usize,
1406        dy: usize,
1407        dx: usize,
1408        n: u32,
1409        c: u32,
1410        h: u32,
1411        w: u32,
1412        h_out: u32,
1413        w_out: u32,
1414        kh: u32,
1415        kw: u32,
1416        sh: u32,
1417        sw: u32,
1418        ph: u32,
1419        pw: u32,
1420    },
1421
1422    /// 2D conv backward w.r.t. input (`dx = conv_transpose(dy, w)`).
1423    /// `dy [N, C_out, H_out, W_out]`, `w [C_out, C_in_per_group, kH, kW]`,
1424    /// `dx [N, C_in, H, W]`.
1425    Conv2dBackwardInput {
1426        dy: usize,
1427        w: usize,
1428        dx: usize,
1429        n: u32,
1430        c_in: u32,
1431        h: u32,
1432        w_in: u32,
1433        c_out: u32,
1434        h_out: u32,
1435        w_out: u32,
1436        kh: u32,
1437        kw: u32,
1438        sh: u32,
1439        sw: u32,
1440        ph: u32,
1441        pw: u32,
1442        dh: u32,
1443        dw: u32,
1444        groups: u32,
1445    },
1446
1447    /// 2D conv backward w.r.t. weight. `x [N, C_in, H, W]`,
1448    /// `dy [N, C_out, H_out, W_out]`, `dw [C_out, C_in_per_group, kH, kW]`.
1449    /// `dw` is zeroed before accumulation.
1450    Conv2dBackwardWeight {
1451        x: usize,
1452        dy: usize,
1453        dw: usize,
1454        n: u32,
1455        c_in: u32,
1456        h: u32,
1457        w: u32,
1458        c_out: u32,
1459        h_out: u32,
1460        w_out: u32,
1461        kh: u32,
1462        kw: u32,
1463        sh: u32,
1464        sw: u32,
1465        ph: u32,
1466        pw: u32,
1467        dh: u32,
1468        dw_dil: u32,
1469        groups: u32,
1470    },
1471
1472    /// Fused softmax + cross-entropy loss with f32-encoded integer
1473    /// labels. `logits [N, C]`, `labels [N]`, output `[N]` per-row loss.
1474    /// Numerically stable (max-subtract before exp).
1475    SoftmaxCrossEntropy {
1476        logits: usize,
1477        labels: usize,
1478        dst: usize,
1479        n: u32,
1480        c: u32,
1481    },
1482
1483    /// Backward of the fused loss above.
1484    /// `dlogits[n, k] = (softmax(logits[n])[k] - one_hot(labels[n])[k]) * d_loss[n]`.
1485    SoftmaxCrossEntropyBackward {
1486        logits: usize,
1487        labels: usize,
1488        d_loss: usize,
1489        dlogits: usize,
1490        n: u32,
1491        c: u32,
1492    },
1493
1494    /// User-registered custom op (CPU side). Lowered from `Op::Custom`.
1495    /// `kernel` is resolved against the global CPU kernel registry at
1496    /// compile time and stored as `Arc<dyn CpuKernel>` so execution
1497    /// avoids per-call lookups. v1: f32 contiguous only — see
1498    /// `op_registry::CpuKernel::execute_f32`.
1499    CustomOp {
1500        kernel: Arc<dyn CpuKernel>,
1501        inputs: Vec<(usize, u32, Shape)>, // (offset, len_elements, shape)
1502        output: (usize, u32, Shape),      // (offset, len_elements, shape)
1503        attrs: Vec<u8>,
1504    },
1505
1506    /// 1D FFT along the last axis. Input/output are `[..., 2N]`
1507    /// real-block layout (first N real, second N imag along the
1508    /// transformed axis). `outer` is the product of all leading axes;
1509    /// `n_complex` is N (the number of complex points). Both halves
1510    /// of the real-block layout are read together by the kernel.
1511    /// `dtype` selects the f32 or f64 path; the two share structure
1512    /// but not buffers, so a flag at compile time avoids per-row
1513    /// dispatch.
1514    /// CPU reference 3D Gaussian splat render ([`rlx_ir::Op::GaussianSplatRender`]).
1515    GaussianSplatRender {
1516        positions_off: usize,
1517        positions_len: usize,
1518        scales_off: usize,
1519        scales_len: usize,
1520        rotations_off: usize,
1521        rotations_len: usize,
1522        opacities_off: usize,
1523        opacities_len: usize,
1524        colors_off: usize,
1525        colors_len: usize,
1526        sh_coeffs_off: usize,
1527        sh_coeffs_len: usize,
1528        meta_off: usize,
1529        dst_off: usize,
1530        dst_len: usize,
1531        width: u32,
1532        height: u32,
1533        tile_size: u32,
1534        radius_scale: f32,
1535        alpha_cutoff: f32,
1536        max_splat_steps: u32,
1537        transmittance_threshold: f32,
1538        max_list_entries: u32,
1539    },
1540    GaussianSplatRenderBackward {
1541        positions_off: usize,
1542        positions_len: usize,
1543        scales_off: usize,
1544        scales_len: usize,
1545        rotations_off: usize,
1546        rotations_len: usize,
1547        opacities_off: usize,
1548        opacities_len: usize,
1549        colors_off: usize,
1550        colors_len: usize,
1551        sh_coeffs_off: usize,
1552        sh_coeffs_len: usize,
1553        meta_off: usize,
1554        d_loss_off: usize,
1555        d_loss_len: usize,
1556        packed_off: usize,
1557        packed_len: usize,
1558        width: u32,
1559        height: u32,
1560        tile_size: u32,
1561        radius_scale: f32,
1562        alpha_cutoff: f32,
1563        max_splat_steps: u32,
1564        transmittance_threshold: f32,
1565        max_list_entries: u32,
1566        loss_grad_clip: f32,
1567        sh_band: u32,
1568        max_anisotropy: f32,
1569    },
1570    /// Strict IR stage 1 — project + bin + sort + rays ([`Op::GaussianSplatPrepare`]).
1571    GaussianSplatPrepare {
1572        positions_off: usize,
1573        positions_len: usize,
1574        scales_off: usize,
1575        scales_len: usize,
1576        rotations_off: usize,
1577        rotations_len: usize,
1578        opacities_off: usize,
1579        opacities_len: usize,
1580        colors_off: usize,
1581        colors_len: usize,
1582        sh_coeffs_off: usize,
1583        sh_coeffs_len: usize,
1584        meta_off: usize,
1585        meta_len: usize,
1586        prep_off: usize,
1587        prep_len: usize,
1588        width: u32,
1589        height: u32,
1590        tile_size: u32,
1591        radius_scale: f32,
1592        alpha_cutoff: f32,
1593        max_splat_steps: u32,
1594        transmittance_threshold: f32,
1595        max_list_entries: u32,
1596    },
1597    /// Strict IR stage 2 — tile raster from prepare buffer ([`Op::GaussianSplatRasterize`]).
1598    GaussianSplatRasterize {
1599        prep_off: usize,
1600        prep_len: usize,
1601        meta_off: usize,
1602        meta_len: usize,
1603        dst_off: usize,
1604        dst_len: usize,
1605        count: usize,
1606        width: u32,
1607        height: u32,
1608        tile_size: u32,
1609        alpha_cutoff: f32,
1610        max_splat_steps: u32,
1611        transmittance_threshold: f32,
1612        max_list_entries: u32,
1613    },
1614    Fft1d {
1615        src: usize,
1616        dst: usize,
1617        outer: u32,
1618        n_complex: u32,
1619        inverse: bool,
1620        dtype: rlx_ir::DType,
1621    },
1622}
1623
1624/// Compiled thunk schedule — the runtime hot path.
1625/// Nop thunks are filtered out at compile time for zero iteration overhead.
1626#[derive(Clone)]
1627pub struct ThunkSchedule {
1628    pub thunks: Vec<Thunk>,
1629    /// TIDE merged placement mask (union across layers).
1630    pub moe_resident: Option<std::sync::Arc<[bool]>>,
1631    /// Per MoE layer placement (`layer[e]`); preferred when set.
1632    pub moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
1633    /// MoE router TopK capture (per-layer refresh).
1634    pub moe_topk_capture: Option<std::sync::Arc<crate::moe_topk_capture::MoeTopkCapture>>,
1635    /// Cached config values.
1636    pub mask_threshold: f32,
1637    pub mask_neg_inf: f32,
1638    pub score_skip: f32,
1639    /// Pre-compiled closure dispatch (zero match overhead). `Arc` (not
1640    /// `Box`) so the schedule can be `Clone` — multiple parallel
1641    /// executors share the same compiled closures (they're read-only
1642    /// `Fn(*mut u8)` so concurrent dispatch is safe; the arena pointer
1643    /// they receive is the only mutable state and is per-executor).
1644    pub compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>>,
1645}
1646
1647impl ThunkSchedule {
1648    pub fn strip_nops(&mut self) {
1649        self.thunks.retain(|t| !matches!(t, Thunk::Nop));
1650        // compiled_fns must be rebuilt after stripping — caller should
1651        // call strip_nops() before compile_closures().
1652        self.compiled_fns.clear();
1653    }
1654}
1655
1656/// Get the arena byte offset for a node.
1657fn node_offset(arena: &Arena, id: NodeId) -> usize {
1658    if arena.has_buffer(id) {
1659        arena.byte_offset(id)
1660    } else {
1661        usize::MAX
1662    }
1663}
1664
1665/// Every byte-offset that a thunk reads from. Used by the Narrow→Rope
1666/// fusion (#45) to verify a Narrow's dst has exactly one consumer
1667/// before eliding it. Conservative: when in doubt about reads (an op
1668/// not yet listed here), the fusion will skip — correctness over
1669/// completeness.
1670fn thunk_read_offsets(t: &Thunk) -> Vec<usize> {
1671    match t {
1672        Thunk::Sgemm { a, b, .. } => vec![*a, *b],
1673        Thunk::DenseSolveF64 { a, b, .. } => vec![*a, *b],
1674        Thunk::DenseSolveF32 { a, b, .. } => vec![*a, *b],
1675        Thunk::BatchedDenseSolveF64 { a, b, .. } => vec![*a, *b],
1676        Thunk::BatchedDgemmF64 { a, b, .. } => vec![*a, *b],
1677        Thunk::BatchedSgemm { a, b, .. } => vec![*a, *b],
1678        Thunk::FusedMmBiasAct { a, w, bias, .. } => vec![*a, *w, *bias],
1679        Thunk::BiasAdd { src, bias, .. } => vec![*src, *bias],
1680        Thunk::BinaryFull { lhs, rhs, .. } => vec![*lhs, *rhs],
1681        Thunk::BinaryFullF64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1682        Thunk::BinaryFullC64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1683        Thunk::ComplexNormSqF32 { src, .. } => vec![*src],
1684        Thunk::ComplexNormSqBackwardF32 { z, g, .. } => vec![*z, *g],
1685        Thunk::ConjugateC64 { src, .. } => vec![*src],
1686        Thunk::Scan {
1687            outer_init_off,
1688            xs_inputs,
1689            ..
1690        } => {
1691            let mut v = vec![*outer_init_off];
1692            for (_, outer_xs_off, _) in xs_inputs.iter() {
1693                v.push(*outer_xs_off);
1694            }
1695            v
1696        }
1697        Thunk::ScanBackward {
1698            outer_init_off,
1699            outer_traj_off,
1700            outer_upstream_off,
1701            outer_xs_offs,
1702            ..
1703        } => {
1704            let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1705            for (off, _) in outer_xs_offs.iter() {
1706                v.push(*off);
1707            }
1708            v
1709        }
1710        Thunk::ScanBackwardXs {
1711            outer_init_off,
1712            outer_traj_off,
1713            outer_upstream_off,
1714            outer_xs_offs,
1715            ..
1716        } => {
1717            let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1718            for (off, _) in outer_xs_offs.iter() {
1719                v.push(*off);
1720            }
1721            v
1722        }
1723        Thunk::CustomFn { inputs, .. } => {
1724            inputs.iter().map(|(_, outer_off, _)| *outer_off).collect()
1725        }
1726        Thunk::ActivationInPlace { data, .. } => vec![*data],
1727        Thunk::LayerNorm { src, g, b, .. } | Thunk::GroupNorm { src, g, b, .. } => {
1728            vec![*src, *g, *b]
1729        }
1730        Thunk::ResizeNearest2x { src, .. } => vec![*src],
1731        Thunk::AxialRope2d { src, .. } => vec![*src],
1732        Thunk::FusedResidualLN {
1733            x, res, bias, g, b, ..
1734        } => vec![*x, *res, *bias, *g, *b],
1735        Thunk::FusedResidualRmsNorm {
1736            x, res, bias, g, b, ..
1737        } => vec![*x, *res, *bias, *g, *b],
1738        Thunk::RmsNorm { src, g, b, .. } => vec![*src, *g, *b],
1739        Thunk::Softmax { data, .. } => vec![*data],
1740        Thunk::Cumsum { src, .. } => vec![*src],
1741        Thunk::Sample { logits, .. } => vec![*logits],
1742        Thunk::LoraMatMul { x, w, a, b, .. } => vec![*x, *w, *a, *b],
1743        Thunk::DequantMatMul {
1744            x, w_q, scale, zp, ..
1745        } => vec![*x, *w_q, *scale, *zp],
1746        Thunk::DequantMatMulGguf { x, w_q, .. } => vec![*x, *w_q],
1747        Thunk::DequantMatMulInt4 {
1748            x, w_q, scale, zp, ..
1749        } => vec![*x, *w_q, *scale, *zp],
1750        Thunk::DequantMatMulFp8 { x, w_q, scale, .. } => vec![*x, *w_q, *scale],
1751        Thunk::DequantMatMulNvfp4 {
1752            x,
1753            w_q,
1754            scale,
1755            global_scale,
1756            ..
1757        } => vec![*x, *w_q, *scale, *global_scale],
1758        Thunk::Conv2D1x1 { src, weight, .. } => vec![*src, *weight],
1759        Thunk::SelectiveScan {
1760            x, delta, a, b, c, ..
1761        } => vec![*x, *delta, *a, *b, *c],
1762        Thunk::GatedDeltaNet {
1763            q,
1764            k,
1765            v,
1766            g,
1767            beta,
1768            state,
1769            ..
1770        } => {
1771            let mut v = vec![*q, *k, *v, *g, *beta];
1772            if *state != 0 {
1773                v.push(*state);
1774            }
1775            v
1776        }
1777        Thunk::Attention { q, k, v, mask, .. } => vec![*q, *k, *v, *mask],
1778        Thunk::AttentionBackward {
1779            q, k, v, dy, mask, ..
1780        } => {
1781            let mut v = vec![*q, *k, *v, *dy];
1782            if *mask != 0 {
1783                v.push(*mask);
1784            }
1785            v
1786        }
1787        Thunk::Rope { src, cos, sin, .. } => vec![*src, *cos, *sin],
1788        Thunk::FusedAttnBlock {
1789            hidden,
1790            qkv_w,
1791            out_w,
1792            mask,
1793            qkv_b,
1794            out_b,
1795            cos,
1796            sin,
1797            ..
1798        } => vec![*hidden, *qkv_w, *out_w, *mask, *qkv_b, *out_b, *cos, *sin],
1799        Thunk::FusedSwiGLU { src, .. } => vec![*src],
1800        Thunk::Concat { inputs, .. } => inputs.iter().map(|(off, _)| *off).collect(),
1801        Thunk::ConcatF64 { inputs, .. } => inputs.iter().map(|(off, _)| *off).collect(),
1802        Thunk::Narrow { src, .. } => vec![*src],
1803        Thunk::Copy { src, .. } => vec![*src],
1804        Thunk::Gather { table, idx, .. } => vec![*table, *idx],
1805        // Anything not enumerated → return the dst as a "read" too,
1806        // forcing the fusion to bail (read_count >= 2 → skip). Keeps
1807        // this list safe to be incomplete.
1808        _ => vec![],
1809    }
1810}
1811
1812/// Fused dequant + matmul (plan #5). Int8-blockwise weights: each
1813/// `block_size` consecutive elements of a column share one f32
1814/// scale (and optionally a zero-point). The dequant happens inside
1815/// the inner accumulate so the f32 weight is never materialized.
1816///
1817/// `w_bytes` is the row-major i8 weight matrix `[k, n]`. `scales`
1818/// and `zps` are `[k/block, n]`. When `asym=false`, `zps` may be
1819/// empty.
1820///
1821/// Today this is the reference scalar implementation — the win is
1822/// memory bandwidth, not flops, since LLM weights dominate the
1823/// working set. A NEON SIMD path that loads 16 i8 → splat-scale →
1824/// fused-multiply-add is the natural follow-on.
1825#[allow(clippy::too_many_arguments)]
1826fn dequant_matmul_int8(
1827    x: &[f32],       // [m, k]
1828    w_bytes: &[i8],  // [k, n]
1829    scales: &[f32],  // [k/block, n]
1830    zps: &[f32],     // [k/block, n] or empty
1831    out: &mut [f32], // [m, n]
1832    m: usize,
1833    k: usize,
1834    n: usize,
1835    block_size: usize,
1836    asym: bool,
1837) {
1838    let blocks_per_col = k.div_ceil(block_size);
1839    for i in 0..m {
1840        for j in 0..n {
1841            let mut acc = 0f32;
1842            for p in 0..k {
1843                let block = p / block_size;
1844                let s = scales[block * n + j];
1845                let z = if asym { zps[block * n + j] } else { 0.0 };
1846                let q = w_bytes[p * n + j] as f32;
1847                let dequantized = (q - z) * s;
1848                acc += x[i * k + p] * dequantized;
1849            }
1850            out[i * n + j] = acc;
1851        }
1852    }
1853    let _ = blocks_per_col;
1854}
1855
1856#[allow(clippy::too_many_arguments)]
1857fn dequant_matmul_int4(
1858    x: &[f32],
1859    w_bytes: &[u8],
1860    scales: &[f32],
1861    zps: &[f32],
1862    out: &mut [f32],
1863    m: usize,
1864    k: usize,
1865    n: usize,
1866    block_size: usize,
1867    asym: bool,
1868) {
1869    for i in 0..m {
1870        for j in 0..n {
1871            let mut acc = 0f32;
1872            for p in 0..k {
1873                let block = p / block_size;
1874                let s = scales[block * n + j];
1875                let z = if asym { zps[block * n + j] } else { 0.0 };
1876                let byte_idx = (p * n + j) / 2;
1877                let nibble = if (p * n + j) & 1 == 0 {
1878                    w_bytes[byte_idx] & 0x0F
1879                } else {
1880                    w_bytes[byte_idx] >> 4
1881                };
1882                let dequantized = (nibble as f32 - z) * s;
1883                acc += x[i * k + p] * dequantized;
1884            }
1885            out[i * n + j] = acc;
1886        }
1887    }
1888}
1889
1890fn fp8_e4m3_to_f32(b: u8) -> f32 {
1891    let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
1892    let exp = (b >> 3) & 0x0F;
1893    let mant = b & 0x07;
1894    if exp == 0 {
1895        if mant == 0 {
1896            return 0.0;
1897        }
1898        return sign * (mant as f32) * 2f32.powi(-9);
1899    }
1900    if exp == 0x0F {
1901        return if mant == 0 {
1902            sign * f32::INFINITY
1903        } else {
1904            f32::NAN
1905        };
1906    }
1907    sign * (1.0 + mant as f32 / 8.0) * 2f32.powi(exp as i32 - 7)
1908}
1909
1910fn fp8_e5m2_to_f32(b: u8) -> f32 {
1911    let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
1912    let exp = (b >> 2) & 0x1F;
1913    let mant = b & 0x03;
1914    if exp == 0 {
1915        if mant == 0 {
1916            return 0.0;
1917        }
1918        return sign * (mant as f32) * 2f32.powi(-16);
1919    }
1920    if exp == 0x1F {
1921        return if mant == 0 {
1922            sign * f32::INFINITY
1923        } else {
1924            f32::NAN
1925        };
1926    }
1927    sign * (1.0 + mant as f32 / 4.0) * 2f32.powi(exp as i32 - 15)
1928}
1929
1930#[allow(clippy::too_many_arguments)]
1931fn dequant_matmul_fp8(
1932    x: &[f32],
1933    w_bytes: &[u8],
1934    scales: &[f32],
1935    out: &mut [f32],
1936    m: usize,
1937    k: usize,
1938    n: usize,
1939    e5m2: bool,
1940) {
1941    let dequant = if e5m2 {
1942        fp8_e5m2_to_f32
1943    } else {
1944        fp8_e4m3_to_f32
1945    };
1946    for i in 0..m {
1947        for j in 0..n {
1948            let mut acc = 0f32;
1949            for p in 0..k {
1950                let w = dequant(w_bytes[p * n + j]);
1951                let s = scales.get(j).copied().unwrap_or(1.0);
1952                acc += x[i * k + p] * w * s;
1953            }
1954            out[i * n + j] = acc;
1955        }
1956    }
1957}
1958
1959#[allow(clippy::too_many_arguments)]
1960pub fn dequant_matmul_nvfp4(
1961    x: &[f32],
1962    w_bytes: &[u8],
1963    scale_bytes: &[u8],
1964    global_scale: f32,
1965    out: &mut [f32],
1966    m: usize,
1967    k: usize,
1968    n: usize,
1969) {
1970    use rlx_ir::{NVFP4_GROUP_SIZE, fp4_e2m1_to_f32, fp8_e4m3_scale_to_f32};
1971    let gs = NVFP4_GROUP_SIZE;
1972    for i in 0..m {
1973        for j in 0..n {
1974            let mut acc = 0f32;
1975            for p in 0..k {
1976                let byte_idx = (p * n + j) / 2;
1977                let nibble = if (p * n + j) & 1 == 0 {
1978                    w_bytes[byte_idx] & 0x0F
1979                } else {
1980                    w_bytes[byte_idx] >> 4
1981                };
1982                let block = p / gs;
1983                let scale = fp8_e4m3_scale_to_f32(scale_bytes[block * n + j]);
1984                let w = fp4_e2m1_to_f32(nibble) * scale * global_scale;
1985                acc += x[i * k + p] * w;
1986            }
1987            out[i * n + j] = acc;
1988        }
1989    }
1990}
1991
1992/// Fused sampling step: logits → top-k filter → top-p truncation
1993/// → softmax → multinomial sample. Operates on one row of length
1994/// `vocab` and returns the sampled index. Plan #42.
1995///
1996/// Internal scratch is on the stack via SmallVec-style fallback —
1997/// for `vocab > 8192` we heap-allocate a working buffer; below
1998/// that we keep things in a fixed array. (TODO: thread the
1999/// scratch through ThunkSchedule like sdpa_scores does.)
2000fn sample_row(
2001    logits: &[f32],
2002    top_k: usize,
2003    top_p: f32,
2004    temperature: f32,
2005    rng: &mut rlx_ir::Philox4x32,
2006) -> usize {
2007    let v = logits.len();
2008    if v == 0 {
2009        return 0;
2010    }
2011    let temp = temperature.max(1e-6);
2012    // Copy + temperature-scale into a working buffer.
2013    let mut scaled: Vec<f32> = logits.iter().map(|&x| x / temp).collect();
2014
2015    // Top-k: zero out everything but the k largest by setting to -inf.
2016    if top_k > 0 && top_k < v {
2017        // Partial selection: find k-th largest then mask below.
2018        let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2019        // Sort descending; partial would be O(n log k), full sort is fine
2020        // for typical vocab sizes (32k-128k) — single-row work.
2021        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2022        let cutoff = indexed[top_k - 1].1;
2023        for x in scaled.iter_mut() {
2024            if *x < cutoff {
2025                *x = f32::NEG_INFINITY;
2026            }
2027        }
2028    }
2029
2030    // Stable softmax.
2031    let mut max_l = f32::NEG_INFINITY;
2032    for &x in &scaled {
2033        if x > max_l {
2034            max_l = x;
2035        }
2036    }
2037    let mut sum = 0.0f32;
2038    for x in scaled.iter_mut() {
2039        *x = (*x - max_l).exp();
2040        sum += *x;
2041    }
2042    let inv = 1.0 / sum.max(f32::MIN_POSITIVE);
2043    for x in scaled.iter_mut() {
2044        *x *= inv;
2045    }
2046
2047    // Top-p: keep the smallest set of tokens whose cumulative
2048    // probability exceeds top_p (after sorting descending).
2049    if top_p < 1.0 {
2050        let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2051        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2052        let mut cum = 0.0f32;
2053        let mut keep = vec![false; v];
2054        for (idx, p) in indexed.iter() {
2055            keep[*idx] = true;
2056            cum += *p;
2057            if cum >= top_p {
2058                break;
2059            }
2060        }
2061        let mut new_sum = 0.0f32;
2062        for (i, x) in scaled.iter_mut().enumerate() {
2063            if !keep[i] {
2064                *x = 0.0;
2065            }
2066            new_sum += *x;
2067        }
2068        let inv = 1.0 / new_sum.max(f32::MIN_POSITIVE);
2069        for x in scaled.iter_mut() {
2070            *x *= inv;
2071        }
2072    }
2073
2074    // Multinomial sample via inverse-CDF.
2075    let r = rng.next_f32();
2076    let mut acc = 0.0f32;
2077    for (i, &p) in scaled.iter().enumerate() {
2078        acc += p;
2079        if r <= acc {
2080            return i;
2081        }
2082    }
2083    v - 1 // floating-point edge case fallback
2084}
2085
2086/// Apply a synthetic (kernel-generated) attention mask to a `[q_seq, k_seq]`
2087/// scores matrix. Custom masks are read from a tensor and not handled here.
2088/// `None` is a no-op so callers don't need to special-case it.
2089#[inline]
2090fn apply_synthetic_mask(
2091    scores: &mut [f32],
2092    q_seq: usize,
2093    k_seq: usize,
2094    kind: rlx_ir::op::MaskKind,
2095) {
2096    let neg = crate::config::RuntimeConfig::global().attn_mask_neg_inf;
2097    let q_offset = k_seq.saturating_sub(q_seq);
2098    match kind {
2099        rlx_ir::op::MaskKind::None | rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias => {}
2100        rlx_ir::op::MaskKind::Causal => {
2101            for qi in 0..q_seq {
2102                let abs_q = q_offset + qi;
2103                for ki in (abs_q + 1)..k_seq {
2104                    scores[qi * k_seq + ki] = neg;
2105                }
2106            }
2107        }
2108        rlx_ir::op::MaskKind::SlidingWindow(w) => {
2109            for qi in 0..q_seq {
2110                let abs_q = q_offset + qi;
2111                let lo = abs_q.saturating_sub(w);
2112                for ki in 0..k_seq {
2113                    if ki < lo || ki > abs_q {
2114                        scores[qi * k_seq + ki] = neg;
2115                    }
2116                }
2117            }
2118        }
2119    }
2120}
2121
2122/// Compile graph into thunk schedule.
2123pub fn compile_thunks(graph: &Graph, arena: &Arena) -> ThunkSchedule {
2124    let mut thunks = Vec::with_capacity(graph.len());
2125
2126    for node in graph.nodes() {
2127        // View ops (Reshape / same-dtype Cast / axis-0 Narrow) are aliased
2128        // to their parent's slot by the memory planner — no copy needed.
2129        // Plan #46.
2130        if rlx_opt::is_pure_view(graph, node) {
2131            thunks.push(Thunk::Nop);
2132            continue;
2133        }
2134        let t = match &node.op {
2135            Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => Thunk::Nop,
2136
2137            Op::FusedMatMulBiasAct { activation } => {
2138                let shape = &node.shape;
2139                let n = shape.dim(shape.rank() - 1).unwrap_static();
2140                let total = shape.num_elements().unwrap();
2141                let m = total / n;
2142                let a_len = get_len(graph, node.inputs[0]);
2143                let k = a_len / m;
2144                Thunk::FusedMmBiasAct {
2145                    a: node_offset(arena, node.inputs[0]),
2146                    w: node_offset(arena, node.inputs[1]),
2147                    bias: node_offset(arena, node.inputs[2]),
2148                    c: node_offset(arena, node.id),
2149                    m: m as u32,
2150                    k: k as u32,
2151                    n: n as u32,
2152                    act: *activation,
2153                }
2154            }
2155
2156            Op::FusedResidualLN { has_bias, eps } => {
2157                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2158                let total = node.shape.num_elements().unwrap();
2159                let rows = total / h;
2160                let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2161                Thunk::FusedResidualLN {
2162                    x: node_offset(arena, node.inputs[0]),
2163                    res: node_offset(arena, node.inputs[1]),
2164                    bias: if *has_bias {
2165                        node_offset(arena, node.inputs[2])
2166                    } else {
2167                        0
2168                    },
2169                    g: node_offset(arena, node.inputs[g_idx]),
2170                    b: node_offset(arena, node.inputs[b_idx]),
2171                    out: node_offset(arena, node.id),
2172                    rows: rows as u32,
2173                    h: h as u32,
2174                    eps: *eps,
2175                    has_bias: *has_bias,
2176                }
2177            }
2178
2179            Op::FusedResidualRmsNorm { has_bias, eps } => {
2180                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2181                let total = node.shape.num_elements().unwrap();
2182                let rows = total / h;
2183                let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2184                Thunk::FusedResidualRmsNorm {
2185                    x: node_offset(arena, node.inputs[0]),
2186                    res: node_offset(arena, node.inputs[1]),
2187                    bias: if *has_bias {
2188                        node_offset(arena, node.inputs[2])
2189                    } else {
2190                        0
2191                    },
2192                    g: node_offset(arena, node.inputs[g_idx]),
2193                    b: node_offset(arena, node.inputs[b_idx]),
2194                    out: node_offset(arena, node.id),
2195                    rows: rows as u32,
2196                    h: h as u32,
2197                    eps: *eps,
2198                    has_bias: *has_bias,
2199                }
2200            }
2201
2202            Op::MatMul => {
2203                let shape = &node.shape;
2204                let a_shape = &graph.node(node.inputs[0]).shape;
2205                let b_shape = &graph.node(node.inputs[1]).shape;
2206                let n = shape.dim(shape.rank() - 1).unwrap_static();
2207
2208                // Detect batched matmul: any rank where both inputs
2209                // and output share the same leading batch dims and
2210                // the last 2 dims form an [M, K] @ [K, N] = [M, N].
2211                // The 2-D MatMul lowering's flatten-and-call-dgemm trick
2212                // is wrong when both operands carry independent batch
2213                // dims (per-batch K dimension differs).
2214                let batched_3d = a_shape.rank() >= 3
2215                    && b_shape.rank() == a_shape.rank()
2216                    && shape.rank() == a_shape.rank()
2217                    && {
2218                        // All leading dims (everything except last 2) match.
2219                        let mut ok = true;
2220                        for d in 0..a_shape.rank() - 2 {
2221                            if a_shape.dim(d) != b_shape.dim(d) || a_shape.dim(d) != shape.dim(d) {
2222                                ok = false;
2223                                break;
2224                            }
2225                        }
2226                        ok
2227                    };
2228                if batched_3d && shape.dtype() == rlx_ir::DType::F64 {
2229                    // Batch is the product of all leading dims (every
2230                    // dim except the last 2); m/k/n are the inner
2231                    // matmul dims. Works for any rank >= 3.
2232                    let r = shape.rank();
2233                    let mut batch_prod = 1usize;
2234                    for d in 0..r - 2 {
2235                        batch_prod *= shape.dim(d).unwrap_static();
2236                    }
2237                    let m_dim = shape.dim(r - 2).unwrap_static();
2238                    let k_dim = a_shape.dim(r - 1).unwrap_static();
2239                    debug_assert_eq!(k_dim, b_shape.dim(r - 2).unwrap_static());
2240                    Thunk::BatchedDgemmF64 {
2241                        a: node_offset(arena, node.inputs[0]),
2242                        b: node_offset(arena, node.inputs[1]),
2243                        c: node_offset(arena, node.id),
2244                        batch: batch_prod as u32,
2245                        m: m_dim as u32,
2246                        k: k_dim as u32,
2247                        n: n as u32,
2248                    }
2249                } else if batched_3d && shape.dtype() == rlx_ir::DType::F32 {
2250                    // f32 batched matmul for any rank >= 3 (collapse all
2251                    // leading batch dims into a single batch count).
2252                    let r = shape.rank();
2253                    let mut batch_prod = 1usize;
2254                    for d in 0..r - 2 {
2255                        batch_prod *= shape.dim(d).unwrap_static();
2256                    }
2257                    let m_dim = shape.dim(r - 2).unwrap_static();
2258                    let k_dim = a_shape.dim(r - 1).unwrap_static();
2259                    debug_assert_eq!(k_dim, b_shape.dim(r - 2).unwrap_static());
2260                    Thunk::BatchedSgemm {
2261                        a: node_offset(arena, node.inputs[0]),
2262                        b: node_offset(arena, node.inputs[1]),
2263                        c: node_offset(arena, node.id),
2264                        batch: batch_prod as u32,
2265                        m: m_dim as u32,
2266                        k: k_dim as u32,
2267                        n: n as u32,
2268                    }
2269                } else {
2270                    let total = shape.num_elements().unwrap();
2271                    let m = total / n;
2272                    let a_len = get_len(graph, node.inputs[0]);
2273                    let k = a_len / m;
2274                    match shape.dtype() {
2275                        rlx_ir::DType::F64 => Thunk::Dgemm {
2276                            a: node_offset(arena, node.inputs[0]),
2277                            b: node_offset(arena, node.inputs[1]),
2278                            c: node_offset(arena, node.id),
2279                            m: m as u32,
2280                            k: k as u32,
2281                            n: n as u32,
2282                        },
2283                        _ => Thunk::Sgemm {
2284                            a: node_offset(arena, node.inputs[0]),
2285                            b: node_offset(arena, node.inputs[1]),
2286                            c: node_offset(arena, node.id),
2287                            m: m as u32,
2288                            k: k as u32,
2289                            n: n as u32,
2290                        },
2291                    }
2292                }
2293            }
2294
2295            Op::Binary(op) => {
2296                let lhs_len = get_len(graph, node.inputs[0]);
2297                let rhs_len = get_len(graph, node.inputs[1]);
2298                let out_len = node.shape.num_elements().unwrap();
2299                if node.shape.dtype() == rlx_ir::DType::C64 {
2300                    // Native C64 element-wise. Add/Sub/Mul/Div lower
2301                    // to `BinaryFullC64`; the rest don't have a
2302                    // single natural complex definition.
2303                    match op {
2304                        BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div => {}
2305                        BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => panic!(
2306                            "Op::Binary({op:?}) on DType::C64: complex \
2307                             max/min/pow have no single natural definition \
2308                             — caller should drop to 2N-real-block (see \
2309                             spike-ac) and pick a convention there"
2310                        ),
2311                    }
2312                }
2313                // Compute broadcast strides for the slow path. Empty
2314                // vectors when no broadcast is needed (the fast-path
2315                // kernel ignores them anyway).
2316                let (out_dims_bcast, bcast_lhs_strides, bcast_rhs_strides) =
2317                    if lhs_len == out_len && rhs_len == out_len {
2318                        (Vec::new(), Vec::new(), Vec::new())
2319                    } else {
2320                        let lhs_dims = get_static_dims(graph, node.inputs[0]);
2321                        let rhs_dims = get_static_dims(graph, node.inputs[1]);
2322                        let out_dims_v = get_static_dims(graph, node.id);
2323                        if lhs_dims.is_empty() || rhs_dims.is_empty() || out_dims_v.is_empty() {
2324                            // Dynamic shape — fall back to the legacy
2325                            // modulo path (correct for scalar / last-
2326                            // axis broadcast, which is the only
2327                            // dynamic case in practice).
2328                            (Vec::new(), Vec::new(), Vec::new())
2329                        } else {
2330                            let ls = broadcast_strides(&lhs_dims, &out_dims_v);
2331                            let rs = broadcast_strides(&rhs_dims, &out_dims_v);
2332                            let od: Vec<u32> = out_dims_v.iter().map(|x| *x as u32).collect();
2333                            (od, ls, rs)
2334                        }
2335                    };
2336                if node.shape.dtype() == rlx_ir::DType::C64 {
2337                    Thunk::BinaryFullC64 {
2338                        lhs: node_offset(arena, node.inputs[0]),
2339                        rhs: node_offset(arena, node.inputs[1]),
2340                        dst: node_offset(arena, node.id),
2341                        len: out_len as u32,
2342                        lhs_len: lhs_len as u32,
2343                        rhs_len: rhs_len as u32,
2344                        op: *op,
2345                        out_dims_bcast,
2346                        bcast_lhs_strides,
2347                        bcast_rhs_strides,
2348                    }
2349                } else if node.shape.dtype() == rlx_ir::DType::F64 {
2350                    // f64 path — no BiasAdd fast-path (yet); use the
2351                    // general binary-with-broadcast kernel.
2352                    Thunk::BinaryFullF64 {
2353                        lhs: node_offset(arena, node.inputs[0]),
2354                        rhs: node_offset(arena, node.inputs[1]),
2355                        dst: node_offset(arena, node.id),
2356                        len: out_len as u32,
2357                        lhs_len: lhs_len as u32,
2358                        rhs_len: rhs_len as u32,
2359                        op: *op,
2360                        out_dims_bcast,
2361                        bcast_lhs_strides,
2362                        bcast_rhs_strides,
2363                    }
2364                } else if matches!(op, BinaryOp::Add)
2365                    && rhs_len < out_len
2366                    && out_len % rhs_len == 0
2367                    && is_trailing_bias_broadcast(
2368                        graph.node(node.inputs[1]).shape.dims(),
2369                        graph.node(node.id).shape.dims(),
2370                    )
2371                {
2372                    // `BiasAdd` is only correct when the bias is a
2373                    // *trailing* broadcast — rhs dims match the right-
2374                    // hand side of the output dims (with size-1 only
2375                    // allowed in left-padded outer positions).
2376                    // SAM's rel-pos `[bh, h, w, 1, w] + [bh, h, w, h, w]`
2377                    // has rhs_len divide out_len cleanly but is a
2378                    // mid-shape singleton, NOT a trailing broadcast.
2379                    // Routing it through BiasAdd silently treats it as
2380                    // last-`rhs_len`-cols repeated — wrong values.
2381                    Thunk::BiasAdd {
2382                        src: node_offset(arena, node.inputs[0]),
2383                        bias: node_offset(arena, node.inputs[1]),
2384                        dst: node_offset(arena, node.id),
2385                        m: (out_len / rhs_len) as u32,
2386                        n: rhs_len as u32,
2387                    }
2388                } else {
2389                    let lhs_len = get_len(graph, node.inputs[0]);
2390                    Thunk::BinaryFull {
2391                        lhs: node_offset(arena, node.inputs[0]),
2392                        rhs: node_offset(arena, node.inputs[1]),
2393                        dst: node_offset(arena, node.id),
2394                        len: out_len as u32,
2395                        lhs_len: lhs_len as u32,
2396                        rhs_len: rhs_len as u32,
2397                        op: *op,
2398                        out_dims_bcast,
2399                        bcast_lhs_strides,
2400                        bcast_rhs_strides,
2401                    }
2402                }
2403            }
2404
2405            Op::Activation(act) => {
2406                let len = node.shape.num_elements().unwrap();
2407                let in_off = node_offset(arena, node.inputs[0]);
2408                let out_off = node_offset(arena, node.id);
2409                if node.shape.dtype() == rlx_ir::DType::C64 {
2410                    // Only Neg/Exp/Log/Sqrt have natural complex
2411                    // extensions used in signal-processing graphs.
2412                    // Everything else (Sigmoid, Tanh, Relu, Abs,
2413                    // Sin/Cos/Tan/Atan, Round, GeLU family) is rejected.
2414                    match act {
2415                        Activation::Neg | Activation::Exp | Activation::Log | Activation::Sqrt => {}
2416                        other => panic!(
2417                            "Op::Activation({other:?}) on DType::C64: no \
2418                             natural complex extension — supported on C64: \
2419                             Neg, Exp, Log, Sqrt"
2420                        ),
2421                    }
2422                    Thunk::ActivationC64 {
2423                        src: in_off,
2424                        dst: out_off,
2425                        len: len as u32,
2426                        kind: *act,
2427                    }
2428                } else if node.shape.dtype() == rlx_ir::DType::F64 {
2429                    Thunk::ActivationF64 {
2430                        src: in_off,
2431                        dst: out_off,
2432                        len: len as u32,
2433                        kind: *act,
2434                    }
2435                } else if in_off == out_off {
2436                    // ActivationInPlace operates on a single buffer. When the
2437                    // planner has assigned input and output the same slot
2438                    // (typical post-fusion case), we just run on that slot.
2439                    Thunk::ActivationInPlace {
2440                        data: out_off,
2441                        len: len as u32,
2442                        act: *act,
2443                    }
2444                } else {
2445                    // Two-step: copy input → output, then activate output in place.
2446                    // The schedule executes them in this order; downstream
2447                    // thunks see the activated output at out_off.
2448                    thunks.push(Thunk::Copy {
2449                        src: in_off,
2450                        dst: out_off,
2451                        len: len as u32,
2452                    });
2453                    Thunk::ActivationInPlace {
2454                        data: out_off,
2455                        len: len as u32,
2456                        act: *act,
2457                    }
2458                }
2459            }
2460
2461            Op::Gather { axis } if *axis == 0 => {
2462                let table_shape = &graph.node(node.inputs[0]).shape;
2463                let table_total = table_shape.num_elements().unwrap();
2464                let trailing: usize = (1..table_shape.rank())
2465                    .map(|i| table_shape.dim(i).unwrap_static())
2466                    .product();
2467                let idx_len = get_len(graph, node.inputs[1]);
2468                Thunk::Gather {
2469                    table: node_offset(arena, node.inputs[0]),
2470                    table_len: table_total as u32,
2471                    idx: node_offset(arena, node.inputs[1]),
2472                    dst: node_offset(arena, node.id),
2473                    num_idx: idx_len as u32,
2474                    trailing: trailing as u32,
2475                }
2476            }
2477
2478            Op::Gather { axis } => {
2479                // Non-zero axis: outer × num_idx × trailing layout.
2480                let table_shape = &graph.node(node.inputs[0]).shape;
2481                let rank = table_shape.rank();
2482                let outer: usize = (0..*axis)
2483                    .map(|i| table_shape.dim(i).unwrap_static())
2484                    .product::<usize>()
2485                    .max(1);
2486                let trailing: usize = (*axis + 1..rank)
2487                    .map(|i| table_shape.dim(i).unwrap_static())
2488                    .product::<usize>()
2489                    .max(1);
2490                let axis_dim = table_shape.dim(*axis).unwrap_static();
2491                let idx_len = get_len(graph, node.inputs[1]);
2492                Thunk::GatherAxis {
2493                    table: node_offset(arena, node.inputs[0]),
2494                    idx: node_offset(arena, node.inputs[1]),
2495                    dst: node_offset(arena, node.id),
2496                    outer: outer as u32,
2497                    axis_dim: axis_dim as u32,
2498                    num_idx: idx_len as u32,
2499                    trailing: trailing as u32,
2500                }
2501            }
2502
2503            Op::Narrow { axis, start, len } => {
2504                let in_shape = &graph.node(node.inputs[0]).shape;
2505                let elem_bytes = in_shape.dtype().size_bytes() as u8;
2506                let rank = in_shape.rank();
2507                let outer: usize = (0..*axis)
2508                    .map(|i| in_shape.dim(i).unwrap_static())
2509                    .product::<usize>()
2510                    .max(1);
2511                let inner: usize = (*axis + 1..rank)
2512                    .map(|i| in_shape.dim(i).unwrap_static())
2513                    .product::<usize>()
2514                    .max(1);
2515                let in_axis = in_shape.dim(*axis).unwrap_static();
2516                let src_byte_offset =
2517                    node_offset(arena, node.inputs[0]) + start * inner * elem_bytes as usize;
2518                Thunk::Narrow {
2519                    src: src_byte_offset,
2520                    dst: node_offset(arena, node.id),
2521                    outer: outer as u32,
2522                    src_stride: (in_axis * inner) as u32, // elements per outer step in source
2523                    dst_stride: (*len * inner) as u32,    // elements per outer step in dest
2524                    inner: (*len * inner) as u32,         // elements to copy per outer step
2525                    elem_bytes,
2526                }
2527            }
2528
2529            Op::Reshape { .. } | Op::Cast { .. } => {
2530                // Pure layout/dtype change: same total element count, plain copy.
2531                let len = node.shape.num_elements().unwrap();
2532                let src = node_offset(arena, node.inputs[0]);
2533                let dst = node_offset(arena, node.id);
2534                match node.shape.dtype() {
2535                    rlx_ir::DType::F64 => Thunk::CopyF64 {
2536                        src,
2537                        dst,
2538                        len: len as u32,
2539                    },
2540                    _ => Thunk::Copy {
2541                        src,
2542                        dst,
2543                        len: len as u32,
2544                    },
2545                }
2546            }
2547
2548            Op::Quantize {
2549                axis,
2550                scales,
2551                zero_points,
2552            } => {
2553                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2554                Thunk::Quantize {
2555                    x: node_offset(arena, node.inputs[0]),
2556                    q: node_offset(arena, node.id),
2557                    len: node.shape.num_elements().unwrap() as u32,
2558                    chan_axis: chan_axis as u32,
2559                    chan_dim: chan_dim as u32,
2560                    inner: inner as u32,
2561                    scales: scales.clone(),
2562                    zero_points: zero_points.clone(),
2563                }
2564            }
2565
2566            Op::FakeQuantize {
2567                bits,
2568                axis,
2569                ste,
2570                scale_mode,
2571            } => {
2572                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2573                let state_off = match scale_mode {
2574                    rlx_ir::op::ScaleMode::PerBatch => None,
2575                    rlx_ir::op::ScaleMode::EMA { .. } | rlx_ir::op::ScaleMode::Fixed => {
2576                        // Second input carries the [chan_dim] scale state.
2577                        debug_assert_eq!(
2578                            node.inputs.len(),
2579                            2,
2580                            "EMA/Fixed FakeQuantize needs a state input"
2581                        );
2582                        Some(node_offset(arena, node.inputs[1]))
2583                    }
2584                };
2585                Thunk::FakeQuantize {
2586                    x: node_offset(arena, node.inputs[0]),
2587                    out: node_offset(arena, node.id),
2588                    len: node.shape.num_elements().unwrap() as u32,
2589                    chan_axis: chan_axis as u32,
2590                    chan_dim: chan_dim as u32,
2591                    inner: inner as u32,
2592                    bits: *bits,
2593                    ste: *ste,
2594                    scale_mode: *scale_mode,
2595                    state_off,
2596                }
2597            }
2598
2599            Op::FakeQuantizeLSQ { bits, axis } => {
2600                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2601                Thunk::FakeQuantizeLSQ {
2602                    x: node_offset(arena, node.inputs[0]),
2603                    scale_off: node_offset(arena, node.inputs[1]),
2604                    out: node_offset(arena, node.id),
2605                    len: node.shape.num_elements().unwrap() as u32,
2606                    chan_axis: chan_axis as u32,
2607                    chan_dim: chan_dim as u32,
2608                    inner: inner as u32,
2609                    bits: *bits,
2610                }
2611            }
2612
2613            Op::FakeQuantizeLSQBackwardX { bits, axis } => {
2614                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2615                Thunk::FakeQuantizeLSQBackwardX {
2616                    x: node_offset(arena, node.inputs[0]),
2617                    scale_off: node_offset(arena, node.inputs[1]),
2618                    dy: node_offset(arena, node.inputs[2]),
2619                    dx: node_offset(arena, node.id),
2620                    len: node.shape.num_elements().unwrap() as u32,
2621                    chan_axis: chan_axis as u32,
2622                    chan_dim: chan_dim as u32,
2623                    inner: inner as u32,
2624                    bits: *bits,
2625                }
2626            }
2627
2628            Op::FakeQuantizeLSQBackwardScale { bits, axis } => {
2629                // Output shape is [chan_dim] — node.shape doesn't
2630                // describe the input data layout, but inputs[0] does.
2631                let in_shape = &graph.node(node.inputs[0]).shape;
2632                let (chan_axis, chan_dim, inner) = quant_layout(in_shape, *axis);
2633                Thunk::FakeQuantizeLSQBackwardScale {
2634                    x: node_offset(arena, node.inputs[0]),
2635                    scale_off: node_offset(arena, node.inputs[1]),
2636                    dy: node_offset(arena, node.inputs[2]),
2637                    dscale: node_offset(arena, node.id),
2638                    len: in_shape.num_elements().unwrap() as u32,
2639                    chan_axis: chan_axis as u32,
2640                    chan_dim: chan_dim as u32,
2641                    inner: inner as u32,
2642                    bits: *bits,
2643                }
2644            }
2645
2646            Op::FakeQuantizeBackward { bits, axis, ste } => {
2647                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2648                Thunk::FakeQuantizeBackward {
2649                    x: node_offset(arena, node.inputs[0]),
2650                    dy: node_offset(arena, node.inputs[1]),
2651                    dx: node_offset(arena, node.id),
2652                    len: node.shape.num_elements().unwrap() as u32,
2653                    chan_axis: chan_axis as u32,
2654                    chan_dim: chan_dim as u32,
2655                    inner: inner as u32,
2656                    bits: *bits,
2657                    ste: *ste,
2658                }
2659            }
2660
2661            Op::Dequantize {
2662                axis,
2663                scales,
2664                zero_points,
2665            } => {
2666                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2667                Thunk::Dequantize {
2668                    q: node_offset(arena, node.inputs[0]),
2669                    x: node_offset(arena, node.id),
2670                    len: node.shape.num_elements().unwrap() as u32,
2671                    chan_axis: chan_axis as u32,
2672                    chan_dim: chan_dim as u32,
2673                    inner: inner as u32,
2674                    scales: scales.clone(),
2675                    zero_points: zero_points.clone(),
2676                }
2677            }
2678
2679            Op::Expand { .. } => {
2680                // Broadcast: build per-output-dim strides where any input dim
2681                // of size 1 has stride 0 (read the same element repeatedly).
2682                // Reuses the Thunk::Transpose runtime — N-D walk with strides
2683                // is identical; only the strides differ.
2684                let in_shape = &graph.node(node.inputs[0]).shape;
2685                let out_shape = &node.shape;
2686                let in_rank = in_shape.rank();
2687                let out_rank = out_shape.rank();
2688                // Implicit leading 1s if input has lower rank.
2689                let pad = out_rank.saturating_sub(in_rank);
2690                let in_dims: Vec<usize> = (0..out_rank)
2691                    .map(|i| {
2692                        if i < pad {
2693                            1
2694                        } else {
2695                            in_shape.dim(i - pad).unwrap_static()
2696                        }
2697                    })
2698                    .collect();
2699                // Row-major input strides (over the padded shape).
2700                let mut in_strides_full = vec![1usize; out_rank];
2701                for d in (0..out_rank.saturating_sub(1)).rev() {
2702                    in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
2703                }
2704                let out_dims: Vec<u32> = (0..out_rank)
2705                    .map(|i| out_shape.dim(i).unwrap_static() as u32)
2706                    .collect();
2707                // Stride is 0 for broadcast dims (in_dim == 1 && out_dim > 1).
2708                let in_strides: Vec<u32> = (0..out_rank)
2709                    .map(|i| {
2710                        if in_dims[i] == 1 && (out_dims[i] as usize) > 1 {
2711                            0
2712                        } else {
2713                            in_strides_full[i] as u32
2714                        }
2715                    })
2716                    .collect();
2717                let in_total = in_dims.iter().product::<usize>() as u32;
2718                let src = node_offset(arena, node.inputs[0]);
2719                let dst = node_offset(arena, node.id);
2720                match node.shape.dtype() {
2721                    rlx_ir::DType::F64 => Thunk::TransposeF64 {
2722                        src,
2723                        dst,
2724                        in_total,
2725                        out_dims,
2726                        in_strides,
2727                    },
2728                    _ => Thunk::Transpose {
2729                        src,
2730                        dst,
2731                        in_total,
2732                        out_dims,
2733                        in_strides,
2734                    },
2735                }
2736            }
2737
2738            Op::RmsNorm { eps, .. } => {
2739                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2740                let total = node.shape.num_elements().unwrap();
2741                Thunk::RmsNorm {
2742                    src: node_offset(arena, node.inputs[0]),
2743                    g: node_offset(arena, node.inputs[1]),
2744                    b: node_offset(arena, node.inputs[2]),
2745                    dst: node_offset(arena, node.id),
2746                    rows: (total / h) as u32,
2747                    h: h as u32,
2748                    eps: *eps,
2749                }
2750            }
2751
2752            Op::LayerNorm { eps, .. } => {
2753                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2754                let total = node.shape.num_elements().unwrap();
2755                Thunk::LayerNorm {
2756                    src: node_offset(arena, node.inputs[0]),
2757                    g: node_offset(arena, node.inputs[1]),
2758                    b: node_offset(arena, node.inputs[2]),
2759                    dst: node_offset(arena, node.id),
2760                    rows: (total / h) as u32,
2761                    h: h as u32,
2762                    eps: *eps,
2763                }
2764            }
2765
2766            Op::GroupNorm { num_groups, eps } => {
2767                let in_shape = &graph.node(node.inputs[0]).shape;
2768                Thunk::GroupNorm {
2769                    src: node_offset(arena, node.inputs[0]),
2770                    g: node_offset(arena, node.inputs[1]),
2771                    b: node_offset(arena, node.inputs[2]),
2772                    dst: node_offset(arena, node.id),
2773                    n: in_shape.dim(0).unwrap_static() as u32,
2774                    c: in_shape.dim(1).unwrap_static() as u32,
2775                    h: in_shape.dim(2).unwrap_static() as u32,
2776                    w: in_shape.dim(3).unwrap_static() as u32,
2777                    num_groups: *num_groups as u32,
2778                    eps: *eps,
2779                }
2780            }
2781
2782            Op::LayerNorm2d { eps } => {
2783                let in_shape = &graph.node(node.inputs[0]).shape;
2784                Thunk::LayerNorm2d {
2785                    src: node_offset(arena, node.inputs[0]),
2786                    g: node_offset(arena, node.inputs[1]),
2787                    b: node_offset(arena, node.inputs[2]),
2788                    dst: node_offset(arena, node.id),
2789                    n: in_shape.dim(0).unwrap_static() as u32,
2790                    c: in_shape.dim(1).unwrap_static() as u32,
2791                    h: in_shape.dim(2).unwrap_static() as u32,
2792                    w: in_shape.dim(3).unwrap_static() as u32,
2793                    eps: *eps,
2794                }
2795            }
2796
2797            Op::ConvTranspose2d {
2798                kernel_size,
2799                stride,
2800                padding,
2801                dilation,
2802                output_padding: _,
2803                groups,
2804            } => {
2805                let in_shape = &graph.node(node.inputs[0]).shape;
2806                let out_shape = &node.shape;
2807                Thunk::ConvTranspose2d {
2808                    src: node_offset(arena, node.inputs[0]),
2809                    weight: node_offset(arena, node.inputs[1]),
2810                    dst: node_offset(arena, node.id),
2811                    n: in_shape.dim(0).unwrap_static() as u32,
2812                    c_in: in_shape.dim(1).unwrap_static() as u32,
2813                    h: in_shape.dim(2).unwrap_static() as u32,
2814                    w_in: in_shape.dim(3).unwrap_static() as u32,
2815                    c_out: out_shape.dim(1).unwrap_static() as u32,
2816                    h_out: out_shape.dim(2).unwrap_static() as u32,
2817                    w_out: out_shape.dim(3).unwrap_static() as u32,
2818                    kh: kernel_size[0] as u32,
2819                    kw: kernel_size[1] as u32,
2820                    sh: stride.first().copied().unwrap_or(1) as u32,
2821                    sw: stride.get(1).copied().unwrap_or(1) as u32,
2822                    ph: padding.first().copied().unwrap_or(0) as u32,
2823                    pw: padding.get(1).copied().unwrap_or(0) as u32,
2824                    dh: dilation.first().copied().unwrap_or(1) as u32,
2825                    dw: dilation.get(1).copied().unwrap_or(1) as u32,
2826                    groups: *groups as u32,
2827                }
2828            }
2829
2830            Op::ResizeNearest2x => {
2831                let in_shape = &graph.node(node.inputs[0]).shape;
2832                Thunk::ResizeNearest2x {
2833                    src: node_offset(arena, node.inputs[0]),
2834                    dst: node_offset(arena, node.id),
2835                    n: in_shape.dim(0).unwrap_static() as u32,
2836                    c: in_shape.dim(1).unwrap_static() as u32,
2837                    h: in_shape.dim(2).unwrap_static() as u32,
2838                    w: in_shape.dim(3).unwrap_static() as u32,
2839                }
2840            }
2841
2842            Op::AxialRope2d {
2843                end_x,
2844                end_y,
2845                head_dim,
2846                num_heads,
2847                theta,
2848                repeat_factor,
2849            } => {
2850                let in_shape = &graph.node(node.inputs[0]).shape;
2851                let batch = in_shape.dim(0).unwrap_static() as u32;
2852                let seq = in_shape.dim(1).unwrap_static() as u32;
2853                let hidden = in_shape.dim(2).unwrap_static() as u32;
2854                Thunk::AxialRope2d {
2855                    src: node_offset(arena, node.inputs[0]),
2856                    dst: node_offset(arena, node.id),
2857                    batch,
2858                    seq,
2859                    hidden,
2860                    end_x: *end_x as u32,
2861                    end_y: *end_y as u32,
2862                    head_dim: *head_dim as u32,
2863                    num_heads: *num_heads as u32,
2864                    theta: *theta,
2865                    repeat_factor: *repeat_factor as u32,
2866                }
2867            }
2868
2869            Op::Softmax { axis } => {
2870                let rank = node.shape.rank();
2871                let ax = if *axis < 0 {
2872                    (rank as i32 + axis) as usize
2873                } else {
2874                    *axis as usize
2875                };
2876                let cols = node.shape.dim(ax).unwrap_static();
2877                let total = node.shape.num_elements().unwrap();
2878                let in_off = node_offset(arena, node.inputs[0]);
2879                let out_off = node_offset(arena, node.id);
2880                // Softmax kernel runs in-place on its data buffer. If the
2881                // planner gave input and output separate slots (their live
2882                // ranges overlap, so no aliasing), the output starts
2883                // uninitialized — emit a Copy first so the data is there.
2884                // Same pattern as Op::Activation.
2885                if in_off != out_off {
2886                    thunks.push(Thunk::Copy {
2887                        src: in_off,
2888                        dst: out_off,
2889                        len: total as u32,
2890                    });
2891                }
2892                Thunk::Softmax {
2893                    data: out_off,
2894                    rows: (total / cols) as u32,
2895                    cols: cols as u32,
2896                }
2897            }
2898
2899            Op::SelectiveScan { state_size } => {
2900                let in_shape = &graph.node(node.inputs[0]).shape;
2901                let (batch, seq, hidden) = (
2902                    in_shape.dim(0).unwrap_static(),
2903                    in_shape.dim(1).unwrap_static(),
2904                    in_shape.dim(2).unwrap_static(),
2905                );
2906                Thunk::SelectiveScan {
2907                    x: node_offset(arena, node.inputs[0]),
2908                    delta: node_offset(arena, node.inputs[1]),
2909                    a: node_offset(arena, node.inputs[2]),
2910                    b: node_offset(arena, node.inputs[3]),
2911                    c: node_offset(arena, node.inputs[4]),
2912                    dst: node_offset(arena, node.id),
2913                    batch: batch as u32,
2914                    seq: seq as u32,
2915                    hidden: hidden as u32,
2916                    state_size: *state_size as u32,
2917                }
2918            }
2919
2920            Op::GatedDeltaNet {
2921                state_size,
2922                carry_state,
2923            } => {
2924                let q_shape = &graph.node(node.inputs[0]).shape;
2925                let (batch, seq, heads) = (
2926                    q_shape.dim(0).unwrap_static(),
2927                    q_shape.dim(1).unwrap_static(),
2928                    q_shape.dim(2).unwrap_static(),
2929                );
2930                let state_off = if *carry_state {
2931                    node_offset(arena, node.inputs[5])
2932                } else {
2933                    0
2934                };
2935                Thunk::GatedDeltaNet {
2936                    q: node_offset(arena, node.inputs[0]),
2937                    k: node_offset(arena, node.inputs[1]),
2938                    v: node_offset(arena, node.inputs[2]),
2939                    g: node_offset(arena, node.inputs[3]),
2940                    beta: node_offset(arena, node.inputs[4]),
2941                    state: state_off,
2942                    dst: node_offset(arena, node.id),
2943                    batch: batch as u32,
2944                    seq: seq as u32,
2945                    heads: heads as u32,
2946                    state_size: *state_size as u32,
2947                }
2948            }
2949
2950            Op::QMatMul {
2951                x_zp,
2952                w_zp,
2953                out_zp,
2954                mult,
2955            } => {
2956                let x_shape = &graph.node(node.inputs[0]).shape;
2957                let w_shape = &graph.node(node.inputs[1]).shape;
2958                let m = x_shape.dim(0).unwrap_static();
2959                let k = x_shape.dim(1).unwrap_static();
2960                let n = w_shape.dim(1).unwrap_static();
2961                Thunk::QMatMul {
2962                    x: node_offset(arena, node.inputs[0]),
2963                    w: node_offset(arena, node.inputs[1]),
2964                    bias: node_offset(arena, node.inputs[2]),
2965                    out: node_offset(arena, node.id),
2966                    m: m as u32,
2967                    k: k as u32,
2968                    n: n as u32,
2969                    x_zp: *x_zp,
2970                    w_zp: *w_zp,
2971                    out_zp: *out_zp,
2972                    mult: *mult,
2973                }
2974            }
2975
2976            Op::QConv2d {
2977                kernel_size,
2978                stride,
2979                padding,
2980                dilation,
2981                groups,
2982                x_zp,
2983                w_zp,
2984                out_zp,
2985                mult,
2986            } => {
2987                let in_shape = &graph.node(node.inputs[0]).shape;
2988                let w_shape = &graph.node(node.inputs[1]).shape;
2989                let out_shape = &node.shape;
2990                if kernel_size.len() == 2
2991                    && in_shape.rank() == 4
2992                    && w_shape.rank() == 4
2993                    && out_shape.rank() == 4
2994                {
2995                    Thunk::QConv2d {
2996                        x: node_offset(arena, node.inputs[0]),
2997                        w: node_offset(arena, node.inputs[1]),
2998                        bias: node_offset(arena, node.inputs[2]),
2999                        out: node_offset(arena, node.id),
3000                        n: in_shape.dim(0).unwrap_static() as u32,
3001                        c_in: in_shape.dim(1).unwrap_static() as u32,
3002                        h: in_shape.dim(2).unwrap_static() as u32,
3003                        w_in: in_shape.dim(3).unwrap_static() as u32,
3004                        c_out: out_shape.dim(1).unwrap_static() as u32,
3005                        h_out: out_shape.dim(2).unwrap_static() as u32,
3006                        w_out: out_shape.dim(3).unwrap_static() as u32,
3007                        kh: kernel_size[0] as u32,
3008                        kw: kernel_size[1] as u32,
3009                        sh: stride.first().copied().unwrap_or(1) as u32,
3010                        sw: stride.get(1).copied().unwrap_or(1) as u32,
3011                        ph: padding.first().copied().unwrap_or(0) as u32,
3012                        pw: padding.get(1).copied().unwrap_or(0) as u32,
3013                        dh: dilation.first().copied().unwrap_or(1) as u32,
3014                        dw: dilation.get(1).copied().unwrap_or(1) as u32,
3015                        groups: *groups as u32,
3016                        x_zp: *x_zp,
3017                        w_zp: *w_zp,
3018                        out_zp: *out_zp,
3019                        mult: *mult,
3020                    }
3021                } else {
3022                    Thunk::Nop
3023                }
3024            }
3025
3026            Op::DequantMatMul { scheme } => {
3027                use rlx_ir::quant::QuantScheme;
3028                let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3029                let total = node.shape.num_elements().unwrap();
3030                let m = total / n.max(1);
3031                let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3032                let k = x_total / m.max(1);
3033                if scheme.is_gguf() {
3034                    Thunk::DequantMatMulGguf {
3035                        x: node_offset(arena, node.inputs[0]),
3036                        w_q: node_offset(arena, node.inputs[1]),
3037                        dst: node_offset(arena, node.id),
3038                        m: m as u32,
3039                        k: k as u32,
3040                        n: n as u32,
3041                        scheme: *scheme,
3042                    }
3043                } else {
3044                    match scheme {
3045                        QuantScheme::Nvfp4Block => Thunk::DequantMatMulNvfp4 {
3046                            x: node_offset(arena, node.inputs[0]),
3047                            w_q: node_offset(arena, node.inputs[1]),
3048                            scale: node_offset(arena, node.inputs[2]),
3049                            global_scale: node_offset(arena, node.inputs[3]),
3050                            dst: node_offset(arena, node.id),
3051                            m: m as u32,
3052                            k: k as u32,
3053                            n: n as u32,
3054                        },
3055                        QuantScheme::Int4Block { block_size } => Thunk::DequantMatMulInt4 {
3056                            x: node_offset(arena, node.inputs[0]),
3057                            w_q: node_offset(arena, node.inputs[1]),
3058                            scale: node_offset(arena, node.inputs[2]),
3059                            zp: node_offset(arena, node.inputs[3]),
3060                            dst: node_offset(arena, node.id),
3061                            m: m as u32,
3062                            k: k as u32,
3063                            n: n as u32,
3064                            block_size: *block_size,
3065                            is_asymmetric: false,
3066                        },
3067                        QuantScheme::Fp8E4m3 => Thunk::DequantMatMulFp8 {
3068                            x: node_offset(arena, node.inputs[0]),
3069                            w_q: node_offset(arena, node.inputs[1]),
3070                            scale: node_offset(arena, node.inputs[2]),
3071                            dst: node_offset(arena, node.id),
3072                            m: m as u32,
3073                            k: k as u32,
3074                            n: n as u32,
3075                            e5m2: false,
3076                        },
3077                        QuantScheme::Fp8E5m2 => Thunk::DequantMatMulFp8 {
3078                            x: node_offset(arena, node.inputs[0]),
3079                            w_q: node_offset(arena, node.inputs[1]),
3080                            scale: node_offset(arena, node.inputs[2]),
3081                            dst: node_offset(arena, node.id),
3082                            m: m as u32,
3083                            k: k as u32,
3084                            n: n as u32,
3085                            e5m2: true,
3086                        },
3087                        QuantScheme::Int8Block { block_size } => Thunk::DequantMatMul {
3088                            x: node_offset(arena, node.inputs[0]),
3089                            w_q: node_offset(arena, node.inputs[1]),
3090                            scale: node_offset(arena, node.inputs[2]),
3091                            zp: node_offset(arena, node.inputs[3]),
3092                            dst: node_offset(arena, node.id),
3093                            m: m as u32,
3094                            k: k as u32,
3095                            n: n as u32,
3096                            block_size: *block_size,
3097                            is_asymmetric: false,
3098                        },
3099                        QuantScheme::Int8BlockAsym { block_size } => Thunk::DequantMatMul {
3100                            x: node_offset(arena, node.inputs[0]),
3101                            w_q: node_offset(arena, node.inputs[1]),
3102                            scale: node_offset(arena, node.inputs[2]),
3103                            zp: node_offset(arena, node.inputs[3]),
3104                            dst: node_offset(arena, node.id),
3105                            m: m as u32,
3106                            k: k as u32,
3107                            n: n as u32,
3108                            block_size: *block_size,
3109                            is_asymmetric: true,
3110                        },
3111                        other => panic!(
3112                            "DequantMatMul on CPU supports Int8/Int4/FP8/NVFP4 legacy or GGUF schemes; got {other}"
3113                        ),
3114                    }
3115                }
3116            }
3117
3118            Op::LoraMatMul { scale } => {
3119                // x [m, k], w [k, n], a [k, r], b [r, n].
3120                let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3121                let total = node.shape.num_elements().unwrap();
3122                let m = total / n.max(1);
3123                let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3124                let k = x_total / m.max(1);
3125                let a_total = graph.node(node.inputs[2]).shape.num_elements().unwrap();
3126                let r = a_total / k.max(1);
3127                Thunk::LoraMatMul {
3128                    x: node_offset(arena, node.inputs[0]),
3129                    w: node_offset(arena, node.inputs[1]),
3130                    a: node_offset(arena, node.inputs[2]),
3131                    b: node_offset(arena, node.inputs[3]),
3132                    dst: node_offset(arena, node.id),
3133                    m: m as u32,
3134                    k: k as u32,
3135                    n: n as u32,
3136                    r: r as u32,
3137                    scale: *scale,
3138                }
3139            }
3140
3141            Op::Sample {
3142                top_k,
3143                top_p,
3144                temperature,
3145                seed,
3146            } => {
3147                let in_shape = &graph.node(node.inputs[0]).shape;
3148                // Logits are [batch, vocab] (or [vocab] → batch=1).
3149                let (batch, vocab) = if in_shape.rank() >= 2 {
3150                    (
3151                        in_shape.dim(0).unwrap_static(),
3152                        in_shape.dim(in_shape.rank() - 1).unwrap_static(),
3153                    )
3154                } else {
3155                    (1, in_shape.num_elements().unwrap_or(0))
3156                };
3157                Thunk::Sample {
3158                    logits: node_offset(arena, node.inputs[0]),
3159                    dst: node_offset(arena, node.id),
3160                    batch: batch as u32,
3161                    vocab: vocab as u32,
3162                    top_k: *top_k as u32,
3163                    top_p: *top_p,
3164                    temperature: *temperature,
3165                    seed: *seed,
3166                }
3167            }
3168
3169            Op::Cumsum { axis, exclusive } => {
3170                // For now CPU only supports last-axis cumsum (the
3171                // common case for sampling / ragged offsets).
3172                // Other axes can lower via Transpose → Cumsum →
3173                // Transpose; not on the hot path today.
3174                let rank = node.shape.rank();
3175                let ax = if *axis < 0 {
3176                    (rank as i32 + axis) as usize
3177                } else {
3178                    *axis as usize
3179                };
3180                assert_eq!(
3181                    ax,
3182                    rank - 1,
3183                    "Cumsum only supports the last axis on CPU today"
3184                );
3185                let cols = node.shape.dim(ax).unwrap_static();
3186                let total = node.shape.num_elements().unwrap();
3187                Thunk::Cumsum {
3188                    src: node_offset(arena, node.inputs[0]),
3189                    dst: node_offset(arena, node.id),
3190                    rows: (total / cols) as u32,
3191                    cols: cols as u32,
3192                    exclusive: *exclusive,
3193                }
3194            }
3195
3196            Op::Attention {
3197                num_heads,
3198                head_dim,
3199                mask_kind,
3200            } => {
3201                // Layout dispatch: rank-4 input could be either
3202                // `[B, S, H, D]` (CPU's historical convention) or
3203                // `[B, H, S, D]` (the convention the GPU/TPU backends
3204                // share). Disambiguate by which axis matches
3205                // `num_heads`. Rank-3 is always `[B, S, H*D]`.
3206                let q_shape = &graph.node(node.inputs[0]).shape;
3207                let k_shape = &graph.node(node.inputs[1]).shape;
3208                let rank = q_shape.rank();
3209                let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3210                    let d1 = q_shape.dim(1).unwrap_static();
3211                    let d2 = q_shape.dim(2).unwrap_static();
3212                    if d1 == *num_heads {
3213                        // [B, H, S, D]
3214                        (
3215                            q_shape.dim(0).unwrap_static(),
3216                            d2,
3217                            k_shape.dim(2).unwrap_static(),
3218                            true,
3219                        )
3220                    } else {
3221                        // [B, S, H, D]
3222                        (
3223                            q_shape.dim(0).unwrap_static(),
3224                            d1,
3225                            k_shape.dim(1).unwrap_static(),
3226                            false,
3227                        )
3228                    }
3229                } else if rank >= 3 {
3230                    (
3231                        q_shape.dim(0).unwrap_static(),
3232                        q_shape.dim(1).unwrap_static(),
3233                        k_shape.dim(1).unwrap_static(),
3234                        false,
3235                    )
3236                } else {
3237                    (
3238                        1,
3239                        q_shape.dim(0).unwrap_static(),
3240                        k_shape.dim(0).unwrap_static(),
3241                        false,
3242                    )
3243                };
3244                let mask_off = if matches!(
3245                    mask_kind,
3246                    rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3247                ) {
3248                    node_offset(arena, node.inputs[3])
3249                } else {
3250                    0
3251                };
3252                let hs = (*num_heads * *head_dim) as u32;
3253                Thunk::Attention {
3254                    q: node_offset(arena, node.inputs[0]),
3255                    k: node_offset(arena, node.inputs[1]),
3256                    v: node_offset(arena, node.inputs[2]),
3257                    mask: mask_off,
3258                    out: node_offset(arena, node.id),
3259                    batch: batch as u32,
3260                    seq: seq as u32,
3261                    kv_seq: kv_seq as u32,
3262                    heads: *num_heads as u32,
3263                    head_dim: *head_dim as u32,
3264                    mask_kind: *mask_kind,
3265                    // Defaults: each input is its own contiguous buffer
3266                    // with row stride = hidden. Rewritten by the
3267                    // Narrow→Attention fusion when applicable.
3268                    q_row_stride: hs,
3269                    k_row_stride: hs,
3270                    v_row_stride: hs,
3271                    bhsd,
3272                }
3273            }
3274
3275            Op::AttentionBackward {
3276                num_heads,
3277                head_dim,
3278                mask_kind,
3279                wrt,
3280            } => {
3281                let q_shape = &graph.node(node.inputs[0]).shape;
3282                let k_shape = &graph.node(node.inputs[1]).shape;
3283                let rank = q_shape.rank();
3284                let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3285                    let d1 = q_shape.dim(1).unwrap_static();
3286                    let d2 = q_shape.dim(2).unwrap_static();
3287                    if d1 == *num_heads {
3288                        (
3289                            q_shape.dim(0).unwrap_static(),
3290                            d2,
3291                            k_shape.dim(2).unwrap_static(),
3292                            true,
3293                        )
3294                    } else {
3295                        (
3296                            q_shape.dim(0).unwrap_static(),
3297                            d1,
3298                            k_shape.dim(1).unwrap_static(),
3299                            false,
3300                        )
3301                    }
3302                } else if rank >= 3 {
3303                    (
3304                        q_shape.dim(0).unwrap_static(),
3305                        q_shape.dim(1).unwrap_static(),
3306                        k_shape.dim(1).unwrap_static(),
3307                        false,
3308                    )
3309                } else {
3310                    (
3311                        1,
3312                        q_shape.dim(0).unwrap_static(),
3313                        k_shape.dim(0).unwrap_static(),
3314                        false,
3315                    )
3316                };
3317                let mask_off = if matches!(
3318                    mask_kind,
3319                    rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3320                ) {
3321                    node_offset(arena, node.inputs[4])
3322                } else {
3323                    0
3324                };
3325                Thunk::AttentionBackward {
3326                    q: node_offset(arena, node.inputs[0]),
3327                    k: node_offset(arena, node.inputs[1]),
3328                    v: node_offset(arena, node.inputs[2]),
3329                    dy: node_offset(arena, node.inputs[3]),
3330                    mask: mask_off,
3331                    out: node_offset(arena, node.id),
3332                    batch: batch as u32,
3333                    seq: seq as u32,
3334                    kv_seq: kv_seq as u32,
3335                    heads: *num_heads as u32,
3336                    head_dim: *head_dim as u32,
3337                    mask_kind: *mask_kind,
3338                    wrt: *wrt,
3339                    bhsd,
3340                }
3341            }
3342
3343            Op::FusedAttentionBlock {
3344                num_heads,
3345                head_dim,
3346                has_bias,
3347                has_rope,
3348            } => {
3349                let x_shape = &graph.node(node.inputs[0]).shape;
3350                let (batch, seq) = if x_shape.rank() >= 3 {
3351                    (
3352                        x_shape.dim(0).unwrap_static(),
3353                        x_shape.dim(1).unwrap_static(),
3354                    )
3355                } else {
3356                    let total = x_shape.num_elements().unwrap();
3357                    let s = x_shape.dim(x_shape.rank() - 2).unwrap_static();
3358                    (total / (s * num_heads * head_dim), s)
3359                };
3360                let hs = (*num_heads * *head_dim) as u32;
3361                // Inputs: hidden, qkv_w, out_w, mask, [qkv_b, out_b], [cos, sin]
3362                let mut idx = 4;
3363                let (qkv_b_off, out_b_off) = if *has_bias {
3364                    let qb = node_offset(arena, node.inputs[idx]);
3365                    let ob = node_offset(arena, node.inputs[idx + 1]);
3366                    idx += 2;
3367                    (qb, ob)
3368                } else {
3369                    (0, 0)
3370                };
3371                let (cos_off, sin_off, cl) = if *has_rope {
3372                    let c = node_offset(arena, node.inputs[idx]);
3373                    let s = node_offset(arena, node.inputs[idx + 1]);
3374                    let clen = get_len(graph, node.inputs[idx]);
3375                    (c, s, clen as u32)
3376                } else {
3377                    (0, 0, 0)
3378                };
3379
3380                Thunk::FusedAttnBlock {
3381                    hidden: node_offset(arena, node.inputs[0]),
3382                    qkv_w: node_offset(arena, node.inputs[1]),
3383                    out_w: node_offset(arena, node.inputs[2]),
3384                    mask: node_offset(arena, node.inputs[3]),
3385                    out: node_offset(arena, node.id),
3386                    qkv_b: qkv_b_off,
3387                    out_b: out_b_off,
3388                    cos: cos_off,
3389                    sin: sin_off,
3390                    cos_len: cl,
3391                    batch: batch as u32,
3392                    seq: seq as u32,
3393                    hs,
3394                    nh: *num_heads as u32,
3395                    dh: *head_dim as u32,
3396                    has_bias: *has_bias,
3397                    has_rope: *has_rope,
3398                }
3399            }
3400
3401            Op::Rope { head_dim, n_rot } => {
3402                let x_shape = &graph.node(node.inputs[0]).shape;
3403                let (batch, seq, hidden) = if x_shape.rank() >= 3 {
3404                    (
3405                        x_shape.dim(0).unwrap_static(),
3406                        x_shape.dim(1).unwrap_static(),
3407                        x_shape.dim(2).unwrap_static(),
3408                    )
3409                } else {
3410                    let total = x_shape.num_elements().unwrap();
3411                    (
3412                        1,
3413                        x_shape.dim(0).unwrap_static(),
3414                        total / x_shape.dim(0).unwrap_static(),
3415                    )
3416                };
3417                let cos_len = get_len(graph, node.inputs[1]);
3418                Thunk::Rope {
3419                    src: node_offset(arena, node.inputs[0]),
3420                    cos: node_offset(arena, node.inputs[1]),
3421                    sin: node_offset(arena, node.inputs[2]),
3422                    dst: node_offset(arena, node.id),
3423                    batch: batch as u32,
3424                    seq: seq as u32,
3425                    hidden: hidden as u32,
3426                    head_dim: *head_dim as u32,
3427                    n_rot: *n_rot as u32,
3428                    cos_len: cos_len as u32,
3429                    // Default: source rows are tightly packed (rewritten
3430                    // by the Narrow→Rope fusion pass below if Rope ends
3431                    // up reading from a wider parent like QKV).
3432                    src_row_stride: hidden as u32,
3433                }
3434            }
3435
3436            Op::FusedSwiGLU {
3437                cast_to: _,
3438                gate_first,
3439            } => {
3440                let n_half = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3441                let total = node.shape.num_elements().unwrap();
3442                Thunk::FusedSwiGLU {
3443                    src: node_offset(arena, node.inputs[0]),
3444                    dst: node_offset(arena, node.id),
3445                    n_half: n_half as u32,
3446                    total: total as u32,
3447                    gate_first: *gate_first,
3448                }
3449            }
3450
3451            Op::Conv {
3452                kernel_size,
3453                stride,
3454                padding,
3455                dilation,
3456                groups,
3457            } => {
3458                let in_shape = &graph.node(node.inputs[0]).shape;
3459                let w_shape = &graph.node(node.inputs[1]).shape;
3460                let out_shape = &node.shape;
3461                // 1×1 fast path (plan #26): kH=kW=1, stride=1,
3462                // padding=0, dilation=1, groups=1. Emits a single
3463                // Conv2D1x1 thunk that BLAS-dispatches per batch.
3464                let is_1x1_simple = kernel_size.len() == 2
3465                    && kernel_size[0] == 1
3466                    && kernel_size[1] == 1
3467                    && stride.iter().all(|&s| s == 1)
3468                    && padding.iter().all(|&p| p == 0)
3469                    && dilation.iter().all(|&d| d == 1)
3470                    && *groups == 1;
3471                if is_1x1_simple && in_shape.rank() == 4 && out_shape.rank() == 4 {
3472                    let n = in_shape.dim(0).unwrap_static();
3473                    let c_in = in_shape.dim(1).unwrap_static();
3474                    let c_out = out_shape.dim(1).unwrap_static();
3475                    let h = in_shape.dim(2).unwrap_static();
3476                    let w = in_shape.dim(3).unwrap_static();
3477                    Thunk::Conv2D1x1 {
3478                        src: node_offset(arena, node.inputs[0]),
3479                        weight: node_offset(arena, node.inputs[1]),
3480                        dst: node_offset(arena, node.id),
3481                        n: n as u32,
3482                        c_in: c_in as u32,
3483                        c_out: c_out as u32,
3484                        hw: (h * w) as u32,
3485                    }
3486                } else if kernel_size.len() == 2
3487                    && in_shape.rank() == 4
3488                    && w_shape.rank() == 4
3489                    && out_shape.rank() == 4
3490                {
3491                    Thunk::Conv2D {
3492                        src: node_offset(arena, node.inputs[0]),
3493                        weight: node_offset(arena, node.inputs[1]),
3494                        dst: node_offset(arena, node.id),
3495                        n: in_shape.dim(0).unwrap_static() as u32,
3496                        c_in: in_shape.dim(1).unwrap_static() as u32,
3497                        h: in_shape.dim(2).unwrap_static() as u32,
3498                        w: in_shape.dim(3).unwrap_static() as u32,
3499                        c_out: out_shape.dim(1).unwrap_static() as u32,
3500                        h_out: out_shape.dim(2).unwrap_static() as u32,
3501                        w_out: out_shape.dim(3).unwrap_static() as u32,
3502                        kh: kernel_size[0] as u32,
3503                        kw: kernel_size[1] as u32,
3504                        sh: stride.first().copied().unwrap_or(1) as u32,
3505                        sw: stride.get(1).copied().unwrap_or(1) as u32,
3506                        ph: padding.first().copied().unwrap_or(0) as u32,
3507                        pw: padding.get(1).copied().unwrap_or(0) as u32,
3508                        dh: dilation.first().copied().unwrap_or(1) as u32,
3509                        dw: dilation.get(1).copied().unwrap_or(1) as u32,
3510                        groups: *groups as u32,
3511                    }
3512                } else {
3513                    Thunk::Nop
3514                }
3515            }
3516
3517            Op::Pool {
3518                kind,
3519                kernel_size,
3520                stride,
3521                padding,
3522            } => {
3523                // Currently support 2D pooling on rank-4 NCHW tensors.
3524                let in_shape = &graph.node(node.inputs[0]).shape;
3525                let out_shape = &node.shape;
3526                if kernel_size.len() == 2 && in_shape.rank() == 4 && out_shape.rank() == 4 {
3527                    Thunk::Pool2D {
3528                        src: node_offset(arena, node.inputs[0]),
3529                        dst: node_offset(arena, node.id),
3530                        n: in_shape.dim(0).unwrap_static() as u32,
3531                        c: in_shape.dim(1).unwrap_static() as u32,
3532                        h: in_shape.dim(2).unwrap_static() as u32,
3533                        w: in_shape.dim(3).unwrap_static() as u32,
3534                        h_out: out_shape.dim(2).unwrap_static() as u32,
3535                        w_out: out_shape.dim(3).unwrap_static() as u32,
3536                        kh: kernel_size[0] as u32,
3537                        kw: kernel_size[1] as u32,
3538                        sh: stride.first().copied().unwrap_or(1) as u32,
3539                        sw: stride.get(1).copied().unwrap_or(1) as u32,
3540                        ph: padding.first().copied().unwrap_or(0) as u32,
3541                        pw: padding.get(1).copied().unwrap_or(0) as u32,
3542                        kind: *kind,
3543                    }
3544                } else {
3545                    Thunk::Nop
3546                }
3547            }
3548
3549            Op::Transpose { perm } => {
3550                // Pre-compute (out_dims, in_strides_for_each_out_dim) so the
3551                // runtime loop is just an N-D index walk + scatter.
3552                let in_shape = &graph.node(node.inputs[0]).shape;
3553                let in_rank = in_shape.rank();
3554                let in_dims: Vec<usize> = (0..in_rank)
3555                    .map(|i| in_shape.dim(i).unwrap_static())
3556                    .collect();
3557                // Row-major input strides: stride[d] = product of dims[d+1..].
3558                let mut in_strides_full = vec![1usize; in_rank];
3559                for d in (0..in_rank.saturating_sub(1)).rev() {
3560                    in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
3561                }
3562                let out_dims: Vec<u32> = perm.iter().map(|&p| in_dims[p] as u32).collect();
3563                let in_strides: Vec<u32> =
3564                    perm.iter().map(|&p| in_strides_full[p] as u32).collect();
3565                let in_total = in_dims.iter().product::<usize>() as u32;
3566                let src = node_offset(arena, node.inputs[0]);
3567                let dst = node_offset(arena, node.id);
3568                match node.shape.dtype() {
3569                    rlx_ir::DType::F64 => Thunk::TransposeF64 {
3570                        src,
3571                        dst,
3572                        in_total,
3573                        out_dims,
3574                        in_strides,
3575                    },
3576                    _ => Thunk::Transpose {
3577                        src,
3578                        dst,
3579                        in_total,
3580                        out_dims,
3581                        in_strides,
3582                    },
3583                }
3584            }
3585
3586            Op::ScatterAdd => {
3587                // updates: [num_updates, ...trailing], indices: [num_updates],
3588                // output: [out_dim, ...trailing]
3589                let upd_shape = &graph.node(node.inputs[0]).shape;
3590                let out_shape = &node.shape;
3591                let num_updates = upd_shape.dim(0).unwrap_static();
3592                let out_dim = out_shape.dim(0).unwrap_static();
3593                let trailing: usize = (1..out_shape.rank())
3594                    .map(|i| out_shape.dim(i).unwrap_static())
3595                    .product::<usize>()
3596                    .max(1);
3597                Thunk::ScatterAdd {
3598                    updates: node_offset(arena, node.inputs[0]),
3599                    indices: node_offset(arena, node.inputs[1]),
3600                    dst: node_offset(arena, node.id),
3601                    num_updates: num_updates as u32,
3602                    out_dim: out_dim as u32,
3603                    trailing: trailing as u32,
3604                }
3605            }
3606
3607            Op::GroupedMatMul => {
3608                // Inputs: [input(M, K), weight(E, K, N), expert_idx(M)]
3609                let in_shape = &graph.node(node.inputs[0]).shape;
3610                let w_shape = &graph.node(node.inputs[1]).shape;
3611                let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
3612                let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
3613                let num_experts = w_shape.dim(0).unwrap_static();
3614                let n = w_shape.dim(2).unwrap_static();
3615                Thunk::GroupedMatMul {
3616                    input: node_offset(arena, node.inputs[0]),
3617                    weight: node_offset(arena, node.inputs[1]),
3618                    expert_idx: node_offset(arena, node.inputs[2]),
3619                    dst: node_offset(arena, node.id),
3620                    m: m as u32,
3621                    k_dim: k_dim as u32,
3622                    n: n as u32,
3623                    num_experts: num_experts as u32,
3624                }
3625            }
3626
3627            Op::DequantGroupedMatMul { scheme } => {
3628                let in_shape = &graph.node(node.inputs[0]).shape;
3629                let w_shape = &graph.node(node.inputs[1]).shape;
3630                let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
3631                let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
3632                let out_shape = &node.shape;
3633                let n = out_shape.dim(out_shape.rank() - 1).unwrap_static();
3634                let block_elems = scheme.gguf_block_size() as usize;
3635                let block_bytes = scheme.gguf_block_bytes() as usize;
3636                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
3637                let total_bytes = w_shape.num_elements().unwrap();
3638                let num_experts = total_bytes / slab_bytes.max(1);
3639                Thunk::DequantGroupedMatMulGguf {
3640                    input: node_offset(arena, node.inputs[0]),
3641                    w_q: node_offset(arena, node.inputs[1]),
3642                    expert_idx: node_offset(arena, node.inputs[2]),
3643                    dst: node_offset(arena, node.id),
3644                    m: m as u32,
3645                    k_dim: k_dim as u32,
3646                    n: n as u32,
3647                    num_experts: num_experts as u32,
3648                    scheme: *scheme,
3649                }
3650            }
3651
3652            Op::DequantMoEWeights { scheme } => {
3653                let w_shape = &graph.node(node.inputs[0]).shape;
3654                let out_shape = &node.shape;
3655                let num_experts = out_shape.dim(0).unwrap_static();
3656                let k_dim = out_shape.dim(1).unwrap_static();
3657                let n = out_shape.dim(2).unwrap_static();
3658                let block_elems = scheme.gguf_block_size() as usize;
3659                let block_bytes = scheme.gguf_block_bytes() as usize;
3660                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
3661                let total_bytes = w_shape.num_elements().unwrap();
3662                assert_eq!(
3663                    total_bytes,
3664                    num_experts * slab_bytes,
3665                    "DequantMoEWeights packed bytes mismatch"
3666                );
3667                Thunk::DequantMoEWeightsGguf {
3668                    w_q: node_offset(arena, node.inputs[0]),
3669                    dst: node_offset(arena, node.id),
3670                    k_dim: k_dim as u32,
3671                    n: n as u32,
3672                    num_experts: num_experts as u32,
3673                    scheme: *scheme,
3674                }
3675            }
3676
3677            Op::TopK { k } => {
3678                let in_shape = &graph.node(node.inputs[0]).shape;
3679                let rank = in_shape.rank();
3680                let axis_dim = in_shape.dim(rank - 1).unwrap_static();
3681                let outer = in_shape.num_elements().unwrap() / axis_dim;
3682                Thunk::TopK {
3683                    src: node_offset(arena, node.inputs[0]),
3684                    dst: node_offset(arena, node.id),
3685                    outer: outer as u32,
3686                    axis_dim: axis_dim as u32,
3687                    k: *k as u32,
3688                }
3689            }
3690
3691            Op::Reduce {
3692                op,
3693                axes,
3694                keep_dim: _,
3695            } => {
3696                // Decompose the input shape into [outer, reduced, inner]
3697                // around the reduced axis range. Non-contiguous reduced
3698                // axes aren't supported here — caller must transpose them
3699                // contiguous first (the coverage tool would surface the
3700                // gap if a model needs it).
3701                let in_shape = &graph.node(node.inputs[0]).shape;
3702                let rank = in_shape.rank();
3703                let mut sorted = axes.clone();
3704                sorted.sort();
3705                sorted.dedup();
3706                let contiguous = sorted.windows(2).all(|w| w[1] == w[0] + 1)
3707                    && !sorted.is_empty()
3708                    && *sorted.last().unwrap() < rank;
3709                if !contiguous {
3710                    Thunk::Nop
3711                } else {
3712                    let first = sorted[0];
3713                    let last = *sorted.last().unwrap();
3714                    let outer: usize = (0..first)
3715                        .map(|i| in_shape.dim(i).unwrap_static())
3716                        .product::<usize>()
3717                        .max(1);
3718                    let reduced: usize = (first..=last)
3719                        .map(|i| in_shape.dim(i).unwrap_static())
3720                        .product();
3721                    let inner: usize = (last + 1..rank)
3722                        .map(|i| in_shape.dim(i).unwrap_static())
3723                        .product::<usize>()
3724                        .max(1);
3725                    let src = node_offset(arena, node.inputs[0]);
3726                    let dst = node_offset(arena, node.id);
3727                    if node.shape.dtype() == rlx_ir::DType::F64 && matches!(op, ReduceOp::Sum) {
3728                        Thunk::ReduceSumF64 {
3729                            src,
3730                            dst,
3731                            outer: outer as u32,
3732                            reduced: reduced as u32,
3733                            inner: inner as u32,
3734                        }
3735                    } else {
3736                        Thunk::Reduce {
3737                            src,
3738                            dst,
3739                            outer: outer as u32,
3740                            reduced: reduced as u32,
3741                            inner: inner as u32,
3742                            op: *op,
3743                        }
3744                    }
3745                }
3746            }
3747
3748            Op::Compare(cmp) => {
3749                let len = node.shape.num_elements().unwrap();
3750                Thunk::Compare {
3751                    lhs: node_offset(arena, node.inputs[0]),
3752                    rhs: node_offset(arena, node.inputs[1]),
3753                    dst: node_offset(arena, node.id),
3754                    len: len as u32,
3755                    op: *cmp,
3756                }
3757            }
3758
3759            Op::Where => {
3760                let len = node.shape.num_elements().unwrap();
3761                Thunk::Where {
3762                    cond: node_offset(arena, node.inputs[0]),
3763                    on_true: node_offset(arena, node.inputs[1]),
3764                    on_false: node_offset(arena, node.inputs[2]),
3765                    dst: node_offset(arena, node.id),
3766                    len: len as u32,
3767                }
3768            }
3769
3770            Op::ReluBackward => {
3771                let len: usize = (0..node.shape.rank())
3772                    .map(|i| node.shape.dim(i).unwrap_static())
3773                    .product();
3774                let x = node_offset(arena, node.inputs[0]);
3775                let dy = node_offset(arena, node.inputs[1]);
3776                let dx = node_offset(arena, node.id);
3777                match node.shape.dtype() {
3778                    rlx_ir::DType::F64 => Thunk::ReluBackwardF64 {
3779                        x,
3780                        dy,
3781                        dx,
3782                        len: len as u32,
3783                    },
3784                    _ => Thunk::ReluBackward {
3785                        x,
3786                        dy,
3787                        dx,
3788                        len: len as u32,
3789                    },
3790                }
3791            }
3792
3793            Op::ComplexNormSq => {
3794                let len: usize = (0..node.shape.rank())
3795                    .map(|i| node.shape.dim(i).unwrap_static())
3796                    .product();
3797                let src = node_offset(arena, node.inputs[0]);
3798                let dst = node_offset(arena, node.id);
3799                Thunk::ComplexNormSqF32 {
3800                    src,
3801                    dst,
3802                    len: len as u32,
3803                }
3804            }
3805
3806            Op::ComplexNormSqBackward => {
3807                let len: usize = (0..node.shape.rank())
3808                    .map(|i| node.shape.dim(i).unwrap_static())
3809                    .product();
3810                let z = node_offset(arena, node.inputs[0]);
3811                let g = node_offset(arena, node.inputs[1]);
3812                let dz = node_offset(arena, node.id);
3813                Thunk::ComplexNormSqBackwardF32 {
3814                    z,
3815                    g,
3816                    dz,
3817                    len: len as u32,
3818                }
3819            }
3820
3821            Op::Conjugate => {
3822                let len: usize = (0..node.shape.rank())
3823                    .map(|i| node.shape.dim(i).unwrap_static())
3824                    .product();
3825                Thunk::ConjugateC64 {
3826                    src: node_offset(arena, node.inputs[0]),
3827                    dst: node_offset(arena, node.id),
3828                    len: len as u32,
3829                }
3830            }
3831
3832            Op::ActivationBackward { kind } => {
3833                let len: usize = (0..node.shape.rank())
3834                    .map(|i| node.shape.dim(i).unwrap_static())
3835                    .product();
3836                let x = node_offset(arena, node.inputs[0]);
3837                let dy = node_offset(arena, node.inputs[1]);
3838                let dx = node_offset(arena, node.id);
3839                match node.shape.dtype() {
3840                    rlx_ir::DType::F64 => Thunk::ActivationBackwardF64 {
3841                        x,
3842                        dy,
3843                        dx,
3844                        len: len as u32,
3845                        kind: *kind,
3846                    },
3847                    _ => Thunk::ActivationBackward {
3848                        x,
3849                        dy,
3850                        dx,
3851                        len: len as u32,
3852                        kind: *kind,
3853                    },
3854                }
3855            }
3856
3857            Op::LayerNormBackwardInput { eps, .. } => {
3858                // axis = -1 only (matches forward LayerNorm thunk).
3859                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3860                let total = node.shape.num_elements().unwrap();
3861                Thunk::LayerNormBackwardInput {
3862                    x: node_offset(arena, node.inputs[0]),
3863                    gamma: node_offset(arena, node.inputs[1]),
3864                    dy: node_offset(arena, node.inputs[2]),
3865                    dx: node_offset(arena, node.id),
3866                    rows: (total / h) as u32,
3867                    h: h as u32,
3868                    eps: *eps,
3869                }
3870            }
3871
3872            Op::LayerNormBackwardGamma { eps, .. } => {
3873                let x_shape = &graph.node(node.inputs[0]).shape;
3874                let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
3875                let x_total = x_shape.num_elements().unwrap();
3876                Thunk::LayerNormBackwardGamma {
3877                    x: node_offset(arena, node.inputs[0]),
3878                    dy: node_offset(arena, node.inputs[1]),
3879                    dgamma: node_offset(arena, node.id),
3880                    rows: (x_total / h) as u32,
3881                    h: h as u32,
3882                    eps: *eps,
3883                }
3884            }
3885
3886            Op::RmsNormBackwardInput { eps, .. }
3887            | Op::RmsNormBackwardGamma { eps, .. }
3888            | Op::RmsNormBackwardBeta { eps, .. } => {
3889                let x_shape = &graph.node(node.inputs[0]).shape;
3890                let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
3891                let rows = (x_shape.num_elements().unwrap() / h) as u32;
3892                let off = |i: usize| node_offset(arena, node.inputs[i]);
3893                let common = (off(0), off(1), off(2), off(3), rows, h as u32, *eps);
3894                match &node.op {
3895                    Op::RmsNormBackwardInput { .. } => Thunk::RmsNormBackwardInput {
3896                        x: common.0,
3897                        gamma: common.1,
3898                        beta: common.2,
3899                        dy: common.3,
3900                        dx: node_offset(arena, node.id),
3901                        rows: common.4,
3902                        h: common.5,
3903                        eps: common.6,
3904                    },
3905                    Op::RmsNormBackwardGamma { .. } => Thunk::RmsNormBackwardGamma {
3906                        x: common.0,
3907                        gamma: common.1,
3908                        beta: common.2,
3909                        dy: common.3,
3910                        dgamma: node_offset(arena, node.id),
3911                        rows: common.4,
3912                        h: common.5,
3913                        eps: common.6,
3914                    },
3915                    Op::RmsNormBackwardBeta { .. } => Thunk::RmsNormBackwardBeta {
3916                        x: common.0,
3917                        gamma: common.1,
3918                        beta: common.2,
3919                        dy: common.3,
3920                        dbeta: node_offset(arena, node.id),
3921                        rows: common.4,
3922                        h: common.5,
3923                        eps: common.6,
3924                    },
3925                    _ => unreachable!(),
3926                }
3927            }
3928
3929            Op::RopeBackward { head_dim, n_rot } => {
3930                let dy_shape = &graph.node(node.inputs[0]).shape;
3931                let (batch, seq, hidden) = if dy_shape.rank() >= 3 {
3932                    (
3933                        dy_shape.dim(0).unwrap_static(),
3934                        dy_shape.dim(1).unwrap_static(),
3935                        dy_shape.dim(2).unwrap_static(),
3936                    )
3937                } else {
3938                    (
3939                        1,
3940                        dy_shape.dim(0).unwrap_static(),
3941                        dy_shape.dim(1).unwrap_static(),
3942                    )
3943                };
3944                let cos_shape = &graph.node(node.inputs[1]).shape;
3945                let cos_len = cos_shape.num_elements().unwrap();
3946                Thunk::RopeBackward {
3947                    dy: node_offset(arena, node.inputs[0]),
3948                    cos: node_offset(arena, node.inputs[1]),
3949                    sin: node_offset(arena, node.inputs[2]),
3950                    dx: node_offset(arena, node.id),
3951                    batch: batch as u32,
3952                    seq: seq as u32,
3953                    hidden: hidden as u32,
3954                    head_dim: *head_dim as u32,
3955                    n_rot: *n_rot as u32,
3956                    cos_len: cos_len as u32,
3957                }
3958            }
3959
3960            Op::CumsumBackward { exclusive, .. } => {
3961                let dy_shape = &graph.node(node.inputs[0]).shape;
3962                let rank = dy_shape.rank();
3963                let cols = dy_shape.dim(rank - 1).unwrap_static();
3964                let rows = dy_shape.num_elements().unwrap() / cols;
3965                Thunk::CumsumBackward {
3966                    dy: node_offset(arena, node.inputs[0]),
3967                    dx: node_offset(arena, node.id),
3968                    rows: rows as u32,
3969                    cols: cols as u32,
3970                    exclusive: *exclusive,
3971                }
3972            }
3973
3974            Op::GatherBackward { .. } => {
3975                let dy_shape = &graph.node(node.inputs[0]).shape;
3976                let idx_shape = &graph.node(node.inputs[1]).shape;
3977                let out_shape = &node.shape;
3978                let rank = out_shape.rank();
3979                let axis = match &node.op {
3980                    Op::GatherBackward { axis } => *axis,
3981                    _ => 0,
3982                };
3983                let axis_u = if axis < 0 {
3984                    (rank as i32 + axis) as usize
3985                } else {
3986                    axis as usize
3987                };
3988                let outer: usize = (0..axis_u)
3989                    .map(|i| dy_shape.dim(i).unwrap_static())
3990                    .product::<usize>()
3991                    .max(1);
3992                let num_idx = idx_shape.dim(axis_u).unwrap_static();
3993                let trailing: usize = (axis_u + 1..dy_shape.rank())
3994                    .map(|i| dy_shape.dim(i).unwrap_static())
3995                    .product::<usize>()
3996                    .max(1);
3997                let axis_dim = out_shape.dim(axis_u).unwrap_static();
3998                Thunk::GatherBackward {
3999                    dy: node_offset(arena, node.inputs[0]),
4000                    indices: node_offset(arena, node.inputs[1]),
4001                    dst: node_offset(arena, node.id),
4002                    outer: outer as u32,
4003                    axis_dim: axis_dim as u32,
4004                    num_idx: num_idx as u32,
4005                    trailing: trailing as u32,
4006                }
4007            }
4008
4009            Op::GroupNormBackwardInput { num_groups, eps }
4010            | Op::GroupNormBackwardGamma { num_groups, eps }
4011            | Op::GroupNormBackwardBeta { num_groups, eps } => {
4012                let x_shape = &graph.node(node.inputs[0]).shape;
4013                let n = x_shape.dim(0).unwrap_static() as u32;
4014                let c = x_shape.dim(1).unwrap_static() as u32;
4015                let h = x_shape.dim(2).unwrap_static() as u32;
4016                let w = x_shape.dim(3).unwrap_static() as u32;
4017                match &node.op {
4018                    Op::GroupNormBackwardInput { .. } => Thunk::GroupNormBackwardInput {
4019                        x: node_offset(arena, node.inputs[0]),
4020                        gamma: node_offset(arena, node.inputs[1]),
4021                        beta: node_offset(arena, node.inputs[2]),
4022                        dy: node_offset(arena, node.inputs[3]),
4023                        dx: node_offset(arena, node.id),
4024                        n,
4025                        c,
4026                        h,
4027                        w,
4028                        num_groups: *num_groups as u32,
4029                        eps: *eps,
4030                    },
4031                    Op::GroupNormBackwardGamma { .. } => Thunk::GroupNormBackwardGamma {
4032                        x: node_offset(arena, node.inputs[0]),
4033                        dy: node_offset(arena, node.inputs[1]),
4034                        dgamma: node_offset(arena, node.id),
4035                        n,
4036                        c,
4037                        h,
4038                        w,
4039                        num_groups: *num_groups as u32,
4040                        eps: *eps,
4041                    },
4042                    Op::GroupNormBackwardBeta { .. } => Thunk::GroupNormBackwardBeta {
4043                        dy: node_offset(arena, node.inputs[1]),
4044                        dbeta: node_offset(arena, node.id),
4045                        n,
4046                        c,
4047                        h,
4048                        w,
4049                    },
4050                    _ => unreachable!(),
4051                }
4052            }
4053
4054            Op::MaxPool2dBackward {
4055                kernel_size,
4056                stride,
4057                padding,
4058            } => {
4059                let x_shape = &graph.node(node.inputs[0]).shape;
4060                let dy_shape = &graph.node(node.inputs[1]).shape;
4061                if kernel_size.len() == 2 && x_shape.rank() == 4 && dy_shape.rank() == 4 {
4062                    Thunk::MaxPool2dBackward {
4063                        x: node_offset(arena, node.inputs[0]),
4064                        dy: node_offset(arena, node.inputs[1]),
4065                        dx: node_offset(arena, node.id),
4066                        n: x_shape.dim(0).unwrap_static() as u32,
4067                        c: x_shape.dim(1).unwrap_static() as u32,
4068                        h: x_shape.dim(2).unwrap_static() as u32,
4069                        w: x_shape.dim(3).unwrap_static() as u32,
4070                        h_out: dy_shape.dim(2).unwrap_static() as u32,
4071                        w_out: dy_shape.dim(3).unwrap_static() as u32,
4072                        kh: kernel_size[0] as u32,
4073                        kw: kernel_size[1] as u32,
4074                        sh: stride.first().copied().unwrap_or(1) as u32,
4075                        sw: stride.get(1).copied().unwrap_or(1) as u32,
4076                        ph: padding.first().copied().unwrap_or(0) as u32,
4077                        pw: padding.get(1).copied().unwrap_or(0) as u32,
4078                    }
4079                } else {
4080                    Thunk::Nop
4081                }
4082            }
4083
4084            Op::Conv2dBackwardInput {
4085                kernel_size,
4086                stride,
4087                padding,
4088                dilation,
4089                groups,
4090            } => {
4091                let dy_shape = &graph.node(node.inputs[0]).shape;
4092                let w_shape = &graph.node(node.inputs[1]).shape;
4093                let out_shape = &node.shape;
4094                if kernel_size.len() == 2
4095                    && dy_shape.rank() == 4
4096                    && w_shape.rank() == 4
4097                    && out_shape.rank() == 4
4098                {
4099                    Thunk::Conv2dBackwardInput {
4100                        dy: node_offset(arena, node.inputs[0]),
4101                        w: node_offset(arena, node.inputs[1]),
4102                        dx: node_offset(arena, node.id),
4103                        n: out_shape.dim(0).unwrap_static() as u32,
4104                        c_in: out_shape.dim(1).unwrap_static() as u32,
4105                        h: out_shape.dim(2).unwrap_static() as u32,
4106                        w_in: out_shape.dim(3).unwrap_static() as u32,
4107                        c_out: dy_shape.dim(1).unwrap_static() as u32,
4108                        h_out: dy_shape.dim(2).unwrap_static() as u32,
4109                        w_out: dy_shape.dim(3).unwrap_static() as u32,
4110                        kh: kernel_size[0] as u32,
4111                        kw: kernel_size[1] as u32,
4112                        sh: stride.first().copied().unwrap_or(1) as u32,
4113                        sw: stride.get(1).copied().unwrap_or(1) as u32,
4114                        ph: padding.first().copied().unwrap_or(0) as u32,
4115                        pw: padding.get(1).copied().unwrap_or(0) as u32,
4116                        dh: dilation.first().copied().unwrap_or(1) as u32,
4117                        dw: dilation.get(1).copied().unwrap_or(1) as u32,
4118                        groups: *groups as u32,
4119                    }
4120                } else {
4121                    Thunk::Nop
4122                }
4123            }
4124
4125            Op::Conv2dBackwardWeight {
4126                kernel_size,
4127                stride,
4128                padding,
4129                dilation,
4130                groups,
4131            } => {
4132                let x_shape = &graph.node(node.inputs[0]).shape;
4133                let dy_shape = &graph.node(node.inputs[1]).shape;
4134                let dw_shape = &node.shape;
4135                if kernel_size.len() == 2
4136                    && x_shape.rank() == 4
4137                    && dy_shape.rank() == 4
4138                    && dw_shape.rank() == 4
4139                {
4140                    Thunk::Conv2dBackwardWeight {
4141                        x: node_offset(arena, node.inputs[0]),
4142                        dy: node_offset(arena, node.inputs[1]),
4143                        dw: node_offset(arena, node.id),
4144                        n: x_shape.dim(0).unwrap_static() as u32,
4145                        c_in: x_shape.dim(1).unwrap_static() as u32,
4146                        h: x_shape.dim(2).unwrap_static() as u32,
4147                        w: x_shape.dim(3).unwrap_static() as u32,
4148                        c_out: dy_shape.dim(1).unwrap_static() as u32,
4149                        h_out: dy_shape.dim(2).unwrap_static() as u32,
4150                        w_out: dy_shape.dim(3).unwrap_static() as u32,
4151                        kh: kernel_size[0] as u32,
4152                        kw: kernel_size[1] as u32,
4153                        sh: stride.first().copied().unwrap_or(1) as u32,
4154                        sw: stride.get(1).copied().unwrap_or(1) as u32,
4155                        ph: padding.first().copied().unwrap_or(0) as u32,
4156                        pw: padding.get(1).copied().unwrap_or(0) as u32,
4157                        dh: dilation.first().copied().unwrap_or(1) as u32,
4158                        dw_dil: dilation.get(1).copied().unwrap_or(1) as u32,
4159                        groups: *groups as u32,
4160                    }
4161                } else {
4162                    Thunk::Nop
4163                }
4164            }
4165
4166            Op::SoftmaxCrossEntropyWithLogits => {
4167                let logits_shape = &graph.node(node.inputs[0]).shape;
4168                if logits_shape.rank() == 2 {
4169                    Thunk::SoftmaxCrossEntropy {
4170                        logits: node_offset(arena, node.inputs[0]),
4171                        labels: node_offset(arena, node.inputs[1]),
4172                        dst: node_offset(arena, node.id),
4173                        n: logits_shape.dim(0).unwrap_static() as u32,
4174                        c: logits_shape.dim(1).unwrap_static() as u32,
4175                    }
4176                } else {
4177                    Thunk::Nop
4178                }
4179            }
4180
4181            Op::SoftmaxCrossEntropyBackward => {
4182                let logits_shape = &graph.node(node.inputs[0]).shape;
4183                if logits_shape.rank() == 2 {
4184                    Thunk::SoftmaxCrossEntropyBackward {
4185                        logits: node_offset(arena, node.inputs[0]),
4186                        labels: node_offset(arena, node.inputs[1]),
4187                        d_loss: node_offset(arena, node.inputs[2]),
4188                        dlogits: node_offset(arena, node.id),
4189                        n: logits_shape.dim(0).unwrap_static() as u32,
4190                        c: logits_shape.dim(1).unwrap_static() as u32,
4191                    }
4192                } else {
4193                    Thunk::Nop
4194                }
4195            }
4196
4197            Op::DenseSolve => {
4198                // A: [n, n], b: [n] or [n, nrhs]. Output matches b.
4199                let a_shape = &graph.node(node.inputs[0]).shape;
4200                let n = a_shape.dim(0).unwrap_static();
4201                debug_assert_eq!(
4202                    n,
4203                    a_shape.dim(1).unwrap_static(),
4204                    "DenseSolve: A must be square"
4205                );
4206                let b_elems = node.shape.num_elements().unwrap();
4207                let nrhs = b_elems / n;
4208                match node.shape.dtype() {
4209                    rlx_ir::DType::F64 => Thunk::DenseSolveF64 {
4210                        a: node_offset(arena, node.inputs[0]),
4211                        b: node_offset(arena, node.inputs[1]),
4212                        x: node_offset(arena, node.id),
4213                        n: n as u32,
4214                        nrhs: nrhs as u32,
4215                    },
4216                    rlx_ir::DType::F32 => Thunk::DenseSolveF32 {
4217                        a: node_offset(arena, node.inputs[0]),
4218                        b: node_offset(arena, node.inputs[1]),
4219                        x: node_offset(arena, node.id),
4220                        n: n as u32,
4221                        nrhs: nrhs as u32,
4222                    },
4223                    other => panic!(
4224                        "DenseSolve: F32 + F64 lowered; got {other:?}. \
4225                         Add another variant when needed."
4226                    ),
4227                }
4228            }
4229
4230            Op::BatchedDenseSolve => {
4231                // A: [B, N, N], b: [B, N] or [B, N, K]. Output matches b.
4232                let a_shape = &graph.node(node.inputs[0]).shape;
4233                assert_eq!(a_shape.rank(), 3, "BatchedDenseSolve: A rank must be 3");
4234                let batch = a_shape.dim(0).unwrap_static();
4235                let n = a_shape.dim(1).unwrap_static();
4236                debug_assert_eq!(
4237                    n,
4238                    a_shape.dim(2).unwrap_static(),
4239                    "BatchedDenseSolve: A's last two dims must match"
4240                );
4241                let total = node.shape.num_elements().unwrap();
4242                let nrhs = total / (batch * n);
4243                match node.shape.dtype() {
4244                    rlx_ir::DType::F32 => Thunk::BatchedDenseSolveF32 {
4245                        a: node_offset(arena, node.inputs[0]),
4246                        b: node_offset(arena, node.inputs[1]),
4247                        x: node_offset(arena, node.id),
4248                        batch: batch as u32,
4249                        n: n as u32,
4250                        nrhs: nrhs as u32,
4251                    },
4252                    rlx_ir::DType::F64 => Thunk::BatchedDenseSolveF64 {
4253                        a: node_offset(arena, node.inputs[0]),
4254                        b: node_offset(arena, node.inputs[1]),
4255                        x: node_offset(arena, node.id),
4256                        batch: batch as u32,
4257                        n: n as u32,
4258                        nrhs: nrhs as u32,
4259                    },
4260                    other => panic!("BatchedDenseSolve: F32 + F64 only, got {other:?}"),
4261                }
4262            }
4263
4264            Op::Scan {
4265                body,
4266                length,
4267                save_trajectory,
4268                num_bcast,
4269                num_xs,
4270                num_checkpoints,
4271            } => {
4272                assert!(
4273                    *num_checkpoints == 0 || *num_checkpoints <= *length,
4274                    "Op::Scan: num_checkpoints={} must be 0 or ≤ length={}",
4275                    *num_checkpoints,
4276                    *length
4277                );
4278                if *num_checkpoints != 0 && *num_checkpoints != *length {
4279                    assert!(
4280                        *save_trajectory,
4281                        "Op::Scan: num_checkpoints<length only meaningful when save_trajectory=true"
4282                    );
4283                }
4284                // Plan + compile the body sub-graph standalone. The body
4285                // gets its own Arena; per execution we clone its
4286                // pristine bytes, copy the outer carry (and per-step xs
4287                // slices, if any) into the body's Input slots, run the
4288                // body schedule N times, then copy the body's output
4289                // back to the outer arena.
4290                //
4291                // Body invariants: 1 + num_xs Op::Inputs in NodeId order
4292                // — first declared is the carry, rest are x_t_i. Single
4293                // graph output (the next carry), same shape as carry.
4294                let body_plan = rlx_opt::memory::plan_memory(body);
4295                let _body_arena_size = body_plan.arena_size;
4296                // Snapshot per-input byte offsets before plan_memory
4297                // moves into the Arena below.
4298                let body_offsets: HashMap<NodeId, usize> = body_plan
4299                    .assignments
4300                    .iter()
4301                    .map(|(id, slot)| (*id, slot.offset))
4302                    .collect();
4303
4304                // Collect body Input nodes in NodeId order; first is
4305                // carry, rest are per-step xs in matching order.
4306                let mut body_inputs: Vec<NodeId> = body
4307                    .nodes()
4308                    .iter()
4309                    .filter(|n| matches!(n.op, Op::Input { .. }))
4310                    .map(|n| n.id)
4311                    .collect();
4312                body_inputs.sort();
4313                let n_body_inputs = body_inputs.len();
4314                let expected = 1 + *num_bcast as usize + *num_xs as usize;
4315                if n_body_inputs != expected {
4316                    let names: Vec<String> = body
4317                        .nodes()
4318                        .iter()
4319                        .filter_map(|n| match &n.op {
4320                            Op::Input { name } => Some(format!("{}={}", n.id, name)),
4321                            _ => None,
4322                        })
4323                        .collect();
4324                    panic!(
4325                        "Op::Scan body has {} Op::Input nodes; expected {} \
4326                            (1 carry + {} bcast + {} xs). Inputs by NodeId: [{}]",
4327                        n_body_inputs,
4328                        expected,
4329                        *num_bcast,
4330                        *num_xs,
4331                        names.join(", ")
4332                    );
4333                }
4334
4335                let body_input_id = body_inputs[0];
4336                let body_input_off = body_offsets[&body_input_id];
4337                let body_output_id = body
4338                    .outputs
4339                    .first()
4340                    .copied()
4341                    .expect("Op::Scan body must declare one output");
4342                let body_output_off = body_offsets[&body_output_id];
4343
4344                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4345                // Fill body Constant nodes — mirror the outer-graph logic
4346                // in rlx-runtime/src/backend.rs (dtype-aware).
4347                for n in body.nodes() {
4348                    if let Op::Constant { data } = &n.op
4349                        && body_arena.has_buffer(n.id)
4350                        && !data.is_empty()
4351                    {
4352                        match n.shape.dtype() {
4353                            rlx_ir::DType::F64 => {
4354                                let off = body_arena.byte_offset(n.id);
4355                                let buf = body_arena.raw_buf_mut();
4356                                let nbytes = (buf.len() - off).min(data.len());
4357                                buf[off..off + nbytes].copy_from_slice(&data[..nbytes]);
4358                            }
4359                            _ => {
4360                                let buf = body_arena.slice_mut(n.id);
4361                                let n_floats = data.len() / 4;
4362                                let n_lim = buf.len().min(n_floats);
4363                                for i in 0..n_lim {
4364                                    let bytes = [
4365                                        data[i * 4],
4366                                        data[i * 4 + 1],
4367                                        data[i * 4 + 2],
4368                                        data[i * 4 + 3],
4369                                    ];
4370                                    buf[i] = f32::from_le_bytes(bytes);
4371                                }
4372                            }
4373                        }
4374                    }
4375                }
4376                let body_init = body_arena.raw_buf().to_vec();
4377                let body_schedule = compile_thunks(body, &body_arena);
4378
4379                // Carry bytes — for trajectory mode, the outer node's
4380                // shape is [length, *carry_shape], so dividing by length
4381                // gives one row's bytes; the body's input slot still
4382                // holds carry_shape bytes.
4383                let carry_bytes = if *save_trajectory {
4384                    let total = node
4385                        .shape
4386                        .size_bytes()
4387                        .expect("Op::Scan trajectory output must have static shape");
4388                    total / *length as usize
4389                } else {
4390                    node.shape
4391                        .size_bytes()
4392                        .expect("Op::Scan carry must have static shape")
4393                };
4394
4395                // Bcast inputs occupy body_inputs[1..1+num_bcast] and
4396                // outer node.inputs[1..1+num_bcast]. They keep their
4397                // natural shape (no [length, ...] prefix) and are
4398                // copied into body_buf ONCE before the scan loop.
4399                let mut bcast_inputs: Vec<(usize, usize, u32)> =
4400                    Vec::with_capacity(*num_bcast as usize);
4401                for i in 0..*num_bcast as usize {
4402                    let body_b_id = body_inputs[1 + i];
4403                    let body_b_off = body_offsets[&body_b_id];
4404                    let outer_b_id = node.inputs[1 + i];
4405                    let outer_b_off = node_offset(arena, outer_b_id);
4406                    let outer_b_shape = &graph.node(outer_b_id).shape;
4407                    let total = outer_b_shape
4408                        .size_bytes()
4409                        .expect("Op::Scan bcast must have static shape");
4410                    bcast_inputs.push((body_b_off, outer_b_off, total as u32));
4411                }
4412
4413                // xs occupy body_inputs[1+num_bcast..] and node.inputs
4414                // [1+num_bcast..]. Each has shape [length, *per_step];
4415                // per-step bytes = total / length.
4416                let mut xs_inputs: Vec<(usize, usize, u32)> = Vec::with_capacity(*num_xs as usize);
4417                let xs_base = 1 + *num_bcast as usize;
4418                for i in 0..*num_xs as usize {
4419                    let body_x_id = body_inputs[xs_base + i];
4420                    let body_x_off = body_offsets[&body_x_id];
4421                    let outer_xs_id = node.inputs[xs_base + i];
4422                    let outer_xs_off = node_offset(arena, outer_xs_id);
4423                    let outer_xs_shape = &graph.node(outer_xs_id).shape;
4424                    let total = outer_xs_shape
4425                        .size_bytes()
4426                        .expect("Op::Scan xs must have static shape");
4427                    let per_step = total / *length as usize;
4428                    xs_inputs.push((body_x_off, outer_xs_off, per_step as u32));
4429                }
4430
4431                Thunk::Scan {
4432                    body: Arc::new(body_schedule),
4433                    body_init: Arc::new(body_init),
4434                    body_input_off,
4435                    body_output_off,
4436                    outer_init_off: node_offset(arena, node.inputs[0]),
4437                    outer_final_off: node_offset(arena, node.id),
4438                    length: *length,
4439                    carry_bytes: carry_bytes as u32,
4440                    save_trajectory: *save_trajectory,
4441                    xs_inputs: Arc::new(xs_inputs),
4442                    bcast_inputs: Arc::new(bcast_inputs),
4443                    num_checkpoints: *num_checkpoints,
4444                }
4445            }
4446
4447            Op::ScanBackward {
4448                body_vjp,
4449                length,
4450                save_trajectory,
4451                num_xs,
4452                num_checkpoints,
4453                forward_body,
4454            } => {
4455                let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
4456                if is_recursive {
4457                    assert!(
4458                        forward_body.is_some(),
4459                        "Op::ScanBackward with num_checkpoints<length requires forward_body"
4460                    );
4461                }
4462                // body_vjp has signature
4463                //   (carry, x_t_0, ..., x_t_{num_xs-1}, d_output) → dcarry
4464                // Identify slots:
4465                //   * "d_output" by exact name (AD-introduced seed Input).
4466                //   * Remaining Inputs sorted by NodeId — first is the
4467                //     carry mirror, rest are x_t_i mirrors in body's
4468                //     original Op::Input declaration order.
4469                let body_plan = rlx_opt::memory::plan_memory(body_vjp);
4470                let body_offsets: HashMap<NodeId, usize> = body_plan
4471                    .assignments
4472                    .iter()
4473                    .map(|(id, slot)| (*id, slot.offset))
4474                    .collect();
4475                let mut body_d_output_off: Option<usize> = None;
4476                let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
4477                for n in body_vjp.nodes() {
4478                    if let Op::Input { name } = &n.op {
4479                        let off = body_offsets[&n.id];
4480                        if name == "d_output" {
4481                            body_d_output_off = Some(off);
4482                        } else {
4483                            body_other_inputs.push((n.id, off));
4484                        }
4485                    }
4486                }
4487                body_other_inputs.sort_by_key(|(id, _)| *id);
4488                let body_d_output_off =
4489                    body_d_output_off.expect("ScanBackward body_vjp missing 'd_output' Input");
4490                let expected_others = 1 + *num_xs as usize;
4491                assert_eq!(
4492                    body_other_inputs.len(),
4493                    expected_others,
4494                    "ScanBackward body_vjp has {} non-d_output Inputs; \
4495                     expected {} (1 carry + {} xs)",
4496                    body_other_inputs.len(),
4497                    expected_others,
4498                    num_xs
4499                );
4500                let body_carry_in_off = body_other_inputs[0].1;
4501                let body_x_offs: Vec<usize> = body_other_inputs
4502                    .iter()
4503                    .skip(1)
4504                    .map(|(_, off)| *off)
4505                    .collect();
4506                let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
4507
4508                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4509                // Fill body_vjp's Constants (mirrors the Scan lowering).
4510                for n in body_vjp.nodes() {
4511                    if let Op::Constant { data } = &n.op
4512                        && body_arena.has_buffer(n.id)
4513                        && !data.is_empty()
4514                    {
4515                        match n.shape.dtype() {
4516                            rlx_ir::DType::F64 => {
4517                                let off = body_arena.byte_offset(n.id);
4518                                let buf = body_arena.raw_buf_mut();
4519                                let nb = (buf.len() - off).min(data.len());
4520                                buf[off..off + nb].copy_from_slice(&data[..nb]);
4521                            }
4522                            _ => {
4523                                let buf = body_arena.slice_mut(n.id);
4524                                let nf = data.len() / 4;
4525                                let nl = buf.len().min(nf);
4526                                for i in 0..nl {
4527                                    let bytes = [
4528                                        data[i * 4],
4529                                        data[i * 4 + 1],
4530                                        data[i * 4 + 2],
4531                                        data[i * 4 + 3],
4532                                    ];
4533                                    buf[i] = f32::from_le_bytes(bytes);
4534                                }
4535                            }
4536                        }
4537                    }
4538                }
4539                let body_init = body_arena.raw_buf().to_vec();
4540                let body_schedule = compile_thunks(body_vjp, &body_arena);
4541
4542                // Carry bytes from the dcarry output node (== carry shape).
4543                let carry_bytes = body_vjp
4544                    .node(body_vjp.outputs[0])
4545                    .shape
4546                    .size_bytes()
4547                    .expect("ScanBackward dcarry must be statically shaped");
4548                let carry_elem_size = body_vjp
4549                    .node(body_vjp.outputs[0])
4550                    .shape
4551                    .dtype()
4552                    .size_bytes() as u32;
4553
4554                // For each xs input on the outer node:
4555                // (outer_xs_base, per_step_bytes).
4556                let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
4557                for i in 0..*num_xs as usize {
4558                    let outer_xs_id = node.inputs[3 + i];
4559                    let outer_xs_off = node_offset(arena, outer_xs_id);
4560                    let outer_xs_shape = &graph.node(outer_xs_id).shape;
4561                    let total = outer_xs_shape
4562                        .size_bytes()
4563                        .expect("ScanBackward xs must have static shape");
4564                    let per_step = total / *length as usize;
4565                    outer_xs_offs.push((outer_xs_off, per_step as u32));
4566                }
4567
4568                // If recursive checkpointing is active, we also compile
4569                // the forward body so the executor can recompute
4570                // intermediate carries. The forward body is supplied
4571                // by the AD pass via `forward_body: Some(_)`.
4572                let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
4573                    if is_recursive {
4574                        let fb = forward_body.as_ref().unwrap();
4575                        let fb_plan = rlx_opt::memory::plan_memory(fb);
4576                        let fb_offsets: HashMap<NodeId, usize> = fb_plan
4577                            .assignments
4578                            .iter()
4579                            .map(|(id, slot)| (*id, slot.offset))
4580                            .collect();
4581                        let mut fb_inputs: Vec<NodeId> = fb
4582                            .nodes()
4583                            .iter()
4584                            .filter(|n| matches!(n.op, Op::Input { .. }))
4585                            .map(|n| n.id)
4586                            .collect();
4587                        fb_inputs.sort();
4588                        let fb_carry = fb_offsets[&fb_inputs[0]];
4589                        let fb_xs: Vec<usize> = (1..fb_inputs.len())
4590                            .map(|i| fb_offsets[&fb_inputs[i]])
4591                            .collect();
4592                        let fb_out = fb_offsets[&fb.outputs[0]];
4593                        let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
4594                        for n in fb.nodes() {
4595                            if let Op::Constant { data } = &n.op
4596                                && fb_arena.has_buffer(n.id)
4597                                && !data.is_empty()
4598                            {
4599                                // Byte-copy works for any
4600                                // numeric dtype as long as the
4601                                // arena slot is sized to hold
4602                                // it — the Constant's `data`
4603                                // already encodes the right
4604                                // bytes per element.
4605                                let off = fb_arena.byte_offset(n.id);
4606                                let buf = fb_arena.raw_buf_mut();
4607                                let nb = (buf.len() - off).min(data.len());
4608                                buf[off..off + nb].copy_from_slice(&data[..nb]);
4609                            }
4610                        }
4611                        let fb_init_bytes = fb_arena.raw_buf().to_vec();
4612                        let fb_sched = compile_thunks(fb, &fb_arena);
4613                        (
4614                            Some(Arc::new(fb_sched)),
4615                            Some(Arc::new(fb_init_bytes)),
4616                            fb_carry,
4617                            fb_out,
4618                            fb_xs,
4619                        )
4620                    } else {
4621                        (None, None, 0, 0, Vec::new())
4622                    };
4623
4624                Thunk::ScanBackward {
4625                    body_vjp: Arc::new(body_schedule),
4626                    body_init: Arc::new(body_init),
4627                    body_carry_in_off,
4628                    body_x_offs: Arc::new(body_x_offs),
4629                    body_d_output_off,
4630                    body_dcarry_out_off,
4631                    outer_init_off: node_offset(arena, node.inputs[0]),
4632                    outer_traj_off: node_offset(arena, node.inputs[1]),
4633                    outer_upstream_off: node_offset(arena, node.inputs[2]),
4634                    outer_xs_offs: Arc::new(outer_xs_offs),
4635                    outer_dinit_off: node_offset(arena, node.id),
4636                    length: *length,
4637                    carry_bytes: carry_bytes as u32,
4638                    carry_elem_size,
4639                    save_trajectory: *save_trajectory,
4640                    num_checkpoints: *num_checkpoints,
4641                    forward_body: fb_schedule,
4642                    forward_body_init: fb_init,
4643                    forward_body_carry_in_off: fb_carry_in_off,
4644                    forward_body_output_off: fb_output_off,
4645                    forward_body_x_offs: Arc::new(fb_x_offs),
4646                }
4647            }
4648
4649            Op::ScanBackwardXs {
4650                body_vjp,
4651                length,
4652                save_trajectory,
4653                num_xs,
4654                xs_idx,
4655                num_checkpoints,
4656                forward_body,
4657            } => {
4658                assert!(
4659                    *num_checkpoints == 0 || *num_checkpoints <= *length,
4660                    "Op::ScanBackwardXs: num_checkpoints={} must be 0 or ≤ length={}",
4661                    *num_checkpoints,
4662                    *length
4663                );
4664                let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
4665                if is_recursive {
4666                    assert!(
4667                        forward_body.is_some(),
4668                        "Op::ScanBackwardXs with num_checkpoints<length \
4669                         requires forward_body"
4670                    );
4671                }
4672                // Mirror ScanBackward's body_vjp slot identification +
4673                // arena prep, then add: per-iteration extraction of the
4674                // body_vjp output that corresponds to the chosen xs.
4675                //
4676                // body_vjp's outputs (from `grad(body, [carry, xs_0, ..., xs_{num_xs-1}])`):
4677                //   outputs[0]      = dcarry
4678                //   outputs[1 + i]  = dx_t_i
4679                let body_plan = rlx_opt::memory::plan_memory(body_vjp);
4680                let body_offsets: HashMap<NodeId, usize> = body_plan
4681                    .assignments
4682                    .iter()
4683                    .map(|(id, slot)| (*id, slot.offset))
4684                    .collect();
4685                let mut body_d_output_off: Option<usize> = None;
4686                let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
4687                for n in body_vjp.nodes() {
4688                    if let Op::Input { name } = &n.op {
4689                        let off = body_offsets[&n.id];
4690                        if name == "d_output" {
4691                            body_d_output_off = Some(off);
4692                        } else {
4693                            body_other_inputs.push((n.id, off));
4694                        }
4695                    }
4696                }
4697                body_other_inputs.sort_by_key(|(id, _)| *id);
4698                let body_d_output_off =
4699                    body_d_output_off.expect("ScanBackwardXs body_vjp missing 'd_output' Input");
4700                let expected_others = 1 + *num_xs as usize;
4701                assert_eq!(
4702                    body_other_inputs.len(),
4703                    expected_others,
4704                    "ScanBackwardXs body_vjp has {} non-d_output Inputs; expected {}",
4705                    body_other_inputs.len(),
4706                    expected_others
4707                );
4708                let body_carry_in_off = body_other_inputs[0].1;
4709                let body_x_offs: Vec<usize> = body_other_inputs
4710                    .iter()
4711                    .skip(1)
4712                    .map(|(_, off)| *off)
4713                    .collect();
4714                let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
4715                let dxs_out_node = body_vjp.outputs[1 + *xs_idx as usize];
4716                let body_dxs_out_off = body_offsets[&dxs_out_node];
4717
4718                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4719                for n in body_vjp.nodes() {
4720                    if let Op::Constant { data } = &n.op
4721                        && body_arena.has_buffer(n.id)
4722                        && !data.is_empty()
4723                    {
4724                        match n.shape.dtype() {
4725                            rlx_ir::DType::F64 => {
4726                                let off = body_arena.byte_offset(n.id);
4727                                let buf = body_arena.raw_buf_mut();
4728                                let nb = (buf.len() - off).min(data.len());
4729                                buf[off..off + nb].copy_from_slice(&data[..nb]);
4730                            }
4731                            _ => {
4732                                let buf = body_arena.slice_mut(n.id);
4733                                let nf = data.len() / 4;
4734                                let nl = buf.len().min(nf);
4735                                for i in 0..nl {
4736                                    let bytes = [
4737                                        data[i * 4],
4738                                        data[i * 4 + 1],
4739                                        data[i * 4 + 2],
4740                                        data[i * 4 + 3],
4741                                    ];
4742                                    buf[i] = f32::from_le_bytes(bytes);
4743                                }
4744                            }
4745                        }
4746                    }
4747                }
4748                let body_init = body_arena.raw_buf().to_vec();
4749                let body_schedule = compile_thunks(body_vjp, &body_arena);
4750
4751                let carry_bytes = body_vjp
4752                    .node(body_vjp.outputs[0])
4753                    .shape
4754                    .size_bytes()
4755                    .expect("ScanBackwardXs dcarry must be statically shaped");
4756                let carry_elem_size = body_vjp
4757                    .node(body_vjp.outputs[0])
4758                    .shape
4759                    .dtype()
4760                    .size_bytes() as u32;
4761                let per_step_bytes = body_vjp
4762                    .node(dxs_out_node)
4763                    .shape
4764                    .size_bytes()
4765                    .expect("ScanBackwardXs dxs body output must be statically shaped");
4766
4767                let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
4768                for i in 0..*num_xs as usize {
4769                    let outer_xs_id = node.inputs[3 + i];
4770                    let outer_xs_off = node_offset(arena, outer_xs_id);
4771                    let outer_xs_shape = &graph.node(outer_xs_id).shape;
4772                    let total = outer_xs_shape
4773                        .size_bytes()
4774                        .expect("ScanBackwardXs xs must have static shape");
4775                    let per_step = total / *length as usize;
4776                    outer_xs_offs.push((outer_xs_off, per_step as u32));
4777                }
4778
4779                // Compile forward_body for recompute when checkpointed.
4780                // Mirrors the same code path in the ScanBackward arm.
4781                let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
4782                    if is_recursive {
4783                        let fb = forward_body.as_ref().unwrap();
4784                        let fb_plan = rlx_opt::memory::plan_memory(fb);
4785                        let fb_offsets: HashMap<NodeId, usize> = fb_plan
4786                            .assignments
4787                            .iter()
4788                            .map(|(id, slot)| (*id, slot.offset))
4789                            .collect();
4790                        let mut fb_inputs: Vec<NodeId> = fb
4791                            .nodes()
4792                            .iter()
4793                            .filter(|n| matches!(n.op, Op::Input { .. }))
4794                            .map(|n| n.id)
4795                            .collect();
4796                        fb_inputs.sort();
4797                        let fb_carry = fb_offsets[&fb_inputs[0]];
4798                        let fb_xs: Vec<usize> = (1..fb_inputs.len())
4799                            .map(|i| fb_offsets[&fb_inputs[i]])
4800                            .collect();
4801                        let fb_out = fb_offsets[&fb.outputs[0]];
4802                        let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
4803                        for n in fb.nodes() {
4804                            if let Op::Constant { data } = &n.op
4805                                && fb_arena.has_buffer(n.id)
4806                                && !data.is_empty()
4807                            {
4808                                // Byte-copy works for any
4809                                // numeric dtype as long as the
4810                                // arena slot is sized to hold
4811                                // it — the Constant's `data`
4812                                // already encodes the right
4813                                // bytes per element.
4814                                let off = fb_arena.byte_offset(n.id);
4815                                let buf = fb_arena.raw_buf_mut();
4816                                let nb = (buf.len() - off).min(data.len());
4817                                buf[off..off + nb].copy_from_slice(&data[..nb]);
4818                            }
4819                        }
4820                        let fb_init_bytes = fb_arena.raw_buf().to_vec();
4821                        let fb_sched = compile_thunks(fb, &fb_arena);
4822                        (
4823                            Some(Arc::new(fb_sched)),
4824                            Some(Arc::new(fb_init_bytes)),
4825                            fb_carry,
4826                            fb_out,
4827                            fb_xs,
4828                        )
4829                    } else {
4830                        (None, None, 0, 0, Vec::new())
4831                    };
4832
4833                Thunk::ScanBackwardXs {
4834                    body_vjp: Arc::new(body_schedule),
4835                    body_init: Arc::new(body_init),
4836                    body_carry_in_off,
4837                    body_x_offs: Arc::new(body_x_offs),
4838                    body_d_output_off,
4839                    body_dcarry_out_off,
4840                    body_dxs_out_off,
4841                    outer_init_off: node_offset(arena, node.inputs[0]),
4842                    outer_traj_off: node_offset(arena, node.inputs[1]),
4843                    outer_upstream_off: node_offset(arena, node.inputs[2]),
4844                    outer_xs_offs: Arc::new(outer_xs_offs),
4845                    outer_dxs_off: node_offset(arena, node.id),
4846                    length: *length,
4847                    carry_bytes: carry_bytes as u32,
4848                    carry_elem_size,
4849                    per_step_bytes: per_step_bytes as u32,
4850                    save_trajectory: *save_trajectory,
4851                    num_checkpoints: *num_checkpoints,
4852                    forward_body: fb_schedule,
4853                    forward_body_init: fb_init,
4854                    forward_body_carry_in_off: fb_carry_in_off,
4855                    forward_body_output_off: fb_output_off,
4856                    forward_body_x_offs: Arc::new(fb_x_offs),
4857                }
4858            }
4859
4860            Op::Concat { axis } => {
4861                // Compute outer/inner from the OUTPUT shape: all inputs share
4862                // the same shape except along `axis`. The output's leading
4863                // and trailing dims match.
4864                let out_shape = &node.shape;
4865                let rank = out_shape.rank();
4866                let outer: usize = (0..*axis)
4867                    .map(|i| out_shape.dim(i).unwrap_static())
4868                    .product::<usize>()
4869                    .max(1);
4870                let inner: usize = (*axis + 1..rank)
4871                    .map(|i| out_shape.dim(i).unwrap_static())
4872                    .product::<usize>()
4873                    .max(1);
4874                let total_axis = out_shape.dim(*axis).unwrap_static();
4875                let inputs: Vec<(usize, u32)> = node
4876                    .inputs
4877                    .iter()
4878                    .map(|&in_id| {
4879                        let in_shape = &graph.node(in_id).shape;
4880                        let in_axis = in_shape.dim(*axis).unwrap_static();
4881                        (node_offset(arena, in_id), in_axis as u32)
4882                    })
4883                    .collect();
4884                let dst = node_offset(arena, node.id);
4885                match out_shape.dtype() {
4886                    rlx_ir::DType::F64 => Thunk::ConcatF64 {
4887                        dst,
4888                        outer: outer as u32,
4889                        inner: inner as u32,
4890                        total_axis: total_axis as u32,
4891                        inputs,
4892                    },
4893                    _ => Thunk::Concat {
4894                        dst,
4895                        outer: outer as u32,
4896                        inner: inner as u32,
4897                        total_axis: total_axis as u32,
4898                        inputs,
4899                    },
4900                }
4901            }
4902
4903            Op::GaussianSplatRender {
4904                width,
4905                height,
4906                tile_size,
4907                radius_scale,
4908                alpha_cutoff,
4909                max_splat_steps,
4910                transmittance_threshold,
4911                max_list_entries,
4912            } => {
4913                let elem_len =
4914                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
4915                Thunk::GaussianSplatRender {
4916                    positions_off: node_offset(arena, node.inputs[0]),
4917                    positions_len: elem_len(node.inputs[0]),
4918                    scales_off: node_offset(arena, node.inputs[1]),
4919                    scales_len: elem_len(node.inputs[1]),
4920                    rotations_off: node_offset(arena, node.inputs[2]),
4921                    rotations_len: elem_len(node.inputs[2]),
4922                    opacities_off: node_offset(arena, node.inputs[3]),
4923                    opacities_len: elem_len(node.inputs[3]),
4924                    colors_off: node_offset(arena, node.inputs[4]),
4925                    colors_len: elem_len(node.inputs[4]),
4926                    sh_coeffs_off: node_offset(arena, node.inputs[5]),
4927                    sh_coeffs_len: elem_len(node.inputs[5]),
4928                    meta_off: node_offset(arena, node.inputs[6]),
4929                    dst_off: node_offset(arena, node.id),
4930                    dst_len: node.shape.num_elements().unwrap_or(0),
4931                    width: *width,
4932                    height: *height,
4933                    tile_size: *tile_size,
4934                    radius_scale: *radius_scale,
4935                    alpha_cutoff: *alpha_cutoff,
4936                    max_splat_steps: *max_splat_steps,
4937                    transmittance_threshold: *transmittance_threshold,
4938                    max_list_entries: *max_list_entries,
4939                }
4940            }
4941
4942            Op::GaussianSplatRenderBackward {
4943                width,
4944                height,
4945                tile_size,
4946                radius_scale,
4947                alpha_cutoff,
4948                max_splat_steps,
4949                transmittance_threshold,
4950                max_list_entries,
4951                loss_grad_clip,
4952                sh_band,
4953                max_anisotropy,
4954            } => {
4955                let elem_len =
4956                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
4957                Thunk::GaussianSplatRenderBackward {
4958                    positions_off: node_offset(arena, node.inputs[0]),
4959                    positions_len: elem_len(node.inputs[0]),
4960                    scales_off: node_offset(arena, node.inputs[1]),
4961                    scales_len: elem_len(node.inputs[1]),
4962                    rotations_off: node_offset(arena, node.inputs[2]),
4963                    rotations_len: elem_len(node.inputs[2]),
4964                    opacities_off: node_offset(arena, node.inputs[3]),
4965                    opacities_len: elem_len(node.inputs[3]),
4966                    colors_off: node_offset(arena, node.inputs[4]),
4967                    colors_len: elem_len(node.inputs[4]),
4968                    sh_coeffs_off: node_offset(arena, node.inputs[5]),
4969                    sh_coeffs_len: elem_len(node.inputs[5]),
4970                    meta_off: node_offset(arena, node.inputs[6]),
4971                    d_loss_off: node_offset(arena, node.inputs[7]),
4972                    d_loss_len: elem_len(node.inputs[7]),
4973                    packed_off: node_offset(arena, node.id),
4974                    packed_len: node.shape.num_elements().unwrap_or(0),
4975                    width: *width,
4976                    height: *height,
4977                    tile_size: *tile_size,
4978                    radius_scale: *radius_scale,
4979                    alpha_cutoff: *alpha_cutoff,
4980                    max_splat_steps: *max_splat_steps,
4981                    transmittance_threshold: *transmittance_threshold,
4982                    max_list_entries: *max_list_entries,
4983                    loss_grad_clip: *loss_grad_clip,
4984                    sh_band: *sh_band,
4985                    max_anisotropy: *max_anisotropy,
4986                }
4987            }
4988
4989            Op::GaussianSplatPrepare {
4990                width,
4991                height,
4992                tile_size,
4993                radius_scale,
4994                alpha_cutoff,
4995                max_splat_steps,
4996                transmittance_threshold,
4997                max_list_entries,
4998            } => {
4999                let elem_len =
5000                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5001                Thunk::GaussianSplatPrepare {
5002                    positions_off: node_offset(arena, node.inputs[0]),
5003                    positions_len: elem_len(node.inputs[0]),
5004                    scales_off: node_offset(arena, node.inputs[1]),
5005                    scales_len: elem_len(node.inputs[1]),
5006                    rotations_off: node_offset(arena, node.inputs[2]),
5007                    rotations_len: elem_len(node.inputs[2]),
5008                    opacities_off: node_offset(arena, node.inputs[3]),
5009                    opacities_len: elem_len(node.inputs[3]),
5010                    colors_off: node_offset(arena, node.inputs[4]),
5011                    colors_len: elem_len(node.inputs[4]),
5012                    sh_coeffs_off: node_offset(arena, node.inputs[5]),
5013                    sh_coeffs_len: elem_len(node.inputs[5]),
5014                    meta_off: node_offset(arena, node.inputs[6]),
5015                    meta_len: elem_len(node.inputs[6]),
5016                    prep_off: node_offset(arena, node.id),
5017                    prep_len: node.shape.num_elements().unwrap_or(0),
5018                    width: *width,
5019                    height: *height,
5020                    tile_size: *tile_size,
5021                    radius_scale: *radius_scale,
5022                    alpha_cutoff: *alpha_cutoff,
5023                    max_splat_steps: *max_splat_steps,
5024                    transmittance_threshold: *transmittance_threshold,
5025                    max_list_entries: *max_list_entries,
5026                }
5027            }
5028
5029            Op::GaussianSplatRasterize {
5030                width,
5031                height,
5032                tile_size,
5033                alpha_cutoff,
5034                max_splat_steps,
5035                transmittance_threshold,
5036                max_list_entries,
5037            } => {
5038                let elem_len =
5039                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5040                let prep_id = node.inputs[0];
5041                let count = match &graph.node(prep_id).op {
5042                    rlx_ir::Op::GaussianSplatPrepare { .. } => {
5043                        elem_len(graph.node(prep_id).inputs[0]) / 3
5044                    }
5045                    _ => 1,
5046                };
5047                Thunk::GaussianSplatRasterize {
5048                    prep_off: node_offset(arena, prep_id),
5049                    prep_len: elem_len(prep_id),
5050                    meta_off: node_offset(arena, node.inputs[1]),
5051                    meta_len: elem_len(node.inputs[1]),
5052                    dst_off: node_offset(arena, node.id),
5053                    dst_len: node.shape.num_elements().unwrap_or(0),
5054                    count,
5055                    width: *width,
5056                    height: *height,
5057                    tile_size: *tile_size,
5058                    alpha_cutoff: *alpha_cutoff,
5059                    max_splat_steps: *max_splat_steps,
5060                    transmittance_threshold: *transmittance_threshold,
5061                    max_list_entries: *max_list_entries,
5062                }
5063            }
5064
5065            Op::Custom { name, attrs, .. } => {
5066                let kernel = crate::op_registry::lookup_cpu_kernel(name).unwrap_or_else(|| {
5067                    panic!(
5068                        "compile_thunks: no CPU kernel registered for \
5069                         Op::Custom('{name}'). Register one via \
5070                         rlx_cpu::op_registry::register_cpu_kernel \
5071                         before compiling on the CPU backend."
5072                    )
5073                });
5074                let inputs_v: Vec<(usize, u32, Shape)> = node
5075                    .inputs
5076                    .iter()
5077                    .map(|&in_id| {
5078                        let s = graph.node(in_id).shape.clone();
5079                        let len = s.num_elements().unwrap_or(0) as u32;
5080                        (node_offset(arena, in_id), len, s)
5081                    })
5082                    .collect();
5083                let out_len = node.shape.num_elements().unwrap_or(0) as u32;
5084                Thunk::CustomOp {
5085                    kernel,
5086                    inputs: inputs_v,
5087                    output: (node_offset(arena, node.id), out_len, node.shape.clone()),
5088                    attrs: attrs.clone(),
5089                }
5090            }
5091
5092            Op::Fft { inverse } => {
5093                // Last axis carries the 2N real-block layout; complex
5094                // points = N = last_dim / 2. `outer` is the product
5095                // of all preceding axes — the kernel iterates one
5096                // batch-row at a time. f32 and f64 share the same
5097                // radix-2 structure but use separate scratch buffers;
5098                // the dtype is captured here so the closure dispatches
5099                // without per-row branching.
5100                let shape = &node.shape;
5101                let last = shape.dim(shape.rank() - 1).unwrap_static();
5102                let n_complex = (last / 2) as u32;
5103                let total = shape.num_elements().unwrap_or(0);
5104                let outer = (total / last) as u32;
5105                let dtype = shape.dtype();
5106                assert!(
5107                    matches!(dtype, rlx_ir::DType::F32 | rlx_ir::DType::F64),
5108                    "Op::Fft on CPU requires F32 or F64, got {dtype:?}"
5109                );
5110                Thunk::Fft1d {
5111                    src: node_offset(arena, node.inputs[0]),
5112                    dst: node_offset(arena, node.id),
5113                    outer,
5114                    n_complex,
5115                    inverse: *inverse,
5116                    dtype,
5117                }
5118            }
5119
5120            Op::CustomFn {
5121                fwd_body,
5122                num_inputs,
5123                ..
5124            } => {
5125                // Plan + compile the body sub-graph standalone, fill its
5126                // Constants (mirrors the Op::Scan body lowering), then
5127                // capture per-input copy specs and the output spec.
5128                // Body Inputs in NodeId order match the outer node's
5129                // operand vector by position.
5130                let body_plan = rlx_opt::memory::plan_memory(fwd_body);
5131                let body_offsets: HashMap<NodeId, usize> = body_plan
5132                    .assignments
5133                    .iter()
5134                    .map(|(id, slot)| (*id, slot.offset))
5135                    .collect();
5136
5137                let mut body_input_ids: Vec<NodeId> = fwd_body
5138                    .nodes()
5139                    .iter()
5140                    .filter(|n| matches!(n.op, Op::Input { .. }))
5141                    .map(|n| n.id)
5142                    .collect();
5143                body_input_ids.sort();
5144                assert_eq!(
5145                    body_input_ids.len(),
5146                    *num_inputs as usize,
5147                    "Op::CustomFn fwd_body has {} Op::Input(s); declared num_inputs={}",
5148                    body_input_ids.len(),
5149                    *num_inputs,
5150                );
5151
5152                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5153                for n in fwd_body.nodes() {
5154                    if let Op::Constant { data } = &n.op
5155                        && body_arena.has_buffer(n.id)
5156                        && !data.is_empty()
5157                    {
5158                        match n.shape.dtype() {
5159                            rlx_ir::DType::F64 => {
5160                                let off = body_arena.byte_offset(n.id);
5161                                let buf = body_arena.raw_buf_mut();
5162                                let nb = (buf.len() - off).min(data.len());
5163                                buf[off..off + nb].copy_from_slice(&data[..nb]);
5164                            }
5165                            _ => {
5166                                let buf = body_arena.slice_mut(n.id);
5167                                let nf = data.len() / 4;
5168                                let nl = buf.len().min(nf);
5169                                for i in 0..nl {
5170                                    let bytes = [
5171                                        data[i * 4],
5172                                        data[i * 4 + 1],
5173                                        data[i * 4 + 2],
5174                                        data[i * 4 + 3],
5175                                    ];
5176                                    buf[i] = f32::from_le_bytes(bytes);
5177                                }
5178                            }
5179                        }
5180                    }
5181                }
5182                let body_init = body_arena.raw_buf().to_vec();
5183                let body_schedule = compile_thunks(fwd_body, &body_arena);
5184
5185                // Per primal input: (body_input_off, outer_input_off, bytes).
5186                let inputs_v: Vec<(usize, usize, u32)> = (0..*num_inputs as usize)
5187                    .map(|i| {
5188                        let body_in = body_input_ids[i];
5189                        let body_off = body_offsets[&body_in];
5190                        let outer_in = node.inputs[i];
5191                        let outer_off = node_offset(arena, outer_in);
5192                        let bytes = graph
5193                            .node(outer_in)
5194                            .shape
5195                            .size_bytes()
5196                            .expect("Op::CustomFn primal input must have static shape");
5197                        (body_off, outer_off, bytes as u32)
5198                    })
5199                    .collect();
5200
5201                let body_output_id = fwd_body
5202                    .outputs
5203                    .first()
5204                    .copied()
5205                    .expect("Op::CustomFn fwd_body must declare exactly one output");
5206                let body_output_off = body_offsets[&body_output_id];
5207                let out_bytes = node
5208                    .shape
5209                    .size_bytes()
5210                    .expect("Op::CustomFn output must have static shape");
5211
5212                Thunk::CustomFn {
5213                    body: Arc::new(body_schedule),
5214                    body_init: Arc::new(body_init),
5215                    inputs: Arc::new(inputs_v),
5216                    body_output_off,
5217                    outer_output_off: node_offset(arena, node.id),
5218                    out_bytes: out_bytes as u32,
5219                }
5220            }
5221
5222            _ => Thunk::Nop,
5223        };
5224        thunks.push(t);
5225    }
5226
5227    let cfg = crate::config::RuntimeConfig::global();
5228    let mask_thr = cfg.mask_binary_threshold;
5229    let mask_neg = cfg.attn_mask_neg_inf;
5230    let score_skip = cfg.score_skip_threshold;
5231
5232    // Pre-compile closures (skip Nops — they're filtered out)
5233    let compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>> = thunks
5234        .iter()
5235        .filter(|t| !matches!(t, Thunk::Nop))
5236        .map(|thunk| {
5237            match thunk.clone() {
5238                Thunk::Nop => Arc::new(|_: *mut u8| {}) as Arc<dyn Fn(*mut u8) + Send + Sync>,
5239
5240                Thunk::Sgemm { a, b, c, m, k, n } => {
5241                    let (m, k, n) = (m as usize, k as usize, n as usize);
5242                    Arc::new(move |base: *mut u8| unsafe {
5243                        crate::blas::sgemm(
5244                            sl(a, base, m * k),
5245                            sl(b, base, k * n),
5246                            sl_mut(c, base, m * n),
5247                            m,
5248                            k,
5249                            n,
5250                        );
5251                    })
5252                }
5253
5254                Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
5255                    let (n_, nrhs_) = (n as usize, nrhs as usize);
5256                    Arc::new(move |base: *mut u8| unsafe {
5257                        let a_src = sl_f64(a, base, n_ * n_);
5258                        let b_src = sl_f64(b, base, n_ * nrhs_);
5259                        let mut a_scratch: Vec<f64> = a_src.to_vec();
5260                        let mut x_buf: Vec<f64> = b_src.to_vec();
5261                        let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
5262                        if info != 0 {
5263                            panic!("DenseSolveF64: singular (info={info})");
5264                        }
5265                        sl_mut_f64(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
5266                    })
5267                }
5268
5269                Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
5270                    let (n_, nrhs_) = (n as usize, nrhs as usize);
5271                    Arc::new(move |base: *mut u8| unsafe {
5272                        let a_src = sl(a, base, n_ * n_);
5273                        let b_src = sl(b, base, n_ * nrhs_);
5274                        let mut a_scratch: Vec<f32> = a_src.to_vec();
5275                        let mut x_buf: Vec<f32> = b_src.to_vec();
5276                        let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
5277                        if info != 0 {
5278                            panic!("DenseSolveF32: singular (info={info})");
5279                        }
5280                        sl_mut(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
5281                    })
5282                }
5283
5284                Thunk::FusedMmBiasAct {
5285                    a,
5286                    w,
5287                    bias,
5288                    c,
5289                    m,
5290                    k,
5291                    n,
5292                    act,
5293                } => {
5294                    let (m, k, n) = (m as usize, k as usize, n as usize);
5295                    Arc::new(move |base: *mut u8| unsafe {
5296                        let out = sl_mut(c, base, m * n);
5297                        crate::blas::sgemm(sl(a, base, m * k), sl(w, base, k * n), out, m, k, n);
5298                        // Bias + activation epilogue. Gelu uses the fused
5299                        // `par_bias_gelu` kernel (bias add + Gelu in one
5300                        // pass). For everything else, do the bias add first
5301                        // and then apply the activation per-element. The
5302                        // pre-fix code dispatched `_ => bias_add` and dropped
5303                        // the activation entirely — silent correctness bug
5304                        // for Silu/Relu/Sigmoid/etc.
5305                        match act {
5306                            Some(Activation::Gelu) => {
5307                                crate::kernels::par_bias_gelu(out, sl(bias, base, n), m, n)
5308                            }
5309                            Some(other) => {
5310                                crate::blas::bias_add(out, sl(bias, base, n), m, n);
5311                                apply_activation_inplace(out, other);
5312                            }
5313                            None => crate::blas::bias_add(out, sl(bias, base, n), m, n),
5314                        }
5315                    })
5316                }
5317
5318                Thunk::FusedResidualLN {
5319                    x,
5320                    res,
5321                    bias,
5322                    g,
5323                    b,
5324                    out,
5325                    rows,
5326                    h,
5327                    eps,
5328                    has_bias,
5329                } => {
5330                    let (rows, h) = (rows as usize, h as usize);
5331                    Arc::new(move |base: *mut u8| unsafe {
5332                        let zero = vec![0f32; h]; // closure only — not hot path
5333                        let bi = if has_bias { sl(bias, base, h) } else { &zero };
5334                        let xp = sl(x, base, rows * h).as_ptr() as usize;
5335                        let rp = sl(res, base, rows * h).as_ptr() as usize;
5336                        let op = sl_mut(out, base, rows * h).as_mut_ptr() as usize;
5337                        let bp = bi.as_ptr() as usize;
5338                        let gp = sl(g, base, h).as_ptr() as usize;
5339                        let bbp = sl(b, base, h).as_ptr() as usize;
5340                        crate::pool::par_for(rows, 4, &|off, cnt| {
5341                            let xs = std::slice::from_raw_parts(
5342                                (xp as *const f32).add(off * h),
5343                                cnt * h,
5344                            );
5345                            let rs = std::slice::from_raw_parts(
5346                                (rp as *const f32).add(off * h),
5347                                cnt * h,
5348                            );
5349                            let os = std::slice::from_raw_parts_mut(
5350                                (op as *mut f32).add(off * h),
5351                                cnt * h,
5352                            );
5353                            let bi = std::slice::from_raw_parts(bp as *const f32, h);
5354                            let g = std::slice::from_raw_parts(gp as *const f32, h);
5355                            let b = std::slice::from_raw_parts(bbp as *const f32, h);
5356                            crate::kernels::residual_bias_layer_norm(
5357                                xs, rs, bi, g, b, os, cnt, h, eps,
5358                            );
5359                        });
5360                    })
5361                }
5362
5363                Thunk::BiasAdd {
5364                    src,
5365                    bias,
5366                    dst,
5367                    m,
5368                    n,
5369                } => {
5370                    let (m, n) = (m as usize, n as usize);
5371                    Arc::new(move |base: *mut u8| unsafe {
5372                        let out = sl_mut(dst, base, m * n);
5373                        out.copy_from_slice(sl(src, base, m * n));
5374                        crate::blas::bias_add(out, sl(bias, base, n), m, n);
5375                    })
5376                }
5377
5378                Thunk::Gather {
5379                    table,
5380                    table_len,
5381                    idx,
5382                    dst,
5383                    num_idx,
5384                    trailing,
5385                } => {
5386                    let (ni, tr, tl) = (num_idx as usize, trailing as usize, table_len as usize);
5387                    Arc::new(move |base: *mut u8| unsafe {
5388                        let tab = sl(table, base, tl);
5389                        let ids = sl(idx, base, ni);
5390                        let out = sl_mut(dst, base, ni * tr);
5391                        for i in 0..ni {
5392                            let row = ids[i] as usize;
5393                            out[i * tr..(i + 1) * tr]
5394                                .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5395                        }
5396                    })
5397                }
5398
5399                Thunk::Narrow {
5400                    src,
5401                    dst,
5402                    outer,
5403                    src_stride,
5404                    dst_stride,
5405                    inner,
5406                    elem_bytes,
5407                } => {
5408                    narrow_thunk_closure(src, dst, outer, src_stride, dst_stride, inner, elem_bytes)
5409                }
5410
5411                Thunk::Copy { src, dst, len } => {
5412                    let len = len as usize;
5413                    Arc::new(move |base: *mut u8| unsafe {
5414                        sl_mut(dst, base, len).copy_from_slice(sl(src, base, len));
5415                    })
5416                }
5417
5418                Thunk::Softmax { data, rows, cols } => {
5419                    let (rows, cols) = (rows as usize, cols as usize);
5420                    Arc::new(move |base: *mut u8| unsafe {
5421                        crate::naive::softmax(sl_mut(data, base, rows * cols), rows, cols);
5422                    })
5423                }
5424
5425                Thunk::Cumsum {
5426                    src,
5427                    dst,
5428                    rows,
5429                    cols,
5430                    exclusive,
5431                } => {
5432                    let (rows, cols) = (rows as usize, cols as usize);
5433                    Arc::new(move |base: *mut u8| unsafe {
5434                        let s = sl(src, base, rows * cols);
5435                        let d = sl_mut(dst, base, rows * cols);
5436                        if exclusive {
5437                            for r in 0..rows {
5438                                let mut acc = 0.0f32;
5439                                for c in 0..cols {
5440                                    d[r * cols + c] = acc;
5441                                    acc += s[r * cols + c];
5442                                }
5443                            }
5444                        } else {
5445                            for r in 0..rows {
5446                                let mut acc = 0.0f32;
5447                                for c in 0..cols {
5448                                    acc += s[r * cols + c];
5449                                    d[r * cols + c] = acc;
5450                                }
5451                            }
5452                        }
5453                    })
5454                }
5455
5456                Thunk::Sample {
5457                    logits,
5458                    dst,
5459                    batch,
5460                    vocab,
5461                    top_k,
5462                    top_p,
5463                    temperature,
5464                    seed,
5465                } => {
5466                    let (b, v) = (batch as usize, vocab as usize);
5467                    let k = (top_k as usize).min(v);
5468                    Arc::new(move |base: *mut u8| unsafe {
5469                        let lg = sl(logits, base, b * v);
5470                        let out = sl_mut(dst, base, b);
5471                        let mut rng =
5472                            rlx_ir::Philox4x32::new(if seed == 0 { 0xDEADBEEF } else { seed });
5473                        for bi in 0..b {
5474                            let row = &lg[bi * v..(bi + 1) * v];
5475                            out[bi] = sample_row(row, k, top_p, temperature, &mut rng) as f32;
5476                        }
5477                    })
5478                }
5479
5480                Thunk::DequantMatMul {
5481                    x,
5482                    w_q,
5483                    scale,
5484                    zp,
5485                    dst,
5486                    m,
5487                    k,
5488                    n,
5489                    block_size,
5490                    is_asymmetric,
5491                } => {
5492                    let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
5493                    let n_blocks_per_col = k.div_ceil(bs);
5494                    Arc::new(move |base: *mut u8| unsafe {
5495                        let xs = sl(x, base, m * k);
5496                        // w_q is packed i8 — use raw byte slice + reinterpret.
5497                        let raw = base.add(w_q);
5498                        let w_bytes = std::slice::from_raw_parts(raw as *const i8, k * n);
5499                        let scales = sl(scale, base, n_blocks_per_col * n);
5500                        let zps = if is_asymmetric {
5501                            sl(zp, base, n_blocks_per_col * n)
5502                        } else {
5503                            &[][..]
5504                        };
5505                        let out = sl_mut(dst, base, m * n);
5506                        dequant_matmul_int8(
5507                            xs,
5508                            w_bytes,
5509                            scales,
5510                            zps,
5511                            out,
5512                            m,
5513                            k,
5514                            n,
5515                            bs,
5516                            is_asymmetric,
5517                        );
5518                    })
5519                }
5520
5521                Thunk::DequantMatMulGguf {
5522                    x,
5523                    w_q,
5524                    dst,
5525                    m,
5526                    k,
5527                    n,
5528                    scheme,
5529                } => {
5530                    let (m, k, n) = (m as usize, k as usize, n as usize);
5531                    let block_bytes = scheme.gguf_block_bytes() as usize;
5532                    let block_elems = scheme.gguf_block_size() as usize;
5533                    let total_bytes = (k * n) / block_elems * block_bytes;
5534                    Arc::new(move |base: *mut u8| unsafe {
5535                        let xs = sl(x, base, m * k);
5536                        let w_bytes =
5537                            std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
5538                        let out = sl_mut(dst, base, m * n);
5539                        crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
5540                    })
5541                }
5542
5543                Thunk::DequantMatMulInt4 {
5544                    x,
5545                    w_q,
5546                    scale,
5547                    zp,
5548                    dst,
5549                    m,
5550                    k,
5551                    n,
5552                    block_size,
5553                    is_asymmetric,
5554                } => {
5555                    let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
5556                    let n_blocks = k.div_ceil(bs);
5557                    Arc::new(move |base: *mut u8| unsafe {
5558                        let xs = sl(x, base, m * k);
5559                        let w_bytes = std::slice::from_raw_parts(
5560                            base.add(w_q) as *const u8,
5561                            (k * n).div_ceil(2),
5562                        );
5563                        let scales = sl(scale, base, n_blocks * n);
5564                        let zps = if is_asymmetric {
5565                            sl(zp, base, n_blocks * n)
5566                        } else {
5567                            &[][..]
5568                        };
5569                        let out = sl_mut(dst, base, m * n);
5570                        dequant_matmul_int4(
5571                            xs,
5572                            w_bytes,
5573                            scales,
5574                            zps,
5575                            out,
5576                            m,
5577                            k,
5578                            n,
5579                            bs,
5580                            is_asymmetric,
5581                        );
5582                    })
5583                }
5584
5585                Thunk::DequantMatMulFp8 {
5586                    x,
5587                    w_q,
5588                    scale,
5589                    dst,
5590                    m,
5591                    k,
5592                    n,
5593                    e5m2,
5594                } => {
5595                    let (m, k, n) = (m as usize, k as usize, n as usize);
5596                    Arc::new(move |base: *mut u8| unsafe {
5597                        let xs = sl(x, base, m * k);
5598                        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
5599                        let scales = sl(scale, base, n);
5600                        let out = sl_mut(dst, base, m * n);
5601                        dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
5602                    })
5603                }
5604
5605                Thunk::DequantMatMulNvfp4 {
5606                    x,
5607                    w_q,
5608                    scale,
5609                    global_scale,
5610                    dst,
5611                    m,
5612                    k,
5613                    n,
5614                } => {
5615                    let (m, k, n) = (m as usize, k as usize, n as usize);
5616                    let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
5617                    Arc::new(move |base: *mut u8| unsafe {
5618                        let xs = sl(x, base, m * k);
5619                        let w_bytes = std::slice::from_raw_parts(
5620                            base.add(w_q) as *const u8,
5621                            (k * n).div_ceil(2),
5622                        );
5623                        let scale_bytes =
5624                            std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
5625                        let gs = sl(global_scale, base, 1)[0];
5626                        let out = sl_mut(dst, base, m * n);
5627                        dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
5628                    })
5629                }
5630
5631                Thunk::LoraMatMul {
5632                    x,
5633                    w,
5634                    a,
5635                    b,
5636                    dst,
5637                    m,
5638                    k,
5639                    n,
5640                    r,
5641                    scale,
5642                } => {
5643                    let (m, k, n, r) = (m as usize, k as usize, n as usize, r as usize);
5644                    Arc::new(move |base: *mut u8| unsafe {
5645                        let xs = sl(x, base, m * k);
5646                        let ws = sl(w, base, k * n);
5647                        let a_s = sl(a, base, k * r);
5648                        let bs = sl(b, base, r * n);
5649                        let out = sl_mut(dst, base, m * n);
5650                        // Step 1: out = x · W.
5651                        crate::blas::sgemm(xs, ws, out, m, k, n);
5652                        // Step 2: tmp = x · A (rank-r intermediate; tiny).
5653                        let mut tmp = vec![0f32; m * r];
5654                        crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
5655                        // Step 3: out += scale * (tmp · B).
5656                        // sgemm_accumulate uses alpha=1.0 internally, so
5657                        // scale tmp first.
5658                        if scale != 1.0 {
5659                            for v in tmp.iter_mut() {
5660                                *v *= scale;
5661                            }
5662                        }
5663                        crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
5664                    })
5665                }
5666
5667                Thunk::LayerNorm {
5668                    src,
5669                    g,
5670                    b,
5671                    dst,
5672                    rows,
5673                    h,
5674                    eps,
5675                } => {
5676                    let (rows, h) = (rows as usize, h as usize);
5677                    Arc::new(move |base: *mut u8| unsafe {
5678                        let inp = sl(src, base, rows * h);
5679                        let gamma = sl(g, base, h);
5680                        let beta = sl(b, base, h);
5681                        let out = sl_mut(dst, base, rows * h);
5682                        for row in 0..rows {
5683                            crate::kernels::layer_norm_row(
5684                                &inp[row * h..(row + 1) * h],
5685                                gamma,
5686                                beta,
5687                                &mut out[row * h..(row + 1) * h],
5688                                h,
5689                                eps,
5690                            );
5691                        }
5692                    })
5693                }
5694
5695                Thunk::Attention {
5696                    q,
5697                    k,
5698                    v,
5699                    mask,
5700                    out,
5701                    batch,
5702                    seq,
5703                    kv_seq: _,
5704                    heads,
5705                    head_dim,
5706                    mask_kind,
5707                    q_row_stride,
5708                    k_row_stride,
5709                    v_row_stride,
5710                    bhsd,
5711                } => {
5712                    let (b, s, nh, dh) = (
5713                        batch as usize,
5714                        seq as usize,
5715                        heads as usize,
5716                        head_dim as usize,
5717                    );
5718                    let hs = nh * dh;
5719                    let qrs = q_row_stride as usize;
5720                    let krs = k_row_stride as usize;
5721                    let vrs = v_row_stride as usize;
5722                    let scale = (dh as f32).powf(-0.5);
5723                    Arc::new(move |base: *mut u8| unsafe {
5724                        // Slice lengths use the source's row stride so the
5725                        // compiler-emitted bounds checks cover the whole
5726                        // strided span (the kernel walks with q/k/v_rs).
5727                        // For [B, H, S, D] the buffer is dense B*H*S*D.
5728                        let (q_len, k_len, v_len, o_len) = if bhsd {
5729                            let n = b * nh * s * dh;
5730                            (n, n, n, n)
5731                        } else {
5732                            (b * s * qrs, b * s * krs, b * s * vrs, b * s * hs)
5733                        };
5734                        let q_d = sl(q, base, q_len);
5735                        let k_d = sl(k, base, k_len);
5736                        let v_d = sl(v, base, v_len);
5737                        let m_d: &[f32] = match mask_kind {
5738                            rlx_ir::op::MaskKind::Custom => sl(mask, base, b * s),
5739                            rlx_ir::op::MaskKind::Bias => sl(mask, base, b * nh * s * s),
5740                            _ => &[],
5741                        };
5742                        let o_d = sl_mut(out, base, o_len);
5743                        let sdh = s * dh;
5744                        let mut qh = vec![0f32; sdh];
5745                        let mut kh = vec![0f32; sdh];
5746                        let mut vh = vec![0f32; sdh];
5747                        let mut sc = vec![0f32; s * s];
5748                        let mut oh = vec![0f32; sdh];
5749                        for bi in 0..b {
5750                            for hi in 0..nh {
5751                                for si in 0..s {
5752                                    // Two layouts:
5753                                    //   bhsd=false: [B, S, H, D] (default) →
5754                                    //     off = bi*S*RS + si*RS + hi*D
5755                                    //   bhsd=true:  [B, H, S, D] (GPU/TPU
5756                                    //     convention) →
5757                                    //     off = bi*H*S*D + hi*S*D + si*D
5758                                    // The thunk-fusion pass below sets row
5759                                    // strides, but only for the [B, S, H, D]
5760                                    // case. For bhsd we always use the dense
5761                                    // contiguous stride (qrs == krs == vrs ==
5762                                    // H*D from compile_thunks).
5763                                    let (q_off, k_off, v_off) = if bhsd {
5764                                        (
5765                                            bi * nh * s * dh + hi * s * dh + si * dh,
5766                                            bi * nh * s * dh + hi * s * dh + si * dh,
5767                                            bi * nh * s * dh + hi * s * dh + si * dh,
5768                                        )
5769                                    } else {
5770                                        (
5771                                            bi * s * qrs + si * qrs + hi * dh,
5772                                            bi * s * krs + si * krs + hi * dh,
5773                                            bi * s * vrs + si * vrs + hi * dh,
5774                                        )
5775                                    };
5776                                    qh[si * dh..(si + 1) * dh]
5777                                        .copy_from_slice(&q_d[q_off..q_off + dh]);
5778                                    kh[si * dh..(si + 1) * dh]
5779                                        .copy_from_slice(&k_d[k_off..k_off + dh]);
5780                                    vh[si * dh..(si + 1) * dh]
5781                                        .copy_from_slice(&v_d[v_off..v_off + dh]);
5782                                }
5783                                for qi in 0..s {
5784                                    for ki in 0..s {
5785                                        let mut dot = 0f32;
5786                                        for d in 0..dh {
5787                                            dot += qh[qi * dh + d] * kh[ki * dh + d];
5788                                        }
5789                                        sc[qi * s + ki] = dot * scale;
5790                                    }
5791                                }
5792                                // Apply mask kind — None skips entirely, Causal /
5793                                // SlidingWindow synthesize, Custom reads m_d.
5794                                match mask_kind {
5795                                    rlx_ir::op::MaskKind::None => {}
5796                                    rlx_ir::op::MaskKind::Causal => {
5797                                        for qi in 0..s {
5798                                            for ki in (qi + 1)..s {
5799                                                sc[qi * s + ki] = mask_neg;
5800                                            }
5801                                        }
5802                                    }
5803                                    rlx_ir::op::MaskKind::SlidingWindow(w) => {
5804                                        for qi in 0..s {
5805                                            let lo = qi.saturating_sub(w);
5806                                            for ki in 0..s {
5807                                                if ki < lo || ki > qi {
5808                                                    sc[qi * s + ki] = mask_neg;
5809                                                }
5810                                            }
5811                                        }
5812                                    }
5813                                    rlx_ir::op::MaskKind::Custom => {
5814                                        for qi in 0..s {
5815                                            for ki in 0..s {
5816                                                if m_d[bi * s + ki] < mask_thr {
5817                                                    sc[qi * s + ki] = mask_neg;
5818                                                }
5819                                            }
5820                                        }
5821                                    }
5822                                    rlx_ir::op::MaskKind::Bias => {
5823                                        let per_bh = s * s;
5824                                        let off = (bi * nh + hi) * per_bh;
5825                                        for i in 0..per_bh {
5826                                            sc[i] += m_d[off + i];
5827                                        }
5828                                    }
5829                                }
5830                                crate::naive::softmax(&mut sc, s, s);
5831                                oh.fill(0.0);
5832                                for qi in 0..s {
5833                                    for ki in 0..s {
5834                                        let w = sc[qi * s + ki];
5835                                        if w > score_skip {
5836                                            for d in 0..dh {
5837                                                oh[qi * dh + d] += w * vh[ki * dh + d];
5838                                            }
5839                                        }
5840                                    }
5841                                }
5842                                for si in 0..s {
5843                                    let off = if bhsd {
5844                                        bi * nh * s * dh + hi * s * dh + si * dh
5845                                    } else {
5846                                        bi * s * hs + si * hs + hi * dh
5847                                    };
5848                                    o_d[off..off + dh].copy_from_slice(&oh[si * dh..(si + 1) * dh]);
5849                                }
5850                            }
5851                        }
5852                    })
5853                }
5854
5855                Thunk::FusedSwiGLU {
5856                    src,
5857                    dst,
5858                    n_half,
5859                    total,
5860                    gate_first,
5861                } => {
5862                    let n = n_half as usize;
5863                    let t = total as usize;
5864                    let outer = t / n;
5865                    let in_total = outer * 2 * n;
5866                    Arc::new(move |base: *mut u8| unsafe {
5867                        let inp = sl(src, base, in_total);
5868                        let out = sl_mut(dst, base, t);
5869                        for o in 0..outer {
5870                            let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
5871                            let out_row = &mut out[o * n..(o + 1) * n];
5872                            for i in 0..n {
5873                                let (up, gate) = if gate_first {
5874                                    (in_row[n + i], in_row[i])
5875                                } else {
5876                                    (in_row[i], in_row[n + i])
5877                                };
5878                                out_row[i] = up * (gate / (1.0 + (-gate).exp()));
5879                            }
5880                        }
5881                    })
5882                }
5883
5884                Thunk::Concat {
5885                    dst,
5886                    outer,
5887                    inner,
5888                    total_axis,
5889                    inputs,
5890                } => {
5891                    let outer = outer as usize;
5892                    let inner = inner as usize;
5893                    let total_axis = total_axis as usize;
5894                    let out_total = outer * total_axis * inner;
5895                    // Pre-compute the destination row offset for each input
5896                    // (cumulative axis offsets times inner).
5897                    let mut layout: Vec<(usize, usize, usize)> = Vec::with_capacity(inputs.len());
5898                    let mut cum: usize = 0;
5899                    for (src_off, in_axis) in &inputs {
5900                        let in_axis = *in_axis as usize;
5901                        layout.push((*src_off, cum * inner, in_axis * inner));
5902                        cum += in_axis;
5903                    }
5904                    Arc::new(move |base: *mut u8| unsafe {
5905                        let out = sl_mut(dst, base, out_total);
5906                        let row_stride = total_axis * inner;
5907                        for (src_off, dst_col_off, copy_per_row) in &layout {
5908                            let in_total = outer * *copy_per_row;
5909                            let inp = sl(*src_off, base, in_total);
5910                            for o in 0..outer {
5911                                let dst_row_start = o * row_stride + *dst_col_off;
5912                                let src_row_start = o * *copy_per_row;
5913                                out[dst_row_start..dst_row_start + *copy_per_row].copy_from_slice(
5914                                    &inp[src_row_start..src_row_start + *copy_per_row],
5915                                );
5916                            }
5917                        }
5918                    })
5919                }
5920
5921                Thunk::CustomOp {
5922                    kernel,
5923                    inputs,
5924                    output,
5925                    attrs,
5926                } => {
5927                    // Capture-by-move: clone the Arc and Vecs once into the
5928                    // closure. Dispatch by output dtype each call (the
5929                    // dtype is fixed at compile time but it's cheaper to
5930                    // branch once per execution than to monomorphize a
5931                    // dozen closure variants).
5932                    let kernel = kernel.clone();
5933                    let attrs = attrs.clone();
5934                    let inputs = inputs.clone();
5935                    let (out_off, out_len, out_shape) = output.clone();
5936                    Arc::new(move |base: *mut u8| unsafe {
5937                        dispatch_custom_op(
5938                            &*kernel, &inputs, out_off, out_len, &out_shape, &attrs, base,
5939                        );
5940                    })
5941                }
5942
5943                Thunk::GaussianSplatRender {
5944                    positions_off,
5945                    positions_len,
5946                    scales_off,
5947                    scales_len,
5948                    rotations_off,
5949                    rotations_len,
5950                    opacities_off,
5951                    opacities_len,
5952                    colors_off,
5953                    colors_len,
5954                    sh_coeffs_off,
5955                    sh_coeffs_len,
5956                    meta_off,
5957                    dst_off,
5958                    dst_len,
5959                    width,
5960                    height,
5961                    tile_size,
5962                    radius_scale,
5963                    alpha_cutoff,
5964                    max_splat_steps,
5965                    transmittance_threshold,
5966                    max_list_entries,
5967                } => Arc::new(move |base: *mut u8| unsafe {
5968                    crate::splat::execute_gaussian_splat_render(
5969                        positions_off,
5970                        positions_len,
5971                        scales_off,
5972                        scales_len,
5973                        rotations_off,
5974                        rotations_len,
5975                        opacities_off,
5976                        opacities_len,
5977                        colors_off,
5978                        colors_len,
5979                        sh_coeffs_off,
5980                        sh_coeffs_len,
5981                        meta_off,
5982                        dst_off,
5983                        dst_len,
5984                        width,
5985                        height,
5986                        tile_size,
5987                        radius_scale,
5988                        alpha_cutoff,
5989                        max_splat_steps,
5990                        transmittance_threshold,
5991                        max_list_entries,
5992                        base,
5993                    );
5994                }),
5995
5996                Thunk::GaussianSplatRenderBackward {
5997                    positions_off,
5998                    positions_len,
5999                    scales_off,
6000                    scales_len,
6001                    rotations_off,
6002                    rotations_len,
6003                    opacities_off,
6004                    opacities_len,
6005                    colors_off,
6006                    colors_len,
6007                    sh_coeffs_off,
6008                    sh_coeffs_len,
6009                    meta_off,
6010                    d_loss_off,
6011                    d_loss_len,
6012                    packed_off,
6013                    packed_len,
6014                    width,
6015                    height,
6016                    tile_size,
6017                    radius_scale,
6018                    alpha_cutoff,
6019                    max_splat_steps,
6020                    transmittance_threshold,
6021                    max_list_entries,
6022                    loss_grad_clip,
6023                    sh_band,
6024                    max_anisotropy,
6025                } => Arc::new(move |base: *mut u8| unsafe {
6026                    crate::splat::execute_gaussian_splat_render_backward(
6027                        positions_off,
6028                        positions_len,
6029                        scales_off,
6030                        scales_len,
6031                        rotations_off,
6032                        rotations_len,
6033                        opacities_off,
6034                        opacities_len,
6035                        colors_off,
6036                        colors_len,
6037                        sh_coeffs_off,
6038                        sh_coeffs_len,
6039                        meta_off,
6040                        d_loss_off,
6041                        d_loss_len,
6042                        packed_off,
6043                        packed_len,
6044                        width,
6045                        height,
6046                        tile_size,
6047                        radius_scale,
6048                        alpha_cutoff,
6049                        max_splat_steps,
6050                        transmittance_threshold,
6051                        max_list_entries,
6052                        loss_grad_clip,
6053                        sh_band,
6054                        max_anisotropy,
6055                        base,
6056                    );
6057                }),
6058
6059                Thunk::GaussianSplatPrepare {
6060                    positions_off,
6061                    positions_len,
6062                    scales_off,
6063                    scales_len,
6064                    rotations_off,
6065                    rotations_len,
6066                    opacities_off,
6067                    opacities_len,
6068                    colors_off,
6069                    colors_len,
6070                    sh_coeffs_off,
6071                    sh_coeffs_len,
6072                    meta_off,
6073                    meta_len,
6074                    prep_off,
6075                    prep_len,
6076                    width,
6077                    height,
6078                    tile_size,
6079                    radius_scale,
6080                    alpha_cutoff,
6081                    max_splat_steps,
6082                    transmittance_threshold,
6083                    max_list_entries,
6084                } => Arc::new(move |base: *mut u8| unsafe {
6085                    crate::splat::execute_gaussian_splat_prepare(
6086                        positions_off,
6087                        positions_len,
6088                        scales_off,
6089                        scales_len,
6090                        rotations_off,
6091                        rotations_len,
6092                        opacities_off,
6093                        opacities_len,
6094                        colors_off,
6095                        colors_len,
6096                        sh_coeffs_off,
6097                        sh_coeffs_len,
6098                        meta_off,
6099                        meta_len,
6100                        prep_off,
6101                        prep_len,
6102                        width,
6103                        height,
6104                        tile_size,
6105                        radius_scale,
6106                        alpha_cutoff,
6107                        max_splat_steps,
6108                        transmittance_threshold,
6109                        max_list_entries,
6110                        base,
6111                    );
6112                }),
6113
6114                Thunk::GaussianSplatRasterize {
6115                    prep_off,
6116                    prep_len,
6117                    meta_off,
6118                    meta_len,
6119                    dst_off,
6120                    dst_len,
6121                    count,
6122                    width,
6123                    height,
6124                    tile_size,
6125                    alpha_cutoff,
6126                    max_splat_steps,
6127                    transmittance_threshold,
6128                    max_list_entries,
6129                } => Arc::new(move |base: *mut u8| unsafe {
6130                    crate::splat::execute_gaussian_splat_rasterize(
6131                        prep_off,
6132                        prep_len,
6133                        meta_off,
6134                        meta_len,
6135                        dst_off,
6136                        dst_len,
6137                        count,
6138                        width,
6139                        height,
6140                        tile_size,
6141                        alpha_cutoff,
6142                        max_splat_steps,
6143                        transmittance_threshold,
6144                        max_list_entries,
6145                        base,
6146                    );
6147                }),
6148
6149                Thunk::Fft1d {
6150                    src,
6151                    dst,
6152                    outer,
6153                    n_complex,
6154                    inverse,
6155                    dtype,
6156                } => {
6157                    let f: Arc<dyn Fn(*mut u8) + Send + Sync> = match dtype {
6158                        rlx_ir::DType::F64 => Arc::new(move |base: *mut u8| unsafe {
6159                            execute_fft1d_f64(
6160                                src,
6161                                dst,
6162                                outer as usize,
6163                                n_complex as usize,
6164                                inverse,
6165                                base,
6166                            );
6167                        }),
6168                        rlx_ir::DType::F32 => Arc::new(move |base: *mut u8| unsafe {
6169                            execute_fft1d_f32(
6170                                src,
6171                                dst,
6172                                outer as usize,
6173                                n_complex as usize,
6174                                inverse,
6175                                base,
6176                            );
6177                        }),
6178                        other => panic!("Op::Fft on CPU requires F32/F64, got {other:?}"),
6179                    };
6180                    f
6181                }
6182
6183                _ => Arc::new(|_: *mut u8| {}),
6184            }
6185        })
6186        .collect();
6187
6188    // ── Thunk-level attention fusion ──────────────────────
6189    // For small batch*seq, fuse QKV→Narrow×3→[Rope×2]→Attention→OutProj
6190    // into a single FusedAttnBlock. Auto-detects from Attention thunks.
6191    let fuse_threshold: usize = rlx_ir::env::var("RLX_FUSE_ATTN_THRESHOLD")
6192        .and_then(|v| v.parse().ok())
6193        .unwrap_or(64);
6194    let should_fuse = thunks.iter().any(|t| match t {
6195        Thunk::Attention { batch, seq, .. } => {
6196            (*batch as usize) * (*seq as usize) <= fuse_threshold
6197        }
6198        _ => false,
6199    });
6200
6201    if should_fuse {
6202        // Build non-Nop index for pattern matching across Nop gaps
6203        let active: Vec<usize> = thunks
6204            .iter()
6205            .enumerate()
6206            .filter(|(_, t)| !matches!(t, Thunk::Nop))
6207            .map(|(i, _)| i)
6208            .collect();
6209
6210        let mut kill = vec![false; thunks.len()]; // mark thunks to remove
6211        let mut insertions: Vec<(usize, Thunk)> = Vec::new(); // (position, replacement)
6212
6213        let mut ai = 0;
6214        while ai < active.len() {
6215            // Helper: get active thunk at offset from current
6216            let a = |off: usize| -> Option<(usize, &Thunk)> {
6217                active.get(ai + off).map(|&idx| (idx, &thunks[idx]))
6218            };
6219
6220            // Try BERT pattern: FusedMmBiasAct(QKV) → Narrow×3 → Attention → FusedMmBiasAct(out)
6221            let matched = (|| {
6222                let (_i0, t0) = a(0)?;
6223                let (_, t1) = a(1)?;
6224                let (_, t2) = a(2)?;
6225                let (_, t3) = a(3)?;
6226
6227                // a[0] must be FusedMmBiasAct or Sgemm (QKV projection)
6228                let (hidden, qkv_w, qkv_b, has_b) = match t0 {
6229                    Thunk::FusedMmBiasAct {
6230                        a,
6231                        w,
6232                        bias,
6233                        n: _,
6234                        act: None,
6235                        ..
6236                    } => (*a, *w, *bias, true),
6237                    Thunk::Sgemm { a, b, n: _, .. } => (*a, *b, 0, false),
6238                    _ => return None,
6239                };
6240
6241                // a[1..3] must be Narrows
6242                if !matches!(t1, Thunk::Narrow { .. }) {
6243                    return None;
6244                }
6245                if !matches!(t2, Thunk::Narrow { .. }) {
6246                    return None;
6247                }
6248                if !matches!(t3, Thunk::Narrow { .. }) {
6249                    return None;
6250                }
6251
6252                // Look for optional Rope×2 then Attention
6253                let (has_rope, attn_ai, cos_off, sin_off, cl) = if let Some((
6254                    _,
6255                    Thunk::Rope {
6256                        cos, sin, cos_len, ..
6257                    },
6258                )) = a(4)
6259                {
6260                    if matches!(a(5).map(|x| x.1), Some(Thunk::Rope { .. })) {
6261                        if matches!(a(6).map(|x| x.1), Some(Thunk::Attention { .. })) {
6262                            (true, 6, *cos, *sin, *cos_len)
6263                        } else {
6264                            return None;
6265                        }
6266                    } else {
6267                        return None;
6268                    }
6269                } else if matches!(a(4).map(|x| x.1), Some(Thunk::Attention { .. })) {
6270                    (false, 4, 0, 0, 0)
6271                } else {
6272                    return None;
6273                };
6274
6275                let (_attn_real_idx, attn_t) = a(attn_ai)?;
6276                let (batch, seq, heads, head_dim, mask) = match attn_t {
6277                    Thunk::Attention {
6278                        batch,
6279                        seq,
6280                        heads,
6281                        head_dim,
6282                        mask,
6283                        ..
6284                    } => (*batch, *seq, *heads, *head_dim, *mask),
6285                    _ => return None,
6286                };
6287
6288                // Next active must be out projection (FusedMmBiasAct or Sgemm)
6289                let (_out_real_idx, out_t) = a(attn_ai + 1)?;
6290                let (out_w, out_b, out_dst) = match out_t {
6291                    Thunk::FusedMmBiasAct {
6292                        w,
6293                        bias,
6294                        c,
6295                        act: None,
6296                        ..
6297                    } => (*w, *bias, *c),
6298                    Thunk::Sgemm { b: w, c, .. } => (*w, 0, *c),
6299                    _ => return None,
6300                };
6301
6302                let hs = heads * head_dim;
6303                let total_active = attn_ai + 2; // number of active thunks consumed
6304
6305                Some((
6306                    total_active,
6307                    Thunk::FusedAttnBlock {
6308                        hidden,
6309                        qkv_w,
6310                        out_w,
6311                        mask,
6312                        out: out_dst,
6313                        qkv_b: if has_b { qkv_b } else { 0 },
6314                        out_b: if has_b { out_b } else { 0 },
6315                        cos: cos_off,
6316                        sin: sin_off,
6317                        cos_len: cl,
6318                        batch,
6319                        seq,
6320                        hs,
6321                        nh: heads,
6322                        dh: head_dim,
6323                        has_bias: has_b,
6324                        has_rope,
6325                    },
6326                ))
6327            })();
6328
6329            if let Some((count, fused_thunk)) = matched {
6330                // Mark consumed thunks for removal
6331                for off in 0..count {
6332                    if let Some(&idx) = active.get(ai + off) {
6333                        kill[idx] = true;
6334                    }
6335                }
6336                // Insert replacement at position of the QKV thunk
6337                insertions.push((active[ai], fused_thunk));
6338                ai += count;
6339            } else {
6340                ai += 1;
6341            }
6342        }
6343
6344        // Rebuild thunk list: keep non-killed, insert fused at right positions
6345        if !insertions.is_empty() {
6346            let mut new_thunks = Vec::with_capacity(thunks.len());
6347            let mut insert_idx = 0;
6348            for (i, t) in thunks.into_iter().enumerate() {
6349                if insert_idx < insertions.len() && insertions[insert_idx].0 == i {
6350                    new_thunks.push(insertions[insert_idx].1.clone());
6351                    insert_idx += 1;
6352                }
6353                if !kill[i] {
6354                    new_thunks.push(t);
6355                }
6356            }
6357            if cfg.verbose >= 1 {
6358                eprintln!(
6359                    "[rlx] fused_attention: {} attention blocks fused",
6360                    insertions.len()
6361                );
6362            }
6363            thunks = new_thunks;
6364        }
6365    }
6366
6367    // ── Full layer fusion ──────────────────────────────────
6368    // After attention blocks are fused, scan for full layer patterns:
6369    // BERT:  FusedAttnBlock → FusedResidualLN → FusedMmBiasAct(gelu) → Sgemm → BiasAdd → FusedResidualLN
6370    // Nomic: FusedAttnBlock → BinaryFull(add) → LayerNorm → Sgemm → [Narrow×2 → Silu → BinaryFull(mul)] → Sgemm → BinaryFull(add) → LayerNorm
6371    if should_fuse {
6372        let active: Vec<usize> = thunks
6373            .iter()
6374            .enumerate()
6375            .filter(|(_, t)| !matches!(t, Thunk::Nop))
6376            .map(|(i, _)| i)
6377            .collect();
6378
6379        let mut kill = vec![false; thunks.len()];
6380        let mut insertions: Vec<(usize, Thunk)> = Vec::new();
6381
6382        let a = |ai: usize| -> Option<&Thunk> { active.get(ai).map(|&i| &thunks[i]) };
6383
6384        let mut ai = 0;
6385        while ai < active.len() {
6386            // BERT pattern: FusedAttnBlock → FusedResidualLN → FusedMmBiasAct(gelu) → FusedMmBiasAct(none) → FusedResidualLN
6387            let bert_match = (|| -> Option<usize> {
6388                let fab = a(ai)?;
6389                let rln1 = a(ai + 1)?;
6390                let ffn1 = a(ai + 2)?;
6391                let ffn2 = a(ai + 3)?;
6392                let rln2 = a(ai + 4)?;
6393
6394                let (hidden, qkv_w, qkv_b, out_w, out_b, mask, batch, seq, hs, nh, dh) = match fab {
6395                    Thunk::FusedAttnBlock {
6396                        hidden,
6397                        qkv_w,
6398                        qkv_b,
6399                        out_w,
6400                        out_b,
6401                        mask,
6402                        batch,
6403                        seq,
6404                        hs,
6405                        nh,
6406                        dh,
6407                        has_bias: true,
6408                        has_rope: false,
6409                        ..
6410                    } => (
6411                        *hidden, *qkv_w, *qkv_b, *out_w, *out_b, *mask, *batch, *seq, *hs, *nh, *dh,
6412                    ),
6413                    _ => return None,
6414                };
6415                let (ln1_g, ln1_b, eps1) = match rln1 {
6416                    Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
6417                    _ => return None,
6418                };
6419                let (fc1_w, fc1_b, int_dim) = match ffn1 {
6420                    Thunk::FusedMmBiasAct {
6421                        w,
6422                        bias,
6423                        n,
6424                        act: Some(Activation::Gelu),
6425                        ..
6426                    } => (*w, *bias, *n),
6427                    _ => return None,
6428                };
6429                let (fc2_w, fc2_b) = match ffn2 {
6430                    Thunk::FusedMmBiasAct {
6431                        w, bias, act: None, ..
6432                    } => (*w, *bias),
6433                    _ => return None,
6434                };
6435                let (ln2_g, ln2_b, eps2, out) = match rln2 {
6436                    Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
6437                    _ => return None,
6438                };
6439
6440                for off in 0..5 {
6441                    kill[active[ai + off]] = true;
6442                }
6443                insertions.push((
6444                    active[ai],
6445                    Thunk::FusedBertLayer {
6446                        hidden,
6447                        qkv_w,
6448                        qkv_b,
6449                        out_w,
6450                        out_b,
6451                        mask,
6452                        ln1_g,
6453                        ln1_b,
6454                        eps1,
6455                        fc1_w,
6456                        fc1_b,
6457                        fc2_w,
6458                        fc2_b,
6459                        ln2_g,
6460                        ln2_b,
6461                        eps2,
6462                        out,
6463                        batch,
6464                        seq,
6465                        hs,
6466                        nh,
6467                        dh,
6468                        int_dim,
6469                    },
6470                ));
6471                Some(5)
6472            })();
6473            if let Some(n) = bert_match {
6474                ai += n;
6475                continue;
6476            }
6477
6478            // Nomic full layer fusion — disabled pending SwiGLU stride debugging.
6479            // Nomic still benefits from FusedAttnBlock (attention-level fusion).
6480            // The body below is kept as reference for when the stride bug is fixed.
6481            #[allow(unreachable_code)]
6482            let nomic_match = (|| -> Option<usize> {
6483                return None; // TODO: fix SwiGLU strided fc2 output mismatch
6484                let fab = a(ai)?;
6485                let (hidden, qkv_w, out_w, mask, cos, sin, cos_len, batch, seq, hs, nh, dh) =
6486                    match fab {
6487                        Thunk::FusedAttnBlock {
6488                            hidden,
6489                            qkv_w,
6490                            out_w,
6491                            mask,
6492                            cos,
6493                            sin,
6494                            cos_len,
6495                            batch,
6496                            seq,
6497                            hs,
6498                            nh,
6499                            dh,
6500                            has_bias: false,
6501                            has_rope: true,
6502                            ..
6503                        } => (
6504                            *hidden, *qkv_w, *out_w, *mask, *cos, *sin, *cos_len, *batch, *seq,
6505                            *hs, *nh, *dh,
6506                        ),
6507                        _ => return None,
6508                    };
6509                // FusedResidualLN for LN1
6510                let (ln1_g, ln1_b, eps1) = match a(ai + 1)? {
6511                    Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
6512                    _ => return None,
6513                };
6514                // Sgemm (fused fc11+fc12)
6515                let fused_fc_w = match a(ai + 2)? {
6516                    Thunk::Sgemm { b: w, .. } => *w,
6517                    _ => return None,
6518                };
6519                // Narrow×2 for split
6520                if !matches!(a(ai + 3)?, Thunk::Narrow { .. }) {
6521                    return None;
6522                }
6523                if !matches!(a(ai + 4)?, Thunk::Narrow { .. }) {
6524                    return None;
6525                }
6526                // SiLU
6527                if !matches!(
6528                    a(ai + 5)?,
6529                    Thunk::ActivationInPlace {
6530                        act: Activation::Silu,
6531                        ..
6532                    }
6533                ) {
6534                    return None;
6535                }
6536                // BinaryFull(Mul) for gate
6537                if !matches!(
6538                    a(ai + 6)?,
6539                    Thunk::BinaryFull {
6540                        op: BinaryOp::Mul,
6541                        ..
6542                    }
6543                ) {
6544                    return None;
6545                }
6546                // Sgemm (fc2)
6547                let fc2_w = match a(ai + 7)? {
6548                    Thunk::Sgemm { b: w, .. } => *w,
6549                    _ => return None,
6550                };
6551                // Get int_dim from the Narrow (inner = int_dim for last-axis narrow)
6552                let int_dim = match a(ai + 3)? {
6553                    Thunk::Narrow { inner, .. } => *inner,
6554                    _ => return None,
6555                };
6556                // FusedResidualLN for LN2
6557                let (ln2_g, ln2_b, eps2, out) = match a(ai + 8)? {
6558                    Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
6559                    _ => return None,
6560                };
6561
6562                for off in 0..9 {
6563                    kill[active[ai + off]] = true;
6564                }
6565                insertions.push((
6566                    active[ai],
6567                    Thunk::FusedNomicLayer {
6568                        hidden,
6569                        qkv_w,
6570                        out_w,
6571                        mask,
6572                        cos,
6573                        sin,
6574                        cos_len,
6575                        ln1_g,
6576                        ln1_b,
6577                        eps1,
6578                        fc11_w: fused_fc_w,
6579                        fc12_w: 0,
6580                        fc2_w,
6581                        ln2_g,
6582                        ln2_b,
6583                        eps2,
6584                        out,
6585                        batch,
6586                        seq,
6587                        hs,
6588                        nh,
6589                        dh,
6590                        int_dim,
6591                    },
6592                ));
6593                Some(9)
6594            })();
6595            if let Some(n) = nomic_match {
6596                ai += n;
6597                continue;
6598            }
6599
6600            ai += 1;
6601        }
6602
6603        if !insertions.is_empty() {
6604            let mut new_thunks = Vec::with_capacity(thunks.len());
6605            let mut ins_idx = 0;
6606            for (i, t) in thunks.into_iter().enumerate() {
6607                if ins_idx < insertions.len() && insertions[ins_idx].0 == i {
6608                    new_thunks.push(insertions[ins_idx].1.clone());
6609                    ins_idx += 1;
6610                }
6611                if !kill[i] {
6612                    new_thunks.push(t);
6613                }
6614            }
6615            if cfg.verbose >= 1 {
6616                eprintln!(
6617                    "[rlx] fused_layer: {} full transformer layers fused",
6618                    insertions.len()
6619                );
6620            }
6621            thunks = new_thunks;
6622        }
6623    }
6624
6625    // ── Narrow → Rope thunk fusion (plan #45) ──────────────
6626    // Runs *after* FusedAttnBlock fusion so it only catches the medium-
6627    // batch path (batch*seq > 64) where the bigger fusion didn't fire.
6628    // Pattern: a Rope thunk whose `src` is the dst of an immediately-
6629    // preceding Narrow whose dst has no other consumer in this schedule.
6630    // Rewrite Rope to read directly from the parent buffer with the
6631    // parent's row stride; the Narrow becomes a Nop.
6632    //
6633    // Skipping the Narrow's write saves one full pass over Q/K (B*S*hs
6634    // f32) per Rope. For Nomic h=768 / batch=8 / seq=15 / 12 layers
6635    // that's 2 ropes/layer × 369 KB = ~8.9 MB of write traffic gone.
6636    {
6637        // Collect every byte-offset that's read as a thunk's `src` so
6638        // we know whether a Narrow's dst has consumers other than Rope.
6639        let mut read_offsets: HashMap<usize, usize> = HashMap::new();
6640        for t in &thunks {
6641            for off in thunk_read_offsets(t) {
6642                *read_offsets.entry(off).or_insert(0) += 1;
6643            }
6644        }
6645
6646        let mut fused_count = 0usize;
6647        for i in 0..thunks.len().saturating_sub(1) {
6648            // Look for Rope at i+1 reading from Narrow at i (skip Nops
6649            // between them since the planner left them in place).
6650            let narrow = match &thunks[i] {
6651                Thunk::Narrow { .. } => i,
6652                _ => continue,
6653            };
6654            // Find the next non-Nop thunk
6655            let mut j = narrow + 1;
6656            while j < thunks.len() && matches!(thunks[j], Thunk::Nop) {
6657                j += 1;
6658            }
6659            if j >= thunks.len() {
6660                continue;
6661            }
6662            // Must be Rope reading Narrow's dst
6663            let (n_src, n_dst, n_src_stride) = match &thunks[narrow] {
6664                Thunk::Narrow {
6665                    src,
6666                    dst,
6667                    src_stride,
6668                    ..
6669                } => (*src, *dst, *src_stride),
6670                _ => continue,
6671            };
6672            let rope_reads_narrow = matches!(&thunks[j],
6673                Thunk::Rope { src, .. } if *src == n_dst);
6674            if !rope_reads_narrow {
6675                continue;
6676            }
6677            // Conservatively require that the Narrow's dst has exactly
6678            // one reader (the Rope). Anything else and rewriting would
6679            // skip a needed write.
6680            if read_offsets.get(&n_dst).copied().unwrap_or(0) != 1 {
6681                continue;
6682            }
6683
6684            // Rewire: Rope reads from Narrow's adjusted source with the
6685            // parent buffer's row stride.
6686            if let Thunk::Rope {
6687                src,
6688                src_row_stride,
6689                ..
6690            } = &mut thunks[j]
6691            {
6692                *src = n_src;
6693                *src_row_stride = n_src_stride;
6694            }
6695            thunks[narrow] = Thunk::Nop;
6696            fused_count += 1;
6697        }
6698
6699        if fused_count > 0 && cfg.verbose >= 1 {
6700            eprintln!(
6701                "[rlx] fused_qk_rope: {} Narrow→Rope pairs collapsed",
6702                fused_count
6703            );
6704        }
6705    }
6706
6707    // ── Narrow×3 → Attention thunk fusion (plan #46 deep) ────
6708    // For each Attention thunk in the schedule, look up the producers
6709    // of its q/k/v inputs. If each is a Narrow whose dst has exactly
6710    // one consumer (the Attention), rewire Attention to read directly
6711    // from the parent buffer with the parent's row stride. The three
6712    // Narrows become Nops.
6713    //
6714    // This catches the BERT/Nomic QKV split path that FusedAttnBlock
6715    // misses (batch*seq > 64) — eliminates Q/K/V copies entirely.
6716    // For minilm6 batch=32 seq=16 hs=384: 3 × 32*16*384*4 = 2.3 MB
6717    // per layer × 6 layers = ~14 MB of write traffic gone.
6718    {
6719        let mut read_counts: HashMap<usize, usize> = HashMap::new();
6720        for t in &thunks {
6721            for off in thunk_read_offsets(t) {
6722                *read_counts.entry(off).or_insert(0) += 1;
6723            }
6724        }
6725        // Build dst→index map for fast producer lookup.
6726        let mut dst_to_idx: HashMap<usize, usize> = HashMap::new();
6727        for (i, t) in thunks.iter().enumerate() {
6728            if let Thunk::Narrow { dst, .. } = t {
6729                dst_to_idx.insert(*dst, i);
6730            }
6731        }
6732
6733        let mut fused_count = 0usize;
6734        for i in 0..thunks.len() {
6735            let (q_off, k_off, v_off) = match &thunks[i] {
6736                Thunk::Attention { q, k, v, .. } => (*q, *k, *v),
6737                _ => continue,
6738            };
6739            // All three inputs must come from Narrows.
6740            let q_n = match dst_to_idx.get(&q_off).copied() {
6741                Some(x) => x,
6742                None => continue,
6743            };
6744            let k_n = match dst_to_idx.get(&k_off).copied() {
6745                Some(x) => x,
6746                None => continue,
6747            };
6748            let v_n = match dst_to_idx.get(&v_off).copied() {
6749                Some(x) => x,
6750                None => continue,
6751            };
6752            // Each Narrow's dst must have exactly one reader (this Attn).
6753            if read_counts.get(&q_off).copied().unwrap_or(0) != 1 {
6754                continue;
6755            }
6756            if read_counts.get(&k_off).copied().unwrap_or(0) != 1 {
6757                continue;
6758            }
6759            if read_counts.get(&v_off).copied().unwrap_or(0) != 1 {
6760                continue;
6761            }
6762
6763            let (q_src, q_stride) = match &thunks[q_n] {
6764                Thunk::Narrow {
6765                    src, src_stride, ..
6766                } => (*src, *src_stride),
6767                _ => continue,
6768            };
6769            let (k_src, k_stride) = match &thunks[k_n] {
6770                Thunk::Narrow {
6771                    src, src_stride, ..
6772                } => (*src, *src_stride),
6773                _ => continue,
6774            };
6775            let (v_src, v_stride) = match &thunks[v_n] {
6776                Thunk::Narrow {
6777                    src, src_stride, ..
6778                } => (*src, *src_stride),
6779                _ => continue,
6780            };
6781
6782            if let Thunk::Attention {
6783                q,
6784                k,
6785                v,
6786                q_row_stride,
6787                k_row_stride,
6788                v_row_stride,
6789                ..
6790            } = &mut thunks[i]
6791            {
6792                *q = q_src;
6793                *k = k_src;
6794                *v = v_src;
6795                *q_row_stride = q_stride;
6796                *k_row_stride = k_stride;
6797                *v_row_stride = v_stride;
6798            }
6799            thunks[q_n] = Thunk::Nop;
6800            thunks[k_n] = Thunk::Nop;
6801            thunks[v_n] = Thunk::Nop;
6802            fused_count += 1;
6803        }
6804
6805        if fused_count > 0 && cfg.verbose >= 1 {
6806            eprintln!(
6807                "[rlx] fused_strided_attn: {} Narrow×3→Attention rewrites",
6808                fused_count
6809            );
6810        }
6811    }
6812
6813    ThunkSchedule {
6814        thunks,
6815        moe_resident: None,
6816        moe_resident_layers: None,
6817        moe_topk_capture: None,
6818        mask_threshold: cfg.mask_binary_threshold,
6819        mask_neg_inf: cfg.attn_mask_neg_inf,
6820        score_skip: cfg.score_skip_threshold,
6821        compiled_fns,
6822    }
6823}
6824
6825fn get_len(graph: &Graph, id: NodeId) -> usize {
6826    graph.node(id).shape.num_elements().unwrap_or(0)
6827}
6828
6829/// Static `usize` dims of a node's shape, or empty if any dim is dynamic.
6830fn get_static_dims(graph: &Graph, id: NodeId) -> Vec<usize> {
6831    let dims = graph.node(id).shape.dims();
6832    let mut out = Vec::with_capacity(dims.len());
6833    for d in dims {
6834        if let Some(s) = match d {
6835            rlx_ir::Dim::Static(s) => Some(*s),
6836            _ => None,
6837        } {
6838            out.push(s);
6839        } else {
6840            return Vec::new();
6841        }
6842    }
6843    out
6844}
6845
6846/// NumPy-style broadcast strides for one operand into the flat output
6847/// buffer. Returns a length-`out_dims.len()` `Vec<u32>` where entry
6848/// `d` is `0` if the input is size-1 (broadcast) at output dim `d`
6849/// (after left-padding with size-1 to match ranks), otherwise the
6850/// natural row-major stride into the *input* buffer.
6851///
6852/// Caller iterates output flat index `i` → output coords (row-major)
6853/// → input flat index = dot(coords, strides). The result is correct
6854/// for any broadcast pattern (scalar, last-axis, middle-axis,
6855/// bidirectional).
6856/// True when `rhs_dims` describes a *trailing* broadcast of `out_dims`
6857/// — i.e. every rhs dim either equals the corresponding output dim
6858/// (counting from the right) or rhs is shorter (left-padded with 1s).
6859/// Mid-shape singletons (e.g. rhs `[a, b, 1, d]` into out `[a, b, c, d]`
6860/// where `c > 1`) are NOT trailing broadcasts and require the
6861/// shape-aware `BinaryFull` slow path — `BiasAdd`'s linear bias-replicated
6862/// kernel silently miscomputes them.
6863fn is_trailing_bias_broadcast(rhs_dims: &[rlx_ir::Dim], out_dims: &[rlx_ir::Dim]) -> bool {
6864    if rhs_dims.len() > out_dims.len() {
6865        return false;
6866    }
6867    let off = out_dims.len() - rhs_dims.len();
6868    for i in 0..rhs_dims.len() {
6869        let r = match rhs_dims[i] {
6870            rlx_ir::Dim::Static(n) => n,
6871            _ => return false,
6872        };
6873        let o = match out_dims[off + i] {
6874            rlx_ir::Dim::Static(n) => n,
6875            _ => return false,
6876        };
6877        if r != o {
6878            return false;
6879        }
6880    }
6881    true
6882}
6883
6884fn broadcast_strides(in_dims: &[usize], out_dims: &[usize]) -> Vec<u32> {
6885    let r_out = out_dims.len();
6886    let r_in = in_dims.len();
6887    assert!(
6888        r_in <= r_out,
6889        "broadcast: input rank {r_in} > output rank {r_out}"
6890    );
6891    let pad = r_out - r_in;
6892    let mut strides = vec![0u32; r_out];
6893    let mut acc: usize = 1;
6894    for d in (0..r_out).rev() {
6895        let in_size = if d < pad { 1 } else { in_dims[d - pad] };
6896        if in_size == 1 {
6897            strides[d] = 0;
6898        } else {
6899            assert_eq!(
6900                in_size, out_dims[d],
6901                "broadcast: input dim {in_size} doesn't match output dim {} at axis {d}",
6902                out_dims[d]
6903            );
6904            strides[d] = acc as u32;
6905            acc *= in_size;
6906        }
6907    }
6908    strides
6909}
6910
6911/// Execute a thunk schedule on a raw arena buffer.
6912/// Fastest executor: call pre-compiled closures sequentially.
6913/// Zero match dispatch — each closure is a direct kernel call.
6914pub fn execute_compiled(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
6915    let base = arena_buf.as_mut_ptr();
6916    for f in &schedule.compiled_fns {
6917        f(base);
6918    }
6919}
6920
6921/// Active-extent execution stub. The runtime calls this when it has an
6922/// active-extent hint set. CPU doesn't implement per-thunk active-extent
6923/// scaling yet — return false so the caller falls back to the full
6924/// `execute_thunks` path.
6925pub fn execute_thunks_active(
6926    schedule: &ThunkSchedule,
6927    _arena_buf: &mut [u8],
6928    _actual: usize,
6929    _upper: usize,
6930) -> bool {
6931    let _ = schedule;
6932    false
6933}
6934
6935/// Match-based executor (fallback, used by tests).
6936struct MoeResidencyGuard;
6937impl Drop for MoeResidencyGuard {
6938    fn drop(&mut self) {
6939        if let Some(stats) = crate::moe_residency::take_stats() {
6940            crate::moe_residency::stash_last_forward_stats(stats);
6941        } else {
6942            crate::moe_residency::clear_mask();
6943        }
6944    }
6945}
6946
6947pub fn execute_thunks(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
6948    crate::moe_residency::reset_gmm_counters();
6949    if let Some(layers) = schedule.moe_resident_layers.clone() {
6950        crate::moe_residency::set_per_layer_masks(Some(layers));
6951    } else {
6952        crate::moe_residency::set_mask(schedule.moe_resident.clone());
6953    }
6954    if let Some(cap) = schedule.moe_topk_capture.as_ref() {
6955        cap.clear();
6956    }
6957    let _moe_guard = MoeResidencyGuard;
6958    let base = arena_buf.as_mut_ptr();
6959    let mask_thr = schedule.mask_threshold;
6960    let mask_neg = schedule.mask_neg_inf;
6961    let score_thr = schedule.score_skip;
6962    let thunks = &schedule.thunks;
6963    let len = thunks.len();
6964
6965    // Pre-allocate ALL reusable buffers once (zero per-call allocation)
6966    let max_h = thunks
6967        .iter()
6968        .filter_map(|t| match t {
6969            Thunk::FusedResidualLN { h, .. }
6970            | Thunk::FusedResidualRmsNorm { h, .. }
6971            | Thunk::LayerNorm { h, .. } => Some(*h as usize),
6972            _ => None,
6973        })
6974        .max()
6975        .unwrap_or(0);
6976    let zero_bias = vec![0f32; max_h];
6977
6978    // Pre-allocate per-(batch,head) score buffers for parallel SDPA.
6979    // Q/K/V/out are accessed via strided BLAS — no deinterleave copy needed.
6980    let max_sdpa = thunks
6981        .iter()
6982        .filter_map(|t| match t {
6983            Thunk::Attention {
6984                batch,
6985                seq,
6986                kv_seq,
6987                heads,
6988                head_dim,
6989                ..
6990            } => Some((
6991                *batch as usize,
6992                (*seq as usize).max(*kv_seq as usize),
6993                *heads as usize,
6994                *head_dim as usize,
6995            )),
6996            _ => None,
6997        })
6998        .fold((0, 0, 0, 0), |(mb, ms, mh, md), (b, s, h, d)| {
6999            (mb.max(b), ms.max(s), mh.max(h), md.max(d))
7000        });
7001    let (max_batch, max_seq, max_heads, _max_dh) = max_sdpa;
7002    let max_units = max_batch * max_heads;
7003    let mut sdpa_scores = vec![0f32; max_units * max_seq * max_seq];
7004
7005    // Pre-allocate fused layer buffers (reused across all 12+ layers — zero malloc per layer)
7006    let fl = thunks
7007        .iter()
7008        .filter_map(|t| match t {
7009            Thunk::FusedBertLayer {
7010                batch,
7011                seq,
7012                hs,
7013                int_dim,
7014                ..
7015            } => {
7016                let m = (*batch as usize) * (*seq as usize);
7017                let h = *hs as usize;
7018                let id = *int_dim as usize;
7019                Some((m, h, id, m * (*seq as usize)))
7020            }
7021            Thunk::FusedNomicLayer {
7022                batch,
7023                seq,
7024                hs,
7025                int_dim,
7026                ..
7027            } => {
7028                let m = (*batch as usize) * (*seq as usize);
7029                let h = *hs as usize;
7030                let id = *int_dim as usize;
7031                Some((m, h, id, m * (*seq as usize)))
7032            }
7033            _ => None,
7034        })
7035        .fold((0, 0, 0, 0), |(mm, mh, mi, ms), (m, h, id, ss)| {
7036            (mm.max(m), mh.max(h), mi.max(id), ms.max(ss))
7037        });
7038    let (fl_m, fl_h, fl_int, fl_ss) = fl;
7039    let mut fl_qkv = vec![0f32; fl_m * 3 * fl_h];
7040    let mut fl_attn = vec![0f32; fl_m * fl_h];
7041    let mut fl_res = vec![0f32; fl_m * fl_h];
7042    let mut fl_normed = vec![0f32; fl_m * fl_h];
7043    let mut fl_ffn = vec![0f32; fl_m * fl_int.max(2 * fl_int)]; // Nomic needs 2×int for fused fc11+fc12
7044    let mut fl_sc = vec![0f32; fl_ss.max(1)];
7045
7046    for i in 0..len {
7047        let thunk = unsafe { thunks.get_unchecked(i) };
7048        match thunk {
7049            Thunk::Nop => {}
7050
7051            Thunk::GaussianSplatRender {
7052                positions_off,
7053                positions_len,
7054                scales_off,
7055                scales_len,
7056                rotations_off,
7057                rotations_len,
7058                opacities_off,
7059                opacities_len,
7060                colors_off,
7061                colors_len,
7062                sh_coeffs_off,
7063                sh_coeffs_len,
7064                meta_off,
7065                dst_off,
7066                dst_len,
7067                width,
7068                height,
7069                tile_size,
7070                radius_scale,
7071                alpha_cutoff,
7072                max_splat_steps,
7073                transmittance_threshold,
7074                max_list_entries,
7075            } => unsafe {
7076                crate::splat::execute_gaussian_splat_render(
7077                    *positions_off,
7078                    *positions_len,
7079                    *scales_off,
7080                    *scales_len,
7081                    *rotations_off,
7082                    *rotations_len,
7083                    *opacities_off,
7084                    *opacities_len,
7085                    *colors_off,
7086                    *colors_len,
7087                    *sh_coeffs_off,
7088                    *sh_coeffs_len,
7089                    *meta_off,
7090                    *dst_off,
7091                    *dst_len,
7092                    *width,
7093                    *height,
7094                    *tile_size,
7095                    *radius_scale,
7096                    *alpha_cutoff,
7097                    *max_splat_steps,
7098                    *transmittance_threshold,
7099                    *max_list_entries,
7100                    base,
7101                );
7102            },
7103
7104            Thunk::GaussianSplatRenderBackward {
7105                positions_off,
7106                positions_len,
7107                scales_off,
7108                scales_len,
7109                rotations_off,
7110                rotations_len,
7111                opacities_off,
7112                opacities_len,
7113                colors_off,
7114                colors_len,
7115                sh_coeffs_off,
7116                sh_coeffs_len,
7117                meta_off,
7118                d_loss_off,
7119                d_loss_len,
7120                packed_off,
7121                packed_len,
7122                width,
7123                height,
7124                tile_size,
7125                radius_scale,
7126                alpha_cutoff,
7127                max_splat_steps,
7128                transmittance_threshold,
7129                max_list_entries,
7130                loss_grad_clip,
7131                sh_band,
7132                max_anisotropy,
7133            } => unsafe {
7134                crate::splat::execute_gaussian_splat_render_backward(
7135                    *positions_off,
7136                    *positions_len,
7137                    *scales_off,
7138                    *scales_len,
7139                    *rotations_off,
7140                    *rotations_len,
7141                    *opacities_off,
7142                    *opacities_len,
7143                    *colors_off,
7144                    *colors_len,
7145                    *sh_coeffs_off,
7146                    *sh_coeffs_len,
7147                    *meta_off,
7148                    *d_loss_off,
7149                    *d_loss_len,
7150                    *packed_off,
7151                    *packed_len,
7152                    *width,
7153                    *height,
7154                    *tile_size,
7155                    *radius_scale,
7156                    *alpha_cutoff,
7157                    *max_splat_steps,
7158                    *transmittance_threshold,
7159                    *max_list_entries,
7160                    *loss_grad_clip,
7161                    *sh_band,
7162                    *max_anisotropy,
7163                    base,
7164                );
7165            },
7166
7167            Thunk::GaussianSplatPrepare {
7168                positions_off,
7169                positions_len,
7170                scales_off,
7171                scales_len,
7172                rotations_off,
7173                rotations_len,
7174                opacities_off,
7175                opacities_len,
7176                colors_off,
7177                colors_len,
7178                sh_coeffs_off,
7179                sh_coeffs_len,
7180                meta_off,
7181                meta_len,
7182                prep_off,
7183                prep_len,
7184                width,
7185                height,
7186                tile_size,
7187                radius_scale,
7188                alpha_cutoff,
7189                max_splat_steps,
7190                transmittance_threshold,
7191                max_list_entries,
7192            } => unsafe {
7193                crate::splat::execute_gaussian_splat_prepare(
7194                    *positions_off,
7195                    *positions_len,
7196                    *scales_off,
7197                    *scales_len,
7198                    *rotations_off,
7199                    *rotations_len,
7200                    *opacities_off,
7201                    *opacities_len,
7202                    *colors_off,
7203                    *colors_len,
7204                    *sh_coeffs_off,
7205                    *sh_coeffs_len,
7206                    *meta_off,
7207                    *meta_len,
7208                    *prep_off,
7209                    *prep_len,
7210                    *width,
7211                    *height,
7212                    *tile_size,
7213                    *radius_scale,
7214                    *alpha_cutoff,
7215                    *max_splat_steps,
7216                    *transmittance_threshold,
7217                    *max_list_entries,
7218                    base,
7219                );
7220            },
7221
7222            Thunk::GaussianSplatRasterize {
7223                prep_off,
7224                prep_len,
7225                meta_off,
7226                meta_len,
7227                dst_off,
7228                dst_len,
7229                count,
7230                width,
7231                height,
7232                tile_size,
7233                alpha_cutoff,
7234                max_splat_steps,
7235                transmittance_threshold,
7236                max_list_entries,
7237            } => unsafe {
7238                crate::splat::execute_gaussian_splat_rasterize(
7239                    *prep_off,
7240                    *prep_len,
7241                    *meta_off,
7242                    *meta_len,
7243                    *dst_off,
7244                    *dst_len,
7245                    *count,
7246                    *width,
7247                    *height,
7248                    *tile_size,
7249                    *alpha_cutoff,
7250                    *max_splat_steps,
7251                    *transmittance_threshold,
7252                    *max_list_entries,
7253                    base,
7254                );
7255            },
7256
7257            Thunk::Fft1d {
7258                src,
7259                dst,
7260                outer,
7261                n_complex,
7262                inverse,
7263                dtype,
7264            } => unsafe {
7265                match dtype {
7266                    rlx_ir::DType::F64 => execute_fft1d_f64(
7267                        *src,
7268                        *dst,
7269                        *outer as usize,
7270                        *n_complex as usize,
7271                        *inverse,
7272                        base,
7273                    ),
7274                    rlx_ir::DType::F32 => execute_fft1d_f32(
7275                        *src,
7276                        *dst,
7277                        *outer as usize,
7278                        *n_complex as usize,
7279                        *inverse,
7280                        base,
7281                    ),
7282                    other => panic!("Op::Fft on CPU requires F32/F64, got {other:?}"),
7283                }
7284            },
7285
7286            // CustomFn dispatch (interpreted path). Mirrors the
7287            // pre-compiled-closure variant elsewhere in this file.
7288            // Patched by rlx-eda.
7289            Thunk::CustomFn {
7290                body,
7291                body_init,
7292                inputs,
7293                body_output_off,
7294                outer_output_off,
7295                out_bytes,
7296            } => {
7297                let mut body_buf: Vec<u8> = (**body_init).clone();
7298                unsafe {
7299                    for (body_in_off, outer_in_off, n_bytes) in inputs.iter() {
7300                        let src = (base as *const u8).add(*outer_in_off);
7301                        let dst = body_buf.as_mut_ptr().add(*body_in_off);
7302                        std::ptr::copy_nonoverlapping(src, dst, *n_bytes as usize);
7303                    }
7304                }
7305                execute_thunks(body, &mut body_buf);
7306                unsafe {
7307                    let src = body_buf.as_ptr().add(*body_output_off);
7308                    let dst = base.add(*outer_output_off);
7309                    std::ptr::copy_nonoverlapping(src, dst, *out_bytes as usize);
7310                }
7311            }
7312
7313            Thunk::Sgemm { a, b, c, m, k, n } => {
7314                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
7315                unsafe {
7316                    crate::blas::sgemm_auto(
7317                        sl(*a, base, m * k),
7318                        sl(*b, base, k * n),
7319                        sl_mut(*c, base, m * n),
7320                        m,
7321                        k,
7322                        n,
7323                    );
7324                }
7325            }
7326
7327            Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
7328                let (n_, nrhs_) = (*n as usize, *nrhs as usize);
7329                // LAPACK overwrites both A and B; clone into scratch
7330                // each call. Caller's A and b must be preserved for
7331                // VJP recompute. (Eventually: swap to a factor-once /
7332                // solve-many scheme; that's the symbolic-reuse story
7333                // and lives with the sparse path.)
7334                unsafe {
7335                    let a_src = sl_f64(*a, base, n_ * n_);
7336                    let b_src = sl_f64(*b, base, n_ * nrhs_);
7337                    let mut a_scratch: Vec<f64> = a_src.to_vec();
7338                    let mut x_buf: Vec<f64> = b_src.to_vec();
7339                    let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7340                    if info != 0 {
7341                        panic!(
7342                            "DenseSolveF64: dgesv reported singular matrix \
7343                                (info={info}, n={n_}, nrhs={nrhs_})"
7344                        );
7345                    }
7346                    let dst = sl_mut_f64(*x, base, n_ * nrhs_);
7347                    dst.copy_from_slice(&x_buf);
7348                }
7349            }
7350
7351            Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
7352                let (n_, nrhs_) = (*n as usize, *nrhs as usize);
7353                unsafe {
7354                    let a_src = sl(*a, base, n_ * n_);
7355                    let b_src = sl(*b, base, n_ * nrhs_);
7356                    let mut a_scratch: Vec<f32> = a_src.to_vec();
7357                    let mut x_buf: Vec<f32> = b_src.to_vec();
7358                    let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7359                    if info != 0 {
7360                        panic!(
7361                            "DenseSolveF32: sgesv reported singular matrix \
7362                             (info={info}, n={n_}, nrhs={nrhs_})"
7363                        );
7364                    }
7365                    let dst = sl_mut(*x, base, n_ * nrhs_);
7366                    dst.copy_from_slice(&x_buf);
7367                }
7368            }
7369
7370            Thunk::BatchedDenseSolveF64 {
7371                a,
7372                b,
7373                x,
7374                batch,
7375                n,
7376                nrhs,
7377            } => {
7378                // Per slice: extract A_i and b_i, dgesv, write x_i.
7379                // LAPACK has no batched dgesv on Accelerate, so this
7380                // is a serial loop over the batch axis. cuSOLVER /
7381                // hipSOLVER expose `getrfBatched` / `getrsBatched` for
7382                // the GPU path — we'll wire that in rlx-cuda when
7383                // someone needs Linux+CUDA.
7384                let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
7385                let a_stride = n_ * n_;
7386                let b_stride = n_ * nrhs_;
7387                unsafe {
7388                    let a_full = sl_f64(*a, base, b_ * a_stride);
7389                    let b_full = sl_f64(*b, base, b_ * b_stride);
7390                    let x_full = sl_mut_f64(*x, base, b_ * b_stride);
7391                    for bi in 0..b_ {
7392                        let mut a_scratch: Vec<f64> =
7393                            a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
7394                        let mut x_buf: Vec<f64> =
7395                            b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
7396                        let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7397                        if info != 0 {
7398                            panic!(
7399                                "BatchedDenseSolveF64: slice {bi} \
7400                                    singular (info={info}, n={n_}, nrhs={nrhs_})"
7401                            );
7402                        }
7403                        x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
7404                    }
7405                }
7406            }
7407
7408            Thunk::BatchedDenseSolveF32 {
7409                a,
7410                b,
7411                x,
7412                batch,
7413                n,
7414                nrhs,
7415            } => {
7416                let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
7417                let a_stride = n_ * n_;
7418                let b_stride = n_ * nrhs_;
7419                unsafe {
7420                    let a_full = sl(*a, base, b_ * a_stride);
7421                    let b_full = sl(*b, base, b_ * b_stride);
7422                    let x_full = sl_mut(*x, base, b_ * b_stride);
7423                    for bi in 0..b_ {
7424                        let mut a_scratch = a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
7425                        let mut x_buf = b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
7426                        let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7427                        if info != 0 {
7428                            panic!("BatchedDenseSolveF32: slice {bi} singular (info={info})");
7429                        }
7430                        x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
7431                    }
7432                }
7433            }
7434
7435            Thunk::BatchedDgemmF64 {
7436                a,
7437                b,
7438                c,
7439                batch,
7440                m,
7441                k,
7442                n,
7443            } => {
7444                let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
7445                let a_stride = m_ * k_;
7446                let b_stride = k_ * n_;
7447                let c_stride = m_ * n_;
7448                unsafe {
7449                    let a_full = sl_f64(*a, base, b_ * a_stride);
7450                    let b_full = sl_f64(*b, base, b_ * b_stride);
7451                    let c_full = sl_mut_f64(*c, base, b_ * c_stride);
7452                    for bi in 0..b_ {
7453                        let a_slice = &a_full[bi * a_stride..(bi + 1) * a_stride];
7454                        let b_slice = &b_full[bi * b_stride..(bi + 1) * b_stride];
7455                        let c_slice = &mut c_full[bi * c_stride..(bi + 1) * c_stride];
7456                        crate::blas::dgemm(a_slice, b_slice, c_slice, m_, k_, n_);
7457                    }
7458                }
7459            }
7460
7461            Thunk::BatchedSgemm {
7462                a,
7463                b,
7464                c,
7465                batch,
7466                m,
7467                k,
7468                n,
7469            } => {
7470                let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
7471                let a_stride = m_ * k_;
7472                let b_stride = k_ * n_;
7473                let c_stride = m_ * n_;
7474                unsafe {
7475                    let a_full = sl(*a, base, b_ * a_stride);
7476                    let b_full = sl(*b, base, b_ * b_stride);
7477                    let c_full = sl_mut(*c, base, b_ * c_stride);
7478                    for bi in 0..b_ {
7479                        let a_slice = &a_full[bi * a_stride..(bi + 1) * a_stride];
7480                        let b_slice = &b_full[bi * b_stride..(bi + 1) * b_stride];
7481                        let c_slice = &mut c_full[bi * c_stride..(bi + 1) * c_stride];
7482                        crate::blas::sgemm_auto(a_slice, b_slice, c_slice, m_, k_, n_);
7483                    }
7484                }
7485            }
7486
7487            Thunk::Dgemm { a, b, c, m, k, n } => {
7488                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
7489                unsafe {
7490                    crate::blas::dgemm(
7491                        sl_f64(*a, base, m * k),
7492                        sl_f64(*b, base, k * n),
7493                        sl_mut_f64(*c, base, m * n),
7494                        m,
7495                        k,
7496                        n,
7497                    );
7498                }
7499            }
7500
7501            Thunk::TransposeF64 {
7502                src,
7503                dst,
7504                in_total,
7505                out_dims,
7506                in_strides,
7507            } => unsafe {
7508                let inp = sl_f64(*src, base, *in_total as usize);
7509                let out_total: usize = out_dims.iter().map(|d| *d as usize).product();
7510                let out = sl_mut_f64(*dst, base, out_total);
7511                transpose_walk_f64(inp, out, out_dims, in_strides);
7512            },
7513
7514            Thunk::ActivationF64 {
7515                src,
7516                dst,
7517                len,
7518                kind,
7519            } => {
7520                let len = *len as usize;
7521                unsafe {
7522                    let inp = sl_f64(*src, base, len);
7523                    let out = sl_mut_f64(*dst, base, len);
7524                    apply_activation_f64(inp, out, *kind);
7525                }
7526            }
7527
7528            Thunk::ReduceSumF64 {
7529                src,
7530                dst,
7531                outer,
7532                reduced,
7533                inner,
7534            } => {
7535                let (o, r, n) = (*outer as usize, *reduced as usize, *inner as usize);
7536                unsafe {
7537                    let inp = sl_f64(*src, base, o * r * n);
7538                    let out = sl_mut_f64(*dst, base, o * n);
7539                    reduce_sum_f64(inp, out, o, r, n);
7540                }
7541            }
7542
7543            Thunk::CopyF64 { src, dst, len } => {
7544                let len = *len as usize;
7545                if *src == *dst { /* aliased, no copy needed */
7546                } else {
7547                    unsafe {
7548                        let s = sl_f64(*src, base, len);
7549                        let d = sl_mut_f64(*dst, base, len);
7550                        d.copy_from_slice(s);
7551                    }
7552                }
7553            }
7554
7555            Thunk::BinaryFullF64 {
7556                lhs,
7557                rhs,
7558                dst,
7559                len,
7560                lhs_len,
7561                rhs_len,
7562                op,
7563                out_dims_bcast,
7564                bcast_lhs_strides,
7565                bcast_rhs_strides,
7566            } => {
7567                let len = *len as usize;
7568                let lhs_len = *lhs_len as usize;
7569                let rhs_len = *rhs_len as usize;
7570                unsafe {
7571                    let l = sl_f64(*lhs, base, lhs_len);
7572                    let r = sl_f64(*rhs, base, rhs_len);
7573                    let d = sl_mut_f64(*dst, base, len);
7574                    if lhs_len == len && rhs_len == len {
7575                        for i in 0..len {
7576                            d[i] = binary_op_f64(*op, l[i], r[i]);
7577                        }
7578                    } else if !out_dims_bcast.is_empty() {
7579                        // Shape-aware broadcast path: correct for
7580                        // arbitrary NumPy-style broadcasts including
7581                        // bidirectional `[N,1] op [1,S]`.
7582                        let rank = out_dims_bcast.len();
7583                        let mut coords = vec![0u32; rank];
7584                        for i in 0..len {
7585                            let mut rem = i;
7586                            for ax in (0..rank).rev() {
7587                                let sz = out_dims_bcast[ax] as usize;
7588                                coords[ax] = (rem % sz) as u32;
7589                                rem /= sz;
7590                            }
7591                            let mut li: usize = 0;
7592                            let mut ri: usize = 0;
7593                            for ax in 0..rank {
7594                                li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
7595                                ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
7596                            }
7597                            d[i] = binary_op_f64(*op, l[li], r[ri]);
7598                        }
7599                    } else {
7600                        // Fallback: legacy modulo path (preserved for
7601                        // dynamic-shape graphs where strides can't be
7602                        // precomputed). Only correct for scalar /
7603                        // last-axis broadcast.
7604                        for i in 0..len {
7605                            d[i] = binary_op_f64(*op, l[i % lhs_len], r[i % rhs_len]);
7606                        }
7607                    }
7608                }
7609            }
7610
7611            Thunk::BinaryFullC64 {
7612                lhs,
7613                rhs,
7614                dst,
7615                len,
7616                lhs_len,
7617                rhs_len,
7618                op,
7619                out_dims_bcast,
7620                bcast_lhs_strides,
7621                bcast_rhs_strides,
7622            } => {
7623                // Complex element layout: [re_0, im_0, re_1, im_1, ...]
7624                // Underlying f32 buffer length is 2·N (N = complex
7625                // element count). All offsets are byte offsets; the
7626                // `sl` helper reads as f32 starting at the byte
7627                // offset, so f32-length = 2·complex-len.
7628                let n_out = *len as usize;
7629                let n_l = *lhs_len as usize;
7630                let n_r = *rhs_len as usize;
7631                unsafe {
7632                    let l = sl(*lhs, base, 2 * n_l);
7633                    let r = sl(*rhs, base, 2 * n_r);
7634                    let d = sl_mut(*dst, base, 2 * n_out);
7635                    let do_c64 = |a_re: f32, a_im: f32, b_re: f32, b_im: f32| -> (f32, f32) {
7636                        match op {
7637                            BinaryOp::Add => (a_re + b_re, a_im + b_im),
7638                            BinaryOp::Sub => (a_re - b_re, a_im - b_im),
7639                            BinaryOp::Mul => (a_re * b_re - a_im * b_im, a_re * b_im + a_im * b_re),
7640                            BinaryOp::Div => {
7641                                let denom = b_re * b_re + b_im * b_im;
7642                                (
7643                                    (a_re * b_re + a_im * b_im) / denom,
7644                                    (a_im * b_re - a_re * b_im) / denom,
7645                                )
7646                            }
7647                            BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => {
7648                                unreachable!("C64 max/min/pow rejected at lowering")
7649                            }
7650                        }
7651                    };
7652                    if n_l == n_out && n_r == n_out {
7653                        for i in 0..n_out {
7654                            let (re, im) = do_c64(l[2 * i], l[2 * i + 1], r[2 * i], r[2 * i + 1]);
7655                            d[2 * i] = re;
7656                            d[2 * i + 1] = im;
7657                        }
7658                    } else if !out_dims_bcast.is_empty() {
7659                        // Strided complex broadcast: strides are in
7660                        // *complex element* units; multiply by 2 when
7661                        // indexing into the f32 buffer.
7662                        let rank = out_dims_bcast.len();
7663                        let mut coords = vec![0u32; rank];
7664                        for i in 0..n_out {
7665                            let mut rem = i;
7666                            for ax in (0..rank).rev() {
7667                                let sz = out_dims_bcast[ax] as usize;
7668                                coords[ax] = (rem % sz) as u32;
7669                                rem /= sz;
7670                            }
7671                            let mut li: usize = 0;
7672                            let mut ri: usize = 0;
7673                            for ax in 0..rank {
7674                                li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
7675                                ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
7676                            }
7677                            let (re, im) =
7678                                do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
7679                            d[2 * i] = re;
7680                            d[2 * i + 1] = im;
7681                        }
7682                    } else {
7683                        // Modulo fallback (scalar / last-axis broadcast).
7684                        for i in 0..n_out {
7685                            let li = if n_l == 1 { 0 } else { i % n_l };
7686                            let ri = if n_r == 1 { 0 } else { i % n_r };
7687                            let (re, im) =
7688                                do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
7689                            d[2 * i] = re;
7690                            d[2 * i + 1] = im;
7691                        }
7692                    }
7693                }
7694            }
7695
7696            Thunk::ComplexNormSqF32 { src, dst, len } => {
7697                let n = *len as usize;
7698                unsafe {
7699                    let s = sl(*src, base, 2 * n);
7700                    let d = sl_mut(*dst, base, n);
7701                    for i in 0..n {
7702                        let re = s[2 * i];
7703                        let im = s[2 * i + 1];
7704                        d[i] = re * re + im * im;
7705                    }
7706                }
7707            }
7708
7709            Thunk::ComplexNormSqBackwardF32 { z, g, dz, len } => {
7710                // Wirtinger: dz = g · z, element-wise complex
7711                // (g is real, z is complex).
7712                let n = *len as usize;
7713                unsafe {
7714                    let zb = sl(*z, base, 2 * n);
7715                    let gb = sl(*g, base, n);
7716                    let db = sl_mut(*dz, base, 2 * n);
7717                    for i in 0..n {
7718                        let re = zb[2 * i];
7719                        let im = zb[2 * i + 1];
7720                        let gv = gb[i];
7721                        db[2 * i] = gv * re;
7722                        db[2 * i + 1] = gv * im;
7723                    }
7724                }
7725            }
7726
7727            Thunk::ConjugateC64 { src, dst, len } => {
7728                let n = *len as usize;
7729                unsafe {
7730                    let s = sl(*src, base, 2 * n);
7731                    let d = sl_mut(*dst, base, 2 * n);
7732                    for i in 0..n {
7733                        d[2 * i] = s[2 * i];
7734                        d[2 * i + 1] = -s[2 * i + 1];
7735                    }
7736                }
7737            }
7738
7739            Thunk::ActivationC64 {
7740                src,
7741                dst,
7742                len,
7743                kind,
7744            } => {
7745                let n = *len as usize;
7746                unsafe {
7747                    let s = sl(*src, base, 2 * n);
7748                    let d = sl_mut(*dst, base, 2 * n);
7749                    for i in 0..n {
7750                        let a = s[2 * i];
7751                        let b = s[2 * i + 1];
7752                        let (re, im) = match kind {
7753                            Activation::Neg => (-a, -b),
7754                            Activation::Exp => {
7755                                // exp(a + bi) = e^a · (cos b + i·sin b)
7756                                let ea = a.exp();
7757                                (ea * b.cos(), ea * b.sin())
7758                            }
7759                            Activation::Log => {
7760                                // log(z) = log|z| + i·arg(z), principal branch
7761                                let r = (a * a + b * b).sqrt();
7762                                (r.ln(), b.atan2(a))
7763                            }
7764                            Activation::Sqrt => {
7765                                // sqrt(a+bi) = sqrt((|z|+a)/2) + sign(b)·i·sqrt((|z|-a)/2)
7766                                // Principal branch; for b == 0 and a < 0 returns +i·sqrt(|a|).
7767                                let r = (a * a + b * b).sqrt();
7768                                let re = ((r + a) * 0.5).max(0.0).sqrt();
7769                                let im_mag = ((r - a) * 0.5).max(0.0).sqrt();
7770                                let im = if b >= 0.0 { im_mag } else { -im_mag };
7771                                (re, im)
7772                            }
7773                            _ => unreachable!("non-C64 activation kind survived lowering"),
7774                        };
7775                        d[2 * i] = re;
7776                        d[2 * i + 1] = im;
7777                    }
7778                }
7779            }
7780
7781            Thunk::Scan {
7782                body,
7783                body_init,
7784                body_input_off,
7785                body_output_off,
7786                outer_init_off,
7787                outer_final_off,
7788                length,
7789                carry_bytes,
7790                save_trajectory,
7791                xs_inputs,
7792                bcast_inputs,
7793                num_checkpoints,
7794            } => {
7795                let cb = *carry_bytes as usize;
7796                let n_steps = *length as usize;
7797                // Checkpoint mode: when 0 < K < length, save trajectory[k]
7798                // only when t == c_k = floor((k+1) * length / K) - 1.
7799                // The last index c_{K-1} = length - 1 always.
7800                let k_total = if *num_checkpoints == 0 || *num_checkpoints == *length {
7801                    n_steps // save every step
7802                } else {
7803                    *num_checkpoints as usize
7804                };
7805                let checkpoint_t_for_k = |k: usize| -> usize {
7806                    if k_total == n_steps {
7807                        k
7808                    } else {
7809                        ((k + 1) * n_steps)
7810                            .div_ceil(k_total)
7811                            .saturating_sub(1)
7812                            .min(n_steps - 1)
7813                    }
7814                };
7815                let mut next_k = 0usize;
7816
7817                let mut body_buf: Vec<u8> = (**body_init).clone();
7818                unsafe {
7819                    std::ptr::copy_nonoverlapping(
7820                        base.add(*outer_init_off),
7821                        body_buf.as_mut_ptr().add(*body_input_off),
7822                        cb,
7823                    );
7824                    // Broadcast inputs: copy each one into the body's
7825                    // input slot ONCE. They aren't touched in the
7826                    // iteration loop below (in contrast to xs).
7827                    for (body_b_off, outer_b_off, total_bytes) in bcast_inputs.iter() {
7828                        std::ptr::copy_nonoverlapping(
7829                            base.add(*outer_b_off),
7830                            body_buf.as_mut_ptr().add(*body_b_off),
7831                            *total_bytes as usize,
7832                        );
7833                    }
7834                }
7835                for t in 0..n_steps {
7836                    for (body_x_off, outer_xs_off, per_step_bytes) in xs_inputs.iter() {
7837                        let psb = *per_step_bytes as usize;
7838                        unsafe {
7839                            std::ptr::copy_nonoverlapping(
7840                                base.add(*outer_xs_off + t * psb),
7841                                body_buf.as_mut_ptr().add(*body_x_off),
7842                                psb,
7843                            );
7844                        }
7845                    }
7846
7847                    execute_thunks(body, &mut body_buf);
7848
7849                    if *save_trajectory && next_k < k_total && t == checkpoint_t_for_k(next_k) {
7850                        unsafe {
7851                            std::ptr::copy_nonoverlapping(
7852                                body_buf.as_ptr().add(*body_output_off),
7853                                base.add(*outer_final_off + next_k * cb),
7854                                cb,
7855                            );
7856                        }
7857                        next_k += 1;
7858                    }
7859
7860                    if *body_output_off != *body_input_off {
7861                        body_buf
7862                            .copy_within(*body_output_off..*body_output_off + cb, *body_input_off);
7863                    }
7864                }
7865
7866                if !*save_trajectory {
7867                    // Single final-carry write.
7868                    unsafe {
7869                        std::ptr::copy_nonoverlapping(
7870                            body_buf.as_ptr().add(*body_output_off),
7871                            base.add(*outer_final_off),
7872                            cb,
7873                        );
7874                    }
7875                }
7876            }
7877
7878            Thunk::ScanBackward {
7879                body_vjp,
7880                body_init,
7881                body_carry_in_off,
7882                body_x_offs,
7883                body_d_output_off,
7884                body_dcarry_out_off,
7885                outer_init_off,
7886                outer_traj_off,
7887                outer_upstream_off,
7888                outer_xs_offs,
7889                outer_dinit_off,
7890                length,
7891                carry_bytes,
7892                save_trajectory,
7893                num_checkpoints,
7894                forward_body,
7895                forward_body_init,
7896                forward_body_carry_in_off,
7897                forward_body_output_off,
7898                forward_body_x_offs,
7899                carry_elem_size,
7900            } => {
7901                // Two backward paths share the same per-iteration body
7902                // (body_vjp run + dcarry threading). The "All" path
7903                // reads the carry directly from the saved trajectory
7904                // each step. The "Recursive checkpointing" path stores
7905                // only K saved checkpoints and reconstructs intermediate
7906                // carries via Griewank-style recursive subdivision —
7907                // see [`griewank_process_segment`]. Auxiliary memory
7908                // is `O(log(segment_size) · carry_bytes)` for the
7909                // recursion stack, vs the old segment-cache scheme's
7910                // `O(segment_size · carry_bytes)`. Total recompute work
7911                // grows from `O(length)` to `O(length · log)`, which
7912                // is the canonical Griewank trade.
7913                let cb = *carry_bytes as usize;
7914                let n_steps = *length as usize;
7915                let k_total = *num_checkpoints as usize;
7916                let is_recursive = k_total != 0 && k_total != n_steps;
7917                let checkpoint_t_for_k = |k: usize| -> usize {
7918                    ((k + 1) * n_steps)
7919                        .div_ceil(k_total)
7920                        .saturating_sub(1)
7921                        .min(n_steps - 1)
7922                };
7923
7924                let mut fwd_buf: Vec<u8> = if is_recursive {
7925                    (**forward_body_init.as_ref().unwrap()).clone()
7926                } else {
7927                    Vec::new()
7928                };
7929
7930                let mut dcarry: Vec<u8> = vec![0u8; cb];
7931                if !*save_trajectory {
7932                    unsafe {
7933                        std::ptr::copy_nonoverlapping(
7934                            base.add(*outer_upstream_off),
7935                            dcarry.as_mut_ptr(),
7936                            cb,
7937                        );
7938                    }
7939                }
7940
7941                let mut body_buf: Vec<u8> = (**body_init).clone();
7942
7943                // Per-iteration backward action — shared between the
7944                // direct-trajectory (All) and Griewank (Recursive) paths.
7945                // Both feed the same body_vjp run with carry-at-t,
7946                // x_t_i, and d_output, then thread dcarry backward.
7947                let process_iter =
7948                    |t: usize, carry_in: &[u8], dcarry: &mut Vec<u8>, body_buf: &mut Vec<u8>| {
7949                        if *save_trajectory {
7950                            unsafe {
7951                                let up_off = *outer_upstream_off + t * cb;
7952                                match *carry_elem_size {
7953                                    4 => {
7954                                        let up_ptr = base.add(up_off) as *const f32;
7955                                        let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
7956                                        let n_elems = cb / 4;
7957                                        for i in 0..n_elems {
7958                                            *dc_ptr.add(i) += *up_ptr.add(i);
7959                                        }
7960                                    }
7961                                    8 => {
7962                                        let up_ptr = base.add(up_off) as *const f64;
7963                                        let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
7964                                        let n_elems = cb / 8;
7965                                        for i in 0..n_elems {
7966                                            *dc_ptr.add(i) += *up_ptr.add(i);
7967                                        }
7968                                    }
7969                                    other => panic!(
7970                                        "ScanBackward: unsupported carry elem size {other} \
7971                                     (only f32/f64 carries are supported today)"
7972                                    ),
7973                                }
7974                            }
7975                        }
7976                        body_buf[*body_carry_in_off..*body_carry_in_off + cb]
7977                            .copy_from_slice(carry_in);
7978                        unsafe {
7979                            for (i, body_x_off) in body_x_offs.iter().enumerate() {
7980                                let (outer_xs_off, per_step_bytes) = outer_xs_offs[i];
7981                                let psb = per_step_bytes as usize;
7982                                std::ptr::copy_nonoverlapping(
7983                                    base.add(outer_xs_off + t * psb),
7984                                    body_buf.as_mut_ptr().add(*body_x_off),
7985                                    psb,
7986                                );
7987                            }
7988                            std::ptr::copy_nonoverlapping(
7989                                dcarry.as_ptr(),
7990                                body_buf.as_mut_ptr().add(*body_d_output_off),
7991                                cb,
7992                            );
7993                        }
7994                        execute_thunks(body_vjp, body_buf);
7995                        unsafe {
7996                            std::ptr::copy_nonoverlapping(
7997                                body_buf.as_ptr().add(*body_dcarry_out_off),
7998                                dcarry.as_mut_ptr(),
7999                                cb,
8000                            );
8001                        }
8002                    };
8003
8004                if is_recursive {
8005                    // Griewank treeverse path. Process saved-checkpoint
8006                    // segments from highest-t to lowest-t; within each,
8007                    // recursive binary subdivision via
8008                    // `griewank_process_segment`. Auxiliary memory:
8009                    // O(log(seg_size) · cb) for the recursion stack
8010                    // (vs O(seg_size · cb) for the older segment-cache
8011                    // scheme); recompute work: O(seg_size · log).
8012                    let leaf_threshold = 4usize;
8013                    let fb_sched = forward_body.as_ref().unwrap();
8014                    let fb_init = forward_body_init.as_ref().unwrap().as_slice();
8015                    let mut segment_end = n_steps - 1;
8016                    for seg_k in (0..k_total).rev() {
8017                        let segment_start = if seg_k == 0 {
8018                            0
8019                        } else {
8020                            checkpoint_t_for_k(seg_k - 1) + 1
8021                        };
8022                        let mut anchor: Vec<u8> = vec![0u8; cb];
8023                        unsafe {
8024                            let src = if seg_k == 0 {
8025                                base.add(*outer_init_off)
8026                            } else {
8027                                base.add(*outer_traj_off + (seg_k - 1) * cb)
8028                            };
8029                            std::ptr::copy_nonoverlapping(src, anchor.as_mut_ptr(), cb);
8030                        }
8031                        // Closure adapter for the helper's signature
8032                        // (mutably re-borrows dcarry / body_buf each call).
8033                        let mut leaf_action = |t: usize, carry_in: &[u8]| {
8034                            process_iter(t, carry_in, &mut dcarry, &mut body_buf);
8035                        };
8036                        unsafe {
8037                            griewank_process_segment(
8038                                segment_start,
8039                                segment_end,
8040                                &anchor,
8041                                cb,
8042                                fb_sched,
8043                                fb_init,
8044                                *forward_body_carry_in_off,
8045                                *forward_body_output_off,
8046                                forward_body_x_offs,
8047                                base,
8048                                outer_xs_offs,
8049                                &mut fwd_buf,
8050                                leaf_threshold,
8051                                &mut leaf_action,
8052                            );
8053                        }
8054                        if seg_k == 0 {
8055                            break;
8056                        }
8057                        segment_end = segment_start - 1;
8058                    }
8059                } else {
8060                    // All-trajectory path: read each carry directly
8061                    // from the saved trajectory buffer.
8062                    let mut carry_buf: Vec<u8> = vec![0u8; cb];
8063                    for t in (0..n_steps).rev() {
8064                        unsafe {
8065                            let src = if t == 0 {
8066                                base.add(*outer_init_off)
8067                            } else {
8068                                base.add(*outer_traj_off + (t - 1) * cb)
8069                            };
8070                            std::ptr::copy_nonoverlapping(src, carry_buf.as_mut_ptr(), cb);
8071                        }
8072                        process_iter(t, &carry_buf, &mut dcarry, &mut body_buf);
8073                    }
8074                }
8075
8076                unsafe {
8077                    std::ptr::copy_nonoverlapping(dcarry.as_ptr(), base.add(*outer_dinit_off), cb);
8078                }
8079            }
8080
8081            Thunk::ScanBackwardXs {
8082                body_vjp,
8083                body_init,
8084                body_carry_in_off,
8085                body_x_offs,
8086                body_d_output_off,
8087                body_dcarry_out_off,
8088                body_dxs_out_off,
8089                outer_init_off,
8090                outer_traj_off,
8091                outer_upstream_off,
8092                outer_xs_offs,
8093                outer_dxs_off,
8094                length,
8095                carry_bytes,
8096                carry_elem_size,
8097                per_step_bytes,
8098                save_trajectory,
8099                num_checkpoints,
8100                forward_body,
8101                forward_body_init,
8102                forward_body_carry_in_off,
8103                forward_body_output_off,
8104                forward_body_x_offs,
8105            } => {
8106                let cb = *carry_bytes as usize;
8107                let psb = *per_step_bytes as usize;
8108                let n_steps = *length as usize;
8109                let k_total = *num_checkpoints as usize;
8110                let is_recursive = k_total != 0 && k_total != n_steps;
8111                let checkpoint_t_for_k = |k: usize| -> usize {
8112                    ((k + 1) * n_steps)
8113                        .div_ceil(k_total)
8114                        .saturating_sub(1)
8115                        .min(n_steps - 1)
8116                };
8117
8118                // Forward-body recompute scratch + segment cache —
8119                // exact mirror of the ScanBackward path. With ≈√length
8120                // checkpoints, total recompute work is O(length).
8121                let mut fwd_buf: Vec<u8> = if is_recursive {
8122                    (**forward_body_init.as_ref().unwrap()).clone()
8123                } else {
8124                    Vec::new()
8125                };
8126                let mut seg_cache: Vec<u8> = Vec::new();
8127                let mut seg_start_t: usize = usize::MAX;
8128                let mut seg_count: usize = 0;
8129                let recompute_carry_t =
8130                    |t: usize,
8131                     dst: &mut [u8],
8132                     fwd_buf: &mut Vec<u8>,
8133                     seg_cache: &mut Vec<u8>,
8134                     seg_start_t: &mut usize,
8135                     seg_count: &mut usize| {
8136                        if !is_recursive {
8137                            unsafe {
8138                                let src = if t == 0 {
8139                                    base.add(*outer_init_off)
8140                                } else {
8141                                    base.add(*outer_traj_off + (t - 1) * cb)
8142                                };
8143                                std::ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), cb);
8144                            }
8145                            return;
8146                        }
8147                        if *seg_start_t != usize::MAX
8148                            && t >= *seg_start_t
8149                            && t < *seg_start_t + *seg_count
8150                        {
8151                            let off = (t - *seg_start_t) * cb;
8152                            dst.copy_from_slice(&seg_cache[off..off + cb]);
8153                            return;
8154                        }
8155                        let seg_k = (0..k_total)
8156                            .find(|&k| t <= checkpoint_t_for_k(k))
8157                            .unwrap_or(k_total - 1);
8158                        let (anchor_t, anchor_ptr): (usize, *const u8) = if seg_k == 0 {
8159                            (0, unsafe { base.add(*outer_init_off) as *const u8 })
8160                        } else {
8161                            let prev_ck = checkpoint_t_for_k(seg_k - 1);
8162                            (prev_ck + 1, unsafe {
8163                                base.add(*outer_traj_off + (seg_k - 1) * cb) as *const u8
8164                            })
8165                        };
8166                        let seg_end_t = checkpoint_t_for_k(seg_k);
8167                        let seg_size = seg_end_t - anchor_t + 1;
8168
8169                        fwd_buf.copy_from_slice(forward_body_init.as_ref().unwrap());
8170                        unsafe {
8171                            std::ptr::copy_nonoverlapping(
8172                                anchor_ptr,
8173                                fwd_buf.as_mut_ptr().add(*forward_body_carry_in_off),
8174                                cb,
8175                            );
8176                        }
8177                        seg_cache.resize(seg_size * cb, 0u8);
8178                        seg_cache[0..cb].copy_from_slice(
8179                            &fwd_buf[*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
8180                        );
8181                        let fb_sched = forward_body.as_ref().unwrap();
8182                        for i in 1..seg_size {
8183                            let cur_iter = anchor_t + i - 1;
8184                            for (idx, fb_x_off) in forward_body_x_offs.iter().enumerate() {
8185                                let (outer_xs_off, x_psb) = outer_xs_offs[idx];
8186                                let xb = x_psb as usize;
8187                                unsafe {
8188                                    std::ptr::copy_nonoverlapping(
8189                                        base.add(outer_xs_off + cur_iter * xb),
8190                                        fwd_buf.as_mut_ptr().add(*fb_x_off),
8191                                        xb,
8192                                    );
8193                                }
8194                            }
8195                            execute_thunks(fb_sched, fwd_buf);
8196                            if *forward_body_output_off != *forward_body_carry_in_off {
8197                                fwd_buf.copy_within(
8198                                    *forward_body_output_off..*forward_body_output_off + cb,
8199                                    *forward_body_carry_in_off,
8200                                );
8201                            }
8202                            let cache_off = i * cb;
8203                            seg_cache[cache_off..cache_off + cb].copy_from_slice(
8204                                &fwd_buf
8205                                    [*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
8206                            );
8207                        }
8208                        *seg_start_t = anchor_t;
8209                        *seg_count = seg_size;
8210
8211                        let off = (t - anchor_t) * cb;
8212                        dst.copy_from_slice(&seg_cache[off..off + cb]);
8213                    };
8214
8215                let mut dcarry: Vec<u8> = vec![0u8; cb];
8216                if !*save_trajectory {
8217                    unsafe {
8218                        std::ptr::copy_nonoverlapping(
8219                            base.add(*outer_upstream_off),
8220                            dcarry.as_mut_ptr(),
8221                            cb,
8222                        );
8223                    }
8224                }
8225
8226                let mut body_buf: Vec<u8> = (**body_init).clone();
8227
8228                for t in (0..n_steps).rev() {
8229                    if *save_trajectory {
8230                        unsafe {
8231                            let up_off = *outer_upstream_off + t * cb;
8232                            match *carry_elem_size {
8233                                4 => {
8234                                    let up_ptr = base.add(up_off) as *const f32;
8235                                    let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
8236                                    let n_elems = cb / 4;
8237                                    for i in 0..n_elems {
8238                                        *dc_ptr.add(i) += *up_ptr.add(i);
8239                                    }
8240                                }
8241                                8 => {
8242                                    let up_ptr = base.add(up_off) as *const f64;
8243                                    let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
8244                                    let n_elems = cb / 8;
8245                                    for i in 0..n_elems {
8246                                        *dc_ptr.add(i) += *up_ptr.add(i);
8247                                    }
8248                                }
8249                                other => panic!(
8250                                    "ScanBackwardXs: unsupported carry elem size {other} \
8251                                     (only f32/f64 carries are supported today)"
8252                                ),
8253                            }
8254                        }
8255                    }
8256
8257                    // Seed body_vjp's carry input via the recompute
8258                    // helper (works for both All and Recursive modes),
8259                    // then x_t_i + d_output.
8260                    let carry_dst_start = *body_carry_in_off;
8261                    {
8262                        let carry_slice = &mut body_buf[carry_dst_start..carry_dst_start + cb];
8263                        recompute_carry_t(
8264                            t,
8265                            carry_slice,
8266                            &mut fwd_buf,
8267                            &mut seg_cache,
8268                            &mut seg_start_t,
8269                            &mut seg_count,
8270                        );
8271                    }
8272                    unsafe {
8273                        for (i, body_x_off) in body_x_offs.iter().enumerate() {
8274                            let (outer_xs_off, x_psb) = outer_xs_offs[i];
8275                            let xb = x_psb as usize;
8276                            std::ptr::copy_nonoverlapping(
8277                                base.add(outer_xs_off + t * xb),
8278                                body_buf.as_mut_ptr().add(*body_x_off),
8279                                xb,
8280                            );
8281                        }
8282                        std::ptr::copy_nonoverlapping(
8283                            dcarry.as_ptr(),
8284                            body_buf.as_mut_ptr().add(*body_d_output_off),
8285                            cb,
8286                        );
8287                    }
8288
8289                    execute_thunks(body_vjp, &mut body_buf);
8290
8291                    // Stash this step's dxs into row `t` of the outer
8292                    // [length, *per_step_xs] output.
8293                    unsafe {
8294                        std::ptr::copy_nonoverlapping(
8295                            body_buf.as_ptr().add(*body_dxs_out_off),
8296                            base.add(*outer_dxs_off + t * psb),
8297                            psb,
8298                        );
8299                    }
8300
8301                    // Update dcarry for next backward iteration.
8302                    unsafe {
8303                        std::ptr::copy_nonoverlapping(
8304                            body_buf.as_ptr().add(*body_dcarry_out_off),
8305                            dcarry.as_mut_ptr(),
8306                            cb,
8307                        );
8308                    }
8309                }
8310            }
8311
8312            Thunk::FusedMmBiasAct {
8313                a,
8314                w,
8315                bias,
8316                c,
8317                m,
8318                k,
8319                n,
8320                act,
8321            } => {
8322                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
8323                unsafe {
8324                    let out = sl_mut(*c, base, m * n);
8325                    crate::blas::sgemm_auto(sl(*a, base, m * k), sl(*w, base, k * n), out, m, k, n);
8326                    match act {
8327                        Some(Activation::Gelu) => {
8328                            crate::kernels::par_bias_gelu(out, sl(*bias, base, n), m, n)
8329                        }
8330                        Some(other) => {
8331                            crate::blas::bias_add(out, sl(*bias, base, n), m, n);
8332                            apply_activation_inplace(out, *other);
8333                        }
8334                        None => crate::blas::bias_add(out, sl(*bias, base, n), m, n),
8335                    }
8336                }
8337            }
8338
8339            Thunk::FusedResidualLN {
8340                x,
8341                res,
8342                bias,
8343                g,
8344                b,
8345                out,
8346                rows,
8347                h,
8348                eps,
8349                has_bias,
8350            } => {
8351                let (rows, h) = (*rows as usize, *h as usize);
8352                unsafe {
8353                    let zero = &zero_bias[..h];
8354                    let bi = if *has_bias { sl(*bias, base, h) } else { zero };
8355                    let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
8356                    let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
8357                    let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
8358                    let bi_ptr = bi.as_ptr() as usize;
8359                    let g_ptr = sl(*g, base, h).as_ptr() as usize;
8360                    let b_ptr = sl(*b, base, h).as_ptr() as usize;
8361                    let e = *eps;
8362                    crate::pool::par_for(rows, 4, &|off, cnt| {
8363                        let xs =
8364                            std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
8365                        let rs =
8366                            std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
8367                        let os = std::slice::from_raw_parts_mut(
8368                            (o_ptr as *mut f32).add(off * h),
8369                            cnt * h,
8370                        );
8371                        let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
8372                        let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
8373                        let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
8374                        crate::kernels::residual_bias_layer_norm(xs, rs, bi, g, b, os, cnt, h, e);
8375                    });
8376                }
8377            }
8378
8379            Thunk::FusedResidualRmsNorm {
8380                x,
8381                res,
8382                bias,
8383                g,
8384                b,
8385                out,
8386                rows,
8387                h,
8388                eps,
8389                has_bias,
8390            } => {
8391                let (rows, h) = (*rows as usize, *h as usize);
8392                unsafe {
8393                    let zero = &zero_bias[..h];
8394                    let bi = if *has_bias { sl(*bias, base, h) } else { zero };
8395                    let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
8396                    let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
8397                    let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
8398                    let bi_ptr = bi.as_ptr() as usize;
8399                    let g_ptr = sl(*g, base, h).as_ptr() as usize;
8400                    let b_ptr = sl(*b, base, h).as_ptr() as usize;
8401                    let e = *eps;
8402                    crate::pool::par_for(rows, 4, &|off, cnt| {
8403                        let xs =
8404                            std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
8405                        let rs =
8406                            std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
8407                        let os = std::slice::from_raw_parts_mut(
8408                            (o_ptr as *mut f32).add(off * h),
8409                            cnt * h,
8410                        );
8411                        let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
8412                        let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
8413                        let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
8414                        crate::kernels::residual_bias_rms_norm(xs, rs, bi, g, b, os, cnt, h, e);
8415                    });
8416                }
8417            }
8418
8419            Thunk::BiasAdd {
8420                src,
8421                bias,
8422                dst,
8423                m,
8424                n,
8425            } => {
8426                let (m, n) = (*m as usize, *n as usize);
8427                unsafe {
8428                    let out = sl_mut(*dst, base, m * n);
8429                    out.copy_from_slice(sl(*src, base, m * n));
8430                    crate::blas::bias_add(out, sl(*bias, base, n), m, n);
8431                }
8432            }
8433
8434            Thunk::BinaryFull {
8435                lhs,
8436                rhs,
8437                dst,
8438                len,
8439                lhs_len,
8440                rhs_len,
8441                op,
8442                out_dims_bcast,
8443                bcast_lhs_strides,
8444                bcast_rhs_strides,
8445            } => {
8446                let len = *len as usize;
8447                let ll = (*lhs_len as usize).max(1);
8448                let rl = (*rhs_len as usize).max(1);
8449                unsafe {
8450                    let l = sl(*lhs, base, ll);
8451                    let r = sl(*rhs, base, rl);
8452                    let o = sl_mut(*dst, base, len);
8453                    // Fast path: shapes match exactly → NEON-vectorized loop.
8454                    if ll == len && rl == len {
8455                        #[cfg(target_arch = "aarch64")]
8456                        if matches!(op, BinaryOp::Add | BinaryOp::Mul) {
8457                            use std::arch::aarch64::*;
8458                            let chunks = len / 4;
8459                            for c in 0..chunks {
8460                                let off = c * 4;
8461                                let vl = vld1q_f32(l.as_ptr().add(off));
8462                                let vr = vld1q_f32(r.as_ptr().add(off));
8463                                let res = match op {
8464                                    BinaryOp::Add => vaddq_f32(vl, vr),
8465                                    BinaryOp::Mul => vmulq_f32(vl, vr),
8466                                    _ => unreachable!(),
8467                                };
8468                                vst1q_f32(o.as_mut_ptr().add(off), res);
8469                            }
8470                            for i in (chunks * 4)..len {
8471                                o[i] = match op {
8472                                    BinaryOp::Add => l[i] + r[i],
8473                                    BinaryOp::Mul => l[i] * r[i],
8474                                    _ => unreachable!(),
8475                                };
8476                            }
8477                            // `continue` to next thunk in the schedule — a
8478                            // bare `return` here used to exit execute_thunks
8479                            // entirely, silently dropping every thunk after
8480                            // the first BinaryFull (catastrophic for chained
8481                            // adds in BERT embedding stage).
8482                            continue;
8483                        }
8484                    }
8485                    if !out_dims_bcast.is_empty() {
8486                        // Shape-aware broadcast path: correct for
8487                        // bidirectional `[N,1] op [1,S]` etc.
8488                        let rank = out_dims_bcast.len();
8489                        let mut coords = vec![0u32; rank];
8490                        for i in 0..len {
8491                            let mut rem = i;
8492                            for ax in (0..rank).rev() {
8493                                let sz = out_dims_bcast[ax] as usize;
8494                                coords[ax] = (rem % sz) as u32;
8495                                rem /= sz;
8496                            }
8497                            let mut li: usize = 0;
8498                            let mut ri: usize = 0;
8499                            for ax in 0..rank {
8500                                li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
8501                                ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
8502                            }
8503                            o[i] = match op {
8504                                BinaryOp::Add => l[li] + r[ri],
8505                                BinaryOp::Sub => l[li] - r[ri],
8506                                BinaryOp::Mul => l[li] * r[ri],
8507                                BinaryOp::Div => l[li] / r[ri],
8508                                BinaryOp::Max => l[li].max(r[ri]),
8509                                BinaryOp::Min => l[li].min(r[ri]),
8510                                BinaryOp::Pow => l[li].powf(r[ri]),
8511                            };
8512                        }
8513                    } else {
8514                        // Fallback: legacy modulo path (dynamic shapes only).
8515                        for i in 0..len {
8516                            let li = if ll == 1 { 0 } else { i % ll };
8517                            let ri = if rl == 1 { 0 } else { i % rl };
8518                            o[i] = match op {
8519                                BinaryOp::Add => l[li] + r[ri],
8520                                BinaryOp::Sub => l[li] - r[ri],
8521                                BinaryOp::Mul => l[li] * r[ri],
8522                                BinaryOp::Div => l[li] / r[ri],
8523                                BinaryOp::Max => l[li].max(r[ri]),
8524                                BinaryOp::Min => l[li].min(r[ri]),
8525                                BinaryOp::Pow => l[li].powf(r[ri]),
8526                            };
8527                        }
8528                    }
8529                }
8530            }
8531
8532            Thunk::Gather {
8533                table,
8534                table_len,
8535                idx,
8536                dst,
8537                num_idx,
8538                trailing,
8539            } => {
8540                let (ni, tr) = (*num_idx as usize, *trailing as usize);
8541                unsafe {
8542                    let tab = sl(*table, base, *table_len as usize);
8543                    let ids = sl(*idx, base, ni);
8544                    let out = sl_mut(*dst, base, ni * tr);
8545                    for i in 0..ni {
8546                        let row = ids[i] as usize;
8547                        out[i * tr..(i + 1) * tr].copy_from_slice(&tab[row * tr..(row + 1) * tr]);
8548                    }
8549                }
8550            }
8551
8552            Thunk::Narrow {
8553                src,
8554                dst,
8555                outer,
8556                src_stride,
8557                dst_stride,
8558                inner,
8559                elem_bytes,
8560            } => {
8561                let f = narrow_thunk_closure(
8562                    *src,
8563                    *dst,
8564                    *outer,
8565                    *src_stride,
8566                    *dst_stride,
8567                    *inner,
8568                    *elem_bytes,
8569                );
8570                f(base);
8571            }
8572
8573            Thunk::Copy { src, dst, len } => {
8574                let len = *len as usize;
8575                unsafe {
8576                    let s = sl(*src, base, len);
8577                    let d = sl_mut(*dst, base, len);
8578                    d.copy_from_slice(s);
8579                }
8580            }
8581
8582            Thunk::LayerNorm {
8583                src,
8584                g,
8585                b,
8586                dst,
8587                rows,
8588                h,
8589                eps,
8590            } => {
8591                let (rows, h) = (*rows as usize, *h as usize);
8592                unsafe {
8593                    let input = sl(*src, base, rows * h);
8594                    let gamma = sl(*g, base, h);
8595                    let beta = sl(*b, base, h);
8596                    let output = sl_mut(*dst, base, rows * h);
8597                    // Parallelize across rows (same pattern as FusedResidualLN)
8598                    if rows >= 4 && rows * h >= 30_000 {
8599                        let i_ptr = input.as_ptr() as usize;
8600                        let o_ptr = output.as_mut_ptr() as usize;
8601                        let g_ptr = gamma.as_ptr() as usize;
8602                        let b_ptr = beta.as_ptr() as usize;
8603                        let e = *eps;
8604                        crate::pool::par_for(rows, 4, &|off, cnt| {
8605                            let inp = std::slice::from_raw_parts(
8606                                (i_ptr as *const f32).add(off * h),
8607                                cnt * h,
8608                            );
8609                            let out = std::slice::from_raw_parts_mut(
8610                                (o_ptr as *mut f32).add(off * h),
8611                                cnt * h,
8612                            );
8613                            let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
8614                            let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
8615                            for row in 0..cnt {
8616                                crate::kernels::layer_norm_row(
8617                                    &inp[row * h..(row + 1) * h],
8618                                    g,
8619                                    b,
8620                                    &mut out[row * h..(row + 1) * h],
8621                                    h,
8622                                    e,
8623                                );
8624                            }
8625                        });
8626                    } else {
8627                        for row in 0..rows {
8628                            crate::kernels::layer_norm_row(
8629                                &input[row * h..(row + 1) * h],
8630                                gamma,
8631                                beta,
8632                                &mut output[row * h..(row + 1) * h],
8633                                h,
8634                                *eps,
8635                            );
8636                        }
8637                    }
8638                }
8639            }
8640
8641            Thunk::GroupNorm {
8642                src,
8643                g,
8644                b,
8645                dst,
8646                n,
8647                c,
8648                h,
8649                w,
8650                num_groups,
8651                eps,
8652            } => {
8653                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
8654                let plane = c * h * w;
8655                unsafe {
8656                    for ni in 0..n {
8657                        let input = sl(*src, base.add(ni * plane), plane);
8658                        let gamma = sl(*g, base, c);
8659                        let beta = sl(*b, base, c);
8660                        let output = sl_mut(*dst, base.add(ni * plane), plane);
8661                        crate::kernels::group_norm_nchw(
8662                            input,
8663                            gamma,
8664                            beta,
8665                            output,
8666                            1,
8667                            c,
8668                            h,
8669                            w,
8670                            *num_groups as usize,
8671                            *eps,
8672                        );
8673                    }
8674                }
8675            }
8676
8677            Thunk::LayerNorm2d {
8678                src,
8679                g,
8680                b,
8681                dst,
8682                n,
8683                c,
8684                h,
8685                w,
8686                eps,
8687            } => {
8688                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
8689                let plane = c * h * w;
8690                unsafe {
8691                    let input = sl(*src, base, n * plane);
8692                    let gamma = sl(*g, base, c);
8693                    let beta = sl(*b, base, c);
8694                    let output = sl_mut(*dst, base, n * plane);
8695                    crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, *eps);
8696                }
8697            }
8698
8699            Thunk::ConvTranspose2d {
8700                src,
8701                weight,
8702                dst,
8703                n,
8704                c_in,
8705                h,
8706                w_in,
8707                c_out,
8708                h_out,
8709                w_out,
8710                kh,
8711                kw,
8712                sh,
8713                sw,
8714                ph,
8715                pw,
8716                dh,
8717                dw,
8718                groups,
8719            } => {
8720                let n = *n as usize;
8721                let c_in = *c_in as usize;
8722                let h = *h as usize;
8723                let w_in = *w_in as usize;
8724                let c_out = *c_out as usize;
8725                let h_out = *h_out as usize;
8726                let w_out = *w_out as usize;
8727                unsafe {
8728                    let inp = sl(*src, base, n * c_in * h * w_in);
8729                    let wt = sl(
8730                        *weight,
8731                        base,
8732                        c_in * (c_out / *groups as usize) * (*kh as usize) * (*kw as usize),
8733                    );
8734                    let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
8735                    crate::kernels::conv_transpose2d_nchw(
8736                        inp,
8737                        wt,
8738                        out,
8739                        n,
8740                        c_in,
8741                        h,
8742                        w_in,
8743                        c_out,
8744                        h_out,
8745                        w_out,
8746                        *kh as usize,
8747                        *kw as usize,
8748                        *sh as usize,
8749                        *sw as usize,
8750                        *ph as usize,
8751                        *pw as usize,
8752                        *dh as usize,
8753                        *dw as usize,
8754                        *groups as usize,
8755                    );
8756                }
8757            }
8758
8759            Thunk::ResizeNearest2x {
8760                src,
8761                dst,
8762                n,
8763                c,
8764                h,
8765                w,
8766            } => {
8767                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
8768                let in_plane = c * h * w;
8769                let out_plane = c * h * 2 * w * 2;
8770                unsafe {
8771                    for ni in 0..n {
8772                        let input = sl(*src, base.add(ni * in_plane), in_plane);
8773                        let output = sl_mut(*dst, base.add(ni * out_plane), out_plane);
8774                        crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
8775                    }
8776                }
8777            }
8778
8779            Thunk::AxialRope2d {
8780                src,
8781                dst,
8782                batch,
8783                seq,
8784                hidden,
8785                end_x,
8786                end_y,
8787                head_dim,
8788                num_heads,
8789                theta,
8790                repeat_factor,
8791            } => {
8792                let b = *batch as usize;
8793                let s = *seq as usize;
8794                let hdim = *head_dim as usize;
8795                let nh = *num_heads as usize;
8796                let plane = s * (*hidden as usize);
8797                unsafe {
8798                    for bi in 0..b {
8799                        let input = sl(*src, base.add(bi * plane), plane);
8800                        let output = sl_mut(*dst, base.add(bi * plane), plane);
8801                        let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
8802                            input,
8803                            nh,
8804                            s,
8805                            hdim,
8806                            *end_x as usize,
8807                            *end_y as usize,
8808                            *theta,
8809                            *repeat_factor as usize,
8810                        );
8811                        output.copy_from_slice(&rotated);
8812                    }
8813                }
8814            }
8815
8816            Thunk::RmsNorm {
8817                src,
8818                g,
8819                b,
8820                dst,
8821                rows,
8822                h,
8823                eps,
8824            } => {
8825                let (rows, h) = (*rows as usize, *h as usize);
8826                unsafe {
8827                    let input = sl(*src, base, rows * h);
8828                    let gamma = sl(*g, base, h);
8829                    let beta = sl(*b, base, h);
8830                    let output = sl_mut(*dst, base, rows * h);
8831                    let inv_h = 1.0 / h as f32;
8832                    for row in 0..rows {
8833                        let in_row = &input[row * h..(row + 1) * h];
8834                        let out_row = &mut output[row * h..(row + 1) * h];
8835                        // RMS = sqrt(mean(x^2) + eps); scale = 1/RMS.
8836                        let mut sumsq = 0f32;
8837                        for &v in in_row {
8838                            sumsq += v * v;
8839                        }
8840                        let inv_rms = (sumsq * inv_h + *eps).sqrt().recip();
8841                        for i in 0..h {
8842                            out_row[i] = in_row[i] * inv_rms * gamma[i] + beta[i];
8843                        }
8844                    }
8845                }
8846            }
8847
8848            Thunk::Softmax { data, rows, cols } => {
8849                let (rows, cols) = (*rows as usize, *cols as usize);
8850                unsafe {
8851                    crate::kernels::neon_softmax(sl_mut(*data, base, rows * cols), rows, cols);
8852                }
8853            }
8854
8855            Thunk::Cumsum {
8856                src,
8857                dst,
8858                rows,
8859                cols,
8860                exclusive,
8861            } => {
8862                let (rows, cols) = (*rows as usize, *cols as usize);
8863                unsafe {
8864                    let s = sl(*src, base, rows * cols);
8865                    let d = sl_mut(*dst, base, rows * cols);
8866                    if *exclusive {
8867                        for r in 0..rows {
8868                            let mut acc = 0.0f32;
8869                            for c in 0..cols {
8870                                d[r * cols + c] = acc;
8871                                acc += s[r * cols + c];
8872                            }
8873                        }
8874                    } else {
8875                        for r in 0..rows {
8876                            let mut acc = 0.0f32;
8877                            for c in 0..cols {
8878                                acc += s[r * cols + c];
8879                                d[r * cols + c] = acc;
8880                            }
8881                        }
8882                    }
8883                }
8884            }
8885
8886            Thunk::Sample {
8887                logits,
8888                dst,
8889                batch,
8890                vocab,
8891                top_k,
8892                top_p,
8893                temperature,
8894                seed,
8895            } => {
8896                let (b, v) = (*batch as usize, *vocab as usize);
8897                let k = (*top_k as usize).min(v);
8898                unsafe {
8899                    let lg = sl(*logits, base, b * v);
8900                    let out = sl_mut(*dst, base, b);
8901                    let mut rng =
8902                        rlx_ir::Philox4x32::new(if *seed == 0 { 0xDEADBEEF } else { *seed });
8903                    for bi in 0..b {
8904                        let row = &lg[bi * v..(bi + 1) * v];
8905                        out[bi] = sample_row(row, k, *top_p, *temperature, &mut rng) as f32;
8906                    }
8907                }
8908            }
8909
8910            Thunk::GatedDeltaNet {
8911                q,
8912                k,
8913                v,
8914                g,
8915                beta,
8916                state,
8917                dst,
8918                batch,
8919                seq,
8920                heads,
8921                state_size,
8922            } => unsafe {
8923                execute_gated_delta_net_f32(
8924                    *q,
8925                    *k,
8926                    *v,
8927                    *g,
8928                    *beta,
8929                    *state,
8930                    *dst,
8931                    *batch as usize,
8932                    *seq as usize,
8933                    *heads as usize,
8934                    *state_size as usize,
8935                    base,
8936                );
8937            },
8938
8939            Thunk::SelectiveScan {
8940                x,
8941                delta,
8942                a,
8943                b: bp,
8944                c: cp,
8945                dst,
8946                batch,
8947                seq,
8948                hidden,
8949                state_size,
8950            } => {
8951                let (b, s, h, n) = (
8952                    *batch as usize,
8953                    *seq as usize,
8954                    *hidden as usize,
8955                    *state_size as usize,
8956                );
8957                unsafe {
8958                    let xs = sl(*x, base, b * s * h);
8959                    let dt = sl(*delta, base, b * s * h);
8960                    let am = sl(*a, base, h * n);
8961                    let bm = sl(*bp, base, b * s * n);
8962                    let cm = sl(*cp, base, b * s * n);
8963                    let out = sl_mut(*dst, base, b * s * h);
8964
8965                    // State buffer per-batch: h channels × n state.
8966                    // Sequential along the seq dimension; could
8967                    // parallelize over batch+channel later.
8968                    let mut state = vec![0f32; h * n];
8969                    for bi in 0..b {
8970                        // Reset state at the start of each batch row.
8971                        for v in state.iter_mut() {
8972                            *v = 0.0;
8973                        }
8974                        for si in 0..s {
8975                            let x_row = &xs[bi * s * h + si * h..bi * s * h + (si + 1) * h];
8976                            let dt_row = &dt[bi * s * h + si * h..bi * s * h + (si + 1) * h];
8977                            let b_row = &bm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
8978                            let c_row = &cm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
8979                            let out_row = &mut out[bi * s * h + si * h..bi * s * h + (si + 1) * h];
8980
8981                            for ci in 0..h {
8982                                let d = dt_row[ci];
8983                                let xv = x_row[ci];
8984                                let mut acc = 0f32;
8985                                for ni in 0..n {
8986                                    // Discretize: exp(d * a) and d * b.
8987                                    let da = (d * am[ci * n + ni]).exp();
8988                                    state[ci * n + ni] =
8989                                        da * state[ci * n + ni] + d * b_row[ni] * xv;
8990                                    acc += c_row[ni] * state[ci * n + ni];
8991                                }
8992                                out_row[ci] = acc;
8993                            }
8994                        }
8995                    }
8996                }
8997            }
8998
8999            Thunk::DequantMatMul {
9000                x,
9001                w_q,
9002                scale,
9003                zp,
9004                dst,
9005                m,
9006                k,
9007                n,
9008                block_size,
9009                is_asymmetric,
9010            } => {
9011                let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
9012                let n_blocks = k.div_ceil(bs);
9013                unsafe {
9014                    let xs = sl(*x, base, m * k);
9015                    let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const i8, k * n);
9016                    let scales = sl(*scale, base, n_blocks * n);
9017                    let zps = if *is_asymmetric {
9018                        sl(*zp, base, n_blocks * n)
9019                    } else {
9020                        &[][..]
9021                    };
9022                    let out = sl_mut(*dst, base, m * n);
9023                    dequant_matmul_int8(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
9024                }
9025            }
9026
9027            Thunk::DequantMatMulGguf {
9028                x,
9029                w_q,
9030                dst,
9031                m,
9032                k,
9033                n,
9034                scheme,
9035            } => {
9036                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9037                let block_bytes = scheme.gguf_block_bytes() as usize;
9038                let block_elems = scheme.gguf_block_size() as usize;
9039                debug_assert!(
9040                    block_bytes > 0 && block_elems > 0,
9041                    "non-GGUF scheme in GGUF arm"
9042                );
9043                debug_assert!(
9044                    (k * n).is_multiple_of(block_elems),
9045                    "k*n={} not aligned to GGUF block size {}",
9046                    k * n,
9047                    block_elems
9048                );
9049                let total_bytes = (k * n) / block_elems * block_bytes;
9050                unsafe {
9051                    let xs = sl(*x, base, m * k);
9052                    let w_bytes_ptr = base.add(*w_q) as *const u8;
9053                    let w_bytes = std::slice::from_raw_parts(w_bytes_ptr, total_bytes);
9054                    let out = sl_mut(*dst, base, m * n);
9055                    crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, *scheme);
9056                }
9057            }
9058
9059            Thunk::DequantMatMulInt4 {
9060                x,
9061                w_q,
9062                scale,
9063                zp,
9064                dst,
9065                m,
9066                k,
9067                n,
9068                block_size,
9069                is_asymmetric,
9070            } => {
9071                let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
9072                let n_blocks = k.div_ceil(bs);
9073                unsafe {
9074                    let xs = sl(*x, base, m * k);
9075                    let w_bytes = std::slice::from_raw_parts(
9076                        base.add(*w_q) as *const u8,
9077                        (k * n).div_ceil(2),
9078                    );
9079                    let scales = sl(*scale, base, n_blocks * n);
9080                    let zps = if *is_asymmetric {
9081                        sl(*zp, base, n_blocks * n)
9082                    } else {
9083                        &[][..]
9084                    };
9085                    let out = sl_mut(*dst, base, m * n);
9086                    dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
9087                }
9088            }
9089
9090            Thunk::DequantMatMulFp8 {
9091                x,
9092                w_q,
9093                scale,
9094                dst,
9095                m,
9096                k,
9097                n,
9098                e5m2,
9099            } => {
9100                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9101                unsafe {
9102                    let xs = sl(*x, base, m * k);
9103                    let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const u8, k * n);
9104                    let scales = sl(*scale, base, n);
9105                    let out = sl_mut(*dst, base, m * n);
9106                    dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, *e5m2);
9107                }
9108            }
9109
9110            Thunk::DequantMatMulNvfp4 {
9111                x,
9112                w_q,
9113                scale,
9114                global_scale,
9115                dst,
9116                m,
9117                k,
9118                n,
9119            } => {
9120                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9121                let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
9122                unsafe {
9123                    let xs = sl(*x, base, m * k);
9124                    let w_bytes = std::slice::from_raw_parts(
9125                        base.add(*w_q) as *const u8,
9126                        (k * n).div_ceil(2),
9127                    );
9128                    let scale_bytes =
9129                        std::slice::from_raw_parts(base.add(*scale) as *const u8, n_scale);
9130                    let gs = sl(*global_scale, base, 1)[0];
9131                    let out = sl_mut(*dst, base, m * n);
9132                    dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
9133                }
9134            }
9135
9136            Thunk::LoraMatMul {
9137                x,
9138                w,
9139                a,
9140                b,
9141                dst,
9142                m,
9143                k,
9144                n,
9145                r,
9146                scale,
9147            } => {
9148                let (m, k, n, r) = (*m as usize, *k as usize, *n as usize, *r as usize);
9149                unsafe {
9150                    let xs = sl(*x, base, m * k);
9151                    let ws = sl(*w, base, k * n);
9152                    let a_s = sl(*a, base, k * r);
9153                    let bs = sl(*b, base, r * n);
9154                    let out = sl_mut(*dst, base, m * n);
9155                    crate::blas::sgemm(xs, ws, out, m, k, n);
9156                    let mut tmp = vec![0f32; m * r];
9157                    crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
9158                    if *scale != 1.0 {
9159                        for v in tmp.iter_mut() {
9160                            *v *= *scale;
9161                        }
9162                    }
9163                    crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
9164                }
9165            }
9166
9167            Thunk::Attention {
9168                q,
9169                k,
9170                v,
9171                mask,
9172                out,
9173                batch,
9174                seq,
9175                kv_seq,
9176                heads,
9177                head_dim,
9178                mask_kind,
9179                q_row_stride,
9180                k_row_stride,
9181                v_row_stride,
9182                bhsd,
9183            } => {
9184                let (b, q_s, k_s, nh, dh) = (
9185                    *batch as usize,
9186                    *seq as usize,
9187                    *kv_seq as usize,
9188                    *heads as usize,
9189                    *head_dim as usize,
9190                );
9191                let hs = nh * dh;
9192                // For [B, H, S, D] layout each (b, h) tile is dense
9193                // contiguous; the qrs/krs/vrs strides are not used.
9194                let (qrs, krs, vrs) = if *bhsd {
9195                    (dh, dh, dh)
9196                } else {
9197                    (
9198                        *q_row_stride as usize,
9199                        *k_row_stride as usize,
9200                        *v_row_stride as usize,
9201                    )
9202                };
9203                let bhsd = *bhsd;
9204                let _ = (q_row_stride, k_row_stride, v_row_stride);
9205                let scale = (dh as f32).powf(-0.5);
9206                let ss = q_s * k_s;
9207                let cfg = crate::config::RuntimeConfig::global();
9208                unsafe {
9209                    // Slice lengths cover the strided span. When Q/K/V
9210                    // alias the parent QKV (post-#46-fusion), the same
9211                    // bytes back all three slices — compiler bounds
9212                    // checks see the right size. For [B, H, S, D] the
9213                    // buffer is densely B*H*S*D elements; the row
9214                    // strides aren't used.
9215                    let q_len = if bhsd {
9216                        b * nh * q_s * dh
9217                    } else {
9218                        b * q_s * qrs
9219                    };
9220                    let k_len = if bhsd {
9221                        b * nh * k_s * dh
9222                    } else {
9223                        b * k_s * krs
9224                    };
9225                    let v_len = if bhsd {
9226                        b * nh * k_s * dh
9227                    } else {
9228                        b * k_s * vrs
9229                    };
9230                    let q_data = sl(*q, base, q_len);
9231                    let k_data = sl(*k, base, k_len);
9232                    let v_data = sl(*v, base, v_len);
9233                    let mask_data: &[f32] = match mask_kind {
9234                        rlx_ir::op::MaskKind::Custom => sl(*mask, base, b * k_s),
9235                        rlx_ir::op::MaskKind::Bias => sl(*mask, base, b * nh * q_s * k_s),
9236                        _ => &[],
9237                    };
9238                    let out_len = if bhsd {
9239                        b * nh * q_s * dh
9240                    } else {
9241                        b * q_s * hs
9242                    };
9243                    let out_data = sl_mut(*out, base, out_len);
9244
9245                    // ── [B, H, S, D] fallback ──────────────────────
9246                    // The NEON / strided-BLAS specializations below
9247                    // are written for the [B, S, H, D] layout. When
9248                    // the input is head-major ([B, H, S, D] —
9249                    // matching rlx-cuda / rlx-rocm / rlx-tpu), bypass
9250                    // them and run a simple (correct but slower)
9251                    // scalar implementation. Production-CPU inference
9252                    // graphs use [B, S, H, D] so they still hit the
9253                    // hot path; cross-backend parity tests use
9254                    // [B, H, S, D] and land here.
9255                    if bhsd {
9256                        let scores = &mut sdpa_scores[..ss];
9257                        for bi in 0..b {
9258                            for hi in 0..nh {
9259                                let q_head_base = bi * nh * q_s * dh + hi * q_s * dh;
9260                                let k_head_base = bi * nh * k_s * dh + hi * k_s * dh;
9261                                // Q@K^T
9262                                for qi in 0..q_s {
9263                                    let q_base = q_head_base + qi * dh;
9264                                    for ki in 0..k_s {
9265                                        let k_base = k_head_base + ki * dh;
9266                                        let mut dot = 0f32;
9267                                        for d in 0..dh {
9268                                            dot += q_data[q_base + d] * k_data[k_base + d];
9269                                        }
9270                                        scores[qi * k_s + ki] = dot * scale;
9271                                        if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
9272                                            && !mask_data.is_empty()
9273                                            && mask_data[bi * k_s + ki] < mask_thr
9274                                        {
9275                                            scores[qi * k_s + ki] = mask_neg;
9276                                        }
9277                                    }
9278                                }
9279                                if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
9280                                    let off = (bi * nh + hi) * q_s * k_s;
9281                                    for i in 0..q_s * k_s {
9282                                        scores[i] += mask_data[off + i];
9283                                    }
9284                                }
9285                                apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
9286                                crate::kernels::neon_softmax(scores, q_s, k_s);
9287                                // score @ V
9288                                for qi in 0..q_s {
9289                                    let o_base = q_head_base + qi * dh;
9290                                    for d in 0..dh {
9291                                        out_data[o_base + d] = 0.0;
9292                                    }
9293                                    for ki in 0..k_s {
9294                                        let sc = scores[qi * k_s + ki];
9295                                        if sc > score_thr {
9296                                            let v_base = k_head_base + ki * dh;
9297                                            for d in 0..dh {
9298                                                out_data[o_base + d] += sc * v_data[v_base + d];
9299                                            }
9300                                        }
9301                                    }
9302                                }
9303                            }
9304                        }
9305                        continue;
9306                    }
9307
9308                    // ── Auto-select kernel: NEON dots vs strided BLAS ───
9309                    // For tiny inputs (batch=1, short seq), per-head BLAS call
9310                    // overhead (~0.5µs × 2 calls × num_heads × num_layers)
9311                    // exceeds the NEON compute cost. Use direct strided NEON
9312                    // with zero dispatch overhead.
9313                    // For batch≥2: always BLAS + par_for (parallelism wins).
9314                    if b == 1 && q_s.max(k_s) <= cfg.sdpa_seq_threshold {
9315                        // ── Sequential NEON path (zero overhead) ──
9316                        let scores = &mut sdpa_scores[..ss];
9317                        #[cfg(target_arch = "aarch64")]
9318                        let neon_chunks = dh / 4;
9319
9320                        for bi in 0..b {
9321                            for hi in 0..nh {
9322                                // Q@K^T via strided NEON dot products
9323                                for qi in 0..q_s {
9324                                    let q_off = bi * q_s * qrs + qi * qrs + hi * dh;
9325                                    for ki in 0..k_s {
9326                                        let k_off = bi * k_s * krs + ki * krs + hi * dh;
9327                                        #[cfg(target_arch = "aarch64")]
9328                                        let mut dot;
9329                                        #[cfg(not(target_arch = "aarch64"))]
9330                                        let mut dot = 0f32;
9331                                        #[cfg(target_arch = "aarch64")]
9332                                        {
9333                                            use std::arch::aarch64::*;
9334                                            let mut acc = vdupq_n_f32(0.0);
9335                                            for c in 0..neon_chunks {
9336                                                let vq =
9337                                                    vld1q_f32(q_data.as_ptr().add(q_off + c * 4));
9338                                                let vk =
9339                                                    vld1q_f32(k_data.as_ptr().add(k_off + c * 4));
9340                                                acc = vfmaq_f32(acc, vq, vk);
9341                                            }
9342                                            dot = vaddvq_f32(acc);
9343                                            for d in (neon_chunks * 4)..dh {
9344                                                dot += q_data[q_off + d] * k_data[k_off + d];
9345                                            }
9346                                        }
9347                                        #[cfg(not(target_arch = "aarch64"))]
9348                                        for d in 0..dh {
9349                                            dot += q_data[q_off + d] * k_data[k_off + d];
9350                                        }
9351                                        scores[qi * k_s + ki] = dot * scale;
9352                                        // Inner-loop Custom mask check —
9353                                        // Causal / SlidingWindow / None
9354                                        // apply outside the loop below.
9355                                        // Skip for Bias — that mask is a
9356                                        // per-head additive tensor, not a
9357                                        // 0/1 key-padding mask.
9358                                        if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
9359                                            && !mask_data.is_empty()
9360                                            && mask_data[bi * k_s + ki] < mask_thr
9361                                        {
9362                                            scores[qi * k_s + ki] = mask_neg;
9363                                        }
9364                                    }
9365                                }
9366
9367                                if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
9368                                    let off = (bi * nh + hi) * q_s * k_s;
9369                                    for i in 0..q_s * k_s {
9370                                        scores[i] += mask_data[off + i];
9371                                    }
9372                                }
9373                                apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
9374                                crate::kernels::neon_softmax(scores, q_s, k_s);
9375
9376                                // Score@V via strided NEON accumulation (zero-copy)
9377                                for qi in 0..q_s {
9378                                    let o_off = bi * q_s * hs + qi * hs + hi * dh;
9379                                    // Zero output for this head position
9380                                    for d in 0..dh {
9381                                        out_data[o_off + d] = 0.0;
9382                                    }
9383                                    for ki in 0..k_s {
9384                                        let sc = scores[qi * k_s + ki];
9385                                        if sc > score_thr {
9386                                            let v_off = bi * k_s * vrs + ki * vrs + hi * dh;
9387                                            #[cfg(target_arch = "aarch64")]
9388                                            {
9389                                                use std::arch::aarch64::*;
9390                                                let vsc = vdupq_n_f32(sc);
9391                                                for c in 0..neon_chunks {
9392                                                    let off = c * 4;
9393                                                    let vo = vld1q_f32(
9394                                                        out_data.as_ptr().add(o_off + off),
9395                                                    );
9396                                                    let vv =
9397                                                        vld1q_f32(v_data.as_ptr().add(v_off + off));
9398                                                    vst1q_f32(
9399                                                        out_data.as_mut_ptr().add(o_off + off),
9400                                                        vfmaq_f32(vo, vsc, vv),
9401                                                    );
9402                                                }
9403                                            }
9404                                            #[cfg(not(target_arch = "aarch64"))]
9405                                            for d in 0..dh {
9406                                                out_data[o_off + d] += sc * v_data[v_off + d];
9407                                            }
9408                                        }
9409                                    }
9410                                }
9411                            }
9412                        }
9413                    } else {
9414                        // ── Parallel strided BLAS path (high throughput) ──
9415                        let total_work = b * nh;
9416                        let q_addr = q_data.as_ptr() as usize;
9417                        let k_addr = k_data.as_ptr() as usize;
9418                        let v_addr = v_data.as_ptr() as usize;
9419                        let m_addr = mask_data.as_ptr() as usize;
9420                        let o_addr = out_data.as_mut_ptr() as usize;
9421                        let sc_addr = sdpa_scores.as_mut_ptr() as usize;
9422
9423                        crate::pool::par_for(total_work, 1, &|off, cnt| {
9424                            for idx in off..off + cnt {
9425                                let bi = idx / nh;
9426                                let hi = idx % nh;
9427
9428                                let q_start = (q_addr as *const f32).add(bi * q_s * qrs + hi * dh);
9429                                let k_start = (k_addr as *const f32).add(bi * k_s * krs + hi * dh);
9430                                let v_start = (v_addr as *const f32).add(bi * k_s * vrs + hi * dh);
9431                                let o_start = (o_addr as *mut f32).add(bi * q_s * hs + hi * dh);
9432                                let sc = std::slice::from_raw_parts_mut(
9433                                    (sc_addr as *mut f32).add(idx * ss),
9434                                    ss,
9435                                );
9436
9437                                // LDA = qrs, LDB = krs (parent row strides
9438                                // when fused; hs otherwise).
9439                                crate::blas::sgemm_general(
9440                                    q_start,
9441                                    k_start,
9442                                    sc.as_mut_ptr(),
9443                                    q_s,
9444                                    k_s,
9445                                    dh,
9446                                    scale,
9447                                    0.0,
9448                                    qrs,
9449                                    krs,
9450                                    k_s,
9451                                    false,
9452                                    true,
9453                                );
9454
9455                                match mask_kind {
9456                                    rlx_ir::op::MaskKind::Custom => {
9457                                        let mask_bi = std::slice::from_raw_parts(
9458                                            (m_addr as *const f32).add(bi * k_s),
9459                                            k_s,
9460                                        );
9461                                        for ki in 0..k_s {
9462                                            if mask_bi[ki] < mask_thr {
9463                                                for qi in 0..q_s {
9464                                                    sc[qi * k_s + ki] = mask_neg;
9465                                                }
9466                                            }
9467                                        }
9468                                    }
9469                                    rlx_ir::op::MaskKind::Bias => {
9470                                        // Per-head additive bias slice.
9471                                        let bias = std::slice::from_raw_parts(
9472                                            (m_addr as *const f32).add((bi * nh + hi) * q_s * k_s),
9473                                            q_s * k_s,
9474                                        );
9475                                        for i in 0..q_s * k_s {
9476                                            sc[i] += bias[i];
9477                                        }
9478                                    }
9479                                    _ => apply_synthetic_mask(sc, q_s, k_s, *mask_kind),
9480                                }
9481
9482                                crate::kernels::neon_softmax(sc, q_s, k_s);
9483
9484                                // LDB = vrs (parent row stride when
9485                                // fused; hs otherwise). LDC stays hs —
9486                                // output is its own contiguous buffer.
9487                                crate::blas::sgemm_general(
9488                                    sc.as_ptr(),
9489                                    v_start,
9490                                    o_start,
9491                                    q_s,
9492                                    dh,
9493                                    k_s,
9494                                    1.0,
9495                                    0.0,
9496                                    k_s,
9497                                    vrs,
9498                                    hs,
9499                                    false,
9500                                    false,
9501                                );
9502                            }
9503                        });
9504                    }
9505                }
9506            }
9507
9508            Thunk::AttentionBackward {
9509                q,
9510                k,
9511                v,
9512                dy,
9513                mask,
9514                out,
9515                batch,
9516                seq,
9517                kv_seq,
9518                heads,
9519                head_dim,
9520                mask_kind,
9521                wrt,
9522                bhsd,
9523            } => {
9524                let (b, q_s, k_s, nh, dh) = (
9525                    *batch as usize,
9526                    *seq as usize,
9527                    *kv_seq as usize,
9528                    *heads as usize,
9529                    *head_dim as usize,
9530                );
9531                unsafe {
9532                    let q_len = if *bhsd {
9533                        b * nh * q_s * dh
9534                    } else {
9535                        b * q_s * nh * dh
9536                    };
9537                    let k_len = if *bhsd {
9538                        b * nh * k_s * dh
9539                    } else {
9540                        b * k_s * nh * dh
9541                    };
9542                    let out_len = match wrt {
9543                        rlx_ir::op::AttentionBwdWrt::Key | rlx_ir::op::AttentionBwdWrt::Value => {
9544                            k_len
9545                        }
9546                        rlx_ir::op::AttentionBwdWrt::Query => q_len,
9547                    };
9548                    let q_data = sl(*q, base, q_len);
9549                    let k_data = sl(*k, base, k_len);
9550                    let v_data = sl(*v, base, k_len);
9551                    let dy_data = sl(*dy, base, q_len);
9552                    let out_data = sl_mut(*out, base, out_len);
9553                    let mask_data: &[f32] = if *mask != 0 {
9554                        let ml = match mask_kind {
9555                            rlx_ir::op::MaskKind::Custom => b * k_s,
9556                            rlx_ir::op::MaskKind::Bias => b * nh * q_s * k_s,
9557                            _ => 0,
9558                        };
9559                        sl(*mask, base, ml)
9560                    } else {
9561                        &[]
9562                    };
9563                    crate::attention_bwd::attention_backward(
9564                        *wrt, q_data, k_data, v_data, dy_data, out_data, b, nh, q_s, k_s, dh,
9565                        *mask_kind, mask_data, *bhsd,
9566                    );
9567                }
9568            }
9569
9570            Thunk::ActivationInPlace { data, len, act } => {
9571                let len = *len as usize;
9572                unsafe {
9573                    let d = sl_mut(*data, base, len);
9574                    match act {
9575                        Activation::Gelu => crate::kernels::par_gelu_inplace(d),
9576                        Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
9577                        Activation::Silu => crate::kernels::par_silu_inplace(d),
9578                        Activation::Relu => {
9579                            for v in d.iter_mut() {
9580                                *v = v.max(0.0);
9581                            }
9582                        }
9583                        Activation::Sigmoid => {
9584                            for v in d.iter_mut() {
9585                                *v = 1.0 / (1.0 + (-*v).exp());
9586                            }
9587                        }
9588                        Activation::Tanh => {
9589                            for v in d.iter_mut() {
9590                                *v = v.tanh();
9591                            }
9592                        }
9593                        Activation::Exp => {
9594                            for v in d.iter_mut() {
9595                                *v = v.exp();
9596                            }
9597                        }
9598                        Activation::Log => {
9599                            for v in d.iter_mut() {
9600                                *v = v.ln();
9601                            }
9602                        }
9603                        Activation::Sqrt => {
9604                            for v in d.iter_mut() {
9605                                *v = v.sqrt();
9606                            }
9607                        }
9608                        Activation::Rsqrt => {
9609                            for v in d.iter_mut() {
9610                                *v = 1.0 / v.sqrt();
9611                            }
9612                        }
9613                        Activation::Neg => {
9614                            for v in d.iter_mut() {
9615                                *v = -*v;
9616                            }
9617                        }
9618                        Activation::Abs => {
9619                            for v in d.iter_mut() {
9620                                *v = v.abs();
9621                            }
9622                        }
9623                        Activation::Round => {
9624                            for v in d.iter_mut() {
9625                                *v = v.round();
9626                            }
9627                        }
9628                        Activation::Sin => {
9629                            for v in d.iter_mut() {
9630                                *v = v.sin();
9631                            }
9632                        }
9633                        Activation::Cos => {
9634                            for v in d.iter_mut() {
9635                                *v = v.cos();
9636                            }
9637                        }
9638                        Activation::Tan => {
9639                            for v in d.iter_mut() {
9640                                *v = v.tan();
9641                            }
9642                        }
9643                        Activation::Atan => {
9644                            for v in d.iter_mut() {
9645                                *v = v.atan();
9646                            }
9647                        }
9648                    }
9649                }
9650            }
9651
9652            Thunk::FusedAttnBlock {
9653                hidden,
9654                qkv_w,
9655                out_w,
9656                mask,
9657                out,
9658                qkv_b,
9659                out_b,
9660                cos,
9661                sin,
9662                cos_len,
9663                batch,
9664                seq,
9665                hs,
9666                nh,
9667                dh,
9668                has_bias,
9669                has_rope,
9670            } => {
9671                let (b, s) = (*batch as usize, *seq as usize);
9672                let (h, n_h, d_h) = (*hs as usize, *nh as usize, *dh as usize);
9673                let m = b * s;
9674                let scale = (d_h as f32).powf(-0.5);
9675                let half = d_h / 2;
9676                unsafe {
9677                    let inp = sl(*hidden, base, m * h);
9678                    let wq = sl(*qkv_w, base, h * 3 * h);
9679                    let wo = sl(*out_w, base, h * h);
9680                    let mk = sl(*mask, base, b * s);
9681                    let dst = sl_mut(*out, base, m * h);
9682
9683                    // Stack-allocated intermediates — all fit in L1 cache for small batch
9684                    let mut qkv = vec![0f32; m * 3 * h];
9685                    let mut attn_out = vec![0f32; m * h];
9686                    let mut scores_buf = vec![0f32; s * s]; // one head at a time
9687
9688                    // 1. QKV projection: [m, h] @ [h, 3h] → [m, 3h]
9689                    crate::blas::sgemm(inp, wq, &mut qkv, m, h, 3 * h);
9690                    if *has_bias {
9691                        let bias = sl(*qkv_b, base, 3 * h);
9692                        crate::blas::bias_add(&mut qkv, bias, m, 3 * h);
9693                    }
9694
9695                    // 2. Multi-head SDPA (Q/K/V are views into qkv at offsets 0, h, 2h)
9696                    //    Process heads sequentially with inline RoPE — zero copy.
9697                    #[cfg(target_arch = "aarch64")]
9698                    let neon_chunks = d_h / 4;
9699                    #[cfg(target_arch = "aarch64")]
9700                    let _rope_chunks = half / 4;
9701
9702                    for bi in 0..b {
9703                        for hi in 0..n_h {
9704                            // For each (query_pos, key_pos): compute Q@K^T with inline RoPE
9705                            for qi in 0..s {
9706                                let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
9707                                for ki in 0..s {
9708                                    let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
9709                                    let mut dot = 0f32;
9710
9711                                    if *has_rope {
9712                                        // Apply RoPE inline during dot product
9713                                        let q_cos = qi * half;
9714                                        let k_cos = ki * half;
9715                                        let cos_tab = sl(*cos, base, *cos_len as usize);
9716                                        let sin_tab = sl(*sin, base, *cos_len as usize);
9717                                        // First half: (q1*c - q2*s) * (k1*c - k2*s)
9718                                        // Second half: (q2*c + q1*s) * (k2*c + k1*s)
9719                                        for i in 0..half {
9720                                            let q1 = qkv[q_base + i];
9721                                            let q2 = qkv[q_base + half + i];
9722                                            let k1 = qkv[k_base + i];
9723                                            let k2 = qkv[k_base + half + i];
9724                                            let c_q = cos_tab[q_cos + i];
9725                                            let s_q = sin_tab[q_cos + i];
9726                                            let c_k = cos_tab[k_cos + i];
9727                                            let s_k = sin_tab[k_cos + i];
9728                                            let qr1 = q1 * c_q - q2 * s_q;
9729                                            let kr1 = k1 * c_k - k2 * s_k;
9730                                            let qr2 = q2 * c_q + q1 * s_q;
9731                                            let kr2 = k2 * c_k + k1 * s_k;
9732                                            dot += qr1 * kr1 + qr2 * kr2;
9733                                        }
9734                                    } else {
9735                                        // Standard dot product
9736                                        #[cfg(target_arch = "aarch64")]
9737                                        {
9738                                            use std::arch::aarch64::*;
9739                                            let mut acc = vdupq_n_f32(0.0);
9740                                            for c in 0..neon_chunks {
9741                                                let vq =
9742                                                    vld1q_f32(qkv.as_ptr().add(q_base + c * 4));
9743                                                let vk =
9744                                                    vld1q_f32(qkv.as_ptr().add(k_base + c * 4));
9745                                                acc = vfmaq_f32(acc, vq, vk);
9746                                            }
9747                                            dot = vaddvq_f32(acc);
9748                                            for d in (neon_chunks * 4)..d_h {
9749                                                dot += qkv[q_base + d] * qkv[k_base + d];
9750                                            }
9751                                        }
9752                                        #[cfg(not(target_arch = "aarch64"))]
9753                                        for d in 0..d_h {
9754                                            dot += qkv[q_base + d] * qkv[k_base + d];
9755                                        }
9756                                    }
9757
9758                                    scores_buf[qi * s + ki] = dot * scale;
9759                                    if mk[bi * s + ki] < mask_thr {
9760                                        scores_buf[qi * s + ki] = mask_neg;
9761                                    }
9762                                }
9763                            }
9764
9765                            // Softmax
9766                            crate::kernels::neon_softmax(&mut scores_buf[..s * s], s, s);
9767
9768                            // Score @ V accumulation (V at offset 2h in QKV)
9769                            for qi in 0..s {
9770                                let o_base = bi * s * h + qi * h + hi * d_h;
9771                                for d in 0..d_h {
9772                                    attn_out[o_base + d] = 0.0;
9773                                }
9774                                for ki in 0..s {
9775                                    let sc = scores_buf[qi * s + ki];
9776                                    if sc > score_thr {
9777                                        let v_base = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
9778                                        #[cfg(target_arch = "aarch64")]
9779                                        {
9780                                            use std::arch::aarch64::*;
9781                                            let vsc = vdupq_n_f32(sc);
9782                                            for c in 0..neon_chunks {
9783                                                let off = c * 4;
9784                                                let vo =
9785                                                    vld1q_f32(attn_out.as_ptr().add(o_base + off));
9786                                                let vv = vld1q_f32(qkv.as_ptr().add(v_base + off));
9787                                                vst1q_f32(
9788                                                    attn_out.as_mut_ptr().add(o_base + off),
9789                                                    vfmaq_f32(vo, vsc, vv),
9790                                                );
9791                                            }
9792                                        }
9793                                        #[cfg(not(target_arch = "aarch64"))]
9794                                        for d in 0..d_h {
9795                                            attn_out[o_base + d] += sc * qkv[v_base + d];
9796                                        }
9797                                    }
9798                                }
9799                            }
9800                        }
9801                    }
9802
9803                    // 3. Output projection: [m, h] @ [h, h] → dst
9804                    crate::blas::sgemm(&attn_out, wo, dst, m, h, h);
9805                    if *has_bias {
9806                        let bias = sl(*out_b, base, h);
9807                        crate::blas::bias_add(dst, bias, m, h);
9808                    }
9809                }
9810            }
9811
9812            Thunk::Rope {
9813                src,
9814                cos,
9815                sin,
9816                dst,
9817                batch,
9818                seq,
9819                hidden,
9820                head_dim,
9821                n_rot,
9822                cos_len,
9823                src_row_stride,
9824            } => {
9825                let (b, s, hs, dh, nr) = (
9826                    *batch as usize,
9827                    *seq as usize,
9828                    *hidden as usize,
9829                    *head_dim as usize,
9830                    *n_rot as usize,
9831                );
9832                let tab_half = dh / 2;
9833                let rot_half = nr / 2;
9834                let nh = hs / dh;
9835                let cl = *cos_len as usize;
9836                let src_rs = *src_row_stride as usize;
9837                unsafe {
9838                    let x = sl(*src, base, b * s * src_rs);
9839                    let cos_tab = sl(*cos, base, cl);
9840                    let sin_tab = sl(*sin, base, cl);
9841                    let out = sl_mut(*dst, base, b * s * hs);
9842
9843                    let total = b * s;
9844                    let x_ptr = x.as_ptr() as usize;
9845                    let o_ptr = out.as_mut_ptr() as usize;
9846                    let c_ptr = cos_tab.as_ptr() as usize;
9847                    let s_ptr = sin_tab.as_ptr() as usize;
9848
9849                    crate::pool::par_for(total, 4, &|off, cnt| {
9850                        for idx in off..off + cnt {
9851                            let bi = idx / s;
9852                            let si = idx % s;
9853                            let tab_off = si * tab_half;
9854
9855                            for hi in 0..nh {
9856                                let src_base = bi * s * src_rs + si * src_rs + hi * dh;
9857                                let dst_base = bi * s * hs + si * hs + hi * dh;
9858                                let xp = (x_ptr as *const f32).add(src_base);
9859                                let op = (o_ptr as *mut f32).add(dst_base);
9860                                let cp = (c_ptr as *const f32).add(tab_off);
9861                                let sp = (s_ptr as *const f32).add(tab_off);
9862
9863                                for i in 0..rot_half {
9864                                    let x1 = *xp.add(i);
9865                                    let x2 = *xp.add(rot_half + i);
9866                                    let cv = *cp.add(i);
9867                                    let sv = *sp.add(i);
9868                                    *op.add(i) = x1 * cv - x2 * sv;
9869                                    *op.add(rot_half + i) = x2 * cv + x1 * sv;
9870                                }
9871                                for j in nr..dh {
9872                                    *op.add(j) = *xp.add(j);
9873                                }
9874                            }
9875                        }
9876                    });
9877                }
9878            }
9879            Thunk::FusedBertLayer {
9880                hidden,
9881                qkv_w,
9882                qkv_b,
9883                out_w,
9884                out_b,
9885                mask,
9886                ln1_g,
9887                ln1_b,
9888                eps1,
9889                fc1_w,
9890                fc1_b,
9891                fc2_w,
9892                fc2_b,
9893                ln2_g,
9894                ln2_b,
9895                eps2,
9896                out,
9897                batch,
9898                seq,
9899                hs,
9900                nh,
9901                dh,
9902                int_dim,
9903            } => {
9904                let (b, s, h, n_h, d_h) = (
9905                    *batch as usize,
9906                    *seq as usize,
9907                    *hs as usize,
9908                    *nh as usize,
9909                    *dh as usize,
9910                );
9911                let m = b * s;
9912                let id = *int_dim as usize;
9913                let scale = (d_h as f32).powf(-0.5);
9914                let _half = d_h / 2;
9915                #[cfg(target_arch = "aarch64")]
9916                let neon_chunks = d_h / 4;
9917                unsafe {
9918                    let inp = sl(*hidden, base, m * h);
9919                    let dst = sl_mut(*out, base, m * h);
9920                    let mk = sl(*mask, base, b * s);
9921
9922                    // Pre-allocated buffers (zero malloc per layer — allocated once before thunk loop)
9923                    let qkv = std::slice::from_raw_parts_mut(fl_qkv.as_mut_ptr(), m * 3 * h);
9924                    let attn = std::slice::from_raw_parts_mut(fl_attn.as_mut_ptr(), m * h);
9925                    let res = std::slice::from_raw_parts_mut(fl_res.as_mut_ptr(), m * h);
9926                    let normed = std::slice::from_raw_parts_mut(fl_normed.as_mut_ptr(), m * h);
9927                    let ffn = std::slice::from_raw_parts_mut(fl_ffn.as_mut_ptr(), m * id);
9928                    let sc = std::slice::from_raw_parts_mut(fl_sc.as_mut_ptr(), s * s);
9929
9930                    // QKV (parallelized across cores — multiple AMX coprocessors)
9931                    crate::blas::par_sgemm_bias(
9932                        inp,
9933                        sl(*qkv_w, base, h * 3 * h),
9934                        sl(*qkv_b, base, 3 * h),
9935                        qkv,
9936                        m,
9937                        h,
9938                        3 * h,
9939                    );
9940
9941                    // SDPA per head (sequential NEON, inline — zero overhead)
9942                    for bi in 0..b {
9943                        for hi in 0..n_h {
9944                            for qi in 0..s {
9945                                for ki in 0..s {
9946                                    let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
9947                                    let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
9948                                    #[cfg(target_arch = "aarch64")]
9949                                    let dot;
9950                                    #[cfg(not(target_arch = "aarch64"))]
9951                                    let mut dot = 0f32;
9952                                    #[cfg(target_arch = "aarch64")]
9953                                    {
9954                                        use std::arch::aarch64::*;
9955                                        let mut acc = vdupq_n_f32(0.0);
9956                                        for c in 0..neon_chunks {
9957                                            acc = vfmaq_f32(
9958                                                acc,
9959                                                vld1q_f32(qkv.as_ptr().add(q_base + c * 4)),
9960                                                vld1q_f32(qkv.as_ptr().add(k_base + c * 4)),
9961                                            );
9962                                        }
9963                                        dot = vaddvq_f32(acc);
9964                                    }
9965                                    #[cfg(not(target_arch = "aarch64"))]
9966                                    for d in 0..d_h {
9967                                        dot += qkv[q_base + d] * qkv[k_base + d];
9968                                    }
9969                                    sc[qi * s + ki] = dot * scale;
9970                                    if mk[bi * s + ki] < mask_thr {
9971                                        sc[qi * s + ki] = mask_neg;
9972                                    }
9973                                }
9974                            }
9975                            crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
9976                            for qi in 0..s {
9977                                let o = bi * s * h + qi * h + hi * d_h;
9978                                for d in 0..d_h {
9979                                    attn[o + d] = 0.0;
9980                                }
9981                                for ki in 0..s {
9982                                    let w = sc[qi * s + ki];
9983                                    if w > score_thr {
9984                                        let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
9985                                        #[cfg(target_arch = "aarch64")]
9986                                        {
9987                                            use std::arch::aarch64::*;
9988                                            let vw = vdupq_n_f32(w);
9989                                            for c in 0..neon_chunks {
9990                                                let off = c * 4;
9991                                                vst1q_f32(
9992                                                    attn.as_mut_ptr().add(o + off),
9993                                                    vfmaq_f32(
9994                                                        vld1q_f32(attn.as_ptr().add(o + off)),
9995                                                        vw,
9996                                                        vld1q_f32(qkv.as_ptr().add(v + off)),
9997                                                    ),
9998                                                );
9999                                            }
10000                                        }
10001                                        #[cfg(not(target_arch = "aarch64"))]
10002                                        for d in 0..d_h {
10003                                            attn[o + d] += w * qkv[v + d];
10004                                        }
10005                                    }
10006                                }
10007                            }
10008                        }
10009                    }
10010
10011                    // Out proj (sgemm + bias fused) + residual add with NEON
10012                    crate::blas::sgemm_bias(
10013                        attn,
10014                        sl(*out_w, base, h * h),
10015                        sl(*out_b, base, h),
10016                        res,
10017                        m,
10018                        h,
10019                        h,
10020                    );
10021                    #[cfg(target_arch = "aarch64")]
10022                    {
10023                        use std::arch::aarch64::*;
10024                        let chunks_h = (m * h) / 4;
10025                        for c in 0..chunks_h {
10026                            let off = c * 4;
10027                            vst1q_f32(
10028                                res.as_mut_ptr().add(off),
10029                                vaddq_f32(
10030                                    vld1q_f32(res.as_ptr().add(off)),
10031                                    vld1q_f32(inp.as_ptr().add(off)),
10032                                ),
10033                            );
10034                        }
10035                        for i in (chunks_h * 4)..(m * h) {
10036                            res[i] += inp[i];
10037                        }
10038                    }
10039                    #[cfg(not(target_arch = "aarch64"))]
10040                    for i in 0..m * h {
10041                        res[i] += inp[i];
10042                    }
10043
10044                    // LN1 (fused residual already done above — just normalize)
10045                    let g1 = sl(*ln1_g, base, h);
10046                    let b1 = sl(*ln1_b, base, h);
10047                    for r in 0..m {
10048                        crate::kernels::layer_norm_row(
10049                            &res[r * h..(r + 1) * h],
10050                            g1,
10051                            b1,
10052                            &mut normed[r * h..(r + 1) * h],
10053                            h,
10054                            *eps1,
10055                        );
10056                    }
10057
10058                    // FFN: fc1 (parallel across cores) + GELU
10059                    crate::blas::par_sgemm_bias(
10060                        normed,
10061                        sl(*fc1_w, base, h * id),
10062                        sl(*fc1_b, base, id),
10063                        ffn,
10064                        m,
10065                        h,
10066                        id,
10067                    );
10068                    crate::kernels::par_gelu_inplace(ffn);
10069
10070                    // fc2 + bias (parallel across cores) + residual with NEON
10071                    crate::blas::par_sgemm_bias(
10072                        ffn,
10073                        sl(*fc2_w, base, id * h),
10074                        sl(*fc2_b, base, h),
10075                        res,
10076                        m,
10077                        id,
10078                        h,
10079                    );
10080                    #[cfg(target_arch = "aarch64")]
10081                    {
10082                        use std::arch::aarch64::*;
10083                        let chunks_h = (m * h) / 4;
10084                        for c in 0..chunks_h {
10085                            let off = c * 4;
10086                            vst1q_f32(
10087                                res.as_mut_ptr().add(off),
10088                                vaddq_f32(
10089                                    vld1q_f32(res.as_ptr().add(off)),
10090                                    vld1q_f32(normed.as_ptr().add(off)),
10091                                ),
10092                            );
10093                        }
10094                        for i in (chunks_h * 4)..(m * h) {
10095                            res[i] += normed[i];
10096                        }
10097                    }
10098                    #[cfg(not(target_arch = "aarch64"))]
10099                    for i in 0..m * h {
10100                        res[i] += normed[i];
10101                    }
10102
10103                    // LN2 → output
10104                    let g2 = sl(*ln2_g, base, h);
10105                    let b2 = sl(*ln2_b, base, h);
10106                    for r in 0..m {
10107                        crate::kernels::layer_norm_row(
10108                            &res[r * h..(r + 1) * h],
10109                            g2,
10110                            b2,
10111                            &mut dst[r * h..(r + 1) * h],
10112                            h,
10113                            *eps2,
10114                        );
10115                    }
10116                }
10117            }
10118
10119            Thunk::FusedNomicLayer {
10120                hidden,
10121                qkv_w,
10122                out_w,
10123                mask,
10124                cos,
10125                sin,
10126                cos_len,
10127                ln1_g,
10128                ln1_b,
10129                eps1,
10130                fc11_w,
10131                fc12_w: _,
10132                fc2_w,
10133                ln2_g,
10134                ln2_b,
10135                eps2,
10136                out,
10137                batch,
10138                seq,
10139                hs,
10140                nh,
10141                dh,
10142                int_dim,
10143            } => {
10144                let (b, s, h, n_h, d_h) = (
10145                    *batch as usize,
10146                    *seq as usize,
10147                    *hs as usize,
10148                    *nh as usize,
10149                    *dh as usize,
10150                );
10151                let m = b * s;
10152                let id = *int_dim as usize;
10153                let scale = (d_h as f32).powf(-0.5);
10154                let half_dh = d_h / 2;
10155                #[cfg(target_arch = "aarch64")]
10156                let neon_chunks = d_h / 4;
10157                unsafe {
10158                    let inp = sl(*hidden, base, m * h);
10159                    let dst = sl_mut(*out, base, m * h);
10160                    let mk = sl(*mask, base, b * s);
10161                    let cos_tab = sl(*cos, base, *cos_len as usize);
10162                    let sin_tab = sl(*sin, base, *cos_len as usize);
10163                    // fc11_w is the fused [h, 2*int_dim] weight (fc11 || fc12 concatenated)
10164                    let fused_fc_w = sl(*fc11_w, base, h * 2 * id);
10165
10166                    let mut qkv = vec![0f32; m * 3 * h];
10167                    let mut attn = vec![0f32; m * h];
10168                    let mut res = vec![0f32; m * h];
10169                    let mut normed = vec![0f32; m * h];
10170                    let mut ffn_concat = vec![0f32; m * 2 * id]; // fc11||fc12 output
10171                    let mut sc = vec![0f32; s * s];
10172
10173                    // QKV (no bias)
10174                    crate::blas::sgemm(inp, sl(*qkv_w, base, h * 3 * h), &mut qkv, m, h, 3 * h);
10175
10176                    // SDPA with inline RoPE
10177                    for bi in 0..b {
10178                        for hi in 0..n_h {
10179                            for qi in 0..s {
10180                                for ki in 0..s {
10181                                    let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
10182                                    let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
10183                                    let mut dot = 0f32;
10184                                    for i in 0..half_dh {
10185                                        let q1 = qkv[q_base + i];
10186                                        let q2 = qkv[q_base + half_dh + i];
10187                                        let k1 = qkv[k_base + i];
10188                                        let k2 = qkv[k_base + half_dh + i];
10189                                        let cq = cos_tab[qi * half_dh + i];
10190                                        let sq = sin_tab[qi * half_dh + i];
10191                                        let ck = cos_tab[ki * half_dh + i];
10192                                        let sk = sin_tab[ki * half_dh + i];
10193                                        dot += (q1 * cq - q2 * sq) * (k1 * ck - k2 * sk)
10194                                            + (q2 * cq + q1 * sq) * (k2 * ck + k1 * sk);
10195                                    }
10196                                    sc[qi * s + ki] = dot * scale;
10197                                    if mk[bi * s + ki] < mask_thr {
10198                                        sc[qi * s + ki] = mask_neg;
10199                                    }
10200                                }
10201                            }
10202                            crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
10203                            for qi in 0..s {
10204                                let o = bi * s * h + qi * h + hi * d_h;
10205                                for d in 0..d_h {
10206                                    attn[o + d] = 0.0;
10207                                }
10208                                for ki in 0..s {
10209                                    let w = sc[qi * s + ki];
10210                                    if w > score_thr {
10211                                        let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
10212                                        #[cfg(target_arch = "aarch64")]
10213                                        {
10214                                            use std::arch::aarch64::*;
10215                                            let vw = vdupq_n_f32(w);
10216                                            for c in 0..neon_chunks {
10217                                                let off = c * 4;
10218                                                vst1q_f32(
10219                                                    attn.as_mut_ptr().add(o + off),
10220                                                    vfmaq_f32(
10221                                                        vld1q_f32(attn.as_ptr().add(o + off)),
10222                                                        vw,
10223                                                        vld1q_f32(qkv.as_ptr().add(v + off)),
10224                                                    ),
10225                                                );
10226                                            }
10227                                        }
10228                                        #[cfg(not(target_arch = "aarch64"))]
10229                                        for d in 0..d_h {
10230                                            attn[o + d] += w * qkv[v + d];
10231                                        }
10232                                    }
10233                                }
10234                            }
10235                        }
10236                    }
10237
10238                    // Out proj (no bias) + residual
10239                    crate::blas::sgemm(&attn, sl(*out_w, base, h * h), &mut res, m, h, h);
10240                    for i in 0..m * h {
10241                        res[i] += inp[i];
10242                    }
10243
10244                    // LN1
10245                    let g1 = sl(*ln1_g, base, h);
10246                    let b1 = sl(*ln1_b, base, h);
10247                    for r in 0..m {
10248                        crate::kernels::layer_norm_row(
10249                            &res[r * h..(r + 1) * h],
10250                            g1,
10251                            b1,
10252                            &mut normed[r * h..(r + 1) * h],
10253                            h,
10254                            *eps1,
10255                        );
10256                    }
10257
10258                    // SwiGLU: fused fc11+fc12 sgemm, then split, silu, mul
10259                    crate::blas::sgemm(&normed, fused_fc_w, &mut ffn_concat, m, h, 2 * id);
10260                    // Split: first id cols = fc11 (up), second id cols = fc12 (gate)
10261                    // SiLU on gate, then multiply up * gate → store in up region
10262                    for row in 0..m {
10263                        let bo = row * 2 * id;
10264                        // SiLU in-place on gate portion
10265                        for j in 0..id {
10266                            let x = ffn_concat[bo + id + j];
10267                            ffn_concat[bo + id + j] = x / (1.0 + (-x).exp());
10268                        }
10269                        // Multiply: up[j] *= gate[j]
10270                        for j in 0..id {
10271                            ffn_concat[bo + j] *= ffn_concat[bo + id + j];
10272                        }
10273                    }
10274
10275                    // fc2 (no bias) + residual  — read from first id cols of ffn_concat
10276                    // Need contiguous [m, id] for sgemm. Copy or use strided sgemm.
10277                    // The up*gate result is at ffn_concat[row * 2*id .. row * 2*id + id]
10278                    // Stride = 2*id. Use sgemm_general with lda = 2*id.
10279                    crate::blas::sgemm_general(
10280                        ffn_concat.as_ptr(),
10281                        sl(*fc2_w, base, id * h).as_ptr(),
10282                        res.as_mut_ptr(),
10283                        m,
10284                        h,
10285                        id,
10286                        1.0,
10287                        0.0,
10288                        2 * id,
10289                        h,
10290                        h,
10291                        false,
10292                        false,
10293                    );
10294                    for i in 0..m * h {
10295                        res[i] += normed[i];
10296                    }
10297
10298                    // LN2 → output
10299                    let g2 = sl(*ln2_g, base, h);
10300                    let b2 = sl(*ln2_b, base, h);
10301                    for r in 0..m {
10302                        crate::kernels::layer_norm_row(
10303                            &res[r * h..(r + 1) * h],
10304                            g2,
10305                            b2,
10306                            &mut dst[r * h..(r + 1) * h],
10307                            h,
10308                            *eps2,
10309                        );
10310                    }
10311                }
10312            }
10313
10314            Thunk::FusedSwiGLU {
10315                src,
10316                dst,
10317                n_half,
10318                total,
10319                gate_first,
10320            } => {
10321                let n = *n_half as usize;
10322                let t = *total as usize;
10323                let outer = t / n;
10324                let in_total = outer * 2 * n;
10325                let gate_first = *gate_first;
10326                unsafe {
10327                    let inp = sl(*src, base, in_total);
10328                    let out = sl_mut(*dst, base, t);
10329                    for o in 0..outer {
10330                        let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
10331                        let out_row = &mut out[o * n..(o + 1) * n];
10332                        for i in 0..n {
10333                            let (up, gate) = if gate_first {
10334                                (in_row[n + i], in_row[i])
10335                            } else {
10336                                (in_row[i], in_row[n + i])
10337                            };
10338                            out_row[i] = up * (gate / (1.0 + (-gate).exp()));
10339                        }
10340                    }
10341                }
10342            }
10343
10344            Thunk::Concat {
10345                dst,
10346                outer,
10347                inner,
10348                total_axis,
10349                inputs,
10350            } => {
10351                let outer = *outer as usize;
10352                let inner = *inner as usize;
10353                let total_axis = *total_axis as usize;
10354                let row_stride = total_axis * inner;
10355                let out_total = outer * row_stride;
10356                unsafe {
10357                    let out = sl_mut(*dst, base, out_total);
10358                    let mut cum: usize = 0;
10359                    for (src_off, in_axis) in inputs {
10360                        let in_axis = *in_axis as usize;
10361                        let copy_per_row = in_axis * inner;
10362                        let dst_col_off = cum * inner;
10363                        let in_total = outer * copy_per_row;
10364                        let inp = sl(*src_off, base, in_total);
10365                        for o in 0..outer {
10366                            let dst_row_start = o * row_stride + dst_col_off;
10367                            let src_row_start = o * copy_per_row;
10368                            out[dst_row_start..dst_row_start + copy_per_row]
10369                                .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
10370                        }
10371                        cum += in_axis;
10372                    }
10373                }
10374            }
10375
10376            Thunk::ConcatF64 {
10377                dst,
10378                outer,
10379                inner,
10380                total_axis,
10381                inputs,
10382            } => {
10383                let outer = *outer as usize;
10384                let inner = *inner as usize;
10385                let total_axis = *total_axis as usize;
10386                let row_stride = total_axis * inner;
10387                let out_total = outer * row_stride;
10388                unsafe {
10389                    let out = sl_mut_f64(*dst, base, out_total);
10390                    let mut cum: usize = 0;
10391                    for (src_off, in_axis) in inputs {
10392                        let in_axis = *in_axis as usize;
10393                        let copy_per_row = in_axis * inner;
10394                        let dst_col_off = cum * inner;
10395                        let in_total = outer * copy_per_row;
10396                        let inp = sl_f64(*src_off, base, in_total);
10397                        for o in 0..outer {
10398                            let dst_row_start = o * row_stride + dst_col_off;
10399                            let src_row_start = o * copy_per_row;
10400                            out[dst_row_start..dst_row_start + copy_per_row]
10401                                .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
10402                        }
10403                        cum += in_axis;
10404                    }
10405                }
10406            }
10407
10408            Thunk::Compare {
10409                lhs,
10410                rhs,
10411                dst,
10412                len,
10413                op,
10414            } => {
10415                let len = *len as usize;
10416                unsafe {
10417                    let l = sl(*lhs, base, len);
10418                    let r = sl(*rhs, base, len);
10419                    let o = sl_mut(*dst, base, len);
10420                    for i in 0..len {
10421                        o[i] = match op {
10422                            CmpOp::Eq => (l[i] == r[i]) as u32 as f32,
10423                            CmpOp::Ne => (l[i] != r[i]) as u32 as f32,
10424                            CmpOp::Lt => (l[i] < r[i]) as u32 as f32,
10425                            CmpOp::Le => (l[i] <= r[i]) as u32 as f32,
10426                            CmpOp::Gt => (l[i] > r[i]) as u32 as f32,
10427                            CmpOp::Ge => (l[i] >= r[i]) as u32 as f32,
10428                        };
10429                    }
10430                }
10431            }
10432
10433            Thunk::Where {
10434                cond,
10435                on_true,
10436                on_false,
10437                dst,
10438                len,
10439            } => {
10440                let len = *len as usize;
10441                unsafe {
10442                    let c = sl(*cond, base, len);
10443                    let t = sl(*on_true, base, len);
10444                    let e = sl(*on_false, base, len);
10445                    let o = sl_mut(*dst, base, len);
10446                    for i in 0..len {
10447                        // Treat cond as boolean: any non-zero → true.
10448                        o[i] = if c[i] != 0.0 { t[i] } else { e[i] };
10449                    }
10450                }
10451            }
10452
10453            Thunk::ScatterAdd {
10454                updates,
10455                indices,
10456                dst,
10457                num_updates,
10458                out_dim,
10459                trailing,
10460            } => {
10461                let num_updates = *num_updates as usize;
10462                let out_dim = *out_dim as usize;
10463                let trailing = *trailing as usize;
10464                unsafe {
10465                    let upd = sl(*updates, base, num_updates * trailing);
10466                    let ids = sl(*indices, base, num_updates);
10467                    let out = sl_mut(*dst, base, out_dim * trailing);
10468                    // Zero the output first — semantics are accumulate-into-zeros.
10469                    for v in out.iter_mut() {
10470                        *v = 0.0;
10471                    }
10472                    for i in 0..num_updates {
10473                        let row = ids[i] as usize;
10474                        debug_assert!(row < out_dim, "ScatterAdd index out of range");
10475                        let src_off = i * trailing;
10476                        let dst_off = row * trailing;
10477                        for j in 0..trailing {
10478                            out[dst_off + j] += upd[src_off + j];
10479                        }
10480                    }
10481                }
10482            }
10483
10484            Thunk::GroupedMatMul {
10485                input,
10486                weight,
10487                expert_idx,
10488                dst,
10489                m,
10490                k_dim,
10491                n,
10492                num_experts,
10493            } => {
10494                let m = *m as usize;
10495                let k_dim = *k_dim as usize;
10496                let n = *n as usize;
10497                let num_experts = *num_experts as usize;
10498                unsafe {
10499                    let inp = sl(*input, base, m * k_dim);
10500                    let wt = sl(*weight, base, num_experts * k_dim * n);
10501                    let ids = sl(*expert_idx, base, m);
10502                    let out = sl_mut(*dst, base, m * n);
10503
10504                    // Counting-sort tokens by their assigned expert.
10505                    // counts[e] = how many tokens routed to expert e.
10506                    let mut counts = vec![0usize; num_experts];
10507                    for i in 0..m {
10508                        let e = ids[i] as usize;
10509                        debug_assert!(
10510                            e < num_experts,
10511                            "expert_idx out of range: {e} >= {num_experts}"
10512                        );
10513                        counts[e] += 1;
10514                    }
10515                    // Cumulative offsets into the packed buffer.
10516                    let mut offsets = vec![0usize; num_experts + 1];
10517                    for e in 0..num_experts {
10518                        offsets[e + 1] = offsets[e] + counts[e];
10519                    }
10520                    // Pack: each expert's rows land contiguously in `packed_in`.
10521                    // `original_pos[packed_idx] = original_token_idx` for the
10522                    // unpermute step at the end.
10523                    let mut packed_in = vec![0f32; m * k_dim];
10524                    let mut original_pos = vec![0usize; m];
10525                    let mut write_idx = vec![0usize; num_experts];
10526                    for i in 0..m {
10527                        let e = ids[i] as usize;
10528                        let dst_row = offsets[e] + write_idx[e];
10529                        packed_in[dst_row * k_dim..(dst_row + 1) * k_dim]
10530                            .copy_from_slice(&inp[i * k_dim..(i + 1) * k_dim]);
10531                        original_pos[dst_row] = i;
10532                        write_idx[e] += 1;
10533                    }
10534
10535                    // One BLAS sgemm per expert. Skip experts with no
10536                    // tokens — common at the tail when M is much smaller
10537                    // than num_experts × k.
10538                    let mut packed_out = vec![0f32; m * n];
10539                    let expert_stride = k_dim * n;
10540                    let gmm_ord = crate::moe_residency::next_gmm_ord();
10541                    let moe_layer = gmm_ord / 3;
10542                    for e in 0..num_experts {
10543                        let count = counts[e];
10544                        if count == 0 {
10545                            continue;
10546                        }
10547                        crate::moe_residency::record_expert_tokens(moe_layer, e, count);
10548                        let in_start = offsets[e];
10549                        let in_slice = &packed_in[in_start * k_dim..(in_start + count) * k_dim];
10550                        let w_slab: &[f32] =
10551                            if !crate::moe_residency::expert_on_device_for_layer(moe_layer, e) {
10552                                if let Some(ptr) =
10553                                    crate::moe_residency::host_expert_weight_ptr(gmm_ord, e)
10554                                {
10555                                    std::slice::from_raw_parts(ptr, expert_stride)
10556                                } else {
10557                                    &wt[e * expert_stride..(e + 1) * expert_stride]
10558                                }
10559                            } else {
10560                                &wt[e * expert_stride..(e + 1) * expert_stride]
10561                            };
10562                        let out_slice = &mut packed_out[in_start * n..(in_start + count) * n];
10563                        crate::blas::sgemm(in_slice, w_slab, out_slice, count, k_dim, n);
10564                    }
10565
10566                    // Unpermute back to original token order.
10567                    for packed_idx in 0..m {
10568                        let i = original_pos[packed_idx];
10569                        out[i * n..(i + 1) * n]
10570                            .copy_from_slice(&packed_out[packed_idx * n..(packed_idx + 1) * n]);
10571                    }
10572                }
10573            }
10574
10575            Thunk::DequantGroupedMatMulGguf {
10576                input,
10577                w_q,
10578                expert_idx,
10579                dst,
10580                m,
10581                k_dim,
10582                n,
10583                num_experts,
10584                scheme,
10585            } => {
10586                let m = *m as usize;
10587                let k_dim = *k_dim as usize;
10588                let n = *n as usize;
10589                let num_experts = *num_experts as usize;
10590                let block_elems = scheme.gguf_block_size() as usize;
10591                let block_bytes = scheme.gguf_block_bytes() as usize;
10592                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
10593                unsafe {
10594                    let inp = sl(*input, base, m * k_dim);
10595                    let wt = std::slice::from_raw_parts(
10596                        base.add(*w_q) as *const u8,
10597                        num_experts * slab_bytes,
10598                    );
10599                    let ids = sl(*expert_idx, base, m);
10600                    let out = sl_mut(*dst, base, m * n);
10601                    crate::gguf_matmul::gguf_grouped_matmul_bt(
10602                        inp,
10603                        wt,
10604                        ids,
10605                        out,
10606                        m,
10607                        k_dim,
10608                        n,
10609                        num_experts,
10610                        *scheme,
10611                    );
10612                }
10613            }
10614
10615            Thunk::DequantMoEWeightsGguf {
10616                w_q,
10617                dst,
10618                k_dim,
10619                n,
10620                num_experts,
10621                scheme,
10622            } => {
10623                let k_dim = *k_dim as usize;
10624                let n = *n as usize;
10625                let num_experts = *num_experts as usize;
10626                let block_elems = scheme.gguf_block_size() as usize;
10627                let block_bytes = scheme.gguf_block_bytes() as usize;
10628                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
10629                unsafe {
10630                    let wt = std::slice::from_raw_parts(
10631                        base.add(*w_q) as *const u8,
10632                        num_experts * slab_bytes,
10633                    );
10634                    let out = sl_mut(*dst, base, num_experts * k_dim * n);
10635                    crate::gguf_matmul::dequant_moe_weights_to_grouped_f32(
10636                        wt,
10637                        out,
10638                        num_experts,
10639                        k_dim,
10640                        n,
10641                        *scheme,
10642                    );
10643                }
10644            }
10645
10646            Thunk::TopK {
10647                src,
10648                dst,
10649                outer,
10650                axis_dim,
10651                k,
10652            } => {
10653                let outer = *outer as usize;
10654                let axis_dim = *axis_dim as usize;
10655                let k = *k as usize;
10656                unsafe {
10657                    let inp = sl(*src, base, outer * axis_dim);
10658                    let out = sl_mut(*dst, base, outer * k);
10659                    // Repeated argmax with masking. O(k * axis_dim) per row;
10660                    // good enough for small k (MoE typical k=2–8). For larger
10661                    // k a partial heap would win.
10662                    let mut row_buf: Vec<f32> = vec![0.0; axis_dim];
10663                    for o in 0..outer {
10664                        row_buf.copy_from_slice(&inp[o * axis_dim..(o + 1) * axis_dim]);
10665                        for ki in 0..k {
10666                            // Find argmax with tie-break to smaller index.
10667                            let mut best_i = 0usize;
10668                            let mut best_v = row_buf[0];
10669                            for i in 1..axis_dim {
10670                                let v = row_buf[i];
10671                                if v > best_v {
10672                                    best_v = v;
10673                                    best_i = i;
10674                                }
10675                            }
10676                            out[o * k + ki] = best_i as f32;
10677                            // Mask the chosen index so the next pass picks
10678                            // the next-largest instead.
10679                            row_buf[best_i] = f32::NEG_INFINITY;
10680                        }
10681                    }
10682                    if let Some(cap) = schedule.moe_topk_capture.as_ref() {
10683                        cap.push_topk_f32(&out[..outer * k], axis_dim);
10684                    }
10685                }
10686            }
10687
10688            Thunk::Reduce {
10689                src,
10690                dst,
10691                outer,
10692                reduced,
10693                inner,
10694                op,
10695            } => {
10696                let outer = *outer as usize;
10697                let reduced = *reduced as usize;
10698                let inner = *inner as usize;
10699                let in_total = outer * reduced * inner;
10700                let out_total = outer * inner;
10701                unsafe {
10702                    let inp = sl(*src, base, in_total);
10703                    let out = sl_mut(*dst, base, out_total);
10704                    for o in 0..outer {
10705                        for i in 0..inner {
10706                            let mut acc = match op {
10707                                ReduceOp::Max => f32::NEG_INFINITY,
10708                                ReduceOp::Min => f32::INFINITY,
10709                                ReduceOp::Prod => 1.0f32,
10710                                _ => 0.0f32, // Sum / Mean
10711                            };
10712                            // Walk the reduced axis with stride `inner`.
10713                            for r in 0..reduced {
10714                                let v = inp[o * reduced * inner + r * inner + i];
10715                                acc = match op {
10716                                    ReduceOp::Sum | ReduceOp::Mean => acc + v,
10717                                    ReduceOp::Max => acc.max(v),
10718                                    ReduceOp::Min => acc.min(v),
10719                                    ReduceOp::Prod => acc * v,
10720                                };
10721                            }
10722                            if matches!(op, ReduceOp::Mean) {
10723                                acc /= reduced as f32;
10724                            }
10725                            out[o * inner + i] = acc;
10726                        }
10727                    }
10728                }
10729            }
10730
10731            Thunk::Conv2D1x1 {
10732                src,
10733                weight,
10734                dst,
10735                n,
10736                c_in,
10737                c_out,
10738                hw,
10739            } => {
10740                let n = *n as usize;
10741                let c_in = *c_in as usize;
10742                let c_out = *c_out as usize;
10743                let hw = *hw as usize;
10744                unsafe {
10745                    let inp = sl(*src, base, n * c_in * hw);
10746                    let wt = sl(*weight, base, c_out * c_in);
10747                    let out = sl_mut(*dst, base, n * c_out * hw);
10748                    // Per-batch sgemm: weight [c_out, c_in] @ input
10749                    // [c_in, hw] = output [c_out, hw]. The weight is
10750                    // shared across batches, so we get to dispatch
10751                    // BLAS once per N (typically 1).
10752                    for ni in 0..n {
10753                        let in_off = ni * c_in * hw;
10754                        let out_off = ni * c_out * hw;
10755                        crate::blas::sgemm(
10756                            wt,
10757                            &inp[in_off..in_off + c_in * hw],
10758                            &mut out[out_off..out_off + c_out * hw],
10759                            c_out,
10760                            c_in,
10761                            hw,
10762                        );
10763                    }
10764                }
10765            }
10766
10767            Thunk::Conv2D {
10768                src,
10769                weight,
10770                dst,
10771                n,
10772                c_in,
10773                h,
10774                w,
10775                c_out,
10776                h_out,
10777                w_out,
10778                kh,
10779                kw,
10780                sh,
10781                sw,
10782                ph,
10783                pw,
10784                dh,
10785                dw,
10786                groups,
10787            } => {
10788                let n = *n as usize;
10789                let c_in = *c_in as usize;
10790                let h = *h as usize;
10791                let w = *w as usize;
10792                let c_out = *c_out as usize;
10793                let h_out = *h_out as usize;
10794                let w_out = *w_out as usize;
10795                let kh = *kh as usize;
10796                let kw = *kw as usize;
10797                let sh = *sh as usize;
10798                let sw = *sw as usize;
10799                let ph = *ph as usize;
10800                let pw = *pw as usize;
10801                let dh = *dh as usize;
10802                let dw = *dw as usize;
10803                let groups = *groups as usize;
10804                let c_in_per_g = c_in / groups;
10805                let c_out_per_g = c_out / groups;
10806                unsafe {
10807                    let inp = sl(*src, base, n * c_in * h * w);
10808                    let wt = sl(*weight, base, c_out * c_in_per_g * kh * kw);
10809                    let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
10810                    for ni in 0..n {
10811                        for co in 0..c_out {
10812                            let g = co / c_out_per_g;
10813                            let ci_start = g * c_in_per_g;
10814                            for ho in 0..h_out {
10815                                for wo in 0..w_out {
10816                                    let mut acc = 0f32;
10817                                    for ci_off in 0..c_in_per_g {
10818                                        let ci = ci_start + ci_off;
10819                                        let in_chan = ((ni * c_in) + ci) * h * w;
10820                                        let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
10821                                        for ki in 0..kh {
10822                                            for kj in 0..kw {
10823                                                let hi = ho * sh + ki * dh;
10824                                                let wi = wo * sw + kj * dw;
10825                                                if hi < ph || wi < pw {
10826                                                    continue;
10827                                                }
10828                                                let hi = hi - ph;
10829                                                let wi = wi - pw;
10830                                                if hi >= h || wi >= w {
10831                                                    continue;
10832                                                }
10833                                                acc += inp[in_chan + hi * w + wi]
10834                                                    * wt[wt_chan + ki * kw + kj];
10835                                            }
10836                                        }
10837                                    }
10838                                    out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] =
10839                                        acc;
10840                                }
10841                            }
10842                        }
10843                    }
10844                }
10845            }
10846
10847            Thunk::Pool2D {
10848                src,
10849                dst,
10850                n,
10851                c,
10852                h,
10853                w,
10854                h_out,
10855                w_out,
10856                kh,
10857                kw,
10858                sh,
10859                sw,
10860                ph,
10861                pw,
10862                kind,
10863            } => {
10864                let n = *n as usize;
10865                let c = *c as usize;
10866                let h = *h as usize;
10867                let w = *w as usize;
10868                let h_out = *h_out as usize;
10869                let w_out = *w_out as usize;
10870                let kh = *kh as usize;
10871                let kw = *kw as usize;
10872                let sh = *sh as usize;
10873                let sw = *sw as usize;
10874                let ph = *ph as usize;
10875                let pw = *pw as usize;
10876                let kernel_area = (kh * kw) as f32;
10877                unsafe {
10878                    let inp = sl(*src, base, n * c * h * w);
10879                    let out = sl_mut(*dst, base, n * c * h_out * w_out);
10880                    for ni in 0..n {
10881                        for ci in 0..c {
10882                            let in_chan = ni * c * h * w + ci * h * w;
10883                            let out_chan = ni * c * h_out * w_out + ci * h_out * w_out;
10884                            for ho in 0..h_out {
10885                                for wo in 0..w_out {
10886                                    let mut acc = match kind {
10887                                        ReduceOp::Max => f32::NEG_INFINITY,
10888                                        _ => 0f32, // Mean (and Sum/Min/Prod fall back here)
10889                                    };
10890                                    for ki in 0..kh {
10891                                        for kj in 0..kw {
10892                                            let hi = ho * sh + ki;
10893                                            let wi = wo * sw + kj;
10894                                            // Padded-zero region.
10895                                            if hi < ph || wi < pw {
10896                                                continue;
10897                                            }
10898                                            let hi = hi - ph;
10899                                            let wi = wi - pw;
10900                                            if hi >= h || wi >= w {
10901                                                continue;
10902                                            }
10903                                            let v = inp[in_chan + hi * w + wi];
10904                                            match kind {
10905                                                ReduceOp::Max => acc = acc.max(v),
10906                                                _ => acc += v,
10907                                            }
10908                                        }
10909                                    }
10910                                    if matches!(kind, ReduceOp::Mean) {
10911                                        acc /= kernel_area;
10912                                    }
10913                                    out[out_chan + ho * w_out + wo] = acc;
10914                                }
10915                            }
10916                        }
10917                    }
10918                }
10919            }
10920
10921            Thunk::ReluBackward { x, dy, dx, len } => {
10922                let len = *len as usize;
10923                unsafe {
10924                    let xs = sl(*x, base, len);
10925                    let dys = sl(*dy, base, len);
10926                    let out = sl_mut(*dx, base, len);
10927                    for i in 0..len {
10928                        out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
10929                    }
10930                }
10931            }
10932
10933            Thunk::ReluBackwardF64 { x, dy, dx, len } => {
10934                let len = *len as usize;
10935                unsafe {
10936                    let xs = sl_f64(*x, base, len);
10937                    let dys = sl_f64(*dy, base, len);
10938                    let out = sl_mut_f64(*dx, base, len);
10939                    for i in 0..len {
10940                        out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
10941                    }
10942                }
10943            }
10944
10945            Thunk::QMatMul {
10946                x,
10947                w,
10948                bias,
10949                out,
10950                m,
10951                k,
10952                n,
10953                x_zp,
10954                w_zp,
10955                out_zp,
10956                mult,
10957            } => {
10958                let m = *m as usize;
10959                let k = *k as usize;
10960                let n = *n as usize;
10961                unsafe {
10962                    let x_ptr = base.add(*x) as *const i8;
10963                    let w_ptr = base.add(*w) as *const i8;
10964                    let bias_ptr = base.add(*bias) as *const i32;
10965                    let out_ptr = base.add(*out) as *mut i8;
10966                    for mi in 0..m {
10967                        for ni in 0..n {
10968                            let mut acc: i32 = *bias_ptr.add(ni);
10969                            for ki in 0..k {
10970                                let xv = *x_ptr.add(mi * k + ki) as i32 - *x_zp;
10971                                let wv = *w_ptr.add(ki * n + ni) as i32 - *w_zp;
10972                                acc += xv * wv;
10973                            }
10974                            // Requantize: round(acc · mult) + out_zp,
10975                            // clamped to i8.
10976                            let r = (acc as f32 * *mult).round() as i32 + *out_zp;
10977                            let r = r.clamp(-128, 127) as i8;
10978                            *out_ptr.add(mi * n + ni) = r;
10979                        }
10980                    }
10981                }
10982            }
10983
10984            Thunk::QConv2d {
10985                x,
10986                w,
10987                bias,
10988                out,
10989                n,
10990                c_in,
10991                h,
10992                w_in,
10993                c_out,
10994                h_out,
10995                w_out,
10996                kh,
10997                kw,
10998                sh,
10999                sw,
11000                ph,
11001                pw,
11002                dh,
11003                dw,
11004                groups,
11005                x_zp,
11006                w_zp,
11007                out_zp,
11008                mult,
11009            } => {
11010                let n = *n as usize;
11011                let c_in = *c_in as usize;
11012                let h = *h as usize;
11013                let w_in = *w_in as usize;
11014                let c_out = *c_out as usize;
11015                let h_out = *h_out as usize;
11016                let w_out = *w_out as usize;
11017                let kh = *kh as usize;
11018                let kw = *kw as usize;
11019                let sh = *sh as usize;
11020                let sw = *sw as usize;
11021                let ph = *ph as usize;
11022                let pw = *pw as usize;
11023                let dh = *dh as usize;
11024                let dw = *dw as usize;
11025                let groups = *groups as usize;
11026                let c_in_per_g = c_in / groups;
11027                let c_out_per_g = c_out / groups;
11028                unsafe {
11029                    let x_ptr = base.add(*x) as *const i8;
11030                    let w_ptr = base.add(*w) as *const i8;
11031                    let bias_ptr = base.add(*bias) as *const i32;
11032                    let out_ptr = base.add(*out) as *mut i8;
11033                    for ni in 0..n {
11034                        for co in 0..c_out {
11035                            let g = co / c_out_per_g;
11036                            let ci_start = g * c_in_per_g;
11037                            for ho in 0..h_out {
11038                                for wo in 0..w_out {
11039                                    let mut acc: i32 = *bias_ptr.add(co);
11040                                    for ci_off in 0..c_in_per_g {
11041                                        let ci = ci_start + ci_off;
11042                                        let in_chan = ((ni * c_in) + ci) * h * w_in;
11043                                        let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
11044                                        for ki in 0..kh {
11045                                            for kj in 0..kw {
11046                                                let hi = ho * sh + ki * dh;
11047                                                let wi = wo * sw + kj * dw;
11048                                                if hi < ph || wi < pw {
11049                                                    continue;
11050                                                }
11051                                                let hi = hi - ph;
11052                                                let wi = wi - pw;
11053                                                if hi >= h || wi >= w_in {
11054                                                    continue;
11055                                                }
11056                                                let xv = *x_ptr.add(in_chan + hi * w_in + wi)
11057                                                    as i32
11058                                                    - *x_zp;
11059                                                let wv = *w_ptr.add(wt_chan + ki * kw + kj) as i32
11060                                                    - *w_zp;
11061                                                acc += xv * wv;
11062                                            }
11063                                        }
11064                                    }
11065                                    let r = (acc as f32 * *mult).round() as i32 + *out_zp;
11066                                    let r = r.clamp(-128, 127) as i8;
11067                                    let dst = ((ni * c_out) + co) * h_out * w_out + ho * w_out + wo;
11068                                    *out_ptr.add(dst) = r;
11069                                }
11070                            }
11071                        }
11072                    }
11073                }
11074            }
11075
11076            Thunk::Quantize {
11077                x,
11078                q,
11079                len,
11080                chan_axis: _,
11081                chan_dim,
11082                inner,
11083                scales,
11084                zero_points,
11085            } => {
11086                let len = *len as usize;
11087                let chan_dim = *chan_dim as usize;
11088                let inner = *inner as usize;
11089                unsafe {
11090                    let xs = sl(*x, base, len);
11091                    let q_ptr = base.add(*q) as *mut i8;
11092                    for i in 0..len {
11093                        let c = if chan_dim == 1 {
11094                            0
11095                        } else {
11096                            (i / inner) % chan_dim
11097                        };
11098                        let inv_scale = 1.0 / scales[c];
11099                        let zp = zero_points[c];
11100                        let v = (xs[i] * inv_scale).round() as i32 + zp;
11101                        *q_ptr.add(i) = v.clamp(-128, 127) as i8;
11102                    }
11103                }
11104            }
11105
11106            Thunk::Dequantize {
11107                q,
11108                x,
11109                len,
11110                chan_axis: _,
11111                chan_dim,
11112                inner,
11113                scales,
11114                zero_points,
11115            } => {
11116                let len = *len as usize;
11117                let chan_dim = *chan_dim as usize;
11118                let inner = *inner as usize;
11119                unsafe {
11120                    let q_ptr = base.add(*q) as *const i8;
11121                    let out = sl_mut(*x, base, len);
11122                    for i in 0..len {
11123                        let c = if chan_dim == 1 {
11124                            0
11125                        } else {
11126                            (i / inner) % chan_dim
11127                        };
11128                        let scale = scales[c];
11129                        let zp = zero_points[c];
11130                        let qv = *q_ptr.add(i) as i32;
11131                        out[i] = (qv - zp) as f32 * scale;
11132                    }
11133                }
11134            }
11135
11136            Thunk::FakeQuantize {
11137                x,
11138                out,
11139                len,
11140                chan_axis: _,
11141                chan_dim,
11142                inner,
11143                bits,
11144                ste: _,
11145                scale_mode,
11146                state_off,
11147            } => {
11148                use rlx_ir::op::ScaleMode;
11149                let len = *len as usize;
11150                let chan_dim = *chan_dim as usize;
11151                let inner = *inner as usize;
11152                let q_max: f32 = match *bits {
11153                    8 => 127.0,
11154                    4 => 7.0,
11155                    2 => 1.0,
11156                    n => panic!("FakeQuantize: unsupported bits {n}"),
11157                };
11158                unsafe {
11159                    let xs = sl(*x, base, len);
11160                    let outs = sl_mut(*out, base, len);
11161
11162                    let mut scale = vec![0f32; chan_dim];
11163                    match scale_mode {
11164                        ScaleMode::PerBatch => {
11165                            let mut max_abs = vec![0f32; chan_dim];
11166                            for i in 0..len {
11167                                let c = if chan_dim == 1 {
11168                                    0
11169                                } else {
11170                                    (i / inner) % chan_dim
11171                                };
11172                                let a = xs[i].abs();
11173                                if a > max_abs[c] {
11174                                    max_abs[c] = a;
11175                                }
11176                            }
11177                            for c in 0..chan_dim {
11178                                scale[c] = (max_abs[c] / q_max).max(1e-12);
11179                            }
11180                        }
11181                        ScaleMode::EMA { decay } => {
11182                            // Per-channel current max-abs, then blend
11183                            // into the running state in place.
11184                            let mut max_abs = vec![0f32; chan_dim];
11185                            for i in 0..len {
11186                                let c = if chan_dim == 1 {
11187                                    0
11188                                } else {
11189                                    (i / inner) % chan_dim
11190                                };
11191                                let a = xs[i].abs();
11192                                if a > max_abs[c] {
11193                                    max_abs[c] = a;
11194                                }
11195                            }
11196                            let state =
11197                                sl_mut(state_off.expect("EMA needs state_off"), base, chan_dim);
11198                            for c in 0..chan_dim {
11199                                let cur = (max_abs[c] / q_max).max(1e-12);
11200                                // Cold-start: state==0 → seed directly.
11201                                let blended = if state[c] <= 0.0 {
11202                                    cur
11203                                } else {
11204                                    *decay * state[c] + (1.0 - *decay) * cur
11205                                };
11206                                state[c] = blended;
11207                                scale[c] = blended;
11208                            }
11209                        }
11210                        ScaleMode::Fixed => {
11211                            let state =
11212                                sl(state_off.expect("Fixed needs state_off"), base, chan_dim);
11213                            for c in 0..chan_dim {
11214                                scale[c] = state[c].max(1e-12);
11215                            }
11216                        }
11217                    }
11218
11219                    for i in 0..len {
11220                        let c = if chan_dim == 1 {
11221                            0
11222                        } else {
11223                            (i / inner) % chan_dim
11224                        };
11225                        let s = scale[c];
11226                        let qv = (xs[i] / s).round().clamp(-q_max, q_max);
11227                        outs[i] = qv * s;
11228                    }
11229                }
11230            }
11231
11232            Thunk::ActivationBackward {
11233                x,
11234                dy,
11235                dx,
11236                len,
11237                kind,
11238            } => {
11239                let len = *len as usize;
11240                unsafe {
11241                    let xs = sl(*x, base, len);
11242                    let dys = sl(*dy, base, len);
11243                    let out = sl_mut(*dx, base, len);
11244                    activation_backward_kernel(*kind, xs, dys, out);
11245                }
11246            }
11247
11248            Thunk::ActivationBackwardF64 {
11249                x,
11250                dy,
11251                dx,
11252                len,
11253                kind,
11254            } => {
11255                let len = *len as usize;
11256                unsafe {
11257                    let xs = sl_f64(*x, base, len);
11258                    let dys = sl_f64(*dy, base, len);
11259                    let out = sl_mut_f64(*dx, base, len);
11260                    activation_backward_kernel_f64(*kind, xs, dys, out);
11261                }
11262            }
11263
11264            Thunk::FakeQuantizeLSQ {
11265                x,
11266                scale_off,
11267                out,
11268                len,
11269                chan_axis: _,
11270                chan_dim,
11271                inner,
11272                bits,
11273            } => {
11274                let len = *len as usize;
11275                let chan_dim = *chan_dim as usize;
11276                let inner = *inner as usize;
11277                let q_max: f32 = match *bits {
11278                    8 => 127.0,
11279                    4 => 7.0,
11280                    2 => 1.0,
11281                    n => panic!("FakeQuantizeLSQ: bad bits {n}"),
11282                };
11283                unsafe {
11284                    let xs = sl(*x, base, len);
11285                    let scale = sl(*scale_off, base, chan_dim);
11286                    let outs = sl_mut(*out, base, len);
11287                    for i in 0..len {
11288                        let c = if chan_dim == 1 {
11289                            0
11290                        } else {
11291                            (i / inner) % chan_dim
11292                        };
11293                        let s = scale[c].max(1e-12);
11294                        let qv = (xs[i] / s).round().clamp(-q_max, q_max);
11295                        outs[i] = qv * s;
11296                    }
11297                }
11298            }
11299
11300            Thunk::FakeQuantizeLSQBackwardX {
11301                x,
11302                scale_off,
11303                dy,
11304                dx,
11305                len,
11306                chan_axis: _,
11307                chan_dim,
11308                inner,
11309                bits,
11310            } => {
11311                let len = *len as usize;
11312                let chan_dim = *chan_dim as usize;
11313                let inner = *inner as usize;
11314                let q_max: f32 = match *bits {
11315                    8 => 127.0,
11316                    4 => 7.0,
11317                    2 => 1.0,
11318                    n => panic!("FakeQuantizeLSQBackwardX: bad bits {n}"),
11319                };
11320                unsafe {
11321                    let xs = sl(*x, base, len);
11322                    let scale = sl(*scale_off, base, chan_dim);
11323                    let dys = sl(*dy, base, len);
11324                    let outs = sl_mut(*dx, base, len);
11325                    // STE-clipped: dx = dy when |x/s| ≤ q_max, else 0.
11326                    for i in 0..len {
11327                        let c = if chan_dim == 1 {
11328                            0
11329                        } else {
11330                            (i / inner) % chan_dim
11331                        };
11332                        let z = xs[i] / scale[c].max(1e-12);
11333                        outs[i] = if z.abs() <= q_max { dys[i] } else { 0.0 };
11334                    }
11335                }
11336            }
11337
11338            Thunk::FakeQuantizeLSQBackwardScale {
11339                x,
11340                scale_off,
11341                dy,
11342                dscale,
11343                len,
11344                chan_axis: _,
11345                chan_dim,
11346                inner,
11347                bits,
11348            } => {
11349                let len = *len as usize;
11350                let chan_dim = *chan_dim as usize;
11351                let inner = *inner as usize;
11352                let q_max: f32 = match *bits {
11353                    8 => 127.0,
11354                    4 => 7.0,
11355                    2 => 1.0,
11356                    n => panic!("FakeQuantizeLSQBackwardScale: bad bits {n}"),
11357                };
11358                unsafe {
11359                    let xs = sl(*x, base, len);
11360                    let scale = sl(*scale_off, base, chan_dim);
11361                    let dys = sl(*dy, base, len);
11362                    let outs = sl_mut(*dscale, base, chan_dim);
11363                    for v in outs.iter_mut() {
11364                        *v = 0.0;
11365                    }
11366                    // ψ(z) = -z + round(z) inside range, sign(z)·q_max outside.
11367                    // dscale[c] = sum_i ψ(x_i/s[c]) * upstream[i].
11368                    for i in 0..len {
11369                        let c = if chan_dim == 1 {
11370                            0
11371                        } else {
11372                            (i / inner) % chan_dim
11373                        };
11374                        let s = scale[c].max(1e-12);
11375                        let z = xs[i] / s;
11376                        let psi = if z.abs() <= q_max {
11377                            -z + z.round()
11378                        } else if z > 0.0 {
11379                            q_max
11380                        } else {
11381                            -q_max
11382                        };
11383                        outs[c] += psi * dys[i];
11384                    }
11385                }
11386            }
11387
11388            Thunk::FakeQuantizeBackward {
11389                x,
11390                dy,
11391                dx,
11392                len,
11393                chan_axis: _,
11394                chan_dim,
11395                inner,
11396                bits,
11397                ste,
11398            } => {
11399                use rlx_ir::op::SteKind;
11400                let len = *len as usize;
11401                let chan_dim = *chan_dim as usize;
11402                let inner = *inner as usize;
11403                let q_max: f32 = match *bits {
11404                    8 => 127.0,
11405                    4 => 7.0,
11406                    2 => 1.0,
11407                    n => panic!("FakeQuantizeBackward: bad bits {n}"),
11408                };
11409                unsafe {
11410                    let xs = sl(*x, base, len);
11411                    let dys = sl(*dy, base, len);
11412                    let outs = sl_mut(*dx, base, len);
11413
11414                    // Per-channel max-abs → scale, same as forward.
11415                    let mut max_abs = vec![0f32; chan_dim];
11416                    for i in 0..len {
11417                        let c = if chan_dim == 1 {
11418                            0
11419                        } else {
11420                            (i / inner) % chan_dim
11421                        };
11422                        let a = xs[i].abs();
11423                        if a > max_abs[c] {
11424                            max_abs[c] = a;
11425                        }
11426                    }
11427                    let mut scale = vec![0f32; chan_dim];
11428                    for c in 0..chan_dim {
11429                        scale[c] = (max_abs[c] / q_max).max(1e-12);
11430                    }
11431
11432                    match *ste {
11433                        SteKind::Identity => {
11434                            // dx = dy unchanged.
11435                            outs.copy_from_slice(dys);
11436                        }
11437                        SteKind::ClippedIdentity => {
11438                            // dx = dy * (|x| <= q_max·s); zero if the
11439                            // forward saturated.
11440                            for i in 0..len {
11441                                let c = if chan_dim == 1 {
11442                                    0
11443                                } else {
11444                                    (i / inner) % chan_dim
11445                                };
11446                                let bound = q_max * scale[c];
11447                                outs[i] = if xs[i].abs() <= bound { dys[i] } else { 0.0 };
11448                            }
11449                        }
11450                        SteKind::Tanh => {
11451                            // dx = dy * (1 - tanh²(x/s)).
11452                            for i in 0..len {
11453                                let c = if chan_dim == 1 {
11454                                    0
11455                                } else {
11456                                    (i / inner) % chan_dim
11457                                };
11458                                let t = (xs[i] / scale[c]).tanh();
11459                                outs[i] = dys[i] * (1.0 - t * t);
11460                            }
11461                        }
11462                        SteKind::HardTanh => {
11463                            // dx = dy * max(0, 1 - |x/(q_max·s)|).
11464                            for i in 0..len {
11465                                let c = if chan_dim == 1 {
11466                                    0
11467                                } else {
11468                                    (i / inner) % chan_dim
11469                                };
11470                                let bound = q_max * scale[c];
11471                                let attenuation = (1.0 - (xs[i] / bound).abs()).max(0.0);
11472                                outs[i] = dys[i] * attenuation;
11473                            }
11474                        }
11475                    }
11476                }
11477            }
11478
11479            Thunk::LayerNormBackwardInput {
11480                x,
11481                gamma,
11482                dy,
11483                dx,
11484                rows,
11485                h,
11486                eps,
11487            } => {
11488                let rows = *rows as usize;
11489                let h = *h as usize;
11490                let eps = *eps;
11491                unsafe {
11492                    let xs = sl(*x, base, rows * h);
11493                    let g = sl(*gamma, base, h);
11494                    let dys = sl(*dy, base, rows * h);
11495                    let out = sl_mut(*dx, base, rows * h);
11496                    let n_inv = 1.0 / h as f32;
11497                    for r in 0..rows {
11498                        let xr = &xs[r * h..(r + 1) * h];
11499                        let dyr = &dys[r * h..(r + 1) * h];
11500                        // Per-row mean and inv_std (recompute — no saved
11501                        // tensor from the forward pass).
11502                        let mut sum = 0f32;
11503                        for &v in xr {
11504                            sum += v;
11505                        }
11506                        let mean = sum * n_inv;
11507                        let mut var = 0f32;
11508                        for &v in xr {
11509                            let d = v - mean;
11510                            var += d * d;
11511                        }
11512                        let inv_std = 1.0 / (var * n_inv + eps).sqrt();
11513
11514                        // sums needed for the closed-form:
11515                        //   mean(dy·γ) and mean(dy·γ·x̂)
11516                        let mut s_sy = 0f32;
11517                        let mut s_sxh = 0f32;
11518                        for d in 0..h {
11519                            let xh = (xr[d] - mean) * inv_std;
11520                            let sy = dyr[d] * g[d];
11521                            s_sy += sy;
11522                            s_sxh += sy * xh;
11523                        }
11524                        let m_sy = s_sy * n_inv;
11525                        let m_sxh = s_sxh * n_inv;
11526
11527                        for d in 0..h {
11528                            let xh = (xr[d] - mean) * inv_std;
11529                            let sy = dyr[d] * g[d];
11530                            out[r * h + d] = inv_std * (sy - m_sy - xh * m_sxh);
11531                        }
11532                    }
11533                }
11534            }
11535
11536            Thunk::LayerNormBackwardGamma {
11537                x,
11538                dy,
11539                dgamma,
11540                rows,
11541                h,
11542                eps,
11543            } => {
11544                let rows = *rows as usize;
11545                let h = *h as usize;
11546                let eps = *eps;
11547                unsafe {
11548                    let xs = sl(*x, base, rows * h);
11549                    let dys = sl(*dy, base, rows * h);
11550                    let out = sl_mut(*dgamma, base, h);
11551                    for v in out.iter_mut() {
11552                        *v = 0.0;
11553                    }
11554                    let n_inv = 1.0 / h as f32;
11555                    for r in 0..rows {
11556                        let xr = &xs[r * h..(r + 1) * h];
11557                        let dyr = &dys[r * h..(r + 1) * h];
11558                        let mut sum = 0f32;
11559                        for &v in xr {
11560                            sum += v;
11561                        }
11562                        let mean = sum * n_inv;
11563                        let mut var = 0f32;
11564                        for &v in xr {
11565                            let d = v - mean;
11566                            var += d * d;
11567                        }
11568                        let inv_std = 1.0 / (var * n_inv + eps).sqrt();
11569                        for d in 0..h {
11570                            let xh = (xr[d] - mean) * inv_std;
11571                            out[d] += dyr[d] * xh;
11572                        }
11573                    }
11574                }
11575            }
11576
11577            Thunk::RmsNormBackwardInput {
11578                x,
11579                gamma,
11580                beta,
11581                dy,
11582                dx,
11583                rows,
11584                h,
11585                eps,
11586            } => {
11587                let (rows, h) = (*rows as usize, *h as usize);
11588                unsafe {
11589                    let xs = sl(*x, base, rows * h);
11590                    let g = sl(*gamma, base, h);
11591                    let b = sl(*beta, base, h);
11592                    let dys = sl(*dy, base, rows * h);
11593                    let out = sl_mut(*dx, base, rows * h);
11594                    let mut dg = vec![0f32; h];
11595                    let mut db = vec![0f32; h];
11596                    for r in 0..rows {
11597                        crate::training_bwd::rms_norm_backward_row(
11598                            &xs[r * h..(r + 1) * h],
11599                            g,
11600                            b,
11601                            &dys[r * h..(r + 1) * h],
11602                            &mut out[r * h..(r + 1) * h],
11603                            &mut dg,
11604                            &mut db,
11605                            *eps,
11606                        );
11607                    }
11608                }
11609            }
11610
11611            Thunk::RmsNormBackwardGamma {
11612                x,
11613                gamma,
11614                beta,
11615                dy,
11616                dgamma,
11617                rows,
11618                h,
11619                eps,
11620            } => {
11621                let (rows, h) = (*rows as usize, *h as usize);
11622                unsafe {
11623                    let xs = sl(*x, base, rows * h);
11624                    let g = sl(*gamma, base, h);
11625                    let b = sl(*beta, base, h);
11626                    let dys = sl(*dy, base, rows * h);
11627                    let out = sl_mut(*dgamma, base, h);
11628                    for v in out.iter_mut() {
11629                        *v = 0.0;
11630                    }
11631                    let mut dx = vec![0f32; h];
11632                    let mut db = vec![0f32; h];
11633                    for r in 0..rows {
11634                        crate::training_bwd::rms_norm_backward_row(
11635                            &xs[r * h..(r + 1) * h],
11636                            g,
11637                            b,
11638                            &dys[r * h..(r + 1) * h],
11639                            &mut dx,
11640                            &mut *out,
11641                            &mut db,
11642                            *eps,
11643                        );
11644                    }
11645                }
11646            }
11647
11648            Thunk::RmsNormBackwardBeta {
11649                x,
11650                gamma,
11651                beta,
11652                dy,
11653                dbeta,
11654                rows,
11655                h,
11656                eps,
11657            } => {
11658                let (rows, h) = (*rows as usize, *h as usize);
11659                unsafe {
11660                    let xs = sl(*x, base, rows * h);
11661                    let g = sl(*gamma, base, h);
11662                    let b = sl(*beta, base, h);
11663                    let dys = sl(*dy, base, rows * h);
11664                    let out = sl_mut(*dbeta, base, h);
11665                    for v in out.iter_mut() {
11666                        *v = 0.0;
11667                    }
11668                    let mut dx = vec![0f32; h];
11669                    let mut dg = vec![0f32; h];
11670                    for r in 0..rows {
11671                        crate::training_bwd::rms_norm_backward_row(
11672                            &xs[r * h..(r + 1) * h],
11673                            g,
11674                            b,
11675                            &dys[r * h..(r + 1) * h],
11676                            &mut dx,
11677                            &mut dg,
11678                            &mut *out,
11679                            *eps,
11680                        );
11681                    }
11682                }
11683            }
11684
11685            Thunk::RopeBackward {
11686                dy,
11687                cos,
11688                sin,
11689                dx,
11690                batch,
11691                seq,
11692                hidden,
11693                head_dim,
11694                n_rot,
11695                cos_len,
11696            } => {
11697                let (b, s, hs, dh, nr, cl) = (
11698                    *batch as usize,
11699                    *seq as usize,
11700                    *hidden as usize,
11701                    *head_dim as usize,
11702                    *n_rot as usize,
11703                    *cos_len as usize,
11704                );
11705                let nh = hs / dh;
11706                let tab_half = dh / 2;
11707                unsafe {
11708                    let dys = sl(*dy, base, b * s * hs);
11709                    let cos_tab = sl(*cos, base, cl);
11710                    let sin_tab = sl(*sin, base, cl);
11711                    let out = sl_mut(*dx, base, b * s * hs);
11712                    for bi in 0..b {
11713                        for si in 0..s {
11714                            let tab_off = si.saturating_mul(tab_half) % cl.max(1);
11715                            let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
11716                            let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
11717                            for hi in 0..nh {
11718                                let base_idx = bi * s * hs + si * hs + hi * dh;
11719                                crate::training_bwd::rope_backward_row(
11720                                    &dys[base_idx..base_idx + dh],
11721                                    cp,
11722                                    sp,
11723                                    &mut out[base_idx..base_idx + dh],
11724                                    dh,
11725                                    nr,
11726                                );
11727                            }
11728                        }
11729                    }
11730                }
11731            }
11732
11733            Thunk::CumsumBackward {
11734                dy,
11735                dx,
11736                rows,
11737                cols,
11738                exclusive,
11739            } => {
11740                let (rows, cols) = (*rows as usize, *cols as usize);
11741                unsafe {
11742                    let dys = sl(*dy, base, rows * cols);
11743                    let out = sl_mut(*dx, base, rows * cols);
11744                    for r in 0..rows {
11745                        crate::training_bwd::cumsum_backward_row(
11746                            &dys[r * cols..(r + 1) * cols],
11747                            &mut out[r * cols..(r + 1) * cols],
11748                            *exclusive,
11749                        );
11750                    }
11751                }
11752            }
11753
11754            Thunk::GroupNormBackwardInput {
11755                x,
11756                gamma,
11757                beta: _beta,
11758                dy,
11759                dx,
11760                n,
11761                c,
11762                h,
11763                w,
11764                num_groups,
11765                eps,
11766            } => {
11767                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
11768                let plane = c * h * w;
11769                unsafe {
11770                    let xs = sl(*x, base, n * plane);
11771                    let g = sl(*gamma, base, c);
11772                    let dys = sl(*dy, base, n * plane);
11773                    let out = sl_mut(*dx, base, n * plane);
11774                    crate::training_bwd::group_norm_backward_input_nchw(
11775                        xs,
11776                        g,
11777                        dys,
11778                        out,
11779                        n,
11780                        c,
11781                        h,
11782                        w,
11783                        *num_groups as usize,
11784                        *eps,
11785                    );
11786                }
11787            }
11788
11789            Thunk::GroupNormBackwardGamma {
11790                x,
11791                dy,
11792                dgamma,
11793                n,
11794                c,
11795                h,
11796                w,
11797                num_groups,
11798                eps,
11799            } => {
11800                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
11801                let plane = c * h * w;
11802                unsafe {
11803                    let xs = sl(*x, base, n * plane);
11804                    let dys = sl(*dy, base, n * plane);
11805                    let out = sl_mut(*dgamma, base, c);
11806                    crate::training_bwd::group_norm_backward_gamma_nchw(
11807                        xs,
11808                        dys,
11809                        out,
11810                        n,
11811                        c,
11812                        h,
11813                        w,
11814                        *num_groups as usize,
11815                        *eps,
11816                    );
11817                }
11818            }
11819
11820            Thunk::GroupNormBackwardBeta {
11821                dy,
11822                dbeta,
11823                n,
11824                c,
11825                h,
11826                w,
11827            } => {
11828                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
11829                let plane = c * h * w;
11830                unsafe {
11831                    let dys = sl(*dy, base, n * plane);
11832                    let out = sl_mut(*dbeta, base, c);
11833                    crate::training_bwd::group_norm_backward_beta_nchw(dys, out, n, c, h, w);
11834                }
11835            }
11836
11837            Thunk::GatherBackward {
11838                dy,
11839                indices,
11840                dst,
11841                outer,
11842                axis_dim,
11843                num_idx,
11844                trailing,
11845            } => {
11846                let (outer, axis_dim, num_idx, trailing) = (
11847                    *outer as usize,
11848                    *axis_dim as usize,
11849                    *num_idx as usize,
11850                    *trailing as usize,
11851                );
11852                unsafe {
11853                    let dys = sl(*dy, base, outer * num_idx * trailing);
11854                    let ids = sl(*indices, base, num_idx);
11855                    let out = sl_mut(*dst, base, outer * axis_dim * trailing);
11856                    for v in out.iter_mut() {
11857                        *v = 0.0;
11858                    }
11859                    crate::training_bwd::gather_axis_backward(
11860                        dys, ids, out, outer, axis_dim, num_idx, trailing,
11861                    );
11862                }
11863            }
11864
11865            Thunk::MaxPool2dBackward {
11866                x,
11867                dy,
11868                dx,
11869                n,
11870                c,
11871                h,
11872                w,
11873                h_out,
11874                w_out,
11875                kh,
11876                kw,
11877                sh,
11878                sw,
11879                ph,
11880                pw,
11881            } => {
11882                let n = *n as usize;
11883                let c = *c as usize;
11884                let h = *h as usize;
11885                let w = *w as usize;
11886                let h_out = *h_out as usize;
11887                let w_out = *w_out as usize;
11888                let kh = *kh as usize;
11889                let kw = *kw as usize;
11890                let sh = *sh as usize;
11891                let sw = *sw as usize;
11892                let ph = *ph as usize;
11893                let pw = *pw as usize;
11894                unsafe {
11895                    let xs = sl(*x, base, n * c * h * w);
11896                    let dys = sl(*dy, base, n * c * h_out * w_out);
11897                    let dxs = sl_mut(*dx, base, n * c * h * w);
11898                    // Zero before scatter — multiple windows can write
11899                    // to the same input position when stride < kernel.
11900                    for v in dxs.iter_mut() {
11901                        *v = 0.0;
11902                    }
11903                    for ni in 0..n {
11904                        for ci in 0..c {
11905                            let in_chan = (ni * c + ci) * h * w;
11906                            let out_chan = (ni * c + ci) * h_out * w_out;
11907                            for ho in 0..h_out {
11908                                for wo in 0..w_out {
11909                                    // Recompute argmax inside this window.
11910                                    let mut best_v = f32::NEG_INFINITY;
11911                                    let mut best_idx: Option<usize> = None;
11912                                    for ki in 0..kh {
11913                                        for kj in 0..kw {
11914                                            let hi = ho * sh + ki;
11915                                            let wi = wo * sw + kj;
11916                                            if hi < ph || wi < pw {
11917                                                continue;
11918                                            }
11919                                            let hi = hi - ph;
11920                                            let wi = wi - pw;
11921                                            if hi >= h || wi >= w {
11922                                                continue;
11923                                            }
11924                                            let idx = in_chan + hi * w + wi;
11925                                            let v = xs[idx];
11926                                            // Tie-break: keep first hit
11927                                            // (matches forward's `acc.max(v)`
11928                                            // — strict greater-than wins).
11929                                            if v > best_v {
11930                                                best_v = v;
11931                                                best_idx = Some(idx);
11932                                            }
11933                                        }
11934                                    }
11935                                    if let Some(idx) = best_idx {
11936                                        dxs[idx] += dys[out_chan + ho * w_out + wo];
11937                                    }
11938                                }
11939                            }
11940                        }
11941                    }
11942                }
11943            }
11944
11945            Thunk::Conv2dBackwardInput {
11946                dy,
11947                w,
11948                dx,
11949                n,
11950                c_in,
11951                h,
11952                w_in,
11953                c_out,
11954                h_out,
11955                w_out,
11956                kh,
11957                kw,
11958                sh,
11959                sw,
11960                ph,
11961                pw,
11962                dh,
11963                dw,
11964                groups,
11965            } => {
11966                // Per-group GEMM + col2im. Two orders of magnitude faster
11967                // than the naive 6-deep nested loop on training shapes.
11968                //
11969                //   dcol_n_g = w_g^T  @  dy_n_g            (sgemm)
11970                //   dx_n_g  += col2im(dcol_n_g)            (scatter-add)
11971                //
11972                // Layouts (all row-major):
11973                //   w_g       [c_out_per_g, c_in_per_g · kh · kw]
11974                //   dy_n_g    [c_out_per_g, h_out · w_out]
11975                //   dcol_n_g  [c_in_per_g · kh · kw, h_out · w_out]
11976                //   dx_n_g    [c_in_per_g, h · w_in]
11977                let n = *n as usize;
11978                let c_in = *c_in as usize;
11979                let h = *h as usize;
11980                let w_in = *w_in as usize;
11981                let c_out = *c_out as usize;
11982                let h_out = *h_out as usize;
11983                let w_out = *w_out as usize;
11984                let kh = *kh as usize;
11985                let kw = *kw as usize;
11986                let sh = *sh as usize;
11987                let sw = *sw as usize;
11988                let ph = *ph as usize;
11989                let pw = *pw as usize;
11990                let dh = *dh as usize;
11991                let dw = *dw as usize;
11992                let groups = *groups as usize;
11993                let c_in_per_g = c_in / groups;
11994                let c_out_per_g = c_out / groups;
11995
11996                let m_dim = c_in_per_g * kh * kw;
11997                let n_dim = h_out * w_out;
11998                let k_dim = c_out_per_g;
11999
12000                let dy_stride_n = c_out * h_out * w_out;
12001                let dy_stride_g = c_out_per_g * h_out * w_out;
12002                let w_stride_g = c_out_per_g * c_in_per_g * kh * kw;
12003                let dx_stride_n = c_in * h * w_in;
12004                let dx_stride_g = c_in_per_g * h * w_in;
12005
12006                unsafe {
12007                    let dys = sl(*dy, base, n * c_out * h_out * w_out);
12008                    let ws = sl(*w, base, c_out * c_in_per_g * kh * kw);
12009                    let dxs = sl_mut(*dx, base, n * c_in * h * w_in);
12010                    for v in dxs.iter_mut() {
12011                        *v = 0.0;
12012                    }
12013
12014                    // Reused scratch buffer for the [m_dim, n_dim] dcol.
12015                    let mut dcol = vec![0f32; m_dim * n_dim];
12016
12017                    for ni in 0..n {
12018                        for g in 0..groups {
12019                            let w_g_off = g * w_stride_g;
12020                            let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
12021                            let dx_n_g_off = ni * dx_stride_n + g * dx_stride_g;
12022
12023                            // dcol = w_g^T @ dy_n_g
12024                            // w_g  is stored as [k_dim rows, m_dim cols] row-major
12025                            // (i.e. K×M storage with lda = M = m_dim — exactly what
12026                            // sgemm_general wants for trans_a=true).
12027                            crate::blas::sgemm_general(
12028                                ws.as_ptr().add(w_g_off),
12029                                dys.as_ptr().add(dy_n_g_off),
12030                                dcol.as_mut_ptr(),
12031                                m_dim,
12032                                n_dim,
12033                                k_dim,
12034                                1.0,
12035                                0.0,
12036                                /*lda=*/ m_dim,
12037                                /*ldb=*/ n_dim,
12038                                /*ldc=*/ n_dim,
12039                                /*trans_a=*/ true,
12040                                /*trans_b=*/ false,
12041                            );
12042
12043                            // dx_n_g += col2im(dcol)
12044                            col2im(
12045                                &dcol,
12046                                &mut dxs[dx_n_g_off..dx_n_g_off + dx_stride_g],
12047                                c_in_per_g,
12048                                h,
12049                                w_in,
12050                                h_out,
12051                                w_out,
12052                                kh,
12053                                kw,
12054                                sh,
12055                                sw,
12056                                ph,
12057                                pw,
12058                                dh,
12059                                dw,
12060                            );
12061                        }
12062                    }
12063                }
12064            }
12065
12066            Thunk::Conv2dBackwardWeight {
12067                x,
12068                dy,
12069                dw,
12070                n,
12071                c_in,
12072                h,
12073                w,
12074                c_out,
12075                h_out,
12076                w_out,
12077                kh,
12078                kw,
12079                sh,
12080                sw,
12081                ph,
12082                pw,
12083                dh,
12084                dw_dil,
12085                groups,
12086            } => {
12087                let n = *n as usize;
12088                let c_in = *c_in as usize;
12089                let h = *h as usize;
12090                let w = *w as usize;
12091                // Per-group im2col + GEMM, summed across batch.
12092                //
12093                //   col_n_g  = im2col(x_n_g)               (gather)
12094                //   dw_g    += dy_n_g  @  col_n_g^T        (sgemm, β=1)
12095                //
12096                // Layouts:
12097                //   x_n_g     [c_in_per_g, h · w]
12098                //   col_n_g   [c_in_per_g · kh · kw, h_out · w_out]
12099                //   dy_n_g    [c_out_per_g, h_out · w_out]
12100                //   dw_g      [c_out_per_g, c_in_per_g · kh · kw]
12101                let c_out = *c_out as usize;
12102                let h_out = *h_out as usize;
12103                let w_out = *w_out as usize;
12104                let kh = *kh as usize;
12105                let kw = *kw as usize;
12106                let sh = *sh as usize;
12107                let sw = *sw as usize;
12108                let ph = *ph as usize;
12109                let pw = *pw as usize;
12110                let dh = *dh as usize;
12111                let dw_dil = *dw_dil as usize;
12112                let groups = *groups as usize;
12113                let c_in_per_g = c_in / groups;
12114                let c_out_per_g = c_out / groups;
12115
12116                let m_dim = c_out_per_g;
12117                let n_dim = c_in_per_g * kh * kw;
12118                let k_dim = h_out * w_out;
12119
12120                let x_stride_n = c_in * h * w;
12121                let x_stride_g = c_in_per_g * h * w;
12122                let dy_stride_n = c_out * h_out * w_out;
12123                let dy_stride_g = c_out_per_g * h_out * w_out;
12124                let dw_stride_g = c_out_per_g * c_in_per_g * kh * kw;
12125
12126                unsafe {
12127                    let xs = sl(*x, base, n * c_in * h * w);
12128                    let dys = sl(*dy, base, n * c_out * h_out * w_out);
12129                    let dws = sl_mut(*dw, base, c_out * c_in_per_g * kh * kw);
12130                    for v in dws.iter_mut() {
12131                        *v = 0.0;
12132                    }
12133
12134                    let mut col = vec![0f32; n_dim * k_dim];
12135
12136                    for ni in 0..n {
12137                        for g in 0..groups {
12138                            let x_n_g_off = ni * x_stride_n + g * x_stride_g;
12139                            im2col(
12140                                &xs[x_n_g_off..x_n_g_off + x_stride_g],
12141                                &mut col,
12142                                c_in_per_g,
12143                                h,
12144                                w,
12145                                h_out,
12146                                w_out,
12147                                kh,
12148                                kw,
12149                                sh,
12150                                sw,
12151                                ph,
12152                                pw,
12153                                dh,
12154                                dw_dil,
12155                            );
12156
12157                            let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
12158                            let dw_g_off = g * dw_stride_g;
12159
12160                            // dw_g += dy_n_g @ col^T
12161                            //
12162                            // Output shape m × n_out = c_out_per_g × (c_in_per_g·kh·kw).
12163                            // dy_n_g is stored M×K row-major (lda = K = k_dim).
12164                            // col is stored as N×K row-major; with trans_b=true,
12165                            // sgemm_general uses ldb = K = k_dim and treats it as
12166                            // transposed. β=1 accumulates across the batch loop.
12167                            crate::blas::sgemm_general(
12168                                dys.as_ptr().add(dy_n_g_off),
12169                                col.as_ptr(),
12170                                dws.as_mut_ptr().add(dw_g_off),
12171                                m_dim,
12172                                n_dim,
12173                                k_dim,
12174                                1.0,
12175                                1.0,
12176                                /*lda=*/ k_dim,
12177                                /*ldb=*/ k_dim,
12178                                /*ldc=*/ n_dim,
12179                                /*trans_a=*/ false,
12180                                /*trans_b=*/ true,
12181                            );
12182                        }
12183                    }
12184                }
12185            }
12186
12187            Thunk::SoftmaxCrossEntropy {
12188                logits,
12189                labels,
12190                dst,
12191                n,
12192                c,
12193            } => {
12194                let n = *n as usize;
12195                let c = *c as usize;
12196                unsafe {
12197                    let lg = sl(*logits, base, n * c);
12198                    let lb = sl(*labels, base, n);
12199                    let out = sl_mut(*dst, base, n);
12200                    for ni in 0..n {
12201                        let row = &lg[ni * c..(ni + 1) * c];
12202                        // log-sum-exp: max-subtract for stability.
12203                        let mut m = f32::NEG_INFINITY;
12204                        for &v in row {
12205                            if v > m {
12206                                m = v;
12207                            }
12208                        }
12209                        let mut sum = 0f32;
12210                        for &v in row {
12211                            sum += (v - m).exp();
12212                        }
12213                        let lse = m + sum.ln();
12214                        let label_idx = lb[ni] as usize;
12215                        // loss = -(logits[label] - lse) = lse - logits[label].
12216                        out[ni] = lse - row[label_idx];
12217                    }
12218                }
12219            }
12220
12221            Thunk::SoftmaxCrossEntropyBackward {
12222                logits,
12223                labels,
12224                d_loss,
12225                dlogits,
12226                n,
12227                c,
12228            } => {
12229                let n = *n as usize;
12230                let c = *c as usize;
12231                unsafe {
12232                    let lg = sl(*logits, base, n * c);
12233                    let lb = sl(*labels, base, n);
12234                    let dl = sl(*d_loss, base, n);
12235                    let out = sl_mut(*dlogits, base, n * c);
12236                    for ni in 0..n {
12237                        let row = &lg[ni * c..(ni + 1) * c];
12238                        let label_idx = lb[ni] as usize;
12239                        let scale = dl[ni];
12240                        let mut m = f32::NEG_INFINITY;
12241                        for &v in row {
12242                            if v > m {
12243                                m = v;
12244                            }
12245                        }
12246                        let mut sum = 0f32;
12247                        for &v in row {
12248                            sum += (v - m).exp();
12249                        }
12250                        let inv_sum = 1.0 / sum;
12251                        let dst_row = &mut out[ni * c..(ni + 1) * c];
12252                        for k in 0..c {
12253                            let p = (row[k] - m).exp() * inv_sum;
12254                            let one_hot = if k == label_idx { 1.0 } else { 0.0 };
12255                            dst_row[k] = (p - one_hot) * scale;
12256                        }
12257                    }
12258                }
12259            }
12260
12261            Thunk::GatherAxis {
12262                table,
12263                idx,
12264                dst,
12265                outer,
12266                axis_dim,
12267                num_idx,
12268                trailing,
12269            } => {
12270                let outer = *outer as usize;
12271                let axis_dim = *axis_dim as usize;
12272                let num_idx = *num_idx as usize;
12273                let trailing = *trailing as usize;
12274                unsafe {
12275                    let tab = sl(*table, base, outer * axis_dim * trailing);
12276                    let ids = sl(*idx, base, num_idx);
12277                    let out = sl_mut(*dst, base, outer * num_idx * trailing);
12278                    for o in 0..outer {
12279                        let tab_outer = o * axis_dim * trailing;
12280                        let out_outer = o * num_idx * trailing;
12281                        for k in 0..num_idx {
12282                            let row = ids[k] as usize;
12283                            let tab_row = tab_outer + row * trailing;
12284                            let out_row = out_outer + k * trailing;
12285                            out[out_row..out_row + trailing]
12286                                .copy_from_slice(&tab[tab_row..tab_row + trailing]);
12287                        }
12288                    }
12289                }
12290            }
12291
12292            Thunk::Transpose {
12293                src,
12294                dst,
12295                in_total,
12296                out_dims,
12297                in_strides,
12298            } => {
12299                // N-D index walk: for each output flat index, decompose into
12300                // multi-dim coords using out_dims, then dot with in_strides
12301                // to find the source flat index. Stride 0 = broadcast (read
12302                // the same input element repeatedly along that dim).
12303                let rank = out_dims.len();
12304                let total: usize = out_dims.iter().map(|&d| d as usize).product();
12305                let in_total = *in_total as usize;
12306                unsafe {
12307                    let inp = sl(*src, base, in_total);
12308                    let out = sl_mut(*dst, base, total);
12309                    let mut idx = vec![0usize; rank];
12310                    for o in 0..total {
12311                        let mut src_idx = 0usize;
12312                        for d in 0..rank {
12313                            src_idx += idx[d] * in_strides[d] as usize;
12314                        }
12315                        out[o] = inp[src_idx];
12316                        // Increment multi-index (innermost dim first).
12317                        for d in (0..rank).rev() {
12318                            idx[d] += 1;
12319                            if idx[d] < out_dims[d] as usize {
12320                                break;
12321                            }
12322                            idx[d] = 0;
12323                        }
12324                    }
12325                }
12326            }
12327
12328            // (Thunk::DenseSolveF64 / Thunk::ScanBackward had panic
12329            // stubs here as placeholders during the wire-up; both
12330            // are now reached by the real implementations earlier in
12331            // this same match — the stubs were dead code shadowed by
12332            // the specific-pattern arms above. Removed.)
12333            Thunk::CustomOp {
12334                kernel,
12335                inputs,
12336                output,
12337                attrs,
12338            } => {
12339                let (out_off, out_len, out_shape) = output;
12340                unsafe {
12341                    dispatch_custom_op(
12342                        &**kernel, inputs, *out_off, *out_len, out_shape, attrs, base,
12343                    );
12344                }
12345            }
12346        }
12347    }
12348}
12349
12350/// Griewank treeverse: process backward iterations `[t_lo..=t_hi]` (with
12351/// the carry entering iteration `t_lo` supplied as `anchor_carry`) by
12352/// recursive binary subdivision. Total work `O((t_hi-t_lo+1) · log)`,
12353/// auxiliary memory `O(log · carry_bytes)` for the recursion stack.
12354///
12355/// Compared to the iterative segment-cached scheme, this trades extra
12356/// recompute for less working memory — each level of recursion holds
12357/// one `cb`-sized intermediate carry on the stack but never the whole
12358/// segment at once. With K saved outer checkpoints, the outer driver
12359/// invokes this helper once per segment.
12360///
12361/// `process_iter(t, carry_at_t)` is the per-iteration leaf action: it
12362/// runs `body_vjp` at iteration `t` with the supplied carry, threads
12363/// `dcarry` backward, and (for ScanBackwardXs) writes `dxs[t]`.
12364#[allow(clippy::too_many_arguments)]
12365unsafe fn griewank_process_segment(
12366    t_lo: usize,
12367    t_hi: usize,
12368    anchor_carry: &[u8],
12369    cb: usize,
12370    fwd_sched: &ThunkSchedule,
12371    fwd_init: &[u8],
12372    fwd_carry_in_off: usize,
12373    fwd_output_off: usize,
12374    fwd_x_offs: &[usize],
12375    base: *mut u8,
12376    outer_xs_offs: &[(usize, u32)],
12377    fwd_buf: &mut Vec<u8>,
12378    leaf_threshold: usize,
12379    process_iter: &mut dyn FnMut(usize, &[u8]),
12380) {
12381    unsafe {
12382        let size = t_hi - t_lo + 1;
12383        if size == 1 {
12384            process_iter(t_lo, anchor_carry);
12385            return;
12386        }
12387        if size <= leaf_threshold {
12388            // Walk forward, cache each carry, run backward in reverse.
12389            let mut cache: Vec<u8> = Vec::with_capacity(size * cb);
12390            cache.extend_from_slice(anchor_carry);
12391            fwd_buf.copy_from_slice(fwd_init);
12392            std::ptr::copy_nonoverlapping(
12393                anchor_carry.as_ptr(),
12394                fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
12395                cb,
12396            );
12397            for i in 1..size {
12398                let cur_iter = t_lo + i - 1;
12399                for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
12400                    let (outer_xs_off, x_psb) = outer_xs_offs[idx];
12401                    let xb = x_psb as usize;
12402                    std::ptr::copy_nonoverlapping(
12403                        base.add(outer_xs_off + cur_iter * xb),
12404                        fwd_buf.as_mut_ptr().add(*fb_x_off),
12405                        xb,
12406                    );
12407                }
12408                execute_thunks(fwd_sched, fwd_buf);
12409                if fwd_output_off != fwd_carry_in_off {
12410                    fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
12411                }
12412                cache.extend_from_slice(&fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb]);
12413            }
12414            // Process backward.
12415            for t in (t_lo..=t_hi).rev() {
12416                let idx = t - t_lo;
12417                let carry = &cache[idx * cb..(idx + 1) * cb];
12418                process_iter(t, carry);
12419            }
12420            return;
12421        }
12422
12423        // Split: walk forward from anchor to compute carry entering `mid`.
12424        // (We need `mid - t_lo` body executions: one per iteration in
12425        // [t_lo, mid).)
12426        let mid = t_lo + size / 2;
12427        fwd_buf.copy_from_slice(fwd_init);
12428        std::ptr::copy_nonoverlapping(
12429            anchor_carry.as_ptr(),
12430            fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
12431            cb,
12432        );
12433        for cur_iter in t_lo..mid {
12434            for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
12435                let (outer_xs_off, x_psb) = outer_xs_offs[idx];
12436                let xb = x_psb as usize;
12437                std::ptr::copy_nonoverlapping(
12438                    base.add(outer_xs_off + cur_iter * xb),
12439                    fwd_buf.as_mut_ptr().add(*fb_x_off),
12440                    xb,
12441                );
12442            }
12443            execute_thunks(fwd_sched, fwd_buf);
12444            if fwd_output_off != fwd_carry_in_off {
12445                fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
12446            }
12447        }
12448        let mid_carry: Vec<u8> = fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb].to_vec();
12449
12450        // Right half first (higher t values processed first to match the
12451        // canonical reverse-mode iteration order: dcarry threads from
12452        // t=length-1 down to t=0).
12453        griewank_process_segment(
12454            mid,
12455            t_hi,
12456            &mid_carry,
12457            cb,
12458            fwd_sched,
12459            fwd_init,
12460            fwd_carry_in_off,
12461            fwd_output_off,
12462            fwd_x_offs,
12463            base,
12464            outer_xs_offs,
12465            fwd_buf,
12466            leaf_threshold,
12467            process_iter,
12468        );
12469        // Then left half with original anchor.
12470        griewank_process_segment(
12471            t_lo,
12472            mid - 1,
12473            anchor_carry,
12474            cb,
12475            fwd_sched,
12476            fwd_init,
12477            fwd_carry_in_off,
12478            fwd_output_off,
12479            fwd_x_offs,
12480            base,
12481            outer_xs_offs,
12482            fwd_buf,
12483            leaf_threshold,
12484            process_iter,
12485        );
12486    }
12487}
12488
12489/// Execute a batched 1D FFT in the f64 2N-real-block layout.
12490/// Each "row" is `2N` f64 elements: first `N` real, then `N` imag.
12491/// The `outer` rows are independent and processed sequentially.
12492///
12493/// Both forward and inverse use the same Cooley-Tukey radix-2 DIT
12494/// kernel — only the twiddle-factor sign differs. Power-of-2 only
12495/// (the IR builder rejects non-power-of-2 sizes at graph-build time).
12496/// Batched 1D FFT on the f64 2N-real-block layout. Public so other
12497/// backend crates can invoke this as a host fallback against a
12498/// unified-memory arena (e.g. rlx-metal: sync the command buffer,
12499/// pass the Metal `Buffer::contents()` pointer as `base`, restart the
12500/// command buffer). Self-contained — no rlx-cpu state required.
12501///
12502/// Safety: `base + src` and `base + dst` must be valid for the
12503/// `outer * 2 * n_complex * sizeof::<f64>()` byte range and stay
12504/// alive for the duration of the call.
12505pub unsafe fn execute_fft1d_f64(
12506    src: usize,
12507    dst: usize,
12508    outer: usize,
12509    n_complex: usize,
12510    inverse: bool,
12511    base: *mut u8,
12512) {
12513    let row_elems = 2 * n_complex;
12514    let mut re = vec![0f64; n_complex];
12515    let mut im = vec![0f64; n_complex];
12516    // Scratch reused across rows for the Bluestein path. Empty when
12517    // we're on the radix-2 fast path.
12518    let mut scratch = if n_complex.is_power_of_two() {
12519        BluesteinScratchF64::empty()
12520    } else {
12521        BluesteinScratchF64::build(n_complex, inverse)
12522    };
12523    for o in 0..outer {
12524        let row_offset = src + o * row_elems * std::mem::size_of::<f64>();
12525        let s = unsafe { sl_f64(row_offset, base, row_elems) };
12526        re.copy_from_slice(&s[..n_complex]);
12527        im.copy_from_slice(&s[n_complex..]);
12528        if n_complex.is_power_of_two() {
12529            fft_radix2_inplace_f64(&mut re, &mut im, inverse);
12530        } else {
12531            fft_bluestein_inplace_f64(&mut re, &mut im, inverse, &mut scratch);
12532        }
12533        let dst_offset = dst + o * row_elems * std::mem::size_of::<f64>();
12534        let d = unsafe { sl_mut_f64(dst_offset, base, row_elems) };
12535        d[..n_complex].copy_from_slice(&re);
12536        d[n_complex..].copy_from_slice(&im);
12537    }
12538}
12539
12540/// f32 counterpart of `execute_fft1d_f64`. Same 2N-real-block layout
12541/// (first N real, second N imag per row), same unnormalized
12542/// convention; only the element width differs. Twiddle factors are
12543/// computed in f64 and cast to f32 to keep large-N error closer to
12544/// the f64 path (the savings from f32 are in memory bandwidth, not in
12545/// twiddle precision).
12546/// Host-fallback entry for `Op::GatedDeltaNet` (Metal / unified memory).
12547/// When `state == 0`, uses a zero-initialized scratch state per batch item.
12548pub unsafe fn execute_gated_delta_net_f32(
12549    q: usize,
12550    k: usize,
12551    v: usize,
12552    g: usize,
12553    beta: usize,
12554    state: usize,
12555    dst: usize,
12556    batch: usize,
12557    seq: usize,
12558    heads: usize,
12559    state_size: usize,
12560    base: *mut u8,
12561) {
12562    use rayon::prelude::*;
12563
12564    #[derive(Copy, Clone)]
12565    struct ArenaPtr(usize);
12566    unsafe impl Send for ArenaPtr {}
12567    unsafe impl Sync for ArenaPtr {}
12568    impl ArenaPtr {
12569        #[inline]
12570        fn get(self) -> *mut u8 {
12571            self.0 as *mut u8
12572        }
12573    }
12574
12575    unsafe {
12576        let arena = ArenaPtr(base as usize);
12577        let (b, s, h, n) = (batch, seq, heads, state_size);
12578        let scale = 1.0f32 / (n as f32).sqrt();
12579        let use_external = state != 0;
12580        let mut owned_state = vec![0f32; h * n * n];
12581
12582        crate::pool::num_threads();
12583
12584        assert!(
12585            n <= crate::gdn::GDN_MAX_STATE,
12586            "GatedDeltaNet state_size={n} exceeds stack scratch ({})",
12587            crate::gdn::GDN_MAX_STATE
12588        );
12589
12590        let qs = sl(q, arena.get(), b * s * h * n);
12591        let ks = sl(k, arena.get(), b * s * h * n);
12592        let vs = sl(v, arena.get(), b * s * h * n);
12593        let gs = sl(g, arena.get(), b * s * h);
12594        let betas = sl(beta, arena.get(), b * s * h);
12595        let _out = sl_mut(dst, arena.get(), b * s * h * n);
12596        let hs_n = h * n;
12597
12598        let run_head = |bi: usize, hi: usize, s_mat: &mut [f32], sk: &mut [f32]| {
12599            for ti in 0..s {
12600                let qkv_step = bi * s * hs_n + ti * hs_n + hi * n;
12601                let gb_step = bi * s * h + ti * h + hi;
12602                let out_row = sl_mut(dst + qkv_step * std::mem::size_of::<f32>(), arena.get(), n);
12603                crate::gdn::gdn_step_blas(
12604                    s_mat,
12605                    &qs[qkv_step..qkv_step + n],
12606                    &ks[qkv_step..qkv_step + n],
12607                    &vs[qkv_step..qkv_step + n],
12608                    gs[gb_step],
12609                    betas[gb_step],
12610                    out_row,
12611                    sk,
12612                    n,
12613                    scale,
12614                );
12615            }
12616        };
12617
12618        // Prefill (seq>1, ephemeral state): time-outer, parallel over heads —
12619        // better occupancy than head-outer when prompt length dominates.
12620        if !use_external && s > 1 {
12621            for bi in 0..b {
12622                (0..h).into_par_iter().for_each(|hi| {
12623                    let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
12624                    let sk = &mut sk_buf[..n];
12625                    let mut local_state =
12626                        [0f32; crate::gdn::GDN_MAX_STATE * crate::gdn::GDN_MAX_STATE];
12627                    let s_mat = &mut local_state[..n * n];
12628                    s_mat.fill(0.0);
12629                    run_head(bi, hi, s_mat, sk);
12630                });
12631            }
12632            return;
12633        }
12634
12635        if use_external {
12636            let state_bytes = state;
12637            (0..b * h).into_par_iter().for_each(|bhi| {
12638                let bi = bhi / h;
12639                let hi = bhi % h;
12640                let elem_off = bi * h * n * n + hi * n * n;
12641                let s_mat = sl_mut(
12642                    state_bytes + elem_off * std::mem::size_of::<f32>(),
12643                    arena.get(),
12644                    n * n,
12645                );
12646                let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
12647                run_head(bi, hi, s_mat, &mut sk_buf[..n]);
12648            });
12649        } else {
12650            for bi in 0..b {
12651                owned_state.fill(0.0);
12652                owned_state
12653                    .par_chunks_mut(n * n)
12654                    .enumerate()
12655                    .for_each(|(hi, s_mat)| {
12656                        let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
12657                        run_head(bi, hi, s_mat, &mut sk_buf[..n]);
12658                    });
12659            }
12660        }
12661    }
12662}
12663
12664/// Host-fallback: `Op::RmsNormBackwardInput` (GPU unified-memory / D2H arenas).
12665pub unsafe fn execute_rms_norm_backward_input_f32(
12666    x: usize,
12667    gamma: usize,
12668    beta: usize,
12669    dy: usize,
12670    dx: usize,
12671    rows: u32,
12672    h: u32,
12673    eps: f32,
12674    base: *mut u8,
12675) {
12676    let (rows, h) = (rows as usize, h as usize);
12677    let mut dg = vec![0f32; h];
12678    let mut db = vec![0f32; h];
12679    let xs = sl(x, base, rows * h);
12680    let dys = sl(dy, base, rows * h);
12681    let g = sl(gamma, base, h);
12682    let b = sl(beta, base, h);
12683    let out = sl_mut(dx, base, rows * h);
12684    for r in 0..rows {
12685        crate::training_bwd::rms_norm_backward_row(
12686            &xs[r * h..(r + 1) * h],
12687            g,
12688            b,
12689            &dys[r * h..(r + 1) * h],
12690            &mut out[r * h..(r + 1) * h],
12691            &mut dg,
12692            &mut db,
12693            eps,
12694        );
12695    }
12696}
12697
12698pub unsafe fn execute_rms_norm_backward_gamma_f32(
12699    x: usize,
12700    gamma: usize,
12701    beta: usize,
12702    dy: usize,
12703    dgamma: usize,
12704    rows: u32,
12705    h: u32,
12706    eps: f32,
12707    base: *mut u8,
12708) {
12709    let (rows, h) = (rows as usize, h as usize);
12710    let out = sl_mut(dgamma, base, h);
12711    out.fill(0.0);
12712    let mut dx = vec![0f32; h];
12713    let mut db = vec![0f32; h];
12714    let xs = sl(x, base, rows * h);
12715    let dys = sl(dy, base, rows * h);
12716    let g = sl(gamma, base, h);
12717    let b = sl(beta, base, h);
12718    for r in 0..rows {
12719        crate::training_bwd::rms_norm_backward_row(
12720            &xs[r * h..(r + 1) * h],
12721            g,
12722            b,
12723            &dys[r * h..(r + 1) * h],
12724            &mut dx,
12725            out,
12726            &mut db,
12727            eps,
12728        );
12729    }
12730}
12731
12732pub unsafe fn execute_rms_norm_backward_beta_f32(
12733    x: usize,
12734    gamma: usize,
12735    beta: usize,
12736    dy: usize,
12737    dbeta: usize,
12738    rows: u32,
12739    h: u32,
12740    eps: f32,
12741    base: *mut u8,
12742) {
12743    let (rows, h) = (rows as usize, h as usize);
12744    let out = sl_mut(dbeta, base, h);
12745    out.fill(0.0);
12746    let mut dx = vec![0f32; h];
12747    let mut dg = vec![0f32; h];
12748    let xs = sl(x, base, rows * h);
12749    let dys = sl(dy, base, rows * h);
12750    let g = sl(gamma, base, h);
12751    let b = sl(beta, base, h);
12752    for r in 0..rows {
12753        crate::training_bwd::rms_norm_backward_row(
12754            &xs[r * h..(r + 1) * h],
12755            g,
12756            b,
12757            &dys[r * h..(r + 1) * h],
12758            &mut dx,
12759            &mut dg,
12760            out,
12761            eps,
12762        );
12763    }
12764}
12765
12766pub unsafe fn execute_rope_backward_f32(
12767    dy: usize,
12768    cos: usize,
12769    sin: usize,
12770    dx: usize,
12771    batch: u32,
12772    seq: u32,
12773    hidden: u32,
12774    head_dim: u32,
12775    n_rot: u32,
12776    cos_len: u32,
12777    base: *mut u8,
12778) {
12779    let (b, s, hs, dh, nr, cl) = (
12780        batch as usize,
12781        seq as usize,
12782        hidden as usize,
12783        head_dim as usize,
12784        n_rot as usize,
12785        cos_len as usize,
12786    );
12787    let nh = hs / dh;
12788    let tab_half = dh / 2;
12789    let dys = sl(dy, base, b * s * hs);
12790    let cos_tab = sl(cos, base, cl);
12791    let sin_tab = sl(sin, base, cl);
12792    let out = sl_mut(dx, base, b * s * hs);
12793    for bi in 0..b {
12794        for si in 0..s {
12795            let tab_off = si.saturating_mul(tab_half) % cl.max(1);
12796            let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
12797            let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
12798            for hi in 0..nh {
12799                let base_idx = bi * s * hs + si * hs + hi * dh;
12800                crate::training_bwd::rope_backward_row(
12801                    &dys[base_idx..base_idx + dh],
12802                    cp,
12803                    sp,
12804                    &mut out[base_idx..base_idx + dh],
12805                    dh,
12806                    nr,
12807                );
12808            }
12809        }
12810    }
12811}
12812
12813pub unsafe fn execute_cumsum_backward_f32(
12814    dy: usize,
12815    dx: usize,
12816    rows: u32,
12817    cols: u32,
12818    exclusive: bool,
12819    base: *mut u8,
12820) {
12821    let (rows, cols) = (rows as usize, cols as usize);
12822    let dys = sl(dy, base, rows * cols);
12823    let out = sl_mut(dx, base, rows * cols);
12824    for r in 0..rows {
12825        crate::training_bwd::cumsum_backward_row(
12826            &dys[r * cols..(r + 1) * cols],
12827            &mut out[r * cols..(r + 1) * cols],
12828            exclusive,
12829        );
12830    }
12831}
12832
12833pub unsafe fn execute_gather_backward_f32(
12834    dy: usize,
12835    indices: usize,
12836    dst: usize,
12837    outer: u32,
12838    axis_dim: u32,
12839    num_idx: u32,
12840    trailing: u32,
12841    base: *mut u8,
12842) {
12843    let (outer, axis_dim, num_idx, trailing) = (
12844        outer as usize,
12845        axis_dim as usize,
12846        num_idx as usize,
12847        trailing as usize,
12848    );
12849    let out = sl_mut(dst, base, outer * axis_dim * trailing);
12850    out.fill(0.0);
12851    crate::training_bwd::gather_axis_backward(
12852        sl(dy, base, outer * num_idx * trailing),
12853        sl(indices, base, num_idx),
12854        out,
12855        outer,
12856        axis_dim,
12857        num_idx,
12858        trailing,
12859    );
12860}
12861
12862/// Host-fallback entry for GGUF `Op::DequantMatMul` (Metal unified memory).
12863pub unsafe fn execute_dequant_matmul_gguf_f32(
12864    x: usize,
12865    w_q: usize,
12866    dst: usize,
12867    m: usize,
12868    k: usize,
12869    n: usize,
12870    scheme: rlx_ir::quant::QuantScheme,
12871    base: *mut u8,
12872) {
12873    unsafe {
12874        let block_bytes = scheme.gguf_block_bytes() as usize;
12875        let block_elems = scheme.gguf_block_size() as usize;
12876        let total_bytes = (k * n) / block_elems * block_bytes;
12877        let xs = sl(x, base, m * k);
12878        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
12879        let out = sl_mut(dst, base, m * n);
12880        crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
12881    }
12882}
12883
12884/// Host-fallback entry for GGUF `Op::DequantGroupedMatMul` (MoE expert stack).
12885pub unsafe fn execute_dequant_grouped_matmul_gguf_f32(
12886    input: usize,
12887    w_q: usize,
12888    expert_idx: usize,
12889    dst: usize,
12890    m: usize,
12891    k: usize,
12892    n: usize,
12893    num_experts: usize,
12894    scheme: rlx_ir::quant::QuantScheme,
12895    base: *mut u8,
12896) {
12897    unsafe {
12898        let block_bytes = scheme.gguf_block_bytes() as usize;
12899        let block_elems = scheme.gguf_block_size() as usize;
12900        let slab_bytes = (k * n) / block_elems * block_bytes;
12901        let xs = sl(input, base, m * k);
12902        let w_bytes =
12903            std::slice::from_raw_parts(base.add(w_q) as *const u8, num_experts * slab_bytes);
12904        let ids = sl(expert_idx, base, m);
12905        let out = sl_mut(dst, base, m * n);
12906        crate::gguf_matmul::gguf_grouped_matmul_bt(
12907            xs,
12908            w_bytes,
12909            ids,
12910            out,
12911            m,
12912            k,
12913            n,
12914            num_experts,
12915            scheme,
12916        );
12917    }
12918}
12919
12920/// Host-fallback entry for Int4 `Op::DequantMatMul` (Metal unified memory).
12921pub unsafe fn execute_dequant_matmul_int4_f32(
12922    x: usize,
12923    w_q: usize,
12924    scale: usize,
12925    zp: usize,
12926    dst: usize,
12927    m: usize,
12928    k: usize,
12929    n: usize,
12930    block_size: u32,
12931    is_asymmetric: bool,
12932    base: *mut u8,
12933) {
12934    let bs = block_size as usize;
12935    let n_blocks = k.div_ceil(bs);
12936    unsafe {
12937        let xs = sl(x, base, m * k);
12938        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
12939        let scales = sl(scale, base, n_blocks * n);
12940        let zps = if is_asymmetric {
12941            sl(zp, base, n_blocks * n)
12942        } else {
12943            &[][..]
12944        };
12945        let out = sl_mut(dst, base, m * n);
12946        dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, is_asymmetric);
12947    }
12948}
12949
12950/// Host-fallback entry for FP8 `Op::DequantMatMul` (Metal unified memory).
12951pub unsafe fn execute_dequant_matmul_fp8_f32(
12952    x: usize,
12953    w_q: usize,
12954    scale: usize,
12955    dst: usize,
12956    m: usize,
12957    k: usize,
12958    n: usize,
12959    e5m2: bool,
12960    base: *mut u8,
12961) {
12962    unsafe {
12963        let xs = sl(x, base, m * k);
12964        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
12965        let scales = sl(scale, base, n);
12966        let out = sl_mut(dst, base, m * n);
12967        dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
12968    }
12969}
12970
12971/// Host-fallback entry for NVFP4 `Op::DequantMatMul` (Metal unified memory).
12972pub unsafe fn execute_dequant_matmul_nvfp4_f32(
12973    x: usize,
12974    w_q: usize,
12975    scale: usize,
12976    global_scale: usize,
12977    dst: usize,
12978    m: usize,
12979    k: usize,
12980    n: usize,
12981    base: *mut u8,
12982) {
12983    let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
12984    unsafe {
12985        let xs = sl(x, base, m * k);
12986        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
12987        let scale_bytes = std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
12988        let gs = sl(global_scale, base, 1)[0];
12989        let out = sl_mut(dst, base, m * n);
12990        dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
12991    }
12992}
12993
12994/// Host-fallback entry for f16 `Op::GatedDeltaNet` tensors on Metal.
12995pub unsafe fn execute_gated_delta_net_f16(
12996    q: usize,
12997    k: usize,
12998    v: usize,
12999    g: usize,
13000    beta: usize,
13001    state: usize,
13002    dst: usize,
13003    batch: usize,
13004    seq: usize,
13005    heads: usize,
13006    state_size: usize,
13007    base: *mut u8,
13008) {
13009    use half::f16;
13010    unsafe {
13011        let read_f16 = |off: usize, len: usize| -> Vec<f32> {
13012            let raw = std::slice::from_raw_parts(base.add(off) as *const u8, len * 2);
13013            raw.chunks_exact(2)
13014                .map(|c| f16::from_le_bytes([c[0], c[1]]).to_f32())
13015                .collect()
13016        };
13017        let write_f16 = |off: usize, data: &[f32]| {
13018            let out = std::slice::from_raw_parts_mut(base.add(off), data.len() * 2);
13019            for (i, &v) in data.iter().enumerate() {
13020                let le = f16::from_f32(v).to_le_bytes();
13021                out[i * 2] = le[0];
13022                out[i * 2 + 1] = le[1];
13023            }
13024        };
13025
13026        let (b, s, h, n) = (batch, seq, heads, state_size);
13027        let q_f = read_f16(q, b * s * h * n);
13028        let k_f = read_f16(k, b * s * h * n);
13029        let v_f = read_f16(v, b * s * h * n);
13030        let g_f = read_f16(g, b * s * h);
13031        let b_f = read_f16(beta, b * s * h);
13032        let mut state_f = if state != 0 {
13033            read_f16(state, b * h * n * n)
13034        } else {
13035            vec![0f32; b * h * n * n]
13036        };
13037        let mut out_f = vec![0f32; b * s * h * n];
13038        let scale = 1.0f32 / (n as f32).sqrt();
13039        let mut sk_buf = vec![0f32; n];
13040        let mut owned_state = vec![0f32; h * n * n];
13041
13042        for bi in 0..b {
13043            let state_slice: &mut [f32] = if state != 0 {
13044                let start = bi * h * n * n;
13045                &mut state_f[start..start + h * n * n]
13046            } else {
13047                owned_state.fill(0.0);
13048                &mut owned_state
13049            };
13050
13051            for ti in 0..s {
13052                let qkv_step_base = bi * s * h * n + ti * h * n;
13053                let gb_step_base = bi * s * h + ti * h;
13054
13055                for hi in 0..h {
13056                    let q_row = &q_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13057                    let k_row = &k_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13058                    let v_row = &v_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13059                    let g_t = g_f[gb_step_base + hi];
13060                    let beta_t = b_f[gb_step_base + hi];
13061
13062                    let s_base = hi * n * n;
13063                    let s_mat = &mut state_slice[s_base..s_base + n * n];
13064
13065                    let g_exp = g_t.exp();
13066                    for st in s_mat.iter_mut() {
13067                        *st *= g_exp;
13068                    }
13069
13070                    for j in 0..n {
13071                        let mut acc = 0f32;
13072                        for i in 0..n {
13073                            acc += s_mat[i * n + j] * k_row[i];
13074                        }
13075                        sk_buf[j] = acc;
13076                    }
13077
13078                    for j in 0..n {
13079                        sk_buf[j] = (v_row[j] - sk_buf[j]) * beta_t;
13080                    }
13081
13082                    for i in 0..n {
13083                        let ki = k_row[i];
13084                        for j in 0..n {
13085                            s_mat[i * n + j] += ki * sk_buf[j];
13086                        }
13087                    }
13088
13089                    let out_row = &mut out_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13090                    for j in 0..n {
13091                        let mut acc = 0f32;
13092                        for i in 0..n {
13093                            acc += s_mat[i * n + j] * q_row[i];
13094                        }
13095                        out_row[j] = acc * scale;
13096                    }
13097                }
13098            }
13099        }
13100
13101        write_f16(dst, &out_f);
13102        if state != 0 {
13103            write_f16(state, &state_f);
13104        }
13105    }
13106}
13107
13108/// Host fallback for NCHW group norm (Metal unified-memory arena).
13109pub unsafe fn execute_group_norm_nchw_f32(
13110    src: usize,
13111    g: usize,
13112    b: usize,
13113    dst: usize,
13114    n: usize,
13115    c: usize,
13116    h: usize,
13117    w: usize,
13118    num_groups: usize,
13119    eps: f32,
13120    base: *mut u8,
13121) {
13122    let plane = c * h * w;
13123    for ni in 0..n {
13124        let input = unsafe { sl(src + ni * plane * std::mem::size_of::<f32>(), base, plane) };
13125        let gamma = unsafe { sl(g, base, c) };
13126        let beta = unsafe { sl(b, base, c) };
13127        let output = unsafe { sl_mut(dst + ni * plane * std::mem::size_of::<f32>(), base, plane) };
13128        crate::kernels::group_norm_nchw(input, gamma, beta, output, 1, c, h, w, num_groups, eps);
13129    }
13130}
13131
13132/// Host fallback for NCHW LayerNorm2d (SAM / candle semantics).
13133pub unsafe fn execute_layer_norm2d_nchw_f32(
13134    src: usize,
13135    g: usize,
13136    b: usize,
13137    dst: usize,
13138    n: usize,
13139    c: usize,
13140    h: usize,
13141    w: usize,
13142    eps: f32,
13143    base: *mut u8,
13144) {
13145    let plane = c * h * w;
13146    unsafe {
13147        let input = sl(src, base, n * plane);
13148        let gamma = sl(g, base, c);
13149        let beta = sl(b, base, c);
13150        let output = sl_mut(dst, base, n * plane);
13151        crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, eps);
13152    }
13153}
13154
13155/// Host fallback for NCHW ConvTranspose2d.
13156pub unsafe fn execute_conv_transpose2d_nchw_f32(
13157    src: usize,
13158    weight: usize,
13159    dst: usize,
13160    n: usize,
13161    c_in: usize,
13162    h: usize,
13163    w_in: usize,
13164    c_out: usize,
13165    h_out: usize,
13166    w_out: usize,
13167    kh: usize,
13168    kw: usize,
13169    sh: usize,
13170    sw: usize,
13171    ph: usize,
13172    pw: usize,
13173    dh: usize,
13174    dw: usize,
13175    groups: usize,
13176    base: *mut u8,
13177) {
13178    let in_elems = n * c_in * h * w_in;
13179    let w_elems = c_in * (c_out / groups) * kh * kw;
13180    let out_elems = n * c_out * h_out * w_out;
13181    unsafe {
13182        let input = sl(src, base, in_elems);
13183        let wt = sl(weight, base, w_elems);
13184        let output = sl_mut(dst, base, out_elems);
13185        crate::kernels::conv_transpose2d_nchw(
13186            input, wt, output, n, c_in, h, w_in, c_out, h_out, w_out, kh, kw, sh, sw, ph, pw, dh,
13187            dw, groups,
13188        );
13189    }
13190}
13191
13192/// Host fallback for nearest 2× upsample on NCHW.
13193pub unsafe fn execute_resize_nearest_2x_f32(
13194    src: usize,
13195    dst: usize,
13196    n: usize,
13197    c: usize,
13198    h: usize,
13199    w: usize,
13200    base: *mut u8,
13201) {
13202    let in_plane = c * h * w;
13203    let out_plane = c * h * 2 * w * 2;
13204    for ni in 0..n {
13205        let input = unsafe {
13206            sl(
13207                src + ni * in_plane * std::mem::size_of::<f32>(),
13208                base,
13209                in_plane,
13210            )
13211        };
13212        let output = unsafe {
13213            sl_mut(
13214                dst + ni * out_plane * std::mem::size_of::<f32>(),
13215                base,
13216                out_plane,
13217            )
13218        };
13219        crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
13220    }
13221}
13222
13223/// Host axial 2-D RoPE for Metal (and other) fallbacks on unified memory.
13224pub unsafe fn execute_axial_rope2d_f32(
13225    src: usize,
13226    dst: usize,
13227    batch: usize,
13228    seq: usize,
13229    hidden: usize,
13230    end_x: usize,
13231    end_y: usize,
13232    head_dim: usize,
13233    num_heads: usize,
13234    theta: f32,
13235    repeat_factor: usize,
13236    base: *mut u8,
13237) {
13238    let plane = seq * hidden;
13239    let plane_bytes = plane * std::mem::size_of::<f32>();
13240    for bi in 0..batch {
13241        let in_off = src + bi * plane_bytes;
13242        let input = unsafe { sl(in_off, base, plane) };
13243        let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
13244            input,
13245            num_heads,
13246            seq,
13247            head_dim,
13248            end_x,
13249            end_y,
13250            theta,
13251            repeat_factor,
13252        );
13253        let out_off = dst + bi * plane_bytes;
13254        let output = unsafe { sl_mut(out_off, base, plane) };
13255        output.copy_from_slice(&rotated);
13256    }
13257}
13258
13259/// f32 mirror of `execute_fft1d_f64`. Same public-host-fallback role.
13260pub unsafe fn execute_fft1d_f32(
13261    src: usize,
13262    dst: usize,
13263    outer: usize,
13264    n_complex: usize,
13265    inverse: bool,
13266    base: *mut u8,
13267) {
13268    let row_elems = 2 * n_complex;
13269    let mut re = vec![0f32; n_complex];
13270    let mut im = vec![0f32; n_complex];
13271    let mut scratch = if n_complex.is_power_of_two() {
13272        BluesteinScratchF32::empty()
13273    } else {
13274        BluesteinScratchF32::build(n_complex, inverse)
13275    };
13276    for o in 0..outer {
13277        let row_offset = src + o * row_elems * std::mem::size_of::<f32>();
13278        let s = unsafe { sl(row_offset, base, row_elems) };
13279        re.copy_from_slice(&s[..n_complex]);
13280        im.copy_from_slice(&s[n_complex..]);
13281        if n_complex.is_power_of_two() {
13282            fft_radix2_inplace_f32(&mut re, &mut im, inverse);
13283        } else {
13284            fft_bluestein_inplace_f32(&mut re, &mut im, inverse, &mut scratch);
13285        }
13286        let dst_offset = dst + o * row_elems * std::mem::size_of::<f32>();
13287        let d = unsafe { sl_mut(dst_offset, base, row_elems) };
13288        d[..n_complex].copy_from_slice(&re);
13289        d[n_complex..].copy_from_slice(&im);
13290    }
13291}
13292
13293/// f32 in-place radix-2 DIT Cooley-Tukey. Structurally identical to
13294/// the f64 path; twiddle recurrence is kept in f64 so accumulated
13295/// rotation drift doesn't dominate the per-stage error budget at
13296/// larger N.
13297fn fft_radix2_inplace_f32(re: &mut [f32], im: &mut [f32], inverse: bool) {
13298    let n = re.len();
13299    debug_assert_eq!(im.len(), n);
13300    debug_assert!(
13301        n.is_power_of_two(),
13302        "fft_radix2_f32: n={n} must be a power of two"
13303    );
13304    if n <= 1 {
13305        return;
13306    }
13307
13308    let mut j = 0usize;
13309    for i in 1..n {
13310        let mut bit = n >> 1;
13311        while j & bit != 0 {
13312            j ^= bit;
13313            bit >>= 1;
13314        }
13315        j ^= bit;
13316        if i < j {
13317            re.swap(i, j);
13318            im.swap(i, j);
13319        }
13320    }
13321
13322    let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
13323    let mut len = 2usize;
13324    while len <= n {
13325        let half = len / 2;
13326        let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
13327        let w_re_step = theta.cos();
13328        let w_im_step = theta.sin();
13329        let mut i = 0usize;
13330        while i < n {
13331            let mut wre = 1.0_f64;
13332            let mut wim = 0.0_f64;
13333            for k in 0..half {
13334                let wre_f = wre as f32;
13335                let wim_f = wim as f32;
13336                let t_re = wre_f * re[i + k + half] - wim_f * im[i + k + half];
13337                let t_im = wre_f * im[i + k + half] + wim_f * re[i + k + half];
13338                let u_re = re[i + k];
13339                let u_im = im[i + k];
13340                re[i + k] = u_re + t_re;
13341                im[i + k] = u_im + t_im;
13342                re[i + k + half] = u_re - t_re;
13343                im[i + k + half] = u_im - t_im;
13344                let new_wre = wre * w_re_step - wim * w_im_step;
13345                let new_wim = wre * w_im_step + wim * w_re_step;
13346                wre = new_wre;
13347                wim = new_wim;
13348            }
13349            i += len;
13350        }
13351        len <<= 1;
13352    }
13353}
13354
13355/// In-place radix-2 DIT Cooley-Tukey FFT on split (real, imag) f64
13356/// arrays. `n = re.len() = im.len()` must be a power of two. Forward
13357/// uses ω = exp(-2πi/n); inverse uses ω = exp(+2πi/n) (no 1/N scale).
13358fn fft_radix2_inplace_f64(re: &mut [f64], im: &mut [f64], inverse: bool) {
13359    let n = re.len();
13360    debug_assert_eq!(im.len(), n);
13361    debug_assert!(
13362        n.is_power_of_two(),
13363        "fft_radix2: n={n} must be a power of two"
13364    );
13365    if n <= 1 {
13366        return;
13367    }
13368
13369    // Bit-reverse permutation.
13370    let mut j = 0usize;
13371    for i in 1..n {
13372        let mut bit = n >> 1;
13373        while j & bit != 0 {
13374            j ^= bit;
13375            bit >>= 1;
13376        }
13377        j ^= bit;
13378        if i < j {
13379            re.swap(i, j);
13380            im.swap(i, j);
13381        }
13382    }
13383
13384    // Cooley-Tukey butterflies: ω_len = exp(±2πi/len).
13385    let sign = if inverse { 1.0 } else { -1.0 };
13386    let mut len = 2usize;
13387    while len <= n {
13388        let half = len / 2;
13389        let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
13390        let w_re_step = theta.cos();
13391        let w_im_step = theta.sin();
13392        let mut i = 0usize;
13393        while i < n {
13394            // Twiddle starts at 1+0i for each segment.
13395            let mut wre = 1.0_f64;
13396            let mut wim = 0.0_f64;
13397            for k in 0..half {
13398                let t_re = wre * re[i + k + half] - wim * im[i + k + half];
13399                let t_im = wre * im[i + k + half] + wim * re[i + k + half];
13400                let u_re = re[i + k];
13401                let u_im = im[i + k];
13402                re[i + k] = u_re + t_re;
13403                im[i + k] = u_im + t_im;
13404                re[i + k + half] = u_re - t_re;
13405                im[i + k + half] = u_im - t_im;
13406                let new_wre = wre * w_re_step - wim * w_im_step;
13407                let new_wim = wre * w_im_step + wim * w_re_step;
13408                wre = new_wre;
13409                wim = new_wim;
13410            }
13411            i += len;
13412        }
13413        len <<= 1;
13414    }
13415}
13416
13417/// Pre-computed chirp + filter-spectrum for one (N, direction) pair.
13418/// Built once per call to `execute_fft1d_f64` and reused across rows
13419/// when `outer > 1` — the chirp and FFT(b) don't depend on the input.
13420struct BluesteinScratchF64 {
13421    /// Power-of-two convolution length, ≥ 2N - 1.
13422    m: usize,
13423    /// `w[k] = exp(sign · iπ · k² / N)` for k=0..N, where sign matches
13424    /// the requested direction. Forward chirp on the way in, output
13425    /// chirp on the way out.
13426    w_re: Vec<f64>,
13427    w_im: Vec<f64>,
13428    /// FFT of the embedded filter `b[k] = conj(w[|k|])` in length-M.
13429    /// Doesn't depend on the input — precomputed once.
13430    bf_re: Vec<f64>,
13431    bf_im: Vec<f64>,
13432    /// Workspace reused per row (avoids per-row allocation).
13433    ar: Vec<f64>,
13434    ai: Vec<f64>,
13435}
13436
13437impl BluesteinScratchF64 {
13438    fn empty() -> Self {
13439        Self {
13440            m: 0,
13441            w_re: Vec::new(),
13442            w_im: Vec::new(),
13443            bf_re: Vec::new(),
13444            bf_im: Vec::new(),
13445            ar: Vec::new(),
13446            ai: Vec::new(),
13447        }
13448    }
13449
13450    fn build(n: usize, inverse: bool) -> Self {
13451        // M = next power of two ≥ 2N - 1 keeps the inner FFT on the
13452        // fast radix-2 path. For N=1 fall back to M=1 (no-op convolution).
13453        let m = if n <= 1 {
13454            1
13455        } else {
13456            (2 * n - 1).next_power_of_two()
13457        };
13458
13459        // Chirp arg reduced via k² mod 2N — without this, large N
13460        // bleeds precision into the trig call (n² grows quadratically).
13461        let mod_2n = (2 * n) as u64;
13462        let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
13463        let mut w_re = vec![0.0_f64; n];
13464        let mut w_im = vec![0.0_f64; n];
13465        for k in 0..n {
13466            let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
13467            let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
13468            w_re[k] = theta.cos();
13469            w_im[k] = theta.sin();
13470        }
13471
13472        // Embed b[k] = conj(w[|k|]) into length M with the negative
13473        // indices wrapping to the tail: b[-j] → B[M-j] for j=1..N-1.
13474        let mut bf_re = vec![0.0_f64; m];
13475        let mut bf_im = vec![0.0_f64; m];
13476        if n > 0 {
13477            bf_re[0] = w_re[0];
13478            bf_im[0] = -w_im[0];
13479            for k in 1..n {
13480                bf_re[k] = w_re[k];
13481                bf_im[k] = -w_im[k];
13482                bf_re[m - k] = w_re[k];
13483                bf_im[m - k] = -w_im[k];
13484            }
13485        }
13486        if m > 1 {
13487            fft_radix2_inplace_f64(&mut bf_re, &mut bf_im, false);
13488        }
13489
13490        Self {
13491            m,
13492            w_re,
13493            w_im,
13494            bf_re,
13495            bf_im,
13496            ar: vec![0.0_f64; m],
13497            ai: vec![0.0_f64; m],
13498        }
13499    }
13500}
13501
13502/// Bluestein (chirp-z) FFT for arbitrary N. Identity used:
13503///   `n·k = (n² + k² - (k-n)²) / 2`
13504/// which lets the DFT be written as a linear convolution sandwiched
13505/// between two chirp multiplies:
13506///   `X[k] = w[k] · ((x·w) ⊛ conj(w))[k]`   where `w[n] = exp(±iπ·n²/N)`.
13507/// The convolution is computed via a length-M radix-2 FFT (M ≥ 2N-1).
13508/// Both directions stay unnormalized to match the radix-2 path, so the
13509/// chain rule keeps working without scaling.
13510fn fft_bluestein_inplace_f64(
13511    re: &mut [f64],
13512    im: &mut [f64],
13513    _inverse: bool,
13514    s: &mut BluesteinScratchF64,
13515) {
13516    let n = re.len();
13517    debug_assert_eq!(im.len(), n);
13518    debug_assert_eq!(s.w_re.len(), n);
13519    if n <= 1 {
13520        return;
13521    }
13522    let m = s.m;
13523
13524    // Pre-chirp: a[k] = x[k] · w[k], zero-padded to M.
13525    for k in 0..m {
13526        s.ar[k] = 0.0;
13527        s.ai[k] = 0.0;
13528    }
13529    for k in 0..n {
13530        s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
13531        s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
13532    }
13533
13534    // Length-M forward FFT of the padded chirped input.
13535    fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, false);
13536
13537    // Pointwise product with FFT(b). Stored back into (ar, ai).
13538    for k in 0..m {
13539        let ar = s.ar[k];
13540        let ai = s.ai[k];
13541        let br = s.bf_re[k];
13542        let bi = s.bf_im[k];
13543        s.ar[k] = ar * br - ai * bi;
13544        s.ai[k] = ar * bi + ai * br;
13545    }
13546
13547    // Inverse FFT — radix-2 here is the unnormalized inverse, so we
13548    // divide by M to recover the true circular convolution.
13549    fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, true);
13550    let inv_m = 1.0 / (m as f64);
13551
13552    // Post-chirp: X[k] = w[k] · Y[k] / M for k = 0..N.
13553    for k in 0..n {
13554        let yr = s.ar[k] * inv_m;
13555        let yi = s.ai[k] * inv_m;
13556        re[k] = yr * s.w_re[k] - yi * s.w_im[k];
13557        im[k] = yr * s.w_im[k] + yi * s.w_re[k];
13558    }
13559}
13560
13561/// f32 mirror of `BluesteinScratchF64`. Chirp is computed in f64 for
13562/// precision (same justification as the radix-2 f32 path: twiddles in
13563/// f64, butterflies in f32). The actual conv buffers are f32.
13564struct BluesteinScratchF32 {
13565    m: usize,
13566    w_re: Vec<f32>,
13567    w_im: Vec<f32>,
13568    bf_re: Vec<f32>,
13569    bf_im: Vec<f32>,
13570    ar: Vec<f32>,
13571    ai: Vec<f32>,
13572}
13573
13574impl BluesteinScratchF32 {
13575    fn empty() -> Self {
13576        Self {
13577            m: 0,
13578            w_re: Vec::new(),
13579            w_im: Vec::new(),
13580            bf_re: Vec::new(),
13581            bf_im: Vec::new(),
13582            ar: Vec::new(),
13583            ai: Vec::new(),
13584        }
13585    }
13586
13587    fn build(n: usize, inverse: bool) -> Self {
13588        let m = if n <= 1 {
13589            1
13590        } else {
13591            (2 * n - 1).next_power_of_two()
13592        };
13593
13594        let mod_2n = (2 * n) as u64;
13595        let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
13596        let mut w_re = vec![0.0_f32; n];
13597        let mut w_im = vec![0.0_f32; n];
13598        for k in 0..n {
13599            let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
13600            let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
13601            w_re[k] = theta.cos() as f32;
13602            w_im[k] = theta.sin() as f32;
13603        }
13604
13605        let mut bf_re = vec![0.0_f32; m];
13606        let mut bf_im = vec![0.0_f32; m];
13607        if n > 0 {
13608            bf_re[0] = w_re[0];
13609            bf_im[0] = -w_im[0];
13610            for k in 1..n {
13611                bf_re[k] = w_re[k];
13612                bf_im[k] = -w_im[k];
13613                bf_re[m - k] = w_re[k];
13614                bf_im[m - k] = -w_im[k];
13615            }
13616        }
13617        if m > 1 {
13618            fft_radix2_inplace_f32(&mut bf_re, &mut bf_im, false);
13619        }
13620
13621        Self {
13622            m,
13623            w_re,
13624            w_im,
13625            bf_re,
13626            bf_im,
13627            ar: vec![0.0_f32; m],
13628            ai: vec![0.0_f32; m],
13629        }
13630    }
13631}
13632
13633fn fft_bluestein_inplace_f32(
13634    re: &mut [f32],
13635    im: &mut [f32],
13636    _inverse: bool,
13637    s: &mut BluesteinScratchF32,
13638) {
13639    let n = re.len();
13640    debug_assert_eq!(im.len(), n);
13641    debug_assert_eq!(s.w_re.len(), n);
13642    if n <= 1 {
13643        return;
13644    }
13645    let m = s.m;
13646
13647    for k in 0..m {
13648        s.ar[k] = 0.0;
13649        s.ai[k] = 0.0;
13650    }
13651    for k in 0..n {
13652        s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
13653        s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
13654    }
13655
13656    fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, false);
13657
13658    for k in 0..m {
13659        let ar = s.ar[k];
13660        let ai = s.ai[k];
13661        let br = s.bf_re[k];
13662        let bi = s.bf_im[k];
13663        s.ar[k] = ar * br - ai * bi;
13664        s.ai[k] = ar * bi + ai * br;
13665    }
13666
13667    fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, true);
13668    let inv_m = 1.0_f32 / (m as f32);
13669
13670    for k in 0..n {
13671        let yr = s.ar[k] * inv_m;
13672        let yi = s.ai[k] * inv_m;
13673        re[k] = yr * s.w_re[k] - yi * s.w_im[k];
13674        im[k] = yr * s.w_im[k] + yi * s.w_re[k];
13675    }
13676}
13677
13678/// Shared dispatch path for `Thunk::CustomOp`. Builds a typed
13679/// [`CpuTensorRef`] for each input *at that input's declared dtype*
13680/// (so a sparse-LU op with mixed F64/I32 inputs gets the right
13681/// typed slices) and a [`CpuTensorMut`] for the output, then calls
13682/// the kernel's single `execute` method.
13683unsafe fn dispatch_custom_op(
13684    kernel: &dyn crate::op_registry::CpuKernel,
13685    inputs: &[(usize, u32, Shape)],
13686    out_off: usize,
13687    out_len: u32,
13688    out_shape: &Shape,
13689    attrs: &[u8],
13690    base: *mut u8,
13691) {
13692    use crate::op_registry::{CpuTensorMut, CpuTensorRef};
13693    use rlx_ir::DType;
13694
13695    // One arm per `DType` variant — single source of truth for
13696    // "which dtypes the CPU custom-op dispatcher wires." If a new
13697    // DType lands in `rlx-ir`, the compiler flags this match as
13698    // non-exhaustive and the gap gets named at the right place.
13699    macro_rules! build_in_view {
13700        ($shape:expr, $off:expr, $n:expr, $variant:ident, $rust_ty:ty) => {
13701            CpuTensorRef::$variant {
13702                data: unsafe { sl_typed::<$rust_ty>($off, base, $n) },
13703                shape: $shape,
13704            }
13705        };
13706    }
13707    macro_rules! build_out_view {
13708        ($variant:ident, $rust_ty:ty) => {
13709            CpuTensorMut::$variant {
13710                data: unsafe { sl_mut_typed::<$rust_ty>(out_off, base, out_len as usize) },
13711                shape: out_shape,
13712            }
13713        };
13714    }
13715
13716    let in_views: Vec<CpuTensorRef<'_>> = inputs
13717        .iter()
13718        .map(|(off, len, shape)| {
13719            let n = *len as usize;
13720            let off = *off;
13721            match shape.dtype() {
13722                DType::F32 => build_in_view!(shape, off, n, F32, f32),
13723                DType::F64 => build_in_view!(shape, off, n, F64, f64),
13724                DType::F16 => build_in_view!(shape, off, n, F16, half::f16),
13725                DType::BF16 => build_in_view!(shape, off, n, BF16, half::bf16),
13726                DType::I8 => build_in_view!(shape, off, n, I8, i8),
13727                DType::I16 => build_in_view!(shape, off, n, I16, i16),
13728                DType::I32 => build_in_view!(shape, off, n, I32, i32),
13729                DType::I64 => build_in_view!(shape, off, n, I64, i64),
13730                DType::U8 => build_in_view!(shape, off, n, U8, u8),
13731                DType::U32 => build_in_view!(shape, off, n, U32, u32),
13732                DType::Bool => build_in_view!(shape, off, n, Bool, u8),
13733                // C64 isn't a CpuTensor variant today; the user-registered
13734                // op_registry path doesn't see complex inputs (those are
13735                // handled by built-in ops with dedicated kernels).
13736                DType::C64 => panic!(
13737                    "Op::Custom kernel input has DType::C64 — built-in \
13738                 complex ops handle their own kernels; user-registered \
13739                 ops don't yet see complex tensors"
13740                ),
13741            }
13742        })
13743        .collect();
13744
13745    let result = match out_shape.dtype() {
13746        DType::F32 => kernel.execute(&in_views, build_out_view!(F32, f32), attrs),
13747        DType::F64 => kernel.execute(&in_views, build_out_view!(F64, f64), attrs),
13748        DType::F16 => kernel.execute(&in_views, build_out_view!(F16, half::f16), attrs),
13749        DType::BF16 => kernel.execute(&in_views, build_out_view!(BF16, half::bf16), attrs),
13750        DType::I8 => kernel.execute(&in_views, build_out_view!(I8, i8), attrs),
13751        DType::I16 => kernel.execute(&in_views, build_out_view!(I16, i16), attrs),
13752        DType::I32 => kernel.execute(&in_views, build_out_view!(I32, i32), attrs),
13753        DType::I64 => kernel.execute(&in_views, build_out_view!(I64, i64), attrs),
13754        DType::U8 => kernel.execute(&in_views, build_out_view!(U8, u8), attrs),
13755        DType::U32 => kernel.execute(&in_views, build_out_view!(U32, u32), attrs),
13756        DType::Bool => kernel.execute(&in_views, build_out_view!(Bool, u8), attrs),
13757        DType::C64 => panic!("Op::Custom output DType::C64 not supported"),
13758    };
13759    if let Err(e) = result {
13760        panic!("Op::Custom('{}') CPU kernel failed: {e}", kernel.name());
13761    }
13762}
13763
13764/// Generic raw-cast slice helper. The existing per-dtype `sl_*` /
13765/// `sl_mut_*` helpers stay in place for the rest of `thunk.rs` (which
13766/// uses them at call sites with concrete dtypes); the custom-op
13767/// dispatcher uses these to enumerate every `DType` uniformly without
13768/// listing one helper per dtype.
13769#[inline(always)]
13770unsafe fn sl_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static [T] {
13771    if offset == usize::MAX {
13772        return &[];
13773    }
13774    unsafe { std::slice::from_raw_parts(base.add(offset) as *const T, len) }
13775}
13776
13777#[inline(always)]
13778unsafe fn sl_mut_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static mut [T] {
13779    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut T, len) }
13780}
13781
13782// Unsafe helpers to create slices from arena base + offset
13783#[inline(always)]
13784/// In-place per-element activation. Mirrors the dispatch in
13785/// `Thunk::ActivationInPlace`. Used by `Thunk::FusedMmBiasAct` to
13786/// apply the activation after `bias_add` for all non-Gelu cases.
13787fn apply_activation_inplace(d: &mut [f32], act: rlx_ir::op::Activation) {
13788    use rlx_ir::op::Activation;
13789    match act {
13790        Activation::Gelu => crate::kernels::par_gelu_inplace(d),
13791        Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
13792        Activation::Silu => crate::kernels::par_silu_inplace(d),
13793        Activation::Relu => {
13794            for v in d.iter_mut() {
13795                *v = v.max(0.0);
13796            }
13797        }
13798        Activation::Sigmoid => {
13799            for v in d.iter_mut() {
13800                *v = 1.0 / (1.0 + (-*v).exp());
13801            }
13802        }
13803        Activation::Tanh => {
13804            for v in d.iter_mut() {
13805                *v = v.tanh();
13806            }
13807        }
13808        Activation::Exp => {
13809            for v in d.iter_mut() {
13810                *v = v.exp();
13811            }
13812        }
13813        Activation::Log => {
13814            for v in d.iter_mut() {
13815                *v = v.ln();
13816            }
13817        }
13818        Activation::Sqrt => {
13819            for v in d.iter_mut() {
13820                *v = v.sqrt();
13821            }
13822        }
13823        Activation::Rsqrt => {
13824            for v in d.iter_mut() {
13825                *v = 1.0 / v.sqrt();
13826            }
13827        }
13828        Activation::Neg => {
13829            for v in d.iter_mut() {
13830                *v = -*v;
13831            }
13832        }
13833        Activation::Abs => {
13834            for v in d.iter_mut() {
13835                *v = v.abs();
13836            }
13837        }
13838        Activation::Round => {
13839            for v in d.iter_mut() {
13840                *v = v.round();
13841            }
13842        }
13843        Activation::Sin => {
13844            for v in d.iter_mut() {
13845                *v = v.sin();
13846            }
13847        }
13848        Activation::Cos => {
13849            for v in d.iter_mut() {
13850                *v = v.cos();
13851            }
13852        }
13853        Activation::Tan => {
13854            for v in d.iter_mut() {
13855                *v = v.tan();
13856            }
13857        }
13858        Activation::Atan => {
13859            for v in d.iter_mut() {
13860                *v = v.atan();
13861            }
13862        }
13863    }
13864}
13865
13866/// im2col for one image (single batch + group slice).
13867///
13868/// Source `x` is `[c_in, H, W]` row-major. Destination `col` is
13869/// `[c_in · kH · kW, H_out · W_out]` row-major. Out-of-bounds positions
13870/// (in the padded region) are written as 0.
13871///
13872/// `col[(ci · kH · kW + ki · kW + kj) · n_dim + ho · W_out + wo] =
13873///    x[ci, ho·sh + ki·dh − ph, wo·sw + kj·dw_dil − pw]`
13874#[allow(clippy::too_many_arguments)]
13875fn im2col(
13876    x: &[f32],
13877    col: &mut [f32],
13878    c_in: usize,
13879    h: usize,
13880    w: usize,
13881    h_out: usize,
13882    w_out: usize,
13883    kh: usize,
13884    kw: usize,
13885    sh: usize,
13886    sw: usize,
13887    ph: usize,
13888    pw: usize,
13889    dh: usize,
13890    dw_dil: usize,
13891) {
13892    let n_dim = h_out * w_out;
13893    debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
13894    debug_assert_eq!(x.len(), c_in * h * w);
13895    let h_isz = h as isize;
13896    let w_isz = w as isize;
13897    let ph_isz = ph as isize;
13898    let pw_isz = pw as isize;
13899    for ci in 0..c_in {
13900        for ki in 0..kh {
13901            for kj in 0..kw {
13902                let row = ((ci * kh) + ki) * kw + kj;
13903                let row_off = row * n_dim;
13904                for ho in 0..h_out {
13905                    let hi = (ho * sh + ki * dh) as isize - ph_isz;
13906                    if hi < 0 || hi >= h_isz {
13907                        for wo in 0..w_out {
13908                            col[row_off + ho * w_out + wo] = 0.0;
13909                        }
13910                        continue;
13911                    }
13912                    let hi = hi as usize;
13913                    let in_row_off = (ci * h + hi) * w;
13914                    for wo in 0..w_out {
13915                        let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
13916                        col[row_off + ho * w_out + wo] = if wi < 0 || wi >= w_isz {
13917                            0.0
13918                        } else {
13919                            x[in_row_off + wi as usize]
13920                        };
13921                    }
13922                }
13923            }
13924        }
13925    }
13926}
13927
13928/// col2im — inverse of `im2col` with scatter-accumulation. The caller
13929/// is responsible for zeroing `x` if it doesn't already start zero
13930/// (the conv-input-grad path zeros once before the batch loop).
13931///
13932/// `x[ci, hi, wi] += col[(ci · kH · kW + ki · kW + kj) · n_dim + ho · W_out + wo]`
13933/// for all `(ki, kj, ho, wo)` whose `(hi, wi)` lands in `[0, H) × [0, W)`.
13934#[allow(clippy::too_many_arguments)]
13935fn col2im(
13936    col: &[f32],
13937    x: &mut [f32],
13938    c_in: usize,
13939    h: usize,
13940    w: usize,
13941    h_out: usize,
13942    w_out: usize,
13943    kh: usize,
13944    kw: usize,
13945    sh: usize,
13946    sw: usize,
13947    ph: usize,
13948    pw: usize,
13949    dh: usize,
13950    dw_dil: usize,
13951) {
13952    let n_dim = h_out * w_out;
13953    debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
13954    debug_assert_eq!(x.len(), c_in * h * w);
13955    let h_isz = h as isize;
13956    let w_isz = w as isize;
13957    let ph_isz = ph as isize;
13958    let pw_isz = pw as isize;
13959    for ci in 0..c_in {
13960        for ki in 0..kh {
13961            for kj in 0..kw {
13962                let row = ((ci * kh) + ki) * kw + kj;
13963                let row_off = row * n_dim;
13964                for ho in 0..h_out {
13965                    let hi = (ho * sh + ki * dh) as isize - ph_isz;
13966                    if hi < 0 || hi >= h_isz {
13967                        continue;
13968                    }
13969                    let hi = hi as usize;
13970                    let in_row_off = (ci * h + hi) * w;
13971                    for wo in 0..w_out {
13972                        let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
13973                        if wi < 0 || wi >= w_isz {
13974                            continue;
13975                        }
13976                        x[in_row_off + wi as usize] += col[row_off + ho * w_out + wo];
13977                    }
13978                }
13979            }
13980        }
13981    }
13982}
13983
13984/// Element-wise backward for `Op::Activation`. `xs` is the original
13985/// input to the forward activation; `dys` is the upstream gradient.
13986/// Writes `out[i] = (d/dx act(xs[i])) * dys[i]`.
13987/// Decompose a per-channel quantization shape into the
13988/// `(chan_axis, chan_dim, inner)` triplet the kernel needs to map a
13989/// flat output index to a channel index. Per-tensor (`axis = None`)
13990/// degenerates to `chan_dim = 1, inner = len`, which makes the
13991/// kernel's `(i / inner) % chan_dim` always 0 — same fast path the
13992/// scalar version used.
13993fn quant_layout(shape: &rlx_ir::Shape, axis: Option<usize>) -> (usize, usize, usize) {
13994    match axis {
13995        None => (0, 1, shape.num_elements().unwrap_or(0).max(1)),
13996        Some(d) => {
13997            let chan_dim = shape.dim(d).unwrap_static();
13998            let inner: usize = (d + 1..shape.rank())
13999                .map(|i| shape.dim(i).unwrap_static())
14000                .product::<usize>()
14001                .max(1);
14002            (d, chan_dim, inner)
14003        }
14004    }
14005}
14006
14007fn activation_backward_kernel(
14008    act: rlx_ir::op::Activation,
14009    xs: &[f32],
14010    dys: &[f32],
14011    out: &mut [f32],
14012) {
14013    use rlx_ir::op::Activation;
14014    let n = xs.len();
14015    debug_assert_eq!(dys.len(), n);
14016    debug_assert_eq!(out.len(), n);
14017    match act {
14018        Activation::Relu => {
14019            for i in 0..n {
14020                out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
14021            }
14022        }
14023        Activation::Sigmoid => {
14024            for i in 0..n {
14025                let s = 1.0 / (1.0 + (-xs[i]).exp());
14026                out[i] = s * (1.0 - s) * dys[i];
14027            }
14028        }
14029        Activation::Tanh => {
14030            for i in 0..n {
14031                let t = xs[i].tanh();
14032                out[i] = (1.0 - t * t) * dys[i];
14033            }
14034        }
14035        Activation::Silu => {
14036            // y = x * σ(x);  dy/dx = σ(x) * (1 + x * (1 - σ(x))).
14037            for i in 0..n {
14038                let s = 1.0 / (1.0 + (-xs[i]).exp());
14039                out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
14040            }
14041        }
14042        Activation::Gelu => {
14043            // Exact erf-based GELU:  y = 0.5 x (1 + erf(x / √2)).
14044            //   dy/dx = 0.5 (1 + erf(x/√2)) + (x / √(2π)) · exp(-x²/2)
14045            const INV_SQRT2: f32 = 0.707_106_77;
14046            const INV_SQRT_2PI: f32 = 0.398_942_3;
14047            for i in 0..n {
14048                let x = xs[i];
14049                let phi = 0.5 * (1.0 + erf_f32(x * INV_SQRT2));
14050                let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
14051                out[i] = (phi + x * pdf) * dys[i];
14052            }
14053        }
14054        Activation::GeluApprox => {
14055            // Tanh-approximation:
14056            //   y = 0.5 x (1 + tanh(c · (x + 0.044715 x³))) where c = √(2/π).
14057            const C: f32 = 0.797_884_6; // √(2/π)
14058            const A: f32 = 0.044_715;
14059            for i in 0..n {
14060                let x = xs[i];
14061                let inner = C * (x + A * x * x * x);
14062                let t = inner.tanh();
14063                let dinner = C * (1.0 + 3.0 * A * x * x);
14064                let d = 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * dinner;
14065                out[i] = d * dys[i];
14066            }
14067        }
14068        Activation::Exp => {
14069            for i in 0..n {
14070                out[i] = xs[i].exp() * dys[i];
14071            }
14072        }
14073        Activation::Log => {
14074            for i in 0..n {
14075                out[i] = dys[i] / xs[i];
14076            }
14077        }
14078        Activation::Sqrt => {
14079            // d/dx √x = 0.5 / √x — undefined at x=0; clamp to 0.
14080            for i in 0..n {
14081                let s = xs[i].sqrt();
14082                out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
14083            }
14084        }
14085        Activation::Rsqrt => {
14086            // d/dx (1/√x) = -0.5 · x^(-3/2).
14087            for i in 0..n {
14088                let s = xs[i].sqrt();
14089                out[i] = if s > 0.0 {
14090                    -0.5 * dys[i] / (xs[i] * s)
14091                } else {
14092                    0.0
14093                };
14094            }
14095        }
14096        Activation::Neg => {
14097            for i in 0..n {
14098                out[i] = -dys[i];
14099            }
14100        }
14101        Activation::Abs => {
14102            // sign(x); 0 at x=0.
14103            for i in 0..n {
14104                let x = xs[i];
14105                let s = if x > 0.0 {
14106                    1.0
14107                } else if x < 0.0 {
14108                    -1.0
14109                } else {
14110                    0.0
14111                };
14112                out[i] = s * dys[i];
14113            }
14114        }
14115        Activation::Round => {
14116            // STE: pretend the round was identity in the backward
14117            // pass. The round step has zero gradient almost
14118            // everywhere, so without this trick the optimizer can't
14119            // learn through it.
14120            out.copy_from_slice(dys);
14121        }
14122        Activation::Sin => {
14123            // d/dx sin(x) = cos(x).
14124            for i in 0..n {
14125                out[i] = xs[i].cos() * dys[i];
14126            }
14127        }
14128        Activation::Cos => {
14129            for i in 0..n {
14130                out[i] = -xs[i].sin() * dys[i];
14131            }
14132        }
14133        Activation::Tan => {
14134            // d/dx tan(x) = sec²(x) = 1 + tan²(x)
14135            for i in 0..n {
14136                let t = xs[i].tan();
14137                out[i] = (1.0 + t * t) * dys[i];
14138            }
14139        }
14140        Activation::Atan => {
14141            // d/dx atan(x) = 1 / (1 + x²)
14142            for i in 0..n {
14143                let x = xs[i];
14144                out[i] = dys[i] / (1.0 + x * x);
14145            }
14146        }
14147    }
14148}
14149
14150/// f64 sibling of `activation_backward_kernel`. Same math, twice the
14151/// precision — used by f64 graphs where the f32 kernel reading bytes
14152/// as `&[f32]` would silently discard half of every f64 value.
14153fn activation_backward_kernel_f64(
14154    act: rlx_ir::op::Activation,
14155    xs: &[f64],
14156    dys: &[f64],
14157    out: &mut [f64],
14158) {
14159    use rlx_ir::op::Activation;
14160    let n = xs.len();
14161    debug_assert_eq!(dys.len(), n);
14162    debug_assert_eq!(out.len(), n);
14163    match act {
14164        Activation::Relu => {
14165            for i in 0..n {
14166                out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
14167            }
14168        }
14169        Activation::Sigmoid => {
14170            for i in 0..n {
14171                let s = 1.0 / (1.0 + (-xs[i]).exp());
14172                out[i] = s * (1.0 - s) * dys[i];
14173            }
14174        }
14175        Activation::Tanh => {
14176            for i in 0..n {
14177                let t = xs[i].tanh();
14178                out[i] = (1.0 - t * t) * dys[i];
14179            }
14180        }
14181        Activation::Silu => {
14182            for i in 0..n {
14183                let s = 1.0 / (1.0 + (-xs[i]).exp());
14184                out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
14185            }
14186        }
14187        Activation::Gelu | Activation::GeluApprox => {
14188            // Both rare on f64 paths; use the high-quality libm erf.
14189            const INV_SQRT2: f64 = std::f64::consts::FRAC_1_SQRT_2;
14190            const INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
14191            for i in 0..n {
14192                let x = xs[i];
14193                let phi = 0.5 * (1.0 + erf_f64(x * INV_SQRT2));
14194                let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
14195                out[i] = (phi + x * pdf) * dys[i];
14196            }
14197        }
14198        Activation::Exp => {
14199            for i in 0..n {
14200                out[i] = xs[i].exp() * dys[i];
14201            }
14202        }
14203        Activation::Log => {
14204            for i in 0..n {
14205                out[i] = dys[i] / xs[i];
14206            }
14207        }
14208        Activation::Sqrt => {
14209            for i in 0..n {
14210                let s = xs[i].sqrt();
14211                out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
14212            }
14213        }
14214        Activation::Rsqrt => {
14215            for i in 0..n {
14216                let s = xs[i].sqrt();
14217                out[i] = if s > 0.0 {
14218                    -0.5 * dys[i] / (xs[i] * s)
14219                } else {
14220                    0.0
14221                };
14222            }
14223        }
14224        Activation::Neg => {
14225            for i in 0..n {
14226                out[i] = -dys[i];
14227            }
14228        }
14229        Activation::Abs => {
14230            for i in 0..n {
14231                let x = xs[i];
14232                let s = if x > 0.0 {
14233                    1.0
14234                } else if x < 0.0 {
14235                    -1.0
14236                } else {
14237                    0.0
14238                };
14239                out[i] = s * dys[i];
14240            }
14241        }
14242        Activation::Round => {
14243            out.copy_from_slice(dys);
14244        }
14245        Activation::Sin => {
14246            for i in 0..n {
14247                out[i] = xs[i].cos() * dys[i];
14248            }
14249        }
14250        Activation::Cos => {
14251            for i in 0..n {
14252                out[i] = -xs[i].sin() * dys[i];
14253            }
14254        }
14255        Activation::Tan => {
14256            for i in 0..n {
14257                let t = xs[i].tan();
14258                out[i] = (1.0 + t * t) * dys[i];
14259            }
14260        }
14261        Activation::Atan => {
14262            for i in 0..n {
14263                let x = xs[i];
14264                out[i] = dys[i] / (1.0 + x * x);
14265            }
14266        }
14267    }
14268}
14269
14270/// f64 erf via A&S 7.1.26 — same coefficients as `erf_f32`, computed
14271/// at f64 width. Max error ~1.5e-7 (limited by the polynomial, not the
14272/// arithmetic). Adequate for gradient kernels; if higher precision is
14273/// needed, swap in a libm dependency.
14274#[inline(always)]
14275fn erf_f64(x: f64) -> f64 {
14276    let s = x.signum();
14277    let x = x.abs();
14278    let t = 1.0 / (1.0 + 0.327_591_1 * x);
14279    let y = 1.0
14280        - (((((1.061_405_43 * t - 1.453_152_03) * t) + 1.421_413_75) * t - 0.284_496_74) * t
14281            + 0.254_829_59)
14282            * t
14283            * (-x * x).exp();
14284    s * y
14285}
14286
14287/// Cheap erf approximation (Abramowitz & Stegun 7.1.26, max error ~1.5e-7
14288/// over all of ℝ — plenty for f32 gradient kernels).
14289#[inline(always)]
14290fn erf_f32(x: f32) -> f32 {
14291    let s = x.signum();
14292    let x = x.abs();
14293    let t = 1.0 / (1.0 + 0.327_591_1 * x);
14294    let y = 1.0
14295        - (((((1.061_405_4 * t - 1.453_152_1) * t) + 1.421_413_8) * t - 0.284_496_74) * t
14296            + 0.254_829_6)
14297            * t
14298            * (-x * x).exp();
14299    s * y
14300}
14301
14302fn narrow_thunk_closure(
14303    src: usize,
14304    dst: usize,
14305    outer: u32,
14306    src_stride: u32,
14307    dst_stride: u32,
14308    inner: u32,
14309    elem_bytes: u8,
14310) -> Arc<dyn Fn(*mut u8) + Send + Sync> {
14311    let (outer, ss, ds, inner) = (
14312        outer as usize,
14313        src_stride as usize,
14314        dst_stride as usize,
14315        inner as usize,
14316    );
14317    if elem_bytes == 8 {
14318        Arc::new(move |base: *mut u8| unsafe {
14319            let s = sl_f64(src, base, outer * ss);
14320            let d = sl_mut_f64(dst, base, outer * ds);
14321            for o in 0..outer {
14322                d[o * ds..o * ds + inner].copy_from_slice(&s[o * ss..o * ss + inner]);
14323            }
14324        })
14325    } else {
14326        Arc::new(move |base: *mut u8| unsafe {
14327            let s = sl(src, base, outer * ss);
14328            let d = sl_mut(dst, base, outer * ds);
14329            for o in 0..outer {
14330                d[o * ds..o * ds + inner].copy_from_slice(&s[o * ss..o * ss + inner]);
14331            }
14332        })
14333    }
14334}
14335
14336unsafe fn sl(offset: usize, base: *mut u8, len: usize) -> &'static [f32] {
14337    if offset == usize::MAX {
14338        return &[];
14339    }
14340    unsafe { std::slice::from_raw_parts(base.add(offset) as *const f32, len) }
14341}
14342
14343#[inline(always)]
14344unsafe fn sl_mut(offset: usize, base: *mut u8, len: usize) -> &'static mut [f32] {
14345    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f32, len) }
14346}
14347
14348#[inline(always)]
14349unsafe fn sl_f64(offset: usize, base: *mut u8, len: usize) -> &'static [f64] {
14350    if offset == usize::MAX {
14351        return &[];
14352    }
14353    unsafe { std::slice::from_raw_parts(base.add(offset) as *const f64, len) }
14354}
14355
14356#[inline(always)]
14357unsafe fn sl_mut_f64(offset: usize, base: *mut u8, len: usize) -> &'static mut [f64] {
14358    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f64, len) }
14359}
14360
14361// i32 / i64 typed slice helpers — siblings of sl_f32/sl_f64. Kept for
14362// integer-tensor thunks that haven't landed yet (Sample, Gather index
14363// buffers); deleting them now would force re-deriving the unsafe
14364// boilerplate when the next int-typed thunk lands.
14365#[allow(dead_code)]
14366#[inline(always)]
14367unsafe fn sl_i32(offset: usize, base: *mut u8, len: usize) -> &'static [i32] {
14368    if offset == usize::MAX {
14369        return &[];
14370    }
14371    unsafe { std::slice::from_raw_parts(base.add(offset) as *const i32, len) }
14372}
14373
14374#[allow(dead_code)]
14375#[inline(always)]
14376unsafe fn sl_mut_i32(offset: usize, base: *mut u8, len: usize) -> &'static mut [i32] {
14377    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i32, len) }
14378}
14379
14380#[allow(dead_code)]
14381#[inline(always)]
14382unsafe fn sl_i64(offset: usize, base: *mut u8, len: usize) -> &'static [i64] {
14383    if offset == usize::MAX {
14384        return &[];
14385    }
14386    unsafe { std::slice::from_raw_parts(base.add(offset) as *const i64, len) }
14387}
14388
14389#[allow(dead_code)]
14390#[inline(always)]
14391unsafe fn sl_mut_i64(offset: usize, base: *mut u8, len: usize) -> &'static mut [i64] {
14392    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i64, len) }
14393}
14394
14395/// f64 N-D index walk used by Transpose and Expand. `out_dims` gives
14396/// the output shape; `in_strides` gives the source stride for each
14397/// output dim (broadcast axes have stride 0).
14398fn transpose_walk_f64(inp: &[f64], out: &mut [f64], out_dims: &[u32], in_strides: &[u32]) {
14399    let rank = out_dims.len();
14400    let mut idx = vec![0u32; rank];
14401    for o in 0..out.len() {
14402        let mut src_off = 0usize;
14403        for d in 0..rank {
14404            src_off += idx[d] as usize * in_strides[d] as usize;
14405        }
14406        out[o] = inp[src_off];
14407        // Increment index — last dim varies fastest.
14408        for d in (0..rank).rev() {
14409            idx[d] += 1;
14410            if idx[d] < out_dims[d] {
14411                break;
14412            }
14413            idx[d] = 0;
14414        }
14415    }
14416}
14417
14418/// f64 elementwise activation. Reads `inp`, writes `out`. For now
14419/// covers what the autodiff-emitted gradient graph needs (Neg, Exp,
14420/// Log, Sqrt, Rsqrt, Abs, Tanh, Sigmoid, Relu — the
14421/// transcendental-free subset). Approximate Gelu/Silu deferred until a
14422/// workload demands them at f64.
14423fn apply_activation_f64(inp: &[f64], out: &mut [f64], kind: Activation) {
14424    match kind {
14425        Activation::Neg => {
14426            for (o, &v) in out.iter_mut().zip(inp) {
14427                *o = -v;
14428            }
14429        }
14430        Activation::Exp => {
14431            for (o, &v) in out.iter_mut().zip(inp) {
14432                *o = v.exp();
14433            }
14434        }
14435        Activation::Log => {
14436            for (o, &v) in out.iter_mut().zip(inp) {
14437                *o = v.ln();
14438            }
14439        }
14440        Activation::Sqrt => {
14441            for (o, &v) in out.iter_mut().zip(inp) {
14442                *o = v.sqrt();
14443            }
14444        }
14445        Activation::Rsqrt => {
14446            for (o, &v) in out.iter_mut().zip(inp) {
14447                *o = 1.0 / v.sqrt();
14448            }
14449        }
14450        Activation::Abs => {
14451            for (o, &v) in out.iter_mut().zip(inp) {
14452                *o = v.abs();
14453            }
14454        }
14455        Activation::Tanh => {
14456            for (o, &v) in out.iter_mut().zip(inp) {
14457                *o = v.tanh();
14458            }
14459        }
14460        Activation::Sigmoid => {
14461            for (o, &v) in out.iter_mut().zip(inp) {
14462                *o = 1.0 / (1.0 + (-v).exp());
14463            }
14464        }
14465        Activation::Relu => {
14466            for (o, &v) in out.iter_mut().zip(inp) {
14467                *o = v.max(0.0);
14468            }
14469        }
14470        Activation::Round => {
14471            for (o, &v) in out.iter_mut().zip(inp) {
14472                *o = v.round_ties_even();
14473            }
14474        }
14475        Activation::Sin => {
14476            for (o, &v) in out.iter_mut().zip(inp) {
14477                *o = v.sin();
14478            }
14479        }
14480        Activation::Cos => {
14481            for (o, &v) in out.iter_mut().zip(inp) {
14482                *o = v.cos();
14483            }
14484        }
14485        Activation::Tan => {
14486            for (o, &v) in out.iter_mut().zip(inp) {
14487                *o = v.tan();
14488            }
14489        }
14490        Activation::Atan => {
14491            for (o, &v) in out.iter_mut().zip(inp) {
14492                *o = v.atan();
14493            }
14494        }
14495        Activation::Gelu | Activation::GeluApprox | Activation::Silu => {
14496            panic!(
14497                "apply_activation_f64: {kind:?} not yet implemented at f64. \
14498                    Add when a workload needs it."
14499            );
14500        }
14501    }
14502}
14503
14504#[inline]
14505fn binary_op_f64(op: BinaryOp, a: f64, b: f64) -> f64 {
14506    match op {
14507        BinaryOp::Add => a + b,
14508        BinaryOp::Sub => a - b,
14509        BinaryOp::Mul => a * b,
14510        BinaryOp::Div => a / b,
14511        BinaryOp::Max => a.max(b),
14512        BinaryOp::Min => a.min(b),
14513        BinaryOp::Pow => a.powf(b),
14514    }
14515}
14516
14517/// f64 sum reduction over a contiguous middle range.
14518/// Layout: input is `[outer, reduced, inner]`, output is `[outer, inner]`.
14519fn reduce_sum_f64(inp: &[f64], out: &mut [f64], outer: usize, reduced: usize, inner: usize) {
14520    for o in 0..outer {
14521        for n in 0..inner {
14522            let mut acc = 0.0_f64;
14523            for r in 0..reduced {
14524                acc += inp[o * reduced * inner + r * inner + n];
14525            }
14526            out[o * inner + n] = acc;
14527        }
14528    }
14529}
14530
14531#[cfg(test)]
14532mod tests {
14533    use super::*;
14534    use rlx_ir::*;
14535
14536    /// Plan #45: when a Narrow's only consumer is a Rope, the thunk
14537    /// fusion pass collapses them — the Narrow becomes Nop, and the
14538    /// Rope reads from the parent buffer with its row stride. This
14539    /// test runs the unfused path (batch*seq > FusedAttnBlock
14540    /// threshold) and asserts the rewrite happened.
14541    #[test]
14542    fn narrow_rope_fuses_in_unfused_path() {
14543        let f = DType::F32;
14544        let mut g = Graph::new("nr_fuse");
14545        // Force batch*seq > 64 so FusedAttnBlock doesn't pre-empt us.
14546        let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f)); // 16*8=128 > 64
14547        let cos = g.input("cos", Shape::new(&[16], f));
14548        let sin = g.input("sin", Shape::new(&[16], f));
14549        // Last-axis narrow: Q = qkv[..., 0..64]
14550        let q = g.narrow_(qkv, 2, 0, 64);
14551        let q_rope = g.rope(q, cos, sin, 16);
14552        g.set_outputs(vec![q_rope]);
14553
14554        let plan = rlx_opt::memory::plan_memory(&g);
14555        let arena = crate::arena::Arena::from_plan(plan);
14556        let sched = compile_thunks(&g, &arena);
14557
14558        let mut narrow_count = 0;
14559        let mut rope_with_stride: Option<u32> = None;
14560        for t in &sched.thunks {
14561            match t {
14562                Thunk::Narrow { .. } => narrow_count += 1,
14563                Thunk::Rope { src_row_stride, .. } => rope_with_stride = Some(*src_row_stride),
14564                _ => {}
14565            }
14566        }
14567        // After fusion the Narrow is gone; only the Rope remains, and
14568        // it now walks with the parent QKV's row stride (3 * 64 = 192).
14569        assert_eq!(
14570            narrow_count, 0,
14571            "Narrow→Rope fusion should leave zero Narrow thunks; saw {narrow_count}"
14572        );
14573        assert_eq!(
14574            rope_with_stride,
14575            Some(192),
14576            "Rope's src_row_stride should be 192 (parent qkv axis), saw {rope_with_stride:?}"
14577        );
14578    }
14579
14580    /// Plan #15: SSM selective scan matches a naive Python-style
14581    /// Python-style sequential reference.
14582    #[test]
14583    fn ssm_selective_scan_matches_reference() {
14584        use rlx_ir::Philox4x32;
14585        let bch = 1usize;
14586        let s = 4usize;
14587        let h = 3usize;
14588        let n = 2usize;
14589
14590        let mut rng = Philox4x32::new(13);
14591        let mut x = vec![0f32; bch * s * h];
14592        rng.fill_normal(&mut x);
14593        let mut delta = vec![0f32; bch * s * h];
14594        // Keep Δ small so exp(Δ·A) doesn't blow up.
14595        for v in delta.iter_mut() {
14596            *v = (rng.next_f32() - 0.5) * 0.1;
14597        }
14598        let mut a = vec![0f32; h * n];
14599        for v in a.iter_mut() {
14600            *v = -(rng.next_f32() * 0.5 + 0.1);
14601        } // negative for stability
14602        let mut b = vec![0f32; bch * s * n];
14603        rng.fill_normal(&mut b);
14604        let mut c = vec![0f32; bch * s * n];
14605        rng.fill_normal(&mut c);
14606
14607        // Reference scan.
14608        let mut expected = vec![0f32; bch * s * h];
14609        for bi in 0..bch {
14610            let mut state = vec![0f32; h * n];
14611            for si in 0..s {
14612                for ci in 0..h {
14613                    let d = delta[bi * s * h + si * h + ci];
14614                    let xv = x[bi * s * h + si * h + ci];
14615                    let mut acc = 0f32;
14616                    for ni in 0..n {
14617                        let da = (d * a[ci * n + ni]).exp();
14618                        state[ci * n + ni] =
14619                            da * state[ci * n + ni] + d * b[bi * s * n + si * n + ni] * xv;
14620                        acc += c[bi * s * n + si * n + ni] * state[ci * n + ni];
14621                    }
14622                    expected[bi * s * h + si * h + ci] = acc;
14623                }
14624            }
14625        }
14626
14627        // RLX path.
14628        let f = DType::F32;
14629        let mut g = Graph::new("ssm");
14630        let xn = g.input("x", Shape::new(&[bch, s, h], f));
14631        let dn = g.input("delta", Shape::new(&[bch, s, h], f));
14632        let an = g.param("a", Shape::new(&[h, n], f));
14633        let bn = g.param("b", Shape::new(&[bch, s, n], f));
14634        let cn = g.param("c", Shape::new(&[bch, s, n], f));
14635        let yn = g.selective_scan(xn, dn, an, bn, cn, n, Shape::new(&[bch, s, h], f));
14636        g.set_outputs(vec![yn]);
14637
14638        let plan = rlx_opt::memory::plan_memory(&g);
14639        let mut arena = crate::arena::Arena::from_plan(plan);
14640        let sched = compile_thunks(&g, &arena);
14641
14642        let xn_off = arena.byte_offset(xn);
14643        let dn_off = arena.byte_offset(dn);
14644        let an_off = arena.byte_offset(an);
14645        let bn_off = arena.byte_offset(bn);
14646        let cn_off = arena.byte_offset(cn);
14647        let yn_off = arena.byte_offset(yn);
14648        let buf = arena.raw_buf_mut();
14649        unsafe {
14650            let copy = |dst: *mut f32, data: &[f32]| {
14651                for (i, &v) in data.iter().enumerate() {
14652                    *dst.add(i) = v;
14653                }
14654            };
14655            copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
14656            copy(buf.as_mut_ptr().add(dn_off) as *mut f32, &delta);
14657            copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
14658            copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
14659            copy(buf.as_mut_ptr().add(cn_off) as *mut f32, &c);
14660        }
14661        execute_thunks(&sched, arena.raw_buf_mut());
14662
14663        let actual: Vec<f32> = unsafe {
14664            let p = arena.raw_buf().as_ptr().add(yn_off) as *const f32;
14665            (0..bch * s * h).map(|i| *p.add(i)).collect()
14666        };
14667
14668        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
14669            assert!(
14670                (e - a).abs() < 1e-3,
14671                "mismatch at {i}: expected {e}, got {a}"
14672            );
14673        }
14674    }
14675
14676    /// Plan #26: 1×1 conv lowers to per-batch sgemm and matches the
14677    /// scalar 7-loop reference.
14678    #[test]
14679    fn conv_1x1_fast_path_matches_scalar() {
14680        use rlx_ir::Philox4x32;
14681        // [N=2, C_in=4, H=3, W=3]
14682        let n = 2usize;
14683        let c_in = 4usize;
14684        let h = 3usize;
14685        let w = 3usize;
14686        let c_out = 5usize;
14687        let mut rng = Philox4x32::new(31);
14688        let mut x = vec![0f32; n * c_in * h * w];
14689        rng.fill_normal(&mut x);
14690        let mut weight = vec![0f32; c_out * c_in];
14691        rng.fill_normal(&mut weight);
14692
14693        // Reference: scalar 1×1 conv = per-batch matmul
14694        // out[ni, co, hi, wi] = sum_ci weight[co, ci] * x[ni, ci, hi, wi]
14695        let mut expected = vec![0f32; n * c_out * h * w];
14696        for ni in 0..n {
14697            for co in 0..c_out {
14698                for hi in 0..h {
14699                    for wi in 0..w {
14700                        let mut acc = 0f32;
14701                        for ci in 0..c_in {
14702                            acc += weight[co * c_in + ci]
14703                                * x[((ni * c_in) + ci) * h * w + hi * w + wi];
14704                        }
14705                        expected[((ni * c_out) + co) * h * w + hi * w + wi] = acc;
14706                    }
14707                }
14708            }
14709        }
14710
14711        // RLX path: build a graph with Op::Conv (kernel=[1,1], stride=[1,1], etc).
14712        let f = DType::F32;
14713        let mut g = Graph::new("conv1x1");
14714        let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
14715        let wn = g.param("w", Shape::new(&[c_out, c_in, 1, 1], f));
14716        // Manually add Op::Conv since there's no `g.conv()` helper.
14717        let cn = g.add_node(
14718            rlx_ir::Op::Conv {
14719                kernel_size: vec![1, 1],
14720                stride: vec![1, 1],
14721                padding: vec![0, 0],
14722                dilation: vec![1, 1],
14723                groups: 1,
14724            },
14725            vec![xn, wn],
14726            Shape::new(&[n, c_out, h, w], f),
14727        );
14728        g.set_outputs(vec![cn]);
14729
14730        let plan = rlx_opt::memory::plan_memory(&g);
14731        let mut arena = crate::arena::Arena::from_plan(plan);
14732        let sched = compile_thunks(&g, &arena);
14733
14734        // Verify the fast path was selected.
14735        let saw_fast = sched
14736            .thunks
14737            .iter()
14738            .any(|t| matches!(t, Thunk::Conv2D1x1 { .. }));
14739        let saw_slow = sched
14740            .thunks
14741            .iter()
14742            .any(|t| matches!(t, Thunk::Conv2D { .. }));
14743        assert!(saw_fast, "1×1 conv should emit Conv2D1x1");
14744        assert!(!saw_slow, "1×1 conv must not fall through to scalar Conv2D");
14745
14746        let xn_off = arena.byte_offset(xn);
14747        let wn_off = arena.byte_offset(wn);
14748        let cn_off = arena.byte_offset(cn);
14749        let buf = arena.raw_buf_mut();
14750        unsafe {
14751            let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
14752            for (i, &v) in x.iter().enumerate() {
14753                *xp.add(i) = v;
14754            }
14755            let wp = buf.as_mut_ptr().add(wn_off) as *mut f32;
14756            for (i, &v) in weight.iter().enumerate() {
14757                *wp.add(i) = v;
14758            }
14759        }
14760        execute_thunks(&sched, arena.raw_buf_mut());
14761
14762        let actual: Vec<f32> = unsafe {
14763            let p = arena.raw_buf().as_ptr().add(cn_off) as *const f32;
14764            (0..(n * c_out * h * w)).map(|i| *p.add(i)).collect()
14765        };
14766
14767        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
14768            assert!(
14769                (e - a).abs() < 1e-3,
14770                "mismatch at {i}: expected {e}, got {a}"
14771            );
14772        }
14773    }
14774
14775    /// Plan #5: fused dequant matmul matches the dequant-then-matmul
14776    /// reference (i.e. `(scale * (q - z)) @ x` materialized).
14777    #[test]
14778    fn dequant_matmul_int8_sym_matches_reference() {
14779        use rlx_ir::Philox4x32;
14780        use rlx_ir::quant::QuantScheme;
14781
14782        let m = 3usize;
14783        let k = 8usize;
14784        let n = 4usize;
14785        let block_size = 4usize; // 2 blocks per column
14786        let blocks_per_col = k / block_size;
14787
14788        // Random inputs: x f32, w_q i8, scales f32. Symmetric → no zp.
14789        let mut rng = Philox4x32::new(99);
14790        let mut x = vec![0f32; m * k];
14791        rng.fill_normal(&mut x);
14792        let w_q: Vec<i8> = (0..(k * n))
14793            .map(|i| ((i as i32 * 13 + 7) % 127 - 63) as i8)
14794            .collect();
14795        let scales: Vec<f32> = (0..(blocks_per_col * n))
14796            .map(|i| 0.01 + 0.001 * i as f32)
14797            .collect();
14798
14799        // Reference: build f32 weights from (q * scale) per block.
14800        let mut w_f32 = vec![0f32; k * n];
14801        for p in 0..k {
14802            let block = p / block_size;
14803            for j in 0..n {
14804                let s = scales[block * n + j];
14805                w_f32[p * n + j] = w_q[p * n + j] as f32 * s;
14806            }
14807        }
14808        let mut expected = vec![0f32; m * n];
14809        for i in 0..m {
14810            for j in 0..n {
14811                let mut acc = 0f32;
14812                for p in 0..k {
14813                    acc += x[i * k + p] * w_f32[p * n + j];
14814                }
14815                expected[i * n + j] = acc;
14816            }
14817        }
14818
14819        // RLX path.
14820        let f = DType::F32;
14821        let mut g = Graph::new("dq");
14822        let xn = g.input("x", Shape::new(&[m, k], f));
14823        let wn = g.param("w", Shape::new(&[k, n], DType::I8));
14824        let sn = g.param("scale", Shape::new(&[blocks_per_col, n], f));
14825        let zn = g.param("zp", Shape::new(&[blocks_per_col, n], f)); // unused (sym)
14826        let dq = g.dequant_matmul(
14827            xn,
14828            wn,
14829            sn,
14830            zn,
14831            QuantScheme::Int8Block {
14832                block_size: block_size as u32,
14833            },
14834            Shape::new(&[m, n], f),
14835        );
14836        g.set_outputs(vec![dq]);
14837
14838        let plan = rlx_opt::memory::plan_memory(&g);
14839        let mut arena = crate::arena::Arena::from_plan(plan);
14840        let sched = compile_thunks(&g, &arena);
14841
14842        let xn_off = arena.byte_offset(xn);
14843        let wn_off = arena.byte_offset(wn);
14844        let sn_off = arena.byte_offset(sn);
14845        let zn_off = arena.byte_offset(zn);
14846        let dq_off = arena.byte_offset(dq);
14847        let buf = arena.raw_buf_mut();
14848        unsafe {
14849            // Seed f32 inputs.
14850            let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
14851            for (i, &v) in x.iter().enumerate() {
14852                *xp.add(i) = v;
14853            }
14854            let sp = buf.as_mut_ptr().add(sn_off) as *mut f32;
14855            for (i, &v) in scales.iter().enumerate() {
14856                *sp.add(i) = v;
14857            }
14858            let zp = buf.as_mut_ptr().add(zn_off) as *mut f32;
14859            for i in 0..(blocks_per_col * n) {
14860                *zp.add(i) = 0.0;
14861            }
14862            // Seed i8 weights byte-by-byte.
14863            let wp = buf.as_mut_ptr().add(wn_off) as *mut i8;
14864            for (i, &v) in w_q.iter().enumerate() {
14865                *wp.add(i) = v;
14866            }
14867        }
14868        execute_thunks(&sched, arena.raw_buf_mut());
14869
14870        let actual: Vec<f32> = unsafe {
14871            let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
14872            (0..m * n).map(|i| *p.add(i)).collect()
14873        };
14874
14875        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
14876            assert!(
14877                (e - a).abs() < 1e-3,
14878                "mismatch at {i}: expected {e}, got {a}"
14879            );
14880        }
14881    }
14882
14883    /// Plan #9: LoRA matmul matches the unfused 3-matmul reference.
14884    #[test]
14885    fn lora_matmul_matches_unfused_reference() {
14886        use rlx_ir::Philox4x32;
14887
14888        let m = 4usize;
14889        let k = 8usize;
14890        let n = 6usize;
14891        let r = 2usize;
14892        let scale = 0.5f32;
14893
14894        // Random inputs (deterministic via Philox).
14895        let mut rng = Philox4x32::new(42);
14896        let mut x = vec![0f32; m * k];
14897        rng.fill_normal(&mut x);
14898        let mut w = vec![0f32; k * n];
14899        rng.fill_normal(&mut w);
14900        let mut a = vec![0f32; k * r];
14901        rng.fill_normal(&mut a);
14902        let mut b = vec![0f32; r * n];
14903        rng.fill_normal(&mut b);
14904
14905        // Reference: out = x·W + scale * x·A·B. Naive triple-loop.
14906        let naive = |a_buf: &[f32], b_buf: &[f32], rows: usize, inner: usize, cols: usize| {
14907            let mut o = vec![0f32; rows * cols];
14908            for i in 0..rows {
14909                for j in 0..cols {
14910                    let mut acc = 0f32;
14911                    for p in 0..inner {
14912                        acc += a_buf[i * inner + p] * b_buf[p * cols + j];
14913                    }
14914                    o[i * cols + j] = acc;
14915                }
14916            }
14917            o
14918        };
14919        let xw = naive(&x, &w, m, k, n);
14920        let xa = naive(&x, &a, m, k, r);
14921        let xab = naive(&xa, &b, m, r, n);
14922        let mut expected = xw;
14923        for i in 0..(m * n) {
14924            expected[i] += scale * xab[i];
14925        }
14926
14927        // RLX path: build a graph with one LoraMatMul.
14928        let f = DType::F32;
14929        let mut g = Graph::new("lora");
14930        let xn = g.input("x", Shape::new(&[m, k], f));
14931        let wn = g.param("w", Shape::new(&[k, n], f));
14932        let an = g.param("a", Shape::new(&[k, r], f));
14933        let bn = g.param("b", Shape::new(&[r, n], f));
14934        let lm = g.lora_matmul(xn, wn, an, bn, scale, Shape::new(&[m, n], f));
14935        g.set_outputs(vec![lm]);
14936
14937        let plan = rlx_opt::memory::plan_memory(&g);
14938        let mut arena = crate::arena::Arena::from_plan(plan);
14939        let sched = compile_thunks(&g, &arena);
14940
14941        let xn_off = arena.byte_offset(xn);
14942        let wn_off = arena.byte_offset(wn);
14943        let an_off = arena.byte_offset(an);
14944        let bn_off = arena.byte_offset(bn);
14945        let lm_off = arena.byte_offset(lm);
14946        let buf = arena.raw_buf_mut();
14947        unsafe {
14948            let copy = |dst: *mut f32, data: &[f32]| {
14949                for (i, &v) in data.iter().enumerate() {
14950                    *dst.add(i) = v;
14951                }
14952            };
14953            copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
14954            copy(buf.as_mut_ptr().add(wn_off) as *mut f32, &w);
14955            copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
14956            copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
14957        }
14958        execute_thunks(&sched, arena.raw_buf_mut());
14959
14960        let actual: Vec<f32> = unsafe {
14961            let p = arena.raw_buf().as_ptr().add(lm_off) as *const f32;
14962            (0..m * n).map(|i| *p.add(i)).collect()
14963        };
14964
14965        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
14966            assert!(
14967                (e - a).abs() < 1e-3,
14968                "mismatch at {i}: expected {e}, got {a}"
14969            );
14970        }
14971    }
14972
14973    /// Plan #42: fused sampling kernel determinism + greedy fallback.
14974    #[test]
14975    fn sample_temperature_zero_is_argmax() {
14976        // Very low temperature → distribution collapses on argmax.
14977        // Same seed → same output bit-for-bit.
14978        let f = DType::F32;
14979        let mut g = Graph::new("samp");
14980        let logits = g.input("logits", Shape::new(&[1, 8], f));
14981        let s = g.sample(logits, 0, 1.0, 1e-3, 42, Shape::new(&[1], f));
14982        g.set_outputs(vec![s]);
14983        let plan = rlx_opt::memory::plan_memory(&g);
14984        let mut arena = crate::arena::Arena::from_plan(plan);
14985        let sched = compile_thunks(&g, &arena);
14986
14987        let logits_off = arena.byte_offset(logits);
14988        let s_off = arena.byte_offset(s);
14989        let buf = arena.raw_buf_mut();
14990        unsafe {
14991            let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
14992            // argmax = index 5 (value 9.0).
14993            let inputs = [0.1f32, 0.2, 0.3, 0.4, 0.5, 9.0, 0.7, 0.8];
14994            for (i, &v) in inputs.iter().enumerate() {
14995                *p.add(i) = v;
14996            }
14997        }
14998        execute_thunks(&sched, arena.raw_buf_mut());
14999
15000        let token = unsafe {
15001            let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
15002            *p as usize
15003        };
15004        assert_eq!(token, 5, "low-temp sampling should pick the argmax");
15005    }
15006
15007    #[test]
15008    fn sample_top_k_one_is_deterministic() {
15009        // top_k=1 forces only the argmax to have nonzero probability.
15010        let f = DType::F32;
15011        let mut g = Graph::new("samp_k1");
15012        let logits = g.input("logits", Shape::new(&[1, 4], f));
15013        let s = g.sample(logits, 1, 1.0, 1.0, 7, Shape::new(&[1], f));
15014        g.set_outputs(vec![s]);
15015        let plan = rlx_opt::memory::plan_memory(&g);
15016        let mut arena = crate::arena::Arena::from_plan(plan);
15017        let sched = compile_thunks(&g, &arena);
15018
15019        let logits_off = arena.byte_offset(logits);
15020        let s_off = arena.byte_offset(s);
15021        let buf = arena.raw_buf_mut();
15022        unsafe {
15023            let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
15024            let inputs = [0.1f32, 5.0, 0.3, 0.4]; // argmax = 1
15025            for (i, &v) in inputs.iter().enumerate() {
15026                *p.add(i) = v;
15027            }
15028        }
15029        execute_thunks(&sched, arena.raw_buf_mut());
15030        let token = unsafe {
15031            let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
15032            *p as usize
15033        };
15034        assert_eq!(token, 1);
15035    }
15036
15037    /// Plan #44: cumsum primitive parity vs. naive scan.
15038    #[test]
15039    fn cumsum_inclusive_matches_naive() {
15040        let f = DType::F32;
15041        let mut g = Graph::new("cumsum");
15042        let x = g.input("x", Shape::new(&[2, 4], f));
15043        let cs = g.cumsum(x, -1, false, Shape::new(&[2, 4], f));
15044        g.set_outputs(vec![cs]);
15045        let plan = rlx_opt::memory::plan_memory(&g);
15046        let mut arena = crate::arena::Arena::from_plan(plan);
15047        let sched = compile_thunks(&g, &arena);
15048
15049        // Cache offsets up-front so we can drop the immutable borrow.
15050        let x_off = arena.byte_offset(x);
15051        let out_off = arena.byte_offset(cs);
15052        let buf = arena.raw_buf_mut();
15053        unsafe {
15054            let p = buf.as_mut_ptr().add(x_off) as *mut f32;
15055            let inputs = [1.0f32, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0];
15056            for (i, &v) in inputs.iter().enumerate() {
15057                *p.add(i) = v;
15058            }
15059        }
15060        execute_thunks(&sched, arena.raw_buf_mut());
15061
15062        let out: Vec<f32> = unsafe {
15063            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
15064            (0..8).map(|i| *p.add(i)).collect()
15065        };
15066        assert_eq!(out, vec![1.0, 3.0, 6.0, 10.0, 10.0, 30.0, 60.0, 100.0]);
15067    }
15068
15069    /// Plan #46 deep: Narrow×3 → Attention fusion. The three QKV
15070    /// narrows that BERT/Nomic emit on the unfused (batch*seq > 64)
15071    /// path collapse into a single strided-Attention thunk.
15072    #[test]
15073    fn narrow_attention_fuses_in_unfused_path() {
15074        let f = DType::F32;
15075        let mut g = Graph::new("nattn_fuse");
15076        // batch*seq = 8*16 = 128 > 64 so FusedAttnBlock skips.
15077        let qkv = g.input("qkv", Shape::new(&[8, 16, 192], f)); // 3*64 = 192
15078        let mask = g.input("mask", Shape::new(&[8, 16], f));
15079        let q = g.narrow_(qkv, 2, 0, 64);
15080        let k = g.narrow_(qkv, 2, 64, 64);
15081        let v = g.narrow_(qkv, 2, 128, 64);
15082        let attn = g.attention(q, k, v, mask, 4, 16, Shape::new(&[8, 16, 64], f));
15083        g.set_outputs(vec![attn]);
15084
15085        let plan = rlx_opt::memory::plan_memory(&g);
15086        let arena = crate::arena::Arena::from_plan(plan);
15087        let sched = compile_thunks(&g, &arena);
15088
15089        let mut narrow_count = 0;
15090        let mut attn_strides: Option<(u32, u32, u32)> = None;
15091        for t in &sched.thunks {
15092            match t {
15093                Thunk::Narrow { .. } => narrow_count += 1,
15094                Thunk::Attention {
15095                    q_row_stride,
15096                    k_row_stride,
15097                    v_row_stride,
15098                    ..
15099                } => attn_strides = Some((*q_row_stride, *k_row_stride, *v_row_stride)),
15100                _ => {}
15101            }
15102        }
15103        // After fusion the 3 narrows are gone; Attention now walks the
15104        // QKV with parent row stride = 192 (3 × 64) on all three inputs.
15105        assert_eq!(
15106            narrow_count, 0,
15107            "Narrow×3→Attention fusion should eliminate all 3 narrows; saw {narrow_count}"
15108        );
15109        assert_eq!(
15110            attn_strides,
15111            Some((192, 192, 192)),
15112            "Attention should walk Q/K/V with parent row stride 192"
15113        );
15114    }
15115
15116    // ── Backward / training op parity tests ────────────────────
15117    //
15118    // Strategy: build a graph that contains exactly the backward op
15119    // under test (plus its inputs as graph Inputs), execute, and
15120    // compare against a hand-rolled scalar reference. For
15121    // Conv2dBackwardInput we additionally check against the numerical
15122    // gradient of the forward Conv2D — that's the gold-standard test
15123    // that validates the math, not just consistency between two
15124    // implementations of the same formula.
15125
15126    fn run_graph(
15127        g: &Graph,
15128        inputs: &[(NodeId, &[f32])],
15129        out_id: NodeId,
15130        out_len: usize,
15131    ) -> Vec<f32> {
15132        let plan = rlx_opt::memory::plan_memory(g);
15133        let mut arena = crate::arena::Arena::from_plan(plan);
15134        let sched = compile_thunks(g, &arena);
15135        for &(id, data) in inputs {
15136            let off = arena.byte_offset(id);
15137            let buf = arena.raw_buf_mut();
15138            unsafe {
15139                let p = buf.as_mut_ptr().add(off) as *mut f32;
15140                for (i, &v) in data.iter().enumerate() {
15141                    *p.add(i) = v;
15142                }
15143            }
15144        }
15145        execute_thunks(&sched, arena.raw_buf_mut());
15146        let off = arena.byte_offset(out_id);
15147        unsafe {
15148            let p = arena.raw_buf().as_ptr().add(off) as *const f32;
15149            (0..out_len).map(|i| *p.add(i)).collect()
15150        }
15151    }
15152
15153    #[test]
15154    fn relu_backward_matches_mask() {
15155        let f = DType::F32;
15156        let len = 7usize;
15157        let x: Vec<f32> = vec![-2.0, -0.1, 0.0, 0.1, 1.0, 3.0, -5.0];
15158        let dy: Vec<f32> = vec![0.5, 1.5, 2.5, -0.7, 4.0, -1.0, 9.0];
15159
15160        let mut g = Graph::new("relu_bw");
15161        let xn = g.input("x", Shape::new(&[len], f));
15162        let dyn_ = g.input("dy", Shape::new(&[len], f));
15163        let dx = g.relu_backward(xn, dyn_);
15164        g.set_outputs(vec![dx]);
15165
15166        let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, len);
15167        // Reference: gradient is dy where x>0 strictly, else 0.
15168        // (zero is not "positive" — the forward applied max(0, x), and at
15169        // x=0 the subgradient could be anything in [0, dy]; we pick 0.)
15170        let expected: Vec<f32> = x
15171            .iter()
15172            .zip(&dy)
15173            .map(|(&xi, &dyi)| if xi > 0.0 { dyi } else { 0.0 })
15174            .collect();
15175        for (a, e) in actual.iter().zip(&expected) {
15176            assert!((a - e).abs() < 1e-6, "relu_bw mismatch: {a} vs {e}");
15177        }
15178    }
15179
15180    #[test]
15181    fn maxpool2d_backward_routes_to_argmax() {
15182        let f = DType::F32;
15183        // [N=1, C=1, H=4, W=4] → 2x2 max-pool stride 2 → [1,1,2,2].
15184        let x: Vec<f32> = vec![
15185            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
15186        ];
15187        // Argmax of each 2x2 window:
15188        //   (0,0)→6 (idx 5), (0,1)→8 (idx 7),
15189        //   (1,0)→14(idx 13),(1,1)→16(idx 15).
15190        let dy: Vec<f32> = vec![0.5, 1.0, 2.0, 4.0];
15191
15192        let mut g = Graph::new("maxpool_bw");
15193        let xn = g.input("x", Shape::new(&[1, 1, 4, 4], f));
15194        let dyn_ = g.input("dy", Shape::new(&[1, 1, 2, 2], f));
15195        let dx = g.maxpool2d_backward(xn, dyn_, vec![2, 2], vec![2, 2], vec![0, 0]);
15196        g.set_outputs(vec![dx]);
15197
15198        let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, 16);
15199        let mut expected = vec![0f32; 16];
15200        expected[5] = 0.5;
15201        expected[7] = 1.0;
15202        expected[13] = 2.0;
15203        expected[15] = 4.0;
15204        for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
15205            assert!((a - e).abs() < 1e-6, "maxpool_bw[{i}] mismatch: {a} vs {e}");
15206        }
15207    }
15208
15209    #[test]
15210    fn conv2d_backward_input_matches_numerical_gradient() {
15211        use rlx_ir::Philox4x32;
15212        // Small enough to numerically differentiate exhaustively but
15213        // big enough to exercise stride/padding edge cases.
15214        let n = 1usize;
15215        let c_in = 2usize;
15216        let h = 4usize;
15217        let w = 4usize;
15218        let c_out = 3usize;
15219        let kh = 3usize;
15220        let kw = 3usize;
15221        let ph = 1usize;
15222        let pw = 1usize;
15223        let sh = 1usize;
15224        let sw = 1usize;
15225        // Output dims with padding=1, stride=1: same as input.
15226        let h_out = (h + 2 * ph - kh) / sh + 1;
15227        let w_out = (w + 2 * pw - kw) / sw + 1;
15228        assert_eq!(h_out, 4);
15229        assert_eq!(w_out, 4);
15230
15231        let mut rng = Philox4x32::new(7);
15232        let mut x = vec![0f32; n * c_in * h * w];
15233        rng.fill_normal(&mut x);
15234        let mut wt = vec![0f32; c_out * c_in * kh * kw];
15235        rng.fill_normal(&mut wt);
15236        let mut dy = vec![0f32; n * c_out * h_out * w_out];
15237        rng.fill_normal(&mut dy);
15238
15239        // Analytical: Conv2dBackwardInput on (dy, w).
15240        let f = DType::F32;
15241        let mut g = Graph::new("conv_bwi");
15242        let dy_in = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
15243        let w_in = g.input("w", Shape::new(&[c_out, c_in, kh, kw], f));
15244        let dx = g.conv2d_backward_input(
15245            dy_in,
15246            w_in,
15247            Shape::new(&[n, c_in, h, w], f),
15248            vec![kh, kw],
15249            vec![sh, sw],
15250            vec![ph, pw],
15251            vec![1, 1],
15252            1,
15253        );
15254        g.set_outputs(vec![dx]);
15255        let analytical = run_graph(&g, &[(dy_in, &dy), (w_in, &wt)], dx, n * c_in * h * w);
15256
15257        // Numerical: for each x[i], finite-difference forward conv twice.
15258        // Forward: y[j] = sum over filter window of w * x ; dot(dy, y) is
15259        // the scalar we differentiate. Then dx[i] = ∂(dot(dy, y))/∂x[i].
15260        let forward = |x: &[f32]| -> Vec<f32> {
15261            let mut out = vec![0f32; n * c_out * h_out * w_out];
15262            for ni in 0..n {
15263                for co in 0..c_out {
15264                    for ho in 0..h_out {
15265                        for wo in 0..w_out {
15266                            let mut acc = 0f32;
15267                            for ci in 0..c_in {
15268                                for ki in 0..kh {
15269                                    for kj in 0..kw {
15270                                        let hi = ho * sh + ki;
15271                                        let wi = wo * sw + kj;
15272                                        if hi < ph || wi < pw {
15273                                            continue;
15274                                        }
15275                                        let hi = hi - ph;
15276                                        let wi = wi - pw;
15277                                        if hi >= h || wi >= w {
15278                                            continue;
15279                                        }
15280                                        let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
15281                                        let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
15282                                        acc += xv * wv;
15283                                    }
15284                                }
15285                            }
15286                            out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
15287                        }
15288                    }
15289                }
15290            }
15291            out
15292        };
15293        let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
15294        let eps = 1e-3f32;
15295        let mut numerical = vec![0f32; x.len()];
15296        for i in 0..x.len() {
15297            let saved = x[i];
15298            x[i] = saved + eps;
15299            let plus = dot(&forward(&x), &dy);
15300            x[i] = saved - eps;
15301            let minus = dot(&forward(&x), &dy);
15302            x[i] = saved;
15303            numerical[i] = (plus - minus) / (2.0 * eps);
15304        }
15305        for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
15306            // f32 + eps=1e-3 numerical grad → ~1e-3 absolute is realistic.
15307            assert!(
15308                (a - n).abs() < 5e-3,
15309                "conv_bw_input[{i}]: analytical {a} vs numerical {n}"
15310            );
15311        }
15312    }
15313
15314    #[test]
15315    fn conv2d_backward_weight_matches_numerical_gradient() {
15316        use rlx_ir::Philox4x32;
15317        let n = 2usize;
15318        let c_in = 2usize;
15319        let h = 4usize;
15320        let w = 4usize;
15321        let c_out = 2usize;
15322        let kh = 3usize;
15323        let kw = 3usize;
15324        let ph = 0usize;
15325        let pw = 0usize;
15326        let sh = 1usize;
15327        let sw = 1usize;
15328        let h_out = (h + 2 * ph - kh) / sh + 1;
15329        let w_out = (w + 2 * pw - kw) / sw + 1;
15330
15331        let mut rng = Philox4x32::new(11);
15332        let mut x = vec![0f32; n * c_in * h * w];
15333        rng.fill_normal(&mut x);
15334        let mut wt = vec![0f32; c_out * c_in * kh * kw];
15335        rng.fill_normal(&mut wt);
15336        let mut dy = vec![0f32; n * c_out * h_out * w_out];
15337        rng.fill_normal(&mut dy);
15338
15339        let f = DType::F32;
15340        let mut g = Graph::new("conv_bww");
15341        let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
15342        let dyn_ = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
15343        let dwn = g.conv2d_backward_weight(
15344            xn,
15345            dyn_,
15346            Shape::new(&[c_out, c_in, kh, kw], f),
15347            vec![kh, kw],
15348            vec![sh, sw],
15349            vec![ph, pw],
15350            vec![1, 1],
15351            1,
15352        );
15353        g.set_outputs(vec![dwn]);
15354        let analytical = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dwn, c_out * c_in * kh * kw);
15355
15356        let forward = |wt: &[f32]| -> Vec<f32> {
15357            let mut out = vec![0f32; n * c_out * h_out * w_out];
15358            for ni in 0..n {
15359                for co in 0..c_out {
15360                    for ho in 0..h_out {
15361                        for wo in 0..w_out {
15362                            let mut acc = 0f32;
15363                            for ci in 0..c_in {
15364                                for ki in 0..kh {
15365                                    for kj in 0..kw {
15366                                        let hi = ho + ki;
15367                                        let wi = wo + kj;
15368                                        let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
15369                                        let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
15370                                        acc += xv * wv;
15371                                    }
15372                                }
15373                            }
15374                            out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
15375                        }
15376                    }
15377                }
15378            }
15379            out
15380        };
15381        let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
15382        let eps = 1e-3f32;
15383        let mut numerical = vec![0f32; wt.len()];
15384        for i in 0..wt.len() {
15385            let saved = wt[i];
15386            wt[i] = saved + eps;
15387            let plus = dot(&forward(&wt), &dy);
15388            wt[i] = saved - eps;
15389            let minus = dot(&forward(&wt), &dy);
15390            wt[i] = saved;
15391            numerical[i] = (plus - minus) / (2.0 * eps);
15392        }
15393        for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
15394            assert!(
15395                (a - n).abs() < 5e-3,
15396                "conv_bw_weight[{i}]: analytical {a} vs numerical {n}"
15397            );
15398        }
15399    }
15400
15401    #[test]
15402    fn softmax_cross_entropy_matches_reference() {
15403        let f = DType::F32;
15404        let logits: Vec<f32> = vec![
15405            1.0, 2.0, 3.0, // row 0: max=3 (idx 2)
15406            -1.0, 0.0, 4.0, // row 1: max=4 (idx 2)
15407            5.0, 5.0, 5.0, // row 2: uniform
15408        ];
15409        let labels: Vec<f32> = vec![2.0, 0.0, 1.0];
15410
15411        let mut g = Graph::new("sce");
15412        let lg = g.input("logits", Shape::new(&[3, 3], f));
15413        let lb = g.input("labels", Shape::new(&[3], f));
15414        let loss = g.softmax_cross_entropy_with_logits(lg, lb);
15415        g.set_outputs(vec![loss]);
15416        let actual = run_graph(&g, &[(lg, &logits), (lb, &labels)], loss, 3);
15417
15418        // Reference per-row: -log(softmax(row)[label]).
15419        let mut expected = vec![0f32; 3];
15420        for ni in 0..3 {
15421            let row = &logits[ni * 3..(ni + 1) * 3];
15422            let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
15423            let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
15424            let lse = m + sum.ln();
15425            let label_idx = labels[ni] as usize;
15426            expected[ni] = lse - row[label_idx];
15427        }
15428        for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
15429            assert!((a - e).abs() < 1e-5, "sce loss[{i}]: {a} vs {e}");
15430        }
15431    }
15432
15433    #[test]
15434    fn softmax_cross_entropy_backward_matches_numerical_gradient() {
15435        use rlx_ir::Philox4x32;
15436        let n = 4usize;
15437        let c = 5usize;
15438        let mut rng = Philox4x32::new(23);
15439        let mut logits = vec![0f32; n * c];
15440        rng.fill_normal(&mut logits);
15441        let labels: Vec<f32> = (0..n).map(|i| (i % c) as f32).collect();
15442        let mut d_loss = vec![0f32; n];
15443        rng.fill_normal(&mut d_loss);
15444
15445        let f = DType::F32;
15446        let mut g = Graph::new("sce_bw");
15447        let lg = g.input("logits", Shape::new(&[n, c], f));
15448        let lb = g.input("labels", Shape::new(&[n], f));
15449        let dl = g.input("d_loss", Shape::new(&[n], f));
15450        let dlogits = g.softmax_cross_entropy_backward(lg, lb, dl);
15451        g.set_outputs(vec![dlogits]);
15452        let analytical = run_graph(
15453            &g,
15454            &[(lg, &logits), (lb, &labels), (dl, &d_loss)],
15455            dlogits,
15456            n * c,
15457        );
15458
15459        // Numerical: differentiate dot(d_loss, sce_loss(logits)) w.r.t. each logit.
15460        let sce_loss = |logits: &[f32]| -> Vec<f32> {
15461            let mut out = vec![0f32; n];
15462            for ni in 0..n {
15463                let row = &logits[ni * c..(ni + 1) * c];
15464                let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
15465                let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
15466                out[ni] = (m + sum.ln()) - row[labels[ni] as usize];
15467            }
15468            out
15469        };
15470        let dot = |a: &[f32], b: &[f32]| a.iter().zip(b).map(|(&u, &v)| u * v).sum::<f32>();
15471        let eps = 1e-3f32;
15472        let mut numerical = vec![0f32; logits.len()];
15473        for i in 0..logits.len() {
15474            let saved = logits[i];
15475            logits[i] = saved + eps;
15476            let plus = dot(&sce_loss(&logits), &d_loss);
15477            logits[i] = saved - eps;
15478            let minus = dot(&sce_loss(&logits), &d_loss);
15479            logits[i] = saved;
15480            numerical[i] = (plus - minus) / (2.0 * eps);
15481        }
15482        for (i, (a, num)) in analytical.iter().zip(&numerical).enumerate() {
15483            assert!(
15484                (a - num).abs() < 5e-3,
15485                "sce_bw[{i}]: analytical {a} vs numerical {num}"
15486            );
15487        }
15488    }
15489
15490    // ── End-to-end autodiff parity tests ──────────────────────
15491    //
15492    // Build a forward graph, run `grad_with_loss` to produce a graph
15493    // that emits [loss, gradients...], execute it through rlx-cpu,
15494    // and compare each gradient to a finite-difference estimate
15495    // produced by re-running the forward graph with each parameter
15496    // entry perturbed. f32 + ε=1e-3 puts the tolerance floor around
15497    // 5e-3 absolute error.
15498
15499    /// Initialize Op::Constant slots in the arena with their literal
15500    /// data. Mirrors the loop in rlx_runtime::backend (which serves
15501    /// the same role for production runs).
15502    fn fill_constants_into_arena(graph: &Graph, arena: &mut crate::arena::Arena) {
15503        for node in graph.nodes() {
15504            if let Op::Constant { data } = &node.op
15505                && arena.has_buffer(node.id)
15506                && !data.is_empty()
15507            {
15508                let buf = arena.slice_mut(node.id);
15509                let n_floats = data.len() / 4;
15510                let n = buf.len().min(n_floats);
15511                for i in 0..n {
15512                    let bytes = [
15513                        data[i * 4],
15514                        data[i * 4 + 1],
15515                        data[i * 4 + 2],
15516                        data[i * 4 + 3],
15517                    ];
15518                    buf[i] = f32::from_le_bytes(bytes);
15519                }
15520            }
15521        }
15522    }
15523
15524    /// Compile + arena-prep helper for these tests. Returns the
15525    /// schedule and a populated arena. `seed_inputs` writes f32 input
15526    /// data into the arena slot for each (NodeId, &[f32]) pair.
15527    fn prepare(
15528        graph: &Graph,
15529        seed_inputs: &[(NodeId, &[f32])],
15530    ) -> (ThunkSchedule, crate::arena::Arena) {
15531        let plan = rlx_opt::memory::plan_memory(graph);
15532        let mut arena = crate::arena::Arena::from_plan(plan);
15533        let sched = compile_thunks(graph, &arena);
15534        fill_constants_into_arena(graph, &mut arena);
15535        for &(id, data) in seed_inputs {
15536            let off = arena.byte_offset(id);
15537            let buf = arena.raw_buf_mut();
15538            unsafe {
15539                let p = buf.as_mut_ptr().add(off) as *mut f32;
15540                for (i, &v) in data.iter().enumerate() {
15541                    *p.add(i) = v;
15542                }
15543            }
15544        }
15545        (sched, arena)
15546    }
15547
15548    fn read_arena(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f32> {
15549        let off = arena.byte_offset(id);
15550        unsafe {
15551            let p = arena.raw_buf().as_ptr().add(off) as *const f32;
15552            (0..len).map(|i| *p.add(i)).collect()
15553        }
15554    }
15555
15556    fn write_arena(arena: &mut crate::arena::Arena, id: NodeId, data: &[f32]) {
15557        let off = arena.byte_offset(id);
15558        let buf = arena.raw_buf_mut();
15559        unsafe {
15560            let p = buf.as_mut_ptr().add(off) as *mut f32;
15561            for (i, &v) in data.iter().enumerate() {
15562                *p.add(i) = v;
15563            }
15564        }
15565    }
15566
15567    /// f64 sibling of `prepare`. Writes f64 input data into the arena.
15568    fn prepare_f64(
15569        graph: &Graph,
15570        seed_inputs: &[(NodeId, &[f64])],
15571    ) -> (ThunkSchedule, crate::arena::Arena) {
15572        let plan = rlx_opt::memory::plan_memory(graph);
15573        let mut arena = crate::arena::Arena::from_plan(plan);
15574        let sched = compile_thunks(graph, &arena);
15575        fill_constants_into_arena(graph, &mut arena);
15576        for &(id, data) in seed_inputs {
15577            let off = arena.byte_offset(id);
15578            let buf = arena.raw_buf_mut();
15579            unsafe {
15580                let p = buf.as_mut_ptr().add(off) as *mut f64;
15581                for (i, &v) in data.iter().enumerate() {
15582                    *p.add(i) = v;
15583                }
15584            }
15585        }
15586        (sched, arena)
15587    }
15588
15589    fn read_arena_f64(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f64> {
15590        let off = arena.byte_offset(id);
15591        unsafe {
15592            let p = arena.raw_buf().as_ptr().add(off) as *const f64;
15593            (0..len).map(|i| *p.add(i)).collect()
15594        }
15595    }
15596
15597    /// End-to-end f64 DenseSolve through the full compile + execute
15598    /// path. Validates: IR shape inference, memory planner f64 sizing,
15599    /// arena f64 accessors, Thunk::DenseSolveF64 lowering, executor
15600    /// dispatch, Accelerate dgesv FFI.
15601    ///
15602    /// System:
15603    ///   A = [[2, 1],
15604    ///        [1, 3]]   b = [5, 10]
15605    ///   ⇒  x = [1, 3]   (verified by hand)
15606    #[test]
15607    fn dense_solve_f64_end_to_end() {
15608        let mut g = Graph::new("solve_e2e");
15609        let a = g.input("A", Shape::new(&[2, 2], DType::F64));
15610        let b = g.input("b", Shape::new(&[2], DType::F64));
15611        let x = g.dense_solve(a, b, Shape::new(&[2], DType::F64));
15612        g.set_outputs(vec![x]);
15613
15614        let a_data = [2.0, 1.0, 1.0, 3.0_f64];
15615        let b_data = [5.0, 10.0_f64];
15616        let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
15617        execute_thunks(&sched, arena.raw_buf_mut());
15618
15619        let got = read_arena_f64(&arena, x, 2);
15620        let want = [1.0, 3.0_f64];
15621        for i in 0..2 {
15622            assert!(
15623                (got[i] - want[i]).abs() < 1e-12,
15624                "x[{i}] = {} (expected {})",
15625                got[i],
15626                want[i]
15627            );
15628        }
15629    }
15630
15631    /// Scaled-up f64 DenseSolve — tridiagonal Laplacian-shape (typical
15632    /// MNA structure for a passive RC mesh in Circulax). Validates
15633    /// that the solve scales beyond the trivial 2×2 and that the
15634    /// row-major ↔ col-major dance in `dgesv` is correct for the
15635    /// general case.
15636    #[test]
15637    fn dense_solve_f64_5x5_laplacian() {
15638        let n = 5usize;
15639        let mut g = Graph::new("solve_5x5");
15640        let a = g.input("A", Shape::new(&[n, n], DType::F64));
15641        let b = g.input("b", Shape::new(&[n], DType::F64));
15642        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
15643        g.set_outputs(vec![x]);
15644
15645        // 1-D Laplacian: 2 on diagonal, -1 on off-diagonals, 0 elsewhere.
15646        let mut a_data = vec![0.0_f64; n * n];
15647        for i in 0..n {
15648            a_data[i * n + i] = 2.0;
15649            if i > 0 {
15650                a_data[i * n + (i - 1)] = -1.0;
15651            }
15652            if i + 1 < n {
15653                a_data[i * n + (i + 1)] = -1.0;
15654            }
15655        }
15656        let b_data: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
15657        let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
15658        execute_thunks(&sched, arena.raw_buf_mut());
15659
15660        let got = read_arena_f64(&arena, x, n);
15661        // Verify A·x ≈ b by computing the residual.
15662        let mut residual = vec![0.0_f64; n];
15663        for i in 0..n {
15664            for j in 0..n {
15665                residual[i] += a_data[i * n + j] * got[j];
15666            }
15667        }
15668        for i in 0..n {
15669            assert!(
15670                (residual[i] - b_data[i]).abs() < 1e-10,
15671                "row {i}: residual {} vs b {}",
15672                residual[i],
15673                b_data[i]
15674            );
15675        }
15676    }
15677
15678    /// Hello Resistor: end-to-end f64 gradient through a dense solve.
15679    ///
15680    /// Forward:
15681    ///   A      : Param  [N, N]   f64
15682    ///   b      : Input  [N]      f64
15683    ///   x      = solve(A, b)            (DenseSolve)
15684    ///   loss   = sum(x)                 (Reduce::Sum)
15685    ///
15686    /// Backward (via grad_with_loss):
15687    ///   ones [N] = expand(d_output, [N])      (Reduce::Sum VJP)
15688    ///   dx_int   = solve(Aᵀ, ones)             (DenseSolve VJP step 1)
15689    ///   dA       = -outer(dx_int, x)           (DenseSolve VJP step 2)
15690    ///   db       = dx_int                       (DenseSolve VJP step 3)
15691    ///
15692    /// Closed form: with loss = sum(solve(A, b)) = ones·x and
15693    /// implicit-function calculus, db = (Aᵀ)⁻¹·ones, dA = -db ⊗ x.
15694    /// We verify this against the autodiff-emitted graph's output and
15695    /// against a finite-difference baseline.
15696    #[test]
15697    fn hello_resistor_gradient_end_to_end() {
15698        use rlx_opt::autodiff::grad_with_loss;
15699        let n = 3usize;
15700
15701        // ── Build forward graph ──
15702        let mut g = Graph::new("hello_resistor");
15703        let a = g.param("A", Shape::new(&[n, n], DType::F64));
15704        let b = g.input("b", Shape::new(&[n], DType::F64));
15705        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
15706        let loss = g.reduce(
15707            x,
15708            ReduceOp::Sum,
15709            vec![0],
15710            false,
15711            Shape::new(&[1], DType::F64),
15712        );
15713        g.set_outputs(vec![loss]);
15714
15715        // ── Run reverse-mode AD ──
15716        let bwd = grad_with_loss(&g, &[a, b]);
15717        assert_eq!(bwd.outputs.len(), 3, "expect [loss, dA, db]");
15718
15719        // ── Locate the inputs the bwd graph still needs from us ──
15720        // grad_with_loss copies forward nodes into bwd, so A/b/d_output
15721        // appear under their original names. Find them by name.
15722        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
15723            for node in graph.nodes() {
15724                let name = match &node.op {
15725                    rlx_ir::Op::Input { name } => Some(name.as_str()),
15726                    rlx_ir::Op::Param { name } => Some(name.as_str()),
15727                    _ => None,
15728                };
15729                if name == Some(want) {
15730                    return node.id;
15731                }
15732            }
15733            panic!("no node named {want:?} in bwd graph");
15734        };
15735        let a_bwd = find_by_name(&bwd, "A");
15736        let b_bwd = find_by_name(&bwd, "b");
15737        let d_out_bwd = find_by_name(&bwd, "d_output");
15738
15739        // ── Test data ──
15740        // A = [[2,1,0],[1,3,1],[0,1,2]]   (SPD tridiagonal, well-conditioned)
15741        // b = [1,2,3]
15742        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
15743        let b_data = [1.0, 2.0, 3.0_f64];
15744        let d_output = [1.0_f64]; // ∂loss/∂loss
15745
15746        // ── Compile + execute backward graph ──
15747        let (sched, mut arena) = prepare_f64(
15748            &bwd,
15749            &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out_bwd, &d_output)],
15750        );
15751        execute_thunks(&sched, arena.raw_buf_mut());
15752
15753        let loss_out = read_arena_f64(&arena, bwd.outputs[0], 1);
15754        let da_out = read_arena_f64(&arena, bwd.outputs[1], n * n);
15755        let db_out = read_arena_f64(&arena, bwd.outputs[2], n);
15756
15757        // ── Closed-form reference ──
15758        // x = A⁻¹ b ; loss = sum(x).
15759        let x_ref = {
15760            let mut a = a_data;
15761            let mut b = b_data;
15762            let info = crate::blas::dgesv(&mut a, &mut b, n, 1);
15763            assert_eq!(info, 0);
15764            b
15765        };
15766        let loss_ref: f64 = x_ref.iter().sum();
15767        // db = (Aᵀ)⁻¹ · 1
15768        let db_ref = {
15769            let mut at = [0.0_f64; 9];
15770            for i in 0..n {
15771                for j in 0..n {
15772                    at[i * n + j] = a_data[j * n + i];
15773                }
15774            }
15775            let mut ones = [1.0_f64; 3];
15776            let info = crate::blas::dgesv(&mut at, &mut ones, n, 1);
15777            assert_eq!(info, 0);
15778            ones
15779        };
15780        // dA = -outer(db, x) ; dA[i,j] = -db[i] * x[j]
15781        let mut da_ref = [0.0_f64; 9];
15782        for i in 0..n {
15783            for j in 0..n {
15784                da_ref[i * n + j] = -db_ref[i] * x_ref[j];
15785            }
15786        }
15787
15788        // ── Assertions vs analytic answer ──
15789        assert!(
15790            (loss_out[0] - loss_ref).abs() < 1e-10,
15791            "loss: got {}, want {}",
15792            loss_out[0],
15793            loss_ref
15794        );
15795        for i in 0..n {
15796            assert!(
15797                (db_out[i] - db_ref[i]).abs() < 1e-10,
15798                "db[{i}]: got {}, want {}",
15799                db_out[i],
15800                db_ref[i]
15801            );
15802        }
15803        for i in 0..n * n {
15804            assert!(
15805                (da_out[i] - da_ref[i]).abs() < 1e-10,
15806                "dA[{i}]: got {}, want {}",
15807                da_out[i],
15808                da_ref[i]
15809            );
15810        }
15811
15812        // ── Cross-check vs finite differences on db (a few entries) ──
15813        // ∂loss/∂b[k] ≈ (loss(b + h·e_k) - loss(b - h·e_k)) / (2h).
15814        let h = 1e-6_f64;
15815        for k in 0..n {
15816            let mut bp = b_data;
15817            bp[k] += h;
15818            let mut bm = b_data;
15819            bm[k] -= h;
15820            let lp = {
15821                let mut ac = a_data;
15822                let info = crate::blas::dgesv(&mut ac, &mut bp, n, 1);
15823                assert_eq!(info, 0);
15824                bp.iter().sum::<f64>()
15825            };
15826            let lm = {
15827                let mut ac = a_data;
15828                let info = crate::blas::dgesv(&mut ac, &mut bm, n, 1);
15829                assert_eq!(info, 0);
15830                bm.iter().sum::<f64>()
15831            };
15832            let fd = (lp - lm) / (2.0 * h);
15833            assert!(
15834                (db_out[k] - fd).abs() < 1e-7,
15835                "FD mismatch on db[{k}]: AD={} FD={}",
15836                db_out[k],
15837                fd
15838            );
15839        }
15840    }
15841
15842    /// Smallest possible Op::Scan basic test: geometric growth.
15843    /// init = [1, 1, 1] f64, body = (x → x + 0.1·x) = (x → 1.1·x),
15844    /// length = 10. Final carry must equal init·(1.1)^10 ≈ 2.5937…
15845    /// to f64 precision.
15846    #[test]
15847    fn scan_geometric_growth_f64() {
15848        let n = 3usize;
15849        let length = 10u32;
15850
15851        // Body: (x) → x + 0.1·x. One Input, one output, same shape/dtype.
15852        let mut body = Graph::new("scan_body");
15853        let x = body.input("carry", Shape::new(&[n], DType::F64));
15854        let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 0.1_f64.to_le_bytes()).collect();
15855        let scale = body.add_node(
15856            Op::Constant { data: scale_bytes },
15857            vec![],
15858            Shape::new(&[n], DType::F64),
15859        );
15860        let scaled = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
15861        let next = body.binary(BinaryOp::Add, x, scaled, Shape::new(&[n], DType::F64));
15862        body.set_outputs(vec![next]);
15863
15864        // Outer graph: scan(init, body, length).
15865        let mut g = Graph::new("scan_outer");
15866        let init = g.input("init", Shape::new(&[n], DType::F64));
15867        let final_carry = g.scan(init, body, length);
15868        g.set_outputs(vec![final_carry]);
15869
15870        let init_data = vec![1.0_f64; n];
15871        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
15872        execute_thunks(&sched, arena.raw_buf_mut());
15873        let got = read_arena_f64(&arena, final_carry, n);
15874        let want: f64 = 1.1_f64.powi(length as i32);
15875        for i in 0..n {
15876            assert!(
15877                (got[i] - want).abs() < 1e-12,
15878                "got[{i}] = {} want {}",
15879                got[i],
15880                want
15881            );
15882        }
15883    }
15884
15885    /// Per-step xs scan: cumulative-sum.
15886    ///   carry_0 = init
15887    ///   carry_{t+1} = carry_t + xs\[t\]
15888    ///   final = sum_{t<length} xs\[t\] + init
15889    /// Body has 2 inputs (carry, x_t) in that NodeId order; one output
15890    /// (next carry). Validates the per-step-input plumbing end-to-end.
15891    #[test]
15892    fn scan_with_xs_cumulative_sum() {
15893        let n = 3usize;
15894        let length = 4u32;
15895
15896        let mut body = Graph::new("cumsum_body");
15897        // carry must come first in NodeId order — declare it first.
15898        let carry = body.input("carry", Shape::new(&[n], DType::F64));
15899        let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
15900        let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
15901        body.set_outputs(vec![next]);
15902
15903        let mut g = Graph::new("cumsum_outer");
15904        let init = g.input("init", Shape::new(&[n], DType::F64));
15905        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
15906        let final_carry = g.scan_with_xs(init, &[xs], body, length);
15907        g.set_outputs(vec![final_carry]);
15908
15909        let init_data = vec![0.0_f64; n];
15910        let xs_data: Vec<f64> = (0..length as usize * n).map(|i| (i + 1) as f64).collect(); // 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12
15911        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
15912        execute_thunks(&sched, arena.raw_buf_mut());
15913        let got = read_arena_f64(&arena, final_carry, n);
15914
15915        // Reference: column-wise sum of xs rows + init. With our row-major
15916        // layout, column j of xs is xs_data[j], xs_data[n+j], xs_data[2n+j], ...
15917        // (per-step row at offset t*n contributes element j to slot j).
15918        let mut want = init_data.clone();
15919        for t in 0..length as usize {
15920            for j in 0..n {
15921                want[j] += xs_data[t * n + j];
15922            }
15923        }
15924        for i in 0..n {
15925            assert!(
15926                (got[i] - want[i]).abs() < 1e-12,
15927                "got[{i}] = {} want {}",
15928                got[i],
15929                want[i]
15930            );
15931        }
15932    }
15933
15934    /// Per-step xs scan composing with DenseSolve — Circulax-shaped:
15935    ///   carry_{t+1} = solve(M, carry_t + xs\[t\])
15936    /// Models a Backward-Euler step driven by a time-varying source.
15937    #[test]
15938    fn scan_with_xs_be_with_drive() {
15939        let n = 3usize;
15940        let length = 4u32;
15941        let dt = 0.1_f64;
15942
15943        let mut m_data = vec![0.0_f64; n * n];
15944        for i in 0..n {
15945            m_data[i * n + i] = 1.0 + dt * 2.0;
15946            if i > 0 {
15947                m_data[i * n + (i - 1)] = -dt;
15948            }
15949            if i + 1 < n {
15950                m_data[i * n + (i + 1)] = -dt;
15951            }
15952        }
15953        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
15954
15955        let mut body = Graph::new("be_drive_body");
15956        let carry = body.input("carry", Shape::new(&[n], DType::F64));
15957        let drive = body.input("drive", Shape::new(&[n], DType::F64));
15958        let m = body.add_node(
15959            Op::Constant { data: m_bytes },
15960            vec![],
15961            Shape::new(&[n, n], DType::F64),
15962        );
15963        let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
15964        let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
15965        body.set_outputs(vec![next]);
15966
15967        let mut g = Graph::new("be_drive_outer");
15968        let init = g.input("init", Shape::new(&[n], DType::F64));
15969        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
15970        let final_carry = g.scan_with_xs(init, &[xs], body, length);
15971        g.set_outputs(vec![final_carry]);
15972
15973        let init_data = vec![0.0_f64; n];
15974        // Drive the system with a unit pulse on element 0 at t=0,
15975        // zeros after.
15976        let mut xs_data = vec![0.0_f64; length as usize * n];
15977        xs_data[0] = 1.0;
15978
15979        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
15980        execute_thunks(&sched, arena.raw_buf_mut());
15981        let got = read_arena_f64(&arena, final_carry, n);
15982
15983        // Reference: per-step in pure Rust.
15984        let mut x = init_data.clone();
15985        for t in 0..length as usize {
15986            for j in 0..n {
15987                x[j] += xs_data[t * n + j];
15988            }
15989            let mut a_copy = m_data.clone();
15990            crate::blas::dgesv(&mut a_copy, &mut x, n, 1);
15991        }
15992        for i in 0..n {
15993            assert!(
15994                (got[i] - x[i]).abs() < 1e-12,
15995                "got[{i}] = {} ref {}",
15996                got[i],
15997                x[i]
15998            );
15999        }
16000    }
16001
16002    /// Reverse-mode AD through Op::BatchedDenseSolve. Forward solves
16003    /// `[B, N, N] · x = [B, N]`; loss = sum of all entries. Closed
16004    /// form: dB = (Aᵀ)⁻¹·1, dA = -(Aᵀ)⁻¹·1 ⊗ x. Verified analytically
16005    /// per batch (each slice matches what the unbatched DenseSolve VJP
16006    /// would compute).
16007    #[test]
16008    fn batched_dense_solve_gradient_matches_per_batch_analytic() {
16009        use rlx_opt::autodiff::grad_with_loss;
16010        let n = 3usize;
16011        let batch = 4usize;
16012
16013        let mut g = Graph::new("bds_grad");
16014        let a = g.param("A", Shape::new(&[batch, n, n], DType::F64));
16015        let b = g.input("b", Shape::new(&[batch, n], DType::F64));
16016        let x = g.batched_dense_solve(a, b, Shape::new(&[batch, n], DType::F64));
16017        let loss = g.reduce(
16018            x,
16019            ReduceOp::Sum,
16020            vec![0, 1],
16021            false,
16022            Shape::new(&[1], DType::F64),
16023        );
16024        g.set_outputs(vec![loss]);
16025
16026        let bwd = grad_with_loss(&g, &[a, b]);
16027
16028        let find = |graph: &Graph, want: &str| -> NodeId {
16029            for node in graph.nodes() {
16030                let name = match &node.op {
16031                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16032                    _ => None,
16033                };
16034                if name == Some(want) {
16035                    return node.id;
16036                }
16037            }
16038            panic!("no node named {want}");
16039        };
16040        let a_id = find(&bwd, "A");
16041        let b_id = find(&bwd, "b");
16042        let d_out_id = find(&bwd, "d_output");
16043
16044        let mut rng = rlx_ir::Philox4x32::new(0x57e1_u64);
16045        let mut a_data = vec![0.0_f64; batch * n * n];
16046        let mut b_data = vec![0.0_f64; batch * n];
16047        for bi in 0..batch {
16048            for i in 0..n {
16049                for j in 0..n {
16050                    a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
16051                }
16052                a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
16053            }
16054            for i in 0..n {
16055                b_data[bi * n + i] = rng.next_f32() as f64;
16056            }
16057        }
16058        let d_seed = [1.0_f64];
16059
16060        let (sched, mut arena) = prepare_f64(
16061            &bwd,
16062            &[(a_id, &a_data), (b_id, &b_data), (d_out_id, &d_seed)],
16063        );
16064        execute_thunks(&sched, arena.raw_buf_mut());
16065        let da_out = read_arena_f64(&arena, bwd.outputs[1], batch * n * n);
16066        let db_out = read_arena_f64(&arena, bwd.outputs[2], batch * n);
16067
16068        // Reference: per-batch analytic solve. dB_i = (A_iᵀ)⁻¹ · 1,
16069        // dA_i = -dB_i ⊗ x_i.
16070        for bi in 0..batch {
16071            let a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
16072            let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
16073            let mut a_copy = a_slice.clone();
16074            crate::blas::dgesv(&mut a_copy, &mut b_slice, n, 1);
16075            let x_ref = b_slice.clone();
16076            // dB: solve(A^T, ones)
16077            let mut at = vec![0.0_f64; n * n];
16078            for i in 0..n {
16079                for j in 0..n {
16080                    at[i * n + j] = a_slice[j * n + i];
16081                }
16082            }
16083            let mut ones = vec![1.0_f64; n];
16084            crate::blas::dgesv(&mut at, &mut ones, n, 1);
16085            let db_ref = ones;
16086            for i in 0..n {
16087                let got = db_out[bi * n + i];
16088                assert!(
16089                    (got - db_ref[i]).abs() < 1e-10,
16090                    "batch {bi}, db[{i}]: got {got} ref {}",
16091                    db_ref[i]
16092                );
16093            }
16094            // dA: -outer(db, x)
16095            for i in 0..n {
16096                for j in 0..n {
16097                    let got = da_out[bi * n * n + i * n + j];
16098                    let want = -db_ref[i] * x_ref[j];
16099                    assert!(
16100                        (got - want).abs() < 1e-10,
16101                        "batch {bi}, dA[{i},{j}]: got {got} ref {want}"
16102                    );
16103                }
16104            }
16105        }
16106    }
16107
16108    /// AD knob: gradient through `scan_checkpointed` automatically
16109    /// uses the recompute backward path. Compares dinit from a plain
16110    /// scan against the same forward written with `scan_checkpointed`,
16111    /// both run through `grad_with_loss`. They must match to f64.
16112    #[test]
16113    fn scan_checkpointed_grad_matches_plain_scan_grad() {
16114        use rlx_opt::autodiff::grad_with_loss;
16115        let n = 2usize;
16116        let length = 6u32;
16117
16118        let make_body = || {
16119            let mut body = Graph::new("ck_body");
16120            let carry = body.input("carry", Shape::new(&[n], DType::F64));
16121            let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.05_f64.to_le_bytes()).collect();
16122            let scale = body.add_node(
16123                Op::Constant { data: scale_bytes },
16124                vec![],
16125                Shape::new(&[n], DType::F64),
16126            );
16127            let next = body.binary(BinaryOp::Mul, carry, scale, Shape::new(&[n], DType::F64));
16128            body.set_outputs(vec![next]);
16129            body
16130        };
16131
16132        // Plain scan path.
16133        let mut g_plain = Graph::new("ck_plain");
16134        let init_p = g_plain.input("init", Shape::new(&[n], DType::F64));
16135        let final_p = g_plain.scan(init_p, make_body(), length);
16136        let loss_p = g_plain.reduce(
16137            final_p,
16138            ReduceOp::Sum,
16139            vec![0],
16140            false,
16141            Shape::new(&[1], DType::F64),
16142        );
16143        g_plain.set_outputs(vec![loss_p]);
16144        let bwd_p = grad_with_loss(&g_plain, &[init_p]);
16145
16146        // Checkpointed scan path with K=2 (length=6).
16147        let mut g_ck = Graph::new("ck_ckpt");
16148        let init_c = g_ck.input("init", Shape::new(&[n], DType::F64));
16149        let final_c = g_ck.scan_checkpointed(init_c, make_body(), length, 2);
16150        let loss_c = g_ck.reduce(
16151            final_c,
16152            ReduceOp::Sum,
16153            vec![0],
16154            false,
16155            Shape::new(&[1], DType::F64),
16156        );
16157        g_ck.set_outputs(vec![loss_c]);
16158        let bwd_c = grad_with_loss(&g_ck, &[init_c]);
16159
16160        let find = |graph: &Graph, want: &str| -> NodeId {
16161            for node in graph.nodes() {
16162                let name = match &node.op {
16163                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16164                    _ => None,
16165                };
16166                if name == Some(want) {
16167                    return node.id;
16168                }
16169            }
16170            panic!("no {want}");
16171        };
16172
16173        let init_data = vec![0.5_f64, -0.5];
16174        let d_seed = [1.0_f64];
16175
16176        let (s_p, mut a_p) = prepare_f64(
16177            &bwd_p,
16178            &[
16179                (find(&bwd_p, "init"), &init_data),
16180                (find(&bwd_p, "d_output"), &d_seed),
16181            ],
16182        );
16183        execute_thunks(&s_p, a_p.raw_buf_mut());
16184        let dinit_p = read_arena_f64(&a_p, bwd_p.outputs[1], n);
16185
16186        let (s_c, mut a_c) = prepare_f64(
16187            &bwd_c,
16188            &[
16189                (find(&bwd_c, "init"), &init_data),
16190                (find(&bwd_c, "d_output"), &d_seed),
16191            ],
16192        );
16193        execute_thunks(&s_c, a_c.raw_buf_mut());
16194        let dinit_c = read_arena_f64(&a_c, bwd_c.outputs[1], n);
16195
16196        for i in 0..n {
16197            assert!(
16198                (dinit_p[i] - dinit_c[i]).abs() < 1e-12,
16199                "dinit[{i}]: plain={} checkpointed={}",
16200                dinit_p[i],
16201                dinit_c[i]
16202            );
16203        }
16204    }
16205
16206    /// Recursive checkpointing end-to-end: build a ScanBackward
16207    /// configured with K=2 checkpoints (for length=4), and compare
16208    /// dinit against the same backward graph with full trajectory
16209    /// (K=0). Forward computes a cumulative-sum-style scan; loss = sum.
16210    /// Both paths must agree to f64 precision.
16211    #[test]
16212    fn recursive_checkpointing_matches_full_trajectory() {
16213        let n = 2usize;
16214        let length = 4u32;
16215
16216        // Body: carry + ones (deterministic, no xs)
16217        let build_body = || -> Graph {
16218            let mut body = Graph::new("rc_body");
16219            let carry = body.input("carry", Shape::new(&[n], DType::F64));
16220            let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
16221            let ones = body.add_node(
16222                Op::Constant { data: ones_bytes },
16223                vec![],
16224                Shape::new(&[n], DType::F64),
16225            );
16226            let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
16227            body.set_outputs(vec![next]);
16228            body
16229        };
16230
16231        // body_vjp: same body + d_output, output dcarry. body_vjp is
16232        // used by ScanBackward to walk the chain rule per step.
16233        let body_vjp_for = || -> Graph {
16234            use rlx_opt::autodiff::grad;
16235            let body = build_body();
16236            // grad(body, [carry_id]) → graph with dcarry as the output.
16237            let carry_id = body
16238                .nodes()
16239                .iter()
16240                .find(|n| matches!(n.op, Op::Input { .. }))
16241                .map(|n| n.id)
16242                .unwrap();
16243            grad(&body, &[carry_id])
16244        };
16245
16246        // ── Forward (All-strategy): scan with full trajectory ──
16247        let mut g_full = Graph::new("rc_outer_full");
16248        let init_full = g_full.input("init", Shape::new(&[n], DType::F64));
16249        let traj_full_id = g_full.scan_trajectory(init_full, build_body(), length);
16250        // Hand-build a ScanBackward node that reads the full trajectory.
16251        let upstream_full = g_full.input("upstream", Shape::new(&[length as usize, n], DType::F64));
16252        let dinit_full_id = g_full.scan_backward(
16253            init_full,
16254            traj_full_id,
16255            upstream_full,
16256            &[],
16257            body_vjp_for(),
16258            length,
16259            true,
16260            Shape::new(&[n], DType::F64),
16261        );
16262        g_full.set_outputs(vec![dinit_full_id]);
16263
16264        // ── Forward (Recursive-2): scan saves only K=2 rows ──
16265        // Build the trajectory shape [K, *carry] = [2, 2].
16266        let k = 2u32;
16267        let mut g_rec = Graph::new("rc_outer_rec");
16268        let init_rec = g_rec.input("init", Shape::new(&[n], DType::F64));
16269        let traj_rec_id = g_rec.add_node(
16270            Op::Scan {
16271                body: Box::new(build_body()),
16272                length,
16273                save_trajectory: true,
16274                num_bcast: 0,
16275                num_xs: 0,
16276                num_checkpoints: k,
16277            },
16278            vec![init_rec],
16279            Shape::new(&[k as usize, n], DType::F64),
16280        );
16281        // Same upstream shape as the full version (the upstream is per
16282        // *forward step*, length rows — independent of K).
16283        let upstream_rec = g_rec.input("upstream", Shape::new(&[length as usize, n], DType::F64));
16284        let dinit_rec_id = g_rec.add_node(
16285            Op::ScanBackward {
16286                body_vjp: Box::new(body_vjp_for()),
16287                length,
16288                save_trajectory: true,
16289                num_xs: 0,
16290                num_checkpoints: k,
16291                forward_body: Some(Box::new(build_body())),
16292            },
16293            vec![init_rec, traj_rec_id, upstream_rec],
16294            Shape::new(&[n], DType::F64),
16295        );
16296        g_rec.set_outputs(vec![dinit_rec_id]);
16297
16298        // ── Run both, same inputs ──
16299        let init_data = vec![0.5_f64, -0.5];
16300        let upstream_data: Vec<f64> = (0..length as usize * n).map(|i| (i as f64) * 0.1).collect();
16301
16302        let find = |graph: &Graph, want: &str| -> NodeId {
16303            for node in graph.nodes() {
16304                if let Op::Input { name } = &node.op
16305                    && name == want
16306                {
16307                    return node.id;
16308                }
16309            }
16310            panic!("no input {want}");
16311        };
16312
16313        let (s_full, mut a_full) = prepare_f64(
16314            &g_full,
16315            &[
16316                (find(&g_full, "init"), &init_data),
16317                (find(&g_full, "upstream"), &upstream_data),
16318            ],
16319        );
16320        execute_thunks(&s_full, a_full.raw_buf_mut());
16321        let dinit_full = read_arena_f64(&a_full, g_full.outputs[0], n);
16322
16323        let (s_rec, mut a_rec) = prepare_f64(
16324            &g_rec,
16325            &[
16326                (find(&g_rec, "init"), &init_data),
16327                (find(&g_rec, "upstream"), &upstream_data),
16328            ],
16329        );
16330        execute_thunks(&s_rec, a_rec.raw_buf_mut());
16331        let dinit_rec = read_arena_f64(&a_rec, g_rec.outputs[0], n);
16332
16333        for i in 0..n {
16334            assert!(
16335                (dinit_full[i] - dinit_rec[i]).abs() < 1e-12,
16336                "i={i}: full={} rec={}",
16337                dinit_full[i],
16338                dinit_rec[i]
16339            );
16340        }
16341    }
16342
16343    /// vmap-of-grad: gradient through Scan, vmap'd over init.
16344    /// Forward (per row):
16345    ///   carry_{t+1} = carry_t + ones    (body adds a constant)
16346    ///   loss = sum(carry_length) = sum(init) + length·n
16347    /// Closed form: dloss/dinit_i = 1 for every i. vmap over init at
16348    /// batch=3 → dinit_batched is all-ones [3, n]. Cross-checks
16349    /// against per-row grad_with_loss runs. Validates the vmap rule
16350    /// for Op::ScanBackward.
16351    #[test]
16352    fn vmap_of_grad_scan_matches_per_row_runs() {
16353        use rlx_opt::autodiff::grad_with_loss;
16354        use rlx_opt::vmap::vmap;
16355        let n = 2usize;
16356        let length = 3u32;
16357        let batch = 3usize;
16358
16359        let mut body = Graph::new("scan_grad_body");
16360        let carry = body.input("carry", Shape::new(&[n], DType::F64));
16361        let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
16362        let ones = body.add_node(
16363            Op::Constant { data: ones_bytes },
16364            vec![],
16365            Shape::new(&[n], DType::F64),
16366        );
16367        let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
16368        body.set_outputs(vec![next]);
16369
16370        let mut g = Graph::new("scan_grad_outer");
16371        let init = g.input("init", Shape::new(&[n], DType::F64));
16372        let final_x = g.scan(init, body, length);
16373        let loss = g.reduce(
16374            final_x,
16375            ReduceOp::Sum,
16376            vec![0],
16377            false,
16378            Shape::new(&[1], DType::F64),
16379        );
16380        g.set_outputs(vec![loss]);
16381
16382        let bwd = grad_with_loss(&g, &[init]);
16383        let bg = vmap(&bwd, &["init"], batch);
16384
16385        let find = |graph: &Graph, want: &str| -> NodeId {
16386            for node in graph.nodes() {
16387                let name = match &node.op {
16388                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16389                    _ => None,
16390                };
16391                if name == Some(want) {
16392                    return node.id;
16393                }
16394            }
16395            panic!("no node named {want}");
16396        };
16397        let init_b = find(&bg, "init");
16398        let d_out_b = find(&bg, "d_output");
16399
16400        let init_data: Vec<f64> = (0..batch * n).map(|i| (i as f64) * 0.5).collect();
16401        let d_seed = [1.0_f64];
16402
16403        let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (d_out_b, &d_seed)]);
16404        execute_thunks(&sched, arena.raw_buf_mut());
16405        let dinit_b = read_arena_f64(&arena, bg.outputs[1], batch * n);
16406
16407        for i in 0..batch * n {
16408            assert!(
16409                (dinit_b[i] - 1.0).abs() < 1e-12,
16410                "dinit[{i}] = {} (expected 1.0)",
16411                dinit_b[i]
16412            );
16413        }
16414
16415        // Cross-check vs per-row grad_with_loss.
16416        for bi in 0..batch {
16417            let row = &init_data[bi * n..(bi + 1) * n];
16418            let mut g2 = Graph::new("per_row_grad");
16419            let init2 = g2.input("init", Shape::new(&[n], DType::F64));
16420            let mut body2 = Graph::new("per_row_body");
16421            let c2 = body2.input("carry", Shape::new(&[n], DType::F64));
16422            let ones2_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
16423            let ones2 = body2.add_node(
16424                Op::Constant { data: ones2_bytes },
16425                vec![],
16426                Shape::new(&[n], DType::F64),
16427            );
16428            let next2 = body2.binary(BinaryOp::Add, c2, ones2, Shape::new(&[n], DType::F64));
16429            body2.set_outputs(vec![next2]);
16430            let final2 = g2.scan(init2, body2, length);
16431            let loss2 = g2.reduce(
16432                final2,
16433                ReduceOp::Sum,
16434                vec![0],
16435                false,
16436                Shape::new(&[1], DType::F64),
16437            );
16438            g2.set_outputs(vec![loss2]);
16439            let bwd2 = grad_with_loss(&g2, &[init2]);
16440            let init2_id = find(&bwd2, "init");
16441            let d_out2_id = find(&bwd2, "d_output");
16442            let (s2, mut a2) = prepare_f64(&bwd2, &[(init2_id, row), (d_out2_id, &d_seed)]);
16443            execute_thunks(&s2, a2.raw_buf_mut());
16444            let row_dinit = read_arena_f64(&a2, bwd2.outputs[1], n);
16445            for j in 0..n {
16446                let got = dinit_b[bi * n + j];
16447                let want = row_dinit[j];
16448                assert!(
16449                    (got - want).abs() < 1e-12,
16450                    "row {bi}, j {j}: vmap'd={got} per-row={want}"
16451                );
16452            }
16453        }
16454    }
16455
16456    /// vmap of Op::Scan: batched cumulative-sum. Forward
16457    ///   carry_{t+1} = carry_t + xs\[t\]
16458    ///   final = init + sum(xs)
16459    /// vmap over both init and xs at batch=3. Each batch row should
16460    /// equal the scalar run of the same body+xs subset.
16461    #[test]
16462    fn vmap_scan_cumulative_sum_matches_scalar_runs() {
16463        use rlx_opt::vmap::vmap;
16464        let n = 2usize;
16465        let length = 4u32;
16466        let batch = 3usize;
16467
16468        // Body: (carry, x_t) → carry + x_t
16469        let mut body = Graph::new("scan_body_cumsum");
16470        let carry = body.input("carry", Shape::new(&[n], DType::F64));
16471        let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
16472        let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
16473        body.set_outputs(vec![next]);
16474
16475        let mut g = Graph::new("scan_outer_cumsum");
16476        let init = g.input("init", Shape::new(&[n], DType::F64));
16477        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
16478        let final_carry = g.scan_with_xs(init, &[xs], body, length);
16479        g.set_outputs(vec![final_carry]);
16480
16481        // vmap over both init and xs.
16482        let bg = vmap(&g, &["init", "xs"], batch);
16483
16484        // Test data — distinct per-batch rows.
16485        let init_data: Vec<f64> = (0..batch * n).map(|i| (i + 1) as f64).collect();
16486        // xs has shape [B, length, n] after vmap (the outer's xs is
16487        // [length, n]; vmap lifts it to [B, length, n]).
16488        let xs_data: Vec<f64> = (0..batch * length as usize * n)
16489            .map(|i| 0.1 * (i as f64))
16490            .collect();
16491
16492        let find = |graph: &Graph, want: &str| -> NodeId {
16493            for node in graph.nodes() {
16494                if let Op::Input { name } = &node.op
16495                    && name == want
16496                {
16497                    return node.id;
16498                }
16499            }
16500            panic!("no input {want}");
16501        };
16502        let init_b = find(&bg, "init");
16503        let xs_b = find(&bg, "xs");
16504        let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (xs_b, &xs_data)]);
16505        execute_thunks(&sched, arena.raw_buf_mut());
16506        let batched_out = read_arena_f64(&arena, bg.outputs[0], batch * n);
16507
16508        // Reference: per-batch scalar Scan.
16509        for bi in 0..batch {
16510            let init_slice = &init_data[bi * n..(bi + 1) * n];
16511            let mut x = init_slice.to_vec();
16512            for t in 0..length as usize {
16513                for j in 0..n {
16514                    x[j] += xs_data[bi * length as usize * n + t * n + j];
16515                }
16516            }
16517
16518            for i in 0..n {
16519                let got = batched_out[bi * n + i];
16520                assert!(
16521                    (got - x[i]).abs() < 1e-12,
16522                    "row {bi}, i {i}: got {got} ref {}",
16523                    x[i]
16524                );
16525            }
16526        }
16527    }
16528
16529    /// vmap of dense solve — Circulax-shaped batched parameter sweep.
16530    /// Forward: x = solve(A, b). vmap over both A (batched [B,N,N])
16531    /// and b (batched [B,N]). Run on CPU and compare each batch row
16532    /// against an independent scalar dgesv.
16533    #[test]
16534    fn vmap_dense_solve_matches_scalar_runs() {
16535        use rlx_opt::vmap::vmap;
16536        let n = 3usize;
16537        let batch = 4usize;
16538
16539        let mut g = Graph::new("solve_forward");
16540        let a = g.input("A", Shape::new(&[n, n], DType::F64));
16541        let b = g.input("b", Shape::new(&[n], DType::F64));
16542        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
16543        g.set_outputs(vec![x]);
16544
16545        // vmap both A and b across the batch.
16546        let bg = vmap(&g, &["A", "b"], batch);
16547
16548        // Independent A and b per batch row.
16549        let mut rng = rlx_ir::Philox4x32::new(0xb47c_u64);
16550        let mut a_data = vec![0.0_f64; batch * n * n];
16551        let mut b_data = vec![0.0_f64; batch * n];
16552        for bi in 0..batch {
16553            // Diagonally dominant A — guaranteed non-singular.
16554            for i in 0..n {
16555                for j in 0..n {
16556                    a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
16557                }
16558                a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
16559            }
16560            for i in 0..n {
16561                b_data[bi * n + i] = rng.next_f32() as f64;
16562            }
16563        }
16564
16565        let find = |graph: &Graph, want: &str| -> NodeId {
16566            for node in graph.nodes() {
16567                if let Op::Input { name } = &node.op
16568                    && name == want
16569                {
16570                    return node.id;
16571                }
16572            }
16573            panic!("no input named {want}");
16574        };
16575        let ba = find(&bg, "A");
16576        let bb = find(&bg, "b");
16577        let (sched, mut arena) = prepare_f64(&bg, &[(ba, &a_data), (bb, &b_data)]);
16578        execute_thunks(&sched, arena.raw_buf_mut());
16579        let batched_x = read_arena_f64(&arena, bg.outputs[0], batch * n);
16580
16581        // Reference: per-batch dgesv.
16582        for bi in 0..batch {
16583            let mut a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
16584            let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
16585            crate::blas::dgesv(&mut a_slice, &mut b_slice, n, 1);
16586            for i in 0..n {
16587                let got = batched_x[bi * n + i];
16588                let want = b_slice[i];
16589                assert!(
16590                    (got - want).abs() < 1e-12,
16591                    "row {bi}, i {i}: got {got} want {want}"
16592                );
16593            }
16594        }
16595    }
16596
16597    /// vmap end-to-end: build a graph that computes y = MatMul(x, w) + b
16598    /// and reduces to a per-element loss. vmap over x with batch=4.
16599    /// Run the batched graph and compare each output row against an
16600    /// independent scalar run of the original graph. Validates the
16601    /// structural lift + the runtime path for batched MatMul +
16602    /// batched Binary + batched Reduce.
16603    #[test]
16604    fn vmap_matmul_add_reduce_matches_scalar_runs() {
16605        use rlx_opt::vmap::vmap;
16606        let n = 3usize;
16607        let batch = 4usize;
16608
16609        // Forward graph: y = MatMul(reshape(x, [1,n]), w) + b ; loss = sum(y).
16610        let mut g = Graph::new("vmap_e2e_forward");
16611        let x = g.input("x", Shape::new(&[n], DType::F64));
16612        let w = g.input("w", Shape::new(&[n, n], DType::F64));
16613        let b = g.input("b", Shape::new(&[n], DType::F64));
16614        let x_row = g.add_node(
16615            Op::Reshape {
16616                new_shape: vec![1, n as i64],
16617            },
16618            vec![x],
16619            Shape::new(&[1, n], DType::F64),
16620        );
16621        let mm = g.matmul(x_row, w, Shape::new(&[1, n], DType::F64));
16622        let mm_flat = g.add_node(
16623            Op::Reshape {
16624                new_shape: vec![n as i64],
16625            },
16626            vec![mm],
16627            Shape::new(&[n], DType::F64),
16628        );
16629        let yv = g.binary(BinaryOp::Add, mm_flat, b, Shape::new(&[n], DType::F64));
16630        let loss = g.reduce(
16631            yv,
16632            ReduceOp::Sum,
16633            vec![0],
16634            false,
16635            Shape::new(&[1], DType::F64),
16636        );
16637        g.set_outputs(vec![loss]);
16638
16639        // Build the vmap'd version (batch over x; w and b shared).
16640        let bg = vmap(&g, &["x"], batch);
16641
16642        // Test data — distinct rows so we can verify the per-row dispatch.
16643        let mut rng = rlx_ir::Philox4x32::new(0xc1c0_u64);
16644        let n_w = n * n;
16645        let w_data: Vec<f64> = (0..n_w).map(|_| rng.next_f32() as f64).collect();
16646        let b_data: Vec<f64> = (0..n).map(|_| rng.next_f32() as f64).collect();
16647        let mut x_data_batched: Vec<f64> = Vec::with_capacity(batch * n);
16648        for _ in 0..batch * n {
16649            x_data_batched.push(rng.next_f32() as f64);
16650        }
16651
16652        // Run the batched graph.
16653        let find = |graph: &Graph, want: &str| -> NodeId {
16654            for node in graph.nodes() {
16655                if let Op::Input { name } = &node.op
16656                    && name == want
16657                {
16658                    return node.id;
16659                }
16660            }
16661            panic!("no input named {want}");
16662        };
16663        let bx = find(&bg, "x");
16664        let bw = find(&bg, "w");
16665        let bb = find(&bg, "b");
16666        let (sched, mut arena) =
16667            prepare_f64(&bg, &[(bx, &x_data_batched), (bw, &w_data), (bb, &b_data)]);
16668        execute_thunks(&sched, arena.raw_buf_mut());
16669        // Reduce::Sum on shifted axis 1 with keep_dim=false → output [B, 1]
16670        // (it preserves the leading batch axis but reduces what was [n] to [].
16671        // Since the original output was [1] f64 and the reduce was over
16672        // axis 0, after vmap the leading-axis-shifted reduce keeps the
16673        // leading 1 from the original output's [1] shape.)
16674        let batched_out = read_arena_f64(&arena, bg.outputs[0], batch);
16675
16676        // Reference: run the original (un-batched) graph once per batch row.
16677        for bi in 0..batch {
16678            let xs_slice = &x_data_batched[bi * n..(bi + 1) * n];
16679            let mut g2 = Graph::new("scalar_run");
16680            let x2 = g2.input("x", Shape::new(&[n], DType::F64));
16681            let w2 = g2.input("w", Shape::new(&[n, n], DType::F64));
16682            let b2 = g2.input("b", Shape::new(&[n], DType::F64));
16683            let xr = g2.add_node(
16684                Op::Reshape {
16685                    new_shape: vec![1, n as i64],
16686                },
16687                vec![x2],
16688                Shape::new(&[1, n], DType::F64),
16689            );
16690            let m = g2.matmul(xr, w2, Shape::new(&[1, n], DType::F64));
16691            let mf = g2.add_node(
16692                Op::Reshape {
16693                    new_shape: vec![n as i64],
16694                },
16695                vec![m],
16696                Shape::new(&[n], DType::F64),
16697            );
16698            let yv2 = g2.binary(BinaryOp::Add, mf, b2, Shape::new(&[n], DType::F64));
16699            let l2 = g2.reduce(
16700                yv2,
16701                ReduceOp::Sum,
16702                vec![0],
16703                false,
16704                Shape::new(&[1], DType::F64),
16705            );
16706            g2.set_outputs(vec![l2]);
16707            let (s2, mut a2) = prepare_f64(&g2, &[(x2, xs_slice), (w2, &w_data), (b2, &b_data)]);
16708            execute_thunks(&s2, a2.raw_buf_mut());
16709            let scalar_out = read_arena_f64(&a2, l2, 1);
16710            assert!(
16711                (batched_out[bi] - scalar_out[0]).abs() < 1e-12,
16712                "row {bi}: batched={} scalar={}",
16713                batched_out[bi],
16714                scalar_out[0]
16715            );
16716        }
16717    }
16718
16719    /// Full gradient through scan-with-xs: dinit AND dxs both checked
16720    /// against finite differences. Forward
16721    ///   carry_{t+1} = solve(M, carry_t + xs\[t\])
16722    ///   loss        = sum(carry_length)
16723    /// Verifies that grad_with_loss returns gradients w.r.t. both
16724    /// `init` and `xs` and that dxs matches per-element FD.
16725    #[test]
16726    fn scan_with_xs_dxs_matches_fd() {
16727        use rlx_opt::autodiff::grad_with_loss;
16728        let n = 3usize;
16729        let length = 3u32;
16730        let dt = 0.1_f64;
16731
16732        let mut m_data = vec![0.0_f64; n * n];
16733        for i in 0..n {
16734            m_data[i * n + i] = 1.0 + dt * 2.0;
16735            if i > 0 {
16736                m_data[i * n + (i - 1)] = -dt;
16737            }
16738            if i + 1 < n {
16739                m_data[i * n + (i + 1)] = -dt;
16740            }
16741        }
16742        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
16743
16744        let mut body = Graph::new("be_dxs_body");
16745        let carry = body.input("carry", Shape::new(&[n], DType::F64));
16746        let drive = body.input("drive", Shape::new(&[n], DType::F64));
16747        let m = body.add_node(
16748            Op::Constant { data: m_bytes },
16749            vec![],
16750            Shape::new(&[n, n], DType::F64),
16751        );
16752        let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
16753        let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
16754        body.set_outputs(vec![next]);
16755
16756        let mut g = Graph::new("be_dxs_outer");
16757        let init = g.input("init", Shape::new(&[n], DType::F64));
16758        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
16759        let final_carry = g.scan_with_xs(init, &[xs], body, length);
16760        let loss = g.reduce(
16761            final_carry,
16762            ReduceOp::Sum,
16763            vec![0],
16764            false,
16765            Shape::new(&[1], DType::F64),
16766        );
16767        g.set_outputs(vec![loss]);
16768
16769        // wrt = [init, xs] — get both gradients back.
16770        let bwd = grad_with_loss(&g, &[init, xs]);
16771        assert_eq!(bwd.outputs.len(), 3, "[loss, dinit, dxs]");
16772
16773        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
16774            for node in graph.nodes() {
16775                let name = match &node.op {
16776                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16777                    _ => None,
16778                };
16779                if name == Some(want) {
16780                    return node.id;
16781                }
16782            }
16783            panic!("no node named {want:?}");
16784        };
16785        let init_bwd = find_by_name(&bwd, "init");
16786        let xs_bwd = find_by_name(&bwd, "xs");
16787        let d_out_bwd = find_by_name(&bwd, "d_output");
16788
16789        let init_data = vec![0.5_f64, 0.0, -0.5];
16790        let xs_data: Vec<f64> = (0..length as usize * n)
16791            .map(|i| 0.1_f64 * ((i as f64) - 4.0))
16792            .collect();
16793        let d_seed = [1.0_f64];
16794
16795        let (sched, mut arena) = prepare_f64(
16796            &bwd,
16797            &[
16798                (init_bwd, &init_data),
16799                (xs_bwd, &xs_data),
16800                (d_out_bwd, &d_seed),
16801            ],
16802        );
16803        execute_thunks(&sched, arena.raw_buf_mut());
16804        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
16805        let dxs = read_arena_f64(&arena, bwd.outputs[2], length as usize * n);
16806
16807        let h = 1e-6;
16808        let loss_at = |x0: &[f64], xs_in: &[f64]| -> f64 {
16809            let mut acc = x0.to_vec();
16810            for t in 0..length as usize {
16811                for j in 0..n {
16812                    acc[j] += xs_in[t * n + j];
16813                }
16814                let mut a_copy = m_data.clone();
16815                crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
16816            }
16817            acc.iter().sum()
16818        };
16819
16820        // FD on dinit (sanity).
16821        for i in 0..n {
16822            let mut ip = init_data.to_vec();
16823            ip[i] += h;
16824            let mut im = init_data.to_vec();
16825            im[i] -= h;
16826            let fd = (loss_at(&ip, &xs_data) - loss_at(&im, &xs_data)) / (2.0 * h);
16827            assert!(
16828                (dinit[i] - fd).abs() < 1e-7,
16829                "FD dinit[{i}]: AD={} FD={}",
16830                dinit[i],
16831                fd
16832            );
16833        }
16834
16835        // FD on every dxs entry — full per-step gradient check.
16836        for t in 0..length as usize {
16837            for j in 0..n {
16838                let idx = t * n + j;
16839                let mut xp = xs_data.clone();
16840                xp[idx] += h;
16841                let mut xm = xs_data.clone();
16842                xm[idx] -= h;
16843                let fd = (loss_at(&init_data, &xp) - loss_at(&init_data, &xm)) / (2.0 * h);
16844                assert!(
16845                    (dxs[idx] - fd).abs() < 1e-7,
16846                    "FD dxs[t={t},j={j}]: AD={} FD={}",
16847                    dxs[idx],
16848                    fd
16849                );
16850            }
16851        }
16852    }
16853
16854    /// Gradient through a scan with per-step xs (Circulax-shaped).
16855    /// Forward:
16856    ///   carry_{t+1} = solve(M, carry_t + xs\[t\])
16857    ///   loss = sum(carry_length)
16858    /// dxs is out of MVP (asserted in the VJP rule's body_vjp `wrt`),
16859    /// but `dinit` flows correctly through the body's reverse Jacobian
16860    /// even with xs in the chain. Verify dinit against finite differences.
16861    #[test]
16862    fn scan_with_xs_gradient_dinit_matches_fd() {
16863        use rlx_opt::autodiff::grad_with_loss;
16864        let n = 3usize;
16865        let length = 3u32;
16866        let dt = 0.1_f64;
16867
16868        let mut m_data = vec![0.0_f64; n * n];
16869        for i in 0..n {
16870            m_data[i * n + i] = 1.0 + dt * 2.0;
16871            if i > 0 {
16872                m_data[i * n + (i - 1)] = -dt;
16873            }
16874            if i + 1 < n {
16875                m_data[i * n + (i + 1)] = -dt;
16876            }
16877        }
16878        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
16879
16880        let mut body = Graph::new("be_xs_grad_body");
16881        let carry = body.input("carry", Shape::new(&[n], DType::F64));
16882        let drive = body.input("drive", Shape::new(&[n], DType::F64));
16883        let m = body.add_node(
16884            Op::Constant { data: m_bytes },
16885            vec![],
16886            Shape::new(&[n, n], DType::F64),
16887        );
16888        let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
16889        let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
16890        body.set_outputs(vec![next]);
16891
16892        let mut g = Graph::new("be_xs_grad_outer");
16893        let init = g.input("init", Shape::new(&[n], DType::F64));
16894        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
16895        let final_carry = g.scan_with_xs(init, &[xs], body, length);
16896        let loss = g.reduce(
16897            final_carry,
16898            ReduceOp::Sum,
16899            vec![0],
16900            false,
16901            Shape::new(&[1], DType::F64),
16902        );
16903        g.set_outputs(vec![loss]);
16904
16905        let bwd = grad_with_loss(&g, &[init]);
16906
16907        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
16908            for node in graph.nodes() {
16909                let name = match &node.op {
16910                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16911                    _ => None,
16912                };
16913                if name == Some(want) {
16914                    return node.id;
16915                }
16916            }
16917            panic!("no node named {want:?}");
16918        };
16919        let init_bwd = find_by_name(&bwd, "init");
16920        let xs_bwd = find_by_name(&bwd, "xs");
16921        let d_out_bwd = find_by_name(&bwd, "d_output");
16922
16923        let init_data = vec![0.5_f64, 0.0, -0.5];
16924        // Drive: small per-step pulse, varying per element.
16925        let xs_data: Vec<f64> = (0..length as usize * n)
16926            .map(|i| 0.1_f64 * ((i as f64) - 4.0))
16927            .collect();
16928        let d_seed = [1.0_f64];
16929
16930        let (sched, mut arena) = prepare_f64(
16931            &bwd,
16932            &[
16933                (init_bwd, &init_data),
16934                (xs_bwd, &xs_data),
16935                (d_out_bwd, &d_seed),
16936            ],
16937        );
16938        execute_thunks(&sched, arena.raw_buf_mut());
16939        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
16940
16941        let h = 1e-6;
16942        let loss_at = |x0: &[f64]| -> f64 {
16943            let mut acc = x0.to_vec();
16944            for t in 0..length as usize {
16945                for j in 0..n {
16946                    acc[j] += xs_data[t * n + j];
16947                }
16948                let mut a_copy = m_data.clone();
16949                crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
16950            }
16951            acc.iter().sum()
16952        };
16953        for i in 0..n {
16954            let mut ip = init_data.to_vec();
16955            ip[i] += h;
16956            let mut im = init_data.to_vec();
16957            im[i] -= h;
16958            let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
16959            assert!(
16960                (dinit[i] - fd).abs() < 1e-7,
16961                "FD dinit[{i}]: AD={} FD={}",
16962                dinit[i],
16963                fd
16964            );
16965        }
16966    }
16967
16968    /// Gradient through a geometric-growth scan: forward
16969    ///   x_{t+1} = 1.1 · x_t,    x_0 = init
16970    ///   final   = x_length     = init · 1.1^length
16971    ///   loss    = sum(final)
16972    /// closed-form ∂loss/∂init\[i\] = 1.1^length for every i.
16973    /// Validates the VJP path: AD pre-pass rewrites save_trajectory=false
16974    /// to true, autodiff emits Op::ScanBackward, executor walks t back.
16975    #[test]
16976    fn scan_gradient_geometric_matches_closed_form() {
16977        use rlx_opt::autodiff::grad_with_loss;
16978        let n = 3usize;
16979        let length = 5u32;
16980
16981        let mut body = Graph::new("scan_grad_body");
16982        let x = body.input("carry", Shape::new(&[n], DType::F64));
16983        let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.1_f64.to_le_bytes()).collect();
16984        let scale = body.add_node(
16985            Op::Constant { data: scale_bytes },
16986            vec![],
16987            Shape::new(&[n], DType::F64),
16988        );
16989        let next = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
16990        body.set_outputs(vec![next]);
16991
16992        let mut g = Graph::new("scan_grad_outer");
16993        let init = g.input("init", Shape::new(&[n], DType::F64));
16994        let final_x = g.scan(init, body, length);
16995        let loss = g.reduce(
16996            final_x,
16997            ReduceOp::Sum,
16998            vec![0],
16999            false,
17000            Shape::new(&[1], DType::F64),
17001        );
17002        g.set_outputs(vec![loss]);
17003
17004        let bwd = grad_with_loss(&g, &[init]);
17005        assert_eq!(bwd.outputs.len(), 2);
17006
17007        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17008            for node in graph.nodes() {
17009                let name = match &node.op {
17010                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17011                    _ => None,
17012                };
17013                if name == Some(want) {
17014                    return node.id;
17015                }
17016            }
17017            panic!("no node named {want:?}");
17018        };
17019        let init_bwd = find_by_name(&bwd, "init");
17020        let d_out_bwd = find_by_name(&bwd, "d_output");
17021
17022        let init_data = vec![1.0_f64; n];
17023        let d_seed = [1.0_f64];
17024        let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
17025        execute_thunks(&sched, arena.raw_buf_mut());
17026        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
17027
17028        let want = 1.1_f64.powi(length as i32);
17029        for i in 0..n {
17030            assert!(
17031                (dinit[i] - want).abs() < 1e-12,
17032                "dinit[{i}] = {} want {}",
17033                dinit[i],
17034                want
17035            );
17036        }
17037
17038        // Finite-difference cross-check on init[0].
17039        let h = 1e-6;
17040        let loss_at = |x: &[f64]| -> f64 {
17041            let mut acc = x.to_vec();
17042            for _ in 0..length {
17043                for v in acc.iter_mut() {
17044                    *v *= 1.1;
17045                }
17046            }
17047            acc.iter().sum()
17048        };
17049        let mut ip = init_data.clone();
17050        ip[0] += h;
17051        let mut im = init_data.clone();
17052        im[0] -= h;
17053        let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
17054        assert!(
17055            (dinit[0] - fd).abs() < 1e-7,
17056            "FD dinit[0]: AD={} FD={}",
17057            dinit[0],
17058            fd
17059        );
17060    }
17061
17062    /// Gradient through Backward Euler scan composing with DenseSolve.
17063    /// Asserts dinit matches finite-difference per coordinate.
17064    #[test]
17065    fn scan_gradient_backward_euler_matches_fd() {
17066        use rlx_opt::autodiff::grad_with_loss;
17067        let n = 4usize;
17068        let length = 3u32;
17069        let dt = 0.05_f64;
17070
17071        let mut m_data = vec![0.0_f64; n * n];
17072        for i in 0..n {
17073            m_data[i * n + i] = 1.0 + dt * 2.0;
17074            if i > 0 {
17075                m_data[i * n + (i - 1)] = -dt;
17076            }
17077            if i + 1 < n {
17078                m_data[i * n + (i + 1)] = -dt;
17079            }
17080        }
17081        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17082
17083        let mut body = Graph::new("be_grad_body");
17084        let x = body.input("x", Shape::new(&[n], DType::F64));
17085        let m = body.add_node(
17086            Op::Constant { data: m_bytes },
17087            vec![],
17088            Shape::new(&[n, n], DType::F64),
17089        );
17090        let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
17091        body.set_outputs(vec![next]);
17092
17093        let mut g = Graph::new("be_grad_outer");
17094        let init = g.input("x0", Shape::new(&[n], DType::F64));
17095        let final_x = g.scan(init, body, length);
17096        let loss = g.reduce(
17097            final_x,
17098            ReduceOp::Sum,
17099            vec![0],
17100            false,
17101            Shape::new(&[1], DType::F64),
17102        );
17103        g.set_outputs(vec![loss]);
17104
17105        let bwd = grad_with_loss(&g, &[init]);
17106
17107        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17108            for node in graph.nodes() {
17109                let name = match &node.op {
17110                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17111                    _ => None,
17112                };
17113                if name == Some(want) {
17114                    return node.id;
17115                }
17116            }
17117            panic!("no node named {want:?}");
17118        };
17119        let init_bwd = find_by_name(&bwd, "x0");
17120        let d_out_bwd = find_by_name(&bwd, "d_output");
17121
17122        let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
17123        let d_seed = [1.0_f64];
17124        let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
17125        execute_thunks(&sched, arena.raw_buf_mut());
17126        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
17127
17128        let h = 1e-6;
17129        let loss_at = |x0: &[f64]| -> f64 {
17130            let mut acc = x0.to_vec();
17131            for _ in 0..length {
17132                let mut a_copy = m_data.clone();
17133                crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
17134            }
17135            acc.iter().sum()
17136        };
17137        for i in 0..n {
17138            let mut ip = init_data.to_vec();
17139            ip[i] += h;
17140            let mut im = init_data.to_vec();
17141            im[i] -= h;
17142            let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
17143            assert!(
17144                (dinit[i] - fd).abs() < 1e-7,
17145                "FD dinit[{i}]: AD={} FD={}",
17146                dinit[i],
17147                fd
17148            );
17149        }
17150    }
17151
17152    /// Trajectory-mode scan: same Backward Euler body, but record the
17153    /// carry at every step. Output is `[length, n]` — row `t` is the
17154    /// state after step `t+1`. Validates the SaveAt-style waveform
17155    /// recording end-to-end, including that the last row equals what
17156    /// the no-trajectory variant would have returned.
17157    #[test]
17158    fn scan_trajectory_backward_euler_records_waveform() {
17159        let n = 4usize;
17160        let length = 5u32;
17161        let dt = 0.05_f64;
17162
17163        let mut m_data = vec![0.0_f64; n * n];
17164        for i in 0..n {
17165            m_data[i * n + i] = 1.0 + dt * 2.0;
17166            if i > 0 {
17167                m_data[i * n + (i - 1)] = -dt;
17168            }
17169            if i + 1 < n {
17170                m_data[i * n + (i + 1)] = -dt;
17171            }
17172        }
17173        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17174
17175        let mut body = Graph::new("be_traj_body");
17176        let x = body.input("x", Shape::new(&[n], DType::F64));
17177        let m = body.add_node(
17178            Op::Constant { data: m_bytes },
17179            vec![],
17180            Shape::new(&[n, n], DType::F64),
17181        );
17182        let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
17183        body.set_outputs(vec![next]);
17184
17185        let mut g = Graph::new("be_traj_outer");
17186        let init = g.input("x0", Shape::new(&[n], DType::F64));
17187        let traj = g.scan_trajectory(init, body, length);
17188        g.set_outputs(vec![traj]);
17189
17190        let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
17191        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
17192        execute_thunks(&sched, arena.raw_buf_mut());
17193        let got = read_arena_f64(&arena, traj, length as usize * n);
17194
17195        // Reference: each step's solve, recorded.
17196        let mut want = Vec::<f64>::with_capacity(length as usize * n);
17197        let mut x_ref = init_data.to_vec();
17198        for _ in 0..length {
17199            let mut a_copy = m_data.clone();
17200            crate::blas::dgesv(&mut a_copy, &mut x_ref, n, 1);
17201            want.extend_from_slice(&x_ref);
17202        }
17203        for i in 0..length as usize * n {
17204            assert!(
17205                (got[i] - want[i]).abs() < 1e-12,
17206                "got[{i}] = {} ref {}",
17207                got[i],
17208                want[i]
17209            );
17210        }
17211
17212        // Sanity: trajectory rows are monotone-decreasing in mass
17213        // (Backward Euler diffuses; boundary leak removes mass).
17214        for t in 1..length as usize {
17215            let prev: f64 = got[(t - 1) * n..t * n].iter().sum();
17216            let curr: f64 = got[t * n..(t + 1) * n].iter().sum();
17217            assert!(
17218                curr <= prev + 1e-15,
17219                "mass should decay: row {} sum {prev}, row {t} sum {curr}",
17220                t - 1
17221            );
17222        }
17223
17224        // Last row of the trajectory equals what a non-trajectory
17225        // scan returns — verify by running the same forward through
17226        // the simpler API and comparing.
17227        let mut body2 = Graph::new("be_final_body");
17228        let x2 = body2.input("x", Shape::new(&[n], DType::F64));
17229        let m_bytes2: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17230        let m2 = body2.add_node(
17231            Op::Constant { data: m_bytes2 },
17232            vec![],
17233            Shape::new(&[n, n], DType::F64),
17234        );
17235        let next2 = body2.dense_solve(m2, x2, Shape::new(&[n], DType::F64));
17236        body2.set_outputs(vec![next2]);
17237
17238        let mut g2 = Graph::new("be_final_outer");
17239        let init2 = g2.input("x0", Shape::new(&[n], DType::F64));
17240        let final_x = g2.scan(init2, body2, length);
17241        g2.set_outputs(vec![final_x]);
17242        let (sched2, mut arena2) = prepare_f64(&g2, &[(init2, &init_data)]);
17243        execute_thunks(&sched2, arena2.raw_buf_mut());
17244        let final_got = read_arena_f64(&arena2, final_x, n);
17245
17246        let last_row = &got[(length as usize - 1) * n..length as usize * n];
17247        for i in 0..n {
17248            assert!(
17249                (last_row[i] - final_got[i]).abs() < 1e-15,
17250                "last trajectory row[{i}] = {} vs final-scan = {}",
17251                last_row[i],
17252                final_got[i]
17253            );
17254        }
17255    }
17256
17257    /// Op::Scan composing with Op::DenseSolve — the Circulax-shaped
17258    /// pattern for Backward Euler.
17259    /// Body: x_{t+1} = solve(I + dt·A, x_t).
17260    /// 1-D heat-equation Laplacian A; analytic ground truth from
17261    /// composing the same per-step solve in Rust.
17262    #[test]
17263    fn scan_backward_euler_heat_f64() {
17264        let n = 4usize;
17265        let length = 5u32;
17266        let dt = 0.05_f64;
17267
17268        // Construct M = I + dt · L  where L is the Laplacian (-1, 2, -1).
17269        // M is constant across iterations; embed it in the body via Op::Constant.
17270        let mut m_data = vec![0.0_f64; n * n];
17271        for i in 0..n {
17272            m_data[i * n + i] = 1.0 + dt * 2.0;
17273            if i > 0 {
17274                m_data[i * n + (i - 1)] = -dt;
17275            }
17276            if i + 1 < n {
17277                m_data[i * n + (i + 1)] = -dt;
17278            }
17279        }
17280        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17281
17282        let mut body = Graph::new("be_body");
17283        let x = body.input("x", Shape::new(&[n], DType::F64));
17284        let m = body.add_node(
17285            Op::Constant { data: m_bytes },
17286            vec![],
17287            Shape::new(&[n, n], DType::F64),
17288        );
17289        let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
17290        body.set_outputs(vec![next]);
17291
17292        let mut g = Graph::new("be_outer");
17293        let init = g.input("x0", Shape::new(&[n], DType::F64));
17294        let final_x = g.scan(init, body, length);
17295        g.set_outputs(vec![final_x]);
17296
17297        // Initial: a sharp pulse at index 1.
17298        let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
17299        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
17300        execute_thunks(&sched, arena.raw_buf_mut());
17301        let got = read_arena_f64(&arena, final_x, n);
17302
17303        // Reference: apply the same M-solve `length` times in pure Rust.
17304        let mut ref_x = init_data.to_vec();
17305        for _ in 0..length {
17306            let mut a_copy = m_data.clone();
17307            crate::blas::dgesv(&mut a_copy, &mut ref_x, n, 1);
17308        }
17309        for i in 0..n {
17310            assert!(
17311                (got[i] - ref_x[i]).abs() < 1e-12,
17312                "got[{i}] = {} ref {}",
17313                got[i],
17314                ref_x[i]
17315            );
17316        }
17317        // Sanity: pulse should diffuse, mass should be conserved-ish
17318        // (Backward Euler is mass-conserving for this stencil with
17319        // zero-flux boundaries — but our boundaries leak, so check
17320        // that mass strictly decreases instead).
17321        let mass: f64 = got.iter().sum();
17322        assert!(mass > 0.0 && mass < 1.0, "diffusion mass: {mass}");
17323    }
17324
17325    /// Multi-RHS forward DenseSolve: X = solve(A, B) with B [N, K]
17326    /// stays correct end-to-end. Verifies the executor/lowering and
17327    /// the LAPACK column-major dance both honour `nrhs > 1`.
17328    #[test]
17329    fn dense_solve_f64_multi_rhs_forward() {
17330        let n = 3usize;
17331        let k = 2usize;
17332        let mut g = Graph::new("solve_multi_rhs");
17333        let a = g.input("A", Shape::new(&[n, n], DType::F64));
17334        let b = g.input("B", Shape::new(&[n, k], DType::F64));
17335        let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
17336        g.set_outputs(vec![x]);
17337
17338        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17339        let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
17340        let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
17341        execute_thunks(&sched, arena.raw_buf_mut());
17342        let x_got = read_arena_f64(&arena, x, n * k);
17343        for c in 0..k {
17344            for i in 0..n {
17345                let mut acc = 0.0_f64;
17346                for j in 0..n {
17347                    acc += a_data[i * n + j] * x_got[j * k + c];
17348                }
17349                let want = b_data[i * k + c];
17350                assert!(
17351                    (acc - want).abs() < 1e-10,
17352                    "col {c} row {i}: got {acc} want {want}"
17353                );
17354            }
17355        }
17356    }
17357
17358    /// Multi-RHS reverse-mode VJP: dB = (Aᵀ)⁻¹·1, dA = -dB · Xᵀ.
17359    /// Verified analytically + finite differences on dB[0,0].
17360    #[test]
17361    fn dense_solve_f64_multi_rhs_gradient() {
17362        use rlx_opt::autodiff::grad_with_loss;
17363        let n = 3usize;
17364        let k = 2usize;
17365        let mut g = Graph::new("solve_mrhs_grad");
17366        let a = g.param("A", Shape::new(&[n, n], DType::F64));
17367        let b = g.input("B", Shape::new(&[n, k], DType::F64));
17368        let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
17369        let loss = g.reduce(
17370            x,
17371            ReduceOp::Sum,
17372            vec![0, 1],
17373            false,
17374            Shape::new(&[1], DType::F64),
17375        );
17376        g.set_outputs(vec![loss]);
17377
17378        let bwd = grad_with_loss(&g, &[a, b]);
17379        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17380            for node in graph.nodes() {
17381                let name = match &node.op {
17382                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17383                    _ => None,
17384                };
17385                if name == Some(want) {
17386                    return node.id;
17387                }
17388            }
17389            panic!("no node named {want:?}");
17390        };
17391        let a_bwd = find_by_name(&bwd, "A");
17392        let b_bwd = find_by_name(&bwd, "B");
17393        let d_out = find_by_name(&bwd, "d_output");
17394
17395        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17396        let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
17397        let d_seed = [1.0_f64];
17398
17399        let (sched, mut arena) = prepare_f64(
17400            &bwd,
17401            &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out, &d_seed)],
17402        );
17403        execute_thunks(&sched, arena.raw_buf_mut());
17404        let da_got = read_arena_f64(&arena, bwd.outputs[1], n * n);
17405        let db_got = read_arena_f64(&arena, bwd.outputs[2], n * k);
17406
17407        // Reference.
17408        let mut x_ref = b_data;
17409        {
17410            let mut a_copy = a_data;
17411            crate::blas::dgesv(&mut a_copy, &mut x_ref, n, k);
17412        }
17413        let mut at = [0.0_f64; 9];
17414        for i in 0..n {
17415            for j in 0..n {
17416                at[i * n + j] = a_data[j * n + i];
17417            }
17418        }
17419        let mut ones_nk = vec![1.0_f64; n * k];
17420        crate::blas::dgesv(&mut at, &mut ones_nk, n, k);
17421        let db_ref = ones_nk;
17422        let mut da_ref = [0.0_f64; 9];
17423        for i in 0..n {
17424            for j in 0..n {
17425                let mut acc = 0.0_f64;
17426                for c in 0..k {
17427                    acc += db_ref[i * k + c] * x_ref[j * k + c];
17428                }
17429                da_ref[i * n + j] = -acc;
17430            }
17431        }
17432        for i in 0..n * k {
17433            assert!(
17434                (db_got[i] - db_ref[i]).abs() < 1e-10,
17435                "dB[{i}]: got {} want {}",
17436                db_got[i],
17437                db_ref[i]
17438            );
17439        }
17440        for i in 0..n * n {
17441            assert!(
17442                (da_got[i] - da_ref[i]).abs() < 1e-10,
17443                "dA[{i}]: got {} want {}",
17444                da_got[i],
17445                da_ref[i]
17446            );
17447        }
17448
17449        // FD on dB[0,0].
17450        let h = 1e-6;
17451        let mut bp = b_data;
17452        bp[0] += h;
17453        let mut bm = b_data;
17454        bm[0] -= h;
17455        let xp = {
17456            let mut a_copy = a_data;
17457            crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
17458            bp
17459        };
17460        let xm = {
17461            let mut a_copy = a_data;
17462            crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
17463            bm
17464        };
17465        let lp: f64 = xp.iter().sum();
17466        let lm: f64 = xm.iter().sum();
17467        let fd = (lp - lm) / (2.0 * h);
17468        assert!(
17469            (db_got[0] - fd).abs() < 1e-7,
17470            "FD dB[0,0]: AD={} FD={}",
17471            db_got[0],
17472            fd
17473        );
17474    }
17475
17476    /// Multi-RHS forward-mode JVP w.r.t. B. Closed form: t_X = solve(A, t_B).
17477    #[test]
17478    fn dense_solve_f64_multi_rhs_jvp() {
17479        use rlx_opt::autodiff_fwd::jvp;
17480        let n = 3usize;
17481        let k = 2usize;
17482        let mut g = Graph::new("solve_mrhs_jvp");
17483        let a = g.input("A", Shape::new(&[n, n], DType::F64));
17484        let b = g.input("B", Shape::new(&[n, k], DType::F64));
17485        let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
17486        g.set_outputs(vec![x]);
17487
17488        let jg = jvp(&g, &[b]);
17489        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17490            for node in graph.nodes() {
17491                let name = match &node.op {
17492                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17493                    _ => None,
17494                };
17495                if name == Some(want) {
17496                    return node.id;
17497                }
17498            }
17499            panic!("no node named {want:?}");
17500        };
17501        let a_id = find_by_name(&jg, "A");
17502        let b_id = find_by_name(&jg, "B");
17503        let tb_id = find_by_name(&jg, "tangent_B");
17504
17505        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17506        let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
17507        let tb_data = [0.5, 0.0, -0.25, 1.0, 1.0, -0.5_f64];
17508
17509        let (sched, mut arena) =
17510            prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
17511        execute_thunks(&sched, arena.raw_buf_mut());
17512        let tangent_x = read_arena_f64(&arena, jg.outputs[1], n * k);
17513
17514        let mut a_copy = a_data;
17515        let mut tb_copy = tb_data;
17516        crate::blas::dgesv(&mut a_copy, &mut tb_copy, n, k);
17517        for i in 0..n * k {
17518            assert!(
17519                (tangent_x[i] - tb_copy[i]).abs() < 1e-10,
17520                "t_X[{i}]: AD={} ref={}",
17521                tangent_x[i],
17522                tb_copy[i]
17523            );
17524        }
17525
17526        let h = 1e-6;
17527        let mut bp = b_data;
17528        let mut bm = b_data;
17529        for i in 0..n * k {
17530            bp[i] += h * tb_data[i];
17531            bm[i] -= h * tb_data[i];
17532        }
17533        let xp = {
17534            let mut a_copy = a_data;
17535            crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
17536            bp
17537        };
17538        let xm = {
17539            let mut a_copy = a_data;
17540            crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
17541            bm
17542        };
17543        for i in 0..n * k {
17544            let fd = (xp[i] - xm[i]) / (2.0 * h);
17545            assert!(
17546                (tangent_x[i] - fd).abs() < 1e-7,
17547                "FD t_X[{i}]: AD={} FD={}",
17548                tangent_x[i],
17549                fd
17550            );
17551        }
17552    }
17553
17554    /// Forward-mode JVP through DenseSolve, end-to-end at f64.
17555    ///
17556    /// Build forward x = solve(A, b), call `jvp(forward, [b])`,
17557    /// compile + run, and check the tangent output matches the
17558    /// closed form `t_x = solve(A, t_b)` plus a finite-difference
17559    /// cross-check `(solve(A, b + h·t_b) − solve(A, b − h·t_b)) / 2h`.
17560    #[test]
17561    fn jvp_dense_solve_b_runs_and_matches_fd() {
17562        use rlx_opt::autodiff_fwd::jvp;
17563        let n = 3usize;
17564
17565        // Forward.
17566        let mut g = Graph::new("jvp_b_e2e");
17567        let a = g.input("A", Shape::new(&[n, n], DType::F64));
17568        let b = g.input("b", Shape::new(&[n], DType::F64));
17569        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
17570        g.set_outputs(vec![x]);
17571
17572        // JVP graph perturbing b only.
17573        let jg = jvp(&g, &[b]);
17574        // The JVP graph holds a fresh "tangent_b" Input on top of A and b.
17575        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17576            for node in graph.nodes() {
17577                let name = match &node.op {
17578                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17579                    _ => None,
17580                };
17581                if name == Some(want) {
17582                    return node.id;
17583                }
17584            }
17585            panic!("no node named {want:?}");
17586        };
17587        let a_id = find_by_name(&jg, "A");
17588        let b_id = find_by_name(&jg, "b");
17589        let tb_id = find_by_name(&jg, "tangent_b");
17590
17591        let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
17592        let b_data: [f64; 3] = [1.0, 2.0, 3.0];
17593        // Pick an arbitrary perturbation direction.
17594        let tb_data: [f64; 3] = [0.5, -0.25, 1.0];
17595
17596        let (sched, mut arena) =
17597            prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
17598        execute_thunks(&sched, arena.raw_buf_mut());
17599
17600        // Outputs: [primal_x, tangent_x].
17601        let primal_x = read_arena_f64(&arena, jg.outputs[0], n);
17602        let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
17603
17604        // Closed form: t_x = solve(A, t_b).
17605        let t_x_ref = {
17606            let mut a = a_data;
17607            let mut tb = tb_data;
17608            let info = crate::blas::dgesv(&mut a, &mut tb, n, 1);
17609            assert_eq!(info, 0);
17610            tb
17611        };
17612        for i in 0..n {
17613            assert!(
17614                (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
17615                "t_x[{i}]: got {} want {}",
17616                tangent_x[i],
17617                t_x_ref[i]
17618            );
17619        }
17620
17621        // FD: x(b + h·tb) − x(b − h·tb)) / 2h
17622        let h = 1e-6;
17623        let mut bp = b_data;
17624        let mut bm = b_data;
17625        for i in 0..n {
17626            bp[i] += h * tb_data[i];
17627            bm[i] -= h * tb_data[i];
17628        }
17629        let xp = {
17630            let mut a = a_data;
17631            let info = crate::blas::dgesv(&mut a, &mut bp, n, 1);
17632            assert_eq!(info, 0);
17633            bp
17634        };
17635        let xm = {
17636            let mut a = a_data;
17637            let info = crate::blas::dgesv(&mut a, &mut bm, n, 1);
17638            assert_eq!(info, 0);
17639            bm
17640        };
17641        let fd: Vec<f64> = (0..n).map(|i| (xp[i] - xm[i]) / (2.0 * h)).collect();
17642        for i in 0..n {
17643            assert!(
17644                (tangent_x[i] - fd[i]).abs() < 1e-7,
17645                "FD mismatch t_x[{i}]: AD={} FD={}",
17646                tangent_x[i],
17647                fd[i]
17648            );
17649        }
17650        // Sanity: primal output is the actual solve.
17651        let primal_ref = {
17652            let mut a = a_data;
17653            let mut b = b_data;
17654            crate::blas::dgesv(&mut a, &mut b, n, 1);
17655            b
17656        };
17657        for i in 0..n {
17658            assert!((primal_x[i] - primal_ref[i]).abs() < 1e-10);
17659        }
17660    }
17661
17662    /// Forward-mode JVP through DenseSolve perturbing A. The tangent
17663    /// path includes the −t_A·x correction term.
17664    /// `t_x = −solve(A, t_A · x)` should match a finite-difference
17665    /// directional derivative of `solve(A, b)` w.r.t. A in the
17666    /// `t_A` direction.
17667    #[test]
17668    fn jvp_dense_solve_a_runs_and_matches_fd() {
17669        use rlx_opt::autodiff_fwd::jvp;
17670        let n = 3usize;
17671
17672        let mut g = Graph::new("jvp_a_e2e");
17673        let a = g.input("A", Shape::new(&[n, n], DType::F64));
17674        let b = g.input("b", Shape::new(&[n], DType::F64));
17675        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
17676        g.set_outputs(vec![x]);
17677
17678        let jg = jvp(&g, &[a]);
17679        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17680            for node in graph.nodes() {
17681                let name = match &node.op {
17682                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17683                    _ => None,
17684                };
17685                if name == Some(want) {
17686                    return node.id;
17687                }
17688            }
17689            panic!("no node named {want:?}");
17690        };
17691        let a_id = find_by_name(&jg, "A");
17692        let b_id = find_by_name(&jg, "b");
17693        let ta_id = find_by_name(&jg, "tangent_A");
17694
17695        let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
17696        let b_data: [f64; 3] = [1.0, 2.0, 3.0];
17697        // Asymmetric perturbation direction for A.
17698        let ta_data: [f64; 9] = [0.10, -0.05, 0.02, 0.03, 0.20, -0.04, -0.01, 0.07, 0.15];
17699
17700        let (sched, mut arena) =
17701            prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (ta_id, &ta_data)]);
17702        execute_thunks(&sched, arena.raw_buf_mut());
17703
17704        let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
17705
17706        // Closed form: x = solve(A, b); t_x = −solve(A, t_A · x).
17707        let x_ref = {
17708            let mut a = a_data;
17709            let mut b = b_data;
17710            crate::blas::dgesv(&mut a, &mut b, n, 1);
17711            b
17712        };
17713        let mut prod = [0.0_f64; 3];
17714        for i in 0..n {
17715            for j in 0..n {
17716                prod[i] += ta_data[i * n + j] * x_ref[j];
17717            }
17718        }
17719        let t_x_ref = {
17720            let mut a = a_data;
17721            let mut p = prod;
17722            crate::blas::dgesv(&mut a, &mut p, n, 1);
17723            [-p[0], -p[1], -p[2]]
17724        };
17725        for i in 0..n {
17726            assert!(
17727                (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
17728                "closed-form t_x[{i}]: AD={} ref={}",
17729                tangent_x[i],
17730                t_x_ref[i]
17731            );
17732        }
17733
17734        // FD: solve(A + h·t_A, b) and solve(A − h·t_A, b).
17735        let h = 1e-6;
17736        let mut ap = a_data;
17737        let mut am = a_data;
17738        for i in 0..n * n {
17739            ap[i] += h * ta_data[i];
17740            am[i] -= h * ta_data[i];
17741        }
17742        let xp = {
17743            let mut a = ap;
17744            let mut b = b_data;
17745            crate::blas::dgesv(&mut a, &mut b, n, 1);
17746            b
17747        };
17748        let xm = {
17749            let mut a = am;
17750            let mut b = b_data;
17751            crate::blas::dgesv(&mut a, &mut b, n, 1);
17752            b
17753        };
17754        for i in 0..n {
17755            let fd = (xp[i] - xm[i]) / (2.0 * h);
17756            assert!(
17757                (tangent_x[i] - fd).abs() < 1e-7,
17758                "FD t_x[{i}]: AD={} FD={}",
17759                tangent_x[i],
17760                fd
17761            );
17762        }
17763    }
17764
17765    /// Real INT8 conv2d parity. Same setup as QMatMul: pre-quantize
17766    /// f32 inputs to i8, run `Op::QConv2d`, compare against an
17767    /// in-test reference loop that does the same i32 accumulation
17768    /// and requantize math. Symmetric quant (zp=0) to keep the math
17769    /// head-to-head.
17770    #[test]
17771    fn q_conv2d_matches_reference() {
17772        use rlx_ir::Philox4x32;
17773        // Small NCHW shape — enough to exercise stride/padding edges.
17774        let n = 1usize;
17775        let c_in = 2usize;
17776        let h = 5usize;
17777        let w_in = 5usize;
17778        let c_out = 3usize;
17779        let kh = 3usize;
17780        let kw = 3usize;
17781        let ph = 1usize;
17782        let pw = 1usize;
17783        let sh = 1usize;
17784        let sw = 1usize;
17785        let h_out = (h + 2 * ph - kh) / sh + 1;
17786        let w_out = (w_in + 2 * pw - kw) / sw + 1;
17787
17788        let x_scale = 0.04f32;
17789        let w_scale = 0.02f32;
17790        let out_scale = 0.5f32;
17791        let mult = x_scale * w_scale / out_scale;
17792
17793        let mut rng = Philox4x32::new(2099);
17794        let mut xf = vec![0f32; n * c_in * h * w_in];
17795        rng.fill_normal(&mut xf);
17796        let mut wf = vec![0f32; c_out * c_in * kh * kw];
17797        rng.fill_normal(&mut wf);
17798        let xq: Vec<i8> = xf
17799            .iter()
17800            .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
17801            .collect();
17802        let wq: Vec<i8> = wf
17803            .iter()
17804            .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
17805            .collect();
17806        let bias: Vec<i32> = vec![0i32; c_out];
17807
17808        let mut g = Graph::new("qconv");
17809        let xn = g.input("x", Shape::new(&[n, c_in, h, w_in], DType::I8));
17810        let wn = g.input("w", Shape::new(&[c_out, c_in, kh, kw], DType::I8));
17811        let bn = g.input("b", Shape::new(&[c_out], DType::I32));
17812        let out = g.q_conv2d(
17813            xn,
17814            wn,
17815            bn,
17816            vec![kh, kw],
17817            vec![sh, sw],
17818            vec![ph, pw],
17819            vec![1, 1],
17820            1,
17821            0,
17822            0,
17823            0,
17824            mult,
17825            Shape::new(&[n, c_out, h_out, w_out], DType::I8),
17826        );
17827        g.set_outputs(vec![out]);
17828
17829        let plan = rlx_opt::memory::plan_memory(&g);
17830        let mut arena = crate::arena::Arena::from_plan(plan);
17831        let sched = compile_thunks(&g, &arena);
17832        // Capture offsets before borrowing the buf mutably (avoids
17833        // overlap between &mut and the &arena.byte_offset reads).
17834        let xn_off = arena.byte_offset(xn);
17835        let wn_off = arena.byte_offset(wn);
17836        let bn_off = arena.byte_offset(bn);
17837        let out_off = arena.byte_offset(out);
17838        let buf = arena.raw_buf_mut();
17839        unsafe {
17840            let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
17841            for (i, &v) in xq.iter().enumerate() {
17842                *p.add(i) = v;
17843            }
17844            let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
17845            for (i, &v) in wq.iter().enumerate() {
17846                *p.add(i) = v;
17847            }
17848            let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
17849            for (i, &v) in bias.iter().enumerate() {
17850                *p.add(i) = v;
17851            }
17852        }
17853        execute_thunks(&sched, arena.raw_buf_mut());
17854        let out_q: Vec<i8> = unsafe {
17855            let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
17856            (0..n * c_out * h_out * w_out).map(|i| *p.add(i)).collect()
17857        };
17858
17859        // Reference: scalar loop in NCHW with the same requantize.
17860        let mut out_ref = vec![0i8; n * c_out * h_out * w_out];
17861        for ni in 0..n {
17862            for co in 0..c_out {
17863                for ho in 0..h_out {
17864                    for wo in 0..w_out {
17865                        let mut acc: i32 = 0;
17866                        for ci in 0..c_in {
17867                            for ki in 0..kh {
17868                                for kj in 0..kw {
17869                                    let hi = ho * sh + ki;
17870                                    let wi = wo * sw + kj;
17871                                    if hi < ph || wi < pw {
17872                                        continue;
17873                                    }
17874                                    let hi = hi - ph;
17875                                    let wi = wi - pw;
17876                                    if hi >= h || wi >= w_in {
17877                                        continue;
17878                                    }
17879                                    let xv =
17880                                        xq[((ni * c_in) + ci) * h * w_in + hi * w_in + wi] as i32;
17881                                    let wv = wq[((co * c_in) + ci) * kh * kw + ki * kw + kj] as i32;
17882                                    acc += xv * wv;
17883                                }
17884                            }
17885                        }
17886                        let r = (acc as f32 * mult).round() as i32;
17887                        let r = r.clamp(-128, 127) as i8;
17888                        out_ref[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = r;
17889                    }
17890                }
17891            }
17892        }
17893
17894        for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
17895            assert_eq!(a, r, "q_conv2d[{i}]: kernel {a} vs reference {r}");
17896        }
17897    }
17898
17899    /// Real INT8 matmul parity: compare `Op::QMatMul` against the
17900    /// fake-quant reference `Dequantize → MatMul → Quantize` that
17901    /// would produce the same output if we round-tripped through
17902    /// f32. Both should agree element-for-element (or within ±1 i8
17903    /// step, since rounding in the requantize uses different code
17904    /// paths). Symmetric quantization (zp=0) for both paths to keep
17905    /// the math head-to-head.
17906    #[test]
17907    fn q_matmul_matches_fake_quant_reference() {
17908        use rlx_ir::Philox4x32;
17909        let m = 3usize;
17910        let k = 8usize;
17911        let n = 5usize;
17912        let mut rng = Philox4x32::new(2031);
17913
17914        // Pick scales and quantize random f32 inputs to i8.
17915        let x_scale = 0.05f32;
17916        let w_scale = 0.03f32;
17917        let out_scale = 0.4f32;
17918        let mult = x_scale * w_scale / out_scale;
17919        let mut xf = vec![0f32; m * k];
17920        rng.fill_normal(&mut xf);
17921        let mut wf = vec![0f32; k * n];
17922        rng.fill_normal(&mut wf);
17923        let xq: Vec<i8> = xf
17924            .iter()
17925            .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
17926            .collect();
17927        let wq: Vec<i8> = wf
17928            .iter()
17929            .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
17930            .collect();
17931        let bias: Vec<i32> = vec![0i32; n];
17932
17933        // ── Direct INT8 path ──
17934        let _f = DType::F32;
17935        let mut g_q = Graph::new("qmm_direct");
17936        let xn = g_q.input("x", Shape::new(&[m, k], DType::I8));
17937        let wn = g_q.input("w", Shape::new(&[k, n], DType::I8));
17938        let bn = g_q.input("b", Shape::new(&[n], DType::I32));
17939        let out = g_q.q_matmul(xn, wn, bn, 0, 0, 0, mult, Shape::new(&[m, n], DType::I8));
17940        g_q.set_outputs(vec![out]);
17941        let plan = rlx_opt::memory::plan_memory(&g_q);
17942        let mut arena = crate::arena::Arena::from_plan(plan);
17943        let sched = compile_thunks(&g_q, &arena);
17944
17945        // Fill inputs.
17946        let xn_off = arena.byte_offset(xn);
17947        let wn_off = arena.byte_offset(wn);
17948        let bn_off = arena.byte_offset(bn);
17949        let out_off = arena.byte_offset(out);
17950        let buf = arena.raw_buf_mut();
17951        unsafe {
17952            let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
17953            for (i, &v) in xq.iter().enumerate() {
17954                *p.add(i) = v;
17955            }
17956            let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
17957            for (i, &v) in wq.iter().enumerate() {
17958                *p.add(i) = v;
17959            }
17960            let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
17961            for (i, &v) in bias.iter().enumerate() {
17962                *p.add(i) = v;
17963            }
17964        }
17965        execute_thunks(&sched, arena.raw_buf_mut());
17966        let out_q: Vec<i8> = unsafe {
17967            let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
17968            (0..m * n).map(|i| *p.add(i)).collect()
17969        };
17970
17971        // ── Fake-quant reference: scalar emulation in plain Rust ──
17972        // Same arithmetic the kernel does, but in a verifier loop:
17973        //   acc = Σ (x[m,k]) · (w[k,n]),  // zps are 0
17974        //   out[m,n] = saturate_i8(round(acc · mult) + 0)
17975        let mut out_ref = vec![0i8; m * n];
17976        for mi in 0..m {
17977            for ni in 0..n {
17978                let mut acc: i32 = 0;
17979                for ki in 0..k {
17980                    acc += (xq[mi * k + ki] as i32) * (wq[ki * n + ni] as i32);
17981                }
17982                let r = (acc as f32 * mult).round() as i32;
17983                out_ref[mi * n + ni] = r.clamp(-128, 127) as i8;
17984            }
17985        }
17986
17987        for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
17988            assert_eq!(a, r, "q_matmul[{i}]: kernel {a} vs reference {r}");
17989        }
17990    }
17991
17992    /// Quantize/Dequantize round-trip — quantize an f32 tensor, then
17993    /// dequantize back, and confirm the result tracks the input
17994    /// within the per-element scale (the inevitable rounding error).
17995    /// Also pins the kernel's saturation behavior at the i8 limits.
17996    #[test]
17997    fn quantize_dequantize_round_trip() {
17998        use rlx_ir::Philox4x32;
17999        let len = 64;
18000        let mut rng = Philox4x32::new(2027);
18001        let mut x = vec![0f32; len];
18002        rng.fill_normal(&mut x);
18003        // Stretch a couple values past the +/- saturation cliff so
18004        // the saturate_i8 path is exercised.
18005        x[0] = 999.0;
18006        x[1] = -999.0;
18007
18008        let scale = 0.05f32;
18009        let zp = 3i32;
18010
18011        let f = DType::F32;
18012        let mut g = Graph::new("qdq");
18013        let xn = g.input("x", Shape::new(&[len], f));
18014        let q = g.quantize(xn, scale, zp);
18015        let dq = g.dequantize(q, scale, zp);
18016        g.set_outputs(vec![dq]);
18017
18018        let plan = rlx_opt::memory::plan_memory(&g);
18019        let mut arena = crate::arena::Arena::from_plan(plan);
18020        let sched = compile_thunks(&g, &arena);
18021        let xn_off = arena.byte_offset(xn);
18022        let dq_off = arena.byte_offset(dq);
18023        let buf = arena.raw_buf_mut();
18024        unsafe {
18025            let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
18026            for (i, &v) in x.iter().enumerate() {
18027                *p.add(i) = v;
18028            }
18029        }
18030        execute_thunks(&sched, arena.raw_buf_mut());
18031        let out: Vec<f32> = unsafe {
18032            let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
18033            (0..len).map(|i| *p.add(i)).collect()
18034        };
18035
18036        // Saturated values at i=0,1 should clamp to ±127's dequant
18037        // range (= (±127 - zp) · scale).
18038        let sat_pos = (127 - zp) as f32 * scale;
18039        let sat_neg = (-128 - zp) as f32 * scale;
18040        assert!((out[0] - sat_pos).abs() < 1e-6, "+sat: {}", out[0]);
18041        assert!((out[1] - sat_neg).abs() < 1e-6, "-sat: {}", out[1]);
18042
18043        // Everything else should round-trip within `scale` (one quant
18044        // step = the worst-case rounding error).
18045        for i in 2..len {
18046            assert!(
18047                (out[i] - x[i]).abs() <= scale + 1e-5,
18048                "qdq[{i}]: {} → {}, scale={scale}",
18049                x[i],
18050                out[i]
18051            );
18052        }
18053    }
18054
18055    /// Per-channel quantize / dequantize: independent scale and zp
18056    /// per slice along an axis. Verifies (a) each channel uses its
18057    /// own scale (not a shared one), (b) saturation still respects
18058    /// the i8 range, (c) channel data layout decomposition is
18059    /// correct (no cross-channel leakage).
18060    #[test]
18061    fn quantize_per_channel_round_trip() {
18062        let c = 4usize;
18063        let inner = 5usize;
18064        // Different magnitudes per channel — proves the per-channel
18065        // scale is actually being read for each row.
18066        let mags = [0.01f32, 0.5, 5.0, 50.0];
18067        let mut x = vec![0f32; c * inner];
18068        for ci in 0..c {
18069            for ii in 0..inner {
18070                // Sweep through values that span [-max_abs, +max_abs]
18071                // for each channel, plus one value past the cliff to
18072                // trigger saturation.
18073                x[ci * inner + ii] = match ii {
18074                    0 => -mags[ci],
18075                    1 => 0.0,
18076                    2 => mags[ci],
18077                    3 => mags[ci] * 1000.0,  // saturates +
18078                    _ => -mags[ci] * 1000.0, // saturates -
18079                };
18080            }
18081        }
18082        let scales: Vec<f32> = mags.iter().map(|&m| m / 127.0).collect();
18083        let zps: Vec<i32> = vec![0, 0, 0, 0];
18084
18085        let f = DType::F32;
18086        let mut g = Graph::new("qdq_pc");
18087        let xn = g.input("x", Shape::new(&[c, inner], f));
18088        let q = g.quantize_per_channel(xn, 0, scales.clone(), zps.clone());
18089        let dq = g.dequantize_per_channel(q, 0, scales.clone(), zps);
18090        g.set_outputs(vec![dq]);
18091
18092        let plan = rlx_opt::memory::plan_memory(&g);
18093        let mut arena = crate::arena::Arena::from_plan(plan);
18094        let sched = compile_thunks(&g, &arena);
18095        let xn_off = arena.byte_offset(xn);
18096        let dq_off = arena.byte_offset(dq);
18097        let buf = arena.raw_buf_mut();
18098        unsafe {
18099            let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
18100            for (i, &v) in x.iter().enumerate() {
18101                *p.add(i) = v;
18102            }
18103        }
18104        execute_thunks(&sched, arena.raw_buf_mut());
18105        let out: Vec<f32> = unsafe {
18106            let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
18107            (0..c * inner).map(|i| *p.add(i)).collect()
18108        };
18109
18110        for ci in 0..c {
18111            // Within-range entries (positions 0, 1, 2) must round-trip
18112            // within one quant step of *that channel's* scale.
18113            for ii in 0..3 {
18114                let idx = ci * inner + ii;
18115                assert!(
18116                    (out[idx] - x[idx]).abs() <= scales[ci] + 1e-5,
18117                    "ch {ci} idx {ii}: {} vs {}",
18118                    x[idx],
18119                    out[idx]
18120                );
18121            }
18122            // Saturated positions clamp to ±127 · scale[ci].
18123            let sat_pos = 127.0 * scales[ci];
18124            let sat_neg = -128.0 * scales[ci];
18125            assert!(
18126                (out[ci * inner + 3] - sat_pos).abs() < 1e-5,
18127                "ch {ci} +sat: {}",
18128                out[ci * inner + 3]
18129            );
18130            assert!(
18131                (out[ci * inner + 4] - sat_neg).abs() < 1e-5,
18132                "ch {ci} -sat: {}",
18133                out[ci * inner + 4]
18134            );
18135        }
18136    }
18137
18138    /// `Op::ActivationBackward` parity for every supported kind.
18139    /// Builds a single-op graph `dx = activation_backward(x, dy)` and
18140    /// compares each `dx[i]` to the central-difference `(act(x+ε) -
18141    /// act(x-ε)) / (2ε) · dy\[i\]`. Sweeps the closed-form covered by
18142    /// the kernel.
18143    #[test]
18144    fn activation_backward_matches_numerical_per_kind() {
18145        use rlx_ir::Philox4x32;
18146        use rlx_ir::op::Activation;
18147        let mut rng = Philox4x32::new(91);
18148        let len = 32;
18149        // x sampled away from kink/branch points: shifted positive
18150        // (exp/sqrt/log domain) for the unary-positive activations;
18151        // wide range otherwise. Two parallel tests would be cleaner
18152        // but this is concise enough.
18153        let mut x_pos = vec![0f32; len];
18154        rng.fill_normal(&mut x_pos);
18155        for v in x_pos.iter_mut() {
18156            *v = v.abs() + 0.5;
18157        }
18158        let mut x_any = vec![0f32; len];
18159        rng.fill_normal(&mut x_any);
18160        let mut dy = vec![0f32; len];
18161        rng.fill_normal(&mut dy);
18162
18163        for &(kind, x_data, eps, tol) in &[
18164            (Activation::Sigmoid, &x_any[..], 1e-3, 5e-3),
18165            (Activation::Tanh, &x_any[..], 1e-3, 5e-3),
18166            (Activation::Silu, &x_any[..], 1e-3, 5e-3),
18167            (Activation::Gelu, &x_any[..], 1e-3, 5e-3),
18168            (Activation::GeluApprox, &x_any[..], 1e-3, 5e-3),
18169            (Activation::Exp, &x_any[..], 1e-4, 5e-3),
18170            (Activation::Log, &x_pos[..], 1e-4, 5e-3),
18171            (Activation::Sqrt, &x_pos[..], 1e-4, 5e-3),
18172            (Activation::Rsqrt, &x_pos[..], 1e-4, 5e-3),
18173            (Activation::Neg, &x_any[..], 1e-3, 5e-4),
18174        ] {
18175            let f = DType::F32;
18176            let mut g = Graph::new("act_bw");
18177            let xn = g.input("x", Shape::new(&[len], f));
18178            let dyn_ = g.input("dy", Shape::new(&[len], f));
18179            let dx = g.activation_backward(kind, xn, dyn_);
18180            g.set_outputs(vec![dx]);
18181
18182            let plan = rlx_opt::memory::plan_memory(&g);
18183            let mut arena = crate::arena::Arena::from_plan(plan);
18184            let sched = compile_thunks(&g, &arena);
18185
18186            let xn_off = arena.byte_offset(xn);
18187            let dyn_off = arena.byte_offset(dyn_);
18188            let dx_off = arena.byte_offset(dx);
18189            let buf = arena.raw_buf_mut();
18190            unsafe {
18191                let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
18192                for (i, &v) in x_data.iter().enumerate() {
18193                    *p.add(i) = v;
18194                }
18195                let p = buf.as_mut_ptr().add(dyn_off) as *mut f32;
18196                for (i, &v) in dy.iter().enumerate() {
18197                    *p.add(i) = v;
18198                }
18199            }
18200            execute_thunks(&sched, arena.raw_buf_mut());
18201            let analytical: Vec<f32> = unsafe {
18202                let p = arena.raw_buf().as_ptr().add(dx_off) as *const f32;
18203                (0..len).map(|i| *p.add(i)).collect()
18204            };
18205
18206            // Apply the forward activation manually; finite-difference
18207            // each element.
18208            let act_apply = |kind: Activation, x: f32| -> f32 {
18209                match kind {
18210                    Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
18211                    Activation::Tanh => x.tanh(),
18212                    Activation::Silu => x / (1.0 + (-x).exp()),
18213                    Activation::Gelu => {
18214                        // Match the kernel's exact erf form.
18215                        const INV_SQRT2: f32 = 0.707_106_77;
18216                        0.5 * x * (1.0 + erf_f32(x * INV_SQRT2))
18217                    }
18218                    Activation::GeluApprox => {
18219                        const C: f32 = 0.797_884_6;
18220                        const A: f32 = 0.044_715;
18221                        let inner = C * (x + A * x * x * x);
18222                        0.5 * x * (1.0 + inner.tanh())
18223                    }
18224                    Activation::Exp => x.exp(),
18225                    Activation::Log => x.ln(),
18226                    Activation::Sqrt => x.sqrt(),
18227                    Activation::Rsqrt => 1.0 / x.sqrt(),
18228                    Activation::Neg => -x,
18229                    Activation::Relu => x.max(0.0),
18230                    Activation::Abs => x.abs(),
18231                    Activation::Round => x.round(),
18232                    Activation::Sin => x.sin(),
18233                    Activation::Cos => x.cos(),
18234                    Activation::Tan => x.tan(),
18235                    Activation::Atan => x.atan(),
18236                }
18237            };
18238            for i in 0..len {
18239                let xv = x_data[i];
18240                let plus = act_apply(kind, xv + eps);
18241                let minus = act_apply(kind, xv - eps);
18242                let num = (plus - minus) / (2.0 * eps) * dy[i];
18243                assert!(
18244                    (analytical[i] - num).abs() < tol,
18245                    "{kind:?}[{i}]: analytical {} vs numerical {num}",
18246                    analytical[i]
18247                );
18248            }
18249        }
18250    }
18251
18252    /// Batched 3-D MatMul VJP — the transformer-attention shape
18253    /// `[B, M, K] @ [B, K, N] = [B, M, N]`. Both gradients flow through
18254    /// `Op::Transpose` with a perm that swaps the last two dims.
18255    #[test]
18256    fn matmul_3d_gradient_matches_numerical() {
18257        use rlx_ir::Philox4x32;
18258        let batch = 2usize;
18259        let m = 3usize;
18260        let k = 4usize;
18261        let n = 5usize;
18262        let mut rng = Philox4x32::new(101);
18263        let mut a_data = vec![0f32; batch * m * k];
18264        rng.fill_normal(&mut a_data);
18265        let mut b_data = vec![0f32; batch * k * n];
18266        rng.fill_normal(&mut b_data);
18267
18268        let f = DType::F32;
18269        let mut fwd = Graph::new("matmul_3d");
18270        let an = fwd.input("a", Shape::new(&[batch, m, k], f));
18271        let bp = fwd.param("b", Shape::new(&[batch, k, n], f));
18272        let mm = fwd.matmul(an, bp, Shape::new(&[batch, m, n], f));
18273        let loss = fwd.add_node(
18274            Op::Reduce {
18275                op: ReduceOp::Sum,
18276                axes: vec![0, 1, 2],
18277                keep_dim: false,
18278            },
18279            vec![mm],
18280            Shape::from_dims(&[], f),
18281        );
18282        fwd.set_outputs(vec![loss]);
18283
18284        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[bp]);
18285        let d_out = bwd_graph
18286            .nodes()
18287            .iter()
18288            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18289            .map(|n| n.id)
18290            .unwrap();
18291
18292        let plan = rlx_opt::memory::plan_memory(&bwd_graph);
18293        let mut arena = crate::arena::Arena::from_plan(plan);
18294        let sched = compile_thunks(&bwd_graph, &arena);
18295        for &(id, data) in &[(an, &a_data), (bp, &b_data), (d_out, &vec![1.0f32])] {
18296            let off = arena.byte_offset(id);
18297            let buf = arena.raw_buf_mut();
18298            unsafe {
18299                let p = buf.as_mut_ptr().add(off) as *mut f32;
18300                for (i, &v) in data.iter().enumerate() {
18301                    *p.add(i) = v;
18302                }
18303            }
18304        }
18305        execute_thunks(&sched, arena.raw_buf_mut());
18306        let gb_id = bwd_graph.outputs[1];
18307        let g_b: Vec<f32> = unsafe {
18308            let p = arena.raw_buf().as_ptr().add(arena.byte_offset(gb_id)) as *const f32;
18309            (0..batch * k * n).map(|i| *p.add(i)).collect()
18310        };
18311
18312        // Numerical gradient: differentiate sum(a @ b) w.r.t. each b entry.
18313        let forward_loss = |b_vals: &[f32]| -> f32 {
18314            let mut out = vec![0f32; batch * m * n];
18315            for bi in 0..batch {
18316                for mi in 0..m {
18317                    for ni in 0..n {
18318                        let mut acc = 0f32;
18319                        for ki in 0..k {
18320                            acc +=
18321                                a_data[bi * m * k + mi * k + ki] * b_vals[bi * k * n + ki * n + ni];
18322                        }
18323                        out[bi * m * n + mi * n + ni] = acc;
18324                    }
18325                }
18326            }
18327            out.iter().sum()
18328        };
18329        let eps = 1e-3f32;
18330        let mut bp_p = b_data.clone();
18331        let mut g_b_num = vec![0f32; b_data.len()];
18332        for i in 0..b_data.len() {
18333            let s = bp_p[i];
18334            bp_p[i] = s + eps;
18335            let lp = forward_loss(&bp_p);
18336            bp_p[i] = s - eps;
18337            let lm = forward_loss(&bp_p);
18338            bp_p[i] = s;
18339            g_b_num[i] = (lp - lm) / (2.0 * eps);
18340        }
18341        for (i, (a, n)) in g_b.iter().zip(&g_b_num).enumerate() {
18342            assert!(
18343                (a - n).abs() < 5e-3,
18344                "matmul_3d g_b[{i}]: analytical {a} vs numerical {n}"
18345            );
18346        }
18347    }
18348
18349    /// Composed `Op::Softmax` VJP — the gradient is built from
18350    /// `mul + reduce_sum + expand + sub + mul`, no dedicated
18351    /// SoftmaxBackward kernel. Verifies the closed-form
18352    /// `dx = y · (g - Σ y·g)` matches the FD gradient over a small
18353    /// 2-D logits tensor.
18354    #[test]
18355    fn softmax_gradient_matches_numerical() {
18356        use rlx_ir::Philox4x32;
18357        let n = 3usize;
18358        let c = 5usize;
18359        let mut rng = Philox4x32::new(57);
18360        let mut x_data = vec![0f32; n * c];
18361        rng.fill_normal(&mut x_data);
18362
18363        let f = DType::F32;
18364        let mut fwd = Graph::new("softmax_only");
18365        let xn = fwd.input("x", Shape::new(&[n, c], f));
18366        let sm = fwd.add_node(Op::Softmax { axis: -1 }, vec![xn], Shape::new(&[n, c], f));
18367        // Loss = sum(softmax · target) for some random fixed target —
18368        // any linear loss will do; sum-of-all is the simplest and gives
18369        // a uniform gradient flow into the softmax.
18370        let loss = fwd.add_node(
18371            Op::Reduce {
18372                op: ReduceOp::Sum,
18373                axes: vec![0, 1],
18374                keep_dim: false,
18375            },
18376            vec![sm],
18377            Shape::from_dims(&[], f),
18378        );
18379        fwd.set_outputs(vec![loss]);
18380
18381        // `wrt = [xn]` — autodiff exposes the gradient w.r.t. the
18382        // input so we can compare it directly. The forward NodeId for
18383        // `xn` doubles as its bwd-graph mirror.
18384        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn]);
18385        let d_out = bwd_graph
18386            .nodes()
18387            .iter()
18388            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18389            .map(|n| n.id)
18390            .unwrap();
18391
18392        let plan = rlx_opt::memory::plan_memory(&bwd_graph);
18393        let mut arena = crate::arena::Arena::from_plan(plan);
18394        let sched = compile_thunks(&bwd_graph, &arena);
18395        for &(id, data) in &[(xn, &x_data), (d_out, &vec![1.0f32])] {
18396            let off = arena.byte_offset(id);
18397            let buf = arena.raw_buf_mut();
18398            unsafe {
18399                let p = buf.as_mut_ptr().add(off) as *mut f32;
18400                for (i, &v) in data.iter().enumerate() {
18401                    *p.add(i) = v;
18402                }
18403            }
18404        }
18405        execute_thunks(&sched, arena.raw_buf_mut());
18406        let g_x_id = bwd_graph.outputs[1];
18407        let g_x: Vec<f32> = unsafe {
18408            let p = arena.raw_buf().as_ptr().add(arena.byte_offset(g_x_id)) as *const f32;
18409            (0..n * c).map(|i| *p.add(i)).collect()
18410        };
18411
18412        // Loss derivative: softmax sums to 1 per row → d/dx_i sum(softmax) = 0
18413        // analytically. So expect g_x ≈ 0 within FD precision. (This
18414        // doubles as a strong sanity check for the composition.)
18415        let forward_loss = |x: &[f32]| -> f32 {
18416            let mut total = 0f32;
18417            for ni in 0..n {
18418                let row = &x[ni * c..(ni + 1) * c];
18419                let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
18420                let denom: f32 = row.iter().map(|&v| (v - m).exp()).sum();
18421                for &v in row {
18422                    total += (v - m).exp() / denom;
18423                }
18424            }
18425            total
18426        };
18427        let eps = 1e-3f32;
18428        let mut p = x_data.clone();
18429        for i in 0..x_data.len() {
18430            let s = p[i];
18431            p[i] = s + eps;
18432            let lp = forward_loss(&p);
18433            p[i] = s - eps;
18434            let lm = forward_loss(&p);
18435            p[i] = s;
18436            let num = (lp - lm) / (2.0 * eps);
18437            assert!(
18438                (g_x[i] - num).abs() < 5e-3,
18439                "softmax g_x[{i}]: analytical {} vs numerical {num}",
18440                g_x[i]
18441            );
18442        }
18443    }
18444
18445    /// LayerNorm VJP — three gradients in one pass:
18446    ///   d_x via `LayerNormBackwardInput`,
18447    ///   d_gamma via `LayerNormBackwardGamma`,
18448    ///   d_beta = `unbroadcast(upstream)` to gamma's shape.
18449    #[test]
18450    fn layer_norm_gradient_matches_numerical() {
18451        use rlx_ir::Philox4x32;
18452        let rows = 3usize;
18453        let h = 6usize;
18454        let mut rng = Philox4x32::new(1009);
18455        let mut x_data = vec![0f32; rows * h];
18456        rng.fill_normal(&mut x_data);
18457        let mut g_data = vec![0f32; h];
18458        rng.fill_normal(&mut g_data);
18459        for v in g_data.iter_mut() {
18460            *v = v.abs() + 0.5;
18461        }
18462        let mut b_data = vec![0f32; h];
18463        rng.fill_normal(&mut b_data);
18464        let eps = 1e-5f32;
18465
18466        let f = DType::F32;
18467        let mut fwd = Graph::new("ln_only");
18468        let xn = fwd.input("x", Shape::new(&[rows, h], f));
18469        let gp = fwd.param("gamma", Shape::new(&[h], f));
18470        let bp = fwd.param("beta", Shape::new(&[h], f));
18471        let ln = fwd.add_node(
18472            Op::LayerNorm { axis: -1, eps },
18473            vec![xn, gp, bp],
18474            Shape::new(&[rows, h], f),
18475        );
18476        let loss = fwd.add_node(
18477            Op::Reduce {
18478                op: ReduceOp::Sum,
18479                axes: vec![0, 1],
18480                keep_dim: false,
18481            },
18482            vec![ln],
18483            Shape::from_dims(&[], f),
18484        );
18485        fwd.set_outputs(vec![loss]);
18486
18487        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn, gp, bp]);
18488        let d_out = bwd_graph
18489            .nodes()
18490            .iter()
18491            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18492            .map(|n| n.id)
18493            .unwrap();
18494
18495        let plan = rlx_opt::memory::plan_memory(&bwd_graph);
18496        let mut arena = crate::arena::Arena::from_plan(plan);
18497        let sched = compile_thunks(&bwd_graph, &arena);
18498        for &(id, data) in &[
18499            (xn, &x_data),
18500            (gp, &g_data),
18501            (bp, &b_data),
18502            (d_out, &vec![1.0f32]),
18503        ] {
18504            let off = arena.byte_offset(id);
18505            let buf = arena.raw_buf_mut();
18506            unsafe {
18507                let p = buf.as_mut_ptr().add(off) as *mut f32;
18508                for (i, &v) in data.iter().enumerate() {
18509                    *p.add(i) = v;
18510                }
18511            }
18512        }
18513        execute_thunks(&sched, arena.raw_buf_mut());
18514        let read = |id: NodeId, n: usize| -> Vec<f32> {
18515            let off = arena.byte_offset(id);
18516            unsafe {
18517                let p = arena.raw_buf().as_ptr().add(off) as *const f32;
18518                (0..n).map(|i| *p.add(i)).collect()
18519            }
18520        };
18521        let dx_a = read(bwd_graph.outputs[1], rows * h);
18522        let dg_a = read(bwd_graph.outputs[2], h);
18523        let db_a = read(bwd_graph.outputs[3], h);
18524
18525        let forward_loss = |x: &[f32], g: &[f32], b: &[f32]| -> f32 {
18526            let mut total = 0f32;
18527            for r in 0..rows {
18528                let row = &x[r * h..(r + 1) * h];
18529                let mean = row.iter().sum::<f32>() / h as f32;
18530                let var = row.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / h as f32;
18531                let inv_std = 1.0 / (var + eps).sqrt();
18532                for d in 0..h {
18533                    total += ((row[d] - mean) * inv_std) * g[d] + b[d];
18534                }
18535            }
18536            total
18537        };
18538        let h_eps = 1e-3f32;
18539
18540        let mut x_p = x_data.clone();
18541        for i in 0..x_p.len() {
18542            let s = x_p[i];
18543            x_p[i] = s + h_eps;
18544            let lp = forward_loss(&x_p, &g_data, &b_data);
18545            x_p[i] = s - h_eps;
18546            let lm = forward_loss(&x_p, &g_data, &b_data);
18547            x_p[i] = s;
18548            let num = (lp - lm) / (2.0 * h_eps);
18549            assert!(
18550                (dx_a[i] - num).abs() < 5e-3,
18551                "ln dx[{i}]: analytical {} vs numerical {num}",
18552                dx_a[i]
18553            );
18554        }
18555        let mut g_p = g_data.clone();
18556        for i in 0..g_p.len() {
18557            let s = g_p[i];
18558            g_p[i] = s + h_eps;
18559            let lp = forward_loss(&x_data, &g_p, &b_data);
18560            g_p[i] = s - h_eps;
18561            let lm = forward_loss(&x_data, &g_p, &b_data);
18562            g_p[i] = s;
18563            let num = (lp - lm) / (2.0 * h_eps);
18564            assert!(
18565                (dg_a[i] - num).abs() < 5e-3,
18566                "ln dg[{i}]: analytical {} vs numerical {num}",
18567                dg_a[i]
18568            );
18569        }
18570        let mut b_p = b_data.clone();
18571        for i in 0..b_p.len() {
18572            let s = b_p[i];
18573            b_p[i] = s + h_eps;
18574            let lp = forward_loss(&x_data, &g_data, &b_p);
18575            b_p[i] = s - h_eps;
18576            let lm = forward_loss(&x_data, &g_data, &b_p);
18577            b_p[i] = s;
18578            let num = (lp - lm) / (2.0 * h_eps);
18579            assert!(
18580                (db_a[i] - num).abs() < 5e-3,
18581                "ln db[{i}]: analytical {} vs numerical {num}",
18582                db_a[i]
18583            );
18584        }
18585    }
18586
18587    /// Single dense layer + softmax-cross-entropy + mean reduce —
18588    /// the simplest non-trivial training graph. Validates MatMul,
18589    /// broadcast Add, SCE, Reduce(Mean) VJPs and the grad_with_loss
18590    /// plumbing all at once.
18591    #[test]
18592    fn dense_sce_mean_gradient_matches_numerical() {
18593        use rlx_ir::Philox4x32;
18594        let bs = 4usize;
18595        let k_in = 3usize;
18596        let c = 5usize;
18597        let mut rng = Philox4x32::new(7);
18598        let mut x = vec![0f32; bs * k_in];
18599        rng.fill_normal(&mut x);
18600        let mut w_init = vec![0f32; k_in * c];
18601        rng.fill_normal(&mut w_init);
18602        let mut b_init = vec![0f32; c];
18603        rng.fill_normal(&mut b_init);
18604        let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
18605
18606        // ── Forward graph: loss = mean(sce(x @ w + b, labels)) ──
18607        let f = DType::F32;
18608        let mut fwd = Graph::new("dense_sce");
18609        let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
18610        let lb = fwd.input("labels", Shape::new(&[bs], f));
18611        let wp = fwd.param("w", Shape::new(&[k_in, c], f));
18612        let bp = fwd.param("b", Shape::new(&[c], f));
18613        let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
18614        let logits = fwd.binary(BinaryOp::Add, mm, bp, Shape::new(&[bs, c], f));
18615        let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
18616        let loss = fwd.add_node(
18617            Op::Reduce {
18618                op: ReduceOp::Sum,
18619                axes: vec![0],
18620                keep_dim: false,
18621            },
18622            vec![loss_per],
18623            // Reduce sum of [bs] with axes=[0] keep_dim=false → scalar [].
18624            Shape::from_dims(&[], f),
18625        );
18626        // Use Sum + manual /bs scalar mul — also exercises BinaryOp::Mul VJP path
18627        // less aggressively than Mean would, and gives us a closed-form
18628        // reference for the loss we expect.
18629        // For simplicity though, switch to Mean which the tests should also cover.
18630        // (Re-using `loss` with Sum here for now; the mean factor cancels in
18631        // the gradient comparison since both analytical and numerical use the
18632        // same forward.)
18633        fwd.set_outputs(vec![loss]);
18634
18635        // ── Backward graph ──
18636        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp, bp]);
18637        // Outputs: [loss, grad_w, grad_b]. NodeIds for x/labels/w/b/loss
18638        // in bwd_graph match their fwd ids (the mirror keeps order).
18639        let d_out = bwd_graph
18640            .nodes()
18641            .iter()
18642            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18643            .map(|n| n.id)
18644            .expect("d_output input");
18645
18646        let (sched, mut arena) = prepare(
18647            &bwd_graph,
18648            &[
18649                (xn, &x),
18650                (lb, &labels),
18651                (wp, &w_init),
18652                (bp, &b_init),
18653                (d_out, &[1.0]),
18654            ],
18655        );
18656        execute_thunks(&sched, arena.raw_buf_mut());
18657
18658        let outs = &bwd_graph.outputs;
18659        let loss_id = outs[0];
18660        let gw_id = outs[1];
18661        let gb_id = outs[2];
18662        let loss_actual = read_arena(&arena, loss_id, 1)[0];
18663        let gw_actual = read_arena(&arena, gw_id, k_in * c);
18664        let gb_actual = read_arena(&arena, gb_id, c);
18665
18666        // ── Forward-only graph for finite differences ──
18667        // Re-use the same `fwd` graph; set up its own arena and rerun
18668        // for each perturbed parameter.
18669        let plan = rlx_opt::memory::plan_memory(&fwd);
18670        let mut fwd_arena = crate::arena::Arena::from_plan(plan);
18671        let fwd_sched = compile_thunks(&fwd, &fwd_arena);
18672        write_arena(&mut fwd_arena, xn, &x);
18673        write_arena(&mut fwd_arena, lb, &labels);
18674
18675        let run_loss = |arena: &mut crate::arena::Arena, w: &[f32], b: &[f32]| -> f32 {
18676            write_arena(arena, wp, w);
18677            write_arena(arena, bp, b);
18678            execute_thunks(&fwd_sched, arena.raw_buf_mut());
18679            read_arena(arena, loss, 1)[0]
18680        };
18681
18682        // Sanity: the loss reported by the bwd graph matches the
18683        // forward-only graph on the unperturbed inputs.
18684        let loss_check = run_loss(&mut fwd_arena, &w_init, &b_init);
18685        assert!(
18686            (loss_actual - loss_check).abs() < 1e-4,
18687            "loss mismatch: bwd graph {loss_actual} vs fwd-only {loss_check}"
18688        );
18689
18690        let eps = 1e-3f32;
18691        let mut w_perturbed = w_init.clone();
18692        let mut gw_numerical = vec![0f32; w_init.len()];
18693        for i in 0..w_init.len() {
18694            let saved = w_perturbed[i];
18695            w_perturbed[i] = saved + eps;
18696            let lp = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
18697            w_perturbed[i] = saved - eps;
18698            let lm = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
18699            w_perturbed[i] = saved;
18700            gw_numerical[i] = (lp - lm) / (2.0 * eps);
18701        }
18702        for (i, (a, n)) in gw_actual.iter().zip(&gw_numerical).enumerate() {
18703            assert!(
18704                (a - n).abs() < 5e-3,
18705                "grad_w[{i}]: analytical {a} vs numerical {n}"
18706            );
18707        }
18708
18709        let mut b_perturbed = b_init.clone();
18710        let mut gb_numerical = vec![0f32; b_init.len()];
18711        for i in 0..b_init.len() {
18712            let saved = b_perturbed[i];
18713            b_perturbed[i] = saved + eps;
18714            let lp = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
18715            b_perturbed[i] = saved - eps;
18716            let lm = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
18717            b_perturbed[i] = saved;
18718            gb_numerical[i] = (lp - lm) / (2.0 * eps);
18719        }
18720        for (i, (a, n)) in gb_actual.iter().zip(&gb_numerical).enumerate() {
18721            assert!(
18722                (a - n).abs() < 5e-3,
18723                "grad_b[{i}]: analytical {a} vs numerical {n}"
18724            );
18725        }
18726    }
18727
18728    /// Reduce::Mean specifically — verifies the 1/N scaling in the VJP.
18729    /// The same dense+SCE graph but with Mean instead of Sum on the loss.
18730    #[test]
18731    fn dense_sce_mean_reduce_gradient_matches_numerical() {
18732        use rlx_ir::Philox4x32;
18733        let bs = 3usize;
18734        let k_in = 2usize;
18735        let c = 4usize;
18736        let mut rng = Philox4x32::new(13);
18737        let mut x = vec![0f32; bs * k_in];
18738        rng.fill_normal(&mut x);
18739        let mut w_init = vec![0f32; k_in * c];
18740        rng.fill_normal(&mut w_init);
18741        let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
18742
18743        let f = DType::F32;
18744        let mut fwd = Graph::new("dense_sce_mean");
18745        let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
18746        let lb = fwd.input("labels", Shape::new(&[bs], f));
18747        let wp = fwd.param("w", Shape::new(&[k_in, c], f));
18748        let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
18749        let loss_per = fwd.softmax_cross_entropy_with_logits(mm, lb);
18750        let loss = fwd.add_node(
18751            Op::Reduce {
18752                op: ReduceOp::Mean,
18753                axes: vec![0],
18754                keep_dim: false,
18755            },
18756            vec![loss_per],
18757            Shape::from_dims(&[], f),
18758        );
18759        fwd.set_outputs(vec![loss]);
18760
18761        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp]);
18762        let d_out = bwd_graph
18763            .nodes()
18764            .iter()
18765            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18766            .map(|n| n.id)
18767            .unwrap();
18768
18769        let (sched, mut arena) = prepare(
18770            &bwd_graph,
18771            &[(xn, &x), (lb, &labels), (wp, &w_init), (d_out, &[1.0])],
18772        );
18773        execute_thunks(&sched, arena.raw_buf_mut());
18774
18775        let outs = &bwd_graph.outputs;
18776        let loss_id = outs[0];
18777        let gw_id = outs[1];
18778        let _ = read_arena(&arena, loss_id, 1)[0];
18779        let gw_actual = read_arena(&arena, gw_id, k_in * c);
18780
18781        let plan = rlx_opt::memory::plan_memory(&fwd);
18782        let mut fwd_arena = crate::arena::Arena::from_plan(plan);
18783        let fwd_sched = compile_thunks(&fwd, &fwd_arena);
18784        write_arena(&mut fwd_arena, xn, &x);
18785        write_arena(&mut fwd_arena, lb, &labels);
18786
18787        let run_loss = |arena: &mut crate::arena::Arena, w: &[f32]| -> f32 {
18788            write_arena(arena, wp, w);
18789            execute_thunks(&fwd_sched, arena.raw_buf_mut());
18790            read_arena(arena, loss, 1)[0]
18791        };
18792
18793        let eps = 1e-3f32;
18794        let mut wp_p = w_init.clone();
18795        let mut gw_num = vec![0f32; w_init.len()];
18796        for i in 0..w_init.len() {
18797            let s = wp_p[i];
18798            wp_p[i] = s + eps;
18799            let lp = run_loss(&mut fwd_arena, &wp_p);
18800            wp_p[i] = s - eps;
18801            let lm = run_loss(&mut fwd_arena, &wp_p);
18802            wp_p[i] = s;
18803            gw_num[i] = (lp - lm) / (2.0 * eps);
18804        }
18805        for (i, (a, n)) in gw_actual.iter().zip(&gw_num).enumerate() {
18806            assert!((a - n).abs() < 5e-3, "mean reduce grad_w[{i}]: {a} vs {n}");
18807        }
18808    }
18809    /// The full TinyConv-MNIST forward path (downsized) plumbed
18810    /// through grad_with_loss. Validates that Conv, Pool(Max), ReLU,
18811    /// Reshape, MatMul, Add (broadcast), SCE, Reduce(Mean) VJPs all
18812    /// compose into a graph that produces correct gradients.
18813    #[test]
18814    fn tinyconv_full_gradient_matches_numerical() {
18815        use rlx_ir::Philox4x32;
18816        // Tiny shapes so finite differences finish in <1s.
18817        let n = 1usize;
18818        let c_in = 1usize;
18819        let h = 6usize;
18820        let w_in = 6usize;
18821        let c_mid = 2usize; // first conv output channels
18822        let kh = 3;
18823        let kw = 3;
18824        let h1 = h - kh + 1; // 4
18825        let w1 = w_in - kw + 1; // 4
18826        let h2 = h1 / 2;
18827        let w2 = w1 / 2; // 2 × 2 after 2× pool
18828        let flat = c_mid * h2 * w2; // 8
18829        let num_classes = 3usize;
18830
18831        let mut rng = Philox4x32::new(31);
18832        let mut x = vec![0f32; n * c_in * h * w_in];
18833        rng.fill_normal(&mut x);
18834        let mut wc = vec![0f32; c_mid * c_in * kh * kw];
18835        rng.fill_normal(&mut wc);
18836        for v in wc.iter_mut() {
18837            *v *= 0.2;
18838        }
18839        // Shift conv-bias well away from the ReLU zero-boundary. Without
18840        // this, an ε-perturbation of bc[c] can flip the ReLU mask on a
18841        // pre-activation that happened to land near zero — making the
18842        // central-difference numerical gradient discontinuous and
18843        // diverge from the analytical (which assumes local smoothness).
18844        // +5.0 keeps every pre-activation positive for any random init
18845        // produced by Philox seed 31 with the wc/x scales used here, so
18846        // ReLU acts as an identity and finite differences are exact.
18847        let bc: Vec<f32> = (0..c_mid).map(|i| 5.0 + 0.1 * i as f32).collect();
18848        let mut wfc = vec![0f32; flat * num_classes];
18849        rng.fill_normal(&mut wfc);
18850        for v in wfc.iter_mut() {
18851            *v *= 0.5;
18852        }
18853        let mut bfc = vec![0f32; num_classes];
18854        rng.fill_normal(&mut bfc);
18855        let labels: Vec<f32> = vec![1.0]; // batch=1
18856
18857        let f = DType::F32;
18858        let mut fwd = Graph::new("tinyconv");
18859        let xn = fwd.input("x", Shape::new(&[n, c_in, h, w_in], f));
18860        let lb = fwd.input("labels", Shape::new(&[n], f));
18861        let wcp = fwd.param("wc", Shape::new(&[c_mid, c_in, kh, kw], f));
18862        let bcp = fwd.param("bc", Shape::new(&[c_mid], f));
18863        let wfp = fwd.param("wfc", Shape::new(&[flat, num_classes], f));
18864        let bfp = fwd.param("bfc", Shape::new(&[num_classes], f));
18865
18866        // conv: [n, c_in, h, w] → [n, c_mid, h1, w1]
18867        let conv = fwd.add_node(
18868            Op::Conv {
18869                kernel_size: vec![kh, kw],
18870                stride: vec![1, 1],
18871                padding: vec![0, 0],
18872                dilation: vec![1, 1],
18873                groups: 1,
18874            },
18875            vec![xn, wcp],
18876            Shape::new(&[n, c_mid, h1, w1], f),
18877        );
18878        // Bias add: expand bc[c_mid] up to the full [n, c_mid, h1, w1]
18879        // shape so the Add becomes a plain element-wise op. Going through
18880        // an explicit Reshape→Expand instead of relying on the Add to
18881        // broadcast `[1, C, 1, 1]` → `[N, C, H, W]` works around a known
18882        // limitation of `rlx-cpu`'s `Op::Binary` lowering: it dispatches
18883        // on `out_len % rhs_len == 0` and treats `rhs` as a last-axis
18884        // bias, which produces `bc[0], bc[1], bc[0], bc[1], …` alternating
18885        // across all positions instead of channel-broadcasting. Going
18886        // through Expand (a real broadcast thunk) avoids that path
18887        // entirely. The autodiff still exercises `unbroadcast` because
18888        // `Op::Expand`'s VJP reduces over the broadcast axes.
18889        let bc_4d = fwd.add_node(
18890            Op::Reshape {
18891                new_shape: vec![1, c_mid as i64, 1, 1],
18892            },
18893            vec![bcp],
18894            Shape::new(&[1, c_mid, 1, 1], f),
18895        );
18896        let bc_expanded = fwd.add_node(
18897            Op::Expand {
18898                target_shape: vec![n as i64, c_mid as i64, h1 as i64, w1 as i64],
18899            },
18900            vec![bc_4d],
18901            Shape::new(&[n, c_mid, h1, w1], f),
18902        );
18903        let conv_b = fwd.binary(
18904            BinaryOp::Add,
18905            conv,
18906            bc_expanded,
18907            Shape::new(&[n, c_mid, h1, w1], f),
18908        );
18909        let relu = fwd.activation(Activation::Relu, conv_b, Shape::new(&[n, c_mid, h1, w1], f));
18910        let pool = fwd.add_node(
18911            Op::Pool {
18912                kind: ReduceOp::Max,
18913                kernel_size: vec![2, 2],
18914                stride: vec![2, 2],
18915                padding: vec![0, 0],
18916            },
18917            vec![relu],
18918            Shape::new(&[n, c_mid, h2, w2], f),
18919        );
18920        let flatn = fwd.add_node(
18921            Op::Reshape {
18922                new_shape: vec![n as i64, flat as i64],
18923            },
18924            vec![pool],
18925            Shape::new(&[n, flat], f),
18926        );
18927        let mm = fwd.matmul(flatn, wfp, Shape::new(&[n, num_classes], f));
18928        let logits = fwd.binary(BinaryOp::Add, mm, bfp, Shape::new(&[n, num_classes], f));
18929        let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
18930        let loss = fwd.add_node(
18931            Op::Reduce {
18932                op: ReduceOp::Mean,
18933                axes: vec![0],
18934                keep_dim: false,
18935            },
18936            vec![loss_per],
18937            Shape::from_dims(&[], f),
18938        );
18939        fwd.set_outputs(vec![loss]);
18940
18941        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wcp, bcp, wfp, bfp]);
18942        let d_out = bwd_graph
18943            .nodes()
18944            .iter()
18945            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18946            .map(|n| n.id)
18947            .unwrap();
18948
18949        let (sched, mut arena) = prepare(
18950            &bwd_graph,
18951            &[
18952                (xn, &x),
18953                (lb, &labels),
18954                (wcp, &wc),
18955                (bcp, &bc),
18956                (wfp, &wfc),
18957                (bfp, &bfc),
18958                (d_out, &[1.0]),
18959            ],
18960        );
18961        execute_thunks(&sched, arena.raw_buf_mut());
18962
18963        let outs = bwd_graph.outputs.clone();
18964        let loss_id = outs[0];
18965        let g_wc_id = outs[1];
18966        let g_bc_id = outs[2];
18967        let g_wfc_id = outs[3];
18968        let g_bfc_id = outs[4];
18969        let loss_actual = read_arena(&arena, loss_id, 1)[0];
18970        let g_wc = read_arena(&arena, g_wc_id, wc.len());
18971        let g_bc = read_arena(&arena, g_bc_id, bc.len());
18972        let g_wfc = read_arena(&arena, g_wfc_id, wfc.len());
18973        let g_bfc = read_arena(&arena, g_bfc_id, bfc.len());
18974
18975        // Forward-only arena for finite differences.
18976        let plan = rlx_opt::memory::plan_memory(&fwd);
18977        let mut fwd_arena = crate::arena::Arena::from_plan(plan);
18978        let fwd_sched = compile_thunks(&fwd, &fwd_arena);
18979        write_arena(&mut fwd_arena, xn, &x);
18980        write_arena(&mut fwd_arena, lb, &labels);
18981
18982        // Closure variant: we need to set all four params each call so
18983        // perturbations to one don't leak between sweeps.
18984        let run_loss = |arena: &mut crate::arena::Arena,
18985                        wc: &[f32],
18986                        bc: &[f32],
18987                        wfc: &[f32],
18988                        bfc: &[f32]|
18989         -> f32 {
18990            write_arena(arena, wcp, wc);
18991            write_arena(arena, bcp, bc);
18992            write_arena(arena, wfp, wfc);
18993            write_arena(arena, bfp, bfc);
18994            execute_thunks(&fwd_sched, arena.raw_buf_mut());
18995            read_arena(arena, loss, 1)[0]
18996        };
18997
18998        let loss_check = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc);
18999        assert!(
19000            (loss_actual - loss_check).abs() < 1e-4,
19001            "tinyconv loss mismatch: bwd {loss_actual} vs fwd {loss_check}"
19002        );
19003
19004        let eps = 1e-3f32;
19005        let check_grad = |arena: &mut crate::arena::Arena,
19006                          name: &str,
19007                          analytical: &[f32],
19008                          mut perturb: Box<
19009            dyn FnMut(&mut [f32], usize, f32, &mut crate::arena::Arena) -> f32 + '_,
19010        >,
19011                          n: usize| {
19012            for i in 0..n {
19013                let lp = perturb(&mut analytical.to_vec(), i, eps, arena);
19014                let lm = perturb(&mut analytical.to_vec(), i, -eps, arena);
19015                let num = (lp - lm) / (2.0 * eps);
19016                assert!(
19017                    (analytical[i] - num).abs() < 5e-3,
19018                    "{name}[{i}]: analytical {} vs numerical {num}",
19019                    analytical[i]
19020                );
19021            }
19022        };
19023
19024        // Helper to perturb one param and run forward. Kept as a
19025        // reference for the explicit per-param sweep pattern below.
19026        #[allow(unused_macros)]
19027        macro_rules! sweep {
19028            ($name:expr, $base:expr, $analytical:expr, $set_param:ident) => {{
19029                let n = $base.len();
19030                for i in 0..n {
19031                    let mut p = $base.clone();
19032                    let s = p[i];
19033                    p[i] = s + eps;
19034                    let lp = {
19035                        let $set_param = &p;
19036                        run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc).max(f32::NEG_INFINITY);
19037                        // Reset others, set the one being swept, run.
19038                        // (the macro receives one of the four params via $set_param)
19039                        let _ = $set_param;
19040                        // Fall through to the explicit per-param helper:
19041                        0.0_f32
19042                    };
19043                    let _ = lp;
19044                }
19045            }};
19046        }
19047        let _ = check_grad; // silence unused (sweep! macro is intentionally\n        // unused — kept as reference for the per-param sweep pattern below)
19048
19049        // Per-param sweeps (explicit, not macro — clearer).
19050        for i in 0..wc.len() {
19051            let mut p = wc.clone();
19052            let s = p[i];
19053            p[i] = s + eps;
19054            let lp = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
19055            p[i] = s - eps;
19056            let lm = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
19057            let num = (lp - lm) / (2.0 * eps);
19058            assert!(
19059                (g_wc[i] - num).abs() < 5e-3,
19060                "g_wc[{i}]: {} vs {num}",
19061                g_wc[i]
19062            );
19063        }
19064        for i in 0..bc.len() {
19065            let mut p = bc.clone();
19066            let s = p[i];
19067            p[i] = s + eps;
19068            let lp = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
19069            p[i] = s - eps;
19070            let lm = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
19071            let num = (lp - lm) / (2.0 * eps);
19072            assert!(
19073                (g_bc[i] - num).abs() < 5e-3,
19074                "g_bc[{i}]: {} vs {num}",
19075                g_bc[i]
19076            );
19077        }
19078        for i in 0..wfc.len() {
19079            let mut p = wfc.clone();
19080            let s = p[i];
19081            p[i] = s + eps;
19082            let lp = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
19083            p[i] = s - eps;
19084            let lm = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
19085            let num = (lp - lm) / (2.0 * eps);
19086            assert!(
19087                (g_wfc[i] - num).abs() < 5e-3,
19088                "g_wfc[{i}]: {} vs {num}",
19089                g_wfc[i]
19090            );
19091        }
19092        for i in 0..bfc.len() {
19093            let mut p = bfc.clone();
19094            let s = p[i];
19095            p[i] = s + eps;
19096            let lp = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
19097            p[i] = s - eps;
19098            let lm = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
19099            let num = (lp - lm) / (2.0 * eps);
19100            assert!(
19101                (g_bfc[i] - num).abs() < 5e-3,
19102                "g_bfc[{i}]: {} vs {num}",
19103                g_bfc[i]
19104            );
19105        }
19106    }
19107
19108    /// Negative case: a Narrow whose output has multiple consumers
19109    /// must NOT be fused (we can't elide its write — something else
19110    /// reads it).
19111    #[test]
19112    fn narrow_rope_skips_when_narrow_has_multiple_consumers() {
19113        let f = DType::F32;
19114        let mut g = Graph::new("nr_skip");
19115        let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f));
19116        let cos = g.input("cos", Shape::new(&[16], f));
19117        let sin = g.input("sin", Shape::new(&[16], f));
19118        let q = g.narrow_(qkv, 2, 0, 64);
19119        let q_rope = g.rope(q, cos, sin, 16);
19120        // Second consumer of `q` blocks the fusion.
19121        let q_dup = g.activation(rlx_ir::op::Activation::Relu, q, Shape::new(&[16, 8, 64], f));
19122        g.set_outputs(vec![q_rope, q_dup]);
19123
19124        let plan = rlx_opt::memory::plan_memory(&g);
19125        let arena = crate::arena::Arena::from_plan(plan);
19126        let sched = compile_thunks(&g, &arena);
19127
19128        let narrow_count = sched
19129            .thunks
19130            .iter()
19131            .filter(|t| matches!(t, Thunk::Narrow { .. }))
19132            .count();
19133        assert!(
19134            narrow_count >= 1,
19135            "Narrow with multiple consumers must NOT be fused away"
19136        );
19137    }
19138
19139    // ── Op::CustomFn (custom_vjp / custom_jvp) tests ──
19140    //
19141    // Validates: forward execution inlines fwd_body; VJP rule inlines
19142    // vjp_body in place of recursing into fwd_body; JVP rule inlines
19143    // jvp_body. Each test deliberately picks a body whose AD-via-tracing
19144    // would yield a *different* gradient than the override, so we know
19145    // the override actually fired.
19146
19147    /// Forward only: CustomFn wrapping `f(x) = x + c` (c=1 inside body)
19148    /// without override AD bodies. Verifies the body is compiled,
19149    /// constants in the body fill correctly, and the output lands at
19150    /// the outer node's slot.
19151    #[test]
19152    fn custom_fn_forward_inlines_body() {
19153        let s = Shape::new(&[3], DType::F32);
19154
19155        // Body: f(x) = x + 1
19156        let mut body = Graph::new("addone_body");
19157        let x = body.input("x", s.clone());
19158        let one_data: Vec<u8> = (0..3).flat_map(|_| 1.0_f32.to_le_bytes()).collect();
19159        let one = body.add_node(Op::Constant { data: one_data }, vec![], s.clone());
19160        let y = body.binary(BinaryOp::Add, x, one, s.clone());
19161        body.set_outputs(vec![y]);
19162
19163        let mut g = Graph::new("custom_fn_outer");
19164        let xin = g.input("x_in", s.clone());
19165        let cf = g.custom_fn(vec![xin], body, None, None);
19166        g.set_outputs(vec![cf]);
19167
19168        let xs = vec![10.0_f32, 20.0, 30.0];
19169        let (sched, mut arena) = prepare(&g, &[(xin, &xs)]);
19170        execute_thunks(&sched, arena.raw_buf_mut());
19171        let got = read_arena(&arena, cf, 3);
19172        assert_eq!(got, vec![11.0, 21.0, 31.0]);
19173    }
19174
19175    /// Locate an Op::Input or Op::Param by name in a graph.
19176    fn find_named(graph: &Graph, want: &str) -> NodeId {
19177        for n in graph.nodes() {
19178            let name = match &n.op {
19179                Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19180                _ => None,
19181            };
19182            if name == Some(want) {
19183                return n.id;
19184            }
19185        }
19186        panic!("no node named {want:?} in graph");
19187    }
19188
19189    /// VJP override: f(x) = x but vjp_body returns 2 * d_output, so the
19190    /// reported gradient should be 2 — different from the natural 1
19191    /// you'd get by recursing into the identity body.
19192    #[test]
19193    fn custom_fn_vjp_overrides_natural_gradient() {
19194        use rlx_opt::autodiff::grad_with_loss;
19195        let s = Shape::new(&[1], DType::F32);
19196
19197        let mut fwd = Graph::new("id_fwd");
19198        let x = fwd.input("x", s.clone());
19199        fwd.set_outputs(vec![x]);
19200
19201        let mut vjp_g = Graph::new("id_vjp");
19202        let _x_p = vjp_g.input("x", s.clone());
19203        let _y_p = vjp_g.input("primal_output", s.clone());
19204        let dy = vjp_g.input("d_output", s.clone());
19205        let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
19206        let two = vjp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
19207        let dx = vjp_g.binary(BinaryOp::Mul, dy, two, s.clone());
19208        vjp_g.set_outputs(vec![dx]);
19209
19210        let mut g = Graph::new("outer");
19211        let xp = g.param("x", s.clone());
19212        let cf = g.custom_fn(vec![xp], fwd, Some(vjp_g), None);
19213        g.set_outputs(vec![cf]);
19214
19215        let bwd = grad_with_loss(&g, &[xp]);
19216        assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
19217
19218        let xb = find_named(&bwd, "x");
19219        let dout = find_named(&bwd, "d_output");
19220        let (sched, mut arena) = prepare(&bwd, &[(xb, &[7.0]), (dout, &[1.0])]);
19221        execute_thunks(&sched, arena.raw_buf_mut());
19222        let loss = read_arena(&arena, bwd.outputs[0], 1);
19223        let dx_v = read_arena(&arena, bwd.outputs[1], 1);
19224        assert!((loss[0] - 7.0).abs() < 1e-6, "loss should be 7.0");
19225        assert!(
19226            (dx_v[0] - 2.0).abs() < 1e-6,
19227            "vjp override should yield dx=2.0, got {} (natural autodiff would give 1.0)",
19228            dx_v[0]
19229        );
19230    }
19231
19232    /// VJP override: f(a, b) = a*b with vjp_body returning
19233    /// (b * d_output, a * d_output). Validates routing of multiple
19234    /// primals + d_output through the override; matches the natural
19235    /// autodiff-of-Mul gradient (b, a).
19236    #[test]
19237    fn custom_fn_vjp_two_inputs_matches_mul_autodiff() {
19238        use rlx_opt::autodiff::grad_with_loss;
19239        let s = Shape::new(&[1], DType::F32);
19240
19241        let mut fwd = Graph::new("mul_fwd");
19242        let a_f = fwd.input("a", s.clone());
19243        let b_f = fwd.input("b", s.clone());
19244        let y_f = fwd.binary(BinaryOp::Mul, a_f, b_f, s.clone());
19245        fwd.set_outputs(vec![y_f]);
19246
19247        let mut vjp_g = Graph::new("mul_vjp");
19248        let a_v = vjp_g.input("a", s.clone());
19249        let b_v = vjp_g.input("b", s.clone());
19250        let _y_v = vjp_g.input("primal_output", s.clone());
19251        let dy_v = vjp_g.input("d_output", s.clone());
19252        let da = vjp_g.binary(BinaryOp::Mul, b_v, dy_v, s.clone());
19253        let db = vjp_g.binary(BinaryOp::Mul, a_v, dy_v, s.clone());
19254        vjp_g.set_outputs(vec![da, db]);
19255
19256        let mut g = Graph::new("outer");
19257        let ap = g.param("a", s.clone());
19258        let bp = g.param("b", s.clone());
19259        let cf = g.custom_fn(vec![ap, bp], fwd, Some(vjp_g), None);
19260        g.set_outputs(vec![cf]);
19261
19262        let bwd = grad_with_loss(&g, &[ap, bp]);
19263        assert_eq!(bwd.outputs.len(), 3, "expect [loss, da, db]");
19264
19265        let ab = find_named(&bwd, "a");
19266        let bb = find_named(&bwd, "b");
19267        let dout = find_named(&bwd, "d_output");
19268        let (sched, mut arena) = prepare(&bwd, &[(ab, &[3.0]), (bb, &[5.0]), (dout, &[1.0])]);
19269        execute_thunks(&sched, arena.raw_buf_mut());
19270        let loss = read_arena(&arena, bwd.outputs[0], 1);
19271        let da_v = read_arena(&arena, bwd.outputs[1], 1);
19272        let db_v = read_arena(&arena, bwd.outputs[2], 1);
19273        assert!((loss[0] - 15.0).abs() < 1e-5);
19274        assert!(
19275            (da_v[0] - 5.0).abs() < 1e-5,
19276            "da should be b=5.0, got {}",
19277            da_v[0]
19278        );
19279        assert!(
19280            (db_v[0] - 3.0).abs() < 1e-5,
19281            "db should be a=3.0, got {}",
19282            db_v[0]
19283        );
19284    }
19285
19286    /// JVP override: f(x) = x but jvp_body returns 2 * tangent_0.
19287    /// Forward-mode tangent should be 2x the seed (1.0) → 2.0.
19288    #[test]
19289    fn custom_fn_jvp_overrides_natural_tangent() {
19290        use rlx_opt::autodiff_fwd::jvp;
19291        let s = Shape::new(&[1], DType::F32);
19292
19293        let mut fwd = Graph::new("id_fwd");
19294        let x = fwd.input("x", s.clone());
19295        fwd.set_outputs(vec![x]);
19296
19297        let mut jvp_g = Graph::new("id_jvp");
19298        let _x_p = jvp_g.input("x", s.clone());
19299        let tx = jvp_g.input("tangent_0", s.clone());
19300        let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
19301        let two = jvp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
19302        let ty = jvp_g.binary(BinaryOp::Mul, tx, two, s.clone());
19303        jvp_g.set_outputs(vec![ty]);
19304
19305        let mut g = Graph::new("outer");
19306        let xin = g.input("x_in", s.clone());
19307        let cf = g.custom_fn(vec![xin], fwd, None, Some(jvp_g));
19308        g.set_outputs(vec![cf]);
19309
19310        let fwd_g = jvp(&g, &[xin]);
19311        assert_eq!(fwd_g.outputs.len(), 2, "expect [primal_y, tangent_y]");
19312
19313        let xb = find_named(&fwd_g, "x_in");
19314        let tan = find_named(&fwd_g, "tangent_x_in");
19315        let (sched, mut arena) = prepare(&fwd_g, &[(xb, &[7.0]), (tan, &[1.0])]);
19316        execute_thunks(&sched, arena.raw_buf_mut());
19317        let y = read_arena(&arena, fwd_g.outputs[0], 1);
19318        let ty_v = read_arena(&arena, fwd_g.outputs[1], 1);
19319        assert!((y[0] - 7.0).abs() < 1e-6);
19320        assert!(
19321            (ty_v[0] - 2.0).abs() < 1e-6,
19322            "jvp override should yield t_y=2.0 (natural autodiff would give 1.0), got {}",
19323            ty_v[0]
19324        );
19325    }
19326
19327    /// IR-level basic test: `DType::C64` is wired through the dtype
19328    /// table — `size_bytes() == 8`, `is_complex()` reports true, and
19329    /// a `[2]`-shaped C64 buffer in the arena occupies the expected
19330    /// 16 bytes.
19331    #[test]
19332    fn c64_dtype_storage_layout() {
19333        assert_eq!(
19334            DType::C64.size_bytes(),
19335            8,
19336            "C64 should be 8 bytes (f32 real + f32 imag)"
19337        );
19338        assert!(DType::C64.is_complex());
19339        assert!(!DType::C64.is_float());
19340
19341        // A length-2 C64 buffer should have shape size_bytes = 16.
19342        let s = Shape::new(&[2], DType::C64);
19343        assert_eq!(s.size_bytes().unwrap(), 16);
19344    }
19345
19346    // ── C64 element-wise binary kernel witnesses (2026-05-17) ──────
19347    //
19348    // Build a tiny graph: Input `a` + Input `b` (both C64 [2]),
19349    // output = a OP b. Run through CompileResult and compare against
19350    // the closed-form complex arithmetic on the four chosen pairs.
19351
19352    fn run_c64_binary(op: BinaryOp, a: &[(f32, f32)], b: &[(f32, f32)]) -> Vec<(f32, f32)> {
19353        let n = a.len();
19354        let s = Shape::new(&[n], DType::C64);
19355        let mut g = Graph::new("c64_bin");
19356        let in_a = g.input("a", s.clone());
19357        let in_b = g.input("b", s.clone());
19358        let out = g.binary(op, in_a, in_b, s.clone());
19359        g.set_outputs(vec![out]);
19360
19361        let plan = rlx_opt::memory::plan_memory(&g);
19362        let mut arena = crate::arena::Arena::from_plan(plan);
19363        let sched = compile_thunks(&g, &arena);
19364
19365        let a_off = arena.byte_offset(in_a);
19366        let b_off = arena.byte_offset(in_b);
19367        let out_off = arena.byte_offset(out);
19368        // Interleave [re_0, im_0, re_1, im_1, ...] in the f32 buffer.
19369        let buf = arena.raw_buf_mut();
19370        unsafe {
19371            let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
19372            let pb = buf.as_mut_ptr().add(b_off) as *mut f32;
19373            for (i, &(re, im)) in a.iter().enumerate() {
19374                *pa.add(2 * i) = re;
19375                *pa.add(2 * i + 1) = im;
19376            }
19377            for (i, &(re, im)) in b.iter().enumerate() {
19378                *pb.add(2 * i) = re;
19379                *pb.add(2 * i + 1) = im;
19380            }
19381        }
19382        execute_thunks(&sched, arena.raw_buf_mut());
19383        let raw_out: Vec<f32> = unsafe {
19384            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19385            (0..(2 * n)).map(|i| *p.add(i)).collect()
19386        };
19387        (0..n)
19388            .map(|i| (raw_out[2 * i], raw_out[2 * i + 1]))
19389            .collect()
19390    }
19391
19392    #[track_caller]
19393    fn assert_close_c(got: (f32, f32), expected: (f32, f32), tol: f32, label: &str) {
19394        let dr = (got.0 - expected.0).abs();
19395        let di = (got.1 - expected.1).abs();
19396        assert!(
19397            dr < tol && di < tol,
19398            "[{label}] got ({:+.4}, {:+.4}), expected ({:+.4}, {:+.4})",
19399            got.0,
19400            got.1,
19401            expected.0,
19402            expected.1
19403        );
19404    }
19405
19406    #[test]
19407    fn c64_binary_add_matches_complex_arithmetic() {
19408        let a = [(1.0_f32, 2.0_f32), (3.0_f32, -1.0_f32)];
19409        let b = [(4.0_f32, -1.0_f32), (0.5_f32, 0.5_f32)];
19410        let out = run_c64_binary(BinaryOp::Add, &a, &b);
19411        assert_close_c(out[0], (5.0, 1.0), 1e-6, "add[0]");
19412        assert_close_c(out[1], (3.5, -0.5), 1e-6, "add[1]");
19413    }
19414
19415    #[test]
19416    fn c64_binary_sub_matches_complex_arithmetic() {
19417        let a = [(5.0_f32, 1.0_f32)];
19418        let b = [(2.0_f32, 3.0_f32)];
19419        let out = run_c64_binary(BinaryOp::Sub, &a, &b);
19420        assert_close_c(out[0], (3.0, -2.0), 1e-6, "sub");
19421    }
19422
19423    #[test]
19424    fn c64_binary_mul_matches_complex_arithmetic() {
19425        // (1 + 2i)(3 + 4i) = 3 + 4i + 6i + 8i² = -5 + 10i.
19426        let a = [(1.0_f32, 2.0_f32)];
19427        let b = [(3.0_f32, 4.0_f32)];
19428        let out = run_c64_binary(BinaryOp::Mul, &a, &b);
19429        assert_close_c(out[0], (-5.0, 10.0), 1e-5, "mul");
19430    }
19431
19432    #[test]
19433    fn c64_binary_div_matches_complex_arithmetic() {
19434        // (1 + 2i) / (3 + 4i) = ((1·3 + 2·4) + (2·3 − 1·4)i) / 25
19435        //                     = (11 + 2i) / 25
19436        //                     = 0.44 + 0.08i
19437        let a = [(1.0_f32, 2.0_f32)];
19438        let b = [(3.0_f32, 4.0_f32)];
19439        let out = run_c64_binary(BinaryOp::Div, &a, &b);
19440        assert_close_c(out[0], (0.44, 0.08), 1e-5, "div");
19441    }
19442
19443    #[test]
19444    fn c64_binary_mul_identity_one_is_no_op() {
19445        // (a + bi) · (1 + 0i) = a + bi.
19446        let a = [(3.5_f32, -1.25_f32), (-2.0_f32, 7.0_f32)];
19447        let b = [(1.0_f32, 0.0_f32), (1.0_f32, 0.0_f32)];
19448        let out = run_c64_binary(BinaryOp::Mul, &a, &b);
19449        assert_close_c(out[0], a[0], 1e-6, "mul·1[0]");
19450        assert_close_c(out[1], a[1], 1e-6, "mul·1[1]");
19451    }
19452
19453    #[test]
19454    fn c64_binary_mul_by_i_rotates_90_degrees() {
19455        // (a + bi) · i = (a + bi)(0 + i) = -b + ai. 90° CCW rotation.
19456        let a = [(1.0_f32, 0.0_f32)];
19457        let b = [(0.0_f32, 1.0_f32)];
19458        let out = run_c64_binary(BinaryOp::Mul, &a, &b);
19459        assert_close_c(out[0], (0.0, 1.0), 1e-6, "1·i");
19460    }
19461
19462    #[test]
19463    fn c64_binary_div_by_self_gives_unity() {
19464        let a = [(2.5_f32, -1.5_f32), (-0.7_f32, 4.2_f32)];
19465        let out = run_c64_binary(BinaryOp::Div, &a, &a);
19466        assert_close_c(out[0], (1.0, 0.0), 1e-5, "div_self[0]");
19467        assert_close_c(out[1], (1.0, 0.0), 1e-5, "div_self[1]");
19468    }
19469
19470    #[test]
19471    #[should_panic(expected = "C64: complex max/min/pow")]
19472    fn c64_binary_max_is_rejected_at_lowering() {
19473        run_c64_binary(BinaryOp::Max, &[(1.0_f32, 2.0_f32)], &[(3.0_f32, 4.0_f32)]);
19474    }
19475
19476    fn run_c64_activation(act: Activation, a: &[(f32, f32)]) -> Vec<(f32, f32)> {
19477        let n = a.len();
19478        let s = Shape::new(&[n], DType::C64);
19479        let mut g = Graph::new("c64_act");
19480        let in_a = g.input("a", s.clone());
19481        let out = g.activation(act, in_a, s.clone());
19482        g.set_outputs(vec![out]);
19483        let plan = rlx_opt::memory::plan_memory(&g);
19484        let mut arena = crate::arena::Arena::from_plan(plan);
19485        let sched = compile_thunks(&g, &arena);
19486        let a_off = arena.byte_offset(in_a);
19487        let out_off = arena.byte_offset(out);
19488        let buf = arena.raw_buf_mut();
19489        unsafe {
19490            let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
19491            for (i, &(re, im)) in a.iter().enumerate() {
19492                *pa.add(2 * i) = re;
19493                *pa.add(2 * i + 1) = im;
19494            }
19495        }
19496        execute_thunks(&sched, arena.raw_buf_mut());
19497        let raw: Vec<f32> = unsafe {
19498            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19499            (0..(2 * n)).map(|i| *p.add(i)).collect()
19500        };
19501        (0..n).map(|i| (raw[2 * i], raw[2 * i + 1])).collect()
19502    }
19503
19504    #[test]
19505    fn c64_activation_neg_negates_both_components() {
19506        let inp = [(3.5_f32, -1.25_f32), (-2.0_f32, 0.0_f32)];
19507        let out = run_c64_activation(Activation::Neg, &inp);
19508        assert_close_c(out[0], (-3.5, 1.25), 1e-6, "neg[0]");
19509        assert_close_c(out[1], (2.0, 0.0), 1e-6, "neg[1]");
19510    }
19511
19512    #[test]
19513    fn c64_activation_exp_matches_euler() {
19514        // exp(0 + i·π) = -1 + 0i.
19515        // exp(1 + 0i) = e ≈ 2.71828.
19516        let inp = [(0.0_f32, std::f32::consts::PI), (1.0_f32, 0.0_f32)];
19517        let out = run_c64_activation(Activation::Exp, &inp);
19518        assert_close_c(out[0], (-1.0, 0.0), 1e-5, "exp(iπ)");
19519        assert_close_c(out[1], (std::f32::consts::E, 0.0), 1e-5, "exp(1)");
19520    }
19521
19522    #[test]
19523    fn c64_activation_log_matches_principal_branch() {
19524        // log(1 + 0i) = 0.
19525        // log(0 + i) = log(1) + i·π/2 = 0 + i·π/2.
19526        // log(-1 + 0i) = 0 + i·π.
19527        let inp = [(1.0_f32, 0.0_f32), (0.0_f32, 1.0_f32), (-1.0_f32, 0.0_f32)];
19528        let out = run_c64_activation(Activation::Log, &inp);
19529        assert_close_c(out[0], (0.0, 0.0), 1e-5, "log(1)");
19530        assert_close_c(out[1], (0.0, std::f32::consts::FRAC_PI_2), 1e-5, "log(i)");
19531        assert_close_c(out[2], (0.0, std::f32::consts::PI), 1e-5, "log(-1)");
19532    }
19533
19534    #[test]
19535    fn c64_activation_sqrt_squared_recovers_input() {
19536        // For positive-real-part inputs, sqrt(z)² should equal z exactly
19537        // to f32 noise.
19538        let inp = [(4.0_f32, 0.0_f32), (3.0_f32, 4.0_f32)];
19539        let roots = run_c64_activation(Activation::Sqrt, &inp);
19540        // sqrt(4) = 2 + 0i; sqrt(3+4i) = 2 + i (since (2+i)² = 4+4i-1 = 3+4i).
19541        assert_close_c(roots[0], (2.0, 0.0), 1e-5, "sqrt(4)");
19542        assert_close_c(roots[1], (2.0, 1.0), 1e-5, "sqrt(3+4i)");
19543    }
19544
19545    #[test]
19546    #[should_panic(expected = "no natural complex extension")]
19547    fn c64_activation_relu_is_rejected_at_lowering() {
19548        run_c64_activation(Activation::Relu, &[(1.0_f32, 2.0_f32)]);
19549    }
19550
19551    // ── ComplexNormSq + Wirtinger backward witnesses ───────────────
19552
19553    /// Forward `|z|²`: returns `[n]` f32.
19554    fn run_complex_norm_sq(z: &[(f32, f32)]) -> Vec<f32> {
19555        let n = z.len();
19556        let mut g = Graph::new("cns_fwd");
19557        let in_z = g.input("z", Shape::new(&[n], DType::C64));
19558        let out = g.complex_norm_sq(in_z);
19559        g.set_outputs(vec![out]);
19560        let plan = rlx_opt::memory::plan_memory(&g);
19561        let mut arena = crate::arena::Arena::from_plan(plan);
19562        let sched = compile_thunks(&g, &arena);
19563        let z_off = arena.byte_offset(in_z);
19564        let out_off = arena.byte_offset(out);
19565        let buf = arena.raw_buf_mut();
19566        unsafe {
19567            let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
19568            for (i, &(re, im)) in z.iter().enumerate() {
19569                *pz.add(2 * i) = re;
19570                *pz.add(2 * i + 1) = im;
19571            }
19572        }
19573        execute_thunks(&sched, arena.raw_buf_mut());
19574        unsafe {
19575            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19576            (0..n).map(|i| *p.add(i)).collect()
19577        }
19578    }
19579
19580    /// Backward: given z and upstream g, return dz = g·z element-wise (C64).
19581    fn run_complex_norm_sq_bwd(z: &[(f32, f32)], g: &[f32]) -> Vec<(f32, f32)> {
19582        let n = z.len();
19583        let mut gr = Graph::new("cns_bwd");
19584        let in_z = gr.input("z", Shape::new(&[n], DType::C64));
19585        let in_g = gr.input("g", Shape::new(&[n], DType::F32));
19586        let out = gr.complex_norm_sq_backward(in_z, in_g);
19587        gr.set_outputs(vec![out]);
19588        let plan = rlx_opt::memory::plan_memory(&gr);
19589        let mut arena = crate::arena::Arena::from_plan(plan);
19590        let sched = compile_thunks(&gr, &arena);
19591        let z_off = arena.byte_offset(in_z);
19592        let g_off = arena.byte_offset(in_g);
19593        let out_off = arena.byte_offset(out);
19594        let buf = arena.raw_buf_mut();
19595        unsafe {
19596            let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
19597            let pg = buf.as_mut_ptr().add(g_off) as *mut f32;
19598            for (i, &(re, im)) in z.iter().enumerate() {
19599                *pz.add(2 * i) = re;
19600                *pz.add(2 * i + 1) = im;
19601            }
19602            for (i, &v) in g.iter().enumerate() {
19603                *pg.add(i) = v;
19604            }
19605        }
19606        execute_thunks(&sched, arena.raw_buf_mut());
19607        unsafe {
19608            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19609            (0..n).map(|i| (*p.add(2 * i), *p.add(2 * i + 1))).collect()
19610        }
19611    }
19612
19613    #[test]
19614    fn complex_norm_sq_matches_textbook() {
19615        // |3 + 4i|² = 9 + 16 = 25.
19616        // |1 + 0i|² = 1.
19617        // |0 + 0i|² = 0.
19618        let z = [(3.0_f32, 4.0_f32), (1.0_f32, 0.0_f32), (0.0_f32, 0.0_f32)];
19619        let out = run_complex_norm_sq(&z);
19620        assert!((out[0] - 25.0).abs() < 1e-5);
19621        assert!((out[1] - 1.0).abs() < 1e-6);
19622        assert!(out[2].abs() < 1e-6);
19623    }
19624
19625    #[test]
19626    fn complex_norm_sq_backward_matches_wirtinger_formula() {
19627        // Wirtinger: ∂|z|²/∂z̄ = z. With upstream g = 1, dz = z.
19628        let z = [(3.0_f32, 4.0_f32), (1.5_f32, -2.5_f32)];
19629        let g = [1.0_f32, 1.0_f32];
19630        let dz = run_complex_norm_sq_bwd(&z, &g);
19631        assert_close_c(dz[0], z[0], 1e-6, "dz[0] = g·z[0]");
19632        assert_close_c(dz[1], z[1], 1e-6, "dz[1] = g·z[1]");
19633    }
19634
19635    #[test]
19636    fn complex_norm_sq_backward_scales_with_upstream() {
19637        // With upstream g[i] ≠ 1: dz[i] = g[i]·z[i].
19638        let z = [(2.0_f32, 1.0_f32), (-1.0_f32, 3.0_f32)];
19639        let g = [0.5_f32, -2.0_f32];
19640        let dz = run_complex_norm_sq_bwd(&z, &g);
19641        assert_close_c(dz[0], (1.0, 0.5), 1e-6, "g=0.5 · (2,1)");
19642        assert_close_c(dz[1], (2.0, -6.0), 1e-6, "g=-2 · (-1,3)");
19643    }
19644
19645    /// Multi-output Op::CustomFn via the concat-with-Narrow design
19646    /// (rlx-ir::Graph::custom_fn_multi). Build a custom_fn whose
19647    /// fwd_body returns two outputs (x², 2x), then materialize each
19648    /// via the MultiOutputHandle and verify both numerically.
19649    #[test]
19650    fn custom_fn_multi_extracts_each_subgraph_output() {
19651        use rlx_ir::ops::special::MultiOutputHandle;
19652
19653        let _ = MultiOutputHandle {
19654            source: NodeId(0),
19655            sub_shapes: vec![],
19656            offsets: vec![],
19657        }; // import sanity
19658
19659        // Inner body: input x [3] f32, outputs (x², 2x) both [3] f32.
19660        let mut body = Graph::new("multi_body");
19661        let s3 = Shape::new(&[3], DType::F32);
19662        let x = body.input("x", s3.clone());
19663        let x_sq = body.binary(BinaryOp::Mul, x, x, s3.clone());
19664        let two = body.add_node(
19665            Op::Constant {
19666                data: vec![
19667                    2.0_f32.to_le_bytes(),
19668                    2.0_f32.to_le_bytes(),
19669                    2.0_f32.to_le_bytes(),
19670                ]
19671                .into_iter()
19672                .flatten()
19673                .collect(),
19674            },
19675            vec![],
19676            s3.clone(),
19677        );
19678        let two_x = body.binary(BinaryOp::Mul, two, x, s3.clone());
19679        body.set_outputs(vec![x_sq, two_x]);
19680
19681        // Outer graph: feed in_x → custom_fn_multi → handle.output(0/1).
19682        let mut outer = Graph::new("multi_outer");
19683        let in_x = outer.input("xin", s3.clone());
19684        let handle = outer.custom_fn_multi(vec![in_x], body);
19685        assert_eq!(handle.n_outputs(), 2);
19686        let out0 = handle.output(&mut outer, 0); // x²
19687        let out1 = handle.output(&mut outer, 1); // 2x
19688        outer.set_outputs(vec![out0, out1]);
19689
19690        let plan = rlx_opt::memory::plan_memory(&outer);
19691        let mut arena = crate::arena::Arena::from_plan(plan);
19692        let sched = compile_thunks(&outer, &arena);
19693        let xin_off = arena.byte_offset(in_x);
19694        let out0_off = arena.byte_offset(out0);
19695        let out1_off = arena.byte_offset(out1);
19696        let xs = [1.0_f32, 2.0, 3.0];
19697        unsafe {
19698            let p = arena.raw_buf_mut().as_mut_ptr().add(xin_off) as *mut f32;
19699            for (i, &v) in xs.iter().enumerate() {
19700                *p.add(i) = v;
19701            }
19702        }
19703        execute_thunks(&sched, arena.raw_buf_mut());
19704        let out0_v: Vec<f32> = unsafe {
19705            let p = arena.raw_buf().as_ptr().add(out0_off) as *const f32;
19706            (0..3).map(|i| *p.add(i)).collect()
19707        };
19708        let out1_v: Vec<f32> = unsafe {
19709            let p = arena.raw_buf().as_ptr().add(out1_off) as *const f32;
19710            (0..3).map(|i| *p.add(i)).collect()
19711        };
19712        // x² = [1, 4, 9]; 2x = [2, 4, 6].
19713        for i in 0..3 {
19714            assert!(
19715                (out0_v[i] - xs[i] * xs[i]).abs() < 1e-5,
19716                "out0[{i}] = {} != x² = {}",
19717                out0_v[i],
19718                xs[i] * xs[i]
19719            );
19720            assert!(
19721                (out1_v[i] - 2.0 * xs[i]).abs() < 1e-5,
19722                "out1[{i}] = {} != 2x = {}",
19723                out1_v[i],
19724                2.0 * xs[i]
19725            );
19726        }
19727    }
19728
19729    #[test]
19730    fn complex_norm_sq_gradient_matches_finite_difference() {
19731        // Numerical sanity: perturb z[0].re by ε, observe Δ|z|² ≈ 2·re·ε.
19732        let z = [(3.0_f32, 4.0_f32)];
19733        let eps = 1e-3_f32;
19734        let v0 = run_complex_norm_sq(&z)[0];
19735        let z_pert = [(3.0_f32 + eps, 4.0_f32)];
19736        let v1 = run_complex_norm_sq(&z_pert)[0];
19737        let fd_re = (v1 - v0) / eps;
19738        let analytic_re = 2.0 * z[0].0;
19739        assert!((fd_re - analytic_re).abs() < 1e-2);
19740
19741        // ∂/∂im at z = (3, 4) is 2·im = 8.
19742        let z_pert_im = [(3.0_f32, 4.0_f32 + eps)];
19743        let v2 = run_complex_norm_sq(&z_pert_im)[0];
19744        let fd_im = (v2 - v0) / eps;
19745        let analytic_im = 2.0 * z[0].1;
19746        assert!((fd_im - analytic_im).abs() < 1e-2);
19747
19748        // Compare with the Wirtinger backward at upstream g = 1.
19749        // Wirtinger ∂/∂z̄ = z gives dz = (re, im). The "real
19750        // gradient" wrt (re, im) is 2·(re, im), i.e. 2·dz = (2·re,
19751        // 2·im) — that's the factor 2 difference between Wirtinger
19752        // ∂/∂z̄ and the real-vector gradient on (re, im).
19753        let dz = run_complex_norm_sq_bwd(&z, &[1.0_f32]);
19754        assert!((2.0 * dz[0].0 - analytic_re).abs() < 1e-5);
19755        assert!((2.0 * dz[0].1 - analytic_im).abs() < 1e-5);
19756    }
19757
19758    /// Direct regression test for the 5-D mid-shape singleton broadcast
19759    /// (SAM rel_pos pattern: `[bh, h, w, 1, w] + [bh, h, w, h, w]`).
19760    /// The SAM port worked around this by `concat`-tiling the rhs; this
19761    /// test verifies the in-graph broadcast path is bit-correct.
19762    #[test]
19763    fn binary_full_5d_mid_singleton_broadcast() {
19764        let bh = 2usize;
19765        let h = 3;
19766        let w = 4;
19767        let f = DType::F32;
19768
19769        let mut g = Graph::new("bcast_5d");
19770        let lhs = g.input("lhs", Shape::new(&[bh, h, w, h, w], f));
19771        // rhs shape with size-1 at axis 3 (mid-shape singleton).
19772        let rhs = g.input("rhs", Shape::new(&[bh, h, w, 1, w], f));
19773        let out = g.binary(BinaryOp::Add, lhs, rhs, Shape::new(&[bh, h, w, h, w], f));
19774        g.set_outputs(vec![out]);
19775
19776        // Deterministic data.
19777        let lhs_data: Vec<f32> = (0..bh * h * w * h * w).map(|i| i as f32 * 0.01).collect();
19778        let rhs_data: Vec<f32> = (0..bh * h * w * w)
19779            .map(|i| (i as f32 + 100.0) * 0.01)
19780            .collect();
19781
19782        // Compute expected output by hand.
19783        let mut expected = vec![0f32; bh * h * w * h * w];
19784        for b_ in 0..bh {
19785            for hq in 0..h {
19786                for wq in 0..w {
19787                    for hk in 0..h {
19788                        for wk in 0..w {
19789                            let li = (((b_ * h + hq) * w + wq) * h + hk) * w + wk;
19790                            // rhs has hk dim = 1, so it's always index 0 there.
19791                            let ri = ((b_ * h + hq) * w + wq) * w + wk;
19792                            expected[li] = lhs_data[li] + rhs_data[ri];
19793                        }
19794                    }
19795                }
19796            }
19797        }
19798
19799        let plan = rlx_opt::memory::plan_memory(&g);
19800        let mut arena = crate::arena::Arena::from_plan(plan);
19801        let sched = compile_thunks(&g, &arena);
19802        let lhs_off = arena.byte_offset(lhs);
19803        let rhs_off = arena.byte_offset(rhs);
19804        let out_off = arena.byte_offset(out);
19805        let buf = arena.raw_buf_mut();
19806        unsafe {
19807            let p = buf.as_mut_ptr().add(lhs_off) as *mut f32;
19808            for (i, &v) in lhs_data.iter().enumerate() {
19809                *p.add(i) = v;
19810            }
19811            let p = buf.as_mut_ptr().add(rhs_off) as *mut f32;
19812            for (i, &v) in rhs_data.iter().enumerate() {
19813                *p.add(i) = v;
19814            }
19815        }
19816        execute_thunks(&sched, arena.raw_buf_mut());
19817        let actual: Vec<f32> = unsafe {
19818            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19819            (0..bh * h * w * h * w).map(|i| *p.add(i)).collect()
19820        };
19821
19822        // Bit-exact check.
19823        let mut max_diff = 0f32;
19824        let mut max_idx = 0;
19825        for i in 0..actual.len() {
19826            let d = (actual[i] - expected[i]).abs();
19827            if d > max_diff {
19828                max_diff = d;
19829                max_idx = i;
19830            }
19831        }
19832        assert!(
19833            max_diff < 1e-6,
19834            "5D mid-shape singleton broadcast wrong: max |Δ| = {max_diff} at idx {max_idx} \
19835             (actual={}, expected={})",
19836            actual[max_idx],
19837            expected[max_idx]
19838        );
19839    }
19840
19841    #[test]
19842    fn layer_norm2d_and_conv_transpose2d_kernels() {
19843        let mut out = vec![0f32; 8];
19844        crate::kernels::layer_norm2d_nchw(
19845            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
19846            &[1.0, 1.0],
19847            &[0.0, 0.0],
19848            &mut out,
19849            1,
19850            2,
19851            2,
19852            2,
19853            1e-5,
19854        );
19855        let mean0: f32 = (1.0 + 3.0) / 2.0;
19856        assert!((out[0] - mean0).abs() > 0.1);
19857
19858        let mut up = vec![0f32; 4];
19859        crate::kernels::conv_transpose2d_nchw(
19860            &[2.0],
19861            &[1.0, 0.0, 0.0, 1.0],
19862            &mut up,
19863            1,
19864            1,
19865            1,
19866            1,
19867            1,
19868            2,
19869            2,
19870            2,
19871            2,
19872            2,
19873            2,
19874            0,
19875            0,
19876            1,
19877            1,
19878            1,
19879        );
19880        assert!((up[0] - 2.0).abs() < 1e-5);
19881        assert!((up[3] - 2.0).abs() < 1e-5);
19882    }
19883}