Skip to main content

rlx_cpu/
thunk.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Thunks — pre-compiled kernel dispatch with zero per-call overhead.
17//!
18//! At compile time, the graph is lowered into a flat `Vec<Thunk>` where each
19//! thunk holds pre-computed arena offsets, dimensions, and kernel type.
20//! At runtime, the executor just iterates thunks and calls kernels directly.
21
22// Edition 2024: bodies of `unsafe fn` are safe by default; `sl`/`sl_mut` stay `unsafe fn`.
23#![allow(unsafe_op_in_unsafe_fn)]
24//! No match dispatch, no HashMap lookup, no dimension computation.
25
26use crate::arena::Arena;
27use crate::op_registry::CpuKernel;
28use rlx_ir::op::{Activation, BinaryOp, CmpOp, ReduceOp};
29use rlx_ir::{Graph, NodeId, Op, Shape};
30use std::collections::HashMap;
31use std::sync::Arc;
32
33/// A pre-compiled kernel call with all args resolved to arena offsets.
34#[derive(Clone)]
35pub enum Thunk {
36    /// Skip (Input/Param already in arena)
37    Nop,
38    /// C = A @ B (BLAS sgemm)
39    Sgemm {
40        a: usize,
41        b: usize,
42        c: usize,
43        m: u32,
44        k: u32,
45        n: u32,
46    },
47    /// f64 dense solve `x = A⁻¹·b` via LAPACK dgesv.
48    /// `a`, `b`, `x` are byte-offsets into the arena. `n` is the matrix
49    /// dimension; `nrhs` is 1 for a vector RHS or >1 for multi-RHS.
50    /// The kernel materializes scratch copies of A and b internally
51    /// (LAPACK overwrites both with LU factors and solution).
52    DenseSolveF64 {
53        a: usize,
54        b: usize,
55        x: usize,
56        n: u32,
57        nrhs: u32,
58    },
59    /// f32 twin of `DenseSolveF64`. Calls LAPACK `sgesv` (or the
60    /// no-blas Rust fallback). Same arena byte-offset contract.
61    DenseSolveF32 {
62        a: usize,
63        b: usize,
64        x: usize,
65        n: u32,
66        nrhs: u32,
67    },
68    /// Batched f64 dense solve. `a`, `b`, `x` are byte-offsets to
69    /// the leading slice; `batch` is the number of independent
70    /// systems. Per slice the kernel calls `dgesv(A_i, b_i, n, nrhs)`
71    /// — LAPACK has no batched dgesv on Accelerate, so we loop.
72    BatchedDenseSolveF64 {
73        a: usize,
74        b: usize,
75        x: usize,
76        batch: u32,
77        n: u32,
78        nrhs: u32,
79    },
80    /// Batched f32 dense solve — loop of `sgesv` per batch slice.
81    BatchedDenseSolveF32 {
82        a: usize,
83        b: usize,
84        x: usize,
85        batch: u32,
86        n: u32,
87        nrhs: u32,
88    },
89    /// Batched f64 matmul. Both inputs and output have a leading
90    /// batch axis of size `batch`. Per-batch independent dgemm:
91    /// `C[i] = A[i] @ B[i]` for `i in 0..batch`. Used by VJP rules
92    /// that emit per-batch outer products (e.g., BatchedDenseSolve
93    /// VJP). The unbatched `Dgemm` thunk handles the rank-2 case.
94    BatchedDgemmF64 {
95        a: usize,
96        b: usize,
97        c: usize,
98        batch: u32,
99        m: u32,
100        k: u32,
101        n: u32,
102    },
103    /// Batched f32 matmul — same loop-per-batch shape as
104    /// `BatchedDgemmF64` but calling `sgemm`. Needed for attention
105    /// patterns where both operands carry a batch dim (e.g. q@k^T
106    /// and attn@v in decomposed self-attention). The 2-D `Sgemm`
107    /// flatten trick is wrong in that case because it treats `b` as
108    /// a single shared RHS across every batch.
109    BatchedSgemm {
110        a: usize,
111        b: usize,
112        c: usize,
113        batch: u32,
114        m: u32,
115        k: u32,
116        n: u32,
117    },
118    /// C = A @ B via Accelerate cblas_dgemm. Mirror of `Sgemm` at f64.
119    Dgemm {
120        a: usize,
121        b: usize,
122        c: usize,
123        m: u32,
124        k: u32,
125        n: u32,
126    },
127    /// f64 N-D index walk used for both `Op::Transpose` and `Op::Expand`.
128    /// `in_strides` carries 0s on broadcast axes (Expand) or permuted
129    /// strides (Transpose). Mirror of `Thunk::Transpose` at f64.
130    TransposeF64 {
131        src: usize,
132        dst: usize,
133        in_total: u32,
134        out_dims: Vec<u32>,
135        in_strides: Vec<u32>,
136    },
137    /// f64 element-wise activation. Single-input, single-output. The
138    /// kernel always reads from `src` and writes to `dst`, so it works
139    /// whether or not the planner aliased the two slots.
140    ActivationF64 {
141        src: usize,
142        dst: usize,
143        len: u32,
144        kind: Activation,
145    },
146    /// Element-wise complex squared-magnitude: `|z|² = re² + im²`.
147    /// Reads the C64 input at `src` as `2·len` f32 ([re,im] pairs),
148    /// writes `len` f32 to `dst`.
149    ComplexNormSqF32 {
150        src: usize,
151        dst: usize,
152        /// Logical element count (number of complex values).
153        len: u32,
154    },
155    /// Wirtinger backward for [`ComplexNormSqF32`]: `dz = g · z` as
156    /// C64. Reads `z` at `2·len` f32 + `g` at `len` f32; writes
157    /// `2·len` f32 to `dz`.
158    ComplexNormSqBackwardF32 {
159        z: usize,
160        g: usize,
161        dz: usize,
162        len: u32,
163    },
164    /// Element-wise C64 conjugate: writes `[re_i, -im_i]` per element.
165    /// Layout matches the rest of C64 here ([re,im] interleaved f32).
166    ConjugateC64 { src: usize, dst: usize, len: u32 },
167    /// C64 element-wise activation. Only kinds with well-defined
168    /// complex extensions are supported: Neg, Exp, Log, Sqrt.
169    /// Everything else (Sigmoid, Tanh, Relu, Abs, Sin/Cos/Tan/Atan,
170    /// Round, GeLU family) is rejected at lowering — those don't have
171    /// single natural complex definitions. `len` is the **complex
172    /// element count** (the f32 buffer holds `2·len` floats).
173    ActivationC64 {
174        src: usize,
175        dst: usize,
176        len: u32,
177        kind: Activation,
178    },
179    /// f64 contiguous reduction along a single axis range. Layout
180    /// `[outer, reduced, inner]` in memory; output is `[outer, inner]`.
181    /// Sum only for now (Mean composes via 1/N multiply post-pass).
182    ReduceSumF64 {
183        src: usize,
184        dst: usize,
185        outer: u32,
186        reduced: u32,
187        inner: u32,
188    },
189    /// f64 plain copy (Reshape / Cast at the same dtype). Mirrors `Copy`
190    /// but at 8 bytes per element.
191    CopyF64 { src: usize, dst: usize, len: u32 },
192    /// f64 element-wise binary with broadcast. `len`/`lhs_len`/`rhs_len`
193    /// are element counts; kernel does `out[i] = lhs[i % lhs_len] OP rhs[i % rhs_len]`.
194    /// Mirror of `BinaryFull` at 8 bytes per element.
195    BinaryFullF64 {
196        lhs: usize,
197        rhs: usize,
198        dst: usize,
199        len: u32,
200        lhs_len: u32,
201        rhs_len: u32,
202        op: BinaryOp,
203        /// Output shape dims (row-major). Empty in the fast path. See
204        /// `BinaryFull` doc for the broadcast convention.
205        out_dims_bcast: Vec<u32>,
206        bcast_lhs_strides: Vec<u32>,
207        bcast_rhs_strides: Vec<u32>,
208    },
209    /// f64 concat — byte-for-byte mirror of `Concat` but copies
210    /// 8 bytes per element. Element-counted offsets/strides match
211    /// the f32 variant; the executor scales by elem_size internally.
212    ConcatF64 {
213        dst: usize,
214        outer: u32,
215        inner: u32,
216        total_axis: u32,
217        inputs: Vec<(usize, u32)>,
218    },
219    /// C64 element-wise binary with broadcast. Same `len` /
220    /// `lhs_len` / `rhs_len` semantics as `BinaryFull` but each
221    /// "element" is one complex value (8 bytes = `[re, im]` as two
222    /// f32s). The executor reads the underlying f32 buffer at
223    /// `2·len` floats and walks element pairs. Supports Add / Sub /
224    /// Mul / Div; Max / Min / Pow have no single natural complex
225    /// definition and panic at lowering.
226    BinaryFullC64 {
227        lhs: usize,
228        rhs: usize,
229        dst: usize,
230        /// Complex element count (NOT f32 count). f32 buffer length
231        /// is `2·len`.
232        len: u32,
233        lhs_len: u32,
234        rhs_len: u32,
235        op: BinaryOp,
236        out_dims_bcast: Vec<u32>,
237        bcast_lhs_strides: Vec<u32>,
238        bcast_rhs_strides: Vec<u32>,
239    },
240    /// Bounded scan. Holds a recursively-compiled body schedule + a
241    /// pre-initialized body arena snapshot (constants filled). Each
242    /// outer execution clones the snapshot, copies the carry-in slot
243    /// from the outer arena, runs the body schedule `length` times,
244    /// then writes the final carry to the outer arena.
245    ///
246    /// Single-carry MVP — body has exactly one Input and one output,
247    /// both same shape and dtype.
248    Scan {
249        body: Arc<ThunkSchedule>,
250        body_init: Arc<Vec<u8>>, // pristine body arena bytes
251        body_input_off: usize,   // byte offset of the body's carry-Input slot
252        body_output_off: usize,  // byte offset of the body's output slot
253        outer_init_off: usize,   // outer-arena offset of the initial carry
254        outer_final_off: usize,  // outer-arena offset of the final carry / trajectory base
255        length: u32,
256        carry_bytes: u32, // carry size in bytes
257        /// When true, write each step's carry to the outer arena at
258        /// offset `outer_final_off + t * carry_bytes`, producing a
259        /// `[length, *carry]` stacked trajectory. When false, only the
260        /// final carry lands at `outer_final_off`.
261        save_trajectory: bool,
262        /// Per-step `xs` inputs. For each: (body_x_input_off,
263        /// outer_xs_base_off, per_step_bytes). Per iteration `t`, the
264        /// executor copies `outer_xs_base_off + t * per_step_bytes`
265        /// into `body_x_input_off`. Empty when the scan has no xs.
266        xs_inputs: Arc<Vec<(usize, usize, u32)>>,
267        /// Broadcast inputs — values constant across iterations. For
268        /// each: (body_bcast_input_off, outer_bcast_off, total_bytes).
269        /// Filled into `body_buf` ONCE before the scan loop starts
270        /// (xs in contrast are re-filled every iteration). Empty when
271        /// the scan has no bcasts.
272        bcast_inputs: Arc<Vec<(usize, usize, u32)>>,
273        /// Number of trajectory checkpoints (when `save_trajectory`).
274        /// `0` or `length` ⇒ save every iteration. Otherwise save only
275        /// `K` rows at indices `floor((k+1) * length / K) - 1` for
276        /// `k in 0..K`. Last index is always `length-1` so the final
277        /// carry is always cached.
278        num_checkpoints: u32,
279    },
280
281    /// Reverse-mode AD companion to `Thunk::Scan`. Walks `t = length-1
282    /// .. 0`, threading `dcarry` through the body's VJP. Per iteration:
283    /// writes `carry_t` (from outer init or trajectory), each `xs_i[t]`
284    /// slice, and the current `dcarry` into the body_vjp's Input
285    /// slots, runs body_vjp, reads new `dcarry` from its single output.
286    /// f64 carry only — the upstream-accumulation step in trajectory
287    /// mode does an element-wise f64 add.
288    ScanBackward {
289        body_vjp: Arc<ThunkSchedule>,
290        body_init: Arc<Vec<u8>>,
291        body_carry_in_off: usize, // body_vjp's mirrored body-carry-input slot
292        body_x_offs: Arc<Vec<usize>>, // body_vjp's mirrored x_t_i Input slots, in xs order
293        body_d_output_off: usize, // body_vjp's "d_output" Input slot
294        body_dcarry_out_off: usize, // body_vjp's gradient output
295        outer_init_off: usize,    // original init carry
296        outer_traj_off: usize,    // [length-or-K, *carry] trajectory base
297        outer_upstream_off: usize, // upstream gradient (carry shape, or [length, *carry])
298        /// Per-xs entries: (outer_xs_base_off, per_step_bytes). Read
299        /// `xs_i[t]` from `outer_xs_base_off + t * per_step_bytes`.
300        outer_xs_offs: Arc<Vec<(usize, u32)>>,
301        outer_dinit_off: usize, // output: dinit
302        length: u32,
303        carry_bytes: u32,
304        /// Bytes per element in the carry tensor: 4 for f32, 8 for f64.
305        /// Used to dispatch the trajectory-mode upstream accumulation
306        /// kernel (the dcarry += upstream\[t\] add must use the right
307        /// floating-point type — a hard-coded f64 add silently does
308        /// nothing for an f32 carry whose `cb` isn't divisible by 8).
309        carry_elem_size: u32,
310        save_trajectory: bool, // true → upstream is per-step; false → just final
311        /// Recursive checkpointing config. `0` or `length` ⇒ full
312        /// trajectory cached, no recompute (existing behavior).
313        /// `0 < K < length` ⇒ trajectory has only K rows; the executor
314        /// recomputes intermediate carries via `forward_body` between
315        /// checkpoints. Memory: O(K · carry_bytes); time: O(length).
316        num_checkpoints: u32,
317        /// Forward body schedule (same compiled body as the forward
318        /// Op::Scan), used for recompute when `num_checkpoints` is
319        /// active. `None` for the All strategy.
320        forward_body: Option<Arc<ThunkSchedule>>,
321        /// Pristine forward body arena bytes (constants filled).
322        forward_body_init: Option<Arc<Vec<u8>>>,
323        /// Forward body's carry-Input and output slot offsets — needed
324        /// to seed/read the body during recompute.
325        forward_body_carry_in_off: usize,
326        forward_body_output_off: usize,
327        /// Forward body's per-step xs Input slots (one per outer xs).
328        /// Same indexing convention as `body_x_offs`.
329        forward_body_x_offs: Arc<Vec<usize>>,
330    },
331
332    /// Companion to `ScanBackward` that materializes one stacked
333    /// `dxs_i`. Same backward loop; per iteration, after running
334    /// body_vjp, copies its `body_dxs_out_off` slot into the outer
335    /// arena at `outer_dxs_off + t * per_step_bytes`. dcarry threading
336    /// is identical — we still need it for the body_vjp recurrence
337    /// even though we don't write it back to the outer arena.
338    ScanBackwardXs {
339        body_vjp: Arc<ThunkSchedule>,
340        body_init: Arc<Vec<u8>>,
341        body_carry_in_off: usize,
342        body_x_offs: Arc<Vec<usize>>,
343        body_d_output_off: usize,
344        body_dcarry_out_off: usize,
345        body_dxs_out_off: usize, // the body_vjp output we extract per step
346        outer_init_off: usize,
347        outer_traj_off: usize,
348        outer_upstream_off: usize,
349        outer_xs_offs: Arc<Vec<(usize, u32)>>,
350        outer_dxs_off: usize, // base of the stacked [length, *per_step] output
351        length: u32,
352        carry_bytes: u32,
353        /// Same role as `Thunk::ScanBackward::carry_elem_size`.
354        carry_elem_size: u32,
355        per_step_bytes: u32, // bytes per row of the dxs output
356        save_trajectory: bool,
357        /// Recursive checkpointing config. Same semantics as
358        /// `Thunk::ScanBackward::num_checkpoints` — `0` or `length`
359        /// means "save every step's carry"; `0 < K < length` means
360        /// the trajectory has only K rows and the executor recomputes
361        /// intermediate carries via `forward_body` (which must be
362        /// `Some`). Implemented via segment-cached recompute,
363        /// mirroring the `ScanBackward` path.
364        num_checkpoints: u32,
365        forward_body: Option<Arc<ThunkSchedule>>,
366        forward_body_init: Option<Arc<Vec<u8>>>,
367        forward_body_carry_in_off: usize,
368        forward_body_output_off: usize,
369        forward_body_x_offs: Arc<Vec<usize>>,
370    },
371    /// User-defined sub-graph (`Op::CustomFn`) — runs `fwd_body` once.
372    /// Per execution: clone `body_init`, copy each primal input from the
373    /// outer arena into its body Input slot, run the body schedule,
374    /// copy the body's single output back to the outer arena.
375    CustomFn {
376        body: Arc<ThunkSchedule>,
377        body_init: Arc<Vec<u8>>,
378        /// Per primal input: (body_input_off, outer_input_off, bytes).
379        inputs: Arc<Vec<(usize, usize, u32)>>,
380        body_output_off: usize,
381        outer_output_off: usize,
382        out_bytes: u32,
383    },
384    /// C = A @ B; C += bias; C = act(C)
385    FusedMmBiasAct {
386        a: usize,
387        w: usize,
388        bias: usize,
389        c: usize,
390        m: u32,
391        k: u32,
392        n: u32,
393        act: Option<Activation>,
394    },
395    /// out = LN(x + residual + bias, gamma, beta)
396    FusedResidualLN {
397        x: usize,
398        res: usize,
399        bias: usize,
400        g: usize,
401        b: usize,
402        out: usize,
403        rows: u32,
404        h: u32,
405        eps: f32,
406        has_bias: bool,
407    },
408    /// out = RmsNorm(x + residual + bias, gamma, beta)
409    FusedResidualRmsNorm {
410        x: usize,
411        res: usize,
412        bias: usize,
413        g: usize,
414        b: usize,
415        out: usize,
416        rows: u32,
417        h: u32,
418        eps: f32,
419        has_bias: bool,
420    },
421    /// out = bias_add(data, bias, m, n) for Binary::Add with broadcast
422    BiasAdd {
423        src: usize,
424        bias: usize,
425        dst: usize,
426        m: u32,
427        n: u32,
428    },
429    /// Element-wise binary op with NumPy-style broadcast.
430    ///
431    /// Fast path (`lhs_len == rhs_len == len`): plain element-wise loop,
432    /// SIMD-vectorized on aarch64 for `Add`/`Mul`. `bcast_*` fields
433    /// are unused.
434    ///
435    /// Broadcast path: uses `out_dims_bcast` + `bcast_lhs_strides` +
436    /// `bcast_rhs_strides` to compute per-cell indices into each
437    /// operand. The strides are precomputed at thunk-construction
438    /// time from the operands' true shapes (with stride 0 on any axis
439    /// where the operand has size 1). This is the only correct way
440    /// to handle bidirectional broadcasts like `[N, 1] op [1, S]
441    /// → [N, S]`, which simple `i % lhs_len` modulo indexing maps to
442    /// wrong cells.
443    BinaryFull {
444        lhs: usize,
445        rhs: usize,
446        dst: usize,
447        len: u32,
448        lhs_len: u32,
449        rhs_len: u32,
450        op: BinaryOp,
451        /// Output shape dims (row-major). Empty in the fast path.
452        out_dims_bcast: Vec<u32>,
453        /// Per-dim stride into `lhs` (0 where lhs broadcasts).
454        bcast_lhs_strides: Vec<u32>,
455        /// Per-dim stride into `rhs`.
456        bcast_rhs_strides: Vec<u32>,
457    },
458    /// Activation in-place
459    ActivationInPlace {
460        data: usize,
461        len: u32,
462        act: Activation,
463    },
464    /// Gather axis=0: table\[idx\] → out
465    Gather {
466        table: usize,
467        table_len: u32,
468        idx: usize,
469        dst: usize,
470        num_idx: u32,
471        trailing: u32,
472    },
473    /// Narrow: copy slice (`elem_bytes` = source element size: 4 for f32, 8 for f64).
474    Narrow {
475        src: usize,
476        dst: usize,
477        outer: u32,
478        src_stride: u32,
479        dst_stride: u32,
480        inner: u32,
481        elem_bytes: u8,
482    },
483    /// Copy (reshape, expand)
484    Copy { src: usize, dst: usize, len: u32 },
485    /// LayerNorm standalone
486    LayerNorm {
487        src: usize,
488        g: usize,
489        b: usize,
490        dst: usize,
491        rows: u32,
492        h: u32,
493        eps: f32,
494    },
495    /// GroupNorm on NCHW `[N,C,H,W]`.
496    GroupNorm {
497        src: usize,
498        g: usize,
499        b: usize,
500        dst: usize,
501        n: u32,
502        c: u32,
503        h: u32,
504        w: u32,
505        num_groups: u32,
506        eps: f32,
507    },
508    /// LayerNorm2d on NCHW (SAM / candle semantics).
509    LayerNorm2d {
510        src: usize,
511        g: usize,
512        b: usize,
513        dst: usize,
514        n: u32,
515        c: u32,
516        h: u32,
517        w: u32,
518        eps: f32,
519    },
520    /// ConvTranspose2d on NCHW.
521    ConvTranspose2d {
522        src: usize,
523        weight: usize,
524        dst: usize,
525        n: u32,
526        c_in: u32,
527        h: u32,
528        w_in: u32,
529        c_out: u32,
530        h_out: u32,
531        w_out: u32,
532        kh: u32,
533        kw: u32,
534        sh: u32,
535        sw: u32,
536        ph: u32,
537        pw: u32,
538        dh: u32,
539        dw: u32,
540        groups: u32,
541    },
542    /// Nearest 2× upsample on NCHW (per-batch slice).
543    ResizeNearest2x {
544        src: usize,
545        dst: usize,
546        n: u32,
547        c: u32,
548        h: u32,
549        w: u32,
550    },
551    /// SAM2 axial 2-D RoPE on `[batch, seq, num_heads * head_dim]`.
552    AxialRope2d {
553        src: usize,
554        dst: usize,
555        batch: u32,
556        seq: u32,
557        hidden: u32,
558        end_x: u32,
559        end_y: u32,
560        head_dim: u32,
561        num_heads: u32,
562        theta: f32,
563        repeat_factor: u32,
564    },
565    /// RMSNorm: out = (x / sqrt(mean(x^2) + eps)) * gamma + beta. No mean
566    /// subtraction, hence cheaper than LayerNorm. Used by Llama-class models.
567    RmsNorm {
568        src: usize,
569        g: usize,
570        b: usize,
571        dst: usize,
572        rows: u32,
573        h: u32,
574        eps: f32,
575    },
576    /// Softmax
577    Softmax { data: usize, rows: u32, cols: u32 },
578    /// Inclusive (or exclusive) cumulative sum along the last axis
579    /// (callers pre-flatten higher-dim cumsums via reshape views).
580    Cumsum {
581        src: usize,
582        dst: usize,
583        rows: u32,
584        cols: u32,
585        exclusive: bool,
586    },
587    /// Mamba-style selective scan (plan #15).
588    /// Inputs: x, delta \[b,s,h\], a \[h,n\], b \[b,s,n\], c \[b,s,n\].
589    /// Output: y \[b,s,h\]. State h carries through the seq.
590    SelectiveScan {
591        x: usize,
592        delta: usize,
593        a: usize,
594        b: usize,
595        c: usize,
596        dst: usize,
597        batch: u32,
598        seq: u32,
599        hidden: u32,
600        state_size: u32,
601    },
602
603    /// Gated DeltaNet linear-attention scan (Qwen3.5/3.6 trunk).
604    /// Inputs: q, k, v `[b, s, h, n]`; g, beta `[b, s, h]`. Output:
605    /// `[b, s, h, n]`. See `Op::GatedDeltaNet` for math.
606    GatedDeltaNet {
607        q: usize,
608        k: usize,
609        v: usize,
610        g: usize,
611        beta: usize,
612        /// When non-zero, load initial `[b, h, n, n]` state and write
613        /// the final state back in place after the scan.
614        state: usize,
615        dst: usize,
616        batch: u32,
617        seq: u32,
618        heads: u32,
619        state_size: u32,
620    },
621
622    /// 1×1 conv fast path (plan #26). The general Conv2D thunk
623    /// runs the textbook 7-deep loop; a 1×1 stride-1 padding-0
624    /// groups-1 conv is mathematically a per-batch matmul, and
625    /// dispatching it through BLAS is 3-10× faster than the
626    /// scalar nest. Common case: ViT patch-projection follow-on,
627    /// transformer "expert" reductions in some MoE designs.
628    ///
629    /// Per batch: weight `[c_out, c_in]` × input `[c_in, h*w]`
630    ///         = output `[c_out, h*w]`.
631    Conv2D1x1 {
632        src: usize,
633        weight: usize,
634        dst: usize,
635        n: u32,
636        c_in: u32,
637        c_out: u32,
638        hw: u32,
639    },
640
641    /// Fused dequant + matmul (plan #5). Today supports
642    /// `QuantScheme::Int8Block` (symmetric); other schemes panic
643    /// at lowering time with a clear message until kernels are added.
644    DequantMatMul {
645        x: usize,
646        w_q: usize,   // packed i8 bytes for Int8 schemes
647        scale: usize, // [k/block, n] f32 scale
648        zp: usize,    // [k/block, n] f32 zero-point (0 for sym)
649        dst: usize,
650        m: u32,
651        k: u32,
652        n: u32,
653        block_size: u32,
654        is_asymmetric: bool,
655    },
656
657    /// GGUF-format dequant + matmul. Weight is a packed byte tensor
658    /// in one of the K-quant super-block layouts (Q4_K, Q5_K, Q6_K,
659    /// Q8_K). Scales / mins live inside the packed bytes — no
660    /// side-channel scale tensor.
661    ///
662    /// Today this is a "dequant-to-scratch then sgemm" kernel — it
663    /// keeps the *arena* memory footprint down (weights stay packed)
664    /// but the dequant itself happens per matmul. A future fully
665    /// fused tile-streaming kernel would close the compute gap.
666    DequantMatMulGguf {
667        x: usize,   // f32 activations [m, k]
668        w_q: usize, // packed weight bytes (k*n elements packed)
669        dst: usize, // f32 output [m, n]
670        m: u32,
671        k: u32,
672        n: u32,
673        scheme: rlx_ir::quant::QuantScheme,
674    },
675
676    /// Int4 block dequant + matmul (packed nibbles, side scale/zp).
677    DequantMatMulInt4 {
678        x: usize,
679        w_q: usize,
680        scale: usize,
681        zp: usize,
682        dst: usize,
683        m: u32,
684        k: u32,
685        n: u32,
686        block_size: u32,
687        is_asymmetric: bool,
688    },
689
690    /// FP8 dequant + matmul (per-tensor or per-column scale).
691    DequantMatMulFp8 {
692        x: usize,
693        w_q: usize,
694        scale: usize,
695        dst: usize,
696        m: u32,
697        k: u32,
698        n: u32,
699        e5m2: bool,
700    },
701
702    /// NVFP4 (E2M1) block dequant + matmul — 16-wide groups, FP8 scales.
703    DequantMatMulNvfp4 {
704        x: usize,
705        w_q: usize,
706        scale: usize,
707        global_scale: usize,
708        dst: usize,
709        m: u32,
710        k: u32,
711        n: u32,
712    },
713
714    /// Fused LoRA matmul (plan #9): out = x·W + scale * (x·A)·B.
715    /// `r` is the LoRA rank (typically 4-64) — the rank-r
716    /// intermediate `x·A` lives in scratch, never on the arena.
717    LoraMatMul {
718        x: usize,
719        w: usize,
720        a: usize,
721        b: usize,
722        dst: usize,
723        m: u32,
724        k: u32,
725        n: u32,
726        r: u32,
727        scale: f32,
728    },
729    /// Fused sample: logits [batch, vocab] → token ids \[batch\].
730    /// See Op::Sample. Output values are f32-encoded usize indices
731    /// (matches the rest of the IR's "ids as f32" convention).
732    Sample {
733        logits: usize,
734        dst: usize,
735        batch: u32,
736        vocab: u32,
737        top_k: u32,       // 0 = disabled
738        top_p: f32,       // 1.0 = disabled
739        temperature: f32, // 1.0 = neutral
740        seed: u64,
741    },
742    /// Attention SDPA. `mask` is the offset of the optional mask tensor
743    /// (only meaningful when `mask_kind == MaskKind::Custom`); other
744    /// kinds synthesize the mask in-kernel.
745    ///
746    /// Q/K/V each carry a `_row_stride` (elements per source row).
747    /// Defaults to `heads * head_dim` — matches the standalone
748    /// "Q/K/V are their own contiguous buffers" case. The Narrow→
749    /// Attention fusion below rewrites these to the parent QKV stride
750    /// (typically `3 * heads * head_dim`) so the kernel reads QKV
751    /// directly without materializing the per-head buffers (plan #46).
752    Attention {
753        q: usize,
754        k: usize,
755        v: usize,
756        mask: usize,
757        out: usize,
758        batch: u32,
759        /// Query sequence length.
760        seq: u32,
761        /// Key/value sequence length. Differs from `seq` during cached decode.
762        kv_seq: u32,
763        heads: u32,
764        head_dim: u32,
765        mask_kind: rlx_ir::op::MaskKind,
766        q_row_stride: u32,
767        k_row_stride: u32,
768        v_row_stride: u32,
769        /// Memory layout flag. `false` (the historical default) →
770        /// `[B, S, H, D]` row-major: per-head offset is
771        /// `bi*S*H*D + si*H*D + hi*D`. `true` → `[B, H, S, D]`
772        /// (head-major), matching the convention used by rlx-cuda /
773        /// rlx-rocm / rlx-tpu: per-head offset is
774        /// `bi*H*S*D + hi*S*D + si*D`. Detected at lowering time
775        /// from the input shape vs `num_heads` / `head_dim`.
776        bhsd: bool,
777    },
778    /// [`Op::AttentionBackward`] — emits dQ, dK, or dV (see `wrt`).
779    AttentionBackward {
780        q: usize,
781        k: usize,
782        v: usize,
783        dy: usize,
784        mask: usize,
785        out: usize,
786        batch: u32,
787        seq: u32,
788        kv_seq: u32,
789        heads: u32,
790        head_dim: u32,
791        mask_kind: rlx_ir::op::MaskKind,
792        wrt: rlx_ir::op::AttentionBwdWrt,
793        bhsd: bool,
794    },
795    /// RoPE (rotary position embeddings).
796    /// `src_row_stride` is elements per source row (defaults to `hidden`
797    /// for the standalone case; set to `qkv_axis * inner` when the
798    /// thunk fusion pass below rewires Rope to read directly from the
799    /// fused QKV buffer — plan #45).
800    Rope {
801        src: usize,
802        cos: usize,
803        sin: usize,
804        dst: usize,
805        batch: u32,
806        seq: u32,
807        hidden: u32,
808        head_dim: u32,
809        n_rot: u32,
810        cos_len: u32,
811        src_row_stride: u32,
812    },
813    /// Fused attention block: QKV proj → split → \[RoPE\] → SDPA → output proj.
814    /// All intermediates stay in L1 cache. Zero arena writes between ops.
815    FusedAttnBlock {
816        hidden: usize,
817        qkv_w: usize,
818        out_w: usize,
819        mask: usize,
820        out: usize,
821        qkv_b: usize,
822        out_b: usize, // 0 = no bias
823        cos: usize,
824        sin: usize,
825        cos_len: u32, // 0 = no RoPE
826        batch: u32,
827        seq: u32,
828        hs: u32,
829        nh: u32,
830        dh: u32,
831        has_bias: bool,
832        has_rope: bool,
833    },
834    /// Fused ENTIRE transformer layer: attention + residual + LN + FFN + residual + LN.
835    /// Combines ~10 thunks into 1. All intermediates on stack. Zero arena traffic.
836    FusedBertLayer {
837        // attention
838        hidden: usize,
839        qkv_w: usize,
840        qkv_b: usize,
841        out_w: usize,
842        out_b: usize,
843        mask: usize,
844        // LN1
845        ln1_g: usize,
846        ln1_b: usize,
847        eps1: f32,
848        // FFN (GELU)
849        fc1_w: usize,
850        fc1_b: usize,
851        fc2_w: usize,
852        fc2_b: usize,
853        // LN2
854        ln2_g: usize,
855        ln2_b: usize,
856        eps2: f32,
857        // output
858        out: usize,
859        // dims
860        batch: u32,
861        seq: u32,
862        hs: u32,
863        nh: u32,
864        dh: u32,
865        int_dim: u32,
866    },
867    /// Fused Nomic transformer layer: attention+RoPE + residual + LN + SwiGLU FFN + residual + LN.
868    FusedNomicLayer {
869        hidden: usize,
870        qkv_w: usize,
871        out_w: usize,
872        mask: usize,
873        cos: usize,
874        sin: usize,
875        cos_len: u32,
876        ln1_g: usize,
877        ln1_b: usize,
878        eps1: f32,
879        fc11_w: usize,
880        fc12_w: usize,
881        fc2_w: usize,
882        ln2_g: usize,
883        ln2_b: usize,
884        eps2: f32,
885        out: usize,
886        batch: u32,
887        seq: u32,
888        hs: u32,
889        nh: u32,
890        dh: u32,
891        int_dim: u32,
892    },
893    /// Fused SwiGLU: out\[r,i\] = x\[r,i\] * silu(x[r, n_half+i]).
894    /// Input: [outer, 2*n_half] — concatenated up||gate per row.
895    /// Output: [outer, n_half].
896    FusedSwiGLU {
897        src: usize,
898        dst: usize,
899        n_half: u32,
900        total: u32,
901        gate_first: bool,
902    },
903    /// Concat along an axis: output[outer, axis, inner] = inputs concatenated.
904    /// Each entry of `inputs` is (src_offset, axis_len_for_that_input) in u32
905    /// elements. `outer`, `inner`, and `total_axis_len` are pre-computed
906    /// at compile time to avoid per-run shape work.
907    Concat {
908        dst: usize,
909        outer: u32,
910        inner: u32,
911        total_axis: u32,
912        inputs: Vec<(usize, u32)>,
913    },
914    /// Element-wise comparison: out = (lhs CMP rhs) ? 1.0 : 0.0
915    Compare {
916        lhs: usize,
917        rhs: usize,
918        dst: usize,
919        len: u32,
920        op: CmpOp,
921    },
922    /// Reduction along a contiguous range of axes. Input layout (after
923    /// shape decomposition) is `[outer, reduced, inner]`; output is
924    /// `[outer, inner]`. The single-axis cases (axis=0 → outer=1;
925    /// axis=last → inner=1) and contiguous multi-axis (e.g. reduce over
926    /// [0, 1] of an [N, C, H, W] tensor → outer=1, reduced=N*C, inner=H*W)
927    /// all map onto this triplet. Non-contiguous axes are not supported
928    /// and bail to Nop in the compile pass.
929    Reduce {
930        src: usize,
931        dst: usize,
932        outer: u32,
933        reduced: u32,
934        inner: u32,
935        op: ReduceOp,
936    },
937    /// Top-K **indices** along the last axis. Input shape `[outer, axis_dim]`,
938    /// output `[outer, k]` of f32-encoded i64 indices. Ties broken by
939    /// smaller index. Used by MoE gating + beam search.
940    TopK {
941        src: usize,
942        dst: usize,
943        outer: u32,
944        axis_dim: u32,
945        k: u32,
946    },
947    /// Indexed batched matmul: out\[i\] = input\[i\] @ weight[expert_idx\[i\]].
948    /// Naive impl per token; for real MoE workloads, sort-by-expert + run
949    /// segmented GEMM would amortize. Done when there's a workload.
950    GroupedMatMul {
951        input: usize,
952        weight: usize,
953        expert_idx: usize,
954        dst: usize,
955        m: u32,
956        k_dim: u32,
957        n: u32,
958        num_experts: u32,
959    },
960    /// GGUF K-quant packed expert stack + grouped matmul (MoE FFN).
961    DequantGroupedMatMulGguf {
962        input: usize,
963        w_q: usize,
964        expert_idx: usize,
965        dst: usize,
966        m: u32,
967        k_dim: u32,
968        n: u32,
969        num_experts: u32,
970        scheme: rlx_ir::quant::QuantScheme,
971    },
972    /// Materialize packed MoE weights to F32 `[E, K, N]` (autodiff helper).
973    DequantMoEWeightsGguf {
974        w_q: usize,
975        dst: usize,
976        k_dim: u32,
977        n: u32,
978        num_experts: u32,
979        scheme: rlx_ir::quant::QuantScheme,
980    },
981    /// Scatter-add: dst[indices\[i\] * trailing + j] += updates[i * trailing + j].
982    /// Output is zeroed first; multiple updates to the same row accumulate.
983    ScatterAdd {
984        updates: usize,
985        indices: usize,
986        dst: usize,
987        num_updates: u32,
988        out_dim: u32,
989        trailing: u32,
990    },
991    /// Ternary select: out = cond != 0 ? on_true : on_false
992    Where {
993        cond: usize,
994        on_true: usize,
995        on_false: usize,
996        dst: usize,
997        len: u32,
998    },
999    /// General N-D transpose / broadcast. `out_dims[i]` is the output's dim
1000    /// i length; `in_strides[i]` is the input stride (in elements) used to
1001    /// index that dim — 0 for broadcast dims (Expand). `in_total` is the
1002    /// total element count in the source buffer (≤ output total when
1003    /// broadcasting). Strides are pre-computed at compile time.
1004    Transpose {
1005        src: usize,
1006        dst: usize,
1007        in_total: u32,
1008        out_dims: Vec<u32>,
1009        in_strides: Vec<u32>,
1010    },
1011    /// Gather along an arbitrary axis. `outer = product(dims[..axis])`,
1012    /// `trailing = product(dims[axis+1..])`, `axis_dim` = the dimension
1013    /// being indexed into. Output: outer × num_idx × trailing.
1014    /// (axis=0 still routes to the simpler Thunk::Gather fast path.)
1015    GatherAxis {
1016        table: usize,
1017        idx: usize,
1018        dst: usize,
1019        outer: u32,
1020        axis_dim: u32,
1021        num_idx: u32,
1022        trailing: u32,
1023    },
1024    /// 2D pooling (Max or Mean). Input layout [N, C, H, W], output
1025    /// [N, C, H_out, W_out]. Padding is implicit-zero; Mean divides by
1026    /// the full kernel area (matches torch's `count_include_pad=True`).
1027    Pool2D {
1028        src: usize,
1029        dst: usize,
1030        n: u32,
1031        c: u32,
1032        h: u32,
1033        w: u32,
1034        h_out: u32,
1035        w_out: u32,
1036        kh: u32,
1037        kw: u32,
1038        sh: u32,
1039        sw: u32,
1040        ph: u32,
1041        pw: u32,
1042        kind: ReduceOp,
1043    },
1044    /// 2D convolution. Input [N, C_in, H, W], weight [C_out, C_in_per_group, kH, kW],
1045    /// output [N, C_out, H_out, W_out]. Bias is a separate Op::Binary::Add
1046    /// after the conv (matching the IR's input layout — Op::Conv has 2 inputs).
1047    /// Naive direct convolution; sufficient for correctness, not optimised.
1048    Conv2D {
1049        src: usize,
1050        weight: usize,
1051        dst: usize,
1052        n: u32,
1053        c_in: u32,
1054        h: u32,
1055        w: u32,
1056        c_out: u32,
1057        h_out: u32,
1058        w_out: u32,
1059        kh: u32,
1060        kw: u32,
1061        sh: u32,
1062        sw: u32,
1063        ph: u32,
1064        pw: u32,
1065        dh: u32,
1066        dw: u32,
1067        groups: u32,
1068    },
1069
1070    // ── Backward / training kernels ─────────────────────────────
1071    /// Real INT8 matmul with i32 accumulation.
1072    ///   `out[m, n] = requantize(bias[n] + Σₖ (x[m,k]-x_zp)·(w[k,n]-w_zp), mult, out_zp)`
1073    /// Reads `x` and `w` as i8, `bias` as i32; writes `out` as i8.
1074    /// Same kernel shape as `rlx_cortexm::dense::dense_i8` — promoted
1075    /// to a desktop thunk so a quantized graph compiled here doesn't
1076    /// have to round-trip through fake-quant.
1077    QMatMul {
1078        x: usize,
1079        w: usize,
1080        bias: usize,
1081        out: usize,
1082        m: u32,
1083        k: u32,
1084        n: u32,
1085        x_zp: i32,
1086        w_zp: i32,
1087        out_zp: i32,
1088        mult: f32,
1089    },
1090
1091    /// Real INT8 conv2d, NCHW layout. Same loop shape as `Thunk::Conv2D`
1092    /// but with i8 reads, i32 accumulation, and per-output requantize
1093    /// to i8. Bias is i32 in the accumulator scale.
1094    QConv2d {
1095        x: usize,
1096        w: usize,
1097        bias: usize,
1098        out: usize,
1099        n: u32,
1100        c_in: u32,
1101        h: u32,
1102        w_in: u32,
1103        c_out: u32,
1104        h_out: u32,
1105        w_out: u32,
1106        kh: u32,
1107        kw: u32,
1108        sh: u32,
1109        sw: u32,
1110        ph: u32,
1111        pw: u32,
1112        dh: u32,
1113        dw: u32,
1114        groups: u32,
1115        x_zp: i32,
1116        w_zp: i32,
1117        out_zp: i32,
1118        mult: f32,
1119    },
1120
1121    /// INT8 quantize. Reads `x` as f32, writes `q` as i8.
1122    /// `chan = (i / inner) % chan_dim` selects the per-channel
1123    /// scale/zp; `chan_axis` is informational only (the kernel uses
1124    /// `chan_dim` and `inner` directly).
1125    /// For per-tensor, `chan_dim = 1` and `inner = len` so `chan` is
1126    /// always 0.
1127    Quantize {
1128        x: usize,
1129        q: usize,
1130        len: u32,
1131        chan_axis: u32,
1132        chan_dim: u32,
1133        inner: u32,
1134        scales: Vec<f32>,
1135        zero_points: Vec<i32>,
1136    },
1137
1138    /// INT8 dequantize — inverse of `Thunk::Quantize`.
1139    Dequantize {
1140        q: usize,
1141        x: usize,
1142        len: u32,
1143        chan_axis: u32,
1144        chan_dim: u32,
1145        inner: u32,
1146        scales: Vec<f32>,
1147        zero_points: Vec<i32>,
1148    },
1149
1150    /// QAT fake-quantize. Per-channel (or per-tensor) symmetric
1151    /// quantize-then-dequantize on the fly. Computes
1152    ///   `s[c] = max(|x[..., c, ...]|) / q_max`
1153    /// then
1154    ///   `out[i] = clamp(round(x[i]/s[c]), -q_max, q_max) * s[c]`
1155    /// with `q_max = {127, 7, 1}` for `bits = {8, 4, 2}`. Same
1156    /// channel-layout convention as `Thunk::Quantize`: every
1157    /// element's channel is `(i / inner) % chan_dim`. The kernel
1158    /// does two passes — one to scan max-abs per channel, one to
1159    /// quant-dequant per element.
1160    FakeQuantize {
1161        x: usize,
1162        out: usize,
1163        len: u32,
1164        chan_axis: u32,
1165        chan_dim: u32,
1166        inner: u32,
1167        bits: u8,
1168        /// STE variant — informational on the forward side (output is
1169        /// the same regardless), kernel-relevant in the matching
1170        /// `FakeQuantizeBackward` thunk.
1171        ste: rlx_ir::op::SteKind,
1172        /// Scale-tracking strategy. `PerBatch` recomputes
1173        /// `max_abs/q_max` every call (the original path). `EMA{decay}`
1174        /// blends per-batch max-abs into the `state_off` buffer; `Fixed`
1175        /// reads `state_off` and never updates it.
1176        scale_mode: rlx_ir::op::ScaleMode,
1177        /// `Some(off)` for `EMA` and `Fixed`; `None` for `PerBatch`.
1178        /// Points at a `[chan_dim]` f32 buffer holding the running scale
1179        /// per channel.
1180        state_off: Option<usize>,
1181    },
1182
1183    /// Backward pass for `Op::FakeQuantize` under one of four STE
1184    /// variants. Computes `dx[i]` from the f32 forward input `x` and
1185    /// the upstream gradient `dy`, using the same per-channel scale
1186    /// scheme as the forward.
1187    FakeQuantizeBackward {
1188        x: usize,
1189        dy: usize,
1190        dx: usize,
1191        len: u32,
1192        chan_axis: u32,
1193        chan_dim: u32,
1194        inner: u32,
1195        bits: u8,
1196        ste: rlx_ir::op::SteKind,
1197    },
1198
1199    /// LSQ forward — same kernel shape as `FakeQuantize` Fixed mode.
1200    /// Reads scale from `scale_off` (a `[chan_dim]` Param tensor).
1201    FakeQuantizeLSQ {
1202        x: usize,
1203        scale_off: usize,
1204        out: usize,
1205        len: u32,
1206        chan_axis: u32,
1207        chan_dim: u32,
1208        inner: u32,
1209        bits: u8,
1210    },
1211
1212    /// LSQ backward, x-gradient. STE-clipped: passes upstream
1213    /// through inside the quantization range, zeros outside.
1214    FakeQuantizeLSQBackwardX {
1215        x: usize,
1216        scale_off: usize,
1217        dy: usize,
1218        dx: usize,
1219        len: u32,
1220        chan_axis: u32,
1221        chan_dim: u32,
1222        inner: u32,
1223        bits: u8,
1224    },
1225
1226    /// LSQ backward, scale-gradient. Per-channel:
1227    ///   `dscale[c] = sum_i ψ(x[i]/s[c]) · upstream[i]`
1228    /// where `ψ(z) = -z + round(z)` if `|z| ≤ q_max` else
1229    /// `sign(z) · q_max`. Output shape: `[chan_dim]`.
1230    FakeQuantizeLSQBackwardScale {
1231        x: usize,
1232        scale_off: usize,
1233        dy: usize,
1234        dscale: usize,
1235        len: u32,
1236        chan_axis: u32,
1237        chan_dim: u32,
1238        inner: u32,
1239        bits: u8,
1240    },
1241
1242    /// ReLU backward: `dx[i] = dy[i] if x[i] > 0 else 0`.
1243    ReluBackward {
1244        x: usize,
1245        dy: usize,
1246        dx: usize,
1247        len: u32,
1248    },
1249    /// f64 sibling of `ReluBackward` — same shape as the f32 variant
1250    /// but reads/writes 8 bytes per element. Required because
1251    /// `ReluBackward`'s `&[f32]` slot view returns half of every f64
1252    /// otherwise → backward silently produces 0 gradients on an f64
1253    /// graph. Mirrors the `ActivationBackwardF64` split.
1254    ReluBackwardF64 {
1255        x: usize,
1256        dy: usize,
1257        dx: usize,
1258        len: u32,
1259    },
1260
1261    /// Generic element-wise activation backward.
1262    /// `dx[i] = (d/dx act(x))[i] · dy[i]`. The closure dispatch is
1263    /// per-element; expensive activations (Gelu) recompute internals
1264    /// inline rather than threading an extra "saved y" tensor through.
1265    ActivationBackward {
1266        x: usize,
1267        dy: usize,
1268        dx: usize,
1269        len: u32,
1270        kind: Activation,
1271    },
1272    /// f64 sibling of `ActivationBackward` — slot offsets, len in
1273    /// elements; kernel reads/writes 8 bytes per element. Required
1274    /// because `ActivationBackward`'s `&[f32]` slot view silently
1275    /// returns garbage on an f64 graph (cb % 4 still works but every
1276    /// loaded value is half of an f64 → wrong gradient).
1277    ActivationBackwardF64 {
1278        x: usize,
1279        dy: usize,
1280        dx: usize,
1281        len: u32,
1282        kind: Activation,
1283    },
1284
1285    /// LayerNorm backward — input gradient. Recomputes mean/var/x̂ from
1286    /// `x` and emits the closed-form `d_x` per row.
1287    LayerNormBackwardInput {
1288        x: usize,
1289        gamma: usize,
1290        dy: usize,
1291        dx: usize,
1292        rows: u32,
1293        h: u32,
1294        eps: f32,
1295    },
1296
1297    /// LayerNorm backward — gamma gradient. `d_gamma[d] = Σ_row dy·x̂`.
1298    LayerNormBackwardGamma {
1299        x: usize,
1300        dy: usize,
1301        dgamma: usize,
1302        rows: u32,
1303        h: u32,
1304        eps: f32,
1305    },
1306
1307    RmsNormBackwardInput {
1308        x: usize,
1309        gamma: usize,
1310        beta: usize,
1311        dy: usize,
1312        dx: usize,
1313        rows: u32,
1314        h: u32,
1315        eps: f32,
1316    },
1317    RmsNormBackwardGamma {
1318        x: usize,
1319        gamma: usize,
1320        beta: usize,
1321        dy: usize,
1322        dgamma: usize,
1323        rows: u32,
1324        h: u32,
1325        eps: f32,
1326    },
1327    RmsNormBackwardBeta {
1328        x: usize,
1329        gamma: usize,
1330        beta: usize,
1331        dy: usize,
1332        dbeta: usize,
1333        rows: u32,
1334        h: u32,
1335        eps: f32,
1336    },
1337    RopeBackward {
1338        dy: usize,
1339        cos: usize,
1340        sin: usize,
1341        dx: usize,
1342        batch: u32,
1343        seq: u32,
1344        hidden: u32,
1345        head_dim: u32,
1346        n_rot: u32,
1347        cos_len: u32,
1348    },
1349    CumsumBackward {
1350        dy: usize,
1351        dx: usize,
1352        rows: u32,
1353        cols: u32,
1354        exclusive: bool,
1355    },
1356    GatherBackward {
1357        dy: usize,
1358        indices: usize,
1359        dst: usize,
1360        outer: u32,
1361        axis_dim: u32,
1362        num_idx: u32,
1363        trailing: u32,
1364    },
1365
1366    GroupNormBackwardInput {
1367        x: usize,
1368        gamma: usize,
1369        beta: usize,
1370        dy: usize,
1371        dx: usize,
1372        n: u32,
1373        c: u32,
1374        h: u32,
1375        w: u32,
1376        num_groups: u32,
1377        eps: f32,
1378    },
1379    GroupNormBackwardGamma {
1380        x: usize,
1381        dy: usize,
1382        dgamma: usize,
1383        n: u32,
1384        c: u32,
1385        h: u32,
1386        w: u32,
1387        num_groups: u32,
1388        eps: f32,
1389    },
1390    GroupNormBackwardBeta {
1391        dy: usize,
1392        dbeta: usize,
1393        n: u32,
1394        c: u32,
1395        h: u32,
1396        w: u32,
1397    },
1398
1399    /// 2D max-pool backward (NCHW). Recomputes the argmax position
1400    /// inside each window and accumulates `dy` into `dx` at that
1401    /// position. Output is zeroed first; ties resolve to the first
1402    /// hit (lowest (kh,kw) index), matching what the forward kernel
1403    /// does with `acc.max(v)`.
1404    MaxPool2dBackward {
1405        x: usize,
1406        dy: usize,
1407        dx: usize,
1408        n: u32,
1409        c: u32,
1410        h: u32,
1411        w: u32,
1412        h_out: u32,
1413        w_out: u32,
1414        kh: u32,
1415        kw: u32,
1416        sh: u32,
1417        sw: u32,
1418        ph: u32,
1419        pw: u32,
1420    },
1421
1422    /// 2D conv backward w.r.t. input (`dx = conv_transpose(dy, w)`).
1423    /// `dy [N, C_out, H_out, W_out]`, `w [C_out, C_in_per_group, kH, kW]`,
1424    /// `dx [N, C_in, H, W]`.
1425    Conv2dBackwardInput {
1426        dy: usize,
1427        w: usize,
1428        dx: usize,
1429        n: u32,
1430        c_in: u32,
1431        h: u32,
1432        w_in: u32,
1433        c_out: u32,
1434        h_out: u32,
1435        w_out: u32,
1436        kh: u32,
1437        kw: u32,
1438        sh: u32,
1439        sw: u32,
1440        ph: u32,
1441        pw: u32,
1442        dh: u32,
1443        dw: u32,
1444        groups: u32,
1445    },
1446
1447    /// 2D conv backward w.r.t. weight. `x [N, C_in, H, W]`,
1448    /// `dy [N, C_out, H_out, W_out]`, `dw [C_out, C_in_per_group, kH, kW]`.
1449    /// `dw` is zeroed before accumulation.
1450    Conv2dBackwardWeight {
1451        x: usize,
1452        dy: usize,
1453        dw: usize,
1454        n: u32,
1455        c_in: u32,
1456        h: u32,
1457        w: u32,
1458        c_out: u32,
1459        h_out: u32,
1460        w_out: u32,
1461        kh: u32,
1462        kw: u32,
1463        sh: u32,
1464        sw: u32,
1465        ph: u32,
1466        pw: u32,
1467        dh: u32,
1468        dw_dil: u32,
1469        groups: u32,
1470    },
1471
1472    /// Fused softmax + cross-entropy loss with f32-encoded integer
1473    /// labels. `logits [N, C]`, `labels [N]`, output `[N]` per-row loss.
1474    /// Numerically stable (max-subtract before exp).
1475    SoftmaxCrossEntropy {
1476        logits: usize,
1477        labels: usize,
1478        dst: usize,
1479        n: u32,
1480        c: u32,
1481    },
1482
1483    /// Backward of the fused loss above.
1484    /// `dlogits[n, k] = (softmax(logits[n])[k] - one_hot(labels[n])[k]) * d_loss[n]`.
1485    SoftmaxCrossEntropyBackward {
1486        logits: usize,
1487        labels: usize,
1488        d_loss: usize,
1489        dlogits: usize,
1490        n: u32,
1491        c: u32,
1492    },
1493
1494    /// User-registered custom op (CPU side). Lowered from `Op::Custom`.
1495    /// `kernel` is resolved against the global CPU kernel registry at
1496    /// compile time and stored as `Arc<dyn CpuKernel>` so execution
1497    /// avoids per-call lookups. v1: f32 contiguous only — see
1498    /// `op_registry::CpuKernel::execute_f32`.
1499    CustomOp {
1500        kernel: Arc<dyn CpuKernel>,
1501        inputs: Vec<(usize, u32, Shape)>, // (offset, len_elements, shape)
1502        output: (usize, u32, Shape),      // (offset, len_elements, shape)
1503        attrs: Vec<u8>,
1504    },
1505
1506    /// 1D FFT along the last axis. Input/output are `[..., 2N]`
1507    /// real-block layout (first N real, second N imag along the
1508    /// transformed axis). `outer` is the product of all leading axes;
1509    /// `n_complex` is N (the number of complex points). Both halves
1510    /// of the real-block layout are read together by the kernel.
1511    /// `dtype` selects the f32 or f64 path; the two share structure
1512    /// but not buffers, so a flag at compile time avoids per-row
1513    /// dispatch.
1514    /// CPU reference 3D Gaussian splat render ([`rlx_ir::Op::GaussianSplatRender`]).
1515    GaussianSplatRender {
1516        positions_off: usize,
1517        positions_len: usize,
1518        scales_off: usize,
1519        scales_len: usize,
1520        rotations_off: usize,
1521        rotations_len: usize,
1522        opacities_off: usize,
1523        opacities_len: usize,
1524        colors_off: usize,
1525        colors_len: usize,
1526        sh_coeffs_off: usize,
1527        sh_coeffs_len: usize,
1528        meta_off: usize,
1529        dst_off: usize,
1530        dst_len: usize,
1531        width: u32,
1532        height: u32,
1533        tile_size: u32,
1534        radius_scale: f32,
1535        alpha_cutoff: f32,
1536        max_splat_steps: u32,
1537        transmittance_threshold: f32,
1538        max_list_entries: u32,
1539    },
1540    GaussianSplatRenderBackward {
1541        positions_off: usize,
1542        positions_len: usize,
1543        scales_off: usize,
1544        scales_len: usize,
1545        rotations_off: usize,
1546        rotations_len: usize,
1547        opacities_off: usize,
1548        opacities_len: usize,
1549        colors_off: usize,
1550        colors_len: usize,
1551        sh_coeffs_off: usize,
1552        sh_coeffs_len: usize,
1553        meta_off: usize,
1554        d_loss_off: usize,
1555        d_loss_len: usize,
1556        packed_off: usize,
1557        packed_len: usize,
1558        width: u32,
1559        height: u32,
1560        tile_size: u32,
1561        radius_scale: f32,
1562        alpha_cutoff: f32,
1563        max_splat_steps: u32,
1564        transmittance_threshold: f32,
1565        max_list_entries: u32,
1566        loss_grad_clip: f32,
1567        sh_band: u32,
1568        max_anisotropy: f32,
1569    },
1570    /// Strict IR stage 1 — project + bin + sort + rays ([`Op::GaussianSplatPrepare`]).
1571    GaussianSplatPrepare {
1572        positions_off: usize,
1573        positions_len: usize,
1574        scales_off: usize,
1575        scales_len: usize,
1576        rotations_off: usize,
1577        rotations_len: usize,
1578        opacities_off: usize,
1579        opacities_len: usize,
1580        colors_off: usize,
1581        colors_len: usize,
1582        sh_coeffs_off: usize,
1583        sh_coeffs_len: usize,
1584        meta_off: usize,
1585        meta_len: usize,
1586        prep_off: usize,
1587        prep_len: usize,
1588        width: u32,
1589        height: u32,
1590        tile_size: u32,
1591        radius_scale: f32,
1592        alpha_cutoff: f32,
1593        max_splat_steps: u32,
1594        transmittance_threshold: f32,
1595        max_list_entries: u32,
1596    },
1597    /// Strict IR stage 2 — tile raster from prepare buffer ([`Op::GaussianSplatRasterize`]).
1598    GaussianSplatRasterize {
1599        prep_off: usize,
1600        prep_len: usize,
1601        meta_off: usize,
1602        meta_len: usize,
1603        dst_off: usize,
1604        dst_len: usize,
1605        count: usize,
1606        width: u32,
1607        height: u32,
1608        tile_size: u32,
1609        alpha_cutoff: f32,
1610        max_splat_steps: u32,
1611        transmittance_threshold: f32,
1612        max_list_entries: u32,
1613    },
1614    Fft1d {
1615        src: usize,
1616        dst: usize,
1617        outer: u32,
1618        n_complex: u32,
1619        inverse: bool,
1620        norm_tag: u32,
1621        dtype: rlx_ir::DType,
1622    },
1623}
1624
1625/// Compiled thunk schedule — the runtime hot path.
1626/// Nop thunks are filtered out at compile time for zero iteration overhead.
1627#[derive(Clone)]
1628pub struct ThunkSchedule {
1629    pub thunks: Vec<Thunk>,
1630    /// TIDE merged placement mask (union across layers).
1631    pub moe_resident: Option<std::sync::Arc<[bool]>>,
1632    /// Per MoE layer placement (`layer[e]`); preferred when set.
1633    pub moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
1634    /// MoE router TopK capture (per-layer refresh).
1635    pub moe_topk_capture: Option<std::sync::Arc<crate::moe_topk_capture::MoeTopkCapture>>,
1636    /// Cached config values.
1637    pub mask_threshold: f32,
1638    pub mask_neg_inf: f32,
1639    pub score_skip: f32,
1640    /// Pre-compiled closure dispatch (zero match overhead). `Arc` (not
1641    /// `Box`) so the schedule can be `Clone` — multiple parallel
1642    /// executors share the same compiled closures (they're read-only
1643    /// `Fn(*mut u8)` so concurrent dispatch is safe; the arena pointer
1644    /// they receive is the only mutable state and is per-executor).
1645    pub compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>>,
1646}
1647
1648impl ThunkSchedule {
1649    pub fn strip_nops(&mut self) {
1650        self.thunks.retain(|t| !matches!(t, Thunk::Nop));
1651        // compiled_fns must be rebuilt after stripping — caller should
1652        // call strip_nops() before compile_closures().
1653        self.compiled_fns.clear();
1654    }
1655}
1656
1657/// Get the arena byte offset for a node.
1658fn node_offset(arena: &Arena, id: NodeId) -> usize {
1659    if arena.has_buffer(id) {
1660        arena.byte_offset(id)
1661    } else {
1662        usize::MAX
1663    }
1664}
1665
1666/// Every byte-offset that a thunk reads from. Used by the Narrow→Rope
1667/// fusion (#45) to verify a Narrow's dst has exactly one consumer
1668/// before eliding it. Conservative: when in doubt about reads (an op
1669/// not yet listed here), the fusion will skip — correctness over
1670/// completeness.
1671fn thunk_read_offsets(t: &Thunk) -> Vec<usize> {
1672    match t {
1673        Thunk::Sgemm { a, b, .. } => vec![*a, *b],
1674        Thunk::DenseSolveF64 { a, b, .. } => vec![*a, *b],
1675        Thunk::DenseSolveF32 { a, b, .. } => vec![*a, *b],
1676        Thunk::BatchedDenseSolveF64 { a, b, .. } => vec![*a, *b],
1677        Thunk::BatchedDgemmF64 { a, b, .. } => vec![*a, *b],
1678        Thunk::BatchedSgemm { a, b, .. } => vec![*a, *b],
1679        Thunk::FusedMmBiasAct { a, w, bias, .. } => vec![*a, *w, *bias],
1680        Thunk::BiasAdd { src, bias, .. } => vec![*src, *bias],
1681        Thunk::BinaryFull { lhs, rhs, .. } => vec![*lhs, *rhs],
1682        Thunk::BinaryFullF64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1683        Thunk::BinaryFullC64 { lhs, rhs, .. } => vec![*lhs, *rhs],
1684        Thunk::ComplexNormSqF32 { src, .. } => vec![*src],
1685        Thunk::ComplexNormSqBackwardF32 { z, g, .. } => vec![*z, *g],
1686        Thunk::ConjugateC64 { src, .. } => vec![*src],
1687        Thunk::Scan {
1688            outer_init_off,
1689            xs_inputs,
1690            ..
1691        } => {
1692            let mut v = vec![*outer_init_off];
1693            for (_, outer_xs_off, _) in xs_inputs.iter() {
1694                v.push(*outer_xs_off);
1695            }
1696            v
1697        }
1698        Thunk::ScanBackward {
1699            outer_init_off,
1700            outer_traj_off,
1701            outer_upstream_off,
1702            outer_xs_offs,
1703            ..
1704        } => {
1705            let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1706            for (off, _) in outer_xs_offs.iter() {
1707                v.push(*off);
1708            }
1709            v
1710        }
1711        Thunk::ScanBackwardXs {
1712            outer_init_off,
1713            outer_traj_off,
1714            outer_upstream_off,
1715            outer_xs_offs,
1716            ..
1717        } => {
1718            let mut v = vec![*outer_init_off, *outer_traj_off, *outer_upstream_off];
1719            for (off, _) in outer_xs_offs.iter() {
1720                v.push(*off);
1721            }
1722            v
1723        }
1724        Thunk::CustomFn { inputs, .. } => {
1725            inputs.iter().map(|(_, outer_off, _)| *outer_off).collect()
1726        }
1727        Thunk::ActivationInPlace { data, .. } => vec![*data],
1728        Thunk::LayerNorm { src, g, b, .. } | Thunk::GroupNorm { src, g, b, .. } => {
1729            vec![*src, *g, *b]
1730        }
1731        Thunk::ResizeNearest2x { src, .. } => vec![*src],
1732        Thunk::AxialRope2d { src, .. } => vec![*src],
1733        Thunk::FusedResidualLN {
1734            x, res, bias, g, b, ..
1735        } => vec![*x, *res, *bias, *g, *b],
1736        Thunk::FusedResidualRmsNorm {
1737            x, res, bias, g, b, ..
1738        } => vec![*x, *res, *bias, *g, *b],
1739        Thunk::RmsNorm { src, g, b, .. } => vec![*src, *g, *b],
1740        Thunk::Softmax { data, .. } => vec![*data],
1741        Thunk::Cumsum { src, .. } => vec![*src],
1742        Thunk::Sample { logits, .. } => vec![*logits],
1743        Thunk::LoraMatMul { x, w, a, b, .. } => vec![*x, *w, *a, *b],
1744        Thunk::DequantMatMul {
1745            x, w_q, scale, zp, ..
1746        } => vec![*x, *w_q, *scale, *zp],
1747        Thunk::DequantMatMulGguf { x, w_q, .. } => vec![*x, *w_q],
1748        Thunk::DequantMatMulInt4 {
1749            x, w_q, scale, zp, ..
1750        } => vec![*x, *w_q, *scale, *zp],
1751        Thunk::DequantMatMulFp8 { x, w_q, scale, .. } => vec![*x, *w_q, *scale],
1752        Thunk::DequantMatMulNvfp4 {
1753            x,
1754            w_q,
1755            scale,
1756            global_scale,
1757            ..
1758        } => vec![*x, *w_q, *scale, *global_scale],
1759        Thunk::Conv2D1x1 { src, weight, .. } => vec![*src, *weight],
1760        Thunk::SelectiveScan {
1761            x, delta, a, b, c, ..
1762        } => vec![*x, *delta, *a, *b, *c],
1763        Thunk::GatedDeltaNet {
1764            q,
1765            k,
1766            v,
1767            g,
1768            beta,
1769            state,
1770            ..
1771        } => {
1772            let mut v = vec![*q, *k, *v, *g, *beta];
1773            if *state != 0 {
1774                v.push(*state);
1775            }
1776            v
1777        }
1778        Thunk::Attention { q, k, v, mask, .. } => vec![*q, *k, *v, *mask],
1779        Thunk::AttentionBackward {
1780            q, k, v, dy, mask, ..
1781        } => {
1782            let mut v = vec![*q, *k, *v, *dy];
1783            if *mask != 0 {
1784                v.push(*mask);
1785            }
1786            v
1787        }
1788        Thunk::Rope { src, cos, sin, .. } => vec![*src, *cos, *sin],
1789        Thunk::FusedAttnBlock {
1790            hidden,
1791            qkv_w,
1792            out_w,
1793            mask,
1794            qkv_b,
1795            out_b,
1796            cos,
1797            sin,
1798            ..
1799        } => vec![*hidden, *qkv_w, *out_w, *mask, *qkv_b, *out_b, *cos, *sin],
1800        Thunk::FusedSwiGLU { src, .. } => vec![*src],
1801        Thunk::Concat { inputs, .. } => inputs.iter().map(|(off, _)| *off).collect(),
1802        Thunk::ConcatF64 { inputs, .. } => inputs.iter().map(|(off, _)| *off).collect(),
1803        Thunk::Narrow { src, .. } => vec![*src],
1804        Thunk::Copy { src, .. } => vec![*src],
1805        Thunk::Gather { table, idx, .. } => vec![*table, *idx],
1806        // Anything not enumerated → return the dst as a "read" too,
1807        // forcing the fusion to bail (read_count >= 2 → skip). Keeps
1808        // this list safe to be incomplete.
1809        _ => vec![],
1810    }
1811}
1812
1813/// Fused dequant + matmul (plan #5). Int8-blockwise weights: each
1814/// `block_size` consecutive elements of a column share one f32
1815/// scale (and optionally a zero-point). The dequant happens inside
1816/// the inner accumulate so the f32 weight is never materialized.
1817///
1818/// `w_bytes` is the row-major i8 weight matrix `[k, n]`. `scales`
1819/// and `zps` are `[k/block, n]`. When `asym=false`, `zps` may be
1820/// empty.
1821///
1822/// Today this is the reference scalar implementation — the win is
1823/// memory bandwidth, not flops, since LLM weights dominate the
1824/// working set. A NEON SIMD path that loads 16 i8 → splat-scale →
1825/// fused-multiply-add is the natural follow-on.
1826#[allow(clippy::too_many_arguments)]
1827fn dequant_matmul_int8(
1828    x: &[f32],       // [m, k]
1829    w_bytes: &[i8],  // [k, n]
1830    scales: &[f32],  // [k/block, n]
1831    zps: &[f32],     // [k/block, n] or empty
1832    out: &mut [f32], // [m, n]
1833    m: usize,
1834    k: usize,
1835    n: usize,
1836    block_size: usize,
1837    asym: bool,
1838) {
1839    let blocks_per_col = k.div_ceil(block_size);
1840    for i in 0..m {
1841        for j in 0..n {
1842            let mut acc = 0f32;
1843            for p in 0..k {
1844                let block = p / block_size;
1845                let s = scales[block * n + j];
1846                let z = if asym { zps[block * n + j] } else { 0.0 };
1847                let q = w_bytes[p * n + j] as f32;
1848                let dequantized = (q - z) * s;
1849                acc += x[i * k + p] * dequantized;
1850            }
1851            out[i * n + j] = acc;
1852        }
1853    }
1854    let _ = blocks_per_col;
1855}
1856
1857#[allow(clippy::too_many_arguments)]
1858fn dequant_matmul_int4(
1859    x: &[f32],
1860    w_bytes: &[u8],
1861    scales: &[f32],
1862    zps: &[f32],
1863    out: &mut [f32],
1864    m: usize,
1865    k: usize,
1866    n: usize,
1867    block_size: usize,
1868    asym: bool,
1869) {
1870    for i in 0..m {
1871        for j in 0..n {
1872            let mut acc = 0f32;
1873            for p in 0..k {
1874                let block = p / block_size;
1875                let s = scales[block * n + j];
1876                let z = if asym { zps[block * n + j] } else { 0.0 };
1877                let byte_idx = (p * n + j) / 2;
1878                let nibble = if (p * n + j) & 1 == 0 {
1879                    w_bytes[byte_idx] & 0x0F
1880                } else {
1881                    w_bytes[byte_idx] >> 4
1882                };
1883                let dequantized = (nibble as f32 - z) * s;
1884                acc += x[i * k + p] * dequantized;
1885            }
1886            out[i * n + j] = acc;
1887        }
1888    }
1889}
1890
1891fn fp8_e4m3_to_f32(b: u8) -> f32 {
1892    let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
1893    let exp = (b >> 3) & 0x0F;
1894    let mant = b & 0x07;
1895    if exp == 0 {
1896        if mant == 0 {
1897            return 0.0;
1898        }
1899        return sign * (mant as f32) * 2f32.powi(-9);
1900    }
1901    if exp == 0x0F {
1902        return if mant == 0 {
1903            sign * f32::INFINITY
1904        } else {
1905            f32::NAN
1906        };
1907    }
1908    sign * (1.0 + mant as f32 / 8.0) * 2f32.powi(exp as i32 - 7)
1909}
1910
1911fn fp8_e5m2_to_f32(b: u8) -> f32 {
1912    let sign = if b & 0x80 != 0 { -1.0 } else { 1.0 };
1913    let exp = (b >> 2) & 0x1F;
1914    let mant = b & 0x03;
1915    if exp == 0 {
1916        if mant == 0 {
1917            return 0.0;
1918        }
1919        return sign * (mant as f32) * 2f32.powi(-16);
1920    }
1921    if exp == 0x1F {
1922        return if mant == 0 {
1923            sign * f32::INFINITY
1924        } else {
1925            f32::NAN
1926        };
1927    }
1928    sign * (1.0 + mant as f32 / 4.0) * 2f32.powi(exp as i32 - 15)
1929}
1930
1931#[allow(clippy::too_many_arguments)]
1932fn dequant_matmul_fp8(
1933    x: &[f32],
1934    w_bytes: &[u8],
1935    scales: &[f32],
1936    out: &mut [f32],
1937    m: usize,
1938    k: usize,
1939    n: usize,
1940    e5m2: bool,
1941) {
1942    let dequant = if e5m2 {
1943        fp8_e5m2_to_f32
1944    } else {
1945        fp8_e4m3_to_f32
1946    };
1947    for i in 0..m {
1948        for j in 0..n {
1949            let mut acc = 0f32;
1950            for p in 0..k {
1951                let w = dequant(w_bytes[p * n + j]);
1952                let s = scales.get(j).copied().unwrap_or(1.0);
1953                acc += x[i * k + p] * w * s;
1954            }
1955            out[i * n + j] = acc;
1956        }
1957    }
1958}
1959
1960#[allow(clippy::too_many_arguments)]
1961pub fn dequant_matmul_nvfp4(
1962    x: &[f32],
1963    w_bytes: &[u8],
1964    scale_bytes: &[u8],
1965    global_scale: f32,
1966    out: &mut [f32],
1967    m: usize,
1968    k: usize,
1969    n: usize,
1970) {
1971    use rlx_ir::{NVFP4_GROUP_SIZE, fp4_e2m1_to_f32, fp8_e4m3_scale_to_f32};
1972    let gs = NVFP4_GROUP_SIZE;
1973    for i in 0..m {
1974        for j in 0..n {
1975            let mut acc = 0f32;
1976            for p in 0..k {
1977                let byte_idx = (p * n + j) / 2;
1978                let nibble = if (p * n + j) & 1 == 0 {
1979                    w_bytes[byte_idx] & 0x0F
1980                } else {
1981                    w_bytes[byte_idx] >> 4
1982                };
1983                let block = p / gs;
1984                let scale = fp8_e4m3_scale_to_f32(scale_bytes[block * n + j]);
1985                let w = fp4_e2m1_to_f32(nibble) * scale * global_scale;
1986                acc += x[i * k + p] * w;
1987            }
1988            out[i * n + j] = acc;
1989        }
1990    }
1991}
1992
1993/// Fused sampling step: logits → top-k filter → top-p truncation
1994/// → softmax → multinomial sample. Operates on one row of length
1995/// `vocab` and returns the sampled index. Plan #42.
1996///
1997/// Internal scratch is on the stack via SmallVec-style fallback —
1998/// for `vocab > 8192` we heap-allocate a working buffer; below
1999/// that we keep things in a fixed array. (TODO: thread the
2000/// scratch through ThunkSchedule like sdpa_scores does.)
2001fn sample_row(
2002    logits: &[f32],
2003    top_k: usize,
2004    top_p: f32,
2005    temperature: f32,
2006    rng: &mut rlx_ir::Philox4x32,
2007) -> usize {
2008    let v = logits.len();
2009    if v == 0 {
2010        return 0;
2011    }
2012    let temp = temperature.max(1e-6);
2013    // Copy + temperature-scale into a working buffer.
2014    let mut scaled: Vec<f32> = logits.iter().map(|&x| x / temp).collect();
2015
2016    // Top-k: zero out everything but the k largest by setting to -inf.
2017    if top_k > 0 && top_k < v {
2018        // Partial selection: find k-th largest then mask below.
2019        let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2020        // Sort descending; partial would be O(n log k), full sort is fine
2021        // for typical vocab sizes (32k-128k) — single-row work.
2022        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2023        let cutoff = indexed[top_k - 1].1;
2024        for x in scaled.iter_mut() {
2025            if *x < cutoff {
2026                *x = f32::NEG_INFINITY;
2027            }
2028        }
2029    }
2030
2031    // Stable softmax.
2032    let mut max_l = f32::NEG_INFINITY;
2033    for &x in &scaled {
2034        if x > max_l {
2035            max_l = x;
2036        }
2037    }
2038    let mut sum = 0.0f32;
2039    for x in scaled.iter_mut() {
2040        *x = (*x - max_l).exp();
2041        sum += *x;
2042    }
2043    let inv = 1.0 / sum.max(f32::MIN_POSITIVE);
2044    for x in scaled.iter_mut() {
2045        *x *= inv;
2046    }
2047
2048    // Top-p: keep the smallest set of tokens whose cumulative
2049    // probability exceeds top_p (after sorting descending).
2050    if top_p < 1.0 {
2051        let mut indexed: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
2052        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2053        let mut cum = 0.0f32;
2054        let mut keep = vec![false; v];
2055        for (idx, p) in indexed.iter() {
2056            keep[*idx] = true;
2057            cum += *p;
2058            if cum >= top_p {
2059                break;
2060            }
2061        }
2062        let mut new_sum = 0.0f32;
2063        for (i, x) in scaled.iter_mut().enumerate() {
2064            if !keep[i] {
2065                *x = 0.0;
2066            }
2067            new_sum += *x;
2068        }
2069        let inv = 1.0 / new_sum.max(f32::MIN_POSITIVE);
2070        for x in scaled.iter_mut() {
2071            *x *= inv;
2072        }
2073    }
2074
2075    // Multinomial sample via inverse-CDF.
2076    let r = rng.next_f32();
2077    let mut acc = 0.0f32;
2078    for (i, &p) in scaled.iter().enumerate() {
2079        acc += p;
2080        if r <= acc {
2081            return i;
2082        }
2083    }
2084    v - 1 // floating-point edge case fallback
2085}
2086
2087/// Apply a synthetic (kernel-generated) attention mask to a `[q_seq, k_seq]`
2088/// scores matrix. Custom masks are read from a tensor and not handled here.
2089/// `None` is a no-op so callers don't need to special-case it.
2090#[inline]
2091fn apply_synthetic_mask(
2092    scores: &mut [f32],
2093    q_seq: usize,
2094    k_seq: usize,
2095    kind: rlx_ir::op::MaskKind,
2096) {
2097    let neg = crate::config::RuntimeConfig::global().attn_mask_neg_inf;
2098    let q_offset = k_seq.saturating_sub(q_seq);
2099    match kind {
2100        rlx_ir::op::MaskKind::None | rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias => {}
2101        rlx_ir::op::MaskKind::Causal => {
2102            for qi in 0..q_seq {
2103                let abs_q = q_offset + qi;
2104                for ki in (abs_q + 1)..k_seq {
2105                    scores[qi * k_seq + ki] = neg;
2106                }
2107            }
2108        }
2109        rlx_ir::op::MaskKind::SlidingWindow(w) => {
2110            for qi in 0..q_seq {
2111                let abs_q = q_offset + qi;
2112                let lo = abs_q.saturating_sub(w);
2113                for ki in 0..k_seq {
2114                    if ki < lo || ki > abs_q {
2115                        scores[qi * k_seq + ki] = neg;
2116                    }
2117                }
2118            }
2119        }
2120    }
2121}
2122
2123/// Compile graph into thunk schedule.
2124pub fn compile_thunks(graph: &Graph, arena: &Arena) -> ThunkSchedule {
2125    let mut thunks = Vec::with_capacity(graph.len());
2126
2127    for node in graph.nodes() {
2128        // View ops (Reshape / same-dtype Cast / axis-0 Narrow) are aliased
2129        // to their parent's slot by the memory planner — no copy needed.
2130        // Plan #46.
2131        if rlx_opt::is_pure_view(graph, node) {
2132            thunks.push(Thunk::Nop);
2133            continue;
2134        }
2135        let t = match &node.op {
2136            Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => Thunk::Nop,
2137
2138            Op::FusedMatMulBiasAct { activation } => {
2139                let shape = &node.shape;
2140                let n = shape.dim(shape.rank() - 1).unwrap_static();
2141                let total = shape.num_elements().unwrap();
2142                let m = total / n;
2143                let a_len = get_len(graph, node.inputs[0]);
2144                let k = a_len / m;
2145                Thunk::FusedMmBiasAct {
2146                    a: node_offset(arena, node.inputs[0]),
2147                    w: node_offset(arena, node.inputs[1]),
2148                    bias: node_offset(arena, node.inputs[2]),
2149                    c: node_offset(arena, node.id),
2150                    m: m as u32,
2151                    k: k as u32,
2152                    n: n as u32,
2153                    act: *activation,
2154                }
2155            }
2156
2157            Op::FusedResidualLN { has_bias, eps } => {
2158                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2159                let total = node.shape.num_elements().unwrap();
2160                let rows = total / h;
2161                let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2162                Thunk::FusedResidualLN {
2163                    x: node_offset(arena, node.inputs[0]),
2164                    res: node_offset(arena, node.inputs[1]),
2165                    bias: if *has_bias {
2166                        node_offset(arena, node.inputs[2])
2167                    } else {
2168                        0
2169                    },
2170                    g: node_offset(arena, node.inputs[g_idx]),
2171                    b: node_offset(arena, node.inputs[b_idx]),
2172                    out: node_offset(arena, node.id),
2173                    rows: rows as u32,
2174                    h: h as u32,
2175                    eps: *eps,
2176                    has_bias: *has_bias,
2177                }
2178            }
2179
2180            Op::FusedResidualRmsNorm { has_bias, eps } => {
2181                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2182                let total = node.shape.num_elements().unwrap();
2183                let rows = total / h;
2184                let (g_idx, b_idx) = if *has_bias { (3, 4) } else { (2, 3) };
2185                Thunk::FusedResidualRmsNorm {
2186                    x: node_offset(arena, node.inputs[0]),
2187                    res: node_offset(arena, node.inputs[1]),
2188                    bias: if *has_bias {
2189                        node_offset(arena, node.inputs[2])
2190                    } else {
2191                        0
2192                    },
2193                    g: node_offset(arena, node.inputs[g_idx]),
2194                    b: node_offset(arena, node.inputs[b_idx]),
2195                    out: node_offset(arena, node.id),
2196                    rows: rows as u32,
2197                    h: h as u32,
2198                    eps: *eps,
2199                    has_bias: *has_bias,
2200                }
2201            }
2202
2203            Op::MatMul => {
2204                let shape = &node.shape;
2205                let a_shape = &graph.node(node.inputs[0]).shape;
2206                let b_shape = &graph.node(node.inputs[1]).shape;
2207                let n = shape.dim(shape.rank() - 1).unwrap_static();
2208
2209                // Detect batched matmul: any rank where both inputs
2210                // and output share the same leading batch dims and
2211                // the last 2 dims form an [M, K] @ [K, N] = [M, N].
2212                // The 2-D MatMul lowering's flatten-and-call-dgemm trick
2213                // is wrong when both operands carry independent batch
2214                // dims (per-batch K dimension differs).
2215                let batched_3d = a_shape.rank() >= 3
2216                    && b_shape.rank() == a_shape.rank()
2217                    && shape.rank() == a_shape.rank()
2218                    && {
2219                        // All leading dims (everything except last 2) match.
2220                        let mut ok = true;
2221                        for d in 0..a_shape.rank() - 2 {
2222                            if a_shape.dim(d) != b_shape.dim(d) || a_shape.dim(d) != shape.dim(d) {
2223                                ok = false;
2224                                break;
2225                            }
2226                        }
2227                        ok
2228                    };
2229                if batched_3d && shape.dtype() == rlx_ir::DType::F64 {
2230                    // Batch is the product of all leading dims (every
2231                    // dim except the last 2); m/k/n are the inner
2232                    // matmul dims. Works for any rank >= 3.
2233                    let r = shape.rank();
2234                    let mut batch_prod = 1usize;
2235                    for d in 0..r - 2 {
2236                        batch_prod *= shape.dim(d).unwrap_static();
2237                    }
2238                    let m_dim = shape.dim(r - 2).unwrap_static();
2239                    let k_dim = a_shape.dim(r - 1).unwrap_static();
2240                    debug_assert_eq!(k_dim, b_shape.dim(r - 2).unwrap_static());
2241                    Thunk::BatchedDgemmF64 {
2242                        a: node_offset(arena, node.inputs[0]),
2243                        b: node_offset(arena, node.inputs[1]),
2244                        c: node_offset(arena, node.id),
2245                        batch: batch_prod as u32,
2246                        m: m_dim as u32,
2247                        k: k_dim as u32,
2248                        n: n as u32,
2249                    }
2250                } else if batched_3d && shape.dtype() == rlx_ir::DType::F32 {
2251                    // f32 batched matmul for any rank >= 3 (collapse all
2252                    // leading batch dims into a single batch count).
2253                    let r = shape.rank();
2254                    let mut batch_prod = 1usize;
2255                    for d in 0..r - 2 {
2256                        batch_prod *= shape.dim(d).unwrap_static();
2257                    }
2258                    let m_dim = shape.dim(r - 2).unwrap_static();
2259                    let k_dim = a_shape.dim(r - 1).unwrap_static();
2260                    debug_assert_eq!(k_dim, b_shape.dim(r - 2).unwrap_static());
2261                    Thunk::BatchedSgemm {
2262                        a: node_offset(arena, node.inputs[0]),
2263                        b: node_offset(arena, node.inputs[1]),
2264                        c: node_offset(arena, node.id),
2265                        batch: batch_prod as u32,
2266                        m: m_dim as u32,
2267                        k: k_dim as u32,
2268                        n: n as u32,
2269                    }
2270                } else {
2271                    let total = shape.num_elements().unwrap();
2272                    let m = total / n;
2273                    let a_len = get_len(graph, node.inputs[0]);
2274                    let k = a_len / m;
2275                    match shape.dtype() {
2276                        rlx_ir::DType::F64 => Thunk::Dgemm {
2277                            a: node_offset(arena, node.inputs[0]),
2278                            b: node_offset(arena, node.inputs[1]),
2279                            c: node_offset(arena, node.id),
2280                            m: m as u32,
2281                            k: k as u32,
2282                            n: n as u32,
2283                        },
2284                        _ => Thunk::Sgemm {
2285                            a: node_offset(arena, node.inputs[0]),
2286                            b: node_offset(arena, node.inputs[1]),
2287                            c: node_offset(arena, node.id),
2288                            m: m as u32,
2289                            k: k as u32,
2290                            n: n as u32,
2291                        },
2292                    }
2293                }
2294            }
2295
2296            Op::Binary(op) => {
2297                let lhs_len = get_len(graph, node.inputs[0]);
2298                let rhs_len = get_len(graph, node.inputs[1]);
2299                let out_len = node.shape.num_elements().unwrap();
2300                if node.shape.dtype() == rlx_ir::DType::C64 {
2301                    // Native C64 element-wise. Add/Sub/Mul/Div lower
2302                    // to `BinaryFullC64`; the rest don't have a
2303                    // single natural complex definition.
2304                    match op {
2305                        BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div => {}
2306                        BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => panic!(
2307                            "Op::Binary({op:?}) on DType::C64: complex \
2308                             max/min/pow have no single natural definition \
2309                             — caller should drop to 2N-real-block (see \
2310                             spike-ac) and pick a convention there"
2311                        ),
2312                    }
2313                }
2314                // Compute broadcast strides for the slow path. Empty
2315                // vectors when no broadcast is needed (the fast-path
2316                // kernel ignores them anyway).
2317                let (out_dims_bcast, bcast_lhs_strides, bcast_rhs_strides) =
2318                    if lhs_len == out_len && rhs_len == out_len {
2319                        (Vec::new(), Vec::new(), Vec::new())
2320                    } else {
2321                        let lhs_dims = get_static_dims(graph, node.inputs[0]);
2322                        let rhs_dims = get_static_dims(graph, node.inputs[1]);
2323                        let out_dims_v = get_static_dims(graph, node.id);
2324                        if lhs_dims.is_empty() || rhs_dims.is_empty() || out_dims_v.is_empty() {
2325                            // Dynamic shape — fall back to the legacy
2326                            // modulo path (correct for scalar / last-
2327                            // axis broadcast, which is the only
2328                            // dynamic case in practice).
2329                            (Vec::new(), Vec::new(), Vec::new())
2330                        } else {
2331                            let ls = broadcast_strides(&lhs_dims, &out_dims_v);
2332                            let rs = broadcast_strides(&rhs_dims, &out_dims_v);
2333                            let od: Vec<u32> = out_dims_v.iter().map(|x| *x as u32).collect();
2334                            (od, ls, rs)
2335                        }
2336                    };
2337                if node.shape.dtype() == rlx_ir::DType::C64 {
2338                    Thunk::BinaryFullC64 {
2339                        lhs: node_offset(arena, node.inputs[0]),
2340                        rhs: node_offset(arena, node.inputs[1]),
2341                        dst: node_offset(arena, node.id),
2342                        len: out_len as u32,
2343                        lhs_len: lhs_len as u32,
2344                        rhs_len: rhs_len as u32,
2345                        op: *op,
2346                        out_dims_bcast,
2347                        bcast_lhs_strides,
2348                        bcast_rhs_strides,
2349                    }
2350                } else if node.shape.dtype() == rlx_ir::DType::F64 {
2351                    // f64 path — no BiasAdd fast-path (yet); use the
2352                    // general binary-with-broadcast kernel.
2353                    Thunk::BinaryFullF64 {
2354                        lhs: node_offset(arena, node.inputs[0]),
2355                        rhs: node_offset(arena, node.inputs[1]),
2356                        dst: node_offset(arena, node.id),
2357                        len: out_len as u32,
2358                        lhs_len: lhs_len as u32,
2359                        rhs_len: rhs_len as u32,
2360                        op: *op,
2361                        out_dims_bcast,
2362                        bcast_lhs_strides,
2363                        bcast_rhs_strides,
2364                    }
2365                } else if matches!(op, BinaryOp::Add)
2366                    && rhs_len < out_len
2367                    && out_len % rhs_len == 0
2368                    && is_trailing_bias_broadcast(
2369                        graph.node(node.inputs[1]).shape.dims(),
2370                        graph.node(node.id).shape.dims(),
2371                    )
2372                {
2373                    // `BiasAdd` is only correct when the bias is a
2374                    // *trailing* broadcast — rhs dims match the right-
2375                    // hand side of the output dims (with size-1 only
2376                    // allowed in left-padded outer positions).
2377                    // SAM's rel-pos `[bh, h, w, 1, w] + [bh, h, w, h, w]`
2378                    // has rhs_len divide out_len cleanly but is a
2379                    // mid-shape singleton, NOT a trailing broadcast.
2380                    // Routing it through BiasAdd silently treats it as
2381                    // last-`rhs_len`-cols repeated — wrong values.
2382                    Thunk::BiasAdd {
2383                        src: node_offset(arena, node.inputs[0]),
2384                        bias: node_offset(arena, node.inputs[1]),
2385                        dst: node_offset(arena, node.id),
2386                        m: (out_len / rhs_len) as u32,
2387                        n: rhs_len as u32,
2388                    }
2389                } else {
2390                    let lhs_len = get_len(graph, node.inputs[0]);
2391                    Thunk::BinaryFull {
2392                        lhs: node_offset(arena, node.inputs[0]),
2393                        rhs: node_offset(arena, node.inputs[1]),
2394                        dst: node_offset(arena, node.id),
2395                        len: out_len as u32,
2396                        lhs_len: lhs_len as u32,
2397                        rhs_len: rhs_len as u32,
2398                        op: *op,
2399                        out_dims_bcast,
2400                        bcast_lhs_strides,
2401                        bcast_rhs_strides,
2402                    }
2403                }
2404            }
2405
2406            Op::Activation(act) => {
2407                let len = node.shape.num_elements().unwrap();
2408                let in_off = node_offset(arena, node.inputs[0]);
2409                let out_off = node_offset(arena, node.id);
2410                if node.shape.dtype() == rlx_ir::DType::C64 {
2411                    // Only Neg/Exp/Log/Sqrt have natural complex
2412                    // extensions used in signal-processing graphs.
2413                    // Everything else (Sigmoid, Tanh, Relu, Abs,
2414                    // Sin/Cos/Tan/Atan, Round, GeLU family) is rejected.
2415                    match act {
2416                        Activation::Neg | Activation::Exp | Activation::Log | Activation::Sqrt => {}
2417                        other => panic!(
2418                            "Op::Activation({other:?}) on DType::C64: no \
2419                             natural complex extension — supported on C64: \
2420                             Neg, Exp, Log, Sqrt"
2421                        ),
2422                    }
2423                    Thunk::ActivationC64 {
2424                        src: in_off,
2425                        dst: out_off,
2426                        len: len as u32,
2427                        kind: *act,
2428                    }
2429                } else if node.shape.dtype() == rlx_ir::DType::F64 {
2430                    Thunk::ActivationF64 {
2431                        src: in_off,
2432                        dst: out_off,
2433                        len: len as u32,
2434                        kind: *act,
2435                    }
2436                } else if in_off == out_off {
2437                    // ActivationInPlace operates on a single buffer. When the
2438                    // planner has assigned input and output the same slot
2439                    // (typical post-fusion case), we just run on that slot.
2440                    Thunk::ActivationInPlace {
2441                        data: out_off,
2442                        len: len as u32,
2443                        act: *act,
2444                    }
2445                } else {
2446                    // Two-step: copy input → output, then activate output in place.
2447                    // The schedule executes them in this order; downstream
2448                    // thunks see the activated output at out_off.
2449                    thunks.push(Thunk::Copy {
2450                        src: in_off,
2451                        dst: out_off,
2452                        len: len as u32,
2453                    });
2454                    Thunk::ActivationInPlace {
2455                        data: out_off,
2456                        len: len as u32,
2457                        act: *act,
2458                    }
2459                }
2460            }
2461
2462            Op::Gather { axis } if *axis == 0 => {
2463                let table_shape = &graph.node(node.inputs[0]).shape;
2464                let table_total = table_shape.num_elements().unwrap();
2465                let trailing: usize = (1..table_shape.rank())
2466                    .map(|i| table_shape.dim(i).unwrap_static())
2467                    .product();
2468                let idx_len = get_len(graph, node.inputs[1]);
2469                Thunk::Gather {
2470                    table: node_offset(arena, node.inputs[0]),
2471                    table_len: table_total as u32,
2472                    idx: node_offset(arena, node.inputs[1]),
2473                    dst: node_offset(arena, node.id),
2474                    num_idx: idx_len as u32,
2475                    trailing: trailing as u32,
2476                }
2477            }
2478
2479            Op::Gather { axis } => {
2480                // Non-zero axis: outer × num_idx × trailing layout.
2481                let table_shape = &graph.node(node.inputs[0]).shape;
2482                let rank = table_shape.rank();
2483                let outer: usize = (0..*axis)
2484                    .map(|i| table_shape.dim(i).unwrap_static())
2485                    .product::<usize>()
2486                    .max(1);
2487                let trailing: usize = (*axis + 1..rank)
2488                    .map(|i| table_shape.dim(i).unwrap_static())
2489                    .product::<usize>()
2490                    .max(1);
2491                let axis_dim = table_shape.dim(*axis).unwrap_static();
2492                let idx_len = get_len(graph, node.inputs[1]);
2493                Thunk::GatherAxis {
2494                    table: node_offset(arena, node.inputs[0]),
2495                    idx: node_offset(arena, node.inputs[1]),
2496                    dst: node_offset(arena, node.id),
2497                    outer: outer as u32,
2498                    axis_dim: axis_dim as u32,
2499                    num_idx: idx_len as u32,
2500                    trailing: trailing as u32,
2501                }
2502            }
2503
2504            Op::Narrow { axis, start, len } => {
2505                let in_shape = &graph.node(node.inputs[0]).shape;
2506                let elem_bytes = in_shape.dtype().size_bytes() as u8;
2507                let rank = in_shape.rank();
2508                let outer: usize = (0..*axis)
2509                    .map(|i| in_shape.dim(i).unwrap_static())
2510                    .product::<usize>()
2511                    .max(1);
2512                let inner: usize = (*axis + 1..rank)
2513                    .map(|i| in_shape.dim(i).unwrap_static())
2514                    .product::<usize>()
2515                    .max(1);
2516                let in_axis = in_shape.dim(*axis).unwrap_static();
2517                let src_byte_offset =
2518                    node_offset(arena, node.inputs[0]) + start * inner * elem_bytes as usize;
2519                Thunk::Narrow {
2520                    src: src_byte_offset,
2521                    dst: node_offset(arena, node.id),
2522                    outer: outer as u32,
2523                    src_stride: (in_axis * inner) as u32, // elements per outer step in source
2524                    dst_stride: (*len * inner) as u32,    // elements per outer step in dest
2525                    inner: (*len * inner) as u32,         // elements to copy per outer step
2526                    elem_bytes,
2527                }
2528            }
2529
2530            Op::Reshape { .. } | Op::Cast { .. } => {
2531                // Pure layout/dtype change: same total element count, plain copy.
2532                let len = node.shape.num_elements().unwrap();
2533                let src = node_offset(arena, node.inputs[0]);
2534                let dst = node_offset(arena, node.id);
2535                match node.shape.dtype() {
2536                    rlx_ir::DType::F64 => Thunk::CopyF64 {
2537                        src,
2538                        dst,
2539                        len: len as u32,
2540                    },
2541                    _ => Thunk::Copy {
2542                        src,
2543                        dst,
2544                        len: len as u32,
2545                    },
2546                }
2547            }
2548
2549            Op::Quantize {
2550                axis,
2551                scales,
2552                zero_points,
2553            } => {
2554                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2555                Thunk::Quantize {
2556                    x: node_offset(arena, node.inputs[0]),
2557                    q: node_offset(arena, node.id),
2558                    len: node.shape.num_elements().unwrap() as u32,
2559                    chan_axis: chan_axis as u32,
2560                    chan_dim: chan_dim as u32,
2561                    inner: inner as u32,
2562                    scales: scales.clone(),
2563                    zero_points: zero_points.clone(),
2564                }
2565            }
2566
2567            Op::FakeQuantize {
2568                bits,
2569                axis,
2570                ste,
2571                scale_mode,
2572            } => {
2573                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2574                let state_off = match scale_mode {
2575                    rlx_ir::op::ScaleMode::PerBatch => None,
2576                    rlx_ir::op::ScaleMode::EMA { .. } | rlx_ir::op::ScaleMode::Fixed => {
2577                        // Second input carries the [chan_dim] scale state.
2578                        debug_assert_eq!(
2579                            node.inputs.len(),
2580                            2,
2581                            "EMA/Fixed FakeQuantize needs a state input"
2582                        );
2583                        Some(node_offset(arena, node.inputs[1]))
2584                    }
2585                };
2586                Thunk::FakeQuantize {
2587                    x: node_offset(arena, node.inputs[0]),
2588                    out: node_offset(arena, node.id),
2589                    len: node.shape.num_elements().unwrap() as u32,
2590                    chan_axis: chan_axis as u32,
2591                    chan_dim: chan_dim as u32,
2592                    inner: inner as u32,
2593                    bits: *bits,
2594                    ste: *ste,
2595                    scale_mode: *scale_mode,
2596                    state_off,
2597                }
2598            }
2599
2600            Op::FakeQuantizeLSQ { bits, axis } => {
2601                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2602                Thunk::FakeQuantizeLSQ {
2603                    x: node_offset(arena, node.inputs[0]),
2604                    scale_off: node_offset(arena, node.inputs[1]),
2605                    out: node_offset(arena, node.id),
2606                    len: node.shape.num_elements().unwrap() as u32,
2607                    chan_axis: chan_axis as u32,
2608                    chan_dim: chan_dim as u32,
2609                    inner: inner as u32,
2610                    bits: *bits,
2611                }
2612            }
2613
2614            Op::FakeQuantizeLSQBackwardX { bits, axis } => {
2615                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2616                Thunk::FakeQuantizeLSQBackwardX {
2617                    x: node_offset(arena, node.inputs[0]),
2618                    scale_off: node_offset(arena, node.inputs[1]),
2619                    dy: node_offset(arena, node.inputs[2]),
2620                    dx: node_offset(arena, node.id),
2621                    len: node.shape.num_elements().unwrap() as u32,
2622                    chan_axis: chan_axis as u32,
2623                    chan_dim: chan_dim as u32,
2624                    inner: inner as u32,
2625                    bits: *bits,
2626                }
2627            }
2628
2629            Op::FakeQuantizeLSQBackwardScale { bits, axis } => {
2630                // Output shape is [chan_dim] — node.shape doesn't
2631                // describe the input data layout, but inputs[0] does.
2632                let in_shape = &graph.node(node.inputs[0]).shape;
2633                let (chan_axis, chan_dim, inner) = quant_layout(in_shape, *axis);
2634                Thunk::FakeQuantizeLSQBackwardScale {
2635                    x: node_offset(arena, node.inputs[0]),
2636                    scale_off: node_offset(arena, node.inputs[1]),
2637                    dy: node_offset(arena, node.inputs[2]),
2638                    dscale: node_offset(arena, node.id),
2639                    len: in_shape.num_elements().unwrap() as u32,
2640                    chan_axis: chan_axis as u32,
2641                    chan_dim: chan_dim as u32,
2642                    inner: inner as u32,
2643                    bits: *bits,
2644                }
2645            }
2646
2647            Op::FakeQuantizeBackward { bits, axis, ste } => {
2648                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2649                Thunk::FakeQuantizeBackward {
2650                    x: node_offset(arena, node.inputs[0]),
2651                    dy: node_offset(arena, node.inputs[1]),
2652                    dx: node_offset(arena, node.id),
2653                    len: node.shape.num_elements().unwrap() as u32,
2654                    chan_axis: chan_axis as u32,
2655                    chan_dim: chan_dim as u32,
2656                    inner: inner as u32,
2657                    bits: *bits,
2658                    ste: *ste,
2659                }
2660            }
2661
2662            Op::Dequantize {
2663                axis,
2664                scales,
2665                zero_points,
2666            } => {
2667                let (chan_axis, chan_dim, inner) = quant_layout(&node.shape, *axis);
2668                Thunk::Dequantize {
2669                    q: node_offset(arena, node.inputs[0]),
2670                    x: node_offset(arena, node.id),
2671                    len: node.shape.num_elements().unwrap() as u32,
2672                    chan_axis: chan_axis as u32,
2673                    chan_dim: chan_dim as u32,
2674                    inner: inner as u32,
2675                    scales: scales.clone(),
2676                    zero_points: zero_points.clone(),
2677                }
2678            }
2679
2680            Op::Expand { .. } => {
2681                // Broadcast: build per-output-dim strides where any input dim
2682                // of size 1 has stride 0 (read the same element repeatedly).
2683                // Reuses the Thunk::Transpose runtime — N-D walk with strides
2684                // is identical; only the strides differ.
2685                let in_shape = &graph.node(node.inputs[0]).shape;
2686                let out_shape = &node.shape;
2687                let in_rank = in_shape.rank();
2688                let out_rank = out_shape.rank();
2689                // Implicit leading 1s if input has lower rank.
2690                let pad = out_rank.saturating_sub(in_rank);
2691                let in_dims: Vec<usize> = (0..out_rank)
2692                    .map(|i| {
2693                        if i < pad {
2694                            1
2695                        } else {
2696                            in_shape.dim(i - pad).unwrap_static()
2697                        }
2698                    })
2699                    .collect();
2700                // Row-major input strides (over the padded shape).
2701                let mut in_strides_full = vec![1usize; out_rank];
2702                for d in (0..out_rank.saturating_sub(1)).rev() {
2703                    in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
2704                }
2705                let out_dims: Vec<u32> = (0..out_rank)
2706                    .map(|i| out_shape.dim(i).unwrap_static() as u32)
2707                    .collect();
2708                // Stride is 0 for broadcast dims (in_dim == 1 && out_dim > 1).
2709                let in_strides: Vec<u32> = (0..out_rank)
2710                    .map(|i| {
2711                        if in_dims[i] == 1 && (out_dims[i] as usize) > 1 {
2712                            0
2713                        } else {
2714                            in_strides_full[i] as u32
2715                        }
2716                    })
2717                    .collect();
2718                let in_total = in_dims.iter().product::<usize>() as u32;
2719                let src = node_offset(arena, node.inputs[0]);
2720                let dst = node_offset(arena, node.id);
2721                match node.shape.dtype() {
2722                    rlx_ir::DType::F64 => Thunk::TransposeF64 {
2723                        src,
2724                        dst,
2725                        in_total,
2726                        out_dims,
2727                        in_strides,
2728                    },
2729                    _ => Thunk::Transpose {
2730                        src,
2731                        dst,
2732                        in_total,
2733                        out_dims,
2734                        in_strides,
2735                    },
2736                }
2737            }
2738
2739            Op::RmsNorm { eps, .. } => {
2740                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2741                let total = node.shape.num_elements().unwrap();
2742                Thunk::RmsNorm {
2743                    src: node_offset(arena, node.inputs[0]),
2744                    g: node_offset(arena, node.inputs[1]),
2745                    b: node_offset(arena, node.inputs[2]),
2746                    dst: node_offset(arena, node.id),
2747                    rows: (total / h) as u32,
2748                    h: h as u32,
2749                    eps: *eps,
2750                }
2751            }
2752
2753            Op::LayerNorm { eps, .. } => {
2754                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
2755                let total = node.shape.num_elements().unwrap();
2756                Thunk::LayerNorm {
2757                    src: node_offset(arena, node.inputs[0]),
2758                    g: node_offset(arena, node.inputs[1]),
2759                    b: node_offset(arena, node.inputs[2]),
2760                    dst: node_offset(arena, node.id),
2761                    rows: (total / h) as u32,
2762                    h: h as u32,
2763                    eps: *eps,
2764                }
2765            }
2766
2767            Op::GroupNorm { num_groups, eps } => {
2768                let in_shape = &graph.node(node.inputs[0]).shape;
2769                Thunk::GroupNorm {
2770                    src: node_offset(arena, node.inputs[0]),
2771                    g: node_offset(arena, node.inputs[1]),
2772                    b: node_offset(arena, node.inputs[2]),
2773                    dst: node_offset(arena, node.id),
2774                    n: in_shape.dim(0).unwrap_static() as u32,
2775                    c: in_shape.dim(1).unwrap_static() as u32,
2776                    h: in_shape.dim(2).unwrap_static() as u32,
2777                    w: in_shape.dim(3).unwrap_static() as u32,
2778                    num_groups: *num_groups as u32,
2779                    eps: *eps,
2780                }
2781            }
2782
2783            Op::LayerNorm2d { eps } => {
2784                let in_shape = &graph.node(node.inputs[0]).shape;
2785                Thunk::LayerNorm2d {
2786                    src: node_offset(arena, node.inputs[0]),
2787                    g: node_offset(arena, node.inputs[1]),
2788                    b: node_offset(arena, node.inputs[2]),
2789                    dst: node_offset(arena, node.id),
2790                    n: in_shape.dim(0).unwrap_static() as u32,
2791                    c: in_shape.dim(1).unwrap_static() as u32,
2792                    h: in_shape.dim(2).unwrap_static() as u32,
2793                    w: in_shape.dim(3).unwrap_static() as u32,
2794                    eps: *eps,
2795                }
2796            }
2797
2798            Op::ConvTranspose2d {
2799                kernel_size,
2800                stride,
2801                padding,
2802                dilation,
2803                output_padding: _,
2804                groups,
2805            } => {
2806                let in_shape = &graph.node(node.inputs[0]).shape;
2807                let out_shape = &node.shape;
2808                Thunk::ConvTranspose2d {
2809                    src: node_offset(arena, node.inputs[0]),
2810                    weight: node_offset(arena, node.inputs[1]),
2811                    dst: node_offset(arena, node.id),
2812                    n: in_shape.dim(0).unwrap_static() as u32,
2813                    c_in: in_shape.dim(1).unwrap_static() as u32,
2814                    h: in_shape.dim(2).unwrap_static() as u32,
2815                    w_in: in_shape.dim(3).unwrap_static() as u32,
2816                    c_out: out_shape.dim(1).unwrap_static() as u32,
2817                    h_out: out_shape.dim(2).unwrap_static() as u32,
2818                    w_out: out_shape.dim(3).unwrap_static() as u32,
2819                    kh: kernel_size[0] as u32,
2820                    kw: kernel_size[1] as u32,
2821                    sh: stride.first().copied().unwrap_or(1) as u32,
2822                    sw: stride.get(1).copied().unwrap_or(1) as u32,
2823                    ph: padding.first().copied().unwrap_or(0) as u32,
2824                    pw: padding.get(1).copied().unwrap_or(0) as u32,
2825                    dh: dilation.first().copied().unwrap_or(1) as u32,
2826                    dw: dilation.get(1).copied().unwrap_or(1) as u32,
2827                    groups: *groups as u32,
2828                }
2829            }
2830
2831            Op::ResizeNearest2x => {
2832                let in_shape = &graph.node(node.inputs[0]).shape;
2833                Thunk::ResizeNearest2x {
2834                    src: node_offset(arena, node.inputs[0]),
2835                    dst: node_offset(arena, node.id),
2836                    n: in_shape.dim(0).unwrap_static() as u32,
2837                    c: in_shape.dim(1).unwrap_static() as u32,
2838                    h: in_shape.dim(2).unwrap_static() as u32,
2839                    w: in_shape.dim(3).unwrap_static() as u32,
2840                }
2841            }
2842
2843            Op::AxialRope2d {
2844                end_x,
2845                end_y,
2846                head_dim,
2847                num_heads,
2848                theta,
2849                repeat_factor,
2850            } => {
2851                let in_shape = &graph.node(node.inputs[0]).shape;
2852                let batch = in_shape.dim(0).unwrap_static() as u32;
2853                let seq = in_shape.dim(1).unwrap_static() as u32;
2854                let hidden = in_shape.dim(2).unwrap_static() as u32;
2855                Thunk::AxialRope2d {
2856                    src: node_offset(arena, node.inputs[0]),
2857                    dst: node_offset(arena, node.id),
2858                    batch,
2859                    seq,
2860                    hidden,
2861                    end_x: *end_x as u32,
2862                    end_y: *end_y as u32,
2863                    head_dim: *head_dim as u32,
2864                    num_heads: *num_heads as u32,
2865                    theta: *theta,
2866                    repeat_factor: *repeat_factor as u32,
2867                }
2868            }
2869
2870            Op::Softmax { axis } => {
2871                let rank = node.shape.rank();
2872                let ax = if *axis < 0 {
2873                    (rank as i32 + axis) as usize
2874                } else {
2875                    *axis as usize
2876                };
2877                let cols = node.shape.dim(ax).unwrap_static();
2878                let total = node.shape.num_elements().unwrap();
2879                let in_off = node_offset(arena, node.inputs[0]);
2880                let out_off = node_offset(arena, node.id);
2881                // Softmax kernel runs in-place on its data buffer. If the
2882                // planner gave input and output separate slots (their live
2883                // ranges overlap, so no aliasing), the output starts
2884                // uninitialized — emit a Copy first so the data is there.
2885                // Same pattern as Op::Activation.
2886                if in_off != out_off {
2887                    thunks.push(Thunk::Copy {
2888                        src: in_off,
2889                        dst: out_off,
2890                        len: total as u32,
2891                    });
2892                }
2893                Thunk::Softmax {
2894                    data: out_off,
2895                    rows: (total / cols) as u32,
2896                    cols: cols as u32,
2897                }
2898            }
2899
2900            Op::SelectiveScan { state_size } => {
2901                let in_shape = &graph.node(node.inputs[0]).shape;
2902                let (batch, seq, hidden) = (
2903                    in_shape.dim(0).unwrap_static(),
2904                    in_shape.dim(1).unwrap_static(),
2905                    in_shape.dim(2).unwrap_static(),
2906                );
2907                Thunk::SelectiveScan {
2908                    x: node_offset(arena, node.inputs[0]),
2909                    delta: node_offset(arena, node.inputs[1]),
2910                    a: node_offset(arena, node.inputs[2]),
2911                    b: node_offset(arena, node.inputs[3]),
2912                    c: node_offset(arena, node.inputs[4]),
2913                    dst: node_offset(arena, node.id),
2914                    batch: batch as u32,
2915                    seq: seq as u32,
2916                    hidden: hidden as u32,
2917                    state_size: *state_size as u32,
2918                }
2919            }
2920
2921            Op::GatedDeltaNet {
2922                state_size,
2923                carry_state,
2924            } => {
2925                let q_shape = &graph.node(node.inputs[0]).shape;
2926                let (batch, seq, heads) = (
2927                    q_shape.dim(0).unwrap_static(),
2928                    q_shape.dim(1).unwrap_static(),
2929                    q_shape.dim(2).unwrap_static(),
2930                );
2931                let state_off = if *carry_state {
2932                    node_offset(arena, node.inputs[5])
2933                } else {
2934                    0
2935                };
2936                Thunk::GatedDeltaNet {
2937                    q: node_offset(arena, node.inputs[0]),
2938                    k: node_offset(arena, node.inputs[1]),
2939                    v: node_offset(arena, node.inputs[2]),
2940                    g: node_offset(arena, node.inputs[3]),
2941                    beta: node_offset(arena, node.inputs[4]),
2942                    state: state_off,
2943                    dst: node_offset(arena, node.id),
2944                    batch: batch as u32,
2945                    seq: seq as u32,
2946                    heads: heads as u32,
2947                    state_size: *state_size as u32,
2948                }
2949            }
2950
2951            Op::QMatMul {
2952                x_zp,
2953                w_zp,
2954                out_zp,
2955                mult,
2956            } => {
2957                let x_shape = &graph.node(node.inputs[0]).shape;
2958                let w_shape = &graph.node(node.inputs[1]).shape;
2959                let m = x_shape.dim(0).unwrap_static();
2960                let k = x_shape.dim(1).unwrap_static();
2961                let n = w_shape.dim(1).unwrap_static();
2962                Thunk::QMatMul {
2963                    x: node_offset(arena, node.inputs[0]),
2964                    w: node_offset(arena, node.inputs[1]),
2965                    bias: node_offset(arena, node.inputs[2]),
2966                    out: node_offset(arena, node.id),
2967                    m: m as u32,
2968                    k: k as u32,
2969                    n: n as u32,
2970                    x_zp: *x_zp,
2971                    w_zp: *w_zp,
2972                    out_zp: *out_zp,
2973                    mult: *mult,
2974                }
2975            }
2976
2977            Op::QConv2d {
2978                kernel_size,
2979                stride,
2980                padding,
2981                dilation,
2982                groups,
2983                x_zp,
2984                w_zp,
2985                out_zp,
2986                mult,
2987            } => {
2988                let in_shape = &graph.node(node.inputs[0]).shape;
2989                let w_shape = &graph.node(node.inputs[1]).shape;
2990                let out_shape = &node.shape;
2991                if kernel_size.len() == 2
2992                    && in_shape.rank() == 4
2993                    && w_shape.rank() == 4
2994                    && out_shape.rank() == 4
2995                {
2996                    Thunk::QConv2d {
2997                        x: node_offset(arena, node.inputs[0]),
2998                        w: node_offset(arena, node.inputs[1]),
2999                        bias: node_offset(arena, node.inputs[2]),
3000                        out: node_offset(arena, node.id),
3001                        n: in_shape.dim(0).unwrap_static() as u32,
3002                        c_in: in_shape.dim(1).unwrap_static() as u32,
3003                        h: in_shape.dim(2).unwrap_static() as u32,
3004                        w_in: in_shape.dim(3).unwrap_static() as u32,
3005                        c_out: out_shape.dim(1).unwrap_static() as u32,
3006                        h_out: out_shape.dim(2).unwrap_static() as u32,
3007                        w_out: out_shape.dim(3).unwrap_static() as u32,
3008                        kh: kernel_size[0] as u32,
3009                        kw: kernel_size[1] as u32,
3010                        sh: stride.first().copied().unwrap_or(1) as u32,
3011                        sw: stride.get(1).copied().unwrap_or(1) as u32,
3012                        ph: padding.first().copied().unwrap_or(0) as u32,
3013                        pw: padding.get(1).copied().unwrap_or(0) as u32,
3014                        dh: dilation.first().copied().unwrap_or(1) as u32,
3015                        dw: dilation.get(1).copied().unwrap_or(1) as u32,
3016                        groups: *groups as u32,
3017                        x_zp: *x_zp,
3018                        w_zp: *w_zp,
3019                        out_zp: *out_zp,
3020                        mult: *mult,
3021                    }
3022                } else {
3023                    Thunk::Nop
3024                }
3025            }
3026
3027            Op::DequantMatMul { scheme } => {
3028                use rlx_ir::quant::QuantScheme;
3029                let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3030                let total = node.shape.num_elements().unwrap();
3031                let m = total / n.max(1);
3032                let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3033                let k = x_total / m.max(1);
3034                if scheme.is_gguf() {
3035                    Thunk::DequantMatMulGguf {
3036                        x: node_offset(arena, node.inputs[0]),
3037                        w_q: node_offset(arena, node.inputs[1]),
3038                        dst: node_offset(arena, node.id),
3039                        m: m as u32,
3040                        k: k as u32,
3041                        n: n as u32,
3042                        scheme: *scheme,
3043                    }
3044                } else {
3045                    match scheme {
3046                        QuantScheme::Nvfp4Block => Thunk::DequantMatMulNvfp4 {
3047                            x: node_offset(arena, node.inputs[0]),
3048                            w_q: node_offset(arena, node.inputs[1]),
3049                            scale: node_offset(arena, node.inputs[2]),
3050                            global_scale: node_offset(arena, node.inputs[3]),
3051                            dst: node_offset(arena, node.id),
3052                            m: m as u32,
3053                            k: k as u32,
3054                            n: n as u32,
3055                        },
3056                        QuantScheme::Int4Block { block_size } => Thunk::DequantMatMulInt4 {
3057                            x: node_offset(arena, node.inputs[0]),
3058                            w_q: node_offset(arena, node.inputs[1]),
3059                            scale: node_offset(arena, node.inputs[2]),
3060                            zp: node_offset(arena, node.inputs[3]),
3061                            dst: node_offset(arena, node.id),
3062                            m: m as u32,
3063                            k: k as u32,
3064                            n: n as u32,
3065                            block_size: *block_size,
3066                            is_asymmetric: false,
3067                        },
3068                        QuantScheme::Fp8E4m3 => Thunk::DequantMatMulFp8 {
3069                            x: node_offset(arena, node.inputs[0]),
3070                            w_q: node_offset(arena, node.inputs[1]),
3071                            scale: node_offset(arena, node.inputs[2]),
3072                            dst: node_offset(arena, node.id),
3073                            m: m as u32,
3074                            k: k as u32,
3075                            n: n as u32,
3076                            e5m2: false,
3077                        },
3078                        QuantScheme::Fp8E5m2 => Thunk::DequantMatMulFp8 {
3079                            x: node_offset(arena, node.inputs[0]),
3080                            w_q: node_offset(arena, node.inputs[1]),
3081                            scale: node_offset(arena, node.inputs[2]),
3082                            dst: node_offset(arena, node.id),
3083                            m: m as u32,
3084                            k: k as u32,
3085                            n: n as u32,
3086                            e5m2: true,
3087                        },
3088                        QuantScheme::Int8Block { block_size } => Thunk::DequantMatMul {
3089                            x: node_offset(arena, node.inputs[0]),
3090                            w_q: node_offset(arena, node.inputs[1]),
3091                            scale: node_offset(arena, node.inputs[2]),
3092                            zp: node_offset(arena, node.inputs[3]),
3093                            dst: node_offset(arena, node.id),
3094                            m: m as u32,
3095                            k: k as u32,
3096                            n: n as u32,
3097                            block_size: *block_size,
3098                            is_asymmetric: false,
3099                        },
3100                        QuantScheme::Int8BlockAsym { block_size } => Thunk::DequantMatMul {
3101                            x: node_offset(arena, node.inputs[0]),
3102                            w_q: node_offset(arena, node.inputs[1]),
3103                            scale: node_offset(arena, node.inputs[2]),
3104                            zp: node_offset(arena, node.inputs[3]),
3105                            dst: node_offset(arena, node.id),
3106                            m: m as u32,
3107                            k: k as u32,
3108                            n: n as u32,
3109                            block_size: *block_size,
3110                            is_asymmetric: true,
3111                        },
3112                        other => panic!(
3113                            "DequantMatMul on CPU supports Int8/Int4/FP8/NVFP4 legacy or GGUF schemes; got {other}"
3114                        ),
3115                    }
3116                }
3117            }
3118
3119            Op::LoraMatMul { scale } => {
3120                // x [m, k], w [k, n], a [k, r], b [r, n].
3121                let n = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3122                let total = node.shape.num_elements().unwrap();
3123                let m = total / n.max(1);
3124                let x_total = graph.node(node.inputs[0]).shape.num_elements().unwrap();
3125                let k = x_total / m.max(1);
3126                let a_total = graph.node(node.inputs[2]).shape.num_elements().unwrap();
3127                let r = a_total / k.max(1);
3128                Thunk::LoraMatMul {
3129                    x: node_offset(arena, node.inputs[0]),
3130                    w: node_offset(arena, node.inputs[1]),
3131                    a: node_offset(arena, node.inputs[2]),
3132                    b: node_offset(arena, node.inputs[3]),
3133                    dst: node_offset(arena, node.id),
3134                    m: m as u32,
3135                    k: k as u32,
3136                    n: n as u32,
3137                    r: r as u32,
3138                    scale: *scale,
3139                }
3140            }
3141
3142            Op::Sample {
3143                top_k,
3144                top_p,
3145                temperature,
3146                seed,
3147            } => {
3148                let in_shape = &graph.node(node.inputs[0]).shape;
3149                // Logits are [batch, vocab] (or [vocab] → batch=1).
3150                let (batch, vocab) = if in_shape.rank() >= 2 {
3151                    (
3152                        in_shape.dim(0).unwrap_static(),
3153                        in_shape.dim(in_shape.rank() - 1).unwrap_static(),
3154                    )
3155                } else {
3156                    (1, in_shape.num_elements().unwrap_or(0))
3157                };
3158                Thunk::Sample {
3159                    logits: node_offset(arena, node.inputs[0]),
3160                    dst: node_offset(arena, node.id),
3161                    batch: batch as u32,
3162                    vocab: vocab as u32,
3163                    top_k: *top_k as u32,
3164                    top_p: *top_p,
3165                    temperature: *temperature,
3166                    seed: *seed,
3167                }
3168            }
3169
3170            Op::Cumsum { axis, exclusive } => {
3171                // For now CPU only supports last-axis cumsum (the
3172                // common case for sampling / ragged offsets).
3173                // Other axes can lower via Transpose → Cumsum →
3174                // Transpose; not on the hot path today.
3175                let rank = node.shape.rank();
3176                let ax = if *axis < 0 {
3177                    (rank as i32 + axis) as usize
3178                } else {
3179                    *axis as usize
3180                };
3181                assert_eq!(
3182                    ax,
3183                    rank - 1,
3184                    "Cumsum only supports the last axis on CPU today"
3185                );
3186                let cols = node.shape.dim(ax).unwrap_static();
3187                let total = node.shape.num_elements().unwrap();
3188                Thunk::Cumsum {
3189                    src: node_offset(arena, node.inputs[0]),
3190                    dst: node_offset(arena, node.id),
3191                    rows: (total / cols) as u32,
3192                    cols: cols as u32,
3193                    exclusive: *exclusive,
3194                }
3195            }
3196
3197            Op::Attention {
3198                num_heads,
3199                head_dim,
3200                mask_kind,
3201                score_scale: _,
3202                attn_logit_softcap: _,
3203            } => {
3204                // Layout dispatch: rank-4 input could be either
3205                // `[B, S, H, D]` (CPU's historical convention) or
3206                // `[B, H, S, D]` (the convention the GPU/TPU backends
3207                // share). Disambiguate by which axis matches
3208                // `num_heads`. Rank-3 is always `[B, S, H*D]`.
3209                let q_shape = &graph.node(node.inputs[0]).shape;
3210                let k_shape = &graph.node(node.inputs[1]).shape;
3211                let rank = q_shape.rank();
3212                let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3213                    let d1 = q_shape.dim(1).unwrap_static();
3214                    let d2 = q_shape.dim(2).unwrap_static();
3215                    if d1 == *num_heads {
3216                        // [B, H, S, D]
3217                        (
3218                            q_shape.dim(0).unwrap_static(),
3219                            d2,
3220                            k_shape.dim(2).unwrap_static(),
3221                            true,
3222                        )
3223                    } else {
3224                        // [B, S, H, D]
3225                        (
3226                            q_shape.dim(0).unwrap_static(),
3227                            d1,
3228                            k_shape.dim(1).unwrap_static(),
3229                            false,
3230                        )
3231                    }
3232                } else if rank >= 3 {
3233                    (
3234                        q_shape.dim(0).unwrap_static(),
3235                        q_shape.dim(1).unwrap_static(),
3236                        k_shape.dim(1).unwrap_static(),
3237                        false,
3238                    )
3239                } else {
3240                    (
3241                        1,
3242                        q_shape.dim(0).unwrap_static(),
3243                        k_shape.dim(0).unwrap_static(),
3244                        false,
3245                    )
3246                };
3247                let mask_off = if matches!(
3248                    mask_kind,
3249                    rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3250                ) {
3251                    node_offset(arena, node.inputs[3])
3252                } else {
3253                    0
3254                };
3255                let hs = (*num_heads * *head_dim) as u32;
3256                Thunk::Attention {
3257                    q: node_offset(arena, node.inputs[0]),
3258                    k: node_offset(arena, node.inputs[1]),
3259                    v: node_offset(arena, node.inputs[2]),
3260                    mask: mask_off,
3261                    out: node_offset(arena, node.id),
3262                    batch: batch as u32,
3263                    seq: seq as u32,
3264                    kv_seq: kv_seq as u32,
3265                    heads: *num_heads as u32,
3266                    head_dim: *head_dim as u32,
3267                    mask_kind: *mask_kind,
3268                    // Defaults: each input is its own contiguous buffer
3269                    // with row stride = hidden. Rewritten by the
3270                    // Narrow→Attention fusion when applicable.
3271                    q_row_stride: hs,
3272                    k_row_stride: hs,
3273                    v_row_stride: hs,
3274                    bhsd,
3275                }
3276            }
3277
3278            Op::AttentionBackward {
3279                num_heads,
3280                head_dim,
3281                mask_kind,
3282                wrt,
3283            } => {
3284                let q_shape = &graph.node(node.inputs[0]).shape;
3285                let k_shape = &graph.node(node.inputs[1]).shape;
3286                let rank = q_shape.rank();
3287                let (batch, seq, kv_seq, bhsd) = if rank == 4 {
3288                    let d1 = q_shape.dim(1).unwrap_static();
3289                    let d2 = q_shape.dim(2).unwrap_static();
3290                    if d1 == *num_heads {
3291                        (
3292                            q_shape.dim(0).unwrap_static(),
3293                            d2,
3294                            k_shape.dim(2).unwrap_static(),
3295                            true,
3296                        )
3297                    } else {
3298                        (
3299                            q_shape.dim(0).unwrap_static(),
3300                            d1,
3301                            k_shape.dim(1).unwrap_static(),
3302                            false,
3303                        )
3304                    }
3305                } else if rank >= 3 {
3306                    (
3307                        q_shape.dim(0).unwrap_static(),
3308                        q_shape.dim(1).unwrap_static(),
3309                        k_shape.dim(1).unwrap_static(),
3310                        false,
3311                    )
3312                } else {
3313                    (
3314                        1,
3315                        q_shape.dim(0).unwrap_static(),
3316                        k_shape.dim(0).unwrap_static(),
3317                        false,
3318                    )
3319                };
3320                let mask_off = if matches!(
3321                    mask_kind,
3322                    rlx_ir::op::MaskKind::Custom | rlx_ir::op::MaskKind::Bias
3323                ) {
3324                    node_offset(arena, node.inputs[4])
3325                } else {
3326                    0
3327                };
3328                Thunk::AttentionBackward {
3329                    q: node_offset(arena, node.inputs[0]),
3330                    k: node_offset(arena, node.inputs[1]),
3331                    v: node_offset(arena, node.inputs[2]),
3332                    dy: node_offset(arena, node.inputs[3]),
3333                    mask: mask_off,
3334                    out: node_offset(arena, node.id),
3335                    batch: batch as u32,
3336                    seq: seq as u32,
3337                    kv_seq: kv_seq as u32,
3338                    heads: *num_heads as u32,
3339                    head_dim: *head_dim as u32,
3340                    mask_kind: *mask_kind,
3341                    wrt: *wrt,
3342                    bhsd,
3343                }
3344            }
3345
3346            Op::FusedAttentionBlock {
3347                num_heads,
3348                head_dim,
3349                has_bias,
3350                has_rope,
3351            } => {
3352                let x_shape = &graph.node(node.inputs[0]).shape;
3353                let (batch, seq) = if x_shape.rank() >= 3 {
3354                    (
3355                        x_shape.dim(0).unwrap_static(),
3356                        x_shape.dim(1).unwrap_static(),
3357                    )
3358                } else {
3359                    let total = x_shape.num_elements().unwrap();
3360                    let s = x_shape.dim(x_shape.rank() - 2).unwrap_static();
3361                    (total / (s * num_heads * head_dim), s)
3362                };
3363                let hs = (*num_heads * *head_dim) as u32;
3364                // Inputs: hidden, qkv_w, out_w, mask, [qkv_b, out_b], [cos, sin]
3365                let mut idx = 4;
3366                let (qkv_b_off, out_b_off) = if *has_bias {
3367                    let qb = node_offset(arena, node.inputs[idx]);
3368                    let ob = node_offset(arena, node.inputs[idx + 1]);
3369                    idx += 2;
3370                    (qb, ob)
3371                } else {
3372                    (0, 0)
3373                };
3374                let (cos_off, sin_off, cl) = if *has_rope {
3375                    let c = node_offset(arena, node.inputs[idx]);
3376                    let s = node_offset(arena, node.inputs[idx + 1]);
3377                    let clen = get_len(graph, node.inputs[idx]);
3378                    (c, s, clen as u32)
3379                } else {
3380                    (0, 0, 0)
3381                };
3382
3383                Thunk::FusedAttnBlock {
3384                    hidden: node_offset(arena, node.inputs[0]),
3385                    qkv_w: node_offset(arena, node.inputs[1]),
3386                    out_w: node_offset(arena, node.inputs[2]),
3387                    mask: node_offset(arena, node.inputs[3]),
3388                    out: node_offset(arena, node.id),
3389                    qkv_b: qkv_b_off,
3390                    out_b: out_b_off,
3391                    cos: cos_off,
3392                    sin: sin_off,
3393                    cos_len: cl,
3394                    batch: batch as u32,
3395                    seq: seq as u32,
3396                    hs,
3397                    nh: *num_heads as u32,
3398                    dh: *head_dim as u32,
3399                    has_bias: *has_bias,
3400                    has_rope: *has_rope,
3401                }
3402            }
3403
3404            Op::Rope { head_dim, n_rot } => {
3405                let x_shape = &graph.node(node.inputs[0]).shape;
3406                let (batch, seq, hidden) = if x_shape.rank() >= 3 {
3407                    (
3408                        x_shape.dim(0).unwrap_static(),
3409                        x_shape.dim(1).unwrap_static(),
3410                        x_shape.dim(2).unwrap_static(),
3411                    )
3412                } else {
3413                    let total = x_shape.num_elements().unwrap();
3414                    (
3415                        1,
3416                        x_shape.dim(0).unwrap_static(),
3417                        total / x_shape.dim(0).unwrap_static(),
3418                    )
3419                };
3420                let cos_len = get_len(graph, node.inputs[1]);
3421                Thunk::Rope {
3422                    src: node_offset(arena, node.inputs[0]),
3423                    cos: node_offset(arena, node.inputs[1]),
3424                    sin: node_offset(arena, node.inputs[2]),
3425                    dst: node_offset(arena, node.id),
3426                    batch: batch as u32,
3427                    seq: seq as u32,
3428                    hidden: hidden as u32,
3429                    head_dim: *head_dim as u32,
3430                    n_rot: *n_rot as u32,
3431                    cos_len: cos_len as u32,
3432                    // Default: source rows are tightly packed (rewritten
3433                    // by the Narrow→Rope fusion pass below if Rope ends
3434                    // up reading from a wider parent like QKV).
3435                    src_row_stride: hidden as u32,
3436                }
3437            }
3438
3439            Op::FusedSwiGLU {
3440                cast_to: _,
3441                gate_first,
3442            } => {
3443                let n_half = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3444                let total = node.shape.num_elements().unwrap();
3445                Thunk::FusedSwiGLU {
3446                    src: node_offset(arena, node.inputs[0]),
3447                    dst: node_offset(arena, node.id),
3448                    n_half: n_half as u32,
3449                    total: total as u32,
3450                    gate_first: *gate_first,
3451                }
3452            }
3453
3454            Op::Conv {
3455                kernel_size,
3456                stride,
3457                padding,
3458                dilation,
3459                groups,
3460            } => {
3461                let in_shape = &graph.node(node.inputs[0]).shape;
3462                let w_shape = &graph.node(node.inputs[1]).shape;
3463                let out_shape = &node.shape;
3464                // 1×1 fast path (plan #26): kH=kW=1, stride=1,
3465                // padding=0, dilation=1, groups=1. Emits a single
3466                // Conv2D1x1 thunk that BLAS-dispatches per batch.
3467                let is_1x1_simple = kernel_size.len() == 2
3468                    && kernel_size[0] == 1
3469                    && kernel_size[1] == 1
3470                    && stride.iter().all(|&s| s == 1)
3471                    && padding.iter().all(|&p| p == 0)
3472                    && dilation.iter().all(|&d| d == 1)
3473                    && *groups == 1;
3474                if is_1x1_simple && in_shape.rank() == 4 && out_shape.rank() == 4 {
3475                    let n = in_shape.dim(0).unwrap_static();
3476                    let c_in = in_shape.dim(1).unwrap_static();
3477                    let c_out = out_shape.dim(1).unwrap_static();
3478                    let h = in_shape.dim(2).unwrap_static();
3479                    let w = in_shape.dim(3).unwrap_static();
3480                    Thunk::Conv2D1x1 {
3481                        src: node_offset(arena, node.inputs[0]),
3482                        weight: node_offset(arena, node.inputs[1]),
3483                        dst: node_offset(arena, node.id),
3484                        n: n as u32,
3485                        c_in: c_in as u32,
3486                        c_out: c_out as u32,
3487                        hw: (h * w) as u32,
3488                    }
3489                } else if kernel_size.len() == 2
3490                    && in_shape.rank() == 4
3491                    && w_shape.rank() == 4
3492                    && out_shape.rank() == 4
3493                {
3494                    Thunk::Conv2D {
3495                        src: node_offset(arena, node.inputs[0]),
3496                        weight: node_offset(arena, node.inputs[1]),
3497                        dst: node_offset(arena, node.id),
3498                        n: in_shape.dim(0).unwrap_static() as u32,
3499                        c_in: in_shape.dim(1).unwrap_static() as u32,
3500                        h: in_shape.dim(2).unwrap_static() as u32,
3501                        w: in_shape.dim(3).unwrap_static() as u32,
3502                        c_out: out_shape.dim(1).unwrap_static() as u32,
3503                        h_out: out_shape.dim(2).unwrap_static() as u32,
3504                        w_out: out_shape.dim(3).unwrap_static() as u32,
3505                        kh: kernel_size[0] as u32,
3506                        kw: kernel_size[1] as u32,
3507                        sh: stride.first().copied().unwrap_or(1) as u32,
3508                        sw: stride.get(1).copied().unwrap_or(1) as u32,
3509                        ph: padding.first().copied().unwrap_or(0) as u32,
3510                        pw: padding.get(1).copied().unwrap_or(0) as u32,
3511                        dh: dilation.first().copied().unwrap_or(1) as u32,
3512                        dw: dilation.get(1).copied().unwrap_or(1) as u32,
3513                        groups: *groups as u32,
3514                    }
3515                } else {
3516                    Thunk::Nop
3517                }
3518            }
3519
3520            Op::Pool {
3521                kind,
3522                kernel_size,
3523                stride,
3524                padding,
3525            } => {
3526                // Currently support 2D pooling on rank-4 NCHW tensors.
3527                let in_shape = &graph.node(node.inputs[0]).shape;
3528                let out_shape = &node.shape;
3529                if kernel_size.len() == 2 && in_shape.rank() == 4 && out_shape.rank() == 4 {
3530                    Thunk::Pool2D {
3531                        src: node_offset(arena, node.inputs[0]),
3532                        dst: node_offset(arena, node.id),
3533                        n: in_shape.dim(0).unwrap_static() as u32,
3534                        c: in_shape.dim(1).unwrap_static() as u32,
3535                        h: in_shape.dim(2).unwrap_static() as u32,
3536                        w: in_shape.dim(3).unwrap_static() as u32,
3537                        h_out: out_shape.dim(2).unwrap_static() as u32,
3538                        w_out: out_shape.dim(3).unwrap_static() as u32,
3539                        kh: kernel_size[0] as u32,
3540                        kw: kernel_size[1] as u32,
3541                        sh: stride.first().copied().unwrap_or(1) as u32,
3542                        sw: stride.get(1).copied().unwrap_or(1) as u32,
3543                        ph: padding.first().copied().unwrap_or(0) as u32,
3544                        pw: padding.get(1).copied().unwrap_or(0) as u32,
3545                        kind: *kind,
3546                    }
3547                } else {
3548                    Thunk::Nop
3549                }
3550            }
3551
3552            Op::Transpose { perm } => {
3553                // Pre-compute (out_dims, in_strides_for_each_out_dim) so the
3554                // runtime loop is just an N-D index walk + scatter.
3555                let in_shape = &graph.node(node.inputs[0]).shape;
3556                let in_rank = in_shape.rank();
3557                let in_dims: Vec<usize> = (0..in_rank)
3558                    .map(|i| in_shape.dim(i).unwrap_static())
3559                    .collect();
3560                // Row-major input strides: stride[d] = product of dims[d+1..].
3561                let mut in_strides_full = vec![1usize; in_rank];
3562                for d in (0..in_rank.saturating_sub(1)).rev() {
3563                    in_strides_full[d] = in_strides_full[d + 1] * in_dims[d + 1];
3564                }
3565                let out_dims: Vec<u32> = perm.iter().map(|&p| in_dims[p] as u32).collect();
3566                let in_strides: Vec<u32> =
3567                    perm.iter().map(|&p| in_strides_full[p] as u32).collect();
3568                let in_total = in_dims.iter().product::<usize>() as u32;
3569                let src = node_offset(arena, node.inputs[0]);
3570                let dst = node_offset(arena, node.id);
3571                match node.shape.dtype() {
3572                    rlx_ir::DType::F64 => Thunk::TransposeF64 {
3573                        src,
3574                        dst,
3575                        in_total,
3576                        out_dims,
3577                        in_strides,
3578                    },
3579                    _ => Thunk::Transpose {
3580                        src,
3581                        dst,
3582                        in_total,
3583                        out_dims,
3584                        in_strides,
3585                    },
3586                }
3587            }
3588
3589            Op::ScatterAdd => {
3590                // updates: [num_updates, ...trailing], indices: [num_updates],
3591                // output: [out_dim, ...trailing]
3592                let upd_shape = &graph.node(node.inputs[0]).shape;
3593                let out_shape = &node.shape;
3594                let num_updates = upd_shape.dim(0).unwrap_static();
3595                let out_dim = out_shape.dim(0).unwrap_static();
3596                let trailing: usize = (1..out_shape.rank())
3597                    .map(|i| out_shape.dim(i).unwrap_static())
3598                    .product::<usize>()
3599                    .max(1);
3600                Thunk::ScatterAdd {
3601                    updates: node_offset(arena, node.inputs[0]),
3602                    indices: node_offset(arena, node.inputs[1]),
3603                    dst: node_offset(arena, node.id),
3604                    num_updates: num_updates as u32,
3605                    out_dim: out_dim as u32,
3606                    trailing: trailing as u32,
3607                }
3608            }
3609
3610            Op::GroupedMatMul => {
3611                // Inputs: [input(M, K), weight(E, K, N), expert_idx(M)]
3612                let in_shape = &graph.node(node.inputs[0]).shape;
3613                let w_shape = &graph.node(node.inputs[1]).shape;
3614                let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
3615                let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
3616                let num_experts = w_shape.dim(0).unwrap_static();
3617                let n = w_shape.dim(2).unwrap_static();
3618                Thunk::GroupedMatMul {
3619                    input: node_offset(arena, node.inputs[0]),
3620                    weight: node_offset(arena, node.inputs[1]),
3621                    expert_idx: node_offset(arena, node.inputs[2]),
3622                    dst: node_offset(arena, node.id),
3623                    m: m as u32,
3624                    k_dim: k_dim as u32,
3625                    n: n as u32,
3626                    num_experts: num_experts as u32,
3627                }
3628            }
3629
3630            Op::DequantGroupedMatMul { scheme } => {
3631                let in_shape = &graph.node(node.inputs[0]).shape;
3632                let w_shape = &graph.node(node.inputs[1]).shape;
3633                let m = in_shape.dim(in_shape.rank() - 2).unwrap_static();
3634                let k_dim = in_shape.dim(in_shape.rank() - 1).unwrap_static();
3635                let out_shape = &node.shape;
3636                let n = out_shape.dim(out_shape.rank() - 1).unwrap_static();
3637                let block_elems = scheme.gguf_block_size() as usize;
3638                let block_bytes = scheme.gguf_block_bytes() as usize;
3639                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
3640                let total_bytes = w_shape.num_elements().unwrap();
3641                let num_experts = total_bytes / slab_bytes.max(1);
3642                Thunk::DequantGroupedMatMulGguf {
3643                    input: node_offset(arena, node.inputs[0]),
3644                    w_q: node_offset(arena, node.inputs[1]),
3645                    expert_idx: node_offset(arena, node.inputs[2]),
3646                    dst: node_offset(arena, node.id),
3647                    m: m as u32,
3648                    k_dim: k_dim as u32,
3649                    n: n as u32,
3650                    num_experts: num_experts as u32,
3651                    scheme: *scheme,
3652                }
3653            }
3654
3655            Op::DequantMoEWeights { scheme } => {
3656                let w_shape = &graph.node(node.inputs[0]).shape;
3657                let out_shape = &node.shape;
3658                let num_experts = out_shape.dim(0).unwrap_static();
3659                let k_dim = out_shape.dim(1).unwrap_static();
3660                let n = out_shape.dim(2).unwrap_static();
3661                let block_elems = scheme.gguf_block_size() as usize;
3662                let block_bytes = scheme.gguf_block_bytes() as usize;
3663                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
3664                let total_bytes = w_shape.num_elements().unwrap();
3665                assert_eq!(
3666                    total_bytes,
3667                    num_experts * slab_bytes,
3668                    "DequantMoEWeights packed bytes mismatch"
3669                );
3670                Thunk::DequantMoEWeightsGguf {
3671                    w_q: node_offset(arena, node.inputs[0]),
3672                    dst: node_offset(arena, node.id),
3673                    k_dim: k_dim as u32,
3674                    n: n as u32,
3675                    num_experts: num_experts as u32,
3676                    scheme: *scheme,
3677                }
3678            }
3679
3680            Op::TopK { k } => {
3681                let in_shape = &graph.node(node.inputs[0]).shape;
3682                let rank = in_shape.rank();
3683                let axis_dim = in_shape.dim(rank - 1).unwrap_static();
3684                let outer = in_shape.num_elements().unwrap() / axis_dim;
3685                Thunk::TopK {
3686                    src: node_offset(arena, node.inputs[0]),
3687                    dst: node_offset(arena, node.id),
3688                    outer: outer as u32,
3689                    axis_dim: axis_dim as u32,
3690                    k: *k as u32,
3691                }
3692            }
3693
3694            Op::Reduce {
3695                op,
3696                axes,
3697                keep_dim: _,
3698            } => {
3699                // Decompose the input shape into [outer, reduced, inner]
3700                // around the reduced axis range. Non-contiguous reduced
3701                // axes aren't supported here — caller must transpose them
3702                // contiguous first (the coverage tool would surface the
3703                // gap if a model needs it).
3704                let in_shape = &graph.node(node.inputs[0]).shape;
3705                let rank = in_shape.rank();
3706                let mut sorted = axes.clone();
3707                sorted.sort();
3708                sorted.dedup();
3709                let contiguous = sorted.windows(2).all(|w| w[1] == w[0] + 1)
3710                    && !sorted.is_empty()
3711                    && *sorted.last().unwrap() < rank;
3712                if !contiguous {
3713                    Thunk::Nop
3714                } else {
3715                    let first = sorted[0];
3716                    let last = *sorted.last().unwrap();
3717                    let outer: usize = (0..first)
3718                        .map(|i| in_shape.dim(i).unwrap_static())
3719                        .product::<usize>()
3720                        .max(1);
3721                    let reduced: usize = (first..=last)
3722                        .map(|i| in_shape.dim(i).unwrap_static())
3723                        .product();
3724                    let inner: usize = (last + 1..rank)
3725                        .map(|i| in_shape.dim(i).unwrap_static())
3726                        .product::<usize>()
3727                        .max(1);
3728                    let src = node_offset(arena, node.inputs[0]);
3729                    let dst = node_offset(arena, node.id);
3730                    if node.shape.dtype() == rlx_ir::DType::F64 && matches!(op, ReduceOp::Sum) {
3731                        Thunk::ReduceSumF64 {
3732                            src,
3733                            dst,
3734                            outer: outer as u32,
3735                            reduced: reduced as u32,
3736                            inner: inner as u32,
3737                        }
3738                    } else {
3739                        Thunk::Reduce {
3740                            src,
3741                            dst,
3742                            outer: outer as u32,
3743                            reduced: reduced as u32,
3744                            inner: inner as u32,
3745                            op: *op,
3746                        }
3747                    }
3748                }
3749            }
3750
3751            Op::Compare(cmp) => {
3752                let len = node.shape.num_elements().unwrap();
3753                Thunk::Compare {
3754                    lhs: node_offset(arena, node.inputs[0]),
3755                    rhs: node_offset(arena, node.inputs[1]),
3756                    dst: node_offset(arena, node.id),
3757                    len: len as u32,
3758                    op: *cmp,
3759                }
3760            }
3761
3762            Op::Where => {
3763                let len = node.shape.num_elements().unwrap();
3764                Thunk::Where {
3765                    cond: node_offset(arena, node.inputs[0]),
3766                    on_true: node_offset(arena, node.inputs[1]),
3767                    on_false: node_offset(arena, node.inputs[2]),
3768                    dst: node_offset(arena, node.id),
3769                    len: len as u32,
3770                }
3771            }
3772
3773            Op::ReluBackward => {
3774                let len: usize = (0..node.shape.rank())
3775                    .map(|i| node.shape.dim(i).unwrap_static())
3776                    .product();
3777                let x = node_offset(arena, node.inputs[0]);
3778                let dy = node_offset(arena, node.inputs[1]);
3779                let dx = node_offset(arena, node.id);
3780                match node.shape.dtype() {
3781                    rlx_ir::DType::F64 => Thunk::ReluBackwardF64 {
3782                        x,
3783                        dy,
3784                        dx,
3785                        len: len as u32,
3786                    },
3787                    _ => Thunk::ReluBackward {
3788                        x,
3789                        dy,
3790                        dx,
3791                        len: len as u32,
3792                    },
3793                }
3794            }
3795
3796            Op::ComplexNormSq => {
3797                let len: usize = (0..node.shape.rank())
3798                    .map(|i| node.shape.dim(i).unwrap_static())
3799                    .product();
3800                let src = node_offset(arena, node.inputs[0]);
3801                let dst = node_offset(arena, node.id);
3802                Thunk::ComplexNormSqF32 {
3803                    src,
3804                    dst,
3805                    len: len as u32,
3806                }
3807            }
3808
3809            Op::ComplexNormSqBackward => {
3810                let len: usize = (0..node.shape.rank())
3811                    .map(|i| node.shape.dim(i).unwrap_static())
3812                    .product();
3813                let z = node_offset(arena, node.inputs[0]);
3814                let g = node_offset(arena, node.inputs[1]);
3815                let dz = node_offset(arena, node.id);
3816                Thunk::ComplexNormSqBackwardF32 {
3817                    z,
3818                    g,
3819                    dz,
3820                    len: len as u32,
3821                }
3822            }
3823
3824            Op::Conjugate => {
3825                let len: usize = (0..node.shape.rank())
3826                    .map(|i| node.shape.dim(i).unwrap_static())
3827                    .product();
3828                Thunk::ConjugateC64 {
3829                    src: node_offset(arena, node.inputs[0]),
3830                    dst: node_offset(arena, node.id),
3831                    len: len as u32,
3832                }
3833            }
3834
3835            Op::ActivationBackward { kind } => {
3836                let len: usize = (0..node.shape.rank())
3837                    .map(|i| node.shape.dim(i).unwrap_static())
3838                    .product();
3839                let x = node_offset(arena, node.inputs[0]);
3840                let dy = node_offset(arena, node.inputs[1]);
3841                let dx = node_offset(arena, node.id);
3842                match node.shape.dtype() {
3843                    rlx_ir::DType::F64 => Thunk::ActivationBackwardF64 {
3844                        x,
3845                        dy,
3846                        dx,
3847                        len: len as u32,
3848                        kind: *kind,
3849                    },
3850                    _ => Thunk::ActivationBackward {
3851                        x,
3852                        dy,
3853                        dx,
3854                        len: len as u32,
3855                        kind: *kind,
3856                    },
3857                }
3858            }
3859
3860            Op::LayerNormBackwardInput { eps, .. } => {
3861                // axis = -1 only (matches forward LayerNorm thunk).
3862                let h = node.shape.dim(node.shape.rank() - 1).unwrap_static();
3863                let total = node.shape.num_elements().unwrap();
3864                Thunk::LayerNormBackwardInput {
3865                    x: node_offset(arena, node.inputs[0]),
3866                    gamma: node_offset(arena, node.inputs[1]),
3867                    dy: node_offset(arena, node.inputs[2]),
3868                    dx: node_offset(arena, node.id),
3869                    rows: (total / h) as u32,
3870                    h: h as u32,
3871                    eps: *eps,
3872                }
3873            }
3874
3875            Op::LayerNormBackwardGamma { eps, .. } => {
3876                let x_shape = &graph.node(node.inputs[0]).shape;
3877                let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
3878                let x_total = x_shape.num_elements().unwrap();
3879                Thunk::LayerNormBackwardGamma {
3880                    x: node_offset(arena, node.inputs[0]),
3881                    dy: node_offset(arena, node.inputs[1]),
3882                    dgamma: node_offset(arena, node.id),
3883                    rows: (x_total / h) as u32,
3884                    h: h as u32,
3885                    eps: *eps,
3886                }
3887            }
3888
3889            Op::RmsNormBackwardInput { eps, .. }
3890            | Op::RmsNormBackwardGamma { eps, .. }
3891            | Op::RmsNormBackwardBeta { eps, .. } => {
3892                let x_shape = &graph.node(node.inputs[0]).shape;
3893                let h = x_shape.dim(x_shape.rank() - 1).unwrap_static();
3894                let rows = (x_shape.num_elements().unwrap() / h) as u32;
3895                let off = |i: usize| node_offset(arena, node.inputs[i]);
3896                let common = (off(0), off(1), off(2), off(3), rows, h as u32, *eps);
3897                match &node.op {
3898                    Op::RmsNormBackwardInput { .. } => Thunk::RmsNormBackwardInput {
3899                        x: common.0,
3900                        gamma: common.1,
3901                        beta: common.2,
3902                        dy: common.3,
3903                        dx: node_offset(arena, node.id),
3904                        rows: common.4,
3905                        h: common.5,
3906                        eps: common.6,
3907                    },
3908                    Op::RmsNormBackwardGamma { .. } => Thunk::RmsNormBackwardGamma {
3909                        x: common.0,
3910                        gamma: common.1,
3911                        beta: common.2,
3912                        dy: common.3,
3913                        dgamma: node_offset(arena, node.id),
3914                        rows: common.4,
3915                        h: common.5,
3916                        eps: common.6,
3917                    },
3918                    Op::RmsNormBackwardBeta { .. } => Thunk::RmsNormBackwardBeta {
3919                        x: common.0,
3920                        gamma: common.1,
3921                        beta: common.2,
3922                        dy: common.3,
3923                        dbeta: node_offset(arena, node.id),
3924                        rows: common.4,
3925                        h: common.5,
3926                        eps: common.6,
3927                    },
3928                    _ => unreachable!(),
3929                }
3930            }
3931
3932            Op::RopeBackward { head_dim, n_rot } => {
3933                let dy_shape = &graph.node(node.inputs[0]).shape;
3934                let (batch, seq, hidden) = if dy_shape.rank() >= 3 {
3935                    (
3936                        dy_shape.dim(0).unwrap_static(),
3937                        dy_shape.dim(1).unwrap_static(),
3938                        dy_shape.dim(2).unwrap_static(),
3939                    )
3940                } else {
3941                    (
3942                        1,
3943                        dy_shape.dim(0).unwrap_static(),
3944                        dy_shape.dim(1).unwrap_static(),
3945                    )
3946                };
3947                let cos_shape = &graph.node(node.inputs[1]).shape;
3948                let cos_len = cos_shape.num_elements().unwrap();
3949                Thunk::RopeBackward {
3950                    dy: node_offset(arena, node.inputs[0]),
3951                    cos: node_offset(arena, node.inputs[1]),
3952                    sin: node_offset(arena, node.inputs[2]),
3953                    dx: node_offset(arena, node.id),
3954                    batch: batch as u32,
3955                    seq: seq as u32,
3956                    hidden: hidden as u32,
3957                    head_dim: *head_dim as u32,
3958                    n_rot: *n_rot as u32,
3959                    cos_len: cos_len as u32,
3960                }
3961            }
3962
3963            Op::CumsumBackward { exclusive, .. } => {
3964                let dy_shape = &graph.node(node.inputs[0]).shape;
3965                let rank = dy_shape.rank();
3966                let cols = dy_shape.dim(rank - 1).unwrap_static();
3967                let rows = dy_shape.num_elements().unwrap() / cols;
3968                Thunk::CumsumBackward {
3969                    dy: node_offset(arena, node.inputs[0]),
3970                    dx: node_offset(arena, node.id),
3971                    rows: rows as u32,
3972                    cols: cols as u32,
3973                    exclusive: *exclusive,
3974                }
3975            }
3976
3977            Op::GatherBackward { .. } => {
3978                let dy_shape = &graph.node(node.inputs[0]).shape;
3979                let idx_shape = &graph.node(node.inputs[1]).shape;
3980                let out_shape = &node.shape;
3981                let rank = out_shape.rank();
3982                let axis = match &node.op {
3983                    Op::GatherBackward { axis } => *axis,
3984                    _ => 0,
3985                };
3986                let axis_u = if axis < 0 {
3987                    (rank as i32 + axis) as usize
3988                } else {
3989                    axis as usize
3990                };
3991                let outer: usize = (0..axis_u)
3992                    .map(|i| dy_shape.dim(i).unwrap_static())
3993                    .product::<usize>()
3994                    .max(1);
3995                let num_idx = idx_shape.dim(axis_u).unwrap_static();
3996                let trailing: usize = (axis_u + 1..dy_shape.rank())
3997                    .map(|i| dy_shape.dim(i).unwrap_static())
3998                    .product::<usize>()
3999                    .max(1);
4000                let axis_dim = out_shape.dim(axis_u).unwrap_static();
4001                Thunk::GatherBackward {
4002                    dy: node_offset(arena, node.inputs[0]),
4003                    indices: node_offset(arena, node.inputs[1]),
4004                    dst: node_offset(arena, node.id),
4005                    outer: outer as u32,
4006                    axis_dim: axis_dim as u32,
4007                    num_idx: num_idx as u32,
4008                    trailing: trailing as u32,
4009                }
4010            }
4011
4012            Op::GroupNormBackwardInput { num_groups, eps }
4013            | Op::GroupNormBackwardGamma { num_groups, eps }
4014            | Op::GroupNormBackwardBeta { num_groups, eps } => {
4015                let x_shape = &graph.node(node.inputs[0]).shape;
4016                let n = x_shape.dim(0).unwrap_static() as u32;
4017                let c = x_shape.dim(1).unwrap_static() as u32;
4018                let h = x_shape.dim(2).unwrap_static() as u32;
4019                let w = x_shape.dim(3).unwrap_static() as u32;
4020                match &node.op {
4021                    Op::GroupNormBackwardInput { .. } => Thunk::GroupNormBackwardInput {
4022                        x: node_offset(arena, node.inputs[0]),
4023                        gamma: node_offset(arena, node.inputs[1]),
4024                        beta: node_offset(arena, node.inputs[2]),
4025                        dy: node_offset(arena, node.inputs[3]),
4026                        dx: node_offset(arena, node.id),
4027                        n,
4028                        c,
4029                        h,
4030                        w,
4031                        num_groups: *num_groups as u32,
4032                        eps: *eps,
4033                    },
4034                    Op::GroupNormBackwardGamma { .. } => Thunk::GroupNormBackwardGamma {
4035                        x: node_offset(arena, node.inputs[0]),
4036                        dy: node_offset(arena, node.inputs[1]),
4037                        dgamma: node_offset(arena, node.id),
4038                        n,
4039                        c,
4040                        h,
4041                        w,
4042                        num_groups: *num_groups as u32,
4043                        eps: *eps,
4044                    },
4045                    Op::GroupNormBackwardBeta { .. } => Thunk::GroupNormBackwardBeta {
4046                        dy: node_offset(arena, node.inputs[1]),
4047                        dbeta: node_offset(arena, node.id),
4048                        n,
4049                        c,
4050                        h,
4051                        w,
4052                    },
4053                    _ => unreachable!(),
4054                }
4055            }
4056
4057            Op::MaxPool2dBackward {
4058                kernel_size,
4059                stride,
4060                padding,
4061            } => {
4062                let x_shape = &graph.node(node.inputs[0]).shape;
4063                let dy_shape = &graph.node(node.inputs[1]).shape;
4064                if kernel_size.len() == 2 && x_shape.rank() == 4 && dy_shape.rank() == 4 {
4065                    Thunk::MaxPool2dBackward {
4066                        x: node_offset(arena, node.inputs[0]),
4067                        dy: node_offset(arena, node.inputs[1]),
4068                        dx: node_offset(arena, node.id),
4069                        n: x_shape.dim(0).unwrap_static() as u32,
4070                        c: x_shape.dim(1).unwrap_static() as u32,
4071                        h: x_shape.dim(2).unwrap_static() as u32,
4072                        w: x_shape.dim(3).unwrap_static() as u32,
4073                        h_out: dy_shape.dim(2).unwrap_static() as u32,
4074                        w_out: dy_shape.dim(3).unwrap_static() as u32,
4075                        kh: kernel_size[0] as u32,
4076                        kw: kernel_size[1] as u32,
4077                        sh: stride.first().copied().unwrap_or(1) as u32,
4078                        sw: stride.get(1).copied().unwrap_or(1) as u32,
4079                        ph: padding.first().copied().unwrap_or(0) as u32,
4080                        pw: padding.get(1).copied().unwrap_or(0) as u32,
4081                    }
4082                } else {
4083                    Thunk::Nop
4084                }
4085            }
4086
4087            Op::Conv2dBackwardInput {
4088                kernel_size,
4089                stride,
4090                padding,
4091                dilation,
4092                groups,
4093            } => {
4094                let dy_shape = &graph.node(node.inputs[0]).shape;
4095                let w_shape = &graph.node(node.inputs[1]).shape;
4096                let out_shape = &node.shape;
4097                if kernel_size.len() == 2
4098                    && dy_shape.rank() == 4
4099                    && w_shape.rank() == 4
4100                    && out_shape.rank() == 4
4101                {
4102                    Thunk::Conv2dBackwardInput {
4103                        dy: node_offset(arena, node.inputs[0]),
4104                        w: node_offset(arena, node.inputs[1]),
4105                        dx: node_offset(arena, node.id),
4106                        n: out_shape.dim(0).unwrap_static() as u32,
4107                        c_in: out_shape.dim(1).unwrap_static() as u32,
4108                        h: out_shape.dim(2).unwrap_static() as u32,
4109                        w_in: out_shape.dim(3).unwrap_static() as u32,
4110                        c_out: dy_shape.dim(1).unwrap_static() as u32,
4111                        h_out: dy_shape.dim(2).unwrap_static() as u32,
4112                        w_out: dy_shape.dim(3).unwrap_static() as u32,
4113                        kh: kernel_size[0] as u32,
4114                        kw: kernel_size[1] as u32,
4115                        sh: stride.first().copied().unwrap_or(1) as u32,
4116                        sw: stride.get(1).copied().unwrap_or(1) as u32,
4117                        ph: padding.first().copied().unwrap_or(0) as u32,
4118                        pw: padding.get(1).copied().unwrap_or(0) as u32,
4119                        dh: dilation.first().copied().unwrap_or(1) as u32,
4120                        dw: dilation.get(1).copied().unwrap_or(1) as u32,
4121                        groups: *groups as u32,
4122                    }
4123                } else {
4124                    Thunk::Nop
4125                }
4126            }
4127
4128            Op::Conv2dBackwardWeight {
4129                kernel_size,
4130                stride,
4131                padding,
4132                dilation,
4133                groups,
4134            } => {
4135                let x_shape = &graph.node(node.inputs[0]).shape;
4136                let dy_shape = &graph.node(node.inputs[1]).shape;
4137                let dw_shape = &node.shape;
4138                if kernel_size.len() == 2
4139                    && x_shape.rank() == 4
4140                    && dy_shape.rank() == 4
4141                    && dw_shape.rank() == 4
4142                {
4143                    Thunk::Conv2dBackwardWeight {
4144                        x: node_offset(arena, node.inputs[0]),
4145                        dy: node_offset(arena, node.inputs[1]),
4146                        dw: node_offset(arena, node.id),
4147                        n: x_shape.dim(0).unwrap_static() as u32,
4148                        c_in: x_shape.dim(1).unwrap_static() as u32,
4149                        h: x_shape.dim(2).unwrap_static() as u32,
4150                        w: x_shape.dim(3).unwrap_static() as u32,
4151                        c_out: dy_shape.dim(1).unwrap_static() as u32,
4152                        h_out: dy_shape.dim(2).unwrap_static() as u32,
4153                        w_out: dy_shape.dim(3).unwrap_static() as u32,
4154                        kh: kernel_size[0] as u32,
4155                        kw: kernel_size[1] as u32,
4156                        sh: stride.first().copied().unwrap_or(1) as u32,
4157                        sw: stride.get(1).copied().unwrap_or(1) as u32,
4158                        ph: padding.first().copied().unwrap_or(0) as u32,
4159                        pw: padding.get(1).copied().unwrap_or(0) as u32,
4160                        dh: dilation.first().copied().unwrap_or(1) as u32,
4161                        dw_dil: dilation.get(1).copied().unwrap_or(1) as u32,
4162                        groups: *groups as u32,
4163                    }
4164                } else {
4165                    Thunk::Nop
4166                }
4167            }
4168
4169            Op::SoftmaxCrossEntropyWithLogits => {
4170                let logits_shape = &graph.node(node.inputs[0]).shape;
4171                if logits_shape.rank() == 2 {
4172                    Thunk::SoftmaxCrossEntropy {
4173                        logits: node_offset(arena, node.inputs[0]),
4174                        labels: node_offset(arena, node.inputs[1]),
4175                        dst: node_offset(arena, node.id),
4176                        n: logits_shape.dim(0).unwrap_static() as u32,
4177                        c: logits_shape.dim(1).unwrap_static() as u32,
4178                    }
4179                } else {
4180                    Thunk::Nop
4181                }
4182            }
4183
4184            Op::SoftmaxCrossEntropyBackward => {
4185                let logits_shape = &graph.node(node.inputs[0]).shape;
4186                if logits_shape.rank() == 2 {
4187                    Thunk::SoftmaxCrossEntropyBackward {
4188                        logits: node_offset(arena, node.inputs[0]),
4189                        labels: node_offset(arena, node.inputs[1]),
4190                        d_loss: node_offset(arena, node.inputs[2]),
4191                        dlogits: node_offset(arena, node.id),
4192                        n: logits_shape.dim(0).unwrap_static() as u32,
4193                        c: logits_shape.dim(1).unwrap_static() as u32,
4194                    }
4195                } else {
4196                    Thunk::Nop
4197                }
4198            }
4199
4200            Op::DenseSolve => {
4201                // A: [n, n], b: [n] or [n, nrhs]. Output matches b.
4202                let a_shape = &graph.node(node.inputs[0]).shape;
4203                let n = a_shape.dim(0).unwrap_static();
4204                debug_assert_eq!(
4205                    n,
4206                    a_shape.dim(1).unwrap_static(),
4207                    "DenseSolve: A must be square"
4208                );
4209                let b_elems = node.shape.num_elements().unwrap();
4210                let nrhs = b_elems / n;
4211                match node.shape.dtype() {
4212                    rlx_ir::DType::F64 => Thunk::DenseSolveF64 {
4213                        a: node_offset(arena, node.inputs[0]),
4214                        b: node_offset(arena, node.inputs[1]),
4215                        x: node_offset(arena, node.id),
4216                        n: n as u32,
4217                        nrhs: nrhs as u32,
4218                    },
4219                    rlx_ir::DType::F32 => Thunk::DenseSolveF32 {
4220                        a: node_offset(arena, node.inputs[0]),
4221                        b: node_offset(arena, node.inputs[1]),
4222                        x: node_offset(arena, node.id),
4223                        n: n as u32,
4224                        nrhs: nrhs as u32,
4225                    },
4226                    other => panic!(
4227                        "DenseSolve: F32 + F64 lowered; got {other:?}. \
4228                         Add another variant when needed."
4229                    ),
4230                }
4231            }
4232
4233            Op::BatchedDenseSolve => {
4234                // A: [B, N, N], b: [B, N] or [B, N, K]. Output matches b.
4235                let a_shape = &graph.node(node.inputs[0]).shape;
4236                assert_eq!(a_shape.rank(), 3, "BatchedDenseSolve: A rank must be 3");
4237                let batch = a_shape.dim(0).unwrap_static();
4238                let n = a_shape.dim(1).unwrap_static();
4239                debug_assert_eq!(
4240                    n,
4241                    a_shape.dim(2).unwrap_static(),
4242                    "BatchedDenseSolve: A's last two dims must match"
4243                );
4244                let total = node.shape.num_elements().unwrap();
4245                let nrhs = total / (batch * n);
4246                match node.shape.dtype() {
4247                    rlx_ir::DType::F32 => Thunk::BatchedDenseSolveF32 {
4248                        a: node_offset(arena, node.inputs[0]),
4249                        b: node_offset(arena, node.inputs[1]),
4250                        x: node_offset(arena, node.id),
4251                        batch: batch as u32,
4252                        n: n as u32,
4253                        nrhs: nrhs as u32,
4254                    },
4255                    rlx_ir::DType::F64 => Thunk::BatchedDenseSolveF64 {
4256                        a: node_offset(arena, node.inputs[0]),
4257                        b: node_offset(arena, node.inputs[1]),
4258                        x: node_offset(arena, node.id),
4259                        batch: batch as u32,
4260                        n: n as u32,
4261                        nrhs: nrhs as u32,
4262                    },
4263                    other => panic!("BatchedDenseSolve: F32 + F64 only, got {other:?}"),
4264                }
4265            }
4266
4267            Op::Scan {
4268                body,
4269                length,
4270                save_trajectory,
4271                num_bcast,
4272                num_xs,
4273                num_checkpoints,
4274            } => {
4275                assert!(
4276                    *num_checkpoints == 0 || *num_checkpoints <= *length,
4277                    "Op::Scan: num_checkpoints={} must be 0 or ≤ length={}",
4278                    *num_checkpoints,
4279                    *length
4280                );
4281                if *num_checkpoints != 0 && *num_checkpoints != *length {
4282                    assert!(
4283                        *save_trajectory,
4284                        "Op::Scan: num_checkpoints<length only meaningful when save_trajectory=true"
4285                    );
4286                }
4287                // Plan + compile the body sub-graph standalone. The body
4288                // gets its own Arena; per execution we clone its
4289                // pristine bytes, copy the outer carry (and per-step xs
4290                // slices, if any) into the body's Input slots, run the
4291                // body schedule N times, then copy the body's output
4292                // back to the outer arena.
4293                //
4294                // Body invariants: 1 + num_xs Op::Inputs in NodeId order
4295                // — first declared is the carry, rest are x_t_i. Single
4296                // graph output (the next carry), same shape as carry.
4297                let body_plan = rlx_opt::memory::plan_memory(body);
4298                let _body_arena_size = body_plan.arena_size;
4299                // Snapshot per-input byte offsets before plan_memory
4300                // moves into the Arena below.
4301                let body_offsets: HashMap<NodeId, usize> = body_plan
4302                    .assignments
4303                    .iter()
4304                    .map(|(id, slot)| (*id, slot.offset))
4305                    .collect();
4306
4307                // Collect body Input nodes in NodeId order; first is
4308                // carry, rest are per-step xs in matching order.
4309                let mut body_inputs: Vec<NodeId> = body
4310                    .nodes()
4311                    .iter()
4312                    .filter(|n| matches!(n.op, Op::Input { .. }))
4313                    .map(|n| n.id)
4314                    .collect();
4315                body_inputs.sort();
4316                let n_body_inputs = body_inputs.len();
4317                let expected = 1 + *num_bcast as usize + *num_xs as usize;
4318                if n_body_inputs != expected {
4319                    let names: Vec<String> = body
4320                        .nodes()
4321                        .iter()
4322                        .filter_map(|n| match &n.op {
4323                            Op::Input { name } => Some(format!("{}={}", n.id, name)),
4324                            _ => None,
4325                        })
4326                        .collect();
4327                    panic!(
4328                        "Op::Scan body has {} Op::Input nodes; expected {} \
4329                            (1 carry + {} bcast + {} xs). Inputs by NodeId: [{}]",
4330                        n_body_inputs,
4331                        expected,
4332                        *num_bcast,
4333                        *num_xs,
4334                        names.join(", ")
4335                    );
4336                }
4337
4338                let body_input_id = body_inputs[0];
4339                let body_input_off = body_offsets[&body_input_id];
4340                let body_output_id = body
4341                    .outputs
4342                    .first()
4343                    .copied()
4344                    .expect("Op::Scan body must declare one output");
4345                let body_output_off = body_offsets[&body_output_id];
4346
4347                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4348                // Fill body Constant nodes — mirror the outer-graph logic
4349                // in rlx-runtime/src/backend.rs (dtype-aware).
4350                for n in body.nodes() {
4351                    if let Op::Constant { data } = &n.op
4352                        && body_arena.has_buffer(n.id)
4353                        && !data.is_empty()
4354                    {
4355                        match n.shape.dtype() {
4356                            rlx_ir::DType::F64 => {
4357                                let off = body_arena.byte_offset(n.id);
4358                                let buf = body_arena.raw_buf_mut();
4359                                let nbytes = (buf.len() - off).min(data.len());
4360                                buf[off..off + nbytes].copy_from_slice(&data[..nbytes]);
4361                            }
4362                            _ => {
4363                                let buf = body_arena.slice_mut(n.id);
4364                                let n_floats = data.len() / 4;
4365                                let n_lim = buf.len().min(n_floats);
4366                                for i in 0..n_lim {
4367                                    let bytes = [
4368                                        data[i * 4],
4369                                        data[i * 4 + 1],
4370                                        data[i * 4 + 2],
4371                                        data[i * 4 + 3],
4372                                    ];
4373                                    buf[i] = f32::from_le_bytes(bytes);
4374                                }
4375                            }
4376                        }
4377                    }
4378                }
4379                let body_init = body_arena.raw_buf().to_vec();
4380                let body_schedule = compile_thunks(body, &body_arena);
4381
4382                // Carry bytes — for trajectory mode, the outer node's
4383                // shape is [length, *carry_shape], so dividing by length
4384                // gives one row's bytes; the body's input slot still
4385                // holds carry_shape bytes.
4386                let carry_bytes = if *save_trajectory {
4387                    let total = node
4388                        .shape
4389                        .size_bytes()
4390                        .expect("Op::Scan trajectory output must have static shape");
4391                    total / *length as usize
4392                } else {
4393                    node.shape
4394                        .size_bytes()
4395                        .expect("Op::Scan carry must have static shape")
4396                };
4397
4398                // Bcast inputs occupy body_inputs[1..1+num_bcast] and
4399                // outer node.inputs[1..1+num_bcast]. They keep their
4400                // natural shape (no [length, ...] prefix) and are
4401                // copied into body_buf ONCE before the scan loop.
4402                let mut bcast_inputs: Vec<(usize, usize, u32)> =
4403                    Vec::with_capacity(*num_bcast as usize);
4404                for i in 0..*num_bcast as usize {
4405                    let body_b_id = body_inputs[1 + i];
4406                    let body_b_off = body_offsets[&body_b_id];
4407                    let outer_b_id = node.inputs[1 + i];
4408                    let outer_b_off = node_offset(arena, outer_b_id);
4409                    let outer_b_shape = &graph.node(outer_b_id).shape;
4410                    let total = outer_b_shape
4411                        .size_bytes()
4412                        .expect("Op::Scan bcast must have static shape");
4413                    bcast_inputs.push((body_b_off, outer_b_off, total as u32));
4414                }
4415
4416                // xs occupy body_inputs[1+num_bcast..] and node.inputs
4417                // [1+num_bcast..]. Each has shape [length, *per_step];
4418                // per-step bytes = total / length.
4419                let mut xs_inputs: Vec<(usize, usize, u32)> = Vec::with_capacity(*num_xs as usize);
4420                let xs_base = 1 + *num_bcast as usize;
4421                for i in 0..*num_xs as usize {
4422                    let body_x_id = body_inputs[xs_base + i];
4423                    let body_x_off = body_offsets[&body_x_id];
4424                    let outer_xs_id = node.inputs[xs_base + i];
4425                    let outer_xs_off = node_offset(arena, outer_xs_id);
4426                    let outer_xs_shape = &graph.node(outer_xs_id).shape;
4427                    let total = outer_xs_shape
4428                        .size_bytes()
4429                        .expect("Op::Scan xs must have static shape");
4430                    let per_step = total / *length as usize;
4431                    xs_inputs.push((body_x_off, outer_xs_off, per_step as u32));
4432                }
4433
4434                Thunk::Scan {
4435                    body: Arc::new(body_schedule),
4436                    body_init: Arc::new(body_init),
4437                    body_input_off,
4438                    body_output_off,
4439                    outer_init_off: node_offset(arena, node.inputs[0]),
4440                    outer_final_off: node_offset(arena, node.id),
4441                    length: *length,
4442                    carry_bytes: carry_bytes as u32,
4443                    save_trajectory: *save_trajectory,
4444                    xs_inputs: Arc::new(xs_inputs),
4445                    bcast_inputs: Arc::new(bcast_inputs),
4446                    num_checkpoints: *num_checkpoints,
4447                }
4448            }
4449
4450            Op::ScanBackward {
4451                body_vjp,
4452                length,
4453                save_trajectory,
4454                num_xs,
4455                num_checkpoints,
4456                forward_body,
4457            } => {
4458                let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
4459                if is_recursive {
4460                    assert!(
4461                        forward_body.is_some(),
4462                        "Op::ScanBackward with num_checkpoints<length requires forward_body"
4463                    );
4464                }
4465                // body_vjp has signature
4466                //   (carry, x_t_0, ..., x_t_{num_xs-1}, d_output) → dcarry
4467                // Identify slots:
4468                //   * "d_output" by exact name (AD-introduced seed Input).
4469                //   * Remaining Inputs sorted by NodeId — first is the
4470                //     carry mirror, rest are x_t_i mirrors in body's
4471                //     original Op::Input declaration order.
4472                let body_plan = rlx_opt::memory::plan_memory(body_vjp);
4473                let body_offsets: HashMap<NodeId, usize> = body_plan
4474                    .assignments
4475                    .iter()
4476                    .map(|(id, slot)| (*id, slot.offset))
4477                    .collect();
4478                let mut body_d_output_off: Option<usize> = None;
4479                let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
4480                for n in body_vjp.nodes() {
4481                    if let Op::Input { name } = &n.op {
4482                        let off = body_offsets[&n.id];
4483                        if name == "d_output" {
4484                            body_d_output_off = Some(off);
4485                        } else {
4486                            body_other_inputs.push((n.id, off));
4487                        }
4488                    }
4489                }
4490                body_other_inputs.sort_by_key(|(id, _)| *id);
4491                let body_d_output_off =
4492                    body_d_output_off.expect("ScanBackward body_vjp missing 'd_output' Input");
4493                let expected_others = 1 + *num_xs as usize;
4494                assert_eq!(
4495                    body_other_inputs.len(),
4496                    expected_others,
4497                    "ScanBackward body_vjp has {} non-d_output Inputs; \
4498                     expected {} (1 carry + {} xs)",
4499                    body_other_inputs.len(),
4500                    expected_others,
4501                    num_xs
4502                );
4503                let body_carry_in_off = body_other_inputs[0].1;
4504                let body_x_offs: Vec<usize> = body_other_inputs
4505                    .iter()
4506                    .skip(1)
4507                    .map(|(_, off)| *off)
4508                    .collect();
4509                let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
4510
4511                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4512                // Fill body_vjp's Constants (mirrors the Scan lowering).
4513                for n in body_vjp.nodes() {
4514                    if let Op::Constant { data } = &n.op
4515                        && body_arena.has_buffer(n.id)
4516                        && !data.is_empty()
4517                    {
4518                        match n.shape.dtype() {
4519                            rlx_ir::DType::F64 => {
4520                                let off = body_arena.byte_offset(n.id);
4521                                let buf = body_arena.raw_buf_mut();
4522                                let nb = (buf.len() - off).min(data.len());
4523                                buf[off..off + nb].copy_from_slice(&data[..nb]);
4524                            }
4525                            _ => {
4526                                let buf = body_arena.slice_mut(n.id);
4527                                let nf = data.len() / 4;
4528                                let nl = buf.len().min(nf);
4529                                for i in 0..nl {
4530                                    let bytes = [
4531                                        data[i * 4],
4532                                        data[i * 4 + 1],
4533                                        data[i * 4 + 2],
4534                                        data[i * 4 + 3],
4535                                    ];
4536                                    buf[i] = f32::from_le_bytes(bytes);
4537                                }
4538                            }
4539                        }
4540                    }
4541                }
4542                let body_init = body_arena.raw_buf().to_vec();
4543                let body_schedule = compile_thunks(body_vjp, &body_arena);
4544
4545                // Carry bytes from the dcarry output node (== carry shape).
4546                let carry_bytes = body_vjp
4547                    .node(body_vjp.outputs[0])
4548                    .shape
4549                    .size_bytes()
4550                    .expect("ScanBackward dcarry must be statically shaped");
4551                let carry_elem_size = body_vjp
4552                    .node(body_vjp.outputs[0])
4553                    .shape
4554                    .dtype()
4555                    .size_bytes() as u32;
4556
4557                // For each xs input on the outer node:
4558                // (outer_xs_base, per_step_bytes).
4559                let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
4560                for i in 0..*num_xs as usize {
4561                    let outer_xs_id = node.inputs[3 + i];
4562                    let outer_xs_off = node_offset(arena, outer_xs_id);
4563                    let outer_xs_shape = &graph.node(outer_xs_id).shape;
4564                    let total = outer_xs_shape
4565                        .size_bytes()
4566                        .expect("ScanBackward xs must have static shape");
4567                    let per_step = total / *length as usize;
4568                    outer_xs_offs.push((outer_xs_off, per_step as u32));
4569                }
4570
4571                // If recursive checkpointing is active, we also compile
4572                // the forward body so the executor can recompute
4573                // intermediate carries. The forward body is supplied
4574                // by the AD pass via `forward_body: Some(_)`.
4575                let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
4576                    if is_recursive {
4577                        let fb = forward_body.as_ref().unwrap();
4578                        let fb_plan = rlx_opt::memory::plan_memory(fb);
4579                        let fb_offsets: HashMap<NodeId, usize> = fb_plan
4580                            .assignments
4581                            .iter()
4582                            .map(|(id, slot)| (*id, slot.offset))
4583                            .collect();
4584                        let mut fb_inputs: Vec<NodeId> = fb
4585                            .nodes()
4586                            .iter()
4587                            .filter(|n| matches!(n.op, Op::Input { .. }))
4588                            .map(|n| n.id)
4589                            .collect();
4590                        fb_inputs.sort();
4591                        let fb_carry = fb_offsets[&fb_inputs[0]];
4592                        let fb_xs: Vec<usize> = (1..fb_inputs.len())
4593                            .map(|i| fb_offsets[&fb_inputs[i]])
4594                            .collect();
4595                        let fb_out = fb_offsets[&fb.outputs[0]];
4596                        let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
4597                        for n in fb.nodes() {
4598                            if let Op::Constant { data } = &n.op
4599                                && fb_arena.has_buffer(n.id)
4600                                && !data.is_empty()
4601                            {
4602                                // Byte-copy works for any
4603                                // numeric dtype as long as the
4604                                // arena slot is sized to hold
4605                                // it — the Constant's `data`
4606                                // already encodes the right
4607                                // bytes per element.
4608                                let off = fb_arena.byte_offset(n.id);
4609                                let buf = fb_arena.raw_buf_mut();
4610                                let nb = (buf.len() - off).min(data.len());
4611                                buf[off..off + nb].copy_from_slice(&data[..nb]);
4612                            }
4613                        }
4614                        let fb_init_bytes = fb_arena.raw_buf().to_vec();
4615                        let fb_sched = compile_thunks(fb, &fb_arena);
4616                        (
4617                            Some(Arc::new(fb_sched)),
4618                            Some(Arc::new(fb_init_bytes)),
4619                            fb_carry,
4620                            fb_out,
4621                            fb_xs,
4622                        )
4623                    } else {
4624                        (None, None, 0, 0, Vec::new())
4625                    };
4626
4627                Thunk::ScanBackward {
4628                    body_vjp: Arc::new(body_schedule),
4629                    body_init: Arc::new(body_init),
4630                    body_carry_in_off,
4631                    body_x_offs: Arc::new(body_x_offs),
4632                    body_d_output_off,
4633                    body_dcarry_out_off,
4634                    outer_init_off: node_offset(arena, node.inputs[0]),
4635                    outer_traj_off: node_offset(arena, node.inputs[1]),
4636                    outer_upstream_off: node_offset(arena, node.inputs[2]),
4637                    outer_xs_offs: Arc::new(outer_xs_offs),
4638                    outer_dinit_off: node_offset(arena, node.id),
4639                    length: *length,
4640                    carry_bytes: carry_bytes as u32,
4641                    carry_elem_size,
4642                    save_trajectory: *save_trajectory,
4643                    num_checkpoints: *num_checkpoints,
4644                    forward_body: fb_schedule,
4645                    forward_body_init: fb_init,
4646                    forward_body_carry_in_off: fb_carry_in_off,
4647                    forward_body_output_off: fb_output_off,
4648                    forward_body_x_offs: Arc::new(fb_x_offs),
4649                }
4650            }
4651
4652            Op::ScanBackwardXs {
4653                body_vjp,
4654                length,
4655                save_trajectory,
4656                num_xs,
4657                xs_idx,
4658                num_checkpoints,
4659                forward_body,
4660            } => {
4661                assert!(
4662                    *num_checkpoints == 0 || *num_checkpoints <= *length,
4663                    "Op::ScanBackwardXs: num_checkpoints={} must be 0 or ≤ length={}",
4664                    *num_checkpoints,
4665                    *length
4666                );
4667                let is_recursive = *num_checkpoints != 0 && *num_checkpoints != *length;
4668                if is_recursive {
4669                    assert!(
4670                        forward_body.is_some(),
4671                        "Op::ScanBackwardXs with num_checkpoints<length \
4672                         requires forward_body"
4673                    );
4674                }
4675                // Mirror ScanBackward's body_vjp slot identification +
4676                // arena prep, then add: per-iteration extraction of the
4677                // body_vjp output that corresponds to the chosen xs.
4678                //
4679                // body_vjp's outputs (from `grad(body, [carry, xs_0, ..., xs_{num_xs-1}])`):
4680                //   outputs[0]      = dcarry
4681                //   outputs[1 + i]  = dx_t_i
4682                let body_plan = rlx_opt::memory::plan_memory(body_vjp);
4683                let body_offsets: HashMap<NodeId, usize> = body_plan
4684                    .assignments
4685                    .iter()
4686                    .map(|(id, slot)| (*id, slot.offset))
4687                    .collect();
4688                let mut body_d_output_off: Option<usize> = None;
4689                let mut body_other_inputs: Vec<(NodeId, usize)> = Vec::new();
4690                for n in body_vjp.nodes() {
4691                    if let Op::Input { name } = &n.op {
4692                        let off = body_offsets[&n.id];
4693                        if name == "d_output" {
4694                            body_d_output_off = Some(off);
4695                        } else {
4696                            body_other_inputs.push((n.id, off));
4697                        }
4698                    }
4699                }
4700                body_other_inputs.sort_by_key(|(id, _)| *id);
4701                let body_d_output_off =
4702                    body_d_output_off.expect("ScanBackwardXs body_vjp missing 'd_output' Input");
4703                let expected_others = 1 + *num_xs as usize;
4704                assert_eq!(
4705                    body_other_inputs.len(),
4706                    expected_others,
4707                    "ScanBackwardXs body_vjp has {} non-d_output Inputs; expected {}",
4708                    body_other_inputs.len(),
4709                    expected_others
4710                );
4711                let body_carry_in_off = body_other_inputs[0].1;
4712                let body_x_offs: Vec<usize> = body_other_inputs
4713                    .iter()
4714                    .skip(1)
4715                    .map(|(_, off)| *off)
4716                    .collect();
4717                let body_dcarry_out_off = body_offsets[&body_vjp.outputs[0]];
4718                let dxs_out_node = body_vjp.outputs[1 + *xs_idx as usize];
4719                let body_dxs_out_off = body_offsets[&dxs_out_node];
4720
4721                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
4722                for n in body_vjp.nodes() {
4723                    if let Op::Constant { data } = &n.op
4724                        && body_arena.has_buffer(n.id)
4725                        && !data.is_empty()
4726                    {
4727                        match n.shape.dtype() {
4728                            rlx_ir::DType::F64 => {
4729                                let off = body_arena.byte_offset(n.id);
4730                                let buf = body_arena.raw_buf_mut();
4731                                let nb = (buf.len() - off).min(data.len());
4732                                buf[off..off + nb].copy_from_slice(&data[..nb]);
4733                            }
4734                            _ => {
4735                                let buf = body_arena.slice_mut(n.id);
4736                                let nf = data.len() / 4;
4737                                let nl = buf.len().min(nf);
4738                                for i in 0..nl {
4739                                    let bytes = [
4740                                        data[i * 4],
4741                                        data[i * 4 + 1],
4742                                        data[i * 4 + 2],
4743                                        data[i * 4 + 3],
4744                                    ];
4745                                    buf[i] = f32::from_le_bytes(bytes);
4746                                }
4747                            }
4748                        }
4749                    }
4750                }
4751                let body_init = body_arena.raw_buf().to_vec();
4752                let body_schedule = compile_thunks(body_vjp, &body_arena);
4753
4754                let carry_bytes = body_vjp
4755                    .node(body_vjp.outputs[0])
4756                    .shape
4757                    .size_bytes()
4758                    .expect("ScanBackwardXs dcarry must be statically shaped");
4759                let carry_elem_size = body_vjp
4760                    .node(body_vjp.outputs[0])
4761                    .shape
4762                    .dtype()
4763                    .size_bytes() as u32;
4764                let per_step_bytes = body_vjp
4765                    .node(dxs_out_node)
4766                    .shape
4767                    .size_bytes()
4768                    .expect("ScanBackwardXs dxs body output must be statically shaped");
4769
4770                let mut outer_xs_offs: Vec<(usize, u32)> = Vec::with_capacity(*num_xs as usize);
4771                for i in 0..*num_xs as usize {
4772                    let outer_xs_id = node.inputs[3 + i];
4773                    let outer_xs_off = node_offset(arena, outer_xs_id);
4774                    let outer_xs_shape = &graph.node(outer_xs_id).shape;
4775                    let total = outer_xs_shape
4776                        .size_bytes()
4777                        .expect("ScanBackwardXs xs must have static shape");
4778                    let per_step = total / *length as usize;
4779                    outer_xs_offs.push((outer_xs_off, per_step as u32));
4780                }
4781
4782                // Compile forward_body for recompute when checkpointed.
4783                // Mirrors the same code path in the ScanBackward arm.
4784                let (fb_schedule, fb_init, fb_carry_in_off, fb_output_off, fb_x_offs) =
4785                    if is_recursive {
4786                        let fb = forward_body.as_ref().unwrap();
4787                        let fb_plan = rlx_opt::memory::plan_memory(fb);
4788                        let fb_offsets: HashMap<NodeId, usize> = fb_plan
4789                            .assignments
4790                            .iter()
4791                            .map(|(id, slot)| (*id, slot.offset))
4792                            .collect();
4793                        let mut fb_inputs: Vec<NodeId> = fb
4794                            .nodes()
4795                            .iter()
4796                            .filter(|n| matches!(n.op, Op::Input { .. }))
4797                            .map(|n| n.id)
4798                            .collect();
4799                        fb_inputs.sort();
4800                        let fb_carry = fb_offsets[&fb_inputs[0]];
4801                        let fb_xs: Vec<usize> = (1..fb_inputs.len())
4802                            .map(|i| fb_offsets[&fb_inputs[i]])
4803                            .collect();
4804                        let fb_out = fb_offsets[&fb.outputs[0]];
4805                        let mut fb_arena = crate::arena::Arena::from_plan(fb_plan);
4806                        for n in fb.nodes() {
4807                            if let Op::Constant { data } = &n.op
4808                                && fb_arena.has_buffer(n.id)
4809                                && !data.is_empty()
4810                            {
4811                                // Byte-copy works for any
4812                                // numeric dtype as long as the
4813                                // arena slot is sized to hold
4814                                // it — the Constant's `data`
4815                                // already encodes the right
4816                                // bytes per element.
4817                                let off = fb_arena.byte_offset(n.id);
4818                                let buf = fb_arena.raw_buf_mut();
4819                                let nb = (buf.len() - off).min(data.len());
4820                                buf[off..off + nb].copy_from_slice(&data[..nb]);
4821                            }
4822                        }
4823                        let fb_init_bytes = fb_arena.raw_buf().to_vec();
4824                        let fb_sched = compile_thunks(fb, &fb_arena);
4825                        (
4826                            Some(Arc::new(fb_sched)),
4827                            Some(Arc::new(fb_init_bytes)),
4828                            fb_carry,
4829                            fb_out,
4830                            fb_xs,
4831                        )
4832                    } else {
4833                        (None, None, 0, 0, Vec::new())
4834                    };
4835
4836                Thunk::ScanBackwardXs {
4837                    body_vjp: Arc::new(body_schedule),
4838                    body_init: Arc::new(body_init),
4839                    body_carry_in_off,
4840                    body_x_offs: Arc::new(body_x_offs),
4841                    body_d_output_off,
4842                    body_dcarry_out_off,
4843                    body_dxs_out_off,
4844                    outer_init_off: node_offset(arena, node.inputs[0]),
4845                    outer_traj_off: node_offset(arena, node.inputs[1]),
4846                    outer_upstream_off: node_offset(arena, node.inputs[2]),
4847                    outer_xs_offs: Arc::new(outer_xs_offs),
4848                    outer_dxs_off: node_offset(arena, node.id),
4849                    length: *length,
4850                    carry_bytes: carry_bytes as u32,
4851                    carry_elem_size,
4852                    per_step_bytes: per_step_bytes as u32,
4853                    save_trajectory: *save_trajectory,
4854                    num_checkpoints: *num_checkpoints,
4855                    forward_body: fb_schedule,
4856                    forward_body_init: fb_init,
4857                    forward_body_carry_in_off: fb_carry_in_off,
4858                    forward_body_output_off: fb_output_off,
4859                    forward_body_x_offs: Arc::new(fb_x_offs),
4860                }
4861            }
4862
4863            Op::Concat { axis } => {
4864                // Compute outer/inner from the OUTPUT shape: all inputs share
4865                // the same shape except along `axis`. The output's leading
4866                // and trailing dims match.
4867                let out_shape = &node.shape;
4868                let rank = out_shape.rank();
4869                let outer: usize = (0..*axis)
4870                    .map(|i| out_shape.dim(i).unwrap_static())
4871                    .product::<usize>()
4872                    .max(1);
4873                let inner: usize = (*axis + 1..rank)
4874                    .map(|i| out_shape.dim(i).unwrap_static())
4875                    .product::<usize>()
4876                    .max(1);
4877                let total_axis = out_shape.dim(*axis).unwrap_static();
4878                let inputs: Vec<(usize, u32)> = node
4879                    .inputs
4880                    .iter()
4881                    .map(|&in_id| {
4882                        let in_shape = &graph.node(in_id).shape;
4883                        let in_axis = in_shape.dim(*axis).unwrap_static();
4884                        (node_offset(arena, in_id), in_axis as u32)
4885                    })
4886                    .collect();
4887                let dst = node_offset(arena, node.id);
4888                match out_shape.dtype() {
4889                    rlx_ir::DType::F64 => Thunk::ConcatF64 {
4890                        dst,
4891                        outer: outer as u32,
4892                        inner: inner as u32,
4893                        total_axis: total_axis as u32,
4894                        inputs,
4895                    },
4896                    _ => Thunk::Concat {
4897                        dst,
4898                        outer: outer as u32,
4899                        inner: inner as u32,
4900                        total_axis: total_axis as u32,
4901                        inputs,
4902                    },
4903                }
4904            }
4905
4906            Op::GaussianSplatRender {
4907                width,
4908                height,
4909                tile_size,
4910                radius_scale,
4911                alpha_cutoff,
4912                max_splat_steps,
4913                transmittance_threshold,
4914                max_list_entries,
4915            } => {
4916                let elem_len =
4917                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
4918                Thunk::GaussianSplatRender {
4919                    positions_off: node_offset(arena, node.inputs[0]),
4920                    positions_len: elem_len(node.inputs[0]),
4921                    scales_off: node_offset(arena, node.inputs[1]),
4922                    scales_len: elem_len(node.inputs[1]),
4923                    rotations_off: node_offset(arena, node.inputs[2]),
4924                    rotations_len: elem_len(node.inputs[2]),
4925                    opacities_off: node_offset(arena, node.inputs[3]),
4926                    opacities_len: elem_len(node.inputs[3]),
4927                    colors_off: node_offset(arena, node.inputs[4]),
4928                    colors_len: elem_len(node.inputs[4]),
4929                    sh_coeffs_off: node_offset(arena, node.inputs[5]),
4930                    sh_coeffs_len: elem_len(node.inputs[5]),
4931                    meta_off: node_offset(arena, node.inputs[6]),
4932                    dst_off: node_offset(arena, node.id),
4933                    dst_len: node.shape.num_elements().unwrap_or(0),
4934                    width: *width,
4935                    height: *height,
4936                    tile_size: *tile_size,
4937                    radius_scale: *radius_scale,
4938                    alpha_cutoff: *alpha_cutoff,
4939                    max_splat_steps: *max_splat_steps,
4940                    transmittance_threshold: *transmittance_threshold,
4941                    max_list_entries: *max_list_entries,
4942                }
4943            }
4944
4945            Op::GaussianSplatRenderBackward {
4946                width,
4947                height,
4948                tile_size,
4949                radius_scale,
4950                alpha_cutoff,
4951                max_splat_steps,
4952                transmittance_threshold,
4953                max_list_entries,
4954                loss_grad_clip,
4955                sh_band,
4956                max_anisotropy,
4957            } => {
4958                let elem_len =
4959                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
4960                Thunk::GaussianSplatRenderBackward {
4961                    positions_off: node_offset(arena, node.inputs[0]),
4962                    positions_len: elem_len(node.inputs[0]),
4963                    scales_off: node_offset(arena, node.inputs[1]),
4964                    scales_len: elem_len(node.inputs[1]),
4965                    rotations_off: node_offset(arena, node.inputs[2]),
4966                    rotations_len: elem_len(node.inputs[2]),
4967                    opacities_off: node_offset(arena, node.inputs[3]),
4968                    opacities_len: elem_len(node.inputs[3]),
4969                    colors_off: node_offset(arena, node.inputs[4]),
4970                    colors_len: elem_len(node.inputs[4]),
4971                    sh_coeffs_off: node_offset(arena, node.inputs[5]),
4972                    sh_coeffs_len: elem_len(node.inputs[5]),
4973                    meta_off: node_offset(arena, node.inputs[6]),
4974                    d_loss_off: node_offset(arena, node.inputs[7]),
4975                    d_loss_len: elem_len(node.inputs[7]),
4976                    packed_off: node_offset(arena, node.id),
4977                    packed_len: node.shape.num_elements().unwrap_or(0),
4978                    width: *width,
4979                    height: *height,
4980                    tile_size: *tile_size,
4981                    radius_scale: *radius_scale,
4982                    alpha_cutoff: *alpha_cutoff,
4983                    max_splat_steps: *max_splat_steps,
4984                    transmittance_threshold: *transmittance_threshold,
4985                    max_list_entries: *max_list_entries,
4986                    loss_grad_clip: *loss_grad_clip,
4987                    sh_band: *sh_band,
4988                    max_anisotropy: *max_anisotropy,
4989                }
4990            }
4991
4992            Op::GaussianSplatPrepare {
4993                width,
4994                height,
4995                tile_size,
4996                radius_scale,
4997                alpha_cutoff,
4998                max_splat_steps,
4999                transmittance_threshold,
5000                max_list_entries,
5001            } => {
5002                let elem_len =
5003                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5004                Thunk::GaussianSplatPrepare {
5005                    positions_off: node_offset(arena, node.inputs[0]),
5006                    positions_len: elem_len(node.inputs[0]),
5007                    scales_off: node_offset(arena, node.inputs[1]),
5008                    scales_len: elem_len(node.inputs[1]),
5009                    rotations_off: node_offset(arena, node.inputs[2]),
5010                    rotations_len: elem_len(node.inputs[2]),
5011                    opacities_off: node_offset(arena, node.inputs[3]),
5012                    opacities_len: elem_len(node.inputs[3]),
5013                    colors_off: node_offset(arena, node.inputs[4]),
5014                    colors_len: elem_len(node.inputs[4]),
5015                    sh_coeffs_off: node_offset(arena, node.inputs[5]),
5016                    sh_coeffs_len: elem_len(node.inputs[5]),
5017                    meta_off: node_offset(arena, node.inputs[6]),
5018                    meta_len: elem_len(node.inputs[6]),
5019                    prep_off: node_offset(arena, node.id),
5020                    prep_len: node.shape.num_elements().unwrap_or(0),
5021                    width: *width,
5022                    height: *height,
5023                    tile_size: *tile_size,
5024                    radius_scale: *radius_scale,
5025                    alpha_cutoff: *alpha_cutoff,
5026                    max_splat_steps: *max_splat_steps,
5027                    transmittance_threshold: *transmittance_threshold,
5028                    max_list_entries: *max_list_entries,
5029                }
5030            }
5031
5032            Op::GaussianSplatRasterize {
5033                width,
5034                height,
5035                tile_size,
5036                alpha_cutoff,
5037                max_splat_steps,
5038                transmittance_threshold,
5039                max_list_entries,
5040            } => {
5041                let elem_len =
5042                    |id: NodeId| -> usize { graph.node(id).shape.num_elements().unwrap_or(0) };
5043                let prep_id = node.inputs[0];
5044                let count = match &graph.node(prep_id).op {
5045                    rlx_ir::Op::GaussianSplatPrepare { .. } => {
5046                        elem_len(graph.node(prep_id).inputs[0]) / 3
5047                    }
5048                    _ => 1,
5049                };
5050                Thunk::GaussianSplatRasterize {
5051                    prep_off: node_offset(arena, prep_id),
5052                    prep_len: elem_len(prep_id),
5053                    meta_off: node_offset(arena, node.inputs[1]),
5054                    meta_len: elem_len(node.inputs[1]),
5055                    dst_off: node_offset(arena, node.id),
5056                    dst_len: node.shape.num_elements().unwrap_or(0),
5057                    count,
5058                    width: *width,
5059                    height: *height,
5060                    tile_size: *tile_size,
5061                    alpha_cutoff: *alpha_cutoff,
5062                    max_splat_steps: *max_splat_steps,
5063                    transmittance_threshold: *transmittance_threshold,
5064                    max_list_entries: *max_list_entries,
5065                }
5066            }
5067
5068            Op::Custom { name, attrs, .. } => {
5069                let kernel = crate::op_registry::lookup_cpu_kernel(name).unwrap_or_else(|| {
5070                    panic!(
5071                        "compile_thunks: no CPU kernel registered for \
5072                         Op::Custom('{name}'). Register one via \
5073                         rlx_cpu::op_registry::register_cpu_kernel \
5074                         before compiling on the CPU backend."
5075                    )
5076                });
5077                let inputs_v: Vec<(usize, u32, Shape)> = node
5078                    .inputs
5079                    .iter()
5080                    .map(|&in_id| {
5081                        let s = graph.node(in_id).shape.clone();
5082                        let len = s.num_elements().unwrap_or(0) as u32;
5083                        (node_offset(arena, in_id), len, s)
5084                    })
5085                    .collect();
5086                let out_len = node.shape.num_elements().unwrap_or(0) as u32;
5087                Thunk::CustomOp {
5088                    kernel,
5089                    inputs: inputs_v,
5090                    output: (node_offset(arena, node.id), out_len, node.shape.clone()),
5091                    attrs: attrs.clone(),
5092                }
5093            }
5094
5095            Op::Fft { inverse, norm } => {
5096                let shape = &node.shape;
5097                let meta = rlx_ir::fft::fft_meta(shape);
5098                let dtype = shape.dtype();
5099                assert!(
5100                    matches!(
5101                        dtype,
5102                        rlx_ir::DType::F32 | rlx_ir::DType::F64 | rlx_ir::DType::C64
5103                    ),
5104                    "Op::Fft on CPU requires F32, F64, or C64, got {dtype:?}"
5105                );
5106                Thunk::Fft1d {
5107                    src: node_offset(arena, node.inputs[0]),
5108                    dst: node_offset(arena, node.id),
5109                    outer: meta.outer as u32,
5110                    n_complex: meta.n_complex as u32,
5111                    inverse: *inverse,
5112                    norm_tag: norm.tag(),
5113                    dtype,
5114                }
5115            }
5116
5117            Op::CustomFn {
5118                fwd_body,
5119                num_inputs,
5120                ..
5121            } => {
5122                // Plan + compile the body sub-graph standalone, fill its
5123                // Constants (mirrors the Op::Scan body lowering), then
5124                // capture per-input copy specs and the output spec.
5125                // Body Inputs in NodeId order match the outer node's
5126                // operand vector by position.
5127                let body_plan = rlx_opt::memory::plan_memory(fwd_body);
5128                let body_offsets: HashMap<NodeId, usize> = body_plan
5129                    .assignments
5130                    .iter()
5131                    .map(|(id, slot)| (*id, slot.offset))
5132                    .collect();
5133
5134                let mut body_input_ids: Vec<NodeId> = fwd_body
5135                    .nodes()
5136                    .iter()
5137                    .filter(|n| matches!(n.op, Op::Input { .. }))
5138                    .map(|n| n.id)
5139                    .collect();
5140                body_input_ids.sort();
5141                assert_eq!(
5142                    body_input_ids.len(),
5143                    *num_inputs as usize,
5144                    "Op::CustomFn fwd_body has {} Op::Input(s); declared num_inputs={}",
5145                    body_input_ids.len(),
5146                    *num_inputs,
5147                );
5148
5149                let mut body_arena = crate::arena::Arena::from_plan(body_plan);
5150                for n in fwd_body.nodes() {
5151                    if let Op::Constant { data } = &n.op
5152                        && body_arena.has_buffer(n.id)
5153                        && !data.is_empty()
5154                    {
5155                        match n.shape.dtype() {
5156                            rlx_ir::DType::F64 => {
5157                                let off = body_arena.byte_offset(n.id);
5158                                let buf = body_arena.raw_buf_mut();
5159                                let nb = (buf.len() - off).min(data.len());
5160                                buf[off..off + nb].copy_from_slice(&data[..nb]);
5161                            }
5162                            _ => {
5163                                let buf = body_arena.slice_mut(n.id);
5164                                let nf = data.len() / 4;
5165                                let nl = buf.len().min(nf);
5166                                for i in 0..nl {
5167                                    let bytes = [
5168                                        data[i * 4],
5169                                        data[i * 4 + 1],
5170                                        data[i * 4 + 2],
5171                                        data[i * 4 + 3],
5172                                    ];
5173                                    buf[i] = f32::from_le_bytes(bytes);
5174                                }
5175                            }
5176                        }
5177                    }
5178                }
5179                let body_init = body_arena.raw_buf().to_vec();
5180                let body_schedule = compile_thunks(fwd_body, &body_arena);
5181
5182                // Per primal input: (body_input_off, outer_input_off, bytes).
5183                let inputs_v: Vec<(usize, usize, u32)> = (0..*num_inputs as usize)
5184                    .map(|i| {
5185                        let body_in = body_input_ids[i];
5186                        let body_off = body_offsets[&body_in];
5187                        let outer_in = node.inputs[i];
5188                        let outer_off = node_offset(arena, outer_in);
5189                        let bytes = graph
5190                            .node(outer_in)
5191                            .shape
5192                            .size_bytes()
5193                            .expect("Op::CustomFn primal input must have static shape");
5194                        (body_off, outer_off, bytes as u32)
5195                    })
5196                    .collect();
5197
5198                let body_output_id = fwd_body
5199                    .outputs
5200                    .first()
5201                    .copied()
5202                    .expect("Op::CustomFn fwd_body must declare exactly one output");
5203                let body_output_off = body_offsets[&body_output_id];
5204                let out_bytes = node
5205                    .shape
5206                    .size_bytes()
5207                    .expect("Op::CustomFn output must have static shape");
5208
5209                Thunk::CustomFn {
5210                    body: Arc::new(body_schedule),
5211                    body_init: Arc::new(body_init),
5212                    inputs: Arc::new(inputs_v),
5213                    body_output_off,
5214                    outer_output_off: node_offset(arena, node.id),
5215                    out_bytes: out_bytes as u32,
5216                }
5217            }
5218
5219            _ => Thunk::Nop,
5220        };
5221        thunks.push(t);
5222    }
5223
5224    let cfg = crate::config::RuntimeConfig::global();
5225    let mask_thr = cfg.mask_binary_threshold;
5226    let mask_neg = cfg.attn_mask_neg_inf;
5227    let score_skip = cfg.score_skip_threshold;
5228
5229    // Pre-compile closures (skip Nops — they're filtered out)
5230    let compiled_fns: Vec<Arc<dyn Fn(*mut u8) + Send + Sync>> = thunks
5231        .iter()
5232        .filter(|t| !matches!(t, Thunk::Nop))
5233        .map(|thunk| {
5234            match thunk.clone() {
5235                Thunk::Nop => Arc::new(|_: *mut u8| {}) as Arc<dyn Fn(*mut u8) + Send + Sync>,
5236
5237                Thunk::Sgemm { a, b, c, m, k, n } => {
5238                    let (m, k, n) = (m as usize, k as usize, n as usize);
5239                    Arc::new(move |base: *mut u8| unsafe {
5240                        crate::blas::sgemm(
5241                            sl(a, base, m * k),
5242                            sl(b, base, k * n),
5243                            sl_mut(c, base, m * n),
5244                            m,
5245                            k,
5246                            n,
5247                        );
5248                    })
5249                }
5250
5251                Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
5252                    let (n_, nrhs_) = (n as usize, nrhs as usize);
5253                    Arc::new(move |base: *mut u8| unsafe {
5254                        let a_src = sl_f64(a, base, n_ * n_);
5255                        let b_src = sl_f64(b, base, n_ * nrhs_);
5256                        let mut a_scratch: Vec<f64> = a_src.to_vec();
5257                        let mut x_buf: Vec<f64> = b_src.to_vec();
5258                        let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
5259                        if info != 0 {
5260                            panic!("DenseSolveF64: singular (info={info})");
5261                        }
5262                        sl_mut_f64(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
5263                    })
5264                }
5265
5266                Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
5267                    let (n_, nrhs_) = (n as usize, nrhs as usize);
5268                    Arc::new(move |base: *mut u8| unsafe {
5269                        let a_src = sl(a, base, n_ * n_);
5270                        let b_src = sl(b, base, n_ * nrhs_);
5271                        let mut a_scratch: Vec<f32> = a_src.to_vec();
5272                        let mut x_buf: Vec<f32> = b_src.to_vec();
5273                        let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
5274                        if info != 0 {
5275                            panic!("DenseSolveF32: singular (info={info})");
5276                        }
5277                        sl_mut(x, base, n_ * nrhs_).copy_from_slice(&x_buf);
5278                    })
5279                }
5280
5281                Thunk::FusedMmBiasAct {
5282                    a,
5283                    w,
5284                    bias,
5285                    c,
5286                    m,
5287                    k,
5288                    n,
5289                    act,
5290                } => {
5291                    let (m, k, n) = (m as usize, k as usize, n as usize);
5292                    Arc::new(move |base: *mut u8| unsafe {
5293                        let out = sl_mut(c, base, m * n);
5294                        crate::blas::sgemm(sl(a, base, m * k), sl(w, base, k * n), out, m, k, n);
5295                        // Bias + activation epilogue. Gelu uses the fused
5296                        // `par_bias_gelu` kernel (bias add + Gelu in one
5297                        // pass). For everything else, do the bias add first
5298                        // and then apply the activation per-element. The
5299                        // pre-fix code dispatched `_ => bias_add` and dropped
5300                        // the activation entirely — silent correctness bug
5301                        // for Silu/Relu/Sigmoid/etc.
5302                        match act {
5303                            Some(Activation::Gelu) => {
5304                                crate::kernels::par_bias_gelu(out, sl(bias, base, n), m, n)
5305                            }
5306                            Some(other) => {
5307                                crate::blas::bias_add(out, sl(bias, base, n), m, n);
5308                                apply_activation_inplace(out, other);
5309                            }
5310                            None => crate::blas::bias_add(out, sl(bias, base, n), m, n),
5311                        }
5312                    })
5313                }
5314
5315                Thunk::FusedResidualLN {
5316                    x,
5317                    res,
5318                    bias,
5319                    g,
5320                    b,
5321                    out,
5322                    rows,
5323                    h,
5324                    eps,
5325                    has_bias,
5326                } => {
5327                    let (rows, h) = (rows as usize, h as usize);
5328                    Arc::new(move |base: *mut u8| unsafe {
5329                        let zero = vec![0f32; h]; // closure only — not hot path
5330                        let bi = if has_bias { sl(bias, base, h) } else { &zero };
5331                        let xp = sl(x, base, rows * h).as_ptr() as usize;
5332                        let rp = sl(res, base, rows * h).as_ptr() as usize;
5333                        let op = sl_mut(out, base, rows * h).as_mut_ptr() as usize;
5334                        let bp = bi.as_ptr() as usize;
5335                        let gp = sl(g, base, h).as_ptr() as usize;
5336                        let bbp = sl(b, base, h).as_ptr() as usize;
5337                        crate::pool::par_for(rows, 4, &|off, cnt| {
5338                            let xs = std::slice::from_raw_parts(
5339                                (xp as *const f32).add(off * h),
5340                                cnt * h,
5341                            );
5342                            let rs = std::slice::from_raw_parts(
5343                                (rp as *const f32).add(off * h),
5344                                cnt * h,
5345                            );
5346                            let os = std::slice::from_raw_parts_mut(
5347                                (op as *mut f32).add(off * h),
5348                                cnt * h,
5349                            );
5350                            let bi = std::slice::from_raw_parts(bp as *const f32, h);
5351                            let g = std::slice::from_raw_parts(gp as *const f32, h);
5352                            let b = std::slice::from_raw_parts(bbp as *const f32, h);
5353                            crate::kernels::residual_bias_layer_norm(
5354                                xs, rs, bi, g, b, os, cnt, h, eps,
5355                            );
5356                        });
5357                    })
5358                }
5359
5360                Thunk::BiasAdd {
5361                    src,
5362                    bias,
5363                    dst,
5364                    m,
5365                    n,
5366                } => {
5367                    let (m, n) = (m as usize, n as usize);
5368                    Arc::new(move |base: *mut u8| unsafe {
5369                        let out = sl_mut(dst, base, m * n);
5370                        out.copy_from_slice(sl(src, base, m * n));
5371                        crate::blas::bias_add(out, sl(bias, base, n), m, n);
5372                    })
5373                }
5374
5375                Thunk::Gather {
5376                    table,
5377                    table_len,
5378                    idx,
5379                    dst,
5380                    num_idx,
5381                    trailing,
5382                } => {
5383                    let (ni, tr, tl) = (num_idx as usize, trailing as usize, table_len as usize);
5384                    Arc::new(move |base: *mut u8| unsafe {
5385                        let tab = sl(table, base, tl);
5386                        let ids = sl(idx, base, ni);
5387                        let out = sl_mut(dst, base, ni * tr);
5388                        for i in 0..ni {
5389                            let row = ids[i] as usize;
5390                            out[i * tr..(i + 1) * tr]
5391                                .copy_from_slice(&tab[row * tr..(row + 1) * tr]);
5392                        }
5393                    })
5394                }
5395
5396                Thunk::Narrow {
5397                    src,
5398                    dst,
5399                    outer,
5400                    src_stride,
5401                    dst_stride,
5402                    inner,
5403                    elem_bytes,
5404                } => {
5405                    narrow_thunk_closure(src, dst, outer, src_stride, dst_stride, inner, elem_bytes)
5406                }
5407
5408                Thunk::Copy { src, dst, len } => {
5409                    let len = len as usize;
5410                    Arc::new(move |base: *mut u8| unsafe {
5411                        sl_mut(dst, base, len).copy_from_slice(sl(src, base, len));
5412                    })
5413                }
5414
5415                Thunk::Softmax { data, rows, cols } => {
5416                    let (rows, cols) = (rows as usize, cols as usize);
5417                    Arc::new(move |base: *mut u8| unsafe {
5418                        crate::naive::softmax(sl_mut(data, base, rows * cols), rows, cols);
5419                    })
5420                }
5421
5422                Thunk::Cumsum {
5423                    src,
5424                    dst,
5425                    rows,
5426                    cols,
5427                    exclusive,
5428                } => {
5429                    let (rows, cols) = (rows as usize, cols as usize);
5430                    Arc::new(move |base: *mut u8| unsafe {
5431                        let s = sl(src, base, rows * cols);
5432                        let d = sl_mut(dst, base, rows * cols);
5433                        if exclusive {
5434                            for r in 0..rows {
5435                                let mut acc = 0.0f32;
5436                                for c in 0..cols {
5437                                    d[r * cols + c] = acc;
5438                                    acc += s[r * cols + c];
5439                                }
5440                            }
5441                        } else {
5442                            for r in 0..rows {
5443                                let mut acc = 0.0f32;
5444                                for c in 0..cols {
5445                                    acc += s[r * cols + c];
5446                                    d[r * cols + c] = acc;
5447                                }
5448                            }
5449                        }
5450                    })
5451                }
5452
5453                Thunk::Sample {
5454                    logits,
5455                    dst,
5456                    batch,
5457                    vocab,
5458                    top_k,
5459                    top_p,
5460                    temperature,
5461                    seed,
5462                } => {
5463                    let (b, v) = (batch as usize, vocab as usize);
5464                    let k = (top_k as usize).min(v);
5465                    Arc::new(move |base: *mut u8| unsafe {
5466                        let lg = sl(logits, base, b * v);
5467                        let out = sl_mut(dst, base, b);
5468                        let mut rng =
5469                            rlx_ir::Philox4x32::new(if seed == 0 { 0xDEADBEEF } else { seed });
5470                        for bi in 0..b {
5471                            let row = &lg[bi * v..(bi + 1) * v];
5472                            out[bi] = sample_row(row, k, top_p, temperature, &mut rng) as f32;
5473                        }
5474                    })
5475                }
5476
5477                Thunk::DequantMatMul {
5478                    x,
5479                    w_q,
5480                    scale,
5481                    zp,
5482                    dst,
5483                    m,
5484                    k,
5485                    n,
5486                    block_size,
5487                    is_asymmetric,
5488                } => {
5489                    let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
5490                    let n_blocks_per_col = k.div_ceil(bs);
5491                    Arc::new(move |base: *mut u8| unsafe {
5492                        let xs = sl(x, base, m * k);
5493                        // w_q is packed i8 — use raw byte slice + reinterpret.
5494                        let raw = base.add(w_q);
5495                        let w_bytes = std::slice::from_raw_parts(raw as *const i8, k * n);
5496                        let scales = sl(scale, base, n_blocks_per_col * n);
5497                        let zps = if is_asymmetric {
5498                            sl(zp, base, n_blocks_per_col * n)
5499                        } else {
5500                            &[][..]
5501                        };
5502                        let out = sl_mut(dst, base, m * n);
5503                        dequant_matmul_int8(
5504                            xs,
5505                            w_bytes,
5506                            scales,
5507                            zps,
5508                            out,
5509                            m,
5510                            k,
5511                            n,
5512                            bs,
5513                            is_asymmetric,
5514                        );
5515                    })
5516                }
5517
5518                Thunk::DequantMatMulGguf {
5519                    x,
5520                    w_q,
5521                    dst,
5522                    m,
5523                    k,
5524                    n,
5525                    scheme,
5526                } => {
5527                    let (m, k, n) = (m as usize, k as usize, n as usize);
5528                    let block_bytes = scheme.gguf_block_bytes() as usize;
5529                    let block_elems = scheme.gguf_block_size() as usize;
5530                    let total_bytes = (k * n) / block_elems * block_bytes;
5531                    Arc::new(move |base: *mut u8| unsafe {
5532                        let xs = sl(x, base, m * k);
5533                        let w_bytes =
5534                            std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
5535                        let out = sl_mut(dst, base, m * n);
5536                        crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
5537                    })
5538                }
5539
5540                Thunk::DequantMatMulInt4 {
5541                    x,
5542                    w_q,
5543                    scale,
5544                    zp,
5545                    dst,
5546                    m,
5547                    k,
5548                    n,
5549                    block_size,
5550                    is_asymmetric,
5551                } => {
5552                    let (m, k, n, bs) = (m as usize, k as usize, n as usize, block_size as usize);
5553                    let n_blocks = k.div_ceil(bs);
5554                    Arc::new(move |base: *mut u8| unsafe {
5555                        let xs = sl(x, base, m * k);
5556                        let w_bytes = std::slice::from_raw_parts(
5557                            base.add(w_q) as *const u8,
5558                            (k * n).div_ceil(2),
5559                        );
5560                        let scales = sl(scale, base, n_blocks * n);
5561                        let zps = if is_asymmetric {
5562                            sl(zp, base, n_blocks * n)
5563                        } else {
5564                            &[][..]
5565                        };
5566                        let out = sl_mut(dst, base, m * n);
5567                        dequant_matmul_int4(
5568                            xs,
5569                            w_bytes,
5570                            scales,
5571                            zps,
5572                            out,
5573                            m,
5574                            k,
5575                            n,
5576                            bs,
5577                            is_asymmetric,
5578                        );
5579                    })
5580                }
5581
5582                Thunk::DequantMatMulFp8 {
5583                    x,
5584                    w_q,
5585                    scale,
5586                    dst,
5587                    m,
5588                    k,
5589                    n,
5590                    e5m2,
5591                } => {
5592                    let (m, k, n) = (m as usize, k as usize, n as usize);
5593                    Arc::new(move |base: *mut u8| unsafe {
5594                        let xs = sl(x, base, m * k);
5595                        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
5596                        let scales = sl(scale, base, n);
5597                        let out = sl_mut(dst, base, m * n);
5598                        dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
5599                    })
5600                }
5601
5602                Thunk::DequantMatMulNvfp4 {
5603                    x,
5604                    w_q,
5605                    scale,
5606                    global_scale,
5607                    dst,
5608                    m,
5609                    k,
5610                    n,
5611                } => {
5612                    let (m, k, n) = (m as usize, k as usize, n as usize);
5613                    let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
5614                    Arc::new(move |base: *mut u8| unsafe {
5615                        let xs = sl(x, base, m * k);
5616                        let w_bytes = std::slice::from_raw_parts(
5617                            base.add(w_q) as *const u8,
5618                            (k * n).div_ceil(2),
5619                        );
5620                        let scale_bytes =
5621                            std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
5622                        let gs = sl(global_scale, base, 1)[0];
5623                        let out = sl_mut(dst, base, m * n);
5624                        dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
5625                    })
5626                }
5627
5628                Thunk::LoraMatMul {
5629                    x,
5630                    w,
5631                    a,
5632                    b,
5633                    dst,
5634                    m,
5635                    k,
5636                    n,
5637                    r,
5638                    scale,
5639                } => {
5640                    let (m, k, n, r) = (m as usize, k as usize, n as usize, r as usize);
5641                    Arc::new(move |base: *mut u8| unsafe {
5642                        let xs = sl(x, base, m * k);
5643                        let ws = sl(w, base, k * n);
5644                        let a_s = sl(a, base, k * r);
5645                        let bs = sl(b, base, r * n);
5646                        let out = sl_mut(dst, base, m * n);
5647                        // Step 1: out = x · W.
5648                        crate::blas::sgemm(xs, ws, out, m, k, n);
5649                        // Step 2: tmp = x · A (rank-r intermediate; tiny).
5650                        let mut tmp = vec![0f32; m * r];
5651                        crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
5652                        // Step 3: out += scale * (tmp · B).
5653                        // sgemm_accumulate uses alpha=1.0 internally, so
5654                        // scale tmp first.
5655                        if scale != 1.0 {
5656                            for v in tmp.iter_mut() {
5657                                *v *= scale;
5658                            }
5659                        }
5660                        crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
5661                    })
5662                }
5663
5664                Thunk::LayerNorm {
5665                    src,
5666                    g,
5667                    b,
5668                    dst,
5669                    rows,
5670                    h,
5671                    eps,
5672                } => {
5673                    let (rows, h) = (rows as usize, h as usize);
5674                    Arc::new(move |base: *mut u8| unsafe {
5675                        let inp = sl(src, base, rows * h);
5676                        let gamma = sl(g, base, h);
5677                        let beta = sl(b, base, h);
5678                        let out = sl_mut(dst, base, rows * h);
5679                        for row in 0..rows {
5680                            crate::kernels::layer_norm_row(
5681                                &inp[row * h..(row + 1) * h],
5682                                gamma,
5683                                beta,
5684                                &mut out[row * h..(row + 1) * h],
5685                                h,
5686                                eps,
5687                            );
5688                        }
5689                    })
5690                }
5691
5692                Thunk::Attention {
5693                    q,
5694                    k,
5695                    v,
5696                    mask,
5697                    out,
5698                    batch,
5699                    seq,
5700                    kv_seq: _,
5701                    heads,
5702                    head_dim,
5703                    mask_kind,
5704                    q_row_stride,
5705                    k_row_stride,
5706                    v_row_stride,
5707                    bhsd,
5708                } => {
5709                    let (b, s, nh, dh) = (
5710                        batch as usize,
5711                        seq as usize,
5712                        heads as usize,
5713                        head_dim as usize,
5714                    );
5715                    let hs = nh * dh;
5716                    let qrs = q_row_stride as usize;
5717                    let krs = k_row_stride as usize;
5718                    let vrs = v_row_stride as usize;
5719                    let scale = (dh as f32).powf(-0.5);
5720                    Arc::new(move |base: *mut u8| unsafe {
5721                        // Slice lengths use the source's row stride so the
5722                        // compiler-emitted bounds checks cover the whole
5723                        // strided span (the kernel walks with q/k/v_rs).
5724                        // For [B, H, S, D] the buffer is dense B*H*S*D.
5725                        let (q_len, k_len, v_len, o_len) = if bhsd {
5726                            let n = b * nh * s * dh;
5727                            (n, n, n, n)
5728                        } else {
5729                            (b * s * qrs, b * s * krs, b * s * vrs, b * s * hs)
5730                        };
5731                        let q_d = sl(q, base, q_len);
5732                        let k_d = sl(k, base, k_len);
5733                        let v_d = sl(v, base, v_len);
5734                        let m_d: &[f32] = match mask_kind {
5735                            rlx_ir::op::MaskKind::Custom => sl(mask, base, b * s),
5736                            rlx_ir::op::MaskKind::Bias => sl(mask, base, b * nh * s * s),
5737                            _ => &[],
5738                        };
5739                        let o_d = sl_mut(out, base, o_len);
5740                        let sdh = s * dh;
5741                        let mut qh = vec![0f32; sdh];
5742                        let mut kh = vec![0f32; sdh];
5743                        let mut vh = vec![0f32; sdh];
5744                        let mut sc = vec![0f32; s * s];
5745                        let mut oh = vec![0f32; sdh];
5746                        for bi in 0..b {
5747                            for hi in 0..nh {
5748                                for si in 0..s {
5749                                    // Two layouts:
5750                                    //   bhsd=false: [B, S, H, D] (default) →
5751                                    //     off = bi*S*RS + si*RS + hi*D
5752                                    //   bhsd=true:  [B, H, S, D] (GPU/TPU
5753                                    //     convention) →
5754                                    //     off = bi*H*S*D + hi*S*D + si*D
5755                                    // The thunk-fusion pass below sets row
5756                                    // strides, but only for the [B, S, H, D]
5757                                    // case. For bhsd we always use the dense
5758                                    // contiguous stride (qrs == krs == vrs ==
5759                                    // H*D from compile_thunks).
5760                                    let (q_off, k_off, v_off) = if bhsd {
5761                                        (
5762                                            bi * nh * s * dh + hi * s * dh + si * dh,
5763                                            bi * nh * s * dh + hi * s * dh + si * dh,
5764                                            bi * nh * s * dh + hi * s * dh + si * dh,
5765                                        )
5766                                    } else {
5767                                        (
5768                                            bi * s * qrs + si * qrs + hi * dh,
5769                                            bi * s * krs + si * krs + hi * dh,
5770                                            bi * s * vrs + si * vrs + hi * dh,
5771                                        )
5772                                    };
5773                                    qh[si * dh..(si + 1) * dh]
5774                                        .copy_from_slice(&q_d[q_off..q_off + dh]);
5775                                    kh[si * dh..(si + 1) * dh]
5776                                        .copy_from_slice(&k_d[k_off..k_off + dh]);
5777                                    vh[si * dh..(si + 1) * dh]
5778                                        .copy_from_slice(&v_d[v_off..v_off + dh]);
5779                                }
5780                                for qi in 0..s {
5781                                    for ki in 0..s {
5782                                        let mut dot = 0f32;
5783                                        for d in 0..dh {
5784                                            dot += qh[qi * dh + d] * kh[ki * dh + d];
5785                                        }
5786                                        sc[qi * s + ki] = dot * scale;
5787                                    }
5788                                }
5789                                // Apply mask kind — None skips entirely, Causal /
5790                                // SlidingWindow synthesize, Custom reads m_d.
5791                                match mask_kind {
5792                                    rlx_ir::op::MaskKind::None => {}
5793                                    rlx_ir::op::MaskKind::Causal => {
5794                                        for qi in 0..s {
5795                                            for ki in (qi + 1)..s {
5796                                                sc[qi * s + ki] = mask_neg;
5797                                            }
5798                                        }
5799                                    }
5800                                    rlx_ir::op::MaskKind::SlidingWindow(w) => {
5801                                        for qi in 0..s {
5802                                            let lo = qi.saturating_sub(w);
5803                                            for ki in 0..s {
5804                                                if ki < lo || ki > qi {
5805                                                    sc[qi * s + ki] = mask_neg;
5806                                                }
5807                                            }
5808                                        }
5809                                    }
5810                                    rlx_ir::op::MaskKind::Custom => {
5811                                        for qi in 0..s {
5812                                            for ki in 0..s {
5813                                                if m_d[bi * s + ki] < mask_thr {
5814                                                    sc[qi * s + ki] = mask_neg;
5815                                                }
5816                                            }
5817                                        }
5818                                    }
5819                                    rlx_ir::op::MaskKind::Bias => {
5820                                        let per_bh = s * s;
5821                                        let off = (bi * nh + hi) * per_bh;
5822                                        for i in 0..per_bh {
5823                                            sc[i] += m_d[off + i];
5824                                        }
5825                                    }
5826                                }
5827                                crate::naive::softmax(&mut sc, s, s);
5828                                oh.fill(0.0);
5829                                for qi in 0..s {
5830                                    for ki in 0..s {
5831                                        let w = sc[qi * s + ki];
5832                                        if w > score_skip {
5833                                            for d in 0..dh {
5834                                                oh[qi * dh + d] += w * vh[ki * dh + d];
5835                                            }
5836                                        }
5837                                    }
5838                                }
5839                                for si in 0..s {
5840                                    let off = if bhsd {
5841                                        bi * nh * s * dh + hi * s * dh + si * dh
5842                                    } else {
5843                                        bi * s * hs + si * hs + hi * dh
5844                                    };
5845                                    o_d[off..off + dh].copy_from_slice(&oh[si * dh..(si + 1) * dh]);
5846                                }
5847                            }
5848                        }
5849                    })
5850                }
5851
5852                Thunk::FusedSwiGLU {
5853                    src,
5854                    dst,
5855                    n_half,
5856                    total,
5857                    gate_first,
5858                } => {
5859                    let n = n_half as usize;
5860                    let t = total as usize;
5861                    let outer = t / n;
5862                    let in_total = outer * 2 * n;
5863                    Arc::new(move |base: *mut u8| unsafe {
5864                        let inp = sl(src, base, in_total);
5865                        let out = sl_mut(dst, base, t);
5866                        for o in 0..outer {
5867                            let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
5868                            let out_row = &mut out[o * n..(o + 1) * n];
5869                            for i in 0..n {
5870                                let (up, gate) = if gate_first {
5871                                    (in_row[n + i], in_row[i])
5872                                } else {
5873                                    (in_row[i], in_row[n + i])
5874                                };
5875                                out_row[i] = up * (gate / (1.0 + (-gate).exp()));
5876                            }
5877                        }
5878                    })
5879                }
5880
5881                Thunk::Concat {
5882                    dst,
5883                    outer,
5884                    inner,
5885                    total_axis,
5886                    inputs,
5887                } => {
5888                    let outer = outer as usize;
5889                    let inner = inner as usize;
5890                    let total_axis = total_axis as usize;
5891                    let out_total = outer * total_axis * inner;
5892                    // Pre-compute the destination row offset for each input
5893                    // (cumulative axis offsets times inner).
5894                    let mut layout: Vec<(usize, usize, usize)> = Vec::with_capacity(inputs.len());
5895                    let mut cum: usize = 0;
5896                    for (src_off, in_axis) in &inputs {
5897                        let in_axis = *in_axis as usize;
5898                        layout.push((*src_off, cum * inner, in_axis * inner));
5899                        cum += in_axis;
5900                    }
5901                    Arc::new(move |base: *mut u8| unsafe {
5902                        let out = sl_mut(dst, base, out_total);
5903                        let row_stride = total_axis * inner;
5904                        for (src_off, dst_col_off, copy_per_row) in &layout {
5905                            let in_total = outer * *copy_per_row;
5906                            let inp = sl(*src_off, base, in_total);
5907                            for o in 0..outer {
5908                                let dst_row_start = o * row_stride + *dst_col_off;
5909                                let src_row_start = o * *copy_per_row;
5910                                out[dst_row_start..dst_row_start + *copy_per_row].copy_from_slice(
5911                                    &inp[src_row_start..src_row_start + *copy_per_row],
5912                                );
5913                            }
5914                        }
5915                    })
5916                }
5917
5918                Thunk::CustomOp {
5919                    kernel,
5920                    inputs,
5921                    output,
5922                    attrs,
5923                } => {
5924                    // Capture-by-move: clone the Arc and Vecs once into the
5925                    // closure. Dispatch by output dtype each call (the
5926                    // dtype is fixed at compile time but it's cheaper to
5927                    // branch once per execution than to monomorphize a
5928                    // dozen closure variants).
5929                    let kernel = kernel.clone();
5930                    let attrs = attrs.clone();
5931                    let inputs = inputs.clone();
5932                    let (out_off, out_len, out_shape) = output.clone();
5933                    Arc::new(move |base: *mut u8| unsafe {
5934                        dispatch_custom_op(
5935                            &*kernel, &inputs, out_off, out_len, &out_shape, &attrs, base,
5936                        );
5937                    })
5938                }
5939
5940                Thunk::GaussianSplatRender {
5941                    positions_off,
5942                    positions_len,
5943                    scales_off,
5944                    scales_len,
5945                    rotations_off,
5946                    rotations_len,
5947                    opacities_off,
5948                    opacities_len,
5949                    colors_off,
5950                    colors_len,
5951                    sh_coeffs_off,
5952                    sh_coeffs_len,
5953                    meta_off,
5954                    dst_off,
5955                    dst_len,
5956                    width,
5957                    height,
5958                    tile_size,
5959                    radius_scale,
5960                    alpha_cutoff,
5961                    max_splat_steps,
5962                    transmittance_threshold,
5963                    max_list_entries,
5964                } => Arc::new(move |base: *mut u8| unsafe {
5965                    crate::splat::execute_gaussian_splat_render(
5966                        positions_off,
5967                        positions_len,
5968                        scales_off,
5969                        scales_len,
5970                        rotations_off,
5971                        rotations_len,
5972                        opacities_off,
5973                        opacities_len,
5974                        colors_off,
5975                        colors_len,
5976                        sh_coeffs_off,
5977                        sh_coeffs_len,
5978                        meta_off,
5979                        dst_off,
5980                        dst_len,
5981                        width,
5982                        height,
5983                        tile_size,
5984                        radius_scale,
5985                        alpha_cutoff,
5986                        max_splat_steps,
5987                        transmittance_threshold,
5988                        max_list_entries,
5989                        base,
5990                    );
5991                }),
5992
5993                Thunk::GaussianSplatRenderBackward {
5994                    positions_off,
5995                    positions_len,
5996                    scales_off,
5997                    scales_len,
5998                    rotations_off,
5999                    rotations_len,
6000                    opacities_off,
6001                    opacities_len,
6002                    colors_off,
6003                    colors_len,
6004                    sh_coeffs_off,
6005                    sh_coeffs_len,
6006                    meta_off,
6007                    d_loss_off,
6008                    d_loss_len,
6009                    packed_off,
6010                    packed_len,
6011                    width,
6012                    height,
6013                    tile_size,
6014                    radius_scale,
6015                    alpha_cutoff,
6016                    max_splat_steps,
6017                    transmittance_threshold,
6018                    max_list_entries,
6019                    loss_grad_clip,
6020                    sh_band,
6021                    max_anisotropy,
6022                } => Arc::new(move |base: *mut u8| unsafe {
6023                    crate::splat::execute_gaussian_splat_render_backward(
6024                        positions_off,
6025                        positions_len,
6026                        scales_off,
6027                        scales_len,
6028                        rotations_off,
6029                        rotations_len,
6030                        opacities_off,
6031                        opacities_len,
6032                        colors_off,
6033                        colors_len,
6034                        sh_coeffs_off,
6035                        sh_coeffs_len,
6036                        meta_off,
6037                        d_loss_off,
6038                        d_loss_len,
6039                        packed_off,
6040                        packed_len,
6041                        width,
6042                        height,
6043                        tile_size,
6044                        radius_scale,
6045                        alpha_cutoff,
6046                        max_splat_steps,
6047                        transmittance_threshold,
6048                        max_list_entries,
6049                        loss_grad_clip,
6050                        sh_band,
6051                        max_anisotropy,
6052                        base,
6053                    );
6054                }),
6055
6056                Thunk::GaussianSplatPrepare {
6057                    positions_off,
6058                    positions_len,
6059                    scales_off,
6060                    scales_len,
6061                    rotations_off,
6062                    rotations_len,
6063                    opacities_off,
6064                    opacities_len,
6065                    colors_off,
6066                    colors_len,
6067                    sh_coeffs_off,
6068                    sh_coeffs_len,
6069                    meta_off,
6070                    meta_len,
6071                    prep_off,
6072                    prep_len,
6073                    width,
6074                    height,
6075                    tile_size,
6076                    radius_scale,
6077                    alpha_cutoff,
6078                    max_splat_steps,
6079                    transmittance_threshold,
6080                    max_list_entries,
6081                } => Arc::new(move |base: *mut u8| unsafe {
6082                    crate::splat::execute_gaussian_splat_prepare(
6083                        positions_off,
6084                        positions_len,
6085                        scales_off,
6086                        scales_len,
6087                        rotations_off,
6088                        rotations_len,
6089                        opacities_off,
6090                        opacities_len,
6091                        colors_off,
6092                        colors_len,
6093                        sh_coeffs_off,
6094                        sh_coeffs_len,
6095                        meta_off,
6096                        meta_len,
6097                        prep_off,
6098                        prep_len,
6099                        width,
6100                        height,
6101                        tile_size,
6102                        radius_scale,
6103                        alpha_cutoff,
6104                        max_splat_steps,
6105                        transmittance_threshold,
6106                        max_list_entries,
6107                        base,
6108                    );
6109                }),
6110
6111                Thunk::GaussianSplatRasterize {
6112                    prep_off,
6113                    prep_len,
6114                    meta_off,
6115                    meta_len,
6116                    dst_off,
6117                    dst_len,
6118                    count,
6119                    width,
6120                    height,
6121                    tile_size,
6122                    alpha_cutoff,
6123                    max_splat_steps,
6124                    transmittance_threshold,
6125                    max_list_entries,
6126                } => Arc::new(move |base: *mut u8| unsafe {
6127                    crate::splat::execute_gaussian_splat_rasterize(
6128                        prep_off,
6129                        prep_len,
6130                        meta_off,
6131                        meta_len,
6132                        dst_off,
6133                        dst_len,
6134                        count,
6135                        width,
6136                        height,
6137                        tile_size,
6138                        alpha_cutoff,
6139                        max_splat_steps,
6140                        transmittance_threshold,
6141                        max_list_entries,
6142                        base,
6143                    );
6144                }),
6145
6146                Thunk::Fft1d {
6147                    src,
6148                    dst,
6149                    outer,
6150                    n_complex,
6151                    inverse,
6152                    norm_tag,
6153                    dtype,
6154                } => {
6155                    let f: Arc<dyn Fn(*mut u8) + Send + Sync> = match dtype {
6156                        rlx_ir::DType::F64 => Arc::new(move |base: *mut u8| unsafe {
6157                            execute_fft1d_f64(
6158                                src,
6159                                dst,
6160                                outer as usize,
6161                                n_complex as usize,
6162                                inverse,
6163                                norm_tag,
6164                                base,
6165                            );
6166                        }),
6167                        rlx_ir::DType::F32 => Arc::new(move |base: *mut u8| unsafe {
6168                            execute_fft1d_f32(
6169                                src,
6170                                dst,
6171                                outer as usize,
6172                                n_complex as usize,
6173                                inverse,
6174                                norm_tag,
6175                                base,
6176                            );
6177                        }),
6178                        rlx_ir::DType::C64 => Arc::new(move |base: *mut u8| unsafe {
6179                            execute_fft1d_c64(
6180                                src,
6181                                dst,
6182                                outer as usize,
6183                                n_complex as usize,
6184                                inverse,
6185                                norm_tag,
6186                                base,
6187                            );
6188                        }),
6189                        other => panic!("Op::Fft on CPU requires F32/F64/C64, got {other:?}"),
6190                    };
6191                    f
6192                }
6193
6194                _ => Arc::new(|_: *mut u8| {}),
6195            }
6196        })
6197        .collect();
6198
6199    // ── Thunk-level attention fusion ──────────────────────
6200    // For small batch*seq, fuse QKV→Narrow×3→[Rope×2]→Attention→OutProj
6201    // into a single FusedAttnBlock. Auto-detects from Attention thunks.
6202    let fuse_threshold: usize = rlx_ir::env::var("RLX_FUSE_ATTN_THRESHOLD")
6203        .and_then(|v| v.parse().ok())
6204        .unwrap_or(64);
6205    let should_fuse = thunks.iter().any(|t| match t {
6206        Thunk::Attention { batch, seq, .. } => {
6207            (*batch as usize) * (*seq as usize) <= fuse_threshold
6208        }
6209        _ => false,
6210    });
6211
6212    if should_fuse {
6213        // Build non-Nop index for pattern matching across Nop gaps
6214        let active: Vec<usize> = thunks
6215            .iter()
6216            .enumerate()
6217            .filter(|(_, t)| !matches!(t, Thunk::Nop))
6218            .map(|(i, _)| i)
6219            .collect();
6220
6221        let mut kill = vec![false; thunks.len()]; // mark thunks to remove
6222        let mut insertions: Vec<(usize, Thunk)> = Vec::new(); // (position, replacement)
6223
6224        let mut ai = 0;
6225        while ai < active.len() {
6226            // Helper: get active thunk at offset from current
6227            let a = |off: usize| -> Option<(usize, &Thunk)> {
6228                active.get(ai + off).map(|&idx| (idx, &thunks[idx]))
6229            };
6230
6231            // Try BERT pattern: FusedMmBiasAct(QKV) → Narrow×3 → Attention → FusedMmBiasAct(out)
6232            let matched = (|| {
6233                let (_i0, t0) = a(0)?;
6234                let (_, t1) = a(1)?;
6235                let (_, t2) = a(2)?;
6236                let (_, t3) = a(3)?;
6237
6238                // a[0] must be FusedMmBiasAct or Sgemm (QKV projection)
6239                let (hidden, qkv_w, qkv_b, has_b) = match t0 {
6240                    Thunk::FusedMmBiasAct {
6241                        a,
6242                        w,
6243                        bias,
6244                        n: _,
6245                        act: None,
6246                        ..
6247                    } => (*a, *w, *bias, true),
6248                    Thunk::Sgemm { a, b, n: _, .. } => (*a, *b, 0, false),
6249                    _ => return None,
6250                };
6251
6252                // a[1..3] must be Narrows
6253                if !matches!(t1, Thunk::Narrow { .. }) {
6254                    return None;
6255                }
6256                if !matches!(t2, Thunk::Narrow { .. }) {
6257                    return None;
6258                }
6259                if !matches!(t3, Thunk::Narrow { .. }) {
6260                    return None;
6261                }
6262
6263                // Look for optional Rope×2 then Attention
6264                let (has_rope, attn_ai, cos_off, sin_off, cl) = if let Some((
6265                    _,
6266                    Thunk::Rope {
6267                        cos, sin, cos_len, ..
6268                    },
6269                )) = a(4)
6270                {
6271                    if matches!(a(5).map(|x| x.1), Some(Thunk::Rope { .. })) {
6272                        if matches!(a(6).map(|x| x.1), Some(Thunk::Attention { .. })) {
6273                            (true, 6, *cos, *sin, *cos_len)
6274                        } else {
6275                            return None;
6276                        }
6277                    } else {
6278                        return None;
6279                    }
6280                } else if matches!(a(4).map(|x| x.1), Some(Thunk::Attention { .. })) {
6281                    (false, 4, 0, 0, 0)
6282                } else {
6283                    return None;
6284                };
6285
6286                let (_attn_real_idx, attn_t) = a(attn_ai)?;
6287                let (batch, seq, heads, head_dim, mask) = match attn_t {
6288                    Thunk::Attention {
6289                        batch,
6290                        seq,
6291                        heads,
6292                        head_dim,
6293                        mask,
6294                        ..
6295                    } => (*batch, *seq, *heads, *head_dim, *mask),
6296                    _ => return None,
6297                };
6298
6299                // Next active must be out projection (FusedMmBiasAct or Sgemm)
6300                let (_out_real_idx, out_t) = a(attn_ai + 1)?;
6301                let (out_w, out_b, out_dst) = match out_t {
6302                    Thunk::FusedMmBiasAct {
6303                        w,
6304                        bias,
6305                        c,
6306                        act: None,
6307                        ..
6308                    } => (*w, *bias, *c),
6309                    Thunk::Sgemm { b: w, c, .. } => (*w, 0, *c),
6310                    _ => return None,
6311                };
6312
6313                let hs = heads * head_dim;
6314                let total_active = attn_ai + 2; // number of active thunks consumed
6315
6316                Some((
6317                    total_active,
6318                    Thunk::FusedAttnBlock {
6319                        hidden,
6320                        qkv_w,
6321                        out_w,
6322                        mask,
6323                        out: out_dst,
6324                        qkv_b: if has_b { qkv_b } else { 0 },
6325                        out_b: if has_b { out_b } else { 0 },
6326                        cos: cos_off,
6327                        sin: sin_off,
6328                        cos_len: cl,
6329                        batch,
6330                        seq,
6331                        hs,
6332                        nh: heads,
6333                        dh: head_dim,
6334                        has_bias: has_b,
6335                        has_rope,
6336                    },
6337                ))
6338            })();
6339
6340            if let Some((count, fused_thunk)) = matched {
6341                // Mark consumed thunks for removal
6342                for off in 0..count {
6343                    if let Some(&idx) = active.get(ai + off) {
6344                        kill[idx] = true;
6345                    }
6346                }
6347                // Insert replacement at position of the QKV thunk
6348                insertions.push((active[ai], fused_thunk));
6349                ai += count;
6350            } else {
6351                ai += 1;
6352            }
6353        }
6354
6355        // Rebuild thunk list: keep non-killed, insert fused at right positions
6356        if !insertions.is_empty() {
6357            let mut new_thunks = Vec::with_capacity(thunks.len());
6358            let mut insert_idx = 0;
6359            for (i, t) in thunks.into_iter().enumerate() {
6360                if insert_idx < insertions.len() && insertions[insert_idx].0 == i {
6361                    new_thunks.push(insertions[insert_idx].1.clone());
6362                    insert_idx += 1;
6363                }
6364                if !kill[i] {
6365                    new_thunks.push(t);
6366                }
6367            }
6368            if cfg.verbose >= 1 {
6369                eprintln!(
6370                    "[rlx] fused_attention: {} attention blocks fused",
6371                    insertions.len()
6372                );
6373            }
6374            thunks = new_thunks;
6375        }
6376    }
6377
6378    // ── Full layer fusion ──────────────────────────────────
6379    // After attention blocks are fused, scan for full layer patterns:
6380    // BERT:  FusedAttnBlock → FusedResidualLN → FusedMmBiasAct(gelu) → Sgemm → BiasAdd → FusedResidualLN
6381    // Nomic: FusedAttnBlock → BinaryFull(add) → LayerNorm → Sgemm → [Narrow×2 → Silu → BinaryFull(mul)] → Sgemm → BinaryFull(add) → LayerNorm
6382    if should_fuse {
6383        let active: Vec<usize> = thunks
6384            .iter()
6385            .enumerate()
6386            .filter(|(_, t)| !matches!(t, Thunk::Nop))
6387            .map(|(i, _)| i)
6388            .collect();
6389
6390        let mut kill = vec![false; thunks.len()];
6391        let mut insertions: Vec<(usize, Thunk)> = Vec::new();
6392
6393        let a = |ai: usize| -> Option<&Thunk> { active.get(ai).map(|&i| &thunks[i]) };
6394
6395        let mut ai = 0;
6396        while ai < active.len() {
6397            // BERT pattern: FusedAttnBlock → FusedResidualLN → FusedMmBiasAct(gelu) → FusedMmBiasAct(none) → FusedResidualLN
6398            let bert_match = (|| -> Option<usize> {
6399                let fab = a(ai)?;
6400                let rln1 = a(ai + 1)?;
6401                let ffn1 = a(ai + 2)?;
6402                let ffn2 = a(ai + 3)?;
6403                let rln2 = a(ai + 4)?;
6404
6405                let (hidden, qkv_w, qkv_b, out_w, out_b, mask, batch, seq, hs, nh, dh) = match fab {
6406                    Thunk::FusedAttnBlock {
6407                        hidden,
6408                        qkv_w,
6409                        qkv_b,
6410                        out_w,
6411                        out_b,
6412                        mask,
6413                        batch,
6414                        seq,
6415                        hs,
6416                        nh,
6417                        dh,
6418                        has_bias: true,
6419                        has_rope: false,
6420                        ..
6421                    } => (
6422                        *hidden, *qkv_w, *qkv_b, *out_w, *out_b, *mask, *batch, *seq, *hs, *nh, *dh,
6423                    ),
6424                    _ => return None,
6425                };
6426                let (ln1_g, ln1_b, eps1) = match rln1 {
6427                    Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
6428                    _ => return None,
6429                };
6430                let (fc1_w, fc1_b, int_dim) = match ffn1 {
6431                    Thunk::FusedMmBiasAct {
6432                        w,
6433                        bias,
6434                        n,
6435                        act: Some(Activation::Gelu),
6436                        ..
6437                    } => (*w, *bias, *n),
6438                    _ => return None,
6439                };
6440                let (fc2_w, fc2_b) = match ffn2 {
6441                    Thunk::FusedMmBiasAct {
6442                        w, bias, act: None, ..
6443                    } => (*w, *bias),
6444                    _ => return None,
6445                };
6446                let (ln2_g, ln2_b, eps2, out) = match rln2 {
6447                    Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
6448                    _ => return None,
6449                };
6450
6451                for off in 0..5 {
6452                    kill[active[ai + off]] = true;
6453                }
6454                insertions.push((
6455                    active[ai],
6456                    Thunk::FusedBertLayer {
6457                        hidden,
6458                        qkv_w,
6459                        qkv_b,
6460                        out_w,
6461                        out_b,
6462                        mask,
6463                        ln1_g,
6464                        ln1_b,
6465                        eps1,
6466                        fc1_w,
6467                        fc1_b,
6468                        fc2_w,
6469                        fc2_b,
6470                        ln2_g,
6471                        ln2_b,
6472                        eps2,
6473                        out,
6474                        batch,
6475                        seq,
6476                        hs,
6477                        nh,
6478                        dh,
6479                        int_dim,
6480                    },
6481                ));
6482                Some(5)
6483            })();
6484            if let Some(n) = bert_match {
6485                ai += n;
6486                continue;
6487            }
6488
6489            // Nomic full layer fusion — disabled pending SwiGLU stride debugging.
6490            // Nomic still benefits from FusedAttnBlock (attention-level fusion).
6491            // The body below is kept as reference for when the stride bug is fixed.
6492            #[allow(unreachable_code)]
6493            let nomic_match = (|| -> Option<usize> {
6494                return None; // TODO: fix SwiGLU strided fc2 output mismatch
6495                let fab = a(ai)?;
6496                let (hidden, qkv_w, out_w, mask, cos, sin, cos_len, batch, seq, hs, nh, dh) =
6497                    match fab {
6498                        Thunk::FusedAttnBlock {
6499                            hidden,
6500                            qkv_w,
6501                            out_w,
6502                            mask,
6503                            cos,
6504                            sin,
6505                            cos_len,
6506                            batch,
6507                            seq,
6508                            hs,
6509                            nh,
6510                            dh,
6511                            has_bias: false,
6512                            has_rope: true,
6513                            ..
6514                        } => (
6515                            *hidden, *qkv_w, *out_w, *mask, *cos, *sin, *cos_len, *batch, *seq,
6516                            *hs, *nh, *dh,
6517                        ),
6518                        _ => return None,
6519                    };
6520                // FusedResidualLN for LN1
6521                let (ln1_g, ln1_b, eps1) = match a(ai + 1)? {
6522                    Thunk::FusedResidualLN { g, b, eps, .. } => (*g, *b, *eps),
6523                    _ => return None,
6524                };
6525                // Sgemm (fused fc11+fc12)
6526                let fused_fc_w = match a(ai + 2)? {
6527                    Thunk::Sgemm { b: w, .. } => *w,
6528                    _ => return None,
6529                };
6530                // Narrow×2 for split
6531                if !matches!(a(ai + 3)?, Thunk::Narrow { .. }) {
6532                    return None;
6533                }
6534                if !matches!(a(ai + 4)?, Thunk::Narrow { .. }) {
6535                    return None;
6536                }
6537                // SiLU
6538                if !matches!(
6539                    a(ai + 5)?,
6540                    Thunk::ActivationInPlace {
6541                        act: Activation::Silu,
6542                        ..
6543                    }
6544                ) {
6545                    return None;
6546                }
6547                // BinaryFull(Mul) for gate
6548                if !matches!(
6549                    a(ai + 6)?,
6550                    Thunk::BinaryFull {
6551                        op: BinaryOp::Mul,
6552                        ..
6553                    }
6554                ) {
6555                    return None;
6556                }
6557                // Sgemm (fc2)
6558                let fc2_w = match a(ai + 7)? {
6559                    Thunk::Sgemm { b: w, .. } => *w,
6560                    _ => return None,
6561                };
6562                // Get int_dim from the Narrow (inner = int_dim for last-axis narrow)
6563                let int_dim = match a(ai + 3)? {
6564                    Thunk::Narrow { inner, .. } => *inner,
6565                    _ => return None,
6566                };
6567                // FusedResidualLN for LN2
6568                let (ln2_g, ln2_b, eps2, out) = match a(ai + 8)? {
6569                    Thunk::FusedResidualLN { g, b, eps, out, .. } => (*g, *b, *eps, *out),
6570                    _ => return None,
6571                };
6572
6573                for off in 0..9 {
6574                    kill[active[ai + off]] = true;
6575                }
6576                insertions.push((
6577                    active[ai],
6578                    Thunk::FusedNomicLayer {
6579                        hidden,
6580                        qkv_w,
6581                        out_w,
6582                        mask,
6583                        cos,
6584                        sin,
6585                        cos_len,
6586                        ln1_g,
6587                        ln1_b,
6588                        eps1,
6589                        fc11_w: fused_fc_w,
6590                        fc12_w: 0,
6591                        fc2_w,
6592                        ln2_g,
6593                        ln2_b,
6594                        eps2,
6595                        out,
6596                        batch,
6597                        seq,
6598                        hs,
6599                        nh,
6600                        dh,
6601                        int_dim,
6602                    },
6603                ));
6604                Some(9)
6605            })();
6606            if let Some(n) = nomic_match {
6607                ai += n;
6608                continue;
6609            }
6610
6611            ai += 1;
6612        }
6613
6614        if !insertions.is_empty() {
6615            let mut new_thunks = Vec::with_capacity(thunks.len());
6616            let mut ins_idx = 0;
6617            for (i, t) in thunks.into_iter().enumerate() {
6618                if ins_idx < insertions.len() && insertions[ins_idx].0 == i {
6619                    new_thunks.push(insertions[ins_idx].1.clone());
6620                    ins_idx += 1;
6621                }
6622                if !kill[i] {
6623                    new_thunks.push(t);
6624                }
6625            }
6626            if cfg.verbose >= 1 {
6627                eprintln!(
6628                    "[rlx] fused_layer: {} full transformer layers fused",
6629                    insertions.len()
6630                );
6631            }
6632            thunks = new_thunks;
6633        }
6634    }
6635
6636    // ── Narrow → Rope thunk fusion (plan #45) ──────────────
6637    // Runs *after* FusedAttnBlock fusion so it only catches the medium-
6638    // batch path (batch*seq > 64) where the bigger fusion didn't fire.
6639    // Pattern: a Rope thunk whose `src` is the dst of an immediately-
6640    // preceding Narrow whose dst has no other consumer in this schedule.
6641    // Rewrite Rope to read directly from the parent buffer with the
6642    // parent's row stride; the Narrow becomes a Nop.
6643    //
6644    // Skipping the Narrow's write saves one full pass over Q/K (B*S*hs
6645    // f32) per Rope. For Nomic h=768 / batch=8 / seq=15 / 12 layers
6646    // that's 2 ropes/layer × 369 KB = ~8.9 MB of write traffic gone.
6647    {
6648        // Collect every byte-offset that's read as a thunk's `src` so
6649        // we know whether a Narrow's dst has consumers other than Rope.
6650        let mut read_offsets: HashMap<usize, usize> = HashMap::new();
6651        for t in &thunks {
6652            for off in thunk_read_offsets(t) {
6653                *read_offsets.entry(off).or_insert(0) += 1;
6654            }
6655        }
6656
6657        let mut fused_count = 0usize;
6658        for i in 0..thunks.len().saturating_sub(1) {
6659            // Look for Rope at i+1 reading from Narrow at i (skip Nops
6660            // between them since the planner left them in place).
6661            let narrow = match &thunks[i] {
6662                Thunk::Narrow { .. } => i,
6663                _ => continue,
6664            };
6665            // Find the next non-Nop thunk
6666            let mut j = narrow + 1;
6667            while j < thunks.len() && matches!(thunks[j], Thunk::Nop) {
6668                j += 1;
6669            }
6670            if j >= thunks.len() {
6671                continue;
6672            }
6673            // Must be Rope reading Narrow's dst
6674            let (n_src, n_dst, n_src_stride) = match &thunks[narrow] {
6675                Thunk::Narrow {
6676                    src,
6677                    dst,
6678                    src_stride,
6679                    ..
6680                } => (*src, *dst, *src_stride),
6681                _ => continue,
6682            };
6683            let rope_reads_narrow = matches!(&thunks[j],
6684                Thunk::Rope { src, .. } if *src == n_dst);
6685            if !rope_reads_narrow {
6686                continue;
6687            }
6688            // Conservatively require that the Narrow's dst has exactly
6689            // one reader (the Rope). Anything else and rewriting would
6690            // skip a needed write.
6691            if read_offsets.get(&n_dst).copied().unwrap_or(0) != 1 {
6692                continue;
6693            }
6694
6695            // Rewire: Rope reads from Narrow's adjusted source with the
6696            // parent buffer's row stride.
6697            if let Thunk::Rope {
6698                src,
6699                src_row_stride,
6700                ..
6701            } = &mut thunks[j]
6702            {
6703                *src = n_src;
6704                *src_row_stride = n_src_stride;
6705            }
6706            thunks[narrow] = Thunk::Nop;
6707            fused_count += 1;
6708        }
6709
6710        if fused_count > 0 && cfg.verbose >= 1 {
6711            eprintln!(
6712                "[rlx] fused_qk_rope: {} Narrow→Rope pairs collapsed",
6713                fused_count
6714            );
6715        }
6716    }
6717
6718    // ── Narrow×3 → Attention thunk fusion (plan #46 deep) ────
6719    // For each Attention thunk in the schedule, look up the producers
6720    // of its q/k/v inputs. If each is a Narrow whose dst has exactly
6721    // one consumer (the Attention), rewire Attention to read directly
6722    // from the parent buffer with the parent's row stride. The three
6723    // Narrows become Nops.
6724    //
6725    // This catches the BERT/Nomic QKV split path that FusedAttnBlock
6726    // misses (batch*seq > 64) — eliminates Q/K/V copies entirely.
6727    // For minilm6 batch=32 seq=16 hs=384: 3 × 32*16*384*4 = 2.3 MB
6728    // per layer × 6 layers = ~14 MB of write traffic gone.
6729    {
6730        let mut read_counts: HashMap<usize, usize> = HashMap::new();
6731        for t in &thunks {
6732            for off in thunk_read_offsets(t) {
6733                *read_counts.entry(off).or_insert(0) += 1;
6734            }
6735        }
6736        // Build dst→index map for fast producer lookup.
6737        let mut dst_to_idx: HashMap<usize, usize> = HashMap::new();
6738        for (i, t) in thunks.iter().enumerate() {
6739            if let Thunk::Narrow { dst, .. } = t {
6740                dst_to_idx.insert(*dst, i);
6741            }
6742        }
6743
6744        let mut fused_count = 0usize;
6745        for i in 0..thunks.len() {
6746            let (q_off, k_off, v_off) = match &thunks[i] {
6747                Thunk::Attention { q, k, v, .. } => (*q, *k, *v),
6748                _ => continue,
6749            };
6750            // All three inputs must come from Narrows.
6751            let q_n = match dst_to_idx.get(&q_off).copied() {
6752                Some(x) => x,
6753                None => continue,
6754            };
6755            let k_n = match dst_to_idx.get(&k_off).copied() {
6756                Some(x) => x,
6757                None => continue,
6758            };
6759            let v_n = match dst_to_idx.get(&v_off).copied() {
6760                Some(x) => x,
6761                None => continue,
6762            };
6763            // Each Narrow's dst must have exactly one reader (this Attn).
6764            if read_counts.get(&q_off).copied().unwrap_or(0) != 1 {
6765                continue;
6766            }
6767            if read_counts.get(&k_off).copied().unwrap_or(0) != 1 {
6768                continue;
6769            }
6770            if read_counts.get(&v_off).copied().unwrap_or(0) != 1 {
6771                continue;
6772            }
6773
6774            let (q_src, q_stride) = match &thunks[q_n] {
6775                Thunk::Narrow {
6776                    src, src_stride, ..
6777                } => (*src, *src_stride),
6778                _ => continue,
6779            };
6780            let (k_src, k_stride) = match &thunks[k_n] {
6781                Thunk::Narrow {
6782                    src, src_stride, ..
6783                } => (*src, *src_stride),
6784                _ => continue,
6785            };
6786            let (v_src, v_stride) = match &thunks[v_n] {
6787                Thunk::Narrow {
6788                    src, src_stride, ..
6789                } => (*src, *src_stride),
6790                _ => continue,
6791            };
6792
6793            if let Thunk::Attention {
6794                q,
6795                k,
6796                v,
6797                q_row_stride,
6798                k_row_stride,
6799                v_row_stride,
6800                ..
6801            } = &mut thunks[i]
6802            {
6803                *q = q_src;
6804                *k = k_src;
6805                *v = v_src;
6806                *q_row_stride = q_stride;
6807                *k_row_stride = k_stride;
6808                *v_row_stride = v_stride;
6809            }
6810            thunks[q_n] = Thunk::Nop;
6811            thunks[k_n] = Thunk::Nop;
6812            thunks[v_n] = Thunk::Nop;
6813            fused_count += 1;
6814        }
6815
6816        if fused_count > 0 && cfg.verbose >= 1 {
6817            eprintln!(
6818                "[rlx] fused_strided_attn: {} Narrow×3→Attention rewrites",
6819                fused_count
6820            );
6821        }
6822    }
6823
6824    ThunkSchedule {
6825        thunks,
6826        moe_resident: None,
6827        moe_resident_layers: None,
6828        moe_topk_capture: None,
6829        mask_threshold: cfg.mask_binary_threshold,
6830        mask_neg_inf: cfg.attn_mask_neg_inf,
6831        score_skip: cfg.score_skip_threshold,
6832        compiled_fns,
6833    }
6834}
6835
6836fn get_len(graph: &Graph, id: NodeId) -> usize {
6837    graph.node(id).shape.num_elements().unwrap_or(0)
6838}
6839
6840/// Static `usize` dims of a node's shape, or empty if any dim is dynamic.
6841fn get_static_dims(graph: &Graph, id: NodeId) -> Vec<usize> {
6842    let dims = graph.node(id).shape.dims();
6843    let mut out = Vec::with_capacity(dims.len());
6844    for d in dims {
6845        if let Some(s) = match d {
6846            rlx_ir::Dim::Static(s) => Some(*s),
6847            _ => None,
6848        } {
6849            out.push(s);
6850        } else {
6851            return Vec::new();
6852        }
6853    }
6854    out
6855}
6856
6857/// NumPy-style broadcast strides for one operand into the flat output
6858/// buffer. Returns a length-`out_dims.len()` `Vec<u32>` where entry
6859/// `d` is `0` if the input is size-1 (broadcast) at output dim `d`
6860/// (after left-padding with size-1 to match ranks), otherwise the
6861/// natural row-major stride into the *input* buffer.
6862///
6863/// Caller iterates output flat index `i` → output coords (row-major)
6864/// → input flat index = dot(coords, strides). The result is correct
6865/// for any broadcast pattern (scalar, last-axis, middle-axis,
6866/// bidirectional).
6867/// True when `rhs_dims` describes a *trailing* broadcast of `out_dims`
6868/// — i.e. every rhs dim either equals the corresponding output dim
6869/// (counting from the right) or rhs is shorter (left-padded with 1s).
6870/// Mid-shape singletons (e.g. rhs `[a, b, 1, d]` into out `[a, b, c, d]`
6871/// where `c > 1`) are NOT trailing broadcasts and require the
6872/// shape-aware `BinaryFull` slow path — `BiasAdd`'s linear bias-replicated
6873/// kernel silently miscomputes them.
6874fn is_trailing_bias_broadcast(rhs_dims: &[rlx_ir::Dim], out_dims: &[rlx_ir::Dim]) -> bool {
6875    if rhs_dims.len() > out_dims.len() {
6876        return false;
6877    }
6878    let off = out_dims.len() - rhs_dims.len();
6879    for i in 0..rhs_dims.len() {
6880        let r = match rhs_dims[i] {
6881            rlx_ir::Dim::Static(n) => n,
6882            _ => return false,
6883        };
6884        let o = match out_dims[off + i] {
6885            rlx_ir::Dim::Static(n) => n,
6886            _ => return false,
6887        };
6888        if r != o {
6889            return false;
6890        }
6891    }
6892    true
6893}
6894
6895fn broadcast_strides(in_dims: &[usize], out_dims: &[usize]) -> Vec<u32> {
6896    let r_out = out_dims.len();
6897    let r_in = in_dims.len();
6898    assert!(
6899        r_in <= r_out,
6900        "broadcast: input rank {r_in} > output rank {r_out}"
6901    );
6902    let pad = r_out - r_in;
6903    let mut strides = vec![0u32; r_out];
6904    let mut acc: usize = 1;
6905    for d in (0..r_out).rev() {
6906        let in_size = if d < pad { 1 } else { in_dims[d - pad] };
6907        if in_size == 1 {
6908            strides[d] = 0;
6909        } else {
6910            assert_eq!(
6911                in_size, out_dims[d],
6912                "broadcast: input dim {in_size} doesn't match output dim {} at axis {d}",
6913                out_dims[d]
6914            );
6915            strides[d] = acc as u32;
6916            acc *= in_size;
6917        }
6918    }
6919    strides
6920}
6921
6922/// Execute a thunk schedule on a raw arena buffer.
6923/// Fastest executor: call pre-compiled closures sequentially.
6924/// Zero match dispatch — each closure is a direct kernel call.
6925pub fn execute_compiled(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
6926    let base = arena_buf.as_mut_ptr();
6927    for f in &schedule.compiled_fns {
6928        f(base);
6929    }
6930}
6931
6932/// Active-extent execution stub. The runtime calls this when it has an
6933/// active-extent hint set. CPU doesn't implement per-thunk active-extent
6934/// scaling yet — return false so the caller falls back to the full
6935/// `execute_thunks` path.
6936pub fn execute_thunks_active(
6937    schedule: &ThunkSchedule,
6938    _arena_buf: &mut [u8],
6939    _actual: usize,
6940    _upper: usize,
6941) -> bool {
6942    let _ = schedule;
6943    false
6944}
6945
6946/// Match-based executor (fallback, used by tests).
6947struct MoeResidencyGuard;
6948impl Drop for MoeResidencyGuard {
6949    fn drop(&mut self) {
6950        if let Some(stats) = crate::moe_residency::take_stats() {
6951            crate::moe_residency::stash_last_forward_stats(stats);
6952        } else {
6953            crate::moe_residency::clear_mask();
6954        }
6955    }
6956}
6957
6958pub fn execute_thunks(schedule: &ThunkSchedule, arena_buf: &mut [u8]) {
6959    crate::moe_residency::reset_gmm_counters();
6960    if let Some(layers) = schedule.moe_resident_layers.clone() {
6961        crate::moe_residency::set_per_layer_masks(Some(layers));
6962    } else {
6963        crate::moe_residency::set_mask(schedule.moe_resident.clone());
6964    }
6965    if let Some(cap) = schedule.moe_topk_capture.as_ref() {
6966        cap.clear();
6967    }
6968    let _moe_guard = MoeResidencyGuard;
6969    let base = arena_buf.as_mut_ptr();
6970    let mask_thr = schedule.mask_threshold;
6971    let mask_neg = schedule.mask_neg_inf;
6972    let score_thr = schedule.score_skip;
6973    let thunks = &schedule.thunks;
6974    let len = thunks.len();
6975
6976    // Pre-allocate ALL reusable buffers once (zero per-call allocation)
6977    let max_h = thunks
6978        .iter()
6979        .filter_map(|t| match t {
6980            Thunk::FusedResidualLN { h, .. }
6981            | Thunk::FusedResidualRmsNorm { h, .. }
6982            | Thunk::LayerNorm { h, .. } => Some(*h as usize),
6983            _ => None,
6984        })
6985        .max()
6986        .unwrap_or(0);
6987    let zero_bias = vec![0f32; max_h];
6988
6989    // Pre-allocate per-(batch,head) score buffers for parallel SDPA.
6990    // Q/K/V/out are accessed via strided BLAS — no deinterleave copy needed.
6991    let max_sdpa = thunks
6992        .iter()
6993        .filter_map(|t| match t {
6994            Thunk::Attention {
6995                batch,
6996                seq,
6997                kv_seq,
6998                heads,
6999                head_dim,
7000                ..
7001            } => Some((
7002                *batch as usize,
7003                (*seq as usize).max(*kv_seq as usize),
7004                *heads as usize,
7005                *head_dim as usize,
7006            )),
7007            _ => None,
7008        })
7009        .fold((0, 0, 0, 0), |(mb, ms, mh, md), (b, s, h, d)| {
7010            (mb.max(b), ms.max(s), mh.max(h), md.max(d))
7011        });
7012    let (max_batch, max_seq, max_heads, _max_dh) = max_sdpa;
7013    let max_units = max_batch * max_heads;
7014    let mut sdpa_scores = vec![0f32; max_units * max_seq * max_seq];
7015
7016    // Pre-allocate fused layer buffers (reused across all 12+ layers — zero malloc per layer)
7017    let fl = thunks
7018        .iter()
7019        .filter_map(|t| match t {
7020            Thunk::FusedBertLayer {
7021                batch,
7022                seq,
7023                hs,
7024                int_dim,
7025                ..
7026            } => {
7027                let m = (*batch as usize) * (*seq as usize);
7028                let h = *hs as usize;
7029                let id = *int_dim as usize;
7030                Some((m, h, id, m * (*seq as usize)))
7031            }
7032            Thunk::FusedNomicLayer {
7033                batch,
7034                seq,
7035                hs,
7036                int_dim,
7037                ..
7038            } => {
7039                let m = (*batch as usize) * (*seq as usize);
7040                let h = *hs as usize;
7041                let id = *int_dim as usize;
7042                Some((m, h, id, m * (*seq as usize)))
7043            }
7044            _ => None,
7045        })
7046        .fold((0, 0, 0, 0), |(mm, mh, mi, ms), (m, h, id, ss)| {
7047            (mm.max(m), mh.max(h), mi.max(id), ms.max(ss))
7048        });
7049    let (fl_m, fl_h, fl_int, fl_ss) = fl;
7050    let mut fl_qkv = vec![0f32; fl_m * 3 * fl_h];
7051    let mut fl_attn = vec![0f32; fl_m * fl_h];
7052    let mut fl_res = vec![0f32; fl_m * fl_h];
7053    let mut fl_normed = vec![0f32; fl_m * fl_h];
7054    let mut fl_ffn = vec![0f32; fl_m * fl_int.max(2 * fl_int)]; // Nomic needs 2×int for fused fc11+fc12
7055    let mut fl_sc = vec![0f32; fl_ss.max(1)];
7056
7057    for i in 0..len {
7058        let thunk = unsafe { thunks.get_unchecked(i) };
7059        match thunk {
7060            Thunk::Nop => {}
7061
7062            Thunk::GaussianSplatRender {
7063                positions_off,
7064                positions_len,
7065                scales_off,
7066                scales_len,
7067                rotations_off,
7068                rotations_len,
7069                opacities_off,
7070                opacities_len,
7071                colors_off,
7072                colors_len,
7073                sh_coeffs_off,
7074                sh_coeffs_len,
7075                meta_off,
7076                dst_off,
7077                dst_len,
7078                width,
7079                height,
7080                tile_size,
7081                radius_scale,
7082                alpha_cutoff,
7083                max_splat_steps,
7084                transmittance_threshold,
7085                max_list_entries,
7086            } => unsafe {
7087                crate::splat::execute_gaussian_splat_render(
7088                    *positions_off,
7089                    *positions_len,
7090                    *scales_off,
7091                    *scales_len,
7092                    *rotations_off,
7093                    *rotations_len,
7094                    *opacities_off,
7095                    *opacities_len,
7096                    *colors_off,
7097                    *colors_len,
7098                    *sh_coeffs_off,
7099                    *sh_coeffs_len,
7100                    *meta_off,
7101                    *dst_off,
7102                    *dst_len,
7103                    *width,
7104                    *height,
7105                    *tile_size,
7106                    *radius_scale,
7107                    *alpha_cutoff,
7108                    *max_splat_steps,
7109                    *transmittance_threshold,
7110                    *max_list_entries,
7111                    base,
7112                );
7113            },
7114
7115            Thunk::GaussianSplatRenderBackward {
7116                positions_off,
7117                positions_len,
7118                scales_off,
7119                scales_len,
7120                rotations_off,
7121                rotations_len,
7122                opacities_off,
7123                opacities_len,
7124                colors_off,
7125                colors_len,
7126                sh_coeffs_off,
7127                sh_coeffs_len,
7128                meta_off,
7129                d_loss_off,
7130                d_loss_len,
7131                packed_off,
7132                packed_len,
7133                width,
7134                height,
7135                tile_size,
7136                radius_scale,
7137                alpha_cutoff,
7138                max_splat_steps,
7139                transmittance_threshold,
7140                max_list_entries,
7141                loss_grad_clip,
7142                sh_band,
7143                max_anisotropy,
7144            } => unsafe {
7145                crate::splat::execute_gaussian_splat_render_backward(
7146                    *positions_off,
7147                    *positions_len,
7148                    *scales_off,
7149                    *scales_len,
7150                    *rotations_off,
7151                    *rotations_len,
7152                    *opacities_off,
7153                    *opacities_len,
7154                    *colors_off,
7155                    *colors_len,
7156                    *sh_coeffs_off,
7157                    *sh_coeffs_len,
7158                    *meta_off,
7159                    *d_loss_off,
7160                    *d_loss_len,
7161                    *packed_off,
7162                    *packed_len,
7163                    *width,
7164                    *height,
7165                    *tile_size,
7166                    *radius_scale,
7167                    *alpha_cutoff,
7168                    *max_splat_steps,
7169                    *transmittance_threshold,
7170                    *max_list_entries,
7171                    *loss_grad_clip,
7172                    *sh_band,
7173                    *max_anisotropy,
7174                    base,
7175                );
7176            },
7177
7178            Thunk::GaussianSplatPrepare {
7179                positions_off,
7180                positions_len,
7181                scales_off,
7182                scales_len,
7183                rotations_off,
7184                rotations_len,
7185                opacities_off,
7186                opacities_len,
7187                colors_off,
7188                colors_len,
7189                sh_coeffs_off,
7190                sh_coeffs_len,
7191                meta_off,
7192                meta_len,
7193                prep_off,
7194                prep_len,
7195                width,
7196                height,
7197                tile_size,
7198                radius_scale,
7199                alpha_cutoff,
7200                max_splat_steps,
7201                transmittance_threshold,
7202                max_list_entries,
7203            } => unsafe {
7204                crate::splat::execute_gaussian_splat_prepare(
7205                    *positions_off,
7206                    *positions_len,
7207                    *scales_off,
7208                    *scales_len,
7209                    *rotations_off,
7210                    *rotations_len,
7211                    *opacities_off,
7212                    *opacities_len,
7213                    *colors_off,
7214                    *colors_len,
7215                    *sh_coeffs_off,
7216                    *sh_coeffs_len,
7217                    *meta_off,
7218                    *meta_len,
7219                    *prep_off,
7220                    *prep_len,
7221                    *width,
7222                    *height,
7223                    *tile_size,
7224                    *radius_scale,
7225                    *alpha_cutoff,
7226                    *max_splat_steps,
7227                    *transmittance_threshold,
7228                    *max_list_entries,
7229                    base,
7230                );
7231            },
7232
7233            Thunk::GaussianSplatRasterize {
7234                prep_off,
7235                prep_len,
7236                meta_off,
7237                meta_len,
7238                dst_off,
7239                dst_len,
7240                count,
7241                width,
7242                height,
7243                tile_size,
7244                alpha_cutoff,
7245                max_splat_steps,
7246                transmittance_threshold,
7247                max_list_entries,
7248            } => unsafe {
7249                crate::splat::execute_gaussian_splat_rasterize(
7250                    *prep_off,
7251                    *prep_len,
7252                    *meta_off,
7253                    *meta_len,
7254                    *dst_off,
7255                    *dst_len,
7256                    *count,
7257                    *width,
7258                    *height,
7259                    *tile_size,
7260                    *alpha_cutoff,
7261                    *max_splat_steps,
7262                    *transmittance_threshold,
7263                    *max_list_entries,
7264                    base,
7265                );
7266            },
7267
7268            Thunk::Fft1d {
7269                src,
7270                dst,
7271                outer,
7272                n_complex,
7273                inverse,
7274                norm_tag,
7275                dtype,
7276            } => unsafe {
7277                match dtype {
7278                    rlx_ir::DType::F64 => execute_fft1d_f64(
7279                        *src,
7280                        *dst,
7281                        *outer as usize,
7282                        *n_complex as usize,
7283                        *inverse,
7284                        *norm_tag,
7285                        base,
7286                    ),
7287                    rlx_ir::DType::F32 => execute_fft1d_f32(
7288                        *src,
7289                        *dst,
7290                        *outer as usize,
7291                        *n_complex as usize,
7292                        *inverse,
7293                        *norm_tag,
7294                        base,
7295                    ),
7296                    rlx_ir::DType::C64 => execute_fft1d_c64(
7297                        *src,
7298                        *dst,
7299                        *outer as usize,
7300                        *n_complex as usize,
7301                        *inverse,
7302                        *norm_tag,
7303                        base,
7304                    ),
7305                    other => panic!("Op::Fft on CPU requires F32/F64/C64, got {other:?}"),
7306                }
7307            },
7308
7309            // CustomFn dispatch (interpreted path). Mirrors the
7310            // pre-compiled-closure variant elsewhere in this file.
7311            // Patched by rlx-eda.
7312            Thunk::CustomFn {
7313                body,
7314                body_init,
7315                inputs,
7316                body_output_off,
7317                outer_output_off,
7318                out_bytes,
7319            } => {
7320                let mut body_buf: Vec<u8> = (**body_init).clone();
7321                unsafe {
7322                    for (body_in_off, outer_in_off, n_bytes) in inputs.iter() {
7323                        let src = (base as *const u8).add(*outer_in_off);
7324                        let dst = body_buf.as_mut_ptr().add(*body_in_off);
7325                        std::ptr::copy_nonoverlapping(src, dst, *n_bytes as usize);
7326                    }
7327                }
7328                execute_thunks(body, &mut body_buf);
7329                unsafe {
7330                    let src = body_buf.as_ptr().add(*body_output_off);
7331                    let dst = base.add(*outer_output_off);
7332                    std::ptr::copy_nonoverlapping(src, dst, *out_bytes as usize);
7333                }
7334            }
7335
7336            Thunk::Sgemm { a, b, c, m, k, n } => {
7337                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
7338                unsafe {
7339                    crate::blas::sgemm_auto(
7340                        sl(*a, base, m * k),
7341                        sl(*b, base, k * n),
7342                        sl_mut(*c, base, m * n),
7343                        m,
7344                        k,
7345                        n,
7346                    );
7347                }
7348            }
7349
7350            Thunk::DenseSolveF64 { a, b, x, n, nrhs } => {
7351                let (n_, nrhs_) = (*n as usize, *nrhs as usize);
7352                // LAPACK overwrites both A and B; clone into scratch
7353                // each call. Caller's A and b must be preserved for
7354                // VJP recompute. (Eventually: swap to a factor-once /
7355                // solve-many scheme; that's the symbolic-reuse story
7356                // and lives with the sparse path.)
7357                unsafe {
7358                    let a_src = sl_f64(*a, base, n_ * n_);
7359                    let b_src = sl_f64(*b, base, n_ * nrhs_);
7360                    let mut a_scratch: Vec<f64> = a_src.to_vec();
7361                    let mut x_buf: Vec<f64> = b_src.to_vec();
7362                    let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7363                    if info != 0 {
7364                        panic!(
7365                            "DenseSolveF64: dgesv reported singular matrix \
7366                                (info={info}, n={n_}, nrhs={nrhs_})"
7367                        );
7368                    }
7369                    let dst = sl_mut_f64(*x, base, n_ * nrhs_);
7370                    dst.copy_from_slice(&x_buf);
7371                }
7372            }
7373
7374            Thunk::DenseSolveF32 { a, b, x, n, nrhs } => {
7375                let (n_, nrhs_) = (*n as usize, *nrhs as usize);
7376                unsafe {
7377                    let a_src = sl(*a, base, n_ * n_);
7378                    let b_src = sl(*b, base, n_ * nrhs_);
7379                    let mut a_scratch: Vec<f32> = a_src.to_vec();
7380                    let mut x_buf: Vec<f32> = b_src.to_vec();
7381                    let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7382                    if info != 0 {
7383                        panic!(
7384                            "DenseSolveF32: sgesv reported singular matrix \
7385                             (info={info}, n={n_}, nrhs={nrhs_})"
7386                        );
7387                    }
7388                    let dst = sl_mut(*x, base, n_ * nrhs_);
7389                    dst.copy_from_slice(&x_buf);
7390                }
7391            }
7392
7393            Thunk::BatchedDenseSolveF64 {
7394                a,
7395                b,
7396                x,
7397                batch,
7398                n,
7399                nrhs,
7400            } => {
7401                // Per slice: extract A_i and b_i, dgesv, write x_i.
7402                // LAPACK has no batched dgesv on Accelerate, so this
7403                // is a serial loop over the batch axis. cuSOLVER /
7404                // hipSOLVER expose `getrfBatched` / `getrsBatched` for
7405                // the GPU path — we'll wire that in rlx-cuda when
7406                // someone needs Linux+CUDA.
7407                let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
7408                let a_stride = n_ * n_;
7409                let b_stride = n_ * nrhs_;
7410                unsafe {
7411                    let a_full = sl_f64(*a, base, b_ * a_stride);
7412                    let b_full = sl_f64(*b, base, b_ * b_stride);
7413                    let x_full = sl_mut_f64(*x, base, b_ * b_stride);
7414                    for bi in 0..b_ {
7415                        let mut a_scratch: Vec<f64> =
7416                            a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
7417                        let mut x_buf: Vec<f64> =
7418                            b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
7419                        let info = crate::blas::dgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7420                        if info != 0 {
7421                            panic!(
7422                                "BatchedDenseSolveF64: slice {bi} \
7423                                    singular (info={info}, n={n_}, nrhs={nrhs_})"
7424                            );
7425                        }
7426                        x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
7427                    }
7428                }
7429            }
7430
7431            Thunk::BatchedDenseSolveF32 {
7432                a,
7433                b,
7434                x,
7435                batch,
7436                n,
7437                nrhs,
7438            } => {
7439                let (b_, n_, nrhs_) = (*batch as usize, *n as usize, *nrhs as usize);
7440                let a_stride = n_ * n_;
7441                let b_stride = n_ * nrhs_;
7442                unsafe {
7443                    let a_full = sl(*a, base, b_ * a_stride);
7444                    let b_full = sl(*b, base, b_ * b_stride);
7445                    let x_full = sl_mut(*x, base, b_ * b_stride);
7446                    for bi in 0..b_ {
7447                        let mut a_scratch = a_full[bi * a_stride..(bi + 1) * a_stride].to_vec();
7448                        let mut x_buf = b_full[bi * b_stride..(bi + 1) * b_stride].to_vec();
7449                        let info = crate::blas::sgesv(&mut a_scratch, &mut x_buf, n_, nrhs_);
7450                        if info != 0 {
7451                            panic!("BatchedDenseSolveF32: slice {bi} singular (info={info})");
7452                        }
7453                        x_full[bi * b_stride..(bi + 1) * b_stride].copy_from_slice(&x_buf);
7454                    }
7455                }
7456            }
7457
7458            Thunk::BatchedDgemmF64 {
7459                a,
7460                b,
7461                c,
7462                batch,
7463                m,
7464                k,
7465                n,
7466            } => {
7467                let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
7468                let a_stride = m_ * k_;
7469                let b_stride = k_ * n_;
7470                let c_stride = m_ * n_;
7471                unsafe {
7472                    let a_full = sl_f64(*a, base, b_ * a_stride);
7473                    let b_full = sl_f64(*b, base, b_ * b_stride);
7474                    let c_full = sl_mut_f64(*c, base, b_ * c_stride);
7475                    for bi in 0..b_ {
7476                        let a_slice = &a_full[bi * a_stride..(bi + 1) * a_stride];
7477                        let b_slice = &b_full[bi * b_stride..(bi + 1) * b_stride];
7478                        let c_slice = &mut c_full[bi * c_stride..(bi + 1) * c_stride];
7479                        crate::blas::dgemm(a_slice, b_slice, c_slice, m_, k_, n_);
7480                    }
7481                }
7482            }
7483
7484            Thunk::BatchedSgemm {
7485                a,
7486                b,
7487                c,
7488                batch,
7489                m,
7490                k,
7491                n,
7492            } => {
7493                let (b_, m_, k_, n_) = (*batch as usize, *m as usize, *k as usize, *n as usize);
7494                let a_stride = m_ * k_;
7495                let b_stride = k_ * n_;
7496                let c_stride = m_ * n_;
7497                unsafe {
7498                    let a_full = sl(*a, base, b_ * a_stride);
7499                    let b_full = sl(*b, base, b_ * b_stride);
7500                    let c_full = sl_mut(*c, base, b_ * c_stride);
7501                    for bi in 0..b_ {
7502                        let a_slice = &a_full[bi * a_stride..(bi + 1) * a_stride];
7503                        let b_slice = &b_full[bi * b_stride..(bi + 1) * b_stride];
7504                        let c_slice = &mut c_full[bi * c_stride..(bi + 1) * c_stride];
7505                        crate::blas::sgemm_auto(a_slice, b_slice, c_slice, m_, k_, n_);
7506                    }
7507                }
7508            }
7509
7510            Thunk::Dgemm { a, b, c, m, k, n } => {
7511                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
7512                unsafe {
7513                    crate::blas::dgemm(
7514                        sl_f64(*a, base, m * k),
7515                        sl_f64(*b, base, k * n),
7516                        sl_mut_f64(*c, base, m * n),
7517                        m,
7518                        k,
7519                        n,
7520                    );
7521                }
7522            }
7523
7524            Thunk::TransposeF64 {
7525                src,
7526                dst,
7527                in_total,
7528                out_dims,
7529                in_strides,
7530            } => unsafe {
7531                let inp = sl_f64(*src, base, *in_total as usize);
7532                let out_total: usize = out_dims.iter().map(|d| *d as usize).product();
7533                let out = sl_mut_f64(*dst, base, out_total);
7534                transpose_walk_f64(inp, out, out_dims, in_strides);
7535            },
7536
7537            Thunk::ActivationF64 {
7538                src,
7539                dst,
7540                len,
7541                kind,
7542            } => {
7543                let len = *len as usize;
7544                unsafe {
7545                    let inp = sl_f64(*src, base, len);
7546                    let out = sl_mut_f64(*dst, base, len);
7547                    apply_activation_f64(inp, out, *kind);
7548                }
7549            }
7550
7551            Thunk::ReduceSumF64 {
7552                src,
7553                dst,
7554                outer,
7555                reduced,
7556                inner,
7557            } => {
7558                let (o, r, n) = (*outer as usize, *reduced as usize, *inner as usize);
7559                unsafe {
7560                    let inp = sl_f64(*src, base, o * r * n);
7561                    let out = sl_mut_f64(*dst, base, o * n);
7562                    reduce_sum_f64(inp, out, o, r, n);
7563                }
7564            }
7565
7566            Thunk::CopyF64 { src, dst, len } => {
7567                let len = *len as usize;
7568                if *src == *dst { /* aliased, no copy needed */
7569                } else {
7570                    unsafe {
7571                        let s = sl_f64(*src, base, len);
7572                        let d = sl_mut_f64(*dst, base, len);
7573                        d.copy_from_slice(s);
7574                    }
7575                }
7576            }
7577
7578            Thunk::BinaryFullF64 {
7579                lhs,
7580                rhs,
7581                dst,
7582                len,
7583                lhs_len,
7584                rhs_len,
7585                op,
7586                out_dims_bcast,
7587                bcast_lhs_strides,
7588                bcast_rhs_strides,
7589            } => {
7590                let len = *len as usize;
7591                let lhs_len = *lhs_len as usize;
7592                let rhs_len = *rhs_len as usize;
7593                unsafe {
7594                    let l = sl_f64(*lhs, base, lhs_len);
7595                    let r = sl_f64(*rhs, base, rhs_len);
7596                    let d = sl_mut_f64(*dst, base, len);
7597                    if lhs_len == len && rhs_len == len {
7598                        for i in 0..len {
7599                            d[i] = binary_op_f64(*op, l[i], r[i]);
7600                        }
7601                    } else if !out_dims_bcast.is_empty() {
7602                        // Shape-aware broadcast path: correct for
7603                        // arbitrary NumPy-style broadcasts including
7604                        // bidirectional `[N,1] op [1,S]`.
7605                        let rank = out_dims_bcast.len();
7606                        let mut coords = vec![0u32; rank];
7607                        for i in 0..len {
7608                            let mut rem = i;
7609                            for ax in (0..rank).rev() {
7610                                let sz = out_dims_bcast[ax] as usize;
7611                                coords[ax] = (rem % sz) as u32;
7612                                rem /= sz;
7613                            }
7614                            let mut li: usize = 0;
7615                            let mut ri: usize = 0;
7616                            for ax in 0..rank {
7617                                li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
7618                                ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
7619                            }
7620                            d[i] = binary_op_f64(*op, l[li], r[ri]);
7621                        }
7622                    } else {
7623                        // Fallback: legacy modulo path (preserved for
7624                        // dynamic-shape graphs where strides can't be
7625                        // precomputed). Only correct for scalar /
7626                        // last-axis broadcast.
7627                        for i in 0..len {
7628                            d[i] = binary_op_f64(*op, l[i % lhs_len], r[i % rhs_len]);
7629                        }
7630                    }
7631                }
7632            }
7633
7634            Thunk::BinaryFullC64 {
7635                lhs,
7636                rhs,
7637                dst,
7638                len,
7639                lhs_len,
7640                rhs_len,
7641                op,
7642                out_dims_bcast,
7643                bcast_lhs_strides,
7644                bcast_rhs_strides,
7645            } => {
7646                // Complex element layout: [re_0, im_0, re_1, im_1, ...]
7647                // Underlying f32 buffer length is 2·N (N = complex
7648                // element count). All offsets are byte offsets; the
7649                // `sl` helper reads as f32 starting at the byte
7650                // offset, so f32-length = 2·complex-len.
7651                let n_out = *len as usize;
7652                let n_l = *lhs_len as usize;
7653                let n_r = *rhs_len as usize;
7654                unsafe {
7655                    let l = sl(*lhs, base, 2 * n_l);
7656                    let r = sl(*rhs, base, 2 * n_r);
7657                    let d = sl_mut(*dst, base, 2 * n_out);
7658                    let do_c64 = |a_re: f32, a_im: f32, b_re: f32, b_im: f32| -> (f32, f32) {
7659                        match op {
7660                            BinaryOp::Add => (a_re + b_re, a_im + b_im),
7661                            BinaryOp::Sub => (a_re - b_re, a_im - b_im),
7662                            BinaryOp::Mul => (a_re * b_re - a_im * b_im, a_re * b_im + a_im * b_re),
7663                            BinaryOp::Div => {
7664                                let denom = b_re * b_re + b_im * b_im;
7665                                (
7666                                    (a_re * b_re + a_im * b_im) / denom,
7667                                    (a_im * b_re - a_re * b_im) / denom,
7668                                )
7669                            }
7670                            BinaryOp::Max | BinaryOp::Min | BinaryOp::Pow => {
7671                                unreachable!("C64 max/min/pow rejected at lowering")
7672                            }
7673                        }
7674                    };
7675                    if n_l == n_out && n_r == n_out {
7676                        for i in 0..n_out {
7677                            let (re, im) = do_c64(l[2 * i], l[2 * i + 1], r[2 * i], r[2 * i + 1]);
7678                            d[2 * i] = re;
7679                            d[2 * i + 1] = im;
7680                        }
7681                    } else if !out_dims_bcast.is_empty() {
7682                        // Strided complex broadcast: strides are in
7683                        // *complex element* units; multiply by 2 when
7684                        // indexing into the f32 buffer.
7685                        let rank = out_dims_bcast.len();
7686                        let mut coords = vec![0u32; rank];
7687                        for i in 0..n_out {
7688                            let mut rem = i;
7689                            for ax in (0..rank).rev() {
7690                                let sz = out_dims_bcast[ax] as usize;
7691                                coords[ax] = (rem % sz) as u32;
7692                                rem /= sz;
7693                            }
7694                            let mut li: usize = 0;
7695                            let mut ri: usize = 0;
7696                            for ax in 0..rank {
7697                                li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
7698                                ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
7699                            }
7700                            let (re, im) =
7701                                do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
7702                            d[2 * i] = re;
7703                            d[2 * i + 1] = im;
7704                        }
7705                    } else {
7706                        // Modulo fallback (scalar / last-axis broadcast).
7707                        for i in 0..n_out {
7708                            let li = if n_l == 1 { 0 } else { i % n_l };
7709                            let ri = if n_r == 1 { 0 } else { i % n_r };
7710                            let (re, im) =
7711                                do_c64(l[2 * li], l[2 * li + 1], r[2 * ri], r[2 * ri + 1]);
7712                            d[2 * i] = re;
7713                            d[2 * i + 1] = im;
7714                        }
7715                    }
7716                }
7717            }
7718
7719            Thunk::ComplexNormSqF32 { src, dst, len } => {
7720                let n = *len as usize;
7721                unsafe {
7722                    let s = sl(*src, base, 2 * n);
7723                    let d = sl_mut(*dst, base, n);
7724                    for i in 0..n {
7725                        let re = s[2 * i];
7726                        let im = s[2 * i + 1];
7727                        d[i] = re * re + im * im;
7728                    }
7729                }
7730            }
7731
7732            Thunk::ComplexNormSqBackwardF32 { z, g, dz, len } => {
7733                // Wirtinger: dz = g · z, element-wise complex
7734                // (g is real, z is complex).
7735                let n = *len as usize;
7736                unsafe {
7737                    let zb = sl(*z, base, 2 * n);
7738                    let gb = sl(*g, base, n);
7739                    let db = sl_mut(*dz, base, 2 * n);
7740                    for i in 0..n {
7741                        let re = zb[2 * i];
7742                        let im = zb[2 * i + 1];
7743                        let gv = gb[i];
7744                        db[2 * i] = gv * re;
7745                        db[2 * i + 1] = gv * im;
7746                    }
7747                }
7748            }
7749
7750            Thunk::ConjugateC64 { src, dst, len } => {
7751                let n = *len as usize;
7752                unsafe {
7753                    let s = sl(*src, base, 2 * n);
7754                    let d = sl_mut(*dst, base, 2 * n);
7755                    for i in 0..n {
7756                        d[2 * i] = s[2 * i];
7757                        d[2 * i + 1] = -s[2 * i + 1];
7758                    }
7759                }
7760            }
7761
7762            Thunk::ActivationC64 {
7763                src,
7764                dst,
7765                len,
7766                kind,
7767            } => {
7768                let n = *len as usize;
7769                unsafe {
7770                    let s = sl(*src, base, 2 * n);
7771                    let d = sl_mut(*dst, base, 2 * n);
7772                    for i in 0..n {
7773                        let a = s[2 * i];
7774                        let b = s[2 * i + 1];
7775                        let (re, im) = match kind {
7776                            Activation::Neg => (-a, -b),
7777                            Activation::Exp => {
7778                                // exp(a + bi) = e^a · (cos b + i·sin b)
7779                                let ea = a.exp();
7780                                (ea * b.cos(), ea * b.sin())
7781                            }
7782                            Activation::Log => {
7783                                // log(z) = log|z| + i·arg(z), principal branch
7784                                let r = (a * a + b * b).sqrt();
7785                                (r.ln(), b.atan2(a))
7786                            }
7787                            Activation::Sqrt => {
7788                                // sqrt(a+bi) = sqrt((|z|+a)/2) + sign(b)·i·sqrt((|z|-a)/2)
7789                                // Principal branch; for b == 0 and a < 0 returns +i·sqrt(|a|).
7790                                let r = (a * a + b * b).sqrt();
7791                                let re = ((r + a) * 0.5).max(0.0).sqrt();
7792                                let im_mag = ((r - a) * 0.5).max(0.0).sqrt();
7793                                let im = if b >= 0.0 { im_mag } else { -im_mag };
7794                                (re, im)
7795                            }
7796                            _ => unreachable!("non-C64 activation kind survived lowering"),
7797                        };
7798                        d[2 * i] = re;
7799                        d[2 * i + 1] = im;
7800                    }
7801                }
7802            }
7803
7804            Thunk::Scan {
7805                body,
7806                body_init,
7807                body_input_off,
7808                body_output_off,
7809                outer_init_off,
7810                outer_final_off,
7811                length,
7812                carry_bytes,
7813                save_trajectory,
7814                xs_inputs,
7815                bcast_inputs,
7816                num_checkpoints,
7817            } => {
7818                let cb = *carry_bytes as usize;
7819                let n_steps = *length as usize;
7820                // Checkpoint mode: when 0 < K < length, save trajectory[k]
7821                // only when t == c_k = floor((k+1) * length / K) - 1.
7822                // The last index c_{K-1} = length - 1 always.
7823                let k_total = if *num_checkpoints == 0 || *num_checkpoints == *length {
7824                    n_steps // save every step
7825                } else {
7826                    *num_checkpoints as usize
7827                };
7828                let checkpoint_t_for_k = |k: usize| -> usize {
7829                    if k_total == n_steps {
7830                        k
7831                    } else {
7832                        ((k + 1) * n_steps)
7833                            .div_ceil(k_total)
7834                            .saturating_sub(1)
7835                            .min(n_steps - 1)
7836                    }
7837                };
7838                let mut next_k = 0usize;
7839
7840                let mut body_buf: Vec<u8> = (**body_init).clone();
7841                unsafe {
7842                    std::ptr::copy_nonoverlapping(
7843                        base.add(*outer_init_off),
7844                        body_buf.as_mut_ptr().add(*body_input_off),
7845                        cb,
7846                    );
7847                    // Broadcast inputs: copy each one into the body's
7848                    // input slot ONCE. They aren't touched in the
7849                    // iteration loop below (in contrast to xs).
7850                    for (body_b_off, outer_b_off, total_bytes) in bcast_inputs.iter() {
7851                        std::ptr::copy_nonoverlapping(
7852                            base.add(*outer_b_off),
7853                            body_buf.as_mut_ptr().add(*body_b_off),
7854                            *total_bytes as usize,
7855                        );
7856                    }
7857                }
7858                for t in 0..n_steps {
7859                    for (body_x_off, outer_xs_off, per_step_bytes) in xs_inputs.iter() {
7860                        let psb = *per_step_bytes as usize;
7861                        unsafe {
7862                            std::ptr::copy_nonoverlapping(
7863                                base.add(*outer_xs_off + t * psb),
7864                                body_buf.as_mut_ptr().add(*body_x_off),
7865                                psb,
7866                            );
7867                        }
7868                    }
7869
7870                    execute_thunks(body, &mut body_buf);
7871
7872                    if *save_trajectory && next_k < k_total && t == checkpoint_t_for_k(next_k) {
7873                        unsafe {
7874                            std::ptr::copy_nonoverlapping(
7875                                body_buf.as_ptr().add(*body_output_off),
7876                                base.add(*outer_final_off + next_k * cb),
7877                                cb,
7878                            );
7879                        }
7880                        next_k += 1;
7881                    }
7882
7883                    if *body_output_off != *body_input_off {
7884                        body_buf
7885                            .copy_within(*body_output_off..*body_output_off + cb, *body_input_off);
7886                    }
7887                }
7888
7889                if !*save_trajectory {
7890                    // Single final-carry write.
7891                    unsafe {
7892                        std::ptr::copy_nonoverlapping(
7893                            body_buf.as_ptr().add(*body_output_off),
7894                            base.add(*outer_final_off),
7895                            cb,
7896                        );
7897                    }
7898                }
7899            }
7900
7901            Thunk::ScanBackward {
7902                body_vjp,
7903                body_init,
7904                body_carry_in_off,
7905                body_x_offs,
7906                body_d_output_off,
7907                body_dcarry_out_off,
7908                outer_init_off,
7909                outer_traj_off,
7910                outer_upstream_off,
7911                outer_xs_offs,
7912                outer_dinit_off,
7913                length,
7914                carry_bytes,
7915                save_trajectory,
7916                num_checkpoints,
7917                forward_body,
7918                forward_body_init,
7919                forward_body_carry_in_off,
7920                forward_body_output_off,
7921                forward_body_x_offs,
7922                carry_elem_size,
7923            } => {
7924                // Two backward paths share the same per-iteration body
7925                // (body_vjp run + dcarry threading). The "All" path
7926                // reads the carry directly from the saved trajectory
7927                // each step. The "Recursive checkpointing" path stores
7928                // only K saved checkpoints and reconstructs intermediate
7929                // carries via Griewank-style recursive subdivision —
7930                // see [`griewank_process_segment`]. Auxiliary memory
7931                // is `O(log(segment_size) · carry_bytes)` for the
7932                // recursion stack, vs the old segment-cache scheme's
7933                // `O(segment_size · carry_bytes)`. Total recompute work
7934                // grows from `O(length)` to `O(length · log)`, which
7935                // is the canonical Griewank trade.
7936                let cb = *carry_bytes as usize;
7937                let n_steps = *length as usize;
7938                let k_total = *num_checkpoints as usize;
7939                let is_recursive = k_total != 0 && k_total != n_steps;
7940                let checkpoint_t_for_k = |k: usize| -> usize {
7941                    ((k + 1) * n_steps)
7942                        .div_ceil(k_total)
7943                        .saturating_sub(1)
7944                        .min(n_steps - 1)
7945                };
7946
7947                let mut fwd_buf: Vec<u8> = if is_recursive {
7948                    (**forward_body_init.as_ref().unwrap()).clone()
7949                } else {
7950                    Vec::new()
7951                };
7952
7953                let mut dcarry: Vec<u8> = vec![0u8; cb];
7954                if !*save_trajectory {
7955                    unsafe {
7956                        std::ptr::copy_nonoverlapping(
7957                            base.add(*outer_upstream_off),
7958                            dcarry.as_mut_ptr(),
7959                            cb,
7960                        );
7961                    }
7962                }
7963
7964                let mut body_buf: Vec<u8> = (**body_init).clone();
7965
7966                // Per-iteration backward action — shared between the
7967                // direct-trajectory (All) and Griewank (Recursive) paths.
7968                // Both feed the same body_vjp run with carry-at-t,
7969                // x_t_i, and d_output, then thread dcarry backward.
7970                let process_iter =
7971                    |t: usize, carry_in: &[u8], dcarry: &mut Vec<u8>, body_buf: &mut Vec<u8>| {
7972                        if *save_trajectory {
7973                            unsafe {
7974                                let up_off = *outer_upstream_off + t * cb;
7975                                match *carry_elem_size {
7976                                    4 => {
7977                                        let up_ptr = base.add(up_off) as *const f32;
7978                                        let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
7979                                        let n_elems = cb / 4;
7980                                        for i in 0..n_elems {
7981                                            *dc_ptr.add(i) += *up_ptr.add(i);
7982                                        }
7983                                    }
7984                                    8 => {
7985                                        let up_ptr = base.add(up_off) as *const f64;
7986                                        let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
7987                                        let n_elems = cb / 8;
7988                                        for i in 0..n_elems {
7989                                            *dc_ptr.add(i) += *up_ptr.add(i);
7990                                        }
7991                                    }
7992                                    other => panic!(
7993                                        "ScanBackward: unsupported carry elem size {other} \
7994                                     (only f32/f64 carries are supported today)"
7995                                    ),
7996                                }
7997                            }
7998                        }
7999                        body_buf[*body_carry_in_off..*body_carry_in_off + cb]
8000                            .copy_from_slice(carry_in);
8001                        unsafe {
8002                            for (i, body_x_off) in body_x_offs.iter().enumerate() {
8003                                let (outer_xs_off, per_step_bytes) = outer_xs_offs[i];
8004                                let psb = per_step_bytes as usize;
8005                                std::ptr::copy_nonoverlapping(
8006                                    base.add(outer_xs_off + t * psb),
8007                                    body_buf.as_mut_ptr().add(*body_x_off),
8008                                    psb,
8009                                );
8010                            }
8011                            std::ptr::copy_nonoverlapping(
8012                                dcarry.as_ptr(),
8013                                body_buf.as_mut_ptr().add(*body_d_output_off),
8014                                cb,
8015                            );
8016                        }
8017                        execute_thunks(body_vjp, body_buf);
8018                        unsafe {
8019                            std::ptr::copy_nonoverlapping(
8020                                body_buf.as_ptr().add(*body_dcarry_out_off),
8021                                dcarry.as_mut_ptr(),
8022                                cb,
8023                            );
8024                        }
8025                    };
8026
8027                if is_recursive {
8028                    // Griewank treeverse path. Process saved-checkpoint
8029                    // segments from highest-t to lowest-t; within each,
8030                    // recursive binary subdivision via
8031                    // `griewank_process_segment`. Auxiliary memory:
8032                    // O(log(seg_size) · cb) for the recursion stack
8033                    // (vs O(seg_size · cb) for the older segment-cache
8034                    // scheme); recompute work: O(seg_size · log).
8035                    let leaf_threshold = 4usize;
8036                    let fb_sched = forward_body.as_ref().unwrap();
8037                    let fb_init = forward_body_init.as_ref().unwrap().as_slice();
8038                    let mut segment_end = n_steps - 1;
8039                    for seg_k in (0..k_total).rev() {
8040                        let segment_start = if seg_k == 0 {
8041                            0
8042                        } else {
8043                            checkpoint_t_for_k(seg_k - 1) + 1
8044                        };
8045                        let mut anchor: Vec<u8> = vec![0u8; cb];
8046                        unsafe {
8047                            let src = if seg_k == 0 {
8048                                base.add(*outer_init_off)
8049                            } else {
8050                                base.add(*outer_traj_off + (seg_k - 1) * cb)
8051                            };
8052                            std::ptr::copy_nonoverlapping(src, anchor.as_mut_ptr(), cb);
8053                        }
8054                        // Closure adapter for the helper's signature
8055                        // (mutably re-borrows dcarry / body_buf each call).
8056                        let mut leaf_action = |t: usize, carry_in: &[u8]| {
8057                            process_iter(t, carry_in, &mut dcarry, &mut body_buf);
8058                        };
8059                        unsafe {
8060                            griewank_process_segment(
8061                                segment_start,
8062                                segment_end,
8063                                &anchor,
8064                                cb,
8065                                fb_sched,
8066                                fb_init,
8067                                *forward_body_carry_in_off,
8068                                *forward_body_output_off,
8069                                forward_body_x_offs,
8070                                base,
8071                                outer_xs_offs,
8072                                &mut fwd_buf,
8073                                leaf_threshold,
8074                                &mut leaf_action,
8075                            );
8076                        }
8077                        if seg_k == 0 {
8078                            break;
8079                        }
8080                        segment_end = segment_start - 1;
8081                    }
8082                } else {
8083                    // All-trajectory path: read each carry directly
8084                    // from the saved trajectory buffer.
8085                    let mut carry_buf: Vec<u8> = vec![0u8; cb];
8086                    for t in (0..n_steps).rev() {
8087                        unsafe {
8088                            let src = if t == 0 {
8089                                base.add(*outer_init_off)
8090                            } else {
8091                                base.add(*outer_traj_off + (t - 1) * cb)
8092                            };
8093                            std::ptr::copy_nonoverlapping(src, carry_buf.as_mut_ptr(), cb);
8094                        }
8095                        process_iter(t, &carry_buf, &mut dcarry, &mut body_buf);
8096                    }
8097                }
8098
8099                unsafe {
8100                    std::ptr::copy_nonoverlapping(dcarry.as_ptr(), base.add(*outer_dinit_off), cb);
8101                }
8102            }
8103
8104            Thunk::ScanBackwardXs {
8105                body_vjp,
8106                body_init,
8107                body_carry_in_off,
8108                body_x_offs,
8109                body_d_output_off,
8110                body_dcarry_out_off,
8111                body_dxs_out_off,
8112                outer_init_off,
8113                outer_traj_off,
8114                outer_upstream_off,
8115                outer_xs_offs,
8116                outer_dxs_off,
8117                length,
8118                carry_bytes,
8119                carry_elem_size,
8120                per_step_bytes,
8121                save_trajectory,
8122                num_checkpoints,
8123                forward_body,
8124                forward_body_init,
8125                forward_body_carry_in_off,
8126                forward_body_output_off,
8127                forward_body_x_offs,
8128            } => {
8129                let cb = *carry_bytes as usize;
8130                let psb = *per_step_bytes as usize;
8131                let n_steps = *length as usize;
8132                let k_total = *num_checkpoints as usize;
8133                let is_recursive = k_total != 0 && k_total != n_steps;
8134                let checkpoint_t_for_k = |k: usize| -> usize {
8135                    ((k + 1) * n_steps)
8136                        .div_ceil(k_total)
8137                        .saturating_sub(1)
8138                        .min(n_steps - 1)
8139                };
8140
8141                // Forward-body recompute scratch + segment cache —
8142                // exact mirror of the ScanBackward path. With ≈√length
8143                // checkpoints, total recompute work is O(length).
8144                let mut fwd_buf: Vec<u8> = if is_recursive {
8145                    (**forward_body_init.as_ref().unwrap()).clone()
8146                } else {
8147                    Vec::new()
8148                };
8149                let mut seg_cache: Vec<u8> = Vec::new();
8150                let mut seg_start_t: usize = usize::MAX;
8151                let mut seg_count: usize = 0;
8152                let recompute_carry_t =
8153                    |t: usize,
8154                     dst: &mut [u8],
8155                     fwd_buf: &mut Vec<u8>,
8156                     seg_cache: &mut Vec<u8>,
8157                     seg_start_t: &mut usize,
8158                     seg_count: &mut usize| {
8159                        if !is_recursive {
8160                            unsafe {
8161                                let src = if t == 0 {
8162                                    base.add(*outer_init_off)
8163                                } else {
8164                                    base.add(*outer_traj_off + (t - 1) * cb)
8165                                };
8166                                std::ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), cb);
8167                            }
8168                            return;
8169                        }
8170                        if *seg_start_t != usize::MAX
8171                            && t >= *seg_start_t
8172                            && t < *seg_start_t + *seg_count
8173                        {
8174                            let off = (t - *seg_start_t) * cb;
8175                            dst.copy_from_slice(&seg_cache[off..off + cb]);
8176                            return;
8177                        }
8178                        let seg_k = (0..k_total)
8179                            .find(|&k| t <= checkpoint_t_for_k(k))
8180                            .unwrap_or(k_total - 1);
8181                        let (anchor_t, anchor_ptr): (usize, *const u8) = if seg_k == 0 {
8182                            (0, unsafe { base.add(*outer_init_off) as *const u8 })
8183                        } else {
8184                            let prev_ck = checkpoint_t_for_k(seg_k - 1);
8185                            (prev_ck + 1, unsafe {
8186                                base.add(*outer_traj_off + (seg_k - 1) * cb) as *const u8
8187                            })
8188                        };
8189                        let seg_end_t = checkpoint_t_for_k(seg_k);
8190                        let seg_size = seg_end_t - anchor_t + 1;
8191
8192                        fwd_buf.copy_from_slice(forward_body_init.as_ref().unwrap());
8193                        unsafe {
8194                            std::ptr::copy_nonoverlapping(
8195                                anchor_ptr,
8196                                fwd_buf.as_mut_ptr().add(*forward_body_carry_in_off),
8197                                cb,
8198                            );
8199                        }
8200                        seg_cache.resize(seg_size * cb, 0u8);
8201                        seg_cache[0..cb].copy_from_slice(
8202                            &fwd_buf[*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
8203                        );
8204                        let fb_sched = forward_body.as_ref().unwrap();
8205                        for i in 1..seg_size {
8206                            let cur_iter = anchor_t + i - 1;
8207                            for (idx, fb_x_off) in forward_body_x_offs.iter().enumerate() {
8208                                let (outer_xs_off, x_psb) = outer_xs_offs[idx];
8209                                let xb = x_psb as usize;
8210                                unsafe {
8211                                    std::ptr::copy_nonoverlapping(
8212                                        base.add(outer_xs_off + cur_iter * xb),
8213                                        fwd_buf.as_mut_ptr().add(*fb_x_off),
8214                                        xb,
8215                                    );
8216                                }
8217                            }
8218                            execute_thunks(fb_sched, fwd_buf);
8219                            if *forward_body_output_off != *forward_body_carry_in_off {
8220                                fwd_buf.copy_within(
8221                                    *forward_body_output_off..*forward_body_output_off + cb,
8222                                    *forward_body_carry_in_off,
8223                                );
8224                            }
8225                            let cache_off = i * cb;
8226                            seg_cache[cache_off..cache_off + cb].copy_from_slice(
8227                                &fwd_buf
8228                                    [*forward_body_carry_in_off..*forward_body_carry_in_off + cb],
8229                            );
8230                        }
8231                        *seg_start_t = anchor_t;
8232                        *seg_count = seg_size;
8233
8234                        let off = (t - anchor_t) * cb;
8235                        dst.copy_from_slice(&seg_cache[off..off + cb]);
8236                    };
8237
8238                let mut dcarry: Vec<u8> = vec![0u8; cb];
8239                if !*save_trajectory {
8240                    unsafe {
8241                        std::ptr::copy_nonoverlapping(
8242                            base.add(*outer_upstream_off),
8243                            dcarry.as_mut_ptr(),
8244                            cb,
8245                        );
8246                    }
8247                }
8248
8249                let mut body_buf: Vec<u8> = (**body_init).clone();
8250
8251                for t in (0..n_steps).rev() {
8252                    if *save_trajectory {
8253                        unsafe {
8254                            let up_off = *outer_upstream_off + t * cb;
8255                            match *carry_elem_size {
8256                                4 => {
8257                                    let up_ptr = base.add(up_off) as *const f32;
8258                                    let dc_ptr = dcarry.as_mut_ptr() as *mut f32;
8259                                    let n_elems = cb / 4;
8260                                    for i in 0..n_elems {
8261                                        *dc_ptr.add(i) += *up_ptr.add(i);
8262                                    }
8263                                }
8264                                8 => {
8265                                    let up_ptr = base.add(up_off) as *const f64;
8266                                    let dc_ptr = dcarry.as_mut_ptr() as *mut f64;
8267                                    let n_elems = cb / 8;
8268                                    for i in 0..n_elems {
8269                                        *dc_ptr.add(i) += *up_ptr.add(i);
8270                                    }
8271                                }
8272                                other => panic!(
8273                                    "ScanBackwardXs: unsupported carry elem size {other} \
8274                                     (only f32/f64 carries are supported today)"
8275                                ),
8276                            }
8277                        }
8278                    }
8279
8280                    // Seed body_vjp's carry input via the recompute
8281                    // helper (works for both All and Recursive modes),
8282                    // then x_t_i + d_output.
8283                    let carry_dst_start = *body_carry_in_off;
8284                    {
8285                        let carry_slice = &mut body_buf[carry_dst_start..carry_dst_start + cb];
8286                        recompute_carry_t(
8287                            t,
8288                            carry_slice,
8289                            &mut fwd_buf,
8290                            &mut seg_cache,
8291                            &mut seg_start_t,
8292                            &mut seg_count,
8293                        );
8294                    }
8295                    unsafe {
8296                        for (i, body_x_off) in body_x_offs.iter().enumerate() {
8297                            let (outer_xs_off, x_psb) = outer_xs_offs[i];
8298                            let xb = x_psb as usize;
8299                            std::ptr::copy_nonoverlapping(
8300                                base.add(outer_xs_off + t * xb),
8301                                body_buf.as_mut_ptr().add(*body_x_off),
8302                                xb,
8303                            );
8304                        }
8305                        std::ptr::copy_nonoverlapping(
8306                            dcarry.as_ptr(),
8307                            body_buf.as_mut_ptr().add(*body_d_output_off),
8308                            cb,
8309                        );
8310                    }
8311
8312                    execute_thunks(body_vjp, &mut body_buf);
8313
8314                    // Stash this step's dxs into row `t` of the outer
8315                    // [length, *per_step_xs] output.
8316                    unsafe {
8317                        std::ptr::copy_nonoverlapping(
8318                            body_buf.as_ptr().add(*body_dxs_out_off),
8319                            base.add(*outer_dxs_off + t * psb),
8320                            psb,
8321                        );
8322                    }
8323
8324                    // Update dcarry for next backward iteration.
8325                    unsafe {
8326                        std::ptr::copy_nonoverlapping(
8327                            body_buf.as_ptr().add(*body_dcarry_out_off),
8328                            dcarry.as_mut_ptr(),
8329                            cb,
8330                        );
8331                    }
8332                }
8333            }
8334
8335            Thunk::FusedMmBiasAct {
8336                a,
8337                w,
8338                bias,
8339                c,
8340                m,
8341                k,
8342                n,
8343                act,
8344            } => {
8345                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
8346                unsafe {
8347                    let out = sl_mut(*c, base, m * n);
8348                    crate::blas::sgemm_auto(sl(*a, base, m * k), sl(*w, base, k * n), out, m, k, n);
8349                    match act {
8350                        Some(Activation::Gelu) => {
8351                            crate::kernels::par_bias_gelu(out, sl(*bias, base, n), m, n)
8352                        }
8353                        Some(other) => {
8354                            crate::blas::bias_add(out, sl(*bias, base, n), m, n);
8355                            apply_activation_inplace(out, *other);
8356                        }
8357                        None => crate::blas::bias_add(out, sl(*bias, base, n), m, n),
8358                    }
8359                }
8360            }
8361
8362            Thunk::FusedResidualLN {
8363                x,
8364                res,
8365                bias,
8366                g,
8367                b,
8368                out,
8369                rows,
8370                h,
8371                eps,
8372                has_bias,
8373            } => {
8374                let (rows, h) = (*rows as usize, *h as usize);
8375                unsafe {
8376                    let zero = &zero_bias[..h];
8377                    let bi = if *has_bias { sl(*bias, base, h) } else { zero };
8378                    let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
8379                    let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
8380                    let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
8381                    let bi_ptr = bi.as_ptr() as usize;
8382                    let g_ptr = sl(*g, base, h).as_ptr() as usize;
8383                    let b_ptr = sl(*b, base, h).as_ptr() as usize;
8384                    let e = *eps;
8385                    crate::pool::par_for(rows, 4, &|off, cnt| {
8386                        let xs =
8387                            std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
8388                        let rs =
8389                            std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
8390                        let os = std::slice::from_raw_parts_mut(
8391                            (o_ptr as *mut f32).add(off * h),
8392                            cnt * h,
8393                        );
8394                        let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
8395                        let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
8396                        let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
8397                        crate::kernels::residual_bias_layer_norm(xs, rs, bi, g, b, os, cnt, h, e);
8398                    });
8399                }
8400            }
8401
8402            Thunk::FusedResidualRmsNorm {
8403                x,
8404                res,
8405                bias,
8406                g,
8407                b,
8408                out,
8409                rows,
8410                h,
8411                eps,
8412                has_bias,
8413            } => {
8414                let (rows, h) = (*rows as usize, *h as usize);
8415                unsafe {
8416                    let zero = &zero_bias[..h];
8417                    let bi = if *has_bias { sl(*bias, base, h) } else { zero };
8418                    let x_ptr = sl(*x, base, rows * h).as_ptr() as usize;
8419                    let r_ptr = sl(*res, base, rows * h).as_ptr() as usize;
8420                    let o_ptr = sl_mut(*out, base, rows * h).as_mut_ptr() as usize;
8421                    let bi_ptr = bi.as_ptr() as usize;
8422                    let g_ptr = sl(*g, base, h).as_ptr() as usize;
8423                    let b_ptr = sl(*b, base, h).as_ptr() as usize;
8424                    let e = *eps;
8425                    crate::pool::par_for(rows, 4, &|off, cnt| {
8426                        let xs =
8427                            std::slice::from_raw_parts((x_ptr as *const f32).add(off * h), cnt * h);
8428                        let rs =
8429                            std::slice::from_raw_parts((r_ptr as *const f32).add(off * h), cnt * h);
8430                        let os = std::slice::from_raw_parts_mut(
8431                            (o_ptr as *mut f32).add(off * h),
8432                            cnt * h,
8433                        );
8434                        let bi = std::slice::from_raw_parts(bi_ptr as *const f32, h);
8435                        let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
8436                        let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
8437                        crate::kernels::residual_bias_rms_norm(xs, rs, bi, g, b, os, cnt, h, e);
8438                    });
8439                }
8440            }
8441
8442            Thunk::BiasAdd {
8443                src,
8444                bias,
8445                dst,
8446                m,
8447                n,
8448            } => {
8449                let (m, n) = (*m as usize, *n as usize);
8450                unsafe {
8451                    let out = sl_mut(*dst, base, m * n);
8452                    out.copy_from_slice(sl(*src, base, m * n));
8453                    crate::blas::bias_add(out, sl(*bias, base, n), m, n);
8454                }
8455            }
8456
8457            Thunk::BinaryFull {
8458                lhs,
8459                rhs,
8460                dst,
8461                len,
8462                lhs_len,
8463                rhs_len,
8464                op,
8465                out_dims_bcast,
8466                bcast_lhs_strides,
8467                bcast_rhs_strides,
8468            } => {
8469                let len = *len as usize;
8470                let ll = (*lhs_len as usize).max(1);
8471                let rl = (*rhs_len as usize).max(1);
8472                unsafe {
8473                    let l = sl(*lhs, base, ll);
8474                    let r = sl(*rhs, base, rl);
8475                    let o = sl_mut(*dst, base, len);
8476                    // Fast path: shapes match exactly → NEON-vectorized loop.
8477                    if ll == len && rl == len {
8478                        #[cfg(target_arch = "aarch64")]
8479                        if matches!(op, BinaryOp::Add | BinaryOp::Mul) {
8480                            use std::arch::aarch64::*;
8481                            let chunks = len / 4;
8482                            for c in 0..chunks {
8483                                let off = c * 4;
8484                                let vl = vld1q_f32(l.as_ptr().add(off));
8485                                let vr = vld1q_f32(r.as_ptr().add(off));
8486                                let res = match op {
8487                                    BinaryOp::Add => vaddq_f32(vl, vr),
8488                                    BinaryOp::Mul => vmulq_f32(vl, vr),
8489                                    _ => unreachable!(),
8490                                };
8491                                vst1q_f32(o.as_mut_ptr().add(off), res);
8492                            }
8493                            for i in (chunks * 4)..len {
8494                                o[i] = match op {
8495                                    BinaryOp::Add => l[i] + r[i],
8496                                    BinaryOp::Mul => l[i] * r[i],
8497                                    _ => unreachable!(),
8498                                };
8499                            }
8500                            // `continue` to next thunk in the schedule — a
8501                            // bare `return` here used to exit execute_thunks
8502                            // entirely, silently dropping every thunk after
8503                            // the first BinaryFull (catastrophic for chained
8504                            // adds in BERT embedding stage).
8505                            continue;
8506                        }
8507                    }
8508                    if !out_dims_bcast.is_empty() {
8509                        // Shape-aware broadcast path: correct for
8510                        // bidirectional `[N,1] op [1,S]` etc.
8511                        let rank = out_dims_bcast.len();
8512                        let mut coords = vec![0u32; rank];
8513                        for i in 0..len {
8514                            let mut rem = i;
8515                            for ax in (0..rank).rev() {
8516                                let sz = out_dims_bcast[ax] as usize;
8517                                coords[ax] = (rem % sz) as u32;
8518                                rem /= sz;
8519                            }
8520                            let mut li: usize = 0;
8521                            let mut ri: usize = 0;
8522                            for ax in 0..rank {
8523                                li += coords[ax] as usize * bcast_lhs_strides[ax] as usize;
8524                                ri += coords[ax] as usize * bcast_rhs_strides[ax] as usize;
8525                            }
8526                            o[i] = match op {
8527                                BinaryOp::Add => l[li] + r[ri],
8528                                BinaryOp::Sub => l[li] - r[ri],
8529                                BinaryOp::Mul => l[li] * r[ri],
8530                                BinaryOp::Div => l[li] / r[ri],
8531                                BinaryOp::Max => l[li].max(r[ri]),
8532                                BinaryOp::Min => l[li].min(r[ri]),
8533                                BinaryOp::Pow => l[li].powf(r[ri]),
8534                            };
8535                        }
8536                    } else {
8537                        // Fallback: legacy modulo path (dynamic shapes only).
8538                        for i in 0..len {
8539                            let li = if ll == 1 { 0 } else { i % ll };
8540                            let ri = if rl == 1 { 0 } else { i % rl };
8541                            o[i] = match op {
8542                                BinaryOp::Add => l[li] + r[ri],
8543                                BinaryOp::Sub => l[li] - r[ri],
8544                                BinaryOp::Mul => l[li] * r[ri],
8545                                BinaryOp::Div => l[li] / r[ri],
8546                                BinaryOp::Max => l[li].max(r[ri]),
8547                                BinaryOp::Min => l[li].min(r[ri]),
8548                                BinaryOp::Pow => l[li].powf(r[ri]),
8549                            };
8550                        }
8551                    }
8552                }
8553            }
8554
8555            Thunk::Gather {
8556                table,
8557                table_len,
8558                idx,
8559                dst,
8560                num_idx,
8561                trailing,
8562            } => {
8563                let (ni, tr) = (*num_idx as usize, *trailing as usize);
8564                unsafe {
8565                    let tab = sl(*table, base, *table_len as usize);
8566                    let ids = sl(*idx, base, ni);
8567                    let out = sl_mut(*dst, base, ni * tr);
8568                    for i in 0..ni {
8569                        let row = ids[i] as usize;
8570                        out[i * tr..(i + 1) * tr].copy_from_slice(&tab[row * tr..(row + 1) * tr]);
8571                    }
8572                }
8573            }
8574
8575            Thunk::Narrow {
8576                src,
8577                dst,
8578                outer,
8579                src_stride,
8580                dst_stride,
8581                inner,
8582                elem_bytes,
8583            } => {
8584                let f = narrow_thunk_closure(
8585                    *src,
8586                    *dst,
8587                    *outer,
8588                    *src_stride,
8589                    *dst_stride,
8590                    *inner,
8591                    *elem_bytes,
8592                );
8593                f(base);
8594            }
8595
8596            Thunk::Copy { src, dst, len } => {
8597                let len = *len as usize;
8598                unsafe {
8599                    let s = sl(*src, base, len);
8600                    let d = sl_mut(*dst, base, len);
8601                    d.copy_from_slice(s);
8602                }
8603            }
8604
8605            Thunk::LayerNorm {
8606                src,
8607                g,
8608                b,
8609                dst,
8610                rows,
8611                h,
8612                eps,
8613            } => {
8614                let (rows, h) = (*rows as usize, *h as usize);
8615                unsafe {
8616                    let input = sl(*src, base, rows * h);
8617                    let gamma = sl(*g, base, h);
8618                    let beta = sl(*b, base, h);
8619                    let output = sl_mut(*dst, base, rows * h);
8620                    // Parallelize across rows (same pattern as FusedResidualLN)
8621                    if rows >= 4 && rows * h >= 30_000 {
8622                        let i_ptr = input.as_ptr() as usize;
8623                        let o_ptr = output.as_mut_ptr() as usize;
8624                        let g_ptr = gamma.as_ptr() as usize;
8625                        let b_ptr = beta.as_ptr() as usize;
8626                        let e = *eps;
8627                        crate::pool::par_for(rows, 4, &|off, cnt| {
8628                            let inp = std::slice::from_raw_parts(
8629                                (i_ptr as *const f32).add(off * h),
8630                                cnt * h,
8631                            );
8632                            let out = std::slice::from_raw_parts_mut(
8633                                (o_ptr as *mut f32).add(off * h),
8634                                cnt * h,
8635                            );
8636                            let g = std::slice::from_raw_parts(g_ptr as *const f32, h);
8637                            let b = std::slice::from_raw_parts(b_ptr as *const f32, h);
8638                            for row in 0..cnt {
8639                                crate::kernels::layer_norm_row(
8640                                    &inp[row * h..(row + 1) * h],
8641                                    g,
8642                                    b,
8643                                    &mut out[row * h..(row + 1) * h],
8644                                    h,
8645                                    e,
8646                                );
8647                            }
8648                        });
8649                    } else {
8650                        for row in 0..rows {
8651                            crate::kernels::layer_norm_row(
8652                                &input[row * h..(row + 1) * h],
8653                                gamma,
8654                                beta,
8655                                &mut output[row * h..(row + 1) * h],
8656                                h,
8657                                *eps,
8658                            );
8659                        }
8660                    }
8661                }
8662            }
8663
8664            Thunk::GroupNorm {
8665                src,
8666                g,
8667                b,
8668                dst,
8669                n,
8670                c,
8671                h,
8672                w,
8673                num_groups,
8674                eps,
8675            } => {
8676                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
8677                let plane = c * h * w;
8678                unsafe {
8679                    for ni in 0..n {
8680                        let input = sl(*src, base.add(ni * plane), plane);
8681                        let gamma = sl(*g, base, c);
8682                        let beta = sl(*b, base, c);
8683                        let output = sl_mut(*dst, base.add(ni * plane), plane);
8684                        crate::kernels::group_norm_nchw(
8685                            input,
8686                            gamma,
8687                            beta,
8688                            output,
8689                            1,
8690                            c,
8691                            h,
8692                            w,
8693                            *num_groups as usize,
8694                            *eps,
8695                        );
8696                    }
8697                }
8698            }
8699
8700            Thunk::LayerNorm2d {
8701                src,
8702                g,
8703                b,
8704                dst,
8705                n,
8706                c,
8707                h,
8708                w,
8709                eps,
8710            } => {
8711                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
8712                let plane = c * h * w;
8713                unsafe {
8714                    let input = sl(*src, base, n * plane);
8715                    let gamma = sl(*g, base, c);
8716                    let beta = sl(*b, base, c);
8717                    let output = sl_mut(*dst, base, n * plane);
8718                    crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, *eps);
8719                }
8720            }
8721
8722            Thunk::ConvTranspose2d {
8723                src,
8724                weight,
8725                dst,
8726                n,
8727                c_in,
8728                h,
8729                w_in,
8730                c_out,
8731                h_out,
8732                w_out,
8733                kh,
8734                kw,
8735                sh,
8736                sw,
8737                ph,
8738                pw,
8739                dh,
8740                dw,
8741                groups,
8742            } => {
8743                let n = *n as usize;
8744                let c_in = *c_in as usize;
8745                let h = *h as usize;
8746                let w_in = *w_in as usize;
8747                let c_out = *c_out as usize;
8748                let h_out = *h_out as usize;
8749                let w_out = *w_out as usize;
8750                unsafe {
8751                    let inp = sl(*src, base, n * c_in * h * w_in);
8752                    let wt = sl(
8753                        *weight,
8754                        base,
8755                        c_in * (c_out / *groups as usize) * (*kh as usize) * (*kw as usize),
8756                    );
8757                    let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
8758                    crate::kernels::conv_transpose2d_nchw(
8759                        inp,
8760                        wt,
8761                        out,
8762                        n,
8763                        c_in,
8764                        h,
8765                        w_in,
8766                        c_out,
8767                        h_out,
8768                        w_out,
8769                        *kh as usize,
8770                        *kw as usize,
8771                        *sh as usize,
8772                        *sw as usize,
8773                        *ph as usize,
8774                        *pw as usize,
8775                        *dh as usize,
8776                        *dw as usize,
8777                        *groups as usize,
8778                    );
8779                }
8780            }
8781
8782            Thunk::ResizeNearest2x {
8783                src,
8784                dst,
8785                n,
8786                c,
8787                h,
8788                w,
8789            } => {
8790                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
8791                let in_plane = c * h * w;
8792                let out_plane = c * h * 2 * w * 2;
8793                unsafe {
8794                    for ni in 0..n {
8795                        let input = sl(*src, base.add(ni * in_plane), in_plane);
8796                        let output = sl_mut(*dst, base.add(ni * out_plane), out_plane);
8797                        crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
8798                    }
8799                }
8800            }
8801
8802            Thunk::AxialRope2d {
8803                src,
8804                dst,
8805                batch,
8806                seq,
8807                hidden,
8808                end_x,
8809                end_y,
8810                head_dim,
8811                num_heads,
8812                theta,
8813                repeat_factor,
8814            } => {
8815                let b = *batch as usize;
8816                let s = *seq as usize;
8817                let hdim = *head_dim as usize;
8818                let nh = *num_heads as usize;
8819                let plane = s * (*hidden as usize);
8820                unsafe {
8821                    for bi in 0..b {
8822                        let input = sl(*src, base.add(bi * plane), plane);
8823                        let output = sl_mut(*dst, base.add(bi * plane), plane);
8824                        let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
8825                            input,
8826                            nh,
8827                            s,
8828                            hdim,
8829                            *end_x as usize,
8830                            *end_y as usize,
8831                            *theta,
8832                            *repeat_factor as usize,
8833                        );
8834                        output.copy_from_slice(&rotated);
8835                    }
8836                }
8837            }
8838
8839            Thunk::RmsNorm {
8840                src,
8841                g,
8842                b,
8843                dst,
8844                rows,
8845                h,
8846                eps,
8847            } => {
8848                let (rows, h) = (*rows as usize, *h as usize);
8849                unsafe {
8850                    let input = sl(*src, base, rows * h);
8851                    let gamma = sl(*g, base, h);
8852                    let beta = sl(*b, base, h);
8853                    let output = sl_mut(*dst, base, rows * h);
8854                    let inv_h = 1.0 / h as f32;
8855                    for row in 0..rows {
8856                        let in_row = &input[row * h..(row + 1) * h];
8857                        let out_row = &mut output[row * h..(row + 1) * h];
8858                        // RMS = sqrt(mean(x^2) + eps); scale = 1/RMS.
8859                        let mut sumsq = 0f32;
8860                        for &v in in_row {
8861                            sumsq += v * v;
8862                        }
8863                        let inv_rms = (sumsq * inv_h + *eps).sqrt().recip();
8864                        for i in 0..h {
8865                            out_row[i] = in_row[i] * inv_rms * gamma[i] + beta[i];
8866                        }
8867                    }
8868                }
8869            }
8870
8871            Thunk::Softmax { data, rows, cols } => {
8872                let (rows, cols) = (*rows as usize, *cols as usize);
8873                unsafe {
8874                    crate::kernels::neon_softmax(sl_mut(*data, base, rows * cols), rows, cols);
8875                }
8876            }
8877
8878            Thunk::Cumsum {
8879                src,
8880                dst,
8881                rows,
8882                cols,
8883                exclusive,
8884            } => {
8885                let (rows, cols) = (*rows as usize, *cols as usize);
8886                unsafe {
8887                    let s = sl(*src, base, rows * cols);
8888                    let d = sl_mut(*dst, base, rows * cols);
8889                    if *exclusive {
8890                        for r in 0..rows {
8891                            let mut acc = 0.0f32;
8892                            for c in 0..cols {
8893                                d[r * cols + c] = acc;
8894                                acc += s[r * cols + c];
8895                            }
8896                        }
8897                    } else {
8898                        for r in 0..rows {
8899                            let mut acc = 0.0f32;
8900                            for c in 0..cols {
8901                                acc += s[r * cols + c];
8902                                d[r * cols + c] = acc;
8903                            }
8904                        }
8905                    }
8906                }
8907            }
8908
8909            Thunk::Sample {
8910                logits,
8911                dst,
8912                batch,
8913                vocab,
8914                top_k,
8915                top_p,
8916                temperature,
8917                seed,
8918            } => {
8919                let (b, v) = (*batch as usize, *vocab as usize);
8920                let k = (*top_k as usize).min(v);
8921                unsafe {
8922                    let lg = sl(*logits, base, b * v);
8923                    let out = sl_mut(*dst, base, b);
8924                    let mut rng =
8925                        rlx_ir::Philox4x32::new(if *seed == 0 { 0xDEADBEEF } else { *seed });
8926                    for bi in 0..b {
8927                        let row = &lg[bi * v..(bi + 1) * v];
8928                        out[bi] = sample_row(row, k, *top_p, *temperature, &mut rng) as f32;
8929                    }
8930                }
8931            }
8932
8933            Thunk::GatedDeltaNet {
8934                q,
8935                k,
8936                v,
8937                g,
8938                beta,
8939                state,
8940                dst,
8941                batch,
8942                seq,
8943                heads,
8944                state_size,
8945            } => unsafe {
8946                execute_gated_delta_net_f32(
8947                    *q,
8948                    *k,
8949                    *v,
8950                    *g,
8951                    *beta,
8952                    *state,
8953                    *dst,
8954                    *batch as usize,
8955                    *seq as usize,
8956                    *heads as usize,
8957                    *state_size as usize,
8958                    base,
8959                );
8960            },
8961
8962            Thunk::SelectiveScan {
8963                x,
8964                delta,
8965                a,
8966                b: bp,
8967                c: cp,
8968                dst,
8969                batch,
8970                seq,
8971                hidden,
8972                state_size,
8973            } => {
8974                let (b, s, h, n) = (
8975                    *batch as usize,
8976                    *seq as usize,
8977                    *hidden as usize,
8978                    *state_size as usize,
8979                );
8980                unsafe {
8981                    let xs = sl(*x, base, b * s * h);
8982                    let dt = sl(*delta, base, b * s * h);
8983                    let am = sl(*a, base, h * n);
8984                    let bm = sl(*bp, base, b * s * n);
8985                    let cm = sl(*cp, base, b * s * n);
8986                    let out = sl_mut(*dst, base, b * s * h);
8987
8988                    // State buffer per-batch: h channels × n state.
8989                    // Sequential along the seq dimension; could
8990                    // parallelize over batch+channel later.
8991                    let mut state = vec![0f32; h * n];
8992                    for bi in 0..b {
8993                        // Reset state at the start of each batch row.
8994                        for v in state.iter_mut() {
8995                            *v = 0.0;
8996                        }
8997                        for si in 0..s {
8998                            let x_row = &xs[bi * s * h + si * h..bi * s * h + (si + 1) * h];
8999                            let dt_row = &dt[bi * s * h + si * h..bi * s * h + (si + 1) * h];
9000                            let b_row = &bm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
9001                            let c_row = &cm[bi * s * n + si * n..bi * s * n + (si + 1) * n];
9002                            let out_row = &mut out[bi * s * h + si * h..bi * s * h + (si + 1) * h];
9003
9004                            for ci in 0..h {
9005                                let d = dt_row[ci];
9006                                let xv = x_row[ci];
9007                                let mut acc = 0f32;
9008                                for ni in 0..n {
9009                                    // Discretize: exp(d * a) and d * b.
9010                                    let da = (d * am[ci * n + ni]).exp();
9011                                    state[ci * n + ni] =
9012                                        da * state[ci * n + ni] + d * b_row[ni] * xv;
9013                                    acc += c_row[ni] * state[ci * n + ni];
9014                                }
9015                                out_row[ci] = acc;
9016                            }
9017                        }
9018                    }
9019                }
9020            }
9021
9022            Thunk::DequantMatMul {
9023                x,
9024                w_q,
9025                scale,
9026                zp,
9027                dst,
9028                m,
9029                k,
9030                n,
9031                block_size,
9032                is_asymmetric,
9033            } => {
9034                let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
9035                let n_blocks = k.div_ceil(bs);
9036                unsafe {
9037                    let xs = sl(*x, base, m * k);
9038                    let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const i8, k * n);
9039                    let scales = sl(*scale, base, n_blocks * n);
9040                    let zps = if *is_asymmetric {
9041                        sl(*zp, base, n_blocks * n)
9042                    } else {
9043                        &[][..]
9044                    };
9045                    let out = sl_mut(*dst, base, m * n);
9046                    dequant_matmul_int8(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
9047                }
9048            }
9049
9050            Thunk::DequantMatMulGguf {
9051                x,
9052                w_q,
9053                dst,
9054                m,
9055                k,
9056                n,
9057                scheme,
9058            } => {
9059                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9060                let block_bytes = scheme.gguf_block_bytes() as usize;
9061                let block_elems = scheme.gguf_block_size() as usize;
9062                debug_assert!(
9063                    block_bytes > 0 && block_elems > 0,
9064                    "non-GGUF scheme in GGUF arm"
9065                );
9066                debug_assert!(
9067                    (k * n).is_multiple_of(block_elems),
9068                    "k*n={} not aligned to GGUF block size {}",
9069                    k * n,
9070                    block_elems
9071                );
9072                let total_bytes = (k * n) / block_elems * block_bytes;
9073                unsafe {
9074                    let xs = sl(*x, base, m * k);
9075                    let w_bytes_ptr = base.add(*w_q) as *const u8;
9076                    let w_bytes = std::slice::from_raw_parts(w_bytes_ptr, total_bytes);
9077                    let out = sl_mut(*dst, base, m * n);
9078                    crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, *scheme);
9079                }
9080            }
9081
9082            Thunk::DequantMatMulInt4 {
9083                x,
9084                w_q,
9085                scale,
9086                zp,
9087                dst,
9088                m,
9089                k,
9090                n,
9091                block_size,
9092                is_asymmetric,
9093            } => {
9094                let (m, k, n, bs) = (*m as usize, *k as usize, *n as usize, *block_size as usize);
9095                let n_blocks = k.div_ceil(bs);
9096                unsafe {
9097                    let xs = sl(*x, base, m * k);
9098                    let w_bytes = std::slice::from_raw_parts(
9099                        base.add(*w_q) as *const u8,
9100                        (k * n).div_ceil(2),
9101                    );
9102                    let scales = sl(*scale, base, n_blocks * n);
9103                    let zps = if *is_asymmetric {
9104                        sl(*zp, base, n_blocks * n)
9105                    } else {
9106                        &[][..]
9107                    };
9108                    let out = sl_mut(*dst, base, m * n);
9109                    dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, *is_asymmetric);
9110                }
9111            }
9112
9113            Thunk::DequantMatMulFp8 {
9114                x,
9115                w_q,
9116                scale,
9117                dst,
9118                m,
9119                k,
9120                n,
9121                e5m2,
9122            } => {
9123                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9124                unsafe {
9125                    let xs = sl(*x, base, m * k);
9126                    let w_bytes = std::slice::from_raw_parts(base.add(*w_q) as *const u8, k * n);
9127                    let scales = sl(*scale, base, n);
9128                    let out = sl_mut(*dst, base, m * n);
9129                    dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, *e5m2);
9130                }
9131            }
9132
9133            Thunk::DequantMatMulNvfp4 {
9134                x,
9135                w_q,
9136                scale,
9137                global_scale,
9138                dst,
9139                m,
9140                k,
9141                n,
9142            } => {
9143                let (m, k, n) = (*m as usize, *k as usize, *n as usize);
9144                let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
9145                unsafe {
9146                    let xs = sl(*x, base, m * k);
9147                    let w_bytes = std::slice::from_raw_parts(
9148                        base.add(*w_q) as *const u8,
9149                        (k * n).div_ceil(2),
9150                    );
9151                    let scale_bytes =
9152                        std::slice::from_raw_parts(base.add(*scale) as *const u8, n_scale);
9153                    let gs = sl(*global_scale, base, 1)[0];
9154                    let out = sl_mut(*dst, base, m * n);
9155                    dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
9156                }
9157            }
9158
9159            Thunk::LoraMatMul {
9160                x,
9161                w,
9162                a,
9163                b,
9164                dst,
9165                m,
9166                k,
9167                n,
9168                r,
9169                scale,
9170            } => {
9171                let (m, k, n, r) = (*m as usize, *k as usize, *n as usize, *r as usize);
9172                unsafe {
9173                    let xs = sl(*x, base, m * k);
9174                    let ws = sl(*w, base, k * n);
9175                    let a_s = sl(*a, base, k * r);
9176                    let bs = sl(*b, base, r * n);
9177                    let out = sl_mut(*dst, base, m * n);
9178                    crate::blas::sgemm(xs, ws, out, m, k, n);
9179                    let mut tmp = vec![0f32; m * r];
9180                    crate::blas::sgemm(xs, a_s, &mut tmp, m, k, r);
9181                    if *scale != 1.0 {
9182                        for v in tmp.iter_mut() {
9183                            *v *= *scale;
9184                        }
9185                    }
9186                    crate::blas::sgemm_accumulate(&tmp, bs, out, m, r, n);
9187                }
9188            }
9189
9190            Thunk::Attention {
9191                q,
9192                k,
9193                v,
9194                mask,
9195                out,
9196                batch,
9197                seq,
9198                kv_seq,
9199                heads,
9200                head_dim,
9201                mask_kind,
9202                q_row_stride,
9203                k_row_stride,
9204                v_row_stride,
9205                bhsd,
9206            } => {
9207                let (b, q_s, k_s, nh, dh) = (
9208                    *batch as usize,
9209                    *seq as usize,
9210                    *kv_seq as usize,
9211                    *heads as usize,
9212                    *head_dim as usize,
9213                );
9214                let hs = nh * dh;
9215                // For [B, H, S, D] layout each (b, h) tile is dense
9216                // contiguous; the qrs/krs/vrs strides are not used.
9217                let (qrs, krs, vrs) = if *bhsd {
9218                    (dh, dh, dh)
9219                } else {
9220                    (
9221                        *q_row_stride as usize,
9222                        *k_row_stride as usize,
9223                        *v_row_stride as usize,
9224                    )
9225                };
9226                let bhsd = *bhsd;
9227                let _ = (q_row_stride, k_row_stride, v_row_stride);
9228                let scale = (dh as f32).powf(-0.5);
9229                let ss = q_s * k_s;
9230                let cfg = crate::config::RuntimeConfig::global();
9231                unsafe {
9232                    // Slice lengths cover the strided span. When Q/K/V
9233                    // alias the parent QKV (post-#46-fusion), the same
9234                    // bytes back all three slices — compiler bounds
9235                    // checks see the right size. For [B, H, S, D] the
9236                    // buffer is densely B*H*S*D elements; the row
9237                    // strides aren't used.
9238                    let q_len = if bhsd {
9239                        b * nh * q_s * dh
9240                    } else {
9241                        b * q_s * qrs
9242                    };
9243                    let k_len = if bhsd {
9244                        b * nh * k_s * dh
9245                    } else {
9246                        b * k_s * krs
9247                    };
9248                    let v_len = if bhsd {
9249                        b * nh * k_s * dh
9250                    } else {
9251                        b * k_s * vrs
9252                    };
9253                    let q_data = sl(*q, base, q_len);
9254                    let k_data = sl(*k, base, k_len);
9255                    let v_data = sl(*v, base, v_len);
9256                    let mask_data: &[f32] = match mask_kind {
9257                        rlx_ir::op::MaskKind::Custom => sl(*mask, base, b * k_s),
9258                        rlx_ir::op::MaskKind::Bias => sl(*mask, base, b * nh * q_s * k_s),
9259                        _ => &[],
9260                    };
9261                    let out_len = if bhsd {
9262                        b * nh * q_s * dh
9263                    } else {
9264                        b * q_s * hs
9265                    };
9266                    let out_data = sl_mut(*out, base, out_len);
9267
9268                    // ── [B, H, S, D] fallback ──────────────────────
9269                    // The NEON / strided-BLAS specializations below
9270                    // are written for the [B, S, H, D] layout. When
9271                    // the input is head-major ([B, H, S, D] —
9272                    // matching rlx-cuda / rlx-rocm / rlx-tpu), bypass
9273                    // them and run a simple (correct but slower)
9274                    // scalar implementation. Production-CPU inference
9275                    // graphs use [B, S, H, D] so they still hit the
9276                    // hot path; cross-backend parity tests use
9277                    // [B, H, S, D] and land here.
9278                    if bhsd {
9279                        let scores = &mut sdpa_scores[..ss];
9280                        for bi in 0..b {
9281                            for hi in 0..nh {
9282                                let q_head_base = bi * nh * q_s * dh + hi * q_s * dh;
9283                                let k_head_base = bi * nh * k_s * dh + hi * k_s * dh;
9284                                // Q@K^T
9285                                for qi in 0..q_s {
9286                                    let q_base = q_head_base + qi * dh;
9287                                    for ki in 0..k_s {
9288                                        let k_base = k_head_base + ki * dh;
9289                                        let mut dot = 0f32;
9290                                        for d in 0..dh {
9291                                            dot += q_data[q_base + d] * k_data[k_base + d];
9292                                        }
9293                                        scores[qi * k_s + ki] = dot * scale;
9294                                        if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
9295                                            && !mask_data.is_empty()
9296                                            && mask_data[bi * k_s + ki] < mask_thr
9297                                        {
9298                                            scores[qi * k_s + ki] = mask_neg;
9299                                        }
9300                                    }
9301                                }
9302                                if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
9303                                    let off = (bi * nh + hi) * q_s * k_s;
9304                                    for i in 0..q_s * k_s {
9305                                        scores[i] += mask_data[off + i];
9306                                    }
9307                                }
9308                                apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
9309                                crate::kernels::neon_softmax(scores, q_s, k_s);
9310                                // score @ V
9311                                for qi in 0..q_s {
9312                                    let o_base = q_head_base + qi * dh;
9313                                    for d in 0..dh {
9314                                        out_data[o_base + d] = 0.0;
9315                                    }
9316                                    for ki in 0..k_s {
9317                                        let sc = scores[qi * k_s + ki];
9318                                        if sc > score_thr {
9319                                            let v_base = k_head_base + ki * dh;
9320                                            for d in 0..dh {
9321                                                out_data[o_base + d] += sc * v_data[v_base + d];
9322                                            }
9323                                        }
9324                                    }
9325                                }
9326                            }
9327                        }
9328                        continue;
9329                    }
9330
9331                    // ── Auto-select kernel: NEON dots vs strided BLAS ───
9332                    // For tiny inputs (batch=1, short seq), per-head BLAS call
9333                    // overhead (~0.5µs × 2 calls × num_heads × num_layers)
9334                    // exceeds the NEON compute cost. Use direct strided NEON
9335                    // with zero dispatch overhead.
9336                    // For batch≥2: always BLAS + par_for (parallelism wins).
9337                    if b == 1 && q_s.max(k_s) <= cfg.sdpa_seq_threshold {
9338                        // ── Sequential NEON path (zero overhead) ──
9339                        let scores = &mut sdpa_scores[..ss];
9340                        #[cfg(target_arch = "aarch64")]
9341                        let neon_chunks = dh / 4;
9342
9343                        for bi in 0..b {
9344                            for hi in 0..nh {
9345                                // Q@K^T via strided NEON dot products
9346                                for qi in 0..q_s {
9347                                    let q_off = bi * q_s * qrs + qi * qrs + hi * dh;
9348                                    for ki in 0..k_s {
9349                                        let k_off = bi * k_s * krs + ki * krs + hi * dh;
9350                                        #[cfg(target_arch = "aarch64")]
9351                                        let mut dot;
9352                                        #[cfg(not(target_arch = "aarch64"))]
9353                                        let mut dot = 0f32;
9354                                        #[cfg(target_arch = "aarch64")]
9355                                        {
9356                                            use std::arch::aarch64::*;
9357                                            let mut acc = vdupq_n_f32(0.0);
9358                                            for c in 0..neon_chunks {
9359                                                let vq =
9360                                                    vld1q_f32(q_data.as_ptr().add(q_off + c * 4));
9361                                                let vk =
9362                                                    vld1q_f32(k_data.as_ptr().add(k_off + c * 4));
9363                                                acc = vfmaq_f32(acc, vq, vk);
9364                                            }
9365                                            dot = vaddvq_f32(acc);
9366                                            for d in (neon_chunks * 4)..dh {
9367                                                dot += q_data[q_off + d] * k_data[k_off + d];
9368                                            }
9369                                        }
9370                                        #[cfg(not(target_arch = "aarch64"))]
9371                                        for d in 0..dh {
9372                                            dot += q_data[q_off + d] * k_data[k_off + d];
9373                                        }
9374                                        scores[qi * k_s + ki] = dot * scale;
9375                                        // Inner-loop Custom mask check —
9376                                        // Causal / SlidingWindow / None
9377                                        // apply outside the loop below.
9378                                        // Skip for Bias — that mask is a
9379                                        // per-head additive tensor, not a
9380                                        // 0/1 key-padding mask.
9381                                        if matches!(mask_kind, rlx_ir::op::MaskKind::Custom)
9382                                            && !mask_data.is_empty()
9383                                            && mask_data[bi * k_s + ki] < mask_thr
9384                                        {
9385                                            scores[qi * k_s + ki] = mask_neg;
9386                                        }
9387                                    }
9388                                }
9389
9390                                if matches!(mask_kind, rlx_ir::op::MaskKind::Bias) {
9391                                    let off = (bi * nh + hi) * q_s * k_s;
9392                                    for i in 0..q_s * k_s {
9393                                        scores[i] += mask_data[off + i];
9394                                    }
9395                                }
9396                                apply_synthetic_mask(scores, q_s, k_s, *mask_kind);
9397                                crate::kernels::neon_softmax(scores, q_s, k_s);
9398
9399                                // Score@V via strided NEON accumulation (zero-copy)
9400                                for qi in 0..q_s {
9401                                    let o_off = bi * q_s * hs + qi * hs + hi * dh;
9402                                    // Zero output for this head position
9403                                    for d in 0..dh {
9404                                        out_data[o_off + d] = 0.0;
9405                                    }
9406                                    for ki in 0..k_s {
9407                                        let sc = scores[qi * k_s + ki];
9408                                        if sc > score_thr {
9409                                            let v_off = bi * k_s * vrs + ki * vrs + hi * dh;
9410                                            #[cfg(target_arch = "aarch64")]
9411                                            {
9412                                                use std::arch::aarch64::*;
9413                                                let vsc = vdupq_n_f32(sc);
9414                                                for c in 0..neon_chunks {
9415                                                    let off = c * 4;
9416                                                    let vo = vld1q_f32(
9417                                                        out_data.as_ptr().add(o_off + off),
9418                                                    );
9419                                                    let vv =
9420                                                        vld1q_f32(v_data.as_ptr().add(v_off + off));
9421                                                    vst1q_f32(
9422                                                        out_data.as_mut_ptr().add(o_off + off),
9423                                                        vfmaq_f32(vo, vsc, vv),
9424                                                    );
9425                                                }
9426                                            }
9427                                            #[cfg(not(target_arch = "aarch64"))]
9428                                            for d in 0..dh {
9429                                                out_data[o_off + d] += sc * v_data[v_off + d];
9430                                            }
9431                                        }
9432                                    }
9433                                }
9434                            }
9435                        }
9436                    } else {
9437                        // ── Parallel strided BLAS path (high throughput) ──
9438                        let total_work = b * nh;
9439                        let q_addr = q_data.as_ptr() as usize;
9440                        let k_addr = k_data.as_ptr() as usize;
9441                        let v_addr = v_data.as_ptr() as usize;
9442                        let m_addr = mask_data.as_ptr() as usize;
9443                        let o_addr = out_data.as_mut_ptr() as usize;
9444                        let sc_addr = sdpa_scores.as_mut_ptr() as usize;
9445
9446                        crate::pool::par_for(total_work, 1, &|off, cnt| {
9447                            for idx in off..off + cnt {
9448                                let bi = idx / nh;
9449                                let hi = idx % nh;
9450
9451                                let q_start = (q_addr as *const f32).add(bi * q_s * qrs + hi * dh);
9452                                let k_start = (k_addr as *const f32).add(bi * k_s * krs + hi * dh);
9453                                let v_start = (v_addr as *const f32).add(bi * k_s * vrs + hi * dh);
9454                                let o_start = (o_addr as *mut f32).add(bi * q_s * hs + hi * dh);
9455                                let sc = std::slice::from_raw_parts_mut(
9456                                    (sc_addr as *mut f32).add(idx * ss),
9457                                    ss,
9458                                );
9459
9460                                // LDA = qrs, LDB = krs (parent row strides
9461                                // when fused; hs otherwise).
9462                                crate::blas::sgemm_general(
9463                                    q_start,
9464                                    k_start,
9465                                    sc.as_mut_ptr(),
9466                                    q_s,
9467                                    k_s,
9468                                    dh,
9469                                    scale,
9470                                    0.0,
9471                                    qrs,
9472                                    krs,
9473                                    k_s,
9474                                    false,
9475                                    true,
9476                                );
9477
9478                                match mask_kind {
9479                                    rlx_ir::op::MaskKind::Custom => {
9480                                        let mask_bi = std::slice::from_raw_parts(
9481                                            (m_addr as *const f32).add(bi * k_s),
9482                                            k_s,
9483                                        );
9484                                        for ki in 0..k_s {
9485                                            if mask_bi[ki] < mask_thr {
9486                                                for qi in 0..q_s {
9487                                                    sc[qi * k_s + ki] = mask_neg;
9488                                                }
9489                                            }
9490                                        }
9491                                    }
9492                                    rlx_ir::op::MaskKind::Bias => {
9493                                        // Per-head additive bias slice.
9494                                        let bias = std::slice::from_raw_parts(
9495                                            (m_addr as *const f32).add((bi * nh + hi) * q_s * k_s),
9496                                            q_s * k_s,
9497                                        );
9498                                        for i in 0..q_s * k_s {
9499                                            sc[i] += bias[i];
9500                                        }
9501                                    }
9502                                    _ => apply_synthetic_mask(sc, q_s, k_s, *mask_kind),
9503                                }
9504
9505                                crate::kernels::neon_softmax(sc, q_s, k_s);
9506
9507                                // LDB = vrs (parent row stride when
9508                                // fused; hs otherwise). LDC stays hs —
9509                                // output is its own contiguous buffer.
9510                                crate::blas::sgemm_general(
9511                                    sc.as_ptr(),
9512                                    v_start,
9513                                    o_start,
9514                                    q_s,
9515                                    dh,
9516                                    k_s,
9517                                    1.0,
9518                                    0.0,
9519                                    k_s,
9520                                    vrs,
9521                                    hs,
9522                                    false,
9523                                    false,
9524                                );
9525                            }
9526                        });
9527                    }
9528                }
9529            }
9530
9531            Thunk::AttentionBackward {
9532                q,
9533                k,
9534                v,
9535                dy,
9536                mask,
9537                out,
9538                batch,
9539                seq,
9540                kv_seq,
9541                heads,
9542                head_dim,
9543                mask_kind,
9544                wrt,
9545                bhsd,
9546            } => {
9547                let (b, q_s, k_s, nh, dh) = (
9548                    *batch as usize,
9549                    *seq as usize,
9550                    *kv_seq as usize,
9551                    *heads as usize,
9552                    *head_dim as usize,
9553                );
9554                unsafe {
9555                    let q_len = if *bhsd {
9556                        b * nh * q_s * dh
9557                    } else {
9558                        b * q_s * nh * dh
9559                    };
9560                    let k_len = if *bhsd {
9561                        b * nh * k_s * dh
9562                    } else {
9563                        b * k_s * nh * dh
9564                    };
9565                    let out_len = match wrt {
9566                        rlx_ir::op::AttentionBwdWrt::Key | rlx_ir::op::AttentionBwdWrt::Value => {
9567                            k_len
9568                        }
9569                        rlx_ir::op::AttentionBwdWrt::Query => q_len,
9570                    };
9571                    let q_data = sl(*q, base, q_len);
9572                    let k_data = sl(*k, base, k_len);
9573                    let v_data = sl(*v, base, k_len);
9574                    let dy_data = sl(*dy, base, q_len);
9575                    let out_data = sl_mut(*out, base, out_len);
9576                    let mask_data: &[f32] = if *mask != 0 {
9577                        let ml = match mask_kind {
9578                            rlx_ir::op::MaskKind::Custom => b * k_s,
9579                            rlx_ir::op::MaskKind::Bias => b * nh * q_s * k_s,
9580                            _ => 0,
9581                        };
9582                        sl(*mask, base, ml)
9583                    } else {
9584                        &[]
9585                    };
9586                    crate::attention_bwd::attention_backward(
9587                        *wrt, q_data, k_data, v_data, dy_data, out_data, b, nh, q_s, k_s, dh,
9588                        *mask_kind, mask_data, *bhsd,
9589                    );
9590                }
9591            }
9592
9593            Thunk::ActivationInPlace { data, len, act } => {
9594                let len = *len as usize;
9595                unsafe {
9596                    let d = sl_mut(*data, base, len);
9597                    match act {
9598                        Activation::Gelu => crate::kernels::par_gelu_inplace(d),
9599                        Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
9600                        Activation::Silu => crate::kernels::par_silu_inplace(d),
9601                        Activation::Relu => {
9602                            for v in d.iter_mut() {
9603                                *v = v.max(0.0);
9604                            }
9605                        }
9606                        Activation::Sigmoid => {
9607                            for v in d.iter_mut() {
9608                                *v = 1.0 / (1.0 + (-*v).exp());
9609                            }
9610                        }
9611                        Activation::Tanh => {
9612                            for v in d.iter_mut() {
9613                                *v = v.tanh();
9614                            }
9615                        }
9616                        Activation::Exp => {
9617                            for v in d.iter_mut() {
9618                                *v = v.exp();
9619                            }
9620                        }
9621                        Activation::Log => {
9622                            for v in d.iter_mut() {
9623                                *v = v.ln();
9624                            }
9625                        }
9626                        Activation::Sqrt => {
9627                            for v in d.iter_mut() {
9628                                *v = v.sqrt();
9629                            }
9630                        }
9631                        Activation::Rsqrt => {
9632                            for v in d.iter_mut() {
9633                                *v = 1.0 / v.sqrt();
9634                            }
9635                        }
9636                        Activation::Neg => {
9637                            for v in d.iter_mut() {
9638                                *v = -*v;
9639                            }
9640                        }
9641                        Activation::Abs => {
9642                            for v in d.iter_mut() {
9643                                *v = v.abs();
9644                            }
9645                        }
9646                        Activation::Round => {
9647                            for v in d.iter_mut() {
9648                                *v = v.round();
9649                            }
9650                        }
9651                        Activation::Sin => {
9652                            for v in d.iter_mut() {
9653                                *v = v.sin();
9654                            }
9655                        }
9656                        Activation::Cos => {
9657                            for v in d.iter_mut() {
9658                                *v = v.cos();
9659                            }
9660                        }
9661                        Activation::Tan => {
9662                            for v in d.iter_mut() {
9663                                *v = v.tan();
9664                            }
9665                        }
9666                        Activation::Atan => {
9667                            for v in d.iter_mut() {
9668                                *v = v.atan();
9669                            }
9670                        }
9671                    }
9672                }
9673            }
9674
9675            Thunk::FusedAttnBlock {
9676                hidden,
9677                qkv_w,
9678                out_w,
9679                mask,
9680                out,
9681                qkv_b,
9682                out_b,
9683                cos,
9684                sin,
9685                cos_len,
9686                batch,
9687                seq,
9688                hs,
9689                nh,
9690                dh,
9691                has_bias,
9692                has_rope,
9693            } => {
9694                let (b, s) = (*batch as usize, *seq as usize);
9695                let (h, n_h, d_h) = (*hs as usize, *nh as usize, *dh as usize);
9696                let m = b * s;
9697                let scale = (d_h as f32).powf(-0.5);
9698                let half = d_h / 2;
9699                unsafe {
9700                    let inp = sl(*hidden, base, m * h);
9701                    let wq = sl(*qkv_w, base, h * 3 * h);
9702                    let wo = sl(*out_w, base, h * h);
9703                    let mk = sl(*mask, base, b * s);
9704                    let dst = sl_mut(*out, base, m * h);
9705
9706                    // Stack-allocated intermediates — all fit in L1 cache for small batch
9707                    let mut qkv = vec![0f32; m * 3 * h];
9708                    let mut attn_out = vec![0f32; m * h];
9709                    let mut scores_buf = vec![0f32; s * s]; // one head at a time
9710
9711                    // 1. QKV projection: [m, h] @ [h, 3h] → [m, 3h]
9712                    crate::blas::sgemm(inp, wq, &mut qkv, m, h, 3 * h);
9713                    if *has_bias {
9714                        let bias = sl(*qkv_b, base, 3 * h);
9715                        crate::blas::bias_add(&mut qkv, bias, m, 3 * h);
9716                    }
9717
9718                    // 2. Multi-head SDPA (Q/K/V are views into qkv at offsets 0, h, 2h)
9719                    //    Process heads sequentially with inline RoPE — zero copy.
9720                    #[cfg(target_arch = "aarch64")]
9721                    let neon_chunks = d_h / 4;
9722                    #[cfg(target_arch = "aarch64")]
9723                    let _rope_chunks = half / 4;
9724
9725                    for bi in 0..b {
9726                        for hi in 0..n_h {
9727                            // For each (query_pos, key_pos): compute Q@K^T with inline RoPE
9728                            for qi in 0..s {
9729                                let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
9730                                for ki in 0..s {
9731                                    let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
9732                                    let mut dot = 0f32;
9733
9734                                    if *has_rope {
9735                                        // Apply RoPE inline during dot product
9736                                        let q_cos = qi * half;
9737                                        let k_cos = ki * half;
9738                                        let cos_tab = sl(*cos, base, *cos_len as usize);
9739                                        let sin_tab = sl(*sin, base, *cos_len as usize);
9740                                        // First half: (q1*c - q2*s) * (k1*c - k2*s)
9741                                        // Second half: (q2*c + q1*s) * (k2*c + k1*s)
9742                                        for i in 0..half {
9743                                            let q1 = qkv[q_base + i];
9744                                            let q2 = qkv[q_base + half + i];
9745                                            let k1 = qkv[k_base + i];
9746                                            let k2 = qkv[k_base + half + i];
9747                                            let c_q = cos_tab[q_cos + i];
9748                                            let s_q = sin_tab[q_cos + i];
9749                                            let c_k = cos_tab[k_cos + i];
9750                                            let s_k = sin_tab[k_cos + i];
9751                                            let qr1 = q1 * c_q - q2 * s_q;
9752                                            let kr1 = k1 * c_k - k2 * s_k;
9753                                            let qr2 = q2 * c_q + q1 * s_q;
9754                                            let kr2 = k2 * c_k + k1 * s_k;
9755                                            dot += qr1 * kr1 + qr2 * kr2;
9756                                        }
9757                                    } else {
9758                                        // Standard dot product
9759                                        #[cfg(target_arch = "aarch64")]
9760                                        {
9761                                            use std::arch::aarch64::*;
9762                                            let mut acc = vdupq_n_f32(0.0);
9763                                            for c in 0..neon_chunks {
9764                                                let vq =
9765                                                    vld1q_f32(qkv.as_ptr().add(q_base + c * 4));
9766                                                let vk =
9767                                                    vld1q_f32(qkv.as_ptr().add(k_base + c * 4));
9768                                                acc = vfmaq_f32(acc, vq, vk);
9769                                            }
9770                                            dot = vaddvq_f32(acc);
9771                                            for d in (neon_chunks * 4)..d_h {
9772                                                dot += qkv[q_base + d] * qkv[k_base + d];
9773                                            }
9774                                        }
9775                                        #[cfg(not(target_arch = "aarch64"))]
9776                                        for d in 0..d_h {
9777                                            dot += qkv[q_base + d] * qkv[k_base + d];
9778                                        }
9779                                    }
9780
9781                                    scores_buf[qi * s + ki] = dot * scale;
9782                                    if mk[bi * s + ki] < mask_thr {
9783                                        scores_buf[qi * s + ki] = mask_neg;
9784                                    }
9785                                }
9786                            }
9787
9788                            // Softmax
9789                            crate::kernels::neon_softmax(&mut scores_buf[..s * s], s, s);
9790
9791                            // Score @ V accumulation (V at offset 2h in QKV)
9792                            for qi in 0..s {
9793                                let o_base = bi * s * h + qi * h + hi * d_h;
9794                                for d in 0..d_h {
9795                                    attn_out[o_base + d] = 0.0;
9796                                }
9797                                for ki in 0..s {
9798                                    let sc = scores_buf[qi * s + ki];
9799                                    if sc > score_thr {
9800                                        let v_base = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
9801                                        #[cfg(target_arch = "aarch64")]
9802                                        {
9803                                            use std::arch::aarch64::*;
9804                                            let vsc = vdupq_n_f32(sc);
9805                                            for c in 0..neon_chunks {
9806                                                let off = c * 4;
9807                                                let vo =
9808                                                    vld1q_f32(attn_out.as_ptr().add(o_base + off));
9809                                                let vv = vld1q_f32(qkv.as_ptr().add(v_base + off));
9810                                                vst1q_f32(
9811                                                    attn_out.as_mut_ptr().add(o_base + off),
9812                                                    vfmaq_f32(vo, vsc, vv),
9813                                                );
9814                                            }
9815                                        }
9816                                        #[cfg(not(target_arch = "aarch64"))]
9817                                        for d in 0..d_h {
9818                                            attn_out[o_base + d] += sc * qkv[v_base + d];
9819                                        }
9820                                    }
9821                                }
9822                            }
9823                        }
9824                    }
9825
9826                    // 3. Output projection: [m, h] @ [h, h] → dst
9827                    crate::blas::sgemm(&attn_out, wo, dst, m, h, h);
9828                    if *has_bias {
9829                        let bias = sl(*out_b, base, h);
9830                        crate::blas::bias_add(dst, bias, m, h);
9831                    }
9832                }
9833            }
9834
9835            Thunk::Rope {
9836                src,
9837                cos,
9838                sin,
9839                dst,
9840                batch,
9841                seq,
9842                hidden,
9843                head_dim,
9844                n_rot,
9845                cos_len,
9846                src_row_stride,
9847            } => {
9848                let (b, s, hs, dh, nr) = (
9849                    *batch as usize,
9850                    *seq as usize,
9851                    *hidden as usize,
9852                    *head_dim as usize,
9853                    *n_rot as usize,
9854                );
9855                let tab_half = dh / 2;
9856                let rot_half = nr / 2;
9857                let nh = hs / dh;
9858                let cl = *cos_len as usize;
9859                let src_rs = *src_row_stride as usize;
9860                unsafe {
9861                    let x = sl(*src, base, b * s * src_rs);
9862                    let cos_tab = sl(*cos, base, cl);
9863                    let sin_tab = sl(*sin, base, cl);
9864                    let out = sl_mut(*dst, base, b * s * hs);
9865
9866                    let total = b * s;
9867                    let x_ptr = x.as_ptr() as usize;
9868                    let o_ptr = out.as_mut_ptr() as usize;
9869                    let c_ptr = cos_tab.as_ptr() as usize;
9870                    let s_ptr = sin_tab.as_ptr() as usize;
9871
9872                    crate::pool::par_for(total, 4, &|off, cnt| {
9873                        for idx in off..off + cnt {
9874                            let bi = idx / s;
9875                            let si = idx % s;
9876                            let tab_off = si * tab_half;
9877
9878                            for hi in 0..nh {
9879                                let src_base = bi * s * src_rs + si * src_rs + hi * dh;
9880                                let dst_base = bi * s * hs + si * hs + hi * dh;
9881                                let xp = (x_ptr as *const f32).add(src_base);
9882                                let op = (o_ptr as *mut f32).add(dst_base);
9883                                let cp = (c_ptr as *const f32).add(tab_off);
9884                                let sp = (s_ptr as *const f32).add(tab_off);
9885
9886                                for i in 0..rot_half {
9887                                    let x1 = *xp.add(i);
9888                                    let x2 = *xp.add(rot_half + i);
9889                                    let cv = *cp.add(i);
9890                                    let sv = *sp.add(i);
9891                                    *op.add(i) = x1 * cv - x2 * sv;
9892                                    *op.add(rot_half + i) = x2 * cv + x1 * sv;
9893                                }
9894                                for j in nr..dh {
9895                                    *op.add(j) = *xp.add(j);
9896                                }
9897                            }
9898                        }
9899                    });
9900                }
9901            }
9902            Thunk::FusedBertLayer {
9903                hidden,
9904                qkv_w,
9905                qkv_b,
9906                out_w,
9907                out_b,
9908                mask,
9909                ln1_g,
9910                ln1_b,
9911                eps1,
9912                fc1_w,
9913                fc1_b,
9914                fc2_w,
9915                fc2_b,
9916                ln2_g,
9917                ln2_b,
9918                eps2,
9919                out,
9920                batch,
9921                seq,
9922                hs,
9923                nh,
9924                dh,
9925                int_dim,
9926            } => {
9927                let (b, s, h, n_h, d_h) = (
9928                    *batch as usize,
9929                    *seq as usize,
9930                    *hs as usize,
9931                    *nh as usize,
9932                    *dh as usize,
9933                );
9934                let m = b * s;
9935                let id = *int_dim as usize;
9936                let scale = (d_h as f32).powf(-0.5);
9937                let _half = d_h / 2;
9938                #[cfg(target_arch = "aarch64")]
9939                let neon_chunks = d_h / 4;
9940                unsafe {
9941                    let inp = sl(*hidden, base, m * h);
9942                    let dst = sl_mut(*out, base, m * h);
9943                    let mk = sl(*mask, base, b * s);
9944
9945                    // Pre-allocated buffers (zero malloc per layer — allocated once before thunk loop)
9946                    let qkv = std::slice::from_raw_parts_mut(fl_qkv.as_mut_ptr(), m * 3 * h);
9947                    let attn = std::slice::from_raw_parts_mut(fl_attn.as_mut_ptr(), m * h);
9948                    let res = std::slice::from_raw_parts_mut(fl_res.as_mut_ptr(), m * h);
9949                    let normed = std::slice::from_raw_parts_mut(fl_normed.as_mut_ptr(), m * h);
9950                    let ffn = std::slice::from_raw_parts_mut(fl_ffn.as_mut_ptr(), m * id);
9951                    let sc = std::slice::from_raw_parts_mut(fl_sc.as_mut_ptr(), s * s);
9952
9953                    // QKV (parallelized across cores — multiple AMX coprocessors)
9954                    crate::blas::par_sgemm_bias(
9955                        inp,
9956                        sl(*qkv_w, base, h * 3 * h),
9957                        sl(*qkv_b, base, 3 * h),
9958                        qkv,
9959                        m,
9960                        h,
9961                        3 * h,
9962                    );
9963
9964                    // SDPA per head (sequential NEON, inline — zero overhead)
9965                    for bi in 0..b {
9966                        for hi in 0..n_h {
9967                            for qi in 0..s {
9968                                for ki in 0..s {
9969                                    let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
9970                                    let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
9971                                    #[cfg(target_arch = "aarch64")]
9972                                    let dot;
9973                                    #[cfg(not(target_arch = "aarch64"))]
9974                                    let mut dot = 0f32;
9975                                    #[cfg(target_arch = "aarch64")]
9976                                    {
9977                                        use std::arch::aarch64::*;
9978                                        let mut acc = vdupq_n_f32(0.0);
9979                                        for c in 0..neon_chunks {
9980                                            acc = vfmaq_f32(
9981                                                acc,
9982                                                vld1q_f32(qkv.as_ptr().add(q_base + c * 4)),
9983                                                vld1q_f32(qkv.as_ptr().add(k_base + c * 4)),
9984                                            );
9985                                        }
9986                                        dot = vaddvq_f32(acc);
9987                                    }
9988                                    #[cfg(not(target_arch = "aarch64"))]
9989                                    for d in 0..d_h {
9990                                        dot += qkv[q_base + d] * qkv[k_base + d];
9991                                    }
9992                                    sc[qi * s + ki] = dot * scale;
9993                                    if mk[bi * s + ki] < mask_thr {
9994                                        sc[qi * s + ki] = mask_neg;
9995                                    }
9996                                }
9997                            }
9998                            crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
9999                            for qi in 0..s {
10000                                let o = bi * s * h + qi * h + hi * d_h;
10001                                for d in 0..d_h {
10002                                    attn[o + d] = 0.0;
10003                                }
10004                                for ki in 0..s {
10005                                    let w = sc[qi * s + ki];
10006                                    if w > score_thr {
10007                                        let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
10008                                        #[cfg(target_arch = "aarch64")]
10009                                        {
10010                                            use std::arch::aarch64::*;
10011                                            let vw = vdupq_n_f32(w);
10012                                            for c in 0..neon_chunks {
10013                                                let off = c * 4;
10014                                                vst1q_f32(
10015                                                    attn.as_mut_ptr().add(o + off),
10016                                                    vfmaq_f32(
10017                                                        vld1q_f32(attn.as_ptr().add(o + off)),
10018                                                        vw,
10019                                                        vld1q_f32(qkv.as_ptr().add(v + off)),
10020                                                    ),
10021                                                );
10022                                            }
10023                                        }
10024                                        #[cfg(not(target_arch = "aarch64"))]
10025                                        for d in 0..d_h {
10026                                            attn[o + d] += w * qkv[v + d];
10027                                        }
10028                                    }
10029                                }
10030                            }
10031                        }
10032                    }
10033
10034                    // Out proj (sgemm + bias fused) + residual add with NEON
10035                    crate::blas::sgemm_bias(
10036                        attn,
10037                        sl(*out_w, base, h * h),
10038                        sl(*out_b, base, h),
10039                        res,
10040                        m,
10041                        h,
10042                        h,
10043                    );
10044                    #[cfg(target_arch = "aarch64")]
10045                    {
10046                        use std::arch::aarch64::*;
10047                        let chunks_h = (m * h) / 4;
10048                        for c in 0..chunks_h {
10049                            let off = c * 4;
10050                            vst1q_f32(
10051                                res.as_mut_ptr().add(off),
10052                                vaddq_f32(
10053                                    vld1q_f32(res.as_ptr().add(off)),
10054                                    vld1q_f32(inp.as_ptr().add(off)),
10055                                ),
10056                            );
10057                        }
10058                        for i in (chunks_h * 4)..(m * h) {
10059                            res[i] += inp[i];
10060                        }
10061                    }
10062                    #[cfg(not(target_arch = "aarch64"))]
10063                    for i in 0..m * h {
10064                        res[i] += inp[i];
10065                    }
10066
10067                    // LN1 (fused residual already done above — just normalize)
10068                    let g1 = sl(*ln1_g, base, h);
10069                    let b1 = sl(*ln1_b, base, h);
10070                    for r in 0..m {
10071                        crate::kernels::layer_norm_row(
10072                            &res[r * h..(r + 1) * h],
10073                            g1,
10074                            b1,
10075                            &mut normed[r * h..(r + 1) * h],
10076                            h,
10077                            *eps1,
10078                        );
10079                    }
10080
10081                    // FFN: fc1 (parallel across cores) + GELU
10082                    crate::blas::par_sgemm_bias(
10083                        normed,
10084                        sl(*fc1_w, base, h * id),
10085                        sl(*fc1_b, base, id),
10086                        ffn,
10087                        m,
10088                        h,
10089                        id,
10090                    );
10091                    crate::kernels::par_gelu_inplace(ffn);
10092
10093                    // fc2 + bias (parallel across cores) + residual with NEON
10094                    crate::blas::par_sgemm_bias(
10095                        ffn,
10096                        sl(*fc2_w, base, id * h),
10097                        sl(*fc2_b, base, h),
10098                        res,
10099                        m,
10100                        id,
10101                        h,
10102                    );
10103                    #[cfg(target_arch = "aarch64")]
10104                    {
10105                        use std::arch::aarch64::*;
10106                        let chunks_h = (m * h) / 4;
10107                        for c in 0..chunks_h {
10108                            let off = c * 4;
10109                            vst1q_f32(
10110                                res.as_mut_ptr().add(off),
10111                                vaddq_f32(
10112                                    vld1q_f32(res.as_ptr().add(off)),
10113                                    vld1q_f32(normed.as_ptr().add(off)),
10114                                ),
10115                            );
10116                        }
10117                        for i in (chunks_h * 4)..(m * h) {
10118                            res[i] += normed[i];
10119                        }
10120                    }
10121                    #[cfg(not(target_arch = "aarch64"))]
10122                    for i in 0..m * h {
10123                        res[i] += normed[i];
10124                    }
10125
10126                    // LN2 → output
10127                    let g2 = sl(*ln2_g, base, h);
10128                    let b2 = sl(*ln2_b, base, h);
10129                    for r in 0..m {
10130                        crate::kernels::layer_norm_row(
10131                            &res[r * h..(r + 1) * h],
10132                            g2,
10133                            b2,
10134                            &mut dst[r * h..(r + 1) * h],
10135                            h,
10136                            *eps2,
10137                        );
10138                    }
10139                }
10140            }
10141
10142            Thunk::FusedNomicLayer {
10143                hidden,
10144                qkv_w,
10145                out_w,
10146                mask,
10147                cos,
10148                sin,
10149                cos_len,
10150                ln1_g,
10151                ln1_b,
10152                eps1,
10153                fc11_w,
10154                fc12_w: _,
10155                fc2_w,
10156                ln2_g,
10157                ln2_b,
10158                eps2,
10159                out,
10160                batch,
10161                seq,
10162                hs,
10163                nh,
10164                dh,
10165                int_dim,
10166            } => {
10167                let (b, s, h, n_h, d_h) = (
10168                    *batch as usize,
10169                    *seq as usize,
10170                    *hs as usize,
10171                    *nh as usize,
10172                    *dh as usize,
10173                );
10174                let m = b * s;
10175                let id = *int_dim as usize;
10176                let scale = (d_h as f32).powf(-0.5);
10177                let half_dh = d_h / 2;
10178                #[cfg(target_arch = "aarch64")]
10179                let neon_chunks = d_h / 4;
10180                unsafe {
10181                    let inp = sl(*hidden, base, m * h);
10182                    let dst = sl_mut(*out, base, m * h);
10183                    let mk = sl(*mask, base, b * s);
10184                    let cos_tab = sl(*cos, base, *cos_len as usize);
10185                    let sin_tab = sl(*sin, base, *cos_len as usize);
10186                    // fc11_w is the fused [h, 2*int_dim] weight (fc11 || fc12 concatenated)
10187                    let fused_fc_w = sl(*fc11_w, base, h * 2 * id);
10188
10189                    let mut qkv = vec![0f32; m * 3 * h];
10190                    let mut attn = vec![0f32; m * h];
10191                    let mut res = vec![0f32; m * h];
10192                    let mut normed = vec![0f32; m * h];
10193                    let mut ffn_concat = vec![0f32; m * 2 * id]; // fc11||fc12 output
10194                    let mut sc = vec![0f32; s * s];
10195
10196                    // QKV (no bias)
10197                    crate::blas::sgemm(inp, sl(*qkv_w, base, h * 3 * h), &mut qkv, m, h, 3 * h);
10198
10199                    // SDPA with inline RoPE
10200                    for bi in 0..b {
10201                        for hi in 0..n_h {
10202                            for qi in 0..s {
10203                                for ki in 0..s {
10204                                    let q_base = bi * s * 3 * h + qi * 3 * h + hi * d_h;
10205                                    let k_base = bi * s * 3 * h + ki * 3 * h + h + hi * d_h;
10206                                    let mut dot = 0f32;
10207                                    for i in 0..half_dh {
10208                                        let q1 = qkv[q_base + i];
10209                                        let q2 = qkv[q_base + half_dh + i];
10210                                        let k1 = qkv[k_base + i];
10211                                        let k2 = qkv[k_base + half_dh + i];
10212                                        let cq = cos_tab[qi * half_dh + i];
10213                                        let sq = sin_tab[qi * half_dh + i];
10214                                        let ck = cos_tab[ki * half_dh + i];
10215                                        let sk = sin_tab[ki * half_dh + i];
10216                                        dot += (q1 * cq - q2 * sq) * (k1 * ck - k2 * sk)
10217                                            + (q2 * cq + q1 * sq) * (k2 * ck + k1 * sk);
10218                                    }
10219                                    sc[qi * s + ki] = dot * scale;
10220                                    if mk[bi * s + ki] < mask_thr {
10221                                        sc[qi * s + ki] = mask_neg;
10222                                    }
10223                                }
10224                            }
10225                            crate::kernels::neon_softmax(&mut sc[..s * s], s, s);
10226                            for qi in 0..s {
10227                                let o = bi * s * h + qi * h + hi * d_h;
10228                                for d in 0..d_h {
10229                                    attn[o + d] = 0.0;
10230                                }
10231                                for ki in 0..s {
10232                                    let w = sc[qi * s + ki];
10233                                    if w > score_thr {
10234                                        let v = bi * s * 3 * h + ki * 3 * h + 2 * h + hi * d_h;
10235                                        #[cfg(target_arch = "aarch64")]
10236                                        {
10237                                            use std::arch::aarch64::*;
10238                                            let vw = vdupq_n_f32(w);
10239                                            for c in 0..neon_chunks {
10240                                                let off = c * 4;
10241                                                vst1q_f32(
10242                                                    attn.as_mut_ptr().add(o + off),
10243                                                    vfmaq_f32(
10244                                                        vld1q_f32(attn.as_ptr().add(o + off)),
10245                                                        vw,
10246                                                        vld1q_f32(qkv.as_ptr().add(v + off)),
10247                                                    ),
10248                                                );
10249                                            }
10250                                        }
10251                                        #[cfg(not(target_arch = "aarch64"))]
10252                                        for d in 0..d_h {
10253                                            attn[o + d] += w * qkv[v + d];
10254                                        }
10255                                    }
10256                                }
10257                            }
10258                        }
10259                    }
10260
10261                    // Out proj (no bias) + residual
10262                    crate::blas::sgemm(&attn, sl(*out_w, base, h * h), &mut res, m, h, h);
10263                    for i in 0..m * h {
10264                        res[i] += inp[i];
10265                    }
10266
10267                    // LN1
10268                    let g1 = sl(*ln1_g, base, h);
10269                    let b1 = sl(*ln1_b, base, h);
10270                    for r in 0..m {
10271                        crate::kernels::layer_norm_row(
10272                            &res[r * h..(r + 1) * h],
10273                            g1,
10274                            b1,
10275                            &mut normed[r * h..(r + 1) * h],
10276                            h,
10277                            *eps1,
10278                        );
10279                    }
10280
10281                    // SwiGLU: fused fc11+fc12 sgemm, then split, silu, mul
10282                    crate::blas::sgemm(&normed, fused_fc_w, &mut ffn_concat, m, h, 2 * id);
10283                    // Split: first id cols = fc11 (up), second id cols = fc12 (gate)
10284                    // SiLU on gate, then multiply up * gate → store in up region
10285                    for row in 0..m {
10286                        let bo = row * 2 * id;
10287                        // SiLU in-place on gate portion
10288                        for j in 0..id {
10289                            let x = ffn_concat[bo + id + j];
10290                            ffn_concat[bo + id + j] = x / (1.0 + (-x).exp());
10291                        }
10292                        // Multiply: up[j] *= gate[j]
10293                        for j in 0..id {
10294                            ffn_concat[bo + j] *= ffn_concat[bo + id + j];
10295                        }
10296                    }
10297
10298                    // fc2 (no bias) + residual  — read from first id cols of ffn_concat
10299                    // Need contiguous [m, id] for sgemm. Copy or use strided sgemm.
10300                    // The up*gate result is at ffn_concat[row * 2*id .. row * 2*id + id]
10301                    // Stride = 2*id. Use sgemm_general with lda = 2*id.
10302                    crate::blas::sgemm_general(
10303                        ffn_concat.as_ptr(),
10304                        sl(*fc2_w, base, id * h).as_ptr(),
10305                        res.as_mut_ptr(),
10306                        m,
10307                        h,
10308                        id,
10309                        1.0,
10310                        0.0,
10311                        2 * id,
10312                        h,
10313                        h,
10314                        false,
10315                        false,
10316                    );
10317                    for i in 0..m * h {
10318                        res[i] += normed[i];
10319                    }
10320
10321                    // LN2 → output
10322                    let g2 = sl(*ln2_g, base, h);
10323                    let b2 = sl(*ln2_b, base, h);
10324                    for r in 0..m {
10325                        crate::kernels::layer_norm_row(
10326                            &res[r * h..(r + 1) * h],
10327                            g2,
10328                            b2,
10329                            &mut dst[r * h..(r + 1) * h],
10330                            h,
10331                            *eps2,
10332                        );
10333                    }
10334                }
10335            }
10336
10337            Thunk::FusedSwiGLU {
10338                src,
10339                dst,
10340                n_half,
10341                total,
10342                gate_first,
10343            } => {
10344                let n = *n_half as usize;
10345                let t = *total as usize;
10346                let outer = t / n;
10347                let in_total = outer * 2 * n;
10348                let gate_first = *gate_first;
10349                unsafe {
10350                    let inp = sl(*src, base, in_total);
10351                    let out = sl_mut(*dst, base, t);
10352                    for o in 0..outer {
10353                        let in_row = &inp[o * 2 * n..(o + 1) * 2 * n];
10354                        let out_row = &mut out[o * n..(o + 1) * n];
10355                        for i in 0..n {
10356                            let (up, gate) = if gate_first {
10357                                (in_row[n + i], in_row[i])
10358                            } else {
10359                                (in_row[i], in_row[n + i])
10360                            };
10361                            out_row[i] = up * (gate / (1.0 + (-gate).exp()));
10362                        }
10363                    }
10364                }
10365            }
10366
10367            Thunk::Concat {
10368                dst,
10369                outer,
10370                inner,
10371                total_axis,
10372                inputs,
10373            } => {
10374                let outer = *outer as usize;
10375                let inner = *inner as usize;
10376                let total_axis = *total_axis as usize;
10377                let row_stride = total_axis * inner;
10378                let out_total = outer * row_stride;
10379                unsafe {
10380                    let out = sl_mut(*dst, base, out_total);
10381                    let mut cum: usize = 0;
10382                    for (src_off, in_axis) in inputs {
10383                        let in_axis = *in_axis as usize;
10384                        let copy_per_row = in_axis * inner;
10385                        let dst_col_off = cum * inner;
10386                        let in_total = outer * copy_per_row;
10387                        let inp = sl(*src_off, base, in_total);
10388                        for o in 0..outer {
10389                            let dst_row_start = o * row_stride + dst_col_off;
10390                            let src_row_start = o * copy_per_row;
10391                            out[dst_row_start..dst_row_start + copy_per_row]
10392                                .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
10393                        }
10394                        cum += in_axis;
10395                    }
10396                }
10397            }
10398
10399            Thunk::ConcatF64 {
10400                dst,
10401                outer,
10402                inner,
10403                total_axis,
10404                inputs,
10405            } => {
10406                let outer = *outer as usize;
10407                let inner = *inner as usize;
10408                let total_axis = *total_axis as usize;
10409                let row_stride = total_axis * inner;
10410                let out_total = outer * row_stride;
10411                unsafe {
10412                    let out = sl_mut_f64(*dst, base, out_total);
10413                    let mut cum: usize = 0;
10414                    for (src_off, in_axis) in inputs {
10415                        let in_axis = *in_axis as usize;
10416                        let copy_per_row = in_axis * inner;
10417                        let dst_col_off = cum * inner;
10418                        let in_total = outer * copy_per_row;
10419                        let inp = sl_f64(*src_off, base, in_total);
10420                        for o in 0..outer {
10421                            let dst_row_start = o * row_stride + dst_col_off;
10422                            let src_row_start = o * copy_per_row;
10423                            out[dst_row_start..dst_row_start + copy_per_row]
10424                                .copy_from_slice(&inp[src_row_start..src_row_start + copy_per_row]);
10425                        }
10426                        cum += in_axis;
10427                    }
10428                }
10429            }
10430
10431            Thunk::Compare {
10432                lhs,
10433                rhs,
10434                dst,
10435                len,
10436                op,
10437            } => {
10438                let len = *len as usize;
10439                unsafe {
10440                    let l = sl(*lhs, base, len);
10441                    let r = sl(*rhs, base, len);
10442                    let o = sl_mut(*dst, base, len);
10443                    for i in 0..len {
10444                        o[i] = match op {
10445                            CmpOp::Eq => (l[i] == r[i]) as u32 as f32,
10446                            CmpOp::Ne => (l[i] != r[i]) as u32 as f32,
10447                            CmpOp::Lt => (l[i] < r[i]) as u32 as f32,
10448                            CmpOp::Le => (l[i] <= r[i]) as u32 as f32,
10449                            CmpOp::Gt => (l[i] > r[i]) as u32 as f32,
10450                            CmpOp::Ge => (l[i] >= r[i]) as u32 as f32,
10451                        };
10452                    }
10453                }
10454            }
10455
10456            Thunk::Where {
10457                cond,
10458                on_true,
10459                on_false,
10460                dst,
10461                len,
10462            } => {
10463                let len = *len as usize;
10464                unsafe {
10465                    let c = sl(*cond, base, len);
10466                    let t = sl(*on_true, base, len);
10467                    let e = sl(*on_false, base, len);
10468                    let o = sl_mut(*dst, base, len);
10469                    for i in 0..len {
10470                        // Treat cond as boolean: any non-zero → true.
10471                        o[i] = if c[i] != 0.0 { t[i] } else { e[i] };
10472                    }
10473                }
10474            }
10475
10476            Thunk::ScatterAdd {
10477                updates,
10478                indices,
10479                dst,
10480                num_updates,
10481                out_dim,
10482                trailing,
10483            } => {
10484                let num_updates = *num_updates as usize;
10485                let out_dim = *out_dim as usize;
10486                let trailing = *trailing as usize;
10487                unsafe {
10488                    let upd = sl(*updates, base, num_updates * trailing);
10489                    let ids = sl(*indices, base, num_updates);
10490                    let out = sl_mut(*dst, base, out_dim * trailing);
10491                    // Zero the output first — semantics are accumulate-into-zeros.
10492                    for v in out.iter_mut() {
10493                        *v = 0.0;
10494                    }
10495                    for i in 0..num_updates {
10496                        let row = ids[i] as usize;
10497                        debug_assert!(row < out_dim, "ScatterAdd index out of range");
10498                        let src_off = i * trailing;
10499                        let dst_off = row * trailing;
10500                        for j in 0..trailing {
10501                            out[dst_off + j] += upd[src_off + j];
10502                        }
10503                    }
10504                }
10505            }
10506
10507            Thunk::GroupedMatMul {
10508                input,
10509                weight,
10510                expert_idx,
10511                dst,
10512                m,
10513                k_dim,
10514                n,
10515                num_experts,
10516            } => {
10517                let m = *m as usize;
10518                let k_dim = *k_dim as usize;
10519                let n = *n as usize;
10520                let num_experts = *num_experts as usize;
10521                unsafe {
10522                    let inp = sl(*input, base, m * k_dim);
10523                    let wt = sl(*weight, base, num_experts * k_dim * n);
10524                    let ids = sl(*expert_idx, base, m);
10525                    let out = sl_mut(*dst, base, m * n);
10526
10527                    // Counting-sort tokens by their assigned expert.
10528                    // counts[e] = how many tokens routed to expert e.
10529                    let mut counts = vec![0usize; num_experts];
10530                    for i in 0..m {
10531                        let e = ids[i] as usize;
10532                        debug_assert!(
10533                            e < num_experts,
10534                            "expert_idx out of range: {e} >= {num_experts}"
10535                        );
10536                        counts[e] += 1;
10537                    }
10538                    // Cumulative offsets into the packed buffer.
10539                    let mut offsets = vec![0usize; num_experts + 1];
10540                    for e in 0..num_experts {
10541                        offsets[e + 1] = offsets[e] + counts[e];
10542                    }
10543                    // Pack: each expert's rows land contiguously in `packed_in`.
10544                    // `original_pos[packed_idx] = original_token_idx` for the
10545                    // unpermute step at the end.
10546                    let mut packed_in = vec![0f32; m * k_dim];
10547                    let mut original_pos = vec![0usize; m];
10548                    let mut write_idx = vec![0usize; num_experts];
10549                    for i in 0..m {
10550                        let e = ids[i] as usize;
10551                        let dst_row = offsets[e] + write_idx[e];
10552                        packed_in[dst_row * k_dim..(dst_row + 1) * k_dim]
10553                            .copy_from_slice(&inp[i * k_dim..(i + 1) * k_dim]);
10554                        original_pos[dst_row] = i;
10555                        write_idx[e] += 1;
10556                    }
10557
10558                    // One BLAS sgemm per expert. Skip experts with no
10559                    // tokens — common at the tail when M is much smaller
10560                    // than num_experts × k.
10561                    let mut packed_out = vec![0f32; m * n];
10562                    let expert_stride = k_dim * n;
10563                    let gmm_ord = crate::moe_residency::next_gmm_ord();
10564                    let moe_layer = gmm_ord / 3;
10565                    for e in 0..num_experts {
10566                        let count = counts[e];
10567                        if count == 0 {
10568                            continue;
10569                        }
10570                        crate::moe_residency::record_expert_tokens(moe_layer, e, count);
10571                        let in_start = offsets[e];
10572                        let in_slice = &packed_in[in_start * k_dim..(in_start + count) * k_dim];
10573                        let w_slab: &[f32] =
10574                            if !crate::moe_residency::expert_on_device_for_layer(moe_layer, e) {
10575                                if let Some(ptr) =
10576                                    crate::moe_residency::host_expert_weight_ptr(gmm_ord, e)
10577                                {
10578                                    std::slice::from_raw_parts(ptr, expert_stride)
10579                                } else {
10580                                    &wt[e * expert_stride..(e + 1) * expert_stride]
10581                                }
10582                            } else {
10583                                &wt[e * expert_stride..(e + 1) * expert_stride]
10584                            };
10585                        let out_slice = &mut packed_out[in_start * n..(in_start + count) * n];
10586                        crate::blas::sgemm(in_slice, w_slab, out_slice, count, k_dim, n);
10587                    }
10588
10589                    // Unpermute back to original token order.
10590                    for packed_idx in 0..m {
10591                        let i = original_pos[packed_idx];
10592                        out[i * n..(i + 1) * n]
10593                            .copy_from_slice(&packed_out[packed_idx * n..(packed_idx + 1) * n]);
10594                    }
10595                }
10596            }
10597
10598            Thunk::DequantGroupedMatMulGguf {
10599                input,
10600                w_q,
10601                expert_idx,
10602                dst,
10603                m,
10604                k_dim,
10605                n,
10606                num_experts,
10607                scheme,
10608            } => {
10609                let m = *m as usize;
10610                let k_dim = *k_dim as usize;
10611                let n = *n as usize;
10612                let num_experts = *num_experts as usize;
10613                let block_elems = scheme.gguf_block_size() as usize;
10614                let block_bytes = scheme.gguf_block_bytes() as usize;
10615                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
10616                unsafe {
10617                    let inp = sl(*input, base, m * k_dim);
10618                    let wt = std::slice::from_raw_parts(
10619                        base.add(*w_q) as *const u8,
10620                        num_experts * slab_bytes,
10621                    );
10622                    let ids = sl(*expert_idx, base, m);
10623                    let out = sl_mut(*dst, base, m * n);
10624                    crate::gguf_matmul::gguf_grouped_matmul_bt(
10625                        inp,
10626                        wt,
10627                        ids,
10628                        out,
10629                        m,
10630                        k_dim,
10631                        n,
10632                        num_experts,
10633                        *scheme,
10634                    );
10635                }
10636            }
10637
10638            Thunk::DequantMoEWeightsGguf {
10639                w_q,
10640                dst,
10641                k_dim,
10642                n,
10643                num_experts,
10644                scheme,
10645            } => {
10646                let k_dim = *k_dim as usize;
10647                let n = *n as usize;
10648                let num_experts = *num_experts as usize;
10649                let block_elems = scheme.gguf_block_size() as usize;
10650                let block_bytes = scheme.gguf_block_bytes() as usize;
10651                let slab_bytes = (k_dim * n) / block_elems * block_bytes;
10652                unsafe {
10653                    let wt = std::slice::from_raw_parts(
10654                        base.add(*w_q) as *const u8,
10655                        num_experts * slab_bytes,
10656                    );
10657                    let out = sl_mut(*dst, base, num_experts * k_dim * n);
10658                    crate::gguf_matmul::dequant_moe_weights_to_grouped_f32(
10659                        wt,
10660                        out,
10661                        num_experts,
10662                        k_dim,
10663                        n,
10664                        *scheme,
10665                    );
10666                }
10667            }
10668
10669            Thunk::TopK {
10670                src,
10671                dst,
10672                outer,
10673                axis_dim,
10674                k,
10675            } => {
10676                let outer = *outer as usize;
10677                let axis_dim = *axis_dim as usize;
10678                let k = *k as usize;
10679                unsafe {
10680                    let inp = sl(*src, base, outer * axis_dim);
10681                    let out = sl_mut(*dst, base, outer * k);
10682                    // Repeated argmax with masking. O(k * axis_dim) per row;
10683                    // good enough for small k (MoE typical k=2–8). For larger
10684                    // k a partial heap would win.
10685                    let mut row_buf: Vec<f32> = vec![0.0; axis_dim];
10686                    for o in 0..outer {
10687                        row_buf.copy_from_slice(&inp[o * axis_dim..(o + 1) * axis_dim]);
10688                        for ki in 0..k {
10689                            // Find argmax with tie-break to smaller index.
10690                            let mut best_i = 0usize;
10691                            let mut best_v = row_buf[0];
10692                            for i in 1..axis_dim {
10693                                let v = row_buf[i];
10694                                if v > best_v {
10695                                    best_v = v;
10696                                    best_i = i;
10697                                }
10698                            }
10699                            out[o * k + ki] = best_i as f32;
10700                            // Mask the chosen index so the next pass picks
10701                            // the next-largest instead.
10702                            row_buf[best_i] = f32::NEG_INFINITY;
10703                        }
10704                    }
10705                    if let Some(cap) = schedule.moe_topk_capture.as_ref() {
10706                        cap.push_topk_f32(&out[..outer * k], axis_dim);
10707                    }
10708                }
10709            }
10710
10711            Thunk::Reduce {
10712                src,
10713                dst,
10714                outer,
10715                reduced,
10716                inner,
10717                op,
10718            } => {
10719                let outer = *outer as usize;
10720                let reduced = *reduced as usize;
10721                let inner = *inner as usize;
10722                let in_total = outer * reduced * inner;
10723                let out_total = outer * inner;
10724                unsafe {
10725                    let inp = sl(*src, base, in_total);
10726                    let out = sl_mut(*dst, base, out_total);
10727                    for o in 0..outer {
10728                        for i in 0..inner {
10729                            let mut acc = match op {
10730                                ReduceOp::Max => f32::NEG_INFINITY,
10731                                ReduceOp::Min => f32::INFINITY,
10732                                ReduceOp::Prod => 1.0f32,
10733                                _ => 0.0f32, // Sum / Mean
10734                            };
10735                            // Walk the reduced axis with stride `inner`.
10736                            for r in 0..reduced {
10737                                let v = inp[o * reduced * inner + r * inner + i];
10738                                acc = match op {
10739                                    ReduceOp::Sum | ReduceOp::Mean => acc + v,
10740                                    ReduceOp::Max => acc.max(v),
10741                                    ReduceOp::Min => acc.min(v),
10742                                    ReduceOp::Prod => acc * v,
10743                                };
10744                            }
10745                            if matches!(op, ReduceOp::Mean) {
10746                                acc /= reduced as f32;
10747                            }
10748                            out[o * inner + i] = acc;
10749                        }
10750                    }
10751                }
10752            }
10753
10754            Thunk::Conv2D1x1 {
10755                src,
10756                weight,
10757                dst,
10758                n,
10759                c_in,
10760                c_out,
10761                hw,
10762            } => {
10763                let n = *n as usize;
10764                let c_in = *c_in as usize;
10765                let c_out = *c_out as usize;
10766                let hw = *hw as usize;
10767                unsafe {
10768                    let inp = sl(*src, base, n * c_in * hw);
10769                    let wt = sl(*weight, base, c_out * c_in);
10770                    let out = sl_mut(*dst, base, n * c_out * hw);
10771                    // Per-batch sgemm: weight [c_out, c_in] @ input
10772                    // [c_in, hw] = output [c_out, hw]. The weight is
10773                    // shared across batches, so we get to dispatch
10774                    // BLAS once per N (typically 1).
10775                    for ni in 0..n {
10776                        let in_off = ni * c_in * hw;
10777                        let out_off = ni * c_out * hw;
10778                        crate::blas::sgemm(
10779                            wt,
10780                            &inp[in_off..in_off + c_in * hw],
10781                            &mut out[out_off..out_off + c_out * hw],
10782                            c_out,
10783                            c_in,
10784                            hw,
10785                        );
10786                    }
10787                }
10788            }
10789
10790            Thunk::Conv2D {
10791                src,
10792                weight,
10793                dst,
10794                n,
10795                c_in,
10796                h,
10797                w,
10798                c_out,
10799                h_out,
10800                w_out,
10801                kh,
10802                kw,
10803                sh,
10804                sw,
10805                ph,
10806                pw,
10807                dh,
10808                dw,
10809                groups,
10810            } => {
10811                let n = *n as usize;
10812                let c_in = *c_in as usize;
10813                let h = *h as usize;
10814                let w = *w as usize;
10815                let c_out = *c_out as usize;
10816                let h_out = *h_out as usize;
10817                let w_out = *w_out as usize;
10818                let kh = *kh as usize;
10819                let kw = *kw as usize;
10820                let sh = *sh as usize;
10821                let sw = *sw as usize;
10822                let ph = *ph as usize;
10823                let pw = *pw as usize;
10824                let dh = *dh as usize;
10825                let dw = *dw as usize;
10826                let groups = *groups as usize;
10827                let c_in_per_g = c_in / groups;
10828                let c_out_per_g = c_out / groups;
10829                unsafe {
10830                    let inp = sl(*src, base, n * c_in * h * w);
10831                    let wt = sl(*weight, base, c_out * c_in_per_g * kh * kw);
10832                    let out = sl_mut(*dst, base, n * c_out * h_out * w_out);
10833                    for ni in 0..n {
10834                        for co in 0..c_out {
10835                            let g = co / c_out_per_g;
10836                            let ci_start = g * c_in_per_g;
10837                            for ho in 0..h_out {
10838                                for wo in 0..w_out {
10839                                    let mut acc = 0f32;
10840                                    for ci_off in 0..c_in_per_g {
10841                                        let ci = ci_start + ci_off;
10842                                        let in_chan = ((ni * c_in) + ci) * h * w;
10843                                        let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
10844                                        for ki in 0..kh {
10845                                            for kj in 0..kw {
10846                                                let hi = ho * sh + ki * dh;
10847                                                let wi = wo * sw + kj * dw;
10848                                                if hi < ph || wi < pw {
10849                                                    continue;
10850                                                }
10851                                                let hi = hi - ph;
10852                                                let wi = wi - pw;
10853                                                if hi >= h || wi >= w {
10854                                                    continue;
10855                                                }
10856                                                acc += inp[in_chan + hi * w + wi]
10857                                                    * wt[wt_chan + ki * kw + kj];
10858                                            }
10859                                        }
10860                                    }
10861                                    out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] =
10862                                        acc;
10863                                }
10864                            }
10865                        }
10866                    }
10867                }
10868            }
10869
10870            Thunk::Pool2D {
10871                src,
10872                dst,
10873                n,
10874                c,
10875                h,
10876                w,
10877                h_out,
10878                w_out,
10879                kh,
10880                kw,
10881                sh,
10882                sw,
10883                ph,
10884                pw,
10885                kind,
10886            } => {
10887                let n = *n as usize;
10888                let c = *c as usize;
10889                let h = *h as usize;
10890                let w = *w as usize;
10891                let h_out = *h_out as usize;
10892                let w_out = *w_out as usize;
10893                let kh = *kh as usize;
10894                let kw = *kw as usize;
10895                let sh = *sh as usize;
10896                let sw = *sw as usize;
10897                let ph = *ph as usize;
10898                let pw = *pw as usize;
10899                let kernel_area = (kh * kw) as f32;
10900                unsafe {
10901                    let inp = sl(*src, base, n * c * h * w);
10902                    let out = sl_mut(*dst, base, n * c * h_out * w_out);
10903                    for ni in 0..n {
10904                        for ci in 0..c {
10905                            let in_chan = ni * c * h * w + ci * h * w;
10906                            let out_chan = ni * c * h_out * w_out + ci * h_out * w_out;
10907                            for ho in 0..h_out {
10908                                for wo in 0..w_out {
10909                                    let mut acc = match kind {
10910                                        ReduceOp::Max => f32::NEG_INFINITY,
10911                                        _ => 0f32, // Mean (and Sum/Min/Prod fall back here)
10912                                    };
10913                                    for ki in 0..kh {
10914                                        for kj in 0..kw {
10915                                            let hi = ho * sh + ki;
10916                                            let wi = wo * sw + kj;
10917                                            // Padded-zero region.
10918                                            if hi < ph || wi < pw {
10919                                                continue;
10920                                            }
10921                                            let hi = hi - ph;
10922                                            let wi = wi - pw;
10923                                            if hi >= h || wi >= w {
10924                                                continue;
10925                                            }
10926                                            let v = inp[in_chan + hi * w + wi];
10927                                            match kind {
10928                                                ReduceOp::Max => acc = acc.max(v),
10929                                                _ => acc += v,
10930                                            }
10931                                        }
10932                                    }
10933                                    if matches!(kind, ReduceOp::Mean) {
10934                                        acc /= kernel_area;
10935                                    }
10936                                    out[out_chan + ho * w_out + wo] = acc;
10937                                }
10938                            }
10939                        }
10940                    }
10941                }
10942            }
10943
10944            Thunk::ReluBackward { x, dy, dx, len } => {
10945                let len = *len as usize;
10946                unsafe {
10947                    let xs = sl(*x, base, len);
10948                    let dys = sl(*dy, base, len);
10949                    let out = sl_mut(*dx, base, len);
10950                    for i in 0..len {
10951                        out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
10952                    }
10953                }
10954            }
10955
10956            Thunk::ReluBackwardF64 { x, dy, dx, len } => {
10957                let len = *len as usize;
10958                unsafe {
10959                    let xs = sl_f64(*x, base, len);
10960                    let dys = sl_f64(*dy, base, len);
10961                    let out = sl_mut_f64(*dx, base, len);
10962                    for i in 0..len {
10963                        out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
10964                    }
10965                }
10966            }
10967
10968            Thunk::QMatMul {
10969                x,
10970                w,
10971                bias,
10972                out,
10973                m,
10974                k,
10975                n,
10976                x_zp,
10977                w_zp,
10978                out_zp,
10979                mult,
10980            } => {
10981                let m = *m as usize;
10982                let k = *k as usize;
10983                let n = *n as usize;
10984                unsafe {
10985                    let x_ptr = base.add(*x) as *const i8;
10986                    let w_ptr = base.add(*w) as *const i8;
10987                    let bias_ptr = base.add(*bias) as *const i32;
10988                    let out_ptr = base.add(*out) as *mut i8;
10989                    for mi in 0..m {
10990                        for ni in 0..n {
10991                            let mut acc: i32 = *bias_ptr.add(ni);
10992                            for ki in 0..k {
10993                                let xv = *x_ptr.add(mi * k + ki) as i32 - *x_zp;
10994                                let wv = *w_ptr.add(ki * n + ni) as i32 - *w_zp;
10995                                acc += xv * wv;
10996                            }
10997                            // Requantize: round(acc · mult) + out_zp,
10998                            // clamped to i8.
10999                            let r = (acc as f32 * *mult).round() as i32 + *out_zp;
11000                            let r = r.clamp(-128, 127) as i8;
11001                            *out_ptr.add(mi * n + ni) = r;
11002                        }
11003                    }
11004                }
11005            }
11006
11007            Thunk::QConv2d {
11008                x,
11009                w,
11010                bias,
11011                out,
11012                n,
11013                c_in,
11014                h,
11015                w_in,
11016                c_out,
11017                h_out,
11018                w_out,
11019                kh,
11020                kw,
11021                sh,
11022                sw,
11023                ph,
11024                pw,
11025                dh,
11026                dw,
11027                groups,
11028                x_zp,
11029                w_zp,
11030                out_zp,
11031                mult,
11032            } => {
11033                let n = *n as usize;
11034                let c_in = *c_in as usize;
11035                let h = *h as usize;
11036                let w_in = *w_in as usize;
11037                let c_out = *c_out as usize;
11038                let h_out = *h_out as usize;
11039                let w_out = *w_out as usize;
11040                let kh = *kh as usize;
11041                let kw = *kw as usize;
11042                let sh = *sh as usize;
11043                let sw = *sw as usize;
11044                let ph = *ph as usize;
11045                let pw = *pw as usize;
11046                let dh = *dh as usize;
11047                let dw = *dw as usize;
11048                let groups = *groups as usize;
11049                let c_in_per_g = c_in / groups;
11050                let c_out_per_g = c_out / groups;
11051                unsafe {
11052                    let x_ptr = base.add(*x) as *const i8;
11053                    let w_ptr = base.add(*w) as *const i8;
11054                    let bias_ptr = base.add(*bias) as *const i32;
11055                    let out_ptr = base.add(*out) as *mut i8;
11056                    for ni in 0..n {
11057                        for co in 0..c_out {
11058                            let g = co / c_out_per_g;
11059                            let ci_start = g * c_in_per_g;
11060                            for ho in 0..h_out {
11061                                for wo in 0..w_out {
11062                                    let mut acc: i32 = *bias_ptr.add(co);
11063                                    for ci_off in 0..c_in_per_g {
11064                                        let ci = ci_start + ci_off;
11065                                        let in_chan = ((ni * c_in) + ci) * h * w_in;
11066                                        let wt_chan = ((co * c_in_per_g) + ci_off) * kh * kw;
11067                                        for ki in 0..kh {
11068                                            for kj in 0..kw {
11069                                                let hi = ho * sh + ki * dh;
11070                                                let wi = wo * sw + kj * dw;
11071                                                if hi < ph || wi < pw {
11072                                                    continue;
11073                                                }
11074                                                let hi = hi - ph;
11075                                                let wi = wi - pw;
11076                                                if hi >= h || wi >= w_in {
11077                                                    continue;
11078                                                }
11079                                                let xv = *x_ptr.add(in_chan + hi * w_in + wi)
11080                                                    as i32
11081                                                    - *x_zp;
11082                                                let wv = *w_ptr.add(wt_chan + ki * kw + kj) as i32
11083                                                    - *w_zp;
11084                                                acc += xv * wv;
11085                                            }
11086                                        }
11087                                    }
11088                                    let r = (acc as f32 * *mult).round() as i32 + *out_zp;
11089                                    let r = r.clamp(-128, 127) as i8;
11090                                    let dst = ((ni * c_out) + co) * h_out * w_out + ho * w_out + wo;
11091                                    *out_ptr.add(dst) = r;
11092                                }
11093                            }
11094                        }
11095                    }
11096                }
11097            }
11098
11099            Thunk::Quantize {
11100                x,
11101                q,
11102                len,
11103                chan_axis: _,
11104                chan_dim,
11105                inner,
11106                scales,
11107                zero_points,
11108            } => {
11109                let len = *len as usize;
11110                let chan_dim = *chan_dim as usize;
11111                let inner = *inner as usize;
11112                unsafe {
11113                    let xs = sl(*x, base, len);
11114                    let q_ptr = base.add(*q) as *mut i8;
11115                    for i in 0..len {
11116                        let c = if chan_dim == 1 {
11117                            0
11118                        } else {
11119                            (i / inner) % chan_dim
11120                        };
11121                        let inv_scale = 1.0 / scales[c];
11122                        let zp = zero_points[c];
11123                        let v = (xs[i] * inv_scale).round() as i32 + zp;
11124                        *q_ptr.add(i) = v.clamp(-128, 127) as i8;
11125                    }
11126                }
11127            }
11128
11129            Thunk::Dequantize {
11130                q,
11131                x,
11132                len,
11133                chan_axis: _,
11134                chan_dim,
11135                inner,
11136                scales,
11137                zero_points,
11138            } => {
11139                let len = *len as usize;
11140                let chan_dim = *chan_dim as usize;
11141                let inner = *inner as usize;
11142                unsafe {
11143                    let q_ptr = base.add(*q) as *const i8;
11144                    let out = sl_mut(*x, base, len);
11145                    for i in 0..len {
11146                        let c = if chan_dim == 1 {
11147                            0
11148                        } else {
11149                            (i / inner) % chan_dim
11150                        };
11151                        let scale = scales[c];
11152                        let zp = zero_points[c];
11153                        let qv = *q_ptr.add(i) as i32;
11154                        out[i] = (qv - zp) as f32 * scale;
11155                    }
11156                }
11157            }
11158
11159            Thunk::FakeQuantize {
11160                x,
11161                out,
11162                len,
11163                chan_axis: _,
11164                chan_dim,
11165                inner,
11166                bits,
11167                ste: _,
11168                scale_mode,
11169                state_off,
11170            } => {
11171                use rlx_ir::op::ScaleMode;
11172                let len = *len as usize;
11173                let chan_dim = *chan_dim as usize;
11174                let inner = *inner as usize;
11175                let q_max: f32 = match *bits {
11176                    8 => 127.0,
11177                    4 => 7.0,
11178                    2 => 1.0,
11179                    n => panic!("FakeQuantize: unsupported bits {n}"),
11180                };
11181                unsafe {
11182                    let xs = sl(*x, base, len);
11183                    let outs = sl_mut(*out, base, len);
11184
11185                    let mut scale = vec![0f32; chan_dim];
11186                    match scale_mode {
11187                        ScaleMode::PerBatch => {
11188                            let mut max_abs = vec![0f32; chan_dim];
11189                            for i in 0..len {
11190                                let c = if chan_dim == 1 {
11191                                    0
11192                                } else {
11193                                    (i / inner) % chan_dim
11194                                };
11195                                let a = xs[i].abs();
11196                                if a > max_abs[c] {
11197                                    max_abs[c] = a;
11198                                }
11199                            }
11200                            for c in 0..chan_dim {
11201                                scale[c] = (max_abs[c] / q_max).max(1e-12);
11202                            }
11203                        }
11204                        ScaleMode::EMA { decay } => {
11205                            // Per-channel current max-abs, then blend
11206                            // into the running state in place.
11207                            let mut max_abs = vec![0f32; chan_dim];
11208                            for i in 0..len {
11209                                let c = if chan_dim == 1 {
11210                                    0
11211                                } else {
11212                                    (i / inner) % chan_dim
11213                                };
11214                                let a = xs[i].abs();
11215                                if a > max_abs[c] {
11216                                    max_abs[c] = a;
11217                                }
11218                            }
11219                            let state =
11220                                sl_mut(state_off.expect("EMA needs state_off"), base, chan_dim);
11221                            for c in 0..chan_dim {
11222                                let cur = (max_abs[c] / q_max).max(1e-12);
11223                                // Cold-start: state==0 → seed directly.
11224                                let blended = if state[c] <= 0.0 {
11225                                    cur
11226                                } else {
11227                                    *decay * state[c] + (1.0 - *decay) * cur
11228                                };
11229                                state[c] = blended;
11230                                scale[c] = blended;
11231                            }
11232                        }
11233                        ScaleMode::Fixed => {
11234                            let state =
11235                                sl(state_off.expect("Fixed needs state_off"), base, chan_dim);
11236                            for c in 0..chan_dim {
11237                                scale[c] = state[c].max(1e-12);
11238                            }
11239                        }
11240                    }
11241
11242                    for i in 0..len {
11243                        let c = if chan_dim == 1 {
11244                            0
11245                        } else {
11246                            (i / inner) % chan_dim
11247                        };
11248                        let s = scale[c];
11249                        let qv = (xs[i] / s).round().clamp(-q_max, q_max);
11250                        outs[i] = qv * s;
11251                    }
11252                }
11253            }
11254
11255            Thunk::ActivationBackward {
11256                x,
11257                dy,
11258                dx,
11259                len,
11260                kind,
11261            } => {
11262                let len = *len as usize;
11263                unsafe {
11264                    let xs = sl(*x, base, len);
11265                    let dys = sl(*dy, base, len);
11266                    let out = sl_mut(*dx, base, len);
11267                    activation_backward_kernel(*kind, xs, dys, out);
11268                }
11269            }
11270
11271            Thunk::ActivationBackwardF64 {
11272                x,
11273                dy,
11274                dx,
11275                len,
11276                kind,
11277            } => {
11278                let len = *len as usize;
11279                unsafe {
11280                    let xs = sl_f64(*x, base, len);
11281                    let dys = sl_f64(*dy, base, len);
11282                    let out = sl_mut_f64(*dx, base, len);
11283                    activation_backward_kernel_f64(*kind, xs, dys, out);
11284                }
11285            }
11286
11287            Thunk::FakeQuantizeLSQ {
11288                x,
11289                scale_off,
11290                out,
11291                len,
11292                chan_axis: _,
11293                chan_dim,
11294                inner,
11295                bits,
11296            } => {
11297                let len = *len as usize;
11298                let chan_dim = *chan_dim as usize;
11299                let inner = *inner as usize;
11300                let q_max: f32 = match *bits {
11301                    8 => 127.0,
11302                    4 => 7.0,
11303                    2 => 1.0,
11304                    n => panic!("FakeQuantizeLSQ: bad bits {n}"),
11305                };
11306                unsafe {
11307                    let xs = sl(*x, base, len);
11308                    let scale = sl(*scale_off, base, chan_dim);
11309                    let outs = sl_mut(*out, base, len);
11310                    for i in 0..len {
11311                        let c = if chan_dim == 1 {
11312                            0
11313                        } else {
11314                            (i / inner) % chan_dim
11315                        };
11316                        let s = scale[c].max(1e-12);
11317                        let qv = (xs[i] / s).round().clamp(-q_max, q_max);
11318                        outs[i] = qv * s;
11319                    }
11320                }
11321            }
11322
11323            Thunk::FakeQuantizeLSQBackwardX {
11324                x,
11325                scale_off,
11326                dy,
11327                dx,
11328                len,
11329                chan_axis: _,
11330                chan_dim,
11331                inner,
11332                bits,
11333            } => {
11334                let len = *len as usize;
11335                let chan_dim = *chan_dim as usize;
11336                let inner = *inner as usize;
11337                let q_max: f32 = match *bits {
11338                    8 => 127.0,
11339                    4 => 7.0,
11340                    2 => 1.0,
11341                    n => panic!("FakeQuantizeLSQBackwardX: bad bits {n}"),
11342                };
11343                unsafe {
11344                    let xs = sl(*x, base, len);
11345                    let scale = sl(*scale_off, base, chan_dim);
11346                    let dys = sl(*dy, base, len);
11347                    let outs = sl_mut(*dx, base, len);
11348                    // STE-clipped: dx = dy when |x/s| ≤ q_max, else 0.
11349                    for i in 0..len {
11350                        let c = if chan_dim == 1 {
11351                            0
11352                        } else {
11353                            (i / inner) % chan_dim
11354                        };
11355                        let z = xs[i] / scale[c].max(1e-12);
11356                        outs[i] = if z.abs() <= q_max { dys[i] } else { 0.0 };
11357                    }
11358                }
11359            }
11360
11361            Thunk::FakeQuantizeLSQBackwardScale {
11362                x,
11363                scale_off,
11364                dy,
11365                dscale,
11366                len,
11367                chan_axis: _,
11368                chan_dim,
11369                inner,
11370                bits,
11371            } => {
11372                let len = *len as usize;
11373                let chan_dim = *chan_dim as usize;
11374                let inner = *inner as usize;
11375                let q_max: f32 = match *bits {
11376                    8 => 127.0,
11377                    4 => 7.0,
11378                    2 => 1.0,
11379                    n => panic!("FakeQuantizeLSQBackwardScale: bad bits {n}"),
11380                };
11381                unsafe {
11382                    let xs = sl(*x, base, len);
11383                    let scale = sl(*scale_off, base, chan_dim);
11384                    let dys = sl(*dy, base, len);
11385                    let outs = sl_mut(*dscale, base, chan_dim);
11386                    for v in outs.iter_mut() {
11387                        *v = 0.0;
11388                    }
11389                    // ψ(z) = -z + round(z) inside range, sign(z)·q_max outside.
11390                    // dscale[c] = sum_i ψ(x_i/s[c]) * upstream[i].
11391                    for i in 0..len {
11392                        let c = if chan_dim == 1 {
11393                            0
11394                        } else {
11395                            (i / inner) % chan_dim
11396                        };
11397                        let s = scale[c].max(1e-12);
11398                        let z = xs[i] / s;
11399                        let psi = if z.abs() <= q_max {
11400                            -z + z.round()
11401                        } else if z > 0.0 {
11402                            q_max
11403                        } else {
11404                            -q_max
11405                        };
11406                        outs[c] += psi * dys[i];
11407                    }
11408                }
11409            }
11410
11411            Thunk::FakeQuantizeBackward {
11412                x,
11413                dy,
11414                dx,
11415                len,
11416                chan_axis: _,
11417                chan_dim,
11418                inner,
11419                bits,
11420                ste,
11421            } => {
11422                use rlx_ir::op::SteKind;
11423                let len = *len as usize;
11424                let chan_dim = *chan_dim as usize;
11425                let inner = *inner as usize;
11426                let q_max: f32 = match *bits {
11427                    8 => 127.0,
11428                    4 => 7.0,
11429                    2 => 1.0,
11430                    n => panic!("FakeQuantizeBackward: bad bits {n}"),
11431                };
11432                unsafe {
11433                    let xs = sl(*x, base, len);
11434                    let dys = sl(*dy, base, len);
11435                    let outs = sl_mut(*dx, base, len);
11436
11437                    // Per-channel max-abs → scale, same as forward.
11438                    let mut max_abs = vec![0f32; chan_dim];
11439                    for i in 0..len {
11440                        let c = if chan_dim == 1 {
11441                            0
11442                        } else {
11443                            (i / inner) % chan_dim
11444                        };
11445                        let a = xs[i].abs();
11446                        if a > max_abs[c] {
11447                            max_abs[c] = a;
11448                        }
11449                    }
11450                    let mut scale = vec![0f32; chan_dim];
11451                    for c in 0..chan_dim {
11452                        scale[c] = (max_abs[c] / q_max).max(1e-12);
11453                    }
11454
11455                    match *ste {
11456                        SteKind::Identity => {
11457                            // dx = dy unchanged.
11458                            outs.copy_from_slice(dys);
11459                        }
11460                        SteKind::ClippedIdentity => {
11461                            // dx = dy * (|x| <= q_max·s); zero if the
11462                            // forward saturated.
11463                            for i in 0..len {
11464                                let c = if chan_dim == 1 {
11465                                    0
11466                                } else {
11467                                    (i / inner) % chan_dim
11468                                };
11469                                let bound = q_max * scale[c];
11470                                outs[i] = if xs[i].abs() <= bound { dys[i] } else { 0.0 };
11471                            }
11472                        }
11473                        SteKind::Tanh => {
11474                            // dx = dy * (1 - tanh²(x/s)).
11475                            for i in 0..len {
11476                                let c = if chan_dim == 1 {
11477                                    0
11478                                } else {
11479                                    (i / inner) % chan_dim
11480                                };
11481                                let t = (xs[i] / scale[c]).tanh();
11482                                outs[i] = dys[i] * (1.0 - t * t);
11483                            }
11484                        }
11485                        SteKind::HardTanh => {
11486                            // dx = dy * max(0, 1 - |x/(q_max·s)|).
11487                            for i in 0..len {
11488                                let c = if chan_dim == 1 {
11489                                    0
11490                                } else {
11491                                    (i / inner) % chan_dim
11492                                };
11493                                let bound = q_max * scale[c];
11494                                let attenuation = (1.0 - (xs[i] / bound).abs()).max(0.0);
11495                                outs[i] = dys[i] * attenuation;
11496                            }
11497                        }
11498                    }
11499                }
11500            }
11501
11502            Thunk::LayerNormBackwardInput {
11503                x,
11504                gamma,
11505                dy,
11506                dx,
11507                rows,
11508                h,
11509                eps,
11510            } => {
11511                let rows = *rows as usize;
11512                let h = *h as usize;
11513                let eps = *eps;
11514                unsafe {
11515                    let xs = sl(*x, base, rows * h);
11516                    let g = sl(*gamma, base, h);
11517                    let dys = sl(*dy, base, rows * h);
11518                    let out = sl_mut(*dx, base, rows * h);
11519                    let n_inv = 1.0 / h as f32;
11520                    for r in 0..rows {
11521                        let xr = &xs[r * h..(r + 1) * h];
11522                        let dyr = &dys[r * h..(r + 1) * h];
11523                        // Per-row mean and inv_std (recompute — no saved
11524                        // tensor from the forward pass).
11525                        let mut sum = 0f32;
11526                        for &v in xr {
11527                            sum += v;
11528                        }
11529                        let mean = sum * n_inv;
11530                        let mut var = 0f32;
11531                        for &v in xr {
11532                            let d = v - mean;
11533                            var += d * d;
11534                        }
11535                        let inv_std = 1.0 / (var * n_inv + eps).sqrt();
11536
11537                        // sums needed for the closed-form:
11538                        //   mean(dy·γ) and mean(dy·γ·x̂)
11539                        let mut s_sy = 0f32;
11540                        let mut s_sxh = 0f32;
11541                        for d in 0..h {
11542                            let xh = (xr[d] - mean) * inv_std;
11543                            let sy = dyr[d] * g[d];
11544                            s_sy += sy;
11545                            s_sxh += sy * xh;
11546                        }
11547                        let m_sy = s_sy * n_inv;
11548                        let m_sxh = s_sxh * n_inv;
11549
11550                        for d in 0..h {
11551                            let xh = (xr[d] - mean) * inv_std;
11552                            let sy = dyr[d] * g[d];
11553                            out[r * h + d] = inv_std * (sy - m_sy - xh * m_sxh);
11554                        }
11555                    }
11556                }
11557            }
11558
11559            Thunk::LayerNormBackwardGamma {
11560                x,
11561                dy,
11562                dgamma,
11563                rows,
11564                h,
11565                eps,
11566            } => {
11567                let rows = *rows as usize;
11568                let h = *h as usize;
11569                let eps = *eps;
11570                unsafe {
11571                    let xs = sl(*x, base, rows * h);
11572                    let dys = sl(*dy, base, rows * h);
11573                    let out = sl_mut(*dgamma, base, h);
11574                    for v in out.iter_mut() {
11575                        *v = 0.0;
11576                    }
11577                    let n_inv = 1.0 / h as f32;
11578                    for r in 0..rows {
11579                        let xr = &xs[r * h..(r + 1) * h];
11580                        let dyr = &dys[r * h..(r + 1) * h];
11581                        let mut sum = 0f32;
11582                        for &v in xr {
11583                            sum += v;
11584                        }
11585                        let mean = sum * n_inv;
11586                        let mut var = 0f32;
11587                        for &v in xr {
11588                            let d = v - mean;
11589                            var += d * d;
11590                        }
11591                        let inv_std = 1.0 / (var * n_inv + eps).sqrt();
11592                        for d in 0..h {
11593                            let xh = (xr[d] - mean) * inv_std;
11594                            out[d] += dyr[d] * xh;
11595                        }
11596                    }
11597                }
11598            }
11599
11600            Thunk::RmsNormBackwardInput {
11601                x,
11602                gamma,
11603                beta,
11604                dy,
11605                dx,
11606                rows,
11607                h,
11608                eps,
11609            } => {
11610                let (rows, h) = (*rows as usize, *h as usize);
11611                unsafe {
11612                    let xs = sl(*x, base, rows * h);
11613                    let g = sl(*gamma, base, h);
11614                    let b = sl(*beta, base, h);
11615                    let dys = sl(*dy, base, rows * h);
11616                    let out = sl_mut(*dx, base, rows * h);
11617                    let mut dg = vec![0f32; h];
11618                    let mut db = vec![0f32; h];
11619                    for r in 0..rows {
11620                        crate::training_bwd::rms_norm_backward_row(
11621                            &xs[r * h..(r + 1) * h],
11622                            g,
11623                            b,
11624                            &dys[r * h..(r + 1) * h],
11625                            &mut out[r * h..(r + 1) * h],
11626                            &mut dg,
11627                            &mut db,
11628                            *eps,
11629                        );
11630                    }
11631                }
11632            }
11633
11634            Thunk::RmsNormBackwardGamma {
11635                x,
11636                gamma,
11637                beta,
11638                dy,
11639                dgamma,
11640                rows,
11641                h,
11642                eps,
11643            } => {
11644                let (rows, h) = (*rows as usize, *h as usize);
11645                unsafe {
11646                    let xs = sl(*x, base, rows * h);
11647                    let g = sl(*gamma, base, h);
11648                    let b = sl(*beta, base, h);
11649                    let dys = sl(*dy, base, rows * h);
11650                    let out = sl_mut(*dgamma, base, h);
11651                    for v in out.iter_mut() {
11652                        *v = 0.0;
11653                    }
11654                    let mut dx = vec![0f32; h];
11655                    let mut db = vec![0f32; h];
11656                    for r in 0..rows {
11657                        crate::training_bwd::rms_norm_backward_row(
11658                            &xs[r * h..(r + 1) * h],
11659                            g,
11660                            b,
11661                            &dys[r * h..(r + 1) * h],
11662                            &mut dx,
11663                            &mut *out,
11664                            &mut db,
11665                            *eps,
11666                        );
11667                    }
11668                }
11669            }
11670
11671            Thunk::RmsNormBackwardBeta {
11672                x,
11673                gamma,
11674                beta,
11675                dy,
11676                dbeta,
11677                rows,
11678                h,
11679                eps,
11680            } => {
11681                let (rows, h) = (*rows as usize, *h as usize);
11682                unsafe {
11683                    let xs = sl(*x, base, rows * h);
11684                    let g = sl(*gamma, base, h);
11685                    let b = sl(*beta, base, h);
11686                    let dys = sl(*dy, base, rows * h);
11687                    let out = sl_mut(*dbeta, base, h);
11688                    for v in out.iter_mut() {
11689                        *v = 0.0;
11690                    }
11691                    let mut dx = vec![0f32; h];
11692                    let mut dg = vec![0f32; h];
11693                    for r in 0..rows {
11694                        crate::training_bwd::rms_norm_backward_row(
11695                            &xs[r * h..(r + 1) * h],
11696                            g,
11697                            b,
11698                            &dys[r * h..(r + 1) * h],
11699                            &mut dx,
11700                            &mut dg,
11701                            &mut *out,
11702                            *eps,
11703                        );
11704                    }
11705                }
11706            }
11707
11708            Thunk::RopeBackward {
11709                dy,
11710                cos,
11711                sin,
11712                dx,
11713                batch,
11714                seq,
11715                hidden,
11716                head_dim,
11717                n_rot,
11718                cos_len,
11719            } => {
11720                let (b, s, hs, dh, nr, cl) = (
11721                    *batch as usize,
11722                    *seq as usize,
11723                    *hidden as usize,
11724                    *head_dim as usize,
11725                    *n_rot as usize,
11726                    *cos_len as usize,
11727                );
11728                let nh = hs / dh;
11729                let tab_half = dh / 2;
11730                unsafe {
11731                    let dys = sl(*dy, base, b * s * hs);
11732                    let cos_tab = sl(*cos, base, cl);
11733                    let sin_tab = sl(*sin, base, cl);
11734                    let out = sl_mut(*dx, base, b * s * hs);
11735                    for bi in 0..b {
11736                        for si in 0..s {
11737                            let tab_off = si.saturating_mul(tab_half) % cl.max(1);
11738                            let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
11739                            let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
11740                            for hi in 0..nh {
11741                                let base_idx = bi * s * hs + si * hs + hi * dh;
11742                                crate::training_bwd::rope_backward_row(
11743                                    &dys[base_idx..base_idx + dh],
11744                                    cp,
11745                                    sp,
11746                                    &mut out[base_idx..base_idx + dh],
11747                                    dh,
11748                                    nr,
11749                                );
11750                            }
11751                        }
11752                    }
11753                }
11754            }
11755
11756            Thunk::CumsumBackward {
11757                dy,
11758                dx,
11759                rows,
11760                cols,
11761                exclusive,
11762            } => {
11763                let (rows, cols) = (*rows as usize, *cols as usize);
11764                unsafe {
11765                    let dys = sl(*dy, base, rows * cols);
11766                    let out = sl_mut(*dx, base, rows * cols);
11767                    for r in 0..rows {
11768                        crate::training_bwd::cumsum_backward_row(
11769                            &dys[r * cols..(r + 1) * cols],
11770                            &mut out[r * cols..(r + 1) * cols],
11771                            *exclusive,
11772                        );
11773                    }
11774                }
11775            }
11776
11777            Thunk::GroupNormBackwardInput {
11778                x,
11779                gamma,
11780                beta: _beta,
11781                dy,
11782                dx,
11783                n,
11784                c,
11785                h,
11786                w,
11787                num_groups,
11788                eps,
11789            } => {
11790                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
11791                let plane = c * h * w;
11792                unsafe {
11793                    let xs = sl(*x, base, n * plane);
11794                    let g = sl(*gamma, base, c);
11795                    let dys = sl(*dy, base, n * plane);
11796                    let out = sl_mut(*dx, base, n * plane);
11797                    crate::training_bwd::group_norm_backward_input_nchw(
11798                        xs,
11799                        g,
11800                        dys,
11801                        out,
11802                        n,
11803                        c,
11804                        h,
11805                        w,
11806                        *num_groups as usize,
11807                        *eps,
11808                    );
11809                }
11810            }
11811
11812            Thunk::GroupNormBackwardGamma {
11813                x,
11814                dy,
11815                dgamma,
11816                n,
11817                c,
11818                h,
11819                w,
11820                num_groups,
11821                eps,
11822            } => {
11823                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
11824                let plane = c * h * w;
11825                unsafe {
11826                    let xs = sl(*x, base, n * plane);
11827                    let dys = sl(*dy, base, n * plane);
11828                    let out = sl_mut(*dgamma, base, c);
11829                    crate::training_bwd::group_norm_backward_gamma_nchw(
11830                        xs,
11831                        dys,
11832                        out,
11833                        n,
11834                        c,
11835                        h,
11836                        w,
11837                        *num_groups as usize,
11838                        *eps,
11839                    );
11840                }
11841            }
11842
11843            Thunk::GroupNormBackwardBeta {
11844                dy,
11845                dbeta,
11846                n,
11847                c,
11848                h,
11849                w,
11850            } => {
11851                let (n, c, h, w) = (*n as usize, *c as usize, *h as usize, *w as usize);
11852                let plane = c * h * w;
11853                unsafe {
11854                    let dys = sl(*dy, base, n * plane);
11855                    let out = sl_mut(*dbeta, base, c);
11856                    crate::training_bwd::group_norm_backward_beta_nchw(dys, out, n, c, h, w);
11857                }
11858            }
11859
11860            Thunk::GatherBackward {
11861                dy,
11862                indices,
11863                dst,
11864                outer,
11865                axis_dim,
11866                num_idx,
11867                trailing,
11868            } => {
11869                let (outer, axis_dim, num_idx, trailing) = (
11870                    *outer as usize,
11871                    *axis_dim as usize,
11872                    *num_idx as usize,
11873                    *trailing as usize,
11874                );
11875                unsafe {
11876                    let dys = sl(*dy, base, outer * num_idx * trailing);
11877                    let ids = sl(*indices, base, num_idx);
11878                    let out = sl_mut(*dst, base, outer * axis_dim * trailing);
11879                    for v in out.iter_mut() {
11880                        *v = 0.0;
11881                    }
11882                    crate::training_bwd::gather_axis_backward(
11883                        dys, ids, out, outer, axis_dim, num_idx, trailing,
11884                    );
11885                }
11886            }
11887
11888            Thunk::MaxPool2dBackward {
11889                x,
11890                dy,
11891                dx,
11892                n,
11893                c,
11894                h,
11895                w,
11896                h_out,
11897                w_out,
11898                kh,
11899                kw,
11900                sh,
11901                sw,
11902                ph,
11903                pw,
11904            } => {
11905                let n = *n as usize;
11906                let c = *c as usize;
11907                let h = *h as usize;
11908                let w = *w as usize;
11909                let h_out = *h_out as usize;
11910                let w_out = *w_out as usize;
11911                let kh = *kh as usize;
11912                let kw = *kw as usize;
11913                let sh = *sh as usize;
11914                let sw = *sw as usize;
11915                let ph = *ph as usize;
11916                let pw = *pw as usize;
11917                unsafe {
11918                    let xs = sl(*x, base, n * c * h * w);
11919                    let dys = sl(*dy, base, n * c * h_out * w_out);
11920                    let dxs = sl_mut(*dx, base, n * c * h * w);
11921                    // Zero before scatter — multiple windows can write
11922                    // to the same input position when stride < kernel.
11923                    for v in dxs.iter_mut() {
11924                        *v = 0.0;
11925                    }
11926                    for ni in 0..n {
11927                        for ci in 0..c {
11928                            let in_chan = (ni * c + ci) * h * w;
11929                            let out_chan = (ni * c + ci) * h_out * w_out;
11930                            for ho in 0..h_out {
11931                                for wo in 0..w_out {
11932                                    // Recompute argmax inside this window.
11933                                    let mut best_v = f32::NEG_INFINITY;
11934                                    let mut best_idx: Option<usize> = None;
11935                                    for ki in 0..kh {
11936                                        for kj in 0..kw {
11937                                            let hi = ho * sh + ki;
11938                                            let wi = wo * sw + kj;
11939                                            if hi < ph || wi < pw {
11940                                                continue;
11941                                            }
11942                                            let hi = hi - ph;
11943                                            let wi = wi - pw;
11944                                            if hi >= h || wi >= w {
11945                                                continue;
11946                                            }
11947                                            let idx = in_chan + hi * w + wi;
11948                                            let v = xs[idx];
11949                                            // Tie-break: keep first hit
11950                                            // (matches forward's `acc.max(v)`
11951                                            // — strict greater-than wins).
11952                                            if v > best_v {
11953                                                best_v = v;
11954                                                best_idx = Some(idx);
11955                                            }
11956                                        }
11957                                    }
11958                                    if let Some(idx) = best_idx {
11959                                        dxs[idx] += dys[out_chan + ho * w_out + wo];
11960                                    }
11961                                }
11962                            }
11963                        }
11964                    }
11965                }
11966            }
11967
11968            Thunk::Conv2dBackwardInput {
11969                dy,
11970                w,
11971                dx,
11972                n,
11973                c_in,
11974                h,
11975                w_in,
11976                c_out,
11977                h_out,
11978                w_out,
11979                kh,
11980                kw,
11981                sh,
11982                sw,
11983                ph,
11984                pw,
11985                dh,
11986                dw,
11987                groups,
11988            } => {
11989                // Per-group GEMM + col2im. Two orders of magnitude faster
11990                // than the naive 6-deep nested loop on training shapes.
11991                //
11992                //   dcol_n_g = w_g^T  @  dy_n_g            (sgemm)
11993                //   dx_n_g  += col2im(dcol_n_g)            (scatter-add)
11994                //
11995                // Layouts (all row-major):
11996                //   w_g       [c_out_per_g, c_in_per_g · kh · kw]
11997                //   dy_n_g    [c_out_per_g, h_out · w_out]
11998                //   dcol_n_g  [c_in_per_g · kh · kw, h_out · w_out]
11999                //   dx_n_g    [c_in_per_g, h · w_in]
12000                let n = *n as usize;
12001                let c_in = *c_in as usize;
12002                let h = *h as usize;
12003                let w_in = *w_in as usize;
12004                let c_out = *c_out as usize;
12005                let h_out = *h_out as usize;
12006                let w_out = *w_out as usize;
12007                let kh = *kh as usize;
12008                let kw = *kw as usize;
12009                let sh = *sh as usize;
12010                let sw = *sw as usize;
12011                let ph = *ph as usize;
12012                let pw = *pw as usize;
12013                let dh = *dh as usize;
12014                let dw = *dw as usize;
12015                let groups = *groups as usize;
12016                let c_in_per_g = c_in / groups;
12017                let c_out_per_g = c_out / groups;
12018
12019                let m_dim = c_in_per_g * kh * kw;
12020                let n_dim = h_out * w_out;
12021                let k_dim = c_out_per_g;
12022
12023                let dy_stride_n = c_out * h_out * w_out;
12024                let dy_stride_g = c_out_per_g * h_out * w_out;
12025                let w_stride_g = c_out_per_g * c_in_per_g * kh * kw;
12026                let dx_stride_n = c_in * h * w_in;
12027                let dx_stride_g = c_in_per_g * h * w_in;
12028
12029                unsafe {
12030                    let dys = sl(*dy, base, n * c_out * h_out * w_out);
12031                    let ws = sl(*w, base, c_out * c_in_per_g * kh * kw);
12032                    let dxs = sl_mut(*dx, base, n * c_in * h * w_in);
12033                    for v in dxs.iter_mut() {
12034                        *v = 0.0;
12035                    }
12036
12037                    // Reused scratch buffer for the [m_dim, n_dim] dcol.
12038                    let mut dcol = vec![0f32; m_dim * n_dim];
12039
12040                    for ni in 0..n {
12041                        for g in 0..groups {
12042                            let w_g_off = g * w_stride_g;
12043                            let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
12044                            let dx_n_g_off = ni * dx_stride_n + g * dx_stride_g;
12045
12046                            // dcol = w_g^T @ dy_n_g
12047                            // w_g  is stored as [k_dim rows, m_dim cols] row-major
12048                            // (i.e. K×M storage with lda = M = m_dim — exactly what
12049                            // sgemm_general wants for trans_a=true).
12050                            crate::blas::sgemm_general(
12051                                ws.as_ptr().add(w_g_off),
12052                                dys.as_ptr().add(dy_n_g_off),
12053                                dcol.as_mut_ptr(),
12054                                m_dim,
12055                                n_dim,
12056                                k_dim,
12057                                1.0,
12058                                0.0,
12059                                /*lda=*/ m_dim,
12060                                /*ldb=*/ n_dim,
12061                                /*ldc=*/ n_dim,
12062                                /*trans_a=*/ true,
12063                                /*trans_b=*/ false,
12064                            );
12065
12066                            // dx_n_g += col2im(dcol)
12067                            col2im(
12068                                &dcol,
12069                                &mut dxs[dx_n_g_off..dx_n_g_off + dx_stride_g],
12070                                c_in_per_g,
12071                                h,
12072                                w_in,
12073                                h_out,
12074                                w_out,
12075                                kh,
12076                                kw,
12077                                sh,
12078                                sw,
12079                                ph,
12080                                pw,
12081                                dh,
12082                                dw,
12083                            );
12084                        }
12085                    }
12086                }
12087            }
12088
12089            Thunk::Conv2dBackwardWeight {
12090                x,
12091                dy,
12092                dw,
12093                n,
12094                c_in,
12095                h,
12096                w,
12097                c_out,
12098                h_out,
12099                w_out,
12100                kh,
12101                kw,
12102                sh,
12103                sw,
12104                ph,
12105                pw,
12106                dh,
12107                dw_dil,
12108                groups,
12109            } => {
12110                let n = *n as usize;
12111                let c_in = *c_in as usize;
12112                let h = *h as usize;
12113                let w = *w as usize;
12114                // Per-group im2col + GEMM, summed across batch.
12115                //
12116                //   col_n_g  = im2col(x_n_g)               (gather)
12117                //   dw_g    += dy_n_g  @  col_n_g^T        (sgemm, β=1)
12118                //
12119                // Layouts:
12120                //   x_n_g     [c_in_per_g, h · w]
12121                //   col_n_g   [c_in_per_g · kh · kw, h_out · w_out]
12122                //   dy_n_g    [c_out_per_g, h_out · w_out]
12123                //   dw_g      [c_out_per_g, c_in_per_g · kh · kw]
12124                let c_out = *c_out as usize;
12125                let h_out = *h_out as usize;
12126                let w_out = *w_out as usize;
12127                let kh = *kh as usize;
12128                let kw = *kw as usize;
12129                let sh = *sh as usize;
12130                let sw = *sw as usize;
12131                let ph = *ph as usize;
12132                let pw = *pw as usize;
12133                let dh = *dh as usize;
12134                let dw_dil = *dw_dil as usize;
12135                let groups = *groups as usize;
12136                let c_in_per_g = c_in / groups;
12137                let c_out_per_g = c_out / groups;
12138
12139                let m_dim = c_out_per_g;
12140                let n_dim = c_in_per_g * kh * kw;
12141                let k_dim = h_out * w_out;
12142
12143                let x_stride_n = c_in * h * w;
12144                let x_stride_g = c_in_per_g * h * w;
12145                let dy_stride_n = c_out * h_out * w_out;
12146                let dy_stride_g = c_out_per_g * h_out * w_out;
12147                let dw_stride_g = c_out_per_g * c_in_per_g * kh * kw;
12148
12149                unsafe {
12150                    let xs = sl(*x, base, n * c_in * h * w);
12151                    let dys = sl(*dy, base, n * c_out * h_out * w_out);
12152                    let dws = sl_mut(*dw, base, c_out * c_in_per_g * kh * kw);
12153                    for v in dws.iter_mut() {
12154                        *v = 0.0;
12155                    }
12156
12157                    let mut col = vec![0f32; n_dim * k_dim];
12158
12159                    for ni in 0..n {
12160                        for g in 0..groups {
12161                            let x_n_g_off = ni * x_stride_n + g * x_stride_g;
12162                            im2col(
12163                                &xs[x_n_g_off..x_n_g_off + x_stride_g],
12164                                &mut col,
12165                                c_in_per_g,
12166                                h,
12167                                w,
12168                                h_out,
12169                                w_out,
12170                                kh,
12171                                kw,
12172                                sh,
12173                                sw,
12174                                ph,
12175                                pw,
12176                                dh,
12177                                dw_dil,
12178                            );
12179
12180                            let dy_n_g_off = ni * dy_stride_n + g * dy_stride_g;
12181                            let dw_g_off = g * dw_stride_g;
12182
12183                            // dw_g += dy_n_g @ col^T
12184                            //
12185                            // Output shape m × n_out = c_out_per_g × (c_in_per_g·kh·kw).
12186                            // dy_n_g is stored M×K row-major (lda = K = k_dim).
12187                            // col is stored as N×K row-major; with trans_b=true,
12188                            // sgemm_general uses ldb = K = k_dim and treats it as
12189                            // transposed. β=1 accumulates across the batch loop.
12190                            crate::blas::sgemm_general(
12191                                dys.as_ptr().add(dy_n_g_off),
12192                                col.as_ptr(),
12193                                dws.as_mut_ptr().add(dw_g_off),
12194                                m_dim,
12195                                n_dim,
12196                                k_dim,
12197                                1.0,
12198                                1.0,
12199                                /*lda=*/ k_dim,
12200                                /*ldb=*/ k_dim,
12201                                /*ldc=*/ n_dim,
12202                                /*trans_a=*/ false,
12203                                /*trans_b=*/ true,
12204                            );
12205                        }
12206                    }
12207                }
12208            }
12209
12210            Thunk::SoftmaxCrossEntropy {
12211                logits,
12212                labels,
12213                dst,
12214                n,
12215                c,
12216            } => {
12217                let n = *n as usize;
12218                let c = *c as usize;
12219                unsafe {
12220                    let lg = sl(*logits, base, n * c);
12221                    let lb = sl(*labels, base, n);
12222                    let out = sl_mut(*dst, base, n);
12223                    for ni in 0..n {
12224                        let row = &lg[ni * c..(ni + 1) * c];
12225                        // log-sum-exp: max-subtract for stability.
12226                        let mut m = f32::NEG_INFINITY;
12227                        for &v in row {
12228                            if v > m {
12229                                m = v;
12230                            }
12231                        }
12232                        let mut sum = 0f32;
12233                        for &v in row {
12234                            sum += (v - m).exp();
12235                        }
12236                        let lse = m + sum.ln();
12237                        let label_idx = lb[ni] as usize;
12238                        // loss = -(logits[label] - lse) = lse - logits[label].
12239                        out[ni] = lse - row[label_idx];
12240                    }
12241                }
12242            }
12243
12244            Thunk::SoftmaxCrossEntropyBackward {
12245                logits,
12246                labels,
12247                d_loss,
12248                dlogits,
12249                n,
12250                c,
12251            } => {
12252                let n = *n as usize;
12253                let c = *c as usize;
12254                unsafe {
12255                    let lg = sl(*logits, base, n * c);
12256                    let lb = sl(*labels, base, n);
12257                    let dl = sl(*d_loss, base, n);
12258                    let out = sl_mut(*dlogits, base, n * c);
12259                    for ni in 0..n {
12260                        let row = &lg[ni * c..(ni + 1) * c];
12261                        let label_idx = lb[ni] as usize;
12262                        let scale = dl[ni];
12263                        let mut m = f32::NEG_INFINITY;
12264                        for &v in row {
12265                            if v > m {
12266                                m = v;
12267                            }
12268                        }
12269                        let mut sum = 0f32;
12270                        for &v in row {
12271                            sum += (v - m).exp();
12272                        }
12273                        let inv_sum = 1.0 / sum;
12274                        let dst_row = &mut out[ni * c..(ni + 1) * c];
12275                        for k in 0..c {
12276                            let p = (row[k] - m).exp() * inv_sum;
12277                            let one_hot = if k == label_idx { 1.0 } else { 0.0 };
12278                            dst_row[k] = (p - one_hot) * scale;
12279                        }
12280                    }
12281                }
12282            }
12283
12284            Thunk::GatherAxis {
12285                table,
12286                idx,
12287                dst,
12288                outer,
12289                axis_dim,
12290                num_idx,
12291                trailing,
12292            } => {
12293                let outer = *outer as usize;
12294                let axis_dim = *axis_dim as usize;
12295                let num_idx = *num_idx as usize;
12296                let trailing = *trailing as usize;
12297                unsafe {
12298                    let tab = sl(*table, base, outer * axis_dim * trailing);
12299                    let ids = sl(*idx, base, num_idx);
12300                    let out = sl_mut(*dst, base, outer * num_idx * trailing);
12301                    for o in 0..outer {
12302                        let tab_outer = o * axis_dim * trailing;
12303                        let out_outer = o * num_idx * trailing;
12304                        for k in 0..num_idx {
12305                            let row = ids[k] as usize;
12306                            let tab_row = tab_outer + row * trailing;
12307                            let out_row = out_outer + k * trailing;
12308                            out[out_row..out_row + trailing]
12309                                .copy_from_slice(&tab[tab_row..tab_row + trailing]);
12310                        }
12311                    }
12312                }
12313            }
12314
12315            Thunk::Transpose {
12316                src,
12317                dst,
12318                in_total,
12319                out_dims,
12320                in_strides,
12321            } => {
12322                // N-D index walk: for each output flat index, decompose into
12323                // multi-dim coords using out_dims, then dot with in_strides
12324                // to find the source flat index. Stride 0 = broadcast (read
12325                // the same input element repeatedly along that dim).
12326                let rank = out_dims.len();
12327                let total: usize = out_dims.iter().map(|&d| d as usize).product();
12328                let in_total = *in_total as usize;
12329                unsafe {
12330                    let inp = sl(*src, base, in_total);
12331                    let out = sl_mut(*dst, base, total);
12332                    let mut idx = vec![0usize; rank];
12333                    for o in 0..total {
12334                        let mut src_idx = 0usize;
12335                        for d in 0..rank {
12336                            src_idx += idx[d] * in_strides[d] as usize;
12337                        }
12338                        out[o] = inp[src_idx];
12339                        // Increment multi-index (innermost dim first).
12340                        for d in (0..rank).rev() {
12341                            idx[d] += 1;
12342                            if idx[d] < out_dims[d] as usize {
12343                                break;
12344                            }
12345                            idx[d] = 0;
12346                        }
12347                    }
12348                }
12349            }
12350
12351            // (Thunk::DenseSolveF64 / Thunk::ScanBackward had panic
12352            // stubs here as placeholders during the wire-up; both
12353            // are now reached by the real implementations earlier in
12354            // this same match — the stubs were dead code shadowed by
12355            // the specific-pattern arms above. Removed.)
12356            Thunk::CustomOp {
12357                kernel,
12358                inputs,
12359                output,
12360                attrs,
12361            } => {
12362                let (out_off, out_len, out_shape) = output;
12363                unsafe {
12364                    dispatch_custom_op(
12365                        &**kernel, inputs, *out_off, *out_len, out_shape, attrs, base,
12366                    );
12367                }
12368            }
12369        }
12370    }
12371}
12372
12373/// Griewank treeverse: process backward iterations `[t_lo..=t_hi]` (with
12374/// the carry entering iteration `t_lo` supplied as `anchor_carry`) by
12375/// recursive binary subdivision. Total work `O((t_hi-t_lo+1) · log)`,
12376/// auxiliary memory `O(log · carry_bytes)` for the recursion stack.
12377///
12378/// Compared to the iterative segment-cached scheme, this trades extra
12379/// recompute for less working memory — each level of recursion holds
12380/// one `cb`-sized intermediate carry on the stack but never the whole
12381/// segment at once. With K saved outer checkpoints, the outer driver
12382/// invokes this helper once per segment.
12383///
12384/// `process_iter(t, carry_at_t)` is the per-iteration leaf action: it
12385/// runs `body_vjp` at iteration `t` with the supplied carry, threads
12386/// `dcarry` backward, and (for ScanBackwardXs) writes `dxs[t]`.
12387#[allow(clippy::too_many_arguments)]
12388unsafe fn griewank_process_segment(
12389    t_lo: usize,
12390    t_hi: usize,
12391    anchor_carry: &[u8],
12392    cb: usize,
12393    fwd_sched: &ThunkSchedule,
12394    fwd_init: &[u8],
12395    fwd_carry_in_off: usize,
12396    fwd_output_off: usize,
12397    fwd_x_offs: &[usize],
12398    base: *mut u8,
12399    outer_xs_offs: &[(usize, u32)],
12400    fwd_buf: &mut Vec<u8>,
12401    leaf_threshold: usize,
12402    process_iter: &mut dyn FnMut(usize, &[u8]),
12403) {
12404    unsafe {
12405        let size = t_hi - t_lo + 1;
12406        if size == 1 {
12407            process_iter(t_lo, anchor_carry);
12408            return;
12409        }
12410        if size <= leaf_threshold {
12411            // Walk forward, cache each carry, run backward in reverse.
12412            let mut cache: Vec<u8> = Vec::with_capacity(size * cb);
12413            cache.extend_from_slice(anchor_carry);
12414            fwd_buf.copy_from_slice(fwd_init);
12415            std::ptr::copy_nonoverlapping(
12416                anchor_carry.as_ptr(),
12417                fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
12418                cb,
12419            );
12420            for i in 1..size {
12421                let cur_iter = t_lo + i - 1;
12422                for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
12423                    let (outer_xs_off, x_psb) = outer_xs_offs[idx];
12424                    let xb = x_psb as usize;
12425                    std::ptr::copy_nonoverlapping(
12426                        base.add(outer_xs_off + cur_iter * xb),
12427                        fwd_buf.as_mut_ptr().add(*fb_x_off),
12428                        xb,
12429                    );
12430                }
12431                execute_thunks(fwd_sched, fwd_buf);
12432                if fwd_output_off != fwd_carry_in_off {
12433                    fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
12434                }
12435                cache.extend_from_slice(&fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb]);
12436            }
12437            // Process backward.
12438            for t in (t_lo..=t_hi).rev() {
12439                let idx = t - t_lo;
12440                let carry = &cache[idx * cb..(idx + 1) * cb];
12441                process_iter(t, carry);
12442            }
12443            return;
12444        }
12445
12446        // Split: walk forward from anchor to compute carry entering `mid`.
12447        // (We need `mid - t_lo` body executions: one per iteration in
12448        // [t_lo, mid).)
12449        let mid = t_lo + size / 2;
12450        fwd_buf.copy_from_slice(fwd_init);
12451        std::ptr::copy_nonoverlapping(
12452            anchor_carry.as_ptr(),
12453            fwd_buf.as_mut_ptr().add(fwd_carry_in_off),
12454            cb,
12455        );
12456        for cur_iter in t_lo..mid {
12457            for (idx, fb_x_off) in fwd_x_offs.iter().enumerate() {
12458                let (outer_xs_off, x_psb) = outer_xs_offs[idx];
12459                let xb = x_psb as usize;
12460                std::ptr::copy_nonoverlapping(
12461                    base.add(outer_xs_off + cur_iter * xb),
12462                    fwd_buf.as_mut_ptr().add(*fb_x_off),
12463                    xb,
12464                );
12465            }
12466            execute_thunks(fwd_sched, fwd_buf);
12467            if fwd_output_off != fwd_carry_in_off {
12468                fwd_buf.copy_within(fwd_output_off..fwd_output_off + cb, fwd_carry_in_off);
12469            }
12470        }
12471        let mid_carry: Vec<u8> = fwd_buf[fwd_carry_in_off..fwd_carry_in_off + cb].to_vec();
12472
12473        // Right half first (higher t values processed first to match the
12474        // canonical reverse-mode iteration order: dcarry threads from
12475        // t=length-1 down to t=0).
12476        griewank_process_segment(
12477            mid,
12478            t_hi,
12479            &mid_carry,
12480            cb,
12481            fwd_sched,
12482            fwd_init,
12483            fwd_carry_in_off,
12484            fwd_output_off,
12485            fwd_x_offs,
12486            base,
12487            outer_xs_offs,
12488            fwd_buf,
12489            leaf_threshold,
12490            process_iter,
12491        );
12492        // Then left half with original anchor.
12493        griewank_process_segment(
12494            t_lo,
12495            mid - 1,
12496            anchor_carry,
12497            cb,
12498            fwd_sched,
12499            fwd_init,
12500            fwd_carry_in_off,
12501            fwd_output_off,
12502            fwd_x_offs,
12503            base,
12504            outer_xs_offs,
12505            fwd_buf,
12506            leaf_threshold,
12507            process_iter,
12508        );
12509    }
12510}
12511
12512/// Execute a batched 1D FFT in the f64 2N-real-block layout.
12513/// Each "row" is `2N` f64 elements: first `N` real, then `N` imag.
12514/// The `outer` rows are independent and processed sequentially.
12515///
12516/// Both forward and inverse use the same Cooley-Tukey radix-2 DIT
12517/// kernel — only the twiddle-factor sign differs. Power-of-2 only
12518/// (the IR builder rejects non-power-of-2 sizes at graph-build time).
12519/// Batched 1D FFT on the f64 2N-real-block layout. Public so other
12520/// backend crates can invoke this as a host fallback against a
12521/// unified-memory arena (e.g. rlx-metal: sync the command buffer,
12522/// pass the Metal `Buffer::contents()` pointer as `base`, restart the
12523/// command buffer). Self-contained — no rlx-cpu state required.
12524///
12525/// Safety: `base + src` and `base + dst` must be valid for the
12526/// `outer * 2 * n_complex * sizeof::<f64>()` byte range and stay
12527/// alive for the duration of the call.
12528pub unsafe fn execute_fft1d_f64(
12529    src: usize,
12530    dst: usize,
12531    outer: usize,
12532    n_complex: usize,
12533    inverse: bool,
12534    norm_tag: u32,
12535    base: *mut u8,
12536) {
12537    let row_elems = 2 * n_complex;
12538    let mut re = vec![0f64; n_complex];
12539    let mut im = vec![0f64; n_complex];
12540    let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
12541    let scale = norm.output_scale(n_complex, inverse);
12542    // Scratch reused across rows for the Bluestein path. Empty when
12543    // we're on the radix-2 fast path.
12544    let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
12545        BluesteinScratchF64::empty()
12546    } else {
12547        BluesteinScratchF64::build(n_complex, inverse)
12548    };
12549    for o in 0..outer {
12550        let row_offset = src + o * row_elems * std::mem::size_of::<f64>();
12551        let s = unsafe { sl_f64(row_offset, base, row_elems) };
12552        re.copy_from_slice(&s[..n_complex]);
12553        im.copy_from_slice(&s[n_complex..]);
12554        if n_complex.is_power_of_two() {
12555            fft_radix2_inplace_f64(&mut re, &mut im, inverse);
12556        } else if n_complex <= 16 {
12557            fft_naive_inplace_f64(&mut re, &mut im, inverse);
12558        } else {
12559            fft_bluestein_inplace_f64(&mut re, &mut im, inverse, &mut scratch);
12560        }
12561        if scale != 1.0 {
12562            re.iter_mut().for_each(|v| *v *= scale);
12563            im.iter_mut().for_each(|v| *v *= scale);
12564        }
12565        let dst_offset = dst + o * row_elems * std::mem::size_of::<f64>();
12566        let d = unsafe { sl_mut_f64(dst_offset, base, row_elems) };
12567        d[..n_complex].copy_from_slice(&re);
12568        d[n_complex..].copy_from_slice(&im);
12569    }
12570}
12571
12572/// f32 counterpart of `execute_fft1d_f64`. Same 2N-real-block layout
12573/// (first N real, second N imag per row), same unnormalized
12574/// convention; only the element width differs. Twiddle factors are
12575/// computed in f64 and cast to f32 to keep large-N error closer to
12576/// the f64 path (the savings from f32 are in memory bandwidth, not in
12577/// twiddle precision).
12578/// Host-fallback entry for `Op::GatedDeltaNet` (Metal / unified memory).
12579/// When `state == 0`, uses a zero-initialized scratch state per batch item.
12580pub unsafe fn execute_gated_delta_net_f32(
12581    q: usize,
12582    k: usize,
12583    v: usize,
12584    g: usize,
12585    beta: usize,
12586    state: usize,
12587    dst: usize,
12588    batch: usize,
12589    seq: usize,
12590    heads: usize,
12591    state_size: usize,
12592    base: *mut u8,
12593) {
12594    use rayon::prelude::*;
12595
12596    #[derive(Copy, Clone)]
12597    struct ArenaPtr(usize);
12598    unsafe impl Send for ArenaPtr {}
12599    unsafe impl Sync for ArenaPtr {}
12600    impl ArenaPtr {
12601        #[inline]
12602        fn get(self) -> *mut u8 {
12603            self.0 as *mut u8
12604        }
12605    }
12606
12607    unsafe {
12608        let arena = ArenaPtr(base as usize);
12609        let (b, s, h, n) = (batch, seq, heads, state_size);
12610        let scale = 1.0f32 / (n as f32).sqrt();
12611        let use_external = state != 0;
12612        let mut owned_state = vec![0f32; h * n * n];
12613
12614        crate::pool::num_threads();
12615
12616        assert!(
12617            n <= crate::gdn::GDN_MAX_STATE,
12618            "GatedDeltaNet state_size={n} exceeds stack scratch ({})",
12619            crate::gdn::GDN_MAX_STATE
12620        );
12621
12622        let qs = sl(q, arena.get(), b * s * h * n);
12623        let ks = sl(k, arena.get(), b * s * h * n);
12624        let vs = sl(v, arena.get(), b * s * h * n);
12625        let gs = sl(g, arena.get(), b * s * h);
12626        let betas = sl(beta, arena.get(), b * s * h);
12627        let _out = sl_mut(dst, arena.get(), b * s * h * n);
12628        let hs_n = h * n;
12629
12630        let run_head = |bi: usize, hi: usize, s_mat: &mut [f32], sk: &mut [f32]| {
12631            for ti in 0..s {
12632                let qkv_step = bi * s * hs_n + ti * hs_n + hi * n;
12633                let gb_step = bi * s * h + ti * h + hi;
12634                let out_row = sl_mut(dst + qkv_step * std::mem::size_of::<f32>(), arena.get(), n);
12635                crate::gdn::gdn_step_blas(
12636                    s_mat,
12637                    &qs[qkv_step..qkv_step + n],
12638                    &ks[qkv_step..qkv_step + n],
12639                    &vs[qkv_step..qkv_step + n],
12640                    gs[gb_step],
12641                    betas[gb_step],
12642                    out_row,
12643                    sk,
12644                    n,
12645                    scale,
12646                );
12647            }
12648        };
12649
12650        // Prefill (seq>1, ephemeral state): time-outer, parallel over heads —
12651        // better occupancy than head-outer when prompt length dominates.
12652        if !use_external && s > 1 {
12653            for bi in 0..b {
12654                (0..h).into_par_iter().for_each(|hi| {
12655                    let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
12656                    let sk = &mut sk_buf[..n];
12657                    let mut local_state =
12658                        [0f32; crate::gdn::GDN_MAX_STATE * crate::gdn::GDN_MAX_STATE];
12659                    let s_mat = &mut local_state[..n * n];
12660                    s_mat.fill(0.0);
12661                    run_head(bi, hi, s_mat, sk);
12662                });
12663            }
12664            return;
12665        }
12666
12667        if use_external {
12668            let state_bytes = state;
12669            (0..b * h).into_par_iter().for_each(|bhi| {
12670                let bi = bhi / h;
12671                let hi = bhi % h;
12672                let elem_off = bi * h * n * n + hi * n * n;
12673                let s_mat = sl_mut(
12674                    state_bytes + elem_off * std::mem::size_of::<f32>(),
12675                    arena.get(),
12676                    n * n,
12677                );
12678                let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
12679                run_head(bi, hi, s_mat, &mut sk_buf[..n]);
12680            });
12681        } else {
12682            for bi in 0..b {
12683                owned_state.fill(0.0);
12684                owned_state
12685                    .par_chunks_mut(n * n)
12686                    .enumerate()
12687                    .for_each(|(hi, s_mat)| {
12688                        let mut sk_buf = [0f32; crate::gdn::GDN_MAX_STATE];
12689                        run_head(bi, hi, s_mat, &mut sk_buf[..n]);
12690                    });
12691            }
12692        }
12693    }
12694}
12695
12696/// Host-fallback: `Op::RmsNormBackwardInput` (GPU unified-memory / D2H arenas).
12697pub unsafe fn execute_rms_norm_backward_input_f32(
12698    x: usize,
12699    gamma: usize,
12700    beta: usize,
12701    dy: usize,
12702    dx: usize,
12703    rows: u32,
12704    h: u32,
12705    eps: f32,
12706    base: *mut u8,
12707) {
12708    let (rows, h) = (rows as usize, h as usize);
12709    let mut dg = vec![0f32; h];
12710    let mut db = vec![0f32; h];
12711    let xs = sl(x, base, rows * h);
12712    let dys = sl(dy, base, rows * h);
12713    let g = sl(gamma, base, h);
12714    let b = sl(beta, base, h);
12715    let out = sl_mut(dx, base, rows * h);
12716    for r in 0..rows {
12717        crate::training_bwd::rms_norm_backward_row(
12718            &xs[r * h..(r + 1) * h],
12719            g,
12720            b,
12721            &dys[r * h..(r + 1) * h],
12722            &mut out[r * h..(r + 1) * h],
12723            &mut dg,
12724            &mut db,
12725            eps,
12726        );
12727    }
12728}
12729
12730pub unsafe fn execute_rms_norm_backward_gamma_f32(
12731    x: usize,
12732    gamma: usize,
12733    beta: usize,
12734    dy: usize,
12735    dgamma: usize,
12736    rows: u32,
12737    h: u32,
12738    eps: f32,
12739    base: *mut u8,
12740) {
12741    let (rows, h) = (rows as usize, h as usize);
12742    let out = sl_mut(dgamma, base, h);
12743    out.fill(0.0);
12744    let mut dx = vec![0f32; h];
12745    let mut db = vec![0f32; h];
12746    let xs = sl(x, base, rows * h);
12747    let dys = sl(dy, base, rows * h);
12748    let g = sl(gamma, base, h);
12749    let b = sl(beta, base, h);
12750    for r in 0..rows {
12751        crate::training_bwd::rms_norm_backward_row(
12752            &xs[r * h..(r + 1) * h],
12753            g,
12754            b,
12755            &dys[r * h..(r + 1) * h],
12756            &mut dx,
12757            out,
12758            &mut db,
12759            eps,
12760        );
12761    }
12762}
12763
12764pub unsafe fn execute_rms_norm_backward_beta_f32(
12765    x: usize,
12766    gamma: usize,
12767    beta: usize,
12768    dy: usize,
12769    dbeta: usize,
12770    rows: u32,
12771    h: u32,
12772    eps: f32,
12773    base: *mut u8,
12774) {
12775    let (rows, h) = (rows as usize, h as usize);
12776    let out = sl_mut(dbeta, base, h);
12777    out.fill(0.0);
12778    let mut dx = vec![0f32; h];
12779    let mut dg = vec![0f32; h];
12780    let xs = sl(x, base, rows * h);
12781    let dys = sl(dy, base, rows * h);
12782    let g = sl(gamma, base, h);
12783    let b = sl(beta, base, h);
12784    for r in 0..rows {
12785        crate::training_bwd::rms_norm_backward_row(
12786            &xs[r * h..(r + 1) * h],
12787            g,
12788            b,
12789            &dys[r * h..(r + 1) * h],
12790            &mut dx,
12791            &mut dg,
12792            out,
12793            eps,
12794        );
12795    }
12796}
12797
12798pub unsafe fn execute_rope_backward_f32(
12799    dy: usize,
12800    cos: usize,
12801    sin: usize,
12802    dx: usize,
12803    batch: u32,
12804    seq: u32,
12805    hidden: u32,
12806    head_dim: u32,
12807    n_rot: u32,
12808    cos_len: u32,
12809    base: *mut u8,
12810) {
12811    let (b, s, hs, dh, nr, cl) = (
12812        batch as usize,
12813        seq as usize,
12814        hidden as usize,
12815        head_dim as usize,
12816        n_rot as usize,
12817        cos_len as usize,
12818    );
12819    let nh = hs / dh;
12820    let tab_half = dh / 2;
12821    let dys = sl(dy, base, b * s * hs);
12822    let cos_tab = sl(cos, base, cl);
12823    let sin_tab = sl(sin, base, cl);
12824    let out = sl_mut(dx, base, b * s * hs);
12825    for bi in 0..b {
12826        for si in 0..s {
12827            let tab_off = si.saturating_mul(tab_half) % cl.max(1);
12828            let cp = &cos_tab[tab_off..tab_off + tab_half.min(cl)];
12829            let sp = &sin_tab[tab_off..tab_off + tab_half.min(cl)];
12830            for hi in 0..nh {
12831                let base_idx = bi * s * hs + si * hs + hi * dh;
12832                crate::training_bwd::rope_backward_row(
12833                    &dys[base_idx..base_idx + dh],
12834                    cp,
12835                    sp,
12836                    &mut out[base_idx..base_idx + dh],
12837                    dh,
12838                    nr,
12839                );
12840            }
12841        }
12842    }
12843}
12844
12845pub unsafe fn execute_cumsum_backward_f32(
12846    dy: usize,
12847    dx: usize,
12848    rows: u32,
12849    cols: u32,
12850    exclusive: bool,
12851    base: *mut u8,
12852) {
12853    let (rows, cols) = (rows as usize, cols as usize);
12854    let dys = sl(dy, base, rows * cols);
12855    let out = sl_mut(dx, base, rows * cols);
12856    for r in 0..rows {
12857        crate::training_bwd::cumsum_backward_row(
12858            &dys[r * cols..(r + 1) * cols],
12859            &mut out[r * cols..(r + 1) * cols],
12860            exclusive,
12861        );
12862    }
12863}
12864
12865pub unsafe fn execute_gather_backward_f32(
12866    dy: usize,
12867    indices: usize,
12868    dst: usize,
12869    outer: u32,
12870    axis_dim: u32,
12871    num_idx: u32,
12872    trailing: u32,
12873    base: *mut u8,
12874) {
12875    let (outer, axis_dim, num_idx, trailing) = (
12876        outer as usize,
12877        axis_dim as usize,
12878        num_idx as usize,
12879        trailing as usize,
12880    );
12881    let out = sl_mut(dst, base, outer * axis_dim * trailing);
12882    out.fill(0.0);
12883    crate::training_bwd::gather_axis_backward(
12884        sl(dy, base, outer * num_idx * trailing),
12885        sl(indices, base, num_idx),
12886        out,
12887        outer,
12888        axis_dim,
12889        num_idx,
12890        trailing,
12891    );
12892}
12893
12894/// Host-fallback entry for GGUF `Op::DequantMatMul` (Metal unified memory).
12895pub unsafe fn execute_dequant_matmul_gguf_f32(
12896    x: usize,
12897    w_q: usize,
12898    dst: usize,
12899    m: usize,
12900    k: usize,
12901    n: usize,
12902    scheme: rlx_ir::quant::QuantScheme,
12903    base: *mut u8,
12904) {
12905    unsafe {
12906        let block_bytes = scheme.gguf_block_bytes() as usize;
12907        let block_elems = scheme.gguf_block_size() as usize;
12908        let total_bytes = (k * n) / block_elems * block_bytes;
12909        let xs = sl(x, base, m * k);
12910        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, total_bytes);
12911        let out = sl_mut(dst, base, m * n);
12912        crate::gguf_matmul::gguf_matmul_bt(xs, w_bytes, out, m, k, n, scheme);
12913    }
12914}
12915
12916/// Host-fallback entry for GGUF `Op::DequantGroupedMatMul` (MoE expert stack).
12917pub unsafe fn execute_dequant_grouped_matmul_gguf_f32(
12918    input: usize,
12919    w_q: usize,
12920    expert_idx: usize,
12921    dst: usize,
12922    m: usize,
12923    k: usize,
12924    n: usize,
12925    num_experts: usize,
12926    scheme: rlx_ir::quant::QuantScheme,
12927    base: *mut u8,
12928) {
12929    unsafe {
12930        let block_bytes = scheme.gguf_block_bytes() as usize;
12931        let block_elems = scheme.gguf_block_size() as usize;
12932        let slab_bytes = (k * n) / block_elems * block_bytes;
12933        let xs = sl(input, base, m * k);
12934        let w_bytes =
12935            std::slice::from_raw_parts(base.add(w_q) as *const u8, num_experts * slab_bytes);
12936        let ids = sl(expert_idx, base, m);
12937        let out = sl_mut(dst, base, m * n);
12938        crate::gguf_matmul::gguf_grouped_matmul_bt(
12939            xs,
12940            w_bytes,
12941            ids,
12942            out,
12943            m,
12944            k,
12945            n,
12946            num_experts,
12947            scheme,
12948        );
12949    }
12950}
12951
12952/// Host-fallback entry for Int4 `Op::DequantMatMul` (Metal unified memory).
12953pub unsafe fn execute_dequant_matmul_int4_f32(
12954    x: usize,
12955    w_q: usize,
12956    scale: usize,
12957    zp: usize,
12958    dst: usize,
12959    m: usize,
12960    k: usize,
12961    n: usize,
12962    block_size: u32,
12963    is_asymmetric: bool,
12964    base: *mut u8,
12965) {
12966    let bs = block_size as usize;
12967    let n_blocks = k.div_ceil(bs);
12968    unsafe {
12969        let xs = sl(x, base, m * k);
12970        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
12971        let scales = sl(scale, base, n_blocks * n);
12972        let zps = if is_asymmetric {
12973            sl(zp, base, n_blocks * n)
12974        } else {
12975            &[][..]
12976        };
12977        let out = sl_mut(dst, base, m * n);
12978        dequant_matmul_int4(xs, w_bytes, scales, zps, out, m, k, n, bs, is_asymmetric);
12979    }
12980}
12981
12982/// Host-fallback entry for FP8 `Op::DequantMatMul` (Metal unified memory).
12983pub unsafe fn execute_dequant_matmul_fp8_f32(
12984    x: usize,
12985    w_q: usize,
12986    scale: usize,
12987    dst: usize,
12988    m: usize,
12989    k: usize,
12990    n: usize,
12991    e5m2: bool,
12992    base: *mut u8,
12993) {
12994    unsafe {
12995        let xs = sl(x, base, m * k);
12996        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, k * n);
12997        let scales = sl(scale, base, n);
12998        let out = sl_mut(dst, base, m * n);
12999        dequant_matmul_fp8(xs, w_bytes, scales, out, m, k, n, e5m2);
13000    }
13001}
13002
13003/// Host-fallback entry for NVFP4 `Op::DequantMatMul` (Metal unified memory).
13004pub unsafe fn execute_dequant_matmul_nvfp4_f32(
13005    x: usize,
13006    w_q: usize,
13007    scale: usize,
13008    global_scale: usize,
13009    dst: usize,
13010    m: usize,
13011    k: usize,
13012    n: usize,
13013    base: *mut u8,
13014) {
13015    let n_scale = k.div_ceil(rlx_ir::NVFP4_GROUP_SIZE) * n;
13016    unsafe {
13017        let xs = sl(x, base, m * k);
13018        let w_bytes = std::slice::from_raw_parts(base.add(w_q) as *const u8, (k * n).div_ceil(2));
13019        let scale_bytes = std::slice::from_raw_parts(base.add(scale) as *const u8, n_scale);
13020        let gs = sl(global_scale, base, 1)[0];
13021        let out = sl_mut(dst, base, m * n);
13022        dequant_matmul_nvfp4(xs, w_bytes, scale_bytes, gs, out, m, k, n);
13023    }
13024}
13025
13026/// Host-fallback entry for f16 `Op::GatedDeltaNet` tensors on Metal.
13027pub unsafe fn execute_gated_delta_net_f16(
13028    q: usize,
13029    k: usize,
13030    v: usize,
13031    g: usize,
13032    beta: usize,
13033    state: usize,
13034    dst: usize,
13035    batch: usize,
13036    seq: usize,
13037    heads: usize,
13038    state_size: usize,
13039    base: *mut u8,
13040) {
13041    use half::f16;
13042    unsafe {
13043        let read_f16 = |off: usize, len: usize| -> Vec<f32> {
13044            let raw = std::slice::from_raw_parts(base.add(off) as *const u8, len * 2);
13045            raw.chunks_exact(2)
13046                .map(|c| f16::from_le_bytes([c[0], c[1]]).to_f32())
13047                .collect()
13048        };
13049        let write_f16 = |off: usize, data: &[f32]| {
13050            let out = std::slice::from_raw_parts_mut(base.add(off), data.len() * 2);
13051            for (i, &v) in data.iter().enumerate() {
13052                let le = f16::from_f32(v).to_le_bytes();
13053                out[i * 2] = le[0];
13054                out[i * 2 + 1] = le[1];
13055            }
13056        };
13057
13058        let (b, s, h, n) = (batch, seq, heads, state_size);
13059        let q_f = read_f16(q, b * s * h * n);
13060        let k_f = read_f16(k, b * s * h * n);
13061        let v_f = read_f16(v, b * s * h * n);
13062        let g_f = read_f16(g, b * s * h);
13063        let b_f = read_f16(beta, b * s * h);
13064        let mut state_f = if state != 0 {
13065            read_f16(state, b * h * n * n)
13066        } else {
13067            vec![0f32; b * h * n * n]
13068        };
13069        let mut out_f = vec![0f32; b * s * h * n];
13070        let scale = 1.0f32 / (n as f32).sqrt();
13071        let mut sk_buf = vec![0f32; n];
13072        let mut owned_state = vec![0f32; h * n * n];
13073
13074        for bi in 0..b {
13075            let state_slice: &mut [f32] = if state != 0 {
13076                let start = bi * h * n * n;
13077                &mut state_f[start..start + h * n * n]
13078            } else {
13079                owned_state.fill(0.0);
13080                &mut owned_state
13081            };
13082
13083            for ti in 0..s {
13084                let qkv_step_base = bi * s * h * n + ti * h * n;
13085                let gb_step_base = bi * s * h + ti * h;
13086
13087                for hi in 0..h {
13088                    let q_row = &q_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13089                    let k_row = &k_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13090                    let v_row = &v_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13091                    let g_t = g_f[gb_step_base + hi];
13092                    let beta_t = b_f[gb_step_base + hi];
13093
13094                    let s_base = hi * n * n;
13095                    let s_mat = &mut state_slice[s_base..s_base + n * n];
13096
13097                    let g_exp = g_t.exp();
13098                    for st in s_mat.iter_mut() {
13099                        *st *= g_exp;
13100                    }
13101
13102                    for j in 0..n {
13103                        let mut acc = 0f32;
13104                        for i in 0..n {
13105                            acc += s_mat[i * n + j] * k_row[i];
13106                        }
13107                        sk_buf[j] = acc;
13108                    }
13109
13110                    for j in 0..n {
13111                        sk_buf[j] = (v_row[j] - sk_buf[j]) * beta_t;
13112                    }
13113
13114                    for i in 0..n {
13115                        let ki = k_row[i];
13116                        for j in 0..n {
13117                            s_mat[i * n + j] += ki * sk_buf[j];
13118                        }
13119                    }
13120
13121                    let out_row = &mut out_f[qkv_step_base + hi * n..qkv_step_base + (hi + 1) * n];
13122                    for j in 0..n {
13123                        let mut acc = 0f32;
13124                        for i in 0..n {
13125                            acc += s_mat[i * n + j] * q_row[i];
13126                        }
13127                        out_row[j] = acc * scale;
13128                    }
13129                }
13130            }
13131        }
13132
13133        write_f16(dst, &out_f);
13134        if state != 0 {
13135            write_f16(state, &state_f);
13136        }
13137    }
13138}
13139
13140/// Host fallback for NCHW group norm (Metal unified-memory arena).
13141pub unsafe fn execute_group_norm_nchw_f32(
13142    src: usize,
13143    g: usize,
13144    b: usize,
13145    dst: usize,
13146    n: usize,
13147    c: usize,
13148    h: usize,
13149    w: usize,
13150    num_groups: usize,
13151    eps: f32,
13152    base: *mut u8,
13153) {
13154    let plane = c * h * w;
13155    for ni in 0..n {
13156        let input = unsafe { sl(src + ni * plane * std::mem::size_of::<f32>(), base, plane) };
13157        let gamma = unsafe { sl(g, base, c) };
13158        let beta = unsafe { sl(b, base, c) };
13159        let output = unsafe { sl_mut(dst + ni * plane * std::mem::size_of::<f32>(), base, plane) };
13160        crate::kernels::group_norm_nchw(input, gamma, beta, output, 1, c, h, w, num_groups, eps);
13161    }
13162}
13163
13164/// Host fallback for NCHW LayerNorm2d (SAM / candle semantics).
13165pub unsafe fn execute_layer_norm2d_nchw_f32(
13166    src: usize,
13167    g: usize,
13168    b: usize,
13169    dst: usize,
13170    n: usize,
13171    c: usize,
13172    h: usize,
13173    w: usize,
13174    eps: f32,
13175    base: *mut u8,
13176) {
13177    let plane = c * h * w;
13178    unsafe {
13179        let input = sl(src, base, n * plane);
13180        let gamma = sl(g, base, c);
13181        let beta = sl(b, base, c);
13182        let output = sl_mut(dst, base, n * plane);
13183        crate::kernels::layer_norm2d_nchw(input, gamma, beta, output, n, c, h, w, eps);
13184    }
13185}
13186
13187/// Host fallback for NCHW ConvTranspose2d.
13188pub unsafe fn execute_conv_transpose2d_nchw_f32(
13189    src: usize,
13190    weight: usize,
13191    dst: usize,
13192    n: usize,
13193    c_in: usize,
13194    h: usize,
13195    w_in: usize,
13196    c_out: usize,
13197    h_out: usize,
13198    w_out: usize,
13199    kh: usize,
13200    kw: usize,
13201    sh: usize,
13202    sw: usize,
13203    ph: usize,
13204    pw: usize,
13205    dh: usize,
13206    dw: usize,
13207    groups: usize,
13208    base: *mut u8,
13209) {
13210    let in_elems = n * c_in * h * w_in;
13211    let w_elems = c_in * (c_out / groups) * kh * kw;
13212    let out_elems = n * c_out * h_out * w_out;
13213    unsafe {
13214        let input = sl(src, base, in_elems);
13215        let wt = sl(weight, base, w_elems);
13216        let output = sl_mut(dst, base, out_elems);
13217        crate::kernels::conv_transpose2d_nchw(
13218            input, wt, output, n, c_in, h, w_in, c_out, h_out, w_out, kh, kw, sh, sw, ph, pw, dh,
13219            dw, groups,
13220        );
13221    }
13222}
13223
13224/// Host fallback for nearest 2× upsample on NCHW.
13225pub unsafe fn execute_resize_nearest_2x_f32(
13226    src: usize,
13227    dst: usize,
13228    n: usize,
13229    c: usize,
13230    h: usize,
13231    w: usize,
13232    base: *mut u8,
13233) {
13234    let in_plane = c * h * w;
13235    let out_plane = c * h * 2 * w * 2;
13236    for ni in 0..n {
13237        let input = unsafe {
13238            sl(
13239                src + ni * in_plane * std::mem::size_of::<f32>(),
13240                base,
13241                in_plane,
13242            )
13243        };
13244        let output = unsafe {
13245            sl_mut(
13246                dst + ni * out_plane * std::mem::size_of::<f32>(),
13247                base,
13248                out_plane,
13249            )
13250        };
13251        crate::kernels::resize_nearest_2x_nchw(input, output, c, h, w);
13252    }
13253}
13254
13255/// Host axial 2-D RoPE for Metal (and other) fallbacks on unified memory.
13256pub unsafe fn execute_axial_rope2d_f32(
13257    src: usize,
13258    dst: usize,
13259    batch: usize,
13260    seq: usize,
13261    hidden: usize,
13262    end_x: usize,
13263    end_y: usize,
13264    head_dim: usize,
13265    num_heads: usize,
13266    theta: f32,
13267    repeat_factor: usize,
13268    base: *mut u8,
13269) {
13270    let plane = seq * hidden;
13271    let plane_bytes = plane * std::mem::size_of::<f32>();
13272    for bi in 0..batch {
13273        let in_off = src + bi * plane_bytes;
13274        let input = unsafe { sl(in_off, base, plane) };
13275        let rotated = rlx_ir::ops::axial_rope2d::apply_axial_rope2d(
13276            input,
13277            num_heads,
13278            seq,
13279            head_dim,
13280            end_x,
13281            end_y,
13282            theta,
13283            repeat_factor,
13284        );
13285        let out_off = dst + bi * plane_bytes;
13286        let output = unsafe { sl_mut(out_off, base, plane) };
13287        output.copy_from_slice(&rotated);
13288    }
13289}
13290
13291/// f32 mirror of `execute_fft1d_f64`. Same public-host-fallback role.
13292pub unsafe fn execute_fft1d_f32(
13293    src: usize,
13294    dst: usize,
13295    outer: usize,
13296    n_complex: usize,
13297    inverse: bool,
13298    norm_tag: u32,
13299    base: *mut u8,
13300) {
13301    let row_elems = 2 * n_complex;
13302    let mut re = vec![0f32; n_complex];
13303    let mut im = vec![0f32; n_complex];
13304    let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
13305    let scale = norm.output_scale(n_complex, inverse) as f32;
13306    let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
13307        BluesteinScratchF32::empty()
13308    } else {
13309        BluesteinScratchF32::build(n_complex, inverse)
13310    };
13311    for o in 0..outer {
13312        let row_offset = src + o * row_elems * std::mem::size_of::<f32>();
13313        let s = unsafe { sl(row_offset, base, row_elems) };
13314        re.copy_from_slice(&s[..n_complex]);
13315        im.copy_from_slice(&s[n_complex..]);
13316        if n_complex.is_power_of_two() {
13317            fft_radix2_inplace_f32(&mut re, &mut im, inverse);
13318        } else if n_complex <= 16 {
13319            fft_naive_inplace_f32(&mut re, &mut im, inverse);
13320        } else {
13321            fft_bluestein_inplace_f32(&mut re, &mut im, inverse, &mut scratch);
13322        }
13323        if scale != 1.0 {
13324            re.iter_mut().for_each(|v| *v *= scale);
13325            im.iter_mut().for_each(|v| *v *= scale);
13326        }
13327        let dst_offset = dst + o * row_elems * std::mem::size_of::<f32>();
13328        let d = unsafe { sl_mut(dst_offset, base, row_elems) };
13329        d[..n_complex].copy_from_slice(&re);
13330        d[n_complex..].copy_from_slice(&im);
13331    }
13332}
13333
13334/// C64 interleaved layout: each complex element is `[re: f32, im: f32]`.
13335pub unsafe fn execute_fft1d_c64(
13336    src: usize,
13337    dst: usize,
13338    outer: usize,
13339    n_complex: usize,
13340    inverse: bool,
13341    norm_tag: u32,
13342    base: *mut u8,
13343) {
13344    let row_bytes = n_complex * 8;
13345    let mut re = vec![0f32; n_complex];
13346    let mut im = vec![0f32; n_complex];
13347    let norm = rlx_ir::fft::FftNorm::from_tag(norm_tag);
13348    let scale = norm.output_scale(n_complex, inverse) as f32;
13349    let mut scratch = if n_complex.is_power_of_two() || n_complex <= 16 {
13350        BluesteinScratchF32::empty()
13351    } else {
13352        BluesteinScratchF32::build(n_complex, inverse)
13353    };
13354    for o in 0..outer {
13355        let row_offset = src + o * row_bytes;
13356        for i in 0..n_complex {
13357            let elem_off = row_offset + i * 8;
13358            re[i] = f32::from_le_bytes([
13359                *base.add(elem_off),
13360                *base.add(elem_off + 1),
13361                *base.add(elem_off + 2),
13362                *base.add(elem_off + 3),
13363            ]);
13364            im[i] = f32::from_le_bytes([
13365                *base.add(elem_off + 4),
13366                *base.add(elem_off + 5),
13367                *base.add(elem_off + 6),
13368                *base.add(elem_off + 7),
13369            ]);
13370        }
13371        if n_complex.is_power_of_two() {
13372            fft_radix2_inplace_f32(&mut re, &mut im, inverse);
13373        } else if n_complex <= 16 {
13374            fft_naive_inplace_f32(&mut re, &mut im, inverse);
13375        } else {
13376            fft_bluestein_inplace_f32(&mut re, &mut im, inverse, &mut scratch);
13377        }
13378        if scale != 1.0 {
13379            re.iter_mut().for_each(|v| *v *= scale);
13380            im.iter_mut().for_each(|v| *v *= scale);
13381        }
13382        let dst_row = dst + o * row_bytes;
13383        for i in 0..n_complex {
13384            let elem_off = dst_row + i * 8;
13385            let re_b = re[i].to_le_bytes();
13386            let im_b = im[i].to_le_bytes();
13387            for j in 0..4 {
13388                *base.add(elem_off + j) = re_b[j];
13389                *base.add(elem_off + 4 + j) = im_b[j];
13390            }
13391        }
13392    }
13393}
13394
13395/// Dtype-dispatching host entry for `Op::Fft` (shared by GPU host fallbacks).
13396pub unsafe fn execute_fft1d(
13397    src: usize,
13398    dst: usize,
13399    outer: usize,
13400    n_complex: usize,
13401    inverse: bool,
13402    norm_tag: u32,
13403    dtype: rlx_ir::DType,
13404    base: *mut u8,
13405) {
13406    match dtype {
13407        rlx_ir::DType::F32 => {
13408            execute_fft1d_f32(src, dst, outer, n_complex, inverse, norm_tag, base)
13409        }
13410        rlx_ir::DType::F64 => {
13411            execute_fft1d_f64(src, dst, outer, n_complex, inverse, norm_tag, base)
13412        }
13413        rlx_ir::DType::C64 => {
13414            execute_fft1d_c64(src, dst, outer, n_complex, inverse, norm_tag, base)
13415        }
13416        other => panic!("execute_fft1d: unsupported dtype {other:?}"),
13417    }
13418}
13419
13420/// f32 in-place radix-2 DIT Cooley-Tukey. Structurally identical to
13421/// the f64 path; twiddle recurrence is kept in f64 so accumulated
13422/// rotation drift doesn't dominate the per-stage error budget at
13423/// larger N.
13424fn fft_radix2_inplace_f32(re: &mut [f32], im: &mut [f32], inverse: bool) {
13425    let n = re.len();
13426    debug_assert_eq!(im.len(), n);
13427    debug_assert!(
13428        n.is_power_of_two(),
13429        "fft_radix2_f32: n={n} must be a power of two"
13430    );
13431    if n <= 1 {
13432        return;
13433    }
13434
13435    let mut j = 0usize;
13436    for i in 1..n {
13437        let mut bit = n >> 1;
13438        while j & bit != 0 {
13439            j ^= bit;
13440            bit >>= 1;
13441        }
13442        j ^= bit;
13443        if i < j {
13444            re.swap(i, j);
13445            im.swap(i, j);
13446        }
13447    }
13448
13449    let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
13450    let mut len = 2usize;
13451    while len <= n {
13452        let half = len / 2;
13453        let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
13454        let w_re_step = theta.cos();
13455        let w_im_step = theta.sin();
13456        let mut i = 0usize;
13457        while i < n {
13458            let mut wre = 1.0_f64;
13459            let mut wim = 0.0_f64;
13460            for k in 0..half {
13461                let wre_f = wre as f32;
13462                let wim_f = wim as f32;
13463                let t_re = wre_f * re[i + k + half] - wim_f * im[i + k + half];
13464                let t_im = wre_f * im[i + k + half] + wim_f * re[i + k + half];
13465                let u_re = re[i + k];
13466                let u_im = im[i + k];
13467                re[i + k] = u_re + t_re;
13468                im[i + k] = u_im + t_im;
13469                re[i + k + half] = u_re - t_re;
13470                im[i + k + half] = u_im - t_im;
13471                let new_wre = wre * w_re_step - wim * w_im_step;
13472                let new_wim = wre * w_im_step + wim * w_re_step;
13473                wre = new_wre;
13474                wim = new_wim;
13475            }
13476            i += len;
13477        }
13478        len <<= 1;
13479    }
13480}
13481
13482/// In-place radix-2 DIT Cooley-Tukey FFT on split (real, imag) f64
13483/// arrays. `n = re.len() = im.len()` must be a power of two. Forward
13484/// uses ω = exp(-2πi/n); inverse uses ω = exp(+2πi/n) (no 1/N scale).
13485fn fft_radix2_inplace_f64(re: &mut [f64], im: &mut [f64], inverse: bool) {
13486    let n = re.len();
13487    debug_assert_eq!(im.len(), n);
13488    debug_assert!(
13489        n.is_power_of_two(),
13490        "fft_radix2: n={n} must be a power of two"
13491    );
13492    if n <= 1 {
13493        return;
13494    }
13495
13496    // Bit-reverse permutation.
13497    let mut j = 0usize;
13498    for i in 1..n {
13499        let mut bit = n >> 1;
13500        while j & bit != 0 {
13501            j ^= bit;
13502            bit >>= 1;
13503        }
13504        j ^= bit;
13505        if i < j {
13506            re.swap(i, j);
13507            im.swap(i, j);
13508        }
13509    }
13510
13511    // Cooley-Tukey butterflies: ω_len = exp(±2πi/len).
13512    let sign = if inverse { 1.0 } else { -1.0 };
13513    let mut len = 2usize;
13514    while len <= n {
13515        let half = len / 2;
13516        let theta = sign * 2.0 * std::f64::consts::PI / (len as f64);
13517        let w_re_step = theta.cos();
13518        let w_im_step = theta.sin();
13519        let mut i = 0usize;
13520        while i < n {
13521            // Twiddle starts at 1+0i for each segment.
13522            let mut wre = 1.0_f64;
13523            let mut wim = 0.0_f64;
13524            for k in 0..half {
13525                let t_re = wre * re[i + k + half] - wim * im[i + k + half];
13526                let t_im = wre * im[i + k + half] + wim * re[i + k + half];
13527                let u_re = re[i + k];
13528                let u_im = im[i + k];
13529                re[i + k] = u_re + t_re;
13530                im[i + k] = u_im + t_im;
13531                re[i + k + half] = u_re - t_re;
13532                im[i + k + half] = u_im - t_im;
13533                let new_wre = wre * w_re_step - wim * w_im_step;
13534                let new_wim = wre * w_im_step + wim * w_re_step;
13535                wre = new_wre;
13536                wim = new_wim;
13537            }
13538            i += len;
13539        }
13540        len <<= 1;
13541    }
13542}
13543
13544/// Pre-computed chirp + filter-spectrum for one (N, direction) pair.
13545/// Built once per call to `execute_fft1d_f64` and reused across rows
13546/// when `outer > 1` — the chirp and FFT(b) don't depend on the input.
13547struct BluesteinScratchF64 {
13548    /// Power-of-two convolution length, ≥ 2N - 1.
13549    m: usize,
13550    /// `w[k] = exp(sign · iπ · k² / N)` for k=0..N, where sign matches
13551    /// the requested direction. Forward chirp on the way in, output
13552    /// chirp on the way out.
13553    w_re: Vec<f64>,
13554    w_im: Vec<f64>,
13555    /// FFT of the embedded filter `b[k] = conj(w[|k|])` in length-M.
13556    /// Doesn't depend on the input — precomputed once.
13557    bf_re: Vec<f64>,
13558    bf_im: Vec<f64>,
13559    /// Workspace reused per row (avoids per-row allocation).
13560    ar: Vec<f64>,
13561    ai: Vec<f64>,
13562}
13563
13564impl BluesteinScratchF64 {
13565    fn empty() -> Self {
13566        Self {
13567            m: 0,
13568            w_re: Vec::new(),
13569            w_im: Vec::new(),
13570            bf_re: Vec::new(),
13571            bf_im: Vec::new(),
13572            ar: Vec::new(),
13573            ai: Vec::new(),
13574        }
13575    }
13576
13577    fn build(n: usize, inverse: bool) -> Self {
13578        // M = next power of two ≥ 2N - 1 keeps the inner FFT on the
13579        // fast radix-2 path. For N=1 fall back to M=1 (no-op convolution).
13580        let m = if n <= 1 {
13581            1
13582        } else {
13583            (2 * n - 1).next_power_of_two()
13584        };
13585
13586        // Chirp arg reduced via k² mod 2N — without this, large N
13587        // bleeds precision into the trig call (n² grows quadratically).
13588        let mod_2n = (2 * n) as u64;
13589        let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
13590        let mut w_re = vec![0.0_f64; n];
13591        let mut w_im = vec![0.0_f64; n];
13592        for k in 0..n {
13593            let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
13594            let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
13595            w_re[k] = theta.cos();
13596            w_im[k] = theta.sin();
13597        }
13598
13599        // Embed b[k] = conj(w[|k|]) into length M with the negative
13600        // indices wrapping to the tail: b[-j] → B[M-j] for j=1..N-1.
13601        let mut bf_re = vec![0.0_f64; m];
13602        let mut bf_im = vec![0.0_f64; m];
13603        if n > 0 {
13604            bf_re[0] = w_re[0];
13605            bf_im[0] = -w_im[0];
13606            for k in 1..n {
13607                bf_re[k] = w_re[k];
13608                bf_im[k] = -w_im[k];
13609                bf_re[m - k] = w_re[k];
13610                bf_im[m - k] = -w_im[k];
13611            }
13612        }
13613        if m > 1 {
13614            fft_radix2_inplace_f64(&mut bf_re, &mut bf_im, false);
13615        }
13616
13617        Self {
13618            m,
13619            w_re,
13620            w_im,
13621            bf_re,
13622            bf_im,
13623            ar: vec![0.0_f64; m],
13624            ai: vec![0.0_f64; m],
13625        }
13626    }
13627}
13628
13629/// Direct O(N²) DFT for small non-pow2 N (faster than Bluestein setup).
13630fn fft_naive_inplace_f64(re: &mut [f64], im: &mut [f64], inverse: bool) {
13631    let n = re.len();
13632    if n <= 1 {
13633        return;
13634    }
13635    let sign = if inverse { 1.0 } else { -1.0 };
13636    let mut out_re = vec![0.0_f64; n];
13637    let mut out_im = vec![0.0_f64; n];
13638    for k in 0..n {
13639        for nn in 0..n {
13640            let theta = sign * 2.0 * std::f64::consts::PI * (nn as f64) * (k as f64) / (n as f64);
13641            let c = theta.cos();
13642            let s = theta.sin();
13643            out_re[k] += re[nn] * c - im[nn] * s;
13644            out_im[k] += re[nn] * s + im[nn] * c;
13645        }
13646    }
13647    re.copy_from_slice(&out_re);
13648    im.copy_from_slice(&out_im);
13649}
13650
13651fn fft_naive_inplace_f32(re: &mut [f32], im: &mut [f32], inverse: bool) {
13652    let n = re.len();
13653    if n <= 1 {
13654        return;
13655    }
13656    let sign = if inverse { 1.0f32 } else { -1.0f32 };
13657    let mut out_re = vec![0.0_f32; n];
13658    let mut out_im = vec![0.0_f32; n];
13659    for k in 0..n {
13660        for nn in 0..n {
13661            let theta = sign * 2.0 * std::f32::consts::PI * (nn as f32) * (k as f32) / (n as f32);
13662            let c = theta.cos();
13663            let s = theta.sin();
13664            out_re[k] += re[nn] * c - im[nn] * s;
13665            out_im[k] += re[nn] * s + im[nn] * c;
13666        }
13667    }
13668    re.copy_from_slice(&out_re);
13669    im.copy_from_slice(&out_im);
13670}
13671
13672/// Bluestein (chirp-z) FFT for arbitrary N. Identity used:
13673///   `n·k = (n² + k² - (k-n)²) / 2`
13674/// which lets the DFT be written as a linear convolution sandwiched
13675/// between two chirp multiplies:
13676///   `X[k] = w[k] · ((x·w) ⊛ conj(w))[k]`   where `w[n] = exp(±iπ·n²/N)`.
13677/// The convolution is computed via a length-M radix-2 FFT (M ≥ 2N-1).
13678/// Both directions stay unnormalized to match the radix-2 path, so the
13679/// chain rule keeps working without scaling.
13680fn fft_bluestein_inplace_f64(
13681    re: &mut [f64],
13682    im: &mut [f64],
13683    _inverse: bool,
13684    s: &mut BluesteinScratchF64,
13685) {
13686    let n = re.len();
13687    debug_assert_eq!(im.len(), n);
13688    debug_assert_eq!(s.w_re.len(), n);
13689    if n <= 1 {
13690        return;
13691    }
13692    let m = s.m;
13693
13694    // Pre-chirp: a[k] = x[k] · w[k], zero-padded to M.
13695    for k in 0..m {
13696        s.ar[k] = 0.0;
13697        s.ai[k] = 0.0;
13698    }
13699    for k in 0..n {
13700        s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
13701        s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
13702    }
13703
13704    // Length-M forward FFT of the padded chirped input.
13705    fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, false);
13706
13707    // Pointwise product with FFT(b). Stored back into (ar, ai).
13708    for k in 0..m {
13709        let ar = s.ar[k];
13710        let ai = s.ai[k];
13711        let br = s.bf_re[k];
13712        let bi = s.bf_im[k];
13713        s.ar[k] = ar * br - ai * bi;
13714        s.ai[k] = ar * bi + ai * br;
13715    }
13716
13717    // Inverse FFT — radix-2 here is the unnormalized inverse, so we
13718    // divide by M to recover the true circular convolution.
13719    fft_radix2_inplace_f64(&mut s.ar, &mut s.ai, true);
13720    let inv_m = 1.0 / (m as f64);
13721
13722    // Post-chirp: X[k] = w[k] · Y[k] / M for k = 0..N.
13723    for k in 0..n {
13724        let yr = s.ar[k] * inv_m;
13725        let yi = s.ai[k] * inv_m;
13726        re[k] = yr * s.w_re[k] - yi * s.w_im[k];
13727        im[k] = yr * s.w_im[k] + yi * s.w_re[k];
13728    }
13729}
13730
13731/// f32 mirror of `BluesteinScratchF64`. Chirp is computed in f64 for
13732/// precision (same justification as the radix-2 f32 path: twiddles in
13733/// f64, butterflies in f32). The actual conv buffers are f32.
13734struct BluesteinScratchF32 {
13735    m: usize,
13736    w_re: Vec<f32>,
13737    w_im: Vec<f32>,
13738    bf_re: Vec<f32>,
13739    bf_im: Vec<f32>,
13740    ar: Vec<f32>,
13741    ai: Vec<f32>,
13742}
13743
13744impl BluesteinScratchF32 {
13745    fn empty() -> Self {
13746        Self {
13747            m: 0,
13748            w_re: Vec::new(),
13749            w_im: Vec::new(),
13750            bf_re: Vec::new(),
13751            bf_im: Vec::new(),
13752            ar: Vec::new(),
13753            ai: Vec::new(),
13754        }
13755    }
13756
13757    fn build(n: usize, inverse: bool) -> Self {
13758        let m = if n <= 1 {
13759            1
13760        } else {
13761            (2 * n - 1).next_power_of_two()
13762        };
13763
13764        let mod_2n = (2 * n) as u64;
13765        let sign = if inverse { 1.0_f64 } else { -1.0_f64 };
13766        let mut w_re = vec![0.0_f32; n];
13767        let mut w_im = vec![0.0_f32; n];
13768        for k in 0..n {
13769            let k2 = (k as u64).wrapping_mul(k as u64) % mod_2n;
13770            let theta = sign * std::f64::consts::PI * (k2 as f64) / (n as f64);
13771            w_re[k] = theta.cos() as f32;
13772            w_im[k] = theta.sin() as f32;
13773        }
13774
13775        let mut bf_re = vec![0.0_f32; m];
13776        let mut bf_im = vec![0.0_f32; m];
13777        if n > 0 {
13778            bf_re[0] = w_re[0];
13779            bf_im[0] = -w_im[0];
13780            for k in 1..n {
13781                bf_re[k] = w_re[k];
13782                bf_im[k] = -w_im[k];
13783                bf_re[m - k] = w_re[k];
13784                bf_im[m - k] = -w_im[k];
13785            }
13786        }
13787        if m > 1 {
13788            fft_radix2_inplace_f32(&mut bf_re, &mut bf_im, false);
13789        }
13790
13791        Self {
13792            m,
13793            w_re,
13794            w_im,
13795            bf_re,
13796            bf_im,
13797            ar: vec![0.0_f32; m],
13798            ai: vec![0.0_f32; m],
13799        }
13800    }
13801}
13802
13803fn fft_bluestein_inplace_f32(
13804    re: &mut [f32],
13805    im: &mut [f32],
13806    _inverse: bool,
13807    s: &mut BluesteinScratchF32,
13808) {
13809    let n = re.len();
13810    debug_assert_eq!(im.len(), n);
13811    debug_assert_eq!(s.w_re.len(), n);
13812    if n <= 1 {
13813        return;
13814    }
13815    let m = s.m;
13816
13817    for k in 0..m {
13818        s.ar[k] = 0.0;
13819        s.ai[k] = 0.0;
13820    }
13821    for k in 0..n {
13822        s.ar[k] = re[k] * s.w_re[k] - im[k] * s.w_im[k];
13823        s.ai[k] = re[k] * s.w_im[k] + im[k] * s.w_re[k];
13824    }
13825
13826    fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, false);
13827
13828    for k in 0..m {
13829        let ar = s.ar[k];
13830        let ai = s.ai[k];
13831        let br = s.bf_re[k];
13832        let bi = s.bf_im[k];
13833        s.ar[k] = ar * br - ai * bi;
13834        s.ai[k] = ar * bi + ai * br;
13835    }
13836
13837    fft_radix2_inplace_f32(&mut s.ar, &mut s.ai, true);
13838    let inv_m = 1.0_f32 / (m as f32);
13839
13840    for k in 0..n {
13841        let yr = s.ar[k] * inv_m;
13842        let yi = s.ai[k] * inv_m;
13843        re[k] = yr * s.w_re[k] - yi * s.w_im[k];
13844        im[k] = yr * s.w_im[k] + yi * s.w_re[k];
13845    }
13846}
13847
13848/// Shared dispatch path for `Thunk::CustomOp`. Builds a typed
13849/// [`CpuTensorRef`] for each input *at that input's declared dtype*
13850/// (so a sparse-LU op with mixed F64/I32 inputs gets the right
13851/// typed slices) and a [`CpuTensorMut`] for the output, then calls
13852/// the kernel's single `execute` method.
13853unsafe fn dispatch_custom_op(
13854    kernel: &dyn crate::op_registry::CpuKernel,
13855    inputs: &[(usize, u32, Shape)],
13856    out_off: usize,
13857    out_len: u32,
13858    out_shape: &Shape,
13859    attrs: &[u8],
13860    base: *mut u8,
13861) {
13862    use crate::op_registry::{CpuTensorMut, CpuTensorRef};
13863    use rlx_ir::DType;
13864
13865    // One arm per `DType` variant — single source of truth for
13866    // "which dtypes the CPU custom-op dispatcher wires." If a new
13867    // DType lands in `rlx-ir`, the compiler flags this match as
13868    // non-exhaustive and the gap gets named at the right place.
13869    macro_rules! build_in_view {
13870        ($shape:expr, $off:expr, $n:expr, $variant:ident, $rust_ty:ty) => {
13871            CpuTensorRef::$variant {
13872                data: unsafe { sl_typed::<$rust_ty>($off, base, $n) },
13873                shape: $shape,
13874            }
13875        };
13876    }
13877    macro_rules! build_out_view {
13878        ($variant:ident, $rust_ty:ty) => {
13879            CpuTensorMut::$variant {
13880                data: unsafe { sl_mut_typed::<$rust_ty>(out_off, base, out_len as usize) },
13881                shape: out_shape,
13882            }
13883        };
13884    }
13885
13886    let in_views: Vec<CpuTensorRef<'_>> = inputs
13887        .iter()
13888        .map(|(off, len, shape)| {
13889            let n = *len as usize;
13890            let off = *off;
13891            match shape.dtype() {
13892                DType::F32 => build_in_view!(shape, off, n, F32, f32),
13893                DType::F64 => build_in_view!(shape, off, n, F64, f64),
13894                DType::F16 => build_in_view!(shape, off, n, F16, half::f16),
13895                DType::BF16 => build_in_view!(shape, off, n, BF16, half::bf16),
13896                DType::I8 => build_in_view!(shape, off, n, I8, i8),
13897                DType::I16 => build_in_view!(shape, off, n, I16, i16),
13898                DType::I32 => build_in_view!(shape, off, n, I32, i32),
13899                DType::I64 => build_in_view!(shape, off, n, I64, i64),
13900                DType::U8 => build_in_view!(shape, off, n, U8, u8),
13901                DType::U32 => build_in_view!(shape, off, n, U32, u32),
13902                DType::Bool => build_in_view!(shape, off, n, Bool, u8),
13903                // C64 isn't a CpuTensor variant today; the user-registered
13904                // op_registry path doesn't see complex inputs (those are
13905                // handled by built-in ops with dedicated kernels).
13906                DType::C64 => panic!(
13907                    "Op::Custom kernel input has DType::C64 — built-in \
13908                 complex ops handle their own kernels; user-registered \
13909                 ops don't yet see complex tensors"
13910                ),
13911            }
13912        })
13913        .collect();
13914
13915    let result = match out_shape.dtype() {
13916        DType::F32 => kernel.execute(&in_views, build_out_view!(F32, f32), attrs),
13917        DType::F64 => kernel.execute(&in_views, build_out_view!(F64, f64), attrs),
13918        DType::F16 => kernel.execute(&in_views, build_out_view!(F16, half::f16), attrs),
13919        DType::BF16 => kernel.execute(&in_views, build_out_view!(BF16, half::bf16), attrs),
13920        DType::I8 => kernel.execute(&in_views, build_out_view!(I8, i8), attrs),
13921        DType::I16 => kernel.execute(&in_views, build_out_view!(I16, i16), attrs),
13922        DType::I32 => kernel.execute(&in_views, build_out_view!(I32, i32), attrs),
13923        DType::I64 => kernel.execute(&in_views, build_out_view!(I64, i64), attrs),
13924        DType::U8 => kernel.execute(&in_views, build_out_view!(U8, u8), attrs),
13925        DType::U32 => kernel.execute(&in_views, build_out_view!(U32, u32), attrs),
13926        DType::Bool => kernel.execute(&in_views, build_out_view!(Bool, u8), attrs),
13927        DType::C64 => panic!("Op::Custom output DType::C64 not supported"),
13928    };
13929    if let Err(e) = result {
13930        panic!("Op::Custom('{}') CPU kernel failed: {e}", kernel.name());
13931    }
13932}
13933
13934/// Generic raw-cast slice helper. The existing per-dtype `sl_*` /
13935/// `sl_mut_*` helpers stay in place for the rest of `thunk.rs` (which
13936/// uses them at call sites with concrete dtypes); the custom-op
13937/// dispatcher uses these to enumerate every `DType` uniformly without
13938/// listing one helper per dtype.
13939#[inline(always)]
13940unsafe fn sl_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static [T] {
13941    if offset == usize::MAX {
13942        return &[];
13943    }
13944    unsafe { std::slice::from_raw_parts(base.add(offset) as *const T, len) }
13945}
13946
13947#[inline(always)]
13948unsafe fn sl_mut_typed<T>(offset: usize, base: *mut u8, len: usize) -> &'static mut [T] {
13949    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut T, len) }
13950}
13951
13952// Unsafe helpers to create slices from arena base + offset
13953#[inline(always)]
13954/// In-place per-element activation. Mirrors the dispatch in
13955/// `Thunk::ActivationInPlace`. Used by `Thunk::FusedMmBiasAct` to
13956/// apply the activation after `bias_add` for all non-Gelu cases.
13957fn apply_activation_inplace(d: &mut [f32], act: rlx_ir::op::Activation) {
13958    use rlx_ir::op::Activation;
13959    match act {
13960        Activation::Gelu => crate::kernels::par_gelu_inplace(d),
13961        Activation::GeluApprox => crate::kernels::par_gelu_approx_inplace(d),
13962        Activation::Silu => crate::kernels::par_silu_inplace(d),
13963        Activation::Relu => {
13964            for v in d.iter_mut() {
13965                *v = v.max(0.0);
13966            }
13967        }
13968        Activation::Sigmoid => {
13969            for v in d.iter_mut() {
13970                *v = 1.0 / (1.0 + (-*v).exp());
13971            }
13972        }
13973        Activation::Tanh => {
13974            for v in d.iter_mut() {
13975                *v = v.tanh();
13976            }
13977        }
13978        Activation::Exp => {
13979            for v in d.iter_mut() {
13980                *v = v.exp();
13981            }
13982        }
13983        Activation::Log => {
13984            for v in d.iter_mut() {
13985                *v = v.ln();
13986            }
13987        }
13988        Activation::Sqrt => {
13989            for v in d.iter_mut() {
13990                *v = v.sqrt();
13991            }
13992        }
13993        Activation::Rsqrt => {
13994            for v in d.iter_mut() {
13995                *v = 1.0 / v.sqrt();
13996            }
13997        }
13998        Activation::Neg => {
13999            for v in d.iter_mut() {
14000                *v = -*v;
14001            }
14002        }
14003        Activation::Abs => {
14004            for v in d.iter_mut() {
14005                *v = v.abs();
14006            }
14007        }
14008        Activation::Round => {
14009            for v in d.iter_mut() {
14010                *v = v.round();
14011            }
14012        }
14013        Activation::Sin => {
14014            for v in d.iter_mut() {
14015                *v = v.sin();
14016            }
14017        }
14018        Activation::Cos => {
14019            for v in d.iter_mut() {
14020                *v = v.cos();
14021            }
14022        }
14023        Activation::Tan => {
14024            for v in d.iter_mut() {
14025                *v = v.tan();
14026            }
14027        }
14028        Activation::Atan => {
14029            for v in d.iter_mut() {
14030                *v = v.atan();
14031            }
14032        }
14033    }
14034}
14035
14036/// im2col for one image (single batch + group slice).
14037///
14038/// Source `x` is `[c_in, H, W]` row-major. Destination `col` is
14039/// `[c_in · kH · kW, H_out · W_out]` row-major. Out-of-bounds positions
14040/// (in the padded region) are written as 0.
14041///
14042/// `col[(ci · kH · kW + ki · kW + kj) · n_dim + ho · W_out + wo] =
14043///    x[ci, ho·sh + ki·dh − ph, wo·sw + kj·dw_dil − pw]`
14044#[allow(clippy::too_many_arguments)]
14045fn im2col(
14046    x: &[f32],
14047    col: &mut [f32],
14048    c_in: usize,
14049    h: usize,
14050    w: usize,
14051    h_out: usize,
14052    w_out: usize,
14053    kh: usize,
14054    kw: usize,
14055    sh: usize,
14056    sw: usize,
14057    ph: usize,
14058    pw: usize,
14059    dh: usize,
14060    dw_dil: usize,
14061) {
14062    let n_dim = h_out * w_out;
14063    debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
14064    debug_assert_eq!(x.len(), c_in * h * w);
14065    let h_isz = h as isize;
14066    let w_isz = w as isize;
14067    let ph_isz = ph as isize;
14068    let pw_isz = pw as isize;
14069    for ci in 0..c_in {
14070        for ki in 0..kh {
14071            for kj in 0..kw {
14072                let row = ((ci * kh) + ki) * kw + kj;
14073                let row_off = row * n_dim;
14074                for ho in 0..h_out {
14075                    let hi = (ho * sh + ki * dh) as isize - ph_isz;
14076                    if hi < 0 || hi >= h_isz {
14077                        for wo in 0..w_out {
14078                            col[row_off + ho * w_out + wo] = 0.0;
14079                        }
14080                        continue;
14081                    }
14082                    let hi = hi as usize;
14083                    let in_row_off = (ci * h + hi) * w;
14084                    for wo in 0..w_out {
14085                        let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
14086                        col[row_off + ho * w_out + wo] = if wi < 0 || wi >= w_isz {
14087                            0.0
14088                        } else {
14089                            x[in_row_off + wi as usize]
14090                        };
14091                    }
14092                }
14093            }
14094        }
14095    }
14096}
14097
14098/// col2im — inverse of `im2col` with scatter-accumulation. The caller
14099/// is responsible for zeroing `x` if it doesn't already start zero
14100/// (the conv-input-grad path zeros once before the batch loop).
14101///
14102/// `x[ci, hi, wi] += col[(ci · kH · kW + ki · kW + kj) · n_dim + ho · W_out + wo]`
14103/// for all `(ki, kj, ho, wo)` whose `(hi, wi)` lands in `[0, H) × [0, W)`.
14104#[allow(clippy::too_many_arguments)]
14105fn col2im(
14106    col: &[f32],
14107    x: &mut [f32],
14108    c_in: usize,
14109    h: usize,
14110    w: usize,
14111    h_out: usize,
14112    w_out: usize,
14113    kh: usize,
14114    kw: usize,
14115    sh: usize,
14116    sw: usize,
14117    ph: usize,
14118    pw: usize,
14119    dh: usize,
14120    dw_dil: usize,
14121) {
14122    let n_dim = h_out * w_out;
14123    debug_assert_eq!(col.len(), c_in * kh * kw * n_dim);
14124    debug_assert_eq!(x.len(), c_in * h * w);
14125    let h_isz = h as isize;
14126    let w_isz = w as isize;
14127    let ph_isz = ph as isize;
14128    let pw_isz = pw as isize;
14129    for ci in 0..c_in {
14130        for ki in 0..kh {
14131            for kj in 0..kw {
14132                let row = ((ci * kh) + ki) * kw + kj;
14133                let row_off = row * n_dim;
14134                for ho in 0..h_out {
14135                    let hi = (ho * sh + ki * dh) as isize - ph_isz;
14136                    if hi < 0 || hi >= h_isz {
14137                        continue;
14138                    }
14139                    let hi = hi as usize;
14140                    let in_row_off = (ci * h + hi) * w;
14141                    for wo in 0..w_out {
14142                        let wi = (wo * sw + kj * dw_dil) as isize - pw_isz;
14143                        if wi < 0 || wi >= w_isz {
14144                            continue;
14145                        }
14146                        x[in_row_off + wi as usize] += col[row_off + ho * w_out + wo];
14147                    }
14148                }
14149            }
14150        }
14151    }
14152}
14153
14154/// Element-wise backward for `Op::Activation`. `xs` is the original
14155/// input to the forward activation; `dys` is the upstream gradient.
14156/// Writes `out[i] = (d/dx act(xs[i])) * dys[i]`.
14157/// Decompose a per-channel quantization shape into the
14158/// `(chan_axis, chan_dim, inner)` triplet the kernel needs to map a
14159/// flat output index to a channel index. Per-tensor (`axis = None`)
14160/// degenerates to `chan_dim = 1, inner = len`, which makes the
14161/// kernel's `(i / inner) % chan_dim` always 0 — same fast path the
14162/// scalar version used.
14163fn quant_layout(shape: &rlx_ir::Shape, axis: Option<usize>) -> (usize, usize, usize) {
14164    match axis {
14165        None => (0, 1, shape.num_elements().unwrap_or(0).max(1)),
14166        Some(d) => {
14167            let chan_dim = shape.dim(d).unwrap_static();
14168            let inner: usize = (d + 1..shape.rank())
14169                .map(|i| shape.dim(i).unwrap_static())
14170                .product::<usize>()
14171                .max(1);
14172            (d, chan_dim, inner)
14173        }
14174    }
14175}
14176
14177fn activation_backward_kernel(
14178    act: rlx_ir::op::Activation,
14179    xs: &[f32],
14180    dys: &[f32],
14181    out: &mut [f32],
14182) {
14183    use rlx_ir::op::Activation;
14184    let n = xs.len();
14185    debug_assert_eq!(dys.len(), n);
14186    debug_assert_eq!(out.len(), n);
14187    match act {
14188        Activation::Relu => {
14189            for i in 0..n {
14190                out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
14191            }
14192        }
14193        Activation::Sigmoid => {
14194            for i in 0..n {
14195                let s = 1.0 / (1.0 + (-xs[i]).exp());
14196                out[i] = s * (1.0 - s) * dys[i];
14197            }
14198        }
14199        Activation::Tanh => {
14200            for i in 0..n {
14201                let t = xs[i].tanh();
14202                out[i] = (1.0 - t * t) * dys[i];
14203            }
14204        }
14205        Activation::Silu => {
14206            // y = x * σ(x);  dy/dx = σ(x) * (1 + x * (1 - σ(x))).
14207            for i in 0..n {
14208                let s = 1.0 / (1.0 + (-xs[i]).exp());
14209                out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
14210            }
14211        }
14212        Activation::Gelu => {
14213            // Exact erf-based GELU:  y = 0.5 x (1 + erf(x / √2)).
14214            //   dy/dx = 0.5 (1 + erf(x/√2)) + (x / √(2π)) · exp(-x²/2)
14215            const INV_SQRT2: f32 = 0.707_106_77;
14216            const INV_SQRT_2PI: f32 = 0.398_942_3;
14217            for i in 0..n {
14218                let x = xs[i];
14219                let phi = 0.5 * (1.0 + erf_f32(x * INV_SQRT2));
14220                let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
14221                out[i] = (phi + x * pdf) * dys[i];
14222            }
14223        }
14224        Activation::GeluApprox => {
14225            // Tanh-approximation:
14226            //   y = 0.5 x (1 + tanh(c · (x + 0.044715 x³))) where c = √(2/π).
14227            const C: f32 = 0.797_884_6; // √(2/π)
14228            const A: f32 = 0.044_715;
14229            for i in 0..n {
14230                let x = xs[i];
14231                let inner = C * (x + A * x * x * x);
14232                let t = inner.tanh();
14233                let dinner = C * (1.0 + 3.0 * A * x * x);
14234                let d = 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * dinner;
14235                out[i] = d * dys[i];
14236            }
14237        }
14238        Activation::Exp => {
14239            for i in 0..n {
14240                out[i] = xs[i].exp() * dys[i];
14241            }
14242        }
14243        Activation::Log => {
14244            for i in 0..n {
14245                out[i] = dys[i] / xs[i];
14246            }
14247        }
14248        Activation::Sqrt => {
14249            // d/dx √x = 0.5 / √x — undefined at x=0; clamp to 0.
14250            for i in 0..n {
14251                let s = xs[i].sqrt();
14252                out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
14253            }
14254        }
14255        Activation::Rsqrt => {
14256            // d/dx (1/√x) = -0.5 · x^(-3/2).
14257            for i in 0..n {
14258                let s = xs[i].sqrt();
14259                out[i] = if s > 0.0 {
14260                    -0.5 * dys[i] / (xs[i] * s)
14261                } else {
14262                    0.0
14263                };
14264            }
14265        }
14266        Activation::Neg => {
14267            for i in 0..n {
14268                out[i] = -dys[i];
14269            }
14270        }
14271        Activation::Abs => {
14272            // sign(x); 0 at x=0.
14273            for i in 0..n {
14274                let x = xs[i];
14275                let s = if x > 0.0 {
14276                    1.0
14277                } else if x < 0.0 {
14278                    -1.0
14279                } else {
14280                    0.0
14281                };
14282                out[i] = s * dys[i];
14283            }
14284        }
14285        Activation::Round => {
14286            // STE: pretend the round was identity in the backward
14287            // pass. The round step has zero gradient almost
14288            // everywhere, so without this trick the optimizer can't
14289            // learn through it.
14290            out.copy_from_slice(dys);
14291        }
14292        Activation::Sin => {
14293            // d/dx sin(x) = cos(x).
14294            for i in 0..n {
14295                out[i] = xs[i].cos() * dys[i];
14296            }
14297        }
14298        Activation::Cos => {
14299            for i in 0..n {
14300                out[i] = -xs[i].sin() * dys[i];
14301            }
14302        }
14303        Activation::Tan => {
14304            // d/dx tan(x) = sec²(x) = 1 + tan²(x)
14305            for i in 0..n {
14306                let t = xs[i].tan();
14307                out[i] = (1.0 + t * t) * dys[i];
14308            }
14309        }
14310        Activation::Atan => {
14311            // d/dx atan(x) = 1 / (1 + x²)
14312            for i in 0..n {
14313                let x = xs[i];
14314                out[i] = dys[i] / (1.0 + x * x);
14315            }
14316        }
14317    }
14318}
14319
14320/// f64 sibling of `activation_backward_kernel`. Same math, twice the
14321/// precision — used by f64 graphs where the f32 kernel reading bytes
14322/// as `&[f32]` would silently discard half of every f64 value.
14323fn activation_backward_kernel_f64(
14324    act: rlx_ir::op::Activation,
14325    xs: &[f64],
14326    dys: &[f64],
14327    out: &mut [f64],
14328) {
14329    use rlx_ir::op::Activation;
14330    let n = xs.len();
14331    debug_assert_eq!(dys.len(), n);
14332    debug_assert_eq!(out.len(), n);
14333    match act {
14334        Activation::Relu => {
14335            for i in 0..n {
14336                out[i] = if xs[i] > 0.0 { dys[i] } else { 0.0 };
14337            }
14338        }
14339        Activation::Sigmoid => {
14340            for i in 0..n {
14341                let s = 1.0 / (1.0 + (-xs[i]).exp());
14342                out[i] = s * (1.0 - s) * dys[i];
14343            }
14344        }
14345        Activation::Tanh => {
14346            for i in 0..n {
14347                let t = xs[i].tanh();
14348                out[i] = (1.0 - t * t) * dys[i];
14349            }
14350        }
14351        Activation::Silu => {
14352            for i in 0..n {
14353                let s = 1.0 / (1.0 + (-xs[i]).exp());
14354                out[i] = s * (1.0 + xs[i] * (1.0 - s)) * dys[i];
14355            }
14356        }
14357        Activation::Gelu | Activation::GeluApprox => {
14358            // Both rare on f64 paths; use the high-quality libm erf.
14359            const INV_SQRT2: f64 = std::f64::consts::FRAC_1_SQRT_2;
14360            const INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
14361            for i in 0..n {
14362                let x = xs[i];
14363                let phi = 0.5 * (1.0 + erf_f64(x * INV_SQRT2));
14364                let pdf = INV_SQRT_2PI * (-(x * x) * 0.5).exp();
14365                out[i] = (phi + x * pdf) * dys[i];
14366            }
14367        }
14368        Activation::Exp => {
14369            for i in 0..n {
14370                out[i] = xs[i].exp() * dys[i];
14371            }
14372        }
14373        Activation::Log => {
14374            for i in 0..n {
14375                out[i] = dys[i] / xs[i];
14376            }
14377        }
14378        Activation::Sqrt => {
14379            for i in 0..n {
14380                let s = xs[i].sqrt();
14381                out[i] = if s > 0.0 { 0.5 * dys[i] / s } else { 0.0 };
14382            }
14383        }
14384        Activation::Rsqrt => {
14385            for i in 0..n {
14386                let s = xs[i].sqrt();
14387                out[i] = if s > 0.0 {
14388                    -0.5 * dys[i] / (xs[i] * s)
14389                } else {
14390                    0.0
14391                };
14392            }
14393        }
14394        Activation::Neg => {
14395            for i in 0..n {
14396                out[i] = -dys[i];
14397            }
14398        }
14399        Activation::Abs => {
14400            for i in 0..n {
14401                let x = xs[i];
14402                let s = if x > 0.0 {
14403                    1.0
14404                } else if x < 0.0 {
14405                    -1.0
14406                } else {
14407                    0.0
14408                };
14409                out[i] = s * dys[i];
14410            }
14411        }
14412        Activation::Round => {
14413            out.copy_from_slice(dys);
14414        }
14415        Activation::Sin => {
14416            for i in 0..n {
14417                out[i] = xs[i].cos() * dys[i];
14418            }
14419        }
14420        Activation::Cos => {
14421            for i in 0..n {
14422                out[i] = -xs[i].sin() * dys[i];
14423            }
14424        }
14425        Activation::Tan => {
14426            for i in 0..n {
14427                let t = xs[i].tan();
14428                out[i] = (1.0 + t * t) * dys[i];
14429            }
14430        }
14431        Activation::Atan => {
14432            for i in 0..n {
14433                let x = xs[i];
14434                out[i] = dys[i] / (1.0 + x * x);
14435            }
14436        }
14437    }
14438}
14439
14440/// f64 erf via A&S 7.1.26 — same coefficients as `erf_f32`, computed
14441/// at f64 width. Max error ~1.5e-7 (limited by the polynomial, not the
14442/// arithmetic). Adequate for gradient kernels; if higher precision is
14443/// needed, swap in a libm dependency.
14444#[inline(always)]
14445fn erf_f64(x: f64) -> f64 {
14446    let s = x.signum();
14447    let x = x.abs();
14448    let t = 1.0 / (1.0 + 0.327_591_1 * x);
14449    let y = 1.0
14450        - (((((1.061_405_43 * t - 1.453_152_03) * t) + 1.421_413_75) * t - 0.284_496_74) * t
14451            + 0.254_829_59)
14452            * t
14453            * (-x * x).exp();
14454    s * y
14455}
14456
14457/// Cheap erf approximation (Abramowitz & Stegun 7.1.26, max error ~1.5e-7
14458/// over all of ℝ — plenty for f32 gradient kernels).
14459#[inline(always)]
14460fn erf_f32(x: f32) -> f32 {
14461    let s = x.signum();
14462    let x = x.abs();
14463    let t = 1.0 / (1.0 + 0.327_591_1 * x);
14464    let y = 1.0
14465        - (((((1.061_405_4 * t - 1.453_152_1) * t) + 1.421_413_8) * t - 0.284_496_74) * t
14466            + 0.254_829_6)
14467            * t
14468            * (-x * x).exp();
14469    s * y
14470}
14471
14472fn narrow_thunk_closure(
14473    src: usize,
14474    dst: usize,
14475    outer: u32,
14476    src_stride: u32,
14477    dst_stride: u32,
14478    inner: u32,
14479    elem_bytes: u8,
14480) -> Arc<dyn Fn(*mut u8) + Send + Sync> {
14481    let (outer, ss, ds, inner) = (
14482        outer as usize,
14483        src_stride as usize,
14484        dst_stride as usize,
14485        inner as usize,
14486    );
14487    if elem_bytes == 8 {
14488        Arc::new(move |base: *mut u8| unsafe {
14489            let s = sl_f64(src, base, outer * ss);
14490            let d = sl_mut_f64(dst, base, outer * ds);
14491            for o in 0..outer {
14492                d[o * ds..o * ds + inner].copy_from_slice(&s[o * ss..o * ss + inner]);
14493            }
14494        })
14495    } else {
14496        Arc::new(move |base: *mut u8| unsafe {
14497            let s = sl(src, base, outer * ss);
14498            let d = sl_mut(dst, base, outer * ds);
14499            for o in 0..outer {
14500                d[o * ds..o * ds + inner].copy_from_slice(&s[o * ss..o * ss + inner]);
14501            }
14502        })
14503    }
14504}
14505
14506unsafe fn sl(offset: usize, base: *mut u8, len: usize) -> &'static [f32] {
14507    if offset == usize::MAX {
14508        return &[];
14509    }
14510    unsafe { std::slice::from_raw_parts(base.add(offset) as *const f32, len) }
14511}
14512
14513#[inline(always)]
14514unsafe fn sl_mut(offset: usize, base: *mut u8, len: usize) -> &'static mut [f32] {
14515    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f32, len) }
14516}
14517
14518#[inline(always)]
14519unsafe fn sl_f64(offset: usize, base: *mut u8, len: usize) -> &'static [f64] {
14520    if offset == usize::MAX {
14521        return &[];
14522    }
14523    unsafe { std::slice::from_raw_parts(base.add(offset) as *const f64, len) }
14524}
14525
14526#[inline(always)]
14527unsafe fn sl_mut_f64(offset: usize, base: *mut u8, len: usize) -> &'static mut [f64] {
14528    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut f64, len) }
14529}
14530
14531// i32 / i64 typed slice helpers — siblings of sl_f32/sl_f64. Kept for
14532// integer-tensor thunks that haven't landed yet (Sample, Gather index
14533// buffers); deleting them now would force re-deriving the unsafe
14534// boilerplate when the next int-typed thunk lands.
14535#[allow(dead_code)]
14536#[inline(always)]
14537unsafe fn sl_i32(offset: usize, base: *mut u8, len: usize) -> &'static [i32] {
14538    if offset == usize::MAX {
14539        return &[];
14540    }
14541    unsafe { std::slice::from_raw_parts(base.add(offset) as *const i32, len) }
14542}
14543
14544#[allow(dead_code)]
14545#[inline(always)]
14546unsafe fn sl_mut_i32(offset: usize, base: *mut u8, len: usize) -> &'static mut [i32] {
14547    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i32, len) }
14548}
14549
14550#[allow(dead_code)]
14551#[inline(always)]
14552unsafe fn sl_i64(offset: usize, base: *mut u8, len: usize) -> &'static [i64] {
14553    if offset == usize::MAX {
14554        return &[];
14555    }
14556    unsafe { std::slice::from_raw_parts(base.add(offset) as *const i64, len) }
14557}
14558
14559#[allow(dead_code)]
14560#[inline(always)]
14561unsafe fn sl_mut_i64(offset: usize, base: *mut u8, len: usize) -> &'static mut [i64] {
14562    unsafe { std::slice::from_raw_parts_mut(base.add(offset) as *mut i64, len) }
14563}
14564
14565/// f64 N-D index walk used by Transpose and Expand. `out_dims` gives
14566/// the output shape; `in_strides` gives the source stride for each
14567/// output dim (broadcast axes have stride 0).
14568fn transpose_walk_f64(inp: &[f64], out: &mut [f64], out_dims: &[u32], in_strides: &[u32]) {
14569    let rank = out_dims.len();
14570    let mut idx = vec![0u32; rank];
14571    for o in 0..out.len() {
14572        let mut src_off = 0usize;
14573        for d in 0..rank {
14574            src_off += idx[d] as usize * in_strides[d] as usize;
14575        }
14576        out[o] = inp[src_off];
14577        // Increment index — last dim varies fastest.
14578        for d in (0..rank).rev() {
14579            idx[d] += 1;
14580            if idx[d] < out_dims[d] {
14581                break;
14582            }
14583            idx[d] = 0;
14584        }
14585    }
14586}
14587
14588/// f64 elementwise activation. Reads `inp`, writes `out`. For now
14589/// covers what the autodiff-emitted gradient graph needs (Neg, Exp,
14590/// Log, Sqrt, Rsqrt, Abs, Tanh, Sigmoid, Relu — the
14591/// transcendental-free subset). Approximate Gelu/Silu deferred until a
14592/// workload demands them at f64.
14593fn apply_activation_f64(inp: &[f64], out: &mut [f64], kind: Activation) {
14594    match kind {
14595        Activation::Neg => {
14596            for (o, &v) in out.iter_mut().zip(inp) {
14597                *o = -v;
14598            }
14599        }
14600        Activation::Exp => {
14601            for (o, &v) in out.iter_mut().zip(inp) {
14602                *o = v.exp();
14603            }
14604        }
14605        Activation::Log => {
14606            for (o, &v) in out.iter_mut().zip(inp) {
14607                *o = v.ln();
14608            }
14609        }
14610        Activation::Sqrt => {
14611            for (o, &v) in out.iter_mut().zip(inp) {
14612                *o = v.sqrt();
14613            }
14614        }
14615        Activation::Rsqrt => {
14616            for (o, &v) in out.iter_mut().zip(inp) {
14617                *o = 1.0 / v.sqrt();
14618            }
14619        }
14620        Activation::Abs => {
14621            for (o, &v) in out.iter_mut().zip(inp) {
14622                *o = v.abs();
14623            }
14624        }
14625        Activation::Tanh => {
14626            for (o, &v) in out.iter_mut().zip(inp) {
14627                *o = v.tanh();
14628            }
14629        }
14630        Activation::Sigmoid => {
14631            for (o, &v) in out.iter_mut().zip(inp) {
14632                *o = 1.0 / (1.0 + (-v).exp());
14633            }
14634        }
14635        Activation::Relu => {
14636            for (o, &v) in out.iter_mut().zip(inp) {
14637                *o = v.max(0.0);
14638            }
14639        }
14640        Activation::Round => {
14641            for (o, &v) in out.iter_mut().zip(inp) {
14642                *o = v.round_ties_even();
14643            }
14644        }
14645        Activation::Sin => {
14646            for (o, &v) in out.iter_mut().zip(inp) {
14647                *o = v.sin();
14648            }
14649        }
14650        Activation::Cos => {
14651            for (o, &v) in out.iter_mut().zip(inp) {
14652                *o = v.cos();
14653            }
14654        }
14655        Activation::Tan => {
14656            for (o, &v) in out.iter_mut().zip(inp) {
14657                *o = v.tan();
14658            }
14659        }
14660        Activation::Atan => {
14661            for (o, &v) in out.iter_mut().zip(inp) {
14662                *o = v.atan();
14663            }
14664        }
14665        Activation::Gelu | Activation::GeluApprox | Activation::Silu => {
14666            panic!(
14667                "apply_activation_f64: {kind:?} not yet implemented at f64. \
14668                    Add when a workload needs it."
14669            );
14670        }
14671    }
14672}
14673
14674#[inline]
14675fn binary_op_f64(op: BinaryOp, a: f64, b: f64) -> f64 {
14676    match op {
14677        BinaryOp::Add => a + b,
14678        BinaryOp::Sub => a - b,
14679        BinaryOp::Mul => a * b,
14680        BinaryOp::Div => a / b,
14681        BinaryOp::Max => a.max(b),
14682        BinaryOp::Min => a.min(b),
14683        BinaryOp::Pow => a.powf(b),
14684    }
14685}
14686
14687/// f64 sum reduction over a contiguous middle range.
14688/// Layout: input is `[outer, reduced, inner]`, output is `[outer, inner]`.
14689fn reduce_sum_f64(inp: &[f64], out: &mut [f64], outer: usize, reduced: usize, inner: usize) {
14690    for o in 0..outer {
14691        for n in 0..inner {
14692            let mut acc = 0.0_f64;
14693            for r in 0..reduced {
14694                acc += inp[o * reduced * inner + r * inner + n];
14695            }
14696            out[o * inner + n] = acc;
14697        }
14698    }
14699}
14700
14701#[cfg(test)]
14702mod tests {
14703    use super::*;
14704    use rlx_ir::*;
14705
14706    /// Plan #45: when a Narrow's only consumer is a Rope, the thunk
14707    /// fusion pass collapses them — the Narrow becomes Nop, and the
14708    /// Rope reads from the parent buffer with its row stride. This
14709    /// test runs the unfused path (batch*seq > FusedAttnBlock
14710    /// threshold) and asserts the rewrite happened.
14711    #[test]
14712    fn narrow_rope_fuses_in_unfused_path() {
14713        let f = DType::F32;
14714        let mut g = Graph::new("nr_fuse");
14715        // Force batch*seq > 64 so FusedAttnBlock doesn't pre-empt us.
14716        let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f)); // 16*8=128 > 64
14717        let cos = g.input("cos", Shape::new(&[16], f));
14718        let sin = g.input("sin", Shape::new(&[16], f));
14719        // Last-axis narrow: Q = qkv[..., 0..64]
14720        let q = g.narrow_(qkv, 2, 0, 64);
14721        let q_rope = g.rope(q, cos, sin, 16);
14722        g.set_outputs(vec![q_rope]);
14723
14724        let plan = rlx_opt::memory::plan_memory(&g);
14725        let arena = crate::arena::Arena::from_plan(plan);
14726        let sched = compile_thunks(&g, &arena);
14727
14728        let mut narrow_count = 0;
14729        let mut rope_with_stride: Option<u32> = None;
14730        for t in &sched.thunks {
14731            match t {
14732                Thunk::Narrow { .. } => narrow_count += 1,
14733                Thunk::Rope { src_row_stride, .. } => rope_with_stride = Some(*src_row_stride),
14734                _ => {}
14735            }
14736        }
14737        // After fusion the Narrow is gone; only the Rope remains, and
14738        // it now walks with the parent QKV's row stride (3 * 64 = 192).
14739        assert_eq!(
14740            narrow_count, 0,
14741            "Narrow→Rope fusion should leave zero Narrow thunks; saw {narrow_count}"
14742        );
14743        assert_eq!(
14744            rope_with_stride,
14745            Some(192),
14746            "Rope's src_row_stride should be 192 (parent qkv axis), saw {rope_with_stride:?}"
14747        );
14748    }
14749
14750    /// Plan #15: SSM selective scan matches a naive Python-style
14751    /// Python-style sequential reference.
14752    #[test]
14753    fn ssm_selective_scan_matches_reference() {
14754        use rlx_ir::Philox4x32;
14755        let bch = 1usize;
14756        let s = 4usize;
14757        let h = 3usize;
14758        let n = 2usize;
14759
14760        let mut rng = Philox4x32::new(13);
14761        let mut x = vec![0f32; bch * s * h];
14762        rng.fill_normal(&mut x);
14763        let mut delta = vec![0f32; bch * s * h];
14764        // Keep Δ small so exp(Δ·A) doesn't blow up.
14765        for v in delta.iter_mut() {
14766            *v = (rng.next_f32() - 0.5) * 0.1;
14767        }
14768        let mut a = vec![0f32; h * n];
14769        for v in a.iter_mut() {
14770            *v = -(rng.next_f32() * 0.5 + 0.1);
14771        } // negative for stability
14772        let mut b = vec![0f32; bch * s * n];
14773        rng.fill_normal(&mut b);
14774        let mut c = vec![0f32; bch * s * n];
14775        rng.fill_normal(&mut c);
14776
14777        // Reference scan.
14778        let mut expected = vec![0f32; bch * s * h];
14779        for bi in 0..bch {
14780            let mut state = vec![0f32; h * n];
14781            for si in 0..s {
14782                for ci in 0..h {
14783                    let d = delta[bi * s * h + si * h + ci];
14784                    let xv = x[bi * s * h + si * h + ci];
14785                    let mut acc = 0f32;
14786                    for ni in 0..n {
14787                        let da = (d * a[ci * n + ni]).exp();
14788                        state[ci * n + ni] =
14789                            da * state[ci * n + ni] + d * b[bi * s * n + si * n + ni] * xv;
14790                        acc += c[bi * s * n + si * n + ni] * state[ci * n + ni];
14791                    }
14792                    expected[bi * s * h + si * h + ci] = acc;
14793                }
14794            }
14795        }
14796
14797        // RLX path.
14798        let f = DType::F32;
14799        let mut g = Graph::new("ssm");
14800        let xn = g.input("x", Shape::new(&[bch, s, h], f));
14801        let dn = g.input("delta", Shape::new(&[bch, s, h], f));
14802        let an = g.param("a", Shape::new(&[h, n], f));
14803        let bn = g.param("b", Shape::new(&[bch, s, n], f));
14804        let cn = g.param("c", Shape::new(&[bch, s, n], f));
14805        let yn = g.selective_scan(xn, dn, an, bn, cn, n, Shape::new(&[bch, s, h], f));
14806        g.set_outputs(vec![yn]);
14807
14808        let plan = rlx_opt::memory::plan_memory(&g);
14809        let mut arena = crate::arena::Arena::from_plan(plan);
14810        let sched = compile_thunks(&g, &arena);
14811
14812        let xn_off = arena.byte_offset(xn);
14813        let dn_off = arena.byte_offset(dn);
14814        let an_off = arena.byte_offset(an);
14815        let bn_off = arena.byte_offset(bn);
14816        let cn_off = arena.byte_offset(cn);
14817        let yn_off = arena.byte_offset(yn);
14818        let buf = arena.raw_buf_mut();
14819        unsafe {
14820            let copy = |dst: *mut f32, data: &[f32]| {
14821                for (i, &v) in data.iter().enumerate() {
14822                    *dst.add(i) = v;
14823                }
14824            };
14825            copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
14826            copy(buf.as_mut_ptr().add(dn_off) as *mut f32, &delta);
14827            copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
14828            copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
14829            copy(buf.as_mut_ptr().add(cn_off) as *mut f32, &c);
14830        }
14831        execute_thunks(&sched, arena.raw_buf_mut());
14832
14833        let actual: Vec<f32> = unsafe {
14834            let p = arena.raw_buf().as_ptr().add(yn_off) as *const f32;
14835            (0..bch * s * h).map(|i| *p.add(i)).collect()
14836        };
14837
14838        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
14839            assert!(
14840                (e - a).abs() < 1e-3,
14841                "mismatch at {i}: expected {e}, got {a}"
14842            );
14843        }
14844    }
14845
14846    /// Plan #26: 1×1 conv lowers to per-batch sgemm and matches the
14847    /// scalar 7-loop reference.
14848    #[test]
14849    fn conv_1x1_fast_path_matches_scalar() {
14850        use rlx_ir::Philox4x32;
14851        // [N=2, C_in=4, H=3, W=3]
14852        let n = 2usize;
14853        let c_in = 4usize;
14854        let h = 3usize;
14855        let w = 3usize;
14856        let c_out = 5usize;
14857        let mut rng = Philox4x32::new(31);
14858        let mut x = vec![0f32; n * c_in * h * w];
14859        rng.fill_normal(&mut x);
14860        let mut weight = vec![0f32; c_out * c_in];
14861        rng.fill_normal(&mut weight);
14862
14863        // Reference: scalar 1×1 conv = per-batch matmul
14864        // out[ni, co, hi, wi] = sum_ci weight[co, ci] * x[ni, ci, hi, wi]
14865        let mut expected = vec![0f32; n * c_out * h * w];
14866        for ni in 0..n {
14867            for co in 0..c_out {
14868                for hi in 0..h {
14869                    for wi in 0..w {
14870                        let mut acc = 0f32;
14871                        for ci in 0..c_in {
14872                            acc += weight[co * c_in + ci]
14873                                * x[((ni * c_in) + ci) * h * w + hi * w + wi];
14874                        }
14875                        expected[((ni * c_out) + co) * h * w + hi * w + wi] = acc;
14876                    }
14877                }
14878            }
14879        }
14880
14881        // RLX path: build a graph with Op::Conv (kernel=[1,1], stride=[1,1], etc).
14882        let f = DType::F32;
14883        let mut g = Graph::new("conv1x1");
14884        let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
14885        let wn = g.param("w", Shape::new(&[c_out, c_in, 1, 1], f));
14886        // Manually add Op::Conv since there's no `g.conv()` helper.
14887        let cn = g.add_node(
14888            rlx_ir::Op::Conv {
14889                kernel_size: vec![1, 1],
14890                stride: vec![1, 1],
14891                padding: vec![0, 0],
14892                dilation: vec![1, 1],
14893                groups: 1,
14894            },
14895            vec![xn, wn],
14896            Shape::new(&[n, c_out, h, w], f),
14897        );
14898        g.set_outputs(vec![cn]);
14899
14900        let plan = rlx_opt::memory::plan_memory(&g);
14901        let mut arena = crate::arena::Arena::from_plan(plan);
14902        let sched = compile_thunks(&g, &arena);
14903
14904        // Verify the fast path was selected.
14905        let saw_fast = sched
14906            .thunks
14907            .iter()
14908            .any(|t| matches!(t, Thunk::Conv2D1x1 { .. }));
14909        let saw_slow = sched
14910            .thunks
14911            .iter()
14912            .any(|t| matches!(t, Thunk::Conv2D { .. }));
14913        assert!(saw_fast, "1×1 conv should emit Conv2D1x1");
14914        assert!(!saw_slow, "1×1 conv must not fall through to scalar Conv2D");
14915
14916        let xn_off = arena.byte_offset(xn);
14917        let wn_off = arena.byte_offset(wn);
14918        let cn_off = arena.byte_offset(cn);
14919        let buf = arena.raw_buf_mut();
14920        unsafe {
14921            let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
14922            for (i, &v) in x.iter().enumerate() {
14923                *xp.add(i) = v;
14924            }
14925            let wp = buf.as_mut_ptr().add(wn_off) as *mut f32;
14926            for (i, &v) in weight.iter().enumerate() {
14927                *wp.add(i) = v;
14928            }
14929        }
14930        execute_thunks(&sched, arena.raw_buf_mut());
14931
14932        let actual: Vec<f32> = unsafe {
14933            let p = arena.raw_buf().as_ptr().add(cn_off) as *const f32;
14934            (0..(n * c_out * h * w)).map(|i| *p.add(i)).collect()
14935        };
14936
14937        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
14938            assert!(
14939                (e - a).abs() < 1e-3,
14940                "mismatch at {i}: expected {e}, got {a}"
14941            );
14942        }
14943    }
14944
14945    /// Plan #5: fused dequant matmul matches the dequant-then-matmul
14946    /// reference (i.e. `(scale * (q - z)) @ x` materialized).
14947    #[test]
14948    fn dequant_matmul_int8_sym_matches_reference() {
14949        use rlx_ir::Philox4x32;
14950        use rlx_ir::quant::QuantScheme;
14951
14952        let m = 3usize;
14953        let k = 8usize;
14954        let n = 4usize;
14955        let block_size = 4usize; // 2 blocks per column
14956        let blocks_per_col = k / block_size;
14957
14958        // Random inputs: x f32, w_q i8, scales f32. Symmetric → no zp.
14959        let mut rng = Philox4x32::new(99);
14960        let mut x = vec![0f32; m * k];
14961        rng.fill_normal(&mut x);
14962        let w_q: Vec<i8> = (0..(k * n))
14963            .map(|i| ((i as i32 * 13 + 7) % 127 - 63) as i8)
14964            .collect();
14965        let scales: Vec<f32> = (0..(blocks_per_col * n))
14966            .map(|i| 0.01 + 0.001 * i as f32)
14967            .collect();
14968
14969        // Reference: build f32 weights from (q * scale) per block.
14970        let mut w_f32 = vec![0f32; k * n];
14971        for p in 0..k {
14972            let block = p / block_size;
14973            for j in 0..n {
14974                let s = scales[block * n + j];
14975                w_f32[p * n + j] = w_q[p * n + j] as f32 * s;
14976            }
14977        }
14978        let mut expected = vec![0f32; m * n];
14979        for i in 0..m {
14980            for j in 0..n {
14981                let mut acc = 0f32;
14982                for p in 0..k {
14983                    acc += x[i * k + p] * w_f32[p * n + j];
14984                }
14985                expected[i * n + j] = acc;
14986            }
14987        }
14988
14989        // RLX path.
14990        let f = DType::F32;
14991        let mut g = Graph::new("dq");
14992        let xn = g.input("x", Shape::new(&[m, k], f));
14993        let wn = g.param("w", Shape::new(&[k, n], DType::I8));
14994        let sn = g.param("scale", Shape::new(&[blocks_per_col, n], f));
14995        let zn = g.param("zp", Shape::new(&[blocks_per_col, n], f)); // unused (sym)
14996        let dq = g.dequant_matmul(
14997            xn,
14998            wn,
14999            sn,
15000            zn,
15001            QuantScheme::Int8Block {
15002                block_size: block_size as u32,
15003            },
15004            Shape::new(&[m, n], f),
15005        );
15006        g.set_outputs(vec![dq]);
15007
15008        let plan = rlx_opt::memory::plan_memory(&g);
15009        let mut arena = crate::arena::Arena::from_plan(plan);
15010        let sched = compile_thunks(&g, &arena);
15011
15012        let xn_off = arena.byte_offset(xn);
15013        let wn_off = arena.byte_offset(wn);
15014        let sn_off = arena.byte_offset(sn);
15015        let zn_off = arena.byte_offset(zn);
15016        let dq_off = arena.byte_offset(dq);
15017        let buf = arena.raw_buf_mut();
15018        unsafe {
15019            // Seed f32 inputs.
15020            let xp = buf.as_mut_ptr().add(xn_off) as *mut f32;
15021            for (i, &v) in x.iter().enumerate() {
15022                *xp.add(i) = v;
15023            }
15024            let sp = buf.as_mut_ptr().add(sn_off) as *mut f32;
15025            for (i, &v) in scales.iter().enumerate() {
15026                *sp.add(i) = v;
15027            }
15028            let zp = buf.as_mut_ptr().add(zn_off) as *mut f32;
15029            for i in 0..(blocks_per_col * n) {
15030                *zp.add(i) = 0.0;
15031            }
15032            // Seed i8 weights byte-by-byte.
15033            let wp = buf.as_mut_ptr().add(wn_off) as *mut i8;
15034            for (i, &v) in w_q.iter().enumerate() {
15035                *wp.add(i) = v;
15036            }
15037        }
15038        execute_thunks(&sched, arena.raw_buf_mut());
15039
15040        let actual: Vec<f32> = unsafe {
15041            let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
15042            (0..m * n).map(|i| *p.add(i)).collect()
15043        };
15044
15045        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
15046            assert!(
15047                (e - a).abs() < 1e-3,
15048                "mismatch at {i}: expected {e}, got {a}"
15049            );
15050        }
15051    }
15052
15053    /// Plan #9: LoRA matmul matches the unfused 3-matmul reference.
15054    #[test]
15055    fn lora_matmul_matches_unfused_reference() {
15056        use rlx_ir::Philox4x32;
15057
15058        let m = 4usize;
15059        let k = 8usize;
15060        let n = 6usize;
15061        let r = 2usize;
15062        let scale = 0.5f32;
15063
15064        // Random inputs (deterministic via Philox).
15065        let mut rng = Philox4x32::new(42);
15066        let mut x = vec![0f32; m * k];
15067        rng.fill_normal(&mut x);
15068        let mut w = vec![0f32; k * n];
15069        rng.fill_normal(&mut w);
15070        let mut a = vec![0f32; k * r];
15071        rng.fill_normal(&mut a);
15072        let mut b = vec![0f32; r * n];
15073        rng.fill_normal(&mut b);
15074
15075        // Reference: out = x·W + scale * x·A·B. Naive triple-loop.
15076        let naive = |a_buf: &[f32], b_buf: &[f32], rows: usize, inner: usize, cols: usize| {
15077            let mut o = vec![0f32; rows * cols];
15078            for i in 0..rows {
15079                for j in 0..cols {
15080                    let mut acc = 0f32;
15081                    for p in 0..inner {
15082                        acc += a_buf[i * inner + p] * b_buf[p * cols + j];
15083                    }
15084                    o[i * cols + j] = acc;
15085                }
15086            }
15087            o
15088        };
15089        let xw = naive(&x, &w, m, k, n);
15090        let xa = naive(&x, &a, m, k, r);
15091        let xab = naive(&xa, &b, m, r, n);
15092        let mut expected = xw;
15093        for i in 0..(m * n) {
15094            expected[i] += scale * xab[i];
15095        }
15096
15097        // RLX path: build a graph with one LoraMatMul.
15098        let f = DType::F32;
15099        let mut g = Graph::new("lora");
15100        let xn = g.input("x", Shape::new(&[m, k], f));
15101        let wn = g.param("w", Shape::new(&[k, n], f));
15102        let an = g.param("a", Shape::new(&[k, r], f));
15103        let bn = g.param("b", Shape::new(&[r, n], f));
15104        let lm = g.lora_matmul(xn, wn, an, bn, scale, Shape::new(&[m, n], f));
15105        g.set_outputs(vec![lm]);
15106
15107        let plan = rlx_opt::memory::plan_memory(&g);
15108        let mut arena = crate::arena::Arena::from_plan(plan);
15109        let sched = compile_thunks(&g, &arena);
15110
15111        let xn_off = arena.byte_offset(xn);
15112        let wn_off = arena.byte_offset(wn);
15113        let an_off = arena.byte_offset(an);
15114        let bn_off = arena.byte_offset(bn);
15115        let lm_off = arena.byte_offset(lm);
15116        let buf = arena.raw_buf_mut();
15117        unsafe {
15118            let copy = |dst: *mut f32, data: &[f32]| {
15119                for (i, &v) in data.iter().enumerate() {
15120                    *dst.add(i) = v;
15121                }
15122            };
15123            copy(buf.as_mut_ptr().add(xn_off) as *mut f32, &x);
15124            copy(buf.as_mut_ptr().add(wn_off) as *mut f32, &w);
15125            copy(buf.as_mut_ptr().add(an_off) as *mut f32, &a);
15126            copy(buf.as_mut_ptr().add(bn_off) as *mut f32, &b);
15127        }
15128        execute_thunks(&sched, arena.raw_buf_mut());
15129
15130        let actual: Vec<f32> = unsafe {
15131            let p = arena.raw_buf().as_ptr().add(lm_off) as *const f32;
15132            (0..m * n).map(|i| *p.add(i)).collect()
15133        };
15134
15135        for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
15136            assert!(
15137                (e - a).abs() < 1e-3,
15138                "mismatch at {i}: expected {e}, got {a}"
15139            );
15140        }
15141    }
15142
15143    /// Plan #42: fused sampling kernel determinism + greedy fallback.
15144    #[test]
15145    fn sample_temperature_zero_is_argmax() {
15146        // Very low temperature → distribution collapses on argmax.
15147        // Same seed → same output bit-for-bit.
15148        let f = DType::F32;
15149        let mut g = Graph::new("samp");
15150        let logits = g.input("logits", Shape::new(&[1, 8], f));
15151        let s = g.sample(logits, 0, 1.0, 1e-3, 42, Shape::new(&[1], f));
15152        g.set_outputs(vec![s]);
15153        let plan = rlx_opt::memory::plan_memory(&g);
15154        let mut arena = crate::arena::Arena::from_plan(plan);
15155        let sched = compile_thunks(&g, &arena);
15156
15157        let logits_off = arena.byte_offset(logits);
15158        let s_off = arena.byte_offset(s);
15159        let buf = arena.raw_buf_mut();
15160        unsafe {
15161            let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
15162            // argmax = index 5 (value 9.0).
15163            let inputs = [0.1f32, 0.2, 0.3, 0.4, 0.5, 9.0, 0.7, 0.8];
15164            for (i, &v) in inputs.iter().enumerate() {
15165                *p.add(i) = v;
15166            }
15167        }
15168        execute_thunks(&sched, arena.raw_buf_mut());
15169
15170        let token = unsafe {
15171            let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
15172            *p as usize
15173        };
15174        assert_eq!(token, 5, "low-temp sampling should pick the argmax");
15175    }
15176
15177    #[test]
15178    fn sample_top_k_one_is_deterministic() {
15179        // top_k=1 forces only the argmax to have nonzero probability.
15180        let f = DType::F32;
15181        let mut g = Graph::new("samp_k1");
15182        let logits = g.input("logits", Shape::new(&[1, 4], f));
15183        let s = g.sample(logits, 1, 1.0, 1.0, 7, Shape::new(&[1], f));
15184        g.set_outputs(vec![s]);
15185        let plan = rlx_opt::memory::plan_memory(&g);
15186        let mut arena = crate::arena::Arena::from_plan(plan);
15187        let sched = compile_thunks(&g, &arena);
15188
15189        let logits_off = arena.byte_offset(logits);
15190        let s_off = arena.byte_offset(s);
15191        let buf = arena.raw_buf_mut();
15192        unsafe {
15193            let p = buf.as_mut_ptr().add(logits_off) as *mut f32;
15194            let inputs = [0.1f32, 5.0, 0.3, 0.4]; // argmax = 1
15195            for (i, &v) in inputs.iter().enumerate() {
15196                *p.add(i) = v;
15197            }
15198        }
15199        execute_thunks(&sched, arena.raw_buf_mut());
15200        let token = unsafe {
15201            let p = arena.raw_buf().as_ptr().add(s_off) as *const f32;
15202            *p as usize
15203        };
15204        assert_eq!(token, 1);
15205    }
15206
15207    /// Plan #44: cumsum primitive parity vs. naive scan.
15208    #[test]
15209    fn cumsum_inclusive_matches_naive() {
15210        let f = DType::F32;
15211        let mut g = Graph::new("cumsum");
15212        let x = g.input("x", Shape::new(&[2, 4], f));
15213        let cs = g.cumsum(x, -1, false, Shape::new(&[2, 4], f));
15214        g.set_outputs(vec![cs]);
15215        let plan = rlx_opt::memory::plan_memory(&g);
15216        let mut arena = crate::arena::Arena::from_plan(plan);
15217        let sched = compile_thunks(&g, &arena);
15218
15219        // Cache offsets up-front so we can drop the immutable borrow.
15220        let x_off = arena.byte_offset(x);
15221        let out_off = arena.byte_offset(cs);
15222        let buf = arena.raw_buf_mut();
15223        unsafe {
15224            let p = buf.as_mut_ptr().add(x_off) as *mut f32;
15225            let inputs = [1.0f32, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0];
15226            for (i, &v) in inputs.iter().enumerate() {
15227                *p.add(i) = v;
15228            }
15229        }
15230        execute_thunks(&sched, arena.raw_buf_mut());
15231
15232        let out: Vec<f32> = unsafe {
15233            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
15234            (0..8).map(|i| *p.add(i)).collect()
15235        };
15236        assert_eq!(out, vec![1.0, 3.0, 6.0, 10.0, 10.0, 30.0, 60.0, 100.0]);
15237    }
15238
15239    /// Plan #46 deep: Narrow×3 → Attention fusion. The three QKV
15240    /// narrows that BERT/Nomic emit on the unfused (batch*seq > 64)
15241    /// path collapse into a single strided-Attention thunk.
15242    #[test]
15243    fn narrow_attention_fuses_in_unfused_path() {
15244        let f = DType::F32;
15245        let mut g = Graph::new("nattn_fuse");
15246        // batch*seq = 8*16 = 128 > 64 so FusedAttnBlock skips.
15247        let qkv = g.input("qkv", Shape::new(&[8, 16, 192], f)); // 3*64 = 192
15248        let mask = g.input("mask", Shape::new(&[8, 16], f));
15249        let q = g.narrow_(qkv, 2, 0, 64);
15250        let k = g.narrow_(qkv, 2, 64, 64);
15251        let v = g.narrow_(qkv, 2, 128, 64);
15252        let attn = g.attention(q, k, v, mask, 4, 16, Shape::new(&[8, 16, 64], f));
15253        g.set_outputs(vec![attn]);
15254
15255        let plan = rlx_opt::memory::plan_memory(&g);
15256        let arena = crate::arena::Arena::from_plan(plan);
15257        let sched = compile_thunks(&g, &arena);
15258
15259        let mut narrow_count = 0;
15260        let mut attn_strides: Option<(u32, u32, u32)> = None;
15261        for t in &sched.thunks {
15262            match t {
15263                Thunk::Narrow { .. } => narrow_count += 1,
15264                Thunk::Attention {
15265                    q_row_stride,
15266                    k_row_stride,
15267                    v_row_stride,
15268                    ..
15269                } => attn_strides = Some((*q_row_stride, *k_row_stride, *v_row_stride)),
15270                _ => {}
15271            }
15272        }
15273        // After fusion the 3 narrows are gone; Attention now walks the
15274        // QKV with parent row stride = 192 (3 × 64) on all three inputs.
15275        assert_eq!(
15276            narrow_count, 0,
15277            "Narrow×3→Attention fusion should eliminate all 3 narrows; saw {narrow_count}"
15278        );
15279        assert_eq!(
15280            attn_strides,
15281            Some((192, 192, 192)),
15282            "Attention should walk Q/K/V with parent row stride 192"
15283        );
15284    }
15285
15286    // ── Backward / training op parity tests ────────────────────
15287    //
15288    // Strategy: build a graph that contains exactly the backward op
15289    // under test (plus its inputs as graph Inputs), execute, and
15290    // compare against a hand-rolled scalar reference. For
15291    // Conv2dBackwardInput we additionally check against the numerical
15292    // gradient of the forward Conv2D — that's the gold-standard test
15293    // that validates the math, not just consistency between two
15294    // implementations of the same formula.
15295
15296    fn run_graph(
15297        g: &Graph,
15298        inputs: &[(NodeId, &[f32])],
15299        out_id: NodeId,
15300        out_len: usize,
15301    ) -> Vec<f32> {
15302        let plan = rlx_opt::memory::plan_memory(g);
15303        let mut arena = crate::arena::Arena::from_plan(plan);
15304        let sched = compile_thunks(g, &arena);
15305        for &(id, data) in inputs {
15306            let off = arena.byte_offset(id);
15307            let buf = arena.raw_buf_mut();
15308            unsafe {
15309                let p = buf.as_mut_ptr().add(off) as *mut f32;
15310                for (i, &v) in data.iter().enumerate() {
15311                    *p.add(i) = v;
15312                }
15313            }
15314        }
15315        execute_thunks(&sched, arena.raw_buf_mut());
15316        let off = arena.byte_offset(out_id);
15317        unsafe {
15318            let p = arena.raw_buf().as_ptr().add(off) as *const f32;
15319            (0..out_len).map(|i| *p.add(i)).collect()
15320        }
15321    }
15322
15323    #[test]
15324    fn relu_backward_matches_mask() {
15325        let f = DType::F32;
15326        let len = 7usize;
15327        let x: Vec<f32> = vec![-2.0, -0.1, 0.0, 0.1, 1.0, 3.0, -5.0];
15328        let dy: Vec<f32> = vec![0.5, 1.5, 2.5, -0.7, 4.0, -1.0, 9.0];
15329
15330        let mut g = Graph::new("relu_bw");
15331        let xn = g.input("x", Shape::new(&[len], f));
15332        let dyn_ = g.input("dy", Shape::new(&[len], f));
15333        let dx = g.relu_backward(xn, dyn_);
15334        g.set_outputs(vec![dx]);
15335
15336        let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, len);
15337        // Reference: gradient is dy where x>0 strictly, else 0.
15338        // (zero is not "positive" — the forward applied max(0, x), and at
15339        // x=0 the subgradient could be anything in [0, dy]; we pick 0.)
15340        let expected: Vec<f32> = x
15341            .iter()
15342            .zip(&dy)
15343            .map(|(&xi, &dyi)| if xi > 0.0 { dyi } else { 0.0 })
15344            .collect();
15345        for (a, e) in actual.iter().zip(&expected) {
15346            assert!((a - e).abs() < 1e-6, "relu_bw mismatch: {a} vs {e}");
15347        }
15348    }
15349
15350    #[test]
15351    fn maxpool2d_backward_routes_to_argmax() {
15352        let f = DType::F32;
15353        // [N=1, C=1, H=4, W=4] → 2x2 max-pool stride 2 → [1,1,2,2].
15354        let x: Vec<f32> = vec![
15355            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
15356        ];
15357        // Argmax of each 2x2 window:
15358        //   (0,0)→6 (idx 5), (0,1)→8 (idx 7),
15359        //   (1,0)→14(idx 13),(1,1)→16(idx 15).
15360        let dy: Vec<f32> = vec![0.5, 1.0, 2.0, 4.0];
15361
15362        let mut g = Graph::new("maxpool_bw");
15363        let xn = g.input("x", Shape::new(&[1, 1, 4, 4], f));
15364        let dyn_ = g.input("dy", Shape::new(&[1, 1, 2, 2], f));
15365        let dx = g.maxpool2d_backward(xn, dyn_, vec![2, 2], vec![2, 2], vec![0, 0]);
15366        g.set_outputs(vec![dx]);
15367
15368        let actual = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dx, 16);
15369        let mut expected = vec![0f32; 16];
15370        expected[5] = 0.5;
15371        expected[7] = 1.0;
15372        expected[13] = 2.0;
15373        expected[15] = 4.0;
15374        for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
15375            assert!((a - e).abs() < 1e-6, "maxpool_bw[{i}] mismatch: {a} vs {e}");
15376        }
15377    }
15378
15379    #[test]
15380    fn conv2d_backward_input_matches_numerical_gradient() {
15381        use rlx_ir::Philox4x32;
15382        // Small enough to numerically differentiate exhaustively but
15383        // big enough to exercise stride/padding edge cases.
15384        let n = 1usize;
15385        let c_in = 2usize;
15386        let h = 4usize;
15387        let w = 4usize;
15388        let c_out = 3usize;
15389        let kh = 3usize;
15390        let kw = 3usize;
15391        let ph = 1usize;
15392        let pw = 1usize;
15393        let sh = 1usize;
15394        let sw = 1usize;
15395        // Output dims with padding=1, stride=1: same as input.
15396        let h_out = (h + 2 * ph - kh) / sh + 1;
15397        let w_out = (w + 2 * pw - kw) / sw + 1;
15398        assert_eq!(h_out, 4);
15399        assert_eq!(w_out, 4);
15400
15401        let mut rng = Philox4x32::new(7);
15402        let mut x = vec![0f32; n * c_in * h * w];
15403        rng.fill_normal(&mut x);
15404        let mut wt = vec![0f32; c_out * c_in * kh * kw];
15405        rng.fill_normal(&mut wt);
15406        let mut dy = vec![0f32; n * c_out * h_out * w_out];
15407        rng.fill_normal(&mut dy);
15408
15409        // Analytical: Conv2dBackwardInput on (dy, w).
15410        let f = DType::F32;
15411        let mut g = Graph::new("conv_bwi");
15412        let dy_in = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
15413        let w_in = g.input("w", Shape::new(&[c_out, c_in, kh, kw], f));
15414        let dx = g.conv2d_backward_input(
15415            dy_in,
15416            w_in,
15417            Shape::new(&[n, c_in, h, w], f),
15418            vec![kh, kw],
15419            vec![sh, sw],
15420            vec![ph, pw],
15421            vec![1, 1],
15422            1,
15423        );
15424        g.set_outputs(vec![dx]);
15425        let analytical = run_graph(&g, &[(dy_in, &dy), (w_in, &wt)], dx, n * c_in * h * w);
15426
15427        // Numerical: for each x[i], finite-difference forward conv twice.
15428        // Forward: y[j] = sum over filter window of w * x ; dot(dy, y) is
15429        // the scalar we differentiate. Then dx[i] = ∂(dot(dy, y))/∂x[i].
15430        let forward = |x: &[f32]| -> Vec<f32> {
15431            let mut out = vec![0f32; n * c_out * h_out * w_out];
15432            for ni in 0..n {
15433                for co in 0..c_out {
15434                    for ho in 0..h_out {
15435                        for wo in 0..w_out {
15436                            let mut acc = 0f32;
15437                            for ci in 0..c_in {
15438                                for ki in 0..kh {
15439                                    for kj in 0..kw {
15440                                        let hi = ho * sh + ki;
15441                                        let wi = wo * sw + kj;
15442                                        if hi < ph || wi < pw {
15443                                            continue;
15444                                        }
15445                                        let hi = hi - ph;
15446                                        let wi = wi - pw;
15447                                        if hi >= h || wi >= w {
15448                                            continue;
15449                                        }
15450                                        let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
15451                                        let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
15452                                        acc += xv * wv;
15453                                    }
15454                                }
15455                            }
15456                            out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
15457                        }
15458                    }
15459                }
15460            }
15461            out
15462        };
15463        let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
15464        let eps = 1e-3f32;
15465        let mut numerical = vec![0f32; x.len()];
15466        for i in 0..x.len() {
15467            let saved = x[i];
15468            x[i] = saved + eps;
15469            let plus = dot(&forward(&x), &dy);
15470            x[i] = saved - eps;
15471            let minus = dot(&forward(&x), &dy);
15472            x[i] = saved;
15473            numerical[i] = (plus - minus) / (2.0 * eps);
15474        }
15475        for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
15476            // f32 + eps=1e-3 numerical grad → ~1e-3 absolute is realistic.
15477            assert!(
15478                (a - n).abs() < 5e-3,
15479                "conv_bw_input[{i}]: analytical {a} vs numerical {n}"
15480            );
15481        }
15482    }
15483
15484    #[test]
15485    fn conv2d_backward_weight_matches_numerical_gradient() {
15486        use rlx_ir::Philox4x32;
15487        let n = 2usize;
15488        let c_in = 2usize;
15489        let h = 4usize;
15490        let w = 4usize;
15491        let c_out = 2usize;
15492        let kh = 3usize;
15493        let kw = 3usize;
15494        let ph = 0usize;
15495        let pw = 0usize;
15496        let sh = 1usize;
15497        let sw = 1usize;
15498        let h_out = (h + 2 * ph - kh) / sh + 1;
15499        let w_out = (w + 2 * pw - kw) / sw + 1;
15500
15501        let mut rng = Philox4x32::new(11);
15502        let mut x = vec![0f32; n * c_in * h * w];
15503        rng.fill_normal(&mut x);
15504        let mut wt = vec![0f32; c_out * c_in * kh * kw];
15505        rng.fill_normal(&mut wt);
15506        let mut dy = vec![0f32; n * c_out * h_out * w_out];
15507        rng.fill_normal(&mut dy);
15508
15509        let f = DType::F32;
15510        let mut g = Graph::new("conv_bww");
15511        let xn = g.input("x", Shape::new(&[n, c_in, h, w], f));
15512        let dyn_ = g.input("dy", Shape::new(&[n, c_out, h_out, w_out], f));
15513        let dwn = g.conv2d_backward_weight(
15514            xn,
15515            dyn_,
15516            Shape::new(&[c_out, c_in, kh, kw], f),
15517            vec![kh, kw],
15518            vec![sh, sw],
15519            vec![ph, pw],
15520            vec![1, 1],
15521            1,
15522        );
15523        g.set_outputs(vec![dwn]);
15524        let analytical = run_graph(&g, &[(xn, &x), (dyn_, &dy)], dwn, c_out * c_in * kh * kw);
15525
15526        let forward = |wt: &[f32]| -> Vec<f32> {
15527            let mut out = vec![0f32; n * c_out * h_out * w_out];
15528            for ni in 0..n {
15529                for co in 0..c_out {
15530                    for ho in 0..h_out {
15531                        for wo in 0..w_out {
15532                            let mut acc = 0f32;
15533                            for ci in 0..c_in {
15534                                for ki in 0..kh {
15535                                    for kj in 0..kw {
15536                                        let hi = ho + ki;
15537                                        let wi = wo + kj;
15538                                        let xv = x[((ni * c_in) + ci) * h * w + hi * w + wi];
15539                                        let wv = wt[((co * c_in) + ci) * kh * kw + ki * kw + kj];
15540                                        acc += xv * wv;
15541                                    }
15542                                }
15543                            }
15544                            out[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = acc;
15545                        }
15546                    }
15547                }
15548            }
15549            out
15550        };
15551        let dot = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(&u, &v)| u * v).sum() };
15552        let eps = 1e-3f32;
15553        let mut numerical = vec![0f32; wt.len()];
15554        for i in 0..wt.len() {
15555            let saved = wt[i];
15556            wt[i] = saved + eps;
15557            let plus = dot(&forward(&wt), &dy);
15558            wt[i] = saved - eps;
15559            let minus = dot(&forward(&wt), &dy);
15560            wt[i] = saved;
15561            numerical[i] = (plus - minus) / (2.0 * eps);
15562        }
15563        for (i, (a, n)) in analytical.iter().zip(&numerical).enumerate() {
15564            assert!(
15565                (a - n).abs() < 5e-3,
15566                "conv_bw_weight[{i}]: analytical {a} vs numerical {n}"
15567            );
15568        }
15569    }
15570
15571    #[test]
15572    fn softmax_cross_entropy_matches_reference() {
15573        let f = DType::F32;
15574        let logits: Vec<f32> = vec![
15575            1.0, 2.0, 3.0, // row 0: max=3 (idx 2)
15576            -1.0, 0.0, 4.0, // row 1: max=4 (idx 2)
15577            5.0, 5.0, 5.0, // row 2: uniform
15578        ];
15579        let labels: Vec<f32> = vec![2.0, 0.0, 1.0];
15580
15581        let mut g = Graph::new("sce");
15582        let lg = g.input("logits", Shape::new(&[3, 3], f));
15583        let lb = g.input("labels", Shape::new(&[3], f));
15584        let loss = g.softmax_cross_entropy_with_logits(lg, lb);
15585        g.set_outputs(vec![loss]);
15586        let actual = run_graph(&g, &[(lg, &logits), (lb, &labels)], loss, 3);
15587
15588        // Reference per-row: -log(softmax(row)[label]).
15589        let mut expected = vec![0f32; 3];
15590        for ni in 0..3 {
15591            let row = &logits[ni * 3..(ni + 1) * 3];
15592            let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
15593            let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
15594            let lse = m + sum.ln();
15595            let label_idx = labels[ni] as usize;
15596            expected[ni] = lse - row[label_idx];
15597        }
15598        for (i, (a, e)) in actual.iter().zip(&expected).enumerate() {
15599            assert!((a - e).abs() < 1e-5, "sce loss[{i}]: {a} vs {e}");
15600        }
15601    }
15602
15603    #[test]
15604    fn softmax_cross_entropy_backward_matches_numerical_gradient() {
15605        use rlx_ir::Philox4x32;
15606        let n = 4usize;
15607        let c = 5usize;
15608        let mut rng = Philox4x32::new(23);
15609        let mut logits = vec![0f32; n * c];
15610        rng.fill_normal(&mut logits);
15611        let labels: Vec<f32> = (0..n).map(|i| (i % c) as f32).collect();
15612        let mut d_loss = vec![0f32; n];
15613        rng.fill_normal(&mut d_loss);
15614
15615        let f = DType::F32;
15616        let mut g = Graph::new("sce_bw");
15617        let lg = g.input("logits", Shape::new(&[n, c], f));
15618        let lb = g.input("labels", Shape::new(&[n], f));
15619        let dl = g.input("d_loss", Shape::new(&[n], f));
15620        let dlogits = g.softmax_cross_entropy_backward(lg, lb, dl);
15621        g.set_outputs(vec![dlogits]);
15622        let analytical = run_graph(
15623            &g,
15624            &[(lg, &logits), (lb, &labels), (dl, &d_loss)],
15625            dlogits,
15626            n * c,
15627        );
15628
15629        // Numerical: differentiate dot(d_loss, sce_loss(logits)) w.r.t. each logit.
15630        let sce_loss = |logits: &[f32]| -> Vec<f32> {
15631            let mut out = vec![0f32; n];
15632            for ni in 0..n {
15633                let row = &logits[ni * c..(ni + 1) * c];
15634                let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
15635                let sum: f32 = row.iter().map(|&v| (v - m).exp()).sum();
15636                out[ni] = (m + sum.ln()) - row[labels[ni] as usize];
15637            }
15638            out
15639        };
15640        let dot = |a: &[f32], b: &[f32]| a.iter().zip(b).map(|(&u, &v)| u * v).sum::<f32>();
15641        let eps = 1e-3f32;
15642        let mut numerical = vec![0f32; logits.len()];
15643        for i in 0..logits.len() {
15644            let saved = logits[i];
15645            logits[i] = saved + eps;
15646            let plus = dot(&sce_loss(&logits), &d_loss);
15647            logits[i] = saved - eps;
15648            let minus = dot(&sce_loss(&logits), &d_loss);
15649            logits[i] = saved;
15650            numerical[i] = (plus - minus) / (2.0 * eps);
15651        }
15652        for (i, (a, num)) in analytical.iter().zip(&numerical).enumerate() {
15653            assert!(
15654                (a - num).abs() < 5e-3,
15655                "sce_bw[{i}]: analytical {a} vs numerical {num}"
15656            );
15657        }
15658    }
15659
15660    // ── End-to-end autodiff parity tests ──────────────────────
15661    //
15662    // Build a forward graph, run `grad_with_loss` to produce a graph
15663    // that emits [loss, gradients...], execute it through rlx-cpu,
15664    // and compare each gradient to a finite-difference estimate
15665    // produced by re-running the forward graph with each parameter
15666    // entry perturbed. f32 + ε=1e-3 puts the tolerance floor around
15667    // 5e-3 absolute error.
15668
15669    /// Initialize Op::Constant slots in the arena with their literal
15670    /// data. Mirrors the loop in rlx_runtime::backend (which serves
15671    /// the same role for production runs).
15672    fn fill_constants_into_arena(graph: &Graph, arena: &mut crate::arena::Arena) {
15673        for node in graph.nodes() {
15674            if let Op::Constant { data } = &node.op
15675                && arena.has_buffer(node.id)
15676                && !data.is_empty()
15677            {
15678                let buf = arena.slice_mut(node.id);
15679                let n_floats = data.len() / 4;
15680                let n = buf.len().min(n_floats);
15681                for i in 0..n {
15682                    let bytes = [
15683                        data[i * 4],
15684                        data[i * 4 + 1],
15685                        data[i * 4 + 2],
15686                        data[i * 4 + 3],
15687                    ];
15688                    buf[i] = f32::from_le_bytes(bytes);
15689                }
15690            }
15691        }
15692    }
15693
15694    /// Compile + arena-prep helper for these tests. Returns the
15695    /// schedule and a populated arena. `seed_inputs` writes f32 input
15696    /// data into the arena slot for each (NodeId, &[f32]) pair.
15697    fn prepare(
15698        graph: &Graph,
15699        seed_inputs: &[(NodeId, &[f32])],
15700    ) -> (ThunkSchedule, crate::arena::Arena) {
15701        let plan = rlx_opt::memory::plan_memory(graph);
15702        let mut arena = crate::arena::Arena::from_plan(plan);
15703        let sched = compile_thunks(graph, &arena);
15704        fill_constants_into_arena(graph, &mut arena);
15705        for &(id, data) in seed_inputs {
15706            let off = arena.byte_offset(id);
15707            let buf = arena.raw_buf_mut();
15708            unsafe {
15709                let p = buf.as_mut_ptr().add(off) as *mut f32;
15710                for (i, &v) in data.iter().enumerate() {
15711                    *p.add(i) = v;
15712                }
15713            }
15714        }
15715        (sched, arena)
15716    }
15717
15718    fn read_arena(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f32> {
15719        let off = arena.byte_offset(id);
15720        unsafe {
15721            let p = arena.raw_buf().as_ptr().add(off) as *const f32;
15722            (0..len).map(|i| *p.add(i)).collect()
15723        }
15724    }
15725
15726    fn write_arena(arena: &mut crate::arena::Arena, id: NodeId, data: &[f32]) {
15727        let off = arena.byte_offset(id);
15728        let buf = arena.raw_buf_mut();
15729        unsafe {
15730            let p = buf.as_mut_ptr().add(off) as *mut f32;
15731            for (i, &v) in data.iter().enumerate() {
15732                *p.add(i) = v;
15733            }
15734        }
15735    }
15736
15737    /// f64 sibling of `prepare`. Writes f64 input data into the arena.
15738    fn prepare_f64(
15739        graph: &Graph,
15740        seed_inputs: &[(NodeId, &[f64])],
15741    ) -> (ThunkSchedule, crate::arena::Arena) {
15742        let plan = rlx_opt::memory::plan_memory(graph);
15743        let mut arena = crate::arena::Arena::from_plan(plan);
15744        let sched = compile_thunks(graph, &arena);
15745        fill_constants_into_arena(graph, &mut arena);
15746        for &(id, data) in seed_inputs {
15747            let off = arena.byte_offset(id);
15748            let buf = arena.raw_buf_mut();
15749            unsafe {
15750                let p = buf.as_mut_ptr().add(off) as *mut f64;
15751                for (i, &v) in data.iter().enumerate() {
15752                    *p.add(i) = v;
15753                }
15754            }
15755        }
15756        (sched, arena)
15757    }
15758
15759    fn read_arena_f64(arena: &crate::arena::Arena, id: NodeId, len: usize) -> Vec<f64> {
15760        let off = arena.byte_offset(id);
15761        unsafe {
15762            let p = arena.raw_buf().as_ptr().add(off) as *const f64;
15763            (0..len).map(|i| *p.add(i)).collect()
15764        }
15765    }
15766
15767    /// End-to-end f64 DenseSolve through the full compile + execute
15768    /// path. Validates: IR shape inference, memory planner f64 sizing,
15769    /// arena f64 accessors, Thunk::DenseSolveF64 lowering, executor
15770    /// dispatch, Accelerate dgesv FFI.
15771    ///
15772    /// System:
15773    ///   A = [[2, 1],
15774    ///        [1, 3]]   b = [5, 10]
15775    ///   ⇒  x = [1, 3]   (verified by hand)
15776    #[test]
15777    fn dense_solve_f64_end_to_end() {
15778        let mut g = Graph::new("solve_e2e");
15779        let a = g.input("A", Shape::new(&[2, 2], DType::F64));
15780        let b = g.input("b", Shape::new(&[2], DType::F64));
15781        let x = g.dense_solve(a, b, Shape::new(&[2], DType::F64));
15782        g.set_outputs(vec![x]);
15783
15784        let a_data = [2.0, 1.0, 1.0, 3.0_f64];
15785        let b_data = [5.0, 10.0_f64];
15786        let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
15787        execute_thunks(&sched, arena.raw_buf_mut());
15788
15789        let got = read_arena_f64(&arena, x, 2);
15790        let want = [1.0, 3.0_f64];
15791        for i in 0..2 {
15792            assert!(
15793                (got[i] - want[i]).abs() < 1e-12,
15794                "x[{i}] = {} (expected {})",
15795                got[i],
15796                want[i]
15797            );
15798        }
15799    }
15800
15801    /// Scaled-up f64 DenseSolve — tridiagonal Laplacian-shape (typical
15802    /// MNA structure for a passive RC mesh in Circulax). Validates
15803    /// that the solve scales beyond the trivial 2×2 and that the
15804    /// row-major ↔ col-major dance in `dgesv` is correct for the
15805    /// general case.
15806    #[test]
15807    fn dense_solve_f64_5x5_laplacian() {
15808        let n = 5usize;
15809        let mut g = Graph::new("solve_5x5");
15810        let a = g.input("A", Shape::new(&[n, n], DType::F64));
15811        let b = g.input("b", Shape::new(&[n], DType::F64));
15812        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
15813        g.set_outputs(vec![x]);
15814
15815        // 1-D Laplacian: 2 on diagonal, -1 on off-diagonals, 0 elsewhere.
15816        let mut a_data = vec![0.0_f64; n * n];
15817        for i in 0..n {
15818            a_data[i * n + i] = 2.0;
15819            if i > 0 {
15820                a_data[i * n + (i - 1)] = -1.0;
15821            }
15822            if i + 1 < n {
15823                a_data[i * n + (i + 1)] = -1.0;
15824            }
15825        }
15826        let b_data: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
15827        let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
15828        execute_thunks(&sched, arena.raw_buf_mut());
15829
15830        let got = read_arena_f64(&arena, x, n);
15831        // Verify A·x ≈ b by computing the residual.
15832        let mut residual = vec![0.0_f64; n];
15833        for i in 0..n {
15834            for j in 0..n {
15835                residual[i] += a_data[i * n + j] * got[j];
15836            }
15837        }
15838        for i in 0..n {
15839            assert!(
15840                (residual[i] - b_data[i]).abs() < 1e-10,
15841                "row {i}: residual {} vs b {}",
15842                residual[i],
15843                b_data[i]
15844            );
15845        }
15846    }
15847
15848    /// Hello Resistor: end-to-end f64 gradient through a dense solve.
15849    ///
15850    /// Forward:
15851    ///   A      : Param  [N, N]   f64
15852    ///   b      : Input  [N]      f64
15853    ///   x      = solve(A, b)            (DenseSolve)
15854    ///   loss   = sum(x)                 (Reduce::Sum)
15855    ///
15856    /// Backward (via grad_with_loss):
15857    ///   ones [N] = expand(d_output, [N])      (Reduce::Sum VJP)
15858    ///   dx_int   = solve(Aᵀ, ones)             (DenseSolve VJP step 1)
15859    ///   dA       = -outer(dx_int, x)           (DenseSolve VJP step 2)
15860    ///   db       = dx_int                       (DenseSolve VJP step 3)
15861    ///
15862    /// Closed form: with loss = sum(solve(A, b)) = ones·x and
15863    /// implicit-function calculus, db = (Aᵀ)⁻¹·ones, dA = -db ⊗ x.
15864    /// We verify this against the autodiff-emitted graph's output and
15865    /// against a finite-difference baseline.
15866    #[test]
15867    fn hello_resistor_gradient_end_to_end() {
15868        use rlx_opt::autodiff::grad_with_loss;
15869        let n = 3usize;
15870
15871        // ── Build forward graph ──
15872        let mut g = Graph::new("hello_resistor");
15873        let a = g.param("A", Shape::new(&[n, n], DType::F64));
15874        let b = g.input("b", Shape::new(&[n], DType::F64));
15875        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
15876        let loss = g.reduce(
15877            x,
15878            ReduceOp::Sum,
15879            vec![0],
15880            false,
15881            Shape::new(&[1], DType::F64),
15882        );
15883        g.set_outputs(vec![loss]);
15884
15885        // ── Run reverse-mode AD ──
15886        let bwd = grad_with_loss(&g, &[a, b]);
15887        assert_eq!(bwd.outputs.len(), 3, "expect [loss, dA, db]");
15888
15889        // ── Locate the inputs the bwd graph still needs from us ──
15890        // grad_with_loss copies forward nodes into bwd, so A/b/d_output
15891        // appear under their original names. Find them by name.
15892        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
15893            for node in graph.nodes() {
15894                let name = match &node.op {
15895                    rlx_ir::Op::Input { name } => Some(name.as_str()),
15896                    rlx_ir::Op::Param { name } => Some(name.as_str()),
15897                    _ => None,
15898                };
15899                if name == Some(want) {
15900                    return node.id;
15901                }
15902            }
15903            panic!("no node named {want:?} in bwd graph");
15904        };
15905        let a_bwd = find_by_name(&bwd, "A");
15906        let b_bwd = find_by_name(&bwd, "b");
15907        let d_out_bwd = find_by_name(&bwd, "d_output");
15908
15909        // ── Test data ──
15910        // A = [[2,1,0],[1,3,1],[0,1,2]]   (SPD tridiagonal, well-conditioned)
15911        // b = [1,2,3]
15912        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
15913        let b_data = [1.0, 2.0, 3.0_f64];
15914        let d_output = [1.0_f64]; // ∂loss/∂loss
15915
15916        // ── Compile + execute backward graph ──
15917        let (sched, mut arena) = prepare_f64(
15918            &bwd,
15919            &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out_bwd, &d_output)],
15920        );
15921        execute_thunks(&sched, arena.raw_buf_mut());
15922
15923        let loss_out = read_arena_f64(&arena, bwd.outputs[0], 1);
15924        let da_out = read_arena_f64(&arena, bwd.outputs[1], n * n);
15925        let db_out = read_arena_f64(&arena, bwd.outputs[2], n);
15926
15927        // ── Closed-form reference ──
15928        // x = A⁻¹ b ; loss = sum(x).
15929        let x_ref = {
15930            let mut a = a_data;
15931            let mut b = b_data;
15932            let info = crate::blas::dgesv(&mut a, &mut b, n, 1);
15933            assert_eq!(info, 0);
15934            b
15935        };
15936        let loss_ref: f64 = x_ref.iter().sum();
15937        // db = (Aᵀ)⁻¹ · 1
15938        let db_ref = {
15939            let mut at = [0.0_f64; 9];
15940            for i in 0..n {
15941                for j in 0..n {
15942                    at[i * n + j] = a_data[j * n + i];
15943                }
15944            }
15945            let mut ones = [1.0_f64; 3];
15946            let info = crate::blas::dgesv(&mut at, &mut ones, n, 1);
15947            assert_eq!(info, 0);
15948            ones
15949        };
15950        // dA = -outer(db, x) ; dA[i,j] = -db[i] * x[j]
15951        let mut da_ref = [0.0_f64; 9];
15952        for i in 0..n {
15953            for j in 0..n {
15954                da_ref[i * n + j] = -db_ref[i] * x_ref[j];
15955            }
15956        }
15957
15958        // ── Assertions vs analytic answer ──
15959        assert!(
15960            (loss_out[0] - loss_ref).abs() < 1e-10,
15961            "loss: got {}, want {}",
15962            loss_out[0],
15963            loss_ref
15964        );
15965        for i in 0..n {
15966            assert!(
15967                (db_out[i] - db_ref[i]).abs() < 1e-10,
15968                "db[{i}]: got {}, want {}",
15969                db_out[i],
15970                db_ref[i]
15971            );
15972        }
15973        for i in 0..n * n {
15974            assert!(
15975                (da_out[i] - da_ref[i]).abs() < 1e-10,
15976                "dA[{i}]: got {}, want {}",
15977                da_out[i],
15978                da_ref[i]
15979            );
15980        }
15981
15982        // ── Cross-check vs finite differences on db (a few entries) ──
15983        // ∂loss/∂b[k] ≈ (loss(b + h·e_k) - loss(b - h·e_k)) / (2h).
15984        let h = 1e-6_f64;
15985        for k in 0..n {
15986            let mut bp = b_data;
15987            bp[k] += h;
15988            let mut bm = b_data;
15989            bm[k] -= h;
15990            let lp = {
15991                let mut ac = a_data;
15992                let info = crate::blas::dgesv(&mut ac, &mut bp, n, 1);
15993                assert_eq!(info, 0);
15994                bp.iter().sum::<f64>()
15995            };
15996            let lm = {
15997                let mut ac = a_data;
15998                let info = crate::blas::dgesv(&mut ac, &mut bm, n, 1);
15999                assert_eq!(info, 0);
16000                bm.iter().sum::<f64>()
16001            };
16002            let fd = (lp - lm) / (2.0 * h);
16003            assert!(
16004                (db_out[k] - fd).abs() < 1e-7,
16005                "FD mismatch on db[{k}]: AD={} FD={}",
16006                db_out[k],
16007                fd
16008            );
16009        }
16010    }
16011
16012    /// Smallest possible Op::Scan basic test: geometric growth.
16013    /// init = [1, 1, 1] f64, body = (x → x + 0.1·x) = (x → 1.1·x),
16014    /// length = 10. Final carry must equal init·(1.1)^10 ≈ 2.5937…
16015    /// to f64 precision.
16016    #[test]
16017    fn scan_geometric_growth_f64() {
16018        let n = 3usize;
16019        let length = 10u32;
16020
16021        // Body: (x) → x + 0.1·x. One Input, one output, same shape/dtype.
16022        let mut body = Graph::new("scan_body");
16023        let x = body.input("carry", Shape::new(&[n], DType::F64));
16024        let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 0.1_f64.to_le_bytes()).collect();
16025        let scale = body.add_node(
16026            Op::Constant { data: scale_bytes },
16027            vec![],
16028            Shape::new(&[n], DType::F64),
16029        );
16030        let scaled = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
16031        let next = body.binary(BinaryOp::Add, x, scaled, Shape::new(&[n], DType::F64));
16032        body.set_outputs(vec![next]);
16033
16034        // Outer graph: scan(init, body, length).
16035        let mut g = Graph::new("scan_outer");
16036        let init = g.input("init", Shape::new(&[n], DType::F64));
16037        let final_carry = g.scan(init, body, length);
16038        g.set_outputs(vec![final_carry]);
16039
16040        let init_data = vec![1.0_f64; n];
16041        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
16042        execute_thunks(&sched, arena.raw_buf_mut());
16043        let got = read_arena_f64(&arena, final_carry, n);
16044        let want: f64 = 1.1_f64.powi(length as i32);
16045        for i in 0..n {
16046            assert!(
16047                (got[i] - want).abs() < 1e-12,
16048                "got[{i}] = {} want {}",
16049                got[i],
16050                want
16051            );
16052        }
16053    }
16054
16055    /// Per-step xs scan: cumulative-sum.
16056    ///   carry_0 = init
16057    ///   carry_{t+1} = carry_t + xs\[t\]
16058    ///   final = sum_{t<length} xs\[t\] + init
16059    /// Body has 2 inputs (carry, x_t) in that NodeId order; one output
16060    /// (next carry). Validates the per-step-input plumbing end-to-end.
16061    #[test]
16062    fn scan_with_xs_cumulative_sum() {
16063        let n = 3usize;
16064        let length = 4u32;
16065
16066        let mut body = Graph::new("cumsum_body");
16067        // carry must come first in NodeId order — declare it first.
16068        let carry = body.input("carry", Shape::new(&[n], DType::F64));
16069        let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
16070        let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
16071        body.set_outputs(vec![next]);
16072
16073        let mut g = Graph::new("cumsum_outer");
16074        let init = g.input("init", Shape::new(&[n], DType::F64));
16075        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
16076        let final_carry = g.scan_with_xs(init, &[xs], body, length);
16077        g.set_outputs(vec![final_carry]);
16078
16079        let init_data = vec![0.0_f64; n];
16080        let xs_data: Vec<f64> = (0..length as usize * n).map(|i| (i + 1) as f64).collect(); // 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12
16081        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
16082        execute_thunks(&sched, arena.raw_buf_mut());
16083        let got = read_arena_f64(&arena, final_carry, n);
16084
16085        // Reference: column-wise sum of xs rows + init. With our row-major
16086        // layout, column j of xs is xs_data[j], xs_data[n+j], xs_data[2n+j], ...
16087        // (per-step row at offset t*n contributes element j to slot j).
16088        let mut want = init_data.clone();
16089        for t in 0..length as usize {
16090            for j in 0..n {
16091                want[j] += xs_data[t * n + j];
16092            }
16093        }
16094        for i in 0..n {
16095            assert!(
16096                (got[i] - want[i]).abs() < 1e-12,
16097                "got[{i}] = {} want {}",
16098                got[i],
16099                want[i]
16100            );
16101        }
16102    }
16103
16104    /// Per-step xs scan composing with DenseSolve — Circulax-shaped:
16105    ///   carry_{t+1} = solve(M, carry_t + xs\[t\])
16106    /// Models a Backward-Euler step driven by a time-varying source.
16107    #[test]
16108    fn scan_with_xs_be_with_drive() {
16109        let n = 3usize;
16110        let length = 4u32;
16111        let dt = 0.1_f64;
16112
16113        let mut m_data = vec![0.0_f64; n * n];
16114        for i in 0..n {
16115            m_data[i * n + i] = 1.0 + dt * 2.0;
16116            if i > 0 {
16117                m_data[i * n + (i - 1)] = -dt;
16118            }
16119            if i + 1 < n {
16120                m_data[i * n + (i + 1)] = -dt;
16121            }
16122        }
16123        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
16124
16125        let mut body = Graph::new("be_drive_body");
16126        let carry = body.input("carry", Shape::new(&[n], DType::F64));
16127        let drive = body.input("drive", Shape::new(&[n], DType::F64));
16128        let m = body.add_node(
16129            Op::Constant { data: m_bytes },
16130            vec![],
16131            Shape::new(&[n, n], DType::F64),
16132        );
16133        let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
16134        let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
16135        body.set_outputs(vec![next]);
16136
16137        let mut g = Graph::new("be_drive_outer");
16138        let init = g.input("init", Shape::new(&[n], DType::F64));
16139        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
16140        let final_carry = g.scan_with_xs(init, &[xs], body, length);
16141        g.set_outputs(vec![final_carry]);
16142
16143        let init_data = vec![0.0_f64; n];
16144        // Drive the system with a unit pulse on element 0 at t=0,
16145        // zeros after.
16146        let mut xs_data = vec![0.0_f64; length as usize * n];
16147        xs_data[0] = 1.0;
16148
16149        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data), (xs, &xs_data)]);
16150        execute_thunks(&sched, arena.raw_buf_mut());
16151        let got = read_arena_f64(&arena, final_carry, n);
16152
16153        // Reference: per-step in pure Rust.
16154        let mut x = init_data.clone();
16155        for t in 0..length as usize {
16156            for j in 0..n {
16157                x[j] += xs_data[t * n + j];
16158            }
16159            let mut a_copy = m_data.clone();
16160            crate::blas::dgesv(&mut a_copy, &mut x, n, 1);
16161        }
16162        for i in 0..n {
16163            assert!(
16164                (got[i] - x[i]).abs() < 1e-12,
16165                "got[{i}] = {} ref {}",
16166                got[i],
16167                x[i]
16168            );
16169        }
16170    }
16171
16172    /// Reverse-mode AD through Op::BatchedDenseSolve. Forward solves
16173    /// `[B, N, N] · x = [B, N]`; loss = sum of all entries. Closed
16174    /// form: dB = (Aᵀ)⁻¹·1, dA = -(Aᵀ)⁻¹·1 ⊗ x. Verified analytically
16175    /// per batch (each slice matches what the unbatched DenseSolve VJP
16176    /// would compute).
16177    #[test]
16178    fn batched_dense_solve_gradient_matches_per_batch_analytic() {
16179        use rlx_opt::autodiff::grad_with_loss;
16180        let n = 3usize;
16181        let batch = 4usize;
16182
16183        let mut g = Graph::new("bds_grad");
16184        let a = g.param("A", Shape::new(&[batch, n, n], DType::F64));
16185        let b = g.input("b", Shape::new(&[batch, n], DType::F64));
16186        let x = g.batched_dense_solve(a, b, Shape::new(&[batch, n], DType::F64));
16187        let loss = g.reduce(
16188            x,
16189            ReduceOp::Sum,
16190            vec![0, 1],
16191            false,
16192            Shape::new(&[1], DType::F64),
16193        );
16194        g.set_outputs(vec![loss]);
16195
16196        let bwd = grad_with_loss(&g, &[a, b]);
16197
16198        let find = |graph: &Graph, want: &str| -> NodeId {
16199            for node in graph.nodes() {
16200                let name = match &node.op {
16201                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16202                    _ => None,
16203                };
16204                if name == Some(want) {
16205                    return node.id;
16206                }
16207            }
16208            panic!("no node named {want}");
16209        };
16210        let a_id = find(&bwd, "A");
16211        let b_id = find(&bwd, "b");
16212        let d_out_id = find(&bwd, "d_output");
16213
16214        let mut rng = rlx_ir::Philox4x32::new(0x57e1_u64);
16215        let mut a_data = vec![0.0_f64; batch * n * n];
16216        let mut b_data = vec![0.0_f64; batch * n];
16217        for bi in 0..batch {
16218            for i in 0..n {
16219                for j in 0..n {
16220                    a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
16221                }
16222                a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
16223            }
16224            for i in 0..n {
16225                b_data[bi * n + i] = rng.next_f32() as f64;
16226            }
16227        }
16228        let d_seed = [1.0_f64];
16229
16230        let (sched, mut arena) = prepare_f64(
16231            &bwd,
16232            &[(a_id, &a_data), (b_id, &b_data), (d_out_id, &d_seed)],
16233        );
16234        execute_thunks(&sched, arena.raw_buf_mut());
16235        let da_out = read_arena_f64(&arena, bwd.outputs[1], batch * n * n);
16236        let db_out = read_arena_f64(&arena, bwd.outputs[2], batch * n);
16237
16238        // Reference: per-batch analytic solve. dB_i = (A_iᵀ)⁻¹ · 1,
16239        // dA_i = -dB_i ⊗ x_i.
16240        for bi in 0..batch {
16241            let a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
16242            let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
16243            let mut a_copy = a_slice.clone();
16244            crate::blas::dgesv(&mut a_copy, &mut b_slice, n, 1);
16245            let x_ref = b_slice.clone();
16246            // dB: solve(A^T, ones)
16247            let mut at = vec![0.0_f64; n * n];
16248            for i in 0..n {
16249                for j in 0..n {
16250                    at[i * n + j] = a_slice[j * n + i];
16251                }
16252            }
16253            let mut ones = vec![1.0_f64; n];
16254            crate::blas::dgesv(&mut at, &mut ones, n, 1);
16255            let db_ref = ones;
16256            for i in 0..n {
16257                let got = db_out[bi * n + i];
16258                assert!(
16259                    (got - db_ref[i]).abs() < 1e-10,
16260                    "batch {bi}, db[{i}]: got {got} ref {}",
16261                    db_ref[i]
16262                );
16263            }
16264            // dA: -outer(db, x)
16265            for i in 0..n {
16266                for j in 0..n {
16267                    let got = da_out[bi * n * n + i * n + j];
16268                    let want = -db_ref[i] * x_ref[j];
16269                    assert!(
16270                        (got - want).abs() < 1e-10,
16271                        "batch {bi}, dA[{i},{j}]: got {got} ref {want}"
16272                    );
16273                }
16274            }
16275        }
16276    }
16277
16278    /// AD knob: gradient through `scan_checkpointed` automatically
16279    /// uses the recompute backward path. Compares dinit from a plain
16280    /// scan against the same forward written with `scan_checkpointed`,
16281    /// both run through `grad_with_loss`. They must match to f64.
16282    #[test]
16283    fn scan_checkpointed_grad_matches_plain_scan_grad() {
16284        use rlx_opt::autodiff::grad_with_loss;
16285        let n = 2usize;
16286        let length = 6u32;
16287
16288        let make_body = || {
16289            let mut body = Graph::new("ck_body");
16290            let carry = body.input("carry", Shape::new(&[n], DType::F64));
16291            let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.05_f64.to_le_bytes()).collect();
16292            let scale = body.add_node(
16293                Op::Constant { data: scale_bytes },
16294                vec![],
16295                Shape::new(&[n], DType::F64),
16296            );
16297            let next = body.binary(BinaryOp::Mul, carry, scale, Shape::new(&[n], DType::F64));
16298            body.set_outputs(vec![next]);
16299            body
16300        };
16301
16302        // Plain scan path.
16303        let mut g_plain = Graph::new("ck_plain");
16304        let init_p = g_plain.input("init", Shape::new(&[n], DType::F64));
16305        let final_p = g_plain.scan(init_p, make_body(), length);
16306        let loss_p = g_plain.reduce(
16307            final_p,
16308            ReduceOp::Sum,
16309            vec![0],
16310            false,
16311            Shape::new(&[1], DType::F64),
16312        );
16313        g_plain.set_outputs(vec![loss_p]);
16314        let bwd_p = grad_with_loss(&g_plain, &[init_p]);
16315
16316        // Checkpointed scan path with K=2 (length=6).
16317        let mut g_ck = Graph::new("ck_ckpt");
16318        let init_c = g_ck.input("init", Shape::new(&[n], DType::F64));
16319        let final_c = g_ck.scan_checkpointed(init_c, make_body(), length, 2);
16320        let loss_c = g_ck.reduce(
16321            final_c,
16322            ReduceOp::Sum,
16323            vec![0],
16324            false,
16325            Shape::new(&[1], DType::F64),
16326        );
16327        g_ck.set_outputs(vec![loss_c]);
16328        let bwd_c = grad_with_loss(&g_ck, &[init_c]);
16329
16330        let find = |graph: &Graph, want: &str| -> NodeId {
16331            for node in graph.nodes() {
16332                let name = match &node.op {
16333                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16334                    _ => None,
16335                };
16336                if name == Some(want) {
16337                    return node.id;
16338                }
16339            }
16340            panic!("no {want}");
16341        };
16342
16343        let init_data = vec![0.5_f64, -0.5];
16344        let d_seed = [1.0_f64];
16345
16346        let (s_p, mut a_p) = prepare_f64(
16347            &bwd_p,
16348            &[
16349                (find(&bwd_p, "init"), &init_data),
16350                (find(&bwd_p, "d_output"), &d_seed),
16351            ],
16352        );
16353        execute_thunks(&s_p, a_p.raw_buf_mut());
16354        let dinit_p = read_arena_f64(&a_p, bwd_p.outputs[1], n);
16355
16356        let (s_c, mut a_c) = prepare_f64(
16357            &bwd_c,
16358            &[
16359                (find(&bwd_c, "init"), &init_data),
16360                (find(&bwd_c, "d_output"), &d_seed),
16361            ],
16362        );
16363        execute_thunks(&s_c, a_c.raw_buf_mut());
16364        let dinit_c = read_arena_f64(&a_c, bwd_c.outputs[1], n);
16365
16366        for i in 0..n {
16367            assert!(
16368                (dinit_p[i] - dinit_c[i]).abs() < 1e-12,
16369                "dinit[{i}]: plain={} checkpointed={}",
16370                dinit_p[i],
16371                dinit_c[i]
16372            );
16373        }
16374    }
16375
16376    /// Recursive checkpointing end-to-end: build a ScanBackward
16377    /// configured with K=2 checkpoints (for length=4), and compare
16378    /// dinit against the same backward graph with full trajectory
16379    /// (K=0). Forward computes a cumulative-sum-style scan; loss = sum.
16380    /// Both paths must agree to f64 precision.
16381    #[test]
16382    fn recursive_checkpointing_matches_full_trajectory() {
16383        let n = 2usize;
16384        let length = 4u32;
16385
16386        // Body: carry + ones (deterministic, no xs)
16387        let build_body = || -> Graph {
16388            let mut body = Graph::new("rc_body");
16389            let carry = body.input("carry", Shape::new(&[n], DType::F64));
16390            let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
16391            let ones = body.add_node(
16392                Op::Constant { data: ones_bytes },
16393                vec![],
16394                Shape::new(&[n], DType::F64),
16395            );
16396            let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
16397            body.set_outputs(vec![next]);
16398            body
16399        };
16400
16401        // body_vjp: same body + d_output, output dcarry. body_vjp is
16402        // used by ScanBackward to walk the chain rule per step.
16403        let body_vjp_for = || -> Graph {
16404            use rlx_opt::autodiff::grad;
16405            let body = build_body();
16406            // grad(body, [carry_id]) → graph with dcarry as the output.
16407            let carry_id = body
16408                .nodes()
16409                .iter()
16410                .find(|n| matches!(n.op, Op::Input { .. }))
16411                .map(|n| n.id)
16412                .unwrap();
16413            grad(&body, &[carry_id])
16414        };
16415
16416        // ── Forward (All-strategy): scan with full trajectory ──
16417        let mut g_full = Graph::new("rc_outer_full");
16418        let init_full = g_full.input("init", Shape::new(&[n], DType::F64));
16419        let traj_full_id = g_full.scan_trajectory(init_full, build_body(), length);
16420        // Hand-build a ScanBackward node that reads the full trajectory.
16421        let upstream_full = g_full.input("upstream", Shape::new(&[length as usize, n], DType::F64));
16422        let dinit_full_id = g_full.scan_backward(
16423            init_full,
16424            traj_full_id,
16425            upstream_full,
16426            &[],
16427            body_vjp_for(),
16428            length,
16429            true,
16430            Shape::new(&[n], DType::F64),
16431        );
16432        g_full.set_outputs(vec![dinit_full_id]);
16433
16434        // ── Forward (Recursive-2): scan saves only K=2 rows ──
16435        // Build the trajectory shape [K, *carry] = [2, 2].
16436        let k = 2u32;
16437        let mut g_rec = Graph::new("rc_outer_rec");
16438        let init_rec = g_rec.input("init", Shape::new(&[n], DType::F64));
16439        let traj_rec_id = g_rec.add_node(
16440            Op::Scan {
16441                body: Box::new(build_body()),
16442                length,
16443                save_trajectory: true,
16444                num_bcast: 0,
16445                num_xs: 0,
16446                num_checkpoints: k,
16447            },
16448            vec![init_rec],
16449            Shape::new(&[k as usize, n], DType::F64),
16450        );
16451        // Same upstream shape as the full version (the upstream is per
16452        // *forward step*, length rows — independent of K).
16453        let upstream_rec = g_rec.input("upstream", Shape::new(&[length as usize, n], DType::F64));
16454        let dinit_rec_id = g_rec.add_node(
16455            Op::ScanBackward {
16456                body_vjp: Box::new(body_vjp_for()),
16457                length,
16458                save_trajectory: true,
16459                num_xs: 0,
16460                num_checkpoints: k,
16461                forward_body: Some(Box::new(build_body())),
16462            },
16463            vec![init_rec, traj_rec_id, upstream_rec],
16464            Shape::new(&[n], DType::F64),
16465        );
16466        g_rec.set_outputs(vec![dinit_rec_id]);
16467
16468        // ── Run both, same inputs ──
16469        let init_data = vec![0.5_f64, -0.5];
16470        let upstream_data: Vec<f64> = (0..length as usize * n).map(|i| (i as f64) * 0.1).collect();
16471
16472        let find = |graph: &Graph, want: &str| -> NodeId {
16473            for node in graph.nodes() {
16474                if let Op::Input { name } = &node.op
16475                    && name == want
16476                {
16477                    return node.id;
16478                }
16479            }
16480            panic!("no input {want}");
16481        };
16482
16483        let (s_full, mut a_full) = prepare_f64(
16484            &g_full,
16485            &[
16486                (find(&g_full, "init"), &init_data),
16487                (find(&g_full, "upstream"), &upstream_data),
16488            ],
16489        );
16490        execute_thunks(&s_full, a_full.raw_buf_mut());
16491        let dinit_full = read_arena_f64(&a_full, g_full.outputs[0], n);
16492
16493        let (s_rec, mut a_rec) = prepare_f64(
16494            &g_rec,
16495            &[
16496                (find(&g_rec, "init"), &init_data),
16497                (find(&g_rec, "upstream"), &upstream_data),
16498            ],
16499        );
16500        execute_thunks(&s_rec, a_rec.raw_buf_mut());
16501        let dinit_rec = read_arena_f64(&a_rec, g_rec.outputs[0], n);
16502
16503        for i in 0..n {
16504            assert!(
16505                (dinit_full[i] - dinit_rec[i]).abs() < 1e-12,
16506                "i={i}: full={} rec={}",
16507                dinit_full[i],
16508                dinit_rec[i]
16509            );
16510        }
16511    }
16512
16513    /// vmap-of-grad: gradient through Scan, vmap'd over init.
16514    /// Forward (per row):
16515    ///   carry_{t+1} = carry_t + ones    (body adds a constant)
16516    ///   loss = sum(carry_length) = sum(init) + length·n
16517    /// Closed form: dloss/dinit_i = 1 for every i. vmap over init at
16518    /// batch=3 → dinit_batched is all-ones [3, n]. Cross-checks
16519    /// against per-row grad_with_loss runs. Validates the vmap rule
16520    /// for Op::ScanBackward.
16521    #[test]
16522    fn vmap_of_grad_scan_matches_per_row_runs() {
16523        use rlx_opt::autodiff::grad_with_loss;
16524        use rlx_opt::vmap::vmap;
16525        let n = 2usize;
16526        let length = 3u32;
16527        let batch = 3usize;
16528
16529        let mut body = Graph::new("scan_grad_body");
16530        let carry = body.input("carry", Shape::new(&[n], DType::F64));
16531        let ones_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
16532        let ones = body.add_node(
16533            Op::Constant { data: ones_bytes },
16534            vec![],
16535            Shape::new(&[n], DType::F64),
16536        );
16537        let next = body.binary(BinaryOp::Add, carry, ones, Shape::new(&[n], DType::F64));
16538        body.set_outputs(vec![next]);
16539
16540        let mut g = Graph::new("scan_grad_outer");
16541        let init = g.input("init", Shape::new(&[n], DType::F64));
16542        let final_x = g.scan(init, body, length);
16543        let loss = g.reduce(
16544            final_x,
16545            ReduceOp::Sum,
16546            vec![0],
16547            false,
16548            Shape::new(&[1], DType::F64),
16549        );
16550        g.set_outputs(vec![loss]);
16551
16552        let bwd = grad_with_loss(&g, &[init]);
16553        let bg = vmap(&bwd, &["init"], batch);
16554
16555        let find = |graph: &Graph, want: &str| -> NodeId {
16556            for node in graph.nodes() {
16557                let name = match &node.op {
16558                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16559                    _ => None,
16560                };
16561                if name == Some(want) {
16562                    return node.id;
16563                }
16564            }
16565            panic!("no node named {want}");
16566        };
16567        let init_b = find(&bg, "init");
16568        let d_out_b = find(&bg, "d_output");
16569
16570        let init_data: Vec<f64> = (0..batch * n).map(|i| (i as f64) * 0.5).collect();
16571        let d_seed = [1.0_f64];
16572
16573        let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (d_out_b, &d_seed)]);
16574        execute_thunks(&sched, arena.raw_buf_mut());
16575        let dinit_b = read_arena_f64(&arena, bg.outputs[1], batch * n);
16576
16577        for i in 0..batch * n {
16578            assert!(
16579                (dinit_b[i] - 1.0).abs() < 1e-12,
16580                "dinit[{i}] = {} (expected 1.0)",
16581                dinit_b[i]
16582            );
16583        }
16584
16585        // Cross-check vs per-row grad_with_loss.
16586        for bi in 0..batch {
16587            let row = &init_data[bi * n..(bi + 1) * n];
16588            let mut g2 = Graph::new("per_row_grad");
16589            let init2 = g2.input("init", Shape::new(&[n], DType::F64));
16590            let mut body2 = Graph::new("per_row_body");
16591            let c2 = body2.input("carry", Shape::new(&[n], DType::F64));
16592            let ones2_bytes: Vec<u8> = (0..n).flat_map(|_| 1.0_f64.to_le_bytes()).collect();
16593            let ones2 = body2.add_node(
16594                Op::Constant { data: ones2_bytes },
16595                vec![],
16596                Shape::new(&[n], DType::F64),
16597            );
16598            let next2 = body2.binary(BinaryOp::Add, c2, ones2, Shape::new(&[n], DType::F64));
16599            body2.set_outputs(vec![next2]);
16600            let final2 = g2.scan(init2, body2, length);
16601            let loss2 = g2.reduce(
16602                final2,
16603                ReduceOp::Sum,
16604                vec![0],
16605                false,
16606                Shape::new(&[1], DType::F64),
16607            );
16608            g2.set_outputs(vec![loss2]);
16609            let bwd2 = grad_with_loss(&g2, &[init2]);
16610            let init2_id = find(&bwd2, "init");
16611            let d_out2_id = find(&bwd2, "d_output");
16612            let (s2, mut a2) = prepare_f64(&bwd2, &[(init2_id, row), (d_out2_id, &d_seed)]);
16613            execute_thunks(&s2, a2.raw_buf_mut());
16614            let row_dinit = read_arena_f64(&a2, bwd2.outputs[1], n);
16615            for j in 0..n {
16616                let got = dinit_b[bi * n + j];
16617                let want = row_dinit[j];
16618                assert!(
16619                    (got - want).abs() < 1e-12,
16620                    "row {bi}, j {j}: vmap'd={got} per-row={want}"
16621                );
16622            }
16623        }
16624    }
16625
16626    /// vmap of Op::Scan: batched cumulative-sum. Forward
16627    ///   carry_{t+1} = carry_t + xs\[t\]
16628    ///   final = init + sum(xs)
16629    /// vmap over both init and xs at batch=3. Each batch row should
16630    /// equal the scalar run of the same body+xs subset.
16631    #[test]
16632    fn vmap_scan_cumulative_sum_matches_scalar_runs() {
16633        use rlx_opt::vmap::vmap;
16634        let n = 2usize;
16635        let length = 4u32;
16636        let batch = 3usize;
16637
16638        // Body: (carry, x_t) → carry + x_t
16639        let mut body = Graph::new("scan_body_cumsum");
16640        let carry = body.input("carry", Shape::new(&[n], DType::F64));
16641        let x_t = body.input("x_t", Shape::new(&[n], DType::F64));
16642        let next = body.binary(BinaryOp::Add, carry, x_t, Shape::new(&[n], DType::F64));
16643        body.set_outputs(vec![next]);
16644
16645        let mut g = Graph::new("scan_outer_cumsum");
16646        let init = g.input("init", Shape::new(&[n], DType::F64));
16647        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
16648        let final_carry = g.scan_with_xs(init, &[xs], body, length);
16649        g.set_outputs(vec![final_carry]);
16650
16651        // vmap over both init and xs.
16652        let bg = vmap(&g, &["init", "xs"], batch);
16653
16654        // Test data — distinct per-batch rows.
16655        let init_data: Vec<f64> = (0..batch * n).map(|i| (i + 1) as f64).collect();
16656        // xs has shape [B, length, n] after vmap (the outer's xs is
16657        // [length, n]; vmap lifts it to [B, length, n]).
16658        let xs_data: Vec<f64> = (0..batch * length as usize * n)
16659            .map(|i| 0.1 * (i as f64))
16660            .collect();
16661
16662        let find = |graph: &Graph, want: &str| -> NodeId {
16663            for node in graph.nodes() {
16664                if let Op::Input { name } = &node.op
16665                    && name == want
16666                {
16667                    return node.id;
16668                }
16669            }
16670            panic!("no input {want}");
16671        };
16672        let init_b = find(&bg, "init");
16673        let xs_b = find(&bg, "xs");
16674        let (sched, mut arena) = prepare_f64(&bg, &[(init_b, &init_data), (xs_b, &xs_data)]);
16675        execute_thunks(&sched, arena.raw_buf_mut());
16676        let batched_out = read_arena_f64(&arena, bg.outputs[0], batch * n);
16677
16678        // Reference: per-batch scalar Scan.
16679        for bi in 0..batch {
16680            let init_slice = &init_data[bi * n..(bi + 1) * n];
16681            let mut x = init_slice.to_vec();
16682            for t in 0..length as usize {
16683                for j in 0..n {
16684                    x[j] += xs_data[bi * length as usize * n + t * n + j];
16685                }
16686            }
16687
16688            for i in 0..n {
16689                let got = batched_out[bi * n + i];
16690                assert!(
16691                    (got - x[i]).abs() < 1e-12,
16692                    "row {bi}, i {i}: got {got} ref {}",
16693                    x[i]
16694                );
16695            }
16696        }
16697    }
16698
16699    /// vmap of dense solve — Circulax-shaped batched parameter sweep.
16700    /// Forward: x = solve(A, b). vmap over both A (batched [B,N,N])
16701    /// and b (batched [B,N]). Run on CPU and compare each batch row
16702    /// against an independent scalar dgesv.
16703    #[test]
16704    fn vmap_dense_solve_matches_scalar_runs() {
16705        use rlx_opt::vmap::vmap;
16706        let n = 3usize;
16707        let batch = 4usize;
16708
16709        let mut g = Graph::new("solve_forward");
16710        let a = g.input("A", Shape::new(&[n, n], DType::F64));
16711        let b = g.input("b", Shape::new(&[n], DType::F64));
16712        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
16713        g.set_outputs(vec![x]);
16714
16715        // vmap both A and b across the batch.
16716        let bg = vmap(&g, &["A", "b"], batch);
16717
16718        // Independent A and b per batch row.
16719        let mut rng = rlx_ir::Philox4x32::new(0xb47c_u64);
16720        let mut a_data = vec![0.0_f64; batch * n * n];
16721        let mut b_data = vec![0.0_f64; batch * n];
16722        for bi in 0..batch {
16723            // Diagonally dominant A — guaranteed non-singular.
16724            for i in 0..n {
16725                for j in 0..n {
16726                    a_data[bi * n * n + i * n + j] = rng.next_f32() as f64 * 0.1;
16727                }
16728                a_data[bi * n * n + i * n + i] += 1.0 + n as f64;
16729            }
16730            for i in 0..n {
16731                b_data[bi * n + i] = rng.next_f32() as f64;
16732            }
16733        }
16734
16735        let find = |graph: &Graph, want: &str| -> NodeId {
16736            for node in graph.nodes() {
16737                if let Op::Input { name } = &node.op
16738                    && name == want
16739                {
16740                    return node.id;
16741                }
16742            }
16743            panic!("no input named {want}");
16744        };
16745        let ba = find(&bg, "A");
16746        let bb = find(&bg, "b");
16747        let (sched, mut arena) = prepare_f64(&bg, &[(ba, &a_data), (bb, &b_data)]);
16748        execute_thunks(&sched, arena.raw_buf_mut());
16749        let batched_x = read_arena_f64(&arena, bg.outputs[0], batch * n);
16750
16751        // Reference: per-batch dgesv.
16752        for bi in 0..batch {
16753            let mut a_slice: Vec<f64> = a_data[bi * n * n..(bi + 1) * n * n].to_vec();
16754            let mut b_slice: Vec<f64> = b_data[bi * n..(bi + 1) * n].to_vec();
16755            crate::blas::dgesv(&mut a_slice, &mut b_slice, n, 1);
16756            for i in 0..n {
16757                let got = batched_x[bi * n + i];
16758                let want = b_slice[i];
16759                assert!(
16760                    (got - want).abs() < 1e-12,
16761                    "row {bi}, i {i}: got {got} want {want}"
16762                );
16763            }
16764        }
16765    }
16766
16767    /// vmap end-to-end: build a graph that computes y = MatMul(x, w) + b
16768    /// and reduces to a per-element loss. vmap over x with batch=4.
16769    /// Run the batched graph and compare each output row against an
16770    /// independent scalar run of the original graph. Validates the
16771    /// structural lift + the runtime path for batched MatMul +
16772    /// batched Binary + batched Reduce.
16773    #[test]
16774    fn vmap_matmul_add_reduce_matches_scalar_runs() {
16775        use rlx_opt::vmap::vmap;
16776        let n = 3usize;
16777        let batch = 4usize;
16778
16779        // Forward graph: y = MatMul(reshape(x, [1,n]), w) + b ; loss = sum(y).
16780        let mut g = Graph::new("vmap_e2e_forward");
16781        let x = g.input("x", Shape::new(&[n], DType::F64));
16782        let w = g.input("w", Shape::new(&[n, n], DType::F64));
16783        let b = g.input("b", Shape::new(&[n], DType::F64));
16784        let x_row = g.add_node(
16785            Op::Reshape {
16786                new_shape: vec![1, n as i64],
16787            },
16788            vec![x],
16789            Shape::new(&[1, n], DType::F64),
16790        );
16791        let mm = g.matmul(x_row, w, Shape::new(&[1, n], DType::F64));
16792        let mm_flat = g.add_node(
16793            Op::Reshape {
16794                new_shape: vec![n as i64],
16795            },
16796            vec![mm],
16797            Shape::new(&[n], DType::F64),
16798        );
16799        let yv = g.binary(BinaryOp::Add, mm_flat, b, Shape::new(&[n], DType::F64));
16800        let loss = g.reduce(
16801            yv,
16802            ReduceOp::Sum,
16803            vec![0],
16804            false,
16805            Shape::new(&[1], DType::F64),
16806        );
16807        g.set_outputs(vec![loss]);
16808
16809        // Build the vmap'd version (batch over x; w and b shared).
16810        let bg = vmap(&g, &["x"], batch);
16811
16812        // Test data — distinct rows so we can verify the per-row dispatch.
16813        let mut rng = rlx_ir::Philox4x32::new(0xc1c0_u64);
16814        let n_w = n * n;
16815        let w_data: Vec<f64> = (0..n_w).map(|_| rng.next_f32() as f64).collect();
16816        let b_data: Vec<f64> = (0..n).map(|_| rng.next_f32() as f64).collect();
16817        let mut x_data_batched: Vec<f64> = Vec::with_capacity(batch * n);
16818        for _ in 0..batch * n {
16819            x_data_batched.push(rng.next_f32() as f64);
16820        }
16821
16822        // Run the batched graph.
16823        let find = |graph: &Graph, want: &str| -> NodeId {
16824            for node in graph.nodes() {
16825                if let Op::Input { name } = &node.op
16826                    && name == want
16827                {
16828                    return node.id;
16829                }
16830            }
16831            panic!("no input named {want}");
16832        };
16833        let bx = find(&bg, "x");
16834        let bw = find(&bg, "w");
16835        let bb = find(&bg, "b");
16836        let (sched, mut arena) =
16837            prepare_f64(&bg, &[(bx, &x_data_batched), (bw, &w_data), (bb, &b_data)]);
16838        execute_thunks(&sched, arena.raw_buf_mut());
16839        // Reduce::Sum on shifted axis 1 with keep_dim=false → output [B, 1]
16840        // (it preserves the leading batch axis but reduces what was [n] to [].
16841        // Since the original output was [1] f64 and the reduce was over
16842        // axis 0, after vmap the leading-axis-shifted reduce keeps the
16843        // leading 1 from the original output's [1] shape.)
16844        let batched_out = read_arena_f64(&arena, bg.outputs[0], batch);
16845
16846        // Reference: run the original (un-batched) graph once per batch row.
16847        for bi in 0..batch {
16848            let xs_slice = &x_data_batched[bi * n..(bi + 1) * n];
16849            let mut g2 = Graph::new("scalar_run");
16850            let x2 = g2.input("x", Shape::new(&[n], DType::F64));
16851            let w2 = g2.input("w", Shape::new(&[n, n], DType::F64));
16852            let b2 = g2.input("b", Shape::new(&[n], DType::F64));
16853            let xr = g2.add_node(
16854                Op::Reshape {
16855                    new_shape: vec![1, n as i64],
16856                },
16857                vec![x2],
16858                Shape::new(&[1, n], DType::F64),
16859            );
16860            let m = g2.matmul(xr, w2, Shape::new(&[1, n], DType::F64));
16861            let mf = g2.add_node(
16862                Op::Reshape {
16863                    new_shape: vec![n as i64],
16864                },
16865                vec![m],
16866                Shape::new(&[n], DType::F64),
16867            );
16868            let yv2 = g2.binary(BinaryOp::Add, mf, b2, Shape::new(&[n], DType::F64));
16869            let l2 = g2.reduce(
16870                yv2,
16871                ReduceOp::Sum,
16872                vec![0],
16873                false,
16874                Shape::new(&[1], DType::F64),
16875            );
16876            g2.set_outputs(vec![l2]);
16877            let (s2, mut a2) = prepare_f64(&g2, &[(x2, xs_slice), (w2, &w_data), (b2, &b_data)]);
16878            execute_thunks(&s2, a2.raw_buf_mut());
16879            let scalar_out = read_arena_f64(&a2, l2, 1);
16880            assert!(
16881                (batched_out[bi] - scalar_out[0]).abs() < 1e-12,
16882                "row {bi}: batched={} scalar={}",
16883                batched_out[bi],
16884                scalar_out[0]
16885            );
16886        }
16887    }
16888
16889    /// Full gradient through scan-with-xs: dinit AND dxs both checked
16890    /// against finite differences. Forward
16891    ///   carry_{t+1} = solve(M, carry_t + xs\[t\])
16892    ///   loss        = sum(carry_length)
16893    /// Verifies that grad_with_loss returns gradients w.r.t. both
16894    /// `init` and `xs` and that dxs matches per-element FD.
16895    #[test]
16896    fn scan_with_xs_dxs_matches_fd() {
16897        use rlx_opt::autodiff::grad_with_loss;
16898        let n = 3usize;
16899        let length = 3u32;
16900        let dt = 0.1_f64;
16901
16902        let mut m_data = vec![0.0_f64; n * n];
16903        for i in 0..n {
16904            m_data[i * n + i] = 1.0 + dt * 2.0;
16905            if i > 0 {
16906                m_data[i * n + (i - 1)] = -dt;
16907            }
16908            if i + 1 < n {
16909                m_data[i * n + (i + 1)] = -dt;
16910            }
16911        }
16912        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
16913
16914        let mut body = Graph::new("be_dxs_body");
16915        let carry = body.input("carry", Shape::new(&[n], DType::F64));
16916        let drive = body.input("drive", Shape::new(&[n], DType::F64));
16917        let m = body.add_node(
16918            Op::Constant { data: m_bytes },
16919            vec![],
16920            Shape::new(&[n, n], DType::F64),
16921        );
16922        let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
16923        let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
16924        body.set_outputs(vec![next]);
16925
16926        let mut g = Graph::new("be_dxs_outer");
16927        let init = g.input("init", Shape::new(&[n], DType::F64));
16928        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
16929        let final_carry = g.scan_with_xs(init, &[xs], body, length);
16930        let loss = g.reduce(
16931            final_carry,
16932            ReduceOp::Sum,
16933            vec![0],
16934            false,
16935            Shape::new(&[1], DType::F64),
16936        );
16937        g.set_outputs(vec![loss]);
16938
16939        // wrt = [init, xs] — get both gradients back.
16940        let bwd = grad_with_loss(&g, &[init, xs]);
16941        assert_eq!(bwd.outputs.len(), 3, "[loss, dinit, dxs]");
16942
16943        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
16944            for node in graph.nodes() {
16945                let name = match &node.op {
16946                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
16947                    _ => None,
16948                };
16949                if name == Some(want) {
16950                    return node.id;
16951                }
16952            }
16953            panic!("no node named {want:?}");
16954        };
16955        let init_bwd = find_by_name(&bwd, "init");
16956        let xs_bwd = find_by_name(&bwd, "xs");
16957        let d_out_bwd = find_by_name(&bwd, "d_output");
16958
16959        let init_data = vec![0.5_f64, 0.0, -0.5];
16960        let xs_data: Vec<f64> = (0..length as usize * n)
16961            .map(|i| 0.1_f64 * ((i as f64) - 4.0))
16962            .collect();
16963        let d_seed = [1.0_f64];
16964
16965        let (sched, mut arena) = prepare_f64(
16966            &bwd,
16967            &[
16968                (init_bwd, &init_data),
16969                (xs_bwd, &xs_data),
16970                (d_out_bwd, &d_seed),
16971            ],
16972        );
16973        execute_thunks(&sched, arena.raw_buf_mut());
16974        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
16975        let dxs = read_arena_f64(&arena, bwd.outputs[2], length as usize * n);
16976
16977        let h = 1e-6;
16978        let loss_at = |x0: &[f64], xs_in: &[f64]| -> f64 {
16979            let mut acc = x0.to_vec();
16980            for t in 0..length as usize {
16981                for j in 0..n {
16982                    acc[j] += xs_in[t * n + j];
16983                }
16984                let mut a_copy = m_data.clone();
16985                crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
16986            }
16987            acc.iter().sum()
16988        };
16989
16990        // FD on dinit (sanity).
16991        for i in 0..n {
16992            let mut ip = init_data.to_vec();
16993            ip[i] += h;
16994            let mut im = init_data.to_vec();
16995            im[i] -= h;
16996            let fd = (loss_at(&ip, &xs_data) - loss_at(&im, &xs_data)) / (2.0 * h);
16997            assert!(
16998                (dinit[i] - fd).abs() < 1e-7,
16999                "FD dinit[{i}]: AD={} FD={}",
17000                dinit[i],
17001                fd
17002            );
17003        }
17004
17005        // FD on every dxs entry — full per-step gradient check.
17006        for t in 0..length as usize {
17007            for j in 0..n {
17008                let idx = t * n + j;
17009                let mut xp = xs_data.clone();
17010                xp[idx] += h;
17011                let mut xm = xs_data.clone();
17012                xm[idx] -= h;
17013                let fd = (loss_at(&init_data, &xp) - loss_at(&init_data, &xm)) / (2.0 * h);
17014                assert!(
17015                    (dxs[idx] - fd).abs() < 1e-7,
17016                    "FD dxs[t={t},j={j}]: AD={} FD={}",
17017                    dxs[idx],
17018                    fd
17019                );
17020            }
17021        }
17022    }
17023
17024    /// Gradient through a scan with per-step xs (Circulax-shaped).
17025    /// Forward:
17026    ///   carry_{t+1} = solve(M, carry_t + xs\[t\])
17027    ///   loss = sum(carry_length)
17028    /// dxs is out of MVP (asserted in the VJP rule's body_vjp `wrt`),
17029    /// but `dinit` flows correctly through the body's reverse Jacobian
17030    /// even with xs in the chain. Verify dinit against finite differences.
17031    #[test]
17032    fn scan_with_xs_gradient_dinit_matches_fd() {
17033        use rlx_opt::autodiff::grad_with_loss;
17034        let n = 3usize;
17035        let length = 3u32;
17036        let dt = 0.1_f64;
17037
17038        let mut m_data = vec![0.0_f64; n * n];
17039        for i in 0..n {
17040            m_data[i * n + i] = 1.0 + dt * 2.0;
17041            if i > 0 {
17042                m_data[i * n + (i - 1)] = -dt;
17043            }
17044            if i + 1 < n {
17045                m_data[i * n + (i + 1)] = -dt;
17046            }
17047        }
17048        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17049
17050        let mut body = Graph::new("be_xs_grad_body");
17051        let carry = body.input("carry", Shape::new(&[n], DType::F64));
17052        let drive = body.input("drive", Shape::new(&[n], DType::F64));
17053        let m = body.add_node(
17054            Op::Constant { data: m_bytes },
17055            vec![],
17056            Shape::new(&[n, n], DType::F64),
17057        );
17058        let driven = body.binary(BinaryOp::Add, carry, drive, Shape::new(&[n], DType::F64));
17059        let next = body.dense_solve(m, driven, Shape::new(&[n], DType::F64));
17060        body.set_outputs(vec![next]);
17061
17062        let mut g = Graph::new("be_xs_grad_outer");
17063        let init = g.input("init", Shape::new(&[n], DType::F64));
17064        let xs = g.input("xs", Shape::new(&[length as usize, n], DType::F64));
17065        let final_carry = g.scan_with_xs(init, &[xs], body, length);
17066        let loss = g.reduce(
17067            final_carry,
17068            ReduceOp::Sum,
17069            vec![0],
17070            false,
17071            Shape::new(&[1], DType::F64),
17072        );
17073        g.set_outputs(vec![loss]);
17074
17075        let bwd = grad_with_loss(&g, &[init]);
17076
17077        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17078            for node in graph.nodes() {
17079                let name = match &node.op {
17080                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17081                    _ => None,
17082                };
17083                if name == Some(want) {
17084                    return node.id;
17085                }
17086            }
17087            panic!("no node named {want:?}");
17088        };
17089        let init_bwd = find_by_name(&bwd, "init");
17090        let xs_bwd = find_by_name(&bwd, "xs");
17091        let d_out_bwd = find_by_name(&bwd, "d_output");
17092
17093        let init_data = vec![0.5_f64, 0.0, -0.5];
17094        // Drive: small per-step pulse, varying per element.
17095        let xs_data: Vec<f64> = (0..length as usize * n)
17096            .map(|i| 0.1_f64 * ((i as f64) - 4.0))
17097            .collect();
17098        let d_seed = [1.0_f64];
17099
17100        let (sched, mut arena) = prepare_f64(
17101            &bwd,
17102            &[
17103                (init_bwd, &init_data),
17104                (xs_bwd, &xs_data),
17105                (d_out_bwd, &d_seed),
17106            ],
17107        );
17108        execute_thunks(&sched, arena.raw_buf_mut());
17109        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
17110
17111        let h = 1e-6;
17112        let loss_at = |x0: &[f64]| -> f64 {
17113            let mut acc = x0.to_vec();
17114            for t in 0..length as usize {
17115                for j in 0..n {
17116                    acc[j] += xs_data[t * n + j];
17117                }
17118                let mut a_copy = m_data.clone();
17119                crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
17120            }
17121            acc.iter().sum()
17122        };
17123        for i in 0..n {
17124            let mut ip = init_data.to_vec();
17125            ip[i] += h;
17126            let mut im = init_data.to_vec();
17127            im[i] -= h;
17128            let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
17129            assert!(
17130                (dinit[i] - fd).abs() < 1e-7,
17131                "FD dinit[{i}]: AD={} FD={}",
17132                dinit[i],
17133                fd
17134            );
17135        }
17136    }
17137
17138    /// Gradient through a geometric-growth scan: forward
17139    ///   x_{t+1} = 1.1 · x_t,    x_0 = init
17140    ///   final   = x_length     = init · 1.1^length
17141    ///   loss    = sum(final)
17142    /// closed-form ∂loss/∂init\[i\] = 1.1^length for every i.
17143    /// Validates the VJP path: AD pre-pass rewrites save_trajectory=false
17144    /// to true, autodiff emits Op::ScanBackward, executor walks t back.
17145    #[test]
17146    fn scan_gradient_geometric_matches_closed_form() {
17147        use rlx_opt::autodiff::grad_with_loss;
17148        let n = 3usize;
17149        let length = 5u32;
17150
17151        let mut body = Graph::new("scan_grad_body");
17152        let x = body.input("carry", Shape::new(&[n], DType::F64));
17153        let scale_bytes: Vec<u8> = (0..n).flat_map(|_| 1.1_f64.to_le_bytes()).collect();
17154        let scale = body.add_node(
17155            Op::Constant { data: scale_bytes },
17156            vec![],
17157            Shape::new(&[n], DType::F64),
17158        );
17159        let next = body.binary(BinaryOp::Mul, x, scale, Shape::new(&[n], DType::F64));
17160        body.set_outputs(vec![next]);
17161
17162        let mut g = Graph::new("scan_grad_outer");
17163        let init = g.input("init", Shape::new(&[n], DType::F64));
17164        let final_x = g.scan(init, body, length);
17165        let loss = g.reduce(
17166            final_x,
17167            ReduceOp::Sum,
17168            vec![0],
17169            false,
17170            Shape::new(&[1], DType::F64),
17171        );
17172        g.set_outputs(vec![loss]);
17173
17174        let bwd = grad_with_loss(&g, &[init]);
17175        assert_eq!(bwd.outputs.len(), 2);
17176
17177        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17178            for node in graph.nodes() {
17179                let name = match &node.op {
17180                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17181                    _ => None,
17182                };
17183                if name == Some(want) {
17184                    return node.id;
17185                }
17186            }
17187            panic!("no node named {want:?}");
17188        };
17189        let init_bwd = find_by_name(&bwd, "init");
17190        let d_out_bwd = find_by_name(&bwd, "d_output");
17191
17192        let init_data = vec![1.0_f64; n];
17193        let d_seed = [1.0_f64];
17194        let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
17195        execute_thunks(&sched, arena.raw_buf_mut());
17196        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
17197
17198        let want = 1.1_f64.powi(length as i32);
17199        for i in 0..n {
17200            assert!(
17201                (dinit[i] - want).abs() < 1e-12,
17202                "dinit[{i}] = {} want {}",
17203                dinit[i],
17204                want
17205            );
17206        }
17207
17208        // Finite-difference cross-check on init[0].
17209        let h = 1e-6;
17210        let loss_at = |x: &[f64]| -> f64 {
17211            let mut acc = x.to_vec();
17212            for _ in 0..length {
17213                for v in acc.iter_mut() {
17214                    *v *= 1.1;
17215                }
17216            }
17217            acc.iter().sum()
17218        };
17219        let mut ip = init_data.clone();
17220        ip[0] += h;
17221        let mut im = init_data.clone();
17222        im[0] -= h;
17223        let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
17224        assert!(
17225            (dinit[0] - fd).abs() < 1e-7,
17226            "FD dinit[0]: AD={} FD={}",
17227            dinit[0],
17228            fd
17229        );
17230    }
17231
17232    /// Gradient through Backward Euler scan composing with DenseSolve.
17233    /// Asserts dinit matches finite-difference per coordinate.
17234    #[test]
17235    fn scan_gradient_backward_euler_matches_fd() {
17236        use rlx_opt::autodiff::grad_with_loss;
17237        let n = 4usize;
17238        let length = 3u32;
17239        let dt = 0.05_f64;
17240
17241        let mut m_data = vec![0.0_f64; n * n];
17242        for i in 0..n {
17243            m_data[i * n + i] = 1.0 + dt * 2.0;
17244            if i > 0 {
17245                m_data[i * n + (i - 1)] = -dt;
17246            }
17247            if i + 1 < n {
17248                m_data[i * n + (i + 1)] = -dt;
17249            }
17250        }
17251        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17252
17253        let mut body = Graph::new("be_grad_body");
17254        let x = body.input("x", Shape::new(&[n], DType::F64));
17255        let m = body.add_node(
17256            Op::Constant { data: m_bytes },
17257            vec![],
17258            Shape::new(&[n, n], DType::F64),
17259        );
17260        let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
17261        body.set_outputs(vec![next]);
17262
17263        let mut g = Graph::new("be_grad_outer");
17264        let init = g.input("x0", Shape::new(&[n], DType::F64));
17265        let final_x = g.scan(init, body, length);
17266        let loss = g.reduce(
17267            final_x,
17268            ReduceOp::Sum,
17269            vec![0],
17270            false,
17271            Shape::new(&[1], DType::F64),
17272        );
17273        g.set_outputs(vec![loss]);
17274
17275        let bwd = grad_with_loss(&g, &[init]);
17276
17277        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17278            for node in graph.nodes() {
17279                let name = match &node.op {
17280                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17281                    _ => None,
17282                };
17283                if name == Some(want) {
17284                    return node.id;
17285                }
17286            }
17287            panic!("no node named {want:?}");
17288        };
17289        let init_bwd = find_by_name(&bwd, "x0");
17290        let d_out_bwd = find_by_name(&bwd, "d_output");
17291
17292        let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
17293        let d_seed = [1.0_f64];
17294        let (sched, mut arena) = prepare_f64(&bwd, &[(init_bwd, &init_data), (d_out_bwd, &d_seed)]);
17295        execute_thunks(&sched, arena.raw_buf_mut());
17296        let dinit = read_arena_f64(&arena, bwd.outputs[1], n);
17297
17298        let h = 1e-6;
17299        let loss_at = |x0: &[f64]| -> f64 {
17300            let mut acc = x0.to_vec();
17301            for _ in 0..length {
17302                let mut a_copy = m_data.clone();
17303                crate::blas::dgesv(&mut a_copy, &mut acc, n, 1);
17304            }
17305            acc.iter().sum()
17306        };
17307        for i in 0..n {
17308            let mut ip = init_data.to_vec();
17309            ip[i] += h;
17310            let mut im = init_data.to_vec();
17311            im[i] -= h;
17312            let fd = (loss_at(&ip) - loss_at(&im)) / (2.0 * h);
17313            assert!(
17314                (dinit[i] - fd).abs() < 1e-7,
17315                "FD dinit[{i}]: AD={} FD={}",
17316                dinit[i],
17317                fd
17318            );
17319        }
17320    }
17321
17322    /// Trajectory-mode scan: same Backward Euler body, but record the
17323    /// carry at every step. Output is `[length, n]` — row `t` is the
17324    /// state after step `t+1`. Validates the SaveAt-style waveform
17325    /// recording end-to-end, including that the last row equals what
17326    /// the no-trajectory variant would have returned.
17327    #[test]
17328    fn scan_trajectory_backward_euler_records_waveform() {
17329        let n = 4usize;
17330        let length = 5u32;
17331        let dt = 0.05_f64;
17332
17333        let mut m_data = vec![0.0_f64; n * n];
17334        for i in 0..n {
17335            m_data[i * n + i] = 1.0 + dt * 2.0;
17336            if i > 0 {
17337                m_data[i * n + (i - 1)] = -dt;
17338            }
17339            if i + 1 < n {
17340                m_data[i * n + (i + 1)] = -dt;
17341            }
17342        }
17343        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17344
17345        let mut body = Graph::new("be_traj_body");
17346        let x = body.input("x", Shape::new(&[n], DType::F64));
17347        let m = body.add_node(
17348            Op::Constant { data: m_bytes },
17349            vec![],
17350            Shape::new(&[n, n], DType::F64),
17351        );
17352        let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
17353        body.set_outputs(vec![next]);
17354
17355        let mut g = Graph::new("be_traj_outer");
17356        let init = g.input("x0", Shape::new(&[n], DType::F64));
17357        let traj = g.scan_trajectory(init, body, length);
17358        g.set_outputs(vec![traj]);
17359
17360        let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
17361        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
17362        execute_thunks(&sched, arena.raw_buf_mut());
17363        let got = read_arena_f64(&arena, traj, length as usize * n);
17364
17365        // Reference: each step's solve, recorded.
17366        let mut want = Vec::<f64>::with_capacity(length as usize * n);
17367        let mut x_ref = init_data.to_vec();
17368        for _ in 0..length {
17369            let mut a_copy = m_data.clone();
17370            crate::blas::dgesv(&mut a_copy, &mut x_ref, n, 1);
17371            want.extend_from_slice(&x_ref);
17372        }
17373        for i in 0..length as usize * n {
17374            assert!(
17375                (got[i] - want[i]).abs() < 1e-12,
17376                "got[{i}] = {} ref {}",
17377                got[i],
17378                want[i]
17379            );
17380        }
17381
17382        // Sanity: trajectory rows are monotone-decreasing in mass
17383        // (Backward Euler diffuses; boundary leak removes mass).
17384        for t in 1..length as usize {
17385            let prev: f64 = got[(t - 1) * n..t * n].iter().sum();
17386            let curr: f64 = got[t * n..(t + 1) * n].iter().sum();
17387            assert!(
17388                curr <= prev + 1e-15,
17389                "mass should decay: row {} sum {prev}, row {t} sum {curr}",
17390                t - 1
17391            );
17392        }
17393
17394        // Last row of the trajectory equals what a non-trajectory
17395        // scan returns — verify by running the same forward through
17396        // the simpler API and comparing.
17397        let mut body2 = Graph::new("be_final_body");
17398        let x2 = body2.input("x", Shape::new(&[n], DType::F64));
17399        let m_bytes2: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17400        let m2 = body2.add_node(
17401            Op::Constant { data: m_bytes2 },
17402            vec![],
17403            Shape::new(&[n, n], DType::F64),
17404        );
17405        let next2 = body2.dense_solve(m2, x2, Shape::new(&[n], DType::F64));
17406        body2.set_outputs(vec![next2]);
17407
17408        let mut g2 = Graph::new("be_final_outer");
17409        let init2 = g2.input("x0", Shape::new(&[n], DType::F64));
17410        let final_x = g2.scan(init2, body2, length);
17411        g2.set_outputs(vec![final_x]);
17412        let (sched2, mut arena2) = prepare_f64(&g2, &[(init2, &init_data)]);
17413        execute_thunks(&sched2, arena2.raw_buf_mut());
17414        let final_got = read_arena_f64(&arena2, final_x, n);
17415
17416        let last_row = &got[(length as usize - 1) * n..length as usize * n];
17417        for i in 0..n {
17418            assert!(
17419                (last_row[i] - final_got[i]).abs() < 1e-15,
17420                "last trajectory row[{i}] = {} vs final-scan = {}",
17421                last_row[i],
17422                final_got[i]
17423            );
17424        }
17425    }
17426
17427    /// Op::Scan composing with Op::DenseSolve — the Circulax-shaped
17428    /// pattern for Backward Euler.
17429    /// Body: x_{t+1} = solve(I + dt·A, x_t).
17430    /// 1-D heat-equation Laplacian A; analytic ground truth from
17431    /// composing the same per-step solve in Rust.
17432    #[test]
17433    fn scan_backward_euler_heat_f64() {
17434        let n = 4usize;
17435        let length = 5u32;
17436        let dt = 0.05_f64;
17437
17438        // Construct M = I + dt · L  where L is the Laplacian (-1, 2, -1).
17439        // M is constant across iterations; embed it in the body via Op::Constant.
17440        let mut m_data = vec![0.0_f64; n * n];
17441        for i in 0..n {
17442            m_data[i * n + i] = 1.0 + dt * 2.0;
17443            if i > 0 {
17444                m_data[i * n + (i - 1)] = -dt;
17445            }
17446            if i + 1 < n {
17447                m_data[i * n + (i + 1)] = -dt;
17448            }
17449        }
17450        let m_bytes: Vec<u8> = m_data.iter().flat_map(|x| x.to_le_bytes()).collect();
17451
17452        let mut body = Graph::new("be_body");
17453        let x = body.input("x", Shape::new(&[n], DType::F64));
17454        let m = body.add_node(
17455            Op::Constant { data: m_bytes },
17456            vec![],
17457            Shape::new(&[n, n], DType::F64),
17458        );
17459        let next = body.dense_solve(m, x, Shape::new(&[n], DType::F64));
17460        body.set_outputs(vec![next]);
17461
17462        let mut g = Graph::new("be_outer");
17463        let init = g.input("x0", Shape::new(&[n], DType::F64));
17464        let final_x = g.scan(init, body, length);
17465        g.set_outputs(vec![final_x]);
17466
17467        // Initial: a sharp pulse at index 1.
17468        let init_data: [f64; 4] = [0.0, 1.0, 0.0, 0.0];
17469        let (sched, mut arena) = prepare_f64(&g, &[(init, &init_data)]);
17470        execute_thunks(&sched, arena.raw_buf_mut());
17471        let got = read_arena_f64(&arena, final_x, n);
17472
17473        // Reference: apply the same M-solve `length` times in pure Rust.
17474        let mut ref_x = init_data.to_vec();
17475        for _ in 0..length {
17476            let mut a_copy = m_data.clone();
17477            crate::blas::dgesv(&mut a_copy, &mut ref_x, n, 1);
17478        }
17479        for i in 0..n {
17480            assert!(
17481                (got[i] - ref_x[i]).abs() < 1e-12,
17482                "got[{i}] = {} ref {}",
17483                got[i],
17484                ref_x[i]
17485            );
17486        }
17487        // Sanity: pulse should diffuse, mass should be conserved-ish
17488        // (Backward Euler is mass-conserving for this stencil with
17489        // zero-flux boundaries — but our boundaries leak, so check
17490        // that mass strictly decreases instead).
17491        let mass: f64 = got.iter().sum();
17492        assert!(mass > 0.0 && mass < 1.0, "diffusion mass: {mass}");
17493    }
17494
17495    /// Multi-RHS forward DenseSolve: X = solve(A, B) with B [N, K]
17496    /// stays correct end-to-end. Verifies the executor/lowering and
17497    /// the LAPACK column-major dance both honour `nrhs > 1`.
17498    #[test]
17499    fn dense_solve_f64_multi_rhs_forward() {
17500        let n = 3usize;
17501        let k = 2usize;
17502        let mut g = Graph::new("solve_multi_rhs");
17503        let a = g.input("A", Shape::new(&[n, n], DType::F64));
17504        let b = g.input("B", Shape::new(&[n, k], DType::F64));
17505        let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
17506        g.set_outputs(vec![x]);
17507
17508        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17509        let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
17510        let (sched, mut arena) = prepare_f64(&g, &[(a, &a_data), (b, &b_data)]);
17511        execute_thunks(&sched, arena.raw_buf_mut());
17512        let x_got = read_arena_f64(&arena, x, n * k);
17513        for c in 0..k {
17514            for i in 0..n {
17515                let mut acc = 0.0_f64;
17516                for j in 0..n {
17517                    acc += a_data[i * n + j] * x_got[j * k + c];
17518                }
17519                let want = b_data[i * k + c];
17520                assert!(
17521                    (acc - want).abs() < 1e-10,
17522                    "col {c} row {i}: got {acc} want {want}"
17523                );
17524            }
17525        }
17526    }
17527
17528    /// Multi-RHS reverse-mode VJP: dB = (Aᵀ)⁻¹·1, dA = -dB · Xᵀ.
17529    /// Verified analytically + finite differences on dB[0,0].
17530    #[test]
17531    fn dense_solve_f64_multi_rhs_gradient() {
17532        use rlx_opt::autodiff::grad_with_loss;
17533        let n = 3usize;
17534        let k = 2usize;
17535        let mut g = Graph::new("solve_mrhs_grad");
17536        let a = g.param("A", Shape::new(&[n, n], DType::F64));
17537        let b = g.input("B", Shape::new(&[n, k], DType::F64));
17538        let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
17539        let loss = g.reduce(
17540            x,
17541            ReduceOp::Sum,
17542            vec![0, 1],
17543            false,
17544            Shape::new(&[1], DType::F64),
17545        );
17546        g.set_outputs(vec![loss]);
17547
17548        let bwd = grad_with_loss(&g, &[a, b]);
17549        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17550            for node in graph.nodes() {
17551                let name = match &node.op {
17552                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17553                    _ => None,
17554                };
17555                if name == Some(want) {
17556                    return node.id;
17557                }
17558            }
17559            panic!("no node named {want:?}");
17560        };
17561        let a_bwd = find_by_name(&bwd, "A");
17562        let b_bwd = find_by_name(&bwd, "B");
17563        let d_out = find_by_name(&bwd, "d_output");
17564
17565        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17566        let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
17567        let d_seed = [1.0_f64];
17568
17569        let (sched, mut arena) = prepare_f64(
17570            &bwd,
17571            &[(a_bwd, &a_data), (b_bwd, &b_data), (d_out, &d_seed)],
17572        );
17573        execute_thunks(&sched, arena.raw_buf_mut());
17574        let da_got = read_arena_f64(&arena, bwd.outputs[1], n * n);
17575        let db_got = read_arena_f64(&arena, bwd.outputs[2], n * k);
17576
17577        // Reference.
17578        let mut x_ref = b_data;
17579        {
17580            let mut a_copy = a_data;
17581            crate::blas::dgesv(&mut a_copy, &mut x_ref, n, k);
17582        }
17583        let mut at = [0.0_f64; 9];
17584        for i in 0..n {
17585            for j in 0..n {
17586                at[i * n + j] = a_data[j * n + i];
17587            }
17588        }
17589        let mut ones_nk = vec![1.0_f64; n * k];
17590        crate::blas::dgesv(&mut at, &mut ones_nk, n, k);
17591        let db_ref = ones_nk;
17592        let mut da_ref = [0.0_f64; 9];
17593        for i in 0..n {
17594            for j in 0..n {
17595                let mut acc = 0.0_f64;
17596                for c in 0..k {
17597                    acc += db_ref[i * k + c] * x_ref[j * k + c];
17598                }
17599                da_ref[i * n + j] = -acc;
17600            }
17601        }
17602        for i in 0..n * k {
17603            assert!(
17604                (db_got[i] - db_ref[i]).abs() < 1e-10,
17605                "dB[{i}]: got {} want {}",
17606                db_got[i],
17607                db_ref[i]
17608            );
17609        }
17610        for i in 0..n * n {
17611            assert!(
17612                (da_got[i] - da_ref[i]).abs() < 1e-10,
17613                "dA[{i}]: got {} want {}",
17614                da_got[i],
17615                da_ref[i]
17616            );
17617        }
17618
17619        // FD on dB[0,0].
17620        let h = 1e-6;
17621        let mut bp = b_data;
17622        bp[0] += h;
17623        let mut bm = b_data;
17624        bm[0] -= h;
17625        let xp = {
17626            let mut a_copy = a_data;
17627            crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
17628            bp
17629        };
17630        let xm = {
17631            let mut a_copy = a_data;
17632            crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
17633            bm
17634        };
17635        let lp: f64 = xp.iter().sum();
17636        let lm: f64 = xm.iter().sum();
17637        let fd = (lp - lm) / (2.0 * h);
17638        assert!(
17639            (db_got[0] - fd).abs() < 1e-7,
17640            "FD dB[0,0]: AD={} FD={}",
17641            db_got[0],
17642            fd
17643        );
17644    }
17645
17646    /// Multi-RHS forward-mode JVP w.r.t. B. Closed form: t_X = solve(A, t_B).
17647    #[test]
17648    fn dense_solve_f64_multi_rhs_jvp() {
17649        use rlx_opt::autodiff_fwd::jvp;
17650        let n = 3usize;
17651        let k = 2usize;
17652        let mut g = Graph::new("solve_mrhs_jvp");
17653        let a = g.input("A", Shape::new(&[n, n], DType::F64));
17654        let b = g.input("B", Shape::new(&[n, k], DType::F64));
17655        let x = g.dense_solve(a, b, Shape::new(&[n, k], DType::F64));
17656        g.set_outputs(vec![x]);
17657
17658        let jg = jvp(&g, &[b]);
17659        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17660            for node in graph.nodes() {
17661                let name = match &node.op {
17662                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17663                    _ => None,
17664                };
17665                if name == Some(want) {
17666                    return node.id;
17667                }
17668            }
17669            panic!("no node named {want:?}");
17670        };
17671        let a_id = find_by_name(&jg, "A");
17672        let b_id = find_by_name(&jg, "B");
17673        let tb_id = find_by_name(&jg, "tangent_B");
17674
17675        let a_data = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0_f64];
17676        let b_data = [1.0, 4.0, 2.0, -1.0, 3.0, 2.0_f64];
17677        let tb_data = [0.5, 0.0, -0.25, 1.0, 1.0, -0.5_f64];
17678
17679        let (sched, mut arena) =
17680            prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
17681        execute_thunks(&sched, arena.raw_buf_mut());
17682        let tangent_x = read_arena_f64(&arena, jg.outputs[1], n * k);
17683
17684        let mut a_copy = a_data;
17685        let mut tb_copy = tb_data;
17686        crate::blas::dgesv(&mut a_copy, &mut tb_copy, n, k);
17687        for i in 0..n * k {
17688            assert!(
17689                (tangent_x[i] - tb_copy[i]).abs() < 1e-10,
17690                "t_X[{i}]: AD={} ref={}",
17691                tangent_x[i],
17692                tb_copy[i]
17693            );
17694        }
17695
17696        let h = 1e-6;
17697        let mut bp = b_data;
17698        let mut bm = b_data;
17699        for i in 0..n * k {
17700            bp[i] += h * tb_data[i];
17701            bm[i] -= h * tb_data[i];
17702        }
17703        let xp = {
17704            let mut a_copy = a_data;
17705            crate::blas::dgesv(&mut a_copy, &mut bp, n, k);
17706            bp
17707        };
17708        let xm = {
17709            let mut a_copy = a_data;
17710            crate::blas::dgesv(&mut a_copy, &mut bm, n, k);
17711            bm
17712        };
17713        for i in 0..n * k {
17714            let fd = (xp[i] - xm[i]) / (2.0 * h);
17715            assert!(
17716                (tangent_x[i] - fd).abs() < 1e-7,
17717                "FD t_X[{i}]: AD={} FD={}",
17718                tangent_x[i],
17719                fd
17720            );
17721        }
17722    }
17723
17724    /// Forward-mode JVP through DenseSolve, end-to-end at f64.
17725    ///
17726    /// Build forward x = solve(A, b), call `jvp(forward, [b])`,
17727    /// compile + run, and check the tangent output matches the
17728    /// closed form `t_x = solve(A, t_b)` plus a finite-difference
17729    /// cross-check `(solve(A, b + h·t_b) − solve(A, b − h·t_b)) / 2h`.
17730    #[test]
17731    fn jvp_dense_solve_b_runs_and_matches_fd() {
17732        use rlx_opt::autodiff_fwd::jvp;
17733        let n = 3usize;
17734
17735        // Forward.
17736        let mut g = Graph::new("jvp_b_e2e");
17737        let a = g.input("A", Shape::new(&[n, n], DType::F64));
17738        let b = g.input("b", Shape::new(&[n], DType::F64));
17739        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
17740        g.set_outputs(vec![x]);
17741
17742        // JVP graph perturbing b only.
17743        let jg = jvp(&g, &[b]);
17744        // The JVP graph holds a fresh "tangent_b" Input on top of A and b.
17745        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17746            for node in graph.nodes() {
17747                let name = match &node.op {
17748                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17749                    _ => None,
17750                };
17751                if name == Some(want) {
17752                    return node.id;
17753                }
17754            }
17755            panic!("no node named {want:?}");
17756        };
17757        let a_id = find_by_name(&jg, "A");
17758        let b_id = find_by_name(&jg, "b");
17759        let tb_id = find_by_name(&jg, "tangent_b");
17760
17761        let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
17762        let b_data: [f64; 3] = [1.0, 2.0, 3.0];
17763        // Pick an arbitrary perturbation direction.
17764        let tb_data: [f64; 3] = [0.5, -0.25, 1.0];
17765
17766        let (sched, mut arena) =
17767            prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (tb_id, &tb_data)]);
17768        execute_thunks(&sched, arena.raw_buf_mut());
17769
17770        // Outputs: [primal_x, tangent_x].
17771        let primal_x = read_arena_f64(&arena, jg.outputs[0], n);
17772        let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
17773
17774        // Closed form: t_x = solve(A, t_b).
17775        let t_x_ref = {
17776            let mut a = a_data;
17777            let mut tb = tb_data;
17778            let info = crate::blas::dgesv(&mut a, &mut tb, n, 1);
17779            assert_eq!(info, 0);
17780            tb
17781        };
17782        for i in 0..n {
17783            assert!(
17784                (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
17785                "t_x[{i}]: got {} want {}",
17786                tangent_x[i],
17787                t_x_ref[i]
17788            );
17789        }
17790
17791        // FD: x(b + h·tb) − x(b − h·tb)) / 2h
17792        let h = 1e-6;
17793        let mut bp = b_data;
17794        let mut bm = b_data;
17795        for i in 0..n {
17796            bp[i] += h * tb_data[i];
17797            bm[i] -= h * tb_data[i];
17798        }
17799        let xp = {
17800            let mut a = a_data;
17801            let info = crate::blas::dgesv(&mut a, &mut bp, n, 1);
17802            assert_eq!(info, 0);
17803            bp
17804        };
17805        let xm = {
17806            let mut a = a_data;
17807            let info = crate::blas::dgesv(&mut a, &mut bm, n, 1);
17808            assert_eq!(info, 0);
17809            bm
17810        };
17811        let fd: Vec<f64> = (0..n).map(|i| (xp[i] - xm[i]) / (2.0 * h)).collect();
17812        for i in 0..n {
17813            assert!(
17814                (tangent_x[i] - fd[i]).abs() < 1e-7,
17815                "FD mismatch t_x[{i}]: AD={} FD={}",
17816                tangent_x[i],
17817                fd[i]
17818            );
17819        }
17820        // Sanity: primal output is the actual solve.
17821        let primal_ref = {
17822            let mut a = a_data;
17823            let mut b = b_data;
17824            crate::blas::dgesv(&mut a, &mut b, n, 1);
17825            b
17826        };
17827        for i in 0..n {
17828            assert!((primal_x[i] - primal_ref[i]).abs() < 1e-10);
17829        }
17830    }
17831
17832    /// Forward-mode JVP through DenseSolve perturbing A. The tangent
17833    /// path includes the −t_A·x correction term.
17834    /// `t_x = −solve(A, t_A · x)` should match a finite-difference
17835    /// directional derivative of `solve(A, b)` w.r.t. A in the
17836    /// `t_A` direction.
17837    #[test]
17838    fn jvp_dense_solve_a_runs_and_matches_fd() {
17839        use rlx_opt::autodiff_fwd::jvp;
17840        let n = 3usize;
17841
17842        let mut g = Graph::new("jvp_a_e2e");
17843        let a = g.input("A", Shape::new(&[n, n], DType::F64));
17844        let b = g.input("b", Shape::new(&[n], DType::F64));
17845        let x = g.dense_solve(a, b, Shape::new(&[n], DType::F64));
17846        g.set_outputs(vec![x]);
17847
17848        let jg = jvp(&g, &[a]);
17849        let find_by_name = |graph: &Graph, want: &str| -> NodeId {
17850            for node in graph.nodes() {
17851                let name = match &node.op {
17852                    Op::Input { name } | Op::Param { name } => Some(name.as_str()),
17853                    _ => None,
17854                };
17855                if name == Some(want) {
17856                    return node.id;
17857                }
17858            }
17859            panic!("no node named {want:?}");
17860        };
17861        let a_id = find_by_name(&jg, "A");
17862        let b_id = find_by_name(&jg, "b");
17863        let ta_id = find_by_name(&jg, "tangent_A");
17864
17865        let a_data: [f64; 9] = [2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0];
17866        let b_data: [f64; 3] = [1.0, 2.0, 3.0];
17867        // Asymmetric perturbation direction for A.
17868        let ta_data: [f64; 9] = [0.10, -0.05, 0.02, 0.03, 0.20, -0.04, -0.01, 0.07, 0.15];
17869
17870        let (sched, mut arena) =
17871            prepare_f64(&jg, &[(a_id, &a_data), (b_id, &b_data), (ta_id, &ta_data)]);
17872        execute_thunks(&sched, arena.raw_buf_mut());
17873
17874        let tangent_x = read_arena_f64(&arena, jg.outputs[1], n);
17875
17876        // Closed form: x = solve(A, b); t_x = −solve(A, t_A · x).
17877        let x_ref = {
17878            let mut a = a_data;
17879            let mut b = b_data;
17880            crate::blas::dgesv(&mut a, &mut b, n, 1);
17881            b
17882        };
17883        let mut prod = [0.0_f64; 3];
17884        for i in 0..n {
17885            for j in 0..n {
17886                prod[i] += ta_data[i * n + j] * x_ref[j];
17887            }
17888        }
17889        let t_x_ref = {
17890            let mut a = a_data;
17891            let mut p = prod;
17892            crate::blas::dgesv(&mut a, &mut p, n, 1);
17893            [-p[0], -p[1], -p[2]]
17894        };
17895        for i in 0..n {
17896            assert!(
17897                (tangent_x[i] - t_x_ref[i]).abs() < 1e-10,
17898                "closed-form t_x[{i}]: AD={} ref={}",
17899                tangent_x[i],
17900                t_x_ref[i]
17901            );
17902        }
17903
17904        // FD: solve(A + h·t_A, b) and solve(A − h·t_A, b).
17905        let h = 1e-6;
17906        let mut ap = a_data;
17907        let mut am = a_data;
17908        for i in 0..n * n {
17909            ap[i] += h * ta_data[i];
17910            am[i] -= h * ta_data[i];
17911        }
17912        let xp = {
17913            let mut a = ap;
17914            let mut b = b_data;
17915            crate::blas::dgesv(&mut a, &mut b, n, 1);
17916            b
17917        };
17918        let xm = {
17919            let mut a = am;
17920            let mut b = b_data;
17921            crate::blas::dgesv(&mut a, &mut b, n, 1);
17922            b
17923        };
17924        for i in 0..n {
17925            let fd = (xp[i] - xm[i]) / (2.0 * h);
17926            assert!(
17927                (tangent_x[i] - fd).abs() < 1e-7,
17928                "FD t_x[{i}]: AD={} FD={}",
17929                tangent_x[i],
17930                fd
17931            );
17932        }
17933    }
17934
17935    /// Real INT8 conv2d parity. Same setup as QMatMul: pre-quantize
17936    /// f32 inputs to i8, run `Op::QConv2d`, compare against an
17937    /// in-test reference loop that does the same i32 accumulation
17938    /// and requantize math. Symmetric quant (zp=0) to keep the math
17939    /// head-to-head.
17940    #[test]
17941    fn q_conv2d_matches_reference() {
17942        use rlx_ir::Philox4x32;
17943        // Small NCHW shape — enough to exercise stride/padding edges.
17944        let n = 1usize;
17945        let c_in = 2usize;
17946        let h = 5usize;
17947        let w_in = 5usize;
17948        let c_out = 3usize;
17949        let kh = 3usize;
17950        let kw = 3usize;
17951        let ph = 1usize;
17952        let pw = 1usize;
17953        let sh = 1usize;
17954        let sw = 1usize;
17955        let h_out = (h + 2 * ph - kh) / sh + 1;
17956        let w_out = (w_in + 2 * pw - kw) / sw + 1;
17957
17958        let x_scale = 0.04f32;
17959        let w_scale = 0.02f32;
17960        let out_scale = 0.5f32;
17961        let mult = x_scale * w_scale / out_scale;
17962
17963        let mut rng = Philox4x32::new(2099);
17964        let mut xf = vec![0f32; n * c_in * h * w_in];
17965        rng.fill_normal(&mut xf);
17966        let mut wf = vec![0f32; c_out * c_in * kh * kw];
17967        rng.fill_normal(&mut wf);
17968        let xq: Vec<i8> = xf
17969            .iter()
17970            .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
17971            .collect();
17972        let wq: Vec<i8> = wf
17973            .iter()
17974            .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
17975            .collect();
17976        let bias: Vec<i32> = vec![0i32; c_out];
17977
17978        let mut g = Graph::new("qconv");
17979        let xn = g.input("x", Shape::new(&[n, c_in, h, w_in], DType::I8));
17980        let wn = g.input("w", Shape::new(&[c_out, c_in, kh, kw], DType::I8));
17981        let bn = g.input("b", Shape::new(&[c_out], DType::I32));
17982        let out = g.q_conv2d(
17983            xn,
17984            wn,
17985            bn,
17986            vec![kh, kw],
17987            vec![sh, sw],
17988            vec![ph, pw],
17989            vec![1, 1],
17990            1,
17991            0,
17992            0,
17993            0,
17994            mult,
17995            Shape::new(&[n, c_out, h_out, w_out], DType::I8),
17996        );
17997        g.set_outputs(vec![out]);
17998
17999        let plan = rlx_opt::memory::plan_memory(&g);
18000        let mut arena = crate::arena::Arena::from_plan(plan);
18001        let sched = compile_thunks(&g, &arena);
18002        // Capture offsets before borrowing the buf mutably (avoids
18003        // overlap between &mut and the &arena.byte_offset reads).
18004        let xn_off = arena.byte_offset(xn);
18005        let wn_off = arena.byte_offset(wn);
18006        let bn_off = arena.byte_offset(bn);
18007        let out_off = arena.byte_offset(out);
18008        let buf = arena.raw_buf_mut();
18009        unsafe {
18010            let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
18011            for (i, &v) in xq.iter().enumerate() {
18012                *p.add(i) = v;
18013            }
18014            let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
18015            for (i, &v) in wq.iter().enumerate() {
18016                *p.add(i) = v;
18017            }
18018            let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
18019            for (i, &v) in bias.iter().enumerate() {
18020                *p.add(i) = v;
18021            }
18022        }
18023        execute_thunks(&sched, arena.raw_buf_mut());
18024        let out_q: Vec<i8> = unsafe {
18025            let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
18026            (0..n * c_out * h_out * w_out).map(|i| *p.add(i)).collect()
18027        };
18028
18029        // Reference: scalar loop in NCHW with the same requantize.
18030        let mut out_ref = vec![0i8; n * c_out * h_out * w_out];
18031        for ni in 0..n {
18032            for co in 0..c_out {
18033                for ho in 0..h_out {
18034                    for wo in 0..w_out {
18035                        let mut acc: i32 = 0;
18036                        for ci in 0..c_in {
18037                            for ki in 0..kh {
18038                                for kj in 0..kw {
18039                                    let hi = ho * sh + ki;
18040                                    let wi = wo * sw + kj;
18041                                    if hi < ph || wi < pw {
18042                                        continue;
18043                                    }
18044                                    let hi = hi - ph;
18045                                    let wi = wi - pw;
18046                                    if hi >= h || wi >= w_in {
18047                                        continue;
18048                                    }
18049                                    let xv =
18050                                        xq[((ni * c_in) + ci) * h * w_in + hi * w_in + wi] as i32;
18051                                    let wv = wq[((co * c_in) + ci) * kh * kw + ki * kw + kj] as i32;
18052                                    acc += xv * wv;
18053                                }
18054                            }
18055                        }
18056                        let r = (acc as f32 * mult).round() as i32;
18057                        let r = r.clamp(-128, 127) as i8;
18058                        out_ref[((ni * c_out) + co) * h_out * w_out + ho * w_out + wo] = r;
18059                    }
18060                }
18061            }
18062        }
18063
18064        for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
18065            assert_eq!(a, r, "q_conv2d[{i}]: kernel {a} vs reference {r}");
18066        }
18067    }
18068
18069    /// Real INT8 matmul parity: compare `Op::QMatMul` against the
18070    /// fake-quant reference `Dequantize → MatMul → Quantize` that
18071    /// would produce the same output if we round-tripped through
18072    /// f32. Both should agree element-for-element (or within ±1 i8
18073    /// step, since rounding in the requantize uses different code
18074    /// paths). Symmetric quantization (zp=0) for both paths to keep
18075    /// the math head-to-head.
18076    #[test]
18077    fn q_matmul_matches_fake_quant_reference() {
18078        use rlx_ir::Philox4x32;
18079        let m = 3usize;
18080        let k = 8usize;
18081        let n = 5usize;
18082        let mut rng = Philox4x32::new(2031);
18083
18084        // Pick scales and quantize random f32 inputs to i8.
18085        let x_scale = 0.05f32;
18086        let w_scale = 0.03f32;
18087        let out_scale = 0.4f32;
18088        let mult = x_scale * w_scale / out_scale;
18089        let mut xf = vec![0f32; m * k];
18090        rng.fill_normal(&mut xf);
18091        let mut wf = vec![0f32; k * n];
18092        rng.fill_normal(&mut wf);
18093        let xq: Vec<i8> = xf
18094            .iter()
18095            .map(|&v| ((v / x_scale).round() as i32).clamp(-128, 127) as i8)
18096            .collect();
18097        let wq: Vec<i8> = wf
18098            .iter()
18099            .map(|&v| ((v / w_scale).round() as i32).clamp(-128, 127) as i8)
18100            .collect();
18101        let bias: Vec<i32> = vec![0i32; n];
18102
18103        // ── Direct INT8 path ──
18104        let _f = DType::F32;
18105        let mut g_q = Graph::new("qmm_direct");
18106        let xn = g_q.input("x", Shape::new(&[m, k], DType::I8));
18107        let wn = g_q.input("w", Shape::new(&[k, n], DType::I8));
18108        let bn = g_q.input("b", Shape::new(&[n], DType::I32));
18109        let out = g_q.q_matmul(xn, wn, bn, 0, 0, 0, mult, Shape::new(&[m, n], DType::I8));
18110        g_q.set_outputs(vec![out]);
18111        let plan = rlx_opt::memory::plan_memory(&g_q);
18112        let mut arena = crate::arena::Arena::from_plan(plan);
18113        let sched = compile_thunks(&g_q, &arena);
18114
18115        // Fill inputs.
18116        let xn_off = arena.byte_offset(xn);
18117        let wn_off = arena.byte_offset(wn);
18118        let bn_off = arena.byte_offset(bn);
18119        let out_off = arena.byte_offset(out);
18120        let buf = arena.raw_buf_mut();
18121        unsafe {
18122            let p = buf.as_mut_ptr().add(xn_off) as *mut i8;
18123            for (i, &v) in xq.iter().enumerate() {
18124                *p.add(i) = v;
18125            }
18126            let p = buf.as_mut_ptr().add(wn_off) as *mut i8;
18127            for (i, &v) in wq.iter().enumerate() {
18128                *p.add(i) = v;
18129            }
18130            let p = buf.as_mut_ptr().add(bn_off) as *mut i32;
18131            for (i, &v) in bias.iter().enumerate() {
18132                *p.add(i) = v;
18133            }
18134        }
18135        execute_thunks(&sched, arena.raw_buf_mut());
18136        let out_q: Vec<i8> = unsafe {
18137            let p = arena.raw_buf().as_ptr().add(out_off) as *const i8;
18138            (0..m * n).map(|i| *p.add(i)).collect()
18139        };
18140
18141        // ── Fake-quant reference: scalar emulation in plain Rust ──
18142        // Same arithmetic the kernel does, but in a verifier loop:
18143        //   acc = Σ (x[m,k]) · (w[k,n]),  // zps are 0
18144        //   out[m,n] = saturate_i8(round(acc · mult) + 0)
18145        let mut out_ref = vec![0i8; m * n];
18146        for mi in 0..m {
18147            for ni in 0..n {
18148                let mut acc: i32 = 0;
18149                for ki in 0..k {
18150                    acc += (xq[mi * k + ki] as i32) * (wq[ki * n + ni] as i32);
18151                }
18152                let r = (acc as f32 * mult).round() as i32;
18153                out_ref[mi * n + ni] = r.clamp(-128, 127) as i8;
18154            }
18155        }
18156
18157        for (i, (a, r)) in out_q.iter().zip(&out_ref).enumerate() {
18158            assert_eq!(a, r, "q_matmul[{i}]: kernel {a} vs reference {r}");
18159        }
18160    }
18161
18162    /// Quantize/Dequantize round-trip — quantize an f32 tensor, then
18163    /// dequantize back, and confirm the result tracks the input
18164    /// within the per-element scale (the inevitable rounding error).
18165    /// Also pins the kernel's saturation behavior at the i8 limits.
18166    #[test]
18167    fn quantize_dequantize_round_trip() {
18168        use rlx_ir::Philox4x32;
18169        let len = 64;
18170        let mut rng = Philox4x32::new(2027);
18171        let mut x = vec![0f32; len];
18172        rng.fill_normal(&mut x);
18173        // Stretch a couple values past the +/- saturation cliff so
18174        // the saturate_i8 path is exercised.
18175        x[0] = 999.0;
18176        x[1] = -999.0;
18177
18178        let scale = 0.05f32;
18179        let zp = 3i32;
18180
18181        let f = DType::F32;
18182        let mut g = Graph::new("qdq");
18183        let xn = g.input("x", Shape::new(&[len], f));
18184        let q = g.quantize(xn, scale, zp);
18185        let dq = g.dequantize(q, scale, zp);
18186        g.set_outputs(vec![dq]);
18187
18188        let plan = rlx_opt::memory::plan_memory(&g);
18189        let mut arena = crate::arena::Arena::from_plan(plan);
18190        let sched = compile_thunks(&g, &arena);
18191        let xn_off = arena.byte_offset(xn);
18192        let dq_off = arena.byte_offset(dq);
18193        let buf = arena.raw_buf_mut();
18194        unsafe {
18195            let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
18196            for (i, &v) in x.iter().enumerate() {
18197                *p.add(i) = v;
18198            }
18199        }
18200        execute_thunks(&sched, arena.raw_buf_mut());
18201        let out: Vec<f32> = unsafe {
18202            let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
18203            (0..len).map(|i| *p.add(i)).collect()
18204        };
18205
18206        // Saturated values at i=0,1 should clamp to ±127's dequant
18207        // range (= (±127 - zp) · scale).
18208        let sat_pos = (127 - zp) as f32 * scale;
18209        let sat_neg = (-128 - zp) as f32 * scale;
18210        assert!((out[0] - sat_pos).abs() < 1e-6, "+sat: {}", out[0]);
18211        assert!((out[1] - sat_neg).abs() < 1e-6, "-sat: {}", out[1]);
18212
18213        // Everything else should round-trip within `scale` (one quant
18214        // step = the worst-case rounding error).
18215        for i in 2..len {
18216            assert!(
18217                (out[i] - x[i]).abs() <= scale + 1e-5,
18218                "qdq[{i}]: {} → {}, scale={scale}",
18219                x[i],
18220                out[i]
18221            );
18222        }
18223    }
18224
18225    /// Per-channel quantize / dequantize: independent scale and zp
18226    /// per slice along an axis. Verifies (a) each channel uses its
18227    /// own scale (not a shared one), (b) saturation still respects
18228    /// the i8 range, (c) channel data layout decomposition is
18229    /// correct (no cross-channel leakage).
18230    #[test]
18231    fn quantize_per_channel_round_trip() {
18232        let c = 4usize;
18233        let inner = 5usize;
18234        // Different magnitudes per channel — proves the per-channel
18235        // scale is actually being read for each row.
18236        let mags = [0.01f32, 0.5, 5.0, 50.0];
18237        let mut x = vec![0f32; c * inner];
18238        for ci in 0..c {
18239            for ii in 0..inner {
18240                // Sweep through values that span [-max_abs, +max_abs]
18241                // for each channel, plus one value past the cliff to
18242                // trigger saturation.
18243                x[ci * inner + ii] = match ii {
18244                    0 => -mags[ci],
18245                    1 => 0.0,
18246                    2 => mags[ci],
18247                    3 => mags[ci] * 1000.0,  // saturates +
18248                    _ => -mags[ci] * 1000.0, // saturates -
18249                };
18250            }
18251        }
18252        let scales: Vec<f32> = mags.iter().map(|&m| m / 127.0).collect();
18253        let zps: Vec<i32> = vec![0, 0, 0, 0];
18254
18255        let f = DType::F32;
18256        let mut g = Graph::new("qdq_pc");
18257        let xn = g.input("x", Shape::new(&[c, inner], f));
18258        let q = g.quantize_per_channel(xn, 0, scales.clone(), zps.clone());
18259        let dq = g.dequantize_per_channel(q, 0, scales.clone(), zps);
18260        g.set_outputs(vec![dq]);
18261
18262        let plan = rlx_opt::memory::plan_memory(&g);
18263        let mut arena = crate::arena::Arena::from_plan(plan);
18264        let sched = compile_thunks(&g, &arena);
18265        let xn_off = arena.byte_offset(xn);
18266        let dq_off = arena.byte_offset(dq);
18267        let buf = arena.raw_buf_mut();
18268        unsafe {
18269            let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
18270            for (i, &v) in x.iter().enumerate() {
18271                *p.add(i) = v;
18272            }
18273        }
18274        execute_thunks(&sched, arena.raw_buf_mut());
18275        let out: Vec<f32> = unsafe {
18276            let p = arena.raw_buf().as_ptr().add(dq_off) as *const f32;
18277            (0..c * inner).map(|i| *p.add(i)).collect()
18278        };
18279
18280        for ci in 0..c {
18281            // Within-range entries (positions 0, 1, 2) must round-trip
18282            // within one quant step of *that channel's* scale.
18283            for ii in 0..3 {
18284                let idx = ci * inner + ii;
18285                assert!(
18286                    (out[idx] - x[idx]).abs() <= scales[ci] + 1e-5,
18287                    "ch {ci} idx {ii}: {} vs {}",
18288                    x[idx],
18289                    out[idx]
18290                );
18291            }
18292            // Saturated positions clamp to ±127 · scale[ci].
18293            let sat_pos = 127.0 * scales[ci];
18294            let sat_neg = -128.0 * scales[ci];
18295            assert!(
18296                (out[ci * inner + 3] - sat_pos).abs() < 1e-5,
18297                "ch {ci} +sat: {}",
18298                out[ci * inner + 3]
18299            );
18300            assert!(
18301                (out[ci * inner + 4] - sat_neg).abs() < 1e-5,
18302                "ch {ci} -sat: {}",
18303                out[ci * inner + 4]
18304            );
18305        }
18306    }
18307
18308    /// `Op::ActivationBackward` parity for every supported kind.
18309    /// Builds a single-op graph `dx = activation_backward(x, dy)` and
18310    /// compares each `dx[i]` to the central-difference `(act(x+ε) -
18311    /// act(x-ε)) / (2ε) · dy\[i\]`. Sweeps the closed-form covered by
18312    /// the kernel.
18313    #[test]
18314    fn activation_backward_matches_numerical_per_kind() {
18315        use rlx_ir::Philox4x32;
18316        use rlx_ir::op::Activation;
18317        let mut rng = Philox4x32::new(91);
18318        let len = 32;
18319        // x sampled away from kink/branch points: shifted positive
18320        // (exp/sqrt/log domain) for the unary-positive activations;
18321        // wide range otherwise. Two parallel tests would be cleaner
18322        // but this is concise enough.
18323        let mut x_pos = vec![0f32; len];
18324        rng.fill_normal(&mut x_pos);
18325        for v in x_pos.iter_mut() {
18326            *v = v.abs() + 0.5;
18327        }
18328        let mut x_any = vec![0f32; len];
18329        rng.fill_normal(&mut x_any);
18330        let mut dy = vec![0f32; len];
18331        rng.fill_normal(&mut dy);
18332
18333        for &(kind, x_data, eps, tol) in &[
18334            (Activation::Sigmoid, &x_any[..], 1e-3, 5e-3),
18335            (Activation::Tanh, &x_any[..], 1e-3, 5e-3),
18336            (Activation::Silu, &x_any[..], 1e-3, 5e-3),
18337            (Activation::Gelu, &x_any[..], 1e-3, 5e-3),
18338            (Activation::GeluApprox, &x_any[..], 1e-3, 5e-3),
18339            (Activation::Exp, &x_any[..], 1e-4, 5e-3),
18340            (Activation::Log, &x_pos[..], 1e-4, 5e-3),
18341            (Activation::Sqrt, &x_pos[..], 1e-4, 5e-3),
18342            (Activation::Rsqrt, &x_pos[..], 1e-4, 5e-3),
18343            (Activation::Neg, &x_any[..], 1e-3, 5e-4),
18344        ] {
18345            let f = DType::F32;
18346            let mut g = Graph::new("act_bw");
18347            let xn = g.input("x", Shape::new(&[len], f));
18348            let dyn_ = g.input("dy", Shape::new(&[len], f));
18349            let dx = g.activation_backward(kind, xn, dyn_);
18350            g.set_outputs(vec![dx]);
18351
18352            let plan = rlx_opt::memory::plan_memory(&g);
18353            let mut arena = crate::arena::Arena::from_plan(plan);
18354            let sched = compile_thunks(&g, &arena);
18355
18356            let xn_off = arena.byte_offset(xn);
18357            let dyn_off = arena.byte_offset(dyn_);
18358            let dx_off = arena.byte_offset(dx);
18359            let buf = arena.raw_buf_mut();
18360            unsafe {
18361                let p = buf.as_mut_ptr().add(xn_off) as *mut f32;
18362                for (i, &v) in x_data.iter().enumerate() {
18363                    *p.add(i) = v;
18364                }
18365                let p = buf.as_mut_ptr().add(dyn_off) as *mut f32;
18366                for (i, &v) in dy.iter().enumerate() {
18367                    *p.add(i) = v;
18368                }
18369            }
18370            execute_thunks(&sched, arena.raw_buf_mut());
18371            let analytical: Vec<f32> = unsafe {
18372                let p = arena.raw_buf().as_ptr().add(dx_off) as *const f32;
18373                (0..len).map(|i| *p.add(i)).collect()
18374            };
18375
18376            // Apply the forward activation manually; finite-difference
18377            // each element.
18378            let act_apply = |kind: Activation, x: f32| -> f32 {
18379                match kind {
18380                    Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
18381                    Activation::Tanh => x.tanh(),
18382                    Activation::Silu => x / (1.0 + (-x).exp()),
18383                    Activation::Gelu => {
18384                        // Match the kernel's exact erf form.
18385                        const INV_SQRT2: f32 = 0.707_106_77;
18386                        0.5 * x * (1.0 + erf_f32(x * INV_SQRT2))
18387                    }
18388                    Activation::GeluApprox => {
18389                        const C: f32 = 0.797_884_6;
18390                        const A: f32 = 0.044_715;
18391                        let inner = C * (x + A * x * x * x);
18392                        0.5 * x * (1.0 + inner.tanh())
18393                    }
18394                    Activation::Exp => x.exp(),
18395                    Activation::Log => x.ln(),
18396                    Activation::Sqrt => x.sqrt(),
18397                    Activation::Rsqrt => 1.0 / x.sqrt(),
18398                    Activation::Neg => -x,
18399                    Activation::Relu => x.max(0.0),
18400                    Activation::Abs => x.abs(),
18401                    Activation::Round => x.round(),
18402                    Activation::Sin => x.sin(),
18403                    Activation::Cos => x.cos(),
18404                    Activation::Tan => x.tan(),
18405                    Activation::Atan => x.atan(),
18406                }
18407            };
18408            for i in 0..len {
18409                let xv = x_data[i];
18410                let plus = act_apply(kind, xv + eps);
18411                let minus = act_apply(kind, xv - eps);
18412                let num = (plus - minus) / (2.0 * eps) * dy[i];
18413                assert!(
18414                    (analytical[i] - num).abs() < tol,
18415                    "{kind:?}[{i}]: analytical {} vs numerical {num}",
18416                    analytical[i]
18417                );
18418            }
18419        }
18420    }
18421
18422    /// Batched 3-D MatMul VJP — the transformer-attention shape
18423    /// `[B, M, K] @ [B, K, N] = [B, M, N]`. Both gradients flow through
18424    /// `Op::Transpose` with a perm that swaps the last two dims.
18425    #[test]
18426    fn matmul_3d_gradient_matches_numerical() {
18427        use rlx_ir::Philox4x32;
18428        let batch = 2usize;
18429        let m = 3usize;
18430        let k = 4usize;
18431        let n = 5usize;
18432        let mut rng = Philox4x32::new(101);
18433        let mut a_data = vec![0f32; batch * m * k];
18434        rng.fill_normal(&mut a_data);
18435        let mut b_data = vec![0f32; batch * k * n];
18436        rng.fill_normal(&mut b_data);
18437
18438        let f = DType::F32;
18439        let mut fwd = Graph::new("matmul_3d");
18440        let an = fwd.input("a", Shape::new(&[batch, m, k], f));
18441        let bp = fwd.param("b", Shape::new(&[batch, k, n], f));
18442        let mm = fwd.matmul(an, bp, Shape::new(&[batch, m, n], f));
18443        let loss = fwd.add_node(
18444            Op::Reduce {
18445                op: ReduceOp::Sum,
18446                axes: vec![0, 1, 2],
18447                keep_dim: false,
18448            },
18449            vec![mm],
18450            Shape::from_dims(&[], f),
18451        );
18452        fwd.set_outputs(vec![loss]);
18453
18454        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[bp]);
18455        let d_out = bwd_graph
18456            .nodes()
18457            .iter()
18458            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18459            .map(|n| n.id)
18460            .unwrap();
18461
18462        let plan = rlx_opt::memory::plan_memory(&bwd_graph);
18463        let mut arena = crate::arena::Arena::from_plan(plan);
18464        let sched = compile_thunks(&bwd_graph, &arena);
18465        for &(id, data) in &[(an, &a_data), (bp, &b_data), (d_out, &vec![1.0f32])] {
18466            let off = arena.byte_offset(id);
18467            let buf = arena.raw_buf_mut();
18468            unsafe {
18469                let p = buf.as_mut_ptr().add(off) as *mut f32;
18470                for (i, &v) in data.iter().enumerate() {
18471                    *p.add(i) = v;
18472                }
18473            }
18474        }
18475        execute_thunks(&sched, arena.raw_buf_mut());
18476        let gb_id = bwd_graph.outputs[1];
18477        let g_b: Vec<f32> = unsafe {
18478            let p = arena.raw_buf().as_ptr().add(arena.byte_offset(gb_id)) as *const f32;
18479            (0..batch * k * n).map(|i| *p.add(i)).collect()
18480        };
18481
18482        // Numerical gradient: differentiate sum(a @ b) w.r.t. each b entry.
18483        let forward_loss = |b_vals: &[f32]| -> f32 {
18484            let mut out = vec![0f32; batch * m * n];
18485            for bi in 0..batch {
18486                for mi in 0..m {
18487                    for ni in 0..n {
18488                        let mut acc = 0f32;
18489                        for ki in 0..k {
18490                            acc +=
18491                                a_data[bi * m * k + mi * k + ki] * b_vals[bi * k * n + ki * n + ni];
18492                        }
18493                        out[bi * m * n + mi * n + ni] = acc;
18494                    }
18495                }
18496            }
18497            out.iter().sum()
18498        };
18499        let eps = 1e-3f32;
18500        let mut bp_p = b_data.clone();
18501        let mut g_b_num = vec![0f32; b_data.len()];
18502        for i in 0..b_data.len() {
18503            let s = bp_p[i];
18504            bp_p[i] = s + eps;
18505            let lp = forward_loss(&bp_p);
18506            bp_p[i] = s - eps;
18507            let lm = forward_loss(&bp_p);
18508            bp_p[i] = s;
18509            g_b_num[i] = (lp - lm) / (2.0 * eps);
18510        }
18511        for (i, (a, n)) in g_b.iter().zip(&g_b_num).enumerate() {
18512            assert!(
18513                (a - n).abs() < 5e-3,
18514                "matmul_3d g_b[{i}]: analytical {a} vs numerical {n}"
18515            );
18516        }
18517    }
18518
18519    /// Composed `Op::Softmax` VJP — the gradient is built from
18520    /// `mul + reduce_sum + expand + sub + mul`, no dedicated
18521    /// SoftmaxBackward kernel. Verifies the closed-form
18522    /// `dx = y · (g - Σ y·g)` matches the FD gradient over a small
18523    /// 2-D logits tensor.
18524    #[test]
18525    fn softmax_gradient_matches_numerical() {
18526        use rlx_ir::Philox4x32;
18527        let n = 3usize;
18528        let c = 5usize;
18529        let mut rng = Philox4x32::new(57);
18530        let mut x_data = vec![0f32; n * c];
18531        rng.fill_normal(&mut x_data);
18532
18533        let f = DType::F32;
18534        let mut fwd = Graph::new("softmax_only");
18535        let xn = fwd.input("x", Shape::new(&[n, c], f));
18536        let sm = fwd.add_node(Op::Softmax { axis: -1 }, vec![xn], Shape::new(&[n, c], f));
18537        // Loss = sum(softmax · target) for some random fixed target —
18538        // any linear loss will do; sum-of-all is the simplest and gives
18539        // a uniform gradient flow into the softmax.
18540        let loss = fwd.add_node(
18541            Op::Reduce {
18542                op: ReduceOp::Sum,
18543                axes: vec![0, 1],
18544                keep_dim: false,
18545            },
18546            vec![sm],
18547            Shape::from_dims(&[], f),
18548        );
18549        fwd.set_outputs(vec![loss]);
18550
18551        // `wrt = [xn]` — autodiff exposes the gradient w.r.t. the
18552        // input so we can compare it directly. The forward NodeId for
18553        // `xn` doubles as its bwd-graph mirror.
18554        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn]);
18555        let d_out = bwd_graph
18556            .nodes()
18557            .iter()
18558            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18559            .map(|n| n.id)
18560            .unwrap();
18561
18562        let plan = rlx_opt::memory::plan_memory(&bwd_graph);
18563        let mut arena = crate::arena::Arena::from_plan(plan);
18564        let sched = compile_thunks(&bwd_graph, &arena);
18565        for &(id, data) in &[(xn, &x_data), (d_out, &vec![1.0f32])] {
18566            let off = arena.byte_offset(id);
18567            let buf = arena.raw_buf_mut();
18568            unsafe {
18569                let p = buf.as_mut_ptr().add(off) as *mut f32;
18570                for (i, &v) in data.iter().enumerate() {
18571                    *p.add(i) = v;
18572                }
18573            }
18574        }
18575        execute_thunks(&sched, arena.raw_buf_mut());
18576        let g_x_id = bwd_graph.outputs[1];
18577        let g_x: Vec<f32> = unsafe {
18578            let p = arena.raw_buf().as_ptr().add(arena.byte_offset(g_x_id)) as *const f32;
18579            (0..n * c).map(|i| *p.add(i)).collect()
18580        };
18581
18582        // Loss derivative: softmax sums to 1 per row → d/dx_i sum(softmax) = 0
18583        // analytically. So expect g_x ≈ 0 within FD precision. (This
18584        // doubles as a strong sanity check for the composition.)
18585        let forward_loss = |x: &[f32]| -> f32 {
18586            let mut total = 0f32;
18587            for ni in 0..n {
18588                let row = &x[ni * c..(ni + 1) * c];
18589                let m = row.iter().fold(f32::NEG_INFINITY, |a, &v| a.max(v));
18590                let denom: f32 = row.iter().map(|&v| (v - m).exp()).sum();
18591                for &v in row {
18592                    total += (v - m).exp() / denom;
18593                }
18594            }
18595            total
18596        };
18597        let eps = 1e-3f32;
18598        let mut p = x_data.clone();
18599        for i in 0..x_data.len() {
18600            let s = p[i];
18601            p[i] = s + eps;
18602            let lp = forward_loss(&p);
18603            p[i] = s - eps;
18604            let lm = forward_loss(&p);
18605            p[i] = s;
18606            let num = (lp - lm) / (2.0 * eps);
18607            assert!(
18608                (g_x[i] - num).abs() < 5e-3,
18609                "softmax g_x[{i}]: analytical {} vs numerical {num}",
18610                g_x[i]
18611            );
18612        }
18613    }
18614
18615    /// LayerNorm VJP — three gradients in one pass:
18616    ///   d_x via `LayerNormBackwardInput`,
18617    ///   d_gamma via `LayerNormBackwardGamma`,
18618    ///   d_beta = `unbroadcast(upstream)` to gamma's shape.
18619    #[test]
18620    fn layer_norm_gradient_matches_numerical() {
18621        use rlx_ir::Philox4x32;
18622        let rows = 3usize;
18623        let h = 6usize;
18624        let mut rng = Philox4x32::new(1009);
18625        let mut x_data = vec![0f32; rows * h];
18626        rng.fill_normal(&mut x_data);
18627        let mut g_data = vec![0f32; h];
18628        rng.fill_normal(&mut g_data);
18629        for v in g_data.iter_mut() {
18630            *v = v.abs() + 0.5;
18631        }
18632        let mut b_data = vec![0f32; h];
18633        rng.fill_normal(&mut b_data);
18634        let eps = 1e-5f32;
18635
18636        let f = DType::F32;
18637        let mut fwd = Graph::new("ln_only");
18638        let xn = fwd.input("x", Shape::new(&[rows, h], f));
18639        let gp = fwd.param("gamma", Shape::new(&[h], f));
18640        let bp = fwd.param("beta", Shape::new(&[h], f));
18641        let ln = fwd.add_node(
18642            Op::LayerNorm { axis: -1, eps },
18643            vec![xn, gp, bp],
18644            Shape::new(&[rows, h], f),
18645        );
18646        let loss = fwd.add_node(
18647            Op::Reduce {
18648                op: ReduceOp::Sum,
18649                axes: vec![0, 1],
18650                keep_dim: false,
18651            },
18652            vec![ln],
18653            Shape::from_dims(&[], f),
18654        );
18655        fwd.set_outputs(vec![loss]);
18656
18657        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[xn, gp, bp]);
18658        let d_out = bwd_graph
18659            .nodes()
18660            .iter()
18661            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18662            .map(|n| n.id)
18663            .unwrap();
18664
18665        let plan = rlx_opt::memory::plan_memory(&bwd_graph);
18666        let mut arena = crate::arena::Arena::from_plan(plan);
18667        let sched = compile_thunks(&bwd_graph, &arena);
18668        for &(id, data) in &[
18669            (xn, &x_data),
18670            (gp, &g_data),
18671            (bp, &b_data),
18672            (d_out, &vec![1.0f32]),
18673        ] {
18674            let off = arena.byte_offset(id);
18675            let buf = arena.raw_buf_mut();
18676            unsafe {
18677                let p = buf.as_mut_ptr().add(off) as *mut f32;
18678                for (i, &v) in data.iter().enumerate() {
18679                    *p.add(i) = v;
18680                }
18681            }
18682        }
18683        execute_thunks(&sched, arena.raw_buf_mut());
18684        let read = |id: NodeId, n: usize| -> Vec<f32> {
18685            let off = arena.byte_offset(id);
18686            unsafe {
18687                let p = arena.raw_buf().as_ptr().add(off) as *const f32;
18688                (0..n).map(|i| *p.add(i)).collect()
18689            }
18690        };
18691        let dx_a = read(bwd_graph.outputs[1], rows * h);
18692        let dg_a = read(bwd_graph.outputs[2], h);
18693        let db_a = read(bwd_graph.outputs[3], h);
18694
18695        let forward_loss = |x: &[f32], g: &[f32], b: &[f32]| -> f32 {
18696            let mut total = 0f32;
18697            for r in 0..rows {
18698                let row = &x[r * h..(r + 1) * h];
18699                let mean = row.iter().sum::<f32>() / h as f32;
18700                let var = row.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / h as f32;
18701                let inv_std = 1.0 / (var + eps).sqrt();
18702                for d in 0..h {
18703                    total += ((row[d] - mean) * inv_std) * g[d] + b[d];
18704                }
18705            }
18706            total
18707        };
18708        let h_eps = 1e-3f32;
18709
18710        let mut x_p = x_data.clone();
18711        for i in 0..x_p.len() {
18712            let s = x_p[i];
18713            x_p[i] = s + h_eps;
18714            let lp = forward_loss(&x_p, &g_data, &b_data);
18715            x_p[i] = s - h_eps;
18716            let lm = forward_loss(&x_p, &g_data, &b_data);
18717            x_p[i] = s;
18718            let num = (lp - lm) / (2.0 * h_eps);
18719            assert!(
18720                (dx_a[i] - num).abs() < 5e-3,
18721                "ln dx[{i}]: analytical {} vs numerical {num}",
18722                dx_a[i]
18723            );
18724        }
18725        let mut g_p = g_data.clone();
18726        for i in 0..g_p.len() {
18727            let s = g_p[i];
18728            g_p[i] = s + h_eps;
18729            let lp = forward_loss(&x_data, &g_p, &b_data);
18730            g_p[i] = s - h_eps;
18731            let lm = forward_loss(&x_data, &g_p, &b_data);
18732            g_p[i] = s;
18733            let num = (lp - lm) / (2.0 * h_eps);
18734            assert!(
18735                (dg_a[i] - num).abs() < 5e-3,
18736                "ln dg[{i}]: analytical {} vs numerical {num}",
18737                dg_a[i]
18738            );
18739        }
18740        let mut b_p = b_data.clone();
18741        for i in 0..b_p.len() {
18742            let s = b_p[i];
18743            b_p[i] = s + h_eps;
18744            let lp = forward_loss(&x_data, &g_data, &b_p);
18745            b_p[i] = s - h_eps;
18746            let lm = forward_loss(&x_data, &g_data, &b_p);
18747            b_p[i] = s;
18748            let num = (lp - lm) / (2.0 * h_eps);
18749            assert!(
18750                (db_a[i] - num).abs() < 5e-3,
18751                "ln db[{i}]: analytical {} vs numerical {num}",
18752                db_a[i]
18753            );
18754        }
18755    }
18756
18757    /// Single dense layer + softmax-cross-entropy + mean reduce —
18758    /// the simplest non-trivial training graph. Validates MatMul,
18759    /// broadcast Add, SCE, Reduce(Mean) VJPs and the grad_with_loss
18760    /// plumbing all at once.
18761    #[test]
18762    fn dense_sce_mean_gradient_matches_numerical() {
18763        use rlx_ir::Philox4x32;
18764        let bs = 4usize;
18765        let k_in = 3usize;
18766        let c = 5usize;
18767        let mut rng = Philox4x32::new(7);
18768        let mut x = vec![0f32; bs * k_in];
18769        rng.fill_normal(&mut x);
18770        let mut w_init = vec![0f32; k_in * c];
18771        rng.fill_normal(&mut w_init);
18772        let mut b_init = vec![0f32; c];
18773        rng.fill_normal(&mut b_init);
18774        let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
18775
18776        // ── Forward graph: loss = mean(sce(x @ w + b, labels)) ──
18777        let f = DType::F32;
18778        let mut fwd = Graph::new("dense_sce");
18779        let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
18780        let lb = fwd.input("labels", Shape::new(&[bs], f));
18781        let wp = fwd.param("w", Shape::new(&[k_in, c], f));
18782        let bp = fwd.param("b", Shape::new(&[c], f));
18783        let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
18784        let logits = fwd.binary(BinaryOp::Add, mm, bp, Shape::new(&[bs, c], f));
18785        let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
18786        let loss = fwd.add_node(
18787            Op::Reduce {
18788                op: ReduceOp::Sum,
18789                axes: vec![0],
18790                keep_dim: false,
18791            },
18792            vec![loss_per],
18793            // Reduce sum of [bs] with axes=[0] keep_dim=false → scalar [].
18794            Shape::from_dims(&[], f),
18795        );
18796        // Use Sum + manual /bs scalar mul — also exercises BinaryOp::Mul VJP path
18797        // less aggressively than Mean would, and gives us a closed-form
18798        // reference for the loss we expect.
18799        // For simplicity though, switch to Mean which the tests should also cover.
18800        // (Re-using `loss` with Sum here for now; the mean factor cancels in
18801        // the gradient comparison since both analytical and numerical use the
18802        // same forward.)
18803        fwd.set_outputs(vec![loss]);
18804
18805        // ── Backward graph ──
18806        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp, bp]);
18807        // Outputs: [loss, grad_w, grad_b]. NodeIds for x/labels/w/b/loss
18808        // in bwd_graph match their fwd ids (the mirror keeps order).
18809        let d_out = bwd_graph
18810            .nodes()
18811            .iter()
18812            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18813            .map(|n| n.id)
18814            .expect("d_output input");
18815
18816        let (sched, mut arena) = prepare(
18817            &bwd_graph,
18818            &[
18819                (xn, &x),
18820                (lb, &labels),
18821                (wp, &w_init),
18822                (bp, &b_init),
18823                (d_out, &[1.0]),
18824            ],
18825        );
18826        execute_thunks(&sched, arena.raw_buf_mut());
18827
18828        let outs = &bwd_graph.outputs;
18829        let loss_id = outs[0];
18830        let gw_id = outs[1];
18831        let gb_id = outs[2];
18832        let loss_actual = read_arena(&arena, loss_id, 1)[0];
18833        let gw_actual = read_arena(&arena, gw_id, k_in * c);
18834        let gb_actual = read_arena(&arena, gb_id, c);
18835
18836        // ── Forward-only graph for finite differences ──
18837        // Re-use the same `fwd` graph; set up its own arena and rerun
18838        // for each perturbed parameter.
18839        let plan = rlx_opt::memory::plan_memory(&fwd);
18840        let mut fwd_arena = crate::arena::Arena::from_plan(plan);
18841        let fwd_sched = compile_thunks(&fwd, &fwd_arena);
18842        write_arena(&mut fwd_arena, xn, &x);
18843        write_arena(&mut fwd_arena, lb, &labels);
18844
18845        let run_loss = |arena: &mut crate::arena::Arena, w: &[f32], b: &[f32]| -> f32 {
18846            write_arena(arena, wp, w);
18847            write_arena(arena, bp, b);
18848            execute_thunks(&fwd_sched, arena.raw_buf_mut());
18849            read_arena(arena, loss, 1)[0]
18850        };
18851
18852        // Sanity: the loss reported by the bwd graph matches the
18853        // forward-only graph on the unperturbed inputs.
18854        let loss_check = run_loss(&mut fwd_arena, &w_init, &b_init);
18855        assert!(
18856            (loss_actual - loss_check).abs() < 1e-4,
18857            "loss mismatch: bwd graph {loss_actual} vs fwd-only {loss_check}"
18858        );
18859
18860        let eps = 1e-3f32;
18861        let mut w_perturbed = w_init.clone();
18862        let mut gw_numerical = vec![0f32; w_init.len()];
18863        for i in 0..w_init.len() {
18864            let saved = w_perturbed[i];
18865            w_perturbed[i] = saved + eps;
18866            let lp = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
18867            w_perturbed[i] = saved - eps;
18868            let lm = run_loss(&mut fwd_arena, &w_perturbed, &b_init);
18869            w_perturbed[i] = saved;
18870            gw_numerical[i] = (lp - lm) / (2.0 * eps);
18871        }
18872        for (i, (a, n)) in gw_actual.iter().zip(&gw_numerical).enumerate() {
18873            assert!(
18874                (a - n).abs() < 5e-3,
18875                "grad_w[{i}]: analytical {a} vs numerical {n}"
18876            );
18877        }
18878
18879        let mut b_perturbed = b_init.clone();
18880        let mut gb_numerical = vec![0f32; b_init.len()];
18881        for i in 0..b_init.len() {
18882            let saved = b_perturbed[i];
18883            b_perturbed[i] = saved + eps;
18884            let lp = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
18885            b_perturbed[i] = saved - eps;
18886            let lm = run_loss(&mut fwd_arena, &w_init, &b_perturbed);
18887            b_perturbed[i] = saved;
18888            gb_numerical[i] = (lp - lm) / (2.0 * eps);
18889        }
18890        for (i, (a, n)) in gb_actual.iter().zip(&gb_numerical).enumerate() {
18891            assert!(
18892                (a - n).abs() < 5e-3,
18893                "grad_b[{i}]: analytical {a} vs numerical {n}"
18894            );
18895        }
18896    }
18897
18898    /// Reduce::Mean specifically — verifies the 1/N scaling in the VJP.
18899    /// The same dense+SCE graph but with Mean instead of Sum on the loss.
18900    #[test]
18901    fn dense_sce_mean_reduce_gradient_matches_numerical() {
18902        use rlx_ir::Philox4x32;
18903        let bs = 3usize;
18904        let k_in = 2usize;
18905        let c = 4usize;
18906        let mut rng = Philox4x32::new(13);
18907        let mut x = vec![0f32; bs * k_in];
18908        rng.fill_normal(&mut x);
18909        let mut w_init = vec![0f32; k_in * c];
18910        rng.fill_normal(&mut w_init);
18911        let labels: Vec<f32> = (0..bs).map(|i| (i % c) as f32).collect();
18912
18913        let f = DType::F32;
18914        let mut fwd = Graph::new("dense_sce_mean");
18915        let xn = fwd.input("x", Shape::new(&[bs, k_in], f));
18916        let lb = fwd.input("labels", Shape::new(&[bs], f));
18917        let wp = fwd.param("w", Shape::new(&[k_in, c], f));
18918        let mm = fwd.matmul(xn, wp, Shape::new(&[bs, c], f));
18919        let loss_per = fwd.softmax_cross_entropy_with_logits(mm, lb);
18920        let loss = fwd.add_node(
18921            Op::Reduce {
18922                op: ReduceOp::Mean,
18923                axes: vec![0],
18924                keep_dim: false,
18925            },
18926            vec![loss_per],
18927            Shape::from_dims(&[], f),
18928        );
18929        fwd.set_outputs(vec![loss]);
18930
18931        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wp]);
18932        let d_out = bwd_graph
18933            .nodes()
18934            .iter()
18935            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
18936            .map(|n| n.id)
18937            .unwrap();
18938
18939        let (sched, mut arena) = prepare(
18940            &bwd_graph,
18941            &[(xn, &x), (lb, &labels), (wp, &w_init), (d_out, &[1.0])],
18942        );
18943        execute_thunks(&sched, arena.raw_buf_mut());
18944
18945        let outs = &bwd_graph.outputs;
18946        let loss_id = outs[0];
18947        let gw_id = outs[1];
18948        let _ = read_arena(&arena, loss_id, 1)[0];
18949        let gw_actual = read_arena(&arena, gw_id, k_in * c);
18950
18951        let plan = rlx_opt::memory::plan_memory(&fwd);
18952        let mut fwd_arena = crate::arena::Arena::from_plan(plan);
18953        let fwd_sched = compile_thunks(&fwd, &fwd_arena);
18954        write_arena(&mut fwd_arena, xn, &x);
18955        write_arena(&mut fwd_arena, lb, &labels);
18956
18957        let run_loss = |arena: &mut crate::arena::Arena, w: &[f32]| -> f32 {
18958            write_arena(arena, wp, w);
18959            execute_thunks(&fwd_sched, arena.raw_buf_mut());
18960            read_arena(arena, loss, 1)[0]
18961        };
18962
18963        let eps = 1e-3f32;
18964        let mut wp_p = w_init.clone();
18965        let mut gw_num = vec![0f32; w_init.len()];
18966        for i in 0..w_init.len() {
18967            let s = wp_p[i];
18968            wp_p[i] = s + eps;
18969            let lp = run_loss(&mut fwd_arena, &wp_p);
18970            wp_p[i] = s - eps;
18971            let lm = run_loss(&mut fwd_arena, &wp_p);
18972            wp_p[i] = s;
18973            gw_num[i] = (lp - lm) / (2.0 * eps);
18974        }
18975        for (i, (a, n)) in gw_actual.iter().zip(&gw_num).enumerate() {
18976            assert!((a - n).abs() < 5e-3, "mean reduce grad_w[{i}]: {a} vs {n}");
18977        }
18978    }
18979    /// The full TinyConv-MNIST forward path (downsized) plumbed
18980    /// through grad_with_loss. Validates that Conv, Pool(Max), ReLU,
18981    /// Reshape, MatMul, Add (broadcast), SCE, Reduce(Mean) VJPs all
18982    /// compose into a graph that produces correct gradients.
18983    #[test]
18984    fn tinyconv_full_gradient_matches_numerical() {
18985        use rlx_ir::Philox4x32;
18986        // Tiny shapes so finite differences finish in <1s.
18987        let n = 1usize;
18988        let c_in = 1usize;
18989        let h = 6usize;
18990        let w_in = 6usize;
18991        let c_mid = 2usize; // first conv output channels
18992        let kh = 3;
18993        let kw = 3;
18994        let h1 = h - kh + 1; // 4
18995        let w1 = w_in - kw + 1; // 4
18996        let h2 = h1 / 2;
18997        let w2 = w1 / 2; // 2 × 2 after 2× pool
18998        let flat = c_mid * h2 * w2; // 8
18999        let num_classes = 3usize;
19000
19001        let mut rng = Philox4x32::new(31);
19002        let mut x = vec![0f32; n * c_in * h * w_in];
19003        rng.fill_normal(&mut x);
19004        let mut wc = vec![0f32; c_mid * c_in * kh * kw];
19005        rng.fill_normal(&mut wc);
19006        for v in wc.iter_mut() {
19007            *v *= 0.2;
19008        }
19009        // Shift conv-bias well away from the ReLU zero-boundary. Without
19010        // this, an ε-perturbation of bc[c] can flip the ReLU mask on a
19011        // pre-activation that happened to land near zero — making the
19012        // central-difference numerical gradient discontinuous and
19013        // diverge from the analytical (which assumes local smoothness).
19014        // +5.0 keeps every pre-activation positive for any random init
19015        // produced by Philox seed 31 with the wc/x scales used here, so
19016        // ReLU acts as an identity and finite differences are exact.
19017        let bc: Vec<f32> = (0..c_mid).map(|i| 5.0 + 0.1 * i as f32).collect();
19018        let mut wfc = vec![0f32; flat * num_classes];
19019        rng.fill_normal(&mut wfc);
19020        for v in wfc.iter_mut() {
19021            *v *= 0.5;
19022        }
19023        let mut bfc = vec![0f32; num_classes];
19024        rng.fill_normal(&mut bfc);
19025        let labels: Vec<f32> = vec![1.0]; // batch=1
19026
19027        let f = DType::F32;
19028        let mut fwd = Graph::new("tinyconv");
19029        let xn = fwd.input("x", Shape::new(&[n, c_in, h, w_in], f));
19030        let lb = fwd.input("labels", Shape::new(&[n], f));
19031        let wcp = fwd.param("wc", Shape::new(&[c_mid, c_in, kh, kw], f));
19032        let bcp = fwd.param("bc", Shape::new(&[c_mid], f));
19033        let wfp = fwd.param("wfc", Shape::new(&[flat, num_classes], f));
19034        let bfp = fwd.param("bfc", Shape::new(&[num_classes], f));
19035
19036        // conv: [n, c_in, h, w] → [n, c_mid, h1, w1]
19037        let conv = fwd.add_node(
19038            Op::Conv {
19039                kernel_size: vec![kh, kw],
19040                stride: vec![1, 1],
19041                padding: vec![0, 0],
19042                dilation: vec![1, 1],
19043                groups: 1,
19044            },
19045            vec![xn, wcp],
19046            Shape::new(&[n, c_mid, h1, w1], f),
19047        );
19048        // Bias add: expand bc[c_mid] up to the full [n, c_mid, h1, w1]
19049        // shape so the Add becomes a plain element-wise op. Going through
19050        // an explicit Reshape→Expand instead of relying on the Add to
19051        // broadcast `[1, C, 1, 1]` → `[N, C, H, W]` works around a known
19052        // limitation of `rlx-cpu`'s `Op::Binary` lowering: it dispatches
19053        // on `out_len % rhs_len == 0` and treats `rhs` as a last-axis
19054        // bias, which produces `bc[0], bc[1], bc[0], bc[1], …` alternating
19055        // across all positions instead of channel-broadcasting. Going
19056        // through Expand (a real broadcast thunk) avoids that path
19057        // entirely. The autodiff still exercises `unbroadcast` because
19058        // `Op::Expand`'s VJP reduces over the broadcast axes.
19059        let bc_4d = fwd.add_node(
19060            Op::Reshape {
19061                new_shape: vec![1, c_mid as i64, 1, 1],
19062            },
19063            vec![bcp],
19064            Shape::new(&[1, c_mid, 1, 1], f),
19065        );
19066        let bc_expanded = fwd.add_node(
19067            Op::Expand {
19068                target_shape: vec![n as i64, c_mid as i64, h1 as i64, w1 as i64],
19069            },
19070            vec![bc_4d],
19071            Shape::new(&[n, c_mid, h1, w1], f),
19072        );
19073        let conv_b = fwd.binary(
19074            BinaryOp::Add,
19075            conv,
19076            bc_expanded,
19077            Shape::new(&[n, c_mid, h1, w1], f),
19078        );
19079        let relu = fwd.activation(Activation::Relu, conv_b, Shape::new(&[n, c_mid, h1, w1], f));
19080        let pool = fwd.add_node(
19081            Op::Pool {
19082                kind: ReduceOp::Max,
19083                kernel_size: vec![2, 2],
19084                stride: vec![2, 2],
19085                padding: vec![0, 0],
19086            },
19087            vec![relu],
19088            Shape::new(&[n, c_mid, h2, w2], f),
19089        );
19090        let flatn = fwd.add_node(
19091            Op::Reshape {
19092                new_shape: vec![n as i64, flat as i64],
19093            },
19094            vec![pool],
19095            Shape::new(&[n, flat], f),
19096        );
19097        let mm = fwd.matmul(flatn, wfp, Shape::new(&[n, num_classes], f));
19098        let logits = fwd.binary(BinaryOp::Add, mm, bfp, Shape::new(&[n, num_classes], f));
19099        let loss_per = fwd.softmax_cross_entropy_with_logits(logits, lb);
19100        let loss = fwd.add_node(
19101            Op::Reduce {
19102                op: ReduceOp::Mean,
19103                axes: vec![0],
19104                keep_dim: false,
19105            },
19106            vec![loss_per],
19107            Shape::from_dims(&[], f),
19108        );
19109        fwd.set_outputs(vec![loss]);
19110
19111        let bwd_graph = rlx_opt::autodiff::grad_with_loss(&fwd, &[wcp, bcp, wfp, bfp]);
19112        let d_out = bwd_graph
19113            .nodes()
19114            .iter()
19115            .find(|n| matches!(&n.op, Op::Input { name } if name == "d_output"))
19116            .map(|n| n.id)
19117            .unwrap();
19118
19119        let (sched, mut arena) = prepare(
19120            &bwd_graph,
19121            &[
19122                (xn, &x),
19123                (lb, &labels),
19124                (wcp, &wc),
19125                (bcp, &bc),
19126                (wfp, &wfc),
19127                (bfp, &bfc),
19128                (d_out, &[1.0]),
19129            ],
19130        );
19131        execute_thunks(&sched, arena.raw_buf_mut());
19132
19133        let outs = bwd_graph.outputs.clone();
19134        let loss_id = outs[0];
19135        let g_wc_id = outs[1];
19136        let g_bc_id = outs[2];
19137        let g_wfc_id = outs[3];
19138        let g_bfc_id = outs[4];
19139        let loss_actual = read_arena(&arena, loss_id, 1)[0];
19140        let g_wc = read_arena(&arena, g_wc_id, wc.len());
19141        let g_bc = read_arena(&arena, g_bc_id, bc.len());
19142        let g_wfc = read_arena(&arena, g_wfc_id, wfc.len());
19143        let g_bfc = read_arena(&arena, g_bfc_id, bfc.len());
19144
19145        // Forward-only arena for finite differences.
19146        let plan = rlx_opt::memory::plan_memory(&fwd);
19147        let mut fwd_arena = crate::arena::Arena::from_plan(plan);
19148        let fwd_sched = compile_thunks(&fwd, &fwd_arena);
19149        write_arena(&mut fwd_arena, xn, &x);
19150        write_arena(&mut fwd_arena, lb, &labels);
19151
19152        // Closure variant: we need to set all four params each call so
19153        // perturbations to one don't leak between sweeps.
19154        let run_loss = |arena: &mut crate::arena::Arena,
19155                        wc: &[f32],
19156                        bc: &[f32],
19157                        wfc: &[f32],
19158                        bfc: &[f32]|
19159         -> f32 {
19160            write_arena(arena, wcp, wc);
19161            write_arena(arena, bcp, bc);
19162            write_arena(arena, wfp, wfc);
19163            write_arena(arena, bfp, bfc);
19164            execute_thunks(&fwd_sched, arena.raw_buf_mut());
19165            read_arena(arena, loss, 1)[0]
19166        };
19167
19168        let loss_check = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc);
19169        assert!(
19170            (loss_actual - loss_check).abs() < 1e-4,
19171            "tinyconv loss mismatch: bwd {loss_actual} vs fwd {loss_check}"
19172        );
19173
19174        let eps = 1e-3f32;
19175        let check_grad = |arena: &mut crate::arena::Arena,
19176                          name: &str,
19177                          analytical: &[f32],
19178                          mut perturb: Box<
19179            dyn FnMut(&mut [f32], usize, f32, &mut crate::arena::Arena) -> f32 + '_,
19180        >,
19181                          n: usize| {
19182            for i in 0..n {
19183                let lp = perturb(&mut analytical.to_vec(), i, eps, arena);
19184                let lm = perturb(&mut analytical.to_vec(), i, -eps, arena);
19185                let num = (lp - lm) / (2.0 * eps);
19186                assert!(
19187                    (analytical[i] - num).abs() < 5e-3,
19188                    "{name}[{i}]: analytical {} vs numerical {num}",
19189                    analytical[i]
19190                );
19191            }
19192        };
19193
19194        // Helper to perturb one param and run forward. Kept as a
19195        // reference for the explicit per-param sweep pattern below.
19196        #[allow(unused_macros)]
19197        macro_rules! sweep {
19198            ($name:expr, $base:expr, $analytical:expr, $set_param:ident) => {{
19199                let n = $base.len();
19200                for i in 0..n {
19201                    let mut p = $base.clone();
19202                    let s = p[i];
19203                    p[i] = s + eps;
19204                    let lp = {
19205                        let $set_param = &p;
19206                        run_loss(&mut fwd_arena, &wc, &bc, &wfc, &bfc).max(f32::NEG_INFINITY);
19207                        // Reset others, set the one being swept, run.
19208                        // (the macro receives one of the four params via $set_param)
19209                        let _ = $set_param;
19210                        // Fall through to the explicit per-param helper:
19211                        0.0_f32
19212                    };
19213                    let _ = lp;
19214                }
19215            }};
19216        }
19217        let _ = check_grad; // silence unused (sweep! macro is intentionally\n        // unused — kept as reference for the per-param sweep pattern below)
19218
19219        // Per-param sweeps (explicit, not macro — clearer).
19220        for i in 0..wc.len() {
19221            let mut p = wc.clone();
19222            let s = p[i];
19223            p[i] = s + eps;
19224            let lp = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
19225            p[i] = s - eps;
19226            let lm = run_loss(&mut fwd_arena, &p, &bc, &wfc, &bfc);
19227            let num = (lp - lm) / (2.0 * eps);
19228            assert!(
19229                (g_wc[i] - num).abs() < 5e-3,
19230                "g_wc[{i}]: {} vs {num}",
19231                g_wc[i]
19232            );
19233        }
19234        for i in 0..bc.len() {
19235            let mut p = bc.clone();
19236            let s = p[i];
19237            p[i] = s + eps;
19238            let lp = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
19239            p[i] = s - eps;
19240            let lm = run_loss(&mut fwd_arena, &wc, &p, &wfc, &bfc);
19241            let num = (lp - lm) / (2.0 * eps);
19242            assert!(
19243                (g_bc[i] - num).abs() < 5e-3,
19244                "g_bc[{i}]: {} vs {num}",
19245                g_bc[i]
19246            );
19247        }
19248        for i in 0..wfc.len() {
19249            let mut p = wfc.clone();
19250            let s = p[i];
19251            p[i] = s + eps;
19252            let lp = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
19253            p[i] = s - eps;
19254            let lm = run_loss(&mut fwd_arena, &wc, &bc, &p, &bfc);
19255            let num = (lp - lm) / (2.0 * eps);
19256            assert!(
19257                (g_wfc[i] - num).abs() < 5e-3,
19258                "g_wfc[{i}]: {} vs {num}",
19259                g_wfc[i]
19260            );
19261        }
19262        for i in 0..bfc.len() {
19263            let mut p = bfc.clone();
19264            let s = p[i];
19265            p[i] = s + eps;
19266            let lp = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
19267            p[i] = s - eps;
19268            let lm = run_loss(&mut fwd_arena, &wc, &bc, &wfc, &p);
19269            let num = (lp - lm) / (2.0 * eps);
19270            assert!(
19271                (g_bfc[i] - num).abs() < 5e-3,
19272                "g_bfc[{i}]: {} vs {num}",
19273                g_bfc[i]
19274            );
19275        }
19276    }
19277
19278    /// Negative case: a Narrow whose output has multiple consumers
19279    /// must NOT be fused (we can't elide its write — something else
19280    /// reads it).
19281    #[test]
19282    fn narrow_rope_skips_when_narrow_has_multiple_consumers() {
19283        let f = DType::F32;
19284        let mut g = Graph::new("nr_skip");
19285        let qkv = g.input("qkv", Shape::new(&[16, 8, 192], f));
19286        let cos = g.input("cos", Shape::new(&[16], f));
19287        let sin = g.input("sin", Shape::new(&[16], f));
19288        let q = g.narrow_(qkv, 2, 0, 64);
19289        let q_rope = g.rope(q, cos, sin, 16);
19290        // Second consumer of `q` blocks the fusion.
19291        let q_dup = g.activation(rlx_ir::op::Activation::Relu, q, Shape::new(&[16, 8, 64], f));
19292        g.set_outputs(vec![q_rope, q_dup]);
19293
19294        let plan = rlx_opt::memory::plan_memory(&g);
19295        let arena = crate::arena::Arena::from_plan(plan);
19296        let sched = compile_thunks(&g, &arena);
19297
19298        let narrow_count = sched
19299            .thunks
19300            .iter()
19301            .filter(|t| matches!(t, Thunk::Narrow { .. }))
19302            .count();
19303        assert!(
19304            narrow_count >= 1,
19305            "Narrow with multiple consumers must NOT be fused away"
19306        );
19307    }
19308
19309    // ── Op::CustomFn (custom_vjp / custom_jvp) tests ──
19310    //
19311    // Validates: forward execution inlines fwd_body; VJP rule inlines
19312    // vjp_body in place of recursing into fwd_body; JVP rule inlines
19313    // jvp_body. Each test deliberately picks a body whose AD-via-tracing
19314    // would yield a *different* gradient than the override, so we know
19315    // the override actually fired.
19316
19317    /// Forward only: CustomFn wrapping `f(x) = x + c` (c=1 inside body)
19318    /// without override AD bodies. Verifies the body is compiled,
19319    /// constants in the body fill correctly, and the output lands at
19320    /// the outer node's slot.
19321    #[test]
19322    fn custom_fn_forward_inlines_body() {
19323        let s = Shape::new(&[3], DType::F32);
19324
19325        // Body: f(x) = x + 1
19326        let mut body = Graph::new("addone_body");
19327        let x = body.input("x", s.clone());
19328        let one_data: Vec<u8> = (0..3).flat_map(|_| 1.0_f32.to_le_bytes()).collect();
19329        let one = body.add_node(Op::Constant { data: one_data }, vec![], s.clone());
19330        let y = body.binary(BinaryOp::Add, x, one, s.clone());
19331        body.set_outputs(vec![y]);
19332
19333        let mut g = Graph::new("custom_fn_outer");
19334        let xin = g.input("x_in", s.clone());
19335        let cf = g.custom_fn(vec![xin], body, None, None);
19336        g.set_outputs(vec![cf]);
19337
19338        let xs = vec![10.0_f32, 20.0, 30.0];
19339        let (sched, mut arena) = prepare(&g, &[(xin, &xs)]);
19340        execute_thunks(&sched, arena.raw_buf_mut());
19341        let got = read_arena(&arena, cf, 3);
19342        assert_eq!(got, vec![11.0, 21.0, 31.0]);
19343    }
19344
19345    /// Locate an Op::Input or Op::Param by name in a graph.
19346    fn find_named(graph: &Graph, want: &str) -> NodeId {
19347        for n in graph.nodes() {
19348            let name = match &n.op {
19349                Op::Input { name } | Op::Param { name } => Some(name.as_str()),
19350                _ => None,
19351            };
19352            if name == Some(want) {
19353                return n.id;
19354            }
19355        }
19356        panic!("no node named {want:?} in graph");
19357    }
19358
19359    /// VJP override: f(x) = x but vjp_body returns 2 * d_output, so the
19360    /// reported gradient should be 2 — different from the natural 1
19361    /// you'd get by recursing into the identity body.
19362    #[test]
19363    fn custom_fn_vjp_overrides_natural_gradient() {
19364        use rlx_opt::autodiff::grad_with_loss;
19365        let s = Shape::new(&[1], DType::F32);
19366
19367        let mut fwd = Graph::new("id_fwd");
19368        let x = fwd.input("x", s.clone());
19369        fwd.set_outputs(vec![x]);
19370
19371        let mut vjp_g = Graph::new("id_vjp");
19372        let _x_p = vjp_g.input("x", s.clone());
19373        let _y_p = vjp_g.input("primal_output", s.clone());
19374        let dy = vjp_g.input("d_output", s.clone());
19375        let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
19376        let two = vjp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
19377        let dx = vjp_g.binary(BinaryOp::Mul, dy, two, s.clone());
19378        vjp_g.set_outputs(vec![dx]);
19379
19380        let mut g = Graph::new("outer");
19381        let xp = g.param("x", s.clone());
19382        let cf = g.custom_fn(vec![xp], fwd, Some(vjp_g), None);
19383        g.set_outputs(vec![cf]);
19384
19385        let bwd = grad_with_loss(&g, &[xp]);
19386        assert_eq!(bwd.outputs.len(), 2, "expect [loss, dx]");
19387
19388        let xb = find_named(&bwd, "x");
19389        let dout = find_named(&bwd, "d_output");
19390        let (sched, mut arena) = prepare(&bwd, &[(xb, &[7.0]), (dout, &[1.0])]);
19391        execute_thunks(&sched, arena.raw_buf_mut());
19392        let loss = read_arena(&arena, bwd.outputs[0], 1);
19393        let dx_v = read_arena(&arena, bwd.outputs[1], 1);
19394        assert!((loss[0] - 7.0).abs() < 1e-6, "loss should be 7.0");
19395        assert!(
19396            (dx_v[0] - 2.0).abs() < 1e-6,
19397            "vjp override should yield dx=2.0, got {} (natural autodiff would give 1.0)",
19398            dx_v[0]
19399        );
19400    }
19401
19402    /// VJP override: f(a, b) = a*b with vjp_body returning
19403    /// (b * d_output, a * d_output). Validates routing of multiple
19404    /// primals + d_output through the override; matches the natural
19405    /// autodiff-of-Mul gradient (b, a).
19406    #[test]
19407    fn custom_fn_vjp_two_inputs_matches_mul_autodiff() {
19408        use rlx_opt::autodiff::grad_with_loss;
19409        let s = Shape::new(&[1], DType::F32);
19410
19411        let mut fwd = Graph::new("mul_fwd");
19412        let a_f = fwd.input("a", s.clone());
19413        let b_f = fwd.input("b", s.clone());
19414        let y_f = fwd.binary(BinaryOp::Mul, a_f, b_f, s.clone());
19415        fwd.set_outputs(vec![y_f]);
19416
19417        let mut vjp_g = Graph::new("mul_vjp");
19418        let a_v = vjp_g.input("a", s.clone());
19419        let b_v = vjp_g.input("b", s.clone());
19420        let _y_v = vjp_g.input("primal_output", s.clone());
19421        let dy_v = vjp_g.input("d_output", s.clone());
19422        let da = vjp_g.binary(BinaryOp::Mul, b_v, dy_v, s.clone());
19423        let db = vjp_g.binary(BinaryOp::Mul, a_v, dy_v, s.clone());
19424        vjp_g.set_outputs(vec![da, db]);
19425
19426        let mut g = Graph::new("outer");
19427        let ap = g.param("a", s.clone());
19428        let bp = g.param("b", s.clone());
19429        let cf = g.custom_fn(vec![ap, bp], fwd, Some(vjp_g), None);
19430        g.set_outputs(vec![cf]);
19431
19432        let bwd = grad_with_loss(&g, &[ap, bp]);
19433        assert_eq!(bwd.outputs.len(), 3, "expect [loss, da, db]");
19434
19435        let ab = find_named(&bwd, "a");
19436        let bb = find_named(&bwd, "b");
19437        let dout = find_named(&bwd, "d_output");
19438        let (sched, mut arena) = prepare(&bwd, &[(ab, &[3.0]), (bb, &[5.0]), (dout, &[1.0])]);
19439        execute_thunks(&sched, arena.raw_buf_mut());
19440        let loss = read_arena(&arena, bwd.outputs[0], 1);
19441        let da_v = read_arena(&arena, bwd.outputs[1], 1);
19442        let db_v = read_arena(&arena, bwd.outputs[2], 1);
19443        assert!((loss[0] - 15.0).abs() < 1e-5);
19444        assert!(
19445            (da_v[0] - 5.0).abs() < 1e-5,
19446            "da should be b=5.0, got {}",
19447            da_v[0]
19448        );
19449        assert!(
19450            (db_v[0] - 3.0).abs() < 1e-5,
19451            "db should be a=3.0, got {}",
19452            db_v[0]
19453        );
19454    }
19455
19456    /// JVP override: f(x) = x but jvp_body returns 2 * tangent_0.
19457    /// Forward-mode tangent should be 2x the seed (1.0) → 2.0.
19458    #[test]
19459    fn custom_fn_jvp_overrides_natural_tangent() {
19460        use rlx_opt::autodiff_fwd::jvp;
19461        let s = Shape::new(&[1], DType::F32);
19462
19463        let mut fwd = Graph::new("id_fwd");
19464        let x = fwd.input("x", s.clone());
19465        fwd.set_outputs(vec![x]);
19466
19467        let mut jvp_g = Graph::new("id_jvp");
19468        let _x_p = jvp_g.input("x", s.clone());
19469        let tx = jvp_g.input("tangent_0", s.clone());
19470        let two_data: Vec<u8> = 2.0_f32.to_le_bytes().to_vec();
19471        let two = jvp_g.add_node(Op::Constant { data: two_data }, vec![], s.clone());
19472        let ty = jvp_g.binary(BinaryOp::Mul, tx, two, s.clone());
19473        jvp_g.set_outputs(vec![ty]);
19474
19475        let mut g = Graph::new("outer");
19476        let xin = g.input("x_in", s.clone());
19477        let cf = g.custom_fn(vec![xin], fwd, None, Some(jvp_g));
19478        g.set_outputs(vec![cf]);
19479
19480        let fwd_g = jvp(&g, &[xin]);
19481        assert_eq!(fwd_g.outputs.len(), 2, "expect [primal_y, tangent_y]");
19482
19483        let xb = find_named(&fwd_g, "x_in");
19484        let tan = find_named(&fwd_g, "tangent_x_in");
19485        let (sched, mut arena) = prepare(&fwd_g, &[(xb, &[7.0]), (tan, &[1.0])]);
19486        execute_thunks(&sched, arena.raw_buf_mut());
19487        let y = read_arena(&arena, fwd_g.outputs[0], 1);
19488        let ty_v = read_arena(&arena, fwd_g.outputs[1], 1);
19489        assert!((y[0] - 7.0).abs() < 1e-6);
19490        assert!(
19491            (ty_v[0] - 2.0).abs() < 1e-6,
19492            "jvp override should yield t_y=2.0 (natural autodiff would give 1.0), got {}",
19493            ty_v[0]
19494        );
19495    }
19496
19497    /// IR-level basic test: `DType::C64` is wired through the dtype
19498    /// table — `size_bytes() == 8`, `is_complex()` reports true, and
19499    /// a `[2]`-shaped C64 buffer in the arena occupies the expected
19500    /// 16 bytes.
19501    #[test]
19502    fn c64_dtype_storage_layout() {
19503        assert_eq!(
19504            DType::C64.size_bytes(),
19505            8,
19506            "C64 should be 8 bytes (f32 real + f32 imag)"
19507        );
19508        assert!(DType::C64.is_complex());
19509        assert!(!DType::C64.is_float());
19510
19511        // A length-2 C64 buffer should have shape size_bytes = 16.
19512        let s = Shape::new(&[2], DType::C64);
19513        assert_eq!(s.size_bytes().unwrap(), 16);
19514    }
19515
19516    // ── C64 element-wise binary kernel witnesses (2026-05-17) ──────
19517    //
19518    // Build a tiny graph: Input `a` + Input `b` (both C64 [2]),
19519    // output = a OP b. Run through CompileResult and compare against
19520    // the closed-form complex arithmetic on the four chosen pairs.
19521
19522    fn run_c64_binary(op: BinaryOp, a: &[(f32, f32)], b: &[(f32, f32)]) -> Vec<(f32, f32)> {
19523        let n = a.len();
19524        let s = Shape::new(&[n], DType::C64);
19525        let mut g = Graph::new("c64_bin");
19526        let in_a = g.input("a", s.clone());
19527        let in_b = g.input("b", s.clone());
19528        let out = g.binary(op, in_a, in_b, s.clone());
19529        g.set_outputs(vec![out]);
19530
19531        let plan = rlx_opt::memory::plan_memory(&g);
19532        let mut arena = crate::arena::Arena::from_plan(plan);
19533        let sched = compile_thunks(&g, &arena);
19534
19535        let a_off = arena.byte_offset(in_a);
19536        let b_off = arena.byte_offset(in_b);
19537        let out_off = arena.byte_offset(out);
19538        // Interleave [re_0, im_0, re_1, im_1, ...] in the f32 buffer.
19539        let buf = arena.raw_buf_mut();
19540        unsafe {
19541            let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
19542            let pb = buf.as_mut_ptr().add(b_off) as *mut f32;
19543            for (i, &(re, im)) in a.iter().enumerate() {
19544                *pa.add(2 * i) = re;
19545                *pa.add(2 * i + 1) = im;
19546            }
19547            for (i, &(re, im)) in b.iter().enumerate() {
19548                *pb.add(2 * i) = re;
19549                *pb.add(2 * i + 1) = im;
19550            }
19551        }
19552        execute_thunks(&sched, arena.raw_buf_mut());
19553        let raw_out: Vec<f32> = unsafe {
19554            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19555            (0..(2 * n)).map(|i| *p.add(i)).collect()
19556        };
19557        (0..n)
19558            .map(|i| (raw_out[2 * i], raw_out[2 * i + 1]))
19559            .collect()
19560    }
19561
19562    #[track_caller]
19563    fn assert_close_c(got: (f32, f32), expected: (f32, f32), tol: f32, label: &str) {
19564        let dr = (got.0 - expected.0).abs();
19565        let di = (got.1 - expected.1).abs();
19566        assert!(
19567            dr < tol && di < tol,
19568            "[{label}] got ({:+.4}, {:+.4}), expected ({:+.4}, {:+.4})",
19569            got.0,
19570            got.1,
19571            expected.0,
19572            expected.1
19573        );
19574    }
19575
19576    #[test]
19577    fn c64_binary_add_matches_complex_arithmetic() {
19578        let a = [(1.0_f32, 2.0_f32), (3.0_f32, -1.0_f32)];
19579        let b = [(4.0_f32, -1.0_f32), (0.5_f32, 0.5_f32)];
19580        let out = run_c64_binary(BinaryOp::Add, &a, &b);
19581        assert_close_c(out[0], (5.0, 1.0), 1e-6, "add[0]");
19582        assert_close_c(out[1], (3.5, -0.5), 1e-6, "add[1]");
19583    }
19584
19585    #[test]
19586    fn c64_binary_sub_matches_complex_arithmetic() {
19587        let a = [(5.0_f32, 1.0_f32)];
19588        let b = [(2.0_f32, 3.0_f32)];
19589        let out = run_c64_binary(BinaryOp::Sub, &a, &b);
19590        assert_close_c(out[0], (3.0, -2.0), 1e-6, "sub");
19591    }
19592
19593    #[test]
19594    fn c64_binary_mul_matches_complex_arithmetic() {
19595        // (1 + 2i)(3 + 4i) = 3 + 4i + 6i + 8i² = -5 + 10i.
19596        let a = [(1.0_f32, 2.0_f32)];
19597        let b = [(3.0_f32, 4.0_f32)];
19598        let out = run_c64_binary(BinaryOp::Mul, &a, &b);
19599        assert_close_c(out[0], (-5.0, 10.0), 1e-5, "mul");
19600    }
19601
19602    #[test]
19603    fn c64_binary_div_matches_complex_arithmetic() {
19604        // (1 + 2i) / (3 + 4i) = ((1·3 + 2·4) + (2·3 − 1·4)i) / 25
19605        //                     = (11 + 2i) / 25
19606        //                     = 0.44 + 0.08i
19607        let a = [(1.0_f32, 2.0_f32)];
19608        let b = [(3.0_f32, 4.0_f32)];
19609        let out = run_c64_binary(BinaryOp::Div, &a, &b);
19610        assert_close_c(out[0], (0.44, 0.08), 1e-5, "div");
19611    }
19612
19613    #[test]
19614    fn c64_binary_mul_identity_one_is_no_op() {
19615        // (a + bi) · (1 + 0i) = a + bi.
19616        let a = [(3.5_f32, -1.25_f32), (-2.0_f32, 7.0_f32)];
19617        let b = [(1.0_f32, 0.0_f32), (1.0_f32, 0.0_f32)];
19618        let out = run_c64_binary(BinaryOp::Mul, &a, &b);
19619        assert_close_c(out[0], a[0], 1e-6, "mul·1[0]");
19620        assert_close_c(out[1], a[1], 1e-6, "mul·1[1]");
19621    }
19622
19623    #[test]
19624    fn c64_binary_mul_by_i_rotates_90_degrees() {
19625        // (a + bi) · i = (a + bi)(0 + i) = -b + ai. 90° CCW rotation.
19626        let a = [(1.0_f32, 0.0_f32)];
19627        let b = [(0.0_f32, 1.0_f32)];
19628        let out = run_c64_binary(BinaryOp::Mul, &a, &b);
19629        assert_close_c(out[0], (0.0, 1.0), 1e-6, "1·i");
19630    }
19631
19632    #[test]
19633    fn c64_binary_div_by_self_gives_unity() {
19634        let a = [(2.5_f32, -1.5_f32), (-0.7_f32, 4.2_f32)];
19635        let out = run_c64_binary(BinaryOp::Div, &a, &a);
19636        assert_close_c(out[0], (1.0, 0.0), 1e-5, "div_self[0]");
19637        assert_close_c(out[1], (1.0, 0.0), 1e-5, "div_self[1]");
19638    }
19639
19640    #[test]
19641    #[should_panic(expected = "C64: complex max/min/pow")]
19642    fn c64_binary_max_is_rejected_at_lowering() {
19643        run_c64_binary(BinaryOp::Max, &[(1.0_f32, 2.0_f32)], &[(3.0_f32, 4.0_f32)]);
19644    }
19645
19646    fn run_c64_activation(act: Activation, a: &[(f32, f32)]) -> Vec<(f32, f32)> {
19647        let n = a.len();
19648        let s = Shape::new(&[n], DType::C64);
19649        let mut g = Graph::new("c64_act");
19650        let in_a = g.input("a", s.clone());
19651        let out = g.activation(act, in_a, s.clone());
19652        g.set_outputs(vec![out]);
19653        let plan = rlx_opt::memory::plan_memory(&g);
19654        let mut arena = crate::arena::Arena::from_plan(plan);
19655        let sched = compile_thunks(&g, &arena);
19656        let a_off = arena.byte_offset(in_a);
19657        let out_off = arena.byte_offset(out);
19658        let buf = arena.raw_buf_mut();
19659        unsafe {
19660            let pa = buf.as_mut_ptr().add(a_off) as *mut f32;
19661            for (i, &(re, im)) in a.iter().enumerate() {
19662                *pa.add(2 * i) = re;
19663                *pa.add(2 * i + 1) = im;
19664            }
19665        }
19666        execute_thunks(&sched, arena.raw_buf_mut());
19667        let raw: Vec<f32> = unsafe {
19668            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19669            (0..(2 * n)).map(|i| *p.add(i)).collect()
19670        };
19671        (0..n).map(|i| (raw[2 * i], raw[2 * i + 1])).collect()
19672    }
19673
19674    #[test]
19675    fn c64_activation_neg_negates_both_components() {
19676        let inp = [(3.5_f32, -1.25_f32), (-2.0_f32, 0.0_f32)];
19677        let out = run_c64_activation(Activation::Neg, &inp);
19678        assert_close_c(out[0], (-3.5, 1.25), 1e-6, "neg[0]");
19679        assert_close_c(out[1], (2.0, 0.0), 1e-6, "neg[1]");
19680    }
19681
19682    #[test]
19683    fn c64_activation_exp_matches_euler() {
19684        // exp(0 + i·π) = -1 + 0i.
19685        // exp(1 + 0i) = e ≈ 2.71828.
19686        let inp = [(0.0_f32, std::f32::consts::PI), (1.0_f32, 0.0_f32)];
19687        let out = run_c64_activation(Activation::Exp, &inp);
19688        assert_close_c(out[0], (-1.0, 0.0), 1e-5, "exp(iπ)");
19689        assert_close_c(out[1], (std::f32::consts::E, 0.0), 1e-5, "exp(1)");
19690    }
19691
19692    #[test]
19693    fn c64_activation_log_matches_principal_branch() {
19694        // log(1 + 0i) = 0.
19695        // log(0 + i) = log(1) + i·π/2 = 0 + i·π/2.
19696        // log(-1 + 0i) = 0 + i·π.
19697        let inp = [(1.0_f32, 0.0_f32), (0.0_f32, 1.0_f32), (-1.0_f32, 0.0_f32)];
19698        let out = run_c64_activation(Activation::Log, &inp);
19699        assert_close_c(out[0], (0.0, 0.0), 1e-5, "log(1)");
19700        assert_close_c(out[1], (0.0, std::f32::consts::FRAC_PI_2), 1e-5, "log(i)");
19701        assert_close_c(out[2], (0.0, std::f32::consts::PI), 1e-5, "log(-1)");
19702    }
19703
19704    #[test]
19705    fn c64_activation_sqrt_squared_recovers_input() {
19706        // For positive-real-part inputs, sqrt(z)² should equal z exactly
19707        // to f32 noise.
19708        let inp = [(4.0_f32, 0.0_f32), (3.0_f32, 4.0_f32)];
19709        let roots = run_c64_activation(Activation::Sqrt, &inp);
19710        // sqrt(4) = 2 + 0i; sqrt(3+4i) = 2 + i (since (2+i)² = 4+4i-1 = 3+4i).
19711        assert_close_c(roots[0], (2.0, 0.0), 1e-5, "sqrt(4)");
19712        assert_close_c(roots[1], (2.0, 1.0), 1e-5, "sqrt(3+4i)");
19713    }
19714
19715    #[test]
19716    #[should_panic(expected = "no natural complex extension")]
19717    fn c64_activation_relu_is_rejected_at_lowering() {
19718        run_c64_activation(Activation::Relu, &[(1.0_f32, 2.0_f32)]);
19719    }
19720
19721    // ── ComplexNormSq + Wirtinger backward witnesses ───────────────
19722
19723    /// Forward `|z|²`: returns `[n]` f32.
19724    fn run_complex_norm_sq(z: &[(f32, f32)]) -> Vec<f32> {
19725        let n = z.len();
19726        let mut g = Graph::new("cns_fwd");
19727        let in_z = g.input("z", Shape::new(&[n], DType::C64));
19728        let out = g.complex_norm_sq(in_z);
19729        g.set_outputs(vec![out]);
19730        let plan = rlx_opt::memory::plan_memory(&g);
19731        let mut arena = crate::arena::Arena::from_plan(plan);
19732        let sched = compile_thunks(&g, &arena);
19733        let z_off = arena.byte_offset(in_z);
19734        let out_off = arena.byte_offset(out);
19735        let buf = arena.raw_buf_mut();
19736        unsafe {
19737            let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
19738            for (i, &(re, im)) in z.iter().enumerate() {
19739                *pz.add(2 * i) = re;
19740                *pz.add(2 * i + 1) = im;
19741            }
19742        }
19743        execute_thunks(&sched, arena.raw_buf_mut());
19744        unsafe {
19745            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19746            (0..n).map(|i| *p.add(i)).collect()
19747        }
19748    }
19749
19750    /// Backward: given z and upstream g, return dz = g·z element-wise (C64).
19751    fn run_complex_norm_sq_bwd(z: &[(f32, f32)], g: &[f32]) -> Vec<(f32, f32)> {
19752        let n = z.len();
19753        let mut gr = Graph::new("cns_bwd");
19754        let in_z = gr.input("z", Shape::new(&[n], DType::C64));
19755        let in_g = gr.input("g", Shape::new(&[n], DType::F32));
19756        let out = gr.complex_norm_sq_backward(in_z, in_g);
19757        gr.set_outputs(vec![out]);
19758        let plan = rlx_opt::memory::plan_memory(&gr);
19759        let mut arena = crate::arena::Arena::from_plan(plan);
19760        let sched = compile_thunks(&gr, &arena);
19761        let z_off = arena.byte_offset(in_z);
19762        let g_off = arena.byte_offset(in_g);
19763        let out_off = arena.byte_offset(out);
19764        let buf = arena.raw_buf_mut();
19765        unsafe {
19766            let pz = buf.as_mut_ptr().add(z_off) as *mut f32;
19767            let pg = buf.as_mut_ptr().add(g_off) as *mut f32;
19768            for (i, &(re, im)) in z.iter().enumerate() {
19769                *pz.add(2 * i) = re;
19770                *pz.add(2 * i + 1) = im;
19771            }
19772            for (i, &v) in g.iter().enumerate() {
19773                *pg.add(i) = v;
19774            }
19775        }
19776        execute_thunks(&sched, arena.raw_buf_mut());
19777        unsafe {
19778            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19779            (0..n).map(|i| (*p.add(2 * i), *p.add(2 * i + 1))).collect()
19780        }
19781    }
19782
19783    #[test]
19784    fn complex_norm_sq_matches_textbook() {
19785        // |3 + 4i|² = 9 + 16 = 25.
19786        // |1 + 0i|² = 1.
19787        // |0 + 0i|² = 0.
19788        let z = [(3.0_f32, 4.0_f32), (1.0_f32, 0.0_f32), (0.0_f32, 0.0_f32)];
19789        let out = run_complex_norm_sq(&z);
19790        assert!((out[0] - 25.0).abs() < 1e-5);
19791        assert!((out[1] - 1.0).abs() < 1e-6);
19792        assert!(out[2].abs() < 1e-6);
19793    }
19794
19795    #[test]
19796    fn complex_norm_sq_backward_matches_wirtinger_formula() {
19797        // Wirtinger: ∂|z|²/∂z̄ = z. With upstream g = 1, dz = z.
19798        let z = [(3.0_f32, 4.0_f32), (1.5_f32, -2.5_f32)];
19799        let g = [1.0_f32, 1.0_f32];
19800        let dz = run_complex_norm_sq_bwd(&z, &g);
19801        assert_close_c(dz[0], z[0], 1e-6, "dz[0] = g·z[0]");
19802        assert_close_c(dz[1], z[1], 1e-6, "dz[1] = g·z[1]");
19803    }
19804
19805    #[test]
19806    fn complex_norm_sq_backward_scales_with_upstream() {
19807        // With upstream g[i] ≠ 1: dz[i] = g[i]·z[i].
19808        let z = [(2.0_f32, 1.0_f32), (-1.0_f32, 3.0_f32)];
19809        let g = [0.5_f32, -2.0_f32];
19810        let dz = run_complex_norm_sq_bwd(&z, &g);
19811        assert_close_c(dz[0], (1.0, 0.5), 1e-6, "g=0.5 · (2,1)");
19812        assert_close_c(dz[1], (2.0, -6.0), 1e-6, "g=-2 · (-1,3)");
19813    }
19814
19815    /// Multi-output Op::CustomFn via the concat-with-Narrow design
19816    /// (rlx-ir::Graph::custom_fn_multi). Build a custom_fn whose
19817    /// fwd_body returns two outputs (x², 2x), then materialize each
19818    /// via the MultiOutputHandle and verify both numerically.
19819    #[test]
19820    fn custom_fn_multi_extracts_each_subgraph_output() {
19821        use rlx_ir::ops::special::MultiOutputHandle;
19822
19823        let _ = MultiOutputHandle {
19824            source: NodeId(0),
19825            sub_shapes: vec![],
19826            offsets: vec![],
19827        }; // import sanity
19828
19829        // Inner body: input x [3] f32, outputs (x², 2x) both [3] f32.
19830        let mut body = Graph::new("multi_body");
19831        let s3 = Shape::new(&[3], DType::F32);
19832        let x = body.input("x", s3.clone());
19833        let x_sq = body.binary(BinaryOp::Mul, x, x, s3.clone());
19834        let two = body.add_node(
19835            Op::Constant {
19836                data: vec![
19837                    2.0_f32.to_le_bytes(),
19838                    2.0_f32.to_le_bytes(),
19839                    2.0_f32.to_le_bytes(),
19840                ]
19841                .into_iter()
19842                .flatten()
19843                .collect(),
19844            },
19845            vec![],
19846            s3.clone(),
19847        );
19848        let two_x = body.binary(BinaryOp::Mul, two, x, s3.clone());
19849        body.set_outputs(vec![x_sq, two_x]);
19850
19851        // Outer graph: feed in_x → custom_fn_multi → handle.output(0/1).
19852        let mut outer = Graph::new("multi_outer");
19853        let in_x = outer.input("xin", s3.clone());
19854        let handle = outer.custom_fn_multi(vec![in_x], body);
19855        assert_eq!(handle.n_outputs(), 2);
19856        let out0 = handle.output(&mut outer, 0); // x²
19857        let out1 = handle.output(&mut outer, 1); // 2x
19858        outer.set_outputs(vec![out0, out1]);
19859
19860        let plan = rlx_opt::memory::plan_memory(&outer);
19861        let mut arena = crate::arena::Arena::from_plan(plan);
19862        let sched = compile_thunks(&outer, &arena);
19863        let xin_off = arena.byte_offset(in_x);
19864        let out0_off = arena.byte_offset(out0);
19865        let out1_off = arena.byte_offset(out1);
19866        let xs = [1.0_f32, 2.0, 3.0];
19867        unsafe {
19868            let p = arena.raw_buf_mut().as_mut_ptr().add(xin_off) as *mut f32;
19869            for (i, &v) in xs.iter().enumerate() {
19870                *p.add(i) = v;
19871            }
19872        }
19873        execute_thunks(&sched, arena.raw_buf_mut());
19874        let out0_v: Vec<f32> = unsafe {
19875            let p = arena.raw_buf().as_ptr().add(out0_off) as *const f32;
19876            (0..3).map(|i| *p.add(i)).collect()
19877        };
19878        let out1_v: Vec<f32> = unsafe {
19879            let p = arena.raw_buf().as_ptr().add(out1_off) as *const f32;
19880            (0..3).map(|i| *p.add(i)).collect()
19881        };
19882        // x² = [1, 4, 9]; 2x = [2, 4, 6].
19883        for i in 0..3 {
19884            assert!(
19885                (out0_v[i] - xs[i] * xs[i]).abs() < 1e-5,
19886                "out0[{i}] = {} != x² = {}",
19887                out0_v[i],
19888                xs[i] * xs[i]
19889            );
19890            assert!(
19891                (out1_v[i] - 2.0 * xs[i]).abs() < 1e-5,
19892                "out1[{i}] = {} != 2x = {}",
19893                out1_v[i],
19894                2.0 * xs[i]
19895            );
19896        }
19897    }
19898
19899    #[test]
19900    fn complex_norm_sq_gradient_matches_finite_difference() {
19901        // Numerical sanity: perturb z[0].re by ε, observe Δ|z|² ≈ 2·re·ε.
19902        let z = [(3.0_f32, 4.0_f32)];
19903        let eps = 1e-3_f32;
19904        let v0 = run_complex_norm_sq(&z)[0];
19905        let z_pert = [(3.0_f32 + eps, 4.0_f32)];
19906        let v1 = run_complex_norm_sq(&z_pert)[0];
19907        let fd_re = (v1 - v0) / eps;
19908        let analytic_re = 2.0 * z[0].0;
19909        assert!((fd_re - analytic_re).abs() < 1e-2);
19910
19911        // ∂/∂im at z = (3, 4) is 2·im = 8.
19912        let z_pert_im = [(3.0_f32, 4.0_f32 + eps)];
19913        let v2 = run_complex_norm_sq(&z_pert_im)[0];
19914        let fd_im = (v2 - v0) / eps;
19915        let analytic_im = 2.0 * z[0].1;
19916        assert!((fd_im - analytic_im).abs() < 1e-2);
19917
19918        // Compare with the Wirtinger backward at upstream g = 1.
19919        // Wirtinger ∂/∂z̄ = z gives dz = (re, im). The "real
19920        // gradient" wrt (re, im) is 2·(re, im), i.e. 2·dz = (2·re,
19921        // 2·im) — that's the factor 2 difference between Wirtinger
19922        // ∂/∂z̄ and the real-vector gradient on (re, im).
19923        let dz = run_complex_norm_sq_bwd(&z, &[1.0_f32]);
19924        assert!((2.0 * dz[0].0 - analytic_re).abs() < 1e-5);
19925        assert!((2.0 * dz[0].1 - analytic_im).abs() < 1e-5);
19926    }
19927
19928    /// Direct regression test for the 5-D mid-shape singleton broadcast
19929    /// (SAM rel_pos pattern: `[bh, h, w, 1, w] + [bh, h, w, h, w]`).
19930    /// The SAM port worked around this by `concat`-tiling the rhs; this
19931    /// test verifies the in-graph broadcast path is bit-correct.
19932    #[test]
19933    fn binary_full_5d_mid_singleton_broadcast() {
19934        let bh = 2usize;
19935        let h = 3;
19936        let w = 4;
19937        let f = DType::F32;
19938
19939        let mut g = Graph::new("bcast_5d");
19940        let lhs = g.input("lhs", Shape::new(&[bh, h, w, h, w], f));
19941        // rhs shape with size-1 at axis 3 (mid-shape singleton).
19942        let rhs = g.input("rhs", Shape::new(&[bh, h, w, 1, w], f));
19943        let out = g.binary(BinaryOp::Add, lhs, rhs, Shape::new(&[bh, h, w, h, w], f));
19944        g.set_outputs(vec![out]);
19945
19946        // Deterministic data.
19947        let lhs_data: Vec<f32> = (0..bh * h * w * h * w).map(|i| i as f32 * 0.01).collect();
19948        let rhs_data: Vec<f32> = (0..bh * h * w * w)
19949            .map(|i| (i as f32 + 100.0) * 0.01)
19950            .collect();
19951
19952        // Compute expected output by hand.
19953        let mut expected = vec![0f32; bh * h * w * h * w];
19954        for b_ in 0..bh {
19955            for hq in 0..h {
19956                for wq in 0..w {
19957                    for hk in 0..h {
19958                        for wk in 0..w {
19959                            let li = (((b_ * h + hq) * w + wq) * h + hk) * w + wk;
19960                            // rhs has hk dim = 1, so it's always index 0 there.
19961                            let ri = ((b_ * h + hq) * w + wq) * w + wk;
19962                            expected[li] = lhs_data[li] + rhs_data[ri];
19963                        }
19964                    }
19965                }
19966            }
19967        }
19968
19969        let plan = rlx_opt::memory::plan_memory(&g);
19970        let mut arena = crate::arena::Arena::from_plan(plan);
19971        let sched = compile_thunks(&g, &arena);
19972        let lhs_off = arena.byte_offset(lhs);
19973        let rhs_off = arena.byte_offset(rhs);
19974        let out_off = arena.byte_offset(out);
19975        let buf = arena.raw_buf_mut();
19976        unsafe {
19977            let p = buf.as_mut_ptr().add(lhs_off) as *mut f32;
19978            for (i, &v) in lhs_data.iter().enumerate() {
19979                *p.add(i) = v;
19980            }
19981            let p = buf.as_mut_ptr().add(rhs_off) as *mut f32;
19982            for (i, &v) in rhs_data.iter().enumerate() {
19983                *p.add(i) = v;
19984            }
19985        }
19986        execute_thunks(&sched, arena.raw_buf_mut());
19987        let actual: Vec<f32> = unsafe {
19988            let p = arena.raw_buf().as_ptr().add(out_off) as *const f32;
19989            (0..bh * h * w * h * w).map(|i| *p.add(i)).collect()
19990        };
19991
19992        // Bit-exact check.
19993        let mut max_diff = 0f32;
19994        let mut max_idx = 0;
19995        for i in 0..actual.len() {
19996            let d = (actual[i] - expected[i]).abs();
19997            if d > max_diff {
19998                max_diff = d;
19999                max_idx = i;
20000            }
20001        }
20002        assert!(
20003            max_diff < 1e-6,
20004            "5D mid-shape singleton broadcast wrong: max |Δ| = {max_diff} at idx {max_idx} \
20005             (actual={}, expected={})",
20006            actual[max_idx],
20007            expected[max_idx]
20008        );
20009    }
20010
20011    #[test]
20012    fn layer_norm2d_and_conv_transpose2d_kernels() {
20013        let mut out = vec![0f32; 8];
20014        crate::kernels::layer_norm2d_nchw(
20015            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
20016            &[1.0, 1.0],
20017            &[0.0, 0.0],
20018            &mut out,
20019            1,
20020            2,
20021            2,
20022            2,
20023            1e-5,
20024        );
20025        let mean0: f32 = (1.0 + 3.0) / 2.0;
20026        assert!((out[0] - mean0).abs() > 0.1);
20027
20028        let mut up = vec![0f32; 4];
20029        crate::kernels::conv_transpose2d_nchw(
20030            &[2.0],
20031            &[1.0, 0.0, 0.0, 1.0],
20032            &mut up,
20033            1,
20034            1,
20035            1,
20036            1,
20037            1,
20038            2,
20039            2,
20040            2,
20041            2,
20042            2,
20043            2,
20044            0,
20045            0,
20046            1,
20047            1,
20048            1,
20049        );
20050        assert!((up[0] - 2.0).abs() < 1e-5);
20051        assert!((up[3] - 2.0).abs() < 1e-5);
20052    }
20053}